!25383 fix matmul perchannel bug

Merge pull request !25383 from zhaozhenlong/lite/issue/matmul-perchannel-fix
This commit is contained in:
i-robot 2021-10-26 03:56:27 +00:00 committed by Gitee
commit 82b3a29c03
2 changed files with 8 additions and 3 deletions

View File

@ -96,11 +96,14 @@ void MatmulBaseInt8CPUKernel::FreeQuantParam() {
int MatmulBaseInt8CPUKernel::MallocQuantParam() {
auto weight_tensor = in_tensors_.at(1);
auto weight_quant_params = weight_tensor->quant_params();
int col = weight_tensor->shape().front();
auto w_shape = weight_tensor->shape();
MS_CHECK_TRUE_MSG(weight_tensor->shape().size() >= DIMENSION_2D, lite::RET_ERROR, "weight dims should >=2");
int col = param_->b_transpose_ ? w_shape[w_shape.size() - 2] : w_shape[w_shape.size() - 1];
filter_per_channel_ = (weight_quant_params.size() > 1);
int init_size = filter_per_channel_ ? col : 1;
channel_num_ = filter_per_channel_ ? col : 1;
const int &init_size = channel_num_;
quant_param_ = reinterpret_cast<MatmulQuantParameter *>(malloc(sizeof(MatmulQuantParameter)));
if (quant_param_ == nullptr) {
@ -141,8 +144,9 @@ void MatmulBaseInt8CPUKernel::InitQuantParam() {
quant_param_->output_.scale_ = out_quant_params.front().scale;
auto weight_tensor = in_tensors_.at(1);
int weight_quant_num = filter_per_channel_ ? weight_tensor->shape().front() : 1;
const int &weight_quant_num = channel_num_;
auto weight_quant_params = weight_tensor->quant_params();
MS_CHECK_TRUE_RET_VOID(static_cast<int>(weight_quant_params.size()) == weight_quant_num);
for (int i = 0; i < weight_quant_num; i++) {
quant_param_->filter_zp_[i] = weight_quant_params[i].zeroPoint;

View File

@ -77,6 +77,7 @@ class MatmulBaseInt8CPUKernel : public InnerKernel {
int *batch_sums_ = nullptr;
int row_tile_ = C4NUM;
int col_tile_ = C4NUM;
int channel_num_ = 0;
};
} // namespace mindspore::kernel