fix bug in matmul testcase

This commit is contained in:
chenzupeng 2020-08-21 11:30:12 +08:00
parent 11e670c54b
commit 89ca31a03d
21 changed files with 165 additions and 117 deletions

3
.gitignore vendored
View File

@ -98,3 +98,6 @@ mindspore/.commit_id
# lite test file
mindspore/lite/test/do_test/
# lite opencl compile file
*.cl.inc

View File

@ -67,7 +67,9 @@ int ActivationOpenClKernel::Init() {
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
ori_format_ = out_tensors_[0]->GetFormat();
in_ori_format_ = in_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(schema::Format_NHWC4);
out_ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
MS_LOG(DEBUG) << op_parameter_->name_ << " init Done!";
return RET_OK;

View File

@ -122,7 +122,9 @@ int ArithmeticOpenCLKernel::Init() {
if (error_code != RET_OK) {
return error_code;
}
ori_format_ = out_tensors_[0]->GetFormat();
in_ori_format_ = in_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(schema::Format_NHWC4);
out_ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
Image2dGetWorkGroupSize();
return RET_OK;

View File

@ -56,7 +56,9 @@ int BatchNormOpenCLKernel::Init() {
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
ori_format_ = out_tensors_[0]->GetFormat();
in_ori_format_ = in_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(schema::Format_NHWC4);
out_ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
return RET_OK;

View File

@ -63,7 +63,9 @@ int CaffePReluOpenCLKernel::Init() {
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
ori_format_ = out_tensors_[0]->GetFormat();
in_ori_format_ = in_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(schema::Format_NHWC4);
out_ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
MS_LOG(DEBUG) << program_name << " Init Done!";
return RET_OK;

View File

@ -83,7 +83,9 @@ int ConcatOpenCLKernel::Init() {
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
}
ori_format_ = out_tensors_[0]->GetFormat();
in_ori_format_ = in_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(schema::Format_NHWC4);
out_ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
return RET_OK;

View File

@ -51,7 +51,9 @@ int Conv2dTransposeOpenCLKernel::Init() {
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
#endif
PadWeight();
ori_format_ = out_tensors_[0]->GetFormat();
in_ori_format_ = in_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(schema::Format_NHWC4);
out_ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
MS_LOG(DEBUG) << kernel_name << " Init Done!";
return RET_OK;

View File

@ -92,8 +92,11 @@ int ConvolutionOpenCLKernel::Init() {
}
this->InitBuffer();
ori_format_ = out_tensors_[0]->GetFormat();
in_ori_format_ = in_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(schema::Format_NHWC4);
out_ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
MS_LOG(DEBUG) << "Convolution Init Done!";
return RET_OK;
}

View File

@ -42,7 +42,8 @@ int DepthwiseConv2dOpenCLKernel::Init() {
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
std::string kernel_name = "DepthwiseConv2d";
auto in_format = in_tensors_[0]->GetFormat();
ori_format_ = out_tensors_[0]->GetFormat();
in_ori_format_ = in_format;
out_ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(in_format);
if (in_format != schema::Format_NHWC4 && in_format != schema::Format_NC4HW4) {
MS_LOG(ERROR) << "input format(" << in_format << ") "

View File

@ -44,26 +44,28 @@ int MatMulOpenCLKernel::Init() {
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
#endif
auto weight_format = in_tensors_[1]->GetFormat();
if (weight_format != schema::Format_NHWC) {
MS_LOG(ERROR) << "weight format(" << weight_format << ") "
<< "format not support!";
return 1;
int ci, co;
if (in_tensors_[1]->shape().size() == 2) {
ci = in_tensors_[1]->shape()[1];
co = in_tensors_[1]->shape()[0];
} else {
ci = in_tensors_[1]->shape()[3];
co = in_tensors_[1]->shape()[0];
}
int ci = in_tensors_[1]->shape()[3];
int co = in_tensors_[1]->shape()[0];
sizeCI = {ci, UP_DIV(ci, 4)};
sizeCO = {co, UP_DIV(co, 4)};
auto allocator = ocl_runtime->GetAllocator();
padWeight_ = reinterpret_cast<FLOAT_T *>(allocator->Malloc(sizeCI.s[1] * sizeCO.s[1] * 16 * sizeof(FLOAT_T)));
padWeight_ = reinterpret_cast<FLOAT_T *>(allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true));
bias_ = reinterpret_cast<FLOAT_T *>(allocator->Malloc(sizeCO.s[1] * 4 * sizeof(FLOAT_T)));
bias_ = reinterpret_cast<FLOAT_T *>(allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true));
sizeCI = {ci, UP_DIV(ci, C4NUM)};
sizeCO = {co, UP_DIV(co, C4NUM)};
PadWeight();
allocator->UnmapBuffer(padWeight_);
allocator->UnmapBuffer(bias_);
ori_format_ = out_tensors_[0]->GetFormat();
in_ori_format_ = in_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(schema::Format_NHWC4);
out_ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
if (out_tensors_[0]->shape().size() == 2) {
out_ori_format_ = schema::Format_NC;
out_tensors_[0]->SetFormat(schema::Format_NC4);
in_ori_format_ = schema::Format_NC;
in_tensors_[0]->SetFormat(schema::Format_NC4);
}
MS_LOG(DEBUG) << kernel_name << " Init Done!";
return 0;
}
@ -71,16 +73,22 @@ int MatMulOpenCLKernel::Init() {
int MatMulOpenCLKernel::ReSize() { return 0; }
void MatMulOpenCLKernel::PadWeight() {
auto origin_weight = reinterpret_cast<FLOAT_T *>(in_tensors_.at(kWeightIndex)->Data());
auto allocator = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator();
padWeight_ =
reinterpret_cast<FLOAT_t *>(allocator->Malloc(sizeCI.s[1] * sizeCO.s[1] * C4NUM * C4NUM * sizeof(FLOAT_t)));
padWeight_ = reinterpret_cast<FLOAT_t *>(allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true));
auto origin_weight = reinterpret_cast<FLOAT_t *>(in_tensors_.at(kWeightIndex)->Data());
int divCI = sizeCI.s[1];
int divCO = sizeCO.s[1];
int co = sizeCO.s[0];
int index = 0;
for (int i = 0; i < divCI; ++i) {
for (int j = 0; j < divCO; ++j) {
for (int k = 0; k < 4; ++k) {
for (int l = 0; l < 4; ++l) {
int src_x = i * 4 + l;
int src_y = j * 4 + k;
for (int k = 0; k < C4NUM; ++k) {
for (int l = 0; l < C4NUM; ++l) {
int src_x = i * C4NUM + l;
int src_y = j * C4NUM + k;
if (src_x < sizeCI.s[0] && src_y < sizeCO.s[0]) {
padWeight_[index++] = origin_weight[src_y * sizeCI.s[0] + src_x];
} else {
@ -90,60 +98,55 @@ void MatMulOpenCLKernel::PadWeight() {
}
}
}
if (hasBias_) {
memcpy(bias_, in_tensors_[2]->Data(), sizeof(FLOAT_T) * sizeCO.s[0]);
for (int i = sizeCO.s[0]; i < sizeCO.s[1] * 4; i++) {
bias_[i] = 0;
}
} else {
for (int i = 0; i < sizeCO.s[1] * 4; i++) {
bias_[i] = 0;
}
size_t im_dst_x, im_dst_y;
im_dst_x = divCO;
im_dst_y = 1;
#ifdef ENABLE_FP16
size_t img_dtype = CL_HALF_FLOAT;
#else
size_t img_dtype = CL_FLOAT;
#endif
std::vector<size_t> img_size{im_dst_x, im_dst_y, img_dtype};
bias_ = reinterpret_cast<FLOAT_t *>(allocator->Malloc(im_dst_x * im_dst_y * C4NUM * sizeof(FLOAT_t), img_size));
bias_ = reinterpret_cast<FLOAT_t *>(allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true));
memset(bias_, 0x00, divCO * C4NUM * sizeof(FLOAT_t));
if (in_tensors_.size() >= 3) {
memcpy(bias_, in_tensors_[2]->Data(), co * sizeof(FLOAT_t));
}
allocator->UnmapBuffer(bias_);
}
int MatMulOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
size_t im_dst_x, im_dst_y;
im_dst_x = sizeCO.s[1];
im_dst_y = 1;
#ifdef ENABLE_FP16
size_t img_dtype = CL_HALF_FLOAT;
#else
size_t img_dtype = CL_FLOAT;
#endif
img_size->clear();
std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype};
*img_size = vec;
return RET_OK;
}
int MatMulOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running!";
std::vector<int> shapex = in_tensors_[0]->shape();
int n = shapex[0];
if (n > 1) {
MS_LOG(ERROR) << "MatMul n > 1 not supported!";
return 1;
}
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
// local size should less than MAX_GROUP_SIZE
std::vector<size_t> local = {64, 4};
std::vector<size_t> global = {UP_ROUND(sizeCO.s[1], local[0]), 4};
cl::ImageFormat image_format;
{
image_format.image_channel_order = CL_RGBA;
#ifdef ENABLE_FP16
image_format.image_channel_data_type = CL_HALF_FLOAT;
#else
image_format.image_channel_data_type = CL_FLOAT;
#endif
}
cl_int in_error_code, in_error_code_weight, in_error_code_bias, out_error_code;
cl::Image2D img_input(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, image_format, sizeCI.s[1], 1,
0, in_tensors_[0]->Data(), &in_error_code);
cl::Image2D img_bias(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, image_format, sizeCO.s[1], 1,
0, bias_, &in_error_code_bias);
cl::Image2D img_out(*ocl_runtime->Context(), CL_MEM_WRITE_ONLY, image_format, sizeCO.s[1], 1, 0, nullptr,
&out_error_code);
ocl_runtime->SetKernelArg(kernel_, 0, img_input);
ocl_runtime->SetKernelArg(kernel_, 1, padWeight_);
ocl_runtime->SetKernelArg(kernel_, 2, img_bias);
ocl_runtime->SetKernelArg(kernel_, 3, img_out);
ocl_runtime->SetKernelArg(kernel_, 4, sizeCI);
ocl_runtime->SetKernelArg(kernel_, 5, sizeCO);
ocl_runtime->SetKernelArg(kernel_, 6, hasBias_ ? 1 : 0);
int arg_count = 0;
ocl_runtime->SetKernelArg(kernel_, arg_count++, in_tensors_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_count++, padWeight_);
ocl_runtime->SetKernelArg(kernel_, arg_count++, bias_);
ocl_runtime->SetKernelArg(kernel_, arg_count++, out_tensors_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_count++, sizeCI);
ocl_runtime->SetKernelArg(kernel_, arg_count++, sizeCO);
ocl_runtime->SetKernelArg(kernel_, arg_count++, hasBias_ ? 1 : 0);
ocl_runtime->RunKernel(kernel_, global, local, nullptr);
auto origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{(size_t)(sizeCO.s[1]), 1, 1};
ocl_runtime->GetDefaultCommandQueue()->enqueueReadImage(img_out, CL_TRUE, origin, region, 0, 0,
out_tensors_[0]->Data());
return 0;
}

View File

@ -23,11 +23,6 @@
#include "src/runtime/kernel/arm/nnacl/conv_parameter.h"
#include "src/runtime/opencl/opencl_runtime.h"
#ifdef ENABLE_FP16
using FLOAT_T = float16_t;
#else
using FLOAT_T = float;
#endif
namespace mindspore::kernel {
@ -44,11 +39,12 @@ class MatMulOpenCLKernel : public OpenCLKernel {
int ReSize() override;
int Run() override;
void PadWeight();
int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;
private:
cl::Kernel kernel_;
FLOAT_T *padWeight_;
FLOAT_T *bias_;
FLOAT_t *padWeight_;
FLOAT_t *bias_;
bool hasBias_ = false;
cl_int2 sizeCI;
cl_int2 sizeCO;

View File

@ -73,7 +73,9 @@ int PoolingOpenCLKernel::Init() {
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
#endif
ori_format_ = out_tensors_[0]->GetFormat();
in_ori_format_ = in_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(schema::Format_NHWC4);
out_ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
MS_LOG(DEBUG) << kernel_name << " Init Done!";

View File

@ -46,7 +46,9 @@ int PReluOpenCLKernel::Init() {
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
ori_format_ = out_tensors_[0]->GetFormat();
in_ori_format_ = in_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(schema::Format_NHWC4);
out_ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
MS_LOG(DEBUG) << program_name << " init Done!";
return RET_OK;

View File

@ -43,8 +43,14 @@ int ReshapeOpenCLKernel::Init() {
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
#endif
ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(schema::Format_NHWC);
in_ori_format_ = in_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(schema::Format_NHWC4);
out_ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
if (out_tensors_[0]->shape().size() == 2) {
out_ori_format_ = schema::Format_NC;
out_tensors_[0]->SetFormat(schema::Format_NC4);
}
MS_LOG(DEBUG) << kernel_name << " Init Done!";
return RET_OK;
}

View File

@ -123,10 +123,12 @@ int SoftmaxOpenCLKernel::Init() {
runtime_->LoadSource(program_name, source);
runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options);
#endif
ori_format_ = out_tensors_[0]->GetFormat();
in_ori_format_ = in_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(schema::Format_NHWC4);
out_ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
if (!is_image_out_) {
ori_format_ = schema::Format_NC;
out_ori_format_ = schema::Format_NC;
out_tensors_[0]->SetFormat(schema::Format_NC);
}
MS_LOG(DEBUG) << kernel_name << " Init Done!";

View File

@ -49,17 +49,13 @@ int TransposeOpenCLKernel::Init() {
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
#endif
auto input_format = in_tensors_[0]->GetFormat();
if (input_format != schema::Format_NHWC4) {
MS_LOG(ERROR) << "input format(" << input_format << ") "
<< "format not support!";
return RET_ERROR;
}
if ((in_tensors_[0]->Height() * in_tensors_[0]->Width()) % 4 != 0) {
MS_LOG(ERROR) << "input H * W % 4 != 0 not support!";
return RET_ERROR;
}
ori_format_ = schema::Format_NCHW;
in_ori_format_ = in_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(schema::Format_NHWC4);
out_ori_format_ = schema::Format_NCHW;
out_tensors_[0]->SetFormat(schema::Format_NCHW);
if (!is_image_out_) {
out_mem_type_ = OpenCLMemType::BUF;

View File

@ -35,7 +35,7 @@ class OpenCLKernel : public LiteKernel {
public:
explicit OpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs)
: LiteKernel(parameter, inputs, outputs, nullptr, nullptr) {}
: LiteKernel(parameter, inputs, outputs, nullptr, nullptr) {}
virtual int Init() { return -1; }
virtual int Prepare() { return -1; }
@ -49,11 +49,13 @@ class OpenCLKernel : public LiteKernel {
}
OpenCLMemType GetMemType() { return out_mem_type_; }
void SetMemType(OpenCLMemType mem_type) { out_mem_type_ = mem_type; }
schema::Format GetOriFormat() { return ori_format_;}
schema::Format GetInOriFormat() { return in_ori_format_; }
schema::Format GetOutOriFormat() { return out_ori_format_; }
protected:
OpenCLMemType out_mem_type_{OpenCLMemType::IMG};
schema::Format ori_format_{schema::Format_NHWC4};
schema::Format in_ori_format_{schema::Format_NHWC};
schema::Format out_ori_format_{schema::Format_NHWC4};
};
} // namespace mindspore::kernel

View File

@ -38,10 +38,10 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector<lite::tensor::Tensor *
for (auto &iv : in_kernels) {
for (auto &jv : iv) {
OpenCLKernel *cur_opencl_op = reinterpret_cast<OpenCLKernel *>(jv);
schema::Format ori_format = cur_opencl_op->GetOriFormat();
schema::Format out_ori_format = cur_opencl_op->GetOutOriFormat();
auto tens = cur_opencl_op->out_tensors();
if (mem_type == OpenCLMemType::BUF && mem_type == cur_opencl_op->GetMemType() &&
tens[0]->GetFormat() == ori_format) {
tens[0]->GetFormat() == out_ori_format) {
continue;
}
if (mem_type == OpenCLMemType::IMG) {
@ -53,14 +53,16 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector<lite::tensor::Tensor *
}
for (size_t i = 0; i < in_tensors.size(); ++i) {
OpenCLKernel *cur_opencl_op = reinterpret_cast<OpenCLKernel *>(in_kernels[i][0]);
schema::Format ori_format = cur_opencl_op->GetOriFormat();
schema::Format out_ori_format = cur_opencl_op->GetOutOriFormat();
schema::Format in_ori_format = cur_opencl_op->GetInOriFormat();
if (mem_type == OpenCLMemType::BUF && mem_type == cur_opencl_op->GetMemType() &&
in_tensors[i]->GetFormat() == ori_format) {
in_tensors[i]->GetFormat() == out_ori_format) {
continue;
}
auto dst_format = (mem_type == OpenCLMemType::IMG) ? in_kernels[i][0]->out_tensors()[0]->GetFormat() : ori_format;
auto dst_format =
(mem_type == OpenCLMemType::IMG) ? in_kernels[i][0]->in_tensors()[0]->GetFormat() : out_ori_format;
auto src_format =
(mem_type == OpenCLMemType::IMG) ? in_tensors[i]->GetFormat() : in_kernels[i][0]->out_tensors()[0]->GetFormat();
(mem_type == OpenCLMemType::IMG) ? in_ori_format : in_kernels[i][0]->out_tensors()[0]->GetFormat();
lite::tensor::Tensor *new_tensor = new (std::nothrow) lite::tensor::Tensor();
MS_ASSERT(new_tensor);
if (new_tensor == nullptr) {
@ -80,7 +82,14 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector<lite::tensor::Tensor *
std::vector<int> dst_shape{shape[0], shape[2], shape[3], shape[1]};
new_tensor->set_shape(shape);
}
new_tensor->SetFormat(in_kernels[i][0]->out_tensors()[0]->GetFormat());
if (mem_type == OpenCLMemType::IMG) {
new_tensor->SetFormat(dst_format);
in_tensors[i]->SetFormat(src_format);
} else {
new_tensor->SetFormat(src_format);
in_tensors[i]->SetFormat(dst_format);
}
out_tensors->emplace_back(new_tensor);
#ifdef ENABLE_FP16
KernelKey desc{kGPU, kNumberTypeFloat16, schema::PrimitiveType_ToFormat};

View File

@ -161,7 +161,8 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) {
// compare
CompareOutputData(output_data, correct_data, oh * ow * co, 0.00001);
inputs[0]->SetData(nullptr);
outputs[0]->SetData(nullptr);
MS_LOG(INFO) << "Test Conv2dTransposeFp32 passed";
lite::opencl::OpenCLRuntime::DeleteInstance();
}

View File

@ -31,6 +31,7 @@ class TestMatMulOpenCL : public mindspore::CommonTest {
TEST_F(TestMatMulOpenCL, MatMulFp32) {
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();
size_t input_size;
int ci = 1280;
int co = 1001;
@ -47,16 +48,16 @@ TEST_F(TestMatMulOpenCL, MatMulFp32) {
MS_LOG(ERROR) << "weight_data load error.";
return;
}
std::vector<int> input_shape = {1, 1, 1, ci};
auto tensor_x_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), input_shape);
std::vector<int> input_shape = {1, ci};
auto tensor_x_ptr =
std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), input_shape, schema::Format_NC);
auto tensor_x = tensor_x_ptr.get();
if (tensor_x == nullptr) {
MS_LOG(ERROR) << "tensor_x create error.";
return;
}
tensor_x->SetData(input_data);
std::vector<int> w_shape = {co, 1, 1, ci};
std::vector<int> w_shape = {co, ci};
auto tensor_w_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), w_shape);
auto tensor_w = tensor_w_ptr.get();
if (tensor_w == nullptr) {
@ -65,8 +66,9 @@ TEST_F(TestMatMulOpenCL, MatMulFp32) {
}
tensor_w->SetData(weight_data);
std::vector<int> out_shape = {1, 1, 1, co};
auto tensor_out_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), out_shape);
std::vector<int> out_shape = {1, co};
auto tensor_out_ptr =
std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), out_shape, schema::Format_NC);
auto tensor_out = tensor_out_ptr.get();
if (tensor_out == nullptr) {
MS_LOG(ERROR) << "tensor_out create error.";
@ -81,6 +83,7 @@ TEST_F(TestMatMulOpenCL, MatMulFp32) {
return;
}
arith_kernel->Init();
inputs[0]->MallocData(allocator);
std::vector<kernel::LiteKernel *> kernels{arith_kernel};
@ -92,6 +95,7 @@ TEST_F(TestMatMulOpenCL, MatMulFp32) {
return;
}
pGraph->Init();
memcpy(inputs[0]->Data(), input_data, input_size);
pGraph->Run();
size_t output_size;
@ -108,9 +112,10 @@ TEST_F(TestMatMulOpenCL, MatMulFp32) {
std::cout << std::endl;
// compare
CompareOutputData(output_data, correct_data, co, 0.00001);
MS_LOG(INFO) << "TestMatMulFp32 passed";
CompareOutputData(output_data, correct_data, co, 0.0001);
tensor_x->SetData(nullptr);
tensor_out->SetData(nullptr);
lite::opencl::OpenCLRuntime::DeleteInstance();
MS_LOG(INFO) << "TestMatMulFp32 passed";
}
} // namespace mindspore

View File

@ -44,14 +44,15 @@ TEST_F(TestTransposeOpenCL, TransposeFp32) {
}
std::vector<int> input_shape = {1, h, w, c};
auto tensor_x_ptr =
std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), input_shape, schema::Format_NHWC4);
std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), input_shape, schema::Format_NHWC);
auto tensor_x = tensor_x_ptr.get();
if (tensor_x == nullptr) {
MS_LOG(ERROR) << "tensor_x create error.";
return;
}
std::vector<int> out_shape = {1, c, h, w};
auto tensor_out_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), out_shape);
auto tensor_out_ptr =
std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), out_shape, schema::Format_NCHW);
auto tensor_out = tensor_out_ptr.get();
if (tensor_out == nullptr) {
MS_LOG(ERROR) << "tensor_out create error.";
@ -102,7 +103,11 @@ TEST_F(TestTransposeOpenCL, TransposeFp32) {
// compare
CompareOutputData(output_data, correct_data, h * w * c, 0.00001);
MS_LOG(INFO) << "Test TransposeFp32 passed";
inputs[0]->SetData(nullptr);
outputs[0]->SetData(nullptr);
lite::opencl::OpenCLRuntime::DeleteInstance();
MS_LOG(INFO) << "Test TransposeFp32 passed";
}
} // namespace mindspore