!14310 [MS][LITE][r1.2]fix exp for big-endian devices

From: @lx0095
Reviewed-by: @zhanghaibo5,@lilongfei15
Signed-off-by: @lilongfei15
This commit is contained in:
mindspore-ci-bot 2021-03-29 17:43:04 +08:00 committed by Gitee
commit 1b9a5563e5
2 changed files with 11 additions and 5 deletions

View File

@ -56,7 +56,7 @@ static inline void simd_exp(MS_FLOAT32X4 input, float *dst) {
MS_FLOAT32X4 decimal_exp =
param[5] +
decimal * (param[5] + decimal * (param[4] + decimal * (param[3] + decimal * (param[2] + decimal * param[1]))));
MS_STQ_F32(dst, decimal_exp * MS_LDQ_F32((float *)(&int_exp)));
MS_STQ_F32(dst, decimal_exp * MS_CAST_F32_S32(int_exp));
}
#endif
@ -78,20 +78,23 @@ static inline void simd_exp_avx(MS_FLOAT32X8 input, float *dst) {
MS_FLOAT32X8 decimal_exp =
param[5] +
decimal * (param[5] + decimal * (param[4] + decimal * (param[3] + decimal * (param[2] + decimal * param[1]))));
MS_ST256_F32(dst, decimal_exp * MS_LD256_F32((float *)(&int_exp)));
MS_ST256_F32(dst, decimal_exp * MS_CAST256_F32_S32(int_exp));
}
#endif
static inline void single_exp(float src, float *dst) {
typedef union {
float f;
int i;
} fi;
static float param[] = {0.693147f, 1.0f / 120, 1.0f / 24, 1.0f / 6, 1.0f / 2, 1.0f}; // log(2.0f)
src = MSMAX(-88.0f, MSMIN(88.0f, src));
int integer = src / param[0];
float decimal = src - integer * param[0];
int int_exp = (integer + 127) << 23;
fi int_exp = {.i = (integer + 127) << 23};
float decimal_exp =
1.0f + decimal * (1.0f + decimal * (0.5f + decimal * (param[3] + decimal * (param[2] + decimal * param[1]))));
float *ptr = (float *)&int_exp;
*dst = *ptr * decimal_exp;
*dst = int_exp.f * decimal_exp;
}
#ifdef __cplusplus
}

View File

@ -65,6 +65,7 @@ inline static float32x4_t vrecp(float32x4_t v) {
// Note: Compared with X86, the vbslq_f32 parameters are the opposite with _mm_blendv_f32
#define MS_BLENDQ_F32(src1, src2, src3) vbslq_f32(src3, src2, src1)
#define MS_BLENDQ_EPI32(src1, src2, src3) vbslq_s32(src3, src2, src1)
#define MS_CAST_F32_S32(src) vreinterpretq_f32_s32(src)
#endif
#if defined(ENABLE_AVX)
@ -97,6 +98,7 @@ inline static float32x4_t vrecp(float32x4_t v) {
#define MS_CMPGT256_EPI32(src1, src2) _mm256_cmpgt_epi32(src1, src2)
#define MS_BLEND256_F32(src1, src2, src3) _mm256_blendv_ps(src1, src2, src3)
#define MS_BLEND256_EPI32(src1, src2, src3) _mm256_blendv_epi8(src1, src2, src3)
#define MS_CAST256_F32_S32(src) _mm256_castsi256_ps(src)
#endif
#if defined(ENABLE_SSE)
@ -129,6 +131,7 @@ inline static float32x4_t vrecp(float32x4_t v) {
#define MS_CMPGTQ_EPI32(src1, src2) _mm_cmpgt_epi32(src1, src2)
#define MS_BLENDQ_F32(src1, src2, src3) _mm_blendv_ps(src1, src2, src3)
#define MS_BLENDQ_EPI32(src1, src2, src3) _mm_blendv_epi8(src1, src2, src3)
#define MS_CAST_F32_S32(src) _mm_castsi128_ps(src)
#endif
#define LOAD256X8_F32(src, input_ptr, num) \