forked from mindspore-Ecosystem/mindspore
modify onnx parsers format
This commit is contained in:
parent
ec1cf059a7
commit
6313af4d2f
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_argmax_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -23,7 +23,22 @@ STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph,
|
|||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx ArgMaxParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ArgMaxT> attr = std::make_unique<schema::ArgMaxT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "axis") {
|
||||
|
@ -32,11 +47,9 @@ STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph,
|
|||
attr->keepDims = static_cast<bool>(onnx_node_attr.i());
|
||||
}
|
||||
}
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_ArgMax;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_ArgMax;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_ARGMAX_PARSER_H
|
||||
#define MS_ONNX_ARGMAX_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARGMAX_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARGMAX_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxArgMaxParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxArgMaxParser() : OnnxNodeParser("ArgMax") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_ARGMAX_PARSER_H
|
||||
#endif // MMINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARGMAX_PARSER_H
|
||||
|
||||
|
|
|
@ -14,114 +14,238 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx AddParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::AddT> attr = std::make_unique<schema::AddT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Add;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::AddT> attr = std::make_unique<schema::AddT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Add;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxSubParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx SubParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::SubT> attr = std::make_unique<schema::SubT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Sub;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::SubT> attr = std::make_unique<schema::SubT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Sub;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxMulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx MulParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::MulT> attr = std::make_unique<schema::MulT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Mul;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::MulT> attr = std::make_unique<schema::MulT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Mul;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxDivParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx DivParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::DivT> attr = std::make_unique<schema::DivT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Div;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::DivT> attr = std::make_unique<schema::DivT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Div;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxPowParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx PowParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::PowerT> attr(new schema::PowerT());
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Power;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::PowerT> attr(new schema::PowerT());
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Power;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxEqualParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx EqualParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::EqualT> attr = std::make_unique<schema::EqualT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Equal;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::EqualT> attr = std::make_unique<schema::EqualT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Equal;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxLessParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx LessParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::LessT> attr = std::make_unique<schema::LessT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Less;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::LessT> attr = std::make_unique<schema::LessT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Less;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
STATUS OnnxGreaterParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx GreaterParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::GreaterT> attr = std::make_unique<schema::GreaterT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Greater;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::GreaterT> attr = std::make_unique<schema::GreaterT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Greater;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxMinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx MinParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::MinT> attr = std::make_unique<schema::MinT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Min;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::MinT> attr = std::make_unique<schema::MinT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Min;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx EltwiseParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::EltwiseT> attr = std::make_unique<schema::EltwiseT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
// there is no Prod in onnx
|
||||
if (onnx_node.op_type() == "Sum") {
|
||||
attr->mode = schema::EltwiseMode_SUM;
|
||||
|
@ -129,143 +253,303 @@ STATUS OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::
|
|||
attr->mode = schema::EltwiseMode_MAXIMUM;
|
||||
}
|
||||
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Eltwise;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
op->primitive->value.type = schema::PrimitiveType_Eltwise;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxFloorParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx FloorParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::FloorT> attr = std::make_unique<schema::FloorT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Floor;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::FloorT> attr = std::make_unique<schema::FloorT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Floor;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxAbsParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx AbsParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::AbsT> attr = std::make_unique<schema::AbsT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Abs;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::AbsT> attr = std::make_unique<schema::AbsT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Abs;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxNegParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx NegParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::NegT> attr = std::make_unique<schema::NegT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Neg;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::NegT> attr = std::make_unique<schema::NegT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Neg;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxExpParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx ExpParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::ExpT> attr = std::make_unique<schema::ExpT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Exp;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ExpT> attr = std::make_unique<schema::ExpT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Exp;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxCosParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx CosParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::CosT> attr = std::make_unique<schema::CosT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Cos;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::CosT> attr = std::make_unique<schema::CosT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Cos;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxSinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx SinParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::SinT> attr = std::make_unique<schema::SinT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Sin;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::SinT> attr = std::make_unique<schema::SinT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Sin;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxSqrtParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx SqrtParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::SqrtT> attr = std::make_unique<schema::SqrtT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Sqrt;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::SqrtT> attr = std::make_unique<schema::SqrtT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Sqrt;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxCeilParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx CeilParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::CeilT> attr = std::make_unique<schema::CeilT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Ceil;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::CeilT> attr = std::make_unique<schema::CeilT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Ceil;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxLogParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx LogParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::LogT> attr = std::make_unique<schema::LogT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Log;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::LogT> attr = std::make_unique<schema::LogT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Log;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxTanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx TanParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::TanT> attr = std::make_unique<schema::TanT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Tan;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::TanT> attr = std::make_unique<schema::TanT>();
|
||||
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Tan;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxAtanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx AtanParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::AtanT> attr = std::make_unique<schema::AtanT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Atan;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::AtanT> attr = std::make_unique<schema::AtanT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Atan;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxAsinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx AsinParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::AsinT> attr = std::make_unique<schema::AsinT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Asin;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::AsinT> attr = std::make_unique<schema::AsinT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Asin;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxTanhParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx TanhParser";
|
||||
if (op != nullptr) {
|
||||
MS_LOG(ERROR) << "mslite don't support tanh now";
|
||||
return RET_ERROR;
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
return RET_OK;
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
MS_LOG(ERROR) << "mslite don't support tanh now";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser());
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_ARITHMETIC_OPREATION_PARSER_H
|
||||
#define MS_ONNX_ARITHMETIC_OPREATION_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -167,5 +167,5 @@ class OnnxTanhParser : public OnnxNodeParser {
|
|||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_ARITHMETIC_OPREATION_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H
|
||||
|
||||
|
|
|
@ -19,10 +19,26 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx BatchNormParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::FusedBatchNormT> attr = std::make_unique<schema::FusedBatchNormT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
if (onnx_node_attr.name() == "epsilon") {
|
||||
attr->epsilon = onnx_node_attr.f();
|
||||
|
@ -32,11 +48,9 @@ STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx
|
|||
attr->spatial = static_cast<int32_t>(onnx_node_attr.i());
|
||||
}
|
||||
}
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_ADD_PARSER_H
|
||||
#define MS_ONNX_ADD_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BATCHNORM_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BATCHNORM_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxBatchNormParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxBatchNormParser() : OnnxNodeParser("BatchNormalization") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_ADD_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BATCHNORM_PARSER_H
|
||||
|
||||
|
|
|
@ -14,26 +14,36 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_biasadd_parser.h"
|
||||
#include <memory>
|
||||
|
||||
// using namespace mindspore::predict;
|
||||
// using namespace onnx;
|
||||
// using namespace std;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxBiasAddParser::Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx BiasAddParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::BiasAddT> attr = std::make_unique<schema::BiasAddT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
// use channel dim as axis
|
||||
attr->axis = {1};
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_BiasAdd;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_BiasAdd;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_BIASADD_PARSER_H
|
||||
#define MS_ONNX_BIASADD_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BIASADD_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BIASADD_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -26,9 +26,11 @@ class OnnxBiasAddParser : public OnnxNodeParser {
|
|||
public:
|
||||
OnnxBiasAddParser() : OnnxNodeParser("BiasAdd") {}
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_BIASADD_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_BIASADD_PARSER_H
|
||||
|
||||
|
|
|
@ -14,25 +14,40 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_cast_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx CastParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "to") {
|
||||
attr->dstT = static_cast<int32_t>(onnx_node_attr.i());
|
||||
}
|
||||
}
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Cast;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Cast;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_CAST_PARSER_H
|
||||
#define MS_ONNX_CAST_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CAST_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CAST_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxCastParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxCastParser() : OnnxNodeParser("Cast") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_CAST_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CAST_PARSER_H
|
||||
|
||||
|
|
|
@ -14,13 +14,25 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_clip_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx ClipParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
float min = -1, max = -1;
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
|
@ -32,15 +44,17 @@ STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
}
|
||||
if (min == 0 && max == 6) {
|
||||
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
|
||||
attr->type = schema::ActivationType_RELU6;
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Activation;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->type = schema::ActivationType_RELU6;
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Activation;
|
||||
op->primitive->value.value = attr.release();
|
||||
} else {
|
||||
MS_LOG(ERROR) << "only support convert clip(0,6) to relu6, other value is not supported";
|
||||
return RET_PARAM_INVALID;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_CLIP_PARSER_H
|
||||
#define MS_ONNX_CLIP_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CLIP_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CLIP_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxClipParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxClipParser() : OnnxNodeParser("Clip") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_ARGMAX_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CLIP_PARSER_H
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_concat_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -23,18 +23,31 @@ STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph,
|
|||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx ConcatParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ConcatT> attr = std::make_unique<schema::ConcatT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "axis") {
|
||||
attr->axis = static_cast<int32_t>(onnx_node_attr.i());
|
||||
}
|
||||
}
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Concat;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Concat;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_CONCAT_PARSER_H
|
||||
#define MS_ONNX_CONCAT_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONCAT_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONCAT_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxConcatParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxConcatParser() : OnnxNodeParser("Concat") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_CONCAT_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONCAT_PARSER_H
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_constant_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -23,12 +23,24 @@ STATUS OnnxConstantParser::Parse(const onnx::GraphProto &onnx_graph,
|
|||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx ConstantParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::ConstantT> attr = std::make_unique<schema::ConstantT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Constant;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ConstantT> attr = std::make_unique<schema::ConstantT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Constant;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_CONSTANT_PARSER_H
|
||||
#define MS_ONNX_CONSTANT_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONSTANT_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONSTANT_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxConstantParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxConstantParser() : OnnxNodeParser("Constant") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_CONSTANT_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONSTANT_PARSER_H
|
||||
|
||||
|
|
|
@ -14,21 +14,22 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_conv_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "tools/converter/parser/onnx/onnx_conv_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, schema::CNodeT *op) {
|
||||
bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx DepthwiseConvParser";
|
||||
if (attr == nullptr || attr->group != attr->channelIn) {
|
||||
return false;
|
||||
}
|
||||
std::unique_ptr<schema::DepthwiseConv2DT> depthwiseConv2DParam = std::make_unique<schema::DepthwiseConv2DT>();
|
||||
if (depthwiseConv2DParam == nullptr) {
|
||||
MS_LOG(ERROR) << "new DepthwiseConv2DT failed";
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return false;
|
||||
}
|
||||
depthwiseConv2DParam->format = attr->format;
|
||||
|
@ -47,15 +48,32 @@ bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT
|
|||
depthwiseConv2DParam->dilateH = attr->dilateH;
|
||||
depthwiseConv2DParam->hasBias = attr->hasBias;
|
||||
depthwiseConv2DParam->activationType = attr->activationType;
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
||||
op->primitive->value.value = depthwiseConv2DParam.release();
|
||||
return true;
|
||||
}
|
||||
|
||||
STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx ConvParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::Conv2DT> attr = std::make_unique<schema::Conv2DT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
// set opdef each attr params
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
if (onnx_node_attr.name() == "group") {
|
||||
|
@ -149,13 +167,13 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
} else {
|
||||
attr->activationType = schema::ActivationType_NO_ACTIVATION;
|
||||
}
|
||||
|
||||
if (attr->group != 1) {
|
||||
if (!ParseGroupConvolution(attr, op)) {
|
||||
MS_LOG(ERROR) << "Convert Convolution to Depthwise failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_CONV_PARSER_H
|
||||
#define MS_ONNX_CONV_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONV_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONV_PARSER_H
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
|
@ -26,11 +26,16 @@ namespace lite {
|
|||
class OnnxConvParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxConvParser() : OnnxNodeParser("Conv") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
|
||||
private:
|
||||
bool ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, schema::CNodeT *op);
|
||||
bool ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr,
|
||||
schema::CNodeT *op);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_CONV_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONV_PARSER_H
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ namespace lite {
|
|||
OnnxConverter::OnnxConverter() {
|
||||
modelParser = new OnnxModelParser();
|
||||
}
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -14,8 +14,9 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_CONVERTER_H
|
||||
#define MS_ONNX_CONVERTER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "tools/converter/converter.h"
|
||||
|
@ -27,10 +28,10 @@ class OnnxConverter : public Converter {
|
|||
public:
|
||||
OnnxConverter();
|
||||
|
||||
~OnnxConverter() override = default;
|
||||
~OnnxConverter() = default;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MS_ONNX_CONVERTER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H
|
||||
|
||||
|
|
|
@ -14,21 +14,21 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_deconv_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "tools/converter/parser/onnx/onnx_deconv_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx DeConvParser";
|
||||
bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr,
|
||||
schema::CNodeT *op) {
|
||||
if (attr == nullptr || attr->group != attr->channelOut) {
|
||||
return false;
|
||||
}
|
||||
std::unique_ptr<schema::DeDepthwiseConv2DT> deDepthwiseConv2DParam = std::make_unique<schema::DeDepthwiseConv2DT>();
|
||||
if (deDepthwiseConv2DParam == nullptr) {
|
||||
MS_LOG(ERROR) << "new DeDepthwiseConv2DT failed";
|
||||
MS_LOG(WARNING) << "new op failed";
|
||||
return false;
|
||||
}
|
||||
deDepthwiseConv2DParam->format = attr->format;
|
||||
|
@ -47,38 +47,53 @@ bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeC
|
|||
deDepthwiseConv2DParam->dilateH = attr->dilateH;
|
||||
deDepthwiseConv2DParam->hasBias = attr->hasBias;
|
||||
deDepthwiseConv2DParam->activationType = attr->activationType;
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D;
|
||||
op->primitive->value.value = deDepthwiseConv2DParam.release();
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D;
|
||||
op->primitive->value.value = deDepthwiseConv2DParam.release();
|
||||
return true;
|
||||
}
|
||||
|
||||
STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx DeConvParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::DeConv2DT> attr = std::make_unique<schema::DeConv2DT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
// set opdef each attr params
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
if (onnx_node_attr.name() == "group") {
|
||||
attr->group = static_cast<int32_t>(onnx_node_attr.i());
|
||||
} else if (onnx_node_attr.name() == "dilations") {
|
||||
if (onnx_node_attr.ints().size() != 2) {
|
||||
// MS_LOGE("dilations size %d is not 2", onnx_node_attr.ints().size());
|
||||
MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
} else if (onnx_node_attr.name() == "kernels") {
|
||||
if (onnx_node_attr.ints().size() != 2) {
|
||||
// MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size());
|
||||
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
} else if (onnx_node_attr.name() == "kernel_shape") {
|
||||
if (onnx_node_attr.ints().size() != 2) {
|
||||
// MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size());
|
||||
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
|
@ -87,7 +102,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
attr->padMode = GetOnnxPadMode(onnx_node_attr);
|
||||
} else if (onnx_node_attr.name() == "pads") {
|
||||
if (onnx_node_attr.ints().size() != 4) {
|
||||
// MS_LOGE("pads size %d is not 4", onnx_node_attr.ints().size());
|
||||
MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
|
@ -96,7 +111,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
attr->padRight = static_cast<int32_t>(onnx_node_attr.ints(3));
|
||||
} else if (onnx_node_attr.name() == "strides") {
|
||||
if (onnx_node_attr.ints().size() != 2) {
|
||||
// MS_LOGE("strides size %d is not 2", onnx_node_attr.ints().size());
|
||||
MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
|
@ -105,7 +120,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
if (onnx_node_attr.s() == "NHWC") {
|
||||
attr->format = schema::Format_NHWC;
|
||||
} else {
|
||||
// MS_LOGE("Unsupported format: %s", onnx_node_attr.s().c_str());
|
||||
MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s().c_str();
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
@ -116,7 +131,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
|
||||
[onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; });
|
||||
if (nodeIter == onnx_graph.initializer().end()) {
|
||||
// MS_LOGE("not find node: %s", onnx_conv_weight.c_str())
|
||||
MS_LOG(ERROR) << "not find node: " << onnx_conv_weight.c_str();
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<int> weight_shape;
|
||||
|
@ -137,7 +152,6 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_DeConv2D;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_DECONV_PARSER_H
|
||||
#define MS_ONNX_DECONV_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DECONV_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DECONV_PARSER_H
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
|
@ -26,11 +26,16 @@ namespace lite {
|
|||
class OnnxDeConvParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxDeConvParser() : OnnxNodeParser("DeConv") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
|
||||
private:
|
||||
bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op);
|
||||
bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr,
|
||||
schema::CNodeT *op);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_DECONV_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DECONV_PARSER_H
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_depth_to_space_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -23,18 +23,31 @@ STATUS OnnxDepthToSpaceParser::Parse(const onnx::GraphProto &onnx_graph,
|
|||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx DepthToSpaceParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::DepthToSpaceT> attr = std::make_unique<schema::DepthToSpaceT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto& attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "blocksize") {
|
||||
attr->blockSize = static_cast<int32_t>(onnx_node_attr.i());
|
||||
}
|
||||
}
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_DepthToSpace;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_DepthToSpace;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_DEPTH_TO_SPACE_PARSER_H
|
||||
#define MS_ONNX_DEPTH_TO_SPACE_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DEPTH_TO_SPACE_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DEPTH_TO_SPACE_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxDepthToSpaceParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxDepthToSpaceParser() : OnnxNodeParser("DepthToSpace") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_DEPTH_TO_SPACE_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DEPTH_TO_SPACE_PARSER_H
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_dropout_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -23,18 +23,31 @@ STATUS OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph,
|
|||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx DropoutParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::DropoutT> attr = std::make_unique<schema::DropoutT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "ratio") {
|
||||
attr->ratio = static_cast<int32_t>(onnx_node_attr.i());
|
||||
}
|
||||
}
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Dropout;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Dropout;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_ARGMAX_PARSER_H
|
||||
#define MS_ONNX_ARGMAX_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DROPOUT_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DROPOUT_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxDropoutParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxDropoutParser() : OnnxNodeParser("Dropout") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_ARGMAX_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_DROPOUT_PARSER_H
|
||||
|
||||
|
|
|
@ -14,25 +14,40 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_elu_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
STATUS OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx EluParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::EluT> attr = std::make_unique<schema::EluT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto& attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "alpha") {
|
||||
attr->alpha = onnx_node_attr.f();
|
||||
}
|
||||
}
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Elu;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Elu;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_ELU_PARSER_H
|
||||
#define MS_ONNX_ELU_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ELU_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ELU_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxEluParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxEluParser() : OnnxNodeParser("Elu") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_ELU_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ELU_PARSER_H
|
||||
|
||||
|
|
|
@ -14,20 +14,32 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_expand_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx ExpandParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::BroadcastT> attr = std::make_unique<schema::BroadcastT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Broadcast;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::BroadcastT> attr = std::make_unique<schema::BroadcastT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Broadcast;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_EXPAND_PARSER_H
|
||||
#define MS_ONNX_EXPAND_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_EXPAND_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_EXPAND_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxExpandParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxExpandParser() : OnnxNodeParser("Expand") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_EXPAND_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_EXPAND_PARSER_H
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_flatten_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -23,7 +23,22 @@ STATUS OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph,
|
|||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx FlattenParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ReshapeT> attr = std::make_unique<schema::ReshapeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
int axis = 1;
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
|
@ -36,11 +51,8 @@ STATUS OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph,
|
|||
}
|
||||
attr->shape.emplace_back(-1);
|
||||
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Reshape;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
op->primitive->value.type = schema::PrimitiveType_Reshape;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_FLATTEN_PARSER_H
|
||||
#define MS_ONNX_FLATTEN_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_FLATTEN_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_FLATTEN_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -26,9 +26,11 @@ class OnnxFlattenParser : public OnnxNodeParser {
|
|||
public:
|
||||
OnnxFlattenParser() : OnnxNodeParser("Fatten") {}
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_FLATTEN_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_FLATTEN_PARSER_H
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_gather_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -23,18 +23,31 @@ STATUS OnnxGatherParser::Parse(const onnx::GraphProto &onnx_graph,
|
|||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx GatherParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::GatherT> attr = std::make_unique<schema::GatherT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto& attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "axis") {
|
||||
attr->axis = static_cast<int32_t>(onnx_node_attr.i());
|
||||
}
|
||||
}
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Gather;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Gather;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_GATHER_PARSER_H
|
||||
#define MS_ONNX_GATHER_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GATHER_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GATHER_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxGatherParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxGatherParser() : OnnxNodeParser("Gather") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_GATHER_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GATHER_PARSER_H
|
||||
|
||||
|
|
|
@ -14,31 +14,69 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_lrn_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx LrnParser";
|
||||
std::unique_ptr<schema::LrnT> attr = std::make_unique<schema::LrnT>();
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto& attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "size") {
|
||||
attr->size = static_cast<int32_t>(onnx_node_attr.i());
|
||||
} else if (attribute_name == "alpha") {
|
||||
attr->alpha = onnx_node_attr.f();
|
||||
} else if (attribute_name == "beta") {
|
||||
attr->beta = onnx_node_attr.f();
|
||||
} else if (attribute_name == "bias") {
|
||||
attr->bias = onnx_node_attr.f();
|
||||
}
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Lrn;
|
||||
op->primitive->value.value = attr.release();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::LocalResponseNormalizationT> attr
|
||||
= std::make_unique<schema::LocalResponseNormalizationT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto onnx_node_attr = onnx_node.attribute().at(0);
|
||||
int32_t size = 0;
|
||||
if (onnx_node_attr.name() == "size") {
|
||||
size = static_cast<int32_t>(onnx_node_attr.i());
|
||||
attr->depth_radius = static_cast<int32_t>(size / 2);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "the first attr is not size";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
onnx_node_attr = onnx_node.attribute().at(1);
|
||||
if (onnx_node_attr.name() == "alpha") {
|
||||
auto alpha = onnx_node_attr.f();
|
||||
attr->alpha = alpha / size;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "the second attr is not alpha";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
onnx_node_attr = onnx_node.attribute().at(2);
|
||||
if (onnx_node_attr.name() == "beta") {
|
||||
attr->beta = onnx_node_attr.f();
|
||||
} else {
|
||||
MS_LOG(ERROR) << "the third attr is not beta";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
onnx_node_attr = onnx_node.attribute().at(3);
|
||||
if (onnx_node_attr.name() == "bias") {
|
||||
attr->bias = onnx_node_attr.f();
|
||||
} else {
|
||||
MS_LOG(ERROR) << "the third attr is not bias";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_LRN_PARSER_H
|
||||
#define MS_ONNX_LRN_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LRN_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LRN_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxLrnParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxLrnParser() : OnnxNodeParser("Lrn") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_LRN_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LRN_PARSER_H
|
||||
|
||||
|
|
|
@ -14,15 +14,31 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_matmul_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx MatMulParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::MatMulT> attr = std::make_unique<schema::MatMulT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 1.0f;
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
|
@ -39,14 +55,11 @@ STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
}
|
||||
if (alpha != 1 || beta != 1) {
|
||||
MS_LOG(ERROR) << "not support alpha * A * B + beta * C";
|
||||
return RET_PARAM_INVALID;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_MatMul;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
op->primitive->value.type = schema::PrimitiveType_MatMul;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_MATMUL_PARSER_H
|
||||
#define MS_ONNX_MATMUL_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MATMUL_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MATMUL_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxMatmulParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxMatmulParser() : OnnxNodeParser("MatMul") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_MATMUL_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MATMUL_PARSER_H
|
||||
|
||||
|
|
|
@ -14,11 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
||||
#include <cfloat>
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "src/common/utils.h"
|
||||
|
||||
|
@ -54,7 +54,8 @@ std::vector<int32_t> OnnxModelParser::GetDimsFromOnnxValue(const onnx::ValueInfo
|
|||
return dims;
|
||||
}
|
||||
|
||||
STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile, google::protobuf::Message *onnx_model) {
|
||||
STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile,
|
||||
google::protobuf::Message *onnx_model) {
|
||||
std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0});
|
||||
#ifdef _WIN32
|
||||
if (_fullpath(onnx_file.get(), modelFile.c_str(), 1024) == nullptr) {
|
||||
|
@ -81,7 +82,8 @@ STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile, go
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache) {
|
||||
STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph,
|
||||
TensorCache *tensor_cache) {
|
||||
MS_LOG(DEBUG) << "set onnx constant tensors";
|
||||
for (const auto &onnx_const_value : onnx_graph.initializer()) {
|
||||
int index;
|
||||
|
@ -117,8 +119,11 @@ STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph,
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const TensorType &type,
|
||||
TensorCache *tensor_cache, int *index) {
|
||||
STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto,
|
||||
const std::string &name,
|
||||
const TensorType &type,
|
||||
TensorCache *tensor_cache,
|
||||
int *index) {
|
||||
auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type()));
|
||||
if (data_type == kTypeUnknown) {
|
||||
MS_LOG(ERROR) << "not support onnx data type "
|
||||
|
@ -137,8 +142,12 @@ STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const st
|
|||
*index = tensor_cache->AddTensor(name, tensor.release(), type);
|
||||
return RET_OK;
|
||||
}
|
||||
STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const TensorType &type,
|
||||
TensorCache *tensor_cache, int *index) {
|
||||
|
||||
STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto,
|
||||
const std::string &name,
|
||||
const TensorType &type,
|
||||
TensorCache *tensor_cache,
|
||||
int *index) {
|
||||
auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.data_type()));
|
||||
if (data_type == kTypeUnknown) {
|
||||
MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(proto.data_type());
|
||||
|
@ -165,7 +174,8 @@ STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph,
|
||||
STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph,
|
||||
schema::MetaGraphT *graph,
|
||||
TensorCache *tensor_cache) {
|
||||
for (const auto &input_value : onnx_graph.input()) {
|
||||
auto ret = tensor_cache->FindTensor(input_value.name());
|
||||
|
@ -182,7 +192,8 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph,
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph,
|
||||
STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
|
||||
schema::MetaGraphT *graph,
|
||||
TensorCache *tensor_cache) {
|
||||
for (const auto &output_value : onnx_graph.output()) {
|
||||
int index;
|
||||
|
@ -196,8 +207,10 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::MetaGraphT *graph, TensorCache *tensor_cache) {
|
||||
void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::MetaGraphT *graph,
|
||||
TensorCache *tensor_cache) {
|
||||
std::unique_ptr<schema::CNodeT> dst_op_1 = std::make_unique<schema::CNodeT>();
|
||||
dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0);
|
||||
ParseOnnxNodeAttr(onnx_graph, onnx_node, "MatMul", dst_op_1.get());
|
||||
|
@ -218,7 +231,8 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, cons
|
|||
graph->nodes.emplace_back(std::move(dst_op_2));
|
||||
}
|
||||
|
||||
STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) {
|
||||
STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
|
||||
TensorCache *tensor_cache) {
|
||||
// convert GivenTensorFill node to a weight/bias tensor
|
||||
auto ret = tensor_cache->FindTensor(onnx_node.output(0));
|
||||
if (ret < 0) {
|
||||
|
@ -270,8 +284,10 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *dst_op, schema::TensorT *dst_tensor,
|
||||
STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *dst_op,
|
||||
schema::TensorT *dst_tensor,
|
||||
TensorCache *tensor_cache) {
|
||||
// change op_type() to name(), that is unique
|
||||
dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0);
|
||||
|
@ -303,8 +319,11 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache) {
|
||||
void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *dst_op,
|
||||
schema::TensorT *dst_tensor,
|
||||
TensorCache *tensor_cache) {
|
||||
MS_ASSERT(dst_op != nullptr);
|
||||
MS_ASSERT(tensor_cache != nullptr);
|
||||
std::vector<string> quant_node_name;
|
||||
|
@ -361,8 +380,10 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const
|
|||
}
|
||||
}
|
||||
|
||||
STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
const string &onnx_op_type, schema::CNodeT *dst_op) {
|
||||
STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
const string &onnx_op_type,
|
||||
schema::CNodeT *dst_op) {
|
||||
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_op_type);
|
||||
if (node_parser == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "not find " << onnx_op_type << ", node parser is nullptr";
|
||||
|
@ -371,8 +392,10 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, co
|
|||
return node_parser->Parse(onnx_graph, onnx_node, dst_op);
|
||||
}
|
||||
|
||||
STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op,
|
||||
const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) {
|
||||
STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs,
|
||||
schema::CNodeT *dst_op,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
TensorCache *tensor_cache) {
|
||||
for (const auto &onnx_node_input : node_inputs) {
|
||||
auto index = tensor_cache->FindTensor(onnx_node_input);
|
||||
if (index < 0) {
|
||||
|
@ -385,7 +408,8 @@ STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs,
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op,
|
||||
STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs,
|
||||
schema::CNodeT *dst_op,
|
||||
TensorCache *tensor_cache) {
|
||||
for (const auto &onnx_node_output : node_outputs) {
|
||||
auto index = tensor_cache->FindTensor(onnx_node_output);
|
||||
|
@ -400,7 +424,8 @@ STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, schema::TensorT *tensor) {
|
||||
STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value,
|
||||
schema::TensorT *tensor) {
|
||||
size_t data_count = 1;
|
||||
std::for_each(tensor->dims.begin(), tensor->dims.end(), [&data_count](int dim) { data_count *= dim; });
|
||||
size_t data_size = 0;
|
||||
|
@ -459,7 +484,8 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxModelParser::SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef) {
|
||||
STATUS OnnxModelParser::SetAllTensors(const TensorCache &tensor_cache,
|
||||
schema::MetaGraphT *graphDef) {
|
||||
std::vector<schema::TensorT *> tensors = tensor_cache.GetCachedTensor();
|
||||
for (auto iter : tensors) {
|
||||
std::unique_ptr<schema::TensorT> temp(iter);
|
||||
|
@ -481,7 +507,8 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph)
|
|||
}
|
||||
}
|
||||
|
||||
MetaGraphT *OnnxModelParser::Parse(const std::string &modelFile, const std::string &weightFile,
|
||||
MetaGraphT *OnnxModelParser::Parse(const std::string &modelFile,
|
||||
const std::string &weightFile,
|
||||
const QuantType &quantType) {
|
||||
if (ValidateFileStr(modelFile, ".onnx") != RET_OK) {
|
||||
MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx";
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_MODEL_PARSER_H
|
||||
#define MS_ONNX_MODEL_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MODEL_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MODEL_PARSER_H
|
||||
|
||||
#include <google/protobuf/io/coded_stream.h>
|
||||
#include <google/protobuf/io/zero_copy_stream_impl.h>
|
||||
|
@ -37,35 +37,83 @@ namespace lite {
|
|||
class OnnxModelParser : public ModelParser {
|
||||
public:
|
||||
OnnxModelParser();
|
||||
|
||||
virtual ~OnnxModelParser();
|
||||
|
||||
MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile,
|
||||
const QuantType &quantType = QuantType_QUANT_NONE) override;
|
||||
|
||||
private:
|
||||
TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
|
||||
|
||||
std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value);
|
||||
STATUS ReadOnnxModelFromBinary(const std::string &modelFile, google::protobuf::Message *model_proto);
|
||||
STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache);
|
||||
STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache);
|
||||
STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache);
|
||||
STATUS AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const TensorType &type,
|
||||
TensorCache *tensor_cache, int *index);
|
||||
STATUS AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const TensorType &type,
|
||||
TensorCache *tensor_cache, int *index);
|
||||
STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache);
|
||||
void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::MetaGraphT *graph, TensorCache *tensor_cache);
|
||||
STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache);
|
||||
STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
const string &onnx_op_type, schema::CNodeT *dst_op);
|
||||
void SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op,
|
||||
schema::TensorT *dst_tensor, TensorCache *tensor_cache);
|
||||
STATUS SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op,
|
||||
const onnx::NodeProto &onnx_node, TensorCache *tensor_cache);
|
||||
STATUS SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op, TensorCache *tensor_cache);
|
||||
STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, schema::TensorT *tensor);
|
||||
STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef);
|
||||
|
||||
STATUS ReadOnnxModelFromBinary(const std::string &modelFile,
|
||||
google::protobuf::Message *model_proto);
|
||||
|
||||
STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph,
|
||||
TensorCache *tensor_cache);
|
||||
|
||||
STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph,
|
||||
schema::MetaGraphT *graph,
|
||||
TensorCache *tensor_cache);
|
||||
|
||||
STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
|
||||
schema::MetaGraphT *graph,
|
||||
TensorCache *tensor_cache);
|
||||
|
||||
STATUS AddValueInfo(const onnx::ValueInfoProto &proto,
|
||||
const std::string &name,
|
||||
const TensorType &type,
|
||||
TensorCache *tensor_cache,
|
||||
int *index);
|
||||
|
||||
STATUS AddTensorProto(const onnx::TensorProto &proto,
|
||||
const std::string &name,
|
||||
const TensorType &type,
|
||||
TensorCache *tensor_cache,
|
||||
int *index);
|
||||
|
||||
STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *dst_op,
|
||||
schema::TensorT *dst_tensor,
|
||||
TensorCache *tensor_cache);
|
||||
|
||||
void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::MetaGraphT *graph,
|
||||
TensorCache *tensor_cache);
|
||||
|
||||
STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
|
||||
TensorCache *tensor_cache);
|
||||
|
||||
STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
const string &onnx_op_type,
|
||||
schema::CNodeT *dst_op);
|
||||
|
||||
void SetOpQuantParams(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *dst_op,
|
||||
schema::TensorT *dst_tensor,
|
||||
TensorCache *tensor_cache);
|
||||
|
||||
STATUS SetOpInputIndex(const std::vector<string> &node_inputs,
|
||||
schema::CNodeT *dst_op,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
TensorCache *tensor_cache);
|
||||
|
||||
STATUS SetOpOutputIndex(const std::vector<string> &node_outputs,
|
||||
schema::CNodeT *dst_op,
|
||||
TensorCache *tensor_cache);
|
||||
|
||||
STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value,
|
||||
schema::TensorT *tensor);
|
||||
|
||||
STATUS SetAllTensors(const TensorCache &tensor_cache,
|
||||
schema::MetaGraphT *graphDef);
|
||||
|
||||
void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph);
|
||||
|
||||
private:
|
||||
|
@ -75,4 +123,4 @@ class OnnxModelParser : public ModelParser {
|
|||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MS_ONNX_MODEL_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_MODEL_PARSER_H
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -30,6 +31,20 @@ schema::PadMode OnnxNodeParser::GetOnnxPadMode(const onnx::AttributeProto &onnx_
|
|||
return schema::PadMode_NOTSET;
|
||||
}
|
||||
}
|
||||
|
||||
void OnnxNodeParser::Split(const std::string &src_str,
|
||||
std::vector<std::string> *dst_str,
|
||||
const std::string &chr) {
|
||||
std::string ::size_type p1 = 0, p2 = src_str.find(chr);
|
||||
while (std::string::npos != p2) {
|
||||
dst_str->push_back(src_str.substr(p1, p2 - p1));
|
||||
p1 = p2 + chr.size();
|
||||
p2 = src_str.find(chr, p1);
|
||||
}
|
||||
if (p1 != src_str.length()) {
|
||||
dst_str->push_back(src_str.substr(p1));
|
||||
}
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -14,10 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_NODE_PARSER_H
|
||||
#define MS_ONNX_NODE_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_PARSER_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "google/protobuf/message.h"
|
||||
#include "tools/converter/parser/onnx/onnx.pb.h"
|
||||
#include "include/errorcode.h"
|
||||
|
@ -29,14 +30,23 @@ namespace lite {
|
|||
class OnnxNodeParser {
|
||||
public:
|
||||
explicit OnnxNodeParser(const std::string &nodeName) : name(nodeName) {}
|
||||
|
||||
virtual ~OnnxNodeParser() = default;
|
||||
virtual STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) = 0;
|
||||
|
||||
virtual STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) = 0;
|
||||
|
||||
protected:
|
||||
schema::PadMode GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr);
|
||||
|
||||
void Split(const std::string &src_str,
|
||||
std::vector<std::string> *dst_str,
|
||||
const std::string &chr);
|
||||
|
||||
const std::string &name;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_NODE_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_PARSER_H
|
||||
|
||||
|
|
|
@ -21,7 +21,14 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
OnnxNodeParserRegistry::OnnxNodeParserRegistry() = default;
|
||||
|
||||
OnnxNodeParserRegistry::~OnnxNodeParserRegistry() = default;
|
||||
OnnxNodeParserRegistry::~OnnxNodeParserRegistry() {
|
||||
for (auto ite : parsers) {
|
||||
if (ite.second != nullptr) {
|
||||
delete ite.second;
|
||||
ite.second = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OnnxNodeParserRegistry *OnnxNodeParserRegistry::GetInstance() {
|
||||
static OnnxNodeParserRegistry instance;
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_OP_REGISTRY_H
|
||||
#define MS_ONNX_OP_REGISTRY_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_REGISTRY_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_REGISTRY_H
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
@ -30,6 +30,7 @@ class OnnxNodeParserRegistry {
|
|||
virtual ~OnnxNodeParserRegistry();
|
||||
|
||||
static OnnxNodeParserRegistry *GetInstance();
|
||||
|
||||
OnnxNodeParser *GetNodeParser(const std::string &name);
|
||||
|
||||
std::unordered_map<std::string, OnnxNodeParser *> parsers;
|
||||
|
@ -37,12 +38,13 @@ class OnnxNodeParserRegistry {
|
|||
|
||||
class OnnxNodeRegistrar {
|
||||
public:
|
||||
OnnxNodeRegistrar(const std::string &name, OnnxNodeParser *parser) {
|
||||
OnnxNodeRegistrar(const std::string &name,
|
||||
OnnxNodeParser *parser) {
|
||||
OnnxNodeParserRegistry::GetInstance()->parsers[name] = parser;
|
||||
}
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MS_ONNX_OP_REGISTRY_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_REGISTRY_H
|
||||
|
||||
|
|
|
@ -14,14 +14,31 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_pad_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx PadParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::PadT> attr = std::make_unique<schema::PadT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "pads") {
|
||||
|
@ -42,11 +59,9 @@ STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node
|
|||
}
|
||||
}
|
||||
}
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Pad;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Pad;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_LRN_PARSER_H
|
||||
#define MS_ONNX_LRN_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_PAD_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_PAD_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxPadParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxPadParser() : OnnxNodeParser("Pad") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_LRN_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_PAD_PARSER_H
|
||||
|
||||
|
|
|
@ -14,14 +14,30 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_pool_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx PoolParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::PoolingT> attr = std::make_unique<schema::PoolingT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
attr->format = schema::Format_NCHW;
|
||||
const auto &pool_type = onnx_node.op_type();
|
||||
|
@ -79,11 +95,9 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Pooling;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Pooling;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_POOL_PARSER_H
|
||||
#define MS_ONNX_POOL_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_POOL_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_POOL_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxPoolParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxPoolParser() : OnnxNodeParser("Pool") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_POOL_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_POOL_PARSER_H
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_reduce_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -23,7 +23,22 @@ STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph,
|
|||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx ReduceParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ReduceT> attr = std::make_unique<schema::ReduceT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "axes") {
|
||||
|
@ -45,13 +60,12 @@ STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph,
|
|||
} else if (type == "ReduceSum") {
|
||||
attr->mode = schema::ReduceMode_ReduceSum;
|
||||
} else {
|
||||
// MS_LOGE("unsupoort type");
|
||||
}
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Reduce;
|
||||
op->primitive->value.value = attr.release();
|
||||
MS_LOG(ERROR) << "unsupported type";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Reduce;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_REDUCE_PARSER_H
|
||||
#define MS_ONNX_REDUCE_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_REDUCE_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_REDUCE_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxReduceParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxReduceParser() : OnnxNodeParser("Reduce") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_REDUCE_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_REDUCE_PARSER_H
|
||||
|
||||
|
|
|
@ -14,36 +14,64 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_relu_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_relu_parser.h"
|
||||
#include "securec/include/securec.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx ReluParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
const auto &relu_type = onnx_node.op_type();
|
||||
if (relu_type == "Relu") {
|
||||
MS_LOG(DEBUG) << "onnx ReluParser";
|
||||
attr->type = schema::ActivationType_RELU;
|
||||
} else if (relu_type == "LeakyRelu") {
|
||||
MS_LOG(DEBUG) << "onnx LeakyReluParser";
|
||||
attr->type = schema::ActivationType_LEAKY_RELU;
|
||||
}
|
||||
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Activation;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
op->primitive->value.type = schema::PrimitiveType_Activation;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx PReluParser";
|
||||
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
if (onnx_node.input_size() != 2) {
|
||||
MS_LOG(ERROR) << "input num is not 2";
|
||||
return RET_PARAM_INVALID;
|
||||
MS_LOG(ERROR) << "input num should be 2";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::unique_ptr<schema::CaffePReLUT> attr = std::make_unique<schema::CaffePReLUT>();
|
||||
std::vector<onnx::TensorProto> params;
|
||||
|
@ -57,8 +85,8 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
|
|||
|
||||
const onnx::TensorProto *slope = ¶ms[0];
|
||||
if (slope == nullptr) {
|
||||
MS_LOG(ERROR) << "input error";
|
||||
return RET_PARAM_INVALID;
|
||||
MS_LOG(ERROR) << "input error: params[0] is null";
|
||||
return RET_ERROR;
|
||||
}
|
||||
const auto slope_raw_data = reinterpret_cast<const float *>(slope->raw_data().data());
|
||||
const int64_t slope_size = slope->raw_data().size() / sizeof(float);
|
||||
|
@ -74,11 +102,8 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
|
|||
}
|
||||
}
|
||||
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_CaffePReLU;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
op->primitive->value.type = schema::PrimitiveType_CaffePReLU;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_RELU_PARSER_H
|
||||
#define MS_ONNX_RELU_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RELU_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RELU_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,7 +25,10 @@ namespace lite {
|
|||
class OnnxReluParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxReluParser() : OnnxNodeParser("Relu") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
|
||||
class OnnxLeakeyReluParser : public OnnxReluParser {
|
||||
|
@ -36,9 +39,12 @@ class OnnxLeakeyReluParser : public OnnxReluParser {
|
|||
class OnnxPReluParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxPReluParser() : OnnxNodeParser("Prelu") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_RELU_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RELU_PARSER_H
|
||||
|
||||
|
|
|
@ -14,16 +14,32 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_reshape_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_reshape_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx ReshapeParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ReshapeT> attr = std::make_unique<schema::ReshapeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
attr->format = schema::Format_NCHW;
|
||||
std::vector<onnx::TensorProto> params;
|
||||
for (int i = 0; i < onnx_node.input_size(); ++i) {
|
||||
|
@ -40,18 +56,16 @@ STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::
|
|||
} else {
|
||||
if (params.size() != 1) {
|
||||
MS_LOG(ERROR) << "shape param num is " << params.size() << ", not equal to 1";
|
||||
return RET_PARAM_INVALID;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
for (int i = 0; i < params[0].int64_data_size(); ++i) {
|
||||
attr->shape.emplace_back(params[0].int64_data(i));
|
||||
}
|
||||
}
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Reshape;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Reshape;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_RESHAPE_PARSER_H
|
||||
#define MS_ONNX_RESHAPE_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESHAPE_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESHAPE_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxReshapeParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxReshapeParser() : OnnxNodeParser("Reshape") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_RESHAPE_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESHAPE_PARSER_H
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_shape_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -23,12 +23,24 @@ STATUS OnnxShapeParser::Parse(const onnx::GraphProto &onnx_graph,
|
|||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx ShapeParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::ShapeT> attr = std::make_unique<schema::ShapeT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Shape;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ShapeT> attr = std::make_unique<schema::ShapeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Shape;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_SHAPE_PARSER_H
|
||||
#define MS_ONNX_SHAPE_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SHAPE_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SHAPE_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxShapeParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxShapeParser() : OnnxNodeParser("Shape") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_SHAPE_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SHAPE_PARSER_H
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_sigmoid_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -23,13 +23,26 @@ STATUS OnnxSigmoidParser::Parse(const onnx::GraphProto &onnx_graph,
|
|||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx SigmoidParser";
|
||||
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
|
||||
attr->type = schema::ActivationType_SIGMOID;
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Activation;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
attr->type = schema::ActivationType_SIGMOID;
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Activation;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_SIGMOID_PARSER_H
|
||||
#define MS_ONNX_SIGMOID_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SIGMOID_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SIGMOID_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxSigmoidParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxSigmoidParser() : OnnxNodeParser("Sigmoid") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_SIGMOID_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SIGMOID_PARSER_H
|
||||
|
||||
|
|
|
@ -14,15 +14,30 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_slice_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx SliceParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::SliceT> attr = std::make_unique<schema::SliceT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "starts") {
|
||||
|
@ -38,11 +53,9 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
|
|||
}
|
||||
}
|
||||
}
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Slice;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Slice;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_SLICE_PARSER_H
|
||||
#define MS_ONNX_SLICE_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SLICE_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SLICE_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxSliceParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxSliceParser() : OnnxNodeParser("Slice") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_SLICE_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SLICE_PARSER_H
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_softmax_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -23,18 +23,31 @@ STATUS OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph,
|
|||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx SoftMaxParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::SoftMaxT> attr = std::make_unique<schema::SoftMaxT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto& attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "axis") {
|
||||
attr->axis = static_cast<int32_t>(onnx_node_attr.i());
|
||||
}
|
||||
}
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_SoftMax;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_SoftMax;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_SOFTMAX_PARSER_H
|
||||
#define MS_ONNX_SOFTMAX_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SOFTMAX_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SOFTMAX_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxSoftMaxParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxSoftMaxParser() : OnnxNodeParser("Softmax") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_SOFTMAX_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SOFTMAX_PARSER_H
|
||||
|
||||
|
|
|
@ -14,26 +14,40 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_space_to_depth_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxSpaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
STATUS OnnxSpaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx SpaceToDepthParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::SpaceToDepthT> attr = std::make_unique<schema::SpaceToDepthT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "blocksize") {
|
||||
attr->blockSize = static_cast<int32_t>(onnx_node_attr.i());
|
||||
}
|
||||
}
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_SpaceToDepth;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_SpaceToDepth;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_SPACE_TO_DEPTH_PARSER_H
|
||||
#define MS_ONNX_SPACE_TO_DEPTH_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPACE_TO_DEPTH_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPACE_TO_DEPTH_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxSpaceToDepthParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxSpaceToDepthParser() : OnnxNodeParser("SpaceToDepth") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_SPACE_TO_DEPTH_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPACE_TO_DEPTH_PARSER_H
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_squeeze_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -23,7 +23,22 @@ STATUS OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph,
|
|||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx SqueezeParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::SqueezeT> attr = std::make_unique<schema::SqueezeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "axes") {
|
||||
|
@ -32,11 +47,9 @@ STATUS OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph,
|
|||
}
|
||||
}
|
||||
}
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Squeeze;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Squeeze;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_SQUEEZE_PARSER_H
|
||||
#define MS_ONNX_SQUEEZE_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SQUEEZE_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SQUEEZE_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxSqueezeParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxSqueezeParser() : OnnxNodeParser("Squeeze") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_SQUEEZE_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SQUEEZE_PARSER_H
|
||||
|
||||
|
|
|
@ -14,19 +14,33 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_tile_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx TileParser";
|
||||
if (op != nullptr) {
|
||||
std::unique_ptr<schema::TileT> attr = std::make_unique<schema::TileT>();
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Tile;
|
||||
op->primitive->value.value = attr.release();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::TileT> attr = std::make_unique<schema::TileT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Tile;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_TILE_PARSER_H
|
||||
#define MS_ONNX_TILE_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TILE_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TILE_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxTileParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxTileParser() : OnnxNodeParser("Tile") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_ARGMAX_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TILE_PARSER_H
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_transpose_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -23,7 +23,22 @@ STATUS OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph,
|
|||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx TransposeParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::TransposeT> attr = std::make_unique<schema::TransposeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
attr->conjugate = false;
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
|
@ -40,11 +55,9 @@ STATUS OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph,
|
|||
}
|
||||
}
|
||||
}
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Transpose;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Transpose;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_TRANSPOSE_PARSER_H
|
||||
#define MS_ONNX_TRANSPOSE_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TRANSPOSE_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TRANSPOSE_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxTransposeParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxTransposeParser() : OnnxNodeParser("Transpose") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_TRANSPOSE_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TRANSPOSE_PARSER_H
|
||||
|
||||
|
|
|
@ -23,7 +23,22 @@ STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph,
|
|||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx UpsampleParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::UpsampleT> attr = std::make_unique<schema::UpsampleT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "mode") {
|
||||
|
@ -34,12 +49,9 @@ STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph,
|
|||
}
|
||||
}
|
||||
}
|
||||
// to do
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Upsample;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Upsample;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_UPSAMPLE_PARSER_H
|
||||
#define MS_ONNX_UPSAMPLE_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UPSAMPLE_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UPSAMPLE_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxUpsampleParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxUpsampleParser() : OnnxNodeParser("Upsample") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_UPSAMPLE_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UPSAMPLE_PARSER_H
|
||||
|
||||
|
|
|
@ -14,15 +14,30 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_unsqueeze_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxUnSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx UnSqueezeParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::UnsqueezeT> attr = std::make_unique<schema::UnsqueezeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "axes") {
|
||||
|
@ -31,11 +46,9 @@ STATUS OnnxUnSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx
|
|||
}
|
||||
}
|
||||
}
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
op->primitive->value.type = schema::PrimitiveType_Unsqueeze;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Unsqueeze;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_UNSQUEEZE_PARSER_H
|
||||
#define MS_ONNX_UNSQUEEZE_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UNSQUEEZE_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UNSQUEEZE_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxUnSqueezeParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxUnSqueezeParser() : OnnxNodeParser("Unsqueeze") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_UNSQUEEZE_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_UNSQUEEZE_PARSER_H
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_unuseful_node_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -23,25 +23,35 @@ STATUS OnnxUnusefulNodeParser::Parse(const onnx::GraphProto &onnx_graph,
|
|||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx UnusefulNodeParser";
|
||||
if (op != nullptr) {
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (onnx_node.op_type() == "Int8Quantize") {
|
||||
op->primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize;
|
||||
op->primitive->value.value = std::make_unique<schema::OnnxInt8QuantizeT>().release();
|
||||
} else if (onnx_node.op_type() == "Int8Dequantize") {
|
||||
op->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize;
|
||||
op->primitive->value.value = std::make_unique<schema::OnnxInt8DequantizeT>().release();
|
||||
} else {
|
||||
// MS_LOGE("Unsupported nodeType: %s", onnx_node.op_type().c_str());
|
||||
return RET_ERROR;
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
if (onnx_node.op_type() == "Int8Quantize") {
|
||||
std::unique_ptr<schema::OnnxInt8QuantizeT> attr = std::make_unique<schema::OnnxInt8QuantizeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (op->primitive->value.value == nullptr) {
|
||||
// MS_LOGE("new %s attr value failed", onnx_node.op_type().c_str());
|
||||
return RET_ERROR;
|
||||
op->primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize;
|
||||
op->primitive->value.value = attr.release();
|
||||
} else if (onnx_node.op_type() == "Int8Dequantize") {
|
||||
std::unique_ptr<schema::OnnxInt8DequantizeT> attr = std::make_unique<schema::OnnxInt8DequantizeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize;
|
||||
op->primitive->value.value = attr.release();
|
||||
} else {
|
||||
// MS_LOGE("Input opDef is nullptr");
|
||||
return RET_PARAM_INVALID;
|
||||
MS_LOG(ERROR) << "Unsupported nodeType: " << onnx_node.op_type().c_str();
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MS_ONNX_UNUSEFUL_PARSER_H
|
||||
#define MS_ONNX_UNUSEFUL_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -25,9 +25,12 @@ namespace lite {
|
|||
class OnnxUnusefulNodeParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxUnusefulNodeParser() : OnnxNodeParser("UnusefulNode") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MS_ONNX_UNUSEFUL_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H
|
||||
|
||||
|
|
Loading…
Reference in New Issue