[MSLITE] fullconnection int8 suppot perchannel

This commit is contained in:
ling 2020-11-25 16:59:20 +08:00
parent c78683a411
commit 6de20db54a
6 changed files with 256 additions and 134 deletions

View File

@ -254,19 +254,22 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row,
void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums,
const int *bias, int mini, int maxi, int out_zp, int32_t *multiplier, int32_t *left_shift,
int32_t *right_shift, size_t stride, size_t filter_peroc, int32_t *filter_zp) {
int col_tile = C4NUM;
/* support per-layer && weight per-channel */
/* row4x16-major * row16x2-major => (int8)row-major*/
/*
* row4x16-major * row16x4-major => (int8)row-major
* support per-layer && weight per-channel
* a_sums is perT : input_row_sum * filter_zp
* perOc : input_row_sum
* */
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r4div = r / C4NUM, r4mod = r % C4NUM;
int c4div = c / col_tile, c4mod = c % col_tile;
int c4div = c / C4NUM, c4mod = c % C4NUM;
size_t ci = r * stride + c;
int32_t value = 0;
for (int d = 0; d < deep16; d++) {
int d16div = d / C16NUM, d16mod = d % C16NUM;
size_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod;
size_t bi = c4div * deep16 * col_tile + d16div * col_tile * C16NUM + c4mod * C16NUM + d16mod;
size_t bi = c4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod;
value = value + a[ai] * b[bi];
}
int32_t cur_input_sum = filter_peroc ? a_sums[r] * filter_zp[c] : a_sums[r];
@ -568,8 +571,8 @@ void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, Dat
}
// dst: bias + depth*input_zp*weight_zp - input_zp*weight_col_sums
void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, const int *bias, int *dst,
DataOrder order) {
void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int *weight_zp_ptr, const int *bias, int *dst,
DataOrder order, bool filter_per_channel) {
for (int c = 0; c < col; ++c) {
int sum = 0;
for (int r = 0; r < row; ++r) {
@ -579,6 +582,7 @@ void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weig
sum += weight[c * row + r];
}
}
int weight_zp = filter_per_channel ? weight_zp_ptr[c] : weight_zp_ptr[0];
dst[c] = row * input_zp * weight_zp - input_zp * sum;
if (bias != NULL) {
dst[c] += bias[c];

View File

@ -35,8 +35,11 @@ void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row
void RowMajor2Row16x4MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void RowMajor2Col16x4MajorInt8(int8_t *src, int row, int col, int8_t *dst);
void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order);
void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, const int *bias, int *dst,
DataOrder order);
void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int *weight_zp_ptr, const int *bias, int *dst,
DataOrder order, bool filter_per_channel);
void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums,
const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift,
int32_t *right_shift, size_t stride, size_t filter_peroc, int32_t *filter_zp);
/* 8x4 4x8 -> 8x8 */
void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
@ -60,9 +63,6 @@ void MatMulInt8_4x16_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
size_t per_channel, int32_t *filter_zp);
void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums,
const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift,
int32_t *right_shift, size_t stride, size_t filter_peroc, int32_t *filter_zp);
#ifdef ENABLE_ARM64
void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums,

View File

@ -18,6 +18,7 @@
#define MINDSPORE_LITE_NNACL_MATMUL_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
const int *input_sum, const int *bias);
@ -60,4 +61,16 @@ typedef struct MatMulParameter {
ActType act_type_;
} MatMulParameter;
typedef struct MatmulQuantParameter {
QuantArg input_;
QuantArg output_;
int32_t out_act_min_;
int32_t out_act_max_;
float *filter_scale_;
int32_t *filter_zp_;
int32_t *left_shift_;
int32_t *right_shift_;
int32_t *quant_multiplier_;
} MatmulQuantParameter;
#endif // MINDSPORE_LITE_NNACL_MATMUL_H_

View File

@ -15,123 +15,230 @@
*/
#include "src/runtime/kernel/arm/int8/fullconnection_int8.h"
#include "nnacl/int8/matmul_int8.h"
#include "nnacl/common_func.h"
#include "src/runtime/runtime_api.h"
#include "include/errorcode.h"
#include "src/kernel_registry.h"
using mindspore::lite::RET_MEMORY_FAILED;
using mindspore::lite::RET_OK;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_MEMORY_FAILED;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_FullConnection;
namespace mindspore::kernel {
void FullconnectionInt8CPUKernel::FreeQuantParam() {
if (quant_.filter_scale_ != nullptr) {
free(quant_.filter_scale_);
quant_.filter_scale_ = nullptr;
}
if (quant_.filter_zp_ != nullptr) {
free(quant_.filter_zp_);
quant_.filter_zp_ = nullptr;
}
if (quant_.left_shift_ != nullptr) {
free(quant_.left_shift_);
quant_.left_shift_ = nullptr;
}
if (quant_.right_shift_ != nullptr) {
free(quant_.right_shift_);
quant_.right_shift_ = nullptr;
}
if (quant_.quant_multiplier_ != nullptr) {
free(quant_.quant_multiplier_);
quant_.quant_multiplier_ = nullptr;
}
return;
}
void FullconnectionInt8CPUKernel::FreeTmpBuffer() {
if (pack_a_ptr_ != nullptr) {
free(pack_a_ptr_);
pack_a_ptr_ = nullptr;
}
if (pack_b_ptr_ != nullptr) {
free(pack_b_ptr_);
pack_b_ptr_ = nullptr;
}
if (input_sums_ != nullptr) {
free(input_sums_);
input_sums_ = nullptr;
}
if (weight_bias_sums_ != nullptr) {
free(weight_bias_sums_);
weight_bias_sums_ = nullptr;
}
if (bias_ptr_ != nullptr) {
free(bias_ptr_);
bias_ptr_ = nullptr;
}
return;
}
int FullconnectionInt8CPUKernel::MallocQuantParam() {
auto weight_tensor = in_tensors_[1];
auto weight_quant_params = weight_tensor->quant_params();
int col = weight_tensor->shape().front();
filter_per_channel_ = (weight_quant_params.size() > 1);
int init_size = filter_per_channel_ ? col : 1;
quant_.filter_scale_ = reinterpret_cast<float *>(malloc(init_size * sizeof(float)));
if (quant_.filter_scale_ == nullptr) {
return RET_ERROR;
}
quant_.filter_zp_ = reinterpret_cast<int32_t *>(malloc(init_size * sizeof(int32_t)));
if (quant_.filter_zp_ == nullptr) {
return RET_ERROR;
}
quant_.left_shift_ = reinterpret_cast<int32_t *>(malloc(init_size * sizeof(int32_t)));
if (quant_.left_shift_ == nullptr) {
return RET_ERROR;
}
quant_.right_shift_ = reinterpret_cast<int32_t *>(malloc(init_size * sizeof(int32_t)));
if (quant_.right_shift_ == nullptr) {
return RET_ERROR;
}
quant_.quant_multiplier_ = reinterpret_cast<int32_t *>(malloc(init_size * sizeof(int32_t)));
if (quant_.quant_multiplier_ == nullptr) {
return RET_ERROR;
}
return RET_OK;
}
int FullconnectionInt8CPUKernel::Init() {
auto ret = MallocQuantParam();
if (ret != RET_OK) {
FreeQuantParam();
return ret;
}
auto in_quant_params = in_tensors_[0]->quant_params();
quant_.input_.zp_ = in_quant_params.front().zeroPoint;
quant_.input_.scale_ = in_quant_params.front().scale;
auto out_quant_params = out_tensors_[0]->quant_params();
quant_.output_.zp_ = out_quant_params.front().zeroPoint;
quant_.output_.scale_ = out_quant_params.front().scale;
auto weight_tensor = in_tensors_[1];
fc_param_->b_const_ = (weight_tensor->data_c() != nullptr);
int weight_quant_num = filter_per_channel_ ? weight_tensor->shape().front() : 1;
auto weight_quant_params = weight_tensor->quant_params();
for (int i = 0; i < weight_quant_num; i++) {
quant_.filter_zp_[i] = weight_quant_params[i].zeroPoint;
quant_.filter_scale_[i] = weight_quant_params[i].scale;
}
for (int i = 0; i < weight_quant_num; ++i) {
const double in_scale = static_cast<double>(quant_.input_.scale_ * quant_.filter_scale_[i]);
double real_multiplier = in_scale / static_cast<double>(quant_.output_.scale_);
QuantizeRoundParameter(real_multiplier, &quant_.quant_multiplier_[i], &quant_.left_shift_[i],
&quant_.right_shift_[i]);
}
CalculateActivationRangeQuantized(fc_param_->act_type_ == ActType_Relu, fc_param_->act_type_ == ActType_Relu6,
quant_.output_.zp_, quant_.output_.scale_, &quant_.out_act_min_,
&quant_.out_act_max_);
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int FullconnectionInt8CPUKernel::ReSize() {
FreeTmpBuffer();
void FullconnectionInt8CPUKernel::InitParam() {
int row = 1;
for (size_t i = 0; i < out_tensors_[0]->shape().size() - 1; ++i) row *= (out_tensors_[0]->shape())[i];
for (size_t i = 0; i < out_tensors_[0]->shape().size() - 1; ++i) {
row *= (out_tensors_[0]->shape())[i];
}
fc_param_->row_ = row;
fc_param_->col_ = out_tensors_[0]->shape().back();
fc_param_->deep_ = (in_tensors_[1]->shape())[1];
fc_param_->row_8_ = UP_ROUND(fc_param_->row_, 8);
fc_param_->col_8_ = UP_ROUND(fc_param_->col_, 8);
r4_ = UP_ROUND(fc_param_->row_, 4);
c4_ = UP_ROUND(fc_param_->col_, 4);
d16_ = UP_ROUND(fc_param_->deep_, 16);
thread_count_ = MSMIN(thread_count_, UP_DIV(c4_, 4));
thread_stride_ = UP_DIV(UP_DIV(c4_, 4), thread_count_);
fc_param_->row_4_ = UP_ROUND(fc_param_->row_, C4NUM);
fc_param_->row_8_ = UP_ROUND(fc_param_->row_, C8NUM);
fc_param_->col_2_ = UP_ROUND(fc_param_->col_, C2NUM);
fc_param_->col_4_ = UP_ROUND(fc_param_->col_, C4NUM);
fc_param_->col_8_ = UP_ROUND(fc_param_->col_, C8NUM);
fc_param_->col_16_ = UP_ROUND(fc_param_->col_, C16NUM);
fc_param_->deep_4_ = UP_ROUND(fc_param_->deep_, C4NUM);
fc_param_->deep_16_ = UP_ROUND(fc_param_->deep_, C16NUM);
a_r4x16_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(r4_ * d16_ * sizeof(int8_t)));
b_c16x4_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(c4_ * d16_ * sizeof(int8_t)));
input_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(r4_ * sizeof(int)));
weight_bias_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(c4_ * sizeof(int)));
if (a_r4x16_ptr_ == nullptr || b_c16x4_ptr_ == nullptr || input_sums_ == nullptr || weight_bias_sums_ == nullptr) {
MS_LOG(ERROR) << "Memory allocation failed";
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(fc_param_->col_4_, C4NUM));
thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_4_, C4NUM), thread_count_);
return;
}
int FullconnectionInt8CPUKernel::ReSize() {
FreeTmpBuffer();
InitParam();
pack_a_ptr_ = reinterpret_cast<int8_t *>(malloc(fc_param_->row_4_ * fc_param_->deep_16_ * sizeof(int8_t)));
if (pack_a_ptr_ == nullptr) {
FreeTmpBuffer();
return RET_MEMORY_FAILED;
return RET_ERROR;
}
memset(a_r4x16_ptr_, 0, r4_ * d16_ * sizeof(int8_t));
memset(b_c16x4_ptr_, 0, c4_ * d16_ * sizeof(int8_t));
memset(input_sums_, 0, r4_ * sizeof(int));
memset(weight_bias_sums_, 0, c4_ * sizeof(int));
pack_b_ptr_ = reinterpret_cast<int8_t *>(malloc(fc_param_->col_4_ * fc_param_->deep_16_ * sizeof(int8_t)));
if (pack_b_ptr_ == nullptr) {
FreeTmpBuffer();
return RET_ERROR;
}
input_sums_ = reinterpret_cast<int *>(malloc(fc_param_->row_4_ * sizeof(int)));
if (input_sums_ == nullptr) {
FreeTmpBuffer();
return RET_ERROR;
}
weight_bias_sums_ = reinterpret_cast<int *>(malloc(fc_param_->col_4_ * sizeof(int)));
if (weight_bias_sums_ == nullptr) {
FreeTmpBuffer();
return RET_ERROR;
}
memset(pack_a_ptr_, 0, fc_param_->row_4_ * fc_param_->deep_16_ * sizeof(int8_t));
memset(pack_b_ptr_, 0, fc_param_->col_4_ * fc_param_->deep_16_ * sizeof(int8_t));
memset(input_sums_, 0, fc_param_->row_4_ * sizeof(int));
memset(weight_bias_sums_, 0, fc_param_->col_4_ * sizeof(int));
if (in_tensors_.size() == 3) {
auto bias_len = fc_param_->col_8_ * sizeof(int);
bias_ptr_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(bias_len));
bias_ptr_ = reinterpret_cast<int *>(malloc(fc_param_->col_4_ * sizeof(int)));
if (bias_ptr_ == nullptr) {
MS_LOG(ERROR) << "Memory allocation failed";
FreeTmpBuffer();
return RET_MEMORY_FAILED;
}
memcpy(bias_ptr_, in_tensors_[2]->data_c(), bias_len);
memcpy(bias_ptr_, in_tensors_[2]->data_c(), fc_param_->col_ * sizeof(int));
} else {
bias_ptr_ = nullptr;
}
auto input_tensor = in_tensors_[0];
auto params = input_tensor->quant_params();
MS_ASSERT(params.size() == 1);
quant_params_.input.zp_ = params.front().zeroPoint;
quant_params_.input.scale_ = params.front().scale;
auto weight_tensor = in_tensors_[1];
params = weight_tensor->quant_params();
MS_ASSERT(params.size() == 1);
quant_params_.weight.zp_ = params.front().zeroPoint;
quant_params_.weight.scale_ = params.front().scale;
auto output_tensor = out_tensors_[0];
params = output_tensor->quant_params();
MS_ASSERT(params.size() == 1);
quant_params_.output.zp_ = params.front().zeroPoint;
quant_params_.output.scale_ = params.front().scale;
double real_multiplier = quant_params_.input.scale_ * quant_params_.weight.scale_ / quant_params_.output.scale_;
QuantizeRoundParameter(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift,
&quant_params_.right_shift);
CalculateActivationRangeQuantized(fc_param_->act_type_ == ActType_Relu, fc_param_->act_type_ == ActType_Relu6,
quant_params_.output.zp_, quant_params_.output.scale_, &quant_params_.out_act_min,
&quant_params_.out_act_max);
fc_param_->b_const_ = (in_tensors_[1]->data_c() != nullptr);
if (fc_param_->b_const_) {
auto weight_data = reinterpret_cast<int8_t *>(in_tensors_[1]->data_c());
RowMajor2Row16x4MajorInt8(weight_data, b_c16x4_ptr_, fc_param_->col_, fc_param_->deep_);
CalcWeightBiasSums(weight_data, fc_param_->deep_, fc_param_->col_, quant_params_.input.zp_,
quant_params_.weight.zp_, bias_ptr_, weight_bias_sums_, ColMajor);
RowMajor2Row16x4MajorInt8(weight_data, pack_b_ptr_, fc_param_->col_, fc_param_->deep_);
CalcWeightBiasSums(weight_data, fc_param_->deep_, fc_param_->col_, quant_.input_.zp_, quant_.filter_zp_, bias_ptr_,
weight_bias_sums_, ColMajor, filter_per_channel_);
}
return RET_OK;
}
int FullconnectionInt8CPUKernel::RunImpl(int task_id) {
int cur_oc = MSMIN(thread_stride_, UP_DIV(c4_, 4) - task_id * thread_stride_);
int stride = thread_stride_ * C4NUM;
int cur_stride = task_id * stride;
int res_stride = fc_param_->col_ - cur_stride;
int cur_oc = MSMIN(stride, res_stride);
if (cur_oc <= 0) {
return RET_OK;
}
int cur_oc_res = MSMIN(thread_stride_ * C4NUM, fc_param_->col_ - task_id * thread_stride_ * C4NUM);
auto &q = quant_params_;
auto &p = fc_param_;
auto cur_b = b_c16x4_ptr_ + task_id * thread_stride_ * C4NUM * d16_;
auto cur_bias = weight_bias_sums_ + task_id * thread_stride_ * C4NUM;
auto output_ptr = reinterpret_cast<int8_t *>(out_tensors_[0]->data_c());
auto cur_c = output_ptr + task_id * thread_stride_ * C4NUM;
#ifdef ENABLE_ARM64
MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, q.out_act_min,
q.out_act_max, q.output.zp_, &q.quant_multiplier, &q.left_shift, &q.right_shift, p->row_, cur_oc_res,
p->col_ * sizeof(int8_t), 0);
#else
MatMulInt8_16x4_r(a_r4x16_ptr_, cur_b, cur_c, p->row_, cur_oc_res, d16_, p->col_, input_sums_, cur_bias,
&q.left_shift, &q.right_shift, &q.quant_multiplier, q.output.zp_, INT8_MIN, INT8_MAX, false);
#endif
int32_t *cur_left = filter_per_channel_ ? quant_.left_shift_ + cur_stride : quant_.left_shift_;
int32_t *cur_right = filter_per_channel_ ? quant_.right_shift_ + cur_stride : quant_.right_shift_;
int32_t *cur_mul = filter_per_channel_ ? quant_.quant_multiplier_ + cur_stride : quant_.quant_multiplier_;
int32_t *cur_zp = filter_per_channel_ ? quant_.filter_zp_ + cur_stride : quant_.filter_zp_;
MatmulInt8Opt(pack_a_ptr_, pack_b_ptr_ + cur_stride * fc_param_->deep_16_, c_ptr_ + cur_stride, fc_param_->row_,
cur_oc, fc_param_->deep_16_, input_sums_, weight_bias_sums_ + cur_stride, quant_.out_act_min_,
quant_.out_act_max_, quant_.output_.zp_, cur_mul, cur_left, cur_right, fc_param_->col_,
filter_per_channel_, cur_zp);
return RET_OK;
}
@ -148,14 +255,19 @@ int FcInt8Run(void *cdata, int task_id) {
int FullconnectionInt8CPUKernel::Run() {
auto input_ptr = reinterpret_cast<int8_t *>(in_tensors_[0]->data_c());
RowMajor2Row16x4MajorInt8(input_ptr, a_r4x16_ptr_, fc_param_->row_, fc_param_->deep_);
CalcInputSums(input_ptr, fc_param_->row_, fc_param_->deep_, quant_params_.weight.zp_, input_sums_, RowMajor);
RowMajor2Row16x4MajorInt8(input_ptr, pack_a_ptr_, fc_param_->row_, fc_param_->deep_);
int32_t tmp_weight_zp = filter_per_channel_ ? 1 : quant_.filter_zp_[0];
CalcInputSums(input_ptr, fc_param_->row_, fc_param_->deep_, tmp_weight_zp, input_sums_, RowMajor);
if (!fc_param_->b_const_) {
auto weight_data = reinterpret_cast<int8_t *>(in_tensors_[1]->data_c());
RowMajor2Row16x4MajorInt8(weight_data, b_c16x4_ptr_, fc_param_->col_, fc_param_->deep_);
CalcWeightBiasSums(weight_data, fc_param_->deep_, fc_param_->col_, quant_params_.input.zp_,
quant_params_.weight.zp_, bias_ptr_, weight_bias_sums_, ColMajor);
RowMajor2Row16x4MajorInt8(weight_data, pack_b_ptr_, fc_param_->col_, fc_param_->deep_);
CalcWeightBiasSums(weight_data, fc_param_->deep_, fc_param_->col_, quant_.input_.zp_, quant_.filter_zp_, bias_ptr_,
weight_bias_sums_, ColMajor, filter_per_channel_);
}
c_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->data_c());
auto ret = ParallelLaunch(this->context_->thread_pool_, FcInt8Run, this, thread_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ParallelLaunch failed";
@ -163,6 +275,7 @@ int FullconnectionInt8CPUKernel::Run() {
}
return RET_OK;
}
kernel::LiteKernel *CpuFullConnectionInt8KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter, const lite::InnerContext *ctx,
@ -185,5 +298,4 @@ kernel::LiteKernel *CpuFullConnectionInt8KernelCreator(const std::vector<lite::T
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_FullConnection, CpuFullConnectionInt8KernelCreator)
} // namespace mindspore::kernel

View File

@ -18,59 +18,52 @@
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_FULLCONNECTION_INT8_H_
#include <vector>
#include "src/runtime/kernel/arm/base/fullconnection_base.h"
#include "include/context.h"
#include "src/lite_kernel.h"
#include "include/errorcode.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/common_func.h"
#include "nnacl/int8/common_func_int8.h"
using mindspore::lite::InnerContext;
#include "nnacl/int8/matmul_int8.h"
namespace mindspore::kernel {
class FullconnectionInt8CPUKernel : public FullconnectionBaseCPUKernel {
class FullconnectionInt8CPUKernel : public LiteKernel {
public:
FullconnectionInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const InnerContext *ctx,
const std::vector<lite::Tensor *> &outputs, const mindspore::lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: FullconnectionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~FullconnectionInt8CPUKernel() override { FreeTmpBuffer(); }
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
fc_param_ = reinterpret_cast<MatMulParameter *>(op_parameter_);
}
~FullconnectionInt8CPUKernel() override {
FreeTmpBuffer();
FreeQuantParam();
}
int Init() override;
int ReSize() override;
int Run() override;
public:
int RunImpl(int task_id);
private:
void FreeTmpBuffer() {
if (a_r4x16_ptr_ != nullptr) {
ctx_->allocator->Free(a_r4x16_ptr_);
a_r4x16_ptr_ = nullptr;
}
if (b_c16x4_ptr_ != nullptr) {
ctx_->allocator->Free(b_c16x4_ptr_);
b_c16x4_ptr_ = nullptr;
}
if (input_sums_ != nullptr) {
ctx_->allocator->Free(input_sums_);
input_sums_ = nullptr;
}
if (weight_bias_sums_ != nullptr) {
ctx_->allocator->Free(weight_bias_sums_);
weight_bias_sums_ = nullptr;
}
if (bias_ptr_ != nullptr) {
ctx_->allocator->Free(weight_bias_sums_);
weight_bias_sums_ = nullptr;
}
}
MatmulQuantArg quant_params_;
int8_t *a_r4x16_ptr_ = nullptr;
int8_t *b_c16x4_ptr_ = nullptr;
void InitParam();
void FreeTmpBuffer();
void FreeQuantParam();
int MallocQuantParam();
private:
MatMulParameter *fc_param_ = nullptr;
MatmulQuantParameter quant_;
int thread_count_ = 1;
int thread_stride_ = 0;
int8_t *pack_a_ptr_ = nullptr;
int8_t *pack_b_ptr_ = nullptr;
int8_t *c_ptr_ = nullptr;
int *input_sums_ = nullptr;
int *weight_bias_sums_ = nullptr;
int *bias_ptr_ = nullptr;
int r4_ = 0;
int c4_ = 0;
int d16_ = 0;
bool filter_per_channel_ = true;
};
} // namespace mindspore::kernel

View File

@ -102,12 +102,12 @@ int MatmulInt8CPUKernel::ReSize() {
auto cur_sums = weight_bias_sums_batch_ + i * params_->col_4_;
if (params_->b_transpose_) {
RowMajor2Row16x4MajorInt8(cur_b, cur_b_pack, params_->col_, params_->deep_);
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
bias_ptr_, cur_sums, ColMajor);
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, &quant_params_.weight.zp_,
bias_ptr_, cur_sums, ColMajor, false);
} else {
RowMajor2Col16x4MajorInt8(cur_b, params_->deep_, params_->col_, cur_b_pack);
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
bias_ptr_, cur_sums, RowMajor);
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, &quant_params_.weight.zp_,
bias_ptr_, cur_sums, RowMajor, false);
}
}
}
@ -166,12 +166,12 @@ int MatmulInt8CPUKernel::Run() {
auto cur_sums = weight_bias_sums_batch_ + i * params_->col_4_;
if (params_->b_transpose_) {
RowMajor2Row16x4MajorInt8(cur_b, cur_b_pack, params_->col_, params_->deep_);
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
bias_ptr_, cur_sums, ColMajor);
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, &quant_params_.weight.zp_,
bias_ptr_, cur_sums, ColMajor, false);
} else {
RowMajor2Col16x4MajorInt8(cur_b, params_->deep_, params_->col_, cur_b_pack);
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
bias_ptr_, cur_sums, RowMajor);
CalcWeightBiasSums(cur_b, params_->deep_, params_->col_, quant_params_.input.zp_, &quant_params_.weight.zp_,
bias_ptr_, cur_sums, RowMajor, false);
}
}
}