oencl depthwise optimize for adreno gpu

This commit is contained in:
wandongdong 2021-01-13 00:22:21 -08:00
parent 0bda42634f
commit 82669cfc67
6 changed files with 89 additions and 93 deletions

View File

@ -1,12 +1,12 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void DepthwiseConv2d_IMG_NC4HW4(__write_only image2d_t dst_data, __read_only image2d_t src_data,
__global FLT4 *filter, __global FLT4 *bias, int2 kernel_size, int2 stride,
int2 padding, int2 dilation, int4 src_size, int4 dst_size, float relu_clip_min,
float relu_clip_max) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
__kernel void DepthwiseConv2d_IMG_NHWC4(__write_only image2d_t dst_data, __read_only image2d_t src_data,
__read_only image2d_t filter, __global FLT4 *bias, int2 kernel_size,
int2 stride, int2 padding, int2 dilation, int4 src_size, int4 dst_size,
float relu_clip_min, float relu_clip_max) {
int X = get_global_id(1);
int Y = get_global_id(2);
int Z = get_global_id(0);
if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;
FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
int x_offset = X * stride.x + padding.x;
@ -19,8 +19,8 @@ __kernel void DepthwiseConv2d_IMG_NC4HW4(__write_only image2d_t dst_data, __read
int x_c = x_offset + kx * dilation.x;
bool outside_x = x_c < 0 || x_c >= src_size.x;
if (!outside_x && !outside_y) {
FLT4 flt_p = filter[fx_c];
FLT4 src_p = READ_IMAGE(src_data, smp_zero, (int2)(x_c, (Z * src_size.y + y_c)));
FLT4 flt_p = READ_IMAGE(filter, smp_zero, (int2)(ky * kernel_size.x + kx, Z));
FLT4 src_p = READ_IMAGE(src_data, smp_zero, (int2)(Z + x_c * src_size.z, y_c));
r += TO_FLT4(src_p * flt_p);
}
fx_c++;
@ -29,9 +29,39 @@ __kernel void DepthwiseConv2d_IMG_NC4HW4(__write_only image2d_t dst_data, __read
FLT4 bias_p = bias[Z];
FLT4 res = TO_FLT4(r) + bias_p;
res = clamp(res, (FLT)(relu_clip_min), (FLT)(relu_clip_max));
WRITE_IMAGE(dst_data, (int2)(X, (Z * dst_size.y + Y)), res);
WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), res);
}
__kernel void DepthwiseConv2d_IMG_NHWC4_1x1(__write_only image2d_t dst_data, __read_only image2d_t src_data,
__read_only image2d_t filter, __global FLT4 *bias, int2 kernel_size,
int2 stride, int2 padding, int2 dilation, int4 src_size, int4 dst_size,
float relu_clip_min, float relu_clip_max) {
int X = get_global_id(1);
int Y = get_global_id(2);
int Z = get_global_id(0);
if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;
FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
int x_offset = X * stride.x + padding.x;
int y_offset = Y * stride.y + padding.y;
int fx_c = Z;
{
int y_c = y_offset;
bool outside_y = y_c < 0 || y_c >= src_size.y;
{
int x_c = x_offset;
bool outside_x = x_c < 0 || x_c >= src_size.x;
if (!outside_x && !outside_y) {
FLT4 flt_p = READ_IMAGE(filter, smp_zero, (int2)(0, Z));
FLT4 src_p = READ_IMAGE(src_data, smp_zero, (int2)(Z + x_c * src_size.z, y_c));
r += TO_FLT4(src_p * flt_p);
}
}
}
FLT4 bias_p = bias[Z];
FLT4 res = TO_FLT4(r) + bias_p;
res = clamp(res, (FLT)(relu_clip_min), (FLT)(relu_clip_max));
WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), res);
}
__kernel void DepthwiseConv2d_IMG_NHWC4_b222(__write_only image2d_t dst_data, __read_only image2d_t src_data,
__global FLT4 *filter, __global FLT4 *bias, int2 kernel_size, int2 stride,
int2 padding, int2 dilation, int4 src_size, int4 dst_size,
@ -264,65 +294,3 @@ __kernel void DepthwiseConv2d_BUF_NC4HW4(__global FLT4 *dst_data, __global FLT4
res = clamp(res, (FLT)(relu_clip_min), (FLT)(relu_clip_max));
dst_data[(((Z)*dst_size.y + (Y)) * dst_size.x + (X))] = res;
}
__kernel void DepthwiseConv2d_BUF_NHWC4(__global FLT4 *dst_data, __global FLT4 *src_data, __global FLT4 *filter,
__global FLT4 *bias, int2 kernel_size, int2 stride, int2 padding, int2 dilation,
int4 src_size, int4 dst_size, float relu_clip_min, float relu_clip_max) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;
FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
int x_offset = X * stride.x + padding.x;
int y_offset = Y * stride.y + padding.y;
int fx_c = Z * kernel_size.x * kernel_size.y;
for (int ky = 0; ky < kernel_size.y; ++ky) {
int y_c = y_offset + ky * dilation.y;
bool outside_y = y_c < 0 || y_c >= src_size.y;
for (int kx = 0; kx < kernel_size.x; ++kx) {
int x_c = x_offset + kx * dilation.x;
bool outside_x = x_c < 0 || x_c >= src_size.x;
if (!outside_x && !outside_y) {
FLT4 flt_p = filter[fx_c];
FLT4 src_p = src_data[((y_c * src_size.x + x_c) * src_size.z + Z)];
r += TO_FLT4(src_p * flt_p);
}
fx_c++;
}
}
FLT4 bias_p = bias[Z];
FLT4 res = TO_FLT4(r) + bias_p;
res = clamp(res, (FLT)(relu_clip_min), (FLT)(relu_clip_max));
dst_data[((Y * dst_size.x + X) * dst_size.z + Z)] = res;
}
__kernel void DepthwiseConv2d_BUF_NHWC4_1x1(__global FLT4 *dst_data, __global FLT4 *src_data, __global FLT4 *filter,
__global FLT4 *bias, int2 kernel_size, int2 stride, int2 padding,
int2 dilation, int4 src_size, int4 dst_size, float relu_clip_min,
float relu_clip_max) {
int X = get_global_id(0);
int Y = get_global_id(1);
int Z = get_global_id(2);
if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;
FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
int x_offset = X * stride.x + padding.x;
int y_offset = Y * stride.y + padding.y;
int fx_c = Z;
{
int y_c = y_offset;
bool outside_y = y_c < 0 || y_c >= src_size.y;
{
int x_c = x_offset;
bool outside_x = x_c < 0 || x_c >= src_size.x;
if (!outside_x && !outside_y) {
FLT4 flt_p = filter[fx_c];
FLT4 src_p = src_data[((y_c * src_size.x + x_c) * src_size.z + Z)];
r += TO_FLT4(src_p * flt_p);
}
}
}
FLT4 bias_p = bias[Z];
FLT4 res = TO_FLT4(r) + bias_p;
res = clamp(res, (FLT)(relu_clip_min), (FLT)(relu_clip_max));
dst_data[((Y * dst_size.x + X) * dst_size.z + Z)] = res;
}

View File

@ -73,7 +73,11 @@ int DepthwiseConv2dOpenCLKernel::Prepare() {
if (parameter->kernel_h_ == 1 && parameter->kernel_w_ == 1) {
kernel_name += "_1x1";
}
kernel_name += "_b" + std::to_string(block_size_.H) + std::to_string(block_size_.W) + std::to_string(block_size_.C);
if (filter_type_ == lite::opencl::MemType::BUF) {
kernel_name += "_b" + std::to_string(block_size_.H) + std::to_string(block_size_.W) + std::to_string(block_size_.C);
} else {
block_size_.C = block_size_.H = block_size_.W = 1;
}
#ifdef PROGRAM_WITH_IL
kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name);
#else
@ -107,32 +111,42 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() {
int CO4 = UP_DIV(out_info.C, C4NUM * block_size_.C);
int pack_weight_size = C4NUM * CO4 * parameter->kernel_h_ * parameter->kernel_w_;
int plane = parameter->kernel_h_ * parameter->kernel_w_;
int plane_in = parameter->kernel_h_ * parameter->kernel_w_;
int plane_out = plane_in * C4NUM;
std::vector<size_t> img_size;
if (filter_type_ == MemType::IMG) {
int alignment = ocl_runtime_->GetImagePitchAlignment();
plane_out = UP_ROUND(plane_out, alignment) * C4NUM;
pack_weight_size = plane_out * CO4;
auto shape = in_tensors_[1]->shape();
size_t img_dtype = ocl_runtime_->GetFp16Enable() ? CL_HALF_FLOAT : CL_FLOAT;
img_size = {(size_t)plane_out / C4NUM, (size_t)shape[0] * CO4, img_dtype};
}
if (is_fp16) {
packed_weight_ = allocator->Malloc(pack_weight_size * sizeof(int16_t));
packed_weight_ = allocator->Malloc(pack_weight_size * sizeof(int16_t), img_size);
packed_weight_ = allocator->MapBuffer(packed_weight_, CL_MAP_WRITE, nullptr, true);
if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat16) {
std::function<int16_t(int16_t)> to_dtype = [](int16_t x) -> int16_t { return x; };
PackNCHWToNC4HW4<int16_t, int16_t>(origin_weight, packed_weight_, 1, plane, out_info.C, to_dtype);
PackNCHWToNC4HW4<int16_t, int16_t>(origin_weight, packed_weight_, 1, plane_in, plane_out, out_info.C, to_dtype);
} else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat32) {
std::function<float16_t(float)> to_dtype = [](float x) -> float16_t { return static_cast<float16_t>(x); };
PackNCHWToNC4HW4<float, float16_t>(origin_weight, packed_weight_, 1, plane, out_info.C, to_dtype);
PackNCHWToNC4HW4<float, float16_t>(origin_weight, packed_weight_, 1, plane_in, plane_out, out_info.C, to_dtype);
} else { // int8 or int16
std::function<int16_t(int16_t)> to_dtype = [](int16_t x) -> int16_t { return x; };
PackNCHWToNC4HW4<int16_t, int16_t>(origin_weight, packed_weight_, 1, plane, out_info.C, to_dtype);
PackNCHWToNC4HW4<int16_t, int16_t>(origin_weight, packed_weight_, 1, plane_in, plane_out, out_info.C, to_dtype);
}
} else {
packed_weight_ = allocator->Malloc(pack_weight_size * sizeof(float));
packed_weight_ = allocator->Malloc(pack_weight_size * sizeof(float), img_size);
packed_weight_ = allocator->MapBuffer(packed_weight_, CL_MAP_WRITE, nullptr, true);
if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat32) {
std::function<float(float)> to_dtype = [](float x) -> float { return x; };
PackNCHWToNC4HW4<float, float>(origin_weight, packed_weight_, 1, plane, out_info.C, to_dtype);
PackNCHWToNC4HW4<float, float>(origin_weight, packed_weight_, 1, plane_in, plane_out, out_info.C, to_dtype);
} else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat16) {
std::function<float(float16_t)> to_dtype = [](float16_t x) -> float { return static_cast<float>(x); };
PackNCHWToNC4HW4<float16_t, float>(origin_weight, packed_weight_, 1, plane, out_info.C, to_dtype);
PackNCHWToNC4HW4<float16_t, float>(origin_weight, packed_weight_, 1, plane_in, plane_out, out_info.C, to_dtype);
} else { // int8 or int16
std::function<float(float)> to_dtype = [](float x) -> float { return x; };
PackNCHWToNC4HW4<float, float>(origin_weight, packed_weight_, 1, plane, out_info.C, to_dtype);
PackNCHWToNC4HW4<float, float>(origin_weight, packed_weight_, 1, plane_in, plane_out, out_info.C, to_dtype);
}
}
allocator->UnmapBuffer(packed_weight_);
@ -184,7 +198,7 @@ void DepthwiseConv2dOpenCLKernel::SetConstArgs() {
cl_int4 dst_size = {(cl_int)out_info.W, (cl_int)out_info.H, (cl_int)CO4, (cl_int)out_info.N};
int arg_cnt = 2;
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, packed_weight_, lite::opencl::MemType::BUF);
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, packed_weight_, filter_type_);
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, bias_data_, lite::opencl::MemType::BUF);
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, kernel_size);
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, stride);

View File

@ -21,13 +21,18 @@
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "nnacl/conv_parameter.h"
using mindspore::lite::opencl::MemType;
namespace mindspore::kernel {
class DepthwiseConv2dOpenCLKernel : public OpenCLKernel {
public:
DepthwiseConv2dOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs)
: OpenCLKernel(parameter, inputs, outputs) {}
: OpenCLKernel(parameter, inputs, outputs) {
bool is_adreno = ocl_runtime_->GetGpuInfo().type == lite::opencl::GpuType::ADRENO;
filter_type_ = is_adreno ? MemType::IMG : MemType::BUF;
}
~DepthwiseConv2dOpenCLKernel() override = default;
@ -47,6 +52,7 @@ class DepthwiseConv2dOpenCLKernel : public OpenCLKernel {
int W{2};
int C{1};
} block_size_;
MemType filter_type_{MemType::BUF};
};
} // namespace mindspore::kernel

View File

@ -62,19 +62,20 @@ std::vector<int> GetNHWCShape(const std::vector<int> &tensor_shape);
std::vector<size_t> GetImage2dShapeFromNHWC(const std::vector<int> &tensor_shape, schema::Format format);
template <class T1, class T2>
void PackNCHWToNC4HW4(void *src, void *dst, int batch, int plane, int channel, const std::function<T2(T1)> &to_dtype) {
void PackNCHWToNC4HW4(void *src, void *dst, int batch, int plane_in, int plane_out, int channel,
const std::function<T2(T1)> &to_dtype) {
MS_ASSERT(src);
MS_ASSERT(dst);
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_offset = b * plane * channel;
int dst_offset = b * plane * c4 * C4NUM;
int src_offset = b * plane_in * channel;
int dst_offset = b * plane_out * c4;
for (int c = 0; c < channel; c++) {
int c4_block_num = c / C4NUM;
int c4_block_rem = c % C4NUM;
int src_c_offset = src_offset + c * plane;
int dst_c_offset = dst_offset + c4_block_num * plane * C4NUM;
for (int k = 0; k < plane; k++) {
int src_c_offset = src_offset + c * plane_in;
int dst_c_offset = dst_offset + c4_block_num * plane_out;
for (int k = 0; k < plane_in; k++) {
int src_kernel_offset = src_c_offset + k;
int dst_kernel_offset = dst_c_offset + C4NUM * k + c4_block_rem;
(static_cast<T2 *>(dst) + dst_kernel_offset)[0] = to_dtype((static_cast<T1 *>(src) + src_kernel_offset)[0]);

View File

@ -187,13 +187,20 @@ class OpenCLRuntime {
std::vector<size_t> max_work_item_sizes_;
void *handle_{nullptr};
TuningMode tuning_mode_{TuningMode::DEFAULT};
#if MS_OPENCL_PROFILE
bool profiling_{true};
#else
bool profiling_{false};
#endif
// for cache
private:
void LoadCache();
void StoreCache();
#ifdef MS_OPENCL_BINARY_CACHE
bool enable_cache_{true};
#else
bool enable_cache_{false};
#endif
bool flush_cache_{false};
std::string cache_path_{"/data/local/tmp/.opencl_cache"};
const std::string cache_version_{"V0.1"};

View File

@ -81,7 +81,7 @@ TEST_F(TestOpenCL_DepthwiseConv2d, NoPad) {
TestMain({{input_shape, input_data, VAR},
{weight_shape, weight_data, CONST_TENSOR},
{bias_shape, bias_data, CONST_TENSOR}},
{output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-2 : 1e-5);
{output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-2 : 1e-5, 1e-1, true);
}
}
@ -128,7 +128,7 @@ TEST_F(TestOpenCL_DepthwiseConv2d, Pad) {
TestMain({{input_shape, input_data, VAR},
{weight_shape, weight_data, CONST_TENSOR},
{bias_shape, bias_data, CONST_TENSOR}},
{output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-2 : 1e-5);
{output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-2 : 1e-5, 1e-1, true);
}
}