diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/group_norm_fp32.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/group_norm_fp32.c index bc7ebdb6143..67935701ea7 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/group_norm_fp32.c +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/group_norm_fp32.c @@ -20,30 +20,25 @@ #include "nnacl/op_base.h" #include "nnacl/errorcode.h" #include "nnacl/intrinsics/ms_simd_instructions.h" +#ifdef ENABLE_AVX512 +#include "nnacl/avx512/group_norm_fp32_avx512.h" +#endif + +#ifdef ENABLE_AVX +#include "nnacl/avx/group_norm_fp32_avx.h" +#endif + +#ifdef ENABLE_SSE +#include "nnacl/sse/group_norm_fp32_sse.h" +#endif + +#ifdef ENABLE_ARM +#include "nnacl/neon/group_norm_fp32_neon.h" +#endif static void GroupNormFp32MeanVar(const float *input, float *run_mean, float *run_var, int completed_group, int cur_groups, const GroupNormParameter *param); -#define SimdFusedGroupNormFp32DoWork(block_size, block_num, mean, v_sqrt, scale, offset, unit_input, unit_output, u) \ - do { \ - MS_FLOAT_32xN(block_num) input = MS_LD_F32(block_size, unit_input + u); \ - MS_FLOAT_32xN(block_num) norm_val = MS_DIV_F32(block_size, MS_SUB_F32(block_size, input, mean), v_sqrt); \ - MS_FLOAT_32xN(block_num) output = MS_ADD_F32(block_size, MS_MUL_F32(block_size, norm_val, scale), offset); \ - MS_ST_F32(block_size, unit_output + u, output); \ - } while (0) - -// 32 bits, block_size : (512/256/128/32), block_num : (16/8/4/1) -#define SimdFusedGroupNormFp32CoreCalc(block_size, block_num, unit_input, s, m, o, var_sqrt, param, unit_output, u) \ - do { \ - MS_FLOAT_32xN(block_num) mean = MS_MOVN_F32(block_size, m); \ - MS_FLOAT_32xN(block_num) v_sqrt = MS_MOVN_F32(block_size, var_sqrt); \ - MS_FLOAT_32xN(block_num) scale = MS_MOVN_F32(block_size, s); \ - MS_FLOAT_32xN(block_num) offset = MS_MOVN_F32(block_size, o); \ - for (int block_max_size = param->unit_ - block_num + 1; u < block_max_size; u += block_num) { \ - SimdFusedGroupNormFp32DoWork(block_size, block_num, mean, v_sqrt, scale, offset, unit_input, unit_output, u); \ - } \ - } while (0) - int GroupNormFp32(const float *input, const float *scale, const float *offset, float *mean, float *variance, const GroupNormParameter *param, int task_id, float *output) { if (param->op_parameter_.thread_num_ == 0) { @@ -76,8 +71,7 @@ int GroupNormFp32(const float *input, const float *scale, const float *offset, f float s = scale[c_offset + c]; float o = offset[c_offset + c]; int u = 0; - MS_SIMD_RUN_NO_SCALAR(SimdFusedGroupNormFp32CoreCalc, unit_input, s, m, o, variance_sqrt, param, unit_output, - u); + SIMD_RUN_NO_SCALAR(GroupNormFp32, u, unit_input, s, o, m, variance_sqrt, param->unit_, unit_output); for (; u < param->unit_; u++) { float norm_val = (unit_input[u] - m) / variance_sqrt; unit_output[u] = norm_val * s + o; @@ -120,7 +114,7 @@ static void GroupNormFp32MeanVar(const float *input, float *run_mean, float *run for (int c = 0; c < num_of_ch_per_group; c++) { const float *in = input + (num_of_ch_per_group * g_idx + c) * param->unit_; int i = 0; - MS_SIMD_RUN_NO_SCALAR(SimdReduceSum, in, i, sum); + SIMD_RUN_NO_SCALAR(GroupNormReduceSum, i, in, &sum, param->unit_); for (; i < param->unit_; i++) { sum += in[i]; } @@ -136,7 +130,7 @@ static void GroupNormFp32MeanVar(const float *input, float *run_mean, float *run for (int c = 0; c < num_of_ch_per_group; c++) { const float *in = input + (num_of_ch_per_group * g_idx + c) * param->unit_; int i = 0; - MS_SIMD_RUN_NO_SCALAR(SimdReduceVar, in, run_mean[g_idx], i, var); + SIMD_RUN_NO_SCALAR(GroupNormReduceVar, i, in, run_mean[g_idx], &var, param->unit_); for (; i < param->unit_; i++) { var += (in[i] - run_mean[g_idx]) * (in[i] - run_mean[g_idx]); } diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/group_norm_fp32_simd.h.in b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/group_norm_fp32_simd.h.in new file mode 100644 index 00000000000..8fb8fa5b1a2 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/group_norm_fp32_simd.h.in @@ -0,0 +1,70 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_GROUP_NORM_FP32_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_GROUP_NORM_FP32_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t GroupNormFp32@SIMD_INSTRUCTION@(int64_t index, const float *unit_input, float scale, float offset, float mean, + float var_sqrt, int unit, float *unit_output) { + SIMD_F32 mean_val = SIMD_MOV_F32(mean); + SIMD_F32 v_sqrt = SIMD_MOV_F32(var_sqrt); + SIMD_F32 scale_val = SIMD_MOV_F32(scale); + SIMD_F32 offset_val = SIMD_MOV_F32(offset); + for (int block_max_size = unit - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input = SIMD_LD_F32(unit_input + index); + SIMD_F32 norm_val = SIMD_DIV_F32(SIMD_SUB_F32(input, mean_val), v_sqrt); + SIMD_F32 output = SIMD_ADD_F32(SIMD_MUL_F32(norm_val, scale_val), offset_val); + SIMD_ST_F32(unit_output + index, output); + } + return index; +} + +static inline int64_t GroupNormReduceSum@SIMD_INSTRUCTION@(int64_t index, const float *in, float *sum, int unit) { + if (unit - index >= 4 * BLOCK_NUM) { + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int block_max_size = unit - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + tmp = SIMD_ADD_F32(tmp, SIMD_LD_F32(in + index)); + } + *sum += SIMD_GET_SUM_F32(tmp); + } + return index; +} + +static inline int64_t GroupNormReduceVar@SIMD_INSTRUCTION@(int64_t index, const float *in, float mean, float *sum, int unit) { + if (unit - index >= 4 * BLOCK_NUM) { + SIMD_F32 mean_val = SIMD_MOV_F32(mean); + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int block_max_size = unit - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input = SIMD_SUB_F32(SIMD_LD_F32(in + index), mean_val); + tmp = SIMD_ADD_F32(tmp, SIMD_MUL_F32(input, input)); + } + *sum += SIMD_GET_SUM_F32(tmp); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif