avgpooling optimization from nc4hwc -> nhwc

This commit is contained in:
greatpanc 2021-12-15 10:13:38 +08:00
parent 7295eca19d
commit 0baf9d53b6
3 changed files with 290 additions and 27 deletions

View File

@ -140,6 +140,243 @@ int AvgPooling(const float *input_ptr, float *output_ptr, const PoolingParameter
return NNACL_OK;
}
int AvgPoolingFromNC4HW4ToNHWCLessC(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param,
int task_id, float minf, float maxf) {
int in_w = pooling_param->input_w_, in_h = pooling_param->input_h_;
int win_w = pooling_param->window_w_, win_h = pooling_param->window_h_;
int output_w = pooling_param->output_w_, output_h = pooling_param->output_h_;
int channel = pooling_param->input_channel_;
int out_plane = output_w * output_h;
int in_plane = in_w * in_h;
NNACL_CHECK_ZERO_RETURN_ERR(output_w);
#ifdef ENABLE_AVX
const int c_tile = C8NUM;
const int once_calc_num = 2;
#elif defined(ENABLE_NEON) || defined(ENABLE_SSE)
const int c_tile = C4NUM;
const int once_calc_num = 1;
#else
const int c_tile = 1;
const int once_calc_num = 1;
#endif
const int c_xtile = once_calc_num * c_tile;
int cur_c = (channel / c_xtile) * c_xtile;
int last_c_size = channel - cur_c;
int less_out_plane = out_plane * last_c_size;
int calc_tile = UP_DIV(less_out_plane, pooling_param->thread_num_);
int index_begin = task_id * calc_tile;
int index_end = (index_begin + calc_tile) < less_out_plane ? (index_begin + calc_tile) : less_out_plane;
int c_start = (index_begin / out_plane) + cur_c;
int index_less = index_begin % out_plane;
int h_start = index_less / output_h;
int w_start = index_less % output_h;
int c_end = (index_end / out_plane) + cur_c;
index_less = index_end % out_plane;
int h_end = index_less / output_h;
int w_end = index_less % output_h;
int c = c_start;
int h = h_start;
int w = w_start;
int in_w_cx_line = in_w * last_c_size;
const float *src_c_ptr = src_b_ptr + c * in_plane;
for (; c < channel; c += c_xtile) {
for (; h < output_h; h++) {
int cur_index_in_h_start = MSMAX(h * pooling_param->stride_h_ - pooling_param->pad_d_, 0);
int cur_index_in_h_end = MSMIN(cur_index_in_h_start + win_h, in_h);
for (; w < output_w; w++) {
MS_CHECK_TRUE_RET((c < c_end || h < h_end || w < w_end), NNACL_OK);
float tmp_avg = 0.0;
int cur_index_in_w_start = MSMAX(w * pooling_param->stride_w_ - pooling_param->pad_l_, 0);
int cur_index_in_w_end = MSMIN(cur_index_in_w_start + win_w, in_w);
int real_count = (cur_index_in_w_end - cur_index_in_w_start) * (cur_index_in_h_end - cur_index_in_h_start);
MS_CHECK_TRUE_RET(real_count != 0, NNACL_ERR);
for (int cur_index_in_h = cur_index_in_h_start; cur_index_in_h < cur_index_in_h_end; cur_index_in_h++) {
const float *src_c_ptr_h_line = src_c_ptr + cur_index_in_h * in_w_cx_line;
for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) {
const float *cur_input_index = src_c_ptr_h_line + cur_index_in_w * last_c_size + (c - cur_c);
tmp_avg += cur_input_index[0];
}
}
float *dst_c_ptr = dst_b_ptr + h * output_w * channel + w * channel + c;
tmp_avg = tmp_avg / (float)real_count;
tmp_avg = fminf(tmp_avg, maxf);
dst_c_ptr[0] = tmp_avg;
}
w = 0;
}
h = 0;
}
return NNACL_OK;
}
int AvgPoolingFromNC4HW4ToNHWCBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param,
int task_id, float minf, float maxf) {
int in_w = pooling_param->input_w_, in_h = pooling_param->input_h_;
int win_w = pooling_param->window_w_, win_h = pooling_param->window_h_;
int output_w = pooling_param->output_w_, output_h = pooling_param->output_h_;
int channel = pooling_param->input_channel_;
int out_plane = output_w * output_h;
int in_plane = in_w * in_h;
NNACL_CHECK_ZERO_RETURN_ERR(output_w);
#ifdef ENABLE_AVX
const int c_tile = C8NUM;
const int once_calc_num = 2;
MS_FLOAT32X8 min_value_8 = MS_MOV256_F32(minf);
MS_FLOAT32X8 max_value_8 = MS_MOV256_F32(maxf);
#elif defined(ENABLE_NEON) || defined(ENABLE_SSE)
const int c_tile = C4NUM;
const int once_calc_num = 1;
MS_FLOAT32X4 min_value = MS_MOVQ_F32(minf);
MS_FLOAT32X4 max_value = MS_MOVQ_F32(maxf);
#else
const int c_tile = 1;
const int once_calc_num = 1;
#endif
int in_w_cx_line = in_w * c_tile;
const int c_xtile = once_calc_num * c_tile;
int c_tile_num = channel / c_xtile;
int all_out_plane = out_plane * c_tile_num;
int calc_tile = UP_DIV(all_out_plane, pooling_param->thread_num_);
int index_begin = task_id * calc_tile;
int index_end = (index_begin + calc_tile) < all_out_plane ? (index_begin + calc_tile) : all_out_plane;
int c_start = (index_begin / out_plane) * c_xtile;
int index_less = index_begin % out_plane;
int h_start = index_less / output_h;
int w_start = index_less % output_h;
int c_end = (index_end / out_plane) * c_xtile;
index_less = index_end % out_plane;
int h_end = index_less / output_h;
int w_end = index_less % output_h;
int c = c_start;
int h = h_start;
int w = w_start;
for (; c < channel; c += c_xtile) {
const float *src_c_ptr = src_b_ptr + c * in_plane;
for (; h < output_h; h++) {
int cur_index_in_h_start = MSMAX(h * pooling_param->stride_h_ - pooling_param->pad_d_, 0);
int cur_index_in_h_end = MSMIN(cur_index_in_h_start + win_h, in_h);
for (; w < output_w; w++) {
MS_CHECK_TRUE_RET((c < c_end || h < h_end || w < w_end), NNACL_OK);
#ifdef ENABLE_AVX
MS_FLOAT32X8 tmp_avg = MS_MOV256_F32(0);
MS_FLOAT32X8 tmp_avg2 = MS_MOV256_F32(0);
#elif defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_FLOAT32X4 tmp_avg = MS_MOVQ_F32(0);
#else
float tmp_avg = 0;
#endif
int cur_index_in_w_start = MSMAX(w * pooling_param->stride_w_ - pooling_param->pad_l_, 0);
int cur_index_in_w_end = MSMIN(cur_index_in_w_start + win_w, in_w);
int real_count = (cur_index_in_w_end - cur_index_in_w_start) * (cur_index_in_h_end - cur_index_in_h_start);
MS_CHECK_TRUE_RET(real_count != 0, NNACL_ERR);
for (int cur_index_in_h = cur_index_in_h_start; cur_index_in_h < cur_index_in_h_end; cur_index_in_h++) {
const float *src_c_ptr_h_line = src_c_ptr + cur_index_in_h * in_w_cx_line;
for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) {
const float *cur_input_index = src_c_ptr_h_line + cur_index_in_w * c_tile;
#ifdef ENABLE_AVX
tmp_avg = MS_ADD256_F32(tmp_avg, MS_LD256_F32(cur_input_index));
#elif defined(ENABLE_NEON) || defined(ENABLE_SSE)
tmp_avg = MS_ADDQ_F32(tmp_avg, MS_LDQ_F32(cur_input_index));
#else
tmp_avg += cur_input_index[0];
#endif
}
#ifdef ENABLE_AVX
const float *src_c2_ptr_h_line = src_c_ptr_h_line + c_tile * in_plane;
for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) {
const float *cur_input_index = src_c2_ptr_h_line + cur_index_in_w * c_tile;
tmp_avg2 = MS_ADD256_F32(tmp_avg2, MS_LD256_F32(cur_input_index));
}
#endif
}
float *dst_c_ptr = dst_b_ptr + h * output_w * channel + w * channel + c;
#ifdef ENABLE_AVX
float *dst_c2_ptr = dst_c_ptr + c_tile;
tmp_avg = MS_DIV256_F32(tmp_avg, MS_MOV256_F32(real_count));
tmp_avg = MS_MAX256_F32(tmp_avg, min_value_8);
tmp_avg = MS_MIN256_F32(tmp_avg, max_value_8);
MS_ST256_F32(dst_c_ptr, tmp_avg);
tmp_avg2 = MS_DIV256_F32(tmp_avg2, MS_MOV256_F32(real_count));
tmp_avg2 = MS_MAX256_F32(tmp_avg2, min_value_8);
tmp_avg2 = MS_MIN256_F32(tmp_avg2, max_value_8);
MS_ST256_F32(dst_c2_ptr, tmp_avg2);
#elif defined(ENABLE_NEON) || defined(ENABLE_SSE)
tmp_avg = MS_DIVQ_F32(tmp_avg, MS_MOVQ_F32(real_count));
tmp_avg = MS_MAXQ_F32(tmp_avg, min_value);
tmp_avg = MS_MINQ_F32(tmp_avg, max_value);
MS_STQ_F32(dst_c_ptr, tmp_avg);
#else
tmp_avg = tmp_avg / (float)real_count;
tmp_avg = fmaxf(tmp_avg, minf);
tmp_avg = fminf(tmp_avg, maxf);
dst_c_ptr[0] = tmp_avg;
#endif
}
w = 0;
}
h = 0;
}
return NNACL_OK;
}
int AvgPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param,
int task_id, float minf, float maxf) {
int in_w = pooling_param->input_w_;
int in_h = pooling_param->input_h_;
int output_w = pooling_param->output_w_;
int output_h = pooling_param->output_h_;
int channel = pooling_param->input_channel_;
int output_batch = pooling_param->output_batch_;
for (int batch = 0; batch < output_batch; batch++) {
const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel;
float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel;
int ret = AvgPoolingFromNC4HW4ToNHWCBatch(src_b_ptr, dst_b_ptr, pooling_param, task_id, minf, maxf);
if (ret != NNACL_OK) {
return ret;
}
ret = AvgPoolingFromNC4HW4ToNHWCLessC(src_b_ptr, dst_b_ptr, pooling_param, task_id, minf, maxf);
if (ret != NNACL_OK) {
return ret;
}
}
return NNACL_OK;
}
int MaxPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, int task_id,
float minf, float maxf) {
int in_w = pooling_param->input_w_, in_h = pooling_param->input_h_;
@ -249,7 +486,7 @@ int MaxPooling(const float *input_ptr, float *output_ptr, const PoolingParameter
return NNACL_OK;
}
int MaxPoolingFormNC4HW4ToNHWCLessC(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param,
int MaxPoolingFromNC4HW4ToNHWCLessC(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param,
int task_id, float minf, float maxf) {
int in_w = pooling_param->input_w_, in_h = pooling_param->input_h_;
int win_w = pooling_param->window_w_, win_h = pooling_param->window_h_;
@ -276,36 +513,61 @@ int MaxPoolingFormNC4HW4ToNHWCLessC(const float *src_b_ptr, float *dst_b_ptr, co
int cur_c = (channel / c_xtile) * c_xtile;
int last_c_size = channel - cur_c;
int calc_tile = UP_DIV(out_plane, pooling_param->thread_num_);
int less_out_plane = out_plane * last_c_size;
int calc_tile = UP_DIV(less_out_plane, pooling_param->thread_num_);
int index_begin = task_id * calc_tile;
int index_end = (index_begin + calc_tile) < out_plane ? (index_begin + calc_tile) : out_plane;
int index_end = (index_begin + calc_tile) < less_out_plane ? (index_begin + calc_tile) : less_out_plane;
for (int c = cur_c; c < channel; c++) {
for (int index = index_begin; index < index_end; index++) {
const float *src_c_ptr = src_b_ptr + c * in_plane;
int h = index / output_h;
int w = index % output_h;
int c_start = (index_begin / out_plane) + cur_c;
int index_less = index_begin % out_plane;
int h_start = index_less / output_h;
int w_start = index_less % output_h;
float tmp_max = -FLT_MAX;
for (int kh = 0; kh < win_h; kh++) {
for (int kw = 0; kw < win_w; kw++) {
const float *src_win_ptr = src_c_ptr + (kh + win_h * h) * in_w * last_c_size + (kw + win_w * w) * last_c_size;
tmp_max = fmaxf(tmp_max, src_win_ptr[0]);
int c_end = (index_end / out_plane) + cur_c;
index_less = index_end % out_plane;
int h_end = index_less / output_h;
int w_end = index_less % output_h;
int c = c_start;
int h = h_start;
int w = w_start;
int in_w_cx_line = in_w * last_c_size;
const float *src_c_ptr = src_b_ptr + cur_c * in_plane;
for (; c < channel; c++) {
for (; h < output_h; h++) {
int cur_index_in_h_start = MSMAX(h * pooling_param->stride_h_ - pooling_param->pad_d_, 0);
int cur_index_in_h_end = MSMIN(cur_index_in_h_start + win_h, in_h);
for (; w < output_w; w++) {
MS_CHECK_TRUE_RET((c < c_end || h < h_end || w < w_end), NNACL_OK);
float tmp_max = -FLT_MAX;
int cur_index_in_w_start = MSMAX(w * pooling_param->stride_w_ - pooling_param->pad_l_, 0);
int cur_index_in_w_end = MSMIN(cur_index_in_w_start + win_w, in_w);
for (int cur_index_in_h = cur_index_in_h_start; cur_index_in_h < cur_index_in_h_end; cur_index_in_h++) {
const float *src_c_ptr_h_line = src_c_ptr + cur_index_in_h * in_w_cx_line;
for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) {
const float *cur_input_index = src_c_ptr_h_line + cur_index_in_w * last_c_size + (c - cur_c);
tmp_max = fmaxf(tmp_max, cur_input_index[0]);
}
}
float *dst_c_ptr = dst_b_ptr + h * output_w * channel + w * channel + c;
tmp_max = fmaxf(tmp_max, minf);
tmp_max = fminf(tmp_max, maxf);
dst_c_ptr[0] = tmp_max;
}
float *dst_c_ptr = dst_b_ptr + h * output_w * channel + w * channel + c;
tmp_max = fmaxf(tmp_max, minf);
tmp_max = fminf(tmp_max, maxf);
dst_c_ptr[0] = tmp_max;
w = 0;
}
h = 0;
}
return NNACL_OK;
}
int MaxPoolingFormNC4HW4ToNHWCBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param,
int MaxPoolingFromNC4HW4ToNHWCBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param,
int task_id, float minf, float maxf) {
int in_w = pooling_param->input_w_, in_h = pooling_param->input_h_;
int win_w = pooling_param->window_w_, win_h = pooling_param->window_h_;
@ -426,7 +688,7 @@ int MaxPoolingFormNC4HW4ToNHWCBatch(const float *src_b_ptr, float *dst_b_ptr, co
return NNACL_OK;
}
int MaxPoolingFormNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param,
int MaxPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param,
int task_id, float minf, float maxf) {
int in_w = pooling_param->input_w_;
int in_h = pooling_param->input_h_;
@ -438,12 +700,12 @@ int MaxPoolingFormNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const
for (int batch = 0; batch < output_batch; batch++) {
const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel;
float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel;
int ret = MaxPoolingFormNC4HW4ToNHWCBatch(src_b_ptr, dst_b_ptr, pooling_param, task_id, minf, maxf);
int ret = MaxPoolingFromNC4HW4ToNHWCBatch(src_b_ptr, dst_b_ptr, pooling_param, task_id, minf, maxf);
if (ret != NNACL_OK) {
return ret;
}
ret = MaxPoolingFormNC4HW4ToNHWCLessC(src_b_ptr, dst_b_ptr, pooling_param, task_id, minf, maxf);
ret = MaxPoolingFromNC4HW4ToNHWCLessC(src_b_ptr, dst_b_ptr, pooling_param, task_id, minf, maxf);
if (ret != NNACL_OK) {
return ret;
}

View File

@ -29,9 +29,11 @@ extern "C" {
#endif
int AvgPooling(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, int task_id,
float minf, float maxf);
int AvgPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param,
int task_id, float minf, float maxf);
int MaxPooling(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, int task_id,
float minf, float maxf);
int MaxPoolingFormNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param,
int MaxPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param,
int task_id, float minf, float maxf);
#ifdef __cplusplus
}

View File

@ -67,10 +67,9 @@ int PoolingCPUKernel::RunImpl(int task_id) const {
if (in_tensors_[0]->format() == NC4HW4) {
if (pooling_param_->pool_mode_ == PoolMode_MaxPool) {
ret = MaxPoolingFormNC4HW4ToNHWC(input_ptr, output_ptr, pooling_param_, task_id, minf, maxf);
ret = MaxPoolingFromNC4HW4ToNHWC(input_ptr, output_ptr, pooling_param_, task_id, minf, maxf);
} else {
// ret = AvgPoolingFormNC4HW4ToNHWC(input_ptr, output_ptr, pooling_param_, task_id, minf, maxf);
MS_LOG(ERROR) << "Do not support NC4HW4 AvgPooling input format.";
ret = AvgPoolingFromNC4HW4ToNHWC(input_ptr, output_ptr, pooling_param_, task_id, minf, maxf);
}
} else if (in_tensors_[0]->format() == NHWC) {
if (pooling_param_->pool_mode_ == PoolMode_MaxPool) {