!26900 fix matmul int8 bug

Merge pull request !26900 from yeyunpeng2020/quant_bak
This commit is contained in:
i-robot 2021-11-29 02:49:17 +00:00 committed by Gitee
commit ab26378451
8 changed files with 46 additions and 20 deletions

View File

@ -116,7 +116,7 @@ int MatmulBaseInt8CPUKernel::Arm64SdotImpl(int task_id) {
int32_t *cur_zp = filter_per_channel_ ? quant_param_->filter_zp_ + cur_stride : quant_param_->filter_zp_; int32_t *cur_zp = filter_per_channel_ ? quant_param_->filter_zp_ + cur_stride : quant_param_->filter_zp_;
MatmulInt8DpOpt(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, batch_c_ptr_ + cur_stride, param_->row_, MatmulInt8DpOpt(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, batch_c_ptr_ + cur_stride, param_->row_,
cur_oc, param_->deep_align_, input_sums_, weight_bias_sums_ + cur_stride, quant_param_->out_act_min_, cur_oc, param_->deep_align_, input_sums_, batch_sums_ + cur_stride, quant_param_->out_act_min_,
quant_param_->out_act_max_, quant_param_->output_.zp_, cur_mul, cur_left, cur_right, param_->col_, quant_param_->out_act_max_, quant_param_->output_.zp_, cur_mul, cur_left, cur_right, param_->col_,
filter_per_channel_, cur_zp); filter_per_channel_, cur_zp);
@ -140,7 +140,7 @@ int MatmulBaseInt8CPUKernel::RunImpl(int task_id) {
int32_t *cur_zp = filter_per_channel_ ? quant_param_->filter_zp_ + cur_stride : quant_param_->filter_zp_; int32_t *cur_zp = filter_per_channel_ ? quant_param_->filter_zp_ + cur_stride : quant_param_->filter_zp_;
MatmulInt8Opt(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, batch_c_ptr_ + cur_stride, param_->row_, MatmulInt8Opt(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, batch_c_ptr_ + cur_stride, param_->row_,
cur_oc, param_->deep_align_, input_sums_, weight_bias_sums_ + cur_stride, quant_param_->out_act_min_, cur_oc, param_->deep_align_, input_sums_, batch_sums_ + cur_stride, quant_param_->out_act_min_,
quant_param_->out_act_max_, quant_param_->output_.zp_, cur_mul, cur_left, cur_right, param_->col_, quant_param_->out_act_max_, quant_param_->output_.zp_, cur_mul, cur_left, cur_right, param_->col_,
filter_per_channel_, cur_zp); filter_per_channel_, cur_zp);

View File

@ -1,6 +1,6 @@
ml_face_mnet 105 ml_face_mnet 64.6
ml_face_landmark_2 2 ml_face_landmark_2 0.6
mobilenet.tflite 0.5 mobilenet.tflite 0.4
transformer_20200831_encoder_fp32.tflite;36 82.7 transformer_20200831_encoder_fp32.tflite;36 73.5
transformer_20200831_decoder_fp32.tflite;11 18.3 transformer_20200831_decoder_fp32.tflite;11 15.8
ml_face_mnet_image 105 ml_face_mnet_image 54.1

View File

@ -147,4 +147,4 @@ MindrtRuntimeTest.RuntimeFp16
MixDataTypeTest.mix1 MixDataTypeTest.mix1
SchedulerTest.TestScheduleInt32OpToFp16Subgraph SchedulerTest.TestScheduleInt32OpToFp16Subgraph
TestGPURegistryCustomOp.TestGPUCustomAdd TestGPURegistryCustomOp.TestGPUCustomAdd
QuantCastInt8Test.*

View File

@ -176,14 +176,29 @@ double DataDistribution::CalculateScale(float min_value, float max_value) {
min_value = -abs_max; min_value = -abs_max;
max_value = abs_max; max_value = abs_max;
} }
this->encode_min_ = min_value;
this->encode_max_ = max_value; encode_min_ = min_value;
// Optimize Handle 0. encode_max_ = max_value;
// Handling 0
// Inputs are strictly positive, set the real min to 0. e.g. input range = [1.0, 5.0] -> [0.0, 5.0]
if (encode_min_ > 0.0f) {
MS_LOG(DEBUG) << "min " << encode_min_ << " is bigger then 0, set to 0, this may course low precision";
encode_min_ = 0.0f;
}
// Inputs are strictly negative, set the real max to 0. e.g. input range = [-5.0, -1.0] -> [-5.0, 0.0]
if (encode_max_ < 0.0f) {
MS_LOG(DEBUG) << "real_max " << encode_max_ << " is smaller than 0, set to 0, this may course low precision";
encode_max_ = 0.0f;
}
// Inputs are both negative and positive, real_min and real_max are slightly shifted to make the floating point zero
// exactly representable. e.g. input range = [-5.1, 5.1] -> [-5.12, 5.08]
MS_ASSERT(quant_max_ - quant_min_ > 0); MS_ASSERT(quant_max_ - quant_min_ > 0);
return (encode_max_ - encode_min_) / (quant_max_ - quant_min_); return (encode_max_ - encode_min_) / (quant_max_ - quant_min_);
} }
double DataDistribution::CalculateKLScale() { return CalculateScale(this->best_T_, this->real_max_); } double DataDistribution::CalculateKLScale() {
return CalculateScale(-std::abs(this->best_T_), std::abs(this->best_T_));
}
double DataDistribution::GetScale() { double DataDistribution::GetScale() {
switch (this->activation_quant_method_) { switch (this->activation_quant_method_) {

View File

@ -36,8 +36,12 @@ class DataDistribution {
this->quant_max_ = quant_max; this->quant_max_ = quant_max;
this->quant_min_ = quant_min; this->quant_min_ = quant_min;
std::fill(histogram_.begin(), histogram_.end(), 1.0e-7); std::fill(histogram_.begin(), histogram_.end(), 1.0e-7);
if (this->activation_quant_method_ == KL) {
symmetry_ = true;
} else {
symmetry_ = symmetry; symmetry_ = symmetry;
} }
}
int RecordMaxMinValueArray(const std::vector<float> &data); int RecordMaxMinValueArray(const std::vector<float> &data);

View File

@ -471,6 +471,16 @@ KernelCallBack DebugInfoManager::GetAfterCallBack(const std::map<std::string, Op
// all outputs are same dtype. // all outputs are same dtype.
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
auto tensor = outputs.at(i); auto tensor = outputs.at(i);
if (save_flag_ && !tensor->quant_params().empty()) {
QuantParamExtend quant_param;
quant_param.node_name = call_param.node_name;
quant_param.node_type = call_param.node_type;
quant_param.quant_params = tensor->quant_params();
quant_param.tensor_name = tensor->tensor_name();
quant_param.element_num = tensor->ElementsNum();
quant_param.dims = tensor->shape();
quant_params_.push_back(quant_param);
}
AddOriginInfo(call_param, op_parameters.at(call_param.node_name), false, i, AddOriginInfo(call_param, op_parameters.at(call_param.node_name), false, i,
static_cast<mindspore::lite::Tensor *>(tensor)); static_cast<mindspore::lite::Tensor *>(tensor));
} }

View File

@ -495,8 +495,8 @@ int FullQuantQuantizer::UpdateDivergeInterval() {
void FullQuantQuantizer::InitCpuConfig() { void FullQuantQuantizer::InitCpuConfig() {
this->target_data_type_ = kNumberTypeInt8; this->target_data_type_ = kNumberTypeInt8;
activation_symmetry_ = true; activation_symmetry_ = false;
weight_symmetry_ = false; weight_symmetry_ = true;
} }
void FullQuantQuantizer::InitQMinMax() { void FullQuantQuantizer::InitQMinMax() {

View File

@ -123,9 +123,6 @@ bool QuantStrategy::CanOpFullQuantized(const AnfNodePtr &node) {
} }
bool QuantStrategy::IsSkipOp(const AnfNodePtr &input_node) { bool QuantStrategy::IsSkipOp(const AnfNodePtr &input_node) {
if (skip_node_.find(input_node->fullname_with_scope()) == skip_node_.end()) { return !(skip_node_.find(input_node->fullname_with_scope()) == skip_node_.end());
return false;
}
return true;
} }
} // namespace mindspore::lite::quant } // namespace mindspore::lite::quant