forked from mindspore-Ecosystem/mindspore
append onnx parser
This commit is contained in:
parent
25dd36059b
commit
7f42991624
|
@ -219,6 +219,10 @@ union PrimitiveType {
|
|||
Sgd,
|
||||
Adam,
|
||||
GroupConv2DGradInput,
|
||||
Loop,
|
||||
NonMaxSuppression,
|
||||
InstanceNorm,
|
||||
Identity,
|
||||
}
|
||||
|
||||
enum QuantType: int {
|
||||
|
@ -250,6 +254,7 @@ table MetaGraph {
|
|||
mempoolSize: uint;
|
||||
nodes: [CNode];
|
||||
allTensors: [Tensor]; // weight + input + output
|
||||
subGraph : [MetaGraph];
|
||||
}
|
||||
|
||||
root_type MetaGraph;
|
||||
|
|
|
@ -18,8 +18,28 @@ namespace mindspore.schema;
|
|||
|
||||
enum ResizeMethod: byte {
|
||||
UNKNOW = -1,
|
||||
BILINEAR = 0,
|
||||
NEAREST_NEIGHBOR = 1
|
||||
LINEAR = 0,
|
||||
NEAREST = 1,
|
||||
CUBIC = 2
|
||||
}
|
||||
|
||||
enum CoordinateTransformMode: byte {
|
||||
COMMON = 0,
|
||||
HALF_PIXEL = 1,
|
||||
PYTORCH_HALF_PIXEL = 2,
|
||||
TF_HALF_PIXEL = 3,
|
||||
TF_CROP_AND_RESIZE = 4,
|
||||
ALIGN_CORNERS = 5,
|
||||
ASYMMETRIC = 6,
|
||||
ALIGN_CORNERS_WITH_HALF_PIEXL = 7
|
||||
}
|
||||
|
||||
enum NearestMode : byte {
|
||||
NORMAL = 0,
|
||||
ROUND_HALF_DOWN = 1,
|
||||
ROUND_HALF_UP = 2,
|
||||
FLOOR = 3,
|
||||
CEIL = 4
|
||||
}
|
||||
|
||||
enum Format : int {
|
||||
|
@ -376,8 +396,13 @@ table Resize {
|
|||
method: ResizeMethod;
|
||||
newHeight: long;
|
||||
newWidth: long;
|
||||
alignCorners: bool = false;
|
||||
alignCorners: bool = false; // DEPRECATED IN FUTURE: use 'coordinateTransformMode' instead.
|
||||
preserveAspectRatio: bool = false;
|
||||
coordinateTransformMode : CoordinateTransformMode;
|
||||
cubicCoeff : float;
|
||||
excludeOutside : int;
|
||||
extrapolationValue : float = 0;
|
||||
nearestMode : NearestMode;
|
||||
}
|
||||
|
||||
table DetectionPostProcess {
|
||||
|
@ -1054,3 +1079,21 @@ table FftReal {
|
|||
|
||||
table FftImag {
|
||||
}
|
||||
|
||||
table NonMaxSuppression {
|
||||
maxOutBoxPerClass : int = 0;
|
||||
iouThreshold : float = 0;
|
||||
scoreThreshold : float = 0;
|
||||
centerPointBox : int = 0;
|
||||
}
|
||||
|
||||
table InstanceNorm {
|
||||
epsilon : float = 0.00001;
|
||||
}
|
||||
|
||||
table Loop {
|
||||
subGraphIndex : int;
|
||||
}
|
||||
|
||||
table Identity {
|
||||
}
|
||||
|
|
|
@ -51,9 +51,9 @@ int Resize::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
|
|||
if (this->primitive_->value.value == nullptr) {
|
||||
auto attr = new (std::nothrow) schema::ResizeT();
|
||||
if (prim.instance_name() == "ResizeNearestNeighbor") {
|
||||
attr->method = schema::ResizeMethod_NEAREST_NEIGHBOR;
|
||||
attr->method = schema::ResizeMethod_NEAREST;
|
||||
} else if (prim.instance_name() == "ResizeBilinear") {
|
||||
attr->method = schema::ResizeMethod_BILINEAR;
|
||||
attr->method = schema::ResizeMethod_LINEAR;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "wrong resize type";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -41,8 +41,8 @@ int ResizeBaseCPUKernel::CheckParameters() {
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
method_ = parameter->method_;
|
||||
if (method_ != static_cast<int>(schema::ResizeMethod_BILINEAR) &&
|
||||
method_ != static_cast<int>(schema::ResizeMethod_NEAREST_NEIGHBOR)) {
|
||||
if (method_ != static_cast<int>(schema::ResizeMethod_LINEAR) &&
|
||||
method_ != static_cast<int>(schema::ResizeMethod_NEAREST)) {
|
||||
MS_LOG(ERROR) << "Resize method should be bilinear or nearest_neighbor, but got " << method_;
|
||||
return RET_INVALID_OP_ATTR;
|
||||
}
|
||||
|
|
|
@ -14,11 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include "src/runtime/kernel/arm/fp32/resize.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "nnacl/fp32/resize.h"
|
||||
#include <algorithm>
|
||||
#include "include/errorcode.h"
|
||||
#include "nnacl/fp32/resize.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
|
@ -41,7 +41,7 @@ int ResizeCPUKernel::Init() {
|
|||
|
||||
int ResizeCPUKernel::ReSize() {
|
||||
int ret = RET_OK;
|
||||
if (method_ == static_cast<int>(schema::ResizeMethod_BILINEAR)) {
|
||||
if (method_ == static_cast<int>(schema::ResizeMethod_LINEAR)) {
|
||||
FreeTmpBuffer();
|
||||
ret = MallocTmpBuffer();
|
||||
if (ret != RET_OK) {
|
||||
|
@ -162,7 +162,7 @@ int ResizeCPUKernel::RunImpl(int task_id) {
|
|||
|
||||
int ret = 0;
|
||||
switch (method_) {
|
||||
case static_cast<int>(schema::ResizeMethod_BILINEAR): {
|
||||
case static_cast<int>(schema::ResizeMethod_LINEAR): {
|
||||
int n_h_begin, n_h_end;
|
||||
int n = out_tensors_.at(0)->shape()[0];
|
||||
int h = new_height_;
|
||||
|
@ -178,7 +178,7 @@ int ResizeCPUKernel::RunImpl(int task_id) {
|
|||
|
||||
break;
|
||||
}
|
||||
case static_cast<int>(schema::ResizeMethod_NEAREST_NEIGHBOR): {
|
||||
case static_cast<int>(schema::ResizeMethod_NEAREST): {
|
||||
if (in_tensors_.size() == lite::kDoubleNum && !const_shape_) {
|
||||
auto out_shape = in_tensors_.at(1);
|
||||
auto data = reinterpret_cast<int32_t *>(out_shape->MutableData());
|
||||
|
|
|
@ -14,12 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/int8/resize_int8.h"
|
||||
#include <vector>
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "nnacl/int8/resize.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/kernel/arm/int8/resize_int8.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
|
@ -84,7 +84,7 @@ int ResizeInt8CPUKernel::RunImpl(int task_id) {
|
|||
|
||||
int ret = 0;
|
||||
switch (method_) {
|
||||
case static_cast<int>(schema::ResizeMethod_BILINEAR): {
|
||||
case static_cast<int>(schema::ResizeMethod_LINEAR): {
|
||||
if (quant_in_->zp_ == 0) {
|
||||
ret = ResizeBilinearInt8(input_data, output_data, input_shape.data(), out_tensors_[0]->shape().data(),
|
||||
align_corners_, quant_in_, quant_out_, multiplier_, task_id, context_->thread_num_);
|
||||
|
@ -95,7 +95,7 @@ int ResizeInt8CPUKernel::RunImpl(int task_id) {
|
|||
}
|
||||
break;
|
||||
}
|
||||
case static_cast<int>(schema::ResizeMethod_NEAREST_NEIGHBOR): {
|
||||
case static_cast<int>(schema::ResizeMethod_NEAREST): {
|
||||
bool same_zp = quant_in_->zp_ == quant_out_->zp_;
|
||||
bool same_scale = abs(quant_out_->scale_ - quant_in_->scale_) < 1e-6;
|
||||
if (same_zp && same_scale) {
|
||||
|
|
|
@ -14,12 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/opencl/kernel/resize.h"
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "include/errorcode.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/kernel/opencl/kernel/resize.h"
|
||||
#include "src/runtime/kernel/opencl/cl/resize.cl.inc"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kGPU;
|
||||
|
@ -46,9 +46,9 @@ int ResizeOpenCLKernel::Init() {
|
|||
return RET_PARAM_INVALID;
|
||||
}
|
||||
std::string kernel_name = "resize";
|
||||
if (resize_param->method_ == schema::ResizeMethod_BILINEAR) {
|
||||
if (resize_param->method_ == schema::ResizeMethod_LINEAR) {
|
||||
kernel_name += "_bilinear";
|
||||
} else if (resize_param->method_ == schema::ResizeMethod_NEAREST_NEIGHBOR) {
|
||||
} else if (resize_param->method_ == schema::ResizeMethod_NEAREST) {
|
||||
kernel_name += "_nearest_neighbor";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unsupported resize method:" << resize_param->method_;
|
||||
|
|
|
@ -14,11 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include <vector>
|
||||
#include "common/common_test.h"
|
||||
#include "mindspore/lite/src/kernel_registry.h"
|
||||
#include "mindspore/lite/src/lite_kernel.h"
|
||||
#include "mindspore/lite/src/tensor.h"
|
||||
#include "common/common_test.h"
|
||||
#include "nnacl/resize_parameter.h"
|
||||
#include "mindspore/lite/src/kernel_registry.h"
|
||||
#include "schema/ops_generated.h"
|
||||
using mindspore::schema::Format_NHWC;
|
||||
|
||||
|
@ -62,7 +62,7 @@ void TestResizeBilinearFp32::Prepare(const std::vector<int> &input_shape, const
|
|||
out_tensor_.SetData(output_data);
|
||||
|
||||
ResizeParameter param_ = {
|
||||
{}, static_cast<int>(schema::ResizeMethod_BILINEAR), output_shape[1], output_shape[2], align_corners};
|
||||
{}, static_cast<int>(schema::ResizeMethod_LINEAR), output_shape[1], output_shape[2], align_corners};
|
||||
desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Resize};
|
||||
ctx_ = lite::InnerContext();
|
||||
ctx_.thread_num_ = thread_num;
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
#include <vector>
|
||||
#include "common/common_test.h"
|
||||
#include "nnacl/resize_parameter.h"
|
||||
#include "mindspore/lite/src/kernel_registry.h"
|
||||
#include "src/kernel_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
|
@ -57,7 +57,7 @@ void TestResizeNearestNeighborFp32::Prepare(const std::vector<int> &input_shape,
|
|||
out_tensor_.SetData(output_data);
|
||||
|
||||
ResizeParameter param_ = {
|
||||
{}, static_cast<int>(schema::ResizeMethod_NEAREST_NEIGHBOR), output_shape[1], output_shape[2], align_corners};
|
||||
{}, static_cast<int>(schema::ResizeMethod_NEAREST), output_shape[1], output_shape[2], align_corners};
|
||||
desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Resize};
|
||||
ctx_ = lite::InnerContext();
|
||||
ctx_.thread_num_ = thread_num;
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include "include/context.h"
|
||||
#include "src/tensor.h"
|
||||
#include "common/common_test.h"
|
||||
#include "mindspore/lite/src/kernel_registry.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "nnacl/int8/resize.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -68,7 +68,7 @@ void TestResizeBilinearInt8::Prepare(const std::vector<int> &in_shape, const std
|
|||
inputs.push_back(&in_tensor);
|
||||
outputs.push_back(&out_tensor);
|
||||
|
||||
param_.method_ = static_cast<int>(schema::ResizeMethod_BILINEAR);
|
||||
param_.method_ = static_cast<int>(schema::ResizeMethod_LINEAR);
|
||||
param_.new_width_ = out_shape[2];
|
||||
param_.new_height_ = out_shape[1];
|
||||
param_.align_corners_ = align_corners;
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include "include/context.h"
|
||||
#include "src/tensor.h"
|
||||
#include "common/common_test.h"
|
||||
#include "mindspore/lite/src/kernel_registry.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "nnacl/int8/resize.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -63,7 +63,7 @@ void TestResizeNearestNeighborInt8::Prepare(const std::vector<int> &in_shape, co
|
|||
inputs.push_back(&in_tensor);
|
||||
outputs.push_back(&out_tensor);
|
||||
|
||||
param_.method_ = static_cast<int>(schema::ResizeMethod_NEAREST_NEIGHBOR);
|
||||
param_.method_ = static_cast<int>(schema::ResizeMethod_NEAREST);
|
||||
param_.new_width_ = out_shape[2];
|
||||
param_.new_height_ = out_shape[1];
|
||||
param_.align_corners_ = align_corners;
|
||||
|
|
|
@ -15,13 +15,13 @@
|
|||
*/
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "common/common_test.h"
|
||||
#include "mindspore/lite/src/common/file_utils.h"
|
||||
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/resize.h"
|
||||
#include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/runtime/kernel/opencl/kernel/resize.h"
|
||||
#include "src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
|
||||
#include "src/runtime/opencl/opencl_runtime.h"
|
||||
#include "test/ut/src/runtime/kernel/opencl/utils_tests.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestResizeOpenCL : public mindspore::CommonTest {
|
||||
|
@ -119,7 +119,7 @@ TEST_F(TestResizeOpenCL, ResizeBilinearFp32) {
|
|||
std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||
std::vector<float> output_data = {0.0f, 0.5f, 1.0f, 1.0f, 1.0f, 1.5f, 2.0f, 2.0f,
|
||||
2.0f, 2.5f, 3.0f, 3.0f, 2.0f, 2.5f, 3.0f, 3.0f};
|
||||
RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_BILINEAR, align_corners);
|
||||
RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_LINEAR, align_corners);
|
||||
}
|
||||
|
||||
TEST_F(TestResizeOpenCL, ResizeBilinearFp16) {
|
||||
|
@ -134,7 +134,7 @@ TEST_F(TestResizeOpenCL, ResizeBilinearFp16) {
|
|||
std::vector<float16_t> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||
std::vector<float16_t> output_data = {0.0f, 0.5f, 1.0f, 1.0f, 1.0f, 1.5f, 2.0f, 2.0f,
|
||||
2.0f, 2.5f, 3.0f, 3.0f, 2.0f, 2.5f, 3.0f, 3.0f};
|
||||
RunTestCaseResize(shape, input_data.data(), output_data.data(), true, schema::ResizeMethod_BILINEAR, align_corners);
|
||||
RunTestCaseResize(shape, input_data.data(), output_data.data(), true, schema::ResizeMethod_LINEAR, align_corners);
|
||||
}
|
||||
|
||||
TEST_F(TestResizeOpenCL, ResizeBilinearAlignFp32) {
|
||||
|
@ -148,7 +148,7 @@ TEST_F(TestResizeOpenCL, ResizeBilinearAlignFp32) {
|
|||
std::vector<int> shape = {n, h, w, oh, ow, c};
|
||||
std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||
std::vector<float> output_data = {0.0f, 0.5f, 1.0f, 1.0f, 1.5f, 2.0f, 2.0f, 2.5f, 3.0f};
|
||||
RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_BILINEAR, align_corners);
|
||||
RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_LINEAR, align_corners);
|
||||
}
|
||||
|
||||
TEST_F(TestResizeOpenCL, ResizeNearestNeighborFp32) {
|
||||
|
@ -163,8 +163,7 @@ TEST_F(TestResizeOpenCL, ResizeNearestNeighborFp32) {
|
|||
std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||
std::vector<float> output_data = {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f,
|
||||
2.0f, 2.0f, 3.0f, 3.0f, 2.0f, 2.0f, 3.0f, 3.0f};
|
||||
RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_NEAREST_NEIGHBOR,
|
||||
align_corners);
|
||||
RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_NEAREST, align_corners);
|
||||
}
|
||||
|
||||
TEST_F(TestResizeOpenCL, ResizeNearestNeighborFp16) {
|
||||
|
@ -179,7 +178,6 @@ TEST_F(TestResizeOpenCL, ResizeNearestNeighborFp16) {
|
|||
std::vector<float16_t> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||
std::vector<float16_t> output_data = {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f,
|
||||
2.0f, 2.0f, 3.0f, 3.0f, 2.0f, 2.0f, 3.0f, 3.0f};
|
||||
RunTestCaseResize(shape, input_data.data(), output_data.data(), true, schema::ResizeMethod_NEAREST_NEIGHBOR,
|
||||
align_corners);
|
||||
RunTestCaseResize(shape, input_data.data(), output_data.data(), true, schema::ResizeMethod_NEAREST, align_corners);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -40,7 +40,7 @@ TEST_F(TestTfliteParserResizeNN, AttrValue) {
|
|||
ASSERT_EQ(val->newWidth, 100);
|
||||
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||
ASSERT_EQ(val->preserveAspectRatio, false);
|
||||
ASSERT_EQ(val->method, schema::ResizeMethod_NEAREST_NEIGHBOR);
|
||||
ASSERT_EQ(val->method, schema::ResizeMethod_NEAREST);
|
||||
}
|
||||
|
||||
class TestTfliteParserResizeBilinear : public TestTfliteParser {
|
||||
|
@ -64,7 +64,7 @@ TEST_F(TestTfliteParserResizeBilinear, AttrValue) {
|
|||
ASSERT_EQ(val->newWidth, 4);
|
||||
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||
ASSERT_EQ(val->preserveAspectRatio, false);
|
||||
ASSERT_EQ(val->method, schema::ResizeMethod_BILINEAR);
|
||||
ASSERT_EQ(val->method, schema::ResizeMethod_LINEAR);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -57,7 +57,7 @@ STATUS CaffeInterpParser::Parse(const caffe::LayerParameter &proto, const caffe:
|
|||
attr->newWidth = width;
|
||||
}
|
||||
attr->alignCorners = true;
|
||||
attr->method = schema::ResizeMethod_BILINEAR;
|
||||
attr->method = schema::ResizeMethod_LINEAR;
|
||||
|
||||
op->name = proto.name();
|
||||
op->primitive->value.type = schema::PrimitiveType_Resize;
|
||||
|
|
|
@ -582,6 +582,94 @@ STATUS OnnxSignParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxAndParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx AndParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::LogicalAndT> attr = std::make_unique<schema::LogicalAndT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive->value.type = schema::PrimitiveType_LogicalAnd;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxOrParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx OrParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::LogicalOrT> attr = std::make_unique<schema::LogicalOrT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive->value.type = schema::PrimitiveType_LogicalOr;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxNotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx NotParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::LogicalNotT> attr = std::make_unique<schema::LogicalNotT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive->value.type = schema::PrimitiveType_LogicalNot;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxRoundParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx RoundParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::RoundT> attr = std::make_unique<schema::RoundT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive->value.type = schema::PrimitiveType_Round;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser());
|
||||
OnnxNodeRegistrar g_onnxInt8AddParser("Int8Add", new OnnxAddParser());
|
||||
OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser());
|
||||
|
@ -608,5 +696,9 @@ OnnxNodeRegistrar g_onnxAtanParser("Atan", new OnnxAtanParser());
|
|||
OnnxNodeRegistrar g_onnxAsinParser("Asin", new OnnxAsinParser());
|
||||
OnnxNodeRegistrar g_onnxTanhParser("Tanh", new OnnxTanhParser());
|
||||
OnnxNodeRegistrar g_onnxSignParser("Sign", new OnnxTanhParser());
|
||||
OnnxNodeRegistrar g_onnxAndParser("And", new OnnxAndParser());
|
||||
OnnxNodeRegistrar g_onnxOrParser("Or", new OnnxOrParser());
|
||||
OnnxNodeRegistrar g_onnxNotParser("Not", new OnnxNotParser());
|
||||
OnnxNodeRegistrar g_onnxRoundParser("Round", new OnnxRoundParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -171,6 +171,30 @@ class OnnxSignParser : public OnnxNodeParser {
|
|||
OnnxSignParser() : OnnxNodeParser("Sign") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
};
|
||||
|
||||
class OnnxAndParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxAndParser() : OnnxNodeParser("And") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
};
|
||||
|
||||
class OnnxOrParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxOrParser() : OnnxNodeParser("Or") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
};
|
||||
|
||||
class OnnxNotParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxNotParser() : OnnxNodeParser("Not") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
};
|
||||
|
||||
class OnnxRoundParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxRoundParser() : OnnxNodeParser("Round") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H
|
||||
|
|
|
@ -15,9 +15,9 @@
|
|||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_conv_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -176,9 +176,6 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
MS_LOG(ERROR) << "Convert Convolution to Depthwise failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (attr->group != 1) {
|
||||
MS_LOG(ERROR) << "group conv hasn't supported";
|
||||
return RET_NOT_SUPPORT;
|
||||
} else {
|
||||
op->primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_identity_parser.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxIdentityParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx IdentityParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::IdentityT> attr = std::make_unique<schema::IdentityT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Identity;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxIdentityParser("Identity", new OnnxIdentityParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IDENTITY_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IDENTITY_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class OnnxIdentityParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxIdentityParser() : OnnxNodeParser("Identity") {}
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IDENTITY_PARSER_H
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_instance_norm_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxInstanceNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx InstanceNormParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::InstanceNormT> attr = std::make_unique<schema::InstanceNormT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
if (!onnx_node.attribute().empty()) {
|
||||
auto onnx_node_attr = onnx_node.attribute().at(0);
|
||||
if (onnx_node_attr.name() == "epsilon") {
|
||||
attr->epsilon = onnx_node_attr.f();
|
||||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_InstanceNorm;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxInstanceNormParser("InstanceNormalization", new OnnxInstanceNormParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_INSTANCE_NORM_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_INSTANCE_NORM_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class OnnxInstanceNormParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxInstanceNormParser() : OnnxNodeParser("InstanceNorm") {}
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_INSTANCE_NORM_PARSER_H
|
|
@ -15,12 +15,12 @@
|
|||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
||||
#include <algorithm>
|
||||
#include <cfloat>
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/protobuf_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -36,7 +36,8 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
|
|||
{onnx::TensorProto_DataType_UINT32, mindspore::kNumberTypeUInt32},
|
||||
{onnx::TensorProto_DataType_INT64, mindspore::kNumberTypeInt64},
|
||||
{onnx::TensorProto_DataType_FLOAT16, mindspore::kNumberTypeFloat16},
|
||||
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32}};
|
||||
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
|
||||
{onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
|
||||
|
||||
TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type) {
|
||||
auto iter = TYPE_MAP.find(onnx_type);
|
||||
|
@ -161,9 +162,13 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
|
|||
TensorCache *tensor_cache) {
|
||||
for (const auto &output_value : onnx_graph.output()) {
|
||||
int index;
|
||||
const auto status = AddValueInfo(output_value, output_value.name(), OP_OUTPUT, tensor_cache, &index);
|
||||
if (status != RET_OK) {
|
||||
return status;
|
||||
if (tensor_cache->FindTensor(output_value.name()) != -1) {
|
||||
index = tensor_cache->FindTensor(output_value.name());
|
||||
} else {
|
||||
const auto status = AddValueInfo(output_value, output_value.name(), OP_OUTPUT, tensor_cache, &index);
|
||||
if (status != RET_OK) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
graph->outputIndex.emplace_back(index);
|
||||
MS_LOG(DEBUG) << "output_value name: " << output_value.name() << ", graph output index: " << index;
|
||||
|
@ -250,7 +255,8 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
|
|||
|
||||
STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *dst_op, schema::TensorT *dst_tensor,
|
||||
TensorCache *tensor_cache, const QuantType &quantType) {
|
||||
TensorCache *tensor_cache, const QuantType &quantType,
|
||||
schema::MetaGraphT *dst_graph) {
|
||||
// change op_type() to name(), that is unique
|
||||
static bool interrupt = false;
|
||||
dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0);
|
||||
|
@ -260,23 +266,34 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
|
|||
<< onnx_node.input_size();
|
||||
// get the real op type
|
||||
SetOpQuantParams(onnx_graph, onnx_node, dst_op, dst_tensor, tensor_cache);
|
||||
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type());
|
||||
if (node_parser == nullptr || interrupt) {
|
||||
if (onnx_node.op_type() == "Loop") {
|
||||
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
|
||||
interrupt = true;
|
||||
if (node_parser == nullptr) {
|
||||
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
|
||||
}
|
||||
return RET_NOT_FIND_OP;
|
||||
}
|
||||
auto status = node_parser->Parse(onnx_graph, onnx_node, dst_op);
|
||||
if (status != RET_OK) {
|
||||
interrupt = true;
|
||||
if (status == RET_NOT_SUPPORT) {
|
||||
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "parser onnx node " << onnx_node.op_type() << " attr failed";
|
||||
int status = ParseLoopAttr(dst_op, onnx_node, quantType, dst_graph);
|
||||
if (status != RET_OK || interrupt) {
|
||||
interrupt = true;
|
||||
return status;
|
||||
}
|
||||
} else {
|
||||
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type());
|
||||
if (node_parser == nullptr || interrupt) {
|
||||
interrupt = true;
|
||||
if (node_parser == nullptr) {
|
||||
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
|
||||
}
|
||||
return RET_NOT_FIND_OP;
|
||||
}
|
||||
auto status = node_parser->Parse(onnx_graph, onnx_node, dst_op);
|
||||
if (status != RET_OK) {
|
||||
interrupt = true;
|
||||
if (status == RET_NOT_FIND_OP) {
|
||||
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "parser onnx node " << onnx_node.op_type() << " attr failed";
|
||||
}
|
||||
return status;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
// set op input index
|
||||
std::vector<string> node_inputs;
|
||||
|
@ -366,7 +383,7 @@ STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs,
|
|||
const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) {
|
||||
for (const auto &onnx_node_input : node_inputs) {
|
||||
if (onnx_node_input != "") {
|
||||
auto index = tensor_cache->FindTensor(onnx_node_input);
|
||||
int index = tensor_cache->FindTensor(onnx_node_input);
|
||||
if (index < 0) {
|
||||
MS_LOG(ERROR) << "input " << onnx_node_input << " of node " << onnx_node.name() << " can't be found";
|
||||
return RET_ERROR;
|
||||
|
@ -428,6 +445,9 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
|
|||
}
|
||||
for (size_t i = 0; i < data_count; ++i) {
|
||||
if (in_data[i] > static_cast<int64_t>(INT32_MAX) || in_data[i] < static_cast<int64_t>(INT32_MIN)) {
|
||||
if (llabs(in_data[i]) == INT64_MAX || in_data[i] == INT64_MIN) {
|
||||
buffer[i] = in_data[i] > 0 ? INT32_MAX : INT32_MIN;
|
||||
}
|
||||
MS_LOG(ERROR) << "int64 data " << in_data[i] << "too big to fit into int32";
|
||||
return RET_ERROR;
|
||||
} else {
|
||||
|
@ -438,6 +458,7 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
|
|||
break;
|
||||
case kNumberTypeUInt8:
|
||||
case kNumberTypeInt8:
|
||||
case kNumberTypeBool:
|
||||
data_size = data_count * sizeof(uint8_t);
|
||||
tensor_data = onnx_const_value.raw_data().data();
|
||||
break;
|
||||
|
@ -446,7 +467,7 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
|
|||
return RET_ERROR;
|
||||
}
|
||||
tensor->data.resize(data_size);
|
||||
if (memcpy_s(static_cast<void *>(tensor->data.data()), data_size, tensor_data, data_size) != 0) {
|
||||
if (data_size != 0 && memcpy_s(static_cast<void *>(tensor->data.data()), data_size, tensor_data, data_size) != 0) {
|
||||
MS_LOG(ERROR) << "memcpy_s failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -475,30 +496,39 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph)
|
|||
}
|
||||
}
|
||||
|
||||
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
|
||||
const QuantType &quantType) {
|
||||
int status = ValidateFileStr(modelFile, ".onnx");
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
STATUS OnnxModelParser::ParseLoopAttr(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node,
|
||||
const QuantType &quantType, schema::MetaGraphT *dst_graph) {
|
||||
MS_LOG(DEBUG) << "onnx LoopParser";
|
||||
if (dst_op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
onnx::ModelProto onnx_model;
|
||||
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << modelFile;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
dst_op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (dst_op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
const onnx::GraphProto &onnx_graph = onnx_model.graph();
|
||||
MS_LOG(INFO) << "model producer name: " << onnx_model.producer_name() << ", graph name: " << onnx_graph.name();
|
||||
std::unique_ptr<schema::LoopT> attr = std::make_unique<schema::LoopT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->subGraphIndex = subGraphNum;
|
||||
auto sub_graph = std::make_unique<schema::MetaGraphT>();
|
||||
sub_graph.reset(ParseGraph(onnx_node.attribute().at(0).g(), quantType));
|
||||
dst_graph->subGraph.push_back(std::move(sub_graph));
|
||||
subGraphNum += 1;
|
||||
dst_op->primitive->value.type = schema::PrimitiveType_Loop;
|
||||
dst_op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_graph, const QuantType &quantType) {
|
||||
TensorCache tensor_cache;
|
||||
// dst_graph->name = onnx_graph.name(); // this is not used
|
||||
// find out input names and const names
|
||||
FindGraphInputAndConst(onnx_graph);
|
||||
// set const tensor
|
||||
status = SetGraphConstTensor(onnx_graph, &tensor_cache);
|
||||
int status = SetGraphConstTensor(onnx_graph, &tensor_cache);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetGraphConstTensor failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
@ -512,13 +542,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
// init onnx model graph output tensor
|
||||
status = SetGraphOutputTensor(onnx_graph, dst_graph.get(), &tensor_cache);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetGraphOutputTensor failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// init op node input/output tensor, and dst_op attr
|
||||
NoSupportOp::GetInstance()->SetFmkType("ONNX");
|
||||
for (const auto &onnx_node : onnx_graph.node()) {
|
||||
|
@ -544,7 +568,8 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
|
|||
|
||||
std::unique_ptr<schema::CNodeT> dst_op = std::make_unique<schema::CNodeT>();
|
||||
std::unique_ptr<schema::TensorT> dst_tensor = std::make_unique<schema::TensorT>();
|
||||
status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType);
|
||||
status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType,
|
||||
dst_graph.get());
|
||||
if (status_node != RET_OK) {
|
||||
status = (status == RET_OK ? status_node : status);
|
||||
continue;
|
||||
|
@ -558,9 +583,42 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
// init onnx model graph output tensor
|
||||
status = SetGraphOutputTensor(onnx_graph, dst_graph.get(), &tensor_cache);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetGraphOutputTensor failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
SetAllTensors(tensor_cache, dst_graph.get());
|
||||
dst_graph->name = GetModelName(modelFile);
|
||||
return dst_graph.release();
|
||||
}
|
||||
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
|
||||
const QuantType &quantType) {
|
||||
int status = ValidateFileStr(modelFile, ".onnx");
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
onnx::ModelProto onnx_model;
|
||||
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << modelFile;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
const onnx::GraphProto &onnx_graph = onnx_model.graph();
|
||||
MS_LOG(INFO) << "model producer name: " << onnx_model.producer_name() << ", graph name: " << onnx_graph.name();
|
||||
|
||||
schema::MetaGraphT *dst_graph = ParseGraph(onnx_graph, quantType);
|
||||
if (dst_graph == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
dst_graph->name = GetModelName(modelFile);
|
||||
return dst_graph;
|
||||
}
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include "securec/include/securec.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
@ -40,6 +41,7 @@ class OnnxModelParser : public ModelParser {
|
|||
|
||||
virtual ~OnnxModelParser();
|
||||
|
||||
schema::MetaGraphT *ParseGraph(const onnx::GraphProto &graph, const QuantType &quantType = QuantType_QUANT_NONE);
|
||||
schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
|
||||
const QuantType &quantType = QuantType_QUANT_NONE) override;
|
||||
|
||||
|
@ -62,7 +64,7 @@ class OnnxModelParser : public ModelParser {
|
|||
|
||||
STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache,
|
||||
const QuantType &quantType);
|
||||
const QuantType &quantType, schema::MetaGraphT *dst_graph);
|
||||
|
||||
void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::MetaGraphT *graph, TensorCache *tensor_cache, const QuantType &quant_type);
|
||||
|
@ -86,9 +88,13 @@ class OnnxModelParser : public ModelParser {
|
|||
|
||||
void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph);
|
||||
|
||||
STATUS ParseLoopAttr(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node, const QuantType &quantType,
|
||||
schema::MetaGraphT *dst_graph);
|
||||
|
||||
private:
|
||||
std::vector<string> graphInputNames;
|
||||
std::vector<string> graphConstNames;
|
||||
std::vector<std::string> graphInputNames;
|
||||
std::vector<std::string> graphConstNames;
|
||||
int subGraphNum = 0;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_non_max_suppression_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxNonMaxSuppressionParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx EluParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::NonMaxSuppressionT> attr = std::make_unique<schema::NonMaxSuppressionT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (onnx_node.input_size() > 2) {
|
||||
auto it = std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
|
||||
[&](const onnx::TensorProto &it) { return it.name() == onnx_node.input(2); });
|
||||
if (it != onnx_graph.initializer().end()) {
|
||||
attr->maxOutBoxPerClass = it->int64_data(0);
|
||||
}
|
||||
}
|
||||
|
||||
if (onnx_node.input_size() > 3) {
|
||||
auto it = std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
|
||||
[&](const onnx::TensorProto &it) { return it.name() == onnx_node.input(3); });
|
||||
if (it != onnx_graph.initializer().end()) {
|
||||
attr->iouThreshold = it->float_data(0);
|
||||
}
|
||||
}
|
||||
|
||||
if (onnx_node.input_size() > 4) {
|
||||
auto it = std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
|
||||
[&](const onnx::TensorProto &it) { return it.name() == onnx_node.input(4); });
|
||||
if (it != onnx_graph.initializer().end()) {
|
||||
attr->scoreThreshold = it->float_data(0);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "center_point_box") {
|
||||
if (onnx_node_attr.has_i()) {
|
||||
attr->centerPointBox = onnx_node_attr.i();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Elu;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxNonMaxSuppressionParser("NonMaxSuppression", new OnnxNonMaxSuppressionParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NON_MAX_SUPPRESSION_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NON_MAX_SUPPRESSION_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class OnnxNonMaxSuppressionParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxNonMaxSuppressionParser() : OnnxNodeParser("NonMaxSuppression") {}
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NON_MAX_SUPPRESSION_PARSER_H
|
|
@ -0,0 +1,95 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_resize_parser.h"
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS OnnxResizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx ResizeParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ResizeT> attr = std::make_unique<schema::ResizeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
attr->format = schema::Format_NCHW;
|
||||
attr->nearestMode = schema::NearestMode_ROUND_HALF_DOWN;
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "coordinate_transformation_mode") {
|
||||
attr->coordinateTransformMode = [&]() {
|
||||
std::map<std::string, schema::CoordinateTransformMode> transform_map = {
|
||||
{"half_pixel", schema::CoordinateTransformMode_HALF_PIXEL},
|
||||
{"pytorch_half_pixel", schema::CoordinateTransformMode_PYTORCH_HALF_PIXEL},
|
||||
{"align_corners", schema::CoordinateTransformMode_ALIGN_CORNERS},
|
||||
{"asymmetric", schema::CoordinateTransformMode_ASYMMETRIC},
|
||||
{"tf_half_pixel_for_nn", schema::CoordinateTransformMode_TF_HALF_PIXEL},
|
||||
{"tf_crop_and_resize", schema::CoordinateTransformMode_TF_CROP_AND_RESIZE},
|
||||
};
|
||||
return transform_map[onnx_node_attr.strings(0)];
|
||||
}();
|
||||
} else if (attribute_name == "cubic_coeff_a") {
|
||||
attr->cubicCoeff = onnx_node_attr.f();
|
||||
} else if (attribute_name == "exclude_outside") {
|
||||
attr->excludeOutside = onnx_node_attr.i();
|
||||
} else if (attribute_name == "extrapolation_value") {
|
||||
attr->extrapolationValue = onnx_node_attr.f();
|
||||
} else if (attribute_name == "mode") {
|
||||
attr->method = [&]() {
|
||||
std::map<std::string, schema::ResizeMethod> resize_mode = {
|
||||
{"nearest", schema::ResizeMethod_NEAREST},
|
||||
{"linear", schema::ResizeMethod_LINEAR},
|
||||
{"cubic", schema::ResizeMethod_CUBIC},
|
||||
};
|
||||
return resize_mode[onnx_node_attr.strings(0)];
|
||||
}();
|
||||
} else if (attribute_name == "nearest_mode") {
|
||||
attr->nearestMode = [&]() {
|
||||
std::map<std::string, schema::NearestMode> nearest_mode = {
|
||||
{"round_prefer_floor", schema::NearestMode_ROUND_HALF_DOWN},
|
||||
{"round_prefer_ceil", schema::NearestMode_ROUND_HALF_UP},
|
||||
{"floor", schema::NearestMode_FLOOR},
|
||||
{"ceil", schema::NearestMode_CEIL},
|
||||
};
|
||||
return nearest_mode[onnx_node_attr.strings(0)];
|
||||
}();
|
||||
}
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Resize;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxResizeParser("Resize", new OnnxResizeParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESIZE_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESIZE_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class OnnxResizeParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxResizeParser() : OnnxNodeParser("Resize") {}
|
||||
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESIZE_PARSER_H
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_upsample_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -42,9 +42,9 @@ STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:
|
|||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "mode") {
|
||||
if ("nearest" == onnx_node_attr.s()) {
|
||||
attr->method = schema::ResizeMethod_NEAREST_NEIGHBOR;
|
||||
attr->method = schema::ResizeMethod_NEAREST;
|
||||
} else if ("bilinear" == onnx_node_attr.s()) {
|
||||
attr->method = schema::ResizeMethod_BILINEAR;
|
||||
attr->method = schema::ResizeMethod_LINEAR;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Resize do not support upsample mode";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -15,9 +15,9 @@
|
|||
*/
|
||||
|
||||
#include "tools/converter/parser/tflite/tflite_custom_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "flatbuffers/flatbuffers.h"
|
||||
#include "flatbuffers/flexbuffers.h"
|
||||
|
||||
|
@ -206,6 +206,8 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
status = ExtractFeatures(custom_attr, op, tflite_op);
|
||||
} else if (custom_type == "AudioSpectrogram") {
|
||||
status = AudioSpectrogram(custom_attr, op, tflite_op);
|
||||
} else if (custom_type == "Mfcc") {
|
||||
status = Mfcc(custom_attr, op, tflite_op);
|
||||
} else if (custom_type == "FlexRFFT") {
|
||||
status = Rfft(custom_attr, op, tflite_op, tflite_model);
|
||||
} else if (custom_type == "FlexReal") {
|
||||
|
|
|
@ -15,10 +15,10 @@
|
|||
*/
|
||||
|
||||
#include "tools/converter/parser/tflite/tflite_resize_parser.h"
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -39,7 +39,7 @@ STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
attr->coordinateTransformMode = schema::CoordinateTransformMode_COMMON;
|
||||
std::vector<std::string> node_name_str;
|
||||
Split(op->name.data(), &node_name_str, "-");
|
||||
const char *node_name = node_name_str.data()->c_str();
|
||||
|
@ -50,8 +50,16 @@ STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->alignCorners = tfliteAttr->align_corners;
|
||||
attr->method = schema::ResizeMethod_BILINEAR;
|
||||
if (tfliteAttr->align_corners) {
|
||||
attr->alignCorners = tfliteAttr->align_corners;
|
||||
attr->coordinateTransformMode = schema::CoordinateTransformMode_ALIGN_CORNERS;
|
||||
}
|
||||
if (tfliteAttr->half_pixel_centers) {
|
||||
attr->coordinateTransformMode = (attr->coordinateTransformMode == schema::CoordinateTransformMode_COMMON
|
||||
? schema::CoordinateTransformMode_TF_HALF_PIXEL
|
||||
: schema::CoordinateTransformMode_ALIGN_CORNERS_WITH_HALF_PIEXL);
|
||||
}
|
||||
attr->method = schema::ResizeMethod_LINEAR;
|
||||
} else if (std::strcmp(node_name, "NearestNeighbor") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteResizeNearestNeighborParser";
|
||||
const auto &tfliteAttr = tflite_op->builtin_options.AsResizeNearestNeighborOptions();
|
||||
|
@ -59,8 +67,17 @@ STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->alignCorners = tfliteAttr->align_corners;
|
||||
attr->method = schema::ResizeMethod_NEAREST_NEIGHBOR;
|
||||
if (tfliteAttr->align_corners) {
|
||||
attr->alignCorners = tfliteAttr->align_corners;
|
||||
attr->coordinateTransformMode = schema::CoordinateTransformMode_ALIGN_CORNERS;
|
||||
}
|
||||
if (tfliteAttr->half_pixel_centers) {
|
||||
attr->coordinateTransformMode = (attr->coordinateTransformMode == schema::CoordinateTransformMode_COMMON
|
||||
? schema::CoordinateTransformMode_TF_HALF_PIXEL
|
||||
: schema::CoordinateTransformMode_ALIGN_CORNERS_WITH_HALF_PIEXL);
|
||||
}
|
||||
attr->method = schema::ResizeMethod_NEAREST;
|
||||
attr->nearestMode = schema::NearestMode_NORMAL;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "wrong resize type";
|
||||
return RET_ERROR;
|
||||
|
|
Loading…
Reference in New Issue