!9051 [MS][LITE][Develop]optimization for int8 matmul kernel on arm32

From: @lx0095
Reviewed-by: @zhang_xue_tong,@hangangqiang
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2020-11-26 15:29:56 +08:00 committed by Gitee
commit 247c610c57
1 changed files with 239 additions and 295 deletions

View File

@ -12,335 +12,279 @@
// const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp,
// int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel,
// int *filter_zp);
// #-52: a, #-48: b, #-44: dst, #-40: row
// #-48: a, #-44: b, #-40: dst, #-36: row
// #0: col, #4: deep16, #8: input_sums, #12: weight_bias, #16: act_min, #20: act_max, #24: out_zp
// #28: multiplier, #32: left_shift, #36: right_shift, #40: stride, #44: per_channel, #48: filter_zp
MatmulInt8Opt:
push {r0-r11, lr}
vpush {q4-q7}
add sp, sp, #116
push {r0-r8, r10, r11, lr}
vpush {q4-q7}
add sp, sp, #112
ldr r0, [sp, #-52] // load a ptr
vld1.8 {d0, d1, d2, d3}, [r0]!
ldr r5, [sp, #4]
ldr r6, [sp, #8] // reload a_sums ptr
ldr r8, [sp, #40]
mov r10, #4
mul r10, r10, r5 // lhs step
mov r11, #4
mul r11, r11, r8 // dst step
LoopRow:
ldr r1, [sp, #-44] //reload rhs ptr
ldr r4, [sp] // reload rhs col
ldr lr, [sp, #-40]
vmov.32 d4[0], lr // reload dst ptr
ldr lr, [sp, #32]
vmov.32 d4[1], lr // reload left shift
ldr lr, [sp, #28]
vmov.32 d5[0], lr // reload multiplier
ldr lr, [sp, #36]
vmov.32 d5[1], lr // reload right shift
ldr r7, [sp, #48] // reload filter_zp
ldr r12, [sp, #12] // reload bias ptr
ldr r1, [sp, #-48] // load b ptr
vld1.8 {d8, d9, d10, d11}, [r1]!
LoopCol:
vmov.32 r2, d4[0] // reload dst ptr
ldr r0, [sp, #-48] //reload lhs ptr
ldr r5, [sp, #4] // reaload depth
ldr r4, [sp] // col
ldr r7, [sp, #40] // output stride
mov r8, #0 // output channels offset
ldr r10, [sp, #44]
cmp r10, #0
beq L1
ldr r6, [sp, #8] // load intpu_sums ptr if per_channel
L1:
cmp r4, #0 // if at the end of col
ble End1
vmov.i32 q6, #0
vmov.i32 q7, #0
vmov.i32 q8, #0
vmov.i32 q9, #0
vmov.i32 q10, #0
vmov.i32 q11, #0
vmov.i32 q12, #0
vmov.i32 q13, #0
ldr r3, [sp, #-40] // reset row counter
ldr r6, [sp, #8] // reload intpu_sums ptr if per_tensor
L2:
cmp r3, #0 // if at the end of row
ble End2
LoopDepth:
vld1.8 {q0-q1}, [r0]!
vld1.8 {q4-q5}, [r1]!
vmull.s8 q14, d0, d8
vmull.s8 q15, d2, d8
vmlal.s8 q14, d1, d9
vmlal.s8 q15, d3, d9
vpadal.s16 q6, q14
vpadal.s16 q8, q15
vmull.s8 q14, d0, d10
vmull.s8 q15, d2, d10
vmlal.s8 q14, d1, d11
vmlal.s8 q15, d3, d11
vld1.8 {q0-q1}, [r0]!
ldr r5, [sp, #4] // reset deep16
sub r5, r5, #16
vmov.i32 q6, #0
vmov.i32 q7, #0
vmov.i32 q8, #0
vmov.i32 q9, #0
vmov.i32 q10, #0
vmov.i32 q11, #0
vmov.i32 q12, #0
vmov.i32 q13, #0
vpadal.s16 q7, q14
vpadal.s16 q9, q15
cmp r5, #0
beq L3Tail
L3:
vmull.s8 q14, d0, d8
vmull.s8 q2, d0, d10
vmull.s8 q15, d2, d8
vmull.s8 q3, d2, d10
vmlal.s8 q14, d1, d9
vmlal.s8 q2, d1, d11
vmlal.s8 q15, d3, d9
vmlal.s8 q3, d3, d11
vld1.8 {d0, d1, d2, d3}, [r0]!
vmull.s8 q14, d0, d8
vmull.s8 q15, d2, d8
vmlal.s8 q14, d1, d9
vmlal.s8 q15, d3, d9
vpadal.s16 q10, q14
vpadal.s16 q12, q15
vmull.s8 q14, d0, d10
vmull.s8 q15, d2, d10
vmlal.s8 q14, d1, d11
vmlal.s8 q15, d3, d11
vpadal.s16 q6, q14
vpadal.s16 q7, q2
vpadal.s16 q8, q15
vpadal.s16 q9, q3
vpadal.s16 q11, q14
vpadal.s16 q13, q15
vmull.s8 q14, d0, d8
vmull.s8 q2, d0, d10
vmull.s8 q15, d2, d8
vmull.s8 q3, d2, d10
vmlal.s8 q14, d1, d9
vmlal.s8 q2, d1, d11
vmlal.s8 q15, d3, d9
vmlal.s8 q3, d3, d11
vld1.8 {d0, d1, d2, d3}, [r0]!
cmp r5, #16
ble LoopDepthEnd
sub r5, r5, #16
b LoopDepth
vpadal.s16 q10, q14
vld1.8 {d8, d9, d10, d11}, [r1]!
vpadal.s16 q11, q2
vpadal.s16 q12, q15
vpadal.s16 q13, q3
sub r5, r5, #16 // deep16 -= 16
cmp r5, #0
bgt L3
L3Tail:
vmull.s8 q14, d0, d8
vmull.s8 q2, d0, d10
vmull.s8 q15, d2, d8
vmull.s8 q3, d2, d10
vmlal.s8 q14, d1, d9
vmlal.s8 q2, d1, d11
vmlal.s8 q15, d3, d9
vmlal.s8 q3, d3, d11
vld1.8 {d0, d1, d2, d3}, [r0]!
LoopDepthEnd:
vpadd.i32 d12, d12, d13
vpadd.i32 d14, d14, d15
vpadd.i32 d16, d16, d17
vpadd.i32 d18, d18, d19
vpadd.i32 d20, d20, d21
vpadd.i32 d22, d22, d23
vpadd.i32 d24, d24, d25
vpadd.i32 d26, d26, d27
vpadal.s16 q6, q14
vpadal.s16 q7, q2
vpadal.s16 q8, q15
vpadal.s16 q9, q3
vpadd.i32 d28, d12, d14
vpadd.i32 d29, d16, d18
vpadd.i32 d30, d20, d22
vpadd.i32 d31, d24, d26
vmull.s8 q14, d0, d8
vmull.s8 q2, d0, d10
vmull.s8 q15, d2, d8
vmull.s8 q3, d2, d10
vmlal.s8 q14, d1, d9
vmlal.s8 q2, d1, d11
vmlal.s8 q15, d3, d9
vmlal.s8 q3, d3, d11
Bias:
cmp r12, #0
beq NoBias
vld1.32 {d26}, [r12]!
vadd.i32 d28, d28, d26
vadd.i32 d29, d29, d26
vadd.i32 d30, d30, d26
vadd.i32 d31, d31, d26
vpadal.s16 q10, q14
vpadal.s16 q11, q2
vpadal.s16 q12, q15
vpadal.s16 q13, q3
NoBias:
ldr lr, [sp, #44]
cmp lr, #0
bne PerChannel
vpadd.i32 d0, d12, d13
vpadd.i32 d1, d14, d15
vpadd.i32 d2, d16, d17
vpadd.i32 d3, d18, d19
vpadd.i32 d4, d20, d21
vpadd.i32 d5, d22, d23
vpadd.i32 d6, d24, d25
vpadd.i32 d7, d26, d27
PerTensor:
vld1.32 {d24, d25}, [r6]
vdup.32 d20, d24[0]
vdup.32 d21, d24[1]
vdup.32 d22, d25[0]
vdup.32 d23, d25[1]
vsub.s32 d28, d28, d20
vsub.s32 d29, d29, d21
vsub.s32 d30, d30, d22
vsub.s32 d31, d31, d23
vpadd.i32 d28, d0, d1
vpadd.i32 d29, d2, d3
vpadd.i32 d30, d4, d5
vpadd.i32 d31, d6, d7
vmov.32 lr, d4[1]
vld1.32 {q9[]}, [lr]
vshl.s32 q14, q14, q9
vshl.s32 q15, q15, q9
cmp r3, #4
ble LAST_ROW
vld1.8 {d0, d1, d2, d3}, [r0]!
ldr r1, [sp, #-48] // reload b ptr
vld1.8 {d8, d9, d10, d11}, [r1]!
b AddWeightBias
vmov.32 lr, d5[0]
vld1.32 {q8[]}, [lr]
vqrdmulh.s32 q14, q14, q8
vqrdmulh.s32 q15, q15, q8
LAST_ROW:
ldr r0, [sp, #-52] // reload a ptr
vld1.8 {d0, d1, d2, d3}, [r0]!
ldr r1, [sp, #-48] // reload b ptr
ldr r9, [sp, #4]
mov r10, #2
mul r9, r9, r10 // the stride of b
add r1, r1, r9 // b ptr + stride
str r1, [sp, #-48]
vld1.8 {d8, d9, d10, d11}, [r1]!
vmov.32 lr, d5[1]
vld1.32 {q7[]}, [lr]
vand q6, q7, q14
vshr.s32 q6, q6, #31
vqadd.s32 q14, q14, q6
vrshl.s32 q14, q14, q7
vand q5, q7, q15
vshr.s32 q5, q5, #31
vqadd.s32 q15, q15, q5
vrshl.s32 q15, q15, q7
b Quantize
AddWeightBias:
ldr r9, [sp, #12] // reload weight_bias ptr
add r9, r9, r8
vld1.32 {d26}, [r9]!
vadd.i32 d28, d28, d26
vadd.i32 d29, d29, d26
vadd.i32 d30, d30, d26
vadd.i32 d31, d31, d26
PerChannel:
vld1.32 {d24, d25}, [r6]
vdup.32 d20, d24[0]
vdup.32 d21, d24[1]
vdup.32 d22, d25[0]
vdup.32 d23, d25[1]
vld1.32 {d19}, [r7]!
vmul.s32 d24, d20, d19
vmul.s32 d25, d21, d19
vmul.s32 d26, d22, d19
vmul.s32 d27, d23, d19
vsub.s32 d28, d28, d24
vsub.s32 d29, d29, d25
vsub.s32 d30, d30, d26
vsub.s32 d31, d31, d27
ldr r10, [sp, #44]
cmp r10, #0
bgt PerChannel
vmov.32 lr, d4[1]
vld1.32 {d23}, [lr]!
vmov.32 d4[1], lr
vshl.s32 d28, d28, d23
vshl.s32 d29, d29, d23
vshl.s32 d30, d30, d23
vshl.s32 d31, d31, d23
PerTensor:
// Substract input_sums
vld1.32 {d24, d25}, [r6]!
vdup.32 d20, d24[0]
vdup.32 d21, d24[1]
vdup.32 d22, d25[0]
vdup.32 d23, d25[1]
vsub.s32 d28, d28, d20
vsub.s32 d29, d29, d21
vsub.s32 d30, d30, d22
vsub.s32 d31, d31, d23
vmov.32 lr, d5[0]
vld1.32 {d22}, [lr]!
vmov.32 d5[0], lr
vqrdmulh.s32 d28, d28, d22
vqrdmulh.s32 d29, d29, d22
vqrdmulh.s32 d30, d30, d22
vqrdmulh.s32 d31, d31, d22
// Apply left shift
ldr r10, [sp, #32]
ldr r11, [r10]!
vdup.32 q9, r11
vshl.s32 q14, q14, q9
vshl.s32 q15, q15, q9
vmov.32 lr, d5[1]
vld1.32 {d21}, [lr]!
vmov.32 d5[1], lr
vand d20, d21, d28
vshr.s32 d20, d20, #31
vqadd.s32 d28, d28, d20
vrshl.s32 d28, d28, d21
vand d19, d21, d29
vshr.s32 d19, d19, #31
vqadd.s32 d29, d29, d19
vrshl.s32 d29, d29, d21
vand d18, d21, d30
vshr.s32 d18, d18, #31
vqadd.s32 d30, d30, d18
vrshl.s32 d30, d30, d21
vand d17, d21, d31
vshr.s32 d17, d17, #31
vqadd.s32 d31, d31, d17
vrshl.s32 d31, d31, d21
// Apply the fixed-point part of the multiplier
ldr r10, [sp, #28]
ldr r11, [r10]
vdup.32 q8, r11
vqrdmulh.s32 q14, q14, q8
vqrdmulh.s32 q15, q15, q8
Quantize:
ldr lr, [sp, #24]
vdup.32 q0, lr
vadd.i32 q14, q14, q0
vadd.i32 q15, q15, q0
// Apply right shift
ldr r10, [sp, #36]
ldr r11, [r10]
vdup.32 q7, r11
vand q6, q7, q14
vshr.s32 q6, q6, #31
vqadd.s32 q14, q14, q6
vrshl.s32 q14, q14, q7
vand q3, q7, q15
vshr.s32 q3, q3, #31
vqadd.s32 q15, q15, q3
vrshl.s32 q15, q15, q7
b AddDstZP
ldr lr, [sp, #16]
vdup.32 q1, lr
vmax.s32 q14, q14, q1
vmax.s32 q15, q15, q1
PerChannel:
// Substract input_sums
vld1.32 {d24, d25}, [r6]!
vdup.32 d20, d24[0]
vdup.32 d21, d24[1]
vdup.32 d22, d25[0]
vdup.32 d23, d25[1]
ldr r10, [sp, #48]
vld1.32 {d19}, [r10]
vmul.s32 d24, d20, d19
vmul.s32 d25, d21, d19
vmul.s32 d26, d22, d19
vmul.s32 d27, d23, d19
vsub.s32 d28, d28, d24
vsub.s32 d29, d29, d25
vsub.s32 d30, d30, d26
vsub.s32 d31, d31, d27
ldr lr, [sp, #20]
vdup.32 q0, lr
vmin.s32 q14, q14, q0
vmin.s32 q15, q15, q0
// Apply left shift
ldr r10, [sp, #32]
add r10, r10, r8
vld1.32 {d23}, [r10]
vshl.s32 d28, d28, d23
vshl.s32 d29, d29, d23
vshl.s32 d30, d30, d23
vshl.s32 d31, d31, d23
vqmovn.s32 d28, q14
vqmovn.s32 d29, q15
// Apply the fixed-point part of the multiplier
ldr r10, [sp, #28]
add r10, r10, r8
vld1.32 {d22}, [r10]
vqrdmulh.s32 d28, d28, d22
vqrdmulh.s32 d29, d29, d22
vqrdmulh.s32 d30, d30, d22
vqrdmulh.s32 d31, d31, d22
vqmovn.s16 d30, q14
// Apply right shift
ldr r10, [sp, #36]
add r10, r10, r8
vld1.32 {d21}, [r10]
vand d20, d21, d28
vshr.s32 d20, d20, #31
vqadd.s32 d28, d28, d20
vrshl.s32 d28, d28, d21
vand d19, d21, d29
vshr.s32 d19, d19, #31
vqadd.s32 d29, d29, d19
vrshl.s32 d29, d29, d21
vand d18, d21, d30
vshr.s32 d18, d18, #31
vqadd.s32 d30, d30, d18
vrshl.s32 d30, d30, d21
vand d17, d21, d31
vshr.s32 d17, d17, #31
vqadd.s32 d31, d31, d17
vrshl.s32 d31, d31, d21
cmp r4, #1
beq Write1
b Write2
AddDstZP:
// Add the destination zero point
ldr r10, [sp, #24]
vdup.32 q2, r10
vadd.i32 q14, q14, q2
vadd.i32 q15, q15, q2
Write1:
vmov.32 lr, d4[0]
add lr, lr, #1
vmov.32 d4[0], lr
vst1.8 {d30[0]}, [r2], r8
cmp r3, #1
beq WriteEnd
vst1.8 {d30[2]}, [r2], r8
cmp r3, #2
beq WriteEnd
vst1.8 {d30[4]}, [r2], r8
cmp r3, #3
beq WriteEnd
vst1.8 {d30[6]}, [r2], r8
b WriteEnd
// Apply the act_min bound
ldr r10, [sp, #16]
vdup.32 q3, r10
vmax.s32 q14, q14, q3
vmax.s32 q15, q15, q3
Write2:
vmov.32 lr, d4[0]
add lr, lr, #2
vmov.32 d4[0], lr
vst1.16 {d30[0]}, [r2], r8
cmp r3, #1
beq WriteEnd
vst1.16 {d30[1]}, [r2], r8
cmp r3, #2
beq WriteEnd
vst1.16 {d30[2]}, [r2], r8
cmp r3, #3
beq WriteEnd
vst1.16 {d30[3]}, [r2], r8
// Apply the act_max bound
ldr r10, [sp, #20]
vdup.32 q2, r10
vmin.s32 q14, q14, q2
vmin.s32 q15, q15, q2
WriteEnd:
cmp r4, #2
ble LoopColEnd
sub r4, r4, #2
b LoopCol
// Cast-and-saturate from int32 to int16
vqmovn.s32 d28, q14
vqmovn.s32 d29, q15
LoopColEnd:
cmp r3, #4
ble LoopRowEnd
ldr lr, [sp, #-48]
add lr, lr, r10
str lr, [sp, #-48]
ldr lr, [sp, #-40]
add lr, lr, r11
str lr, [sp, #-40]
sub r3, r3, #4
add r6, r6, #16
b LoopRow
// Cast-and-saturate from int16 to int8
vqmovn.s16 d30, q14
// start to write
cmp r4, #2
bge WriteCol2
cmp r4, #1
beq WriteCol1
b EndWrite
WriteCol2:
vst1.16 {d30[0]}, [r2], r7
cmp r3, #1
beq EndWrite
vst1.16 {d30[1]}, [r2], r7
cmp r3, #2
beq EndWrite
vst1.16 {d30[2]}, [r2], r7
cmp r3, #3
beq EndWrite
vst1.16 {d30[3]}, [r2], r7
b EndWrite
WriteCol1:
vst1.8 {d30[0]}, [r2], r7
cmp r3, #1
beq EndWrite
vst1.8 {d30[2]}, [r2], r7
cmp r3, #2
beq EndWrite
vst1.8 {d30[4]}, [r2], r7
cmp r3, #3
beq EndWrite
vst1.8 {d30[6]}, [r2], r7
b EndWrite
EndWrite:
sub r3, r3, #4 // a row counter -= 4
b L2
End2:
sub r4, r4, #2 // b col counter -= 2
ldr r2, [sp, #-44] // load dst ptr
add r2, r2, #2 // dst ptr + offset
str r2, [sp, #-44]
ldr r10, [sp, #48]
add r10, r10, #8
str r10, [sp, #48]
add r8, r8, #8 // output channels offset + 2*sizeof(int)
b L1
End1:
sub sp, sp, #116
vpop {q4-q7}
pop {r0-r11, pc}
LoopRowEnd:
sub sp, sp, #112
vpop {q4-q7}
pop {r0-r8, r10, r11, pc}
#endif
#endif