fix sparse decode bug
This commit is contained in:
parent
e17d24f802
commit
3d54a5434d
|
@ -62,10 +62,11 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl
|
|||
MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_LOG(WARNING) << "The maximum and minimum values are equal to 0.";
|
||||
quantParam->inited = true;
|
||||
quantParam->min = mMin;
|
||||
quantParam->max = mMax;
|
||||
quantParam->scale = 0.0f;
|
||||
quantParam->scale = 1;
|
||||
quantParam->zeroPoint = 0;
|
||||
quantParam->narrowRange = narrowRange;
|
||||
quantParam->numBits = num_bits;
|
||||
|
|
|
@ -89,9 +89,9 @@ STATUS UnSparseTensorData(const std::vector<int> &unique_values, const std::vect
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto coor = coors[i];
|
||||
auto cur_channel = data_index / elem_perchannel;
|
||||
auto zp = quant_params->Get(cur_channel)->zeroPoint();
|
||||
for (size_t j = 0; j < coor; j++) {
|
||||
auto cur_channel = data_index / elem_perchannel;
|
||||
auto zp = quant_params->Get(cur_channel)->zeroPoint();
|
||||
un_sparsed_data.push_back(zp);
|
||||
data_index++;
|
||||
}
|
||||
|
|
|
@ -460,15 +460,20 @@ FullQuantQuantizer::~FullQuantQuantizer() {
|
|||
delete int8_model_;
|
||||
}
|
||||
|
||||
STATUS FullQuantQuantizer::DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min,
|
||||
const PrimitivePtr &primitive) const {
|
||||
STATUS FullQuantQuantizer::SetInOutQuantParam(double scale, int zero_point, struct MaxMin *max_min,
|
||||
const PrimitivePtr &primitive, bool is_input) const {
|
||||
MS_ASSERT(max_min != nullptr);
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
auto quant_param_holder = GetCNodeQuantHolder(primitive);
|
||||
MS_CHECK_TRUE_MSG(quant_param_holder != nullptr, RET_NULL_PTR, "quant_param_holder is nullptr.");
|
||||
schema::QuantParamT quant_param;
|
||||
quant_param.scale = scale;
|
||||
quant_param.zeroPoint = zeropoint;
|
||||
if (scale == 0) {
|
||||
MS_LOG(WARNING) << "The input or output values are very close to 0, so set the scale to 1.";
|
||||
quant_param.scale = 1;
|
||||
} else {
|
||||
quant_param.scale = scale;
|
||||
}
|
||||
quant_param.zeroPoint = zero_point;
|
||||
quant_param.max = max_min->max;
|
||||
quant_param.min = max_min->min;
|
||||
quant_param.numBits = bit_num;
|
||||
|
@ -477,28 +482,11 @@ STATUS FullQuantQuantizer::DoQuantInput(double scale, int32_t zeropoint, struct
|
|||
quant_param.roundType = 1;
|
||||
quant_param.multiplier = 1;
|
||||
std::vector<schema::QuantParamT> quant_params = {quant_param};
|
||||
quant_param_holder->AddInputQuantParam(quant_params);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS FullQuantQuantizer::DoQuantOutput(double scale, int zeropoint, struct MaxMin *max_min,
|
||||
const PrimitivePtr &primitive) const {
|
||||
MS_ASSERT(max_min != nullptr);
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
auto quant_param_holder = GetCNodeQuantHolder(primitive);
|
||||
MS_CHECK_TRUE_MSG(quant_param_holder != nullptr, RET_NULL_PTR, "quant_param_holder is nullptr.");
|
||||
schema::QuantParamT quant_param;
|
||||
quant_param.scale = scale;
|
||||
quant_param.zeroPoint = zeropoint;
|
||||
quant_param.max = max_min->max;
|
||||
quant_param.min = max_min->min;
|
||||
quant_param.numBits = bit_num;
|
||||
quant_param.narrowRange = false;
|
||||
quant_param.inited = true;
|
||||
quant_param.roundType = 1;
|
||||
quant_param.multiplier = 1;
|
||||
std::vector<schema::QuantParamT> quant_params = {quant_param};
|
||||
quant_param_holder->AddOutputQuantParam(quant_params);
|
||||
if (is_input) {
|
||||
quant_param_holder->AddInputQuantParam(quant_params);
|
||||
} else {
|
||||
quant_param_holder->AddOutputQuantParam(quant_params);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -684,7 +672,7 @@ STATUS FullQuantQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) {
|
|||
struct MaxMin input_min_max {};
|
||||
input_min_max.max = info->max;
|
||||
input_min_max.min = info->min;
|
||||
DoQuantInput(input_scale, input_zp, &input_min_max, primitive);
|
||||
SetInOutQuantParam(input_scale, input_zp, &input_min_max, primitive, true);
|
||||
}
|
||||
} else if (is_graph_input) {
|
||||
auto &info = (*inputs_diverg_info)[op_name][i - 1];
|
||||
|
@ -693,7 +681,7 @@ STATUS FullQuantQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) {
|
|||
struct MaxMin input_min_max {};
|
||||
input_min_max.max = info->max;
|
||||
input_min_max.min = info->min;
|
||||
DoQuantInput(input_scale, input_zp, &input_min_max, primitive);
|
||||
SetInOutQuantParam(input_scale, input_zp, &input_min_max, primitive, true);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "node: " << op_name << " input " << i << " not a cnode";
|
||||
// get dtype
|
||||
|
@ -785,7 +773,7 @@ STATUS FullQuantQuantizer::QuantNode() {
|
|||
struct MaxMin input_min_max {};
|
||||
input_min_max.max = info->max;
|
||||
input_min_max.min = info->min;
|
||||
DoQuantInput(input_scale, input_zp, &input_min_max, primitive);
|
||||
SetInOutQuantParam(input_scale, input_zp, &input_min_max, primitive, true);
|
||||
// do weight quant
|
||||
auto weight = cnode->input(2);
|
||||
bool per_channel = false;
|
||||
|
@ -814,7 +802,7 @@ STATUS FullQuantQuantizer::QuantNode() {
|
|||
output_min_max.max = info->max;
|
||||
output_min_max.min = info->min;
|
||||
|
||||
DoQuantOutput(output_scale, output_zp, &output_min_max, primitive);
|
||||
SetInOutQuantParam(output_scale, output_zp, &output_min_max, primitive, false);
|
||||
primitive_quant_holder->set_quant_type(schema::QuantType_QUANT_ALL);
|
||||
}
|
||||
}
|
||||
|
@ -1219,6 +1207,18 @@ STATUS FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
}
|
||||
calibrator_->full_quant_param_ = flags.fullQuantParam;
|
||||
calibrator_->data_pre_process_param_ = flags.dataPreProcessParam;
|
||||
if (flags.dataPreProcessParam.calibrate_path_vector.empty()) {
|
||||
MS_LOG(ERROR) << "calibrate path must pass. The format is input_name_1:input_1_dir,input_name_2:input_2_dir.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (flags.dataPreProcessParam.calibrate_size < 0) {
|
||||
MS_LOG(ERROR) << "calibrate size must pass and the size must > 0.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (flags.dataPreProcessParam.input_type == preprocess::INPUT_TYPE_MAX) {
|
||||
MS_LOG(ERROR) << "input_type must pass IMAGE | BIN.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
STATUS status;
|
||||
status = PreProcess();
|
||||
if (status != RET_OK) {
|
||||
|
|
|
@ -106,9 +106,8 @@ class FullQuantQuantizer : public Quantizer {
|
|||
|
||||
STATUS QuantNode();
|
||||
|
||||
STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, const PrimitivePtr &primitive) const;
|
||||
|
||||
STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, const PrimitivePtr &) const;
|
||||
STATUS SetInOutQuantParam(double scale, int32_t zero_point, struct MaxMin *max_min, const PrimitivePtr &primitive,
|
||||
bool is_input) const;
|
||||
|
||||
STATUS DoWeightQuant(const std::string &op_name, const AnfNodePtr &weight, const PrimitivePtr &primitive,
|
||||
bool per_channel) const;
|
||||
|
|
Loading…
Reference in New Issue