!27648 npu activation support for u8

Merge pull request !27648 from yeyunpeng2020/quant2
This commit is contained in:
i-robot 2021-12-15 06:26:59 +00:00 committed by Gitee
commit 8d3cd3ab53
4 changed files with 35 additions and 27 deletions

View File

@ -177,8 +177,8 @@ int FullQuantQuantizer::DoParameterWeightQuant(const ParameterPtr &weight, const
}
auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
auto status =
FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, q_max_, q_min_, bit_num_,
weight_quant_type, kNumberTypeInt8, input_index - 1, weight_symmetry_, true);
FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, weight_q_max_, weight_q_min_,
bit_num_, weight_quant_type, kNumberTypeInt8, input_index - 1, weight_symmetry_, true);
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed: " << status;
return status;
@ -197,8 +197,8 @@ int FullQuantQuantizer::DoValueNodeWeightQuant(const ValueNodePtr &weight, const
}
auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
auto status =
FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, q_max_, q_min_, bit_num_,
weight_quant_type, kNumberTypeInt8, input_index - 1, weight_symmetry_, true);
FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, weight_q_max_, weight_q_min_,
bit_num_, weight_quant_type, kNumberTypeInt8, input_index - 1, weight_symmetry_, true);
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed: " << status;
return status;
@ -528,9 +528,9 @@ int FullQuantQuantizer::UpdateDivergeInterval() {
}
void FullQuantQuantizer::InitCpuConfig() {
quant_data_type_ = kNumberTypeInt8;
activation_quant_data_type_ = kNumberTypeInt8;
activation_target_data_type_ = kNumberTypeInt8;
weight_target_data_type_ = kNumberTypeInt8;
weight_data_type_ = kNumberTypeInt8;
activation_symmetry_ = false;
weight_symmetry_ = true;
support_int8_ops_ = {
@ -573,9 +573,9 @@ void FullQuantQuantizer::InitCpuConfig() {
void FullQuantQuantizer::InitKirinConfig() {
// `kTypeUnknown` represents the original data type
quant_data_type_ = kNumberTypeInt8;
activation_quant_data_type_ = kNumberTypeUInt8;
activation_target_data_type_ = kTypeUnknown;
weight_target_data_type_ = kNumberTypeInt8;
weight_data_type_ = kNumberTypeInt8;
activation_symmetry_ = false;
weight_symmetry_ = true;
support_int8_ops_ = {prim::kPrimConv2DFusion, prim::kPrimFullConnection};
@ -584,14 +584,21 @@ void FullQuantQuantizer::InitKirinConfig() {
}
void FullQuantQuantizer::InitQMinMax() {
if (this->quant_data_type_ == kNumberTypeInt8) {
q_max_ = QuantMax(this->bit_num_, false); // 127
q_min_ = QuantMin(this->bit_num_, false, true); // -127
} else if (quant_data_type_ == kNumberTypeUInt8) {
q_max_ = QuantMax(this->bit_num_, true); // 255
q_min_ = QuantMin(this->bit_num_, true, false); // 0
} else {
MS_LOG(ERROR) << "unsupported quant value type: " << quant_data_type_;
MS_ASSERT(activation_quant_data_type_ == kNumberTypeInt8 || activation_quant_data_type_ == kNumberTypeUInt8);
if (activation_quant_data_type_ == kNumberTypeInt8) {
activation_q_min_ = QuantMin(this->bit_num_, false, true); // -127
activation_q_max_ = QuantMax(this->bit_num_, false); // 127
} else if (activation_quant_data_type_ == kNumberTypeUInt8) {
activation_q_min_ = QuantMin(this->bit_num_, true, false); // 0
activation_q_max_ = QuantMax(this->bit_num_, true); // 255
}
MS_ASSERT(weight_data_type_ == kNumberTypeInt8 || weight_data_type_ == kNumberTypeUInt8);
if (weight_data_type_ == kNumberTypeInt8) {
weight_q_max_ = QuantMax(this->bit_num_, false); // 127
weight_q_min_ = QuantMin(this->bit_num_, false, true); // -127
} else if (activation_quant_data_type_ == kNumberTypeUInt8) {
weight_q_max_ = QuantMax(this->bit_num_, true); // 255
weight_q_min_ = QuantMin(this->bit_num_, true, false); // 0
}
}
@ -651,9 +658,9 @@ int FullQuantQuantizer::PreProcess(const FuncGraphPtr &func_graph) {
break;
}
InitQMinMax();
calibrator_ =
std::make_unique<Calibrator>(this->bit_num_, q_max_, q_min_, this->flags_.fullQuantParam.activation_quant_method,
this->flags_.dataPreProcessParam, activation_symmetry_);
calibrator_ = std::make_unique<Calibrator>(this->bit_num_, activation_q_max_, activation_q_min_,
this->flags_.fullQuantParam.activation_quant_method,
this->flags_.dataPreProcessParam, activation_symmetry_);
MSLITE_CHECK_PTR(calibrator_);
auto ret = MarkQuantNode(func_graph);
if (ret != RET_OK) {
@ -1086,7 +1093,7 @@ KernelCallBack FullQuantQuantizer::GetBeforeCallBack(bool int8_op) {
quant_param_t.scale = quant_params[0].scale;
quant_param_t.zeroPoint = quant_params[0].zeroPoint;
for (auto float_data : fp32_op_input) {
auto quant_data = QuantizeData<int8_t>(float_data, &quant_param_t, q_max_, q_min_);
auto quant_data = QuantizeData<int8_t>(float_data, &quant_param_t, activation_q_max_, activation_q_min_);
quant_datas.push_back(quant_data);
}

View File

@ -102,12 +102,15 @@ class FullQuantQuantizer : public Quantizer {
private:
// Config
TypeId activation_quant_data_type_{kNumberTypeInt8};
TypeId activation_target_data_type_{kNumberTypeInt8};
TypeId quant_data_type_{kNumberTypeInt8};
TypeId weight_target_data_type_{kNumberTypeInt8};
// quant and export are same data type.
TypeId weight_data_type_{kNumberTypeInt8};
size_t bit_num_{8};
int q_max_{INT8_MAX};
int q_min_{INT8_MIN};
int activation_q_min_{INT8_MIN};
int activation_q_max_{INT8_MAX};
int weight_q_min_{INT8_MIN};
int weight_q_max_{INT8_MAX};
bool activation_symmetry_{false};
bool weight_symmetry_{true};
std::set<PrimitivePtr> support_int8_ops_;

View File

@ -55,7 +55,6 @@ enum WeightQuantType {
FIXED_BIT_PER_LAYER = 1,
MIXED_BIT_PER_LAYER = 2,
};
constexpr size_t kUint8Quantization = 8;
constexpr size_t k8Bit = 8;
constexpr size_t k16Bit = 16;
constexpr size_t kMaxNum1024 = 1024;

View File

@ -54,12 +54,11 @@ class WeightQuantizer : public Quantizer {
quant_max_ = QuantMax(bit_num_, false);
quant_min_ = QuantMin(bit_num_, false, false);
// parse type_id_
MS_ASSERT(bit_num_ > 0 && bit_num_ <= k16Bit);
if (bit_num_ > 0 && bit_num_ <= k8Bit) {
type_id_ = kNumberTypeInt8;
} else if (bit_num_ <= k16Bit) {
type_id_ = kNumberTypeInt16;
} else {
MS_LOG(ERROR) << "invalid input bits";
}
}
}