From 797221b1444b3510d273e6a5ce9626ae122ac7bc Mon Sep 17 00:00:00 2001 From: kai00 Date: Thu, 24 Sep 2020 21:24:23 +0800 Subject: [PATCH] float16 quant fix --- mindspore/lite/src/lite_kernel.cc | 1 + .../src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc | 4 ++++ .../lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc | 4 ++++ .../runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc | 4 ++++ .../lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc | 4 ++++ .../lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc | 4 ++++ mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc | 4 ++++ 7 files changed, 25 insertions(+) diff --git a/mindspore/lite/src/lite_kernel.cc b/mindspore/lite/src/lite_kernel.cc index c20ce8fe9d4..43a7a07f698 100644 --- a/mindspore/lite/src/lite_kernel.cc +++ b/mindspore/lite/src/lite_kernel.cc @@ -214,6 +214,7 @@ float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) { dequant_datas[j] = static_cast((quant_datas[j] - zero_point) * scale); } } + return dequant_datas; } } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc index 8602b14df1c..2c2b060a7a5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc @@ -151,6 +151,7 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector MS_LOG(ERROR) << "dequant data is nullptr."; return nullptr; } + weight_tensor->set_data_type(kNumberTypeFloat32); weight_tensor->SetData(dequant_weight); } @@ -166,6 +167,7 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector MS_LOG(ERROR) << "kernel is nullptr."; if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { weight_tensor->FreeData(); + weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->SetData(restore_data); } return nullptr; @@ -177,12 +179,14 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { weight_tensor->FreeData(); + weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->SetData(restore_data); } return nullptr; } if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { weight_tensor->FreeData(); + weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->SetData(restore_data); } return kernel; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc index e7337cefe81..a7ae4dd4c70 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -196,6 +196,7 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector & MS_LOG(ERROR) << "dequant data is nullptr."; return nullptr; } + weight_tensor->set_data_type(kNumberTypeFloat32); weight_tensor->SetData(dequant_weight); } @@ -232,6 +233,7 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector & MS_LOG(DEBUG) << "Create conv fp16 kernel failed."; if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { weight_tensor->FreeData(); + weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->SetData(restore_data); } return nullptr; @@ -243,12 +245,14 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector & << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { weight_tensor->FreeData(); + weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->SetData(restore_data); } return nullptr; } if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { weight_tensor->FreeData(); + weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->SetData(restore_data); } return kernel; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc index 09f6627d95f..6c1b6b58446 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc @@ -210,6 +210,7 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vectorset_data_type(kNumberTypeFloat32); weight_tensor->SetData(dequant_weight); } @@ -218,6 +219,7 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vectordata_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { weight_tensor->FreeData(); + weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->SetData(restore_data); } return nullptr; @@ -229,12 +231,14 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector(opParameter->type_)); if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { weight_tensor->FreeData(); + weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->SetData(restore_data); } return nullptr; } if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { weight_tensor->FreeData(); + weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->SetData(restore_data); } return kernel; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc index dd40cb06338..cb1e52bfc6d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc @@ -217,6 +217,7 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector MS_LOG(ERROR) << "dequant data is nullptr."; return nullptr; } + weight_tensor->set_data_type(kNumberTypeFloat32); weight_tensor->SetData(dequant_weight); } @@ -225,6 +226,7 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector MS_LOG(ERROR) << "kernel is nullptr."; if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { weight_tensor->FreeData(); + weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->SetData(restore_data); } return nullptr; @@ -236,12 +238,14 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { weight_tensor->FreeData(); + weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->SetData(restore_data); } return nullptr; } if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { weight_tensor->FreeData(); + weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->SetData(restore_data); } return kernel; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc index bfded12ce2d..78316811520 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc @@ -198,6 +198,7 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vectorset_data_type(kNumberTypeFloat32); weight_tensor->SetData(dequant_weight); } auto *kernel = new (std::nothrow) FullconnectionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); @@ -205,6 +206,7 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vectorGetQuantParams().empty()) { weight_tensor->FreeData(); + weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->SetData(restore_data); } return nullptr; @@ -216,12 +218,14 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vectorGetQuantParams().empty()) { weight_tensor->FreeData(); + weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->SetData(restore_data); } return nullptr; } if (!weight_tensor->GetQuantParams().empty()) { weight_tensor->FreeData(); + weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->SetData(restore_data); } return kernel; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc index b03bf3c31ca..5ba2e0bae16 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc @@ -261,6 +261,7 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector MS_LOG(ERROR) << "dequant data is nullptr."; return nullptr; } + weight_tensor->set_data_type(kNumberTypeFloat32); weight_tensor->SetData(dequant_weight); } auto *kernel = new (std::nothrow) MatmulFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); @@ -268,6 +269,7 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector MS_LOG(ERROR) << "kernel is nullptr."; if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { weight_tensor->FreeData(); + weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->SetData(restore_data); } return nullptr; @@ -279,12 +281,14 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector delete kernel; if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { weight_tensor->FreeData(); + weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->SetData(restore_data); } return nullptr; } if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { weight_tensor->FreeData(); + weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->SetData(restore_data); } return kernel;