From 27ddd48788943ff9bc6be63f3019e54b5ca509e0 Mon Sep 17 00:00:00 2001 From: xuanyue Date: Wed, 19 Jan 2022 17:12:53 +0800 Subject: [PATCH] matmul opt --- .jenkins/check/config/whitelizard.txt | 1 + .../cpu/nnacl/fp32/matmul_fp32.c | 328 ++++++++++++++++++ .../cpu/nnacl/fp32/matmul_fp32.h | 4 + .../kernel/arm/fp32/matmul_fp32_base.cc | 93 +++-- .../kernel/arm/fp32/matmul_fp32_base.h | 12 +- 5 files changed, 414 insertions(+), 24 deletions(-) diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index e4e14e26643..05b77108e23 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -70,6 +70,7 @@ mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/conv_int8.c:Con mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/conv_int8.c:Conv1x1PreOptPert mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/pack_int8.c:PackNHWCToNCHWInt8 mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/pooling_fp32.c:AvgPooling +mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c:MatMul4x1Kernel mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.c:SWConv3x32Kernel, SWConv4x24Kernel, SWConv12x8Kernel, SWConv8x8Kernel, SWConv4x8Kernel, SWConv6x16Kernel, SWConv4x16Kernel mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_depthwise_fp32.c:DepthwiseSW3x32Kernel, DepthwiseSW4x24Kernel, DepthwiseSW12x8Kernel, DepthwiseSW8x8Kernel, DepthwiseSW4x8Kernel, DepthwiseSW6x16Kernel, DepthwiseSW4x16Kernel mindspore/mindspore/core/ir/dtype/type.cc:mindspore::ObjectIdLabel diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c index 57da9bb6e14..a1088f9c3c0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c @@ -2268,3 +2268,331 @@ void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float } } } + +#ifdef ENABLE_ARM64 +void MatMul4x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) { + // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute + // 9: WriteBack + asm volatile( + "mov x8, %[input]\n" + "mov x9, %[weight]\n" + "mov x10, %[deep]\n" + "add x5, %[input], %[deep], LSL #2\n" + "add x6, %[input], %[deep], LSL #3\n" + "add x7, x5, %[deep], LSL #3\n" + "dup v0.2d, xzr\n" + "dup v1.2d, xzr\n" + "dup v2.2d, xzr\n" + "dup v3.2d, xzr\n" + "subs x10, x10, #16\n" + "blt 2f\n" + "1:\n" // LoopD16 + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64\n" + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x6], #64\n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x7], #64\n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v2.4s, v21.4s, v29.4s\n" + "fmla v3.4s, v25.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v1.4s, v18.4s, v30.4s\n" + "fmla v2.4s, v22.4s, v30.4s\n" + "fmla v3.4s, v26.4s, v30.4s\n" + "fmla v0.4s, v7.4s, v31.4s\n" + "fmla v1.4s, v19.4s, v31.4s\n" + "fmla v2.4s, v23.4s, v31.4s\n" + "fmla v3.4s, v27.4s, v31.4s\n" + "subs x10, x10, #16\n" + "bge 1b\n" + "2:\n" // LoopD12 + "adds x10, x10, #16\n" + "cbz x10, 6f\n" + "cmp x10, #12\n" + "blt 3f\n" + "ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n" + "ld1 {v16.4s, v17.4s, v18.4s}, [x5], #48\n" + "ld1 {v20.4s, v21.4s, v22.4s}, [x6], #48\n" + "ld1 {v24.4s, v25.4s, v26.4s}, [x7], #48\n" + "ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v2.4s, v21.4s, v29.4s\n" + "fmla v3.4s, v25.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v1.4s, v18.4s, v30.4s\n" + "fmla v2.4s, v22.4s, v30.4s\n" + "fmla v3.4s, v26.4s, v30.4s\n" + "sub x10, x10, #12\n" + "b 7f\n" + "3:\n" // LoopD8 + "cmp x10, #8\n" + "blt 4f\n" + "ld1 {v4.4s, v5.4s}, [x8], #32\n" + "ld1 {v16.4s, v17.4s}, [x5], #32\n" + "ld1 {v20.4s, v21.4s}, [x6], #32\n" + "ld1 {v24.4s, v25.4s}, [x7], #32\n" + "ld1 {v28.4s, v29.4s}, [x9], #32\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v2.4s, v21.4s, v29.4s\n" + "fmla v3.4s, v25.4s, v29.4s\n" + "sub x10, x10, #8\n" + "b 7f\n" + "4:\n" // LoopD4 + "cmp x10, #4\n" + "blt 7f\n" + "ld1 {v4.4s}, [x8], #16\n" + "ld1 {v16.4s}, [x5], #16\n" + "ld1 {v20.4s}, [x6], #16\n" + "ld1 {v24.4s}, [x7], #16\n" + "ld1 {v28.4s}, [x9], #16\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "sub x10, x10, #4\n" + "7:\n" + "cbz x10, 6f\n" + "dup v4.2d, xzr\n" + "dup v16.2d, xzr\n" + "dup v20.2d, xzr\n" + "dup v24.2d, xzr\n" + "dup v28.2d, xzr\n" + "subs x10, x10, #2\n" + "blt 5f\n" + "ld1 {v4.d}[0], [x8], #8\n" // LoopD2 + "ld1 {v16.d}[0], [x5], #8\n" + "ld1 {v20.d}[0], [x6], #8\n" + "ld1 {v24.d}[0], [x7], #8\n" + "ld1 {v28.d}[0], [x9], #8\n" + "cbz x10, 8f\n" + "5:\n" // LoopD1 + "ld1 {v4.s}[2], [x8]\n" + "ld1 {v16.s}[2], [x5]\n" + "ld1 {v20.s}[2], [x6]\n" + "ld1 {v24.s}[2], [x7]\n" + "ld1 {v28.s}[2], [x9]\n" + "8:\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "6:\n" + "faddp v4.4s, v0.4s, v1.4s\n" + "faddp v5.4s, v2.4s, v3.4s\n" + "faddp v0.4s, v4.4s, v5.4s\n" + "cbz %[bias], 9f\n" + "ld1r {v1.4s}, [%[bias]]\n" + "fadd v0.4s, v0.4s, v1.4s\n" + "9:\n" + "st1 {v0.4s}, [%[output]]\n" + + : + : [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep) + : "cc", "x5", "x6", "x7", "x8", "x9", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} + +void MatMul2x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) { + // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute + // 9: WriteBack + asm volatile( + "mov x8, %[input]\n" + "mov x9, %[weight]\n" + "mov x10, %[deep]\n" + "add x5, %[input], %[deep], LSL #2\n" + "dup v0.2d, xzr\n" + "dup v1.2d, xzr\n" + "subs x10, x10, #16\n" + "blt 2f\n" + "1:\n" // LoopD16 + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64\n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v1.4s, v18.4s, v30.4s\n" + "fmla v0.4s, v7.4s, v31.4s\n" + "fmla v1.4s, v19.4s, v31.4s\n" + "subs x10, x10, #16\n" + "bge 1b\n" + "2:\n" // LoopD12 + "adds x10, x10, #16\n" + "cbz x10, 6f\n" + "cmp x10, #12\n" + "blt 3f\n" + "ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n" + "ld1 {v16.4s, v17.4s, v18.4s}, [x5], #48\n" + "ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v1.4s, v18.4s, v30.4s\n" + "sub x10, x10, #12\n" + "b 7f\n" + "3:\n" // LoopD8 + "cmp x10, #8\n" + "blt 4f\n" + "ld1 {v4.4s, v5.4s}, [x8], #32\n" + "ld1 {v16.4s, v17.4s}, [x5], #32\n" + "ld1 {v28.4s, v29.4s}, [x9], #32\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "sub x10, x10, #8\n" + "b 7f\n" + "4:\n" // LoopD4 + "cmp x10, #4\n" + "blt 7f\n" + "ld1 {v4.4s}, [x8], #16\n" + "ld1 {v16.4s}, [x5], #16\n" + "ld1 {v28.4s}, [x9], #16\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "sub x10, x10, #4\n" + "7:\n" + "cbz x10, 6f\n" + "dup v4.2d, xzr\n" + "dup v16.2d, xzr\n" + "subs x10, x10, #2\n" + "blt 5f\n" + "ld1 {v4.d}[0], [x8], #8\n" // LoopD2 + "ld1 {v16.d}[0], [x5], #8\n" + "ld1 {v28.d}[0], [x9], #8\n" + "cbz x10, 8f\n" + "5:\n" // LoopD1 + "ld1 {v4.s}[2], [x8]\n" + "ld1 {v16.s}[2], [x5]\n" + "ld1 {v28.s}[2], [x9]\n" + "8:\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "6:\n" + "faddp v4.4s, v0.4s, v1.4s\n" + "faddp v0.4s, v4.4s, v4.4s\n" + "cbz %[bias], 9f\n" + "ld1r {v1.4s}, [%[bias]]\n" + "fadd v0.2s, v0.2s, v1.2s\n" + "9:\n" + "st1 {v0.2s}, [%[output]]\n" + + : + : [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep) + : "cc", "x5", "x8", "x9", "x10", "v0", "v1", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29", + "v30", "v31", "memory"); +} + +void MatMul1x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) { + // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute + // 9: WriteBack + asm volatile( + "mov x8, %[input]\n" + "mov x9, %[weight]\n" + "mov x10, %[deep]\n" + "dup v0.2d, xzr\n" + "subs x10, x10, #16\n" + "blt 2f\n" + "1:\n" // LoopD16 + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v0.4s, v7.4s, v31.4s\n" + "subs x10, x10, #16\n" + "bge 1b\n" + "2:\n" // LoopD12 + "adds x10, x10, #16\n" + "cbz x10, 6f\n" + "cmp x10, #12\n" + "blt 3f\n" + "ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n" + "ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "sub x10, x10, #12\n" + "b 7f\n" + "3:\n" // LoopD8 + "cmp x10, #8\n" + "blt 4f\n" + "ld1 {v4.4s, v5.4s}, [x8], #32\n" + "ld1 {v28.4s, v29.4s}, [x9], #32\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "sub x10, x10, #8\n" + "b 7f\n" + "4:\n" // LoopD4 + "cmp x10, #4\n" + "blt 7f\n" + "ld1 {v4.4s}, [x8], #16\n" + "ld1 {v28.4s}, [x9], #16\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "sub x10, x10, #4\n" + "7:\n" + "cbz x10, 6f\n" + "dup v4.2d, xzr\n" + "subs x10, x10, #2\n" + "blt 5f\n" + "ld1 {v4.d}[0], [x8], #8\n" // LoopD2 + "ld1 {v28.d}[0], [x9], #8\n" + "cbz x10, 8f\n" + "5:\n" // LoopD1 + "ld1 {v4.s}[3], [x8]\n" + "ld1 {v28.s}[3], [x9]\n" + "8:\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "6:\n" + "faddp v4.4s, v0.4s, v0.4s\n" + "faddp v0.4s, v4.4s, v4.4s\n" + "cbz %[bias], 9f\n" + "ld1 {v1.s}[0], [%[bias]]\n" + "fadd s0, s0, s1\n" + "9:\n" + "st1 {v0.s}[0], [%[output]]\n" + + : + : [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep) + : "cc", "x8", "x9", "x10", "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29", "v30", "v31"); +} + +void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row, + int deep) { + const float *input = a + start_row * deep; + float *output = c + start_row; + const int step = C4NUM * deep; + for (; start_row <= end_row - C4NUM; start_row += C4NUM) { + MatMul4x1Kernel(input, b, output, bias, deep); + input += step; + output += C4NUM; + } + for (; start_row <= end_row - C2NUM; start_row += C2NUM) { + MatMul2x1Kernel(input, b, output, bias, deep); + input += C2NUM * deep; + output += C2NUM; + } + if (start_row == end_row - 1) { + MatMul1x1Kernel(input, b, output, bias, deep); + } +} +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.h index 2d7543a8e6b..e03def587a6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.h @@ -129,6 +129,10 @@ void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k); +#ifdef ENABLE_ARM64 +void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row, + int deep); +#endif #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc index b7965f7ea54..321909e0be2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc @@ -228,7 +228,8 @@ void MatmulFp32BaseCPUKernel::FreeBiasBuf() { } void MatmulFp32BaseCPUKernel::FreeResizeBufA() { - if (!vec_matmul_ && !op_parameter_->is_train_session_ && a_pack_ptr_ != nullptr && is_pack_) { + if (!vec_matmul_ && !op_parameter_->is_train_session_ && a_pack_ptr_ != nullptr && + (is_pack_ || (params_->a_transpose_ && params_->deep_ != 1))) { #ifdef USING_SERVING if (a_is_packed_ == lite::MALLOC) { #endif @@ -293,8 +294,8 @@ int MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch(int task_id) const { bias = bias_ptr_[0]; } for (int index = start_batch; index < end_batch; ++index) { - const float *a = a_pack_ptr_ + index * params_->row_ * params_->deep_; - const float *b = b_pack_ptr_ + index * params_->deep_ * params_->col_; + const float *a = a_pack_ptr_ + a_offset_[index] * params_->row_ * params_->deep_; + const float *b = b_pack_ptr_ + b_offset_[index] * params_->deep_ * params_->col_; float *c = output_data_ + index * params_->row_ * params_->col_; gemmIsNotPackFun(a, b, c, &bias, params_->row_, params_->deep_); } @@ -340,6 +341,21 @@ int MatmulFp32BaseCPUKernel::ParallelRunByOC(int task_id) const { return RET_OK; } +#ifdef ENABLE_ARM64 +int MatmulFp32BaseCPUKernel::ParallelRunByRow(int task_id) const { + int start_row = row_split_points_[task_id]; + int end_row = row_num_; + if (task_id < (static_cast(row_split_points_.size()) - 1)) { + end_row = row_split_points_[task_id + 1]; + } + if (start_row == end_row) { + return RET_OK; + } + GemmIsNotPackByRow(a_pack_ptr_, b_pack_ptr_, output_data_, bias_ptr_, start_row, end_row, params_->deep_); + return RET_OK; +} +#endif + int MatmulFp32BaseCPUKernel::init_global_variable() { #ifdef ENABLE_AVX512 matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor; @@ -372,14 +388,6 @@ int MatmulFp32BaseCPUKernel::init_global_variable() { matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2Col8Major : RowMajor2Row8Major; row_tile_ = C12NUM; col_tile_ = C8NUM; -#endif - params_->row_align_ = UP_ROUND(params_->row_, row_tile_); - params_->col_align_ = UP_ROUND(params_->col_, col_tile_); -#if defined(ENABLE_AVX) || defined(ENABLE_AVX512) - col_step_ = params_->col_align_; -#else - // need not aligned - col_step_ = params_->col_; #endif MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->row_align_, RET_ERROR); MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->row_align_, params_->deep_, RET_ERROR); @@ -387,14 +395,21 @@ int MatmulFp32BaseCPUKernel::init_global_variable() { MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->col_align_, params_->deep_, RET_ERROR); if (params_->col_ == 1 && !params_->a_const_) { is_pack_ = false; - matrix_a_pack_size_ = a_batch_ * params_->row_ * params_->deep_; - matrix_b_pack_size_ = b_batch_ * params_->col_ * params_->deep_; + row_tile_ = 1; + col_tile_ = 1; matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor; matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor; - } else { - matrix_a_pack_size_ = a_batch_ * params_->row_align_ * params_->deep_; - matrix_b_pack_size_ = b_batch_ * params_->col_align_ * params_->deep_; } + params_->row_align_ = UP_ROUND(params_->row_, row_tile_); + params_->col_align_ = UP_ROUND(params_->col_, col_tile_); + matrix_a_pack_size_ = a_batch_ * params_->row_align_ * params_->deep_; + matrix_b_pack_size_ = b_batch_ * params_->col_align_ * params_->deep_; +#if defined(ENABLE_AVX) || defined(ENABLE_AVX512) + col_step_ = params_->col_align_; +#else + // need not aligned + col_step_ = params_->col_; +#endif return RET_OK; } @@ -463,8 +478,12 @@ int MatmulFp32BaseCPUKernel::ReSize() { if (op_parameter_->is_train_session_) { set_workspace_size((matrix_a_pack_size_ + matrix_b_pack_size_) * static_cast(sizeof(float))); } - GetThreadCuttingPolicy(); - auto ret = InitTmpOutBuffer(); + auto ret = GetThreadCuttingPolicy(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ThreadCuttingPolicy error!"; + return ret; + } + ret = InitTmpOutBuffer(); if (ret != RET_OK) { MS_LOG(ERROR) << "InitTmpOutBuffer error!"; return ret; @@ -503,7 +522,7 @@ int MatmulFp32BaseCPUKernel::InitTmpOutBuffer() { return RET_OK; } -void MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() { +int MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() { if (params_->batch >= op_parameter_->thread_num_ || (params_->col_ == 1 && !params_->a_const_)) { thread_count_ = op_parameter_->thread_num_; batch_stride_ = UP_DIV(params_->batch, thread_count_); @@ -526,10 +545,38 @@ void MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() { gemmIsNotPackFun = GemmIsNotPack; } else { gemmIsNotPackFun = GemmIsNotPackOptimize; +#ifdef ENABLE_ARM64 + auto ret = GetThreadCuttingPolicyForArm64WhenVb(); + MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "get thread policy for arm64 failed."); +#endif } } + return RET_OK; } +#ifdef ENABLE_ARM64 +int MatmulFp32BaseCPUKernel::GetThreadCuttingPolicyForArm64WhenVb() { + if (b_batch_ != 1) { + return RET_OK; + } + parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunByRow; + MS_CHECK_FALSE(INT_MUL_OVERFLOW(a_batch_, params_->row_), RET_ERROR); + row_num_ = a_batch_ * params_->row_; + int row_step = MSMAX(row_num_ / op_parameter_->thread_num_, C64NUM); + int row_remaining = MSMAX(row_num_ - row_step * op_parameter_->thread_num_, 0); + row_split_points_.resize(op_parameter_->thread_num_); + for (size_t i = 0; i < row_split_points_.size(); ++i) { + if (i == 0) { + row_split_points_[i] = 0; + continue; + } + row_split_points_[i] = + MSMIN(row_split_points_[i - 1] + row_step + (static_cast(i) < row_remaining ? 1 : 0), row_num_); + } + return RET_OK; +} +#endif + int MatmulFp32BaseCPUKernel::Run() { auto out_data = reinterpret_cast(out_tensors_.front()->data()); CHECK_NULL_RETURN(out_data); @@ -555,9 +602,7 @@ int MatmulFp32BaseCPUKernel::Run() { if (!params_->a_const_) { auto a_ptr = reinterpret_cast(in_tensors_[0]->data()); CHECK_NULL_RETURN(a_ptr); - if (!is_pack_) { - a_pack_ptr_ = a_ptr; - } else { + if (is_pack_ || (params_->a_transpose_ && params_->deep_ != 1)) { if (InitBufferA() != RET_OK) { return RET_ERROR; } @@ -566,6 +611,8 @@ int MatmulFp32BaseCPUKernel::Run() { MS_LOG(ERROR) << "InitMatrixA failed!"; return ret; } + } else { + a_pack_ptr_ = a_ptr; } } @@ -592,7 +639,7 @@ int MatmulFp32BaseCPUKernel::Run() { PackNHWCXToNHWCFp32(output_data_, out_data, params_->batch, params_->row_, params_->col_, col_tile_); } if (!params_->a_const_) { - if (is_pack_) { + if (is_pack_ || (params_->a_transpose_ && params_->deep_ != 1)) { FreeResizeBufA(); } else { a_pack_ptr_ = nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.h b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.h index 93cf32829b3..7e15371c699 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.h @@ -52,6 +52,9 @@ class MatmulFp32BaseCPUKernel : public InnerKernel { int Run() override; public: +#ifdef ENABLE_ARM64 + int ParallelRunByRow(int task_id) const; +#endif int ParallelRunByOC(int task_id) const; int ParallelRunByBatch(int task_id) const; int ParallelRunIsNotPackByBatch(int task_id) const; @@ -75,7 +78,10 @@ class MatmulFp32BaseCPUKernel : public InnerKernel { void FreeResizeBufB(); int CalBroadCastBiasDataElements(); int InitTmpOutBuffer(); - void GetThreadCuttingPolicy(); + int GetThreadCuttingPolicy(); +#ifdef ENABLE_ARM64 + int GetThreadCuttingPolicyForArm64WhenVb(); +#endif protected: MatMulParameter *params_ = nullptr; @@ -115,6 +121,10 @@ class MatmulFp32BaseCPUKernel : public InnerKernel { GemvFun gemvCalFun = nullptr; #endif GemmIsNotPackFun gemmIsNotPackFun = nullptr; +#ifdef ENABLE_ARM64 + int row_num_; + std::vector row_split_points_; +#endif }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_MATMUL_FP32_BASE_H_