reconstruct onnx

This commit is contained in:
xuanyue 2020-12-09 00:24:35 +08:00
parent e1e8f1d429
commit e7151c194c
110 changed files with 2612 additions and 2257 deletions

View File

@ -0,0 +1,36 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef PRIMITIVE_WRITEABLE
#ifndef LITE_MINDSPORE_LITE_C_OPS_CONSTANT_H_
#define LITE_MINDSPORE_LITE_C_OPS_CONSTANT_H_
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class Constant : public PrimitiveC {
public:
Constant() = default;
~Constant() = default;
MS_DECLARE_PARENT(Constant, PrimitiveC);
explicit Constant(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_CONSTANT_H_
#endif

View File

@ -149,6 +149,7 @@
#include "src/ops/oneslike.h"
#include "src/ops/unsorted_segment_sum.h"
#include "src/ops/reciprocal.h"
#include "src/ops/constant.h"
#ifdef SUPPORT_TRAIN
#include "src/ops/neg_grad.h"
@ -182,7 +183,7 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> CastToInt(const ValuePtr value) {
std::vector<int> CastToInt(const ValuePtr &value) {
if (value == nullptr) {
MS_LOG(WARNING) << "valueptr is nullptr.";
return {};
@ -891,6 +892,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) Dequant(primitive);
case schema::PrimitiveType_Reciprocal:
return new (std::nothrow) Reciprocal(primitive);
case schema::PrimitiveType_Constant:
return new (std::nothrow) Constant(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:

View File

@ -57,7 +57,7 @@ static std::map<std::string, schema::ActivationType> kActivationTypeMap{
{"LeakyRelu", schema::ActivationType_LEAKY_RELU},
{"Tanh", schema::ActivationType_TANH},
{"Logistic", schema::ActivationType_SIGMOID}};
std::vector<int> CastToInt(const ValuePtr value);
std::vector<int> CastToInt(const ValuePtr &value);
class PrimitiveC : public mindspore::Primitive {
public:
// Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC().

View File

@ -205,6 +205,7 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc
${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc
${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc
${LITE_DIR}/tools/optimizer/graph/onnx_inputs_adjust_pass.cc
)
endif()
### train

View File

@ -58,6 +58,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/graph/infershape_pass.cc
../optimizer/graph/slice_prepose_pass.cc
../optimizer/graph/mindir_adjust_pass.cc
../optimizer/graph/onnx_inputs_adjust_pass.cc
)
add_subdirectory(../anf_importer anf_importer)

View File

@ -36,6 +36,7 @@
#include "tools/optimizer/graph/clip_convert_activation_pass.h"
#include "tools/optimizer/graph/group_depthwise_op_convert_pass.h"
#include "tools/optimizer/graph/tflite_inputs_order_exchange_pass.h"
#include "tools/optimizer/graph/onnx_inputs_adjust_pass.h"
#include "tools/optimizer/graph/update_conv2d_param_pass.h"
#include "tools/optimizer/graph/unused_cast_node_remove_pass.h"
#include "tools/optimizer/graph/unused_transpose_node_remove_pass.h"
@ -74,6 +75,16 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
}
}
// onnx pre adjustment
if (config->fmk == converter::FmkType_ONNX) {
auto onnx_adjust_pass = std::make_shared<opt::OnnxInputAdjustOpPass>();
if (!onnx_adjust_pass->Run(old_graph)) {
MS_LOG(ERROR) << "onnx adjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
}
// for now - trainning is not supporting fuse operations
if (!config->trainModel) {
// remove quantdtype when awaretraining

View File

@ -90,6 +90,7 @@ STATUS CaffeModelParser::ConvertLayers() {
auto primitive_c = node_parser->ParseLitePrimitive(layer, weight);
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "parse node " << layer.name() << " failed.";
status = RET_ERROR;
continue;
}
@ -98,8 +99,7 @@ STATUS CaffeModelParser::ConvertLayers() {
status = ConvertBottom(layer, &input_nodes);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert layer bottom for " << layer.name() << " failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return status;
continue;
}
// build weights
@ -107,8 +107,7 @@ STATUS CaffeModelParser::ConvertLayers() {
status = ConvertBlobs(weight, &const_parameters);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert blobs for " << layer.name() << " failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return status;
continue;
}
// build cnode
@ -122,15 +121,13 @@ STATUS CaffeModelParser::ConvertLayers() {
status = ConvertTop(layer, new_cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert outputs for " << layer.name() << " failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return status;
continue;
}
status = ConvertLayerQuantParams(layer, weight, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert quant params for " << layer.name() << " failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return status;
continue;
}
}
return status;

View File

@ -19,27 +19,22 @@
namespace mindspore {
namespace lite {
STATUS OnnxAdderParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxAdderParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx AdderParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::AdderT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
op->primitive->value.type = schema::PrimitiveType_Adder;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Adder;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxAdderParser("adder_f", new OnnxAdderParser());

View File

@ -26,8 +26,7 @@ class OnnxAdderParser : public OnnxNodeParser {
public:
OnnxAdderParser() : OnnxNodeParser("Adder") {}
~OnnxAdderParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,23 +19,14 @@
namespace mindspore {
namespace lite {
STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxArgMaxParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx ArgMaxParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ArgMaxT> attr = std::make_unique<schema::ArgMaxT>();
auto attr = std::make_unique<schema::ArgMaxT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -46,10 +37,14 @@ STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
attr->keepDims = static_cast<bool>(onnx_node_attr.i());
}
}
op->primitive->value.type = schema::PrimitiveType_ArgMax;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_ArgMax;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxArgMaxParser("ArgMax", new OnnxArgMaxParser());

View File

@ -27,7 +27,7 @@ class OnnxArgMaxParser : public OnnxNodeParser {
OnnxArgMaxParser() : OnnxNodeParser("ArgMax") {}
~OnnxArgMaxParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -26,203 +26,203 @@ class OnnxAddParser : public OnnxNodeParser {
public:
OnnxAddParser() : OnnxNodeParser("Add") {}
~OnnxAddParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxSubParser : public OnnxNodeParser {
public:
OnnxSubParser() : OnnxNodeParser("Sub") {}
~OnnxSubParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxMulParser : public OnnxNodeParser {
public:
OnnxMulParser() : OnnxNodeParser("Mul") {}
~OnnxMulParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxDivParser : public OnnxNodeParser {
public:
OnnxDivParser() : OnnxNodeParser("Div") {}
~OnnxDivParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxPowParser : public OnnxNodeParser {
public:
OnnxPowParser() : OnnxNodeParser("Power") {}
~OnnxPowParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxEqualParser : public OnnxNodeParser {
public:
OnnxEqualParser() : OnnxNodeParser("Equal") {}
~OnnxEqualParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxLessParser : public OnnxNodeParser {
public:
OnnxLessParser() : OnnxNodeParser("Less") {}
~OnnxLessParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxGreaterParser : public OnnxNodeParser {
public:
OnnxGreaterParser() : OnnxNodeParser("Greater") {}
~OnnxGreaterParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxMinParser : public OnnxNodeParser {
public:
OnnxMinParser() : OnnxNodeParser("Min") {}
~OnnxMinParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxEltwiseParser : public OnnxNodeParser {
public:
OnnxEltwiseParser() : OnnxNodeParser("Eltwise") {}
~OnnxEltwiseParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxFloorParser : public OnnxNodeParser {
public:
OnnxFloorParser() : OnnxNodeParser("Floor") {}
~OnnxFloorParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxAbsParser : public OnnxNodeParser {
public:
OnnxAbsParser() : OnnxNodeParser("Abs") {}
~OnnxAbsParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxNegParser : public OnnxNodeParser {
public:
OnnxNegParser() : OnnxNodeParser("Neg") {}
~OnnxNegParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxExpParser : public OnnxNodeParser {
public:
OnnxExpParser() : OnnxNodeParser("Exp") {}
~OnnxExpParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxCosParser : public OnnxNodeParser {
public:
OnnxCosParser() : OnnxNodeParser("Cos") {}
~OnnxCosParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxSinParser : public OnnxNodeParser {
public:
OnnxSinParser() : OnnxNodeParser("Sin") {}
~OnnxSinParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxSqrtParser : public OnnxNodeParser {
public:
OnnxSqrtParser() : OnnxNodeParser("Sqrt") {}
~OnnxSqrtParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxCeilParser : public OnnxNodeParser {
public:
OnnxCeilParser() : OnnxNodeParser("Ceil") {}
~OnnxCeilParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxLogParser : public OnnxNodeParser {
public:
OnnxLogParser() : OnnxNodeParser("Log") {}
~OnnxLogParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxTanParser : public OnnxNodeParser {
public:
OnnxTanParser() : OnnxNodeParser("Tan") {}
~OnnxTanParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxAtanParser : public OnnxNodeParser {
public:
OnnxAtanParser() : OnnxNodeParser("Atan") {}
~OnnxAtanParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxAsinParser : public OnnxNodeParser {
public:
OnnxAsinParser() : OnnxNodeParser("Asin") {}
~OnnxAsinParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxTanhParser : public OnnxNodeParser {
public:
OnnxTanhParser() : OnnxNodeParser("Tanh") {}
~OnnxTanhParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxSignParser : public OnnxNodeParser {
public:
OnnxSignParser() : OnnxNodeParser("Sign") {}
~OnnxSignParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxAndParser : public OnnxNodeParser {
public:
OnnxAndParser() : OnnxNodeParser("And") {}
~OnnxAndParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxOrParser : public OnnxNodeParser {
public:
OnnxOrParser() : OnnxNodeParser("Or") {}
~OnnxOrParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxNotParser : public OnnxNodeParser {
public:
OnnxNotParser() : OnnxNodeParser("Not") {}
~OnnxNotParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxRoundParser : public OnnxNodeParser {
public:
OnnxRoundParser() : OnnxNodeParser("Round") {}
~OnnxRoundParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxReciprocalParser : public OnnxNodeParser {
public:
OnnxReciprocalParser() : OnnxNodeParser("Reciprocal") {}
~OnnxReciprocalParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,23 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxBatchNormParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx BatchNormParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::FusedBatchNormT> attr = std::make_unique<schema::FusedBatchNormT>();
auto attr = std::make_unique<schema::FusedBatchNormT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -47,10 +37,14 @@ STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx
attr->spatial = static_cast<int32_t>(onnx_node_attr.i());
}
}
op->primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxBatchNormParser("BatchNormalization", new OnnxBatchNormParser());

View File

@ -27,7 +27,7 @@ class OnnxBatchNormParser : public OnnxNodeParser {
OnnxBatchNormParser() : OnnxNodeParser("BatchNormalization") {}
~OnnxBatchNormParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,30 +19,25 @@
namespace mindspore {
namespace lite {
STATUS OnnxBiasAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxBiasAddParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx BiasAddParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::BiasAddT> attr = std::make_unique<schema::BiasAddT>();
auto attr = std::make_unique<schema::BiasAddT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
attr->axis = {1};
op->primitive->value.type = schema::PrimitiveType_BiasAdd;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_BiasAdd;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxBiasAddParser("BiasAdd", new OnnxBiasAddParser());

View File

@ -27,7 +27,7 @@ class OnnxBiasAddParser : public OnnxNodeParser {
OnnxBiasAddParser() : OnnxNodeParser("BiasAdd") {}
~OnnxBiasAddParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -20,22 +20,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
lite::PrimitiveC *OnnxCastParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx CastParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>();
auto attr = std::make_unique<schema::CastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -48,10 +39,14 @@ STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
attr->dstT = static_cast<int>(dst_type);
}
}
op->primitive->value.type = schema::PrimitiveType_Cast;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Cast;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxCastParser("Cast", new OnnxCastParser());

View File

@ -27,7 +27,7 @@ class OnnxCastParser : public OnnxNodeParser {
OnnxCastParser() : OnnxNodeParser("Cast") {}
~OnnxCastParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,39 +19,32 @@
namespace mindspore {
namespace lite {
STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
lite::PrimitiveC *OnnxClipParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx ClipParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
auto attr = std::make_unique<schema::ClipT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
float min = -1, max = -1;
attr->max = -1;
attr->min = -1;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "max") {
max = onnx_node_attr.f();
attr->max = onnx_node_attr.f();
} else if (attribute_name == "min") {
min = onnx_node_attr.f();
attr->min = onnx_node_attr.f();
}
}
std::unique_ptr<schema::ClipT> attr = std::make_unique<schema::ClipT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
attr->max = max;
attr->min = min;
op->primitive->value.type = schema::PrimitiveType_Clip;
op->primitive->value.value = attr.release();
return RET_OK;
primitive->value.type = schema::PrimitiveType_Clip;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxClipParser("Clip", new OnnxClipParser());

View File

@ -27,7 +27,7 @@ class OnnxClipParser : public OnnxNodeParser {
OnnxClipParser() : OnnxNodeParser("Clip") {}
~OnnxClipParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,23 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxConcatParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx ConcatParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ConcatT> attr = std::make_unique<schema::ConcatT>();
auto attr = std::make_unique<schema::ConcatT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -44,10 +34,14 @@ STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
attr->axis = static_cast<int32_t>(onnx_node_attr.i());
}
}
op->primitive->value.type = schema::PrimitiveType_Concat;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Concat;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxConcatParser("Concat", new OnnxConcatParser());

View File

@ -27,7 +27,7 @@ class OnnxConcatParser : public OnnxNodeParser {
OnnxConcatParser() : OnnxNodeParser("Concat") {}
~OnnxConcatParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -20,23 +20,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxConstantOfShapeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx ConstantOfShapeParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ConstantOfShapeT> attr = std::make_unique<schema::ConstantOfShapeT>();
auto attr = std::make_unique<schema::ConstantOfShapeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -55,19 +45,24 @@ STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, cons
const auto &tensor = onnx_node_attr.t();
auto ret = GetTensorDataFromOnnx(tensor, &attr->value, &attr->dataType);
if (ret != RET_OK) {
return ret;
MS_LOG(ERROR) << "get data from tensor failed";
return nullptr;
}
} break;
default:
MS_LOG(ERROR) << "The data type is not supported.";
return RET_ERROR;
return nullptr;
}
}
}
op->primitive->value.type = schema::PrimitiveType_ConstantOfShape;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_ConstantOfShape;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxConstantOfShapeParser("ConstantOfShape", new OnnxConstantOfShapeParser());

View File

@ -27,7 +27,7 @@ class OnnxConstantOfShapeParser : public OnnxNodeParser {
OnnxConstantOfShapeParser() : OnnxNodeParser("ConstantOfShape") {}
~OnnxConstantOfShapeParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -16,33 +16,75 @@
#include "tools/converter/parser/onnx/onnx_constant_parser.h"
#include <memory>
#include <vector>
#include <algorithm>
#include "tools/converter/parser/onnx/onnx_model_parser.h"
namespace mindspore {
namespace lite {
STATUS OnnxConstantParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ConstantParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
STATUS OnnxConstantParser::AddDataInfoAttr(const onnx::TensorProto &onnx_const_tensor, lite::PrimitiveC *primitive_c) {
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
if (param_value == nullptr) {
MS_LOG(ERROR) << "new a paramValueLite failed.";
return RET_ERROR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
auto data_type =
OnnxModelParser::GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(onnx_const_tensor.data_type()));
if (data_type == kTypeUnknown) {
MS_LOG(ERROR) << "not support onnx data type "
<< static_cast<onnx::TensorProto_DataType>(onnx_const_tensor.data_type());
return RET_ERROR;
}
std::unique_ptr<schema::ConstantT> attr = std::make_unique<schema::ConstantT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
std::vector<int64_t> shape_vector(onnx_const_tensor.dims().begin(), onnx_const_tensor.dims().end());
std::vector<int> shape;
std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape),
[](const int64_t &val) { return static_cast<int32_t>(val); });
param_value->set_tensor_type(data_type);
param_value->set_tensor_shape(shape);
param_value->set_format(schema::Format_NCHW);
if (OnnxModelParser::CopyOnnxTensorData(onnx_const_tensor, param_value) != RET_OK) {
MS_LOG(ERROR) << "get value failed.";
return RET_ERROR;
}
op->primitive->value.type = schema::PrimitiveType_Constant;
op->primitive->value.value = attr.release();
primitive_c->set_attr("const_data", param_value);
return RET_OK;
}
lite::PrimitiveC *OnnxConstantParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx ConstantParser";
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Constant;
auto primitive_c = PrimitiveC::Create(primitive.release());
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "create primitiveC failed.";
return nullptr;
}
for (const auto &attr : onnx_node.attribute()) {
if (attr.name() == "sparse_value") {
MS_LOG(WARNING) << "sparse_value";
continue;
}
if (attr.name() == "value") {
const auto &const_tensor = attr.t();
if (AddDataInfoAttr(const_tensor, primitive_c) != RET_OK) {
MS_LOG(ERROR) << "add basic attr failed.";
delete primitive_c;
return nullptr;
}
} else {
MS_LOG(ERROR) << "processing Constant op attr " << attr.name() << " not implemented";
delete primitive_c;
return nullptr;
}
}
return primitive_c;
}
OnnxNodeRegistrar g_onnxConstantParser("Constant", new OnnxConstantParser());
} // namespace lite
} // namespace mindspore

View File

@ -27,7 +27,8 @@ class OnnxConstantParser : public OnnxNodeParser {
OnnxConstantParser() : OnnxNodeParser("Constant") {}
~OnnxConstantParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS AddDataInfoAttr(const onnx::TensorProto &onnx_const_tensor, lite::PrimitiveC *primitive_c);
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -21,9 +21,14 @@
namespace mindspore::lite {
constexpr int32_t kSingleGroup = 1;
bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, schema::CNodeT *op) {
bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr,
schema::PrimitiveT *primitive) {
MS_LOG(DEBUG) << "onnx DepthwiseConvParser";
std::unique_ptr<schema::DepthwiseConv2DT> depthwiseConv2DParam = std::make_unique<schema::DepthwiseConv2DT>();
if (attr == nullptr || primitive == nullptr) {
MS_LOG(ERROR) << "input parameter is nullptr";
return false;
}
auto depthwiseConv2DParam = std::make_unique<schema::DepthwiseConv2DT>();
if (depthwiseConv2DParam == nullptr) {
MS_LOG(ERROR) << "new op failed";
return false;
@ -45,27 +50,18 @@ bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT
depthwiseConv2DParam->hasBias = attr->hasBias;
depthwiseConv2DParam->activationType = attr->activationType;
op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
op->primitive->value.value = depthwiseConv2DParam.release();
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
primitive->value.value = depthwiseConv2DParam.release();
return true;
}
STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
lite::PrimitiveC *OnnxConvParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx ConvParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::Conv2DT> attr = std::make_unique<schema::Conv2DT>();
auto attr = std::make_unique<schema::Conv2DT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
attr->strideH = 1;
@ -83,21 +79,21 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
} else if (onnx_node_attr.name() == "dilations") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
return nullptr;
}
attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "kernels") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
return nullptr;
}
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "kernel_shape") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
return nullptr;
}
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
@ -106,7 +102,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
} else if (onnx_node_attr.name() == "pads") {
if (onnx_node_attr.ints().size() != 4) {
MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4";
return RET_ERROR;
return nullptr;
}
attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->padLeft = static_cast<int32_t>(onnx_node_attr.ints(1));
@ -115,7 +111,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
} else if (onnx_node_attr.name() == "strides") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
return nullptr;
}
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1));
@ -124,7 +120,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
attr->format = schema::Format::Format_NHWC;
} else {
MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s();
return RET_ERROR;
return nullptr;
}
}
}
@ -152,7 +148,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
[onnx_conv_weight](const onnx::NodeProto &proto) { return proto.output(0) == onnx_conv_weight; });
if (node_iter == onnx_graph.node().end()) {
MS_LOG(ERROR) << "can not find node: " << onnx_conv_weight;
return RET_ERROR;
return nullptr;
}
std::vector<int> dims;
auto iter = std::find_if((*node_iter).attribute().begin(), (*node_iter).attribute().end(),
@ -160,7 +156,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
if (iter != (*node_iter).attribute().end()) {
if (iter->ints().begin() == nullptr || iter->ints().end() == nullptr) {
MS_LOG(ERROR) << "dims insert failed";
return RET_ERROR;
return nullptr;
}
dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end());
}
@ -174,16 +170,21 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
if (attr->group > kSingleGroup && attr->group == attr->channelIn) {
if (!ParseGroupConvolution(attr, op)) {
if (!ParseGroupConvolution(attr, primitive.get())) {
MS_LOG(ERROR) << "Convert Convolution to Depthwise failed";
return RET_ERROR;
return nullptr;
}
} else {
op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_Conv2D;
primitive->value.value = attr.release();
}
return RET_OK;
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxConvParser("Conv", new OnnxConvParser());

View File

@ -28,10 +28,10 @@ class OnnxConvParser : public OnnxNodeParser {
OnnxConvParser() : OnnxNodeParser("Conv") {}
~OnnxConvParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
private:
static bool ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, schema::CNodeT *op);
static bool ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, schema::PrimitiveT *primitive);
};
} // namespace lite
} // namespace mindspore

View File

@ -21,11 +21,13 @@
namespace mindspore {
namespace lite {
bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op) {
if (attr == nullptr || attr->group != attr->channelOut) {
bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr,
schema::PrimitiveT *primitive) {
if (attr == nullptr || attr->group != attr->channelOut || primitive == nullptr) {
MS_LOG(ERROR) << "input parameter is nullptr";
return false;
}
std::unique_ptr<schema::DeDepthwiseConv2DT> deDepthwiseConv2DParam = std::make_unique<schema::DeDepthwiseConv2DT>();
auto deDepthwiseConv2DParam = std::make_unique<schema::DeDepthwiseConv2DT>();
if (deDepthwiseConv2DParam == nullptr) {
MS_LOG(ERROR) << "new op failed";
return false;
@ -47,28 +49,18 @@ bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeC
deDepthwiseConv2DParam->hasBias = attr->hasBias;
deDepthwiseConv2DParam->activationType = attr->activationType;
op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D;
op->primitive->value.value = deDepthwiseConv2DParam.release();
primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D;
primitive->value.value = deDepthwiseConv2DParam.release();
return true;
}
STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx DeConvParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::DeConv2DT> attr = std::make_unique<schema::DeConv2DT>();
auto attr = std::make_unique<schema::DeConv2DT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
attr->padMode = schema::PadMode_NOTSET;
@ -83,21 +75,21 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
} else if (onnx_node_attr.name() == "dilations") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
return nullptr;
}
attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "kernels") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
return nullptr;
}
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "kernel_shape") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
return nullptr;
}
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
@ -106,7 +98,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
} else if (onnx_node_attr.name() == "pads") {
if (onnx_node_attr.ints().size() != 4) {
MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4";
return RET_ERROR;
return nullptr;
}
attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->padLeft = static_cast<int32_t>(onnx_node_attr.ints(1));
@ -115,7 +107,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
} else if (onnx_node_attr.name() == "strides") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
return nullptr;
}
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1));
@ -124,11 +116,11 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
attr->format = schema::Format::Format_NHWC;
} else {
MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s().c_str();
return RET_ERROR;
return nullptr;
}
} else if (onnx_node_attr.name() == "output_padding") {
MS_LOG(ERROR) << "output_padding param hasn't been supported";
return RET_NOT_SUPPORT;
return nullptr;
}
}
@ -138,7 +130,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
[onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; });
if (node_iter == onnx_graph.initializer().end()) {
MS_LOG(ERROR) << "not find node: " << onnx_conv_weight.c_str();
return RET_ERROR;
return nullptr;
}
std::vector<int> weight_shape;
auto size = (*node_iter).dims_size();
@ -148,7 +140,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
}
if (weight_shape.size() != 4) {
MS_LOG(ERROR) << "weight_shape.size() should be 4, but is " << weight_shape.size();
return RET_ERROR;
return nullptr;
}
attr->channelIn = weight_shape[0];
attr->channelOut = weight_shape[1] * attr->group;
@ -156,17 +148,22 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
attr->format = schema::Format::Format_NCHW;
attr->hasBias = onnx_node.input().size() == 3;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
if (attr->group != 1) {
if (!ParseGroupDeConvolution(attr, op)) {
if (!ParseGroupDeConvolution(attr, primitive.get())) {
MS_LOG(ERROR) << "Convert DeConvolution to DeDepthwise failed, generalized group deconv hasn't support";
return RET_NOT_SUPPORT;
return nullptr;
}
} else {
op->primitive->value.type = schema::PrimitiveType_DeConv2D;
op->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_DeConv2D;
primitive->value.value = attr.release();
}
return RET_OK;
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxDeConvParser("ConvTranspose", new OnnxDeConvParser());

View File

@ -28,10 +28,10 @@ class OnnxDeConvParser : public OnnxNodeParser {
OnnxDeConvParser() : OnnxNodeParser("DeConv") {}
~OnnxDeConvParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
private:
static bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op);
bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::PrimitiveT *primitive);
};
} // namespace lite
} // namespace mindspore

View File

@ -19,23 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxDepthToSpaceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxDepthToSpaceParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx DepthToSpaceParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::DepthToSpaceT> attr = std::make_unique<schema::DepthToSpaceT>();
auto attr = std::make_unique<schema::DepthToSpaceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -44,10 +34,14 @@ STATUS OnnxDepthToSpaceParser::Parse(const onnx::GraphProto &onnx_graph, const o
attr->blockSize = static_cast<int32_t>(onnx_node_attr.i());
}
}
op->primitive->value.type = schema::PrimitiveType_DepthToSpace;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_DepthToSpace;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxDepthToSpaceParser("DepthToSpace", new OnnxDepthToSpaceParser());

View File

@ -27,7 +27,7 @@ class OnnxDepthToSpaceParser : public OnnxNodeParser {
OnnxDepthToSpaceParser() : OnnxNodeParser("DepthToSpace") {}
~OnnxDepthToSpaceParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,23 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxDropoutParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx DropoutParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::DropoutT> attr = std::make_unique<schema::DropoutT>();
auto attr = std::make_unique<schema::DropoutT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -44,10 +34,14 @@ STATUS OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::
attr->ratio = static_cast<float>(onnx_node_attr.f());
}
}
op->primitive->value.type = schema::PrimitiveType_Dropout;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Dropout;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxDropoutParser("Dropout", new OnnxDropoutParser());

View File

@ -27,7 +27,7 @@ class OnnxDropoutParser : public OnnxNodeParser {
OnnxDropoutParser() : OnnxNodeParser("Dropout") {}
~OnnxDropoutParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,22 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
lite::PrimitiveC *OnnxEluParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx EluParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::EluT> attr = std::make_unique<schema::EluT>();
auto attr = std::make_unique<schema::EluT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -43,10 +34,14 @@ STATUS OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node
attr->alpha = onnx_node_attr.f();
}
}
op->primitive->value.type = schema::PrimitiveType_Elu;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Elu;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxEluParser("Elu", new OnnxEluParser());

View File

@ -27,7 +27,7 @@ class OnnxEluParser : public OnnxNodeParser {
OnnxEluParser() : OnnxNodeParser("Elu") {}
~OnnxEluParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -20,23 +20,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxExpandParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx ExpandParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::BroadcastToT> attr = std::make_unique<schema::BroadcastToT>();
auto attr = std::make_unique<schema::BroadcastToT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
std::vector<int> dst_shape;
@ -46,7 +36,7 @@ STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
[onnx_expand_power](const onnx::NodeProto &proto) { return proto.output(0) == onnx_expand_power; });
if (node_iter == onnx_graph.node().end()) {
MS_LOG(ERROR) << "can not find node: " << onnx_expand_power;
return RET_ERROR;
return nullptr;
}
for (const auto &attrPower : node_iter->attribute()) {
if (attrPower.name() == "value") {
@ -58,9 +48,14 @@ STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
}
}
attr->dst_shape = dst_shape;
op->primitive->value.type = schema::PrimitiveType_BroadcastTo;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_BroadcastTo;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxExpandSpaceParser("Expand", new OnnxExpandParser());

View File

@ -27,7 +27,7 @@ class OnnxExpandParser : public OnnxNodeParser {
OnnxExpandParser() : OnnxNodeParser("Expand") {}
~OnnxExpandParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,23 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxFlattenParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx FlattenParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ReshapeT> attr = std::make_unique<schema::ReshapeT>();
auto attr = std::make_unique<schema::ReshapeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
int axis = 1;
@ -49,10 +39,14 @@ STATUS OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::
attr->shape.emplace_back(0);
}
attr->shape.emplace_back(-1);
op->primitive->value.type = schema::PrimitiveType_Reshape;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Reshape;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxFlattenParser("Flatten", new OnnxFlattenParser());

View File

@ -27,7 +27,7 @@ class OnnxFlattenParser : public OnnxNodeParser {
OnnxFlattenParser() : OnnxNodeParser("Fatten") {}
~OnnxFlattenParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,23 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxGatherParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxGatherParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx GatherParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::GatherT> attr = std::make_unique<schema::GatherT>();
auto attr = std::make_unique<schema::GatherT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -45,9 +35,14 @@ STATUS OnnxGatherParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
}
}
op->primitive->value.type = schema::PrimitiveType_Gather;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Gather;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxGatherParser("Gather", new OnnxGatherParser());

View File

@ -27,7 +27,7 @@ class OnnxGatherParser : public OnnxNodeParser {
OnnxGatherParser() : OnnxNodeParser("Gather") {}
~OnnxGatherParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,55 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_gemm_parser.h"
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
lite::PrimitiveC *OnnxGemmParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx IdentityParser";
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("MatMul");
if (node_parser == nullptr) {
MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed.";
return nullptr;
}
auto *matmul_primitive = node_parser->ParseLitePrimitive(onnx_graph, onnx_node);
node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("BiasAdd");
if (node_parser == nullptr) {
MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed.";
return nullptr;
}
auto *bias_add_primitive = node_parser->ParseLitePrimitive(onnx_graph, onnx_node);
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_MakeTuple;
auto primitve_c = PrimitiveC::Create(primitive.release());
primitve_c->set_attr("MatMul", std::shared_ptr<lite::PrimitiveC>(matmul_primitive));
primitve_c->set_attr("BiasAdd", std::shared_ptr<lite::PrimitiveC>(bias_add_primitive));
return primitve_c;
}
OnnxNodeRegistrar g_onnxGemmParser("Gemm", new OnnxGemmParser());
} // namespace lite
} // namespace mindspore

View File

@ -14,26 +14,21 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TENSOR_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TENSOR_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H
#include "tools/common/tensor_util.h"
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore {
namespace lite {
class OnnxTensorParser {
class OnnxGemmParser : public OnnxNodeParser {
public:
~OnnxTensorParser() = default;
static OnnxTensorParser *GetInstance() {
static OnnxTensorParser onnxTensorParser;
return &onnxTensorParser;
}
TensorCache *GetTensorCache() { return &tensor_cache_; }
OnnxGemmParser() : OnnxNodeParser("Gemm") {}
~OnnxGemmParser() override = default;
private:
OnnxTensorParser() = default;
TensorCache tensor_cache_;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TESNOR_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H

View File

@ -0,0 +1,127 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_given_tensor_fill_parser.h"
#include <functional>
#include <memory>
#include <vector>
#include <algorithm>
#include "src/param_value_lite.h"
namespace mindspore {
namespace lite {
STATUS OnnxGivenTensorFillParser::ParseInt8GivenIntTensorFill(const onnx::NodeProto &onnx_node,
lite::PrimitiveC *primitive_c,
const std::vector<int> &shape) {
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
if (param_value == nullptr) {
MS_LOG(ERROR) << "new a paramValueLite failed.";
return RET_ERROR;
}
int data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(),
[](const onnx::AttributeProto &attr) { return attr.name() == "values"; });
if (iter == onnx_node.attribute().end()) {
return RET_OK;
}
size_t data_size = data_count * sizeof(int64_t) / sizeof(uint8_t);
char *param_data = new (std::nothrow) char[data_size];
if (param_data == nullptr) {
MS_LOG(ERROR) << "new char[] failed";
return RET_MEMORY_FAILED;
}
if (memcpy_s(param_data, data_size, iter->ints().data(), data_size) != EOK) {
MS_LOG(ERROR) << "memcpy data failed.";
delete[] param_data;
return RET_ERROR;
}
param_value->set_tensor_shape(shape);
param_value->set_format(schema::Format_NUM_OF_FORMAT);
param_value->set_tensor_type(kNumberTypeInt64);
param_value->SetTensorData(param_data, data_size);
primitive_c->set_attr("const_data", param_value);
return RET_OK;
}
STATUS OnnxGivenTensorFillParser::ParseInt8GivenTensorFill(const onnx::NodeProto &onnx_node,
lite::PrimitiveC *primitive_c,
const std::vector<int> &shape) {
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
if (param_value == nullptr) {
MS_LOG(ERROR) << "new a paramValueLite failed.";
return RET_ERROR;
}
int data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(),
[](const onnx::AttributeProto &attr) { return attr.name() == "values"; });
if (iter == onnx_node.attribute().end()) {
return RET_OK;
}
char *param_data = new (std::nothrow) char[data_count];
if (param_data == nullptr) {
MS_LOG(ERROR) << "new char[] failed";
return RET_MEMORY_FAILED;
}
if (memcpy_s(param_data, data_count, iter->s().data(), data_count) != EOK) {
MS_LOG(ERROR) << "memcpy data failed.";
delete[] param_data;
return RET_ERROR;
}
param_value->set_tensor_shape(shape);
param_value->set_format(schema::Format_NUM_OF_FORMAT);
param_value->set_tensor_type(kNumberTypeUInt8);
param_value->SetTensorData(param_data, data_count);
primitive_c->set_attr("const_data", param_value);
return RET_OK;
}
lite::PrimitiveC *OnnxGivenTensorFillParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx GivenTensorFillParser";
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Constant;
auto primitive_c = PrimitiveC::Create(primitive.release());
std::vector<int64_t> shape_vector;
auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(),
[](const onnx::AttributeProto &attr) { return attr.name() == "shape"; });
if (iter != onnx_node.attribute().end()) {
shape_vector.insert(shape_vector.begin(), iter->ints().begin(), iter->ints().end());
}
std::vector<int> shape;
std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape),
[](const int64_t &val) { return static_cast<int32_t>(val); });
if (onnx_node.op_type() == "Int8GivenIntTensorFill") {
if (ParseInt8GivenIntTensorFill(onnx_node, primitive_c, shape) != RET_OK) {
MS_LOG(ERROR) << "given tensor fill parse failed.";
return nullptr;
}
} else if (onnx_node.op_type() == "Int8GivenTensorFill") {
if (ParseInt8GivenTensorFill(onnx_node, primitive_c, shape) != RET_OK) {
MS_LOG(ERROR) << "given tensor fill parse failed.";
return nullptr;
}
}
return primitive_c;
}
OnnxNodeRegistrar g_onnxInt8GivenIntTensorFillParser("Int8GivenIntTensorFill", new OnnxGivenTensorFillParser());
OnnxNodeRegistrar g_onnxInt8GivenTensorFillParser("Int8GivenTensorFill", new OnnxGivenTensorFillParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,39 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GIVEN_TENSOR_FILL_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GIVEN_TENSOR_FILL_PARSER_H
#include <vector>
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore {
namespace lite {
class OnnxGivenTensorFillParser : public OnnxNodeParser {
public:
OnnxGivenTensorFillParser() : OnnxNodeParser("GivenTensorFill") {}
~OnnxGivenTensorFillParser() override = default;
STATUS ParseInt8GivenIntTensorFill(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c,
const std::vector<int> &shape);
STATUS ParseInt8GivenTensorFill(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c,
const std::vector<int> &shape);
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GIVEN_TENSOR_FILL_PARSER_H

View File

@ -20,28 +20,23 @@
namespace mindspore {
namespace lite {
STATUS OnnxIdentityParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxIdentityParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx IdentityParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::IdentityT> attr = std::make_unique<schema::IdentityT>();
auto attr = std::make_unique<schema::IdentityT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
op->primitive->value.type = schema::PrimitiveType_Identity;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Identity;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxIdentityParser("Identity", new OnnxIdentityParser());

View File

@ -27,7 +27,7 @@ class OnnxIdentityParser : public OnnxNodeParser {
OnnxIdentityParser() : OnnxNodeParser("Identity") {}
~OnnxIdentityParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,23 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxInstanceNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxInstanceNormParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx InstanceNormParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::InstanceNormT> attr = std::make_unique<schema::InstanceNormT>();
auto attr = std::make_unique<schema::InstanceNormT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
if (!onnx_node.attribute().empty()) {
@ -44,10 +34,14 @@ STATUS OnnxInstanceNormParser::Parse(const onnx::GraphProto &onnx_graph, const o
attr->epsilon = onnx_node_attr.f();
}
}
op->primitive->value.type = schema::PrimitiveType_InstanceNorm;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_InstanceNorm;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxInstanceNormParser("InstanceNormalization", new OnnxInstanceNormParser());

View File

@ -27,7 +27,7 @@ class OnnxInstanceNormParser : public OnnxNodeParser {
OnnxInstanceNormParser() : OnnxNodeParser("InstanceNorm") {}
~OnnxInstanceNormParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -18,23 +18,13 @@
#include <memory>
namespace mindspore::lite {
STATUS OnnxLpNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxLpNormParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx LpNormParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::LpNormalizationT> attr = std::make_unique<schema::LpNormalizationT>();
auto attr = std::make_unique<schema::LpNormalizationT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -45,10 +35,14 @@ STATUS OnnxLpNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
attr->p = onnx_node_attr.i();
}
}
op->primitive->value.type = schema::PrimitiveType_LpNormalization;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_LpNormalization;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxLpNormParser("LpNormalization", new OnnxLpNormParser());

View File

@ -27,7 +27,7 @@ class OnnxLpNormParser : public OnnxNodeParser {
OnnxLpNormParser() : OnnxNodeParser("LpNorm") {}
~OnnxLpNormParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -18,22 +18,13 @@
#include <memory>
namespace mindspore::lite {
STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
lite::PrimitiveC *OnnxLrnParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx LrnParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::LocalResponseNormalizationT> attr = std::make_unique<schema::LocalResponseNormalizationT>();
auto attr = std::make_unique<schema::LocalResponseNormalizationT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
int32_t size = 0;
@ -53,13 +44,18 @@ STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node
if (size == 0) {
MS_LOG(ERROR) << "Divide-by-zero error.";
return RET_ERROR;
return nullptr;
}
attr->alpha /= size;
op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_LocalResponseNormalization;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxLrnxParser("Lrn", new OnnxLrnParser());

View File

@ -27,7 +27,7 @@ class OnnxLrnParser : public OnnxNodeParser {
OnnxLrnParser() : OnnxNodeParser("Lrn") {}
~OnnxLrnParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,22 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxLstmParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
lite::PrimitiveC *OnnxLstmParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx LstmParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::LstmT> attr = std::make_unique<schema::LstmT>();
auto attr = std::make_unique<schema::LstmT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -44,9 +35,14 @@ STATUS OnnxLstmParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
}
}
op->primitive->value.type = schema::PrimitiveType_Lstm;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Lstm;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxLstmParser("LSTM", new OnnxLstmParser());

View File

@ -27,7 +27,7 @@ class OnnxLstmParser : public OnnxNodeParser {
OnnxLstmParser() : OnnxNodeParser("LSTM") {}
~OnnxLstmParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,23 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxMatmulParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx MatMulParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::MatMulT> attr = std::make_unique<schema::MatMulT>();
auto attr = std::make_unique<schema::MatMulT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
float alpha = 1.0f;
@ -54,12 +44,17 @@ STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
}
if (alpha != 1 || beta != 1) {
MS_LOG(ERROR) << "not support alpha * A * B + beta * C";
return RET_ERROR;
return nullptr;
}
op->primitive->value.type = schema::PrimitiveType_MatMul;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_MatMul;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxMatmulParser("MatMul", new OnnxMatmulParser());

View File

@ -27,7 +27,7 @@ class OnnxMatmulParser : public OnnxNodeParser {
OnnxMatmulParser() : OnnxNodeParser("MatMul") {}
~OnnxMatmulParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

File diff suppressed because it is too large Load Diff

View File

@ -26,75 +26,57 @@
#include <vector>
#include <memory>
#include <set>
#include <map>
#include <unordered_map>
#include "securec/include/securec.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_tensor_parser.h"
#include "proto/onnx.pb.h"
#include "src/param_value_lite.h"
namespace mindspore {
namespace lite {
class OnnxModelParser : public ModelParser {
public:
OnnxModelParser();
OnnxModelParser() = default;
virtual ~OnnxModelParser();
~OnnxModelParser() override = default;
// schema::MetaGraphT *ParseGraph(const onnx::GraphProto &graph, const QuantType &quantType = QuantType_QUANT_NONE);
int ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph, const onnx::GraphProto &onnx_graph,
const QuantType &quantType);
MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) override {
return nullptr;
}
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) override;
static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
static STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value,
const ParamValueLitePtr &param_value_lite);
private:
schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type = QuantType_QUANT_NONE) override;
STATUS InitOriginModel(const std::string &model_file);
STATUS ConvertNodes();
STATUS ConvertConstTensors();
STATUS ConvertGraphInputs();
STATUS ConvertGraphOutputs();
STATUS BuildReturnNode(const std::vector<AnfNodePtr> &return_inputs);
STATUS BuildParameterNode(const ParameterPtr &parameter_node, const onnx::TensorProto &tensor);
STATUS BuildParameterNodeForQuantParam(void *data, const std::string &name, TypeId type);
STATUS BuildCNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c);
STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const CNodePtr &cnode);
STATUS ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c);
STATUS ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c);
STATUS BuildCNodeForGemm(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c, const std::string &name);
STATUS ConvertOpQuantParams(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c);
STATUS ParseQuantParam(const onnx::NodeProto &onnx_node);
STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector<QuantParamT> *quant_params);
STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector<QuantParamT> *quant_params);
STATUS CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param, bool scale_or_not);
bool IsSpecialOnnxNode(const onnx::NodeProto &onnx_node);
std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value);
STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph);
STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph);
STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph);
STATUS AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const Category &type, int *index);
STATUS AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const Category &type, int *index);
STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op, const QuantType &quantType, schema::MetaGraphT *dst_graph);
void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::SubGraphT *sub_graph, schema::MetaGraphT *graph, const QuantType &quant_type);
STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node);
STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
const string &onnx_op_type, schema::CNodeT *dst_op);
void SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op,
schema::TensorT *dst_tensor);
STATUS SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op,
const onnx::NodeProto &onnx_node);
STATUS SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op);
STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, schema::TensorT *tensor);
STATUS SetAllTensors(schema::MetaGraphT *graphDef);
void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph);
STATUS ParseSubgraph(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node, const QuantType &quantType,
schema::MetaGraphT *dst_graph);
private:
std::vector<std::string> graphInputNames;
std::vector<std::string> graphConstNames;
int subGraphNum = 0;
onnx::ModelProto onnx_model_;
onnx::GraphProto onnx_graph_;
std::unordered_map<std::string, AnfNodePtr> nodes_;
FuncGraphPtr func_graph_ptr_ = nullptr;
};
} // namespace lite
} // namespace mindspore

View File

@ -15,7 +15,9 @@
*/
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include <algorithm>
#include <vector>
#include <memory>
#include "tools/converter/parser/onnx/onnx_model_parser.h"
namespace mindspore {

View File

@ -20,6 +20,7 @@
#include <string>
#include <utility>
#include <vector>
#include "src/ops/primitive_c.h"
#include "google/protobuf/message.h"
#include "proto/onnx.pb.h"
#include "include/errorcode.h"
@ -34,7 +35,8 @@ class OnnxNodeParser {
virtual ~OnnxNodeParser() = default;
virtual STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) = 0;
virtual lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) = 0;
static STATUS GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector<float> *value, int *type);

View File

@ -19,23 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxNonMaxSuppressionParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxNonMaxSuppressionParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx EluParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::NonMaxSuppressionT> attr = std::make_unique<schema::NonMaxSuppressionT>();
auto attr = std::make_unique<schema::NonMaxSuppressionT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -47,9 +37,14 @@ STATUS OnnxNonMaxSuppressionParser::Parse(const onnx::GraphProto &onnx_graph, co
}
}
op->primitive->value.type = schema::PrimitiveType_NonMaxSuppression;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_NonMaxSuppression;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxNonMaxSuppressionParser("NonMaxSuppression", new OnnxNonMaxSuppressionParser());

View File

@ -27,7 +27,7 @@ class OnnxNonMaxSuppressionParser : public OnnxNodeParser {
OnnxNonMaxSuppressionParser() : OnnxNodeParser("NonMaxSuppression") {}
~OnnxNonMaxSuppressionParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,23 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxOneHotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxOneHotParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx OneHotParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::OneHotT> attr = std::make_unique<schema::OneHotT>();
auto attr = std::make_unique<schema::OneHotT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -45,9 +35,14 @@ STATUS OnnxOneHotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
}
}
op->primitive->value.type = schema::PrimitiveType_OneHot;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_OneHot;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxOneHotParser("OneHot", new OnnxOneHotParser());

View File

@ -27,7 +27,7 @@ class OnnxOneHotParser : public OnnxNodeParser {
OnnxOneHotParser() : OnnxNodeParser("OneHot") {}
~OnnxOneHotParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,22 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
lite::PrimitiveC *OnnxPadParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx PadParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::PadT> attr = std::make_unique<schema::PadT>();
auto attr = std::make_unique<schema::PadT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -58,9 +49,14 @@ STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node
}
}
op->primitive->value.type = schema::PrimitiveType_Pad;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Pad;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxPadParser("Pad", new OnnxPadParser());

View File

@ -27,7 +27,7 @@ class OnnxPadParser : public OnnxNodeParser {
OnnxPadParser() : OnnxNodeParser("Pad") {}
~OnnxPadParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,22 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
lite::PrimitiveC *OnnxPoolParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx PoolParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::PoolingT> attr = std::make_unique<schema::PoolingT>();
auto attr = std::make_unique<schema::PoolingT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
attr->format = schema::Format::Format_NCHW;
@ -56,7 +47,7 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
attr->global = false;
} else {
MS_LOG(ERROR) << "Pooling param`s PoolingMode is not MAX either AVE. MindSpore support MAX and AVE only.";
return RET_ERROR;
return nullptr;
}
attr->roundMode = schema::RoundMode_FLOOR;
@ -101,13 +92,18 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
}
if (attribute_name == "dilations") {
MS_LOG(ERROR) << "pooling op not support dilations now";
return RET_ERROR;
return nullptr;
}
}
op->primitive->value.type = schema::PrimitiveType_Pooling;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Pooling;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxMaxPoolParser("MaxPool", new OnnxPoolParser());

View File

@ -27,7 +27,7 @@ class OnnxPoolParser : public OnnxNodeParser {
OnnxPoolParser() : OnnxNodeParser("Pool") {}
~OnnxPoolParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,23 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxQuantizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxQuantizeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx QuantizeDequantizeParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>();
auto attr = std::make_unique<schema::QuantDTypeCastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed.";
return RET_NULL_PTR;
return nullptr;
}
if (onnx_node.op_type() == "Int8Quantize") {
attr->srcT = kNumberTypeFloat32;
@ -45,11 +35,16 @@ STATUS OnnxQuantizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:
attr->dstT = kNumberTypeFloat32;
} else {
MS_LOG(ERROR) << "Unsupported nodeType: " << onnx_node.op_type().c_str();
return RET_ERROR;
return nullptr;
}
op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_QuantDTypeCast;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxInt8QuantizeParser("Int8Quantize", new OnnxQuantizeParser());

View File

@ -27,7 +27,7 @@ class OnnxQuantizeParser : public OnnxNodeParser {
OnnxQuantizeParser() : OnnxNodeParser("Quantize") {}
~OnnxQuantizeParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,28 +19,23 @@
namespace mindspore {
namespace lite {
STATUS OnnxRangeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxRangeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx RangeParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::RangeT> attr = std::make_unique<schema::RangeT>();
auto attr = std::make_unique<schema::RangeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
attr->dType = 0;
op->primitive->value.type = schema::PrimitiveType_Range;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Range;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxRangeParser("Range", new OnnxRangeParser());

View File

@ -27,7 +27,7 @@ class OnnxRangeParser : public OnnxNodeParser {
OnnxRangeParser() : OnnxNodeParser("Range") {}
~OnnxRangeParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,23 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxReduceParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx ReduceParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ReduceT> attr = std::make_unique<schema::ReduceT>();
auto attr = std::make_unique<schema::ReduceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
attr->keepDims = 1;
@ -65,12 +55,17 @@ STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
attr->mode = schema::ReduceMode_ReduceSumSquare;
} else {
MS_LOG(ERROR) << "unsupported type";
return RET_ERROR;
return nullptr;
}
op->primitive->value.type = schema::PrimitiveType_Reduce;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Reduce;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxReduceMeanParser("ReduceMean", new OnnxReduceParser());

View File

@ -27,7 +27,7 @@ class OnnxReduceParser : public OnnxNodeParser {
OnnxReduceParser() : OnnxNodeParser("Reduce") {}
~OnnxReduceParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -21,22 +21,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
lite::PrimitiveC *OnnxReluParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx ReluParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
auto attr = std::make_unique<schema::ActivationT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
const auto &relu_type = onnx_node.op_type();
@ -54,29 +45,24 @@ STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
}
}
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Activation;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxPReluParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx PReluParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
if (onnx_node.input_size() != 2) {
MS_LOG(ERROR) << "input num should be 2";
return RET_ERROR;
return nullptr;
}
std::unique_ptr<schema::PReLUT> attr = std::make_unique<schema::PReLUT>();
auto attr = std::make_unique<schema::PReLUT>();
std::vector<onnx::TensorProto> params;
const auto &input_name = onnx_node.input(1);
for (const auto &it : onnx_graph.initializer()) {
@ -90,7 +76,7 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
const onnx::TensorProto *slope = &params[0];
if (slope == nullptr) {
MS_LOG(ERROR) << "input error: params[0] is null";
return RET_ERROR;
return nullptr;
}
const auto slope_raw_data = reinterpret_cast<const float *>(slope->raw_data().data());
const int64_t slope_size = slope->raw_data().size() / sizeof(float);
@ -102,16 +88,21 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
attr->channelShared = false;
if (memcpy_s(attr->slope.data(), slope_size * sizeof(float), slope_raw_data, slope_size * sizeof(float)) != EOK) {
MS_LOG(ERROR) << "memcpy_s failed";
return RET_ERROR;
return nullptr;
}
}
} else {
MS_LOG(WARNING) << "The slope pf prelu is null, which may cause errors.";
}
op->primitive->value.type = schema::PrimitiveType_PReLU;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_PReLU;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxReluParser("Relu", new OnnxReluParser());

View File

@ -27,7 +27,7 @@ class OnnxReluParser : public OnnxNodeParser {
OnnxReluParser() : OnnxNodeParser("Relu") {}
~OnnxReluParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
class OnnxPReluParser : public OnnxNodeParser {
@ -35,7 +35,7 @@ class OnnxPReluParser : public OnnxNodeParser {
OnnxPReluParser() : OnnxNodeParser("Prelu") {}
~OnnxPReluParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -20,23 +20,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxReshapeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx ReshapeParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ReshapeT> attr = std::make_unique<schema::ReshapeT>();
auto attr = std::make_unique<schema::ReshapeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
attr->format = schema::Format_NCHW;
@ -51,28 +41,17 @@ STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::
}
}
}
} else {
onnx::TensorProto input_shape;
const auto &shape_name = onnx_node.input(1);
for (const auto &it : onnx_graph.initializer()) {
if (it.name() == shape_name) {
input_shape = it;
break;
}
}
if (input_shape.int64_data_size() == 0) {
MS_LOG(INFO) << "shape maybe from another op other than const initializer";
} else {
for (int i = 0; i < input_shape.int64_data_size(); ++i) {
shape.push_back(input_shape.int64_data(i));
}
}
}
attr->shape = shape;
op->primitive->value.type = schema::PrimitiveType_Reshape;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Reshape;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxReshapeParser("Reshape", new OnnxReshapeParser());

View File

@ -27,7 +27,7 @@ class OnnxReshapeParser : public OnnxNodeParser {
OnnxReshapeParser() : OnnxNodeParser("Reshape") {}
~OnnxReshapeParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -22,23 +22,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxResizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxResizeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx ResizeParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ResizeT> attr = std::make_unique<schema::ResizeT>();
auto attr = std::make_unique<schema::ResizeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
attr->format = schema::Format_NCHW;
@ -85,9 +75,14 @@ STATUS OnnxResizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
}
}
op->primitive->value.type = schema::PrimitiveType_Resize;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Resize;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxResizeParser("Resize", new OnnxResizeParser());

View File

@ -27,7 +27,7 @@ class OnnxResizeParser : public OnnxNodeParser {
OnnxResizeParser() : OnnxNodeParser("Resize") {}
~OnnxResizeParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,28 +19,23 @@
namespace mindspore {
namespace lite {
STATUS OnnxShapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxShapeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx ShapeParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ShapeT> attr = std::make_unique<schema::ShapeT>();
auto attr = std::make_unique<schema::ShapeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
op->primitive->value.type = schema::PrimitiveType_Shape;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Shape;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxShapeParser("Shape", new OnnxShapeParser());

View File

@ -27,7 +27,7 @@ class OnnxShapeParser : public OnnxNodeParser {
OnnxShapeParser() : OnnxNodeParser("Shape") {}
~OnnxShapeParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,30 +19,25 @@
namespace mindspore {
namespace lite {
STATUS OnnxSigmoidParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxSigmoidParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx SigmoidParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
auto attr = std::make_unique<schema::ActivationT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
attr->type = schema::ActivationType_SIGMOID;
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Activation;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxSigmoodParser("Sigmoid", new OnnxSigmoidParser());

View File

@ -27,7 +27,7 @@ class OnnxSigmoidParser : public OnnxNodeParser {
OnnxSigmoidParser() : OnnxNodeParser("Sigmoid") {}
~OnnxSigmoidParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -23,77 +23,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxSliceParser::InsertTensor(const std::vector<int> &onnx_val, const std::string &name,
onnx::NodeProto *onnx_node) {
std::unique_ptr<schema::TensorT> tensor = std::make_unique<schema::TensorT>();
if (tensor == nullptr) {
MS_LOG(ERROR) << "new tensor failed";
return RET_ERROR;
}
tensor->dataType = mindspore::kNumberTypeInt32;
tensor->dims.push_back(onnx_val.size());
tensor->format = schema::Format::Format_NCHW;
tensor->nodeType = schema::NodeType::NodeType_ValueNode;
int data_size = sizeof(int32_t) * onnx_val.size();
tensor->data.resize(data_size);
if (data_size != 0 &&
memcpy_s(static_cast<void *>(tensor->data.data()), data_size, onnx_val.data(), data_size) != EOK) {
MS_LOG(ERROR) << "memcpy_s failed";
return RET_ERROR;
}
int tensor_num = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor().size();
std::string tensor_name = name + std::to_string(tensor_num);
OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(tensor_name, tensor.release(), GRAPH_INPUT);
onnx_node->add_input(tensor_name);
return RET_OK;
}
STATUS OnnxSliceParser::GetInputTensor(std::vector<int> *onnx_val, const std::string &name) {
if (onnx_val == nullptr) {
MS_LOG(ERROR) << "input vector is nullptr.";
return RET_ERROR;
}
if (OnnxTensorParser::GetInstance() == nullptr || OnnxTensorParser::GetInstance()->GetTensorCache() == nullptr) {
MS_LOG(ERROR) << "cannot get tensorcache.";
return RET_ERROR;
}
int index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(name);
if (index == -1) {
MS_LOG(ERROR) << "can not find node: " << name;
return RET_ERROR;
}
auto input_tensor = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor()[index];
if (input_tensor->data.empty()) {
MS_LOG(DEBUG) << "data is empty.";
return RET_NO_CHANGE;
}
int data_num = std::accumulate(input_tensor->dims.begin(), input_tensor->dims.end(), 1, std::multiplies<int>());
onnx_val->resize(data_num);
if (memcpy_s(onnx_val->data(), data_num * sizeof(int32_t), input_tensor->data.data(), data_num * sizeof(int32_t)) !=
EOK) {
MS_LOG(ERROR) << "memcpy_s failed";
return RET_ERROR;
}
return RET_OK;
}
STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxSliceParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx SliceParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::StridedSliceT> attr = std::make_unique<schema::StridedSliceT>();
auto attr = std::make_unique<schema::StridedSliceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
std::vector<int> starts;
@ -128,36 +64,17 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
}
}
}
int status = RET_OK;
switch (onnx_node.input_size()) {
case 5: {
if (steps.empty()) {
status = GetInputTensor(&steps, onnx_node.input(4));
}
}
case 4: {
if (status != RET_ERROR && axes.empty()) {
status = GetInputTensor(&axes, onnx_node.input(3));
}
}
case 3: {
if (status != RET_ERROR && ends.empty()) {
status = GetInputTensor(&ends, onnx_node.input(2));
}
}
case 2: {
if (status != RET_ERROR && starts.empty()) {
status = GetInputTensor(&starts, onnx_node.input(1));
}
}
default: {
if (status == RET_ERROR) {
MS_LOG(ERROR) << "onnx slice inputs are invalid.";
return RET_INPUT_TENSOR_ERROR;
}
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_StridedSlice;
primitive->value.value = attr.release();
auto primitive_c = PrimitiveC::Create(primitive.release());
if (starts.empty()) {
return primitive_c;
}
if (axes.empty()) {
for (size_t i = 0; i < starts.size(); ++i) {
axes.push_back(i);
@ -166,42 +83,11 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
if (steps.empty()) {
steps.assign(starts.size(), 1);
}
onnx::NodeProto *slice_node = nullptr;
for (auto &node : onnx_graph.node()) {
if (&node == &onnx_node) {
slice_node = const_cast<onnx::NodeProto *>(&node);
}
}
int insert_num = 5 - onnx_node.input_size();
switch (insert_num) {
case 4: {
std::string name = "slice/starts/";
status = InsertTensor(starts, name, slice_node);
}
case 3:
if (status == RET_OK) {
std::string name = "slice/ends/";
status = InsertTensor(ends, name, slice_node);
}
case 2:
if (status == RET_OK) {
std::string name = "slice/axes/";
status = InsertTensor(axes, name, slice_node);
}
case 1:
if (status == RET_OK) {
std::string name = "slice/steps/";
status = InsertTensor(steps, name, slice_node);
}
default:
if (status != RET_OK) {
MS_LOG(ERROR) << "onnx slice insert tensor failed";
return RET_ERROR;
}
}
op->primitive->value.type = schema::PrimitiveType_StridedSlice;
op->primitive->value.value = attr.release();
return RET_OK;
primitive_c->set_attr("starts", MakeValue<std::vector<int>>(starts));
primitive_c->set_attr("ends", MakeValue<std::vector<int>>(ends));
primitive_c->set_attr("axes", MakeValue<std::vector<int>>(axes));
primitive_c->set_attr("steps", MakeValue<std::vector<int>>(steps));
return primitive_c;
}
OnnxNodeRegistrar g_onnxSliceParser("Slice", new OnnxSliceParser());

View File

@ -21,7 +21,6 @@
#include <string>
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
#include "tools/converter/parser/onnx/onnx_tensor_parser.h"
namespace mindspore {
namespace lite {
@ -30,9 +29,7 @@ class OnnxSliceParser : public OnnxNodeParser {
OnnxSliceParser() : OnnxNodeParser("Slice") {}
~OnnxSliceParser() override = default;
STATUS InsertTensor(const std::vector<int> &onnx_val, const std::string &name, onnx::NodeProto *onnx_node);
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS GetInputTensor(std::vector<int> *onnx_val, const std::string &name);
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,23 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxSoftMaxParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx SoftMaxParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::SoftMaxT> attr = std::make_unique<schema::SoftMaxT>();
auto attr = std::make_unique<schema::SoftMaxT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
bool axis_is_def = true;
@ -53,9 +43,14 @@ STATUS OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::
attr->axis = 1;
}
}
op->primitive->value.type = schema::PrimitiveType_SoftMax;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_SoftMax;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxSoftMaxParser("Softmax", new OnnxSoftMaxParser());

View File

@ -27,7 +27,7 @@ class OnnxSoftMaxParser : public OnnxNodeParser {
OnnxSoftMaxParser() : OnnxNodeParser("Softmax") {}
~OnnxSoftMaxParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,23 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxSpaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxSpaceToDepthParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx SpaceToDepthParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::SpaceToDepthT> attr = std::make_unique<schema::SpaceToDepthT>();
auto attr = std::make_unique<schema::SpaceToDepthT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -45,9 +35,14 @@ STATUS OnnxSpaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, const o
}
}
op->primitive->value.type = schema::PrimitiveType_SpaceToDepth;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_SpaceToDepth;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxSpaceToDepthParser("SpaceToDepth", new OnnxSpaceToDepthParser());

View File

@ -27,7 +27,7 @@ class OnnxSpaceToDepthParser : public OnnxNodeParser {
OnnxSpaceToDepthParser() : OnnxNodeParser("SpaceToDepth") {}
~OnnxSpaceToDepthParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,23 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxSplitParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxSplitParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx SplitParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::SplitT> attr = std::make_unique<schema::SplitT>();
auto attr = std::make_unique<schema::SplitT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
attr->splitDim = 0;
@ -51,9 +41,14 @@ STATUS OnnxSplitParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
}
}
op->primitive->value.type = schema::PrimitiveType_Split;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Split;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxSplitParser("Split", new OnnxSplitParser());

View File

@ -27,7 +27,7 @@ class OnnxSplitParser : public OnnxNodeParser {
OnnxSplitParser() : OnnxNodeParser("Split") {}
~OnnxSplitParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,23 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
lite::PrimitiveC *OnnxSqueezeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx SqueezeParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::SqueezeT> attr = std::make_unique<schema::SqueezeT>();
auto attr = std::make_unique<schema::SqueezeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -47,9 +37,14 @@ STATUS OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::
}
}
op->primitive->value.type = schema::PrimitiveType_Squeeze;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Squeeze;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxSqueezeParser("Squeeze", new OnnxSqueezeParser());

View File

@ -27,7 +27,7 @@ class OnnxSqueezeParser : public OnnxNodeParser {
OnnxSqueezeParser() : OnnxNodeParser("Squeeze") {}
~OnnxSqueezeParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -20,26 +20,22 @@
namespace mindspore {
namespace lite {
STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
lite::PrimitiveC *OnnxTileParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx TileParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::TileT> attr = std::make_unique<schema::TileT>();
auto attr = std::make_unique<schema::TileT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
op->primitive->value.type = schema::PrimitiveType_Tile;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Tile;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxTileParser("Tile", new OnnxTileParser());

View File

@ -27,7 +27,7 @@ class OnnxTileParser : public OnnxNodeParser {
OnnxTileParser() : OnnxNodeParser("Tile") {}
~OnnxTileParser() override = default;
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,22 +19,13 @@
namespace mindspore {
namespace lite {
STATUS OnnxTopkParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
lite::PrimitiveC *OnnxTopkParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx TopKParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::TopKT> attr = std::make_unique<schema::TopKT>();
auto attr = std::make_unique<schema::TopKT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
return nullptr;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -44,9 +35,14 @@ STATUS OnnxTopkParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
}
}
op->primitive->value.type = schema::PrimitiveType_TopK;
op->primitive->value.value = attr.release();
return RET_OK;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_TopK;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxTopkParser("TopK", new OnnxTopkParser());

Some files were not shown because too many files have changed in this diff Show More