forked from mindspore-Ecosystem/mindspore
optimize magic number
This commit is contained in:
parent
5bd7e4def9
commit
8370571b65
|
@ -27,6 +27,10 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
|
namespace {
|
||||||
|
constexpr int kMinSize = 0;
|
||||||
|
constexpr int kMaxSize = 65535;
|
||||||
|
} // namespace
|
||||||
int PreprocessParser::ParseInputType(const std::string &input_type_str, preprocess::InputType *input_type) {
|
int PreprocessParser::ParseInputType(const std::string &input_type_str, preprocess::InputType *input_type) {
|
||||||
if (input_type_str == "IMAGE") {
|
if (input_type_str == "IMAGE") {
|
||||||
(*input_type) = preprocess::IMAGE;
|
(*input_type) = preprocess::IMAGE;
|
||||||
|
@ -56,7 +60,7 @@ int PreprocessParser::ParsePreprocess(const DataPreProcessString &data_pre_proce
|
||||||
MS_LOG(ERROR) << "calibrate_size should be a valid number.";
|
MS_LOG(ERROR) << "calibrate_size should be a valid number.";
|
||||||
return RET_INPUT_PARAM_INVALID;
|
return RET_INPUT_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
if (data_pre_process->calibrate_size < 0) {
|
if (data_pre_process->calibrate_size < kMinSize) {
|
||||||
MS_LOG(ERROR) << "calibrate_size must larger 0.";
|
MS_LOG(ERROR) << "calibrate_size must larger 0.";
|
||||||
return RET_INPUT_PARAM_INVALID;
|
return RET_INPUT_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
|
@ -190,6 +194,7 @@ int PreprocessParser::CollectCalibInputs(const std::map<std::string, std::string
|
||||||
image_dir = readdir(root);
|
image_dir = readdir(root);
|
||||||
}
|
}
|
||||||
closedir(root);
|
closedir(root);
|
||||||
|
std::sort(inputs->at(image_path.first).begin(), inputs->at(image_path.first).end());
|
||||||
if (count != limited_count) {
|
if (count != limited_count) {
|
||||||
MS_LOG(ERROR) << " data path: " << image_path.second << " data count:" << count
|
MS_LOG(ERROR) << " data path: " << image_path.second << " data count:" << count
|
||||||
<< " < limited_count:" << limited_count;
|
<< " < limited_count:" << limited_count;
|
||||||
|
@ -221,7 +226,7 @@ int PreprocessParser::ParseImageResize(const DataPreProcessString &data_pre_proc
|
||||||
MS_LOG(ERROR) << "resize_width should be a valid number.";
|
MS_LOG(ERROR) << "resize_width should be a valid number.";
|
||||||
return RET_INPUT_PARAM_INVALID;
|
return RET_INPUT_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
if (image_pre_process->resize_width <= 0 || image_pre_process->resize_width > 65535) {
|
if (image_pre_process->resize_width <= kMinSize || image_pre_process->resize_width > kMaxSize) {
|
||||||
MS_LOG(ERROR) << "resize_width must be in [1, 65535].";
|
MS_LOG(ERROR) << "resize_width must be in [1, 65535].";
|
||||||
return RET_INPUT_PARAM_INVALID;
|
return RET_INPUT_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
|
@ -231,7 +236,7 @@ int PreprocessParser::ParseImageResize(const DataPreProcessString &data_pre_proc
|
||||||
MS_LOG(ERROR) << "resize_width should be a valid number.";
|
MS_LOG(ERROR) << "resize_width should be a valid number.";
|
||||||
return RET_INPUT_PARAM_INVALID;
|
return RET_INPUT_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
if (image_pre_process->resize_height <= 0 || image_pre_process->resize_height > 65535) {
|
if (image_pre_process->resize_height <= kMinSize || image_pre_process->resize_height > kMaxSize) {
|
||||||
MS_LOG(ERROR) << "resize_height must be in [1, 65535].";
|
MS_LOG(ERROR) << "resize_height must be in [1, 65535].";
|
||||||
return RET_INPUT_PARAM_INVALID;
|
return RET_INPUT_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
|
@ -253,7 +258,7 @@ int PreprocessParser::ParseImageCenterCrop(const DataPreProcessString &data_pre_
|
||||||
MS_LOG(ERROR) << "center_crop_width should be a valid number.";
|
MS_LOG(ERROR) << "center_crop_width should be a valid number.";
|
||||||
return RET_INPUT_PARAM_INVALID;
|
return RET_INPUT_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
if (image_pre_process->center_crop_width <= 0 || image_pre_process->center_crop_width > 65535) {
|
if (image_pre_process->center_crop_width <= kMinSize || image_pre_process->center_crop_width > kMaxSize) {
|
||||||
MS_LOG(ERROR) << "center_crop_width must be in [1, 65535].";
|
MS_LOG(ERROR) << "center_crop_width must be in [1, 65535].";
|
||||||
return RET_INPUT_PARAM_INVALID;
|
return RET_INPUT_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
|
@ -263,7 +268,7 @@ int PreprocessParser::ParseImageCenterCrop(const DataPreProcessString &data_pre_
|
||||||
MS_LOG(ERROR) << "center_crop_height should be a valid number.";
|
MS_LOG(ERROR) << "center_crop_height should be a valid number.";
|
||||||
return RET_INPUT_PARAM_INVALID;
|
return RET_INPUT_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
if (image_pre_process->center_crop_height <= 0 || image_pre_process->center_crop_height > 65535) {
|
if (image_pre_process->center_crop_height <= kMinSize || image_pre_process->center_crop_height > kMaxSize) {
|
||||||
MS_LOG(ERROR) << "center_crop_height must be in [1, 65535].";
|
MS_LOG(ERROR) << "center_crop_height must be in [1, 65535].";
|
||||||
return RET_INPUT_PARAM_INVALID;
|
return RET_INPUT_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,8 @@ namespace lite {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr int kQuantBitNumInt16 = 16;
|
constexpr int kQuantBitNumInt16 = 16;
|
||||||
constexpr int kQuantBitNumInt8 = 8;
|
constexpr int kQuantBitNumInt8 = 8;
|
||||||
|
constexpr int kMinSize = 0;
|
||||||
|
constexpr int kMaxSize = 65535;
|
||||||
} // namespace
|
} // namespace
|
||||||
int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_string,
|
int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_string,
|
||||||
quant::CommonQuantParam *common_quant) {
|
quant::CommonQuantParam *common_quant) {
|
||||||
|
@ -59,12 +61,12 @@ int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_str
|
||||||
MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should be a valid number.";
|
MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should be a valid number.";
|
||||||
return RET_INPUT_PARAM_INVALID;
|
return RET_INPUT_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
if (common_quant->min_quant_weight_size < 0 || common_quant->min_quant_weight_size > 65535) {
|
if (common_quant->min_quant_weight_size < kMinSize || common_quant->min_quant_weight_size > kMaxSize) {
|
||||||
MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_size should in [0,65535]." << std::endl;
|
MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_size should in [0,65535]." << std::endl;
|
||||||
return RET_INPUT_PARAM_INVALID;
|
return RET_INPUT_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (common_quant->min_quant_weight_channel < 0 || common_quant->min_quant_weight_channel > 65535) {
|
if (common_quant->min_quant_weight_channel < kMinSize || common_quant->min_quant_weight_channel > kMaxSize) {
|
||||||
MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should in [0,65535]." << std::endl;
|
MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should in [0,65535]." << std::endl;
|
||||||
return RET_INPUT_PARAM_INVALID;
|
return RET_INPUT_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,9 +22,10 @@
|
||||||
namespace mindspore::lite::quant {
|
namespace mindspore::lite::quant {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr int8_t kCurrentBitCount = 64;
|
constexpr int8_t kCurrentBitCount = 64;
|
||||||
|
constexpr int8_t kTableSize = 6;
|
||||||
} // namespace
|
} // namespace
|
||||||
int BitStream::Create(int bit_capacity) {
|
int BitStream::Create(int bit_capacity) {
|
||||||
chunk_count_ = (bit_capacity >> 6);
|
chunk_count_ = (bit_capacity >> kTableSize);
|
||||||
chunks_ = static_cast<uint64_t *>(calloc(chunk_count_, sizeof(uint64_t)));
|
chunks_ = static_cast<uint64_t *>(calloc(chunk_count_, sizeof(uint64_t)));
|
||||||
if (chunks_ == nullptr) {
|
if (chunks_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "malloc memory failed.";
|
MS_LOG(ERROR) << "malloc memory failed.";
|
||||||
|
|
|
@ -23,9 +23,17 @@
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
|
|
||||||
namespace mindspore::lite::quant {
|
namespace mindspore::lite::quant {
|
||||||
|
namespace {
|
||||||
|
constexpr int kInt32Mask = 31;
|
||||||
|
constexpr int kInt16 = 16;
|
||||||
|
constexpr int kFseTableExtendSize = 3;
|
||||||
|
constexpr int kFrenqTableExtendSize = 2;
|
||||||
|
constexpr int kAlignSize = 8;
|
||||||
|
constexpr float kUpRoundOffSet = 0.5;
|
||||||
|
} // namespace
|
||||||
// The function gives the index of most import `1` in the binary representation.
|
// The function gives the index of most import `1` in the binary representation.
|
||||||
// e.g. for the number 00100 it gives 2.
|
// e.g. for the number 00100 it gives 2.
|
||||||
int fse_count_bits(int32_t x) { return __builtin_clz(x) ^ 31; }
|
int fse_count_bits(int32_t x) { return __builtin_clz(x) ^ kInt32Mask; }
|
||||||
|
|
||||||
int FSEEncoder::FSECreateStatesForEncoding(uint16_t *frequency, int frequency_count, int table_log,
|
int FSEEncoder::FSECreateStatesForEncoding(uint16_t *frequency, int frequency_count, int table_log,
|
||||||
uint32_t *delta_bit_count, int16_t *delta_state, uint16_t *coding_table,
|
uint32_t *delta_bit_count, int16_t *delta_state, uint16_t *coding_table,
|
||||||
|
@ -37,7 +45,7 @@ int FSEEncoder::FSECreateStatesForEncoding(uint16_t *frequency, int frequency_co
|
||||||
MS_ASSERT(coding_table != nullptr);
|
MS_ASSERT(coding_table != nullptr);
|
||||||
int tablesize = 1 << table_log;
|
int tablesize = 1 << table_log;
|
||||||
int tablemask = tablesize - 1;
|
int tablemask = tablesize - 1;
|
||||||
int step = ((tablesize >> 1) + (tablesize >> 3) + 3);
|
int step = ((tablesize >> 1) + (tablesize >> kFseTableExtendSize) + kFseTableExtendSize);
|
||||||
int pos = 0;
|
int pos = 0;
|
||||||
// Separate the same symbols, coding will be better if the same characters are distributed evenly across the table.
|
// Separate the same symbols, coding will be better if the same characters are distributed evenly across the table.
|
||||||
for (int sym = 0; sym < frequency_count; sym++) {
|
for (int sym = 0; sym < frequency_count; sym++) {
|
||||||
|
@ -49,7 +57,7 @@ int FSEEncoder::FSECreateStatesForEncoding(uint16_t *frequency, int frequency_co
|
||||||
}
|
}
|
||||||
if (pos != 0) return 1;
|
if (pos != 0) return 1;
|
||||||
|
|
||||||
std::vector<uint16_t> cfreqs(frequency_count + 2);
|
std::vector<uint16_t> cfreqs(frequency_count + kFrenqTableExtendSize);
|
||||||
cfreqs[0] = 0;
|
cfreqs[0] = 0;
|
||||||
for (int i = 1; i < frequency_count + 1; i++) {
|
for (int i = 1; i < frequency_count + 1; i++) {
|
||||||
cfreqs[i] = cfreqs[i - 1] + frequency[i - 1];
|
cfreqs[i] = cfreqs[i - 1] + frequency[i - 1];
|
||||||
|
@ -63,15 +71,15 @@ int FSEEncoder::FSECreateStatesForEncoding(uint16_t *frequency, int frequency_co
|
||||||
|
|
||||||
int total = 0;
|
int total = 0;
|
||||||
for (int sym = 0; sym < frequency_count; sym++) {
|
for (int sym = 0; sym < frequency_count; sym++) {
|
||||||
if (frequency[sym] >= 2) {
|
if (frequency[sym] >= kFrenqTableExtendSize) {
|
||||||
int max_bits_out = table_log - fse_count_bits(frequency[sym] - 1);
|
int max_bits_out = table_log - fse_count_bits(frequency[sym] - 1);
|
||||||
int min_state_plus = frequency[sym] << max_bits_out;
|
int min_state_plus = frequency[sym] << max_bits_out;
|
||||||
delta_bit_count[sym] = (max_bits_out << 16) - min_state_plus;
|
delta_bit_count[sym] = (max_bits_out << kInt16) - min_state_plus;
|
||||||
delta_state[sym] = total - frequency[sym];
|
delta_state[sym] = total - frequency[sym];
|
||||||
total += frequency[sym];
|
total += frequency[sym];
|
||||||
} else {
|
} else {
|
||||||
// we assume minimum `frequency` is 1
|
// we assume minimum `frequency` is 1
|
||||||
delta_bit_count[sym] = (table_log << 16) - (1 << table_log);
|
delta_bit_count[sym] = (table_log << kInt16) - (1 << table_log);
|
||||||
delta_state[sym] = total - 1;
|
delta_state[sym] = total - 1;
|
||||||
total++;
|
total++;
|
||||||
}
|
}
|
||||||
|
@ -138,8 +146,7 @@ int FSEEncoder::Compress(schema::TensorT *tensor_input) {
|
||||||
ConvertTensor2Quant(tensor_input, &fse_quant);
|
ConvertTensor2Quant(tensor_input, &fse_quant);
|
||||||
NormalizeFrequency(&fse_quant, &table_log);
|
NormalizeFrequency(&fse_quant, &table_log);
|
||||||
BitStream bs;
|
BitStream bs;
|
||||||
int ret;
|
auto ret = bs.Create(kInt16 * fse_quant.symbol_table_count);
|
||||||
ret = bs.Create(16 * fse_quant.symbol_table_count);
|
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "BitStream Create failed.";
|
MS_LOG(ERROR) << "BitStream Create failed.";
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -171,7 +178,7 @@ uint16_t FSEEncoder::FSEEncodeSymbolGetNewState(BitStream *bs, uint16_t sym, uin
|
||||||
MS_ASSERT(coding_table != nullptr);
|
MS_ASSERT(coding_table != nullptr);
|
||||||
// It is to determine the number of bits to flush.
|
// It is to determine the number of bits to flush.
|
||||||
// This is basically one of 2 values, n or n+1, depending on state crossing a threshold.
|
// This is basically one of 2 values, n or n+1, depending on state crossing a threshold.
|
||||||
uint8_t bits_out = (state + delta_bit_count[sym]) >> 16;
|
uint8_t bits_out = (state + delta_bit_count[sym]) >> kInt16;
|
||||||
bs->Push(state, bits_out);
|
bs->Push(state, bits_out);
|
||||||
// subrangeID = state >> nbBitsOut
|
// subrangeID = state >> nbBitsOut
|
||||||
return coding_table[(state >> bits_out) + delta_state[sym]];
|
return coding_table[(state >> bits_out) + delta_state[sym]];
|
||||||
|
@ -194,18 +201,19 @@ void FSEEncoder::NormalizeFrequency(FSEQuant *q, int *table_log) {
|
||||||
MS_ASSERT(q != nullptr);
|
MS_ASSERT(q != nullptr);
|
||||||
// The higher the number, the more accurate we'll be to the shannon entropy,
|
// The higher the number, the more accurate we'll be to the shannon entropy,
|
||||||
// but also the larger the table, so `+3` is a good compromise.
|
// but also the larger the table, so `+3` is a good compromise.
|
||||||
*table_log = std::min(MAX_TABLE_LOG, fse_count_bits((uint32_t)q->size) + 3);
|
*table_log = std::min(MAX_TABLE_LOG, (fse_count_bits((uint32_t)q->size) + kFseTableExtendSize));
|
||||||
int new_table_size = 1 << (*table_log);
|
int new_table_size = 1 << (*table_log);
|
||||||
int curr_table_size = 0;
|
int curr_table_size = 0;
|
||||||
for (int i = 0; i < q->size; i++) {
|
for (int i = 0; i < q->size; i++) {
|
||||||
curr_table_size += q->frequency[i];
|
curr_table_size += q->frequency[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MS_ASSERT(curr_table_size != 0);
|
||||||
// normalize
|
// normalize
|
||||||
int updated_table_size = 0;
|
int updated_table_size = 0;
|
||||||
float rat = (static_cast<float>(new_table_size)) / curr_table_size;
|
float rat = (static_cast<float>(new_table_size)) / curr_table_size;
|
||||||
for (int i = 0; i < q->size; i++) {
|
for (int i = 0; i < q->size; i++) {
|
||||||
q->frequency[i] = std::max(1, static_cast<int>(floorf(0.5 + rat * q->frequency[i])));
|
q->frequency[i] = std::max(1, static_cast<int>(floorf(kUpRoundOffSet + rat * q->frequency[i])));
|
||||||
updated_table_size += q->frequency[i];
|
updated_table_size += q->frequency[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -271,7 +279,8 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
|
||||||
int table_log) {
|
int table_log) {
|
||||||
MS_ASSERT(tensor_input != nullptr);
|
MS_ASSERT(tensor_input != nullptr);
|
||||||
MS_ASSERT(bs != nullptr);
|
MS_ASSERT(bs != nullptr);
|
||||||
auto max_size = tensor_input->data.size() * 2;
|
const int extend_size = 2;
|
||||||
|
auto max_size = tensor_input->data.size() * extend_size;
|
||||||
auto *out8 = static_cast<uint8_t *>(malloc(max_size));
|
auto *out8 = static_cast<uint8_t *>(malloc(max_size));
|
||||||
if (out8 == nullptr) {
|
if (out8 == nullptr) {
|
||||||
MS_LOG(ERROR) << "malloc memory failed.";
|
MS_LOG(ERROR) << "malloc memory failed.";
|
||||||
|
@ -286,7 +295,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
|
||||||
}
|
}
|
||||||
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)table_log;
|
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)table_log;
|
||||||
offset += sizeof(uint16_t);
|
offset += sizeof(uint16_t);
|
||||||
int chunksc = bs->GetCurrChunkIndex() + 2;
|
int chunksc = bs->GetCurrChunkIndex() + sizeof(uint16_t);
|
||||||
if (offset + sizeof(uint32_t) > max_size) {
|
if (offset + sizeof(uint32_t) > max_size) {
|
||||||
MS_LOG(ERROR) << "offset over max size"
|
MS_LOG(ERROR) << "offset over max size"
|
||||||
<< " offset:" << offset << " max_size:" << max_size;
|
<< " offset:" << offset << " max_size:" << max_size;
|
||||||
|
@ -301,7 +310,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
|
||||||
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)fse_quant.frequency[j];
|
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)fse_quant.frequency[j];
|
||||||
offset += sizeof(uint16_t);
|
offset += sizeof(uint16_t);
|
||||||
}
|
}
|
||||||
while (offset % 8 != 0) {
|
while (offset % kAlignSize != 0) {
|
||||||
if (offset + sizeof(uint16_t) > max_size) {
|
if (offset + sizeof(uint16_t) > max_size) {
|
||||||
MS_LOG(ERROR) << "offset over max size"
|
MS_LOG(ERROR) << "offset over max size"
|
||||||
<< " offset:" << offset << " max_size:" << max_size;
|
<< " offset:" << offset << " max_size:" << max_size;
|
||||||
|
@ -317,7 +326,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
|
||||||
*(reinterpret_cast<float *>(&out8[offset])) = static_cast<float>(fse_quant.centroids[j]);
|
*(reinterpret_cast<float *>(&out8[offset])) = static_cast<float>(fse_quant.centroids[j]);
|
||||||
offset += sizeof(float);
|
offset += sizeof(float);
|
||||||
}
|
}
|
||||||
while (offset % 8 != 0) {
|
while (offset % kAlignSize != 0) {
|
||||||
if (offset + sizeof(uint16_t) > max_size) {
|
if (offset + sizeof(uint16_t) > max_size) {
|
||||||
MS_LOG(ERROR) << "offset over max size"
|
MS_LOG(ERROR) << "offset over max size"
|
||||||
<< " offset:" << offset << " max_size:" << max_size;
|
<< " offset:" << offset << " max_size:" << max_size;
|
||||||
|
|
|
@ -52,7 +52,9 @@ namespace {
|
||||||
static const std::set<PrimitivePtr> has_bias_operator = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
|
static const std::set<PrimitivePtr> has_bias_operator = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
|
||||||
prim::kPrimMatMul, prim::kPrimFullConnection,
|
prim::kPrimMatMul, prim::kPrimFullConnection,
|
||||||
prim::kPrimLayerNormFusion};
|
prim::kPrimLayerNormFusion};
|
||||||
}
|
constexpr int kMinSize = 0;
|
||||||
|
constexpr int kMaxSize = 65535;
|
||||||
|
} // namespace
|
||||||
namespace {
|
namespace {
|
||||||
STATUS ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, const std::vector<double> &input_scales,
|
STATUS ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, const std::vector<double> &input_scales,
|
||||||
const float *raw_datas, const QuantParamHolderPtr &quant_param_holder,
|
const float *raw_datas, const QuantParamHolderPtr &quant_param_holder,
|
||||||
|
@ -732,6 +734,7 @@ STATUS FullQuantQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) {
|
||||||
if (input_primitive_quant_holder->IsOutputQuantParamsInited()) {
|
if (input_primitive_quant_holder->IsOutputQuantParamsInited()) {
|
||||||
auto quant_param = input_primitive_quant_holder->get_output_quant_params().front();
|
auto quant_param = input_primitive_quant_holder->get_output_quant_params().front();
|
||||||
primitive_quant_holder->set_input_quant_param(i - 1, quant_param);
|
primitive_quant_holder->set_input_quant_param(i - 1, quant_param);
|
||||||
|
activation_input_index++;
|
||||||
} else {
|
} else {
|
||||||
// do input quant
|
// do input quant
|
||||||
auto &info = (*inputs_diverg_info)[op_name][activation_input_index++];
|
auto &info = (*inputs_diverg_info)[op_name][activation_input_index++];
|
||||||
|
@ -1239,7 +1242,7 @@ STATUS FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
||||||
MS_LOG(ERROR) << "calibrate path must pass. The format is input_name_1:input_1_dir,input_name_2:input_2_dir.";
|
MS_LOG(ERROR) << "calibrate path must pass. The format is input_name_1:input_1_dir,input_name_2:input_2_dir.";
|
||||||
return RET_INPUT_PARAM_INVALID;
|
return RET_INPUT_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
if (flags.dataPreProcessParam.calibrate_size <= 0 || flags.dataPreProcessParam.calibrate_size > 65535) {
|
if (flags.dataPreProcessParam.calibrate_size <= kMinSize || flags.dataPreProcessParam.calibrate_size > kMaxSize) {
|
||||||
MS_LOG(ERROR) << "calibrate size must pass and the size should in [1, 65535].";
|
MS_LOG(ERROR) << "calibrate size must pass and the size should in [1, 65535].";
|
||||||
return RET_INPUT_PARAM_INVALID;
|
return RET_INPUT_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
|
@ -1247,8 +1250,7 @@ STATUS FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
||||||
MS_LOG(ERROR) << "input_type must pass IMAGE | BIN.";
|
MS_LOG(ERROR) << "input_type must pass IMAGE | BIN.";
|
||||||
return RET_INPUT_PARAM_INVALID;
|
return RET_INPUT_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
STATUS status;
|
STATUS status = PreProcess();
|
||||||
status = PreProcess();
|
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "do pre process failed!";
|
MS_LOG(ERROR) << "do pre process failed!";
|
||||||
return status;
|
return status;
|
||||||
|
|
|
@ -132,7 +132,6 @@ bool QuantStrategy::CanTensorQuantized(const AnfNodePtr &input_node, int preferr
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
size_t shape_size = std::accumulate(weight_shape.begin(), weight_shape.end(), 1, std::multiplies<int>());
|
size_t shape_size = std::accumulate(weight_shape.begin(), weight_shape.end(), 1, std::multiplies<int>());
|
||||||
|
|
||||||
if (shape_size < min_quant_weight_size_) {
|
if (shape_size < min_quant_weight_size_) {
|
||||||
MS_LOG(INFO) << "shape_size " << shape_size << " less min_quant_weight_size_ " << shape_size;
|
MS_LOG(INFO) << "shape_size " << shape_size << " less min_quant_weight_size_ " << shape_size;
|
||||||
return false;
|
return false;
|
||||||
|
|
Loading…
Reference in New Issue