!8645 [MS][LITE][Develop]optimization for quantized mobilenet_v2

From: @lx0095
Reviewed-by: @hangangqiang,@zhang_xue_tong
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2020-11-18 09:44:15 +08:00 committed by Gitee
commit 3939874b67
11 changed files with 1518 additions and 159 deletions

File diff suppressed because it is too large Load Diff

View File

@ -378,11 +378,11 @@ void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *
"b 16f \n"
"10: \n"
"ld1 {v16.h}[0], [x10] \n"
"ld1 {v16.d}[0], [x10] \n"
"b 16f \n"
"11: \n"
"ld1 {v16.h}[0], [x10] \n"
"ld1 {v16.d}[0], [x10] \n"
"add x10, x10, #8 \n"
"ld1 {v16.s}[2], [x10] \n"
"b 16f \n"
@ -802,11 +802,12 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i
void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func) {
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int *filter_zp) {
int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1;
matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias,
left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc);
matmul_func(packed_input, packed_weight, dst, row, col, deep4, col, input_sum, bias, left_shift, right_shift,
multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc,
filter_zp);
return;
}

View File

@ -46,7 +46,7 @@ void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t
int32_t *multiplier, ConvParameter *conv_param);
void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func);
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int32_t *filter_zp);
void Conv1x1Int8Arm32(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param);

View File

@ -64,6 +64,21 @@ void MatrixEmptyInt8(int8_t *dst, int row, int col) {
return;
}
void RowMajor2Row4x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
int col4 = UP_ROUND(col, C4NUM);
for (int r = 0; r < row; r++) {
int rd16 = r / C16NUM;
int rm16 = r % C16NUM;
for (int c = 0; c < col; c++) {
int cd4 = c / C4NUM;
int cm4 = c % C4NUM;
int dst_index = rd16 * col4 * C16NUM + cd4 * C16NUM * C4NUM + rm16 * C4NUM + cm4;
int src_index = r * col + c;
dst_ptr[dst_index] = src_ptr[src_index];
}
}
}
void RowMajor2Row16x4MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
/* Row-major to row16x4-major (block row-major) */
int col16 = UP_ROUND(col, C16NUM);
@ -268,6 +283,223 @@ void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row,
return;
}
void MatMulInt8_4x16_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
size_t per_channel, int32_t *filter_zp) {
/* row4x4-major * row4x16-major => (int8)row-major */
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r4div = r / C4NUM, r4mod = r % C4NUM;
int c16div = c / C16NUM, c16mod = c % C16NUM;
size_t ci = r * col + c;
int32_t value = 0;
for (int d = 0; d < deep_4; d++) {
int d4div = d / C4NUM, d4mod = d % C4NUM;
size_t ai = r4div * deep_4 * C4NUM + d4div * C4NUM * C4NUM + r4mod * C4NUM + d4mod;
size_t bi = c16div * deep_4 * C16NUM + d4div * C16NUM * C4NUM + c16mod * C4NUM + d4mod;
value = value + a[ai] * b[bi];
}
int32_t cur_input_sum = per_channel ? input_sum[r] * filter_zp[c] : input_sum[r];
value -= cur_input_sum;
value += bias[c];
int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0];
int32_t cur_right_shift = per_channel ? right_shift[c] : right_shift[0];
int32_t cur_multiplier = per_channel ? multiplier[c] : multiplier[0];
value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp;
value = MSMIN(maxi, value);
value = MSMAX(mini, value);
dst[ci] = (int8_t)value;
}
}
return;
}
void PackInput4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum,
size_t input_channel, size_t plane_size, int32_t filter_zp) {
int ic4 = UP_ROUND(input_channel, C4NUM);
int hw4 = UP_ROUND(plane_size, C4NUM);
size_t hw_4div = plane_size / C4NUM * C4NUM;
size_t ic_4div = input_channel / C4NUM * C4NUM;
const int8_t *src_r = src_input;
int8_t *pack_r = packed_input;
/* per layer */
for (int hwi = 0; hwi < hw_4div; hwi += C4NUM) {
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
int32_t *input_sum_r = input_sum + hwi;
#ifdef ENABLE_ARM64
size_t src_stride = input_channel;
size_t ic_4res = input_channel - ic_4div;
asm volatile(
"dup v2.4s, wzr \n"
"mov x14, %[input_sum_r] \n"
"dup v3.4s, %w[filter_zp] \n"
"mov x10, %[src_ic] \n"
"mov x11, %[pack_ic] \n"
"mov x15, #0 \n"
"1: \n"
"cmp x15, %[ic_4div] \n"
"add x15, x15, #4\n"
"mov x12, x10 \n"
"add x10, x10, #4\n"
"blt 2f \n"
"cmp %[ic_4res], #0\n"
"beq 6f \n"
"cmp %[ic_4res], #1\n"
"beq 3f \n"
"cmp %[ic_4res], #2\n"
"beq 4f \n"
"cmp %[ic_4res], #3\n"
"beq 5f \n"
"2: \n"
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"saddlp v1.8h, v0.16b \n"
"saddlp v0.4s, v1.8h \n"
"add v2.4s, v2.4s, v0.4s \n"
"b 1b \n"
"3: \n" /* ic res 1 */
"dup v0.4s, wzr \n"
"ld1 {v0.b}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"saddlp v1.8h, v0.16b \n"
"saddlp v0.4s, v1.8h \n"
"add v2.4s, v2.4s, v0.4s \n"
"b 6f \n"
"4: \n" /* ic res 2 */
"dup v0.4s, wzr \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"saddlp v1.8h, v0.16b \n"
"saddlp v0.4s, v1.8h \n"
"add v2.4s, v2.4s, v0.4s \n"
"b 6f \n"
"5: \n" /* ic res 3 */
"dup v0.4s, wzr \n"
"add x13, x12, #2 \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"saddlp v1.8h, v0.16b \n"
"saddlp v0.4s, v1.8h \n"
"add v2.4s, v2.4s, v0.4s \n"
"b 6f \n"
"6: \n"
"mul v2.4s, v2.4s, v3.4s \n"
"st1 {v2.4s}, [x14], #16 \n"
:
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r),
[ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp)
: "x10", "x11", "x12", "x13", "x14", "x15", "v0", "v1", "v2", "v3");
#else
int32_t tmp_sum_value[4] = {0};
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
for (int i = 0; i < C4NUM; i++) {
tmp_sum_value[i] += src_ic[0 + i * input_channel];
tmp_sum_value[i] += src_ic[1 + i * input_channel];
tmp_sum_value[i] += src_ic[2 + i * input_channel];
tmp_sum_value[i] += src_ic[3 + i * input_channel];
pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel];
pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel];
pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel];
pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel];
}
src_ic += C4NUM;
pack_ic += C4NUM * C4NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
for (int i = 0; i < C4NUM; i++) {
tmp_sum_value[i] += src_ic[i * input_channel];
pack_ic[i * C4NUM] = src_ic[i * input_channel];
}
src_ic += 1;
pack_ic += 1;
}
for (int ici = input_channel; ici < ic4; ici += 1) {
for (int i = 0; i < C4NUM; i++) {
pack_ic[i * C4NUM] = 0;
}
pack_ic += 1;
}
for (int i = 0; i < C4NUM; i++) {
input_sum_r[i] = tmp_sum_value[i] * filter_zp;
}
#endif
src_r += input_channel * C4NUM;
pack_r += ic4 * C4NUM;
}
if (hw_4div != plane_size) {
memset(pack_r, 0, C4NUM * ic4);
for (int hwi = hw_4div; hwi < plane_size; hwi += 1) {
int32_t tmp_sum_value = 0;
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
tmp_sum_value += src_ic[0];
tmp_sum_value += src_ic[1];
tmp_sum_value += src_ic[2];
tmp_sum_value += src_ic[3];
pack_ic[0] = src_ic[0];
pack_ic[1] = src_ic[1];
pack_ic[2] = src_ic[2];
pack_ic[3] = src_ic[3];
src_ic += C4NUM;
pack_ic += C4NUM * C4NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
tmp_sum_value += src_ic[0];
pack_ic[0] = src_ic[0];
src_ic += 1;
pack_ic += 1;
}
input_sum[hwi] = tmp_sum_value * filter_zp;
src_r += input_channel;
pack_r += C4NUM;
}
for (int hwi = plane_size; hwi < hw4; hwi++) {
input_sum[hwi] = 0;
}
}
return;
}
void RowMajor2Col16x4MajorInt8(int8_t *src, int row, int col, int8_t *dst) {
int row_16 = UP_ROUND(row, C16NUM);
int stride = sizeof(int8_t) * 16 * 4;

View File

@ -52,6 +52,15 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
bool peroc);
/* 4x4 4x16 -> 4x16 */
void RowMajor2Row4x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void PackInput4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum,
size_t input_channel, size_t plane_size, int32_t filter_zp);
void MatMulInt8_4x16_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
size_t per_channel, int32_t *filter_zp);
#ifdef ENABLE_ARM64
void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums,
const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift,

View File

@ -27,6 +27,11 @@ typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini,
int32_t maxi, size_t per_channel);
typedef void (*MATMUL_OPT_DP_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini,
int32_t maxi, size_t per_channel, int *filter_zp);
typedef enum OutType { OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2 } OutType;
typedef struct MatMulParameter {
@ -40,6 +45,7 @@ typedef struct MatMulParameter {
int col_2_;
int col_4_;
int col_8_;
int col_16_;
int deep_;
int deep_4_;
int deep_16_;

View File

@ -38,6 +38,12 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_
void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16);
void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr,
size_t plane_size, size_t input_channel, size_t output_channel);
void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr,
size_t plane_size, size_t input_channel, size_t output_channel);
void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size);
void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParameter *conv_param);

View File

@ -26,16 +26,6 @@ using mindspore::lite::RET_MEMORY_FAILED;
using mindspore::lite::RET_OK;
namespace mindspore::kernel {
int Convolution1x1Int8Pre(void *cdata, int task_id) {
auto conv = reinterpret_cast<Convolution1x1Int8CPUKernel *>(cdata);
auto error_code = conv->RunPre(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv1x1 Int8 RunPre error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
Convolution1x1Int8CPUKernel::~Convolution1x1Int8CPUKernel() {
if (matmul_param_ != nullptr) {
delete matmul_param_;
@ -73,13 +63,42 @@ void Convolution1x1Int8CPUKernel::FreeResizeBuf() {
return;
}
int Convolution1x1Int8CPUKernel::InitRunBuf() {
input_sum_ = reinterpret_cast<int32_t *>(ctx_->allocator->Malloc(input_sum_size_ * sizeof(int32_t)));
if (input_sum_ == nullptr) {
MS_LOG(ERROR) << "malloc input_sum_ failed.";
return RET_ERROR;
}
size_t size = support_optimize_ ? UP_ROUND(matmul_param_->row_, C8NUM) * UP_ROUND(matmul_param_->deep_, C4NUM)
: UP_ROUND(matmul_param_->row_, C4NUM) * UP_ROUND(matmul_param_->deep_, C16NUM);
packed_input_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(size * sizeof(int8_t)));
if (packed_input_ == nullptr) {
MS_LOG(ERROR) << "conv1x1 int8 Malloc packed_input_ error!";
return RET_ERROR;
}
return RET_OK;
}
void Convolution1x1Int8CPUKernel::FreeRunBuf() {
if (packed_input_ != nullptr) {
ctx_->allocator->Free(packed_input_);
packed_input_ = nullptr;
}
if (input_sum_ != nullptr) {
ctx_->allocator->Free(input_sum_);
input_sum_ = nullptr;
}
return;
}
void Convolution1x1Int8CPUKernel::CheckSupportOptimize() {
support_optimize_ = false;
matmul_func_ = MatMulInt8_8x8_r;
matmul_func_ = MatMulInt8_4x16_r;
#ifdef ENABLE_ARM64
if (mindspore::lite::IsSupportSDot()) {
support_optimize_ = true;
matmul_func_ = MatMulRInt8_optimize_handler;
matmul_func_ = MatMulDpInt8_optimize_handler;
} else {
support_optimize_ = false;
matmul_func_ = nullptr;
@ -104,7 +123,8 @@ int Convolution1x1Int8CPUKernel::InitBiasByzp(void *src_weight, int input_channe
}
if (filter_peroc_) {
filter_zp_ptr_ = reinterpret_cast<int32_t *>(malloc(output_channel * sizeof(int32_t)));
/* filter zp */
filter_zp_ptr_ = reinterpret_cast<int32_t *>(malloc(round_oc * sizeof(int32_t)));
if (filter_zp_ptr_ == nullptr) {
return RET_ERROR;
}
@ -112,24 +132,33 @@ int Convolution1x1Int8CPUKernel::InitBiasByzp(void *src_weight, int input_channe
filter_zp_ptr_[fi] = conv_param_->conv_quant_arg_.filter_quant_args_[fi].zp_;
}
/* left shift */
left_shift_ = reinterpret_cast<int32_t *>(malloc(round_oc * sizeof(int32_t)));
if (left_shift_ == nullptr) {
return RET_ERROR;
}
memset(left_shift_, 0, round_oc * sizeof(int32_t));
memcpy(left_shift_, conv_param_->conv_quant_arg_.left_shift_, output_channel * sizeof(int32_t));
/* right shift */
right_shift_ = reinterpret_cast<int32_t *>(malloc(round_oc * sizeof(int32_t)));
if (right_shift_ == nullptr) {
return RET_ERROR;
}
memset(right_shift_, 0, round_oc * sizeof(int32_t));
memcpy(right_shift_, conv_param_->conv_quant_arg_.right_shift_, output_channel * sizeof(int32_t));
/* multiplier */
multiplier_ = reinterpret_cast<int32_t *>(malloc(round_oc * sizeof(int32_t)));
if (multiplier_ == nullptr) {
return RET_ERROR;
}
memset(multiplier_, 0, round_oc * sizeof(int32_t));
memcpy(multiplier_, conv_param_->conv_quant_arg_.quant_multiplier_, output_channel * sizeof(int32_t));
} else {
right_shift_ = conv_param_->conv_quant_arg_.right_shift_;
left_shift_ = conv_param_->conv_quant_arg_.left_shift_;
multiplier_ = conv_param_->conv_quant_arg_.quant_multiplier_;
}
return RET_OK;
}
@ -140,7 +169,7 @@ int Convolution1x1Int8CPUKernel::InitWeightBias() {
auto output_channel = filter_tensor->Batch();
/* weight */
size_t size = support_optimize_ ? UP_ROUND(input_channel, C4NUM) * UP_ROUND(output_channel, C8NUM) * sizeof(int8_t)
size_t size = support_optimize_ ? UP_ROUND(input_channel, C4NUM) * UP_ROUND(output_channel, C16NUM) * sizeof(int8_t)
: UP_ROUND(input_channel, C16NUM) * UP_ROUND(output_channel, C4NUM) * sizeof(int8_t);
packed_weight_ = reinterpret_cast<int8_t *>(malloc(size));
if (packed_weight_ == nullptr) {
@ -149,16 +178,14 @@ int Convolution1x1Int8CPUKernel::InitWeightBias() {
}
memset(packed_weight_, 0, size);
if (support_optimize_) {
RowMajor2Row8x4MajorInt8(reinterpret_cast<int8_t *>(filter_tensor->MutableData()), packed_weight_, output_channel,
input_channel);
RowMajor2Row4x16MajorInt8(reinterpret_cast<int8_t *>(filter_tensor->MutableData()), packed_weight_, output_channel,
input_channel);
} else {
RowMajor2Row16x4MajorInt8(reinterpret_cast<int8_t *>(filter_tensor->MutableData()), packed_weight_, output_channel,
input_channel);
}
int col4 = UP_ROUND(output_channel, C4NUM);
int col8 = UP_ROUND(output_channel, C8NUM);
size = support_optimize_ ? col8 : col4;
size = support_optimize_ ? UP_ROUND(output_channel, C16NUM) : UP_ROUND(output_channel, C4NUM);
bias_data_ = malloc(size * sizeof(int32_t));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 int8 Malloc bias_ptr_ error!";
@ -166,10 +193,10 @@ int Convolution1x1Int8CPUKernel::InitWeightBias() {
}
memset(bias_data_, 0, size * sizeof(int32_t));
if (in_tensors_.size() == 3) {
memcpy(bias_data_, in_tensors_[kBiasIndex]->MutableData(), output_channel * sizeof(int32_t));
memcpy(bias_data_, in_tensors_[kBiasIndex]->data_c(), output_channel * sizeof(int32_t));
}
InitBiasByzp(filter_tensor->MutableData(), input_channel, output_channel, size);
InitBiasByzp(filter_tensor->data_c(), input_channel, output_channel, size);
return RET_OK;
}
@ -198,7 +225,7 @@ int Convolution1x1Int8CPUKernel::InitWeightBiasArm32() {
}
memset(bias_data_, 0, col2 * sizeof(int32_t));
if (in_tensors_.size() == 3) {
memcpy(bias_data_, in_tensors_[kBiasIndex]->MutableData(), output_channel * sizeof(int32_t));
memcpy(bias_data_, in_tensors_[kBiasIndex]->data_c(), output_channel * sizeof(int32_t));
}
InitBiasByzp(filter_tensor->MutableData(), input_channel, output_channel, col2);
@ -248,6 +275,7 @@ int Convolution1x1Int8CPUKernel::InitParam() {
matmul_param_->col_2_ = UP_ROUND(matmul_param_->col_, C2NUM);
matmul_param_->col_4_ = UP_ROUND(matmul_param_->col_, C4NUM);
matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM);
matmul_param_->col_16_ = UP_ROUND(matmul_param_->col_, C16NUM);
matmul_param_->row_4_ = UP_ROUND(matmul_param_->row_, C4NUM);
matmul_param_->row_8_ = UP_ROUND(matmul_param_->row_, C8NUM);
matmul_param_->deep_4_ = UP_ROUND(matmul_param_->deep_, C4NUM);
@ -255,13 +283,14 @@ int Convolution1x1Int8CPUKernel::InitParam() {
int row_pack_count = 0;
int col_pack_count = 0;
#ifdef ENABLE_ARM32
row_pack_count = C4NUM;
col_pack_count = C2NUM;
#else
if (support_optimize_) {
row_pack_count = C8NUM;
col_pack_count = C8NUM;
row_pack_count = C4NUM;
col_pack_count = C16NUM;
} else {
row_pack_count = C4NUM;
col_pack_count = C4NUM;
@ -269,17 +298,18 @@ int Convolution1x1Int8CPUKernel::InitParam() {
#endif
/* init input sum size */
if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) {
input_sum_size_ = UP_ROUND(matmul_param_->col_, col_pack_count) * UP_ROUND(matmul_param_->row_, row_pack_count);
} else {
if (support_optimize_) {
input_sum_size_ = UP_ROUND(matmul_param_->row_, row_pack_count);
} else {
if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) {
input_sum_size_ = UP_ROUND(matmul_param_->col_, col_pack_count) * UP_ROUND(matmul_param_->row_, row_pack_count);
} else {
input_sum_size_ = UP_ROUND(matmul_param_->row_, row_pack_count);
}
}
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, col_pack_count));
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, col_pack_count), thread_count_);
thread_count_hw_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, row_pack_count));
thread_stride_hw_ = UP_DIV(UP_DIV(matmul_param_->row_, row_pack_count), thread_count_hw_);
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, row_pack_count));
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->row_, row_pack_count), thread_count_);
if (pre_trans_input_) {
input_ptr_ = reinterpret_cast<int8_t *>(malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(int8_t)));
@ -306,111 +336,116 @@ int Convolution1x1Int8CPUKernel::ReSize() {
}
void Convolution1x1Int8CPUKernel::Pre1x1Trans(int8_t *src_input, int8_t *src_output) {
/* deal with pad and stride */
output_ptr_ = src_output;
if (pre_trans_input_) {
Conv1x1InputPack(src_input, input_ptr_, conv_param_, sizeof(int8_t));
} else {
input_ptr_ = src_input;
}
if (support_optimize_) {
ParallelLaunch(this->context_->thread_pool_, Convolution1x1Int8Pre, this, thread_count_hw_);
} else {
RowMajor2Row16x4MajorInt8(input_ptr_, packed_input_, matmul_param_->row_, matmul_param_->deep_);
PackInputSum16x4Int8(packed_input_, input_sum_, filter_zp_ptr_, conv_param_);
}
return;
}
int Convolution1x1Int8CPUKernel::RunImpl(int task_id) {
int32_t *cur_input_sum = input_sum_;
int32_t *cur_left_shift = conv_param_->conv_quant_arg_.left_shift_;
int32_t *cur_right_shift = conv_param_->conv_quant_arg_.right_shift_;
int32_t *cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_;
#ifdef ENABLE_ARM32
int cur_stride = thread_stride_ * C2NUM;
int res_stride = matmul_param_->col_ - task_id * thread_stride_ * C2NUM;
int cur_oc = MSMIN(cur_stride, res_stride);
if (cur_oc <= 0) {
return RET_OK;
}
if (filter_peroc_) {
cur_input_sum = input_sum_ + task_id * matmul_param_->row_4_ * thread_stride_ * C2NUM;
cur_left_shift = left_shift_ + task_id * thread_stride_ * C2NUM;
cur_right_shift = right_shift_ + task_id * thread_stride_ * C2NUM;
cur_multiplier = multiplier_ + task_id * thread_stride_ * C2NUM;
}
Conv1x1Int8Arm32(packed_input_, packed_weight_ + task_id * thread_stride_ * C2NUM * matmul_param_->deep_16_,
output_ptr_ + task_id * thread_stride_ * C2NUM, cur_input_sum,
reinterpret_cast<int32_t *>(bias_data_) + task_id * thread_stride_ * C2NUM, matmul_param_->row_,
cur_oc, matmul_param_->deep_16_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_);
#else
if (support_optimize_) {
int cur_stride = thread_stride_ * C8NUM;
int res_stride = matmul_param_->col_ - task_id * thread_stride_ * C8NUM;
int cur_oc = MSMIN(cur_stride, res_stride);
if (cur_oc <= 0) {
return RET_OK;
}
if (filter_peroc_) {
cur_input_sum = input_sum_ + task_id * matmul_param_->row_8_ * thread_stride_ * C8NUM;
cur_left_shift = left_shift_ + task_id * thread_stride_ * C8NUM;
cur_right_shift = right_shift_ + task_id * thread_stride_ * C8NUM;
cur_multiplier = multiplier_ + task_id * thread_stride_ * C8NUM;
}
Conv1x1Int8Opt(packed_input_, packed_weight_ + task_id * thread_stride_ * C8NUM * matmul_param_->deep_4_,
output_ptr_ + task_id * thread_stride_ * C8NUM, cur_input_sum,
reinterpret_cast<int32_t *>(bias_data_) + task_id * thread_stride_ * C8NUM, matmul_param_->row_,
cur_oc, matmul_param_->deep_4_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_,
matmul_func_);
} else {
int cur_stride = thread_stride_ * C4NUM;
int res_stride = matmul_param_->col_ - task_id * thread_stride_ * C4NUM;
int cur_oc = MSMIN(cur_stride, res_stride);
if (cur_oc <= 0) {
return RET_OK;
}
if (filter_peroc_) {
cur_input_sum = input_sum_ + task_id * matmul_param_->row_4_ * thread_stride_ * C4NUM;
cur_left_shift = left_shift_ + task_id * thread_stride_ * C4NUM;
cur_right_shift = right_shift_ + task_id * thread_stride_ * C4NUM;
cur_multiplier = multiplier_ + task_id * thread_stride_ * C4NUM;
}
Conv1x1Int8(packed_input_, packed_weight_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_16_,
output_ptr_ + task_id * thread_stride_ * C4NUM, cur_input_sum,
reinterpret_cast<int32_t *>(bias_data_) + task_id * thread_stride_ * C4NUM, matmul_param_->row_, cur_oc,
matmul_param_->deep_16_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_);
}
#endif
return RET_OK;
}
int Convolution1x1Int8CPUKernel::RunPre(int task_id) {
int cur_stride = thread_stride_hw_ * C8NUM;
int res_stride = matmul_param_->row_ - task_id * thread_stride_hw_ * C8NUM;
int Convolution1x1Int8CPUKernel::RunArm64(int task_id) {
int cur_stride = thread_stride_ * C4NUM;
int res_stride = matmul_param_->row_ - task_id * thread_stride_ * C4NUM;
int cur_hw = MSMIN(cur_stride, res_stride);
if (cur_hw <= 0) {
return RET_OK;
}
int8_t *hw_in = input_ptr_ + task_id * thread_stride_ * C4NUM * conv_param_->input_channel_;
int8_t *hw_out = output_ptr_ + task_id * thread_stride_ * C4NUM * conv_param_->output_channel_;
int8_t *hw_packed_in = packed_input_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_16_;
int32_t *hw_input_sum = filter_peroc_ ? input_sum_ + task_id * thread_stride_ * C4NUM * matmul_param_->col_4_
: input_sum_ + task_id * thread_stride_ * C4NUM;
RowMajor2Row16x4MajorInt8(hw_in, hw_packed_in, cur_hw, matmul_param_->deep_);
if (filter_peroc_) {
Conv1x1PreOptPeroc(input_ptr_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_,
packed_input_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_4_,
input_sum_ + task_id * thread_stride_hw_ * C8NUM * C8NUM, matmul_param_->deep_,
matmul_param_->col_, cur_hw, filter_zp_ptr_, matmul_param_->row_8_ * C8NUM);
PackInputSum16x4PerChannel(hw_packed_in, hw_input_sum, filter_zp_ptr_, cur_hw, matmul_param_->deep_,
matmul_param_->col_);
} else {
Conv1x1PreOptPert(input_ptr_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_,
packed_input_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_4_,
input_sum_ + task_id * thread_stride_hw_ * C8NUM, matmul_param_->deep_, cur_hw, conv_param_);
PackInputSum16x4PerLayer(hw_packed_in, hw_input_sum, conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_,
UP_ROUND(cur_hw, C4NUM), matmul_param_->deep_16_);
}
Conv1x1Int8(hw_packed_in, packed_weight_, hw_out, hw_input_sum, reinterpret_cast<int32_t *>(bias_data_), cur_hw,
matmul_param_->col_, matmul_param_->deep_16_, left_shift_, right_shift_, multiplier_, conv_param_);
return RET_OK;
}
int Convolution1x1Int8CPUKernel::RunArm32(int task_id) {
int cur_stride = thread_stride_ * C4NUM;
int res_stride = matmul_param_->row_ - task_id * thread_stride_ * C4NUM;
int cur_hw = MSMIN(cur_stride, res_stride);
if (cur_hw <= 0) {
return RET_OK;
}
int8_t *hw_in = input_ptr_ + task_id * thread_stride_ * C4NUM * conv_param_->input_channel_;
int8_t *hw_out = output_ptr_ + task_id * thread_stride_ * C4NUM * conv_param_->output_channel_;
int8_t *hw_packed_in = packed_input_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_16_;
int32_t *hw_input_sum = filter_peroc_ ? input_sum_ + task_id * thread_stride_ * C4NUM * matmul_param_->col_2_
: input_sum_ + task_id * thread_stride_ * C4NUM;
RowMajor2Row16x4MajorInt8(hw_in, hw_packed_in, cur_hw, matmul_param_->deep_);
if (filter_peroc_) {
PackInputSum16x4PerChannelArm32(hw_packed_in, hw_input_sum, filter_zp_ptr_, cur_hw, conv_param_->input_channel_,
conv_param_->output_channel_);
} else {
PackInputSum16x4PerLayer(hw_packed_in, hw_input_sum, conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_,
UP_ROUND(cur_hw, C4NUM), matmul_param_->deep_16_);
}
Conv1x1Int8Arm32(hw_packed_in, packed_weight_, hw_out, hw_input_sum, reinterpret_cast<int32_t *>(bias_data_), cur_hw,
matmul_param_->col_, matmul_param_->deep_16_, left_shift_, right_shift_, multiplier_, conv_param_);
return RET_OK;
}
int Convolution1x1Int8Impl(void *cdata, int task_id) {
int Convolution1x1Int8CPUKernel::RunArm64Opt(int task_id) {
int cur_stride = thread_stride_ * C4NUM;
int res_stride = matmul_param_->row_ - task_id * thread_stride_ * C4NUM;
int cur_hw = MSMIN(cur_stride, res_stride);
if (cur_hw <= 0) {
return RET_OK;
}
int8_t *hw_in = input_ptr_ + task_id * thread_stride_ * C4NUM * conv_param_->input_channel_;
int8_t *hw_out = output_ptr_ + task_id * thread_stride_ * C4NUM * conv_param_->output_channel_;
int8_t *hw_packed_in = packed_input_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_4_;
int32_t *hw_input_sum = input_sum_ + task_id * thread_stride_ * C4NUM;
if (filter_peroc_) {
PackInput4x4AndInputSumPert(hw_in, hw_packed_in, hw_input_sum, matmul_param_->deep_, cur_hw, 1);
} else {
PackInput4x4AndInputSumPert(hw_in, hw_packed_in, hw_input_sum, matmul_param_->deep_, cur_hw,
conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_);
}
Conv1x1Int8Opt(hw_packed_in, packed_weight_, hw_out, hw_input_sum, reinterpret_cast<int32_t *>(bias_data_), cur_hw,
matmul_param_->col_, matmul_param_->deep_4_, left_shift_, right_shift_, multiplier_, conv_param_,
matmul_func_, filter_zp_ptr_);
return RET_OK;
}
int Convolution1x1Int8CPUKernel::DoRun(int task_id) {
#ifdef ENABLE_ARM32
return RunArm32(task_id);
#else
if (support_optimize_) {
return RunArm64Opt(task_id);
} else {
return RunArm64(task_id);
}
#endif
}
int Convolution1x1Int8Run(void *cdata, int task_id) {
auto conv = reinterpret_cast<Convolution1x1Int8CPUKernel *>(cdata);
auto error_code = conv->RunImpl(task_id);
auto error_code = conv->DoRun(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv1x1 Int8 Run error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
@ -418,35 +453,6 @@ int Convolution1x1Int8Impl(void *cdata, int task_id) {
return RET_OK;
}
int Convolution1x1Int8CPUKernel::InitRunBuf() {
input_sum_ = reinterpret_cast<int32_t *>(ctx_->allocator->Malloc(input_sum_size_ * sizeof(int32_t)));
if (input_sum_ == nullptr) {
MS_LOG(ERROR) << "malloc input_sum_ failed.";
return RET_ERROR;
}
size_t size = support_optimize_ ? UP_ROUND(matmul_param_->row_, C8NUM) * UP_ROUND(matmul_param_->deep_, C4NUM)
: UP_ROUND(matmul_param_->row_, C4NUM) * UP_ROUND(matmul_param_->deep_, C16NUM);
packed_input_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(size * sizeof(int8_t)));
if (packed_input_ == nullptr) {
MS_LOG(ERROR) << "conv1x1 int8 Malloc packed_input_ error!";
return RET_ERROR;
}
return RET_OK;
}
void Convolution1x1Int8CPUKernel::FreeRunBuf() {
if (packed_input_ != nullptr) {
ctx_->allocator->Free(packed_input_);
packed_input_ = nullptr;
}
if (input_sum_ != nullptr) {
ctx_->allocator->Free(input_sum_);
input_sum_ = nullptr;
}
return;
}
int Convolution1x1Int8CPUKernel::Run() {
int error_code = InitRunBuf();
if (error_code != RET_OK) {
@ -461,7 +467,7 @@ int Convolution1x1Int8CPUKernel::Run() {
for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) {
Pre1x1Trans(src_in + batch_index * conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_channel_,
src_out + batch_index * matmul_param_->row_ * matmul_param_->col_);
auto ret = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Int8Impl, this, thread_count_);
auto ret = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Int8Run, this, thread_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ParallelLaunch run error error_code[" << ret << "]";
FreeRunBuf();

View File

@ -45,8 +45,12 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel {
void FreeRunBuf();
public:
int RunImpl(int task_id);
int RunPre(int task_id);
int DoRun(int task_id);
private:
int RunArm32(int task_id);
int RunArm64(int task_id);
int RunArm64Opt(int task_id);
private:
void FreeResizeBuf();
@ -58,8 +62,8 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel {
int InitBiasByzp(void *src_weight, int input_channel, int output_channel, int round_oc);
private:
int32_t *input_sum_ = nullptr; /* per-oc: oc4 format */
int32_t *filter_zp_ptr_ = nullptr; /* per-oc */
int32_t *input_sum_ = nullptr; /* per-oc */
int32_t *filter_zp_ptr_ = nullptr; /* per-oc up round */
int32_t *left_shift_ = nullptr; /* per-oc up round */
int32_t *right_shift_ = nullptr; /* per-oc up round */
int32_t *multiplier_ = nullptr; /* per-oc up round */
@ -69,12 +73,10 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel {
int8_t *output_ptr_ = nullptr;
size_t thread_count_ = 1;
size_t thread_stride_ = 0;
size_t thread_count_hw_ = 1;
size_t thread_stride_hw_ = 0;
bool pre_trans_input_ = false;
size_t input_sum_size_ = 0;
MatMulParameter *matmul_param_ = nullptr;
MATMUL_OPT_R_FUNC matmul_func_ = nullptr;
MATMUL_OPT_DP_FUNC matmul_func_ = nullptr;
bool support_optimize_ = false;
bool filter_peroc_ = false;
};

View File

@ -33,6 +33,9 @@ extern void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, in
const int *a_sums, const int *bias, int act_min, int act_max, int out_zp,
int *multiplier, int *left_shift, int *right_shift, int row, int col, int stride,
size_t peroc);
extern void MatmulInt8DpOpt(const int8_t *a, const int8_t *b, int8_t *dst, size_t row8, size_t col8, size_t deep4,
const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, int *multiplier,
int *left_shift, int *right_shift, size_t stride, size_t peroc, int *filter_zp);
#ifdef ENABLE_ARM64
void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias,
@ -57,6 +60,13 @@ void MatMulRInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst,
return MatmulInt8DpNeon64(a, b, dst, UP_ROUND(row, 8), UP_ROUND(col, 8), deep_4, input_sum, bias, mini, maxi,
output_zp, multiplier, left_shift, right_shift, row, col, stride, per_channel);
}
void MatMulDpInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini,
int32_t maxi, size_t per_channel, int32_t *filter_zp) {
return MatmulInt8DpOpt(a, b, dst, row, col, deep_4, input_sum, bias, mini, maxi, output_zp, multiplier, left_shift,
right_shift, stride, per_channel, filter_zp);
}
#endif
#ifdef __cplusplus

View File

@ -33,6 +33,10 @@ void MatMulRInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini,
int32_t maxi, size_t per_channel);
void MatMulDpInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini,
int32_t maxi, size_t per_channel, int32_t *filter_zp);
#endif
#ifdef __cplusplus