From 7d97c1b90358d873f2dd5285244c4e5ca6474d10 Mon Sep 17 00:00:00 2001 From: ling Date: Tue, 20 Oct 2020 17:45:11 +0800 Subject: [PATCH] [MSLITE][Develop]deconv winograd input pack and output bias --- .../nnacl/assembly/arm32/PostFuncBiasReluC4.S | 248 ++++++++++++++ .../nnacl/assembly/arm64/PostFuncBiasReluC4.S | 305 ++++++++++++++++++ mindspore/lite/nnacl/conv_parameter.h | 1 - .../lite/nnacl/fp16/deconv_winograd_fp16.c | 11 +- mindspore/lite/nnacl/fp32/common_func.c | 7 + mindspore/lite/nnacl/fp32/common_func.h | 2 + mindspore/lite/nnacl/fp32/deconv_winograd.c | 24 +- .../arm/fp16/deconvolution_winograd_fp16.cc | 4 - .../kernel/arm/fp32/deconvolution_winograd.cc | 4 - 9 files changed, 583 insertions(+), 23 deletions(-) create mode 100644 mindspore/lite/nnacl/assembly/arm32/PostFuncBiasReluC4.S create mode 100644 mindspore/lite/nnacl/assembly/arm64/PostFuncBiasReluC4.S diff --git a/mindspore/lite/nnacl/assembly/arm32/PostFuncBiasReluC4.S b/mindspore/lite/nnacl/assembly/arm32/PostFuncBiasReluC4.S new file mode 100644 index 00000000000..b125b00e26c --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm32/PostFuncBiasReluC4.S @@ -0,0 +1,248 @@ + + .text + .align 5 + //.p2align 5,,15 + .global PostFuncBiasReluC4 + #ifndef __APPLE__ + .type PostFuncBiasReluC4, %function + #endif + +//void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, +// size_t plane_size, size_t plane_stride, size_t relu_type); +// r0 dst r1 srx r2 bias +// r3 oc4div r4 oc4mod r5 plane_size +// r6 plane_stride r7 relu_type + +// v0 ~ v15 value +// v16 v17 bias data +// r10 r11 weite loop tmp buf +// r16 relu6 #6; r17 relu #0 +// lr oc8 loop control +// r8 hw loop control + +PostFuncBiasReluC4: + push {r4-r8, r10, r11, lr} + add sp, sp, #32 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + + vmov.i32 q14, #6 + vcvt.f32.s32 q14, q14 + veor q15, q15, q15 + + mov lr, #4 + add r12, r3, r4 + mul r12, r12, lr + + mov lr, #0 + +Loop_C4: + cmp lr, r3 + beq Loop_C1 + mov r11, #4 + mul r10, lr, r11 + add r11, r0, r10 + add lr, lr, #4 + mov r8, r5 + vld1.32 {q12}, [r2]! + +Loop_4x4: + cmp r8, #4 + blt Loop_1x4 + sub r8, r8, #4 + vld1.32 {q0-q1}, [r1]! + vld1.32 {q2-q3}, [r1]! + + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q12 + vadd.f32 q2, q2, q12 + vadd.f32 q3, q3, q12 + + cmp r7, #3 + beq Relu6_4x4 + cmp r7, #1 + beq Relu_4x4 + b Write_4x4 +Relu6_4x4: + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmin.f32 q2, q2, q14 + vmin.f32 q3, q3, q14 +Relu_4x4: + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vmax.f32 q2, q2, q15 + vmax.f32 q3, q3, q15 +Write_4x4: + vst1.32 {q0}, [r11], r12 + vst1.32 {q1}, [r11], r12 + vst1.32 {q2}, [r11], r12 + vst1.32 {q3}, [r11], r12 + b Loop_4x4 + +Loop_1x4: + cmp r7, #3 + beq Relu6_1x4 + cmp r7, #1 + beq Relu_1x4 + b Write_1x4 +Relu6_1x4: + cmp r8, #0 + beq HW_Add + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {q0}, [r11], r12 + b Relu6_1x4 +Relu_1x4: + cmp r8, #0 + beq HW_Add + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {q0}, [r11], r12 + b Relu_1x4 +Write_1x4: + cmp r8, #0 + beq HW_Add + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {q0}, [r11], r12 + b Write_1x4 + +HW_Add: + add r1, r1, r6 + b Loop_C4 + +Loop_C1: + cmp r4, #0 + beq End + mov r8, r5 + vld1.32 {q12}, [r2]! + mov r11, #4 + mul r10, lr, r11 + add r0, r0, r10 + + cmp r4, #1 + beq Loop_C1_1 + cmp r4, #2 + beq Loop_C1_2 + cmp r4, #3 + beq Loop_C1_3 + +Loop_C1_1: + cmp r7, #3 + beq Loop_C1_1_Relu6 + cmp r7, #1 + beq Loop_C1_1_Relu + b Loop_C1_1_Write +Loop_C1_1_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {d0[0]}, [r0], r12 + b Loop_C1_1_Relu6 +Loop_C1_1_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {d0[0]}, [r0], r12 + b Loop_C1_1_Relu +Loop_C1_1_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {d0[0]}, [r0], r12 + b Loop_C1_1_Write + +Loop_C1_2: + cmp r7, #3 + beq Loop_C1_2_Relu6 + cmp r7, #1 + beq Loop_C1_2_Relu + b Loop_C1_2_Write +Loop_C1_2_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r12 + b Loop_C1_2_Relu6 +Loop_C1_2_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r12 + b Loop_C1_2_Relu +Loop_C1_2_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {d0}, [r0], r12 + b Loop_C1_2_Write + +Loop_C1_3: + add r11, r0, #8 + cmp r7, #3 + beq Loop_C1_3_Relu6 + cmp r7, #1 + beq Loop_C1_3_Relu + b Loop_C1_3_Write +Loop_C1_3_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r6 + vst1.32 {d1[0]}, [r11], r12 + b Loop_C1_3_Relu6 +Loop_C1_3_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r6 + vst1.32 {d1[0]}, [r11], r12 + b Loop_C1_3_Relu +Loop_C1_3_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {d0}, [r0], r6 + vst1.32 {d1[0]}, [r11], r12 + b Loop_C1_3_Write + +End: + sub sp, sp, #32 + pop {r4-r8, r10, r11, pc} diff --git a/mindspore/lite/nnacl/assembly/arm64/PostFuncBiasReluC4.S b/mindspore/lite/nnacl/assembly/arm64/PostFuncBiasReluC4.S new file mode 100644 index 00000000000..3ba57222fa8 --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm64/PostFuncBiasReluC4.S @@ -0,0 +1,305 @@ +#ifdef __aarch64__ + + .text + .align 5 + //.p2align 5,,15 + .global PostFuncBiasReluC4 +#ifndef __APPLE__ + .type PostFuncBiasReluC4, %function +#endif + +//void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, +// size_t plane_size, size_t plane_stride, size_t relu_type); +// x0 dst x1 srx x2 bias +// w3 oc4div w4 oc4mod w5 plane_size +// x6 plane_stride x7 relu_type + +// v0 ~ v7 value +// v16 bias data +// x12 oc_stride +// x14 x15 write loop tmp buf +// v26 relu6 #6; v27 relu #0 +// w10 oc4 loop control +// w13 hw loop control + + +PostFuncBiasReluC4: + + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + mov x10, #4 + add x12, x3, x4 + mul x12, x12, x10 + + mov w10, #0 + +Loop_C4: + cmp w10, w3 + beq Loop_C1 + mov x15, #4 + mul x14, x10, x15 + add x15, x0, x14 + add w10, w10, #4 + mov w13, w5 + ld1 {v16.4s}, [x2], #16 + +Loop_8x4: + cmp w13, #8 + blt Loop_4x4 + sub w13, w13, #8 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 + + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v16.4s + fadd v2.4s, v2.4s, v16.4s + fadd v3.4s, v3.4s, v16.4s + fadd v4.4s, v4.4s, v16.4s + fadd v5.4s, v5.4s, v16.4s + fadd v6.4s, v6.4s, v16.4s + fadd v7.4s, v7.4s, v16.4s + + cmp x7, #3 + beq Relu6_8x4 + cmp x7, #1 + beq Relu_8x4 + b Write_8x4 +Relu6_8x4: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s +Relu_8x4: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s +Write_8x4: + st1 {v0.4s}, [x15], x12 + st1 {v1.4s}, [x15], x12 + st1 {v2.4s}, [x15], x12 + st1 {v3.4s}, [x15], x12 + st1 {v4.4s}, [x15], x12 + st1 {v5.4s}, [x15], x12 + st1 {v6.4s}, [x15], x12 + st1 {v7.4s}, [x15], x12 + b Loop_8x4 + +Loop_4x4: + cmp w13, #4 + blt Loop_1x4 + sub w13, w13, #4 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v16.4s + fadd v2.4s, v2.4s, v16.4s + fadd v3.4s, v3.4s, v16.4s + cmp x7, #3 + beq Relu6_4x4 + cmp x7, #1 + beq Relu_4x4 + b Write_4x4 +Relu6_4x4: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s +Relu_4x4: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s +Write_4x4: + st1 {v0.4s}, [x15], x12 + st1 {v1.4s}, [x15], x12 + st1 {v2.4s}, [x15], x12 + st1 {v3.4s}, [x15], x12 + +Loop_1x4: + cmp x7, #3 + beq Relu6_1x4 + cmp x7, #1 + beq Relu_1x4 + b Write_1x4 +Relu6_1x4: + cmp w13, #0 + beq HW_Add + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + st1 {v0.4s}, [x15], x12 + b Relu6_1x4 +Relu_1x4: + cmp w13, #0 + beq HW_Add + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + st1 {v0.4s}, [x15], x12 + b Relu_1x4 +Write_1x4: + cmp w13, #0 + beq HW_Add + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + st1 {v0.4s}, [x15], x12 + b Write_1x4 + +HW_Add: + add x1, x1, x6 + b Loop_C4 + +Loop_C1: + cmp x4, #0 + beq End + mov w13, w5 + ld1 {v16.4s}, [x2], #16 + mov x15, #4 + mul x14, x10, x15 + add x0, x0, x14 + + cmp x4, #1 + beq Loop_C1_1 + cmp x4, #2 + beq Loop_C1_2 + cmp x4, #3 + beq Loop_C1_3 + +Loop_C1_1: + cmp x7, #3 + beq Loop_C1_1_Relu6 + cmp x7, #1 + beq Loop_C1_1_Relu + b Loop_C1_1_Write +Loop_C1_1_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + str s0, [x0] + add x0, x0, x12 + b Loop_C1_1_Relu6 +Loop_C1_1_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + str s0, [x0] + add x0, x0, x12 + b Loop_C1_1_Relu +Loop_C1_1_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + str s0, [x0] + add x0, x0, x12 + b Loop_C1_1_Write + +Loop_C1_2: + cmp x7, #3 + beq Loop_C1_2_Relu6 + cmp x7, #1 + beq Loop_C1_2_Relu + b Loop_C1_2_Write +Loop_C1_2_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + b Loop_C1_2_Relu6 +Loop_C1_2_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + b Loop_C1_2_Relu +Loop_C1_2_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + b Loop_C1_2_Write + +Loop_C1_3: + add x15, x0, #8 + cmp x7, #3 + beq Loop_C1_3_Relu6 + cmp x7, #1 + beq Loop_C1_3_Relu + b Loop_C1_3_Write +Loop_C1_3_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + st1 {v0.s}[2], [x15], x12 + b Loop_C1_3_Relu6 +Loop_C1_3_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + st1 {v0.s}[2], [x15], x12 + b Loop_C1_3_Relu +Loop_C1_3_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + st1 {v0.s}[2], [x15], x12 + b Loop_C1_3_Write + +End: + ret +#endif diff --git a/mindspore/lite/nnacl/conv_parameter.h b/mindspore/lite/nnacl/conv_parameter.h index ae1a60982f0..293c15cdc11 100644 --- a/mindspore/lite/nnacl/conv_parameter.h +++ b/mindspore/lite/nnacl/conv_parameter.h @@ -89,7 +89,6 @@ typedef struct DeConvWg { typedef struct DeConvWgABuffer { bool buf_init_; - bool trans_formed_; void *middle_buffer_; void *dest_buffer_; } DeConvWgABuffer; diff --git a/mindspore/lite/nnacl/fp16/deconv_winograd_fp16.c b/mindspore/lite/nnacl/fp16/deconv_winograd_fp16.c index cb8ef91109a..983c53a5d5c 100644 --- a/mindspore/lite/nnacl/fp16/deconv_winograd_fp16.c +++ b/mindspore/lite/nnacl/fp16/deconv_winograd_fp16.c @@ -79,15 +79,16 @@ void DeConvWgMergeFp16(const float16_t *src, float16_t *dst, size_t src_stride, } void _deConvWinogradFp16(float16_t *tile_in, float16_t *tile_out, float16_t *weight_buf, float16_t *tmp_buf, - float16_t *at_buf, float16_t *a_mid_buf, float16_t *trans_a_buf, bool a_trans, + float16_t *at_buf, float16_t *a_mid_buf, float16_t *trans_a_buf, bool *transfered, float16_t *bt_buf, float16_t *b_tmp_buf, int unit_size, int w_start, int h_start, ConvParameter *conv_param, DeConvParam *deconv_param) { int winograd_plane = unit_size * unit_size; - if (!a_trans) { + if (!transfered[unit_size]) { WinogradMatrixProductLeftFp16(tile_in, at_buf, a_mid_buf, DECONV_WINOGRAD_DEFAULT_UNIT, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); WinogradMatrixProductRightFp16(a_mid_buf, at_buf, trans_a_buf, unit_size, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); + transfered[unit_size] = false; } for (int index = 0; index < winograd_plane; index++) { @@ -265,6 +266,7 @@ void DeconvWgFp16(float16_t *nhwc_input_, float16_t *tile_in, float16_t *tile_ou } /* compute */ + bool transfered[DECONV_WINOGRAD_BUFFER_COUNT] = {false}; for (int i = 0; i < deconv_param->compute_size_; i++) { DeConvComputeUnit *unit = &deconv_param->compute_units_[i]; if (unit->use_winograd_) { @@ -281,9 +283,8 @@ void DeconvWgFp16(float16_t *nhwc_input_, float16_t *tile_in, float16_t *tile_ou DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up4_; _deConvWinogradFp16(tile_in, tile_out, (float16_t *)unit->weight_, tmp_buf, unit->winograd_.AT_, mid_a, dst_a, - tmp_a->trans_formed_, unit->winograd_.BT_, tmp_b, unit->winograd_.kh_, unit->w_start_, - unit->h_start_, conv_param, deconv_param); - tmp_a->trans_formed_ = true; + transfered, unit->winograd_.BT_, tmp_b, unit->winograd_.kh_, unit->w_start_, unit->h_start_, + conv_param, deconv_param); } else { float16_t *tmp_buf = (float16_t *)unit->tmp_buffer_ + task_id * deconv_param->oc_div4_ * unit->w_size_ * unit->h_size_ * DECONV_WINOGRAD_DEFAULT_TILE * C4NUM; diff --git a/mindspore/lite/nnacl/fp32/common_func.c b/mindspore/lite/nnacl/fp32/common_func.c index cfad0d189bf..c320a1da1bd 100644 --- a/mindspore/lite/nnacl/fp32/common_func.c +++ b/mindspore/lite/nnacl/fp32/common_func.c @@ -56,8 +56,15 @@ void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bi void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, size_t plane_size, size_t plane_stride, size_t relu_type) { +#ifdef ENABLE_ARM + size_t oc4mod = output_channel % C4NUM; + size_t oc4div = output_channel - oc4mod; + size_t stride_size = (plane_stride - plane_size) * C4NUM * sizeof(float); + PostFuncBiasReluC4(out_ptr, c4_out_ptr, bias_ptr, oc4div, oc4mod, plane_size, stride_size, relu_type); +#else PostConvFuncComm(c4_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_stride, output_channel, relu_type, C4NUM); +#endif return; } diff --git a/mindspore/lite/nnacl/fp32/common_func.h b/mindspore/lite/nnacl/fp32/common_func.h index 157aaecc573..55759c29586 100644 --- a/mindspore/lite/nnacl/fp32/common_func.h +++ b/mindspore/lite/nnacl/fp32/common_func.h @@ -53,6 +53,8 @@ void ConvDwFp32Border(float *dst, const float *src, const float *weight, const f size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6); void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, size_t plane_size, size_t stride, size_t relu_type); +void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, + size_t plane_size, size_t plane_stride, size_t relu_type); #endif #ifdef ENABLE_ARM64 diff --git a/mindspore/lite/nnacl/fp32/deconv_winograd.c b/mindspore/lite/nnacl/fp32/deconv_winograd.c index 24da2dc9779..5eccd59ed64 100644 --- a/mindspore/lite/nnacl/fp32/deconv_winograd.c +++ b/mindspore/lite/nnacl/fp32/deconv_winograd.c @@ -109,7 +109,11 @@ void DeConvWgInputPack(float *src_ptr, float *dst_ptr, int channel, int stride) float *dst = dst_ptr; for (int ic = 0; ic < ic4div; ic++) { +#ifdef ENABLE_ARM + vst1q_f32(dst, vld1q_f32(src)); +#else memcpy(dst, src, C4NUM * sizeof(float)); +#endif dst += stride; src += C4NUM; } @@ -159,25 +163,27 @@ void MSGemmFloatUnit_4(float *dstOrigin, const float *src, const float *weight, weight_depth_offset); } -void DeConvWgMerge(const float *source, float *dest, size_t srcStride, size_t dstStride, size_t count) { +void DeConvWgMerge(const float *src, float *dst, size_t src_stride, size_t dst_stride, size_t count) { for (int i = 0; i < count; ++i) { - const float *s = source + i * srcStride; - float *d = dest + i * dstStride; + const float *s = src + i * src_stride; + float *d = dst + i * dst_stride; for (int j = 0; j < 4; ++j) { d[j] += s[j]; } } + return; } void _deConvWinograd(float *tile_in, float *tile_out, float *weight_buf, float *tmp_buf, float *at_buf, - float *a_mid_buf, float *trans_a_buf, bool a_trans, float *bt_buf, float *b_tmp_buf, int unit_size, - int w_start, int h_start, ConvParameter *conv_param, DeConvParam *deconv_param) { + float *a_mid_buf, float *trans_a_buf, bool *transfered, float *bt_buf, float *b_tmp_buf, + int unit_size, int w_start, int h_start, ConvParameter *conv_param, DeConvParam *deconv_param) { int winograd_plane = unit_size * unit_size; - if (!a_trans) { + if (!transfered[unit_size]) { WinogradMatrixProductLeft(tile_in, at_buf, a_mid_buf, DECONV_WINOGRAD_DEFAULT_UNIT, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); WinogradMatrixProductRight(a_mid_buf, at_buf, trans_a_buf, unit_size, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); + transfered[unit_size] = true; } for (int index = 0; index < winograd_plane; index++) { @@ -274,6 +280,7 @@ void DeconvWg(float *nhwc_input_, float *tile_in, float *tile_out, int start_ind } /* compute */ + bool transfered[DECONV_WINOGRAD_BUFFER_COUNT] = {false}; for (int i = 0; i < deconv_param->compute_size_; i++) { DeConvComputeUnit *unit = &deconv_param->compute_units_[i]; if (unit->use_winograd_) { @@ -289,9 +296,8 @@ void DeconvWg(float *nhwc_input_, float *tile_in, float *tile_out, int start_ind float *tmp_b_buf = (float *)unit->winograd_.b_buffer_ + task_id * unit->winograd_.kh_ * unit->winograd_.kw_ * deconv_param->oc_up4_ * DECONV_WINOGRAD_DEFAULT_TILE; _deConvWinograd(tile_in, tile_out, (float *)unit->weight_, tmp_buf, unit->winograd_.AT_, wg_mid_a_buf, - wg_dst_a_buf, wg_buf->trans_formed_, unit->winograd_.BT_, tmp_b_buf, unit->winograd_.kh_, - unit->w_start_, unit->h_start_, conv_param, deconv_param); - wg_buf->trans_formed_ = true; + wg_dst_a_buf, transfered, unit->winograd_.BT_, tmp_b_buf, unit->winograd_.kh_, unit->w_start_, + unit->h_start_, conv_param, deconv_param); } else { float *tmp_buf = (float *)unit->tmp_buffer_ + task_id * deconv_param->oc_div4_ * unit->w_size_ * unit->h_size_ * DECONV_WINOGRAD_DEFAULT_TILE * C4NUM; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_winograd_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_winograd_fp16.cc index 35f8f5f1880..2ac3aef2641 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_winograd_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_winograd_fp16.cc @@ -75,7 +75,6 @@ int DeConvWinogradFp16CPUKernel::InitParameter() { if (unit.use_winograd_) { if (deconv_param_->a_buffer_[unit.winograd_.kh_].buf_init_ == false) { deconv_param_->a_buffer_[unit.winograd_.kh_].buf_init_ = true; - deconv_param_->a_buffer_[unit.winograd_.kh_].trans_formed_ = false; size = unit.winograd_.kh_ * unit.winograd_.kw_ * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param_->ic_up4_; deconv_param_->a_buffer_[unit.winograd_.kh_].middle_buffer_ = @@ -111,9 +110,6 @@ int DeConvWinogradFp16CPUKernel::DoDeconv(int task_id) { int calculate_count = MSMIN(DECONV_WINOGRAD_DEFAULT_TILE, deconv_param_->in_tile_w_count_ * deconv_param_->in_tile_h_count_ - start_index); - for (int i = 0; i < DECONV_WINOGRAD_BUFFER_COUNT; i++) { - deconv_param_->a_buffer_[i].trans_formed_ = false; - } DeconvWgFp16(nhwc_input_, tile_in, tile_out, start_index, calculate_count, conv_param_, deconv_param_, task_id); std::unique_lock merge_lock(lock_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_winograd.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_winograd.cc index 73c589cbe60..63ded033757 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_winograd.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_winograd.cc @@ -138,7 +138,6 @@ int DeConvolutionWinogradCPUKernel::InitParameter() { if (unit.use_winograd_) { if (deconv_param_->a_buffer_[unit.winograd_.kh_].buf_init_ == false) { deconv_param_->a_buffer_[unit.winograd_.kh_].buf_init_ = true; - deconv_param_->a_buffer_[unit.winograd_.kh_].trans_formed_ = false; size = unit.winograd_.kh_ * unit.winograd_.kw_ * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param_->ic_up4_; deconv_param_->a_buffer_[unit.winograd_.kh_].middle_buffer_ = @@ -308,9 +307,6 @@ int DeConvolutionWinogradCPUKernel::DoDeconv(int task_id) { int calculate_count = MSMIN(DECONV_WINOGRAD_DEFAULT_TILE, deconv_param_->in_tile_w_count_ * deconv_param_->in_tile_h_count_ - start_index); - for (int i = 0; i < DECONV_WINOGRAD_BUFFER_COUNT; i++) { - deconv_param_->a_buffer_[i].trans_formed_ = false; - } DeconvWg(nhwc_input_, tile_in, tile_out, start_index, calculate_count, conv_param_, deconv_param_, task_id); std::unique_lock merge_lock(lock_);