promote weight quantization precision

This commit is contained in:
xutianchun 2020-09-23 10:12:11 +08:00
parent 7f35e2e2a1
commit 561ebcca7e
5 changed files with 61 additions and 11 deletions

View File

@ -32,6 +32,8 @@ table QuantParam {
narrowRange: bool = true;
numBits: int = 8;
inited: bool = false;
var_corr: double = 1;
mean_corr: double = 0;
}
table Tensor {

View File

@ -174,9 +174,9 @@ float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) {
MS_LOG(ERROR) << "no quant param";
return nullptr;
}
const auto *quant_data = static_cast<const int8_t *>(input_tensor->MutableData());
auto *dequant_data = static_cast<float *>(malloc(input_tensor->ElementsNum() * sizeof(float)));
if (dequant_data == nullptr) {
const auto *quant_datas = static_cast<const int8_t *>(input_tensor->MutableData());
auto *dequant_datas = static_cast<float *>(malloc(input_tensor->ElementsNum() * sizeof(float)));
if (dequant_datas == nullptr) {
MS_LOG(ERROR) << "malloc faile";
return nullptr;
}
@ -185,7 +185,7 @@ float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) {
size_t channels = static_cast<size_t>(input_tensor->Batch());
if (input_tensor->GetQuantParams().size() != channels) {
MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels;
free(dequant_data);
free(dequant_datas);
return nullptr;
}
size_t per_channel_size = input_tensor->ElementsNum() / channels;
@ -194,9 +194,15 @@ float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) {
auto param = quant_param.at(i);
auto scale = param.scale;
auto zero_point = param.zeroPoint;
auto var_corr = param.var_corr;
auto mean_corr = param.mean_corr;
if (var_corr < 0 || var_corr > 10) {
MS_LOG(WARNING) << "unexpeted var_corr: " << var_corr;
var_corr = 1;
}
for (size_t j = 0; j < per_channel_size; j++) {
dequant_data[per_channel_size * i + j] =
static_cast<float>((quant_data[per_channel_size * i + j] - zero_point) * scale);
auto dequant_data = (quant_datas[per_channel_size * i + j] - zero_point) * scale;
dequant_datas[per_channel_size * i + j] = static_cast<float>(dequant_data * var_corr + mean_corr);
}
}
} else {
@ -205,9 +211,9 @@ float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) {
auto scale = param.scale;
auto zero_point = param.zeroPoint;
for (int64_t j = 0; j < input_tensor->ElementsNum(); j++) {
dequant_data[j] = static_cast<float>((quant_data[j] - zero_point) * scale);
dequant_datas[j] = static_cast<float>((quant_datas[j] - zero_point) * scale);
}
}
return dequant_data;
return dequant_datas;
}
} // namespace mindspore::kernel

View File

@ -106,6 +106,8 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
QuantArg quant_arg{};
quant_arg.scale = quant_params->Get(j)->scale();
quant_arg.zeroPoint = quant_params->Get(j)->zeroPoint();
quant_arg.var_corr = quant_params->Get(j)->var_corr();
quant_arg.mean_corr = quant_params->Get(j)->mean_corr();
dstTensor->AddQuantParam(quant_arg);
}
}
@ -351,7 +353,7 @@ int LiteSession::Init(Context *context) {
}
}
#endif
executor = new(std::nothrow) Executor();
executor = new (std::nothrow) Executor();
if (nullptr == executor) {
MS_LOG(ERROR) << "New Executor failed";
is_running_.store(false);

View File

@ -33,6 +33,8 @@ namespace lite {
struct QuantArg {
double scale;
int32_t zeroPoint;
double var_corr{1};
double mean_corr{0};
};
class Tensor : public mindspore::tensor::MSTensor {

View File

@ -143,7 +143,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
return RET_ERROR;
}
std::vector<T> quant_datas(elem_count);
std::vector<float> dequant_datas(elem_count);
if (per_channel) {
// notice: assume Con2D\DepthwiseConv2D's weight format are same: KHWC
// channel at first
@ -173,8 +173,9 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
return status;
}
quant_params.emplace_back(quant_param);
// do quantization
double average_dequant = 0;
double average_raw = 0;
for (uint32_t j = 0; j < one_filter_size; j++) {
auto index = j + i * one_filter_size;
if (index >= elem_count) {
@ -184,7 +185,44 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
float raw_data = raw_datas[index];
auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min);
quant_datas[index] = quant_data;
if (quantType == QuantType_WeightQuant) {
float dequant_data = quant_param.scale * (quant_data - quant_param.zeroPoint);
dequant_datas[index] = dequant_data;
average_dequant += dequant_data;
average_raw += raw_data;
}
}
if (quantType == QuantType_WeightQuant) {
// mean
average_dequant = average_dequant / one_filter_size;
average_raw = average_raw / one_filter_size;
// std
double variance_dequant = 0;
double variance_raw = 0;
for (uint32_t j = 0; j < one_filter_size; j++) {
auto index = j + i * one_filter_size;
if (index >= elem_count) {
MS_LOG(ERROR) << "over flow!";
return RET_ERROR;
}
variance_dequant += std::pow(dequant_datas[index] - average_dequant, 2);
variance_raw += std::pow(raw_datas[index] - average_raw, 2);
}
variance_dequant = std::sqrt(variance_dequant / one_filter_size);
variance_raw = std::sqrt(variance_raw / one_filter_size);
quant_param.var_corr = 1;
if (variance_raw != 0 && variance_dequant != 0) {
auto temp_var_corr = variance_raw / variance_dequant;
if (temp_var_corr > 0 && temp_var_corr < 10) {
quant_param.var_corr = temp_var_corr;
} else {
MS_LOG(WARNING) << "unexpected var_corr: " << temp_var_corr;
}
}
quant_param.mean_corr = average_raw - average_dequant * quant_param.var_corr;
}
quant_params.emplace_back(quant_param);
}
auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t));
if (ret != EOK) {