forked from mindspore-Ecosystem/mindspore
add avx int8 add
This commit is contained in:
parent
52953f16fc
commit
ca84fa1a00
|
@ -9,6 +9,10 @@ if (PLATFORM_ARM32 OR PLATFORM_ARM64)
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fomit-frame-pointer -fstrict-aliasing -ffunction-sections -fdata-sections -ffast-math")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fomit-frame-pointer -fstrict-aliasing -ffunction-sections -fdata-sections -ffast-math")
|
||||||
endif()
|
endif()
|
||||||
endif ()
|
endif ()
|
||||||
|
if ("${X86_64_SIMD}" STREQUAL "avx")
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.1 -mavx -mavx2")
|
||||||
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -mavx -mavx2")
|
||||||
|
endif ()
|
||||||
|
|
||||||
########################### files ###########################
|
########################### files ###########################
|
||||||
file(GLOB KERNEL_SRC
|
file(GLOB KERNEL_SRC
|
||||||
|
@ -39,6 +43,7 @@ endif()
|
||||||
|
|
||||||
if ("${X86_64_SIMD}" STREQUAL "avx")
|
if ("${X86_64_SIMD}" STREQUAL "avx")
|
||||||
file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/x86_64_sse/*.c
|
file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/x86_64_sse/*.c
|
||||||
|
${NNACL_DIR}/x86_64_avx/*.c
|
||||||
${NNACL_DIR}/assembly/avx/*.S)
|
${NNACL_DIR}/assembly/avx/*.S)
|
||||||
set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)
|
set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -18,16 +18,19 @@
|
||||||
#ifdef ENABLE_NEON
|
#ifdef ENABLE_NEON
|
||||||
#include <arm_neon.h>
|
#include <arm_neon.h>
|
||||||
#endif
|
#endif
|
||||||
|
#ifdef ENABLE_AVX
|
||||||
|
#include <x86intrin.h>
|
||||||
|
#include "nnacl/x86_64_avx/common_utils.h"
|
||||||
|
#endif
|
||||||
#include "nnacl/int8/fixed_point.h"
|
#include "nnacl/int8/fixed_point.h"
|
||||||
|
|
||||||
void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params) {
|
void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params) {
|
||||||
int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_args_.left_shift_);
|
int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_args_.left_shift_);
|
||||||
int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_args_.left_shift_);
|
int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_args_.left_shift_);
|
||||||
int index = 0;
|
int index = 0;
|
||||||
|
|
||||||
#ifdef ENABLE_ARM
|
#ifdef ENABLE_ARM
|
||||||
const int8x16_t min_vec = vdupq_n_s8(params->min_);
|
const int8x16_t min_vec = vdupq_n_s8(params->min_);
|
||||||
const int8x16_t max_vac = vdupq_n_s8(params->max_);
|
const int8x16_t max_vec = vdupq_n_s8(params->max_);
|
||||||
|
|
||||||
const int16x8_t in0_zp_vec = vdupq_n_s16(params->in0_args_.zp_);
|
const int16x8_t in0_zp_vec = vdupq_n_s16(params->in0_args_.zp_);
|
||||||
const int16x8_t in1_zp_vec = vdupq_n_s16(params->in1_args_.zp_);
|
const int16x8_t in1_zp_vec = vdupq_n_s16(params->in1_args_.zp_);
|
||||||
|
@ -142,12 +145,11 @@ void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int siz
|
||||||
const int16x8_t out_s16_2 = vaddq_s16(vcombine_s16(out3_s16, out4_s16), out_zp_vec);
|
const int16x8_t out_s16_2 = vaddq_s16(vcombine_s16(out3_s16, out4_s16), out_zp_vec);
|
||||||
|
|
||||||
const int8x16_t out = vcombine_s8(vqmovn_s16(out_s16_1), vqmovn_s16(out_s16_2));
|
const int8x16_t out = vcombine_s8(vqmovn_s16(out_s16_1), vqmovn_s16(out_s16_2));
|
||||||
const int8x16_t int8_out = vmaxq_s8(min_vec, vminq_s8(max_vac, out));
|
const int8x16_t int8_out = vmaxq_s8(min_vec, vminq_s8(max_vec, out));
|
||||||
|
|
||||||
vst1q_s8(output + index, int8_out);
|
vst1q_s8(output + index, int8_out);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
for (; index < size; index++) {
|
for (; index < size; index++) {
|
||||||
const int32_t in0_left = (input0[index] + params->in0_args_.zp_) * in0_left_shift;
|
const int32_t in0_left = (input0[index] + params->in0_args_.zp_) * in0_left_shift;
|
||||||
const int32_t in1_left = (input1[index] + params->in1_args_.zp_) * in1_left_shift;
|
const int32_t in1_left = (input1[index] + params->in1_args_.zp_) * in1_left_shift;
|
||||||
|
@ -173,7 +175,7 @@ void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, i
|
||||||
#ifdef ENABLE_ARM
|
#ifdef ENABLE_ARM
|
||||||
/* const value init */
|
/* const value init */
|
||||||
const int8x16_t min_vec = vdupq_n_s8(params->min_);
|
const int8x16_t min_vec = vdupq_n_s8(params->min_);
|
||||||
const int8x16_t max_vac = vdupq_n_s8(params->max_);
|
const int8x16_t max_vec = vdupq_n_s8(params->max_);
|
||||||
|
|
||||||
const int16x8_t ptr_zp_vec = vdupq_n_s16(ptr_args->zp_);
|
const int16x8_t ptr_zp_vec = vdupq_n_s16(ptr_args->zp_);
|
||||||
const int16x8_t ele_zp_vec = vdupq_n_s16(ele_args->zp_);
|
const int16x8_t ele_zp_vec = vdupq_n_s16(ele_args->zp_);
|
||||||
|
@ -293,7 +295,7 @@ void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, i
|
||||||
const int16x8_t out_s16_2 = vaddq_s16(vcombine_s16(out3_s16, out4_s16), out_zp_vec);
|
const int16x8_t out_s16_2 = vaddq_s16(vcombine_s16(out3_s16, out4_s16), out_zp_vec);
|
||||||
|
|
||||||
const int8x16_t out = vcombine_s8(vqmovn_s16(out_s16_1), vqmovn_s16(out_s16_2));
|
const int8x16_t out = vcombine_s8(vqmovn_s16(out_s16_1), vqmovn_s16(out_s16_2));
|
||||||
const int8x16_t int8_out = vmaxq_s8(min_vec, vminq_s8(max_vac, out));
|
const int8x16_t int8_out = vmaxq_s8(min_vec, vminq_s8(max_vec, out));
|
||||||
|
|
||||||
vst1q_s8(output + index, int8_out);
|
vst1q_s8(output + index, int8_out);
|
||||||
}
|
}
|
||||||
|
@ -325,3 +327,357 @@ int BroadcastAddInt8(const int8_t *in0, const int8_t *in1, int8_t *tile_in0, int
|
||||||
TileDimensionsInt8(in0, in1, tile_in0, tile_in1, param);
|
TileDimensionsInt8(in0, in1, tile_in0, tile_in1, param);
|
||||||
return ElementAddInt8(tile_in0, tile_in1, out, size);
|
return ElementAddInt8(tile_in0, tile_in1, out, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef ENABLE_AVX
|
||||||
|
void AddInt8_AVX2(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params) {
|
||||||
|
const int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_args_.left_shift_);
|
||||||
|
const int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_args_.left_shift_);
|
||||||
|
const __m128i min_vec = _mm_set1_epi8(params->min_);
|
||||||
|
const __m128i max_vec = _mm_set1_epi8(params->max_);
|
||||||
|
const __m128i in0_zp_vec = _mm_set1_epi16(params->in0_args_.zp_);
|
||||||
|
const __m128i in1_zp_vec = _mm_set1_epi16(params->in1_args_.zp_);
|
||||||
|
const __m128i out_zp_vec = _mm_set1_epi16(params->out_zp_);
|
||||||
|
const __m128i in0_left_vec = _mm_set1_epi32(in0_left_shift);
|
||||||
|
const __m128i in1_left_vec = _mm_set1_epi32(in1_left_shift);
|
||||||
|
const __m128i in0_multiplier = _mm_set1_epi32(params->in0_args_.multiplier_);
|
||||||
|
const __m128i in1_multiplier = _mm_set1_epi32(params->in1_args_.multiplier_);
|
||||||
|
const __m128i out_multiplier = _mm_set1_epi32(params->out_multiplier_);
|
||||||
|
int index = 0;
|
||||||
|
for (; index <= size - 16; index += 16) {
|
||||||
|
const __m128i in0_src = _mm_loadu_si128((__m128i_u *)(input0 + index));
|
||||||
|
const __m128i in1_src = _mm_loadu_si128((__m128i_u *)(input1 + index));
|
||||||
|
|
||||||
|
const __m256i in0_s16 = _mm256_cvtepi8_epi16(in0_src);
|
||||||
|
const __m128i in0_s16_low = _mm256_extractf128_si256(in0_s16, 0);
|
||||||
|
const __m128i in0_s16_high = _mm256_extractf128_si256(in0_s16, 1);
|
||||||
|
const __m256i in1_s16 = _mm256_cvtepi8_epi16(in1_src);
|
||||||
|
const __m128i in1_s16_low = _mm256_extractf128_si256(in1_s16, 0);
|
||||||
|
const __m128i in1_s16_high = _mm256_extractf128_si256(in1_s16, 1);
|
||||||
|
|
||||||
|
const __m128i in0_zp_low = _mm_add_epi16(in0_s16_low, in0_zp_vec);
|
||||||
|
const __m128i in0_zp_high = _mm_add_epi16(in0_s16_high, in0_zp_vec);
|
||||||
|
const __m128i in1_zp_low = _mm_add_epi16(in1_s16_low, in1_zp_vec);
|
||||||
|
const __m128i in1_zp_high = _mm_add_epi16(in1_s16_high, in1_zp_vec);
|
||||||
|
|
||||||
|
__m256i tmp_in0 = _mm256_cvtepi16_epi32(in0_zp_low);
|
||||||
|
__m128i in0_1 = _mm256_extractf128_si256(tmp_in0, 0);
|
||||||
|
__m128i in0_2 = _mm256_extractf128_si256(tmp_in0, 1);
|
||||||
|
tmp_in0 = _mm256_cvtepi16_epi32(in0_zp_high);
|
||||||
|
__m128i in0_3 = _mm256_extractf128_si256(tmp_in0, 0);
|
||||||
|
__m128i in0_4 = _mm256_extractf128_si256(tmp_in0, 1);
|
||||||
|
__m256i tmp_in1 = _mm256_cvtepi16_epi32(in1_zp_low);
|
||||||
|
__m128i in1_1 = _mm256_extractf128_si256(tmp_in1, 0);
|
||||||
|
__m128i in1_2 = _mm256_extractf128_si256(tmp_in1, 1);
|
||||||
|
tmp_in1 = _mm256_cvtepi16_epi32(in1_zp_high);
|
||||||
|
__m128i in1_3 = _mm256_extractf128_si256(tmp_in1, 0);
|
||||||
|
__m128i in1_4 = _mm256_extractf128_si256(tmp_in1, 1);
|
||||||
|
|
||||||
|
// Apply left shift
|
||||||
|
in0_1 = _mm_mullo_epi32(in0_1, in0_left_vec);
|
||||||
|
in0_2 = _mm_mullo_epi32(in0_2, in0_left_vec);
|
||||||
|
in0_3 = _mm_mullo_epi32(in0_3, in0_left_vec);
|
||||||
|
in0_4 = _mm_mullo_epi32(in0_4, in0_left_vec);
|
||||||
|
in1_1 = _mm_mullo_epi32(in1_1, in1_left_vec);
|
||||||
|
in1_2 = _mm_mullo_epi32(in1_2, in1_left_vec);
|
||||||
|
in1_3 = _mm_mullo_epi32(in1_3, in1_left_vec);
|
||||||
|
in1_4 = _mm_mullo_epi32(in1_4, in1_left_vec);
|
||||||
|
|
||||||
|
// Apply the fixed-point part of the multiplier.
|
||||||
|
in0_1 = _mm_qrdmulh_epi32(in0_1, in0_multiplier);
|
||||||
|
in0_2 = _mm_qrdmulh_epi32(in0_2, in0_multiplier);
|
||||||
|
in0_3 = _mm_qrdmulh_epi32(in0_3, in0_multiplier);
|
||||||
|
in0_4 = _mm_qrdmulh_epi32(in0_4, in0_multiplier);
|
||||||
|
in1_1 = _mm_qrdmulh_epi32(in1_1, in1_multiplier);
|
||||||
|
in1_2 = _mm_qrdmulh_epi32(in1_2, in1_multiplier);
|
||||||
|
in1_3 = _mm_qrdmulh_epi32(in1_3, in1_multiplier);
|
||||||
|
in1_4 = _mm_qrdmulh_epi32(in1_4, in1_multiplier);
|
||||||
|
|
||||||
|
// Apply right shift
|
||||||
|
int32_t in0_remainder_mask = (1ll << (params->in0_args_.right_shift_)) - 1;
|
||||||
|
int32_t in0_remainder_threshold = in0_remainder_mask >> 1;
|
||||||
|
const __m128i vin0_remainder_mask = _mm_set1_epi32(in0_remainder_mask);
|
||||||
|
const __m128i vin0_remainder_threshold = _mm_set1_epi32(in0_remainder_threshold);
|
||||||
|
const __m128i in0_1_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(in0_1, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_1));
|
||||||
|
in0_1 = _mm_sub_epi32(_mm_rshr_epi32(in0_1, params->in0_args_.right_shift_),
|
||||||
|
_mm_cmpgt_epi32(in0_1_remainder, vin0_remainder_threshold));
|
||||||
|
const __m128i in0_2_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(in0_2, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_2));
|
||||||
|
in0_2 = _mm_sub_epi32(_mm_rshr_epi32(in0_2, params->in0_args_.right_shift_),
|
||||||
|
_mm_cmpgt_epi32(in0_2_remainder, vin0_remainder_threshold));
|
||||||
|
const __m128i in0_3_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(in0_3, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_3));
|
||||||
|
in0_3 = _mm_sub_epi32(_mm_rshr_epi32(in0_3, params->in0_args_.right_shift_),
|
||||||
|
_mm_cmpgt_epi32(in0_3_remainder, vin0_remainder_threshold));
|
||||||
|
const __m128i in0_4_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(in0_4, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_4));
|
||||||
|
in0_4 = _mm_sub_epi32(_mm_rshr_epi32(in0_4, params->in0_args_.right_shift_),
|
||||||
|
_mm_cmpgt_epi32(in0_4_remainder, vin0_remainder_threshold));
|
||||||
|
|
||||||
|
int32_t in1_remainder_mask = (1ll << (params->in1_args_.right_shift_)) - 1;
|
||||||
|
int32_t in1_remainder_threshold = in1_remainder_mask >> 1;
|
||||||
|
const __m128i vin1_remainder_mask = _mm_set1_epi32(in1_remainder_mask);
|
||||||
|
const __m128i vin1_remainder_threshold = _mm_set1_epi32(in1_remainder_threshold);
|
||||||
|
const __m128i in1_1_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(in1_1, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_1));
|
||||||
|
in1_1 = _mm_sub_epi32(_mm_rshr_epi32(in1_1, params->in1_args_.right_shift_),
|
||||||
|
_mm_cmpgt_epi32(in1_1_remainder, vin1_remainder_threshold));
|
||||||
|
const __m128i in1_2_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(in1_2, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_2));
|
||||||
|
in1_2 = _mm_sub_epi32(_mm_rshr_epi32(in1_2, params->in1_args_.right_shift_),
|
||||||
|
_mm_cmpgt_epi32(in1_2_remainder, vin1_remainder_threshold));
|
||||||
|
const __m128i in1_3_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(in1_3, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_3));
|
||||||
|
in1_3 = _mm_sub_epi32(_mm_rshr_epi32(in1_3, params->in1_args_.right_shift_),
|
||||||
|
_mm_cmpgt_epi32(in1_3_remainder, vin1_remainder_threshold));
|
||||||
|
const __m128i in1_4_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(in1_4, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_4));
|
||||||
|
in1_4 = _mm_sub_epi32(_mm_rshr_epi32(in1_4, params->in1_args_.right_shift_),
|
||||||
|
_mm_cmpgt_epi32(in1_4_remainder, vin1_remainder_threshold));
|
||||||
|
|
||||||
|
/* calculate output */
|
||||||
|
__m128i out1 = _mm_add_epi32(in0_1, in1_1);
|
||||||
|
__m128i out2 = _mm_add_epi32(in0_2, in1_2);
|
||||||
|
__m128i out3 = _mm_add_epi32(in0_3, in1_3);
|
||||||
|
__m128i out4 = _mm_add_epi32(in0_4, in1_4);
|
||||||
|
|
||||||
|
// Apply left shift
|
||||||
|
out1 = _mm_slli_epi32(out1, params->out_left_shift_);
|
||||||
|
out2 = _mm_slli_epi32(out2, params->out_left_shift_);
|
||||||
|
out3 = _mm_slli_epi32(out3, params->out_left_shift_);
|
||||||
|
out4 = _mm_slli_epi32(out4, params->out_left_shift_);
|
||||||
|
|
||||||
|
// Apply the fixed-point part of the multiplier.
|
||||||
|
out1 = _mm_qrdmulh_epi32(out1, out_multiplier);
|
||||||
|
out2 = _mm_qrdmulh_epi32(out2, out_multiplier);
|
||||||
|
out3 = _mm_qrdmulh_epi32(out3, out_multiplier);
|
||||||
|
out4 = _mm_qrdmulh_epi32(out4, out_multiplier);
|
||||||
|
|
||||||
|
// Apply right shift
|
||||||
|
int32_t out_remainder_mask = (1ll << (params->out_right_shift_)) - 1;
|
||||||
|
int32_t out_remainder_threshold = out_remainder_mask >> 1;
|
||||||
|
const __m128i vout_remainder_mask = _mm_set1_epi32(out_remainder_mask);
|
||||||
|
const __m128i vout_remainder_threshold = _mm_set1_epi32(out_remainder_threshold);
|
||||||
|
const __m128i out1_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(out1, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out1));
|
||||||
|
out1 = _mm_sub_epi32(_mm_rshr_epi32(out1, params->out_right_shift_),
|
||||||
|
_mm_cmpgt_epi32(out1_remainder, vout_remainder_threshold));
|
||||||
|
const __m128i out2_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(out2, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out2));
|
||||||
|
out2 = _mm_sub_epi32(_mm_rshr_epi32(out2, params->out_right_shift_),
|
||||||
|
_mm_cmpgt_epi32(out2_remainder, vout_remainder_threshold));
|
||||||
|
const __m128i out3_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(out3, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out3));
|
||||||
|
out3 = _mm_sub_epi32(_mm_rshr_epi32(out3, params->out_right_shift_),
|
||||||
|
_mm_cmpgt_epi32(out3_remainder, vout_remainder_threshold));
|
||||||
|
const __m128i out4_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(out4, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out4));
|
||||||
|
out4 = _mm_sub_epi32(_mm_rshr_epi32(out4, params->out_right_shift_),
|
||||||
|
_mm_cmpgt_epi32(out4_remainder, vout_remainder_threshold));
|
||||||
|
|
||||||
|
__m128i out1_s16 = _mm_packs_epi32(out1, out2);
|
||||||
|
__m128i out2_s16 = _mm_packs_epi32(out3, out4);
|
||||||
|
|
||||||
|
__m128i out_s16_1 = _mm_add_epi16(out1_s16, out_zp_vec);
|
||||||
|
__m128i out_s16_2 = _mm_add_epi16(out2_s16, out_zp_vec);
|
||||||
|
__m128i out = _mm_packs_epi16(out_s16_1, out_s16_2);
|
||||||
|
__m128i int8_out = _mm_max_epi8(min_vec, _mm_min_epi8(max_vec, out));
|
||||||
|
|
||||||
|
_mm_storeu_si128((__m128i_u *)(output + index), int8_out);
|
||||||
|
}
|
||||||
|
for (; index < size; index++) {
|
||||||
|
const int32_t in0_left = (input0[index] + params->in0_args_.zp_) * in0_left_shift;
|
||||||
|
const int32_t in1_left = (input1[index] + params->in1_args_.zp_) * in1_left_shift;
|
||||||
|
const int32_t in0 =
|
||||||
|
MultiplyByMultiplierAndRightShift(in0_left, params->in0_args_.multiplier_, params->in0_args_.right_shift_);
|
||||||
|
const int32_t in1 =
|
||||||
|
MultiplyByMultiplierAndRightShift(in1_left, params->in1_args_.multiplier_, params->in1_args_.right_shift_);
|
||||||
|
|
||||||
|
int32_t out = MultiplyByQuantizedMultiplier(in0 + in1, params->out_multiplier_, params->out_left_shift_,
|
||||||
|
-params->out_right_shift_);
|
||||||
|
out += params->out_zp_;
|
||||||
|
output[index] = (int8_t)MSMAX(params->min_, MSMIN(out, params->max_));
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AddOptInt8_AVX2(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, AddQuantParameter *params,
|
||||||
|
AddQuantQrgs *ptr_args, AddQuantQrgs *ele_args) {
|
||||||
|
// input0: ptr_in
|
||||||
|
// input1: element_in
|
||||||
|
// load quant parameters of input0 and input1
|
||||||
|
const int in0_left_shift = (1 << params->left_shift_) * (1 << ptr_args->left_shift_);
|
||||||
|
const int in1_left_shift = (1 << params->left_shift_) * (1 << ele_args->left_shift_);
|
||||||
|
const __m128i min_vec = _mm_set1_epi8(params->min_);
|
||||||
|
const __m128i max_vec = _mm_set1_epi8(params->max_);
|
||||||
|
const __m128i in0_zp_vec = _mm_set1_epi16(ptr_args->zp_);
|
||||||
|
const __m128i in1_zp_vec = _mm_set1_epi16(ele_args->zp_);
|
||||||
|
const __m128i out_zp_vec = _mm_set1_epi16(params->out_zp_);
|
||||||
|
const __m128i in0_left_vec = _mm_set1_epi32(in0_left_shift);
|
||||||
|
const __m128i in1_left_vec = _mm_set1_epi32(in1_left_shift);
|
||||||
|
const __m128i in0_multiplier = _mm_set1_epi32(params->in0_args_.multiplier_);
|
||||||
|
const __m128i in1_multiplier = _mm_set1_epi32(params->in1_args_.multiplier_);
|
||||||
|
const __m128i out_multiplier = _mm_set1_epi32(params->out_multiplier_);
|
||||||
|
|
||||||
|
// input1 can be processed once because it is const
|
||||||
|
const __m128i in1_src = _mm_set1_epi8(element_in);
|
||||||
|
const __m256i in1_s16 = _mm256_cvtepi8_epi16(in1_src);
|
||||||
|
const __m128i in1_s16_low = _mm256_extractf128_si256(in1_s16, 0);
|
||||||
|
const __m128i in1_s16_high = _mm256_extractf128_si256(in1_s16, 1);
|
||||||
|
const __m128i in1_zp_low = _mm_add_epi16(in1_s16_low, in1_zp_vec);
|
||||||
|
const __m128i in1_zp_high = _mm_add_epi16(in1_s16_high, in1_zp_vec);
|
||||||
|
__m256i tmp_in1 = _mm256_cvtepi16_epi32(in1_zp_low);
|
||||||
|
__m128i in1_1 = _mm256_extractf128_si256(tmp_in1, 0);
|
||||||
|
__m128i in1_2 = _mm256_extractf128_si256(tmp_in1, 1);
|
||||||
|
tmp_in1 = _mm256_cvtepi16_epi32(in1_zp_high);
|
||||||
|
__m128i in1_3 = _mm256_extractf128_si256(tmp_in1, 0);
|
||||||
|
__m128i in1_4 = _mm256_extractf128_si256(tmp_in1, 1);
|
||||||
|
|
||||||
|
// Apply left shift
|
||||||
|
in1_1 = _mm_mullo_epi32(in1_1, in1_left_vec);
|
||||||
|
in1_2 = _mm_mullo_epi32(in1_2, in1_left_vec);
|
||||||
|
in1_3 = _mm_mullo_epi32(in1_3, in1_left_vec);
|
||||||
|
in1_4 = _mm_mullo_epi32(in1_4, in1_left_vec);
|
||||||
|
|
||||||
|
// Apply the fixed-point part of the multiplier.
|
||||||
|
in1_1 = _mm_qrdmulh_epi32(in1_1, in1_multiplier);
|
||||||
|
in1_2 = _mm_qrdmulh_epi32(in1_2, in1_multiplier);
|
||||||
|
in1_3 = _mm_qrdmulh_epi32(in1_3, in1_multiplier);
|
||||||
|
in1_4 = _mm_qrdmulh_epi32(in1_4, in1_multiplier);
|
||||||
|
|
||||||
|
// Apply right shift
|
||||||
|
int32_t in1_remainder_mask = (1ll << (params->in1_args_.right_shift_)) - 1;
|
||||||
|
int32_t in1_remainder_threshold = in1_remainder_mask >> 1;
|
||||||
|
const __m128i vin1_remainder_mask = _mm_set1_epi32(in1_remainder_mask);
|
||||||
|
const __m128i vin1_remainder_threshold = _mm_set1_epi32(in1_remainder_threshold);
|
||||||
|
const __m128i in1_1_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(in1_1, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_1));
|
||||||
|
in1_1 = _mm_sub_epi32(_mm_rshr_epi32(in1_1, params->in1_args_.right_shift_),
|
||||||
|
_mm_cmpgt_epi32(in1_1_remainder, vin1_remainder_threshold));
|
||||||
|
const __m128i in1_2_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(in1_2, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_2));
|
||||||
|
in1_2 = _mm_sub_epi32(_mm_rshr_epi32(in1_2, params->in1_args_.right_shift_),
|
||||||
|
_mm_cmpgt_epi32(in1_2_remainder, vin1_remainder_threshold));
|
||||||
|
const __m128i in1_3_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(in1_3, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_3));
|
||||||
|
in1_3 = _mm_sub_epi32(_mm_rshr_epi32(in1_3, params->in1_args_.right_shift_),
|
||||||
|
_mm_cmpgt_epi32(in1_3_remainder, vin1_remainder_threshold));
|
||||||
|
const __m128i in1_4_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(in1_4, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in1_4));
|
||||||
|
in1_4 = _mm_sub_epi32(_mm_rshr_epi32(in1_4, params->in1_args_.right_shift_),
|
||||||
|
_mm_cmpgt_epi32(in1_4_remainder, vin1_remainder_threshold));
|
||||||
|
|
||||||
|
int index = 0;
|
||||||
|
for (; index <= size - 16; index += 16) {
|
||||||
|
const __m128i in0_src = _mm_loadu_si128((__m128i_u *)(ptr_in + index));
|
||||||
|
const __m256i in0_s16 = _mm256_cvtepi8_epi16(in0_src);
|
||||||
|
const __m128i in0_s16_low = _mm256_extractf128_si256(in0_s16, 0);
|
||||||
|
const __m128i in0_s16_high = _mm256_extractf128_si256(in0_s16, 1);
|
||||||
|
const __m128i in0_zp_low = _mm_add_epi16(in0_s16_low, in0_zp_vec);
|
||||||
|
const __m128i in0_zp_high = _mm_add_epi16(in0_s16_high, in0_zp_vec);
|
||||||
|
|
||||||
|
__m256i tmp_in0 = _mm256_cvtepi16_epi32(in0_zp_low);
|
||||||
|
__m128i in0_1 = _mm256_extractf128_si256(tmp_in0, 0);
|
||||||
|
__m128i in0_2 = _mm256_extractf128_si256(tmp_in0, 1);
|
||||||
|
tmp_in0 = _mm256_cvtepi16_epi32(in0_zp_high);
|
||||||
|
__m128i in0_3 = _mm256_extractf128_si256(tmp_in0, 0);
|
||||||
|
__m128i in0_4 = _mm256_extractf128_si256(tmp_in0, 1);
|
||||||
|
|
||||||
|
// Apply left shift
|
||||||
|
in0_1 = _mm_mullo_epi32(in0_1, in0_left_vec);
|
||||||
|
in0_2 = _mm_mullo_epi32(in0_2, in0_left_vec);
|
||||||
|
in0_3 = _mm_mullo_epi32(in0_3, in0_left_vec);
|
||||||
|
in0_4 = _mm_mullo_epi32(in0_4, in0_left_vec);
|
||||||
|
|
||||||
|
// Apply the fixed-point part of the multiplier.
|
||||||
|
in0_1 = _mm_qrdmulh_epi32(in0_1, in0_multiplier);
|
||||||
|
in0_2 = _mm_qrdmulh_epi32(in0_2, in0_multiplier);
|
||||||
|
in0_3 = _mm_qrdmulh_epi32(in0_3, in0_multiplier);
|
||||||
|
in0_4 = _mm_qrdmulh_epi32(in0_4, in0_multiplier);
|
||||||
|
|
||||||
|
// Apply right shift
|
||||||
|
int32_t in0_remainder_mask = (1ll << (params->in0_args_.right_shift_)) - 1;
|
||||||
|
int32_t in0_remainder_threshold = in0_remainder_mask >> 1;
|
||||||
|
const __m128i vin0_remainder_mask = _mm_set1_epi32(in0_remainder_mask);
|
||||||
|
const __m128i vin0_remainder_threshold = _mm_set1_epi32(in0_remainder_threshold);
|
||||||
|
const __m128i in0_1_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(in0_1, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_1));
|
||||||
|
in0_1 = _mm_sub_epi32(_mm_rshr_epi32(in0_1, params->in0_args_.right_shift_),
|
||||||
|
_mm_cmpgt_epi32(in0_1_remainder, vin0_remainder_threshold));
|
||||||
|
const __m128i in0_2_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(in0_2, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_2));
|
||||||
|
in0_2 = _mm_sub_epi32(_mm_rshr_epi32(in0_2, params->in0_args_.right_shift_),
|
||||||
|
_mm_cmpgt_epi32(in0_2_remainder, vin0_remainder_threshold));
|
||||||
|
const __m128i in0_3_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(in0_3, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_3));
|
||||||
|
in0_3 = _mm_sub_epi32(_mm_rshr_epi32(in0_3, params->in0_args_.right_shift_),
|
||||||
|
_mm_cmpgt_epi32(in0_3_remainder, vin0_remainder_threshold));
|
||||||
|
const __m128i in0_4_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(in0_4, vin0_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), in0_4));
|
||||||
|
in0_4 = _mm_sub_epi32(_mm_rshr_epi32(in0_4, params->in0_args_.right_shift_),
|
||||||
|
_mm_cmpgt_epi32(in0_4_remainder, vin0_remainder_threshold));
|
||||||
|
|
||||||
|
/* calculate output */
|
||||||
|
__m128i out1 = _mm_add_epi32(in0_1, in1_1);
|
||||||
|
__m128i out2 = _mm_add_epi32(in0_2, in1_2);
|
||||||
|
__m128i out3 = _mm_add_epi32(in0_3, in1_3);
|
||||||
|
__m128i out4 = _mm_add_epi32(in0_4, in1_4);
|
||||||
|
|
||||||
|
// Apply left shift
|
||||||
|
out1 = _mm_slli_epi32(out1, params->out_left_shift_);
|
||||||
|
out2 = _mm_slli_epi32(out2, params->out_left_shift_);
|
||||||
|
out3 = _mm_slli_epi32(out3, params->out_left_shift_);
|
||||||
|
out4 = _mm_slli_epi32(out4, params->out_left_shift_);
|
||||||
|
|
||||||
|
// Apply the fixed-point part of the multiplier.
|
||||||
|
out1 = _mm_qrdmulh_epi32(out1, out_multiplier);
|
||||||
|
out2 = _mm_qrdmulh_epi32(out2, out_multiplier);
|
||||||
|
out3 = _mm_qrdmulh_epi32(out3, out_multiplier);
|
||||||
|
out4 = _mm_qrdmulh_epi32(out4, out_multiplier);
|
||||||
|
|
||||||
|
// Apply right shift
|
||||||
|
int32_t out_remainder_mask = (1ll << (params->out_right_shift_)) - 1;
|
||||||
|
int32_t out_remainder_threshold = out_remainder_mask >> 1;
|
||||||
|
const __m128i vout_remainder_mask = _mm_set1_epi32(out_remainder_mask);
|
||||||
|
const __m128i vout_remainder_threshold = _mm_set1_epi32(out_remainder_threshold);
|
||||||
|
const __m128i out1_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(out1, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out1));
|
||||||
|
out1 = _mm_sub_epi32(_mm_rshr_epi32(out1, params->out_right_shift_),
|
||||||
|
_mm_cmpgt_epi32(out1_remainder, vout_remainder_threshold));
|
||||||
|
const __m128i out2_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(out2, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out2));
|
||||||
|
out2 = _mm_sub_epi32(_mm_rshr_epi32(out2, params->out_right_shift_),
|
||||||
|
_mm_cmpgt_epi32(out2_remainder, vout_remainder_threshold));
|
||||||
|
const __m128i out3_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(out3, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out3));
|
||||||
|
out3 = _mm_sub_epi32(_mm_rshr_epi32(out3, params->out_right_shift_),
|
||||||
|
_mm_cmpgt_epi32(out3_remainder, vout_remainder_threshold));
|
||||||
|
const __m128i out4_remainder =
|
||||||
|
_mm_add_epi32(_mm_and_si128(out4, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out4));
|
||||||
|
out4 = _mm_sub_epi32(_mm_rshr_epi32(out4, params->out_right_shift_),
|
||||||
|
_mm_cmpgt_epi32(out4_remainder, vout_remainder_threshold));
|
||||||
|
|
||||||
|
__m128i out1_s16 = _mm_packs_epi32(out1, out2);
|
||||||
|
__m128i out2_s16 = _mm_packs_epi32(out3, out4);
|
||||||
|
|
||||||
|
__m128i out_s16_1 = _mm_add_epi16(out1_s16, out_zp_vec);
|
||||||
|
__m128i out_s16_2 = _mm_add_epi16(out2_s16, out_zp_vec);
|
||||||
|
__m128i out = _mm_packs_epi16(out_s16_1, out_s16_2);
|
||||||
|
__m128i int8_out = _mm_max_epi8(min_vec, _mm_min_epi8(max_vec, out));
|
||||||
|
|
||||||
|
_mm_storeu_si128((__m128i_u *)(output + index), int8_out);
|
||||||
|
}
|
||||||
|
for (; index < size; index++) {
|
||||||
|
const int32_t in0_left = (ptr_in[index] + ptr_args->zp_) * in0_left_shift;
|
||||||
|
const int32_t in1_left = (element_in + ele_args->zp_) * in1_left_shift;
|
||||||
|
const int32_t in0 = MultiplyByMultiplierAndRightShift(in0_left, ptr_args->multiplier_, ptr_args->right_shift_);
|
||||||
|
const int32_t in1 = MultiplyByMultiplierAndRightShift(in1_left, ele_args->multiplier_, ele_args->right_shift_);
|
||||||
|
|
||||||
|
int32_t out = MultiplyByQuantizedMultiplier(in0 + in1, params->out_multiplier_, params->out_left_shift_,
|
||||||
|
-params->out_right_shift_);
|
||||||
|
out += params->out_zp_;
|
||||||
|
output[index] = (int8_t)MSMAX(params->min_, MSMIN(out, params->max_));
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
|
@ -17,6 +17,9 @@
|
||||||
#ifndef MINDSPORE_LITE_NNACL_ADD_INT8_H_
|
#ifndef MINDSPORE_LITE_NNACL_ADD_INT8_H_
|
||||||
#define MINDSPORE_LITE_NNACL_ADD_INT8_H_
|
#define MINDSPORE_LITE_NNACL_ADD_INT8_H_
|
||||||
|
|
||||||
|
#ifdef ENABLE_AVX
|
||||||
|
#include <x86intrin.h>
|
||||||
|
#endif
|
||||||
#include "nnacl/op_base.h"
|
#include "nnacl/op_base.h"
|
||||||
#include "nnacl/errorcode.h"
|
#include "nnacl/errorcode.h"
|
||||||
#include "nnacl/arithmetic.h"
|
#include "nnacl/arithmetic.h"
|
||||||
|
@ -48,13 +51,21 @@ extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params);
|
void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params);
|
||||||
|
|
||||||
void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, AddQuantParameter *params,
|
void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, AddQuantParameter *params,
|
||||||
AddQuantQrgs *ptr_args, AddQuantQrgs *ele_args);
|
AddQuantQrgs *ptr_args, AddQuantQrgs *ele_args);
|
||||||
|
|
||||||
int ElementAddInt8(const int8_t *in0, const int8_t *in1, int8_t *out, int size);
|
int ElementAddInt8(const int8_t *in0, const int8_t *in1, int8_t *out, int size);
|
||||||
|
|
||||||
int BroadcastAddInt8(const int8_t *in0, const int8_t *in1, int8_t *tile_in0, int8_t *tile_in1, int8_t *out, int size,
|
int BroadcastAddInt8(const int8_t *in0, const int8_t *in1, int8_t *tile_in0, int8_t *tile_in1, int8_t *out, int size,
|
||||||
ArithmeticParameter *param);
|
ArithmeticParameter *param);
|
||||||
|
|
||||||
|
#ifdef ENABLE_AVX
|
||||||
|
void AddInt8_AVX2(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params);
|
||||||
|
|
||||||
|
void AddOptInt8_AVX2(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, AddQuantParameter *params,
|
||||||
|
AddQuantQrgs *ptr_args, AddQuantQrgs *ele_args);
|
||||||
|
#endif
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "nnacl/x86_64_avx/common_utils.h"
|
||||||
|
#ifdef WIN32
|
||||||
|
#ifdef ENABLE_AVX
|
||||||
|
#include <stdint.h>
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
__m128i _mm_adds_epi32(__m128i a, __m128i b) {
|
||||||
|
__m128i int_min = _mm_set1_epi32(0x80000000);
|
||||||
|
__m128i int_max = _mm_set1_epi32(0x7FFFFFFF);
|
||||||
|
|
||||||
|
const __m128i res = _mm_add_epi32(a, b);
|
||||||
|
const __m128i sign_and = _mm_and_si128(a, b);
|
||||||
|
const __m128i sign_or = _mm_or_si128(a, b);
|
||||||
|
|
||||||
|
const __m128i min_sat_mask = _mm_andnot_si128(res, sign_and);
|
||||||
|
const __m128i max_sat_mask = _mm_andnot_si128(sign_or, res);
|
||||||
|
const __m128 res_temp =
|
||||||
|
_mm_blendv_ps(_mm_castsi128_ps(res), _mm_castsi128_ps(int_min), _mm_castsi128_ps(min_sat_mask));
|
||||||
|
return _mm_castps_si128(_mm_blendv_ps(res_temp, _mm_castsi128_ps(int_max), _mm_castsi128_ps(max_sat_mask)));
|
||||||
|
}
|
||||||
|
|
||||||
|
__m128i _mm_rshr_epi32(__m128i a, int shift) {
|
||||||
|
const __m128i vmask = _mm_cmpgt_epi32(_mm_setzero_si128(), a);
|
||||||
|
const __m128i vabs_a = _mm_sub_epi32(_mm_xor_si128(a, vmask), vmask);
|
||||||
|
const __m128i tmp_res = _mm_srli_epi32(vabs_a, shift);
|
||||||
|
return _mm_xor_si128(tmp_res, vmask);
|
||||||
|
}
|
||||||
|
|
||||||
|
__m128i _mm_qrdmulh_epi32(__m128i a, __m128i b) {
|
||||||
|
const __m128i tmp_a_lo = _mm_unpacklo_epi32(a, _mm_setzero_si128());
|
||||||
|
const __m128i tmp_a_hi = _mm_unpackhi_epi32(a, _mm_setzero_si128());
|
||||||
|
const __m256i tmp_a_256 = _mm256_set_m128i(tmp_a_hi, tmp_a_lo);
|
||||||
|
const __m128i tmp_b_lo = _mm_unpacklo_epi32(b, _mm_setzero_si128());
|
||||||
|
const __m128i tmp_b_hi = _mm_unpackhi_epi32(b, _mm_setzero_si128());
|
||||||
|
const __m256i tmp_b_256 = _mm256_set_m128i(tmp_b_hi, tmp_b_lo);
|
||||||
|
__m256i tmp_out = _mm256_mul_epi32(tmp_a_256, tmp_b_256);
|
||||||
|
tmp_out = _mm256_add_epi64(tmp_out, _mm256_set1_epi64x(1ll << 30));
|
||||||
|
const __m256i vmask = _mm256_cmpgt_epi64(_mm256_setzero_si256(), tmp_out);
|
||||||
|
const __m256i vabs_tmp_out = _mm256_sub_epi64(_mm256_xor_si256(tmp_out, vmask), vmask);
|
||||||
|
tmp_out = _mm256_srli_epi64(vabs_tmp_out, 31);
|
||||||
|
const __m256i vtmp_out = _mm256_sub_epi64(_mm256_xor_si256(tmp_out, vmask), vmask);
|
||||||
|
const int32_t max_32bit = (1ll << 31) - 1;
|
||||||
|
const int32_t min_32bit = -(1ll << 31);
|
||||||
|
int64_t *tmp_out_ptr = (int64_t *)(&vtmp_out);
|
||||||
|
int32_t r1 = tmp_out_ptr[0] > max_32bit ? max_32bit : tmp_out_ptr[0];
|
||||||
|
r1 = r1 < min_32bit ? min_32bit : r1;
|
||||||
|
int32_t r2 = tmp_out_ptr[1] > max_32bit ? max_32bit : tmp_out_ptr[1];
|
||||||
|
r2 = r2 < min_32bit ? min_32bit : r2;
|
||||||
|
int32_t r3 = tmp_out_ptr[2] > max_32bit ? max_32bit : tmp_out_ptr[2];
|
||||||
|
r3 = r3 < min_32bit ? min_32bit : r3;
|
||||||
|
int32_t r4 = tmp_out_ptr[3] > max_32bit ? max_32bit : tmp_out_ptr[3];
|
||||||
|
r4 = r4 < min_32bit ? min_32bit : r4;
|
||||||
|
return _mm_set_epi32(r4, r3, r2, r1);
|
||||||
|
}
|
|
@ -0,0 +1,44 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_LITE_NNACL_X86_64_AVX_COMMON_UTILS_H_
|
||||||
|
#define MINDSPORE_LITE_NNACL_X86_64_AVX_COMMON_UTILS_H_
|
||||||
|
|
||||||
|
#include <x86intrin.h>
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
#ifdef __GNUC__
|
||||||
|
#if __GNUC__ < 8
|
||||||
|
#define _mm256_set_m128i(xmm1, xmm2) \
|
||||||
|
_mm256_permute2f128_si256(_mm256_castsi128_si256(xmm1), _mm256_castsi128_si256(xmm2), 2)
|
||||||
|
#define _mm256_set_m128f(xmm1, xmm2) \
|
||||||
|
_mm256_permute2f128_ps(_mm256_castps128_ps256(xmm1), _mm256_castps128_ps256(xmm2), 2)
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Signed saturating Add
|
||||||
|
__m128i _mm_adds_epi32(__m128i a, __m128i b);
|
||||||
|
|
||||||
|
// Signed rounding shift right
|
||||||
|
__m128i _mm_rshr_epi32(__m128i a, int shift);
|
||||||
|
|
||||||
|
// Signed saturating Rounding Doubling Multiply return High half
|
||||||
|
__m128i _mm_qrdmulh_epi32(__m128i a, __m128i b);
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif // MINDSPORE_LITE_NNACL_X86_64_AVX_COMMON_UTILS_H_
|
|
@ -153,7 +153,11 @@ void QuantizedAddCPUKernel::BroadcastRun(int task_id) {
|
||||||
cur_in1 = input1_data_ + task_id * stride * in_size_ + i * in_size_;
|
cur_in1 = input1_data_ + task_id * stride * in_size_ + i * in_size_;
|
||||||
cur_out = output_data_ + task_id * stride * in_size_ + i * in_size_;
|
cur_out = output_data_ + task_id * stride * in_size_ + i * in_size_;
|
||||||
}
|
}
|
||||||
|
#ifdef ENABLE_AVX
|
||||||
|
AddInt8_AVX2(cur_in0, cur_in1, cur_out, in_size_, ¶_);
|
||||||
|
#else
|
||||||
AddInt8(cur_in0, cur_in1, cur_out, in_size_, ¶_);
|
AddInt8(cur_in0, cur_in1, cur_out, in_size_, ¶_);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -180,9 +184,17 @@ int QuantizedAddCPUKernel::DoExecute(int task_id) {
|
||||||
int8_t element_in = arith_para_->in_elements_num0_ == 1 ? input0_data_[0] : input1_data_[0];
|
int8_t element_in = arith_para_->in_elements_num0_ == 1 ? input0_data_[0] : input1_data_[0];
|
||||||
AddQuantQrgs *ptr_args = arith_para_->in_elements_num0_ == 1 ? ¶_.in1_args_ : ¶_.in0_args_;
|
AddQuantQrgs *ptr_args = arith_para_->in_elements_num0_ == 1 ? ¶_.in1_args_ : ¶_.in0_args_;
|
||||||
AddQuantQrgs *ele_args = arith_para_->in_elements_num0_ == 1 ? ¶_.in0_args_ : ¶_.in1_args_;
|
AddQuantQrgs *ele_args = arith_para_->in_elements_num0_ == 1 ? ¶_.in0_args_ : ¶_.in1_args_;
|
||||||
|
#ifdef ENABLE_AVX
|
||||||
|
AddOptInt8_AVX2(ptr_in, element_in, cur_out, rest_count, ¶_, ptr_args, ele_args);
|
||||||
|
#else
|
||||||
AddOptInt8(ptr_in, element_in, cur_out, rest_count, ¶_, ptr_args, ele_args);
|
AddOptInt8(ptr_in, element_in, cur_out, rest_count, ¶_, ptr_args, ele_args);
|
||||||
|
#endif
|
||||||
} else {
|
} else {
|
||||||
|
#ifdef ENABLE_AVX
|
||||||
|
AddInt8_AVX2(cur_in0, cur_in1, cur_out, rest_count, ¶_);
|
||||||
|
#else
|
||||||
AddInt8(cur_in0, cur_in1, cur_out, rest_count, ¶_);
|
AddInt8(cur_in0, cur_in1, cur_out, rest_count, ¶_);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
|
|
|
@ -75,7 +75,10 @@ if ("${X86_64_SIMD}" STREQUAL "sse")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if ("${X86_64_SIMD}" STREQUAL "avx")
|
if ("${X86_64_SIMD}" STREQUAL "avx")
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.1 -mavx -mavx2")
|
||||||
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -mavx -mavx2")
|
||||||
file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/x86_64_sse/*.c
|
file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/x86_64_sse/*.c
|
||||||
|
${LITE_DIR}/nnacl/x86_64_avx/*.c
|
||||||
${LITE_DIR}/nnacl/assembly/avx/*.S)
|
${LITE_DIR}/nnacl/assembly/avx/*.S)
|
||||||
set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C)
|
set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C)
|
||||||
set(KERNEL_OP_SRC
|
set(KERNEL_OP_SRC
|
||||||
|
|
Loading…
Reference in New Issue