forked from mindspore-Ecosystem/mindspore
!20130 [MS][LITE][CPU] arm32 fp16 算子优化
Merge pull request !20130 from liuzhongkai/pack
This commit is contained in:
commit
f33c06abae
|
@ -0,0 +1,943 @@
|
|||
#ifdef ENABLE_ARM64
|
||||
#include "nnacl/assembly_global.h"
|
||||
|
||||
.text
|
||||
.align 5
|
||||
|
||||
// void MatmulBaseFp16Neon(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
|
||||
// int depth, int row, int col, size_t stride, size_t writeMode)
|
||||
// x0: a
|
||||
// x1: b
|
||||
// x2: c
|
||||
// x3: bias
|
||||
// x4: act_type
|
||||
// x5: depth
|
||||
// x6: row
|
||||
// x7: col
|
||||
// x8: stride
|
||||
// x9: writeMode
|
||||
|
||||
asm_function MatmulBaseFp16Neon
|
||||
sub sp, sp, #96
|
||||
st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64
|
||||
stp x19, x20, [sp], #16
|
||||
stp x21, x22, [sp], #16
|
||||
|
||||
ldr x8, [sp]
|
||||
ldr x9, [sp, #8] // act
|
||||
add x8, x8, x8 // stride * sizeof(float16_t)
|
||||
|
||||
add x16, x7, x7 // col * sizeof(float16_t)
|
||||
add x17, x5, x5 // depth * zieof(float16_t)
|
||||
mov x11, x2
|
||||
dup v12.8h, wzr
|
||||
movi v13.8h, #0x46, lsl #8
|
||||
LoopRowStart:
|
||||
cmp x6, #16
|
||||
bge LoopRow16
|
||||
cmp x6, #8
|
||||
bge LoopRow8
|
||||
b LoopRow4
|
||||
|
||||
LoopRow16:
|
||||
mov x15, #16
|
||||
mov x14, x1 // reload rhs ptr
|
||||
mov x13, x7 // reload rhs col
|
||||
mov x12, x3 // reload bias
|
||||
|
||||
LoopCol16:
|
||||
mov x11, x2
|
||||
mov x10, x0 // reload lhs ptr
|
||||
mov x19, x5 // reload depth
|
||||
|
||||
ld1 {v16.8h}, [x12], #16
|
||||
mov v17.8h, v16.8h
|
||||
mov v18.8h, v16.8h
|
||||
mov v19.8h, v16.8h
|
||||
mov v20.8h, v16.8h
|
||||
mov v21.8h, v16.8h
|
||||
mov v22.8h, v16.8h
|
||||
mov v23.8h, v16.8h
|
||||
mov v24.8h, v16.8h
|
||||
mov v25.8h, v16.8h
|
||||
mov v26.8h, v16.8h
|
||||
mov v27.8h, v16.8h
|
||||
mov v28.8h, v16.8h
|
||||
mov v29.8h, v16.8h
|
||||
mov v30.8h, v16.8h
|
||||
mov v31.8h, v16.8h
|
||||
|
||||
cmp x19, #4
|
||||
blt LoopDepth16One
|
||||
|
||||
LoopDepth16:
|
||||
ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64
|
||||
ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64
|
||||
fmla v16.8h, v8.8h, v0.h[0]
|
||||
fmla v17.8h, v8.8h, v0.h[1]
|
||||
fmla v18.8h, v8.8h, v0.h[2]
|
||||
fmla v19.8h, v8.8h, v0.h[3]
|
||||
fmla v20.8h, v8.8h, v0.h[4]
|
||||
fmla v21.8h, v8.8h, v0.h[5]
|
||||
fmla v22.8h, v8.8h, v0.h[6]
|
||||
fmla v23.8h, v8.8h, v0.h[7]
|
||||
fmla v24.8h, v8.8h, v1.h[0]
|
||||
fmla v25.8h, v8.8h, v1.h[1]
|
||||
fmla v26.8h, v8.8h, v1.h[2]
|
||||
fmla v27.8h, v8.8h, v1.h[3]
|
||||
fmla v28.8h, v8.8h, v1.h[4]
|
||||
fmla v29.8h, v8.8h, v1.h[5]
|
||||
fmla v30.8h, v8.8h, v1.h[6]
|
||||
fmla v31.8h, v8.8h, v1.h[7]
|
||||
ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64
|
||||
fmla v16.8h, v9.8h, v2.h[0]
|
||||
fmla v17.8h, v9.8h, v2.h[1]
|
||||
fmla v18.8h, v9.8h, v2.h[2]
|
||||
fmla v19.8h, v9.8h, v2.h[3]
|
||||
fmla v20.8h, v9.8h, v2.h[4]
|
||||
fmla v21.8h, v9.8h, v2.h[5]
|
||||
fmla v22.8h, v9.8h, v2.h[6]
|
||||
fmla v23.8h, v9.8h, v2.h[7]
|
||||
fmla v24.8h, v9.8h, v3.h[0]
|
||||
fmla v25.8h, v9.8h, v3.h[1]
|
||||
fmla v26.8h, v9.8h, v3.h[2]
|
||||
fmla v27.8h, v9.8h, v3.h[3]
|
||||
fmla v28.8h, v9.8h, v3.h[4]
|
||||
fmla v29.8h, v9.8h, v3.h[5]
|
||||
fmla v30.8h, v9.8h, v3.h[6]
|
||||
fmla v31.8h, v9.8h, v3.h[7]
|
||||
fmla v16.8h, v10.8h, v4.h[0]
|
||||
fmla v17.8h, v10.8h, v4.h[1]
|
||||
fmla v18.8h, v10.8h, v4.h[2]
|
||||
fmla v19.8h, v10.8h, v4.h[3]
|
||||
fmla v20.8h, v10.8h, v4.h[4]
|
||||
fmla v21.8h, v10.8h, v4.h[5]
|
||||
fmla v22.8h, v10.8h, v4.h[6]
|
||||
fmla v23.8h, v10.8h, v4.h[7]
|
||||
fmla v24.8h, v10.8h, v5.h[0]
|
||||
fmla v25.8h, v10.8h, v5.h[1]
|
||||
fmla v26.8h, v10.8h, v5.h[2]
|
||||
fmla v27.8h, v10.8h, v5.h[3]
|
||||
fmla v28.8h, v10.8h, v5.h[4]
|
||||
fmla v29.8h, v10.8h, v5.h[5]
|
||||
fmla v30.8h, v10.8h, v5.h[6]
|
||||
fmla v31.8h, v10.8h, v5.h[7]
|
||||
fmla v16.8h, v11.8h, v6.h[0]
|
||||
fmla v17.8h, v11.8h, v6.h[1]
|
||||
fmla v18.8h, v11.8h, v6.h[2]
|
||||
fmla v19.8h, v11.8h, v6.h[3]
|
||||
fmla v20.8h, v11.8h, v6.h[4]
|
||||
fmla v21.8h, v11.8h, v6.h[5]
|
||||
fmla v22.8h, v11.8h, v6.h[6]
|
||||
fmla v23.8h, v11.8h, v6.h[7]
|
||||
fmla v24.8h, v11.8h, v7.h[0]
|
||||
fmla v25.8h, v11.8h, v7.h[1]
|
||||
fmla v26.8h, v11.8h, v7.h[2]
|
||||
fmla v27.8h, v11.8h, v7.h[3]
|
||||
fmla v28.8h, v11.8h, v7.h[4]
|
||||
fmla v29.8h, v11.8h, v7.h[5]
|
||||
fmla v30.8h, v11.8h, v7.h[6]
|
||||
fmla v31.8h, v11.8h, v7.h[7]
|
||||
|
||||
subs x19, x19, #4
|
||||
beq Activation16
|
||||
cmp x19, #4
|
||||
bge LoopDepth16
|
||||
|
||||
LoopDepth16One:
|
||||
ld1 {v0.8h, v1.8h}, [x10], #32
|
||||
ld1 {v2.8h}, [x14], #16
|
||||
fmla v16.8h, v2.8h, v0.h[0]
|
||||
fmla v17.8h, v2.8h, v0.h[1]
|
||||
fmla v18.8h, v2.8h, v0.h[2]
|
||||
fmla v19.8h, v2.8h, v0.h[3]
|
||||
fmla v20.8h, v2.8h, v0.h[4]
|
||||
fmla v21.8h, v2.8h, v0.h[5]
|
||||
fmla v22.8h, v2.8h, v0.h[6]
|
||||
fmla v23.8h, v2.8h, v0.h[7]
|
||||
fmla v24.8h, v2.8h, v1.h[0]
|
||||
fmla v25.8h, v2.8h, v1.h[1]
|
||||
fmla v26.8h, v2.8h, v1.h[2]
|
||||
fmla v27.8h, v2.8h, v1.h[3]
|
||||
fmla v28.8h, v2.8h, v1.h[4]
|
||||
fmla v29.8h, v2.8h, v1.h[5]
|
||||
fmla v30.8h, v2.8h, v1.h[6]
|
||||
fmla v31.8h, v2.8h, v1.h[7]
|
||||
subs x19, x19, #1
|
||||
bgt LoopDepth16One
|
||||
|
||||
Activation16:
|
||||
cmp x4, #3
|
||||
beq Relu616
|
||||
cmp x4, #1
|
||||
beq Relu16
|
||||
b Write16
|
||||
Relu616:
|
||||
fmin v16.8h, v16.8h, v13.8h
|
||||
fmin v17.8h, v17.8h, v13.8h
|
||||
fmin v18.8h, v18.8h, v13.8h
|
||||
fmin v19.8h, v19.8h, v13.8h
|
||||
fmin v20.8h, v20.8h, v13.8h
|
||||
fmin v21.8h, v21.8h, v13.8h
|
||||
fmin v22.8h, v22.8h, v13.8h
|
||||
fmin v23.8h, v23.8h, v13.8h
|
||||
fmin v24.8h, v24.8h, v13.8h
|
||||
fmin v25.8h, v25.8h, v13.8h
|
||||
fmin v26.8h, v26.8h, v13.8h
|
||||
fmin v27.8h, v27.8h, v13.8h
|
||||
fmin v28.8h, v28.8h, v13.8h
|
||||
fmin v29.8h, v29.8h, v13.8h
|
||||
fmin v30.8h, v30.8h, v13.8h
|
||||
fmin v31.8h, v31.8h, v13.8h
|
||||
Relu16:
|
||||
fmax v16.8h, v16.8h, v12.8h
|
||||
fmax v17.8h, v17.8h, v12.8h
|
||||
fmax v18.8h, v18.8h, v12.8h
|
||||
fmax v19.8h, v19.8h, v12.8h
|
||||
fmax v20.8h, v20.8h, v12.8h
|
||||
fmax v21.8h, v21.8h, v12.8h
|
||||
fmax v22.8h, v22.8h, v12.8h
|
||||
fmax v23.8h, v23.8h, v12.8h
|
||||
fmax v24.8h, v24.8h, v12.8h
|
||||
fmax v25.8h, v25.8h, v12.8h
|
||||
fmax v26.8h, v26.8h, v12.8h
|
||||
fmax v27.8h, v27.8h, v12.8h
|
||||
fmax v28.8h, v28.8h, v12.8h
|
||||
fmax v29.8h, v29.8h, v12.8h
|
||||
fmax v30.8h, v30.8h, v12.8h
|
||||
fmax v31.8h, v31.8h, v12.8h
|
||||
Write16:
|
||||
cmp x13, #8
|
||||
bge Write16x8
|
||||
b Write
|
||||
Write16x8:
|
||||
add x2, x2, #16
|
||||
st1 {v16.8h}, [x11], x8
|
||||
st1 {v17.8h}, [x11], x8
|
||||
st1 {v18.8h}, [x11], x8
|
||||
st1 {v19.8h}, [x11], x8
|
||||
st1 {v20.8h}, [x11], x8
|
||||
st1 {v21.8h}, [x11], x8
|
||||
st1 {v22.8h}, [x11], x8
|
||||
st1 {v23.8h}, [x11], x8
|
||||
st1 {v24.8h}, [x11], x8
|
||||
st1 {v25.8h}, [x11], x8
|
||||
st1 {v26.8h}, [x11], x8
|
||||
st1 {v27.8h}, [x11], x8
|
||||
st1 {v28.8h}, [x11], x8
|
||||
st1 {v29.8h}, [x11], x8
|
||||
st1 {v30.8h}, [x11], x8
|
||||
st1 {v31.8h}, [x11], x8
|
||||
b WriteEnd
|
||||
|
||||
LoopRow8:
|
||||
mov x15, #8
|
||||
mov x14, x1 // reload rhs ptr
|
||||
mov x13, x7 // reload rhs col
|
||||
mov x12, x3 // reload bias
|
||||
|
||||
LoopCol8:
|
||||
mov x11, x2
|
||||
mov x10, x0 // reload lhs ptr
|
||||
mov x19, x5 // reload depth
|
||||
|
||||
ld1 {v16.8h}, [x12], #16
|
||||
mov v17.8h, v16.8h
|
||||
mov v18.8h, v16.8h
|
||||
mov v19.8h, v16.8h
|
||||
mov v20.8h, v16.8h
|
||||
mov v21.8h, v16.8h
|
||||
mov v22.8h, v16.8h
|
||||
mov v23.8h, v16.8h
|
||||
|
||||
cmp x19, #4
|
||||
blt LoopDepth8One
|
||||
|
||||
LoopDepth8:
|
||||
ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64
|
||||
ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64
|
||||
fmla v16.8h, v8.8h, v0.h[0]
|
||||
fmla v17.8h, v8.8h, v0.h[1]
|
||||
fmla v18.8h, v8.8h, v0.h[2]
|
||||
fmla v19.8h, v8.8h, v0.h[3]
|
||||
fmla v20.8h, v8.8h, v0.h[4]
|
||||
fmla v21.8h, v8.8h, v0.h[5]
|
||||
fmla v22.8h, v8.8h, v0.h[6]
|
||||
fmla v23.8h, v8.8h, v0.h[7]
|
||||
fmla v16.8h, v9.8h, v1.h[0]
|
||||
fmla v17.8h, v9.8h, v1.h[1]
|
||||
fmla v18.8h, v9.8h, v1.h[2]
|
||||
fmla v19.8h, v9.8h, v1.h[3]
|
||||
fmla v20.8h, v9.8h, v1.h[4]
|
||||
fmla v21.8h, v9.8h, v1.h[5]
|
||||
fmla v22.8h, v9.8h, v1.h[6]
|
||||
fmla v23.8h, v9.8h, v1.h[7]
|
||||
fmla v16.8h, v10.8h, v2.h[0]
|
||||
fmla v17.8h, v10.8h, v2.h[1]
|
||||
fmla v18.8h, v10.8h, v2.h[2]
|
||||
fmla v19.8h, v10.8h, v2.h[3]
|
||||
fmla v20.8h, v10.8h, v2.h[4]
|
||||
fmla v21.8h, v10.8h, v2.h[5]
|
||||
fmla v22.8h, v10.8h, v2.h[6]
|
||||
fmla v23.8h, v10.8h, v2.h[7]
|
||||
fmla v16.8h, v11.8h, v3.h[0]
|
||||
fmla v17.8h, v11.8h, v3.h[1]
|
||||
fmla v18.8h, v11.8h, v3.h[2]
|
||||
fmla v19.8h, v11.8h, v3.h[3]
|
||||
fmla v20.8h, v11.8h, v3.h[4]
|
||||
fmla v21.8h, v11.8h, v3.h[5]
|
||||
fmla v22.8h, v11.8h, v3.h[6]
|
||||
fmla v23.8h, v11.8h, v3.h[7]
|
||||
subs x19, x19, #4
|
||||
beq Activation8
|
||||
cmp x19, #4
|
||||
bge LoopDepth8
|
||||
LoopDepth8One:
|
||||
ld1 {v0.8h}, [x10], #16
|
||||
ld1 {v2.8h}, [x14], #16
|
||||
fmla v16.8h, v2.8h, v0.h[0]
|
||||
fmla v17.8h, v2.8h, v0.h[1]
|
||||
fmla v18.8h, v2.8h, v0.h[2]
|
||||
fmla v19.8h, v2.8h, v0.h[3]
|
||||
fmla v20.8h, v2.8h, v0.h[4]
|
||||
fmla v21.8h, v2.8h, v0.h[5]
|
||||
fmla v22.8h, v2.8h, v0.h[6]
|
||||
fmla v23.8h, v2.8h, v0.h[7]
|
||||
subs x19, x19, #1
|
||||
bgt LoopDepth8One
|
||||
Activation8:
|
||||
cmp x4, #3
|
||||
beq Relu68
|
||||
cmp x4, #1
|
||||
beq Relu8
|
||||
b Write8_Row
|
||||
Relu68:
|
||||
fmin v16.8h, v16.8h, v13.8h
|
||||
fmin v17.8h, v17.8h, v13.8h
|
||||
fmin v18.8h, v18.8h, v13.8h
|
||||
fmin v19.8h, v19.8h, v13.8h
|
||||
fmin v20.8h, v20.8h, v13.8h
|
||||
fmin v21.8h, v21.8h, v13.8h
|
||||
fmin v22.8h, v22.8h, v13.8h
|
||||
fmin v23.8h, v23.8h, v13.8h
|
||||
Relu8:
|
||||
fmax v16.8h, v16.8h, v12.8h
|
||||
fmax v17.8h, v17.8h, v12.8h
|
||||
fmax v18.8h, v18.8h, v12.8h
|
||||
fmax v19.8h, v19.8h, v12.8h
|
||||
fmax v20.8h, v20.8h, v12.8h
|
||||
fmax v21.8h, v21.8h, v12.8h
|
||||
fmax v22.8h, v22.8h, v12.8h
|
||||
fmax v23.8h, v23.8h, v12.8h
|
||||
Write8_Row:
|
||||
cmp x13, #8 // row
|
||||
bge Write8x8
|
||||
b Write
|
||||
Write8x8:
|
||||
add x2, x2, #16
|
||||
st1 {v16.8h}, [x11], x8
|
||||
st1 {v17.8h}, [x11], x8
|
||||
st1 {v18.8h}, [x11], x8
|
||||
st1 {v19.8h}, [x11], x8
|
||||
st1 {v20.8h}, [x11], x8
|
||||
st1 {v21.8h}, [x11], x8
|
||||
st1 {v22.8h}, [x11], x8
|
||||
st1 {v23.8h}, [x11], x8
|
||||
b WriteEnd
|
||||
|
||||
LoopRow4:
|
||||
mov x15, #4
|
||||
mov x14, x1 // reload rhs ptr
|
||||
mov x13, x7 // reload rhs col
|
||||
mov x12, x3 // reload bias
|
||||
|
||||
LoopCol4:
|
||||
mov x11, x2
|
||||
mov x10, x0 // reload lhs ptr
|
||||
mov x19, x5 // reload depth
|
||||
ld1 {v16.8h}, [x12], #16
|
||||
mov v17.8h, v16.8h
|
||||
mov v18.8h, v16.8h
|
||||
mov v19.8h, v16.8h
|
||||
cmp x19, #4
|
||||
blt LoopDepth4One
|
||||
LoopDepth4:
|
||||
ld1 {v0.8h, v1.8h}, [x10], #32
|
||||
ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64
|
||||
fmla v16.8h, v8.8h, v0.h[0]
|
||||
fmla v17.8h, v8.8h, v0.h[1]
|
||||
fmla v18.8h, v8.8h, v0.h[2]
|
||||
fmla v19.8h, v8.8h, v0.h[3]
|
||||
fmla v16.8h, v9.8h, v0.h[4]
|
||||
fmla v17.8h, v9.8h, v0.h[5]
|
||||
fmla v18.8h, v9.8h, v0.h[6]
|
||||
fmla v19.8h, v9.8h, v0.h[7]
|
||||
fmla v16.8h, v10.8h, v1.h[0]
|
||||
fmla v17.8h, v10.8h, v1.h[1]
|
||||
fmla v18.8h, v10.8h, v1.h[2]
|
||||
fmla v19.8h, v10.8h, v1.h[3]
|
||||
fmla v16.8h, v11.8h, v1.h[4]
|
||||
fmla v17.8h, v11.8h, v1.h[5]
|
||||
fmla v18.8h, v11.8h, v1.h[6]
|
||||
fmla v19.8h, v11.8h, v1.h[7]
|
||||
subs x19, x19, #4
|
||||
beq Activation4
|
||||
cmp x19, #4
|
||||
bge LoopDepth4
|
||||
LoopDepth4One:
|
||||
ld1 {v0.4h}, [x10], #8
|
||||
ld1 {v2.8h}, [x14], #16
|
||||
fmla v16.8h, v2.8h, v0.h[0]
|
||||
fmla v17.8h, v2.8h, v0.h[1]
|
||||
fmla v18.8h, v2.8h, v0.h[2]
|
||||
fmla v19.8h, v2.8h, v0.h[3]
|
||||
subs x19, x19, #1
|
||||
bgt LoopDepth4One
|
||||
Activation4:
|
||||
cmp x4, #3
|
||||
beq Relu64
|
||||
cmp x4, #1
|
||||
beq Relu4
|
||||
b Write4_Row
|
||||
Relu64:
|
||||
fmin v16.8h, v16.8h, v13.8h
|
||||
fmin v17.8h, v17.8h, v13.8h
|
||||
fmin v18.8h, v18.8h, v13.8h
|
||||
fmin v19.8h, v19.8h, v13.8h
|
||||
Relu4:
|
||||
fmax v16.8h, v16.8h, v12.8h
|
||||
fmax v17.8h, v17.8h, v12.8h
|
||||
fmax v18.8h, v18.8h, v12.8h
|
||||
fmax v19.8h, v19.8h, v12.8h
|
||||
Write4_Row:
|
||||
cmp x6, #4
|
||||
bge Write4x8
|
||||
b Write
|
||||
Write4x8:
|
||||
cmp x13, #8
|
||||
blt Write
|
||||
add x2, x2, #16
|
||||
st1 {v16.8h}, [x11], x8
|
||||
st1 {v17.8h}, [x11], x8
|
||||
st1 {v18.8h}, [x11], x8
|
||||
st1 {v19.8h}, [x11], x8
|
||||
b WriteEnd
|
||||
|
||||
Write:
|
||||
cmp x13, #1
|
||||
beq Write1
|
||||
cmp x13, #2
|
||||
beq Write2
|
||||
cmp x13, #3
|
||||
beq Write3
|
||||
cmp x13, #4
|
||||
beq Write4
|
||||
cmp x13, #5
|
||||
beq Write5
|
||||
cmp x13, #6
|
||||
beq Write6
|
||||
cmp x13, #7
|
||||
beq Write7
|
||||
b Write8
|
||||
|
||||
Write1:
|
||||
add x2, x2, #2
|
||||
st1 {v16.h}[0], [x11], x8
|
||||
cmp x6, #1
|
||||
beq WriteEnd
|
||||
st1 {v17.h}[0], [x11], x8
|
||||
cmp x6, #2
|
||||
beq WriteEnd
|
||||
st1 {v18.h}[0], [x11], x8
|
||||
cmp x6, #3
|
||||
beq WriteEnd
|
||||
st1 {v19.h}[0], [x11], x8
|
||||
cmp x6, #4
|
||||
beq WriteEnd
|
||||
st1 {v20.h}[0], [x11], x8
|
||||
cmp x6, #5
|
||||
beq WriteEnd
|
||||
st1 {v21.h}[0], [x11], x8
|
||||
cmp x6, #6
|
||||
beq WriteEnd
|
||||
st1 {v22.h}[0], [x11], x8
|
||||
cmp x6, #7
|
||||
beq WriteEnd
|
||||
st1 {v23.h}[0], [x11], x8
|
||||
cmp x6, #8
|
||||
beq WriteEnd
|
||||
st1 {v24.h}[0], [x11], x8
|
||||
cmp x6, #9
|
||||
beq WriteEnd
|
||||
st1 {v25.h}[0], [x11], x8
|
||||
cmp x6, #10
|
||||
beq WriteEnd
|
||||
st1 {v26.h}[0], [x11], x8
|
||||
cmp x6, #11
|
||||
beq WriteEnd
|
||||
st1 {v27.h}[0], [x11], x8
|
||||
cmp x6, #12
|
||||
beq WriteEnd
|
||||
st1 {v28.h}[0], [x11], x8
|
||||
cmp x6, #13
|
||||
beq WriteEnd
|
||||
st1 {v29.h}[0], [x11], x8
|
||||
cmp x6, #14
|
||||
beq WriteEnd
|
||||
st1 {v30.h}[0], [x11], x8
|
||||
cmp x6, #15
|
||||
beq WriteEnd
|
||||
st1 {v31.h}[0], [x11], x8
|
||||
b WriteEnd
|
||||
Write2:
|
||||
add x2, x2, #4
|
||||
st1 {v16.s}[0], [x11], x8
|
||||
cmp x6, #1
|
||||
beq WriteEnd
|
||||
st1 {v17.s}[0], [x11], x8
|
||||
cmp x6, #2
|
||||
beq WriteEnd
|
||||
st1 {v18.s}[0], [x11], x8
|
||||
cmp x6, #3
|
||||
beq WriteEnd
|
||||
st1 {v19.s}[0], [x11], x8
|
||||
cmp x6, #4
|
||||
beq WriteEnd
|
||||
st1 {v20.s}[0], [x11], x8
|
||||
cmp x6, #5
|
||||
beq WriteEnd
|
||||
st1 {v21.s}[0], [x11], x8
|
||||
cmp x6, #6
|
||||
beq WriteEnd
|
||||
st1 {v22.s}[0], [x11], x8
|
||||
cmp x6, #7
|
||||
beq WriteEnd
|
||||
st1 {v23.s}[0], [x11], x8
|
||||
cmp x6, #8
|
||||
beq WriteEnd
|
||||
st1 {v24.s}[0], [x11], x8
|
||||
cmp x6, #9
|
||||
beq WriteEnd
|
||||
st1 {v25.s}[0], [x11], x8
|
||||
cmp x6, #10
|
||||
beq WriteEnd
|
||||
st1 {v26.s}[0], [x11], x8
|
||||
cmp x6, #11
|
||||
beq WriteEnd
|
||||
st1 {v27.s}[0], [x11], x8
|
||||
cmp x6, #12
|
||||
beq WriteEnd
|
||||
st1 {v28.s}[0], [x11], x8
|
||||
cmp x6, #13
|
||||
beq WriteEnd
|
||||
st1 {v29.s}[0], [x11], x8
|
||||
cmp x6, #14
|
||||
beq WriteEnd
|
||||
st1 {v30.s}[0], [x11], x8
|
||||
cmp x6, #15
|
||||
beq WriteEnd
|
||||
st1 {v31.s}[0], [x11], x8
|
||||
b WriteEnd
|
||||
Write3:
|
||||
add x2, x2, #6
|
||||
add x19, x11, #4
|
||||
st1 {v16.s}[0], [x11], x8
|
||||
st1 {v16.h}[2], [x19], x8
|
||||
cmp x6, #1
|
||||
beq WriteEnd
|
||||
st1 {v17.s}[0], [x11], x8
|
||||
st1 {v17.h}[2], [x19], x8
|
||||
cmp x6, #2
|
||||
beq WriteEnd
|
||||
st1 {v18.s}[0], [x11], x8
|
||||
st1 {v18.h}[2], [x19], x8
|
||||
cmp x6, #3
|
||||
beq WriteEnd
|
||||
st1 {v19.s}[0], [x11], x8
|
||||
st1 {v19.h}[2], [x19], x8
|
||||
cmp x6, #4
|
||||
beq WriteEnd
|
||||
st1 {v20.s}[0], [x11], x8
|
||||
st1 {v20.h}[2], [x19], x8
|
||||
cmp x6, #5
|
||||
beq WriteEnd
|
||||
st1 {v21.s}[0], [x11], x8
|
||||
st1 {v21.h}[2], [x19], x8
|
||||
cmp x6, #6
|
||||
beq WriteEnd
|
||||
st1 {v22.s}[0], [x11], x8
|
||||
st1 {v22.h}[2], [x19], x8
|
||||
cmp x6, #7
|
||||
beq WriteEnd
|
||||
st1 {v23.s}[0], [x11], x8
|
||||
st1 {v23.h}[2], [x19], x8
|
||||
cmp x6, #8
|
||||
beq WriteEnd
|
||||
st1 {v24.s}[0], [x11], x8
|
||||
st1 {v24.h}[2], [x19], x8
|
||||
cmp x6, #9
|
||||
beq WriteEnd
|
||||
st1 {v25.s}[0], [x11], x8
|
||||
st1 {v25.h}[2], [x19], x8
|
||||
cmp x6, #10
|
||||
beq WriteEnd
|
||||
st1 {v26.s}[0], [x11], x8
|
||||
st1 {v26.h}[2], [x19], x8
|
||||
cmp x6, #11
|
||||
beq WriteEnd
|
||||
st1 {v27.s}[0], [x11], x8
|
||||
st1 {v27.h}[2], [x19], x8
|
||||
cmp x6, #12
|
||||
beq WriteEnd
|
||||
st1 {v28.s}[0], [x11], x8
|
||||
st1 {v28.h}[2], [x19], x8
|
||||
cmp x6, #13
|
||||
beq WriteEnd
|
||||
st1 {v29.s}[0], [x11], x8
|
||||
st1 {v29.h}[2], [x19], x8
|
||||
cmp x6, #14
|
||||
beq WriteEnd
|
||||
st1 {v30.s}[0], [x11], x8
|
||||
st1 {v30.h}[2], [x19], x8
|
||||
cmp x6, #15
|
||||
beq WriteEnd
|
||||
st1 {v31.s}[0], [x11], x8
|
||||
st1 {v31.h}[2], [x19]
|
||||
b WriteEnd
|
||||
Write4:
|
||||
add x2, x2, #8
|
||||
st1 {v16.4h}, [x11], x8
|
||||
cmp x6, #1
|
||||
beq WriteEnd
|
||||
st1 {v17.4h}, [x11], x8
|
||||
cmp x6, #2
|
||||
beq WriteEnd
|
||||
st1 {v18.4h}, [x11], x8
|
||||
cmp x6, #3
|
||||
beq WriteEnd
|
||||
st1 {v19.4h}, [x11], x8
|
||||
cmp x6, #4
|
||||
beq WriteEnd
|
||||
st1 {v20.4h}, [x11], x8
|
||||
cmp x6, #5
|
||||
beq WriteEnd
|
||||
st1 {v21.4h}, [x11], x8
|
||||
cmp x6, #6
|
||||
beq WriteEnd
|
||||
st1 {v22.4h}, [x11], x8
|
||||
cmp x6, #7
|
||||
beq WriteEnd
|
||||
st1 {v23.4h}, [x11], x8
|
||||
cmp x6, #8
|
||||
beq WriteEnd
|
||||
st1 {v24.4h}, [x11], x8
|
||||
cmp x6, #9
|
||||
beq WriteEnd
|
||||
st1 {v25.4h}, [x11], x8
|
||||
cmp x6, #10
|
||||
beq WriteEnd
|
||||
st1 {v26.4h}, [x11], x8
|
||||
cmp x6, #11
|
||||
beq WriteEnd
|
||||
st1 {v27.4h}, [x11], x8
|
||||
cmp x6, #12
|
||||
beq WriteEnd
|
||||
st1 {v28.4h}, [x11], x8
|
||||
cmp x6, #13
|
||||
beq WriteEnd
|
||||
st1 {v29.4h}, [x11], x8
|
||||
cmp x6, #14
|
||||
beq WriteEnd
|
||||
st1 {v30.4h}, [x11], x8
|
||||
cmp x6, #15
|
||||
beq WriteEnd
|
||||
st1 {v31.4h}, [x11], x8
|
||||
b WriteEnd
|
||||
Write5:
|
||||
add x2, x2, #10
|
||||
add x19, x11, #8
|
||||
st1 {v16.4h}, [x11], x8
|
||||
st1 {v16.h}[4], [x19], x8
|
||||
cmp x6, #1
|
||||
beq WriteEnd
|
||||
st1 {v17.4h}, [x11], x8
|
||||
st1 {v17.h}[4], [x19], x8
|
||||
cmp x6, #2
|
||||
beq WriteEnd
|
||||
st1 {v18.4h}, [x11], x8
|
||||
st1 {v18.h}[4], [x19], x8
|
||||
cmp x6, #3
|
||||
beq WriteEnd
|
||||
st1 {v19.4h}, [x11], x8
|
||||
st1 {v19.h}[4], [x19], x8
|
||||
cmp x6, #4
|
||||
beq WriteEnd
|
||||
st1 {v20.4h}, [x11], x8
|
||||
st1 {v20.h}[4], [x19], x8
|
||||
cmp x6, #5
|
||||
beq WriteEnd
|
||||
st1 {v21.4h}, [x11], x8
|
||||
st1 {v21.h}[4], [x19], x8
|
||||
cmp x6, #6
|
||||
beq WriteEnd
|
||||
st1 {v22.4h}, [x11], x8
|
||||
st1 {v22.h}[4], [x19], x8
|
||||
cmp x6, #7
|
||||
beq WriteEnd
|
||||
st1 {v23.4h}, [x11], x8
|
||||
st1 {v23.h}[4], [x19], x8
|
||||
cmp x6, #8
|
||||
beq WriteEnd
|
||||
st1 {v24.4h}, [x11], x8
|
||||
st1 {v24.h}[4], [x19], x8
|
||||
cmp x6, #9
|
||||
beq WriteEnd
|
||||
st1 {v25.4h}, [x11], x8
|
||||
st1 {v25.h}[4], [x19], x8
|
||||
cmp x6, #10
|
||||
beq WriteEnd
|
||||
st1 {v26.4h}, [x11], x8
|
||||
st1 {v26.h}[4], [x19], x8
|
||||
cmp x6, #11
|
||||
beq WriteEnd
|
||||
st1 {v27.4h}, [x11], x8
|
||||
st1 {v27.h}[4], [x19], x8
|
||||
cmp x6, #12
|
||||
beq WriteEnd
|
||||
st1 {v28.4h}, [x11], x8
|
||||
st1 {v28.h}[4], [x19], x8
|
||||
cmp x6, #13
|
||||
beq WriteEnd
|
||||
st1 {v29.4h}, [x11], x8
|
||||
st1 {v29.h}[4], [x19], x8
|
||||
cmp x6, #14
|
||||
beq WriteEnd
|
||||
st1 {v30.4h}, [x11], x8
|
||||
st1 {v30.h}[4], [x19], x8
|
||||
cmp x6, #15
|
||||
beq WriteEnd
|
||||
st1 {v31.4h}, [x11], x8
|
||||
st1 {v31.h}[4], [x19]
|
||||
b WriteEnd
|
||||
Write6:
|
||||
add x2, x2, #12
|
||||
add x19, x11, #8
|
||||
st1 {v16.4h}, [x11], x8
|
||||
st1 {v16.s}[2], [x19], x8
|
||||
cmp x6, #1
|
||||
beq WriteEnd
|
||||
st1 {v17.4h}, [x11], x8
|
||||
st1 {v17.s}[2], [x19], x8
|
||||
cmp x6, #2
|
||||
beq WriteEnd
|
||||
st1 {v18.4h}, [x11], x8
|
||||
st1 {v18.s}[2], [x19], x8
|
||||
cmp x6, #3
|
||||
beq WriteEnd
|
||||
st1 {v19.4h}, [x11], x8
|
||||
st1 {v19.s}[2], [x19], x8
|
||||
cmp x6, #4
|
||||
beq WriteEnd
|
||||
st1 {v20.4h}, [x11], x8
|
||||
st1 {v20.s}[2], [x19], x8
|
||||
cmp x6, #5
|
||||
beq WriteEnd
|
||||
st1 {v21.4h}, [x11], x8
|
||||
st1 {v21.s}[2], [x19], x8
|
||||
cmp x6, #6
|
||||
beq WriteEnd
|
||||
st1 {v22.4h}, [x11], x8
|
||||
st1 {v22.s}[2], [x19], x8
|
||||
cmp x6, #7
|
||||
beq WriteEnd
|
||||
st1 {v23.4h}, [x11], x8
|
||||
st1 {v23.s}[2], [x19], x8
|
||||
cmp x6, #8
|
||||
beq WriteEnd
|
||||
st1 {v24.4h}, [x11], x8
|
||||
st1 {v24.s}[2], [x19], x8
|
||||
cmp x6, #9
|
||||
beq WriteEnd
|
||||
st1 {v25.4h}, [x11], x8
|
||||
st1 {v25.s}[2], [x19], x8
|
||||
cmp x6, #10
|
||||
beq WriteEnd
|
||||
st1 {v26.4h}, [x11], x8
|
||||
st1 {v26.s}[2], [x19], x8
|
||||
cmp x6, #11
|
||||
beq WriteEnd
|
||||
st1 {v27.4h}, [x11], x8
|
||||
st1 {v27.s}[2], [x19], x8
|
||||
cmp x6, #12
|
||||
beq WriteEnd
|
||||
st1 {v28.4h}, [x11], x8
|
||||
st1 {v28.s}[2], [x19], x8
|
||||
cmp x6, #13
|
||||
beq WriteEnd
|
||||
st1 {v29.4h}, [x11], x8
|
||||
st1 {v29.s}[2], [x19], x8
|
||||
cmp x6, #14
|
||||
beq WriteEnd
|
||||
st1 {v30.4h}, [x11], x8
|
||||
st1 {v30.s}[2], [x19], x8
|
||||
cmp x6, #15
|
||||
beq WriteEnd
|
||||
st1 {v31.4h}, [x11], x8
|
||||
st1 {v31.s}[2], [x19]
|
||||
b WriteEnd
|
||||
Write7:
|
||||
add x2, x2, #14
|
||||
add x19, x11, #8
|
||||
add x10, x11, #12
|
||||
st1 {v16.4h}, [x11], x8
|
||||
st1 {v16.s}[2], [x19], x8
|
||||
st1 {v16.h}[6], [x10], x8
|
||||
cmp x6, #1
|
||||
beq WriteEnd
|
||||
st1 {v17.4h}, [x11], x8
|
||||
st1 {v17.s}[2], [x19], x8
|
||||
st1 {v17.h}[6], [x10], x8
|
||||
cmp x6, #2
|
||||
beq WriteEnd
|
||||
st1 {v18.4h}, [x11], x8
|
||||
st1 {v18.s}[2], [x19], x8
|
||||
st1 {v18.h}[6], [x10], x8
|
||||
cmp x6, #3
|
||||
beq WriteEnd
|
||||
st1 {v19.4h}, [x11], x8
|
||||
st1 {v19.s}[2], [x19], x8
|
||||
st1 {v19.h}[6], [x10], x8
|
||||
cmp x6, #4
|
||||
beq WriteEnd
|
||||
st1 {v20.4h}, [x11], x8
|
||||
st1 {v20.s}[2], [x19], x8
|
||||
st1 {v20.h}[6], [x10], x8
|
||||
cmp x6, #5
|
||||
beq WriteEnd
|
||||
st1 {v21.4h}, [x11], x8
|
||||
st1 {v21.s}[2], [x19], x8
|
||||
st1 {v21.h}[6], [x10], x8
|
||||
cmp x6, #6
|
||||
beq WriteEnd
|
||||
st1 {v22.4h}, [x11], x8
|
||||
st1 {v22.s}[2], [x19], x8
|
||||
st1 {v22.h}[6], [x10], x8
|
||||
cmp x6, #7
|
||||
beq WriteEnd
|
||||
st1 {v23.4h}, [x11], x8
|
||||
st1 {v23.s}[2], [x19], x8
|
||||
st1 {v23.h}[6], [x10], x8
|
||||
cmp x6, #8
|
||||
beq WriteEnd
|
||||
st1 {v24.4h}, [x11], x8
|
||||
st1 {v24.s}[2], [x19], x8
|
||||
st1 {v24.h}[6], [x10], x8
|
||||
cmp x6, #9
|
||||
beq WriteEnd
|
||||
st1 {v25.4h}, [x11], x8
|
||||
st1 {v25.s}[2], [x19], x8
|
||||
st1 {v25.h}[6], [x10], x8
|
||||
cmp x6, #10
|
||||
beq WriteEnd
|
||||
st1 {v26.4h}, [x11], x8
|
||||
st1 {v26.s}[2], [x19], x8
|
||||
st1 {v26.h}[6], [x10], x8
|
||||
cmp x6, #11
|
||||
beq WriteEnd
|
||||
st1 {v27.4h}, [x11], x8
|
||||
st1 {v27.s}[2], [x19], x8
|
||||
st1 {v27.h}[6], [x10], x8
|
||||
cmp x6, #12
|
||||
beq WriteEnd
|
||||
st1 {v28.4h}, [x11], x8
|
||||
st1 {v28.s}[2], [x19], x8
|
||||
st1 {v28.h}[6], [x10], x8
|
||||
cmp x6, #13
|
||||
beq WriteEnd
|
||||
st1 {v29.4h}, [x11], x8
|
||||
st1 {v29.s}[2], [x19], x8
|
||||
st1 {v29.h}[6], [x10], x8
|
||||
cmp x6, #14
|
||||
beq WriteEnd
|
||||
st1 {v30.4h}, [x11], x8
|
||||
st1 {v30.s}[2], [x19], x8
|
||||
st1 {v30.h}[6], [x10], x8
|
||||
cmp x6, #15
|
||||
beq WriteEnd
|
||||
st1 {v31.4h}, [x11], x8
|
||||
st1 {v31.s}[2], [x19]
|
||||
st1 {v31.h}[6], [x10]
|
||||
b WriteEnd
|
||||
Write8:
|
||||
add x2, x2, #16
|
||||
st1 {v16.8h}, [x11], x8
|
||||
cmp x6, #1
|
||||
beq WriteEnd
|
||||
st1 {v17.8h}, [x11], x8
|
||||
cmp x6, #2
|
||||
beq WriteEnd
|
||||
st1 {v18.8h}, [x11], x8
|
||||
cmp x6, #3
|
||||
beq WriteEnd
|
||||
st1 {v19.8h}, [x11], x8
|
||||
cmp x6, #4
|
||||
beq WriteEnd
|
||||
st1 {v20.8h}, [x11], x8
|
||||
cmp x6, #5
|
||||
beq WriteEnd
|
||||
st1 {v21.8h}, [x11], x8
|
||||
cmp x6, #6
|
||||
beq WriteEnd
|
||||
st1 {v22.8h}, [x11], x8
|
||||
cmp x6, #7
|
||||
beq WriteEnd
|
||||
st1 {v23.8h}, [x11], x8
|
||||
cmp x6, #8
|
||||
beq WriteEnd
|
||||
st1 {v24.8h}, [x11], x8
|
||||
cmp x6, #9
|
||||
beq WriteEnd
|
||||
st1 {v25.8h}, [x11], x8
|
||||
cmp x6, #10
|
||||
beq WriteEnd
|
||||
st1 {v26.8h}, [x11], x8
|
||||
cmp x6, #11
|
||||
beq WriteEnd
|
||||
st1 {v27.8h}, [x11], x8
|
||||
cmp x6, #12
|
||||
beq WriteEnd
|
||||
st1 {v28.8h}, [x11], x8
|
||||
cmp x6, #13
|
||||
beq WriteEnd
|
||||
st1 {v29.8h}, [x11], x8
|
||||
cmp x6, #14
|
||||
beq WriteEnd
|
||||
st1 {v30.8h}, [x11], x8
|
||||
cmp x6, #15
|
||||
beq WriteEnd
|
||||
st1 {v31.8h}, [x11], x8
|
||||
|
||||
WriteEnd:
|
||||
subs x13, x13, #8 // rhs col - 8
|
||||
ble LoopColEnd
|
||||
cmp x6, #16
|
||||
bge LoopCol16
|
||||
cmp x6, #8
|
||||
bge LoopCol8
|
||||
b LoopCol4
|
||||
|
||||
LoopColEnd:
|
||||
sub x2, x2, x16 // dst - col * 2
|
||||
mul x21, x8, x15 // row_block * col * 2
|
||||
add x2, x2, x21
|
||||
subs x6, x6, x15
|
||||
mul x15, x15, x17
|
||||
add x0, x0, x15
|
||||
bgt LoopRowStart
|
||||
|
||||
sub sp, sp, #96
|
||||
ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64
|
||||
ldp x19, x20, [sp], #16
|
||||
ldp x21, x22, [sp], #16
|
||||
ret
|
||||
#endif
|
|
@ -228,7 +228,7 @@ LoopRow16:
|
|||
fmin v29.8h, v29.8h, v2.8h
|
||||
fmin v30.8h, v30.8h, v2.8h
|
||||
fmin v31.8h, v31.8h, v2.8h
|
||||
|
||||
|
||||
Relu16:
|
||||
dup v2.8h, wzr
|
||||
fmax v16.8h, v16.8h, v2.8h
|
||||
|
@ -359,7 +359,7 @@ LoopRow8:
|
|||
fmin v21.8h, v21.8h, v2.8h
|
||||
fmin v22.8h, v22.8h, v2.8h
|
||||
fmin v23.8h, v23.8h, v2.8h
|
||||
|
||||
|
||||
Relu8:
|
||||
dup v2.8h, wzr
|
||||
fmax v16.8h, v16.8h, v2.8h
|
||||
|
@ -450,7 +450,7 @@ LoopRow4:
|
|||
fmin v17.8h, v17.8h, v2.8h
|
||||
fmin v18.8h, v18.8h, v2.8h
|
||||
fmin v19.8h, v19.8h, v2.8h
|
||||
|
||||
|
||||
Relu4:
|
||||
dup v2.8h, wzr
|
||||
fmax v16.8h, v16.8h, v2.8h
|
||||
|
@ -458,7 +458,7 @@ LoopRow4:
|
|||
fmax v18.8h, v18.8h, v2.8h
|
||||
fmax v19.8h, v19.8h, v2.8h
|
||||
b Write
|
||||
|
||||
|
||||
LoopRow2:
|
||||
mov x14, x1 // reload rhs ptr
|
||||
mov x13, x7 // reload rhs col
|
||||
|
@ -521,7 +521,7 @@ LoopRow2:
|
|||
movi v2.8h, #0x46, lsl #8
|
||||
fmin v16.8h, v16.8h, v2.8h
|
||||
fmin v17.8h, v17.8h, v2.8h
|
||||
|
||||
|
||||
Relu2:
|
||||
dup v2.8h, wzr
|
||||
fmax v16.8h, v16.8h, v2.8h
|
||||
|
@ -582,7 +582,7 @@ LoopRow:
|
|||
Relu6:
|
||||
movi v2.8h, #0x46, lsl #8
|
||||
fmin v16.8h, v16.8h, v2.8h
|
||||
|
||||
|
||||
Relu:
|
||||
dup v2.8h, wzr
|
||||
fmax v16.8h, v16.8h, v2.8h
|
||||
|
@ -675,152 +675,118 @@ LoopRow:
|
|||
b WriteEnd
|
||||
Write2:
|
||||
add x2, x2, #4
|
||||
add x19, x11, #2
|
||||
st1 {v16.h}[0], [x11], x8
|
||||
st1 {v16.h}[1], [x19], x8
|
||||
st1 {v16.s}[0], [x11], x8
|
||||
cmp x6, #1
|
||||
beq WriteEnd
|
||||
st1 {v17.h}[0], [x11], x8
|
||||
st1 {v17.h}[1], [x19], x8
|
||||
st1 {v17.s}[0], [x11], x8
|
||||
cmp x6, #2
|
||||
beq WriteEnd
|
||||
st1 {v18.h}[0], [x11], x8
|
||||
st1 {v18.h}[1], [x19], x8
|
||||
st1 {v18.s}[0], [x11], x8
|
||||
cmp x6, #3
|
||||
beq WriteEnd
|
||||
st1 {v19.h}[0], [x11], x8
|
||||
st1 {v19.h}[1], [x19], x8
|
||||
st1 {v19.s}[0], [x11], x8
|
||||
cmp x6, #4
|
||||
beq WriteEnd
|
||||
st1 {v20.h}[0], [x11], x8
|
||||
st1 {v20.h}[1], [x19], x8
|
||||
st1 {v20.s}[0], [x11], x8
|
||||
cmp x6, #5
|
||||
beq WriteEnd
|
||||
st1 {v21.h}[0], [x11], x8
|
||||
st1 {v21.h}[1], [x19], x8
|
||||
st1 {v21.s}[0], [x11], x8
|
||||
cmp x6, #6
|
||||
beq WriteEnd
|
||||
st1 {v22.h}[0], [x11], x8
|
||||
st1 {v22.h}[1], [x19], x8
|
||||
st1 {v22.s}[0], [x11], x8
|
||||
cmp x6, #7
|
||||
beq WriteEnd
|
||||
st1 {v23.h}[0], [x11], x8
|
||||
st1 {v23.h}[1], [x19], x8
|
||||
st1 {v23.s}[0], [x11], x8
|
||||
cmp x6, #8
|
||||
beq WriteEnd
|
||||
st1 {v24.h}[0], [x11], x8
|
||||
st1 {v24.h}[1], [x19], x8
|
||||
st1 {v24.s}[0], [x11], x8
|
||||
cmp x6, #9
|
||||
beq WriteEnd
|
||||
st1 {v25.h}[0], [x11], x8
|
||||
st1 {v25.h}[1], [x19], x8
|
||||
st1 {v25.s}[0], [x11], x8
|
||||
cmp x6, #10
|
||||
beq WriteEnd
|
||||
st1 {v26.h}[0], [x11], x8
|
||||
st1 {v26.h}[1], [x19], x8
|
||||
st1 {v26.s}[0], [x11], x8
|
||||
cmp x6, #11
|
||||
beq WriteEnd
|
||||
st1 {v27.h}[0], [x11], x8
|
||||
st1 {v27.h}[1], [x19], x8
|
||||
st1 {v27.s}[0], [x11], x8
|
||||
cmp x6, #12
|
||||
beq WriteEnd
|
||||
st1 {v28.h}[0], [x11], x8
|
||||
st1 {v28.h}[1], [x19], x8
|
||||
st1 {v28.s}[0], [x11], x8
|
||||
cmp x6, #13
|
||||
beq WriteEnd
|
||||
st1 {v29.h}[0], [x11], x8
|
||||
st1 {v29.h}[1], [x19], x8
|
||||
st1 {v29.s}[0], [x11], x8
|
||||
cmp x6, #14
|
||||
beq WriteEnd
|
||||
st1 {v30.h}[0], [x11], x8
|
||||
st1 {v30.h}[1], [x19], x8
|
||||
st1 {v30.s}[0], [x11], x8
|
||||
cmp x6, #15
|
||||
beq WriteEnd
|
||||
st1 {v31.h}[0], [x11], x8
|
||||
st1 {v31.h}[1], [x19]
|
||||
st1 {v31.s}[0], [x11], x8
|
||||
add x11, x11, #4
|
||||
b WriteEnd
|
||||
Write3:
|
||||
add x2, x2, #6
|
||||
add x19, x11, #4
|
||||
add x20, x11, #2
|
||||
st1 {v16.h}[0], [x11], x8
|
||||
st1 {v16.h}[1], [x20], x8
|
||||
st1 {v16.s}[0], [x11], x8
|
||||
st1 {v16.h}[2], [x19], x8
|
||||
cmp x6, #1
|
||||
beq WriteEnd
|
||||
st1 {v17.h}[0], [x11], x8
|
||||
st1 {v17.h}[1], [x20], x8
|
||||
st1 {v17.s}[0], [x11], x8
|
||||
st1 {v17.h}[2], [x19], x8
|
||||
cmp x6, #2
|
||||
beq WriteEnd
|
||||
st1 {v18.h}[0], [x11], x8
|
||||
st1 {v18.h}[1], [x20], x8
|
||||
st1 {v18.s}[0], [x11], x8
|
||||
st1 {v18.h}[2], [x19], x8
|
||||
cmp x6, #3
|
||||
beq WriteEnd
|
||||
st1 {v19.h}[0], [x11], x8
|
||||
st1 {v19.h}[1], [x20], x8
|
||||
st1 {v19.s}[0], [x11], x8
|
||||
st1 {v19.h}[2], [x19], x8
|
||||
cmp x6, #4
|
||||
beq WriteEnd
|
||||
st1 {v20.h}[0], [x11], x8
|
||||
st1 {v20.h}[1], [x20], x8
|
||||
st1 {v20.s}[0], [x11], x8
|
||||
st1 {v20.h}[2], [x19], x8
|
||||
cmp x6, #5
|
||||
beq WriteEnd
|
||||
st1 {v21.h}[0], [x11], x8
|
||||
st1 {v21.h}[1], [x20], x8
|
||||
st1 {v21.s}[0], [x11], x8
|
||||
st1 {v21.h}[2], [x19], x8
|
||||
cmp x6, #6
|
||||
beq WriteEnd
|
||||
st1 {v22.h}[0], [x11], x8
|
||||
st1 {v22.h}[1], [x20], x8
|
||||
st1 {v22.s}[0], [x11], x8
|
||||
st1 {v22.h}[2], [x19], x8
|
||||
cmp x6, #7
|
||||
beq WriteEnd
|
||||
st1 {v23.h}[0], [x11], x8
|
||||
st1 {v23.h}[1], [x20], x8
|
||||
st1 {v23.s}[0], [x11], x8
|
||||
st1 {v23.h}[2], [x19], x8
|
||||
cmp x6, #8
|
||||
beq WriteEnd
|
||||
st1 {v24.h}[0], [x11], x8
|
||||
st1 {v24.h}[1], [x20], x8
|
||||
st1 {v24.s}[0], [x11], x8
|
||||
st1 {v24.h}[2], [x19], x8
|
||||
cmp x6, #9
|
||||
beq WriteEnd
|
||||
st1 {v25.h}[0], [x11], x8
|
||||
st1 {v25.h}[1], [x20], x8
|
||||
st1 {v25.s}[0], [x11], x8
|
||||
st1 {v25.h}[2], [x19], x8
|
||||
cmp x6, #10
|
||||
beq WriteEnd
|
||||
st1 {v26.h}[0], [x11], x8
|
||||
st1 {v26.h}[1], [x20], x8
|
||||
st1 {v26.s}[0], [x11], x8
|
||||
st1 {v26.h}[2], [x19], x8
|
||||
cmp x6, #11
|
||||
beq WriteEnd
|
||||
st1 {v27.h}[0], [x11], x8
|
||||
st1 {v27.h}[1], [x20], x8
|
||||
st1 {v27.s}[0], [x11], x8
|
||||
st1 {v27.h}[2], [x19], x8
|
||||
cmp x6, #12
|
||||
beq WriteEnd
|
||||
st1 {v28.h}[0], [x11], x8
|
||||
st1 {v28.h}[1], [x20], x8
|
||||
st1 {v28.s}[0], [x11], x8
|
||||
st1 {v28.h}[2], [x19], x8
|
||||
cmp x6, #13
|
||||
beq WriteEnd
|
||||
st1 {v29.h}[0], [x11], x8
|
||||
st1 {v29.h}[1], [x20], x8
|
||||
st1 {v29.s}[0], [x11], x8
|
||||
st1 {v29.h}[2], [x19], x8
|
||||
cmp x6, #14
|
||||
beq WriteEnd
|
||||
st1 {v30.h}[0], [x11], x8
|
||||
st1 {v30.h}[1], [x20], x8
|
||||
st1 {v30.s}[0], [x11], x8
|
||||
st1 {v30.h}[2], [x19], x8
|
||||
cmp x6, #15
|
||||
beq WriteEnd
|
||||
st1 {v31.h}[0], [x11], x8
|
||||
st1 {v31.h}[1], [x20]
|
||||
st1 {v31.s}[0], [x11], x8
|
||||
st1 {v31.h}[2], [x19]
|
||||
add x11, x11, #6
|
||||
b WriteEnd
|
||||
|
@ -944,185 +910,151 @@ LoopRow:
|
|||
Write6:
|
||||
add x2, x2, #12
|
||||
add x19, x11, #8
|
||||
add x20, x11, #10
|
||||
st1 {v16.4h}, [x11], x8
|
||||
st1 {v16.h}[4], [x19], x8
|
||||
st1 {v16.h}[5], [x20], x8
|
||||
st1 {v16.s}[2], [x19], x8
|
||||
cmp x6, #1
|
||||
beq WriteEnd
|
||||
st1 {v17.4h}, [x11], x8
|
||||
st1 {v17.h}[4], [x19], x8
|
||||
st1 {v17.h}[5], [x20], x8
|
||||
st1 {v17.s}[2], [x19], x8
|
||||
cmp x6, #2
|
||||
beq WriteEnd
|
||||
st1 {v18.4h}, [x11], x8
|
||||
st1 {v18.h}[4], [x19], x8
|
||||
st1 {v18.h}[5], [x20], x8
|
||||
st1 {v18.s}[2], [x19], x8
|
||||
cmp x6, #3
|
||||
beq WriteEnd
|
||||
st1 {v19.4h}, [x11], x8
|
||||
st1 {v19.h}[4], [x19], x8
|
||||
st1 {v19.h}[5], [x20], x8
|
||||
st1 {v19.s}[2], [x19], x8
|
||||
cmp x6, #4
|
||||
beq WriteEnd
|
||||
st1 {v20.4h}, [x11], x8
|
||||
st1 {v20.h}[4], [x19], x8
|
||||
st1 {v20.h}[5], [x20], x8
|
||||
st1 {v20.s}[2], [x19], x8
|
||||
cmp x6, #5
|
||||
beq WriteEnd
|
||||
st1 {v21.4h}, [x11], x8
|
||||
st1 {v21.h}[4], [x19], x8
|
||||
st1 {v21.h}[5], [x20], x8
|
||||
st1 {v21.s}[2], [x19], x8
|
||||
cmp x6, #6
|
||||
beq WriteEnd
|
||||
st1 {v22.4h}, [x11], x8
|
||||
st1 {v22.h}[4], [x19], x8
|
||||
st1 {v22.h}[5], [x20], x8
|
||||
st1 {v22.s}[2], [x19], x8
|
||||
cmp x6, #7
|
||||
beq WriteEnd
|
||||
st1 {v23.4h}, [x11], x8
|
||||
st1 {v23.h}[4], [x19], x8
|
||||
st1 {v23.h}[5], [x20], x8
|
||||
st1 {v23.s}[2], [x19], x8
|
||||
cmp x6, #8
|
||||
beq WriteEnd
|
||||
st1 {v24.4h}, [x11], x8
|
||||
st1 {v24.h}[4], [x19], x8
|
||||
st1 {v24.h}[5], [x20], x8
|
||||
st1 {v24.s}[2], [x19], x8
|
||||
cmp x6, #9
|
||||
beq WriteEnd
|
||||
st1 {v25.4h}, [x11], x8
|
||||
st1 {v25.h}[4], [x19], x8
|
||||
st1 {v25.h}[5], [x20], x8
|
||||
st1 {v25.s}[2], [x19], x8
|
||||
cmp x6, #10
|
||||
beq WriteEnd
|
||||
st1 {v26.4h}, [x11], x8
|
||||
st1 {v26.h}[4], [x19], x8
|
||||
st1 {v26.h}[5], [x20], x8
|
||||
st1 {v26.s}[2], [x19], x8
|
||||
cmp x6, #11
|
||||
beq WriteEnd
|
||||
st1 {v27.4h}, [x11], x8
|
||||
st1 {v27.h}[4], [x19], x8
|
||||
st1 {v27.h}[5], [x20], x8
|
||||
st1 {v27.s}[2], [x19], x8
|
||||
cmp x6, #12
|
||||
beq WriteEnd
|
||||
st1 {v28.4h}, [x11], x8
|
||||
st1 {v28.h}[4], [x19], x8
|
||||
st1 {v28.h}[5], [x20], x8
|
||||
st1 {v28.s}[2], [x19], x8
|
||||
cmp x6, #13
|
||||
beq WriteEnd
|
||||
st1 {v29.4h}, [x11], x8
|
||||
st1 {v29.h}[4], [x19], x8
|
||||
st1 {v29.h}[5], [x20], x8
|
||||
st1 {v29.s}[2], [x19], x8
|
||||
cmp x6, #14
|
||||
beq WriteEnd
|
||||
st1 {v30.4h}, [x11], x8
|
||||
st1 {v30.h}[4], [x19], x8
|
||||
st1 {v30.h}[5], [x20], x8
|
||||
st1 {v30.s}[2], [x19], x8
|
||||
cmp x6, #15
|
||||
beq WriteEnd
|
||||
st1 {v31.4h}, [x11], x8
|
||||
st1 {v31.h}[4], [x19]
|
||||
st1 {v31.h}[5], [x20]
|
||||
st1 {v31.s}[2], [x19]
|
||||
add x11, x11, #12
|
||||
b WriteEnd
|
||||
Write7:
|
||||
add x2, x2, #14
|
||||
add x19, x11, #8
|
||||
add x20, x11, #10
|
||||
add x10, x11, #12
|
||||
st1 {v16.4h}, [x11], x8
|
||||
st1 {v16.h}[4], [x19], x8
|
||||
st1 {v16.h}[5], [x20], x8
|
||||
st1 {v16.s}[2], [x19], x8
|
||||
st1 {v16.h}[6], [x10], x8
|
||||
cmp x6, #1
|
||||
beq WriteEnd
|
||||
st1 {v17.4h}, [x11], x8
|
||||
st1 {v17.h}[4], [x19], x8
|
||||
st1 {v17.h}[5], [x20], x8
|
||||
st1 {v17.s}[2], [x19], x8
|
||||
st1 {v17.h}[6], [x10], x8
|
||||
cmp x6, #2
|
||||
beq WriteEnd
|
||||
st1 {v18.4h}, [x11], x8
|
||||
st1 {v18.h}[4], [x19], x8
|
||||
st1 {v18.h}[5], [x20], x8
|
||||
st1 {v18.s}[2], [x19], x8
|
||||
st1 {v18.h}[6], [x10], x8
|
||||
cmp x6, #3
|
||||
beq WriteEnd
|
||||
st1 {v19.4h}, [x11], x8
|
||||
st1 {v19.h}[4], [x19], x8
|
||||
st1 {v19.h}[5], [x20], x8
|
||||
st1 {v19.s}[2], [x19], x8
|
||||
st1 {v19.h}[6], [x10], x8
|
||||
cmp x6, #4
|
||||
beq WriteEnd
|
||||
st1 {v20.4h}, [x11], x8
|
||||
st1 {v20.h}[4], [x19], x8
|
||||
st1 {v20.h}[5], [x20], x8
|
||||
st1 {v20.s}[2], [x19], x8
|
||||
st1 {v20.h}[6], [x10], x8
|
||||
cmp x6, #5
|
||||
beq WriteEnd
|
||||
st1 {v21.4h}, [x11], x8
|
||||
st1 {v21.h}[4], [x19], x8
|
||||
st1 {v21.h}[5], [x20], x8
|
||||
st1 {v21.s}[2], [x19], x8
|
||||
st1 {v21.h}[6], [x10], x8
|
||||
cmp x6, #6
|
||||
beq WriteEnd
|
||||
st1 {v22.4h}, [x11], x8
|
||||
st1 {v22.h}[4], [x19], x8
|
||||
st1 {v22.h}[5], [x20], x8
|
||||
st1 {v22.s}[2], [x19], x8
|
||||
st1 {v22.h}[6], [x10], x8
|
||||
cmp x6, #7
|
||||
beq WriteEnd
|
||||
st1 {v23.4h}, [x11], x8
|
||||
st1 {v23.h}[4], [x19], x8
|
||||
st1 {v23.h}[5], [x20], x8
|
||||
st1 {v23.s}[2], [x19], x8
|
||||
st1 {v23.h}[6], [x10], x8
|
||||
cmp x6, #8
|
||||
beq WriteEnd
|
||||
st1 {v24.4h}, [x11], x8
|
||||
st1 {v24.h}[4], [x19], x8
|
||||
st1 {v24.h}[5], [x20], x8
|
||||
st1 {v24.s}[2], [x19], x8
|
||||
st1 {v24.h}[6], [x10], x8
|
||||
cmp x6, #9
|
||||
beq WriteEnd
|
||||
st1 {v25.4h}, [x11], x8
|
||||
st1 {v25.h}[4], [x19], x8
|
||||
st1 {v25.h}[5], [x20], x8
|
||||
st1 {v25.s}[2], [x19], x8
|
||||
st1 {v25.h}[6], [x10], x8
|
||||
cmp x6, #10
|
||||
beq WriteEnd
|
||||
st1 {v26.4h}, [x11], x8
|
||||
st1 {v26.h}[4], [x19], x8
|
||||
st1 {v26.h}[5], [x20], x8
|
||||
st1 {v26.s}[2], [x19], x8
|
||||
st1 {v26.h}[6], [x10], x8
|
||||
cmp x6, #11
|
||||
beq WriteEnd
|
||||
st1 {v27.4h}, [x11], x8
|
||||
st1 {v27.h}[4], [x19], x8
|
||||
st1 {v27.h}[5], [x20], x8
|
||||
st1 {v27.s}[2], [x19], x8
|
||||
st1 {v27.h}[6], [x10], x8
|
||||
cmp x6, #12
|
||||
beq WriteEnd
|
||||
st1 {v28.4h}, [x11], x8
|
||||
st1 {v28.h}[4], [x19], x8
|
||||
st1 {v28.h}[5], [x20], x8
|
||||
st1 {v28.s}[2], [x19], x8
|
||||
st1 {v28.h}[6], [x10], x8
|
||||
cmp x6, #13
|
||||
beq WriteEnd
|
||||
st1 {v29.4h}, [x11], x8
|
||||
st1 {v29.h}[4], [x19], x8
|
||||
st1 {v29.h}[5], [x20], x8
|
||||
st1 {v29.s}[2], [x19], x8
|
||||
st1 {v29.h}[6], [x10], x8
|
||||
cmp x6, #14
|
||||
beq WriteEnd
|
||||
st1 {v30.4h}, [x11], x8
|
||||
st1 {v30.h}[4], [x19], x8
|
||||
st1 {v30.h}[5], [x20], x8
|
||||
st1 {v30.s}[2], [x19], x8
|
||||
st1 {v30.h}[6], [x10], x8
|
||||
cmp x6, #15
|
||||
beq WriteEnd
|
||||
st1 {v31.4h}, [x11], x8
|
||||
st1 {v31.h}[4], [x19]
|
||||
st1 {v31.h}[5], [x20]
|
||||
st1 {v31.s}[2], [x19]
|
||||
st1 {v31.h}[6], [x10]
|
||||
add x11, x11, #14
|
||||
b WriteEnd
|
||||
|
|
|
@ -662,6 +662,75 @@ void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, si
|
|||
return;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void RowMajor2ColNMajorFp16(const float16_t *src_ptr, float16_t *dst_ptr, int row, int col) {
|
||||
// Col16Major ==> Col8Major ==> Col4Major
|
||||
const float16_t *src_r = src_ptr;
|
||||
float16_t *dst_r = dst_ptr;
|
||||
int ri = 0;
|
||||
size_t col8 = col / C8NUM * C8NUM;
|
||||
// find 16 block unit
|
||||
for (; ri <= row - C16NUM; ri += C16NUM) {
|
||||
size_t ci = 0;
|
||||
for (; ci < col8; ci += C8NUM) {
|
||||
const float16_t *src_c = src_r + ci;
|
||||
float16_t *dst_c = dst_r + ci * C16NUM;
|
||||
Row2Col16Block16(src_c, dst_c, col);
|
||||
}
|
||||
for (; ci < col; ci++) {
|
||||
const float16_t *src_c = src_r + ci;
|
||||
float16_t *dst_c = dst_r + ci * C16NUM;
|
||||
for (size_t i = 0; i < C16NUM; i++) {
|
||||
dst_c[i] = src_c[i * col];
|
||||
}
|
||||
}
|
||||
src_r += C16NUM * col;
|
||||
dst_r += C16NUM * col;
|
||||
}
|
||||
for (; ri <= row - C8NUM; ri += C8NUM) {
|
||||
size_t ci = 0;
|
||||
for (; ci < col8; ci += C8NUM) {
|
||||
const float16_t *src_c = src_r + ci;
|
||||
float16_t *dst_c = dst_r + ci * C8NUM;
|
||||
Transpose8x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), C8NUM * sizeof(float16_t));
|
||||
}
|
||||
for (; ci < col; ci++) {
|
||||
const float16_t *src_c = src_r + ci;
|
||||
float16_t *dst_c = dst_r + ci * C8NUM;
|
||||
for (size_t i = 0; i < C8NUM; i++) {
|
||||
dst_c[i] = src_c[i * col];
|
||||
}
|
||||
}
|
||||
src_r += C8NUM * col;
|
||||
dst_r += C8NUM * col;
|
||||
}
|
||||
for (; ri <= row - C4NUM; ri += C4NUM) {
|
||||
size_t ci = 0;
|
||||
for (; ci < col8; ci += C8NUM) {
|
||||
const float16_t *src_c = src_r + ci;
|
||||
float16_t *dst_c = dst_r + ci * C4NUM;
|
||||
Transpose4x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), C4NUM * sizeof(float16_t));
|
||||
}
|
||||
for (; ci < col; ci++) {
|
||||
const float16_t *src_c = src_r + ci;
|
||||
float16_t *dst_c = dst_r + ci * C4NUM;
|
||||
for (size_t i = 0; i < C4NUM; i++) {
|
||||
dst_c[i] = src_c[i * col];
|
||||
}
|
||||
}
|
||||
src_r += C4NUM * col;
|
||||
dst_r += C4NUM * col;
|
||||
}
|
||||
for (; ri < row; ri++) {
|
||||
for (size_t i = 0; i < col; ++i) {
|
||||
dst_r[i * C4NUM] = src_r[i];
|
||||
}
|
||||
src_r += col;
|
||||
dst_r += 1;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
|
||||
size_t row_up_12 = UP_ROUND(row, C12NUM);
|
||||
size_t row12 = row / C12NUM * C12NUM;
|
||||
|
@ -760,6 +829,32 @@ void RowMajor2Row16MajorFp16(const void *src, float16_t *dst, int row, int col,
|
|||
}
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void RowMajor2RowNMajorFp16(const float16_t *src, float16_t *dst, int row, int col) {
|
||||
// Row16 ==> Row8 ==> Row4
|
||||
for (int r = 0; r < row; r++) {
|
||||
int c = 0;
|
||||
for (; c <= col - C16NUM; c += C16NUM) {
|
||||
MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c);
|
||||
MS_FLOAT16X8 src_data1 = MS_LDQ_F16(src + r * col + c + C8NUM);
|
||||
MS_STQ_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM, src_data);
|
||||
MS_STQ_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM + C8NUM, src_data1);
|
||||
}
|
||||
for (; c <= col - C8NUM; c += C8NUM) {
|
||||
MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c);
|
||||
MS_STQ_F16(dst + c / C8NUM * C8NUM * row + r * C8NUM, src_data);
|
||||
}
|
||||
for (; c <= col - C4NUM; c += C4NUM) {
|
||||
MS_FLOAT16X4 src_data = MS_LD_F16(src + r * col + c);
|
||||
MS_ST_F16(dst + c / C4NUM * C4NUM * row + r * C4NUM, src_data);
|
||||
}
|
||||
for (; c < col; ++c) {
|
||||
dst[c / C4NUM * C4NUM * row + r * C4NUM + c % C4NUM] = src[r * col + c];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void RowMajor2Row16MajorFp16Opt(const float16_t *src, float16_t *dst, int row, int col) {
|
||||
int col_align = UP_ROUND(col, C16NUM);
|
||||
for (int r = 0; r < row; r++) {
|
||||
|
@ -834,3 +929,32 @@ void RowMajor2Col8MajorFp16(const void *src, float16_t *dst, int row, int col, b
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(ENABLE_DEBUG) && defined(ENABLE_ARM64)
|
||||
// arm64 matmul
|
||||
void MatmulBaseFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
|
||||
size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc) {
|
||||
int r16 = row / C16NUM * C16NUM;
|
||||
int r8 = row / C8NUM * C8NUM;
|
||||
for (int r = 0; r < row; ++r) {
|
||||
int row_tile = 0;
|
||||
if (r < r16) {
|
||||
row_tile = C16NUM;
|
||||
} else if (r < r8) {
|
||||
row_tile = C8NUM;
|
||||
} else {
|
||||
row_tile = C4NUM;
|
||||
}
|
||||
int index = r / row_tile * row_tile * depth + r % row_tile;
|
||||
for (int t = 0; t < col; ++t) {
|
||||
int c_div = t / C8NUM;
|
||||
int c_mod = t % C8NUM;
|
||||
float16_t res = bias[t];
|
||||
for (int d = 0; d < depth; ++d) {
|
||||
res += a[index + d * row_tile] * b[c_div * depth * C8NUM + d * C8NUM + c_mod];
|
||||
}
|
||||
c[r * col + t] = res;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -45,6 +45,10 @@ void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, cons
|
|||
int deep, int row, int col, int stride, int write_mode);
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void RowMajor2ColNMajorFp16(const float16_t *src, float16_t *dst_ptr, int row, int col);
|
||||
|
||||
void RowMajor2RowNMajorFp16(const float16_t *src, float16_t *dst, int row, int col);
|
||||
|
||||
void MatMul12x16Fp16Opt(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
|
||||
int deep, int row, int col, size_t stride, size_t out_type);
|
||||
void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
|
||||
|
@ -53,6 +57,14 @@ void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, cons
|
|||
void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
|
||||
size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc);
|
||||
|
||||
void MatmulBaseFp16Neon(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
|
||||
size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc);
|
||||
|
||||
#ifdef ENABLE_DEBUG
|
||||
void MatmulBaseFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
|
||||
size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc);
|
||||
#endif
|
||||
|
||||
void MatVecMulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
|
||||
int depth, int col);
|
||||
|
||||
|
|
|
@ -640,6 +640,108 @@ inline void Transpose12x8A32Fp16(const float16_t *src_c, float16_t *dst_c, size_
|
|||
#endif
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
inline void Transpose4x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) {
|
||||
dst_stride += dst_stride;
|
||||
asm volatile(
|
||||
"mov x10, %[src_ptr]\n"
|
||||
"mov x11, %[dst_ptr]\n"
|
||||
|
||||
"ld1 {v0.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v1.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v2.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v3.8h}, [x10], %[src_stride]\n"
|
||||
"add x10, x11, %[dst_stride]\n"
|
||||
|
||||
"zip1 v4.8h, v0.8h, v1.8h\n"
|
||||
"zip1 v5.8h, v2.8h, v3.8h\n"
|
||||
|
||||
"trn1 v6.4s, v4.4s, v5.4s\n"
|
||||
"trn2 v7.4s, v4.4s, v5.4s\n"
|
||||
|
||||
"trn1 v24.2d, v6.2d, v7.2d\n"
|
||||
"trn2 v25.2d, v6.2d, v7.2d\n"
|
||||
|
||||
"zip2 v8.8h, v0.8h, v1.8h\n"
|
||||
"zip2 v9.8h, v2.8h, v3.8h\n"
|
||||
|
||||
"trn1 v10.4s, v8.4s, v9.4s\n"
|
||||
"trn2 v11.4s, v8.4s, v9.4s\n"
|
||||
|
||||
"trn1 v26.2d, v10.2d, v11.2d\n"
|
||||
"trn2 v27.2d, v10.2d, v11.2d\n"
|
||||
|
||||
"st1 {v24.8h}, [x11], %[tow_dst_stride]\n"
|
||||
"st1 {v25.8h}, [x10], %[tow_dst_stride]\n"
|
||||
"st1 {v26.8h}, [x11], %[tow_dst_stride]\n"
|
||||
"st1 {v27.8h}, [x10], %[tow_dst_stride]\n"
|
||||
:
|
||||
: [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ src_stride ] "r"(src_stride),
|
||||
[ dst_stride ] "r"(dst_stride), [ tow_dst_stride ] "r"(2 * dst_stride)
|
||||
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v24", "v25", "v26",
|
||||
"v27");
|
||||
}
|
||||
|
||||
inline void Transpose8x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) {
|
||||
asm volatile(
|
||||
"mov x10, %[src_ptr]\n"
|
||||
"mov x11, %[dst_ptr]\n"
|
||||
|
||||
"ld1 {v0.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v1.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v2.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v3.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v4.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v5.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v6.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v7.8h}, [x10], %[src_stride]\n"
|
||||
"add x10, x11, %[dst_stride]\n"
|
||||
|
||||
"zip1 v16.8h, v0.8h, v1.8h\n"
|
||||
"zip1 v17.8h, v2.8h, v3.8h\n"
|
||||
"zip1 v18.8h, v4.8h, v5.8h\n"
|
||||
"zip1 v19.8h, v6.8h, v7.8h\n"
|
||||
|
||||
"trn1 v20.4s, v16.4s, v17.4s\n"
|
||||
"trn2 v21.4s, v16.4s, v17.4s\n"
|
||||
"trn1 v22.4s, v18.4s, v19.4s\n"
|
||||
"trn2 v23.4s, v18.4s, v19.4s\n"
|
||||
|
||||
"trn1 v24.2d, v20.2d, v22.2d\n"
|
||||
"trn2 v26.2d, v20.2d, v22.2d\n"
|
||||
"trn1 v25.2d, v21.2d, v23.2d\n"
|
||||
"trn2 v27.2d, v21.2d, v23.2d\n"
|
||||
|
||||
"zip2 v8.8h, v0.8h, v1.8h\n"
|
||||
"zip2 v9.8h, v2.8h, v3.8h\n"
|
||||
"zip2 v10.8h, v4.8h, v5.8h\n"
|
||||
"zip2 v11.8h, v6.8h, v7.8h\n"
|
||||
|
||||
"trn1 v12.4s, v8.4s, v9.4s\n"
|
||||
"trn2 v13.4s, v8.4s, v9.4s\n"
|
||||
"trn1 v14.4s, v10.4s, v11.4s\n"
|
||||
"trn2 v15.4s, v10.4s, v11.4s\n"
|
||||
|
||||
"trn1 v28.2d, v12.2d, v14.2d\n"
|
||||
"trn2 v30.2d, v12.2d, v14.2d\n"
|
||||
"trn1 v29.2d, v13.2d, v15.2d\n"
|
||||
"trn2 v31.2d, v13.2d, v15.2d\n"
|
||||
|
||||
"st1 {v24.8h}, [x11], %[tow_dst_stride]\n"
|
||||
"st1 {v25.8h}, [x10], %[tow_dst_stride]\n"
|
||||
"st1 {v26.8h}, [x11], %[tow_dst_stride]\n"
|
||||
"st1 {v27.8h}, [x10], %[tow_dst_stride]\n"
|
||||
"st1 {v28.8h}, [x11], %[tow_dst_stride]\n"
|
||||
"st1 {v29.8h}, [x10], %[tow_dst_stride]\n"
|
||||
"st1 {v30.8h}, [x11], %[tow_dst_stride]\n"
|
||||
"st1 {v31.8h}, [x10], %[tow_dst_stride]\n"
|
||||
:
|
||||
: [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ src_stride ] "r"(src_stride),
|
||||
[ dst_stride ] "r"(dst_stride), [ tow_dst_stride ] "r"(2 * dst_stride)
|
||||
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
|
||||
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
|
||||
"v31");
|
||||
}
|
||||
|
||||
void Transpose12x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) {
|
||||
#ifdef ENABLE_DEBUG
|
||||
for (int tr = 0; tr < C12NUM; tr++) {
|
||||
|
|
|
@ -78,6 +78,8 @@ void Transpose12x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_strid
|
|||
#endif
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void Transpose4x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride);
|
||||
void Transpose8x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride);
|
||||
void Transpose12x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride);
|
||||
void Transpose16x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride);
|
||||
#endif
|
||||
|
|
|
@ -93,11 +93,14 @@ static inline float16x4_t ms_vcvt_f16_f32(float32x4_t in) {
|
|||
#endif
|
||||
|
||||
#define MS_FLOAT16X8 float16x8_t
|
||||
#define MS_FLOAT16X4 float16x4_t
|
||||
#define MS_MOVQ_F16 vmovq_n_f16
|
||||
#define MS_STQ_F16 vst1q_f16
|
||||
#define MS_ST_F16 vst1_f16
|
||||
#define MS_MINQ_F16 vminq_f16
|
||||
#define MS_MAXQ_F16 vmaxq_f16
|
||||
#define MS_LDQ_F16 vld1q_f16
|
||||
#define MS_LD_F16 vld1_f16
|
||||
#define MS_ADDQ_F16 vaddq_f16
|
||||
#define MS_SUBQ_F16 vsubq_f16
|
||||
#define MS_MULQ_F16 vmulq_f16
|
||||
|
|
|
@ -70,26 +70,18 @@ void MatmulBaseFP16CPUKernel::InitParameter() {
|
|||
}
|
||||
|
||||
int MatmulBaseFP16CPUKernel::InitBias() {
|
||||
if (in_tensors_.size() == 3) {
|
||||
auto bias_tensor = in_tensors_[2];
|
||||
int max_bias_data = UP_ROUND(bias_tensor->ElementsNum(), C8NUM);
|
||||
if (params_->col_ != 0 && bias_ptr_ == nullptr) {
|
||||
int max_bias_data = UP_ROUND(params_->col_, C8NUM);
|
||||
bias_ptr_ = reinterpret_cast<float16_t *>(malloc(max_bias_data * sizeof(float16_t)));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
bias_ptr_ = reinterpret_cast<float16_t *>(malloc(max_bias_data * sizeof(float16_t)));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias_ptr_ failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
memset(bias_ptr_, 0, max_bias_data * sizeof(float16_t));
|
||||
if (bias_tensor->data_type() == kNumberTypeFloat32) {
|
||||
MS_LOG(ERROR) << "Matmul fp16 only support fp16 weight";
|
||||
MS_LOG(ERROR) << "malloc bias_ptr_ failed";
|
||||
return RET_ERROR;
|
||||
} else if (bias_tensor->data_type() == kNumberTypeFloat16) {
|
||||
MS_ASSERT(bias_tensor->data_c() != nullptr);
|
||||
}
|
||||
if (in_tensors_.size() == 3) {
|
||||
auto bias_tensor = in_tensors_[2];
|
||||
memcpy(bias_ptr_, bias_tensor->data_c(), bias_tensor->ElementsNum() * sizeof(float16_t));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported bias data type : " << bias_tensor->data_type();
|
||||
return RET_ERROR;
|
||||
memset(bias_ptr_, 0, max_bias_data * sizeof(float16_t));
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
|
@ -181,13 +173,13 @@ void MatmulBaseFP16CPUKernel::InitMatrixA(void *src_ptr) {
|
|||
float16_t *dst = a_pack_ptr_ + i * params_->deep_ * params_->row_align_;
|
||||
if (params_->a_transpose_) {
|
||||
#ifdef ENABLE_ARM64
|
||||
RowMajor2Row16MajorFp16(src, dst, params_->deep_, params_->row_, src_data_type == kNumberTypeFloat32);
|
||||
RowMajor2RowNMajorFp16((const float16_t *)src, dst, params_->deep_, params_->row_);
|
||||
#else
|
||||
RowMajor2Row12MajorFp16(src, dst, params_->deep_, params_->row_, src_data_type == kNumberTypeFloat32);
|
||||
#endif
|
||||
} else {
|
||||
#ifdef ENABLE_ARM64
|
||||
RowMajor2Col16MajorFp16(src, dst, params_->row_, params_->deep_, src_data_type == kNumberTypeFloat32);
|
||||
RowMajor2ColNMajorFp16((const float16_t *)src, dst, params_->row_, params_->deep_);
|
||||
#else
|
||||
RowMajor2Col12MajorFp16(src, dst, params_->row_, params_->deep_, src_data_type == kNumberTypeFloat32);
|
||||
#endif
|
||||
|
@ -281,7 +273,7 @@ int MatmulBaseFP16CPUKernel::RunImpl(int task_id) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
auto bias = (bias_ptr_ == nullptr) ? nullptr : bias_ptr_ + thread_stride_ * task_id;
|
||||
auto bias = bias_ptr_ + thread_stride_ * task_id;
|
||||
auto b = batch_b_ptr_ + task_id * thread_stride_ * params_->deep_;
|
||||
auto c = batch_c_ptr_ + task_id * thread_stride_;
|
||||
|
||||
|
@ -292,8 +284,13 @@ int MatmulBaseFP16CPUKernel::RunImpl(int task_id) {
|
|||
MatVecMulFp16(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc);
|
||||
#endif
|
||||
} else {
|
||||
#ifdef ENABLE_ARM64
|
||||
MatmulBaseFp16Neon(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, params_->row_, cur_oc,
|
||||
params_->col_, OutType_Nhwc);
|
||||
#else
|
||||
MatMulFp16(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, params_->row_, cur_oc, params_->col_,
|
||||
OutType_Nhwc);
|
||||
#endif
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -55,7 +55,7 @@ void MatmulFP16CPUKernel::InitBShape() {
|
|||
|
||||
int MatmulFP16CPUKernel::Init() {
|
||||
#ifdef ENABLE_ARM64
|
||||
row_tile_ = C16NUM;
|
||||
row_tile_ = C4NUM;
|
||||
#else
|
||||
row_tile_ = C12NUM;
|
||||
#endif
|
||||
|
|
Loading…
Reference in New Issue