op_format_toNC4HW4
This commit is contained in:
parent
7786adc3aa
commit
b485dc06ec
|
@ -1,8 +1,9 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
|
||||
|
||||
__kernel void Concat(__read_only image2d_t input0, __read_only image2d_t input1, __write_only image2d_t output,
|
||||
int4 input_shape0, int4 input_shape1, int4 output_shape, const int axis) {
|
||||
__kernel void Concat2input_NHWC4(__read_only image2d_t input0, __read_only image2d_t input1,
|
||||
__write_only image2d_t output, int4 input_shape0, int4 input_shape1, int4 output_shape,
|
||||
const int axis) {
|
||||
int X = get_global_id(0); // N*H
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // c/4
|
||||
|
@ -44,9 +45,9 @@ __kernel void Concat(__read_only image2d_t input0, __read_only image2d_t input1,
|
|||
}
|
||||
}
|
||||
|
||||
__kernel void Concat3input(__read_only image2d_t input0, __read_only image2d_t input1, __read_only image2d_t input2,
|
||||
__write_only image2d_t output, int4 input_shape0, int4 input_shape1, int4 input_shape2,
|
||||
int4 output_shape, const int axis) {
|
||||
__kernel void Concat3input_NHWC4(__read_only image2d_t input0, __read_only image2d_t input1,
|
||||
__read_only image2d_t input2, __write_only image2d_t output, int4 input_shape0,
|
||||
int4 input_shape1, int4 input_shape2, int4 output_shape, const int axis) {
|
||||
int X = get_global_id(0); // N*H
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // c/4
|
||||
|
@ -105,3 +106,144 @@ __kernel void Concat3input(__read_only image2d_t input0, __read_only image2d_t i
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
__kernel void Concat2input_NC4HW4(__read_only image2d_t input0, __read_only image2d_t input1,
|
||||
__write_only image2d_t output, int4 input_shape0, int4 input_shape1,
|
||||
int4 output_shape, const int axis) {
|
||||
int X = get_global_id(0); // H
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // c/4
|
||||
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {
|
||||
return;
|
||||
}
|
||||
if (input_shape0.y == 0 || input_shape1.y == 0 || output_shape.y == 0) {
|
||||
return;
|
||||
}
|
||||
int in_postion_x;
|
||||
int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y;
|
||||
if (axis == 0) {
|
||||
if (X < (input_shape0.x * input_shape0.y)) {
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
|
||||
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else {
|
||||
in_postion_x = ((X - input_shape0.x * input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y +
|
||||
Z * input_shape1.y + ((X - input_shape0.x * input_shape0.y) % input_shape1.y);
|
||||
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
}
|
||||
} else if (axis == 1) {
|
||||
if (X < input_shape0.y) {
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
|
||||
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else {
|
||||
in_postion_x = ((X - input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y +
|
||||
((X - input_shape0.y) % input_shape1.y);
|
||||
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
}
|
||||
} else if (axis == 2) {
|
||||
if (Y < input_shape0.z) {
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
|
||||
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else {
|
||||
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + (X % input_shape1.y);
|
||||
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
}
|
||||
} else {
|
||||
if (Z < input_shape0.w) {
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
|
||||
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else {
|
||||
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + (Z - input_shape0.w) * input_shape1.y +
|
||||
(X % input_shape1.y);
|
||||
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__kernel void Concat3input_NC4HW4(__read_only image2d_t input0, __read_only image2d_t input1,
|
||||
__read_only image2d_t input2, __write_only image2d_t output, int4 input_shape0,
|
||||
int4 input_shape1, int4 input_shape2, int4 output_shape, const int axis) {
|
||||
int X = get_global_id(0); // N*H
|
||||
int Y = get_global_id(1); // W
|
||||
int Z = get_global_id(2); // c/4
|
||||
if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {
|
||||
return;
|
||||
}
|
||||
if (input_shape0.y == 0 || input_shape1.y == 0 || input_shape2.y == 0 || output_shape.y == 0) {
|
||||
return;
|
||||
}
|
||||
int in_postion_x;
|
||||
int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y;
|
||||
if (axis == 0) {
|
||||
if (X < (input_shape0.x * input_shape0.y)) {
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
|
||||
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else if (X < (input_shape0.x * input_shape0.y + input_shape1.x * input_shape1.y)) {
|
||||
in_postion_x = ((X - input_shape0.x * input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y +
|
||||
Z * input_shape1.y + ((X - input_shape0.x * input_shape0.y) % input_shape1.y);
|
||||
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else {
|
||||
in_postion_x = ((X - input_shape0.x * input_shape0.y - input_shape1.x * input_shape1.y) / input_shape2.y) *
|
||||
input_shape2.w * input_shape2.y +
|
||||
Z * input_shape2.y +
|
||||
(X - input_shape0.x * input_shape0.y - input_shape1.x * input_shape1.y) % input_shape2.y;
|
||||
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
}
|
||||
} else if (axis == 1) {
|
||||
if (X < input_shape0.y) {
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
|
||||
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else if (X < input_shape0.y + input_shape1.y) {
|
||||
in_postion_x = ((X - input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y +
|
||||
((X - input_shape0.y) % input_shape1.y);
|
||||
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else {
|
||||
in_postion_x = ((X - input_shape0.y - input_shape1.y) / input_shape2.y) * input_shape2.w * input_shape2.y +
|
||||
Z * input_shape2.y + ((X - input_shape0.y - input_shape1.y) % input_shape2.y);
|
||||
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
}
|
||||
} else if (axis == 2) {
|
||||
if (Y < input_shape0.z) {
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
|
||||
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else if (Y < input_shape0.z + input_shape1.z) {
|
||||
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + (X % input_shape1.y);
|
||||
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else {
|
||||
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y + Z * input_shape2.y + (X % input_shape2.y);
|
||||
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
}
|
||||
} else {
|
||||
if (Z < input_shape0.w) {
|
||||
in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y;
|
||||
FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else if (Z < input_shape0.w + input_shape1.w) {
|
||||
in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + (Z - input_shape0.w) * input_shape1.y +
|
||||
(X % input_shape1.y);
|
||||
FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
} else {
|
||||
in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y +
|
||||
(Z - input_shape0.w - input_shape1.w) * input_shape2.y + (X % input_shape2.y);
|
||||
FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x));
|
||||
WRITE_IMAGE(output, (int2)((Y), out_pos_x), result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,8 +35,8 @@ int ConcatOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size)
|
|||
im_dst_x = out_tensors_[0]->Width() * CO4;
|
||||
im_dst_y = out_tensors_[0]->Height() * out_tensors_[0]->Batch();
|
||||
} else {
|
||||
im_dst_y = out_tensors_[0]->Height() * CO4;
|
||||
im_dst_x = out_tensors_[0]->Width() * out_tensors_[0]->Batch();
|
||||
im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * CO4;
|
||||
im_dst_x = out_tensors_[0]->Width();
|
||||
}
|
||||
size_t img_dtype = CL_FLOAT;
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
|
@ -61,30 +61,37 @@ int ConcatOpenCLKernel::Init() {
|
|||
MS_LOG(ERROR) << " only support axis >= 0 and axis <= 3 ";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (in_tensors_.size() == 2) {
|
||||
std::set<std::string> build_options;
|
||||
std::string source = concat_source;
|
||||
std::string program_name = "Concat";
|
||||
std::string kernel_name = "Concat";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->LoadSource(program_name, source);
|
||||
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
|
||||
}
|
||||
|
||||
if (in_tensors_.size() == 3) {
|
||||
std::set<std::string> build_options;
|
||||
std::string source = concat_source;
|
||||
std::string program_name = "Concat3input";
|
||||
std::string kernel_name = "Concat3input";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->LoadSource(program_name, source);
|
||||
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
|
||||
auto in_format = op_format_;
|
||||
if (in_format != schema::Format_NHWC4 && in_format != schema::Format_NC4HW4) {
|
||||
MS_LOG(ERROR) << "input format(" << in_format << ") "
|
||||
<< "format not support!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
in_ori_format_ = in_tensors_[0]->GetFormat();
|
||||
in_tensors_[0]->SetFormat(schema::Format_NHWC4);
|
||||
in_tensors_[0]->SetFormat(op_format_);
|
||||
out_ori_format_ = out_tensors_[0]->GetFormat();
|
||||
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
|
||||
out_tensors_[0]->SetFormat(op_format_);
|
||||
|
||||
std::string kernel_name = "Concat";
|
||||
if (in_tensors_.size() == 2) {
|
||||
kernel_name += "2input";
|
||||
} else if (in_tensors_.size() == 3) {
|
||||
kernel_name += "3input";
|
||||
} else {
|
||||
MS_LOG(ERROR) << " input must be 2 or 3";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (in_format == schema::Format_NC4HW4) {
|
||||
kernel_name += "_NC4HW4";
|
||||
} else if (in_format == schema::Format_NHWC4) {
|
||||
kernel_name += "_NHWC4";
|
||||
}
|
||||
std::set<std::string> build_options;
|
||||
std::string source = concat_source;
|
||||
std::string program_name = "Concat";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->LoadSource(program_name, source);
|
||||
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -49,6 +49,7 @@ int DepthwiseConv2dOpenCLKernel::Init() {
|
|||
if (in_format != schema::Format_NHWC4 && in_format != schema::Format_NC4HW4) {
|
||||
MS_LOG(ERROR) << "input format(" << in_format << ") "
|
||||
<< "format not support!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
in_tensors_[0]->SetFormat(in_format);
|
||||
out_tensors_[0]->SetFormat(in_format);
|
||||
|
@ -103,6 +104,7 @@ int DepthwiseConv2dOpenCLKernel::InitBuffer() {
|
|||
PackNCHWToNC4HW4<float, int16_t>(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Only support float16/float32, actual data type " << in_tensors_.at(kWeightIndex)->data_type();
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
packed_weight_ = allocator->Malloc(pack_weight_size * sizeof(float));
|
||||
|
@ -112,6 +114,7 @@ int DepthwiseConv2dOpenCLKernel::InitBuffer() {
|
|||
PackNCHWToNC4HW4<float, float>(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Only support float16/float32, actual data type " << in_tensors_.at(kWeightIndex)->data_type();
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -168,10 +168,10 @@ TEST_F(TestConcatOpenCLfp32, ConcatFp32_2input_dim4_axis3) {
|
|||
|
||||
// get the input from .bin
|
||||
size_t input1_size, input2_size, input3_size, output_size;
|
||||
std::string input1Ppath = "./test_data/concat_input1.bin";
|
||||
std::string input2Ppath = "./test_data/concat_input2.bin";
|
||||
std::string input3Ppath = "./test_data/concat_input3.bin";
|
||||
std::string correctOutputPath = "./test_data/concat_output.bin";
|
||||
std::string input1Ppath = "./test_data/concatfp32_input1.bin";
|
||||
std::string input2Ppath = "./test_data/concatfp32_input2.bin";
|
||||
std::string input3Ppath = "./test_data/concatfp32_input3.bin";
|
||||
std::string correctOutputPath = "./test_data/concatfp32_output.bin";
|
||||
auto input_data1 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size));
|
||||
auto input_data2 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size));
|
||||
auto input_data3 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input3Ppath.c_str(), &input3_size));
|
||||
|
@ -180,8 +180,8 @@ TEST_F(TestConcatOpenCLfp32, ConcatFp32_2input_dim4_axis3) {
|
|||
MS_LOG(INFO) << " init tensors ";
|
||||
constexpr int INPUT_NUM = 3;
|
||||
std::array<std::vector<int>, INPUT_NUM> input_shapes = {
|
||||
std::vector<int>{1, 16, 256, 80}, std::vector<int>{1, 16, 256, 80}, std::vector<int>{1, 16, 256, 80}};
|
||||
std::vector<int> output_shape = {1, 48, 256, 80};
|
||||
std::vector<int>{1, 2, 4, 8}, std::vector<int>{1, 2, 4, 8}, std::vector<int>{1, 2, 4, 8}};
|
||||
std::vector<int> output_shape = {3, 2, 4, 8};
|
||||
auto data_type = kNumberTypeFloat32;
|
||||
auto tensor_type = schema::NodeType_ValueNode;
|
||||
std::vector<lite::tensor::Tensor *> inputs;
|
||||
|
@ -217,7 +217,7 @@ TEST_F(TestConcatOpenCLfp32, ConcatFp32_2input_dim4_axis3) {
|
|||
}
|
||||
return;
|
||||
}
|
||||
param->axis_ = 1;
|
||||
param->axis_ = 0;
|
||||
auto *concat_kernel =
|
||||
new (std::nothrow) kernel::ConcatOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
|
||||
if (concat_kernel == nullptr) {
|
||||
|
|
Loading…
Reference in New Issue