avx512 instruction optimal

This commit is contained in:
greatpan 2022-06-09 15:18:29 +08:00
parent c1f8bdb59f
commit 2f6615f87c
10 changed files with 19 additions and 27 deletions

View File

@ -160,7 +160,7 @@ if("${X86_64_SIMD}" STREQUAL "avx512")
${NNACL_DIR}/fp32/matmul_avx512_fp32.c) ${NNACL_DIR}/fp32/matmul_avx512_fp32.c)
set_source_files_properties(${MS_X86_AVX512_SRC} PROPERTIES LANGUAGE C set_source_files_properties(${MS_X86_AVX512_SRC} PROPERTIES LANGUAGE C
COMPILE_FLAGS "${CMAKE_C_FLAGS} -mavx512f") COMPILE_FLAGS "${CMAKE_C_FLAGS} -mavx512f -fPIC")
endif() endif()

View File

@ -25,7 +25,7 @@ extern "C" {
@SIMD_INSTRUCTION_BEGIN@ @SIMD_INSTRUCTION_BEGIN@
static inline int Fp32Relu@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { static inline int Fp32Relu@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) {
SIMD_F32 zero = SIMD_MOV_F32(0.0f); SIMD_F32 zero = SIMD_SET0_F32;
for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
SIMD_ST_F32(dst + index, SIMD_MAX_F32(SIMD_LD_F32(src + index), zero)); SIMD_ST_F32(dst + index, SIMD_MAX_F32(SIMD_LD_F32(src + index), zero));
} }
@ -41,7 +41,7 @@ static inline int Int32Relu@SIMD_INSTRUCTION@(int index, const int32_t *src, int
} }
static inline int Fp32Relu6@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { static inline int Fp32Relu6@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) {
SIMD_F32 zero = SIMD_MOV_F32(0.0f); SIMD_F32 zero = SIMD_SET0_F32;
SIMD_F32 six = SIMD_MOV_F32(6.0f); SIMD_F32 six = SIMD_MOV_F32(6.0f);
for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
SIMD_ST_F32(dst + index, SIMD_CLAMP_F32(SIMD_LD_F32(src + index), zero, six)); SIMD_ST_F32(dst + index, SIMD_CLAMP_F32(SIMD_LD_F32(src + index), zero, six));
@ -53,7 +53,7 @@ static inline int LRelu@SIMD_INSTRUCTION@(int index, const float *src, int lengt
SIMD_F32 alpha_data = SIMD_MOV_F32(alpha); SIMD_F32 alpha_data = SIMD_MOV_F32(alpha);
for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
SIMD_F32 src_tmp = SIMD_LD_F32(src + index); SIMD_F32 src_tmp = SIMD_LD_F32(src + index);
SIMD_MASK mask = SIMD_CMPGT_F32(SIMD_MOV_F32(0.0f), src_tmp); SIMD_MASK mask = SIMD_CMPGT_F32(SIMD_SET0_F32, src_tmp);
SIMD_ST_F32(dst + index, SIMD_BLEND_F32(src_tmp, SIMD_MUL_F32(src_tmp, alpha_data), mask)); SIMD_ST_F32(dst + index, SIMD_BLEND_F32(src_tmp, SIMD_MUL_F32(src_tmp, alpha_data), mask));
} }
return index; return index;
@ -61,7 +61,7 @@ static inline int LRelu@SIMD_INSTRUCTION@(int index, const float *src, int lengt
static inline int Sigmoid@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { static inline int Sigmoid@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) {
for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
SIMD_EXP_ST_F32(SIMD_SUB_F32(SIMD_MOV_F32(0.0f), (SIMD_LD_F32(src + index))), dst + index); SIMD_EXP_ST_F32(SIMD_SUB_F32(SIMD_SET0_F32, (SIMD_LD_F32(src + index))), dst + index);
SIMD_ST_F32(dst + index, SIMD_ST_F32(dst + index,
SIMD_DIV_F32(SIMD_MOV_F32(1.0f), SIMD_ADD_F32(SIMD_MOV_F32(1.0f), SIMD_LD_F32(dst + index)))); SIMD_DIV_F32(SIMD_MOV_F32(1.0f), SIMD_ADD_F32(SIMD_MOV_F32(1.0f), SIMD_LD_F32(dst + index))));
} }
@ -142,7 +142,7 @@ static inline int Elu@SIMD_INSTRUCTION@(int index, const float *src, int length,
for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
SIMD_F32 src_tmp = SIMD_LD_F32(src + index); SIMD_F32 src_tmp = SIMD_LD_F32(src + index);
SIMD_F32 exp_tmp = SIMD_SUB_N_F32(SIMD_EXP_F32(src_tmp), 1.0f); SIMD_F32 exp_tmp = SIMD_SUB_N_F32(SIMD_EXP_F32(src_tmp), 1.0f);
SIMD_MASK mask = SIMD_CMPLE_F32(src_tmp, SIMD_MOV_F32(0.0f)); SIMD_MASK mask = SIMD_CMPLE_F32(src_tmp, SIMD_SET0_F32);
SIMD_ST_F32(dst + index, SIMD_BLEND_F32(src_tmp, SIMD_MUL_N_F32(exp_tmp, alpha), mask)); SIMD_ST_F32(dst + index, SIMD_BLEND_F32(src_tmp, SIMD_MUL_N_F32(exp_tmp, alpha), mask));
} }
return index; return index;
@ -152,7 +152,7 @@ static inline int Celu@SIMD_INSTRUCTION@(int index, const float *src, int length
for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
SIMD_F32 src_tmp = SIMD_LD_F32(src + index); SIMD_F32 src_tmp = SIMD_LD_F32(src + index);
SIMD_F32 exp_tmp = SIMD_SUB_N_F32(SIMD_EXP_F32(SIMD_DIV_N_F32(src_tmp, alpha)), 1.0f); SIMD_F32 exp_tmp = SIMD_SUB_N_F32(SIMD_EXP_F32(SIMD_DIV_N_F32(src_tmp, alpha)), 1.0f);
SIMD_MASK mask = SIMD_CMPLE_F32(src_tmp, SIMD_MOV_F32(0.0f)); SIMD_MASK mask = SIMD_CMPLE_F32(src_tmp, SIMD_SET0_F32);
SIMD_ST_F32(dst + index, SIMD_BLEND_F32(src_tmp, SIMD_MUL_N_F32(exp_tmp, alpha), mask)); SIMD_ST_F32(dst + index, SIMD_BLEND_F32(src_tmp, SIMD_MUL_N_F32(exp_tmp, alpha), mask));
} }
return index; return index;

View File

@ -26,8 +26,8 @@ extern "C" {
static inline int LayerNormMeanAndSquare@SIMD_INSTRUCTION@(int index, const float *src, int num, float *mean, float *square_mean) { static inline int LayerNormMeanAndSquare@SIMD_INSTRUCTION@(int index, const float *src, int num, float *mean, float *square_mean) {
if (num >= 4 * BLOCK_NUM) { if (num >= 4 * BLOCK_NUM) {
SIMD_F32 sum_val = SIMD_MOV_F32(0.0f); SIMD_F32 sum_val = SIMD_SET0_F32;
SIMD_F32 square_sum_val = SIMD_MOV_F32(0.0f); SIMD_F32 square_sum_val = SIMD_SET0_F32;
for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
SIMD_F32 value = SIMD_LD_F32(src + index); SIMD_F32 value = SIMD_LD_F32(src + index);
SIMD_F32 square_value = SIMD_MUL_F32(value, value); SIMD_F32 square_value = SIMD_MUL_F32(value, value);

View File

@ -32,7 +32,7 @@ static inline int AvgPoolingBatch@SIMD_INSTRUCTION@(int ci, const float *src_pla
for (int block_max_size = channel - BLOCK_NUM + 1; ci < block_max_size; ci += BLOCK_NUM) { for (int block_max_size = channel - BLOCK_NUM + 1; ci < block_max_size; ci += BLOCK_NUM) {
const float *src_c_ptr = src_plane_ptr + ci; const float *src_c_ptr = src_plane_ptr + ci;
float *dst_c_ptr = dst_plane_ptr + ci; float *dst_c_ptr = dst_plane_ptr + ci;
SIMD_F32 tmp_avg = SIMD_MOV_F32(0.0f); SIMD_F32 tmp_avg = SIMD_SET0_F32;
int real_count = 0; int real_count = 0;
for (int h = real_win_h_start; h < real_win_h_end; h++) { for (int h = real_win_h_start; h < real_win_h_end; h++) {
for (int w = real_win_w_start; w < real_win_w_end; w++) { for (int w = real_win_w_start; w < real_win_w_end; w++) {

View File

@ -48,7 +48,7 @@ static inline int64_t SoftmaxNormCalcNorm@SIMD_INSTRUCTION@(int64_t index, const
static inline int64_t SoftmaxLastAxisGetExpSum@SIMD_INSTRUCTION@(int64_t index, const float *src, float *dst, static inline int64_t SoftmaxLastAxisGetExpSum@SIMD_INSTRUCTION@(int64_t index, const float *src, float *dst,
int cur_batch_offset, float max, float *exp_sum, int channel) { int cur_batch_offset, float max, float *exp_sum, int channel) {
SIMD_F32 sum_val = SIMD_MOV_F32(0.0f); SIMD_F32 sum_val = SIMD_SET0_F32;
for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
SIMD_F32 input = SIMD_LD_F32(src + cur_batch_offset + index); SIMD_F32 input = SIMD_LD_F32(src + cur_batch_offset + index);
SIMD_F32 output = SIMD_SUB_F32(input, SIMD_MOV_F32(max)); SIMD_F32 output = SIMD_SUB_F32(input, SIMD_MOV_F32(max));

View File

@ -42,6 +42,7 @@
#define MS_ADD512_EPI32 _mm512_add_epi32 #define MS_ADD512_EPI32 _mm512_add_epi32
#define MS_MOV512_F32 _mm512_set1_ps #define MS_MOV512_F32 _mm512_set1_ps
#define MS_MOV512_EPI32 _mm512_set1_epi32 #define MS_MOV512_EPI32 _mm512_set1_epi32
#define MS_MOV512_VAL0_F32 _mm512_setzero_ps()
#define MS_MLA512_F32(src1, src2, src3) _mm512_fmadd_ps(src2, src3, src1) #define MS_MLA512_F32(src1, src2, src3) _mm512_fmadd_ps(src2, src3, src1)
#define MS_ST512_F32 _mm512_storeu_ps #define MS_ST512_F32 _mm512_storeu_ps
#define MS_ST512_EPI32(src1, src2) _mm512_storeu_si512((__m512i *)(src1), src2) #define MS_ST512_EPI32(src1, src2) _mm512_storeu_si512((__m512i *)(src1), src2)
@ -82,22 +83,9 @@
#define MS_BLEND512_EPI32(src1, src2, mask) _mm512_mask_blend_epi32(mask, src1, src2) #define MS_BLEND512_EPI32(src1, src2, mask) _mm512_mask_blend_epi32(mask, src1, src2)
#define MS_CAST512_F32_S32(src) _mm512_castsi512_ps(src) #define MS_CAST512_F32_S32(src) _mm512_castsi512_ps(src)
#define MS_REDUCE_ADD512_F32(src) _mm512_reduce_add_ps(src) #define MS_REDUCE_ADD512_F32(src) _mm512_reduce_add_ps(src)
#define MS_GET_MAX512_F32(src) _mm512_reduce_max_ps(src)
static inline float MS_GET_MAX512_F32(__m512 src) { #define MS_GET_MIN512_F32(src) _mm512_reduce_min_ps(src)
float result = MS_F32X16_GETI(src, 0); #define MS_GET_SUM512_F32(src) _mm512_reduce_add_ps(src)
for (int i = 1; i < 16; i++) { // avx512 block num : 16
result = fmaxf(result, MS_F32X16_GETI(src, i));
}
return result;
}
static inline float MS_GET_SUM512_F32(__m512 src) {
float result = MS_F32X16_GETI(src, 0);
for (int i = 1; i < 16; i++) { // avx512 block num : 16
result = result + MS_F32X16_GETI(src, i);
}
return result;
}
#define MS_DIV512_EPI32(src1, src2) \ #define MS_DIV512_EPI32(src1, src2) \
_mm512_cvttps_epi32(MS_DIV512_F32(_mm512_cvtepi32_ps(src1), _mm512_cvtepi32_ps(src2))) _mm512_cvttps_epi32(MS_DIV512_F32(_mm512_cvtepi32_ps(src1), _mm512_cvtepi32_ps(src2)))

View File

@ -37,6 +37,7 @@
#define MS_ADD256_EPI32 _mm256_add_epi32 #define MS_ADD256_EPI32 _mm256_add_epi32
#define MS_MOV256_F32 _mm256_set1_ps #define MS_MOV256_F32 _mm256_set1_ps
#define MS_MOV256_EPI32 _mm256_set1_epi32 #define MS_MOV256_EPI32 _mm256_set1_epi32
#define MS_MOV256_VAL0_F32 _mm256_setzero_ps()
#define MS_MLA256_F32(src1, src2, src3) _mm256_fmadd_ps(src2, src3, src1) #define MS_MLA256_F32(src1, src2, src3) _mm256_fmadd_ps(src2, src3, src1)
#define MS_ST256_F32 _mm256_storeu_ps #define MS_ST256_F32 _mm256_storeu_ps
#define MS_ST256_EPI32(src1, src2) _mm256_storeu_si256((__m256i *)(src1), src2) #define MS_ST256_EPI32(src1, src2) _mm256_storeu_si256((__m256i *)(src1), src2)

View File

@ -51,6 +51,7 @@
// move (float/int) data // move (float/int) data
#define SIMD_MOV_F32 MS_SIMD_INSTRUCTION_F32(MS_MOV) #define SIMD_MOV_F32 MS_SIMD_INSTRUCTION_F32(MS_MOV)
#define SIMD_MOV_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_MOV) #define SIMD_MOV_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_MOV)
#define SIMD_SET0_F32 MS_SIMD_INSTRUCTION(MS_MOV, _VAL0_F32)
// load (float/int) data // load (float/int) data
#define SIMD_LD_F32 MS_SIMD_INSTRUCTION_F32(MS_LD) #define SIMD_LD_F32 MS_SIMD_INSTRUCTION_F32(MS_LD)

View File

@ -39,6 +39,7 @@
#define MS_MOVQ_F32 vmovq_n_f32 #define MS_MOVQ_F32 vmovq_n_f32
#define MS_MOV128_F32 vmovq_n_f32 #define MS_MOV128_F32 vmovq_n_f32
#define MS_MOVQ_EPI32 vmovq_n_s32 #define MS_MOVQ_EPI32 vmovq_n_s32
#define MS_MOV128_VAL0_F32 vmovq_n_f32(0.0f)
#define MS_MOV128_EPI32 vmovq_n_s32 #define MS_MOV128_EPI32 vmovq_n_s32
#define MS_SUBQ_F32 vsubq_f32 #define MS_SUBQ_F32 vsubq_f32
#define MS_SUB128_F32 vsubq_f32 #define MS_SUB128_F32 vsubq_f32

View File

@ -43,6 +43,7 @@
#define MS_MOV128_F32 _mm_set1_ps #define MS_MOV128_F32 _mm_set1_ps
#define MS_MOVQ_EPI32 _mm_set1_epi32 #define MS_MOVQ_EPI32 _mm_set1_epi32
#define MS_MOV128_EPI32 _mm_set1_epi32 #define MS_MOV128_EPI32 _mm_set1_epi32
#define MS_MOV128_VAL0_F32 _mm_setzero_ps()
#define MS_MLAQ_F32(src1, src2, src3) _mm_add_ps(src1, _mm_mul_ps(src2, src3)) #define MS_MLAQ_F32(src1, src2, src3) _mm_add_ps(src1, _mm_mul_ps(src2, src3))
#define MS_STQ_F32 _mm_storeu_ps #define MS_STQ_F32 _mm_storeu_ps
#define MS_ST128_F32 _mm_storeu_ps #define MS_ST128_F32 _mm_storeu_ps