From 7b4ec7b351ab57d51aa8bdf5939a0f6cca0bdd40 Mon Sep 17 00:00:00 2001 From: xuanyue Date: Mon, 9 May 2022 17:32:50 +0800 Subject: [PATCH] matmul-avx support split-by-row when B is a vector --- .jenkins/check/config/whitelizard.txt | 2 +- .../cpu/kernel/nnacl/fp32/matmul_fp32.c | 182 +++++++++++++----- .../cpu/kernel/nnacl/fp32/matmul_fp32.h | 6 +- .../intrinsics/ms_simd_avx512_instructions.h | 1 + .../intrinsics/ms_simd_avx_instructions.h | 2 + .../kernel/cpu/fp32/matmul_fp32_arm64.cc | 2 +- .../kernel/cpu/fp32/matmul_fp32_avx.cc | 19 +- .../kernel/cpu/fp32/matmul_fp32_avx512.cc | 19 +- .../kernel/cpu/fp32/matmul_fp32_base.cc | 7 +- .../kernel/cpu/fp32/matmul_fp32_base.h | 5 +- 10 files changed, 182 insertions(+), 63 deletions(-) diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index 2300032e6fc..e9c10d91a26 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -71,7 +71,7 @@ mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/conv_int8.c:Conv1x mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/conv_int8.c:Conv1x1PreOptPert mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/pack_int8.c:PackNHWCToNCHWInt8 mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/pooling_fp32.c:AvgPooling -mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.c:MatMul4x1Kernel +mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.c:MatMul4x1Kernel, MatMul2x1Kernel mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/conv_common_fp32.c:SWConv3x32Kernel, SWConv4x24Kernel, SWConv12x8Kernel, SWConv8x8Kernel, SWConv4x8Kernel, SWConv6x16Kernel, SWConv4x16Kernel mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/conv_depthwise_fp32.c:DepthwiseSW3x32Kernel, DepthwiseSW4x24Kernel, DepthwiseSW12x8Kernel, DepthwiseSW8x8Kernel, DepthwiseSW4x8Kernel, DepthwiseSW6x16Kernel, DepthwiseSW4x16Kernel mindspore/mindspore/core/ir/dtype/type.cc:mindspore::ObjectIdLabel diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.c index b0dd2998465..35f396e8404 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.c +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.c @@ -1269,52 +1269,71 @@ void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, c #endif #endif -void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep) { +#define ActCompute(bit_num, down_threshold, up_threshold) \ + if (act_type != 0) { \ + dst = MS_MAX##bit_num##_F32(dst, down_threshold); \ + if (act_type == 3) { \ + dst = MS_MIN##bit_num##_F32(dst, up_threshold); \ + } \ + } + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep, int act_type) { int index = 0; #ifdef ENABLE_AVX512 + __m512 down_threshold512 = _mm512_setzero_ps(); + __m512 up_threshold512 = _mm512_set1_ps(C6NUM); __m512 b_data16 = _mm512_set1_ps(b[0]); __m512 bias_data16 = _mm512_set1_ps(bias[0]); -#endif -#ifdef ENABLE_AVX - __m256 b_data8 = _mm256_set1_ps(b[0]); - __m256 bias_data8 = _mm256_set1_ps(bias[0]); -#endif -#if defined(ENABLE_SSE) || defined(ENABLE_ARM) - MS_FLOAT32X4 b_data4 = MS_MOVQ_F32(b[0]); - MS_FLOAT32X4 bias_data4 = MS_MOVQ_F32(bias[0]); -#endif - -#ifdef ENABLE_AVX512 for (; index < row - C16NUM; index += C16NUM) { __m512 a_data = _mm512_loadu_ps(a + index); - _mm512_storeu_ps(c + index, b_data16 * a_data + bias_data16); + __m512 dst = b_data16 * a_data + bias_data16; + ActCompute(512, down_threshold512, up_threshold512); + _mm512_storeu_ps(c + index, dst); } #endif #ifdef ENABLE_AVX + __m256 down_threshold256 = _mm256_setzero_ps(); + __m256 up_threshold256 = _mm256_set1_ps(C6NUM); + __m256 b_data8 = _mm256_set1_ps(b[0]); + __m256 bias_data8 = _mm256_set1_ps(bias[0]); for (; index < row - C8NUM; index += C8NUM) { __m256 a_data = _mm256_loadu_ps(a + index); - _mm256_storeu_ps(c + index, b_data8 * a_data + bias_data8); + __m256 dst = b_data8 * a_data + bias_data8; + ActCompute(256, down_threshold256, up_threshold256); + _mm256_storeu_ps(c + index, dst); } #endif #if defined(ENABLE_SSE) || defined(ENABLE_ARM) + MS_FLOAT32X4 down_threshold128 = MS_MOVQ_F32(0); + MS_FLOAT32X4 up_threshold128 = MS_MOVQ_F32(C6NUM); + MS_FLOAT32X4 b_data4 = MS_MOVQ_F32(b[0]); + MS_FLOAT32X4 bias_data4 = MS_MOVQ_F32(bias[0]); for (; index < row - C4NUM; index += C4NUM) { MS_FLOAT32X4 a_data = MS_LDQ_F32(a + index); - MS_STQ_F32(c + index, MS_ADD128_F32(MS_MUL128_F32(b_data4, a_data), bias_data4)); + MS_FLOAT32X4 dst = MS_ADD128_F32(MS_MUL128_F32(b_data4, a_data), bias_data4); + ActCompute(128, down_threshold128, up_threshold128); + MS_STQ_F32(c + index, dst); } #endif for (; index < row; ++index) { - c[index] = a[index] * b[0] + bias[0]; + float dst = a[index] * b[0] + bias[0]; + ActCompute(32, 0, C6NUM); + c[index] = dst; } } -void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k) { +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k, int act_type) { // gemm dot is [m, k] * [k, 1] ==>> [m, 1] int m_index = 0; #ifdef ENABLE_AVX512 // block 8 + MS_FLOAT32X8 down_threshold256 = _mm256_setzero_ps(); + MS_FLOAT32X8 up_threshold256 = _mm256_set1_ps(C6NUM); for (; m_index <= m - C8NUM; m_index += C8NUM) { int k_index = 0; MS_FLOAT32X8 dst = MS_MOV256_F32(bias[0]); @@ -1324,31 +1343,59 @@ void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float MS_LOAD512X8_F32(src, a + m_index * k + k_index, k) MS_FMADD512X8_F32(src, weight, dst16_) } - MS_F32X8_GETI(dst, 0) += _mm512_reduce_add_ps(dst16_1); - MS_F32X8_GETI(dst, 1) += _mm512_reduce_add_ps(dst16_2); - MS_F32X8_GETI(dst, 2) += _mm512_reduce_add_ps(dst16_3); - MS_F32X8_GETI(dst, 3) += _mm512_reduce_add_ps(dst16_4); - MS_F32X8_GETI(dst, 4) += _mm512_reduce_add_ps(dst16_5); - MS_F32X8_GETI(dst, 5) += _mm512_reduce_add_ps(dst16_6); - MS_F32X8_GETI(dst, 6) += _mm512_reduce_add_ps(dst16_7); - MS_F32X8_GETI(dst, 7) += _mm512_reduce_add_ps(dst16_8); + MS_F32X8_GETI(dst, 0) += MS_REDUCE_ADD512_F32(dst16_1); + MS_F32X8_GETI(dst, 1) += MS_REDUCE_ADD512_F32(dst16_2); + MS_F32X8_GETI(dst, C2NUM) += MS_REDUCE_ADD512_F32(dst16_3); + MS_F32X8_GETI(dst, C3NUM) += MS_REDUCE_ADD512_F32(dst16_4); + MS_F32X8_GETI(dst, C4NUM) += MS_REDUCE_ADD512_F32(dst16_5); + MS_F32X8_GETI(dst, C5NUM) += MS_REDUCE_ADD512_F32(dst16_6); + MS_F32X8_GETI(dst, C6NUM) += MS_REDUCE_ADD512_F32(dst16_7); + MS_F32X8_GETI(dst, C7NUM) += MS_REDUCE_ADD512_F32(dst16_8); for (; k_index < k; k_index++) { MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index]; MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k]; - MS_F32X8_GETI(dst, 2) += b[k_index] * a[m_index * k + k_index + 2 * k]; - MS_F32X8_GETI(dst, 3) += b[k_index] * a[m_index * k + k_index + 3 * k]; - MS_F32X8_GETI(dst, 4) += b[k_index] * a[m_index * k + k_index + 4 * k]; - MS_F32X8_GETI(dst, 5) += b[k_index] * a[m_index * k + k_index + 5 * k]; - MS_F32X8_GETI(dst, 6) += b[k_index] * a[m_index * k + k_index + 6 * k]; - MS_F32X8_GETI(dst, 7) += b[k_index] * a[m_index * k + k_index + 7 * k]; + MS_F32X8_GETI(dst, C2NUM) += b[k_index] * a[m_index * k + k_index + C2NUM * k]; + MS_F32X8_GETI(dst, C3NUM) += b[k_index] * a[m_index * k + k_index + C3NUM * k]; + MS_F32X8_GETI(dst, C4NUM) += b[k_index] * a[m_index * k + k_index + C4NUM * k]; + MS_F32X8_GETI(dst, C5NUM) += b[k_index] * a[m_index * k + k_index + C5NUM * k]; + MS_F32X8_GETI(dst, C6NUM) += b[k_index] * a[m_index * k + k_index + C6NUM * k]; + MS_F32X8_GETI(dst, C7NUM) += b[k_index] * a[m_index * k + k_index + C7NUM * k]; } + ActCompute(256, down_threshold256, up_threshold256); MS_ST256_F32(c + m_index, dst); } #endif +#ifdef ENABLE_AVX + // block 4 + MS_FLOAT32X4 down_threshold128 = MS_MOVQ_F32(0); + MS_FLOAT32X4 up_threshold128 = MS_MOVQ_F32(C6NUM); + for (; m_index <= m - C4NUM; m_index += C4NUM) { + int k_index = 0; + MS_FLOAT32X4 dst = MS_MOV128_F32(bias[0]); + MS_SET_ZERO256X4_F32(dst_) + for (; k_index <= k - C8NUM; k_index += C8NUM) { + MS_FLOAT32X8 weight = MS_LD256_F32(b + k_index); + MS_LOAD256X4_F32(src, a + m_index * k + k_index, k); + MS_FMADD256X4_F32(src, weight, dst_); + } + MS_F32X8_GETI(dst, 0) += MS_REDUCE_ADD256_F32(dst_1); + MS_F32X8_GETI(dst, 1) += MS_REDUCE_ADD256_F32(dst_2); + MS_F32X8_GETI(dst, C2NUM) += MS_REDUCE_ADD256_F32(dst_3); + MS_F32X8_GETI(dst, C3NUM) += MS_REDUCE_ADD256_F32(dst_4); + for (; k_index < k; ++k_index) { + MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index]; + MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k]; + MS_F32X8_GETI(dst, C2NUM) += b[k_index] * a[m_index * k + k_index + C2NUM * k]; + MS_F32X8_GETI(dst, C3NUM) += b[k_index] * a[m_index * k + k_index + C3NUM * k]; + } + ActCompute(128, down_threshold128, up_threshold128); + MS_ST128_F32(c + m_index, dst); + } +#endif // block 1 for (; m_index < m; m_index++) { - c[m_index] = bias[0]; + float dst = bias[0]; int k_index = 0; #ifdef ENABLE_AVX512 __m512 dst1 = _mm512_setzero_ps(); @@ -1357,16 +1404,29 @@ void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float __m512 a1 = _mm512_loadu_ps(a + m_index * k + k_index); dst1 = _mm512_fmadd_ps(weight, a1, dst1); } - c[m_index] += _mm512_reduce_add_ps(dst1); + dst += _mm512_reduce_add_ps(dst1); +#endif +#ifdef ENABLE_AVX + __m256 dst2 = _mm256_setzero_ps(); + for (; k_index <= k - C8NUM; k_index += C8NUM) { + __m256 weight = _mm256_loadu_ps(b + k_index); + __m256 src = _mm256_loadu_ps(a + m_index * k + k_index); + dst2 = _mm256_fmadd_ps(weight, src, dst2); + } + dst += MS_REDUCE_ADD256_F32(dst2); #endif for (; k_index < k; k_index++) { - c[m_index] += b[k_index] * a[m_index * k + k_index]; + dst += b[k_index] * a[m_index * k + k_index]; } + ActCompute(32, 0, C6NUM); + c[m_index] = dst; } } #ifdef ENABLE_ARM64 -void MatMul4x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) { +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +void MatMul4x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep, + size_t act_type) { // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute // 9: WriteBack asm volatile( @@ -1495,15 +1555,26 @@ void MatMul4x1Kernel(const float *input, const float *weight, float *output, con "ld1r {v1.4s}, [%[bias]]\n" "fadd v0.4s, v0.4s, v1.4s\n" "9:\n" + "cbz %[act], 10f\n" + "dup v1.2d, xzr\n" + "fmax v0.4s, v0.4s, v1.4s\n" + "cmp %[act], #3\n" + "bne 10f\n" + "movi v1.4s, #6\n" + "scvtf v1.4s, v1.4s\n" + "fmin v0.4s, v0.4s, v1.4s\n" + "10:\n" "st1 {v0.4s}, [%[output]]\n" : - : [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep) + : [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep), + [ act ] "r"(act_type) : "cc", "x5", "x6", "x7", "x8", "x9", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); } -void MatMul2x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) { +void MatMul2x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep, + size_t act_type) { // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute // 9: WriteBack asm volatile( @@ -1590,15 +1661,26 @@ void MatMul2x1Kernel(const float *input, const float *weight, float *output, con "ld1r {v1.4s}, [%[bias]]\n" "fadd v0.2s, v0.2s, v1.2s\n" "9:\n" + "cbz %[act], 10f\n" + "fmov d1, xzr\n" + "fmax v0.2s, v0.2s, v1.2s\n" + "cmp %[act], #3\n" + "bne 10f\n" + "movi v1.2s, #6\n" + "scvtf v1.2s, v1.2s\n" + "fmin v0.2s, v0.2s, v1.2s\n" + "10:\n" "st1 {v0.2s}, [%[output]]\n" : - : [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep) + : [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep), + [ act ] "r"(act_type) : "cc", "x5", "x8", "x9", "x10", "v0", "v1", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29", "v30", "v31", "memory"); } -void MatMul1x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) { +void MatMul1x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep, + size_t act_type) { // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute // 9: WriteBack asm volatile( @@ -1665,30 +1747,40 @@ void MatMul1x1Kernel(const float *input, const float *weight, float *output, con "ld1 {v1.s}[0], [%[bias]]\n" "fadd s0, s0, s1\n" "9:\n" - "st1 {v0.s}[0], [%[output]]\n" + "cbz %[act], 10f\n" + "fmov s1, wzr\n" + "fmax s0, s0, s1\n" + "cmp %[act], #3\n" + "bne 10f\n" + "mov x10, #6\n" + "scvtf s1, x10\n" + "fmin s0, s0, s1\n" + "10:\n" + "str s0, [%[output]]\n" : - : [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep) + : [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep), + [ act ] "r"(act_type) : "cc", "x8", "x9", "x10", "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29", "v30", "v31"); } void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row, - int deep) { + int deep, int act_type) { const float *input = a + start_row * deep; float *output = c + start_row; const int step = C4NUM * deep; for (; start_row <= end_row - C4NUM; start_row += C4NUM) { - MatMul4x1Kernel(input, b, output, bias, deep); + MatMul4x1Kernel(input, b, output, bias, deep, act_type); input += step; output += C4NUM; } for (; start_row <= end_row - C2NUM; start_row += C2NUM) { - MatMul2x1Kernel(input, b, output, bias, deep); + MatMul2x1Kernel(input, b, output, bias, deep, act_type); input += C2NUM * deep; output += C2NUM; } if (start_row == end_row - 1) { - MatMul1x1Kernel(input, b, output, bias, deep); + MatMul1x1Kernel(input, b, output, bias, deep, act_type); } } #endif diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.h index f94ed0db317..f532bbc81da 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.h @@ -108,13 +108,13 @@ void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float * void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, int col, int stride, int out_type); -void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep); +void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep, int act_type); -void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k); +void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k, int act_type); #ifdef ENABLE_ARM64 void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row, - int deep); + int deep, int act_type); #endif #ifdef __cplusplus } diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_avx512_instructions.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_avx512_instructions.h index e9c78836a36..4e310e7fdad 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_avx512_instructions.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_avx512_instructions.h @@ -74,6 +74,7 @@ #define MS_BLEND512_F32(src1, src2, mask) _mm512_mask_blend_ps(mask, src1, src2) #define MS_BLEND512_EPI32(src1, src2, mask) _mm512_mask_blend_epi32(mask, src1, src2) #define MS_CAST512_F32_S32(src) _mm512_castsi512_ps(src) +#define MS_REDUCE_ADD512_F32(src) _mm512_reduce_add_ps(src) static inline float MS_GET_MAX512_F32(__m512 src) { float result = MS_F32X16_GETI(src, 0); diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_avx_instructions.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_avx_instructions.h index 7b4f6516d0d..665f3d274a3 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_avx_instructions.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_avx_instructions.h @@ -247,4 +247,6 @@ static inline MS_FLOAT32X8 MS_TANHX8_F32(MS_FLOAT32X8 src) { MS_FLOAT32X8 dst##2 = _mm256_setzero_ps(); \ MS_FLOAT32X8 dst##3 = _mm256_setzero_ps(); \ MS_FLOAT32X8 dst##4 = _mm256_setzero_ps(); + +#define MS_REDUCE_ADD256_F32(src) (src = _mm256_hadd_ps(src, src), src = _mm256_hadd_ps(src, src), src[0] + src[4]); #endif // MINDSPORE_NNACL_AVX_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/matmul_fp32_arm64.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32/matmul_fp32_arm64.cc index e3be26ad369..444ff3acb3a 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/fp32/matmul_fp32_arm64.cc +++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/matmul_fp32_arm64.cc @@ -110,7 +110,7 @@ int MatmulFp32BaseCPUKernel::ParallelRunByRow(int task_id) const { return RET_OK; } GemmIsNotPackByRow(matrix_a_.pack_ptr, matrix_b_.pack_ptr, output_data_, matrix_c_.pack_ptr, start_row, end_row, - params_->deep_); + params_->deep_, params_->act_type_); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/matmul_fp32_avx.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32/matmul_fp32_avx.cc index c66d424f0ab..dd332272908 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/fp32/matmul_fp32_avx.cc +++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/matmul_fp32_avx.cc @@ -69,8 +69,16 @@ int MatmulFp32BaseCPUKernel::ParallelRunByRow(int task_id) const { } const float *input = matrix_a_.pack_ptr + start_row * params_->deep_; float *output = output_data_ + start_row * params_->col_align_; - MatMulAvxFp32(input, matrix_b_.pack_ptr, output, matrix_c_.pack_ptr, params_->act_type_, params_->deep_, - params_->col_align_, params_->col_align_, row_num); + if (params_->col_ == 1) { + float bias = 0; + if (matrix_c_.pack_ptr != nullptr) { + bias = matrix_c_.pack_ptr[0]; + } + gemmIsNotPackFun(input, matrix_b_.pack_ptr, output, &bias, row_num, params_->deep_, params_->act_type_); + } else { + MatMulAvxFp32(input, matrix_b_.pack_ptr, output, matrix_c_.pack_ptr, params_->act_type_, params_->deep_, + params_->col_align_, params_->col_align_, row_num); + } return RET_OK; } @@ -102,12 +110,13 @@ bool MatmulFp32BaseCPUKernel::CheckThreadCuttingByRow() { if (b_batch_ != C1NUM) { return false; } - if (params_->col_ == 1 && !params_->a_const_) { - return false; - } if (row_num_ < op_parameter_->thread_num_) { return false; } + if (params_->col_ == 1) { + row_min_unit_ = C4NUM; + return true; + } row_min_unit_ = C3NUM; if (col_step_ < C16NUM) { row_min_unit_ = C8NUM; diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/matmul_fp32_avx512.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32/matmul_fp32_avx512.cc index 002552dc523..056ce6a6850 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/fp32/matmul_fp32_avx512.cc +++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/matmul_fp32_avx512.cc @@ -70,8 +70,16 @@ int MatmulFp32BaseCPUKernel::ParallelRunByRow(int task_id) const { } const float *input = matrix_a_.pack_ptr + start_row * params_->deep_; float *output = output_data_ + start_row * params_->col_align_; - MatMulAvx512Fp32(input, matrix_b_.pack_ptr, output, matrix_c_.pack_ptr, params_->act_type_, params_->deep_, - params_->col_align_, params_->col_align_, row_num); + if (params_->col_ == 1) { + float bias = 0; + if (matrix_c_.pack_ptr != nullptr) { + bias = matrix_c_.pack_ptr[0]; + } + gemmIsNotPackFun(input, matrix_b_.pack_ptr, output, &bias, row_num, params_->deep_, params_->act_type_); + } else { + MatMulAvx512Fp32(input, matrix_b_.pack_ptr, output, matrix_c_.pack_ptr, params_->act_type_, params_->deep_, + params_->col_align_, params_->col_align_, row_num); + } return RET_OK; } @@ -104,12 +112,13 @@ bool MatmulFp32BaseCPUKernel::CheckThreadCuttingByRow() { if (b_batch_ != C1NUM) { return false; } - if (params_->batch >= op_parameter_->thread_num_ || params_->col_ == 1) { - return false; - } if (row_num_ < op_parameter_->thread_num_) { return false; } + if (params_->col_ == 1) { + row_min_unit_ = C8NUM; + return true; + } row_min_unit_ = C6NUM; if (col_step_ < C48NUM) { row_min_unit_ = C12NUM; diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/matmul_fp32_base.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32/matmul_fp32_base.cc index 465d5d3d188..9966e40b403 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/fp32/matmul_fp32_base.cc +++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/matmul_fp32_base.cc @@ -230,7 +230,7 @@ int MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch(int task_id) const { const float *a = matrix_a_.pack_ptr + a_offset_[index] * params_->row_ * params_->deep_; const float *b = matrix_b_.pack_ptr + b_offset_[index] * params_->deep_ * params_->col_; float *c = output_data_ + index * params_->row_ * params_->col_; - gemmIsNotPackFun(a, b, c, &bias, params_->row_, params_->deep_); + gemmIsNotPackFun(a, b, c, &bias, params_->row_, params_->deep_, params_->act_type_); } return RET_OK; } @@ -247,6 +247,11 @@ int MatmulFp32BaseCPUKernel::Prepare() { MS_CHECK_TRUE_MSG(in_tensors_[THIRD_INPUT]->data_type() == kNumberTypeFloat32, RET_ERROR, "matrix-c's data type is invalid."); } + auto act_type = params_->act_type_; + if (act_type != ActType_No && act_type != ActType_Relu && act_type != ActType_Relu6) { + MS_LOG(ERROR) << "matmul don't support the act-type: " << act_type; + return RET_ERROR; + } auto ret = InitParameter(); MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "Init parameters failed."); if (params_->a_const_) { diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/matmul_fp32_base.h b/mindspore/lite/src/runtime/kernel/cpu/fp32/matmul_fp32_base.h index d47c4b512f8..76574d21ca7 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/fp32/matmul_fp32_base.h +++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/matmul_fp32_base.h @@ -30,7 +30,8 @@ using mindspore::lite::RET_OK; namespace mindspore::kernel { using MatrixPackFun = void (*)(const float *src_ptr, float *dst_ptr, int row, int col); -using GemmIsNotPackFun = void (*)(const float *a, const float *b, float *c, const float *bias, int m, int k); +using GemmIsNotPackFun = void (*)(const float *a, const float *b, float *c, const float *bias, int m, int k, + int act_type); class MatmulFp32BaseCPUKernel : public LiteKernel { public: @@ -79,6 +80,7 @@ class MatmulFp32BaseCPUKernel : public LiteKernel { protected: MatMulParameter *params_ = nullptr; + GemmIsNotPackFun gemmIsNotPackFun = nullptr; int a_batch_ = 1; int b_batch_ = 1; std::vector a_offset_; @@ -95,7 +97,6 @@ class MatmulFp32BaseCPUKernel : public LiteKernel { float *output_data_ = nullptr; bool out_need_aligned_ = false; int col_step_ = 0; - GemmIsNotPackFun gemmIsNotPackFun = nullptr; std::vector split_points_; MatrixInfo matrix_a_; MatrixInfo matrix_b_;