!35104 [MSLITE][CPU] avx512 hardward self feel function, groupnorm op's avx512/avx/sse/neon instructions stripped to separate functions

Merge pull request !35104 from Greatpan/avx512_sf_groupnorm
This commit is contained in:
i-robot 2022-05-30 03:30:39 +00:00 committed by Gitee
commit f7c5ea1692
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 88 additions and 24 deletions

View File

@ -20,30 +20,25 @@
#include "nnacl/op_base.h"
#include "nnacl/errorcode.h"
#include "nnacl/intrinsics/ms_simd_instructions.h"
#ifdef ENABLE_AVX512
#include "nnacl/avx512/group_norm_fp32_avx512.h"
#endif
#ifdef ENABLE_AVX
#include "nnacl/avx/group_norm_fp32_avx.h"
#endif
#ifdef ENABLE_SSE
#include "nnacl/sse/group_norm_fp32_sse.h"
#endif
#ifdef ENABLE_ARM
#include "nnacl/neon/group_norm_fp32_neon.h"
#endif
static void GroupNormFp32MeanVar(const float *input, float *run_mean, float *run_var, int completed_group,
int cur_groups, const GroupNormParameter *param);
#define SimdFusedGroupNormFp32DoWork(block_size, block_num, mean, v_sqrt, scale, offset, unit_input, unit_output, u) \
do { \
MS_FLOAT_32xN(block_num) input = MS_LD_F32(block_size, unit_input + u); \
MS_FLOAT_32xN(block_num) norm_val = MS_DIV_F32(block_size, MS_SUB_F32(block_size, input, mean), v_sqrt); \
MS_FLOAT_32xN(block_num) output = MS_ADD_F32(block_size, MS_MUL_F32(block_size, norm_val, scale), offset); \
MS_ST_F32(block_size, unit_output + u, output); \
} while (0)
// 32 bits, block_size : (512/256/128/32), block_num : (16/8/4/1)
#define SimdFusedGroupNormFp32CoreCalc(block_size, block_num, unit_input, s, m, o, var_sqrt, param, unit_output, u) \
do { \
MS_FLOAT_32xN(block_num) mean = MS_MOVN_F32(block_size, m); \
MS_FLOAT_32xN(block_num) v_sqrt = MS_MOVN_F32(block_size, var_sqrt); \
MS_FLOAT_32xN(block_num) scale = MS_MOVN_F32(block_size, s); \
MS_FLOAT_32xN(block_num) offset = MS_MOVN_F32(block_size, o); \
for (int block_max_size = param->unit_ - block_num + 1; u < block_max_size; u += block_num) { \
SimdFusedGroupNormFp32DoWork(block_size, block_num, mean, v_sqrt, scale, offset, unit_input, unit_output, u); \
} \
} while (0)
int GroupNormFp32(const float *input, const float *scale, const float *offset, float *mean, float *variance,
const GroupNormParameter *param, int task_id, float *output) {
if (param->op_parameter_.thread_num_ == 0) {
@ -76,8 +71,7 @@ int GroupNormFp32(const float *input, const float *scale, const float *offset, f
float s = scale[c_offset + c];
float o = offset[c_offset + c];
int u = 0;
MS_SIMD_RUN_NO_SCALAR(SimdFusedGroupNormFp32CoreCalc, unit_input, s, m, o, variance_sqrt, param, unit_output,
u);
SIMD_RUN_NO_SCALAR(GroupNormFp32, u, unit_input, s, o, m, variance_sqrt, param->unit_, unit_output);
for (; u < param->unit_; u++) {
float norm_val = (unit_input[u] - m) / variance_sqrt;
unit_output[u] = norm_val * s + o;
@ -120,7 +114,7 @@ static void GroupNormFp32MeanVar(const float *input, float *run_mean, float *run
for (int c = 0; c < num_of_ch_per_group; c++) {
const float *in = input + (num_of_ch_per_group * g_idx + c) * param->unit_;
int i = 0;
MS_SIMD_RUN_NO_SCALAR(SimdReduceSum, in, i, sum);
SIMD_RUN_NO_SCALAR(GroupNormReduceSum, i, in, &sum, param->unit_);
for (; i < param->unit_; i++) {
sum += in[i];
}
@ -136,7 +130,7 @@ static void GroupNormFp32MeanVar(const float *input, float *run_mean, float *run
for (int c = 0; c < num_of_ch_per_group; c++) {
const float *in = input + (num_of_ch_per_group * g_idx + c) * param->unit_;
int i = 0;
MS_SIMD_RUN_NO_SCALAR(SimdReduceVar, in, run_mean[g_idx], i, var);
SIMD_RUN_NO_SCALAR(GroupNormReduceVar, i, in, run_mean[g_idx], &var, param->unit_);
for (; i < param->unit_; i++) {
var += (in[i] - run_mean[g_idx]) * (in[i] - run_mean[g_idx]);
}

View File

@ -0,0 +1,70 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_FP32_GROUP_NORM_FP32_@SIMD_INSTRUCTION@_H_
#define MINDSPORE_NNACL_FP32_GROUP_NORM_FP32_@SIMD_INSTRUCTION@_H_
#include "nnacl/intrinsics/ms_simd_instructions.h"
#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h"
#ifdef __cplusplus
extern "C" {
#endif
@SIMD_INSTRUCTION_BEGIN@
static inline int64_t GroupNormFp32@SIMD_INSTRUCTION@(int64_t index, const float *unit_input, float scale, float offset, float mean,
float var_sqrt, int unit, float *unit_output) {
SIMD_F32 mean_val = SIMD_MOV_F32(mean);
SIMD_F32 v_sqrt = SIMD_MOV_F32(var_sqrt);
SIMD_F32 scale_val = SIMD_MOV_F32(scale);
SIMD_F32 offset_val = SIMD_MOV_F32(offset);
for (int block_max_size = unit - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
SIMD_F32 input = SIMD_LD_F32(unit_input + index);
SIMD_F32 norm_val = SIMD_DIV_F32(SIMD_SUB_F32(input, mean_val), v_sqrt);
SIMD_F32 output = SIMD_ADD_F32(SIMD_MUL_F32(norm_val, scale_val), offset_val);
SIMD_ST_F32(unit_output + index, output);
}
return index;
}
static inline int64_t GroupNormReduceSum@SIMD_INSTRUCTION@(int64_t index, const float *in, float *sum, int unit) {
if (unit - index >= 4 * BLOCK_NUM) {
SIMD_F32 tmp = SIMD_MOV_F32(0);
for (int block_max_size = unit - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
tmp = SIMD_ADD_F32(tmp, SIMD_LD_F32(in + index));
}
*sum += SIMD_GET_SUM_F32(tmp);
}
return index;
}
static inline int64_t GroupNormReduceVar@SIMD_INSTRUCTION@(int64_t index, const float *in, float mean, float *sum, int unit) {
if (unit - index >= 4 * BLOCK_NUM) {
SIMD_F32 mean_val = SIMD_MOV_F32(mean);
SIMD_F32 tmp = SIMD_MOV_F32(0);
for (int block_max_size = unit - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
SIMD_F32 input = SIMD_SUB_F32(SIMD_LD_F32(in + index), mean_val);
tmp = SIMD_ADD_F32(tmp, SIMD_MUL_F32(input, input));
}
*sum += SIMD_GET_SUM_F32(tmp);
}
return index;
}
@SIMD_INSTRUCTION_END@
#ifdef __cplusplus
}
#endif
#endif