!12217 [MS][LITE][GPU]arithmetic broadcast support n>1

From: @chenzupeng
Reviewed-by: @ddwsky,@zhanghaibo5
Signed-off-by: @ddwsky,@zhanghaibo5
This commit is contained in:
mindspore-ci-bot 2021-02-07 16:52:56 +08:00 committed by Gitee
commit 028dcf85ae
3 changed files with 84 additions and 54 deletions

View File

@ -265,18 +265,22 @@ __kernel void BroadcastNHWC4Add(__read_only image2d_t input_a, __read_only image
const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {
int X = get_global_id(0); // C4
int Y = get_global_id(1); // W
int Z = get_global_id(2); // H
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y) {
int Z = get_global_id(2); // N * H
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y * output_shape.x) {
return;
}
int a_c = X < a_shape.w ? X : a_shape.w - 1;
int a_w = Y < a_shape.z ? Y : a_shape.z - 1;
int a_h = Z < a_shape.y ? Z : a_shape.y - 1;
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_h));
int b_c = X < b_shape.w ? X : b_shape.w - 1;
int b_w = Y < b_shape.z ? Y : b_shape.z - 1;
int b_h = Z < b_shape.y ? Z : b_shape.y - 1;
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h));
int H = Z % output_shape.y;
int N = Z / output_shape.y;
int a_c = X < a_shape.w ? X : 0;
int a_w = Y < a_shape.z ? Y : 0;
int a_h = H < a_shape.y ? H : 0;
int a_n = N < a_shape.x ? N : 0;
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h));
int b_c = X < b_shape.w ? X : 0;
int b_w = Y < b_shape.z ? Y : 0;
int b_h = H < b_shape.y ? H : 0;
int b_n = N < b_shape.x ? N : 0;
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h));
FLT4 result;
if (broadcastC_flag == 0) {
result = a + b;
@ -294,18 +298,22 @@ __kernel void BroadcastNHWC4BiasAdd(__read_only image2d_t input_a, __read_only i
const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {
int X = get_global_id(0); // C4
int Y = get_global_id(1); // W
int Z = get_global_id(2); // H
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y) {
int Z = get_global_id(2); // N * H
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y * output_shape.x) {
return;
}
int a_c = X < a_shape.w ? X : a_shape.w - 1;
int a_w = Y < a_shape.z ? Y : a_shape.z - 1;
int a_h = Z < a_shape.y ? Z : a_shape.y - 1;
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_h));
int b_c = X < b_shape.w ? X : b_shape.w - 1;
int b_w = Y < b_shape.z ? Y : b_shape.z - 1;
int b_h = Z < b_shape.y ? Z : b_shape.y - 1;
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h));
int H = Z % output_shape.y;
int N = Z / output_shape.y;
int a_c = X < a_shape.w ? X : 0;
int a_w = Y < a_shape.z ? Y : 0;
int a_h = H < a_shape.y ? H : 0;
int a_n = N < a_shape.x ? N : 0;
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h));
int b_c = X < b_shape.w ? X : 0;
int b_w = Y < b_shape.z ? Y : 0;
int b_h = H < b_shape.y ? H : 0;
int b_n = N < b_shape.x ? N : 0;
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h));
FLT4 result;
if (broadcastC_flag == 0) {
result = a + b;
@ -323,18 +331,22 @@ __kernel void BroadcastNHWC4Sub(__read_only image2d_t input_a, __read_only image
const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {
int X = get_global_id(0); // C4
int Y = get_global_id(1); // W
int Z = get_global_id(2); // H
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y) {
int Z = get_global_id(2); // N * H
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y * output_shape.x) {
return;
}
int a_c = X < a_shape.w ? X : a_shape.w - 1;
int a_w = Y < a_shape.z ? Y : a_shape.z - 1;
int a_h = Z < a_shape.y ? Z : a_shape.y - 1;
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_h));
int b_c = X < b_shape.w ? X : b_shape.w - 1;
int b_w = Y < b_shape.z ? Y : b_shape.z - 1;
int b_h = Z < b_shape.y ? Z : b_shape.y - 1;
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h));
int H = Z % output_shape.y;
int N = Z / output_shape.y;
int a_c = X < a_shape.w ? X : 0;
int a_w = Y < a_shape.z ? Y : 0;
int a_h = H < a_shape.y ? H : 0;
int a_n = N < a_shape.x ? N : 0;
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h));
int b_c = X < b_shape.w ? X : 0;
int b_w = Y < b_shape.z ? Y : 0;
int b_h = H < b_shape.y ? H : 0;
int b_n = N < b_shape.x ? N : 0;
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h));
FLT4 result;
if (broadcastC_flag == 0) {
result = a - b;
@ -352,18 +364,22 @@ __kernel void BroadcastNHWC4Mul(__read_only image2d_t input_a, __read_only image
const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {
int X = get_global_id(0); // C4
int Y = get_global_id(1); // W
int Z = get_global_id(2); // H
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y) {
int Z = get_global_id(2); // N * H
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y * output_shape.x) {
return;
}
int a_c = X < a_shape.w ? X : a_shape.w - 1;
int a_w = Y < a_shape.z ? Y : a_shape.z - 1;
int a_h = Z < a_shape.y ? Z : a_shape.y - 1;
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_h));
int b_c = X < b_shape.w ? X : b_shape.w - 1;
int b_w = Y < b_shape.z ? Y : b_shape.z - 1;
int b_h = Z < b_shape.y ? Z : b_shape.y - 1;
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h));
int H = Z % output_shape.y;
int N = Z / output_shape.y;
int a_c = X < a_shape.w ? X : 0;
int a_w = Y < a_shape.z ? Y : 0;
int a_h = H < a_shape.y ? H : 0;
int a_n = N < a_shape.x ? N : 0;
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h));
int b_c = X < b_shape.w ? X : 0;
int b_w = Y < b_shape.z ? Y : 0;
int b_h = H < b_shape.y ? H : 0;
int b_n = N < b_shape.x ? N : 0;
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h));
FLT4 result;
if (broadcastC_flag == 0) {
result = a * b;
@ -381,18 +397,22 @@ __kernel void BroadcastNHWC4Div(__read_only image2d_t input_a, __read_only image
const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {
int X = get_global_id(0); // C4
int Y = get_global_id(1); // W
int Z = get_global_id(2); // H
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y) {
int Z = get_global_id(2); // N * H
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y * output_shape.x) {
return;
}
int a_c = X < a_shape.w ? X : a_shape.w - 1;
int a_w = Y < a_shape.z ? Y : a_shape.z - 1;
int a_h = Z < a_shape.y ? Z : a_shape.y - 1;
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_h));
int b_c = X < b_shape.w ? X : b_shape.w - 1;
int b_w = Y < b_shape.z ? Y : b_shape.z - 1;
int b_h = Z < b_shape.y ? Z : b_shape.y - 1;
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h));
int H = Z % output_shape.y;
int N = Z / output_shape.y;
int a_c = X < a_shape.w ? X : 0;
int a_w = Y < a_shape.z ? Y : 0;
int a_h = H < a_shape.y ? H : 0;
int a_n = N < a_shape.x ? N : 0;
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h));
int b_c = X < b_shape.w ? X : 0;
int b_w = Y < b_shape.z ? Y : 0;
int b_h = H < b_shape.y ? H : 0;
int b_n = N < b_shape.x ? N : 0;
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h));
FLT4 result;
if (broadcastC_flag == 0) {
result = a / b;

View File

@ -46,10 +46,6 @@ int ArithmeticOpenCLKernel::CheckSpecs() {
return RET_ERROR;
}
auto *param = reinterpret_cast<const ArithmeticParameter *>(op_parameter_);
if (param->broadcasting_ && out_tensors_.front()->DimensionSize(0) > 1) {
MS_LOG(ERROR) << "Broadcasting don't support N > 1";
return RET_ERROR;
}
if (!IsArithmetic(Type())) {
MS_LOG(ERROR) << "UnSupported Operator: " << schema::EnumNamePrimitiveType(Type());
return RET_ERROR;

View File

@ -112,6 +112,20 @@ TEST_F(TestOpenCL_Arithmetic, BroadcastSub2) {
}
}
TEST_F(TestOpenCL_Arithmetic, BroadcastSub3) {
std::vector<int> input0_shape = {2, 3};
std::vector<int> input1_shape = {2, 2, 2, 3};
std::vector<int> output_shape = {2, 2, 2, 3};
float input0_data[] = {1, 2, 3, 1, 2, 3};
float input1_data[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
float output_data[] = {0, 0, 0, -3, -3, -3, -6, -6, -6, -9, -9, -9, 0, 0, 0, -3, -3, -3, -6, -6, -6, -9, -9, -9};
for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter(schema::PrimitiveType_Sub, input0_shape, input1_shape);
TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, CONST_TENSOR}}, {output_shape, output_data},
param, fp16_enable);
}
}
TEST_F(TestOpenCL_Arithmetic, BroadcastFloorMod) {
std::vector<int> input0_shape = {1, 1, 3, 4};
std::vector<int> input1_shape = {1, 1, 1, 4};