From c29e9596dd5de3ef1446572e56c416c211fb5ef8 Mon Sep 17 00:00:00 2001 From: liuzhongkai Date: Mon, 31 Aug 2020 04:22:56 -0700 Subject: [PATCH] activation support fp16 in opencl --- .../runtime/kernel/opencl/cl/activation.cl | 38 ++--- .../kernel/opencl/cl/depthwise_conv2d.cl | 4 +- .../src/runtime/kernel/opencl/cl/prelu.cl | 16 +- .../kernel/opencl/kernel/activation.cc | 31 +++- .../runtime/kernel/opencl/kernel/activation.h | 4 + .../runtime/kernel/opencl/kernel/biasadd.cc | 2 +- .../kernel/opencl/kernel/depthwise_conv2d.cc | 10 +- .../src/runtime/kernel/opencl/kernel/prelu.cc | 2 +- .../runtime/kernel/opencl/activation_tests.cc | 155 ++++++++++++------ 9 files changed, 169 insertions(+), 93 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/activation.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/activation.cl index 20287b25f7..26b65e626f 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/activation.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/activation.cl @@ -1,22 +1,22 @@ -#pragma OPENCL EXTENSION cl_arm_printf : enable +#pragma OPENCL EXTENSION cl_khr_fp16 : enable #define SLICES 4 #define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) #define MIN(X, Y) (X < Y ? X : Y) __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; -__kernel void ReluScalar(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, - const float alpha) { +__kernel void LeakyRelu(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, + __global FLT *alpha) { int C = input_shape.w; // channel size int Y = get_global_id(0); // height id int X = get_global_id(1); // weight id for (int num = 0; num < UP_DIV(C, SLICES); ++num) { FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC FLT4 tmp; - tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * alpha; - tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * alpha; - tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * alpha; - tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * alpha; + tmp.x = in_c4.x > 0.0f ? in_c4.x : in_c4.x * alpha[0]; + tmp.y = in_c4.y > 0.0f ? in_c4.y : in_c4.y * alpha[0]; + tmp.z = in_c4.z > 0.0f ? in_c4.z : in_c4.z * alpha[0]; + tmp.w = in_c4.w > 0.0f ? in_c4.w : in_c4.w * alpha[0]; WRITE_IMAGE(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC } } @@ -28,10 +28,10 @@ __kernel void Relu(__read_only image2d_t input, __write_only image2d_t output, c for (int num = 0; num < UP_DIV(C, SLICES); ++num) { FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC FLT4 tmp; - tmp.x = in_c4.x >= 0 ? in_c4.x : 0; - tmp.y = in_c4.y >= 0 ? in_c4.y : 0; - tmp.z = in_c4.z >= 0 ? in_c4.z : 0; - tmp.w = in_c4.w >= 0 ? in_c4.w : 0; + tmp.x = in_c4.x > 0.0f ? in_c4.x : 0.0f; + tmp.y = in_c4.y > 0.0f ? in_c4.y : 0.0f; + tmp.z = in_c4.z > 0.0f ? in_c4.z : 0.0f; + tmp.w = in_c4.w > 0.0f ? in_c4.w : 0.0f; WRITE_IMAGE(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC } } @@ -43,10 +43,10 @@ __kernel void Relu6(__read_only image2d_t input, __write_only image2d_t output, for (int num = 0; num < UP_DIV(C, SLICES); ++num) { FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC FLT4 tmp; - tmp.x = in_c4.x >= 0 ? MIN(in_c4.x, 6) : 0; - tmp.y = in_c4.y >= 0 ? MIN(in_c4.y, 6) : 0; - tmp.z = in_c4.z >= 0 ? MIN(in_c4.z, 6) : 0; - tmp.w = in_c4.w >= 0 ? MIN(in_c4.w, 6) : 0; + tmp.x = in_c4.x > 0.0f ? MIN(in_c4.x, 6.0f) : 0.0f; + tmp.y = in_c4.y > 0.0f ? MIN(in_c4.y, 6.0f) : 0.0f; + tmp.z = in_c4.z > 0.0f ? MIN(in_c4.z, 6.0f) : 0.0f; + tmp.w = in_c4.w > 0.0f ? MIN(in_c4.w, 6.0f) : 0.0f; WRITE_IMAGE(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC } } @@ -58,10 +58,10 @@ __kernel void Sigmoid(__read_only image2d_t input, __write_only image2d_t output for (int num = 0; num < UP_DIV(C, SLICES); ++num) { FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC FLT4 tmp; - tmp.x = 1 / (1 + exp(-in_c4.x)); - tmp.y = 1 / (1 + exp(-in_c4.y)); - tmp.z = 1 / (1 + exp(-in_c4.z)); - tmp.w = 1 / (1 + exp(-in_c4.w)); + tmp.x = 1.0f / (1.0f + exp(-in_c4.x)); + tmp.y = 1.0f / (1.0f + exp(-in_c4.y)); + tmp.z = 1.0f / (1.0f + exp(-in_c4.z)); + tmp.w = 1.0f / (1.0f + exp(-in_c4.w)); WRITE_IMAGE(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC } } diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/depthwise_conv2d.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/depthwise_conv2d.cl index 0c1b83d444..17d4bd890a 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/depthwise_conv2d.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/depthwise_conv2d.cl @@ -84,7 +84,7 @@ __kernel void DepthwiseConv2d_IMG_NHWC4_1x1(__read_only image2d_t src_data, __gl bool outside_x = x_c < 0 || x_c >= src_size.x; if (!outside_x && !outside_y) { FLT4 flt_p = filter[fx_c]; - FLT4 src_p = READ_IMAGE(src_data, smp_zero, (int2)(Z, (y_c * src_size.x + x_c) * src_size.z)); + FLT4 src_p = READ_IMAGE(src_data, smp_zero, (int2)(Z + x_c * src_size.z, y_c)); r += TO_FLT4(src_p * flt_p); } } @@ -92,7 +92,7 @@ __kernel void DepthwiseConv2d_IMG_NHWC4_1x1(__read_only image2d_t src_data, __gl FLT4 bias_p = bias[Z]; FLT4 res = TO_FLT4(r) + bias_p; res = clamp(res, (FLT)(relu_clip_min), (FLT)(relu_clip_max)); - WRITE_IMAGE(dst_data, (int2)(Z, (Y * dst_size.x + X) * dst_size.z), res); + WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), res); } __kernel void DepthwiseConv2d_BUF_NC4HW4(__global FLT4 *src_data, __global FLT4 *filter, __global FLT4 *bias, __global FLT4 *dst_data, int2 kernel_size, int2 stride, diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/prelu.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/prelu.cl index 65166c588c..608c232caa 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/prelu.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/prelu.cl @@ -14,16 +14,16 @@ __kernel void PRelu(__read_only image2d_t input, __write_only image2d_t output, FLT4 tmp; if (dim == 1) { FLT4 weight = READ_IMAGE(alpha, smp_zero, (int2)(0, 0)); - tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * weight.x; - tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * weight.x; - tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * weight.x; - tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * weight.x; + tmp.x = in_c4.x > 0.0f ? in_c4.x : in_c4.x * weight.x; + tmp.y = in_c4.y > 0.0f ? in_c4.y : in_c4.y * weight.x; + tmp.z = in_c4.z > 0.0f ? in_c4.z : in_c4.z * weight.x; + tmp.w = in_c4.w > 0.0f ? in_c4.w : in_c4.w * weight.x; } else { FLT4 weight = READ_IMAGE(alpha, smp_zero, (int2)(num, 0)); - tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * weight.x; - tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * weight.y; - tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * weight.z; - tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * weight.w; + tmp.x = in_c4.x > 0.0f ? in_c4.x : in_c4.x * weight.x; + tmp.y = in_c4.y > 0.0f ? in_c4.y : in_c4.y * weight.y; + tmp.z = in_c4.z > 0.0f ? in_c4.z : in_c4.z * weight.z; + tmp.w = in_c4.w > 0.0f ? in_c4.w : in_c4.w * weight.w; } WRITE_IMAGE(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/activation.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/activation.cc index b40cd5ea7c..57fa9e31d7 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/activation.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/activation.cc @@ -38,15 +38,33 @@ using mindspore::schema::PrimitiveType_Activation; namespace mindspore::kernel { +void ActivationOpenClKernel::InitBuffer() { + auto allocator = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator(); + alpha_buff_ = allocator->Malloc(fp_size); + alpha_buff_ = allocator->MapBuffer(alpha_buff_, CL_MAP_WRITE, nullptr, true); + memset(alpha_buff_, 0x00, fp_size); + if (enable_fp16_) { + auto fp16 = (float16_t)alpha_; + memcpy(alpha_buff_, &fp16, fp_size); + } else { + memcpy(alpha_buff_, &alpha_, fp_size); + } + allocator->UnmapBuffer(alpha_buff_); +} + int ActivationOpenClKernel::Init() { in_size_ = in_tensors_[0]->shape().size(); out_size_ = out_tensors_[0]->shape().size(); + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + enable_fp16_ = ocl_runtime->GetFp16Enable(); + fp_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float); if (in_size_ != 2 && in_size_ != 4) { MS_LOG(ERROR) << "Activate fun only support dim=4 or 2, but your dim=" << in_size_; return RET_ERROR; } + InitBuffer(); std::map> Program_Kernel{ - {ActivationType_LEAKY_RELU, std::vector{"LEAKY_RELU", "ReluScalar"}}, + {ActivationType_LEAKY_RELU, std::vector{"LEAKY_RELU", "LeakyRelu"}}, {ActivationType_RELU, std::vector{"RELU", "Relu"}}, {ActivationType_SIGMOID, std::vector{"SIGMOID", "Sigmoid"}}, {ActivationType_RELU6, std::vector{"RELU6", "Relu6"}}}; @@ -57,7 +75,6 @@ int ActivationOpenClKernel::Init() { std::string source = activation_source; std::set build_options; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); ocl_runtime->LoadSource(Program_Kernel[type_][0], source); ocl_runtime->BuildKernel(kernel_, Program_Kernel[type_][0], Program_Kernel[type_][1], build_options); @@ -87,7 +104,7 @@ int ActivationOpenClKernel::Run() { ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data()); ocl_runtime->SetKernelArg(kernel_, arg_idx++, img2d_shape); if (type_ == ActivationType_LEAKY_RELU) { - ocl_runtime->SetKernelArg(kernel_, arg_idx++, alpha_); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, alpha_buff_, lite::opencl::MemType::BUF); } std::vector local = {1, 1}; std::vector global = {static_cast(img2d_shape.s[1]), static_cast(img2d_shape.s[2])}; @@ -114,12 +131,10 @@ cl_int4 ActivationOpenClKernel::GetImg2dShape() { int ActivationOpenClKernel::GetImageSize(size_t idx, std::vector *img_size) { cl_int4 img_shape = GetImg2dShape(); -#ifdef ENABLE_FP16 - size_t img_dtype = CL_HALF_FLOAT; -#else size_t img_dtype = CL_FLOAT; -#endif - + if (enable_fp16_) { + img_dtype = CL_HALF_FLOAT; + } img_size->clear(); img_size->push_back(img_shape.s[2] * UP_DIV(img_shape.s[3], C4NUM)); img_size->push_back(img_shape.s[1]); diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/activation.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/activation.h index d31d744a63..039a4a419c 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/activation.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/activation.h @@ -39,13 +39,17 @@ class ActivationOpenClKernel : public OpenCLKernel { int Run() override; int GetImageSize(size_t idx, std::vector *img_size) override; cl_int4 GetImg2dShape(); + void InitBuffer(); private: cl::Kernel kernel_; int type_; float alpha_; + void *alpha_buff_; int in_size_; int out_size_; + size_t fp_size; + bool enable_fp16_{false}; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.cc index cd6c880f66..89153e5aef 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.cc @@ -56,7 +56,7 @@ int BiasAddOpenCLKernel::Init() { out_size_ = out_tensors_[0]->shape().size(); auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); enable_fp16_ = ocl_runtime->GetFp16Enable(); - fp_size = enable_fp16_ ? sizeof(float) / 2 : sizeof(float); + fp_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float); if (in_size_ != 4 && in_size_ != 2) { MS_LOG(ERROR) << "BiasAdd only support dim=4 or 2, but your dim=" << in_size_; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc index 1ed841409c..1077394db3 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc @@ -44,13 +44,17 @@ int DepthwiseConv2dOpenCLKernel::Init() { auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); std::string kernel_name = "DepthwiseConv2d"; auto in_format = in_tensors_[0]->GetFormat(); - in_ori_format_ = in_format; + in_ori_format_ = in_tensors_[0]->GetFormat(); out_ori_format_ = out_tensors_[0]->GetFormat(); - out_tensors_[0]->SetFormat(in_format); + in_format = (in_format == schema::Format_NHWC) + ? schema::Format_NHWC4 + : ((in_format == schema::Format_NCHW) ? schema::Format_NC4HW4 : in_format); if (in_format != schema::Format_NHWC4 && in_format != schema::Format_NC4HW4) { MS_LOG(ERROR) << "input format(" << in_format << ") " << "format not support!"; } + in_tensors_[0]->SetFormat(in_format); + out_tensors_[0]->SetFormat(in_format); if (out_mem_type_ == OpenCLMemType::BUF) { kernel_name += "_BUF"; } else { @@ -182,7 +186,7 @@ int DepthwiseConv2dOpenCLKernel::Run() { GetLocalSize(0, global, &local); std::map> relu_clips{ - {ActType_No, {FLT_MIN, FLT_MAX}}, {ActType_Relu, {0.0, FLT_MAX}}, {ActType_Relu6, {0, 6.0}}}; + {ActType_No, {-FLT_MAX, FLT_MAX}}, {ActType_Relu, {0.0, FLT_MAX}}, {ActType_Relu6, {0, 6.0}}}; cl_int2 kernel_size = {parameter->kernel_h_, parameter->kernel_w_}; cl_int2 stride = {parameter->stride_h_, parameter->stride_w_}; cl_int2 padding = {-parameter->pad_u_, -parameter->pad_l_}; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.cc index 8b09d42773..c8385fcb2a 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.cc @@ -68,7 +68,7 @@ int PReluOpenCLKernel::Init() { std::string kernel_name = "PRelu"; auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); enable_fp16_ = ocl_runtime->GetFp16Enable(); - fp_size = enable_fp16_ ? sizeof(float) / 2 : sizeof(float); + fp_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float); InitBuffer(); ocl_runtime->LoadSource(program_name, source); ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc index 1298ec65ef..13608a5e2f 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc @@ -46,16 +46,17 @@ void LoadActivationData(void *dst, size_t dst_size, const std::string &file_path } } +template void CompareRes(lite::tensor::Tensor *output_tensor, const std::string &standard_answer_file) { - auto *output_data = reinterpret_cast(output_tensor->Data()); + auto *output_data = reinterpret_cast(output_tensor->Data()); size_t output_size = output_tensor->Size(); - auto expect_data = reinterpret_cast(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size)); + auto expect_data = reinterpret_cast(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size)); constexpr float atol = 0.0002; for (int i = 0; i < output_tensor->ElementsNum(); ++i) { if (std::fabs(output_data[i] - expect_data[i]) > atol) { - printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]); - printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]); - printf("error at idx[%d] expect=%.3f output=%.3f\n\n\n", i, expect_data[i], output_data[i]); + printf("error at idx[%d] expect=%f output=%f\n", i, expect_data[i], output_data[i]); + printf("error at idx[%d] expect=%f output=%f\n", i, expect_data[i], output_data[i]); + printf("error at idx[%d] expect=%f output=%f\n\n\n", i, expect_data[i], output_data[i]); return; } } @@ -64,8 +65,10 @@ void CompareRes(lite::tensor::Tensor *output_tensor, const std::string &standard printf("compare success!\n\n\n"); } -void printf_tensor(mindspore::lite::tensor::Tensor *in_data) { - auto input_data = reinterpret_cast(in_data->Data()); +template +void printf_tensor(const std::string &str, mindspore::lite::tensor::Tensor *in_data) { + MS_LOG(INFO) << str; + auto input_data = reinterpret_cast(in_data->Data()); for (int i = 0; i < in_data->ElementsNum(); ++i) { printf("%f ", input_data[i]); } @@ -73,24 +76,29 @@ void printf_tensor(mindspore::lite::tensor::Tensor *in_data) { MS_LOG(INFO) << "Print tensor done"; } -TEST_F(TestActivationOpenCL, ReluFp32_dim4) { +TEST_F(TestActivationOpenCL, ReluFp_dim4) { std::string in_file = "/data/local/tmp/in_data.bin"; std::string out_file = "/data/local/tmp/relu.bin"; MS_LOG(INFO) << "Relu Begin test!"; auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); - + auto data_type = kNumberTypeFloat16; + ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16); + bool enable_fp16 = ocl_runtime->GetFp16Enable(); MS_LOG(INFO) << "Init tensors."; std::vector input_shape = {1, 9}; - auto data_type = kNumberTypeFloat32; + schema::Format format = schema::Format_NHWC; + if (input_shape.size() == 2) { + format = schema::Format_NC; + } auto tensor_type = schema::NodeType_ValueNode; - auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NC, tensor_type); + auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensor_type); if (input_tensor == nullptr) { MS_LOG(ERROR) << "new input tensor error!"; return; } - auto *output_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NC, tensor_type); + auto *output_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensor_type); if (output_tensor == nullptr) { MS_LOG(ERROR) << "new output tensor error!"; delete input_tensor; @@ -99,10 +107,12 @@ TEST_F(TestActivationOpenCL, ReluFp32_dim4) { std::vector inputs{input_tensor}; std::vector outputs{output_tensor}; inputs[0]->MallocData(allocator); - MS_LOG(INFO) << "Initialize input data"; LoadActivationData(inputs[0]->Data(), inputs[0]->Size(), in_file); - MS_LOG(INFO) << "==================input data================"; - printf_tensor(inputs[0]); + if (enable_fp16) { + printf_tensor("ReluFp16:--input data---", inputs[0]); + } else { + printf_tensor("ReluFp32:--input data---", inputs[0]); + } auto *param = new (std::nothrow) ActivationParameter(); if (param == nullptr) { @@ -164,35 +174,44 @@ TEST_F(TestActivationOpenCL, ReluFp32_dim4) { MS_LOG(ERROR) << "Run SubGraphOpenCLKernel error."; return; } - - MS_LOG(INFO) << "==================output data================"; - printf_tensor(outputs[0]); - CompareRes(output_tensor, out_file); + if (enable_fp16) { + printf_tensor("ReluFp16--output data---", outputs[0]); + CompareRes(output_tensor, out_file); + } else { + printf_tensor("ReluFp32--output data--", outputs[0]); + CompareRes(output_tensor, out_file); + } delete kernel; delete param; delete input_tensor; delete output_tensor; delete sub_graph; + lite::opencl::OpenCLRuntime::DeleteInstance(); } -TEST_F(TestActivationOpenCL, Relu6Fp32_dim4) { +TEST_F(TestActivationOpenCL, Relu6Fp_dim4) { std::string in_file = "/data/local/tmp/in_data.bin"; std::string out_file = "/data/local/tmp/relu6.bin"; MS_LOG(INFO) << "Relu6 Begin test!"; auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto data_type = kNumberTypeFloat32; + ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16); + bool enable_fp16 = ocl_runtime->GetFp16Enable(); ocl_runtime->Init(); - auto allocator = ocl_runtime->GetAllocator(); MS_LOG(INFO) << "Init tensors."; std::vector input_shape = {1, 9}; - auto data_type = kNumberTypeFloat32; + schema::Format format = schema::Format_NHWC; + if (input_shape.size() == 2) { + format = schema::Format_NC; + } auto tensor_type = schema::NodeType_ValueNode; - auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NC, tensor_type); + auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensor_type); if (input_tensor == nullptr) { MS_LOG(ERROR) << "new input tensor error!"; return; } - auto *output_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NC, tensor_type); + auto *output_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensor_type); if (output_tensor == nullptr) { MS_LOG(ERROR) << "new output tensor error!"; delete input_tensor; @@ -200,11 +219,15 @@ TEST_F(TestActivationOpenCL, Relu6Fp32_dim4) { } std::vector inputs{input_tensor}; std::vector outputs{output_tensor}; + auto allocator = ocl_runtime->GetAllocator(); inputs[0]->MallocData(allocator); MS_LOG(INFO) << "Initialize input data"; LoadActivationData(inputs[0]->Data(), inputs[0]->Size(), in_file); - MS_LOG(INFO) << "==================input data================"; - printf_tensor(inputs[0]); + if (enable_fp16) { + printf_tensor("Relu6:FP16--input data--", inputs[0]); + } else { + printf_tensor("Relu6:FP32--input data--", inputs[0]); + } auto *param = new (std::nothrow) ActivationParameter(); if (param == nullptr) { @@ -267,34 +290,44 @@ TEST_F(TestActivationOpenCL, Relu6Fp32_dim4) { return; } - MS_LOG(INFO) << "==================output data================"; - printf_tensor(outputs[0]); - CompareRes(output_tensor, out_file); + if (enable_fp16) { + printf_tensor("Relu6:FP16--output data---", outputs[0]); + CompareRes(output_tensor, out_file); + } else { + printf_tensor("Relu6:FP32--output data---", outputs[0]); + CompareRes(output_tensor, out_file); + } delete kernel; delete param; delete input_tensor; delete output_tensor; delete sub_graph; + lite::opencl::OpenCLRuntime::DeleteInstance(); } -TEST_F(TestActivationOpenCL, SigmoidFp32_dim4) { +TEST_F(TestActivationOpenCL, SigmoidFp_dim4) { std::string in_file = "/data/local/tmp/in_data.bin"; std::string out_file = "/data/local/tmp/sigmoid.bin"; MS_LOG(INFO) << "Sigmoid Begin test!"; auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); ocl_runtime->Init(); - auto allocator = ocl_runtime->GetAllocator(); + auto data_type = kNumberTypeFloat16; + ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16); + bool enable_fp16 = ocl_runtime->GetFp16Enable(); MS_LOG(INFO) << "Init tensors."; std::vector input_shape = {1, 9}; - auto data_type = kNumberTypeFloat32; + schema::Format format = schema::Format_NHWC; + if (input_shape.size() == 2) { + format = schema::Format_NC; + } auto tensor_type = schema::NodeType_ValueNode; - auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NC, tensor_type); + auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensor_type); if (input_tensor == nullptr) { MS_LOG(ERROR) << "new input tensor error!"; return; } - auto *output_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NC, tensor_type); + auto *output_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensor_type); if (output_tensor == nullptr) { MS_LOG(ERROR) << "new output tensor error!"; delete input_tensor; @@ -302,11 +335,15 @@ TEST_F(TestActivationOpenCL, SigmoidFp32_dim4) { } std::vector inputs{input_tensor}; std::vector outputs{output_tensor}; + auto allocator = ocl_runtime->GetAllocator(); inputs[0]->MallocData(allocator); MS_LOG(INFO) << "Initialize input data"; LoadActivationData(inputs[0]->Data(), inputs[0]->Size(), in_file); - MS_LOG(INFO) << "==================input data================"; - printf_tensor(inputs[0]); + if (enable_fp16) { + printf_tensor("Sigmoid:FP16--input data--", inputs[0]); + } else { + printf_tensor("Sigmoid:FP32--input data--", inputs[0]); + } auto *param = new (std::nothrow) ActivationParameter(); if (param == nullptr) { @@ -369,9 +406,13 @@ TEST_F(TestActivationOpenCL, SigmoidFp32_dim4) { return; } - MS_LOG(INFO) << "==================output data================"; - printf_tensor(outputs[0]); - CompareRes(output_tensor, out_file); + if (enable_fp16) { + printf_tensor("Sigmoid:FP16--output data---", outputs[0]); + CompareRes(output_tensor, out_file); + } else { + printf_tensor("Sigmoid:FP32--output data---", outputs[0]); + CompareRes(output_tensor, out_file); + } delete kernel; delete param; delete input_tensor; @@ -380,24 +421,29 @@ TEST_F(TestActivationOpenCL, SigmoidFp32_dim4) { lite::opencl::OpenCLRuntime::DeleteInstance(); } -TEST_F(TestActivationOpenCL, LeakyReluFp32_dim4) { +TEST_F(TestActivationOpenCL, LeakyReluFp_dim4) { std::string in_file = "/data/local/tmp/in_data.bin"; std::string out_file = "/data/local/tmp/leaky_relu.bin"; MS_LOG(INFO) << "Leaky relu Begin test!"; auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); ocl_runtime->Init(); - auto allocator = ocl_runtime->GetAllocator(); + auto data_type = kNumberTypeFloat32; + ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16); + bool enable_fp16 = ocl_runtime->GetFp16Enable(); MS_LOG(INFO) << "Init tensors."; std::vector input_shape = {1, 9}; - auto data_type = kNumberTypeFloat32; auto tensor_type = schema::NodeType_ValueNode; - auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NC, tensor_type); + schema::Format format = schema::Format_NHWC; + if (input_shape.size() == 2) { + format = schema::Format_NC; + } + auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensor_type); if (input_tensor == nullptr) { MS_LOG(ERROR) << "new input tensor error!"; return; } - auto *output_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NC, tensor_type); + auto *output_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensor_type); if (output_tensor == nullptr) { MS_LOG(ERROR) << "new output tensor error!"; delete input_tensor; @@ -405,11 +451,15 @@ TEST_F(TestActivationOpenCL, LeakyReluFp32_dim4) { } std::vector inputs{input_tensor}; std::vector outputs{output_tensor}; + auto allocator = ocl_runtime->GetAllocator(); inputs[0]->MallocData(allocator); MS_LOG(INFO) << "Initialize input data"; LoadActivationData(inputs[0]->Data(), inputs[0]->Size(), in_file); - MS_LOG(INFO) << "==================input data================"; - printf_tensor(inputs[0]); + if (enable_fp16) { + printf_tensor("Leaky Relu:FP16--input data--", inputs[0]); + } else { + printf_tensor("Leaky Relu:FP32--input data--", inputs[0]); + } auto *param = new (std::nothrow) ActivationParameter(); if (param == nullptr) { @@ -418,7 +468,7 @@ TEST_F(TestActivationOpenCL, LeakyReluFp32_dim4) { delete output_tensor; return; } - param->alpha_ = 0.3; + param->alpha_ = 0.3f; param->type_ = ActivationType_LEAKY_RELU; auto *kernel = new (std::nothrow) kernel::ActivationOpenClKernel(reinterpret_cast(param), inputs, outputs); @@ -472,10 +522,13 @@ TEST_F(TestActivationOpenCL, LeakyReluFp32_dim4) { MS_LOG(ERROR) << "Run SubGraphOpenCLKernel error."; return; } - - MS_LOG(INFO) << "==================output data================"; - printf_tensor(outputs[0]); - CompareRes(output_tensor, out_file); + if (enable_fp16) { + printf_tensor("Leaky Relu:FP16--output data---", outputs[0]); + CompareRes(output_tensor, out_file); + } else { + printf_tensor("Leaky Relu:FP32--output data---", outputs[0]); + CompareRes(output_tensor, out_file); + } delete kernel; delete param; delete input_tensor;