forked from mindspore-Ecosystem/mindspore
concat_optimization
This commit is contained in:
parent
859e704644
commit
a7c88c0643
|
@ -480,7 +480,7 @@ gene_clhpp() {
|
|||
do
|
||||
file="$(basename ${file_path})"
|
||||
inc_file=$(echo ${CL_SRC_DIR}/${file} | sed 's/$/.inc/')
|
||||
sed 's/^/\"/;s/$/ \\n\" \\/' ${CL_SRC_DIR}/${file} > ${inc_file}
|
||||
sed 's/\\/\\\\/g;s/\"/\\\"/g;s/^/\"/;s/$/\\n\" \\/' ${CL_SRC_DIR}/${file} > ${inc_file}
|
||||
kernel_name=$(echo ${file} | sed s'/.\{3\}$//')
|
||||
sed -i "1i\static const char *${kernel_name}_source =\"\\n\" \\" ${inc_file}
|
||||
sed -i '$a\;' ${inc_file}
|
||||
|
|
|
@ -1,339 +1,545 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
|
||||
|
||||
__kernel void Concat2input_NHWC4(__read_only image2d_t input0, __read_only image2d_t input1,
|
||||
__write_only image2d_t output, int4 input_shape0, int4 input_shape1, int4 output_shape,
|
||||
const int axis) {
|
||||
int X = get_global_id(0); // N*H
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // c/4
|
||||
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {
|
||||
return;
|
||||
}
|
||||
if (axis == 1) {
|
||||
if (X < input_shape0.y) {
|
||||
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
|
||||
} else {
|
||||
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z, (X - input_shape0.y)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
|
||||
}
|
||||
} else if (axis == 2) {
|
||||
if (Y < input_shape0.z) {
|
||||
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
|
||||
} else {
|
||||
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z) * input_shape1.w + Z, (X)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
|
||||
}
|
||||
} else {
|
||||
if (Z < input_shape0.w) {
|
||||
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
|
||||
} else {
|
||||
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z - input_shape0.w, (X)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
|
||||
}
|
||||
}
|
||||
}
|
||||
#define CHECK_IDXConcat2input_NHWC4 \
|
||||
int X = get_global_id(0); \
|
||||
int Y = get_global_id(1); \
|
||||
int Z = get_global_id(2); \
|
||||
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { \
|
||||
return; \
|
||||
} \
|
||||
FLT4 result;
|
||||
|
||||
__kernel void Concat3input_NHWC4(__read_only image2d_t input0, __read_only image2d_t input1,
|
||||
__read_only image2d_t input2, __write_only image2d_t output, int4 input_shape0,
|
||||
int4 input_shape1, int4 input_shape2, int4 output_shape, const int axis) {
|
||||
int X = get_global_id(0); // N*H
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // c/4
|
||||
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {
|
||||
return;
|
||||
}
|
||||
if (axis == 1) {
|
||||
if (X < input_shape0.y) {
|
||||
FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0);
|
||||
} else if (X < (input_shape0.y + input_shape1.y)) {
|
||||
FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z, (X - input_shape0.y)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1);
|
||||
} else {
|
||||
FLT4 result2 =
|
||||
READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z, (X - input_shape0.y - input_shape1.y)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result2);
|
||||
}
|
||||
} else if (axis == 2) {
|
||||
if (Y < input_shape0.z) {
|
||||
FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0);
|
||||
} else if (Y < (input_shape0.z + input_shape1.z)) {
|
||||
FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z) * input_shape1.w + Z, (X)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1);
|
||||
} else {
|
||||
FLT4 result2 =
|
||||
READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z) * input_shape2.w + Z, (X)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result2);
|
||||
}
|
||||
} else {
|
||||
if (Z < input_shape0.w) {
|
||||
FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0);
|
||||
} else if (Z < (input_shape0.w + input_shape1.w)) {
|
||||
FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z - input_shape0.w, (X)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1);
|
||||
} else {
|
||||
FLT4 result2 =
|
||||
READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z - input_shape0.w - input_shape1.w, (X)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result2);
|
||||
}
|
||||
}
|
||||
}
|
||||
#define DOConcat2inputaxis1_NHWC4 \
|
||||
if (X < input_shape0.y) { \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
|
||||
} else { \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z, (X - input_shape0.y))); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
|
||||
|
||||
__kernel void Concat2input_NC4HW4(__read_only image2d_t input0, __read_only image2d_t input1,
|
||||
__write_only image2d_t output, int4 input_shape0, int4 input_shape1,
|
||||
int4 output_shape, const int axis) {
|
||||
int X = get_global_id(0); // H
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // c/4
|
||||
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {
|
||||
return;
|
||||
}
|
||||
if (input_shape0.y == 0 || input_shape1.y == 0 || output_shape.y == 0) {
|
||||
return;
|
||||
}
|
||||
int in_postion_x;
|
||||
int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y;
|
||||
if (axis == 1) {
|
||||
if (X < input_shape0.y) {
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
|
||||
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else {
|
||||
in_postion_x = ((X - input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y +
|
||||
((X - input_shape0.y) % input_shape1.y);
|
||||
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
}
|
||||
} else if (axis == 2) {
|
||||
if (Y < input_shape0.z) {
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
|
||||
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else {
|
||||
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + (X % input_shape1.y);
|
||||
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
}
|
||||
} else {
|
||||
if (Z < input_shape0.w) {
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
|
||||
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else {
|
||||
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + (Z - input_shape0.w) * input_shape1.y +
|
||||
(X % input_shape1.y);
|
||||
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
}
|
||||
}
|
||||
}
|
||||
#define DOConcat2inputaxis2_NHWC4 \
|
||||
if (Y < input_shape0.z) { \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
|
||||
} else { \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z) * input_shape1.w + Z, (X))); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
|
||||
|
||||
__kernel void Concat3input_NC4HW4(__read_only image2d_t input0, __read_only image2d_t input1,
|
||||
__read_only image2d_t input2, __write_only image2d_t output, int4 input_shape0,
|
||||
int4 input_shape1, int4 input_shape2, int4 output_shape, const int axis) {
|
||||
int X = get_global_id(0); // N*H
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // c/4
|
||||
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {
|
||||
return;
|
||||
}
|
||||
if (input_shape0.y == 0 || input_shape1.y == 0 || input_shape2.y == 0 || output_shape.y == 0) {
|
||||
return;
|
||||
}
|
||||
int in_postion_x;
|
||||
int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y;
|
||||
if (axis == 1) {
|
||||
if (X < input_shape0.y) {
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
|
||||
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
|
||||
#define DOConcat2inputaxis3_NHWC4 \
|
||||
if (Z < input_shape0.w) { \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
|
||||
} else { \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z - input_shape0.w, (X))); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
|
||||
|
||||
#define CHECK_IDXConcat2input_NC4HW4 \
|
||||
int X = get_global_id(0); \
|
||||
int Y = get_global_id(1); \
|
||||
int Z = get_global_id(2); \
|
||||
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { \
|
||||
return; \
|
||||
} \
|
||||
if (input_shape0.y == 0 || input_shape1.y == 0 || output_shape.y == 0) { \
|
||||
return; \
|
||||
} \
|
||||
int in_postion_x; \
|
||||
int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y; \
|
||||
FLT4 result;
|
||||
|
||||
#define DOConcat2inputaxis1_NC4HW4 \
|
||||
if (X < input_shape0.y) { \
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else { \
|
||||
in_postion_x = ((X - input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + \
|
||||
((X - input_shape0.y) % input_shape1.y); \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else if (X < input_shape0.y + input_shape1.y) {
|
||||
in_postion_x = ((X - input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y +
|
||||
((X - input_shape0.y) % input_shape1.y);
|
||||
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
|
||||
|
||||
#define DOConcat2inputaxis2_NC4HW4 \
|
||||
if (Y < input_shape0.z) { \
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else { \
|
||||
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + (X % input_shape1.y); \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z), in_postion_x)); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else {
|
||||
in_postion_x = ((X - input_shape0.y - input_shape1.y) / input_shape2.y) * input_shape2.w * input_shape2.y +
|
||||
Z * input_shape2.y + ((X - input_shape0.y - input_shape1.y) % input_shape2.y);
|
||||
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x));
|
||||
|
||||
#define DOConcat2inputaxis3_NC4HW4 \
|
||||
if (Z < input_shape0.w) { \
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else { \
|
||||
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + (Z - input_shape0.w) * input_shape1.y + \
|
||||
(X % input_shape1.y); \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
}
|
||||
} else if (axis == 2) {
|
||||
if (Y < input_shape0.z) {
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
|
||||
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
|
||||
|
||||
#define CHECK_IDXConcat3input_NC4HW4 \
|
||||
int X = get_global_id(0); \
|
||||
int Y = get_global_id(1); \
|
||||
int Z = get_global_id(2); \
|
||||
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { \
|
||||
return; \
|
||||
} \
|
||||
if (input_shape0.y == 0 || input_shape1.y == 0 || input_shape2.y == 0 || output_shape.y == 0) { \
|
||||
return; \
|
||||
} \
|
||||
int in_postion_x; \
|
||||
int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y; \
|
||||
FLT4 result;
|
||||
|
||||
#define DOConcat3inputaxis1_NC4HW4 \
|
||||
if (X < input_shape0.y) { \
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else if (X < input_shape0.y + input_shape1.y) { \
|
||||
in_postion_x = ((X - input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + \
|
||||
((X - input_shape0.y) % input_shape1.y); \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else { \
|
||||
in_postion_x = ((X - input_shape0.y - input_shape1.y) / input_shape2.y) * input_shape2.w * input_shape2.y + \
|
||||
Z * input_shape2.y + ((X - input_shape0.y - input_shape1.y) % input_shape2.y); \
|
||||
result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else if (Y < input_shape0.z + input_shape1.z) {
|
||||
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + (X % input_shape1.y);
|
||||
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z), in_postion_x));
|
||||
|
||||
#define DOConcat3inputaxis2_NC4HW4 \
|
||||
if (Y < input_shape0.z) { \
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else if (Y < input_shape0.z + input_shape1.z) { \
|
||||
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + (X % input_shape1.y); \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z), in_postion_x)); \
|
||||
} else { \
|
||||
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y + Z * input_shape2.y + (X % input_shape2.y); \
|
||||
result = READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z), in_postion_x)); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else {
|
||||
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y + Z * input_shape2.y + (X % input_shape2.y);
|
||||
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z), in_postion_x));
|
||||
|
||||
#define DOConcat3inputaxis3_NC4HW4 \
|
||||
if (Z < input_shape0.w) { \
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else if (Z < input_shape0.w + input_shape1.w) { \
|
||||
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + (Z - input_shape0.w) * input_shape1.y + \
|
||||
(X % input_shape1.y); \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else { \
|
||||
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y + \
|
||||
(Z - input_shape0.w - input_shape1.w) * input_shape2.y + (X % input_shape2.y); \
|
||||
result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
}
|
||||
} else {
|
||||
if (Z < input_shape0.w) {
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
|
||||
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
|
||||
|
||||
#define CHECK_IDXConcat3input_NHWC4 \
|
||||
int X = get_global_id(0); \
|
||||
int Y = get_global_id(1); \
|
||||
int Z = get_global_id(2); \
|
||||
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { \
|
||||
return; \
|
||||
} \
|
||||
FLT4 result;
|
||||
|
||||
#define DOConcat3inputaxis1_NHWC4 \
|
||||
if (X < input_shape0.y) { \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
|
||||
} else if (X < (input_shape0.y + input_shape1.y)) { \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z, (X - input_shape0.y))); \
|
||||
} else { \
|
||||
result = READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z, (X - input_shape0.y - input_shape1.y))); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
|
||||
|
||||
#define DOConcat3inputaxis2_NHWC4 \
|
||||
if (Y < input_shape0.z) { \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
|
||||
} else if (Y < (input_shape0.z + input_shape1.z)) { \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z) * input_shape1.w + Z, (X))); \
|
||||
} else { \
|
||||
result = READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z) * input_shape2.w + Z, (X))); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
|
||||
|
||||
#define DOConcat3inputaxis3_NHWC4 \
|
||||
if (Z < input_shape0.w) { \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
|
||||
} else if (Z < (input_shape0.w + input_shape1.w)) { \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z - input_shape0.w, (X))); \
|
||||
} else { \
|
||||
result = READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z - input_shape0.w - input_shape1.w, (X))); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
|
||||
|
||||
#define CHECK_IDXConcat4input_NHWC4 \
|
||||
int X = get_global_id(0); \
|
||||
int Y = get_global_id(1); \
|
||||
int Z = get_global_id(2); \
|
||||
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { \
|
||||
return; \
|
||||
} \
|
||||
FLT4 result;
|
||||
|
||||
#define DOConcat4inputaxis1_NHWC4 \
|
||||
if (X < input_shape0.y) { \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
|
||||
} else if (X < (input_shape0.y + input_shape1.y)) { \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z, (X - input_shape0.y))); \
|
||||
} else if (X < (input_shape0.y + input_shape1.y + input_shape2.y)) { \
|
||||
result = READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z, (X - input_shape0.y - input_shape1.y))); \
|
||||
} else { \
|
||||
result = READ_IMAGE(input3, smp_none, \
|
||||
(int2)((Y)*input_shape3.w + Z, (X - input_shape0.y - input_shape1.y - input_shape2.y))); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
|
||||
|
||||
#define DOConcat4inputaxis2_NHWC4 \
|
||||
if (Y < input_shape0.z) { \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
|
||||
} else if (Y < (input_shape0.z + input_shape1.z)) { \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z) * input_shape1.w + Z, (X))); \
|
||||
} else if (Y < (input_shape0.z + input_shape1.z + input_shape2.z)) { \
|
||||
result = READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z) * input_shape2.w + Z, (X))); \
|
||||
} else { \
|
||||
result = READ_IMAGE(input3, smp_none, \
|
||||
(int2)((Y - input_shape0.z - input_shape1.z - input_shape2.z) * input_shape3.w + Z, (X))); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
|
||||
|
||||
#define DOConcat4inputaxis3_NHWC4 \
|
||||
if (Z < input_shape0.w) { \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
|
||||
} else if (Z < (input_shape0.w + input_shape1.w)) { \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z - input_shape0.w, (X))); \
|
||||
} else if (Z < (input_shape0.w + input_shape1.w + input_shape2.w)) { \
|
||||
result = READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z - input_shape0.w - input_shape1.w, (X))); \
|
||||
} else { \
|
||||
result = READ_IMAGE(input3, smp_none, \
|
||||
(int2)((Y)*input_shape3.w + Z - input_shape0.w - input_shape1.w - input_shape2.w, (X))); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
|
||||
|
||||
#define CHECK_IDXConcat4input_NC4HW4 \
|
||||
int X = get_global_id(0); \
|
||||
int Y = get_global_id(1); \
|
||||
int Z = get_global_id(2); \
|
||||
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { \
|
||||
return; \
|
||||
} \
|
||||
if (input_shape0.y == 0 || input_shape1.y == 0 || input_shape2.y == 0 || input_shape3.y == 0 || \
|
||||
output_shape.y == 0) { \
|
||||
return; \
|
||||
} \
|
||||
int in_postion_x; \
|
||||
int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y; \
|
||||
FLT4 result;
|
||||
|
||||
#define DOConcat4inputaxis1_NC4HW4 \
|
||||
if (X < input_shape0.y) { \
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else if (X < input_shape0.y + input_shape1.y) { \
|
||||
in_postion_x = ((X - input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + \
|
||||
((X - input_shape0.y) % input_shape1.y); \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else if (X < input_shape0.y + input_shape1.y + input_shape2.y) { \
|
||||
in_postion_x = ((X - input_shape0.y - input_shape1.y) / input_shape2.y) * input_shape2.w * input_shape2.y + \
|
||||
Z * input_shape2.y + ((X - input_shape0.y - input_shape1.y) % input_shape2.y); \
|
||||
result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else { \
|
||||
in_postion_x = \
|
||||
((X - input_shape0.y - input_shape1.y - input_shape2.y) / input_shape3.y) * input_shape3.w * input_shape3.y + \
|
||||
Z * input_shape3.y + ((X - input_shape0.y - input_shape1.y - input_shape2.y) % input_shape3.y); \
|
||||
result = READ_IMAGE(input3, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else if (Z < input_shape0.w + input_shape1.w) {
|
||||
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + (Z - input_shape0.w) * input_shape1.y +
|
||||
(X % input_shape1.y);
|
||||
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
|
||||
|
||||
#define DOConcat4inputaxis2_NC4HW4 \
|
||||
if (Y < input_shape0.z) { \
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else if (Y < input_shape0.z + input_shape1.z) { \
|
||||
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + (X % input_shape1.y); \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z), in_postion_x)); \
|
||||
} else if (Y < input_shape0.z + input_shape1.z + input_shape2.z) { \
|
||||
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y + Z * input_shape2.y + (X % input_shape2.y); \
|
||||
result = READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z), in_postion_x)); \
|
||||
} else { \
|
||||
in_postion_x = (X / input_shape3.y) * input_shape3.w * input_shape3.y + Z * input_shape3.y + (X % input_shape3.y); \
|
||||
result = \
|
||||
READ_IMAGE(input3, smp_none, (int2)((Y - input_shape0.z - input_shape1.z - input_shape2.z), in_postion_x)); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else {
|
||||
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y +
|
||||
(Z - input_shape0.w - input_shape1.w) * input_shape2.y + (X % input_shape2.y);
|
||||
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x));
|
||||
|
||||
#define DOConcat4inputaxis3_NC4HW4 \
|
||||
if (Z < input_shape0.w) { \
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else if (Z < input_shape0.w + input_shape1.w) { \
|
||||
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + (Z - input_shape0.w) * input_shape1.y + \
|
||||
(X % input_shape1.y); \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else if (Z < input_shape0.w + input_shape1.w + input_shape2.w) { \
|
||||
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y + \
|
||||
(Z - input_shape0.w - input_shape1.w) * input_shape2.y + (X % input_shape2.y); \
|
||||
result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else { \
|
||||
in_postion_x = (X / input_shape3.y) * input_shape3.w * input_shape3.y + \
|
||||
(Z - input_shape0.w - input_shape1.w - input_shape2.w) * input_shape3.y + (X % input_shape3.y); \
|
||||
result = READ_IMAGE(input3, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__kernel void Concat4input_NC4HW4(__read_only image2d_t input0, __read_only image2d_t input1,
|
||||
__read_only image2d_t input2, __read_only image2d_t input3,
|
||||
__write_only image2d_t output, int4 input_shape0, int4 input_shape1,
|
||||
int4 input_shape2, int4 input_shape3, int4 output_shape, const int axis) {
|
||||
int X = get_global_id(0); // N*H
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // c/4
|
||||
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {
|
||||
return;
|
||||
}
|
||||
if (input_shape0.y == 0 || input_shape1.y == 0 || input_shape2.y == 0 || output_shape.y == 0) {
|
||||
return;
|
||||
}
|
||||
int in_postion_x;
|
||||
int4 input_shape2, int4 input_shape3, int4 output_shape, const int axis) {}
|
||||
|
||||
#define CHECK_IDXConcat6input_NHWC4 \
|
||||
int X = get_global_id(0); \
|
||||
int Y = get_global_id(1); \
|
||||
int Z = get_global_id(2); \
|
||||
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { \
|
||||
return; \
|
||||
} \
|
||||
FLT4 result;
|
||||
|
||||
#define DOConcat6inputaxis1_NHWC4 \
|
||||
if (X < input_shape0.y) { \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
|
||||
} else if (X < (input_shape0.y + input_shape1.y)) { \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z, (X - input_shape0.y))); \
|
||||
} else if (X < (input_shape0.y + input_shape1.y + input_shape2.y)) { \
|
||||
result = READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z, (X - input_shape0.y - input_shape1.y))); \
|
||||
} else if (X < (input_shape0.y + input_shape1.y + input_shape2.y + input_shape3.y)) { \
|
||||
result = READ_IMAGE(input3, smp_none, \
|
||||
(int2)((Y)*input_shape3.w + Z, (X - input_shape0.y - input_shape1.y - input_shape2.y))); \
|
||||
} else if (X < (input_shape0.y + input_shape1.y + input_shape2.y + input_shape3.y + input_shape4.y)) { \
|
||||
result = READ_IMAGE( \
|
||||
input4, smp_none, \
|
||||
(int2)((Y)*input_shape4.w + Z, (X - input_shape0.y - input_shape1.y - input_shape2.y - input_shape3.y))); \
|
||||
} else { \
|
||||
result = READ_IMAGE(input5, smp_none, \
|
||||
(int2)((Y)*input_shape5.w + Z, (X - input_shape0.y - input_shape1.y - input_shape2.y - \
|
||||
input_shape3.y - input_shape4.y))); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
|
||||
|
||||
#define DOConcat6inputaxis2_NHWC4 \
|
||||
if (Y < input_shape0.z) { \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
|
||||
} else if (Y < (input_shape0.z + input_shape1.z)) { \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z) * input_shape1.w + Z, (X))); \
|
||||
} else if (Y < (input_shape0.z + input_shape1.z + input_shape2.z)) { \
|
||||
result = READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z) * input_shape2.w + Z, (X))); \
|
||||
} else if (Y < (input_shape0.z + input_shape1.z + input_shape2.z + input_shape3.z)) { \
|
||||
result = READ_IMAGE(input3, smp_none, \
|
||||
(int2)((Y - input_shape0.z - input_shape1.z - input_shape2.z) * input_shape3.w + Z, (X))); \
|
||||
} else if (Y < (input_shape0.z + input_shape1.z + input_shape2.z + input_shape3.z + input_shape4.z)) { \
|
||||
result = READ_IMAGE( \
|
||||
input4, smp_none, \
|
||||
(int2)((Y - input_shape0.z - input_shape1.z - input_shape2.z - input_shape3.z) * input_shape4.w + Z, (X))); \
|
||||
} else { \
|
||||
result = READ_IMAGE( \
|
||||
input5, smp_none, \
|
||||
(int2)( \
|
||||
(Y - input_shape0.z - input_shape1.z - input_shape2.z - input_shape3.z - input_shape4.z) * input_shape5.w + Z, \
|
||||
(X))); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
|
||||
|
||||
#define DOConcat6inputaxis3_NHWC4 \
|
||||
if (Z < input_shape0.w) { \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); \
|
||||
} else if (Z < (input_shape0.w + input_shape1.w)) { \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z - input_shape0.w, (X))); \
|
||||
} else if (Z < (input_shape0.w + input_shape1.w + input_shape2.w)) { \
|
||||
result = READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z - input_shape0.w - input_shape1.w, (X))); \
|
||||
} else if (Z < (input_shape0.w + input_shape1.w + input_shape2.w + input_shape3.w)) { \
|
||||
result = READ_IMAGE(input3, smp_none, \
|
||||
(int2)((Y)*input_shape3.w + Z - input_shape0.w - input_shape1.w - input_shape2.w, (X))); \
|
||||
} else if (Z < (input_shape0.w + input_shape1.w + input_shape2.w + input_shape3.w + input_shape4.w)) { \
|
||||
result = READ_IMAGE( \
|
||||
input4, smp_none, \
|
||||
(int2)((Y)*input_shape4.w + Z - input_shape0.w - input_shape1.w - input_shape2.w - input_shape3.w, (X))); \
|
||||
} else { \
|
||||
result = READ_IMAGE(input5, smp_none, \
|
||||
(int2)((Y)*input_shape5.w + Z - input_shape0.w - input_shape1.w - input_shape2.w - \
|
||||
input_shape3.w - input_shape4.w, \
|
||||
(X))); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);
|
||||
|
||||
#define CHECK_IDXConcat6input_NC4HW4 \
|
||||
int X = get_global_id(0); \
|
||||
int Y = get_global_id(1); \
|
||||
int Z = get_global_id(2); \
|
||||
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { \
|
||||
return; \
|
||||
} \
|
||||
if (input_shape0.y == 0 || input_shape1.y == 0 || input_shape2.y == 0 || input_shape3.y == 0 || \
|
||||
input_shape4.y == 0 || input_shape5.y == 0 || output_shape.y == 0) { \
|
||||
return; \
|
||||
} \
|
||||
int in_postion_x; \
|
||||
FLT4 result; \
|
||||
int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y;
|
||||
if (axis == 1) {
|
||||
if (X < input_shape0.y) {
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
|
||||
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
|
||||
|
||||
#define DOConcat6inputaxis1_NC4HW4 \
|
||||
if (X < input_shape0.y) { \
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else if (X < (input_shape0.y + input_shape1.y)) { \
|
||||
in_postion_x = ((X - input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + \
|
||||
((X - input_shape0.y) % input_shape1.y); \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else if (X < (input_shape0.y + input_shape1.y + input_shape2.y)) { \
|
||||
in_postion_x = ((X - input_shape0.y - input_shape1.y) / input_shape2.y) * input_shape2.w * input_shape2.y + \
|
||||
Z * input_shape2.y + ((X - input_shape0.y - input_shape1.y) % input_shape2.y); \
|
||||
result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else if (X < (input_shape0.y + input_shape1.y + input_shape2.y + input_shape3.y)) { \
|
||||
in_postion_x = \
|
||||
((X - input_shape0.y - input_shape1.y - input_shape2.y) / input_shape3.y) * input_shape3.w * input_shape3.y + \
|
||||
Z * input_shape3.y + ((X - input_shape0.y - input_shape1.y - input_shape2.y) % input_shape3.y); \
|
||||
result = READ_IMAGE(input3, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else if (X < (input_shape0.y + input_shape1.y + input_shape2.y + input_shape3.y + input_shape4.y)) { \
|
||||
in_postion_x = ((X - input_shape0.y - input_shape1.y - input_shape2.y - input_shape3.y) / input_shape4.y) * \
|
||||
input_shape4.w * input_shape4.y + \
|
||||
Z * input_shape4.y + \
|
||||
((X - input_shape0.y - input_shape1.y - input_shape2.y - input_shape3.y) % input_shape4.y); \
|
||||
result = READ_IMAGE(input4, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else { \
|
||||
in_postion_x = \
|
||||
((X - input_shape0.y - input_shape1.y - input_shape2.y - input_shape3.y - input_shape4.y) / input_shape5.y) * \
|
||||
input_shape5.w * input_shape5.y + \
|
||||
Z * input_shape5.y + \
|
||||
((X - input_shape0.y - input_shape1.y - input_shape2.y - input_shape3.y - input_shape4.y) % input_shape5.y); \
|
||||
result = READ_IMAGE(input5, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else if (X < input_shape0.y + input_shape1.y) {
|
||||
in_postion_x = ((X - input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y +
|
||||
((X - input_shape0.y) % input_shape1.y);
|
||||
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
|
||||
|
||||
#define DOConcat6inputaxis2_NC4HW4 \
|
||||
if (Y < input_shape0.z) { \
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else if (Y < (input_shape0.z + input_shape1.z)) { \
|
||||
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + (X % input_shape1.y); \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z), in_postion_x)); \
|
||||
} else if (Y < (input_shape0.z + input_shape1.z + input_shape2.z)) { \
|
||||
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y + Z * input_shape2.y + (X % input_shape2.y); \
|
||||
result = READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z), in_postion_x)); \
|
||||
} else if (Y < (input_shape0.z + input_shape1.z + input_shape2.z + input_shape3.z)) { \
|
||||
in_postion_x = (X / input_shape3.y) * input_shape3.w * input_shape3.y + Z * input_shape3.y + (X % input_shape3.y); \
|
||||
result = \
|
||||
READ_IMAGE(input3, smp_none, (int2)((Y - input_shape0.z - input_shape1.z - input_shape2.z), in_postion_x)); \
|
||||
} else if (Y < (input_shape0.z + input_shape1.z + input_shape2.z + input_shape3.z + input_shape4.z)) { \
|
||||
in_postion_x = (X / input_shape4.y) * input_shape4.w * input_shape4.y + Z * input_shape4.y + (X % input_shape4.y); \
|
||||
result = \
|
||||
READ_IMAGE(input4, smp_none, \
|
||||
(int2)((Y - input_shape0.z - input_shape1.z - input_shape2.z - input_shape3.z), in_postion_x)); \
|
||||
} else { \
|
||||
in_postion_x = (X / input_shape5.y) * input_shape5.w * input_shape5.y + Z * input_shape5.y + (X % input_shape5.y); \
|
||||
result = READ_IMAGE( \
|
||||
input5, smp_none, \
|
||||
(int2)((Y - input_shape0.z - input_shape1.z - input_shape2.z - input_shape3.z - input_shape4.z), in_postion_x)); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else if (X < input_shape0.y + input_shape1.y + input_shape2.y) {
|
||||
in_postion_x = ((X - input_shape0.y - input_shape1.y) / input_shape2.y) * input_shape2.w * input_shape2.y +
|
||||
Z * input_shape2.y + ((X - input_shape0.y - input_shape1.y) % input_shape2.y);
|
||||
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x));
|
||||
|
||||
#define DOConcat6inputaxis3_NC4HW4 \
|
||||
if (Z < input_shape0.w) { \
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; \
|
||||
result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else if (Z < (input_shape0.w + input_shape1.w)) { \
|
||||
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + (Z - input_shape0.w) * input_shape1.y + \
|
||||
(X % input_shape1.y); \
|
||||
result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else if (Z < (input_shape0.w + input_shape1.w + input_shape2.w)) { \
|
||||
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y + \
|
||||
(Z - input_shape0.w - input_shape1.w) * input_shape2.y + (X % input_shape2.y); \
|
||||
result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else if (Z < (input_shape0.w + input_shape1.w + input_shape2.w + input_shape3.w)) { \
|
||||
in_postion_x = (X / input_shape3.y) * input_shape3.w * input_shape3.y + \
|
||||
(Z - input_shape0.w - input_shape1.w - input_shape2.w) * input_shape3.y + (X % input_shape3.y); \
|
||||
result = READ_IMAGE(input3, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else if (Z < (input_shape0.w + input_shape1.w + input_shape2.w + input_shape3.w + input_shape4.w)) { \
|
||||
in_postion_x = (X / input_shape4.y) * input_shape4.w * input_shape4.y + \
|
||||
(Z - input_shape0.w - input_shape1.w - input_shape2.w - input_shape3.w) * input_shape4.y + \
|
||||
(X % input_shape4.y); \
|
||||
result = READ_IMAGE(input4, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} else { \
|
||||
in_postion_x = \
|
||||
(X / input_shape5.y) * input_shape5.w * input_shape5.y + \
|
||||
(Z - input_shape0.w - input_shape1.w - input_shape2.w - input_shape3.w - input_shape4.w) * input_shape5.y + \
|
||||
(X % input_shape5.y); \
|
||||
result = READ_IMAGE(input5, smp_none, (int2)((Y), in_postion_x)); \
|
||||
} \
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else {
|
||||
in_postion_x =
|
||||
((X - input_shape0.y - input_shape1.y - input_shape2.y) / input_shape3.y) * input_shape3.w * input_shape3.y +
|
||||
Z * input_shape3.y + ((X - input_shape0.y - input_shape1.y - input_shape2.y) % input_shape3.y);
|
||||
FLT4 result = READ_IMAGE(input3, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
}
|
||||
} else if (axis == 2) {
|
||||
if (Y < input_shape0.z) {
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
|
||||
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else if (Y < input_shape0.z + input_shape1.z) {
|
||||
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + (X % input_shape1.y);
|
||||
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else if (Y < input_shape0.z + input_shape1.z + input_shape2.z) {
|
||||
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y + Z * input_shape2.y + (X % input_shape2.y);
|
||||
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else {
|
||||
in_postion_x = (X / input_shape3.y) * input_shape3.w * input_shape3.y + Z * input_shape3.y + (X % input_shape3.y);
|
||||
FLT4 result =
|
||||
READ_IMAGE(input3, smp_none, (int2)((Y - input_shape0.z - input_shape1.z - input_shape2.z), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
}
|
||||
} else {
|
||||
if (Z < input_shape0.w) {
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
|
||||
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else if (Z < input_shape0.w + input_shape1.w) {
|
||||
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + (Z - input_shape0.w) * input_shape1.y +
|
||||
(X % input_shape1.y);
|
||||
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else if (Z < input_shape0.w + input_shape1.w + input_shape2.w) {
|
||||
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y +
|
||||
(Z - input_shape0.w - input_shape1.w) * input_shape2.y + (X % input_shape2.y);
|
||||
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else {
|
||||
in_postion_x = (X / input_shape3.y) * input_shape3.w * input_shape3.y +
|
||||
(Z - input_shape0.w - input_shape1.w - input_shape2.w) * input_shape3.y + (X % input_shape3.y);
|
||||
FLT4 result = READ_IMAGE(input3, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
}
|
||||
}
|
||||
|
||||
#define CONCAT6(Inputnum, Axis, ToFormat) \
|
||||
__kernel void Concat##Inputnum##Axis##ToFormat( \
|
||||
__read_only image2d_t input0, __read_only image2d_t input1, __read_only image2d_t input2, \
|
||||
__read_only image2d_t input3, __read_only image2d_t input4, __read_only image2d_t input5, \
|
||||
__write_only image2d_t output, int4 input_shape0, int4 input_shape1, int4 input_shape2, int4 input_shape3, \
|
||||
int4 input_shape4, int4 input_shape5, int4 output_shape, const int axis) { \
|
||||
CHECK_IDXConcat6input##ToFormat; \
|
||||
DOConcat##Inputnum##Axis##ToFormat; \
|
||||
}
|
||||
|
||||
__kernel void Concat4input_NHWC4(__read_only image2d_t input0, __read_only image2d_t input1,
|
||||
__read_only image2d_t input2, __read_only image2d_t input3,
|
||||
__write_only image2d_t output, int4 input_shape0, int4 input_shape1, int4 input_shape2,
|
||||
int4 input_shape3, int4 output_shape, const int axis) {
|
||||
int X = get_global_id(0); // N*H
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // c/4
|
||||
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {
|
||||
return;
|
||||
}
|
||||
if (axis == 1) {
|
||||
if (X < input_shape0.y) {
|
||||
FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0);
|
||||
} else if (X < (input_shape0.y + input_shape1.y)) {
|
||||
FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z, (X - input_shape0.y)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1);
|
||||
} else if (X < (input_shape0.y + input_shape1.y + input_shape2.y)) {
|
||||
FLT4 result2 =
|
||||
READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z, (X - input_shape0.y - input_shape1.y)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result2);
|
||||
} else {
|
||||
FLT4 result3 = READ_IMAGE(input3, smp_none,
|
||||
(int2)((Y)*input_shape3.w + Z, (X - input_shape0.y - input_shape1.y - input_shape2.y)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result3);
|
||||
}
|
||||
} else if (axis == 2) {
|
||||
if (Y < input_shape0.z) {
|
||||
FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0);
|
||||
} else if (Y < (input_shape0.z + input_shape1.z)) {
|
||||
FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z) * input_shape1.w + Z, (X)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1);
|
||||
} else if (Y < (input_shape0.z + input_shape1.z + input_shape2.z)) {
|
||||
FLT4 result2 =
|
||||
READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z) * input_shape2.w + Z, (X)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result2);
|
||||
} else {
|
||||
FLT4 result3 = READ_IMAGE(
|
||||
input3, smp_none, (int2)((Y - input_shape0.z - input_shape1.z - input_shape2.z) * input_shape3.w + Z, (X)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result3);
|
||||
}
|
||||
} else {
|
||||
if (Z < input_shape0.w) {
|
||||
FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0);
|
||||
} else if (Z < (input_shape0.w + input_shape1.w)) {
|
||||
FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z - input_shape0.w, (X)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1);
|
||||
} else if (Z < (input_shape0.w + input_shape1.w + input_shape2.w)) {
|
||||
FLT4 result2 =
|
||||
READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z - input_shape0.w - input_shape1.w, (X)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result2);
|
||||
} else {
|
||||
FLT4 result3 = READ_IMAGE(input3, smp_none,
|
||||
(int2)((Y)*input_shape3.w + Z - input_shape0.w - input_shape1.w - input_shape2.w, (X)));
|
||||
WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result3);
|
||||
#define CONCAT4(Inputnum, Axis, ToFormat) \
|
||||
__kernel void Concat##Inputnum##Axis##ToFormat( \
|
||||
__read_only image2d_t input0, __read_only image2d_t input1, __read_only image2d_t input2, \
|
||||
__read_only image2d_t input3, __write_only image2d_t output, int4 input_shape0, int4 input_shape1, \
|
||||
int4 input_shape2, int4 input_shape3, int4 output_shape, const int axis) { \
|
||||
CHECK_IDXConcat4input##ToFormat; \
|
||||
DOConcat##Inputnum##Axis##ToFormat; \
|
||||
}
|
||||
|
||||
#define CONCAT3(Inputnum, Axis, ToFormat) \
|
||||
__kernel void Concat##Inputnum##Axis##ToFormat(__read_only image2d_t input0, __read_only image2d_t input1, \
|
||||
__read_only image2d_t input2, __write_only image2d_t output, \
|
||||
int4 input_shape0, int4 input_shape1, int4 input_shape2, \
|
||||
int4 output_shape, const int axis) { \
|
||||
CHECK_IDXConcat3input##ToFormat; \
|
||||
DOConcat##Inputnum##Axis##ToFormat; \
|
||||
}
|
||||
|
||||
#define CONCAT2(Inputnum, Axis, ToFormat) \
|
||||
__kernel void Concat##Inputnum##Axis##ToFormat(__read_only image2d_t input0, __read_only image2d_t input1, \
|
||||
__write_only image2d_t output, int4 input_shape0, int4 input_shape1, \
|
||||
int4 output_shape, const int axis) { \
|
||||
CHECK_IDXConcat2input##ToFormat; \
|
||||
DOConcat##Inputnum##Axis##ToFormat; \
|
||||
}
|
||||
|
||||
// nc4hw4 ?
|
||||
CONCAT6(6input, axis1, _NC4HW4)
|
||||
CONCAT6(6input, axis2, _NC4HW4)
|
||||
CONCAT6(6input, axis3, _NC4HW4)
|
||||
CONCAT4(4input, axis1, _NC4HW4)
|
||||
CONCAT4(4input, axis2, _NC4HW4)
|
||||
CONCAT4(4input, axis3, _NC4HW4)
|
||||
CONCAT3(3input, axis1, _NC4HW4)
|
||||
CONCAT3(3input, axis2, _NC4HW4)
|
||||
CONCAT3(3input, axis3, _NC4HW4)
|
||||
CONCAT2(2input, axis1, _NC4HW4)
|
||||
CONCAT2(2input, axis2, _NC4HW4)
|
||||
CONCAT2(2input, axis3, _NC4HW4)
|
||||
|
||||
// nhwc4?
|
||||
CONCAT6(6input, axis1, _NHWC4)
|
||||
CONCAT6(6input, axis2, _NHWC4)
|
||||
CONCAT6(6input, axis3, _NHWC4)
|
||||
CONCAT4(4input, axis1, _NHWC4)
|
||||
CONCAT4(4input, axis2, _NHWC4)
|
||||
CONCAT4(4input, axis3, _NHWC4)
|
||||
CONCAT3(3input, axis1, _NHWC4)
|
||||
CONCAT3(3input, axis2, _NHWC4)
|
||||
CONCAT3(3input, axis3, _NHWC4)
|
||||
CONCAT2(2input, axis1, _NHWC4)
|
||||
CONCAT2(2input, axis2, _NHWC4)
|
||||
CONCAT2(2input, axis3, _NHWC4)
|
||||
|
|
|
@ -93,13 +93,19 @@ int ConcatOpenCLKernel::Init() {
|
|||
|
||||
std::string kernel_name = "Concat";
|
||||
if (in_tensors_.size() == 2) {
|
||||
kernel_name += "2input";
|
||||
kernel_name += "2inputaxis";
|
||||
kernel_name += std::to_string(param->axis_);
|
||||
} else if (in_tensors_.size() == 3) {
|
||||
kernel_name += "3input";
|
||||
kernel_name += "3inputaxis";
|
||||
kernel_name += std::to_string(param->axis_);
|
||||
} else if (in_tensors_.size() == 4) {
|
||||
kernel_name += "4input";
|
||||
kernel_name += "4inputaxis";
|
||||
kernel_name += std::to_string(param->axis_);
|
||||
} else if (in_tensors_.size() == 6) {
|
||||
kernel_name += "6inputaxis";
|
||||
kernel_name += std::to_string(param->axis_);
|
||||
} else {
|
||||
MS_LOG(ERROR) << " input must be 2 3 or 4";
|
||||
MS_LOG(ERROR) << " input must be 2 , 3 , 4 or 6";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (in_format == schema::Format_NC4HW4) {
|
||||
|
@ -107,6 +113,7 @@ int ConcatOpenCLKernel::Init() {
|
|||
} else if (in_format == schema::Format_NHWC4) {
|
||||
kernel_name += "_NHWC4";
|
||||
}
|
||||
MS_LOG(DEBUG) << "kernel_name=: " << kernel_name;
|
||||
std::set<std::string> build_options;
|
||||
std::string source = concat_source;
|
||||
std::string program_name = "Concat";
|
||||
|
@ -118,16 +125,36 @@ int ConcatOpenCLKernel::Init() {
|
|||
|
||||
int ConcatOpenCLKernel::ReSize() { return RET_OK; }
|
||||
|
||||
int ConcatOpenCLKernel::GetSumShape(std::vector<int> *sum_shape, std::vector<int> *in_shape) {
|
||||
std::vector<int> temp_sum = {0, 0, 0, 0};
|
||||
int ConcatOpenCLKernel::IntegraShapeToXYZ() {
|
||||
auto in_format = op_format_;
|
||||
if (out_tensors_[0]->shape().size() > 4 || out_tensors_[0]->shape().size() <= 0) {
|
||||
MS_LOG(ERROR) << "in_tensors_.shape() must between 0~4";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (in_format == schema::Format_NHWC4 || in_format == schema::Format_NC4HW4) {
|
||||
for (int i = 0; i < in_tensors_.size(); ++i) {
|
||||
cl_int4 temp_cl;
|
||||
auto temp = in_tensors_[i]->shape();
|
||||
temp_cl = {temp[0], temp[1], temp[2], UP_DIV(temp[3], C4NUM)};
|
||||
XYZShape.push_back(temp_cl);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < in_tensors_.size(); ++i) {
|
||||
auto temp = in_tensors_[i]->shape();
|
||||
for (int j = 0; j < temp.size(); ++j) {
|
||||
in_shape->push_back(temp[j]);
|
||||
temp_sum.at(j) += temp[j];
|
||||
sum_shape->push_back(temp_sum.at(j));
|
||||
for (int j = temp.size(); j < C4NUM; ++j) {
|
||||
temp.push_back(1);
|
||||
}
|
||||
cl_int4 temp_cl = {temp[0], temp[1], temp[2], UP_DIV(temp[3], C4NUM)};
|
||||
XYZShape.push_back(temp_cl);
|
||||
}
|
||||
auto temp = out_tensors_[0]->shape();
|
||||
for (int i = out_tensors_[0]->shape().size(); i < C4NUM; ++i) {
|
||||
temp.push_back(1);
|
||||
}
|
||||
}
|
||||
shape_nhwc = {out_tensors_[0]->shape()[0] * out_tensors_[0]->shape()[1], out_tensors_[0]->shape()[2],
|
||||
UP_DIV(out_tensors_[0]->shape()[3], C4NUM)};
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -151,70 +178,31 @@ int ConcatOpenCLKernel::Run() {
|
|||
if (param->axis_ == 0) {
|
||||
return RunAxis0();
|
||||
}
|
||||
|
||||
auto input1_shape = in_tensors_[0]->shape();
|
||||
auto input2_shape = in_tensors_[1]->shape();
|
||||
auto output_shape = out_tensors_[0]->shape();
|
||||
|
||||
cl_int4 input_shape1_ = {input1_shape[0], input1_shape[1], input1_shape[2], UP_DIV(input1_shape[3], C4NUM)};
|
||||
cl_int4 input_shape2_ = {input2_shape[0], input2_shape[1], input2_shape[2], UP_DIV(input2_shape[3], C4NUM)};
|
||||
cl_int4 output_shape_ = {output_shape[0], output_shape[1], output_shape[2], UP_DIV(output_shape[3], C4NUM)};
|
||||
|
||||
uint32_t OH = output_shape[0] * output_shape[1]; // N*H
|
||||
uint32_t OW = output_shape[2];
|
||||
uint32_t OC = UP_DIV(output_shape[3], C4NUM);
|
||||
|
||||
IntegraShapeToXYZ();
|
||||
const std::vector<size_t> &max_global = ocl_runtime_->GetWorkItemSize();
|
||||
std::vector<size_t> local = {1, 1, 1}; // init local
|
||||
std::vector<size_t> global = {OH, OW, OC};
|
||||
std::vector<size_t> local = {1, 1, 1};
|
||||
std::vector<size_t> global = {static_cast<size_t>(shape_nhwc.s[0]), static_cast<size_t>(shape_nhwc.s[1]),
|
||||
static_cast<size_t>(shape_nhwc.s[2])};
|
||||
ConcatGetWorkGroup(global, &local, max_global[0]);
|
||||
GetSumShape(&sum_shape, &in_shape);
|
||||
|
||||
if (in_tensors_.size() == 2 || in_tensors_.size() == 3 || in_tensors_.size() == 4 || in_tensors_.size() == 6) {
|
||||
int arg_cn = 0;
|
||||
if (in_tensors_.size() == 2) {
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c());
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[1]->data_c());
|
||||
for (int i = 0; i < in_tensors_.size(); ++i) {
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[i]->data_c());
|
||||
}
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c());
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape1_);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape2_);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, output_shape_);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, param->axis_);
|
||||
} else if (in_tensors_.size() == 3) {
|
||||
auto input3_shape = in_tensors_[2]->shape();
|
||||
cl_int4 input_shape3_ = {input3_shape[0], input3_shape[1], input3_shape[2], UP_DIV(input3_shape[3], C4NUM)};
|
||||
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c());
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[1]->data_c());
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[2]->data_c());
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c());
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape1_);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape2_);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape3_);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, output_shape_);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, param->axis_);
|
||||
} else if (in_tensors_.size() == 4) {
|
||||
auto input3_shape = in_tensors_[2]->shape();
|
||||
auto input4_shape = in_tensors_[3]->shape();
|
||||
cl_int4 input_shape3_ = {input3_shape[0], input3_shape[1], input3_shape[2], UP_DIV(input3_shape[3], C4NUM)};
|
||||
cl_int4 input_shape4_ = {input4_shape[0], input4_shape[1], input4_shape[2], UP_DIV(input4_shape[3], C4NUM)};
|
||||
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c());
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[1]->data_c());
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[2]->data_c());
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[3]->data_c());
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c());
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape1_);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape2_);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape3_);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape4_);
|
||||
for (int i = 0; i < XYZShape.size(); ++i) {
|
||||
cl_int4 temp = {XYZShape[i].s[0], XYZShape[i].s[1], XYZShape[i].s[2], XYZShape[i].s[3]};
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, temp);
|
||||
}
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, output_shape_);
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, param->axis_);
|
||||
} else {
|
||||
MS_LOG(ERROR) << " input sizes must 2 or 3 or 4";
|
||||
MS_LOG(ERROR) << "unsupported input size :" << in_tensors_.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
ocl_runtime_->RunKernel(kernel_, global, local, nullptr);
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -41,12 +41,12 @@ class ConcatOpenCLKernel : public OpenCLKernel {
|
|||
|
||||
int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;
|
||||
|
||||
int GetSumShape(std::vector<int> *sum_shape, std::vector<int> *in_shape);
|
||||
int IntegraShapeToXYZ();
|
||||
|
||||
private:
|
||||
cl::Kernel kernel_;
|
||||
std::vector<int> sum_shape;
|
||||
std::vector<int> in_shape;
|
||||
std::vector<cl_int3> XYZShape;
|
||||
cl_int4 shape_nhwc;
|
||||
};
|
||||
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -146,7 +146,7 @@ TEST_F(TestConcatOpenCLCI, ConcatFp32_2inputforCI) {
|
|||
delete sub_graph;
|
||||
}
|
||||
|
||||
TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis1) {
|
||||
TEST_F(TestConcatOpenCLfp16, ConcatFp16_4input_dim4_axis1) {
|
||||
MS_LOG(INFO) << " begin test ";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->SetFp16Enable(true);
|
||||
|
@ -276,7 +276,7 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis1) {
|
|||
delete sub_graph;
|
||||
}
|
||||
|
||||
TEST_F(TestConcatOpenCLfp32, ConcatFp32_2input_dim4_axis3) {
|
||||
TEST_F(TestConcatOpenCLfp32, ConcatFp32_3input_dim4_axis1) {
|
||||
MS_LOG(INFO) << " begin test ";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->Init();
|
||||
|
@ -396,4 +396,146 @@ TEST_F(TestConcatOpenCLfp32, ConcatFp32_2input_dim4_axis3) {
|
|||
}
|
||||
delete sub_graph;
|
||||
}
|
||||
|
||||
TEST_F(TestConcatOpenCLfp16, ConcatFp16_6input_dim4_axis1) {
|
||||
MS_LOG(INFO) << " begin test ";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->SetFp16Enable(true);
|
||||
ocl_runtime->Init();
|
||||
auto allocator = ocl_runtime->GetAllocator();
|
||||
|
||||
// get the input from .bin
|
||||
size_t input1_size, input2_size, input3_size, input4_size, input5_size, input6_size, output_size;
|
||||
std::string input1Ppath = "./test_data/concatfp16_input1.bin";
|
||||
std::string input2Ppath = "./test_data/concatfp16_input2.bin";
|
||||
std::string input3Ppath = "./test_data/concatfp16_input3.bin";
|
||||
std::string input4Ppath = "./test_data/concatfp16_input4.bin";
|
||||
std::string input5Ppath = "./test_data/concatfp16_input5.bin";
|
||||
std::string input6Ppath = "./test_data/concatfp16_input6.bin";
|
||||
std::string correctOutputPath = "./test_data/concatfp16_output.bin";
|
||||
auto input_data1 = reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size));
|
||||
auto input_data2 = reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size));
|
||||
auto input_data3 = reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(input3Ppath.c_str(), &input3_size));
|
||||
auto input_data4 = reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(input4Ppath.c_str(), &input4_size));
|
||||
auto input_data5 = reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(input5Ppath.c_str(), &input5_size));
|
||||
auto input_data6 = reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(input6Ppath.c_str(), &input6_size));
|
||||
auto correctOutput =
|
||||
reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size));
|
||||
|
||||
MS_LOG(INFO) << " init tensors ";
|
||||
constexpr int INPUT_NUM = 6;
|
||||
std::array<std::vector<int>, INPUT_NUM> input_shapes = {
|
||||
std::vector<int>{1, 1200, 3, 4}, std::vector<int>{1, 600, 3, 4}, std::vector<int>{1, 150, 3, 4},
|
||||
std::vector<int>{1, 50, 3, 4}, std::vector<int>{1, 30, 3, 4}, std::vector<int>{1, 4, 3, 4}};
|
||||
std::vector<int> output_shape = {1, 2034, 3, 4};
|
||||
auto data_type = kNumberTypeFloat16;
|
||||
auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode);
|
||||
std::vector<lite::Tensor *> inputs;
|
||||
for (auto &shape : input_shapes) {
|
||||
auto input_temp = new (std::nothrow) lite::Tensor(data_type, shape, schema::Format_NHWC, tensor_type);
|
||||
inputs.push_back(input_temp);
|
||||
if (input_temp == nullptr) {
|
||||
MS_LOG(INFO) << " new input_tensor failed ";
|
||||
return;
|
||||
}
|
||||
}
|
||||
auto *output_tensor = new (std::nothrow) lite::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type);
|
||||
if (output_tensor == nullptr) {
|
||||
MS_LOG(INFO) << " new output_tensor failed ";
|
||||
for (auto tensor : inputs) {
|
||||
delete tensor;
|
||||
}
|
||||
return;
|
||||
}
|
||||
std::vector<lite::Tensor *> outputs{output_tensor};
|
||||
MS_LOG(INFO) << " input_shapes size =: " << input_shapes.size();
|
||||
|
||||
MS_LOG(INFO) << " initialize tensors ";
|
||||
auto param = reinterpret_cast<ConcatParameter *>(malloc(sizeof(ConcatParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(INFO) << " new ConcatParameter failed ";
|
||||
for (auto tensor : inputs) {
|
||||
delete tensor;
|
||||
}
|
||||
for (auto tensor : outputs) {
|
||||
delete tensor;
|
||||
}
|
||||
return;
|
||||
}
|
||||
param->axis_ = 1;
|
||||
auto *concat_kernel =
|
||||
new (std::nothrow) kernel::ConcatOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
|
||||
if (concat_kernel == nullptr) {
|
||||
MS_LOG(INFO) << " new kernel::ConcatOpenCLKernel failed ";
|
||||
for (auto tensor : inputs) {
|
||||
delete tensor;
|
||||
}
|
||||
for (auto tensor : outputs) {
|
||||
delete tensor;
|
||||
}
|
||||
delete param;
|
||||
return;
|
||||
}
|
||||
concat_kernel->SetFormatType(schema::Format_NC4HW4);
|
||||
concat_kernel->Init();
|
||||
// to do allocate memory for inputs and outputs
|
||||
for (auto &input_tensor : inputs) {
|
||||
input_tensor->MallocData(allocator);
|
||||
}
|
||||
MS_LOG(INFO) << " initialize sub_graph ";
|
||||
std::vector<kernel::LiteKernel *> kernels{concat_kernel};
|
||||
auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
|
||||
if (sub_graph == nullptr) {
|
||||
MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed ";
|
||||
for (auto tensor : inputs) {
|
||||
delete tensor;
|
||||
}
|
||||
for (auto tensor : outputs) {
|
||||
delete tensor;
|
||||
}
|
||||
delete param;
|
||||
delete concat_kernel;
|
||||
return;
|
||||
}
|
||||
sub_graph->Init();
|
||||
MS_LOG(INFO) << " initialize input data ";
|
||||
if (inputs.size() == 2) {
|
||||
memcpy(inputs[0]->data_c(), input_data1, input1_size);
|
||||
memcpy(inputs[1]->data_c(), input_data2, input2_size);
|
||||
} else if (inputs.size() == 3) {
|
||||
memcpy(inputs[0]->data_c(), input_data1, input1_size);
|
||||
memcpy(inputs[1]->data_c(), input_data2, input2_size);
|
||||
memcpy(inputs[2]->data_c(), input_data3, input3_size);
|
||||
} else if (inputs.size() == 4) {
|
||||
memcpy(inputs[0]->data_c(), input_data1, input1_size);
|
||||
memcpy(inputs[1]->data_c(), input_data2, input2_size);
|
||||
memcpy(inputs[2]->data_c(), input_data3, input3_size);
|
||||
memcpy(inputs[3]->data_c(), input_data4, input4_size);
|
||||
} else if (inputs.size() == 6) {
|
||||
memcpy(inputs[0]->data_c(), input_data1, input1_size);
|
||||
memcpy(inputs[1]->data_c(), input_data2, input2_size);
|
||||
memcpy(inputs[2]->data_c(), input_data3, input3_size);
|
||||
memcpy(inputs[3]->data_c(), input_data4, input4_size);
|
||||
memcpy(inputs[4]->data_c(), input_data5, input5_size);
|
||||
memcpy(inputs[5]->data_c(), input_data6, input6_size);
|
||||
} else {
|
||||
MS_LOG(ERROR) << " input size must be 2 or 3 or 4";
|
||||
}
|
||||
|
||||
std::cout << "==================output data================" << std::endl;
|
||||
sub_graph->Run();
|
||||
auto *output_data_gpu = reinterpret_cast<float16_t *>(output_tensor->MutableData());
|
||||
CompareOutputData1(output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001);
|
||||
lite::opencl::OpenCLRuntime::DeleteInstance();
|
||||
for (auto tensor : inputs) {
|
||||
tensor->SetData(nullptr);
|
||||
delete tensor;
|
||||
}
|
||||
for (auto tensor : outputs) {
|
||||
tensor->SetData(nullptr);
|
||||
delete tensor;
|
||||
}
|
||||
delete sub_graph;
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue