forked from mindspore-Ecosystem/mindspore
!4766 [MS][LITE][GPU]memory not release to testcase
Merge pull request !4766 from chenzupeng/master-lite
This commit is contained in:
commit
dc8b3db126
|
@ -192,9 +192,7 @@ kernel::LiteKernel *OpenCLConv2dTransposeKernelCreator(const std::vector<lite::t
|
|||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (0 != ret) {
|
||||
// MS_LOG(ERROR) << "Init kernel failed, name: " << opDef.name()->str()
|
||||
// << ", type: " << lite::EnumNameOpT(opDef.attr_type());
|
||||
if (ret != RET_OK) {
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -40,7 +40,6 @@ int MatMulOpenCLKernel::Init() {
|
|||
ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name);
|
||||
#else
|
||||
std::set<std::string> build_options;
|
||||
// build_options.emplace("-DPOOL_AVG");
|
||||
#ifdef ENABLE_FP16
|
||||
std::string source = matmul_source_fp16;
|
||||
#else
|
||||
|
@ -169,9 +168,7 @@ kernel::LiteKernel *OpenCLMatMulKernelCreator(const std::vector<lite::tensor::Te
|
|||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (0 != ret) {
|
||||
// MS_LOG(ERROR) << "Init kernel failed, name: " << opDef.name()->str()
|
||||
// << ", type: " << lite::EnumNameOpT(opDef.attr_type());
|
||||
if (ret != RET_OK) {
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -83,7 +83,6 @@ int ReshapeOpenCLKernel::Run() {
|
|||
int c = shapex[3];
|
||||
int c4 = UP_DIV(c, C4NUM);
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
// local size should less than MAX_GROUP_SIZE
|
||||
std::vector<size_t> local = {};
|
||||
std::vector<size_t> global = {(size_t)h, (size_t)w, (size_t)c4};
|
||||
cl_int4 size = {h, w, c4, 1};
|
||||
|
|
|
@ -91,7 +91,9 @@ int SoftmaxOpenCLKernel::Init() {
|
|||
std::string source = softmax_source_fp32;
|
||||
runtime_ = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
// framework not set this param yet! just use default.
|
||||
parameter_->axis_ = 1;
|
||||
if (parameter_->axis_ == -1) {
|
||||
parameter_->axis_ = 1;
|
||||
}
|
||||
if (in_tensors_[0]->shape().size() == 4 && parameter_->axis_ == 3) {
|
||||
// support 4d tensor
|
||||
onexone_flag_ = false;
|
||||
|
@ -180,7 +182,7 @@ kernel::LiteKernel *OpenCLSoftMaxKernelCreator(const std::vector<lite::tensor::T
|
|||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (0 != ret) {
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init `Softmax` kernel failed!";
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
|
|
|
@ -64,7 +64,6 @@ int TransposeOpenCLKernel::Init() {
|
|||
MS_LOG(ERROR) << "input H * W % 4 != 0 not support!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// Transpose::InferShape just set output->SetFormat(input->GetFormat()); -^-!
|
||||
ori_format_ = schema::Format_NCHW;
|
||||
out_tensors_[0]->SetFormat(schema::Format_NCHW);
|
||||
if (!is_image_out_) {
|
||||
|
@ -100,7 +99,6 @@ int TransposeOpenCLKernel::Run() {
|
|||
int c4 = UP_DIV(c, 4);
|
||||
int hw4 = UP_DIV(h * w, 4);
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
// local size should less than MAX_GROUP_SIZE
|
||||
std::vector<size_t> local = {16, 16};
|
||||
std::vector<size_t> global = {UP_ROUND(hw4, local[0]), UP_ROUND(c4, local[1])};
|
||||
|
||||
|
@ -126,7 +124,7 @@ kernel::LiteKernel *OpenCLTransposeKernelCreator(const std::vector<lite::tensor:
|
|||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (0 != ret) {
|
||||
if (ret != RET_OK) {
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -152,6 +152,7 @@ if (SUPPORT_GPU)
|
|||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/to_format.cc
|
||||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/caffe_prelu.cc
|
||||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/prelu.cc
|
||||
${LITE_DIR}/src/runtime/kernel/opencl/kernel/to_format.cc
|
||||
)
|
||||
endif()
|
||||
### minddata lite
|
||||
|
|
|
@ -30,7 +30,6 @@ class TestConv2dTransposeOpenCL : public mindspore::CommonTest {
|
|||
};
|
||||
|
||||
TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) {
|
||||
// setbuf(stdout, NULL);
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->Init();
|
||||
auto allocator = ocl_runtime->GetAllocator();
|
||||
|
@ -48,27 +47,67 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) {
|
|||
size_t input_size;
|
||||
std::string input_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_input.bin";
|
||||
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
|
||||
if (input_data == nullptr) {
|
||||
MS_LOG(ERROR) << "input_data load error.";
|
||||
return;
|
||||
}
|
||||
|
||||
size_t weight_size;
|
||||
std::string weight_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_weight.bin";
|
||||
auto weight_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size));
|
||||
if (weight_data == nullptr) {
|
||||
MS_LOG(ERROR) << "weight_data load error.";
|
||||
return;
|
||||
}
|
||||
|
||||
size_t bias_size;
|
||||
std::string bias_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_bias.bin";
|
||||
auto bias_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(bias_path.c_str(), &bias_size));
|
||||
if (bias_data == nullptr) {
|
||||
MS_LOG(ERROR) << "bias_data load error.";
|
||||
return;
|
||||
}
|
||||
std::vector<int> input_shape = {n, h, w, ci};
|
||||
auto tensor_x_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), input_shape);
|
||||
auto tensor_x = tensor_x_ptr.get();
|
||||
if (tensor_x == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_x create error.";
|
||||
return;
|
||||
}
|
||||
|
||||
lite::tensor::Tensor *tensor_x = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, h, w, ci});
|
||||
|
||||
lite::tensor::Tensor *tensor_w = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {co, kh, kw, ci});
|
||||
std::vector<int> weight_shape = {co, kh, kw, ci};
|
||||
auto tensor_w_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), weight_shape);
|
||||
auto tensor_w = tensor_w_ptr.get();
|
||||
if (tensor_w == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_w create error.";
|
||||
return;
|
||||
}
|
||||
tensor_w->SetData(weight_data);
|
||||
|
||||
lite::tensor::Tensor *tensor_bias = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {co});
|
||||
std::vector<int> bias_shape = {co};
|
||||
auto tensor_bias_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), bias_shape);
|
||||
auto tensor_bias = tensor_bias_ptr.get();
|
||||
if (tensor_bias == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_bias create error.";
|
||||
return;
|
||||
}
|
||||
tensor_bias->SetData(bias_data);
|
||||
|
||||
lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, oh, ow, co});
|
||||
std::vector<int> out_shape = {1, oh, ow, co};
|
||||
auto tensor_out_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), out_shape);
|
||||
auto tensor_out = tensor_out_ptr.get();
|
||||
if (tensor_out == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_out create error.";
|
||||
return;
|
||||
}
|
||||
std::vector<lite::tensor::Tensor *> inputs{tensor_x, tensor_w, tensor_bias};
|
||||
std::vector<lite::tensor::Tensor *> outputs{tensor_out};
|
||||
ConvParameter *opParameter = new ConvParameter();
|
||||
auto opParameter_ptr = std::make_unique<ConvParameter>();
|
||||
auto opParameter = opParameter_ptr.get();
|
||||
if (opParameter == nullptr) {
|
||||
MS_LOG(ERROR) << "opParameter create error.";
|
||||
return;
|
||||
}
|
||||
opParameter->kernel_h_ = kh;
|
||||
opParameter->kernel_w_ = kw;
|
||||
opParameter->stride_h_ = 2;
|
||||
|
@ -77,23 +116,39 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) {
|
|||
opParameter->pad_w_ = pad;
|
||||
opParameter->input_channel_ = ci;
|
||||
opParameter->output_channel_ = co;
|
||||
auto *arith_kernel =
|
||||
new kernel::Conv2dTransposeOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs);
|
||||
auto arith_kernel_ptr = std::make_unique<kernel::Conv2dTransposeOpenCLKernel>(
|
||||
reinterpret_cast<OpParameter *>(opParameter), inputs, outputs);
|
||||
auto arith_kernel = arith_kernel_ptr.get();
|
||||
if (arith_kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "arith_kernel create error.";
|
||||
return;
|
||||
}
|
||||
arith_kernel->Init();
|
||||
|
||||
inputs[0]->MallocData(allocator);
|
||||
std::vector<kernel::LiteKernel *> kernels{arith_kernel};
|
||||
auto *pGraph = new kernel::SubGraphOpenCLKernel({tensor_x}, outputs, kernels, kernels, kernels);
|
||||
std::vector<lite::tensor::Tensor *> inputs_g{tensor_x};
|
||||
auto pGraph_ptr = std::make_unique<kernel::SubGraphOpenCLKernel>(inputs_g, outputs, kernels, kernels, kernels);
|
||||
auto pGraph = pGraph_ptr.get();
|
||||
if (pGraph == nullptr) {
|
||||
MS_LOG(ERROR) << "pGraph create error.";
|
||||
return;
|
||||
}
|
||||
|
||||
pGraph->Init();
|
||||
memcpy(inputs[0]->Data(), input_data, input_size);
|
||||
pGraph->Run();
|
||||
|
||||
printf("==================output data=================\n");
|
||||
std::cout << "==================output data=================" << std::endl;
|
||||
float *output_data = reinterpret_cast<float *>(tensor_out->Data());
|
||||
std::cout << std::endl;
|
||||
size_t output_size;
|
||||
std::string output_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_output.bin";
|
||||
auto correct_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size));
|
||||
if (correct_data == nullptr) {
|
||||
MS_LOG(ERROR) << "correct_data create error.";
|
||||
return;
|
||||
}
|
||||
int size_n = oh * ow * co;
|
||||
size_n = size_n > 100 ? 100 : size_n;
|
||||
for (int i = 0; i < size_n; i++) {
|
||||
|
@ -108,14 +163,6 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) {
|
|||
CompareOutputData(output_data, correct_data, oh * ow * co, 0.00001);
|
||||
|
||||
MS_LOG(INFO) << "Test Conv2dTransposeFp32 passed";
|
||||
for (auto tensor : inputs) {
|
||||
delete tensor;
|
||||
}
|
||||
for (auto tensor : outputs) {
|
||||
delete tensor;
|
||||
}
|
||||
delete arith_kernel;
|
||||
delete pGraph;
|
||||
lite::opencl::OpenCLRuntime::DeleteInstance();
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -36,25 +36,61 @@ TEST_F(TestMatMulOpenCL, MatMulFp32) {
|
|||
int co = 1001;
|
||||
std::string input_path = "./test_data/matmul/matmul_fp32_input.bin";
|
||||
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
|
||||
|
||||
if (input_data == nullptr) {
|
||||
MS_LOG(ERROR) << "input_data load error.";
|
||||
return;
|
||||
}
|
||||
size_t weight_size;
|
||||
std::string weight_path = "./test_data/matmul/matmul_fp32_weight.bin";
|
||||
auto weight_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size));
|
||||
|
||||
lite::tensor::Tensor *tensor_x = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, 1, 1, ci});
|
||||
if (weight_data == nullptr) {
|
||||
MS_LOG(ERROR) << "weight_data load error.";
|
||||
return;
|
||||
}
|
||||
std::vector<int> input_shape = {1, 1, 1, ci};
|
||||
auto tensor_x_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), input_shape);
|
||||
auto tensor_x = tensor_x_ptr.get();
|
||||
if (tensor_x == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_x create error.";
|
||||
return;
|
||||
}
|
||||
tensor_x->SetData(input_data);
|
||||
|
||||
lite::tensor::Tensor *tensor_w = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {co, 1, 1, ci});
|
||||
std::vector<int> w_shape = {co, 1, 1, ci};
|
||||
auto tensor_w_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), w_shape);
|
||||
auto tensor_w = tensor_w_ptr.get();
|
||||
if (tensor_w == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_w create error.";
|
||||
return;
|
||||
}
|
||||
tensor_w->SetData(weight_data);
|
||||
|
||||
lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, 1, 1, co});
|
||||
std::vector<int> out_shape = {1, 1, 1, co};
|
||||
auto tensor_out_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), out_shape);
|
||||
auto tensor_out = tensor_out_ptr.get();
|
||||
if (tensor_out == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_out create error.";
|
||||
return;
|
||||
}
|
||||
std::vector<lite::tensor::Tensor *> inputs{tensor_x, tensor_w};
|
||||
std::vector<lite::tensor::Tensor *> outputs{tensor_out};
|
||||
auto *arith_kernel = new kernel::MatMulOpenCLKernel(nullptr, inputs, outputs, false);
|
||||
auto arith_kernel_ptr = std::make_unique<kernel::MatMulOpenCLKernel>(nullptr, inputs, outputs, false);
|
||||
auto arith_kernel = arith_kernel_ptr.get();
|
||||
if (arith_kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "arith_kernel create error.";
|
||||
return;
|
||||
}
|
||||
arith_kernel->Init();
|
||||
|
||||
std::vector<kernel::LiteKernel *> kernels{arith_kernel};
|
||||
auto *pGraph = new kernel::SubGraphOpenCLKernel({tensor_x}, outputs, kernels, kernels, kernels);
|
||||
|
||||
std::vector<lite::tensor::Tensor *> inputs_g{tensor_x};
|
||||
auto pGraph_ptr = std::make_unique<kernel::SubGraphOpenCLKernel>(inputs_g, outputs, kernels, kernels, kernels);
|
||||
auto pGraph = pGraph_ptr.get();
|
||||
if (pGraph == nullptr) {
|
||||
MS_LOG(ERROR) << "pGraph create error.";
|
||||
return;
|
||||
}
|
||||
pGraph->Init();
|
||||
pGraph->Run();
|
||||
|
||||
|
@ -71,19 +107,10 @@ TEST_F(TestMatMulOpenCL, MatMulFp32) {
|
|||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
|
||||
// compare
|
||||
CompareOutputData(output_data, correct_data, co, 0.00001);
|
||||
|
||||
MS_LOG(INFO) << "TestMatMulFp32 passed";
|
||||
for (auto tensor : inputs) {
|
||||
delete tensor;
|
||||
}
|
||||
for (auto tensor : outputs) {
|
||||
delete tensor;
|
||||
}
|
||||
delete arith_kernel;
|
||||
delete pGraph;
|
||||
lite::opencl::OpenCLRuntime::DeleteInstance();
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include "mindspore/lite/src/common/file_utils.h"
|
||||
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestToFormatOpenCL : public mindspore::CommonTest {
|
||||
|
@ -28,8 +28,8 @@ class TestToFormatOpenCL : public mindspore::CommonTest {
|
|||
TestToFormatOpenCL() {}
|
||||
};
|
||||
|
||||
TEST_F(TestToFormatOpenCL, TransposeFp32) {
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
TEST_F(TestToFormatOpenCL, ToFormatNHWC2NCHW) {
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->Init();
|
||||
auto allocator = ocl_runtime->GetAllocator();
|
||||
int h = 64;
|
||||
|
@ -38,20 +38,44 @@ TEST_F(TestToFormatOpenCL, TransposeFp32) {
|
|||
size_t input_size;
|
||||
std::string input_path = "./test_data/transpose/transpose_fp32_input.bin";
|
||||
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
|
||||
|
||||
lite::tensor::Tensor *tensor_x =
|
||||
new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, h, w, c}, schema::Format_NHWC4);
|
||||
|
||||
lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, c, h, w});
|
||||
if (input_data == nullptr) {
|
||||
MS_LOG(ERROR) << "input_data load error.";
|
||||
return;
|
||||
}
|
||||
std::vector<int> input_shape = {1, h, w, c};
|
||||
auto tensor_x_ptr =
|
||||
std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), input_shape, schema::Format_NHWC4);
|
||||
auto tensor_x = tensor_x_ptr.get();
|
||||
if (tensor_x == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_x create error.";
|
||||
return;
|
||||
}
|
||||
std::vector<int> out_shape = {1, c, h, w};
|
||||
auto tensor_out_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), out_shape);
|
||||
auto tensor_out = tensor_out_ptr.get();
|
||||
if (tensor_out == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_out create error.";
|
||||
return;
|
||||
}
|
||||
std::vector<lite::tensor::Tensor *> inputs{tensor_x};
|
||||
std::vector<lite::tensor::Tensor *> outputs{tensor_out};
|
||||
auto *arith_kernel = new kernel::TransposeOpenCLKernel(nullptr, inputs, outputs);
|
||||
auto arith_kernel_ptr = std::make_unique<kernel::ToFormatOpenCLKernel>(nullptr, inputs, outputs);
|
||||
auto arith_kernel = arith_kernel_ptr.get();
|
||||
if (arith_kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "arith_kernel create error.";
|
||||
return;
|
||||
}
|
||||
arith_kernel->Init();
|
||||
|
||||
inputs[0]->MallocData(allocator);
|
||||
|
||||
std::vector<kernel::LiteKernel *> kernels{arith_kernel};
|
||||
auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
|
||||
auto pGraph_ptr = std::make_unique<kernel::SubGraphOpenCLKernel>(inputs, outputs, kernels, kernels, kernels);
|
||||
auto pGraph = pGraph_ptr.get();
|
||||
if (pGraph == nullptr) {
|
||||
MS_LOG(ERROR) << "pGraph create error.";
|
||||
return;
|
||||
}
|
||||
pGraph->Init();
|
||||
memcpy(inputs[0]->Data(), input_data, input_size);
|
||||
pGraph->Run();
|
||||
|
@ -59,6 +83,10 @@ TEST_F(TestToFormatOpenCL, TransposeFp32) {
|
|||
size_t output_size;
|
||||
std::string output_path = "./test_data/transpose/transpose_fp32_output.bin";
|
||||
auto correct_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size));
|
||||
if (correct_data == nullptr) {
|
||||
MS_LOG(ERROR) << "correct_data create error.";
|
||||
return;
|
||||
}
|
||||
printf("==================output data=================\n");
|
||||
float *output_data = reinterpret_cast<float *>(tensor_out->Data());
|
||||
std::cout << std::endl;
|
||||
|
@ -74,15 +102,7 @@ TEST_F(TestToFormatOpenCL, TransposeFp32) {
|
|||
|
||||
// compare
|
||||
CompareOutputData(output_data, correct_data, h * w * c, 0.00001);
|
||||
MS_LOG(INFO) << "TestMatMulFp32 passed";
|
||||
for (auto tensor : inputs) {
|
||||
delete tensor;
|
||||
}
|
||||
for (auto tensor : outputs) {
|
||||
delete tensor;
|
||||
}
|
||||
delete arith_kernel;
|
||||
delete pGraph;
|
||||
MS_LOG(INFO) << "Test TransposeFp32 passed";
|
||||
lite::opencl::OpenCLRuntime::DeleteInstance();
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,20 +38,44 @@ TEST_F(TestTransposeOpenCL, TransposeFp32) {
|
|||
size_t input_size;
|
||||
std::string input_path = "./test_data/transpose/transpose_fp32_input.bin";
|
||||
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
|
||||
|
||||
lite::tensor::Tensor *tensor_x =
|
||||
new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, h, w, c}, schema::Format_NHWC4);
|
||||
|
||||
lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, c, h, w});
|
||||
if (input_data == nullptr) {
|
||||
MS_LOG(ERROR) << "input_data load error.";
|
||||
return;
|
||||
}
|
||||
std::vector<int> input_shape = {1, h, w, c};
|
||||
auto tensor_x_ptr =
|
||||
std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), input_shape, schema::Format_NHWC4);
|
||||
auto tensor_x = tensor_x_ptr.get();
|
||||
if (tensor_x == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_x create error.";
|
||||
return;
|
||||
}
|
||||
std::vector<int> out_shape = {1, c, h, w};
|
||||
auto tensor_out_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), out_shape);
|
||||
auto tensor_out = tensor_out_ptr.get();
|
||||
if (tensor_out == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_out create error.";
|
||||
return;
|
||||
}
|
||||
std::vector<lite::tensor::Tensor *> inputs{tensor_x};
|
||||
std::vector<lite::tensor::Tensor *> outputs{tensor_out};
|
||||
auto *arith_kernel = new kernel::TransposeOpenCLKernel(nullptr, inputs, outputs);
|
||||
auto arith_kernel_ptr = std::make_unique<kernel::TransposeOpenCLKernel>(nullptr, inputs, outputs);
|
||||
auto arith_kernel = arith_kernel_ptr.get();
|
||||
if (arith_kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "arith_kernel create error.";
|
||||
return;
|
||||
}
|
||||
arith_kernel->Init();
|
||||
|
||||
inputs[0]->MallocData(allocator);
|
||||
|
||||
std::vector<kernel::LiteKernel *> kernels{arith_kernel};
|
||||
auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
|
||||
auto pGraph_ptr = std::make_unique<kernel::SubGraphOpenCLKernel>(inputs, outputs, kernels, kernels, kernels);
|
||||
auto pGraph = pGraph_ptr.get();
|
||||
if (pGraph == nullptr) {
|
||||
MS_LOG(ERROR) << "pGraph create error.";
|
||||
return;
|
||||
}
|
||||
pGraph->Init();
|
||||
memcpy(inputs[0]->Data(), input_data, input_size);
|
||||
pGraph->Run();
|
||||
|
@ -59,6 +83,10 @@ TEST_F(TestTransposeOpenCL, TransposeFp32) {
|
|||
size_t output_size;
|
||||
std::string output_path = "./test_data/transpose/transpose_fp32_output.bin";
|
||||
auto correct_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size));
|
||||
if (correct_data == nullptr) {
|
||||
MS_LOG(ERROR) << "correct_data create error.";
|
||||
return;
|
||||
}
|
||||
printf("==================output data=================\n");
|
||||
float *output_data = reinterpret_cast<float *>(tensor_out->Data());
|
||||
std::cout << std::endl;
|
||||
|
@ -74,15 +102,7 @@ TEST_F(TestTransposeOpenCL, TransposeFp32) {
|
|||
|
||||
// compare
|
||||
CompareOutputData(output_data, correct_data, h * w * c, 0.00001);
|
||||
MS_LOG(INFO) << "TestMatMulFp32 passed";
|
||||
for (auto tensor : inputs) {
|
||||
delete tensor;
|
||||
}
|
||||
for (auto tensor : outputs) {
|
||||
delete tensor;
|
||||
}
|
||||
delete arith_kernel;
|
||||
delete pGraph;
|
||||
MS_LOG(INFO) << "Test TransposeFp32 passed";
|
||||
lite::opencl::OpenCLRuntime::DeleteInstance();
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue