forked from mindspore-Ecosystem/mindspore
fix matmul bug
This commit is contained in:
parent
43ffa26a3a
commit
5e51fa9fff
|
@ -179,18 +179,21 @@ void MatmulBaseInt8CPUKernel::FreeQuantParam() {
|
||||||
free(quant_param_);
|
free(quant_param_);
|
||||||
quant_param_ = nullptr;
|
quant_param_ = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (save_b_const_ != nullptr) {
|
||||||
|
free(save_b_const_);
|
||||||
|
save_b_const_ = nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int MatmulBaseInt8CPUKernel::MallocQuantParam() {
|
int MatmulBaseInt8CPUKernel::MallocQuantParam() {
|
||||||
auto weight_tensor = in_tensors_.at(1);
|
auto weight_tensor = in_tensors_.at(1);
|
||||||
auto weight_quant_params = weight_tensor->quant_params();
|
auto weight_quant_params = weight_tensor->quant_params();
|
||||||
auto w_shape = weight_tensor->shape();
|
|
||||||
MS_CHECK_TRUE_MSG(weight_tensor->shape().size() >= DIMENSION_2D, lite::RET_ERROR, "weight dims should >=2");
|
|
||||||
int col = param_->b_transpose_ ? w_shape[w_shape.size() - DIMENSION_2D] : w_shape[w_shape.size() - 1];
|
|
||||||
|
|
||||||
|
MS_CHECK_TRUE_MSG(weight_quant_params.size() >= 1, lite::RET_ERROR, "weight quant params size should >= 1");
|
||||||
filter_per_channel_ = (weight_quant_params.size() > 1);
|
filter_per_channel_ = (weight_quant_params.size() > 1);
|
||||||
|
channel_num_ = weight_quant_params.size();
|
||||||
|
|
||||||
channel_num_ = filter_per_channel_ ? col : 1;
|
|
||||||
const int &init_size = channel_num_;
|
const int &init_size = channel_num_;
|
||||||
|
|
||||||
quant_param_ = reinterpret_cast<MatmulQuantParameter *>(malloc(sizeof(MatmulQuantParameter)));
|
quant_param_ = reinterpret_cast<MatmulQuantParameter *>(malloc(sizeof(MatmulQuantParameter)));
|
||||||
|
@ -348,7 +351,8 @@ void MatmulBaseInt8CPUKernel::FreeTmpBuffer() {
|
||||||
}
|
}
|
||||||
|
|
||||||
int MatmulBaseInt8CPUKernel::TransferB() {
|
int MatmulBaseInt8CPUKernel::TransferB() {
|
||||||
auto weight_data = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data());
|
auto weight_data = (save_b_const_ == nullptr) ? reinterpret_cast<int8_t *>(in_tensors_.at(1)->data())
|
||||||
|
: reinterpret_cast<int8_t *>(save_b_const_);
|
||||||
CHECK_NULL_RETURN(weight_data);
|
CHECK_NULL_RETURN(weight_data);
|
||||||
CHECK_NULL_RETURN(b_pack_func_);
|
CHECK_NULL_RETURN(b_pack_func_);
|
||||||
for (int i = 0; i < param_->batch; i++) {
|
for (int i = 0; i < param_->batch; i++) {
|
||||||
|
@ -365,6 +369,10 @@ int MatmulBaseInt8CPUKernel::TransferB() {
|
||||||
quant_param_->filter_zp_, bias_ptr_, current_sums, RowMajor, filter_per_channel_);
|
quant_param_->filter_zp_, bias_ptr_, current_sums, RowMajor, filter_per_channel_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (save_b_const_ != nullptr) {
|
||||||
|
free(save_b_const_);
|
||||||
|
save_b_const_ = nullptr;
|
||||||
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -407,7 +415,7 @@ int MatmulBaseInt8CPUKernel::InitBias() {
|
||||||
FreeTmpBuffer();
|
FreeTmpBuffer();
|
||||||
return RET_MEMORY_FAILED;
|
return RET_MEMORY_FAILED;
|
||||||
}
|
}
|
||||||
bias_ptr_ = reinterpret_cast<int *>(bias_tensor->data());
|
bias_ptr_ = reinterpret_cast<int *>(malloc(bias_tensor->ElementsNum() * sizeof(int)));
|
||||||
if (bias_ptr_ == nullptr) {
|
if (bias_ptr_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Memory allocation failed";
|
MS_LOG(ERROR) << "Memory allocation failed";
|
||||||
FreeTmpBuffer();
|
FreeTmpBuffer();
|
||||||
|
@ -438,7 +446,15 @@ int MatmulBaseInt8CPUKernel::Prepare() {
|
||||||
FreeQuantParam();
|
FreeQuantParam();
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
if (!InferShapeDone()) {
|
||||||
|
if (param_->b_const_) {
|
||||||
|
auto weight_tensor = in_tensors_.at(1);
|
||||||
|
CHECK_NULL_RETURN(weight_tensor);
|
||||||
|
CHECK_NULL_RETURN(weight_tensor->data());
|
||||||
|
save_b_const_ = reinterpret_cast<int8_t *>(malloc(weight_tensor->ElementsNum() * sizeof(int8_t)));
|
||||||
|
(void)memcpy(save_b_const_, weight_tensor->data(), weight_tensor->ElementsNum() * sizeof(int8_t));
|
||||||
|
}
|
||||||
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
int MatmulBaseInt8CPUKernel::MatmulReSize() {
|
int MatmulBaseInt8CPUKernel::MatmulReSize() {
|
||||||
|
@ -462,7 +478,7 @@ int MatmulBaseInt8CPUKernel::ReSize() {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (param_->b_const_ == true) {
|
if (param_->b_const_) {
|
||||||
if (TransferB() != RET_OK) {
|
if (TransferB() != RET_OK) {
|
||||||
MS_LOG(ERROR) << "TransferB error";
|
MS_LOG(ERROR) << "TransferB error";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
@ -509,7 +525,7 @@ int MatmulBaseInt8CPUKernel::Run() {
|
||||||
return RunArm64Sdot();
|
return RunArm64Sdot();
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
if (param_->b_const_ == false) {
|
if (!param_->b_const_) {
|
||||||
if (TransferB() != RET_OK) {
|
if (TransferB() != RET_OK) {
|
||||||
MS_LOG(ERROR) << "TransferB error";
|
MS_LOG(ERROR) << "TransferB error";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
|
|
@ -83,6 +83,7 @@ class MatmulBaseInt8CPUKernel : public LiteKernel {
|
||||||
int8_t *batch_weight_ptr_ = nullptr;
|
int8_t *batch_weight_ptr_ = nullptr;
|
||||||
int8_t *batch_b_ptr_ = nullptr;
|
int8_t *batch_b_ptr_ = nullptr;
|
||||||
int8_t *batch_c_ptr_ = nullptr;
|
int8_t *batch_c_ptr_ = nullptr;
|
||||||
|
int8_t *save_b_const_ = nullptr;
|
||||||
int *batch_sums_ = nullptr;
|
int *batch_sums_ = nullptr;
|
||||||
int row_tile_ = C4NUM;
|
int row_tile_ = C4NUM;
|
||||||
int col_tile_ = C4NUM;
|
int col_tile_ = C4NUM;
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
ml_asr_encoder_int8_202103.onnx 17 43591176
|
||||||
ml_face_mnet 86 832744
|
ml_face_mnet 86 832744
|
||||||
ml_face_landmark_2 0.8 472112
|
ml_face_landmark_2 0.8 472112
|
||||||
mobilenet.tflite 0.4 26040
|
mobilenet.tflite 0.4 26040
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
[common_quant_param]
|
||||||
|
quant_type=FULL_QUANT
|
||||||
|
bit_num=8
|
||||||
|
|
||||||
|
[data_preprocess_param]
|
||||||
|
calibrate_path=featinput:/home/workspace/mindspore_dataset/mslite/quantTraining/ml_asr_encoder_int8_202103_calibration_data
|
||||||
|
calibrate_size=1
|
||||||
|
input_type=BIN
|
||||||
|
|
||||||
|
|
||||||
|
[full_quant_param]
|
||||||
|
activation_quant_method=MAX_MIN
|
||||||
|
bias_correction=true
|
||||||
|
per_channel=true
|
Loading…
Reference in New Issue