forked from mindspore-Ecosystem/mindspore
gelu optimize
This commit is contained in:
parent
d436b9ee98
commit
6090bf0849
|
@ -168,3 +168,40 @@ int HardTanhFp16(const float16_t *src, int length, float16_t *dst, float min_val
|
|||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int GeluFp16(const float16_t *src, int length, float16_t *dst, bool approximate) {
|
||||
if (src == NULL || dst == NULL) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
int i = 0;
|
||||
if (approximate) {
|
||||
// dst = 0.5 * x * (1 + tanh((2 / pi) ^ 0.5 * (x + 0.044715x^3)))
|
||||
#ifdef ENABLE_NEON
|
||||
int C8 = UP_ROUND(length, C8NUM);
|
||||
for (; i < C8; i += C8NUM) {
|
||||
float16x8_t in = vld1q_f16(src + i);
|
||||
float16x8_t res =
|
||||
0.5 * in * (1.0 + MS_TANHX8_F16(((float16_t)0.79788456080287 + (float16_t)0.035677408136 * in * in) * in));
|
||||
vst1q_f16(dst + i, res);
|
||||
}
|
||||
#endif
|
||||
for (; i < length; i++) {
|
||||
dst[i] =
|
||||
0.5 * src[i] *
|
||||
(1.0 + TanhOptFp16(((float16_t)0.79788456080287 + (float16_t)0.035677408136 * src[i] * src[i]) * src[i]));
|
||||
}
|
||||
} else {
|
||||
#ifdef ENABLE_NEON
|
||||
int C8 = UP_ROUND(length, C8NUM);
|
||||
for (; i < C8; i += C8NUM) {
|
||||
float16x8_t in = vld1q_f16(src + i);
|
||||
float16x8_t res = 0.5 * in * (1.0 + MS_ERFX8_F16(in / (float16_t)1.4142135623730951f));
|
||||
vst1q_f16(dst + i, res);
|
||||
}
|
||||
#endif
|
||||
for (; i < length; i++) {
|
||||
dst[i] = 0.5 * src[i] * (1.0 + erf(src[i] / 1.4142135623730951f));
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -34,6 +34,7 @@ int TanhFp16(const float16_t *src, float16_t *dst, int ele_num);
|
|||
int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num);
|
||||
int SwishFp16(const float16_t *src, float16_t *dst, int ele_num);
|
||||
int HardTanhFp16(const float16_t *src, int length, float16_t *dst, float min_val, float max_val);
|
||||
int GeluFp16(const float16_t *src, int length, float16_t *dst, bool approximate);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -134,50 +134,21 @@ float TanhOpt(float src) {
|
|||
|
||||
int Tanh(const float *src, int length, float *dst) {
|
||||
int i = 0;
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_SSE) || defined(ENABLE_AVX)
|
||||
const int cnt = 6;
|
||||
float data[] = {378.0f, 17325.0f, 135135.0f, 28.0f, 3150.0f, 62370.0f};
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_AVX)
|
||||
MS_FLOAT32X8 neg_one_8 = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f};
|
||||
MS_FLOAT32X8 pos_one_8 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
|
||||
MS_FLOAT32X8 param256[6];
|
||||
for (int j = 0; j < cnt; ++j) {
|
||||
param256[j] = MS_MOV256_F32(data[j]);
|
||||
}
|
||||
for (; i < length - 8; i += 8) {
|
||||
MS_FLOAT32X8 input = MS_LD256_F32(src + i);
|
||||
MS_FLOAT32X8 square = input * input;
|
||||
MS_FLOAT32X8 a = (((square + param256[0]) * square + param256[1]) * square + param256[2]) * input;
|
||||
MS_FLOAT32X8 b = ((param256[3] * square + param256[4]) * square + param256[5]) * square + param256[2];
|
||||
MS_ST256_F32(dst + i, MS_MIN256_F32(MS_MAX256_F32(a / b, neg_one_8), pos_one_8));
|
||||
MS_ST256_F32(dst + i, MS_TANHX8_F32(input));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
|
||||
MS_FLOAT32X4 param[6];
|
||||
MS_FLOAT32X4 neg_one = {-1.0f, -1.0f, -1.0f, -1.0f};
|
||||
MS_FLOAT32X4 pos_one = {1.0f, 1.0f, 1.0f, 1.0f};
|
||||
for (int j = 0; j < cnt; ++j) {
|
||||
param[j] = MS_MOVQ_F32(data[j]);
|
||||
}
|
||||
for (; i < length - 4; i += 4) {
|
||||
MS_FLOAT32X4 input = MS_LDQ_F32(src + i);
|
||||
MS_FLOAT32X4 square = input * input;
|
||||
MS_FLOAT32X4 a = (((square + param[0]) * square + param[1]) * square + param[2]) * input;
|
||||
MS_FLOAT32X4 b = ((param[3] * square + param[4]) * square + param[5]) * square + param[2];
|
||||
MS_STQ_F32(dst + i, MS_MINQ_F32(MS_MAXQ_F32(a / b, neg_one), pos_one));
|
||||
MS_STQ_F32(dst + i, MS_TANHX4_F32(input));
|
||||
}
|
||||
#endif
|
||||
for (; i < length; ++i) {
|
||||
float input = src[i];
|
||||
float square = input * input;
|
||||
float a = (((square + 378.0f) * square + 17325.0f) * square + 135135.0f) * input;
|
||||
float b = ((28.0f * square + 3150.0f) * square + 62370.0f) * square + 135135.0f;
|
||||
dst[i] = a / b;
|
||||
dst[i] = MSMAX(dst[i], -1);
|
||||
dst[i] = MSMIN(dst[i], 1);
|
||||
dst[i] = TanhOpt(src[i]);
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
@ -249,10 +220,44 @@ int HardTanh(const float *src, int length, float *dst, float min_val, float max_
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int Gelu(const float *src, int length, float *dst) {
|
||||
for (int i = 0; i < length; ++i) {
|
||||
float tanh_res = TanhOpt(sqrt(2 / M_PI) * (src[i] + 0.044715 * pow(src[i], 3)));
|
||||
dst[i] = 0.5f * src[i] * (1 + tanh_res);
|
||||
int Gelu(const float *src, int length, float *dst, bool approximate) {
|
||||
if (src == NULL || dst == NULL) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
int i = 0;
|
||||
if (approximate) {
|
||||
// dst = 0.5 * x * (1 + tanh((2 / pi) ^ 0.5 * (x + 0.044715x^3)))
|
||||
#if defined(ENABLE_AVX)
|
||||
int C8 = UP_ROUND(length, C8NUM);
|
||||
for (; i < C8; i += C8NUM) {
|
||||
MS_FLOAT32X8 in = MS_LD256_F32(src + i);
|
||||
MS_FLOAT32X8 res = 0.5 * in * (1.0 + MS_TANHX8_F32((0.79788456080287f + 0.035677408136f * in * in) * in));
|
||||
MS_ST256_F32(dst + i, res);
|
||||
}
|
||||
#endif
|
||||
#if defined(ENABLE_SSE) || defined(ENABLE_ARM)
|
||||
int C4 = UP_ROUND(length, C4NUM);
|
||||
for (; i < C4; i += C4NUM) {
|
||||
MS_FLOAT32X4 in = MS_LDQ_F32(src + i);
|
||||
MS_FLOAT32X4 res = 0.5 * in * (1.0 + MS_TANHX4_F32((0.79788456080287f + 0.035677408136f * in * in) * in));
|
||||
MS_STQ_F32(dst + i, res);
|
||||
}
|
||||
#endif
|
||||
for (; i < length; i++) {
|
||||
dst[i] = 0.5 * src[i] * (1.0 + TanhOpt((0.79788456080287f + 0.035677408136f * src[i] * src[i]) * src[i]));
|
||||
}
|
||||
} else {
|
||||
#if defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM)
|
||||
int C4 = UP_ROUND(length, C4NUM);
|
||||
for (; i < C4; i += C4NUM) {
|
||||
MS_FLOAT32X4 in = MS_LDQ_F32(src + i);
|
||||
MS_FLOAT32X4 res = 0.5 * in * (1.0 + MS_ERFX4_F32(in / 1.4142135623730951f));
|
||||
MS_STQ_F32(dst + i, res);
|
||||
}
|
||||
#endif
|
||||
for (; i < length; i++) {
|
||||
dst[i] = 0.5 * src[i] * (1.0 + erf(src[i] / 1.4142135623730951f));
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ int HSigmoid(const float *src, int length, float *dst);
|
|||
int Swish(const float *src, int length, float *dst);
|
||||
int HSwish(const float *src, int length, float *dst);
|
||||
int HardTanh(const float *src, int length, float *dst, float min_val, float max_val);
|
||||
int Gelu(const float *src, int length, float *dst);
|
||||
int Gelu(const float *src, int length, float *dst, bool approximate);
|
||||
|
||||
float TanhOpt(float src);
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -17,11 +17,6 @@
|
|||
#include "nnacl/fp32/conv_depthwise_fp32.h"
|
||||
#include "nnacl/common_func.h"
|
||||
#include "nnacl/fp32/common_func_fp32.h"
|
||||
#include "nnacl/fp32/winograd_transform.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
||||
#ifdef ENABLE_ARM64
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE)
|
||||
void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, int num_pixels,
|
||||
|
|
|
@ -1,39 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp32/gelu_fp32.h"
|
||||
#include "nnacl/gelu_parameter.h"
|
||||
#include <string.h>
|
||||
#include <math.h>
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int DoGeLU(const float *src, float *out, int64_t real_dst_count, const GeLUParameter *param) {
|
||||
if (src == NULL || out == NULL) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
|
||||
if (param->approximate_) {
|
||||
for (int i = 0; i < real_dst_count; i++) {
|
||||
out[i] = 0.5 * src[i] * (1.0 + tanh(0.7978845608028654 * (src[i] + 0.044715 * pow(src[i], 3))));
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < real_dst_count; i++) {
|
||||
out[i] = 0.5 * src[i] * (1.0 + erf(src[i] / 1.4142135623730951));
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_NNACL_FP32_GELU_H_
|
||||
#define MINDSPORE_LITE_NNACL_FP32_GELU_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/gelu_parameter.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
int DoGeLU(const float *src, float *out, int64_t real_dst_count, const GeLUParameter *param);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_FP32_GELU_H_
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#ifndef MINDSPORE_LITE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
|
||||
#define MINDSPORE_LITE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
|
||||
#include <math.h>
|
||||
#ifdef ENABLE_ARM
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
@ -170,4 +171,56 @@ inline static float32x4_t vrecp(float32x4_t v) {
|
|||
MS_STQ_F32(output_ptr + 6 * num, dst##7); \
|
||||
MS_STQ_F32(output_ptr + 7 * num, dst##8);
|
||||
|
||||
static inline MS_FLOAT32X4 MS_TANHX4_F32(MS_FLOAT32X4 src) {
|
||||
static const float data[] = {378.0f, 17325.0f, 135135.0f, 28.0f, 3150.0f, 62370.0f};
|
||||
static const MS_FLOAT32X4 neg = {-1.0f, -1.0f, -1.0f, -1.0f};
|
||||
static const MS_FLOAT32X4 pos = {1.0f, 1.0f, 1.0f, 1.0f};
|
||||
MS_FLOAT32X4 square = src * src;
|
||||
MS_FLOAT32X4 a = (((square + data[0]) * square + data[1]) * square + data[2]) * src;
|
||||
MS_FLOAT32X4 b = ((data[3] * square + data[4]) * square + data[5]) * square + data[2];
|
||||
return MS_MINQ_F32(MS_MAXQ_F32(a / b, neg), pos);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
static inline MS_FLOAT32X8 MS_TANHX8_F32(MS_FLOAT32X8 src) {
|
||||
static const float data[] = {378.0f, 17325.0f, 135135.0f, 28.0f, 3150.0f, 62370.0f};
|
||||
static const MS_FLOAT32X8 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f};
|
||||
static const MS_FLOAT32X8 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
|
||||
MS_FLOAT32X8 square = src * src;
|
||||
MS_FLOAT32X8 a = (((square + data[0]) * square + data[1]) * square + data[2]) * src;
|
||||
MS_FLOAT32X8 b = ((data[3] * square + data[4]) * square + data[5]) * square + data[2];
|
||||
return MS_MIN256_F32(MS_MAX256_F32(a / b, neg), pos);
|
||||
}
|
||||
#endif
|
||||
|
||||
static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) {
|
||||
MS_FLOAT32X4 dst;
|
||||
dst[0] = erff(src[0]);
|
||||
dst[1] = erff(src[1]);
|
||||
dst[2] = erff(src[2]);
|
||||
dst[3] = erff(src[3]);
|
||||
return dst;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
static inline float16x8_t MS_TANHX8_F16(float16x8_t src) {
|
||||
float32x4_t src_low = vcvt_f32_f16(vget_low_f16(src));
|
||||
float32x4_t src_high = vcvt_f32_f16(vget_high_f16(src));
|
||||
return vcombine_f16(vcvt_f16_f32(MS_TANHX4_F32(src_low)), vcvt_f16_f32(MS_TANHX4_F32(src_high)));
|
||||
}
|
||||
|
||||
static inline float16x8_t MS_ERFX8_F16(float16x8_t src) {
|
||||
float16x8_t dst;
|
||||
dst[0] = erff(src[0]);
|
||||
dst[1] = erff(src[1]);
|
||||
dst[2] = erff(src[2]);
|
||||
dst[3] = erff(src[3]);
|
||||
dst[4] = erff(src[4]);
|
||||
dst[5] = erff(src[5]);
|
||||
dst[6] = erff(src[6]);
|
||||
dst[7] = erff(src[7]);
|
||||
return dst;
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <stdlib.h>
|
||||
#include <stdbool.h>
|
||||
#include <string.h>
|
||||
#if defined(ENBALE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM)
|
||||
#if defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM)
|
||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
||||
#endif
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ using mindspore::kernel::KERNEL_ARCH::kCPU;
|
|||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::ActivationType_GELU;
|
||||
using mindspore::schema::ActivationType_HSWISH;
|
||||
using mindspore::schema::ActivationType_LEAKY_RELU;
|
||||
using mindspore::schema::ActivationType_RELU;
|
||||
|
@ -73,6 +74,8 @@ int ActivationFp16CPUKernel::DoActivation(int task_id) {
|
|||
} else if (type_ == schema::ActivationType_HARD_TANH) {
|
||||
error_code =
|
||||
HardTanhFp16(fp16_input_ + stride * task_id, count, fp16_output_ + stride * task_id, min_val_, max_val_);
|
||||
} else if (type_ == schema::ActivationType_GELU) {
|
||||
error_code = GeluFp16(fp16_input_ + stride * task_id, count, fp16_output_ + stride * task_id, true);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Activation fp16 not support type: " << type_;
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -79,7 +79,7 @@ int ActivationCPUKernel::DoActivation(int task_id) {
|
|||
} else if (type_ == schema::ActivationType_HARD_TANH) {
|
||||
ret = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_);
|
||||
} else if (type_ == schema::ActivationType_GELU) {
|
||||
ret = Gelu(input_addr + stride * task_id, count, output_addr + stride * task_id);
|
||||
ret = Gelu(input_addr + stride * task_id, count, output_addr + stride * task_id, true);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Activation type error";
|
||||
return RET_ERROR;
|
||||
|
|
Loading…
Reference in New Issue