forked from mindspore-Ecosystem/mindspore
fix bug in sigmoid
This commit is contained in:
parent
eb3b093686
commit
c523df059a
|
@ -35,12 +35,27 @@ __kernel void Relu6(__read_only image2d_t input, __write_only image2d_t output,
|
|||
WRITE_IMAGE(output, (int2)(X, Y), in_c4);
|
||||
}
|
||||
|
||||
__kernel void Sigmoid(__read_only image2d_t input, __write_only image2d_t output, const int2 img_shape) {
|
||||
__kernel void Sigmoid(__read_only image2d_t input, __write_only image2d_t output, const int2 img_shape, const int c4,
|
||||
const int last_c4) {
|
||||
int X = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
if (X >= img_shape.x || Y >= img_shape.y) return;
|
||||
int C4 = X % c4;
|
||||
FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X, Y));
|
||||
in_c4 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-in_c4));
|
||||
if (C4 < c4 - 1) {
|
||||
in_c4 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-in_c4));
|
||||
} else {
|
||||
in_c4.x = (FLT)(1.f) / ((FLT)(1.f) + exp(-in_c4.x));
|
||||
if (last_c4 > 1) {
|
||||
in_c4.y = (FLT)(1.f) / ((FLT)(1.f) + exp(-in_c4.y));
|
||||
}
|
||||
if (last_c4 > 2) {
|
||||
in_c4.z = (FLT)(1.f) / ((FLT)(1.f) + exp(-in_c4.z));
|
||||
}
|
||||
if (last_c4 > 3) {
|
||||
in_c4.w = (FLT)(1.f) / ((FLT)(1.f) + exp(-in_c4.w));
|
||||
}
|
||||
}
|
||||
WRITE_IMAGE(output, (int2)(X, Y), in_c4);
|
||||
}
|
||||
|
||||
|
|
|
@ -266,7 +266,7 @@ __kernel void ElementGreaterEqual_IMG(__read_only image2d_t input_a, __read_only
|
|||
|
||||
__kernel void BroadcastNHWC4Add_IMG(__read_only image2d_t input_a, __read_only image2d_t input_b,
|
||||
__write_only image2d_t output, const int4 a_shape, const int4 b_shape,
|
||||
const int4 output_shape, float act_min, float act_max) {
|
||||
const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {
|
||||
int X = get_global_id(0); // C4
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // H
|
||||
|
@ -281,14 +281,21 @@ __kernel void BroadcastNHWC4Add_IMG(__read_only image2d_t input_a, __read_only i
|
|||
int b_w = Y < b_shape.z ? Y : b_shape.z - 1;
|
||||
int b_h = Z < b_shape.y ? Z : b_shape.y - 1;
|
||||
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h));
|
||||
FLT4 result = a + b;
|
||||
FLT4 result;
|
||||
if (broadcastC_flag == 0) {
|
||||
result = a + b;
|
||||
} else if (broadcastC_flag == 1) {
|
||||
result = a.x + b;
|
||||
} else {
|
||||
result = a + b.x;
|
||||
}
|
||||
result = clamp(result, (FLT)(act_min), (FLT)(act_max));
|
||||
WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result);
|
||||
}
|
||||
|
||||
__kernel void BroadcastNHWC4Sub_IMG(__read_only image2d_t input_a, __read_only image2d_t input_b,
|
||||
__write_only image2d_t output, const int4 a_shape, const int4 b_shape,
|
||||
const int4 output_shape, float act_min, float act_max) {
|
||||
const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {
|
||||
int X = get_global_id(0); // C4
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // H
|
||||
|
@ -303,14 +310,21 @@ __kernel void BroadcastNHWC4Sub_IMG(__read_only image2d_t input_a, __read_only i
|
|||
int b_w = Y < b_shape.z ? Y : b_shape.z - 1;
|
||||
int b_h = Z < b_shape.y ? Z : b_shape.y - 1;
|
||||
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h));
|
||||
FLT4 result = a - b;
|
||||
FLT4 result;
|
||||
if (broadcastC_flag == 0) {
|
||||
result = a - b;
|
||||
} else if (broadcastC_flag == 1) {
|
||||
result = a.x - b;
|
||||
} else {
|
||||
result = a - b.x;
|
||||
}
|
||||
result = clamp(result, (FLT)(act_min), (FLT)(act_max));
|
||||
WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result);
|
||||
}
|
||||
|
||||
__kernel void BroadcastNHWC4Mul_IMG(__read_only image2d_t input_a, __read_only image2d_t input_b,
|
||||
__write_only image2d_t output, const int4 a_shape, const int4 b_shape,
|
||||
const int4 output_shape, float act_min, float act_max) {
|
||||
const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {
|
||||
int X = get_global_id(0); // C4
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // H
|
||||
|
@ -325,14 +339,21 @@ __kernel void BroadcastNHWC4Mul_IMG(__read_only image2d_t input_a, __read_only i
|
|||
int b_w = Y < b_shape.z ? Y : b_shape.z - 1;
|
||||
int b_h = Z < b_shape.y ? Z : b_shape.y - 1;
|
||||
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h));
|
||||
FLT4 result = a * b;
|
||||
FLT4 result;
|
||||
if (broadcastC_flag == 0) {
|
||||
result = a * b;
|
||||
} else if (broadcastC_flag == 1) {
|
||||
result = a.x * b;
|
||||
} else {
|
||||
result = a * b.x;
|
||||
}
|
||||
result = clamp(result, (FLT)(act_min), (FLT)(act_max));
|
||||
WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result);
|
||||
}
|
||||
|
||||
__kernel void BroadcastNHWC4Div_IMG(__read_only image2d_t input_a, __read_only image2d_t input_b,
|
||||
__write_only image2d_t output, const int4 a_shape, const int4 b_shape,
|
||||
const int4 output_shape, float act_min, float act_max) {
|
||||
const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {
|
||||
int X = get_global_id(0); // C4
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // H
|
||||
|
@ -347,7 +368,14 @@ __kernel void BroadcastNHWC4Div_IMG(__read_only image2d_t input_a, __read_only i
|
|||
int b_w = Y < b_shape.z ? Y : b_shape.z - 1;
|
||||
int b_h = Z < b_shape.y ? Z : b_shape.y - 1;
|
||||
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h));
|
||||
FLT4 result = a / b;
|
||||
FLT4 result;
|
||||
if (broadcastC_flag == 0) {
|
||||
result = a / b;
|
||||
} else if (broadcastC_flag == 1) {
|
||||
result = a.x / b;
|
||||
} else {
|
||||
result = a / b.x;
|
||||
}
|
||||
result = clamp(result, (FLT)(act_min), (FLT)(act_max));
|
||||
WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result);
|
||||
}
|
||||
|
|
|
@ -70,6 +70,12 @@ int ActivationOpenClKernel::SetArgs() {
|
|||
if (type_ == ActivationType_LEAKY_RELU) {
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, alpha_);
|
||||
}
|
||||
if (type_ == ActivationType_SIGMOID) {
|
||||
int c4 = outShape.Slice;
|
||||
int last_c4 = outShape.C % 4 == 0 ? 4 : outShape.C % 4;
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, c4);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, last_c4);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -148,6 +148,13 @@ int ArithmeticOpenCLKernel::SetArgs() {
|
|||
auto out_shape = GetNHWCShape(out_tensors_[0]->shape());
|
||||
cl_int4 output_shape{out_shape[0], out_shape[1], out_shape[2], UP_DIV(out_shape[3], C4NUM)};
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, output_shape);
|
||||
int broadcastC_flag = 0; // do not need broadcast in C4
|
||||
if (inputs_nhwc_shapes_[0][3] == 1 && inputs_nhwc_shapes_[1][3] != 1) {
|
||||
broadcastC_flag = 1; // BroadCast C4 in input0
|
||||
} else if (inputs_nhwc_shapes_[0][3] != 1 && inputs_nhwc_shapes_[1][3] == 1) {
|
||||
broadcastC_flag = 2; // BroadCast C4 in input1
|
||||
}
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, broadcastC_flag);
|
||||
} else {
|
||||
cl_int2 output_shape{static_cast<int>(global_size_[0]), static_cast<int>(global_size_[1])};
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, output_shape);
|
||||
|
|
Loading…
Reference in New Issue