!4516 add fp16 matmul kernel

Merge pull request !4516 from lixian/master
This commit is contained in:
mindspore-ci-bot 2020-08-15 19:13:33 +08:00 committed by Gitee
commit a3b8353b90
5 changed files with 1086 additions and 17 deletions

View File

@ -582,22 +582,10 @@ Write7:
st1 {v31.s}[2], [x16], x17
b WriteEnd
WriteC8:
st1 {v16.4s}, [x2], #16
st1 {v17.4s}, [x2], #16
st1 {v18.4s}, [x2], #16
st1 {v19.4s}, [x2], #16
st1 {v20.4s}, [x2], #16
st1 {v21.4s}, [x2], #16
st1 {v22.4s}, [x2], #16
st1 {v23.4s}, [x2], #16
st1 {v24.4s}, [x2], #16
st1 {v25.4s}, [x2], #16
st1 {v26.4s}, [x2], #16
st1 {v27.4s}, [x2], #16
st1 {v28.4s}, [x2], #16
st1 {v29.4s}, [x2], #16
st1 {v30.4s}, [x2], #16
st1 {v31.4s}, [x2], #16
st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x2], #64
st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x2], #64
st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x2], #64
st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x2], #64
b WriteEnd
Write8:
st1 {v16.4s, v17.4s}, [x18], x17
@ -634,7 +622,7 @@ End2:
ldrb w13, [sp, #8]
cbz w13, NoDstStep
add x2, x2, #32 // dst ptr + stride
NoDstStep:
NoDstStep:
bgt L1
End1:

View File

@ -0,0 +1,874 @@
#ifdef __aarch64__
.text
.align 5
.global MatmulFp16Neon64
#ifndef __APPLE__
.type MatmulFp16Neon64, %function
#endif
// void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
// int depth, int row, int col, int stride, bool write_nhwc)
// x0: a
// x1: b
// x2: c
// x3: bias
// w4: act_type
// w5: depth
// w6: row
// w7: col
// w17: stride
// w13: writeC8
MatmulFp16Neon64:
sub sp, sp, #128
st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64
st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64
mov w18, #32 // sizeof(float) * 8
mul w15, w5, w18 // block stride of lhs/rhs: sizeof(float) * 8 * depth
mov x11, x3 // bias flag
mov x18, #4
ldr x17, [sp]
mul x17, x17, x18
L1:
mov w10, w6 // reload lhs row
mov x12, x0 // reload lhs ptr
mov x18, x2 // reload dst ptr
L2:
mov x16, x1 // reload rhs ptr
mov w13, w5 // reload depth
mov x14, x3 // reload bias ptr
dup v16.4s, wzr
dup v17.4s, wzr
dup v18.4s, wzr
dup v19.4s, wzr
dup v20.4s, wzr
dup v21.4s, wzr
dup v22.4s, wzr
dup v23.4s, wzr
dup v24.4s, wzr
dup v25.4s, wzr
dup v26.4s, wzr
dup v27.4s, wzr
dup v28.4s, wzr
dup v29.4s, wzr
dup v30.4s, wzr
dup v31.4s, wzr
cmp w13, #4
blt CommLoopMul
OptLoopMul8:
ld1 {v0.8h, v1.8h}, [x12], #32
ld1 {v8.8h, v9.8h}, [x16], #32
fmla v16.8h, v8.8h, v0.h[0]
fmla v17.8h, v8.8h, v0.h[1]
fmla v18.8h, v8.8h, v0.h[2]
fmla v19.8h, v8.8h, v0.h[3]
fmla v20.8h, v8.8h, v0.h[4]
fmla v21.8h, v8.8h, v0.h[5]
fmla v22.8h, v8.8h, v0.h[6]
fmla v23.8h, v8.8h, v0.h[7]
ld1 {v2.8h, v3.8h}, [x12], #32
fmla v24.8h, v8.8h, v1.h[0]
fmla v25.8h, v8.8h, v1.h[1]
fmla v26.8h, v8.8h, v1.h[2]
fmla v27.8h, v8.8h, v1.h[3]
fmla v28.8h, v8.8h, v1.h[4]
fmla v29.8h, v8.8h, v1.h[5]
fmla v30.8h, v8.8h, v1.h[6]
fmla v31.8h, v8.8h, v1.h[7]
ld1 {v10.8h, v11.8h}, [x16], #32
fmla v16.8h, v9.8h, v2.h[0]
fmla v17.8h, v9.8h, v2.h[1]
fmla v18.8h, v9.8h, v2.h[2]
fmla v19.8h, v9.8h, v2.h[3]
fmla v20.8h, v9.8h, v2.h[4]
fmla v21.8h, v9.8h, v2.h[5]
fmla v22.8h, v9.8h, v2.h[6]
fmla v23.8h, v9.8h, v2.h[7]
ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12], #64
fmla v24.8h, v9.8h, v3.h[0]
fmla v25.8h, v9.8h, v3.h[1]
fmla v26.8h, v9.8h, v3.h[2]
fmla v27.8h, v9.8h, v3.h[3]
fmla v28.8h, v9.8h, v3.h[4]
fmla v29.8h, v9.8h, v3.h[5]
fmla v30.8h, v9.8h, v3.h[6]
fmla v31.8h, v9.8h, v3.h[7]
ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x16], #64
fmla v16.8h, v10.8h, v4.h[0]
fmla v17.8h, v10.8h, v4.h[1]
fmla v18.8h, v10.8h, v4.h[2]
fmla v19.8h, v10.8h, v4.h[3]
fmla v20.8h, v10.8h, v4.h[4]
fmla v21.8h, v10.8h, v4.h[5]
fmla v22.8h, v10.8h, v4.h[6]
fmla v23.8h, v10.8h, v4.h[7]
ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x12], #64
fmla v24.8h, v10.8h, v5.h[0]
fmla v25.8h, v10.8h, v5.h[1]
fmla v26.8h, v10.8h, v5.h[2]
fmla v27.8h, v10.8h, v5.h[3]
fmla v28.8h, v10.8h, v5.h[4]
fmla v29.8h, v10.8h, v5.h[5]
fmla v30.8h, v10.8h, v5.h[6]
fmla v31.8h, v10.8h, v5.h[7]
ld1 {v4.8h, v5.8h}, [x12], #32
fmla v16.8h, v11.8h, v6.h[0]
fmla v17.8h, v11.8h, v6.h[1]
fmla v18.8h, v11.8h, v6.h[2]
fmla v19.8h, v11.8h, v6.h[3]
fmla v20.8h, v11.8h, v6.h[4]
fmla v21.8h, v11.8h, v6.h[5]
fmla v22.8h, v11.8h, v6.h[6]
fmla v23.8h, v11.8h, v6.h[7]
fmla v24.8h, v11.8h, v7.h[0]
fmla v25.8h, v11.8h, v7.h[1]
fmla v26.8h, v11.8h, v7.h[2]
fmla v27.8h, v11.8h, v7.h[3]
fmla v28.8h, v11.8h, v7.h[4]
fmla v29.8h, v11.8h, v7.h[5]
fmla v30.8h, v11.8h, v7.h[6]
fmla v31.8h, v11.8h, v7.h[7]
ld1 {v6.8h, v7.8h}, [x12], #32
fmla v16.8h, v12.8h, v0.h[0]
fmla v17.8h, v12.8h, v0.h[1]
fmla v18.8h, v12.8h, v0.h[2]
fmla v19.8h, v12.8h, v0.h[3]
fmla v20.8h, v12.8h, v0.h[4]
fmla v21.8h, v12.8h, v0.h[5]
fmla v22.8h, v12.8h, v0.h[6]
fmla v23.8h, v12.8h, v0.h[7]
fmla v24.8h, v12.8h, v1.h[0]
fmla v25.8h, v12.8h, v1.h[1]
fmla v26.8h, v12.8h, v1.h[2]
fmla v27.8h, v12.8h, v1.h[3]
fmla v28.8h, v12.8h, v1.h[4]
fmla v29.8h, v12.8h, v1.h[5]
fmla v30.8h, v12.8h, v1.h[6]
fmla v31.8h, v12.8h, v1.h[7]
fmla v16.8h, v13.8h, v2.h[0]
fmla v17.8h, v13.8h, v2.h[1]
fmla v18.8h, v13.8h, v2.h[2]
fmla v19.8h, v13.8h, v2.h[3]
fmla v20.8h, v13.8h, v2.h[4]
fmla v21.8h, v13.8h, v2.h[5]
fmla v22.8h, v13.8h, v2.h[6]
fmla v23.8h, v13.8h, v2.h[7]
fmla v24.8h, v13.8h, v3.h[0]
fmla v25.8h, v13.8h, v3.h[1]
fmla v26.8h, v13.8h, v3.h[2]
fmla v27.8h, v13.8h, v3.h[3]
fmla v28.8h, v13.8h, v3.h[4]
fmla v29.8h, v13.8h, v3.h[5]
fmla v30.8h, v13.8h, v3.h[6]
fmla v31.8h, v13.8h, v3.h[7]
fmla v16.8h, v14.8h, v4.h[0]
fmla v17.8h, v14.8h, v4.h[1]
fmla v18.8h, v14.8h, v4.h[2]
fmla v19.8h, v14.8h, v4.h[3]
fmla v20.8h, v14.8h, v4.h[4]
fmla v21.8h, v14.8h, v4.h[5]
fmla v22.8h, v14.8h, v4.h[6]
fmla v23.8h, v14.8h, v4.h[7]
fmla v24.8h, v14.8h, v5.h[0]
fmla v25.8h, v14.8h, v5.h[1]
fmla v26.8h, v14.8h, v5.h[2]
fmla v27.8h, v14.8h, v5.h[3]
fmla v28.8h, v14.8h, v5.h[4]
fmla v29.8h, v14.8h, v5.h[5]
fmla v30.8h, v14.8h, v5.h[6]
fmla v31.8h, v14.8h, v5.h[7]
fmla v16.8h, v15.8h, v6.h[0]
fmla v17.8h, v15.8h, v6.h[1]
fmla v18.8h, v15.8h, v6.h[2]
fmla v19.8h, v15.8h, v6.h[3]
fmla v20.8h, v15.8h, v6.h[4]
fmla v21.8h, v15.8h, v6.h[5]
fmla v22.8h, v15.8h, v6.h[6]
fmla v23.8h, v15.8h, v6.h[7]
fmla v24.8h, v15.8h, v7.h[0]
fmla v25.8h, v15.8h, v7.h[1]
fmla v26.8h, v15.8h, v7.h[2]
fmla v27.8h, v15.8h, v7.h[3]
fmla v28.8h, v15.8h, v7.h[4]
fmla v29.8h, v15.8h, v7.h[5]
fmla v30.8h, v15.8h, v7.h[6]
fmla v31.8h, v15.8h, v7.h[7]
sub w13, w13, #8
cmp w13, #0
ble Bias
cmp w13, #8
bge OptLoopMul8
CommLoopMul:
ld1 {v0.8h, v1.8h}, [x12], #32
ld1 {v8.8h}, [x16], #16
fmla v16.8h, v8.8h, v0.h[0]
fmla v17.8h, v8.8h, v0.h[1]
fmla v18.8h, v8.8h, v0.h[2]
fmla v19.8h, v8.8h, v0.h[3]
fmla v20.8h, v8.8h, v0.h[4]
fmla v21.8h, v8.8h, v0.h[5]
fmla v22.8h, v8.8h, v0.h[6]
fmla v23.8h, v8.8h, v0.h[7]
fmla v24.8h, v8.8h, v1.h[0]
fmla v25.8h, v8.8h, v1.h[1]
fmla v26.8h, v8.8h, v1.h[2]
fmla v27.8h, v8.8h, v1.h[3]
fmla v28.8h, v8.8h, v1.h[4]
fmla v29.8h, v8.8h, v1.h[5]
fmla v30.8h, v8.8h, v1.h[6]
fmla v31.8h, v8.8h, v1.h[7]
subs w13, w13, #1
bgt CommLoopMul
Bias:
cbz x11, Activation
ld1 {v0.8h}, [x14], #16
fadd v16.8h, v16.8h, v0.8h
fadd v17.8h, v17.8h, v0.8h
fadd v18.8h, v18.8h, v0.8h
fadd v19.8h, v19.8h, v0.8h
fadd v20.8h, v20.8h, v0.8h
fadd v21.8h, v21.8h, v0.8h
fadd v22.8h, v22.8h, v0.8h
fadd v23.8h, v23.8h, v0.8h
fadd v24.8h, v24.8h, v0.8h
fadd v25.8h, v25.8h, v0.8h
fadd v26.8h, v26.8h, v0.8h
fadd v27.8h, v27.8h, v0.8h
fadd v28.8h, v28.8h, v0.8h
fadd v29.8h, v29.8h, v0.8h
fadd v30.8h, v30.8h, v0.8h
fadd v31.8h, v31.8h, v0.8h
Activation:
cmp w4, #2
beq Relu6
cmp w4, #1
beq Relu
b Write
Relu6:
movi v15.8h, #0x46, lsl #8
fmin v16.8h, v16.8h, v15.8h
fmin v17.8h, v17.8h, v15.8h
fmin v18.8h, v18.8h, v15.8h
fmin v19.8h, v19.8h, v15.8h
fmin v20.8h, v20.8h, v15.8h
fmin v21.8h, v21.8h, v15.8h
fmin v22.8h, v22.8h, v15.8h
fmin v23.8h, v23.8h, v15.8h
fmin v24.8h, v24.8h, v15.8h
fmin v25.8h, v25.8h, v15.8h
fmin v26.8h, v26.8h, v15.8h
fmin v27.8h, v27.8h, v15.8h
fmin v28.8h, v28.8h, v15.8h
fmin v29.8h, v29.8h, v15.8h
fmin v30.8h, v30.8h, v15.8h
fmin v31.8h, v31.8h, v15.8h
Relu:
dup v14.4s, wzr
fmax v16.8h, v16.8h, v14.8h
fmax v17.8h, v17.8h, v14.8h
fmax v18.8h, v18.8h, v14.8h
fmax v19.8h, v19.8h, v14.8h
fmax v20.8h, v20.8h, v14.8h
fmax v21.8h, v21.8h, v14.8h
fmax v22.8h, v22.8h, v14.8h
fmax v23.8h, v23.8h, v14.8h
fmax v24.8h, v24.8h, v14.8h
fmax v25.8h, v25.8h, v14.8h
fmax v26.8h, v26.8h, v14.8h
fmax v27.8h, v27.8h, v14.8h
fmax v28.8h, v28.8h, v14.8h
fmax v29.8h, v29.8h, v14.8h
fmax v30.8h, v30.8h, v14.8h
fmax v31.8h, v31.8h, v14.8h
Write:
ldrb w13, [sp, #8]
cbz w13, WriteC8
cmp w7, #1
beq Write1
cmp w7, #2
beq Write2
cmp w7, #3
beq Write3
cmp w7, #4
beq Write4
cmp w7, #5
beq Write5
cmp w7, #6
beq Write6
cmp w7, #7
beq Write7
b Write8
Write1:
st1 {v16.h}[0], [x18], x17
cmp w10, #1
beq WriteEnd
st1 {v17.h}[0], [x18], x17
cmp w10, #2
beq WriteEnd
st1 {v18.h}[0], [x18], x17
cmp w10, #3
beq WriteEnd
st1 {v19.h}[0], [x18], x17
cmp w10, #4
beq WriteEnd
st1 {v20.h}[0], [x18], x17
cmp w10, #5
beq WriteEnd
st1 {v21.h}[0], [x18], x17
cmp w10, #6
beq WriteEnd
st1 {v22.h}[0], [x18], x17
cmp w10, #7
beq WriteEnd
st1 {v23.h}[0], [x18], x17
cmp w10, #8
beq WriteEnd
st1 {v24.h}[0], [x18], x17
cmp w10, #9
beq WriteEnd
st1 {v25.h}[0], [x18], x17
cmp w10, #10
beq WriteEnd
st1 {v26.h}[0], [x18], x17
cmp w10, #11
beq WriteEnd
st1 {v27.h}[0], [x18], x17
cmp w10, #12
beq WriteEnd
st1 {v28.h}[0], [x18], x17
cmp w10, #13
beq WriteEnd
st1 {v29.h}[0], [x18], x17
cmp w10, #14
beq WriteEnd
st1 {v30.h}[0], [x18], x17
cmp w10, #15
beq WriteEnd
st1 {v31.h}[0], [x18], x17
b WriteEnd
Write2:
add x13, x18, #2
st1 {v16.h}[0], [x18], x17
st1 {v16.h}[1], [x13], x17
cmp w10, #1
beq WriteEnd
st1 {v17.h}[0], [x18], x17
st1 {v17.h}[1], [x13], x17
cmp w10, #2
beq WriteEnd
st1 {v18.h}[0], [x18], x17
st1 {v18.h}[1], [x13], x17
cmp w10, #3
beq WriteEnd
st1 {v19.h}[0], [x18], x17
st1 {v19.h}[1], [x13], x17
cmp w10, #4
beq WriteEnd
st1 {v20.h}[0], [x18], x17
st1 {v20.h}[1], [x13], x17
cmp w10, #5
beq WriteEnd
st1 {v21.h}[0], [x18], x17
st1 {v21.h}[1], [x13], x17
cmp w10, #6
beq WriteEnd
st1 {v22.h}[0], [x18], x17
st1 {v22.h}[1], [x13], x17
cmp w10, #7
beq WriteEnd
st1 {v23.h}[0], [x18], x17
st1 {v23.h}[1], [x13], x17
cmp w10, #8
beq WriteEnd
st1 {v24.h}[0], [x18], x17
st1 {v24.h}[1], [x13], x17
cmp w10, #9
beq WriteEnd
st1 {v25.h}[0], [x18], x17
st1 {v25.h}[1], [x13], x17
cmp w10, #10
beq WriteEnd
st1 {v26.h}[0], [x18], x17
st1 {v26.h}[1], [x13], x17
cmp w10, #11
beq WriteEnd
st1 {v27.h}[0], [x18], x17
st1 {v27.h}[1], [x13], x17
cmp w10, #12
beq WriteEnd
st1 {v28.h}[0], [x18], x17
st1 {v28.h}[1], [x13], x17
cmp w10, #13
beq WriteEnd
st1 {v29.h}[0], [x18], x17
st1 {v29.h}[1], [x13], x17
cmp w10, #14
beq WriteEnd
st1 {v30.h}[0], [x18], x17
st1 {v30.h}[1], [x13], x17
cmp w10, #15
beq WriteEnd
st1 {v31.h}[0], [x18], x17
st1 {v31.h}[1], [x13], x17
b WriteEnd
Write3:
add x13, x18, #2
add x14, x18, #4
st1 {v16.h}[0], [x18], x17
st1 {v16.h}[1], [x13], x17
st1 {v16.h}[2], [x14], x17
cmp w10, #1
beq WriteEnd
st1 {v17.h}[0], [x18], x17
st1 {v17.h}[1], [x13], x17
st1 {v17.h}[2], [x14], x17
cmp w10, #2
beq WriteEnd
st1 {v18.h}[0], [x18], x17
st1 {v18.h}[1], [x13], x17
st1 {v18.h}[2], [x14], x17
cmp w10, #3
beq WriteEnd
st1 {v19.h}[0], [x18], x17
st1 {v19.h}[1], [x13], x17
st1 {v19.h}[2], [x14], x17
cmp w10, #4
beq WriteEnd
st1 {v20.h}[0], [x18], x17
st1 {v20.h}[1], [x13], x17
st1 {v20.h}[2], [x14], x17
cmp w10, #5
beq WriteEnd
st1 {v21.h}[0], [x18], x17
st1 {v21.h}[1], [x13], x17
st1 {v21.h}[2], [x14], x17
cmp w10, #6
beq WriteEnd
st1 {v22.h}[0], [x18], x17
st1 {v22.h}[1], [x13], x17
st1 {v22.h}[2], [x14], x17
cmp w10, #7
beq WriteEnd
st1 {v23.h}[0], [x18], x17
st1 {v23.h}[1], [x13], x17
st1 {v23.h}[2], [x14], x17
cmp w10, #8
beq WriteEnd
st1 {v24.h}[0], [x18], x17
st1 {v24.h}[1], [x13], x17
st1 {v24.h}[2], [x14], x17
cmp w10, #9
beq WriteEnd
st1 {v25.h}[0], [x18], x17
st1 {v25.h}[1], [x13], x17
st1 {v25.h}[2], [x14], x17
cmp w10, #10
beq WriteEnd
st1 {v26.h}[0], [x18], x17
st1 {v26.h}[1], [x13], x17
st1 {v26.h}[2], [x14], x17
cmp w10, #11
beq WriteEnd
st1 {v27.h}[0], [x18], x17
st1 {v27.h}[1], [x13], x17
st1 {v27.h}[2], [x14], x17
cmp w10, #12
beq WriteEnd
st1 {v28.h}[0], [x18], x17
st1 {v28.h}[1], [x13], x17
st1 {v28.h}[2], [x14], x17
cmp w10, #13
beq WriteEnd
st1 {v29.h}[0], [x18], x17
st1 {v29.h}[1], [x13], x17
st1 {v29.h}[2], [x14], x17
cmp w10, #14
beq WriteEnd
st1 {v30.h}[0], [x18], x17
st1 {v30.h}[1], [x13], x17
st1 {v30.h}[2], [x14], x17
cmp w10, #15
beq WriteEnd
st1 {v31.h}[0], [x18], x17
st1 {v31.h}[1], [x13], x17
st1 {v31.h}[2], [x14], x17
b WriteEnd
Write4:
st1 {v16.4h}, [x18], x17
cmp w10, #1
beq WriteEnd
st1 {v16.4h}, [x18], x17
cmp w10, #2
beq WriteEnd
st1 {v18.4h}, [x18], x17
cmp w10, #3
beq WriteEnd
st1 {v19.4h}, [x18], x17
cmp w10, #4
beq WriteEnd
st1 {v20.4h}, [x18], x17
cmp w10, #5
beq WriteEnd
st1 {v21.4h}, [x18], x17
cmp w10, #6
beq WriteEnd
st1 {v22.4h}, [x18], x17
cmp w10, #7
beq WriteEnd
st1 {v23.4h}, [x18], x17
cmp w10, #8
beq WriteEnd
st1 {v24.4h}, [x18], x17
cmp w10, #9
beq WriteEnd
st1 {v25.4h}, [x18], x17
cmp w10, #10
beq WriteEnd
st1 {v26.4h}, [x18], x17
cmp w10, #11
beq WriteEnd
st1 {v27.4h}, [x18], x17
cmp w10, #12
beq WriteEnd
st1 {v28.4h}, [x18], x17
cmp w10, #13
beq WriteEnd
st1 {v29.4h}, [x18], x17
cmp w10, #14
beq WriteEnd
st1 {v30.4h}, [x18], x17
cmp w10, #15
beq WriteEnd
st1 {v31.4h}, [x18], x17
b WriteEnd
Write5:
add x13, x18, #8
st1 {v16.4h}, [x18], x17
st1 {v16.h}[4], [x13], x17
cmp w10, #1
beq WriteEnd
st1 {v16.4h}, [x18], x17
st1 {v16.h}[4], [x13], x17
cmp w10, #2
beq WriteEnd
st1 {v18.4h}, [x18], x17
st1 {v18.h}[4], [x13], x17
cmp w10, #3
beq WriteEnd
st1 {v19.4h}, [x18], x17
st1 {v19.h}[4], [x13], x17
cmp w10, #4
beq WriteEnd
st1 {v20.4h}, [x18], x17
st1 {v20.h}[4], [x13], x17
cmp w10, #5
beq WriteEnd
st1 {v21.4h}, [x18], x17
st1 {v21.h}[4], [x13], x17
cmp w10, #6
beq WriteEnd
st1 {v22.4h}, [x18], x17
st1 {v22.h}[4], [x13], x17
cmp w10, #7
beq WriteEnd
st1 {v23.4h}, [x18], x17
st1 {v23.h}[4], [x13], x17
cmp w10, #8
beq WriteEnd
st1 {v24.4h}, [x18], x17
st1 {v24.h}[4], [x13], x17
cmp w10, #9
beq WriteEnd
st1 {v25.4h}, [x18], x17
st1 {v25.h}[4], [x13], x17
cmp w10, #10
beq WriteEnd
st1 {v26.4h}, [x18], x17
st1 {v26.h}[4], [x13], x17
cmp w10, #11
beq WriteEnd
st1 {v27.4h}, [x18], x17
st1 {v27.h}[4], [x13], x17
cmp w10, #12
beq WriteEnd
st1 {v28.4h}, [x18], x17
st1 {v28.h}[4], [x13], x17
cmp w10, #13
beq WriteEnd
st1 {v29.4h}, [x18], x17
st1 {v29.h}[4], [x13], x17
cmp w10, #14
beq WriteEnd
st1 {v30.4h}, [x18], x17
st1 {v30.h}[4], [x13], x17
cmp w10, #15
beq WriteEnd
st1 {v31.4h}, [x18], x17
st1 {v31.h}[4], [x13], x17
b WriteEnd
Write6:
add x13, x18, #8
add x14, x18, #10
st1 {v16.4h}, [x18], x17
st1 {v16.h}[4], [x13], x17
st1 {v16.h}[5], [x14], x17
cmp w10, #1
beq WriteEnd
st1 {v16.4h}, [x18], x17
st1 {v16.h}[4], [x13], x17
st1 {v16.h}[5], [x14], x17
cmp w10, #2
beq WriteEnd
st1 {v18.4h}, [x18], x17
st1 {v18.h}[4], [x13], x17
st1 {v18.h}[5], [x14], x17
cmp w10, #3
beq WriteEnd
st1 {v19.4h}, [x18], x17
st1 {v19.h}[4], [x13], x17
st1 {v19.h}[5], [x14], x17
cmp w10, #4
beq WriteEnd
st1 {v20.4h}, [x18], x17
st1 {v20.h}[4], [x13], x17
st1 {v20.h}[5], [x14], x17
cmp w10, #5
beq WriteEnd
st1 {v21.4h}, [x18], x17
st1 {v21.h}[4], [x13], x17
st1 {v21.h}[5], [x14], x17
cmp w10, #6
beq WriteEnd
st1 {v22.4h}, [x18], x17
st1 {v22.h}[4], [x13], x17
st1 {v22.h}[5], [x14], x17
cmp w10, #7
beq WriteEnd
st1 {v23.4h}, [x18], x17
st1 {v23.h}[4], [x13], x17
st1 {v23.h}[5], [x14], x17
cmp w10, #8
beq WriteEnd
st1 {v24.4h}, [x18], x17
st1 {v24.h}[4], [x13], x17
st1 {v24.h}[5], [x14], x17
cmp w10, #9
beq WriteEnd
st1 {v25.4h}, [x18], x17
st1 {v25.h}[4], [x13], x17
st1 {v25.h}[5], [x14], x17
cmp w10, #10
beq WriteEnd
st1 {v26.4h}, [x18], x17
st1 {v26.h}[4], [x13], x17
st1 {v26.h}[5], [x14], x17
cmp w10, #11
beq WriteEnd
st1 {v27.4h}, [x18], x17
st1 {v27.h}[4], [x13], x17
st1 {v27.h}[5], [x14], x17
cmp w10, #12
beq WriteEnd
st1 {v28.4h}, [x18], x17
st1 {v28.h}[4], [x13], x17
st1 {v28.h}[5], [x14], x17
cmp w10, #13
beq WriteEnd
st1 {v29.4h}, [x18], x17
st1 {v29.h}[4], [x13], x17
st1 {v29.h}[5], [x14], x17
cmp w10, #14
beq WriteEnd
st1 {v30.4h}, [x18], x17
st1 {v30.h}[4], [x13], x17
st1 {v30.h}[5], [x14], x17
cmp w10, #15
beq WriteEnd
st1 {v31.4h}, [x18], x17
st1 {v31.h}[4], [x13], x17
st1 {v31.h}[5], [x14], x17
b WriteEnd
Write7:
add x13, x18, #8
add x14, x18, #10
add x16, x18, #12
st1 {v16.4h}, [x18], x17
st1 {v16.h}[4], [x13], x17
st1 {v16.h}[5], [x14], x17
st1 {v16.h}[6], [x16], x17
cmp w10, #1
beq WriteEnd
st1 {v16.4h}, [x18], x17
st1 {v16.h}[4], [x13], x17
st1 {v16.h}[5], [x14], x17
st1 {v16.h}[6], [x16], x17
cmp w10, #2
beq WriteEnd
st1 {v18.4h}, [x18], x17
st1 {v18.h}[4], [x13], x17
st1 {v18.h}[5], [x14], x17
st1 {v18.h}[6], [x16], x17
cmp w10, #3
beq WriteEnd
st1 {v19.4h}, [x18], x17
st1 {v19.h}[4], [x13], x17
st1 {v19.h}[5], [x14], x17
st1 {v19.h}[6], [x16], x17
cmp w10, #4
beq WriteEnd
st1 {v20.4h}, [x18], x17
st1 {v20.h}[4], [x13], x17
st1 {v20.h}[5], [x14], x17
st1 {v20.h}[6], [x16], x17
cmp w10, #5
beq WriteEnd
st1 {v21.4h}, [x18], x17
st1 {v21.h}[4], [x13], x17
st1 {v21.h}[5], [x14], x17
st1 {v21.h}[6], [x16], x17
cmp w10, #6
beq WriteEnd
st1 {v22.4h}, [x18], x17
st1 {v22.h}[4], [x13], x17
st1 {v22.h}[5], [x14], x17
st1 {v22.h}[6], [x16], x17
cmp w10, #7
beq WriteEnd
st1 {v23.4h}, [x18], x17
st1 {v23.h}[4], [x13], x17
st1 {v23.h}[5], [x14], x17
st1 {v23.h}[6], [x16], x17
cmp w10, #8
beq WriteEnd
st1 {v24.4h}, [x18], x17
st1 {v24.h}[4], [x13], x17
st1 {v24.h}[5], [x14], x17
st1 {v24.h}[6], [x16], x17
cmp w10, #9
beq WriteEnd
st1 {v25.4h}, [x18], x17
st1 {v25.h}[4], [x13], x17
st1 {v25.h}[5], [x14], x17
st1 {v25.h}[6], [x16], x17
cmp w10, #10
beq WriteEnd
st1 {v26.4h}, [x18], x17
st1 {v26.h}[4], [x13], x17
st1 {v26.h}[5], [x14], x17
st1 {v26.h}[6], [x16], x17
cmp w10, #11
beq WriteEnd
st1 {v27.4h}, [x18], x17
st1 {v27.h}[4], [x13], x17
st1 {v27.h}[5], [x14], x17
st1 {v27.h}[6], [x16], x17
cmp w10, #12
beq WriteEnd
st1 {v28.4h}, [x18], x17
st1 {v28.h}[4], [x13], x17
st1 {v28.h}[5], [x14], x17
st1 {v28.h}[6], [x16], x17
cmp w10, #13
beq WriteEnd
st1 {v29.4h}, [x18], x17
st1 {v29.h}[4], [x13], x17
st1 {v29.h}[5], [x14], x17
st1 {v29.h}[6], [x16], x17
cmp w10, #14
beq WriteEnd
st1 {v30.4h}, [x18], x17
st1 {v30.h}[4], [x13], x17
st1 {v30.h}[5], [x14], x17
st1 {v30.h}[6], [x16], x17
cmp w10, #15
beq WriteEnd
st1 {v31.4h}, [x18], x17
st1 {v31.h}[4], [x13], x17
st1 {v31.h}[5], [x14], x17
st1 {v31.h}[6], [x16], x17
b WriteEnd
WriteC8:
st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x2], #64
st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x2], #64
st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x2], #64
st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x2], #64
b WriteEnd
Write8:
st1 {v16.8h}, [x18], x17
cmp w10, #1
beq WriteEnd
st1 {v17.8h}, [x18], x17
cmp w10, #2
beq WriteEnd
st1 {v18.8h}, [x18], x17
cmp w10, #3
beq WriteEnd
st1 {v19.8h}, [x18], x17
cmp w10, #4
beq WriteEnd
st1 {v20.8h}, [x18], x17
cmp w10, #5
beq WriteEnd
st1 {v21.8h}, [x18], x17
cmp w10, #6
beq WriteEnd
st1 {v22.8h}, [x18], x17
cmp w10, #7
beq WriteEnd
st1 {v23.8h}, [x18], x17
cmp w10, #8
beq WriteEnd
st1 {v24.8h}, [x18], x17
cmp w10, #9
beq WriteEnd
st1 {v25.8h}, [x18], x17
cmp w10, #10
beq WriteEnd
st1 {v26.8h}, [x18], x17
cmp w10, #11
beq WriteEnd
st1 {v27.8h}, [x18], x17
cmp w10, #12
beq WriteEnd
st1 {v28.8h}, [x18], x17
cmp w10, #13
beq WriteEnd
st1 {v29.8h}, [x18], x17
cmp w10, #14
beq WriteEnd
st1 {v30.8h}, [x18], x17
cmp w10, #15
beq WriteEnd
st1 {v31.8h}, [x18], x17
WriteEnd:
subs w10, w10, #16 // lhs row - 8
bgt L2
End2:
subs w7, w7, #8 // rhs col - 8
add x1, x1, x15 // rhs ptr + stride
add x3, x3, #16 // bias ptr + stride
add x2, x2, #16 // dst ptr + stride
bgt L1
End1:
sub sp, sp, #128
ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64
ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64
ret
#endif

View File

@ -0,0 +1,165 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp16/matmul_fp16.h"
void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
size_t row16 = row / C16NUM * C16NUM;
size_t col8 = col / C8NUM * C8NUM;
float16_t *src_r = src_ptr;
float16_t *dst_r = dst_ptr;
size_t ri = 0;
for (; ri < row16; ri += C16NUM) {
size_t ci = 0;
for (; ci < col8; ci += C8NUM) {
float16_t *src_c = src_r + ci;
float16_t *dst_c = dst_r + ci * C16NUM;
/* 16*8 row-major to col-major */
#ifdef ENABLE_ARM64
size_t stride = col * 2;
asm volatile(
"mov x10, %[src_c]\n"
"mov x11, %[dst_c]\n"
"ld1 {v0.8h}, [x10], %[stride]\n"
"ld1 {v1.8h}, [x10], %[stride]\n"
"ld1 {v2.8h}, [x10], %[stride]\n"
"ld1 {v3.8h}, [x10], %[stride]\n"
"ld1 {v4.8h}, [x10], %[stride]\n"
"ld1 {v5.8h}, [x10], %[stride]\n"
"ld1 {v6.8h}, [x10], %[stride]\n"
"ld1 {v7.8h}, [x10], %[stride]\n"
"zip1 v16.8h, v0.8h, v1.8h\n"
"zip1 v17.8h, v2.8h, v3.8h\n"
"zip1 v18.8h, v4.8h, v5.8h\n"
"zip1 v19.8h, v6.8h, v7.8h\n"
"ld1 {v8.8h}, [x10], %[stride]\n"
"ld1 {v9.8h}, [x10], %[stride]\n"
"ld1 {v10.8h}, [x10], %[stride]\n"
"ld1 {v11.8h}, [x10], %[stride]\n"
"ld1 {v12.8h}, [x10], %[stride]\n"
"ld1 {v13.8h}, [x10], %[stride]\n"
"ld1 {v14.8h}, [x10], %[stride]\n"
"ld1 {v15.8h}, [x10], %[stride]\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v24.2d, v20.2d, v22.2d\n"
"trn2 v25.2d, v20.2d, v22.2d\n"
"trn1 v26.2d, v21.2d, v23.2d\n"
"trn2 v27.2d, v21.2d, v23.2d\n"
"zip1 v16.8h, v8.8h, v9.8h\n"
"zip1 v17.8h, v10.8h, v11.8h\n"
"zip1 v18.8h, v12.8h, v13.8h\n"
"zip1 v19.8h, v14.8h, v15.8h\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v28.2d, v20.2d, v22.2d\n"
"trn2 v29.2d, v20.2d, v22.2d\n"
"trn1 v30.2d, v21.2d, v23.2d\n"
"trn2 v31.2d, v21.2d, v23.2d\n"
"st1 {v24.8h}, [x11], #16\n"
"st1 {v28.8h}, [x11], #16\n"
"st1 {v26.8h}, [x11], #16\n"
"st1 {v30.8h}, [x11], #16\n"
"st1 {v25.8h}, [x11], #16\n"
"st1 {v29.8h}, [x11], #16\n"
"st1 {v27.8h}, [x11], #16\n"
"st1 {v31.8h}, [x11], #16\n"
"zip2 v16.8h, v0.8h, v1.8h\n"
"zip2 v17.8h, v2.8h, v3.8h\n"
"zip2 v18.8h, v4.8h, v5.8h\n"
"zip2 v19.8h, v6.8h, v7.8h\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v24.2d, v20.2d, v22.2d\n"
"trn2 v25.2d, v20.2d, v22.2d\n"
"trn1 v26.2d, v21.2d, v23.2d\n"
"trn2 v27.2d, v21.2d, v23.2d\n"
"zip2 v16.8h, v8.8h, v9.8h\n"
"zip2 v17.8h, v10.8h, v11.8h\n"
"zip2 v18.8h, v12.8h, v13.8h\n"
"zip2 v19.8h, v14.8h, v15.8h\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v28.2d, v20.2d, v22.2d\n"
"trn2 v29.2d, v20.2d, v22.2d\n"
"trn1 v30.2d, v21.2d, v23.2d\n"
"trn2 v31.2d, v21.2d, v23.2d\n"
"st1 {v24.8h}, [x11], #16\n"
"st1 {v28.8h}, [x11], #16\n"
"st1 {v26.8h}, [x11], #16\n"
"st1 {v30.8h}, [x11], #16\n"
"st1 {v25.8h}, [x11], #16\n"
"st1 {v29.8h}, [x11], #16\n"
"st1 {v27.8h}, [x11], #16\n"
"st1 {v31.8h}, [x11], #16\n"
:
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
"v30", "v31");
#else
for (int tr = 0; tr < C16NUM; tr++) {
for (int tc = 0; tc < C8NUM; tc++) {
dst_c[tc * C16NUM + tr] = src_c[tr * col + tc];
}
}
#endif
}
for (; ci < col; ci++) {
float16_t *src_c = src_r + ci;
float16_t *dst_c = dst_r + ci * C16NUM;
for (size_t i = 0; i < C16NUM; i++) {
dst_c[i] = src_c[i * col];
}
}
src_r += C16NUM * col;
dst_r += C16NUM * col;
}
for (; ri < row; ri++) {
for (size_t i = 0; i < col; i++) {
dst_r[i * C16NUM] = src_r[i];
}
src_r += col;
dst_r += 1;
}
return;
}

View File

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_MATMUL_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_MATMUL_H_
#include <string.h>
#include <float.h>
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include "nnacl/errorcode.h"
#include "nnacl/op_base.h"
#include "nnacl/matmul_parameter.h"
#ifdef __cplusplus
extern "C" {
#endif
void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col);
#ifdef __aarch64__
void MatmulFp16Neon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col);
#endif
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_MATMUL_H_

View File

@ -23,6 +23,7 @@
#define C4NUM 4
#define C8NUM 8
#define C16NUM 16
#define BLOCK 4
#define TILE_NUM 8