forked from mindspore-Ecosystem/mindspore
!5957 [MSLITE] Dequantize weight data for tflite models whose quantType is unset.
Merge pull request !5957 from wangshaocong/lite_bug
This commit is contained in:
commit
c6838194ef
|
@ -236,7 +236,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
|
||||||
|
|
||||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||||
auto *restore_data = weight_tensor->MutableData();
|
auto *restore_data = weight_tensor->MutableData();
|
||||||
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||||
ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex));
|
ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -265,7 +265,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||||
weight_tensor->FreeData();
|
weight_tensor->FreeData();
|
||||||
weight_tensor->SetData(restore_data);
|
weight_tensor->SetData(restore_data);
|
||||||
}
|
}
|
||||||
|
|
|
@ -133,7 +133,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
|
||||||
|
|
||||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||||
auto *restore_data = weight_tensor->MutableData();
|
auto *restore_data = weight_tensor->MutableData();
|
||||||
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||||
ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex));
|
ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -156,7 +156,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||||
weight_tensor->FreeData();
|
weight_tensor->FreeData();
|
||||||
weight_tensor->SetData(restore_data);
|
weight_tensor->SetData(restore_data);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue