forked from mindspore-Ecosystem/mindspore
promote weight quantization precision
This commit is contained in:
parent
7f35e2e2a1
commit
561ebcca7e
|
@ -32,6 +32,8 @@ table QuantParam {
|
|||
narrowRange: bool = true;
|
||||
numBits: int = 8;
|
||||
inited: bool = false;
|
||||
var_corr: double = 1;
|
||||
mean_corr: double = 0;
|
||||
}
|
||||
|
||||
table Tensor {
|
||||
|
|
|
@ -174,9 +174,9 @@ float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) {
|
|||
MS_LOG(ERROR) << "no quant param";
|
||||
return nullptr;
|
||||
}
|
||||
const auto *quant_data = static_cast<const int8_t *>(input_tensor->MutableData());
|
||||
auto *dequant_data = static_cast<float *>(malloc(input_tensor->ElementsNum() * sizeof(float)));
|
||||
if (dequant_data == nullptr) {
|
||||
const auto *quant_datas = static_cast<const int8_t *>(input_tensor->MutableData());
|
||||
auto *dequant_datas = static_cast<float *>(malloc(input_tensor->ElementsNum() * sizeof(float)));
|
||||
if (dequant_datas == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc faile";
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -185,7 +185,7 @@ float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) {
|
|||
size_t channels = static_cast<size_t>(input_tensor->Batch());
|
||||
if (input_tensor->GetQuantParams().size() != channels) {
|
||||
MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels;
|
||||
free(dequant_data);
|
||||
free(dequant_datas);
|
||||
return nullptr;
|
||||
}
|
||||
size_t per_channel_size = input_tensor->ElementsNum() / channels;
|
||||
|
@ -194,9 +194,15 @@ float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) {
|
|||
auto param = quant_param.at(i);
|
||||
auto scale = param.scale;
|
||||
auto zero_point = param.zeroPoint;
|
||||
auto var_corr = param.var_corr;
|
||||
auto mean_corr = param.mean_corr;
|
||||
if (var_corr < 0 || var_corr > 10) {
|
||||
MS_LOG(WARNING) << "unexpeted var_corr: " << var_corr;
|
||||
var_corr = 1;
|
||||
}
|
||||
for (size_t j = 0; j < per_channel_size; j++) {
|
||||
dequant_data[per_channel_size * i + j] =
|
||||
static_cast<float>((quant_data[per_channel_size * i + j] - zero_point) * scale);
|
||||
auto dequant_data = (quant_datas[per_channel_size * i + j] - zero_point) * scale;
|
||||
dequant_datas[per_channel_size * i + j] = static_cast<float>(dequant_data * var_corr + mean_corr);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -205,9 +211,9 @@ float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) {
|
|||
auto scale = param.scale;
|
||||
auto zero_point = param.zeroPoint;
|
||||
for (int64_t j = 0; j < input_tensor->ElementsNum(); j++) {
|
||||
dequant_data[j] = static_cast<float>((quant_data[j] - zero_point) * scale);
|
||||
dequant_datas[j] = static_cast<float>((quant_datas[j] - zero_point) * scale);
|
||||
}
|
||||
}
|
||||
return dequant_data;
|
||||
return dequant_datas;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -106,6 +106,8 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
|
|||
QuantArg quant_arg{};
|
||||
quant_arg.scale = quant_params->Get(j)->scale();
|
||||
quant_arg.zeroPoint = quant_params->Get(j)->zeroPoint();
|
||||
quant_arg.var_corr = quant_params->Get(j)->var_corr();
|
||||
quant_arg.mean_corr = quant_params->Get(j)->mean_corr();
|
||||
dstTensor->AddQuantParam(quant_arg);
|
||||
}
|
||||
}
|
||||
|
@ -351,7 +353,7 @@ int LiteSession::Init(Context *context) {
|
|||
}
|
||||
}
|
||||
#endif
|
||||
executor = new(std::nothrow) Executor();
|
||||
executor = new (std::nothrow) Executor();
|
||||
if (nullptr == executor) {
|
||||
MS_LOG(ERROR) << "New Executor failed";
|
||||
is_running_.store(false);
|
||||
|
|
|
@ -33,6 +33,8 @@ namespace lite {
|
|||
struct QuantArg {
|
||||
double scale;
|
||||
int32_t zeroPoint;
|
||||
double var_corr{1};
|
||||
double mean_corr{0};
|
||||
};
|
||||
|
||||
class Tensor : public mindspore::tensor::MSTensor {
|
||||
|
|
|
@ -143,7 +143,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
|
|||
return RET_ERROR;
|
||||
}
|
||||
std::vector<T> quant_datas(elem_count);
|
||||
|
||||
std::vector<float> dequant_datas(elem_count);
|
||||
if (per_channel) {
|
||||
// notice: assume Con2D\DepthwiseConv2D's weight format are same: KHWC
|
||||
// channel at first
|
||||
|
@ -173,8 +173,9 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
|
|||
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
|
||||
return status;
|
||||
}
|
||||
quant_params.emplace_back(quant_param);
|
||||
// do quantization
|
||||
double average_dequant = 0;
|
||||
double average_raw = 0;
|
||||
for (uint32_t j = 0; j < one_filter_size; j++) {
|
||||
auto index = j + i * one_filter_size;
|
||||
if (index >= elem_count) {
|
||||
|
@ -184,7 +185,44 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
|
|||
float raw_data = raw_datas[index];
|
||||
auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min);
|
||||
quant_datas[index] = quant_data;
|
||||
|
||||
if (quantType == QuantType_WeightQuant) {
|
||||
float dequant_data = quant_param.scale * (quant_data - quant_param.zeroPoint);
|
||||
dequant_datas[index] = dequant_data;
|
||||
average_dequant += dequant_data;
|
||||
average_raw += raw_data;
|
||||
}
|
||||
}
|
||||
if (quantType == QuantType_WeightQuant) {
|
||||
// mean
|
||||
average_dequant = average_dequant / one_filter_size;
|
||||
average_raw = average_raw / one_filter_size;
|
||||
// std
|
||||
double variance_dequant = 0;
|
||||
double variance_raw = 0;
|
||||
for (uint32_t j = 0; j < one_filter_size; j++) {
|
||||
auto index = j + i * one_filter_size;
|
||||
if (index >= elem_count) {
|
||||
MS_LOG(ERROR) << "over flow!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
variance_dequant += std::pow(dequant_datas[index] - average_dequant, 2);
|
||||
variance_raw += std::pow(raw_datas[index] - average_raw, 2);
|
||||
}
|
||||
variance_dequant = std::sqrt(variance_dequant / one_filter_size);
|
||||
variance_raw = std::sqrt(variance_raw / one_filter_size);
|
||||
quant_param.var_corr = 1;
|
||||
if (variance_raw != 0 && variance_dequant != 0) {
|
||||
auto temp_var_corr = variance_raw / variance_dequant;
|
||||
if (temp_var_corr > 0 && temp_var_corr < 10) {
|
||||
quant_param.var_corr = temp_var_corr;
|
||||
} else {
|
||||
MS_LOG(WARNING) << "unexpected var_corr: " << temp_var_corr;
|
||||
}
|
||||
}
|
||||
quant_param.mean_corr = average_raw - average_dequant * quant_param.var_corr;
|
||||
}
|
||||
quant_params.emplace_back(quant_param);
|
||||
}
|
||||
auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t));
|
||||
if (ret != EOK) {
|
||||
|
|
Loading…
Reference in New Issue