!6859 [MSLITE] fp16 quant fix
Merge pull request !6859 from wangchangkai/master
This commit is contained in:
commit
a11873a7b4
|
@ -214,6 +214,7 @@ float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) {
|
|||
dequant_datas[j] = static_cast<float>((quant_datas[j] - zero_point) * scale);
|
||||
}
|
||||
}
|
||||
|
||||
return dequant_datas;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -151,6 +151,7 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
|
|||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
weight_tensor->set_data_type(kNumberTypeFloat32);
|
||||
weight_tensor->SetData(dequant_weight);
|
||||
}
|
||||
|
||||
|
@ -166,6 +167,7 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
|
|||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data_type(kNumberTypeInt8);
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return nullptr;
|
||||
|
@ -177,12 +179,14 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
|
|||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data_type(kNumberTypeInt8);
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data_type(kNumberTypeInt8);
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return kernel;
|
||||
|
|
|
@ -196,6 +196,7 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
|
|||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
weight_tensor->set_data_type(kNumberTypeFloat32);
|
||||
weight_tensor->SetData(dequant_weight);
|
||||
}
|
||||
|
||||
|
@ -232,6 +233,7 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
|
|||
MS_LOG(DEBUG) << "Create conv fp16 kernel failed.";
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data_type(kNumberTypeInt8);
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return nullptr;
|
||||
|
@ -243,12 +245,14 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
|
|||
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data_type(kNumberTypeInt8);
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data_type(kNumberTypeInt8);
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return kernel;
|
||||
|
|
|
@ -210,6 +210,7 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
|
|||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
weight_tensor->set_data_type(kNumberTypeFloat32);
|
||||
weight_tensor->SetData(dequant_weight);
|
||||
}
|
||||
|
||||
|
@ -218,6 +219,7 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
|
|||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data_type(kNumberTypeInt8);
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return nullptr;
|
||||
|
@ -229,12 +231,14 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
|
|||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data_type(kNumberTypeInt8);
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data_type(kNumberTypeInt8);
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return kernel;
|
||||
|
|
|
@ -217,6 +217,7 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
|
|||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
weight_tensor->set_data_type(kNumberTypeFloat32);
|
||||
weight_tensor->SetData(dequant_weight);
|
||||
}
|
||||
|
||||
|
@ -225,6 +226,7 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
|
|||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data_type(kNumberTypeInt8);
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return nullptr;
|
||||
|
@ -236,12 +238,14 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
|
|||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data_type(kNumberTypeInt8);
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data_type(kNumberTypeInt8);
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return kernel;
|
||||
|
|
|
@ -198,6 +198,7 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
|
|||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
weight_tensor->set_data_type(kNumberTypeFloat32);
|
||||
weight_tensor->SetData(dequant_weight);
|
||||
}
|
||||
auto *kernel = new (std::nothrow) FullconnectionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
|
@ -205,6 +206,7 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
|
|||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
if (!weight_tensor->GetQuantParams().empty()) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data_type(kNumberTypeInt8);
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return nullptr;
|
||||
|
@ -216,12 +218,14 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
|
|||
delete kernel;
|
||||
if (!weight_tensor->GetQuantParams().empty()) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data_type(kNumberTypeInt8);
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
if (!weight_tensor->GetQuantParams().empty()) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data_type(kNumberTypeInt8);
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return kernel;
|
||||
|
|
|
@ -261,6 +261,7 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
|
|||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
weight_tensor->set_data_type(kNumberTypeFloat32);
|
||||
weight_tensor->SetData(dequant_weight);
|
||||
}
|
||||
auto *kernel = new (std::nothrow) MatmulFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
|
@ -268,6 +269,7 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
|
|||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data_type(kNumberTypeInt8);
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return nullptr;
|
||||
|
@ -279,12 +281,14 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
|
|||
delete kernel;
|
||||
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data_type(kNumberTypeInt8);
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data_type(kNumberTypeInt8);
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return kernel;
|
||||
|
|
Loading…
Reference in New Issue