append onnx parser

This commit is contained in:
xuanyue 2020-10-19 16:33:56 +08:00
parent 25dd36059b
commit 7f42991624
30 changed files with 769 additions and 116 deletions

View File

@ -219,6 +219,10 @@ union PrimitiveType {
Sgd,
Adam,
GroupConv2DGradInput,
Loop,
NonMaxSuppression,
InstanceNorm,
Identity,
}
enum QuantType: int {
@ -250,6 +254,7 @@ table MetaGraph {
mempoolSize: uint;
nodes: [CNode];
allTensors: [Tensor]; // weight + input + output
subGraph : [MetaGraph];
}
root_type MetaGraph;

View File

@ -18,8 +18,28 @@ namespace mindspore.schema;
enum ResizeMethod: byte {
UNKNOW = -1,
BILINEAR = 0,
NEAREST_NEIGHBOR = 1
LINEAR = 0,
NEAREST = 1,
CUBIC = 2
}
enum CoordinateTransformMode: byte {
COMMON = 0,
HALF_PIXEL = 1,
PYTORCH_HALF_PIXEL = 2,
TF_HALF_PIXEL = 3,
TF_CROP_AND_RESIZE = 4,
ALIGN_CORNERS = 5,
ASYMMETRIC = 6,
ALIGN_CORNERS_WITH_HALF_PIEXL = 7
}
enum NearestMode : byte {
NORMAL = 0,
ROUND_HALF_DOWN = 1,
ROUND_HALF_UP = 2,
FLOOR = 3,
CEIL = 4
}
enum Format : int {
@ -376,8 +396,13 @@ table Resize {
method: ResizeMethod;
newHeight: long;
newWidth: long;
alignCorners: bool = false;
alignCorners: bool = false; // DEPRECATED IN FUTURE: use 'coordinateTransformMode' instead.
preserveAspectRatio: bool = false;
coordinateTransformMode : CoordinateTransformMode;
cubicCoeff : float;
excludeOutside : int;
extrapolationValue : float = 0;
nearestMode : NearestMode;
}
table DetectionPostProcess {
@ -1054,3 +1079,21 @@ table FftReal {
table FftImag {
}
table NonMaxSuppression {
maxOutBoxPerClass : int = 0;
iouThreshold : float = 0;
scoreThreshold : float = 0;
centerPointBox : int = 0;
}
table InstanceNorm {
epsilon : float = 0.00001;
}
table Loop {
subGraphIndex : int;
}
table Identity {
}

View File

@ -51,9 +51,9 @@ int Resize::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::ResizeT();
if (prim.instance_name() == "ResizeNearestNeighbor") {
attr->method = schema::ResizeMethod_NEAREST_NEIGHBOR;
attr->method = schema::ResizeMethod_NEAREST;
} else if (prim.instance_name() == "ResizeBilinear") {
attr->method = schema::ResizeMethod_BILINEAR;
attr->method = schema::ResizeMethod_LINEAR;
} else {
MS_LOG(ERROR) << "wrong resize type";
return RET_ERROR;

View File

@ -41,8 +41,8 @@ int ResizeBaseCPUKernel::CheckParameters() {
return RET_NULL_PTR;
}
method_ = parameter->method_;
if (method_ != static_cast<int>(schema::ResizeMethod_BILINEAR) &&
method_ != static_cast<int>(schema::ResizeMethod_NEAREST_NEIGHBOR)) {
if (method_ != static_cast<int>(schema::ResizeMethod_LINEAR) &&
method_ != static_cast<int>(schema::ResizeMethod_NEAREST)) {
MS_LOG(ERROR) << "Resize method should be bilinear or nearest_neighbor, but got " << method_;
return RET_INVALID_OP_ATTR;
}

View File

@ -14,11 +14,11 @@
* limitations under the License.
*/
#include <algorithm>
#include "src/runtime/kernel/arm/fp32/resize.h"
#include "schema/model_generated.h"
#include "nnacl/fp32/resize.h"
#include <algorithm>
#include "include/errorcode.h"
#include "nnacl/fp32/resize.h"
#include "schema/model_generated.h"
#include "src/runtime/runtime_api.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
@ -41,7 +41,7 @@ int ResizeCPUKernel::Init() {
int ResizeCPUKernel::ReSize() {
int ret = RET_OK;
if (method_ == static_cast<int>(schema::ResizeMethod_BILINEAR)) {
if (method_ == static_cast<int>(schema::ResizeMethod_LINEAR)) {
FreeTmpBuffer();
ret = MallocTmpBuffer();
if (ret != RET_OK) {
@ -162,7 +162,7 @@ int ResizeCPUKernel::RunImpl(int task_id) {
int ret = 0;
switch (method_) {
case static_cast<int>(schema::ResizeMethod_BILINEAR): {
case static_cast<int>(schema::ResizeMethod_LINEAR): {
int n_h_begin, n_h_end;
int n = out_tensors_.at(0)->shape()[0];
int h = new_height_;
@ -178,7 +178,7 @@ int ResizeCPUKernel::RunImpl(int task_id) {
break;
}
case static_cast<int>(schema::ResizeMethod_NEAREST_NEIGHBOR): {
case static_cast<int>(schema::ResizeMethod_NEAREST): {
if (in_tensors_.size() == lite::kDoubleNum && !const_shape_) {
auto out_shape = in_tensors_.at(1);
auto data = reinterpret_cast<int32_t *>(out_shape->MutableData());

View File

@ -14,12 +14,12 @@
* limitations under the License.
*/
#include "src/runtime/kernel/arm/int8/resize_int8.h"
#include <vector>
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "nnacl/int8/resize.h"
#include "schema/model_generated.h"
#include "include/errorcode.h"
#include "src/runtime/kernel/arm/int8/resize_int8.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
@ -84,7 +84,7 @@ int ResizeInt8CPUKernel::RunImpl(int task_id) {
int ret = 0;
switch (method_) {
case static_cast<int>(schema::ResizeMethod_BILINEAR): {
case static_cast<int>(schema::ResizeMethod_LINEAR): {
if (quant_in_->zp_ == 0) {
ret = ResizeBilinearInt8(input_data, output_data, input_shape.data(), out_tensors_[0]->shape().data(),
align_corners_, quant_in_, quant_out_, multiplier_, task_id, context_->thread_num_);
@ -95,7 +95,7 @@ int ResizeInt8CPUKernel::RunImpl(int task_id) {
}
break;
}
case static_cast<int>(schema::ResizeMethod_NEAREST_NEIGHBOR): {
case static_cast<int>(schema::ResizeMethod_NEAREST): {
bool same_zp = quant_in_->zp_ == quant_out_->zp_;
bool same_scale = abs(quant_out_->scale_ - quant_in_->scale_) < 1e-6;
if (same_zp && same_scale) {

View File

@ -14,12 +14,12 @@
* limitations under the License.
*/
#include "src/runtime/kernel/opencl/kernel/resize.h"
#include <map>
#include <set>
#include <string>
#include <map>
#include "include/errorcode.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/opencl/kernel/resize.h"
#include "src/runtime/kernel/opencl/cl/resize.cl.inc"
using mindspore::kernel::KERNEL_ARCH::kGPU;
@ -46,9 +46,9 @@ int ResizeOpenCLKernel::Init() {
return RET_PARAM_INVALID;
}
std::string kernel_name = "resize";
if (resize_param->method_ == schema::ResizeMethod_BILINEAR) {
if (resize_param->method_ == schema::ResizeMethod_LINEAR) {
kernel_name += "_bilinear";
} else if (resize_param->method_ == schema::ResizeMethod_NEAREST_NEIGHBOR) {
} else if (resize_param->method_ == schema::ResizeMethod_NEAREST) {
kernel_name += "_nearest_neighbor";
} else {
MS_LOG(ERROR) << "unsupported resize method:" << resize_param->method_;

View File

@ -14,11 +14,11 @@
* limitations under the License.
*/
#include <vector>
#include "common/common_test.h"
#include "mindspore/lite/src/kernel_registry.h"
#include "mindspore/lite/src/lite_kernel.h"
#include "mindspore/lite/src/tensor.h"
#include "common/common_test.h"
#include "nnacl/resize_parameter.h"
#include "mindspore/lite/src/kernel_registry.h"
#include "schema/ops_generated.h"
using mindspore::schema::Format_NHWC;
@ -62,7 +62,7 @@ void TestResizeBilinearFp32::Prepare(const std::vector<int> &input_shape, const
out_tensor_.SetData(output_data);
ResizeParameter param_ = {
{}, static_cast<int>(schema::ResizeMethod_BILINEAR), output_shape[1], output_shape[2], align_corners};
{}, static_cast<int>(schema::ResizeMethod_LINEAR), output_shape[1], output_shape[2], align_corners};
desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Resize};
ctx_ = lite::InnerContext();
ctx_.thread_num_ = thread_num;

View File

@ -16,7 +16,7 @@
#include <vector>
#include "common/common_test.h"
#include "nnacl/resize_parameter.h"
#include "mindspore/lite/src/kernel_registry.h"
#include "src/kernel_registry.h"
namespace mindspore {
@ -57,7 +57,7 @@ void TestResizeNearestNeighborFp32::Prepare(const std::vector<int> &input_shape,
out_tensor_.SetData(output_data);
ResizeParameter param_ = {
{}, static_cast<int>(schema::ResizeMethod_NEAREST_NEIGHBOR), output_shape[1], output_shape[2], align_corners};
{}, static_cast<int>(schema::ResizeMethod_NEAREST), output_shape[1], output_shape[2], align_corners};
desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Resize};
ctx_ = lite::InnerContext();
ctx_.thread_num_ = thread_num;

View File

@ -19,7 +19,7 @@
#include "include/context.h"
#include "src/tensor.h"
#include "common/common_test.h"
#include "mindspore/lite/src/kernel_registry.h"
#include "src/kernel_registry.h"
#include "nnacl/int8/resize.h"
namespace mindspore {
@ -68,7 +68,7 @@ void TestResizeBilinearInt8::Prepare(const std::vector<int> &in_shape, const std
inputs.push_back(&in_tensor);
outputs.push_back(&out_tensor);
param_.method_ = static_cast<int>(schema::ResizeMethod_BILINEAR);
param_.method_ = static_cast<int>(schema::ResizeMethod_LINEAR);
param_.new_width_ = out_shape[2];
param_.new_height_ = out_shape[1];
param_.align_corners_ = align_corners;

View File

@ -19,7 +19,7 @@
#include "include/context.h"
#include "src/tensor.h"
#include "common/common_test.h"
#include "mindspore/lite/src/kernel_registry.h"
#include "src/kernel_registry.h"
#include "nnacl/int8/resize.h"
namespace mindspore {
@ -63,7 +63,7 @@ void TestResizeNearestNeighborInt8::Prepare(const std::vector<int> &in_shape, co
inputs.push_back(&in_tensor);
outputs.push_back(&out_tensor);
param_.method_ = static_cast<int>(schema::ResizeMethod_NEAREST_NEIGHBOR);
param_.method_ = static_cast<int>(schema::ResizeMethod_NEAREST);
param_.new_width_ = out_shape[2];
param_.new_height_ = out_shape[1];
param_.align_corners_ = align_corners;

View File

@ -15,13 +15,13 @@
*/
#include <iostream>
#include <memory>
#include "src/common/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/common/file_utils.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/resize.h"
#include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h"
#include "src/common/file_utils.h"
#include "src/common/log_adapter.h"
#include "src/runtime/kernel/opencl/kernel/resize.h"
#include "src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "test/ut/src/runtime/kernel/opencl/utils_tests.h"
namespace mindspore {
class TestResizeOpenCL : public mindspore::CommonTest {
@ -119,7 +119,7 @@ TEST_F(TestResizeOpenCL, ResizeBilinearFp32) {
std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
std::vector<float> output_data = {0.0f, 0.5f, 1.0f, 1.0f, 1.0f, 1.5f, 2.0f, 2.0f,
2.0f, 2.5f, 3.0f, 3.0f, 2.0f, 2.5f, 3.0f, 3.0f};
RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_BILINEAR, align_corners);
RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_LINEAR, align_corners);
}
TEST_F(TestResizeOpenCL, ResizeBilinearFp16) {
@ -134,7 +134,7 @@ TEST_F(TestResizeOpenCL, ResizeBilinearFp16) {
std::vector<float16_t> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
std::vector<float16_t> output_data = {0.0f, 0.5f, 1.0f, 1.0f, 1.0f, 1.5f, 2.0f, 2.0f,
2.0f, 2.5f, 3.0f, 3.0f, 2.0f, 2.5f, 3.0f, 3.0f};
RunTestCaseResize(shape, input_data.data(), output_data.data(), true, schema::ResizeMethod_BILINEAR, align_corners);
RunTestCaseResize(shape, input_data.data(), output_data.data(), true, schema::ResizeMethod_LINEAR, align_corners);
}
TEST_F(TestResizeOpenCL, ResizeBilinearAlignFp32) {
@ -148,7 +148,7 @@ TEST_F(TestResizeOpenCL, ResizeBilinearAlignFp32) {
std::vector<int> shape = {n, h, w, oh, ow, c};
std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
std::vector<float> output_data = {0.0f, 0.5f, 1.0f, 1.0f, 1.5f, 2.0f, 2.0f, 2.5f, 3.0f};
RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_BILINEAR, align_corners);
RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_LINEAR, align_corners);
}
TEST_F(TestResizeOpenCL, ResizeNearestNeighborFp32) {
@ -163,8 +163,7 @@ TEST_F(TestResizeOpenCL, ResizeNearestNeighborFp32) {
std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
std::vector<float> output_data = {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f,
2.0f, 2.0f, 3.0f, 3.0f, 2.0f, 2.0f, 3.0f, 3.0f};
RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_NEAREST_NEIGHBOR,
align_corners);
RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_NEAREST, align_corners);
}
TEST_F(TestResizeOpenCL, ResizeNearestNeighborFp16) {
@ -179,7 +178,6 @@ TEST_F(TestResizeOpenCL, ResizeNearestNeighborFp16) {
std::vector<float16_t> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
std::vector<float16_t> output_data = {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f,
2.0f, 2.0f, 3.0f, 3.0f, 2.0f, 2.0f, 3.0f, 3.0f};
RunTestCaseResize(shape, input_data.data(), output_data.data(), true, schema::ResizeMethod_NEAREST_NEIGHBOR,
align_corners);
RunTestCaseResize(shape, input_data.data(), output_data.data(), true, schema::ResizeMethod_NEAREST, align_corners);
}
} // namespace mindspore

View File

@ -40,7 +40,7 @@ TEST_F(TestTfliteParserResizeNN, AttrValue) {
ASSERT_EQ(val->newWidth, 100);
ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->preserveAspectRatio, false);
ASSERT_EQ(val->method, schema::ResizeMethod_NEAREST_NEIGHBOR);
ASSERT_EQ(val->method, schema::ResizeMethod_NEAREST);
}
class TestTfliteParserResizeBilinear : public TestTfliteParser {
@ -64,7 +64,7 @@ TEST_F(TestTfliteParserResizeBilinear, AttrValue) {
ASSERT_EQ(val->newWidth, 4);
ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->preserveAspectRatio, false);
ASSERT_EQ(val->method, schema::ResizeMethod_BILINEAR);
ASSERT_EQ(val->method, schema::ResizeMethod_LINEAR);
}
} // namespace mindspore

View File

@ -57,7 +57,7 @@ STATUS CaffeInterpParser::Parse(const caffe::LayerParameter &proto, const caffe:
attr->newWidth = width;
}
attr->alignCorners = true;
attr->method = schema::ResizeMethod_BILINEAR;
attr->method = schema::ResizeMethod_LINEAR;
op->name = proto.name();
op->primitive->value.type = schema::PrimitiveType_Resize;

View File

@ -582,6 +582,94 @@ STATUS OnnxSignParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
return RET_OK;
}
STATUS OnnxAndParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx AndParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::LogicalAndT> attr = std::make_unique<schema::LogicalAndT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_LogicalAnd;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxOrParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx OrParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::LogicalOrT> attr = std::make_unique<schema::LogicalOrT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_LogicalOr;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxNotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx NotParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::LogicalNotT> attr = std::make_unique<schema::LogicalNotT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_LogicalNot;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS OnnxRoundParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx RoundParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::RoundT> attr = std::make_unique<schema::RoundT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Round;
op->primitive->value.value = attr.release();
return RET_OK;
}
OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser());
OnnxNodeRegistrar g_onnxInt8AddParser("Int8Add", new OnnxAddParser());
OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser());
@ -608,5 +696,9 @@ OnnxNodeRegistrar g_onnxAtanParser("Atan", new OnnxAtanParser());
OnnxNodeRegistrar g_onnxAsinParser("Asin", new OnnxAsinParser());
OnnxNodeRegistrar g_onnxTanhParser("Tanh", new OnnxTanhParser());
OnnxNodeRegistrar g_onnxSignParser("Sign", new OnnxTanhParser());
OnnxNodeRegistrar g_onnxAndParser("And", new OnnxAndParser());
OnnxNodeRegistrar g_onnxOrParser("Or", new OnnxOrParser());
OnnxNodeRegistrar g_onnxNotParser("Not", new OnnxNotParser());
OnnxNodeRegistrar g_onnxRoundParser("Round", new OnnxRoundParser());
} // namespace lite
} // namespace mindspore

View File

@ -171,6 +171,30 @@ class OnnxSignParser : public OnnxNodeParser {
OnnxSignParser() : OnnxNodeParser("Sign") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
};
class OnnxAndParser : public OnnxNodeParser {
public:
OnnxAndParser() : OnnxNodeParser("And") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
};
class OnnxOrParser : public OnnxNodeParser {
public:
OnnxOrParser() : OnnxNodeParser("Or") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
};
class OnnxNotParser : public OnnxNodeParser {
public:
OnnxNotParser() : OnnxNodeParser("Not") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
};
class OnnxRoundParser : public OnnxNodeParser {
public:
OnnxRoundParser() : OnnxNodeParser("Round") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H

View File

@ -15,9 +15,9 @@
*/
#include "tools/converter/parser/onnx/onnx_conv_parser.h"
#include <vector>
#include <memory>
#include <algorithm>
#include <memory>
#include <vector>
namespace mindspore {
namespace lite {
@ -176,9 +176,6 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
MS_LOG(ERROR) << "Convert Convolution to Depthwise failed";
return RET_ERROR;
}
} else if (attr->group != 1) {
MS_LOG(ERROR) << "group conv hasn't supported";
return RET_NOT_SUPPORT;
} else {
op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = attr.release();

View File

@ -0,0 +1,49 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_identity_parser.h"
#include <memory>
#include <vector>
namespace mindspore {
namespace lite {
STATUS OnnxIdentityParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx IdentityParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::IdentityT> attr = std::make_unique<schema::IdentityT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Identity;
op->primitive->value.value = attr.release();
return RET_OK;
}
OnnxNodeRegistrar g_onnxIdentityParser("Identity", new OnnxIdentityParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,33 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IDENTITY_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IDENTITY_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore {
namespace lite {
class OnnxIdentityParser : public OnnxNodeParser {
public:
OnnxIdentityParser() : OnnxNodeParser("Identity") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IDENTITY_PARSER_H

View File

@ -0,0 +1,55 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_instance_norm_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
STATUS OnnxInstanceNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx InstanceNormParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::InstanceNormT> attr = std::make_unique<schema::InstanceNormT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
if (!onnx_node.attribute().empty()) {
auto onnx_node_attr = onnx_node.attribute().at(0);
if (onnx_node_attr.name() == "epsilon") {
attr->epsilon = onnx_node_attr.f();
}
}
op->primitive->value.type = schema::PrimitiveType_InstanceNorm;
op->primitive->value.value = attr.release();
return RET_OK;
}
OnnxNodeRegistrar g_onnxInstanceNormParser("InstanceNormalization", new OnnxInstanceNormParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,33 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_INSTANCE_NORM_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_INSTANCE_NORM_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore {
namespace lite {
class OnnxInstanceNormParser : public OnnxNodeParser {
public:
OnnxInstanceNormParser() : OnnxNodeParser("InstanceNorm") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_INSTANCE_NORM_PARSER_H

View File

@ -15,12 +15,12 @@
*/
#include "tools/converter/parser/onnx/onnx_model_parser.h"
#include <algorithm>
#include <cfloat>
#include <unordered_map>
#include <algorithm>
#include <utility>
#include "tools/common/graph_util.h"
#include "src/common/utils.h"
#include "tools/common/graph_util.h"
#include "tools/common/protobuf_utils.h"
namespace mindspore {
@ -36,7 +36,8 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
{onnx::TensorProto_DataType_UINT32, mindspore::kNumberTypeUInt32},
{onnx::TensorProto_DataType_INT64, mindspore::kNumberTypeInt64},
{onnx::TensorProto_DataType_FLOAT16, mindspore::kNumberTypeFloat16},
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32}};
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
{onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type) {
auto iter = TYPE_MAP.find(onnx_type);
@ -161,9 +162,13 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
TensorCache *tensor_cache) {
for (const auto &output_value : onnx_graph.output()) {
int index;
const auto status = AddValueInfo(output_value, output_value.name(), OP_OUTPUT, tensor_cache, &index);
if (status != RET_OK) {
return status;
if (tensor_cache->FindTensor(output_value.name()) != -1) {
index = tensor_cache->FindTensor(output_value.name());
} else {
const auto status = AddValueInfo(output_value, output_value.name(), OP_OUTPUT, tensor_cache, &index);
if (status != RET_OK) {
return status;
}
}
graph->outputIndex.emplace_back(index);
MS_LOG(DEBUG) << "output_value name: " << output_value.name() << ", graph output index: " << index;
@ -250,7 +255,8 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op, schema::TensorT *dst_tensor,
TensorCache *tensor_cache, const QuantType &quantType) {
TensorCache *tensor_cache, const QuantType &quantType,
schema::MetaGraphT *dst_graph) {
// change op_type() to name(), that is unique
static bool interrupt = false;
dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0);
@ -260,23 +266,34 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
<< onnx_node.input_size();
// get the real op type
SetOpQuantParams(onnx_graph, onnx_node, dst_op, dst_tensor, tensor_cache);
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type());
if (node_parser == nullptr || interrupt) {
if (onnx_node.op_type() == "Loop") {
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
interrupt = true;
if (node_parser == nullptr) {
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
}
return RET_NOT_FIND_OP;
}
auto status = node_parser->Parse(onnx_graph, onnx_node, dst_op);
if (status != RET_OK) {
interrupt = true;
if (status == RET_NOT_SUPPORT) {
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
} else {
MS_LOG(ERROR) << "parser onnx node " << onnx_node.op_type() << " attr failed";
int status = ParseLoopAttr(dst_op, onnx_node, quantType, dst_graph);
if (status != RET_OK || interrupt) {
interrupt = true;
return status;
}
} else {
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type());
if (node_parser == nullptr || interrupt) {
interrupt = true;
if (node_parser == nullptr) {
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
}
return RET_NOT_FIND_OP;
}
auto status = node_parser->Parse(onnx_graph, onnx_node, dst_op);
if (status != RET_OK) {
interrupt = true;
if (status == RET_NOT_FIND_OP) {
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
} else {
MS_LOG(ERROR) << "parser onnx node " << onnx_node.op_type() << " attr failed";
}
return status;
}
return status;
}
// set op input index
std::vector<string> node_inputs;
@ -366,7 +383,7 @@ STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs,
const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) {
for (const auto &onnx_node_input : node_inputs) {
if (onnx_node_input != "") {
auto index = tensor_cache->FindTensor(onnx_node_input);
int index = tensor_cache->FindTensor(onnx_node_input);
if (index < 0) {
MS_LOG(ERROR) << "input " << onnx_node_input << " of node " << onnx_node.name() << " can't be found";
return RET_ERROR;
@ -428,6 +445,9 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
}
for (size_t i = 0; i < data_count; ++i) {
if (in_data[i] > static_cast<int64_t>(INT32_MAX) || in_data[i] < static_cast<int64_t>(INT32_MIN)) {
if (llabs(in_data[i]) == INT64_MAX || in_data[i] == INT64_MIN) {
buffer[i] = in_data[i] > 0 ? INT32_MAX : INT32_MIN;
}
MS_LOG(ERROR) << "int64 data " << in_data[i] << "too big to fit into int32";
return RET_ERROR;
} else {
@ -438,6 +458,7 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
break;
case kNumberTypeUInt8:
case kNumberTypeInt8:
case kNumberTypeBool:
data_size = data_count * sizeof(uint8_t);
tensor_data = onnx_const_value.raw_data().data();
break;
@ -446,7 +467,7 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
return RET_ERROR;
}
tensor->data.resize(data_size);
if (memcpy_s(static_cast<void *>(tensor->data.data()), data_size, tensor_data, data_size) != 0) {
if (data_size != 0 && memcpy_s(static_cast<void *>(tensor->data.data()), data_size, tensor_data, data_size) != 0) {
MS_LOG(ERROR) << "memcpy_s failed";
return RET_ERROR;
}
@ -475,30 +496,39 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph)
}
}
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) {
int status = ValidateFileStr(modelFile, ".onnx");
if (status != RET_OK) {
MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
STATUS OnnxModelParser::ParseLoopAttr(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node,
const QuantType &quantType, schema::MetaGraphT *dst_graph) {
MS_LOG(DEBUG) << "onnx LoopParser";
if (dst_op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
onnx::ModelProto onnx_model;
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model);
if (status != RET_OK) {
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << modelFile;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
dst_op->primitive = std::make_unique<schema::PrimitiveT>();
if (dst_op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
const onnx::GraphProto &onnx_graph = onnx_model.graph();
MS_LOG(INFO) << "model producer name: " << onnx_model.producer_name() << ", graph name: " << onnx_graph.name();
std::unique_ptr<schema::LoopT> attr = std::make_unique<schema::LoopT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->subGraphIndex = subGraphNum;
auto sub_graph = std::make_unique<schema::MetaGraphT>();
sub_graph.reset(ParseGraph(onnx_node.attribute().at(0).g(), quantType));
dst_graph->subGraph.push_back(std::move(sub_graph));
subGraphNum += 1;
dst_op->primitive->value.type = schema::PrimitiveType_Loop;
dst_op->primitive->value.value = attr.release();
return RET_OK;
}
schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_graph, const QuantType &quantType) {
TensorCache tensor_cache;
// dst_graph->name = onnx_graph.name(); // this is not used
// find out input names and const names
FindGraphInputAndConst(onnx_graph);
// set const tensor
status = SetGraphConstTensor(onnx_graph, &tensor_cache);
int status = SetGraphConstTensor(onnx_graph, &tensor_cache);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetGraphConstTensor failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@ -512,13 +542,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
// init onnx model graph output tensor
status = SetGraphOutputTensor(onnx_graph, dst_graph.get(), &tensor_cache);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetGraphOutputTensor failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
// init op node input/output tensor, and dst_op attr
NoSupportOp::GetInstance()->SetFmkType("ONNX");
for (const auto &onnx_node : onnx_graph.node()) {
@ -544,7 +568,8 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
std::unique_ptr<schema::CNodeT> dst_op = std::make_unique<schema::CNodeT>();
std::unique_ptr<schema::TensorT> dst_tensor = std::make_unique<schema::TensorT>();
status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType);
status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType,
dst_graph.get());
if (status_node != RET_OK) {
status = (status == RET_OK ? status_node : status);
continue;
@ -558,9 +583,42 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
}
return nullptr;
}
// init onnx model graph output tensor
status = SetGraphOutputTensor(onnx_graph, dst_graph.get(), &tensor_cache);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetGraphOutputTensor failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
SetAllTensors(tensor_cache, dst_graph.get());
dst_graph->name = GetModelName(modelFile);
return dst_graph.release();
}
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) {
int status = ValidateFileStr(modelFile, ".onnx");
if (status != RET_OK) {
MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
onnx::ModelProto onnx_model;
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model);
if (status != RET_OK) {
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << modelFile;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
const onnx::GraphProto &onnx_graph = onnx_model.graph();
MS_LOG(INFO) << "model producer name: " << onnx_model.producer_name() << ", graph name: " << onnx_graph.name();
schema::MetaGraphT *dst_graph = ParseGraph(onnx_graph, quantType);
if (dst_graph == nullptr) {
return nullptr;
}
dst_graph->name = GetModelName(modelFile);
return dst_graph;
}
} // namespace lite
} // namespace mindspore

View File

@ -26,6 +26,7 @@
#include <vector>
#include <memory>
#include <set>
#include <map>
#include "securec/include/securec.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -40,6 +41,7 @@ class OnnxModelParser : public ModelParser {
virtual ~OnnxModelParser();
schema::MetaGraphT *ParseGraph(const onnx::GraphProto &graph, const QuantType &quantType = QuantType_QUANT_NONE);
schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override;
@ -62,7 +64,7 @@ class OnnxModelParser : public ModelParser {
STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache,
const QuantType &quantType);
const QuantType &quantType, schema::MetaGraphT *dst_graph);
void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::MetaGraphT *graph, TensorCache *tensor_cache, const QuantType &quant_type);
@ -86,9 +88,13 @@ class OnnxModelParser : public ModelParser {
void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph);
STATUS ParseLoopAttr(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node, const QuantType &quantType,
schema::MetaGraphT *dst_graph);
private:
std::vector<string> graphInputNames;
std::vector<string> graphConstNames;
std::vector<std::string> graphInputNames;
std::vector<std::string> graphConstNames;
int subGraphNum = 0;
};
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,80 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_non_max_suppression_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
STATUS OnnxNonMaxSuppressionParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx EluParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::NonMaxSuppressionT> attr = std::make_unique<schema::NonMaxSuppressionT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
if (onnx_node.input_size() > 2) {
auto it = std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
[&](const onnx::TensorProto &it) { return it.name() == onnx_node.input(2); });
if (it != onnx_graph.initializer().end()) {
attr->maxOutBoxPerClass = it->int64_data(0);
}
}
if (onnx_node.input_size() > 3) {
auto it = std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
[&](const onnx::TensorProto &it) { return it.name() == onnx_node.input(3); });
if (it != onnx_graph.initializer().end()) {
attr->iouThreshold = it->float_data(0);
}
}
if (onnx_node.input_size() > 4) {
auto it = std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
[&](const onnx::TensorProto &it) { return it.name() == onnx_node.input(4); });
if (it != onnx_graph.initializer().end()) {
attr->scoreThreshold = it->float_data(0);
}
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "center_point_box") {
if (onnx_node_attr.has_i()) {
attr->centerPointBox = onnx_node_attr.i();
}
}
}
op->primitive->value.type = schema::PrimitiveType_Elu;
op->primitive->value.value = attr.release();
return RET_OK;
}
OnnxNodeRegistrar g_onnxNonMaxSuppressionParser("NonMaxSuppression", new OnnxNonMaxSuppressionParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,33 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NON_MAX_SUPPRESSION_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NON_MAX_SUPPRESSION_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore {
namespace lite {
class OnnxNonMaxSuppressionParser : public OnnxNodeParser {
public:
OnnxNonMaxSuppressionParser() : OnnxNodeParser("NonMaxSuppression") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NON_MAX_SUPPRESSION_PARSER_H

View File

@ -0,0 +1,95 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_resize_parser.h"
#include <map>
#include <memory>
#include <string>
#include <vector>
namespace mindspore {
namespace lite {
STATUS OnnxResizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ResizeParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ResizeT> attr = std::make_unique<schema::ResizeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->format = schema::Format_NCHW;
attr->nearestMode = schema::NearestMode_ROUND_HALF_DOWN;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "coordinate_transformation_mode") {
attr->coordinateTransformMode = [&]() {
std::map<std::string, schema::CoordinateTransformMode> transform_map = {
{"half_pixel", schema::CoordinateTransformMode_HALF_PIXEL},
{"pytorch_half_pixel", schema::CoordinateTransformMode_PYTORCH_HALF_PIXEL},
{"align_corners", schema::CoordinateTransformMode_ALIGN_CORNERS},
{"asymmetric", schema::CoordinateTransformMode_ASYMMETRIC},
{"tf_half_pixel_for_nn", schema::CoordinateTransformMode_TF_HALF_PIXEL},
{"tf_crop_and_resize", schema::CoordinateTransformMode_TF_CROP_AND_RESIZE},
};
return transform_map[onnx_node_attr.strings(0)];
}();
} else if (attribute_name == "cubic_coeff_a") {
attr->cubicCoeff = onnx_node_attr.f();
} else if (attribute_name == "exclude_outside") {
attr->excludeOutside = onnx_node_attr.i();
} else if (attribute_name == "extrapolation_value") {
attr->extrapolationValue = onnx_node_attr.f();
} else if (attribute_name == "mode") {
attr->method = [&]() {
std::map<std::string, schema::ResizeMethod> resize_mode = {
{"nearest", schema::ResizeMethod_NEAREST},
{"linear", schema::ResizeMethod_LINEAR},
{"cubic", schema::ResizeMethod_CUBIC},
};
return resize_mode[onnx_node_attr.strings(0)];
}();
} else if (attribute_name == "nearest_mode") {
attr->nearestMode = [&]() {
std::map<std::string, schema::NearestMode> nearest_mode = {
{"round_prefer_floor", schema::NearestMode_ROUND_HALF_DOWN},
{"round_prefer_ceil", schema::NearestMode_ROUND_HALF_UP},
{"floor", schema::NearestMode_FLOOR},
{"ceil", schema::NearestMode_CEIL},
};
return nearest_mode[onnx_node_attr.strings(0)];
}();
}
}
op->primitive->value.type = schema::PrimitiveType_Resize;
op->primitive->value.value = attr.release();
return RET_OK;
}
OnnxNodeRegistrar g_onnxResizeParser("Resize", new OnnxResizeParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,33 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESIZE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESIZE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore {
namespace lite {
class OnnxResizeParser : public OnnxNodeParser {
public:
OnnxResizeParser() : OnnxNodeParser("Resize") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESIZE_PARSER_H

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_upsample_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
@ -42,9 +42,9 @@ STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "mode") {
if ("nearest" == onnx_node_attr.s()) {
attr->method = schema::ResizeMethod_NEAREST_NEIGHBOR;
attr->method = schema::ResizeMethod_NEAREST;
} else if ("bilinear" == onnx_node_attr.s()) {
attr->method = schema::ResizeMethod_BILINEAR;
attr->method = schema::ResizeMethod_LINEAR;
} else {
MS_LOG(ERROR) << "Resize do not support upsample mode";
return RET_ERROR;

View File

@ -15,9 +15,9 @@
*/
#include "tools/converter/parser/tflite/tflite_custom_parser.h"
#include <vector>
#include <memory>
#include <map>
#include <memory>
#include <vector>
#include "flatbuffers/flatbuffers.h"
#include "flatbuffers/flexbuffers.h"
@ -206,6 +206,8 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
status = ExtractFeatures(custom_attr, op, tflite_op);
} else if (custom_type == "AudioSpectrogram") {
status = AudioSpectrogram(custom_attr, op, tflite_op);
} else if (custom_type == "Mfcc") {
status = Mfcc(custom_attr, op, tflite_op);
} else if (custom_type == "FlexRFFT") {
status = Rfft(custom_attr, op, tflite_op, tflite_model);
} else if (custom_type == "FlexReal") {

View File

@ -15,10 +15,10 @@
*/
#include "tools/converter/parser/tflite/tflite_resize_parser.h"
#include <vector>
#include <map>
#include <memory>
#include <string>
#include <map>
#include <vector>
namespace mindspore {
namespace lite {
@ -39,7 +39,7 @@ STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->coordinateTransformMode = schema::CoordinateTransformMode_COMMON;
std::vector<std::string> node_name_str;
Split(op->name.data(), &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str();
@ -50,8 +50,16 @@ STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->alignCorners = tfliteAttr->align_corners;
attr->method = schema::ResizeMethod_BILINEAR;
if (tfliteAttr->align_corners) {
attr->alignCorners = tfliteAttr->align_corners;
attr->coordinateTransformMode = schema::CoordinateTransformMode_ALIGN_CORNERS;
}
if (tfliteAttr->half_pixel_centers) {
attr->coordinateTransformMode = (attr->coordinateTransformMode == schema::CoordinateTransformMode_COMMON
? schema::CoordinateTransformMode_TF_HALF_PIXEL
: schema::CoordinateTransformMode_ALIGN_CORNERS_WITH_HALF_PIEXL);
}
attr->method = schema::ResizeMethod_LINEAR;
} else if (std::strcmp(node_name, "NearestNeighbor") == 0) {
MS_LOG(DEBUG) << "parse TfliteResizeNearestNeighborParser";
const auto &tfliteAttr = tflite_op->builtin_options.AsResizeNearestNeighborOptions();
@ -59,8 +67,17 @@ STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->alignCorners = tfliteAttr->align_corners;
attr->method = schema::ResizeMethod_NEAREST_NEIGHBOR;
if (tfliteAttr->align_corners) {
attr->alignCorners = tfliteAttr->align_corners;
attr->coordinateTransformMode = schema::CoordinateTransformMode_ALIGN_CORNERS;
}
if (tfliteAttr->half_pixel_centers) {
attr->coordinateTransformMode = (attr->coordinateTransformMode == schema::CoordinateTransformMode_COMMON
? schema::CoordinateTransformMode_TF_HALF_PIXEL
: schema::CoordinateTransformMode_ALIGN_CORNERS_WITH_HALF_PIEXL);
}
attr->method = schema::ResizeMethod_NEAREST;
attr->nearestMode = schema::NearestMode_NORMAL;
} else {
MS_LOG(ERROR) << "wrong resize type";
return RET_ERROR;