diff --git a/mindspore/lite/nnacl/int8/matmul_int8.c b/mindspore/lite/nnacl/int8/matmul_int8.c index 0382fb53f64..bc45ea4a766 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.c +++ b/mindspore/lite/nnacl/int8/matmul_int8.c @@ -254,19 +254,22 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums, const int *bias, int mini, int maxi, int out_zp, int32_t *multiplier, int32_t *left_shift, int32_t *right_shift, size_t stride, size_t filter_peroc, int32_t *filter_zp) { - int col_tile = C4NUM; - /* support per-layer && weight per-channel */ - /* row4x16-major * row16x2-major => (int8)row-major*/ + /* + * row4x16-major * row16x4-major => (int8)row-major + * support per-layer && weight per-channel + * a_sums is perT : input_row_sum * filter_zp + * perOc : input_row_sum + * */ for (int r = 0; r < row; r++) { for (int c = 0; c < col; c++) { int r4div = r / C4NUM, r4mod = r % C4NUM; - int c4div = c / col_tile, c4mod = c % col_tile; + int c4div = c / C4NUM, c4mod = c % C4NUM; size_t ci = r * stride + c; int32_t value = 0; for (int d = 0; d < deep16; d++) { int d16div = d / C16NUM, d16mod = d % C16NUM; size_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod; - size_t bi = c4div * deep16 * col_tile + d16div * col_tile * C16NUM + c4mod * C16NUM + d16mod; + size_t bi = c4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod; value = value + a[ai] * b[bi]; } int32_t cur_input_sum = filter_peroc ? a_sums[r] * filter_zp[c] : a_sums[r]; @@ -568,8 +571,8 @@ void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, Dat } // dst: bias + depth*input_zp*weight_zp - input_zp*weight_col_sums -void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, const int *bias, int *dst, - DataOrder order) { +void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int *weight_zp_ptr, const int *bias, int *dst, + DataOrder order, bool filter_per_channel) { for (int c = 0; c < col; ++c) { int sum = 0; for (int r = 0; r < row; ++r) { @@ -579,6 +582,7 @@ void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weig sum += weight[c * row + r]; } } + int weight_zp = filter_per_channel ? weight_zp_ptr[c] : weight_zp_ptr[0]; dst[c] = row * input_zp * weight_zp - input_zp * sum; if (bias != NULL) { dst[c] += bias[c]; diff --git a/mindspore/lite/nnacl/int8/matmul_int8.h b/mindspore/lite/nnacl/int8/matmul_int8.h index 3a0b008fbd2..65316338a16 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.h +++ b/mindspore/lite/nnacl/int8/matmul_int8.h @@ -35,8 +35,11 @@ void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row void RowMajor2Row16x4MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); void RowMajor2Col16x4MajorInt8(int8_t *src, int row, int col, int8_t *dst); void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order); -void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, const int *bias, int *dst, - DataOrder order); +void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int *weight_zp_ptr, const int *bias, int *dst, + DataOrder order, bool filter_per_channel); +void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums, + const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift, + int32_t *right_shift, size_t stride, size_t filter_peroc, int32_t *filter_zp); /* 8x4 4x8 -> 8x8 */ void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); @@ -60,9 +63,6 @@ void MatMulInt8_4x16_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, size_t per_channel, int32_t *filter_zp); -void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums, - const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift, - int32_t *right_shift, size_t stride, size_t filter_peroc, int32_t *filter_zp); #ifdef ENABLE_ARM64 void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums, diff --git a/mindspore/lite/nnacl/matmul_parameter.h b/mindspore/lite/nnacl/matmul_parameter.h index 51aadae35e7..17a8a12955f 100644 --- a/mindspore/lite/nnacl/matmul_parameter.h +++ b/mindspore/lite/nnacl/matmul_parameter.h @@ -18,6 +18,7 @@ #define MINDSPORE_LITE_NNACL_MATMUL_H_ #include "nnacl/op_base.h" +#include "nnacl/quantization/quantize.h" typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, const int *input_sum, const int *bias); @@ -60,4 +61,16 @@ typedef struct MatMulParameter { ActType act_type_; } MatMulParameter; +typedef struct MatmulQuantParameter { + QuantArg input_; + QuantArg output_; + int32_t out_act_min_; + int32_t out_act_max_; + float *filter_scale_; + int32_t *filter_zp_; + int32_t *left_shift_; + int32_t *right_shift_; + int32_t *quant_multiplier_; +} MatmulQuantParameter; + #endif // MINDSPORE_LITE_NNACL_MATMUL_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc index f51eff128e9..49f4d0241cc 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc @@ -15,123 +15,230 @@ */ #include "src/runtime/kernel/arm/int8/fullconnection_int8.h" -#include "nnacl/int8/matmul_int8.h" -#include "nnacl/common_func.h" #include "src/runtime/runtime_api.h" -#include "include/errorcode.h" #include "src/kernel_registry.h" -using mindspore::lite::RET_MEMORY_FAILED; -using mindspore::lite::RET_OK; - using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_MEMORY_FAILED; using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_FullConnection; namespace mindspore::kernel { +void FullconnectionInt8CPUKernel::FreeQuantParam() { + if (quant_.filter_scale_ != nullptr) { + free(quant_.filter_scale_); + quant_.filter_scale_ = nullptr; + } + if (quant_.filter_zp_ != nullptr) { + free(quant_.filter_zp_); + quant_.filter_zp_ = nullptr; + } + if (quant_.left_shift_ != nullptr) { + free(quant_.left_shift_); + quant_.left_shift_ = nullptr; + } + if (quant_.right_shift_ != nullptr) { + free(quant_.right_shift_); + quant_.right_shift_ = nullptr; + } + if (quant_.quant_multiplier_ != nullptr) { + free(quant_.quant_multiplier_); + quant_.quant_multiplier_ = nullptr; + } + return; +} + +void FullconnectionInt8CPUKernel::FreeTmpBuffer() { + if (pack_a_ptr_ != nullptr) { + free(pack_a_ptr_); + pack_a_ptr_ = nullptr; + } + if (pack_b_ptr_ != nullptr) { + free(pack_b_ptr_); + pack_b_ptr_ = nullptr; + } + if (input_sums_ != nullptr) { + free(input_sums_); + input_sums_ = nullptr; + } + if (weight_bias_sums_ != nullptr) { + free(weight_bias_sums_); + weight_bias_sums_ = nullptr; + } + if (bias_ptr_ != nullptr) { + free(bias_ptr_); + bias_ptr_ = nullptr; + } + return; +} + +int FullconnectionInt8CPUKernel::MallocQuantParam() { + auto weight_tensor = in_tensors_[1]; + auto weight_quant_params = weight_tensor->quant_params(); + int col = weight_tensor->shape().front(); + filter_per_channel_ = (weight_quant_params.size() > 1); + + int init_size = filter_per_channel_ ? col : 1; + + quant_.filter_scale_ = reinterpret_cast(malloc(init_size * sizeof(float))); + if (quant_.filter_scale_ == nullptr) { + return RET_ERROR; + } + quant_.filter_zp_ = reinterpret_cast(malloc(init_size * sizeof(int32_t))); + if (quant_.filter_zp_ == nullptr) { + return RET_ERROR; + } + quant_.left_shift_ = reinterpret_cast(malloc(init_size * sizeof(int32_t))); + if (quant_.left_shift_ == nullptr) { + return RET_ERROR; + } + quant_.right_shift_ = reinterpret_cast(malloc(init_size * sizeof(int32_t))); + if (quant_.right_shift_ == nullptr) { + return RET_ERROR; + } + quant_.quant_multiplier_ = reinterpret_cast(malloc(init_size * sizeof(int32_t))); + if (quant_.quant_multiplier_ == nullptr) { + return RET_ERROR; + } + return RET_OK; +} + int FullconnectionInt8CPUKernel::Init() { + auto ret = MallocQuantParam(); + if (ret != RET_OK) { + FreeQuantParam(); + return ret; + } + + auto in_quant_params = in_tensors_[0]->quant_params(); + quant_.input_.zp_ = in_quant_params.front().zeroPoint; + quant_.input_.scale_ = in_quant_params.front().scale; + + auto out_quant_params = out_tensors_[0]->quant_params(); + quant_.output_.zp_ = out_quant_params.front().zeroPoint; + quant_.output_.scale_ = out_quant_params.front().scale; + + auto weight_tensor = in_tensors_[1]; + fc_param_->b_const_ = (weight_tensor->data_c() != nullptr); + int weight_quant_num = filter_per_channel_ ? weight_tensor->shape().front() : 1; + auto weight_quant_params = weight_tensor->quant_params(); + + for (int i = 0; i < weight_quant_num; i++) { + quant_.filter_zp_[i] = weight_quant_params[i].zeroPoint; + quant_.filter_scale_[i] = weight_quant_params[i].scale; + } + + for (int i = 0; i < weight_quant_num; ++i) { + const double in_scale = static_cast(quant_.input_.scale_ * quant_.filter_scale_[i]); + double real_multiplier = in_scale / static_cast(quant_.output_.scale_); + QuantizeRoundParameter(real_multiplier, &quant_.quant_multiplier_[i], &quant_.left_shift_[i], + &quant_.right_shift_[i]); + } + + CalculateActivationRangeQuantized(fc_param_->act_type_ == ActType_Relu, fc_param_->act_type_ == ActType_Relu6, + quant_.output_.zp_, quant_.output_.scale_, &quant_.out_act_min_, + &quant_.out_act_max_); + if (!InferShapeDone()) { return RET_OK; } return ReSize(); } -int FullconnectionInt8CPUKernel::ReSize() { - FreeTmpBuffer(); +void FullconnectionInt8CPUKernel::InitParam() { int row = 1; - for (size_t i = 0; i < out_tensors_[0]->shape().size() - 1; ++i) row *= (out_tensors_[0]->shape())[i]; + for (size_t i = 0; i < out_tensors_[0]->shape().size() - 1; ++i) { + row *= (out_tensors_[0]->shape())[i]; + } fc_param_->row_ = row; fc_param_->col_ = out_tensors_[0]->shape().back(); fc_param_->deep_ = (in_tensors_[1]->shape())[1]; - fc_param_->row_8_ = UP_ROUND(fc_param_->row_, 8); - fc_param_->col_8_ = UP_ROUND(fc_param_->col_, 8); - r4_ = UP_ROUND(fc_param_->row_, 4); - c4_ = UP_ROUND(fc_param_->col_, 4); - d16_ = UP_ROUND(fc_param_->deep_, 16); - thread_count_ = MSMIN(thread_count_, UP_DIV(c4_, 4)); - thread_stride_ = UP_DIV(UP_DIV(c4_, 4), thread_count_); + fc_param_->row_4_ = UP_ROUND(fc_param_->row_, C4NUM); + fc_param_->row_8_ = UP_ROUND(fc_param_->row_, C8NUM); + fc_param_->col_2_ = UP_ROUND(fc_param_->col_, C2NUM); + fc_param_->col_4_ = UP_ROUND(fc_param_->col_, C4NUM); + fc_param_->col_8_ = UP_ROUND(fc_param_->col_, C8NUM); + fc_param_->col_16_ = UP_ROUND(fc_param_->col_, C16NUM); + fc_param_->deep_4_ = UP_ROUND(fc_param_->deep_, C4NUM); + fc_param_->deep_16_ = UP_ROUND(fc_param_->deep_, C16NUM); - a_r4x16_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(r4_ * d16_ * sizeof(int8_t))); - b_c16x4_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(c4_ * d16_ * sizeof(int8_t))); - input_sums_ = reinterpret_cast(ctx_->allocator->Malloc(r4_ * sizeof(int))); - weight_bias_sums_ = reinterpret_cast(ctx_->allocator->Malloc(c4_ * sizeof(int))); - if (a_r4x16_ptr_ == nullptr || b_c16x4_ptr_ == nullptr || input_sums_ == nullptr || weight_bias_sums_ == nullptr) { - MS_LOG(ERROR) << "Memory allocation failed"; + thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(fc_param_->col_4_, C4NUM)); + thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_4_, C4NUM), thread_count_); + return; +} + +int FullconnectionInt8CPUKernel::ReSize() { + FreeTmpBuffer(); + + InitParam(); + + pack_a_ptr_ = reinterpret_cast(malloc(fc_param_->row_4_ * fc_param_->deep_16_ * sizeof(int8_t))); + if (pack_a_ptr_ == nullptr) { FreeTmpBuffer(); - return RET_MEMORY_FAILED; + return RET_ERROR; } - memset(a_r4x16_ptr_, 0, r4_ * d16_ * sizeof(int8_t)); - memset(b_c16x4_ptr_, 0, c4_ * d16_ * sizeof(int8_t)); - memset(input_sums_, 0, r4_ * sizeof(int)); - memset(weight_bias_sums_, 0, c4_ * sizeof(int)); + pack_b_ptr_ = reinterpret_cast(malloc(fc_param_->col_4_ * fc_param_->deep_16_ * sizeof(int8_t))); + if (pack_b_ptr_ == nullptr) { + FreeTmpBuffer(); + return RET_ERROR; + } + input_sums_ = reinterpret_cast(malloc(fc_param_->row_4_ * sizeof(int))); + if (input_sums_ == nullptr) { + FreeTmpBuffer(); + return RET_ERROR; + } + weight_bias_sums_ = reinterpret_cast(malloc(fc_param_->col_4_ * sizeof(int))); + if (weight_bias_sums_ == nullptr) { + FreeTmpBuffer(); + return RET_ERROR; + } + + memset(pack_a_ptr_, 0, fc_param_->row_4_ * fc_param_->deep_16_ * sizeof(int8_t)); + memset(pack_b_ptr_, 0, fc_param_->col_4_ * fc_param_->deep_16_ * sizeof(int8_t)); + memset(input_sums_, 0, fc_param_->row_4_ * sizeof(int)); + memset(weight_bias_sums_, 0, fc_param_->col_4_ * sizeof(int)); if (in_tensors_.size() == 3) { - auto bias_len = fc_param_->col_8_ * sizeof(int); - bias_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(bias_len)); + bias_ptr_ = reinterpret_cast(malloc(fc_param_->col_4_ * sizeof(int))); if (bias_ptr_ == nullptr) { MS_LOG(ERROR) << "Memory allocation failed"; FreeTmpBuffer(); return RET_MEMORY_FAILED; } - memcpy(bias_ptr_, in_tensors_[2]->data_c(), bias_len); + memcpy(bias_ptr_, in_tensors_[2]->data_c(), fc_param_->col_ * sizeof(int)); } else { bias_ptr_ = nullptr; } - auto input_tensor = in_tensors_[0]; - auto params = input_tensor->quant_params(); - MS_ASSERT(params.size() == 1); - quant_params_.input.zp_ = params.front().zeroPoint; - quant_params_.input.scale_ = params.front().scale; - auto weight_tensor = in_tensors_[1]; - params = weight_tensor->quant_params(); - MS_ASSERT(params.size() == 1); - quant_params_.weight.zp_ = params.front().zeroPoint; - quant_params_.weight.scale_ = params.front().scale; - auto output_tensor = out_tensors_[0]; - params = output_tensor->quant_params(); - MS_ASSERT(params.size() == 1); - quant_params_.output.zp_ = params.front().zeroPoint; - quant_params_.output.scale_ = params.front().scale; - - double real_multiplier = quant_params_.input.scale_ * quant_params_.weight.scale_ / quant_params_.output.scale_; - QuantizeRoundParameter(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift, - &quant_params_.right_shift); - CalculateActivationRangeQuantized(fc_param_->act_type_ == ActType_Relu, fc_param_->act_type_ == ActType_Relu6, - quant_params_.output.zp_, quant_params_.output.scale_, &quant_params_.out_act_min, - &quant_params_.out_act_max); - fc_param_->b_const_ = (in_tensors_[1]->data_c() != nullptr); if (fc_param_->b_const_) { auto weight_data = reinterpret_cast(in_tensors_[1]->data_c()); - RowMajor2Row16x4MajorInt8(weight_data, b_c16x4_ptr_, fc_param_->col_, fc_param_->deep_); - CalcWeightBiasSums(weight_data, fc_param_->deep_, fc_param_->col_, quant_params_.input.zp_, - quant_params_.weight.zp_, bias_ptr_, weight_bias_sums_, ColMajor); + RowMajor2Row16x4MajorInt8(weight_data, pack_b_ptr_, fc_param_->col_, fc_param_->deep_); + CalcWeightBiasSums(weight_data, fc_param_->deep_, fc_param_->col_, quant_.input_.zp_, quant_.filter_zp_, bias_ptr_, + weight_bias_sums_, ColMajor, filter_per_channel_); } return RET_OK; } int FullconnectionInt8CPUKernel::RunImpl(int task_id) { - int cur_oc = MSMIN(thread_stride_, UP_DIV(c4_, 4) - task_id * thread_stride_); + int stride = thread_stride_ * C4NUM; + int cur_stride = task_id * stride; + int res_stride = fc_param_->col_ - cur_stride; + int cur_oc = MSMIN(stride, res_stride); if (cur_oc <= 0) { return RET_OK; } - int cur_oc_res = MSMIN(thread_stride_ * C4NUM, fc_param_->col_ - task_id * thread_stride_ * C4NUM); - auto &q = quant_params_; - auto &p = fc_param_; - auto cur_b = b_c16x4_ptr_ + task_id * thread_stride_ * C4NUM * d16_; - auto cur_bias = weight_bias_sums_ + task_id * thread_stride_ * C4NUM; - auto output_ptr = reinterpret_cast(out_tensors_[0]->data_c()); - auto cur_c = output_ptr + task_id * thread_stride_ * C4NUM; -#ifdef ENABLE_ARM64 - MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, q.out_act_min, - q.out_act_max, q.output.zp_, &q.quant_multiplier, &q.left_shift, &q.right_shift, p->row_, cur_oc_res, - p->col_ * sizeof(int8_t), 0); -#else - MatMulInt8_16x4_r(a_r4x16_ptr_, cur_b, cur_c, p->row_, cur_oc_res, d16_, p->col_, input_sums_, cur_bias, - &q.left_shift, &q.right_shift, &q.quant_multiplier, q.output.zp_, INT8_MIN, INT8_MAX, false); -#endif + + int32_t *cur_left = filter_per_channel_ ? quant_.left_shift_ + cur_stride : quant_.left_shift_; + int32_t *cur_right = filter_per_channel_ ? quant_.right_shift_ + cur_stride : quant_.right_shift_; + int32_t *cur_mul = filter_per_channel_ ? quant_.quant_multiplier_ + cur_stride : quant_.quant_multiplier_; + int32_t *cur_zp = filter_per_channel_ ? quant_.filter_zp_ + cur_stride : quant_.filter_zp_; + + MatmulInt8Opt(pack_a_ptr_, pack_b_ptr_ + cur_stride * fc_param_->deep_16_, c_ptr_ + cur_stride, fc_param_->row_, + cur_oc, fc_param_->deep_16_, input_sums_, weight_bias_sums_ + cur_stride, quant_.out_act_min_, + quant_.out_act_max_, quant_.output_.zp_, cur_mul, cur_left, cur_right, fc_param_->col_, + filter_per_channel_, cur_zp); return RET_OK; } @@ -148,14 +255,19 @@ int FcInt8Run(void *cdata, int task_id) { int FullconnectionInt8CPUKernel::Run() { auto input_ptr = reinterpret_cast(in_tensors_[0]->data_c()); - RowMajor2Row16x4MajorInt8(input_ptr, a_r4x16_ptr_, fc_param_->row_, fc_param_->deep_); - CalcInputSums(input_ptr, fc_param_->row_, fc_param_->deep_, quant_params_.weight.zp_, input_sums_, RowMajor); + RowMajor2Row16x4MajorInt8(input_ptr, pack_a_ptr_, fc_param_->row_, fc_param_->deep_); + + int32_t tmp_weight_zp = filter_per_channel_ ? 1 : quant_.filter_zp_[0]; + CalcInputSums(input_ptr, fc_param_->row_, fc_param_->deep_, tmp_weight_zp, input_sums_, RowMajor); + if (!fc_param_->b_const_) { auto weight_data = reinterpret_cast(in_tensors_[1]->data_c()); - RowMajor2Row16x4MajorInt8(weight_data, b_c16x4_ptr_, fc_param_->col_, fc_param_->deep_); - CalcWeightBiasSums(weight_data, fc_param_->deep_, fc_param_->col_, quant_params_.input.zp_, - quant_params_.weight.zp_, bias_ptr_, weight_bias_sums_, ColMajor); + RowMajor2Row16x4MajorInt8(weight_data, pack_b_ptr_, fc_param_->col_, fc_param_->deep_); + CalcWeightBiasSums(weight_data, fc_param_->deep_, fc_param_->col_, quant_.input_.zp_, quant_.filter_zp_, bias_ptr_, + weight_bias_sums_, ColMajor, filter_per_channel_); } + + c_ptr_ = reinterpret_cast(out_tensors_[0]->data_c()); auto ret = ParallelLaunch(this->context_->thread_pool_, FcInt8Run, this, thread_count_); if (ret != RET_OK) { MS_LOG(ERROR) << "ParallelLaunch failed"; @@ -163,6 +275,7 @@ int FullconnectionInt8CPUKernel::Run() { } return RET_OK; } + kernel::LiteKernel *CpuFullConnectionInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::InnerContext *ctx, @@ -185,5 +298,4 @@ kernel::LiteKernel *CpuFullConnectionInt8KernelCreator(const std::vector -#include "src/runtime/kernel/arm/base/fullconnection_base.h" -#include "include/context.h" +#include "src/lite_kernel.h" +#include "include/errorcode.h" #include "nnacl/quantization/quantize.h" +#include "nnacl/common_func.h" #include "nnacl/int8/common_func_int8.h" - -using mindspore::lite::InnerContext; +#include "nnacl/int8/matmul_int8.h" namespace mindspore::kernel { -class FullconnectionInt8CPUKernel : public FullconnectionBaseCPUKernel { +class FullconnectionInt8CPUKernel : public LiteKernel { public: FullconnectionInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const InnerContext *ctx, + const std::vector &outputs, const mindspore::lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : FullconnectionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} - ~FullconnectionInt8CPUKernel() override { FreeTmpBuffer(); } + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + fc_param_ = reinterpret_cast(op_parameter_); + } + ~FullconnectionInt8CPUKernel() override { + FreeTmpBuffer(); + FreeQuantParam(); + } int Init() override; int ReSize() override; int Run() override; + + public: int RunImpl(int task_id); private: - void FreeTmpBuffer() { - if (a_r4x16_ptr_ != nullptr) { - ctx_->allocator->Free(a_r4x16_ptr_); - a_r4x16_ptr_ = nullptr; - } - if (b_c16x4_ptr_ != nullptr) { - ctx_->allocator->Free(b_c16x4_ptr_); - b_c16x4_ptr_ = nullptr; - } - if (input_sums_ != nullptr) { - ctx_->allocator->Free(input_sums_); - input_sums_ = nullptr; - } - if (weight_bias_sums_ != nullptr) { - ctx_->allocator->Free(weight_bias_sums_); - weight_bias_sums_ = nullptr; - } - if (bias_ptr_ != nullptr) { - ctx_->allocator->Free(weight_bias_sums_); - weight_bias_sums_ = nullptr; - } - } - MatmulQuantArg quant_params_; - int8_t *a_r4x16_ptr_ = nullptr; - int8_t *b_c16x4_ptr_ = nullptr; + void InitParam(); + void FreeTmpBuffer(); + void FreeQuantParam(); + int MallocQuantParam(); + + private: + MatMulParameter *fc_param_ = nullptr; + MatmulQuantParameter quant_; + int thread_count_ = 1; + int thread_stride_ = 0; + int8_t *pack_a_ptr_ = nullptr; + int8_t *pack_b_ptr_ = nullptr; + int8_t *c_ptr_ = nullptr; int *input_sums_ = nullptr; int *weight_bias_sums_ = nullptr; int *bias_ptr_ = nullptr; - int r4_ = 0; - int c4_ = 0; - int d16_ = 0; + bool filter_per_channel_ = true; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc index ebea63b8a3a..a8e88edebe3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc @@ -102,12 +102,12 @@ int MatmulInt8CPUKernel::ReSize() { auto cur_sums = weight_bias_sums_batch_ + i * params_->col_4_; if (params_->b_transpose_) { RowMajor2Row16x4MajorInt8(cur_b, cur_b_pack, params_->col_, params_->deep_); - CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, - bias_ptr_, cur_sums, ColMajor); + CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, &quant_params_.weight.zp_, + bias_ptr_, cur_sums, ColMajor, false); } else { RowMajor2Col16x4MajorInt8(cur_b, params_->deep_, params_->col_, cur_b_pack); - CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, - bias_ptr_, cur_sums, RowMajor); + CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, &quant_params_.weight.zp_, + bias_ptr_, cur_sums, RowMajor, false); } } } @@ -166,12 +166,12 @@ int MatmulInt8CPUKernel::Run() { auto cur_sums = weight_bias_sums_batch_ + i * params_->col_4_; if (params_->b_transpose_) { RowMajor2Row16x4MajorInt8(cur_b, cur_b_pack, params_->col_, params_->deep_); - CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, - bias_ptr_, cur_sums, ColMajor); + CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, &quant_params_.weight.zp_, + bias_ptr_, cur_sums, ColMajor, false); } else { RowMajor2Col16x4MajorInt8(cur_b, params_->deep_, params_->col_, cur_b_pack); - CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, - bias_ptr_, cur_sums, RowMajor); + CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, &quant_params_.weight.zp_, + bias_ptr_, cur_sums, RowMajor, false); } } }