From 1288c4ac94ce2a77c0d4571e8c91b13081415b48 Mon Sep 17 00:00:00 2001 From: zhanyuan Date: Mon, 14 Sep 2020 10:24:51 +0800 Subject: [PATCH] Add int8 matmul asm for arm32 platform --- .../lite/nnacl/assembly/arm32/MatmulInt8.S | 237 ++++++++++++++++++ mindspore/lite/nnacl/int8/conv_int8.c | 10 +- mindspore/lite/nnacl/int8/matmul_int8.h | 5 + 3 files changed, 251 insertions(+), 1 deletion(-) create mode 100644 mindspore/lite/nnacl/assembly/arm32/MatmulInt8.S diff --git a/mindspore/lite/nnacl/assembly/arm32/MatmulInt8.S b/mindspore/lite/nnacl/assembly/arm32/MatmulInt8.S new file mode 100644 index 00000000000..d5c66121f3d --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm32/MatmulInt8.S @@ -0,0 +1,237 @@ +#ifdef __arm__ +#ifndef __aarch64__ + +.text +.align 5 +.global MatmulInt8Neon32 +#ifndef __APPLE__ +.type MatmulInt8Neon32, %function +#endif + +//void MatmulInt8Neon32(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, +// const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp, +// int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel); +// #-52: a, #-48: b, #-44: dst, #-40: row +// #0: col, #4: deep16, #8: input_sums, #12: weight_bias, #16: act_min, #20: act_max, #24: out_zp +// #28: multiplier, #32: left_shift, #36: right_shift, #40: stride, #44: per_channel + +MatmulInt8Neon32: + push {r0-r11, lr} + vpush {q4-q7} + add sp, sp, #116 + + ldr r4, [sp] // col + mov r7, #2 + ldr r8, [sp, #4] // deep16 + mul r9, r7, r8 // the sride of b + ldr r7, [sp, #40] // output stride + +L1: + cmp r4, #0 // if at the end of col + ble End1 + + ldr r0, [sp, #-52] // reload a ptr + ldr r3, [sp, #-40] // reset row counter + ldr r6, [sp, #8] // reload intpu_sums ptr +L2: + cmp r3, #0 // if at the end of row + ble End2 + + ldr r1, [sp, #-48] // reload b ptr + ldr r8, [sp, #12] // reload weight_bias ptr + ldr r5, [sp, #4] // reset deep16 + vmov.i32 q6, #0 + vmov.i32 q7, #0 + vmov.i32 q8, #0 + vmov.i32 q9, #0 + vmov.i32 q10, #0 + vmov.i32 q11, #0 + vmov.i32 q12, #0 + vmov.i32 q13, #0 +L3: + cmp r5, #0 + beq End3 + + vld1.8 {d0, d1, d2, d3}, [r0]! + vld1.8 {d8, d9, d10, d11}, [r1]! + vmull.s8 q14, d0, d8 + vmull.s8 q2, d0, d10 + vmull.s8 q15, d2, d8 + vmull.s8 q3, d2, d10 + vmlal.s8 q14, d1, d9 + vmlal.s8 q2, d1, d11 + vmlal.s8 q15, d3, d9 + vmlal.s8 q3, d3, d11 + + vpadal.s16 q6, q14 + vpadal.s16 q7, q2 + vpadal.s16 q8, q15 + vpadal.s16 q9, q3 + + vld1.8 {d0, d1, d2, d3}, [r0]! + vmull.s8 q14, d0, d8 + vmull.s8 q2, d0, d10 + vmull.s8 q15, d2, d8 + vmull.s8 q3, d2, d10 + vmlal.s8 q14, d1, d9 + vmlal.s8 q2, d1, d11 + vmlal.s8 q15, d3, d9 + vmlal.s8 q3, d3, d11 + + vpadal.s16 q10, q14 + vpadal.s16 q11, q2 + vpadal.s16 q12, q15 + vpadal.s16 q13, q3 + sub r5, r5, #16 // deep16 -= 16 + b L3 + +End3: + vpadd.i32 d0, d12, d13 + vpadd.i32 d1, d14, d15 + vpadd.i32 d2, d16, d17 + vpadd.i32 d3, d18, d19 + vpadd.i32 d4, d20, d21 + vpadd.i32 d5, d22, d23 + vpadd.i32 d6, d24, d25 + vpadd.i32 d7, d26, d27 + + vpadd.i32 d28, d0, d1 + vpadd.i32 d29, d2, d3 + vpadd.i32 d30, d4, d5 + vpadd.i32 d31, d6, d7 + + // Add weight_bias + vld1.32 {d26}, [r8]! + vadd.i32 d28, d28, d26 + vadd.i32 d29, d29, d26 + vadd.i32 d30, d30, d26 + vadd.i32 d31, d31, d26 + + ldr r10, [sp, #44] + cmp r10, #0 + bgt PerChannel + + // Substract input_sums + vld1.32 {d24, d25}, [r6]! + vdup.32 d20, d24[0] + vdup.32 d21, d24[1] + vdup.32 d22, d25[0] + vdup.32 d23, d25[1] + vsub.s32 d28, d28, d20 + vsub.s32 d29, d29, d21 + vsub.s32 d30, d30, d22 + vsub.s32 d31, d31, d23 + + // Apply left shift + ldr r10, [sp, #32] + ldr r11, [r10] + vdup.32 q9, r11 + vshl.s32 q14, q14, q9 + vshl.s32 q15, q15, q9 + + // Apply the fixed-point part of the multiplier + ldr r10, [sp, #28] + ldr r11, [r10] + vdup.32 q8, r11 + vqrdmulh.s32 q14, q14, q8 + vqrdmulh.s32 q15, q15, q8 + + // Apply right shift + ldr r10, [sp, #36] + ldr r11, [r10] + vdup.32 q7, r11 + vand q6, q7, q14 + vshr.s32 q6, q6, #31 + vqadd.s32 q14, q14, q6 + vrshl.s32 q14, q14, q7 + vand q5, q7, q15 + vshr.s32 q5, q5, #31 + vqadd.s32 q15, q15, q5 + vrshl.s32 q15, q15, q7 + b AddDstZP + +PerChannel: + + +AddDstZP: + // Add the destination zero point + ldr r10, [sp, #24] + vdup.32 q4, r10 + vadd.i32 q14, q14, q4 + vadd.i32 q15, q15, q4 + + // Apply the act_min bound + ldr r10, [sp, #16] + vdup.32 q3, r10 + vmax.s32 q14, q14, q3 + vmax.s32 q15, q15, q3 + + // Apply the act_max bound + ldr r10, [sp, #20] + vdup.32 q2, r10 + vmin.s32 q14, q14, q2 + vmin.s32 q15, q15, q2 + + // Cast-and-saturate from int32 to int16 + vqmovn.s32 d28, q14 + vqmovn.s32 d29, q15 + + // Cast-and-saturate from int16 to int8 + vqmovn.s16 d30, q14 + + // start to write + cmp r4, #2 + bge WriteCol2 + cmp r4, #1 + beq WriteCol1 + b EndWrite + +WriteCol2: + vst1.16 {d30[0]}, [r2], r7 + cmp r3, #1 + beq EndWrite + vst1.16 {d30[1]}, [r2], r7 + cmp r3, #2 + beq EndWrite + vst1.16 {d30[2]}, [r2], r7 + cmp r3, #3 + beq EndWrite + vst1.16 {d30[3]}, [r2], r7 + b EndWrite + +WriteCol1: + vst1.8 {d30[0]}, [r2], r7 + cmp r3, #1 + beq EndWrite + vst1.8 {d30[2]}, [r2], r7 + cmp r3, #2 + beq EndWrite + vst1.8 {d30[4]}, [r2], r7 + cmp r3, #3 + beq EndWrite + vst1.8 {d30[6]}, [r2], r7 + b EndWrite + +EndWrite: + sub r3, r3, #4 // a row counter -= 4 + b L2 + +End2: + sub r4, r4, #2 // b col counter -= 2 + ldr r1, [sp, #-48] // b ptr + stride + add r1, r1, r9 + str r1, [sp, #-48] + ldr r8, [sp, #12] // weight_bias + stride + add r8, r8, #8 + str r8, [sp, #12] + ldr r2, [sp, #-44] // dst ptr + offset + add r2, r2, #2 + str r2, [sp, #-44] + b L1 + +End1: + sub sp, sp, #116 + vpop {q4-q7} + pop {r0-r11, pc} +#endif +#endif diff --git a/mindspore/lite/nnacl/int8/conv_int8.c b/mindspore/lite/nnacl/int8/conv_int8.c index 33e88fe5dc5..5dccd3b7722 100644 --- a/mindspore/lite/nnacl/int8/conv_int8.c +++ b/mindspore/lite/nnacl/int8/conv_int8.c @@ -738,10 +738,18 @@ void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int void Conv1x1Int8Arm32(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, int32_t *multiplier, ConvParameter *conv_param) { + int is_per_channel = conv_param->conv_quant_arg_.filter_arg_num_ != 1 ? true : false; +#ifdef ENABLE_ARM32 + MatmulInt8Neon32(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, + conv_param->output_channel_, is_per_channel); +#else MatMulInt8_4x2_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias, left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], - conv_param->conv_quant_arg_.filter_arg_num_ != 1); + is_per_channel); +#endif return; } diff --git a/mindspore/lite/nnacl/int8/matmul_int8.h b/mindspore/lite/nnacl/int8/matmul_int8.h index d2fd4c87f67..5babb3b443e 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.h +++ b/mindspore/lite/nnacl/int8/matmul_int8.h @@ -65,6 +65,11 @@ void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, i void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16, const int *input_sum, const int *bias); #endif +#ifdef ENABLE_ARM32 +void MatmulInt8Neon32(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, + const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp, + int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel); +#endif #ifdef __cplusplus } #endif