!20130 [MS][LITE][CPU] arm32 fp16 算子优化

Merge pull request !20130 from liuzhongkai/pack
This commit is contained in:
i-robot 2021-07-15 01:10:43 +00:00 committed by Gitee
commit f33c06abae
9 changed files with 1273 additions and 158 deletions

View File

@ -0,0 +1,943 @@
#ifdef ENABLE_ARM64
#include "nnacl/assembly_global.h"
.text
.align 5
// void MatmulBaseFp16Neon(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
// int depth, int row, int col, size_t stride, size_t writeMode)
// x0: a
// x1: b
// x2: c
// x3: bias
// x4: act_type
// x5: depth
// x6: row
// x7: col
// x8: stride
// x9: writeMode
asm_function MatmulBaseFp16Neon
sub sp, sp, #96
st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64
stp x19, x20, [sp], #16
stp x21, x22, [sp], #16
ldr x8, [sp]
ldr x9, [sp, #8] // act
add x8, x8, x8 // stride * sizeof(float16_t)
add x16, x7, x7 // col * sizeof(float16_t)
add x17, x5, x5 // depth * zieof(float16_t)
mov x11, x2
dup v12.8h, wzr
movi v13.8h, #0x46, lsl #8
LoopRowStart:
cmp x6, #16
bge LoopRow16
cmp x6, #8
bge LoopRow8
b LoopRow4
LoopRow16:
mov x15, #16
mov x14, x1 // reload rhs ptr
mov x13, x7 // reload rhs col
mov x12, x3 // reload bias
LoopCol16:
mov x11, x2
mov x10, x0 // reload lhs ptr
mov x19, x5 // reload depth
ld1 {v16.8h}, [x12], #16
mov v17.8h, v16.8h
mov v18.8h, v16.8h
mov v19.8h, v16.8h
mov v20.8h, v16.8h
mov v21.8h, v16.8h
mov v22.8h, v16.8h
mov v23.8h, v16.8h
mov v24.8h, v16.8h
mov v25.8h, v16.8h
mov v26.8h, v16.8h
mov v27.8h, v16.8h
mov v28.8h, v16.8h
mov v29.8h, v16.8h
mov v30.8h, v16.8h
mov v31.8h, v16.8h
cmp x19, #4
blt LoopDepth16One
LoopDepth16:
ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64
ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64
fmla v16.8h, v8.8h, v0.h[0]
fmla v17.8h, v8.8h, v0.h[1]
fmla v18.8h, v8.8h, v0.h[2]
fmla v19.8h, v8.8h, v0.h[3]
fmla v20.8h, v8.8h, v0.h[4]
fmla v21.8h, v8.8h, v0.h[5]
fmla v22.8h, v8.8h, v0.h[6]
fmla v23.8h, v8.8h, v0.h[7]
fmla v24.8h, v8.8h, v1.h[0]
fmla v25.8h, v8.8h, v1.h[1]
fmla v26.8h, v8.8h, v1.h[2]
fmla v27.8h, v8.8h, v1.h[3]
fmla v28.8h, v8.8h, v1.h[4]
fmla v29.8h, v8.8h, v1.h[5]
fmla v30.8h, v8.8h, v1.h[6]
fmla v31.8h, v8.8h, v1.h[7]
ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64
fmla v16.8h, v9.8h, v2.h[0]
fmla v17.8h, v9.8h, v2.h[1]
fmla v18.8h, v9.8h, v2.h[2]
fmla v19.8h, v9.8h, v2.h[3]
fmla v20.8h, v9.8h, v2.h[4]
fmla v21.8h, v9.8h, v2.h[5]
fmla v22.8h, v9.8h, v2.h[6]
fmla v23.8h, v9.8h, v2.h[7]
fmla v24.8h, v9.8h, v3.h[0]
fmla v25.8h, v9.8h, v3.h[1]
fmla v26.8h, v9.8h, v3.h[2]
fmla v27.8h, v9.8h, v3.h[3]
fmla v28.8h, v9.8h, v3.h[4]
fmla v29.8h, v9.8h, v3.h[5]
fmla v30.8h, v9.8h, v3.h[6]
fmla v31.8h, v9.8h, v3.h[7]
fmla v16.8h, v10.8h, v4.h[0]
fmla v17.8h, v10.8h, v4.h[1]
fmla v18.8h, v10.8h, v4.h[2]
fmla v19.8h, v10.8h, v4.h[3]
fmla v20.8h, v10.8h, v4.h[4]
fmla v21.8h, v10.8h, v4.h[5]
fmla v22.8h, v10.8h, v4.h[6]
fmla v23.8h, v10.8h, v4.h[7]
fmla v24.8h, v10.8h, v5.h[0]
fmla v25.8h, v10.8h, v5.h[1]
fmla v26.8h, v10.8h, v5.h[2]
fmla v27.8h, v10.8h, v5.h[3]
fmla v28.8h, v10.8h, v5.h[4]
fmla v29.8h, v10.8h, v5.h[5]
fmla v30.8h, v10.8h, v5.h[6]
fmla v31.8h, v10.8h, v5.h[7]
fmla v16.8h, v11.8h, v6.h[0]
fmla v17.8h, v11.8h, v6.h[1]
fmla v18.8h, v11.8h, v6.h[2]
fmla v19.8h, v11.8h, v6.h[3]
fmla v20.8h, v11.8h, v6.h[4]
fmla v21.8h, v11.8h, v6.h[5]
fmla v22.8h, v11.8h, v6.h[6]
fmla v23.8h, v11.8h, v6.h[7]
fmla v24.8h, v11.8h, v7.h[0]
fmla v25.8h, v11.8h, v7.h[1]
fmla v26.8h, v11.8h, v7.h[2]
fmla v27.8h, v11.8h, v7.h[3]
fmla v28.8h, v11.8h, v7.h[4]
fmla v29.8h, v11.8h, v7.h[5]
fmla v30.8h, v11.8h, v7.h[6]
fmla v31.8h, v11.8h, v7.h[7]
subs x19, x19, #4
beq Activation16
cmp x19, #4
bge LoopDepth16
LoopDepth16One:
ld1 {v0.8h, v1.8h}, [x10], #32
ld1 {v2.8h}, [x14], #16
fmla v16.8h, v2.8h, v0.h[0]
fmla v17.8h, v2.8h, v0.h[1]
fmla v18.8h, v2.8h, v0.h[2]
fmla v19.8h, v2.8h, v0.h[3]
fmla v20.8h, v2.8h, v0.h[4]
fmla v21.8h, v2.8h, v0.h[5]
fmla v22.8h, v2.8h, v0.h[6]
fmla v23.8h, v2.8h, v0.h[7]
fmla v24.8h, v2.8h, v1.h[0]
fmla v25.8h, v2.8h, v1.h[1]
fmla v26.8h, v2.8h, v1.h[2]
fmla v27.8h, v2.8h, v1.h[3]
fmla v28.8h, v2.8h, v1.h[4]
fmla v29.8h, v2.8h, v1.h[5]
fmla v30.8h, v2.8h, v1.h[6]
fmla v31.8h, v2.8h, v1.h[7]
subs x19, x19, #1
bgt LoopDepth16One
Activation16:
cmp x4, #3
beq Relu616
cmp x4, #1
beq Relu16
b Write16
Relu616:
fmin v16.8h, v16.8h, v13.8h
fmin v17.8h, v17.8h, v13.8h
fmin v18.8h, v18.8h, v13.8h
fmin v19.8h, v19.8h, v13.8h
fmin v20.8h, v20.8h, v13.8h
fmin v21.8h, v21.8h, v13.8h
fmin v22.8h, v22.8h, v13.8h
fmin v23.8h, v23.8h, v13.8h
fmin v24.8h, v24.8h, v13.8h
fmin v25.8h, v25.8h, v13.8h
fmin v26.8h, v26.8h, v13.8h
fmin v27.8h, v27.8h, v13.8h
fmin v28.8h, v28.8h, v13.8h
fmin v29.8h, v29.8h, v13.8h
fmin v30.8h, v30.8h, v13.8h
fmin v31.8h, v31.8h, v13.8h
Relu16:
fmax v16.8h, v16.8h, v12.8h
fmax v17.8h, v17.8h, v12.8h
fmax v18.8h, v18.8h, v12.8h
fmax v19.8h, v19.8h, v12.8h
fmax v20.8h, v20.8h, v12.8h
fmax v21.8h, v21.8h, v12.8h
fmax v22.8h, v22.8h, v12.8h
fmax v23.8h, v23.8h, v12.8h
fmax v24.8h, v24.8h, v12.8h
fmax v25.8h, v25.8h, v12.8h
fmax v26.8h, v26.8h, v12.8h
fmax v27.8h, v27.8h, v12.8h
fmax v28.8h, v28.8h, v12.8h
fmax v29.8h, v29.8h, v12.8h
fmax v30.8h, v30.8h, v12.8h
fmax v31.8h, v31.8h, v12.8h
Write16:
cmp x13, #8
bge Write16x8
b Write
Write16x8:
add x2, x2, #16
st1 {v16.8h}, [x11], x8
st1 {v17.8h}, [x11], x8
st1 {v18.8h}, [x11], x8
st1 {v19.8h}, [x11], x8
st1 {v20.8h}, [x11], x8
st1 {v21.8h}, [x11], x8
st1 {v22.8h}, [x11], x8
st1 {v23.8h}, [x11], x8
st1 {v24.8h}, [x11], x8
st1 {v25.8h}, [x11], x8
st1 {v26.8h}, [x11], x8
st1 {v27.8h}, [x11], x8
st1 {v28.8h}, [x11], x8
st1 {v29.8h}, [x11], x8
st1 {v30.8h}, [x11], x8
st1 {v31.8h}, [x11], x8
b WriteEnd
LoopRow8:
mov x15, #8
mov x14, x1 // reload rhs ptr
mov x13, x7 // reload rhs col
mov x12, x3 // reload bias
LoopCol8:
mov x11, x2
mov x10, x0 // reload lhs ptr
mov x19, x5 // reload depth
ld1 {v16.8h}, [x12], #16
mov v17.8h, v16.8h
mov v18.8h, v16.8h
mov v19.8h, v16.8h
mov v20.8h, v16.8h
mov v21.8h, v16.8h
mov v22.8h, v16.8h
mov v23.8h, v16.8h
cmp x19, #4
blt LoopDepth8One
LoopDepth8:
ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64
ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64
fmla v16.8h, v8.8h, v0.h[0]
fmla v17.8h, v8.8h, v0.h[1]
fmla v18.8h, v8.8h, v0.h[2]
fmla v19.8h, v8.8h, v0.h[3]
fmla v20.8h, v8.8h, v0.h[4]
fmla v21.8h, v8.8h, v0.h[5]
fmla v22.8h, v8.8h, v0.h[6]
fmla v23.8h, v8.8h, v0.h[7]
fmla v16.8h, v9.8h, v1.h[0]
fmla v17.8h, v9.8h, v1.h[1]
fmla v18.8h, v9.8h, v1.h[2]
fmla v19.8h, v9.8h, v1.h[3]
fmla v20.8h, v9.8h, v1.h[4]
fmla v21.8h, v9.8h, v1.h[5]
fmla v22.8h, v9.8h, v1.h[6]
fmla v23.8h, v9.8h, v1.h[7]
fmla v16.8h, v10.8h, v2.h[0]
fmla v17.8h, v10.8h, v2.h[1]
fmla v18.8h, v10.8h, v2.h[2]
fmla v19.8h, v10.8h, v2.h[3]
fmla v20.8h, v10.8h, v2.h[4]
fmla v21.8h, v10.8h, v2.h[5]
fmla v22.8h, v10.8h, v2.h[6]
fmla v23.8h, v10.8h, v2.h[7]
fmla v16.8h, v11.8h, v3.h[0]
fmla v17.8h, v11.8h, v3.h[1]
fmla v18.8h, v11.8h, v3.h[2]
fmla v19.8h, v11.8h, v3.h[3]
fmla v20.8h, v11.8h, v3.h[4]
fmla v21.8h, v11.8h, v3.h[5]
fmla v22.8h, v11.8h, v3.h[6]
fmla v23.8h, v11.8h, v3.h[7]
subs x19, x19, #4
beq Activation8
cmp x19, #4
bge LoopDepth8
LoopDepth8One:
ld1 {v0.8h}, [x10], #16
ld1 {v2.8h}, [x14], #16
fmla v16.8h, v2.8h, v0.h[0]
fmla v17.8h, v2.8h, v0.h[1]
fmla v18.8h, v2.8h, v0.h[2]
fmla v19.8h, v2.8h, v0.h[3]
fmla v20.8h, v2.8h, v0.h[4]
fmla v21.8h, v2.8h, v0.h[5]
fmla v22.8h, v2.8h, v0.h[6]
fmla v23.8h, v2.8h, v0.h[7]
subs x19, x19, #1
bgt LoopDepth8One
Activation8:
cmp x4, #3
beq Relu68
cmp x4, #1
beq Relu8
b Write8_Row
Relu68:
fmin v16.8h, v16.8h, v13.8h
fmin v17.8h, v17.8h, v13.8h
fmin v18.8h, v18.8h, v13.8h
fmin v19.8h, v19.8h, v13.8h
fmin v20.8h, v20.8h, v13.8h
fmin v21.8h, v21.8h, v13.8h
fmin v22.8h, v22.8h, v13.8h
fmin v23.8h, v23.8h, v13.8h
Relu8:
fmax v16.8h, v16.8h, v12.8h
fmax v17.8h, v17.8h, v12.8h
fmax v18.8h, v18.8h, v12.8h
fmax v19.8h, v19.8h, v12.8h
fmax v20.8h, v20.8h, v12.8h
fmax v21.8h, v21.8h, v12.8h
fmax v22.8h, v22.8h, v12.8h
fmax v23.8h, v23.8h, v12.8h
Write8_Row:
cmp x13, #8 // row
bge Write8x8
b Write
Write8x8:
add x2, x2, #16
st1 {v16.8h}, [x11], x8
st1 {v17.8h}, [x11], x8
st1 {v18.8h}, [x11], x8
st1 {v19.8h}, [x11], x8
st1 {v20.8h}, [x11], x8
st1 {v21.8h}, [x11], x8
st1 {v22.8h}, [x11], x8
st1 {v23.8h}, [x11], x8
b WriteEnd
LoopRow4:
mov x15, #4
mov x14, x1 // reload rhs ptr
mov x13, x7 // reload rhs col
mov x12, x3 // reload bias
LoopCol4:
mov x11, x2
mov x10, x0 // reload lhs ptr
mov x19, x5 // reload depth
ld1 {v16.8h}, [x12], #16
mov v17.8h, v16.8h
mov v18.8h, v16.8h
mov v19.8h, v16.8h
cmp x19, #4
blt LoopDepth4One
LoopDepth4:
ld1 {v0.8h, v1.8h}, [x10], #32
ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64
fmla v16.8h, v8.8h, v0.h[0]
fmla v17.8h, v8.8h, v0.h[1]
fmla v18.8h, v8.8h, v0.h[2]
fmla v19.8h, v8.8h, v0.h[3]
fmla v16.8h, v9.8h, v0.h[4]
fmla v17.8h, v9.8h, v0.h[5]
fmla v18.8h, v9.8h, v0.h[6]
fmla v19.8h, v9.8h, v0.h[7]
fmla v16.8h, v10.8h, v1.h[0]
fmla v17.8h, v10.8h, v1.h[1]
fmla v18.8h, v10.8h, v1.h[2]
fmla v19.8h, v10.8h, v1.h[3]
fmla v16.8h, v11.8h, v1.h[4]
fmla v17.8h, v11.8h, v1.h[5]
fmla v18.8h, v11.8h, v1.h[6]
fmla v19.8h, v11.8h, v1.h[7]
subs x19, x19, #4
beq Activation4
cmp x19, #4
bge LoopDepth4
LoopDepth4One:
ld1 {v0.4h}, [x10], #8
ld1 {v2.8h}, [x14], #16
fmla v16.8h, v2.8h, v0.h[0]
fmla v17.8h, v2.8h, v0.h[1]
fmla v18.8h, v2.8h, v0.h[2]
fmla v19.8h, v2.8h, v0.h[3]
subs x19, x19, #1
bgt LoopDepth4One
Activation4:
cmp x4, #3
beq Relu64
cmp x4, #1
beq Relu4
b Write4_Row
Relu64:
fmin v16.8h, v16.8h, v13.8h
fmin v17.8h, v17.8h, v13.8h
fmin v18.8h, v18.8h, v13.8h
fmin v19.8h, v19.8h, v13.8h
Relu4:
fmax v16.8h, v16.8h, v12.8h
fmax v17.8h, v17.8h, v12.8h
fmax v18.8h, v18.8h, v12.8h
fmax v19.8h, v19.8h, v12.8h
Write4_Row:
cmp x6, #4
bge Write4x8
b Write
Write4x8:
cmp x13, #8
blt Write
add x2, x2, #16
st1 {v16.8h}, [x11], x8
st1 {v17.8h}, [x11], x8
st1 {v18.8h}, [x11], x8
st1 {v19.8h}, [x11], x8
b WriteEnd
Write:
cmp x13, #1
beq Write1
cmp x13, #2
beq Write2
cmp x13, #3
beq Write3
cmp x13, #4
beq Write4
cmp x13, #5
beq Write5
cmp x13, #6
beq Write6
cmp x13, #7
beq Write7
b Write8
Write1:
add x2, x2, #2
st1 {v16.h}[0], [x11], x8
cmp x6, #1
beq WriteEnd
st1 {v17.h}[0], [x11], x8
cmp x6, #2
beq WriteEnd
st1 {v18.h}[0], [x11], x8
cmp x6, #3
beq WriteEnd
st1 {v19.h}[0], [x11], x8
cmp x6, #4
beq WriteEnd
st1 {v20.h}[0], [x11], x8
cmp x6, #5
beq WriteEnd
st1 {v21.h}[0], [x11], x8
cmp x6, #6
beq WriteEnd
st1 {v22.h}[0], [x11], x8
cmp x6, #7
beq WriteEnd
st1 {v23.h}[0], [x11], x8
cmp x6, #8
beq WriteEnd
st1 {v24.h}[0], [x11], x8
cmp x6, #9
beq WriteEnd
st1 {v25.h}[0], [x11], x8
cmp x6, #10
beq WriteEnd
st1 {v26.h}[0], [x11], x8
cmp x6, #11
beq WriteEnd
st1 {v27.h}[0], [x11], x8
cmp x6, #12
beq WriteEnd
st1 {v28.h}[0], [x11], x8
cmp x6, #13
beq WriteEnd
st1 {v29.h}[0], [x11], x8
cmp x6, #14
beq WriteEnd
st1 {v30.h}[0], [x11], x8
cmp x6, #15
beq WriteEnd
st1 {v31.h}[0], [x11], x8
b WriteEnd
Write2:
add x2, x2, #4
st1 {v16.s}[0], [x11], x8
cmp x6, #1
beq WriteEnd
st1 {v17.s}[0], [x11], x8
cmp x6, #2
beq WriteEnd
st1 {v18.s}[0], [x11], x8
cmp x6, #3
beq WriteEnd
st1 {v19.s}[0], [x11], x8
cmp x6, #4
beq WriteEnd
st1 {v20.s}[0], [x11], x8
cmp x6, #5
beq WriteEnd
st1 {v21.s}[0], [x11], x8
cmp x6, #6
beq WriteEnd
st1 {v22.s}[0], [x11], x8
cmp x6, #7
beq WriteEnd
st1 {v23.s}[0], [x11], x8
cmp x6, #8
beq WriteEnd
st1 {v24.s}[0], [x11], x8
cmp x6, #9
beq WriteEnd
st1 {v25.s}[0], [x11], x8
cmp x6, #10
beq WriteEnd
st1 {v26.s}[0], [x11], x8
cmp x6, #11
beq WriteEnd
st1 {v27.s}[0], [x11], x8
cmp x6, #12
beq WriteEnd
st1 {v28.s}[0], [x11], x8
cmp x6, #13
beq WriteEnd
st1 {v29.s}[0], [x11], x8
cmp x6, #14
beq WriteEnd
st1 {v30.s}[0], [x11], x8
cmp x6, #15
beq WriteEnd
st1 {v31.s}[0], [x11], x8
b WriteEnd
Write3:
add x2, x2, #6
add x19, x11, #4
st1 {v16.s}[0], [x11], x8
st1 {v16.h}[2], [x19], x8
cmp x6, #1
beq WriteEnd
st1 {v17.s}[0], [x11], x8
st1 {v17.h}[2], [x19], x8
cmp x6, #2
beq WriteEnd
st1 {v18.s}[0], [x11], x8
st1 {v18.h}[2], [x19], x8
cmp x6, #3
beq WriteEnd
st1 {v19.s}[0], [x11], x8
st1 {v19.h}[2], [x19], x8
cmp x6, #4
beq WriteEnd
st1 {v20.s}[0], [x11], x8
st1 {v20.h}[2], [x19], x8
cmp x6, #5
beq WriteEnd
st1 {v21.s}[0], [x11], x8
st1 {v21.h}[2], [x19], x8
cmp x6, #6
beq WriteEnd
st1 {v22.s}[0], [x11], x8
st1 {v22.h}[2], [x19], x8
cmp x6, #7
beq WriteEnd
st1 {v23.s}[0], [x11], x8
st1 {v23.h}[2], [x19], x8
cmp x6, #8
beq WriteEnd
st1 {v24.s}[0], [x11], x8
st1 {v24.h}[2], [x19], x8
cmp x6, #9
beq WriteEnd
st1 {v25.s}[0], [x11], x8
st1 {v25.h}[2], [x19], x8
cmp x6, #10
beq WriteEnd
st1 {v26.s}[0], [x11], x8
st1 {v26.h}[2], [x19], x8
cmp x6, #11
beq WriteEnd
st1 {v27.s}[0], [x11], x8
st1 {v27.h}[2], [x19], x8
cmp x6, #12
beq WriteEnd
st1 {v28.s}[0], [x11], x8
st1 {v28.h}[2], [x19], x8
cmp x6, #13
beq WriteEnd
st1 {v29.s}[0], [x11], x8
st1 {v29.h}[2], [x19], x8
cmp x6, #14
beq WriteEnd
st1 {v30.s}[0], [x11], x8
st1 {v30.h}[2], [x19], x8
cmp x6, #15
beq WriteEnd
st1 {v31.s}[0], [x11], x8
st1 {v31.h}[2], [x19]
b WriteEnd
Write4:
add x2, x2, #8
st1 {v16.4h}, [x11], x8
cmp x6, #1
beq WriteEnd
st1 {v17.4h}, [x11], x8
cmp x6, #2
beq WriteEnd
st1 {v18.4h}, [x11], x8
cmp x6, #3
beq WriteEnd
st1 {v19.4h}, [x11], x8
cmp x6, #4
beq WriteEnd
st1 {v20.4h}, [x11], x8
cmp x6, #5
beq WriteEnd
st1 {v21.4h}, [x11], x8
cmp x6, #6
beq WriteEnd
st1 {v22.4h}, [x11], x8
cmp x6, #7
beq WriteEnd
st1 {v23.4h}, [x11], x8
cmp x6, #8
beq WriteEnd
st1 {v24.4h}, [x11], x8
cmp x6, #9
beq WriteEnd
st1 {v25.4h}, [x11], x8
cmp x6, #10
beq WriteEnd
st1 {v26.4h}, [x11], x8
cmp x6, #11
beq WriteEnd
st1 {v27.4h}, [x11], x8
cmp x6, #12
beq WriteEnd
st1 {v28.4h}, [x11], x8
cmp x6, #13
beq WriteEnd
st1 {v29.4h}, [x11], x8
cmp x6, #14
beq WriteEnd
st1 {v30.4h}, [x11], x8
cmp x6, #15
beq WriteEnd
st1 {v31.4h}, [x11], x8
b WriteEnd
Write5:
add x2, x2, #10
add x19, x11, #8
st1 {v16.4h}, [x11], x8
st1 {v16.h}[4], [x19], x8
cmp x6, #1
beq WriteEnd
st1 {v17.4h}, [x11], x8
st1 {v17.h}[4], [x19], x8
cmp x6, #2
beq WriteEnd
st1 {v18.4h}, [x11], x8
st1 {v18.h}[4], [x19], x8
cmp x6, #3
beq WriteEnd
st1 {v19.4h}, [x11], x8
st1 {v19.h}[4], [x19], x8
cmp x6, #4
beq WriteEnd
st1 {v20.4h}, [x11], x8
st1 {v20.h}[4], [x19], x8
cmp x6, #5
beq WriteEnd
st1 {v21.4h}, [x11], x8
st1 {v21.h}[4], [x19], x8
cmp x6, #6
beq WriteEnd
st1 {v22.4h}, [x11], x8
st1 {v22.h}[4], [x19], x8
cmp x6, #7
beq WriteEnd
st1 {v23.4h}, [x11], x8
st1 {v23.h}[4], [x19], x8
cmp x6, #8
beq WriteEnd
st1 {v24.4h}, [x11], x8
st1 {v24.h}[4], [x19], x8
cmp x6, #9
beq WriteEnd
st1 {v25.4h}, [x11], x8
st1 {v25.h}[4], [x19], x8
cmp x6, #10
beq WriteEnd
st1 {v26.4h}, [x11], x8
st1 {v26.h}[4], [x19], x8
cmp x6, #11
beq WriteEnd
st1 {v27.4h}, [x11], x8
st1 {v27.h}[4], [x19], x8
cmp x6, #12
beq WriteEnd
st1 {v28.4h}, [x11], x8
st1 {v28.h}[4], [x19], x8
cmp x6, #13
beq WriteEnd
st1 {v29.4h}, [x11], x8
st1 {v29.h}[4], [x19], x8
cmp x6, #14
beq WriteEnd
st1 {v30.4h}, [x11], x8
st1 {v30.h}[4], [x19], x8
cmp x6, #15
beq WriteEnd
st1 {v31.4h}, [x11], x8
st1 {v31.h}[4], [x19]
b WriteEnd
Write6:
add x2, x2, #12
add x19, x11, #8
st1 {v16.4h}, [x11], x8
st1 {v16.s}[2], [x19], x8
cmp x6, #1
beq WriteEnd
st1 {v17.4h}, [x11], x8
st1 {v17.s}[2], [x19], x8
cmp x6, #2
beq WriteEnd
st1 {v18.4h}, [x11], x8
st1 {v18.s}[2], [x19], x8
cmp x6, #3
beq WriteEnd
st1 {v19.4h}, [x11], x8
st1 {v19.s}[2], [x19], x8
cmp x6, #4
beq WriteEnd
st1 {v20.4h}, [x11], x8
st1 {v20.s}[2], [x19], x8
cmp x6, #5
beq WriteEnd
st1 {v21.4h}, [x11], x8
st1 {v21.s}[2], [x19], x8
cmp x6, #6
beq WriteEnd
st1 {v22.4h}, [x11], x8
st1 {v22.s}[2], [x19], x8
cmp x6, #7
beq WriteEnd
st1 {v23.4h}, [x11], x8
st1 {v23.s}[2], [x19], x8
cmp x6, #8
beq WriteEnd
st1 {v24.4h}, [x11], x8
st1 {v24.s}[2], [x19], x8
cmp x6, #9
beq WriteEnd
st1 {v25.4h}, [x11], x8
st1 {v25.s}[2], [x19], x8
cmp x6, #10
beq WriteEnd
st1 {v26.4h}, [x11], x8
st1 {v26.s}[2], [x19], x8
cmp x6, #11
beq WriteEnd
st1 {v27.4h}, [x11], x8
st1 {v27.s}[2], [x19], x8
cmp x6, #12
beq WriteEnd
st1 {v28.4h}, [x11], x8
st1 {v28.s}[2], [x19], x8
cmp x6, #13
beq WriteEnd
st1 {v29.4h}, [x11], x8
st1 {v29.s}[2], [x19], x8
cmp x6, #14
beq WriteEnd
st1 {v30.4h}, [x11], x8
st1 {v30.s}[2], [x19], x8
cmp x6, #15
beq WriteEnd
st1 {v31.4h}, [x11], x8
st1 {v31.s}[2], [x19]
b WriteEnd
Write7:
add x2, x2, #14
add x19, x11, #8
add x10, x11, #12
st1 {v16.4h}, [x11], x8
st1 {v16.s}[2], [x19], x8
st1 {v16.h}[6], [x10], x8
cmp x6, #1
beq WriteEnd
st1 {v17.4h}, [x11], x8
st1 {v17.s}[2], [x19], x8
st1 {v17.h}[6], [x10], x8
cmp x6, #2
beq WriteEnd
st1 {v18.4h}, [x11], x8
st1 {v18.s}[2], [x19], x8
st1 {v18.h}[6], [x10], x8
cmp x6, #3
beq WriteEnd
st1 {v19.4h}, [x11], x8
st1 {v19.s}[2], [x19], x8
st1 {v19.h}[6], [x10], x8
cmp x6, #4
beq WriteEnd
st1 {v20.4h}, [x11], x8
st1 {v20.s}[2], [x19], x8
st1 {v20.h}[6], [x10], x8
cmp x6, #5
beq WriteEnd
st1 {v21.4h}, [x11], x8
st1 {v21.s}[2], [x19], x8
st1 {v21.h}[6], [x10], x8
cmp x6, #6
beq WriteEnd
st1 {v22.4h}, [x11], x8
st1 {v22.s}[2], [x19], x8
st1 {v22.h}[6], [x10], x8
cmp x6, #7
beq WriteEnd
st1 {v23.4h}, [x11], x8
st1 {v23.s}[2], [x19], x8
st1 {v23.h}[6], [x10], x8
cmp x6, #8
beq WriteEnd
st1 {v24.4h}, [x11], x8
st1 {v24.s}[2], [x19], x8
st1 {v24.h}[6], [x10], x8
cmp x6, #9
beq WriteEnd
st1 {v25.4h}, [x11], x8
st1 {v25.s}[2], [x19], x8
st1 {v25.h}[6], [x10], x8
cmp x6, #10
beq WriteEnd
st1 {v26.4h}, [x11], x8
st1 {v26.s}[2], [x19], x8
st1 {v26.h}[6], [x10], x8
cmp x6, #11
beq WriteEnd
st1 {v27.4h}, [x11], x8
st1 {v27.s}[2], [x19], x8
st1 {v27.h}[6], [x10], x8
cmp x6, #12
beq WriteEnd
st1 {v28.4h}, [x11], x8
st1 {v28.s}[2], [x19], x8
st1 {v28.h}[6], [x10], x8
cmp x6, #13
beq WriteEnd
st1 {v29.4h}, [x11], x8
st1 {v29.s}[2], [x19], x8
st1 {v29.h}[6], [x10], x8
cmp x6, #14
beq WriteEnd
st1 {v30.4h}, [x11], x8
st1 {v30.s}[2], [x19], x8
st1 {v30.h}[6], [x10], x8
cmp x6, #15
beq WriteEnd
st1 {v31.4h}, [x11], x8
st1 {v31.s}[2], [x19]
st1 {v31.h}[6], [x10]
b WriteEnd
Write8:
add x2, x2, #16
st1 {v16.8h}, [x11], x8
cmp x6, #1
beq WriteEnd
st1 {v17.8h}, [x11], x8
cmp x6, #2
beq WriteEnd
st1 {v18.8h}, [x11], x8
cmp x6, #3
beq WriteEnd
st1 {v19.8h}, [x11], x8
cmp x6, #4
beq WriteEnd
st1 {v20.8h}, [x11], x8
cmp x6, #5
beq WriteEnd
st1 {v21.8h}, [x11], x8
cmp x6, #6
beq WriteEnd
st1 {v22.8h}, [x11], x8
cmp x6, #7
beq WriteEnd
st1 {v23.8h}, [x11], x8
cmp x6, #8
beq WriteEnd
st1 {v24.8h}, [x11], x8
cmp x6, #9
beq WriteEnd
st1 {v25.8h}, [x11], x8
cmp x6, #10
beq WriteEnd
st1 {v26.8h}, [x11], x8
cmp x6, #11
beq WriteEnd
st1 {v27.8h}, [x11], x8
cmp x6, #12
beq WriteEnd
st1 {v28.8h}, [x11], x8
cmp x6, #13
beq WriteEnd
st1 {v29.8h}, [x11], x8
cmp x6, #14
beq WriteEnd
st1 {v30.8h}, [x11], x8
cmp x6, #15
beq WriteEnd
st1 {v31.8h}, [x11], x8
WriteEnd:
subs x13, x13, #8 // rhs col - 8
ble LoopColEnd
cmp x6, #16
bge LoopCol16
cmp x6, #8
bge LoopCol8
b LoopCol4
LoopColEnd:
sub x2, x2, x16 // dst - col * 2
mul x21, x8, x15 // row_block * col * 2
add x2, x2, x21
subs x6, x6, x15
mul x15, x15, x17
add x0, x0, x15
bgt LoopRowStart
sub sp, sp, #96
ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64
ldp x19, x20, [sp], #16
ldp x21, x22, [sp], #16
ret
#endif

View File

@ -675,152 +675,118 @@ LoopRow:
b WriteEnd
Write2:
add x2, x2, #4
add x19, x11, #2
st1 {v16.h}[0], [x11], x8
st1 {v16.h}[1], [x19], x8
st1 {v16.s}[0], [x11], x8
cmp x6, #1
beq WriteEnd
st1 {v17.h}[0], [x11], x8
st1 {v17.h}[1], [x19], x8
st1 {v17.s}[0], [x11], x8
cmp x6, #2
beq WriteEnd
st1 {v18.h}[0], [x11], x8
st1 {v18.h}[1], [x19], x8
st1 {v18.s}[0], [x11], x8
cmp x6, #3
beq WriteEnd
st1 {v19.h}[0], [x11], x8
st1 {v19.h}[1], [x19], x8
st1 {v19.s}[0], [x11], x8
cmp x6, #4
beq WriteEnd
st1 {v20.h}[0], [x11], x8
st1 {v20.h}[1], [x19], x8
st1 {v20.s}[0], [x11], x8
cmp x6, #5
beq WriteEnd
st1 {v21.h}[0], [x11], x8
st1 {v21.h}[1], [x19], x8
st1 {v21.s}[0], [x11], x8
cmp x6, #6
beq WriteEnd
st1 {v22.h}[0], [x11], x8
st1 {v22.h}[1], [x19], x8
st1 {v22.s}[0], [x11], x8
cmp x6, #7
beq WriteEnd
st1 {v23.h}[0], [x11], x8
st1 {v23.h}[1], [x19], x8
st1 {v23.s}[0], [x11], x8
cmp x6, #8
beq WriteEnd
st1 {v24.h}[0], [x11], x8
st1 {v24.h}[1], [x19], x8
st1 {v24.s}[0], [x11], x8
cmp x6, #9
beq WriteEnd
st1 {v25.h}[0], [x11], x8
st1 {v25.h}[1], [x19], x8
st1 {v25.s}[0], [x11], x8
cmp x6, #10
beq WriteEnd
st1 {v26.h}[0], [x11], x8
st1 {v26.h}[1], [x19], x8
st1 {v26.s}[0], [x11], x8
cmp x6, #11
beq WriteEnd
st1 {v27.h}[0], [x11], x8
st1 {v27.h}[1], [x19], x8
st1 {v27.s}[0], [x11], x8
cmp x6, #12
beq WriteEnd
st1 {v28.h}[0], [x11], x8
st1 {v28.h}[1], [x19], x8
st1 {v28.s}[0], [x11], x8
cmp x6, #13
beq WriteEnd
st1 {v29.h}[0], [x11], x8
st1 {v29.h}[1], [x19], x8
st1 {v29.s}[0], [x11], x8
cmp x6, #14
beq WriteEnd
st1 {v30.h}[0], [x11], x8
st1 {v30.h}[1], [x19], x8
st1 {v30.s}[0], [x11], x8
cmp x6, #15
beq WriteEnd
st1 {v31.h}[0], [x11], x8
st1 {v31.h}[1], [x19]
st1 {v31.s}[0], [x11], x8
add x11, x11, #4
b WriteEnd
Write3:
add x2, x2, #6
add x19, x11, #4
add x20, x11, #2
st1 {v16.h}[0], [x11], x8
st1 {v16.h}[1], [x20], x8
st1 {v16.s}[0], [x11], x8
st1 {v16.h}[2], [x19], x8
cmp x6, #1
beq WriteEnd
st1 {v17.h}[0], [x11], x8
st1 {v17.h}[1], [x20], x8
st1 {v17.s}[0], [x11], x8
st1 {v17.h}[2], [x19], x8
cmp x6, #2
beq WriteEnd
st1 {v18.h}[0], [x11], x8
st1 {v18.h}[1], [x20], x8
st1 {v18.s}[0], [x11], x8
st1 {v18.h}[2], [x19], x8
cmp x6, #3
beq WriteEnd
st1 {v19.h}[0], [x11], x8
st1 {v19.h}[1], [x20], x8
st1 {v19.s}[0], [x11], x8
st1 {v19.h}[2], [x19], x8
cmp x6, #4
beq WriteEnd
st1 {v20.h}[0], [x11], x8
st1 {v20.h}[1], [x20], x8
st1 {v20.s}[0], [x11], x8
st1 {v20.h}[2], [x19], x8
cmp x6, #5
beq WriteEnd
st1 {v21.h}[0], [x11], x8
st1 {v21.h}[1], [x20], x8
st1 {v21.s}[0], [x11], x8
st1 {v21.h}[2], [x19], x8
cmp x6, #6
beq WriteEnd
st1 {v22.h}[0], [x11], x8
st1 {v22.h}[1], [x20], x8
st1 {v22.s}[0], [x11], x8
st1 {v22.h}[2], [x19], x8
cmp x6, #7
beq WriteEnd
st1 {v23.h}[0], [x11], x8
st1 {v23.h}[1], [x20], x8
st1 {v23.s}[0], [x11], x8
st1 {v23.h}[2], [x19], x8
cmp x6, #8
beq WriteEnd
st1 {v24.h}[0], [x11], x8
st1 {v24.h}[1], [x20], x8
st1 {v24.s}[0], [x11], x8
st1 {v24.h}[2], [x19], x8
cmp x6, #9
beq WriteEnd
st1 {v25.h}[0], [x11], x8
st1 {v25.h}[1], [x20], x8
st1 {v25.s}[0], [x11], x8
st1 {v25.h}[2], [x19], x8
cmp x6, #10
beq WriteEnd
st1 {v26.h}[0], [x11], x8
st1 {v26.h}[1], [x20], x8
st1 {v26.s}[0], [x11], x8
st1 {v26.h}[2], [x19], x8
cmp x6, #11
beq WriteEnd
st1 {v27.h}[0], [x11], x8
st1 {v27.h}[1], [x20], x8
st1 {v27.s}[0], [x11], x8
st1 {v27.h}[2], [x19], x8
cmp x6, #12
beq WriteEnd
st1 {v28.h}[0], [x11], x8
st1 {v28.h}[1], [x20], x8
st1 {v28.s}[0], [x11], x8
st1 {v28.h}[2], [x19], x8
cmp x6, #13
beq WriteEnd
st1 {v29.h}[0], [x11], x8
st1 {v29.h}[1], [x20], x8
st1 {v29.s}[0], [x11], x8
st1 {v29.h}[2], [x19], x8
cmp x6, #14
beq WriteEnd
st1 {v30.h}[0], [x11], x8
st1 {v30.h}[1], [x20], x8
st1 {v30.s}[0], [x11], x8
st1 {v30.h}[2], [x19], x8
cmp x6, #15
beq WriteEnd
st1 {v31.h}[0], [x11], x8
st1 {v31.h}[1], [x20]
st1 {v31.s}[0], [x11], x8
st1 {v31.h}[2], [x19]
add x11, x11, #6
b WriteEnd
@ -944,185 +910,151 @@ LoopRow:
Write6:
add x2, x2, #12
add x19, x11, #8
add x20, x11, #10
st1 {v16.4h}, [x11], x8
st1 {v16.h}[4], [x19], x8
st1 {v16.h}[5], [x20], x8
st1 {v16.s}[2], [x19], x8
cmp x6, #1
beq WriteEnd
st1 {v17.4h}, [x11], x8
st1 {v17.h}[4], [x19], x8
st1 {v17.h}[5], [x20], x8
st1 {v17.s}[2], [x19], x8
cmp x6, #2
beq WriteEnd
st1 {v18.4h}, [x11], x8
st1 {v18.h}[4], [x19], x8
st1 {v18.h}[5], [x20], x8
st1 {v18.s}[2], [x19], x8
cmp x6, #3
beq WriteEnd
st1 {v19.4h}, [x11], x8
st1 {v19.h}[4], [x19], x8
st1 {v19.h}[5], [x20], x8
st1 {v19.s}[2], [x19], x8
cmp x6, #4
beq WriteEnd
st1 {v20.4h}, [x11], x8
st1 {v20.h}[4], [x19], x8
st1 {v20.h}[5], [x20], x8
st1 {v20.s}[2], [x19], x8
cmp x6, #5
beq WriteEnd
st1 {v21.4h}, [x11], x8
st1 {v21.h}[4], [x19], x8
st1 {v21.h}[5], [x20], x8
st1 {v21.s}[2], [x19], x8
cmp x6, #6
beq WriteEnd
st1 {v22.4h}, [x11], x8
st1 {v22.h}[4], [x19], x8
st1 {v22.h}[5], [x20], x8
st1 {v22.s}[2], [x19], x8
cmp x6, #7
beq WriteEnd
st1 {v23.4h}, [x11], x8
st1 {v23.h}[4], [x19], x8
st1 {v23.h}[5], [x20], x8
st1 {v23.s}[2], [x19], x8
cmp x6, #8
beq WriteEnd
st1 {v24.4h}, [x11], x8
st1 {v24.h}[4], [x19], x8
st1 {v24.h}[5], [x20], x8
st1 {v24.s}[2], [x19], x8
cmp x6, #9
beq WriteEnd
st1 {v25.4h}, [x11], x8
st1 {v25.h}[4], [x19], x8
st1 {v25.h}[5], [x20], x8
st1 {v25.s}[2], [x19], x8
cmp x6, #10
beq WriteEnd
st1 {v26.4h}, [x11], x8
st1 {v26.h}[4], [x19], x8
st1 {v26.h}[5], [x20], x8
st1 {v26.s}[2], [x19], x8
cmp x6, #11
beq WriteEnd
st1 {v27.4h}, [x11], x8
st1 {v27.h}[4], [x19], x8
st1 {v27.h}[5], [x20], x8
st1 {v27.s}[2], [x19], x8
cmp x6, #12
beq WriteEnd
st1 {v28.4h}, [x11], x8
st1 {v28.h}[4], [x19], x8
st1 {v28.h}[5], [x20], x8
st1 {v28.s}[2], [x19], x8
cmp x6, #13
beq WriteEnd
st1 {v29.4h}, [x11], x8
st1 {v29.h}[4], [x19], x8
st1 {v29.h}[5], [x20], x8
st1 {v29.s}[2], [x19], x8
cmp x6, #14
beq WriteEnd
st1 {v30.4h}, [x11], x8
st1 {v30.h}[4], [x19], x8
st1 {v30.h}[5], [x20], x8
st1 {v30.s}[2], [x19], x8
cmp x6, #15
beq WriteEnd
st1 {v31.4h}, [x11], x8
st1 {v31.h}[4], [x19]
st1 {v31.h}[5], [x20]
st1 {v31.s}[2], [x19]
add x11, x11, #12
b WriteEnd
Write7:
add x2, x2, #14
add x19, x11, #8
add x20, x11, #10
add x10, x11, #12
st1 {v16.4h}, [x11], x8
st1 {v16.h}[4], [x19], x8
st1 {v16.h}[5], [x20], x8
st1 {v16.s}[2], [x19], x8
st1 {v16.h}[6], [x10], x8
cmp x6, #1
beq WriteEnd
st1 {v17.4h}, [x11], x8
st1 {v17.h}[4], [x19], x8
st1 {v17.h}[5], [x20], x8
st1 {v17.s}[2], [x19], x8
st1 {v17.h}[6], [x10], x8
cmp x6, #2
beq WriteEnd
st1 {v18.4h}, [x11], x8
st1 {v18.h}[4], [x19], x8
st1 {v18.h}[5], [x20], x8
st1 {v18.s}[2], [x19], x8
st1 {v18.h}[6], [x10], x8
cmp x6, #3
beq WriteEnd
st1 {v19.4h}, [x11], x8
st1 {v19.h}[4], [x19], x8
st1 {v19.h}[5], [x20], x8
st1 {v19.s}[2], [x19], x8
st1 {v19.h}[6], [x10], x8
cmp x6, #4
beq WriteEnd
st1 {v20.4h}, [x11], x8
st1 {v20.h}[4], [x19], x8
st1 {v20.h}[5], [x20], x8
st1 {v20.s}[2], [x19], x8
st1 {v20.h}[6], [x10], x8
cmp x6, #5
beq WriteEnd
st1 {v21.4h}, [x11], x8
st1 {v21.h}[4], [x19], x8
st1 {v21.h}[5], [x20], x8
st1 {v21.s}[2], [x19], x8
st1 {v21.h}[6], [x10], x8
cmp x6, #6
beq WriteEnd
st1 {v22.4h}, [x11], x8
st1 {v22.h}[4], [x19], x8
st1 {v22.h}[5], [x20], x8
st1 {v22.s}[2], [x19], x8
st1 {v22.h}[6], [x10], x8
cmp x6, #7
beq WriteEnd
st1 {v23.4h}, [x11], x8
st1 {v23.h}[4], [x19], x8
st1 {v23.h}[5], [x20], x8
st1 {v23.s}[2], [x19], x8
st1 {v23.h}[6], [x10], x8
cmp x6, #8
beq WriteEnd
st1 {v24.4h}, [x11], x8
st1 {v24.h}[4], [x19], x8
st1 {v24.h}[5], [x20], x8
st1 {v24.s}[2], [x19], x8
st1 {v24.h}[6], [x10], x8
cmp x6, #9
beq WriteEnd
st1 {v25.4h}, [x11], x8
st1 {v25.h}[4], [x19], x8
st1 {v25.h}[5], [x20], x8
st1 {v25.s}[2], [x19], x8
st1 {v25.h}[6], [x10], x8
cmp x6, #10
beq WriteEnd
st1 {v26.4h}, [x11], x8
st1 {v26.h}[4], [x19], x8
st1 {v26.h}[5], [x20], x8
st1 {v26.s}[2], [x19], x8
st1 {v26.h}[6], [x10], x8
cmp x6, #11
beq WriteEnd
st1 {v27.4h}, [x11], x8
st1 {v27.h}[4], [x19], x8
st1 {v27.h}[5], [x20], x8
st1 {v27.s}[2], [x19], x8
st1 {v27.h}[6], [x10], x8
cmp x6, #12
beq WriteEnd
st1 {v28.4h}, [x11], x8
st1 {v28.h}[4], [x19], x8
st1 {v28.h}[5], [x20], x8
st1 {v28.s}[2], [x19], x8
st1 {v28.h}[6], [x10], x8
cmp x6, #13
beq WriteEnd
st1 {v29.4h}, [x11], x8
st1 {v29.h}[4], [x19], x8
st1 {v29.h}[5], [x20], x8
st1 {v29.s}[2], [x19], x8
st1 {v29.h}[6], [x10], x8
cmp x6, #14
beq WriteEnd
st1 {v30.4h}, [x11], x8
st1 {v30.h}[4], [x19], x8
st1 {v30.h}[5], [x20], x8
st1 {v30.s}[2], [x19], x8
st1 {v30.h}[6], [x10], x8
cmp x6, #15
beq WriteEnd
st1 {v31.4h}, [x11], x8
st1 {v31.h}[4], [x19]
st1 {v31.h}[5], [x20]
st1 {v31.s}[2], [x19]
st1 {v31.h}[6], [x10]
add x11, x11, #14
b WriteEnd

View File

@ -662,6 +662,75 @@ void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, si
return;
}
#ifdef ENABLE_ARM64
void RowMajor2ColNMajorFp16(const float16_t *src_ptr, float16_t *dst_ptr, int row, int col) {
// Col16Major ==> Col8Major ==> Col4Major
const float16_t *src_r = src_ptr;
float16_t *dst_r = dst_ptr;
int ri = 0;
size_t col8 = col / C8NUM * C8NUM;
// find 16 block unit
for (; ri <= row - C16NUM; ri += C16NUM) {
size_t ci = 0;
for (; ci < col8; ci += C8NUM) {
const float16_t *src_c = src_r + ci;
float16_t *dst_c = dst_r + ci * C16NUM;
Row2Col16Block16(src_c, dst_c, col);
}
for (; ci < col; ci++) {
const float16_t *src_c = src_r + ci;
float16_t *dst_c = dst_r + ci * C16NUM;
for (size_t i = 0; i < C16NUM; i++) {
dst_c[i] = src_c[i * col];
}
}
src_r += C16NUM * col;
dst_r += C16NUM * col;
}
for (; ri <= row - C8NUM; ri += C8NUM) {
size_t ci = 0;
for (; ci < col8; ci += C8NUM) {
const float16_t *src_c = src_r + ci;
float16_t *dst_c = dst_r + ci * C8NUM;
Transpose8x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), C8NUM * sizeof(float16_t));
}
for (; ci < col; ci++) {
const float16_t *src_c = src_r + ci;
float16_t *dst_c = dst_r + ci * C8NUM;
for (size_t i = 0; i < C8NUM; i++) {
dst_c[i] = src_c[i * col];
}
}
src_r += C8NUM * col;
dst_r += C8NUM * col;
}
for (; ri <= row - C4NUM; ri += C4NUM) {
size_t ci = 0;
for (; ci < col8; ci += C8NUM) {
const float16_t *src_c = src_r + ci;
float16_t *dst_c = dst_r + ci * C4NUM;
Transpose4x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), C4NUM * sizeof(float16_t));
}
for (; ci < col; ci++) {
const float16_t *src_c = src_r + ci;
float16_t *dst_c = dst_r + ci * C4NUM;
for (size_t i = 0; i < C4NUM; i++) {
dst_c[i] = src_c[i * col];
}
}
src_r += C4NUM * col;
dst_r += C4NUM * col;
}
for (; ri < row; ri++) {
for (size_t i = 0; i < col; ++i) {
dst_r[i * C4NUM] = src_r[i];
}
src_r += col;
dst_r += 1;
}
}
#endif
void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
size_t row_up_12 = UP_ROUND(row, C12NUM);
size_t row12 = row / C12NUM * C12NUM;
@ -760,6 +829,32 @@ void RowMajor2Row16MajorFp16(const void *src, float16_t *dst, int row, int col,
}
}
#ifdef ENABLE_ARM64
void RowMajor2RowNMajorFp16(const float16_t *src, float16_t *dst, int row, int col) {
// Row16 ==> Row8 ==> Row4
for (int r = 0; r < row; r++) {
int c = 0;
for (; c <= col - C16NUM; c += C16NUM) {
MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c);
MS_FLOAT16X8 src_data1 = MS_LDQ_F16(src + r * col + c + C8NUM);
MS_STQ_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM, src_data);
MS_STQ_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM + C8NUM, src_data1);
}
for (; c <= col - C8NUM; c += C8NUM) {
MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c);
MS_STQ_F16(dst + c / C8NUM * C8NUM * row + r * C8NUM, src_data);
}
for (; c <= col - C4NUM; c += C4NUM) {
MS_FLOAT16X4 src_data = MS_LD_F16(src + r * col + c);
MS_ST_F16(dst + c / C4NUM * C4NUM * row + r * C4NUM, src_data);
}
for (; c < col; ++c) {
dst[c / C4NUM * C4NUM * row + r * C4NUM + c % C4NUM] = src[r * col + c];
}
}
}
#endif
void RowMajor2Row16MajorFp16Opt(const float16_t *src, float16_t *dst, int row, int col) {
int col_align = UP_ROUND(col, C16NUM);
for (int r = 0; r < row; r++) {
@ -834,3 +929,32 @@ void RowMajor2Col8MajorFp16(const void *src, float16_t *dst, int row, int col, b
}
}
}
#if defined(ENABLE_DEBUG) && defined(ENABLE_ARM64)
// arm64 matmul
void MatmulBaseFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc) {
int r16 = row / C16NUM * C16NUM;
int r8 = row / C8NUM * C8NUM;
for (int r = 0; r < row; ++r) {
int row_tile = 0;
if (r < r16) {
row_tile = C16NUM;
} else if (r < r8) {
row_tile = C8NUM;
} else {
row_tile = C4NUM;
}
int index = r / row_tile * row_tile * depth + r % row_tile;
for (int t = 0; t < col; ++t) {
int c_div = t / C8NUM;
int c_mod = t % C8NUM;
float16_t res = bias[t];
for (int d = 0; d < depth; ++d) {
res += a[index + d * row_tile] * b[c_div * depth * C8NUM + d * C8NUM + c_mod];
}
c[r * col + t] = res;
}
}
}
#endif

View File

@ -45,6 +45,10 @@ void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, cons
int deep, int row, int col, int stride, int write_mode);
#ifdef ENABLE_ARM64
void RowMajor2ColNMajorFp16(const float16_t *src, float16_t *dst_ptr, int row, int col);
void RowMajor2RowNMajorFp16(const float16_t *src, float16_t *dst, int row, int col);
void MatMul12x16Fp16Opt(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
int deep, int row, int col, size_t stride, size_t out_type);
void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
@ -53,6 +57,14 @@ void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, cons
void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc);
void MatmulBaseFp16Neon(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc);
#ifdef ENABLE_DEBUG
void MatmulBaseFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc);
#endif
void MatVecMulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
int depth, int col);

View File

@ -640,6 +640,108 @@ inline void Transpose12x8A32Fp16(const float16_t *src_c, float16_t *dst_c, size_
#endif
#ifdef ENABLE_ARM64
inline void Transpose4x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) {
dst_stride += dst_stride;
asm volatile(
"mov x10, %[src_ptr]\n"
"mov x11, %[dst_ptr]\n"
"ld1 {v0.8h}, [x10], %[src_stride]\n"
"ld1 {v1.8h}, [x10], %[src_stride]\n"
"ld1 {v2.8h}, [x10], %[src_stride]\n"
"ld1 {v3.8h}, [x10], %[src_stride]\n"
"add x10, x11, %[dst_stride]\n"
"zip1 v4.8h, v0.8h, v1.8h\n"
"zip1 v5.8h, v2.8h, v3.8h\n"
"trn1 v6.4s, v4.4s, v5.4s\n"
"trn2 v7.4s, v4.4s, v5.4s\n"
"trn1 v24.2d, v6.2d, v7.2d\n"
"trn2 v25.2d, v6.2d, v7.2d\n"
"zip2 v8.8h, v0.8h, v1.8h\n"
"zip2 v9.8h, v2.8h, v3.8h\n"
"trn1 v10.4s, v8.4s, v9.4s\n"
"trn2 v11.4s, v8.4s, v9.4s\n"
"trn1 v26.2d, v10.2d, v11.2d\n"
"trn2 v27.2d, v10.2d, v11.2d\n"
"st1 {v24.8h}, [x11], %[tow_dst_stride]\n"
"st1 {v25.8h}, [x10], %[tow_dst_stride]\n"
"st1 {v26.8h}, [x11], %[tow_dst_stride]\n"
"st1 {v27.8h}, [x10], %[tow_dst_stride]\n"
:
: [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ src_stride ] "r"(src_stride),
[ dst_stride ] "r"(dst_stride), [ tow_dst_stride ] "r"(2 * dst_stride)
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v24", "v25", "v26",
"v27");
}
inline void Transpose8x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) {
asm volatile(
"mov x10, %[src_ptr]\n"
"mov x11, %[dst_ptr]\n"
"ld1 {v0.8h}, [x10], %[src_stride]\n"
"ld1 {v1.8h}, [x10], %[src_stride]\n"
"ld1 {v2.8h}, [x10], %[src_stride]\n"
"ld1 {v3.8h}, [x10], %[src_stride]\n"
"ld1 {v4.8h}, [x10], %[src_stride]\n"
"ld1 {v5.8h}, [x10], %[src_stride]\n"
"ld1 {v6.8h}, [x10], %[src_stride]\n"
"ld1 {v7.8h}, [x10], %[src_stride]\n"
"add x10, x11, %[dst_stride]\n"
"zip1 v16.8h, v0.8h, v1.8h\n"
"zip1 v17.8h, v2.8h, v3.8h\n"
"zip1 v18.8h, v4.8h, v5.8h\n"
"zip1 v19.8h, v6.8h, v7.8h\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v24.2d, v20.2d, v22.2d\n"
"trn2 v26.2d, v20.2d, v22.2d\n"
"trn1 v25.2d, v21.2d, v23.2d\n"
"trn2 v27.2d, v21.2d, v23.2d\n"
"zip2 v8.8h, v0.8h, v1.8h\n"
"zip2 v9.8h, v2.8h, v3.8h\n"
"zip2 v10.8h, v4.8h, v5.8h\n"
"zip2 v11.8h, v6.8h, v7.8h\n"
"trn1 v12.4s, v8.4s, v9.4s\n"
"trn2 v13.4s, v8.4s, v9.4s\n"
"trn1 v14.4s, v10.4s, v11.4s\n"
"trn2 v15.4s, v10.4s, v11.4s\n"
"trn1 v28.2d, v12.2d, v14.2d\n"
"trn2 v30.2d, v12.2d, v14.2d\n"
"trn1 v29.2d, v13.2d, v15.2d\n"
"trn2 v31.2d, v13.2d, v15.2d\n"
"st1 {v24.8h}, [x11], %[tow_dst_stride]\n"
"st1 {v25.8h}, [x10], %[tow_dst_stride]\n"
"st1 {v26.8h}, [x11], %[tow_dst_stride]\n"
"st1 {v27.8h}, [x10], %[tow_dst_stride]\n"
"st1 {v28.8h}, [x11], %[tow_dst_stride]\n"
"st1 {v29.8h}, [x10], %[tow_dst_stride]\n"
"st1 {v30.8h}, [x11], %[tow_dst_stride]\n"
"st1 {v31.8h}, [x10], %[tow_dst_stride]\n"
:
: [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ src_stride ] "r"(src_stride),
[ dst_stride ] "r"(dst_stride), [ tow_dst_stride ] "r"(2 * dst_stride)
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
"v31");
}
void Transpose12x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) {
#ifdef ENABLE_DEBUG
for (int tr = 0; tr < C12NUM; tr++) {

View File

@ -78,6 +78,8 @@ void Transpose12x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_strid
#endif
#ifdef ENABLE_ARM64
void Transpose4x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride);
void Transpose8x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride);
void Transpose12x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride);
void Transpose16x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride);
#endif

View File

@ -93,11 +93,14 @@ static inline float16x4_t ms_vcvt_f16_f32(float32x4_t in) {
#endif
#define MS_FLOAT16X8 float16x8_t
#define MS_FLOAT16X4 float16x4_t
#define MS_MOVQ_F16 vmovq_n_f16
#define MS_STQ_F16 vst1q_f16
#define MS_ST_F16 vst1_f16
#define MS_MINQ_F16 vminq_f16
#define MS_MAXQ_F16 vmaxq_f16
#define MS_LDQ_F16 vld1q_f16
#define MS_LD_F16 vld1_f16
#define MS_ADDQ_F16 vaddq_f16
#define MS_SUBQ_F16 vsubq_f16
#define MS_MULQ_F16 vmulq_f16

View File

@ -70,26 +70,18 @@ void MatmulBaseFP16CPUKernel::InitParameter() {
}
int MatmulBaseFP16CPUKernel::InitBias() {
if (in_tensors_.size() == 3) {
auto bias_tensor = in_tensors_[2];
int max_bias_data = UP_ROUND(bias_tensor->ElementsNum(), C8NUM);
if (bias_ptr_ == nullptr) {
if (params_->col_ != 0 && bias_ptr_ == nullptr) {
int max_bias_data = UP_ROUND(params_->col_, C8NUM);
bias_ptr_ = reinterpret_cast<float16_t *>(malloc(max_bias_data * sizeof(float16_t)));
if (bias_ptr_ == nullptr) {
MS_LOG(ERROR) << "malloc bias_ptr_ failed";
return RET_ERROR;
}
}
memset(bias_ptr_, 0, max_bias_data * sizeof(float16_t));
if (bias_tensor->data_type() == kNumberTypeFloat32) {
MS_LOG(ERROR) << "Matmul fp16 only support fp16 weight";
return RET_ERROR;
} else if (bias_tensor->data_type() == kNumberTypeFloat16) {
MS_ASSERT(bias_tensor->data_c() != nullptr);
if (in_tensors_.size() == 3) {
auto bias_tensor = in_tensors_[2];
memcpy(bias_ptr_, bias_tensor->data_c(), bias_tensor->ElementsNum() * sizeof(float16_t));
} else {
MS_LOG(ERROR) << "Unsupported bias data type : " << bias_tensor->data_type();
return RET_ERROR;
memset(bias_ptr_, 0, max_bias_data * sizeof(float16_t));
}
}
return RET_OK;
@ -181,13 +173,13 @@ void MatmulBaseFP16CPUKernel::InitMatrixA(void *src_ptr) {
float16_t *dst = a_pack_ptr_ + i * params_->deep_ * params_->row_align_;
if (params_->a_transpose_) {
#ifdef ENABLE_ARM64
RowMajor2Row16MajorFp16(src, dst, params_->deep_, params_->row_, src_data_type == kNumberTypeFloat32);
RowMajor2RowNMajorFp16((const float16_t *)src, dst, params_->deep_, params_->row_);
#else
RowMajor2Row12MajorFp16(src, dst, params_->deep_, params_->row_, src_data_type == kNumberTypeFloat32);
#endif
} else {
#ifdef ENABLE_ARM64
RowMajor2Col16MajorFp16(src, dst, params_->row_, params_->deep_, src_data_type == kNumberTypeFloat32);
RowMajor2ColNMajorFp16((const float16_t *)src, dst, params_->row_, params_->deep_);
#else
RowMajor2Col12MajorFp16(src, dst, params_->row_, params_->deep_, src_data_type == kNumberTypeFloat32);
#endif
@ -281,7 +273,7 @@ int MatmulBaseFP16CPUKernel::RunImpl(int task_id) {
return RET_OK;
}
auto bias = (bias_ptr_ == nullptr) ? nullptr : bias_ptr_ + thread_stride_ * task_id;
auto bias = bias_ptr_ + thread_stride_ * task_id;
auto b = batch_b_ptr_ + task_id * thread_stride_ * params_->deep_;
auto c = batch_c_ptr_ + task_id * thread_stride_;
@ -292,8 +284,13 @@ int MatmulBaseFP16CPUKernel::RunImpl(int task_id) {
MatVecMulFp16(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc);
#endif
} else {
#ifdef ENABLE_ARM64
MatmulBaseFp16Neon(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, params_->row_, cur_oc,
params_->col_, OutType_Nhwc);
#else
MatMulFp16(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, params_->row_, cur_oc, params_->col_,
OutType_Nhwc);
#endif
}
return RET_OK;
}

View File

@ -55,7 +55,7 @@ void MatmulFP16CPUKernel::InitBShape() {
int MatmulFP16CPUKernel::Init() {
#ifdef ENABLE_ARM64
row_tile_ = C16NUM;
row_tile_ = C4NUM;
#else
row_tile_ = C12NUM;
#endif