forked from mindspore-Ecosystem/mindspore
!27648 npu activation support for u8
Merge pull request !27648 from yeyunpeng2020/quant2
This commit is contained in:
commit
8d3cd3ab53
|
@ -177,8 +177,8 @@ int FullQuantQuantizer::DoParameterWeightQuant(const ParameterPtr &weight, const
|
||||||
}
|
}
|
||||||
auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
|
auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
|
||||||
auto status =
|
auto status =
|
||||||
FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, q_max_, q_min_, bit_num_,
|
FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, weight_q_max_, weight_q_min_,
|
||||||
weight_quant_type, kNumberTypeInt8, input_index - 1, weight_symmetry_, true);
|
bit_num_, weight_quant_type, kNumberTypeInt8, input_index - 1, weight_symmetry_, true);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "QuantFilter failed: " << status;
|
MS_LOG(ERROR) << "QuantFilter failed: " << status;
|
||||||
return status;
|
return status;
|
||||||
|
@ -197,8 +197,8 @@ int FullQuantQuantizer::DoValueNodeWeightQuant(const ValueNodePtr &weight, const
|
||||||
}
|
}
|
||||||
auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
|
auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
|
||||||
auto status =
|
auto status =
|
||||||
FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, q_max_, q_min_, bit_num_,
|
FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, weight_q_max_, weight_q_min_,
|
||||||
weight_quant_type, kNumberTypeInt8, input_index - 1, weight_symmetry_, true);
|
bit_num_, weight_quant_type, kNumberTypeInt8, input_index - 1, weight_symmetry_, true);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "QuantFilter failed: " << status;
|
MS_LOG(ERROR) << "QuantFilter failed: " << status;
|
||||||
return status;
|
return status;
|
||||||
|
@ -528,9 +528,9 @@ int FullQuantQuantizer::UpdateDivergeInterval() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void FullQuantQuantizer::InitCpuConfig() {
|
void FullQuantQuantizer::InitCpuConfig() {
|
||||||
quant_data_type_ = kNumberTypeInt8;
|
activation_quant_data_type_ = kNumberTypeInt8;
|
||||||
activation_target_data_type_ = kNumberTypeInt8;
|
activation_target_data_type_ = kNumberTypeInt8;
|
||||||
weight_target_data_type_ = kNumberTypeInt8;
|
weight_data_type_ = kNumberTypeInt8;
|
||||||
activation_symmetry_ = false;
|
activation_symmetry_ = false;
|
||||||
weight_symmetry_ = true;
|
weight_symmetry_ = true;
|
||||||
support_int8_ops_ = {
|
support_int8_ops_ = {
|
||||||
|
@ -573,9 +573,9 @@ void FullQuantQuantizer::InitCpuConfig() {
|
||||||
|
|
||||||
void FullQuantQuantizer::InitKirinConfig() {
|
void FullQuantQuantizer::InitKirinConfig() {
|
||||||
// `kTypeUnknown` represents the original data type
|
// `kTypeUnknown` represents the original data type
|
||||||
quant_data_type_ = kNumberTypeInt8;
|
activation_quant_data_type_ = kNumberTypeUInt8;
|
||||||
activation_target_data_type_ = kTypeUnknown;
|
activation_target_data_type_ = kTypeUnknown;
|
||||||
weight_target_data_type_ = kNumberTypeInt8;
|
weight_data_type_ = kNumberTypeInt8;
|
||||||
activation_symmetry_ = false;
|
activation_symmetry_ = false;
|
||||||
weight_symmetry_ = true;
|
weight_symmetry_ = true;
|
||||||
support_int8_ops_ = {prim::kPrimConv2DFusion, prim::kPrimFullConnection};
|
support_int8_ops_ = {prim::kPrimConv2DFusion, prim::kPrimFullConnection};
|
||||||
|
@ -584,14 +584,21 @@ void FullQuantQuantizer::InitKirinConfig() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void FullQuantQuantizer::InitQMinMax() {
|
void FullQuantQuantizer::InitQMinMax() {
|
||||||
if (this->quant_data_type_ == kNumberTypeInt8) {
|
MS_ASSERT(activation_quant_data_type_ == kNumberTypeInt8 || activation_quant_data_type_ == kNumberTypeUInt8);
|
||||||
q_max_ = QuantMax(this->bit_num_, false); // 127
|
if (activation_quant_data_type_ == kNumberTypeInt8) {
|
||||||
q_min_ = QuantMin(this->bit_num_, false, true); // -127
|
activation_q_min_ = QuantMin(this->bit_num_, false, true); // -127
|
||||||
} else if (quant_data_type_ == kNumberTypeUInt8) {
|
activation_q_max_ = QuantMax(this->bit_num_, false); // 127
|
||||||
q_max_ = QuantMax(this->bit_num_, true); // 255
|
} else if (activation_quant_data_type_ == kNumberTypeUInt8) {
|
||||||
q_min_ = QuantMin(this->bit_num_, true, false); // 0
|
activation_q_min_ = QuantMin(this->bit_num_, true, false); // 0
|
||||||
} else {
|
activation_q_max_ = QuantMax(this->bit_num_, true); // 255
|
||||||
MS_LOG(ERROR) << "unsupported quant value type: " << quant_data_type_;
|
}
|
||||||
|
MS_ASSERT(weight_data_type_ == kNumberTypeInt8 || weight_data_type_ == kNumberTypeUInt8);
|
||||||
|
if (weight_data_type_ == kNumberTypeInt8) {
|
||||||
|
weight_q_max_ = QuantMax(this->bit_num_, false); // 127
|
||||||
|
weight_q_min_ = QuantMin(this->bit_num_, false, true); // -127
|
||||||
|
} else if (activation_quant_data_type_ == kNumberTypeUInt8) {
|
||||||
|
weight_q_max_ = QuantMax(this->bit_num_, true); // 255
|
||||||
|
weight_q_min_ = QuantMin(this->bit_num_, true, false); // 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -651,8 +658,8 @@ int FullQuantQuantizer::PreProcess(const FuncGraphPtr &func_graph) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
InitQMinMax();
|
InitQMinMax();
|
||||||
calibrator_ =
|
calibrator_ = std::make_unique<Calibrator>(this->bit_num_, activation_q_max_, activation_q_min_,
|
||||||
std::make_unique<Calibrator>(this->bit_num_, q_max_, q_min_, this->flags_.fullQuantParam.activation_quant_method,
|
this->flags_.fullQuantParam.activation_quant_method,
|
||||||
this->flags_.dataPreProcessParam, activation_symmetry_);
|
this->flags_.dataPreProcessParam, activation_symmetry_);
|
||||||
MSLITE_CHECK_PTR(calibrator_);
|
MSLITE_CHECK_PTR(calibrator_);
|
||||||
auto ret = MarkQuantNode(func_graph);
|
auto ret = MarkQuantNode(func_graph);
|
||||||
|
@ -1086,7 +1093,7 @@ KernelCallBack FullQuantQuantizer::GetBeforeCallBack(bool int8_op) {
|
||||||
quant_param_t.scale = quant_params[0].scale;
|
quant_param_t.scale = quant_params[0].scale;
|
||||||
quant_param_t.zeroPoint = quant_params[0].zeroPoint;
|
quant_param_t.zeroPoint = quant_params[0].zeroPoint;
|
||||||
for (auto float_data : fp32_op_input) {
|
for (auto float_data : fp32_op_input) {
|
||||||
auto quant_data = QuantizeData<int8_t>(float_data, &quant_param_t, q_max_, q_min_);
|
auto quant_data = QuantizeData<int8_t>(float_data, &quant_param_t, activation_q_max_, activation_q_min_);
|
||||||
quant_datas.push_back(quant_data);
|
quant_datas.push_back(quant_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -102,12 +102,15 @@ class FullQuantQuantizer : public Quantizer {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Config
|
// Config
|
||||||
|
TypeId activation_quant_data_type_{kNumberTypeInt8};
|
||||||
TypeId activation_target_data_type_{kNumberTypeInt8};
|
TypeId activation_target_data_type_{kNumberTypeInt8};
|
||||||
TypeId quant_data_type_{kNumberTypeInt8};
|
// quant and export are same data type.
|
||||||
TypeId weight_target_data_type_{kNumberTypeInt8};
|
TypeId weight_data_type_{kNumberTypeInt8};
|
||||||
size_t bit_num_{8};
|
size_t bit_num_{8};
|
||||||
int q_max_{INT8_MAX};
|
int activation_q_min_{INT8_MIN};
|
||||||
int q_min_{INT8_MIN};
|
int activation_q_max_{INT8_MAX};
|
||||||
|
int weight_q_min_{INT8_MIN};
|
||||||
|
int weight_q_max_{INT8_MAX};
|
||||||
bool activation_symmetry_{false};
|
bool activation_symmetry_{false};
|
||||||
bool weight_symmetry_{true};
|
bool weight_symmetry_{true};
|
||||||
std::set<PrimitivePtr> support_int8_ops_;
|
std::set<PrimitivePtr> support_int8_ops_;
|
||||||
|
|
|
@ -55,7 +55,6 @@ enum WeightQuantType {
|
||||||
FIXED_BIT_PER_LAYER = 1,
|
FIXED_BIT_PER_LAYER = 1,
|
||||||
MIXED_BIT_PER_LAYER = 2,
|
MIXED_BIT_PER_LAYER = 2,
|
||||||
};
|
};
|
||||||
constexpr size_t kUint8Quantization = 8;
|
|
||||||
constexpr size_t k8Bit = 8;
|
constexpr size_t k8Bit = 8;
|
||||||
constexpr size_t k16Bit = 16;
|
constexpr size_t k16Bit = 16;
|
||||||
constexpr size_t kMaxNum1024 = 1024;
|
constexpr size_t kMaxNum1024 = 1024;
|
||||||
|
|
|
@ -54,12 +54,11 @@ class WeightQuantizer : public Quantizer {
|
||||||
quant_max_ = QuantMax(bit_num_, false);
|
quant_max_ = QuantMax(bit_num_, false);
|
||||||
quant_min_ = QuantMin(bit_num_, false, false);
|
quant_min_ = QuantMin(bit_num_, false, false);
|
||||||
// parse type_id_
|
// parse type_id_
|
||||||
|
MS_ASSERT(bit_num_ > 0 && bit_num_ <= k16Bit);
|
||||||
if (bit_num_ > 0 && bit_num_ <= k8Bit) {
|
if (bit_num_ > 0 && bit_num_ <= k8Bit) {
|
||||||
type_id_ = kNumberTypeInt8;
|
type_id_ = kNumberTypeInt8;
|
||||||
} else if (bit_num_ <= k16Bit) {
|
} else if (bit_num_ <= k16Bit) {
|
||||||
type_id_ = kNumberTypeInt16;
|
type_id_ = kNumberTypeInt16;
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "invalid input bits";
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue