forked from mindspore-Ecosystem/mindspore
!7029 [MS][LITE][GPU]add resize op
Merge pull request !7029 from chenzupeng/master-lite
This commit is contained in:
commit
3811e5a933
|
@ -0,0 +1,77 @@
|
|||
#ifdef cl_khr_fp16
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#endif
|
||||
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
__kernel void resize_nearest_neighbor_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data,
|
||||
int4 in_size, int4 out_size, float2 scale_factor) {
|
||||
int X = get_global_id(2); // H
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(0); // C4
|
||||
if (X >= out_size.y || Y >= out_size.z || Z >= out_size.w) {
|
||||
return;
|
||||
}
|
||||
int src_x = (int)(X * scale_factor.x);
|
||||
int src_y = (int)(Y * scale_factor.y);
|
||||
WRITE_IMAGE(dst_data, (int2)(Y * out_size.w + Z, X),
|
||||
READ_IMAGE(src_data, smp_zero, (int2)(src_y * in_size.w + Z, src_x)));
|
||||
}
|
||||
|
||||
__kernel void resize_nearest_neighbor_NC4HW4(__read_only image2d_t src_data, __write_only image2d_t dst_data,
|
||||
int4 in_size, int4 out_size, float2 scale_factor) {
|
||||
int X = get_global_id(2); // H
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(0); // C4
|
||||
if (X >= out_size.y || Y >= out_size.z || Z >= out_size.w) {
|
||||
return;
|
||||
}
|
||||
int src_x = (int)(X * scale_factor.x);
|
||||
int src_y = (int)(Y * scale_factor.y);
|
||||
WRITE_IMAGE(dst_data, (int2)(Y, Z * out_size.y + X),
|
||||
READ_IMAGE(src_data, smp_zero, (int2)(src_y, Z * in_size.y + src_x)));
|
||||
}
|
||||
|
||||
__kernel void resize_bilinear_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 in_size,
|
||||
int4 out_size, float2 scale_factor) {
|
||||
int X = get_global_id(2); // H
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(0); // C4
|
||||
if (X >= out_size.y || Y >= out_size.z || Z >= out_size.w) {
|
||||
return;
|
||||
}
|
||||
float scale_x = X * scale_factor.x;
|
||||
float scale_y = Y * scale_factor.y;
|
||||
int src_x = (int)(scale_x);
|
||||
int src_y = (int)(scale_y);
|
||||
int src_x_1 = min(src_x + 1, in_size.y - 1);
|
||||
int src_y_1 = min(src_y + 1, in_size.z - 1);
|
||||
FLT4 src0 = READ_IMAGE(src_data, smp_zero, (int2)(src_y * in_size.w + Z, src_x));
|
||||
FLT4 src1 = READ_IMAGE(src_data, smp_zero, (int2)(src_y_1 * in_size.w + Z, src_x));
|
||||
FLT4 src2 = READ_IMAGE(src_data, smp_zero, (int2)(src_y * in_size.w + Z, src_x_1));
|
||||
FLT4 src3 = READ_IMAGE(src_data, smp_zero, (int2)(src_y_1 * in_size.w + Z, src_x_1));
|
||||
FLT4 result =
|
||||
mix(mix(src0, src1, TO_FLT(scale_y - src_y)), mix(src2, src3, TO_FLT(scale_y - src_y)), TO_FLT(scale_x - src_x));
|
||||
WRITE_IMAGE(dst_data, (int2)(Y * out_size.w + Z, X), result);
|
||||
}
|
||||
|
||||
__kernel void resize_bilinear_NC4HW4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 in_size,
|
||||
int4 out_size, float2 scale_factor) {
|
||||
int X = get_global_id(2); // H
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(0); // C4
|
||||
if (X >= out_size.y || Y >= out_size.z || Z >= out_size.w) {
|
||||
return;
|
||||
}
|
||||
float scale_x = X * scale_factor.x;
|
||||
float scale_y = Y * scale_factor.y;
|
||||
int src_x = (int)(scale_x);
|
||||
int src_y = (int)(scale_y);
|
||||
int src_x_1 = min(src_x + 1, in_size.y - 1);
|
||||
int src_y_1 = min(src_y + 1, in_size.z - 1);
|
||||
FLT4 src0 = READ_IMAGE(src_data, smp_zero, (int2)(src_y, in_size.y * Z + src_x));
|
||||
FLT4 src1 = READ_IMAGE(src_data, smp_zero, (int2)(src_y_1, in_size.y * Z + src_x));
|
||||
FLT4 src2 = READ_IMAGE(src_data, smp_zero, (int2)(src_y, in_size.y * Z + src_x_1));
|
||||
FLT4 src3 = READ_IMAGE(src_data, smp_zero, (int2)(src_y_1, in_size.y * Z + src_x_1));
|
||||
FLT4 result =
|
||||
mix(mix(src0, src1, TO_FLT(scale_y - src_y)), mix(src2, src3, TO_FLT(scale_y - src_y)), TO_FLT(scale_x - src_x));
|
||||
WRITE_IMAGE(dst_data, (int2)(Y, out_size.w * Z + X), result);
|
||||
}
|
|
@ -0,0 +1,150 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "include/errorcode.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/kernel/opencl/kernel/resize.h"
|
||||
#include "src/runtime/kernel/opencl/cl/resize.cl.inc"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kGPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_NULL_PTR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::lite::RET_PARAM_INVALID;
|
||||
using mindspore::schema::PrimitiveType_Resize;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int ResizeOpenCLKernel::Init() {
|
||||
auto resize_param = reinterpret_cast<ResizeParameter *>(op_parameter_);
|
||||
if (resize_param == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
alignCorner = resize_param->align_corners_;
|
||||
preserveAspectRatio = resize_param->preserve_aspect_ratio_;
|
||||
auto in_shape = in_tensors_[0]->shape();
|
||||
auto out_shape = out_tensors_[0]->shape();
|
||||
if (in_shape.size() != 4 || out_shape.size() != 4 || in_shape[0] != out_shape[0] || in_shape[3] != out_shape[3]) {
|
||||
MS_LOG(ERROR) << "resize op only support 4D and axes HW";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
std::string kernel_name = "resize";
|
||||
if (resize_param->method_ == schema::ResizeMethod_BILINEAR) {
|
||||
kernel_name += "_bilinear";
|
||||
} else if (resize_param->method_ == schema::ResizeMethod_NEAREST_NEIGHBOR) {
|
||||
kernel_name += "_nearest_neighbor";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unsupported resize method:" << resize_param->method_;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
kernel_name += "_" + std::string(EnumNameFormat(op_format_));
|
||||
#ifdef PROGRAM_WITH_IL
|
||||
kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name);
|
||||
#else
|
||||
std::set<std::string> build_options;
|
||||
std::string source = resize_source;
|
||||
ocl_runtime_->LoadSource(kernel_name, source);
|
||||
ocl_runtime_->BuildKernel(kernel_, kernel_name, kernel_name, build_options);
|
||||
#endif
|
||||
in_ori_format_ = in_tensors_[0]->GetFormat();
|
||||
out_ori_format_ = out_tensors_[0]->GetFormat();
|
||||
in_tensors_[0]->SetFormat(op_format_);
|
||||
out_tensors_[0]->SetFormat(op_format_);
|
||||
MS_LOG(DEBUG) << kernel_name << " Init Done!";
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ResizeOpenCLKernel::ReSize() { return RET_OK; }
|
||||
|
||||
int ResizeOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
|
||||
size_t im_dst_x, im_dst_y;
|
||||
auto nhwc_shape_ = out_tensors_[0]->shape();
|
||||
if (op_format_ == schema::Format_NHWC4) {
|
||||
im_dst_x = nhwc_shape_[2] * UP_DIV(nhwc_shape_[3], C4NUM);
|
||||
im_dst_y = nhwc_shape_[0] * nhwc_shape_[1];
|
||||
} else if (op_format_ == schema::Format_NC4HW4) {
|
||||
im_dst_x = nhwc_shape_[2];
|
||||
im_dst_y = nhwc_shape_[0] * UP_DIV(nhwc_shape_[3], C4NUM) * nhwc_shape_[1];
|
||||
} else {
|
||||
MS_LOG(ERROR) << "not support op format:" << EnumNameFormat(op_format_);
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t img_dtype = CL_FLOAT;
|
||||
if (ocl_runtime_->GetFp16Enable()) {
|
||||
img_dtype = CL_HALF_FLOAT;
|
||||
}
|
||||
img_size->clear();
|
||||
std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype};
|
||||
*img_size = vec;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
float ResizeOpenCLKernel::getResizeScaleFactor(int input_size, int output_size) {
|
||||
return input_size > 1 && output_size > 1 && alignCorner
|
||||
? static_cast<float>(input_size - 1) / static_cast<float>(output_size - 1)
|
||||
: static_cast<float>(input_size) / static_cast<float>(output_size);
|
||||
}
|
||||
|
||||
int ResizeOpenCLKernel::Run() {
|
||||
MS_LOG(DEBUG) << this->name() << " Running!";
|
||||
auto in_shape = in_tensors_[0]->shape();
|
||||
auto out_shape = out_tensors_[0]->shape();
|
||||
int n = out_shape[0];
|
||||
int h = out_shape[1];
|
||||
int w = out_shape[2];
|
||||
int c = out_shape[3];
|
||||
int c4 = UP_DIV(c, C4NUM);
|
||||
float scale_h = getResizeScaleFactor(in_tensors_[0]->shape()[1], out_tensors_[0]->shape()[1]);
|
||||
float scale_w = getResizeScaleFactor(in_tensors_[0]->shape()[2], out_tensors_[0]->shape()[2]);
|
||||
std::vector<size_t> local = {};
|
||||
std::vector<size_t> global = {static_cast<size_t>(c4), static_cast<size_t>(w), static_cast<size_t>(h)};
|
||||
cl_int4 in_size = {in_shape[0], in_shape[1], in_shape[2], UP_DIV(in_shape[3], C4NUM)};
|
||||
cl_int4 out_size = {n, h, w, c4};
|
||||
cl_float2 scale = {scale_h, scale_w};
|
||||
int arg_idx = 0;
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->data_c());
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data_c());
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, in_size);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_size);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, scale);
|
||||
ocl_runtime_->RunKernel(kernel_, global, local, nullptr);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *OpenCLResizeKernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
|
||||
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
auto *kernel = new (std::nothrow) ResizeOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel " << opParameter->name_ << " create failed.";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Resize, OpenCLResizeKernelCreator)
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Resize, OpenCLResizeKernelCreator)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_RESIZE_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_RESIZE_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/runtime/kernel/opencl/opencl_kernel.h"
|
||||
#include "nnacl/resize_parameter.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ResizeOpenCLKernel : public OpenCLKernel {
|
||||
public:
|
||||
explicit ResizeOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs)
|
||||
: OpenCLKernel(parameter, inputs, outputs) {}
|
||||
~ResizeOpenCLKernel() override{};
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;
|
||||
float getResizeScaleFactor(int input_size, int output_size);
|
||||
|
||||
private:
|
||||
cl::Kernel kernel_;
|
||||
bool alignCorner;
|
||||
bool preserveAspectRatio;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_RESIZE_H_
|
|
@ -0,0 +1,185 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "common/common_test.h"
|
||||
#include "mindspore/lite/src/common/file_utils.h"
|
||||
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/resize.h"
|
||||
#include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestResizeOpenCL : public mindspore::CommonTest {
|
||||
public:
|
||||
TestResizeOpenCL() {}
|
||||
};
|
||||
|
||||
void RunTestCaseResize(const std::vector<int> &shape, void *input_data, void *output_data, bool enable_fp16,
|
||||
int resize_mode, bool align_corners) {
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance();
|
||||
ocl_runtime->Init();
|
||||
size_t dtype_size = enable_fp16 ? sizeof(float16_t) : sizeof(float);
|
||||
ocl_runtime->SetFp16Enable(enable_fp16);
|
||||
auto allocator = ocl_runtime->GetAllocator();
|
||||
auto param = static_cast<ResizeParameter *>(malloc(sizeof(ResizeParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "param_ptr create error.";
|
||||
return;
|
||||
}
|
||||
int n = shape[0];
|
||||
int h = shape[1];
|
||||
int w = shape[2];
|
||||
int oh = shape[3];
|
||||
int ow = shape[4];
|
||||
int c = shape[5];
|
||||
param->new_height_ = oh;
|
||||
param->new_width_ = ow;
|
||||
param->align_corners_ = align_corners;
|
||||
param->method_ = resize_mode;
|
||||
std::vector<int> input_shape = {n, h, w, c};
|
||||
auto tensor_x_ptr = std::make_unique<lite::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32),
|
||||
input_shape, schema::Format_NHWC);
|
||||
auto tensor_x = tensor_x_ptr.get();
|
||||
if (tensor_x == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_x create error.";
|
||||
return;
|
||||
}
|
||||
std::vector<int> out_shape = {n, oh, ow, c};
|
||||
auto tensor_out_ptr = std::make_unique<lite::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32),
|
||||
out_shape, schema::Format_NHWC);
|
||||
auto tensor_out = tensor_out_ptr.get();
|
||||
if (tensor_out == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_out create error.";
|
||||
return;
|
||||
}
|
||||
std::vector<lite::Tensor *> inputs{tensor_x};
|
||||
std::vector<lite::Tensor *> outputs{tensor_out};
|
||||
auto arith_kernel_ptr =
|
||||
std::make_unique<kernel::ResizeOpenCLKernel>(reinterpret_cast<OpParameter *>(param), inputs, outputs);
|
||||
auto arith_kernel = arith_kernel_ptr.release();
|
||||
if (arith_kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "arith_kernel create error.";
|
||||
return;
|
||||
}
|
||||
arith_kernel->Init();
|
||||
|
||||
inputs[0]->MallocData(allocator);
|
||||
|
||||
std::vector<kernel::LiteKernel *> kernels{arith_kernel};
|
||||
auto pGraph_ptr = std::make_unique<kernel::SubGraphOpenCLKernel>(inputs, outputs, kernels, kernels, kernels);
|
||||
auto pGraph = pGraph_ptr.get();
|
||||
if (pGraph == nullptr) {
|
||||
MS_LOG(ERROR) << "pGraph create error.";
|
||||
return;
|
||||
}
|
||||
pGraph->Init();
|
||||
memcpy(inputs[0]->MutableData(), input_data, inputs[0]->ElementsNum() * dtype_size);
|
||||
pGraph->Run();
|
||||
|
||||
if (enable_fp16) {
|
||||
CompareOutput(outputs[0]->MutableData(), output_data, outputs[0]->ElementsNum(), static_cast<float16_t>(1e-3),
|
||||
2e-2);
|
||||
} else {
|
||||
CompareOutput(outputs[0]->MutableData(), output_data, outputs[0]->ElementsNum(), static_cast<float>(1e-5));
|
||||
}
|
||||
for (auto t : inputs) {
|
||||
t->SetData(nullptr);
|
||||
}
|
||||
for (auto t : outputs) {
|
||||
t->SetData(nullptr);
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Test Resize passed";
|
||||
}
|
||||
|
||||
TEST_F(TestResizeOpenCL, ResizeBilinearFp32) {
|
||||
int n = 1;
|
||||
int h = 2;
|
||||
int w = 2;
|
||||
int oh = 4;
|
||||
int ow = 4;
|
||||
int c = 1;
|
||||
bool align_corners = false;
|
||||
std::vector<int> shape = {n, h, w, oh, ow, c};
|
||||
std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||
std::vector<float> output_data = {0.0f, 0.5f, 1.0f, 1.0f, 1.0f, 1.5f, 2.0f, 2.0f,
|
||||
2.0f, 2.5f, 3.0f, 3.0f, 2.0f, 2.5f, 3.0f, 3.0f};
|
||||
RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_BILINEAR, align_corners);
|
||||
}
|
||||
|
||||
TEST_F(TestResizeOpenCL, ResizeBilinearFp16) {
|
||||
int n = 1;
|
||||
int h = 2;
|
||||
int w = 2;
|
||||
int oh = 4;
|
||||
int ow = 4;
|
||||
int c = 1;
|
||||
bool align_corners = false;
|
||||
std::vector<int> shape = {n, h, w, oh, ow, c};
|
||||
std::vector<float16_t> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||
std::vector<float16_t> output_data = {0.0f, 0.5f, 1.0f, 1.0f, 1.0f, 1.5f, 2.0f, 2.0f,
|
||||
2.0f, 2.5f, 3.0f, 3.0f, 2.0f, 2.5f, 3.0f, 3.0f};
|
||||
RunTestCaseResize(shape, input_data.data(), output_data.data(), true, schema::ResizeMethod_BILINEAR, align_corners);
|
||||
}
|
||||
|
||||
TEST_F(TestResizeOpenCL, ResizeBilinearAlignFp32) {
|
||||
int n = 1;
|
||||
int h = 2;
|
||||
int w = 2;
|
||||
int oh = 3;
|
||||
int ow = 3;
|
||||
int c = 1;
|
||||
bool align_corners = true;
|
||||
std::vector<int> shape = {n, h, w, oh, ow, c};
|
||||
std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||
std::vector<float> output_data = {0.0f, 0.5f, 1.0f, 1.0f, 1.5f, 2.0f, 2.0f, 2.5f, 3.0f};
|
||||
RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_BILINEAR, align_corners);
|
||||
}
|
||||
|
||||
TEST_F(TestResizeOpenCL, ResizeNearestNeighborFp32) {
|
||||
int n = 1;
|
||||
int h = 2;
|
||||
int w = 2;
|
||||
int oh = 4;
|
||||
int ow = 4;
|
||||
int c = 1;
|
||||
bool align_corners = false;
|
||||
std::vector<int> shape = {n, h, w, oh, ow, c};
|
||||
std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||
std::vector<float> output_data = {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f,
|
||||
2.0f, 2.0f, 3.0f, 3.0f, 2.0f, 2.0f, 3.0f, 3.0f};
|
||||
RunTestCaseResize(shape, input_data.data(), output_data.data(), false, schema::ResizeMethod_NEAREST_NEIGHBOR,
|
||||
align_corners);
|
||||
}
|
||||
|
||||
TEST_F(TestResizeOpenCL, ResizeNearestNeighborFp16) {
|
||||
int n = 1;
|
||||
int h = 2;
|
||||
int w = 2;
|
||||
int oh = 4;
|
||||
int ow = 4;
|
||||
int c = 1;
|
||||
bool align_corners = false;
|
||||
std::vector<int> shape = {n, h, w, oh, ow, c};
|
||||
std::vector<float16_t> input_data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||
std::vector<float16_t> output_data = {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f,
|
||||
2.0f, 2.0f, 3.0f, 3.0f, 2.0f, 2.0f, 3.0f, 3.0f};
|
||||
RunTestCaseResize(shape, input_data.data(), output_data.data(), true, schema::ResizeMethod_NEAREST_NEIGHBOR,
|
||||
align_corners);
|
||||
}
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue