!6859 [MSLITE] fp16 quant fix

Merge pull request !6859 from wangchangkai/master
This commit is contained in:
mindspore-ci-bot 2020-09-24 23:03:36 +08:00 committed by Gitee
commit a11873a7b4
7 changed files with 25 additions and 0 deletions

View File

@ -214,6 +214,7 @@ float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) {
dequant_datas[j] = static_cast<float>((quant_datas[j] - zero_point) * scale);
}
}
return dequant_datas;
}
} // namespace mindspore::kernel

View File

@ -151,6 +151,7 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
MS_LOG(ERROR) << "dequant data is nullptr.";
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->SetData(dequant_weight);
}
@ -166,6 +167,7 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
@ -177,12 +179,14 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return kernel;

View File

@ -196,6 +196,7 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
MS_LOG(ERROR) << "dequant data is nullptr.";
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->SetData(dequant_weight);
}
@ -232,6 +233,7 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
MS_LOG(DEBUG) << "Create conv fp16 kernel failed.";
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
@ -243,12 +245,14 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return kernel;

View File

@ -210,6 +210,7 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
MS_LOG(ERROR) << "dequant data is nullptr.";
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->SetData(dequant_weight);
}
@ -218,6 +219,7 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
@ -229,12 +231,14 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return kernel;

View File

@ -217,6 +217,7 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
MS_LOG(ERROR) << "dequant data is nullptr.";
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->SetData(dequant_weight);
}
@ -225,6 +226,7 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
@ -236,12 +238,14 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return kernel;

View File

@ -198,6 +198,7 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
MS_LOG(ERROR) << "dequant data is nullptr.";
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->SetData(dequant_weight);
}
auto *kernel = new (std::nothrow) FullconnectionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
@ -205,6 +206,7 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
MS_LOG(ERROR) << "kernel is nullptr.";
if (!weight_tensor->GetQuantParams().empty()) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
@ -216,12 +218,14 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
delete kernel;
if (!weight_tensor->GetQuantParams().empty()) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (!weight_tensor->GetQuantParams().empty()) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return kernel;

View File

@ -261,6 +261,7 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
MS_LOG(ERROR) << "dequant data is nullptr.";
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->SetData(dequant_weight);
}
auto *kernel = new (std::nothrow) MatmulFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
@ -268,6 +269,7 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
MS_LOG(ERROR) << "kernel is nullptr.";
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
@ -279,12 +281,14 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
delete kernel;
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return kernel;