From 2523ed734df5ea41791f7971e1408d6e7281ae6d Mon Sep 17 00:00:00 2001 From: yeyunpeng2020 Date: Tue, 14 Dec 2021 20:53:37 +0800 Subject: [PATCH] npu activation support for u8 --- .../quantizer/full_quant_quantizer.cc | 47 +++++++++++-------- .../quantizer/full_quant_quantizer.h | 11 +++-- .../tools/converter/quantizer/quantize_util.h | 1 - .../converter/quantizer/weight_quantizer.h | 3 +- 4 files changed, 35 insertions(+), 27 deletions(-) diff --git a/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.cc b/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.cc index 55a5fc23efe..98ae4d5081d 100644 --- a/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.cc @@ -177,8 +177,8 @@ int FullQuantQuantizer::DoParameterWeightQuant(const ParameterPtr &weight, const } auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER; auto status = - FixedBitQuantFilter(weight, tensor_info, primitive, QuantType_QUANT_ALL, q_max_, q_min_, bit_num_, - weight_quant_type, kNumberTypeInt8, input_index - 1, weight_symmetry_, true); + FixedBitQuantFilter(weight, tensor_info, primitive, QuantType_QUANT_ALL, weight_q_max_, weight_q_min_, + bit_num_, weight_quant_type, kNumberTypeInt8, input_index - 1, weight_symmetry_, true); if (status != RET_OK) { MS_LOG(ERROR) << "QuantFilter failed: " << status; return status; @@ -197,8 +197,8 @@ int FullQuantQuantizer::DoValueNodeWeightQuant(const ValueNodePtr &weight, const } auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER; auto status = - FixedBitQuantFilter(weight, tensor_info, primitive, QuantType_QUANT_ALL, q_max_, q_min_, bit_num_, - weight_quant_type, kNumberTypeInt8, input_index - 1, weight_symmetry_, true); + FixedBitQuantFilter(weight, tensor_info, primitive, QuantType_QUANT_ALL, weight_q_max_, weight_q_min_, + bit_num_, weight_quant_type, kNumberTypeInt8, input_index - 1, weight_symmetry_, true); if (status != RET_OK) { MS_LOG(ERROR) << "QuantFilter failed: " << status; return status; @@ -528,9 +528,9 @@ int FullQuantQuantizer::UpdateDivergeInterval() { } void FullQuantQuantizer::InitCpuConfig() { - quant_data_type_ = kNumberTypeInt8; + activation_quant_data_type_ = kNumberTypeInt8; activation_target_data_type_ = kNumberTypeInt8; - weight_target_data_type_ = kNumberTypeInt8; + weight_data_type_ = kNumberTypeInt8; activation_symmetry_ = false; weight_symmetry_ = true; support_int8_ops_ = { @@ -573,9 +573,9 @@ void FullQuantQuantizer::InitCpuConfig() { void FullQuantQuantizer::InitKirinConfig() { // `kTypeUnknown` represents the original data type - quant_data_type_ = kNumberTypeInt8; + activation_quant_data_type_ = kNumberTypeUInt8; activation_target_data_type_ = kTypeUnknown; - weight_target_data_type_ = kNumberTypeInt8; + weight_data_type_ = kNumberTypeInt8; activation_symmetry_ = false; weight_symmetry_ = true; support_int8_ops_ = {prim::kPrimConv2DFusion, prim::kPrimFullConnection}; @@ -584,14 +584,21 @@ void FullQuantQuantizer::InitKirinConfig() { } void FullQuantQuantizer::InitQMinMax() { - if (this->quant_data_type_ == kNumberTypeInt8) { - q_max_ = QuantMax(this->bit_num_, false); // 127 - q_min_ = QuantMin(this->bit_num_, false, true); // -127 - } else if (quant_data_type_ == kNumberTypeUInt8) { - q_max_ = QuantMax(this->bit_num_, true); // 255 - q_min_ = QuantMin(this->bit_num_, true, false); // 0 - } else { - MS_LOG(ERROR) << "unsupported quant value type: " << quant_data_type_; + MS_ASSERT(activation_quant_data_type_ == kNumberTypeInt8 || activation_quant_data_type_ == kNumberTypeUInt8); + if (activation_quant_data_type_ == kNumberTypeInt8) { + activation_q_min_ = QuantMin(this->bit_num_, false, true); // -127 + activation_q_max_ = QuantMax(this->bit_num_, false); // 127 + } else if (activation_quant_data_type_ == kNumberTypeUInt8) { + activation_q_min_ = QuantMin(this->bit_num_, true, false); // 0 + activation_q_max_ = QuantMax(this->bit_num_, true); // 255 + } + MS_ASSERT(weight_data_type_ == kNumberTypeInt8 || weight_data_type_ == kNumberTypeUInt8); + if (weight_data_type_ == kNumberTypeInt8) { + weight_q_max_ = QuantMax(this->bit_num_, false); // 127 + weight_q_min_ = QuantMin(this->bit_num_, false, true); // -127 + } else if (activation_quant_data_type_ == kNumberTypeUInt8) { + weight_q_max_ = QuantMax(this->bit_num_, true); // 255 + weight_q_min_ = QuantMin(this->bit_num_, true, false); // 0 } } @@ -651,9 +658,9 @@ int FullQuantQuantizer::PreProcess(const FuncGraphPtr &func_graph) { break; } InitQMinMax(); - calibrator_ = - std::make_unique(this->bit_num_, q_max_, q_min_, this->flags_.fullQuantParam.activation_quant_method, - this->flags_.dataPreProcessParam, activation_symmetry_); + calibrator_ = std::make_unique(this->bit_num_, activation_q_max_, activation_q_min_, + this->flags_.fullQuantParam.activation_quant_method, + this->flags_.dataPreProcessParam, activation_symmetry_); MSLITE_CHECK_PTR(calibrator_); auto ret = MarkQuantNode(func_graph); if (ret != RET_OK) { @@ -1086,7 +1093,7 @@ KernelCallBack FullQuantQuantizer::GetBeforeCallBack(bool int8_op) { quant_param_t.scale = quant_params[0].scale; quant_param_t.zeroPoint = quant_params[0].zeroPoint; for (auto float_data : fp32_op_input) { - auto quant_data = QuantizeData(float_data, &quant_param_t, q_max_, q_min_); + auto quant_data = QuantizeData(float_data, &quant_param_t, activation_q_max_, activation_q_min_); quant_datas.push_back(quant_data); } diff --git a/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.h b/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.h index 42d56fcd0af..4178b99da94 100644 --- a/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.h @@ -102,12 +102,15 @@ class FullQuantQuantizer : public Quantizer { private: // Config + TypeId activation_quant_data_type_{kNumberTypeInt8}; TypeId activation_target_data_type_{kNumberTypeInt8}; - TypeId quant_data_type_{kNumberTypeInt8}; - TypeId weight_target_data_type_{kNumberTypeInt8}; + // quant and export are same data type. + TypeId weight_data_type_{kNumberTypeInt8}; size_t bit_num_{8}; - int q_max_{INT8_MAX}; - int q_min_{INT8_MIN}; + int activation_q_min_{INT8_MIN}; + int activation_q_max_{INT8_MAX}; + int weight_q_min_{INT8_MIN}; + int weight_q_max_{INT8_MAX}; bool activation_symmetry_{false}; bool weight_symmetry_{true}; std::set support_int8_ops_; diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index 4e13358ec62..80a31ba6ed6 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -55,7 +55,6 @@ enum WeightQuantType { FIXED_BIT_PER_LAYER = 1, MIXED_BIT_PER_LAYER = 2, }; -constexpr size_t kUint8Quantization = 8; constexpr size_t k8Bit = 8; constexpr size_t k16Bit = 16; constexpr size_t kMaxNum1024 = 1024; diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h index 7cd39beffca..2b7bad2a215 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h @@ -54,12 +54,11 @@ class WeightQuantizer : public Quantizer { quant_max_ = QuantMax(bit_num_, false); quant_min_ = QuantMin(bit_num_, false, false); // parse type_id_ + MS_ASSERT(bit_num_ > 0 && bit_num_ <= k16Bit); if (bit_num_ > 0 && bit_num_ <= k8Bit) { type_id_ = kNumberTypeInt8; } else if (bit_num_ <= k16Bit) { type_id_ = kNumberTypeInt16; - } else { - MS_LOG(ERROR) << "invalid input bits"; } } }