From f8fce407eb1e8447d514c637f21e3f39d473fb89 Mon Sep 17 00:00:00 2001 From: greatpanc Date: Mon, 27 Sep 2021 17:43:47 +0800 Subject: [PATCH] ms opencl performance drop bugfix --- .../src/runtime/kernel/opencl/cl/softmax.cl | 73 +++---------------- 1 file changed, 9 insertions(+), 64 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/softmax.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/softmax.cl index 26de948477e..9fb07aac0c6 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/softmax.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/softmax.cl @@ -59,26 +59,14 @@ __kernel void SoftmaxAxis1_NHWC4(__read_only image2d_t input, __write_only image if (n >= input_shape.x || X >= W || Y >= C4) return; - // get max - float input_max = 0.0f; - for (int d = 0; d < H; ++d) { - float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(X * C4 + Y, n * H + d))); - input_max = max(input_max, t.x); - input_max = max(input_max, t.y); - input_max = max(input_max, t.z); - input_max = max(input_max, t.w); - } - float4 input_max_f4 = (float4)(input_max, input_max, input_max, input_max); - - // get sum float4 sum = 0.0f; for (int d = 0; d < H; ++d) { float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(X * C4 + Y, n * H + d))); - sum += exp(t - input_max_f4); + sum += exp(t); } for (int d = 0; d < H; ++d) { float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(X * C4 + Y, n * H + d))); - result = exp(result - input_max_f4) / sum; + result = exp(result) / sum; WRITE_IMAGEOUT(output, (int2)(X * C4 + Y, n * H + d), OUT_FLT4(result)); } } @@ -94,26 +82,14 @@ __kernel void SoftmaxAxis2_NHWC4(__read_only image2d_t input, __write_only image if (n >= input_shape.x || X >= H || Y >= C4) return; - // get max - float input_max = 0.0f; - for (int d = 0; d < W; ++d) { - float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(d * C4 + Y, n * H + X))); - input_max = max(input_max, t.x); - input_max = max(input_max, t.y); - input_max = max(input_max, t.z); - input_max = max(input_max, t.w); - } - float4 input_max_f4 = (float4)(input_max, input_max, input_max, input_max); - - // get sum float4 sum = 0.0f; for (int d = 0; d < W; ++d) { float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(d * C4 + Y, n * H + X))); - sum += exp(t - input_max); + sum += exp(t); } for (int d = 0; d < W; ++d) { float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(d * C4 + Y, n * H + X))); - result = exp(result - input_max) / sum; + result = exp(result) / sum; WRITE_IMAGEOUT(output, (int2)(d * C4 + Y, n * H + X), OUT_FLT4(result)); } } @@ -124,49 +100,18 @@ __kernel void Softmax1x1_NHWC4(__read_only image2d_t input, __write_only image2d int n = get_global_id(1); if (n >= input_shape.x) return; int C4 = input_shape.w; - - // get max - float4 last = convert_float4(READ_IMAGE(input, smp_zero, (int2)(C4 - 1, n))); - float input_max = last.x; - if (mask.y > 0.5f) input_max = max(input_max, last.y); - if (mask.z > 0.5f) input_max = max(input_max, last.z); - if (mask.w > 0.5f) input_max = max(input_max, last.w); - for (size_t i = tid; i < C4 - 1; i += 32) { - float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(i, n))); - input_max = max(input_max, t.x); - input_max = max(input_max, t.y); - input_max = max(input_max, t.z); - input_max = max(input_max, t.w); - } - __local float4 tmp[8]; - __local float *tmpx1 = (__local float *)tmp; - tmpx1[tid] = input_max; - barrier(CLK_LOCAL_MEM_FENCE); - if (tid == 0) { - input_max = max(input_max, tmpx1[0]); - input_max = max(input_max, tmpx1[1]); - input_max = max(input_max, tmpx1[2]); - input_max = max(input_max, tmpx1[3]); - input_max = max(input_max, tmpx1[4]); - input_max = max(input_max, tmpx1[5]); - input_max = max(input_max, tmpx1[6]); - input_max = max(input_max, tmpx1[7]); - tmpx1[0] = input_max; - } - barrier(CLK_GLOBAL_MEM_FENCE); - float4 input_max_f4 = (float4)(tmpx1[0], tmpx1[0], tmpx1[0], tmpx1[0]); - - // get sum float sum = 0.0f; for (size_t i = tid; i < C4 - 1; i += 32) { float4 src = convert_float4(READ_IMAGE(input, smp_zero, (int2)(i, n))); - sum += dot((float4)(1.0f), exp(src - input_max_f4)); + sum += dot((float4)(1.0f), exp(src)); } if ((C4 - 1) % 32 == tid) { float4 src = convert_float4(READ_IMAGE(input, smp_zero, (int2)(C4 - 1, n))); - sum += dot(convert_float4(mask), exp(src - input_max_f4)); + sum += dot(convert_float4(mask), exp(src)); } + __local float4 tmp[8]; + __local float *tmpx1 = (__local float *)tmp; tmpx1[tid] = sum; barrier(CLK_LOCAL_MEM_FENCE); if (tid == 0) { @@ -184,7 +129,7 @@ __kernel void Softmax1x1_NHWC4(__read_only image2d_t input, __write_only image2d sum = tmpx1[0]; for (size_t i = tid; i < C4; i += 32) { float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(i, n))); - result = exp(result - input_max_f4) * sum; + result = exp(result) * sum; WRITE_IMAGEOUT(output, (int2)(i, n), OUT_FLT4(result)); } }