fix dynamic quant kernel bug && dynamic matmul multi threads bug

This commit is contained in:
yeyunpeng2020 2022-01-20 11:51:28 +08:00
parent 6fd806b621
commit a2e3501082
3 changed files with 11 additions and 43 deletions

View File

@ -121,7 +121,7 @@ void DynamicQuantCPUKernel::CalculateScaleZp() {
quant_parm.zeroPoint = zp; quant_parm.zeroPoint = zp;
quant_parm.bitNum = k8Bit; quant_parm.bitNum = k8Bit;
quant_parm.inited = true; quant_parm.inited = true;
this->out_tensors_.front()->AddQuantParam(quant_parm); this->out_tensors_.front()->set_quant_params({quant_parm});
return; return;
} }

View File

@ -50,9 +50,17 @@ int MatmulDynamicInt8CPUKernel::RunImpl(int task_id) {
if (cur_oc <= 0) { if (cur_oc <= 0) {
return RET_OK; return RET_OK;
} }
DynamicMatmulInt8AIWI(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, fp32_bias_ptr_, float *bias_ptr = fp32_bias_ptr_;
if (fp32_bias_ptr_ != nullptr) {
bias_ptr += cur_stride;
}
float *filter_scale = quant_param_->filter_scale_;
if (filter_per_channel_) {
filter_scale += cur_stride;
}
DynamicMatmulInt8AIWI(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, bias_ptr,
batch_c_ptr_ + cur_stride, param_->row_, cur_oc, param_->deep_align_, batch_c_ptr_ + cur_stride, param_->row_, cur_oc, param_->deep_align_,
quant_param_->input_scale_, quant_param_->filter_scale_, param_->col_, filter_per_channel_); quant_param_->input_scale_, filter_scale, param_->col_, filter_per_channel_);
return RET_OK; return RET_OK;
} }
@ -139,54 +147,18 @@ int MatmulDynamicInt8CPUKernel::InitInputQuantParam() {
void MatmulDynamicInt8CPUKernel::InitParameter() { void MatmulDynamicInt8CPUKernel::InitParameter() {
param_->a_const_ = (in_tensors_[kInputIndex]->data() != nullptr); param_->a_const_ = (in_tensors_[kInputIndex]->data() != nullptr);
param_->b_const_ = (in_tensors_[kWeightIndex]->data() != nullptr); param_->b_const_ = (in_tensors_[kWeightIndex]->data() != nullptr);
#ifdef ENABLE_ARM32
row_tile_ = C4NUM;
col_tile_ = C2NUM;
deep_tile_ = C16NUM;
#elif ENABLE_ARM64
support_sdot_ = mindspore::lite::IsSupportSDot();
row_tile_ = C4NUM;
if (support_sdot_) {
col_tile_ = C16NUM;
deep_tile_ = C4NUM;
} else {
col_tile_ = C4NUM;
deep_tile_ = C16NUM;
}
#else
row_tile_ = C4NUM; row_tile_ = C4NUM;
col_tile_ = C4NUM; col_tile_ = C4NUM;
deep_tile_ = C16NUM; deep_tile_ = C16NUM;
#endif
if (param_->a_transpose_) { if (param_->a_transpose_) {
a_pack_func_ = RowMajor2Col16x4MajorInt8; a_pack_func_ = RowMajor2Col16x4MajorInt8;
} else { } else {
a_pack_func_ = RowMajor2Row16x4MajorInt8; a_pack_func_ = RowMajor2Row16x4MajorInt8;
} }
if (param_->b_transpose_) { if (param_->b_transpose_) {
#ifdef ENABLE_ARM32
b_pack_func_ = RowMajor2Row2x16MajorInt8;
#elif ENABLE_ARM64
if (support_sdot_) {
b_pack_func_ = RowMajor2Row4x16MajorInt8;
} else {
b_pack_func_ = RowMajor2Row16x4MajorInt8;
}
#else
b_pack_func_ = RowMajor2Row16x4MajorInt8; b_pack_func_ = RowMajor2Row16x4MajorInt8;
#endif
} else { } else {
#ifdef ENABLE_ARM32
b_pack_func_ = RowMajor2Col16x2MajorInt8;
#elif ENABLE_ARM64
if (support_sdot_) {
b_pack_func_ = RowMajor2Col4x16MajorInt8;
} else {
b_pack_func_ = RowMajor2Col16x4MajorInt8;
}
#else
b_pack_func_ = RowMajor2Col16x4MajorInt8; b_pack_func_ = RowMajor2Col16x4MajorInt8;
#endif
} }
return; return;
} }

View File

@ -1,7 +1,3 @@
[common_quant_param] [common_quant_param]
quant_type=DYNAMIC_QUANT quant_type=DYNAMIC_QUANT
bit_num=8 bit_num=8
# Layers with size of weights exceeds threshold `min_quant_weight_size` will be quantized.
min_quant_weight_size=0
# Layers with channel size of weights exceeds threshold `min_quant_weight_channel` will be quantized.
min_quant_weight_channel=16