forked from OSSInnovation/mindspore
!6397 [MSLITE] Fix the bug of int8 matmul weight tensor init
Merge pull request !6397 from zhanyuan/dev
This commit is contained in:
commit
efd0fde424
|
@ -34,12 +34,11 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::Tensor *> &in
|
|||
MS_ASSERT(desc.type == schema::PrimitiveType_Concat);
|
||||
|
||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||
auto *restore_data = weight_tensor->MutableData();
|
||||
if (restore_data == nullptr) {
|
||||
MS_LOG(ERROR) << "weight_tensor MutableData is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
auto *restore_data = weight_tensor->data_c();
|
||||
auto is_const_quant_weight =
|
||||
(restore_data != nullptr) &&
|
||||
(weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant);
|
||||
if (is_const_quant_weight) {
|
||||
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
|
||||
if (dequant_weight == nullptr) {
|
||||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
|
@ -58,7 +57,7 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::Tensor *> &in
|
|||
}
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
if (is_const_quant_weight) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
|
@ -69,14 +68,14 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::Tensor *> &in
|
|||
delete kernel;
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
if (is_const_quant_weight) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
if (is_const_quant_weight) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
|
|
|
@ -45,26 +45,52 @@ int MatmulInt8CPUKernel::ReSize() {
|
|||
params_->row_ = o_shape[o_shape.size() - 2];
|
||||
params_->col_ = o_shape[o_shape.size() - 1];
|
||||
params_->deep_ = params_->a_transpose_ ? x_shape[x_shape.size() - 2] : x_shape[x_shape.size() - 1];
|
||||
params_->row_8_ = UP_ROUND(params_->row_, 8);
|
||||
params_->col_8_ = UP_ROUND(params_->col_, 8);
|
||||
|
||||
r4_ = UP_ROUND(params_->row_, 4);
|
||||
c4_ = UP_ROUND(params_->col_, 4);
|
||||
d16_ = UP_ROUND(params_->deep_, 16);
|
||||
a_r4x16_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(r4_ * d16_ * sizeof(int8_t)));
|
||||
params_->row_4_ = UP_ROUND(params_->row_, 4);
|
||||
params_->col_4_ = UP_ROUND(params_->col_, 4);
|
||||
params_->deep_16_ = UP_ROUND(params_->deep_, 16);
|
||||
a_r4x16_ptr_ =
|
||||
reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(params_->row_4_ * params_->deep_16_ * sizeof(int8_t)));
|
||||
if (!a_r4x16_ptr_) return RET_MEMORY_FAILED;
|
||||
memset(a_r4x16_ptr_, 0, r4_ * d16_ * sizeof(int8_t));
|
||||
b_c16x4_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(c4_ * d16_ * sizeof(int8_t)));
|
||||
if (!b_c16x4_ptr_) return RET_MEMORY_FAILED;
|
||||
memset(b_c16x4_ptr_, 0, c4_ * d16_ * sizeof(int8_t));
|
||||
input_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(r4_ * sizeof(int)));
|
||||
memset(a_r4x16_ptr_, 0, params_->row_4_ * params_->deep_16_ * sizeof(int8_t));
|
||||
input_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(params_->row_4_ * sizeof(int)));
|
||||
if (!input_sums_) return RET_MEMORY_FAILED;
|
||||
memset(input_sums_, 0, r4_ * sizeof(int));
|
||||
weight_bias_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(c4_ * sizeof(int)));
|
||||
if (!weight_bias_sums_) return RET_MEMORY_FAILED;
|
||||
memset(weight_bias_sums_, 0, c4_ * sizeof(int));
|
||||
thread_count_ = MSMIN(thread_count_, UP_DIV(c4_, 4));
|
||||
thread_stride_ = UP_DIV(UP_DIV(c4_, 4), thread_count_);
|
||||
memset(input_sums_, 0, params_->row_4_ * sizeof(int));
|
||||
if (in_tensors_.size() == 3) {
|
||||
auto bias_size = params_->col_4_ * sizeof(int);
|
||||
bias_ptr_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(bias_size));
|
||||
if (!bias_ptr_) return RET_MEMORY_FAILED;
|
||||
memcpy(bias_ptr_, in_tensors_[2]->data_c(), bias_size);
|
||||
} else {
|
||||
bias_ptr_ = NULL;
|
||||
}
|
||||
|
||||
params_->b_const_ = (in_tensors_[1]->data_c() != nullptr);
|
||||
if (params_->b_const_) {
|
||||
b_c16x4_batch_ = reinterpret_cast<int8_t *>(
|
||||
ctx_->allocator->Malloc(params_->batch * params_->col_4_ * params_->deep_16_ * sizeof(int8_t)));
|
||||
if (!b_c16x4_batch_) return RET_MEMORY_FAILED;
|
||||
weight_bias_sums_batch_ =
|
||||
reinterpret_cast<int *>(ctx_->allocator->Malloc(params_->batch * params_->col_4_ * sizeof(int)));
|
||||
if (!weight_bias_sums_batch_) return RET_MEMORY_FAILED;
|
||||
auto b_ptr = reinterpret_cast<int8_t *>(in_tensors_[1]->data_c());
|
||||
|
||||
for (int i = 0; i < params_->batch; ++i) {
|
||||
auto cur_b = b_ptr + i * params_->deep_ * params_->col_;
|
||||
auto cur_b_pack = b_c16x4_batch_ + i * params_->col_4_ * params_->deep_16_;
|
||||
auto cur_sums = weight_bias_sums_batch_ + i * params_->col_4_;
|
||||
if (params_->b_transpose_) {
|
||||
RowMajor2Row16x4MajorInt8(cur_b, cur_b_pack, params_->col_, params_->deep_);
|
||||
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
|
||||
bias_ptr_, cur_sums, ColMajor);
|
||||
} else {
|
||||
RowMajor2Col16x4Major(cur_b, params_->deep_, params_->col_, cur_b_pack, params_->deep_16_);
|
||||
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
|
||||
bias_ptr_, cur_sums, RowMajor);
|
||||
}
|
||||
}
|
||||
}
|
||||
thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_4_, 4));
|
||||
thread_stride_ = UP_DIV(UP_DIV(params_->col_4_, 4), thread_count_);
|
||||
|
||||
auto input_tensor = in_tensors_[0];
|
||||
auto params = input_tensor->GetQuantParams();
|
||||
|
@ -89,23 +115,24 @@ int MatmulInt8CPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
int MatmulInt8CPUKernel::RunImpl(int task_id) {
|
||||
int cur_oc = MSMIN(thread_stride_, UP_DIV(c4_, 4) - task_id * thread_stride_);
|
||||
int cur_oc = MSMIN(thread_stride_, UP_DIV(params_->col_4_, 4) - task_id * thread_stride_);
|
||||
if (cur_oc <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
int cur_oc_res = MSMIN(thread_stride_ * C4NUM, params_->col_ - task_id * thread_stride_ * C4NUM);
|
||||
auto cur_b = b_c16x4_ptr_ + task_id * thread_stride_ * 4 * d16_;
|
||||
auto cur_b = b_c16x4_ptr_ + task_id * thread_stride_ * 4 * params_->deep_16_;
|
||||
auto cur_bias = weight_bias_sums_ + task_id * thread_stride_ * 4;
|
||||
auto cur_c = c_ptr_ + task_id * thread_stride_ * 4;
|
||||
|
||||
auto &p = quant_params_;
|
||||
#ifdef ENABLE_ARM64
|
||||
MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, INT8_MIN, INT8_MAX,
|
||||
p.output.zp_, &p.quant_multiplier, &p.left_shift, &p.right_shift, params_->row_, cur_oc_res,
|
||||
params_->col_ * sizeof(int8_t), false);
|
||||
MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, params_->row_4_, cur_oc * C4NUM, params_->deep_16_, input_sums_,
|
||||
cur_bias, INT8_MIN, INT8_MAX, p.output.zp_, &p.quant_multiplier, &p.left_shift, &p.right_shift,
|
||||
params_->row_, cur_oc_res, params_->col_ * sizeof(int8_t), false);
|
||||
#else
|
||||
MatMulInt8_16x4_r(a_r4x16_ptr_, cur_b, cur_c, params_->row_, cur_oc_res, d16_, params_->col_, input_sums_, cur_bias,
|
||||
&p.left_shift, &p.right_shift, &p.quant_multiplier, p.output.zp_, INT8_MIN, INT8_MAX, false);
|
||||
MatMulInt8_16x4_r(a_r4x16_ptr_, cur_b, cur_c, params_->row_, cur_oc_res, params_->deep_16_, params_->col_,
|
||||
input_sums_, cur_bias, &p.left_shift, &p.right_shift, &p.quant_multiplier, p.output.zp_, INT8_MIN,
|
||||
INT8_MAX, false);
|
||||
#endif
|
||||
|
||||
return RET_OK;
|
||||
|
@ -127,33 +154,47 @@ int MatmulInt8CPUKernel::Run() {
|
|||
MS_LOG(ERROR) << "Prepare failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto a_ptr = reinterpret_cast<int8_t *>(in_tensors_[0]->MutableData());
|
||||
auto b_ptr = reinterpret_cast<int8_t *>(in_tensors_[1]->MutableData());
|
||||
auto c_ptr = reinterpret_cast<int8_t *>(out_tensors_[0]->MutableData());
|
||||
auto a_ptr = reinterpret_cast<int8_t *>(in_tensors_[0]->data_c());
|
||||
auto c_ptr = reinterpret_cast<int8_t *>(out_tensors_[0]->data_c());
|
||||
auto a_stride = params_->row_ * params_->deep_;
|
||||
auto b_stride = params_->deep_ * params_->col_;
|
||||
auto c_stride = params_->row_ * params_->col_;
|
||||
|
||||
if (!params_->b_const_) {
|
||||
b_c16x4_batch_ = reinterpret_cast<int8_t *>(
|
||||
ctx_->allocator->Malloc(params_->batch * params_->col_4_ * params_->deep_16_ * sizeof(int8_t)));
|
||||
if (!b_c16x4_batch_) return RET_MEMORY_FAILED;
|
||||
weight_bias_sums_batch_ =
|
||||
reinterpret_cast<int *>(ctx_->allocator->Malloc(params_->batch * params_->col_4_ * sizeof(int)));
|
||||
if (!weight_bias_sums_batch_) return RET_MEMORY_FAILED;
|
||||
auto b_ptr = reinterpret_cast<int8_t *>(in_tensors_[1]->data_c());
|
||||
for (int i = 0; i < params_->batch; ++i) {
|
||||
auto cur_b = b_ptr + i * b_stride;
|
||||
auto cur_b_pack = b_c16x4_batch_ + i * params_->col_4_ * params_->deep_16_;
|
||||
auto cur_sums = weight_bias_sums_batch_ + i * params_->col_4_;
|
||||
if (params_->b_transpose_) {
|
||||
RowMajor2Row16x4MajorInt8(cur_b, cur_b_pack, params_->col_, params_->deep_);
|
||||
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
|
||||
bias_ptr_, cur_sums, ColMajor);
|
||||
} else {
|
||||
RowMajor2Col16x4Major(cur_b, params_->deep_, params_->col_, cur_b_pack, params_->deep_16_);
|
||||
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
|
||||
bias_ptr_, cur_sums, RowMajor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < params_->batch; ++i) {
|
||||
auto cur_a_ptr = a_ptr + i * a_stride;
|
||||
auto cur_b_ptr = b_ptr + i * b_stride;
|
||||
|
||||
if (params_->a_transpose_) {
|
||||
RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_, d16_);
|
||||
RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_, params_->deep_16_);
|
||||
CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, ColMajor);
|
||||
} else {
|
||||
RowMajor2Row16x4MajorInt8(cur_a_ptr, a_r4x16_ptr_, params_->row_, params_->deep_);
|
||||
CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, RowMajor);
|
||||
}
|
||||
if (params_->b_transpose_) {
|
||||
RowMajor2Row16x4MajorInt8(cur_b_ptr, b_c16x4_ptr_, params_->col_, params_->deep_);
|
||||
CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
|
||||
NULL, weight_bias_sums_, ColMajor);
|
||||
} else {
|
||||
RowMajor2Col16x4Major(cur_b_ptr, params_->deep_, params_->col_, b_c16x4_ptr_, d16_);
|
||||
CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
|
||||
NULL, weight_bias_sums_, RowMajor);
|
||||
}
|
||||
b_c16x4_ptr_ = b_c16x4_batch_ + i * params_->col_4_ * params_->deep_16_;
|
||||
weight_bias_sums_ = weight_bias_sums_batch_ + i * params_->col_4_;
|
||||
c_ptr_ = c_ptr + i * c_stride;
|
||||
ret = ParallelLaunch(this->context_->thread_pool_, MatmulInt8Run, this, thread_count_);
|
||||
if (ret != RET_OK) {
|
||||
|
|
|
@ -43,28 +43,32 @@ class MatmulInt8CPUKernel : public MatmulBaseCPUKernel {
|
|||
ctx_->allocator->Free(a_r4x16_ptr_);
|
||||
a_r4x16_ptr_ = nullptr;
|
||||
}
|
||||
if (b_c16x4_ptr_ != nullptr) {
|
||||
ctx_->allocator->Free(b_c16x4_ptr_);
|
||||
b_c16x4_ptr_ = nullptr;
|
||||
}
|
||||
if (input_sums_ != nullptr) {
|
||||
ctx_->allocator->Free(input_sums_);
|
||||
input_sums_ = nullptr;
|
||||
}
|
||||
if (weight_bias_sums_ != nullptr) {
|
||||
ctx_->allocator->Free(weight_bias_sums_);
|
||||
weight_bias_sums_ = nullptr;
|
||||
if (b_c16x4_batch_ != nullptr) {
|
||||
ctx_->allocator->Free(b_c16x4_batch_);
|
||||
b_c16x4_batch_ = nullptr;
|
||||
}
|
||||
if (weight_bias_sums_batch_ != nullptr) {
|
||||
ctx_->allocator->Free(weight_bias_sums_batch_);
|
||||
weight_bias_sums_batch_ = nullptr;
|
||||
}
|
||||
if (bias_ptr_ != nullptr) {
|
||||
ctx_->allocator->Free(bias_ptr_);
|
||||
bias_ptr_ = nullptr;
|
||||
}
|
||||
}
|
||||
MatmulQuantArg quant_params_;
|
||||
int8_t *a_r4x16_ptr_ = nullptr;
|
||||
int8_t *b_c16x4_ptr_ = nullptr;
|
||||
int8_t *c_ptr_ = nullptr;
|
||||
int *bias_ptr_ = nullptr;
|
||||
int *input_sums_ = nullptr;
|
||||
int *weight_bias_sums_ = nullptr;
|
||||
int r4_;
|
||||
int c4_;
|
||||
int d16_;
|
||||
int8_t *b_c16x4_batch_ = nullptr;
|
||||
int *weight_bias_sums_batch_ = nullptr;
|
||||
}; // namespace mindspore::kernel
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
Loading…
Reference in New Issue