diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc index 04cae95820d..b9f3a5896a4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc @@ -236,7 +236,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector & auto *weight_tensor = inputs.at(kWeightIndex); auto *restore_data = weight_tensor->MutableData(); - if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex)); } @@ -265,7 +265,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector & return nullptr; } - if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc index e9ac09a5b1b..b24e05cf9b7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc @@ -133,7 +133,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector auto *weight_tensor = inputs.at(kWeightIndex); auto *restore_data = weight_tensor->MutableData(); - if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex)); } @@ -156,7 +156,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector return nullptr; } - if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); }