modify onnx parsers format

This commit is contained in:
lyvette 2020-08-21 10:34:43 +08:00
parent ec1cf059a7
commit 6313af4d2f
76 changed files with 1548 additions and 575 deletions

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_argmax_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
@ -23,7 +23,22 @@ STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ArgMaxParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ArgMaxT> attr = std::make_unique<schema::ArgMaxT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") {
@ -32,11 +47,9 @@ STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph,
attr->keepDims = static_cast<bool>(onnx_node_attr.i());
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_ArgMax;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_ArgMax;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_ARGMAX_PARSER_H
#define MS_ONNX_ARGMAX_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARGMAX_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARGMAX_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxArgMaxParser : public OnnxNodeParser {
public:
OnnxArgMaxParser() : OnnxNodeParser("ArgMax") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_ARGMAX_PARSER_H
#endif // MMINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARGMAX_PARSER_H

View File

@ -14,114 +14,238 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
STATUS OnnxAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx AddParser";
if (op != nullptr) {
std::unique_ptr<schema::AddT> attr = std::make_unique<schema::AddT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Add;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::AddT> attr = std::make_unique<schema::AddT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Add;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxSubParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SubParser";
if (op != nullptr) {
std::unique_ptr<schema::SubT> attr = std::make_unique<schema::SubT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Sub;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::SubT> attr = std::make_unique<schema::SubT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Sub;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxMulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx MulParser";
if (op != nullptr) {
std::unique_ptr<schema::MulT> attr = std::make_unique<schema::MulT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Mul;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::MulT> attr = std::make_unique<schema::MulT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Mul;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxDivParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx DivParser";
if (op != nullptr) {
std::unique_ptr<schema::DivT> attr = std::make_unique<schema::DivT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Div;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::DivT> attr = std::make_unique<schema::DivT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Div;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxPowParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx PowParser";
if (op != nullptr) {
std::unique_ptr<schema::PowerT> attr(new schema::PowerT());
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Power;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::PowerT> attr(new schema::PowerT());
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Power;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxEqualParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx EqualParser";
if (op != nullptr) {
std::unique_ptr<schema::EqualT> attr = std::make_unique<schema::EqualT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Equal;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::EqualT> attr = std::make_unique<schema::EqualT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Equal;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxLessParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx LessParser";
if (op != nullptr) {
std::unique_ptr<schema::LessT> attr = std::make_unique<schema::LessT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Less;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::LessT> attr = std::make_unique<schema::LessT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Less;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxGreaterParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx GreaterParser";
if (op != nullptr) {
std::unique_ptr<schema::GreaterT> attr = std::make_unique<schema::GreaterT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Greater;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::GreaterT> attr = std::make_unique<schema::GreaterT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Greater;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxMinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx MinParser";
if (op != nullptr) {
std::unique_ptr<schema::MinT> attr = std::make_unique<schema::MinT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Min;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::MinT> attr = std::make_unique<schema::MinT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Min;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx EltwiseParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::EltwiseT> attr = std::make_unique<schema::EltwiseT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
// there is no Prod in onnx
if (onnx_node.op_type() == "Sum") {
attr->mode = schema::EltwiseMode_SUM;
@ -129,143 +253,303 @@ STATUS OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::
attr->mode = schema::EltwiseMode_MAXIMUM;
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Eltwise;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Eltwise;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxFloorParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx FloorParser";
if (op != nullptr) {
std::unique_ptr<schema::FloorT> attr = std::make_unique<schema::FloorT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Floor;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::FloorT> attr = std::make_unique<schema::FloorT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Floor;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxAbsParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx AbsParser";
if (op != nullptr) {
std::unique_ptr<schema::AbsT> attr = std::make_unique<schema::AbsT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Abs;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::AbsT> attr = std::make_unique<schema::AbsT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Abs;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxNegParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx NegParser";
if (op != nullptr) {
std::unique_ptr<schema::NegT> attr = std::make_unique<schema::NegT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Neg;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::NegT> attr = std::make_unique<schema::NegT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Neg;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxExpParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ExpParser";
if (op != nullptr) {
std::unique_ptr<schema::ExpT> attr = std::make_unique<schema::ExpT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Exp;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ExpT> attr = std::make_unique<schema::ExpT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Exp;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxCosParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx CosParser";
if (op != nullptr) {
std::unique_ptr<schema::CosT> attr = std::make_unique<schema::CosT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Cos;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::CosT> attr = std::make_unique<schema::CosT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Cos;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxSinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SinParser";
if (op != nullptr) {
std::unique_ptr<schema::SinT> attr = std::make_unique<schema::SinT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Sin;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::SinT> attr = std::make_unique<schema::SinT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Sin;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxSqrtParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SqrtParser";
if (op != nullptr) {
std::unique_ptr<schema::SqrtT> attr = std::make_unique<schema::SqrtT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Sqrt;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::SqrtT> attr = std::make_unique<schema::SqrtT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Sqrt;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxCeilParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx CeilParser";
if (op != nullptr) {
std::unique_ptr<schema::CeilT> attr = std::make_unique<schema::CeilT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Ceil;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::CeilT> attr = std::make_unique<schema::CeilT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Ceil;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxLogParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx LogParser";
if (op != nullptr) {
std::unique_ptr<schema::LogT> attr = std::make_unique<schema::LogT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Log;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::LogT> attr = std::make_unique<schema::LogT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Log;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxTanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx TanParser";
if (op != nullptr) {
std::unique_ptr<schema::TanT> attr = std::make_unique<schema::TanT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Tan;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::TanT> attr = std::make_unique<schema::TanT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Tan;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxAtanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx AtanParser";
if (op != nullptr) {
std::unique_ptr<schema::AtanT> attr = std::make_unique<schema::AtanT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Atan;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::AtanT> attr = std::make_unique<schema::AtanT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Atan;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxAsinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx AsinParser";
if (op != nullptr) {
std::unique_ptr<schema::AsinT> attr = std::make_unique<schema::AsinT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Asin;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::AsinT> attr = std::make_unique<schema::AsinT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Asin;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxTanhParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx TanhParser";
if (op != nullptr) {
MS_LOG(ERROR) << "mslite don't support tanh now";
return RET_ERROR;
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
return RET_OK;
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(ERROR) << "mslite don't support tanh now";
return RET_ERROR;
}
OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser());

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_ARITHMETIC_OPREATION_PARSER_H
#define MS_ONNX_ARITHMETIC_OPREATION_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -167,5 +167,5 @@ class OnnxTanhParser : public OnnxNodeParser {
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_ARITHMETIC_OPREATION_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H

View File

@ -19,10 +19,26 @@
namespace mindspore {
namespace lite {
STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx BatchNormParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::FusedBatchNormT> attr = std::make_unique<schema::FusedBatchNormT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "epsilon") {
attr->epsilon = onnx_node_attr.f();
@ -32,11 +48,9 @@ STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx
attr->spatial = static_cast<int32_t>(onnx_node_attr.i());
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_ADD_PARSER_H
#define MS_ONNX_ADD_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BATCHNORM_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BATCHNORM_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxBatchNormParser : public OnnxNodeParser {
public:
OnnxBatchNormParser() : OnnxNodeParser("BatchNormalization") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_ADD_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BATCHNORM_PARSER_H

View File

@ -14,26 +14,36 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_biasadd_parser.h"
#include <memory>
// using namespace mindspore::predict;
// using namespace onnx;
// using namespace std;
namespace mindspore {
namespace lite {
STATUS OnnxBiasAddParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx BiasAddParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::BiasAddT> attr = std::make_unique<schema::BiasAddT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
// use channel dim as axis
attr->axis = {1};
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_BiasAdd;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_BiasAdd;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_BIASADD_PARSER_H
#define MS_ONNX_BIASADD_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BIASADD_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BIASADD_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -26,9 +26,11 @@ class OnnxBiasAddParser : public OnnxNodeParser {
public:
OnnxBiasAddParser() : OnnxNodeParser("BiasAdd") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_BIASADD_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BIASADD_PARSER_H

View File

@ -14,25 +14,40 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_cast_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx CastParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "to") {
attr->dstT = static_cast<int32_t>(onnx_node_attr.i());
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Cast;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Cast;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_CAST_PARSER_H
#define MS_ONNX_CAST_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CAST_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CAST_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxCastParser : public OnnxNodeParser {
public:
OnnxCastParser() : OnnxNodeParser("Cast") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_CAST_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CAST_PARSER_H

View File

@ -14,13 +14,25 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_clip_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ClipParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
float min = -1, max = -1;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
@ -32,15 +44,17 @@ STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
}
if (min == 0 && max == 6) {
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
attr->type = schema::ActivationType_RELU6;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->type = schema::ActivationType_RELU6;
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
} else {
MS_LOG(ERROR) << "only support convert clip(0,6) to relu6, other value is not supported";
return RET_PARAM_INVALID;
return RET_ERROR;
}
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_CLIP_PARSER_H
#define MS_ONNX_CLIP_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CLIP_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CLIP_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxClipParser : public OnnxNodeParser {
public:
OnnxClipParser() : OnnxNodeParser("Clip") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_ARGMAX_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CLIP_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_concat_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
@ -23,18 +23,31 @@ STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ConcatParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ConcatT> attr = std::make_unique<schema::ConcatT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") {
attr->axis = static_cast<int32_t>(onnx_node_attr.i());
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Concat;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Concat;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_CONCAT_PARSER_H
#define MS_ONNX_CONCAT_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONCAT_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONCAT_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxConcatParser : public OnnxNodeParser {
public:
OnnxConcatParser() : OnnxNodeParser("Concat") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_CONCAT_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONCAT_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_constant_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
@ -23,12 +23,24 @@ STATUS OnnxConstantParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ConstantParser";
if (op != nullptr) {
std::unique_ptr<schema::ConstantT> attr = std::make_unique<schema::ConstantT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Constant;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ConstantT> attr = std::make_unique<schema::ConstantT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Constant;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_CONSTANT_PARSER_H
#define MS_ONNX_CONSTANT_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONSTANT_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONSTANT_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxConstantParser : public OnnxNodeParser {
public:
OnnxConstantParser() : OnnxNodeParser("Constant") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_CONSTANT_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONSTANT_PARSER_H

View File

@ -14,21 +14,22 @@
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_conv_parser.h"
#include <vector>
#include <memory>
#include <algorithm>
#include "tools/converter/parser/onnx/onnx_conv_parser.h"
namespace mindspore {
namespace lite {
bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, schema::CNodeT *op) {
bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx DepthwiseConvParser";
if (attr == nullptr || attr->group != attr->channelIn) {
return false;
}
std::unique_ptr<schema::DepthwiseConv2DT> depthwiseConv2DParam = std::make_unique<schema::DepthwiseConv2DT>();
if (depthwiseConv2DParam == nullptr) {
MS_LOG(ERROR) << "new DepthwiseConv2DT failed";
MS_LOG(ERROR) << "new op failed";
return false;
}
depthwiseConv2DParam->format = attr->format;
@ -47,15 +48,32 @@ bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT
depthwiseConv2DParam->dilateH = attr->dilateH;
depthwiseConv2DParam->hasBias = attr->hasBias;
depthwiseConv2DParam->activationType = attr->activationType;
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
op->primitive->value.value = depthwiseConv2DParam.release();
return true;
}
STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ConvParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::Conv2DT> attr = std::make_unique<schema::Conv2DT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
// set opdef each attr params
for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "group") {
@ -149,13 +167,13 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
} else {
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
if (attr->group != 1) {
if (!ParseGroupConvolution(attr, op)) {
MS_LOG(ERROR) << "Convert Convolution to Depthwise failed";
return RET_ERROR;
}
} else {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = attr.release();
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_CONV_PARSER_H
#define MS_ONNX_CONV_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONV_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONV_PARSER_H
#include <memory>
#include "tools/converter/parser/onnx/onnx_node_parser.h"
@ -26,11 +26,16 @@ namespace lite {
class OnnxConvParser : public OnnxNodeParser {
public:
OnnxConvParser() : OnnxNodeParser("Conv") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
private:
bool ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, schema::CNodeT *op);
bool ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr,
schema::CNodeT *op);
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_CONV_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONV_PARSER_H

View File

@ -22,6 +22,7 @@ namespace lite {
OnnxConverter::OnnxConverter() {
modelParser = new OnnxModelParser();
}
} // namespace lite
} // namespace mindspore

View File

@ -14,8 +14,9 @@
* limitations under the License.
*/
#ifndef MS_ONNX_CONVERTER_H
#define MS_ONNX_CONVERTER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H
#include <string>
#include <memory>
#include "tools/converter/converter.h"
@ -27,10 +28,10 @@ class OnnxConverter : public Converter {
public:
OnnxConverter();
~OnnxConverter() override = default;
~OnnxConverter() = default;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_CONVERTER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H

View File

@ -14,21 +14,21 @@
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_deconv_parser.h"
#include <vector>
#include <memory>
#include <algorithm>
#include "tools/converter/parser/onnx/onnx_deconv_parser.h"
namespace mindspore {
namespace lite {
bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx DeConvParser";
bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr,
schema::CNodeT *op) {
if (attr == nullptr || attr->group != attr->channelOut) {
return false;
}
std::unique_ptr<schema::DeDepthwiseConv2DT> deDepthwiseConv2DParam = std::make_unique<schema::DeDepthwiseConv2DT>();
if (deDepthwiseConv2DParam == nullptr) {
MS_LOG(ERROR) << "new DeDepthwiseConv2DT failed";
MS_LOG(WARNING) << "new op failed";
return false;
}
deDepthwiseConv2DParam->format = attr->format;
@ -47,38 +47,53 @@ bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeC
deDepthwiseConv2DParam->dilateH = attr->dilateH;
deDepthwiseConv2DParam->hasBias = attr->hasBias;
deDepthwiseConv2DParam->activationType = attr->activationType;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D;
op->primitive->value.value = deDepthwiseConv2DParam.release();
}
op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D;
op->primitive->value.value = deDepthwiseConv2DParam.release();
return true;
}
STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx DeConvParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::DeConv2DT> attr = std::make_unique<schema::DeConv2DT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
// set opdef each attr params
for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "group") {
attr->group = static_cast<int32_t>(onnx_node_attr.i());
} else if (onnx_node_attr.name() == "dilations") {
if (onnx_node_attr.ints().size() != 2) {
// MS_LOGE("dilations size %d is not 2", onnx_node_attr.ints().size());
MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
}
attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "kernels") {
if (onnx_node_attr.ints().size() != 2) {
// MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size());
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
}
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "kernel_shape") {
if (onnx_node_attr.ints().size() != 2) {
// MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size());
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
}
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(0));
@ -87,7 +102,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
attr->padMode = GetOnnxPadMode(onnx_node_attr);
} else if (onnx_node_attr.name() == "pads") {
if (onnx_node_attr.ints().size() != 4) {
// MS_LOGE("pads size %d is not 4", onnx_node_attr.ints().size());
MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4";
return RET_ERROR;
}
attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0));
@ -96,7 +111,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
attr->padRight = static_cast<int32_t>(onnx_node_attr.ints(3));
} else if (onnx_node_attr.name() == "strides") {
if (onnx_node_attr.ints().size() != 2) {
// MS_LOGE("strides size %d is not 2", onnx_node_attr.ints().size());
MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
}
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(0));
@ -105,7 +120,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
if (onnx_node_attr.s() == "NHWC") {
attr->format = schema::Format_NHWC;
} else {
// MS_LOGE("Unsupported format: %s", onnx_node_attr.s().c_str());
MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s().c_str();
return RET_ERROR;
}
}
@ -116,7 +131,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
[onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; });
if (nodeIter == onnx_graph.initializer().end()) {
// MS_LOGE("not find node: %s", onnx_conv_weight.c_str())
MS_LOG(ERROR) << "not find node: " << onnx_conv_weight.c_str();
return RET_ERROR;
}
std::vector<int> weight_shape;
@ -137,7 +152,6 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
return RET_ERROR;
}
} else {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_DeConv2D;
op->primitive->value.value = attr.release();
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_DECONV_PARSER_H
#define MS_ONNX_DECONV_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DECONV_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DECONV_PARSER_H
#include <memory>
#include "tools/converter/parser/onnx/onnx_node_parser.h"
@ -26,11 +26,16 @@ namespace lite {
class OnnxDeConvParser : public OnnxNodeParser {
public:
OnnxDeConvParser() : OnnxNodeParser("DeConv") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
private:
bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op);
bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr,
schema::CNodeT *op);
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_DECONV_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DECONV_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_depth_to_space_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
@ -23,18 +23,31 @@ STATUS OnnxDepthToSpaceParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx DepthToSpaceParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::DepthToSpaceT> attr = std::make_unique<schema::DepthToSpaceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto& attribute_name = onnx_node_attr.name();
if (attribute_name == "blocksize") {
attr->blockSize = static_cast<int32_t>(onnx_node_attr.i());
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_DepthToSpace;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_DepthToSpace;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_DEPTH_TO_SPACE_PARSER_H
#define MS_ONNX_DEPTH_TO_SPACE_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DEPTH_TO_SPACE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DEPTH_TO_SPACE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxDepthToSpaceParser : public OnnxNodeParser {
public:
OnnxDepthToSpaceParser() : OnnxNodeParser("DepthToSpace") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_DEPTH_TO_SPACE_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DEPTH_TO_SPACE_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_dropout_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
@ -23,18 +23,31 @@ STATUS OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx DropoutParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::DropoutT> attr = std::make_unique<schema::DropoutT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "ratio") {
attr->ratio = static_cast<int32_t>(onnx_node_attr.i());
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Dropout;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Dropout;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_ARGMAX_PARSER_H
#define MS_ONNX_ARGMAX_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DROPOUT_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DROPOUT_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxDropoutParser : public OnnxNodeParser {
public:
OnnxDropoutParser() : OnnxNodeParser("Dropout") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_ARGMAX_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DROPOUT_PARSER_H

View File

@ -14,25 +14,40 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_elu_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
STATUS OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
STATUS OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx EluParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::EluT> attr = std::make_unique<schema::EluT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto& attribute_name = onnx_node_attr.name();
if (attribute_name == "alpha") {
attr->alpha = onnx_node_attr.f();
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Elu;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Elu;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_ELU_PARSER_H
#define MS_ONNX_ELU_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ELU_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ELU_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxEluParser : public OnnxNodeParser {
public:
OnnxEluParser() : OnnxNodeParser("Elu") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_ELU_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ELU_PARSER_H

View File

@ -14,20 +14,32 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_expand_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ExpandParser";
if (op != nullptr) {
std::unique_ptr<schema::BroadcastT> attr = std::make_unique<schema::BroadcastT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Broadcast;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::BroadcastT> attr = std::make_unique<schema::BroadcastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Broadcast;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_EXPAND_PARSER_H
#define MS_ONNX_EXPAND_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_EXPAND_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_EXPAND_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxExpandParser : public OnnxNodeParser {
public:
OnnxExpandParser() : OnnxNodeParser("Expand") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_EXPAND_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_EXPAND_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_flatten_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
@ -23,7 +23,22 @@ STATUS OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx FlattenParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ReshapeT> attr = std::make_unique<schema::ReshapeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
int axis = 1;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
@ -36,11 +51,8 @@ STATUS OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph,
}
attr->shape.emplace_back(-1);
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Reshape;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Reshape;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_FLATTEN_PARSER_H
#define MS_ONNX_FLATTEN_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_FLATTEN_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_FLATTEN_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -26,9 +26,11 @@ class OnnxFlattenParser : public OnnxNodeParser {
public:
OnnxFlattenParser() : OnnxNodeParser("Fatten") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_FLATTEN_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_FLATTEN_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_gather_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
@ -23,18 +23,31 @@ STATUS OnnxGatherParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx GatherParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::GatherT> attr = std::make_unique<schema::GatherT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto& attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") {
attr->axis = static_cast<int32_t>(onnx_node_attr.i());
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Gather;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Gather;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_GATHER_PARSER_H
#define MS_ONNX_GATHER_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GATHER_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GATHER_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxGatherParser : public OnnxNodeParser {
public:
OnnxGatherParser() : OnnxNodeParser("Gather") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_GATHER_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GATHER_PARSER_H

View File

@ -14,31 +14,69 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_lrn_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx LrnParser";
std::unique_ptr<schema::LrnT> attr = std::make_unique<schema::LrnT>();
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto& attribute_name = onnx_node_attr.name();
if (attribute_name == "size") {
attr->size = static_cast<int32_t>(onnx_node_attr.i());
} else if (attribute_name == "alpha") {
attr->alpha = onnx_node_attr.f();
} else if (attribute_name == "beta") {
attr->beta = onnx_node_attr.f();
} else if (attribute_name == "bias") {
attr->bias = onnx_node_attr.f();
}
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Lrn;
op->primitive->value.value = attr.release();
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::LocalResponseNormalizationT> attr
= std::make_unique<schema::LocalResponseNormalizationT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
auto onnx_node_attr = onnx_node.attribute().at(0);
int32_t size = 0;
if (onnx_node_attr.name() == "size") {
size = static_cast<int32_t>(onnx_node_attr.i());
attr->depth_radius = static_cast<int32_t>(size / 2);
} else {
MS_LOG(ERROR) << "the first attr is not size";
return RET_ERROR;
}
onnx_node_attr = onnx_node.attribute().at(1);
if (onnx_node_attr.name() == "alpha") {
auto alpha = onnx_node_attr.f();
attr->alpha = alpha / size;
} else {
MS_LOG(ERROR) << "the second attr is not alpha";
return RET_ERROR;
}
onnx_node_attr = onnx_node.attribute().at(2);
if (onnx_node_attr.name() == "beta") {
attr->beta = onnx_node_attr.f();
} else {
MS_LOG(ERROR) << "the third attr is not beta";
return RET_ERROR;
}
onnx_node_attr = onnx_node.attribute().at(3);
if (onnx_node_attr.name() == "bias") {
attr->bias = onnx_node_attr.f();
} else {
MS_LOG(ERROR) << "the third attr is not bias";
return RET_ERROR;
}
op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_LRN_PARSER_H
#define MS_ONNX_LRN_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LRN_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LRN_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxLrnParser : public OnnxNodeParser {
public:
OnnxLrnParser() : OnnxNodeParser("Lrn") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_LRN_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LRN_PARSER_H

View File

@ -14,15 +14,31 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_matmul_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx MatMulParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::MatMulT> attr = std::make_unique<schema::MatMulT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
float alpha = 1.0f;
float beta = 1.0f;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
@ -39,14 +55,11 @@ STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
}
if (alpha != 1 || beta != 1) {
MS_LOG(ERROR) << "not support alpha * A * B + beta * C";
return RET_PARAM_INVALID;
return RET_ERROR;
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_MatMul;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_MatMul;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_MATMUL_PARSER_H
#define MS_ONNX_MATMUL_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MATMUL_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MATMUL_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxMatmulParser : public OnnxNodeParser {
public:
OnnxMatmulParser() : OnnxNodeParser("MatMul") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_MATMUL_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MATMUL_PARSER_H

View File

@ -14,11 +14,11 @@
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_model_parser.h"
#include <cfloat>
#include <unordered_map>
#include <algorithm>
#include <utility>
#include "tools/converter/parser/onnx/onnx_model_parser.h"
#include "tools/common/graph_util.h"
#include "src/common/utils.h"
@ -54,7 +54,8 @@ std::vector<int32_t> OnnxModelParser::GetDimsFromOnnxValue(const onnx::ValueInfo
return dims;
}
STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile, google::protobuf::Message *onnx_model) {
STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile,
google::protobuf::Message *onnx_model) {
std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0});
#ifdef _WIN32
if (_fullpath(onnx_file.get(), modelFile.c_str(), 1024) == nullptr) {
@ -81,7 +82,8 @@ STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile, go
return RET_OK;
}
STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache) {
STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph,
TensorCache *tensor_cache) {
MS_LOG(DEBUG) << "set onnx constant tensors";
for (const auto &onnx_const_value : onnx_graph.initializer()) {
int index;
@ -117,8 +119,11 @@ STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph,
return RET_OK;
}
STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const TensorType &type,
TensorCache *tensor_cache, int *index) {
STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto,
const std::string &name,
const TensorType &type,
TensorCache *tensor_cache,
int *index) {
auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type()));
if (data_type == kTypeUnknown) {
MS_LOG(ERROR) << "not support onnx data type "
@ -137,8 +142,12 @@ STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const st
*index = tensor_cache->AddTensor(name, tensor.release(), type);
return RET_OK;
}
STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const TensorType &type,
TensorCache *tensor_cache, int *index) {
STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto,
const std::string &name,
const TensorType &type,
TensorCache *tensor_cache,
int *index) {
auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.data_type()));
if (data_type == kTypeUnknown) {
MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(proto.data_type());
@ -165,7 +174,8 @@ STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std
return RET_OK;
}
STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph,
STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph,
schema::MetaGraphT *graph,
TensorCache *tensor_cache) {
for (const auto &input_value : onnx_graph.input()) {
auto ret = tensor_cache->FindTensor(input_value.name());
@ -182,7 +192,8 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph,
return RET_OK;
}
STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph,
STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
schema::MetaGraphT *graph,
TensorCache *tensor_cache) {
for (const auto &output_value : onnx_graph.output()) {
int index;
@ -196,8 +207,10 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
return RET_OK;
}
void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::MetaGraphT *graph, TensorCache *tensor_cache) {
void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::MetaGraphT *graph,
TensorCache *tensor_cache) {
std::unique_ptr<schema::CNodeT> dst_op_1 = std::make_unique<schema::CNodeT>();
dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0);
ParseOnnxNodeAttr(onnx_graph, onnx_node, "MatMul", dst_op_1.get());
@ -218,7 +231,8 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, cons
graph->nodes.emplace_back(std::move(dst_op_2));
}
STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) {
STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
TensorCache *tensor_cache) {
// convert GivenTensorFill node to a weight/bias tensor
auto ret = tensor_cache->FindTensor(onnx_node.output(0));
if (ret < 0) {
@ -270,8 +284,10 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
return RET_OK;
}
STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op, schema::TensorT *dst_tensor,
STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op,
schema::TensorT *dst_tensor,
TensorCache *tensor_cache) {
// change op_type() to name(), that is unique
dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0);
@ -303,8 +319,11 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
return RET_OK;
}
void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache) {
void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op,
schema::TensorT *dst_tensor,
TensorCache *tensor_cache) {
MS_ASSERT(dst_op != nullptr);
MS_ASSERT(tensor_cache != nullptr);
std::vector<string> quant_node_name;
@ -361,8 +380,10 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const
}
}
STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
const string &onnx_op_type, schema::CNodeT *dst_op) {
STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
const string &onnx_op_type,
schema::CNodeT *dst_op) {
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_op_type);
if (node_parser == nullptr) {
MS_LOG(EXCEPTION) << "not find " << onnx_op_type << ", node parser is nullptr";
@ -371,8 +392,10 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, co
return node_parser->Parse(onnx_graph, onnx_node, dst_op);
}
STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op,
const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) {
STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs,
schema::CNodeT *dst_op,
const onnx::NodeProto &onnx_node,
TensorCache *tensor_cache) {
for (const auto &onnx_node_input : node_inputs) {
auto index = tensor_cache->FindTensor(onnx_node_input);
if (index < 0) {
@ -385,7 +408,8 @@ STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs,
return RET_OK;
}
STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op,
STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs,
schema::CNodeT *dst_op,
TensorCache *tensor_cache) {
for (const auto &onnx_node_output : node_outputs) {
auto index = tensor_cache->FindTensor(onnx_node_output);
@ -400,7 +424,8 @@ STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs
return RET_OK;
}
STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, schema::TensorT *tensor) {
STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value,
schema::TensorT *tensor) {
size_t data_count = 1;
std::for_each(tensor->dims.begin(), tensor->dims.end(), [&data_count](int dim) { data_count *= dim; });
size_t data_size = 0;
@ -459,7 +484,8 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
return RET_OK;
}
STATUS OnnxModelParser::SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef) {
STATUS OnnxModelParser::SetAllTensors(const TensorCache &tensor_cache,
schema::MetaGraphT *graphDef) {
std::vector<schema::TensorT *> tensors = tensor_cache.GetCachedTensor();
for (auto iter : tensors) {
std::unique_ptr<schema::TensorT> temp(iter);
@ -481,7 +507,8 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph)
}
}
MetaGraphT *OnnxModelParser::Parse(const std::string &modelFile, const std::string &weightFile,
MetaGraphT *OnnxModelParser::Parse(const std::string &modelFile,
const std::string &weightFile,
const QuantType &quantType) {
if (ValidateFileStr(modelFile, ".onnx") != RET_OK) {
MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx";

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_MODEL_PARSER_H
#define MS_ONNX_MODEL_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MODEL_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MODEL_PARSER_H
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
@ -37,35 +37,83 @@ namespace lite {
class OnnxModelParser : public ModelParser {
public:
OnnxModelParser();
virtual ~OnnxModelParser();
MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override;
private:
TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value);
STATUS ReadOnnxModelFromBinary(const std::string &modelFile, google::protobuf::Message *model_proto);
STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache);
STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache);
STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache);
STATUS AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const TensorType &type,
TensorCache *tensor_cache, int *index);
STATUS AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const TensorType &type,
TensorCache *tensor_cache, int *index);
STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache);
void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::MetaGraphT *graph, TensorCache *tensor_cache);
STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache);
STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
const string &onnx_op_type, schema::CNodeT *dst_op);
void SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op,
schema::TensorT *dst_tensor, TensorCache *tensor_cache);
STATUS SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op,
const onnx::NodeProto &onnx_node, TensorCache *tensor_cache);
STATUS SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op, TensorCache *tensor_cache);
STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, schema::TensorT *tensor);
STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef);
STATUS ReadOnnxModelFromBinary(const std::string &modelFile,
google::protobuf::Message *model_proto);
STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph,
TensorCache *tensor_cache);
STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph,
schema::MetaGraphT *graph,
TensorCache *tensor_cache);
STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
schema::MetaGraphT *graph,
TensorCache *tensor_cache);
STATUS AddValueInfo(const onnx::ValueInfoProto &proto,
const std::string &name,
const TensorType &type,
TensorCache *tensor_cache,
int *index);
STATUS AddTensorProto(const onnx::TensorProto &proto,
const std::string &name,
const TensorType &type,
TensorCache *tensor_cache,
int *index);
STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op,
schema::TensorT *dst_tensor,
TensorCache *tensor_cache);
void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::MetaGraphT *graph,
TensorCache *tensor_cache);
STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
TensorCache *tensor_cache);
STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
const string &onnx_op_type,
schema::CNodeT *dst_op);
void SetOpQuantParams(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op,
schema::TensorT *dst_tensor,
TensorCache *tensor_cache);
STATUS SetOpInputIndex(const std::vector<string> &node_inputs,
schema::CNodeT *dst_op,
const onnx::NodeProto &onnx_node,
TensorCache *tensor_cache);
STATUS SetOpOutputIndex(const std::vector<string> &node_outputs,
schema::CNodeT *dst_op,
TensorCache *tensor_cache);
STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value,
schema::TensorT *tensor);
STATUS SetAllTensors(const TensorCache &tensor_cache,
schema::MetaGraphT *graphDef);
void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph);
private:
@ -75,4 +123,4 @@ class OnnxModelParser : public ModelParser {
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_MODEL_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MODEL_PARSER_H

View File

@ -15,6 +15,7 @@
*/
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include <vector>
namespace mindspore {
namespace lite {
@ -30,6 +31,20 @@ schema::PadMode OnnxNodeParser::GetOnnxPadMode(const onnx::AttributeProto &onnx_
return schema::PadMode_NOTSET;
}
}
void OnnxNodeParser::Split(const std::string &src_str,
std::vector<std::string> *dst_str,
const std::string &chr) {
std::string ::size_type p1 = 0, p2 = src_str.find(chr);
while (std::string::npos != p2) {
dst_str->push_back(src_str.substr(p1, p2 - p1));
p1 = p2 + chr.size();
p2 = src_str.find(chr, p1);
}
if (p1 != src_str.length()) {
dst_str->push_back(src_str.substr(p1));
}
}
} // namespace lite
} // namespace mindspore

View File

@ -14,10 +14,11 @@
* limitations under the License.
*/
#ifndef MS_ONNX_NODE_PARSER_H
#define MS_ONNX_NODE_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_PARSER_H
#include <string>
#include <vector>
#include "google/protobuf/message.h"
#include "tools/converter/parser/onnx/onnx.pb.h"
#include "include/errorcode.h"
@ -29,14 +30,23 @@ namespace lite {
class OnnxNodeParser {
public:
explicit OnnxNodeParser(const std::string &nodeName) : name(nodeName) {}
virtual ~OnnxNodeParser() = default;
virtual STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) = 0;
virtual STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) = 0;
protected:
schema::PadMode GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr);
void Split(const std::string &src_str,
std::vector<std::string> *dst_str,
const std::string &chr);
const std::string &name;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_NODE_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_PARSER_H

View File

@ -21,7 +21,14 @@ namespace mindspore {
namespace lite {
OnnxNodeParserRegistry::OnnxNodeParserRegistry() = default;
OnnxNodeParserRegistry::~OnnxNodeParserRegistry() = default;
OnnxNodeParserRegistry::~OnnxNodeParserRegistry() {
for (auto ite : parsers) {
if (ite.second != nullptr) {
delete ite.second;
ite.second = nullptr;
}
}
}
OnnxNodeParserRegistry *OnnxNodeParserRegistry::GetInstance() {
static OnnxNodeParserRegistry instance;

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_OP_REGISTRY_H
#define MS_ONNX_OP_REGISTRY_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_REGISTRY_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_REGISTRY_H
#include <string>
#include <unordered_map>
@ -30,6 +30,7 @@ class OnnxNodeParserRegistry {
virtual ~OnnxNodeParserRegistry();
static OnnxNodeParserRegistry *GetInstance();
OnnxNodeParser *GetNodeParser(const std::string &name);
std::unordered_map<std::string, OnnxNodeParser *> parsers;
@ -37,12 +38,13 @@ class OnnxNodeParserRegistry {
class OnnxNodeRegistrar {
public:
OnnxNodeRegistrar(const std::string &name, OnnxNodeParser *parser) {
OnnxNodeRegistrar(const std::string &name,
OnnxNodeParser *parser) {
OnnxNodeParserRegistry::GetInstance()->parsers[name] = parser;
}
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_OP_REGISTRY_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_REGISTRY_H

View File

@ -14,14 +14,31 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_pad_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx PadParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::PadT> attr = std::make_unique<schema::PadT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "pads") {
@ -42,11 +59,9 @@ STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node
}
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Pad;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Pad;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_LRN_PARSER_H
#define MS_ONNX_LRN_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_PAD_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_PAD_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxPadParser : public OnnxNodeParser {
public:
OnnxPadParser() : OnnxNodeParser("Pad") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_LRN_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_PAD_PARSER_H

View File

@ -14,14 +14,30 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_pool_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx PoolParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::PoolingT> attr = std::make_unique<schema::PoolingT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->format = schema::Format_NCHW;
const auto &pool_type = onnx_node.op_type();
@ -79,11 +95,9 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
return RET_ERROR;
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Pooling;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Pooling;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_POOL_PARSER_H
#define MS_ONNX_POOL_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_POOL_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_POOL_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxPoolParser : public OnnxNodeParser {
public:
OnnxPoolParser() : OnnxNodeParser("Pool") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_POOL_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_POOL_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_reduce_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
@ -23,7 +23,22 @@ STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ReduceParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ReduceT> attr = std::make_unique<schema::ReduceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axes") {
@ -45,13 +60,12 @@ STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph,
} else if (type == "ReduceSum") {
attr->mode = schema::ReduceMode_ReduceSum;
} else {
// MS_LOGE("unsupoort type");
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Reduce;
op->primitive->value.value = attr.release();
MS_LOG(ERROR) << "unsupported type";
return RET_ERROR;
}
op->primitive->value.type = schema::PrimitiveType_Reduce;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_REDUCE_PARSER_H
#define MS_ONNX_REDUCE_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_REDUCE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_REDUCE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxReduceParser : public OnnxNodeParser {
public:
OnnxReduceParser() : OnnxNodeParser("Reduce") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_REDUCE_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_REDUCE_PARSER_H

View File

@ -14,36 +14,64 @@
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_relu_parser.h"
#include <vector>
#include <memory>
#include "tools/converter/parser/onnx/onnx_relu_parser.h"
#include "securec/include/securec.h"
namespace mindspore {
namespace lite {
STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ReluParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
const auto &relu_type = onnx_node.op_type();
if (relu_type == "Relu") {
MS_LOG(DEBUG) << "onnx ReluParser";
attr->type = schema::ActivationType_RELU;
} else if (relu_type == "LeakyRelu") {
MS_LOG(DEBUG) << "onnx LeakyReluParser";
attr->type = schema::ActivationType_LEAKY_RELU;
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx PReluParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
if (onnx_node.input_size() != 2) {
MS_LOG(ERROR) << "input num is not 2";
return RET_PARAM_INVALID;
MS_LOG(ERROR) << "input num should be 2";
return RET_ERROR;
}
std::unique_ptr<schema::CaffePReLUT> attr = std::make_unique<schema::CaffePReLUT>();
std::vector<onnx::TensorProto> params;
@ -57,8 +85,8 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
const onnx::TensorProto *slope = &params[0];
if (slope == nullptr) {
MS_LOG(ERROR) << "input error";
return RET_PARAM_INVALID;
MS_LOG(ERROR) << "input error: params[0] is null";
return RET_ERROR;
}
const auto slope_raw_data = reinterpret_cast<const float *>(slope->raw_data().data());
const int64_t slope_size = slope->raw_data().size() / sizeof(float);
@ -74,11 +102,8 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_CaffePReLU;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_CaffePReLU;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_RELU_PARSER_H
#define MS_ONNX_RELU_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RELU_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RELU_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,7 +25,10 @@ namespace lite {
class OnnxReluParser : public OnnxNodeParser {
public:
OnnxReluParser() : OnnxNodeParser("Relu") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
class OnnxLeakeyReluParser : public OnnxReluParser {
@ -36,9 +39,12 @@ class OnnxLeakeyReluParser : public OnnxReluParser {
class OnnxPReluParser : public OnnxNodeParser {
public:
OnnxPReluParser() : OnnxNodeParser("Prelu") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_RELU_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RELU_PARSER_H

View File

@ -14,16 +14,32 @@
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_reshape_parser.h"
#include <vector>
#include <memory>
#include "tools/converter/parser/onnx/onnx_reshape_parser.h"
namespace mindspore {
namespace lite {
STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ReshapeParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ReshapeT> attr = std::make_unique<schema::ReshapeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->format = schema::Format_NCHW;
std::vector<onnx::TensorProto> params;
for (int i = 0; i < onnx_node.input_size(); ++i) {
@ -40,18 +56,16 @@ STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::
} else {
if (params.size() != 1) {
MS_LOG(ERROR) << "shape param num is " << params.size() << ", not equal to 1";
return RET_PARAM_INVALID;
return RET_ERROR;
}
for (int i = 0; i < params[0].int64_data_size(); ++i) {
attr->shape.emplace_back(params[0].int64_data(i));
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Reshape;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Reshape;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_RESHAPE_PARSER_H
#define MS_ONNX_RESHAPE_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESHAPE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESHAPE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxReshapeParser : public OnnxNodeParser {
public:
OnnxReshapeParser() : OnnxNodeParser("Reshape") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_RESHAPE_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESHAPE_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_shape_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
@ -23,12 +23,24 @@ STATUS OnnxShapeParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ShapeParser";
if (op != nullptr) {
std::unique_ptr<schema::ShapeT> attr = std::make_unique<schema::ShapeT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Shape;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ShapeT> attr = std::make_unique<schema::ShapeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Shape;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_SHAPE_PARSER_H
#define MS_ONNX_SHAPE_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SHAPE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SHAPE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxShapeParser : public OnnxNodeParser {
public:
OnnxShapeParser() : OnnxNodeParser("Shape") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_SHAPE_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SHAPE_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_sigmoid_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
@ -23,13 +23,26 @@ STATUS OnnxSigmoidParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SigmoidParser";
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
attr->type = schema::ActivationType_SIGMOID;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->type = schema::ActivationType_SIGMOID;
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_SIGMOID_PARSER_H
#define MS_ONNX_SIGMOID_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SIGMOID_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SIGMOID_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxSigmoidParser : public OnnxNodeParser {
public:
OnnxSigmoidParser() : OnnxNodeParser("Sigmoid") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_SIGMOID_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SIGMOID_PARSER_H

View File

@ -14,15 +14,30 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_slice_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SliceParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::SliceT> attr = std::make_unique<schema::SliceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "starts") {
@ -38,11 +53,9 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
}
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Slice;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Slice;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_SLICE_PARSER_H
#define MS_ONNX_SLICE_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SLICE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SLICE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxSliceParser : public OnnxNodeParser {
public:
OnnxSliceParser() : OnnxNodeParser("Slice") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_SLICE_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SLICE_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_softmax_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
@ -23,18 +23,31 @@ STATUS OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SoftMaxParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::SoftMaxT> attr = std::make_unique<schema::SoftMaxT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto& attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") {
attr->axis = static_cast<int32_t>(onnx_node_attr.i());
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_SoftMax;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_SoftMax;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_SOFTMAX_PARSER_H
#define MS_ONNX_SOFTMAX_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SOFTMAX_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SOFTMAX_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxSoftMaxParser : public OnnxNodeParser {
public:
OnnxSoftMaxParser() : OnnxNodeParser("Softmax") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_SOFTMAX_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SOFTMAX_PARSER_H

View File

@ -14,26 +14,40 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_space_to_depth_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
STATUS OnnxSpaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
STATUS OnnxSpaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SpaceToDepthParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::SpaceToDepthT> attr = std::make_unique<schema::SpaceToDepthT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "blocksize") {
attr->blockSize = static_cast<int32_t>(onnx_node_attr.i());
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_SpaceToDepth;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_SpaceToDepth;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_SPACE_TO_DEPTH_PARSER_H
#define MS_ONNX_SPACE_TO_DEPTH_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPACE_TO_DEPTH_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPACE_TO_DEPTH_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxSpaceToDepthParser : public OnnxNodeParser {
public:
OnnxSpaceToDepthParser() : OnnxNodeParser("SpaceToDepth") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_SPACE_TO_DEPTH_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPACE_TO_DEPTH_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_squeeze_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
@ -23,7 +23,22 @@ STATUS OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SqueezeParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::SqueezeT> attr = std::make_unique<schema::SqueezeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axes") {
@ -32,11 +47,9 @@ STATUS OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph,
}
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Squeeze;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Squeeze;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_SQUEEZE_PARSER_H
#define MS_ONNX_SQUEEZE_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SQUEEZE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SQUEEZE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxSqueezeParser : public OnnxNodeParser {
public:
OnnxSqueezeParser() : OnnxNodeParser("Squeeze") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_SQUEEZE_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SQUEEZE_PARSER_H

View File

@ -14,19 +14,33 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_tile_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx TileParser";
if (op != nullptr) {
std::unique_ptr<schema::TileT> attr = std::make_unique<schema::TileT>();
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Tile;
op->primitive->value.value = attr.release();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::TileT> attr = std::make_unique<schema::TileT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Tile;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_TILE_PARSER_H
#define MS_ONNX_TILE_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TILE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TILE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxTileParser : public OnnxNodeParser {
public:
OnnxTileParser() : OnnxNodeParser("Tile") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_ARGMAX_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TILE_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_transpose_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
@ -23,7 +23,22 @@ STATUS OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx TransposeParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::TransposeT> attr = std::make_unique<schema::TransposeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->conjugate = false;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
@ -40,11 +55,9 @@ STATUS OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph,
}
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Transpose;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Transpose;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_TRANSPOSE_PARSER_H
#define MS_ONNX_TRANSPOSE_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TRANSPOSE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TRANSPOSE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxTransposeParser : public OnnxNodeParser {
public:
OnnxTransposeParser() : OnnxNodeParser("Transpose") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_TRANSPOSE_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TRANSPOSE_PARSER_H

View File

@ -23,7 +23,22 @@ STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx UpsampleParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::UpsampleT> attr = std::make_unique<schema::UpsampleT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "mode") {
@ -34,12 +49,9 @@ STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph,
}
}
}
// to do
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Upsample;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Upsample;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_UPSAMPLE_PARSER_H
#define MS_ONNX_UPSAMPLE_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UPSAMPLE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UPSAMPLE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxUpsampleParser : public OnnxNodeParser {
public:
OnnxUpsampleParser() : OnnxNodeParser("Upsample") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_UPSAMPLE_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UPSAMPLE_PARSER_H

View File

@ -14,15 +14,30 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_unsqueeze_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
STATUS OnnxUnSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx UnSqueezeParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::UnsqueezeT> attr = std::make_unique<schema::UnsqueezeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axes") {
@ -31,11 +46,9 @@ STATUS OnnxUnSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx
}
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Unsqueeze;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Unsqueeze;
op->primitive->value.value = attr.release();
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_UNSQUEEZE_PARSER_H
#define MS_ONNX_UNSQUEEZE_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UNSQUEEZE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UNSQUEEZE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxUnSqueezeParser : public OnnxNodeParser {
public:
OnnxUnSqueezeParser() : OnnxNodeParser("Unsqueeze") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_UNSQUEEZE_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UNSQUEEZE_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_unuseful_node_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
@ -23,25 +23,35 @@ STATUS OnnxUnusefulNodeParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx UnusefulNodeParser";
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
if (onnx_node.op_type() == "Int8Quantize") {
op->primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize;
op->primitive->value.value = std::make_unique<schema::OnnxInt8QuantizeT>().release();
} else if (onnx_node.op_type() == "Int8Dequantize") {
op->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize;
op->primitive->value.value = std::make_unique<schema::OnnxInt8DequantizeT>().release();
} else {
// MS_LOGE("Unsupported nodeType: %s", onnx_node.op_type().c_str());
return RET_ERROR;
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
if (onnx_node.op_type() == "Int8Quantize") {
std::unique_ptr<schema::OnnxInt8QuantizeT> attr = std::make_unique<schema::OnnxInt8QuantizeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
if (op->primitive->value.value == nullptr) {
// MS_LOGE("new %s attr value failed", onnx_node.op_type().c_str());
return RET_ERROR;
op->primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize;
op->primitive->value.value = attr.release();
} else if (onnx_node.op_type() == "Int8Dequantize") {
std::unique_ptr<schema::OnnxInt8DequantizeT> attr = std::make_unique<schema::OnnxInt8DequantizeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize;
op->primitive->value.value = attr.release();
} else {
// MS_LOGE("Input opDef is nullptr");
return RET_PARAM_INVALID;
MS_LOG(ERROR) << "Unsupported nodeType: " << onnx_node.op_type().c_str();
return RET_ERROR;
}
return RET_OK;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MS_ONNX_UNUSEFUL_PARSER_H
#define MS_ONNX_UNUSEFUL_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -25,9 +25,12 @@ namespace lite {
class OnnxUnusefulNodeParser : public OnnxNodeParser {
public:
OnnxUnusefulNodeParser() : OnnxNodeParser("UnusefulNode") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
STATUS Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_UNUSEFUL_PARSER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H