!6397 [MSLITE] Fix the bug of int8 matmul weight tensor init

Merge pull request !6397 from zhanyuan/dev
This commit is contained in:
mindspore-ci-bot 2020-09-17 17:24:32 +08:00 committed by Gitee
commit efd0fde424
3 changed files with 103 additions and 59 deletions

View File

@ -34,12 +34,11 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::Tensor *> &in
MS_ASSERT(desc.type == schema::PrimitiveType_Concat);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
if (restore_data == nullptr) {
MS_LOG(ERROR) << "weight_tensor MutableData is nullptr.";
return nullptr;
}
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
auto *restore_data = weight_tensor->data_c();
auto is_const_quant_weight =
(restore_data != nullptr) &&
(weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant);
if (is_const_quant_weight) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
@ -58,7 +57,7 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::Tensor *> &in
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
if (is_const_quant_weight) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
@ -69,14 +68,14 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::Tensor *> &in
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
if (is_const_quant_weight) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
if (is_const_quant_weight) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}

View File

@ -45,26 +45,52 @@ int MatmulInt8CPUKernel::ReSize() {
params_->row_ = o_shape[o_shape.size() - 2];
params_->col_ = o_shape[o_shape.size() - 1];
params_->deep_ = params_->a_transpose_ ? x_shape[x_shape.size() - 2] : x_shape[x_shape.size() - 1];
params_->row_8_ = UP_ROUND(params_->row_, 8);
params_->col_8_ = UP_ROUND(params_->col_, 8);
r4_ = UP_ROUND(params_->row_, 4);
c4_ = UP_ROUND(params_->col_, 4);
d16_ = UP_ROUND(params_->deep_, 16);
a_r4x16_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(r4_ * d16_ * sizeof(int8_t)));
params_->row_4_ = UP_ROUND(params_->row_, 4);
params_->col_4_ = UP_ROUND(params_->col_, 4);
params_->deep_16_ = UP_ROUND(params_->deep_, 16);
a_r4x16_ptr_ =
reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(params_->row_4_ * params_->deep_16_ * sizeof(int8_t)));
if (!a_r4x16_ptr_) return RET_MEMORY_FAILED;
memset(a_r4x16_ptr_, 0, r4_ * d16_ * sizeof(int8_t));
b_c16x4_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(c4_ * d16_ * sizeof(int8_t)));
if (!b_c16x4_ptr_) return RET_MEMORY_FAILED;
memset(b_c16x4_ptr_, 0, c4_ * d16_ * sizeof(int8_t));
input_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(r4_ * sizeof(int)));
memset(a_r4x16_ptr_, 0, params_->row_4_ * params_->deep_16_ * sizeof(int8_t));
input_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(params_->row_4_ * sizeof(int)));
if (!input_sums_) return RET_MEMORY_FAILED;
memset(input_sums_, 0, r4_ * sizeof(int));
weight_bias_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(c4_ * sizeof(int)));
if (!weight_bias_sums_) return RET_MEMORY_FAILED;
memset(weight_bias_sums_, 0, c4_ * sizeof(int));
thread_count_ = MSMIN(thread_count_, UP_DIV(c4_, 4));
thread_stride_ = UP_DIV(UP_DIV(c4_, 4), thread_count_);
memset(input_sums_, 0, params_->row_4_ * sizeof(int));
if (in_tensors_.size() == 3) {
auto bias_size = params_->col_4_ * sizeof(int);
bias_ptr_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(bias_size));
if (!bias_ptr_) return RET_MEMORY_FAILED;
memcpy(bias_ptr_, in_tensors_[2]->data_c(), bias_size);
} else {
bias_ptr_ = NULL;
}
params_->b_const_ = (in_tensors_[1]->data_c() != nullptr);
if (params_->b_const_) {
b_c16x4_batch_ = reinterpret_cast<int8_t *>(
ctx_->allocator->Malloc(params_->batch * params_->col_4_ * params_->deep_16_ * sizeof(int8_t)));
if (!b_c16x4_batch_) return RET_MEMORY_FAILED;
weight_bias_sums_batch_ =
reinterpret_cast<int *>(ctx_->allocator->Malloc(params_->batch * params_->col_4_ * sizeof(int)));
if (!weight_bias_sums_batch_) return RET_MEMORY_FAILED;
auto b_ptr = reinterpret_cast<int8_t *>(in_tensors_[1]->data_c());
for (int i = 0; i < params_->batch; ++i) {
auto cur_b = b_ptr + i * params_->deep_ * params_->col_;
auto cur_b_pack = b_c16x4_batch_ + i * params_->col_4_ * params_->deep_16_;
auto cur_sums = weight_bias_sums_batch_ + i * params_->col_4_;
if (params_->b_transpose_) {
RowMajor2Row16x4MajorInt8(cur_b, cur_b_pack, params_->col_, params_->deep_);
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
bias_ptr_, cur_sums, ColMajor);
} else {
RowMajor2Col16x4Major(cur_b, params_->deep_, params_->col_, cur_b_pack, params_->deep_16_);
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
bias_ptr_, cur_sums, RowMajor);
}
}
}
thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_4_, 4));
thread_stride_ = UP_DIV(UP_DIV(params_->col_4_, 4), thread_count_);
auto input_tensor = in_tensors_[0];
auto params = input_tensor->GetQuantParams();
@ -89,23 +115,24 @@ int MatmulInt8CPUKernel::ReSize() {
}
int MatmulInt8CPUKernel::RunImpl(int task_id) {
int cur_oc = MSMIN(thread_stride_, UP_DIV(c4_, 4) - task_id * thread_stride_);
int cur_oc = MSMIN(thread_stride_, UP_DIV(params_->col_4_, 4) - task_id * thread_stride_);
if (cur_oc <= 0) {
return RET_OK;
}
int cur_oc_res = MSMIN(thread_stride_ * C4NUM, params_->col_ - task_id * thread_stride_ * C4NUM);
auto cur_b = b_c16x4_ptr_ + task_id * thread_stride_ * 4 * d16_;
auto cur_b = b_c16x4_ptr_ + task_id * thread_stride_ * 4 * params_->deep_16_;
auto cur_bias = weight_bias_sums_ + task_id * thread_stride_ * 4;
auto cur_c = c_ptr_ + task_id * thread_stride_ * 4;
auto &p = quant_params_;
#ifdef ENABLE_ARM64
MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, INT8_MIN, INT8_MAX,
p.output.zp_, &p.quant_multiplier, &p.left_shift, &p.right_shift, params_->row_, cur_oc_res,
params_->col_ * sizeof(int8_t), false);
MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, params_->row_4_, cur_oc * C4NUM, params_->deep_16_, input_sums_,
cur_bias, INT8_MIN, INT8_MAX, p.output.zp_, &p.quant_multiplier, &p.left_shift, &p.right_shift,
params_->row_, cur_oc_res, params_->col_ * sizeof(int8_t), false);
#else
MatMulInt8_16x4_r(a_r4x16_ptr_, cur_b, cur_c, params_->row_, cur_oc_res, d16_, params_->col_, input_sums_, cur_bias,
&p.left_shift, &p.right_shift, &p.quant_multiplier, p.output.zp_, INT8_MIN, INT8_MAX, false);
MatMulInt8_16x4_r(a_r4x16_ptr_, cur_b, cur_c, params_->row_, cur_oc_res, params_->deep_16_, params_->col_,
input_sums_, cur_bias, &p.left_shift, &p.right_shift, &p.quant_multiplier, p.output.zp_, INT8_MIN,
INT8_MAX, false);
#endif
return RET_OK;
@ -127,33 +154,47 @@ int MatmulInt8CPUKernel::Run() {
MS_LOG(ERROR) << "Prepare failed.";
return RET_ERROR;
}
auto a_ptr = reinterpret_cast<int8_t *>(in_tensors_[0]->MutableData());
auto b_ptr = reinterpret_cast<int8_t *>(in_tensors_[1]->MutableData());
auto c_ptr = reinterpret_cast<int8_t *>(out_tensors_[0]->MutableData());
auto a_ptr = reinterpret_cast<int8_t *>(in_tensors_[0]->data_c());
auto c_ptr = reinterpret_cast<int8_t *>(out_tensors_[0]->data_c());
auto a_stride = params_->row_ * params_->deep_;
auto b_stride = params_->deep_ * params_->col_;
auto c_stride = params_->row_ * params_->col_;
if (!params_->b_const_) {
b_c16x4_batch_ = reinterpret_cast<int8_t *>(
ctx_->allocator->Malloc(params_->batch * params_->col_4_ * params_->deep_16_ * sizeof(int8_t)));
if (!b_c16x4_batch_) return RET_MEMORY_FAILED;
weight_bias_sums_batch_ =
reinterpret_cast<int *>(ctx_->allocator->Malloc(params_->batch * params_->col_4_ * sizeof(int)));
if (!weight_bias_sums_batch_) return RET_MEMORY_FAILED;
auto b_ptr = reinterpret_cast<int8_t *>(in_tensors_[1]->data_c());
for (int i = 0; i < params_->batch; ++i) {
auto cur_b = b_ptr + i * b_stride;
auto cur_b_pack = b_c16x4_batch_ + i * params_->col_4_ * params_->deep_16_;
auto cur_sums = weight_bias_sums_batch_ + i * params_->col_4_;
if (params_->b_transpose_) {
RowMajor2Row16x4MajorInt8(cur_b, cur_b_pack, params_->col_, params_->deep_);
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
bias_ptr_, cur_sums, ColMajor);
} else {
RowMajor2Col16x4Major(cur_b, params_->deep_, params_->col_, cur_b_pack, params_->deep_16_);
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
bias_ptr_, cur_sums, RowMajor);
}
}
}
for (int i = 0; i < params_->batch; ++i) {
auto cur_a_ptr = a_ptr + i * a_stride;
auto cur_b_ptr = b_ptr + i * b_stride;
if (params_->a_transpose_) {
RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_, d16_);
RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_, params_->deep_16_);
CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, ColMajor);
} else {
RowMajor2Row16x4MajorInt8(cur_a_ptr, a_r4x16_ptr_, params_->row_, params_->deep_);
CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, RowMajor);
}
if (params_->b_transpose_) {
RowMajor2Row16x4MajorInt8(cur_b_ptr, b_c16x4_ptr_, params_->col_, params_->deep_);
CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
NULL, weight_bias_sums_, ColMajor);
} else {
RowMajor2Col16x4Major(cur_b_ptr, params_->deep_, params_->col_, b_c16x4_ptr_, d16_);
CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
NULL, weight_bias_sums_, RowMajor);
}
b_c16x4_ptr_ = b_c16x4_batch_ + i * params_->col_4_ * params_->deep_16_;
weight_bias_sums_ = weight_bias_sums_batch_ + i * params_->col_4_;
c_ptr_ = c_ptr + i * c_stride;
ret = ParallelLaunch(this->context_->thread_pool_, MatmulInt8Run, this, thread_count_);
if (ret != RET_OK) {

View File

@ -43,28 +43,32 @@ class MatmulInt8CPUKernel : public MatmulBaseCPUKernel {
ctx_->allocator->Free(a_r4x16_ptr_);
a_r4x16_ptr_ = nullptr;
}
if (b_c16x4_ptr_ != nullptr) {
ctx_->allocator->Free(b_c16x4_ptr_);
b_c16x4_ptr_ = nullptr;
}
if (input_sums_ != nullptr) {
ctx_->allocator->Free(input_sums_);
input_sums_ = nullptr;
}
if (weight_bias_sums_ != nullptr) {
ctx_->allocator->Free(weight_bias_sums_);
weight_bias_sums_ = nullptr;
if (b_c16x4_batch_ != nullptr) {
ctx_->allocator->Free(b_c16x4_batch_);
b_c16x4_batch_ = nullptr;
}
if (weight_bias_sums_batch_ != nullptr) {
ctx_->allocator->Free(weight_bias_sums_batch_);
weight_bias_sums_batch_ = nullptr;
}
if (bias_ptr_ != nullptr) {
ctx_->allocator->Free(bias_ptr_);
bias_ptr_ = nullptr;
}
}
MatmulQuantArg quant_params_;
int8_t *a_r4x16_ptr_ = nullptr;
int8_t *b_c16x4_ptr_ = nullptr;
int8_t *c_ptr_ = nullptr;
int *bias_ptr_ = nullptr;
int *input_sums_ = nullptr;
int *weight_bias_sums_ = nullptr;
int r4_;
int c4_;
int d16_;
int8_t *b_c16x4_batch_ = nullptr;
int *weight_bias_sums_batch_ = nullptr;
}; // namespace mindspore::kernel
} // namespace mindspore::kernel