[MS][LITE] arm cpu int8 op conv depthwise: support filter per channel

This commit is contained in:
yangruoqi713 2020-08-14 15:39:51 +08:00
parent 5ca5c346bb
commit 7aef961358
2 changed files with 54 additions and 22 deletions

View File

@ -21,8 +21,9 @@
/*conv depthwise int8 begin*/
void DepthwiseBorderPixelInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int height,
int width, int in_kh_step, int in_kw_step, int kernel_w, int out_multiplier,
int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max) {
int width, int in_kh_step, int in_kw_step, int kernel_w, int *out_multiplier,
int *left_shift, int *right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max,
bool per_channel) {
int tmp_buffer[C4NUM];
for (int i = 0; i < C4NUM; i++) {
tmp_buffer[i] = 0;
@ -42,10 +43,18 @@ void DepthwiseBorderPixelInt8(int8_t *dst, const int16_t *src, const int16_t *we
src_kh += in_kh_step;
weight_kh += kernel_w * C4NUM;
} // kernel_h loop
int32_t left = left_shift[0];
int32_t right = right_shift[0];
int32_t multiplier = out_multiplier[0];
for (int c = 0; c < C4NUM; c++) {
if (per_channel) {
left = left_shift[c];
right = right_shift[c];
multiplier = out_multiplier[c];
}
tmp_buffer[c] += bias[c];
tmp_buffer[c] = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift), out_multiplier), -right_shift);
SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left), multiplier), right);
tmp_buffer[c] += out_zp;
tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min);
tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max);
@ -55,7 +64,8 @@ void DepthwiseBorderPixelInt8(int8_t *dst, const int16_t *src, const int16_t *we
void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int top,
int bottom, int left, int right, const ConvParameter *conv_param,
const SlidingWindowParam *sliding) {
const SlidingWindowParam *sliding, int *out_multiplier, int *left_shift, int *right_shift,
bool per_channel) {
int8_t *dst_h = dst + top * sliding->out_h_step_;
for (int oh = top; oh < bottom; oh++) {
int ih = oh * conv_param->stride_h_ - conv_param->pad_h_;
@ -73,12 +83,11 @@ void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight,
const int16_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_;
const int16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM;
DepthwiseBorderPixelInt8(
dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, sliding->in_kh_step_,
sliding->in_kw_step_, conv_param->kernel_w_, conv_param->conv_quant_arg_.quant_multiplier_[0],
conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0],
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0],
conv_param->conv_quant_arg_.out_act_max_[0]);
DepthwiseBorderPixelInt8(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_, out_multiplier,
left_shift, right_shift, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
per_channel);
dst_kernel += sliding->block_channel_;
} // width loop
@ -89,8 +98,8 @@ void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight,
#ifndef ENABLE_ARM64
void DepthwiseCenterInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int height,
int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step,
int in_sw_step, int in_kh_step, int in_kw_step, int out_multiplier, int left_shift,
int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max) {
int in_sw_step, int in_kh_step, int in_kw_step, int *out_multiplier, int *left_shift,
int *right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max, bool per_channel) {
int tmp_buffer[C4NUM];
int8_t *dst_h = dst;
const int16_t *src_h = src;
@ -118,11 +127,18 @@ void DepthwiseCenterInt8(int8_t *dst, const int16_t *src, const int16_t *weight,
weight_kh += kernel_w * C4NUM;
} // kernel_h loop
// add bias relu
int32_t left = left_shift[0];
int32_t right = right_shift[0];
int32_t multiplier = out_multiplier[0];
for (int c = 0; c < C4NUM; c++) {
if (per_channel) {
left = left_shift[c];
right = right_shift[c];
multiplier = out_multiplier[c];
}
tmp_buffer[c] += bias[c];
tmp_buffer[c] = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift), out_multiplier),
-right_shift);
SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left), multiplier), -right);
tmp_buffer[c] += out_zp;
tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min);
tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max);
@ -141,20 +157,33 @@ void ConvDwInt8(int8_t *output_data, const int16_t *input_data, const int16_t *w
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) {
const int16_t *src = input_data;
int8_t *dst = output_data;
bool per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL;
int *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_;
int *left_shift = conv_param->conv_quant_arg_.left_shift_;
int *right_shift = conv_param->conv_quant_arg_.right_shift_;
for (int b = 0; b < conv_param->output_batch_; b++) {
for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) {
const int16_t *src_data = src + oc * C4NUM;
int8_t *dst_data = dst + oc * C4NUM;
const int16_t *weight = weight_data + oc * sliding->kernel_step_;
const int32_t *bias = bias_data + oc * C4NUM;
if (per_channel) {
out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_ + oc;
left_shift = conv_param->conv_quant_arg_.left_shift_ + oc;
right_shift = conv_param->conv_quant_arg_.right_shift_ + oc;
}
DepthwiseBorderInt8(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param,
sliding);
sliding, out_multiplier, left_shift, right_shift, per_channel);
DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0,
conv_param->output_w_, conv_param, sliding);
conv_param->output_w_, conv_param, sliding, out_multiplier, left_shift, right_shift,
per_channel);
DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_,
conv_param, sliding);
conv_param, sliding, out_multiplier, left_shift, right_shift, per_channel);
DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_,
conv_param->output_w_, conv_param, sliding);
conv_param->output_w_, conv_param, sliding, out_multiplier, left_shift, right_shift,
per_channel);
if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_;
@ -171,13 +200,13 @@ void ConvDwInt8(int8_t *output_data, const int16_t *input_data, const int16_t *w
conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]);
#else
DepthwiseCenterInt8(
out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_,
sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_,
conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0],
conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]);
sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, out_multiplier,
left_shift, right_shift, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], per_channel);
#endif
}
} // output C4 loop

View File

@ -847,6 +847,9 @@ void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight
int weight_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_;
int unit = conv_param->kernel_h_ * conv_param->kernel_w_;
for (int c = 0; c < conv_param->output_channel_; c++) {
if (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL) {
weight_zp = conv_param->conv_quant_arg_.filter_quant_args_[c].zp_;
}
int c4_block_num = c / C4NUM;
int c4_block_rem = c % C4NUM;
int8_t *src_c = origin_weight + c * unit;