forked from mindspore-Ecosystem/mindspore
!12217 [MS][LITE][GPU]arithmetic broadcast support n>1
From: @chenzupeng Reviewed-by: @ddwsky,@zhanghaibo5 Signed-off-by: @ddwsky,@zhanghaibo5
This commit is contained in:
commit
028dcf85ae
|
@ -265,18 +265,22 @@ __kernel void BroadcastNHWC4Add(__read_only image2d_t input_a, __read_only image
|
|||
const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {
|
||||
int X = get_global_id(0); // C4
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // H
|
||||
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y) {
|
||||
int Z = get_global_id(2); // N * H
|
||||
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y * output_shape.x) {
|
||||
return;
|
||||
}
|
||||
int a_c = X < a_shape.w ? X : a_shape.w - 1;
|
||||
int a_w = Y < a_shape.z ? Y : a_shape.z - 1;
|
||||
int a_h = Z < a_shape.y ? Z : a_shape.y - 1;
|
||||
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_h));
|
||||
int b_c = X < b_shape.w ? X : b_shape.w - 1;
|
||||
int b_w = Y < b_shape.z ? Y : b_shape.z - 1;
|
||||
int b_h = Z < b_shape.y ? Z : b_shape.y - 1;
|
||||
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h));
|
||||
int H = Z % output_shape.y;
|
||||
int N = Z / output_shape.y;
|
||||
int a_c = X < a_shape.w ? X : 0;
|
||||
int a_w = Y < a_shape.z ? Y : 0;
|
||||
int a_h = H < a_shape.y ? H : 0;
|
||||
int a_n = N < a_shape.x ? N : 0;
|
||||
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h));
|
||||
int b_c = X < b_shape.w ? X : 0;
|
||||
int b_w = Y < b_shape.z ? Y : 0;
|
||||
int b_h = H < b_shape.y ? H : 0;
|
||||
int b_n = N < b_shape.x ? N : 0;
|
||||
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h));
|
||||
FLT4 result;
|
||||
if (broadcastC_flag == 0) {
|
||||
result = a + b;
|
||||
|
@ -294,18 +298,22 @@ __kernel void BroadcastNHWC4BiasAdd(__read_only image2d_t input_a, __read_only i
|
|||
const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {
|
||||
int X = get_global_id(0); // C4
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // H
|
||||
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y) {
|
||||
int Z = get_global_id(2); // N * H
|
||||
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y * output_shape.x) {
|
||||
return;
|
||||
}
|
||||
int a_c = X < a_shape.w ? X : a_shape.w - 1;
|
||||
int a_w = Y < a_shape.z ? Y : a_shape.z - 1;
|
||||
int a_h = Z < a_shape.y ? Z : a_shape.y - 1;
|
||||
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_h));
|
||||
int b_c = X < b_shape.w ? X : b_shape.w - 1;
|
||||
int b_w = Y < b_shape.z ? Y : b_shape.z - 1;
|
||||
int b_h = Z < b_shape.y ? Z : b_shape.y - 1;
|
||||
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h));
|
||||
int H = Z % output_shape.y;
|
||||
int N = Z / output_shape.y;
|
||||
int a_c = X < a_shape.w ? X : 0;
|
||||
int a_w = Y < a_shape.z ? Y : 0;
|
||||
int a_h = H < a_shape.y ? H : 0;
|
||||
int a_n = N < a_shape.x ? N : 0;
|
||||
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h));
|
||||
int b_c = X < b_shape.w ? X : 0;
|
||||
int b_w = Y < b_shape.z ? Y : 0;
|
||||
int b_h = H < b_shape.y ? H : 0;
|
||||
int b_n = N < b_shape.x ? N : 0;
|
||||
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h));
|
||||
FLT4 result;
|
||||
if (broadcastC_flag == 0) {
|
||||
result = a + b;
|
||||
|
@ -323,18 +331,22 @@ __kernel void BroadcastNHWC4Sub(__read_only image2d_t input_a, __read_only image
|
|||
const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {
|
||||
int X = get_global_id(0); // C4
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // H
|
||||
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y) {
|
||||
int Z = get_global_id(2); // N * H
|
||||
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y * output_shape.x) {
|
||||
return;
|
||||
}
|
||||
int a_c = X < a_shape.w ? X : a_shape.w - 1;
|
||||
int a_w = Y < a_shape.z ? Y : a_shape.z - 1;
|
||||
int a_h = Z < a_shape.y ? Z : a_shape.y - 1;
|
||||
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_h));
|
||||
int b_c = X < b_shape.w ? X : b_shape.w - 1;
|
||||
int b_w = Y < b_shape.z ? Y : b_shape.z - 1;
|
||||
int b_h = Z < b_shape.y ? Z : b_shape.y - 1;
|
||||
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h));
|
||||
int H = Z % output_shape.y;
|
||||
int N = Z / output_shape.y;
|
||||
int a_c = X < a_shape.w ? X : 0;
|
||||
int a_w = Y < a_shape.z ? Y : 0;
|
||||
int a_h = H < a_shape.y ? H : 0;
|
||||
int a_n = N < a_shape.x ? N : 0;
|
||||
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h));
|
||||
int b_c = X < b_shape.w ? X : 0;
|
||||
int b_w = Y < b_shape.z ? Y : 0;
|
||||
int b_h = H < b_shape.y ? H : 0;
|
||||
int b_n = N < b_shape.x ? N : 0;
|
||||
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h));
|
||||
FLT4 result;
|
||||
if (broadcastC_flag == 0) {
|
||||
result = a - b;
|
||||
|
@ -352,18 +364,22 @@ __kernel void BroadcastNHWC4Mul(__read_only image2d_t input_a, __read_only image
|
|||
const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {
|
||||
int X = get_global_id(0); // C4
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // H
|
||||
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y) {
|
||||
int Z = get_global_id(2); // N * H
|
||||
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y * output_shape.x) {
|
||||
return;
|
||||
}
|
||||
int a_c = X < a_shape.w ? X : a_shape.w - 1;
|
||||
int a_w = Y < a_shape.z ? Y : a_shape.z - 1;
|
||||
int a_h = Z < a_shape.y ? Z : a_shape.y - 1;
|
||||
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_h));
|
||||
int b_c = X < b_shape.w ? X : b_shape.w - 1;
|
||||
int b_w = Y < b_shape.z ? Y : b_shape.z - 1;
|
||||
int b_h = Z < b_shape.y ? Z : b_shape.y - 1;
|
||||
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h));
|
||||
int H = Z % output_shape.y;
|
||||
int N = Z / output_shape.y;
|
||||
int a_c = X < a_shape.w ? X : 0;
|
||||
int a_w = Y < a_shape.z ? Y : 0;
|
||||
int a_h = H < a_shape.y ? H : 0;
|
||||
int a_n = N < a_shape.x ? N : 0;
|
||||
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h));
|
||||
int b_c = X < b_shape.w ? X : 0;
|
||||
int b_w = Y < b_shape.z ? Y : 0;
|
||||
int b_h = H < b_shape.y ? H : 0;
|
||||
int b_n = N < b_shape.x ? N : 0;
|
||||
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h));
|
||||
FLT4 result;
|
||||
if (broadcastC_flag == 0) {
|
||||
result = a * b;
|
||||
|
@ -381,18 +397,22 @@ __kernel void BroadcastNHWC4Div(__read_only image2d_t input_a, __read_only image
|
|||
const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {
|
||||
int X = get_global_id(0); // C4
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // H
|
||||
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y) {
|
||||
int Z = get_global_id(2); // N * H
|
||||
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y * output_shape.x) {
|
||||
return;
|
||||
}
|
||||
int a_c = X < a_shape.w ? X : a_shape.w - 1;
|
||||
int a_w = Y < a_shape.z ? Y : a_shape.z - 1;
|
||||
int a_h = Z < a_shape.y ? Z : a_shape.y - 1;
|
||||
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_h));
|
||||
int b_c = X < b_shape.w ? X : b_shape.w - 1;
|
||||
int b_w = Y < b_shape.z ? Y : b_shape.z - 1;
|
||||
int b_h = Z < b_shape.y ? Z : b_shape.y - 1;
|
||||
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h));
|
||||
int H = Z % output_shape.y;
|
||||
int N = Z / output_shape.y;
|
||||
int a_c = X < a_shape.w ? X : 0;
|
||||
int a_w = Y < a_shape.z ? Y : 0;
|
||||
int a_h = H < a_shape.y ? H : 0;
|
||||
int a_n = N < a_shape.x ? N : 0;
|
||||
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h));
|
||||
int b_c = X < b_shape.w ? X : 0;
|
||||
int b_w = Y < b_shape.z ? Y : 0;
|
||||
int b_h = H < b_shape.y ? H : 0;
|
||||
int b_n = N < b_shape.x ? N : 0;
|
||||
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h));
|
||||
FLT4 result;
|
||||
if (broadcastC_flag == 0) {
|
||||
result = a / b;
|
||||
|
|
|
@ -46,10 +46,6 @@ int ArithmeticOpenCLKernel::CheckSpecs() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto *param = reinterpret_cast<const ArithmeticParameter *>(op_parameter_);
|
||||
if (param->broadcasting_ && out_tensors_.front()->DimensionSize(0) > 1) {
|
||||
MS_LOG(ERROR) << "Broadcasting don't support N > 1";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!IsArithmetic(Type())) {
|
||||
MS_LOG(ERROR) << "UnSupported Operator: " << schema::EnumNamePrimitiveType(Type());
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -112,6 +112,20 @@ TEST_F(TestOpenCL_Arithmetic, BroadcastSub2) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(TestOpenCL_Arithmetic, BroadcastSub3) {
|
||||
std::vector<int> input0_shape = {2, 3};
|
||||
std::vector<int> input1_shape = {2, 2, 2, 3};
|
||||
std::vector<int> output_shape = {2, 2, 2, 3};
|
||||
float input0_data[] = {1, 2, 3, 1, 2, 3};
|
||||
float input1_data[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
|
||||
float output_data[] = {0, 0, 0, -3, -3, -3, -6, -6, -6, -9, -9, -9, 0, 0, 0, -3, -3, -3, -6, -6, -6, -9, -9, -9};
|
||||
for (auto fp16_enable : {false, true}) {
|
||||
auto *param = CreateParameter(schema::PrimitiveType_Sub, input0_shape, input1_shape);
|
||||
TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, CONST_TENSOR}}, {output_shape, output_data},
|
||||
param, fp16_enable);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TestOpenCL_Arithmetic, BroadcastFloorMod) {
|
||||
std::vector<int> input0_shape = {1, 1, 3, 4};
|
||||
std::vector<int> input1_shape = {1, 1, 1, 4};
|
||||
|
|
Loading…
Reference in New Issue