!9435 [MS][LITE][Develop] GPU Ops for PRRelu Optimization

From: @pengyongrong
Reviewed-by: @zhanghaibo5,@ddwsky
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-04 16:52:12 +08:00 committed by Gitee
commit af32a33e60
11 changed files with 179 additions and 314 deletions

View File

@ -1,30 +1,19 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
#define NHWC4 2
#define NC4HW4 100
__kernel void PRelu_scalar(__read_only image2d_t input, __write_only image2d_t output, float weight, int4 shape,
int data_format) {
int h = get_global_id(0);
int nh = get_global_id(0);
int w = get_global_id(1);
int slice = get_global_id(2);
int H = shape.y;
int W = shape.z;
int SLICES = shape.w;
if (h >= H || w >= W || slice >= SLICES) {
int c = get_global_id(2);
if (nh >= shape.x * shape.y || w >= shape.z || c >= shape.w || shape.y == 0) {
return;
}
int x, y;
if (data_format == 2) {
x = w * SLICES + slice;
y = h;
} else {
x = w;
y = slice * H + h;
}
int n = nh / shape.y;
int h = nh % shape.y;
int x = w * shape.w + c;
int y = n * shape.y + h;
FLT4 out = READ_IMAGE(input, smp_zero, (int2)(x, y));
if (out.x < 0) {
out.x *= weight;
@ -43,25 +32,17 @@ __kernel void PRelu_scalar(__read_only image2d_t input, __write_only image2d_t o
__kernel void PRelu_vector(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight_vector,
int4 shape, int data_format) {
int h = get_global_id(0);
int nh = get_global_id(0);
int w = get_global_id(1);
int slice = get_global_id(2);
int H = shape.y;
int W = shape.z;
int SLICES = shape.w;
if (h >= H || w >= W || slice >= SLICES) {
int c = get_global_id(2);
if (nh >= shape.x * shape.y || w >= shape.z || c >= shape.w || shape.y == 0) {
return;
}
FLT4 weight = weight_vector[slice];
int x, y;
if (data_format == 2) {
x = w * SLICES + slice;
y = h;
} else {
x = w;
y = slice * H + h;
}
int n = nh / shape.y;
int h = nh % shape.y;
int x = w * shape.w + c;
int y = n * shape.y + h;
FLT4 weight = weight_vector[c];
FLT4 out = READ_IMAGE(input, smp_zero, (int2)(x, y));
if (out.x < 0) {

View File

@ -23,7 +23,7 @@ __kernel void stack_2input_3axis_1inshape(__read_only image2d_t input0, __read_o
FLT4 result1 = READ_IMAGE(input0, smp_none, (int2)(0, (X)));
FLT4 result2 = READ_IMAGE(input1, smp_none, (int2)(0, (X)));
FLT4 result = {result1.x, result2.x, 0, 0};
WRITE_IMAGE(output, (int2)(coordinate_x_out, (X)), result);
WRITE_IMAGE(output, (int2)(Y, (X)), result);
}
// input -2D -axis = 1

View File

@ -35,6 +35,10 @@ int BatchNormOpenCLKernel::CheckSpecs() {
MS_LOG(ERROR) << "in size: " << in_tensors_.size() << ", out size: " << out_tensors_.size();
return RET_ERROR;
}
if (in_tensors_.at(0)->shape().size() == 4) {
MS_LOG(ERROR) << "The dim of in_tensors->shape must be 4 but your dim is : " << in_tensors_.at(0)->shape().size();
return RET_ERROR;
}
if (in_tensors_.at(0)->shape()[0] > 1) {
MS_LOG(ERROR) << " Unsupported batch_size >1 ";
return RET_ERROR;

View File

@ -48,6 +48,10 @@ int CastOpenCLKernel::CheckSpecs() {
MS_LOG(ERROR) << "in size: " << in_tensors_.size() << ", out size: " << out_tensors_.size();
return RET_ERROR;
}
if (in_tensors_.at(0)->shape().size() == 4) {
MS_LOG(ERROR) << "The dim of in_tensors->shape must be 4 but your dim is : " << in_tensors_.at(0)->shape().size();
return RET_ERROR;
}
return RET_OK;
}

View File

@ -69,12 +69,19 @@ int ConcatOpenCLKernel::CheckSpecs() {
return RET_ERROR;
}
auto param = reinterpret_cast<ConcatParameter *>(this->op_parameter_);
auto out_tensors_shape_size = out_tensors_[0]->shape().size();
MS_LOG(DEBUG) << " concat at axis=: " << param->axis_;
if (out_tensors_[0]->shape().size() > 4) {
MS_LOG(ERROR) << " GPU Unsupported shape.size > 4 "
<< "your shape().size()=: " << out_tensors_[0]->shape().size();
if (out_tensors_shape_size > 4) {
MS_LOG(ERROR) << " GPU Unsupported shape.size > 4 ";
return RET_ERROR;
}
for (int i = 0; i < in_tensors_.size(); ++i) {
auto in_tensors_shape_size = in_tensors_[i]->shape().size();
if (in_tensors_shape_size > 4) {
MS_LOG(ERROR) << " GPU Unsupported in_tensor shape.size > 4 ";
return RET_ERROR;
}
}
axis_ = param->axis_;
if (axis_ < 0) {
axis_ += in_tensors_.front()->shape().size();
@ -83,16 +90,20 @@ int ConcatOpenCLKernel::CheckSpecs() {
MS_LOG(ERROR) << " only support axis >= 0 and axis <= 3 ";
return RET_ERROR;
}
if (out_tensors_[0]->shape().size() < 4 && op_parameter_->type_ == PrimitiveType_Concat && axis_ != 0) {
if (out_tensors_[0]->shape().size() == 2) {
if (out_tensors_shape_size < 4 && Type() == PrimitiveType_Concat && axis_ != 0) {
if (out_tensors_shape_size == 2) {
axis_ = axis_ + 2;
} else if (out_tensors_[0]->shape().size() == 3) {
} else if (out_tensors_shape_size == 3) {
axis_ = axis_ + 1;
} else {
MS_LOG(ERROR) << " Unsupported axis =: " << axis_ << " shape().size()=: " << out_tensors_[0]->shape().size();
MS_LOG(ERROR) << " Unsupported axis =: " << axis_ << " shape().size()=: " << out_tensors_shape_size;
return RET_ERROR;
}
}
if (in_tensors_.size() < 2 || in_tensors_.size() > 6) {
MS_LOG(ERROR) << "unsupported input size :" << in_tensors_.size();
return RET_ERROR;
}
return RET_OK;
}
@ -161,12 +172,7 @@ int ConcatOpenCLKernel::Prepare() {
if (axis_ == 3 && !Align_) {
kernel_name += "Input" + std::to_string(in_tensors_.size()) + "UnAlign";
} else {
if (2 <= in_tensors_.size() && in_tensors_.size() <= 6) {
kernel_name += std::to_string(in_tensors_.size()) + "inputaxis" + std::to_string(axis_);
} else {
MS_LOG(ERROR) << " input must be less than 6 and more than 2 ";
return RET_ERROR;
}
kernel_name += std::to_string(in_tensors_.size()) + "inputaxis" + std::to_string(axis_);
}
kernel_name += "_NHWC4";
@ -186,19 +192,14 @@ int ConcatOpenCLKernel::Run() {
if (axis_ == 0) {
return RunAxis0();
}
if (2 <= in_tensors_.size() && in_tensors_.size() <= 6) {
int arg_cn = 0;
for (int i = 0; i < in_tensors_.size(); ++i) {
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[i]->data_c());
}
if (axis_ == 3 && !Align_) {
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c(), lite::opencl::MemType::BUF);
} else {
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c());
}
int arg_cn = 0;
for (int i = 0; i < in_tensors_.size(); ++i) {
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[i]->data_c());
}
if (axis_ == 3 && !Align_) {
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c(), lite::opencl::MemType::BUF);
} else {
MS_LOG(ERROR) << "unsupported input size :" << in_tensors_.size();
return RET_ERROR;
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c());
}
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr, &event_);
return RET_OK;

View File

@ -34,16 +34,22 @@ namespace mindspore::kernel {
int PowerOpenCLKernel::CheckSpecs() {
auto param = reinterpret_cast<PowerParameter *>(this->op_parameter_);
broadcast_ = param->broadcast_;
if (!(broadcast_ && in_tensors_.size() == 1)) {
if (in_tensors_.size() == 2 && in_tensors_.at(0)->shape().size() != in_tensors_.at(1)->shape().size()) {
MS_LOG(ERROR) << "Unsupported input->shape.size " << in_tensors_.at(0)->shape().size()
<< "!=" << in_tensors_.at(1)->shape().size();
return RET_ERROR;
} else if (in_tensors_.size() > 2 || in_tensors_.at(0)->shape().size() > 4) {
MS_LOG(ERROR) << "Unsupported in_tensors_->shape.size " << in_tensors_.size() << " or "
<< "in_tensors_[0]->shape().size(): " << in_tensors_.at(0)->shape().size();
return RET_ERROR;
}
if ((in_tensors_.size() != 1 && in_tensors_.size() != 2) || out_tensors_.size() != 1) {
MS_LOG(ERROR) << "in size: " << in_tensors_.size() << "out size: " << out_tensors_.size();
return RET_ERROR;
}
if (in_tensors_.size() == 1 && !broadcast_) {
MS_LOG(ERROR) << "broadcast is supported when in_tensors_.size() == 1 ";
return RET_ERROR;
}
if (in_tensors_.size() == 2 && in_tensors_.at(0)->shape().size() != in_tensors_.at(1)->shape().size()) {
MS_LOG(ERROR) << "Unsupported input->shape.size " << in_tensors_.at(0)->shape().size()
<< "!=" << in_tensors_.at(1)->shape().size();
return RET_ERROR;
}
if (in_tensors_.at(0)->shape().size() > 4) {
MS_LOG(ERROR) << "in_tensors_->shape.size must be less than 4";
return RET_ERROR;
}
return RET_OK;
}

View File

@ -35,7 +35,8 @@ namespace mindspore::kernel {
int PReluOpenCLKernel::InitWeights() {
auto allocator = ocl_runtime_->GetAllocator();
auto weight_tensor = in_tensors_[1];
auto weight_tensor = in_tensors_.at(1);
int C_ = weight_shape_.s[3];
if (weight_is_scalar) {
if (weight_tensor->data_type() == kNumberTypeFloat16) {
weight_scalar_ = static_cast<float>(*reinterpret_cast<float16_t *>(weight_tensor->data_c()));
@ -74,51 +75,75 @@ int PReluOpenCLKernel::InitWeights() {
return RET_OK;
}
int PReluOpenCLKernel::Init() {
auto input_tensor = in_tensors_[0];
auto weight_tensor = in_tensors_[1];
if (input_tensor->shape().size() != 4) {
MS_LOG(ERROR) << "PRelu only support dim=4, but your dim=" << input_tensor->shape().size();
return mindspore::lite::RET_ERROR;
}
batch_size_ = input_tensor->Batch();
C_ = input_tensor->Channel();
H_ = input_tensor->Height();
W_ = input_tensor->Width();
if (batch_size_ != 1) {
MS_LOG(ERROR) << "Init PRelu kernel failed: Unsupported multi-batch.";
int PReluOpenCLKernel::CheckSpecs() {
if (in_tensors_.size() != 2 || out_tensors_.size() != 1) {
MS_LOG(ERROR) << "PRelu Only supported in_tensors_.size=2 and out_tensors_.size()= 2 but your in_tensors_.size = "
<< in_tensors_.size() << "out_tensors_.size()=: " << out_tensors_.size();
return RET_ERROR;
}
auto weight_channel = weight_tensor->shape()[0];
if (weight_channel != 1 && weight_channel != C_) {
MS_LOG(ERROR)
<< "PRelu weight channel size must be 1 or must be equal with in_teneors channel size, but your weight size is "
<< weight_channel << " and your input channel size is " << C_;
GpuTensorInfo img_info_in_tensors0(in_tensors_[0]);
GpuTensorInfo img_info_in_tensors1(in_tensors_[1]);
auto weight_tensor = in_tensors_.at(1);
auto in_tensor_channel = img_info_in_tensors0.C;
auto weight_channel = img_info_in_tensors1.C;
if (weight_channel != 1 && weight_channel != in_tensor_channel) {
MS_LOG(ERROR) << "PRelu weight must be equal with in_teneors channel size, but your weight size is "
<< weight_channel << " and your input channel size is " << in_tensor_channel;
return mindspore::lite::RET_ERROR;
}
weight_is_scalar = weight_channel == 1;
if (weight_tensor->data_type() != kNumberTypeFloat16 && weight_tensor->data_type() != kNumberTypeFloat32) {
MS_LOG(ERROR) << "PRelu weight must be float32 or float16";
return RET_ERROR;
}
enable_fp16_ = ocl_runtime_->GetFp16Enable();
return RET_OK;
}
void PReluOpenCLKernel::SetConstArgs() {
int arg_idx = 3;
out_shape_.s[3] = UP_DIV(out_shape_.s[3], C4NUM);
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_shape_);
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, 2);
}
void PReluOpenCLKernel::SetGlobalLocal() {
std::vector<size_t> local = {4, 4, 1};
OH = out_shape_.s[0] * out_shape_.s[1];
OW = out_shape_.s[2];
OC = out_shape_.s[3];
std::vector<size_t> global = {OH, OW, OC};
AlignGlobalLocal(global, local);
}
int PReluOpenCLKernel::Prepare() {
cl_int4 output_shape = {};
cl_int4 weight_shape = {};
for (int i = 0; i < out_tensors_.at(0)->shape().size(); ++i) {
output_shape.s[i] = out_tensors_.at(0)->shape()[i];
}
for (int i = 0; i < in_tensors_.at(1)->shape().size(); ++i) {
weight_shape.s[i] = in_tensors_.at(1)->shape()[i];
}
Broadcast2GpuShape(out_shape_.s, output_shape.s, out_tensors_.at(0)->shape().size(), 1);
Broadcast2GpuShape(weight_shape_.s, weight_shape.s, in_tensors_.at(1)->shape().size(), 1);
weight_is_scalar = weight_shape_.s[3] == 1;
enable_fp16_ = ocl_runtime_->GetFp16Enable();
std::string source = prelu_source;
std::string program_name = "PRelu";
std::string kernel_name = "PRelu_" + std::string(weight_is_scalar ? "scalar" : "vector");
ocl_runtime_->LoadSource(program_name, source);
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name);
InitWeights();
MS_LOG(DEBUG) << program_name << " init Done!";
MS_LOG(DEBUG) << "kernel_name=: " << kernel_name << " init Done!";
SetConstArgs();
SetGlobalLocal();
return mindspore::lite::RET_OK;
}
int PReluOpenCLKernel::Run() {
MS_LOG(DEBUG) << op_parameter_->name_ << " Running!";
auto CO_SLICES_ = UP_DIV(C_, C4NUM);
cl_int4 shape = {batch_size_, H_, W_, CO_SLICES_};
int arg_idx = 0;
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data_c());
@ -127,13 +152,7 @@ int PReluOpenCLKernel::Run() {
} else {
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, weight_vector_, lite::opencl::MemType::BUF);
}
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, shape);
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, 2);
std::vector<size_t> local = {4, 4, 1};
std::vector<size_t> global = {static_cast<size_t>(H_), static_cast<size_t>(W_), static_cast<size_t>(CO_SLICES_)};
AlignGlobalLocal(global, local);
auto ret = ocl_runtime_->RunKernel(kernel_, global_range_, local_range_);
auto ret = ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr, &event_);
if (ret != mindspore::lite::RET_OK) {
MS_LOG(ERROR) << "Run kernel " << op_parameter_->name_ << " error.";
return mindspore::lite::RET_ERROR;
@ -141,30 +160,6 @@ int PReluOpenCLKernel::Run() {
return mindspore::lite::RET_OK;
}
kernel::LiteKernel *OpenCLPReluKernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const lite::PrimitiveC *primitive) {
if (inputs.empty()) {
MS_LOG(ERROR) << "Input data size must be greater than 0, but your size is " << inputs.size();
free(opParameter);
return nullptr;
}
auto *kernel = new (std::nothrow) PReluOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr.";
free(opParameter);
return nullptr;
}
auto ret = kernel->Init();
if (ret != mindspore::lite::RET_OK) {
MS_LOG(ERROR) << "Init PRelu kernel failed!";
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_PReLU, OpenCLPReluKernelCreator)
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_PReLU, OpenCLPReluKernelCreator)
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_PReLU, OpenCLKernelCreator<PReluOpenCLKernel>);
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_PReLU, OpenCLKernelCreator<PReluOpenCLKernel>);
} // namespace mindspore::kernel

View File

@ -32,16 +32,20 @@ class PReluOpenCLKernel : public OpenCLKernel {
: OpenCLKernel(parameter, inputs, outputs) {}
~PReluOpenCLKernel() override = default;
int Init() override;
int Prepare() override;
int CheckSpecs() override;
void SetConstArgs() override;
void SetGlobalLocal() override;
int Run() override;
int InitWeights() override;
private:
bool enable_fp16_{false};
int batch_size_{};
int C_{};
int H_{};
int W_{};
uint32_t OH = {1};
uint32_t OW = {1};
uint32_t OC = {1};
cl_int4 weight_shape_{};
cl_int4 out_shape_{};
void *weight_vector_{nullptr};
float weight_scalar_{0.f};
bool weight_is_scalar{false};

View File

@ -91,13 +91,13 @@ int SparseToDenseOpenCLKernel::InitWeights() {
}
int SparseToDenseOpenCLKernel::CheckSpecs() {
if (in_tensors_[0]->shape().size() > 4 || out_tensors_[0]->shape().size() > 4) {
MS_LOG(ERROR) << "Unsupported inputdim: " << in_tensors_[0]->shape().size() << "outdim"
<< out_tensors_[0]->shape().size();
if (in_tensors_.size() < 3 || out_tensors_.at(0)->shape().size() > 4) {
MS_LOG(ERROR) << " only support out_tensors_ dim <= 4 and in_tensors_.size >= 3";
return RET_ERROR;
}
if (out_tensors_[0]->shape().size() > 3 || in_tensors_.size() < 3) {
MS_LOG(ERROR) << " only support dim <= 2 and in_tensors_.size >= 3";
if (in_tensors_.at(0)->shape().size() > 4 || out_tensors_.at(0)->shape().size() > 4) {
MS_LOG(ERROR) << "Unsupported inputdim: " << in_tensors_[0]->shape().size() << "outdim"
<< out_tensors_[0]->shape().size();
return RET_ERROR;
}
if (input_dim_ == 2) {

View File

@ -64,7 +64,9 @@ void StackGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *l
}
int StackOpenCLKernel::CheckSpecs() {
if (in_tensors_[0]->shape().size() > 2 && (axis_ != 0)) {
auto param = reinterpret_cast<StackParameter *>(this->op_parameter_);
axis_ = param->axis_;
if (in_tensors_.size() != 2 && (axis_ != 0)) {
MS_LOG(ERROR) << " only support input size = 2 ";
return RET_ERROR;
}
@ -72,8 +74,6 @@ int StackOpenCLKernel::CheckSpecs() {
MS_LOG(ERROR) << " only support dim <= 4 ";
return RET_ERROR;
}
auto param = reinterpret_cast<StackParameter *>(this->op_parameter_);
axis_ = param->axis_;
axis_ = axis_ < 0 ? axis_ + in_tensors_[0]->shape().size() : axis_;
if (axis_ > 3) {
MS_LOG(ERROR) << " only support axis <= 3 ";
@ -178,7 +178,7 @@ int StackOpenCLKernel::Run() {
}
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c());
}
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_);
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr, &event_);
return RET_OK;
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Stack, OpenCLKernelCreator<StackOpenCLKernel>);

View File

@ -13,186 +13,56 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "src/common/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/common/file_utils.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.h"
#include "ut/src/runtime/kernel/opencl/common.h"
#include "mindspore/lite/nnacl/prelu_parameter.h"
using mindspore::kernel::LiteKernel;
using mindspore::kernel::OpenCLSubGraph;
using mindspore::kernel::PReluOpenCLKernel;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
// PrimitiveType_PReLU: src/ops/populate/p_relu_populate.cc
namespace mindspore::lite::opencl::test {
class TestPReluOpenCL : public CommonTest {};
void LoadDataPRelu(void *dst, size_t dst_size, const std::string &file_path) {
if (file_path.empty()) {
memset(dst, 0x00, dst_size);
} else {
auto src_data = mindspore::lite::ReadFile(file_path.c_str(), &dst_size);
memcpy(dst, src_data, dst_size);
class TestOpenCL_PRrelu : public CommonTest {};
namespace {
// PrimitiveType_PReLU: src/ops/populate/p_relu_populate.cc
OpParameter *CreateParameter() {
auto *param = test::CreateParameter<PReluParameter>(schema::PrimitiveType_PReLU);
return reinterpret_cast<OpParameter *>(param);
}
} // namespace
TEST_F(TestOpenCL_PRrelu, testcase1) {
std::vector<int> input_shape1 = {1, 4, 5, 6};
std::vector<int> input_shape2 = {1};
std::vector<int> output_shape = {1, 4, 5, 6};
size_t input1_size, input2_size, output_size;
std::string input1Ppath = "./test_data/PRRelufp32_input1.bin";
std::string input2Ppath = "./test_data/PRRelufp32_input2.bin";
std::string correctOutputPath = "./test_data/PRRelufp32fp32_output.bin";
auto input_data1 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size));
auto input_data2 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size));
auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size));
for (auto fp16_enable : {true}) {
auto *param = CreateParameter();
TestMain({{input_shape1, input_data1, VAR}, {input_shape2, input_data2, CONST_TENSOR}}, {output_shape, output_data},
param, fp16_enable, fp16_enable ? 1e-2 : 1e-9);
}
}
template <typename T>
void CompareOutPRelu(lite::Tensor *output_tensor, const std::string &standard_answer_file) {
auto *output_data = reinterpret_cast<T *>(output_tensor->data_c());
size_t output_size = output_tensor->Size();
auto expect_data = reinterpret_cast<T *>(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size));
constexpr float atol = 0.0002;
for (int i = 0; i < output_tensor->ElementsNum(); ++i) {
if (std::fabs(output_data[i] - expect_data[i]) > atol) {
printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]);
printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]);
printf("error at idx[%d] expect=%.3f output=%.3f\n\n\n", i, expect_data[i], output_data[i]);
return;
}
}
printf("compare success!\n");
printf("compare success!\n");
printf("compare success!\n\n\n");
}
TEST_F(TestOpenCL_PRrelu, testcase2) {
std::vector<int> input_shape1 = {1, 4, 5, 6};
std::vector<int> input_shape2 = {1, 1, 1, 6};
std::vector<int> output_shape = {1, 4, 5, 6};
size_t input1_size, input2_size, output_size;
std::string input1Ppath = "./test_data/PRRelufp32_input1.bin";
std::string input2Ppath = "./test_data/PRRelufp32_input2.bin";
std::string correctOutputPath = "./test_data/PRRelufp32fp32_output.bin";
auto input_data1 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size));
auto input_data2 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size));
auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size));
template <typename T>
void printf_tensor_Prelu(const std::string &log, mindspore::lite::Tensor *in_data, int size) {
MS_LOG(INFO) << log;
auto input_data = reinterpret_cast<T *>(in_data->data_c());
for (int i = 0; i < size; ++i) {
printf("%f ", input_data[i]);
for (auto fp16_enable : {true}) {
auto *param = CreateParameter();
TestMain({{input_shape1, input_data1, VAR}, {input_shape2, input_data2, CONST_TENSOR}}, {output_shape, output_data},
param, fp16_enable, fp16_enable ? 1e-2 : 1e-9);
}
printf("\n");
MS_LOG(INFO) << "Print tensor done";
}
TEST_F(TestPReluOpenCL, PReluFp32_dim4) {
std::string in_file = "/data/local/tmp/in_data.bin";
std::string weight_file = "/data/local/tmp/weight_data.bin";
std::string standard_answer_file = "/data/local/tmp/caffe_prelu.bin";
MS_LOG(INFO) << "-------------------->> Begin test PRelu!";
auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance();
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();
MS_LOG(INFO) << "Init tensors.";
std::vector<int> input_shape = {1, 4, 3, 9};
auto data_type = kNumberTypeFloat16;
ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16);
schema::Format format = schema::Format_NHWC;
auto tensor_type = lite::Tensor::CONST_TENSOR;
auto input_tensor = new (std::nothrow) lite::Tensor(data_type, input_shape, format, tensor_type);
if (input_tensor == nullptr) {
MS_LOG(ERROR) << "new input_tensor error!";
return;
}
auto output_tensor = new (std::nothrow) lite::Tensor(data_type, input_shape, format, tensor_type);
if (output_tensor == nullptr) {
MS_LOG(ERROR) << "new output_tensor error";
delete input_tensor;
return;
}
auto weight_tensor =
new (std::nothrow) lite::Tensor(data_type, std::vector<int>{input_shape[3]}, schema::Format_NHWC, tensor_type);
if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "new weight_tensor error";
delete input_tensor;
delete output_tensor;
return;
}
std::vector<lite::Tensor *> inputs{input_tensor, weight_tensor};
std::vector<lite::Tensor *> outputs{output_tensor};
inputs[0]->MallocData(allocator);
inputs[1]->MallocData(allocator);
MS_LOG(INFO) << "initialize input data";
LoadDataPRelu(input_tensor->data_c(), input_tensor->Size(), in_file);
LoadDataPRelu(weight_tensor->data_c(), weight_tensor->Size(), weight_file);
if (ocl_runtime->GetFp16Enable()) {
printf_tensor_Prelu<float16_t>("PRELU:FP16--input data", input_tensor, inputs[0]->ElementsNum());
printf_tensor_Prelu<float16_t>("PRELU:FP16--weight data", weight_tensor, weight_tensor->ElementsNum());
} else {
printf_tensor_Prelu<float>("PRELU:FP32--input data", input_tensor, inputs[0]->ElementsNum());
printf_tensor_Prelu<float>("PRELU:FP32--weight data", weight_tensor, inputs[1]->ElementsNum());
}
auto param = new (std::nothrow) PReluParameter();
if (param == nullptr) {
MS_LOG(ERROR) << "new PreluParameter error";
delete input_tensor;
delete output_tensor;
delete weight_tensor;
return;
}
auto prelu_kernel =
new (std::nothrow) kernel::PReluOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (prelu_kernel == nullptr) {
MS_LOG(ERROR) << "new PReluOpenCLKernel error";
delete input_tensor;
delete output_tensor;
delete weight_tensor;
delete param;
return;
}
auto ret = prelu_kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init prelu kernel error";
return;
}
MS_LOG(INFO) << "initialize sub_graph";
std::vector<kernel::LiteKernel *> kernels{prelu_kernel};
auto *sub_graph = new (std::nothrow) kernel::OpenCLSubGraph({input_tensor}, outputs, kernels, kernels, kernels);
if (sub_graph == nullptr) {
MS_LOG(ERROR) << "Create kernel sub_graph error";
delete input_tensor;
delete output_tensor;
delete weight_tensor;
delete param;
delete prelu_kernel;
return;
}
ret = sub_graph->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init sub graph error";
delete input_tensor;
delete output_tensor;
delete weight_tensor;
delete param;
delete sub_graph;
return;
}
ret = sub_graph->Run();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Run sub graph error";
delete input_tensor;
delete output_tensor;
delete weight_tensor;
delete param;
delete sub_graph;
return;
}
if (ocl_runtime->GetFp16Enable()) {
printf_tensor_Prelu<float16_t>("PRelu:FP16--output_data", output_tensor, outputs[0]->ElementsNum());
CompareOutPRelu<float16_t>(output_tensor, standard_answer_file);
} else {
printf_tensor_Prelu<float>("PRelu:FP32--output_data", output_tensor, outputs[0]->ElementsNum());
CompareOutPRelu<float>(output_tensor, standard_answer_file);
}
delete input_tensor;
delete output_tensor;
delete weight_tensor;
delete param;
delete sub_graph;
}
} // namespace mindspore::lite::opencl::test