!27468 optimize mixed bit weight quant

Merge pull request !27468 from yeyunpeng2020/quant
This commit is contained in:
i-robot 2021-12-09 13:06:45 +00:00 committed by Gitee
commit 31b17f273d
10 changed files with 128 additions and 97 deletions

View File

@ -24,8 +24,8 @@ file_identifier "MSL2";
file_extension "ms";
table QuantParam {
scale: double;
zeroPoint: int;
scale: double = 1;
zeroPoint: int = 0;
min: double = 0;
max: double = 0;
narrowRange: bool = true;

View File

@ -23,10 +23,11 @@ namespace mindspore::lite::quant {
namespace {
constexpr int8_t kCurrentBitCount = 64;
constexpr int8_t kTableSize = 6;
constexpr size_t kInt32Mask = 31;
} // namespace
int BitStream::Create(int bit_capacity) {
int FSEBitStream::Create(int bit_capacity) {
chunk_count_ = (bit_capacity >> kTableSize);
chunks_ = static_cast<uint64_t *>(calloc(chunk_count_, sizeof(uint64_t)));
chunks_ = static_cast<uint64_t *>(malloc(chunk_count_ * sizeof(uint64_t)));
if (chunks_ == nullptr) {
MS_LOG(ERROR) << "malloc memory failed.";
return RET_ERROR;
@ -35,7 +36,7 @@ int BitStream::Create(int bit_capacity) {
return RET_OK;
}
void BitStream::Free() {
void FSEBitStream::Free() {
curr_chunk_index_ = -1;
curr_chunk_ = 0;
curr_bit_count_ = 0;
@ -46,7 +47,7 @@ void BitStream::Free() {
}
}
void BitStream::Empty() {
void FSEBitStream::Empty() {
curr_chunk_index_ = -1;
curr_chunk_ = 0;
curr_bit_count_ = 0;
@ -55,7 +56,7 @@ void BitStream::Empty() {
}
}
int64_t BitStream::Pop(uint8_t bit_count) {
int64_t FSEBitStream::Pop(uint8_t bit_count) {
MS_ASSERT(curr_bit_count_ <= kCurrentBitCount);
int64_t right = curr_chunk_ >> (kCurrentBitCount - curr_bit_count_);
int64_t res = right & ((1 << bit_count) - 1);
@ -81,7 +82,7 @@ int64_t BitStream::Pop(uint8_t bit_count) {
return right;
}
void BitStream::Push(int64_t state, uint8_t bit_count) {
void FSEBitStream::Push(int64_t state, uint8_t bit_count) {
curr_bit_count_ += bit_count;
if (curr_bit_count_ <= kCurrentBitCount) {
// happy path, no split
@ -104,5 +105,23 @@ void BitStream::Push(int64_t state, uint8_t bit_count) {
}
}
void BitStream::Flush() { curr_chunk_ <<= kCurrentBitCount - curr_bit_count_; }
void FSEBitStream::Flush() { curr_chunk_ <<= kCurrentBitCount - curr_bit_count_; }
// The function gives the index of most import `1` in the binary representation.
// e.g. for the number 00100 it gives 2.
int FSEBitStream::CountBits(int32_t x) {
#ifdef _MSC_VER
int num = 0;
uint32_t tmp = x;
tmp |= 1;
while (!(tmp & INT32_MIN)) {
num += 1;
tmp <<= 1;
}
return num ^ kInt32Mask;
#else
return __builtin_clz(x) ^ kInt32Mask;
#endif
return 0;
}
} // namespace mindspore::lite::quant

View File

@ -19,11 +19,11 @@
#include <cstdint>
namespace mindspore::lite::quant {
class BitStream {
class FSEBitStream {
public:
BitStream() = default;
FSEBitStream() = default;
~BitStream() = default;
~FSEBitStream() = default;
public:
int Create(int bit_capacity);
@ -32,6 +32,7 @@ class BitStream {
int64_t Pop(uint8_t bit_count);
void Push(int64_t state, uint8_t bit_count);
void Flush();
static int CountBits(int32_t x);
int32_t GetCurrChunkIndex() { return this->curr_chunk_index_; }
uint64_t GetCurrChunk() { return this->curr_chunk_; }

View File

@ -23,6 +23,10 @@
#include "nnacl/op_base.h"
namespace mindspore::lite::quant {
namespace {
constexpr int kTableExtend = 3;
constexpr int kAlignOffset = 7;
} // namespace
int FSEDecoder::FSECreateStatesForDecoding(const uint32_t *symbol_frequency, int symbol_frequency_count, int table_log,
uint16_t *new_state, uint8_t *bit_count, uint16_t *symbol_table) {
MS_ASSERT(symbol_frequency != nullptr);
@ -30,8 +34,8 @@ int FSEDecoder::FSECreateStatesForDecoding(const uint32_t *symbol_frequency, int
MS_ASSERT(bit_count != nullptr);
MS_ASSERT(symbol_table != nullptr);
const int table_size = 1 << table_log;
int table_mask = table_size - 1;
int step = ((table_size >> 1) + (table_size >> 3) + 3);
const int table_mask = table_size - 1;
int step = ((table_size >> 1) + (table_size >> kTableExtend) + kTableExtend);
int pos = 0;
for (int sym = 0; sym < symbol_frequency_count; sym++) {
for (uint32_t i = 0; i < symbol_frequency[sym]; i++) {
@ -52,24 +56,13 @@ int FSEDecoder::FSECreateStatesForDecoding(const uint32_t *symbol_frequency, int
uint16_t sym = symbol_table[i];
uint32_t x = frequency[sym];
frequency[sym] += 1;
#ifdef _MSC_VER
int num = 0;
uint32_t tmp = x;
tmp |= 1;
while (!(tmp & 0x80000000)) {
num += 1;
tmp <<= 1;
}
bit_count[i] = table_log - (num ^ 31);
#else
bit_count[i] = table_log - (__builtin_clz(x) ^ 31);
#endif
bit_count[i] = table_log - FSEBitStream::CountBits(x);
new_state[i] = (x << bit_count[i]) - table_size;
}
return RET_OK;
}
int FSEDecoder::FSEDecode(BitStream *bs, float *buff, int buff_count, uint32_t *frequency, int frequency_count,
int FSEDecoder::FSEDecode(FSEBitStream *bs, float *buff, int buff_count, uint32_t *frequency, int frequency_count,
const float *centroids, int table_log) {
MS_ASSERT(bs != nullptr);
MS_ASSERT(buff != nullptr);
@ -122,10 +115,10 @@ int FSEDecoder::DeCompress(const SchemaTensorWrapper &src_tensor, Tensor *dst_te
int out_sz = dst_tensor->ElementsNum();
MS_CHECK_GT(out_sz, 0, RET_ERROR);
// deserialize from `data`:
BitStream bs;
FSEBitStream bs;
size_t i = 0;
auto data8 = reinterpret_cast<unsigned char *>(const_cast<void *>(src_tensor.data()));
auto data8 = reinterpret_cast<int8_t *>(const_cast<void *>(src_tensor.data()));
int frequency_count = *(reinterpret_cast<uint16_t *>(&data8[i]));
i += sizeof(uint16_t);
@ -152,7 +145,7 @@ int FSEDecoder::DeCompress(const SchemaTensorWrapper &src_tensor, Tensor *dst_te
auto *frequency = reinterpret_cast<uint32_t *>(&data8[i]);
i += frequency_count * sizeof(uint32_t);
// Used for 8-byte alignment
i = ((i + 7) >> 3) << 3;
i = ((i + kAlignOffset) >> kTableExtend) << kTableExtend;
if (i > total_size) {
MS_LOG(ERROR) << "index over total size"
<< " index:" << i << " total size:" << total_size;
@ -162,7 +155,7 @@ int FSEDecoder::DeCompress(const SchemaTensorWrapper &src_tensor, Tensor *dst_te
auto centroids_float = reinterpret_cast<float *>(centroids);
i += frequency_count * sizeof(float);
// Used for 8-byte alignment
i = ((i + 7) >> 3) << 3;
i = ((i + kAlignOffset) >> kTableExtend) << kTableExtend;
if (i > total_size) {
MS_LOG(ERROR) << "index over total size"
<< " index:" << i << " total size:" << total_size;

View File

@ -31,7 +31,7 @@ class FSEDecoder {
static int DeCompress(const SchemaTensorWrapper &src_tensor, Tensor *dst_tensor);
private:
static int FSEDecode(BitStream *bs, float *buff, int buff_count, uint32_t *frequency, int frequency_count,
static int FSEDecode(FSEBitStream *bs, float *buff, int buff_count, uint32_t *frequency, int frequency_count,
const float *centroids, int table_log);
static int FSECreateStatesForDecoding(const uint32_t *symbol_frequency, int symbol_frequency_count, int table_log,

View File

@ -26,16 +26,12 @@
namespace mindspore::lite::quant {
namespace {
constexpr int kInt32Mask = 31;
constexpr int kInt16 = 16;
constexpr int kFseTableExtendSize = 3;
constexpr int kFrenqTableExtendSize = 2;
constexpr int kAlignSize = 8;
constexpr float kUpRoundOffSet = 0.5;
} // namespace
// The function gives the index of most import `1` in the binary representation.
// e.g. for the number 00100 it gives 2.
int fse_count_bits(int32_t x) { return __builtin_clz(x) ^ kInt32Mask; }
int FSEEncoder::FSECreateStatesForEncoding(uint32_t *frequency, int frequency_count, int table_log,
uint32_t *delta_bit_count, int16_t *delta_state, uint16_t *coding_table,
@ -76,7 +72,7 @@ int FSEEncoder::FSECreateStatesForEncoding(uint32_t *frequency, int frequency_co
int total = 0;
for (int sym = 0; sym < frequency_count; sym++) {
if (frequency[sym] >= kFrenqTableExtendSize) {
int max_bits_out = table_log - fse_count_bits(frequency[sym] - 1);
int max_bits_out = table_log - FSEBitStream::CountBits(frequency[sym] - 1);
int min_state_plus = frequency[sym] << max_bits_out;
delta_bit_count[sym] = (max_bits_out << kInt16) - min_state_plus;
delta_state[sym] = total - frequency[sym];
@ -159,10 +155,10 @@ int FSEEncoder::Compress(schema::TensorT *tensor_input) {
MS_LOG(ERROR) << "Normalize frequency failed.";
return ret;
}
BitStream bs;
FSEBitStream bs;
ret = bs.Create(kInt16 * fse_quant.symbol_table_count);
if (ret != RET_OK) {
MS_LOG(ERROR) << "BitStream Create failed.";
MS_LOG(ERROR) << "FSEBitStream Create failed.";
free(fse_quant.symbol_table);
return ret;
}
@ -186,7 +182,7 @@ int FSEEncoder::Compress(schema::TensorT *tensor_input) {
return RET_OK;
}
uint16_t FSEEncoder::FSEEncodeSymbolGetNewState(BitStream *bs, uint16_t sym, uint16_t state,
uint16_t FSEEncoder::FSEEncodeSymbolGetNewState(FSEBitStream *bs, uint16_t sym, uint16_t state,
const uint32_t *delta_bit_count, const int16_t *delta_state,
uint16_t *coding_table) {
MS_ASSERT(bs != nullptr);
@ -219,7 +215,7 @@ int FSEEncoder::NormalizeFrequency(FSEQuant *q, int *table_log) {
CHECK_NULL_RETURN(table_log);
// The higher the number, the more accurate we'll be to the shannon entropy,
// but also the larger the table, so `+3` is a good compromise.
*table_log = std::min(MAX_TABLE_LOG, (fse_count_bits((uint32_t)q->size) + kFseTableExtendSize));
*table_log = std::min(MAX_TABLE_LOG, (FSEBitStream::CountBits((uint32_t)q->size) + kFseTableExtendSize));
const int new_table_size = 1 << (*table_log);
int curr_table_size = 0;
for (int i = 0; i < q->size; i++) {
@ -270,8 +266,8 @@ int FSEEncoder::NormalizeFrequency(FSEQuant *q, int *table_log) {
// - determine nbBits, flush them
// - determine sub-Range Id
// - look for Symbol position of same Id : you get your next state
int FSEEncoder::FSEEncode(BitStream *bs, const uint16_t *data, int data_count, uint32_t *frequency, int frequency_count,
int table_log) {
int FSEEncoder::FSEEncode(FSEBitStream *bs, const uint16_t *data, int data_count, uint32_t *frequency,
int frequency_count, int table_log) {
MS_ASSERT(bs != nullptr);
MS_ASSERT(data != nullptr);
MS_ASSERT(frequency != nullptr);
@ -305,7 +301,7 @@ int FSEEncoder::FSEEncode(BitStream *bs, const uint16_t *data, int data_count, u
return ret;
}
int FSEEncoder::SerializingToTensor(schema::TensorT *tensor_input, BitStream *bs, const FSEQuant &fse_quant,
int FSEEncoder::SerializingToTensor(schema::TensorT *tensor_input, FSEBitStream *bs, const FSEQuant &fse_quant,
int table_log, uint8_t *out8, size_t max_size, size_t *out_size) {
MSLITE_CHECK_PTR(tensor_input);
MSLITE_CHECK_PTR(bs);
@ -400,7 +396,7 @@ int FSEEncoder::SerializingToTensor(schema::TensorT *tensor_input, BitStream *bs
return RET_OK;
}
int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, const FSEQuant &fse_quant,
int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, FSEBitStream *bs, const FSEQuant &fse_quant,
int table_log) {
MSLITE_CHECK_PTR(tensor_input);
MSLITE_CHECK_PTR(bs);

View File

@ -41,17 +41,17 @@ class FSEEncoder {
int FSECreateStatesForEncoding(uint32_t *frequency, int frequency_count, int table_log, uint32_t *delta_bit_count,
int16_t *delta_state, uint16_t *coding_table, uint16_t *symbol_table);
uint16_t FSEEncodeSymbolGetNewState(BitStream *bs, uint16_t sym, uint16_t state, const uint32_t *delta_bit_count,
uint16_t FSEEncodeSymbolGetNewState(FSEBitStream *bs, uint16_t sym, uint16_t state, const uint32_t *delta_bit_count,
const int16_t *delta_state, uint16_t *coding_table);
int FSEEncode(BitStream *bs, const uint16_t *data, int data_count, uint32_t *frequency, int frequency_count,
int FSEEncode(FSEBitStream *bs, const uint16_t *data, int data_count, uint32_t *frequency, int frequency_count,
int table_log);
int NormalizeFrequency(FSEQuant *q, int *table_log);
int SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, const FSEQuant &fse_quant, int table_log);
int SerializingToOut(schema::TensorT *tensor_input, FSEBitStream *bs, const FSEQuant &fse_quant, int table_log);
int SerializingToTensor(schema::TensorT *tensor_input, BitStream *bs, const FSEQuant &fse_quant, int table_log,
int SerializingToTensor(schema::TensorT *tensor_input, FSEBitStream *bs, const FSEQuant &fse_quant, int table_log,
uint8_t *out8, size_t max_size, size_t *offset);
};
} // namespace mindspore::lite::quant

View File

@ -147,7 +147,8 @@ int FullQuantQuantizer::SetInOutQuantParam(const AnfNodePtr &input_node, const s
if (type_id == kNumberTypeFloat32 && info != nullptr) {
auto scale = info->GetScale();
if (scale == 0) {
MS_LOG(WARNING) << input_node->fullname_with_scope() << " input index:" << index
std::string in_out = is_input ? " input" : " output";
MS_LOG(WARNING) << input_node->fullname_with_scope() << in_out << " index:" << index
<< " values are very close to 0, so set the scale to 1.";
quant_param.scale = 1;
} else {

View File

@ -19,16 +19,69 @@
#include "tools/common/statistic_utils.h"
namespace mindspore::lite::quant {
void MixedBitWeightQuantizer::GetBiasCorrection(float *weights, int element_num, float scale,
float *origin_dequant_datas) {
MS_ASSERT(element_num > 0);
double average_dequant = 0;
double average_raw = 0;
for (int i = 0; i < element_num; i++) {
float dequant = scale * (floorf(weights[i] / scale + 0.5));
origin_dequant_datas[i] = dequant;
average_raw += weights[i];
average_dequant += dequant;
}
// mean
average_dequant = average_dequant / element_num;
average_raw = average_raw / element_num;
// std
double variance_dequant = 0;
double variance_raw = 0;
const int exponent = 2;
for (int i = 0; i < element_num; i++) {
variance_dequant += std::pow(origin_dequant_datas[i] - average_dequant, exponent);
variance_raw += std::pow(weights[i] - average_raw, exponent);
}
variance_dequant = std::sqrt(variance_dequant / element_num);
variance_raw = std::sqrt(variance_raw / element_num);
if (variance_dequant == 0) {
var_corr_ = 1;
} else {
var_corr_ = variance_raw / variance_dequant;
}
mean_corr_ = average_raw - average_dequant * var_corr_;
}
// the error is currently measured per channel.
// it could be measured per layer but it would be less good.
float MixedBitWeightQuantizer::CalculateMeanError(std::vector<float> norms2, std::vector<float> dnorms2) {
int error_count = 0;
float mse_error = 1e-10f;
const float soft = 1e-7f;
const float tolerance_error = 1.0e-10f;
for (size_t i = 0; i < norms2.size(); i++) {
if (norms2[i] < tolerance_error) {
continue;
}
error_count += 1;
mse_error += sqrtf(dnorms2[i] / norms2[i]);
}
auto meam_error = mse_error / (error_count + soft);
return meam_error;
}
// the `preferred` dim should point to the output channels dimension.
float MixedBitWeightQuantizer::MeasureQuantizationError(float *weights, const int *shape, int dims, int preferred_dim,
float scale) {
MS_ASSERT(weights != nullptr);
MS_ASSERT(shape != nullptr);
int numel = 1;
// Init
int element_num = 1;
for (int i = 0; i < dims; i++) {
numel *= shape[i];
element_num *= shape[i];
}
if (element_num <= 0) {
MS_LOG(ERROR) << "Element is less than or equal to 0.";
return FLT_MAX;
}
int bucket_count = shape[preferred_dim];
std::vector<float> norms2(bucket_count);
@ -37,41 +90,18 @@ float MixedBitWeightQuantizer::MeasureQuantizationError(float *weights, const in
norms2[i] = 0.0;
dnorms2[i] = 0.0;
}
double average_dequant = 0;
double average_raw = 0;
std::vector<float> origin_dequant_datas(numel);
std::vector<float> corr_dequant_datas(numel);
// Bucketing
std::vector<float> origin_dequant_datas(element_num);
std::vector<float> corr_dequant_datas(element_num);
int bucket_volume = 1;
for (int i = preferred_dim; i < dims; i++) {
bucket_volume *= shape[i];
}
for (int i = 0; i < numel; i++) {
float dequant = scale * (floorf(weights[i] / scale + 0.5));
origin_dequant_datas[i] = dequant;
average_raw += weights[i];
average_dequant += dequant;
}
// mean
average_dequant = average_dequant / numel;
average_raw = average_raw / numel;
// std
double variance_dequant = 0;
double variance_raw = 0;
const int exponent = 2;
for (int i = 0; i < numel; i++) {
variance_dequant += std::pow(origin_dequant_datas[i] - average_dequant, exponent);
variance_raw += std::pow(weights[i] - average_raw, exponent);
}
variance_dequant = std::sqrt(variance_dequant / numel);
variance_raw = std::sqrt(variance_raw / numel);
if (variance_dequant == 0) {
var_corr_ = 1;
} else {
var_corr_ = variance_raw / variance_dequant;
}
mean_corr_ = average_raw - average_dequant * var_corr_;
for (int i = 0; i < numel; i++) {
// Bias Correction
GetBiasCorrection(weights, element_num, scale, origin_dequant_datas.data());
for (int i = 0; i < element_num; i++) {
int bucket = (i / bucket_volume) % bucket_count;
norms2[bucket] += weights[i] * weights[i];
float dequant = var_corr_ * (scale * (floorf(weights[i] / scale + 0.5))) + mean_corr_;
@ -79,18 +109,7 @@ float MixedBitWeightQuantizer::MeasureQuantizationError(float *weights, const in
float d = weights[i] - dequant;
dnorms2[bucket] += d * d;
}
int c = 0;
float t = 1e-10;
for (int i = 0; i < bucket_count; i++) {
if (norms2[i] < 1.0e-10) continue;
c += 1;
t += sqrtf(dnorms2[i] / norms2[i]);
}
auto meam_error = t / (c + 1e-7);
auto cos_sim = mindspore::lite::GetCosSimilarity(weights, corr_dequant_datas.data(), numel);
MS_LOG(INFO) << " meam_error:" << meam_error << " cos_sim:" << cos_sim;
auto meam_error = CalculateMeanError(norms2, dnorms2);
return meam_error;
}
@ -176,8 +195,6 @@ int MixedBitWeightQuantizer::DoQuantization(float *weights, std::vector<int64_t>
MS_LOG(ERROR) << "quant failed.";
return RET_ERROR;
}
// It is used to calculate the Shannon entropy.
quant_params->push_back(quant_param);
return RET_OK;
}

View File

@ -53,6 +53,10 @@ class MixedBitWeightQuantizer {
BinarySearchResult BinarySearchForQuantizationScale(float *weights, int *shape, int dims, int preferred_dim,
int max_iters, float target_err, float rel_tol);
void GetBiasCorrection(float *weights, int element_num, float scale, float *origin_dequant_datas);
float CalculateMeanError(std::vector<float> norms2, std::vector<float> dnorms2);
private:
float var_corr_{1};
float mean_corr_{0};