!11517 [MS][LITE][GPU]init weight if not infer shape

From: @chenzupeng
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-01-25 17:00:47 +08:00 committed by Gitee
commit 78c733ffbe
12 changed files with 253 additions and 175 deletions

View File

@ -76,13 +76,13 @@ int ArithmeticOpenCLKernel::InitWeights() {
auto fp16_enable = ocl_runtime_->GetFp16Enable();
for (int i = 0; i < 2; ++i) {
const auto &in_tensor = in_tensors_.at(i);
GpuTensorInfo *in_shape = (i == 0) ? &in0_shape_ : &in1_shape_;
GpuTensorInfo in_shape = GpuTensorInfo(in_tensor);
if (in_tensor->IsConst()) {
std::vector<char> weight(in_shape->Image2DSize, 0);
std::vector<char> weight(in_shape.Image2DSize, 0);
bool src_is_fp16 = in_tensor->data_type() == kNumberTypeFloat16;
PackNHWCToNHWC4(in_tensor->data_c(), weight.data(), src_is_fp16, fp16_enable, *in_shape);
PackNHWCToNHWC4(in_tensor->data_c(), weight.data(), src_is_fp16, fp16_enable, in_shape);
size_t dtype = fp16_enable ? CL_HALF_FLOAT : CL_FLOAT;
ImageSize img_size{in_shape->width, in_shape->height, dtype};
ImageSize img_size{in_shape.width, in_shape.height, dtype};
auto weight_ptr_ = allocator->Malloc(img_size, weight.data());
weight_ptrs_.push_back(weight_ptr_);
} else {
@ -152,7 +152,10 @@ int ArithmeticOpenCLKernel::Prepare() {
}
SetGlobalLocal();
InitWeights();
// BiasAdd InitWeight will be called in opencl_subgraph prepare
if (Type() != PrimitiveType_BiasAdd) {
InitWeights();
}
SetConstArgs();
MS_LOG(DEBUG) << kernel_name_ << " Init Done!";
return RET_OK;

View File

@ -19,6 +19,7 @@
#include <set>
#include "nnacl/fp32/common_func_fp32.h"
#include "src/kernel_registry.h"
#include "src/ops/conv2d.h"
#ifndef PROGRAM_WITH_IL
#include "src/runtime/kernel/opencl/cl/conv2d_transpose.cl.inc"
#endif
@ -125,6 +126,14 @@ void Conv2dTransposeOpenCLKernel::SetConstArgs() {
}
int Conv2dTransposeOpenCLKernel::InitWeights() {
auto ret = InitFilter();
if (ret != RET_OK) {
return ret;
}
return InitBias();
}
int Conv2dTransposeOpenCLKernel::InitFilter() {
auto ret = DequantWeight();
if (ret != RET_OK) {
return ret;
@ -185,8 +194,15 @@ int Conv2dTransposeOpenCLKernel::InitWeights() {
}
allocator->UnmapBuffer(padWeight_);
FreeDequantedWeight();
return RET_OK;
}
int Conv2dTransposeOpenCLKernel::InitBias() {
// init bias_(image2d mem)
auto allocator = ocl_runtime_->GetAllocator();
auto data_size = enable_fp16_ ? sizeof(int16_t) : sizeof(float);
int co = out_tensors_[0]->shape()[3];
int div_co = UP_DIV(co, C4NUM);
size_t im_dst_x, im_dst_y;
im_dst_x = div_co;
im_dst_y = 1;
@ -225,6 +241,20 @@ int Conv2dTransposeOpenCLKernel::Run() {
return mindspore::lite::RET_OK;
}
int Conv2dTransposeOpenCLKernel::InferShape() {
auto ret = OpenCLKernel::InferShape();
if (ret != RET_OK) {
return ret;
}
auto param = reinterpret_cast<ConvParameter *>(op_parameter_);
auto conv2d_lite_primitive = (lite::Conv2D *)primitive_;
param->pad_u_ = conv2d_lite_primitive->PadUp();
param->pad_d_ = conv2d_lite_primitive->PadDown();
param->pad_l_ = conv2d_lite_primitive->PadLeft();
param->pad_r_ = conv2d_lite_primitive->PadRight();
return RET_OK;
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_DeConv2D, OpenCLKernelCreator<Conv2dTransposeOpenCLKernel>)
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_DeConv2D, OpenCLKernelCreator<Conv2dTransposeOpenCLKernel>)
} // namespace mindspore::kernel

View File

@ -34,8 +34,11 @@ class Conv2dTransposeOpenCLKernel : public OpenCLKernel {
int Prepare() override;
int CheckSpecs() override;
int InitWeights() override;
int InitFilter();
int InitBias();
void SetConstArgs() override;
void SetGlobalLocal() override;
int InferShape() override;
private:
void *padWeight_{nullptr};

View File

@ -52,8 +52,10 @@ int FillOpenCLKernel::RunShape() {
auto allocator_ = ocl_runtime_->GetAllocator();
auto src_data = out_tensors_[0]->data_c();
cl_float4 fill_value = {default_, default_, default_, default_};
for (int i = 0; i < in_tensors_[0]->shape().size(); ++i) {
fill_value.s[0] = in_tensors_[0]->shape()[i];
auto tensor_shape = in_tensors_[0]->shape();
void *tensor_shape_data = tensor_shape.data();
for (int i = 0; i < tensor_shape.size(); ++i) {
fill_value.s[0] = reinterpret_cast<float *>(tensor_shape_data)[i];
size_t index = static_cast<size_t>(i);
auto src_origin = cl::array<cl::size_type, 3U>{0, index, 0};
auto region = cl::array<cl::size_type, 3U>{1, 1, 1};

View File

@ -95,10 +95,6 @@ int MatMulOpenCLKernel::Prepare() {
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name);
#endif
auto ret = InitWeights();
if (ret != RET_OK) {
return ret;
}
SetConstArgs();
SetGlobalLocal();
MS_LOG(DEBUG) << kernel_name << " Init Done!";
@ -106,7 +102,7 @@ int MatMulOpenCLKernel::Prepare() {
}
int MatMulOpenCLKernel::InitWeights() {
if (act_weight_) {
if (!in_tensors_[1]->IsConst()) {
return RET_OK;
}
// ABMCI @ ABCICO = ABMCO
@ -115,12 +111,27 @@ int MatMulOpenCLKernel::InitWeights() {
return ret;
}
auto allocator = ocl_runtime_->GetAllocator();
int ci = inShape[3];
auto weight_shape = in_tensors_[1]->shape();
int weight_ndim = weight_shape.size();
std::vector<int> weight_shape_4d(MAX_DIMS, 1);
for (int i = 0; i < weight_ndim; i++) {
weight_shape_4d[MAX_DIMS - weight_ndim + i] = weight_shape[i];
}
auto param = reinterpret_cast<MatMulParameter *>(op_parameter_);
transposeB = param->b_transpose_;
enable_fp16_ = ocl_runtime_->GetFp16Enable();
int ci, co;
if (transposeB) {
ci = weight_shape_4d[3];
co = weight_shape_4d[2];
} else {
ci = weight_shape_4d[2];
co = weight_shape_4d[3];
}
int ci4 = UP_DIV(ci, C4NUM);
int co = outShape[3];
int co4 = UP_DIV(co, C4NUM);
int a = inShape[0];
int b = inShape[1];
int a = weight_shape_4d[0];
int b = weight_shape_4d[1];
size_t dtype_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float);
padWeight_ = allocator->Malloc(a * b * ci4 * co4 * C4NUM * C4NUM * dtype_size);

View File

@ -97,7 +97,7 @@ int ReshapeOpenCLKernel::Run() {
}
int ReshapeOpenCLKernel::PreProcess() {
if (Type() == PrimitiveType_Reshape) {
if (Type() == PrimitiveType_Reshape && !infer_shape_flag_) {
auto shape_tensor = in_tensors_[1];
if (!shape_tensor->IsConst()) {
ocl_runtime_->SyncCommandQueue();

View File

@ -215,7 +215,7 @@ int ScaleOpenCLKernel::Run() {
}
}
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, param->activation_type_);
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_);
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr, &event_);
return RET_OK;
}

View File

@ -53,6 +53,12 @@ int TransposeOpenCLKernel::Prepare() {
perm_4d_[1] = 1;
perm_4d_[2] = 2;
perm_4d_[3] = tensor_size_.AlignAxis(param->perm_[1]);
if (param->num_axes_ != tensor_size_.NDim) {
perm_4d_[0] = 0;
perm_4d_[1] = 1;
perm_4d_[2] = 2;
perm_4d_[3] = 3;
}
} else if (tensor_size_.NDim == 3) {
perm_4d_[0] = tensor_size_.AlignAxis(param->perm_[0]);
perm_4d_[1] = 1;
@ -65,9 +71,9 @@ int TransposeOpenCLKernel::Prepare() {
perm_4d_[3] = tensor_size_.AlignAxis(param->perm_[3]);
} else {
perm_4d_[0] = 0;
perm_4d_[0] = 1;
perm_4d_[0] = 2;
perm_4d_[0] = 3;
perm_4d_[1] = 1;
perm_4d_[2] = 2;
perm_4d_[3] = 3;
}
if (tensor_size_.N == 1 && perm_4d_[0] == 0 && perm_4d_[1] == 3 && perm_4d_[2] == 1 && perm_4d_[3] == 2) {
type_ = TransposeType::AXIS0312;

View File

@ -218,9 +218,12 @@ inline void MergeRemoveB(LiteKernel *a, LiteKernel *b, std::set<LiteKernel *> *r
// Pad + DeConv2D
// Pad + Pooling
template <typename ParamType>
void TryMergePad(LiteKernel *node, std::set<LiteKernel *> *removed_set) {
void TryMergePadXxx(LiteKernel *node, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) {
MS_ASSERT(node);
MS_ASSERT(removed_set);
if (!PredIs(node, schema::PrimitiveType_Pad, nodes)) {
return;
}
LiteKernel *pad = node->in_kernels().front();
MS_ASSERT(pad);
if (pad->in_tensors().front()->shape().size() != 4) {
@ -245,9 +248,12 @@ void TryMergePad(LiteKernel *node, std::set<LiteKernel *> *removed_set) {
}
// Conv2D + Reshape(N11C->NC)
void TryMergeConvReshape(LiteKernel *reshape, std::set<LiteKernel *> *removed_set) {
void TryMergeConvReshape(LiteKernel *reshape, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) {
MS_ASSERT(reshape);
MS_ASSERT(removed_set);
if (!PredIs(reshape, schema::PrimitiveType_Conv2D, nodes)) {
return;
}
if (N11C_NC(reshape)) {
LiteKernel *conv = reshape->in_kernels().front();
MS_ASSERT(conv);
@ -257,9 +263,12 @@ void TryMergeConvReshape(LiteKernel *reshape, std::set<LiteKernel *> *removed_se
}
// FullConnection + Reshape(NC->N11C or N11C->NC)
void TryMergeFcReshape(LiteKernel *reshape, std::set<LiteKernel *> *removed_set) {
void TryMergeFcReshape(LiteKernel *reshape, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) {
MS_ASSERT(reshape);
MS_ASSERT(removed_set);
if (!PredIs(reshape, schema::PrimitiveType_FullConnection, nodes)) {
return;
}
bool NC_N11C_flag = NC_N11C(reshape);
if (NC_N11C_flag || N11C_NC(reshape)) {
LiteKernel *fc = reshape->in_kernels().front();
@ -272,9 +281,12 @@ void TryMergeFcReshape(LiteKernel *reshape, std::set<LiteKernel *> *removed_set)
// Reshape(NC11->NC) + FullConnection
// Reshape(NC->N11C) + FullConnection
void TryMergeReshapeFc(LiteKernel *fc, std::set<LiteKernel *> *removed_set) {
void TryMergeReshapeFc(LiteKernel *fc, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) {
MS_ASSERT(fc);
MS_ASSERT(removed_set);
if (!PredIs(fc, schema::PrimitiveType_Reshape, nodes)) {
return;
}
LiteKernel *reshape = fc->in_kernels().front();
MS_ASSERT(reshape);
bool NC11_NC_flag = NC11_NC(reshape);
@ -308,7 +320,7 @@ void TryMergeArithmeticAct(LiteKernel *act, std::set<LiteKernel *> *removed_set)
// Conv2D(NO_ACTIVATION) + Activation(RELU/RELU6/TANH)
// FullConnection(NO_ACTIVATION) + Activation(RELU/RELU6/TANH)
template <typename ParamType>
void TryMergeActivation(LiteKernel *act, std::set<LiteKernel *> *removed_set) {
void TryMergeXxxActivation(LiteKernel *act, std::set<LiteKernel *> *removed_set) {
MS_ASSERT(node);
MS_ASSERT(removed_set);
auto *act_param = reinterpret_cast<ActivationParameter *>(reinterpret_cast<OpenCLKernel *>(act)->GetParameter());
@ -316,7 +328,6 @@ void TryMergeActivation(LiteKernel *act, std::set<LiteKernel *> *removed_set) {
auto *param = reinterpret_cast<ParamType *>(reinterpret_cast<OpenCLKernel *>(node)->GetParameter());
MS_ASSERT(param);
if (param->act_type_ == ActType_No) {
param->act_type_ = static_cast<ActType>(act_param->type_);
std::string act_name;
if (act_param->type_ == ActivationType_RELU) {
act_name = "RELU";
@ -324,16 +335,25 @@ void TryMergeActivation(LiteKernel *act, std::set<LiteKernel *> *removed_set) {
act_name = "RELU6";
} else if (act_param->type_ == ActivationType_TANH) {
act_name = "TANH";
} else {
MS_LOG(DEBUG) << "Merge " + GetTypeName(node) + "(NO_ACTIVATION) and Activation(" + act_name +
") is not supported";
return;
}
param->act_type_ = static_cast<ActType>(act_param->type_);
MergeRemoveB(node, act, removed_set);
MS_LOG(DEBUG) << "Merge " + GetTypeName(node) + "(NO_ACTIVATION) and Activation(" + act_name + ") success";
}
}
// Conv2D(NO_ACTIVATION) + PReLU(weight is scalar)
void TryMergeConvPReLU(LiteKernel *prelu, std::set<LiteKernel *> *removed_set) {
void TryMergeConvPReLU(LiteKernel *prelu, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) {
MS_ASSERT(prelu);
MS_ASSERT(removed_set);
if (!PredIs(prelu, schema::PrimitiveType_Conv2D, nodes)) {
return;
}
if (prelu->in_tensors().size() != 2) {
return;
}
@ -409,7 +429,7 @@ int TryFusionConvScaleWeight(LiteKernel *conv_kernel, LiteKernel *scale_kernel)
bias_data[co] *= scale_data[co];
bias_data[co] += offset_data[co];
}
} else { // if deconv dont't have bias, let scale's offset be deconv's bias
} else { // if deconv don't have bias, let scale's offset be deconv's bias
auto tmp = conv_kernel->in_tensors();
tmp.push_back(offset);
conv_kernel->set_in_tensors(tmp);
@ -418,9 +438,12 @@ int TryFusionConvScaleWeight(LiteKernel *conv_kernel, LiteKernel *scale_kernel)
}
// DeConv2D + Scale (can't both has activation)
void TryMergeDeconvScale(LiteKernel *scale, std::set<LiteKernel *> *removed_set) {
void TryMergeDeconvScale(LiteKernel *scale, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) {
MS_ASSERT(scale);
MS_ASSERT(removed_set);
if (!PredIs(scale, schema::PrimitiveType_DeConv2D, nodes)) {
return;
}
LiteKernel *deconv = scale->in_kernels().front();
MS_ASSERT(deconv);
@ -493,7 +516,7 @@ void CreateEltwiseKernelReplaceOld(FusionEltwiseParameter *param, LiteKernel *ol
}
// Eltwise + Eltwise
int TryMergeEltwiseEltwise(LiteKernel *node, std::vector<LiteKernel *> *nodes, std::set<LiteKernel *> *removed_set) {
int TryMergeEltwiseEltwise(LiteKernel *node, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) {
MS_ASSERT(node);
MS_ASSERT(nodes);
MS_ASSERT(removed_set);
@ -536,12 +559,56 @@ int TryMergeEltwiseEltwise(LiteKernel *node, std::vector<LiteKernel *> *nodes, s
return RET_OK;
}
void DoSpecificFusion(LiteKernel *node, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) {
switch (node->Type()) {
case schema::PrimitiveType_Conv2D:
case schema::PrimitiveType_DepthwiseConv2D:
case schema::PrimitiveType_DeConv2D: {
TryMergePadXxx<ConvParameter>(node, removed_set, nodes);
break;
}
case schema::PrimitiveType_Pooling: {
TryMergePadXxx<PoolingParameter>(node, removed_set, nodes);
break;
}
case schema::PrimitiveType_Reshape: {
TryMergeFcReshape(node, removed_set, nodes);
TryMergeConvReshape(node, removed_set, nodes);
break;
}
case schema::PrimitiveType_FullConnection: {
TryMergeReshapeFc(node, removed_set, nodes);
break;
}
case schema::PrimitiveType_Activation: {
// try merge Conv2D/FC(without act) + RELU/RELU6/TANH
// try merge Arithmetic(without act) + RELU/RELU6
if (PredIs(node, schema::PrimitiveType_Conv2D, nodes)) {
TryMergeXxxActivation<ConvParameter>(node, removed_set);
} else if (PredIs(node, schema::PrimitiveType_FullConnection, nodes)) {
TryMergeXxxActivation<MatMulParameter>(node, removed_set);
} else if (std::any_of(ArithmeticPrimitives.begin(), ArithmeticPrimitives.end(),
[&](schema::PrimitiveType type) { return PredIs(node, type, nodes); })) {
TryMergeArithmeticAct(node, removed_set);
}
break;
}
case schema::PrimitiveType_PReLU: {
TryMergeConvPReLU(node, removed_set, nodes);
break;
}
case schema::PrimitiveType_Scale: {
TryMergeDeconvScale(node, removed_set, nodes);
break;
}
default:
break;
}
} // namespace
} // namespace
int OpenCLSubGraph::FusionPass() {
if (!this->IsSubGraphInferShapeDone()) {
return RET_OK;
}
MS_LOG(DEBUG) << "start Fusion";
std::vector<LiteKernel *> input_nodes;
@ -579,77 +646,12 @@ int OpenCLSubGraph::FusionPass() {
}
// do element-wise fusion, like mul+add, mul+add+relu
if (TryMergeEltwiseEltwise(node, &nodes_, &removed_set) == RET_OK) {
if (TryMergeEltwiseEltwise(node, &removed_set, &nodes_) == RET_OK) {
continue;
}
// do special fusion, like pad+conv2d, fc+reshape
switch (node->Type()) {
case schema::PrimitiveType_Conv2D:
case schema::PrimitiveType_DepthwiseConv2D:
case schema::PrimitiveType_DeConv2D: {
if (PredIs(node, schema::PrimitiveType_Pad, &nodes_)) {
TryMergePad<ConvParameter>(node, &removed_set);
}
break;
}
case schema::PrimitiveType_Pooling: {
if (PredIs(node, schema::PrimitiveType_Pad, &nodes_)) {
TryMergePad<PoolingParameter>(node, &removed_set);
}
break;
}
case schema::PrimitiveType_Reshape: {
if (PredIs(node, schema::PrimitiveType_FullConnection, &nodes_)) {
TryMergeFcReshape(node, &removed_set);
} else if (PredIs(node, schema::PrimitiveType_Conv2D, &nodes_)) {
TryMergeConvReshape(node, &removed_set);
}
break;
}
case schema::PrimitiveType_FullConnection: {
if (PredIs(node, schema::PrimitiveType_Reshape, &nodes_)) {
TryMergeReshapeFc(node, &removed_set);
}
break;
}
case schema::PrimitiveType_Activation: {
// try merge Conv2D/FC(without act) + RELU/RELU6/TANH
auto *param = reinterpret_cast<ActivationParameter *>(reinterpret_cast<OpenCLKernel *>(node)->GetParameter());
MS_ASSERT(param);
if (param->type_ == ActivationType_RELU || param->type_ == ActivationType_RELU6 ||
param->type_ == ActivationType_TANH) {
if (PredIs(node, schema::PrimitiveType_Conv2D, &nodes_)) {
TryMergeActivation<ConvParameter>(node, &removed_set);
break;
} else if (PredIs(node, schema::PrimitiveType_FullConnection, &nodes_)) {
TryMergeActivation<MatMulParameter>(node, &removed_set);
break;
}
}
if (std::any_of(ArithmeticPrimitives.begin(), ArithmeticPrimitives.end(),
[&](schema::PrimitiveType type) { return PredIs(node, type, &nodes_); })) {
TryMergeArithmeticAct(node, &removed_set);
}
break;
}
case schema::PrimitiveType_PReLU: {
if (PredIs(node, schema::PrimitiveType_Conv2D, &nodes_)) {
TryMergeConvPReLU(node, &removed_set);
break;
}
break;
}
case schema::PrimitiveType_Scale: {
if (PredIs(node, schema::PrimitiveType_DeConv2D, &nodes_)) {
TryMergeDeconvScale(node, &removed_set);
break;
}
break;
}
default:
break;
}
// do specific fusion, like pad+conv2d, fc+reshape, etc.
DoSpecificFusion(node, &removed_set, &nodes_);
}
for (auto kernel : removed_set) {

View File

@ -330,6 +330,14 @@ int OpenCLSubGraph::Prepare() {
return mindspore::lite::RET_NULL_PTR;
}
auto opencl_kernel = reinterpret_cast<kernel::OpenCLKernel *>(node);
std::set<int> pre_init_weight_list = {schema::PrimitiveType_MatMul, schema::PrimitiveType_BiasAdd};
if (pre_init_weight_list.find(opencl_kernel->Type()) != pre_init_weight_list.end()) {
ret = opencl_kernel->InitWeights();
if (ret != RET_OK) {
MS_LOG(ERROR) << "init weights " << node->name() << " failed";
return ret;
}
}
if (opencl_kernel->GetInferShapeFlag()) {
ret = node->Prepare();
if (ret != RET_OK) {

View File

@ -72,36 +72,10 @@ void printf_callback(const char *buffer, size_t length, size_t final, void *user
fwrite(buffer, 1, length, stdout);
}
// Init will get platforms info, get devices info, create opencl context.
int OpenCLRuntime::Init() {
std::unique_lock<std::mutex> lck(g_init_mtx);
if (init_state_ == InitSuccess) {
return RET_OK;
} else if (init_state_ == InitFailed) {
return RET_ERROR;
}
init_state_ = InitFailed;
MS_LOG(INFO) << "OpenCL version: CL_TARGET_OPENCL_VERSION " << CL_TARGET_OPENCL_VERSION;
MS_LOG(INFO) << "CL_HPP_TARGET_OPENCL_VERSION " << CL_HPP_TARGET_OPENCL_VERSION;
MS_LOG(INFO) << "CL_HPP_MINIMUM_OPENCL_VERSION " << CL_HPP_MINIMUM_OPENCL_VERSION;
#ifdef USE_OPENCL_WRAPPER
if (!lite::opencl::LoadOpenCLLibrary(&handle_)) {
MS_LOG(ERROR) << "Load OpenCL symbols failed!";
return RET_ERROR;
}
#endif // USE_OPENCL_WRAPPER
std::vector<cl::Platform> platforms;
cl_int ret = cl::Platform::get(&platforms);
if (platforms.empty()) {
MS_LOG(ERROR) << "OpenCL Platform not found!" << CLErrorCode(ret);
return RET_ERROR;
}
int OpenCLRuntime::InitGPUDevice(std::vector<cl::Platform> &platforms) {
// search GPU
std::vector<cl::Device> devices;
int ret = RET_OK;
for (auto &platform : platforms) {
std::string platform_name;
ret = platform.getInfo(CL_PLATFORM_NAME, &platform_name);
@ -148,45 +122,6 @@ int OpenCLRuntime::Init() {
<< max_work_item_sizes_[2];
gpu_info_ = ParseGpuInfo(device_name, device_version);
// cl_int ret;
#if defined(SHARING_MEM_WITH_OPENGL) && (CL_HPP_TARGET_OPENCL_VERSION >= 120)
// create context from glcontext
MS_LOG(INFO) << "Create special opencl context to share with OpenGL";
cl_context_properties context_prop[] = {CL_GL_CONTEXT_KHR, (cl_context_properties)eglGetCurrentContext(),
CL_EGL_DISPLAY_KHR, (cl_context_properties)eglGetCurrentDisplay(), 0};
context_ = new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, context_prop, nullptr, nullptr, &ret);
if (ret != CL_SUCCESS) {
MS_LOG(ERROR) << "Create special OpenCL context failed, Create common OpenCL context then.";
context_ = new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, nullptr, nullptr, nullptr, &ret);
if (context_ == nullptr) {
delete device_;
MS_LOG(ERROR) << "Create OpenCL context failed!";
return RET_ERROR;
}
}
#else
MS_LOG(INFO) << "Create common opencl context";
#ifdef Debug
std::vector<cl_context_properties> ctx_properties = {CL_CONTEXT_PLATFORM,
(cl_context_properties)platforms[0](),
CL_PRINTF_CALLBACK_ARM,
(cl_context_properties)printf_callback,
CL_PRINTF_BUFFERSIZE_ARM,
0x1000000,
0};
context_ =
new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, ctx_properties.data(), nullptr, nullptr, &ret);
#else
context_ = new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, nullptr, nullptr, nullptr, &ret);
#endif
#endif
if (ret != CL_SUCCESS) {
delete device_;
MS_LOG(ERROR) << "Context create failed: " << CLErrorCode(ret);
return RET_ERROR;
}
// get cache size, compute units and frequency.
ret = device_->getInfo(CL_DEVICE_GLOBAL_MEM_CACHE_SIZE, &global_memery_cachesize_);
if (ret != CL_SUCCESS) {
@ -235,6 +170,48 @@ int OpenCLRuntime::Init() {
MS_LOG(INFO) << "Max Alloc Size: " << max_alloc_size_;
MS_LOG(INFO) << "Compute Unit: " << compute_units_;
MS_LOG(INFO) << "Clock Frequency: " << max_freq_ << " MHz";
return RET_OK;
}
int OpenCLRuntime::InitQueue(std::vector<cl::Platform> &platforms) {
cl_int ret;
#if defined(SHARING_MEM_WITH_OPENGL) && (CL_HPP_TARGET_OPENCL_VERSION >= 120)
// create context from glcontext
MS_LOG(INFO) << "Create special opencl context to share with OpenGL";
cl_context_properties context_prop[] = {CL_GL_CONTEXT_KHR, (cl_context_properties)eglGetCurrentContext(),
CL_EGL_DISPLAY_KHR, (cl_context_properties)eglGetCurrentDisplay(), 0};
context_ = new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, context_prop, nullptr, nullptr, &ret);
if (ret != CL_SUCCESS) {
MS_LOG(ERROR) << "Create special OpenCL context failed, Create common OpenCL context then.";
context_ = new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, nullptr, nullptr, nullptr, &ret);
if (context_ == nullptr) {
delete device_;
MS_LOG(ERROR) << "Create OpenCL context failed!";
return RET_ERROR;
}
}
#else
MS_LOG(INFO) << "Create common opencl context";
#ifdef Debug
std::vector<cl_context_properties> ctx_properties = {CL_CONTEXT_PLATFORM,
(cl_context_properties)platforms[0](),
CL_PRINTF_CALLBACK_ARM,
(cl_context_properties)printf_callback,
CL_PRINTF_BUFFERSIZE_ARM,
0x1000000,
0};
context_ =
new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, ctx_properties.data(), nullptr, nullptr, &ret);
#else
context_ = new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, nullptr, nullptr, nullptr, &ret);
#endif
#endif
if (ret != CL_SUCCESS) {
delete device_;
MS_LOG(ERROR) << "Context create failed: " << CLErrorCode(ret);
return RET_ERROR;
}
default_command_queue_ = new (std::nothrow) cl::CommandQueue(*context_, *device_, 0, &ret);
if (ret != CL_SUCCESS) {
@ -252,6 +229,44 @@ int OpenCLRuntime::Init() {
MS_LOG(ERROR) << "Profiling command Queue create failed: " << CLErrorCode(ret);
return RET_ERROR;
}
return RET_OK;
}
// Init will get platforms info, get devices info, create opencl context.
int OpenCLRuntime::Init() {
std::unique_lock<std::mutex> lck(g_init_mtx);
if (init_state_ == InitSuccess) {
return RET_OK;
} else if (init_state_ == InitFailed) {
return RET_ERROR;
}
init_state_ = InitFailed;
MS_LOG(INFO) << "OpenCL version: CL_TARGET_OPENCL_VERSION " << CL_TARGET_OPENCL_VERSION;
MS_LOG(INFO) << "CL_HPP_TARGET_OPENCL_VERSION " << CL_HPP_TARGET_OPENCL_VERSION;
MS_LOG(INFO) << "CL_HPP_MINIMUM_OPENCL_VERSION " << CL_HPP_MINIMUM_OPENCL_VERSION;
#ifdef USE_OPENCL_WRAPPER
if (!lite::opencl::LoadOpenCLLibrary(&handle_)) {
MS_LOG(ERROR) << "Load OpenCL symbols failed!";
return RET_ERROR;
}
#endif // USE_OPENCL_WRAPPER
std::vector<cl::Platform> platforms;
cl_int ret = cl::Platform::get(&platforms);
if (platforms.empty()) {
MS_LOG(ERROR) << "OpenCL Platform not found!" << CLErrorCode(ret);
return RET_ERROR;
}
auto ms_ret = InitGPUDevice(platforms);
if (ms_ret != RET_OK) {
return ms_ret;
}
ms_ret = InitQueue(platforms);
if (ms_ret != RET_OK) {
return ms_ret;
}
allocator_ = new (std::nothrow) OpenCLAllocator(this);
if (allocator_ == nullptr) {
@ -289,10 +304,6 @@ int OpenCLRuntime::Uninit() {
profiling_command_queue_ = nullptr;
context_ = nullptr;
device_ = nullptr;
#ifdef USE_OPENCL_WRAPPER
lite::opencl::UnLoadOpenCLLibrary(handle_);
handle_ = nullptr;
#endif
init_state_ = UnInit;
return RET_OK;
}

View File

@ -56,7 +56,7 @@ class OpenCLRuntime {
cl::Context *Context();
cl::Device *Device();
OpenCLAllocator *GetAllocator() { return allocator_; }
cl::CommandQueue *GetDefaultCommandQueue() { return default_command_queue_; }
cl::CommandQueue *GetDefaultCommandQueue() { return profiling_ ? profiling_command_queue_ : default_command_queue_; }
uint64_t DeviceGlobalMemoryCacheSize() const;
int DeviceMaxWorkGroupSize() const;
uint32_t DeviceComputeUnits() const;
@ -101,7 +101,7 @@ class OpenCLRuntime {
return kernel.setArg(index, *image);
}
default:
MS_LOG(ERROR) << "Unsupport opencl memory type: " << static_cast<int>(mem_type);
MS_LOG(ERROR) << "Unsupported opencl memory type: " << static_cast<int>(mem_type);
return CL_IMAGE_FORMAT_NOT_SUPPORTED;
}
}
@ -159,6 +159,8 @@ class OpenCLRuntime {
bool LoadProgram(const std::string &program_name, cl::Program *program);
bool BuildProgram(const std::string &build_options, const cl::Program &program);
int InitGPUDevice(std::vector<cl::Platform> &platforms);
int InitQueue(std::vector<cl::Platform> &platforms);
private:
static InitState init_state_;