forked from mindspore-Ecosystem/mindspore
!8960 [MSLITE] Optimize int8 matmul for arm32
From: @zhanyuan1 Reviewed-by: @zhanghaibo5,@zhang_xue_tong Signed-off-by: @zhang_xue_tong
This commit is contained in:
commit
26c758bfdb
|
@ -8,7 +8,7 @@
|
|||
.type MatmulInt8Neon32Opt, %function
|
||||
#endif
|
||||
|
||||
//void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16,
|
||||
//void MatmulInt8Neon32Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16,
|
||||
// const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp,
|
||||
// int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel,
|
||||
// int *filter_zp);
|
||||
|
@ -21,6 +21,12 @@ MatmulInt8Neon32Opt:
|
|||
vpush {q4-q7}
|
||||
add sp, sp, #116
|
||||
|
||||
ldr r0, [sp, #-52] // load a ptr
|
||||
vld1.8 {d0, d1, d2, d3}, [r0]!
|
||||
|
||||
ldr r1, [sp, #-48] // load b ptr
|
||||
vld1.8 {d8, d9, d10, d11}, [r1]!
|
||||
|
||||
ldr r4, [sp] // col
|
||||
ldr r7, [sp, #40] // output stride
|
||||
mov r8, #0 // output channels offset
|
||||
|
@ -32,15 +38,14 @@ L1:
|
|||
cmp r4, #0 // if at the end of col
|
||||
ble End1
|
||||
|
||||
ldr r0, [sp, #-52] // reload a ptr
|
||||
ldr r3, [sp, #-40] // reset row counter
|
||||
ldr r6, [sp, #8] // reload intpu_sums ptr if per_tensor
|
||||
L2:
|
||||
cmp r3, #0 // if at the end of row
|
||||
ble End2
|
||||
|
||||
ldr r1, [sp, #-48] // reload b ptr
|
||||
ldr r5, [sp, #4] // reset deep16
|
||||
sub r5, r5, #16
|
||||
vmov.i32 q6, #0
|
||||
vmov.i32 q7, #0
|
||||
vmov.i32 q8, #0
|
||||
|
@ -49,12 +54,10 @@ L2:
|
|||
vmov.i32 q11, #0
|
||||
vmov.i32 q12, #0
|
||||
vmov.i32 q13, #0
|
||||
L3:
|
||||
cmp r5, #0
|
||||
beq End3
|
||||
|
||||
vld1.8 {d0, d1, d2, d3}, [r0]!
|
||||
vld1.8 {d8, d9, d10, d11}, [r1]!
|
||||
cmp r5, #0
|
||||
beq L3Tail
|
||||
L3:
|
||||
vmull.s8 q14, d0, d8
|
||||
vmull.s8 q2, d0, d10
|
||||
vmull.s8 q15, d2, d8
|
||||
|
@ -63,13 +66,47 @@ L3:
|
|||
vmlal.s8 q2, d1, d11
|
||||
vmlal.s8 q15, d3, d9
|
||||
vmlal.s8 q3, d3, d11
|
||||
vld1.8 {d0, d1, d2, d3}, [r0]!
|
||||
|
||||
vpadal.s16 q6, q14
|
||||
vpadal.s16 q7, q2
|
||||
vpadal.s16 q8, q15
|
||||
vpadal.s16 q9, q3
|
||||
|
||||
vmull.s8 q14, d0, d8
|
||||
vmull.s8 q2, d0, d10
|
||||
vmull.s8 q15, d2, d8
|
||||
vmull.s8 q3, d2, d10
|
||||
vmlal.s8 q14, d1, d9
|
||||
vmlal.s8 q2, d1, d11
|
||||
vmlal.s8 q15, d3, d9
|
||||
vmlal.s8 q3, d3, d11
|
||||
vld1.8 {d0, d1, d2, d3}, [r0]!
|
||||
|
||||
vpadal.s16 q10, q14
|
||||
vld1.8 {d8, d9, d10, d11}, [r1]!
|
||||
vpadal.s16 q11, q2
|
||||
vpadal.s16 q12, q15
|
||||
vpadal.s16 q13, q3
|
||||
sub r5, r5, #16 // deep16 -= 16
|
||||
cmp r5, #0
|
||||
bgt L3
|
||||
L3Tail:
|
||||
vmull.s8 q14, d0, d8
|
||||
vmull.s8 q2, d0, d10
|
||||
vmull.s8 q15, d2, d8
|
||||
vmull.s8 q3, d2, d10
|
||||
vmlal.s8 q14, d1, d9
|
||||
vmlal.s8 q2, d1, d11
|
||||
vmlal.s8 q15, d3, d9
|
||||
vmlal.s8 q3, d3, d11
|
||||
vld1.8 {d0, d1, d2, d3}, [r0]!
|
||||
|
||||
vpadal.s16 q6, q14
|
||||
vpadal.s16 q7, q2
|
||||
vpadal.s16 q8, q15
|
||||
vpadal.s16 q9, q3
|
||||
|
||||
vmull.s8 q14, d0, d8
|
||||
vmull.s8 q2, d0, d10
|
||||
vmull.s8 q15, d2, d8
|
||||
|
@ -83,10 +120,7 @@ L3:
|
|||
vpadal.s16 q11, q2
|
||||
vpadal.s16 q12, q15
|
||||
vpadal.s16 q13, q3
|
||||
sub r5, r5, #16 // deep16 -= 16
|
||||
b L3
|
||||
|
||||
End3:
|
||||
vpadd.i32 d0, d12, d13
|
||||
vpadd.i32 d1, d14, d15
|
||||
vpadd.i32 d2, d16, d17
|
||||
|
@ -101,7 +135,26 @@ End3:
|
|||
vpadd.i32 d30, d4, d5
|
||||
vpadd.i32 d31, d6, d7
|
||||
|
||||
// Add weight_bias
|
||||
cmp r3, #4
|
||||
ble LAST_ROW
|
||||
|
||||
vld1.8 {d0, d1, d2, d3}, [r0]!
|
||||
ldr r1, [sp, #-48] // reload b ptr
|
||||
vld1.8 {d8, d9, d10, d11}, [r1]!
|
||||
b AddWeightBias
|
||||
|
||||
LAST_ROW:
|
||||
ldr r0, [sp, #-52] // reload a ptr
|
||||
vld1.8 {d0, d1, d2, d3}, [r0]!
|
||||
ldr r1, [sp, #-48] // reload b ptr
|
||||
ldr r9, [sp, #4]
|
||||
mov r10, #2
|
||||
mul r9, r9, r10 // the stride of b
|
||||
add r1, r1, r9 // b ptr + stride
|
||||
str r1, [sp, #-48]
|
||||
vld1.8 {d8, d9, d10, d11}, [r1]!
|
||||
|
||||
AddWeightBias:
|
||||
ldr r9, [sp, #12] // reload weight_bias ptr
|
||||
add r9, r9, r8
|
||||
vld1.32 {d26}, [r9]!
|
||||
|
@ -148,9 +201,9 @@ PerTensor:
|
|||
vshr.s32 q6, q6, #31
|
||||
vqadd.s32 q14, q14, q6
|
||||
vrshl.s32 q14, q14, q7
|
||||
vand q5, q7, q15
|
||||
vshr.s32 q5, q5, #31
|
||||
vqadd.s32 q15, q15, q5
|
||||
vand q3, q7, q15
|
||||
vshr.s32 q3, q3, #31
|
||||
vqadd.s32 q15, q15, q3
|
||||
vrshl.s32 q15, q15, q7
|
||||
b AddDstZP
|
||||
|
||||
|
@ -214,9 +267,9 @@ PerChannel:
|
|||
AddDstZP:
|
||||
// Add the destination zero point
|
||||
ldr r10, [sp, #24]
|
||||
vdup.32 q4, r10
|
||||
vadd.i32 q14, q14, q4
|
||||
vadd.i32 q15, q15, q4
|
||||
vdup.32 q2, r10
|
||||
vadd.i32 q14, q14, q2
|
||||
vadd.i32 q15, q15, q2
|
||||
|
||||
// Apply the act_min bound
|
||||
ldr r10, [sp, #16]
|
||||
|
@ -276,12 +329,6 @@ EndWrite:
|
|||
|
||||
End2:
|
||||
sub r4, r4, #2 // b col counter -= 2
|
||||
ldr r1, [sp, #-48] // load b ptr
|
||||
ldr r9, [sp, #4]
|
||||
mov r10, #2
|
||||
mul r9, r9, r10 // the stride of b
|
||||
add r1, r1, r9 // b ptr + stride
|
||||
str r1, [sp, #-48]
|
||||
ldr r2, [sp, #-44] // load dst ptr
|
||||
add r2, r2, #2 // dst ptr + offset
|
||||
str r2, [sp, #-44]
|
||||
|
|
Loading…
Reference in New Issue