forked from mindspore-Ecosystem/mindspore
!27648 npu activation support for u8
Merge pull request !27648 from yeyunpeng2020/quant2
This commit is contained in:
commit
8d3cd3ab53
|
@ -177,8 +177,8 @@ int FullQuantQuantizer::DoParameterWeightQuant(const ParameterPtr &weight, const
|
|||
}
|
||||
auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
|
||||
auto status =
|
||||
FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, q_max_, q_min_, bit_num_,
|
||||
weight_quant_type, kNumberTypeInt8, input_index - 1, weight_symmetry_, true);
|
||||
FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, weight_q_max_, weight_q_min_,
|
||||
bit_num_, weight_quant_type, kNumberTypeInt8, input_index - 1, weight_symmetry_, true);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed: " << status;
|
||||
return status;
|
||||
|
@ -197,8 +197,8 @@ int FullQuantQuantizer::DoValueNodeWeightQuant(const ValueNodePtr &weight, const
|
|||
}
|
||||
auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
|
||||
auto status =
|
||||
FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, q_max_, q_min_, bit_num_,
|
||||
weight_quant_type, kNumberTypeInt8, input_index - 1, weight_symmetry_, true);
|
||||
FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, weight_q_max_, weight_q_min_,
|
||||
bit_num_, weight_quant_type, kNumberTypeInt8, input_index - 1, weight_symmetry_, true);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed: " << status;
|
||||
return status;
|
||||
|
@ -528,9 +528,9 @@ int FullQuantQuantizer::UpdateDivergeInterval() {
|
|||
}
|
||||
|
||||
void FullQuantQuantizer::InitCpuConfig() {
|
||||
quant_data_type_ = kNumberTypeInt8;
|
||||
activation_quant_data_type_ = kNumberTypeInt8;
|
||||
activation_target_data_type_ = kNumberTypeInt8;
|
||||
weight_target_data_type_ = kNumberTypeInt8;
|
||||
weight_data_type_ = kNumberTypeInt8;
|
||||
activation_symmetry_ = false;
|
||||
weight_symmetry_ = true;
|
||||
support_int8_ops_ = {
|
||||
|
@ -573,9 +573,9 @@ void FullQuantQuantizer::InitCpuConfig() {
|
|||
|
||||
void FullQuantQuantizer::InitKirinConfig() {
|
||||
// `kTypeUnknown` represents the original data type
|
||||
quant_data_type_ = kNumberTypeInt8;
|
||||
activation_quant_data_type_ = kNumberTypeUInt8;
|
||||
activation_target_data_type_ = kTypeUnknown;
|
||||
weight_target_data_type_ = kNumberTypeInt8;
|
||||
weight_data_type_ = kNumberTypeInt8;
|
||||
activation_symmetry_ = false;
|
||||
weight_symmetry_ = true;
|
||||
support_int8_ops_ = {prim::kPrimConv2DFusion, prim::kPrimFullConnection};
|
||||
|
@ -584,14 +584,21 @@ void FullQuantQuantizer::InitKirinConfig() {
|
|||
}
|
||||
|
||||
void FullQuantQuantizer::InitQMinMax() {
|
||||
if (this->quant_data_type_ == kNumberTypeInt8) {
|
||||
q_max_ = QuantMax(this->bit_num_, false); // 127
|
||||
q_min_ = QuantMin(this->bit_num_, false, true); // -127
|
||||
} else if (quant_data_type_ == kNumberTypeUInt8) {
|
||||
q_max_ = QuantMax(this->bit_num_, true); // 255
|
||||
q_min_ = QuantMin(this->bit_num_, true, false); // 0
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unsupported quant value type: " << quant_data_type_;
|
||||
MS_ASSERT(activation_quant_data_type_ == kNumberTypeInt8 || activation_quant_data_type_ == kNumberTypeUInt8);
|
||||
if (activation_quant_data_type_ == kNumberTypeInt8) {
|
||||
activation_q_min_ = QuantMin(this->bit_num_, false, true); // -127
|
||||
activation_q_max_ = QuantMax(this->bit_num_, false); // 127
|
||||
} else if (activation_quant_data_type_ == kNumberTypeUInt8) {
|
||||
activation_q_min_ = QuantMin(this->bit_num_, true, false); // 0
|
||||
activation_q_max_ = QuantMax(this->bit_num_, true); // 255
|
||||
}
|
||||
MS_ASSERT(weight_data_type_ == kNumberTypeInt8 || weight_data_type_ == kNumberTypeUInt8);
|
||||
if (weight_data_type_ == kNumberTypeInt8) {
|
||||
weight_q_max_ = QuantMax(this->bit_num_, false); // 127
|
||||
weight_q_min_ = QuantMin(this->bit_num_, false, true); // -127
|
||||
} else if (activation_quant_data_type_ == kNumberTypeUInt8) {
|
||||
weight_q_max_ = QuantMax(this->bit_num_, true); // 255
|
||||
weight_q_min_ = QuantMin(this->bit_num_, true, false); // 0
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -651,9 +658,9 @@ int FullQuantQuantizer::PreProcess(const FuncGraphPtr &func_graph) {
|
|||
break;
|
||||
}
|
||||
InitQMinMax();
|
||||
calibrator_ =
|
||||
std::make_unique<Calibrator>(this->bit_num_, q_max_, q_min_, this->flags_.fullQuantParam.activation_quant_method,
|
||||
this->flags_.dataPreProcessParam, activation_symmetry_);
|
||||
calibrator_ = std::make_unique<Calibrator>(this->bit_num_, activation_q_max_, activation_q_min_,
|
||||
this->flags_.fullQuantParam.activation_quant_method,
|
||||
this->flags_.dataPreProcessParam, activation_symmetry_);
|
||||
MSLITE_CHECK_PTR(calibrator_);
|
||||
auto ret = MarkQuantNode(func_graph);
|
||||
if (ret != RET_OK) {
|
||||
|
@ -1086,7 +1093,7 @@ KernelCallBack FullQuantQuantizer::GetBeforeCallBack(bool int8_op) {
|
|||
quant_param_t.scale = quant_params[0].scale;
|
||||
quant_param_t.zeroPoint = quant_params[0].zeroPoint;
|
||||
for (auto float_data : fp32_op_input) {
|
||||
auto quant_data = QuantizeData<int8_t>(float_data, &quant_param_t, q_max_, q_min_);
|
||||
auto quant_data = QuantizeData<int8_t>(float_data, &quant_param_t, activation_q_max_, activation_q_min_);
|
||||
quant_datas.push_back(quant_data);
|
||||
}
|
||||
|
||||
|
|
|
@ -102,12 +102,15 @@ class FullQuantQuantizer : public Quantizer {
|
|||
|
||||
private:
|
||||
// Config
|
||||
TypeId activation_quant_data_type_{kNumberTypeInt8};
|
||||
TypeId activation_target_data_type_{kNumberTypeInt8};
|
||||
TypeId quant_data_type_{kNumberTypeInt8};
|
||||
TypeId weight_target_data_type_{kNumberTypeInt8};
|
||||
// quant and export are same data type.
|
||||
TypeId weight_data_type_{kNumberTypeInt8};
|
||||
size_t bit_num_{8};
|
||||
int q_max_{INT8_MAX};
|
||||
int q_min_{INT8_MIN};
|
||||
int activation_q_min_{INT8_MIN};
|
||||
int activation_q_max_{INT8_MAX};
|
||||
int weight_q_min_{INT8_MIN};
|
||||
int weight_q_max_{INT8_MAX};
|
||||
bool activation_symmetry_{false};
|
||||
bool weight_symmetry_{true};
|
||||
std::set<PrimitivePtr> support_int8_ops_;
|
||||
|
|
|
@ -55,7 +55,6 @@ enum WeightQuantType {
|
|||
FIXED_BIT_PER_LAYER = 1,
|
||||
MIXED_BIT_PER_LAYER = 2,
|
||||
};
|
||||
constexpr size_t kUint8Quantization = 8;
|
||||
constexpr size_t k8Bit = 8;
|
||||
constexpr size_t k16Bit = 16;
|
||||
constexpr size_t kMaxNum1024 = 1024;
|
||||
|
|
|
@ -54,12 +54,11 @@ class WeightQuantizer : public Quantizer {
|
|||
quant_max_ = QuantMax(bit_num_, false);
|
||||
quant_min_ = QuantMin(bit_num_, false, false);
|
||||
// parse type_id_
|
||||
MS_ASSERT(bit_num_ > 0 && bit_num_ <= k16Bit);
|
||||
if (bit_num_ > 0 && bit_num_ <= k8Bit) {
|
||||
type_id_ = kNumberTypeInt8;
|
||||
} else if (bit_num_ <= k16Bit) {
|
||||
type_id_ = kNumberTypeInt16;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "invalid input bits";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue