forked from mindspore-Ecosystem/mindspore
!30356 [MSLITE][CPU] AVX512/256/SSE/NENO Advanced packaging, and Pool Op Refactoring and optimization
Merge pull request !30356 from Greatpan/avx512_pooling
This commit is contained in:
commit
26bf47487c
|
@ -19,6 +19,33 @@
|
|||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
// 32 bits, block_size : (512/256/128/32), block_num : (16/8/4/1)
|
||||
#define SimdFp32AvgPoolingBatchCoreCalc(block_size, block_num, src_plane_ptr, channel, dst_plane_ptr, ci, \
|
||||
real_win_h_start, real_win_h_end, real_win_w_start, real_win_w_end, \
|
||||
in_h_index, in_w, in_w_index) \
|
||||
do { \
|
||||
MS_FLOAT_32xN(block_num) min_val_##block_num = MS_MOVN_F32(block_size, minf); \
|
||||
MS_FLOAT_32xN(block_num) max_val_##block_num = MS_MOVN_F32(block_size, maxf); \
|
||||
for (int block_max_size = channel - block_num + 1; ci < block_max_size; ci += block_num) { \
|
||||
const float *src_c_ptr = src_plane_ptr + ci; \
|
||||
float *dst_c_ptr = dst_plane_ptr + ci; \
|
||||
MS_FLOAT_32xN(block_num) tmp_avg = MS_MOVN_F32(block_size, 0.0f); \
|
||||
int real_count = 0; \
|
||||
for (int h = real_win_h_start; h < real_win_h_end; h++) { \
|
||||
for (int w = real_win_w_start; w < real_win_w_end; w++) { \
|
||||
const float *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel; \
|
||||
tmp_avg = MS_ADD_F32(block_size, tmp_avg, MS_LD_F32(block_size, src_win_ptr)); \
|
||||
++real_count; \
|
||||
} \
|
||||
} \
|
||||
MS_CHECK_TRUE_RET(real_count != 0, NNACL_ERR); \
|
||||
tmp_avg = MS_DIV_F32(block_size, tmp_avg, MS_MOVN_F32(block_size, real_count)); \
|
||||
tmp_avg = MS_MAX_F32(block_size, tmp_avg, min_val_##block_num); \
|
||||
tmp_avg = MS_MIN_F32(block_size, tmp_avg, max_val_##block_num); \
|
||||
MS_ST_F32(block_size, dst_c_ptr, tmp_avg); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
int AvgPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, int task_id,
|
||||
float minf, float maxf) {
|
||||
int win_w = pooling_param->window_w_, win_h = pooling_param->window_h_;
|
||||
|
@ -28,16 +55,6 @@ int AvgPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParam
|
|||
int out_plane = output_w * output_h;
|
||||
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
|
||||
NNACL_CHECK_ZERO_RETURN_ERR(output_w);
|
||||
#ifdef ENABLE_AVX
|
||||
int c8 = channel / C8NUM * C8NUM;
|
||||
MS_FLOAT32X8 min_value_8 = MS_MOV256_F32(minf);
|
||||
MS_FLOAT32X8 max_value_8 = MS_MOV256_F32(maxf);
|
||||
#endif
|
||||
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
|
||||
int c4 = channel / C4NUM * C4NUM;
|
||||
MS_FLOAT32X4 min_value = MS_MOVQ_F32(minf);
|
||||
MS_FLOAT32X4 max_value = MS_MOVQ_F32(maxf);
|
||||
#endif
|
||||
|
||||
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += pooling_param->thread_num_) {
|
||||
int cal_start_index = thread_id * TILE_NUM;
|
||||
|
@ -57,46 +74,11 @@ int AvgPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParam
|
|||
int real_win_w_start = MSMAX(0, -in_w_index);
|
||||
int real_win_w_end = MSMIN(win_w, in_w - in_w_index);
|
||||
int ci = 0;
|
||||
#ifdef ENABLE_AVX
|
||||
for (; ci < c8; ci += C8NUM) {
|
||||
const float *src_c_ptr = src_plane_ptr + ci;
|
||||
float *dst_c_ptr = dst_plane_ptr + ci;
|
||||
MS_FLOAT32X8 tmp_avg = MS_MOV256_F32(0);
|
||||
int real_count = 0;
|
||||
for (int h = real_win_h_start; h < real_win_h_end; h++) {
|
||||
for (int w = real_win_w_start; w < real_win_w_end; w++) {
|
||||
const float *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel;
|
||||
tmp_avg = MS_ADD256_F32(tmp_avg, MS_LD256_F32(src_win_ptr));
|
||||
++real_count;
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
MS_CHECK_TRUE_RET(real_count != 0, NNACL_ERR);
|
||||
tmp_avg = MS_DIV256_F32(tmp_avg, MS_MOV256_F32(real_count));
|
||||
tmp_avg = MS_MAX256_F32(tmp_avg, min_value_8);
|
||||
tmp_avg = MS_MIN256_F32(tmp_avg, max_value_8);
|
||||
MS_ST256_F32(dst_c_ptr, tmp_avg);
|
||||
} // ic8-1 loop
|
||||
#endif
|
||||
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
|
||||
for (; ci < c4; ci += C4NUM) {
|
||||
const float *src_c_ptr = src_plane_ptr + ci;
|
||||
float *dst_c_ptr = dst_plane_ptr + ci;
|
||||
MS_FLOAT32X4 tmp_avg = MS_MOVQ_F32(0);
|
||||
int real_count = 0;
|
||||
for (int h = real_win_h_start; h < real_win_h_end; h++) {
|
||||
for (int w = real_win_w_start; w < real_win_w_end; w++) {
|
||||
const float *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel;
|
||||
tmp_avg = MS_ADDQ_F32(tmp_avg, MS_LDQ_F32(src_win_ptr));
|
||||
++real_count;
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
MS_CHECK_TRUE_RET(real_count != 0, NNACL_ERR);
|
||||
tmp_avg = MS_DIVQ_F32(tmp_avg, MS_MOVQ_F32(real_count));
|
||||
tmp_avg = MS_MAXQ_F32(tmp_avg, min_value);
|
||||
tmp_avg = MS_MINQ_F32(tmp_avg, max_value);
|
||||
MS_STQ_F32(dst_c_ptr, tmp_avg);
|
||||
} // ic4-1 loop
|
||||
#endif
|
||||
|
||||
MS_SIMD_RUN_NO_SCALAR(SimdFp32AvgPoolingBatchCoreCalc, src_plane_ptr, channel, dst_plane_ptr, ci,
|
||||
real_win_h_start, real_win_h_end, real_win_w_start, real_win_w_end, in_h_index, in_w,
|
||||
in_w_index);
|
||||
|
||||
for (; ci < channel; ci++) {
|
||||
const float *src_c_ptr = src_plane_ptr + ci;
|
||||
float *dst_c_ptr = dst_plane_ptr + ci;
|
||||
|
@ -377,6 +359,29 @@ int AvgPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
// 32 bits, block_size : (512/256/128/32), block_num : (16/8/4/1)
|
||||
#define SimdFp32MaxPoolingBatchCoreCalc(block_size, block_num, src_plane_ptr, channel, dst_plane_ptr, ci, \
|
||||
real_win_h_start, real_win_h_end, real_win_w_start, real_win_w_end, \
|
||||
in_h_index, in_w, in_w_index) \
|
||||
do { \
|
||||
MS_FLOAT_32xN(block_num) min_val_##block_num = MS_MOVN_F32(block_size, minf); \
|
||||
MS_FLOAT_32xN(block_num) max_val_##block_num = MS_MOVN_F32(block_size, maxf); \
|
||||
for (int block_max_size = channel - block_num + 1; ci < block_max_size; ci += block_num) { \
|
||||
const float *src_c_ptr = src_plane_ptr + ci; \
|
||||
float *dst_c_ptr = dst_plane_ptr + ci; \
|
||||
MS_FLOAT_32xN(block_num) tmp_max = MS_MOVN_F32(block_size, -FLT_MAX); \
|
||||
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) { \
|
||||
for (int kw = real_win_w_start; kw < real_win_w_end; kw++) { \
|
||||
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel; \
|
||||
tmp_max = MS_MAX_F32(block_size, tmp_max, MS_LD_F32(block_size, src_win_ptr)); \
|
||||
} \
|
||||
} \
|
||||
tmp_max = MS_MAX_F32(block_size, tmp_max, min_val_##block_num); \
|
||||
tmp_max = MS_MIN_F32(block_size, tmp_max, max_val_##block_num); \
|
||||
MS_ST_F32(block_size, dst_c_ptr, tmp_max); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
int MaxPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, int task_id,
|
||||
float minf, float maxf) {
|
||||
int in_w = pooling_param->input_w_, in_h = pooling_param->input_h_;
|
||||
|
@ -386,16 +391,6 @@ int MaxPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParam
|
|||
int out_plane = output_w * output_h;
|
||||
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
|
||||
NNACL_CHECK_ZERO_RETURN_ERR(output_w);
|
||||
#ifdef ENABLE_AVX
|
||||
int c8 = channel / C8NUM * C8NUM;
|
||||
MS_FLOAT32X8 min_value_8 = MS_MOV256_F32(minf);
|
||||
MS_FLOAT32X8 max_value_8 = MS_MOV256_F32(maxf);
|
||||
#endif
|
||||
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
|
||||
int c4 = channel / C4NUM * C4NUM;
|
||||
MS_FLOAT32X4 min_value = MS_MOVQ_F32(minf);
|
||||
MS_FLOAT32X4 max_value = MS_MOVQ_F32(maxf);
|
||||
#endif
|
||||
|
||||
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += pooling_param->thread_num_) {
|
||||
int cal_start_index = thread_id * TILE_NUM;
|
||||
|
@ -415,38 +410,11 @@ int MaxPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParam
|
|||
int real_win_w_start = MSMAX(0, -in_w_index);
|
||||
int real_win_w_end = MSMIN(win_w, in_w - in_w_index);
|
||||
int ci = 0;
|
||||
#ifdef ENABLE_AVX
|
||||
for (; ci < c8; ci += C8NUM) {
|
||||
const float *src_c_ptr = src_plane_ptr + ci;
|
||||
float *dst_c_ptr = dst_plane_ptr + ci;
|
||||
MS_FLOAT32X8 tmp_max = MS_MOV256_F32(-FLT_MAX);
|
||||
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) {
|
||||
for (int kw = real_win_w_start; kw < real_win_w_end; kw++) {
|
||||
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel;
|
||||
tmp_max = MS_MAX256_F32(tmp_max, MS_LD256_F32(src_win_ptr));
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
tmp_max = MS_MAX256_F32(tmp_max, min_value_8);
|
||||
tmp_max = MS_MIN256_F32(tmp_max, max_value_8);
|
||||
MS_ST256_F32(dst_c_ptr, tmp_max);
|
||||
} // ic8 loop
|
||||
#endif
|
||||
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
|
||||
for (; ci < c4; ci += C4NUM) {
|
||||
const float *src_c_ptr = src_plane_ptr + ci;
|
||||
float *dst_c_ptr = dst_plane_ptr + ci;
|
||||
MS_FLOAT32X4 tmp_max = MS_MOVQ_F32(-FLT_MAX);
|
||||
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) {
|
||||
for (int kw = real_win_w_start; kw < real_win_w_end; kw++) {
|
||||
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel;
|
||||
tmp_max = MS_MAXQ_F32(tmp_max, MS_LDQ_F32(src_win_ptr));
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
tmp_max = MS_MAXQ_F32(tmp_max, min_value);
|
||||
tmp_max = MS_MINQ_F32(tmp_max, max_value);
|
||||
MS_STQ_F32(dst_c_ptr, tmp_max);
|
||||
} // ic4 loop
|
||||
#endif
|
||||
|
||||
MS_SIMD_RUN_NO_SCALAR(SimdFp32MaxPoolingBatchCoreCalc, src_plane_ptr, channel, dst_plane_ptr, ci,
|
||||
real_win_h_start, real_win_h_end, real_win_w_start, real_win_w_end, in_h_index, in_w,
|
||||
in_w_index);
|
||||
|
||||
for (; ci < channel; ci++) {
|
||||
float *dst_c_ptr = dst_plane_ptr + ci;
|
||||
const float *src_c_ptr = src_plane_ptr + ci;
|
||||
|
|
Loading…
Reference in New Issue