forked from mindspore-Ecosystem/mindspore
!27468 optimize mixed bit weight quant
Merge pull request !27468 from yeyunpeng2020/quant
This commit is contained in:
commit
31b17f273d
|
@ -24,8 +24,8 @@ file_identifier "MSL2";
|
|||
file_extension "ms";
|
||||
|
||||
table QuantParam {
|
||||
scale: double;
|
||||
zeroPoint: int;
|
||||
scale: double = 1;
|
||||
zeroPoint: int = 0;
|
||||
min: double = 0;
|
||||
max: double = 0;
|
||||
narrowRange: bool = true;
|
||||
|
|
|
@ -23,10 +23,11 @@ namespace mindspore::lite::quant {
|
|||
namespace {
|
||||
constexpr int8_t kCurrentBitCount = 64;
|
||||
constexpr int8_t kTableSize = 6;
|
||||
constexpr size_t kInt32Mask = 31;
|
||||
} // namespace
|
||||
int BitStream::Create(int bit_capacity) {
|
||||
int FSEBitStream::Create(int bit_capacity) {
|
||||
chunk_count_ = (bit_capacity >> kTableSize);
|
||||
chunks_ = static_cast<uint64_t *>(calloc(chunk_count_, sizeof(uint64_t)));
|
||||
chunks_ = static_cast<uint64_t *>(malloc(chunk_count_ * sizeof(uint64_t)));
|
||||
if (chunks_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc memory failed.";
|
||||
return RET_ERROR;
|
||||
|
@ -35,7 +36,7 @@ int BitStream::Create(int bit_capacity) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void BitStream::Free() {
|
||||
void FSEBitStream::Free() {
|
||||
curr_chunk_index_ = -1;
|
||||
curr_chunk_ = 0;
|
||||
curr_bit_count_ = 0;
|
||||
|
@ -46,7 +47,7 @@ void BitStream::Free() {
|
|||
}
|
||||
}
|
||||
|
||||
void BitStream::Empty() {
|
||||
void FSEBitStream::Empty() {
|
||||
curr_chunk_index_ = -1;
|
||||
curr_chunk_ = 0;
|
||||
curr_bit_count_ = 0;
|
||||
|
@ -55,7 +56,7 @@ void BitStream::Empty() {
|
|||
}
|
||||
}
|
||||
|
||||
int64_t BitStream::Pop(uint8_t bit_count) {
|
||||
int64_t FSEBitStream::Pop(uint8_t bit_count) {
|
||||
MS_ASSERT(curr_bit_count_ <= kCurrentBitCount);
|
||||
int64_t right = curr_chunk_ >> (kCurrentBitCount - curr_bit_count_);
|
||||
int64_t res = right & ((1 << bit_count) - 1);
|
||||
|
@ -81,7 +82,7 @@ int64_t BitStream::Pop(uint8_t bit_count) {
|
|||
return right;
|
||||
}
|
||||
|
||||
void BitStream::Push(int64_t state, uint8_t bit_count) {
|
||||
void FSEBitStream::Push(int64_t state, uint8_t bit_count) {
|
||||
curr_bit_count_ += bit_count;
|
||||
if (curr_bit_count_ <= kCurrentBitCount) {
|
||||
// happy path, no split
|
||||
|
@ -104,5 +105,23 @@ void BitStream::Push(int64_t state, uint8_t bit_count) {
|
|||
}
|
||||
}
|
||||
|
||||
void BitStream::Flush() { curr_chunk_ <<= kCurrentBitCount - curr_bit_count_; }
|
||||
void FSEBitStream::Flush() { curr_chunk_ <<= kCurrentBitCount - curr_bit_count_; }
|
||||
|
||||
// The function gives the index of most import `1` in the binary representation.
|
||||
// e.g. for the number 00100 it gives 2.
|
||||
int FSEBitStream::CountBits(int32_t x) {
|
||||
#ifdef _MSC_VER
|
||||
int num = 0;
|
||||
uint32_t tmp = x;
|
||||
tmp |= 1;
|
||||
while (!(tmp & INT32_MIN)) {
|
||||
num += 1;
|
||||
tmp <<= 1;
|
||||
}
|
||||
return num ^ kInt32Mask;
|
||||
#else
|
||||
return __builtin_clz(x) ^ kInt32Mask;
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
} // namespace mindspore::lite::quant
|
||||
|
|
|
@ -19,11 +19,11 @@
|
|||
#include <cstdint>
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
class BitStream {
|
||||
class FSEBitStream {
|
||||
public:
|
||||
BitStream() = default;
|
||||
FSEBitStream() = default;
|
||||
|
||||
~BitStream() = default;
|
||||
~FSEBitStream() = default;
|
||||
|
||||
public:
|
||||
int Create(int bit_capacity);
|
||||
|
@ -32,6 +32,7 @@ class BitStream {
|
|||
int64_t Pop(uint8_t bit_count);
|
||||
void Push(int64_t state, uint8_t bit_count);
|
||||
void Flush();
|
||||
static int CountBits(int32_t x);
|
||||
|
||||
int32_t GetCurrChunkIndex() { return this->curr_chunk_index_; }
|
||||
uint64_t GetCurrChunk() { return this->curr_chunk_; }
|
||||
|
|
|
@ -23,6 +23,10 @@
|
|||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
namespace {
|
||||
constexpr int kTableExtend = 3;
|
||||
constexpr int kAlignOffset = 7;
|
||||
} // namespace
|
||||
int FSEDecoder::FSECreateStatesForDecoding(const uint32_t *symbol_frequency, int symbol_frequency_count, int table_log,
|
||||
uint16_t *new_state, uint8_t *bit_count, uint16_t *symbol_table) {
|
||||
MS_ASSERT(symbol_frequency != nullptr);
|
||||
|
@ -30,8 +34,8 @@ int FSEDecoder::FSECreateStatesForDecoding(const uint32_t *symbol_frequency, int
|
|||
MS_ASSERT(bit_count != nullptr);
|
||||
MS_ASSERT(symbol_table != nullptr);
|
||||
const int table_size = 1 << table_log;
|
||||
int table_mask = table_size - 1;
|
||||
int step = ((table_size >> 1) + (table_size >> 3) + 3);
|
||||
const int table_mask = table_size - 1;
|
||||
int step = ((table_size >> 1) + (table_size >> kTableExtend) + kTableExtend);
|
||||
int pos = 0;
|
||||
for (int sym = 0; sym < symbol_frequency_count; sym++) {
|
||||
for (uint32_t i = 0; i < symbol_frequency[sym]; i++) {
|
||||
|
@ -52,24 +56,13 @@ int FSEDecoder::FSECreateStatesForDecoding(const uint32_t *symbol_frequency, int
|
|||
uint16_t sym = symbol_table[i];
|
||||
uint32_t x = frequency[sym];
|
||||
frequency[sym] += 1;
|
||||
#ifdef _MSC_VER
|
||||
int num = 0;
|
||||
uint32_t tmp = x;
|
||||
tmp |= 1;
|
||||
while (!(tmp & 0x80000000)) {
|
||||
num += 1;
|
||||
tmp <<= 1;
|
||||
}
|
||||
bit_count[i] = table_log - (num ^ 31);
|
||||
#else
|
||||
bit_count[i] = table_log - (__builtin_clz(x) ^ 31);
|
||||
#endif
|
||||
bit_count[i] = table_log - FSEBitStream::CountBits(x);
|
||||
new_state[i] = (x << bit_count[i]) - table_size;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int FSEDecoder::FSEDecode(BitStream *bs, float *buff, int buff_count, uint32_t *frequency, int frequency_count,
|
||||
int FSEDecoder::FSEDecode(FSEBitStream *bs, float *buff, int buff_count, uint32_t *frequency, int frequency_count,
|
||||
const float *centroids, int table_log) {
|
||||
MS_ASSERT(bs != nullptr);
|
||||
MS_ASSERT(buff != nullptr);
|
||||
|
@ -122,10 +115,10 @@ int FSEDecoder::DeCompress(const SchemaTensorWrapper &src_tensor, Tensor *dst_te
|
|||
int out_sz = dst_tensor->ElementsNum();
|
||||
MS_CHECK_GT(out_sz, 0, RET_ERROR);
|
||||
// deserialize from `data`:
|
||||
BitStream bs;
|
||||
FSEBitStream bs;
|
||||
|
||||
size_t i = 0;
|
||||
auto data8 = reinterpret_cast<unsigned char *>(const_cast<void *>(src_tensor.data()));
|
||||
auto data8 = reinterpret_cast<int8_t *>(const_cast<void *>(src_tensor.data()));
|
||||
|
||||
int frequency_count = *(reinterpret_cast<uint16_t *>(&data8[i]));
|
||||
i += sizeof(uint16_t);
|
||||
|
@ -152,7 +145,7 @@ int FSEDecoder::DeCompress(const SchemaTensorWrapper &src_tensor, Tensor *dst_te
|
|||
auto *frequency = reinterpret_cast<uint32_t *>(&data8[i]);
|
||||
i += frequency_count * sizeof(uint32_t);
|
||||
// Used for 8-byte alignment
|
||||
i = ((i + 7) >> 3) << 3;
|
||||
i = ((i + kAlignOffset) >> kTableExtend) << kTableExtend;
|
||||
if (i > total_size) {
|
||||
MS_LOG(ERROR) << "index over total size"
|
||||
<< " index:" << i << " total size:" << total_size;
|
||||
|
@ -162,7 +155,7 @@ int FSEDecoder::DeCompress(const SchemaTensorWrapper &src_tensor, Tensor *dst_te
|
|||
auto centroids_float = reinterpret_cast<float *>(centroids);
|
||||
i += frequency_count * sizeof(float);
|
||||
// Used for 8-byte alignment
|
||||
i = ((i + 7) >> 3) << 3;
|
||||
i = ((i + kAlignOffset) >> kTableExtend) << kTableExtend;
|
||||
if (i > total_size) {
|
||||
MS_LOG(ERROR) << "index over total size"
|
||||
<< " index:" << i << " total size:" << total_size;
|
||||
|
|
|
@ -31,7 +31,7 @@ class FSEDecoder {
|
|||
static int DeCompress(const SchemaTensorWrapper &src_tensor, Tensor *dst_tensor);
|
||||
|
||||
private:
|
||||
static int FSEDecode(BitStream *bs, float *buff, int buff_count, uint32_t *frequency, int frequency_count,
|
||||
static int FSEDecode(FSEBitStream *bs, float *buff, int buff_count, uint32_t *frequency, int frequency_count,
|
||||
const float *centroids, int table_log);
|
||||
|
||||
static int FSECreateStatesForDecoding(const uint32_t *symbol_frequency, int symbol_frequency_count, int table_log,
|
||||
|
|
|
@ -26,16 +26,12 @@
|
|||
|
||||
namespace mindspore::lite::quant {
|
||||
namespace {
|
||||
constexpr int kInt32Mask = 31;
|
||||
constexpr int kInt16 = 16;
|
||||
constexpr int kFseTableExtendSize = 3;
|
||||
constexpr int kFrenqTableExtendSize = 2;
|
||||
constexpr int kAlignSize = 8;
|
||||
constexpr float kUpRoundOffSet = 0.5;
|
||||
} // namespace
|
||||
// The function gives the index of most import `1` in the binary representation.
|
||||
// e.g. for the number 00100 it gives 2.
|
||||
int fse_count_bits(int32_t x) { return __builtin_clz(x) ^ kInt32Mask; }
|
||||
|
||||
int FSEEncoder::FSECreateStatesForEncoding(uint32_t *frequency, int frequency_count, int table_log,
|
||||
uint32_t *delta_bit_count, int16_t *delta_state, uint16_t *coding_table,
|
||||
|
@ -76,7 +72,7 @@ int FSEEncoder::FSECreateStatesForEncoding(uint32_t *frequency, int frequency_co
|
|||
int total = 0;
|
||||
for (int sym = 0; sym < frequency_count; sym++) {
|
||||
if (frequency[sym] >= kFrenqTableExtendSize) {
|
||||
int max_bits_out = table_log - fse_count_bits(frequency[sym] - 1);
|
||||
int max_bits_out = table_log - FSEBitStream::CountBits(frequency[sym] - 1);
|
||||
int min_state_plus = frequency[sym] << max_bits_out;
|
||||
delta_bit_count[sym] = (max_bits_out << kInt16) - min_state_plus;
|
||||
delta_state[sym] = total - frequency[sym];
|
||||
|
@ -159,10 +155,10 @@ int FSEEncoder::Compress(schema::TensorT *tensor_input) {
|
|||
MS_LOG(ERROR) << "Normalize frequency failed.";
|
||||
return ret;
|
||||
}
|
||||
BitStream bs;
|
||||
FSEBitStream bs;
|
||||
ret = bs.Create(kInt16 * fse_quant.symbol_table_count);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "BitStream Create failed.";
|
||||
MS_LOG(ERROR) << "FSEBitStream Create failed.";
|
||||
free(fse_quant.symbol_table);
|
||||
return ret;
|
||||
}
|
||||
|
@ -186,7 +182,7 @@ int FSEEncoder::Compress(schema::TensorT *tensor_input) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
uint16_t FSEEncoder::FSEEncodeSymbolGetNewState(BitStream *bs, uint16_t sym, uint16_t state,
|
||||
uint16_t FSEEncoder::FSEEncodeSymbolGetNewState(FSEBitStream *bs, uint16_t sym, uint16_t state,
|
||||
const uint32_t *delta_bit_count, const int16_t *delta_state,
|
||||
uint16_t *coding_table) {
|
||||
MS_ASSERT(bs != nullptr);
|
||||
|
@ -219,7 +215,7 @@ int FSEEncoder::NormalizeFrequency(FSEQuant *q, int *table_log) {
|
|||
CHECK_NULL_RETURN(table_log);
|
||||
// The higher the number, the more accurate we'll be to the shannon entropy,
|
||||
// but also the larger the table, so `+3` is a good compromise.
|
||||
*table_log = std::min(MAX_TABLE_LOG, (fse_count_bits((uint32_t)q->size) + kFseTableExtendSize));
|
||||
*table_log = std::min(MAX_TABLE_LOG, (FSEBitStream::CountBits((uint32_t)q->size) + kFseTableExtendSize));
|
||||
const int new_table_size = 1 << (*table_log);
|
||||
int curr_table_size = 0;
|
||||
for (int i = 0; i < q->size; i++) {
|
||||
|
@ -270,8 +266,8 @@ int FSEEncoder::NormalizeFrequency(FSEQuant *q, int *table_log) {
|
|||
// - determine nbBits, flush them
|
||||
// - determine sub-Range Id
|
||||
// - look for Symbol position of same Id : you get your next state
|
||||
int FSEEncoder::FSEEncode(BitStream *bs, const uint16_t *data, int data_count, uint32_t *frequency, int frequency_count,
|
||||
int table_log) {
|
||||
int FSEEncoder::FSEEncode(FSEBitStream *bs, const uint16_t *data, int data_count, uint32_t *frequency,
|
||||
int frequency_count, int table_log) {
|
||||
MS_ASSERT(bs != nullptr);
|
||||
MS_ASSERT(data != nullptr);
|
||||
MS_ASSERT(frequency != nullptr);
|
||||
|
@ -305,7 +301,7 @@ int FSEEncoder::FSEEncode(BitStream *bs, const uint16_t *data, int data_count, u
|
|||
return ret;
|
||||
}
|
||||
|
||||
int FSEEncoder::SerializingToTensor(schema::TensorT *tensor_input, BitStream *bs, const FSEQuant &fse_quant,
|
||||
int FSEEncoder::SerializingToTensor(schema::TensorT *tensor_input, FSEBitStream *bs, const FSEQuant &fse_quant,
|
||||
int table_log, uint8_t *out8, size_t max_size, size_t *out_size) {
|
||||
MSLITE_CHECK_PTR(tensor_input);
|
||||
MSLITE_CHECK_PTR(bs);
|
||||
|
@ -400,7 +396,7 @@ int FSEEncoder::SerializingToTensor(schema::TensorT *tensor_input, BitStream *bs
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, const FSEQuant &fse_quant,
|
||||
int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, FSEBitStream *bs, const FSEQuant &fse_quant,
|
||||
int table_log) {
|
||||
MSLITE_CHECK_PTR(tensor_input);
|
||||
MSLITE_CHECK_PTR(bs);
|
||||
|
|
|
@ -41,17 +41,17 @@ class FSEEncoder {
|
|||
int FSECreateStatesForEncoding(uint32_t *frequency, int frequency_count, int table_log, uint32_t *delta_bit_count,
|
||||
int16_t *delta_state, uint16_t *coding_table, uint16_t *symbol_table);
|
||||
|
||||
uint16_t FSEEncodeSymbolGetNewState(BitStream *bs, uint16_t sym, uint16_t state, const uint32_t *delta_bit_count,
|
||||
uint16_t FSEEncodeSymbolGetNewState(FSEBitStream *bs, uint16_t sym, uint16_t state, const uint32_t *delta_bit_count,
|
||||
const int16_t *delta_state, uint16_t *coding_table);
|
||||
|
||||
int FSEEncode(BitStream *bs, const uint16_t *data, int data_count, uint32_t *frequency, int frequency_count,
|
||||
int FSEEncode(FSEBitStream *bs, const uint16_t *data, int data_count, uint32_t *frequency, int frequency_count,
|
||||
int table_log);
|
||||
|
||||
int NormalizeFrequency(FSEQuant *q, int *table_log);
|
||||
|
||||
int SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, const FSEQuant &fse_quant, int table_log);
|
||||
int SerializingToOut(schema::TensorT *tensor_input, FSEBitStream *bs, const FSEQuant &fse_quant, int table_log);
|
||||
|
||||
int SerializingToTensor(schema::TensorT *tensor_input, BitStream *bs, const FSEQuant &fse_quant, int table_log,
|
||||
int SerializingToTensor(schema::TensorT *tensor_input, FSEBitStream *bs, const FSEQuant &fse_quant, int table_log,
|
||||
uint8_t *out8, size_t max_size, size_t *offset);
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
|
|
|
@ -147,7 +147,8 @@ int FullQuantQuantizer::SetInOutQuantParam(const AnfNodePtr &input_node, const s
|
|||
if (type_id == kNumberTypeFloat32 && info != nullptr) {
|
||||
auto scale = info->GetScale();
|
||||
if (scale == 0) {
|
||||
MS_LOG(WARNING) << input_node->fullname_with_scope() << " input index:" << index
|
||||
std::string in_out = is_input ? " input" : " output";
|
||||
MS_LOG(WARNING) << input_node->fullname_with_scope() << in_out << " index:" << index
|
||||
<< " values are very close to 0, so set the scale to 1.";
|
||||
quant_param.scale = 1;
|
||||
} else {
|
||||
|
|
|
@ -19,16 +19,69 @@
|
|||
#include "tools/common/statistic_utils.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
void MixedBitWeightQuantizer::GetBiasCorrection(float *weights, int element_num, float scale,
|
||||
float *origin_dequant_datas) {
|
||||
MS_ASSERT(element_num > 0);
|
||||
double average_dequant = 0;
|
||||
double average_raw = 0;
|
||||
for (int i = 0; i < element_num; i++) {
|
||||
float dequant = scale * (floorf(weights[i] / scale + 0.5));
|
||||
origin_dequant_datas[i] = dequant;
|
||||
average_raw += weights[i];
|
||||
average_dequant += dequant;
|
||||
}
|
||||
|
||||
// mean
|
||||
average_dequant = average_dequant / element_num;
|
||||
average_raw = average_raw / element_num;
|
||||
// std
|
||||
double variance_dequant = 0;
|
||||
double variance_raw = 0;
|
||||
const int exponent = 2;
|
||||
for (int i = 0; i < element_num; i++) {
|
||||
variance_dequant += std::pow(origin_dequant_datas[i] - average_dequant, exponent);
|
||||
variance_raw += std::pow(weights[i] - average_raw, exponent);
|
||||
}
|
||||
variance_dequant = std::sqrt(variance_dequant / element_num);
|
||||
variance_raw = std::sqrt(variance_raw / element_num);
|
||||
if (variance_dequant == 0) {
|
||||
var_corr_ = 1;
|
||||
} else {
|
||||
var_corr_ = variance_raw / variance_dequant;
|
||||
}
|
||||
mean_corr_ = average_raw - average_dequant * var_corr_;
|
||||
}
|
||||
|
||||
// the error is currently measured per channel.
|
||||
// it could be measured per layer but it would be less good.
|
||||
float MixedBitWeightQuantizer::CalculateMeanError(std::vector<float> norms2, std::vector<float> dnorms2) {
|
||||
int error_count = 0;
|
||||
float mse_error = 1e-10f;
|
||||
const float soft = 1e-7f;
|
||||
const float tolerance_error = 1.0e-10f;
|
||||
for (size_t i = 0; i < norms2.size(); i++) {
|
||||
if (norms2[i] < tolerance_error) {
|
||||
continue;
|
||||
}
|
||||
error_count += 1;
|
||||
mse_error += sqrtf(dnorms2[i] / norms2[i]);
|
||||
}
|
||||
auto meam_error = mse_error / (error_count + soft);
|
||||
return meam_error;
|
||||
}
|
||||
|
||||
// the `preferred` dim should point to the output channels dimension.
|
||||
float MixedBitWeightQuantizer::MeasureQuantizationError(float *weights, const int *shape, int dims, int preferred_dim,
|
||||
float scale) {
|
||||
MS_ASSERT(weights != nullptr);
|
||||
MS_ASSERT(shape != nullptr);
|
||||
int numel = 1;
|
||||
// Init
|
||||
int element_num = 1;
|
||||
for (int i = 0; i < dims; i++) {
|
||||
numel *= shape[i];
|
||||
element_num *= shape[i];
|
||||
}
|
||||
if (element_num <= 0) {
|
||||
MS_LOG(ERROR) << "Element is less than or equal to 0.";
|
||||
return FLT_MAX;
|
||||
}
|
||||
int bucket_count = shape[preferred_dim];
|
||||
std::vector<float> norms2(bucket_count);
|
||||
|
@ -37,41 +90,18 @@ float MixedBitWeightQuantizer::MeasureQuantizationError(float *weights, const in
|
|||
norms2[i] = 0.0;
|
||||
dnorms2[i] = 0.0;
|
||||
}
|
||||
double average_dequant = 0;
|
||||
double average_raw = 0;
|
||||
std::vector<float> origin_dequant_datas(numel);
|
||||
std::vector<float> corr_dequant_datas(numel);
|
||||
|
||||
// Bucketing
|
||||
std::vector<float> origin_dequant_datas(element_num);
|
||||
std::vector<float> corr_dequant_datas(element_num);
|
||||
int bucket_volume = 1;
|
||||
for (int i = preferred_dim; i < dims; i++) {
|
||||
bucket_volume *= shape[i];
|
||||
}
|
||||
for (int i = 0; i < numel; i++) {
|
||||
float dequant = scale * (floorf(weights[i] / scale + 0.5));
|
||||
origin_dequant_datas[i] = dequant;
|
||||
average_raw += weights[i];
|
||||
average_dequant += dequant;
|
||||
}
|
||||
// mean
|
||||
average_dequant = average_dequant / numel;
|
||||
average_raw = average_raw / numel;
|
||||
// std
|
||||
double variance_dequant = 0;
|
||||
double variance_raw = 0;
|
||||
const int exponent = 2;
|
||||
for (int i = 0; i < numel; i++) {
|
||||
variance_dequant += std::pow(origin_dequant_datas[i] - average_dequant, exponent);
|
||||
variance_raw += std::pow(weights[i] - average_raw, exponent);
|
||||
}
|
||||
variance_dequant = std::sqrt(variance_dequant / numel);
|
||||
variance_raw = std::sqrt(variance_raw / numel);
|
||||
if (variance_dequant == 0) {
|
||||
var_corr_ = 1;
|
||||
} else {
|
||||
var_corr_ = variance_raw / variance_dequant;
|
||||
}
|
||||
mean_corr_ = average_raw - average_dequant * var_corr_;
|
||||
|
||||
for (int i = 0; i < numel; i++) {
|
||||
// Bias Correction
|
||||
GetBiasCorrection(weights, element_num, scale, origin_dequant_datas.data());
|
||||
for (int i = 0; i < element_num; i++) {
|
||||
int bucket = (i / bucket_volume) % bucket_count;
|
||||
norms2[bucket] += weights[i] * weights[i];
|
||||
float dequant = var_corr_ * (scale * (floorf(weights[i] / scale + 0.5))) + mean_corr_;
|
||||
|
@ -79,18 +109,7 @@ float MixedBitWeightQuantizer::MeasureQuantizationError(float *weights, const in
|
|||
float d = weights[i] - dequant;
|
||||
dnorms2[bucket] += d * d;
|
||||
}
|
||||
|
||||
int c = 0;
|
||||
float t = 1e-10;
|
||||
for (int i = 0; i < bucket_count; i++) {
|
||||
if (norms2[i] < 1.0e-10) continue;
|
||||
c += 1;
|
||||
t += sqrtf(dnorms2[i] / norms2[i]);
|
||||
}
|
||||
auto meam_error = t / (c + 1e-7);
|
||||
|
||||
auto cos_sim = mindspore::lite::GetCosSimilarity(weights, corr_dequant_datas.data(), numel);
|
||||
MS_LOG(INFO) << " meam_error:" << meam_error << " cos_sim:" << cos_sim;
|
||||
auto meam_error = CalculateMeanError(norms2, dnorms2);
|
||||
return meam_error;
|
||||
}
|
||||
|
||||
|
@ -176,8 +195,6 @@ int MixedBitWeightQuantizer::DoQuantization(float *weights, std::vector<int64_t>
|
|||
MS_LOG(ERROR) << "quant failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// It is used to calculate the Shannon entropy.
|
||||
quant_params->push_back(quant_param);
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -53,6 +53,10 @@ class MixedBitWeightQuantizer {
|
|||
BinarySearchResult BinarySearchForQuantizationScale(float *weights, int *shape, int dims, int preferred_dim,
|
||||
int max_iters, float target_err, float rel_tol);
|
||||
|
||||
void GetBiasCorrection(float *weights, int element_num, float scale, float *origin_dequant_datas);
|
||||
|
||||
float CalculateMeanError(std::vector<float> norms2, std::vector<float> dnorms2);
|
||||
|
||||
private:
|
||||
float var_corr_{1};
|
||||
float mean_corr_{0};
|
||||
|
|
Loading…
Reference in New Issue