fix dynamic quant kernel bug && dynamic matmul multi threads bug
This commit is contained in:
parent
6fd806b621
commit
a2e3501082
|
@ -121,7 +121,7 @@ void DynamicQuantCPUKernel::CalculateScaleZp() {
|
||||||
quant_parm.zeroPoint = zp;
|
quant_parm.zeroPoint = zp;
|
||||||
quant_parm.bitNum = k8Bit;
|
quant_parm.bitNum = k8Bit;
|
||||||
quant_parm.inited = true;
|
quant_parm.inited = true;
|
||||||
this->out_tensors_.front()->AddQuantParam(quant_parm);
|
this->out_tensors_.front()->set_quant_params({quant_parm});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -50,9 +50,17 @@ int MatmulDynamicInt8CPUKernel::RunImpl(int task_id) {
|
||||||
if (cur_oc <= 0) {
|
if (cur_oc <= 0) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
DynamicMatmulInt8AIWI(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, fp32_bias_ptr_,
|
float *bias_ptr = fp32_bias_ptr_;
|
||||||
|
if (fp32_bias_ptr_ != nullptr) {
|
||||||
|
bias_ptr += cur_stride;
|
||||||
|
}
|
||||||
|
float *filter_scale = quant_param_->filter_scale_;
|
||||||
|
if (filter_per_channel_) {
|
||||||
|
filter_scale += cur_stride;
|
||||||
|
}
|
||||||
|
DynamicMatmulInt8AIWI(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, bias_ptr,
|
||||||
batch_c_ptr_ + cur_stride, param_->row_, cur_oc, param_->deep_align_,
|
batch_c_ptr_ + cur_stride, param_->row_, cur_oc, param_->deep_align_,
|
||||||
quant_param_->input_scale_, quant_param_->filter_scale_, param_->col_, filter_per_channel_);
|
quant_param_->input_scale_, filter_scale, param_->col_, filter_per_channel_);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,54 +147,18 @@ int MatmulDynamicInt8CPUKernel::InitInputQuantParam() {
|
||||||
void MatmulDynamicInt8CPUKernel::InitParameter() {
|
void MatmulDynamicInt8CPUKernel::InitParameter() {
|
||||||
param_->a_const_ = (in_tensors_[kInputIndex]->data() != nullptr);
|
param_->a_const_ = (in_tensors_[kInputIndex]->data() != nullptr);
|
||||||
param_->b_const_ = (in_tensors_[kWeightIndex]->data() != nullptr);
|
param_->b_const_ = (in_tensors_[kWeightIndex]->data() != nullptr);
|
||||||
#ifdef ENABLE_ARM32
|
|
||||||
row_tile_ = C4NUM;
|
|
||||||
col_tile_ = C2NUM;
|
|
||||||
deep_tile_ = C16NUM;
|
|
||||||
#elif ENABLE_ARM64
|
|
||||||
support_sdot_ = mindspore::lite::IsSupportSDot();
|
|
||||||
row_tile_ = C4NUM;
|
|
||||||
if (support_sdot_) {
|
|
||||||
col_tile_ = C16NUM;
|
|
||||||
deep_tile_ = C4NUM;
|
|
||||||
} else {
|
|
||||||
col_tile_ = C4NUM;
|
|
||||||
deep_tile_ = C16NUM;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
row_tile_ = C4NUM;
|
row_tile_ = C4NUM;
|
||||||
col_tile_ = C4NUM;
|
col_tile_ = C4NUM;
|
||||||
deep_tile_ = C16NUM;
|
deep_tile_ = C16NUM;
|
||||||
#endif
|
|
||||||
if (param_->a_transpose_) {
|
if (param_->a_transpose_) {
|
||||||
a_pack_func_ = RowMajor2Col16x4MajorInt8;
|
a_pack_func_ = RowMajor2Col16x4MajorInt8;
|
||||||
} else {
|
} else {
|
||||||
a_pack_func_ = RowMajor2Row16x4MajorInt8;
|
a_pack_func_ = RowMajor2Row16x4MajorInt8;
|
||||||
}
|
}
|
||||||
if (param_->b_transpose_) {
|
if (param_->b_transpose_) {
|
||||||
#ifdef ENABLE_ARM32
|
|
||||||
b_pack_func_ = RowMajor2Row2x16MajorInt8;
|
|
||||||
#elif ENABLE_ARM64
|
|
||||||
if (support_sdot_) {
|
|
||||||
b_pack_func_ = RowMajor2Row4x16MajorInt8;
|
|
||||||
} else {
|
|
||||||
b_pack_func_ = RowMajor2Row16x4MajorInt8;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
b_pack_func_ = RowMajor2Row16x4MajorInt8;
|
b_pack_func_ = RowMajor2Row16x4MajorInt8;
|
||||||
#endif
|
|
||||||
} else {
|
} else {
|
||||||
#ifdef ENABLE_ARM32
|
|
||||||
b_pack_func_ = RowMajor2Col16x2MajorInt8;
|
|
||||||
#elif ENABLE_ARM64
|
|
||||||
if (support_sdot_) {
|
|
||||||
b_pack_func_ = RowMajor2Col4x16MajorInt8;
|
|
||||||
} else {
|
|
||||||
b_pack_func_ = RowMajor2Col16x4MajorInt8;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
b_pack_func_ = RowMajor2Col16x4MajorInt8;
|
b_pack_func_ = RowMajor2Col16x4MajorInt8;
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,3 @@
|
||||||
[common_quant_param]
|
[common_quant_param]
|
||||||
quant_type=DYNAMIC_QUANT
|
quant_type=DYNAMIC_QUANT
|
||||||
bit_num=8
|
bit_num=8
|
||||||
# Layers with size of weights exceeds threshold `min_quant_weight_size` will be quantized.
|
|
||||||
min_quant_weight_size=0
|
|
||||||
# Layers with channel size of weights exceeds threshold `min_quant_weight_channel` will be quantized.
|
|
||||||
min_quant_weight_channel=16
|
|
||||||
|
|
Loading…
Reference in New Issue