diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc index 93b9e660eed..abe905b6363 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc @@ -76,13 +76,13 @@ int ArithmeticOpenCLKernel::InitWeights() { auto fp16_enable = ocl_runtime_->GetFp16Enable(); for (int i = 0; i < 2; ++i) { const auto &in_tensor = in_tensors_.at(i); - GpuTensorInfo *in_shape = (i == 0) ? &in0_shape_ : &in1_shape_; + GpuTensorInfo in_shape = GpuTensorInfo(in_tensor); if (in_tensor->IsConst()) { - std::vector weight(in_shape->Image2DSize, 0); + std::vector weight(in_shape.Image2DSize, 0); bool src_is_fp16 = in_tensor->data_type() == kNumberTypeFloat16; - PackNHWCToNHWC4(in_tensor->data_c(), weight.data(), src_is_fp16, fp16_enable, *in_shape); + PackNHWCToNHWC4(in_tensor->data_c(), weight.data(), src_is_fp16, fp16_enable, in_shape); size_t dtype = fp16_enable ? CL_HALF_FLOAT : CL_FLOAT; - ImageSize img_size{in_shape->width, in_shape->height, dtype}; + ImageSize img_size{in_shape.width, in_shape.height, dtype}; auto weight_ptr_ = allocator->Malloc(img_size, weight.data()); weight_ptrs_.push_back(weight_ptr_); } else { @@ -152,7 +152,10 @@ int ArithmeticOpenCLKernel::Prepare() { } SetGlobalLocal(); - InitWeights(); + // BiasAdd InitWeight will be called in opencl_subgraph prepare + if (Type() != PrimitiveType_BiasAdd) { + InitWeights(); + } SetConstArgs(); MS_LOG(DEBUG) << kernel_name_ << " Init Done!"; return RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc index ac65f13fe73..bce92f538bd 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc @@ -19,6 +19,7 @@ #include #include "nnacl/fp32/common_func_fp32.h" #include "src/kernel_registry.h" +#include "src/ops/conv2d.h" #ifndef PROGRAM_WITH_IL #include "src/runtime/kernel/opencl/cl/conv2d_transpose.cl.inc" #endif @@ -125,6 +126,14 @@ void Conv2dTransposeOpenCLKernel::SetConstArgs() { } int Conv2dTransposeOpenCLKernel::InitWeights() { + auto ret = InitFilter(); + if (ret != RET_OK) { + return ret; + } + return InitBias(); +} + +int Conv2dTransposeOpenCLKernel::InitFilter() { auto ret = DequantWeight(); if (ret != RET_OK) { return ret; @@ -185,8 +194,15 @@ int Conv2dTransposeOpenCLKernel::InitWeights() { } allocator->UnmapBuffer(padWeight_); FreeDequantedWeight(); + return RET_OK; +} +int Conv2dTransposeOpenCLKernel::InitBias() { // init bias_(image2d mem) + auto allocator = ocl_runtime_->GetAllocator(); + auto data_size = enable_fp16_ ? sizeof(int16_t) : sizeof(float); + int co = out_tensors_[0]->shape()[3]; + int div_co = UP_DIV(co, C4NUM); size_t im_dst_x, im_dst_y; im_dst_x = div_co; im_dst_y = 1; @@ -225,6 +241,20 @@ int Conv2dTransposeOpenCLKernel::Run() { return mindspore::lite::RET_OK; } +int Conv2dTransposeOpenCLKernel::InferShape() { + auto ret = OpenCLKernel::InferShape(); + if (ret != RET_OK) { + return ret; + } + auto param = reinterpret_cast(op_parameter_); + auto conv2d_lite_primitive = (lite::Conv2D *)primitive_; + param->pad_u_ = conv2d_lite_primitive->PadUp(); + param->pad_d_ = conv2d_lite_primitive->PadDown(); + param->pad_l_ = conv2d_lite_primitive->PadLeft(); + param->pad_r_ = conv2d_lite_primitive->PadRight(); + return RET_OK; +} + REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_DeConv2D, OpenCLKernelCreator) REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_DeConv2D, OpenCLKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h index d8064dd19a8..b1c1c789227 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h @@ -34,8 +34,11 @@ class Conv2dTransposeOpenCLKernel : public OpenCLKernel { int Prepare() override; int CheckSpecs() override; int InitWeights() override; + int InitFilter(); + int InitBias(); void SetConstArgs() override; void SetGlobalLocal() override; + int InferShape() override; private: void *padWeight_{nullptr}; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/fill.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/fill.cc index 4b4ead6383f..ae79a65bc81 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/fill.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/fill.cc @@ -52,8 +52,10 @@ int FillOpenCLKernel::RunShape() { auto allocator_ = ocl_runtime_->GetAllocator(); auto src_data = out_tensors_[0]->data_c(); cl_float4 fill_value = {default_, default_, default_, default_}; - for (int i = 0; i < in_tensors_[0]->shape().size(); ++i) { - fill_value.s[0] = in_tensors_[0]->shape()[i]; + auto tensor_shape = in_tensors_[0]->shape(); + void *tensor_shape_data = tensor_shape.data(); + for (int i = 0; i < tensor_shape.size(); ++i) { + fill_value.s[0] = reinterpret_cast(tensor_shape_data)[i]; size_t index = static_cast(i); auto src_origin = cl::array{0, index, 0}; auto region = cl::array{1, 1, 1}; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc index dfa63ae24d4..47e19636c70 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc @@ -95,10 +95,6 @@ int MatMulOpenCLKernel::Prepare() { ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name); #endif - auto ret = InitWeights(); - if (ret != RET_OK) { - return ret; - } SetConstArgs(); SetGlobalLocal(); MS_LOG(DEBUG) << kernel_name << " Init Done!"; @@ -106,7 +102,7 @@ int MatMulOpenCLKernel::Prepare() { } int MatMulOpenCLKernel::InitWeights() { - if (act_weight_) { + if (!in_tensors_[1]->IsConst()) { return RET_OK; } // ABMCI @ ABCICO = ABMCO @@ -115,12 +111,27 @@ int MatMulOpenCLKernel::InitWeights() { return ret; } auto allocator = ocl_runtime_->GetAllocator(); - int ci = inShape[3]; + auto weight_shape = in_tensors_[1]->shape(); + int weight_ndim = weight_shape.size(); + std::vector weight_shape_4d(MAX_DIMS, 1); + for (int i = 0; i < weight_ndim; i++) { + weight_shape_4d[MAX_DIMS - weight_ndim + i] = weight_shape[i]; + } + auto param = reinterpret_cast(op_parameter_); + transposeB = param->b_transpose_; + enable_fp16_ = ocl_runtime_->GetFp16Enable(); + int ci, co; + if (transposeB) { + ci = weight_shape_4d[3]; + co = weight_shape_4d[2]; + } else { + ci = weight_shape_4d[2]; + co = weight_shape_4d[3]; + } int ci4 = UP_DIV(ci, C4NUM); - int co = outShape[3]; int co4 = UP_DIV(co, C4NUM); - int a = inShape[0]; - int b = inShape[1]; + int a = weight_shape_4d[0]; + int b = weight_shape_4d[1]; size_t dtype_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float); padWeight_ = allocator->Malloc(a * b * ci4 * co4 * C4NUM * C4NUM * dtype_size); diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.cc index 496bbae6b3f..248d69d4970 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.cc @@ -97,7 +97,7 @@ int ReshapeOpenCLKernel::Run() { } int ReshapeOpenCLKernel::PreProcess() { - if (Type() == PrimitiveType_Reshape) { + if (Type() == PrimitiveType_Reshape && !infer_shape_flag_) { auto shape_tensor = in_tensors_[1]; if (!shape_tensor->IsConst()) { ocl_runtime_->SyncCommandQueue(); diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/scale.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/scale.cc index 1e38716d9e5..67cdd3552a6 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/scale.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/scale.cc @@ -215,7 +215,7 @@ int ScaleOpenCLKernel::Run() { } } ocl_runtime_->SetKernelArg(kernel_, arg_idx++, param->activation_type_); - ocl_runtime_->RunKernel(kernel_, global_range_, local_range_); + ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr, &event_); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc index f67116124b7..58601eb5ac0 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc @@ -53,6 +53,12 @@ int TransposeOpenCLKernel::Prepare() { perm_4d_[1] = 1; perm_4d_[2] = 2; perm_4d_[3] = tensor_size_.AlignAxis(param->perm_[1]); + if (param->num_axes_ != tensor_size_.NDim) { + perm_4d_[0] = 0; + perm_4d_[1] = 1; + perm_4d_[2] = 2; + perm_4d_[3] = 3; + } } else if (tensor_size_.NDim == 3) { perm_4d_[0] = tensor_size_.AlignAxis(param->perm_[0]); perm_4d_[1] = 1; @@ -65,9 +71,9 @@ int TransposeOpenCLKernel::Prepare() { perm_4d_[3] = tensor_size_.AlignAxis(param->perm_[3]); } else { perm_4d_[0] = 0; - perm_4d_[0] = 1; - perm_4d_[0] = 2; - perm_4d_[0] = 3; + perm_4d_[1] = 1; + perm_4d_[2] = 2; + perm_4d_[3] = 3; } if (tensor_size_.N == 1 && perm_4d_[0] == 0 && perm_4d_[1] == 3 && perm_4d_[2] == 1 && perm_4d_[3] == 2) { type_ = TransposeType::AXIS0312; diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc b/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc index e12a6a240cd..9d2e157c496 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc @@ -218,9 +218,12 @@ inline void MergeRemoveB(LiteKernel *a, LiteKernel *b, std::set *r // Pad + DeConv2D // Pad + Pooling template -void TryMergePad(LiteKernel *node, std::set *removed_set) { +void TryMergePadXxx(LiteKernel *node, std::set *removed_set, std::vector *nodes) { MS_ASSERT(node); MS_ASSERT(removed_set); + if (!PredIs(node, schema::PrimitiveType_Pad, nodes)) { + return; + } LiteKernel *pad = node->in_kernels().front(); MS_ASSERT(pad); if (pad->in_tensors().front()->shape().size() != 4) { @@ -245,9 +248,12 @@ void TryMergePad(LiteKernel *node, std::set *removed_set) { } // Conv2D + Reshape(N11C->NC) -void TryMergeConvReshape(LiteKernel *reshape, std::set *removed_set) { +void TryMergeConvReshape(LiteKernel *reshape, std::set *removed_set, std::vector *nodes) { MS_ASSERT(reshape); MS_ASSERT(removed_set); + if (!PredIs(reshape, schema::PrimitiveType_Conv2D, nodes)) { + return; + } if (N11C_NC(reshape)) { LiteKernel *conv = reshape->in_kernels().front(); MS_ASSERT(conv); @@ -257,9 +263,12 @@ void TryMergeConvReshape(LiteKernel *reshape, std::set *removed_se } // FullConnection + Reshape(NC->N11C or N11C->NC) -void TryMergeFcReshape(LiteKernel *reshape, std::set *removed_set) { +void TryMergeFcReshape(LiteKernel *reshape, std::set *removed_set, std::vector *nodes) { MS_ASSERT(reshape); MS_ASSERT(removed_set); + if (!PredIs(reshape, schema::PrimitiveType_FullConnection, nodes)) { + return; + } bool NC_N11C_flag = NC_N11C(reshape); if (NC_N11C_flag || N11C_NC(reshape)) { LiteKernel *fc = reshape->in_kernels().front(); @@ -272,9 +281,12 @@ void TryMergeFcReshape(LiteKernel *reshape, std::set *removed_set) // Reshape(NC11->NC) + FullConnection // Reshape(NC->N11C) + FullConnection -void TryMergeReshapeFc(LiteKernel *fc, std::set *removed_set) { +void TryMergeReshapeFc(LiteKernel *fc, std::set *removed_set, std::vector *nodes) { MS_ASSERT(fc); MS_ASSERT(removed_set); + if (!PredIs(fc, schema::PrimitiveType_Reshape, nodes)) { + return; + } LiteKernel *reshape = fc->in_kernels().front(); MS_ASSERT(reshape); bool NC11_NC_flag = NC11_NC(reshape); @@ -308,7 +320,7 @@ void TryMergeArithmeticAct(LiteKernel *act, std::set *removed_set) // Conv2D(NO_ACTIVATION) + Activation(RELU/RELU6/TANH) // FullConnection(NO_ACTIVATION) + Activation(RELU/RELU6/TANH) template -void TryMergeActivation(LiteKernel *act, std::set *removed_set) { +void TryMergeXxxActivation(LiteKernel *act, std::set *removed_set) { MS_ASSERT(node); MS_ASSERT(removed_set); auto *act_param = reinterpret_cast(reinterpret_cast(act)->GetParameter()); @@ -316,7 +328,6 @@ void TryMergeActivation(LiteKernel *act, std::set *removed_set) { auto *param = reinterpret_cast(reinterpret_cast(node)->GetParameter()); MS_ASSERT(param); if (param->act_type_ == ActType_No) { - param->act_type_ = static_cast(act_param->type_); std::string act_name; if (act_param->type_ == ActivationType_RELU) { act_name = "RELU"; @@ -324,16 +335,25 @@ void TryMergeActivation(LiteKernel *act, std::set *removed_set) { act_name = "RELU6"; } else if (act_param->type_ == ActivationType_TANH) { act_name = "TANH"; + } else { + MS_LOG(DEBUG) << "Merge " + GetTypeName(node) + "(NO_ACTIVATION) and Activation(" + act_name + + ") is not supported"; + return; } + param->act_type_ = static_cast(act_param->type_); MergeRemoveB(node, act, removed_set); MS_LOG(DEBUG) << "Merge " + GetTypeName(node) + "(NO_ACTIVATION) and Activation(" + act_name + ") success"; } } // Conv2D(NO_ACTIVATION) + PReLU(weight is scalar) -void TryMergeConvPReLU(LiteKernel *prelu, std::set *removed_set) { +void TryMergeConvPReLU(LiteKernel *prelu, std::set *removed_set, std::vector *nodes) { MS_ASSERT(prelu); MS_ASSERT(removed_set); + if (!PredIs(prelu, schema::PrimitiveType_Conv2D, nodes)) { + return; + } + if (prelu->in_tensors().size() != 2) { return; } @@ -409,7 +429,7 @@ int TryFusionConvScaleWeight(LiteKernel *conv_kernel, LiteKernel *scale_kernel) bias_data[co] *= scale_data[co]; bias_data[co] += offset_data[co]; } - } else { // if deconv dont't have bias, let scale's offset be deconv's bias + } else { // if deconv don't have bias, let scale's offset be deconv's bias auto tmp = conv_kernel->in_tensors(); tmp.push_back(offset); conv_kernel->set_in_tensors(tmp); @@ -418,9 +438,12 @@ int TryFusionConvScaleWeight(LiteKernel *conv_kernel, LiteKernel *scale_kernel) } // DeConv2D + Scale (can't both has activation) -void TryMergeDeconvScale(LiteKernel *scale, std::set *removed_set) { +void TryMergeDeconvScale(LiteKernel *scale, std::set *removed_set, std::vector *nodes) { MS_ASSERT(scale); MS_ASSERT(removed_set); + if (!PredIs(scale, schema::PrimitiveType_DeConv2D, nodes)) { + return; + } LiteKernel *deconv = scale->in_kernels().front(); MS_ASSERT(deconv); @@ -493,7 +516,7 @@ void CreateEltwiseKernelReplaceOld(FusionEltwiseParameter *param, LiteKernel *ol } // Eltwise + Eltwise -int TryMergeEltwiseEltwise(LiteKernel *node, std::vector *nodes, std::set *removed_set) { +int TryMergeEltwiseEltwise(LiteKernel *node, std::set *removed_set, std::vector *nodes) { MS_ASSERT(node); MS_ASSERT(nodes); MS_ASSERT(removed_set); @@ -536,12 +559,56 @@ int TryMergeEltwiseEltwise(LiteKernel *node, std::vector *nodes, s return RET_OK; } +void DoSpecificFusion(LiteKernel *node, std::set *removed_set, std::vector *nodes) { + switch (node->Type()) { + case schema::PrimitiveType_Conv2D: + case schema::PrimitiveType_DepthwiseConv2D: + case schema::PrimitiveType_DeConv2D: { + TryMergePadXxx(node, removed_set, nodes); + break; + } + case schema::PrimitiveType_Pooling: { + TryMergePadXxx(node, removed_set, nodes); + break; + } + case schema::PrimitiveType_Reshape: { + TryMergeFcReshape(node, removed_set, nodes); + TryMergeConvReshape(node, removed_set, nodes); + break; + } + case schema::PrimitiveType_FullConnection: { + TryMergeReshapeFc(node, removed_set, nodes); + break; + } + case schema::PrimitiveType_Activation: { + // try merge Conv2D/FC(without act) + RELU/RELU6/TANH + // try merge Arithmetic(without act) + RELU/RELU6 + if (PredIs(node, schema::PrimitiveType_Conv2D, nodes)) { + TryMergeXxxActivation(node, removed_set); + } else if (PredIs(node, schema::PrimitiveType_FullConnection, nodes)) { + TryMergeXxxActivation(node, removed_set); + } else if (std::any_of(ArithmeticPrimitives.begin(), ArithmeticPrimitives.end(), + [&](schema::PrimitiveType type) { return PredIs(node, type, nodes); })) { + TryMergeArithmeticAct(node, removed_set); + } + break; + } + case schema::PrimitiveType_PReLU: { + TryMergeConvPReLU(node, removed_set, nodes); + break; + } + case schema::PrimitiveType_Scale: { + TryMergeDeconvScale(node, removed_set, nodes); + break; + } + default: + break; + } +} // namespace + } // namespace int OpenCLSubGraph::FusionPass() { - if (!this->IsSubGraphInferShapeDone()) { - return RET_OK; - } MS_LOG(DEBUG) << "start Fusion"; std::vector input_nodes; @@ -579,77 +646,12 @@ int OpenCLSubGraph::FusionPass() { } // do element-wise fusion, like mul+add, mul+add+relu - if (TryMergeEltwiseEltwise(node, &nodes_, &removed_set) == RET_OK) { + if (TryMergeEltwiseEltwise(node, &removed_set, &nodes_) == RET_OK) { continue; } - // do special fusion, like pad+conv2d, fc+reshape - switch (node->Type()) { - case schema::PrimitiveType_Conv2D: - case schema::PrimitiveType_DepthwiseConv2D: - case schema::PrimitiveType_DeConv2D: { - if (PredIs(node, schema::PrimitiveType_Pad, &nodes_)) { - TryMergePad(node, &removed_set); - } - break; - } - case schema::PrimitiveType_Pooling: { - if (PredIs(node, schema::PrimitiveType_Pad, &nodes_)) { - TryMergePad(node, &removed_set); - } - break; - } - case schema::PrimitiveType_Reshape: { - if (PredIs(node, schema::PrimitiveType_FullConnection, &nodes_)) { - TryMergeFcReshape(node, &removed_set); - } else if (PredIs(node, schema::PrimitiveType_Conv2D, &nodes_)) { - TryMergeConvReshape(node, &removed_set); - } - break; - } - case schema::PrimitiveType_FullConnection: { - if (PredIs(node, schema::PrimitiveType_Reshape, &nodes_)) { - TryMergeReshapeFc(node, &removed_set); - } - break; - } - case schema::PrimitiveType_Activation: { - // try merge Conv2D/FC(without act) + RELU/RELU6/TANH - auto *param = reinterpret_cast(reinterpret_cast(node)->GetParameter()); - MS_ASSERT(param); - if (param->type_ == ActivationType_RELU || param->type_ == ActivationType_RELU6 || - param->type_ == ActivationType_TANH) { - if (PredIs(node, schema::PrimitiveType_Conv2D, &nodes_)) { - TryMergeActivation(node, &removed_set); - break; - } else if (PredIs(node, schema::PrimitiveType_FullConnection, &nodes_)) { - TryMergeActivation(node, &removed_set); - break; - } - } - if (std::any_of(ArithmeticPrimitives.begin(), ArithmeticPrimitives.end(), - [&](schema::PrimitiveType type) { return PredIs(node, type, &nodes_); })) { - TryMergeArithmeticAct(node, &removed_set); - } - break; - } - case schema::PrimitiveType_PReLU: { - if (PredIs(node, schema::PrimitiveType_Conv2D, &nodes_)) { - TryMergeConvPReLU(node, &removed_set); - break; - } - break; - } - case schema::PrimitiveType_Scale: { - if (PredIs(node, schema::PrimitiveType_DeConv2D, &nodes_)) { - TryMergeDeconvScale(node, &removed_set); - break; - } - break; - } - default: - break; - } + // do specific fusion, like pad+conv2d, fc+reshape, etc. + DoSpecificFusion(node, &removed_set, &nodes_); } for (auto kernel : removed_set) { diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc b/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc index fd3aae6330e..67f4f2051a0 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc @@ -330,6 +330,14 @@ int OpenCLSubGraph::Prepare() { return mindspore::lite::RET_NULL_PTR; } auto opencl_kernel = reinterpret_cast(node); + std::set pre_init_weight_list = {schema::PrimitiveType_MatMul, schema::PrimitiveType_BiasAdd}; + if (pre_init_weight_list.find(opencl_kernel->Type()) != pre_init_weight_list.end()) { + ret = opencl_kernel->InitWeights(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "init weights " << node->name() << " failed"; + return ret; + } + } if (opencl_kernel->GetInferShapeFlag()) { ret = node->Prepare(); if (ret != RET_OK) { diff --git a/mindspore/lite/src/runtime/opencl/opencl_runtime.cc b/mindspore/lite/src/runtime/opencl/opencl_runtime.cc index 2bbe79b990b..989bbee5294 100644 --- a/mindspore/lite/src/runtime/opencl/opencl_runtime.cc +++ b/mindspore/lite/src/runtime/opencl/opencl_runtime.cc @@ -72,36 +72,10 @@ void printf_callback(const char *buffer, size_t length, size_t final, void *user fwrite(buffer, 1, length, stdout); } -// Init will get platforms info, get devices info, create opencl context. -int OpenCLRuntime::Init() { - std::unique_lock lck(g_init_mtx); - if (init_state_ == InitSuccess) { - return RET_OK; - } else if (init_state_ == InitFailed) { - return RET_ERROR; - } - init_state_ = InitFailed; - - MS_LOG(INFO) << "OpenCL version: CL_TARGET_OPENCL_VERSION " << CL_TARGET_OPENCL_VERSION; - MS_LOG(INFO) << "CL_HPP_TARGET_OPENCL_VERSION " << CL_HPP_TARGET_OPENCL_VERSION; - MS_LOG(INFO) << "CL_HPP_MINIMUM_OPENCL_VERSION " << CL_HPP_MINIMUM_OPENCL_VERSION; - -#ifdef USE_OPENCL_WRAPPER - if (!lite::opencl::LoadOpenCLLibrary(&handle_)) { - MS_LOG(ERROR) << "Load OpenCL symbols failed!"; - return RET_ERROR; - } -#endif // USE_OPENCL_WRAPPER - - std::vector platforms; - cl_int ret = cl::Platform::get(&platforms); - if (platforms.empty()) { - MS_LOG(ERROR) << "OpenCL Platform not found!" << CLErrorCode(ret); - return RET_ERROR; - } - +int OpenCLRuntime::InitGPUDevice(std::vector &platforms) { // search GPU std::vector devices; + int ret = RET_OK; for (auto &platform : platforms) { std::string platform_name; ret = platform.getInfo(CL_PLATFORM_NAME, &platform_name); @@ -148,45 +122,6 @@ int OpenCLRuntime::Init() { << max_work_item_sizes_[2]; gpu_info_ = ParseGpuInfo(device_name, device_version); -// cl_int ret; -#if defined(SHARING_MEM_WITH_OPENGL) && (CL_HPP_TARGET_OPENCL_VERSION >= 120) - // create context from glcontext - MS_LOG(INFO) << "Create special opencl context to share with OpenGL"; - cl_context_properties context_prop[] = {CL_GL_CONTEXT_KHR, (cl_context_properties)eglGetCurrentContext(), - CL_EGL_DISPLAY_KHR, (cl_context_properties)eglGetCurrentDisplay(), 0}; - context_ = new (std::nothrow) cl::Context(std::vector{*device_}, context_prop, nullptr, nullptr, &ret); - - if (ret != CL_SUCCESS) { - MS_LOG(ERROR) << "Create special OpenCL context failed, Create common OpenCL context then."; - context_ = new (std::nothrow) cl::Context(std::vector{*device_}, nullptr, nullptr, nullptr, &ret); - if (context_ == nullptr) { - delete device_; - MS_LOG(ERROR) << "Create OpenCL context failed!"; - return RET_ERROR; - } - } -#else - MS_LOG(INFO) << "Create common opencl context"; -#ifdef Debug - std::vector ctx_properties = {CL_CONTEXT_PLATFORM, - (cl_context_properties)platforms[0](), - CL_PRINTF_CALLBACK_ARM, - (cl_context_properties)printf_callback, - CL_PRINTF_BUFFERSIZE_ARM, - 0x1000000, - 0}; - context_ = - new (std::nothrow) cl::Context(std::vector{*device_}, ctx_properties.data(), nullptr, nullptr, &ret); -#else - context_ = new (std::nothrow) cl::Context(std::vector{*device_}, nullptr, nullptr, nullptr, &ret); -#endif -#endif - if (ret != CL_SUCCESS) { - delete device_; - MS_LOG(ERROR) << "Context create failed: " << CLErrorCode(ret); - return RET_ERROR; - } - // get cache size, compute units and frequency. ret = device_->getInfo(CL_DEVICE_GLOBAL_MEM_CACHE_SIZE, &global_memery_cachesize_); if (ret != CL_SUCCESS) { @@ -235,6 +170,48 @@ int OpenCLRuntime::Init() { MS_LOG(INFO) << "Max Alloc Size: " << max_alloc_size_; MS_LOG(INFO) << "Compute Unit: " << compute_units_; MS_LOG(INFO) << "Clock Frequency: " << max_freq_ << " MHz"; + return RET_OK; +} + +int OpenCLRuntime::InitQueue(std::vector &platforms) { + cl_int ret; +#if defined(SHARING_MEM_WITH_OPENGL) && (CL_HPP_TARGET_OPENCL_VERSION >= 120) + // create context from glcontext + MS_LOG(INFO) << "Create special opencl context to share with OpenGL"; + cl_context_properties context_prop[] = {CL_GL_CONTEXT_KHR, (cl_context_properties)eglGetCurrentContext(), + CL_EGL_DISPLAY_KHR, (cl_context_properties)eglGetCurrentDisplay(), 0}; + context_ = new (std::nothrow) cl::Context(std::vector{*device_}, context_prop, nullptr, nullptr, &ret); + + if (ret != CL_SUCCESS) { + MS_LOG(ERROR) << "Create special OpenCL context failed, Create common OpenCL context then."; + context_ = new (std::nothrow) cl::Context(std::vector{*device_}, nullptr, nullptr, nullptr, &ret); + if (context_ == nullptr) { + delete device_; + MS_LOG(ERROR) << "Create OpenCL context failed!"; + return RET_ERROR; + } + } +#else + MS_LOG(INFO) << "Create common opencl context"; +#ifdef Debug + std::vector ctx_properties = {CL_CONTEXT_PLATFORM, + (cl_context_properties)platforms[0](), + CL_PRINTF_CALLBACK_ARM, + (cl_context_properties)printf_callback, + CL_PRINTF_BUFFERSIZE_ARM, + 0x1000000, + 0}; + context_ = + new (std::nothrow) cl::Context(std::vector{*device_}, ctx_properties.data(), nullptr, nullptr, &ret); +#else + context_ = new (std::nothrow) cl::Context(std::vector{*device_}, nullptr, nullptr, nullptr, &ret); +#endif +#endif + if (ret != CL_SUCCESS) { + delete device_; + MS_LOG(ERROR) << "Context create failed: " << CLErrorCode(ret); + return RET_ERROR; + } default_command_queue_ = new (std::nothrow) cl::CommandQueue(*context_, *device_, 0, &ret); if (ret != CL_SUCCESS) { @@ -252,6 +229,44 @@ int OpenCLRuntime::Init() { MS_LOG(ERROR) << "Profiling command Queue create failed: " << CLErrorCode(ret); return RET_ERROR; } + return RET_OK; +} + +// Init will get platforms info, get devices info, create opencl context. +int OpenCLRuntime::Init() { + std::unique_lock lck(g_init_mtx); + if (init_state_ == InitSuccess) { + return RET_OK; + } else if (init_state_ == InitFailed) { + return RET_ERROR; + } + init_state_ = InitFailed; + + MS_LOG(INFO) << "OpenCL version: CL_TARGET_OPENCL_VERSION " << CL_TARGET_OPENCL_VERSION; + MS_LOG(INFO) << "CL_HPP_TARGET_OPENCL_VERSION " << CL_HPP_TARGET_OPENCL_VERSION; + MS_LOG(INFO) << "CL_HPP_MINIMUM_OPENCL_VERSION " << CL_HPP_MINIMUM_OPENCL_VERSION; + +#ifdef USE_OPENCL_WRAPPER + if (!lite::opencl::LoadOpenCLLibrary(&handle_)) { + MS_LOG(ERROR) << "Load OpenCL symbols failed!"; + return RET_ERROR; + } +#endif // USE_OPENCL_WRAPPER + std::vector platforms; + cl_int ret = cl::Platform::get(&platforms); + if (platforms.empty()) { + MS_LOG(ERROR) << "OpenCL Platform not found!" << CLErrorCode(ret); + return RET_ERROR; + } + auto ms_ret = InitGPUDevice(platforms); + if (ms_ret != RET_OK) { + return ms_ret; + } + + ms_ret = InitQueue(platforms); + if (ms_ret != RET_OK) { + return ms_ret; + } allocator_ = new (std::nothrow) OpenCLAllocator(this); if (allocator_ == nullptr) { @@ -289,10 +304,6 @@ int OpenCLRuntime::Uninit() { profiling_command_queue_ = nullptr; context_ = nullptr; device_ = nullptr; -#ifdef USE_OPENCL_WRAPPER - lite::opencl::UnLoadOpenCLLibrary(handle_); - handle_ = nullptr; -#endif init_state_ = UnInit; return RET_OK; } diff --git a/mindspore/lite/src/runtime/opencl/opencl_runtime.h b/mindspore/lite/src/runtime/opencl/opencl_runtime.h index 17c2c7d749b..1c4c0d044dd 100644 --- a/mindspore/lite/src/runtime/opencl/opencl_runtime.h +++ b/mindspore/lite/src/runtime/opencl/opencl_runtime.h @@ -56,7 +56,7 @@ class OpenCLRuntime { cl::Context *Context(); cl::Device *Device(); OpenCLAllocator *GetAllocator() { return allocator_; } - cl::CommandQueue *GetDefaultCommandQueue() { return default_command_queue_; } + cl::CommandQueue *GetDefaultCommandQueue() { return profiling_ ? profiling_command_queue_ : default_command_queue_; } uint64_t DeviceGlobalMemoryCacheSize() const; int DeviceMaxWorkGroupSize() const; uint32_t DeviceComputeUnits() const; @@ -101,7 +101,7 @@ class OpenCLRuntime { return kernel.setArg(index, *image); } default: - MS_LOG(ERROR) << "Unsupport opencl memory type: " << static_cast(mem_type); + MS_LOG(ERROR) << "Unsupported opencl memory type: " << static_cast(mem_type); return CL_IMAGE_FORMAT_NOT_SUPPORTED; } } @@ -159,6 +159,8 @@ class OpenCLRuntime { bool LoadProgram(const std::string &program_name, cl::Program *program); bool BuildProgram(const std::string &build_options, const cl::Program &program); + int InitGPUDevice(std::vector &platforms); + int InitQueue(std::vector &platforms); private: static InitState init_state_;