diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.cc index 0d34c7d705e..95c84cc9419 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_argmax_parser.h" +#include namespace mindspore { namespace lite { @@ -23,7 +23,22 @@ STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx ArgMaxParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "axis") { @@ -32,11 +47,9 @@ STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph, attr->keepDims = static_cast(onnx_node_attr.i()); } } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_ArgMax; - op->primitive->value.value = attr.release(); - } + + op->primitive->value.type = schema::PrimitiveType_ArgMax; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.h index b878d18194b..77e616593d2 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_ARGMAX_PARSER_H -#define MS_ONNX_ARGMAX_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARGMAX_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARGMAX_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxArgMaxParser : public OnnxNodeParser { public: OnnxArgMaxParser() : OnnxNodeParser("ArgMax") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_ARGMAX_PARSER_H +#endif // MMINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARGMAX_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc index 00741915052..ea70899bd2f 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc @@ -14,114 +14,238 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h" +#include namespace mindspore { namespace lite { STATUS OnnxAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx AddParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Add; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Add; + op->primitive->value.value = attr.release(); return RET_OK; } STATUS OnnxSubParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx SubParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Sub; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Sub; + op->primitive->value.value = attr.release(); return RET_OK; } STATUS OnnxMulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx MulParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Mul; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Mul; + op->primitive->value.value = attr.release(); return RET_OK; } STATUS OnnxDivParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx DivParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Div; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Div; + op->primitive->value.value = attr.release(); return RET_OK; } STATUS OnnxPowParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx PowParser"; - if (op != nullptr) { - std::unique_ptr attr(new schema::PowerT()); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Power; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr(new schema::PowerT()); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Power; + op->primitive->value.value = attr.release(); return RET_OK; } + STATUS OnnxEqualParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx EqualParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Equal; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Equal; + op->primitive->value.value = attr.release(); return RET_OK; } STATUS OnnxLessParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx LessParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Less; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Less; + op->primitive->value.value = attr.release(); return RET_OK; } STATUS OnnxGreaterParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx GreaterParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Greater; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Greater; + op->primitive->value.value = attr.release(); return RET_OK; } STATUS OnnxMinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx MinParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Min; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Min; + op->primitive->value.value = attr.release(); return RET_OK; } STATUS OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx EltwiseParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + // there is no Prod in onnx if (onnx_node.op_type() == "Sum") { attr->mode = schema::EltwiseMode_SUM; @@ -129,143 +253,303 @@ STATUS OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: attr->mode = schema::EltwiseMode_MAXIMUM; } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Eltwise; - op->primitive->value.value = attr.release(); - } + op->primitive->value.type = schema::PrimitiveType_Eltwise; + op->primitive->value.value = attr.release(); return RET_OK; } STATUS OnnxFloorParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx FloorParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Floor; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Floor; + op->primitive->value.value = attr.release(); return RET_OK; } + STATUS OnnxAbsParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx AbsParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Abs; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Abs; + op->primitive->value.value = attr.release(); return RET_OK; } + STATUS OnnxNegParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx NegParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Neg; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Neg; + op->primitive->value.value = attr.release(); return RET_OK; } + STATUS OnnxExpParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx ExpParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Exp; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Exp; + op->primitive->value.value = attr.release(); return RET_OK; } + STATUS OnnxCosParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx CosParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Cos; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Cos; + op->primitive->value.value = attr.release(); return RET_OK; } + STATUS OnnxSinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx SinParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Sin; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Sin; + op->primitive->value.value = attr.release(); return RET_OK; } + STATUS OnnxSqrtParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx SqrtParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Sqrt; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Sqrt; + op->primitive->value.value = attr.release(); return RET_OK; } + STATUS OnnxCeilParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx CeilParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Ceil; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Ceil; + op->primitive->value.value = attr.release(); return RET_OK; } + STATUS OnnxLogParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx LogParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Log; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Log; + op->primitive->value.value = attr.release(); return RET_OK; } + STATUS OnnxTanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx TanParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Tan; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Tan; + op->primitive->value.value = attr.release(); return RET_OK; } + STATUS OnnxAtanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx AtanParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Atan; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Atan; + op->primitive->value.value = attr.release(); return RET_OK; } + STATUS OnnxAsinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx AsinParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Asin; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Asin; + op->primitive->value.value = attr.release(); return RET_OK; } STATUS OnnxTanhParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx TanhParser"; - if (op != nullptr) { - MS_LOG(ERROR) << "mslite don't support tanh now"; - return RET_ERROR; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } - return RET_OK; + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + MS_LOG(ERROR) << "mslite don't support tanh now"; + return RET_ERROR; } OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h index 3457e43b225..28f9a5d8f0e 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_ARITHMETIC_OPREATION_PARSER_H -#define MS_ONNX_ARITHMETIC_OPREATION_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -167,5 +167,5 @@ class OnnxTanhParser : public OnnxNodeParser { }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_ARITHMETIC_OPREATION_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc index 98837a5da56..73650e15825 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc @@ -19,10 +19,26 @@ namespace mindspore { namespace lite { -STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, +STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx BatchNormParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + for (const auto &onnx_node_attr : onnx_node.attribute()) { if (onnx_node_attr.name() == "epsilon") { attr->epsilon = onnx_node_attr.f(); @@ -32,11 +48,9 @@ STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx attr->spatial = static_cast(onnx_node_attr.i()); } } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_FusedBatchNorm; - op->primitive->value.value = attr.release(); - } + + op->primitive->value.type = schema::PrimitiveType_FusedBatchNorm; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.h index f5fe947c085..0328fb76279 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_ADD_PARSER_H -#define MS_ONNX_ADD_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BATCHNORM_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BATCHNORM_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxBatchNormParser : public OnnxNodeParser { public: OnnxBatchNormParser() : OnnxNodeParser("BatchNormalization") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_ADD_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BATCHNORM_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc index 73fe05a2c78..c92e02151b1 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc @@ -14,26 +14,36 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_biasadd_parser.h" +#include -// using namespace mindspore::predict; -// using namespace onnx; -// using namespace std; namespace mindspore { namespace lite { STATUS OnnxBiasAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx BiasAddParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + // use channel dim as axis attr->axis = {1}; - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_BiasAdd; - op->primitive->value.value = attr.release(); - } + + op->primitive->value.type = schema::PrimitiveType_BiasAdd; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.h index 52874544814..892802acdcf 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_BIASADD_PARSER_H -#define MS_ONNX_BIASADD_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BIASADD_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BIASADD_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -26,9 +26,11 @@ class OnnxBiasAddParser : public OnnxNodeParser { public: OnnxBiasAddParser() : OnnxNodeParser("BiasAdd") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_BIASADD_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BIASADD_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc index 598cf6c8302..66423500f16 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc @@ -14,25 +14,40 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_cast_parser.h" +#include namespace mindspore { namespace lite { -STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx CastParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "to") { attr->dstT = static_cast(onnx_node_attr.i()); } } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Cast; - op->primitive->value.value = attr.release(); - } + + op->primitive->value.type = schema::PrimitiveType_Cast; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.h index 83812999c82..027fd0bdbe0 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_CAST_PARSER_H -#define MS_ONNX_CAST_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CAST_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CAST_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxCastParser : public OnnxNodeParser { public: OnnxCastParser() : OnnxNodeParser("Cast") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_CAST_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CAST_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc index f2c989ecdb6..1ac60d4330e 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc @@ -14,13 +14,25 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_clip_parser.h" +#include namespace mindspore { namespace lite { -STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx ClipParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + float min = -1, max = -1; for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); @@ -32,15 +44,17 @@ STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod } if (min == 0 && max == 6) { std::unique_ptr attr = std::make_unique(); - attr->type = schema::ActivationType_RELU6; - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Activation; - op->primitive->value.value = attr.release(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; } + attr->type = schema::ActivationType_RELU6; + + op->primitive->value.type = schema::PrimitiveType_Activation; + op->primitive->value.value = attr.release(); } else { MS_LOG(ERROR) << "only support convert clip(0,6) to relu6, other value is not supported"; - return RET_PARAM_INVALID; + return RET_ERROR; } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.h index 150cf764046..9f6e2eba6f0 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_CLIP_PARSER_H -#define MS_ONNX_CLIP_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CLIP_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CLIP_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxClipParser : public OnnxNodeParser { public: OnnxClipParser() : OnnxNodeParser("Clip") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_ARGMAX_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CLIP_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.cc index 039cde44695..539da26e62a 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_concat_parser.h" +#include namespace mindspore { namespace lite { @@ -23,18 +23,31 @@ STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx ConcatParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "axis") { attr->axis = static_cast(onnx_node_attr.i()); } } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Concat; - op->primitive->value.value = attr.release(); - } + + op->primitive->value.type = schema::PrimitiveType_Concat; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.h index 69828176a23..ca5a407cf18 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_CONCAT_PARSER_H -#define MS_ONNX_CONCAT_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONCAT_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONCAT_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxConcatParser : public OnnxNodeParser { public: OnnxConcatParser() : OnnxNodeParser("Concat") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_CONCAT_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONCAT_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc index df54cbc165c..06afd29f61c 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_constant_parser.h" +#include namespace mindspore { namespace lite { @@ -23,12 +23,24 @@ STATUS OnnxConstantParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx ConstantParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Constant; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Constant; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h index 26736a5e3ce..43d84a7a6cf 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_CONSTANT_PARSER_H -#define MS_ONNX_CONSTANT_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONSTANT_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONSTANT_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxConstantParser : public OnnxNodeParser { public: OnnxConstantParser() : OnnxNodeParser("Constant") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_CONSTANT_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONSTANT_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc index 02abc62e506..4c1e4b22d7e 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc @@ -14,21 +14,22 @@ * limitations under the License. */ +#include "tools/converter/parser/onnx/onnx_conv_parser.h" #include #include #include -#include "tools/converter/parser/onnx/onnx_conv_parser.h" namespace mindspore { namespace lite { -bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr &attr, schema::CNodeT *op) { +bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr &attr, + schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx DepthwiseConvParser"; if (attr == nullptr || attr->group != attr->channelIn) { return false; } std::unique_ptr depthwiseConv2DParam = std::make_unique(); if (depthwiseConv2DParam == nullptr) { - MS_LOG(ERROR) << "new DepthwiseConv2DT failed"; + MS_LOG(ERROR) << "new op failed"; return false; } depthwiseConv2DParam->format = attr->format; @@ -47,15 +48,32 @@ bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptrdilateH = attr->dilateH; depthwiseConv2DParam->hasBias = attr->hasBias; depthwiseConv2DParam->activationType = attr->activationType; - op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; op->primitive->value.value = depthwiseConv2DParam.release(); return true; } -STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx ConvParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + // set opdef each attr params for (const auto &onnx_node_attr : onnx_node.attribute()) { if (onnx_node_attr.name() == "group") { @@ -149,13 +167,13 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod } else { attr->activationType = schema::ActivationType_NO_ACTIVATION; } + if (attr->group != 1) { if (!ParseGroupConvolution(attr, op)) { MS_LOG(ERROR) << "Convert Convolution to Depthwise failed"; return RET_ERROR; } } else { - op->primitive = std::make_unique(); op->primitive->value.type = schema::PrimitiveType_Conv2D; op->primitive->value.value = attr.release(); } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h index 4d890d29557..6fceb2dc3b3 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_CONV_PARSER_H -#define MS_ONNX_CONV_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONV_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONV_PARSER_H #include #include "tools/converter/parser/onnx/onnx_node_parser.h" @@ -26,11 +26,16 @@ namespace lite { class OnnxConvParser : public OnnxNodeParser { public: OnnxConvParser() : OnnxNodeParser("Conv") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; private: - bool ParseGroupConvolution(const std::unique_ptr &attr, schema::CNodeT *op); + bool ParseGroupConvolution(const std::unique_ptr &attr, + schema::CNodeT *op); }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_CONV_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONV_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_converter.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_converter.cc index 92e02a3aa86..f61e34630aa 100755 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_converter.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_converter.cc @@ -22,6 +22,7 @@ namespace lite { OnnxConverter::OnnxConverter() { modelParser = new OnnxModelParser(); } + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_converter.h b/mindspore/lite/tools/converter/parser/onnx/onnx_converter.h index 7d51c8d2e4d..ad8da5a2abb 100755 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_converter.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_converter.h @@ -14,8 +14,9 @@ * limitations under the License. */ -#ifndef MS_ONNX_CONVERTER_H -#define MS_ONNX_CONVERTER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H + #include #include #include "tools/converter/converter.h" @@ -27,10 +28,10 @@ class OnnxConverter : public Converter { public: OnnxConverter(); - ~OnnxConverter() override = default; + ~OnnxConverter() = default; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_CONVERTER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc index 8d51d2bc36a..44541d8dfd3 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc @@ -14,21 +14,21 @@ * limitations under the License. */ +#include "tools/converter/parser/onnx/onnx_deconv_parser.h" #include #include #include -#include "tools/converter/parser/onnx/onnx_deconv_parser.h" namespace mindspore { namespace lite { -bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr &attr, schema::CNodeT *op) { - MS_LOG(DEBUG) << "onnx DeConvParser"; +bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr &attr, + schema::CNodeT *op) { if (attr == nullptr || attr->group != attr->channelOut) { return false; } std::unique_ptr deDepthwiseConv2DParam = std::make_unique(); if (deDepthwiseConv2DParam == nullptr) { - MS_LOG(ERROR) << "new DeDepthwiseConv2DT failed"; + MS_LOG(WARNING) << "new op failed"; return false; } deDepthwiseConv2DParam->format = attr->format; @@ -47,38 +47,53 @@ bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptrdilateH = attr->dilateH; deDepthwiseConv2DParam->hasBias = attr->hasBias; deDepthwiseConv2DParam->activationType = attr->activationType; - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D; - op->primitive->value.value = deDepthwiseConv2DParam.release(); - } + + op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D; + op->primitive->value.value = deDepthwiseConv2DParam.release(); return true; } -STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, +STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + MS_LOG(DEBUG) << "onnx DeConvParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + // set opdef each attr params for (const auto &onnx_node_attr : onnx_node.attribute()) { if (onnx_node_attr.name() == "group") { attr->group = static_cast(onnx_node_attr.i()); } else if (onnx_node_attr.name() == "dilations") { if (onnx_node_attr.ints().size() != 2) { - // MS_LOGE("dilations size %d is not 2", onnx_node_attr.ints().size()); + MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2"; return RET_ERROR; } attr->dilateW = static_cast(onnx_node_attr.ints(0)); attr->dilateH = static_cast(onnx_node_attr.ints(1)); } else if (onnx_node_attr.name() == "kernels") { if (onnx_node_attr.ints().size() != 2) { - // MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size()); + MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; return RET_ERROR; } attr->kernelH = static_cast(onnx_node_attr.ints(0)); attr->kernelW = static_cast(onnx_node_attr.ints(1)); } else if (onnx_node_attr.name() == "kernel_shape") { if (onnx_node_attr.ints().size() != 2) { - // MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size()); + MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; return RET_ERROR; } attr->kernelW = static_cast(onnx_node_attr.ints(0)); @@ -87,7 +102,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N attr->padMode = GetOnnxPadMode(onnx_node_attr); } else if (onnx_node_attr.name() == "pads") { if (onnx_node_attr.ints().size() != 4) { - // MS_LOGE("pads size %d is not 4", onnx_node_attr.ints().size()); + MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4"; return RET_ERROR; } attr->padUp = static_cast(onnx_node_attr.ints(0)); @@ -96,7 +111,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N attr->padRight = static_cast(onnx_node_attr.ints(3)); } else if (onnx_node_attr.name() == "strides") { if (onnx_node_attr.ints().size() != 2) { - // MS_LOGE("strides size %d is not 2", onnx_node_attr.ints().size()); + MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2"; return RET_ERROR; } attr->strideW = static_cast(onnx_node_attr.ints(0)); @@ -105,7 +120,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N if (onnx_node_attr.s() == "NHWC") { attr->format = schema::Format_NHWC; } else { - // MS_LOGE("Unsupported format: %s", onnx_node_attr.s().c_str()); + MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s().c_str(); return RET_ERROR; } } @@ -116,7 +131,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); if (nodeIter == onnx_graph.initializer().end()) { - // MS_LOGE("not find node: %s", onnx_conv_weight.c_str()) + MS_LOG(ERROR) << "not find node: " << onnx_conv_weight.c_str(); return RET_ERROR; } std::vector weight_shape; @@ -137,7 +152,6 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N return RET_ERROR; } } else { - op->primitive = std::make_unique(); op->primitive->value.type = schema::PrimitiveType_DeConv2D; op->primitive->value.value = attr.release(); } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h index 2dc0a02b16e..0c0730ab118 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_DECONV_PARSER_H -#define MS_ONNX_DECONV_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DECONV_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DECONV_PARSER_H #include #include "tools/converter/parser/onnx/onnx_node_parser.h" @@ -26,11 +26,16 @@ namespace lite { class OnnxDeConvParser : public OnnxNodeParser { public: OnnxDeConvParser() : OnnxNodeParser("DeConv") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; private: - bool ParseGroupDeConvolution(const std::unique_ptr &attr, schema::CNodeT *op); + bool ParseGroupDeConvolution(const std::unique_ptr &attr, + schema::CNodeT *op); }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_DECONV_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DECONV_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc index adb61697861..0738060ffc9 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_depth_to_space_parser.h" +#include namespace mindspore { namespace lite { @@ -23,18 +23,31 @@ STATUS OnnxDepthToSpaceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx DepthToSpaceParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto& attribute_name = onnx_node_attr.name(); if (attribute_name == "blocksize") { attr->blockSize = static_cast(onnx_node_attr.i()); } } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_DepthToSpace; - op->primitive->value.value = attr.release(); - } + + op->primitive->value.type = schema::PrimitiveType_DepthToSpace; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.h index 9b3d51ec54e..b176a123a11 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_DEPTH_TO_SPACE_PARSER_H -#define MS_ONNX_DEPTH_TO_SPACE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DEPTH_TO_SPACE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DEPTH_TO_SPACE_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxDepthToSpaceParser : public OnnxNodeParser { public: OnnxDepthToSpaceParser() : OnnxNodeParser("DepthToSpace") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_DEPTH_TO_SPACE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DEPTH_TO_SPACE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc index 065f270e856..820d957e18a 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_dropout_parser.h" +#include namespace mindspore { namespace lite { @@ -23,18 +23,31 @@ STATUS OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx DropoutParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "ratio") { attr->ratio = static_cast(onnx_node_attr.i()); } } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Dropout; - op->primitive->value.value = attr.release(); - } + + op->primitive->value.type = schema::PrimitiveType_Dropout; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.h index 1a8f99eda32..454a8805e05 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_ARGMAX_PARSER_H -#define MS_ONNX_ARGMAX_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DROPOUT_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DROPOUT_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxDropoutParser : public OnnxNodeParser { public: OnnxDropoutParser() : OnnxNodeParser("Dropout") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_ARGMAX_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DROPOUT_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.cc index 9106a58ff04..8a3723b1cd2 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.cc @@ -14,25 +14,40 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_elu_parser.h" +#include namespace mindspore { namespace lite { -STATUS OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +STATUS OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx EluParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto& attribute_name = onnx_node_attr.name(); if (attribute_name == "alpha") { attr->alpha = onnx_node_attr.f(); } } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Elu; - op->primitive->value.value = attr.release(); - } + + op->primitive->value.type = schema::PrimitiveType_Elu; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.h index 5387e42edf8..76201660c00 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_ELU_PARSER_H -#define MS_ONNX_ELU_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ELU_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ELU_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxEluParser : public OnnxNodeParser { public: OnnxEluParser() : OnnxNodeParser("Elu") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_ELU_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ELU_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc index d474ef565b6..7994faf770b 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc @@ -14,20 +14,32 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_expand_parser.h" +#include namespace mindspore { namespace lite { STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx ExpandParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Broadcast; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Broadcast; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.h index 3b805f5d924..b1bb3fe777e 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_EXPAND_PARSER_H -#define MS_ONNX_EXPAND_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_EXPAND_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_EXPAND_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxExpandParser : public OnnxNodeParser { public: OnnxExpandParser() : OnnxNodeParser("Expand") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_EXPAND_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_EXPAND_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.cc index 2e30d702b36..cff8b530fed 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_flatten_parser.h" +#include namespace mindspore { namespace lite { @@ -23,7 +23,22 @@ STATUS OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx FlattenParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + int axis = 1; for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); @@ -36,11 +51,8 @@ STATUS OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph, } attr->shape.emplace_back(-1); - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Reshape; - op->primitive->value.value = attr.release(); - } + op->primitive->value.type = schema::PrimitiveType_Reshape; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.h index f3637a2a16b..6f28794aa8a 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_FLATTEN_PARSER_H -#define MS_ONNX_FLATTEN_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_FLATTEN_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_FLATTEN_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -26,9 +26,11 @@ class OnnxFlattenParser : public OnnxNodeParser { public: OnnxFlattenParser() : OnnxNodeParser("Fatten") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_FLATTEN_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_FLATTEN_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.cc index 68d916ba46e..98f66ed82b9 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_gather_parser.h" +#include namespace mindspore { namespace lite { @@ -23,18 +23,31 @@ STATUS OnnxGatherParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx GatherParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto& attribute_name = onnx_node_attr.name(); if (attribute_name == "axis") { attr->axis = static_cast(onnx_node_attr.i()); } } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Gather; - op->primitive->value.value = attr.release(); - } + + op->primitive->value.type = schema::PrimitiveType_Gather; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.h index 4bfafdf29be..778c0e8538b 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_GATHER_PARSER_H -#define MS_ONNX_GATHER_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GATHER_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GATHER_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxGatherParser : public OnnxNodeParser { public: OnnxGatherParser() : OnnxNodeParser("Gather") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_GATHER_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GATHER_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc index ee2c345ac4e..228881edcd5 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc @@ -14,31 +14,69 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_lrn_parser.h" +#include namespace mindspore { namespace lite { -STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx LrnParser"; - std::unique_ptr attr = std::make_unique(); - for (const auto &onnx_node_attr : onnx_node.attribute()) { - const auto& attribute_name = onnx_node_attr.name(); - if (attribute_name == "size") { - attr->size = static_cast(onnx_node_attr.i()); - } else if (attribute_name == "alpha") { - attr->alpha = onnx_node_attr.f(); - } else if (attribute_name == "beta") { - attr->beta = onnx_node_attr.f(); - } else if (attribute_name == "bias") { - attr->bias = onnx_node_attr.f(); - } + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Lrn; - op->primitive->value.value = attr.release(); + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; } + + std::unique_ptr attr + = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + auto onnx_node_attr = onnx_node.attribute().at(0); + int32_t size = 0; + if (onnx_node_attr.name() == "size") { + size = static_cast(onnx_node_attr.i()); + attr->depth_radius = static_cast(size / 2); + } else { + MS_LOG(ERROR) << "the first attr is not size"; + return RET_ERROR; + } + + onnx_node_attr = onnx_node.attribute().at(1); + if (onnx_node_attr.name() == "alpha") { + auto alpha = onnx_node_attr.f(); + attr->alpha = alpha / size; + } else { + MS_LOG(ERROR) << "the second attr is not alpha"; + return RET_ERROR; + } + + onnx_node_attr = onnx_node.attribute().at(2); + if (onnx_node_attr.name() == "beta") { + attr->beta = onnx_node_attr.f(); + } else { + MS_LOG(ERROR) << "the third attr is not beta"; + return RET_ERROR; + } + + onnx_node_attr = onnx_node.attribute().at(3); + if (onnx_node_attr.name() == "bias") { + attr->bias = onnx_node_attr.f(); + } else { + MS_LOG(ERROR) << "the third attr is not bias"; + return RET_ERROR; + } + + op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.h index 3fd38d2660e..d7c6cf88d0c 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_LRN_PARSER_H -#define MS_ONNX_LRN_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LRN_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LRN_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxLrnParser : public OnnxNodeParser { public: OnnxLrnParser() : OnnxNodeParser("Lrn") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_LRN_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LRN_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc index fcbfbb46bd1..53fb2a4a872 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc @@ -14,15 +14,31 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_matmul_parser.h" +#include namespace mindspore { namespace lite { -STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, +STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx MatMulParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + float alpha = 1.0f; float beta = 1.0f; for (const auto &onnx_node_attr : onnx_node.attribute()) { @@ -39,14 +55,11 @@ STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N } if (alpha != 1 || beta != 1) { MS_LOG(ERROR) << "not support alpha * A * B + beta * C"; - return RET_PARAM_INVALID; + return RET_ERROR; } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_MatMul; - op->primitive->value.value = attr.release(); - } + op->primitive->value.type = schema::PrimitiveType_MatMul; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.h index 37c6a669f26..2b099d19926 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_MATMUL_PARSER_H -#define MS_ONNX_MATMUL_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MATMUL_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MATMUL_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxMatmulParser : public OnnxNodeParser { public: OnnxMatmulParser() : OnnxNodeParser("MatMul") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_MATMUL_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MATMUL_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index 6539494ddfc..b6edf4ebd42 100755 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -14,11 +14,11 @@ * limitations under the License. */ +#include "tools/converter/parser/onnx/onnx_model_parser.h" #include #include #include #include -#include "tools/converter/parser/onnx/onnx_model_parser.h" #include "tools/common/graph_util.h" #include "src/common/utils.h" @@ -54,7 +54,8 @@ std::vector OnnxModelParser::GetDimsFromOnnxValue(const onnx::ValueInfo return dims; } -STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile, google::protobuf::Message *onnx_model) { +STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile, + google::protobuf::Message *onnx_model) { std::unique_ptr onnx_file(new (std::nothrow) char[PATH_MAX]{0}); #ifdef _WIN32 if (_fullpath(onnx_file.get(), modelFile.c_str(), 1024) == nullptr) { @@ -81,7 +82,8 @@ STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile, go return RET_OK; } -STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache) { +STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, + TensorCache *tensor_cache) { MS_LOG(DEBUG) << "set onnx constant tensors"; for (const auto &onnx_const_value : onnx_graph.initializer()) { int index; @@ -117,8 +119,11 @@ STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, return RET_OK; } -STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const TensorType &type, - TensorCache *tensor_cache, int *index) { +STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, + const std::string &name, + const TensorType &type, + TensorCache *tensor_cache, + int *index) { auto data_type = GetDataTypeFromOnnx(static_cast(proto.type().tensor_type().elem_type())); if (data_type == kTypeUnknown) { MS_LOG(ERROR) << "not support onnx data type " @@ -137,8 +142,12 @@ STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const st *index = tensor_cache->AddTensor(name, tensor.release(), type); return RET_OK; } -STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const TensorType &type, - TensorCache *tensor_cache, int *index) { + +STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, + const std::string &name, + const TensorType &type, + TensorCache *tensor_cache, + int *index) { auto data_type = GetDataTypeFromOnnx(static_cast(proto.data_type())); if (data_type == kTypeUnknown) { MS_LOG(ERROR) << "not support onnx data type " << static_cast(proto.data_type()); @@ -165,7 +174,8 @@ STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std return RET_OK; } -STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, +STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, + schema::MetaGraphT *graph, TensorCache *tensor_cache) { for (const auto &input_value : onnx_graph.input()) { auto ret = tensor_cache->FindTensor(input_value.name()); @@ -182,7 +192,8 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, return RET_OK; } -STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, +STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, + schema::MetaGraphT *graph, TensorCache *tensor_cache) { for (const auto &output_value : onnx_graph.output()) { int index; @@ -196,8 +207,10 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, return RET_OK; } -void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::MetaGraphT *graph, TensorCache *tensor_cache) { +void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::MetaGraphT *graph, + TensorCache *tensor_cache) { std::unique_ptr dst_op_1 = std::make_unique(); dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0); ParseOnnxNodeAttr(onnx_graph, onnx_node, "MatMul", dst_op_1.get()); @@ -218,7 +231,8 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, cons graph->nodes.emplace_back(std::move(dst_op_2)); } -STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) { +STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, + TensorCache *tensor_cache) { // convert GivenTensorFill node to a weight/bias tensor auto ret = tensor_cache->FindTensor(onnx_node.output(0)); if (ret < 0) { @@ -270,8 +284,10 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, return RET_OK; } -STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *dst_op, schema::TensorT *dst_tensor, +STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *dst_op, + schema::TensorT *dst_tensor, TensorCache *tensor_cache) { // change op_type() to name(), that is unique dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0); @@ -303,8 +319,11 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, return RET_OK; } -void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache) { +void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *dst_op, + schema::TensorT *dst_tensor, + TensorCache *tensor_cache) { MS_ASSERT(dst_op != nullptr); MS_ASSERT(tensor_cache != nullptr); std::vector quant_node_name; @@ -361,8 +380,10 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const } } -STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - const string &onnx_op_type, schema::CNodeT *dst_op) { +STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + const string &onnx_op_type, + schema::CNodeT *dst_op) { auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_op_type); if (node_parser == nullptr) { MS_LOG(EXCEPTION) << "not find " << onnx_op_type << ", node parser is nullptr"; @@ -371,8 +392,10 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, co return node_parser->Parse(onnx_graph, onnx_node, dst_op); } -STATUS OnnxModelParser::SetOpInputIndex(const std::vector &node_inputs, schema::CNodeT *dst_op, - const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) { +STATUS OnnxModelParser::SetOpInputIndex(const std::vector &node_inputs, + schema::CNodeT *dst_op, + const onnx::NodeProto &onnx_node, + TensorCache *tensor_cache) { for (const auto &onnx_node_input : node_inputs) { auto index = tensor_cache->FindTensor(onnx_node_input); if (index < 0) { @@ -385,7 +408,8 @@ STATUS OnnxModelParser::SetOpInputIndex(const std::vector &node_inputs, return RET_OK; } -STATUS OnnxModelParser::SetOpOutputIndex(const std::vector &node_outputs, schema::CNodeT *dst_op, +STATUS OnnxModelParser::SetOpOutputIndex(const std::vector &node_outputs, + schema::CNodeT *dst_op, TensorCache *tensor_cache) { for (const auto &onnx_node_output : node_outputs) { auto index = tensor_cache->FindTensor(onnx_node_output); @@ -400,7 +424,8 @@ STATUS OnnxModelParser::SetOpOutputIndex(const std::vector &node_outputs return RET_OK; } -STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, schema::TensorT *tensor) { +STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, + schema::TensorT *tensor) { size_t data_count = 1; std::for_each(tensor->dims.begin(), tensor->dims.end(), [&data_count](int dim) { data_count *= dim; }); size_t data_size = 0; @@ -459,7 +484,8 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v return RET_OK; } -STATUS OnnxModelParser::SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef) { +STATUS OnnxModelParser::SetAllTensors(const TensorCache &tensor_cache, + schema::MetaGraphT *graphDef) { std::vector tensors = tensor_cache.GetCachedTensor(); for (auto iter : tensors) { std::unique_ptr temp(iter); @@ -481,7 +507,8 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph) } } -MetaGraphT *OnnxModelParser::Parse(const std::string &modelFile, const std::string &weightFile, +MetaGraphT *OnnxModelParser::Parse(const std::string &modelFile, + const std::string &weightFile, const QuantType &quantType) { if (ValidateFileStr(modelFile, ".onnx") != RET_OK) { MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx"; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h index 4caf8751b7a..838b82f727c 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_MODEL_PARSER_H -#define MS_ONNX_MODEL_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MODEL_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MODEL_PARSER_H #include #include @@ -37,35 +37,83 @@ namespace lite { class OnnxModelParser : public ModelParser { public: OnnxModelParser(); + virtual ~OnnxModelParser(); + MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType = QuantType_QUANT_NONE) override; private: TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); + std::vector GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value); - STATUS ReadOnnxModelFromBinary(const std::string &modelFile, google::protobuf::Message *model_proto); - STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache); - STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache); - STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache); - STATUS AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const TensorType &type, - TensorCache *tensor_cache, int *index); - STATUS AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const TensorType &type, - TensorCache *tensor_cache, int *index); - STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache); - void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::MetaGraphT *graph, TensorCache *tensor_cache); - STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache); - STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - const string &onnx_op_type, schema::CNodeT *dst_op); - void SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op, - schema::TensorT *dst_tensor, TensorCache *tensor_cache); - STATUS SetOpInputIndex(const std::vector &node_inputs, schema::CNodeT *dst_op, - const onnx::NodeProto &onnx_node, TensorCache *tensor_cache); - STATUS SetOpOutputIndex(const std::vector &node_outputs, schema::CNodeT *dst_op, TensorCache *tensor_cache); - STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, schema::TensorT *tensor); - STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef); + + STATUS ReadOnnxModelFromBinary(const std::string &modelFile, + google::protobuf::Message *model_proto); + + STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, + TensorCache *tensor_cache); + + STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, + schema::MetaGraphT *graph, + TensorCache *tensor_cache); + + STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, + schema::MetaGraphT *graph, + TensorCache *tensor_cache); + + STATUS AddValueInfo(const onnx::ValueInfoProto &proto, + const std::string &name, + const TensorType &type, + TensorCache *tensor_cache, + int *index); + + STATUS AddTensorProto(const onnx::TensorProto &proto, + const std::string &name, + const TensorType &type, + TensorCache *tensor_cache, + int *index); + + STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *dst_op, + schema::TensorT *dst_tensor, + TensorCache *tensor_cache); + + void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::MetaGraphT *graph, + TensorCache *tensor_cache); + + STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, + TensorCache *tensor_cache); + + STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + const string &onnx_op_type, + schema::CNodeT *dst_op); + + void SetOpQuantParams(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *dst_op, + schema::TensorT *dst_tensor, + TensorCache *tensor_cache); + + STATUS SetOpInputIndex(const std::vector &node_inputs, + schema::CNodeT *dst_op, + const onnx::NodeProto &onnx_node, + TensorCache *tensor_cache); + + STATUS SetOpOutputIndex(const std::vector &node_outputs, + schema::CNodeT *dst_op, + TensorCache *tensor_cache); + + STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, + schema::TensorT *tensor); + + STATUS SetAllTensors(const TensorCache &tensor_cache, + schema::MetaGraphT *graphDef); + void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph); private: @@ -75,4 +123,4 @@ class OnnxModelParser : public ModelParser { } // namespace lite } // namespace mindspore -#endif // MS_ONNX_MODEL_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MODEL_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc index 59ba5256dee..ea23f518115 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc @@ -15,6 +15,7 @@ */ #include "tools/converter/parser/onnx/onnx_node_parser.h" +#include namespace mindspore { namespace lite { @@ -30,6 +31,20 @@ schema::PadMode OnnxNodeParser::GetOnnxPadMode(const onnx::AttributeProto &onnx_ return schema::PadMode_NOTSET; } } + +void OnnxNodeParser::Split(const std::string &src_str, + std::vector *dst_str, + const std::string &chr) { + std::string ::size_type p1 = 0, p2 = src_str.find(chr); + while (std::string::npos != p2) { + dst_str->push_back(src_str.substr(p1, p2 - p1)); + p1 = p2 + chr.size(); + p2 = src_str.find(chr, p1); + } + if (p1 != src_str.length()) { + dst_str->push_back(src_str.substr(p1)); + } +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h index 59a07f9e10f..901abc920fd 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h @@ -14,10 +14,11 @@ * limitations under the License. */ -#ifndef MS_ONNX_NODE_PARSER_H -#define MS_ONNX_NODE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_PARSER_H #include +#include #include "google/protobuf/message.h" #include "tools/converter/parser/onnx/onnx.pb.h" #include "include/errorcode.h" @@ -29,14 +30,23 @@ namespace lite { class OnnxNodeParser { public: explicit OnnxNodeParser(const std::string &nodeName) : name(nodeName) {} + virtual ~OnnxNodeParser() = default; - virtual STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) = 0; + + virtual STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) = 0; protected: schema::PadMode GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr); + + void Split(const std::string &src_str, + std::vector *dst_str, + const std::string &chr); + const std::string &name; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_NODE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.cc index ba165104624..ce9b6309177 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.cc @@ -21,7 +21,14 @@ namespace mindspore { namespace lite { OnnxNodeParserRegistry::OnnxNodeParserRegistry() = default; -OnnxNodeParserRegistry::~OnnxNodeParserRegistry() = default; +OnnxNodeParserRegistry::~OnnxNodeParserRegistry() { + for (auto ite : parsers) { + if (ite.second != nullptr) { + delete ite.second; + ite.second = nullptr; + } + } +} OnnxNodeParserRegistry *OnnxNodeParserRegistry::GetInstance() { static OnnxNodeParserRegistry instance; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h index b7fa61d2322..7027abab1f3 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_OP_REGISTRY_H -#define MS_ONNX_OP_REGISTRY_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_REGISTRY_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_REGISTRY_H #include #include @@ -30,6 +30,7 @@ class OnnxNodeParserRegistry { virtual ~OnnxNodeParserRegistry(); static OnnxNodeParserRegistry *GetInstance(); + OnnxNodeParser *GetNodeParser(const std::string &name); std::unordered_map parsers; @@ -37,12 +38,13 @@ class OnnxNodeParserRegistry { class OnnxNodeRegistrar { public: - OnnxNodeRegistrar(const std::string &name, OnnxNodeParser *parser) { + OnnxNodeRegistrar(const std::string &name, + OnnxNodeParser *parser) { OnnxNodeParserRegistry::GetInstance()->parsers[name] = parser; } }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_OP_REGISTRY_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_REGISTRY_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc index 96729b973ad..14e4ef5662d 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc @@ -14,14 +14,31 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_pad_parser.h" +#include namespace mindspore { namespace lite { -STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx PadParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "pads") { @@ -42,11 +59,9 @@ STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node } } } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Pad; - op->primitive->value.value = attr.release(); - } + + op->primitive->value.type = schema::PrimitiveType_Pad; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.h index 0e93953d8ed..e8f60ae30ce 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_LRN_PARSER_H -#define MS_ONNX_LRN_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_PAD_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_PAD_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxPadParser : public OnnxNodeParser { public: OnnxPadParser() : OnnxNodeParser("Pad") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_LRN_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_PAD_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc index 95525c0bf68..b8113a16855 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc @@ -14,14 +14,30 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_pool_parser.h" +#include namespace mindspore { namespace lite { -STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx PoolParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } attr->format = schema::Format_NCHW; const auto &pool_type = onnx_node.op_type(); @@ -79,11 +95,9 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod return RET_ERROR; } } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Pooling; - op->primitive->value.value = attr.release(); - } + + op->primitive->value.type = schema::PrimitiveType_Pooling; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.h index eaa154e2aa3..39b8e4d2412 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_POOL_PARSER_H -#define MS_ONNX_POOL_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_POOL_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_POOL_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxPoolParser : public OnnxNodeParser { public: OnnxPoolParser() : OnnxNodeParser("Pool") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_POOL_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_POOL_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc index 9d7aed8b957..fdf3b3ebe04 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_reduce_parser.h" +#include namespace mindspore { namespace lite { @@ -23,7 +23,22 @@ STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx ReduceParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "axes") { @@ -45,13 +60,12 @@ STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, } else if (type == "ReduceSum") { attr->mode = schema::ReduceMode_ReduceSum; } else { - // MS_LOGE("unsupoort type"); - } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Reduce; - op->primitive->value.value = attr.release(); + MS_LOG(ERROR) << "unsupported type"; + return RET_ERROR; } + + op->primitive->value.type = schema::PrimitiveType_Reduce; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.h index 50570ce917d..edb1c7b094c 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_REDUCE_PARSER_H -#define MS_ONNX_REDUCE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_REDUCE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_REDUCE_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxReduceParser : public OnnxNodeParser { public: OnnxReduceParser() : OnnxNodeParser("Reduce") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_REDUCE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_REDUCE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc index 5e233583666..d14a5054a74 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc @@ -14,36 +14,64 @@ * limitations under the License. */ +#include "tools/converter/parser/onnx/onnx_relu_parser.h" #include #include -#include "tools/converter/parser/onnx/onnx_relu_parser.h" #include "securec/include/securec.h" + namespace mindspore { namespace lite { -STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx ReluParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + const auto &relu_type = onnx_node.op_type(); if (relu_type == "Relu") { + MS_LOG(DEBUG) << "onnx ReluParser"; attr->type = schema::ActivationType_RELU; } else if (relu_type == "LeakyRelu") { + MS_LOG(DEBUG) << "onnx LeakyReluParser"; attr->type = schema::ActivationType_LEAKY_RELU; } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Activation; - op->primitive->value.value = attr.release(); - } + op->primitive->value.type = schema::PrimitiveType_Activation; + op->primitive->value.value = attr.release(); return RET_OK; } STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx PReluParser"; + + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + if (onnx_node.input_size() != 2) { - MS_LOG(ERROR) << "input num is not 2"; - return RET_PARAM_INVALID; + MS_LOG(ERROR) << "input num should be 2"; + return RET_ERROR; } std::unique_ptr attr = std::make_unique(); std::vector params; @@ -57,8 +85,8 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No const onnx::TensorProto *slope = ¶ms[0]; if (slope == nullptr) { - MS_LOG(ERROR) << "input error"; - return RET_PARAM_INVALID; + MS_LOG(ERROR) << "input error: params[0] is null"; + return RET_ERROR; } const auto slope_raw_data = reinterpret_cast(slope->raw_data().data()); const int64_t slope_size = slope->raw_data().size() / sizeof(float); @@ -74,11 +102,8 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No } } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_CaffePReLU; - op->primitive->value.value = attr.release(); - } + op->primitive->value.type = schema::PrimitiveType_CaffePReLU; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.h index 82b0fb3c38f..049c39b3aa7 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_RELU_PARSER_H -#define MS_ONNX_RELU_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RELU_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RELU_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,7 +25,10 @@ namespace lite { class OnnxReluParser : public OnnxNodeParser { public: OnnxReluParser() : OnnxNodeParser("Relu") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; class OnnxLeakeyReluParser : public OnnxReluParser { @@ -36,9 +39,12 @@ class OnnxLeakeyReluParser : public OnnxReluParser { class OnnxPReluParser : public OnnxNodeParser { public: OnnxPReluParser() : OnnxNodeParser("Prelu") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_RELU_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RELU_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc index d89777598a4..83b3d3ecf2e 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc @@ -14,16 +14,32 @@ * limitations under the License. */ +#include "tools/converter/parser/onnx/onnx_reshape_parser.h" #include #include -#include "tools/converter/parser/onnx/onnx_reshape_parser.h" namespace mindspore { namespace lite { -STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, +STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx ReshapeParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + attr->format = schema::Format_NCHW; std::vector params; for (int i = 0; i < onnx_node.input_size(); ++i) { @@ -40,18 +56,16 @@ STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: } else { if (params.size() != 1) { MS_LOG(ERROR) << "shape param num is " << params.size() << ", not equal to 1"; - return RET_PARAM_INVALID; + return RET_ERROR; } for (int i = 0; i < params[0].int64_data_size(); ++i) { attr->shape.emplace_back(params[0].int64_data(i)); } } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Reshape; - op->primitive->value.value = attr.release(); - } + + op->primitive->value.type = schema::PrimitiveType_Reshape; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.h index 1e43621d0ca..6bd227426bd 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_RESHAPE_PARSER_H -#define MS_ONNX_RESHAPE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESHAPE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESHAPE_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxReshapeParser : public OnnxNodeParser { public: OnnxReshapeParser() : OnnxNodeParser("Reshape") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_RESHAPE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESHAPE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.cc index 5b7d1955590..fbfc87a5689 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_shape_parser.h" +#include namespace mindspore { namespace lite { @@ -23,12 +23,24 @@ STATUS OnnxShapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx ShapeParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Shape; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Shape; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.h index ec3e33481f2..d504f5d69f0 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_SHAPE_PARSER_H -#define MS_ONNX_SHAPE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SHAPE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SHAPE_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxShapeParser : public OnnxNodeParser { public: OnnxShapeParser() : OnnxNodeParser("Shape") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_SHAPE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SHAPE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.cc index 15383689ac8..67cd08c8362 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_sigmoid_parser.h" +#include namespace mindspore { namespace lite { @@ -23,13 +23,26 @@ STATUS OnnxSigmoidParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx SigmoidParser"; - std::unique_ptr attr = std::make_unique(); - attr->type = schema::ActivationType_SIGMOID; - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Activation; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + attr->type = schema::ActivationType_SIGMOID; + + op->primitive->value.type = schema::PrimitiveType_Activation; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.h index dd1758dc258..e721f72931f 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_SIGMOID_PARSER_H -#define MS_ONNX_SIGMOID_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SIGMOID_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SIGMOID_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxSigmoidParser : public OnnxNodeParser { public: OnnxSigmoidParser() : OnnxNodeParser("Sigmoid") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_SIGMOID_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SIGMOID_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc index 541e0c86a88..c2bda2c3d86 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc @@ -14,15 +14,30 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_slice_parser.h" +#include namespace mindspore { namespace lite { STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx SliceParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "starts") { @@ -38,11 +53,9 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No } } } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Slice; - op->primitive->value.value = attr.release(); - } + + op->primitive->value.type = schema::PrimitiveType_Slice; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h index 6dd66415954..bda83c1866c 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_SLICE_PARSER_H -#define MS_ONNX_SLICE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SLICE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SLICE_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxSliceParser : public OnnxNodeParser { public: OnnxSliceParser() : OnnxNodeParser("Slice") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_SLICE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SLICE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc index 4f375598211..53dfe860e3c 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_softmax_parser.h" +#include namespace mindspore { namespace lite { @@ -23,18 +23,31 @@ STATUS OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx SoftMaxParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto& attribute_name = onnx_node_attr.name(); if (attribute_name == "axis") { attr->axis = static_cast(onnx_node_attr.i()); } } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_SoftMax; - op->primitive->value.value = attr.release(); - } + + op->primitive->value.type = schema::PrimitiveType_SoftMax; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.h index 012f2225ab5..668ab25ea94 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_SOFTMAX_PARSER_H -#define MS_ONNX_SOFTMAX_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SOFTMAX_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SOFTMAX_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxSoftMaxParser : public OnnxNodeParser { public: OnnxSoftMaxParser() : OnnxNodeParser("Softmax") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_SOFTMAX_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SOFTMAX_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc index fbb6d44d1b0..4928c8a05c7 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc @@ -14,26 +14,40 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_space_to_depth_parser.h" +#include namespace mindspore { namespace lite { -STATUS OnnxSpaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, +STATUS OnnxSpaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx SpaceToDepthParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "blocksize") { attr->blockSize = static_cast(onnx_node_attr.i()); } } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_SpaceToDepth; - op->primitive->value.value = attr.release(); - } + + op->primitive->value.type = schema::PrimitiveType_SpaceToDepth; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.h index fb45aaf2fe6..62340ba3818 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_SPACE_TO_DEPTH_PARSER_H -#define MS_ONNX_SPACE_TO_DEPTH_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPACE_TO_DEPTH_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPACE_TO_DEPTH_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxSpaceToDepthParser : public OnnxNodeParser { public: OnnxSpaceToDepthParser() : OnnxNodeParser("SpaceToDepth") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_SPACE_TO_DEPTH_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPACE_TO_DEPTH_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc index 7e8e2c202be..4f4c9ff868f 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_squeeze_parser.h" +#include namespace mindspore { namespace lite { @@ -23,7 +23,22 @@ STATUS OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx SqueezeParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "axes") { @@ -32,11 +47,9 @@ STATUS OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, } } } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Squeeze; - op->primitive->value.value = attr.release(); - } + + op->primitive->value.type = schema::PrimitiveType_Squeeze; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.h index 686381ceeb0..741c9754a4e 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_SQUEEZE_PARSER_H -#define MS_ONNX_SQUEEZE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SQUEEZE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SQUEEZE_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxSqueezeParser : public OnnxNodeParser { public: OnnxSqueezeParser() : OnnxNodeParser("Squeeze") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_SQUEEZE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SQUEEZE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc index e2d7782b104..20ade294857 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc @@ -14,19 +14,33 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_tile_parser.h" +#include namespace mindspore { namespace lite { -STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { +STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx TileParser"; - if (op != nullptr) { - std::unique_ptr attr = std::make_unique(); - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Tile; - op->primitive->value.value = attr.release(); + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + op->primitive->value.type = schema::PrimitiveType_Tile; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.h index 7c2d98418d1..a921911e1ca 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_TILE_PARSER_H -#define MS_ONNX_TILE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TILE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TILE_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxTileParser : public OnnxNodeParser { public: OnnxTileParser() : OnnxNodeParser("Tile") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_ARGMAX_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TILE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc index ffa5fc69166..7846a023582 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_transpose_parser.h" +#include namespace mindspore { namespace lite { @@ -23,7 +23,22 @@ STATUS OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx TransposeParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + attr->conjugate = false; for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); @@ -40,11 +55,9 @@ STATUS OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph, } } } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Transpose; - op->primitive->value.value = attr.release(); - } + + op->primitive->value.type = schema::PrimitiveType_Transpose; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.h index 225d1bd10f9..e9e84d025d2 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_TRANSPOSE_PARSER_H -#define MS_ONNX_TRANSPOSE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TRANSPOSE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TRANSPOSE_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxTransposeParser : public OnnxNodeParser { public: OnnxTransposeParser() : OnnxNodeParser("Transpose") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_TRANSPOSE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TRANSPOSE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.cc index b18c6ca3099..576fb2d4493 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.cc @@ -23,7 +23,22 @@ STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx UpsampleParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "mode") { @@ -34,12 +49,9 @@ STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph, } } } - // to do - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Upsample; - op->primitive->value.value = attr.release(); - } + + op->primitive->value.type = schema::PrimitiveType_Upsample; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.h index d76fff1ea7c..426d5d8b5a1 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_UPSAMPLE_PARSER_H -#define MS_ONNX_UPSAMPLE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UPSAMPLE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UPSAMPLE_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxUpsampleParser : public OnnxNodeParser { public: OnnxUpsampleParser() : OnnxNodeParser("Upsample") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_UPSAMPLE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UPSAMPLE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc index c5cc5716208..9d1f7dadd9e 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc @@ -14,15 +14,30 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_unsqueeze_parser.h" +#include namespace mindspore { namespace lite { STATUS OnnxUnSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx UnSqueezeParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "axes") { @@ -31,11 +46,9 @@ STATUS OnnxUnSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx } } } - if (op != nullptr) { - op->primitive = std::make_unique(); - op->primitive->value.type = schema::PrimitiveType_Unsqueeze; - op->primitive->value.value = attr.release(); - } + + op->primitive->value.type = schema::PrimitiveType_Unsqueeze; + op->primitive->value.value = attr.release(); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.h index bc96443fbee..10abcae3f76 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_UNSQUEEZE_PARSER_H -#define MS_ONNX_UNSQUEEZE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UNSQUEEZE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UNSQUEEZE_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxUnSqueezeParser : public OnnxNodeParser { public: OnnxUnSqueezeParser() : OnnxNodeParser("Unsqueeze") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_UNSQUEEZE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UNSQUEEZE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.cc index 4b3425eae69..869cbd8c025 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include "tools/converter/parser/onnx/onnx_unuseful_node_parser.h" +#include namespace mindspore { namespace lite { @@ -23,25 +23,35 @@ STATUS OnnxUnusefulNodeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { MS_LOG(DEBUG) << "onnx UnusefulNodeParser"; - if (op != nullptr) { - op->primitive = std::make_unique(); - if (onnx_node.op_type() == "Int8Quantize") { - op->primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize; - op->primitive->value.value = std::make_unique().release(); - } else if (onnx_node.op_type() == "Int8Dequantize") { - op->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize; - op->primitive->value.value = std::make_unique().release(); - } else { - // MS_LOGE("Unsupported nodeType: %s", onnx_node.op_type().c_str()); - return RET_ERROR; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + if (onnx_node.op_type() == "Int8Quantize") { + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; } - if (op->primitive->value.value == nullptr) { - // MS_LOGE("new %s attr value failed", onnx_node.op_type().c_str()); - return RET_ERROR; + op->primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize; + op->primitive->value.value = attr.release(); + } else if (onnx_node.op_type() == "Int8Dequantize") { + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; } + op->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize; + op->primitive->value.value = attr.release(); } else { - // MS_LOGE("Input opDef is nullptr"); - return RET_PARAM_INVALID; + MS_LOG(ERROR) << "Unsupported nodeType: " << onnx_node.op_type().c_str(); + return RET_ERROR; } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.h index 418e33af38a..94cb3db72e4 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MS_ONNX_UNUSEFUL_PARSER_H -#define MS_ONNX_UNUSEFUL_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" @@ -25,9 +25,12 @@ namespace lite { class OnnxUnusefulNodeParser : public OnnxNodeParser { public: OnnxUnusefulNodeParser() : OnnxNodeParser("UnusefulNode") {} - STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + STATUS Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MS_ONNX_UNUSEFUL_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H