forked from mindspore-Ecosystem/mindspore
[MSLITE] layer norm fp32 optimize
This commit is contained in:
parent
f02541b8ed
commit
f94a0b3d87
|
@ -18,30 +18,87 @@
|
||||||
#include "nnacl/errorcode.h"
|
#include "nnacl/errorcode.h"
|
||||||
#include "nnacl/op_base.h"
|
#include "nnacl/op_base.h"
|
||||||
|
|
||||||
int LayerNorm(int outer_size, int inner_size, const float *src_data, const float *gamma_data, const float *beta_data,
|
int LayerNorm(size_t outer_size, size_t inner_size, const float *src_data, const float *gamma_data,
|
||||||
bool affine, float epsilon, float *dst_data, int tid, int thread_num) {
|
const float *beta_data, bool affine, float epsilon, float *dst_data, size_t task_id, size_t thread_num) {
|
||||||
if (src_data == NULL || dst_data == NULL) {
|
if (src_data == NULL || dst_data == NULL) {
|
||||||
return NNACL_NULL_PTR;
|
return NNACL_NULL_PTR;
|
||||||
}
|
}
|
||||||
if (affine && (gamma_data == NULL || beta_data == NULL)) {
|
if (affine && (gamma_data == NULL || beta_data == NULL)) {
|
||||||
return NNACL_NULL_PTR;
|
return NNACL_NULL_PTR;
|
||||||
}
|
}
|
||||||
for (int j = tid; j < outer_size; j += thread_num) {
|
|
||||||
|
for (size_t j = task_id; j < outer_size; j += thread_num) {
|
||||||
const float *src = src_data + j * inner_size;
|
const float *src = src_data + j * inner_size;
|
||||||
float *dst = dst_data + j * inner_size;
|
float *dst = dst_data + j * inner_size;
|
||||||
float mean = 0.0f;
|
float mean = 0.0f;
|
||||||
float square_mean = 0.0f;
|
float square_mean = 0.0f;
|
||||||
for (int i = 0; i < inner_size; i++) {
|
|
||||||
mean += src[i];
|
int index = 0;
|
||||||
square_mean += src[i] * src[i];
|
#ifdef ENABLE_NEON
|
||||||
|
float32x4_t sum = vdupq_n_f32(0);
|
||||||
|
float32x4_t square_sum = vdupq_n_f32(0);
|
||||||
|
for (; index < inner_size - C8NUM; index += C8NUM) {
|
||||||
|
float32x4_t srcv1 = vld1q_f32(src + index);
|
||||||
|
float32x4_t srcv2 = vld1q_f32(src + index + 4);
|
||||||
|
float32x4_t squarev1 = vmulq_f32(srcv1, srcv1);
|
||||||
|
float32x4_t squarev2 = vmulq_f32(srcv2, srcv2);
|
||||||
|
sum = vaddq_f32(sum, srcv1);
|
||||||
|
sum = vaddq_f32(sum, srcv2);
|
||||||
|
square_sum = vaddq_f32(square_sum, squarev1);
|
||||||
|
square_sum = vaddq_f32(square_sum, squarev2);
|
||||||
}
|
}
|
||||||
|
mean = sum[0] + sum[1] + sum[2] + sum[3];
|
||||||
|
square_mean = square_sum[0] + square_sum[1] + square_sum[2] + square_sum[3];
|
||||||
|
#endif
|
||||||
|
for (; index < inner_size; index++) {
|
||||||
|
mean += src[index];
|
||||||
|
square_mean += src[index] * src[index];
|
||||||
|
}
|
||||||
|
|
||||||
mean /= (float)inner_size;
|
mean /= (float)inner_size;
|
||||||
square_mean /= (float)inner_size;
|
square_mean /= (float)inner_size;
|
||||||
const float deno = 1 / sqrtf(square_mean - mean * mean + epsilon);
|
const float deno = 1 / sqrtf(square_mean - mean * mean + epsilon);
|
||||||
for (int i = 0; i < inner_size; ++i) {
|
|
||||||
dst[i] = (src[i] - mean) * deno;
|
index = 0;
|
||||||
|
#ifdef ENABLE_NEON
|
||||||
|
float32x4_t meanv = vdupq_n_f32(mean);
|
||||||
|
float32x4_t denov = vdupq_n_f32(deno);
|
||||||
if (affine) {
|
if (affine) {
|
||||||
dst[i] = dst[i] * gamma_data[i] + beta_data[i];
|
for (; index < inner_size - C8NUM; index += C8NUM) {
|
||||||
|
float32x4_t srcv1 = vld1q_f32(src + index);
|
||||||
|
float32x4_t srcv2 = vld1q_f32(src + index + 4);
|
||||||
|
float32x4_t outv1 = vsubq_f32(srcv1, meanv);
|
||||||
|
float32x4_t outv2 = vsubq_f32(srcv2, meanv);
|
||||||
|
outv1 = vmulq_f32(outv1, denov);
|
||||||
|
outv2 = vmulq_f32(outv2, denov);
|
||||||
|
float32x4_t gammav1 = vld1q_f32(gamma_data + index);
|
||||||
|
float32x4_t gammav2 = vld1q_f32(gamma_data + index + 4);
|
||||||
|
float32x4_t betav1 = vld1q_f32(beta_data + index);
|
||||||
|
float32x4_t betav2 = vld1q_f32(beta_data + index + 4);
|
||||||
|
outv1 = vmulq_f32(outv1, gammav1);
|
||||||
|
outv2 = vmulq_f32(outv2, gammav2);
|
||||||
|
outv1 = vaddq_f32(outv1, betav1);
|
||||||
|
outv2 = vaddq_f32(outv2, betav2);
|
||||||
|
vst1q_f32(dst + index, outv1);
|
||||||
|
vst1q_f32(dst + index + 4, outv2);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; index < inner_size - C8NUM; index += C8NUM) {
|
||||||
|
float32x4_t srcv1 = vld1q_f32(src + index);
|
||||||
|
float32x4_t srcv2 = vld1q_f32(src + index + 4);
|
||||||
|
float32x4_t outv1 = vsubq_f32(srcv1, meanv);
|
||||||
|
float32x4_t outv2 = vsubq_f32(srcv2, meanv);
|
||||||
|
outv1 = vmulq_f32(outv1, denov);
|
||||||
|
outv2 = vmulq_f32(outv2, denov);
|
||||||
|
vst1q_f32(dst + index, outv1);
|
||||||
|
vst1q_f32(dst + index + 4, outv2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
for (; index < inner_size; index++) {
|
||||||
|
dst[index] = (src[index] - mean) * deno;
|
||||||
|
if (affine) {
|
||||||
|
dst[index] = dst[index] * gamma_data[index] + beta_data[index];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,8 +23,8 @@
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
int LayerNorm(int outer_size, int inner_size, const float *src_data, const float *gamma_data, const float *beta_data,
|
int LayerNorm(size_t outer_size, size_t inner_size, const float *src_data, const float *gamma_data,
|
||||||
bool affine, float epsilon, float *dst_data, int tid, int thread_num);
|
const float *beta_data, bool affine, float epsilon, float *dst_data, size_t task_id, size_t thread_num);
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
Loading…
Reference in New Issue