modify onnx parsers format

This commit is contained in:
lyvette 2020-08-21 10:34:43 +08:00
parent ec1cf059a7
commit 6313af4d2f
76 changed files with 1548 additions and 575 deletions

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_argmax_parser.h" #include "tools/converter/parser/onnx/onnx_argmax_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -23,7 +23,22 @@ STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ArgMaxParser"; MS_LOG(DEBUG) << "onnx ArgMaxParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ArgMaxT> attr = std::make_unique<schema::ArgMaxT>(); std::unique_ptr<schema::ArgMaxT> attr = std::make_unique<schema::ArgMaxT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name(); const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") { if (attribute_name == "axis") {
@ -32,11 +47,9 @@ STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph,
attr->keepDims = static_cast<bool>(onnx_node_attr.i()); attr->keepDims = static_cast<bool>(onnx_node_attr.i());
} }
} }
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.type = schema::PrimitiveType_ArgMax;
op->primitive->value.type = schema::PrimitiveType_ArgMax; op->primitive->value.value = attr.release();
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_ARGMAX_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARGMAX_PARSER_H
#define MS_ONNX_ARGMAX_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARGMAX_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxArgMaxParser : public OnnxNodeParser { class OnnxArgMaxParser : public OnnxNodeParser {
public: public:
OnnxArgMaxParser() : OnnxNodeParser("ArgMax") {} OnnxArgMaxParser() : OnnxNodeParser("ArgMax") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_ARGMAX_PARSER_H #endif // MMINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARGMAX_PARSER_H

View File

@ -14,114 +14,238 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h" #include "tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx AddParser"; MS_LOG(DEBUG) << "onnx AddParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::AddT> attr = std::make_unique<schema::AddT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Add;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::AddT> attr = std::make_unique<schema::AddT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Add;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
STATUS OnnxSubParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxSubParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SubParser"; MS_LOG(DEBUG) << "onnx SubParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::SubT> attr = std::make_unique<schema::SubT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Sub;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::SubT> attr = std::make_unique<schema::SubT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Sub;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
STATUS OnnxMulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxMulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx MulParser"; MS_LOG(DEBUG) << "onnx MulParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::MulT> attr = std::make_unique<schema::MulT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Mul;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::MulT> attr = std::make_unique<schema::MulT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Mul;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
STATUS OnnxDivParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxDivParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx DivParser"; MS_LOG(DEBUG) << "onnx DivParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::DivT> attr = std::make_unique<schema::DivT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Div;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::DivT> attr = std::make_unique<schema::DivT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Div;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
STATUS OnnxPowParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxPowParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx PowParser"; MS_LOG(DEBUG) << "onnx PowParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::PowerT> attr(new schema::PowerT()); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Power;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::PowerT> attr(new schema::PowerT());
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Power;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
STATUS OnnxEqualParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, STATUS OnnxEqualParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx EqualParser"; MS_LOG(DEBUG) << "onnx EqualParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::EqualT> attr = std::make_unique<schema::EqualT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Equal;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::EqualT> attr = std::make_unique<schema::EqualT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Equal;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
STATUS OnnxLessParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxLessParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx LessParser"; MS_LOG(DEBUG) << "onnx LessParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::LessT> attr = std::make_unique<schema::LessT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Less;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::LessT> attr = std::make_unique<schema::LessT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Less;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
STATUS OnnxGreaterParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, STATUS OnnxGreaterParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx GreaterParser"; MS_LOG(DEBUG) << "onnx GreaterParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::GreaterT> attr = std::make_unique<schema::GreaterT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Greater;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::GreaterT> attr = std::make_unique<schema::GreaterT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Greater;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
STATUS OnnxMinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxMinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx MinParser"; MS_LOG(DEBUG) << "onnx MinParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::MinT> attr = std::make_unique<schema::MinT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Min;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::MinT> attr = std::make_unique<schema::MinT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Min;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
STATUS OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, STATUS OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx EltwiseParser"; MS_LOG(DEBUG) << "onnx EltwiseParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::EltwiseT> attr = std::make_unique<schema::EltwiseT>(); std::unique_ptr<schema::EltwiseT> attr = std::make_unique<schema::EltwiseT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
// there is no Prod in onnx // there is no Prod in onnx
if (onnx_node.op_type() == "Sum") { if (onnx_node.op_type() == "Sum") {
attr->mode = schema::EltwiseMode_SUM; attr->mode = schema::EltwiseMode_SUM;
@ -129,143 +253,303 @@ STATUS OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::
attr->mode = schema::EltwiseMode_MAXIMUM; attr->mode = schema::EltwiseMode_MAXIMUM;
} }
if (op != nullptr) { op->primitive->value.type = schema::PrimitiveType_Eltwise;
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.value = attr.release();
op->primitive->value.type = schema::PrimitiveType_Eltwise;
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }
STATUS OnnxFloorParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, STATUS OnnxFloorParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx FloorParser"; MS_LOG(DEBUG) << "onnx FloorParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::FloorT> attr = std::make_unique<schema::FloorT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Floor;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::FloorT> attr = std::make_unique<schema::FloorT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Floor;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
STATUS OnnxAbsParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxAbsParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx AbsParser"; MS_LOG(DEBUG) << "onnx AbsParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::AbsT> attr = std::make_unique<schema::AbsT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Abs;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::AbsT> attr = std::make_unique<schema::AbsT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Abs;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
STATUS OnnxNegParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxNegParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx NegParser"; MS_LOG(DEBUG) << "onnx NegParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::NegT> attr = std::make_unique<schema::NegT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Neg;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::NegT> attr = std::make_unique<schema::NegT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Neg;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
STATUS OnnxExpParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxExpParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ExpParser"; MS_LOG(DEBUG) << "onnx ExpParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::ExpT> attr = std::make_unique<schema::ExpT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Exp;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ExpT> attr = std::make_unique<schema::ExpT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Exp;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
STATUS OnnxCosParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxCosParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx CosParser"; MS_LOG(DEBUG) << "onnx CosParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::CosT> attr = std::make_unique<schema::CosT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Cos;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::CosT> attr = std::make_unique<schema::CosT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Cos;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
STATUS OnnxSinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxSinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SinParser"; MS_LOG(DEBUG) << "onnx SinParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::SinT> attr = std::make_unique<schema::SinT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Sin;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::SinT> attr = std::make_unique<schema::SinT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Sin;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
STATUS OnnxSqrtParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxSqrtParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SqrtParser"; MS_LOG(DEBUG) << "onnx SqrtParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::SqrtT> attr = std::make_unique<schema::SqrtT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Sqrt;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::SqrtT> attr = std::make_unique<schema::SqrtT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Sqrt;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
STATUS OnnxCeilParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxCeilParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx CeilParser"; MS_LOG(DEBUG) << "onnx CeilParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::CeilT> attr = std::make_unique<schema::CeilT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Ceil;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::CeilT> attr = std::make_unique<schema::CeilT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Ceil;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
STATUS OnnxLogParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxLogParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx LogParser"; MS_LOG(DEBUG) << "onnx LogParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::LogT> attr = std::make_unique<schema::LogT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Log;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::LogT> attr = std::make_unique<schema::LogT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Log;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
STATUS OnnxTanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxTanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx TanParser"; MS_LOG(DEBUG) << "onnx TanParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::TanT> attr = std::make_unique<schema::TanT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Tan;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::TanT> attr = std::make_unique<schema::TanT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Tan;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
STATUS OnnxAtanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxAtanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx AtanParser"; MS_LOG(DEBUG) << "onnx AtanParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::AtanT> attr = std::make_unique<schema::AtanT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Atan;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::AtanT> attr = std::make_unique<schema::AtanT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Atan;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
STATUS OnnxAsinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxAsinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx AsinParser"; MS_LOG(DEBUG) << "onnx AsinParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::AsinT> attr = std::make_unique<schema::AsinT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Asin;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::AsinT> attr = std::make_unique<schema::AsinT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Asin;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }
STATUS OnnxTanhParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxTanhParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx TanhParser"; MS_LOG(DEBUG) << "onnx TanhParser";
if (op != nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "mslite don't support tanh now"; MS_LOG(ERROR) << "op is null";
return RET_ERROR; return RET_NULL_PTR;
} }
return RET_OK; op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(ERROR) << "mslite don't support tanh now";
return RET_ERROR;
} }
OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser()); OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser());

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_ARITHMETIC_OPREATION_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H
#define MS_ONNX_ARITHMETIC_OPREATION_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -167,5 +167,5 @@ class OnnxTanhParser : public OnnxNodeParser {
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_ARITHMETIC_OPREATION_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H

View File

@ -19,10 +19,26 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx BatchNormParser"; MS_LOG(DEBUG) << "onnx BatchNormParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::FusedBatchNormT> attr = std::make_unique<schema::FusedBatchNormT>(); std::unique_ptr<schema::FusedBatchNormT> attr = std::make_unique<schema::FusedBatchNormT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "epsilon") { if (onnx_node_attr.name() == "epsilon") {
attr->epsilon = onnx_node_attr.f(); attr->epsilon = onnx_node_attr.f();
@ -32,11 +48,9 @@ STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx
attr->spatial = static_cast<int32_t>(onnx_node_attr.i()); attr->spatial = static_cast<int32_t>(onnx_node_attr.i());
} }
} }
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
op->primitive->value.type = schema::PrimitiveType_FusedBatchNorm; op->primitive->value.value = attr.release();
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_ADD_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BATCHNORM_PARSER_H
#define MS_ONNX_ADD_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BATCHNORM_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxBatchNormParser : public OnnxNodeParser { class OnnxBatchNormParser : public OnnxNodeParser {
public: public:
OnnxBatchNormParser() : OnnxNodeParser("BatchNormalization") {} OnnxBatchNormParser() : OnnxNodeParser("BatchNormalization") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_ADD_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BATCHNORM_PARSER_H

View File

@ -14,26 +14,36 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_biasadd_parser.h" #include "tools/converter/parser/onnx/onnx_biasadd_parser.h"
#include <memory>
// using namespace mindspore::predict;
// using namespace onnx;
// using namespace std;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxBiasAddParser::Parse(const onnx::GraphProto &onnx_graph, STATUS OnnxBiasAddParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx BiasAddParser"; MS_LOG(DEBUG) << "onnx BiasAddParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::BiasAddT> attr = std::make_unique<schema::BiasAddT>(); std::unique_ptr<schema::BiasAddT> attr = std::make_unique<schema::BiasAddT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
// use channel dim as axis // use channel dim as axis
attr->axis = {1}; attr->axis = {1};
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.type = schema::PrimitiveType_BiasAdd;
op->primitive->value.type = schema::PrimitiveType_BiasAdd; op->primitive->value.value = attr.release();
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_BIASADD_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BIASADD_PARSER_H
#define MS_ONNX_BIASADD_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BIASADD_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -26,9 +26,11 @@ class OnnxBiasAddParser : public OnnxNodeParser {
public: public:
OnnxBiasAddParser() : OnnxNodeParser("BiasAdd") {} OnnxBiasAddParser() : OnnxNodeParser("BiasAdd") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_BIASADD_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BIASADD_PARSER_H

View File

@ -14,25 +14,40 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_cast_parser.h" #include "tools/converter/parser/onnx/onnx_cast_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx CastParser"; MS_LOG(DEBUG) << "onnx CastParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>(); std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name(); const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "to") { if (attribute_name == "to") {
attr->dstT = static_cast<int32_t>(onnx_node_attr.i()); attr->dstT = static_cast<int32_t>(onnx_node_attr.i());
} }
} }
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.type = schema::PrimitiveType_Cast;
op->primitive->value.type = schema::PrimitiveType_Cast; op->primitive->value.value = attr.release();
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_CAST_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CAST_PARSER_H
#define MS_ONNX_CAST_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CAST_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxCastParser : public OnnxNodeParser { class OnnxCastParser : public OnnxNodeParser {
public: public:
OnnxCastParser() : OnnxNodeParser("Cast") {} OnnxCastParser() : OnnxNodeParser("Cast") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_CAST_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CAST_PARSER_H

View File

@ -14,13 +14,25 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_clip_parser.h" #include "tools/converter/parser/onnx/onnx_clip_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ClipParser"; MS_LOG(DEBUG) << "onnx ClipParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
float min = -1, max = -1; float min = -1, max = -1;
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name(); const auto &attribute_name = onnx_node_attr.name();
@ -32,15 +44,17 @@ STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
} }
if (min == 0 && max == 6) { if (min == 0 && max == 6) {
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>(); std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
attr->type = schema::ActivationType_RELU6; if (attr == nullptr) {
if (op != nullptr) { MS_LOG(ERROR) << "new op failed";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
} }
attr->type = schema::ActivationType_RELU6;
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
} else { } else {
MS_LOG(ERROR) << "only support convert clip(0,6) to relu6, other value is not supported"; MS_LOG(ERROR) << "only support convert clip(0,6) to relu6, other value is not supported";
return RET_PARAM_INVALID; return RET_ERROR;
} }
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_CLIP_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CLIP_PARSER_H
#define MS_ONNX_CLIP_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CLIP_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxClipParser : public OnnxNodeParser { class OnnxClipParser : public OnnxNodeParser {
public: public:
OnnxClipParser() : OnnxNodeParser("Clip") {} OnnxClipParser() : OnnxNodeParser("Clip") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_ARGMAX_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CLIP_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_concat_parser.h" #include "tools/converter/parser/onnx/onnx_concat_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -23,18 +23,31 @@ STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ConcatParser"; MS_LOG(DEBUG) << "onnx ConcatParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ConcatT> attr = std::make_unique<schema::ConcatT>(); std::unique_ptr<schema::ConcatT> attr = std::make_unique<schema::ConcatT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name(); const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") { if (attribute_name == "axis") {
attr->axis = static_cast<int32_t>(onnx_node_attr.i()); attr->axis = static_cast<int32_t>(onnx_node_attr.i());
} }
} }
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.type = schema::PrimitiveType_Concat;
op->primitive->value.type = schema::PrimitiveType_Concat; op->primitive->value.value = attr.release();
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_CONCAT_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONCAT_PARSER_H
#define MS_ONNX_CONCAT_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONCAT_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxConcatParser : public OnnxNodeParser { class OnnxConcatParser : public OnnxNodeParser {
public: public:
OnnxConcatParser() : OnnxNodeParser("Concat") {} OnnxConcatParser() : OnnxNodeParser("Concat") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_CONCAT_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONCAT_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_constant_parser.h" #include "tools/converter/parser/onnx/onnx_constant_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -23,12 +23,24 @@ STATUS OnnxConstantParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ConstantParser"; MS_LOG(DEBUG) << "onnx ConstantParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::ConstantT> attr = std::make_unique<schema::ConstantT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Constant;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ConstantT> attr = std::make_unique<schema::ConstantT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Constant;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_CONSTANT_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONSTANT_PARSER_H
#define MS_ONNX_CONSTANT_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONSTANT_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxConstantParser : public OnnxNodeParser { class OnnxConstantParser : public OnnxNodeParser {
public: public:
OnnxConstantParser() : OnnxNodeParser("Constant") {} OnnxConstantParser() : OnnxNodeParser("Constant") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_CONSTANT_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONSTANT_PARSER_H

View File

@ -14,21 +14,22 @@
* limitations under the License. * limitations under the License.
*/ */
#include "tools/converter/parser/onnx/onnx_conv_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <algorithm> #include <algorithm>
#include "tools/converter/parser/onnx/onnx_conv_parser.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, schema::CNodeT *op) { bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx DepthwiseConvParser"; MS_LOG(DEBUG) << "onnx DepthwiseConvParser";
if (attr == nullptr || attr->group != attr->channelIn) { if (attr == nullptr || attr->group != attr->channelIn) {
return false; return false;
} }
std::unique_ptr<schema::DepthwiseConv2DT> depthwiseConv2DParam = std::make_unique<schema::DepthwiseConv2DT>(); std::unique_ptr<schema::DepthwiseConv2DT> depthwiseConv2DParam = std::make_unique<schema::DepthwiseConv2DT>();
if (depthwiseConv2DParam == nullptr) { if (depthwiseConv2DParam == nullptr) {
MS_LOG(ERROR) << "new DepthwiseConv2DT failed"; MS_LOG(ERROR) << "new op failed";
return false; return false;
} }
depthwiseConv2DParam->format = attr->format; depthwiseConv2DParam->format = attr->format;
@ -47,15 +48,32 @@ bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT
depthwiseConv2DParam->dilateH = attr->dilateH; depthwiseConv2DParam->dilateH = attr->dilateH;
depthwiseConv2DParam->hasBias = attr->hasBias; depthwiseConv2DParam->hasBias = attr->hasBias;
depthwiseConv2DParam->activationType = attr->activationType; depthwiseConv2DParam->activationType = attr->activationType;
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
op->primitive->value.value = depthwiseConv2DParam.release(); op->primitive->value.value = depthwiseConv2DParam.release();
return true; return true;
} }
STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ConvParser"; MS_LOG(DEBUG) << "onnx ConvParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::Conv2DT> attr = std::make_unique<schema::Conv2DT>(); std::unique_ptr<schema::Conv2DT> attr = std::make_unique<schema::Conv2DT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
// set opdef each attr params // set opdef each attr params
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "group") { if (onnx_node_attr.name() == "group") {
@ -149,13 +167,13 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
} else { } else {
attr->activationType = schema::ActivationType_NO_ACTIVATION; attr->activationType = schema::ActivationType_NO_ACTIVATION;
} }
if (attr->group != 1) { if (attr->group != 1) {
if (!ParseGroupConvolution(attr, op)) { if (!ParseGroupConvolution(attr, op)) {
MS_LOG(ERROR) << "Convert Convolution to Depthwise failed"; MS_LOG(ERROR) << "Convert Convolution to Depthwise failed";
return RET_ERROR; return RET_ERROR;
} }
} else { } else {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Conv2D; op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_CONV_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONV_PARSER_H
#define MS_ONNX_CONV_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONV_PARSER_H
#include <memory> #include <memory>
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
@ -26,11 +26,16 @@ namespace lite {
class OnnxConvParser : public OnnxNodeParser { class OnnxConvParser : public OnnxNodeParser {
public: public:
OnnxConvParser() : OnnxNodeParser("Conv") {} OnnxConvParser() : OnnxNodeParser("Conv") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
private: private:
bool ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, schema::CNodeT *op); bool ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr,
schema::CNodeT *op);
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_CONV_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONV_PARSER_H

View File

@ -22,6 +22,7 @@ namespace lite {
OnnxConverter::OnnxConverter() { OnnxConverter::OnnxConverter() {
modelParser = new OnnxModelParser(); modelParser = new OnnxModelParser();
} }
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,9 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_CONVERTER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H
#define MS_ONNX_CONVERTER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H
#include <string> #include <string>
#include <memory> #include <memory>
#include "tools/converter/converter.h" #include "tools/converter/converter.h"
@ -27,10 +28,10 @@ class OnnxConverter : public Converter {
public: public:
OnnxConverter(); OnnxConverter();
~OnnxConverter() override = default; ~OnnxConverter() = default;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_CONVERTER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H

View File

@ -14,21 +14,21 @@
* limitations under the License. * limitations under the License.
*/ */
#include "tools/converter/parser/onnx/onnx_deconv_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <algorithm> #include <algorithm>
#include "tools/converter/parser/onnx/onnx_deconv_parser.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op) { bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr,
MS_LOG(DEBUG) << "onnx DeConvParser"; schema::CNodeT *op) {
if (attr == nullptr || attr->group != attr->channelOut) { if (attr == nullptr || attr->group != attr->channelOut) {
return false; return false;
} }
std::unique_ptr<schema::DeDepthwiseConv2DT> deDepthwiseConv2DParam = std::make_unique<schema::DeDepthwiseConv2DT>(); std::unique_ptr<schema::DeDepthwiseConv2DT> deDepthwiseConv2DParam = std::make_unique<schema::DeDepthwiseConv2DT>();
if (deDepthwiseConv2DParam == nullptr) { if (deDepthwiseConv2DParam == nullptr) {
MS_LOG(ERROR) << "new DeDepthwiseConv2DT failed"; MS_LOG(WARNING) << "new op failed";
return false; return false;
} }
deDepthwiseConv2DParam->format = attr->format; deDepthwiseConv2DParam->format = attr->format;
@ -47,38 +47,53 @@ bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeC
deDepthwiseConv2DParam->dilateH = attr->dilateH; deDepthwiseConv2DParam->dilateH = attr->dilateH;
deDepthwiseConv2DParam->hasBias = attr->hasBias; deDepthwiseConv2DParam->hasBias = attr->hasBias;
deDepthwiseConv2DParam->activationType = attr->activationType; deDepthwiseConv2DParam->activationType = attr->activationType;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D;
op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D; op->primitive->value.value = deDepthwiseConv2DParam.release();
op->primitive->value.value = deDepthwiseConv2DParam.release();
}
return true; return true;
} }
STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx DeConvParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::DeConv2DT> attr = std::make_unique<schema::DeConv2DT>(); std::unique_ptr<schema::DeConv2DT> attr = std::make_unique<schema::DeConv2DT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
// set opdef each attr params // set opdef each attr params
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "group") { if (onnx_node_attr.name() == "group") {
attr->group = static_cast<int32_t>(onnx_node_attr.i()); attr->group = static_cast<int32_t>(onnx_node_attr.i());
} else if (onnx_node_attr.name() == "dilations") { } else if (onnx_node_attr.name() == "dilations") {
if (onnx_node_attr.ints().size() != 2) { if (onnx_node_attr.ints().size() != 2) {
// MS_LOGE("dilations size %d is not 2", onnx_node_attr.ints().size()); MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR; return RET_ERROR;
} }
attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(0)); attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(1)); attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "kernels") { } else if (onnx_node_attr.name() == "kernels") {
if (onnx_node_attr.ints().size() != 2) { if (onnx_node_attr.ints().size() != 2) {
// MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size()); MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR; return RET_ERROR;
} }
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0)); attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1)); attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "kernel_shape") { } else if (onnx_node_attr.name() == "kernel_shape") {
if (onnx_node_attr.ints().size() != 2) { if (onnx_node_attr.ints().size() != 2) {
// MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size()); MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR; return RET_ERROR;
} }
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(0)); attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(0));
@ -87,7 +102,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
attr->padMode = GetOnnxPadMode(onnx_node_attr); attr->padMode = GetOnnxPadMode(onnx_node_attr);
} else if (onnx_node_attr.name() == "pads") { } else if (onnx_node_attr.name() == "pads") {
if (onnx_node_attr.ints().size() != 4) { if (onnx_node_attr.ints().size() != 4) {
// MS_LOGE("pads size %d is not 4", onnx_node_attr.ints().size()); MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4";
return RET_ERROR; return RET_ERROR;
} }
attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0)); attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0));
@ -96,7 +111,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
attr->padRight = static_cast<int32_t>(onnx_node_attr.ints(3)); attr->padRight = static_cast<int32_t>(onnx_node_attr.ints(3));
} else if (onnx_node_attr.name() == "strides") { } else if (onnx_node_attr.name() == "strides") {
if (onnx_node_attr.ints().size() != 2) { if (onnx_node_attr.ints().size() != 2) {
// MS_LOGE("strides size %d is not 2", onnx_node_attr.ints().size()); MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR; return RET_ERROR;
} }
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(0)); attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(0));
@ -105,7 +120,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
if (onnx_node_attr.s() == "NHWC") { if (onnx_node_attr.s() == "NHWC") {
attr->format = schema::Format_NHWC; attr->format = schema::Format_NHWC;
} else { } else {
// MS_LOGE("Unsupported format: %s", onnx_node_attr.s().c_str()); MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s().c_str();
return RET_ERROR; return RET_ERROR;
} }
} }
@ -116,7 +131,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
[onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; });
if (nodeIter == onnx_graph.initializer().end()) { if (nodeIter == onnx_graph.initializer().end()) {
// MS_LOGE("not find node: %s", onnx_conv_weight.c_str()) MS_LOG(ERROR) << "not find node: " << onnx_conv_weight.c_str();
return RET_ERROR; return RET_ERROR;
} }
std::vector<int> weight_shape; std::vector<int> weight_shape;
@ -137,7 +152,6 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
return RET_ERROR; return RET_ERROR;
} }
} else { } else {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_DeConv2D; op->primitive->value.type = schema::PrimitiveType_DeConv2D;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_DECONV_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DECONV_PARSER_H
#define MS_ONNX_DECONV_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DECONV_PARSER_H
#include <memory> #include <memory>
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
@ -26,11 +26,16 @@ namespace lite {
class OnnxDeConvParser : public OnnxNodeParser { class OnnxDeConvParser : public OnnxNodeParser {
public: public:
OnnxDeConvParser() : OnnxNodeParser("DeConv") {} OnnxDeConvParser() : OnnxNodeParser("DeConv") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
private: private:
bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op); bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr,
schema::CNodeT *op);
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_DECONV_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DECONV_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_depth_to_space_parser.h" #include "tools/converter/parser/onnx/onnx_depth_to_space_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -23,18 +23,31 @@ STATUS OnnxDepthToSpaceParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx DepthToSpaceParser"; MS_LOG(DEBUG) << "onnx DepthToSpaceParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::DepthToSpaceT> attr = std::make_unique<schema::DepthToSpaceT>(); std::unique_ptr<schema::DepthToSpaceT> attr = std::make_unique<schema::DepthToSpaceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto& attribute_name = onnx_node_attr.name(); const auto& attribute_name = onnx_node_attr.name();
if (attribute_name == "blocksize") { if (attribute_name == "blocksize") {
attr->blockSize = static_cast<int32_t>(onnx_node_attr.i()); attr->blockSize = static_cast<int32_t>(onnx_node_attr.i());
} }
} }
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.type = schema::PrimitiveType_DepthToSpace;
op->primitive->value.type = schema::PrimitiveType_DepthToSpace; op->primitive->value.value = attr.release();
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_DEPTH_TO_SPACE_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DEPTH_TO_SPACE_PARSER_H
#define MS_ONNX_DEPTH_TO_SPACE_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DEPTH_TO_SPACE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxDepthToSpaceParser : public OnnxNodeParser { class OnnxDepthToSpaceParser : public OnnxNodeParser {
public: public:
OnnxDepthToSpaceParser() : OnnxNodeParser("DepthToSpace") {} OnnxDepthToSpaceParser() : OnnxNodeParser("DepthToSpace") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_DEPTH_TO_SPACE_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DEPTH_TO_SPACE_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_dropout_parser.h" #include "tools/converter/parser/onnx/onnx_dropout_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -23,18 +23,31 @@ STATUS OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx DropoutParser"; MS_LOG(DEBUG) << "onnx DropoutParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::DropoutT> attr = std::make_unique<schema::DropoutT>(); std::unique_ptr<schema::DropoutT> attr = std::make_unique<schema::DropoutT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name(); const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "ratio") { if (attribute_name == "ratio") {
attr->ratio = static_cast<int32_t>(onnx_node_attr.i()); attr->ratio = static_cast<int32_t>(onnx_node_attr.i());
} }
} }
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.type = schema::PrimitiveType_Dropout;
op->primitive->value.type = schema::PrimitiveType_Dropout; op->primitive->value.value = attr.release();
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_ARGMAX_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DROPOUT_PARSER_H
#define MS_ONNX_ARGMAX_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DROPOUT_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxDropoutParser : public OnnxNodeParser { class OnnxDropoutParser : public OnnxNodeParser {
public: public:
OnnxDropoutParser() : OnnxNodeParser("Dropout") {} OnnxDropoutParser() : OnnxNodeParser("Dropout") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_ARGMAX_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DROPOUT_PARSER_H

View File

@ -14,25 +14,40 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_elu_parser.h" #include "tools/converter/parser/onnx/onnx_elu_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx EluParser"; MS_LOG(DEBUG) << "onnx EluParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::EluT> attr = std::make_unique<schema::EluT>(); std::unique_ptr<schema::EluT> attr = std::make_unique<schema::EluT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto& attribute_name = onnx_node_attr.name(); const auto& attribute_name = onnx_node_attr.name();
if (attribute_name == "alpha") { if (attribute_name == "alpha") {
attr->alpha = onnx_node_attr.f(); attr->alpha = onnx_node_attr.f();
} }
} }
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.type = schema::PrimitiveType_Elu;
op->primitive->value.type = schema::PrimitiveType_Elu; op->primitive->value.value = attr.release();
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_ELU_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ELU_PARSER_H
#define MS_ONNX_ELU_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ELU_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxEluParser : public OnnxNodeParser { class OnnxEluParser : public OnnxNodeParser {
public: public:
OnnxEluParser() : OnnxNodeParser("Elu") {} OnnxEluParser() : OnnxNodeParser("Elu") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_ELU_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ELU_PARSER_H

View File

@ -14,20 +14,32 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_expand_parser.h" #include "tools/converter/parser/onnx/onnx_expand_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ExpandParser"; MS_LOG(DEBUG) << "onnx ExpandParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::BroadcastT> attr = std::make_unique<schema::BroadcastT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Broadcast;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::BroadcastT> attr = std::make_unique<schema::BroadcastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Broadcast;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_EXPAND_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_EXPAND_PARSER_H
#define MS_ONNX_EXPAND_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_EXPAND_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxExpandParser : public OnnxNodeParser { class OnnxExpandParser : public OnnxNodeParser {
public: public:
OnnxExpandParser() : OnnxNodeParser("Expand") {} OnnxExpandParser() : OnnxNodeParser("Expand") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_EXPAND_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_EXPAND_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_flatten_parser.h" #include "tools/converter/parser/onnx/onnx_flatten_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -23,7 +23,22 @@ STATUS OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx FlattenParser"; MS_LOG(DEBUG) << "onnx FlattenParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ReshapeT> attr = std::make_unique<schema::ReshapeT>(); std::unique_ptr<schema::ReshapeT> attr = std::make_unique<schema::ReshapeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
int axis = 1; int axis = 1;
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name(); const auto &attribute_name = onnx_node_attr.name();
@ -36,11 +51,8 @@ STATUS OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph,
} }
attr->shape.emplace_back(-1); attr->shape.emplace_back(-1);
if (op != nullptr) { op->primitive->value.type = schema::PrimitiveType_Reshape;
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.value = attr.release();
op->primitive->value.type = schema::PrimitiveType_Reshape;
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_FLATTEN_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_FLATTEN_PARSER_H
#define MS_ONNX_FLATTEN_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_FLATTEN_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -26,9 +26,11 @@ class OnnxFlattenParser : public OnnxNodeParser {
public: public:
OnnxFlattenParser() : OnnxNodeParser("Fatten") {} OnnxFlattenParser() : OnnxNodeParser("Fatten") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_FLATTEN_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_FLATTEN_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_gather_parser.h" #include "tools/converter/parser/onnx/onnx_gather_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -23,18 +23,31 @@ STATUS OnnxGatherParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx GatherParser"; MS_LOG(DEBUG) << "onnx GatherParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::GatherT> attr = std::make_unique<schema::GatherT>(); std::unique_ptr<schema::GatherT> attr = std::make_unique<schema::GatherT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto& attribute_name = onnx_node_attr.name(); const auto& attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") { if (attribute_name == "axis") {
attr->axis = static_cast<int32_t>(onnx_node_attr.i()); attr->axis = static_cast<int32_t>(onnx_node_attr.i());
} }
} }
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.type = schema::PrimitiveType_Gather;
op->primitive->value.type = schema::PrimitiveType_Gather; op->primitive->value.value = attr.release();
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_GATHER_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GATHER_PARSER_H
#define MS_ONNX_GATHER_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GATHER_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxGatherParser : public OnnxNodeParser { class OnnxGatherParser : public OnnxNodeParser {
public: public:
OnnxGatherParser() : OnnxNodeParser("Gather") {} OnnxGatherParser() : OnnxNodeParser("Gather") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_GATHER_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GATHER_PARSER_H

View File

@ -14,31 +14,69 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_lrn_parser.h" #include "tools/converter/parser/onnx/onnx_lrn_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx LrnParser"; MS_LOG(DEBUG) << "onnx LrnParser";
std::unique_ptr<schema::LrnT> attr = std::make_unique<schema::LrnT>(); if (op == nullptr) {
for (const auto &onnx_node_attr : onnx_node.attribute()) { MS_LOG(ERROR) << "op is null";
const auto& attribute_name = onnx_node_attr.name(); return RET_NULL_PTR;
if (attribute_name == "size") {
attr->size = static_cast<int32_t>(onnx_node_attr.i());
} else if (attribute_name == "alpha") {
attr->alpha = onnx_node_attr.f();
} else if (attribute_name == "beta") {
attr->beta = onnx_node_attr.f();
} else if (attribute_name == "bias") {
attr->bias = onnx_node_attr.f();
}
} }
if (op != nullptr) { op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive = std::make_unique<schema::PrimitiveT>(); if (op->primitive == nullptr) {
op->primitive->value.type = schema::PrimitiveType_Lrn; MS_LOG(ERROR) << "op->primitive is null";
op->primitive->value.value = attr.release(); return RET_NULL_PTR;
} }
std::unique_ptr<schema::LocalResponseNormalizationT> attr
= std::make_unique<schema::LocalResponseNormalizationT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
auto onnx_node_attr = onnx_node.attribute().at(0);
int32_t size = 0;
if (onnx_node_attr.name() == "size") {
size = static_cast<int32_t>(onnx_node_attr.i());
attr->depth_radius = static_cast<int32_t>(size / 2);
} else {
MS_LOG(ERROR) << "the first attr is not size";
return RET_ERROR;
}
onnx_node_attr = onnx_node.attribute().at(1);
if (onnx_node_attr.name() == "alpha") {
auto alpha = onnx_node_attr.f();
attr->alpha = alpha / size;
} else {
MS_LOG(ERROR) << "the second attr is not alpha";
return RET_ERROR;
}
onnx_node_attr = onnx_node.attribute().at(2);
if (onnx_node_attr.name() == "beta") {
attr->beta = onnx_node_attr.f();
} else {
MS_LOG(ERROR) << "the third attr is not beta";
return RET_ERROR;
}
onnx_node_attr = onnx_node.attribute().at(3);
if (onnx_node_attr.name() == "bias") {
attr->bias = onnx_node_attr.f();
} else {
MS_LOG(ERROR) << "the third attr is not bias";
return RET_ERROR;
}
op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_LRN_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LRN_PARSER_H
#define MS_ONNX_LRN_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LRN_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxLrnParser : public OnnxNodeParser { class OnnxLrnParser : public OnnxNodeParser {
public: public:
OnnxLrnParser() : OnnxNodeParser("Lrn") {} OnnxLrnParser() : OnnxNodeParser("Lrn") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_LRN_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LRN_PARSER_H

View File

@ -14,15 +14,31 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_matmul_parser.h" #include "tools/converter/parser/onnx/onnx_matmul_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx MatMulParser"; MS_LOG(DEBUG) << "onnx MatMulParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::MatMulT> attr = std::make_unique<schema::MatMulT>(); std::unique_ptr<schema::MatMulT> attr = std::make_unique<schema::MatMulT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
float alpha = 1.0f; float alpha = 1.0f;
float beta = 1.0f; float beta = 1.0f;
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -39,14 +55,11 @@ STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
} }
if (alpha != 1 || beta != 1) { if (alpha != 1 || beta != 1) {
MS_LOG(ERROR) << "not support alpha * A * B + beta * C"; MS_LOG(ERROR) << "not support alpha * A * B + beta * C";
return RET_PARAM_INVALID; return RET_ERROR;
} }
if (op != nullptr) { op->primitive->value.type = schema::PrimitiveType_MatMul;
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.value = attr.release();
op->primitive->value.type = schema::PrimitiveType_MatMul;
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_MATMUL_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MATMUL_PARSER_H
#define MS_ONNX_MATMUL_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MATMUL_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxMatmulParser : public OnnxNodeParser { class OnnxMatmulParser : public OnnxNodeParser {
public: public:
OnnxMatmulParser() : OnnxNodeParser("MatMul") {} OnnxMatmulParser() : OnnxNodeParser("MatMul") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_MATMUL_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MATMUL_PARSER_H

View File

@ -14,11 +14,11 @@
* limitations under the License. * limitations under the License.
*/ */
#include "tools/converter/parser/onnx/onnx_model_parser.h"
#include <cfloat> #include <cfloat>
#include <unordered_map> #include <unordered_map>
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
#include "tools/converter/parser/onnx/onnx_model_parser.h"
#include "tools/common/graph_util.h" #include "tools/common/graph_util.h"
#include "src/common/utils.h" #include "src/common/utils.h"
@ -54,7 +54,8 @@ std::vector<int32_t> OnnxModelParser::GetDimsFromOnnxValue(const onnx::ValueInfo
return dims; return dims;
} }
STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile, google::protobuf::Message *onnx_model) { STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile,
google::protobuf::Message *onnx_model) {
std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0}); std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0});
#ifdef _WIN32 #ifdef _WIN32
if (_fullpath(onnx_file.get(), modelFile.c_str(), 1024) == nullptr) { if (_fullpath(onnx_file.get(), modelFile.c_str(), 1024) == nullptr) {
@ -81,7 +82,8 @@ STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile, go
return RET_OK; return RET_OK;
} }
STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache) { STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph,
TensorCache *tensor_cache) {
MS_LOG(DEBUG) << "set onnx constant tensors"; MS_LOG(DEBUG) << "set onnx constant tensors";
for (const auto &onnx_const_value : onnx_graph.initializer()) { for (const auto &onnx_const_value : onnx_graph.initializer()) {
int index; int index;
@ -117,8 +119,11 @@ STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph,
return RET_OK; return RET_OK;
} }
STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const TensorType &type, STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto,
TensorCache *tensor_cache, int *index) { const std::string &name,
const TensorType &type,
TensorCache *tensor_cache,
int *index) {
auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type())); auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type()));
if (data_type == kTypeUnknown) { if (data_type == kTypeUnknown) {
MS_LOG(ERROR) << "not support onnx data type " MS_LOG(ERROR) << "not support onnx data type "
@ -137,8 +142,12 @@ STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const st
*index = tensor_cache->AddTensor(name, tensor.release(), type); *index = tensor_cache->AddTensor(name, tensor.release(), type);
return RET_OK; return RET_OK;
} }
STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const TensorType &type,
TensorCache *tensor_cache, int *index) { STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto,
const std::string &name,
const TensorType &type,
TensorCache *tensor_cache,
int *index) {
auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.data_type())); auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.data_type()));
if (data_type == kTypeUnknown) { if (data_type == kTypeUnknown) {
MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(proto.data_type()); MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(proto.data_type());
@ -165,7 +174,8 @@ STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std
return RET_OK; return RET_OK;
} }
STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph,
schema::MetaGraphT *graph,
TensorCache *tensor_cache) { TensorCache *tensor_cache) {
for (const auto &input_value : onnx_graph.input()) { for (const auto &input_value : onnx_graph.input()) {
auto ret = tensor_cache->FindTensor(input_value.name()); auto ret = tensor_cache->FindTensor(input_value.name());
@ -182,7 +192,8 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph,
return RET_OK; return RET_OK;
} }
STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
schema::MetaGraphT *graph,
TensorCache *tensor_cache) { TensorCache *tensor_cache) {
for (const auto &output_value : onnx_graph.output()) { for (const auto &output_value : onnx_graph.output()) {
int index; int index;
@ -196,8 +207,10 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
return RET_OK; return RET_OK;
} }
void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph,
schema::MetaGraphT *graph, TensorCache *tensor_cache) { const onnx::NodeProto &onnx_node,
schema::MetaGraphT *graph,
TensorCache *tensor_cache) {
std::unique_ptr<schema::CNodeT> dst_op_1 = std::make_unique<schema::CNodeT>(); std::unique_ptr<schema::CNodeT> dst_op_1 = std::make_unique<schema::CNodeT>();
dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0); dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0);
ParseOnnxNodeAttr(onnx_graph, onnx_node, "MatMul", dst_op_1.get()); ParseOnnxNodeAttr(onnx_graph, onnx_node, "MatMul", dst_op_1.get());
@ -218,7 +231,8 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, cons
graph->nodes.emplace_back(std::move(dst_op_2)); graph->nodes.emplace_back(std::move(dst_op_2));
} }
STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) { STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
TensorCache *tensor_cache) {
// convert GivenTensorFill node to a weight/bias tensor // convert GivenTensorFill node to a weight/bias tensor
auto ret = tensor_cache->FindTensor(onnx_node.output(0)); auto ret = tensor_cache->FindTensor(onnx_node.output(0));
if (ret < 0) { if (ret < 0) {
@ -270,8 +284,10 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
return RET_OK; return RET_OK;
} }
STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
schema::CNodeT *dst_op, schema::TensorT *dst_tensor, const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op,
schema::TensorT *dst_tensor,
TensorCache *tensor_cache) { TensorCache *tensor_cache) {
// change op_type() to name(), that is unique // change op_type() to name(), that is unique
dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0); dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0);
@ -303,8 +319,11 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
return RET_OK; return RET_OK;
} }
void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph,
schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache) { const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op,
schema::TensorT *dst_tensor,
TensorCache *tensor_cache) {
MS_ASSERT(dst_op != nullptr); MS_ASSERT(dst_op != nullptr);
MS_ASSERT(tensor_cache != nullptr); MS_ASSERT(tensor_cache != nullptr);
std::vector<string> quant_node_name; std::vector<string> quant_node_name;
@ -361,8 +380,10 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const
} }
} }
STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph,
const string &onnx_op_type, schema::CNodeT *dst_op) { const onnx::NodeProto &onnx_node,
const string &onnx_op_type,
schema::CNodeT *dst_op) {
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_op_type); auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_op_type);
if (node_parser == nullptr) { if (node_parser == nullptr) {
MS_LOG(EXCEPTION) << "not find " << onnx_op_type << ", node parser is nullptr"; MS_LOG(EXCEPTION) << "not find " << onnx_op_type << ", node parser is nullptr";
@ -371,8 +392,10 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, co
return node_parser->Parse(onnx_graph, onnx_node, dst_op); return node_parser->Parse(onnx_graph, onnx_node, dst_op);
} }
STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op, STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs,
const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) { schema::CNodeT *dst_op,
const onnx::NodeProto &onnx_node,
TensorCache *tensor_cache) {
for (const auto &onnx_node_input : node_inputs) { for (const auto &onnx_node_input : node_inputs) {
auto index = tensor_cache->FindTensor(onnx_node_input); auto index = tensor_cache->FindTensor(onnx_node_input);
if (index < 0) { if (index < 0) {
@ -385,7 +408,8 @@ STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs,
return RET_OK; return RET_OK;
} }
STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op, STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs,
schema::CNodeT *dst_op,
TensorCache *tensor_cache) { TensorCache *tensor_cache) {
for (const auto &onnx_node_output : node_outputs) { for (const auto &onnx_node_output : node_outputs) {
auto index = tensor_cache->FindTensor(onnx_node_output); auto index = tensor_cache->FindTensor(onnx_node_output);
@ -400,7 +424,8 @@ STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs
return RET_OK; return RET_OK;
} }
STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, schema::TensorT *tensor) { STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value,
schema::TensorT *tensor) {
size_t data_count = 1; size_t data_count = 1;
std::for_each(tensor->dims.begin(), tensor->dims.end(), [&data_count](int dim) { data_count *= dim; }); std::for_each(tensor->dims.begin(), tensor->dims.end(), [&data_count](int dim) { data_count *= dim; });
size_t data_size = 0; size_t data_size = 0;
@ -459,7 +484,8 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
return RET_OK; return RET_OK;
} }
STATUS OnnxModelParser::SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef) { STATUS OnnxModelParser::SetAllTensors(const TensorCache &tensor_cache,
schema::MetaGraphT *graphDef) {
std::vector<schema::TensorT *> tensors = tensor_cache.GetCachedTensor(); std::vector<schema::TensorT *> tensors = tensor_cache.GetCachedTensor();
for (auto iter : tensors) { for (auto iter : tensors) {
std::unique_ptr<schema::TensorT> temp(iter); std::unique_ptr<schema::TensorT> temp(iter);
@ -481,7 +507,8 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph)
} }
} }
MetaGraphT *OnnxModelParser::Parse(const std::string &modelFile, const std::string &weightFile, MetaGraphT *OnnxModelParser::Parse(const std::string &modelFile,
const std::string &weightFile,
const QuantType &quantType) { const QuantType &quantType) {
if (ValidateFileStr(modelFile, ".onnx") != RET_OK) { if (ValidateFileStr(modelFile, ".onnx") != RET_OK) {
MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx"; MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx";

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_MODEL_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MODEL_PARSER_H
#define MS_ONNX_MODEL_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MODEL_PARSER_H
#include <google/protobuf/io/coded_stream.h> #include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl.h> #include <google/protobuf/io/zero_copy_stream_impl.h>
@ -37,35 +37,83 @@ namespace lite {
class OnnxModelParser : public ModelParser { class OnnxModelParser : public ModelParser {
public: public:
OnnxModelParser(); OnnxModelParser();
virtual ~OnnxModelParser(); virtual ~OnnxModelParser();
MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile, MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override; const QuantType &quantType = QuantType_QUANT_NONE) override;
private: private:
TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value); std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value);
STATUS ReadOnnxModelFromBinary(const std::string &modelFile, google::protobuf::Message *model_proto);
STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache); STATUS ReadOnnxModelFromBinary(const std::string &modelFile,
STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache); google::protobuf::Message *model_proto);
STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache);
STATUS AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const TensorType &type, STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph,
TensorCache *tensor_cache, int *index); TensorCache *tensor_cache);
STATUS AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const TensorType &type,
TensorCache *tensor_cache, int *index); STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph,
STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::MetaGraphT *graph,
schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache); TensorCache *tensor_cache);
void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::MetaGraphT *graph, TensorCache *tensor_cache); STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache); schema::MetaGraphT *graph,
STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, TensorCache *tensor_cache);
const string &onnx_op_type, schema::CNodeT *dst_op);
void SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op, STATUS AddValueInfo(const onnx::ValueInfoProto &proto,
schema::TensorT *dst_tensor, TensorCache *tensor_cache); const std::string &name,
STATUS SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op, const TensorType &type,
const onnx::NodeProto &onnx_node, TensorCache *tensor_cache); TensorCache *tensor_cache,
STATUS SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op, TensorCache *tensor_cache); int *index);
STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, schema::TensorT *tensor);
STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef); STATUS AddTensorProto(const onnx::TensorProto &proto,
const std::string &name,
const TensorType &type,
TensorCache *tensor_cache,
int *index);
STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op,
schema::TensorT *dst_tensor,
TensorCache *tensor_cache);
void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::MetaGraphT *graph,
TensorCache *tensor_cache);
STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
TensorCache *tensor_cache);
STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
const string &onnx_op_type,
schema::CNodeT *dst_op);
void SetOpQuantParams(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op,
schema::TensorT *dst_tensor,
TensorCache *tensor_cache);
STATUS SetOpInputIndex(const std::vector<string> &node_inputs,
schema::CNodeT *dst_op,
const onnx::NodeProto &onnx_node,
TensorCache *tensor_cache);
STATUS SetOpOutputIndex(const std::vector<string> &node_outputs,
schema::CNodeT *dst_op,
TensorCache *tensor_cache);
STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value,
schema::TensorT *tensor);
STATUS SetAllTensors(const TensorCache &tensor_cache,
schema::MetaGraphT *graphDef);
void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph); void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph);
private: private:
@ -75,4 +123,4 @@ class OnnxModelParser : public ModelParser {
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_MODEL_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MODEL_PARSER_H

View File

@ -15,6 +15,7 @@
*/ */
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include <vector>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -30,6 +31,20 @@ schema::PadMode OnnxNodeParser::GetOnnxPadMode(const onnx::AttributeProto &onnx_
return schema::PadMode_NOTSET; return schema::PadMode_NOTSET;
} }
} }
void OnnxNodeParser::Split(const std::string &src_str,
std::vector<std::string> *dst_str,
const std::string &chr) {
std::string ::size_type p1 = 0, p2 = src_str.find(chr);
while (std::string::npos != p2) {
dst_str->push_back(src_str.substr(p1, p2 - p1));
p1 = p2 + chr.size();
p2 = src_str.find(chr, p1);
}
if (p1 != src_str.length()) {
dst_str->push_back(src_str.substr(p1));
}
}
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -14,10 +14,11 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_NODE_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_PARSER_H
#define MS_ONNX_NODE_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_PARSER_H
#include <string> #include <string>
#include <vector>
#include "google/protobuf/message.h" #include "google/protobuf/message.h"
#include "tools/converter/parser/onnx/onnx.pb.h" #include "tools/converter/parser/onnx/onnx.pb.h"
#include "include/errorcode.h" #include "include/errorcode.h"
@ -29,14 +30,23 @@ namespace lite {
class OnnxNodeParser { class OnnxNodeParser {
public: public:
explicit OnnxNodeParser(const std::string &nodeName) : name(nodeName) {} explicit OnnxNodeParser(const std::string &nodeName) : name(nodeName) {}
virtual ~OnnxNodeParser() = default; virtual ~OnnxNodeParser() = default;
virtual STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) = 0;
virtual STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) = 0;
protected: protected:
schema::PadMode GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr); schema::PadMode GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr);
void Split(const std::string &src_str,
std::vector<std::string> *dst_str,
const std::string &chr);
const std::string &name; const std::string &name;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_NODE_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_PARSER_H

View File

@ -21,7 +21,14 @@ namespace mindspore {
namespace lite { namespace lite {
OnnxNodeParserRegistry::OnnxNodeParserRegistry() = default; OnnxNodeParserRegistry::OnnxNodeParserRegistry() = default;
OnnxNodeParserRegistry::~OnnxNodeParserRegistry() = default; OnnxNodeParserRegistry::~OnnxNodeParserRegistry() {
for (auto ite : parsers) {
if (ite.second != nullptr) {
delete ite.second;
ite.second = nullptr;
}
}
}
OnnxNodeParserRegistry *OnnxNodeParserRegistry::GetInstance() { OnnxNodeParserRegistry *OnnxNodeParserRegistry::GetInstance() {
static OnnxNodeParserRegistry instance; static OnnxNodeParserRegistry instance;

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_OP_REGISTRY_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_REGISTRY_H
#define MS_ONNX_OP_REGISTRY_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_REGISTRY_H
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
@ -30,6 +30,7 @@ class OnnxNodeParserRegistry {
virtual ~OnnxNodeParserRegistry(); virtual ~OnnxNodeParserRegistry();
static OnnxNodeParserRegistry *GetInstance(); static OnnxNodeParserRegistry *GetInstance();
OnnxNodeParser *GetNodeParser(const std::string &name); OnnxNodeParser *GetNodeParser(const std::string &name);
std::unordered_map<std::string, OnnxNodeParser *> parsers; std::unordered_map<std::string, OnnxNodeParser *> parsers;
@ -37,12 +38,13 @@ class OnnxNodeParserRegistry {
class OnnxNodeRegistrar { class OnnxNodeRegistrar {
public: public:
OnnxNodeRegistrar(const std::string &name, OnnxNodeParser *parser) { OnnxNodeRegistrar(const std::string &name,
OnnxNodeParser *parser) {
OnnxNodeParserRegistry::GetInstance()->parsers[name] = parser; OnnxNodeParserRegistry::GetInstance()->parsers[name] = parser;
} }
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_OP_REGISTRY_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_REGISTRY_H

View File

@ -14,14 +14,31 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_pad_parser.h" #include "tools/converter/parser/onnx/onnx_pad_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx PadParser"; MS_LOG(DEBUG) << "onnx PadParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::PadT> attr = std::make_unique<schema::PadT>(); std::unique_ptr<schema::PadT> attr = std::make_unique<schema::PadT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name(); const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "pads") { if (attribute_name == "pads") {
@ -42,11 +59,9 @@ STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node
} }
} }
} }
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.type = schema::PrimitiveType_Pad;
op->primitive->value.type = schema::PrimitiveType_Pad; op->primitive->value.value = attr.release();
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_LRN_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_PAD_PARSER_H
#define MS_ONNX_LRN_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_PAD_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxPadParser : public OnnxNodeParser { class OnnxPadParser : public OnnxNodeParser {
public: public:
OnnxPadParser() : OnnxNodeParser("Pad") {} OnnxPadParser() : OnnxNodeParser("Pad") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_LRN_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_PAD_PARSER_H

View File

@ -14,14 +14,30 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_pool_parser.h" #include "tools/converter/parser/onnx/onnx_pool_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx PoolParser"; MS_LOG(DEBUG) << "onnx PoolParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::PoolingT> attr = std::make_unique<schema::PoolingT>(); std::unique_ptr<schema::PoolingT> attr = std::make_unique<schema::PoolingT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->format = schema::Format_NCHW; attr->format = schema::Format_NCHW;
const auto &pool_type = onnx_node.op_type(); const auto &pool_type = onnx_node.op_type();
@ -79,11 +95,9 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
return RET_ERROR; return RET_ERROR;
} }
} }
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.type = schema::PrimitiveType_Pooling;
op->primitive->value.type = schema::PrimitiveType_Pooling; op->primitive->value.value = attr.release();
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_POOL_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_POOL_PARSER_H
#define MS_ONNX_POOL_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_POOL_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxPoolParser : public OnnxNodeParser { class OnnxPoolParser : public OnnxNodeParser {
public: public:
OnnxPoolParser() : OnnxNodeParser("Pool") {} OnnxPoolParser() : OnnxNodeParser("Pool") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_POOL_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_POOL_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_reduce_parser.h" #include "tools/converter/parser/onnx/onnx_reduce_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -23,7 +23,22 @@ STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ReduceParser"; MS_LOG(DEBUG) << "onnx ReduceParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ReduceT> attr = std::make_unique<schema::ReduceT>(); std::unique_ptr<schema::ReduceT> attr = std::make_unique<schema::ReduceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name(); const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axes") { if (attribute_name == "axes") {
@ -45,13 +60,12 @@ STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph,
} else if (type == "ReduceSum") { } else if (type == "ReduceSum") {
attr->mode = schema::ReduceMode_ReduceSum; attr->mode = schema::ReduceMode_ReduceSum;
} else { } else {
// MS_LOGE("unsupoort type"); MS_LOG(ERROR) << "unsupported type";
} return RET_ERROR;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Reduce;
op->primitive->value.value = attr.release();
} }
op->primitive->value.type = schema::PrimitiveType_Reduce;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_REDUCE_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_REDUCE_PARSER_H
#define MS_ONNX_REDUCE_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_REDUCE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxReduceParser : public OnnxNodeParser { class OnnxReduceParser : public OnnxNodeParser {
public: public:
OnnxReduceParser() : OnnxNodeParser("Reduce") {} OnnxReduceParser() : OnnxNodeParser("Reduce") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_REDUCE_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_REDUCE_PARSER_H

View File

@ -14,36 +14,64 @@
* limitations under the License. * limitations under the License.
*/ */
#include "tools/converter/parser/onnx/onnx_relu_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "tools/converter/parser/onnx/onnx_relu_parser.h"
#include "securec/include/securec.h" #include "securec/include/securec.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ReluParser"; MS_LOG(DEBUG) << "onnx ReluParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>(); std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
const auto &relu_type = onnx_node.op_type(); const auto &relu_type = onnx_node.op_type();
if (relu_type == "Relu") { if (relu_type == "Relu") {
MS_LOG(DEBUG) << "onnx ReluParser";
attr->type = schema::ActivationType_RELU; attr->type = schema::ActivationType_RELU;
} else if (relu_type == "LeakyRelu") { } else if (relu_type == "LeakyRelu") {
MS_LOG(DEBUG) << "onnx LeakyReluParser";
attr->type = schema::ActivationType_LEAKY_RELU; attr->type = schema::ActivationType_LEAKY_RELU;
} }
if (op != nullptr) { op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.value = attr.release();
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }
STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx PReluParser"; MS_LOG(DEBUG) << "onnx PReluParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
if (onnx_node.input_size() != 2) { if (onnx_node.input_size() != 2) {
MS_LOG(ERROR) << "input num is not 2"; MS_LOG(ERROR) << "input num should be 2";
return RET_PARAM_INVALID; return RET_ERROR;
} }
std::unique_ptr<schema::CaffePReLUT> attr = std::make_unique<schema::CaffePReLUT>(); std::unique_ptr<schema::CaffePReLUT> attr = std::make_unique<schema::CaffePReLUT>();
std::vector<onnx::TensorProto> params; std::vector<onnx::TensorProto> params;
@ -57,8 +85,8 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
const onnx::TensorProto *slope = &params[0]; const onnx::TensorProto *slope = &params[0];
if (slope == nullptr) { if (slope == nullptr) {
MS_LOG(ERROR) << "input error"; MS_LOG(ERROR) << "input error: params[0] is null";
return RET_PARAM_INVALID; return RET_ERROR;
} }
const auto slope_raw_data = reinterpret_cast<const float *>(slope->raw_data().data()); const auto slope_raw_data = reinterpret_cast<const float *>(slope->raw_data().data());
const int64_t slope_size = slope->raw_data().size() / sizeof(float); const int64_t slope_size = slope->raw_data().size() / sizeof(float);
@ -74,11 +102,8 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
} }
} }
if (op != nullptr) { op->primitive->value.type = schema::PrimitiveType_CaffePReLU;
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.value = attr.release();
op->primitive->value.type = schema::PrimitiveType_CaffePReLU;
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_RELU_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RELU_PARSER_H
#define MS_ONNX_RELU_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RELU_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,7 +25,10 @@ namespace lite {
class OnnxReluParser : public OnnxNodeParser { class OnnxReluParser : public OnnxNodeParser {
public: public:
OnnxReluParser() : OnnxNodeParser("Relu") {} OnnxReluParser() : OnnxNodeParser("Relu") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
class OnnxLeakeyReluParser : public OnnxReluParser { class OnnxLeakeyReluParser : public OnnxReluParser {
@ -36,9 +39,12 @@ class OnnxLeakeyReluParser : public OnnxReluParser {
class OnnxPReluParser : public OnnxNodeParser { class OnnxPReluParser : public OnnxNodeParser {
public: public:
OnnxPReluParser() : OnnxNodeParser("Prelu") {} OnnxPReluParser() : OnnxNodeParser("Prelu") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_RELU_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RELU_PARSER_H

View File

@ -14,16 +14,32 @@
* limitations under the License. * limitations under the License.
*/ */
#include "tools/converter/parser/onnx/onnx_reshape_parser.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "tools/converter/parser/onnx/onnx_reshape_parser.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ReshapeParser"; MS_LOG(DEBUG) << "onnx ReshapeParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ReshapeT> attr = std::make_unique<schema::ReshapeT>(); std::unique_ptr<schema::ReshapeT> attr = std::make_unique<schema::ReshapeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->format = schema::Format_NCHW; attr->format = schema::Format_NCHW;
std::vector<onnx::TensorProto> params; std::vector<onnx::TensorProto> params;
for (int i = 0; i < onnx_node.input_size(); ++i) { for (int i = 0; i < onnx_node.input_size(); ++i) {
@ -40,18 +56,16 @@ STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::
} else { } else {
if (params.size() != 1) { if (params.size() != 1) {
MS_LOG(ERROR) << "shape param num is " << params.size() << ", not equal to 1"; MS_LOG(ERROR) << "shape param num is " << params.size() << ", not equal to 1";
return RET_PARAM_INVALID; return RET_ERROR;
} }
for (int i = 0; i < params[0].int64_data_size(); ++i) { for (int i = 0; i < params[0].int64_data_size(); ++i) {
attr->shape.emplace_back(params[0].int64_data(i)); attr->shape.emplace_back(params[0].int64_data(i));
} }
} }
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.type = schema::PrimitiveType_Reshape;
op->primitive->value.type = schema::PrimitiveType_Reshape; op->primitive->value.value = attr.release();
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_RESHAPE_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESHAPE_PARSER_H
#define MS_ONNX_RESHAPE_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESHAPE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxReshapeParser : public OnnxNodeParser { class OnnxReshapeParser : public OnnxNodeParser {
public: public:
OnnxReshapeParser() : OnnxNodeParser("Reshape") {} OnnxReshapeParser() : OnnxNodeParser("Reshape") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_RESHAPE_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESHAPE_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_shape_parser.h" #include "tools/converter/parser/onnx/onnx_shape_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -23,12 +23,24 @@ STATUS OnnxShapeParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ShapeParser"; MS_LOG(DEBUG) << "onnx ShapeParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::ShapeT> attr = std::make_unique<schema::ShapeT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Shape;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ShapeT> attr = std::make_unique<schema::ShapeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Shape;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_SHAPE_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SHAPE_PARSER_H
#define MS_ONNX_SHAPE_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SHAPE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxShapeParser : public OnnxNodeParser { class OnnxShapeParser : public OnnxNodeParser {
public: public:
OnnxShapeParser() : OnnxNodeParser("Shape") {} OnnxShapeParser() : OnnxNodeParser("Shape") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_SHAPE_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SHAPE_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_sigmoid_parser.h" #include "tools/converter/parser/onnx/onnx_sigmoid_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -23,13 +23,26 @@ STATUS OnnxSigmoidParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SigmoidParser"; MS_LOG(DEBUG) << "onnx SigmoidParser";
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>(); if (op == nullptr) {
attr->type = schema::ActivationType_SIGMOID; MS_LOG(ERROR) << "op is null";
if (op != nullptr) { return RET_NULL_PTR;
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->type = schema::ActivationType_SIGMOID;
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_SIGMOID_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SIGMOID_PARSER_H
#define MS_ONNX_SIGMOID_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SIGMOID_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxSigmoidParser : public OnnxNodeParser { class OnnxSigmoidParser : public OnnxNodeParser {
public: public:
OnnxSigmoidParser() : OnnxNodeParser("Sigmoid") {} OnnxSigmoidParser() : OnnxNodeParser("Sigmoid") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_SIGMOID_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SIGMOID_PARSER_H

View File

@ -14,15 +14,30 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_slice_parser.h" #include "tools/converter/parser/onnx/onnx_slice_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SliceParser"; MS_LOG(DEBUG) << "onnx SliceParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::SliceT> attr = std::make_unique<schema::SliceT>(); std::unique_ptr<schema::SliceT> attr = std::make_unique<schema::SliceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name(); const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "starts") { if (attribute_name == "starts") {
@ -38,11 +53,9 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
} }
} }
} }
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.type = schema::PrimitiveType_Slice;
op->primitive->value.type = schema::PrimitiveType_Slice; op->primitive->value.value = attr.release();
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_SLICE_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SLICE_PARSER_H
#define MS_ONNX_SLICE_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SLICE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxSliceParser : public OnnxNodeParser { class OnnxSliceParser : public OnnxNodeParser {
public: public:
OnnxSliceParser() : OnnxNodeParser("Slice") {} OnnxSliceParser() : OnnxNodeParser("Slice") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_SLICE_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SLICE_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_softmax_parser.h" #include "tools/converter/parser/onnx/onnx_softmax_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -23,18 +23,31 @@ STATUS OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SoftMaxParser"; MS_LOG(DEBUG) << "onnx SoftMaxParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::SoftMaxT> attr = std::make_unique<schema::SoftMaxT>(); std::unique_ptr<schema::SoftMaxT> attr = std::make_unique<schema::SoftMaxT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto& attribute_name = onnx_node_attr.name(); const auto& attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") { if (attribute_name == "axis") {
attr->axis = static_cast<int32_t>(onnx_node_attr.i()); attr->axis = static_cast<int32_t>(onnx_node_attr.i());
} }
} }
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.type = schema::PrimitiveType_SoftMax;
op->primitive->value.type = schema::PrimitiveType_SoftMax; op->primitive->value.value = attr.release();
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_SOFTMAX_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SOFTMAX_PARSER_H
#define MS_ONNX_SOFTMAX_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SOFTMAX_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxSoftMaxParser : public OnnxNodeParser { class OnnxSoftMaxParser : public OnnxNodeParser {
public: public:
OnnxSoftMaxParser() : OnnxNodeParser("Softmax") {} OnnxSoftMaxParser() : OnnxNodeParser("Softmax") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_SOFTMAX_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SOFTMAX_PARSER_H

View File

@ -14,26 +14,40 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_space_to_depth_parser.h" #include "tools/converter/parser/onnx/onnx_space_to_depth_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxSpaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, STATUS OnnxSpaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SpaceToDepthParser"; MS_LOG(DEBUG) << "onnx SpaceToDepthParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::SpaceToDepthT> attr = std::make_unique<schema::SpaceToDepthT>(); std::unique_ptr<schema::SpaceToDepthT> attr = std::make_unique<schema::SpaceToDepthT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name(); const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "blocksize") { if (attribute_name == "blocksize") {
attr->blockSize = static_cast<int32_t>(onnx_node_attr.i()); attr->blockSize = static_cast<int32_t>(onnx_node_attr.i());
} }
} }
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.type = schema::PrimitiveType_SpaceToDepth;
op->primitive->value.type = schema::PrimitiveType_SpaceToDepth; op->primitive->value.value = attr.release();
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_SPACE_TO_DEPTH_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPACE_TO_DEPTH_PARSER_H
#define MS_ONNX_SPACE_TO_DEPTH_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPACE_TO_DEPTH_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxSpaceToDepthParser : public OnnxNodeParser { class OnnxSpaceToDepthParser : public OnnxNodeParser {
public: public:
OnnxSpaceToDepthParser() : OnnxNodeParser("SpaceToDepth") {} OnnxSpaceToDepthParser() : OnnxNodeParser("SpaceToDepth") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_SPACE_TO_DEPTH_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPACE_TO_DEPTH_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_squeeze_parser.h" #include "tools/converter/parser/onnx/onnx_squeeze_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -23,7 +23,22 @@ STATUS OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SqueezeParser"; MS_LOG(DEBUG) << "onnx SqueezeParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::SqueezeT> attr = std::make_unique<schema::SqueezeT>(); std::unique_ptr<schema::SqueezeT> attr = std::make_unique<schema::SqueezeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name(); const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axes") { if (attribute_name == "axes") {
@ -32,11 +47,9 @@ STATUS OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph,
} }
} }
} }
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.type = schema::PrimitiveType_Squeeze;
op->primitive->value.type = schema::PrimitiveType_Squeeze; op->primitive->value.value = attr.release();
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_SQUEEZE_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SQUEEZE_PARSER_H
#define MS_ONNX_SQUEEZE_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SQUEEZE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxSqueezeParser : public OnnxNodeParser { class OnnxSqueezeParser : public OnnxNodeParser {
public: public:
OnnxSqueezeParser() : OnnxNodeParser("Squeeze") {} OnnxSqueezeParser() : OnnxNodeParser("Squeeze") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_SQUEEZE_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SQUEEZE_PARSER_H

View File

@ -14,19 +14,33 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_tile_parser.h" #include "tools/converter/parser/onnx/onnx_tile_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx TileParser"; MS_LOG(DEBUG) << "onnx TileParser";
if (op != nullptr) { if (op == nullptr) {
std::unique_ptr<schema::TileT> attr = std::make_unique<schema::TileT>(); MS_LOG(ERROR) << "op is null";
op->primitive = std::make_unique<schema::PrimitiveT>(); return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_Tile;
op->primitive->value.value = attr.release();
} }
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::TileT> attr = std::make_unique<schema::TileT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Tile;
op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_TILE_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TILE_PARSER_H
#define MS_ONNX_TILE_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TILE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxTileParser : public OnnxNodeParser { class OnnxTileParser : public OnnxNodeParser {
public: public:
OnnxTileParser() : OnnxNodeParser("Tile") {} OnnxTileParser() : OnnxNodeParser("Tile") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_ARGMAX_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TILE_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_transpose_parser.h" #include "tools/converter/parser/onnx/onnx_transpose_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -23,7 +23,22 @@ STATUS OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx TransposeParser"; MS_LOG(DEBUG) << "onnx TransposeParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::TransposeT> attr = std::make_unique<schema::TransposeT>(); std::unique_ptr<schema::TransposeT> attr = std::make_unique<schema::TransposeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->conjugate = false; attr->conjugate = false;
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name(); const auto &attribute_name = onnx_node_attr.name();
@ -40,11 +55,9 @@ STATUS OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph,
} }
} }
} }
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.type = schema::PrimitiveType_Transpose;
op->primitive->value.type = schema::PrimitiveType_Transpose; op->primitive->value.value = attr.release();
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_TRANSPOSE_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TRANSPOSE_PARSER_H
#define MS_ONNX_TRANSPOSE_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TRANSPOSE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxTransposeParser : public OnnxNodeParser { class OnnxTransposeParser : public OnnxNodeParser {
public: public:
OnnxTransposeParser() : OnnxNodeParser("Transpose") {} OnnxTransposeParser() : OnnxNodeParser("Transpose") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_TRANSPOSE_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TRANSPOSE_PARSER_H

View File

@ -23,7 +23,22 @@ STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx UpsampleParser"; MS_LOG(DEBUG) << "onnx UpsampleParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::UpsampleT> attr = std::make_unique<schema::UpsampleT>(); std::unique_ptr<schema::UpsampleT> attr = std::make_unique<schema::UpsampleT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name(); const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "mode") { if (attribute_name == "mode") {
@ -34,12 +49,9 @@ STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph,
} }
} }
} }
// to do
if (op != nullptr) { op->primitive->value.type = schema::PrimitiveType_Upsample;
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.value = attr.release();
op->primitive->value.type = schema::PrimitiveType_Upsample;
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_UPSAMPLE_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UPSAMPLE_PARSER_H
#define MS_ONNX_UPSAMPLE_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UPSAMPLE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxUpsampleParser : public OnnxNodeParser { class OnnxUpsampleParser : public OnnxNodeParser {
public: public:
OnnxUpsampleParser() : OnnxNodeParser("Upsample") {} OnnxUpsampleParser() : OnnxNodeParser("Upsample") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_UPSAMPLE_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UPSAMPLE_PARSER_H

View File

@ -14,15 +14,30 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_unsqueeze_parser.h" #include "tools/converter/parser/onnx/onnx_unsqueeze_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS OnnxUnSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, STATUS OnnxUnSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx UnSqueezeParser"; MS_LOG(DEBUG) << "onnx UnSqueezeParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::UnsqueezeT> attr = std::make_unique<schema::UnsqueezeT>(); std::unique_ptr<schema::UnsqueezeT> attr = std::make_unique<schema::UnsqueezeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name(); const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axes") { if (attribute_name == "axes") {
@ -31,11 +46,9 @@ STATUS OnnxUnSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx
} }
} }
} }
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive->value.type = schema::PrimitiveType_Unsqueeze;
op->primitive->value.type = schema::PrimitiveType_Unsqueeze; op->primitive->value.value = attr.release();
op->primitive->value.value = attr.release();
}
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_UNSQUEEZE_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UNSQUEEZE_PARSER_H
#define MS_ONNX_UNSQUEEZE_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UNSQUEEZE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxUnSqueezeParser : public OnnxNodeParser { class OnnxUnSqueezeParser : public OnnxNodeParser {
public: public:
OnnxUnSqueezeParser() : OnnxNodeParser("Unsqueeze") {} OnnxUnSqueezeParser() : OnnxNodeParser("Unsqueeze") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_UNSQUEEZE_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UNSQUEEZE_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include <memory>
#include "tools/converter/parser/onnx/onnx_unuseful_node_parser.h" #include "tools/converter/parser/onnx/onnx_unuseful_node_parser.h"
#include <memory>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -23,25 +23,35 @@ STATUS OnnxUnusefulNodeParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) { schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx UnusefulNodeParser"; MS_LOG(DEBUG) << "onnx UnusefulNodeParser";
if (op != nullptr) { if (op == nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>(); MS_LOG(ERROR) << "op is null";
if (onnx_node.op_type() == "Int8Quantize") { return RET_NULL_PTR;
op->primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize; }
op->primitive->value.value = std::make_unique<schema::OnnxInt8QuantizeT>().release(); op->primitive = std::make_unique<schema::PrimitiveT>();
} else if (onnx_node.op_type() == "Int8Dequantize") { if (op->primitive == nullptr) {
op->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize; MS_LOG(ERROR) << "op->primitive is null";
op->primitive->value.value = std::make_unique<schema::OnnxInt8DequantizeT>().release(); return RET_NULL_PTR;
} else { }
// MS_LOGE("Unsupported nodeType: %s", onnx_node.op_type().c_str());
return RET_ERROR; if (onnx_node.op_type() == "Int8Quantize") {
std::unique_ptr<schema::OnnxInt8QuantizeT> attr = std::make_unique<schema::OnnxInt8QuantizeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
} }
if (op->primitive->value.value == nullptr) { op->primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize;
// MS_LOGE("new %s attr value failed", onnx_node.op_type().c_str()); op->primitive->value.value = attr.release();
return RET_ERROR; } else if (onnx_node.op_type() == "Int8Dequantize") {
std::unique_ptr<schema::OnnxInt8DequantizeT> attr = std::make_unique<schema::OnnxInt8DequantizeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
} }
op->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize;
op->primitive->value.value = attr.release();
} else { } else {
// MS_LOGE("Input opDef is nullptr"); MS_LOG(ERROR) << "Unsupported nodeType: " << onnx_node.op_type().c_str();
return RET_PARAM_INVALID; return RET_ERROR;
} }
return RET_OK; return RET_OK;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MS_ONNX_UNUSEFUL_PARSER_H #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H
#define MS_ONNX_UNUSEFUL_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxUnusefulNodeParser : public OnnxNodeParser { class OnnxUnusefulNodeParser : public OnnxNodeParser {
public: public:
OnnxUnusefulNodeParser() : OnnxNodeParser("UnusefulNode") {} OnnxUnusefulNodeParser() : OnnxNodeParser("UnusefulNode") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MS_ONNX_UNUSEFUL_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H