forked from mindspore-Ecosystem/mindspore
optimize magic number
This commit is contained in:
parent
5bd7e4def9
commit
8370571b65
|
@ -27,6 +27,10 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
constexpr int kMinSize = 0;
|
||||
constexpr int kMaxSize = 65535;
|
||||
} // namespace
|
||||
int PreprocessParser::ParseInputType(const std::string &input_type_str, preprocess::InputType *input_type) {
|
||||
if (input_type_str == "IMAGE") {
|
||||
(*input_type) = preprocess::IMAGE;
|
||||
|
@ -56,7 +60,7 @@ int PreprocessParser::ParsePreprocess(const DataPreProcessString &data_pre_proce
|
|||
MS_LOG(ERROR) << "calibrate_size should be a valid number.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (data_pre_process->calibrate_size < 0) {
|
||||
if (data_pre_process->calibrate_size < kMinSize) {
|
||||
MS_LOG(ERROR) << "calibrate_size must larger 0.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
@ -190,6 +194,7 @@ int PreprocessParser::CollectCalibInputs(const std::map<std::string, std::string
|
|||
image_dir = readdir(root);
|
||||
}
|
||||
closedir(root);
|
||||
std::sort(inputs->at(image_path.first).begin(), inputs->at(image_path.first).end());
|
||||
if (count != limited_count) {
|
||||
MS_LOG(ERROR) << " data path: " << image_path.second << " data count:" << count
|
||||
<< " < limited_count:" << limited_count;
|
||||
|
@ -221,7 +226,7 @@ int PreprocessParser::ParseImageResize(const DataPreProcessString &data_pre_proc
|
|||
MS_LOG(ERROR) << "resize_width should be a valid number.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (image_pre_process->resize_width <= 0 || image_pre_process->resize_width > 65535) {
|
||||
if (image_pre_process->resize_width <= kMinSize || image_pre_process->resize_width > kMaxSize) {
|
||||
MS_LOG(ERROR) << "resize_width must be in [1, 65535].";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
@ -231,7 +236,7 @@ int PreprocessParser::ParseImageResize(const DataPreProcessString &data_pre_proc
|
|||
MS_LOG(ERROR) << "resize_width should be a valid number.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (image_pre_process->resize_height <= 0 || image_pre_process->resize_height > 65535) {
|
||||
if (image_pre_process->resize_height <= kMinSize || image_pre_process->resize_height > kMaxSize) {
|
||||
MS_LOG(ERROR) << "resize_height must be in [1, 65535].";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
@ -253,7 +258,7 @@ int PreprocessParser::ParseImageCenterCrop(const DataPreProcessString &data_pre_
|
|||
MS_LOG(ERROR) << "center_crop_width should be a valid number.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (image_pre_process->center_crop_width <= 0 || image_pre_process->center_crop_width > 65535) {
|
||||
if (image_pre_process->center_crop_width <= kMinSize || image_pre_process->center_crop_width > kMaxSize) {
|
||||
MS_LOG(ERROR) << "center_crop_width must be in [1, 65535].";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
@ -263,7 +268,7 @@ int PreprocessParser::ParseImageCenterCrop(const DataPreProcessString &data_pre_
|
|||
MS_LOG(ERROR) << "center_crop_height should be a valid number.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (image_pre_process->center_crop_height <= 0 || image_pre_process->center_crop_height > 65535) {
|
||||
if (image_pre_process->center_crop_height <= kMinSize || image_pre_process->center_crop_height > kMaxSize) {
|
||||
MS_LOG(ERROR) << "center_crop_height must be in [1, 65535].";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
|
|
@ -23,6 +23,8 @@ namespace lite {
|
|||
namespace {
|
||||
constexpr int kQuantBitNumInt16 = 16;
|
||||
constexpr int kQuantBitNumInt8 = 8;
|
||||
constexpr int kMinSize = 0;
|
||||
constexpr int kMaxSize = 65535;
|
||||
} // namespace
|
||||
int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_string,
|
||||
quant::CommonQuantParam *common_quant) {
|
||||
|
@ -59,12 +61,12 @@ int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_str
|
|||
MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should be a valid number.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (common_quant->min_quant_weight_size < 0 || common_quant->min_quant_weight_size > 65535) {
|
||||
if (common_quant->min_quant_weight_size < kMinSize || common_quant->min_quant_weight_size > kMaxSize) {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_size should in [0,65535]." << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (common_quant->min_quant_weight_channel < 0 || common_quant->min_quant_weight_channel > 65535) {
|
||||
if (common_quant->min_quant_weight_channel < kMinSize || common_quant->min_quant_weight_channel > kMaxSize) {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should in [0,65535]." << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
|
|
@ -22,9 +22,10 @@
|
|||
namespace mindspore::lite::quant {
|
||||
namespace {
|
||||
constexpr int8_t kCurrentBitCount = 64;
|
||||
constexpr int8_t kTableSize = 6;
|
||||
} // namespace
|
||||
int BitStream::Create(int bit_capacity) {
|
||||
chunk_count_ = (bit_capacity >> 6);
|
||||
chunk_count_ = (bit_capacity >> kTableSize);
|
||||
chunks_ = static_cast<uint64_t *>(calloc(chunk_count_, sizeof(uint64_t)));
|
||||
if (chunks_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc memory failed.";
|
||||
|
|
|
@ -23,9 +23,17 @@
|
|||
#include "include/errorcode.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
namespace {
|
||||
constexpr int kInt32Mask = 31;
|
||||
constexpr int kInt16 = 16;
|
||||
constexpr int kFseTableExtendSize = 3;
|
||||
constexpr int kFrenqTableExtendSize = 2;
|
||||
constexpr int kAlignSize = 8;
|
||||
constexpr float kUpRoundOffSet = 0.5;
|
||||
} // namespace
|
||||
// The function gives the index of most import `1` in the binary representation.
|
||||
// e.g. for the number 00100 it gives 2.
|
||||
int fse_count_bits(int32_t x) { return __builtin_clz(x) ^ 31; }
|
||||
int fse_count_bits(int32_t x) { return __builtin_clz(x) ^ kInt32Mask; }
|
||||
|
||||
int FSEEncoder::FSECreateStatesForEncoding(uint16_t *frequency, int frequency_count, int table_log,
|
||||
uint32_t *delta_bit_count, int16_t *delta_state, uint16_t *coding_table,
|
||||
|
@ -37,7 +45,7 @@ int FSEEncoder::FSECreateStatesForEncoding(uint16_t *frequency, int frequency_co
|
|||
MS_ASSERT(coding_table != nullptr);
|
||||
int tablesize = 1 << table_log;
|
||||
int tablemask = tablesize - 1;
|
||||
int step = ((tablesize >> 1) + (tablesize >> 3) + 3);
|
||||
int step = ((tablesize >> 1) + (tablesize >> kFseTableExtendSize) + kFseTableExtendSize);
|
||||
int pos = 0;
|
||||
// Separate the same symbols, coding will be better if the same characters are distributed evenly across the table.
|
||||
for (int sym = 0; sym < frequency_count; sym++) {
|
||||
|
@ -49,7 +57,7 @@ int FSEEncoder::FSECreateStatesForEncoding(uint16_t *frequency, int frequency_co
|
|||
}
|
||||
if (pos != 0) return 1;
|
||||
|
||||
std::vector<uint16_t> cfreqs(frequency_count + 2);
|
||||
std::vector<uint16_t> cfreqs(frequency_count + kFrenqTableExtendSize);
|
||||
cfreqs[0] = 0;
|
||||
for (int i = 1; i < frequency_count + 1; i++) {
|
||||
cfreqs[i] = cfreqs[i - 1] + frequency[i - 1];
|
||||
|
@ -63,15 +71,15 @@ int FSEEncoder::FSECreateStatesForEncoding(uint16_t *frequency, int frequency_co
|
|||
|
||||
int total = 0;
|
||||
for (int sym = 0; sym < frequency_count; sym++) {
|
||||
if (frequency[sym] >= 2) {
|
||||
if (frequency[sym] >= kFrenqTableExtendSize) {
|
||||
int max_bits_out = table_log - fse_count_bits(frequency[sym] - 1);
|
||||
int min_state_plus = frequency[sym] << max_bits_out;
|
||||
delta_bit_count[sym] = (max_bits_out << 16) - min_state_plus;
|
||||
delta_bit_count[sym] = (max_bits_out << kInt16) - min_state_plus;
|
||||
delta_state[sym] = total - frequency[sym];
|
||||
total += frequency[sym];
|
||||
} else {
|
||||
// we assume minimum `frequency` is 1
|
||||
delta_bit_count[sym] = (table_log << 16) - (1 << table_log);
|
||||
delta_bit_count[sym] = (table_log << kInt16) - (1 << table_log);
|
||||
delta_state[sym] = total - 1;
|
||||
total++;
|
||||
}
|
||||
|
@ -138,8 +146,7 @@ int FSEEncoder::Compress(schema::TensorT *tensor_input) {
|
|||
ConvertTensor2Quant(tensor_input, &fse_quant);
|
||||
NormalizeFrequency(&fse_quant, &table_log);
|
||||
BitStream bs;
|
||||
int ret;
|
||||
ret = bs.Create(16 * fse_quant.symbol_table_count);
|
||||
auto ret = bs.Create(kInt16 * fse_quant.symbol_table_count);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "BitStream Create failed.";
|
||||
return ret;
|
||||
|
@ -171,7 +178,7 @@ uint16_t FSEEncoder::FSEEncodeSymbolGetNewState(BitStream *bs, uint16_t sym, uin
|
|||
MS_ASSERT(coding_table != nullptr);
|
||||
// It is to determine the number of bits to flush.
|
||||
// This is basically one of 2 values, n or n+1, depending on state crossing a threshold.
|
||||
uint8_t bits_out = (state + delta_bit_count[sym]) >> 16;
|
||||
uint8_t bits_out = (state + delta_bit_count[sym]) >> kInt16;
|
||||
bs->Push(state, bits_out);
|
||||
// subrangeID = state >> nbBitsOut
|
||||
return coding_table[(state >> bits_out) + delta_state[sym]];
|
||||
|
@ -194,18 +201,19 @@ void FSEEncoder::NormalizeFrequency(FSEQuant *q, int *table_log) {
|
|||
MS_ASSERT(q != nullptr);
|
||||
// The higher the number, the more accurate we'll be to the shannon entropy,
|
||||
// but also the larger the table, so `+3` is a good compromise.
|
||||
*table_log = std::min(MAX_TABLE_LOG, fse_count_bits((uint32_t)q->size) + 3);
|
||||
*table_log = std::min(MAX_TABLE_LOG, (fse_count_bits((uint32_t)q->size) + kFseTableExtendSize));
|
||||
int new_table_size = 1 << (*table_log);
|
||||
int curr_table_size = 0;
|
||||
for (int i = 0; i < q->size; i++) {
|
||||
curr_table_size += q->frequency[i];
|
||||
}
|
||||
|
||||
MS_ASSERT(curr_table_size != 0);
|
||||
// normalize
|
||||
int updated_table_size = 0;
|
||||
float rat = (static_cast<float>(new_table_size)) / curr_table_size;
|
||||
for (int i = 0; i < q->size; i++) {
|
||||
q->frequency[i] = std::max(1, static_cast<int>(floorf(0.5 + rat * q->frequency[i])));
|
||||
q->frequency[i] = std::max(1, static_cast<int>(floorf(kUpRoundOffSet + rat * q->frequency[i])));
|
||||
updated_table_size += q->frequency[i];
|
||||
}
|
||||
|
||||
|
@ -271,7 +279,8 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
|
|||
int table_log) {
|
||||
MS_ASSERT(tensor_input != nullptr);
|
||||
MS_ASSERT(bs != nullptr);
|
||||
auto max_size = tensor_input->data.size() * 2;
|
||||
const int extend_size = 2;
|
||||
auto max_size = tensor_input->data.size() * extend_size;
|
||||
auto *out8 = static_cast<uint8_t *>(malloc(max_size));
|
||||
if (out8 == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc memory failed.";
|
||||
|
@ -286,7 +295,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
|
|||
}
|
||||
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)table_log;
|
||||
offset += sizeof(uint16_t);
|
||||
int chunksc = bs->GetCurrChunkIndex() + 2;
|
||||
int chunksc = bs->GetCurrChunkIndex() + sizeof(uint16_t);
|
||||
if (offset + sizeof(uint32_t) > max_size) {
|
||||
MS_LOG(ERROR) << "offset over max size"
|
||||
<< " offset:" << offset << " max_size:" << max_size;
|
||||
|
@ -301,7 +310,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
|
|||
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)fse_quant.frequency[j];
|
||||
offset += sizeof(uint16_t);
|
||||
}
|
||||
while (offset % 8 != 0) {
|
||||
while (offset % kAlignSize != 0) {
|
||||
if (offset + sizeof(uint16_t) > max_size) {
|
||||
MS_LOG(ERROR) << "offset over max size"
|
||||
<< " offset:" << offset << " max_size:" << max_size;
|
||||
|
@ -317,7 +326,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
|
|||
*(reinterpret_cast<float *>(&out8[offset])) = static_cast<float>(fse_quant.centroids[j]);
|
||||
offset += sizeof(float);
|
||||
}
|
||||
while (offset % 8 != 0) {
|
||||
while (offset % kAlignSize != 0) {
|
||||
if (offset + sizeof(uint16_t) > max_size) {
|
||||
MS_LOG(ERROR) << "offset over max size"
|
||||
<< " offset:" << offset << " max_size:" << max_size;
|
||||
|
|
|
@ -52,7 +52,9 @@ namespace {
|
|||
static const std::set<PrimitivePtr> has_bias_operator = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
|
||||
prim::kPrimMatMul, prim::kPrimFullConnection,
|
||||
prim::kPrimLayerNormFusion};
|
||||
}
|
||||
constexpr int kMinSize = 0;
|
||||
constexpr int kMaxSize = 65535;
|
||||
} // namespace
|
||||
namespace {
|
||||
STATUS ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, const std::vector<double> &input_scales,
|
||||
const float *raw_datas, const QuantParamHolderPtr &quant_param_holder,
|
||||
|
@ -732,6 +734,7 @@ STATUS FullQuantQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) {
|
|||
if (input_primitive_quant_holder->IsOutputQuantParamsInited()) {
|
||||
auto quant_param = input_primitive_quant_holder->get_output_quant_params().front();
|
||||
primitive_quant_holder->set_input_quant_param(i - 1, quant_param);
|
||||
activation_input_index++;
|
||||
} else {
|
||||
// do input quant
|
||||
auto &info = (*inputs_diverg_info)[op_name][activation_input_index++];
|
||||
|
@ -1239,7 +1242,7 @@ STATUS FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
MS_LOG(ERROR) << "calibrate path must pass. The format is input_name_1:input_1_dir,input_name_2:input_2_dir.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (flags.dataPreProcessParam.calibrate_size <= 0 || flags.dataPreProcessParam.calibrate_size > 65535) {
|
||||
if (flags.dataPreProcessParam.calibrate_size <= kMinSize || flags.dataPreProcessParam.calibrate_size > kMaxSize) {
|
||||
MS_LOG(ERROR) << "calibrate size must pass and the size should in [1, 65535].";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
@ -1247,8 +1250,7 @@ STATUS FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
MS_LOG(ERROR) << "input_type must pass IMAGE | BIN.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
STATUS status;
|
||||
status = PreProcess();
|
||||
STATUS status = PreProcess();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "do pre process failed!";
|
||||
return status;
|
||||
|
|
|
@ -132,7 +132,6 @@ bool QuantStrategy::CanTensorQuantized(const AnfNodePtr &input_node, int preferr
|
|||
return false;
|
||||
}
|
||||
size_t shape_size = std::accumulate(weight_shape.begin(), weight_shape.end(), 1, std::multiplies<int>());
|
||||
|
||||
if (shape_size < min_quant_weight_size_) {
|
||||
MS_LOG(INFO) << "shape_size " << shape_size << " less min_quant_weight_size_ " << shape_size;
|
||||
return false;
|
||||
|
|
Loading…
Reference in New Issue