!6778 [MS][LITE][Develop] optimization GPU ops concat

Merge pull request !6778 from pengyongrong/optimizationForConcat
This commit is contained in:
mindspore-ci-bot 2020-09-30 11:47:05 +08:00 committed by Gitee
commit 640c1eb19c
5 changed files with 734 additions and 398 deletions

View File

@ -510,9 +510,9 @@ gene_clhpp() {
do
file="$(basename ${file_path})"
inc_file=$(echo ${CL_SRC_DIR}/${file} | sed 's/$/.inc/')
sed 's/^/\"/;s/$/ \\n\" \\/' ${CL_SRC_DIR}/${file} > ${inc_file}
sed 's/\\/\\\\/g;s/\"/\\\"/g;s/^/\"/;s/$/\\n\" \\/' ${CL_SRC_DIR}/${file} > ${inc_file}
kernel_name=$(echo ${file} | sed s'/.\{3\}$//')
sed -i "1i\static const char *${kernel_name}_source =\"\\n\" \\" ${inc_file}
sed -i "1i\static const char *${kernel_name}_source =\"\\n\" \\" ${inc_file}
sed -i '$a\;' ${inc_file}
done
}

View File

@ -1,339 +1,545 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
__kernel void Concat2input_NHWC4(__read_only image2d_t input0, __read_only image2d_t input1,
__write_only image2d_t output, int4 input_shape0, int4 input_shape1, int4 output_shape,
const int axis) {
int X = get_global_id(0); // N*H
int Y = get_global_id(1); // W
int Z = get_global_id(2); // c/4
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {
return;
}
if (axis == 1) {
if (X < input_shape0.y) {
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
} else {
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z, (X - input_shape0.y)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
}
} else if (axis == 2) {
if (Y < input_shape0.z) {
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
} else {
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z) * input_shape1.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
}
} else {
if (Z < input_shape0.w) {
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
} else {
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z - input_shape0.w, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
}
}
}
#define CHECK_IDXConcat2input_NHWC4 \
int X = get_global_id(0); \
int Y = get_global_id(1); \
int Z = get_global_id(2); \
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { \
return; \
} \
FLT4 result;
__kernel void Concat3input_NHWC4(__read_only image2d_t input0, __read_only image2d_t input1,
__read_only image2d_t input2, __write_only image2d_t output, int4 input_shape0,
int4 input_shape1, int4 input_shape2, int4 output_shape, const int axis) {
int X = get_global_id(0); // N*H
int Y = get_global_id(1); // W
int Z = get_global_id(2); // c/4
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {
return;
}
if (axis == 1) {
if (X < input_shape0.y) {
FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0);
} else if (X < (input_shape0.y + input_shape1.y)) {
FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z, (X - input_shape0.y)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1);
} else {
FLT4 result2 =
READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z, (X - input_shape0.y - input_shape1.y)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result2);
}
} else if (axis == 2) {
if (Y < input_shape0.z) {
FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0);
} else if (Y < (input_shape0.z + input_shape1.z)) {
FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z) * input_shape1.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1);
} else {
FLT4 result2 =
READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z) * input_shape2.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result2);
}
} else {
if (Z < input_shape0.w) {
FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0);
} else if (Z < (input_shape0.w + input_shape1.w)) {
FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z - input_shape0.w, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1);
} else {
FLT4 result2 =
READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z - input_shape0.w - input_shape1.w, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result2);
}
}
}
#define DOConcat2inputaxis1_NHWC4 \
if (X < input_shape0.y) { \
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
} else { \
result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z, (X - input_shape0.y))); \
} \
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
__kernel void Concat2input_NC4HW4(__read_only image2d_t input0, __read_only image2d_t input1,
__write_only image2d_t output, int4 input_shape0, int4 input_shape1,
int4 output_shape, const int axis) {
int X = get_global_id(0); // H
int Y = get_global_id(1); // W
int Z = get_global_id(2); // c/4
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {
return;
}
if (input_shape0.y == 0 || input_shape1.y == 0 || output_shape.y == 0) {
return;
}
int in_postion_x;
int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y;
if (axis == 1) {
if (X < input_shape0.y) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else {
in_postion_x = ((X - input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y +
((X - input_shape0.y) % input_shape1.y);
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
}
} else if (axis == 2) {
if (Y < input_shape0.z) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else {
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + (X % input_shape1.y);
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
}
} else {
if (Z < input_shape0.w) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else {
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + (Z - input_shape0.w) * input_shape1.y +
(X % input_shape1.y);
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
}
}
}
#define DOConcat2inputaxis2_NHWC4 \
if (Y < input_shape0.z) { \
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
} else { \
result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z) * input_shape1.w + Z, (X))); \
} \
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
__kernel void Concat3input_NC4HW4(__read_only image2d_t input0, __read_only image2d_t input1,
__read_only image2d_t input2, __write_only image2d_t output, int4 input_shape0,
int4 input_shape1, int4 input_shape2, int4 output_shape, const int axis) {
int X = get_global_id(0); // N*H
int Y = get_global_id(1); // W
int Z = get_global_id(2); // c/4
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {
return;
}
if (input_shape0.y == 0 || input_shape1.y == 0 || input_shape2.y == 0 || output_shape.y == 0) {
return;
}
int in_postion_x;
int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y;
if (axis == 1) {
if (X < input_shape0.y) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else if (X < input_shape0.y + input_shape1.y) {
in_postion_x = ((X - input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y +
((X - input_shape0.y) % input_shape1.y);
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else {
in_postion_x = ((X - input_shape0.y - input_shape1.y) / input_shape2.y) * input_shape2.w * input_shape2.y +
Z * input_shape2.y + ((X - input_shape0.y - input_shape1.y) % input_shape2.y);
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
}
} else if (axis == 2) {
if (Y < input_shape0.z) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else if (Y < input_shape0.z + input_shape1.z) {
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + (X % input_shape1.y);
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else {
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y + Z * input_shape2.y + (X % input_shape2.y);
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
}
} else {
if (Z < input_shape0.w) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else if (Z < input_shape0.w + input_shape1.w) {
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + (Z - input_shape0.w) * input_shape1.y +
(X % input_shape1.y);
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else {
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y +
(Z - input_shape0.w - input_shape1.w) * input_shape2.y + (X % input_shape2.y);
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
}
}
}
#define DOConcat2inputaxis3_NHWC4 \
if (Z < input_shape0.w) { \
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
} else { \
result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z - input_shape0.w, (X))); \
} \
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
#define CHECK_IDXConcat2input_NC4HW4 \
int X = get_global_id(0); \
int Y = get_global_id(1); \
int Z = get_global_id(2); \
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { \
return; \
} \
if (input_shape0.y == 0 || input_shape1.y == 0 || output_shape.y == 0) { \
return; \
} \
int in_postion_x; \
int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y; \
FLT4 result;
#define DOConcat2inputaxis1_NC4HW4 \
if (X < input_shape0.y) { \
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
} else { \
in_postion_x = ((X - input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + \
((X - input_shape0.y) % input_shape1.y); \
result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x)); \
} \
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
#define DOConcat2inputaxis2_NC4HW4 \
if (Y < input_shape0.z) { \
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
} else { \
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + (X % input_shape1.y); \
result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z), in_postion_x)); \
} \
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
#define DOConcat2inputaxis3_NC4HW4 \
if (Z < input_shape0.w) { \
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
} else { \
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + (Z - input_shape0.w) * input_shape1.y + \
(X % input_shape1.y); \
result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x)); \
} \
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
#define CHECK_IDXConcat3input_NC4HW4 \
int X = get_global_id(0); \
int Y = get_global_id(1); \
int Z = get_global_id(2); \
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { \
return; \
} \
if (input_shape0.y == 0 || input_shape1.y == 0 || input_shape2.y == 0 || output_shape.y == 0) { \
return; \
} \
int in_postion_x; \
int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y; \
FLT4 result;
#define DOConcat3inputaxis1_NC4HW4 \
if (X < input_shape0.y) { \
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
} else if (X < input_shape0.y + input_shape1.y) { \
in_postion_x = ((X - input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + \
((X - input_shape0.y) % input_shape1.y); \
result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x)); \
} else { \
in_postion_x = ((X - input_shape0.y - input_shape1.y) / input_shape2.y) * input_shape2.w * input_shape2.y + \
Z * input_shape2.y + ((X - input_shape0.y - input_shape1.y) % input_shape2.y); \
result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x)); \
} \
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
#define DOConcat3inputaxis2_NC4HW4 \
if (Y < input_shape0.z) { \
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
} else if (Y < input_shape0.z + input_shape1.z) { \
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + (X % input_shape1.y); \
result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z), in_postion_x)); \
} else { \
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y + Z * input_shape2.y + (X % input_shape2.y); \
result = READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z), in_postion_x)); \
} \
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
#define DOConcat3inputaxis3_NC4HW4 \
if (Z < input_shape0.w) { \
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
} else if (Z < input_shape0.w + input_shape1.w) { \
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + (Z - input_shape0.w) * input_shape1.y + \
(X % input_shape1.y); \
result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x)); \
} else { \
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y + \
(Z - input_shape0.w - input_shape1.w) * input_shape2.y + (X % input_shape2.y); \
result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x)); \
} \
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
#define CHECK_IDXConcat3input_NHWC4 \
int X = get_global_id(0); \
int Y = get_global_id(1); \
int Z = get_global_id(2); \
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { \
return; \
} \
FLT4 result;
#define DOConcat3inputaxis1_NHWC4 \
if (X < input_shape0.y) { \
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
} else if (X < (input_shape0.y + input_shape1.y)) { \
result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z, (X - input_shape0.y))); \
} else { \
result = READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z, (X - input_shape0.y - input_shape1.y))); \
} \
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
#define DOConcat3inputaxis2_NHWC4 \
if (Y < input_shape0.z) { \
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
} else if (Y < (input_shape0.z + input_shape1.z)) { \
result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z) * input_shape1.w + Z, (X))); \
} else { \
result = READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z) * input_shape2.w + Z, (X))); \
} \
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
#define DOConcat3inputaxis3_NHWC4 \
if (Z < input_shape0.w) { \
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
} else if (Z < (input_shape0.w + input_shape1.w)) { \
result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z - input_shape0.w, (X))); \
} else { \
result = READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z - input_shape0.w - input_shape1.w, (X))); \
} \
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
#define CHECK_IDXConcat4input_NHWC4 \
int X = get_global_id(0); \
int Y = get_global_id(1); \
int Z = get_global_id(2); \
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { \
return; \
} \
FLT4 result;
#define DOConcat4inputaxis1_NHWC4 \
if (X < input_shape0.y) { \
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
} else if (X < (input_shape0.y + input_shape1.y)) { \
result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z, (X - input_shape0.y))); \
} else if (X < (input_shape0.y + input_shape1.y + input_shape2.y)) { \
result = READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z, (X - input_shape0.y - input_shape1.y))); \
} else { \
result = READ_IMAGE(input3, smp_none, \
(int2)((Y)*input_shape3.w + Z, (X - input_shape0.y - input_shape1.y - input_shape2.y))); \
} \
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
#define DOConcat4inputaxis2_NHWC4 \
if (Y < input_shape0.z) { \
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
} else if (Y < (input_shape0.z + input_shape1.z)) { \
result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z) * input_shape1.w + Z, (X))); \
} else if (Y < (input_shape0.z + input_shape1.z + input_shape2.z)) { \
result = READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z) * input_shape2.w + Z, (X))); \
} else { \
result = READ_IMAGE(input3, smp_none, \
(int2)((Y - input_shape0.z - input_shape1.z - input_shape2.z) * input_shape3.w + Z, (X))); \
} \
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
#define DOConcat4inputaxis3_NHWC4 \
if (Z < input_shape0.w) { \
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
} else if (Z < (input_shape0.w + input_shape1.w)) { \
result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z - input_shape0.w, (X))); \
} else if (Z < (input_shape0.w + input_shape1.w + input_shape2.w)) { \
result = READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z - input_shape0.w - input_shape1.w, (X))); \
} else { \
result = READ_IMAGE(input3, smp_none, \
(int2)((Y)*input_shape3.w + Z - input_shape0.w - input_shape1.w - input_shape2.w, (X))); \
} \
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
#define CHECK_IDXConcat4input_NC4HW4 \
int X = get_global_id(0); \
int Y = get_global_id(1); \
int Z = get_global_id(2); \
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { \
return; \
} \
if (input_shape0.y == 0 || input_shape1.y == 0 || input_shape2.y == 0 || input_shape3.y == 0 || \
output_shape.y == 0) { \
return; \
} \
int in_postion_x; \
int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y; \
FLT4 result;
#define DOConcat4inputaxis1_NC4HW4 \
if (X < input_shape0.y) { \
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
} else if (X < input_shape0.y + input_shape1.y) { \
in_postion_x = ((X - input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + \
((X - input_shape0.y) % input_shape1.y); \
result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x)); \
} else if (X < input_shape0.y + input_shape1.y + input_shape2.y) { \
in_postion_x = ((X - input_shape0.y - input_shape1.y) / input_shape2.y) * input_shape2.w * input_shape2.y + \
Z * input_shape2.y + ((X - input_shape0.y - input_shape1.y) % input_shape2.y); \
result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x)); \
} else { \
in_postion_x = \
((X - input_shape0.y - input_shape1.y - input_shape2.y) / input_shape3.y) * input_shape3.w * input_shape3.y + \
Z * input_shape3.y + ((X - input_shape0.y - input_shape1.y - input_shape2.y) % input_shape3.y); \
result = READ_IMAGE(input3, smp_none, (int2)((Y), in_postion_x)); \
} \
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
#define DOConcat4inputaxis2_NC4HW4 \
if (Y < input_shape0.z) { \
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
} else if (Y < input_shape0.z + input_shape1.z) { \
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + (X % input_shape1.y); \
result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z), in_postion_x)); \
} else if (Y < input_shape0.z + input_shape1.z + input_shape2.z) { \
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y + Z * input_shape2.y + (X % input_shape2.y); \
result = READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z), in_postion_x)); \
} else { \
in_postion_x = (X / input_shape3.y) * input_shape3.w * input_shape3.y + Z * input_shape3.y + (X % input_shape3.y); \
result = \
READ_IMAGE(input3, smp_none, (int2)((Y - input_shape0.z - input_shape1.z - input_shape2.z), in_postion_x)); \
} \
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
#define DOConcat4inputaxis3_NC4HW4 \
if (Z < input_shape0.w) { \
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
} else if (Z < input_shape0.w + input_shape1.w) { \
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + (Z - input_shape0.w) * input_shape1.y + \
(X % input_shape1.y); \
result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x)); \
} else if (Z < input_shape0.w + input_shape1.w + input_shape2.w) { \
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y + \
(Z - input_shape0.w - input_shape1.w) * input_shape2.y + (X % input_shape2.y); \
result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x)); \
} else { \
in_postion_x = (X / input_shape3.y) * input_shape3.w * input_shape3.y + \
(Z - input_shape0.w - input_shape1.w - input_shape2.w) * input_shape3.y + (X % input_shape3.y); \
result = READ_IMAGE(input3, smp_none, (int2)((Y), in_postion_x)); \
} \
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
__kernel void Concat4input_NC4HW4(__read_only image2d_t input0, __read_only image2d_t input1,
__read_only image2d_t input2, __read_only image2d_t input3,
__write_only image2d_t output, int4 input_shape0, int4 input_shape1,
int4 input_shape2, int4 input_shape3, int4 output_shape, const int axis) {
int X = get_global_id(0); // N*H
int Y = get_global_id(1); // W
int Z = get_global_id(2); // c/4
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {
return;
}
if (input_shape0.y == 0 || input_shape1.y == 0 || input_shape2.y == 0 || output_shape.y == 0) {
return;
}
int in_postion_x;
int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y;
if (axis == 1) {
if (X < input_shape0.y) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else if (X < input_shape0.y + input_shape1.y) {
in_postion_x = ((X - input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y +
((X - input_shape0.y) % input_shape1.y);
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else if (X < input_shape0.y + input_shape1.y + input_shape2.y) {
in_postion_x = ((X - input_shape0.y - input_shape1.y) / input_shape2.y) * input_shape2.w * input_shape2.y +
Z * input_shape2.y + ((X - input_shape0.y - input_shape1.y) % input_shape2.y);
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else {
in_postion_x =
((X - input_shape0.y - input_shape1.y - input_shape2.y) / input_shape3.y) * input_shape3.w * input_shape3.y +
Z * input_shape3.y + ((X - input_shape0.y - input_shape1.y - input_shape2.y) % input_shape3.y);
FLT4 result = READ_IMAGE(input3, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
}
} else if (axis == 2) {
if (Y < input_shape0.z) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else if (Y < input_shape0.z + input_shape1.z) {
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + (X % input_shape1.y);
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else if (Y < input_shape0.z + input_shape1.z + input_shape2.z) {
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y + Z * input_shape2.y + (X % input_shape2.y);
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else {
in_postion_x = (X / input_shape3.y) * input_shape3.w * input_shape3.y + Z * input_shape3.y + (X % input_shape3.y);
FLT4 result =
READ_IMAGE(input3, smp_none, (int2)((Y - input_shape0.z - input_shape1.z - input_shape2.z), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
}
} else {
if (Z < input_shape0.w) {
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else if (Z < input_shape0.w + input_shape1.w) {
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + (Z - input_shape0.w) * input_shape1.y +
(X % input_shape1.y);
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else if (Z < input_shape0.w + input_shape1.w + input_shape2.w) {
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y +
(Z - input_shape0.w - input_shape1.w) * input_shape2.y + (X % input_shape2.y);
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
} else {
in_postion_x = (X / input_shape3.y) * input_shape3.w * input_shape3.y +
(Z - input_shape0.w - input_shape1.w - input_shape2.w) * input_shape3.y + (X % input_shape3.y);
FLT4 result = READ_IMAGE(input3, smp_none, (int2)((Y), in_postion_x));
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
}
}
}
int4 input_shape2, int4 input_shape3, int4 output_shape, const int axis) {}
__kernel void Concat4input_NHWC4(__read_only image2d_t input0, __read_only image2d_t input1,
__read_only image2d_t input2, __read_only image2d_t input3,
__write_only image2d_t output, int4 input_shape0, int4 input_shape1, int4 input_shape2,
int4 input_shape3, int4 output_shape, const int axis) {
int X = get_global_id(0); // N*H
int Y = get_global_id(1); // W
int Z = get_global_id(2); // c/4
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {
return;
#define CHECK_IDXConcat6input_NHWC4 \
int X = get_global_id(0); \
int Y = get_global_id(1); \
int Z = get_global_id(2); \
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { \
return; \
} \
FLT4 result;
#define DOConcat6inputaxis1_NHWC4 \
if (X < input_shape0.y) { \
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
} else if (X < (input_shape0.y + input_shape1.y)) { \
result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z, (X - input_shape0.y))); \
} else if (X < (input_shape0.y + input_shape1.y + input_shape2.y)) { \
result = READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z, (X - input_shape0.y - input_shape1.y))); \
} else if (X < (input_shape0.y + input_shape1.y + input_shape2.y + input_shape3.y)) { \
result = READ_IMAGE(input3, smp_none, \
(int2)((Y)*input_shape3.w + Z, (X - input_shape0.y - input_shape1.y - input_shape2.y))); \
} else if (X < (input_shape0.y + input_shape1.y + input_shape2.y + input_shape3.y + input_shape4.y)) { \
result = READ_IMAGE( \
input4, smp_none, \
(int2)((Y)*input_shape4.w + Z, (X - input_shape0.y - input_shape1.y - input_shape2.y - input_shape3.y))); \
} else { \
result = READ_IMAGE(input5, smp_none, \
(int2)((Y)*input_shape5.w + Z, (X - input_shape0.y - input_shape1.y - input_shape2.y - \
input_shape3.y - input_shape4.y))); \
} \
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
#define DOConcat6inputaxis2_NHWC4 \
if (Y < input_shape0.z) { \
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
} else if (Y < (input_shape0.z + input_shape1.z)) { \
result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z) * input_shape1.w + Z, (X))); \
} else if (Y < (input_shape0.z + input_shape1.z + input_shape2.z)) { \
result = READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z) * input_shape2.w + Z, (X))); \
} else if (Y < (input_shape0.z + input_shape1.z + input_shape2.z + input_shape3.z)) { \
result = READ_IMAGE(input3, smp_none, \
(int2)((Y - input_shape0.z - input_shape1.z - input_shape2.z) * input_shape3.w + Z, (X))); \
} else if (Y < (input_shape0.z + input_shape1.z + input_shape2.z + input_shape3.z + input_shape4.z)) { \
result = READ_IMAGE( \
input4, smp_none, \
(int2)((Y - input_shape0.z - input_shape1.z - input_shape2.z - input_shape3.z) * input_shape4.w + Z, (X))); \
} else { \
result = READ_IMAGE( \
input5, smp_none, \
(int2)( \
(Y - input_shape0.z - input_shape1.z - input_shape2.z - input_shape3.z - input_shape4.z) * input_shape5.w + Z, \
(X))); \
} \
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
#define DOConcat6inputaxis3_NHWC4 \
if (Z < input_shape0.w) { \
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
} else if (Z < (input_shape0.w + input_shape1.w)) { \
result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z - input_shape0.w, (X))); \
} else if (Z < (input_shape0.w + input_shape1.w + input_shape2.w)) { \
result = READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z - input_shape0.w - input_shape1.w, (X))); \
} else if (Z < (input_shape0.w + input_shape1.w + input_shape2.w + input_shape3.w)) { \
result = READ_IMAGE(input3, smp_none, \
(int2)((Y)*input_shape3.w + Z - input_shape0.w - input_shape1.w - input_shape2.w, (X))); \
} else if (Z < (input_shape0.w + input_shape1.w + input_shape2.w + input_shape3.w + input_shape4.w)) { \
result = READ_IMAGE( \
input4, smp_none, \
(int2)((Y)*input_shape4.w + Z - input_shape0.w - input_shape1.w - input_shape2.w - input_shape3.w, (X))); \
} else { \
result = READ_IMAGE(input5, smp_none, \
(int2)((Y)*input_shape5.w + Z - input_shape0.w - input_shape1.w - input_shape2.w - \
input_shape3.w - input_shape4.w, \
(X))); \
} \
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
#define CHECK_IDXConcat6input_NC4HW4 \
int X = get_global_id(0); \
int Y = get_global_id(1); \
int Z = get_global_id(2); \
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { \
return; \
} \
if (input_shape0.y == 0 || input_shape1.y == 0 || input_shape2.y == 0 || input_shape3.y == 0 || \
input_shape4.y == 0 || input_shape5.y == 0 || output_shape.y == 0) { \
return; \
} \
int in_postion_x; \
FLT4 result; \
int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y;
#define DOConcat6inputaxis1_NC4HW4 \
if (X < input_shape0.y) { \
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
} else if (X < (input_shape0.y + input_shape1.y)) { \
in_postion_x = ((X - input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + \
((X - input_shape0.y) % input_shape1.y); \
result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x)); \
} else if (X < (input_shape0.y + input_shape1.y + input_shape2.y)) { \
in_postion_x = ((X - input_shape0.y - input_shape1.y) / input_shape2.y) * input_shape2.w * input_shape2.y + \
Z * input_shape2.y + ((X - input_shape0.y - input_shape1.y) % input_shape2.y); \
result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x)); \
} else if (X < (input_shape0.y + input_shape1.y + input_shape2.y + input_shape3.y)) { \
in_postion_x = \
((X - input_shape0.y - input_shape1.y - input_shape2.y) / input_shape3.y) * input_shape3.w * input_shape3.y + \
Z * input_shape3.y + ((X - input_shape0.y - input_shape1.y - input_shape2.y) % input_shape3.y); \
result = READ_IMAGE(input3, smp_none, (int2)((Y), in_postion_x)); \
} else if (X < (input_shape0.y + input_shape1.y + input_shape2.y + input_shape3.y + input_shape4.y)) { \
in_postion_x = ((X - input_shape0.y - input_shape1.y - input_shape2.y - input_shape3.y) / input_shape4.y) * \
input_shape4.w * input_shape4.y + \
Z * input_shape4.y + \
((X - input_shape0.y - input_shape1.y - input_shape2.y - input_shape3.y) % input_shape4.y); \
result = READ_IMAGE(input4, smp_none, (int2)((Y), in_postion_x)); \
} else { \
in_postion_x = \
((X - input_shape0.y - input_shape1.y - input_shape2.y - input_shape3.y - input_shape4.y) / input_shape5.y) * \
input_shape5.w * input_shape5.y + \
Z * input_shape5.y + \
((X - input_shape0.y - input_shape1.y - input_shape2.y - input_shape3.y - input_shape4.y) % input_shape5.y); \
result = READ_IMAGE(input5, smp_none, (int2)((Y), in_postion_x)); \
} \
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
#define DOConcat6inputaxis2_NC4HW4 \
if (Y < input_shape0.z) { \
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
} else if (Y < (input_shape0.z + input_shape1.z)) { \
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + (X % input_shape1.y); \
result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z), in_postion_x)); \
} else if (Y < (input_shape0.z + input_shape1.z + input_shape2.z)) { \
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y + Z * input_shape2.y + (X % input_shape2.y); \
result = READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z), in_postion_x)); \
} else if (Y < (input_shape0.z + input_shape1.z + input_shape2.z + input_shape3.z)) { \
in_postion_x = (X / input_shape3.y) * input_shape3.w * input_shape3.y + Z * input_shape3.y + (X % input_shape3.y); \
result = \
READ_IMAGE(input3, smp_none, (int2)((Y - input_shape0.z - input_shape1.z - input_shape2.z), in_postion_x)); \
} else if (Y < (input_shape0.z + input_shape1.z + input_shape2.z + input_shape3.z + input_shape4.z)) { \
in_postion_x = (X / input_shape4.y) * input_shape4.w * input_shape4.y + Z * input_shape4.y + (X % input_shape4.y); \
result = \
READ_IMAGE(input4, smp_none, \
(int2)((Y - input_shape0.z - input_shape1.z - input_shape2.z - input_shape3.z), in_postion_x)); \
} else { \
in_postion_x = (X / input_shape5.y) * input_shape5.w * input_shape5.y + Z * input_shape5.y + (X % input_shape5.y); \
result = READ_IMAGE( \
input5, smp_none, \
(int2)((Y - input_shape0.z - input_shape1.z - input_shape2.z - input_shape3.z - input_shape4.z), in_postion_x)); \
} \
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
#define DOConcat6inputaxis3_NC4HW4 \
if (Z < input_shape0.w) { \
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
} else if (Z < (input_shape0.w + input_shape1.w)) { \
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + (Z - input_shape0.w) * input_shape1.y + \
(X % input_shape1.y); \
result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x)); \
} else if (Z < (input_shape0.w + input_shape1.w + input_shape2.w)) { \
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y + \
(Z - input_shape0.w - input_shape1.w) * input_shape2.y + (X % input_shape2.y); \
result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x)); \
} else if (Z < (input_shape0.w + input_shape1.w + input_shape2.w + input_shape3.w)) { \
in_postion_x = (X / input_shape3.y) * input_shape3.w * input_shape3.y + \
(Z - input_shape0.w - input_shape1.w - input_shape2.w) * input_shape3.y + (X % input_shape3.y); \
result = READ_IMAGE(input3, smp_none, (int2)((Y), in_postion_x)); \
} else if (Z < (input_shape0.w + input_shape1.w + input_shape2.w + input_shape3.w + input_shape4.w)) { \
in_postion_x = (X / input_shape4.y) * input_shape4.w * input_shape4.y + \
(Z - input_shape0.w - input_shape1.w - input_shape2.w - input_shape3.w) * input_shape4.y + \
(X % input_shape4.y); \
result = READ_IMAGE(input4, smp_none, (int2)((Y), in_postion_x)); \
} else { \
in_postion_x = \
(X / input_shape5.y) * input_shape5.w * input_shape5.y + \
(Z - input_shape0.w - input_shape1.w - input_shape2.w - input_shape3.w - input_shape4.w) * input_shape5.y + \
(X % input_shape5.y); \
result = READ_IMAGE(input5, smp_none, (int2)((Y), in_postion_x)); \
} \
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
#define CONCAT6(Inputnum, Axis, ToFormat) \
__kernel void Concat##Inputnum##Axis##ToFormat( \
__read_only image2d_t input0, __read_only image2d_t input1, __read_only image2d_t input2, \
__read_only image2d_t input3, __read_only image2d_t input4, __read_only image2d_t input5, \
__write_only image2d_t output, int4 input_shape0, int4 input_shape1, int4 input_shape2, int4 input_shape3, \
int4 input_shape4, int4 input_shape5, int4 output_shape, const int axis) { \
CHECK_IDXConcat6input##ToFormat; \
DOConcat##Inputnum##Axis##ToFormat; \
}
if (axis == 1) {
if (X < input_shape0.y) {
FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0);
} else if (X < (input_shape0.y + input_shape1.y)) {
FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z, (X - input_shape0.y)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1);
} else if (X < (input_shape0.y + input_shape1.y + input_shape2.y)) {
FLT4 result2 =
READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z, (X - input_shape0.y - input_shape1.y)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result2);
} else {
FLT4 result3 = READ_IMAGE(input3, smp_none,
(int2)((Y)*input_shape3.w + Z, (X - input_shape0.y - input_shape1.y - input_shape2.y)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result3);
}
} else if (axis == 2) {
if (Y < input_shape0.z) {
FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0);
} else if (Y < (input_shape0.z + input_shape1.z)) {
FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z) * input_shape1.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1);
} else if (Y < (input_shape0.z + input_shape1.z + input_shape2.z)) {
FLT4 result2 =
READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z) * input_shape2.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result2);
} else {
FLT4 result3 = READ_IMAGE(
input3, smp_none, (int2)((Y - input_shape0.z - input_shape1.z - input_shape2.z) * input_shape3.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result3);
}
} else {
if (Z < input_shape0.w) {
FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0);
} else if (Z < (input_shape0.w + input_shape1.w)) {
FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z - input_shape0.w, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1);
} else if (Z < (input_shape0.w + input_shape1.w + input_shape2.w)) {
FLT4 result2 =
READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z - input_shape0.w - input_shape1.w, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result2);
} else {
FLT4 result3 = READ_IMAGE(input3, smp_none,
(int2)((Y)*input_shape3.w + Z - input_shape0.w - input_shape1.w - input_shape2.w, (X)));
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result3);
}
#define CONCAT4(Inputnum, Axis, ToFormat) \
__kernel void Concat##Inputnum##Axis##ToFormat( \
__read_only image2d_t input0, __read_only image2d_t input1, __read_only image2d_t input2, \
__read_only image2d_t input3, __write_only image2d_t output, int4 input_shape0, int4 input_shape1, \
int4 input_shape2, int4 input_shape3, int4 output_shape, const int axis) { \
CHECK_IDXConcat4input##ToFormat; \
DOConcat##Inputnum##Axis##ToFormat; \
}
}
#define CONCAT3(Inputnum, Axis, ToFormat) \
__kernel void Concat##Inputnum##Axis##ToFormat(__read_only image2d_t input0, __read_only image2d_t input1, \
__read_only image2d_t input2, __write_only image2d_t output, \
int4 input_shape0, int4 input_shape1, int4 input_shape2, \
int4 output_shape, const int axis) { \
CHECK_IDXConcat3input##ToFormat; \
DOConcat##Inputnum##Axis##ToFormat; \
}
#define CONCAT2(Inputnum, Axis, ToFormat) \
__kernel void Concat##Inputnum##Axis##ToFormat(__read_only image2d_t input0, __read_only image2d_t input1, \
__write_only image2d_t output, int4 input_shape0, int4 input_shape1, \
int4 output_shape, const int axis) { \
CHECK_IDXConcat2input##ToFormat; \
DOConcat##Inputnum##Axis##ToFormat; \
}
// nc4hw4 ?
CONCAT6(6input, axis1, _NC4HW4)
CONCAT6(6input, axis2, _NC4HW4)
CONCAT6(6input, axis3, _NC4HW4)
CONCAT4(4input, axis1, _NC4HW4)
CONCAT4(4input, axis2, _NC4HW4)
CONCAT4(4input, axis3, _NC4HW4)
CONCAT3(3input, axis1, _NC4HW4)
CONCAT3(3input, axis2, _NC4HW4)
CONCAT3(3input, axis3, _NC4HW4)
CONCAT2(2input, axis1, _NC4HW4)
CONCAT2(2input, axis2, _NC4HW4)
CONCAT2(2input, axis3, _NC4HW4)
// nhwc4?
CONCAT6(6input, axis1, _NHWC4)
CONCAT6(6input, axis2, _NHWC4)
CONCAT6(6input, axis3, _NHWC4)
CONCAT4(4input, axis1, _NHWC4)
CONCAT4(4input, axis2, _NHWC4)
CONCAT4(4input, axis3, _NHWC4)
CONCAT3(3input, axis1, _NHWC4)
CONCAT3(3input, axis2, _NHWC4)
CONCAT3(3input, axis3, _NHWC4)
CONCAT2(2input, axis1, _NHWC4)
CONCAT2(2input, axis2, _NHWC4)
CONCAT2(2input, axis3, _NHWC4)

View File

@ -93,13 +93,19 @@ int ConcatOpenCLKernel::Init() {
std::string kernel_name = "Concat";
if (in_tensors_.size() == 2) {
kernel_name += "2input";
kernel_name += "2inputaxis";
kernel_name += std::to_string(param->axis_);
} else if (in_tensors_.size() == 3) {
kernel_name += "3input";
kernel_name += "3inputaxis";
kernel_name += std::to_string(param->axis_);
} else if (in_tensors_.size() == 4) {
kernel_name += "4input";
kernel_name += "4inputaxis";
kernel_name += std::to_string(param->axis_);
} else if (in_tensors_.size() == 6) {
kernel_name += "6inputaxis";
kernel_name += std::to_string(param->axis_);
} else {
MS_LOG(ERROR) << " input must be 2 3 or 4";
MS_LOG(ERROR) << " input must be 2 , 3 , 4 or 6";
return RET_ERROR;
}
if (in_format == schema::Format_NC4HW4) {
@ -107,6 +113,7 @@ int ConcatOpenCLKernel::Init() {
} else if (in_format == schema::Format_NHWC4) {
kernel_name += "_NHWC4";
}
MS_LOG(DEBUG) << "kernel_name=: " << kernel_name;
std::set<std::string> build_options;
std::string source = concat_source;
std::string program_name = "Concat";
@ -118,16 +125,36 @@ int ConcatOpenCLKernel::Init() {
int ConcatOpenCLKernel::ReSize() { return RET_OK; }
int ConcatOpenCLKernel::GetSumShape(std::vector<int> *sum_shape, std::vector<int> *in_shape) {
std::vector<int> temp_sum = {0, 0, 0, 0};
for (int i = 0; i < in_tensors_.size(); ++i) {
auto temp = in_tensors_[i]->shape();
for (int j = 0; j < temp.size(); ++j) {
in_shape->push_back(temp[j]);
temp_sum.at(j) += temp[j];
sum_shape->push_back(temp_sum.at(j));
int ConcatOpenCLKernel::IntegraShapeToXYZ() {
auto in_format = op_format_;
if (out_tensors_[0]->shape().size() > 4 || out_tensors_[0]->shape().size() <= 0) {
MS_LOG(ERROR) << "in_tensors_.shape() must between 0~4";
return RET_ERROR;
}
if (in_format == schema::Format_NHWC4 || in_format == schema::Format_NC4HW4) {
for (int i = 0; i < in_tensors_.size(); ++i) {
cl_int4 temp_cl;
auto temp = in_tensors_[i]->shape();
temp_cl = {temp[0], temp[1], temp[2], UP_DIV(temp[3], C4NUM)};
XYZShape.push_back(temp_cl);
}
} else {
for (int i = 0; i < in_tensors_.size(); ++i) {
auto temp = in_tensors_[i]->shape();
for (int j = temp.size(); j < C4NUM; ++j) {
temp.push_back(1);
}
cl_int4 temp_cl = {temp[0], temp[1], temp[2], UP_DIV(temp[3], C4NUM)};
XYZShape.push_back(temp_cl);
}
auto temp = out_tensors_[0]->shape();
for (int i = out_tensors_[0]->shape().size(); i < C4NUM; ++i) {
temp.push_back(1);
}
}
shape_nhwc = {out_tensors_[0]->shape()[0] * out_tensors_[0]->shape()[1], out_tensors_[0]->shape()[2],
UP_DIV(out_tensors_[0]->shape()[3], C4NUM)};
return RET_OK;
}
@ -151,70 +178,31 @@ int ConcatOpenCLKernel::Run() {
if (param->axis_ == 0) {
return RunAxis0();
}
auto input1_shape = in_tensors_[0]->shape();
auto input2_shape = in_tensors_[1]->shape();
auto output_shape = out_tensors_[0]->shape();
cl_int4 input_shape1_ = {input1_shape[0], input1_shape[1], input1_shape[2], UP_DIV(input1_shape[3], C4NUM)};
cl_int4 input_shape2_ = {input2_shape[0], input2_shape[1], input2_shape[2], UP_DIV(input2_shape[3], C4NUM)};
cl_int4 output_shape_ = {output_shape[0], output_shape[1], output_shape[2], UP_DIV(output_shape[3], C4NUM)};
uint32_t OH = output_shape[0] * output_shape[1]; // N*H
uint32_t OW = output_shape[2];
uint32_t OC = UP_DIV(output_shape[3], C4NUM);
IntegraShapeToXYZ();
const std::vector<size_t> &max_global = ocl_runtime_->GetWorkItemSize();
std::vector<size_t> local = {1, 1, 1}; // init local
std::vector<size_t> global = {OH, OW, OC};
std::vector<size_t> local = {1, 1, 1};
std::vector<size_t> global = {static_cast<size_t>(shape_nhwc.s[0]), static_cast<size_t>(shape_nhwc.s[1]),
static_cast<size_t>(shape_nhwc.s[2])};
ConcatGetWorkGroup(global, &local, max_global[0]);
GetSumShape(&sum_shape, &in_shape);
int arg_cn = 0;
if (in_tensors_.size() == 2) {
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[1]->data_c());
if (in_tensors_.size() == 2 || in_tensors_.size() == 3 || in_tensors_.size() == 4 || in_tensors_.size() == 6) {
int arg_cn = 0;
for (int i = 0; i < in_tensors_.size(); ++i) {
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[i]->data_c());
}
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape1_);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape2_);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, output_shape_);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, param->axis_);
} else if (in_tensors_.size() == 3) {
auto input3_shape = in_tensors_[2]->shape();
cl_int4 input_shape3_ = {input3_shape[0], input3_shape[1], input3_shape[2], UP_DIV(input3_shape[3], C4NUM)};
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[1]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[2]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape1_);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape2_);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape3_);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, output_shape_);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, param->axis_);
} else if (in_tensors_.size() == 4) {
auto input3_shape = in_tensors_[2]->shape();
auto input4_shape = in_tensors_[3]->shape();
cl_int4 input_shape3_ = {input3_shape[0], input3_shape[1], input3_shape[2], UP_DIV(input3_shape[3], C4NUM)};
cl_int4 input_shape4_ = {input4_shape[0], input4_shape[1], input4_shape[2], UP_DIV(input4_shape[3], C4NUM)};
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[1]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[2]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[3]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape1_);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape2_);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape3_);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape4_);
for (int i = 0; i < XYZShape.size(); ++i) {
cl_int4 temp = {XYZShape[i].s[0], XYZShape[i].s[1], XYZShape[i].s[2], XYZShape[i].s[3]};
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, temp);
}
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, output_shape_);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, param->axis_);
} else {
MS_LOG(ERROR) << " input sizes must 2 or 3 or 4";
MS_LOG(ERROR) << "unsupported input size :" << in_tensors_.size();
return RET_ERROR;
}
ocl_runtime_->RunKernel(kernel_, global, local, nullptr);
return RET_OK;
}

View File

@ -41,12 +41,12 @@ class ConcatOpenCLKernel : public OpenCLKernel {
int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;
int GetSumShape(std::vector<int> *sum_shape, std::vector<int> *in_shape);
int IntegraShapeToXYZ();
private:
cl::Kernel kernel_;
std::vector<int> sum_shape;
std::vector<int> in_shape;
std::vector<cl_int3> XYZShape;
cl_int4 shape_nhwc;
};
} // namespace mindspore::kernel

View File

@ -145,7 +145,7 @@ TEST_F(TestConcatOpenCLCI, ConcatFp32_2inputforCI) {
delete sub_graph;
}
TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis1) {
TEST_F(TestConcatOpenCLfp16, ConcatFp16_4input_dim4_axis1) {
MS_LOG(INFO) << " begin test ";
auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance();
ocl_runtime->SetFp16Enable(true);
@ -274,7 +274,7 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis1) {
delete sub_graph;
}
TEST_F(TestConcatOpenCLfp32, ConcatFp32_2input_dim4_axis3) {
TEST_F(TestConcatOpenCLfp32, ConcatFp32_3input_dim4_axis1) {
MS_LOG(INFO) << " begin test ";
auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance();
ocl_runtime->Init();
@ -393,4 +393,146 @@ TEST_F(TestConcatOpenCLfp32, ConcatFp32_2input_dim4_axis3) {
}
delete sub_graph;
}
TEST_F(TestConcatOpenCLfp16, ConcatFp16_6input_dim4_axis1) {
MS_LOG(INFO) << " begin test ";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->SetFp16Enable(true);
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();
// get the input from .bin
size_t input1_size, input2_size, input3_size, input4_size, input5_size, input6_size, output_size;
std::string input1Ppath = "./test_data/concatfp16_input1.bin";
std::string input2Ppath = "./test_data/concatfp16_input2.bin";
std::string input3Ppath = "./test_data/concatfp16_input3.bin";
std::string input4Ppath = "./test_data/concatfp16_input4.bin";
std::string input5Ppath = "./test_data/concatfp16_input5.bin";
std::string input6Ppath = "./test_data/concatfp16_input6.bin";
std::string correctOutputPath = "./test_data/concatfp16_output.bin";
auto input_data1 = reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size));
auto input_data2 = reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size));
auto input_data3 = reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(input3Ppath.c_str(), &input3_size));
auto input_data4 = reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(input4Ppath.c_str(), &input4_size));
auto input_data5 = reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(input5Ppath.c_str(), &input5_size));
auto input_data6 = reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(input6Ppath.c_str(), &input6_size));
auto correctOutput =
reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size));
MS_LOG(INFO) << " init tensors ";
constexpr int INPUT_NUM = 6;
std::array<std::vector<int>, INPUT_NUM> input_shapes = {
std::vector<int>{1, 1200, 3, 4}, std::vector<int>{1, 600, 3, 4}, std::vector<int>{1, 150, 3, 4},
std::vector<int>{1, 50, 3, 4}, std::vector<int>{1, 30, 3, 4}, std::vector<int>{1, 4, 3, 4}};
std::vector<int> output_shape = {1, 2034, 3, 4};
auto data_type = kNumberTypeFloat16;
auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode);
std::vector<lite::Tensor *> inputs;
for (auto &shape : input_shapes) {
auto input_temp = new (std::nothrow) lite::Tensor(data_type, shape, schema::Format_NHWC, tensor_type);
inputs.push_back(input_temp);
if (input_temp == nullptr) {
MS_LOG(INFO) << " new input_tensor failed ";
return;
}
}
auto *output_tensor = new (std::nothrow) lite::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type);
if (output_tensor == nullptr) {
MS_LOG(INFO) << " new output_tensor failed ";
for (auto tensor : inputs) {
delete tensor;
}
return;
}
std::vector<lite::Tensor *> outputs{output_tensor};
MS_LOG(INFO) << " input_shapes size =: " << input_shapes.size();
MS_LOG(INFO) << " initialize tensors ";
auto param = reinterpret_cast<ConcatParameter *>(malloc(sizeof(ConcatParameter)));
if (param == nullptr) {
MS_LOG(INFO) << " new ConcatParameter failed ";
for (auto tensor : inputs) {
delete tensor;
}
for (auto tensor : outputs) {
delete tensor;
}
return;
}
param->axis_ = 1;
auto *concat_kernel =
new (std::nothrow) kernel::ConcatOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (concat_kernel == nullptr) {
MS_LOG(INFO) << " new kernel::ConcatOpenCLKernel failed ";
for (auto tensor : inputs) {
delete tensor;
}
for (auto tensor : outputs) {
delete tensor;
}
delete param;
return;
}
concat_kernel->SetFormatType(schema::Format_NC4HW4);
concat_kernel->Init();
// to do allocate memory for inputs and outputs
for (auto &input_tensor : inputs) {
input_tensor->MallocData(allocator);
}
MS_LOG(INFO) << " initialize sub_graph ";
std::vector<kernel::LiteKernel *> kernels{concat_kernel};
auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
if (sub_graph == nullptr) {
MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed ";
for (auto tensor : inputs) {
delete tensor;
}
for (auto tensor : outputs) {
delete tensor;
}
delete param;
delete concat_kernel;
return;
}
sub_graph->Init();
MS_LOG(INFO) << " initialize input data ";
if (inputs.size() == 2) {
memcpy(inputs[0]->data_c(), input_data1, input1_size);
memcpy(inputs[1]->data_c(), input_data2, input2_size);
} else if (inputs.size() == 3) {
memcpy(inputs[0]->data_c(), input_data1, input1_size);
memcpy(inputs[1]->data_c(), input_data2, input2_size);
memcpy(inputs[2]->data_c(), input_data3, input3_size);
} else if (inputs.size() == 4) {
memcpy(inputs[0]->data_c(), input_data1, input1_size);
memcpy(inputs[1]->data_c(), input_data2, input2_size);
memcpy(inputs[2]->data_c(), input_data3, input3_size);
memcpy(inputs[3]->data_c(), input_data4, input4_size);
} else if (inputs.size() == 6) {
memcpy(inputs[0]->data_c(), input_data1, input1_size);
memcpy(inputs[1]->data_c(), input_data2, input2_size);
memcpy(inputs[2]->data_c(), input_data3, input3_size);
memcpy(inputs[3]->data_c(), input_data4, input4_size);
memcpy(inputs[4]->data_c(), input_data5, input5_size);
memcpy(inputs[5]->data_c(), input_data6, input6_size);
} else {
MS_LOG(ERROR) << " input size must be 2 or 3 or 4";
}
std::cout << "==================output data================" << std::endl;
sub_graph->Run();
auto *output_data_gpu = reinterpret_cast<float16_t *>(output_tensor->MutableData());
CompareOutputData1(output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001);
lite::opencl::OpenCLRuntime::DeleteInstance();
for (auto tensor : inputs) {
tensor->SetData(nullptr);
delete tensor;
}
for (auto tensor : outputs) {
tensor->SetData(nullptr);
delete tensor;
}
delete sub_graph;
}
} // namespace mindspore