opencl for x86

This commit is contained in:
Fazzie 2021-09-27 16:48:48 +08:00
parent 50be8ead68
commit 7abd1a4935
33 changed files with 794 additions and 66 deletions

View File

@ -20,7 +20,7 @@ endif()
#Options that can be configured through environment variables or manually
set(MSLITE_GPU_BACKEND "" CACHE STRING "enable gpu backend, \
only arm64 support opencl, only x86_64 support tensorrt, opencl/cuda/tensorrt/off")
opencl only support arm64 and x86_64 , tensorrt only support x86_64, opencl/cuda/tensorrt/off")
option(MSLITE_ENABLE_NPU "enable npu, only arm64 or arm32 support" off)
option(MSLITE_ENABLE_TRAIN "enable train" on)
option(MSLITE_ENABLE_SSE "enable SSE instruction set, only x86_64 support" off)
@ -155,7 +155,8 @@ else()
if(MSLITE_GPU_BACKEND STREQUAL "")
set(MSLITE_GPU_BACKEND "off")
endif()
if((NOT MSLITE_GPU_BACKEND STREQUAL "tensorrt") AND (NOT MSLITE_GPU_BACKEND STREQUAL "off"))
if((NOT MSLITE_GPU_BACKEND STREQUAL "tensorrt") AND (NOT MSLITE_GPU_BACKEND STREQUAL "off") AND
(NOT MSLITE_GPU_BACKEND STREQUAL "opencl"))
message("invalid MSLITE_GPU_BACKEND value ${MSLITE_GPU_BACKEND} for x86_64, MSLITE_GPU_BACKEND is set to off.")
set(MSLITE_GPU_BACKEND "off")
endif()

View File

@ -173,7 +173,7 @@ int OpenCLRuntime::InitGPUDevice(std::vector<cl::Platform> *platforms) {
int OpenCLRuntime::InitQueue(std::vector<cl::Platform> *platforms) {
MS_ASSERT(platforms);
cl_int ret;
cl_int ret = 0;
#if defined(SHARING_MEM_WITH_OPENGL) && defined(CL_HPP_TARGET_OPENCL_VERSION) && (CL_HPP_TARGET_OPENCL_VERSION >= 120)
// create context from glcontext
MS_LOG(INFO) << "Create special opencl context to share with OpenGL";
@ -719,7 +719,7 @@ void OpenCLRuntime::LoadCache() {
MS_LOG(ERROR) << "Load opencl cache fail: bins == nullptr";
return;
}
for (auto i = 0; i < bins->size(); ++i) {
for (size_t i = 0; i < bins->size(); ++i) {
auto *bin = bins->template GetAs<schema::ProgramBinary>(i);
if (bin == nullptr) {
MS_LOG(ERROR) << "kernel_bin[" << i << "] null";

View File

@ -88,7 +88,7 @@ void ArithmeticOpenCLKernel::SetGlobalLocal() {
int ArithmeticOpenCLKernel::InitWeights() {
auto allocator = ocl_runtime_->GetAllocator();
auto fp16_enable = ocl_runtime_->GetFp16Enable();
for (int i = 0; i < in_tensors_.size(); ++i) {
for (size_t i = 0; i < in_tensors_.size(); ++i) {
const auto &in_tensor = in_tensors_.at(i);
GpuTensorInfo in_shape = GpuTensorInfo(in_tensor);
if (in_tensor->IsConst()) {

View File

@ -20,6 +20,7 @@
#include <vector>
#include <set>
#include <string>
#include <cfloat>
#include "src/runtime/kernel/arm/fp32/arithmetic_fp32.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"

View File

@ -140,6 +140,7 @@ int BatchNormOpenCLKernel::MapBuffer() {
return RET_OK;
}
#ifdef ENABLE_FP16
int BatchNormOpenCLKernel::Initweight() {
auto allocator = ocl_runtime_->GetAllocator();
GpuTensorInfo img_info(in_tensors_.at(1));
@ -264,6 +265,86 @@ int BatchNormOpenCLKernel::Prepare() {
return RET_OK;
}
#else
int BatchNormOpenCLKernel::Initweight() {
auto allocator = ocl_runtime_->GetAllocator();
GpuTensorInfo img_info(in_tensors_.at(1));
size_t weight_size = img_info.OriginSize;
// allocated memory for weight and init value
scale_ = allocator->Malloc(weight_size, lite::opencl::MemType::BUF);
if (scale_ == nullptr) {
MS_LOG(ERROR) << "Malloc failed.";
return RET_ERROR;
}
offset_ = allocator->Malloc(weight_size, lite::opencl::MemType::BUF);
if (offset_ == nullptr) {
MS_LOG(ERROR) << "Malloc failed.";
return RET_ERROR;
}
mean_ = allocator->Malloc(weight_size, lite::opencl::MemType::BUF);
if (mean_ == nullptr) {
MS_LOG(ERROR) << "Malloc failed.";
return RET_ERROR;
}
variance_ = allocator->Malloc(weight_size, lite::opencl::MemType::BUF);
if (variance_ == nullptr) {
MS_LOG(ERROR) << "Malloc failed.";
return RET_ERROR;
}
if (MapBuffer() != RET_OK) {
MS_LOG(ERROR) << "Map Buffer failed.";
return RET_ERROR;
}
memset(scale_, 1, weight_size);
memset(offset_, 0x00, weight_size);
memset(mean_, 0x00, weight_size);
memset(variance_, 0x00, weight_size);
CHECK_NULL_RETURN(in_tensors_.at(kNumInput1)->data());
CHECK_NULL_RETURN(in_tensors_.at(kNumInput2)->data());
CHECK_NULL_RETURN(in_tensors_.at(kNumInput3)->data());
CHECK_NULL_RETURN(in_tensors_.at(kNumInput4)->data());
memcpy(scale_, in_tensors_.at(1)->data(), weight_size);
memcpy(offset_, in_tensors_.at(2)->data(), weight_size);
memcpy(mean_, in_tensors_.at(3)->data(), weight_size);
memcpy(variance_, in_tensors_.at(4)->data(), weight_size);
if (UnmapBuffer() != RET_OK) {
MS_LOG(ERROR) << "UnmapBuffer failed.";
return RET_ERROR;
}
return RET_OK;
}
int BatchNormOpenCLKernel::Prepare() {
use_fp16_enable_ = ocl_runtime_->GetFp16Enable();
const std::string kernel_name = "Batch_normalization_NHWC4";
std::string source = batchnorm_source;
const std::string program_name = "Batch_normalization";
if (!ocl_runtime_->LoadSource(program_name, source)) {
MS_LOG(ERROR) << "Load source failed.";
return RET_ERROR;
}
auto build_options_ext = CreateBuildOptionsExtByDType(this->registry_data_type_);
auto ret = ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options_ext);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Build kernel failed.";
return ret;
}
MS_LOG(DEBUG) << kernel_name << " Init Done!";
ret = Initweight();
if (ret) {
MS_LOG(ERROR) << "Initweight failed ";
return RET_ERROR;
}
if (SetConstArgs() != RET_OK) {
MS_LOG(ERROR) << "SeConstArgs failed.";
return RET_ERROR;
}
SetGlobalLocal();
return RET_OK;
}
#endif
int BatchNormOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running! ";

View File

@ -37,7 +37,7 @@ int ConcatOpenCLKernel::RunAxis0() {
MS_ASSERT(dst_data);
auto dst_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto *out_image = allocator_->GetImage(dst_data);
for (int i = 0; i < in_tensors_.size(); i++) {
for (size_t i = 0; i < in_tensors_.size(); i++) {
auto src_data = weight_ptrs_.at(i) == nullptr ? in_tensors_[i]->data() : weight_ptrs_.at(i);
if (allocator_->GetImageSize(src_data, &img_size) != RET_OK) {
MS_LOG(ERROR) << "GetImageSize failed.";
@ -125,7 +125,7 @@ int ConcatOpenCLKernel::SetConstArgs() {
size_t dtype = ocl_runtime_->GetFp16Enable() ? sizeof(cl_half) : sizeof(cl_float);
stride_w = img_info.RowPitch() / dtype;
cl_int4 output_shape_ = {};
for (int i = 0; i < out_tensors_[0]->shape().size(); ++i) {
for (size_t i = 0; i < out_tensors_[0]->shape().size(); ++i) {
output_shape_.s[i] = out_tensors_[0]->shape()[i];
}
Broadcast2GpuShape(out_shape_.s, output_shape_.s, out_tensors_[0]->shape().size(), 1);
@ -133,7 +133,7 @@ int ConcatOpenCLKernel::SetConstArgs() {
if (axis_ == 3 && !Align_) {
for (auto &in_tensor : in_tensors_) {
cl_int4 temp = {};
for (int j = 0; j < in_tensor->shape().size(); ++j) {
for (size_t j = 0; j < in_tensor->shape().size(); ++j) {
temp.s[j] = in_tensor->shape()[j];
}
Broadcast2GpuShape(in_shape_.s, temp.s, in_tensor->shape().size(), 1);
@ -149,7 +149,7 @@ int ConcatOpenCLKernel::SetConstArgs() {
} else {
for (auto &in_tensor : in_tensors_) {
cl_int4 temp = {};
for (int j = 0; j < in_tensor->shape().size(); ++j) {
for (size_t j = 0; j < in_tensor->shape().size(); ++j) {
temp.s[j] = in_tensor->shape()[j];
}
Broadcast2GpuShape(in_shape_.s, temp.s, in_tensor->shape().size(), 1);
@ -282,7 +282,7 @@ int ConcatOpenCLKernel::Run() {
return RunAxis0();
}
int arg_cn = 0;
for (int i = 0; i < in_tensors_.size(); ++i) {
for (size_t i = 0; i < in_tensors_.size(); ++i) {
auto input_ptr = weight_ptrs_.at(i) == nullptr ? in_tensors_[i]->data() : weight_ptrs_.at(i);
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_ptr) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";

View File

@ -122,8 +122,12 @@ int Conv2DOpenCLKernel::Prepare() {
int Conv2DOpenCLKernel::InitAttrs() {
CHECK_NULL_RETURN(ocl_runtime_);
#ifdef ENABLE_FP16
use_fp16_ = ocl_runtime_->GetFp16Enable();
sizeof_FLT_ = use_fp16_ ? sizeof(float16_t) : sizeof(float);
#else
sizeof_FLT_ = sizeof(float);
#endif
CHECK_NULL_RETURN(in_tensors_.front());
CHECK_NULL_RETURN(out_tensors_.front());
auto input_shape = in_tensors_.front()->shape();
@ -171,7 +175,7 @@ int Conv2DOpenCLKernel::BuildKernel() {
auto build_options_ext = CreateBuildOptionsExtByDType(this->registry_data_type_);
std::string exceed_max_image_width_option =
(OW_ * CO_SLICES_ <= ocl_runtime_->GetMaxImage2DWidth()) ? "" : " -DEXCEDD_MAX_IMAGE2D_WIDTH";
(OW_ * CO_SLICES_ <= static_cast<int>(ocl_runtime_->GetMaxImage2DWidth())) ? "" : " -DEXCEDD_MAX_IMAGE2D_WIDTH";
build_options_ext.push_back(exceed_max_image_width_option);
auto ret = ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name.str(), build_options_ext);
if (ret != RET_OK) {
@ -193,11 +197,15 @@ void Conv2DOpenCLKernel::SetBlockSize() {
KW_ == 1 && param_->stride_w_ == 1 && param_->dilation_w_ == 1 && param_->pad_l_ == 0 && param_->pad_r_ == 0;
bool h_kernel_is_1 =
KH_ == 1 && param_->stride_h_ == 1 && param_->dilation_h_ == 1 && param_->pad_u_ == 0 && param_->pad_d_ == 0;
#ifdef ENABLE_FP16
if (use_fp16_) {
SetMaliFp16BlockSize(task_size_per_cu, w_kernel_is_1, h_kernel_is_1);
} else {
SetMaliFp32BlockSize(task_size_per_cu, w_kernel_is_1, h_kernel_is_1);
}
#else
SetMaliFp32BlockSize(task_size_per_cu, w_kernel_is_1, h_kernel_is_1);
#endif
}
void Conv2DOpenCLKernel::SetMaliFp32BlockSize(int task_size_per_cu, bool w_kernel_is_1, bool h_kernel_is_1) {
@ -269,6 +277,7 @@ int Conv2DOpenCLKernel::InitWeights() {
return RET_OK;
}
#ifdef ENABLE_FP16
void ConvertFilter(void *src, void *dst, TypeId src_dtype, TypeId dst_dtype, FilterFormat src_format,
FilterFormat dst_format, size_t CO, size_t KH, size_t KW, size_t CI, size_t OGroup) {
MS_ASSERT(src);
@ -316,6 +325,45 @@ void ConvertFilter(void *src, void *dst, TypeId src_dtype, TypeId dst_dtype, Fil
}
}
}
#else
void ConvertFilter(void *src, void *dst, TypeId src_dtype, TypeId dst_dtype, FilterFormat src_format,
FilterFormat dst_format, size_t CO, size_t KH, size_t KW, size_t CI, size_t OGroup) {
MS_ASSERT(src);
MS_ASSERT(dst);
MS_ASSERT(src_format == OHWI);
MS_ASSERT(dst_format == HWII4OO4 || dst_format == OHWIOgroupI4O4);
auto src_fp32 = reinterpret_cast<float *>(src);
auto dst_fp32 = reinterpret_cast<float *>(dst);
auto CI_SLICES = UP_DIV(CI, CI_TILE);
auto CO_SLICES = UP_DIV(CO, CO_TILE);
for (size_t co = 0, src_idx = 0; co < CO; ++co) {
for (size_t kh = 0; kh < KH; ++kh) {
for (size_t kw = 0; kw < KW; ++kw) {
for (size_t ci = 0; ci < CI; ++ci, ++src_idx) {
size_t dst_idx = 0;
size_t co_inner = co % CO_TILE;
size_t ci_slice = ci / CI_TILE;
size_t ci_inner = ci % CI_TILE;
if (dst_format == OHWIOgroupI4O4) {
size_t co_slice = co / (CO_TILE * OGroup);
size_t group_idx = co % (CO_TILE * OGroup) / CO_TILE;
dst_idx =
(((((co_slice * KH + kh) * KW + kw) * CI_SLICES + ci_slice) * OGroup + group_idx) * CI_TILE + ci_inner) *
CO_TILE +
co_inner;
} else { // if(dst_format==HWII4OO4)
size_t co_slice = co / CO_TILE;
dst_idx =
((((kh * KW + kw) * CI_SLICES + ci_slice) * CI_TILE + ci_inner) * CO_SLICES + co_slice) * CO_TILE +
co_inner;
}
dst_fp32[dst_idx] = src_fp32[src_idx];
}
}
}
}
}
#endif
int Conv2DOpenCLKernel::InitFilter() {
auto allocator = ocl_runtime_->GetAllocator();
@ -395,6 +443,7 @@ int Conv2DOpenCLKernel::InitBias() {
void *src_data = stored_bias_ == nullptr ? bias_tensor->data() : stored_bias_;
MS_ASSERT(src_data);
#ifdef ENABLE_FP16
if (bias_tensor->data_type() == kNumberTypeFloat16) {
if (use_fp16_) {
memcpy(packed_bias_, src_data, CO_ * sizeof_FLT_);
@ -418,6 +467,9 @@ int Conv2DOpenCLKernel::InitBias() {
memcpy(packed_bias_, src_data, CO_ * sizeof_FLT_);
}
}
#else
memcpy(packed_bias_, src_data, CO_ * sizeof_FLT_);
#endif
}
if (allocator->UnmapBuffer(packed_bias_) != RET_OK) {
MS_LOG(ERROR) << "UnmapBuffer failed.";

View File

@ -188,7 +188,6 @@ int Conv2dTransposeOpenCLKernel::InitFilter() {
}
memset(padWeight_, 0x00, div_ci * div_co * C4NUM * C4NUM * kh * kw * data_size);
auto origin_weight = stored_weight_ == nullptr ? in_tensors_.at(kWeightIndex)->data() : stored_weight_;
auto weight_dtype = in_tensors_.at(kWeightIndex)->data_type();
int index = 0;
for (int co_i = 0; co_i < div_co; co_i++) {
for (int kh_i = 0; kh_i < kh; kh_i++) {
@ -200,6 +199,8 @@ int Conv2dTransposeOpenCLKernel::InitFilter() {
int ci_offset = ci_i * C4NUM + ci4_i;
if (co_offset < co && ci_offset < ci) {
int ori_index = ((ci_offset * kh + kh_i) * kw + kw_i) * co + co_offset;
#ifdef ENABLE_FP16
auto weight_dtype = in_tensors_.at(kWeightIndex)->data_type();
if (enable_fp16_) {
if (weight_dtype == kNumberTypeFloat32) {
reinterpret_cast<float16_t *>(padWeight_)[index++] =
@ -217,6 +218,9 @@ int Conv2dTransposeOpenCLKernel::InitFilter() {
reinterpret_cast<float16_t *>(origin_weight)[ori_index];
}
}
#else
reinterpret_cast<float *>(padWeight_)[index++] = reinterpret_cast<float *>(origin_weight)[ori_index];
#endif
} else {
index++;
}
@ -262,6 +266,7 @@ int Conv2dTransposeOpenCLKernel::InitBias() {
if (in_tensors_.size() == INPUT_TENSOR_SIZE_3) {
void *src_data = stored_bias_ == nullptr ? in_tensors_.at(kBiasIndex)->data() : stored_bias_;
MS_ASSERT(src_data);
#ifdef ENABLE_FP16
auto bias_dtype = in_tensors_[2]->data_type();
if (bias_dtype == kNumberTypeFloat32 && enable_fp16_) {
for (int i = 0; i < co; i++) {
@ -274,6 +279,9 @@ int Conv2dTransposeOpenCLKernel::InitBias() {
} else {
memcpy(bias_, src_data, co * data_size);
}
#else
memcpy(bias_, src_data, co * data_size);
#endif
}
if (allocator->UnmapBuffer(bias_) != RET_OK) {
MS_LOG(ERROR) << "UnmapBuffer failed.";

View File

@ -102,6 +102,7 @@ int DepthwiseConv2dOpenCLKernel::Prepare() {
return RET_OK;
}
#ifdef ENABLE_FP16
int DepthwiseConv2dOpenCLKernel::InitWeights() {
auto parameter = reinterpret_cast<ConvParameter *>(op_parameter_);
auto allocator = ocl_runtime_->GetAllocator();
@ -168,7 +169,53 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() {
FreeStoredData(stored_weight_);
return RET_OK;
}
#else
int DepthwiseConv2dOpenCLKernel::InitWeights() {
auto parameter = reinterpret_cast<ConvParameter *>(op_parameter_);
auto allocator = ocl_runtime_->GetAllocator();
size_t dtype_size = sizeof(float);
auto out_info = GpuTensorInfo(out_tensors_[0]);
// weight: o, h, w, i; o == group, i == 1
void *origin_weight = stored_weight_ == nullptr ? in_tensors_.at(kWeightIndex)->data() : stored_weight_;
MS_ASSERT(origin_weight);
int CO4 = UP_DIV(out_info.C, C4NUM);
int pack_weight_size = C4NUM * CO4 * parameter->kernel_h_ * parameter->kernel_w_;
int plane_in = parameter->kernel_h_ * parameter->kernel_w_;
int plane_out = plane_in * C4NUM;
if (filter_type_ == MemType::IMG) {
int alignment = ocl_runtime_->GetImagePitchAlignment();
plane_out = UP_ROUND(plane_out, alignment) * C4NUM;
pack_weight_size = plane_out * CO4;
}
pack_weight_size = pack_weight_size * dtype_size;
auto ConvertFilter = [](void *src, void *dst, TypeId src_type, TypeId dst_type, size_t plane_in, size_t plane_out,
size_t channel) {
std::function<float(float)> to_dtype = [](float x) -> float { return x; };
PackNCHWToNC4HW4<float, float>(src, dst, 1, plane_in, plane_out, channel, to_dtype);
};
std::vector<char> temp_filter(pack_weight_size);
auto src_type = in_tensors_.at(kWeightIndex)->data_type();
auto dst_type = kNumberTypeFloat32;
ConvertFilter(origin_weight, temp_filter.data(), src_type, dst_type, plane_in, plane_out, out_info.C);
if (filter_type_ == MemType::IMG) {
size_t img_dtype = CL_FLOAT;
ImageSize img_size{(size_t)plane_out / C4NUM, (size_t)out_info.N * CO4, img_dtype};
packed_weight_ = allocator->Malloc(img_size, temp_filter.data());
} else {
packed_weight_ = allocator->Malloc(pack_weight_size, temp_filter.data());
}
if (packed_weight_ == nullptr) {
MS_LOG(ERROR) << "Malloc failed.";
return RET_ERROR;
}
FreeStoredData(stored_weight_);
return RET_OK;
}
#endif
#ifdef ENABLE_FP16
int DepthwiseConv2dOpenCLKernel::InitBias() {
auto allocator = ocl_runtime_->GetAllocator();
bool is_fp16 = ocl_runtime_->GetFp16Enable();
@ -213,6 +260,39 @@ int DepthwiseConv2dOpenCLKernel::InitBias() {
FreeStoredData(stored_bias_);
return RET_OK;
}
#else
int DepthwiseConv2dOpenCLKernel::InitBias() {
auto allocator = ocl_runtime_->GetAllocator();
size_t dtype_size = sizeof(float);
auto out_info = GpuTensorInfo(out_tensors_[0]);
int CO4 = UP_DIV(out_info.C, C4NUM);
auto src_type = in_tensors_.at(kWeightIndex)->data_type();
auto dst_type = kNumberTypeFloat32;
auto ConvertBias = [](void *src, void *dst, size_t size, size_t dtype_size, TypeId src_type, TypeId dst_type) {
memcpy(dst, src, size * dtype_size);
};
size_t bias_size = C4NUM * CO4 * dtype_size;
std::vector<char> temp_bias(bias_size, 0);
if (in_tensors_.size() == INPUT_TENSOR_SIZE_3) {
src_type = in_tensors_.at(kBiasIndex)->data_type();
dst_type = kNumberTypeFloat32;
auto element_size = in_tensors_.at(kBiasIndex)->ElementsNum();
void *src_data = stored_bias_ == nullptr ? in_tensors_.at(kBiasIndex)->data() : stored_bias_;
MS_ASSERT(src_data);
ConvertBias(src_data, temp_bias.data(), element_size, dtype_size, src_type, dst_type);
}
bias_data_ = allocator->Malloc(bias_size, temp_bias.data());
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "Malloc failed.";
return RET_ERROR;
}
FreeStoredData(stored_bias_);
return RET_OK;
}
#endif
int DepthwiseConv2dOpenCLKernel::SetConstArgs() {
auto parameter = reinterpret_cast<ConvParameter *>(op_parameter_);

View File

@ -61,7 +61,7 @@ int FillOpenCLKernel::RunShape() {
auto tensor_shape = in_tensors_[0]->shape();
void *tensor_shape_data = tensor_shape.data();
CHECK_NULL_RETURN(tensor_shape_data);
for (int i = 0; i < tensor_shape.size(); ++i) {
for (size_t i = 0; i < tensor_shape.size(); ++i) {
fill_value.s[i] = reinterpret_cast<int *>(tensor_shape_data)[i];
}
auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0};

View File

@ -73,9 +73,8 @@ int FullConnectionOpenCLKernel::CheckSpecs() {
MS_LOG(WARNING) << "If fullconnection input weight is not constant, it should be 2d.";
return RET_ERROR;
}
if (intensor_shape.C != in_tensors_.at(kWeightIndex)->shape()[1]) {
MS_LOG(WARNING)
<< "If fullconnection input weight is not constant, input channel should equal to weight in_channel.";
if (static_cast<int>(intensor_shape.C) != in_tensors_.at(kWeightIndex)->shape()[1]) {
MS_LOG(WARNING) << "input weight is not constant, input channel should equal to weight in_channel.";
return RET_ERROR;
}
}
@ -132,6 +131,7 @@ int FullConnectionOpenCLKernel::InitWeights() {
return InitBias();
} // namespace mindspore::kernel
#ifdef ENABLE_FP16
int FullConnectionOpenCLKernel::InitFilter() {
auto allocator = ocl_runtime_->GetAllocator();
auto intensor_shape = GpuTensorInfo(in_tensors_[0]);
@ -249,6 +249,97 @@ int FullConnectionOpenCLKernel::InitBias() {
FreeStoredData(stored_bias_);
return RET_OK;
}
#else
int FullConnectionOpenCLKernel::InitFilter() {
auto allocator = ocl_runtime_->GetAllocator();
auto intensor_shape = GpuTensorInfo(in_tensors_[0]);
int co4 = UP_DIV(CO_, C4NUM);
int nhw_remainder = intensor_shape.N * intensor_shape.H * intensor_shape.W / N_;
size_t dtype_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float);
padWeight_ = allocator->Malloc(nhw_remainder * intensor_shape.Slice * co4 * C4NUM * C4NUM * dtype_size,
lite::opencl::MemType::BUF);
if (padWeight_ == nullptr) {
MS_LOG(ERROR) << "Malloc failed.";
return RET_ERROR;
}
padWeight_ = allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true);
if (padWeight_ == nullptr) {
MS_LOG(ERROR) << "Map Buffer failed.";
return RET_ERROR;
}
auto padWeight = reinterpret_cast<float *>(padWeight_);
memset(padWeight_, 0x00, nhw_remainder * intensor_shape.Slice * co4 * C4NUM * C4NUM * dtype_size);
void *src_data = stored_weight_ == nullptr ? in_tensors_.at(kWeightIndex)->data() : stored_weight_;
MS_ASSERT(src_data);
auto originWeight = reinterpret_cast<float *>(src_data);
// pad weight
// HWCICO -> (HWCI4)(CO4)(4 from CO)(4 from CI)
// if tranposeB, COHWCI -> (HWCI4)(CO4)(4 from CO)(4 from CI)
int index = 0;
for (int nhw = 0; nhw < nhw_remainder; nhw++) {
for (size_t i = 0; i < intensor_shape.Slice; ++i) {
for (int j = 0; j < co4; ++j) {
for (int k = 0; k < C4NUM; ++k) {
for (int l = 0; l < C4NUM; ++l) {
size_t src_ci = i * C4NUM + l;
size_t src_co = j * C4NUM + k;
if (src_ci < intensor_shape.C && static_cast<int>(src_co) < CO_) {
int originId = (nhw * intensor_shape.C + src_ci) * CO_ + src_co;
if (transposeB) {
originId = src_co * intensor_shape.C * nhw_remainder + nhw * intensor_shape.C + src_ci;
}
padWeight[index++] = originWeight[originId];
} else {
index++;
}
}
}
}
}
}
if (allocator->UnmapBuffer(padWeight_) != RET_OK) {
MS_LOG(ERROR) << "UnmapBuffer failed.";
return RET_ERROR;
}
FreeStoredData(stored_weight_);
return RET_OK;
}
int FullConnectionOpenCLKernel::InitBias() {
// pad FC Bias
auto allocator = ocl_runtime_->GetAllocator();
int co4 = UP_DIV(CO_, C4NUM);
size_t dtype_size = sizeof(float);
size_t im_dst_x, im_dst_y;
im_dst_x = co4;
im_dst_y = 1;
size_t img_dtype = CL_FLOAT;
ImageSize img_size{im_dst_x, im_dst_y, img_dtype};
bias_ = allocator->Malloc(img_size);
if (bias_ == nullptr) {
MS_LOG(ERROR) << "Malloc failed.";
return RET_ERROR;
}
bias_ = allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true);
if (bias_ == nullptr) {
MS_LOG(ERROR) << "Map Buffer failed.";
return RET_ERROR;
}
memset(bias_, 0x00, co4 * C4NUM * dtype_size);
if (in_tensors_.size() == INPUT_TENSOR_SIZE_3) {
void *src_data = stored_bias_ == nullptr ? in_tensors_.at(kBiasIndex)->data() : stored_bias_;
MS_ASSERT(src_data);
memcpy(bias_, src_data, CO_ * dtype_size);
}
if (allocator->UnmapBuffer(bias_) != RET_OK) {
MS_LOG(ERROR) << "UnmapBuffer failed.";
return RET_ERROR;
}
FreeStoredData(stored_bias_);
return RET_OK;
}
#endif
void FullConnectionOpenCLKernel::SetGlobalLocal() {
local_size_ = {32, 4, 1};

View File

@ -145,9 +145,9 @@ bool IsEltwiseAndOperatorSupported(LiteKernel *node) {
MS_ASSERT(in_tensor);
auto shape = in_tensor->shape();
bool is_scalar = shape.empty() || (shape.size() == DIMENSION_1D && shape.front() == 1);
bool is_vector = shape.size() == DIMENSION_1D && shape.front() == output_info.C;
bool _111C =
shape.size() == DIMENSION_4D && shape[0] == 1 && shape[1] == 1 && shape[2] == 1 && shape[3] == output_info.C;
bool is_vector = shape.size() == DIMENSION_1D && shape.front() == static_cast<int>(output_info.C);
bool _111C = shape.size() == DIMENSION_4D && shape[0] == 1 && shape[1] == 1 && shape[2] == 1 &&
shape[3] == static_cast<int>(output_info.C);
bool same_with_out = shape == output_shape;
if (!(is_scalar || is_vector || _111C || same_with_out)) {
return false;
@ -209,6 +209,7 @@ void CopyNumber(void *dst, void *src, size_t n) {
}
}
#ifdef ENABLE_FP16
int FusionEltwiseOpenCLKernel::InitWeights() {
auto allocator = ocl_runtime_->GetAllocator();
bool use_fp16 = ocl_runtime_->GetFp16Enable();
@ -256,6 +257,41 @@ int FusionEltwiseOpenCLKernel::InitWeights() {
}
return RET_OK;
}
#else
int FusionEltwiseOpenCLKernel::InitWeights() {
auto allocator = ocl_runtime_->GetAllocator();
for (auto *tensor : in_tensors_) {
MS_ASSERT(tensor);
if (tensor->IsConst()) {
if (IsScalar(tensor->shape())) {
float value = *reinterpret_cast<float *>(tensor->data());
scalar_weights_.push_back(value);
} else {
auto tensor_info = GpuTensorInfo(tensor);
size_t num = tensor_info.ElementsNum;
size_t size = tensor_info.Image2DSize;
void *buffer = allocator->Malloc(size, lite::opencl::MemType::BUF);
if (buffer == nullptr) {
MS_LOG(ERROR) << "Malloc failed.";
return RET_ERROR;
}
if (allocator->MapBuffer(buffer, CL_MAP_WRITE, nullptr, true) == nullptr) {
MS_LOG(ERROR) << "Map Buffer failed.";
return RET_ERROR;
}
memset(buffer, 0x00, size);
CopyNumber<float, float>(buffer, tensor->data(), num);
if (allocator->UnmapBuffer(buffer) != RET_OK) {
MS_LOG(ERROR) << "UnmapBuffer failed.";
return RET_ERROR;
}
buffer_weights_.push_back(buffer);
}
}
}
return RET_OK;
}
#endif
void FusionEltwiseOpenCLKernel::SetGlobalLocal() {
auto output = GpuTensorInfo(out_tensors_.front());
@ -275,6 +311,7 @@ int FusionEltwiseOpenCLKernel::SetConstArgs() {
MS_ASSERT(in_tensor);
if (in_tensor->IsConst()) {
if (IsScalar(in_tensor->shape())) {
#ifdef ENABLE_FP16
if (ocl_runtime_->GetFp16Enable()) {
auto value = static_cast<float16_t>(scalar_weights_[scalar_idx++]);
if (ocl_runtime_->SetKernelArg(kernel_, arg_idx, *(reinterpret_cast<cl_half *>(&value))) != CL_SUCCESS) {
@ -287,6 +324,12 @@ int FusionEltwiseOpenCLKernel::SetConstArgs() {
return RET_ERROR;
}
}
#else
if (ocl_runtime_->SetKernelArg(kernel_, arg_idx, scalar_weights_[scalar_idx++]) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
#endif
} else {
if (ocl_runtime_->SetKernelArg(kernel_, arg_idx, buffer_weights_[buffer_idx++], true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
@ -333,7 +376,7 @@ std::string FusionEltwiseOpenCLKernel::Codegen() {
"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"__kernel void FusionEltwise(";
for (int i = 0; i < in_tensors_.size(); ++i) {
for (size_t i = 0; i < in_tensors_.size(); ++i) {
MS_ASSERT(in_tensors_[i]);
if (in_tensors_[i]->IsConst()) {
if (IsScalar(in_tensors_[i]->shape())) {
@ -359,14 +402,14 @@ std::string FusionEltwiseOpenCLKernel::Codegen() {
" }\n";
auto output = GpuTensorInfo(out_tensors_.front());
for (int i = 0; i < in_tensors_.size(); ++i) {
for (size_t i = 0; i < in_tensors_.size(); ++i) {
auto *tensor = in_tensors_[i];
MS_ASSERT(tensor);
auto shape = in_tensors_[i]->shape();
bool is_scalar = IsScalar(shape);
bool is_vector = shape.size() == DIMENSION_1D && shape.front() == output.C;
bool _111C =
shape.size() == DIMENSION_4D && shape[0] == 1 && shape[1] == 1 && shape[2] == 1 && shape[3] == output.C;
bool is_vector = shape.size() == DIMENSION_1D && shape.front() == static_cast<int>(output.C);
bool _111C = shape.size() == DIMENSION_4D && shape[0] == 1 && shape[1] == 1 && shape[2] == 1 &&
shape[3] == static_cast<int>(output.C);
if (tensor->IsConst()) {
if (!is_scalar) {
code << " FLT4 in" << i << " = input" << i << "[";

View File

@ -114,7 +114,7 @@ struct FusionEltwiseParameter {
const std::vector<lite::Tensor *> &in_tensors,
const std::map<lite::Tensor *, FusionEltwiseParameter *> &replace_map = {})
: operator_(operator_init), name_(std::move(kernel_name)) {
for (int i = 0; i < in_tensors.size(); ++i) {
for (size_t i = 0; i < in_tensors.size(); ++i) {
auto *in_tensor = in_tensors[i];
if (replace_map.count(in_tensor)) {
auto *pred_param = replace_map.at(in_tensor);

View File

@ -193,6 +193,7 @@ int GatherOpenCLKernel::ConvertTensorToweight() {
return RET_OK;
}
#ifdef ENABLE_FP16
int GatherOpenCLKernel::InitWeights() {
auto indices_tensor = in_tensors_.at(1);
auto indices_num = indices_tensor->ElementsNum();
@ -226,6 +227,37 @@ int GatherOpenCLKernel::InitWeights() {
}
return RET_OK;
}
#else
int GatherOpenCLKernel::InitWeights() {
auto indices_tensor = in_tensors_.at(1);
auto indices_num = indices_tensor->ElementsNum();
auto allocator = ocl_runtime_->GetAllocator();
indices_data_ =
reinterpret_cast<int32_t *>(allocator->Malloc(sizeof(int32_t) * indices_num, lite::opencl::MemType::BUF));
if (indices_data_ == nullptr) {
MS_LOG(ERROR) << "Memory allocation failed";
return RET_ERROR;
}
auto data_type = indices_tensor->data_type();
auto data = indices_tensor->data();
MS_ASSERT(data);
if (data_type == kNumberTypeInt32) {
for (int i = 0; i < indices_num; i++) {
indices_data_[i] = reinterpret_cast<int32_t *>(data)[i];
}
} else if (data_type == kNumberTypeInt64) {
for (int i = 0; i < indices_num; i++) {
indices_data_[i] = reinterpret_cast<int64_t *>(data)[i];
}
} else if (data_type == kNumberTypeFloat32) {
for (int i = 0; i < indices_num; i++) {
indices_data_[i] = reinterpret_cast<float *>(data)[i];
}
}
return RET_OK;
}
#endif
int GatherOpenCLKernel::PreProcess() {
if (!InferShapeDone()) {

View File

@ -122,6 +122,7 @@ void LayerNormOpenCLKernel::SetGlobalLocal() {
&local_mean_var_);
}
#ifdef ENABLE_FP16
int LayerNormOpenCLKernel::Initweight() {
auto allocator = ocl_runtime_->GetAllocator();
CHECK_NULL_RETURN(allocator);
@ -195,8 +196,55 @@ int LayerNormOpenCLKernel::Initweight() {
}
return RET_OK;
}
#else
int LayerNormOpenCLKernel::Initweight() {
auto allocator = ocl_runtime_->GetAllocator();
CHECK_NULL_RETURN(allocator);
GpuTensorInfo img_info(in_tensors_.at(1));
auto weight_tensor = in_tensors_.at(1);
CHECK_NULL_RETURN(weight_tensor);
size_t weight_size = img_info.Image2DSize;
// allocated memory for weight and init value
gamma_ = allocator->Malloc(weight_size, lite::opencl::MemType::BUF);
if (gamma_ == nullptr) {
MS_LOG(ERROR) << "Malloc failed.";
return RET_ERROR;
}
beta_ = allocator->Malloc(weight_size, lite::opencl::MemType::BUF);
if (beta_ == nullptr) {
MS_LOG(ERROR) << "Malloc failed.";
return RET_ERROR;
}
if (allocator->MapBuffer(gamma_, CL_MAP_WRITE, nullptr, true) == nullptr) {
MS_LOG(ERROR) << "Map Buffer failed.";
return RET_ERROR;
}
if (allocator->MapBuffer(beta_, CL_MAP_WRITE, nullptr, true) == nullptr) {
MS_LOG(ERROR) << "Map Buffer failed.";
return RET_ERROR;
}
memset(gamma_, 0x01, weight_size);
memset(beta_, 0x00, weight_size);
CHECK_NULL_RETURN(in_tensors_.at(1)->data());
CHECK_NULL_RETURN(in_tensors_.at(2));
CHECK_NULL_RETURN(in_tensors_.at(2)->data());
memcpy(gamma_, in_tensors_.at(1)->data(), weight_size);
memcpy(beta_, in_tensors_.at(2)->data(), weight_size);
if (allocator->UnmapBuffer(gamma_) != RET_OK) {
MS_LOG(ERROR) << "UnmapBuffer failed.";
return RET_ERROR;
}
if (allocator->UnmapBuffer(beta_) != RET_OK) {
MS_LOG(ERROR) << "UnmapBuffer failed.";
return RET_ERROR;
}
return RET_OK;
}
#endif
int LayerNormOpenCLKernel::Prepare() {
#ifdef ENABLE_FP16
use_fp16_enable_ = ocl_runtime_->GetFp16Enable();
int ret = Initweight();
if (ret != RET_OK) {
@ -210,6 +258,20 @@ int LayerNormOpenCLKernel::Prepare() {
mean_size *= in_tensors_.at(0)->shape()[i];
}
size_t size_dtype = use_fp16_enable_ ? sizeof(float16_t) : sizeof(float);
#else
int ret = Initweight();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Initweight failed ";
return ret;
}
normalized_shape_size_ = in_tensors_.at(0)->shape().at(normalized_axis_);
auto allocator = ocl_runtime_->GetAllocator();
size_t mean_size = 1;
for (int i = 0; i < normalized_axis_; ++i) {
mean_size *= in_tensors_.at(0)->shape()[i];
}
size_t size_dtype = sizeof(float);
#endif
mean_size *= size_dtype;
mean_ = allocator->Malloc(mean_size, lite::opencl::MemType::BUF);
if (mean_ == nullptr) {

View File

@ -104,6 +104,7 @@ int MatMulOpenCLKernel::Prepare() {
return RET_OK;
}
#ifdef ENABLE_FP16
int MatMulOpenCLKernel::PadWeight(std::vector<int> weight_shape_4d, int ci, int co) {
auto allocator = ocl_runtime_->GetAllocator();
int a = weight_shape_4d[0];
@ -170,6 +171,59 @@ int MatMulOpenCLKernel::PadWeight(std::vector<int> weight_shape_4d, int ci, int
}
return RET_OK;
}
#else
int MatMulOpenCLKernel::PadWeight(std::vector<int> weight_shape_4d, int ci, int co) {
auto allocator = ocl_runtime_->GetAllocator();
int a = weight_shape_4d[0];
int b = weight_shape_4d[1];
int ci4 = UP_DIV(ci, C4NUM);
int co4 = UP_DIV(co, C4NUM);
size_t dtype_size = sizeof(float);
padWeight_ = allocator->Malloc(a * b * ci4 * co4 * C4NUM * C4NUM * dtype_size, lite::opencl::MemType::BUF);
if (padWeight_ == nullptr) {
MS_LOG(ERROR) << "Malloc failed.";
return RET_ERROR;
}
padWeight_ = allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true);
if (padWeight_ == nullptr) {
MS_LOG(ERROR) << "Map Buffer failed.";
return RET_ERROR;
}
auto padWeight = reinterpret_cast<float *>(padWeight_);
memset(padWeight_, 0x00, a * b * ci4 * co4 * C4NUM * C4NUM * dtype_size);
void *src_data = stored_weight_ == nullptr ? in_tensors_.at(kWeightIndex)->data() : stored_weight_;
auto originWeight = reinterpret_cast<float *>(src_data);
// pad weight
// ABCICO -> AB(CI4)(CO4)(4 from CO)(4 from CI)
// if tranposeB, ABCOCI -> AB(CI4)(CO4)(4 from CO)(4 from CI)
int index = 0;
for (int aa = 0; aa < a; aa++) {
for (int bb = 0; bb < b; bb++) {
int baseAB = (aa * b + bb) * ci * CO_;
for (int i = 0; i < ci4; ++i) {
for (int j = 0; j < co4; ++j) {
for (int k = 0; k < C4NUM; ++k) {
for (int l = 0; l < C4NUM; ++l) {
int src_ci = i * C4NUM + l;
int src_co = j * C4NUM + k;
if (src_ci < ci && src_co < CO_) {
int originId = baseAB + src_ci * CO_ + src_co;
if (transposeB) {
originId = baseAB + src_co * ci + src_ci;
}
padWeight[index++] = originWeight[originId];
} else {
index++;
}
}
}
}
}
}
}
return RET_OK;
}
#endif
int MatMulOpenCLKernel::InitWeights() {
if (!in_tensors_[1]->IsConst()) {
@ -205,6 +259,7 @@ int MatMulOpenCLKernel::InitWeights() {
return InitBias();
}
#ifdef ENABLE_FP16
int MatMulOpenCLKernel::InitBias() {
// pad FC Bias
auto allocator = ocl_runtime_->GetAllocator();
@ -250,6 +305,40 @@ int MatMulOpenCLKernel::InitBias() {
FreeStoredData(stored_bias_);
return RET_OK;
}
#else
int MatMulOpenCLKernel::InitBias() {
// pad FC Bias
auto allocator = ocl_runtime_->GetAllocator();
int co4 = UP_DIV(CO_, C4NUM);
size_t dtype_size = sizeof(float);
size_t im_dst_x, im_dst_y;
im_dst_x = co4;
im_dst_y = 1;
size_t img_dtype = CL_FLOAT;
lite::opencl::ImageSize img_size{im_dst_x, im_dst_y, img_dtype};
bias_ = allocator->Malloc(img_size);
if (bias_ == nullptr) {
MS_LOG(ERROR) << "Malloc failed.";
return RET_ERROR;
}
bias_ = allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true);
if (bias_ == nullptr) {
MS_LOG(ERROR) << "Map Buffer failed.";
return RET_ERROR;
}
memset(bias_, 0x00, co4 * C4NUM * dtype_size);
if (in_tensors_.size() == INPUT_TENSOR_SIZE_3) {
void *src_data = stored_bias_ == nullptr ? in_tensors_.at(kBiasIndex)->data() : stored_bias_;
memcpy(bias_, src_data, CO_ * dtype_size);
}
if (allocator->UnmapBuffer(bias_) != RET_OK) {
MS_LOG(ERROR) << "UnmapBuffer failed.";
return RET_ERROR;
}
FreeStoredData(stored_bias_);
return RET_OK;
}
#endif
void MatMulOpenCLKernel::SetGlobalLocal() {
// local size should less than MAX_GROUP_SIZE

View File

@ -61,11 +61,11 @@ int PadOpenCLKernel::CheckSpecs() {
return RET_ERROR;
}
// Compatibility code
if (param->padding_length == DIMENSION_2D * in_ndim) {
if (param->padding_length == static_cast<int>(DIMENSION_2D * in_ndim)) {
return RET_OK;
}
auto pad_shape = in_tensors_.at(1)->shape();
if (pad_shape.size() != DIMENSION_2D || pad_shape[0] != in_ndim || pad_shape[1] != DIMENSION_2D) {
if (pad_shape.size() != DIMENSION_2D || pad_shape[0] != static_cast<int>(in_ndim) || pad_shape[1] != DIMENSION_2D) {
MS_LOG(WARNING) << "pad tensor shape invalid.";
return RET_ERROR;
}
@ -105,7 +105,7 @@ int PadOpenCLKernel::SetConstArgs() {
std::vector<int> pad_before_ori;
pad_before_ori.reserve(ndim);
auto paddings = reinterpret_cast<int32_t *>(in_tensors_.at(1)->data());
for (size_t i = 0; i < ndim; i++) {
for (auto i = 0; i < ndim; i++) {
pad_before_ori.push_back(paddings[2 * i]);
}
cl_int4 pad_before;

View File

@ -61,8 +61,8 @@ int PoolingOpenCLKernel::BuildKernel() {
kernel_name = "AvgPooling2d";
}
if (parameter_->global_ &&
(parameter_->window_h_ >= LOCAL_CACHE_THREAD || parameter_->window_w_ >= LOCAL_CACHE_THREAD)) {
if (parameter_->global_ && (parameter_->window_h_ >= static_cast<int>(LOCAL_CACHE_THREAD) ||
parameter_->window_w_ >= static_cast<int>(LOCAL_CACHE_THREAD))) {
kernel_name += "_global";
is_use_local_ = true;
}

View File

@ -79,6 +79,7 @@ int PowerOpenCLKernel::SetConstArgs() {
return RET_ERROR;
}
}
#ifdef ENABLE_FP16
if (use_fp16_enable_) {
auto x = static_cast<float16_t>(power_);
auto y = static_cast<float16_t>(shift_);
@ -97,12 +98,19 @@ int PowerOpenCLKernel::SetConstArgs() {
return RET_ERROR;
}
}
#else
cl_float4 parameter = {power_, shift_, scale_, unalign_w};
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, parameter) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
}
#endif
return RET_OK;
}
void PowerOpenCLKernel::SetGlobalLocal() {
cl_int4 output_shape = {};
for (int i = 0; i < out_tensors_.at(0)->shape().size(); ++i) {
for (size_t i = 0; i < out_tensors_.at(0)->shape().size(); ++i) {
output_shape.s[i] = out_tensors_.at(0)->shape()[i];
}
Broadcast2GpuShape(out_shape_.s, output_shape.s, out_tensors_.at(0)->shape().size(), 1);

View File

@ -32,6 +32,7 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_PReLUFusion;
namespace mindspore::kernel {
#ifdef ENABLE_FP16
int PReluOpenCLKernel::InitWeights() {
auto allocator = ocl_runtime_->GetAllocator();
auto weight_tensor = in_tensors_.at(1);
@ -84,6 +85,36 @@ int PReluOpenCLKernel::InitWeights() {
}
return RET_OK;
}
#else
int PReluOpenCLKernel::InitWeights() {
auto allocator = ocl_runtime_->GetAllocator();
auto weight_tensor = in_tensors_.at(1);
if (weight_is_scalar) {
weight_scalar_ = *reinterpret_cast<float *>(weight_tensor->data());
MS_ASSERT(weight_scalar_);
} else {
int C_ = weight_tensor->ElementsNum();
auto sizeof_FLT = sizeof(float);
size_t weight_size = UP_ROUND(C_, C4NUM) * sizeof_FLT;
weight_vector_ = allocator->Malloc(weight_size, lite::opencl::MemType::BUF);
if (weight_vector_ == nullptr) {
MS_LOG(ERROR) << "Malloc failed.";
return RET_ERROR;
}
if (allocator->MapBuffer(weight_vector_, CL_MAP_WRITE, nullptr, true) == nullptr) {
MS_LOG(ERROR) << "Map Buffer failed.";
return RET_ERROR;
}
memset(weight_vector_, 0x00, weight_size);
memcpy(weight_vector_, weight_tensor->data(), C_ * sizeof_FLT);
if (allocator->UnmapBuffer(weight_vector_) != RET_OK) {
MS_LOG(ERROR) << "UnmapBuffer failed.";
return RET_ERROR;
}
}
return RET_OK;
}
#endif
int PReluOpenCLKernel::CheckSpecs() {
if (in_tensors_.size() != INPUT_TENSOR_SIZE_2 || out_tensors_.size() != OUTPUT_TENSOR_SIZE_1) {
@ -132,10 +163,10 @@ void PReluOpenCLKernel::SetGlobalLocal() {
int PReluOpenCLKernel::Prepare() {
cl_int4 output_shape = {};
cl_int4 weight_shape = {};
for (int i = 0; i < out_tensors_.at(0)->shape().size(); ++i) {
for (size_t i = 0; i < out_tensors_.at(0)->shape().size(); ++i) {
output_shape.s[i] = out_tensors_.at(0)->shape()[i];
}
for (int i = 0; i < in_tensors_.at(1)->shape().size(); ++i) {
for (size_t i = 0; i < in_tensors_.at(1)->shape().size(); ++i) {
weight_shape.s[i] = in_tensors_.at(1)->shape()[i];
}
Broadcast2GpuShape(out_shape_.s, output_shape.s, out_tensors_.at(0)->shape().size(), 1);

View File

@ -235,6 +235,7 @@ int ScaleOpenCLKernel::SetKernelArg(int *idx) {
return RET_ERROR;
}
} else {
#ifdef ENABLE_FP16
if (in_tensors_[1]->data_type() == kNumberTypeFloat32) {
float scale = static_cast<float *>(in_tensors_[1]->data())[0];
float offset = static_cast<float *>(in_tensors_[2]->data())[0];
@ -257,6 +258,21 @@ int ScaleOpenCLKernel::SetKernelArg(int *idx) {
MS_LOG(ERROR) << "Unsupported data type " << in_tensors_[1]->data_type();
return RET_ERROR;
}
#else
if (in_tensors_[1]->data_type() == kNumberTypeFloat32) {
float scale = static_cast<float *>(in_tensors_[1]->data())[0];
float offset = static_cast<float *>(in_tensors_[2]->data())[0];
if (ocl_runtime_->SetKernelArg(kernel_, arg_idx++, scale) != CL_SUCCESS) {
return RET_ERROR;
}
if (ocl_runtime_->SetKernelArg(kernel_, arg_idx++, offset) != CL_SUCCESS) {
return RET_ERROR;
}
} else {
MS_LOG(ERROR) << "Unsupported data type " << in_tensors_[1]->data_type();
return RET_ERROR;
}
#endif
}
if (ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data()) != CL_SUCCESS) {
return RET_ERROR;

View File

@ -53,6 +53,7 @@ int SparseToDenseOpenCLKernel::InitOutputToDefault() {
return RET_OK;
}
#ifdef ENABLE_FP16
int SparseToDenseOpenCLKernel::InitWeights() {
auto allocator = ocl_runtime_->GetAllocator();
MS_CHECK_GE(in_tensors_.size(), DIMENSION_3D, RET_ERROR);
@ -109,6 +110,40 @@ int SparseToDenseOpenCLKernel::InitWeights() {
}
return RET_OK;
}
#else
int SparseToDenseOpenCLKernel::InitWeights() {
auto allocator = ocl_runtime_->GetAllocator();
MS_CHECK_GE(in_tensors_.size(), DIMENSION_3D, RET_ERROR);
auto weight_tensor = in_tensors_[2];
size_t size = 1;
for (size_t i = 0; i < weight_tensor->shape().size(); ++i) {
size *= weight_tensor->shape()[i];
}
MS_ASSERT(weight_tensor->data());
if (weight_scalar_) {
weight_scalar_ = *reinterpret_cast<float *>(weight_tensor->data());
} else {
auto sizeof_FLT = sizeof(float);
size_t weight_size = UP_ROUND(size, C4NUM) * sizeof_FLT;
weight_vector_ = allocator->Malloc(weight_size, lite::opencl::MemType::BUF);
if (weight_vector_ == nullptr) {
MS_LOG(ERROR) << "Malloc failed.";
return RET_ERROR;
}
if (allocator->MapBuffer(weight_vector_, CL_MAP_WRITE, nullptr, true) == nullptr) {
MS_LOG(ERROR) << "Map Buffer failed.";
return RET_ERROR;
}
memset(weight_vector_, 0x00, weight_size);
memcpy(weight_vector_, weight_tensor->data(), size * sizeof_FLT);
if (allocator->UnmapBuffer(weight_vector_) != RET_OK) {
MS_LOG(ERROR) << "UnmapBuffer failed.";
return RET_ERROR;
}
}
return RET_OK;
}
#endif
int SparseToDenseOpenCLKernel::CheckSpecs() {
if (in_tensors_.size() < DIMENSION_3D || out_tensors_.at(0)->shape().size() > DIMENSION_4D) {
@ -200,11 +235,15 @@ int SparseToDenseOpenCLKernel::Prepare() {
if (in_tensors_.size() > INPUT_TENSOR_SIZE_3) {
auto input_tensor3 = in_tensors_[3];
#ifdef ENABLE_FP16
if (input_tensor3->data_type() == kNumberTypeFloat16) {
default_ = static_cast<float>(*reinterpret_cast<float16_t *>(input_tensor3->data()));
} else {
default_ = *reinterpret_cast<float *>(input_tensor3->data());
}
#else
default_ = *reinterpret_cast<float *>(input_tensor3->data());
#endif
MS_ASSERT(default_);
}
ret = InitWeights();

View File

@ -39,7 +39,7 @@ int SplitOpenCLKernel::RunAxis0() {
return RET_ERROR;
}
auto src_area = cl::array<cl::size_type, 3U>{0, 0, 0};
for (int i = 0; i < out_tensors_.size(); i++) {
for (size_t i = 0; i < out_tensors_.size(); i++) {
auto dst_data = out_tensors_[i]->data();
CHECK_NULL_RETURN(dst_data);
ImageSize img_size;
@ -104,7 +104,7 @@ int SplitOpenCLKernel::CheckSpecs() {
int SplitOpenCLKernel::AlignSplitSizes(SplitParameter *param, const std::vector<int> &in_shape) {
auto allocator = ocl_runtime_->GetAllocator();
CHECK_LESS_RETURN(in_shape.size(), param->split_dim_ + 1);
CHECK_LESS_RETURN(static_cast<int>(in_shape.size()), param->split_dim_ + 1);
int shape_dim = in_shape.at(param->split_dim_);
if (num_split_ == 1) {
CHECK_LESS_RETURN(param->split_sizes_[0], 1);
@ -114,7 +114,7 @@ int SplitOpenCLKernel::AlignSplitSizes(SplitParameter *param, const std::vector<
MS_LOG(ERROR) << "Malloc failed.";
return RET_ERROR;
}
for (int i = 0; i < num_split - 1; ++i) {
for (size_t i = 0; i < num_split - 1; ++i) {
split_sizes_[i] = (i + 1) * param->split_sizes_[0];
}
} else {
@ -124,7 +124,7 @@ int SplitOpenCLKernel::AlignSplitSizes(SplitParameter *param, const std::vector<
MS_LOG(ERROR) << "Malloc failed.";
return RET_ERROR;
}
for (int i = 0; i < num_split_ - 1; ++i) {
for (size_t i = 0; i < num_split_ - 1; ++i) {
sum += param->split_sizes_[i];
split_sizes_[i] = sum;
}
@ -142,7 +142,7 @@ int SplitOpenCLKernel::Prepare() {
if (split_dim_ == 0) {
return RET_OK;
}
for (int i = 0; i < out_tensors_.size(); ++i) {
for (size_t i = 0; i < out_tensors_.size(); ++i) {
int length = out_tensors_[0]->shape().size();
if (split_dim_ == 3) {
if (out_tensors_[i]->shape()[length - 1] % C4NUM != 0) {
@ -186,7 +186,7 @@ int SplitOpenCLKernel::Prepare() {
int SplitOpenCLKernel::SetConstArgs() {
int arg_cn = out_tensors_.size() + 2;
cl_int4 shape = {};
for (int i = 0; i < in_tensors_[0]->shape().size(); ++i) {
for (size_t i = 0; i < in_tensors_[0]->shape().size(); ++i) {
shape.s[i] = in_tensors_[0]->shape()[i];
}
Broadcast2GpuShape(in_shape_.s, shape.s, out_tensors_[0]->shape().size(), 1);
@ -198,9 +198,9 @@ int SplitOpenCLKernel::SetConstArgs() {
return RET_ERROR;
}
for (int i = 0; i < out_tensors_.size(); ++i) {
for (size_t i = 0; i < out_tensors_.size(); ++i) {
cl_int4 temp = {};
for (int j = 0; j < out_tensors_[i]->shape().size(); ++j) {
for (size_t j = 0; j < out_tensors_[i]->shape().size(); ++j) {
temp.s[j] = out_tensors_[i]->shape()[j];
}
Broadcast2GpuShape(out_shape_.s, temp.s, out_tensors_[i]->shape().size(), 1);
@ -257,7 +257,7 @@ int SplitOpenCLKernel::Run() {
return RET_ERROR;
}
}
for (int i = 0; i < out_tensors_.size(); ++i) {
for (size_t i = 0; i < out_tensors_.size(); ++i) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_.at(i)->data()) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;

View File

@ -35,7 +35,7 @@ int StackOpenCLKernel::RunAxis0() {
MS_ASSERT(dst_data);
auto dst_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
cl::Image2D *out_image = allocator_->GetImage(dst_data);
for (int i = 0; i < in_tensors_.size(); i++) {
for (size_t i = 0; i < in_tensors_.size(); i++) {
auto src_data = in_tensors_[i]->data();
MS_ASSERT(src_data);
if (allocator_->GetImageSize(src_data, &img_size) != RET_OK) {
@ -96,7 +96,7 @@ int StackOpenCLKernel::CheckSpecs() {
MS_LOG(WARNING) << " only support axis <= 3 ";
return RET_ERROR;
}
if (axis_ > in_tensors_[0]->shape().size()) {
if (axis_ > static_cast<int>(in_tensors_[0]->shape().size())) {
MS_LOG(WARNING) << " stack axis must been <= in_tensors_[0]->shape().size() ";
return RET_ERROR;
}
@ -106,11 +106,11 @@ int StackOpenCLKernel::CheckSpecs() {
int StackOpenCLKernel::SetConstArgs() {
int arg_cn = in_tensors_.size() + 1;
cl_int4 inshape_tmp = {}, outshape_tmp = {};
for (int i = 0; i < in_tensors_[0]->shape().size(); ++i) {
for (size_t i = 0; i < in_tensors_[0]->shape().size(); ++i) {
inshape_tmp.s[i] = in_tensors_[0]->shape()[i];
}
Broadcast2GpuShape(in_shape_.s, inshape_tmp.s, in_tensors_[0]->shape().size(), 1);
for (int i = 0; i < out_tensors_[0]->shape().size(); ++i) {
for (size_t i = 0; i < out_tensors_[0]->shape().size(); ++i) {
outshape_tmp.s[i] = out_tensors_[0]->shape()[i];
}
Broadcast2GpuShape(out_shape_.s, outshape_tmp.s, out_tensors_[0]->shape().size(), 1);
@ -167,7 +167,7 @@ int StackOpenCLKernel::Prepare() {
}
if (in_tensors_[0]->shape().size() == DIMENSION_1D && axis_ == 1) {
axis_ += 2;
} else if (in_tensors_[0]->shape().size() == axis_) {
} else if (static_cast<int>(in_tensors_[0]->shape().size()) == axis_) {
buffer_button_ = true; // boundary stack judge
}
std::string kernel_name = "stack_";
@ -208,7 +208,7 @@ int StackOpenCLKernel::Run() {
}
int arg_cn = 0;
if (buffer_button_) {
for (int i = 0; i < in_tensors_.size(); ++i) {
for (size_t i = 0; i < in_tensors_.size(); ++i) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[i]->data(), true) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;
@ -219,7 +219,7 @@ int StackOpenCLKernel::Run() {
return RET_ERROR;
}
} else {
for (int i = 0; i < in_tensors_.size(); ++i) {
for (size_t i = 0; i < in_tensors_.size(); ++i) {
if (ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[i]->data()) != CL_SUCCESS) {
MS_LOG(ERROR) << "SetKernelArg failed.";
return RET_ERROR;

View File

@ -115,6 +115,7 @@ int StrassenOpenCLKernel::AllocatorMemoryForStrassen(int NumA, int NumB) {
return RET_OK;
}
#ifdef ENABLE_FP16
int StrassenOpenCLKernel::InitWeights() {
// ABMCI @ ABCICO = ABMCO
auto allocator = ocl_runtime_->GetAllocator();
@ -167,6 +168,41 @@ int StrassenOpenCLKernel::InitWeights() {
}
return RET_OK;
}
#else
int StrassenOpenCLKernel::InitWeights() {
// ABMCI @ ABCICO = ABMCO
auto allocator = ocl_runtime_->GetAllocator();
int NumA = in_tensors_[0]->shape()[0];
int NumB = in_tensors_[1]->shape()[0];
size_t dtype_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float);
padWeight_ = allocator->Malloc(NumA * NumB * dtype_size, lite::opencl::MemType::BUF);
if (padWeight_ == nullptr) {
MS_LOG(ERROR) << "Malloc failed.";
return RET_ERROR;
}
padWeight_ = allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true);
if (padWeight_ == nullptr) {
MS_LOG(ERROR) << "Map Buffer failed.";
return RET_ERROR;
}
auto padWeightFp32 = reinterpret_cast<float *>(padWeight_);
memset(padWeight_, 0x00, NumA * NumB * dtype_size);
auto weight_tensor_data = in_tensors_.at(kWeightIndex)->data();
MS_ASSERT(weight_tensor_data);
auto originWeightFp32 = reinterpret_cast<float *>(weight_tensor_data);
if (AllocatorMemoryForStrassen(NumA / 2, NumB / 2) != RET_OK) {
MS_LOG(ERROR) << "AllocatorMemoryForStrassen failed.";
return RET_ERROR;
}
size_t size = NumA * NumB * dtype_size;
memcpy(padWeightFp32, originWeightFp32, size);
if (allocator->UnmapBuffer(padWeight_) != RET_OK) {
MS_LOG(ERROR) << "UnmapBuffer failed.";
return RET_ERROR;
}
return RET_OK;
}
#endif
void AlignStrassenGlobalLocal(const std::vector<size_t> &global, const std::vector<size_t> &local,
cl::NDRange *global_range, cl::NDRange *local_range) {

View File

@ -58,7 +58,7 @@ int TransposeOpenCLKernel::Prepare() {
perm_4d_[1] = 1;
perm_4d_[2] = 2;
perm_4d_[3] = tensor_size_.AlignAxis(perm[1]);
if (num_axes != tensor_size_.NDim) {
if (num_axes != static_cast<int>(tensor_size_.NDim)) {
perm_4d_[0] = 0;
perm_4d_[1] = 1;
perm_4d_[2] = 2;
@ -93,8 +93,9 @@ int TransposeOpenCLKernel::Prepare() {
kernel_name += "_general";
}
if (in_tensors_[0]->shape().size() == DIMENSION_4D &&
in_tensors_[0]->shape()[2] * UP_DIV(in_tensors_[0]->shape()[3], C4NUM) > ocl_runtime_->GetMaxImage2DWidth()) {
if (in_tensors_[0]->shape().size() == static_cast<int>(DIMENSION_4D) &&
in_tensors_[0]->shape()[2] * UP_DIV(in_tensors_[0]->shape()[3], C4NUM) >
static_cast<int>(ocl_runtime_->GetMaxImage2DWidth())) {
// just for input
kernel_name += "_oversize";
}

View File

@ -41,6 +41,7 @@ constexpr float G[] = {1.0000000000, 0.0000000000, 0.0000000000, 1.0000000000,
1.0000000000, -0.7071067691, 0.4999999702, 1.0000000000, 1.4142135382, 1.9999998808,
1.0000000000, -1.4142135382, 1.9999998808, 0.0000000000, 0.0000000000, 1.0000000000};
std::vector<float> GenerateWinogradFilter(void *src, TypeId dtype, size_t CO, size_t CI) {
#ifdef ENABLE_FP16
auto src_fp32 = reinterpret_cast<float *>(src);
auto src_fp16 = reinterpret_cast<float16_t *>(src);
std::function<float(int)> access_func;
@ -49,13 +50,18 @@ std::vector<float> GenerateWinogradFilter(void *src, TypeId dtype, size_t CO, si
} else {
access_func = [=](int idx) { return static_cast<float>(src_fp16[idx]); };
}
#else
auto src_fp32 = reinterpret_cast<float *>(src);
std::function<float(int)> access_func;
access_func = [=](int idx) { return src_fp32[idx]; };
#endif
// OHWI -> O66I
std::vector<float> dst(CO * 6 * 6 * CI);
if (src == nullptr) {
return dst;
}
for (int co = 0; co < CO; ++co) {
for (int ci = 0; ci < CI; ++ci) {
for (size_t co = 0; co < CO; ++co) {
for (size_t ci = 0; ci < CI; ++ci) {
float in_vals[9];
for (int kh = 0; kh < 3; ++kh) {
for (int kw = 0; kw < 3; ++kw) {

View File

@ -145,7 +145,7 @@ void Merge(LiteKernel *a, LiteKernel *b, bool remove_a) {
// update b in_tensors: b.in_tensors.replace(a.out_tensors[0], a.in_tensors)
auto b_in_tensors = b->in_tensors();
for (int i = 0; i < b_in_tensors.size(); ++i) {
for (size_t i = 0; i < b_in_tensors.size(); ++i) {
if (b_in_tensors[i] == a->out_tensors().front()) {
// reshape: 2nd input tensor is removed
if (a->type() == schema::PrimitiveType_Reshape) {
@ -162,7 +162,7 @@ void Merge(LiteKernel *a, LiteKernel *b, bool remove_a) {
// update b in_kernels: b.in_kernels.replace(a, a.in_kernels)
auto b_in_kernels = b->in_kernels();
for (int i = 0; i < b_in_kernels.size(); ++i) {
for (size_t i = 0; i < b_in_kernels.size(); ++i) {
if (a == b_in_kernels[i]) {
b_in_kernels.erase(b_in_kernels.begin() + i);
b_in_kernels.insert(b_in_kernels.begin() + i, a->in_kernels().begin(), a->in_kernels().end());

View File

@ -125,7 +125,7 @@ void OpenCLKernel::PrintOutput(int print_num, const std::string &out_file) {
printf("shape=(");
auto shape = tensor->shape();
for (int i = 0; i < shape.size(); ++i) {
for (size_t i = 0; i < shape.size(); ++i) {
printf("%4d", shape[i]);
if (i + 1 < shape.size()) {
printf(",");
@ -134,7 +134,8 @@ void OpenCLKernel::PrintOutput(int print_num, const std::string &out_file) {
printf(") ");
auto total_num = mem_type == lite::opencl::MemType::BUF ? img_info.ElementsNum : img_info.ElementsC4Num;
for (int i = 0; i < print_num && i < total_num; ++i) {
for (int i = 0; i < print_num && i < static_cast<int>(total_num); ++i) {
#ifdef ENABLE_FP16
if (tensor->data_type() == kNumberTypeInt32) {
printf("%d %7d | ", i, reinterpret_cast<int32_t *>(data.data())[i]);
} else if (tensor->data_type() == kNumberTypeFloat16) {
@ -144,6 +145,9 @@ void OpenCLKernel::PrintOutput(int print_num, const std::string &out_file) {
} else if (tensor->data_type() == kNumberTypeInt8) {
printf("%d %7d | ", i, static_cast<int>(reinterpret_cast<int8_t *>(data.data())[i]));
}
#else
printf("%d %7.3f | ", i, reinterpret_cast<float *>(data.data())[i]);
#endif
}
printf("\n");
@ -158,7 +162,7 @@ int OpenCLKernel::PreProcess() {
if (ret != RET_OK) {
return ret;
}
for (auto i = 0; i < out_tensors_.size(); ++i) {
for (size_t i = 0; i < out_tensors_.size(); ++i) {
auto *output = out_tensors_.at(i);
CHECK_NULL_RETURN(output);
CHECK_NULL_RETURN(output->allocator());
@ -293,7 +297,7 @@ int OpenCLKernel::Tune() {
}
int index = -1;
double min_time = MAX_PROFILING_TIME_MILLI_SECOND;
for (int i = 0; i < tuning_params.size(); i++) {
for (size_t i = 0; i < tuning_params.size(); i++) {
AssignTuningParam(tuning_params[i]);
auto ret = Run();
if (ret != RET_OK) {
@ -332,7 +336,7 @@ double OpenCLKernel::GetProfilingTimeMs() {
std::set<size_t> OpenCLKernel::GenerateLocalByGlobal(size_t global_i) {
std::set<size_t> local_ = {};
int index = 1;
while (index <= global_i) {
while (index <= static_cast<int>(global_i)) {
local_.insert(index);
index *= 2;
}

View File

@ -21,6 +21,7 @@
#include <set>
#include <map>
#include <string>
#include <cfloat>
#include "src/inner_kernel.h"
#include "include/errorcode.h"
#include "src/runtime/gpu/opencl/opencl_runtime.h"

View File

@ -203,6 +203,7 @@ int GetBroadcastGpuAxis(int ndim, int ori_axis) {
return axis;
}
#ifdef ENABLE_FP16
void PackNHWCToNHWC4(void *src, void *dst, bool src_is_fp16, bool dst_is_fp16, const GpuTensorInfo &tensor,
int data_type) {
MS_ASSERT(src);
@ -238,6 +239,35 @@ void PackNHWCToNHWC4(void *src, void *dst, bool src_is_fp16, bool dst_is_fp16, c
}
}
}
#else
void PackNHWCToNHWC4(void *src, void *dst, bool src_is_fp16, bool dst_is_fp16, const GpuTensorInfo &tensor,
int data_type) {
MS_ASSERT(src);
MS_ASSERT(dst);
auto src_fp32 = reinterpret_cast<float *>(src);
auto src_int32 = reinterpret_cast<int32_t *>(src);
auto dst_fp32 = reinterpret_cast<float *>(dst);
auto dst_int32 = reinterpret_cast<int32_t *>(dst);
for (size_t n = 0, src_idx = 0; n < tensor.N; n++) {
for (size_t h = 0; h < tensor.H; ++h) {
for (size_t w = 0; w < tensor.W; ++w) {
for (size_t c = 0; c < tensor.C; ++c, ++src_idx) {
int dst_idx = ((n * tensor.H + h) * tensor.W + w) * tensor.Slice * C4NUM + c;
if (data_type == kNumberTypeInt32) {
dst_int32[dst_idx] = src_int32[src_idx];
} else {
dst_fp32[dst_idx] = src_fp32[src_idx];
}
}
}
}
}
// scalar
if (tensor.ElementsNum == 1) {
dst_fp32[3] = dst_fp32[2] = dst_fp32[1] = dst_fp32[0];
}
}
#endif
int CheckParamLikeTensor(const std::string &kernel_name, const std::string &tensor_name, lite::Tensor *tensor,
TypeId expect_data_type, const std::vector<int> &expect_shape) {

View File

@ -47,11 +47,13 @@ if(MSLITE_ENABLE_SPARSE_COMPUTE)
endif()
if(MSLITE_GPU_BACKEND STREQUAL opencl)
file(GLOB_RECURSE TEST_GPU_UT_SRC
${TEST_DIR}/ut/src/runtime/kernel/opencl/*.cc
${TEST_DIR}/ut/src/registry/registry_gpu_custom_op_test.cc
)
list(APPEND TEST_UT_SRC ${TEST_GPU_UT_SRC})
if(PLATFORM_ARM)
file(GLOB_RECURSE TEST_GPU_UT_SRC
${TEST_DIR}/ut/src/runtime/kernel/opencl/*.cc
${TEST_DIR}/ut/src/registry/registry_gpu_custom_op_test.cc
)
list(APPEND TEST_UT_SRC ${TEST_GPU_UT_SRC})
endif()
endif()
if(MSLITE_ENABLE_FP16)

View File

@ -182,6 +182,15 @@ set(LITE_SRC
${SRC_DIR}/ops/ops_def.cc
${SRC_DIR}/train/train_populate_parameter.cc
)
if(MSLITE_GPU_BACKEND STREQUAL opencl)
file(GLOB_RECURSE OPENCL_RUNTIME_SRC
${SRC_DIR}/runtime/gpu/opencl/*.cc
)
set(LITE_SRC
${LITE_SRC}
${OPENCL_RUNTIME_SRC}
)
endif()
file(GLOB PROTO_FILE ""
${TOP_DIR}/third_party/proto/caffe/caffe.proto
@ -215,6 +224,11 @@ add_executable(converter_lite
)
add_dependencies(converter_lite fbs_src fbs_inner_src)
if(MSLITE_GPU_BACKEND STREQUAL opencl)
include_directories(${SRC_DIR}/runtime/kernel/opencl)
target_link_libraries(converter_lite PRIVATE opencl_kernel_mid)
endif()
target_link_libraries(converter_lite PRIVATE
ccsrc_src_mid
converter_src_mid