diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matmul.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/MatmulFp32.S similarity index 97% rename from mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matmul.S rename to mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/MatmulFp32.S index 3c5433f62a7..557a081b167 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matmul.S +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/MatmulFp32.S @@ -582,22 +582,10 @@ Write7: st1 {v31.s}[2], [x16], x17 b WriteEnd WriteC8: - st1 {v16.4s}, [x2], #16 - st1 {v17.4s}, [x2], #16 - st1 {v18.4s}, [x2], #16 - st1 {v19.4s}, [x2], #16 - st1 {v20.4s}, [x2], #16 - st1 {v21.4s}, [x2], #16 - st1 {v22.4s}, [x2], #16 - st1 {v23.4s}, [x2], #16 - st1 {v24.4s}, [x2], #16 - st1 {v25.4s}, [x2], #16 - st1 {v26.4s}, [x2], #16 - st1 {v27.4s}, [x2], #16 - st1 {v28.4s}, [x2], #16 - st1 {v29.4s}, [x2], #16 - st1 {v30.4s}, [x2], #16 - st1 {v31.4s}, [x2], #16 + st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x2], #64 + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x2], #64 + st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x2], #64 + st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x2], #64 b WriteEnd Write8: st1 {v16.4s, v17.4s}, [x18], x17 @@ -634,7 +622,7 @@ End2: ldrb w13, [sp, #8] cbz w13, NoDstStep add x2, x2, #32 // dst ptr + stride - NoDstStep: +NoDstStep: bgt L1 End1: diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/MatmulFp16.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/MatmulFp16.S new file mode 100644 index 00000000000..b667bec931c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/MatmulFp16.S @@ -0,0 +1,874 @@ +#ifdef __aarch64__ + .text + .align 5 + .global MatmulFp16Neon64 +#ifndef __APPLE__ + .type MatmulFp16Neon64, %function +#endif + +// void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, +// int depth, int row, int col, int stride, bool write_nhwc) +// x0: a +// x1: b +// x2: c +// x3: bias +// w4: act_type +// w5: depth +// w6: row +// w7: col +// w17: stride +// w13: writeC8 + +MatmulFp16Neon64: + sub sp, sp, #128 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 + + mov w18, #32 // sizeof(float) * 8 + mul w15, w5, w18 // block stride of lhs/rhs: sizeof(float) * 8 * depth + mov x11, x3 // bias flag + mov x18, #4 + ldr x17, [sp] + mul x17, x17, x18 + +L1: + mov w10, w6 // reload lhs row + mov x12, x0 // reload lhs ptr + mov x18, x2 // reload dst ptr + +L2: + mov x16, x1 // reload rhs ptr + mov w13, w5 // reload depth + mov x14, x3 // reload bias ptr + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + cmp w13, #4 + blt CommLoopMul + +OptLoopMul8: + ld1 {v0.8h, v1.8h}, [x12], #32 + ld1 {v8.8h, v9.8h}, [x16], #32 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v20.8h, v8.8h, v0.h[4] + fmla v21.8h, v8.8h, v0.h[5] + fmla v22.8h, v8.8h, v0.h[6] + fmla v23.8h, v8.8h, v0.h[7] + ld1 {v2.8h, v3.8h}, [x12], #32 + fmla v24.8h, v8.8h, v1.h[0] + fmla v25.8h, v8.8h, v1.h[1] + fmla v26.8h, v8.8h, v1.h[2] + fmla v27.8h, v8.8h, v1.h[3] + fmla v28.8h, v8.8h, v1.h[4] + fmla v29.8h, v8.8h, v1.h[5] + fmla v30.8h, v8.8h, v1.h[6] + fmla v31.8h, v8.8h, v1.h[7] + ld1 {v10.8h, v11.8h}, [x16], #32 + fmla v16.8h, v9.8h, v2.h[0] + fmla v17.8h, v9.8h, v2.h[1] + fmla v18.8h, v9.8h, v2.h[2] + fmla v19.8h, v9.8h, v2.h[3] + fmla v20.8h, v9.8h, v2.h[4] + fmla v21.8h, v9.8h, v2.h[5] + fmla v22.8h, v9.8h, v2.h[6] + fmla v23.8h, v9.8h, v2.h[7] + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12], #64 + fmla v24.8h, v9.8h, v3.h[0] + fmla v25.8h, v9.8h, v3.h[1] + fmla v26.8h, v9.8h, v3.h[2] + fmla v27.8h, v9.8h, v3.h[3] + fmla v28.8h, v9.8h, v3.h[4] + fmla v29.8h, v9.8h, v3.h[5] + fmla v30.8h, v9.8h, v3.h[6] + fmla v31.8h, v9.8h, v3.h[7] + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x16], #64 + fmla v16.8h, v10.8h, v4.h[0] + fmla v17.8h, v10.8h, v4.h[1] + fmla v18.8h, v10.8h, v4.h[2] + fmla v19.8h, v10.8h, v4.h[3] + fmla v20.8h, v10.8h, v4.h[4] + fmla v21.8h, v10.8h, v4.h[5] + fmla v22.8h, v10.8h, v4.h[6] + fmla v23.8h, v10.8h, v4.h[7] + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x12], #64 + fmla v24.8h, v10.8h, v5.h[0] + fmla v25.8h, v10.8h, v5.h[1] + fmla v26.8h, v10.8h, v5.h[2] + fmla v27.8h, v10.8h, v5.h[3] + fmla v28.8h, v10.8h, v5.h[4] + fmla v29.8h, v10.8h, v5.h[5] + fmla v30.8h, v10.8h, v5.h[6] + fmla v31.8h, v10.8h, v5.h[7] + ld1 {v4.8h, v5.8h}, [x12], #32 + fmla v16.8h, v11.8h, v6.h[0] + fmla v17.8h, v11.8h, v6.h[1] + fmla v18.8h, v11.8h, v6.h[2] + fmla v19.8h, v11.8h, v6.h[3] + fmla v20.8h, v11.8h, v6.h[4] + fmla v21.8h, v11.8h, v6.h[5] + fmla v22.8h, v11.8h, v6.h[6] + fmla v23.8h, v11.8h, v6.h[7] + fmla v24.8h, v11.8h, v7.h[0] + fmla v25.8h, v11.8h, v7.h[1] + fmla v26.8h, v11.8h, v7.h[2] + fmla v27.8h, v11.8h, v7.h[3] + fmla v28.8h, v11.8h, v7.h[4] + fmla v29.8h, v11.8h, v7.h[5] + fmla v30.8h, v11.8h, v7.h[6] + fmla v31.8h, v11.8h, v7.h[7] + ld1 {v6.8h, v7.8h}, [x12], #32 + fmla v16.8h, v12.8h, v0.h[0] + fmla v17.8h, v12.8h, v0.h[1] + fmla v18.8h, v12.8h, v0.h[2] + fmla v19.8h, v12.8h, v0.h[3] + fmla v20.8h, v12.8h, v0.h[4] + fmla v21.8h, v12.8h, v0.h[5] + fmla v22.8h, v12.8h, v0.h[6] + fmla v23.8h, v12.8h, v0.h[7] + fmla v24.8h, v12.8h, v1.h[0] + fmla v25.8h, v12.8h, v1.h[1] + fmla v26.8h, v12.8h, v1.h[2] + fmla v27.8h, v12.8h, v1.h[3] + fmla v28.8h, v12.8h, v1.h[4] + fmla v29.8h, v12.8h, v1.h[5] + fmla v30.8h, v12.8h, v1.h[6] + fmla v31.8h, v12.8h, v1.h[7] + fmla v16.8h, v13.8h, v2.h[0] + fmla v17.8h, v13.8h, v2.h[1] + fmla v18.8h, v13.8h, v2.h[2] + fmla v19.8h, v13.8h, v2.h[3] + fmla v20.8h, v13.8h, v2.h[4] + fmla v21.8h, v13.8h, v2.h[5] + fmla v22.8h, v13.8h, v2.h[6] + fmla v23.8h, v13.8h, v2.h[7] + fmla v24.8h, v13.8h, v3.h[0] + fmla v25.8h, v13.8h, v3.h[1] + fmla v26.8h, v13.8h, v3.h[2] + fmla v27.8h, v13.8h, v3.h[3] + fmla v28.8h, v13.8h, v3.h[4] + fmla v29.8h, v13.8h, v3.h[5] + fmla v30.8h, v13.8h, v3.h[6] + fmla v31.8h, v13.8h, v3.h[7] + fmla v16.8h, v14.8h, v4.h[0] + fmla v17.8h, v14.8h, v4.h[1] + fmla v18.8h, v14.8h, v4.h[2] + fmla v19.8h, v14.8h, v4.h[3] + fmla v20.8h, v14.8h, v4.h[4] + fmla v21.8h, v14.8h, v4.h[5] + fmla v22.8h, v14.8h, v4.h[6] + fmla v23.8h, v14.8h, v4.h[7] + fmla v24.8h, v14.8h, v5.h[0] + fmla v25.8h, v14.8h, v5.h[1] + fmla v26.8h, v14.8h, v5.h[2] + fmla v27.8h, v14.8h, v5.h[3] + fmla v28.8h, v14.8h, v5.h[4] + fmla v29.8h, v14.8h, v5.h[5] + fmla v30.8h, v14.8h, v5.h[6] + fmla v31.8h, v14.8h, v5.h[7] + fmla v16.8h, v15.8h, v6.h[0] + fmla v17.8h, v15.8h, v6.h[1] + fmla v18.8h, v15.8h, v6.h[2] + fmla v19.8h, v15.8h, v6.h[3] + fmla v20.8h, v15.8h, v6.h[4] + fmla v21.8h, v15.8h, v6.h[5] + fmla v22.8h, v15.8h, v6.h[6] + fmla v23.8h, v15.8h, v6.h[7] + fmla v24.8h, v15.8h, v7.h[0] + fmla v25.8h, v15.8h, v7.h[1] + fmla v26.8h, v15.8h, v7.h[2] + fmla v27.8h, v15.8h, v7.h[3] + fmla v28.8h, v15.8h, v7.h[4] + fmla v29.8h, v15.8h, v7.h[5] + fmla v30.8h, v15.8h, v7.h[6] + fmla v31.8h, v15.8h, v7.h[7] + + sub w13, w13, #8 + cmp w13, #0 + ble Bias + cmp w13, #8 + bge OptLoopMul8 + +CommLoopMul: + ld1 {v0.8h, v1.8h}, [x12], #32 + ld1 {v8.8h}, [x16], #16 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v20.8h, v8.8h, v0.h[4] + fmla v21.8h, v8.8h, v0.h[5] + fmla v22.8h, v8.8h, v0.h[6] + fmla v23.8h, v8.8h, v0.h[7] + fmla v24.8h, v8.8h, v1.h[0] + fmla v25.8h, v8.8h, v1.h[1] + fmla v26.8h, v8.8h, v1.h[2] + fmla v27.8h, v8.8h, v1.h[3] + fmla v28.8h, v8.8h, v1.h[4] + fmla v29.8h, v8.8h, v1.h[5] + fmla v30.8h, v8.8h, v1.h[6] + fmla v31.8h, v8.8h, v1.h[7] + + subs w13, w13, #1 + bgt CommLoopMul + +Bias: + cbz x11, Activation + ld1 {v0.8h}, [x14], #16 + fadd v16.8h, v16.8h, v0.8h + fadd v17.8h, v17.8h, v0.8h + fadd v18.8h, v18.8h, v0.8h + fadd v19.8h, v19.8h, v0.8h + fadd v20.8h, v20.8h, v0.8h + fadd v21.8h, v21.8h, v0.8h + fadd v22.8h, v22.8h, v0.8h + fadd v23.8h, v23.8h, v0.8h + fadd v24.8h, v24.8h, v0.8h + fadd v25.8h, v25.8h, v0.8h + fadd v26.8h, v26.8h, v0.8h + fadd v27.8h, v27.8h, v0.8h + fadd v28.8h, v28.8h, v0.8h + fadd v29.8h, v29.8h, v0.8h + fadd v30.8h, v30.8h, v0.8h + fadd v31.8h, v31.8h, v0.8h + +Activation: + cmp w4, #2 + beq Relu6 + cmp w4, #1 + beq Relu + b Write + +Relu6: + movi v15.8h, #0x46, lsl #8 + fmin v16.8h, v16.8h, v15.8h + fmin v17.8h, v17.8h, v15.8h + fmin v18.8h, v18.8h, v15.8h + fmin v19.8h, v19.8h, v15.8h + fmin v20.8h, v20.8h, v15.8h + fmin v21.8h, v21.8h, v15.8h + fmin v22.8h, v22.8h, v15.8h + fmin v23.8h, v23.8h, v15.8h + fmin v24.8h, v24.8h, v15.8h + fmin v25.8h, v25.8h, v15.8h + fmin v26.8h, v26.8h, v15.8h + fmin v27.8h, v27.8h, v15.8h + fmin v28.8h, v28.8h, v15.8h + fmin v29.8h, v29.8h, v15.8h + fmin v30.8h, v30.8h, v15.8h + fmin v31.8h, v31.8h, v15.8h + +Relu: + dup v14.4s, wzr + fmax v16.8h, v16.8h, v14.8h + fmax v17.8h, v17.8h, v14.8h + fmax v18.8h, v18.8h, v14.8h + fmax v19.8h, v19.8h, v14.8h + fmax v20.8h, v20.8h, v14.8h + fmax v21.8h, v21.8h, v14.8h + fmax v22.8h, v22.8h, v14.8h + fmax v23.8h, v23.8h, v14.8h + fmax v24.8h, v24.8h, v14.8h + fmax v25.8h, v25.8h, v14.8h + fmax v26.8h, v26.8h, v14.8h + fmax v27.8h, v27.8h, v14.8h + fmax v28.8h, v28.8h, v14.8h + fmax v29.8h, v29.8h, v14.8h + fmax v30.8h, v30.8h, v14.8h + fmax v31.8h, v31.8h, v14.8h + +Write: + ldrb w13, [sp, #8] + cbz w13, WriteC8 + cmp w7, #1 + beq Write1 + cmp w7, #2 + beq Write2 + cmp w7, #3 + beq Write3 + cmp w7, #4 + beq Write4 + cmp w7, #5 + beq Write5 + cmp w7, #6 + beq Write6 + cmp w7, #7 + beq Write7 + b Write8 + +Write1: + st1 {v16.h}[0], [x18], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.h}[0], [x18], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.h}[0], [x18], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.h}[0], [x18], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.h}[0], [x18], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.h}[0], [x18], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.h}[0], [x18], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.h}[0], [x18], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.h}[0], [x18], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.h}[0], [x18], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.h}[0], [x18], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.h}[0], [x18], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.h}[0], [x18], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.h}[0], [x18], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.h}[0], [x18], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.h}[0], [x18], x17 + b WriteEnd +Write2: + add x13, x18, #2 + st1 {v16.h}[0], [x18], x17 + st1 {v16.h}[1], [x13], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.h}[0], [x18], x17 + st1 {v17.h}[1], [x13], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.h}[0], [x18], x17 + st1 {v18.h}[1], [x13], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.h}[0], [x18], x17 + st1 {v19.h}[1], [x13], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.h}[0], [x18], x17 + st1 {v20.h}[1], [x13], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.h}[0], [x18], x17 + st1 {v21.h}[1], [x13], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.h}[0], [x18], x17 + st1 {v22.h}[1], [x13], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.h}[0], [x18], x17 + st1 {v23.h}[1], [x13], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.h}[0], [x18], x17 + st1 {v24.h}[1], [x13], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.h}[0], [x18], x17 + st1 {v25.h}[1], [x13], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.h}[0], [x18], x17 + st1 {v26.h}[1], [x13], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.h}[0], [x18], x17 + st1 {v27.h}[1], [x13], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.h}[0], [x18], x17 + st1 {v28.h}[1], [x13], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.h}[0], [x18], x17 + st1 {v29.h}[1], [x13], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.h}[0], [x18], x17 + st1 {v30.h}[1], [x13], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.h}[0], [x18], x17 + st1 {v31.h}[1], [x13], x17 + b WriteEnd +Write3: + add x13, x18, #2 + add x14, x18, #4 + st1 {v16.h}[0], [x18], x17 + st1 {v16.h}[1], [x13], x17 + st1 {v16.h}[2], [x14], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.h}[0], [x18], x17 + st1 {v17.h}[1], [x13], x17 + st1 {v17.h}[2], [x14], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.h}[0], [x18], x17 + st1 {v18.h}[1], [x13], x17 + st1 {v18.h}[2], [x14], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.h}[0], [x18], x17 + st1 {v19.h}[1], [x13], x17 + st1 {v19.h}[2], [x14], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.h}[0], [x18], x17 + st1 {v20.h}[1], [x13], x17 + st1 {v20.h}[2], [x14], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.h}[0], [x18], x17 + st1 {v21.h}[1], [x13], x17 + st1 {v21.h}[2], [x14], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.h}[0], [x18], x17 + st1 {v22.h}[1], [x13], x17 + st1 {v22.h}[2], [x14], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.h}[0], [x18], x17 + st1 {v23.h}[1], [x13], x17 + st1 {v23.h}[2], [x14], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.h}[0], [x18], x17 + st1 {v24.h}[1], [x13], x17 + st1 {v24.h}[2], [x14], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.h}[0], [x18], x17 + st1 {v25.h}[1], [x13], x17 + st1 {v25.h}[2], [x14], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.h}[0], [x18], x17 + st1 {v26.h}[1], [x13], x17 + st1 {v26.h}[2], [x14], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.h}[0], [x18], x17 + st1 {v27.h}[1], [x13], x17 + st1 {v27.h}[2], [x14], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.h}[0], [x18], x17 + st1 {v28.h}[1], [x13], x17 + st1 {v28.h}[2], [x14], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.h}[0], [x18], x17 + st1 {v29.h}[1], [x13], x17 + st1 {v29.h}[2], [x14], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.h}[0], [x18], x17 + st1 {v30.h}[1], [x13], x17 + st1 {v30.h}[2], [x14], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.h}[0], [x18], x17 + st1 {v31.h}[1], [x13], x17 + st1 {v31.h}[2], [x14], x17 + b WriteEnd +Write4: + st1 {v16.4h}, [x18], x17 + cmp w10, #1 + beq WriteEnd + st1 {v16.4h}, [x18], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.4h}, [x18], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.4h}, [x18], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.4h}, [x18], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.4h}, [x18], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.4h}, [x18], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.4h}, [x18], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4h}, [x18], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.4h}, [x18], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.4h}, [x18], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.4h}, [x18], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.4h}, [x18], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.4h}, [x18], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.4h}, [x18], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.4h}, [x18], x17 + b WriteEnd +Write5: + add x13, x18, #8 + st1 {v16.4h}, [x18], x17 + st1 {v16.h}[4], [x13], x17 + cmp w10, #1 + beq WriteEnd + st1 {v16.4h}, [x18], x17 + st1 {v16.h}[4], [x13], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.4h}, [x18], x17 + st1 {v18.h}[4], [x13], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.4h}, [x18], x17 + st1 {v19.h}[4], [x13], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.4h}, [x18], x17 + st1 {v20.h}[4], [x13], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.4h}, [x18], x17 + st1 {v21.h}[4], [x13], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.4h}, [x18], x17 + st1 {v22.h}[4], [x13], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.4h}, [x18], x17 + st1 {v23.h}[4], [x13], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4h}, [x18], x17 + st1 {v24.h}[4], [x13], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.4h}, [x18], x17 + st1 {v25.h}[4], [x13], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.4h}, [x18], x17 + st1 {v26.h}[4], [x13], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.4h}, [x18], x17 + st1 {v27.h}[4], [x13], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.4h}, [x18], x17 + st1 {v28.h}[4], [x13], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.4h}, [x18], x17 + st1 {v29.h}[4], [x13], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.4h}, [x18], x17 + st1 {v30.h}[4], [x13], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.4h}, [x18], x17 + st1 {v31.h}[4], [x13], x17 + b WriteEnd +Write6: + add x13, x18, #8 + add x14, x18, #10 + st1 {v16.4h}, [x18], x17 + st1 {v16.h}[4], [x13], x17 + st1 {v16.h}[5], [x14], x17 + cmp w10, #1 + beq WriteEnd + st1 {v16.4h}, [x18], x17 + st1 {v16.h}[4], [x13], x17 + st1 {v16.h}[5], [x14], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.4h}, [x18], x17 + st1 {v18.h}[4], [x13], x17 + st1 {v18.h}[5], [x14], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.4h}, [x18], x17 + st1 {v19.h}[4], [x13], x17 + st1 {v19.h}[5], [x14], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.4h}, [x18], x17 + st1 {v20.h}[4], [x13], x17 + st1 {v20.h}[5], [x14], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.4h}, [x18], x17 + st1 {v21.h}[4], [x13], x17 + st1 {v21.h}[5], [x14], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.4h}, [x18], x17 + st1 {v22.h}[4], [x13], x17 + st1 {v22.h}[5], [x14], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.4h}, [x18], x17 + st1 {v23.h}[4], [x13], x17 + st1 {v23.h}[5], [x14], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4h}, [x18], x17 + st1 {v24.h}[4], [x13], x17 + st1 {v24.h}[5], [x14], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.4h}, [x18], x17 + st1 {v25.h}[4], [x13], x17 + st1 {v25.h}[5], [x14], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.4h}, [x18], x17 + st1 {v26.h}[4], [x13], x17 + st1 {v26.h}[5], [x14], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.4h}, [x18], x17 + st1 {v27.h}[4], [x13], x17 + st1 {v27.h}[5], [x14], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.4h}, [x18], x17 + st1 {v28.h}[4], [x13], x17 + st1 {v28.h}[5], [x14], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.4h}, [x18], x17 + st1 {v29.h}[4], [x13], x17 + st1 {v29.h}[5], [x14], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.4h}, [x18], x17 + st1 {v30.h}[4], [x13], x17 + st1 {v30.h}[5], [x14], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.4h}, [x18], x17 + st1 {v31.h}[4], [x13], x17 + st1 {v31.h}[5], [x14], x17 + b WriteEnd +Write7: + add x13, x18, #8 + add x14, x18, #10 + add x16, x18, #12 + st1 {v16.4h}, [x18], x17 + st1 {v16.h}[4], [x13], x17 + st1 {v16.h}[5], [x14], x17 + st1 {v16.h}[6], [x16], x17 + cmp w10, #1 + beq WriteEnd + st1 {v16.4h}, [x18], x17 + st1 {v16.h}[4], [x13], x17 + st1 {v16.h}[5], [x14], x17 + st1 {v16.h}[6], [x16], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.4h}, [x18], x17 + st1 {v18.h}[4], [x13], x17 + st1 {v18.h}[5], [x14], x17 + st1 {v18.h}[6], [x16], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.4h}, [x18], x17 + st1 {v19.h}[4], [x13], x17 + st1 {v19.h}[5], [x14], x17 + st1 {v19.h}[6], [x16], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.4h}, [x18], x17 + st1 {v20.h}[4], [x13], x17 + st1 {v20.h}[5], [x14], x17 + st1 {v20.h}[6], [x16], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.4h}, [x18], x17 + st1 {v21.h}[4], [x13], x17 + st1 {v21.h}[5], [x14], x17 + st1 {v21.h}[6], [x16], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.4h}, [x18], x17 + st1 {v22.h}[4], [x13], x17 + st1 {v22.h}[5], [x14], x17 + st1 {v22.h}[6], [x16], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.4h}, [x18], x17 + st1 {v23.h}[4], [x13], x17 + st1 {v23.h}[5], [x14], x17 + st1 {v23.h}[6], [x16], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4h}, [x18], x17 + st1 {v24.h}[4], [x13], x17 + st1 {v24.h}[5], [x14], x17 + st1 {v24.h}[6], [x16], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.4h}, [x18], x17 + st1 {v25.h}[4], [x13], x17 + st1 {v25.h}[5], [x14], x17 + st1 {v25.h}[6], [x16], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.4h}, [x18], x17 + st1 {v26.h}[4], [x13], x17 + st1 {v26.h}[5], [x14], x17 + st1 {v26.h}[6], [x16], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.4h}, [x18], x17 + st1 {v27.h}[4], [x13], x17 + st1 {v27.h}[5], [x14], x17 + st1 {v27.h}[6], [x16], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.4h}, [x18], x17 + st1 {v28.h}[4], [x13], x17 + st1 {v28.h}[5], [x14], x17 + st1 {v28.h}[6], [x16], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.4h}, [x18], x17 + st1 {v29.h}[4], [x13], x17 + st1 {v29.h}[5], [x14], x17 + st1 {v29.h}[6], [x16], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.4h}, [x18], x17 + st1 {v30.h}[4], [x13], x17 + st1 {v30.h}[5], [x14], x17 + st1 {v30.h}[6], [x16], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.4h}, [x18], x17 + st1 {v31.h}[4], [x13], x17 + st1 {v31.h}[5], [x14], x17 + st1 {v31.h}[6], [x16], x17 + b WriteEnd +WriteC8: + st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x2], #64 + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x2], #64 + st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x2], #64 + st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x2], #64 + b WriteEnd +Write8: + st1 {v16.8h}, [x18], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.8h}, [x18], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.8h}, [x18], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.8h}, [x18], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.8h}, [x18], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.8h}, [x18], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.8h}, [x18], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.8h}, [x18], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.8h}, [x18], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.8h}, [x18], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.8h}, [x18], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.8h}, [x18], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.8h}, [x18], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.8h}, [x18], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.8h}, [x18], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.8h}, [x18], x17 + +WriteEnd: + subs w10, w10, #16 // lhs row - 8 + bgt L2 + +End2: + subs w7, w7, #8 // rhs col - 8 + add x1, x1, x15 // rhs ptr + stride + add x3, x3, #16 // bias ptr + stride + add x2, x2, #16 // dst ptr + stride + bgt L1 + +End1: + sub sp, sp, #128 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 + ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.c new file mode 100644 index 00000000000..82428482f3f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.c @@ -0,0 +1,165 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/matmul_fp16.h" + +void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { + size_t row16 = row / C16NUM * C16NUM; + size_t col8 = col / C8NUM * C8NUM; + float16_t *src_r = src_ptr; + float16_t *dst_r = dst_ptr; + + size_t ri = 0; + for (; ri < row16; ri += C16NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C16NUM; + + /* 16*8 row-major to col-major */ +#ifdef ENABLE_ARM64 + size_t stride = col * 2; + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.8h}, [x10], %[stride]\n" + "ld1 {v1.8h}, [x10], %[stride]\n" + "ld1 {v2.8h}, [x10], %[stride]\n" + "ld1 {v3.8h}, [x10], %[stride]\n" + "ld1 {v4.8h}, [x10], %[stride]\n" + "ld1 {v5.8h}, [x10], %[stride]\n" + "ld1 {v6.8h}, [x10], %[stride]\n" + "ld1 {v7.8h}, [x10], %[stride]\n" + + "zip1 v16.8h, v0.8h, v1.8h\n" + "zip1 v17.8h, v2.8h, v3.8h\n" + "zip1 v18.8h, v4.8h, v5.8h\n" + "zip1 v19.8h, v6.8h, v7.8h\n" + + "ld1 {v8.8h}, [x10], %[stride]\n" + "ld1 {v9.8h}, [x10], %[stride]\n" + "ld1 {v10.8h}, [x10], %[stride]\n" + "ld1 {v11.8h}, [x10], %[stride]\n" + "ld1 {v12.8h}, [x10], %[stride]\n" + "ld1 {v13.8h}, [x10], %[stride]\n" + "ld1 {v14.8h}, [x10], %[stride]\n" + "ld1 {v15.8h}, [x10], %[stride]\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip1 v16.8h, v8.8h, v9.8h\n" + "zip1 v17.8h, v10.8h, v11.8h\n" + "zip1 v18.8h, v12.8h, v13.8h\n" + "zip1 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "st1 {v24.8h}, [x11], #16\n" + "st1 {v28.8h}, [x11], #16\n" + "st1 {v26.8h}, [x11], #16\n" + "st1 {v30.8h}, [x11], #16\n" + "st1 {v25.8h}, [x11], #16\n" + "st1 {v29.8h}, [x11], #16\n" + "st1 {v27.8h}, [x11], #16\n" + "st1 {v31.8h}, [x11], #16\n" + + "zip2 v16.8h, v0.8h, v1.8h\n" + "zip2 v17.8h, v2.8h, v3.8h\n" + "zip2 v18.8h, v4.8h, v5.8h\n" + "zip2 v19.8h, v6.8h, v7.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip2 v16.8h, v8.8h, v9.8h\n" + "zip2 v17.8h, v10.8h, v11.8h\n" + "zip2 v18.8h, v12.8h, v13.8h\n" + "zip2 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "st1 {v24.8h}, [x11], #16\n" + "st1 {v28.8h}, [x11], #16\n" + "st1 {v26.8h}, [x11], #16\n" + "st1 {v30.8h}, [x11], #16\n" + "st1 {v25.8h}, [x11], #16\n" + "st1 {v29.8h}, [x11], #16\n" + "st1 {v27.8h}, [x11], #16\n" + "st1 {v31.8h}, [x11], #16\n" + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31"); +#else + for (int tr = 0; tr < C16NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C16NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C16NUM; + for (size_t i = 0; i < C16NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C16NUM * col; + dst_r += C16NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C16NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } + return; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.h new file mode 100644 index 00000000000..c9a4981caf4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_MATMUL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_MATMUL_H_ + +#include +#include +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/errorcode.h" +#include "nnacl/op_base.h" +#include "nnacl/matmul_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); +#ifdef __aarch64__ +void MatmulFp16Neon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col); +#endif +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_MATMUL_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/op_base.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/op_base.h index 0ce1a5da401..8be6051ced1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/op_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/op_base.h @@ -23,6 +23,7 @@ #define C4NUM 4 #define C8NUM 8 +#define C16NUM 16 #define BLOCK 4 #define TILE_NUM 8