From 53c6862a6f37a5dc64595b9b82722eb692740803 Mon Sep 17 00:00:00 2001 From: yangruoqi713 Date: Mon, 21 Sep 2020 14:06:04 +0800 Subject: [PATCH] [MSLITE][Develop] support conv_depthwise arm32 int8 weight perchannel --- .../arm32/ConvDwInt8PostAlign4PerChannel.S | 113 ++++++++++++++++++ mindspore/lite/nnacl/int8/common_func.h | 6 +- .../lite/nnacl/int8/conv_depthwise_int8.c | 2 +- 3 files changed, 117 insertions(+), 4 deletions(-) create mode 100644 mindspore/lite/nnacl/assembly/arm32/ConvDwInt8PostAlign4PerChannel.S diff --git a/mindspore/lite/nnacl/assembly/arm32/ConvDwInt8PostAlign4PerChannel.S b/mindspore/lite/nnacl/assembly/arm32/ConvDwInt8PostAlign4PerChannel.S new file mode 100644 index 0000000000..d6740355f4 --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm32/ConvDwInt8PostAlign4PerChannel.S @@ -0,0 +1,113 @@ +#ifdef __arm__ +#ifndef __aarch64__ + +.text +.align 5 +.global ConvDwInt8PostAlign4PerChannel +#ifndef __APPLE__ +.type ConvDwInt8PostAlign4PerChannel, %function +#endif + +// void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp, int32_t *out_multiplier, +// int32_t *left_shift, int32_t *right_shift, int32_t acc_min, int32_t acc_max); +// r0: dst, r1: buffer, r2: num_pixels, r3: output_zp, r4: out_multiplier, +// r5: left_shift, r6: right_shift, r7: acc_min, r8: acc_max + +ConvDwInt8PostAlign4PerChannel: + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r10} + vpush {q4-q7} + add sp, sp, #88 + + vdup.32 q15, r3 // output_zp + + ldr r4, [sp] // out_multiplier + ldr r5, [sp, #4] // left_shift + ldr r6, [sp, #8] // right_shift + + ldr r7, [sp, #12] // acc_min + vdup.32 q11, r7 + + ldr r8, [sp, #16] // acc_max + vdup.32 q10, r8 + + mov r10, r0 + + LoopDepth8: + cmp r2, #8 + blt End + vld1.32 {q0}, [r1]! + vld1.32 {q13}, [r5]! + vshl.s32 q0, q0, q13 + vld1.32 {q14}, [r4]! + vqrdmulh.s32 q0, q0, q14 + vld1.32 {q12}, [r6]! + vand q4, q0, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q0, q0, q4 + vrshl.s32 q0, q0, q12 + vadd.i32 q0, q0, q15 + vmax.s32 q0, q0, q11 + vmin.s32 q0, q0, q10 + vqmovn.s32 d4, q0 + + vld1.32 {q1}, [r1]! + vld1.32 {q13}, [r5]! + vshl.s32 q1, q1, q13 + vld1.32 {q14}, [r4]! + vqrdmulh.s32 q1, q1, q14 + vld1.32 {q12}, [r6]! + vand q4, q1, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q1, q1, q4 + vrshl.s32 q1, q1, q12 + vadd.i32 q1, q1, q15 + vmax.s32 q1, q1, q11 + vmin.s32 q1, q1, q10 + vqmovn.s32 d5, q1 + vqmovn.s16 d4, q2 + + vst1.8 {d4}, [r10]! + + sub r2, r2, #8 + b LoopDepth8 + + LoopDepth4: + cmp r2, #4 + blt End + vld1.32 {q0}, [r1]! + vld1.32 {q13}, [r5]! + vshl.s32 q0, q0, q13 + vld1.32 {q14}, [r4]! + vqrdmulh.s32 q0, q0, q14 + vld1.32 {q12}, [r6]! + vand q4, q0, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q0, q0, q4 + vrshl.s32 q0, q0, q12 + vadd.i32 q0, q0, q15 + vmax.s32 q0, q0, q11 + vmin.s32 q0, q0, q10 + + vqmovn.s32 d0, q0 + vqmovn.s16 d0, q0 + + vst1.8 {d0[0]}, [r10]! + vst1.8 {d0[1]}, [r10]! + vst1.8 {d0[2]}, [r10]! + vst1.8 {d0[3]}, [r10]! + + sub r2, r2, #4 + b LoopDepth4 + End: + sub sp, sp, #88 + vpop {q4-q7} + pop {r4-r8, r10} + bx lr + +#endif +#endif diff --git a/mindspore/lite/nnacl/int8/common_func.h b/mindspore/lite/nnacl/int8/common_func.h index 95ba01b808..c5e39d34d3 100644 --- a/mindspore/lite/nnacl/int8/common_func.h +++ b/mindspore/lite/nnacl/int8/common_func.h @@ -35,6 +35,9 @@ void PostFuncInt8C4(const int32_t *in, const int32_t *bias, int8_t *out, size_t #ifdef ENABLE_ARM void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, int output_channel, int input_step, int8_t input_zp); +void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp, + int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t acc_min, + int32_t acc_max); void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max); void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t *weight, size_t ksize, size_t ic8, @@ -64,9 +67,6 @@ void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *wei void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); -void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp, - int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t acc_min, - int32_t acc_max); #endif #ifdef __cplusplus diff --git a/mindspore/lite/nnacl/int8/conv_depthwise_int8.c b/mindspore/lite/nnacl/int8/conv_depthwise_int8.c index 164edb387e..d29204ab0b 100644 --- a/mindspore/lite/nnacl/int8/conv_depthwise_int8.c +++ b/mindspore/lite/nnacl/int8/conv_depthwise_int8.c @@ -39,7 +39,7 @@ void ConvDwInt8Post(int8_t *dst, int32_t *buffer, int output_w, int channel, int // support perchannel for (int w = 0; w < output_w; w++) { int channel4 = 0; -#ifdef ENABLE_ARM64 +#ifdef ENABLE_ARM channel4 = channel / 4 * 4; ConvDwInt8PostAlign4PerChannel(dst, buffer, channel4, output_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);