!6402 fixed cast test case for opencl

Merge pull request !6402 from liuchao/master
This commit is contained in:
mindspore-ci-bot 2020-09-17 18:35:44 +08:00 committed by Gitee
commit 9e04ba7074
3 changed files with 12 additions and 17 deletions

View File

@ -18,7 +18,6 @@
#include <set>
#include<string>
#include "src/kernel_registry.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/opencl/kernel/cast.h"
#include "src/runtime/kernel/opencl/utils.h"
#include "src/runtime/kernel/opencl/cl/cast.cl.inc"
@ -40,8 +39,7 @@ int CastOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
im_dst_x = out_tensors_[0]->Width();
}
size_t img_dtype = CL_FLOAT;
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
auto enable_fp16_ = ocl_runtime->GetFp16Enable();
auto enable_fp16_ = ocl_runtime_->GetFp16Enable();
if (enable_fp16_) {
img_dtype = CL_HALF_FLOAT;
}
@ -83,9 +81,8 @@ int CastOpenCLKernel::Init() {
std::set<std::string> build_options;
std::string source = cast_source;
std::string program_name = "cast";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
ocl_runtime_->LoadSource(program_name, source);
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options);
return RET_OK;
}
@ -108,7 +105,6 @@ void CastGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *lo
int CastOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running! ";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
auto input_shape = in_tensors_[0]->shape();
cl_int4 input_shape_ = {input_shape[0], input_shape[1], input_shape[2], UP_DIV(input_shape[3], C4NUM)};
@ -116,15 +112,15 @@ int CastOpenCLKernel::Run() {
uint32_t OW = input_shape[2];
uint32_t OC = UP_DIV(input_shape[3], C4NUM);
const std::vector<size_t> &max_global = ocl_runtime->GetWorkItemSize();
const std::vector<size_t> &max_global = ocl_runtime_->GetWorkItemSize();
std::vector<size_t> local = {1, 1, 1}; // init local
std::vector<size_t> global = {OH, OW, OC};
CastGetWorkGroup(global, &local, max_global[0]);
int arg_cn = 0;
ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c()); // input tensor
ocl_runtime->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c()); // out tensor
ocl_runtime->SetKernelArg(kernel_, arg_cn++, input_shape_);
ocl_runtime->RunKernel(kernel_, global, local, nullptr);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c()); // input tensor
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c()); // out tensor
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape_);
ocl_runtime_->RunKernel(kernel_, global, local, nullptr);
return RET_OK;
}

View File

@ -21,7 +21,6 @@
#include<string>
#include "ir/anf.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "nnacl/fp32/cast.h"
namespace mindspore::kernel {

View File

@ -119,10 +119,10 @@ TEST_F(TestCastSelfOpenCL, Castfp32tofp16) {
for (auto tensor : outputs) {
delete tensor;
}
delete param;
delete cast_kernel;
delete sub_graph;
lite::opencl::OpenCLRuntime::DeleteInstance();
}
TEST_F(TestCastSelfOpenCL, Castfp16tofp32) {
MS_LOG(INFO) << " begin test ";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
@ -199,14 +199,14 @@ TEST_F(TestCastSelfOpenCL, Castfp16tofp32) {
sub_graph->Run();
auto *output_data_gpu = reinterpret_cast<float *>(output_tensor->data_c());
CompareOutputData1(output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001);
for (auto tensor : inputs) {
delete tensor;
}
for (auto tensor : outputs) {
delete tensor;
}
delete param;
delete cast_kernel;
delete sub_graph;
lite::opencl::OpenCLRuntime::DeleteInstance();
}
} // namespace mindspore