optimize int8 pooling

This commit is contained in:
fuzhiye 2020-08-27 18:53:04 +08:00
parent 29070d60a1
commit d0aa719a80
14 changed files with 313 additions and 166 deletions

View File

@ -264,7 +264,8 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c
int output_tile_count = UP_DIV(output_count, tile_n);
int ic4 = UP_DIV(in_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
int unit_size = kernel_plane * ic4 * C4NUM;
int plane_block = UP_DIV(kernel_plane, C4NUM);
int unit_size = plane_block * C4NUM * ic4 * C4NUM;
int packed_input_size = output_tile_count * tile_n * unit_size;
int input_sum_offset;
if (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL) {

View File

@ -89,8 +89,13 @@ void AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
int output_batch = pooling_param->output_batch_;
int out_plane = output_w * output_h;
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
int thread_num = pooling_param->thread_num_;
int c8 = UP_DIV(channel, C8NUM);
int thread_num = out_tile_count < pooling_param->thread_num_ ? out_tile_count : pooling_param->thread_num_;
float input_scale = pooling_param->quant_args_[0][0].scale_;
int input_zp = pooling_param->quant_args_[0][0].zp_;
float output_scale = pooling_param->quant_args_[1][0].scale_;
int output_zp = pooling_param->quant_args_[1][0].zp_;
double real_multiplier = input_scale / output_scale;
int c16 = channel / C16NUM;
const int8_t out_min = INT8_MIN;
const int8_t out_max = INT8_MAX;
@ -107,89 +112,159 @@ void AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
int in_w_index = out_w_index * stride_w - pad_w;
int in_h_index = out_h_index * stride_h - pad_h;
int out_plane_offset = out_batch_offset + index * channel;
for (int j = 0; j < c8 - 1; j++) {
int in_channel_offset = in_batch_offset + j * C8NUM;
int out_channel_offset = out_plane_offset + j * C8NUM;
int16_t tmp_avg1 = 0;
int16_t tmp_avg2 = 0;
int16_t tmp_avg3 = 0;
int16_t tmp_avg4 = 0;
int16_t tmp_avg5 = 0;
int16_t tmp_avg6 = 0;
int16_t tmp_avg7 = 0;
int16_t tmp_avg8 = 0;
int real_count = 0;
for (int h = 0; h < win_h; h++) {
for (int w = 0; w < win_w; w++) {
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
(in_w_index + w) >= in_w) {
continue;
} else {
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
tmp_avg1 += *(input_ptr + in_offset);
tmp_avg2 += *(input_ptr + in_offset + 1);
tmp_avg3 += *(input_ptr + in_offset + 2);
tmp_avg4 += *(input_ptr + in_offset + 3);
tmp_avg5 += *(input_ptr + in_offset + 4);
tmp_avg6 += *(input_ptr + in_offset + 5);
tmp_avg7 += *(input_ptr + in_offset + 6);
tmp_avg8 += *(input_ptr + in_offset + 7);
++real_count;
int input_stride = (in_h_index * in_w + in_w_index) * channel;
int kw_s = MSMAX(0, -in_w_index);
int kw_e = MSMIN(win_w, in_w - in_w_index);
int kh_s = MSMAX(0, -in_h_index);
int kh_e = MSMIN(win_h, in_h - in_h_index);
int real_count = (kw_e - kw_s) * (kh_e - kh_s);
// 16 channels
for (int j = 0; j < c16; j++) {
#ifdef ENABLE_NEON
int16x8_t tmp_avg[2];
tmp_avg[0] = vmovq_n_s16(0);
tmp_avg[1] = vmovq_n_s16(0);
#else
int16_t tmp_avg[16];
int16_t real_out[16];
for (int m = 0; m < C16NUM; ++m) {
tmp_avg[m] = 0;
}
#endif
int in_channel_offset = in_batch_offset + j * C16NUM;
int out_channel_offset = out_plane_offset + j * C16NUM;
for (int h = kh_s; h < kh_e; h++) {
for (int w = kw_s; w < kw_e; w++) {
int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel;
#ifdef ENABLE_NEON
int8x16_t in_ptr = vld1q_s8(input_ptr + in_offset);
int8x8_t in_data1 = vget_low_s8(in_ptr);
int8x8_t in_data2 = vget_high_s8(in_ptr);
int16x8_t data1 = vmovl_s8(in_data1);
int16x8_t data2 = vmovl_s8(in_data2);
tmp_avg[0] = vaddq_s16(tmp_avg[0], data1);
tmp_avg[1] = vaddq_s16(tmp_avg[1], data2);
#else
for (int k = 0; k < C16NUM; ++k) {
tmp_avg[k] += input_ptr[in_offset + k];
}
#endif
} // win_w loop
} // win_h loop
int16_t tmp_out1 = round((float)tmp_avg1 / (float)real_count);
int16_t tmp_out2 = round((float)tmp_avg2 / (float)real_count);
int16_t tmp_out3 = round((float)tmp_avg3 / (float)real_count);
int16_t tmp_out4 = round((float)tmp_avg4 / (float)real_count);
int16_t tmp_out5 = round((float)tmp_avg5 / (float)real_count);
int16_t tmp_out6 = round((float)tmp_avg6 / (float)real_count);
int16_t tmp_out7 = round((float)tmp_avg7 / (float)real_count);
int16_t tmp_out8 = round((float)tmp_avg8 / (float)real_count);
int16_t real_out1 = tmp_out1 < out_min ? out_min : tmp_out1;
int16_t real_out2 = tmp_out2 < out_min ? out_min : tmp_out2;
int16_t real_out3 = tmp_out3 < out_min ? out_min : tmp_out3;
int16_t real_out4 = tmp_out4 < out_min ? out_min : tmp_out4;
int16_t real_out5 = tmp_out5 < out_min ? out_min : tmp_out5;
int16_t real_out6 = tmp_out6 < out_min ? out_min : tmp_out6;
int16_t real_out7 = tmp_out7 < out_min ? out_min : tmp_out7;
int16_t real_out8 = tmp_out8 < out_min ? out_min : tmp_out8;
real_out1 = real_out1 > out_max ? out_max : real_out1;
real_out2 = real_out2 > out_max ? out_max : real_out2;
real_out3 = real_out3 > out_max ? out_max : real_out3;
real_out4 = real_out4 > out_max ? out_max : real_out4;
real_out5 = real_out5 > out_max ? out_max : real_out5;
real_out6 = real_out6 > out_max ? out_max : real_out6;
real_out7 = real_out7 > out_max ? out_max : real_out7;
real_out8 = real_out8 > out_max ? out_max : real_out8;
*(output_ptr + out_channel_offset) = (int8_t)real_out1;
*(output_ptr + out_channel_offset + 1) = (int8_t)real_out2;
*(output_ptr + out_channel_offset + 2) = (int8_t)real_out3;
*(output_ptr + out_channel_offset + 3) = (int8_t)real_out4;
*(output_ptr + out_channel_offset + 4) = (int8_t)real_out5;
*(output_ptr + out_channel_offset + 5) = (int8_t)real_out6;
*(output_ptr + out_channel_offset + 6) = (int8_t)real_out7;
*(output_ptr + out_channel_offset + 7) = (int8_t)real_out8;
} // in_channel loop
int channel_s = (c8 - 1) * C8NUM;
for (int k = channel_s; k < channel; k++) {
int in_channel_offset = in_batch_offset + k;
int out_channel_offset = out_plane_offset + k;
#ifdef ENABLE_NEON
int16_t tmp_data[8];
int16_t tmp_out[8];
int16_t tmp_data1[8];
int16_t tmp_out1[8];
for (int l = 0; l < C8NUM; l++) {
tmp_data[l] = tmp_avg[0][l] + 128 * real_count;
tmp_out[l] = (tmp_data[l] + real_count / 2) / real_count;
tmp_out[l] -= 128;
tmp_out[l] = round((tmp_out[l] - input_zp) * real_multiplier) + output_zp;
}
for (int l = 0; l < C8NUM; l++) {
tmp_data1[l] = tmp_avg[1][l] + 128 * real_count;
tmp_out1[l] = (tmp_data1[l] + real_count / 2) / real_count;
tmp_out1[l] -= 128;
tmp_out1[l] = round((tmp_out1[l] - input_zp) * real_multiplier) + output_zp;
}
int8x8_t real_out[2];
int8x8_t output_min = vdup_n_s8(out_min);
int8x8_t output_max = vdup_n_s8(out_max);
real_out[0] = vqmovn_s16(vld1q_s16(tmp_out));
real_out[0] = vmin_s8(real_out[0], output_max);
real_out[0] = vmax_s8(real_out[0], output_min);
vst1_s8(output_ptr + out_channel_offset, real_out[0]);
real_out[1] = vqmovn_s16(vld1q_s16(tmp_out1));
real_out[1] = vmin_s8(real_out[1], output_max);
real_out[1] = vmax_s8(real_out[1], output_min);
vst1_s8(output_ptr + out_channel_offset + 8, real_out[1]);
#else
for (int l = 0; l < C16NUM; ++l) {
int16_t tmp_data = tmp_avg[l] + 128 * real_count;
real_out[l] = (tmp_data + real_count / 2) / real_count - 128;
real_out[l] = (int8_t)(round((real_out[l] - input_zp) * real_multiplier) + output_zp);
real_out[l] = real_out[l] < out_min ? out_min : real_out[l];
real_out[l] = real_out[l] > out_max ? out_max : real_out[l];
*(output_ptr + out_channel_offset + l) = (int8_t)real_out[l];
}
#endif
}
// 8 channels
int channel_16_res = channel - c16 * C16NUM;
int c8 = channel_16_res / C8NUM;
int in_c16_offset = in_batch_offset + c16 * C16NUM;
int out_c16_offset = out_plane_offset + c16 * C16NUM;
for (int j = 0; j < c8; j++) {
#ifdef ENABLE_NEON
int16x8_t tmp_avg = vmovq_n_s16(0);
#else
int16_t tmp_avg[8] = {0, 0, 0, 0, 0, 0, 0, 0};
int16_t real_out[8];
#endif
int in_channel_offset = in_c16_offset + j * C8NUM;
int out_channel_offset = out_c16_offset + j * C8NUM;
for (int h = kh_s; h < kh_e; h++) {
for (int w = kw_s; w < kw_e; w++) {
int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel;
#ifdef ENABLE_NEON
int8x8_t in_ptr = vld1_s8(input_ptr + in_offset);
int16x8_t data = vmovl_s8(in_ptr);
tmp_avg = vaddq_s16(tmp_avg, data);
#else
for (int k = 0; k < C8NUM; ++k) {
tmp_avg[k] += input_ptr[in_offset + k];
}
#endif
} // win_w loop
} // win_h loop
#ifdef ENABLE_NEON
int16_t tmp_data[8];
int16_t tmp_out[8];
for (int l = 0; l < C8NUM; l++) {
tmp_data[l] = tmp_avg[l] + 128 * real_count;
tmp_out[l] = (tmp_data[l] + real_count / 2) / real_count;
tmp_out[l] -= 128;
tmp_out[l] = round((tmp_out[l] - input_zp) * real_multiplier) + output_zp;
}
int8x8_t real_out;
int8x8_t output_min = vdup_n_s8(out_min);
int8x8_t output_max = vdup_n_s8(out_max);
real_out = vqmovn_s16(vld1q_s16(tmp_out));
real_out = vmin_s8(real_out, output_max);
real_out = vmax_s8(real_out, output_min);
vst1_s8(output_ptr + out_channel_offset, real_out);
#else
for (int l = 0; l < C8NUM; ++l) {
int16_t tmp_data = tmp_avg[l] + 128 * real_count;
real_out[l] = (tmp_data + real_count / 2) / real_count - 128;
real_out[l] = (int8_t)(round((real_out[l] - input_zp) * real_multiplier) + output_zp);
real_out[l] = real_out[l] < out_min ? out_min : real_out[l];
real_out[l] = real_out[l] > out_max ? out_max : real_out[l];
*(output_ptr + out_channel_offset + l) = (int8_t)real_out[l];
}
#endif
}
// less than 8 channel
int channel_8_res = channel_16_res - c8 * C8NUM;
int in_c8_offset = in_c16_offset + c8 * C8NUM;
int out_c8_offset = out_c16_offset + c8 * C8NUM;
for (int k = 0; k < channel_8_res; k++) {
int in_channel_offset = in_c8_offset + k;
int out_channel_offset = out_c8_offset + k;
int16_t tmp_avg = 0;
int real_count = 0;
for (int h = 0; h < win_h; h++) {
for (int w = 0; w < win_w; w++) {
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
(in_w_index + w) >= in_w) {
continue;
} else {
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
tmp_avg += *(input_ptr + in_offset);
++real_count;
}
for (int h = kh_s; h < kh_e; h++) {
for (int w = kw_s; w < kw_e; w++) {
int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel;
tmp_avg += input_ptr[in_offset];
} // win_w loop
} // win_h loop
int16_t tmp_out = round((float)tmp_avg / (float)real_count);
int16_t tmp_out = round((float)tmp_avg / (float)real_count + 128) - 128;
tmp_out = (int8_t)(round((tmp_out - input_zp) * real_multiplier) + output_zp);
int16_t real_out = tmp_out < out_min ? out_min : tmp_out;
real_out = real_out > out_max ? out_max : real_out;
*(output_ptr + out_channel_offset) = (int8_t)real_out;
@ -249,6 +324,109 @@ void MaxPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParamete
} // out_batch loop
}
void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param,
int task_id) {
int stride_w = pooling_param->stride_w_;
int stride_h = pooling_param->stride_h_;
int pad_w = pooling_param->pad_l_;
int pad_h = pooling_param->pad_u_;
int win_w = pooling_param->window_w_;
int win_h = pooling_param->window_h_;
int channel = pooling_param->input_channel_;
int in_w = pooling_param->input_w_;
int in_h = pooling_param->input_h_;
int output_w = pooling_param->output_w_;
int output_h = pooling_param->output_h_;
int output_batch = pooling_param->output_batch_;
int out_plane = output_w * output_h;
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
int thread_num = out_tile_count < pooling_param->thread_num_ ? out_tile_count : pooling_param->thread_num_;
int c16 = UP_DIV(channel, 16);
// input channel is equal to output channel
float input_scale = pooling_param->quant_args_[0][0].scale_;
int input_zp = pooling_param->quant_args_[0][0].zp_;
float output_scale = pooling_param->quant_args_[1][0].scale_;
int output_zp = pooling_param->quant_args_[1][0].zp_;
double real_multiplier = input_scale / output_scale;
for (int batch = 0; batch < output_batch; batch++) {
int in_batch_offset = batch * in_h * in_w * channel;
int out_batch_offset = batch * output_h * output_w * channel;
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
int cal_start_index = thread_id * TILE_NUM;
int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index);
for (int i = 0; i < real_cal_num; i++) {
int index = cal_start_index + i;
int out_w_index = index % output_w;
int out_h_index = index / output_w;
int in_w_index = out_w_index * stride_w - pad_w;
int in_h_index = out_h_index * stride_h - pad_h;
int out_plane_offset = out_batch_offset + index * channel;
for (int j = 0; j < c16 - 1; j++) {
int in_channel_offset = in_batch_offset + j * 16;
int out_channel_offset = out_plane_offset + j * 16;
#ifdef ENABLE_NEON
int8x16_t tmp_max = vdupq_n_s8(INT8_MIN);
#else
int8_t tmp_max[16];
for (int m = 0; m < C16NUM; ++m) {
tmp_max[m] = INT8_MIN;
}
#endif
for (int h = 0; h < win_h; h++) {
for (int w = 0; w < win_w; w++) {
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
(in_w_index + w) >= in_w) {
continue;
} else {
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
#ifdef ENABLE_NEON
tmp_max = vmaxq_s8(tmp_max, vld1q_s8(input_ptr + in_offset));
#else
for (int k = 0; k < C16NUM; ++k) {
tmp_max[k] = MaxInt8(tmp_max[k], *(input_ptr + in_offset + k));
}
#endif
}
} // win_w loop
} // win_h loop
#ifdef ENABLE_NEON
for (int l = 0; l < C16NUM; ++l) {
tmp_max[l] = (int8_t)(round((tmp_max[l] - input_zp) * real_multiplier) + output_zp);
}
vst1q_s8(output_ptr + out_channel_offset, tmp_max);
#else
for (int l = 0; l < C16NUM; ++l) {
*(output_ptr + out_channel_offset + l) =
(int8_t)(round((tmp_max[l] - input_zp) * real_multiplier) + output_zp);
}
#endif
} // in_channel loop
// res channel
int channel_s = (c16 - 1) * 16;
for (int k = channel_s; k < channel; k++) {
int in_channel_offset = in_batch_offset + k;
int out_channel_offset = out_plane_offset + k;
int8_t tmp_max = INT8_MIN;
for (int h = 0; h < win_h; h++) {
for (int w = 0; w < win_w; w++) {
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
(in_w_index + w) >= in_w) {
continue;
} else {
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
tmp_max = MaxInt8(tmp_max, *(input_ptr + in_offset));
}
} // win_w loop
} // win_h loop
*(output_ptr + out_channel_offset) = (int8_t)(round((tmp_max - input_zp) * real_multiplier) + output_zp);
} // channel_res loop
} // out_plane loop
} // out_batch loop
}
}
void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id) {
int stride_w = pooling_param->stride_w_;
int stride_h = pooling_param->stride_h_;
@ -264,7 +442,7 @@ void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
int output_batch = pooling_param->output_batch_;
int out_plane = output_w * output_h;
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
int thread_num = pooling_param->thread_num_;
int thread_num = out_tile_count < pooling_param->thread_num_ ? out_tile_count : pooling_param->thread_num_;
int c16 = UP_DIV(channel, 16);
for (int batch = 0; batch < output_batch; batch++) {
@ -286,22 +464,10 @@ void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
#ifdef ENABLE_NEON
int8x16_t tmp_max = vdupq_n_s8(INT8_MIN);
#else
int8_t tmp_max1 = INT8_MIN;
int8_t tmp_max2 = INT8_MIN;
int8_t tmp_max3 = INT8_MIN;
int8_t tmp_max4 = INT8_MIN;
int8_t tmp_max5 = INT8_MIN;
int8_t tmp_max6 = INT8_MIN;
int8_t tmp_max7 = INT8_MIN;
int8_t tmp_max8 = INT8_MIN;
int8_t tmp_max9 = INT8_MIN;
int8_t tmp_max10 = INT8_MIN;
int8_t tmp_max11 = INT8_MIN;
int8_t tmp_max12 = INT8_MIN;
int8_t tmp_max13 = INT8_MIN;
int8_t tmp_max14 = INT8_MIN;
int8_t tmp_max15 = INT8_MIN;
int8_t tmp_max16 = INT8_MIN;
int8_t tmp_max[16];
for (int m = 0; m < C16NUM; ++m) {
tmp_max[m] = INT8_MIN;
}
#endif
for (int h = 0; h < win_h; h++) {
for (int w = 0; w < win_w; w++) {
@ -313,22 +479,9 @@ void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
#ifdef ENABLE_NEON
tmp_max = vmaxq_s8(tmp_max, vld1q_s8(input_ptr + in_offset));
#else
tmp_max1 = MaxInt8(tmp_max1, *(input_ptr + in_offset));
tmp_max2 = MaxInt8(tmp_max2, *(input_ptr + in_offset + 1));
tmp_max3 = MaxInt8(tmp_max3, *(input_ptr + in_offset + 2));
tmp_max4 = MaxInt8(tmp_max4, *(input_ptr + in_offset + 3));
tmp_max5 = MaxInt8(tmp_max5, *(input_ptr + in_offset + 4));
tmp_max6 = MaxInt8(tmp_max6, *(input_ptr + in_offset + 5));
tmp_max7 = MaxInt8(tmp_max7, *(input_ptr + in_offset + 6));
tmp_max8 = MaxInt8(tmp_max8, *(input_ptr + in_offset + 7));
tmp_max9 = MaxInt8(tmp_max9, *(input_ptr + in_offset + 8));
tmp_max10 = MaxInt8(tmp_max10, *(input_ptr + in_offset + 9));
tmp_max11 = MaxInt8(tmp_max11, *(input_ptr + in_offset + 10));
tmp_max12 = MaxInt8(tmp_max12, *(input_ptr + in_offset + 11));
tmp_max13 = MaxInt8(tmp_max13, *(input_ptr + in_offset + 12));
tmp_max14 = MaxInt8(tmp_max14, *(input_ptr + in_offset + 13));
tmp_max15 = MaxInt8(tmp_max15, *(input_ptr + in_offset + 14));
tmp_max16 = MaxInt8(tmp_max16, *(input_ptr + in_offset + 15));
for (int k = 0; k < C16NUM; ++k) {
tmp_max[k] = MaxInt8(tmp_max[k], *(input_ptr + in_offset + k));
}
#endif
}
} // win_w loop
@ -336,24 +489,13 @@ void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
#ifdef ENABLE_NEON
vst1q_s8(output_ptr + out_channel_offset, tmp_max);
#else
*(output_ptr + out_channel_offset) = tmp_max1;
*(output_ptr + out_channel_offset + 1) = tmp_max2;
*(output_ptr + out_channel_offset + 2) = tmp_max3;
*(output_ptr + out_channel_offset + 3) = tmp_max4;
*(output_ptr + out_channel_offset + 4) = tmp_max5;
*(output_ptr + out_channel_offset + 5) = tmp_max6;
*(output_ptr + out_channel_offset + 6) = tmp_max7;
*(output_ptr + out_channel_offset + 7) = tmp_max8;
*(output_ptr + out_channel_offset + 8) = tmp_max9;
*(output_ptr + out_channel_offset + 9) = tmp_max10;
*(output_ptr + out_channel_offset + 10) = tmp_max11;
*(output_ptr + out_channel_offset + 11) = tmp_max12;
*(output_ptr + out_channel_offset + 12) = tmp_max13;
*(output_ptr + out_channel_offset + 13) = tmp_max14;
*(output_ptr + out_channel_offset + 14) = tmp_max15;
*(output_ptr + out_channel_offset + 15) = tmp_max16;
for (int l = 0; l < C16NUM; ++l) {
*(output_ptr + out_channel_offset + l) = tmp_max[l];
}
#endif
} // in_channel loop
// res channel
int channel_s = (c16 - 1) * 16;
for (int k = channel_s; k < channel; k++) {
int in_channel_offset = in_batch_offset + k;

View File

@ -32,6 +32,8 @@ void AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
void MaxPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id);
void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id);
void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id);
#ifdef __cplusplus
}

View File

@ -19,14 +19,16 @@
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
typedef enum PoolMode { PoolMode_No, PoolMode_MaxPool, PoolMode_AvgPool } PoolMode;
typedef enum RoundMode { RoundMode_No, RoundMode_Ceil, RoundMode_Floor } RoundMode;
typedef struct PoolingParameter {
OpParameter op_parameter_;
PoolMode pool_mode_;
RoundMode round_mode_;
ActType act_type_;
QuantArg **quant_args_;
bool global_;
bool max_pooling_;
bool avg_pooling_;
bool round_ceil_;
bool round_floor_;
int window_w_;
int window_h_;
int input_w_;
@ -44,7 +46,8 @@ typedef struct PoolingParameter {
int stride_w_;
int stride_h_;
int thread_num_;
ActType act_type_;
bool global_;
bool quantize_;
} PoolingParameter;
#endif // MINDSPORE_LITE_NNACL_POOLING_PARAMETER_H_

View File

@ -294,32 +294,26 @@ OpParameter *PopulatePoolingParameter(const mindspore::lite::PrimitiveC *primiti
auto pool_mode = pooling_primitive->GetPoolingMode();
switch (pool_mode) {
case schema::PoolMode_MAX_POOLING:
pooling_param->max_pooling_ = true;
pooling_param->avg_pooling_ = false;
pooling_param->pool_mode_ = PoolMode_MaxPool;
break;
case schema::PoolMode_MEAN_POOLING:
pooling_param->max_pooling_ = false;
pooling_param->avg_pooling_ = true;
pooling_param->pool_mode_ = PoolMode_AvgPool;
break;
default:
pooling_param->max_pooling_ = false;
pooling_param->avg_pooling_ = false;
pooling_param->pool_mode_ = PoolMode_No;
break;
}
auto round_mode = pooling_primitive->GetRoundMode();
switch (round_mode) {
case schema::RoundMode_FLOOR:
pooling_param->round_floor_ = true;
pooling_param->round_ceil_ = false;
pooling_param->round_mode_ = RoundMode_Floor;
break;
case schema::RoundMode_CEIL:
pooling_param->round_floor_ = false;
pooling_param->round_ceil_ = true;
pooling_param->round_mode_ = RoundMode_Ceil;
break;
default:
pooling_param->round_floor_ = false;
pooling_param->round_ceil_ = false;
pooling_param->round_mode_ = RoundMode_No;
break;
}

View File

@ -42,6 +42,12 @@ int PoolingBaseCPUKernel::SetQuantParam() {
pooling_quant_arg_[1][0].scale_ = out_quant_arg.front().scale;
pooling_quant_arg_[1][0].zp_ = out_quant_arg.front().zeroPoint;
pooling_param_->quant_args_ = pooling_quant_arg_;
if (pooling_quant_arg_[0][0].scale_ == pooling_quant_arg_[1][0].scale_ &&
pooling_quant_arg_[0][0].zp_ == pooling_quant_arg_[1][0].zp_) {
pooling_param_->quantize_ = false;
} else {
pooling_param_->quantize_ = true;
}
return RET_OK;
}

View File

@ -53,7 +53,7 @@ int PoolingFp16CPUKernel::ReSize() {
}
int PoolingFp16CPUKernel::RunImpl(int task_id) {
if (pooling_param_->max_pooling_) {
if (pooling_param_->pool_mode_ == PoolMode_MaxPool) {
MaxPoolingFp16(fp16_input_, fp16_output_, pooling_param_, task_id);
} else {
AvgPoolingFp16(fp16_input_, fp16_output_, pooling_param_, task_id);

View File

@ -52,7 +52,7 @@ int PoolingCPUKernel::ReSize() {
int PoolingCPUKernel::RunImpl(int task_id) {
auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->Data());
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->Data());
if (pooling_param_->max_pooling_) {
if (pooling_param_->pool_mode_ == PoolMode_MaxPool) {
switch (pooling_param_->act_type_) {
case ActType_Relu:
MaxPoolingRelu(input_ptr, output_ptr, pooling_param_, task_id);

View File

@ -163,7 +163,7 @@ int PoolingGradCPUKernel::Run() {
auto input_ptr = reinterpret_cast<float *>(inputs_.at(0)->Data());
auto output_ptr = reinterpret_cast<float *>(outputs_.at(0)->Data());
if (pool_param->max_pooling_) {
if (pool_param->pool_mode_ == PoolMode_MaxPool) {
auto ind = reinterpret_cast<int *>(inputs_.at(1)->Data());
MaxPoolingGrad(input_ptr, ind, output_ptr, pool_param);
} else {

View File

@ -61,10 +61,14 @@ int PoolingInt8CPUKernel::ReSize() {
int PoolingInt8CPUKernel::RunImpl(int task_id) {
auto input_data = reinterpret_cast<int8_t *>(in_tensors_.at(kInputIndex)->Data());
auto output_data = reinterpret_cast<int8_t *>(out_tensors_.at(kOutputIndex)->Data());
if (pooling_param_->max_pooling_) {
MaxPoolingInt8(input_data, output_data, pooling_param_, task_id);
if (pooling_param_->pool_mode_ == PoolMode_MaxPool) {
if (pooling_param_->quantize_) {
MaxPoolingWithQuantInt8(input_data, output_data, pooling_param_, task_id);
} else {
MaxPoolingOptInt8(input_data, output_data, pooling_param_, task_id);
}
} else {
AvgPoolingInt8(input_data, output_data, pooling_param_, task_id);
AvgPoolingOptInt8(input_data, output_data, pooling_param_, task_id);
}
return RET_OK;
}

View File

@ -43,13 +43,13 @@ int PoolingOpenCLKernel::Init() {
std::string source;
std::string program_name;
#endif
if (parameter_->max_pooling_) {
if (parameter_->pool_mode_ == PoolMode_MaxPool) {
kernel_name = "MaxPooling2d";
#ifndef PROGRAM_WITH_IL
source = max_pool2d_source;
program_name = "MaxPooling2d";
#endif
} else if (parameter_->avg_pooling_) {
} else if (parameter_->pool_mode_ == PoolMode_AvgPool) {
kernel_name = "AvgPooling2d";
#ifndef PROGRAM_WITH_IL
source = avg_pool2d_source;

View File

@ -26,7 +26,7 @@
#include "nnacl/fp32_grad/pooling_grad.h"
namespace mindspore {
class TestPoolingGradFp32 : public mindspore::CommonTest {
class TestPoolingGradFp32 : public mindspore::CommonTest {
public:
TestPoolingGradFp32() {}
};
@ -161,8 +161,7 @@ TEST_F(TestPoolingGradFp32, MaxPoolingGradFp32) {
auto pooling_param = new PoolingParameter();
InitPoolingParamFP32(pooling_param);
pooling_param->output_channel_ = 3;
pooling_param->avg_pooling_ = false;
pooling_param->max_pooling_ = true;
pooling_param->pool_mode_ = PoolMode_MaxPool;
// runtime part
printf("Calculating runtime cost...\n");
uint64_t time_avg = 0;
@ -215,8 +214,7 @@ TEST_F(TestPoolingGradFp32, MaxPoolingKernelGradFp32) {
// prepare stage
auto maxpool = new PoolingParameter();
InitPoolingParamFP32(maxpool);
maxpool->avg_pooling_ = false;
maxpool->max_pooling_ = true;
maxpool->pool_mode_ = PoolMode_MaxPool;
maxpool->input_h_ = 30;
maxpool->input_w_ = 30;
maxpool->input_channel_ = 3;
@ -268,8 +266,7 @@ TEST_F(TestPoolingGradFp32, MaxPoolingKernelGradFp32) {
auto pooling_param = new PoolingParameter();
InitPoolingParamFP32(pooling_param);
pooling_param->avg_pooling_ = false;
pooling_param->max_pooling_ = true;
pooling_param->pool_mode_ = PoolMode_MaxPool;
pooling_param->input_h_ = 10;
pooling_param->input_w_ = 10;
pooling_param->input_channel_ = 3;

View File

@ -48,8 +48,7 @@ void InitAvgPoolingParam(PoolingParameter *param) {
param->pad_l_ = 0;
param->pad_r_ = 0;
param->max_pooling_ = false;
param->avg_pooling_ = true;
param->pool_mode_ = PoolMode_AvgPool;
}
TEST_F(TestAvgPoolingOpenCL, AvgPoolFp32) {

View File

@ -35,8 +35,7 @@ void InitParameter(PoolingParameter *param) {
param->pad_d_ = 0;
param->pad_l_ = 0;
param->pad_r_ = 0;
param->avg_pooling_ = false;
param->max_pooling_ = true;
param->pool_mode_ = PoolMode_MaxPool;
}
TEST_F(TestMaxPoolingOpenCL, MaxPool_1_32_512_96) {