forked from mindspore-Ecosystem/mindspore
!9051 [MS][LITE][Develop]optimization for int8 matmul kernel on arm32
From: @lx0095 Reviewed-by: @zhang_xue_tong,@hangangqiang Signed-off-by: @zhang_xue_tong
This commit is contained in:
commit
247c610c57
|
@ -12,40 +12,41 @@
|
|||
// const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp,
|
||||
// int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel,
|
||||
// int *filter_zp);
|
||||
// #-52: a, #-48: b, #-44: dst, #-40: row
|
||||
// #-48: a, #-44: b, #-40: dst, #-36: row
|
||||
// #0: col, #4: deep16, #8: input_sums, #12: weight_bias, #16: act_min, #20: act_max, #24: out_zp
|
||||
// #28: multiplier, #32: left_shift, #36: right_shift, #40: stride, #44: per_channel, #48: filter_zp
|
||||
|
||||
MatmulInt8Opt:
|
||||
push {r0-r11, lr}
|
||||
push {r0-r8, r10, r11, lr}
|
||||
vpush {q4-q7}
|
||||
add sp, sp, #116
|
||||
add sp, sp, #112
|
||||
|
||||
ldr r0, [sp, #-52] // load a ptr
|
||||
vld1.8 {d0, d1, d2, d3}, [r0]!
|
||||
ldr r5, [sp, #4]
|
||||
ldr r6, [sp, #8] // reload a_sums ptr
|
||||
ldr r8, [sp, #40]
|
||||
mov r10, #4
|
||||
mul r10, r10, r5 // lhs step
|
||||
mov r11, #4
|
||||
mul r11, r11, r8 // dst step
|
||||
LoopRow:
|
||||
ldr r1, [sp, #-44] //reload rhs ptr
|
||||
ldr r4, [sp] // reload rhs col
|
||||
ldr lr, [sp, #-40]
|
||||
vmov.32 d4[0], lr // reload dst ptr
|
||||
ldr lr, [sp, #32]
|
||||
vmov.32 d4[1], lr // reload left shift
|
||||
ldr lr, [sp, #28]
|
||||
vmov.32 d5[0], lr // reload multiplier
|
||||
ldr lr, [sp, #36]
|
||||
vmov.32 d5[1], lr // reload right shift
|
||||
ldr r7, [sp, #48] // reload filter_zp
|
||||
ldr r12, [sp, #12] // reload bias ptr
|
||||
|
||||
ldr r1, [sp, #-48] // load b ptr
|
||||
vld1.8 {d8, d9, d10, d11}, [r1]!
|
||||
LoopCol:
|
||||
vmov.32 r2, d4[0] // reload dst ptr
|
||||
ldr r0, [sp, #-48] //reload lhs ptr
|
||||
ldr r5, [sp, #4] // reaload depth
|
||||
|
||||
ldr r4, [sp] // col
|
||||
ldr r7, [sp, #40] // output stride
|
||||
mov r8, #0 // output channels offset
|
||||
ldr r10, [sp, #44]
|
||||
cmp r10, #0
|
||||
beq L1
|
||||
ldr r6, [sp, #8] // load intpu_sums ptr if per_channel
|
||||
L1:
|
||||
cmp r4, #0 // if at the end of col
|
||||
ble End1
|
||||
|
||||
ldr r3, [sp, #-40] // reset row counter
|
||||
ldr r6, [sp, #8] // reload intpu_sums ptr if per_tensor
|
||||
L2:
|
||||
cmp r3, #0 // if at the end of row
|
||||
ble End2
|
||||
|
||||
ldr r5, [sp, #4] // reset deep16
|
||||
sub r5, r5, #16
|
||||
vmov.i32 q6, #0
|
||||
vmov.i32 q7, #0
|
||||
vmov.i32 q8, #0
|
||||
|
@ -55,121 +56,74 @@ L2:
|
|||
vmov.i32 q12, #0
|
||||
vmov.i32 q13, #0
|
||||
|
||||
cmp r5, #0
|
||||
beq L3Tail
|
||||
L3:
|
||||
LoopDepth:
|
||||
vld1.8 {q0-q1}, [r0]!
|
||||
vld1.8 {q4-q5}, [r1]!
|
||||
vmull.s8 q14, d0, d8
|
||||
vmull.s8 q2, d0, d10
|
||||
vmull.s8 q15, d2, d8
|
||||
vmull.s8 q3, d2, d10
|
||||
vmlal.s8 q14, d1, d9
|
||||
vmlal.s8 q2, d1, d11
|
||||
vmlal.s8 q15, d3, d9
|
||||
vmlal.s8 q3, d3, d11
|
||||
vld1.8 {d0, d1, d2, d3}, [r0]!
|
||||
|
||||
vpadal.s16 q6, q14
|
||||
vpadal.s16 q7, q2
|
||||
vpadal.s16 q8, q15
|
||||
vpadal.s16 q9, q3
|
||||
vmull.s8 q14, d0, d10
|
||||
vmull.s8 q15, d2, d10
|
||||
vmlal.s8 q14, d1, d11
|
||||
vmlal.s8 q15, d3, d11
|
||||
vld1.8 {q0-q1}, [r0]!
|
||||
|
||||
vpadal.s16 q7, q14
|
||||
vpadal.s16 q9, q15
|
||||
|
||||
vmull.s8 q14, d0, d8
|
||||
vmull.s8 q2, d0, d10
|
||||
vmull.s8 q15, d2, d8
|
||||
vmull.s8 q3, d2, d10
|
||||
vmlal.s8 q14, d1, d9
|
||||
vmlal.s8 q2, d1, d11
|
||||
vmlal.s8 q15, d3, d9
|
||||
vmlal.s8 q3, d3, d11
|
||||
vld1.8 {d0, d1, d2, d3}, [r0]!
|
||||
|
||||
vpadal.s16 q10, q14
|
||||
vld1.8 {d8, d9, d10, d11}, [r1]!
|
||||
vpadal.s16 q11, q2
|
||||
vpadal.s16 q12, q15
|
||||
vpadal.s16 q13, q3
|
||||
sub r5, r5, #16 // deep16 -= 16
|
||||
cmp r5, #0
|
||||
bgt L3
|
||||
L3Tail:
|
||||
vmull.s8 q14, d0, d8
|
||||
vmull.s8 q2, d0, d10
|
||||
vmull.s8 q15, d2, d8
|
||||
vmull.s8 q3, d2, d10
|
||||
vmlal.s8 q14, d1, d9
|
||||
vmlal.s8 q2, d1, d11
|
||||
vmlal.s8 q15, d3, d9
|
||||
vmlal.s8 q3, d3, d11
|
||||
vld1.8 {d0, d1, d2, d3}, [r0]!
|
||||
vmull.s8 q14, d0, d10
|
||||
vmull.s8 q15, d2, d10
|
||||
vmlal.s8 q14, d1, d11
|
||||
vmlal.s8 q15, d3, d11
|
||||
|
||||
vpadal.s16 q6, q14
|
||||
vpadal.s16 q7, q2
|
||||
vpadal.s16 q8, q15
|
||||
vpadal.s16 q9, q3
|
||||
vpadal.s16 q11, q14
|
||||
vpadal.s16 q13, q15
|
||||
|
||||
vmull.s8 q14, d0, d8
|
||||
vmull.s8 q2, d0, d10
|
||||
vmull.s8 q15, d2, d8
|
||||
vmull.s8 q3, d2, d10
|
||||
vmlal.s8 q14, d1, d9
|
||||
vmlal.s8 q2, d1, d11
|
||||
vmlal.s8 q15, d3, d9
|
||||
vmlal.s8 q3, d3, d11
|
||||
cmp r5, #16
|
||||
ble LoopDepthEnd
|
||||
sub r5, r5, #16
|
||||
b LoopDepth
|
||||
|
||||
vpadal.s16 q10, q14
|
||||
vpadal.s16 q11, q2
|
||||
vpadal.s16 q12, q15
|
||||
vpadal.s16 q13, q3
|
||||
LoopDepthEnd:
|
||||
vpadd.i32 d12, d12, d13
|
||||
vpadd.i32 d14, d14, d15
|
||||
vpadd.i32 d16, d16, d17
|
||||
vpadd.i32 d18, d18, d19
|
||||
vpadd.i32 d20, d20, d21
|
||||
vpadd.i32 d22, d22, d23
|
||||
vpadd.i32 d24, d24, d25
|
||||
vpadd.i32 d26, d26, d27
|
||||
|
||||
vpadd.i32 d0, d12, d13
|
||||
vpadd.i32 d1, d14, d15
|
||||
vpadd.i32 d2, d16, d17
|
||||
vpadd.i32 d3, d18, d19
|
||||
vpadd.i32 d4, d20, d21
|
||||
vpadd.i32 d5, d22, d23
|
||||
vpadd.i32 d6, d24, d25
|
||||
vpadd.i32 d7, d26, d27
|
||||
vpadd.i32 d28, d12, d14
|
||||
vpadd.i32 d29, d16, d18
|
||||
vpadd.i32 d30, d20, d22
|
||||
vpadd.i32 d31, d24, d26
|
||||
|
||||
vpadd.i32 d28, d0, d1
|
||||
vpadd.i32 d29, d2, d3
|
||||
vpadd.i32 d30, d4, d5
|
||||
vpadd.i32 d31, d6, d7
|
||||
|
||||
cmp r3, #4
|
||||
ble LAST_ROW
|
||||
|
||||
vld1.8 {d0, d1, d2, d3}, [r0]!
|
||||
ldr r1, [sp, #-48] // reload b ptr
|
||||
vld1.8 {d8, d9, d10, d11}, [r1]!
|
||||
b AddWeightBias
|
||||
|
||||
LAST_ROW:
|
||||
ldr r0, [sp, #-52] // reload a ptr
|
||||
vld1.8 {d0, d1, d2, d3}, [r0]!
|
||||
ldr r1, [sp, #-48] // reload b ptr
|
||||
ldr r9, [sp, #4]
|
||||
mov r10, #2
|
||||
mul r9, r9, r10 // the stride of b
|
||||
add r1, r1, r9 // b ptr + stride
|
||||
str r1, [sp, #-48]
|
||||
vld1.8 {d8, d9, d10, d11}, [r1]!
|
||||
|
||||
AddWeightBias:
|
||||
ldr r9, [sp, #12] // reload weight_bias ptr
|
||||
add r9, r9, r8
|
||||
vld1.32 {d26}, [r9]!
|
||||
Bias:
|
||||
cmp r12, #0
|
||||
beq NoBias
|
||||
vld1.32 {d26}, [r12]!
|
||||
vadd.i32 d28, d28, d26
|
||||
vadd.i32 d29, d29, d26
|
||||
vadd.i32 d30, d30, d26
|
||||
vadd.i32 d31, d31, d26
|
||||
|
||||
ldr r10, [sp, #44]
|
||||
cmp r10, #0
|
||||
bgt PerChannel
|
||||
NoBias:
|
||||
ldr lr, [sp, #44]
|
||||
cmp lr, #0
|
||||
bne PerChannel
|
||||
|
||||
PerTensor:
|
||||
// Substract input_sums
|
||||
vld1.32 {d24, d25}, [r6]!
|
||||
vld1.32 {d24, d25}, [r6]
|
||||
vdup.32 d20, d24[0]
|
||||
vdup.32 d21, d24[1]
|
||||
vdup.32 d22, d25[0]
|
||||
|
@ -179,43 +133,35 @@ PerTensor:
|
|||
vsub.s32 d30, d30, d22
|
||||
vsub.s32 d31, d31, d23
|
||||
|
||||
// Apply left shift
|
||||
ldr r10, [sp, #32]
|
||||
ldr r11, [r10]!
|
||||
vdup.32 q9, r11
|
||||
vmov.32 lr, d4[1]
|
||||
vld1.32 {q9[]}, [lr]
|
||||
vshl.s32 q14, q14, q9
|
||||
vshl.s32 q15, q15, q9
|
||||
|
||||
// Apply the fixed-point part of the multiplier
|
||||
ldr r10, [sp, #28]
|
||||
ldr r11, [r10]
|
||||
vdup.32 q8, r11
|
||||
vmov.32 lr, d5[0]
|
||||
vld1.32 {q8[]}, [lr]
|
||||
vqrdmulh.s32 q14, q14, q8
|
||||
vqrdmulh.s32 q15, q15, q8
|
||||
|
||||
// Apply right shift
|
||||
ldr r10, [sp, #36]
|
||||
ldr r11, [r10]
|
||||
vdup.32 q7, r11
|
||||
vmov.32 lr, d5[1]
|
||||
vld1.32 {q7[]}, [lr]
|
||||
vand q6, q7, q14
|
||||
vshr.s32 q6, q6, #31
|
||||
vqadd.s32 q14, q14, q6
|
||||
vrshl.s32 q14, q14, q7
|
||||
vand q3, q7, q15
|
||||
vshr.s32 q3, q3, #31
|
||||
vqadd.s32 q15, q15, q3
|
||||
vand q5, q7, q15
|
||||
vshr.s32 q5, q5, #31
|
||||
vqadd.s32 q15, q15, q5
|
||||
vrshl.s32 q15, q15, q7
|
||||
b AddDstZP
|
||||
b Quantize
|
||||
|
||||
PerChannel:
|
||||
// Substract input_sums
|
||||
vld1.32 {d24, d25}, [r6]!
|
||||
vld1.32 {d24, d25}, [r6]
|
||||
vdup.32 d20, d24[0]
|
||||
vdup.32 d21, d24[1]
|
||||
vdup.32 d22, d25[0]
|
||||
vdup.32 d23, d25[1]
|
||||
ldr r10, [sp, #48]
|
||||
vld1.32 {d19}, [r10]
|
||||
vld1.32 {d19}, [r7]!
|
||||
vmul.s32 d24, d20, d19
|
||||
vmul.s32 d25, d21, d19
|
||||
vmul.s32 d26, d22, d19
|
||||
|
@ -225,28 +171,25 @@ PerChannel:
|
|||
vsub.s32 d30, d30, d26
|
||||
vsub.s32 d31, d31, d27
|
||||
|
||||
// Apply left shift
|
||||
ldr r10, [sp, #32]
|
||||
add r10, r10, r8
|
||||
vld1.32 {d23}, [r10]
|
||||
vmov.32 lr, d4[1]
|
||||
vld1.32 {d23}, [lr]!
|
||||
vmov.32 d4[1], lr
|
||||
vshl.s32 d28, d28, d23
|
||||
vshl.s32 d29, d29, d23
|
||||
vshl.s32 d30, d30, d23
|
||||
vshl.s32 d31, d31, d23
|
||||
|
||||
// Apply the fixed-point part of the multiplier
|
||||
ldr r10, [sp, #28]
|
||||
add r10, r10, r8
|
||||
vld1.32 {d22}, [r10]
|
||||
vmov.32 lr, d5[0]
|
||||
vld1.32 {d22}, [lr]!
|
||||
vmov.32 d5[0], lr
|
||||
vqrdmulh.s32 d28, d28, d22
|
||||
vqrdmulh.s32 d29, d29, d22
|
||||
vqrdmulh.s32 d30, d30, d22
|
||||
vqrdmulh.s32 d31, d31, d22
|
||||
|
||||
// Apply right shift
|
||||
ldr r10, [sp, #36]
|
||||
add r10, r10, r8
|
||||
vld1.32 {d21}, [r10]
|
||||
vmov.32 lr, d5[1]
|
||||
vld1.32 {d21}, [lr]!
|
||||
vmov.32 d5[1], lr
|
||||
vand d20, d21, d28
|
||||
vshr.s32 d20, d20, #31
|
||||
vqadd.s32 d28, d28, d20
|
||||
|
@ -264,83 +207,84 @@ PerChannel:
|
|||
vqadd.s32 d31, d31, d17
|
||||
vrshl.s32 d31, d31, d21
|
||||
|
||||
AddDstZP:
|
||||
// Add the destination zero point
|
||||
ldr r10, [sp, #24]
|
||||
vdup.32 q2, r10
|
||||
vadd.i32 q14, q14, q2
|
||||
vadd.i32 q15, q15, q2
|
||||
Quantize:
|
||||
ldr lr, [sp, #24]
|
||||
vdup.32 q0, lr
|
||||
vadd.i32 q14, q14, q0
|
||||
vadd.i32 q15, q15, q0
|
||||
|
||||
// Apply the act_min bound
|
||||
ldr r10, [sp, #16]
|
||||
vdup.32 q3, r10
|
||||
vmax.s32 q14, q14, q3
|
||||
vmax.s32 q15, q15, q3
|
||||
ldr lr, [sp, #16]
|
||||
vdup.32 q1, lr
|
||||
vmax.s32 q14, q14, q1
|
||||
vmax.s32 q15, q15, q1
|
||||
|
||||
// Apply the act_max bound
|
||||
ldr r10, [sp, #20]
|
||||
vdup.32 q2, r10
|
||||
vmin.s32 q14, q14, q2
|
||||
vmin.s32 q15, q15, q2
|
||||
ldr lr, [sp, #20]
|
||||
vdup.32 q0, lr
|
||||
vmin.s32 q14, q14, q0
|
||||
vmin.s32 q15, q15, q0
|
||||
|
||||
// Cast-and-saturate from int32 to int16
|
||||
vqmovn.s32 d28, q14
|
||||
vqmovn.s32 d29, q15
|
||||
|
||||
// Cast-and-saturate from int16 to int8
|
||||
vqmovn.s16 d30, q14
|
||||
|
||||
// start to write
|
||||
cmp r4, #2
|
||||
bge WriteCol2
|
||||
cmp r4, #1
|
||||
beq WriteCol1
|
||||
b EndWrite
|
||||
beq Write1
|
||||
b Write2
|
||||
|
||||
WriteCol2:
|
||||
vst1.16 {d30[0]}, [r2], r7
|
||||
Write1:
|
||||
vmov.32 lr, d4[0]
|
||||
add lr, lr, #1
|
||||
vmov.32 d4[0], lr
|
||||
vst1.8 {d30[0]}, [r2], r8
|
||||
cmp r3, #1
|
||||
beq EndWrite
|
||||
vst1.16 {d30[1]}, [r2], r7
|
||||
beq WriteEnd
|
||||
vst1.8 {d30[2]}, [r2], r8
|
||||
cmp r3, #2
|
||||
beq EndWrite
|
||||
vst1.16 {d30[2]}, [r2], r7
|
||||
beq WriteEnd
|
||||
vst1.8 {d30[4]}, [r2], r8
|
||||
cmp r3, #3
|
||||
beq EndWrite
|
||||
vst1.16 {d30[3]}, [r2], r7
|
||||
b EndWrite
|
||||
beq WriteEnd
|
||||
vst1.8 {d30[6]}, [r2], r8
|
||||
b WriteEnd
|
||||
|
||||
WriteCol1:
|
||||
vst1.8 {d30[0]}, [r2], r7
|
||||
Write2:
|
||||
vmov.32 lr, d4[0]
|
||||
add lr, lr, #2
|
||||
vmov.32 d4[0], lr
|
||||
vst1.16 {d30[0]}, [r2], r8
|
||||
cmp r3, #1
|
||||
beq EndWrite
|
||||
vst1.8 {d30[2]}, [r2], r7
|
||||
beq WriteEnd
|
||||
vst1.16 {d30[1]}, [r2], r8
|
||||
cmp r3, #2
|
||||
beq EndWrite
|
||||
vst1.8 {d30[4]}, [r2], r7
|
||||
beq WriteEnd
|
||||
vst1.16 {d30[2]}, [r2], r8
|
||||
cmp r3, #3
|
||||
beq EndWrite
|
||||
vst1.8 {d30[6]}, [r2], r7
|
||||
b EndWrite
|
||||
beq WriteEnd
|
||||
vst1.16 {d30[3]}, [r2], r8
|
||||
|
||||
EndWrite:
|
||||
sub r3, r3, #4 // a row counter -= 4
|
||||
b L2
|
||||
WriteEnd:
|
||||
cmp r4, #2
|
||||
ble LoopColEnd
|
||||
sub r4, r4, #2
|
||||
b LoopCol
|
||||
|
||||
End2:
|
||||
sub r4, r4, #2 // b col counter -= 2
|
||||
ldr r2, [sp, #-44] // load dst ptr
|
||||
add r2, r2, #2 // dst ptr + offset
|
||||
str r2, [sp, #-44]
|
||||
ldr r10, [sp, #48]
|
||||
add r10, r10, #8
|
||||
str r10, [sp, #48]
|
||||
add r8, r8, #8 // output channels offset + 2*sizeof(int)
|
||||
b L1
|
||||
LoopColEnd:
|
||||
cmp r3, #4
|
||||
ble LoopRowEnd
|
||||
ldr lr, [sp, #-48]
|
||||
add lr, lr, r10
|
||||
str lr, [sp, #-48]
|
||||
ldr lr, [sp, #-40]
|
||||
add lr, lr, r11
|
||||
str lr, [sp, #-40]
|
||||
sub r3, r3, #4
|
||||
add r6, r6, #16
|
||||
b LoopRow
|
||||
|
||||
End1:
|
||||
sub sp, sp, #116
|
||||
LoopRowEnd:
|
||||
sub sp, sp, #112
|
||||
vpop {q4-q7}
|
||||
pop {r0-r11, pc}
|
||||
pop {r0-r8, r10, r11, pc}
|
||||
#endif
|
||||
#endif
|
||||
|
|
Loading…
Reference in New Issue