forked from mindspore-Ecosystem/mindspore
reconstruct onnx
This commit is contained in:
parent
e1e8f1d429
commit
e7151c194c
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_CONSTANT_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_CONSTANT_H_
|
||||
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Constant : public PrimitiveC {
|
||||
public:
|
||||
Constant() = default;
|
||||
~Constant() = default;
|
||||
MS_DECLARE_PARENT(Constant, PrimitiveC);
|
||||
explicit Constant(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_CONSTANT_H_
|
||||
#endif
|
|
@ -149,6 +149,7 @@
|
|||
#include "src/ops/oneslike.h"
|
||||
#include "src/ops/unsorted_segment_sum.h"
|
||||
#include "src/ops/reciprocal.h"
|
||||
#include "src/ops/constant.h"
|
||||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
#include "src/ops/neg_grad.h"
|
||||
|
@ -182,7 +183,7 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
std::vector<int> CastToInt(const ValuePtr value) {
|
||||
std::vector<int> CastToInt(const ValuePtr &value) {
|
||||
if (value == nullptr) {
|
||||
MS_LOG(WARNING) << "valueptr is nullptr.";
|
||||
return {};
|
||||
|
@ -891,6 +892,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
|
|||
return new (std::nothrow) Dequant(primitive);
|
||||
case schema::PrimitiveType_Reciprocal:
|
||||
return new (std::nothrow) Reciprocal(primitive);
|
||||
case schema::PrimitiveType_Constant:
|
||||
return new (std::nothrow) Constant(primitive);
|
||||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
case schema::PrimitiveType_ActivationGrad:
|
||||
|
|
|
@ -57,7 +57,7 @@ static std::map<std::string, schema::ActivationType> kActivationTypeMap{
|
|||
{"LeakyRelu", schema::ActivationType_LEAKY_RELU},
|
||||
{"Tanh", schema::ActivationType_TANH},
|
||||
{"Logistic", schema::ActivationType_SIGMOID}};
|
||||
std::vector<int> CastToInt(const ValuePtr value);
|
||||
std::vector<int> CastToInt(const ValuePtr &value);
|
||||
class PrimitiveC : public mindspore::Primitive {
|
||||
public:
|
||||
// Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC().
|
||||
|
|
|
@ -205,6 +205,7 @@ if(ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/onnx_inputs_adjust_pass.cc
|
||||
)
|
||||
endif()
|
||||
### train
|
||||
|
|
|
@ -58,6 +58,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
../optimizer/graph/infershape_pass.cc
|
||||
../optimizer/graph/slice_prepose_pass.cc
|
||||
../optimizer/graph/mindir_adjust_pass.cc
|
||||
../optimizer/graph/onnx_inputs_adjust_pass.cc
|
||||
)
|
||||
|
||||
add_subdirectory(../anf_importer anf_importer)
|
||||
|
|
|
@ -36,6 +36,7 @@
|
|||
#include "tools/optimizer/graph/clip_convert_activation_pass.h"
|
||||
#include "tools/optimizer/graph/group_depthwise_op_convert_pass.h"
|
||||
#include "tools/optimizer/graph/tflite_inputs_order_exchange_pass.h"
|
||||
#include "tools/optimizer/graph/onnx_inputs_adjust_pass.h"
|
||||
#include "tools/optimizer/graph/update_conv2d_param_pass.h"
|
||||
#include "tools/optimizer/graph/unused_cast_node_remove_pass.h"
|
||||
#include "tools/optimizer/graph/unused_transpose_node_remove_pass.h"
|
||||
|
@ -74,6 +75,16 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
|
|||
}
|
||||
}
|
||||
|
||||
// onnx pre adjustment
|
||||
if (config->fmk == converter::FmkType_ONNX) {
|
||||
auto onnx_adjust_pass = std::make_shared<opt::OnnxInputAdjustOpPass>();
|
||||
if (!onnx_adjust_pass->Run(old_graph)) {
|
||||
MS_LOG(ERROR) << "onnx adjust failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// for now - trainning is not supporting fuse operations
|
||||
if (!config->trainModel) {
|
||||
// remove quantdtype when awaretraining
|
||||
|
|
|
@ -90,6 +90,7 @@ STATUS CaffeModelParser::ConvertLayers() {
|
|||
auto primitive_c = node_parser->ParseLitePrimitive(layer, weight);
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "parse node " << layer.name() << " failed.";
|
||||
status = RET_ERROR;
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -98,8 +99,7 @@ STATUS CaffeModelParser::ConvertLayers() {
|
|||
status = ConvertBottom(layer, &input_nodes);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert layer bottom for " << layer.name() << " failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return status;
|
||||
continue;
|
||||
}
|
||||
|
||||
// build weights
|
||||
|
@ -107,8 +107,7 @@ STATUS CaffeModelParser::ConvertLayers() {
|
|||
status = ConvertBlobs(weight, &const_parameters);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert blobs for " << layer.name() << " failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return status;
|
||||
continue;
|
||||
}
|
||||
|
||||
// build cnode
|
||||
|
@ -122,15 +121,13 @@ STATUS CaffeModelParser::ConvertLayers() {
|
|||
status = ConvertTop(layer, new_cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert outputs for " << layer.name() << " failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return status;
|
||||
continue;
|
||||
}
|
||||
|
||||
status = ConvertLayerQuantParams(layer, weight, primitive_c);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert quant params for " << layer.name() << " failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return status;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
return status;
|
||||
|
|
|
@ -19,27 +19,22 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxAdderParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxAdderParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx AdderParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto attr = std::make_unique<schema::AdderT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
op->primitive->value.type = schema::PrimitiveType_Adder;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Adder;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxAdderParser("adder_f", new OnnxAdderParser());
|
||||
|
|
|
@ -26,8 +26,7 @@ class OnnxAdderParser : public OnnxNodeParser {
|
|||
public:
|
||||
OnnxAdderParser() : OnnxNodeParser("Adder") {}
|
||||
~OnnxAdderParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,23 +19,14 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxArgMaxParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx ArgMaxParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ArgMaxT> attr = std::make_unique<schema::ArgMaxT>();
|
||||
auto attr = std::make_unique<schema::ArgMaxT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
|
@ -46,10 +37,14 @@ STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
attr->keepDims = static_cast<bool>(onnx_node_attr.i());
|
||||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_ArgMax;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_ArgMax;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxArgMaxParser("ArgMax", new OnnxArgMaxParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxArgMaxParser : public OnnxNodeParser {
|
|||
OnnxArgMaxParser() : OnnxNodeParser("ArgMax") {}
|
||||
~OnnxArgMaxParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -26,203 +26,203 @@ class OnnxAddParser : public OnnxNodeParser {
|
|||
public:
|
||||
OnnxAddParser() : OnnxNodeParser("Add") {}
|
||||
~OnnxAddParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxSubParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxSubParser() : OnnxNodeParser("Sub") {}
|
||||
~OnnxSubParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxMulParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxMulParser() : OnnxNodeParser("Mul") {}
|
||||
~OnnxMulParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxDivParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxDivParser() : OnnxNodeParser("Div") {}
|
||||
~OnnxDivParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxPowParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxPowParser() : OnnxNodeParser("Power") {}
|
||||
~OnnxPowParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxEqualParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxEqualParser() : OnnxNodeParser("Equal") {}
|
||||
~OnnxEqualParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxLessParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxLessParser() : OnnxNodeParser("Less") {}
|
||||
~OnnxLessParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxGreaterParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxGreaterParser() : OnnxNodeParser("Greater") {}
|
||||
~OnnxGreaterParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxMinParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxMinParser() : OnnxNodeParser("Min") {}
|
||||
~OnnxMinParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxEltwiseParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxEltwiseParser() : OnnxNodeParser("Eltwise") {}
|
||||
~OnnxEltwiseParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxFloorParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxFloorParser() : OnnxNodeParser("Floor") {}
|
||||
~OnnxFloorParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxAbsParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxAbsParser() : OnnxNodeParser("Abs") {}
|
||||
~OnnxAbsParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxNegParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxNegParser() : OnnxNodeParser("Neg") {}
|
||||
~OnnxNegParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxExpParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxExpParser() : OnnxNodeParser("Exp") {}
|
||||
~OnnxExpParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxCosParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxCosParser() : OnnxNodeParser("Cos") {}
|
||||
~OnnxCosParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxSinParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxSinParser() : OnnxNodeParser("Sin") {}
|
||||
~OnnxSinParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxSqrtParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxSqrtParser() : OnnxNodeParser("Sqrt") {}
|
||||
~OnnxSqrtParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxCeilParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxCeilParser() : OnnxNodeParser("Ceil") {}
|
||||
~OnnxCeilParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxLogParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxLogParser() : OnnxNodeParser("Log") {}
|
||||
~OnnxLogParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxTanParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxTanParser() : OnnxNodeParser("Tan") {}
|
||||
~OnnxTanParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxAtanParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxAtanParser() : OnnxNodeParser("Atan") {}
|
||||
~OnnxAtanParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxAsinParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxAsinParser() : OnnxNodeParser("Asin") {}
|
||||
~OnnxAsinParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxTanhParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxTanhParser() : OnnxNodeParser("Tanh") {}
|
||||
~OnnxTanhParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxSignParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxSignParser() : OnnxNodeParser("Sign") {}
|
||||
~OnnxSignParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxAndParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxAndParser() : OnnxNodeParser("And") {}
|
||||
~OnnxAndParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxOrParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxOrParser() : OnnxNodeParser("Or") {}
|
||||
~OnnxOrParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxNotParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxNotParser() : OnnxNodeParser("Not") {}
|
||||
~OnnxNotParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxRoundParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxRoundParser() : OnnxNodeParser("Round") {}
|
||||
~OnnxRoundParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxReciprocalParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxReciprocalParser() : OnnxNodeParser("Reciprocal") {}
|
||||
~OnnxReciprocalParser() override = default;
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,23 +19,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxBatchNormParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx BatchNormParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::FusedBatchNormT> attr = std::make_unique<schema::FusedBatchNormT>();
|
||||
auto attr = std::make_unique<schema::FusedBatchNormT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
|
@ -47,10 +37,14 @@ STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx
|
|||
attr->spatial = static_cast<int32_t>(onnx_node_attr.i());
|
||||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxBatchNormParser("BatchNormalization", new OnnxBatchNormParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxBatchNormParser : public OnnxNodeParser {
|
|||
OnnxBatchNormParser() : OnnxNodeParser("BatchNormalization") {}
|
||||
~OnnxBatchNormParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,30 +19,25 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxBiasAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxBiasAddParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx BiasAddParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::BiasAddT> attr = std::make_unique<schema::BiasAddT>();
|
||||
auto attr = std::make_unique<schema::BiasAddT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
attr->axis = {1};
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_BiasAdd;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_BiasAdd;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxBiasAddParser("BiasAdd", new OnnxBiasAddParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxBiasAddParser : public OnnxNodeParser {
|
|||
OnnxBiasAddParser() : OnnxNodeParser("BiasAdd") {}
|
||||
~OnnxBiasAddParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,22 +20,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxCastParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx CastParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>();
|
||||
auto attr = std::make_unique<schema::CastT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
|
@ -48,10 +39,14 @@ STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
attr->dstT = static_cast<int>(dst_type);
|
||||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Cast;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Cast;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxCastParser("Cast", new OnnxCastParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxCastParser : public OnnxNodeParser {
|
|||
OnnxCastParser() : OnnxNodeParser("Cast") {}
|
||||
~OnnxCastParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,39 +19,32 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxClipParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx ClipParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
auto attr = std::make_unique<schema::ClipT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
float min = -1, max = -1;
|
||||
attr->max = -1;
|
||||
attr->min = -1;
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "max") {
|
||||
max = onnx_node_attr.f();
|
||||
attr->max = onnx_node_attr.f();
|
||||
} else if (attribute_name == "min") {
|
||||
min = onnx_node_attr.f();
|
||||
attr->min = onnx_node_attr.f();
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ClipT> attr = std::make_unique<schema::ClipT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->max = max;
|
||||
attr->min = min;
|
||||
op->primitive->value.type = schema::PrimitiveType_Clip;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
return RET_OK;
|
||||
primitive->value.type = schema::PrimitiveType_Clip;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxClipParser("Clip", new OnnxClipParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxClipParser : public OnnxNodeParser {
|
|||
OnnxClipParser() : OnnxNodeParser("Clip") {}
|
||||
~OnnxClipParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,23 +19,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxConcatParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx ConcatParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ConcatT> attr = std::make_unique<schema::ConcatT>();
|
||||
auto attr = std::make_unique<schema::ConcatT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
|
@ -44,10 +34,14 @@ STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
attr->axis = static_cast<int32_t>(onnx_node_attr.i());
|
||||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Concat;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Concat;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxConcatParser("Concat", new OnnxConcatParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxConcatParser : public OnnxNodeParser {
|
|||
OnnxConcatParser() : OnnxNodeParser("Concat") {}
|
||||
~OnnxConcatParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,23 +20,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxConstantOfShapeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx ConstantOfShapeParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ConstantOfShapeT> attr = std::make_unique<schema::ConstantOfShapeT>();
|
||||
auto attr = std::make_unique<schema::ConstantOfShapeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
|
@ -55,19 +45,24 @@ STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, cons
|
|||
const auto &tensor = onnx_node_attr.t();
|
||||
auto ret = GetTensorDataFromOnnx(tensor, &attr->value, &attr->dataType);
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
MS_LOG(ERROR) << "get data from tensor failed";
|
||||
return nullptr;
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "The data type is not supported.";
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_ConstantOfShape;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_ConstantOfShape;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxConstantOfShapeParser("ConstantOfShape", new OnnxConstantOfShapeParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxConstantOfShapeParser : public OnnxNodeParser {
|
|||
OnnxConstantOfShapeParser() : OnnxNodeParser("ConstantOfShape") {}
|
||||
~OnnxConstantOfShapeParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,33 +16,75 @@
|
|||
|
||||
#include "tools/converter/parser/onnx/onnx_constant_parser.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxConstantParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx ConstantParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
STATUS OnnxConstantParser::AddDataInfoAttr(const onnx::TensorProto &onnx_const_tensor, lite::PrimitiveC *primitive_c) {
|
||||
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
|
||||
if (param_value == nullptr) {
|
||||
MS_LOG(ERROR) << "new a paramValueLite failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
auto data_type =
|
||||
OnnxModelParser::GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(onnx_const_tensor.data_type()));
|
||||
if (data_type == kTypeUnknown) {
|
||||
MS_LOG(ERROR) << "not support onnx data type "
|
||||
<< static_cast<onnx::TensorProto_DataType>(onnx_const_tensor.data_type());
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ConstantT> attr = std::make_unique<schema::ConstantT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
std::vector<int64_t> shape_vector(onnx_const_tensor.dims().begin(), onnx_const_tensor.dims().end());
|
||||
std::vector<int> shape;
|
||||
std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape),
|
||||
[](const int64_t &val) { return static_cast<int32_t>(val); });
|
||||
param_value->set_tensor_type(data_type);
|
||||
param_value->set_tensor_shape(shape);
|
||||
param_value->set_format(schema::Format_NCHW);
|
||||
if (OnnxModelParser::CopyOnnxTensorData(onnx_const_tensor, param_value) != RET_OK) {
|
||||
MS_LOG(ERROR) << "get value failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Constant;
|
||||
op->primitive->value.value = attr.release();
|
||||
primitive_c->set_attr("const_data", param_value);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
lite::PrimitiveC *OnnxConstantParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx ConstantParser";
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Constant;
|
||||
auto primitive_c = PrimitiveC::Create(primitive.release());
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "create primitiveC failed.";
|
||||
return nullptr;
|
||||
}
|
||||
for (const auto &attr : onnx_node.attribute()) {
|
||||
if (attr.name() == "sparse_value") {
|
||||
MS_LOG(WARNING) << "sparse_value";
|
||||
continue;
|
||||
}
|
||||
if (attr.name() == "value") {
|
||||
const auto &const_tensor = attr.t();
|
||||
if (AddDataInfoAttr(const_tensor, primitive_c) != RET_OK) {
|
||||
MS_LOG(ERROR) << "add basic attr failed.";
|
||||
delete primitive_c;
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "processing Constant op attr " << attr.name() << " not implemented";
|
||||
delete primitive_c;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return primitive_c;
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxConstantParser("Constant", new OnnxConstantParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,7 +27,8 @@ class OnnxConstantParser : public OnnxNodeParser {
|
|||
OnnxConstantParser() : OnnxNodeParser("Constant") {}
|
||||
~OnnxConstantParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
STATUS AddDataInfoAttr(const onnx::TensorProto &onnx_const_tensor, lite::PrimitiveC *primitive_c);
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,9 +21,14 @@
|
|||
|
||||
namespace mindspore::lite {
|
||||
constexpr int32_t kSingleGroup = 1;
|
||||
bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, schema::CNodeT *op) {
|
||||
bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr,
|
||||
schema::PrimitiveT *primitive) {
|
||||
MS_LOG(DEBUG) << "onnx DepthwiseConvParser";
|
||||
std::unique_ptr<schema::DepthwiseConv2DT> depthwiseConv2DParam = std::make_unique<schema::DepthwiseConv2DT>();
|
||||
if (attr == nullptr || primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "input parameter is nullptr";
|
||||
return false;
|
||||
}
|
||||
auto depthwiseConv2DParam = std::make_unique<schema::DepthwiseConv2DT>();
|
||||
if (depthwiseConv2DParam == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return false;
|
||||
|
@ -45,27 +50,18 @@ bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT
|
|||
depthwiseConv2DParam->hasBias = attr->hasBias;
|
||||
depthwiseConv2DParam->activationType = attr->activationType;
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
||||
op->primitive->value.value = depthwiseConv2DParam.release();
|
||||
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
||||
primitive->value.value = depthwiseConv2DParam.release();
|
||||
return true;
|
||||
}
|
||||
|
||||
STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxConvParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx ConvParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::Conv2DT> attr = std::make_unique<schema::Conv2DT>();
|
||||
auto attr = std::make_unique<schema::Conv2DT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
attr->strideH = 1;
|
||||
|
@ -83,21 +79,21 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
} else if (onnx_node_attr.name() == "dilations") {
|
||||
if (onnx_node_attr.ints().size() != 2) {
|
||||
MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2";
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
} else if (onnx_node_attr.name() == "kernels") {
|
||||
if (onnx_node_attr.ints().size() != 2) {
|
||||
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
} else if (onnx_node_attr.name() == "kernel_shape") {
|
||||
if (onnx_node_attr.ints().size() != 2) {
|
||||
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
|
@ -106,7 +102,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
} else if (onnx_node_attr.name() == "pads") {
|
||||
if (onnx_node_attr.ints().size() != 4) {
|
||||
MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4";
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->padLeft = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
|
@ -115,7 +111,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
} else if (onnx_node_attr.name() == "strides") {
|
||||
if (onnx_node_attr.ints().size() != 2) {
|
||||
MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2";
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
|
@ -124,7 +120,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
attr->format = schema::Format::Format_NHWC;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s();
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -152,7 +148,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
[onnx_conv_weight](const onnx::NodeProto &proto) { return proto.output(0) == onnx_conv_weight; });
|
||||
if (node_iter == onnx_graph.node().end()) {
|
||||
MS_LOG(ERROR) << "can not find node: " << onnx_conv_weight;
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<int> dims;
|
||||
auto iter = std::find_if((*node_iter).attribute().begin(), (*node_iter).attribute().end(),
|
||||
|
@ -160,7 +156,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
if (iter != (*node_iter).attribute().end()) {
|
||||
if (iter->ints().begin() == nullptr || iter->ints().end() == nullptr) {
|
||||
MS_LOG(ERROR) << "dims insert failed";
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end());
|
||||
}
|
||||
|
@ -174,16 +170,21 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
attr->activationType = schema::ActivationType_NO_ACTIVATION;
|
||||
}
|
||||
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
if (attr->group > kSingleGroup && attr->group == attr->channelIn) {
|
||||
if (!ParseGroupConvolution(attr, op)) {
|
||||
if (!ParseGroupConvolution(attr, primitive.get())) {
|
||||
MS_LOG(ERROR) << "Convert Convolution to Depthwise failed";
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
op->primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||
op->primitive->value.value = attr.release();
|
||||
primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||
primitive->value.value = attr.release();
|
||||
}
|
||||
return RET_OK;
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxConvParser("Conv", new OnnxConvParser());
|
||||
|
|
|
@ -28,10 +28,10 @@ class OnnxConvParser : public OnnxNodeParser {
|
|||
OnnxConvParser() : OnnxNodeParser("Conv") {}
|
||||
~OnnxConvParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
|
||||
private:
|
||||
static bool ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, schema::CNodeT *op);
|
||||
static bool ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, schema::PrimitiveT *primitive);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,11 +21,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op) {
|
||||
if (attr == nullptr || attr->group != attr->channelOut) {
|
||||
bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr,
|
||||
schema::PrimitiveT *primitive) {
|
||||
if (attr == nullptr || attr->group != attr->channelOut || primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "input parameter is nullptr";
|
||||
return false;
|
||||
}
|
||||
std::unique_ptr<schema::DeDepthwiseConv2DT> deDepthwiseConv2DParam = std::make_unique<schema::DeDepthwiseConv2DT>();
|
||||
auto deDepthwiseConv2DParam = std::make_unique<schema::DeDepthwiseConv2DT>();
|
||||
if (deDepthwiseConv2DParam == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return false;
|
||||
|
@ -47,28 +49,18 @@ bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeC
|
|||
deDepthwiseConv2DParam->hasBias = attr->hasBias;
|
||||
deDepthwiseConv2DParam->activationType = attr->activationType;
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D;
|
||||
op->primitive->value.value = deDepthwiseConv2DParam.release();
|
||||
primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D;
|
||||
primitive->value.value = deDepthwiseConv2DParam.release();
|
||||
return true;
|
||||
}
|
||||
|
||||
STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx DeConvParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::DeConv2DT> attr = std::make_unique<schema::DeConv2DT>();
|
||||
auto attr = std::make_unique<schema::DeConv2DT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
attr->padMode = schema::PadMode_NOTSET;
|
||||
|
@ -83,21 +75,21 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
} else if (onnx_node_attr.name() == "dilations") {
|
||||
if (onnx_node_attr.ints().size() != 2) {
|
||||
MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2";
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
} else if (onnx_node_attr.name() == "kernels") {
|
||||
if (onnx_node_attr.ints().size() != 2) {
|
||||
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
} else if (onnx_node_attr.name() == "kernel_shape") {
|
||||
if (onnx_node_attr.ints().size() != 2) {
|
||||
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
|
@ -106,7 +98,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
} else if (onnx_node_attr.name() == "pads") {
|
||||
if (onnx_node_attr.ints().size() != 4) {
|
||||
MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4";
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->padLeft = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
|
@ -115,7 +107,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
} else if (onnx_node_attr.name() == "strides") {
|
||||
if (onnx_node_attr.ints().size() != 2) {
|
||||
MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2";
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
|
@ -124,11 +116,11 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
attr->format = schema::Format::Format_NHWC;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s().c_str();
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
} else if (onnx_node_attr.name() == "output_padding") {
|
||||
MS_LOG(ERROR) << "output_padding param hasn't been supported";
|
||||
return RET_NOT_SUPPORT;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -138,7 +130,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
[onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; });
|
||||
if (node_iter == onnx_graph.initializer().end()) {
|
||||
MS_LOG(ERROR) << "not find node: " << onnx_conv_weight.c_str();
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<int> weight_shape;
|
||||
auto size = (*node_iter).dims_size();
|
||||
|
@ -148,7 +140,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
}
|
||||
if (weight_shape.size() != 4) {
|
||||
MS_LOG(ERROR) << "weight_shape.size() should be 4, but is " << weight_shape.size();
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
attr->channelIn = weight_shape[0];
|
||||
attr->channelOut = weight_shape[1] * attr->group;
|
||||
|
@ -156,17 +148,22 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
attr->format = schema::Format::Format_NCHW;
|
||||
attr->hasBias = onnx_node.input().size() == 3;
|
||||
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
if (attr->group != 1) {
|
||||
if (!ParseGroupDeConvolution(attr, op)) {
|
||||
if (!ParseGroupDeConvolution(attr, primitive.get())) {
|
||||
MS_LOG(ERROR) << "Convert DeConvolution to DeDepthwise failed, generalized group deconv hasn't support";
|
||||
return RET_NOT_SUPPORT;
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
op->primitive->value.type = schema::PrimitiveType_DeConv2D;
|
||||
op->primitive->value.value = attr.release();
|
||||
primitive->value.type = schema::PrimitiveType_DeConv2D;
|
||||
primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxDeConvParser("ConvTranspose", new OnnxDeConvParser());
|
||||
|
|
|
@ -28,10 +28,10 @@ class OnnxDeConvParser : public OnnxNodeParser {
|
|||
OnnxDeConvParser() : OnnxNodeParser("DeConv") {}
|
||||
~OnnxDeConvParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
|
||||
private:
|
||||
static bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op);
|
||||
bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::PrimitiveT *primitive);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,23 +19,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxDepthToSpaceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxDepthToSpaceParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx DepthToSpaceParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::DepthToSpaceT> attr = std::make_unique<schema::DepthToSpaceT>();
|
||||
auto attr = std::make_unique<schema::DepthToSpaceT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
|
@ -44,10 +34,14 @@ STATUS OnnxDepthToSpaceParser::Parse(const onnx::GraphProto &onnx_graph, const o
|
|||
attr->blockSize = static_cast<int32_t>(onnx_node_attr.i());
|
||||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_DepthToSpace;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_DepthToSpace;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxDepthToSpaceParser("DepthToSpace", new OnnxDepthToSpaceParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxDepthToSpaceParser : public OnnxNodeParser {
|
|||
OnnxDepthToSpaceParser() : OnnxNodeParser("DepthToSpace") {}
|
||||
~OnnxDepthToSpaceParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,23 +19,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxDropoutParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx DropoutParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::DropoutT> attr = std::make_unique<schema::DropoutT>();
|
||||
auto attr = std::make_unique<schema::DropoutT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
|
@ -44,10 +34,14 @@ STATUS OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::
|
|||
attr->ratio = static_cast<float>(onnx_node_attr.f());
|
||||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Dropout;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Dropout;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxDropoutParser("Dropout", new OnnxDropoutParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxDropoutParser : public OnnxNodeParser {
|
|||
OnnxDropoutParser() : OnnxNodeParser("Dropout") {}
|
||||
~OnnxDropoutParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,22 +19,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxEluParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx EluParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::EluT> attr = std::make_unique<schema::EluT>();
|
||||
auto attr = std::make_unique<schema::EluT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
|
@ -43,10 +34,14 @@ STATUS OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node
|
|||
attr->alpha = onnx_node_attr.f();
|
||||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Elu;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Elu;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxEluParser("Elu", new OnnxEluParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxEluParser : public OnnxNodeParser {
|
|||
OnnxEluParser() : OnnxNodeParser("Elu") {}
|
||||
~OnnxEluParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,23 +20,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxExpandParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx ExpandParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::BroadcastToT> attr = std::make_unique<schema::BroadcastToT>();
|
||||
auto attr = std::make_unique<schema::BroadcastToT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<int> dst_shape;
|
||||
|
@ -46,7 +36,7 @@ STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
[onnx_expand_power](const onnx::NodeProto &proto) { return proto.output(0) == onnx_expand_power; });
|
||||
if (node_iter == onnx_graph.node().end()) {
|
||||
MS_LOG(ERROR) << "can not find node: " << onnx_expand_power;
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
for (const auto &attrPower : node_iter->attribute()) {
|
||||
if (attrPower.name() == "value") {
|
||||
|
@ -58,9 +48,14 @@ STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
}
|
||||
}
|
||||
attr->dst_shape = dst_shape;
|
||||
op->primitive->value.type = schema::PrimitiveType_BroadcastTo;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_BroadcastTo;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxExpandSpaceParser("Expand", new OnnxExpandParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxExpandParser : public OnnxNodeParser {
|
|||
OnnxExpandParser() : OnnxNodeParser("Expand") {}
|
||||
~OnnxExpandParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,23 +19,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxFlattenParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx FlattenParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ReshapeT> attr = std::make_unique<schema::ReshapeT>();
|
||||
auto attr = std::make_unique<schema::ReshapeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int axis = 1;
|
||||
|
@ -49,10 +39,14 @@ STATUS OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::
|
|||
attr->shape.emplace_back(0);
|
||||
}
|
||||
attr->shape.emplace_back(-1);
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Reshape;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Reshape;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxFlattenParser("Flatten", new OnnxFlattenParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxFlattenParser : public OnnxNodeParser {
|
|||
OnnxFlattenParser() : OnnxNodeParser("Fatten") {}
|
||||
~OnnxFlattenParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,23 +19,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxGatherParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxGatherParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx GatherParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::GatherT> attr = std::make_unique<schema::GatherT>();
|
||||
auto attr = std::make_unique<schema::GatherT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
|
@ -45,9 +35,14 @@ STATUS OnnxGatherParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Gather;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Gather;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxGatherParser("Gather", new OnnxGatherParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxGatherParser : public OnnxNodeParser {
|
|||
OnnxGatherParser() : OnnxNodeParser("Gather") {}
|
||||
~OnnxGatherParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_gemm_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
lite::PrimitiveC *OnnxGemmParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx IdentityParser";
|
||||
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("MatMul");
|
||||
if (node_parser == nullptr) {
|
||||
MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed.";
|
||||
return nullptr;
|
||||
}
|
||||
auto *matmul_primitive = node_parser->ParseLitePrimitive(onnx_graph, onnx_node);
|
||||
|
||||
node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("BiasAdd");
|
||||
if (node_parser == nullptr) {
|
||||
MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto *bias_add_primitive = node_parser->ParseLitePrimitive(onnx_graph, onnx_node);
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_MakeTuple;
|
||||
auto primitve_c = PrimitiveC::Create(primitive.release());
|
||||
primitve_c->set_attr("MatMul", std::shared_ptr<lite::PrimitiveC>(matmul_primitive));
|
||||
primitve_c->set_attr("BiasAdd", std::shared_ptr<lite::PrimitiveC>(bias_add_primitive));
|
||||
return primitve_c;
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxGemmParser("Gemm", new OnnxGemmParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -14,26 +14,21 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TENSOR_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TENSOR_PARSER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H
|
||||
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class OnnxTensorParser {
|
||||
class OnnxGemmParser : public OnnxNodeParser {
|
||||
public:
|
||||
~OnnxTensorParser() = default;
|
||||
static OnnxTensorParser *GetInstance() {
|
||||
static OnnxTensorParser onnxTensorParser;
|
||||
return &onnxTensorParser;
|
||||
}
|
||||
TensorCache *GetTensorCache() { return &tensor_cache_; }
|
||||
OnnxGemmParser() : OnnxNodeParser("Gemm") {}
|
||||
~OnnxGemmParser() override = default;
|
||||
|
||||
private:
|
||||
OnnxTensorParser() = default;
|
||||
TensorCache tensor_cache_;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TESNOR_PARSER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H
|
|
@ -0,0 +1,127 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_given_tensor_fill_parser.h"
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "src/param_value_lite.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxGivenTensorFillParser::ParseInt8GivenIntTensorFill(const onnx::NodeProto &onnx_node,
|
||||
lite::PrimitiveC *primitive_c,
|
||||
const std::vector<int> &shape) {
|
||||
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
|
||||
if (param_value == nullptr) {
|
||||
MS_LOG(ERROR) << "new a paramValueLite failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
|
||||
auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(),
|
||||
[](const onnx::AttributeProto &attr) { return attr.name() == "values"; });
|
||||
if (iter == onnx_node.attribute().end()) {
|
||||
return RET_OK;
|
||||
}
|
||||
size_t data_size = data_count * sizeof(int64_t) / sizeof(uint8_t);
|
||||
char *param_data = new (std::nothrow) char[data_size];
|
||||
if (param_data == nullptr) {
|
||||
MS_LOG(ERROR) << "new char[] failed";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
if (memcpy_s(param_data, data_size, iter->ints().data(), data_size) != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy data failed.";
|
||||
delete[] param_data;
|
||||
return RET_ERROR;
|
||||
}
|
||||
param_value->set_tensor_shape(shape);
|
||||
param_value->set_format(schema::Format_NUM_OF_FORMAT);
|
||||
param_value->set_tensor_type(kNumberTypeInt64);
|
||||
param_value->SetTensorData(param_data, data_size);
|
||||
primitive_c->set_attr("const_data", param_value);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxGivenTensorFillParser::ParseInt8GivenTensorFill(const onnx::NodeProto &onnx_node,
|
||||
lite::PrimitiveC *primitive_c,
|
||||
const std::vector<int> &shape) {
|
||||
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
|
||||
if (param_value == nullptr) {
|
||||
MS_LOG(ERROR) << "new a paramValueLite failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
|
||||
auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(),
|
||||
[](const onnx::AttributeProto &attr) { return attr.name() == "values"; });
|
||||
if (iter == onnx_node.attribute().end()) {
|
||||
return RET_OK;
|
||||
}
|
||||
char *param_data = new (std::nothrow) char[data_count];
|
||||
if (param_data == nullptr) {
|
||||
MS_LOG(ERROR) << "new char[] failed";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
if (memcpy_s(param_data, data_count, iter->s().data(), data_count) != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy data failed.";
|
||||
delete[] param_data;
|
||||
return RET_ERROR;
|
||||
}
|
||||
param_value->set_tensor_shape(shape);
|
||||
param_value->set_format(schema::Format_NUM_OF_FORMAT);
|
||||
param_value->set_tensor_type(kNumberTypeUInt8);
|
||||
param_value->SetTensorData(param_data, data_count);
|
||||
primitive_c->set_attr("const_data", param_value);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
lite::PrimitiveC *OnnxGivenTensorFillParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx GivenTensorFillParser";
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Constant;
|
||||
auto primitive_c = PrimitiveC::Create(primitive.release());
|
||||
std::vector<int64_t> shape_vector;
|
||||
auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(),
|
||||
[](const onnx::AttributeProto &attr) { return attr.name() == "shape"; });
|
||||
if (iter != onnx_node.attribute().end()) {
|
||||
shape_vector.insert(shape_vector.begin(), iter->ints().begin(), iter->ints().end());
|
||||
}
|
||||
std::vector<int> shape;
|
||||
std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape),
|
||||
[](const int64_t &val) { return static_cast<int32_t>(val); });
|
||||
if (onnx_node.op_type() == "Int8GivenIntTensorFill") {
|
||||
if (ParseInt8GivenIntTensorFill(onnx_node, primitive_c, shape) != RET_OK) {
|
||||
MS_LOG(ERROR) << "given tensor fill parse failed.";
|
||||
return nullptr;
|
||||
}
|
||||
} else if (onnx_node.op_type() == "Int8GivenTensorFill") {
|
||||
if (ParseInt8GivenTensorFill(onnx_node, primitive_c, shape) != RET_OK) {
|
||||
MS_LOG(ERROR) << "given tensor fill parse failed.";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return primitive_c;
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxInt8GivenIntTensorFillParser("Int8GivenIntTensorFill", new OnnxGivenTensorFillParser());
|
||||
OnnxNodeRegistrar g_onnxInt8GivenTensorFillParser("Int8GivenTensorFill", new OnnxGivenTensorFillParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GIVEN_TENSOR_FILL_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GIVEN_TENSOR_FILL_PARSER_H
|
||||
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class OnnxGivenTensorFillParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxGivenTensorFillParser() : OnnxNodeParser("GivenTensorFill") {}
|
||||
~OnnxGivenTensorFillParser() override = default;
|
||||
|
||||
STATUS ParseInt8GivenIntTensorFill(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c,
|
||||
const std::vector<int> &shape);
|
||||
STATUS ParseInt8GivenTensorFill(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c,
|
||||
const std::vector<int> &shape);
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GIVEN_TENSOR_FILL_PARSER_H
|
|
@ -20,28 +20,23 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxIdentityParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxIdentityParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx IdentityParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::IdentityT> attr = std::make_unique<schema::IdentityT>();
|
||||
auto attr = std::make_unique<schema::IdentityT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Identity;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Identity;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxIdentityParser("Identity", new OnnxIdentityParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxIdentityParser : public OnnxNodeParser {
|
|||
OnnxIdentityParser() : OnnxNodeParser("Identity") {}
|
||||
~OnnxIdentityParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,23 +19,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxInstanceNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxInstanceNormParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx InstanceNormParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::InstanceNormT> attr = std::make_unique<schema::InstanceNormT>();
|
||||
auto attr = std::make_unique<schema::InstanceNormT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!onnx_node.attribute().empty()) {
|
||||
|
@ -44,10 +34,14 @@ STATUS OnnxInstanceNormParser::Parse(const onnx::GraphProto &onnx_graph, const o
|
|||
attr->epsilon = onnx_node_attr.f();
|
||||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_InstanceNorm;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_InstanceNorm;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxInstanceNormParser("InstanceNormalization", new OnnxInstanceNormParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxInstanceNormParser : public OnnxNodeParser {
|
|||
OnnxInstanceNormParser() : OnnxNodeParser("InstanceNorm") {}
|
||||
~OnnxInstanceNormParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,23 +18,13 @@
|
|||
#include <memory>
|
||||
|
||||
namespace mindspore::lite {
|
||||
STATUS OnnxLpNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxLpNormParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx LpNormParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::LpNormalizationT> attr = std::make_unique<schema::LpNormalizationT>();
|
||||
auto attr = std::make_unique<schema::LpNormalizationT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
|
@ -45,10 +35,14 @@ STATUS OnnxLpNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
attr->p = onnx_node_attr.i();
|
||||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_LpNormalization;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_LpNormalization;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxLpNormParser("LpNormalization", new OnnxLpNormParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxLpNormParser : public OnnxNodeParser {
|
|||
OnnxLpNormParser() : OnnxNodeParser("LpNorm") {}
|
||||
~OnnxLpNormParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,22 +18,13 @@
|
|||
#include <memory>
|
||||
|
||||
namespace mindspore::lite {
|
||||
STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxLrnParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx LrnParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::LocalResponseNormalizationT> attr = std::make_unique<schema::LocalResponseNormalizationT>();
|
||||
auto attr = std::make_unique<schema::LocalResponseNormalizationT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int32_t size = 0;
|
||||
|
@ -53,13 +44,18 @@ STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node
|
|||
|
||||
if (size == 0) {
|
||||
MS_LOG(ERROR) << "Divide-by-zero error.";
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
attr->alpha /= size;
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_LocalResponseNormalization;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxLrnxParser("Lrn", new OnnxLrnParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxLrnParser : public OnnxNodeParser {
|
|||
OnnxLrnParser() : OnnxNodeParser("Lrn") {}
|
||||
~OnnxLrnParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,22 +19,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxLstmParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxLstmParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx LstmParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::LstmT> attr = std::make_unique<schema::LstmT>();
|
||||
auto attr = std::make_unique<schema::LstmT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
|
@ -44,9 +35,14 @@ STATUS OnnxLstmParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Lstm;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Lstm;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxLstmParser("LSTM", new OnnxLstmParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxLstmParser : public OnnxNodeParser {
|
|||
OnnxLstmParser() : OnnxNodeParser("LSTM") {}
|
||||
~OnnxLstmParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,23 +19,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxMatmulParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx MatMulParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::MatMulT> attr = std::make_unique<schema::MatMulT>();
|
||||
auto attr = std::make_unique<schema::MatMulT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
float alpha = 1.0f;
|
||||
|
@ -54,12 +44,17 @@ STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
}
|
||||
if (alpha != 1 || beta != 1) {
|
||||
MS_LOG(ERROR) << "not support alpha * A * B + beta * C";
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_MatMul;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_MatMul;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxMatmulParser("MatMul", new OnnxMatmulParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxMatmulParser : public OnnxNodeParser {
|
|||
OnnxMatmulParser() : OnnxNodeParser("MatMul") {}
|
||||
~OnnxMatmulParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -26,75 +26,57 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include "securec/include/securec.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
#include "tools/converter/parser/onnx/onnx_tensor_parser.h"
|
||||
#include "proto/onnx.pb.h"
|
||||
#include "src/param_value_lite.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class OnnxModelParser : public ModelParser {
|
||||
public:
|
||||
OnnxModelParser();
|
||||
OnnxModelParser() = default;
|
||||
|
||||
virtual ~OnnxModelParser();
|
||||
~OnnxModelParser() override = default;
|
||||
|
||||
// schema::MetaGraphT *ParseGraph(const onnx::GraphProto &graph, const QuantType &quantType = QuantType_QUANT_NONE);
|
||||
int ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph, const onnx::GraphProto &onnx_graph,
|
||||
const QuantType &quantType);
|
||||
MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type) override {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type) override;
|
||||
static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
|
||||
static STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value,
|
||||
const ParamValueLitePtr ¶m_value_lite);
|
||||
|
||||
private:
|
||||
schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type = QuantType_QUANT_NONE) override;
|
||||
STATUS InitOriginModel(const std::string &model_file);
|
||||
STATUS ConvertNodes();
|
||||
STATUS ConvertConstTensors();
|
||||
STATUS ConvertGraphInputs();
|
||||
STATUS ConvertGraphOutputs();
|
||||
STATUS BuildReturnNode(const std::vector<AnfNodePtr> &return_inputs);
|
||||
STATUS BuildParameterNode(const ParameterPtr ¶meter_node, const onnx::TensorProto &tensor);
|
||||
STATUS BuildParameterNodeForQuantParam(void *data, const std::string &name, TypeId type);
|
||||
STATUS BuildCNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c);
|
||||
STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const CNodePtr &cnode);
|
||||
STATUS ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c);
|
||||
STATUS ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c);
|
||||
STATUS BuildCNodeForGemm(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c, const std::string &name);
|
||||
STATUS ConvertOpQuantParams(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c);
|
||||
STATUS ParseQuantParam(const onnx::NodeProto &onnx_node);
|
||||
STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector<QuantParamT> *quant_params);
|
||||
STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector<QuantParamT> *quant_params);
|
||||
STATUS CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param, bool scale_or_not);
|
||||
bool IsSpecialOnnxNode(const onnx::NodeProto &onnx_node);
|
||||
|
||||
std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value);
|
||||
|
||||
STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph);
|
||||
|
||||
STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph);
|
||||
|
||||
STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph);
|
||||
|
||||
STATUS AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const Category &type, int *index);
|
||||
|
||||
STATUS AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const Category &type, int *index);
|
||||
|
||||
STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *dst_op, const QuantType &quantType, schema::MetaGraphT *dst_graph);
|
||||
|
||||
void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::SubGraphT *sub_graph, schema::MetaGraphT *graph, const QuantType &quant_type);
|
||||
|
||||
STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node);
|
||||
|
||||
STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
const string &onnx_op_type, schema::CNodeT *dst_op);
|
||||
|
||||
void SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op,
|
||||
schema::TensorT *dst_tensor);
|
||||
|
||||
STATUS SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op,
|
||||
const onnx::NodeProto &onnx_node);
|
||||
|
||||
STATUS SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op);
|
||||
|
||||
STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, schema::TensorT *tensor);
|
||||
|
||||
STATUS SetAllTensors(schema::MetaGraphT *graphDef);
|
||||
|
||||
void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph);
|
||||
|
||||
STATUS ParseSubgraph(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node, const QuantType &quantType,
|
||||
schema::MetaGraphT *dst_graph);
|
||||
|
||||
private:
|
||||
std::vector<std::string> graphInputNames;
|
||||
std::vector<std::string> graphConstNames;
|
||||
int subGraphNum = 0;
|
||||
onnx::ModelProto onnx_model_;
|
||||
onnx::GraphProto onnx_graph_;
|
||||
std::unordered_map<std::string, AnfNodePtr> nodes_;
|
||||
FuncGraphPtr func_graph_ptr_ = nullptr;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -15,7 +15,9 @@
|
|||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "google/protobuf/message.h"
|
||||
#include "proto/onnx.pb.h"
|
||||
#include "include/errorcode.h"
|
||||
|
@ -34,7 +35,8 @@ class OnnxNodeParser {
|
|||
|
||||
virtual ~OnnxNodeParser() = default;
|
||||
|
||||
virtual STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) = 0;
|
||||
virtual lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) = 0;
|
||||
|
||||
static STATUS GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector<float> *value, int *type);
|
||||
|
||||
|
|
|
@ -19,23 +19,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxNonMaxSuppressionParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxNonMaxSuppressionParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx EluParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::NonMaxSuppressionT> attr = std::make_unique<schema::NonMaxSuppressionT>();
|
||||
auto attr = std::make_unique<schema::NonMaxSuppressionT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
|
@ -47,9 +37,14 @@ STATUS OnnxNonMaxSuppressionParser::Parse(const onnx::GraphProto &onnx_graph, co
|
|||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_NonMaxSuppression;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_NonMaxSuppression;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxNonMaxSuppressionParser("NonMaxSuppression", new OnnxNonMaxSuppressionParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxNonMaxSuppressionParser : public OnnxNodeParser {
|
|||
OnnxNonMaxSuppressionParser() : OnnxNodeParser("NonMaxSuppression") {}
|
||||
~OnnxNonMaxSuppressionParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,23 +19,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxOneHotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxOneHotParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx OneHotParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::OneHotT> attr = std::make_unique<schema::OneHotT>();
|
||||
auto attr = std::make_unique<schema::OneHotT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
|
@ -45,9 +35,14 @@ STATUS OnnxOneHotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_OneHot;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_OneHot;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxOneHotParser("OneHot", new OnnxOneHotParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxOneHotParser : public OnnxNodeParser {
|
|||
OnnxOneHotParser() : OnnxNodeParser("OneHot") {}
|
||||
~OnnxOneHotParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,22 +19,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxPadParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx PadParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::PadT> attr = std::make_unique<schema::PadT>();
|
||||
auto attr = std::make_unique<schema::PadT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
|
@ -58,9 +49,14 @@ STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node
|
|||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Pad;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Pad;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxPadParser("Pad", new OnnxPadParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxPadParser : public OnnxNodeParser {
|
|||
OnnxPadParser() : OnnxNodeParser("Pad") {}
|
||||
~OnnxPadParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,22 +19,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxPoolParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx PoolParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::PoolingT> attr = std::make_unique<schema::PoolingT>();
|
||||
auto attr = std::make_unique<schema::PoolingT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
attr->format = schema::Format::Format_NCHW;
|
||||
|
@ -56,7 +47,7 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
attr->global = false;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Pooling param`s PoolingMode is not MAX either AVE. MindSpore support MAX and AVE only.";
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
attr->roundMode = schema::RoundMode_FLOOR;
|
||||
|
@ -101,13 +92,18 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
}
|
||||
if (attribute_name == "dilations") {
|
||||
MS_LOG(ERROR) << "pooling op not support dilations now";
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Pooling;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Pooling;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxMaxPoolParser("MaxPool", new OnnxPoolParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxPoolParser : public OnnxNodeParser {
|
|||
OnnxPoolParser() : OnnxNodeParser("Pool") {}
|
||||
~OnnxPoolParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,23 +19,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxQuantizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxQuantizeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx QuantizeDequantizeParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>();
|
||||
auto attr = std::make_unique<schema::QuantDTypeCastT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed.";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
if (onnx_node.op_type() == "Int8Quantize") {
|
||||
attr->srcT = kNumberTypeFloat32;
|
||||
|
@ -45,11 +35,16 @@ STATUS OnnxQuantizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:
|
|||
attr->dstT = kNumberTypeFloat32;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported nodeType: " << onnx_node.op_type().c_str();
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_QuantDTypeCast;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxInt8QuantizeParser("Int8Quantize", new OnnxQuantizeParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxQuantizeParser : public OnnxNodeParser {
|
|||
OnnxQuantizeParser() : OnnxNodeParser("Quantize") {}
|
||||
~OnnxQuantizeParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,28 +19,23 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxRangeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxRangeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx RangeParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::RangeT> attr = std::make_unique<schema::RangeT>();
|
||||
auto attr = std::make_unique<schema::RangeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
attr->dType = 0;
|
||||
op->primitive->value.type = schema::PrimitiveType_Range;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Range;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxRangeParser("Range", new OnnxRangeParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxRangeParser : public OnnxNodeParser {
|
|||
OnnxRangeParser() : OnnxNodeParser("Range") {}
|
||||
~OnnxRangeParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,23 +19,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxReduceParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx ReduceParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ReduceT> attr = std::make_unique<schema::ReduceT>();
|
||||
auto attr = std::make_unique<schema::ReduceT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
attr->keepDims = 1;
|
||||
|
@ -65,12 +55,17 @@ STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
attr->mode = schema::ReduceMode_ReduceSumSquare;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unsupported type";
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Reduce;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Reduce;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxReduceMeanParser("ReduceMean", new OnnxReduceParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxReduceParser : public OnnxNodeParser {
|
|||
OnnxReduceParser() : OnnxNodeParser("Reduce") {}
|
||||
~OnnxReduceParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,22 +21,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxReluParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx ReluParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
|
||||
auto attr = std::make_unique<schema::ActivationT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto &relu_type = onnx_node.op_type();
|
||||
|
@ -54,29 +45,24 @@ STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Activation;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Activation;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxPReluParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx PReluParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
if (onnx_node.input_size() != 2) {
|
||||
MS_LOG(ERROR) << "input num should be 2";
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
std::unique_ptr<schema::PReLUT> attr = std::make_unique<schema::PReLUT>();
|
||||
auto attr = std::make_unique<schema::PReLUT>();
|
||||
std::vector<onnx::TensorProto> params;
|
||||
const auto &input_name = onnx_node.input(1);
|
||||
for (const auto &it : onnx_graph.initializer()) {
|
||||
|
@ -90,7 +76,7 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
|
|||
const onnx::TensorProto *slope = ¶ms[0];
|
||||
if (slope == nullptr) {
|
||||
MS_LOG(ERROR) << "input error: params[0] is null";
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
const auto slope_raw_data = reinterpret_cast<const float *>(slope->raw_data().data());
|
||||
const int64_t slope_size = slope->raw_data().size() / sizeof(float);
|
||||
|
@ -102,16 +88,21 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
|
|||
attr->channelShared = false;
|
||||
if (memcpy_s(attr->slope.data(), slope_size * sizeof(float), slope_raw_data, slope_size * sizeof(float)) != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s failed";
|
||||
return RET_ERROR;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "The slope pf prelu is null, which may cause errors.";
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_PReLU;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_PReLU;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxReluParser("Relu", new OnnxReluParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxReluParser : public OnnxNodeParser {
|
|||
OnnxReluParser() : OnnxNodeParser("Relu") {}
|
||||
~OnnxReluParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
|
||||
class OnnxPReluParser : public OnnxNodeParser {
|
||||
|
@ -35,7 +35,7 @@ class OnnxPReluParser : public OnnxNodeParser {
|
|||
OnnxPReluParser() : OnnxNodeParser("Prelu") {}
|
||||
~OnnxPReluParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,23 +20,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxReshapeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx ReshapeParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ReshapeT> attr = std::make_unique<schema::ReshapeT>();
|
||||
auto attr = std::make_unique<schema::ReshapeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
attr->format = schema::Format_NCHW;
|
||||
|
@ -51,28 +41,17 @@ STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::
|
|||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
onnx::TensorProto input_shape;
|
||||
const auto &shape_name = onnx_node.input(1);
|
||||
for (const auto &it : onnx_graph.initializer()) {
|
||||
if (it.name() == shape_name) {
|
||||
input_shape = it;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (input_shape.int64_data_size() == 0) {
|
||||
MS_LOG(INFO) << "shape maybe from another op other than const initializer";
|
||||
} else {
|
||||
for (int i = 0; i < input_shape.int64_data_size(); ++i) {
|
||||
shape.push_back(input_shape.int64_data(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
attr->shape = shape;
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Reshape;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Reshape;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxReshapeParser("Reshape", new OnnxReshapeParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxReshapeParser : public OnnxNodeParser {
|
|||
OnnxReshapeParser() : OnnxNodeParser("Reshape") {}
|
||||
~OnnxReshapeParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,23 +22,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxResizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxResizeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx ResizeParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ResizeT> attr = std::make_unique<schema::ResizeT>();
|
||||
auto attr = std::make_unique<schema::ResizeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
attr->format = schema::Format_NCHW;
|
||||
|
@ -85,9 +75,14 @@ STATUS OnnxResizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Resize;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Resize;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxResizeParser("Resize", new OnnxResizeParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxResizeParser : public OnnxNodeParser {
|
|||
OnnxResizeParser() : OnnxNodeParser("Resize") {}
|
||||
~OnnxResizeParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,28 +19,23 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxShapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxShapeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx ShapeParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ShapeT> attr = std::make_unique<schema::ShapeT>();
|
||||
auto attr = std::make_unique<schema::ShapeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Shape;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Shape;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxShapeParser("Shape", new OnnxShapeParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxShapeParser : public OnnxNodeParser {
|
|||
OnnxShapeParser() : OnnxNodeParser("Shape") {}
|
||||
~OnnxShapeParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,30 +19,25 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxSigmoidParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxSigmoidParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx SigmoidParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
|
||||
auto attr = std::make_unique<schema::ActivationT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
attr->type = schema::ActivationType_SIGMOID;
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Activation;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Activation;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxSigmoodParser("Sigmoid", new OnnxSigmoidParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxSigmoidParser : public OnnxNodeParser {
|
|||
OnnxSigmoidParser() : OnnxNodeParser("Sigmoid") {}
|
||||
~OnnxSigmoidParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,77 +23,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxSliceParser::InsertTensor(const std::vector<int> &onnx_val, const std::string &name,
|
||||
onnx::NodeProto *onnx_node) {
|
||||
std::unique_ptr<schema::TensorT> tensor = std::make_unique<schema::TensorT>();
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "new tensor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
tensor->dataType = mindspore::kNumberTypeInt32;
|
||||
tensor->dims.push_back(onnx_val.size());
|
||||
tensor->format = schema::Format::Format_NCHW;
|
||||
tensor->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
int data_size = sizeof(int32_t) * onnx_val.size();
|
||||
tensor->data.resize(data_size);
|
||||
if (data_size != 0 &&
|
||||
memcpy_s(static_cast<void *>(tensor->data.data()), data_size, onnx_val.data(), data_size) != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int tensor_num = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor().size();
|
||||
std::string tensor_name = name + std::to_string(tensor_num);
|
||||
OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(tensor_name, tensor.release(), GRAPH_INPUT);
|
||||
onnx_node->add_input(tensor_name);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxSliceParser::GetInputTensor(std::vector<int> *onnx_val, const std::string &name) {
|
||||
if (onnx_val == nullptr) {
|
||||
MS_LOG(ERROR) << "input vector is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (OnnxTensorParser::GetInstance() == nullptr || OnnxTensorParser::GetInstance()->GetTensorCache() == nullptr) {
|
||||
MS_LOG(ERROR) << "cannot get tensorcache.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(name);
|
||||
if (index == -1) {
|
||||
MS_LOG(ERROR) << "can not find node: " << name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto input_tensor = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor()[index];
|
||||
if (input_tensor->data.empty()) {
|
||||
MS_LOG(DEBUG) << "data is empty.";
|
||||
return RET_NO_CHANGE;
|
||||
}
|
||||
int data_num = std::accumulate(input_tensor->dims.begin(), input_tensor->dims.end(), 1, std::multiplies<int>());
|
||||
onnx_val->resize(data_num);
|
||||
if (memcpy_s(onnx_val->data(), data_num * sizeof(int32_t), input_tensor->data.data(), data_num * sizeof(int32_t)) !=
|
||||
EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxSliceParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx SliceParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::StridedSliceT> attr = std::make_unique<schema::StridedSliceT>();
|
||||
auto attr = std::make_unique<schema::StridedSliceT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<int> starts;
|
||||
|
@ -128,36 +64,17 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
|
|||
}
|
||||
}
|
||||
}
|
||||
int status = RET_OK;
|
||||
switch (onnx_node.input_size()) {
|
||||
case 5: {
|
||||
if (steps.empty()) {
|
||||
status = GetInputTensor(&steps, onnx_node.input(4));
|
||||
}
|
||||
}
|
||||
case 4: {
|
||||
if (status != RET_ERROR && axes.empty()) {
|
||||
status = GetInputTensor(&axes, onnx_node.input(3));
|
||||
}
|
||||
}
|
||||
case 3: {
|
||||
if (status != RET_ERROR && ends.empty()) {
|
||||
status = GetInputTensor(&ends, onnx_node.input(2));
|
||||
}
|
||||
}
|
||||
case 2: {
|
||||
if (status != RET_ERROR && starts.empty()) {
|
||||
status = GetInputTensor(&starts, onnx_node.input(1));
|
||||
}
|
||||
}
|
||||
default: {
|
||||
if (status == RET_ERROR) {
|
||||
MS_LOG(ERROR) << "onnx slice inputs are invalid.";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
}
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_StridedSlice;
|
||||
primitive->value.value = attr.release();
|
||||
auto primitive_c = PrimitiveC::Create(primitive.release());
|
||||
if (starts.empty()) {
|
||||
return primitive_c;
|
||||
}
|
||||
|
||||
if (axes.empty()) {
|
||||
for (size_t i = 0; i < starts.size(); ++i) {
|
||||
axes.push_back(i);
|
||||
|
@ -166,42 +83,11 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
|
|||
if (steps.empty()) {
|
||||
steps.assign(starts.size(), 1);
|
||||
}
|
||||
onnx::NodeProto *slice_node = nullptr;
|
||||
for (auto &node : onnx_graph.node()) {
|
||||
if (&node == &onnx_node) {
|
||||
slice_node = const_cast<onnx::NodeProto *>(&node);
|
||||
}
|
||||
}
|
||||
int insert_num = 5 - onnx_node.input_size();
|
||||
switch (insert_num) {
|
||||
case 4: {
|
||||
std::string name = "slice/starts/";
|
||||
status = InsertTensor(starts, name, slice_node);
|
||||
}
|
||||
case 3:
|
||||
if (status == RET_OK) {
|
||||
std::string name = "slice/ends/";
|
||||
status = InsertTensor(ends, name, slice_node);
|
||||
}
|
||||
case 2:
|
||||
if (status == RET_OK) {
|
||||
std::string name = "slice/axes/";
|
||||
status = InsertTensor(axes, name, slice_node);
|
||||
}
|
||||
case 1:
|
||||
if (status == RET_OK) {
|
||||
std::string name = "slice/steps/";
|
||||
status = InsertTensor(steps, name, slice_node);
|
||||
}
|
||||
default:
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "onnx slice insert tensor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
op->primitive->value.type = schema::PrimitiveType_StridedSlice;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
primitive_c->set_attr("starts", MakeValue<std::vector<int>>(starts));
|
||||
primitive_c->set_attr("ends", MakeValue<std::vector<int>>(ends));
|
||||
primitive_c->set_attr("axes", MakeValue<std::vector<int>>(axes));
|
||||
primitive_c->set_attr("steps", MakeValue<std::vector<int>>(steps));
|
||||
return primitive_c;
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxSliceParser("Slice", new OnnxSliceParser());
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include <string>
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
#include "tools/converter/parser/onnx/onnx_tensor_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -30,9 +29,7 @@ class OnnxSliceParser : public OnnxNodeParser {
|
|||
OnnxSliceParser() : OnnxNodeParser("Slice") {}
|
||||
~OnnxSliceParser() override = default;
|
||||
|
||||
STATUS InsertTensor(const std::vector<int> &onnx_val, const std::string &name, onnx::NodeProto *onnx_node);
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
STATUS GetInputTensor(std::vector<int> *onnx_val, const std::string &name);
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,23 +19,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxSoftMaxParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx SoftMaxParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::SoftMaxT> attr = std::make_unique<schema::SoftMaxT>();
|
||||
auto attr = std::make_unique<schema::SoftMaxT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool axis_is_def = true;
|
||||
|
@ -53,9 +43,14 @@ STATUS OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::
|
|||
attr->axis = 1;
|
||||
}
|
||||
}
|
||||
op->primitive->value.type = schema::PrimitiveType_SoftMax;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_SoftMax;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxSoftMaxParser("Softmax", new OnnxSoftMaxParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxSoftMaxParser : public OnnxNodeParser {
|
|||
OnnxSoftMaxParser() : OnnxNodeParser("Softmax") {}
|
||||
~OnnxSoftMaxParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,23 +19,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxSpaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxSpaceToDepthParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx SpaceToDepthParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::SpaceToDepthT> attr = std::make_unique<schema::SpaceToDepthT>();
|
||||
auto attr = std::make_unique<schema::SpaceToDepthT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
|
@ -45,9 +35,14 @@ STATUS OnnxSpaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, const o
|
|||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_SpaceToDepth;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_SpaceToDepth;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxSpaceToDepthParser("SpaceToDepth", new OnnxSpaceToDepthParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxSpaceToDepthParser : public OnnxNodeParser {
|
|||
OnnxSpaceToDepthParser() : OnnxNodeParser("SpaceToDepth") {}
|
||||
~OnnxSpaceToDepthParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,23 +19,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxSplitParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxSplitParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx SplitParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::SplitT> attr = std::make_unique<schema::SplitT>();
|
||||
auto attr = std::make_unique<schema::SplitT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
attr->splitDim = 0;
|
||||
|
@ -51,9 +41,14 @@ STATUS OnnxSplitParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
|
|||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Split;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Split;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxSplitParser("Split", new OnnxSplitParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxSplitParser : public OnnxNodeParser {
|
|||
OnnxSplitParser() : OnnxNodeParser("Split") {}
|
||||
~OnnxSplitParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,23 +19,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxSqueezeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx SqueezeParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::SqueezeT> attr = std::make_unique<schema::SqueezeT>();
|
||||
auto attr = std::make_unique<schema::SqueezeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
|
@ -47,9 +37,14 @@ STATUS OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::
|
|||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Squeeze;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Squeeze;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxSqueezeParser("Squeeze", new OnnxSqueezeParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxSqueezeParser : public OnnxNodeParser {
|
|||
OnnxSqueezeParser() : OnnxNodeParser("Squeeze") {}
|
||||
~OnnxSqueezeParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,26 +20,22 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxTileParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx TileParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::TileT> attr = std::make_unique<schema::TileT>();
|
||||
auto attr = std::make_unique<schema::TileT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
op->primitive->value.type = schema::PrimitiveType_Tile;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Tile;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxTileParser("Tile", new OnnxTileParser());
|
||||
|
|
|
@ -27,7 +27,7 @@ class OnnxTileParser : public OnnxNodeParser {
|
|||
OnnxTileParser() : OnnxNodeParser("Tile") {}
|
||||
~OnnxTileParser() override = default;
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,22 +19,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxTopkParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
lite::PrimitiveC *OnnxTopkParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx TopKParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::TopKT> attr = std::make_unique<schema::TopKT>();
|
||||
auto attr = std::make_unique<schema::TopKT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
|
@ -44,9 +35,14 @@ STATUS OnnxTopkParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_TopK;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_TopK;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxTopkParser("TopK", new OnnxTopkParser());
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue