diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_avx512_fp32.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_avx512_fp32.c index 673182f2bbf..a39afc3a34c 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_avx512_fp32.c +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_avx512_fp32.c @@ -19,6 +19,7 @@ #include "nnacl/fp32/matmul_avx512_fp32.h" #include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" void GemmRowxColKernelFp32(float *dst, const float *src, const float *weight, const float *bias, const size_t act_flag, const size_t row_block, const size_t col_block, const size_t depth, const size_t src_stride, @@ -202,5 +203,53 @@ void MatVecMulAvx512Fp32(const float *a, const float *b, float *c, const float * } } } + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +int64_t GemmIsNotPackOptimizeAVX512(int64_t m_index, const float *a, const float *b, float *c, const float *bias, int m, + int k, int act_type) { + // gemm dot is [m, k] * [k, 1] ==>> [m, 1] + // block 8 + MS_FLOAT32X8 down_threshold256 = _mm256_setzero_ps(); + MS_FLOAT32X8 up_threshold256 = _mm256_set1_ps(C6NUM); + for (; m_index <= m - C8NUM; m_index += C8NUM) { + int k_index = 0; + MS_FLOAT32X8 dst = MS_MOV256_F32(bias[0]); + MS_SET_ZERO512X8_F32(dst16_) + for (; k_index <= k - C16NUM; k_index += C16NUM) { + __m512 weight = _mm512_loadu_ps(b + k_index); + MS_LOAD512X8_F32(src, a + m_index * k + k_index, k) + MS_FMADD512X8_F32(src, weight, dst16_) + } + MS_F32X8_GETI(dst, 0) += MS_REDUCE_ADD512_F32(dst16_1); + MS_F32X8_GETI(dst, 1) += MS_REDUCE_ADD512_F32(dst16_2); + MS_F32X8_GETI(dst, C2NUM) += MS_REDUCE_ADD512_F32(dst16_3); + MS_F32X8_GETI(dst, C3NUM) += MS_REDUCE_ADD512_F32(dst16_4); + MS_F32X8_GETI(dst, C4NUM) += MS_REDUCE_ADD512_F32(dst16_5); + MS_F32X8_GETI(dst, C5NUM) += MS_REDUCE_ADD512_F32(dst16_6); + MS_F32X8_GETI(dst, C6NUM) += MS_REDUCE_ADD512_F32(dst16_7); + MS_F32X8_GETI(dst, C7NUM) += MS_REDUCE_ADD512_F32(dst16_8); + for (; k_index < k; k_index++) { + MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index]; + MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k]; + MS_F32X8_GETI(dst, C2NUM) += b[k_index] * a[m_index * k + k_index + C2NUM * k]; + MS_F32X8_GETI(dst, C3NUM) += b[k_index] * a[m_index * k + k_index + C3NUM * k]; + MS_F32X8_GETI(dst, C4NUM) += b[k_index] * a[m_index * k + k_index + C4NUM * k]; + MS_F32X8_GETI(dst, C5NUM) += b[k_index] * a[m_index * k + k_index + C5NUM * k]; + MS_F32X8_GETI(dst, C6NUM) += b[k_index] * a[m_index * k + k_index + C6NUM * k]; + MS_F32X8_GETI(dst, C7NUM) += b[k_index] * a[m_index * k + k_index + C7NUM * k]; + } + + if (act_type != 0) { + dst = MS_MAX256_F32(dst, down_threshold256); + if (act_type == 3) { // 3: relu6 + dst = MS_MIN256_F32(dst, up_threshold256); + } + } + + MS_ST256_F32(c + m_index, dst); + } + return m_index; +} + #pragma GCC pop_options #endif diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_avx512_fp32.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_avx512_fp32.h index 11efdc28a49..59cab81d8a8 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_avx512_fp32.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_avx512_fp32.h @@ -30,6 +30,9 @@ void MatVecMulAvx512Fp32(const float *a, const float *b, float *c, const float * void MatMulAvx512Fp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int cur_col, int col_align, int row); +int64_t GemmIsNotPackOptimizeAVX512(int64_t m_index, const float *a, const float *b, float *c, const float *bias, int m, + int k, int act_type); + // 64 block void nnacl_gemm_avx512_6x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, const size_t act_flag, const size_t row_block, const size_t col_block, diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.c index 65aa5142b68..24722a663c4 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.c +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.c @@ -16,8 +16,25 @@ #include "nnacl/fp32/matmul_fp32.h" #include "nnacl/fp32/pack_fp32.h" +#include "nnacl/fp32/matmul_avx512_fp32.h" #include "nnacl/intrinsics/ms_simd_instructions.h" +#ifdef ENABLE_AVX512 +#include "nnacl/avx512/matmul_fp32_avx512.h" +#endif + +#ifdef ENABLE_AVX +#include "nnacl/avx/matmul_fp32_avx.h" +#endif + +#ifdef ENABLE_SSE +#include "nnacl/sse/matmul_fp32_sse.h" +#endif + +#ifdef ENABLE_ARM +#include "nnacl/neon/matmul_fp32_neon.h" +#endif + #ifndef ENABLE_ARM void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col) { for (int ci = 0; ci < col; ci++) { @@ -1271,44 +1288,8 @@ void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, c // act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep, int act_type) { int index = 0; -#ifdef ENABLE_AVX512 - __m512 down_threshold512 = _mm512_setzero_ps(); - __m512 up_threshold512 = _mm512_set1_ps(C6NUM); - __m512 b_data16 = _mm512_set1_ps(b[0]); - __m512 bias_data16 = _mm512_set1_ps(bias[0]); - for (; index < row - C16NUM; index += C16NUM) { - __m512 a_data = _mm512_loadu_ps(a + index); - __m512 dst = b_data16 * a_data + bias_data16; - ActCompute(512, down_threshold512, up_threshold512); - _mm512_storeu_ps(c + index, dst); - } -#endif -#ifdef ENABLE_AVX - __m256 down_threshold256 = _mm256_setzero_ps(); - __m256 up_threshold256 = _mm256_set1_ps(C6NUM); - __m256 b_data8 = _mm256_set1_ps(b[0]); - __m256 bias_data8 = _mm256_set1_ps(bias[0]); - for (; index < row - C8NUM; index += C8NUM) { - __m256 a_data = _mm256_loadu_ps(a + index); - __m256 dst = b_data8 * a_data + bias_data8; - ActCompute(256, down_threshold256, up_threshold256); - _mm256_storeu_ps(c + index, dst); - } -#endif - -#if defined(ENABLE_SSE) || defined(ENABLE_ARM) - MS_FLOAT32X4 down_threshold128 = MS_MOVQ_F32(0); - MS_FLOAT32X4 up_threshold128 = MS_MOVQ_F32(C6NUM); - MS_FLOAT32X4 b_data4 = MS_MOVQ_F32(b[0]); - MS_FLOAT32X4 bias_data4 = MS_MOVQ_F32(bias[0]); - for (; index < row - C4NUM; index += C4NUM) { - MS_FLOAT32X4 a_data = MS_LDQ_F32(a + index); - MS_FLOAT32X4 dst = MS_ADD128_F32(MS_MUL128_F32(b_data4, a_data), bias_data4); - ActCompute(128, down_threshold128, up_threshold128); - MS_STQ_F32(c + index, dst); - } -#endif + SIMD_RUN_NO_SCALAR(GemmIsNotPack, index, a, b, c, bias, row, deep, act_type); for (; index < row; ++index) { float dst = a[index] * b[0] + bias[0]; @@ -1321,41 +1302,9 @@ void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k, int act_type) { // gemm dot is [m, k] * [k, 1] ==>> [m, 1] int m_index = 0; -#ifdef ENABLE_AVX512 - // block 8 - MS_FLOAT32X8 down_threshold256 = _mm256_setzero_ps(); - MS_FLOAT32X8 up_threshold256 = _mm256_set1_ps(C6NUM); - for (; m_index <= m - C8NUM; m_index += C8NUM) { - int k_index = 0; - MS_FLOAT32X8 dst = MS_MOV256_F32(bias[0]); - MS_SET_ZERO512X8_F32(dst16_) - for (; k_index <= k - C16NUM; k_index += C16NUM) { - __m512 weight = _mm512_loadu_ps(b + k_index); - MS_LOAD512X8_F32(src, a + m_index * k + k_index, k) - MS_FMADD512X8_F32(src, weight, dst16_) - } - MS_F32X8_GETI(dst, 0) += MS_REDUCE_ADD512_F32(dst16_1); - MS_F32X8_GETI(dst, 1) += MS_REDUCE_ADD512_F32(dst16_2); - MS_F32X8_GETI(dst, C2NUM) += MS_REDUCE_ADD512_F32(dst16_3); - MS_F32X8_GETI(dst, C3NUM) += MS_REDUCE_ADD512_F32(dst16_4); - MS_F32X8_GETI(dst, C4NUM) += MS_REDUCE_ADD512_F32(dst16_5); - MS_F32X8_GETI(dst, C5NUM) += MS_REDUCE_ADD512_F32(dst16_6); - MS_F32X8_GETI(dst, C6NUM) += MS_REDUCE_ADD512_F32(dst16_7); - MS_F32X8_GETI(dst, C7NUM) += MS_REDUCE_ADD512_F32(dst16_8); - for (; k_index < k; k_index++) { - MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index]; - MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k]; - MS_F32X8_GETI(dst, C2NUM) += b[k_index] * a[m_index * k + k_index + C2NUM * k]; - MS_F32X8_GETI(dst, C3NUM) += b[k_index] * a[m_index * k + k_index + C3NUM * k]; - MS_F32X8_GETI(dst, C4NUM) += b[k_index] * a[m_index * k + k_index + C4NUM * k]; - MS_F32X8_GETI(dst, C5NUM) += b[k_index] * a[m_index * k + k_index + C5NUM * k]; - MS_F32X8_GETI(dst, C6NUM) += b[k_index] * a[m_index * k + k_index + C6NUM * k]; - MS_F32X8_GETI(dst, C7NUM) += b[k_index] * a[m_index * k + k_index + C7NUM * k]; - } - ActCompute(256, down_threshold256, up_threshold256); - MS_ST256_F32(c + m_index, dst); - } -#endif + + SIMD_RUN_AVX512(GemmIsNotPackOptimize, m_index, a, b, c, bias, m, k, act_type); + #ifdef ENABLE_AVX // block 4 MS_FLOAT32X4 down_threshold128 = MS_MOVQ_F32(0); @@ -1388,24 +1337,10 @@ void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float for (; m_index < m; m_index++) { float dst = bias[0]; int k_index = 0; -#ifdef ENABLE_AVX512 - __m512 dst1 = _mm512_setzero_ps(); - for (; k_index <= k - C16NUM; k_index += C16NUM) { - __m512 weight = _mm512_loadu_ps(b + k_index); - __m512 a1 = _mm512_loadu_ps(a + m_index * k + k_index); - dst1 = _mm512_fmadd_ps(weight, a1, dst1); - } - dst += _mm512_reduce_add_ps(dst1); -#endif -#ifdef ENABLE_AVX - __m256 dst2 = _mm256_setzero_ps(); - for (; k_index <= k - C8NUM; k_index += C8NUM) { - __m256 weight = _mm256_loadu_ps(b + k_index); - __m256 src = _mm256_loadu_ps(a + m_index * k + k_index); - dst2 = _mm256_fmadd_ps(weight, src, dst2); - } - dst += MS_REDUCE_ADD256_F32(dst2); -#endif + + SIMD_RUN_AVX512(GemmIsNotPackOptimizeCore, k_index, a + m_index * k, b, k, &dst); + SIMD_RUN_AVX(GemmIsNotPackOptimizeCore, k_index, a + m_index * k, b, k, &dst); + for (; k_index < k; k_index++) { dst += b[k_index] * a[m_index * k + k_index]; } diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32_simd.h.in b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32_simd.h.in new file mode 100644 index 00000000000..383d936b619 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32_simd.h.in @@ -0,0 +1,66 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_MATMUL_F32_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_MATMUL_F32_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +static inline int64_t GemmIsNotPack@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, float *c, const float *bias, int row, + int deep, int act_type) { + SIMD_F32 down_threshold = SIMD_MOV_F32(0.0f); + SIMD_F32 up_threshold = SIMD_MOV_F32(6); + SIMD_F32 b_data16 = SIMD_MOV_F32(b[0]); + SIMD_F32 bias_data16 = SIMD_MOV_F32(bias[0]); + for (int block_max_size = row - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 a_data = SIMD_LD_F32(a + index); + SIMD_F32 dst = b_data16 * a_data + bias_data16; + if (act_type != 0) { + dst = SIMD_MAX_F32(dst, down_threshold); + if (act_type == 3) { + dst = SIMD_MIN_F32(dst, up_threshold); + } + } + SIMD_ST_F32(c + index, dst); + } + + return index; +} + +#if defined(MS_SIMD_AVX512) || defined(MS_SIMD_AVX) +static inline int64_t GemmIsNotPackOptimizeCore@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, int k, float *dst) { + SIMD_F32 dst1 = SIMD_MOV_F32(0.0f); + for (int block_max_size = k - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 weight = SIMD_LD_F32(b + index); + SIMD_F32 a1 = SIMD_LD_F32(a + index); + dst1 = SIMD_FMADD_F32(weight, a1, dst1); + } + *dst += SIMD_REDUCE_ADD_F32(dst1); + return index; +} +#endif + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_instructions.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_instructions.h index 9852da34d48..28a88024f02 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_instructions.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_instructions.h @@ -134,6 +134,7 @@ // get max (float/int) op #define SIMD_GET_SUM_F32 MS_SIMD_INSTRUCTION_F32(MS_GET_SUM) +#define SIMD_REDUCE_ADD_F32 MS_SIMD_INSTRUCTION(MS_REDUCE_ADD, _F32) // clamp (float/int) op #define SIMD_CLAMP_F32(val, min_val, max_val) SIMD_MIN_F32(SIMD_MAX_F32(val, min_val), max_val)