!6183 fix opencl fp16 convertion

Merge pull request !6183 from wandongdong/master
This commit is contained in:
mindspore-ci-bot 2020-09-14 21:44:40 +08:00 committed by Gitee
commit 3b4d855160
9 changed files with 434 additions and 47 deletions

View File

@ -0,0 +1,72 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void gather_NHWC4(__read_only image2d_t src_data, __global int *indices, __write_only image2d_t dst_data,
int4 src_size, int4 dst_size, int indices_num, int axis) {
int X = get_global_id(0); // w
int Y = get_global_id(1); // h
int Z = get_global_id(2); // c
if (X >= dst_size.x || Y >= dst_size.y || Y >= dst_size.z) {
return;
}
FLT4 res_data = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
if (axis == 1) {
res_data = READ_IMAGE(src_data, smp_zero, (int2)(X * src_size.z + Z, indices[Y]));
} else if (axis == 2) {
res_data = READ_IMAGE(src_data, smp_zero, (int2)(indices[X] * src_size.z + Z, Y));
} else if (axis == 3) {
int offset[4] = {indices[Z * 4] / 4, indices[Z * 4 + 1] / 4, indices[Z * 4 + 2] / 4, indices[Z * 4 + 3] / 4};
FLT tmp[4];
FLT res_tmp[4];
for (int i = 0; i < 4; ++i) {
if (i >= 1 && offset[i] != offset[i - 1]) {
FLT4 rd_data = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
rd_data = READ_IMAGE(src_data, smp_zero, (int2)(X * src_size.z + offset[i], Y));
tmp[0] = rd_data.x;
tmp[1] = rd_data.y;
tmp[2] = rd_data.z;
tmp[3] = rd_data.w;
}
res_tmp[i] = tmp[indices[Z * 4 + i] % 4];
}
res_data.x = res_tmp[0];
res_data.y = res_tmp[1];
res_data.z = res_tmp[2];
res_data.w = res_tmp[3];
}
WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), res_data);
}
__kernel void gather_NC4HW4(__read_only image2d_t src_data, __global int *indices, __write_only image2d_t dst_data,
int4 src_size, int4 dst_size, int indices_num, int axis) {
int X = get_global_id(0); // w
int Y = get_global_id(1); // h
int Z = get_global_id(2); // c
if (X >= dst_size.x || Y >= dst_size.y || Y >= dst_size.z) {
return;
}
FLT4 res_data = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
if (axis == 1) {
res_data = READ_IMAGE(src_data, smp_zero, (int2)(X, Z * dst_size.y + indices[Y]));
} else if (axis == 2) {
res_data = READ_IMAGE(src_data, smp_zero, (int2)(indices[X], Z * dst_size.y + Y));
} else if (axis == 3) {
int offset[4] = {indices[Z * 4] / 4, indices[Z * 4 + 1] / 4, indices[Z * 4 + 2] / 4, indices[Z * 4 + 3] / 4};
FLT tmp[4];
FLT res_tmp[4];
for (int i = 0; i < 4; ++i) {
if (i >= 1 && offset[i] != offset[i - 1]) {
FLT4 rd_data = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
rd_data = READ_IMAGE(src_data, smp_zero, (int2)(X, offset[i] * dst_size.y + Y));
tmp[0] = rd_data.x;
tmp[1] = rd_data.y;
tmp[2] = rd_data.z;
tmp[3] = rd_data.w;
}
res_tmp[i] = tmp[indices[Z * 4 + i] % 4];
}
res_data.x = res_tmp[0];
res_data.y = res_tmp[1];
res_data.z = res_tmp[2];
res_data.w = res_tmp[3];
}
WRITE_IMAGE(dst_data, (int2)(X, (Z * dst_size.y + Y)), res_data);
}

View File

@ -101,8 +101,8 @@ int DepthwiseConv2dOpenCLKernel::InitBuffer() {
std::function<int16_t(int16_t)> to_dtype = [](int16_t x) -> int16_t { return x; };
PackNCHWToNC4HW4<int16_t, int16_t>(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype);
} else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat32) {
std::function<int16_t(float)> to_dtype = Float32ToShort;
PackNCHWToNC4HW4<float, int16_t>(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype);
std::function<float16_t(float)> to_dtype = [](float x) -> float16_t { return static_cast<float16_t>(x); };
PackNCHWToNC4HW4<float, float16_t>(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype);
} else {
MS_LOG(ERROR) << "Only support float16/float32, actual data type " << in_tensors_.at(kWeightIndex)->data_type();
return RET_ERROR;
@ -111,8 +111,11 @@ int DepthwiseConv2dOpenCLKernel::InitBuffer() {
packed_weight_ = allocator->Malloc(pack_weight_size * sizeof(float));
packed_weight_ = allocator->MapBuffer(packed_weight_, CL_MAP_WRITE, nullptr, true);
if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat32) {
std::function<float(float)> to_dtype = [](float x) -> float { return (float)x; };
std::function<float(float)> to_dtype = [](float x) -> float { return x; };
PackNCHWToNC4HW4<float, float>(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype);
} else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat16) {
std::function<float(float16_t)> to_dtype = [](float16_t x) -> float { return static_cast<float>(x); };
PackNCHWToNC4HW4<float16_t, float>(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype);
} else {
MS_LOG(ERROR) << "Only support float16/float32, actual data type " << in_tensors_.at(kWeightIndex)->data_type();
return RET_ERROR;

View File

@ -0,0 +1,170 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstring>
#include <string>
#include <algorithm>
#include <set>
#include <utility>
#include "src/kernel_registry.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/opencl/kernel/gather.h"
#include "src/runtime/kernel/opencl/cl/gather.cl.inc"
using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_Gather;
namespace mindspore::kernel {
int GatherOpenCLKernel::Init() {
std::string kernel_name = "gather";
auto in_format = op_format_;
if (in_format != schema::Format_NHWC4 && in_format != schema::Format_NC4HW4) {
MS_LOG(ERROR) << "input format(" << in_format << ") "
<< "format not support!";
return RET_ERROR;
}
in_ori_format_ = in_tensors_[0]->GetFormat();
out_ori_format_ = out_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(op_format_);
out_tensors_[0]->SetFormat(op_format_);
if (in_format == schema::Format_NC4HW4) {
kernel_name += "_NC4HW4";
} else {
kernel_name += "_NHWC4";
}
std::set<std::string> build_options;
std::string source = gather_source;
std::string program_name = "gather";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
// init indices_data_
auto indices_tensor = in_tensors_.at(1);
int indices_num = indices_tensor->ElementsNum();
bool isIndicesInt32 = indices_tensor->data_type() == kNumberTypeInt32;
auto allocator = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator();
if (!isIndicesInt32) {
indices_data_ = reinterpret_cast<int32_t *>(allocator->Malloc(sizeof(int32_t) * indices_num));
if (indices_data_ == nullptr) {
MS_LOG(ERROR) << "Memory allocation failed";
return RET_ERROR;
}
}
return RET_OK;
}
int GatherOpenCLKernel::InitBuffer() {
auto indices_tensor = in_tensors_.at(1);
int indices_num = indices_tensor->ElementsNum();
bool isIndicesInt32 = indices_tensor->data_type() == kNumberTypeInt32;
if (!isIndicesInt32) {
if (indices_tensor->data_type() == kNumberTypeInt64) {
for (int i = 0; i < indices_num; i++) {
indices_data_[i] = reinterpret_cast<int64_t *>(indices_tensor->data_c())[i];
}
} else if (indices_tensor->data_type() == kNumberTypeFloat32) {
for (int i = 0; i < indices_num; i++) {
indices_data_[i] = reinterpret_cast<float *>(indices_tensor->data_c())[i];
}
} else if (indices_tensor->data_type() == kNumberTypeFloat16) {
for (int i = 0; i < indices_num; i++) {
indices_data_[i] = reinterpret_cast<float16_t *>(indices_tensor->data_c())[i];
}
} else {
MS_LOG(ERROR) << "Unsupported data type: " << indices_tensor->data_type();
return RET_ERROR;
}
} else {
indices_data_ = reinterpret_cast<int32_t *>(indices_tensor->data_c());
}
return RET_OK;
}
int GatherOpenCLKernel::ReSize() { return RET_OK; }
int GatherOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
size_t im_dst_x, im_dst_y;
if (in_tensors_[0]->GetFormat() == schema::Format::Format_NHWC4) {
im_dst_x = out_tensors_[0]->Width() * CO4;
im_dst_y = out_tensors_[0]->Height();
} else {
im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * CO4;
im_dst_x = out_tensors_[0]->Width();
}
size_t img_dtype = CL_FLOAT;
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
auto enable_fp16_ = ocl_runtime->GetFp16Enable();
if (enable_fp16_) {
img_dtype = CL_HALF_FLOAT;
}
img_size->clear();
std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype};
*img_size = std::move(vec);
return RET_OK;
}
int GatherOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running! ";
auto param = reinterpret_cast<GatherParameter *>(this->op_parameter_);
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
if (InitBuffer() != RET_OK) {
return RET_ERROR;
}
auto input_shape = in_tensors_[0]->shape();
auto output_shape = out_tensors_[0]->shape();
int indices_num = in_tensors_[1]->ElementsNum();
size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
size_t CI4 = UP_DIV(in_tensors_[0]->Channel(), C4NUM);
cl_int4 src_size = {in_tensors_[0]->Width(), in_tensors_[0]->Height(), (cl_int)CI4, in_tensors_[0]->Batch()};
cl_int4 dst_size = {(cl_int)out_tensors_[0]->Width(), (cl_int)out_tensors_[0]->Height(), (cl_int)CO4,
(cl_int)out_tensors_[0]->Batch()};
std::vector<size_t> local = {1, 1, 1};
std::vector<size_t> global = {(size_t)out_tensors_[0]->Width(), (size_t)out_tensors_[0]->Height(), CO4};
int arg_cn = 0;
ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c(), lite::opencl::MemType::IMG);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, indices_data_, lite::opencl::MemType::BUF);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c(), lite::opencl::MemType::IMG);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, src_size);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, dst_size);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, indices_num);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, param->axis_);
ocl_runtime->RunKernel(kernel_, global, local, nullptr);
return RET_OK;
}
kernel::LiteKernel *OpenCLGatherKernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::Context *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
auto *kernel = new (std::nothrow) GatherOpenCLKernel(opParameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "Kernel " << opParameter->name_ << " new failed.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Kernel " << opParameter->name_ << " init failed.";
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Gather, OpenCLGatherKernelCreator);
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Gather, OpenCLGatherKernelCreator);
} // namespace mindspore::kernel

View File

@ -0,0 +1,51 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_GATHER_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_GATHER_H_
#include <vector>
#include "ir/anf.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "nnacl/gather_parameter.h"
namespace mindspore::kernel {
class GatherOpenCLKernel : public OpenCLKernel {
public:
explicit GatherOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs)
: OpenCLKernel(parameter, inputs, outputs), indices_data_(nullptr) {}
~GatherOpenCLKernel() override{};
int Init() override;
int ReSize() override;
int Run() override;
int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;
int InitBuffer();
private:
cl::Kernel kernel_;
int32_t *indices_data_;
};
} // namespace mindspore::kernel
#endif

View File

@ -52,7 +52,7 @@ void PReluOpenCLKernel::InitBuffer() {
auto PReluWeight_fp16 = reinterpret_cast<uint16_t *>(PReluWeight_);
auto in_tensor_data_fp32 = reinterpret_cast<float *>(in_tensors_[1]->MutableData());
for (int i = 0; i < elem_num; i++) {
PReluWeight_fp16[i] = Float32ToShort(in_tensor_data_fp32[i]);
PReluWeight_fp16[i] = static_cast<float16_t>(in_tensor_data_fp32[i]);
}
} else {
memcpy(PReluWeight_, in_tensors_[1]->MutableData(), elem_num * fp_size);
@ -60,9 +60,9 @@ void PReluOpenCLKernel::InitBuffer() {
} else {
if (in_tensors_[1]->data_type() == kNumberTypeFloat16) {
auto PReluWeight_fp32 = reinterpret_cast<float *>(PReluWeight_);
auto in_tensor_data_fp16 = reinterpret_cast<uint16_t *>(in_tensors_[1]->MutableData());
auto in_tensor_data_fp16 = reinterpret_cast<float16_t *>(in_tensors_[1]->MutableData());
for (int i = 0; i < elem_num; i++) {
PReluWeight_fp32[i] = ShortToFloat32(in_tensor_data_fp16[i]);
PReluWeight_fp32[i] = static_cast<float>(in_tensor_data_fp16[i]);
}
} else {
memcpy(PReluWeight_, in_tensors_[1]->MutableData(), elem_num * fp_size);

View File

@ -55,6 +55,9 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector<lite::Tensor *> &in_te
}
}
for (size_t i = 0; i < in_tensors.size(); ++i) {
if (in_tensors.at(i)->shape().size() <= 1) {
continue;
}
OpenCLKernel *cur_opencl_op = reinterpret_cast<OpenCLKernel *>(in_kernels[i][0]);
schema::Format out_ori_format = cur_opencl_op->GetOutOriFormat();
schema::Format in_ori_format = cur_opencl_op->GetInOriFormat();

View File

@ -95,6 +95,7 @@ class OpenCLRuntime {
}
default:
MS_LOG(ERROR) << "Unsupport opencl memory type: " << static_cast<int>(mem_type);
return CL_IMAGE_FORMAT_NOT_SUPPORTED;
}
}

View File

@ -104,30 +104,14 @@ if (ENABLE_FP16)
endif ()
### gpu kernel
if (SUPPORT_GPU)
file(GLOB GPU_KERNEL_OP_SRC
${LITE_DIR}/src/runtime/kernel/opencl/kernel/*.cc
)
set(KERNEL_OP_SRC
${KERNEL_OP_SRC}
${GPU_KERNEL_OP_SRC}
${LITE_DIR}/src/runtime/kernel/opencl/subgraph_opencl_kernel.cc
${LITE_DIR}/src/runtime/kernel/opencl/utils.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/arithmetic.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/convolution.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/pooling2d.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/matmul.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/softmax.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/concat.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/arithmetic_self.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/batchnorm.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/slice.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/activation.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/transpose.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/reshape.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/to_format.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/prelu.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/to_format.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/biasadd.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/scale.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/reduce.cc
)
endif()
### minddata lite
@ -294,29 +278,12 @@ else()
endif()
if (SUPPORT_GPU)
file(GLOB_RECURSE TEST_CASE_KERNEL_GPU_SRC
${TEST_DIR}/ut/src/runtime/kernel/opencl/*.cc
)
set(TEST_SRC
${TEST_SRC}
${TEST_DIR}/ut/src/runtime/kernel/opencl/matmul_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/concat_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/batchnorm_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/slice_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/arithmetic_self_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/softmax_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/arithmetic_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/avg_pooling_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/max_pooling_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/utils_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/transpose_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/convolution_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/activation_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/to_format_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/prelu_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/reshape_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/biasadd_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/scale_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/reduce_tests.cc
${TEST_CASE_KERNEL_GPU_SRC}
)
endif()

View File

@ -0,0 +1,120 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "src/runtime/kernel/opencl/utils.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/gather.h"
namespace mindspore {
class TestGatherOpenCL : public mindspore::CommonTest {
public:
TestGatherOpenCL() {}
};
template <typename T>
void test_main_gather(void *input_data, void *correct_data, const std::vector<int> &input_shape,
const std::vector<int> &indices, GatherParameter *param, TypeId data_type,
schema::Format format) {
MS_LOG(INFO) << " begin test ";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();
std::vector<int> indices_shape = {static_cast<int>(indices.size())};
std::vector<int> output_shape = input_shape;
output_shape[param->axis_] = indices.size();
auto tensor_a = lite::Tensor(TypeId(data_type), input_shape, format);
auto tensor_b = lite::Tensor(TypeId(data_type), indices_shape, schema::Format_NC);
auto tensor_c = lite::Tensor(TypeId(data_type), output_shape, format);
std::vector<lite::Tensor *> inputs{&tensor_a, &tensor_b};
std::vector<lite::Tensor *> outputs{&tensor_c};
size_t input_size = tensor_a.Size();
auto *pkernel =
new (std::nothrow) kernel::GatherOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (pkernel == nullptr) {
MS_LOG(INFO) << "new GatherOpenCLKernel failed ";
return;
}
pkernel->Init();
// to do allocate memory for inputs and outputs
for (auto &input_tensor : inputs) {
input_tensor->MallocData(allocator);
}
MS_LOG(INFO) << " initialize sub_graph ";
std::vector<kernel::LiteKernel *> kernels{pkernel};
auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
if (sub_graph == nullptr) {
delete pkernel;
MS_LOG(INFO) << " new SubGraphOpenCLKernel failed ";
return;
}
sub_graph->Init();
MS_LOG(INFO) << " init tensors ";
memcpy(inputs[0]->MutableData(), input_data, input_size);
sub_graph->Run();
std::cout << "==================output data================" << std::endl;
auto *output_data = reinterpret_cast<T *>(outputs[0]->data_c());
CommonTest::CompareOutputData<T>(output_data, static_cast<T*>(correct_data), outputs[0]->ElementsNum(), 0.0001);
delete pkernel;
delete sub_graph;
}
TEST_F(TestGatherOpenCL, Axis1Fp32) {
std::vector<int> input_shape{1, 5, 4, 4};
std::vector<int> indices{1, 3};
GatherParameter *param = std::make_unique<GatherParameter>().release();
param->axis_ = 1;
float input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79};
float correct_data[] = {16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
if (param == nullptr) {
return;
}
TypeId data_type = kNumberTypeFloat32;
schema::Format format = schema::Format_NHWC;
test_main_gather<float>(input_data, correct_data, input_shape, indices, param, data_type, format);
}
TEST_F(TestGatherOpenCL, Axis2Int32) {
std::vector<int> input_shape{1, 5, 4, 4};
std::vector<int> indices{1, 3};
GatherParameter *param = std::make_unique<GatherParameter>().release();
param->axis_ = 1;
float input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79};
float correct_data[] = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31, 36, 37, 38, 39,
44, 45, 46, 47, 52, 53, 54, 55, 60, 61, 62, 63, 68, 69, 70, 71, 76, 77, 78, 79};
if (param == nullptr) {
return;
}
TypeId data_type = kNumberTypeFloat32;
schema::Format format = schema::Format_NHWC;
test_main_gather<int>(input_data, correct_data, input_shape, indices, param, data_type, format);
}
} // namespace mindspore