!30356 [MSLITE][CPU] AVX512/256/SSE/NENO Advanced packaging, and Pool Op Refactoring and optimization

Merge pull request !30356 from Greatpan/avx512_pooling
This commit is contained in:
i-robot 2022-02-22 02:05:35 +00:00 committed by Gitee
commit 26bf47487c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 60 additions and 92 deletions

View File

@ -19,6 +19,33 @@
#include "nnacl/errorcode.h"
#include "nnacl/op_base.h"
// 32 bits, block_size : (512/256/128/32), block_num : (16/8/4/1)
#define SimdFp32AvgPoolingBatchCoreCalc(block_size, block_num, src_plane_ptr, channel, dst_plane_ptr, ci, \
real_win_h_start, real_win_h_end, real_win_w_start, real_win_w_end, \
in_h_index, in_w, in_w_index) \
do { \
MS_FLOAT_32xN(block_num) min_val_##block_num = MS_MOVN_F32(block_size, minf); \
MS_FLOAT_32xN(block_num) max_val_##block_num = MS_MOVN_F32(block_size, maxf); \
for (int block_max_size = channel - block_num + 1; ci < block_max_size; ci += block_num) { \
const float *src_c_ptr = src_plane_ptr + ci; \
float *dst_c_ptr = dst_plane_ptr + ci; \
MS_FLOAT_32xN(block_num) tmp_avg = MS_MOVN_F32(block_size, 0.0f); \
int real_count = 0; \
for (int h = real_win_h_start; h < real_win_h_end; h++) { \
for (int w = real_win_w_start; w < real_win_w_end; w++) { \
const float *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel; \
tmp_avg = MS_ADD_F32(block_size, tmp_avg, MS_LD_F32(block_size, src_win_ptr)); \
++real_count; \
} \
} \
MS_CHECK_TRUE_RET(real_count != 0, NNACL_ERR); \
tmp_avg = MS_DIV_F32(block_size, tmp_avg, MS_MOVN_F32(block_size, real_count)); \
tmp_avg = MS_MAX_F32(block_size, tmp_avg, min_val_##block_num); \
tmp_avg = MS_MIN_F32(block_size, tmp_avg, max_val_##block_num); \
MS_ST_F32(block_size, dst_c_ptr, tmp_avg); \
} \
} while (0)
int AvgPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, int task_id,
float minf, float maxf) {
int win_w = pooling_param->window_w_, win_h = pooling_param->window_h_;
@ -28,16 +55,6 @@ int AvgPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParam
int out_plane = output_w * output_h;
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
NNACL_CHECK_ZERO_RETURN_ERR(output_w);
#ifdef ENABLE_AVX
int c8 = channel / C8NUM * C8NUM;
MS_FLOAT32X8 min_value_8 = MS_MOV256_F32(minf);
MS_FLOAT32X8 max_value_8 = MS_MOV256_F32(maxf);
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
int c4 = channel / C4NUM * C4NUM;
MS_FLOAT32X4 min_value = MS_MOVQ_F32(minf);
MS_FLOAT32X4 max_value = MS_MOVQ_F32(maxf);
#endif
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += pooling_param->thread_num_) {
int cal_start_index = thread_id * TILE_NUM;
@ -57,46 +74,11 @@ int AvgPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParam
int real_win_w_start = MSMAX(0, -in_w_index);
int real_win_w_end = MSMIN(win_w, in_w - in_w_index);
int ci = 0;
#ifdef ENABLE_AVX
for (; ci < c8; ci += C8NUM) {
const float *src_c_ptr = src_plane_ptr + ci;
float *dst_c_ptr = dst_plane_ptr + ci;
MS_FLOAT32X8 tmp_avg = MS_MOV256_F32(0);
int real_count = 0;
for (int h = real_win_h_start; h < real_win_h_end; h++) {
for (int w = real_win_w_start; w < real_win_w_end; w++) {
const float *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel;
tmp_avg = MS_ADD256_F32(tmp_avg, MS_LD256_F32(src_win_ptr));
++real_count;
} // win_w loop
} // win_h loop
MS_CHECK_TRUE_RET(real_count != 0, NNACL_ERR);
tmp_avg = MS_DIV256_F32(tmp_avg, MS_MOV256_F32(real_count));
tmp_avg = MS_MAX256_F32(tmp_avg, min_value_8);
tmp_avg = MS_MIN256_F32(tmp_avg, max_value_8);
MS_ST256_F32(dst_c_ptr, tmp_avg);
} // ic8-1 loop
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
for (; ci < c4; ci += C4NUM) {
const float *src_c_ptr = src_plane_ptr + ci;
float *dst_c_ptr = dst_plane_ptr + ci;
MS_FLOAT32X4 tmp_avg = MS_MOVQ_F32(0);
int real_count = 0;
for (int h = real_win_h_start; h < real_win_h_end; h++) {
for (int w = real_win_w_start; w < real_win_w_end; w++) {
const float *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel;
tmp_avg = MS_ADDQ_F32(tmp_avg, MS_LDQ_F32(src_win_ptr));
++real_count;
} // win_w loop
} // win_h loop
MS_CHECK_TRUE_RET(real_count != 0, NNACL_ERR);
tmp_avg = MS_DIVQ_F32(tmp_avg, MS_MOVQ_F32(real_count));
tmp_avg = MS_MAXQ_F32(tmp_avg, min_value);
tmp_avg = MS_MINQ_F32(tmp_avg, max_value);
MS_STQ_F32(dst_c_ptr, tmp_avg);
} // ic4-1 loop
#endif
MS_SIMD_RUN_NO_SCALAR(SimdFp32AvgPoolingBatchCoreCalc, src_plane_ptr, channel, dst_plane_ptr, ci,
real_win_h_start, real_win_h_end, real_win_w_start, real_win_w_end, in_h_index, in_w,
in_w_index);
for (; ci < channel; ci++) {
const float *src_c_ptr = src_plane_ptr + ci;
float *dst_c_ptr = dst_plane_ptr + ci;
@ -377,6 +359,29 @@ int AvgPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const
return NNACL_OK;
}
// 32 bits, block_size : (512/256/128/32), block_num : (16/8/4/1)
#define SimdFp32MaxPoolingBatchCoreCalc(block_size, block_num, src_plane_ptr, channel, dst_plane_ptr, ci, \
real_win_h_start, real_win_h_end, real_win_w_start, real_win_w_end, \
in_h_index, in_w, in_w_index) \
do { \
MS_FLOAT_32xN(block_num) min_val_##block_num = MS_MOVN_F32(block_size, minf); \
MS_FLOAT_32xN(block_num) max_val_##block_num = MS_MOVN_F32(block_size, maxf); \
for (int block_max_size = channel - block_num + 1; ci < block_max_size; ci += block_num) { \
const float *src_c_ptr = src_plane_ptr + ci; \
float *dst_c_ptr = dst_plane_ptr + ci; \
MS_FLOAT_32xN(block_num) tmp_max = MS_MOVN_F32(block_size, -FLT_MAX); \
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) { \
for (int kw = real_win_w_start; kw < real_win_w_end; kw++) { \
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel; \
tmp_max = MS_MAX_F32(block_size, tmp_max, MS_LD_F32(block_size, src_win_ptr)); \
} \
} \
tmp_max = MS_MAX_F32(block_size, tmp_max, min_val_##block_num); \
tmp_max = MS_MIN_F32(block_size, tmp_max, max_val_##block_num); \
MS_ST_F32(block_size, dst_c_ptr, tmp_max); \
} \
} while (0)
int MaxPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, int task_id,
float minf, float maxf) {
int in_w = pooling_param->input_w_, in_h = pooling_param->input_h_;
@ -386,16 +391,6 @@ int MaxPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParam
int out_plane = output_w * output_h;
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
NNACL_CHECK_ZERO_RETURN_ERR(output_w);
#ifdef ENABLE_AVX
int c8 = channel / C8NUM * C8NUM;
MS_FLOAT32X8 min_value_8 = MS_MOV256_F32(minf);
MS_FLOAT32X8 max_value_8 = MS_MOV256_F32(maxf);
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
int c4 = channel / C4NUM * C4NUM;
MS_FLOAT32X4 min_value = MS_MOVQ_F32(minf);
MS_FLOAT32X4 max_value = MS_MOVQ_F32(maxf);
#endif
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += pooling_param->thread_num_) {
int cal_start_index = thread_id * TILE_NUM;
@ -415,38 +410,11 @@ int MaxPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParam
int real_win_w_start = MSMAX(0, -in_w_index);
int real_win_w_end = MSMIN(win_w, in_w - in_w_index);
int ci = 0;
#ifdef ENABLE_AVX
for (; ci < c8; ci += C8NUM) {
const float *src_c_ptr = src_plane_ptr + ci;
float *dst_c_ptr = dst_plane_ptr + ci;
MS_FLOAT32X8 tmp_max = MS_MOV256_F32(-FLT_MAX);
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) {
for (int kw = real_win_w_start; kw < real_win_w_end; kw++) {
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel;
tmp_max = MS_MAX256_F32(tmp_max, MS_LD256_F32(src_win_ptr));
} // win_w loop
} // win_h loop
tmp_max = MS_MAX256_F32(tmp_max, min_value_8);
tmp_max = MS_MIN256_F32(tmp_max, max_value_8);
MS_ST256_F32(dst_c_ptr, tmp_max);
} // ic8 loop
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
for (; ci < c4; ci += C4NUM) {
const float *src_c_ptr = src_plane_ptr + ci;
float *dst_c_ptr = dst_plane_ptr + ci;
MS_FLOAT32X4 tmp_max = MS_MOVQ_F32(-FLT_MAX);
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) {
for (int kw = real_win_w_start; kw < real_win_w_end; kw++) {
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel;
tmp_max = MS_MAXQ_F32(tmp_max, MS_LDQ_F32(src_win_ptr));
} // win_w loop
} // win_h loop
tmp_max = MS_MAXQ_F32(tmp_max, min_value);
tmp_max = MS_MINQ_F32(tmp_max, max_value);
MS_STQ_F32(dst_c_ptr, tmp_max);
} // ic4 loop
#endif
MS_SIMD_RUN_NO_SCALAR(SimdFp32MaxPoolingBatchCoreCalc, src_plane_ptr, channel, dst_plane_ptr, ci,
real_win_h_start, real_win_h_end, real_win_w_start, real_win_w_end, in_h_index, in_w,
in_w_index);
for (; ci < channel; ci++) {
float *dst_c_ptr = dst_plane_ptr + ci;
const float *src_c_ptr = src_plane_ptr + ci;