From d5034cad264bb1ca76293e7a5d3f14ae7f6ea7c1 Mon Sep 17 00:00:00 2001 From: kai00 Date: Tue, 22 Sep 2020 19:34:38 +0800 Subject: [PATCH] fp16 weight dequant --- .../arm/fp16/convolution_depthwise_fp16.cc | 24 +++++++++++++++++++ .../kernel/arm/fp16/convolution_fp16.cc | 24 +++++++++++++++++++ .../arm/fp16/deconvolution_depthwise_fp16.cc | 24 +++++++++++++++++++ .../kernel/arm/fp16/deconvolution_fp16.cc | 24 +++++++++++++++++++ .../kernel/arm/fp16/fullconnection_fp16.cc | 23 ++++++++++++++++++ .../runtime/kernel/arm/fp16/matmul_fp16.cc | 22 +++++++++++++++++ 6 files changed, 141 insertions(+) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc index a91c045115..8602b14df1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc @@ -142,6 +142,18 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector const mindspore::lite::PrimitiveC *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); + + auto *weight_tensor = inputs.at(kWeightIndex); + auto *restore_data = weight_tensor->MutableData(); + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); + if (dequant_weight == nullptr) { + MS_LOG(ERROR) << "dequant data is nullptr."; + return nullptr; + } + weight_tensor->SetData(dequant_weight); + } + auto conv_param = reinterpret_cast(opParameter); kernel::LiteKernel *kernel; if (conv_param->input_channel_ < 32) { @@ -152,6 +164,10 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector } if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return nullptr; } auto ret = kernel->Init(); @@ -159,8 +175,16 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector delete kernel; MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return nullptr; } + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return kernel; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc index 85c59e3d15..bb78a4f124 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -218,6 +218,18 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector & const mindspore::lite::PrimitiveC *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); + + auto *weight_tensor = inputs.at(kWeightIndex); + auto *restore_data = weight_tensor->MutableData(); + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); + if (dequant_weight == nullptr) { + MS_LOG(ERROR) << "dequant data is nullptr."; + return nullptr; + } + weight_tensor->SetData(dequant_weight); + } + auto conv_param = reinterpret_cast(opParameter); int kernel_h = conv_param->kernel_h_; int kernel_w = conv_param->kernel_w_; @@ -249,6 +261,10 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector & } if (kernel == nullptr) { MS_LOG(DEBUG) << "Create conv fp16 kernel failed."; + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return nullptr; } auto ret = kernel->Init(); @@ -256,8 +272,16 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector & delete kernel; MS_LOG(INFO) << "Init fp16 kernel failed, name: " << opParameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return nullptr; } + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return kernel; } REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Conv2D, CpuConvFp16KernelCreator) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc index f087932c66..09f6627d95 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc @@ -201,9 +201,25 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vectorMutableData(); + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); + if (dequant_weight == nullptr) { + MS_LOG(ERROR) << "dequant data is nullptr."; + return nullptr; + } + weight_tensor->SetData(dequant_weight); + } + auto kernel = new (std::nothrow) DeconvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return nullptr; } auto ret = kernel->Init(); @@ -211,8 +227,16 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vectorname_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return nullptr; } + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return kernel; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc index 2f09e6dead..dd40cb0633 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc @@ -208,9 +208,25 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector const mindspore::lite::PrimitiveC *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); + + auto *weight_tensor = inputs.at(kWeightIndex); + auto *restore_data = weight_tensor->MutableData(); + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); + if (dequant_weight == nullptr) { + MS_LOG(ERROR) << "dequant data is nullptr."; + return nullptr; + } + weight_tensor->SetData(dequant_weight); + } + auto kernel = new (std::nothrow) DeConvolutionFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return nullptr; } auto ret = kernel->Init(); @@ -218,8 +234,16 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector delete kernel; MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return nullptr; } + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return kernel; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc index a8b3f5ab7b..b1c4bf44ec 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc @@ -170,9 +170,24 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vectordata_c(); + if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { + auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); + if (dequant_weight == nullptr) { + MS_LOG(ERROR) << "dequant data is nullptr."; + return nullptr; + } + weight_tensor->SetData(dequant_weight); + } auto *kernel = new (std::nothrow) FullconnectionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; + if (!weight_tensor->GetQuantParams().empty()) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return nullptr; } auto ret = kernel->Init(); @@ -180,8 +195,16 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vectorname_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); delete kernel; + if (!weight_tensor->GetQuantParams().empty()) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return nullptr; } + if (!weight_tensor->GetQuantParams().empty()) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return kernel; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc index bc9fcbf698..b03bf3c31c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc @@ -253,9 +253,23 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector const std::vector &outputs, OpParameter *opParameter, const lite::InnerContext *ctx, const kernel::KernelKey &desc, const mindspore::lite::PrimitiveC *primitive) { + auto *weight_tensor = inputs.at(kWeightIndex); + auto *restore_data = weight_tensor->data_c(); + if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { + auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); + if (dequant_weight == nullptr) { + MS_LOG(ERROR) << "dequant data is nullptr."; + return nullptr; + } + weight_tensor->SetData(dequant_weight); + } auto *kernel = new (std::nothrow) MatmulFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; + if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return nullptr; } auto ret = kernel->Init(); @@ -263,8 +277,16 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); delete kernel; + if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return nullptr; } + if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return kernel; }