!5957 [MSLITE] Dequantize weight data for tflite models whose quantType is unset.

Merge pull request !5957 from wangshaocong/lite_bug
This commit is contained in:
mindspore-ci-bot 2020-09-11 09:53:03 +08:00 committed by Gitee
commit c6838194ef
2 changed files with 4 additions and 4 deletions

View File

@ -236,7 +236,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
auto *weight_tensor = inputs.at(kWeightIndex); auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData(); auto *restore_data = weight_tensor->MutableData();
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex)); ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex));
} }
@ -265,7 +265,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
return nullptr; return nullptr;
} }
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
} }

View File

@ -133,7 +133,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
auto *weight_tensor = inputs.at(kWeightIndex); auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData(); auto *restore_data = weight_tensor->MutableData();
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex)); ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex));
} }
@ -156,7 +156,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
return nullptr; return nullptr;
} }
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
} }