diff --git a/mindspore/lite/src/common/quant_utils.cc b/mindspore/lite/src/common/quant_utils.cc index 576e1af4faa..1a964db33f0 100644 --- a/mindspore/lite/src/common/quant_utils.cc +++ b/mindspore/lite/src/common/quant_utils.cc @@ -62,10 +62,11 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other"; return RET_ERROR; } + MS_LOG(WARNING) << "The maximum and minimum values are equal to 0."; quantParam->inited = true; quantParam->min = mMin; quantParam->max = mMax; - quantParam->scale = 0.0f; + quantParam->scale = 1; quantParam->zeroPoint = 0; quantParam->narrowRange = narrowRange; quantParam->numBits = num_bits; diff --git a/mindspore/lite/src/weight_decoder.h b/mindspore/lite/src/weight_decoder.h index d10319b9cf9..690f205f70e 100644 --- a/mindspore/lite/src/weight_decoder.h +++ b/mindspore/lite/src/weight_decoder.h @@ -89,9 +89,9 @@ STATUS UnSparseTensorData(const std::vector &unique_values, const std::vect return RET_ERROR; } auto coor = coors[i]; - auto cur_channel = data_index / elem_perchannel; - auto zp = quant_params->Get(cur_channel)->zeroPoint(); for (size_t j = 0; j < coor; j++) { + auto cur_channel = data_index / elem_perchannel; + auto zp = quant_params->Get(cur_channel)->zeroPoint(); un_sparsed_data.push_back(zp); data_index++; } diff --git a/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.cc b/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.cc index 0081fb4b063..0ad94be1d02 100644 --- a/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.cc @@ -460,15 +460,20 @@ FullQuantQuantizer::~FullQuantQuantizer() { delete int8_model_; } -STATUS FullQuantQuantizer::DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, - const PrimitivePtr &primitive) const { +STATUS FullQuantQuantizer::SetInOutQuantParam(double scale, int zero_point, struct MaxMin *max_min, + const PrimitivePtr &primitive, bool is_input) const { MS_ASSERT(max_min != nullptr); MS_ASSERT(primitive != nullptr); auto quant_param_holder = GetCNodeQuantHolder(primitive); MS_CHECK_TRUE_MSG(quant_param_holder != nullptr, RET_NULL_PTR, "quant_param_holder is nullptr."); schema::QuantParamT quant_param; - quant_param.scale = scale; - quant_param.zeroPoint = zeropoint; + if (scale == 0) { + MS_LOG(WARNING) << "The input or output values are very close to 0, so set the scale to 1."; + quant_param.scale = 1; + } else { + quant_param.scale = scale; + } + quant_param.zeroPoint = zero_point; quant_param.max = max_min->max; quant_param.min = max_min->min; quant_param.numBits = bit_num; @@ -477,28 +482,11 @@ STATUS FullQuantQuantizer::DoQuantInput(double scale, int32_t zeropoint, struct quant_param.roundType = 1; quant_param.multiplier = 1; std::vector quant_params = {quant_param}; - quant_param_holder->AddInputQuantParam(quant_params); - return RET_OK; -} - -STATUS FullQuantQuantizer::DoQuantOutput(double scale, int zeropoint, struct MaxMin *max_min, - const PrimitivePtr &primitive) const { - MS_ASSERT(max_min != nullptr); - MS_ASSERT(primitive != nullptr); - auto quant_param_holder = GetCNodeQuantHolder(primitive); - MS_CHECK_TRUE_MSG(quant_param_holder != nullptr, RET_NULL_PTR, "quant_param_holder is nullptr."); - schema::QuantParamT quant_param; - quant_param.scale = scale; - quant_param.zeroPoint = zeropoint; - quant_param.max = max_min->max; - quant_param.min = max_min->min; - quant_param.numBits = bit_num; - quant_param.narrowRange = false; - quant_param.inited = true; - quant_param.roundType = 1; - quant_param.multiplier = 1; - std::vector quant_params = {quant_param}; - quant_param_holder->AddOutputQuantParam(quant_params); + if (is_input) { + quant_param_holder->AddInputQuantParam(quant_params); + } else { + quant_param_holder->AddOutputQuantParam(quant_params); + } return RET_OK; } @@ -684,7 +672,7 @@ STATUS FullQuantQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) { struct MaxMin input_min_max {}; input_min_max.max = info->max; input_min_max.min = info->min; - DoQuantInput(input_scale, input_zp, &input_min_max, primitive); + SetInOutQuantParam(input_scale, input_zp, &input_min_max, primitive, true); } } else if (is_graph_input) { auto &info = (*inputs_diverg_info)[op_name][i - 1]; @@ -693,7 +681,7 @@ STATUS FullQuantQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) { struct MaxMin input_min_max {}; input_min_max.max = info->max; input_min_max.min = info->min; - DoQuantInput(input_scale, input_zp, &input_min_max, primitive); + SetInOutQuantParam(input_scale, input_zp, &input_min_max, primitive, true); } else { MS_LOG(DEBUG) << "node: " << op_name << " input " << i << " not a cnode"; // get dtype @@ -785,7 +773,7 @@ STATUS FullQuantQuantizer::QuantNode() { struct MaxMin input_min_max {}; input_min_max.max = info->max; input_min_max.min = info->min; - DoQuantInput(input_scale, input_zp, &input_min_max, primitive); + SetInOutQuantParam(input_scale, input_zp, &input_min_max, primitive, true); // do weight quant auto weight = cnode->input(2); bool per_channel = false; @@ -814,7 +802,7 @@ STATUS FullQuantQuantizer::QuantNode() { output_min_max.max = info->max; output_min_max.min = info->min; - DoQuantOutput(output_scale, output_zp, &output_min_max, primitive); + SetInOutQuantParam(output_scale, output_zp, &output_min_max, primitive, false); primitive_quant_holder->set_quant_type(schema::QuantType_QUANT_ALL); } } @@ -1219,6 +1207,18 @@ STATUS FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) { } calibrator_->full_quant_param_ = flags.fullQuantParam; calibrator_->data_pre_process_param_ = flags.dataPreProcessParam; + if (flags.dataPreProcessParam.calibrate_path_vector.empty()) { + MS_LOG(ERROR) << "calibrate path must pass. The format is input_name_1:input_1_dir,input_name_2:input_2_dir."; + return RET_INPUT_PARAM_INVALID; + } + if (flags.dataPreProcessParam.calibrate_size < 0) { + MS_LOG(ERROR) << "calibrate size must pass and the size must > 0."; + return RET_INPUT_PARAM_INVALID; + } + if (flags.dataPreProcessParam.input_type == preprocess::INPUT_TYPE_MAX) { + MS_LOG(ERROR) << "input_type must pass IMAGE | BIN."; + return RET_INPUT_PARAM_INVALID; + } STATUS status; status = PreProcess(); if (status != RET_OK) { diff --git a/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.h b/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.h index 36878a07d38..27b0ca95b4b 100644 --- a/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.h @@ -106,9 +106,8 @@ class FullQuantQuantizer : public Quantizer { STATUS QuantNode(); - STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, const PrimitivePtr &primitive) const; - - STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, const PrimitivePtr &) const; + STATUS SetInOutQuantParam(double scale, int32_t zero_point, struct MaxMin *max_min, const PrimitivePtr &primitive, + bool is_input) const; STATUS DoWeightQuant(const std::string &op_name, const AnfNodePtr &weight, const PrimitivePtr &primitive, bool per_channel) const;