!48766 [BUG] 修复Matmul 量化算子在动态shape下权重为空的问题

Merge pull request !48766 from douzhixing/fix-matmul
This commit is contained in:
i-robot 2023-02-13 07:45:49 +00:00 committed by Gitee
commit d9cbe819be
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 41 additions and 9 deletions

View File

@ -179,18 +179,21 @@ void MatmulBaseInt8CPUKernel::FreeQuantParam() {
free(quant_param_);
quant_param_ = nullptr;
}
if (save_b_const_ != nullptr) {
free(save_b_const_);
save_b_const_ = nullptr;
}
}
int MatmulBaseInt8CPUKernel::MallocQuantParam() {
auto weight_tensor = in_tensors_.at(1);
auto weight_quant_params = weight_tensor->quant_params();
auto w_shape = weight_tensor->shape();
MS_CHECK_TRUE_MSG(weight_tensor->shape().size() >= DIMENSION_2D, lite::RET_ERROR, "weight dims should >=2");
int col = param_->b_transpose_ ? w_shape[w_shape.size() - DIMENSION_2D] : w_shape[w_shape.size() - 1];
MS_CHECK_TRUE_MSG(weight_quant_params.size() >= 1, lite::RET_ERROR, "weight quant params size should >= 1");
filter_per_channel_ = (weight_quant_params.size() > 1);
channel_num_ = weight_quant_params.size();
channel_num_ = filter_per_channel_ ? col : 1;
const int &init_size = channel_num_;
quant_param_ = reinterpret_cast<MatmulQuantParameter *>(malloc(sizeof(MatmulQuantParameter)));
@ -348,7 +351,8 @@ void MatmulBaseInt8CPUKernel::FreeTmpBuffer() {
}
int MatmulBaseInt8CPUKernel::TransferB() {
auto weight_data = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data());
auto weight_data = (save_b_const_ == nullptr) ? reinterpret_cast<int8_t *>(in_tensors_.at(1)->data())
: reinterpret_cast<int8_t *>(save_b_const_);
CHECK_NULL_RETURN(weight_data);
CHECK_NULL_RETURN(b_pack_func_);
for (int i = 0; i < param_->batch; i++) {
@ -365,6 +369,10 @@ int MatmulBaseInt8CPUKernel::TransferB() {
quant_param_->filter_zp_, bias_ptr_, current_sums, RowMajor, filter_per_channel_);
}
}
if (save_b_const_ != nullptr) {
free(save_b_const_);
save_b_const_ = nullptr;
}
return RET_OK;
}
@ -407,7 +415,7 @@ int MatmulBaseInt8CPUKernel::InitBias() {
FreeTmpBuffer();
return RET_MEMORY_FAILED;
}
bias_ptr_ = reinterpret_cast<int *>(bias_tensor->data());
bias_ptr_ = reinterpret_cast<int *>(malloc(bias_tensor->ElementsNum() * sizeof(int)));
if (bias_ptr_ == nullptr) {
MS_LOG(ERROR) << "Memory allocation failed";
FreeTmpBuffer();
@ -438,7 +446,15 @@ int MatmulBaseInt8CPUKernel::Prepare() {
FreeQuantParam();
return ret;
}
if (!InferShapeDone()) {
if (param_->b_const_) {
auto weight_tensor = in_tensors_.at(1);
CHECK_NULL_RETURN(weight_tensor);
CHECK_NULL_RETURN(weight_tensor->data());
save_b_const_ = reinterpret_cast<int8_t *>(malloc(weight_tensor->ElementsNum() * sizeof(int8_t)));
(void)memcpy(save_b_const_, weight_tensor->data(), weight_tensor->ElementsNum() * sizeof(int8_t));
}
}
return RET_OK;
}
int MatmulBaseInt8CPUKernel::MatmulReSize() {
@ -462,7 +478,7 @@ int MatmulBaseInt8CPUKernel::ReSize() {
return ret;
}
if (param_->b_const_ == true) {
if (param_->b_const_) {
if (TransferB() != RET_OK) {
MS_LOG(ERROR) << "TransferB error";
return RET_ERROR;
@ -509,7 +525,7 @@ int MatmulBaseInt8CPUKernel::Run() {
return RunArm64Sdot();
}
#endif
if (param_->b_const_ == false) {
if (!param_->b_const_) {
if (TransferB() != RET_OK) {
MS_LOG(ERROR) << "TransferB error";
return RET_ERROR;

View File

@ -83,6 +83,7 @@ class MatmulBaseInt8CPUKernel : public LiteKernel {
int8_t *batch_weight_ptr_ = nullptr;
int8_t *batch_b_ptr_ = nullptr;
int8_t *batch_c_ptr_ = nullptr;
int8_t *save_b_const_ = nullptr;
int *batch_sums_ = nullptr;
int row_tile_ = C4NUM;
int col_tile_ = C4NUM;

View File

@ -1,3 +1,4 @@
ml_asr_encoder_int8_202103.onnx 17 43591176
ml_face_mnet 86 832744
ml_face_landmark_2 0.8 472112
mobilenet.tflite 0.4 26040

View File

@ -0,0 +1,14 @@
[common_quant_param]
quant_type=FULL_QUANT
bit_num=8
[data_preprocess_param]
calibrate_path=featinput:/home/workspace/mindspore_dataset/mslite/quantTraining/ml_asr_encoder_int8_202103_calibration_data
calibrate_size=1
input_type=BIN
[full_quant_param]
activation_quant_method=MAX_MIN
bias_correction=true
per_channel=true