From a2e35010825959e9c5dbc1d29a1d77a6a371c418 Mon Sep 17 00:00:00 2001 From: yeyunpeng2020 Date: Thu, 20 Jan 2022 11:51:28 +0800 Subject: [PATCH] fix dynamic quant kernel bug && dynamic matmul multi threads bug --- .../runtime/kernel/arm/int8/dynamic_quant.cc | 2 +- .../kernel/arm/int8/matmul_dynamic_int8.cc | 48 ++++--------------- .../quantizer/config/dynamic_quant.cfg | 4 -- 3 files changed, 11 insertions(+), 43 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/dynamic_quant.cc b/mindspore/lite/src/runtime/kernel/arm/int8/dynamic_quant.cc index 9e69fb47bdd..6015bca3e2a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/dynamic_quant.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/dynamic_quant.cc @@ -121,7 +121,7 @@ void DynamicQuantCPUKernel::CalculateScaleZp() { quant_parm.zeroPoint = zp; quant_parm.bitNum = k8Bit; quant_parm.inited = true; - this->out_tensors_.front()->AddQuantParam(quant_parm); + this->out_tensors_.front()->set_quant_params({quant_parm}); return; } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_dynamic_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_dynamic_int8.cc index 1025a647f10..ee394006179 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_dynamic_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_dynamic_int8.cc @@ -50,9 +50,17 @@ int MatmulDynamicInt8CPUKernel::RunImpl(int task_id) { if (cur_oc <= 0) { return RET_OK; } - DynamicMatmulInt8AIWI(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, fp32_bias_ptr_, + float *bias_ptr = fp32_bias_ptr_; + if (fp32_bias_ptr_ != nullptr) { + bias_ptr += cur_stride; + } + float *filter_scale = quant_param_->filter_scale_; + if (filter_per_channel_) { + filter_scale += cur_stride; + } + DynamicMatmulInt8AIWI(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, bias_ptr, batch_c_ptr_ + cur_stride, param_->row_, cur_oc, param_->deep_align_, - quant_param_->input_scale_, quant_param_->filter_scale_, param_->col_, filter_per_channel_); + quant_param_->input_scale_, filter_scale, param_->col_, filter_per_channel_); return RET_OK; } @@ -139,54 +147,18 @@ int MatmulDynamicInt8CPUKernel::InitInputQuantParam() { void MatmulDynamicInt8CPUKernel::InitParameter() { param_->a_const_ = (in_tensors_[kInputIndex]->data() != nullptr); param_->b_const_ = (in_tensors_[kWeightIndex]->data() != nullptr); -#ifdef ENABLE_ARM32 - row_tile_ = C4NUM; - col_tile_ = C2NUM; - deep_tile_ = C16NUM; -#elif ENABLE_ARM64 - support_sdot_ = mindspore::lite::IsSupportSDot(); - row_tile_ = C4NUM; - if (support_sdot_) { - col_tile_ = C16NUM; - deep_tile_ = C4NUM; - } else { - col_tile_ = C4NUM; - deep_tile_ = C16NUM; - } -#else row_tile_ = C4NUM; col_tile_ = C4NUM; deep_tile_ = C16NUM; -#endif if (param_->a_transpose_) { a_pack_func_ = RowMajor2Col16x4MajorInt8; } else { a_pack_func_ = RowMajor2Row16x4MajorInt8; } if (param_->b_transpose_) { -#ifdef ENABLE_ARM32 - b_pack_func_ = RowMajor2Row2x16MajorInt8; -#elif ENABLE_ARM64 - if (support_sdot_) { - b_pack_func_ = RowMajor2Row4x16MajorInt8; - } else { - b_pack_func_ = RowMajor2Row16x4MajorInt8; - } -#else b_pack_func_ = RowMajor2Row16x4MajorInt8; -#endif } else { -#ifdef ENABLE_ARM32 - b_pack_func_ = RowMajor2Col16x2MajorInt8; -#elif ENABLE_ARM64 - if (support_sdot_) { - b_pack_func_ = RowMajor2Col4x16MajorInt8; - } else { - b_pack_func_ = RowMajor2Col16x4MajorInt8; - } -#else b_pack_func_ = RowMajor2Col16x4MajorInt8; -#endif } return; } diff --git a/mindspore/lite/tools/converter/quantizer/config/dynamic_quant.cfg b/mindspore/lite/tools/converter/quantizer/config/dynamic_quant.cfg index 164a31d86d2..d11325e778c 100644 --- a/mindspore/lite/tools/converter/quantizer/config/dynamic_quant.cfg +++ b/mindspore/lite/tools/converter/quantizer/config/dynamic_quant.cfg @@ -1,7 +1,3 @@ [common_quant_param] quant_type=DYNAMIC_QUANT bit_num=8 -# Layers with size of weights exceeds threshold `min_quant_weight_size` will be quantized. -min_quant_weight_size=0 -# Layers with channel size of weights exceeds threshold `min_quant_weight_channel` will be quantized. -min_quant_weight_channel=16