diff --git a/mindspore/lite/nnacl/int8/concat_int8.c b/mindspore/lite/nnacl/int8/concat_int8.c index 00fb5f23091..d6f0c0a8542 100644 --- a/mindspore/lite/nnacl/int8/concat_int8.c +++ b/mindspore/lite/nnacl/int8/concat_int8.c @@ -46,13 +46,9 @@ void Int8Concat(int8_t **inputs, int8_t *output, ConcatParameter *para, int axis float bias = -input_quant[i].zp_ * scale; for (int j = 0; j < in_copy_size; j++) { int32_t output_tmp = round(input_ptr[j] * scale + bias) + output_zp; - if (output_tmp > max_int8) { - output[j] = max_int8; - } else if (output_tmp < min_int8) { - output[j] = min_int8; - } else { - output[j] = (int8_t)output_tmp; - } + output_tmp = output_tmp > min_int8 ? output_tmp : min_int8; + output_tmp = output_tmp < max_int8 ? output_tmp : max_int8; + output[j] = (int8_t)output_tmp; } } output += in_copy_size; diff --git a/mindspore/lite/nnacl/int8/pooling_int8.c b/mindspore/lite/nnacl/int8/pooling_int8.c index 57920879cda..383fe21ee4b 100644 --- a/mindspore/lite/nnacl/int8/pooling_int8.c +++ b/mindspore/lite/nnacl/int8/pooling_int8.c @@ -451,118 +451,83 @@ void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam int output_batch = pooling_param->output_batch_; int out_plane = output_w * output_h; int out_tile_count = UP_DIV(out_plane, TILE_NUM); - int thread_num = out_tile_count < pooling_param->thread_num_ ? out_tile_count : pooling_param->thread_num_; - int c16 = channel / 16; + int thread_num = MSMIN(out_tile_count, pooling_param->thread_num_); + int8_t out_array[MAX_MAXPOOL_SIZE]; for (int batch = 0; batch < output_batch; batch++) { int in_batch_offset = batch * in_h * in_w * channel; int out_batch_offset = batch * output_h * output_w * channel; for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { int cal_start_index = thread_id * TILE_NUM; - int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + int real_cal_num = out_plane - cal_start_index; + real_cal_num = MSMIN(real_cal_num, TILE_NUM); for (int i = 0; i < real_cal_num; i++) { int index = cal_start_index + i; int out_w_index = index % output_w; int out_h_index = index / output_w; int in_w_index = out_w_index * stride_w - pad_w; int in_h_index = out_h_index * stride_h - pad_h; + int ky_s = 0 > (-in_h_index) ? 0 : (-in_h_index); + int ky_e = MSMIN(win_h, in_h - in_h_index); + int kx_s = 0 > (-in_w_index) ? 0 : (-in_w_index); + int kx_e = MSMIN(win_w, in_w - in_w_index); + int input_stride = (in_h_index * in_w + in_w_index) * channel + in_batch_offset; int out_plane_offset = out_batch_offset + index * channel; - for (int j = 0; j < c16; j++) { - int in_channel_offset = in_batch_offset + j * 16; - int out_channel_offset = out_plane_offset + j * 16; -#ifdef ENABLE_NEON - int8x16_t tmp_max = vdupq_n_s8(INT8_MIN); -#else - int8_t tmp_max[16]; - for (int m = 0; m < C16NUM; ++m) { - tmp_max[m] = INT8_MIN; - } -#endif - for (int h = 0; h < win_h; h++) { - for (int w = 0; w < win_w; w++) { - if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || - (in_w_index + w) >= in_w) { - continue; - } else { - int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; -#ifdef ENABLE_NEON - tmp_max = vmaxq_s8(tmp_max, vld1q_s8(input_ptr + in_offset)); -#else - for (int k = 0; k < C16NUM; ++k) { - tmp_max[k] = MaxInt8(tmp_max[k], *(input_ptr + in_offset + k)); - } -#endif - } - } // win_w loop - } // win_h loop -#ifdef ENABLE_NEON - vst1q_s8(output_ptr + out_channel_offset, tmp_max); -#else - for (int l = 0; l < C16NUM; ++l) { - *(output_ptr + out_channel_offset + l) = tmp_max[l]; - } -#endif - } // in_channel loop - // 8 channel - int tmp_c = c16 * 16; - int c8 = (channel - c16 * 16) / 8; - for (int k = 0; k < c8; k++) { - int in_channel_offset = in_batch_offset + tmp_c + k * 8; - int out_channel_offset = out_plane_offset + tmp_c + k * 8; + int c = 0; + for (; c < channel; c += MAX_MAXPOOL_SIZE) { + int real_channel = channel - c; + real_channel = MSMIN(real_channel, MAX_MAXPOOL_SIZE); + memset(out_array, INT8_MIN, real_channel); + int8_t *out_data = output_ptr + out_plane_offset + c; + for (int h = ky_s; h < ky_e; ++h) { + int in_h_offset = input_stride + h * in_w * channel + c; + for (int w = kx_s; w < kx_e; ++w) { + const int8_t *in_data = input_ptr + in_h_offset + w * channel; + int j = 0; #ifdef ENABLE_NEON - int8x8_t tmp_max = vdup_n_s8(INT8_MIN); -#else - int8_t tmp_max[8]; - for (int m = 0; m < C8NUM; ++m) { - tmp_max[m] = INT8_MIN; - } -#endif - for (int h = 0; h < win_h; h++) { - for (int w = 0; w < win_w; w++) { - if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || - (in_w_index + w) >= in_w) { - continue; - } else { - int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; -#ifdef ENABLE_NEON - tmp_max = vmax_s8(tmp_max, vld1_s8(input_ptr + in_offset)); -#else - for (int l = 0; l < C8NUM; ++l) { - tmp_max[l] = MaxInt8(tmp_max[l], *(input_ptr + in_offset + l)); - } -#endif - } - } // win_w loop - } // win_h loop -#ifdef ENABLE_NEON - vst1_s8(output_ptr + out_channel_offset, tmp_max); -#else - for (int l = 0; l < C8NUM; ++l) { - *(output_ptr + out_channel_offset + l) = tmp_max[l]; - } -#endif - } // 8 channel loop + int c16 = real_channel / 16 * 16; + int c8 = real_channel / 8 * 8; + for (; j < c16; j += 16) { + int8x16_t ori_in = vld1q_s8(in_data); + int8x16_t out_array16 = vld1q_s8(out_array + j); + in_data += 16; + out_array16 = vmaxq_s8(ori_in, out_array16); + vst1q_s8(out_array + j, out_array16); + } // 16 channel loop - // res channel - int channel_s = c16 * 16 + c8 * 8; - for (int k = channel_s; k < channel; k++) { - int in_channel_offset = in_batch_offset + k; - int out_channel_offset = out_plane_offset + k; - int8_t tmp_max = INT8_MIN; - for (int h = 0; h < win_h; h++) { - for (int w = 0; w < win_w; w++) { - if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || - (in_w_index + w) >= in_w) { - continue; - } else { - int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; - tmp_max = MaxInt8(tmp_max, *(input_ptr + in_offset)); + for (; j < c8; j += 8) { + int8x8_t ori_in = vld1_s8(in_data); + int8x8_t out_array8 = vld1_s8(out_array + j); + in_data += 8; + out_array8 = vmax_s8(ori_in, out_array8); + vst1_s8(out_array + j, out_array8); + } // 8 channel loop +#endif + for (; j < real_channel; ++j) { + out_array[j] = out_array[j] > in_data[j] ? out_array[j] : in_data[j]; } - } // win_w loop - } // win_h loop - *(output_ptr + out_channel_offset) = tmp_max; - } // channel_res loop + } // kw loop + } // kh loop + + int j = 0; +#ifdef ENABLE_NEON + int c16 = real_channel / 16 * 16; + int c8 = real_channel / 8 * 8; + for (; j < c16; j += 16) { + vst1q_s8(out_data, vld1q_s8(out_array + j)); + out_data += 16; + } // 16 channel loop + + for (; j < c8; j += 8) { + vst1_s8(out_data, vld1_s8(out_array + j)); + out_data += 8; + } // 8 channel loop +#endif + for (; j < real_channel; ++j) { + out_data[j] = out_array[j]; + } + } // 256 channel loop } // out_plane loop } // out_batch loop } diff --git a/mindspore/lite/nnacl/int8/pooling_int8.h b/mindspore/lite/nnacl/int8/pooling_int8.h index 065de17645d..7bf013dcdb8 100644 --- a/mindspore/lite/nnacl/int8/pooling_int8.h +++ b/mindspore/lite/nnacl/int8/pooling_int8.h @@ -26,6 +26,8 @@ #ifdef __cplusplus extern "C" { #endif +#define MAX_MAXPOOL_SIZE 256 + int AvgPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id); int AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id);