forked from mindspore-Ecosystem/mindspore
Add int8 matmul asm for arm32 platform
This commit is contained in:
parent
f16c01460f
commit
1288c4ac94
|
@ -0,0 +1,237 @@
|
|||
#ifdef __arm__
|
||||
#ifndef __aarch64__
|
||||
|
||||
.text
|
||||
.align 5
|
||||
.global MatmulInt8Neon32
|
||||
#ifndef __APPLE__
|
||||
.type MatmulInt8Neon32, %function
|
||||
#endif
|
||||
|
||||
//void MatmulInt8Neon32(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16,
|
||||
// const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp,
|
||||
// int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel);
|
||||
// #-52: a, #-48: b, #-44: dst, #-40: row
|
||||
// #0: col, #4: deep16, #8: input_sums, #12: weight_bias, #16: act_min, #20: act_max, #24: out_zp
|
||||
// #28: multiplier, #32: left_shift, #36: right_shift, #40: stride, #44: per_channel
|
||||
|
||||
MatmulInt8Neon32:
|
||||
push {r0-r11, lr}
|
||||
vpush {q4-q7}
|
||||
add sp, sp, #116
|
||||
|
||||
ldr r4, [sp] // col
|
||||
mov r7, #2
|
||||
ldr r8, [sp, #4] // deep16
|
||||
mul r9, r7, r8 // the sride of b
|
||||
ldr r7, [sp, #40] // output stride
|
||||
|
||||
L1:
|
||||
cmp r4, #0 // if at the end of col
|
||||
ble End1
|
||||
|
||||
ldr r0, [sp, #-52] // reload a ptr
|
||||
ldr r3, [sp, #-40] // reset row counter
|
||||
ldr r6, [sp, #8] // reload intpu_sums ptr
|
||||
L2:
|
||||
cmp r3, #0 // if at the end of row
|
||||
ble End2
|
||||
|
||||
ldr r1, [sp, #-48] // reload b ptr
|
||||
ldr r8, [sp, #12] // reload weight_bias ptr
|
||||
ldr r5, [sp, #4] // reset deep16
|
||||
vmov.i32 q6, #0
|
||||
vmov.i32 q7, #0
|
||||
vmov.i32 q8, #0
|
||||
vmov.i32 q9, #0
|
||||
vmov.i32 q10, #0
|
||||
vmov.i32 q11, #0
|
||||
vmov.i32 q12, #0
|
||||
vmov.i32 q13, #0
|
||||
L3:
|
||||
cmp r5, #0
|
||||
beq End3
|
||||
|
||||
vld1.8 {d0, d1, d2, d3}, [r0]!
|
||||
vld1.8 {d8, d9, d10, d11}, [r1]!
|
||||
vmull.s8 q14, d0, d8
|
||||
vmull.s8 q2, d0, d10
|
||||
vmull.s8 q15, d2, d8
|
||||
vmull.s8 q3, d2, d10
|
||||
vmlal.s8 q14, d1, d9
|
||||
vmlal.s8 q2, d1, d11
|
||||
vmlal.s8 q15, d3, d9
|
||||
vmlal.s8 q3, d3, d11
|
||||
|
||||
vpadal.s16 q6, q14
|
||||
vpadal.s16 q7, q2
|
||||
vpadal.s16 q8, q15
|
||||
vpadal.s16 q9, q3
|
||||
|
||||
vld1.8 {d0, d1, d2, d3}, [r0]!
|
||||
vmull.s8 q14, d0, d8
|
||||
vmull.s8 q2, d0, d10
|
||||
vmull.s8 q15, d2, d8
|
||||
vmull.s8 q3, d2, d10
|
||||
vmlal.s8 q14, d1, d9
|
||||
vmlal.s8 q2, d1, d11
|
||||
vmlal.s8 q15, d3, d9
|
||||
vmlal.s8 q3, d3, d11
|
||||
|
||||
vpadal.s16 q10, q14
|
||||
vpadal.s16 q11, q2
|
||||
vpadal.s16 q12, q15
|
||||
vpadal.s16 q13, q3
|
||||
sub r5, r5, #16 // deep16 -= 16
|
||||
b L3
|
||||
|
||||
End3:
|
||||
vpadd.i32 d0, d12, d13
|
||||
vpadd.i32 d1, d14, d15
|
||||
vpadd.i32 d2, d16, d17
|
||||
vpadd.i32 d3, d18, d19
|
||||
vpadd.i32 d4, d20, d21
|
||||
vpadd.i32 d5, d22, d23
|
||||
vpadd.i32 d6, d24, d25
|
||||
vpadd.i32 d7, d26, d27
|
||||
|
||||
vpadd.i32 d28, d0, d1
|
||||
vpadd.i32 d29, d2, d3
|
||||
vpadd.i32 d30, d4, d5
|
||||
vpadd.i32 d31, d6, d7
|
||||
|
||||
// Add weight_bias
|
||||
vld1.32 {d26}, [r8]!
|
||||
vadd.i32 d28, d28, d26
|
||||
vadd.i32 d29, d29, d26
|
||||
vadd.i32 d30, d30, d26
|
||||
vadd.i32 d31, d31, d26
|
||||
|
||||
ldr r10, [sp, #44]
|
||||
cmp r10, #0
|
||||
bgt PerChannel
|
||||
|
||||
// Substract input_sums
|
||||
vld1.32 {d24, d25}, [r6]!
|
||||
vdup.32 d20, d24[0]
|
||||
vdup.32 d21, d24[1]
|
||||
vdup.32 d22, d25[0]
|
||||
vdup.32 d23, d25[1]
|
||||
vsub.s32 d28, d28, d20
|
||||
vsub.s32 d29, d29, d21
|
||||
vsub.s32 d30, d30, d22
|
||||
vsub.s32 d31, d31, d23
|
||||
|
||||
// Apply left shift
|
||||
ldr r10, [sp, #32]
|
||||
ldr r11, [r10]
|
||||
vdup.32 q9, r11
|
||||
vshl.s32 q14, q14, q9
|
||||
vshl.s32 q15, q15, q9
|
||||
|
||||
// Apply the fixed-point part of the multiplier
|
||||
ldr r10, [sp, #28]
|
||||
ldr r11, [r10]
|
||||
vdup.32 q8, r11
|
||||
vqrdmulh.s32 q14, q14, q8
|
||||
vqrdmulh.s32 q15, q15, q8
|
||||
|
||||
// Apply right shift
|
||||
ldr r10, [sp, #36]
|
||||
ldr r11, [r10]
|
||||
vdup.32 q7, r11
|
||||
vand q6, q7, q14
|
||||
vshr.s32 q6, q6, #31
|
||||
vqadd.s32 q14, q14, q6
|
||||
vrshl.s32 q14, q14, q7
|
||||
vand q5, q7, q15
|
||||
vshr.s32 q5, q5, #31
|
||||
vqadd.s32 q15, q15, q5
|
||||
vrshl.s32 q15, q15, q7
|
||||
b AddDstZP
|
||||
|
||||
PerChannel:
|
||||
|
||||
|
||||
AddDstZP:
|
||||
// Add the destination zero point
|
||||
ldr r10, [sp, #24]
|
||||
vdup.32 q4, r10
|
||||
vadd.i32 q14, q14, q4
|
||||
vadd.i32 q15, q15, q4
|
||||
|
||||
// Apply the act_min bound
|
||||
ldr r10, [sp, #16]
|
||||
vdup.32 q3, r10
|
||||
vmax.s32 q14, q14, q3
|
||||
vmax.s32 q15, q15, q3
|
||||
|
||||
// Apply the act_max bound
|
||||
ldr r10, [sp, #20]
|
||||
vdup.32 q2, r10
|
||||
vmin.s32 q14, q14, q2
|
||||
vmin.s32 q15, q15, q2
|
||||
|
||||
// Cast-and-saturate from int32 to int16
|
||||
vqmovn.s32 d28, q14
|
||||
vqmovn.s32 d29, q15
|
||||
|
||||
// Cast-and-saturate from int16 to int8
|
||||
vqmovn.s16 d30, q14
|
||||
|
||||
// start to write
|
||||
cmp r4, #2
|
||||
bge WriteCol2
|
||||
cmp r4, #1
|
||||
beq WriteCol1
|
||||
b EndWrite
|
||||
|
||||
WriteCol2:
|
||||
vst1.16 {d30[0]}, [r2], r7
|
||||
cmp r3, #1
|
||||
beq EndWrite
|
||||
vst1.16 {d30[1]}, [r2], r7
|
||||
cmp r3, #2
|
||||
beq EndWrite
|
||||
vst1.16 {d30[2]}, [r2], r7
|
||||
cmp r3, #3
|
||||
beq EndWrite
|
||||
vst1.16 {d30[3]}, [r2], r7
|
||||
b EndWrite
|
||||
|
||||
WriteCol1:
|
||||
vst1.8 {d30[0]}, [r2], r7
|
||||
cmp r3, #1
|
||||
beq EndWrite
|
||||
vst1.8 {d30[2]}, [r2], r7
|
||||
cmp r3, #2
|
||||
beq EndWrite
|
||||
vst1.8 {d30[4]}, [r2], r7
|
||||
cmp r3, #3
|
||||
beq EndWrite
|
||||
vst1.8 {d30[6]}, [r2], r7
|
||||
b EndWrite
|
||||
|
||||
EndWrite:
|
||||
sub r3, r3, #4 // a row counter -= 4
|
||||
b L2
|
||||
|
||||
End2:
|
||||
sub r4, r4, #2 // b col counter -= 2
|
||||
ldr r1, [sp, #-48] // b ptr + stride
|
||||
add r1, r1, r9
|
||||
str r1, [sp, #-48]
|
||||
ldr r8, [sp, #12] // weight_bias + stride
|
||||
add r8, r8, #8
|
||||
str r8, [sp, #12]
|
||||
ldr r2, [sp, #-44] // dst ptr + offset
|
||||
add r2, r2, #2
|
||||
str r2, [sp, #-44]
|
||||
b L1
|
||||
|
||||
End1:
|
||||
sub sp, sp, #116
|
||||
vpop {q4-q7}
|
||||
pop {r0-r11, pc}
|
||||
#endif
|
||||
#endif
|
|
@ -738,10 +738,18 @@ void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int
|
|||
void Conv1x1Int8Arm32(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
|
||||
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
|
||||
int32_t *multiplier, ConvParameter *conv_param) {
|
||||
int is_per_channel = conv_param->conv_quant_arg_.filter_arg_num_ != 1 ? true : false;
|
||||
#ifdef ENABLE_ARM32
|
||||
MatmulInt8Neon32(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias,
|
||||
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
|
||||
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift,
|
||||
conv_param->output_channel_, is_per_channel);
|
||||
#else
|
||||
MatMulInt8_4x2_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias,
|
||||
left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
|
||||
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
|
||||
conv_param->conv_quant_arg_.filter_arg_num_ != 1);
|
||||
is_per_channel);
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -65,6 +65,11 @@ void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, i
|
|||
void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16,
|
||||
const int *input_sum, const int *bias);
|
||||
#endif
|
||||
#ifdef ENABLE_ARM32
|
||||
void MatmulInt8Neon32(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16,
|
||||
const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp,
|
||||
int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel);
|
||||
#endif
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
Loading…
Reference in New Issue