!27648 npu activation support for u8

Merge pull request !27648 from yeyunpeng2020/quant2
This commit is contained in:
i-robot 2021-12-15 06:26:59 +00:00 committed by Gitee
commit 8d3cd3ab53
4 changed files with 35 additions and 27 deletions

View File

@ -177,8 +177,8 @@ int FullQuantQuantizer::DoParameterWeightQuant(const ParameterPtr &weight, const
} }
auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER; auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
auto status = auto status =
FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, q_max_, q_min_, bit_num_, FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, weight_q_max_, weight_q_min_,
weight_quant_type, kNumberTypeInt8, input_index - 1, weight_symmetry_, true); bit_num_, weight_quant_type, kNumberTypeInt8, input_index - 1, weight_symmetry_, true);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed: " << status; MS_LOG(ERROR) << "QuantFilter failed: " << status;
return status; return status;
@ -197,8 +197,8 @@ int FullQuantQuantizer::DoValueNodeWeightQuant(const ValueNodePtr &weight, const
} }
auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER; auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
auto status = auto status =
FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, q_max_, q_min_, bit_num_, FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, weight_q_max_, weight_q_min_,
weight_quant_type, kNumberTypeInt8, input_index - 1, weight_symmetry_, true); bit_num_, weight_quant_type, kNumberTypeInt8, input_index - 1, weight_symmetry_, true);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed: " << status; MS_LOG(ERROR) << "QuantFilter failed: " << status;
return status; return status;
@ -528,9 +528,9 @@ int FullQuantQuantizer::UpdateDivergeInterval() {
} }
void FullQuantQuantizer::InitCpuConfig() { void FullQuantQuantizer::InitCpuConfig() {
quant_data_type_ = kNumberTypeInt8; activation_quant_data_type_ = kNumberTypeInt8;
activation_target_data_type_ = kNumberTypeInt8; activation_target_data_type_ = kNumberTypeInt8;
weight_target_data_type_ = kNumberTypeInt8; weight_data_type_ = kNumberTypeInt8;
activation_symmetry_ = false; activation_symmetry_ = false;
weight_symmetry_ = true; weight_symmetry_ = true;
support_int8_ops_ = { support_int8_ops_ = {
@ -573,9 +573,9 @@ void FullQuantQuantizer::InitCpuConfig() {
void FullQuantQuantizer::InitKirinConfig() { void FullQuantQuantizer::InitKirinConfig() {
// `kTypeUnknown` represents the original data type // `kTypeUnknown` represents the original data type
quant_data_type_ = kNumberTypeInt8; activation_quant_data_type_ = kNumberTypeUInt8;
activation_target_data_type_ = kTypeUnknown; activation_target_data_type_ = kTypeUnknown;
weight_target_data_type_ = kNumberTypeInt8; weight_data_type_ = kNumberTypeInt8;
activation_symmetry_ = false; activation_symmetry_ = false;
weight_symmetry_ = true; weight_symmetry_ = true;
support_int8_ops_ = {prim::kPrimConv2DFusion, prim::kPrimFullConnection}; support_int8_ops_ = {prim::kPrimConv2DFusion, prim::kPrimFullConnection};
@ -584,14 +584,21 @@ void FullQuantQuantizer::InitKirinConfig() {
} }
void FullQuantQuantizer::InitQMinMax() { void FullQuantQuantizer::InitQMinMax() {
if (this->quant_data_type_ == kNumberTypeInt8) { MS_ASSERT(activation_quant_data_type_ == kNumberTypeInt8 || activation_quant_data_type_ == kNumberTypeUInt8);
q_max_ = QuantMax(this->bit_num_, false); // 127 if (activation_quant_data_type_ == kNumberTypeInt8) {
q_min_ = QuantMin(this->bit_num_, false, true); // -127 activation_q_min_ = QuantMin(this->bit_num_, false, true); // -127
} else if (quant_data_type_ == kNumberTypeUInt8) { activation_q_max_ = QuantMax(this->bit_num_, false); // 127
q_max_ = QuantMax(this->bit_num_, true); // 255 } else if (activation_quant_data_type_ == kNumberTypeUInt8) {
q_min_ = QuantMin(this->bit_num_, true, false); // 0 activation_q_min_ = QuantMin(this->bit_num_, true, false); // 0
} else { activation_q_max_ = QuantMax(this->bit_num_, true); // 255
MS_LOG(ERROR) << "unsupported quant value type: " << quant_data_type_; }
MS_ASSERT(weight_data_type_ == kNumberTypeInt8 || weight_data_type_ == kNumberTypeUInt8);
if (weight_data_type_ == kNumberTypeInt8) {
weight_q_max_ = QuantMax(this->bit_num_, false); // 127
weight_q_min_ = QuantMin(this->bit_num_, false, true); // -127
} else if (activation_quant_data_type_ == kNumberTypeUInt8) {
weight_q_max_ = QuantMax(this->bit_num_, true); // 255
weight_q_min_ = QuantMin(this->bit_num_, true, false); // 0
} }
} }
@ -651,8 +658,8 @@ int FullQuantQuantizer::PreProcess(const FuncGraphPtr &func_graph) {
break; break;
} }
InitQMinMax(); InitQMinMax();
calibrator_ = calibrator_ = std::make_unique<Calibrator>(this->bit_num_, activation_q_max_, activation_q_min_,
std::make_unique<Calibrator>(this->bit_num_, q_max_, q_min_, this->flags_.fullQuantParam.activation_quant_method, this->flags_.fullQuantParam.activation_quant_method,
this->flags_.dataPreProcessParam, activation_symmetry_); this->flags_.dataPreProcessParam, activation_symmetry_);
MSLITE_CHECK_PTR(calibrator_); MSLITE_CHECK_PTR(calibrator_);
auto ret = MarkQuantNode(func_graph); auto ret = MarkQuantNode(func_graph);
@ -1086,7 +1093,7 @@ KernelCallBack FullQuantQuantizer::GetBeforeCallBack(bool int8_op) {
quant_param_t.scale = quant_params[0].scale; quant_param_t.scale = quant_params[0].scale;
quant_param_t.zeroPoint = quant_params[0].zeroPoint; quant_param_t.zeroPoint = quant_params[0].zeroPoint;
for (auto float_data : fp32_op_input) { for (auto float_data : fp32_op_input) {
auto quant_data = QuantizeData<int8_t>(float_data, &quant_param_t, q_max_, q_min_); auto quant_data = QuantizeData<int8_t>(float_data, &quant_param_t, activation_q_max_, activation_q_min_);
quant_datas.push_back(quant_data); quant_datas.push_back(quant_data);
} }

View File

@ -102,12 +102,15 @@ class FullQuantQuantizer : public Quantizer {
private: private:
// Config // Config
TypeId activation_quant_data_type_{kNumberTypeInt8};
TypeId activation_target_data_type_{kNumberTypeInt8}; TypeId activation_target_data_type_{kNumberTypeInt8};
TypeId quant_data_type_{kNumberTypeInt8}; // quant and export are same data type.
TypeId weight_target_data_type_{kNumberTypeInt8}; TypeId weight_data_type_{kNumberTypeInt8};
size_t bit_num_{8}; size_t bit_num_{8};
int q_max_{INT8_MAX}; int activation_q_min_{INT8_MIN};
int q_min_{INT8_MIN}; int activation_q_max_{INT8_MAX};
int weight_q_min_{INT8_MIN};
int weight_q_max_{INT8_MAX};
bool activation_symmetry_{false}; bool activation_symmetry_{false};
bool weight_symmetry_{true}; bool weight_symmetry_{true};
std::set<PrimitivePtr> support_int8_ops_; std::set<PrimitivePtr> support_int8_ops_;

View File

@ -55,7 +55,6 @@ enum WeightQuantType {
FIXED_BIT_PER_LAYER = 1, FIXED_BIT_PER_LAYER = 1,
MIXED_BIT_PER_LAYER = 2, MIXED_BIT_PER_LAYER = 2,
}; };
constexpr size_t kUint8Quantization = 8;
constexpr size_t k8Bit = 8; constexpr size_t k8Bit = 8;
constexpr size_t k16Bit = 16; constexpr size_t k16Bit = 16;
constexpr size_t kMaxNum1024 = 1024; constexpr size_t kMaxNum1024 = 1024;

View File

@ -54,12 +54,11 @@ class WeightQuantizer : public Quantizer {
quant_max_ = QuantMax(bit_num_, false); quant_max_ = QuantMax(bit_num_, false);
quant_min_ = QuantMin(bit_num_, false, false); quant_min_ = QuantMin(bit_num_, false, false);
// parse type_id_ // parse type_id_
MS_ASSERT(bit_num_ > 0 && bit_num_ <= k16Bit);
if (bit_num_ > 0 && bit_num_ <= k8Bit) { if (bit_num_ > 0 && bit_num_ <= k8Bit) {
type_id_ = kNumberTypeInt8; type_id_ = kNumberTypeInt8;
} else if (bit_num_ <= k16Bit) { } else if (bit_num_ <= k16Bit) {
type_id_ = kNumberTypeInt16; type_id_ = kNumberTypeInt16;
} else {
MS_LOG(ERROR) << "invalid input bits";
} }
} }
} }