optimize magic number

This commit is contained in:
yeyunpeng2020 2021-09-28 11:29:49 +08:00
parent 5bd7e4def9
commit 8370571b65
6 changed files with 46 additions and 28 deletions

View File

@ -27,6 +27,10 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
constexpr int kMinSize = 0;
constexpr int kMaxSize = 65535;
} // namespace
int PreprocessParser::ParseInputType(const std::string &input_type_str, preprocess::InputType *input_type) { int PreprocessParser::ParseInputType(const std::string &input_type_str, preprocess::InputType *input_type) {
if (input_type_str == "IMAGE") { if (input_type_str == "IMAGE") {
(*input_type) = preprocess::IMAGE; (*input_type) = preprocess::IMAGE;
@ -56,7 +60,7 @@ int PreprocessParser::ParsePreprocess(const DataPreProcessString &data_pre_proce
MS_LOG(ERROR) << "calibrate_size should be a valid number."; MS_LOG(ERROR) << "calibrate_size should be a valid number.";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
if (data_pre_process->calibrate_size < 0) { if (data_pre_process->calibrate_size < kMinSize) {
MS_LOG(ERROR) << "calibrate_size must larger 0."; MS_LOG(ERROR) << "calibrate_size must larger 0.";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
@ -190,6 +194,7 @@ int PreprocessParser::CollectCalibInputs(const std::map<std::string, std::string
image_dir = readdir(root); image_dir = readdir(root);
} }
closedir(root); closedir(root);
std::sort(inputs->at(image_path.first).begin(), inputs->at(image_path.first).end());
if (count != limited_count) { if (count != limited_count) {
MS_LOG(ERROR) << " data path: " << image_path.second << " data count:" << count MS_LOG(ERROR) << " data path: " << image_path.second << " data count:" << count
<< " < limited_count:" << limited_count; << " < limited_count:" << limited_count;
@ -221,7 +226,7 @@ int PreprocessParser::ParseImageResize(const DataPreProcessString &data_pre_proc
MS_LOG(ERROR) << "resize_width should be a valid number."; MS_LOG(ERROR) << "resize_width should be a valid number.";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
if (image_pre_process->resize_width <= 0 || image_pre_process->resize_width > 65535) { if (image_pre_process->resize_width <= kMinSize || image_pre_process->resize_width > kMaxSize) {
MS_LOG(ERROR) << "resize_width must be in [1, 65535]."; MS_LOG(ERROR) << "resize_width must be in [1, 65535].";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
@ -231,7 +236,7 @@ int PreprocessParser::ParseImageResize(const DataPreProcessString &data_pre_proc
MS_LOG(ERROR) << "resize_width should be a valid number."; MS_LOG(ERROR) << "resize_width should be a valid number.";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
if (image_pre_process->resize_height <= 0 || image_pre_process->resize_height > 65535) { if (image_pre_process->resize_height <= kMinSize || image_pre_process->resize_height > kMaxSize) {
MS_LOG(ERROR) << "resize_height must be in [1, 65535]."; MS_LOG(ERROR) << "resize_height must be in [1, 65535].";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
@ -253,7 +258,7 @@ int PreprocessParser::ParseImageCenterCrop(const DataPreProcessString &data_pre_
MS_LOG(ERROR) << "center_crop_width should be a valid number."; MS_LOG(ERROR) << "center_crop_width should be a valid number.";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
if (image_pre_process->center_crop_width <= 0 || image_pre_process->center_crop_width > 65535) { if (image_pre_process->center_crop_width <= kMinSize || image_pre_process->center_crop_width > kMaxSize) {
MS_LOG(ERROR) << "center_crop_width must be in [1, 65535]."; MS_LOG(ERROR) << "center_crop_width must be in [1, 65535].";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
@ -263,7 +268,7 @@ int PreprocessParser::ParseImageCenterCrop(const DataPreProcessString &data_pre_
MS_LOG(ERROR) << "center_crop_height should be a valid number."; MS_LOG(ERROR) << "center_crop_height should be a valid number.";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
if (image_pre_process->center_crop_height <= 0 || image_pre_process->center_crop_height > 65535) { if (image_pre_process->center_crop_height <= kMinSize || image_pre_process->center_crop_height > kMaxSize) {
MS_LOG(ERROR) << "center_crop_height must be in [1, 65535]."; MS_LOG(ERROR) << "center_crop_height must be in [1, 65535].";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }

View File

@ -23,6 +23,8 @@ namespace lite {
namespace { namespace {
constexpr int kQuantBitNumInt16 = 16; constexpr int kQuantBitNumInt16 = 16;
constexpr int kQuantBitNumInt8 = 8; constexpr int kQuantBitNumInt8 = 8;
constexpr int kMinSize = 0;
constexpr int kMaxSize = 65535;
} // namespace } // namespace
int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_string, int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_string,
quant::CommonQuantParam *common_quant) { quant::CommonQuantParam *common_quant) {
@ -59,12 +61,12 @@ int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_str
MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should be a valid number."; MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should be a valid number.";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
if (common_quant->min_quant_weight_size < 0 || common_quant->min_quant_weight_size > 65535) { if (common_quant->min_quant_weight_size < kMinSize || common_quant->min_quant_weight_size > kMaxSize) {
MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_size should in [0,65535]." << std::endl; MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_size should in [0,65535]." << std::endl;
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
if (common_quant->min_quant_weight_channel < 0 || common_quant->min_quant_weight_channel > 65535) { if (common_quant->min_quant_weight_channel < kMinSize || common_quant->min_quant_weight_channel > kMaxSize) {
MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should in [0,65535]." << std::endl; MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should in [0,65535]." << std::endl;
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }

View File

@ -22,9 +22,10 @@
namespace mindspore::lite::quant { namespace mindspore::lite::quant {
namespace { namespace {
constexpr int8_t kCurrentBitCount = 64; constexpr int8_t kCurrentBitCount = 64;
constexpr int8_t kTableSize = 6;
} // namespace } // namespace
int BitStream::Create(int bit_capacity) { int BitStream::Create(int bit_capacity) {
chunk_count_ = (bit_capacity >> 6); chunk_count_ = (bit_capacity >> kTableSize);
chunks_ = static_cast<uint64_t *>(calloc(chunk_count_, sizeof(uint64_t))); chunks_ = static_cast<uint64_t *>(calloc(chunk_count_, sizeof(uint64_t)));
if (chunks_ == nullptr) { if (chunks_ == nullptr) {
MS_LOG(ERROR) << "malloc memory failed."; MS_LOG(ERROR) << "malloc memory failed.";

View File

@ -23,9 +23,17 @@
#include "include/errorcode.h" #include "include/errorcode.h"
namespace mindspore::lite::quant { namespace mindspore::lite::quant {
namespace {
constexpr int kInt32Mask = 31;
constexpr int kInt16 = 16;
constexpr int kFseTableExtendSize = 3;
constexpr int kFrenqTableExtendSize = 2;
constexpr int kAlignSize = 8;
constexpr float kUpRoundOffSet = 0.5;
} // namespace
// The function gives the index of most import `1` in the binary representation. // The function gives the index of most import `1` in the binary representation.
// e.g. for the number 00100 it gives 2. // e.g. for the number 00100 it gives 2.
int fse_count_bits(int32_t x) { return __builtin_clz(x) ^ 31; } int fse_count_bits(int32_t x) { return __builtin_clz(x) ^ kInt32Mask; }
int FSEEncoder::FSECreateStatesForEncoding(uint16_t *frequency, int frequency_count, int table_log, int FSEEncoder::FSECreateStatesForEncoding(uint16_t *frequency, int frequency_count, int table_log,
uint32_t *delta_bit_count, int16_t *delta_state, uint16_t *coding_table, uint32_t *delta_bit_count, int16_t *delta_state, uint16_t *coding_table,
@ -37,7 +45,7 @@ int FSEEncoder::FSECreateStatesForEncoding(uint16_t *frequency, int frequency_co
MS_ASSERT(coding_table != nullptr); MS_ASSERT(coding_table != nullptr);
int tablesize = 1 << table_log; int tablesize = 1 << table_log;
int tablemask = tablesize - 1; int tablemask = tablesize - 1;
int step = ((tablesize >> 1) + (tablesize >> 3) + 3); int step = ((tablesize >> 1) + (tablesize >> kFseTableExtendSize) + kFseTableExtendSize);
int pos = 0; int pos = 0;
// Separate the same symbols, coding will be better if the same characters are distributed evenly across the table. // Separate the same symbols, coding will be better if the same characters are distributed evenly across the table.
for (int sym = 0; sym < frequency_count; sym++) { for (int sym = 0; sym < frequency_count; sym++) {
@ -49,7 +57,7 @@ int FSEEncoder::FSECreateStatesForEncoding(uint16_t *frequency, int frequency_co
} }
if (pos != 0) return 1; if (pos != 0) return 1;
std::vector<uint16_t> cfreqs(frequency_count + 2); std::vector<uint16_t> cfreqs(frequency_count + kFrenqTableExtendSize);
cfreqs[0] = 0; cfreqs[0] = 0;
for (int i = 1; i < frequency_count + 1; i++) { for (int i = 1; i < frequency_count + 1; i++) {
cfreqs[i] = cfreqs[i - 1] + frequency[i - 1]; cfreqs[i] = cfreqs[i - 1] + frequency[i - 1];
@ -63,15 +71,15 @@ int FSEEncoder::FSECreateStatesForEncoding(uint16_t *frequency, int frequency_co
int total = 0; int total = 0;
for (int sym = 0; sym < frequency_count; sym++) { for (int sym = 0; sym < frequency_count; sym++) {
if (frequency[sym] >= 2) { if (frequency[sym] >= kFrenqTableExtendSize) {
int max_bits_out = table_log - fse_count_bits(frequency[sym] - 1); int max_bits_out = table_log - fse_count_bits(frequency[sym] - 1);
int min_state_plus = frequency[sym] << max_bits_out; int min_state_plus = frequency[sym] << max_bits_out;
delta_bit_count[sym] = (max_bits_out << 16) - min_state_plus; delta_bit_count[sym] = (max_bits_out << kInt16) - min_state_plus;
delta_state[sym] = total - frequency[sym]; delta_state[sym] = total - frequency[sym];
total += frequency[sym]; total += frequency[sym];
} else { } else {
// we assume minimum `frequency` is 1 // we assume minimum `frequency` is 1
delta_bit_count[sym] = (table_log << 16) - (1 << table_log); delta_bit_count[sym] = (table_log << kInt16) - (1 << table_log);
delta_state[sym] = total - 1; delta_state[sym] = total - 1;
total++; total++;
} }
@ -138,8 +146,7 @@ int FSEEncoder::Compress(schema::TensorT *tensor_input) {
ConvertTensor2Quant(tensor_input, &fse_quant); ConvertTensor2Quant(tensor_input, &fse_quant);
NormalizeFrequency(&fse_quant, &table_log); NormalizeFrequency(&fse_quant, &table_log);
BitStream bs; BitStream bs;
int ret; auto ret = bs.Create(kInt16 * fse_quant.symbol_table_count);
ret = bs.Create(16 * fse_quant.symbol_table_count);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "BitStream Create failed."; MS_LOG(ERROR) << "BitStream Create failed.";
return ret; return ret;
@ -171,7 +178,7 @@ uint16_t FSEEncoder::FSEEncodeSymbolGetNewState(BitStream *bs, uint16_t sym, uin
MS_ASSERT(coding_table != nullptr); MS_ASSERT(coding_table != nullptr);
// It is to determine the number of bits to flush. // It is to determine the number of bits to flush.
// This is basically one of 2 values, n or n+1, depending on state crossing a threshold. // This is basically one of 2 values, n or n+1, depending on state crossing a threshold.
uint8_t bits_out = (state + delta_bit_count[sym]) >> 16; uint8_t bits_out = (state + delta_bit_count[sym]) >> kInt16;
bs->Push(state, bits_out); bs->Push(state, bits_out);
// subrangeID = state >> nbBitsOut // subrangeID = state >> nbBitsOut
return coding_table[(state >> bits_out) + delta_state[sym]]; return coding_table[(state >> bits_out) + delta_state[sym]];
@ -194,18 +201,19 @@ void FSEEncoder::NormalizeFrequency(FSEQuant *q, int *table_log) {
MS_ASSERT(q != nullptr); MS_ASSERT(q != nullptr);
// The higher the number, the more accurate we'll be to the shannon entropy, // The higher the number, the more accurate we'll be to the shannon entropy,
// but also the larger the table, so `+3` is a good compromise. // but also the larger the table, so `+3` is a good compromise.
*table_log = std::min(MAX_TABLE_LOG, fse_count_bits((uint32_t)q->size) + 3); *table_log = std::min(MAX_TABLE_LOG, (fse_count_bits((uint32_t)q->size) + kFseTableExtendSize));
int new_table_size = 1 << (*table_log); int new_table_size = 1 << (*table_log);
int curr_table_size = 0; int curr_table_size = 0;
for (int i = 0; i < q->size; i++) { for (int i = 0; i < q->size; i++) {
curr_table_size += q->frequency[i]; curr_table_size += q->frequency[i];
} }
MS_ASSERT(curr_table_size != 0);
// normalize // normalize
int updated_table_size = 0; int updated_table_size = 0;
float rat = (static_cast<float>(new_table_size)) / curr_table_size; float rat = (static_cast<float>(new_table_size)) / curr_table_size;
for (int i = 0; i < q->size; i++) { for (int i = 0; i < q->size; i++) {
q->frequency[i] = std::max(1, static_cast<int>(floorf(0.5 + rat * q->frequency[i]))); q->frequency[i] = std::max(1, static_cast<int>(floorf(kUpRoundOffSet + rat * q->frequency[i])));
updated_table_size += q->frequency[i]; updated_table_size += q->frequency[i];
} }
@ -271,7 +279,8 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
int table_log) { int table_log) {
MS_ASSERT(tensor_input != nullptr); MS_ASSERT(tensor_input != nullptr);
MS_ASSERT(bs != nullptr); MS_ASSERT(bs != nullptr);
auto max_size = tensor_input->data.size() * 2; const int extend_size = 2;
auto max_size = tensor_input->data.size() * extend_size;
auto *out8 = static_cast<uint8_t *>(malloc(max_size)); auto *out8 = static_cast<uint8_t *>(malloc(max_size));
if (out8 == nullptr) { if (out8 == nullptr) {
MS_LOG(ERROR) << "malloc memory failed."; MS_LOG(ERROR) << "malloc memory failed.";
@ -286,7 +295,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
} }
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)table_log; *(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)table_log;
offset += sizeof(uint16_t); offset += sizeof(uint16_t);
int chunksc = bs->GetCurrChunkIndex() + 2; int chunksc = bs->GetCurrChunkIndex() + sizeof(uint16_t);
if (offset + sizeof(uint32_t) > max_size) { if (offset + sizeof(uint32_t) > max_size) {
MS_LOG(ERROR) << "offset over max size" MS_LOG(ERROR) << "offset over max size"
<< " offset:" << offset << " max_size:" << max_size; << " offset:" << offset << " max_size:" << max_size;
@ -301,7 +310,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)fse_quant.frequency[j]; *(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)fse_quant.frequency[j];
offset += sizeof(uint16_t); offset += sizeof(uint16_t);
} }
while (offset % 8 != 0) { while (offset % kAlignSize != 0) {
if (offset + sizeof(uint16_t) > max_size) { if (offset + sizeof(uint16_t) > max_size) {
MS_LOG(ERROR) << "offset over max size" MS_LOG(ERROR) << "offset over max size"
<< " offset:" << offset << " max_size:" << max_size; << " offset:" << offset << " max_size:" << max_size;
@ -317,7 +326,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
*(reinterpret_cast<float *>(&out8[offset])) = static_cast<float>(fse_quant.centroids[j]); *(reinterpret_cast<float *>(&out8[offset])) = static_cast<float>(fse_quant.centroids[j]);
offset += sizeof(float); offset += sizeof(float);
} }
while (offset % 8 != 0) { while (offset % kAlignSize != 0) {
if (offset + sizeof(uint16_t) > max_size) { if (offset + sizeof(uint16_t) > max_size) {
MS_LOG(ERROR) << "offset over max size" MS_LOG(ERROR) << "offset over max size"
<< " offset:" << offset << " max_size:" << max_size; << " offset:" << offset << " max_size:" << max_size;

View File

@ -52,7 +52,9 @@ namespace {
static const std::set<PrimitivePtr> has_bias_operator = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion, static const std::set<PrimitivePtr> has_bias_operator = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
prim::kPrimMatMul, prim::kPrimFullConnection, prim::kPrimMatMul, prim::kPrimFullConnection,
prim::kPrimLayerNormFusion}; prim::kPrimLayerNormFusion};
} constexpr int kMinSize = 0;
constexpr int kMaxSize = 65535;
} // namespace
namespace { namespace {
STATUS ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, const std::vector<double> &input_scales, STATUS ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, const std::vector<double> &input_scales,
const float *raw_datas, const QuantParamHolderPtr &quant_param_holder, const float *raw_datas, const QuantParamHolderPtr &quant_param_holder,
@ -732,6 +734,7 @@ STATUS FullQuantQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) {
if (input_primitive_quant_holder->IsOutputQuantParamsInited()) { if (input_primitive_quant_holder->IsOutputQuantParamsInited()) {
auto quant_param = input_primitive_quant_holder->get_output_quant_params().front(); auto quant_param = input_primitive_quant_holder->get_output_quant_params().front();
primitive_quant_holder->set_input_quant_param(i - 1, quant_param); primitive_quant_holder->set_input_quant_param(i - 1, quant_param);
activation_input_index++;
} else { } else {
// do input quant // do input quant
auto &info = (*inputs_diverg_info)[op_name][activation_input_index++]; auto &info = (*inputs_diverg_info)[op_name][activation_input_index++];
@ -1239,7 +1242,7 @@ STATUS FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
MS_LOG(ERROR) << "calibrate path must pass. The format is input_name_1:input_1_dir,input_name_2:input_2_dir."; MS_LOG(ERROR) << "calibrate path must pass. The format is input_name_1:input_1_dir,input_name_2:input_2_dir.";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
if (flags.dataPreProcessParam.calibrate_size <= 0 || flags.dataPreProcessParam.calibrate_size > 65535) { if (flags.dataPreProcessParam.calibrate_size <= kMinSize || flags.dataPreProcessParam.calibrate_size > kMaxSize) {
MS_LOG(ERROR) << "calibrate size must pass and the size should in [1, 65535]."; MS_LOG(ERROR) << "calibrate size must pass and the size should in [1, 65535].";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
@ -1247,8 +1250,7 @@ STATUS FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
MS_LOG(ERROR) << "input_type must pass IMAGE | BIN."; MS_LOG(ERROR) << "input_type must pass IMAGE | BIN.";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
STATUS status; STATUS status = PreProcess();
status = PreProcess();
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "do pre process failed!"; MS_LOG(ERROR) << "do pre process failed!";
return status; return status;

View File

@ -132,7 +132,6 @@ bool QuantStrategy::CanTensorQuantized(const AnfNodePtr &input_node, int preferr
return false; return false;
} }
size_t shape_size = std::accumulate(weight_shape.begin(), weight_shape.end(), 1, std::multiplies<int>()); size_t shape_size = std::accumulate(weight_shape.begin(), weight_shape.end(), 1, std::multiplies<int>());
if (shape_size < min_quant_weight_size_) { if (shape_size < min_quant_weight_size_) {
MS_LOG(INFO) << "shape_size " << shape_size << " less min_quant_weight_size_ " << shape_size; MS_LOG(INFO) << "shape_size " << shape_size << " less min_quant_weight_size_ " << shape_size;
return false; return false;