From e7151c194ca5999a3e64dcd2236b2d450042da8f Mon Sep 17 00:00:00 2001 From: xuanyue Date: Wed, 9 Dec 2020 00:24:35 +0800 Subject: [PATCH] reconstruct onnx --- mindspore/lite/src/ops/constant.h | 36 + mindspore/lite/src/ops/primitive_c.cc | 5 +- mindspore/lite/src/ops/primitive_c.h | 2 +- mindspore/lite/test/CMakeLists.txt | 1 + mindspore/lite/tools/converter/CMakeLists.txt | 1 + .../lite/tools/converter/anf_transform.cc | 11 + .../parser/caffe/caffe_model_parser.cc | 13 +- .../parser/onnx/onnx_adder_parser.cc | 27 +- .../converter/parser/onnx/onnx_adder_parser.h | 3 +- .../parser/onnx/onnx_argmax_parser.cc | 29 +- .../parser/onnx/onnx_argmax_parser.h | 2 +- .../onnx/onnx_arithmetic_operation_parser.cc | 858 +++++------ .../onnx/onnx_arithmetic_operation_parser.h | 58 +- .../parser/onnx/onnx_batchnorm_parser.cc | 30 +- .../parser/onnx/onnx_batchnorm_parser.h | 2 +- .../parser/onnx/onnx_biasadd_parser.cc | 29 +- .../parser/onnx/onnx_biasadd_parser.h | 2 +- .../converter/parser/onnx/onnx_cast_parser.cc | 29 +- .../converter/parser/onnx/onnx_cast_parser.h | 2 +- .../converter/parser/onnx/onnx_clip_parser.cc | 41 +- .../converter/parser/onnx/onnx_clip_parser.h | 2 +- .../parser/onnx/onnx_concat_parser.cc | 30 +- .../parser/onnx/onnx_concat_parser.h | 2 +- .../onnx/onnx_constant_of_shape_parser.cc | 35 +- .../onnx/onnx_constant_of_shape_parser.h | 2 +- .../parser/onnx/onnx_constant_parser.cc | 78 +- .../parser/onnx/onnx_constant_parser.h | 3 +- .../converter/parser/onnx/onnx_conv_parser.cc | 61 +- .../converter/parser/onnx/onnx_conv_parser.h | 4 +- .../parser/onnx/onnx_deconv_parser.cc | 63 +- .../parser/onnx/onnx_deconv_parser.h | 4 +- .../parser/onnx/onnx_depth_to_space_parser.cc | 30 +- .../parser/onnx/onnx_depth_to_space_parser.h | 2 +- .../parser/onnx/onnx_dropout_parser.cc | 30 +- .../parser/onnx/onnx_dropout_parser.h | 2 +- .../converter/parser/onnx/onnx_elu_parser.cc | 29 +- .../converter/parser/onnx/onnx_elu_parser.h | 2 +- .../parser/onnx/onnx_expand_parser.cc | 31 +- .../parser/onnx/onnx_expand_parser.h | 2 +- .../parser/onnx/onnx_flatten_parser.cc | 30 +- .../parser/onnx/onnx_flatten_parser.h | 2 +- .../parser/onnx/onnx_gather_parser.cc | 29 +- .../parser/onnx/onnx_gather_parser.h | 2 +- .../converter/parser/onnx/onnx_gemm_parser.cc | 55 + ...nnx_tensor_parser.h => onnx_gemm_parser.h} | 23 +- .../onnx/onnx_given_tensor_fill_parser.cc | 127 ++ .../onnx/onnx_given_tensor_fill_parser.h | 39 + .../parser/onnx/onnx_identity_parser.cc | 29 +- .../parser/onnx/onnx_identity_parser.h | 2 +- .../parser/onnx/onnx_instance_norm_parser.cc | 30 +- .../parser/onnx/onnx_instance_norm_parser.h | 2 +- .../parser/onnx/onnx_lp_norm_parser.cc | 30 +- .../parser/onnx/onnx_lp_norm_parser.h | 2 +- .../converter/parser/onnx/onnx_lrn_parser.cc | 30 +- .../converter/parser/onnx/onnx_lrn_parser.h | 2 +- .../converter/parser/onnx/onnx_lstm_parser.cc | 28 +- .../converter/parser/onnx/onnx_lstm_parser.h | 2 +- .../parser/onnx/onnx_matmul_parser.cc | 31 +- .../parser/onnx/onnx_matmul_parser.h | 2 +- .../parser/onnx/onnx_model_parser.cc | 1256 +++++++++-------- .../converter/parser/onnx/onnx_model_parser.h | 88 +- .../converter/parser/onnx/onnx_node_parser.cc | 2 + .../converter/parser/onnx/onnx_node_parser.h | 4 +- .../onnx/onnx_non_max_suppression_parser.cc | 29 +- .../onnx/onnx_non_max_suppression_parser.h | 2 +- .../parser/onnx/onnx_onehot_parser.cc | 29 +- .../parser/onnx/onnx_onehot_parser.h | 2 +- .../converter/parser/onnx/onnx_pad_parser.cc | 28 +- .../converter/parser/onnx/onnx_pad_parser.h | 2 +- .../converter/parser/onnx/onnx_pool_parser.cc | 32 +- .../converter/parser/onnx/onnx_pool_parser.h | 2 +- .../parser/onnx/onnx_quantize_parser.cc | 31 +- .../parser/onnx/onnx_quantize_parser.h | 2 +- .../parser/onnx/onnx_range_parser.cc | 29 +- .../converter/parser/onnx/onnx_range_parser.h | 2 +- .../parser/onnx/onnx_reduce_parser.cc | 31 +- .../parser/onnx/onnx_reduce_parser.h | 2 +- .../converter/parser/onnx/onnx_relu_parser.cc | 61 +- .../converter/parser/onnx/onnx_relu_parser.h | 4 +- .../parser/onnx/onnx_reshape_parser.cc | 45 +- .../parser/onnx/onnx_reshape_parser.h | 2 +- .../parser/onnx/onnx_resize_parser.cc | 29 +- .../parser/onnx/onnx_resize_parser.h | 2 +- .../parser/onnx/onnx_shape_parser.cc | 29 +- .../converter/parser/onnx/onnx_shape_parser.h | 2 +- .../parser/onnx/onnx_sigmoid_parser.cc | 29 +- .../parser/onnx/onnx_sigmoid_parser.h | 2 +- .../parser/onnx/onnx_slice_parser.cc | 152 +- .../converter/parser/onnx/onnx_slice_parser.h | 5 +- .../parser/onnx/onnx_softmax_parser.cc | 29 +- .../parser/onnx/onnx_softmax_parser.h | 2 +- .../parser/onnx/onnx_space_to_depth_parser.cc | 29 +- .../parser/onnx/onnx_space_to_depth_parser.h | 2 +- .../parser/onnx/onnx_split_parser.cc | 29 +- .../converter/parser/onnx/onnx_split_parser.h | 2 +- .../parser/onnx/onnx_squeeze_parser.cc | 29 +- .../parser/onnx/onnx_squeeze_parser.h | 2 +- .../converter/parser/onnx/onnx_tile_parser.cc | 28 +- .../converter/parser/onnx/onnx_tile_parser.h | 2 +- .../converter/parser/onnx/onnx_topk_parser.cc | 28 +- .../converter/parser/onnx/onnx_topk_parser.h | 2 +- .../parser/onnx/onnx_transpose_parser.cc | 29 +- .../parser/onnx/onnx_transpose_parser.h | 2 +- .../parser/onnx/onnx_unsqueeze_parser.cc | 29 +- .../parser/onnx/onnx_unsqueeze_parser.h | 2 +- .../parser/onnx/onnx_upsample_parser.cc | 31 +- .../parser/onnx/onnx_upsample_parser.h | 2 +- .../parser/tflite/tflite_model_parser.cc | 8 +- .../graph/onnx_inputs_adjust_pass.cc | 508 +++++++ .../optimizer/graph/onnx_inputs_adjust_pass.h | 49 + 110 files changed, 2612 insertions(+), 2257 deletions(-) create mode 100644 mindspore/lite/src/ops/constant.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_gemm_parser.cc rename mindspore/lite/tools/converter/parser/onnx/{onnx_tensor_parser.h => onnx_gemm_parser.h} (55%) create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.h create mode 100644 mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.cc create mode 100644 mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.h diff --git a/mindspore/lite/src/ops/constant.h b/mindspore/lite/src/ops/constant.h new file mode 100644 index 00000000000..659331c6503 --- /dev/null +++ b/mindspore/lite/src/ops/constant.h @@ -0,0 +1,36 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef PRIMITIVE_WRITEABLE +#ifndef LITE_MINDSPORE_LITE_C_OPS_CONSTANT_H_ +#define LITE_MINDSPORE_LITE_C_OPS_CONSTANT_H_ + +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class Constant : public PrimitiveC { + public: + Constant() = default; + ~Constant() = default; + MS_DECLARE_PARENT(Constant, PrimitiveC); + explicit Constant(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_CONSTANT_H_ +#endif diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index c3fc2841b4d..585d1dc9a14 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -149,6 +149,7 @@ #include "src/ops/oneslike.h" #include "src/ops/unsorted_segment_sum.h" #include "src/ops/reciprocal.h" +#include "src/ops/constant.h" #ifdef SUPPORT_TRAIN #include "src/ops/neg_grad.h" @@ -182,7 +183,7 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector CastToInt(const ValuePtr value) { +std::vector CastToInt(const ValuePtr &value) { if (value == nullptr) { MS_LOG(WARNING) << "valueptr is nullptr."; return {}; @@ -891,6 +892,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new (std::nothrow) Dequant(primitive); case schema::PrimitiveType_Reciprocal: return new (std::nothrow) Reciprocal(primitive); + case schema::PrimitiveType_Constant: + return new (std::nothrow) Constant(primitive); #ifdef SUPPORT_TRAIN case schema::PrimitiveType_ActivationGrad: diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index 77678578408..51f96b34665 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -57,7 +57,7 @@ static std::map kActivationTypeMap{ {"LeakyRelu", schema::ActivationType_LEAKY_RELU}, {"Tanh", schema::ActivationType_TANH}, {"Logistic", schema::ActivationType_SIGMOID}}; -std::vector CastToInt(const ValuePtr value); +std::vector CastToInt(const ValuePtr &value); class PrimitiveC : public mindspore::Primitive { public: // Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC(). diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 0f070bd21e2..90628ce5929 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -205,6 +205,7 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc ${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc ${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc + ${LITE_DIR}/tools/optimizer/graph/onnx_inputs_adjust_pass.cc ) endif() ### train diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 279c5c9719e..bb48b0b2f7b 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -58,6 +58,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/graph/infershape_pass.cc ../optimizer/graph/slice_prepose_pass.cc ../optimizer/graph/mindir_adjust_pass.cc + ../optimizer/graph/onnx_inputs_adjust_pass.cc ) add_subdirectory(../anf_importer anf_importer) diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index e23592d8a15..96e6861ac4e 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -36,6 +36,7 @@ #include "tools/optimizer/graph/clip_convert_activation_pass.h" #include "tools/optimizer/graph/group_depthwise_op_convert_pass.h" #include "tools/optimizer/graph/tflite_inputs_order_exchange_pass.h" +#include "tools/optimizer/graph/onnx_inputs_adjust_pass.h" #include "tools/optimizer/graph/update_conv2d_param_pass.h" #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" #include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" @@ -74,6 +75,16 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver } } + // onnx pre adjustment + if (config->fmk == converter::FmkType_ONNX) { + auto onnx_adjust_pass = std::make_shared(); + if (!onnx_adjust_pass->Run(old_graph)) { + MS_LOG(ERROR) << "onnx adjust failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); + return nullptr; + } + } + // for now - trainning is not supporting fuse operations if (!config->trainModel) { // remove quantdtype when awaretraining diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc index a917fae399f..fa8fda86880 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -90,6 +90,7 @@ STATUS CaffeModelParser::ConvertLayers() { auto primitive_c = node_parser->ParseLitePrimitive(layer, weight); if (primitive_c == nullptr) { MS_LOG(ERROR) << "parse node " << layer.name() << " failed."; + status = RET_ERROR; continue; } @@ -98,8 +99,7 @@ STATUS CaffeModelParser::ConvertLayers() { status = ConvertBottom(layer, &input_nodes); if (status != RET_OK) { MS_LOG(ERROR) << "Convert layer bottom for " << layer.name() << " failed."; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return status; + continue; } // build weights @@ -107,8 +107,7 @@ STATUS CaffeModelParser::ConvertLayers() { status = ConvertBlobs(weight, &const_parameters); if (status != RET_OK) { MS_LOG(ERROR) << "Convert blobs for " << layer.name() << " failed."; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return status; + continue; } // build cnode @@ -122,15 +121,13 @@ STATUS CaffeModelParser::ConvertLayers() { status = ConvertTop(layer, new_cnode); if (status != RET_OK) { MS_LOG(ERROR) << "Convert outputs for " << layer.name() << " failed."; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return status; + continue; } status = ConvertLayerQuantParams(layer, weight, primitive_c); if (status != RET_OK) { MS_LOG(ERROR) << "Convert quant params for " << layer.name() << " failed."; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return status; + continue; } } return status; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_adder_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_adder_parser.cc index eac8e078c61..41a54fef940 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_adder_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_adder_parser.cc @@ -19,27 +19,22 @@ namespace mindspore { namespace lite { -STATUS OnnxAdderParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxAdderParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx AdderParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - op->primitive->value.type = schema::PrimitiveType_Adder; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Adder; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxAdderParser("adder_f", new OnnxAdderParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_adder_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_adder_parser.h index 0c383ee1de1..59c13aa93c3 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_adder_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_adder_parser.h @@ -26,8 +26,7 @@ class OnnxAdderParser : public OnnxNodeParser { public: OnnxAdderParser() : OnnxNodeParser("Adder") {} ~OnnxAdderParser() override = default; - - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.cc index 65a3ce45d17..b901e49cb02 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.cc @@ -19,23 +19,14 @@ namespace mindspore { namespace lite { -STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxArgMaxParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx ArgMaxParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } for (const auto &onnx_node_attr : onnx_node.attribute()) { @@ -46,10 +37,14 @@ STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N attr->keepDims = static_cast(onnx_node_attr.i()); } } - - op->primitive->value.type = schema::PrimitiveType_ArgMax; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_ArgMax; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxArgMaxParser("ArgMax", new OnnxArgMaxParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.h index e9f90a7a3f7..65f888e1073 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.h @@ -27,7 +27,7 @@ class OnnxArgMaxParser : public OnnxNodeParser { OnnxArgMaxParser() : OnnxNodeParser("ArgMax") {} ~OnnxArgMaxParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc index 86e02cb2c35..bb23706f171 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc @@ -15,256 +15,183 @@ */ #include "tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h" -#include "tools/converter/parser/onnx/onnx_tensor_parser.h" #include #include #include namespace mindspore { namespace lite { -STATUS OnnxAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxAddParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx AddParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - - op->primitive->value.type = schema::PrimitiveType_Add; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Add; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxSubParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxSubParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx SubParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - - op->primitive->value.type = schema::PrimitiveType_Sub; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Sub; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxMulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxMulParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx MulParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - - op->primitive->value.type = schema::PrimitiveType_Mul; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Mul; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxDivParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxDivParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx DivParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - - op->primitive->value.type = schema::PrimitiveType_Div; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Div; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxPowParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxPowParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx PowParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr(new schema::PowerT()); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - - const auto &onnx_pow_power = onnx_node.input(1); - int index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(onnx_pow_power); - if (index == -1) { - MS_LOG(ERROR) << "can not find node: " << onnx_pow_power; - return RET_ERROR; - } - auto pow_attr = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor()[index]; - if (std::accumulate(pow_attr->dims.begin(), pow_attr->dims.end(), 1, std::multiplies()) != 1) { - MS_LOG(ERROR) << "the exponent element num is bigger than 1, which don't support now."; - return RET_NOT_SUPPORT; - } - if (pow_attr->data.data() == nullptr) { - MS_LOG(ERROR) << "power's attr pow can't be obtained."; - return RET_INVALID_OP_ATTR; - } - attr->power = *reinterpret_cast(pow_attr->data.data()); attr->scale = 1.0f; attr->shift = 0.0f; - op->primitive->value.type = schema::PrimitiveType_Power; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Power; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxEqualParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxEqualParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx EqualParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - - op->primitive->value.type = schema::PrimitiveType_Equal; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Equal; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxLessParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxLessParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx LessParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - - op->primitive->value.type = schema::PrimitiveType_Less; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Less; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxGreaterParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { + +lite::PrimitiveC *OnnxGreaterParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx GreaterParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - - op->primitive->value.type = schema::PrimitiveType_Greater; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Greater; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxMinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxMinParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx MinParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - - op->primitive->value.type = schema::PrimitiveType_Minimum; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Minimum; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxEltwiseParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx EltwiseParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } if (onnx_node.op_type() == "Sum") { @@ -272,446 +199,357 @@ STATUS OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: } else if (onnx_node.op_type() == "Max") { attr->mode = schema::EltwiseMode_MAXIMUM; } - - op->primitive->value.type = schema::PrimitiveType_Eltwise; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Eltwise; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxFloorParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxFloorParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx FloorParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - - op->primitive->value.type = schema::PrimitiveType_Floor; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Floor; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxAbsParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxAbsParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx AbsParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - - op->primitive->value.type = schema::PrimitiveType_Abs; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Abs; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxNegParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxNegParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx NegParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - - op->primitive->value.type = schema::PrimitiveType_Neg; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Neg; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxExpParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxExpParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx ExpParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - - op->primitive->value.type = schema::PrimitiveType_Exp; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Exp; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxCosParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxCosParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx CosParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - - op->primitive->value.type = schema::PrimitiveType_Cos; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Cos; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxSinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxSinParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx SinParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - - op->primitive->value.type = schema::PrimitiveType_Sin; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Sin; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxSqrtParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxSqrtParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx SqrtParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - - op->primitive->value.type = schema::PrimitiveType_Sqrt; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Sqrt; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxCeilParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxCeilParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx CeilParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - - op->primitive->value.type = schema::PrimitiveType_Ceil; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Ceil; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxLogParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxLogParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx LogParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - - op->primitive->value.type = schema::PrimitiveType_Log; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Log; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxTanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxTanParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx TanParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); - + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - - op->primitive->value.type = schema::PrimitiveType_Tan; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Tan; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxAtanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxAtanParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx AtanParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - - op->primitive->value.type = schema::PrimitiveType_Atan; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Atan; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxAsinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { - MS_LOG(DEBUG) << "onnx AsinParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); +lite::PrimitiveC *OnnxAsinParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - - op->primitive->value.type = schema::PrimitiveType_Asin; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Asin; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxTanhParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxTanhParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx TanhParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } attr->type = schema::ActivationType_TANH; - op->primitive->value.type = schema::PrimitiveType_Activation; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Activation; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxSignParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxSignParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx TanhParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } attr->type = schema::ActivationType_SIGN; - op->primitive->value.type = schema::PrimitiveType_Activation; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Activation; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxAndParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxAndParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx AndParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - op->primitive->value.type = schema::PrimitiveType_LogicalAnd; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_LogicalAnd; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxOrParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxOrParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx OrParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - op->primitive->value.type = schema::PrimitiveType_LogicalOr; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_LogicalOr; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxNotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxNotParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx NotParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - op->primitive->value.type = schema::PrimitiveType_LogicalNot; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_LogicalNot; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxRoundParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxRoundParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx RoundParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - op->primitive->value.type = schema::PrimitiveType_Round; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Round; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxReciprocalParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxReciprocalParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx ReciprocalParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - op->primitive->value.type = schema::PrimitiveType_Reciprocal; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Reciprocal; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser()); OnnxNodeRegistrar g_onnxInt8AddParser("Int8Add", new OnnxAddParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h index a6c635c0887..7fc62cc3065 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h @@ -26,203 +26,203 @@ class OnnxAddParser : public OnnxNodeParser { public: OnnxAddParser() : OnnxNodeParser("Add") {} ~OnnxAddParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxSubParser : public OnnxNodeParser { public: OnnxSubParser() : OnnxNodeParser("Sub") {} ~OnnxSubParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxMulParser : public OnnxNodeParser { public: OnnxMulParser() : OnnxNodeParser("Mul") {} ~OnnxMulParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxDivParser : public OnnxNodeParser { public: OnnxDivParser() : OnnxNodeParser("Div") {} ~OnnxDivParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxPowParser : public OnnxNodeParser { public: OnnxPowParser() : OnnxNodeParser("Power") {} ~OnnxPowParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxEqualParser : public OnnxNodeParser { public: OnnxEqualParser() : OnnxNodeParser("Equal") {} ~OnnxEqualParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxLessParser : public OnnxNodeParser { public: OnnxLessParser() : OnnxNodeParser("Less") {} ~OnnxLessParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxGreaterParser : public OnnxNodeParser { public: OnnxGreaterParser() : OnnxNodeParser("Greater") {} ~OnnxGreaterParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxMinParser : public OnnxNodeParser { public: OnnxMinParser() : OnnxNodeParser("Min") {} ~OnnxMinParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxEltwiseParser : public OnnxNodeParser { public: OnnxEltwiseParser() : OnnxNodeParser("Eltwise") {} ~OnnxEltwiseParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxFloorParser : public OnnxNodeParser { public: OnnxFloorParser() : OnnxNodeParser("Floor") {} ~OnnxFloorParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxAbsParser : public OnnxNodeParser { public: OnnxAbsParser() : OnnxNodeParser("Abs") {} ~OnnxAbsParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxNegParser : public OnnxNodeParser { public: OnnxNegParser() : OnnxNodeParser("Neg") {} ~OnnxNegParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxExpParser : public OnnxNodeParser { public: OnnxExpParser() : OnnxNodeParser("Exp") {} ~OnnxExpParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxCosParser : public OnnxNodeParser { public: OnnxCosParser() : OnnxNodeParser("Cos") {} ~OnnxCosParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxSinParser : public OnnxNodeParser { public: OnnxSinParser() : OnnxNodeParser("Sin") {} ~OnnxSinParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxSqrtParser : public OnnxNodeParser { public: OnnxSqrtParser() : OnnxNodeParser("Sqrt") {} ~OnnxSqrtParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxCeilParser : public OnnxNodeParser { public: OnnxCeilParser() : OnnxNodeParser("Ceil") {} ~OnnxCeilParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxLogParser : public OnnxNodeParser { public: OnnxLogParser() : OnnxNodeParser("Log") {} ~OnnxLogParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxTanParser : public OnnxNodeParser { public: OnnxTanParser() : OnnxNodeParser("Tan") {} ~OnnxTanParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxAtanParser : public OnnxNodeParser { public: OnnxAtanParser() : OnnxNodeParser("Atan") {} ~OnnxAtanParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxAsinParser : public OnnxNodeParser { public: OnnxAsinParser() : OnnxNodeParser("Asin") {} ~OnnxAsinParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxTanhParser : public OnnxNodeParser { public: OnnxTanhParser() : OnnxNodeParser("Tanh") {} ~OnnxTanhParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxSignParser : public OnnxNodeParser { public: OnnxSignParser() : OnnxNodeParser("Sign") {} ~OnnxSignParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxAndParser : public OnnxNodeParser { public: OnnxAndParser() : OnnxNodeParser("And") {} ~OnnxAndParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxOrParser : public OnnxNodeParser { public: OnnxOrParser() : OnnxNodeParser("Or") {} ~OnnxOrParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxNotParser : public OnnxNodeParser { public: OnnxNotParser() : OnnxNodeParser("Not") {} ~OnnxNotParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxRoundParser : public OnnxNodeParser { public: OnnxRoundParser() : OnnxNodeParser("Round") {} ~OnnxRoundParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxReciprocalParser : public OnnxNodeParser { public: OnnxReciprocalParser() : OnnxNodeParser("Reciprocal") {} ~OnnxReciprocalParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc index ea301e7f8bb..3ea9a670ed9 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc @@ -19,23 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxBatchNormParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx BatchNormParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } for (const auto &onnx_node_attr : onnx_node.attribute()) { @@ -47,10 +37,14 @@ STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx attr->spatial = static_cast(onnx_node_attr.i()); } } - - op->primitive->value.type = schema::PrimitiveType_FusedBatchNorm; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_FusedBatchNorm; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxBatchNormParser("BatchNormalization", new OnnxBatchNormParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.h index 643eea0e89a..18f2b7ee3c9 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.h @@ -27,7 +27,7 @@ class OnnxBatchNormParser : public OnnxNodeParser { OnnxBatchNormParser() : OnnxNodeParser("BatchNormalization") {} ~OnnxBatchNormParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc index 66742a52063..935c62f3e7b 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc @@ -19,30 +19,25 @@ namespace mindspore { namespace lite { -STATUS OnnxBiasAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxBiasAddParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx BiasAddParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } attr->axis = {1}; - op->primitive->value.type = schema::PrimitiveType_BiasAdd; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_BiasAdd; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxBiasAddParser("BiasAdd", new OnnxBiasAddParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.h index c2f5112efa0..01b15db53ee 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.h @@ -27,7 +27,7 @@ class OnnxBiasAddParser : public OnnxNodeParser { OnnxBiasAddParser() : OnnxNodeParser("BiasAdd") {} ~OnnxBiasAddParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc index d57b6f1719b..1a2a93cc078 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc @@ -20,22 +20,13 @@ namespace mindspore { namespace lite { -STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxCastParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx CastParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } for (const auto &onnx_node_attr : onnx_node.attribute()) { @@ -48,10 +39,14 @@ STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod attr->dstT = static_cast(dst_type); } } - - op->primitive->value.type = schema::PrimitiveType_Cast; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Cast; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxCastParser("Cast", new OnnxCastParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.h index 9fe47368a0f..45389ce2158 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.h @@ -27,7 +27,7 @@ class OnnxCastParser : public OnnxNodeParser { OnnxCastParser() : OnnxNodeParser("Cast") {} ~OnnxCastParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc index ef1cea6e55f..3012b91c04c 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc @@ -19,39 +19,32 @@ namespace mindspore { namespace lite { -STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxClipParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx ClipParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - float min = -1, max = -1; + attr->max = -1; + attr->min = -1; for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "max") { - max = onnx_node_attr.f(); + attr->max = onnx_node_attr.f(); } else if (attribute_name == "min") { - min = onnx_node_attr.f(); + attr->min = onnx_node_attr.f(); } } - - std::unique_ptr attr = std::make_unique(); - if (attr == nullptr) { - MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; } - attr->max = max; - attr->min = min; - op->primitive->value.type = schema::PrimitiveType_Clip; - op->primitive->value.value = attr.release(); - - return RET_OK; + primitive->value.type = schema::PrimitiveType_Clip; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxClipParser("Clip", new OnnxClipParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.h index 8ffffec7584..bd6dcb8d75a 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.h @@ -27,7 +27,7 @@ class OnnxClipParser : public OnnxNodeParser { OnnxClipParser() : OnnxNodeParser("Clip") {} ~OnnxClipParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.cc index f8148924ab3..4c83fa49923 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.cc @@ -19,23 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxConcatParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx ConcatParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } for (const auto &onnx_node_attr : onnx_node.attribute()) { @@ -44,10 +34,14 @@ STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N attr->axis = static_cast(onnx_node_attr.i()); } } - - op->primitive->value.type = schema::PrimitiveType_Concat; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Concat; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxConcatParser("Concat", new OnnxConcatParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.h index 49dcc733943..ccab17ca15e 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.h @@ -27,7 +27,7 @@ class OnnxConcatParser : public OnnxNodeParser { OnnxConcatParser() : OnnxNodeParser("Concat") {} ~OnnxConcatParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc index ee3cb1c2c69..dc1c930706b 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc @@ -20,23 +20,13 @@ namespace mindspore { namespace lite { -STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxConstantOfShapeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx ConstantOfShapeParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } for (const auto &onnx_node_attr : onnx_node.attribute()) { @@ -55,19 +45,24 @@ STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, cons const auto &tensor = onnx_node_attr.t(); auto ret = GetTensorDataFromOnnx(tensor, &attr->value, &attr->dataType); if (ret != RET_OK) { - return ret; + MS_LOG(ERROR) << "get data from tensor failed"; + return nullptr; } } break; default: MS_LOG(ERROR) << "The data type is not supported."; - return RET_ERROR; + return nullptr; } } } - - op->primitive->value.type = schema::PrimitiveType_ConstantOfShape; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_ConstantOfShape; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxConstantOfShapeParser("ConstantOfShape", new OnnxConstantOfShapeParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.h index 2e19397ec3c..09e5d4a1b57 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.h @@ -27,7 +27,7 @@ class OnnxConstantOfShapeParser : public OnnxNodeParser { OnnxConstantOfShapeParser() : OnnxNodeParser("ConstantOfShape") {} ~OnnxConstantOfShapeParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc index 83348f3fe90..51b8b04da08 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc @@ -16,33 +16,75 @@ #include "tools/converter/parser/onnx/onnx_constant_parser.h" #include +#include +#include +#include "tools/converter/parser/onnx/onnx_model_parser.h" namespace mindspore { namespace lite { -STATUS OnnxConstantParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { - MS_LOG(DEBUG) << "onnx ConstantParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; +STATUS OnnxConstantParser::AddDataInfoAttr(const onnx::TensorProto &onnx_const_tensor, lite::PrimitiveC *primitive_c) { + ParamValueLitePtr param_value = std::make_shared(); + if (param_value == nullptr) { + MS_LOG(ERROR) << "new a paramValueLite failed."; + return RET_ERROR; } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; + auto data_type = + OnnxModelParser::GetDataTypeFromOnnx(static_cast(onnx_const_tensor.data_type())); + if (data_type == kTypeUnknown) { + MS_LOG(ERROR) << "not support onnx data type " + << static_cast(onnx_const_tensor.data_type()); + return RET_ERROR; } - - std::unique_ptr attr = std::make_unique(); - if (attr == nullptr) { - MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + std::vector shape_vector(onnx_const_tensor.dims().begin(), onnx_const_tensor.dims().end()); + std::vector shape; + std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), + [](const int64_t &val) { return static_cast(val); }); + param_value->set_tensor_type(data_type); + param_value->set_tensor_shape(shape); + param_value->set_format(schema::Format_NCHW); + if (OnnxModelParser::CopyOnnxTensorData(onnx_const_tensor, param_value) != RET_OK) { + MS_LOG(ERROR) << "get value failed."; + return RET_ERROR; } - - op->primitive->value.type = schema::PrimitiveType_Constant; - op->primitive->value.value = attr.release(); + primitive_c->set_attr("const_data", param_value); return RET_OK; } +lite::PrimitiveC *OnnxConstantParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { + MS_LOG(DEBUG) << "onnx ConstantParser"; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Constant; + auto primitive_c = PrimitiveC::Create(primitive.release()); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "create primitiveC failed."; + return nullptr; + } + for (const auto &attr : onnx_node.attribute()) { + if (attr.name() == "sparse_value") { + MS_LOG(WARNING) << "sparse_value"; + continue; + } + if (attr.name() == "value") { + const auto &const_tensor = attr.t(); + if (AddDataInfoAttr(const_tensor, primitive_c) != RET_OK) { + MS_LOG(ERROR) << "add basic attr failed."; + delete primitive_c; + return nullptr; + } + } else { + MS_LOG(ERROR) << "processing Constant op attr " << attr.name() << " not implemented"; + delete primitive_c; + return nullptr; + } + } + return primitive_c; +} + OnnxNodeRegistrar g_onnxConstantParser("Constant", new OnnxConstantParser()); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h index 1aacd458d24..d58736bf912 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h @@ -27,7 +27,8 @@ class OnnxConstantParser : public OnnxNodeParser { OnnxConstantParser() : OnnxNodeParser("Constant") {} ~OnnxConstantParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + STATUS AddDataInfoAttr(const onnx::TensorProto &onnx_const_tensor, lite::PrimitiveC *primitive_c); + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc index a0f2c374be0..a81410bf153 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc @@ -21,9 +21,14 @@ namespace mindspore::lite { constexpr int32_t kSingleGroup = 1; -bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr &attr, schema::CNodeT *op) { +bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr &attr, + schema::PrimitiveT *primitive) { MS_LOG(DEBUG) << "onnx DepthwiseConvParser"; - std::unique_ptr depthwiseConv2DParam = std::make_unique(); + if (attr == nullptr || primitive == nullptr) { + MS_LOG(ERROR) << "input parameter is nullptr"; + return false; + } + auto depthwiseConv2DParam = std::make_unique(); if (depthwiseConv2DParam == nullptr) { MS_LOG(ERROR) << "new op failed"; return false; @@ -45,27 +50,18 @@ bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptrhasBias = attr->hasBias; depthwiseConv2DParam->activationType = attr->activationType; - op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; - op->primitive->value.value = depthwiseConv2DParam.release(); + primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + primitive->value.value = depthwiseConv2DParam.release(); return true; } -STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxConvParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx ConvParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } attr->strideH = 1; @@ -83,21 +79,21 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod } else if (onnx_node_attr.name() == "dilations") { if (onnx_node_attr.ints().size() != 2) { MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2"; - return RET_ERROR; + return nullptr; } attr->dilateH = static_cast(onnx_node_attr.ints(0)); attr->dilateW = static_cast(onnx_node_attr.ints(1)); } else if (onnx_node_attr.name() == "kernels") { if (onnx_node_attr.ints().size() != 2) { MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; - return RET_ERROR; + return nullptr; } attr->kernelH = static_cast(onnx_node_attr.ints(0)); attr->kernelW = static_cast(onnx_node_attr.ints(1)); } else if (onnx_node_attr.name() == "kernel_shape") { if (onnx_node_attr.ints().size() != 2) { MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; - return RET_ERROR; + return nullptr; } attr->kernelH = static_cast(onnx_node_attr.ints(0)); attr->kernelW = static_cast(onnx_node_attr.ints(1)); @@ -106,7 +102,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod } else if (onnx_node_attr.name() == "pads") { if (onnx_node_attr.ints().size() != 4) { MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4"; - return RET_ERROR; + return nullptr; } attr->padUp = static_cast(onnx_node_attr.ints(0)); attr->padLeft = static_cast(onnx_node_attr.ints(1)); @@ -115,7 +111,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod } else if (onnx_node_attr.name() == "strides") { if (onnx_node_attr.ints().size() != 2) { MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2"; - return RET_ERROR; + return nullptr; } attr->strideH = static_cast(onnx_node_attr.ints(0)); attr->strideW = static_cast(onnx_node_attr.ints(1)); @@ -124,7 +120,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod attr->format = schema::Format::Format_NHWC; } else { MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s(); - return RET_ERROR; + return nullptr; } } } @@ -152,7 +148,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod [onnx_conv_weight](const onnx::NodeProto &proto) { return proto.output(0) == onnx_conv_weight; }); if (node_iter == onnx_graph.node().end()) { MS_LOG(ERROR) << "can not find node: " << onnx_conv_weight; - return RET_ERROR; + return nullptr; } std::vector dims; auto iter = std::find_if((*node_iter).attribute().begin(), (*node_iter).attribute().end(), @@ -160,7 +156,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod if (iter != (*node_iter).attribute().end()) { if (iter->ints().begin() == nullptr || iter->ints().end() == nullptr) { MS_LOG(ERROR) << "dims insert failed"; - return RET_ERROR; + return nullptr; } dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end()); } @@ -174,16 +170,21 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod attr->activationType = schema::ActivationType_NO_ACTIVATION; } + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } if (attr->group > kSingleGroup && attr->group == attr->channelIn) { - if (!ParseGroupConvolution(attr, op)) { + if (!ParseGroupConvolution(attr, primitive.get())) { MS_LOG(ERROR) << "Convert Convolution to Depthwise failed"; - return RET_ERROR; + return nullptr; } } else { - op->primitive->value.type = schema::PrimitiveType_Conv2D; - op->primitive->value.value = attr.release(); + primitive->value.type = schema::PrimitiveType_Conv2D; + primitive->value.value = attr.release(); } - return RET_OK; + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxConvParser("Conv", new OnnxConvParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h index 0162c6ffe44..9f9987e1aec 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h @@ -28,10 +28,10 @@ class OnnxConvParser : public OnnxNodeParser { OnnxConvParser() : OnnxNodeParser("Conv") {} ~OnnxConvParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; private: - static bool ParseGroupConvolution(const std::unique_ptr &attr, schema::CNodeT *op); + static bool ParseGroupConvolution(const std::unique_ptr &attr, schema::PrimitiveT *primitive); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc index 268129438ed..c53bcd28cfa 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc @@ -21,11 +21,13 @@ namespace mindspore { namespace lite { -bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr &attr, schema::CNodeT *op) { - if (attr == nullptr || attr->group != attr->channelOut) { +bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr &attr, + schema::PrimitiveT *primitive) { + if (attr == nullptr || attr->group != attr->channelOut || primitive == nullptr) { + MS_LOG(ERROR) << "input parameter is nullptr"; return false; } - std::unique_ptr deDepthwiseConv2DParam = std::make_unique(); + auto deDepthwiseConv2DParam = std::make_unique(); if (deDepthwiseConv2DParam == nullptr) { MS_LOG(ERROR) << "new op failed"; return false; @@ -47,28 +49,18 @@ bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptrhasBias = attr->hasBias; deDepthwiseConv2DParam->activationType = attr->activationType; - op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D; - op->primitive->value.value = deDepthwiseConv2DParam.release(); + primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D; + primitive->value.value = deDepthwiseConv2DParam.release(); return true; } -STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx DeConvParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } attr->padMode = schema::PadMode_NOTSET; @@ -83,21 +75,21 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N } else if (onnx_node_attr.name() == "dilations") { if (onnx_node_attr.ints().size() != 2) { MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2"; - return RET_ERROR; + return nullptr; } attr->dilateH = static_cast(onnx_node_attr.ints(0)); attr->dilateW = static_cast(onnx_node_attr.ints(1)); } else if (onnx_node_attr.name() == "kernels") { if (onnx_node_attr.ints().size() != 2) { MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; - return RET_ERROR; + return nullptr; } attr->kernelH = static_cast(onnx_node_attr.ints(0)); attr->kernelW = static_cast(onnx_node_attr.ints(1)); } else if (onnx_node_attr.name() == "kernel_shape") { if (onnx_node_attr.ints().size() != 2) { MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; - return RET_ERROR; + return nullptr; } attr->kernelH = static_cast(onnx_node_attr.ints(0)); attr->kernelW = static_cast(onnx_node_attr.ints(1)); @@ -106,7 +98,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N } else if (onnx_node_attr.name() == "pads") { if (onnx_node_attr.ints().size() != 4) { MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4"; - return RET_ERROR; + return nullptr; } attr->padUp = static_cast(onnx_node_attr.ints(0)); attr->padLeft = static_cast(onnx_node_attr.ints(1)); @@ -115,7 +107,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N } else if (onnx_node_attr.name() == "strides") { if (onnx_node_attr.ints().size() != 2) { MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2"; - return RET_ERROR; + return nullptr; } attr->strideH = static_cast(onnx_node_attr.ints(0)); attr->strideW = static_cast(onnx_node_attr.ints(1)); @@ -124,11 +116,11 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N attr->format = schema::Format::Format_NHWC; } else { MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s().c_str(); - return RET_ERROR; + return nullptr; } } else if (onnx_node_attr.name() == "output_padding") { MS_LOG(ERROR) << "output_padding param hasn't been supported"; - return RET_NOT_SUPPORT; + return nullptr; } } @@ -138,7 +130,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); if (node_iter == onnx_graph.initializer().end()) { MS_LOG(ERROR) << "not find node: " << onnx_conv_weight.c_str(); - return RET_ERROR; + return nullptr; } std::vector weight_shape; auto size = (*node_iter).dims_size(); @@ -148,7 +140,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N } if (weight_shape.size() != 4) { MS_LOG(ERROR) << "weight_shape.size() should be 4, but is " << weight_shape.size(); - return RET_ERROR; + return nullptr; } attr->channelIn = weight_shape[0]; attr->channelOut = weight_shape[1] * attr->group; @@ -156,17 +148,22 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N attr->format = schema::Format::Format_NCHW; attr->hasBias = onnx_node.input().size() == 3; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } if (attr->group != 1) { - if (!ParseGroupDeConvolution(attr, op)) { + if (!ParseGroupDeConvolution(attr, primitive.get())) { MS_LOG(ERROR) << "Convert DeConvolution to DeDepthwise failed, generalized group deconv hasn't support"; - return RET_NOT_SUPPORT; + return nullptr; } } else { - op->primitive->value.type = schema::PrimitiveType_DeConv2D; - op->primitive->value.value = attr.release(); + primitive->value.type = schema::PrimitiveType_DeConv2D; + primitive->value.value = attr.release(); } - return RET_OK; + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxDeConvParser("ConvTranspose", new OnnxDeConvParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h index a0a77e80584..2b83c223cfc 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h @@ -28,10 +28,10 @@ class OnnxDeConvParser : public OnnxNodeParser { OnnxDeConvParser() : OnnxNodeParser("DeConv") {} ~OnnxDeConvParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; private: - static bool ParseGroupDeConvolution(const std::unique_ptr &attr, schema::CNodeT *op); + bool ParseGroupDeConvolution(const std::unique_ptr &attr, schema::PrimitiveT *primitive); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc index 13b2d91726e..8ee50819679 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc @@ -19,23 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxDepthToSpaceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxDepthToSpaceParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx DepthToSpaceParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } for (const auto &onnx_node_attr : onnx_node.attribute()) { @@ -44,10 +34,14 @@ STATUS OnnxDepthToSpaceParser::Parse(const onnx::GraphProto &onnx_graph, const o attr->blockSize = static_cast(onnx_node_attr.i()); } } - - op->primitive->value.type = schema::PrimitiveType_DepthToSpace; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_DepthToSpace; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxDepthToSpaceParser("DepthToSpace", new OnnxDepthToSpaceParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.h index 9a29cc379fb..a53623b799c 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.h @@ -27,7 +27,7 @@ class OnnxDepthToSpaceParser : public OnnxNodeParser { OnnxDepthToSpaceParser() : OnnxNodeParser("DepthToSpace") {} ~OnnxDepthToSpaceParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc index a6bdf6cbd14..b29778372e6 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc @@ -19,23 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxDropoutParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx DropoutParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } for (const auto &onnx_node_attr : onnx_node.attribute()) { @@ -44,10 +34,14 @@ STATUS OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: attr->ratio = static_cast(onnx_node_attr.f()); } } - - op->primitive->value.type = schema::PrimitiveType_Dropout; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Dropout; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxDropoutParser("Dropout", new OnnxDropoutParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.h index 3adbeef8d69..c2c3ca0083d 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.h @@ -27,7 +27,7 @@ class OnnxDropoutParser : public OnnxNodeParser { OnnxDropoutParser() : OnnxNodeParser("Dropout") {} ~OnnxDropoutParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.cc index b567a2950fb..bcb77b8f1b5 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.cc @@ -19,22 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxEluParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx EluParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } for (const auto &onnx_node_attr : onnx_node.attribute()) { @@ -43,10 +34,14 @@ STATUS OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node attr->alpha = onnx_node_attr.f(); } } - - op->primitive->value.type = schema::PrimitiveType_Elu; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Elu; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxEluParser("Elu", new OnnxEluParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.h index 8e49869454f..68a7037aedf 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.h @@ -27,7 +27,7 @@ class OnnxEluParser : public OnnxNodeParser { OnnxEluParser() : OnnxNodeParser("Elu") {} ~OnnxEluParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc index 87414ae823b..9c76c278e1d 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc @@ -20,23 +20,13 @@ namespace mindspore { namespace lite { -STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxExpandParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx ExpandParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } std::vector dst_shape; @@ -46,7 +36,7 @@ STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N [onnx_expand_power](const onnx::NodeProto &proto) { return proto.output(0) == onnx_expand_power; }); if (node_iter == onnx_graph.node().end()) { MS_LOG(ERROR) << "can not find node: " << onnx_expand_power; - return RET_ERROR; + return nullptr; } for (const auto &attrPower : node_iter->attribute()) { if (attrPower.name() == "value") { @@ -58,9 +48,14 @@ STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N } } attr->dst_shape = dst_shape; - op->primitive->value.type = schema::PrimitiveType_BroadcastTo; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_BroadcastTo; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxExpandSpaceParser("Expand", new OnnxExpandParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.h index c6590d89890..7178aa20449 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.h @@ -27,7 +27,7 @@ class OnnxExpandParser : public OnnxNodeParser { OnnxExpandParser() : OnnxNodeParser("Expand") {} ~OnnxExpandParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.cc index c01c48ce5e1..60be9c822ef 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.cc @@ -19,23 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxFlattenParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx FlattenParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } int axis = 1; @@ -49,10 +39,14 @@ STATUS OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: attr->shape.emplace_back(0); } attr->shape.emplace_back(-1); - - op->primitive->value.type = schema::PrimitiveType_Reshape; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Reshape; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxFlattenParser("Flatten", new OnnxFlattenParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.h index 17df69d190e..1b368f67057 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.h @@ -27,7 +27,7 @@ class OnnxFlattenParser : public OnnxNodeParser { OnnxFlattenParser() : OnnxNodeParser("Fatten") {} ~OnnxFlattenParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.cc index 4642b95085d..3d02ca7d0d1 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.cc @@ -19,23 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxGatherParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxGatherParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx GatherParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } for (const auto &onnx_node_attr : onnx_node.attribute()) { @@ -45,9 +35,14 @@ STATUS OnnxGatherParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N } } - op->primitive->value.type = schema::PrimitiveType_Gather; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Gather; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxGatherParser("Gather", new OnnxGatherParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.h index 8200e47abca..a1768bd3982 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.h @@ -27,7 +27,7 @@ class OnnxGatherParser : public OnnxNodeParser { OnnxGatherParser() : OnnxNodeParser("Gather") {} ~OnnxGatherParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_gemm_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_gemm_parser.cc new file mode 100644 index 00000000000..884a34daffa --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_gemm_parser.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/onnx/onnx_gemm_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +lite::PrimitiveC *OnnxGemmParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { + MS_LOG(DEBUG) << "onnx IdentityParser"; + auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("MatMul"); + if (node_parser == nullptr) { + MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed."; + return nullptr; + } + auto *matmul_primitive = node_parser->ParseLitePrimitive(onnx_graph, onnx_node); + + node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("BiasAdd"); + if (node_parser == nullptr) { + MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed."; + return nullptr; + } + + auto *bias_add_primitive = node_parser->ParseLitePrimitive(onnx_graph, onnx_node); + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + + primitive->value.type = schema::PrimitiveType_MakeTuple; + auto primitve_c = PrimitiveC::Create(primitive.release()); + primitve_c->set_attr("MatMul", std::shared_ptr(matmul_primitive)); + primitve_c->set_attr("BiasAdd", std::shared_ptr(bias_add_primitive)); + return primitve_c; +} + +OnnxNodeRegistrar g_onnxGemmParser("Gemm", new OnnxGemmParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_tensor_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_gemm_parser.h similarity index 55% rename from mindspore/lite/tools/converter/parser/onnx/onnx_tensor_parser.h rename to mindspore/lite/tools/converter/parser/onnx/onnx_gemm_parser.h index ad9f66ab28a..4424d2ea6be 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_tensor_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_gemm_parser.h @@ -14,26 +14,21 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TENSOR_PARSER_H -#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TENSOR_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H -#include "tools/common/tensor_util.h" +#include "tools/converter/parser/onnx/onnx_node_parser.h" +#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" namespace mindspore { namespace lite { -class OnnxTensorParser { +class OnnxGemmParser : public OnnxNodeParser { public: - ~OnnxTensorParser() = default; - static OnnxTensorParser *GetInstance() { - static OnnxTensorParser onnxTensorParser; - return &onnxTensorParser; - } - TensorCache *GetTensorCache() { return &tensor_cache_; } + OnnxGemmParser() : OnnxNodeParser("Gemm") {} + ~OnnxGemmParser() override = default; - private: - OnnxTensorParser() = default; - TensorCache tensor_cache_; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore -#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TESNOR_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.cc new file mode 100644 index 00000000000..1a0a330fdf8 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.cc @@ -0,0 +1,127 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/onnx/onnx_given_tensor_fill_parser.h" +#include +#include +#include +#include +#include "src/param_value_lite.h" + +namespace mindspore { +namespace lite { +STATUS OnnxGivenTensorFillParser::ParseInt8GivenIntTensorFill(const onnx::NodeProto &onnx_node, + lite::PrimitiveC *primitive_c, + const std::vector &shape) { + ParamValueLitePtr param_value = std::make_shared(); + if (param_value == nullptr) { + MS_LOG(ERROR) << "new a paramValueLite failed."; + return RET_ERROR; + } + int data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(), + [](const onnx::AttributeProto &attr) { return attr.name() == "values"; }); + if (iter == onnx_node.attribute().end()) { + return RET_OK; + } + size_t data_size = data_count * sizeof(int64_t) / sizeof(uint8_t); + char *param_data = new (std::nothrow) char[data_size]; + if (param_data == nullptr) { + MS_LOG(ERROR) << "new char[] failed"; + return RET_MEMORY_FAILED; + } + if (memcpy_s(param_data, data_size, iter->ints().data(), data_size) != EOK) { + MS_LOG(ERROR) << "memcpy data failed."; + delete[] param_data; + return RET_ERROR; + } + param_value->set_tensor_shape(shape); + param_value->set_format(schema::Format_NUM_OF_FORMAT); + param_value->set_tensor_type(kNumberTypeInt64); + param_value->SetTensorData(param_data, data_size); + primitive_c->set_attr("const_data", param_value); + return RET_OK; +} + +STATUS OnnxGivenTensorFillParser::ParseInt8GivenTensorFill(const onnx::NodeProto &onnx_node, + lite::PrimitiveC *primitive_c, + const std::vector &shape) { + ParamValueLitePtr param_value = std::make_shared(); + if (param_value == nullptr) { + MS_LOG(ERROR) << "new a paramValueLite failed."; + return RET_ERROR; + } + int data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(), + [](const onnx::AttributeProto &attr) { return attr.name() == "values"; }); + if (iter == onnx_node.attribute().end()) { + return RET_OK; + } + char *param_data = new (std::nothrow) char[data_count]; + if (param_data == nullptr) { + MS_LOG(ERROR) << "new char[] failed"; + return RET_MEMORY_FAILED; + } + if (memcpy_s(param_data, data_count, iter->s().data(), data_count) != EOK) { + MS_LOG(ERROR) << "memcpy data failed."; + delete[] param_data; + return RET_ERROR; + } + param_value->set_tensor_shape(shape); + param_value->set_format(schema::Format_NUM_OF_FORMAT); + param_value->set_tensor_type(kNumberTypeUInt8); + param_value->SetTensorData(param_data, data_count); + primitive_c->set_attr("const_data", param_value); + return RET_OK; +} + +lite::PrimitiveC *OnnxGivenTensorFillParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { + MS_LOG(DEBUG) << "onnx GivenTensorFillParser"; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Constant; + auto primitive_c = PrimitiveC::Create(primitive.release()); + std::vector shape_vector; + auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(), + [](const onnx::AttributeProto &attr) { return attr.name() == "shape"; }); + if (iter != onnx_node.attribute().end()) { + shape_vector.insert(shape_vector.begin(), iter->ints().begin(), iter->ints().end()); + } + std::vector shape; + std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), + [](const int64_t &val) { return static_cast(val); }); + if (onnx_node.op_type() == "Int8GivenIntTensorFill") { + if (ParseInt8GivenIntTensorFill(onnx_node, primitive_c, shape) != RET_OK) { + MS_LOG(ERROR) << "given tensor fill parse failed."; + return nullptr; + } + } else if (onnx_node.op_type() == "Int8GivenTensorFill") { + if (ParseInt8GivenTensorFill(onnx_node, primitive_c, shape) != RET_OK) { + MS_LOG(ERROR) << "given tensor fill parse failed."; + return nullptr; + } + } + return primitive_c; +} + +OnnxNodeRegistrar g_onnxInt8GivenIntTensorFillParser("Int8GivenIntTensorFill", new OnnxGivenTensorFillParser()); +OnnxNodeRegistrar g_onnxInt8GivenTensorFillParser("Int8GivenTensorFill", new OnnxGivenTensorFillParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.h new file mode 100644 index 00000000000..4a55f5659f3 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GIVEN_TENSOR_FILL_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GIVEN_TENSOR_FILL_PARSER_H + +#include +#include "tools/converter/parser/onnx/onnx_node_parser.h" +#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxGivenTensorFillParser : public OnnxNodeParser { + public: + OnnxGivenTensorFillParser() : OnnxNodeParser("GivenTensorFill") {} + ~OnnxGivenTensorFillParser() override = default; + + STATUS ParseInt8GivenIntTensorFill(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c, + const std::vector &shape); + STATUS ParseInt8GivenTensorFill(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c, + const std::vector &shape); + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GIVEN_TENSOR_FILL_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_identity_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_identity_parser.cc index c414c64ae91..048f2f65216 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_identity_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_identity_parser.cc @@ -20,28 +20,23 @@ namespace mindspore { namespace lite { -STATUS OnnxIdentityParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxIdentityParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx IdentityParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - op->primitive->value.type = schema::PrimitiveType_Identity; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Identity; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxIdentityParser("Identity", new OnnxIdentityParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_identity_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_identity_parser.h index c3894819bca..14dad740a97 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_identity_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_identity_parser.h @@ -27,7 +27,7 @@ class OnnxIdentityParser : public OnnxNodeParser { OnnxIdentityParser() : OnnxNodeParser("Identity") {} ~OnnxIdentityParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_instance_norm_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_instance_norm_parser.cc index 34039fb6ac0..7cfe142bb74 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_instance_norm_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_instance_norm_parser.cc @@ -19,23 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxInstanceNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxInstanceNormParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx InstanceNormParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } if (!onnx_node.attribute().empty()) { @@ -44,10 +34,14 @@ STATUS OnnxInstanceNormParser::Parse(const onnx::GraphProto &onnx_graph, const o attr->epsilon = onnx_node_attr.f(); } } - - op->primitive->value.type = schema::PrimitiveType_InstanceNorm; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_InstanceNorm; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxInstanceNormParser("InstanceNormalization", new OnnxInstanceNormParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_instance_norm_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_instance_norm_parser.h index d6e2dc88a38..9979c36dabe 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_instance_norm_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_instance_norm_parser.h @@ -27,7 +27,7 @@ class OnnxInstanceNormParser : public OnnxNodeParser { OnnxInstanceNormParser() : OnnxNodeParser("InstanceNorm") {} ~OnnxInstanceNormParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_lp_norm_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_lp_norm_parser.cc index bddbfe9bdf8..773d81cf375 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_lp_norm_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_lp_norm_parser.cc @@ -18,23 +18,13 @@ #include namespace mindspore::lite { -STATUS OnnxLpNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxLpNormParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx LpNormParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } for (const auto &onnx_node_attr : onnx_node.attribute()) { @@ -45,10 +35,14 @@ STATUS OnnxLpNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N attr->p = onnx_node_attr.i(); } } - - op->primitive->value.type = schema::PrimitiveType_LpNormalization; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_LpNormalization; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxLpNormParser("LpNormalization", new OnnxLpNormParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_lp_norm_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_lp_norm_parser.h index 4c2a928864a..9fa92f8be6b 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_lp_norm_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_lp_norm_parser.h @@ -27,7 +27,7 @@ class OnnxLpNormParser : public OnnxNodeParser { OnnxLpNormParser() : OnnxNodeParser("LpNorm") {} ~OnnxLpNormParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc index 267abfa8b89..83a2bb08c56 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc @@ -18,22 +18,13 @@ #include namespace mindspore::lite { -STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxLrnParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx LrnParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } int32_t size = 0; @@ -53,13 +44,18 @@ STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node if (size == 0) { MS_LOG(ERROR) << "Divide-by-zero error."; - return RET_ERROR; + return nullptr; } attr->alpha /= size; - op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_LocalResponseNormalization; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxLrnxParser("Lrn", new OnnxLrnParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.h index 53e88b1975b..347d13cb175 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.h @@ -27,7 +27,7 @@ class OnnxLrnParser : public OnnxNodeParser { OnnxLrnParser() : OnnxNodeParser("Lrn") {} ~OnnxLrnParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_lstm_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_lstm_parser.cc index 7c755b83ef1..09716c1318a 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_lstm_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_lstm_parser.cc @@ -19,22 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxLstmParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxLstmParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx LstmParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } for (const auto &onnx_node_attr : onnx_node.attribute()) { @@ -44,9 +35,14 @@ STATUS OnnxLstmParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod } } - op->primitive->value.type = schema::PrimitiveType_Lstm; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Lstm; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxLstmParser("LSTM", new OnnxLstmParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_lstm_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_lstm_parser.h index c0009e1db5e..3be45c5b7e5 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_lstm_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_lstm_parser.h @@ -27,7 +27,7 @@ class OnnxLstmParser : public OnnxNodeParser { OnnxLstmParser() : OnnxNodeParser("LSTM") {} ~OnnxLstmParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc index 4c4d8c0e4c8..18fd66169f8 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc @@ -19,23 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxMatmulParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx MatMulParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } float alpha = 1.0f; @@ -54,12 +44,17 @@ STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N } if (alpha != 1 || beta != 1) { MS_LOG(ERROR) << "not support alpha * A * B + beta * C"; - return RET_ERROR; + return nullptr; } - op->primitive->value.type = schema::PrimitiveType_MatMul; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_MatMul; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxMatmulParser("MatMul", new OnnxMatmulParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.h index f98a4d34791..22af92f4f34 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.h @@ -27,7 +27,7 @@ class OnnxMatmulParser : public OnnxNodeParser { OnnxMatmulParser() : OnnxNodeParser("MatMul") {} ~OnnxMatmulParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index cd524287715..70cdcd2f906 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -16,18 +16,15 @@ #include "tools/converter/parser/onnx/onnx_model_parser.h" #include -#include -#include +#include #include +#include #include "src/common/utils.h" #include "tools/common/graph_util.h" #include "tools/common/protobuf_utils.h" namespace mindspore { namespace lite { -OnnxModelParser::OnnxModelParser() = default; -OnnxModelParser::~OnnxModelParser() = default; - static const std::unordered_map TYPE_MAP = { {onnx::TensorProto_DataType_INT8, mindspore::kNumberTypeInt8}, {onnx::TensorProto_DataType_UINT8, mindspore::kNumberTypeUInt8}, @@ -39,6 +36,639 @@ static const std::unordered_map TYPE_MAP = { {onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32}, {onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}}; +std::set SPECIAL_NODE = {"Gemm", "Loop"}; +FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type) { + NoSupportOp::GetInstance()->SetFmkType("ONNX"); + auto status = InitOriginModel(model_file); + if (RET_OK != status) { + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + MS_LOG(ERROR) << "init origin model failed."; + return nullptr; + } + + func_graph_ptr_ = std::make_shared(); + if (func_graph_ptr_ == nullptr) { + MS_LOG(ERROR) << "funcgraph is nullptr."; + return nullptr; + } + + status = ConvertConstTensors(); + if (RET_OK != status) { + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + MS_LOG(ERROR) << "convert const nodes failed."; + return nullptr; + } + + status = ConvertGraphInputs(); + if (RET_OK != status) { + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + MS_LOG(ERROR) << "convert graph inputs failed."; + return nullptr; + } + + status = ConvertNodes(); + if (RET_OK != status) { + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + MS_LOG(ERROR) << "convert nodes failed."; + return nullptr; + } + + status = ConvertGraphOutputs(); + if (RET_OK != status) { + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + MS_LOG(ERROR) << "convert graph outputs failed."; + return nullptr; + } + return func_graph_ptr_; +} + +STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) { + auto status = ValidateFileStr(model_file, ".onnx"); + if (status != RET_OK) { + MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.onnx"; + return status; + } + + status = ReadProtoFromBinaryFile((const char *)model_file.c_str(), &onnx_model_); + if (status != RET_OK) { + MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_file; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return status; + } + OnnxNodeParser::set_opset_version(onnx_model_.opset_import().Get(0).version()); + onnx_graph_ = onnx_model_.graph(); + return RET_OK; +} + +STATUS OnnxModelParser::ConvertConstTensors() { + for (const auto &onnx_const_value : onnx_graph_.initializer()) { + auto parameter = func_graph_ptr_->add_parameter(); + auto status = BuildParameterNode(parameter, onnx_const_value); + if (status != RET_OK) { + MS_LOG(ERROR) << "parameter node build failed."; + return status; + } + nodes_.emplace(onnx_const_value.name(), parameter); + } + return RET_OK; +} + +STATUS OnnxModelParser::ConvertGraphInputs() { + for (int i = 0; i < onnx_graph_.input().size(); ++i) { + const auto &input_value = onnx_graph_.input(i); + if (nodes_.find(input_value.name()) != nodes_.end()) { + continue; + } + auto parameter = func_graph_ptr_->add_parameter(); + auto data_type = + GetDataTypeFromOnnx(static_cast(input_value.type().tensor_type().elem_type())); + if (data_type == kTypeUnknown) { + MS_LOG(ERROR) << "not support onnx data type " + << static_cast(input_value.type().tensor_type().elem_type()); + return RET_ERROR; + } + auto type_ptr = TypeIdToType(data_type); + std::vector shape_vector; + auto onnx_shape = input_value.type().tensor_type().shape().dim(); + std::transform(onnx_shape.begin(), onnx_shape.end(), std::back_inserter(shape_vector), + [](const onnx::TensorShapeProto_Dimension &val) { return static_cast(val.dim_value()); }); + auto abstract_tensor = std::make_shared(type_ptr, shape_vector); + parameter->set_abstract(abstract_tensor); + parameter->set_name(input_value.name()); + nodes_.emplace(input_value.name(), parameter); + } + return RET_OK; +} + +STATUS OnnxModelParser::ConvertNodes() { + STATUS status = RET_OK; + for (const auto &onnx_node : onnx_graph_.node()) { + auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type()); + if (node_parser == nullptr) { + NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type()); + status = status == RET_OK ? RET_NOT_FIND_OP : status; + } + if (status != RET_OK) { + continue; + } + auto primitive_c = node_parser->ParseLitePrimitive(onnx_graph_, onnx_node); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed."; + status = RET_ERROR; + continue; + } + if (IsSpecialOnnxNode(onnx_node)) { + auto status_node = ConvertSpecialOnnxNode(onnx_node, primitive_c); + status = status == RET_OK ? status_node : status; + continue; + } + status = ConvertOpQuantParams(onnx_node, primitive_c); + if (status != RET_OK) { + MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed."; + continue; + } + // build CNode + status = BuildCNode(onnx_node, primitive_c); + if (status != RET_OK) { + MS_LOG(ERROR) << "build cnode " << onnx_node.op_type() << " failed."; + } + } + return status; +} + +STATUS OnnxModelParser::ConvertGraphOutputs() { + std::vector return_inputs; + if (onnx_graph_.output_size() > 1) { + std::vector make_tuple_inputs; + auto make_tuple_prim_ptr = GetMakeTuplePrim(); + if (make_tuple_prim_ptr == nullptr) { + MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr"; + return RET_NULL_PTR; + } + for (const auto &graph_out : onnx_graph_.output()) { + if (nodes_.find(graph_out.name()) == nodes_.end()) { + MS_LOG(ERROR) << "graph output get failed."; + return RET_ERROR; + } + auto cnode = nodes_[graph_out.name()]; + if (nullptr == cnode) { + MS_LOG(ERROR) << "Can't find input node."; + return RET_NOT_FIND_OP; + } + make_tuple_inputs.emplace_back(cnode); + } + auto make_tuple_cnode = func_graph_ptr_->NewCNode(make_tuple_prim_ptr, make_tuple_inputs); + make_tuple_cnode->set_fullname_with_scope("return tuple"); + return_inputs.emplace_back(make_tuple_cnode); + } else { + const auto &graph_out = onnx_graph_.output(0); + if (nodes_.find(graph_out.name()) == nodes_.end()) { + MS_LOG(ERROR) << "graph output get failed."; + return RET_ERROR; + } + auto cnode = nodes_[graph_out.name()]; + if (nullptr == cnode) { + MS_LOG(ERROR) << "Can't find input node."; + return RET_NOT_FIND_OP; + } + return_inputs.emplace_back(cnode); + } + if (BuildReturnNode(return_inputs) != RET_OK) { + MS_LOG(ERROR) << "build return node failed."; + return RET_ERROR; + } + return RET_OK; +} + +STATUS OnnxModelParser::BuildReturnNode(const std::vector &return_inputs) { + auto returnPrim = GetReturnPrim(); + if (returnPrim == nullptr) { + MS_LOG(ERROR) << "GetReturnPrim return nullptr"; + return RET_NULL_PTR; + } + auto returnCnode = func_graph_ptr_->NewCNode(returnPrim, return_inputs); + returnCnode->set_fullname_with_scope("return"); + func_graph_ptr_->set_return(returnCnode); + return RET_OK; +} + +STATUS OnnxModelParser::BuildCNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c) { + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr."; + return RET_NULL_PTR; + } + std::vector op_inputs; + for (const auto &input_name : onnx_node.input()) { + if (input_name.empty()) { + continue; + } + if (nodes_.find(input_name) == nodes_.end()) { + MS_LOG(ERROR) << "op " << onnx_node.op_type() << " inputs get failed."; + return RET_ERROR; + } else { + op_inputs.push_back(nodes_[input_name]); + } + } + auto new_cnode = func_graph_ptr_->NewCNode(std::shared_ptr(primitive_c), op_inputs); + new_cnode->set_fullname_with_scope(onnx_node.op_type() + "_" + onnx_node.output(0)); + auto status = BuildOpOutputs(onnx_node, new_cnode); + return status; +} + +STATUS OnnxModelParser::BuildOpOutputs(const onnx::NodeProto &onnx_node, const CNodePtr &cnode) { + if (cnode == nullptr) { + MS_LOG(ERROR) << "parameter is null, get output tensor failed."; + return RET_NULL_PTR; + } + if (onnx_node.output_size() == 1) { + auto type_ptr = TypeIdToType(kTypeUnknown); + std::vector shape_vector; + cnode->set_abstract(std::make_shared(type_ptr, shape_vector)); + nodes_.emplace(onnx_node.output(0), cnode); + } else { + AbstractBasePtrList abstract_list; + int op_idx = 0; + for (const auto &output_name : onnx_node.output()) { + std::vector shape_vector; + auto type_ptr = TypeIdToType(kTypeUnknown); + abstract_list.emplace_back(std::make_shared(type_ptr, shape_vector)); + auto tuple_get_item_prim_ptr = GetTupleGetItemPrim(); + if (tuple_get_item_prim_ptr == nullptr) { + MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr"; + return RET_NULL_PTR; + } + auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); + auto get_item_value = NewValueNode(MakeValue(op_idx)); + std::vector inputs{tuple_get_item_prim, cnode, get_item_value}; + CNodePtr get_item_cnode = func_graph_ptr_->NewCNode(inputs); + get_item_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_getitem_" + std::to_string(op_idx)); + nodes_.emplace(output_name, get_item_cnode); + op_idx++; + } + cnode->set_abstract(std::make_shared(abstract_list)); + } + return RET_OK; +} + +STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c) { + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is null, get quant params failed."; + return RET_NULL_PTR; + } + auto status = ParseQuantParam(onnx_node); + if (status != RET_OK) { + MS_LOG(ERROR) << "parse quant param failed."; + return RET_ERROR; + } + // set input tensors + for (int i = 0; i < onnx_node.input_size(); ++i) { + const auto &input_name = onnx_node.input(i); + std::vector quant_params; + status = SetTensorQuantParam(input_name, &quant_params); + if (status != RET_OK) { + MS_LOG(ERROR) << "set input tensor quant param failed."; + return status; + } + primitive_c->AddInputQuantParam(quant_params); + } + // set out tensors + for (int i = 0; i < onnx_node.output_size(); ++i) { + const auto &output_name = onnx_node.output(i); + std::vector quant_params; + status = SetTensorQuantParam(output_name, &quant_params); + if (status != RET_OK) { + MS_LOG(ERROR) << "set output tensor quant param failed."; + return status; + } + primitive_c->AddOutputQuantParam(quant_params); + } + return RET_OK; +} + +STATUS OnnxModelParser::ParseQuantParam(const onnx::NodeProto &onnx_node) { + for (const auto &onnx_node_attr : onnx_node.attribute()) { + if (onnx_node_attr.name() == "Y_scale") { + float scale = onnx_node_attr.f(); + if (BuildParameterNodeForQuantParam(&scale, "scale_" + onnx_node.output(0), kNumberTypeFloat32) != RET_OK) { + MS_LOG(ERROR) << "parse quant param failed."; + return RET_ERROR; + } + } else if (onnx_node_attr.name() == "Y_zero_point") { + int64_t zero_point = onnx_node_attr.i(); + if (BuildParameterNodeForQuantParam(&zero_point, "zero_point_" + onnx_node.output(0), kNumberTypeInt64) != + RET_OK) { + MS_LOG(ERROR) << "parse quant param failed."; + return RET_ERROR; + } + } + } + return RET_OK; +} + +STATUS OnnxModelParser::SetTensorQuantParam(const std::string &tensor_name, std::vector *quant_params) { + quant_params->clear(); + auto quant_param = std::make_unique(); + for (int i = 0; i < onnx_graph_.quantization_annotation_size(); ++i) { + auto tensor_annotation = onnx_graph_.quantization_annotation(i); + if (!tensor_annotation.has_tensor_name() || tensor_annotation.tensor_name() != tensor_name) { + continue; + } + for (const auto &item : tensor_annotation.quant_parameter_tensor_names()) { + if (!item.has_key() || !item.has_value()) { + continue; + } + + const auto &quant_tensor_name = item.value(); + if (item.key() == "SCALE_TENSOR") { + auto status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), true); + if (status != RET_OK) { + MS_LOG(ERROR) << "quant param scale get failed"; + return status; + } + } else if (item.key() == "ZERO_POINT_TENSOR") { + auto status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), false); + if (status != RET_OK) { + MS_LOG(ERROR) << "quant param zero_point get failed"; + return status; + } + } + } + break; + } + if (quant_param->inited) { + quant_params->push_back(*std::move(quant_param)); + return RET_OK; + } + return SetTensorQuantParamFromNode(tensor_name, quant_params); +} + +STATUS OnnxModelParser::SetTensorQuantParamFromNode(const std::string &tensor_name, + std::vector *quant_params) { + quant_params->clear(); + auto quant_param = std::make_unique(); + std::string quant_tensor_name = "scale_" + tensor_name; + auto status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), true); + if (status != RET_OK) { + MS_LOG(ERROR) << "quant param scale get failed"; + return status; + } + quant_tensor_name = "zero_point_" + tensor_name; + status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), false); + if (status != RET_OK) { + MS_LOG(ERROR) << "quant param zero_point get failed"; + return status; + } + if (quant_param->inited) { + quant_params->push_back(*std::move(quant_param)); + } else { + std::vector notinited_quant_params(1); + *quant_params = notinited_quant_params; + } + return RET_OK; +} + +STATUS OnnxModelParser::CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param, + bool scale_or_not) { + if (quant_param == nullptr) { + MS_LOG(ERROR) << "quant_param is nullptr"; + + return RET_NULL_PTR; + } + auto iter = nodes_.find(tensor_name); + if (iter == nodes_.end()) { + MS_LOG(DEBUG) << "has no quant param"; + return RET_OK; + } + if (!utils::isa(iter->second)) { + MS_LOG(ERROR) << "quant param get failed"; + return RET_ERROR; + } + auto quant_parameter_node = iter->second->cast(); + if (!quant_parameter_node->has_default()) { + MS_LOG(ERROR) << "quant param get failed"; + return RET_ERROR; + } + auto param_value_lite = quant_parameter_node->default_param()->cast(); + if (param_value_lite == nullptr) { + MS_LOG(ERROR) << "parameterNode's default param is not paramValueLite"; + return RET_ERROR; + } + if (scale_or_not) { + quant_param->scale = *reinterpret_cast(param_value_lite->tensor_addr()); + quant_param->inited = true; + } else { + quant_param->zeroPoint = *reinterpret_cast(param_value_lite->tensor_addr()); + quant_param->inited = true; + } + return RET_OK; +} + +STATUS OnnxModelParser::ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c) { + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "imitive_c is nullptr."; + return RET_NULL_PTR; + } + STATUS status = RET_OK; + if (onnx_node.op_type() == "Loop") { + MS_LOG(ERROR) << "loop hasn't supported."; + return RET_NOT_FIND_OP; + } else if (onnx_node.op_type() == "Gemm") { + status = ConvertOnnxGemmNode(onnx_node, primitive_c); + } else { + MS_LOG(ERROR) << "the node is not special node."; + status = RET_ERROR; + } + delete primitive_c; + return status; +} + +STATUS OnnxModelParser::ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c) { + if (onnx_node.op_type() != "Gemm") { + MS_LOG(ERROR) << "this op is not gemm, it is " << onnx_node.op_type(); + return RET_ERROR; + } + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr."; + return RET_NULL_PTR; + } + auto status = BuildCNodeForGemm(onnx_node, primitive_c, "MatMul"); + if (status != RET_OK) { + MS_LOG(ERROR) << "convert gemm node failed."; + return status; + } + status = BuildCNodeForGemm(onnx_node, primitive_c, "BiasAdd"); + if (status != RET_OK) { + MS_LOG(ERROR) << "convert gemm node failed."; + return status; + } + return RET_OK; +} + +STATUS OnnxModelParser::BuildCNodeForGemm(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c, + const std::string &name) { + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr."; + return RET_NULL_PTR; + } + auto value = primitive_c->GetAttr(name); + primitive_c->EraseAttr(name); + if (value == nullptr) { + MS_LOG(ERROR) << "op parse failed."; + return RET_NULL_PTR; + } + auto prim_ptr = value->cast>(); + if (prim_ptr == nullptr) { + MS_LOG(ERROR) << "p parse failed."; + return RET_NULL_PTR; + } + auto type_ptr = TypeIdToType(kTypeUnknown); + std::vector shape_vector; + std::vector op_inputs; + if (name == "MatMul") { + for (int i = 0; i < 2; ++i) { + if (nodes_.find(onnx_node.input(i)) == nodes_.end()) { + MS_LOG(ERROR) << "op " << onnx_node.op_type() << " inputs get failed."; + return RET_ERROR; + } else { + op_inputs.push_back(nodes_[onnx_node.input(i)]); + } + } + auto new_cnode = func_graph_ptr_->NewCNode(prim_ptr, op_inputs); + new_cnode->set_fullname_with_scope("Gemm_MatMul_" + onnx_node.output(0)); + new_cnode->set_abstract(std::make_shared(type_ptr, shape_vector)); + nodes_.emplace("Gemm_MatMul_" + onnx_node.output(0), new_cnode); + } else { + if (nodes_.find("Gemm_MatMul_" + onnx_node.output(0)) == nodes_.end() || + nodes_.find(onnx_node.input(2)) == nodes_.end()) { + MS_LOG(ERROR) << "op " << onnx_node.op_type() << " inputs get failed."; + return RET_ERROR; + } + op_inputs.push_back(nodes_["Gemm_MatMul_" + onnx_node.output(0)]); + op_inputs.push_back(nodes_[onnx_node.input(2)]); + auto new_cnode = func_graph_ptr_->NewCNode(prim_ptr, op_inputs); + new_cnode->set_fullname_with_scope("Gemm_BiasAdd_" + onnx_node.output(0)); + new_cnode->set_abstract(std::make_shared(type_ptr, shape_vector)); + nodes_.emplace(onnx_node.output(0), new_cnode); + } + return RET_OK; +} + +STATUS OnnxModelParser::BuildParameterNodeForQuantParam(void *data, const std::string &name, TypeId type) { + if (data == nullptr) { + MS_LOG(ERROR) << "value is nullptr."; + return RET_NULL_PTR; + } + if (type != kNumberTypeInt64 && type != kNumberTypeFloat32) { + MS_LOG(ERROR) << "quant param type don't support."; + return RET_NOT_SUPPORT; + } + std::vector shape_vector; + auto parameter_node = func_graph_ptr_->add_parameter(); + auto abstract_tensor = std::make_shared(TypeIdToType(type), shape_vector); + parameter_node->set_abstract(abstract_tensor); + parameter_node->set_name(name); + std::vector shape; + ParamValueLitePtr param_value = std::make_shared(); + MS_ASSERT(param_value != nullptr); + param_value->set_tensor_shape(shape); + param_value->set_format(schema::Format_NUM_OF_FORMAT); + param_value->set_tensor_type(type); + int data_size = 0; + if (type == kNumberTypeFloat32) { + data_size = sizeof(float); + } else { + data_size = sizeof(int64_t); + } + auto *tensor_data = new (std::nothrow) char[data_size]; + if (memcpy_s(tensor_data, data_size, data, data_size) != EOK) { + MS_LOG(ERROR) << "memcpy data failed."; + delete[] tensor_data; + return RET_ERROR; + } + param_value->SetTensorData(tensor_data, data_size); + parameter_node->set_default_param(param_value); + nodes_.emplace(name, parameter_node); + return RET_OK; +} + +STATUS OnnxModelParser::BuildParameterNode(const ParameterPtr ¶meter_node, const onnx::TensorProto &tensor) { + auto data_type = GetDataTypeFromOnnx(static_cast(tensor.data_type())); + if (data_type == kTypeUnknown) { + MS_LOG(ERROR) << "not support onnx data type " << static_cast(tensor.data_type()); + return RET_ERROR; + } + auto type_ptr = TypeIdToType(data_type); + std::vector shape_vector(tensor.dims().begin(), tensor.dims().end()); + auto abstract_tensor = std::make_shared(type_ptr, shape_vector); + parameter_node->set_abstract(abstract_tensor); + parameter_node->set_name(tensor.name()); + + ParamValueLitePtr param_value = std::make_shared(); + MS_ASSERT(param_value != nullptr); + std::vector shape; + std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), + [](const int64_t &value) { return static_cast(value); }); + param_value->set_tensor_shape(shape); + param_value->set_tensor_type(data_type); + param_value->set_format(schema::Format::Format_NCHW); + auto status = CopyOnnxTensorData(tensor, param_value); + if (status != RET_OK) { + MS_LOG(ERROR) << "copy data failed."; + return status; + } + parameter_node->set_default_param(param_value); + return RET_OK; +} + +STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_tensor, + const ParamValueLitePtr ¶m_value_lite) { + if (param_value_lite == nullptr) { + MS_LOG(ERROR) << "param_value_lite is nullptr."; + return RET_NULL_PTR; + } + size_t data_count = 1; + std::for_each(onnx_const_tensor.dims().begin(), onnx_const_tensor.dims().end(), + [&data_count](int dim) { data_count *= dim; }); + size_t data_size = 0; + const void *onnx_data = nullptr; + auto data_type = GetDataTypeFromOnnx(static_cast(onnx_const_tensor.data_type())); + switch (data_type) { + case kNumberTypeFloat32: + data_size = data_count * sizeof(float); + if (onnx_const_tensor.float_data_size() == 0) { + onnx_data = onnx_const_tensor.raw_data().data(); + } else { + onnx_data = onnx_const_tensor.float_data().data(); + } + break; + case kNumberTypeInt32: + data_size = data_count * sizeof(int); + if (onnx_const_tensor.int32_data_size() == 0) { + onnx_data = onnx_const_tensor.raw_data().data(); + } else { + onnx_data = onnx_const_tensor.int32_data().data(); + } + break; + case kNumberTypeInt64: + data_size = data_count * sizeof(int64_t); + if (onnx_const_tensor.int64_data_size() == 0) { + onnx_data = onnx_const_tensor.raw_data().data(); + } else { + onnx_data = onnx_const_tensor.int64_data().data(); + } + break; + case kNumberTypeUInt8: + case kNumberTypeInt8: + case kNumberTypeBool: + data_size = data_count * sizeof(uint8_t); + onnx_data = onnx_const_tensor.raw_data().data(); + break; + default: + MS_LOG(ERROR) << "unsupported data type " << data_type; + return RET_ERROR; + } + if (data_size == 0) { + return RET_OK; + } + char *param_data = new (std::nothrow) char[data_size]; + if (param_data == nullptr) { + MS_LOG(ERROR) << "new char[] failed"; + return RET_MEMORY_FAILED; + } + if (memcpy_s(static_cast(param_data), data_size, onnx_data, data_size) != EOK) { + MS_LOG(ERROR) << "memcpy_s failed"; + delete[] param_data; + return RET_ERROR; + } + param_value_lite->SetTensorData(param_data, data_size); + return RET_OK; +} + +bool OnnxModelParser::IsSpecialOnnxNode(const onnx::NodeProto &onnx_node) { + return SPECIAL_NODE.find(onnx_node.op_type()) != SPECIAL_NODE.end(); +} + TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type) { auto iter = TYPE_MAP.find(onnx_type); if (iter == TYPE_MAP.end()) { @@ -47,621 +677,5 @@ TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type } return iter->second; } - -std::vector OnnxModelParser::GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value) { - std::vector dims; - for (const auto &it : onnx_value.type().tensor_type().shape().dim()) { - dims.emplace_back(it.dim_value()); - } - return dims; -} - -STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph) { - MS_LOG(DEBUG) << "set onnx constant tensors"; - for (const auto &onnx_const_value : onnx_graph.initializer()) { - int index; - const auto status = AddTensorProto(onnx_const_value, onnx_const_value.name(), GRAPH_INPUT, &index); - if (status != RET_OK) { - return status; - } - MS_LOG(DEBUG) << "add const tensor: " << onnx_const_value.name() << ", index " << index; - } - MS_LOG(DEBUG) << "process onnx Constant ops"; - for (int i = 0; i < onnx_graph.node_size(); i++) { - const auto &node = onnx_graph.node(i); - if (node.op_type().compare("Constant") == 0) { - for (const auto &attr : node.attribute()) { - if (attr.name() == "sparse_value") { - MS_LOG(ERROR) << "sparse_value"; - } - if (attr.name() == "value") { - const auto &t = attr.t(); - int index; - const auto status = AddTensorProto(t, node.output(0), GRAPH_INPUT, &index); - if (status != RET_OK) { - return status; - } - MS_LOG(DEBUG) << "add const tensor: " << t.name() << ", index " << index; - } else { - MS_LOG(ERROR) << "processing Constant op attr " << attr.name() << " not implemented"; - return RET_INVALID_OP_ATTR; - } - } - } - } - return RET_OK; -} - -STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const Category &type, - int *index) { - auto data_type = GetDataTypeFromOnnx(static_cast(proto.type().tensor_type().elem_type())); - if (data_type == kTypeUnknown) { - MS_LOG(ERROR) << "not support onnx data type " - << static_cast(proto.type().tensor_type().elem_type()); - return RET_ERROR; - } - std::unique_ptr tensor = std::make_unique(); - if (tensor == nullptr) { - MS_LOG(ERROR) << "new tensor failed"; - return RET_ERROR; - } - tensor->dataType = data_type == kNumberTypeInt64 ? kNumberTypeInt32 : data_type; - tensor->dims = GetDimsFromOnnxValue(proto); - tensor->format = schema::Format::Format_NCHW; - tensor->nodeType = schema::NodeType::NodeType_ValueNode; - *index = OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(name, tensor.release(), type); - return RET_OK; -} - -STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const Category &type, - int *index) { - auto data_type = GetDataTypeFromOnnx(static_cast(proto.data_type())); - if (data_type == kTypeUnknown) { - MS_LOG(ERROR) << "not support onnx data type " << static_cast(proto.data_type()); - return RET_ERROR; - } - - std::unique_ptr tensor = std::make_unique(); - if (tensor == nullptr) { - MS_LOG(ERROR) << "new tensor failed"; - return RET_ERROR; - } - tensor->dataType = data_type; - std::copy(proto.dims().begin(), proto.dims().end(), std::back_inserter(tensor->dims)); - tensor->format = schema::Format::Format_NCHW; - tensor->nodeType = schema::NodeType::NodeType_ValueNode; - if (CopyOnnxTensorData(proto, tensor.get())) { - MS_LOG(ERROR) << "copy onnx data failed"; - return RET_ERROR; - } - if (data_type == kNumberTypeInt64) { - tensor->dataType = kNumberTypeInt32; // CopyOnnxTensorData will convert int64 to int32 - } - *index = OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(name, tensor.release(), type); - return RET_OK; -} - -STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph) { - for (const auto &input_value : onnx_graph.input()) { - auto ret = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(input_value.name()); - if (ret < 0) { - int index; - const auto status = AddValueInfo(input_value, input_value.name(), GRAPH_INPUT, &index); - if (status != RET_OK) { - return status; - } - MS_LOG(DEBUG) << "input_value name: " << input_value.name() << ", graph input index: " << index; - graph->inputIndices.emplace_back(static_cast(index)); - } - } - return RET_OK; -} - -STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph) { - for (const auto &output_value : onnx_graph.output()) { - int index; - if (OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(output_value.name()) != -1) { - index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(output_value.name()); - } else { - const auto status = AddValueInfo(output_value, output_value.name(), OP_OUTPUT, &index); - if (status != RET_OK) { - return status; - } - } - graph->outputIndices.emplace_back(index); - MS_LOG(DEBUG) << "output_value name: " << output_value.name() << ", graph output index: " << index; - } - return RET_OK; -} - -void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::SubGraphT *sub_graph, schema::MetaGraphT *graph, - const QuantType &quant_type) { - std::unique_ptr dst_op_1 = std::make_unique(); - dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0); - dst_op_1->quantType = quant_type; - ParseOnnxNodeAttr(onnx_graph, onnx_node, "MatMul", dst_op_1.get()); - auto matmul_output_id = "Gemm_MatMul_" + onnx_node.output(0); - std::vector matmul_inputs{onnx_node.input(0), onnx_node.input(1)}; - std::vector matmul_outputs{matmul_output_id}; - SetOpInputIndex(matmul_inputs, dst_op_1.get(), onnx_node); - SetOpOutputIndex(matmul_outputs, dst_op_1.get()); - graph->nodes.emplace_back(std::move(dst_op_1)); - sub_graph->nodeIndices.push_back(graph->nodes.size() - 1); - - std::unique_ptr dst_op_2 = std::make_unique(); - dst_op_2->name = "Gemm_BiasAdd_" + onnx_node.output(0); - dst_op_2->quantType = quant_type; - ParseOnnxNodeAttr(onnx_graph, onnx_node, "BiasAdd", dst_op_2.get()); - std::vector biasadd_inputs{matmul_output_id, onnx_node.input(2)}; - std::vector biasadd_outputs{onnx_node.output(0)}; - SetOpInputIndex(biasadd_inputs, dst_op_2.get(), onnx_node); - SetOpOutputIndex(biasadd_outputs, dst_op_2.get()); - graph->nodes.emplace_back(std::move(dst_op_2)); - sub_graph->nodeIndices.push_back(graph->nodes.size() - 1); -} - -STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node) { - // convert GivenTensorFill node to a weight/bias tensor - auto ret = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(onnx_node.output(0)); - if (ret < 0) { - std::unique_ptr tensor = std::make_unique(); - std::vector shape; - auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(), - [](const onnx::AttributeProto &attr) { return attr.name() == "shape"; }); - if (iter != onnx_node.attribute().end()) { - (void)shape.insert(shape.begin(), iter->ints().begin(), iter->ints().end()); - std::for_each(shape.begin(), shape.end(), [](int sh) { MS_LOG(DEBUG) << "shape: " << sh; }); - } - tensor->dims = shape; - tensor->format = schema::Format::Format_NUM_OF_FORMAT; - tensor->nodeType = schema::NodeType::NodeType_ValueNode; - iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(), - [](const onnx::AttributeProto &attr) { return attr.name() == "values"; }); - // copy GivenIntTensorFill node value to tensor - if (iter != onnx_node.attribute().end()) { - size_t data_count = 1; - std::for_each(shape.begin(), shape.end(), [&data_count](int dim) { data_count *= dim; }); - size_t data_size = 0; - if (onnx_node.op_type() == "Int8GivenIntTensorFill") { - tensor->dataType = kNumberTypeInt32; - data_size = data_count * sizeof(int32_t) / sizeof(uint8_t); - tensor->data.resize(data_size); - void *tensorData = tensor->data.data(); - auto castedTensorData = static_cast(tensorData); - MS_ASSERT(castedTensorData != nullptr); - for (size_t i = 0; i < data_count; i++) { - castedTensorData[i] = int32_t(iter->ints().data()[i]); - } - } else if (onnx_node.op_type() == "Int8GivenTensorFill") { - tensor->dataType = kNumberTypeUInt8; - data_size = data_count; - tensor->data.resize(data_size); - MS_LOG(DEBUG) << "tensor data size " << data_size << ", s: " << sizeof(iter->s().data()); - if (memcpy_s(tensor->data.data(), data_size, iter->s().data(), data_size) != 0) { - MS_LOG(ERROR) << "memcpy_s failed"; - return RET_ERROR; - } - // set quantParams to Int8GivenTensor. - std::unique_ptr quant_param = std::make_unique(); - for (const auto &onnx_node_attr : onnx_node.attribute()) { - if (onnx_node_attr.name() == "Y_scale") { - quant_param->scale = onnx_node_attr.f(); - } else if (onnx_node_attr.name() == "Y_zero_point") { - quant_param->zeroPoint = static_cast(onnx_node_attr.i()); - } - } - quant_param->inited = true; - tensor->quantParams.emplace_back(std::move(quant_param)); - } else { - MS_LOG(ERROR) << "unsupported data type " << tensor->dataType; - return RET_ERROR; - } - } - auto index = - OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(onnx_node.output(0), tensor.release(), GRAPH_INPUT); - MS_LOG(DEBUG) << "add given tensor: " << index; - } - return RET_OK; -} - -STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *dst_op, const QuantType &quantType, - schema::MetaGraphT *dst_graph) { - // change op_type() to name(), that is unique - static bool interrupt = false; - dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0); - dst_op->quantType = quantType; - // dst_op->fmkType = FmkType_ONNX; - MS_LOG(DEBUG) << "onnx op name " << onnx_node.op_type() << ", dst op name: " << dst_op->name << ", input size " - << onnx_node.input_size(); - // get the real op type - if (onnx_node.op_type() == "Loop") { - NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type()); - interrupt = true; - return RET_NOT_FIND_OP; - int status = ParseSubgraph(dst_op, onnx_node, quantType, dst_graph); - if (status != RET_OK || interrupt) { - interrupt = true; - return status; - } - } else { - auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type()); - if (node_parser == nullptr || interrupt) { - interrupt = true; - if (node_parser == nullptr) { - NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type()); - } - return RET_NOT_FIND_OP; - } - auto status = node_parser->Parse(onnx_graph, onnx_node, dst_op); - if (status != RET_OK) { - interrupt = true; - if (status == RET_NOT_FIND_OP) { - NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type()); - } else { - MS_LOG(ERROR) << "parser onnx node " << onnx_node.op_type() << " attr failed"; - } - return status; - } - } - // set op input index - std::vector node_inputs; - (void)node_inputs.insert(node_inputs.begin(), onnx_node.input().begin(), onnx_node.input().end()); - if (SetOpInputIndex(node_inputs, dst_op, onnx_node)) { - interrupt = true; - MS_LOG(ERROR) << "SetOpInputIndex failed"; - return RET_ERROR; - } - if (dst_op->primitive->value.type == schema::PrimitiveType_Conv2D) { - auto &weight_tensor = - OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex)); - weight_tensor->format = dst_op->primitive->value.AsConv2D()->format; - } else if (dst_op->primitive->value.type == schema::PrimitiveType_DeConv2D) { - auto &weight_tensor = - OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex)); - weight_tensor->format = dst_op->primitive->value.AsDeConv2D()->format; - } - // set op output index - std::vector node_outputs; - (void)node_outputs.insert(node_outputs.begin(), onnx_node.output().begin(), onnx_node.output().end()); - - if (SetOpOutputIndex(node_outputs, dst_op) != RET_OK) { - interrupt = true; - MS_LOG(ERROR) << "SetOpOutputIndex failed"; - return RET_ERROR; - } - auto &output_tensor = - OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor().at(dst_op->outputIndex.front()); - if (output_tensor == nullptr) { - interrupt = true; - MS_LOG(ERROR) << "Output tensor of node " << onnx_node.op_type() << "is nullptr."; - return RET_ERROR; - } - SetOpQuantParams(onnx_graph, onnx_node, dst_op, output_tensor); - return RET_OK; -} - -void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *dst_op, schema::TensorT *dst_tensor) { - MS_ASSERT(dst_op != nullptr); - std::vector quant_node_name; - quant_node_name.insert(quant_node_name.begin(), onnx_node.input().begin(), onnx_node.input().end()); - quant_node_name.insert(quant_node_name.end(), onnx_node.output().begin(), onnx_node.output().end()); - std::vector quant_node; - for (const auto &str : quant_node_name) { - for (auto &node : onnx_graph.node()) { - if (node.output(0) == str) { - quant_node.emplace_back(node); - break; - } - } - } - auto needQuantParams = size_t(onnx_node.input().size() + onnx_node.output().size()); - for (auto iter = onnx_node.input().begin(); iter != onnx_node.input().end(); iter++) { - if (IsContain(this->graphInputNames, *iter)) { - needQuantParams--; - } - } - size_t findQuantParams = 0; - for (const auto &node : quant_node) { - std::unique_ptr quant_param = std::make_unique(); - if (quant_param == nullptr) { - MS_LOG(ERROR) << "new QuantParamT failed, node: " << dst_op->name; - return; - } - int argNum = 0; - for (const auto &onnx_node_attr : node.attribute()) { - if (onnx_node_attr.name() == "Y_scale") { - quant_param->scale = onnx_node_attr.f(); - argNum++; - } else if (onnx_node_attr.name() == "Y_zero_point") { - quant_param->zeroPoint = static_cast(onnx_node_attr.i()); - argNum++; - } - } - if (argNum != 2) { - continue; - } - dst_tensor->quantParams.emplace_back(std::move(quant_param)); - if (argNum == 2) { - findQuantParams++; - } - } - if (findQuantParams == needQuantParams) { - dst_op->quantType = schema::QuantType_AwareTraining; - } -} - -STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - const string &onnx_op_type, schema::CNodeT *dst_op) { - auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_op_type); - if (node_parser == nullptr) { - return RET_NOT_FIND_OP; - } - return node_parser->Parse(onnx_graph, onnx_node, dst_op); -} - -STATUS OnnxModelParser::SetOpInputIndex(const std::vector &node_inputs, schema::CNodeT *dst_op, - const onnx::NodeProto &onnx_node) { - for (const auto &onnx_node_input : node_inputs) { - if (onnx_node_input != "") { - int index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(onnx_node_input); - if (index < 0) { - MS_LOG(ERROR) << "input " << onnx_node_input << " of node " << onnx_node.name() << " can't be found"; - return RET_ERROR; - } - MS_LOG(DEBUG) << "node: " << onnx_node_input << ", input index: " << index; - dst_op->inputIndex.emplace_back(index); - } - } - return RET_OK; -} - -STATUS OnnxModelParser::SetOpOutputIndex(const std::vector &node_outputs, schema::CNodeT *dst_op) { - for (const auto &onnx_node_output : node_outputs) { - auto index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(onnx_node_output); - if (index < 0) { // when index >= 0, it's graph's output - std::unique_ptr tensor = std::make_unique(); - tensor->nodeType = schema::NodeType_Parameter; - index = - OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(onnx_node_output, tensor.release(), OP_OUTPUT); - } - MS_LOG(DEBUG) << "node: " << onnx_node_output << ", output index: " << index; - dst_op->outputIndex.emplace_back(index); - } - return RET_OK; -} - -STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, schema::TensorT *tensor) { - size_t data_count = 1; - std::for_each(tensor->dims.begin(), tensor->dims.end(), [&data_count](int dim) { data_count *= dim; }); - size_t data_size = 0; - const void *tensor_data = nullptr; - std::unique_ptr buffer; - switch (tensor->dataType) { - case kNumberTypeFloat32: - data_size = data_count * sizeof(float); - if (onnx_const_value.float_data_size() == 0) { - tensor_data = onnx_const_value.raw_data().data(); - } else { - tensor_data = onnx_const_value.float_data().data(); - } - break; - case kNumberTypeInt32: - data_size = data_count * sizeof(int); - if (onnx_const_value.int32_data_size() == 0) { - tensor_data = onnx_const_value.raw_data().data(); - } else { - tensor_data = onnx_const_value.int32_data().data(); - } - break; - case kNumberTypeInt64: - data_size = data_count * sizeof(int32_t); - buffer = std::make_unique(data_count); - const int64_t *in_data; - in_data = nullptr; - if (onnx_const_value.int64_data_size() == 0) { - in_data = reinterpret_cast(onnx_const_value.raw_data().data()); - } else { - in_data = onnx_const_value.int64_data().data(); - } - for (size_t i = 0; i < data_count; ++i) { - if (in_data[i] > static_cast(INT32_MAX) || in_data[i] < static_cast(INT32_MIN)) { - MS_LOG(WARNING) << "int64 data " << in_data[i] << "too big to fit into int32"; - buffer[i] = in_data[i] > 0 ? INT32_MAX : INT32_MIN; - } else { - buffer[i] = static_cast(in_data[i]); - } - } - tensor_data = reinterpret_cast(buffer.get()); - break; - case kNumberTypeUInt8: - case kNumberTypeInt8: - case kNumberTypeBool: - data_size = data_count * sizeof(uint8_t); - tensor_data = onnx_const_value.raw_data().data(); - break; - default: - MS_LOG(ERROR) << "unsupported data type " << tensor->dataType; - return RET_ERROR; - } - tensor->data.resize(data_size); - if (data_size != 0 && memcpy_s(static_cast(tensor->data.data()), data_size, tensor_data, data_size) != 0) { - MS_LOG(ERROR) << "memcpy_s failed"; - return RET_ERROR; - } - return RET_OK; -} - -STATUS OnnxModelParser::SetAllTensors(schema::MetaGraphT *graphDef) { - std::vector tensors = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor(); - for (auto iter : tensors) { - std::unique_ptr temp(iter); - graphDef->allTensors.emplace_back(move(temp)); - } - return RET_OK; -} - -void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph) { - this->graphInputNames.clear(); - this->graphConstNames.clear(); - for (auto &onnx_const : onnx_graph.initializer()) { - this->graphConstNames.emplace_back(onnx_const.name()); - } - for (auto &onnx_input : onnx_graph.input()) { - if (!IsContain(this->graphConstNames, onnx_input.name())) { - this->graphInputNames.emplace_back(onnx_input.name()); - } - } -} - -STATUS OnnxModelParser::ParseSubgraph(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node, - const QuantType &quantType, schema::MetaGraphT *dst_graph) { - MS_LOG(DEBUG) << "onnx LoopParser"; - if (dst_op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - dst_op->primitive = std::make_unique(); - if (dst_op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - std::unique_ptr attr = std::make_unique(); - if (attr == nullptr) { - MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; - } - attr->subGraphIndex = subGraphNum; - auto sub_graph = std::make_unique(); - int ret = ParseGraph(dst_graph, sub_graph.get(), onnx_node.attribute().at(0).g(), quantType); - dst_graph->subGraph.push_back(std::move(sub_graph)); - subGraphNum += 1; - if (ret != RET_OK) { - return ret; - } - dst_op->primitive->value.type = schema::PrimitiveType_Loop; - dst_op->primitive->value.value = attr.release(); - return RET_OK; -} - -int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph, - const onnx::GraphProto &onnx_graph, const QuantType &quantType) { - // dst_graph->name = onnx_graph.name(); // this is not used - // find out input names and const names - FindGraphInputAndConst(onnx_graph); - // set const tensor - int status = SetGraphConstTensor(onnx_graph); - if (status != RET_OK) { - MS_LOG(ERROR) << "SetGraphConstTensor failed"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return RET_ERROR; - } - - // init onnx model graph input tensor - - status = SetGraphInputTensor(onnx_graph, dst_sub_graph); - if (status != RET_OK) { - MS_LOG(ERROR) << "SetGraphInputTensor failed"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return RET_ERROR; - } - - // init op node input/output tensor, and dst_op attr - NoSupportOp::GetInstance()->SetFmkType("ONNX"); - for (const auto &onnx_node : onnx_graph.node()) { - int status_node = RET_OK; - if (onnx_node.op_type() == "Constant") { - continue; - } - if (onnx_node.op_type() == "Gemm") { - if (status == RET_OK) { - ParseOnnxGemmNode(onnx_graph, onnx_node, dst_sub_graph, dst_graph, quantType); - } - continue; - } else if (onnx_node.op_type() == "Int8GivenIntTensorFill" || onnx_node.op_type() == "Int8GivenTensorFill") { - if (status == RET_OK) { - status_node = ParseOnnxGivenFillNode(onnx_node); - if (status_node != RET_OK) { - MS_LOG(ERROR) << "ParseOnnxGivenFillNode failed: " << status_node; - status = (status == RET_OK ? status_node : status); - } - } - continue; - } - - std::unique_ptr dst_op = std::make_unique(); - status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), quantType, dst_graph); - if (status_node != RET_OK) { - status = (status == RET_OK ? status_node : status); - continue; - } - dst_graph->nodes.emplace_back(std::move(dst_op)); - dst_sub_graph->nodeIndices.push_back((dst_graph->nodes.size() - 1)); - } - if (status != RET_OK) { - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - for (auto &tensor : OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor()) { - delete tensor; - } - return RET_ERROR; - } - // init onnx model graph output tensor - status = SetGraphOutputTensor(onnx_graph, dst_sub_graph); - if (status != RET_OK) { - MS_LOG(ERROR) << "SetGraphOutputTensor failed"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return RET_ERROR; - } - SetAllTensors(dst_graph); - return RET_OK; -} - -schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) { - int status = ValidateFileStr(model_file, ".onnx"); - if (status != RET_OK) { - MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.onnx"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; - } - - onnx::ModelProto onnx_model; - status = ReadProtoFromBinaryFile((const char *)model_file.c_str(), &onnx_model); - if (status != RET_OK) { - MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_file; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; - } - OnnxNodeParser::set_opset_version(onnx_model.opset_import().Get(0).version()); - const onnx::GraphProto &onnx_graph = onnx_model.graph(); - MS_LOG(INFO) << "model producer name: " << onnx_model.producer_name() << ", graph name: " << onnx_graph.name(); - - auto dst_graph = std::make_unique(); - auto dst_sub_graph = std::make_unique(); - int ret = ParseGraph(dst_graph.get(), dst_sub_graph.get(), onnx_graph, quant_type); - dst_graph->subGraph.push_back(std::move(dst_sub_graph)); - subGraphNum += 1; - if (ret == RET_ERROR) { - return nullptr; - } - dst_graph->name = GetModelName(model_file); - - std::vector input_temp_index; - for (size_t i = 0; i < dst_graph->subGraph.front()->inputIndices.size(); i++) { - input_temp_index.push_back(dst_graph->subGraph.front()->inputIndices[i]); - } - dst_graph->inputIndex = input_temp_index; - - std::vector output_temp_index; - for (size_t i = 0; i < dst_graph->subGraph.front()->outputIndices.size(); i++) { - output_temp_index.push_back(dst_graph->subGraph.front()->outputIndices[i]); - } - dst_graph->outputIndex = output_temp_index; - - return dst_graph.release(); -} - } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h index 265d434587b..f56a8e6b752 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -26,75 +26,57 @@ #include #include #include -#include +#include #include "securec/include/securec.h" #include "tools/converter/model_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" -#include "tools/converter/parser/onnx/onnx_tensor_parser.h" #include "proto/onnx.pb.h" +#include "src/param_value_lite.h" namespace mindspore { namespace lite { class OnnxModelParser : public ModelParser { public: - OnnxModelParser(); + OnnxModelParser() = default; - virtual ~OnnxModelParser(); + ~OnnxModelParser() override = default; - // schema::MetaGraphT *ParseGraph(const onnx::GraphProto &graph, const QuantType &quantType = QuantType_QUANT_NONE); - int ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph, const onnx::GraphProto &onnx_graph, - const QuantType &quantType); + MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type) override { + return nullptr; + } + FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type) override; static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); + static STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, + const ParamValueLitePtr ¶m_value_lite); private: - schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type = QuantType_QUANT_NONE) override; + STATUS InitOriginModel(const std::string &model_file); + STATUS ConvertNodes(); + STATUS ConvertConstTensors(); + STATUS ConvertGraphInputs(); + STATUS ConvertGraphOutputs(); + STATUS BuildReturnNode(const std::vector &return_inputs); + STATUS BuildParameterNode(const ParameterPtr ¶meter_node, const onnx::TensorProto &tensor); + STATUS BuildParameterNodeForQuantParam(void *data, const std::string &name, TypeId type); + STATUS BuildCNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c); + STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const CNodePtr &cnode); + STATUS ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c); + STATUS ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c); + STATUS BuildCNodeForGemm(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c, const std::string &name); + STATUS ConvertOpQuantParams(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c); + STATUS ParseQuantParam(const onnx::NodeProto &onnx_node); + STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector *quant_params); + STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector *quant_params); + STATUS CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param, bool scale_or_not); + bool IsSpecialOnnxNode(const onnx::NodeProto &onnx_node); - std::vector GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value); - - STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph); - - STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph); - - STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph); - - STATUS AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const Category &type, int *index); - - STATUS AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const Category &type, int *index); - - STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *dst_op, const QuantType &quantType, schema::MetaGraphT *dst_graph); - - void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::SubGraphT *sub_graph, schema::MetaGraphT *graph, const QuantType &quant_type); - - STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node); - - STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - const string &onnx_op_type, schema::CNodeT *dst_op); - - void SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op, - schema::TensorT *dst_tensor); - - STATUS SetOpInputIndex(const std::vector &node_inputs, schema::CNodeT *dst_op, - const onnx::NodeProto &onnx_node); - - STATUS SetOpOutputIndex(const std::vector &node_outputs, schema::CNodeT *dst_op); - - STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, schema::TensorT *tensor); - - STATUS SetAllTensors(schema::MetaGraphT *graphDef); - - void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph); - - STATUS ParseSubgraph(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node, const QuantType &quantType, - schema::MetaGraphT *dst_graph); - - private: - std::vector graphInputNames; - std::vector graphConstNames; - int subGraphNum = 0; + onnx::ModelProto onnx_model_; + onnx::GraphProto onnx_graph_; + std::unordered_map nodes_; + FuncGraphPtr func_graph_ptr_ = nullptr; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc index 4ffc5fe9fdc..4a1083a43d0 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc @@ -15,7 +15,9 @@ */ #include "tools/converter/parser/onnx/onnx_node_parser.h" +#include #include +#include #include "tools/converter/parser/onnx/onnx_model_parser.h" namespace mindspore { diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h index f428bbc0c79..222d972cf7b 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h @@ -20,6 +20,7 @@ #include #include #include +#include "src/ops/primitive_c.h" #include "google/protobuf/message.h" #include "proto/onnx.pb.h" #include "include/errorcode.h" @@ -34,7 +35,8 @@ class OnnxNodeParser { virtual ~OnnxNodeParser() = default; - virtual STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) = 0; + virtual lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) = 0; static STATUS GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector *value, int *type); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.cc index 62ab3a78983..8c4979e0cda 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.cc @@ -19,23 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxNonMaxSuppressionParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxNonMaxSuppressionParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx EluParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } for (const auto &onnx_node_attr : onnx_node.attribute()) { @@ -47,9 +37,14 @@ STATUS OnnxNonMaxSuppressionParser::Parse(const onnx::GraphProto &onnx_graph, co } } - op->primitive->value.type = schema::PrimitiveType_NonMaxSuppression; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_NonMaxSuppression; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxNonMaxSuppressionParser("NonMaxSuppression", new OnnxNonMaxSuppressionParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.h index 12fb03fba84..3f20c01296a 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.h @@ -27,7 +27,7 @@ class OnnxNonMaxSuppressionParser : public OnnxNodeParser { OnnxNonMaxSuppressionParser() : OnnxNodeParser("NonMaxSuppression") {} ~OnnxNonMaxSuppressionParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_onehot_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_onehot_parser.cc index b001199d984..0ef4932291c 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_onehot_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_onehot_parser.cc @@ -19,23 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxOneHotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxOneHotParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx OneHotParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } for (const auto &onnx_node_attr : onnx_node.attribute()) { @@ -45,9 +35,14 @@ STATUS OnnxOneHotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N } } - op->primitive->value.type = schema::PrimitiveType_OneHot; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_OneHot; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxOneHotParser("OneHot", new OnnxOneHotParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_onehot_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_onehot_parser.h index 86446859d52..394502e1306 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_onehot_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_onehot_parser.h @@ -27,7 +27,7 @@ class OnnxOneHotParser : public OnnxNodeParser { OnnxOneHotParser() : OnnxNodeParser("OneHot") {} ~OnnxOneHotParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc index 7a26ab8f540..106a35e7c00 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc @@ -19,22 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxPadParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx PadParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } for (const auto &onnx_node_attr : onnx_node.attribute()) { @@ -58,9 +49,14 @@ STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node } } - op->primitive->value.type = schema::PrimitiveType_Pad; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Pad; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxPadParser("Pad", new OnnxPadParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.h index 8b1bb7adad7..4cdb8a0223e 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.h @@ -27,7 +27,7 @@ class OnnxPadParser : public OnnxNodeParser { OnnxPadParser() : OnnxNodeParser("Pad") {} ~OnnxPadParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc index 63bce297492..cdfb103f580 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc @@ -19,22 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxPoolParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx PoolParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } attr->format = schema::Format::Format_NCHW; @@ -56,7 +47,7 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod attr->global = false; } else { MS_LOG(ERROR) << "Pooling param`s PoolingMode is not MAX either AVE. MindSpore support MAX and AVE only."; - return RET_ERROR; + return nullptr; } attr->roundMode = schema::RoundMode_FLOOR; @@ -101,13 +92,18 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod } if (attribute_name == "dilations") { MS_LOG(ERROR) << "pooling op not support dilations now"; - return RET_ERROR; + return nullptr; } } - op->primitive->value.type = schema::PrimitiveType_Pooling; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Pooling; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxMaxPoolParser("MaxPool", new OnnxPoolParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.h index c4aee0398c0..4d864358b75 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.h @@ -27,7 +27,7 @@ class OnnxPoolParser : public OnnxNodeParser { OnnxPoolParser() : OnnxNodeParser("Pool") {} ~OnnxPoolParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_quantize_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_quantize_parser.cc index bfb45dd89a8..3c73e29f75a 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_quantize_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_quantize_parser.cc @@ -19,23 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxQuantizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxQuantizeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx QuantizeDequantizeParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed."; - return RET_NULL_PTR; + return nullptr; } if (onnx_node.op_type() == "Int8Quantize") { attr->srcT = kNumberTypeFloat32; @@ -45,11 +35,16 @@ STATUS OnnxQuantizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx: attr->dstT = kNumberTypeFloat32; } else { MS_LOG(ERROR) << "Unsupported nodeType: " << onnx_node.op_type().c_str(); - return RET_ERROR; + return nullptr; } - op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_QuantDTypeCast; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxInt8QuantizeParser("Int8Quantize", new OnnxQuantizeParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_quantize_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_quantize_parser.h index 182a38b6972..fdaf0b158bc 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_quantize_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_quantize_parser.h @@ -27,7 +27,7 @@ class OnnxQuantizeParser : public OnnxNodeParser { OnnxQuantizeParser() : OnnxNodeParser("Quantize") {} ~OnnxQuantizeParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_range_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_range_parser.cc index 9bc619a4cb7..555e47e64c4 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_range_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_range_parser.cc @@ -19,28 +19,23 @@ namespace mindspore { namespace lite { -STATUS OnnxRangeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxRangeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx RangeParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } attr->dType = 0; - op->primitive->value.type = schema::PrimitiveType_Range; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Range; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxRangeParser("Range", new OnnxRangeParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_range_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_range_parser.h index f565f153a80..cdc02d32c85 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_range_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_range_parser.h @@ -27,7 +27,7 @@ class OnnxRangeParser : public OnnxNodeParser { OnnxRangeParser() : OnnxNodeParser("Range") {} ~OnnxRangeParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc index 74eec741536..f0180d6bfc3 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc @@ -19,23 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxReduceParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx ReduceParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } attr->keepDims = 1; @@ -65,12 +55,17 @@ STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N attr->mode = schema::ReduceMode_ReduceSumSquare; } else { MS_LOG(ERROR) << "unsupported type"; - return RET_ERROR; + return nullptr; } - op->primitive->value.type = schema::PrimitiveType_Reduce; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Reduce; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxReduceMeanParser("ReduceMean", new OnnxReduceParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.h index 8dc803155b9..412200b2270 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.h @@ -27,7 +27,7 @@ class OnnxReduceParser : public OnnxNodeParser { OnnxReduceParser() : OnnxNodeParser("Reduce") {} ~OnnxReduceParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc index 54d91546131..cb80242c35f 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc @@ -21,22 +21,13 @@ namespace mindspore { namespace lite { -STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxReluParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx ReluParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } const auto &relu_type = onnx_node.op_type(); @@ -54,29 +45,24 @@ STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod } } - op->primitive->value.type = schema::PrimitiveType_Activation; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Activation; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } -STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxPReluParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx PReluParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - if (onnx_node.input_size() != 2) { MS_LOG(ERROR) << "input num should be 2"; - return RET_ERROR; + return nullptr; } - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); std::vector params; const auto &input_name = onnx_node.input(1); for (const auto &it : onnx_graph.initializer()) { @@ -90,7 +76,7 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No const onnx::TensorProto *slope = ¶ms[0]; if (slope == nullptr) { MS_LOG(ERROR) << "input error: params[0] is null"; - return RET_ERROR; + return nullptr; } const auto slope_raw_data = reinterpret_cast(slope->raw_data().data()); const int64_t slope_size = slope->raw_data().size() / sizeof(float); @@ -102,16 +88,21 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No attr->channelShared = false; if (memcpy_s(attr->slope.data(), slope_size * sizeof(float), slope_raw_data, slope_size * sizeof(float)) != EOK) { MS_LOG(ERROR) << "memcpy_s failed"; - return RET_ERROR; + return nullptr; } } } else { MS_LOG(WARNING) << "The slope pf prelu is null, which may cause errors."; } - op->primitive->value.type = schema::PrimitiveType_PReLU; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_PReLU; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxReluParser("Relu", new OnnxReluParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.h index 95d4303c413..0672da099ba 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.h @@ -27,7 +27,7 @@ class OnnxReluParser : public OnnxNodeParser { OnnxReluParser() : OnnxNodeParser("Relu") {} ~OnnxReluParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; class OnnxPReluParser : public OnnxNodeParser { @@ -35,7 +35,7 @@ class OnnxPReluParser : public OnnxNodeParser { OnnxPReluParser() : OnnxNodeParser("Prelu") {} ~OnnxPReluParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc index 81e91427329..a9407cfa50a 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc @@ -20,23 +20,13 @@ namespace mindspore { namespace lite { -STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxReshapeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx ReshapeParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } attr->format = schema::Format_NCHW; @@ -51,28 +41,17 @@ STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: } } } - } else { - onnx::TensorProto input_shape; - const auto &shape_name = onnx_node.input(1); - for (const auto &it : onnx_graph.initializer()) { - if (it.name() == shape_name) { - input_shape = it; - break; - } - } - if (input_shape.int64_data_size() == 0) { - MS_LOG(INFO) << "shape maybe from another op other than const initializer"; - } else { - for (int i = 0; i < input_shape.int64_data_size(); ++i) { - shape.push_back(input_shape.int64_data(i)); - } - } } attr->shape = shape; - op->primitive->value.type = schema::PrimitiveType_Reshape; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Reshape; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxReshapeParser("Reshape", new OnnxReshapeParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.h index 897cbf914e8..411329762ae 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.h @@ -27,7 +27,7 @@ class OnnxReshapeParser : public OnnxNodeParser { OnnxReshapeParser() : OnnxNodeParser("Reshape") {} ~OnnxReshapeParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_resize_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_resize_parser.cc index 7f417868a0a..0b9437e5c0a 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_resize_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_resize_parser.cc @@ -22,23 +22,13 @@ namespace mindspore { namespace lite { -STATUS OnnxResizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxResizeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx ResizeParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } attr->format = schema::Format_NCHW; @@ -85,9 +75,14 @@ STATUS OnnxResizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N } } - op->primitive->value.type = schema::PrimitiveType_Resize; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Resize; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxResizeParser("Resize", new OnnxResizeParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_resize_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_resize_parser.h index cdc12378f15..7bb19e84a86 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_resize_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_resize_parser.h @@ -27,7 +27,7 @@ class OnnxResizeParser : public OnnxNodeParser { OnnxResizeParser() : OnnxNodeParser("Resize") {} ~OnnxResizeParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.cc index 83cf58ba410..052c72f7034 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.cc @@ -19,28 +19,23 @@ namespace mindspore { namespace lite { -STATUS OnnxShapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxShapeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx ShapeParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - op->primitive->value.type = schema::PrimitiveType_Shape; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Shape; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxShapeParser("Shape", new OnnxShapeParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.h index 22879bbfbf6..3da6eed628c 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.h @@ -27,7 +27,7 @@ class OnnxShapeParser : public OnnxNodeParser { OnnxShapeParser() : OnnxNodeParser("Shape") {} ~OnnxShapeParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.cc index e8a344bcfb2..956d8936fbc 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.cc @@ -19,30 +19,25 @@ namespace mindspore { namespace lite { -STATUS OnnxSigmoidParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxSigmoidParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx SigmoidParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } attr->type = schema::ActivationType_SIGMOID; - op->primitive->value.type = schema::PrimitiveType_Activation; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Activation; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxSigmoodParser("Sigmoid", new OnnxSigmoidParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.h index 9980373d73c..c131af9fb7c 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.h @@ -27,7 +27,7 @@ class OnnxSigmoidParser : public OnnxNodeParser { OnnxSigmoidParser() : OnnxNodeParser("Sigmoid") {} ~OnnxSigmoidParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc index 5fa0d40d63a..facbc955c44 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc @@ -23,77 +23,13 @@ namespace mindspore { namespace lite { -STATUS OnnxSliceParser::InsertTensor(const std::vector &onnx_val, const std::string &name, - onnx::NodeProto *onnx_node) { - std::unique_ptr tensor = std::make_unique(); - if (tensor == nullptr) { - MS_LOG(ERROR) << "new tensor failed"; - return RET_ERROR; - } - tensor->dataType = mindspore::kNumberTypeInt32; - tensor->dims.push_back(onnx_val.size()); - tensor->format = schema::Format::Format_NCHW; - tensor->nodeType = schema::NodeType::NodeType_ValueNode; - int data_size = sizeof(int32_t) * onnx_val.size(); - tensor->data.resize(data_size); - if (data_size != 0 && - memcpy_s(static_cast(tensor->data.data()), data_size, onnx_val.data(), data_size) != EOK) { - MS_LOG(ERROR) << "memcpy_s failed"; - return RET_ERROR; - } - int tensor_num = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor().size(); - std::string tensor_name = name + std::to_string(tensor_num); - OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(tensor_name, tensor.release(), GRAPH_INPUT); - onnx_node->add_input(tensor_name); - return RET_OK; -} - -STATUS OnnxSliceParser::GetInputTensor(std::vector *onnx_val, const std::string &name) { - if (onnx_val == nullptr) { - MS_LOG(ERROR) << "input vector is nullptr."; - return RET_ERROR; - } - if (OnnxTensorParser::GetInstance() == nullptr || OnnxTensorParser::GetInstance()->GetTensorCache() == nullptr) { - MS_LOG(ERROR) << "cannot get tensorcache."; - return RET_ERROR; - } - int index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(name); - if (index == -1) { - MS_LOG(ERROR) << "can not find node: " << name; - return RET_ERROR; - } - auto input_tensor = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor()[index]; - if (input_tensor->data.empty()) { - MS_LOG(DEBUG) << "data is empty."; - return RET_NO_CHANGE; - } - int data_num = std::accumulate(input_tensor->dims.begin(), input_tensor->dims.end(), 1, std::multiplies()); - onnx_val->resize(data_num); - if (memcpy_s(onnx_val->data(), data_num * sizeof(int32_t), input_tensor->data.data(), data_num * sizeof(int32_t)) != - EOK) { - MS_LOG(ERROR) << "memcpy_s failed"; - return RET_ERROR; - } - return RET_OK; -} - -STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxSliceParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx SliceParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } std::vector starts; @@ -128,36 +64,17 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No } } } - int status = RET_OK; - switch (onnx_node.input_size()) { - case 5: { - if (steps.empty()) { - status = GetInputTensor(&steps, onnx_node.input(4)); - } - } - case 4: { - if (status != RET_ERROR && axes.empty()) { - status = GetInputTensor(&axes, onnx_node.input(3)); - } - } - case 3: { - if (status != RET_ERROR && ends.empty()) { - status = GetInputTensor(&ends, onnx_node.input(2)); - } - } - case 2: { - if (status != RET_ERROR && starts.empty()) { - status = GetInputTensor(&starts, onnx_node.input(1)); - } - } - default: { - if (status == RET_ERROR) { - MS_LOG(ERROR) << "onnx slice inputs are invalid."; - return RET_INPUT_TENSOR_ERROR; - } - } + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_StridedSlice; + primitive->value.value = attr.release(); + auto primitive_c = PrimitiveC::Create(primitive.release()); + if (starts.empty()) { + return primitive_c; } - if (axes.empty()) { for (size_t i = 0; i < starts.size(); ++i) { axes.push_back(i); @@ -166,42 +83,11 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No if (steps.empty()) { steps.assign(starts.size(), 1); } - onnx::NodeProto *slice_node = nullptr; - for (auto &node : onnx_graph.node()) { - if (&node == &onnx_node) { - slice_node = const_cast(&node); - } - } - int insert_num = 5 - onnx_node.input_size(); - switch (insert_num) { - case 4: { - std::string name = "slice/starts/"; - status = InsertTensor(starts, name, slice_node); - } - case 3: - if (status == RET_OK) { - std::string name = "slice/ends/"; - status = InsertTensor(ends, name, slice_node); - } - case 2: - if (status == RET_OK) { - std::string name = "slice/axes/"; - status = InsertTensor(axes, name, slice_node); - } - case 1: - if (status == RET_OK) { - std::string name = "slice/steps/"; - status = InsertTensor(steps, name, slice_node); - } - default: - if (status != RET_OK) { - MS_LOG(ERROR) << "onnx slice insert tensor failed"; - return RET_ERROR; - } - } - op->primitive->value.type = schema::PrimitiveType_StridedSlice; - op->primitive->value.value = attr.release(); - return RET_OK; + primitive_c->set_attr("starts", MakeValue>(starts)); + primitive_c->set_attr("ends", MakeValue>(ends)); + primitive_c->set_attr("axes", MakeValue>(axes)); + primitive_c->set_attr("steps", MakeValue>(steps)); + return primitive_c; } OnnxNodeRegistrar g_onnxSliceParser("Slice", new OnnxSliceParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h index 7bf60dfcc29..210fd4f3a0e 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h @@ -21,7 +21,6 @@ #include #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" -#include "tools/converter/parser/onnx/onnx_tensor_parser.h" namespace mindspore { namespace lite { @@ -30,9 +29,7 @@ class OnnxSliceParser : public OnnxNodeParser { OnnxSliceParser() : OnnxNodeParser("Slice") {} ~OnnxSliceParser() override = default; - STATUS InsertTensor(const std::vector &onnx_val, const std::string &name, onnx::NodeProto *onnx_node); - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; - STATUS GetInputTensor(std::vector *onnx_val, const std::string &name); + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc index 5e136e1583f..5facba0cc7f 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc @@ -19,23 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxSoftMaxParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx SoftMaxParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } bool axis_is_def = true; @@ -53,9 +43,14 @@ STATUS OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: attr->axis = 1; } } - op->primitive->value.type = schema::PrimitiveType_SoftMax; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_SoftMax; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxSoftMaxParser("Softmax", new OnnxSoftMaxParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.h index 6328f8c1019..d60346f65f3 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.h @@ -27,7 +27,7 @@ class OnnxSoftMaxParser : public OnnxNodeParser { OnnxSoftMaxParser() : OnnxNodeParser("Softmax") {} ~OnnxSoftMaxParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc index acfd7ec72ae..d404fe7285f 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc @@ -19,23 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxSpaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxSpaceToDepthParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx SpaceToDepthParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } for (const auto &onnx_node_attr : onnx_node.attribute()) { @@ -45,9 +35,14 @@ STATUS OnnxSpaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, const o } } - op->primitive->value.type = schema::PrimitiveType_SpaceToDepth; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_SpaceToDepth; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxSpaceToDepthParser("SpaceToDepth", new OnnxSpaceToDepthParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.h index a3783de172d..daff7831a7c 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.h @@ -27,7 +27,7 @@ class OnnxSpaceToDepthParser : public OnnxNodeParser { OnnxSpaceToDepthParser() : OnnxNodeParser("SpaceToDepth") {} ~OnnxSpaceToDepthParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_split_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_split_parser.cc index 8ee86e99c57..84a515eaa69 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_split_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_split_parser.cc @@ -19,23 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxSplitParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxSplitParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx SplitParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } attr->splitDim = 0; @@ -51,9 +41,14 @@ STATUS OnnxSplitParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No } } - op->primitive->value.type = schema::PrimitiveType_Split; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Split; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxSplitParser("Split", new OnnxSplitParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_split_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_split_parser.h index e3f3787c4f6..bd6fe288ecf 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_split_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_split_parser.h @@ -27,7 +27,7 @@ class OnnxSplitParser : public OnnxNodeParser { OnnxSplitParser() : OnnxNodeParser("Split") {} ~OnnxSplitParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc index 91012d71b0a..2c8a14b56f5 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc @@ -19,23 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxSqueezeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx SqueezeParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } for (const auto &onnx_node_attr : onnx_node.attribute()) { @@ -47,9 +37,14 @@ STATUS OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: } } - op->primitive->value.type = schema::PrimitiveType_Squeeze; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Squeeze; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxSqueezeParser("Squeeze", new OnnxSqueezeParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.h index 12f98d6b4b6..fef408d8bbc 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.h @@ -27,7 +27,7 @@ class OnnxSqueezeParser : public OnnxNodeParser { OnnxSqueezeParser() : OnnxNodeParser("Squeeze") {} ~OnnxSqueezeParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc index bdde1605f83..f875e312df9 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc @@ -20,26 +20,22 @@ namespace mindspore { namespace lite { -STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxTileParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx TileParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } - op->primitive->value.type = schema::PrimitiveType_Tile; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Tile; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxTileParser("Tile", new OnnxTileParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.h index c5204626f54..1117c34bba1 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.h @@ -27,7 +27,7 @@ class OnnxTileParser : public OnnxNodeParser { OnnxTileParser() : OnnxNodeParser("Tile") {} ~OnnxTileParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_topk_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_topk_parser.cc index 299b3efc319..67f7966f7c8 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_topk_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_topk_parser.cc @@ -19,22 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxTopkParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +lite::PrimitiveC *OnnxTopkParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx TopKParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } for (const auto &onnx_node_attr : onnx_node.attribute()) { @@ -44,9 +35,14 @@ STATUS OnnxTopkParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod } } - op->primitive->value.type = schema::PrimitiveType_TopK; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_TopK; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxTopkParser("TopK", new OnnxTopkParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_topk_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_topk_parser.h index 17111c1a4ee..4c593871d4a 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_topk_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_topk_parser.h @@ -27,7 +27,7 @@ class OnnxTopkParser : public OnnxNodeParser { OnnxTopkParser() : OnnxNodeParser("TopK") {} ~OnnxTopkParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc index 5fee208ad83..afb2b8cf722 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc @@ -19,23 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxTransposeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx TransposeParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } attr->conjugate = false; @@ -49,9 +39,14 @@ STATUS OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx } } - op->primitive->value.type = schema::PrimitiveType_Transpose; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Transpose; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxTransposeParser("Transpose", new OnnxTransposeParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.h index 15d5d2801b7..63f4b7f19e3 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.h @@ -27,7 +27,7 @@ class OnnxTransposeParser : public OnnxNodeParser { OnnxTransposeParser() : OnnxNodeParser("Transpose") {} ~OnnxTransposeParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc index 9d1f7dadd9e..d01fc0e9549 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc @@ -19,23 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxUnSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxUnSqueezeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx UnSqueezeParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } for (const auto &onnx_node_attr : onnx_node.attribute()) { @@ -47,9 +37,14 @@ STATUS OnnxUnSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx } } - op->primitive->value.type = schema::PrimitiveType_Unsqueeze; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Unsqueeze; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxUnsqueezeParser("Unsqueeze", new OnnxUnSqueezeParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.h index d523af53c5d..6e01f72d80f 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.h @@ -27,7 +27,7 @@ class OnnxUnSqueezeParser : public OnnxNodeParser { OnnxUnSqueezeParser() : OnnxNodeParser("Unsqueeze") {} ~OnnxUnSqueezeParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.cc index 79893ec6529..8da443c8072 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.cc @@ -19,23 +19,13 @@ namespace mindspore { namespace lite { -STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { +lite::PrimitiveC *OnnxUpsampleParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx UpsampleParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - std::unique_ptr attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; + return nullptr; } attr->method = schema::ResizeMethod_NEAREST; @@ -44,14 +34,19 @@ STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph, const onnx: if (attribute_name == "mode") { if (onnx_node_attr.s() != "nearest" && onnx_node_attr.s() != "linear") { MS_LOG(ERROR) << "the upsample mode don't support now."; - return RET_NOT_SUPPORT; + return nullptr; } attr->method = onnx_node_attr.s() == "nearest" ? schema::ResizeMethod_NEAREST : schema::ResizeMethod_LINEAR; } } - op->primitive->value.type = schema::PrimitiveType_Resize; - op->primitive->value.value = attr.release(); - return RET_OK; + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "new primitive failed"; + return nullptr; + } + primitive->value.type = schema::PrimitiveType_Resize; + primitive->value.value = attr.release(); + return PrimitiveC::Create(primitive.release()); } OnnxNodeRegistrar g_onnxUpsampleParser("Upsample", new OnnxUpsampleParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.h index 1e1d4fbee14..7b8158dbb47 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.h @@ -27,7 +27,7 @@ class OnnxUpsampleParser : public OnnxNodeParser { OnnxUpsampleParser() : OnnxNodeParser("Upsample") {} ~OnnxUpsampleParser() override = default; - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 0e3cc0b8ddd..3dcce9d9237 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -110,7 +110,7 @@ STATUS TfliteModelParser::ConvertOps() { status = ConvertOpQuantParams(op.get(), primitiveC); if (status != RET_OK) { MS_LOG(ERROR) << "convert " << op_name << " quant param failed."; - return status; + continue; } std::vector op_inputs = {NewValueNode(std::shared_ptr(primitiveC))}; @@ -132,7 +132,7 @@ STATUS TfliteModelParser::ConvertOps() { status = ConvertConstTensor(input_tensor.get(), parameter.get()); if (status != RET_OK) { MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed."; - return status; + continue; } op_inputs.emplace_back(parameter); nodes_.insert(std::pair(input_idx, parameter)); @@ -144,8 +144,7 @@ STATUS TfliteModelParser::ConvertOps() { status = ConvertOutputTensor(op.get(), new_cnode); if (status != RET_OK) { MS_LOG(ERROR) << "Convert output tensors for " << new_cnode->fullname_with_scope() << " failed."; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return status; + continue; } } return status; @@ -404,6 +403,7 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const } return RET_OK; } + MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file, const QuantType &quant_type) { return nullptr; diff --git a/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.cc b/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.cc new file mode 100644 index 00000000000..0f27c33d53d --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.cc @@ -0,0 +1,508 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/graph/onnx_inputs_adjust_pass.h" +#include +#include +#include +#include +#include +#include "mindspore/lite/include/errorcode.h" +#include "src/ops/primitive_c.h" + +namespace mindspore::opt { +bool OnnxInputAdjustOpPass::CheckInputs(const CNodePtr &cnode) { + if (cnode == nullptr) { + MS_LOG(ERROR) << "cnode is nullptr."; + return false; + } + if (std::any_of(cnode->inputs().begin(), cnode->inputs().end(), + [](const AnfNodePtr &anf_node) { return anf_node == nullptr; })) { + MS_LOG(ERROR) << "input is nullptr."; + return false; + } + return true; +} + +ParameterPtr OnnxInputAdjustOpPass::BuildParameterNode(const FuncGraphPtr &func_graph, const std::vector &data, + const std::string &node_name) { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(data.size() != 0); + auto param_node = func_graph->add_parameter(); + auto type_ptr = TypeIdToType(kNumberTypeInt32); + std::vector shape_vector{static_cast(data.size())}; + auto abstract_tensor = std::make_shared(type_ptr, shape_vector); + param_node->set_abstract(abstract_tensor); + param_node->set_name(node_name); + ParamValueLitePtr param_value = std::make_shared(); + MS_ASSERT(param_value != nullptr); + std::vector shape{static_cast(data.size())}; + param_value->set_tensor_shape(shape); + param_value->set_tensor_type(kNumberTypeInt32); + param_value->set_format(schema::Format::Format_NCHW); + char *default_data = new char[data.size() * sizeof(int)]; + if (memcpy_s(default_data, data.size() * sizeof(int), data.data(), data.size() * sizeof(int)) != EOK) { + MS_LOG(ERROR) << "memcpy data failed."; + delete[] default_data; + return nullptr; + } + param_value->SetTensorData(default_data, data.size() * sizeof(int)); + param_node->set_default_param(param_value); + return param_node; +} + +ParameterPtr OnnxInputAdjustOpPass::BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const ParamValueLitePtr ¶m_value) { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(cnode != nullptr); + MS_ASSERT(param_value != nullptr); + auto param_node = func_graph->add_parameter(); + auto shape = param_value->tensor_shape(); + std::vector shape_vector; + std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), + [](const int &val) { return static_cast(val); }); + auto data_type = param_value->tensor_type() == kNumberTypeInt64 ? kNumberTypeInt32 : param_value->tensor_type(); + auto abstract_tensor = std::make_shared(TypeIdToType(data_type), shape_vector); + param_node->set_abstract(abstract_tensor); + if (utils::isa(node)) { + param_node->set_name(node->cast()->fullname_with_scope()); + } else if (utils::isa(node)) { + param_node->set_name(node->cast()->name()); + } + ParamValueLitePtr param_value_new = std::make_shared(); + param_value_new->set_format(param_value->format()); + param_value_new->set_tensor_shape(shape); + size_t data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + if (param_value->tensor_size() == 0) { + if (param_value->tensor_type() == kNumberTypeInt64) { + param_value_new->set_tensor_type(kNumberTypeInt32); + } + param_node->set_default_param(param_value_new); + return param_node; + } + if (param_value->tensor_type() == kNumberTypeInt64) { + param_value_new->set_tensor_type(kNumberTypeInt32); + auto *tensor_data = new (std::nothrow) int[data_count]; + if (tensor_data == nullptr) { + MS_LOG(ERROR) << "new data failed"; + return nullptr; + } + auto *origin_data = reinterpret_cast(param_value->tensor_addr()); + for (size_t i = 0; i < data_count; ++i) { + if (origin_data[i] > static_cast(INT32_MAX) || origin_data[i] < static_cast(INT32_MIN)) { + MS_LOG(WARNING) << "int64 data " << origin_data[i] << "too big to fit into int32"; + tensor_data[i] = origin_data[i] > 0 ? INT32_MAX : INT32_MIN; + } else { + tensor_data[i] = static_cast(origin_data[i]); + } + } + param_value_new->SetTensorData(tensor_data, data_count * sizeof(int32_t)); + } else { + param_value_new->set_tensor_type(param_value->tensor_type()); + char *tensor_data = new (std::nothrow) char[param_value->tensor_size()]; + if (tensor_data == nullptr) { + MS_LOG(ERROR) << "new data failed"; + return nullptr; + } + if (memcpy_s(tensor_data, param_value->tensor_size(), param_value->tensor_addr(), param_value->tensor_size()) != + RET_OK) { + MS_LOG(ERROR) << "memcpy data failed."; + delete[] tensor_data; + return nullptr; + } + param_value_new->SetTensorData(tensor_data, param_value->tensor_size()); + } + param_node->set_default_param(param_value_new); + return param_node; +} + +STATUS OnnxInputAdjustOpPass::StridedSliceAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const std::string &attr_name) { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(cnode != nullptr); + auto inputs = cnode->inputs(); + auto primitive_c = GetValueNode>(cnode->input(0)); + auto value_ptr = primitive_c->GetAttr(attr_name); + MS_ASSERT(value_ptr != nullptr); + std::vector value_data = GetValue>(value_ptr); + auto param_node = BuildParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name); + inputs.push_back(param_node); + cnode->set_inputs(inputs); + primitive_c->EraseAttr(attr_name); + return lite::RET_OK; +} + +STATUS OnnxInputAdjustOpPass::ReplaceInt64ParameterNode(const FuncGraphPtr &func_graph, + const ParameterPtr ¶m_node) { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(param_node != nullptr); + if (param_node->abstract() == nullptr) { + MS_LOG(ERROR) << "parameter node abstract is invalid."; + return lite::RET_NULL_PTR; + } + auto abstract_tensor = param_node->abstract()->cast(); + if (abstract_tensor == nullptr) { + MS_LOG(ERROR) << "param node has no abstract tensor."; + return lite::RET_NULL_PTR; + } + if (abstract_tensor->element() == nullptr || abstract_tensor->element()->GetTypeTrack() == nullptr) { + MS_LOG(ERROR) << "get typePtr failed."; + return lite::RET_NULL_PTR; + } + if (abstract_tensor->element()->GetTypeTrack()->type_id() != kNumberTypeInt64) { + MS_LOG(DEBUG) << "don't need to convert to int32."; + return lite::RET_OK; + } + auto manager = func_graph->manager(); + MS_ASSERT(manager != nullptr); + if (param_node->has_default()) { + auto default_value = param_node->default_param(); + if (default_value == nullptr) { + MS_LOG(ERROR) << "default data is nullptr."; + return lite::RET_NULL_PTR; + } + auto param_value = default_value->cast(); + if (param_value == nullptr) { + MS_LOG(ERROR) << "default data is not paramvaluelite."; + return lite::RET_NULL_PTR; + } + auto param_node_new = BuildParameterNode(func_graph, param_node, param_value); + manager->Replace(param_node, param_node_new); + } else { + // set graph input + param_node->abstract()->set_type(TypeIdToType(kNumberTypeInt32)); + } + return lite::RET_OK; +} + +STATUS OnnxInputAdjustOpPass::AdjustPower(const CNodePtr &cnode) { + MS_ASSERT(cnode != nullptr); + if (!CheckInputs(cnode)) { + MS_LOG(ERROR) << "input is invalid."; + return lite::RET_INPUT_TENSOR_ERROR; + } + if (cnode->inputs().size() != 3) { + MS_LOG(ERROR) << "onnx power inputs is 2, but now is " << cnode->inputs().size() - 1; + return lite::RET_ERROR; + } + auto pow_param = cnode->input(2)->cast(); + if (pow_param == nullptr || !pow_param->has_default()) { + MS_LOG(ERROR) << "pow is from other node, which hasn't been supported."; + return lite::RET_NOT_SUPPORT; + } + auto pow_default = pow_param->default_param()->cast(); + if (pow_default == nullptr) { + MS_LOG(ERROR) << "pow is not a paramValueLite."; + return lite::RET_NULL_PTR; + } + if (std::accumulate(pow_default->tensor_shape().begin(), pow_default->tensor_shape().end(), 1, + std::multiplies()) != 1) { + MS_LOG(ERROR) << "the pow element num is bigger than 1, which don't support now."; + return lite::RET_NOT_SUPPORT; + } + if (pow_default->tensor_addr() == nullptr) { + MS_LOG(ERROR) << "power's attr pow can't be obtained."; + return lite::RET_INVALID_OP_ATTR; + } + auto primitive_c = GetValueNode>(cnode->input(0)); + if (primitive_c == nullptr || primitive_c->primitiveT() == nullptr || + primitive_c->primitiveT()->value.value == nullptr) { + MS_LOG(ERROR) << "get primitive_c failed."; + return lite::RET_NULL_PTR; + } + reinterpret_cast(primitive_c->primitiveT()->value.value)->power = + *reinterpret_cast(pow_default->tensor_addr()); + auto inputs = cnode->inputs(); + inputs.pop_back(); + cnode->set_inputs(inputs); + return lite::RET_OK; +} + +STATUS OnnxInputAdjustOpPass::AdjustStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_ASSERT(cnode != nullptr); + if (!CheckInputs(cnode)) { + MS_LOG(ERROR) << "input is invalid."; + return lite::RET_INPUT_TENSOR_ERROR; + } + if (cnode->inputs().size() == 2) { + if (StridedSliceAttrToInput(func_graph, cnode, "starts") != lite::RET_OK || + StridedSliceAttrToInput(func_graph, cnode, "ends") != lite::RET_OK || + StridedSliceAttrToInput(func_graph, cnode, "axes") != lite::RET_OK || + StridedSliceAttrToInput(func_graph, cnode, "steps") != lite::RET_OK) { + MS_LOG(ERROR) << "attr to input failed."; + return lite::RET_ERROR; + } + } else if (cnode->inputs().size() < 4) { + MS_LOG(ERROR) << "onnx slice's input size need to be larger than 2, now is " << cnode->inputs().size() - 1; + return lite::RET_INPUT_TENSOR_ERROR; + } + int size = 0; + for (size_t i = 2; i < cnode->inputs().size(); ++i) { + const auto ¶m_node = cnode->input(2)->cast(); + if (param_node == nullptr || !param_node->has_default()) { + continue; + } + const auto &default_data = param_node->default_param()->cast(); + if (default_data == nullptr) { + MS_LOG(ERROR) << "this input is not a paramValueLite."; + return lite::RET_ERROR; + } + auto shape = default_data->tensor_shape(); + size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + break; + } + auto inputs = cnode->inputs(); + switch (cnode->inputs().size()) { + case 4: { + std::vector axises; + for (int i = 0; i < size; ++i) { + axises.push_back(i); + } + auto new_param_node = BuildParameterNode(func_graph, axises, cnode->fullname_with_scope() + "_axises"); + if (new_param_node == nullptr) { + MS_LOG(ERROR) << "new a parameter node failed."; + } + inputs.push_back(new_param_node); + } + case 5: { + std::vector steps; + for (int i = 0; i < size; ++i) { + steps.push_back(1); + } + auto new_param_node = BuildParameterNode(func_graph, steps, cnode->fullname_with_scope() + "_steps"); + if (new_param_node == nullptr) { + MS_LOG(ERROR) << "new a parameter node failed."; + } + inputs.push_back(new_param_node); + break; + } + default: + MS_LOG(DEBUG) << "no need to adjust."; + return lite::RET_NO_CHANGE; + } + cnode->set_inputs(inputs); + return lite::RET_OK; +} + +STATUS OnnxInputAdjustOpPass::AdjustConvOrDeConv(const CNodePtr &cnode) { + MS_ASSERT(cnode != nullptr); + if (!CheckInputs(cnode)) { + MS_LOG(ERROR) << "input is invalid."; + return lite::RET_INPUT_TENSOR_ERROR; + } + auto type = opt::GetCNodeType(cnode); + if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DeConv2D) { + MS_LOG(DEBUG) << "node is not conv2d and deconv2d."; + return lite::RET_NO_CHANGE; + } + if (cnode->inputs().size() < 3) { + MS_LOG(ERROR) << "conv2d or deconv2d's input size is error, which is " << cnode->inputs().size() - 1; + return lite::RET_ERROR; + } + auto weight_param_node = cnode->input(2)->cast(); + if (weight_param_node == nullptr || !weight_param_node->has_default()) { + MS_LOG(INFO) << "weight tensor is not const tensor, which hasn't been supported."; + return lite::RET_NOT_SUPPORT; + } + auto weight_param_value = weight_param_node->default_param()->cast(); + if (weight_param_value == nullptr) { + MS_LOG(ERROR) << "weight is not a paramValueLite."; + return lite::RET_ERROR; + } + auto primitive_c = GetValueNode>(cnode->input(0)); + if (primitive_c == nullptr || primitive_c->primitiveT() == nullptr || + primitive_c->primitiveT()->value.value == nullptr) { + MS_LOG(ERROR) << "get primitive_c failed."; + return lite::RET_NULL_PTR; + } + if (type == schema::PrimitiveType_Conv2D) { + weight_param_value->set_format(reinterpret_cast(primitive_c->primitiveT()->value.value)->format); + } else { + weight_param_value->set_format( + reinterpret_cast(primitive_c->primitiveT()->value.value)->format); + } + return lite::RET_OK; +} + +STATUS OnnxInputAdjustOpPass::AdjustTile(const CNodePtr &cnode) { + MS_ASSERT(cnode != nullptr); + if (!CheckInputs(cnode)) { + MS_LOG(ERROR) << "input is invalid."; + return lite::RET_INPUT_TENSOR_ERROR; + } + if (cnode->inputs().size() != 3) { + MS_LOG(ERROR) << "x tile input size should be 2, now is " << cnode->inputs().size() - 1; + return lite::RET_INPUT_TENSOR_ERROR; + } + auto multiples_node = cnode->input(2)->cast(); + if (multiples_node == nullptr || !multiples_node->has_default()) { + MS_LOG(INFO) << "multiples tensor is not const tensor, which hasn't been supported."; + return lite::RET_NOT_SUPPORT; + } + auto multiples_param_value = multiples_node->cast(); + if (multiples_param_value == nullptr) { + MS_LOG(ERROR) << "weight is not a paramValueLite."; + return lite::RET_ERROR; + } + size_t dims_size = multiples_param_value->tensor_size() / sizeof(int); + if (dims_size == 0) { + MS_LOG(INFO) << "multiples tensor is not const tensor, which hasn't been supported."; + return lite::RET_NOT_SUPPORT; + } + std::vector multiples(dims_size, 0); + if (memcpy_s(multiples.data(), dims_size * sizeof(int), multiples_param_value->tensor_addr(), + dims_size * sizeof(int)) != EOK) { + MS_LOG(ERROR) << "memcpy_s failed."; + return lite::RET_ERROR; + } + std::vector dims; + for (size_t i = 0; i < dims_size; ++i) { + dims.push_back(i); + } + auto primitive_c = GetValueNode>(cnode->input(0)); + if (primitive_c == nullptr || primitive_c->primitiveT() == nullptr || + primitive_c->primitiveT()->value.value == nullptr) { + MS_LOG(ERROR) << "get primitive_c failed."; + return lite::RET_NULL_PTR; + } + reinterpret_cast(primitive_c->primitiveT()->value.value)->multiples = multiples; + reinterpret_cast(primitive_c->primitiveT()->value.value)->dims = dims; + return lite::RET_OK; +} + +STATUS OnnxInputAdjustOpPass::AdjustCast(const CNodePtr &cnode) { + MS_ASSERT(cnode != nullptr); + auto node = cnode->input(0); + MS_ASSERT(value_node != nullptr); + auto value_node = node->cast(); + if (value_node == nullptr) { + MS_LOG(ERROR) << "cnode input0 is not a valuenode."; + return lite::RET_ERROR; + } + MS_ASSERT(value_node->value != nullptr); + auto primitive_c = value_node->value()->cast(); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "cnode has no primitive_c."; + return lite::RET_ERROR; + } + auto primitive = primitive_c->primitiveT(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "cnode has no schema::primitive."; + return lite::RET_ERROR; + } + if (primitive->value.type != schema::PrimitiveType_Cast) { + MS_LOG(DEBUG) << "cnode is not cast node."; + return RET_OK; + } + auto value = primitive->value.value; + if (value == nullptr) { + MS_LOG(ERROR) << "value is nullptr."; + return lite::RET_ERROR; + } + auto attr = reinterpret_cast(value); + if (attr->dstT == kNumberTypeInt64) { + attr->dstT = kNumberTypeInt32; + } + return lite::RET_OK; +} + +STATUS OnnxInputAdjustOpPass::ReplaceConstant(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(cnode != nullptr); + if (cnode->inputs().size() < 1 || cnode->input(0) == nullptr) { + MS_LOG(ERROR) << "constant cnode has no primitive."; + return lite::RET_ERROR; + } + auto value_node = cnode->input(0)->cast(); + if (value_node == nullptr) { + MS_LOG(ERROR) << "constant input0 is not valuenode."; + return lite::RET_ERROR; + } + auto value_ptr = value_node->value(); + if (value_ptr == nullptr) { + MS_LOG(ERROR) << "value node has no value."; + return lite::RET_ERROR; + } + auto primitive_c = value_ptr->cast(); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "value is not primitive_c."; + return lite::RET_ERROR; + } + auto param_value = primitive_c->GetAttr("const_data"); + if (param_value == nullptr) { + MS_LOG(ERROR) << "constant cnode has no data."; + return lite::RET_ERROR; + } + auto param_value_lite = param_value->cast(); + if (param_value_lite == nullptr) { + MS_LOG(ERROR) << "valueptr is not paramvalueliteptr."; + return lite::RET_ERROR; + } + auto param_node = BuildParameterNode(func_graph, cnode, param_value_lite); + if (param_node == nullptr) { + MS_LOG(ERROR) << "convert constant to param node failed."; + return lite::RET_ERROR; + } + auto manager = func_graph->manager(); + MS_ASSERT(manager != nullptr); + manager->Replace(cnode, param_node); + return lite::RET_OK; +} + +bool OnnxInputAdjustOpPass::Run(const FuncGraphPtr &func_graph) { + MS_ASSERT(func_graph != nullptr); + auto manager = Manage(func_graph, true); + if (manager == nullptr) { + MS_LOG(ERROR) << "manager is nullptr."; + return lite::RET_NULL_PTR; + } + auto node_list = TopoSort(func_graph->get_return()); + int status = RET_OK; + for (auto &node : node_list) { + if (utils::isa(node)) { + auto param_node = node->cast(); + status = ReplaceInt64ParameterNode(func_graph, param_node); + if (status != lite::RET_OK) { + MS_LOG(ERROR) << "replace int64 param node failed."; + return status; + } + } + auto cnode = node->cast(); + if (cnode == nullptr) { + MS_LOG(DEBUG) << "node is not cnode."; + continue; + } + auto type = opt::GetCNodeType(node); + if (type == schema::PrimitiveType_Power) { + status = AdjustPower(cnode); + } else if (type == schema::PrimitiveType_StridedSlice) { + status = AdjustStridedSlice(func_graph, cnode); + } else if (type == schema::PrimitiveType_Conv2D || type == schema::PrimitiveType_DeConv2D) { + status = AdjustConvOrDeConv(cnode); + } else if (type == schema::PrimitiveType_Tile) { + status = AdjustConvOrDeConv(cnode); + } else if (type == schema::PrimitiveType_Constant) { + status = ReplaceConstant(func_graph, cnode); + } else if (type == schema::PrimitiveType_Cast) { + status = AdjustCast(cnode); + } + if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { + MS_LOG(ERROR) << "adjust input pass is failed."; + return false; + } + } + return true; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.h b/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.h new file mode 100644 index 00000000000..66ebd6a2dc6 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_ONNX_INPUTS_ADJUST_PASS_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_ONNX_INPUTS_ADJUST_PASS_H_ +#include +#include +#include "backend/optimizer/common/pass.h" +#include "tools/converter/converter_flags.h" +#include "tools/optimizer/common/gllo_utils.h" + +using mindspore::lite::converter::FmkType; +namespace mindspore::opt { +class OnnxInputAdjustOpPass : public Pass { + public: + OnnxInputAdjustOpPass() : Pass("onnx_input_adjust") {} + ~OnnxInputAdjustOpPass() override = default; + bool CheckInputs(const CNodePtr &cnode); + ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const std::vector &data, + const std::string &node_name); + ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const ParamValueLitePtr ¶m_value); + STATUS StridedSliceAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::string &attr_name); + STATUS ReplaceInt64ParameterNode(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node); + STATUS AdjustPower(const CNodePtr &cnode); + STATUS AdjustStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode); + STATUS AdjustConvOrDeConv(const CNodePtr &cnode); + STATUS AdjustTile(const CNodePtr &cnode); + STATUS AdjustCast(const CNodePtr &cnode); + STATUS ReplaceConstant(const FuncGraphPtr &func_graph, const CNodePtr &cnode); + bool Run(const FuncGraphPtr &func_graph) override; + + private: +}; +} // namespace mindspore::opt +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_ONNX_INPUTS_ADJUST_PASS_H_