fix sparse decode bug

This commit is contained in:
yeyunpeng2020 2021-09-09 09:44:30 +08:00
parent e17d24f802
commit 3d54a5434d
4 changed files with 36 additions and 36 deletions

View File

@ -62,10 +62,11 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl
MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other";
return RET_ERROR;
}
MS_LOG(WARNING) << "The maximum and minimum values are equal to 0.";
quantParam->inited = true;
quantParam->min = mMin;
quantParam->max = mMax;
quantParam->scale = 0.0f;
quantParam->scale = 1;
quantParam->zeroPoint = 0;
quantParam->narrowRange = narrowRange;
quantParam->numBits = num_bits;

View File

@ -89,9 +89,9 @@ STATUS UnSparseTensorData(const std::vector<int> &unique_values, const std::vect
return RET_ERROR;
}
auto coor = coors[i];
auto cur_channel = data_index / elem_perchannel;
auto zp = quant_params->Get(cur_channel)->zeroPoint();
for (size_t j = 0; j < coor; j++) {
auto cur_channel = data_index / elem_perchannel;
auto zp = quant_params->Get(cur_channel)->zeroPoint();
un_sparsed_data.push_back(zp);
data_index++;
}

View File

@ -460,15 +460,20 @@ FullQuantQuantizer::~FullQuantQuantizer() {
delete int8_model_;
}
STATUS FullQuantQuantizer::DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min,
const PrimitivePtr &primitive) const {
STATUS FullQuantQuantizer::SetInOutQuantParam(double scale, int zero_point, struct MaxMin *max_min,
const PrimitivePtr &primitive, bool is_input) const {
MS_ASSERT(max_min != nullptr);
MS_ASSERT(primitive != nullptr);
auto quant_param_holder = GetCNodeQuantHolder(primitive);
MS_CHECK_TRUE_MSG(quant_param_holder != nullptr, RET_NULL_PTR, "quant_param_holder is nullptr.");
schema::QuantParamT quant_param;
quant_param.scale = scale;
quant_param.zeroPoint = zeropoint;
if (scale == 0) {
MS_LOG(WARNING) << "The input or output values are very close to 0, so set the scale to 1.";
quant_param.scale = 1;
} else {
quant_param.scale = scale;
}
quant_param.zeroPoint = zero_point;
quant_param.max = max_min->max;
quant_param.min = max_min->min;
quant_param.numBits = bit_num;
@ -477,28 +482,11 @@ STATUS FullQuantQuantizer::DoQuantInput(double scale, int32_t zeropoint, struct
quant_param.roundType = 1;
quant_param.multiplier = 1;
std::vector<schema::QuantParamT> quant_params = {quant_param};
quant_param_holder->AddInputQuantParam(quant_params);
return RET_OK;
}
STATUS FullQuantQuantizer::DoQuantOutput(double scale, int zeropoint, struct MaxMin *max_min,
const PrimitivePtr &primitive) const {
MS_ASSERT(max_min != nullptr);
MS_ASSERT(primitive != nullptr);
auto quant_param_holder = GetCNodeQuantHolder(primitive);
MS_CHECK_TRUE_MSG(quant_param_holder != nullptr, RET_NULL_PTR, "quant_param_holder is nullptr.");
schema::QuantParamT quant_param;
quant_param.scale = scale;
quant_param.zeroPoint = zeropoint;
quant_param.max = max_min->max;
quant_param.min = max_min->min;
quant_param.numBits = bit_num;
quant_param.narrowRange = false;
quant_param.inited = true;
quant_param.roundType = 1;
quant_param.multiplier = 1;
std::vector<schema::QuantParamT> quant_params = {quant_param};
quant_param_holder->AddOutputQuantParam(quant_params);
if (is_input) {
quant_param_holder->AddInputQuantParam(quant_params);
} else {
quant_param_holder->AddOutputQuantParam(quant_params);
}
return RET_OK;
}
@ -684,7 +672,7 @@ STATUS FullQuantQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) {
struct MaxMin input_min_max {};
input_min_max.max = info->max;
input_min_max.min = info->min;
DoQuantInput(input_scale, input_zp, &input_min_max, primitive);
SetInOutQuantParam(input_scale, input_zp, &input_min_max, primitive, true);
}
} else if (is_graph_input) {
auto &info = (*inputs_diverg_info)[op_name][i - 1];
@ -693,7 +681,7 @@ STATUS FullQuantQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) {
struct MaxMin input_min_max {};
input_min_max.max = info->max;
input_min_max.min = info->min;
DoQuantInput(input_scale, input_zp, &input_min_max, primitive);
SetInOutQuantParam(input_scale, input_zp, &input_min_max, primitive, true);
} else {
MS_LOG(DEBUG) << "node: " << op_name << " input " << i << " not a cnode";
// get dtype
@ -785,7 +773,7 @@ STATUS FullQuantQuantizer::QuantNode() {
struct MaxMin input_min_max {};
input_min_max.max = info->max;
input_min_max.min = info->min;
DoQuantInput(input_scale, input_zp, &input_min_max, primitive);
SetInOutQuantParam(input_scale, input_zp, &input_min_max, primitive, true);
// do weight quant
auto weight = cnode->input(2);
bool per_channel = false;
@ -814,7 +802,7 @@ STATUS FullQuantQuantizer::QuantNode() {
output_min_max.max = info->max;
output_min_max.min = info->min;
DoQuantOutput(output_scale, output_zp, &output_min_max, primitive);
SetInOutQuantParam(output_scale, output_zp, &output_min_max, primitive, false);
primitive_quant_holder->set_quant_type(schema::QuantType_QUANT_ALL);
}
}
@ -1219,6 +1207,18 @@ STATUS FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
}
calibrator_->full_quant_param_ = flags.fullQuantParam;
calibrator_->data_pre_process_param_ = flags.dataPreProcessParam;
if (flags.dataPreProcessParam.calibrate_path_vector.empty()) {
MS_LOG(ERROR) << "calibrate path must pass. The format is input_name_1:input_1_dir,input_name_2:input_2_dir.";
return RET_INPUT_PARAM_INVALID;
}
if (flags.dataPreProcessParam.calibrate_size < 0) {
MS_LOG(ERROR) << "calibrate size must pass and the size must > 0.";
return RET_INPUT_PARAM_INVALID;
}
if (flags.dataPreProcessParam.input_type == preprocess::INPUT_TYPE_MAX) {
MS_LOG(ERROR) << "input_type must pass IMAGE | BIN.";
return RET_INPUT_PARAM_INVALID;
}
STATUS status;
status = PreProcess();
if (status != RET_OK) {

View File

@ -106,9 +106,8 @@ class FullQuantQuantizer : public Quantizer {
STATUS QuantNode();
STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, const PrimitivePtr &primitive) const;
STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, const PrimitivePtr &) const;
STATUS SetInOutQuantParam(double scale, int32_t zero_point, struct MaxMin *max_min, const PrimitivePtr &primitive,
bool is_input) const;
STATUS DoWeightQuant(const std::string &op_name, const AnfNodePtr &weight, const PrimitivePtr &primitive,
bool per_channel) const;