forked from OSSInnovation/mindspore
!5588 activation support f16 in opencl
Merge pull request !5588 from liuzhongkai/activation1_fp16
This commit is contained in:
commit
39e2791149
|
@ -1,22 +1,22 @@
|
|||
#pragma OPENCL EXTENSION cl_arm_printf : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#define SLICES 4
|
||||
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
|
||||
#define MIN(X, Y) (X < Y ? X : Y)
|
||||
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
|
||||
__kernel void ReluScalar(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape,
|
||||
const float alpha) {
|
||||
__kernel void LeakyRelu(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape,
|
||||
__global FLT *alpha) {
|
||||
int C = input_shape.w; // channel size
|
||||
int Y = get_global_id(0); // height id
|
||||
int X = get_global_id(1); // weight id
|
||||
for (int num = 0; num < UP_DIV(C, SLICES); ++num) {
|
||||
FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC
|
||||
FLT4 tmp;
|
||||
tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * alpha;
|
||||
tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * alpha;
|
||||
tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * alpha;
|
||||
tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * alpha;
|
||||
tmp.x = in_c4.x > 0.0f ? in_c4.x : in_c4.x * alpha[0];
|
||||
tmp.y = in_c4.y > 0.0f ? in_c4.y : in_c4.y * alpha[0];
|
||||
tmp.z = in_c4.z > 0.0f ? in_c4.z : in_c4.z * alpha[0];
|
||||
tmp.w = in_c4.w > 0.0f ? in_c4.w : in_c4.w * alpha[0];
|
||||
WRITE_IMAGE(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC
|
||||
}
|
||||
}
|
||||
|
@ -28,10 +28,10 @@ __kernel void Relu(__read_only image2d_t input, __write_only image2d_t output, c
|
|||
for (int num = 0; num < UP_DIV(C, SLICES); ++num) {
|
||||
FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC
|
||||
FLT4 tmp;
|
||||
tmp.x = in_c4.x >= 0 ? in_c4.x : 0;
|
||||
tmp.y = in_c4.y >= 0 ? in_c4.y : 0;
|
||||
tmp.z = in_c4.z >= 0 ? in_c4.z : 0;
|
||||
tmp.w = in_c4.w >= 0 ? in_c4.w : 0;
|
||||
tmp.x = in_c4.x > 0.0f ? in_c4.x : 0.0f;
|
||||
tmp.y = in_c4.y > 0.0f ? in_c4.y : 0.0f;
|
||||
tmp.z = in_c4.z > 0.0f ? in_c4.z : 0.0f;
|
||||
tmp.w = in_c4.w > 0.0f ? in_c4.w : 0.0f;
|
||||
WRITE_IMAGE(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC
|
||||
}
|
||||
}
|
||||
|
@ -43,10 +43,10 @@ __kernel void Relu6(__read_only image2d_t input, __write_only image2d_t output,
|
|||
for (int num = 0; num < UP_DIV(C, SLICES); ++num) {
|
||||
FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC
|
||||
FLT4 tmp;
|
||||
tmp.x = in_c4.x >= 0 ? MIN(in_c4.x, 6) : 0;
|
||||
tmp.y = in_c4.y >= 0 ? MIN(in_c4.y, 6) : 0;
|
||||
tmp.z = in_c4.z >= 0 ? MIN(in_c4.z, 6) : 0;
|
||||
tmp.w = in_c4.w >= 0 ? MIN(in_c4.w, 6) : 0;
|
||||
tmp.x = in_c4.x > 0.0f ? MIN(in_c4.x, 6.0f) : 0.0f;
|
||||
tmp.y = in_c4.y > 0.0f ? MIN(in_c4.y, 6.0f) : 0.0f;
|
||||
tmp.z = in_c4.z > 0.0f ? MIN(in_c4.z, 6.0f) : 0.0f;
|
||||
tmp.w = in_c4.w > 0.0f ? MIN(in_c4.w, 6.0f) : 0.0f;
|
||||
WRITE_IMAGE(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC
|
||||
}
|
||||
}
|
||||
|
@ -58,10 +58,10 @@ __kernel void Sigmoid(__read_only image2d_t input, __write_only image2d_t output
|
|||
for (int num = 0; num < UP_DIV(C, SLICES); ++num) {
|
||||
FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC
|
||||
FLT4 tmp;
|
||||
tmp.x = 1 / (1 + exp(-in_c4.x));
|
||||
tmp.y = 1 / (1 + exp(-in_c4.y));
|
||||
tmp.z = 1 / (1 + exp(-in_c4.z));
|
||||
tmp.w = 1 / (1 + exp(-in_c4.w));
|
||||
tmp.x = 1.0f / (1.0f + exp(-in_c4.x));
|
||||
tmp.y = 1.0f / (1.0f + exp(-in_c4.y));
|
||||
tmp.z = 1.0f / (1.0f + exp(-in_c4.z));
|
||||
tmp.w = 1.0f / (1.0f + exp(-in_c4.w));
|
||||
WRITE_IMAGE(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC
|
||||
}
|
||||
}
|
||||
|
|
|
@ -84,7 +84,7 @@ __kernel void DepthwiseConv2d_IMG_NHWC4_1x1(__read_only image2d_t src_data, __gl
|
|||
bool outside_x = x_c < 0 || x_c >= src_size.x;
|
||||
if (!outside_x && !outside_y) {
|
||||
FLT4 flt_p = filter[fx_c];
|
||||
FLT4 src_p = READ_IMAGE(src_data, smp_zero, (int2)(Z, (y_c * src_size.x + x_c) * src_size.z));
|
||||
FLT4 src_p = READ_IMAGE(src_data, smp_zero, (int2)(Z + x_c * src_size.z, y_c));
|
||||
r += TO_FLT4(src_p * flt_p);
|
||||
}
|
||||
}
|
||||
|
@ -92,7 +92,7 @@ __kernel void DepthwiseConv2d_IMG_NHWC4_1x1(__read_only image2d_t src_data, __gl
|
|||
FLT4 bias_p = bias[Z];
|
||||
FLT4 res = TO_FLT4(r) + bias_p;
|
||||
res = clamp(res, (FLT)(relu_clip_min), (FLT)(relu_clip_max));
|
||||
WRITE_IMAGE(dst_data, (int2)(Z, (Y * dst_size.x + X) * dst_size.z), res);
|
||||
WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), res);
|
||||
}
|
||||
__kernel void DepthwiseConv2d_BUF_NC4HW4(__global FLT4 *src_data, __global FLT4 *filter, __global FLT4 *bias,
|
||||
__global FLT4 *dst_data, int2 kernel_size, int2 stride,
|
||||
|
|
|
@ -14,16 +14,16 @@ __kernel void PRelu(__read_only image2d_t input, __write_only image2d_t output,
|
|||
FLT4 tmp;
|
||||
if (dim == 1) {
|
||||
FLT4 weight = READ_IMAGE(alpha, smp_zero, (int2)(0, 0));
|
||||
tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * weight.x;
|
||||
tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * weight.x;
|
||||
tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * weight.x;
|
||||
tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * weight.x;
|
||||
tmp.x = in_c4.x > 0.0f ? in_c4.x : in_c4.x * weight.x;
|
||||
tmp.y = in_c4.y > 0.0f ? in_c4.y : in_c4.y * weight.x;
|
||||
tmp.z = in_c4.z > 0.0f ? in_c4.z : in_c4.z * weight.x;
|
||||
tmp.w = in_c4.w > 0.0f ? in_c4.w : in_c4.w * weight.x;
|
||||
} else {
|
||||
FLT4 weight = READ_IMAGE(alpha, smp_zero, (int2)(num, 0));
|
||||
tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * weight.x;
|
||||
tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * weight.y;
|
||||
tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * weight.z;
|
||||
tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * weight.w;
|
||||
tmp.x = in_c4.x > 0.0f ? in_c4.x : in_c4.x * weight.x;
|
||||
tmp.y = in_c4.y > 0.0f ? in_c4.y : in_c4.y * weight.y;
|
||||
tmp.z = in_c4.z > 0.0f ? in_c4.z : in_c4.z * weight.z;
|
||||
tmp.w = in_c4.w > 0.0f ? in_c4.w : in_c4.w * weight.w;
|
||||
}
|
||||
WRITE_IMAGE(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC
|
||||
}
|
||||
|
|
|
@ -38,15 +38,33 @@ using mindspore::schema::PrimitiveType_Activation;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
void ActivationOpenClKernel::InitBuffer() {
|
||||
auto allocator = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator();
|
||||
alpha_buff_ = allocator->Malloc(fp_size);
|
||||
alpha_buff_ = allocator->MapBuffer(alpha_buff_, CL_MAP_WRITE, nullptr, true);
|
||||
memset(alpha_buff_, 0x00, fp_size);
|
||||
if (enable_fp16_) {
|
||||
auto fp16 = (float16_t)alpha_;
|
||||
memcpy(alpha_buff_, &fp16, fp_size);
|
||||
} else {
|
||||
memcpy(alpha_buff_, &alpha_, fp_size);
|
||||
}
|
||||
allocator->UnmapBuffer(alpha_buff_);
|
||||
}
|
||||
|
||||
int ActivationOpenClKernel::Init() {
|
||||
in_size_ = in_tensors_[0]->shape().size();
|
||||
out_size_ = out_tensors_[0]->shape().size();
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
enable_fp16_ = ocl_runtime->GetFp16Enable();
|
||||
fp_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float);
|
||||
if (in_size_ != 2 && in_size_ != 4) {
|
||||
MS_LOG(ERROR) << "Activate fun only support dim=4 or 2, but your dim=" << in_size_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
InitBuffer();
|
||||
std::map<int, std::vector<std::string>> Program_Kernel{
|
||||
{ActivationType_LEAKY_RELU, std::vector<std::string>{"LEAKY_RELU", "ReluScalar"}},
|
||||
{ActivationType_LEAKY_RELU, std::vector<std::string>{"LEAKY_RELU", "LeakyRelu"}},
|
||||
{ActivationType_RELU, std::vector<std::string>{"RELU", "Relu"}},
|
||||
{ActivationType_SIGMOID, std::vector<std::string>{"SIGMOID", "Sigmoid"}},
|
||||
{ActivationType_RELU6, std::vector<std::string>{"RELU6", "Relu6"}}};
|
||||
|
@ -57,7 +75,6 @@ int ActivationOpenClKernel::Init() {
|
|||
|
||||
std::string source = activation_source;
|
||||
std::set<std::string> build_options;
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->LoadSource(Program_Kernel[type_][0], source);
|
||||
ocl_runtime->BuildKernel(kernel_, Program_Kernel[type_][0], Program_Kernel[type_][1], build_options);
|
||||
|
||||
|
@ -87,7 +104,7 @@ int ActivationOpenClKernel::Run() {
|
|||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data());
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, img2d_shape);
|
||||
if (type_ == ActivationType_LEAKY_RELU) {
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, alpha_);
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, alpha_buff_, lite::opencl::MemType::BUF);
|
||||
}
|
||||
std::vector<size_t> local = {1, 1};
|
||||
std::vector<size_t> global = {static_cast<size_t>(img2d_shape.s[1]), static_cast<size_t>(img2d_shape.s[2])};
|
||||
|
@ -114,12 +131,10 @@ cl_int4 ActivationOpenClKernel::GetImg2dShape() {
|
|||
|
||||
int ActivationOpenClKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
|
||||
cl_int4 img_shape = GetImg2dShape();
|
||||
#ifdef ENABLE_FP16
|
||||
size_t img_dtype = CL_HALF_FLOAT;
|
||||
#else
|
||||
size_t img_dtype = CL_FLOAT;
|
||||
#endif
|
||||
|
||||
if (enable_fp16_) {
|
||||
img_dtype = CL_HALF_FLOAT;
|
||||
}
|
||||
img_size->clear();
|
||||
img_size->push_back(img_shape.s[2] * UP_DIV(img_shape.s[3], C4NUM));
|
||||
img_size->push_back(img_shape.s[1]);
|
||||
|
|
|
@ -39,13 +39,17 @@ class ActivationOpenClKernel : public OpenCLKernel {
|
|||
int Run() override;
|
||||
int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;
|
||||
cl_int4 GetImg2dShape();
|
||||
void InitBuffer();
|
||||
|
||||
private:
|
||||
cl::Kernel kernel_;
|
||||
int type_;
|
||||
float alpha_;
|
||||
void *alpha_buff_;
|
||||
int in_size_;
|
||||
int out_size_;
|
||||
size_t fp_size;
|
||||
bool enable_fp16_{false};
|
||||
};
|
||||
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -56,7 +56,7 @@ int BiasAddOpenCLKernel::Init() {
|
|||
out_size_ = out_tensors_[0]->shape().size();
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
enable_fp16_ = ocl_runtime->GetFp16Enable();
|
||||
fp_size = enable_fp16_ ? sizeof(float) / 2 : sizeof(float);
|
||||
fp_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float);
|
||||
if (in_size_ != 4 && in_size_ != 2) {
|
||||
MS_LOG(ERROR) << "BiasAdd only support dim=4 or 2, but your dim=" << in_size_;
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -44,13 +44,17 @@ int DepthwiseConv2dOpenCLKernel::Init() {
|
|||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
std::string kernel_name = "DepthwiseConv2d";
|
||||
auto in_format = in_tensors_[0]->GetFormat();
|
||||
in_ori_format_ = in_format;
|
||||
in_ori_format_ = in_tensors_[0]->GetFormat();
|
||||
out_ori_format_ = out_tensors_[0]->GetFormat();
|
||||
out_tensors_[0]->SetFormat(in_format);
|
||||
in_format = (in_format == schema::Format_NHWC)
|
||||
? schema::Format_NHWC4
|
||||
: ((in_format == schema::Format_NCHW) ? schema::Format_NC4HW4 : in_format);
|
||||
if (in_format != schema::Format_NHWC4 && in_format != schema::Format_NC4HW4) {
|
||||
MS_LOG(ERROR) << "input format(" << in_format << ") "
|
||||
<< "format not support!";
|
||||
}
|
||||
in_tensors_[0]->SetFormat(in_format);
|
||||
out_tensors_[0]->SetFormat(in_format);
|
||||
if (out_mem_type_ == OpenCLMemType::BUF) {
|
||||
kernel_name += "_BUF";
|
||||
} else {
|
||||
|
@ -182,7 +186,7 @@ int DepthwiseConv2dOpenCLKernel::Run() {
|
|||
GetLocalSize(0, global, &local);
|
||||
|
||||
std::map<ActType, std::pair<float, float>> relu_clips{
|
||||
{ActType_No, {FLT_MIN, FLT_MAX}}, {ActType_Relu, {0.0, FLT_MAX}}, {ActType_Relu6, {0, 6.0}}};
|
||||
{ActType_No, {-FLT_MAX, FLT_MAX}}, {ActType_Relu, {0.0, FLT_MAX}}, {ActType_Relu6, {0, 6.0}}};
|
||||
cl_int2 kernel_size = {parameter->kernel_h_, parameter->kernel_w_};
|
||||
cl_int2 stride = {parameter->stride_h_, parameter->stride_w_};
|
||||
cl_int2 padding = {-parameter->pad_u_, -parameter->pad_l_};
|
||||
|
|
|
@ -68,7 +68,7 @@ int PReluOpenCLKernel::Init() {
|
|||
std::string kernel_name = "PRelu";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
enable_fp16_ = ocl_runtime->GetFp16Enable();
|
||||
fp_size = enable_fp16_ ? sizeof(float) / 2 : sizeof(float);
|
||||
fp_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float);
|
||||
InitBuffer();
|
||||
ocl_runtime->LoadSource(program_name, source);
|
||||
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
|
||||
|
|
|
@ -46,16 +46,17 @@ void LoadActivationData(void *dst, size_t dst_size, const std::string &file_path
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CompareRes(lite::tensor::Tensor *output_tensor, const std::string &standard_answer_file) {
|
||||
auto *output_data = reinterpret_cast<float *>(output_tensor->Data());
|
||||
auto *output_data = reinterpret_cast<T *>(output_tensor->Data());
|
||||
size_t output_size = output_tensor->Size();
|
||||
auto expect_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size));
|
||||
auto expect_data = reinterpret_cast<T *>(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size));
|
||||
constexpr float atol = 0.0002;
|
||||
for (int i = 0; i < output_tensor->ElementsNum(); ++i) {
|
||||
if (std::fabs(output_data[i] - expect_data[i]) > atol) {
|
||||
printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]);
|
||||
printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]);
|
||||
printf("error at idx[%d] expect=%.3f output=%.3f\n\n\n", i, expect_data[i], output_data[i]);
|
||||
printf("error at idx[%d] expect=%f output=%f\n", i, expect_data[i], output_data[i]);
|
||||
printf("error at idx[%d] expect=%f output=%f\n", i, expect_data[i], output_data[i]);
|
||||
printf("error at idx[%d] expect=%f output=%f\n\n\n", i, expect_data[i], output_data[i]);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -64,8 +65,10 @@ void CompareRes(lite::tensor::Tensor *output_tensor, const std::string &standard
|
|||
printf("compare success!\n\n\n");
|
||||
}
|
||||
|
||||
void printf_tensor(mindspore::lite::tensor::Tensor *in_data) {
|
||||
auto input_data = reinterpret_cast<float *>(in_data->Data());
|
||||
template <typename T>
|
||||
void printf_tensor(const std::string &str, mindspore::lite::tensor::Tensor *in_data) {
|
||||
MS_LOG(INFO) << str;
|
||||
auto input_data = reinterpret_cast<T *>(in_data->Data());
|
||||
for (int i = 0; i < in_data->ElementsNum(); ++i) {
|
||||
printf("%f ", input_data[i]);
|
||||
}
|
||||
|
@ -73,24 +76,29 @@ void printf_tensor(mindspore::lite::tensor::Tensor *in_data) {
|
|||
MS_LOG(INFO) << "Print tensor done";
|
||||
}
|
||||
|
||||
TEST_F(TestActivationOpenCL, ReluFp32_dim4) {
|
||||
TEST_F(TestActivationOpenCL, ReluFp_dim4) {
|
||||
std::string in_file = "/data/local/tmp/in_data.bin";
|
||||
std::string out_file = "/data/local/tmp/relu.bin";
|
||||
MS_LOG(INFO) << "Relu Begin test!";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->Init();
|
||||
auto allocator = ocl_runtime->GetAllocator();
|
||||
|
||||
auto data_type = kNumberTypeFloat16;
|
||||
ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16);
|
||||
bool enable_fp16 = ocl_runtime->GetFp16Enable();
|
||||
MS_LOG(INFO) << "Init tensors.";
|
||||
std::vector<int> input_shape = {1, 9};
|
||||
auto data_type = kNumberTypeFloat32;
|
||||
schema::Format format = schema::Format_NHWC;
|
||||
if (input_shape.size() == 2) {
|
||||
format = schema::Format_NC;
|
||||
}
|
||||
auto tensor_type = schema::NodeType_ValueNode;
|
||||
auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NC, tensor_type);
|
||||
auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensor_type);
|
||||
if (input_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "new input tensor error!";
|
||||
return;
|
||||
}
|
||||
auto *output_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NC, tensor_type);
|
||||
auto *output_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensor_type);
|
||||
if (output_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "new output tensor error!";
|
||||
delete input_tensor;
|
||||
|
@ -99,10 +107,12 @@ TEST_F(TestActivationOpenCL, ReluFp32_dim4) {
|
|||
std::vector<lite::tensor::Tensor *> inputs{input_tensor};
|
||||
std::vector<lite::tensor::Tensor *> outputs{output_tensor};
|
||||
inputs[0]->MallocData(allocator);
|
||||
MS_LOG(INFO) << "Initialize input data";
|
||||
LoadActivationData(inputs[0]->Data(), inputs[0]->Size(), in_file);
|
||||
MS_LOG(INFO) << "==================input data================";
|
||||
printf_tensor(inputs[0]);
|
||||
if (enable_fp16) {
|
||||
printf_tensor<float16_t>("ReluFp16:--input data---", inputs[0]);
|
||||
} else {
|
||||
printf_tensor<float>("ReluFp32:--input data---", inputs[0]);
|
||||
}
|
||||
|
||||
auto *param = new (std::nothrow) ActivationParameter();
|
||||
if (param == nullptr) {
|
||||
|
@ -164,35 +174,44 @@ TEST_F(TestActivationOpenCL, ReluFp32_dim4) {
|
|||
MS_LOG(ERROR) << "Run SubGraphOpenCLKernel error.";
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "==================output data================";
|
||||
printf_tensor(outputs[0]);
|
||||
CompareRes(output_tensor, out_file);
|
||||
if (enable_fp16) {
|
||||
printf_tensor<float16_t>("ReluFp16--output data---", outputs[0]);
|
||||
CompareRes<float16_t>(output_tensor, out_file);
|
||||
} else {
|
||||
printf_tensor<float>("ReluFp32--output data--", outputs[0]);
|
||||
CompareRes<float>(output_tensor, out_file);
|
||||
}
|
||||
delete kernel;
|
||||
delete param;
|
||||
delete input_tensor;
|
||||
delete output_tensor;
|
||||
delete sub_graph;
|
||||
lite::opencl::OpenCLRuntime::DeleteInstance();
|
||||
}
|
||||
|
||||
TEST_F(TestActivationOpenCL, Relu6Fp32_dim4) {
|
||||
TEST_F(TestActivationOpenCL, Relu6Fp_dim4) {
|
||||
std::string in_file = "/data/local/tmp/in_data.bin";
|
||||
std::string out_file = "/data/local/tmp/relu6.bin";
|
||||
MS_LOG(INFO) << "Relu6 Begin test!";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
auto data_type = kNumberTypeFloat32;
|
||||
ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16);
|
||||
bool enable_fp16 = ocl_runtime->GetFp16Enable();
|
||||
ocl_runtime->Init();
|
||||
auto allocator = ocl_runtime->GetAllocator();
|
||||
|
||||
MS_LOG(INFO) << "Init tensors.";
|
||||
std::vector<int> input_shape = {1, 9};
|
||||
auto data_type = kNumberTypeFloat32;
|
||||
schema::Format format = schema::Format_NHWC;
|
||||
if (input_shape.size() == 2) {
|
||||
format = schema::Format_NC;
|
||||
}
|
||||
auto tensor_type = schema::NodeType_ValueNode;
|
||||
auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NC, tensor_type);
|
||||
auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensor_type);
|
||||
if (input_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "new input tensor error!";
|
||||
return;
|
||||
}
|
||||
auto *output_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NC, tensor_type);
|
||||
auto *output_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensor_type);
|
||||
if (output_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "new output tensor error!";
|
||||
delete input_tensor;
|
||||
|
@ -200,11 +219,15 @@ TEST_F(TestActivationOpenCL, Relu6Fp32_dim4) {
|
|||
}
|
||||
std::vector<lite::tensor::Tensor *> inputs{input_tensor};
|
||||
std::vector<lite::tensor::Tensor *> outputs{output_tensor};
|
||||
auto allocator = ocl_runtime->GetAllocator();
|
||||
inputs[0]->MallocData(allocator);
|
||||
MS_LOG(INFO) << "Initialize input data";
|
||||
LoadActivationData(inputs[0]->Data(), inputs[0]->Size(), in_file);
|
||||
MS_LOG(INFO) << "==================input data================";
|
||||
printf_tensor(inputs[0]);
|
||||
if (enable_fp16) {
|
||||
printf_tensor<float16_t>("Relu6:FP16--input data--", inputs[0]);
|
||||
} else {
|
||||
printf_tensor<float>("Relu6:FP32--input data--", inputs[0]);
|
||||
}
|
||||
|
||||
auto *param = new (std::nothrow) ActivationParameter();
|
||||
if (param == nullptr) {
|
||||
|
@ -267,34 +290,44 @@ TEST_F(TestActivationOpenCL, Relu6Fp32_dim4) {
|
|||
return;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "==================output data================";
|
||||
printf_tensor(outputs[0]);
|
||||
CompareRes(output_tensor, out_file);
|
||||
if (enable_fp16) {
|
||||
printf_tensor<float16_t>("Relu6:FP16--output data---", outputs[0]);
|
||||
CompareRes<float16_t>(output_tensor, out_file);
|
||||
} else {
|
||||
printf_tensor<float>("Relu6:FP32--output data---", outputs[0]);
|
||||
CompareRes<float>(output_tensor, out_file);
|
||||
}
|
||||
delete kernel;
|
||||
delete param;
|
||||
delete input_tensor;
|
||||
delete output_tensor;
|
||||
delete sub_graph;
|
||||
lite::opencl::OpenCLRuntime::DeleteInstance();
|
||||
}
|
||||
|
||||
TEST_F(TestActivationOpenCL, SigmoidFp32_dim4) {
|
||||
TEST_F(TestActivationOpenCL, SigmoidFp_dim4) {
|
||||
std::string in_file = "/data/local/tmp/in_data.bin";
|
||||
std::string out_file = "/data/local/tmp/sigmoid.bin";
|
||||
MS_LOG(INFO) << "Sigmoid Begin test!";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->Init();
|
||||
auto allocator = ocl_runtime->GetAllocator();
|
||||
auto data_type = kNumberTypeFloat16;
|
||||
ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16);
|
||||
bool enable_fp16 = ocl_runtime->GetFp16Enable();
|
||||
|
||||
MS_LOG(INFO) << "Init tensors.";
|
||||
std::vector<int> input_shape = {1, 9};
|
||||
auto data_type = kNumberTypeFloat32;
|
||||
schema::Format format = schema::Format_NHWC;
|
||||
if (input_shape.size() == 2) {
|
||||
format = schema::Format_NC;
|
||||
}
|
||||
auto tensor_type = schema::NodeType_ValueNode;
|
||||
auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NC, tensor_type);
|
||||
auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensor_type);
|
||||
if (input_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "new input tensor error!";
|
||||
return;
|
||||
}
|
||||
auto *output_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NC, tensor_type);
|
||||
auto *output_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensor_type);
|
||||
if (output_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "new output tensor error!";
|
||||
delete input_tensor;
|
||||
|
@ -302,11 +335,15 @@ TEST_F(TestActivationOpenCL, SigmoidFp32_dim4) {
|
|||
}
|
||||
std::vector<lite::tensor::Tensor *> inputs{input_tensor};
|
||||
std::vector<lite::tensor::Tensor *> outputs{output_tensor};
|
||||
auto allocator = ocl_runtime->GetAllocator();
|
||||
inputs[0]->MallocData(allocator);
|
||||
MS_LOG(INFO) << "Initialize input data";
|
||||
LoadActivationData(inputs[0]->Data(), inputs[0]->Size(), in_file);
|
||||
MS_LOG(INFO) << "==================input data================";
|
||||
printf_tensor(inputs[0]);
|
||||
if (enable_fp16) {
|
||||
printf_tensor<float16_t>("Sigmoid:FP16--input data--", inputs[0]);
|
||||
} else {
|
||||
printf_tensor<float>("Sigmoid:FP32--input data--", inputs[0]);
|
||||
}
|
||||
|
||||
auto *param = new (std::nothrow) ActivationParameter();
|
||||
if (param == nullptr) {
|
||||
|
@ -369,9 +406,13 @@ TEST_F(TestActivationOpenCL, SigmoidFp32_dim4) {
|
|||
return;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "==================output data================";
|
||||
printf_tensor(outputs[0]);
|
||||
CompareRes(output_tensor, out_file);
|
||||
if (enable_fp16) {
|
||||
printf_tensor<float16_t>("Sigmoid:FP16--output data---", outputs[0]);
|
||||
CompareRes<float16_t>(output_tensor, out_file);
|
||||
} else {
|
||||
printf_tensor<float>("Sigmoid:FP32--output data---", outputs[0]);
|
||||
CompareRes<float>(output_tensor, out_file);
|
||||
}
|
||||
delete kernel;
|
||||
delete param;
|
||||
delete input_tensor;
|
||||
|
@ -380,24 +421,29 @@ TEST_F(TestActivationOpenCL, SigmoidFp32_dim4) {
|
|||
lite::opencl::OpenCLRuntime::DeleteInstance();
|
||||
}
|
||||
|
||||
TEST_F(TestActivationOpenCL, LeakyReluFp32_dim4) {
|
||||
TEST_F(TestActivationOpenCL, LeakyReluFp_dim4) {
|
||||
std::string in_file = "/data/local/tmp/in_data.bin";
|
||||
std::string out_file = "/data/local/tmp/leaky_relu.bin";
|
||||
MS_LOG(INFO) << "Leaky relu Begin test!";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->Init();
|
||||
auto allocator = ocl_runtime->GetAllocator();
|
||||
auto data_type = kNumberTypeFloat32;
|
||||
ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16);
|
||||
bool enable_fp16 = ocl_runtime->GetFp16Enable();
|
||||
|
||||
MS_LOG(INFO) << "Init tensors.";
|
||||
std::vector<int> input_shape = {1, 9};
|
||||
auto data_type = kNumberTypeFloat32;
|
||||
auto tensor_type = schema::NodeType_ValueNode;
|
||||
auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NC, tensor_type);
|
||||
schema::Format format = schema::Format_NHWC;
|
||||
if (input_shape.size() == 2) {
|
||||
format = schema::Format_NC;
|
||||
}
|
||||
auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensor_type);
|
||||
if (input_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "new input tensor error!";
|
||||
return;
|
||||
}
|
||||
auto *output_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NC, tensor_type);
|
||||
auto *output_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensor_type);
|
||||
if (output_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "new output tensor error!";
|
||||
delete input_tensor;
|
||||
|
@ -405,11 +451,15 @@ TEST_F(TestActivationOpenCL, LeakyReluFp32_dim4) {
|
|||
}
|
||||
std::vector<lite::tensor::Tensor *> inputs{input_tensor};
|
||||
std::vector<lite::tensor::Tensor *> outputs{output_tensor};
|
||||
auto allocator = ocl_runtime->GetAllocator();
|
||||
inputs[0]->MallocData(allocator);
|
||||
MS_LOG(INFO) << "Initialize input data";
|
||||
LoadActivationData(inputs[0]->Data(), inputs[0]->Size(), in_file);
|
||||
MS_LOG(INFO) << "==================input data================";
|
||||
printf_tensor(inputs[0]);
|
||||
if (enable_fp16) {
|
||||
printf_tensor<float16_t>("Leaky Relu:FP16--input data--", inputs[0]);
|
||||
} else {
|
||||
printf_tensor<float>("Leaky Relu:FP32--input data--", inputs[0]);
|
||||
}
|
||||
|
||||
auto *param = new (std::nothrow) ActivationParameter();
|
||||
if (param == nullptr) {
|
||||
|
@ -418,7 +468,7 @@ TEST_F(TestActivationOpenCL, LeakyReluFp32_dim4) {
|
|||
delete output_tensor;
|
||||
return;
|
||||
}
|
||||
param->alpha_ = 0.3;
|
||||
param->alpha_ = 0.3f;
|
||||
param->type_ = ActivationType_LEAKY_RELU;
|
||||
auto *kernel =
|
||||
new (std::nothrow) kernel::ActivationOpenClKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
|
||||
|
@ -472,10 +522,13 @@ TEST_F(TestActivationOpenCL, LeakyReluFp32_dim4) {
|
|||
MS_LOG(ERROR) << "Run SubGraphOpenCLKernel error.";
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "==================output data================";
|
||||
printf_tensor(outputs[0]);
|
||||
CompareRes(output_tensor, out_file);
|
||||
if (enable_fp16) {
|
||||
printf_tensor<float16_t>("Leaky Relu:FP16--output data---", outputs[0]);
|
||||
CompareRes<float16_t>(output_tensor, out_file);
|
||||
} else {
|
||||
printf_tensor<float>("Leaky Relu:FP32--output data---", outputs[0]);
|
||||
CompareRes<float>(output_tensor, out_file);
|
||||
}
|
||||
delete kernel;
|
||||
delete param;
|
||||
delete input_tensor;
|
||||
|
|
Loading…
Reference in New Issue