reduce optimize

This commit is contained in:
chenzupeng 2020-10-24 14:44:38 +08:00
parent a6075cc73b
commit f378a11e30
3 changed files with 53 additions and 19 deletions

View File

@ -1,6 +1,7 @@
#ifdef cl_khr_fp16
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#endif
#define LOCAL_CACHE_THREAD 16
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void mean_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) {
int X = get_global_id(0); // C4
@ -17,19 +18,31 @@ __kernel void mean_NHWC4(__read_only image2d_t src_data, __write_only image2d_t
WRITE_IMAGE(dst_data, (int2)(X, 0), TO_FLT4(result));
}
__kernel void mean_NC4HW4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) {
__kernel void mean_local_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) {
int X = get_global_id(0); // C4
if (X >= size.z) {
return;
}
FLT4 result = (FLT4)0.f;
for (int h = 0; h < size.x; h++) {
for (int w = 0; w < size.y; w++) {
result += READ_IMAGE(src_data, smp_zero, (int2)(w, X * size.x + h));
int localy = get_local_id(1);
int localz = get_local_id(2);
if (X >= size.z) return;
__local float4 temp[LOCAL_CACHE_THREAD][LOCAL_CACHE_THREAD];
temp[localy][localz] = (float4)0.f;
for (int h = localy; h < size.x; h += LOCAL_CACHE_THREAD) {
for (int w = localz; w < size.y; w += LOCAL_CACHE_THREAD) {
temp[localy][localz] += convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + X, h)));
}
}
barrier(CLK_LOCAL_MEM_FENCE);
if (localz == 0) {
for (int i = 1; i < LOCAL_CACHE_THREAD; i++) {
temp[localy][0] += temp[localy][i];
}
}
barrier(CLK_LOCAL_MEM_FENCE);
float4 result = temp[0][0];
for (int i = 1; i < LOCAL_CACHE_THREAD; i++) {
result += temp[i][0];
}
result /= size.x * size.y;
WRITE_IMAGE(dst_data, (int2)(0, X), result);
WRITE_IMAGE(dst_data, (int2)(X, 0), TO_FLT4(result));
}
__kernel void sum_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) {
@ -46,16 +59,28 @@ __kernel void sum_NHWC4(__read_only image2d_t src_data, __write_only image2d_t d
WRITE_IMAGE(dst_data, (int2)(X, 0), result);
}
__kernel void sum_NC4HW4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) {
__kernel void sum_local_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) {
int X = get_global_id(0); // C4
if (X >= size.z) {
return;
}
FLT4 result = (FLT4)0.f;
for (int h = 0; h < size.x; h++) {
for (int w = 0; w < size.y; w++) {
result += READ_IMAGE(src_data, smp_zero, (int2)(w, X * size.x + h));
int localy = get_local_id(1);
int localz = get_local_id(2);
if (X >= size.z) return;
__local float4 temp[LOCAL_CACHE_THREAD][LOCAL_CACHE_THREAD];
temp[localy][localz] = (float4)0.f;
for (int h = localy; h < size.x; h += LOCAL_CACHE_THREAD) {
for (int w = localz; w < size.y; w += LOCAL_CACHE_THREAD) {
temp[localy][localz] += convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + X, h)));
}
}
WRITE_IMAGE(dst_data, (int2)(0, X), result);
barrier(CLK_LOCAL_MEM_FENCE);
if (localz == 0) {
for (int i = 1; i < LOCAL_CACHE_THREAD; i++) {
temp[localy][0] += temp[localy][i];
}
}
barrier(CLK_LOCAL_MEM_FENCE);
float4 result = temp[0][0];
for (int i = 1; i < LOCAL_CACHE_THREAD; i++) {
result += temp[i][0];
}
WRITE_IMAGE(dst_data, (int2)(X, 0), TO_FLT4(result));
}

View File

@ -57,6 +57,10 @@ int ReduceOpenCLKernel::Init() {
return RET_PARAM_INVALID;
}
std::string kernel_name = reduce_type2str.at(reduce_param->mode_);
if (in_tensors_[0]->shape()[1] >= LOCAL_CACHE_THREAD || in_tensors_[0]->shape()[2] >= LOCAL_CACHE_THREAD) {
use_local_ = true;
kernel_name += "_local";
}
kernel_name += "_NHWC4";
enable_fp16_ = ocl_runtime_->GetFp16Enable();
@ -101,7 +105,10 @@ int ReduceOpenCLKernel::Run() {
int c = shapex[3];
int c4 = UP_DIV(c, C4NUM);
std::vector<size_t> local = {};
std::vector<size_t> global = {static_cast<size_t>(c4)};
if (use_local_) {
local = {1, LOCAL_CACHE_THREAD, LOCAL_CACHE_THREAD};
}
std::vector<size_t> global = {static_cast<size_t>(c4), 1, 1};
cl_int4 size = {h, w, c4, 1};
int arg_idx = 0;
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->data_c());

View File

@ -39,6 +39,8 @@ class ReduceOpenCLKernel : public OpenCLKernel {
cl::Kernel kernel_;
bool enable_fp16_{false};
std::vector<size_t> nhwc_shape_;
bool use_local_{false};
static const size_t LOCAL_CACHE_THREAD{16};
};
} // namespace mindspore::kernel