forked from mindspore-Ecosystem/mindspore
append onnx parser
This commit is contained in:
parent
25dd36059b
commit
7f42991624
|
@ -219,6 +219,10 @@ union PrimitiveType {
|
||||||
Sgd,
|
Sgd,
|
||||||
Adam,
|
Adam,
|
||||||
GroupConv2DGradInput,
|
GroupConv2DGradInput,
|
||||||
|
Loop,
|
||||||
|
NonMaxSuppression,
|
||||||
|
InstanceNorm,
|
||||||
|
Identity,
|
||||||
}
|
}
|
||||||
|
|
||||||
enum QuantType: int {
|
enum QuantType: int {
|
||||||
|
@ -250,6 +254,7 @@ table MetaGraph {
|
||||||
mempoolSize: uint;
|
mempoolSize: uint;
|
||||||
nodes: [CNode];
|
nodes: [CNode];
|
||||||
allTensors: [Tensor]; // weight + input + output
|
allTensors: [Tensor]; // weight + input + output
|
||||||
|
subGraph : [MetaGraph];
|
||||||
}
|
}
|
||||||
|
|
||||||
root_type MetaGraph;
|
root_type MetaGraph;
|
||||||
|
|
|
@ -18,8 +18,28 @@ namespace mindspore.schema;
|
||||||
|
|
||||||
enum ResizeMethod: byte {
|
enum ResizeMethod: byte {
|
||||||
UNKNOW = -1,
|
UNKNOW = -1,
|
||||||
BILINEAR = 0,
|
LINEAR = 0,
|
||||||
NEAREST_NEIGHBOR = 1
|
NEAREST = 1,
|
||||||
|
CUBIC = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
enum CoordinateTransformMode: byte {
|
||||||
|
COMMON = 0,
|
||||||
|
HALF_PIXEL = 1,
|
||||||
|
PYTORCH_HALF_PIXEL = 2,
|
||||||
|
TF_HALF_PIXEL = 3,
|
||||||
|
TF_CROP_AND_RESIZE = 4,
|
||||||
|
ALIGN_CORNERS = 5,
|
||||||
|
ASYMMETRIC = 6,
|
||||||
|
ALIGN_CORNERS_WITH_HALF_PIEXL = 7
|
||||||
|
}
|
||||||
|
|
||||||
|
enum NearestMode : byte {
|
||||||
|
NORMAL = 0,
|
||||||
|
ROUND_HALF_DOWN = 1,
|
||||||
|
ROUND_HALF_UP = 2,
|
||||||
|
FLOOR = 3,
|
||||||
|
CEIL = 4
|
||||||
}
|
}
|
||||||
|
|
||||||
enum Format : int {
|
enum Format : int {
|
||||||
|
@ -376,8 +396,13 @@ table Resize {
|
||||||
method: ResizeMethod;
|
method: ResizeMethod;
|
||||||
newHeight: long;
|
newHeight: long;
|
||||||
newWidth: long;
|
newWidth: long;
|
||||||
alignCorners: bool = false;
|
alignCorners: bool = false; // DEPRECATED IN FUTURE: use 'coordinateTransformMode' instead.
|
||||||
preserveAspectRatio: bool = false;
|
preserveAspectRatio: bool = false;
|
||||||
|
coordinateTransformMode : CoordinateTransformMode;
|
||||||
|
cubicCoeff : float;
|
||||||
|
excludeOutside : int;
|
||||||
|
extrapolationValue : float = 0;
|
||||||
|
nearestMode : NearestMode;
|
||||||
}
|
}
|
||||||
|
|
||||||
table DetectionPostProcess {
|
table DetectionPostProcess {
|
||||||
|
@ -1054,3 +1079,21 @@ table FftReal {
|
||||||
|
|
||||||
table FftImag {
|
table FftImag {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
table NonMaxSuppression {
|
||||||
|
maxOutBoxPerClass : int = 0;
|
||||||
|
iouThreshold : float = 0;
|
||||||
|
scoreThreshold : float = 0;
|
||||||
|
centerPointBox : int = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
table InstanceNorm {
|
||||||
|
epsilon : float = 0.00001;
|
||||||
|
}
|
||||||
|
|
||||||
|
table Loop {
|
||||||
|
subGraphIndex : int;
|
||||||
|
}
|
||||||
|
|
||||||
|
table Identity {
|
||||||
|
}
|
||||||
|
|
|
@ -51,9 +51,9 @@ int Resize::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
|
||||||
if (this->primitive_->value.value == nullptr) {
|
if (this->primitive_->value.value == nullptr) {
|
||||||
auto attr = new (std::nothrow) schema::ResizeT();
|
auto attr = new (std::nothrow) schema::ResizeT();
|
||||||
if (prim.instance_name() == "ResizeNearestNeighbor") {
|
if (prim.instance_name() == "ResizeNearestNeighbor") {
|
||||||
attr->method = schema::ResizeMethod_NEAREST_NEIGHBOR;
|
attr->method = schema::ResizeMethod_NEAREST;
|
||||||
} else if (prim.instance_name() == "ResizeBilinear") {
|
} else if (prim.instance_name() == "ResizeBilinear") {
|
||||||
attr->method = schema::ResizeMethod_BILINEAR;
|
attr->method = schema::ResizeMethod_LINEAR;
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "wrong resize type";
|
MS_LOG(ERROR) << "wrong resize type";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
|
|
@ -41,8 +41,8 @@ int ResizeBaseCPUKernel::CheckParameters() {
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
method_ = parameter->method_;
|
method_ = parameter->method_;
|
||||||
if (method_ != static_cast<int>(schema::ResizeMethod_BILINEAR) &&
|
if (method_ != static_cast<int>(schema::ResizeMethod_LINEAR) &&
|
||||||
method_ != static_cast<int>(schema::ResizeMethod_NEAREST_NEIGHBOR)) {
|
method_ != static_cast<int>(schema::ResizeMethod_NEAREST)) {
|
||||||
MS_LOG(ERROR) << "Resize method should be bilinear or nearest_neighbor, but got " << method_;
|
MS_LOG(ERROR) << "Resize method should be bilinear or nearest_neighbor, but got " << method_;
|
||||||
return RET_INVALID_OP_ATTR;
|
return RET_INVALID_OP_ATTR;
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,11 +14,11 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include "src/runtime/kernel/arm/fp32/resize.h"
|
#include "src/runtime/kernel/arm/fp32/resize.h"
|
||||||
#include "schema/model_generated.h"
|
#include <algorithm>
|
||||||
#include "nnacl/fp32/resize.h"
|
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
|
#include "nnacl/fp32/resize.h"
|
||||||
|
#include "schema/model_generated.h"
|
||||||
#include "src/runtime/runtime_api.h"
|
#include "src/runtime/runtime_api.h"
|
||||||
|
|
||||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||||
|
@ -41,7 +41,7 @@ int ResizeCPUKernel::Init() {
|
||||||
|
|
||||||
int ResizeCPUKernel::ReSize() {
|
int ResizeCPUKernel::ReSize() {
|
||||||
int ret = RET_OK;
|
int ret = RET_OK;
|
||||||
if (method_ == static_cast<int>(schema::ResizeMethod_BILINEAR)) {
|
if (method_ == static_cast<int>(schema::ResizeMethod_LINEAR)) {
|
||||||
FreeTmpBuffer();
|
FreeTmpBuffer();
|
||||||
ret = MallocTmpBuffer();
|
ret = MallocTmpBuffer();
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
|
@ -162,7 +162,7 @@ int ResizeCPUKernel::RunImpl(int task_id) {
|
||||||
|
|
||||||
int ret = 0;
|
int ret = 0;
|
||||||
switch (method_) {
|
switch (method_) {
|
||||||
case static_cast<int>(schema::ResizeMethod_BILINEAR): {
|
case static_cast<int>(schema::ResizeMethod_LINEAR): {
|
||||||
int n_h_begin, n_h_end;
|
int n_h_begin, n_h_end;
|
||||||
int n = out_tensors_.at(0)->shape()[0];
|
int n = out_tensors_.at(0)->shape()[0];
|
||||||
int h = new_height_;
|
int h = new_height_;
|
||||||
|
@ -178,7 +178,7 @@ int ResizeCPUKernel::RunImpl(int task_id) {
|
||||||
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case static_cast<int>(schema::ResizeMethod_NEAREST_NEIGHBOR): {
|
case static_cast<int>(schema::ResizeMethod_NEAREST): {
|
||||||
if (in_tensors_.size() == lite::kDoubleNum && !const_shape_) {
|
if (in_tensors_.size() == lite::kDoubleNum && !const_shape_) {
|
||||||
auto out_shape = in_tensors_.at(1);
|
auto out_shape = in_tensors_.at(1);
|
||||||
auto data = reinterpret_cast<int32_t *>(out_shape->MutableData());
|
auto data = reinterpret_cast<int32_t *>(out_shape->MutableData());
|
||||||
|
|
|
@ -14,12 +14,12 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include "src/runtime/kernel/arm/int8/resize_int8.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "src/kernel_registry.h"
|
#include "include/errorcode.h"
|
||||||
#include "nnacl/int8/resize.h"
|
#include "nnacl/int8/resize.h"
|
||||||
#include "schema/model_generated.h"
|
#include "schema/model_generated.h"
|
||||||
#include "include/errorcode.h"
|
#include "src/kernel_registry.h"
|
||||||
#include "src/runtime/kernel/arm/int8/resize_int8.h"
|
|
||||||
#include "src/runtime/runtime_api.h"
|
#include "src/runtime/runtime_api.h"
|
||||||
|
|
||||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||||
|
@ -84,7 +84,7 @@ int ResizeInt8CPUKernel::RunImpl(int task_id) {
|
||||||
|
|
||||||
int ret = 0;
|
int ret = 0;
|
||||||
switch (method_) {
|
switch (method_) {
|
||||||
case static_cast<int>(schema::ResizeMethod_BILINEAR): {
|
case static_cast<int>(schema::ResizeMethod_LINEAR): {
|
||||||
if (quant_in_->zp_ == 0) {
|
if (quant_in_->zp_ == 0) {
|
||||||
ret = ResizeBilinearInt8(input_data, output_data, input_shape.data(), out_tensors_[0]->shape().data(),
|
ret = ResizeBilinearInt8(input_data, output_data, input_shape.data(), out_tensors_[0]->shape().data(),
|
||||||
align_corners_, quant_in_, quant_out_, multiplier_, task_id, context_->thread_num_);
|
align_corners_, quant_in_, quant_out_, multiplier_, task_id, context_->thread_num_);
|
||||||
|
@ -95,7 +95,7 @@ int ResizeInt8CPUKernel::RunImpl(int task_id) {
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case static_cast<int>(schema::ResizeMethod_NEAREST_NEIGHBOR): {
|
case static_cast<int>(schema::ResizeMethod_NEAREST): {
|
||||||
bool same_zp = quant_in_->zp_ == quant_out_->zp_;
|
bool same_zp = quant_in_->zp_ == quant_out_->zp_;
|
||||||
bool same_scale = abs(quant_out_->scale_ - quant_in_->scale_) < 1e-6;
|
bool same_scale = abs(quant_out_->scale_ - quant_in_->scale_) < 1e-6;
|
||||||
if (same_zp && same_scale) {
|
if (same_zp && same_scale) {
|
||||||
|
|
|
@ -14,12 +14,12 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include "src/runtime/kernel/opencl/kernel/resize.h"
|
||||||
|
#include <map>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <map>
|
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
#include "src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
#include "src/runtime/kernel/opencl/kernel/resize.h"
|
|
||||||
#include "src/runtime/kernel/opencl/cl/resize.cl.inc"
|
#include "src/runtime/kernel/opencl/cl/resize.cl.inc"
|
||||||
|
|
||||||
using mindspore::kernel::KERNEL_ARCH::kGPU;
|
using mindspore::kernel::KERNEL_ARCH::kGPU;
|
||||||
|
@ -46,9 +46,9 @@ int ResizeOpenCLKernel::Init() {
|
||||||
return RET_PARAM_INVALID;
|
return RET_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
std::string kernel_name = "resize";
|
std::string kernel_name = "resize";
|
||||||
if (resize_param->method_ == schema::ResizeMethod_BILINEAR) {
|
if (resize_param->method_ == schema::ResizeMethod_LINEAR) {
|
||||||
kernel_name += "_bilinear";
|
kernel_name += "_bilinear";
|
||||||
} else if (resize_param->method_ == schema::ResizeMethod_NEAREST_NEIGHBOR) {
|
} else if (resize_param->method_ == schema::ResizeMethod_NEAREST) {
|
||||||
kernel_name += "_nearest_neighbor";
|
kernel_name += "_nearest_neighbor";
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "unsupported resize method:" << resize_param->method_;
|
MS_LOG(ERROR) << "unsupported resize method:" << resize_param->method_;
|
||||||
|
|
|
@ -14,11 +14,11 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include "common/common_test.h"
|
||||||
|
#include "mindspore/lite/src/kernel_registry.h"
|
||||||
#include "mindspore/lite/src/lite_kernel.h"
|
#include "mindspore/lite/src/lite_kernel.h"
|
||||||
#include "mindspore/lite/src/tensor.h"
|
#include "mindspore/lite/src/tensor.h"
|
||||||
#include "common/common_test.h"
|
|
||||||
#include "nnacl/resize_parameter.h"
|
#include "nnacl/resize_parameter.h"
|
||||||
#include "mindspore/lite/src/kernel_registry.h"
|
|
||||||
#include "schema/ops_generated.h"
|
#include "schema/ops_generated.h"
|
||||||
using mindspore::schema::Format_NHWC;
|
using mindspore::schema::Format_NHWC;
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ void TestResizeBilinearFp32::Prepare(const std::vector<int> &input_shape, const
|
||||||
out_tensor_.SetData(output_data);
|
out_tensor_.SetData(output_data);
|
||||||
|
|
||||||
ResizeParameter param_ = {
|
ResizeParameter param_ = {
|
||||||
{}, static_cast<int>(schema::ResizeMethod_BILINEAR), output_shape[1], output_shape[2], align_corners};
|
{}, static_cast<int>(schema::ResizeMethod_LINEAR), output_shape[1], output_shape[2], align_corners};
|
||||||
desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Resize};
|
desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Resize};
|
||||||
ctx_ = lite::InnerContext();
|
ctx_ = lite::InnerContext();
|
||||||
ctx_.thread_num_ = thread_num;
|
ctx_.thread_num_ = thread_num;
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "common/common_test.h"
|
#include "common/common_test.h"
|
||||||
#include "nnacl/resize_parameter.h"
|
#include "nnacl/resize_parameter.h"
|
||||||
#include "mindspore/lite/src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
||||||
|
@ -57,7 +57,7 @@ void TestResizeNearestNeighborFp32::Prepare(const std::vector<int> &input_shape,
|
||||||
out_tensor_.SetData(output_data);
|
out_tensor_.SetData(output_data);
|
||||||
|
|
||||||
ResizeParameter param_ = {
|
ResizeParameter param_ = {
|
||||||
{}, static_cast<int>(schema::ResizeMethod_NEAREST_NEIGHBOR), output_shape[1], output_shape[2], align_corners};
|
{}, static_cast<int>(schema::ResizeMethod_NEAREST), output_shape[1], output_shape[2], align_corners};
|
||||||
desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Resize};
|
desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Resize};
|
||||||
ctx_ = lite::InnerContext();
|
ctx_ = lite::InnerContext();
|
||||||
ctx_.thread_num_ = thread_num;
|
ctx_.thread_num_ = thread_num;
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
#include "include/context.h"
|
#include "include/context.h"
|
||||||
#include "src/tensor.h"
|
#include "src/tensor.h"
|
||||||
#include "common/common_test.h"
|
#include "common/common_test.h"
|
||||||
#include "mindspore/lite/src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
#include "nnacl/int8/resize.h"
|
#include "nnacl/int8/resize.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -68,7 +68,7 @@ void TestResizeBilinearInt8::Prepare(const std::vector<int> &in_shape, const std
|
||||||
inputs.push_back(&in_tensor);
|
inputs.push_back(&in_tensor);
|
||||||
outputs.push_back(&out_tensor);
|
outputs.push_back(&out_tensor);
|
||||||
|
|
||||||
param_.method_ = static_cast<int>(schema::ResizeMethod_BILINEAR);
|
param_.method_ = static_cast<int>(schema::ResizeMethod_LINEAR);
|
||||||
param_.new_width_ = out_shape[2];
|
param_.new_width_ = out_shape[2];
|
||||||
param_.new_height_ = out_shape[1];
|
param_.new_height_ = out_shape[1];
|
||||||
param_.align_corners_ = align_corners;
|
param_.align_corners_ = align_corners;
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
#include "include/context.h"
|
#include "include/context.h"
|
||||||
#include "src/tensor.h"
|
#include "src/tensor.h"
|
||||||
#include "common/common_test.h"
|
#include "common/common_test.h"
|
||||||
#include "mindspore/lite/src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
#include "nnacl/int8/resize.h"
|
#include "nnacl/int8/resize.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -63,7 +63,7 @@ void TestResizeNearestNeighborInt8::Prepare(const std::vector<int> &in_shape, co
|
||||||
inputs.push_back(&in_tensor);
|
inputs.push_back(&in_tensor);
|
||||||
outputs.push_back(&out_tensor);
|
outputs.push_back(&out_tensor);
|
||||||
|
|
||||||
param_.method_ = static_cast<int>(schema::ResizeMethod_NEAREST_NEIGHBOR);
|
param_.method_ = static_cast<int>(schema::ResizeMethod_NEAREST);
|
||||||
param_.new_width_ = out_shape[2];
|
param_.new_width_ = out_shape[2];
|
||||||
param_.new_height_ = out_shape[1];
|
param_.new_height_ = out_shape[1];
|
||||||
param_.align_corners_ = align_corners;
|
param_.align_corners_ = align_corners;
|
||||||
|
|
|
@ -15,13 +15,13 @@
|
||||||
*/
|
*/
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "src/common/log_adapter.h"
|
|
||||||
#include "common/common_test.h"
|
#include "common/common_test.h"
|
||||||
#include "mindspore/lite/src/common/file_utils.h"
|
#include "src/common/file_utils.h"
|
||||||
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
|
#include "src/common/log_adapter.h"
|
||||||
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
|
#include "src/runtime/kernel/opencl/kernel/resize.h"
|
||||||
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/resize.h"
|
#include "src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
|
||||||
#include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h"
|
#include "src/runtime/opencl/opencl_runtime.h"
|
||||||
|
#include "test/ut/src/runtime/kernel/opencl/utils_tests.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
class TestResizeOpenCL : public mindspore::CommonTest {
|
class TestResizeOpenCL : public mindspore::CommonTest {
|
||||||
|
@ -119,7 +119,7 @@ TEST_F(TestResizeOpenCL, ResizeBilinearFp32) {
|
||||||
std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
|
std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||||
std::vector<float> output_data = {0.0f, 0.5f, 1.0f, 1.0f, 1.0f, 1.5f, 2.0f, 2.0f,
|
std::vector<float> output_data = {0.0f, 0.5f, 1.0f, 1.0f, 1.0f, 1.5f, 2.0f, 2.0f,
|
||||||
2.0f, 2.5f, 3.0f, 3.0f, 2.0f, 2.5f, 3.0f, 3.0f};
|
2.0f, 2.5f, 3.0f, 3.0f, 2.0f, 2.5f, 3.0f, 3.0f};
|
||||||
RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_BILINEAR, align_corners);
|
RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_LINEAR, align_corners);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestResizeOpenCL, ResizeBilinearFp16) {
|
TEST_F(TestResizeOpenCL, ResizeBilinearFp16) {
|
||||||
|
@ -134,7 +134,7 @@ TEST_F(TestResizeOpenCL, ResizeBilinearFp16) {
|
||||||
std::vector<float16_t> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
|
std::vector<float16_t> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||||
std::vector<float16_t> output_data = {0.0f, 0.5f, 1.0f, 1.0f, 1.0f, 1.5f, 2.0f, 2.0f,
|
std::vector<float16_t> output_data = {0.0f, 0.5f, 1.0f, 1.0f, 1.0f, 1.5f, 2.0f, 2.0f,
|
||||||
2.0f, 2.5f, 3.0f, 3.0f, 2.0f, 2.5f, 3.0f, 3.0f};
|
2.0f, 2.5f, 3.0f, 3.0f, 2.0f, 2.5f, 3.0f, 3.0f};
|
||||||
RunTestCaseResize(shape, input_data.data(), output_data.data(), true, schema::ResizeMethod_BILINEAR, align_corners);
|
RunTestCaseResize(shape, input_data.data(), output_data.data(), true, schema::ResizeMethod_LINEAR, align_corners);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestResizeOpenCL, ResizeBilinearAlignFp32) {
|
TEST_F(TestResizeOpenCL, ResizeBilinearAlignFp32) {
|
||||||
|
@ -148,7 +148,7 @@ TEST_F(TestResizeOpenCL, ResizeBilinearAlignFp32) {
|
||||||
std::vector<int> shape = {n, h, w, oh, ow, c};
|
std::vector<int> shape = {n, h, w, oh, ow, c};
|
||||||
std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
|
std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||||
std::vector<float> output_data = {0.0f, 0.5f, 1.0f, 1.0f, 1.5f, 2.0f, 2.0f, 2.5f, 3.0f};
|
std::vector<float> output_data = {0.0f, 0.5f, 1.0f, 1.0f, 1.5f, 2.0f, 2.0f, 2.5f, 3.0f};
|
||||||
RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_BILINEAR, align_corners);
|
RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_LINEAR, align_corners);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestResizeOpenCL, ResizeNearestNeighborFp32) {
|
TEST_F(TestResizeOpenCL, ResizeNearestNeighborFp32) {
|
||||||
|
@ -163,8 +163,7 @@ TEST_F(TestResizeOpenCL, ResizeNearestNeighborFp32) {
|
||||||
std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
|
std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||||
std::vector<float> output_data = {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f,
|
std::vector<float> output_data = {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f,
|
||||||
2.0f, 2.0f, 3.0f, 3.0f, 2.0f, 2.0f, 3.0f, 3.0f};
|
2.0f, 2.0f, 3.0f, 3.0f, 2.0f, 2.0f, 3.0f, 3.0f};
|
||||||
RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_NEAREST_NEIGHBOR,
|
RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_NEAREST, align_corners);
|
||||||
align_corners);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestResizeOpenCL, ResizeNearestNeighborFp16) {
|
TEST_F(TestResizeOpenCL, ResizeNearestNeighborFp16) {
|
||||||
|
@ -179,7 +178,6 @@ TEST_F(TestResizeOpenCL, ResizeNearestNeighborFp16) {
|
||||||
std::vector<float16_t> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
|
std::vector<float16_t> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||||
std::vector<float16_t> output_data = {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f,
|
std::vector<float16_t> output_data = {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f,
|
||||||
2.0f, 2.0f, 3.0f, 3.0f, 2.0f, 2.0f, 3.0f, 3.0f};
|
2.0f, 2.0f, 3.0f, 3.0f, 2.0f, 2.0f, 3.0f, 3.0f};
|
||||||
RunTestCaseResize(shape, input_data.data(), output_data.data(), true, schema::ResizeMethod_NEAREST_NEIGHBOR,
|
RunTestCaseResize(shape, input_data.data(), output_data.data(), true, schema::ResizeMethod_NEAREST, align_corners);
|
||||||
align_corners);
|
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -40,7 +40,7 @@ TEST_F(TestTfliteParserResizeNN, AttrValue) {
|
||||||
ASSERT_EQ(val->newWidth, 100);
|
ASSERT_EQ(val->newWidth, 100);
|
||||||
ASSERT_EQ(val->format, schema::Format_NHWC);
|
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||||
ASSERT_EQ(val->preserveAspectRatio, false);
|
ASSERT_EQ(val->preserveAspectRatio, false);
|
||||||
ASSERT_EQ(val->method, schema::ResizeMethod_NEAREST_NEIGHBOR);
|
ASSERT_EQ(val->method, schema::ResizeMethod_NEAREST);
|
||||||
}
|
}
|
||||||
|
|
||||||
class TestTfliteParserResizeBilinear : public TestTfliteParser {
|
class TestTfliteParserResizeBilinear : public TestTfliteParser {
|
||||||
|
@ -64,7 +64,7 @@ TEST_F(TestTfliteParserResizeBilinear, AttrValue) {
|
||||||
ASSERT_EQ(val->newWidth, 4);
|
ASSERT_EQ(val->newWidth, 4);
|
||||||
ASSERT_EQ(val->format, schema::Format_NHWC);
|
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||||
ASSERT_EQ(val->preserveAspectRatio, false);
|
ASSERT_EQ(val->preserveAspectRatio, false);
|
||||||
ASSERT_EQ(val->method, schema::ResizeMethod_BILINEAR);
|
ASSERT_EQ(val->method, schema::ResizeMethod_LINEAR);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -57,7 +57,7 @@ STATUS CaffeInterpParser::Parse(const caffe::LayerParameter &proto, const caffe:
|
||||||
attr->newWidth = width;
|
attr->newWidth = width;
|
||||||
}
|
}
|
||||||
attr->alignCorners = true;
|
attr->alignCorners = true;
|
||||||
attr->method = schema::ResizeMethod_BILINEAR;
|
attr->method = schema::ResizeMethod_LINEAR;
|
||||||
|
|
||||||
op->name = proto.name();
|
op->name = proto.name();
|
||||||
op->primitive->value.type = schema::PrimitiveType_Resize;
|
op->primitive->value.type = schema::PrimitiveType_Resize;
|
||||||
|
|
|
@ -582,6 +582,94 @@ STATUS OnnxSignParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
STATUS OnnxAndParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||||
|
MS_LOG(DEBUG) << "onnx AndParser";
|
||||||
|
if (op == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "op is null";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
|
if (op->primitive == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "op->primitive is null";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<schema::LogicalAndT> attr = std::make_unique<schema::LogicalAndT>();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new op failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
op->primitive->value.type = schema::PrimitiveType_LogicalAnd;
|
||||||
|
op->primitive->value.value = attr.release();
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
STATUS OnnxOrParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||||
|
MS_LOG(DEBUG) << "onnx OrParser";
|
||||||
|
if (op == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "op is null";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
|
if (op->primitive == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "op->primitive is null";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<schema::LogicalOrT> attr = std::make_unique<schema::LogicalOrT>();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new op failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
op->primitive->value.type = schema::PrimitiveType_LogicalOr;
|
||||||
|
op->primitive->value.value = attr.release();
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
STATUS OnnxNotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||||
|
MS_LOG(DEBUG) << "onnx NotParser";
|
||||||
|
if (op == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "op is null";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
|
if (op->primitive == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "op->primitive is null";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<schema::LogicalNotT> attr = std::make_unique<schema::LogicalNotT>();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new op failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
op->primitive->value.type = schema::PrimitiveType_LogicalNot;
|
||||||
|
op->primitive->value.value = attr.release();
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
STATUS OnnxRoundParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||||
|
schema::CNodeT *op) {
|
||||||
|
MS_LOG(DEBUG) << "onnx RoundParser";
|
||||||
|
if (op == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "op is null";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
|
if (op->primitive == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "op->primitive is null";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<schema::RoundT> attr = std::make_unique<schema::RoundT>();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new op failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
op->primitive->value.type = schema::PrimitiveType_Round;
|
||||||
|
op->primitive->value.value = attr.release();
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser());
|
OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser());
|
||||||
OnnxNodeRegistrar g_onnxInt8AddParser("Int8Add", new OnnxAddParser());
|
OnnxNodeRegistrar g_onnxInt8AddParser("Int8Add", new OnnxAddParser());
|
||||||
OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser());
|
OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser());
|
||||||
|
@ -608,5 +696,9 @@ OnnxNodeRegistrar g_onnxAtanParser("Atan", new OnnxAtanParser());
|
||||||
OnnxNodeRegistrar g_onnxAsinParser("Asin", new OnnxAsinParser());
|
OnnxNodeRegistrar g_onnxAsinParser("Asin", new OnnxAsinParser());
|
||||||
OnnxNodeRegistrar g_onnxTanhParser("Tanh", new OnnxTanhParser());
|
OnnxNodeRegistrar g_onnxTanhParser("Tanh", new OnnxTanhParser());
|
||||||
OnnxNodeRegistrar g_onnxSignParser("Sign", new OnnxTanhParser());
|
OnnxNodeRegistrar g_onnxSignParser("Sign", new OnnxTanhParser());
|
||||||
|
OnnxNodeRegistrar g_onnxAndParser("And", new OnnxAndParser());
|
||||||
|
OnnxNodeRegistrar g_onnxOrParser("Or", new OnnxOrParser());
|
||||||
|
OnnxNodeRegistrar g_onnxNotParser("Not", new OnnxNotParser());
|
||||||
|
OnnxNodeRegistrar g_onnxRoundParser("Round", new OnnxRoundParser());
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -171,6 +171,30 @@ class OnnxSignParser : public OnnxNodeParser {
|
||||||
OnnxSignParser() : OnnxNodeParser("Sign") {}
|
OnnxSignParser() : OnnxNodeParser("Sign") {}
|
||||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class OnnxAndParser : public OnnxNodeParser {
|
||||||
|
public:
|
||||||
|
OnnxAndParser() : OnnxNodeParser("And") {}
|
||||||
|
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class OnnxOrParser : public OnnxNodeParser {
|
||||||
|
public:
|
||||||
|
OnnxOrParser() : OnnxNodeParser("Or") {}
|
||||||
|
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class OnnxNotParser : public OnnxNodeParser {
|
||||||
|
public:
|
||||||
|
OnnxNotParser() : OnnxNodeParser("Not") {}
|
||||||
|
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class OnnxRoundParser : public OnnxNodeParser {
|
||||||
|
public:
|
||||||
|
OnnxRoundParser() : OnnxNodeParser("Round") {}
|
||||||
|
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||||
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H
|
||||||
|
|
|
@ -15,9 +15,9 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "tools/converter/parser/onnx/onnx_conv_parser.h"
|
#include "tools/converter/parser/onnx/onnx_conv_parser.h"
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
|
@ -176,9 +176,6 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
||||||
MS_LOG(ERROR) << "Convert Convolution to Depthwise failed";
|
MS_LOG(ERROR) << "Convert Convolution to Depthwise failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
} else if (attr->group != 1) {
|
|
||||||
MS_LOG(ERROR) << "group conv hasn't supported";
|
|
||||||
return RET_NOT_SUPPORT;
|
|
||||||
} else {
|
} else {
|
||||||
op->primitive->value.type = schema::PrimitiveType_Conv2D;
|
op->primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "tools/converter/parser/onnx/onnx_identity_parser.h"
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
STATUS OnnxIdentityParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||||
|
schema::CNodeT *op) {
|
||||||
|
MS_LOG(DEBUG) << "onnx IdentityParser";
|
||||||
|
if (op == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "op is null";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
|
if (op->primitive == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "op->primitive is null";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<schema::IdentityT> attr = std::make_unique<schema::IdentityT>();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new op failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
|
||||||
|
op->primitive->value.type = schema::PrimitiveType_Identity;
|
||||||
|
op->primitive->value.value = attr.release();
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
OnnxNodeRegistrar g_onnxIdentityParser("Identity", new OnnxIdentityParser());
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,33 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IDENTITY_PARSER_H
|
||||||
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IDENTITY_PARSER_H
|
||||||
|
|
||||||
|
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||||
|
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
class OnnxIdentityParser : public OnnxNodeParser {
|
||||||
|
public:
|
||||||
|
OnnxIdentityParser() : OnnxNodeParser("Identity") {}
|
||||||
|
|
||||||
|
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||||
|
};
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IDENTITY_PARSER_H
|
|
@ -0,0 +1,55 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "tools/converter/parser/onnx/onnx_instance_norm_parser.h"
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
STATUS OnnxInstanceNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||||
|
schema::CNodeT *op) {
|
||||||
|
MS_LOG(DEBUG) << "onnx InstanceNormParser";
|
||||||
|
if (op == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "op is null";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
|
if (op->primitive == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "op->primitive is null";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<schema::InstanceNormT> attr = std::make_unique<schema::InstanceNormT>();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new op failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!onnx_node.attribute().empty()) {
|
||||||
|
auto onnx_node_attr = onnx_node.attribute().at(0);
|
||||||
|
if (onnx_node_attr.name() == "epsilon") {
|
||||||
|
attr->epsilon = onnx_node_attr.f();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
op->primitive->value.type = schema::PrimitiveType_InstanceNorm;
|
||||||
|
op->primitive->value.value = attr.release();
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
OnnxNodeRegistrar g_onnxInstanceNormParser("InstanceNormalization", new OnnxInstanceNormParser());
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,33 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_INSTANCE_NORM_PARSER_H
|
||||||
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_INSTANCE_NORM_PARSER_H
|
||||||
|
|
||||||
|
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||||
|
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
class OnnxInstanceNormParser : public OnnxNodeParser {
|
||||||
|
public:
|
||||||
|
OnnxInstanceNormParser() : OnnxNodeParser("InstanceNorm") {}
|
||||||
|
|
||||||
|
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||||
|
};
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_INSTANCE_NORM_PARSER_H
|
|
@ -15,12 +15,12 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
||||||
|
#include <algorithm>
|
||||||
#include <cfloat>
|
#include <cfloat>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <algorithm>
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include "tools/common/graph_util.h"
|
|
||||||
#include "src/common/utils.h"
|
#include "src/common/utils.h"
|
||||||
|
#include "tools/common/graph_util.h"
|
||||||
#include "tools/common/protobuf_utils.h"
|
#include "tools/common/protobuf_utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -36,7 +36,8 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
|
||||||
{onnx::TensorProto_DataType_UINT32, mindspore::kNumberTypeUInt32},
|
{onnx::TensorProto_DataType_UINT32, mindspore::kNumberTypeUInt32},
|
||||||
{onnx::TensorProto_DataType_INT64, mindspore::kNumberTypeInt64},
|
{onnx::TensorProto_DataType_INT64, mindspore::kNumberTypeInt64},
|
||||||
{onnx::TensorProto_DataType_FLOAT16, mindspore::kNumberTypeFloat16},
|
{onnx::TensorProto_DataType_FLOAT16, mindspore::kNumberTypeFloat16},
|
||||||
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32}};
|
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
|
||||||
|
{onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
|
||||||
|
|
||||||
TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type) {
|
TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type) {
|
||||||
auto iter = TYPE_MAP.find(onnx_type);
|
auto iter = TYPE_MAP.find(onnx_type);
|
||||||
|
@ -161,10 +162,14 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
|
||||||
TensorCache *tensor_cache) {
|
TensorCache *tensor_cache) {
|
||||||
for (const auto &output_value : onnx_graph.output()) {
|
for (const auto &output_value : onnx_graph.output()) {
|
||||||
int index;
|
int index;
|
||||||
|
if (tensor_cache->FindTensor(output_value.name()) != -1) {
|
||||||
|
index = tensor_cache->FindTensor(output_value.name());
|
||||||
|
} else {
|
||||||
const auto status = AddValueInfo(output_value, output_value.name(), OP_OUTPUT, tensor_cache, &index);
|
const auto status = AddValueInfo(output_value, output_value.name(), OP_OUTPUT, tensor_cache, &index);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
graph->outputIndex.emplace_back(index);
|
graph->outputIndex.emplace_back(index);
|
||||||
MS_LOG(DEBUG) << "output_value name: " << output_value.name() << ", graph output index: " << index;
|
MS_LOG(DEBUG) << "output_value name: " << output_value.name() << ", graph output index: " << index;
|
||||||
}
|
}
|
||||||
|
@ -250,7 +255,8 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
|
||||||
|
|
||||||
STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||||
schema::CNodeT *dst_op, schema::TensorT *dst_tensor,
|
schema::CNodeT *dst_op, schema::TensorT *dst_tensor,
|
||||||
TensorCache *tensor_cache, const QuantType &quantType) {
|
TensorCache *tensor_cache, const QuantType &quantType,
|
||||||
|
schema::MetaGraphT *dst_graph) {
|
||||||
// change op_type() to name(), that is unique
|
// change op_type() to name(), that is unique
|
||||||
static bool interrupt = false;
|
static bool interrupt = false;
|
||||||
dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0);
|
dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0);
|
||||||
|
@ -260,6 +266,16 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
|
||||||
<< onnx_node.input_size();
|
<< onnx_node.input_size();
|
||||||
// get the real op type
|
// get the real op type
|
||||||
SetOpQuantParams(onnx_graph, onnx_node, dst_op, dst_tensor, tensor_cache);
|
SetOpQuantParams(onnx_graph, onnx_node, dst_op, dst_tensor, tensor_cache);
|
||||||
|
if (onnx_node.op_type() == "Loop") {
|
||||||
|
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
|
||||||
|
interrupt = true;
|
||||||
|
return RET_NOT_FIND_OP;
|
||||||
|
int status = ParseLoopAttr(dst_op, onnx_node, quantType, dst_graph);
|
||||||
|
if (status != RET_OK || interrupt) {
|
||||||
|
interrupt = true;
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type());
|
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type());
|
||||||
if (node_parser == nullptr || interrupt) {
|
if (node_parser == nullptr || interrupt) {
|
||||||
interrupt = true;
|
interrupt = true;
|
||||||
|
@ -271,13 +287,14 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
|
||||||
auto status = node_parser->Parse(onnx_graph, onnx_node, dst_op);
|
auto status = node_parser->Parse(onnx_graph, onnx_node, dst_op);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
interrupt = true;
|
interrupt = true;
|
||||||
if (status == RET_NOT_SUPPORT) {
|
if (status == RET_NOT_FIND_OP) {
|
||||||
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
|
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "parser onnx node " << onnx_node.op_type() << " attr failed";
|
MS_LOG(ERROR) << "parser onnx node " << onnx_node.op_type() << " attr failed";
|
||||||
}
|
}
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
// set op input index
|
// set op input index
|
||||||
std::vector<string> node_inputs;
|
std::vector<string> node_inputs;
|
||||||
(void)node_inputs.insert(node_inputs.begin(), onnx_node.input().begin(), onnx_node.input().end());
|
(void)node_inputs.insert(node_inputs.begin(), onnx_node.input().begin(), onnx_node.input().end());
|
||||||
|
@ -366,7 +383,7 @@ STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs,
|
||||||
const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) {
|
const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) {
|
||||||
for (const auto &onnx_node_input : node_inputs) {
|
for (const auto &onnx_node_input : node_inputs) {
|
||||||
if (onnx_node_input != "") {
|
if (onnx_node_input != "") {
|
||||||
auto index = tensor_cache->FindTensor(onnx_node_input);
|
int index = tensor_cache->FindTensor(onnx_node_input);
|
||||||
if (index < 0) {
|
if (index < 0) {
|
||||||
MS_LOG(ERROR) << "input " << onnx_node_input << " of node " << onnx_node.name() << " can't be found";
|
MS_LOG(ERROR) << "input " << onnx_node_input << " of node " << onnx_node.name() << " can't be found";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
@ -428,6 +445,9 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < data_count; ++i) {
|
for (size_t i = 0; i < data_count; ++i) {
|
||||||
if (in_data[i] > static_cast<int64_t>(INT32_MAX) || in_data[i] < static_cast<int64_t>(INT32_MIN)) {
|
if (in_data[i] > static_cast<int64_t>(INT32_MAX) || in_data[i] < static_cast<int64_t>(INT32_MIN)) {
|
||||||
|
if (llabs(in_data[i]) == INT64_MAX || in_data[i] == INT64_MIN) {
|
||||||
|
buffer[i] = in_data[i] > 0 ? INT32_MAX : INT32_MIN;
|
||||||
|
}
|
||||||
MS_LOG(ERROR) << "int64 data " << in_data[i] << "too big to fit into int32";
|
MS_LOG(ERROR) << "int64 data " << in_data[i] << "too big to fit into int32";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
} else {
|
} else {
|
||||||
|
@ -438,6 +458,7 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
|
||||||
break;
|
break;
|
||||||
case kNumberTypeUInt8:
|
case kNumberTypeUInt8:
|
||||||
case kNumberTypeInt8:
|
case kNumberTypeInt8:
|
||||||
|
case kNumberTypeBool:
|
||||||
data_size = data_count * sizeof(uint8_t);
|
data_size = data_count * sizeof(uint8_t);
|
||||||
tensor_data = onnx_const_value.raw_data().data();
|
tensor_data = onnx_const_value.raw_data().data();
|
||||||
break;
|
break;
|
||||||
|
@ -446,7 +467,7 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
tensor->data.resize(data_size);
|
tensor->data.resize(data_size);
|
||||||
if (memcpy_s(static_cast<void *>(tensor->data.data()), data_size, tensor_data, data_size) != 0) {
|
if (data_size != 0 && memcpy_s(static_cast<void *>(tensor->data.data()), data_size, tensor_data, data_size) != 0) {
|
||||||
MS_LOG(ERROR) << "memcpy_s failed";
|
MS_LOG(ERROR) << "memcpy_s failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
@ -475,30 +496,39 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
|
STATUS OnnxModelParser::ParseLoopAttr(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node,
|
||||||
const QuantType &quantType) {
|
const QuantType &quantType, schema::MetaGraphT *dst_graph) {
|
||||||
int status = ValidateFileStr(modelFile, ".onnx");
|
MS_LOG(DEBUG) << "onnx LoopParser";
|
||||||
if (status != RET_OK) {
|
if (dst_op == nullptr) {
|
||||||
MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx";
|
MS_LOG(ERROR) << "op is null";
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
return RET_NULL_PTR;
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
dst_op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
onnx::ModelProto onnx_model;
|
if (dst_op->primitive == nullptr) {
|
||||||
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model);
|
MS_LOG(ERROR) << "op->primitive is null";
|
||||||
if (status != RET_OK) {
|
return RET_NULL_PTR;
|
||||||
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << modelFile;
|
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
const onnx::GraphProto &onnx_graph = onnx_model.graph();
|
std::unique_ptr<schema::LoopT> attr = std::make_unique<schema::LoopT>();
|
||||||
MS_LOG(INFO) << "model producer name: " << onnx_model.producer_name() << ", graph name: " << onnx_graph.name();
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new op failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
attr->subGraphIndex = subGraphNum;
|
||||||
|
auto sub_graph = std::make_unique<schema::MetaGraphT>();
|
||||||
|
sub_graph.reset(ParseGraph(onnx_node.attribute().at(0).g(), quantType));
|
||||||
|
dst_graph->subGraph.push_back(std::move(sub_graph));
|
||||||
|
subGraphNum += 1;
|
||||||
|
dst_op->primitive->value.type = schema::PrimitiveType_Loop;
|
||||||
|
dst_op->primitive->value.value = attr.release();
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_graph, const QuantType &quantType) {
|
||||||
TensorCache tensor_cache;
|
TensorCache tensor_cache;
|
||||||
// dst_graph->name = onnx_graph.name(); // this is not used
|
// dst_graph->name = onnx_graph.name(); // this is not used
|
||||||
// find out input names and const names
|
// find out input names and const names
|
||||||
FindGraphInputAndConst(onnx_graph);
|
FindGraphInputAndConst(onnx_graph);
|
||||||
// set const tensor
|
// set const tensor
|
||||||
status = SetGraphConstTensor(onnx_graph, &tensor_cache);
|
int status = SetGraphConstTensor(onnx_graph, &tensor_cache);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "SetGraphConstTensor failed";
|
MS_LOG(ERROR) << "SetGraphConstTensor failed";
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
|
@ -512,13 +542,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
// init onnx model graph output tensor
|
|
||||||
status = SetGraphOutputTensor(onnx_graph, dst_graph.get(), &tensor_cache);
|
|
||||||
if (status != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "SetGraphOutputTensor failed";
|
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
// init op node input/output tensor, and dst_op attr
|
// init op node input/output tensor, and dst_op attr
|
||||||
NoSupportOp::GetInstance()->SetFmkType("ONNX");
|
NoSupportOp::GetInstance()->SetFmkType("ONNX");
|
||||||
for (const auto &onnx_node : onnx_graph.node()) {
|
for (const auto &onnx_node : onnx_graph.node()) {
|
||||||
|
@ -544,7 +568,8 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
|
||||||
|
|
||||||
std::unique_ptr<schema::CNodeT> dst_op = std::make_unique<schema::CNodeT>();
|
std::unique_ptr<schema::CNodeT> dst_op = std::make_unique<schema::CNodeT>();
|
||||||
std::unique_ptr<schema::TensorT> dst_tensor = std::make_unique<schema::TensorT>();
|
std::unique_ptr<schema::TensorT> dst_tensor = std::make_unique<schema::TensorT>();
|
||||||
status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType);
|
status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType,
|
||||||
|
dst_graph.get());
|
||||||
if (status_node != RET_OK) {
|
if (status_node != RET_OK) {
|
||||||
status = (status == RET_OK ? status_node : status);
|
status = (status == RET_OK ? status_node : status);
|
||||||
continue;
|
continue;
|
||||||
|
@ -558,9 +583,42 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
// init onnx model graph output tensor
|
||||||
|
status = SetGraphOutputTensor(onnx_graph, dst_graph.get(), &tensor_cache);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "SetGraphOutputTensor failed";
|
||||||
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
SetAllTensors(tensor_cache, dst_graph.get());
|
SetAllTensors(tensor_cache, dst_graph.get());
|
||||||
dst_graph->name = GetModelName(modelFile);
|
|
||||||
return dst_graph.release();
|
return dst_graph.release();
|
||||||
}
|
}
|
||||||
|
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
|
||||||
|
const QuantType &quantType) {
|
||||||
|
int status = ValidateFileStr(modelFile, ".onnx");
|
||||||
|
if (status != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx";
|
||||||
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
onnx::ModelProto onnx_model;
|
||||||
|
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << modelFile;
|
||||||
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
const onnx::GraphProto &onnx_graph = onnx_model.graph();
|
||||||
|
MS_LOG(INFO) << "model producer name: " << onnx_model.producer_name() << ", graph name: " << onnx_graph.name();
|
||||||
|
|
||||||
|
schema::MetaGraphT *dst_graph = ParseGraph(onnx_graph, quantType);
|
||||||
|
if (dst_graph == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
dst_graph->name = GetModelName(modelFile);
|
||||||
|
return dst_graph;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <set>
|
#include <set>
|
||||||
|
#include <map>
|
||||||
#include "securec/include/securec.h"
|
#include "securec/include/securec.h"
|
||||||
#include "tools/converter/model_parser.h"
|
#include "tools/converter/model_parser.h"
|
||||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||||
|
@ -40,6 +41,7 @@ class OnnxModelParser : public ModelParser {
|
||||||
|
|
||||||
virtual ~OnnxModelParser();
|
virtual ~OnnxModelParser();
|
||||||
|
|
||||||
|
schema::MetaGraphT *ParseGraph(const onnx::GraphProto &graph, const QuantType &quantType = QuantType_QUANT_NONE);
|
||||||
schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
|
schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
|
||||||
const QuantType &quantType = QuantType_QUANT_NONE) override;
|
const QuantType &quantType = QuantType_QUANT_NONE) override;
|
||||||
|
|
||||||
|
@ -62,7 +64,7 @@ class OnnxModelParser : public ModelParser {
|
||||||
|
|
||||||
STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||||
schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache,
|
schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache,
|
||||||
const QuantType &quantType);
|
const QuantType &quantType, schema::MetaGraphT *dst_graph);
|
||||||
|
|
||||||
void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||||
schema::MetaGraphT *graph, TensorCache *tensor_cache, const QuantType &quant_type);
|
schema::MetaGraphT *graph, TensorCache *tensor_cache, const QuantType &quant_type);
|
||||||
|
@ -86,9 +88,13 @@ class OnnxModelParser : public ModelParser {
|
||||||
|
|
||||||
void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph);
|
void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph);
|
||||||
|
|
||||||
|
STATUS ParseLoopAttr(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node, const QuantType &quantType,
|
||||||
|
schema::MetaGraphT *dst_graph);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<string> graphInputNames;
|
std::vector<std::string> graphInputNames;
|
||||||
std::vector<string> graphConstNames;
|
std::vector<std::string> graphConstNames;
|
||||||
|
int subGraphNum = 0;
|
||||||
};
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,80 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "tools/converter/parser/onnx/onnx_non_max_suppression_parser.h"
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
STATUS OnnxNonMaxSuppressionParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||||
|
schema::CNodeT *op) {
|
||||||
|
MS_LOG(DEBUG) << "onnx EluParser";
|
||||||
|
if (op == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "op is null";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
|
if (op->primitive == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "op->primitive is null";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<schema::NonMaxSuppressionT> attr = std::make_unique<schema::NonMaxSuppressionT>();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new op failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
if (onnx_node.input_size() > 2) {
|
||||||
|
auto it = std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
|
||||||
|
[&](const onnx::TensorProto &it) { return it.name() == onnx_node.input(2); });
|
||||||
|
if (it != onnx_graph.initializer().end()) {
|
||||||
|
attr->maxOutBoxPerClass = it->int64_data(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (onnx_node.input_size() > 3) {
|
||||||
|
auto it = std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
|
||||||
|
[&](const onnx::TensorProto &it) { return it.name() == onnx_node.input(3); });
|
||||||
|
if (it != onnx_graph.initializer().end()) {
|
||||||
|
attr->iouThreshold = it->float_data(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (onnx_node.input_size() > 4) {
|
||||||
|
auto it = std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
|
||||||
|
[&](const onnx::TensorProto &it) { return it.name() == onnx_node.input(4); });
|
||||||
|
if (it != onnx_graph.initializer().end()) {
|
||||||
|
attr->scoreThreshold = it->float_data(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||||
|
const auto &attribute_name = onnx_node_attr.name();
|
||||||
|
if (attribute_name == "center_point_box") {
|
||||||
|
if (onnx_node_attr.has_i()) {
|
||||||
|
attr->centerPointBox = onnx_node_attr.i();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
op->primitive->value.type = schema::PrimitiveType_Elu;
|
||||||
|
op->primitive->value.value = attr.release();
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
OnnxNodeRegistrar g_onnxNonMaxSuppressionParser("NonMaxSuppression", new OnnxNonMaxSuppressionParser());
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,33 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NON_MAX_SUPPRESSION_PARSER_H
|
||||||
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NON_MAX_SUPPRESSION_PARSER_H
|
||||||
|
|
||||||
|
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||||
|
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
class OnnxNonMaxSuppressionParser : public OnnxNodeParser {
|
||||||
|
public:
|
||||||
|
OnnxNonMaxSuppressionParser() : OnnxNodeParser("NonMaxSuppression") {}
|
||||||
|
|
||||||
|
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||||
|
};
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NON_MAX_SUPPRESSION_PARSER_H
|
|
@ -0,0 +1,95 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "tools/converter/parser/onnx/onnx_resize_parser.h"
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
STATUS OnnxResizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||||
|
schema::CNodeT *op) {
|
||||||
|
MS_LOG(DEBUG) << "onnx ResizeParser";
|
||||||
|
if (op == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "op is null";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
|
if (op->primitive == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "op->primitive is null";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<schema::ResizeT> attr = std::make_unique<schema::ResizeT>();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new op failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
|
||||||
|
attr->format = schema::Format_NCHW;
|
||||||
|
attr->nearestMode = schema::NearestMode_ROUND_HALF_DOWN;
|
||||||
|
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||||
|
const auto &attribute_name = onnx_node_attr.name();
|
||||||
|
if (attribute_name == "coordinate_transformation_mode") {
|
||||||
|
attr->coordinateTransformMode = [&]() {
|
||||||
|
std::map<std::string, schema::CoordinateTransformMode> transform_map = {
|
||||||
|
{"half_pixel", schema::CoordinateTransformMode_HALF_PIXEL},
|
||||||
|
{"pytorch_half_pixel", schema::CoordinateTransformMode_PYTORCH_HALF_PIXEL},
|
||||||
|
{"align_corners", schema::CoordinateTransformMode_ALIGN_CORNERS},
|
||||||
|
{"asymmetric", schema::CoordinateTransformMode_ASYMMETRIC},
|
||||||
|
{"tf_half_pixel_for_nn", schema::CoordinateTransformMode_TF_HALF_PIXEL},
|
||||||
|
{"tf_crop_and_resize", schema::CoordinateTransformMode_TF_CROP_AND_RESIZE},
|
||||||
|
};
|
||||||
|
return transform_map[onnx_node_attr.strings(0)];
|
||||||
|
}();
|
||||||
|
} else if (attribute_name == "cubic_coeff_a") {
|
||||||
|
attr->cubicCoeff = onnx_node_attr.f();
|
||||||
|
} else if (attribute_name == "exclude_outside") {
|
||||||
|
attr->excludeOutside = onnx_node_attr.i();
|
||||||
|
} else if (attribute_name == "extrapolation_value") {
|
||||||
|
attr->extrapolationValue = onnx_node_attr.f();
|
||||||
|
} else if (attribute_name == "mode") {
|
||||||
|
attr->method = [&]() {
|
||||||
|
std::map<std::string, schema::ResizeMethod> resize_mode = {
|
||||||
|
{"nearest", schema::ResizeMethod_NEAREST},
|
||||||
|
{"linear", schema::ResizeMethod_LINEAR},
|
||||||
|
{"cubic", schema::ResizeMethod_CUBIC},
|
||||||
|
};
|
||||||
|
return resize_mode[onnx_node_attr.strings(0)];
|
||||||
|
}();
|
||||||
|
} else if (attribute_name == "nearest_mode") {
|
||||||
|
attr->nearestMode = [&]() {
|
||||||
|
std::map<std::string, schema::NearestMode> nearest_mode = {
|
||||||
|
{"round_prefer_floor", schema::NearestMode_ROUND_HALF_DOWN},
|
||||||
|
{"round_prefer_ceil", schema::NearestMode_ROUND_HALF_UP},
|
||||||
|
{"floor", schema::NearestMode_FLOOR},
|
||||||
|
{"ceil", schema::NearestMode_CEIL},
|
||||||
|
};
|
||||||
|
return nearest_mode[onnx_node_attr.strings(0)];
|
||||||
|
}();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
op->primitive->value.type = schema::PrimitiveType_Resize;
|
||||||
|
op->primitive->value.value = attr.release();
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
OnnxNodeRegistrar g_onnxResizeParser("Resize", new OnnxResizeParser());
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,33 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESIZE_PARSER_H
|
||||||
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESIZE_PARSER_H
|
||||||
|
|
||||||
|
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||||
|
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
class OnnxResizeParser : public OnnxNodeParser {
|
||||||
|
public:
|
||||||
|
OnnxResizeParser() : OnnxNodeParser("Resize") {}
|
||||||
|
|
||||||
|
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||||
|
};
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_RESIZE_PARSER_H
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include "tools/converter/parser/onnx/onnx_upsample_parser.h"
|
#include "tools/converter/parser/onnx/onnx_upsample_parser.h"
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
|
@ -42,9 +42,9 @@ STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:
|
||||||
const auto &attribute_name = onnx_node_attr.name();
|
const auto &attribute_name = onnx_node_attr.name();
|
||||||
if (attribute_name == "mode") {
|
if (attribute_name == "mode") {
|
||||||
if ("nearest" == onnx_node_attr.s()) {
|
if ("nearest" == onnx_node_attr.s()) {
|
||||||
attr->method = schema::ResizeMethod_NEAREST_NEIGHBOR;
|
attr->method = schema::ResizeMethod_NEAREST;
|
||||||
} else if ("bilinear" == onnx_node_attr.s()) {
|
} else if ("bilinear" == onnx_node_attr.s()) {
|
||||||
attr->method = schema::ResizeMethod_BILINEAR;
|
attr->method = schema::ResizeMethod_LINEAR;
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "Resize do not support upsample mode";
|
MS_LOG(ERROR) << "Resize do not support upsample mode";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
|
|
@ -15,9 +15,9 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "tools/converter/parser/tflite/tflite_custom_parser.h"
|
#include "tools/converter/parser/tflite/tflite_custom_parser.h"
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
#include "flatbuffers/flatbuffers.h"
|
#include "flatbuffers/flatbuffers.h"
|
||||||
#include "flatbuffers/flexbuffers.h"
|
#include "flatbuffers/flexbuffers.h"
|
||||||
|
|
||||||
|
@ -206,6 +206,8 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
||||||
status = ExtractFeatures(custom_attr, op, tflite_op);
|
status = ExtractFeatures(custom_attr, op, tflite_op);
|
||||||
} else if (custom_type == "AudioSpectrogram") {
|
} else if (custom_type == "AudioSpectrogram") {
|
||||||
status = AudioSpectrogram(custom_attr, op, tflite_op);
|
status = AudioSpectrogram(custom_attr, op, tflite_op);
|
||||||
|
} else if (custom_type == "Mfcc") {
|
||||||
|
status = Mfcc(custom_attr, op, tflite_op);
|
||||||
} else if (custom_type == "FlexRFFT") {
|
} else if (custom_type == "FlexRFFT") {
|
||||||
status = Rfft(custom_attr, op, tflite_op, tflite_model);
|
status = Rfft(custom_attr, op, tflite_op, tflite_model);
|
||||||
} else if (custom_type == "FlexReal") {
|
} else if (custom_type == "FlexReal") {
|
||||||
|
|
|
@ -15,10 +15,10 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "tools/converter/parser/tflite/tflite_resize_parser.h"
|
#include "tools/converter/parser/tflite/tflite_resize_parser.h"
|
||||||
#include <vector>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <map>
|
#include <vector>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
|
@ -39,7 +39,7 @@ STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
||||||
MS_LOG(ERROR) << "new op failed";
|
MS_LOG(ERROR) << "new op failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
attr->coordinateTransformMode = schema::CoordinateTransformMode_COMMON;
|
||||||
std::vector<std::string> node_name_str;
|
std::vector<std::string> node_name_str;
|
||||||
Split(op->name.data(), &node_name_str, "-");
|
Split(op->name.data(), &node_name_str, "-");
|
||||||
const char *node_name = node_name_str.data()->c_str();
|
const char *node_name = node_name_str.data()->c_str();
|
||||||
|
@ -50,8 +50,16 @@ STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
||||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
if (tfliteAttr->align_corners) {
|
||||||
attr->alignCorners = tfliteAttr->align_corners;
|
attr->alignCorners = tfliteAttr->align_corners;
|
||||||
attr->method = schema::ResizeMethod_BILINEAR;
|
attr->coordinateTransformMode = schema::CoordinateTransformMode_ALIGN_CORNERS;
|
||||||
|
}
|
||||||
|
if (tfliteAttr->half_pixel_centers) {
|
||||||
|
attr->coordinateTransformMode = (attr->coordinateTransformMode == schema::CoordinateTransformMode_COMMON
|
||||||
|
? schema::CoordinateTransformMode_TF_HALF_PIXEL
|
||||||
|
: schema::CoordinateTransformMode_ALIGN_CORNERS_WITH_HALF_PIEXL);
|
||||||
|
}
|
||||||
|
attr->method = schema::ResizeMethod_LINEAR;
|
||||||
} else if (std::strcmp(node_name, "NearestNeighbor") == 0) {
|
} else if (std::strcmp(node_name, "NearestNeighbor") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteResizeNearestNeighborParser";
|
MS_LOG(DEBUG) << "parse TfliteResizeNearestNeighborParser";
|
||||||
const auto &tfliteAttr = tflite_op->builtin_options.AsResizeNearestNeighborOptions();
|
const auto &tfliteAttr = tflite_op->builtin_options.AsResizeNearestNeighborOptions();
|
||||||
|
@ -59,8 +67,17 @@ STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
||||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
if (tfliteAttr->align_corners) {
|
||||||
attr->alignCorners = tfliteAttr->align_corners;
|
attr->alignCorners = tfliteAttr->align_corners;
|
||||||
attr->method = schema::ResizeMethod_NEAREST_NEIGHBOR;
|
attr->coordinateTransformMode = schema::CoordinateTransformMode_ALIGN_CORNERS;
|
||||||
|
}
|
||||||
|
if (tfliteAttr->half_pixel_centers) {
|
||||||
|
attr->coordinateTransformMode = (attr->coordinateTransformMode == schema::CoordinateTransformMode_COMMON
|
||||||
|
? schema::CoordinateTransformMode_TF_HALF_PIXEL
|
||||||
|
: schema::CoordinateTransformMode_ALIGN_CORNERS_WITH_HALF_PIEXL);
|
||||||
|
}
|
||||||
|
attr->method = schema::ResizeMethod_NEAREST;
|
||||||
|
attr->nearestMode = schema::NearestMode_NORMAL;
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "wrong resize type";
|
MS_LOG(ERROR) << "wrong resize type";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
|
Loading…
Reference in New Issue