forked from mindspore-Ecosystem/mindspore
!30438 [MSLITE] check int8 weight tensor
Merge pull request !30438 from ling/clean
This commit is contained in:
commit
1345de01ed
|
@ -346,13 +346,14 @@ void MatmulBaseInt8CPUKernel::FreeTmpBuffer() {
|
|||
return;
|
||||
}
|
||||
|
||||
void MatmulBaseInt8CPUKernel::TransferB() {
|
||||
int MatmulBaseInt8CPUKernel::TransferB() {
|
||||
auto weight_data = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data());
|
||||
CHECK_NULL_RETURN(weight_data);
|
||||
CHECK_NULL_RETURN(b_pack_func_);
|
||||
for (int i = 0; i < param_->batch; i++) {
|
||||
auto current_weight = weight_data + i * param_->deep_ * param_->col_;
|
||||
auto current_b_pack = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_;
|
||||
auto current_sums = weight_bias_sums_ + i * param_->col_align_;
|
||||
MS_CHECK_PTR_IF_NULL(b_pack_func_);
|
||||
if (param_->b_transpose_) {
|
||||
b_pack_func_(current_weight, current_b_pack, param_->col_, param_->deep_);
|
||||
CalcWeightBiasSums(current_weight, param_->deep_, param_->col_, quant_param_->input_.zp_,
|
||||
|
@ -363,7 +364,7 @@ void MatmulBaseInt8CPUKernel::TransferB() {
|
|||
quant_param_->filter_zp_, bias_ptr_, current_sums, RowMajor, filter_per_channel_);
|
||||
}
|
||||
}
|
||||
return;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulBaseInt8CPUKernel::InitTmpBuffer() {
|
||||
|
@ -447,7 +448,10 @@ int MatmulBaseInt8CPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
if (param_->b_const_ == true) {
|
||||
TransferB();
|
||||
if (TransferB() != RET_OK) {
|
||||
MS_LOG(ERROR) << "TransferB error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -491,7 +495,10 @@ int MatmulBaseInt8CPUKernel::Run() {
|
|||
}
|
||||
#endif
|
||||
if (param_->b_const_ == false) {
|
||||
TransferB();
|
||||
if (TransferB() != RET_OK) {
|
||||
MS_LOG(ERROR) << "TransferB error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
int8_t *a_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(0)->data());
|
||||
int8_t *c_ptr = reinterpret_cast<int8_t *>(out_tensors_.at(0)->data());
|
||||
|
|
|
@ -60,8 +60,7 @@ class MatmulBaseInt8CPUKernel : public InnerKernel {
|
|||
private:
|
||||
int InitTmpBuffer();
|
||||
void FreeTmpBuffer();
|
||||
void TransferA();
|
||||
void TransferB();
|
||||
int TransferB();
|
||||
|
||||
private:
|
||||
int MallocQuantParam();
|
||||
|
|
Loading…
Reference in New Issue