!11703 [MSLITE][Develop] optimize fp16 sigmoid and tanh

From: @sunsuodong
Reviewed-by: @zhanghaibo5,@zhang_xue_tong
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2021-01-28 14:28:08 +08:00 committed by Gitee
commit 159cd250d7
1 changed files with 42 additions and 6 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp16/activation_fp16.h"
#include "nnacl/fp32/exp_fp32.h"
#include "nnacl/errorcode.h"
int ReluFp16(const float16_t *src, float16_t *dst, int ele_num) {
@ -60,8 +60,19 @@ int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha
}
int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) {
for (int i = 0; i < ele_num; ++i) {
dst[i] = (float16_t)1.0f / (float16_t)(1.0f + exp(-src[i]));
int i = 0;
#ifdef ENABLE_ARM64
int count = (ele_num / C4NUM) * C4NUM;
for (; i < count; i += C4NUM) {
float32x4_t tmp;
simd_exp(vnegq_f32(vcvt_f32_f16(vld1_f16(src + i))), (float *)&tmp);
vst1_f16(dst + i, vcvt_f16_f32(vdivq_f32(vdupq_n_f32(1.0f), vaddq_f32(vdupq_n_f32(1.0f), tmp))));
}
#endif
for (; i < ele_num; ++i) {
float temp;
single_exp(-src[i], &temp);
dst[i] = (float16_t)1.0f / ((float16_t)1.0f + temp);
}
return NNACL_OK;
}
@ -80,8 +91,33 @@ float16_t TanhOptFp16(float16_t src) {
}
int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) {
for (int i = 0; i < ele_num; ++i) {
dst[i] = TanhOptFp16(src[i]);
int i = 0;
#ifdef ENABLE_ARM64
static float32x4_t paramv[] = {{378.0f, 378.0f, 378.0f, 378.0f},
{17325.0f, 17325.0f, 17325.0f, 17325.0f},
{135135.0f, 135135.0f, 135135.0f, 135135.0f},
{28.0f, 28.0f, 28.0f, 28.0f},
{3150.0f, 3150.0f, 3150.0f, 3150.0f},
{62370.0f, 62370.0f, 62370.0f, 62370.0f}};
int count = (ele_num / C4NUM) * C4NUM;
for (; i < count; i += C4NUM) {
float32x4_t input = vcvt_f32_f16(vld1_f16(src + i));
float32x4_t square = vmulq_f32(input, input);
float32x4_t a = vmulq_f32(
vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(square, paramv[0]), square), paramv[1]), square), paramv[2]),
input);
float32x4_t b = vaddq_f32(
vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(paramv[3], square), paramv[4]), square), paramv[5]), square),
paramv[2]);
vst1_f16(dst + i, vcvt_f16_f32(vdivq_f32(a, b)));
}
#endif
for (; i < ele_num; ++i) {
float input = src[i];
float square = input * input;
float a = (((square + 378.0f) * square + 17325.0f) * square + 135135.0f) * input;
float b = ((28.0f * square + 3150.0f) * square + 62370.0f) * square + 135135.0f;
dst[i] = a / b;
}
return NNACL_OK;
}