add percent method

This commit is contained in:
guohongzilong 2020-09-18 11:04:16 +08:00
parent d5fad7804e
commit 7ad32a773d
4 changed files with 123 additions and 1 deletions

View File

@ -54,6 +54,21 @@ STATUS DivergInfo::RecordMaxValue(const std::vector<float> &datas) {
return RET_OK;
}
STATUS DivergInfo::RecordMaxValueArray(const std::vector<float> &datas) {
if (datas.size() == 0) {
return RET_ERROR;
}
float max_num = datas.at(0);
float min_num = datas.at(0);
for (float data : datas) {
max_num = std::max(data, max_num);
min_num = std::min(data, min_num);
}
this->max_datas.emplace_back(max_num);
this->min_datas.emplace_back(min_num);
return RET_OK;
}
void DivergInfo::UpdateInterval() {
auto max_value = std::max(fabs(this->max), fabs(this->min));
this->interval = max_value / static_cast<float>(bin_num);
@ -85,6 +100,12 @@ STATUS DivergInfo::ComputeThreshold() {
return RET_OK;
}
if (method_x == kMethodOutlier) {
this->percent_result = PercentMethod(min_datas, max_datas);
this->best_T = std::max(std::fabs(percent_result.first), std::fabs(percent_result.second));
return RET_OK;
}
constexpr int quant_bint_nums = 128;
int threshold = quant_bint_nums;
float min_kl = FLT_MAX;
@ -195,8 +216,14 @@ std::pair<CNodePtr, float> DivergInfo::GetScale() {
float max_value = this->best_T;
float min_value = -max_value;
if (this->method_x == kMethodOutlier) {
min_value = percent_result.first;
max_value = percent_result.second;
}
MS_ASSERT(quant_max - quant_min != 0);
float scale = (max_value - min_value) / (quant_max - quant_min);
this->scale_tmp = scale;
MS_ASSERT(scale != 0);
return std::make_pair(this->cnode, scale);
}
@ -210,6 +237,10 @@ std::pair<CNodePtr, int32_t> DivergInfo::GetZeropoint() {
} else {
MS_LOG(WARNING) << "unexpectd quant range, quant_min: " << quant_min << " quant_max: " << quant_max;
}
if (this->method_x == kMethodOutlier) {
zero_point = std::round(quant_max - percent_result.second / scale_tmp);
}
return std::make_pair(this->cnode, zero_point);
}
@ -267,6 +298,7 @@ STATUS Calibrator::RecordMaxValue(const std::string &op_name, const vector<float
auto got = (*diverg_info).find(op_name);
if (got != (*diverg_info).end()) {
((*got).second)->RecordMaxValue(data);
((*got).second)->RecordMaxValueArray(data);
}
return RET_OK;
}
@ -445,7 +477,7 @@ STATUS Calibrator::ReadConfig() {
} else if (key == "thread_num") {
config_param_.thread_num = std::stoul(value);
} else if (key == "method_x") {
if (value != kMethodKL && value != kMethodMaxMin) {
if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) {
MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value.";
} else {
config_param_.method_x = value;

View File

@ -42,6 +42,7 @@ struct MaxMin {
const char kMethodMaxMin[] = "MAX_MIN";
const char kMethodKL[] = "KL";
const char kMethodOutlier[] = "RemovalOutlier";
constexpr int kDefaultBinNumber = 2048;
struct ConfigParam {
@ -127,6 +128,10 @@ struct DivergInfo {
int quant_max = 255;
int quant_min = 0;
std::string method_x = kMethodKL;
std::vector<float> min_datas;
std::vector<float> max_datas;
std::pair<float, float> percent_result{0.0, 0.0};
float scale_tmp = 0;
DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min, const std::string &method_x) {
this->method_x = method_x;
@ -143,6 +148,8 @@ struct DivergInfo {
STATUS RecordMaxValue(const std::vector<float> &datas);
STATUS RecordMaxValueArray(const std::vector<float> &datas);
void UpdateInterval();
STATUS UpdateHistogram(const std::vector<float> &data);

View File

@ -304,6 +304,74 @@ STATUS PostBitPack(float *weight, size_t shapeSize, size_t bitNum) {
return RET_OK;
}
bool SearchLowerBound(const std::vector<float> &data, const size_t &index, const float &max_tmp, float *min_tmp,
size_t *min_idx) {
size_t length = data.size();
if (max_tmp - data.at(index) < delta) {
return false;
}
float range_ratio = (data.at(index) - *min_tmp) / (max_tmp - *min_tmp);
float index_ratio = static_cast<float>(index - *min_idx) / (length - *min_idx);
if (index_ratio > 0 && range_ratio / index_ratio > ratio) {
*min_idx = index;
*min_tmp = data.at(index);
}
return true;
}
bool SearchUpperBound(const std::vector<float> &data, const size_t &index, float *max_tmp, const float &min_tmp,
size_t *max_idx) {
size_t length = data.size();
if (data.at(index) - min_tmp < delta) {
return false;
}
float range_ratio = (*max_tmp - data.at(index)) / (*max_tmp - min_tmp);
float index_ratio = static_cast<float>(index - *max_idx) / (length - *max_idx);
if (index_ratio > 0 && range_ratio / index_ratio > ratio) {
*max_idx = index;
*max_tmp = data.at(index);
}
return true;
}
float CalPercentile(const std::vector<float> &datas, const int &outlier_percent) {
const int size = datas.size();
float val = outlier_percent / 100.0 * size;
int index = std::ceil(val);
float result = 0.0;
if (index - val > 0) {
result = datas.at(index - 1);
} else {
result = (datas.at(index - 1) + datas.at(index)) / 2;
}
return result;
}
std::pair<float, float> PercentMethod(std::vector<float> min_datas, std::vector<float> max_datas) {
std::sort(max_datas.begin(), max_datas.end());
std::sort(min_datas.begin(), min_datas.end());
float min_val = CalPercentile(min_datas, percent);
float max_val = CalPercentile(max_datas, 100 - percent);
std::reverse(max_datas.begin(), max_datas.end());
MS_ASSERT(min_val < max_val);
MS_ASSERT(min_datas.size() == max_datas.size());
float min_tmp = min_val;
float max_tmp = max_val;
size_t min_idx = 0;
size_t max_idx = 0;
size_t length = min_datas.size();
for (size_t i = 0; i < length; i++) {
if (!SearchLowerBound(min_datas, i, max_tmp, &min_tmp, &min_idx)) {
break;
}
if (!SearchUpperBound(min_datas, i, &max_tmp, min_tmp, &max_idx)) {
break;
}
}
std::pair<float, float> result{min_tmp, max_tmp};
return result;
}
} // namespace quant
} // namespace lite
} // namespace mindspore

View File

@ -24,6 +24,7 @@
#include <vector>
#include <algorithm>
#include <limits>
#include <utility>
#include "tools/converter/quantizer/quantizer.h"
#include "src/ops/primitive_c.h"
#include "include/errorcode.h"
@ -61,12 +62,26 @@ class QuantStrategy {
static const std::vector<schema::PrimitiveType> mul_types;
};
constexpr float delta = 0.1;
constexpr float ratio = 10.0;
constexpr int percent = 10;
STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int quant_max,
int quant_min, int num_bits);
STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange = false,
int numBits = UINT8_QUANTIZATION);
bool SearchLowerBound(const std::vector<float> &data, const size_t &index, const float &max_tmp, float *min_tmp,
size_t *min_idx);
bool SearchUpperBound(const std::vector<float> &data, const size_t &index, float *max_tmp, const float &min_tmp,
size_t *max_idx);
float CalPercentile(const std::vector<float> &datas, const int &percent);
std::pair<float, float> PercentMethod(std::vector<float> min_datas, std::vector<float> max_datas);
template <typename T>
T QuantizeData(const float originData, const schema::QuantParamT *quantParam) {
MS_ASSERT(quantParam != nullptr);