[MSLITE] layer norm fp32 optimize

This commit is contained in:
ling 2020-12-08 14:08:38 +08:00
parent f02541b8ed
commit f94a0b3d87
2 changed files with 68 additions and 11 deletions

View File

@ -18,30 +18,87 @@
#include "nnacl/errorcode.h" #include "nnacl/errorcode.h"
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
int LayerNorm(int outer_size, int inner_size, const float *src_data, const float *gamma_data, const float *beta_data, int LayerNorm(size_t outer_size, size_t inner_size, const float *src_data, const float *gamma_data,
bool affine, float epsilon, float *dst_data, int tid, int thread_num) { const float *beta_data, bool affine, float epsilon, float *dst_data, size_t task_id, size_t thread_num) {
if (src_data == NULL || dst_data == NULL) { if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR; return NNACL_NULL_PTR;
} }
if (affine && (gamma_data == NULL || beta_data == NULL)) { if (affine && (gamma_data == NULL || beta_data == NULL)) {
return NNACL_NULL_PTR; return NNACL_NULL_PTR;
} }
for (int j = tid; j < outer_size; j += thread_num) {
for (size_t j = task_id; j < outer_size; j += thread_num) {
const float *src = src_data + j * inner_size; const float *src = src_data + j * inner_size;
float *dst = dst_data + j * inner_size; float *dst = dst_data + j * inner_size;
float mean = 0.0f; float mean = 0.0f;
float square_mean = 0.0f; float square_mean = 0.0f;
for (int i = 0; i < inner_size; i++) {
mean += src[i]; int index = 0;
square_mean += src[i] * src[i]; #ifdef ENABLE_NEON
float32x4_t sum = vdupq_n_f32(0);
float32x4_t square_sum = vdupq_n_f32(0);
for (; index < inner_size - C8NUM; index += C8NUM) {
float32x4_t srcv1 = vld1q_f32(src + index);
float32x4_t srcv2 = vld1q_f32(src + index + 4);
float32x4_t squarev1 = vmulq_f32(srcv1, srcv1);
float32x4_t squarev2 = vmulq_f32(srcv2, srcv2);
sum = vaddq_f32(sum, srcv1);
sum = vaddq_f32(sum, srcv2);
square_sum = vaddq_f32(square_sum, squarev1);
square_sum = vaddq_f32(square_sum, squarev2);
} }
mean = sum[0] + sum[1] + sum[2] + sum[3];
square_mean = square_sum[0] + square_sum[1] + square_sum[2] + square_sum[3];
#endif
for (; index < inner_size; index++) {
mean += src[index];
square_mean += src[index] * src[index];
}
mean /= (float)inner_size; mean /= (float)inner_size;
square_mean /= (float)inner_size; square_mean /= (float)inner_size;
const float deno = 1 / sqrtf(square_mean - mean * mean + epsilon); const float deno = 1 / sqrtf(square_mean - mean * mean + epsilon);
for (int i = 0; i < inner_size; ++i) {
dst[i] = (src[i] - mean) * deno; index = 0;
#ifdef ENABLE_NEON
float32x4_t meanv = vdupq_n_f32(mean);
float32x4_t denov = vdupq_n_f32(deno);
if (affine) { if (affine) {
dst[i] = dst[i] * gamma_data[i] + beta_data[i]; for (; index < inner_size - C8NUM; index += C8NUM) {
float32x4_t srcv1 = vld1q_f32(src + index);
float32x4_t srcv2 = vld1q_f32(src + index + 4);
float32x4_t outv1 = vsubq_f32(srcv1, meanv);
float32x4_t outv2 = vsubq_f32(srcv2, meanv);
outv1 = vmulq_f32(outv1, denov);
outv2 = vmulq_f32(outv2, denov);
float32x4_t gammav1 = vld1q_f32(gamma_data + index);
float32x4_t gammav2 = vld1q_f32(gamma_data + index + 4);
float32x4_t betav1 = vld1q_f32(beta_data + index);
float32x4_t betav2 = vld1q_f32(beta_data + index + 4);
outv1 = vmulq_f32(outv1, gammav1);
outv2 = vmulq_f32(outv2, gammav2);
outv1 = vaddq_f32(outv1, betav1);
outv2 = vaddq_f32(outv2, betav2);
vst1q_f32(dst + index, outv1);
vst1q_f32(dst + index + 4, outv2);
}
} else {
for (; index < inner_size - C8NUM; index += C8NUM) {
float32x4_t srcv1 = vld1q_f32(src + index);
float32x4_t srcv2 = vld1q_f32(src + index + 4);
float32x4_t outv1 = vsubq_f32(srcv1, meanv);
float32x4_t outv2 = vsubq_f32(srcv2, meanv);
outv1 = vmulq_f32(outv1, denov);
outv2 = vmulq_f32(outv2, denov);
vst1q_f32(dst + index, outv1);
vst1q_f32(dst + index + 4, outv2);
}
}
#endif
for (; index < inner_size; index++) {
dst[index] = (src[index] - mean) * deno;
if (affine) {
dst[index] = dst[index] * gamma_data[index] + beta_data[index];
} }
} }
} }

View File

@ -23,8 +23,8 @@
extern "C" { extern "C" {
#endif #endif
int LayerNorm(int outer_size, int inner_size, const float *src_data, const float *gamma_data, const float *beta_data, int LayerNorm(size_t outer_size, size_t inner_size, const float *src_data, const float *gamma_data,
bool affine, float epsilon, float *dst_data, int tid, int thread_num); const float *beta_data, bool affine, float epsilon, float *dst_data, size_t task_id, size_t thread_num);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif