fix bug in sigmoid

This commit is contained in:
chenzupeng 2020-10-30 16:52:00 +08:00
parent eb3b093686
commit c523df059a
4 changed files with 66 additions and 10 deletions

View File

@ -35,12 +35,27 @@ __kernel void Relu6(__read_only image2d_t input, __write_only image2d_t output,
WRITE_IMAGE(output, (int2)(X, Y), in_c4);
}
__kernel void Sigmoid(__read_only image2d_t input, __write_only image2d_t output, const int2 img_shape) {
__kernel void Sigmoid(__read_only image2d_t input, __write_only image2d_t output, const int2 img_shape, const int c4,
const int last_c4) {
int X = get_global_id(0);
int Y = get_global_id(1);
if (X >= img_shape.x || Y >= img_shape.y) return;
int C4 = X % c4;
FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X, Y));
in_c4 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-in_c4));
if (C4 < c4 - 1) {
in_c4 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-in_c4));
} else {
in_c4.x = (FLT)(1.f) / ((FLT)(1.f) + exp(-in_c4.x));
if (last_c4 > 1) {
in_c4.y = (FLT)(1.f) / ((FLT)(1.f) + exp(-in_c4.y));
}
if (last_c4 > 2) {
in_c4.z = (FLT)(1.f) / ((FLT)(1.f) + exp(-in_c4.z));
}
if (last_c4 > 3) {
in_c4.w = (FLT)(1.f) / ((FLT)(1.f) + exp(-in_c4.w));
}
}
WRITE_IMAGE(output, (int2)(X, Y), in_c4);
}

View File

@ -266,7 +266,7 @@ __kernel void ElementGreaterEqual_IMG(__read_only image2d_t input_a, __read_only
__kernel void BroadcastNHWC4Add_IMG(__read_only image2d_t input_a, __read_only image2d_t input_b,
__write_only image2d_t output, const int4 a_shape, const int4 b_shape,
const int4 output_shape, float act_min, float act_max) {
const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {
int X = get_global_id(0); // C4
int Y = get_global_id(1); // W
int Z = get_global_id(2); // H
@ -281,14 +281,21 @@ __kernel void BroadcastNHWC4Add_IMG(__read_only image2d_t input_a, __read_only i
int b_w = Y < b_shape.z ? Y : b_shape.z - 1;
int b_h = Z < b_shape.y ? Z : b_shape.y - 1;
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h));
FLT4 result = a + b;
FLT4 result;
if (broadcastC_flag == 0) {
result = a + b;
} else if (broadcastC_flag == 1) {
result = a.x + b;
} else {
result = a + b.x;
}
result = clamp(result, (FLT)(act_min), (FLT)(act_max));
WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result);
}
__kernel void BroadcastNHWC4Sub_IMG(__read_only image2d_t input_a, __read_only image2d_t input_b,
__write_only image2d_t output, const int4 a_shape, const int4 b_shape,
const int4 output_shape, float act_min, float act_max) {
const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {
int X = get_global_id(0); // C4
int Y = get_global_id(1); // W
int Z = get_global_id(2); // H
@ -303,14 +310,21 @@ __kernel void BroadcastNHWC4Sub_IMG(__read_only image2d_t input_a, __read_only i
int b_w = Y < b_shape.z ? Y : b_shape.z - 1;
int b_h = Z < b_shape.y ? Z : b_shape.y - 1;
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h));
FLT4 result = a - b;
FLT4 result;
if (broadcastC_flag == 0) {
result = a - b;
} else if (broadcastC_flag == 1) {
result = a.x - b;
} else {
result = a - b.x;
}
result = clamp(result, (FLT)(act_min), (FLT)(act_max));
WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result);
}
__kernel void BroadcastNHWC4Mul_IMG(__read_only image2d_t input_a, __read_only image2d_t input_b,
__write_only image2d_t output, const int4 a_shape, const int4 b_shape,
const int4 output_shape, float act_min, float act_max) {
const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {
int X = get_global_id(0); // C4
int Y = get_global_id(1); // W
int Z = get_global_id(2); // H
@ -325,14 +339,21 @@ __kernel void BroadcastNHWC4Mul_IMG(__read_only image2d_t input_a, __read_only i
int b_w = Y < b_shape.z ? Y : b_shape.z - 1;
int b_h = Z < b_shape.y ? Z : b_shape.y - 1;
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h));
FLT4 result = a * b;
FLT4 result;
if (broadcastC_flag == 0) {
result = a * b;
} else if (broadcastC_flag == 1) {
result = a.x * b;
} else {
result = a * b.x;
}
result = clamp(result, (FLT)(act_min), (FLT)(act_max));
WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result);
}
__kernel void BroadcastNHWC4Div_IMG(__read_only image2d_t input_a, __read_only image2d_t input_b,
__write_only image2d_t output, const int4 a_shape, const int4 b_shape,
const int4 output_shape, float act_min, float act_max) {
const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {
int X = get_global_id(0); // C4
int Y = get_global_id(1); // W
int Z = get_global_id(2); // H
@ -347,7 +368,14 @@ __kernel void BroadcastNHWC4Div_IMG(__read_only image2d_t input_a, __read_only i
int b_w = Y < b_shape.z ? Y : b_shape.z - 1;
int b_h = Z < b_shape.y ? Z : b_shape.y - 1;
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h));
FLT4 result = a / b;
FLT4 result;
if (broadcastC_flag == 0) {
result = a / b;
} else if (broadcastC_flag == 1) {
result = a.x / b;
} else {
result = a / b.x;
}
result = clamp(result, (FLT)(act_min), (FLT)(act_max));
WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result);
}

View File

@ -70,6 +70,12 @@ int ActivationOpenClKernel::SetArgs() {
if (type_ == ActivationType_LEAKY_RELU) {
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, alpha_);
}
if (type_ == ActivationType_SIGMOID) {
int c4 = outShape.Slice;
int last_c4 = outShape.C % 4 == 0 ? 4 : outShape.C % 4;
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, c4);
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, last_c4);
}
return RET_OK;
}

View File

@ -148,6 +148,13 @@ int ArithmeticOpenCLKernel::SetArgs() {
auto out_shape = GetNHWCShape(out_tensors_[0]->shape());
cl_int4 output_shape{out_shape[0], out_shape[1], out_shape[2], UP_DIV(out_shape[3], C4NUM)};
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, output_shape);
int broadcastC_flag = 0; // do not need broadcast in C4
if (inputs_nhwc_shapes_[0][3] == 1 && inputs_nhwc_shapes_[1][3] != 1) {
broadcastC_flag = 1; // BroadCast C4 in input0
} else if (inputs_nhwc_shapes_[0][3] != 1 && inputs_nhwc_shapes_[1][3] == 1) {
broadcastC_flag = 2; // BroadCast C4 in input1
}
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, broadcastC_flag);
} else {
cl_int2 output_shape{static_cast<int>(global_size_[0]), static_cast<int>(global_size_[1])};
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, output_shape);