forked from mindspore-Ecosystem/mindspore
matmul opt
This commit is contained in:
parent
4cc0cb5b2b
commit
27ddd48788
|
@ -70,6 +70,7 @@ mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/conv_int8.c:Con
|
|||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/conv_int8.c:Conv1x1PreOptPert
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/pack_int8.c:PackNHWCToNCHWInt8
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/pooling_fp32.c:AvgPooling
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c:MatMul4x1Kernel
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.c:SWConv3x32Kernel, SWConv4x24Kernel, SWConv12x8Kernel, SWConv8x8Kernel, SWConv4x8Kernel, SWConv6x16Kernel, SWConv4x16Kernel
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_depthwise_fp32.c:DepthwiseSW3x32Kernel, DepthwiseSW4x24Kernel, DepthwiseSW12x8Kernel, DepthwiseSW8x8Kernel, DepthwiseSW4x8Kernel, DepthwiseSW6x16Kernel, DepthwiseSW4x16Kernel
|
||||
mindspore/mindspore/core/ir/dtype/type.cc:mindspore::ObjectIdLabel
|
||||
|
|
|
@ -2268,3 +2268,331 @@ void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void MatMul4x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) {
|
||||
// 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute
|
||||
// 9: WriteBack
|
||||
asm volatile(
|
||||
"mov x8, %[input]\n"
|
||||
"mov x9, %[weight]\n"
|
||||
"mov x10, %[deep]\n"
|
||||
"add x5, %[input], %[deep], LSL #2\n"
|
||||
"add x6, %[input], %[deep], LSL #3\n"
|
||||
"add x7, x5, %[deep], LSL #3\n"
|
||||
"dup v0.2d, xzr\n"
|
||||
"dup v1.2d, xzr\n"
|
||||
"dup v2.2d, xzr\n"
|
||||
"dup v3.2d, xzr\n"
|
||||
"subs x10, x10, #16\n"
|
||||
"blt 2f\n"
|
||||
"1:\n" // LoopD16
|
||||
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n"
|
||||
"ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64\n"
|
||||
"ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x6], #64\n"
|
||||
"ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x7], #64\n"
|
||||
"ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v1.4s, v16.4s, v28.4s\n"
|
||||
"fmla v2.4s, v20.4s, v28.4s\n"
|
||||
"fmla v3.4s, v24.4s, v28.4s\n"
|
||||
"fmla v0.4s, v5.4s, v29.4s\n"
|
||||
"fmla v1.4s, v17.4s, v29.4s\n"
|
||||
"fmla v2.4s, v21.4s, v29.4s\n"
|
||||
"fmla v3.4s, v25.4s, v29.4s\n"
|
||||
"fmla v0.4s, v6.4s, v30.4s\n"
|
||||
"fmla v1.4s, v18.4s, v30.4s\n"
|
||||
"fmla v2.4s, v22.4s, v30.4s\n"
|
||||
"fmla v3.4s, v26.4s, v30.4s\n"
|
||||
"fmla v0.4s, v7.4s, v31.4s\n"
|
||||
"fmla v1.4s, v19.4s, v31.4s\n"
|
||||
"fmla v2.4s, v23.4s, v31.4s\n"
|
||||
"fmla v3.4s, v27.4s, v31.4s\n"
|
||||
"subs x10, x10, #16\n"
|
||||
"bge 1b\n"
|
||||
"2:\n" // LoopD12
|
||||
"adds x10, x10, #16\n"
|
||||
"cbz x10, 6f\n"
|
||||
"cmp x10, #12\n"
|
||||
"blt 3f\n"
|
||||
"ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n"
|
||||
"ld1 {v16.4s, v17.4s, v18.4s}, [x5], #48\n"
|
||||
"ld1 {v20.4s, v21.4s, v22.4s}, [x6], #48\n"
|
||||
"ld1 {v24.4s, v25.4s, v26.4s}, [x7], #48\n"
|
||||
"ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v1.4s, v16.4s, v28.4s\n"
|
||||
"fmla v2.4s, v20.4s, v28.4s\n"
|
||||
"fmla v3.4s, v24.4s, v28.4s\n"
|
||||
"fmla v0.4s, v5.4s, v29.4s\n"
|
||||
"fmla v1.4s, v17.4s, v29.4s\n"
|
||||
"fmla v2.4s, v21.4s, v29.4s\n"
|
||||
"fmla v3.4s, v25.4s, v29.4s\n"
|
||||
"fmla v0.4s, v6.4s, v30.4s\n"
|
||||
"fmla v1.4s, v18.4s, v30.4s\n"
|
||||
"fmla v2.4s, v22.4s, v30.4s\n"
|
||||
"fmla v3.4s, v26.4s, v30.4s\n"
|
||||
"sub x10, x10, #12\n"
|
||||
"b 7f\n"
|
||||
"3:\n" // LoopD8
|
||||
"cmp x10, #8\n"
|
||||
"blt 4f\n"
|
||||
"ld1 {v4.4s, v5.4s}, [x8], #32\n"
|
||||
"ld1 {v16.4s, v17.4s}, [x5], #32\n"
|
||||
"ld1 {v20.4s, v21.4s}, [x6], #32\n"
|
||||
"ld1 {v24.4s, v25.4s}, [x7], #32\n"
|
||||
"ld1 {v28.4s, v29.4s}, [x9], #32\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v1.4s, v16.4s, v28.4s\n"
|
||||
"fmla v2.4s, v20.4s, v28.4s\n"
|
||||
"fmla v3.4s, v24.4s, v28.4s\n"
|
||||
"fmla v0.4s, v5.4s, v29.4s\n"
|
||||
"fmla v1.4s, v17.4s, v29.4s\n"
|
||||
"fmla v2.4s, v21.4s, v29.4s\n"
|
||||
"fmla v3.4s, v25.4s, v29.4s\n"
|
||||
"sub x10, x10, #8\n"
|
||||
"b 7f\n"
|
||||
"4:\n" // LoopD4
|
||||
"cmp x10, #4\n"
|
||||
"blt 7f\n"
|
||||
"ld1 {v4.4s}, [x8], #16\n"
|
||||
"ld1 {v16.4s}, [x5], #16\n"
|
||||
"ld1 {v20.4s}, [x6], #16\n"
|
||||
"ld1 {v24.4s}, [x7], #16\n"
|
||||
"ld1 {v28.4s}, [x9], #16\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v1.4s, v16.4s, v28.4s\n"
|
||||
"fmla v2.4s, v20.4s, v28.4s\n"
|
||||
"fmla v3.4s, v24.4s, v28.4s\n"
|
||||
"sub x10, x10, #4\n"
|
||||
"7:\n"
|
||||
"cbz x10, 6f\n"
|
||||
"dup v4.2d, xzr\n"
|
||||
"dup v16.2d, xzr\n"
|
||||
"dup v20.2d, xzr\n"
|
||||
"dup v24.2d, xzr\n"
|
||||
"dup v28.2d, xzr\n"
|
||||
"subs x10, x10, #2\n"
|
||||
"blt 5f\n"
|
||||
"ld1 {v4.d}[0], [x8], #8\n" // LoopD2
|
||||
"ld1 {v16.d}[0], [x5], #8\n"
|
||||
"ld1 {v20.d}[0], [x6], #8\n"
|
||||
"ld1 {v24.d}[0], [x7], #8\n"
|
||||
"ld1 {v28.d}[0], [x9], #8\n"
|
||||
"cbz x10, 8f\n"
|
||||
"5:\n" // LoopD1
|
||||
"ld1 {v4.s}[2], [x8]\n"
|
||||
"ld1 {v16.s}[2], [x5]\n"
|
||||
"ld1 {v20.s}[2], [x6]\n"
|
||||
"ld1 {v24.s}[2], [x7]\n"
|
||||
"ld1 {v28.s}[2], [x9]\n"
|
||||
"8:\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v1.4s, v16.4s, v28.4s\n"
|
||||
"fmla v2.4s, v20.4s, v28.4s\n"
|
||||
"fmla v3.4s, v24.4s, v28.4s\n"
|
||||
"6:\n"
|
||||
"faddp v4.4s, v0.4s, v1.4s\n"
|
||||
"faddp v5.4s, v2.4s, v3.4s\n"
|
||||
"faddp v0.4s, v4.4s, v5.4s\n"
|
||||
"cbz %[bias], 9f\n"
|
||||
"ld1r {v1.4s}, [%[bias]]\n"
|
||||
"fadd v0.4s, v0.4s, v1.4s\n"
|
||||
"9:\n"
|
||||
"st1 {v0.4s}, [%[output]]\n"
|
||||
|
||||
:
|
||||
: [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep)
|
||||
: "cc", "x5", "x6", "x7", "x8", "x9", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18",
|
||||
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
|
||||
}
|
||||
|
||||
void MatMul2x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) {
|
||||
// 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute
|
||||
// 9: WriteBack
|
||||
asm volatile(
|
||||
"mov x8, %[input]\n"
|
||||
"mov x9, %[weight]\n"
|
||||
"mov x10, %[deep]\n"
|
||||
"add x5, %[input], %[deep], LSL #2\n"
|
||||
"dup v0.2d, xzr\n"
|
||||
"dup v1.2d, xzr\n"
|
||||
"subs x10, x10, #16\n"
|
||||
"blt 2f\n"
|
||||
"1:\n" // LoopD16
|
||||
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n"
|
||||
"ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64\n"
|
||||
"ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v1.4s, v16.4s, v28.4s\n"
|
||||
"fmla v0.4s, v5.4s, v29.4s\n"
|
||||
"fmla v1.4s, v17.4s, v29.4s\n"
|
||||
"fmla v0.4s, v6.4s, v30.4s\n"
|
||||
"fmla v1.4s, v18.4s, v30.4s\n"
|
||||
"fmla v0.4s, v7.4s, v31.4s\n"
|
||||
"fmla v1.4s, v19.4s, v31.4s\n"
|
||||
"subs x10, x10, #16\n"
|
||||
"bge 1b\n"
|
||||
"2:\n" // LoopD12
|
||||
"adds x10, x10, #16\n"
|
||||
"cbz x10, 6f\n"
|
||||
"cmp x10, #12\n"
|
||||
"blt 3f\n"
|
||||
"ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n"
|
||||
"ld1 {v16.4s, v17.4s, v18.4s}, [x5], #48\n"
|
||||
"ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v1.4s, v16.4s, v28.4s\n"
|
||||
"fmla v0.4s, v5.4s, v29.4s\n"
|
||||
"fmla v1.4s, v17.4s, v29.4s\n"
|
||||
"fmla v0.4s, v6.4s, v30.4s\n"
|
||||
"fmla v1.4s, v18.4s, v30.4s\n"
|
||||
"sub x10, x10, #12\n"
|
||||
"b 7f\n"
|
||||
"3:\n" // LoopD8
|
||||
"cmp x10, #8\n"
|
||||
"blt 4f\n"
|
||||
"ld1 {v4.4s, v5.4s}, [x8], #32\n"
|
||||
"ld1 {v16.4s, v17.4s}, [x5], #32\n"
|
||||
"ld1 {v28.4s, v29.4s}, [x9], #32\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v1.4s, v16.4s, v28.4s\n"
|
||||
"fmla v0.4s, v5.4s, v29.4s\n"
|
||||
"fmla v1.4s, v17.4s, v29.4s\n"
|
||||
"sub x10, x10, #8\n"
|
||||
"b 7f\n"
|
||||
"4:\n" // LoopD4
|
||||
"cmp x10, #4\n"
|
||||
"blt 7f\n"
|
||||
"ld1 {v4.4s}, [x8], #16\n"
|
||||
"ld1 {v16.4s}, [x5], #16\n"
|
||||
"ld1 {v28.4s}, [x9], #16\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v1.4s, v16.4s, v28.4s\n"
|
||||
"sub x10, x10, #4\n"
|
||||
"7:\n"
|
||||
"cbz x10, 6f\n"
|
||||
"dup v4.2d, xzr\n"
|
||||
"dup v16.2d, xzr\n"
|
||||
"subs x10, x10, #2\n"
|
||||
"blt 5f\n"
|
||||
"ld1 {v4.d}[0], [x8], #8\n" // LoopD2
|
||||
"ld1 {v16.d}[0], [x5], #8\n"
|
||||
"ld1 {v28.d}[0], [x9], #8\n"
|
||||
"cbz x10, 8f\n"
|
||||
"5:\n" // LoopD1
|
||||
"ld1 {v4.s}[2], [x8]\n"
|
||||
"ld1 {v16.s}[2], [x5]\n"
|
||||
"ld1 {v28.s}[2], [x9]\n"
|
||||
"8:\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v1.4s, v16.4s, v28.4s\n"
|
||||
"6:\n"
|
||||
"faddp v4.4s, v0.4s, v1.4s\n"
|
||||
"faddp v0.4s, v4.4s, v4.4s\n"
|
||||
"cbz %[bias], 9f\n"
|
||||
"ld1r {v1.4s}, [%[bias]]\n"
|
||||
"fadd v0.2s, v0.2s, v1.2s\n"
|
||||
"9:\n"
|
||||
"st1 {v0.2s}, [%[output]]\n"
|
||||
|
||||
:
|
||||
: [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep)
|
||||
: "cc", "x5", "x8", "x9", "x10", "v0", "v1", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29",
|
||||
"v30", "v31", "memory");
|
||||
}
|
||||
|
||||
void MatMul1x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) {
|
||||
// 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute
|
||||
// 9: WriteBack
|
||||
asm volatile(
|
||||
"mov x8, %[input]\n"
|
||||
"mov x9, %[weight]\n"
|
||||
"mov x10, %[deep]\n"
|
||||
"dup v0.2d, xzr\n"
|
||||
"subs x10, x10, #16\n"
|
||||
"blt 2f\n"
|
||||
"1:\n" // LoopD16
|
||||
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n"
|
||||
"ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v0.4s, v5.4s, v29.4s\n"
|
||||
"fmla v0.4s, v6.4s, v30.4s\n"
|
||||
"fmla v0.4s, v7.4s, v31.4s\n"
|
||||
"subs x10, x10, #16\n"
|
||||
"bge 1b\n"
|
||||
"2:\n" // LoopD12
|
||||
"adds x10, x10, #16\n"
|
||||
"cbz x10, 6f\n"
|
||||
"cmp x10, #12\n"
|
||||
"blt 3f\n"
|
||||
"ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n"
|
||||
"ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v0.4s, v5.4s, v29.4s\n"
|
||||
"fmla v0.4s, v6.4s, v30.4s\n"
|
||||
"sub x10, x10, #12\n"
|
||||
"b 7f\n"
|
||||
"3:\n" // LoopD8
|
||||
"cmp x10, #8\n"
|
||||
"blt 4f\n"
|
||||
"ld1 {v4.4s, v5.4s}, [x8], #32\n"
|
||||
"ld1 {v28.4s, v29.4s}, [x9], #32\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"fmla v0.4s, v5.4s, v29.4s\n"
|
||||
"sub x10, x10, #8\n"
|
||||
"b 7f\n"
|
||||
"4:\n" // LoopD4
|
||||
"cmp x10, #4\n"
|
||||
"blt 7f\n"
|
||||
"ld1 {v4.4s}, [x8], #16\n"
|
||||
"ld1 {v28.4s}, [x9], #16\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"sub x10, x10, #4\n"
|
||||
"7:\n"
|
||||
"cbz x10, 6f\n"
|
||||
"dup v4.2d, xzr\n"
|
||||
"subs x10, x10, #2\n"
|
||||
"blt 5f\n"
|
||||
"ld1 {v4.d}[0], [x8], #8\n" // LoopD2
|
||||
"ld1 {v28.d}[0], [x9], #8\n"
|
||||
"cbz x10, 8f\n"
|
||||
"5:\n" // LoopD1
|
||||
"ld1 {v4.s}[3], [x8]\n"
|
||||
"ld1 {v28.s}[3], [x9]\n"
|
||||
"8:\n"
|
||||
"fmla v0.4s, v4.4s, v28.4s\n"
|
||||
"6:\n"
|
||||
"faddp v4.4s, v0.4s, v0.4s\n"
|
||||
"faddp v0.4s, v4.4s, v4.4s\n"
|
||||
"cbz %[bias], 9f\n"
|
||||
"ld1 {v1.s}[0], [%[bias]]\n"
|
||||
"fadd s0, s0, s1\n"
|
||||
"9:\n"
|
||||
"st1 {v0.s}[0], [%[output]]\n"
|
||||
|
||||
:
|
||||
: [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep)
|
||||
: "cc", "x8", "x9", "x10", "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29", "v30", "v31");
|
||||
}
|
||||
|
||||
void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row,
|
||||
int deep) {
|
||||
const float *input = a + start_row * deep;
|
||||
float *output = c + start_row;
|
||||
const int step = C4NUM * deep;
|
||||
for (; start_row <= end_row - C4NUM; start_row += C4NUM) {
|
||||
MatMul4x1Kernel(input, b, output, bias, deep);
|
||||
input += step;
|
||||
output += C4NUM;
|
||||
}
|
||||
for (; start_row <= end_row - C2NUM; start_row += C2NUM) {
|
||||
MatMul2x1Kernel(input, b, output, bias, deep);
|
||||
input += C2NUM * deep;
|
||||
output += C2NUM;
|
||||
}
|
||||
if (start_row == end_row - 1) {
|
||||
MatMul1x1Kernel(input, b, output, bias, deep);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -129,6 +129,10 @@ void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias,
|
|||
|
||||
void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k);
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row,
|
||||
int deep);
|
||||
#endif
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -228,7 +228,8 @@ void MatmulFp32BaseCPUKernel::FreeBiasBuf() {
|
|||
}
|
||||
|
||||
void MatmulFp32BaseCPUKernel::FreeResizeBufA() {
|
||||
if (!vec_matmul_ && !op_parameter_->is_train_session_ && a_pack_ptr_ != nullptr && is_pack_) {
|
||||
if (!vec_matmul_ && !op_parameter_->is_train_session_ && a_pack_ptr_ != nullptr &&
|
||||
(is_pack_ || (params_->a_transpose_ && params_->deep_ != 1))) {
|
||||
#ifdef USING_SERVING
|
||||
if (a_is_packed_ == lite::MALLOC) {
|
||||
#endif
|
||||
|
@ -293,8 +294,8 @@ int MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch(int task_id) const {
|
|||
bias = bias_ptr_[0];
|
||||
}
|
||||
for (int index = start_batch; index < end_batch; ++index) {
|
||||
const float *a = a_pack_ptr_ + index * params_->row_ * params_->deep_;
|
||||
const float *b = b_pack_ptr_ + index * params_->deep_ * params_->col_;
|
||||
const float *a = a_pack_ptr_ + a_offset_[index] * params_->row_ * params_->deep_;
|
||||
const float *b = b_pack_ptr_ + b_offset_[index] * params_->deep_ * params_->col_;
|
||||
float *c = output_data_ + index * params_->row_ * params_->col_;
|
||||
gemmIsNotPackFun(a, b, c, &bias, params_->row_, params_->deep_);
|
||||
}
|
||||
|
@ -340,6 +341,21 @@ int MatmulFp32BaseCPUKernel::ParallelRunByOC(int task_id) const {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
int MatmulFp32BaseCPUKernel::ParallelRunByRow(int task_id) const {
|
||||
int start_row = row_split_points_[task_id];
|
||||
int end_row = row_num_;
|
||||
if (task_id < (static_cast<int>(row_split_points_.size()) - 1)) {
|
||||
end_row = row_split_points_[task_id + 1];
|
||||
}
|
||||
if (start_row == end_row) {
|
||||
return RET_OK;
|
||||
}
|
||||
GemmIsNotPackByRow(a_pack_ptr_, b_pack_ptr_, output_data_, bias_ptr_, start_row, end_row, params_->deep_);
|
||||
return RET_OK;
|
||||
}
|
||||
#endif
|
||||
|
||||
int MatmulFp32BaseCPUKernel::init_global_variable() {
|
||||
#ifdef ENABLE_AVX512
|
||||
matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor;
|
||||
|
@ -372,14 +388,6 @@ int MatmulFp32BaseCPUKernel::init_global_variable() {
|
|||
matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2Col8Major : RowMajor2Row8Major;
|
||||
row_tile_ = C12NUM;
|
||||
col_tile_ = C8NUM;
|
||||
#endif
|
||||
params_->row_align_ = UP_ROUND(params_->row_, row_tile_);
|
||||
params_->col_align_ = UP_ROUND(params_->col_, col_tile_);
|
||||
#if defined(ENABLE_AVX) || defined(ENABLE_AVX512)
|
||||
col_step_ = params_->col_align_;
|
||||
#else
|
||||
// need not aligned
|
||||
col_step_ = params_->col_;
|
||||
#endif
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->row_align_, RET_ERROR);
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->row_align_, params_->deep_, RET_ERROR);
|
||||
|
@ -387,14 +395,21 @@ int MatmulFp32BaseCPUKernel::init_global_variable() {
|
|||
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->col_align_, params_->deep_, RET_ERROR);
|
||||
if (params_->col_ == 1 && !params_->a_const_) {
|
||||
is_pack_ = false;
|
||||
matrix_a_pack_size_ = a_batch_ * params_->row_ * params_->deep_;
|
||||
matrix_b_pack_size_ = b_batch_ * params_->col_ * params_->deep_;
|
||||
row_tile_ = 1;
|
||||
col_tile_ = 1;
|
||||
matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor;
|
||||
matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor;
|
||||
} else {
|
||||
}
|
||||
params_->row_align_ = UP_ROUND(params_->row_, row_tile_);
|
||||
params_->col_align_ = UP_ROUND(params_->col_, col_tile_);
|
||||
matrix_a_pack_size_ = a_batch_ * params_->row_align_ * params_->deep_;
|
||||
matrix_b_pack_size_ = b_batch_ * params_->col_align_ * params_->deep_;
|
||||
}
|
||||
#if defined(ENABLE_AVX) || defined(ENABLE_AVX512)
|
||||
col_step_ = params_->col_align_;
|
||||
#else
|
||||
// need not aligned
|
||||
col_step_ = params_->col_;
|
||||
#endif
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -463,8 +478,12 @@ int MatmulFp32BaseCPUKernel::ReSize() {
|
|||
if (op_parameter_->is_train_session_) {
|
||||
set_workspace_size((matrix_a_pack_size_ + matrix_b_pack_size_) * static_cast<int>(sizeof(float)));
|
||||
}
|
||||
GetThreadCuttingPolicy();
|
||||
auto ret = InitTmpOutBuffer();
|
||||
auto ret = GetThreadCuttingPolicy();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ThreadCuttingPolicy error!";
|
||||
return ret;
|
||||
}
|
||||
ret = InitTmpOutBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InitTmpOutBuffer error!";
|
||||
return ret;
|
||||
|
@ -503,7 +522,7 @@ int MatmulFp32BaseCPUKernel::InitTmpOutBuffer() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() {
|
||||
int MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() {
|
||||
if (params_->batch >= op_parameter_->thread_num_ || (params_->col_ == 1 && !params_->a_const_)) {
|
||||
thread_count_ = op_parameter_->thread_num_;
|
||||
batch_stride_ = UP_DIV(params_->batch, thread_count_);
|
||||
|
@ -526,10 +545,38 @@ void MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() {
|
|||
gemmIsNotPackFun = GemmIsNotPack;
|
||||
} else {
|
||||
gemmIsNotPackFun = GemmIsNotPackOptimize;
|
||||
#ifdef ENABLE_ARM64
|
||||
auto ret = GetThreadCuttingPolicyForArm64WhenVb();
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "get thread policy for arm64 failed.");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
int MatmulFp32BaseCPUKernel::GetThreadCuttingPolicyForArm64WhenVb() {
|
||||
if (b_batch_ != 1) {
|
||||
return RET_OK;
|
||||
}
|
||||
parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunByRow;
|
||||
MS_CHECK_FALSE(INT_MUL_OVERFLOW(a_batch_, params_->row_), RET_ERROR);
|
||||
row_num_ = a_batch_ * params_->row_;
|
||||
int row_step = MSMAX(row_num_ / op_parameter_->thread_num_, C64NUM);
|
||||
int row_remaining = MSMAX(row_num_ - row_step * op_parameter_->thread_num_, 0);
|
||||
row_split_points_.resize(op_parameter_->thread_num_);
|
||||
for (size_t i = 0; i < row_split_points_.size(); ++i) {
|
||||
if (i == 0) {
|
||||
row_split_points_[i] = 0;
|
||||
continue;
|
||||
}
|
||||
row_split_points_[i] =
|
||||
MSMIN(row_split_points_[i - 1] + row_step + (static_cast<int>(i) < row_remaining ? 1 : 0), row_num_);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
#endif
|
||||
|
||||
int MatmulFp32BaseCPUKernel::Run() {
|
||||
auto out_data = reinterpret_cast<float *>(out_tensors_.front()->data());
|
||||
CHECK_NULL_RETURN(out_data);
|
||||
|
@ -555,9 +602,7 @@ int MatmulFp32BaseCPUKernel::Run() {
|
|||
if (!params_->a_const_) {
|
||||
auto a_ptr = reinterpret_cast<float *>(in_tensors_[0]->data());
|
||||
CHECK_NULL_RETURN(a_ptr);
|
||||
if (!is_pack_) {
|
||||
a_pack_ptr_ = a_ptr;
|
||||
} else {
|
||||
if (is_pack_ || (params_->a_transpose_ && params_->deep_ != 1)) {
|
||||
if (InitBufferA() != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -566,6 +611,8 @@ int MatmulFp32BaseCPUKernel::Run() {
|
|||
MS_LOG(ERROR) << "InitMatrixA failed!";
|
||||
return ret;
|
||||
}
|
||||
} else {
|
||||
a_pack_ptr_ = a_ptr;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -592,7 +639,7 @@ int MatmulFp32BaseCPUKernel::Run() {
|
|||
PackNHWCXToNHWCFp32(output_data_, out_data, params_->batch, params_->row_, params_->col_, col_tile_);
|
||||
}
|
||||
if (!params_->a_const_) {
|
||||
if (is_pack_) {
|
||||
if (is_pack_ || (params_->a_transpose_ && params_->deep_ != 1)) {
|
||||
FreeResizeBufA();
|
||||
} else {
|
||||
a_pack_ptr_ = nullptr;
|
||||
|
|
|
@ -52,6 +52,9 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
|
|||
int Run() override;
|
||||
|
||||
public:
|
||||
#ifdef ENABLE_ARM64
|
||||
int ParallelRunByRow(int task_id) const;
|
||||
#endif
|
||||
int ParallelRunByOC(int task_id) const;
|
||||
int ParallelRunByBatch(int task_id) const;
|
||||
int ParallelRunIsNotPackByBatch(int task_id) const;
|
||||
|
@ -75,7 +78,10 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
|
|||
void FreeResizeBufB();
|
||||
int CalBroadCastBiasDataElements();
|
||||
int InitTmpOutBuffer();
|
||||
void GetThreadCuttingPolicy();
|
||||
int GetThreadCuttingPolicy();
|
||||
#ifdef ENABLE_ARM64
|
||||
int GetThreadCuttingPolicyForArm64WhenVb();
|
||||
#endif
|
||||
|
||||
protected:
|
||||
MatMulParameter *params_ = nullptr;
|
||||
|
@ -115,6 +121,10 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
|
|||
GemvFun gemvCalFun = nullptr;
|
||||
#endif
|
||||
GemmIsNotPackFun gemmIsNotPackFun = nullptr;
|
||||
#ifdef ENABLE_ARM64
|
||||
int row_num_;
|
||||
std::vector<int> row_split_points_;
|
||||
#endif
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_MATMUL_FP32_BASE_H_
|
||||
|
|
Loading…
Reference in New Issue