!30438 [MSLITE] check int8 weight tensor

Merge pull request !30438 from ling/clean
This commit is contained in:
i-robot 2022-02-25 02:58:36 +00:00 committed by Gitee
commit 1345de01ed
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 13 additions and 7 deletions

View File

@ -346,13 +346,14 @@ void MatmulBaseInt8CPUKernel::FreeTmpBuffer() {
return;
}
void MatmulBaseInt8CPUKernel::TransferB() {
int MatmulBaseInt8CPUKernel::TransferB() {
auto weight_data = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data());
CHECK_NULL_RETURN(weight_data);
CHECK_NULL_RETURN(b_pack_func_);
for (int i = 0; i < param_->batch; i++) {
auto current_weight = weight_data + i * param_->deep_ * param_->col_;
auto current_b_pack = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_;
auto current_sums = weight_bias_sums_ + i * param_->col_align_;
MS_CHECK_PTR_IF_NULL(b_pack_func_);
if (param_->b_transpose_) {
b_pack_func_(current_weight, current_b_pack, param_->col_, param_->deep_);
CalcWeightBiasSums(current_weight, param_->deep_, param_->col_, quant_param_->input_.zp_,
@ -363,7 +364,7 @@ void MatmulBaseInt8CPUKernel::TransferB() {
quant_param_->filter_zp_, bias_ptr_, current_sums, RowMajor, filter_per_channel_);
}
}
return;
return RET_OK;
}
int MatmulBaseInt8CPUKernel::InitTmpBuffer() {
@ -447,7 +448,10 @@ int MatmulBaseInt8CPUKernel::ReSize() {
}
if (param_->b_const_ == true) {
TransferB();
if (TransferB() != RET_OK) {
MS_LOG(ERROR) << "TransferB error";
return RET_ERROR;
}
}
return RET_OK;
}
@ -491,7 +495,10 @@ int MatmulBaseInt8CPUKernel::Run() {
}
#endif
if (param_->b_const_ == false) {
TransferB();
if (TransferB() != RET_OK) {
MS_LOG(ERROR) << "TransferB error";
return RET_ERROR;
}
}
int8_t *a_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(0)->data());
int8_t *c_ptr = reinterpret_cast<int8_t *>(out_tensors_.at(0)->data());

View File

@ -60,8 +60,7 @@ class MatmulBaseInt8CPUKernel : public InnerKernel {
private:
int InitTmpBuffer();
void FreeTmpBuffer();
void TransferA();
void TransferB();
int TransferB();
private:
int MallocQuantParam();