forked from mindspore-Ecosystem/mindspore
!48766 [BUG] 修复Matmul 量化算子在动态shape下权重为空的问题
Merge pull request !48766 from douzhixing/fix-matmul
This commit is contained in:
commit
d9cbe819be
|
@ -179,18 +179,21 @@ void MatmulBaseInt8CPUKernel::FreeQuantParam() {
|
|||
free(quant_param_);
|
||||
quant_param_ = nullptr;
|
||||
}
|
||||
|
||||
if (save_b_const_ != nullptr) {
|
||||
free(save_b_const_);
|
||||
save_b_const_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int MatmulBaseInt8CPUKernel::MallocQuantParam() {
|
||||
auto weight_tensor = in_tensors_.at(1);
|
||||
auto weight_quant_params = weight_tensor->quant_params();
|
||||
auto w_shape = weight_tensor->shape();
|
||||
MS_CHECK_TRUE_MSG(weight_tensor->shape().size() >= DIMENSION_2D, lite::RET_ERROR, "weight dims should >=2");
|
||||
int col = param_->b_transpose_ ? w_shape[w_shape.size() - DIMENSION_2D] : w_shape[w_shape.size() - 1];
|
||||
|
||||
MS_CHECK_TRUE_MSG(weight_quant_params.size() >= 1, lite::RET_ERROR, "weight quant params size should >= 1");
|
||||
filter_per_channel_ = (weight_quant_params.size() > 1);
|
||||
channel_num_ = weight_quant_params.size();
|
||||
|
||||
channel_num_ = filter_per_channel_ ? col : 1;
|
||||
const int &init_size = channel_num_;
|
||||
|
||||
quant_param_ = reinterpret_cast<MatmulQuantParameter *>(malloc(sizeof(MatmulQuantParameter)));
|
||||
|
@ -348,7 +351,8 @@ void MatmulBaseInt8CPUKernel::FreeTmpBuffer() {
|
|||
}
|
||||
|
||||
int MatmulBaseInt8CPUKernel::TransferB() {
|
||||
auto weight_data = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data());
|
||||
auto weight_data = (save_b_const_ == nullptr) ? reinterpret_cast<int8_t *>(in_tensors_.at(1)->data())
|
||||
: reinterpret_cast<int8_t *>(save_b_const_);
|
||||
CHECK_NULL_RETURN(weight_data);
|
||||
CHECK_NULL_RETURN(b_pack_func_);
|
||||
for (int i = 0; i < param_->batch; i++) {
|
||||
|
@ -365,6 +369,10 @@ int MatmulBaseInt8CPUKernel::TransferB() {
|
|||
quant_param_->filter_zp_, bias_ptr_, current_sums, RowMajor, filter_per_channel_);
|
||||
}
|
||||
}
|
||||
if (save_b_const_ != nullptr) {
|
||||
free(save_b_const_);
|
||||
save_b_const_ = nullptr;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -407,7 +415,7 @@ int MatmulBaseInt8CPUKernel::InitBias() {
|
|||
FreeTmpBuffer();
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
bias_ptr_ = reinterpret_cast<int *>(bias_tensor->data());
|
||||
bias_ptr_ = reinterpret_cast<int *>(malloc(bias_tensor->ElementsNum() * sizeof(int)));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Memory allocation failed";
|
||||
FreeTmpBuffer();
|
||||
|
@ -438,7 +446,15 @@ int MatmulBaseInt8CPUKernel::Prepare() {
|
|||
FreeQuantParam();
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (!InferShapeDone()) {
|
||||
if (param_->b_const_) {
|
||||
auto weight_tensor = in_tensors_.at(1);
|
||||
CHECK_NULL_RETURN(weight_tensor);
|
||||
CHECK_NULL_RETURN(weight_tensor->data());
|
||||
save_b_const_ = reinterpret_cast<int8_t *>(malloc(weight_tensor->ElementsNum() * sizeof(int8_t)));
|
||||
(void)memcpy(save_b_const_, weight_tensor->data(), weight_tensor->ElementsNum() * sizeof(int8_t));
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
int MatmulBaseInt8CPUKernel::MatmulReSize() {
|
||||
|
@ -462,7 +478,7 @@ int MatmulBaseInt8CPUKernel::ReSize() {
|
|||
return ret;
|
||||
}
|
||||
|
||||
if (param_->b_const_ == true) {
|
||||
if (param_->b_const_) {
|
||||
if (TransferB() != RET_OK) {
|
||||
MS_LOG(ERROR) << "TransferB error";
|
||||
return RET_ERROR;
|
||||
|
@ -509,7 +525,7 @@ int MatmulBaseInt8CPUKernel::Run() {
|
|||
return RunArm64Sdot();
|
||||
}
|
||||
#endif
|
||||
if (param_->b_const_ == false) {
|
||||
if (!param_->b_const_) {
|
||||
if (TransferB() != RET_OK) {
|
||||
MS_LOG(ERROR) << "TransferB error";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -83,6 +83,7 @@ class MatmulBaseInt8CPUKernel : public LiteKernel {
|
|||
int8_t *batch_weight_ptr_ = nullptr;
|
||||
int8_t *batch_b_ptr_ = nullptr;
|
||||
int8_t *batch_c_ptr_ = nullptr;
|
||||
int8_t *save_b_const_ = nullptr;
|
||||
int *batch_sums_ = nullptr;
|
||||
int row_tile_ = C4NUM;
|
||||
int col_tile_ = C4NUM;
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
ml_asr_encoder_int8_202103.onnx 17 43591176
|
||||
ml_face_mnet 86 832744
|
||||
ml_face_landmark_2 0.8 472112
|
||||
mobilenet.tflite 0.4 26040
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
[common_quant_param]
|
||||
quant_type=FULL_QUANT
|
||||
bit_num=8
|
||||
|
||||
[data_preprocess_param]
|
||||
calibrate_path=featinput:/home/workspace/mindspore_dataset/mslite/quantTraining/ml_asr_encoder_int8_202103_calibration_data
|
||||
calibrate_size=1
|
||||
input_type=BIN
|
||||
|
||||
|
||||
[full_quant_param]
|
||||
activation_quant_method=MAX_MIN
|
||||
bias_correction=true
|
||||
per_channel=true
|
Loading…
Reference in New Issue