!35104 [MSLITE][CPU] avx512 hardward self feel function, groupnorm op's avx512/avx/sse/neon instructions stripped to separate functions
Merge pull request !35104 from Greatpan/avx512_sf_groupnorm
This commit is contained in:
commit
f7c5ea1692
|
@ -20,30 +20,25 @@
|
|||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
||||
#ifdef ENABLE_AVX512
|
||||
#include "nnacl/avx512/group_norm_fp32_avx512.h"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
#include "nnacl/avx/group_norm_fp32_avx.h"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_SSE
|
||||
#include "nnacl/sse/group_norm_fp32_sse.h"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_ARM
|
||||
#include "nnacl/neon/group_norm_fp32_neon.h"
|
||||
#endif
|
||||
|
||||
static void GroupNormFp32MeanVar(const float *input, float *run_mean, float *run_var, int completed_group,
|
||||
int cur_groups, const GroupNormParameter *param);
|
||||
|
||||
#define SimdFusedGroupNormFp32DoWork(block_size, block_num, mean, v_sqrt, scale, offset, unit_input, unit_output, u) \
|
||||
do { \
|
||||
MS_FLOAT_32xN(block_num) input = MS_LD_F32(block_size, unit_input + u); \
|
||||
MS_FLOAT_32xN(block_num) norm_val = MS_DIV_F32(block_size, MS_SUB_F32(block_size, input, mean), v_sqrt); \
|
||||
MS_FLOAT_32xN(block_num) output = MS_ADD_F32(block_size, MS_MUL_F32(block_size, norm_val, scale), offset); \
|
||||
MS_ST_F32(block_size, unit_output + u, output); \
|
||||
} while (0)
|
||||
|
||||
// 32 bits, block_size : (512/256/128/32), block_num : (16/8/4/1)
|
||||
#define SimdFusedGroupNormFp32CoreCalc(block_size, block_num, unit_input, s, m, o, var_sqrt, param, unit_output, u) \
|
||||
do { \
|
||||
MS_FLOAT_32xN(block_num) mean = MS_MOVN_F32(block_size, m); \
|
||||
MS_FLOAT_32xN(block_num) v_sqrt = MS_MOVN_F32(block_size, var_sqrt); \
|
||||
MS_FLOAT_32xN(block_num) scale = MS_MOVN_F32(block_size, s); \
|
||||
MS_FLOAT_32xN(block_num) offset = MS_MOVN_F32(block_size, o); \
|
||||
for (int block_max_size = param->unit_ - block_num + 1; u < block_max_size; u += block_num) { \
|
||||
SimdFusedGroupNormFp32DoWork(block_size, block_num, mean, v_sqrt, scale, offset, unit_input, unit_output, u); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
int GroupNormFp32(const float *input, const float *scale, const float *offset, float *mean, float *variance,
|
||||
const GroupNormParameter *param, int task_id, float *output) {
|
||||
if (param->op_parameter_.thread_num_ == 0) {
|
||||
|
@ -76,8 +71,7 @@ int GroupNormFp32(const float *input, const float *scale, const float *offset, f
|
|||
float s = scale[c_offset + c];
|
||||
float o = offset[c_offset + c];
|
||||
int u = 0;
|
||||
MS_SIMD_RUN_NO_SCALAR(SimdFusedGroupNormFp32CoreCalc, unit_input, s, m, o, variance_sqrt, param, unit_output,
|
||||
u);
|
||||
SIMD_RUN_NO_SCALAR(GroupNormFp32, u, unit_input, s, o, m, variance_sqrt, param->unit_, unit_output);
|
||||
for (; u < param->unit_; u++) {
|
||||
float norm_val = (unit_input[u] - m) / variance_sqrt;
|
||||
unit_output[u] = norm_val * s + o;
|
||||
|
@ -120,7 +114,7 @@ static void GroupNormFp32MeanVar(const float *input, float *run_mean, float *run
|
|||
for (int c = 0; c < num_of_ch_per_group; c++) {
|
||||
const float *in = input + (num_of_ch_per_group * g_idx + c) * param->unit_;
|
||||
int i = 0;
|
||||
MS_SIMD_RUN_NO_SCALAR(SimdReduceSum, in, i, sum);
|
||||
SIMD_RUN_NO_SCALAR(GroupNormReduceSum, i, in, &sum, param->unit_);
|
||||
for (; i < param->unit_; i++) {
|
||||
sum += in[i];
|
||||
}
|
||||
|
@ -136,7 +130,7 @@ static void GroupNormFp32MeanVar(const float *input, float *run_mean, float *run
|
|||
for (int c = 0; c < num_of_ch_per_group; c++) {
|
||||
const float *in = input + (num_of_ch_per_group * g_idx + c) * param->unit_;
|
||||
int i = 0;
|
||||
MS_SIMD_RUN_NO_SCALAR(SimdReduceVar, in, run_mean[g_idx], i, var);
|
||||
SIMD_RUN_NO_SCALAR(GroupNormReduceVar, i, in, run_mean[g_idx], &var, param->unit_);
|
||||
for (; i < param->unit_; i++) {
|
||||
var += (in[i] - run_mean[g_idx]) * (in[i] - run_mean[g_idx]);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_NNACL_FP32_GROUP_NORM_FP32_@SIMD_INSTRUCTION@_H_
|
||||
#define MINDSPORE_NNACL_FP32_GROUP_NORM_FP32_@SIMD_INSTRUCTION@_H_
|
||||
|
||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
||||
#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
@SIMD_INSTRUCTION_BEGIN@
|
||||
|
||||
static inline int64_t GroupNormFp32@SIMD_INSTRUCTION@(int64_t index, const float *unit_input, float scale, float offset, float mean,
|
||||
float var_sqrt, int unit, float *unit_output) {
|
||||
SIMD_F32 mean_val = SIMD_MOV_F32(mean);
|
||||
SIMD_F32 v_sqrt = SIMD_MOV_F32(var_sqrt);
|
||||
SIMD_F32 scale_val = SIMD_MOV_F32(scale);
|
||||
SIMD_F32 offset_val = SIMD_MOV_F32(offset);
|
||||
for (int block_max_size = unit - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
|
||||
SIMD_F32 input = SIMD_LD_F32(unit_input + index);
|
||||
SIMD_F32 norm_val = SIMD_DIV_F32(SIMD_SUB_F32(input, mean_val), v_sqrt);
|
||||
SIMD_F32 output = SIMD_ADD_F32(SIMD_MUL_F32(norm_val, scale_val), offset_val);
|
||||
SIMD_ST_F32(unit_output + index, output);
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
static inline int64_t GroupNormReduceSum@SIMD_INSTRUCTION@(int64_t index, const float *in, float *sum, int unit) {
|
||||
if (unit - index >= 4 * BLOCK_NUM) {
|
||||
SIMD_F32 tmp = SIMD_MOV_F32(0);
|
||||
for (int block_max_size = unit - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
|
||||
tmp = SIMD_ADD_F32(tmp, SIMD_LD_F32(in + index));
|
||||
}
|
||||
*sum += SIMD_GET_SUM_F32(tmp);
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
static inline int64_t GroupNormReduceVar@SIMD_INSTRUCTION@(int64_t index, const float *in, float mean, float *sum, int unit) {
|
||||
if (unit - index >= 4 * BLOCK_NUM) {
|
||||
SIMD_F32 mean_val = SIMD_MOV_F32(mean);
|
||||
SIMD_F32 tmp = SIMD_MOV_F32(0);
|
||||
for (int block_max_size = unit - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
|
||||
SIMD_F32 input = SIMD_SUB_F32(SIMD_LD_F32(in + index), mean_val);
|
||||
tmp = SIMD_ADD_F32(tmp, SIMD_MUL_F32(input, input));
|
||||
}
|
||||
*sum += SIMD_GET_SUM_F32(tmp);
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
@SIMD_INSTRUCTION_END@
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
Loading…
Reference in New Issue