!4885 [MS][LITE][Develop]add new ops named slice for opencl(GPU)

Merge pull request !4885 from pengyongrong/slice
This commit is contained in:
mindspore-ci-bot 2020-08-21 16:37:10 +08:00 committed by Gitee
commit 38498aad7a
10 changed files with 445 additions and 10 deletions

View File

@ -0,0 +1,81 @@
#define INT2 int2
#define INT4 int4
#define FLT4 float4
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
__kernel void slice(__read_only image2d_t input, __write_only image2d_t output, INT4 input_shape, INT4 out_shape,
INT4 begin, INT2 sharedNoUpdiv) {
int X = get_global_id(1); // H
int Y = get_global_id(2); // W
if (X >= out_shape.y || Y >= out_shape.z) {
return;
}
FLT4 result;
if (sharedNoUpdiv.x % 4 == 0) {
for (int i = 0; i < out_shape.w; i++) {
result = read_imagef(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + (i + begin.w), (X + begin.y)));
write_imagef(output, (INT2)((Y)*out_shape.w + i, (X)), result);
}
} else {
int begin_postion = sharedNoUpdiv.y % 4;
FLT4 first = read_imagef(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + begin.w, (X + begin.y)));
if (begin_postion == 1) {
for (int i = 1; i <= out_shape.w; i++) {
FLT4 second =
read_imagef(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + (begin.w + i), (X + begin.y)));
result.x = first.y;
result.y = first.z;
result.z = first.w;
result.w = second.x;
write_imagef(output, (INT2)((Y)*out_shape.w + i - 1, (X)), result);
first.y = second.y;
first.z = second.z;
first.w = second.w;
}
} else if (begin_postion == 2) {
for (int i = 1; i <= out_shape.w; i++) {
FLT4 second =
read_imagef(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + (begin.w + i), (X + begin.y)));
result.x = first.z;
result.y = first.w;
result.z = second.x;
result.w = second.y;
write_imagef(output, (INT2)((Y)*out_shape.w + i - 1, (X)), result);
first.z = second.z;
first.w = second.w;
}
} else {
for (int i = 1; i <= out_shape.w; i++) {
FLT4 second =
read_imagef(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + (begin.w + i), (X + begin.y)));
result.x = first.w;
result.y = second.x;
result.z = second.y;
result.w = second.z;
write_imagef(output, (INT2)((Y)*out_shape.w + i - 1, (X)), result);
first.w = second.w;
}
}
}
// judge the line of size
int size = sharedNoUpdiv.y % 4;
FLT4 result_fill0;
if (size == 1) {
result_fill0.x = result.x;
result_fill0.y = 0;
result_fill0.z = 0;
result_fill0.w = 0;
write_imagef(output, (INT2)((Y)*out_shape.w + out_shape.w - 1, (X)), result_fill0);
} else if (size == 2) {
result_fill0.x = result.x;
result_fill0.y = result.y;
result_fill0.z = 0;
result_fill0.w = 0;
write_imagef(output, (INT2)((Y)*out_shape.w + out_shape.w - 1, (X)), result_fill0);
} else if (size == 3) {
result_fill0.x = result.x;
result_fill0.y = result.y;
result_fill0.z = result.z;
result_fill0.w = 0;
write_imagef(output, (INT2)((Y)*out_shape.w + out_shape.w - 1, (X)), result_fill0);
}
}

View File

@ -139,7 +139,7 @@ kernel::LiteKernel *OpenCLBatchnormKernelCreator(const std::vector<lite::tensor:
return nullptr;
}
auto ret = kernel->Init();
if (0 != ret) {
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: Convolution";
delete kernel;
return nullptr;

View File

@ -216,7 +216,7 @@ kernel::LiteKernel *OpenCLConcatKernelCreator(const std::vector<lite::tensor::Te
return nullptr;
}
auto ret = kernel->Init();
if (0 != ret) {
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: Convolution";
delete kernel;
return nullptr;

View File

@ -276,7 +276,7 @@ int ConvolutionOpenCLKernel::Run() {
}
if (use_winograd_) {
ocl_runtime->RunKernel(kernel_4x4to36, {size_t(TILES_XY), 6, size_t(CI_SLICES)}, {16, 6, 4}, nullptr);
ocl_runtime->RunKernel(kernel_4x4to36, {size_t(TILES_XY), 6, size_t(CI_SLICES)}, {8, 6, 4}, nullptr);
ocl_runtime->RunKernel(kernel_conv, {size_t(TILES_XY / 2), 36, size_t(CO_SLICES / 2)}, {8, 6, 2}, nullptr);
ocl_runtime->RunKernel(kernel_36to4x4, {size_t(TILES_XY), 4, size_t(CO_SLICES)}, {32, 4, 2}, nullptr);
} else {
@ -677,7 +677,7 @@ kernel::LiteKernel *OpenCLConvolutionKernelCreator(const std::vector<lite::tenso
return nullptr;
}
auto ret = kernel->Init();
if (0 != ret) {
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: Convolution";
delete kernel;
return nullptr;

View File

@ -0,0 +1,144 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstring>
#include <string>
#include <algorithm>
#include <set>
#include "src/kernel_registry.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/opencl/kernel/slice.h"
#include "src/runtime/kernel/opencl/cl/slice.cl.inc"
using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_Slice;
namespace mindspore::kernel {
int SliceOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
size_t im_dst_x, im_dst_y;
if (in_tensors_[0]->GetFormat() == schema::Format_NHWC4) {
im_dst_x = out_tensors_[0]->Width() * CO4;
im_dst_y = out_tensors_[0]->Height();
} else {
im_dst_y = out_tensors_[0]->Height() * CO4;
im_dst_x = out_tensors_[0]->Width();
}
#ifdef ENABLE_FP16
size_t img_dtype = CL_HALF_FLOAT;
#else
size_t img_dtype = CL_FLOAT;
#endif
img_size->clear();
std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype};
*img_size = vec;
return RET_OK;
}
int SliceOpenCLKernel::Init() {
std::set<std::string> build_options;
std::string source = slice_source;
std::string program_name = "slice";
std::string kernel_name = "slice";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
return RET_OK;
}
int SliceOpenCLKernel::ReSize() { return RET_OK; }
int SliceGetBiggestDividerWithPriority(int number, int max_divider) {
if (number % 8 == 0 && 8 <= max_divider) {
return number / 8;
} else if (number % 4 == 0 && 4 <= max_divider) {
return number / 4;
} else if (number % 2 == 0 && 2 <= max_divider) {
return number / 2;
}
for (int i = max_divider; i != 0; i--) {
if (number % i == 0) {
return i;
}
}
return 1;
}
void SlcieGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *local, int max_size) {
const int max_divider = 8;
const int max_x = 4, max_y = 8;
int x = std::min(SliceGetBiggestDividerWithPriority(global[0], max_divider), max_x);
int yz = max_size / x;
int y = std::min(std::min(SliceGetBiggestDividerWithPriority(global[1], max_divider), yz), max_y);
int z = std::min(yz / y, static_cast<int>(UP_DIV(global[2], 2)));
local->clear();
local->push_back(x);
local->push_back(y);
local->push_back(z);
}
int SliceOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running!";
auto param = reinterpret_cast<SliceParameter *>(this->op_parameter_);
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
auto input_shape = in_tensors_[0]->shape();
cl_int4 input_shape_ = {input_shape[0], input_shape[1], input_shape[2], UP_DIV(input_shape[3], C4NUM)};
cl_int4 size_ = {param->size_[0], param->size_[1], param->size_[2], UP_DIV(param->size_[3], C4NUM)};
cl_int4 begin_ = {param->begin_[0], param->begin_[1], param->begin_[2], param->begin_[3] / 4};
cl_int2 sharedNoUpdiv = {param->begin_[3], param->size_[3]};
uint32_t OH = param->size_[1];
uint32_t OW = param->size_[2];
const std::vector<size_t> &max_global = ocl_runtime->GetWorkItemSize();
std::vector<size_t> local = {1, 1, 1}; // init local
std::vector<size_t> global = {1, OH, OW};
SlcieGetWorkGroup(global, &local, max_global[0]);
int arg_cn = 0;
ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->Data()); // input tensor
ocl_runtime->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->Data()); // out tensor
ocl_runtime->SetKernelArg(kernel_, arg_cn++, input_shape_);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, size_);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, begin_);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, sharedNoUpdiv);
ocl_runtime->RunKernel(kernel_, global, local, nullptr);
return RET_OK;
} // namespace mindspore::kernel
kernel::LiteKernel *OpenCLSliceKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
auto *kernel = new (std::nothrow) SliceOpenCLKernel(opParameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new SliceOpenCLKernel failed";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: Convolution";
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Slice, OpenCLSliceKernelCreator);
} // namespace mindspore::kernel

View File

@ -0,0 +1,49 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_SLICE_H_
#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_SLICE_H_
#include <vector>
#include "ir/anf.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/arm/nnacl/fp32/slice.h"
namespace mindspore::kernel {
class SliceOpenCLKernel : public OpenCLKernel {
public:
explicit SliceOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs)
: OpenCLKernel(parameter, inputs, outputs) {}
~SliceOpenCLKernel() override{};
int Init() override;
int ReSize() override;
int Run() override;
int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;
private:
cl::Kernel kernel_;
};
} // namespace mindspore::kernel
#endif

View File

@ -46,8 +46,10 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector<lite::tensor::Tensor *
}
if (mem_type == OpenCLMemType::IMG) {
jv->set_in_tensors({});
jv->SetInKernel({});
} else {
jv->set_out_tensors({});
jv->SetOutKernel({});
}
}
}
@ -129,13 +131,21 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector<lite::tensor::Tensor *
if (mem_type == OpenCLMemType::IMG) {
for (auto &iv : in_kernels[i]) {
in_opencl_op->AddOutKernel(iv);
reinterpret_cast<OpenCLKernel *>(iv)->SetInKernel({in_convert_op});
reinterpret_cast<OpenCLKernel *>(iv)->set_in_tensors({new_tensor});
auto kernels = iv->in_kernels();
kernels.emplace_back(in_convert_op);
iv->SetInKernel(kernels);
auto tensors = iv->in_tensors();
tensors.emplace_back(new_tensor);
iv->set_in_tensors(tensors);
}
} else {
for (auto &iv : in_kernels[i]) {
reinterpret_cast<OpenCLKernel *>(iv)->SetOutKernel({in_convert_op});
reinterpret_cast<OpenCLKernel *>(iv)->set_out_tensors({new_tensor});
auto kernels = iv->out_kernels();
kernels.emplace_back(in_convert_op);
iv->SetOutKernel(kernels);
auto tensors = iv->out_tensors();
tensors.emplace_back(new_tensor);
iv->set_out_tensors(tensors);
in_convert_op->AddInKernel(iv);
}
}

View File

@ -149,6 +149,7 @@ if (SUPPORT_GPU)
${LITE_DIR}/src/runtime/kernel/opencl/kernel/softmax.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/concat.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/batchnorm.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/slice.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/activation.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/transpose.cc
@ -322,6 +323,7 @@ if (SUPPORT_GPU)
${TEST_DIR}/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/concat_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/batchnorm_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/slice_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/softmax_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/arithmetic_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/avg_pooling_tests.cc

View File

@ -43,8 +43,8 @@ TEST_F(TestBatchnormOpenCL, Batchnorminput_dim4) {
auto allocator = ocl_runtime->GetAllocator();
MS_LOG(INFO) << "Read tensors from .bin";
std::vector<int> input_shape = {1, 256, 256, 48};
std::vector<int> output_shape = {1, 256, 256, 48};
std::vector<int> input_shape = {1, 256, 256, 16};
std::vector<int> output_shape = {1, 256, 256, 16};
auto data_type = kNumberTypeFloat32;
auto tensor_type = schema::NodeType_ValueNode;

View File

@ -0,0 +1,149 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/common/file_utils.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/slice.h"
namespace mindspore {
class TestSliceOpenCL : public mindspore::CommonTest {
public:
TestSliceOpenCL() {}
};
template <typename T>
void CompareOutputData1(T *output_data, T *correct_data, int size, float err_bound) {
for (size_t i = 0; i < size; i++) {
T abs = fabs(output_data[i] - correct_data[i]);
ASSERT_LE(abs, err_bound);
}
}
TEST_F(TestSliceOpenCL, Sliceinput_dim4) {
MS_LOG(INFO) << "begin test";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();
MS_LOG(INFO) << "Read tensors from .bin";
std::vector<int> input_shape = {1, 256, 256, 48};
std::vector<int> output_shape = {1, 255, 255, 15};
std::vector<int> begin = {0, 1, 1, 7};
std::vector<int> size = {1, 255, 255, 15};
auto data_type = kNumberTypeFloat32;
auto tensor_type = schema::NodeType_ValueNode;
// get the input from .bin
size_t input_size, output_size;
std::string input_path = "./test_data/in_data.bin";
std::string output_path = "./test_data/out_data.bin";
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
auto correct_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size));
MS_LOG(INFO) << "construct tensors";
lite::tensor::Tensor *tensor_data =
new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type);
if (tensor_data == nullptr) {
MS_LOG(INFO) << "init tensor failed";
return;
}
auto *output_tensor =
new (std::nothrow) lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC4, tensor_type);
if (output_tensor == nullptr) {
delete tensor_data;
MS_LOG(INFO) << "init tensor failed";
return;
}
std::vector<lite::tensor::Tensor *> inputs = {tensor_data};
std::vector<lite::tensor::Tensor *> outputs = {output_tensor};
MS_LOG(INFO) << "setting SliceParameter";
auto param = new (std::nothrow) SliceParameter();
if (param == nullptr) {
for (auto tensor : inputs) {
delete tensor;
}
for (auto tensor : outputs) {
delete tensor;
}
MS_LOG(INFO) << "new SliceParameter failed";
return;
}
for (int i = 0; i < 4; i++) {
param->begin_[i] = begin[i];
param->size_[i] = size[i];
}
auto *slice_kernel =
new (std::nothrow) kernel::SliceOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (slice_kernel == nullptr) {
for (auto tensor : inputs) {
delete tensor;
}
for (auto tensor : outputs) {
delete tensor;
}
delete param;
MS_LOG(INFO) << "new kernel::slice_kernel failed";
return;
}
slice_kernel->Init();
// to do allocate memory for inputs and outputs
for (auto &input_tensor : inputs) {
input_tensor->MallocData(allocator);
}
MS_LOG(INFO) << "initialize sub_graph";
std::vector<kernel::LiteKernel *> kernels{slice_kernel};
auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
if (sub_graph == nullptr) {
for (auto tensor : inputs) {
delete tensor;
}
for (auto tensor : outputs) {
delete tensor;
}
delete param;
delete slice_kernel;
MS_LOG(INFO) << "new kernel::SubGraphOpenCLKernel failed";
return;
}
sub_graph->Init();
MS_LOG(INFO) << "init tensors";
memcpy(inputs[0]->Data(), input_data, input_size);
std::cout << "==================output data================" << std::endl;
sub_graph->Run();
auto *output_data_gpu = reinterpret_cast<float *>(output_tensor->Data());
CompareOutputData1(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.0001);
for (auto tensor : inputs) {
delete tensor;
}
for (auto tensor : outputs) {
delete tensor;
}
delete slice_kernel;
delete sub_graph;
lite::opencl::OpenCLRuntime::DeleteInstance();
}
} // namespace mindspore