!8960 [MSLITE] Optimize int8 matmul for arm32

From: @zhanyuan1
Reviewed-by: @zhanghaibo5,@zhang_xue_tong
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2020-11-24 20:11:21 +08:00 committed by Gitee
commit 26c758bfdb
1 changed files with 74 additions and 27 deletions

View File

@ -8,7 +8,7 @@
.type MatmulInt8Neon32Opt, %function
#endif
//void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16,
//void MatmulInt8Neon32Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16,
// const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp,
// int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel,
// int *filter_zp);
@ -21,6 +21,12 @@ MatmulInt8Neon32Opt:
vpush {q4-q7}
add sp, sp, #116
ldr r0, [sp, #-52] // load a ptr
vld1.8 {d0, d1, d2, d3}, [r0]!
ldr r1, [sp, #-48] // load b ptr
vld1.8 {d8, d9, d10, d11}, [r1]!
ldr r4, [sp] // col
ldr r7, [sp, #40] // output stride
mov r8, #0 // output channels offset
@ -32,15 +38,14 @@ L1:
cmp r4, #0 // if at the end of col
ble End1
ldr r0, [sp, #-52] // reload a ptr
ldr r3, [sp, #-40] // reset row counter
ldr r6, [sp, #8] // reload intpu_sums ptr if per_tensor
L2:
cmp r3, #0 // if at the end of row
ble End2
ldr r1, [sp, #-48] // reload b ptr
ldr r5, [sp, #4] // reset deep16
sub r5, r5, #16
vmov.i32 q6, #0
vmov.i32 q7, #0
vmov.i32 q8, #0
@ -49,12 +54,10 @@ L2:
vmov.i32 q11, #0
vmov.i32 q12, #0
vmov.i32 q13, #0
L3:
cmp r5, #0
beq End3
vld1.8 {d0, d1, d2, d3}, [r0]!
vld1.8 {d8, d9, d10, d11}, [r1]!
cmp r5, #0
beq L3Tail
L3:
vmull.s8 q14, d0, d8
vmull.s8 q2, d0, d10
vmull.s8 q15, d2, d8
@ -63,13 +66,47 @@ L3:
vmlal.s8 q2, d1, d11
vmlal.s8 q15, d3, d9
vmlal.s8 q3, d3, d11
vld1.8 {d0, d1, d2, d3}, [r0]!
vpadal.s16 q6, q14
vpadal.s16 q7, q2
vpadal.s16 q8, q15
vpadal.s16 q9, q3
vmull.s8 q14, d0, d8
vmull.s8 q2, d0, d10
vmull.s8 q15, d2, d8
vmull.s8 q3, d2, d10
vmlal.s8 q14, d1, d9
vmlal.s8 q2, d1, d11
vmlal.s8 q15, d3, d9
vmlal.s8 q3, d3, d11
vld1.8 {d0, d1, d2, d3}, [r0]!
vpadal.s16 q10, q14
vld1.8 {d8, d9, d10, d11}, [r1]!
vpadal.s16 q11, q2
vpadal.s16 q12, q15
vpadal.s16 q13, q3
sub r5, r5, #16 // deep16 -= 16
cmp r5, #0
bgt L3
L3Tail:
vmull.s8 q14, d0, d8
vmull.s8 q2, d0, d10
vmull.s8 q15, d2, d8
vmull.s8 q3, d2, d10
vmlal.s8 q14, d1, d9
vmlal.s8 q2, d1, d11
vmlal.s8 q15, d3, d9
vmlal.s8 q3, d3, d11
vld1.8 {d0, d1, d2, d3}, [r0]!
vpadal.s16 q6, q14
vpadal.s16 q7, q2
vpadal.s16 q8, q15
vpadal.s16 q9, q3
vmull.s8 q14, d0, d8
vmull.s8 q2, d0, d10
vmull.s8 q15, d2, d8
@ -83,10 +120,7 @@ L3:
vpadal.s16 q11, q2
vpadal.s16 q12, q15
vpadal.s16 q13, q3
sub r5, r5, #16 // deep16 -= 16
b L3
End3:
vpadd.i32 d0, d12, d13
vpadd.i32 d1, d14, d15
vpadd.i32 d2, d16, d17
@ -101,7 +135,26 @@ End3:
vpadd.i32 d30, d4, d5
vpadd.i32 d31, d6, d7
// Add weight_bias
cmp r3, #4
ble LAST_ROW
vld1.8 {d0, d1, d2, d3}, [r0]!
ldr r1, [sp, #-48] // reload b ptr
vld1.8 {d8, d9, d10, d11}, [r1]!
b AddWeightBias
LAST_ROW:
ldr r0, [sp, #-52] // reload a ptr
vld1.8 {d0, d1, d2, d3}, [r0]!
ldr r1, [sp, #-48] // reload b ptr
ldr r9, [sp, #4]
mov r10, #2
mul r9, r9, r10 // the stride of b
add r1, r1, r9 // b ptr + stride
str r1, [sp, #-48]
vld1.8 {d8, d9, d10, d11}, [r1]!
AddWeightBias:
ldr r9, [sp, #12] // reload weight_bias ptr
add r9, r9, r8
vld1.32 {d26}, [r9]!
@ -148,9 +201,9 @@ PerTensor:
vshr.s32 q6, q6, #31
vqadd.s32 q14, q14, q6
vrshl.s32 q14, q14, q7
vand q5, q7, q15
vshr.s32 q5, q5, #31
vqadd.s32 q15, q15, q5
vand q3, q7, q15
vshr.s32 q3, q3, #31
vqadd.s32 q15, q15, q3
vrshl.s32 q15, q15, q7
b AddDstZP
@ -214,9 +267,9 @@ PerChannel:
AddDstZP:
// Add the destination zero point
ldr r10, [sp, #24]
vdup.32 q4, r10
vadd.i32 q14, q14, q4
vadd.i32 q15, q15, q4
vdup.32 q2, r10
vadd.i32 q14, q14, q2
vadd.i32 q15, q15, q2
// Apply the act_min bound
ldr r10, [sp, #16]
@ -276,12 +329,6 @@ EndWrite:
End2:
sub r4, r4, #2 // b col counter -= 2
ldr r1, [sp, #-48] // load b ptr
ldr r9, [sp, #4]
mov r10, #2
mul r9, r9, r10 // the stride of b
add r1, r1, r9 // b ptr + stride
str r1, [sp, #-48]
ldr r2, [sp, #-44] // load dst ptr
add r2, r2, #2 // dst ptr + offset
str r2, [sp, #-44]