forked from mindspore-Ecosystem/mindspore
!11517 [MS][LITE][GPU]init weight if not infer shape
From: @chenzupeng Reviewed-by: Signed-off-by:
This commit is contained in:
commit
78c733ffbe
|
@ -76,13 +76,13 @@ int ArithmeticOpenCLKernel::InitWeights() {
|
|||
auto fp16_enable = ocl_runtime_->GetFp16Enable();
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
const auto &in_tensor = in_tensors_.at(i);
|
||||
GpuTensorInfo *in_shape = (i == 0) ? &in0_shape_ : &in1_shape_;
|
||||
GpuTensorInfo in_shape = GpuTensorInfo(in_tensor);
|
||||
if (in_tensor->IsConst()) {
|
||||
std::vector<char> weight(in_shape->Image2DSize, 0);
|
||||
std::vector<char> weight(in_shape.Image2DSize, 0);
|
||||
bool src_is_fp16 = in_tensor->data_type() == kNumberTypeFloat16;
|
||||
PackNHWCToNHWC4(in_tensor->data_c(), weight.data(), src_is_fp16, fp16_enable, *in_shape);
|
||||
PackNHWCToNHWC4(in_tensor->data_c(), weight.data(), src_is_fp16, fp16_enable, in_shape);
|
||||
size_t dtype = fp16_enable ? CL_HALF_FLOAT : CL_FLOAT;
|
||||
ImageSize img_size{in_shape->width, in_shape->height, dtype};
|
||||
ImageSize img_size{in_shape.width, in_shape.height, dtype};
|
||||
auto weight_ptr_ = allocator->Malloc(img_size, weight.data());
|
||||
weight_ptrs_.push_back(weight_ptr_);
|
||||
} else {
|
||||
|
@ -152,7 +152,10 @@ int ArithmeticOpenCLKernel::Prepare() {
|
|||
}
|
||||
|
||||
SetGlobalLocal();
|
||||
InitWeights();
|
||||
// BiasAdd InitWeight will be called in opencl_subgraph prepare
|
||||
if (Type() != PrimitiveType_BiasAdd) {
|
||||
InitWeights();
|
||||
}
|
||||
SetConstArgs();
|
||||
MS_LOG(DEBUG) << kernel_name_ << " Init Done!";
|
||||
return RET_OK;
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <set>
|
||||
#include "nnacl/fp32/common_func_fp32.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/ops/conv2d.h"
|
||||
#ifndef PROGRAM_WITH_IL
|
||||
#include "src/runtime/kernel/opencl/cl/conv2d_transpose.cl.inc"
|
||||
#endif
|
||||
|
@ -125,6 +126,14 @@ void Conv2dTransposeOpenCLKernel::SetConstArgs() {
|
|||
}
|
||||
|
||||
int Conv2dTransposeOpenCLKernel::InitWeights() {
|
||||
auto ret = InitFilter();
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
return InitBias();
|
||||
}
|
||||
|
||||
int Conv2dTransposeOpenCLKernel::InitFilter() {
|
||||
auto ret = DequantWeight();
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
|
@ -185,8 +194,15 @@ int Conv2dTransposeOpenCLKernel::InitWeights() {
|
|||
}
|
||||
allocator->UnmapBuffer(padWeight_);
|
||||
FreeDequantedWeight();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Conv2dTransposeOpenCLKernel::InitBias() {
|
||||
// init bias_(image2d mem)
|
||||
auto allocator = ocl_runtime_->GetAllocator();
|
||||
auto data_size = enable_fp16_ ? sizeof(int16_t) : sizeof(float);
|
||||
int co = out_tensors_[0]->shape()[3];
|
||||
int div_co = UP_DIV(co, C4NUM);
|
||||
size_t im_dst_x, im_dst_y;
|
||||
im_dst_x = div_co;
|
||||
im_dst_y = 1;
|
||||
|
@ -225,6 +241,20 @@ int Conv2dTransposeOpenCLKernel::Run() {
|
|||
return mindspore::lite::RET_OK;
|
||||
}
|
||||
|
||||
int Conv2dTransposeOpenCLKernel::InferShape() {
|
||||
auto ret = OpenCLKernel::InferShape();
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
auto param = reinterpret_cast<ConvParameter *>(op_parameter_);
|
||||
auto conv2d_lite_primitive = (lite::Conv2D *)primitive_;
|
||||
param->pad_u_ = conv2d_lite_primitive->PadUp();
|
||||
param->pad_d_ = conv2d_lite_primitive->PadDown();
|
||||
param->pad_l_ = conv2d_lite_primitive->PadLeft();
|
||||
param->pad_r_ = conv2d_lite_primitive->PadRight();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_DeConv2D, OpenCLKernelCreator<Conv2dTransposeOpenCLKernel>)
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_DeConv2D, OpenCLKernelCreator<Conv2dTransposeOpenCLKernel>)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -34,8 +34,11 @@ class Conv2dTransposeOpenCLKernel : public OpenCLKernel {
|
|||
int Prepare() override;
|
||||
int CheckSpecs() override;
|
||||
int InitWeights() override;
|
||||
int InitFilter();
|
||||
int InitBias();
|
||||
void SetConstArgs() override;
|
||||
void SetGlobalLocal() override;
|
||||
int InferShape() override;
|
||||
|
||||
private:
|
||||
void *padWeight_{nullptr};
|
||||
|
|
|
@ -52,8 +52,10 @@ int FillOpenCLKernel::RunShape() {
|
|||
auto allocator_ = ocl_runtime_->GetAllocator();
|
||||
auto src_data = out_tensors_[0]->data_c();
|
||||
cl_float4 fill_value = {default_, default_, default_, default_};
|
||||
for (int i = 0; i < in_tensors_[0]->shape().size(); ++i) {
|
||||
fill_value.s[0] = in_tensors_[0]->shape()[i];
|
||||
auto tensor_shape = in_tensors_[0]->shape();
|
||||
void *tensor_shape_data = tensor_shape.data();
|
||||
for (int i = 0; i < tensor_shape.size(); ++i) {
|
||||
fill_value.s[0] = reinterpret_cast<float *>(tensor_shape_data)[i];
|
||||
size_t index = static_cast<size_t>(i);
|
||||
auto src_origin = cl::array<cl::size_type, 3U>{0, index, 0};
|
||||
auto region = cl::array<cl::size_type, 3U>{1, 1, 1};
|
||||
|
|
|
@ -95,10 +95,6 @@ int MatMulOpenCLKernel::Prepare() {
|
|||
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name);
|
||||
|
||||
#endif
|
||||
auto ret = InitWeights();
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
SetConstArgs();
|
||||
SetGlobalLocal();
|
||||
MS_LOG(DEBUG) << kernel_name << " Init Done!";
|
||||
|
@ -106,7 +102,7 @@ int MatMulOpenCLKernel::Prepare() {
|
|||
}
|
||||
|
||||
int MatMulOpenCLKernel::InitWeights() {
|
||||
if (act_weight_) {
|
||||
if (!in_tensors_[1]->IsConst()) {
|
||||
return RET_OK;
|
||||
}
|
||||
// ABMCI @ ABCICO = ABMCO
|
||||
|
@ -115,12 +111,27 @@ int MatMulOpenCLKernel::InitWeights() {
|
|||
return ret;
|
||||
}
|
||||
auto allocator = ocl_runtime_->GetAllocator();
|
||||
int ci = inShape[3];
|
||||
auto weight_shape = in_tensors_[1]->shape();
|
||||
int weight_ndim = weight_shape.size();
|
||||
std::vector<int> weight_shape_4d(MAX_DIMS, 1);
|
||||
for (int i = 0; i < weight_ndim; i++) {
|
||||
weight_shape_4d[MAX_DIMS - weight_ndim + i] = weight_shape[i];
|
||||
}
|
||||
auto param = reinterpret_cast<MatMulParameter *>(op_parameter_);
|
||||
transposeB = param->b_transpose_;
|
||||
enable_fp16_ = ocl_runtime_->GetFp16Enable();
|
||||
int ci, co;
|
||||
if (transposeB) {
|
||||
ci = weight_shape_4d[3];
|
||||
co = weight_shape_4d[2];
|
||||
} else {
|
||||
ci = weight_shape_4d[2];
|
||||
co = weight_shape_4d[3];
|
||||
}
|
||||
int ci4 = UP_DIV(ci, C4NUM);
|
||||
int co = outShape[3];
|
||||
int co4 = UP_DIV(co, C4NUM);
|
||||
int a = inShape[0];
|
||||
int b = inShape[1];
|
||||
int a = weight_shape_4d[0];
|
||||
int b = weight_shape_4d[1];
|
||||
|
||||
size_t dtype_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float);
|
||||
padWeight_ = allocator->Malloc(a * b * ci4 * co4 * C4NUM * C4NUM * dtype_size);
|
||||
|
|
|
@ -97,7 +97,7 @@ int ReshapeOpenCLKernel::Run() {
|
|||
}
|
||||
|
||||
int ReshapeOpenCLKernel::PreProcess() {
|
||||
if (Type() == PrimitiveType_Reshape) {
|
||||
if (Type() == PrimitiveType_Reshape && !infer_shape_flag_) {
|
||||
auto shape_tensor = in_tensors_[1];
|
||||
if (!shape_tensor->IsConst()) {
|
||||
ocl_runtime_->SyncCommandQueue();
|
||||
|
|
|
@ -215,7 +215,7 @@ int ScaleOpenCLKernel::Run() {
|
|||
}
|
||||
}
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, param->activation_type_);
|
||||
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_);
|
||||
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr, &event_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -53,6 +53,12 @@ int TransposeOpenCLKernel::Prepare() {
|
|||
perm_4d_[1] = 1;
|
||||
perm_4d_[2] = 2;
|
||||
perm_4d_[3] = tensor_size_.AlignAxis(param->perm_[1]);
|
||||
if (param->num_axes_ != tensor_size_.NDim) {
|
||||
perm_4d_[0] = 0;
|
||||
perm_4d_[1] = 1;
|
||||
perm_4d_[2] = 2;
|
||||
perm_4d_[3] = 3;
|
||||
}
|
||||
} else if (tensor_size_.NDim == 3) {
|
||||
perm_4d_[0] = tensor_size_.AlignAxis(param->perm_[0]);
|
||||
perm_4d_[1] = 1;
|
||||
|
@ -65,9 +71,9 @@ int TransposeOpenCLKernel::Prepare() {
|
|||
perm_4d_[3] = tensor_size_.AlignAxis(param->perm_[3]);
|
||||
} else {
|
||||
perm_4d_[0] = 0;
|
||||
perm_4d_[0] = 1;
|
||||
perm_4d_[0] = 2;
|
||||
perm_4d_[0] = 3;
|
||||
perm_4d_[1] = 1;
|
||||
perm_4d_[2] = 2;
|
||||
perm_4d_[3] = 3;
|
||||
}
|
||||
if (tensor_size_.N == 1 && perm_4d_[0] == 0 && perm_4d_[1] == 3 && perm_4d_[2] == 1 && perm_4d_[3] == 2) {
|
||||
type_ = TransposeType::AXIS0312;
|
||||
|
|
|
@ -218,9 +218,12 @@ inline void MergeRemoveB(LiteKernel *a, LiteKernel *b, std::set<LiteKernel *> *r
|
|||
// Pad + DeConv2D
|
||||
// Pad + Pooling
|
||||
template <typename ParamType>
|
||||
void TryMergePad(LiteKernel *node, std::set<LiteKernel *> *removed_set) {
|
||||
void TryMergePadXxx(LiteKernel *node, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) {
|
||||
MS_ASSERT(node);
|
||||
MS_ASSERT(removed_set);
|
||||
if (!PredIs(node, schema::PrimitiveType_Pad, nodes)) {
|
||||
return;
|
||||
}
|
||||
LiteKernel *pad = node->in_kernels().front();
|
||||
MS_ASSERT(pad);
|
||||
if (pad->in_tensors().front()->shape().size() != 4) {
|
||||
|
@ -245,9 +248,12 @@ void TryMergePad(LiteKernel *node, std::set<LiteKernel *> *removed_set) {
|
|||
}
|
||||
|
||||
// Conv2D + Reshape(N11C->NC)
|
||||
void TryMergeConvReshape(LiteKernel *reshape, std::set<LiteKernel *> *removed_set) {
|
||||
void TryMergeConvReshape(LiteKernel *reshape, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) {
|
||||
MS_ASSERT(reshape);
|
||||
MS_ASSERT(removed_set);
|
||||
if (!PredIs(reshape, schema::PrimitiveType_Conv2D, nodes)) {
|
||||
return;
|
||||
}
|
||||
if (N11C_NC(reshape)) {
|
||||
LiteKernel *conv = reshape->in_kernels().front();
|
||||
MS_ASSERT(conv);
|
||||
|
@ -257,9 +263,12 @@ void TryMergeConvReshape(LiteKernel *reshape, std::set<LiteKernel *> *removed_se
|
|||
}
|
||||
|
||||
// FullConnection + Reshape(NC->N11C or N11C->NC)
|
||||
void TryMergeFcReshape(LiteKernel *reshape, std::set<LiteKernel *> *removed_set) {
|
||||
void TryMergeFcReshape(LiteKernel *reshape, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) {
|
||||
MS_ASSERT(reshape);
|
||||
MS_ASSERT(removed_set);
|
||||
if (!PredIs(reshape, schema::PrimitiveType_FullConnection, nodes)) {
|
||||
return;
|
||||
}
|
||||
bool NC_N11C_flag = NC_N11C(reshape);
|
||||
if (NC_N11C_flag || N11C_NC(reshape)) {
|
||||
LiteKernel *fc = reshape->in_kernels().front();
|
||||
|
@ -272,9 +281,12 @@ void TryMergeFcReshape(LiteKernel *reshape, std::set<LiteKernel *> *removed_set)
|
|||
|
||||
// Reshape(NC11->NC) + FullConnection
|
||||
// Reshape(NC->N11C) + FullConnection
|
||||
void TryMergeReshapeFc(LiteKernel *fc, std::set<LiteKernel *> *removed_set) {
|
||||
void TryMergeReshapeFc(LiteKernel *fc, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) {
|
||||
MS_ASSERT(fc);
|
||||
MS_ASSERT(removed_set);
|
||||
if (!PredIs(fc, schema::PrimitiveType_Reshape, nodes)) {
|
||||
return;
|
||||
}
|
||||
LiteKernel *reshape = fc->in_kernels().front();
|
||||
MS_ASSERT(reshape);
|
||||
bool NC11_NC_flag = NC11_NC(reshape);
|
||||
|
@ -308,7 +320,7 @@ void TryMergeArithmeticAct(LiteKernel *act, std::set<LiteKernel *> *removed_set)
|
|||
// Conv2D(NO_ACTIVATION) + Activation(RELU/RELU6/TANH)
|
||||
// FullConnection(NO_ACTIVATION) + Activation(RELU/RELU6/TANH)
|
||||
template <typename ParamType>
|
||||
void TryMergeActivation(LiteKernel *act, std::set<LiteKernel *> *removed_set) {
|
||||
void TryMergeXxxActivation(LiteKernel *act, std::set<LiteKernel *> *removed_set) {
|
||||
MS_ASSERT(node);
|
||||
MS_ASSERT(removed_set);
|
||||
auto *act_param = reinterpret_cast<ActivationParameter *>(reinterpret_cast<OpenCLKernel *>(act)->GetParameter());
|
||||
|
@ -316,7 +328,6 @@ void TryMergeActivation(LiteKernel *act, std::set<LiteKernel *> *removed_set) {
|
|||
auto *param = reinterpret_cast<ParamType *>(reinterpret_cast<OpenCLKernel *>(node)->GetParameter());
|
||||
MS_ASSERT(param);
|
||||
if (param->act_type_ == ActType_No) {
|
||||
param->act_type_ = static_cast<ActType>(act_param->type_);
|
||||
std::string act_name;
|
||||
if (act_param->type_ == ActivationType_RELU) {
|
||||
act_name = "RELU";
|
||||
|
@ -324,16 +335,25 @@ void TryMergeActivation(LiteKernel *act, std::set<LiteKernel *> *removed_set) {
|
|||
act_name = "RELU6";
|
||||
} else if (act_param->type_ == ActivationType_TANH) {
|
||||
act_name = "TANH";
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Merge " + GetTypeName(node) + "(NO_ACTIVATION) and Activation(" + act_name +
|
||||
") is not supported";
|
||||
return;
|
||||
}
|
||||
param->act_type_ = static_cast<ActType>(act_param->type_);
|
||||
MergeRemoveB(node, act, removed_set);
|
||||
MS_LOG(DEBUG) << "Merge " + GetTypeName(node) + "(NO_ACTIVATION) and Activation(" + act_name + ") success";
|
||||
}
|
||||
}
|
||||
|
||||
// Conv2D(NO_ACTIVATION) + PReLU(weight is scalar)
|
||||
void TryMergeConvPReLU(LiteKernel *prelu, std::set<LiteKernel *> *removed_set) {
|
||||
void TryMergeConvPReLU(LiteKernel *prelu, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) {
|
||||
MS_ASSERT(prelu);
|
||||
MS_ASSERT(removed_set);
|
||||
if (!PredIs(prelu, schema::PrimitiveType_Conv2D, nodes)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (prelu->in_tensors().size() != 2) {
|
||||
return;
|
||||
}
|
||||
|
@ -409,7 +429,7 @@ int TryFusionConvScaleWeight(LiteKernel *conv_kernel, LiteKernel *scale_kernel)
|
|||
bias_data[co] *= scale_data[co];
|
||||
bias_data[co] += offset_data[co];
|
||||
}
|
||||
} else { // if deconv dont't have bias, let scale's offset be deconv's bias
|
||||
} else { // if deconv don't have bias, let scale's offset be deconv's bias
|
||||
auto tmp = conv_kernel->in_tensors();
|
||||
tmp.push_back(offset);
|
||||
conv_kernel->set_in_tensors(tmp);
|
||||
|
@ -418,9 +438,12 @@ int TryFusionConvScaleWeight(LiteKernel *conv_kernel, LiteKernel *scale_kernel)
|
|||
}
|
||||
|
||||
// DeConv2D + Scale (can't both has activation)
|
||||
void TryMergeDeconvScale(LiteKernel *scale, std::set<LiteKernel *> *removed_set) {
|
||||
void TryMergeDeconvScale(LiteKernel *scale, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) {
|
||||
MS_ASSERT(scale);
|
||||
MS_ASSERT(removed_set);
|
||||
if (!PredIs(scale, schema::PrimitiveType_DeConv2D, nodes)) {
|
||||
return;
|
||||
}
|
||||
LiteKernel *deconv = scale->in_kernels().front();
|
||||
MS_ASSERT(deconv);
|
||||
|
||||
|
@ -493,7 +516,7 @@ void CreateEltwiseKernelReplaceOld(FusionEltwiseParameter *param, LiteKernel *ol
|
|||
}
|
||||
|
||||
// Eltwise + Eltwise
|
||||
int TryMergeEltwiseEltwise(LiteKernel *node, std::vector<LiteKernel *> *nodes, std::set<LiteKernel *> *removed_set) {
|
||||
int TryMergeEltwiseEltwise(LiteKernel *node, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) {
|
||||
MS_ASSERT(node);
|
||||
MS_ASSERT(nodes);
|
||||
MS_ASSERT(removed_set);
|
||||
|
@ -536,12 +559,56 @@ int TryMergeEltwiseEltwise(LiteKernel *node, std::vector<LiteKernel *> *nodes, s
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void DoSpecificFusion(LiteKernel *node, std::set<LiteKernel *> *removed_set, std::vector<LiteKernel *> *nodes) {
|
||||
switch (node->Type()) {
|
||||
case schema::PrimitiveType_Conv2D:
|
||||
case schema::PrimitiveType_DepthwiseConv2D:
|
||||
case schema::PrimitiveType_DeConv2D: {
|
||||
TryMergePadXxx<ConvParameter>(node, removed_set, nodes);
|
||||
break;
|
||||
}
|
||||
case schema::PrimitiveType_Pooling: {
|
||||
TryMergePadXxx<PoolingParameter>(node, removed_set, nodes);
|
||||
break;
|
||||
}
|
||||
case schema::PrimitiveType_Reshape: {
|
||||
TryMergeFcReshape(node, removed_set, nodes);
|
||||
TryMergeConvReshape(node, removed_set, nodes);
|
||||
break;
|
||||
}
|
||||
case schema::PrimitiveType_FullConnection: {
|
||||
TryMergeReshapeFc(node, removed_set, nodes);
|
||||
break;
|
||||
}
|
||||
case schema::PrimitiveType_Activation: {
|
||||
// try merge Conv2D/FC(without act) + RELU/RELU6/TANH
|
||||
// try merge Arithmetic(without act) + RELU/RELU6
|
||||
if (PredIs(node, schema::PrimitiveType_Conv2D, nodes)) {
|
||||
TryMergeXxxActivation<ConvParameter>(node, removed_set);
|
||||
} else if (PredIs(node, schema::PrimitiveType_FullConnection, nodes)) {
|
||||
TryMergeXxxActivation<MatMulParameter>(node, removed_set);
|
||||
} else if (std::any_of(ArithmeticPrimitives.begin(), ArithmeticPrimitives.end(),
|
||||
[&](schema::PrimitiveType type) { return PredIs(node, type, nodes); })) {
|
||||
TryMergeArithmeticAct(node, removed_set);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case schema::PrimitiveType_PReLU: {
|
||||
TryMergeConvPReLU(node, removed_set, nodes);
|
||||
break;
|
||||
}
|
||||
case schema::PrimitiveType_Scale: {
|
||||
TryMergeDeconvScale(node, removed_set, nodes);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
} // namespace
|
||||
|
||||
int OpenCLSubGraph::FusionPass() {
|
||||
if (!this->IsSubGraphInferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
MS_LOG(DEBUG) << "start Fusion";
|
||||
|
||||
std::vector<LiteKernel *> input_nodes;
|
||||
|
@ -579,77 +646,12 @@ int OpenCLSubGraph::FusionPass() {
|
|||
}
|
||||
|
||||
// do element-wise fusion, like mul+add, mul+add+relu
|
||||
if (TryMergeEltwiseEltwise(node, &nodes_, &removed_set) == RET_OK) {
|
||||
if (TryMergeEltwiseEltwise(node, &removed_set, &nodes_) == RET_OK) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// do special fusion, like pad+conv2d, fc+reshape
|
||||
switch (node->Type()) {
|
||||
case schema::PrimitiveType_Conv2D:
|
||||
case schema::PrimitiveType_DepthwiseConv2D:
|
||||
case schema::PrimitiveType_DeConv2D: {
|
||||
if (PredIs(node, schema::PrimitiveType_Pad, &nodes_)) {
|
||||
TryMergePad<ConvParameter>(node, &removed_set);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case schema::PrimitiveType_Pooling: {
|
||||
if (PredIs(node, schema::PrimitiveType_Pad, &nodes_)) {
|
||||
TryMergePad<PoolingParameter>(node, &removed_set);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case schema::PrimitiveType_Reshape: {
|
||||
if (PredIs(node, schema::PrimitiveType_FullConnection, &nodes_)) {
|
||||
TryMergeFcReshape(node, &removed_set);
|
||||
} else if (PredIs(node, schema::PrimitiveType_Conv2D, &nodes_)) {
|
||||
TryMergeConvReshape(node, &removed_set);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case schema::PrimitiveType_FullConnection: {
|
||||
if (PredIs(node, schema::PrimitiveType_Reshape, &nodes_)) {
|
||||
TryMergeReshapeFc(node, &removed_set);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case schema::PrimitiveType_Activation: {
|
||||
// try merge Conv2D/FC(without act) + RELU/RELU6/TANH
|
||||
auto *param = reinterpret_cast<ActivationParameter *>(reinterpret_cast<OpenCLKernel *>(node)->GetParameter());
|
||||
MS_ASSERT(param);
|
||||
if (param->type_ == ActivationType_RELU || param->type_ == ActivationType_RELU6 ||
|
||||
param->type_ == ActivationType_TANH) {
|
||||
if (PredIs(node, schema::PrimitiveType_Conv2D, &nodes_)) {
|
||||
TryMergeActivation<ConvParameter>(node, &removed_set);
|
||||
break;
|
||||
} else if (PredIs(node, schema::PrimitiveType_FullConnection, &nodes_)) {
|
||||
TryMergeActivation<MatMulParameter>(node, &removed_set);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (std::any_of(ArithmeticPrimitives.begin(), ArithmeticPrimitives.end(),
|
||||
[&](schema::PrimitiveType type) { return PredIs(node, type, &nodes_); })) {
|
||||
TryMergeArithmeticAct(node, &removed_set);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case schema::PrimitiveType_PReLU: {
|
||||
if (PredIs(node, schema::PrimitiveType_Conv2D, &nodes_)) {
|
||||
TryMergeConvPReLU(node, &removed_set);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case schema::PrimitiveType_Scale: {
|
||||
if (PredIs(node, schema::PrimitiveType_DeConv2D, &nodes_)) {
|
||||
TryMergeDeconvScale(node, &removed_set);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
// do specific fusion, like pad+conv2d, fc+reshape, etc.
|
||||
DoSpecificFusion(node, &removed_set, &nodes_);
|
||||
}
|
||||
|
||||
for (auto kernel : removed_set) {
|
||||
|
|
|
@ -330,6 +330,14 @@ int OpenCLSubGraph::Prepare() {
|
|||
return mindspore::lite::RET_NULL_PTR;
|
||||
}
|
||||
auto opencl_kernel = reinterpret_cast<kernel::OpenCLKernel *>(node);
|
||||
std::set<int> pre_init_weight_list = {schema::PrimitiveType_MatMul, schema::PrimitiveType_BiasAdd};
|
||||
if (pre_init_weight_list.find(opencl_kernel->Type()) != pre_init_weight_list.end()) {
|
||||
ret = opencl_kernel->InitWeights();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "init weights " << node->name() << " failed";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
if (opencl_kernel->GetInferShapeFlag()) {
|
||||
ret = node->Prepare();
|
||||
if (ret != RET_OK) {
|
||||
|
|
|
@ -72,36 +72,10 @@ void printf_callback(const char *buffer, size_t length, size_t final, void *user
|
|||
fwrite(buffer, 1, length, stdout);
|
||||
}
|
||||
|
||||
// Init will get platforms info, get devices info, create opencl context.
|
||||
int OpenCLRuntime::Init() {
|
||||
std::unique_lock<std::mutex> lck(g_init_mtx);
|
||||
if (init_state_ == InitSuccess) {
|
||||
return RET_OK;
|
||||
} else if (init_state_ == InitFailed) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
init_state_ = InitFailed;
|
||||
|
||||
MS_LOG(INFO) << "OpenCL version: CL_TARGET_OPENCL_VERSION " << CL_TARGET_OPENCL_VERSION;
|
||||
MS_LOG(INFO) << "CL_HPP_TARGET_OPENCL_VERSION " << CL_HPP_TARGET_OPENCL_VERSION;
|
||||
MS_LOG(INFO) << "CL_HPP_MINIMUM_OPENCL_VERSION " << CL_HPP_MINIMUM_OPENCL_VERSION;
|
||||
|
||||
#ifdef USE_OPENCL_WRAPPER
|
||||
if (!lite::opencl::LoadOpenCLLibrary(&handle_)) {
|
||||
MS_LOG(ERROR) << "Load OpenCL symbols failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
#endif // USE_OPENCL_WRAPPER
|
||||
|
||||
std::vector<cl::Platform> platforms;
|
||||
cl_int ret = cl::Platform::get(&platforms);
|
||||
if (platforms.empty()) {
|
||||
MS_LOG(ERROR) << "OpenCL Platform not found!" << CLErrorCode(ret);
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
int OpenCLRuntime::InitGPUDevice(std::vector<cl::Platform> &platforms) {
|
||||
// search GPU
|
||||
std::vector<cl::Device> devices;
|
||||
int ret = RET_OK;
|
||||
for (auto &platform : platforms) {
|
||||
std::string platform_name;
|
||||
ret = platform.getInfo(CL_PLATFORM_NAME, &platform_name);
|
||||
|
@ -148,45 +122,6 @@ int OpenCLRuntime::Init() {
|
|||
<< max_work_item_sizes_[2];
|
||||
|
||||
gpu_info_ = ParseGpuInfo(device_name, device_version);
|
||||
// cl_int ret;
|
||||
#if defined(SHARING_MEM_WITH_OPENGL) && (CL_HPP_TARGET_OPENCL_VERSION >= 120)
|
||||
// create context from glcontext
|
||||
MS_LOG(INFO) << "Create special opencl context to share with OpenGL";
|
||||
cl_context_properties context_prop[] = {CL_GL_CONTEXT_KHR, (cl_context_properties)eglGetCurrentContext(),
|
||||
CL_EGL_DISPLAY_KHR, (cl_context_properties)eglGetCurrentDisplay(), 0};
|
||||
context_ = new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, context_prop, nullptr, nullptr, &ret);
|
||||
|
||||
if (ret != CL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Create special OpenCL context failed, Create common OpenCL context then.";
|
||||
context_ = new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, nullptr, nullptr, nullptr, &ret);
|
||||
if (context_ == nullptr) {
|
||||
delete device_;
|
||||
MS_LOG(ERROR) << "Create OpenCL context failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
#else
|
||||
MS_LOG(INFO) << "Create common opencl context";
|
||||
#ifdef Debug
|
||||
std::vector<cl_context_properties> ctx_properties = {CL_CONTEXT_PLATFORM,
|
||||
(cl_context_properties)platforms[0](),
|
||||
CL_PRINTF_CALLBACK_ARM,
|
||||
(cl_context_properties)printf_callback,
|
||||
CL_PRINTF_BUFFERSIZE_ARM,
|
||||
0x1000000,
|
||||
0};
|
||||
context_ =
|
||||
new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, ctx_properties.data(), nullptr, nullptr, &ret);
|
||||
#else
|
||||
context_ = new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, nullptr, nullptr, nullptr, &ret);
|
||||
#endif
|
||||
#endif
|
||||
if (ret != CL_SUCCESS) {
|
||||
delete device_;
|
||||
MS_LOG(ERROR) << "Context create failed: " << CLErrorCode(ret);
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// get cache size, compute units and frequency.
|
||||
ret = device_->getInfo(CL_DEVICE_GLOBAL_MEM_CACHE_SIZE, &global_memery_cachesize_);
|
||||
if (ret != CL_SUCCESS) {
|
||||
|
@ -235,6 +170,48 @@ int OpenCLRuntime::Init() {
|
|||
MS_LOG(INFO) << "Max Alloc Size: " << max_alloc_size_;
|
||||
MS_LOG(INFO) << "Compute Unit: " << compute_units_;
|
||||
MS_LOG(INFO) << "Clock Frequency: " << max_freq_ << " MHz";
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int OpenCLRuntime::InitQueue(std::vector<cl::Platform> &platforms) {
|
||||
cl_int ret;
|
||||
#if defined(SHARING_MEM_WITH_OPENGL) && (CL_HPP_TARGET_OPENCL_VERSION >= 120)
|
||||
// create context from glcontext
|
||||
MS_LOG(INFO) << "Create special opencl context to share with OpenGL";
|
||||
cl_context_properties context_prop[] = {CL_GL_CONTEXT_KHR, (cl_context_properties)eglGetCurrentContext(),
|
||||
CL_EGL_DISPLAY_KHR, (cl_context_properties)eglGetCurrentDisplay(), 0};
|
||||
context_ = new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, context_prop, nullptr, nullptr, &ret);
|
||||
|
||||
if (ret != CL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Create special OpenCL context failed, Create common OpenCL context then.";
|
||||
context_ = new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, nullptr, nullptr, nullptr, &ret);
|
||||
if (context_ == nullptr) {
|
||||
delete device_;
|
||||
MS_LOG(ERROR) << "Create OpenCL context failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
#else
|
||||
MS_LOG(INFO) << "Create common opencl context";
|
||||
#ifdef Debug
|
||||
std::vector<cl_context_properties> ctx_properties = {CL_CONTEXT_PLATFORM,
|
||||
(cl_context_properties)platforms[0](),
|
||||
CL_PRINTF_CALLBACK_ARM,
|
||||
(cl_context_properties)printf_callback,
|
||||
CL_PRINTF_BUFFERSIZE_ARM,
|
||||
0x1000000,
|
||||
0};
|
||||
context_ =
|
||||
new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, ctx_properties.data(), nullptr, nullptr, &ret);
|
||||
#else
|
||||
context_ = new (std::nothrow) cl::Context(std::vector<cl::Device>{*device_}, nullptr, nullptr, nullptr, &ret);
|
||||
#endif
|
||||
#endif
|
||||
if (ret != CL_SUCCESS) {
|
||||
delete device_;
|
||||
MS_LOG(ERROR) << "Context create failed: " << CLErrorCode(ret);
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
default_command_queue_ = new (std::nothrow) cl::CommandQueue(*context_, *device_, 0, &ret);
|
||||
if (ret != CL_SUCCESS) {
|
||||
|
@ -252,6 +229,44 @@ int OpenCLRuntime::Init() {
|
|||
MS_LOG(ERROR) << "Profiling command Queue create failed: " << CLErrorCode(ret);
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
// Init will get platforms info, get devices info, create opencl context.
|
||||
int OpenCLRuntime::Init() {
|
||||
std::unique_lock<std::mutex> lck(g_init_mtx);
|
||||
if (init_state_ == InitSuccess) {
|
||||
return RET_OK;
|
||||
} else if (init_state_ == InitFailed) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
init_state_ = InitFailed;
|
||||
|
||||
MS_LOG(INFO) << "OpenCL version: CL_TARGET_OPENCL_VERSION " << CL_TARGET_OPENCL_VERSION;
|
||||
MS_LOG(INFO) << "CL_HPP_TARGET_OPENCL_VERSION " << CL_HPP_TARGET_OPENCL_VERSION;
|
||||
MS_LOG(INFO) << "CL_HPP_MINIMUM_OPENCL_VERSION " << CL_HPP_MINIMUM_OPENCL_VERSION;
|
||||
|
||||
#ifdef USE_OPENCL_WRAPPER
|
||||
if (!lite::opencl::LoadOpenCLLibrary(&handle_)) {
|
||||
MS_LOG(ERROR) << "Load OpenCL symbols failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
#endif // USE_OPENCL_WRAPPER
|
||||
std::vector<cl::Platform> platforms;
|
||||
cl_int ret = cl::Platform::get(&platforms);
|
||||
if (platforms.empty()) {
|
||||
MS_LOG(ERROR) << "OpenCL Platform not found!" << CLErrorCode(ret);
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ms_ret = InitGPUDevice(platforms);
|
||||
if (ms_ret != RET_OK) {
|
||||
return ms_ret;
|
||||
}
|
||||
|
||||
ms_ret = InitQueue(platforms);
|
||||
if (ms_ret != RET_OK) {
|
||||
return ms_ret;
|
||||
}
|
||||
|
||||
allocator_ = new (std::nothrow) OpenCLAllocator(this);
|
||||
if (allocator_ == nullptr) {
|
||||
|
@ -289,10 +304,6 @@ int OpenCLRuntime::Uninit() {
|
|||
profiling_command_queue_ = nullptr;
|
||||
context_ = nullptr;
|
||||
device_ = nullptr;
|
||||
#ifdef USE_OPENCL_WRAPPER
|
||||
lite::opencl::UnLoadOpenCLLibrary(handle_);
|
||||
handle_ = nullptr;
|
||||
#endif
|
||||
init_state_ = UnInit;
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -56,7 +56,7 @@ class OpenCLRuntime {
|
|||
cl::Context *Context();
|
||||
cl::Device *Device();
|
||||
OpenCLAllocator *GetAllocator() { return allocator_; }
|
||||
cl::CommandQueue *GetDefaultCommandQueue() { return default_command_queue_; }
|
||||
cl::CommandQueue *GetDefaultCommandQueue() { return profiling_ ? profiling_command_queue_ : default_command_queue_; }
|
||||
uint64_t DeviceGlobalMemoryCacheSize() const;
|
||||
int DeviceMaxWorkGroupSize() const;
|
||||
uint32_t DeviceComputeUnits() const;
|
||||
|
@ -101,7 +101,7 @@ class OpenCLRuntime {
|
|||
return kernel.setArg(index, *image);
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupport opencl memory type: " << static_cast<int>(mem_type);
|
||||
MS_LOG(ERROR) << "Unsupported opencl memory type: " << static_cast<int>(mem_type);
|
||||
return CL_IMAGE_FORMAT_NOT_SUPPORTED;
|
||||
}
|
||||
}
|
||||
|
@ -159,6 +159,8 @@ class OpenCLRuntime {
|
|||
|
||||
bool LoadProgram(const std::string &program_name, cl::Program *program);
|
||||
bool BuildProgram(const std::string &build_options, const cl::Program &program);
|
||||
int InitGPUDevice(std::vector<cl::Platform> &platforms);
|
||||
int InitQueue(std::vector<cl::Platform> &platforms);
|
||||
|
||||
private:
|
||||
static InitState init_state_;
|
||||
|
|
Loading…
Reference in New Issue