forked from mindspore-Ecosystem/mindspore
transpose output set to buffer
This commit is contained in:
parent
19c800a758
commit
4d3be49a66
|
@ -3,7 +3,7 @@
|
||||||
#define READ_IMAGE read_imagef
|
#define READ_IMAGE read_imagef
|
||||||
#define WRITE_IMAGE write_imagef
|
#define WRITE_IMAGE write_imagef
|
||||||
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||||
__kernel void transpose(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 HW, int2 C) {
|
__kernel void transpose_IMG(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 HW, int2 C) {
|
||||||
int X = get_global_id(0);
|
int X = get_global_id(0);
|
||||||
int Y = get_global_id(1);
|
int Y = get_global_id(1);
|
||||||
if (X >= HW.y || Y >= C.y) {
|
if (X >= HW.y || Y >= C.y) {
|
||||||
|
@ -43,3 +43,44 @@ __kernel void transpose(__read_only image2d_t src_data, __write_only image2d_t d
|
||||||
WRITE_IMAGE(dst_data, (int2)(X, 4 * Y + 2), result[2]);
|
WRITE_IMAGE(dst_data, (int2)(X, 4 * Y + 2), result[2]);
|
||||||
WRITE_IMAGE(dst_data, (int2)(X, 4 * Y + 3), result[3]);
|
WRITE_IMAGE(dst_data, (int2)(X, 4 * Y + 3), result[3]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__kernel void transpose_BUF(__read_only image2d_t src_data, global FLT4 *dst_data, int2 HW, int2 C) {
|
||||||
|
int X = get_global_id(0);
|
||||||
|
int Y = get_global_id(1);
|
||||||
|
if (X >= HW.y || Y >= C.y) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
FLT4 result[4];
|
||||||
|
result[0] = (FLT4)(0.0f);
|
||||||
|
result[1] = (FLT4)(0.0f);
|
||||||
|
result[2] = (FLT4)(0.0f);
|
||||||
|
result[3] = (FLT4)(0.0f);
|
||||||
|
FLT4 x0 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X));
|
||||||
|
FLT4 x1 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 1));
|
||||||
|
FLT4 x2 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 2));
|
||||||
|
FLT4 x3 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 3));
|
||||||
|
result[0].x = x0.x;
|
||||||
|
result[0].y = x1.x;
|
||||||
|
result[0].z = x2.x;
|
||||||
|
result[0].w = x3.x;
|
||||||
|
|
||||||
|
result[1].x = x0.y;
|
||||||
|
result[1].y = x1.y;
|
||||||
|
result[1].z = x2.y;
|
||||||
|
result[1].w = x3.y;
|
||||||
|
|
||||||
|
result[2].x = x0.z;
|
||||||
|
result[2].y = x1.z;
|
||||||
|
result[2].z = x2.z;
|
||||||
|
result[2].w = x3.z;
|
||||||
|
|
||||||
|
result[3].x = x0.w;
|
||||||
|
result[3].y = x1.w;
|
||||||
|
result[3].z = x2.w;
|
||||||
|
result[3].w = x3.w;
|
||||||
|
|
||||||
|
dst_data[4 * Y * HW.y + X] = result[0];
|
||||||
|
dst_data[(4 * Y + 1) * HW.y + X] = result[1];
|
||||||
|
dst_data[(4 * Y + 2) * HW.y + X] = result[2];
|
||||||
|
dst_data[(4 * Y + 3) * HW.y + X] = result[3];
|
||||||
|
}
|
||||||
|
|
|
@ -36,7 +36,11 @@ namespace mindspore::kernel {
|
||||||
int TransposeOpenCLKernel::Init() {
|
int TransposeOpenCLKernel::Init() {
|
||||||
std::string kernel_name = "transpose";
|
std::string kernel_name = "transpose";
|
||||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||||
|
if (!is_image_out_) {
|
||||||
|
kernel_name += "_BUF";
|
||||||
|
} else {
|
||||||
|
kernel_name += "_IMG";
|
||||||
|
}
|
||||||
#ifdef PROGRAM_WITH_IL
|
#ifdef PROGRAM_WITH_IL
|
||||||
ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name);
|
ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name);
|
||||||
#else
|
#else
|
||||||
|
@ -60,8 +64,12 @@ int TransposeOpenCLKernel::Init() {
|
||||||
MS_LOG(ERROR) << "input H * W % 4 != 0 not support!";
|
MS_LOG(ERROR) << "input H * W % 4 != 0 not support!";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
ori_format_ = out_tensors_[0]->GetFormat();
|
// Transpose::InferShape just set output->SetFormat(input->GetFormat()); -^-!
|
||||||
|
ori_format_ = schema::Format_NCHW;
|
||||||
out_tensors_[0]->SetFormat(schema::Format_NCHW);
|
out_tensors_[0]->SetFormat(schema::Format_NCHW);
|
||||||
|
if (!is_image_out_) {
|
||||||
|
out_mem_type_ = OpenCLMemType::BUF;
|
||||||
|
}
|
||||||
MS_LOG(DEBUG) << kernel_name << " Init Done!";
|
MS_LOG(DEBUG) << kernel_name << " Init Done!";
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,6 +38,7 @@ class TransposeOpenCLKernel : public OpenCLKernel {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
cl::Kernel kernel_;
|
cl::Kernel kernel_;
|
||||||
|
bool is_image_out_ = false;
|
||||||
};
|
};
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
||||||
|
|
|
@ -34,13 +34,13 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector<lite::tensor::Tensor *
|
||||||
out_parameters->clear();
|
out_parameters->clear();
|
||||||
out_convert_ops->clear();
|
out_convert_ops->clear();
|
||||||
for (size_t i = 0; i < in_tensors.size(); ++i) {
|
for (size_t i = 0; i < in_tensors.size(); ++i) {
|
||||||
OpenCLKernel* cur_opencl_op = reinterpret_cast<OpenCLKernel*>(in_kernels[i]);
|
OpenCLKernel *cur_opencl_op = reinterpret_cast<OpenCLKernel *>(in_kernels[i]);
|
||||||
schema::Format ori_format = cur_opencl_op->GetOriFormat();
|
schema::Format ori_format = cur_opencl_op->GetOriFormat();
|
||||||
if (mem_type == cur_opencl_op->GetMemType() && in_tensors[i]->GetFormat() == ori_format) {
|
if (mem_type == OpenCLMemType::BUF && mem_type == cur_opencl_op->GetMemType() &&
|
||||||
|
in_tensors[i]->GetFormat() == ori_format) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto dst_format =
|
auto dst_format = (mem_type == OpenCLMemType::IMG) ? in_kernels[i]->out_tensors()[0]->GetFormat() : ori_format;
|
||||||
(mem_type == OpenCLMemType::IMG) ? in_kernels[i]->out_tensors()[0]->GetFormat() : ori_format;
|
|
||||||
auto src_format =
|
auto src_format =
|
||||||
(mem_type == OpenCLMemType::IMG) ? in_tensors[i]->GetFormat() : in_kernels[i]->out_tensors()[0]->GetFormat();
|
(mem_type == OpenCLMemType::IMG) ? in_tensors[i]->GetFormat() : in_kernels[i]->out_tensors()[0]->GetFormat();
|
||||||
lite::tensor::Tensor *new_tensor = new (std::nothrow) lite::tensor::Tensor();
|
lite::tensor::Tensor *new_tensor = new (std::nothrow) lite::tensor::Tensor();
|
||||||
|
|
|
@ -125,6 +125,7 @@ int RunSubGraphOpenCLKernel(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||||
MS_LOG(ERROR) << "Run SubGraphOpenCLKernel error.";
|
MS_LOG(ERROR) << "Run SubGraphOpenCLKernel error.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
delete sub_graph;
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -180,6 +181,7 @@ TEST_F(TestActivationOpenCL, LeakyReluFp32_dim4) {
|
||||||
|
|
||||||
delete input_tensor;
|
delete input_tensor;
|
||||||
delete output_tensor;
|
delete output_tensor;
|
||||||
|
lite::opencl::OpenCLRuntime::DeleteInstance();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -119,6 +119,11 @@ TEST_F(TestAvgPoolingOpenCL, AvgPoolFp32) {
|
||||||
}
|
}
|
||||||
printf("test all close OK!\n");
|
printf("test all close OK!\n");
|
||||||
lite::CompareOutputData(output_data, expect, 4);
|
lite::CompareOutputData(output_data, expect, 4);
|
||||||
|
delete tensor_in;
|
||||||
|
delete tensor_out;
|
||||||
|
delete pooling_kernel;
|
||||||
|
delete pGraph;
|
||||||
|
lite::opencl::OpenCLRuntime::DeleteInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -175,6 +175,14 @@ TEST_F(TestConcatOpenCL, ConcatFp32_2input_dim4_axis3) {
|
||||||
sub_graph->Run();
|
sub_graph->Run();
|
||||||
auto *output_data_gpu = reinterpret_cast<float *>(output_tensor->Data());
|
auto *output_data_gpu = reinterpret_cast<float *>(output_tensor->Data());
|
||||||
CompareOutputData1(output_data_gpu, output_data_cpu.data(), output_tensor->ElementsNum(), 0.00001);
|
CompareOutputData1(output_data_gpu, output_data_cpu.data(), output_tensor->ElementsNum(), 0.00001);
|
||||||
|
for (auto tensor : inputs) {
|
||||||
|
delete tensor;
|
||||||
|
}
|
||||||
|
for (auto tensor : outputs) {
|
||||||
|
delete tensor;
|
||||||
|
}
|
||||||
|
delete concat_kernel;
|
||||||
|
delete sub_graph;
|
||||||
lite::opencl::OpenCLRuntime::DeleteInstance();
|
lite::opencl::OpenCLRuntime::DeleteInstance();
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -108,5 +108,14 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) {
|
||||||
CompareOutputData(output_data, correct_data, oh * ow * co, 0.00001);
|
CompareOutputData(output_data, correct_data, oh * ow * co, 0.00001);
|
||||||
|
|
||||||
MS_LOG(INFO) << "Test Conv2dTransposeFp32 passed";
|
MS_LOG(INFO) << "Test Conv2dTransposeFp32 passed";
|
||||||
|
for (auto tensor : inputs) {
|
||||||
|
delete tensor;
|
||||||
|
}
|
||||||
|
for (auto tensor : outputs) {
|
||||||
|
delete tensor;
|
||||||
|
}
|
||||||
|
delete arith_kernel;
|
||||||
|
delete pGraph;
|
||||||
|
lite::opencl::OpenCLRuntime::DeleteInstance();
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -120,6 +120,14 @@ void TEST_MAIN(ConvParameter *param, schema::Format data_format, const std::stri
|
||||||
MyCompareOutput(output_tensor, expect_file);
|
MyCompareOutput(output_tensor, expect_file);
|
||||||
// lite::CompareOutput(reinterpret_cast<float *>(output_tensor->Data()), expect_file);
|
// lite::CompareOutput(reinterpret_cast<float *>(output_tensor->Data()), expect_file);
|
||||||
|
|
||||||
|
for (auto tensor : inputs) {
|
||||||
|
delete tensor;
|
||||||
|
}
|
||||||
|
for (auto tensor : outputs) {
|
||||||
|
delete tensor;
|
||||||
|
}
|
||||||
|
delete conv_kernel;
|
||||||
|
delete sub_graph;
|
||||||
mindspore::lite::opencl::OpenCLRuntime::DeleteInstance();
|
mindspore::lite::opencl::OpenCLRuntime::DeleteInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -75,12 +75,15 @@ TEST_F(TestMatMulOpenCL, MatMulFp32) {
|
||||||
// compare
|
// compare
|
||||||
CompareOutputData(output_data, correct_data, co, 0.00001);
|
CompareOutputData(output_data, correct_data, co, 0.00001);
|
||||||
|
|
||||||
delete input_data;
|
|
||||||
delete weight_data;
|
|
||||||
delete tensor_x;
|
|
||||||
delete tensor_w;
|
|
||||||
delete tensor_out;
|
|
||||||
delete correct_data;
|
|
||||||
MS_LOG(INFO) << "TestMatMulFp32 passed";
|
MS_LOG(INFO) << "TestMatMulFp32 passed";
|
||||||
|
for (auto tensor : inputs) {
|
||||||
|
delete tensor;
|
||||||
|
}
|
||||||
|
for (auto tensor : outputs) {
|
||||||
|
delete tensor;
|
||||||
|
}
|
||||||
|
delete arith_kernel;
|
||||||
|
delete pGraph;
|
||||||
|
lite::opencl::OpenCLRuntime::DeleteInstance();
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -92,6 +92,15 @@ TEST_F(TestMaxPoolingOpenCL, MaxPool_1_32_512_96) {
|
||||||
MS_LOG(INFO) << "compare result";
|
MS_LOG(INFO) << "compare result";
|
||||||
std::cout << "compare result" << std::endl;
|
std::cout << "compare result" << std::endl;
|
||||||
CompareOutput(output_tensor, expect_file);
|
CompareOutput(output_tensor, expect_file);
|
||||||
|
for (auto tensor : inputs) {
|
||||||
|
delete tensor;
|
||||||
|
}
|
||||||
|
for (auto tensor : outputs) {
|
||||||
|
delete tensor;
|
||||||
|
}
|
||||||
|
delete pooling_kernel;
|
||||||
|
delete pGraph;
|
||||||
|
lite::opencl::OpenCLRuntime::DeleteInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -77,6 +77,15 @@ void RunTestCase(std::vector<int> input_shape, std::vector<int> output_shape, st
|
||||||
MS_LOG(INFO) << "compare result";
|
MS_LOG(INFO) << "compare result";
|
||||||
std::cout << "compare result" << std::endl;
|
std::cout << "compare result" << std::endl;
|
||||||
CompareOutput(output_tensor, expect_file);
|
CompareOutput(output_tensor, expect_file);
|
||||||
|
for (auto tensor : inputs) {
|
||||||
|
delete tensor;
|
||||||
|
}
|
||||||
|
for (auto tensor : outputs) {
|
||||||
|
delete tensor;
|
||||||
|
}
|
||||||
|
delete kernel;
|
||||||
|
delete pGraph;
|
||||||
|
lite::opencl::OpenCLRuntime::DeleteInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestSoftmaxOpenCL, Softmax_1) {
|
TEST_F(TestSoftmaxOpenCL, Softmax_1) {
|
||||||
|
|
|
@ -75,5 +75,14 @@ TEST_F(TestToFormatOpenCL, TransposeFp32) {
|
||||||
// compare
|
// compare
|
||||||
CompareOutputData(output_data, correct_data, h * w * c, 0.00001);
|
CompareOutputData(output_data, correct_data, h * w * c, 0.00001);
|
||||||
MS_LOG(INFO) << "TestMatMulFp32 passed";
|
MS_LOG(INFO) << "TestMatMulFp32 passed";
|
||||||
|
for (auto tensor : inputs) {
|
||||||
|
delete tensor;
|
||||||
|
}
|
||||||
|
for (auto tensor : outputs) {
|
||||||
|
delete tensor;
|
||||||
|
}
|
||||||
|
delete arith_kernel;
|
||||||
|
delete pGraph;
|
||||||
|
lite::opencl::OpenCLRuntime::DeleteInstance();
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -75,5 +75,14 @@ TEST_F(TestTransposeOpenCL, TransposeFp32) {
|
||||||
// compare
|
// compare
|
||||||
CompareOutputData(output_data, correct_data, h * w * c, 0.00001);
|
CompareOutputData(output_data, correct_data, h * w * c, 0.00001);
|
||||||
MS_LOG(INFO) << "TestMatMulFp32 passed";
|
MS_LOG(INFO) << "TestMatMulFp32 passed";
|
||||||
|
for (auto tensor : inputs) {
|
||||||
|
delete tensor;
|
||||||
|
}
|
||||||
|
for (auto tensor : outputs) {
|
||||||
|
delete tensor;
|
||||||
|
}
|
||||||
|
delete arith_kernel;
|
||||||
|
delete pGraph;
|
||||||
|
lite::opencl::OpenCLRuntime::DeleteInstance();
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
Loading…
Reference in New Issue