Dequantize weight parameters for quantized tflite model with the quantType of None.

This commit is contained in:
wsc 2020-09-09 17:02:40 +08:00
parent 483b364d92
commit a9b8b39cac
2 changed files with 4 additions and 4 deletions

View File

@ -236,7 +236,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex));
}
@ -265,7 +265,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
return nullptr;
}
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}

View File

@ -133,7 +133,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex));
}
@ -156,7 +156,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
return nullptr;
}
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}