diff --git a/mindspore/lite/nnacl/assembly/arm64/ConvDw3x3Int8Corner.S b/mindspore/lite/nnacl/assembly/arm64/ConvDw3x3Int8Corner.S index 5e8a7f888b6..1c6aea12574 100644 --- a/mindspore/lite/nnacl/assembly/arm64/ConvDw3x3Int8Corner.S +++ b/mindspore/lite/nnacl/assembly/arm64/ConvDw3x3Int8Corner.S @@ -24,8 +24,8 @@ ConvDw3x3Int8Corner: dup v26.4s, w9 // out_zp ldr x9, [sp, #8] dup v27.4s, w9 // out_multiplier - ldr x9, [sp, #16] - dup v28.4s, w9 // left_shift + ldr x8, [sp, #16] + dup v28.4s, w8 // left_shift ldr x9, [sp, #24] dup v29.4s, w9 // right_shift ldr x9, [sp, #32] @@ -85,26 +85,24 @@ ConvDw3x3Int8Corner: smlal v23.4s, v3.4h, v7.4h ld1 {v6.8h}, [x12], x13 smlal2 v24.4s, v3.8h, v7.8h - - sqshl v23.4s, v23.4s, v28.4s - sqshl v24.4s, v24.4s, v28.4s - sqrdmulh v23.4s, v23.4s, v27.4s - sqrdmulh v24.4s, v24.4s, v27.4s - - and v21.16b, v29.16b, v23.16b - sshr v21.4s, v21.4s, #31 - sqadd v23.4s, v23.4s, v21.4s - srshl v23.4s, v23.4s, v29.4s - - and v22.16b, v29.16b, v24.16b - sshr v22.4s, v22.4s, #31 - sqadd v24.4s, v24.4s, v22.4s - srshl v24.4s, v24.4s, v29.4s - ld1 {v3.8b}, [x11], x5 ssubl v3.8h, v3.8b, v25.8b ld1 {v7.8h}, [x12], x13 + cbz w8, RightShiftLoop + sqshl v23.4s, v23.4s, v28.4s + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b AddZpLoop + + RightShiftLoop: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + + AddZpLoop: add v23.4s, v23.4s, v26.4s add v24.4s, v24.4s, v26.4s smax v23.4s, v23.4s, v30.4s @@ -135,21 +133,20 @@ ConvDw3x3Int8Corner: smlal v23.4s, v3.4h, v7.4h smlal2 v24.4s, v3.8h, v7.8h + cbz w8, RightShift sqshl v23.4s, v23.4s, v28.4s sqshl v24.4s, v24.4s, v28.4s sqrdmulh v23.4s, v23.4s, v27.4s sqrdmulh v24.4s, v24.4s, v27.4s + b AddZp - and v21.16b, v29.16b, v23.16b - sshr v21.4s, v21.4s, #31 - sqadd v23.4s, v23.4s, v21.4s - srshl v23.4s, v23.4s, v29.4s - - and v22.16b, v29.16b, v24.16b - sshr v22.4s, v22.4s, #31 - sqadd v24.4s, v24.4s, v22.4s - srshl v24.4s, v24.4s, v29.4s + RightShift: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + AddZp: add v23.4s, v23.4s, v26.4s add v24.4s, v24.4s, v26.4s smax v23.4s, v23.4s, v30.4s diff --git a/mindspore/lite/nnacl/assembly/arm64/ConvDw3x3Int8Horizontal.S b/mindspore/lite/nnacl/assembly/arm64/ConvDw3x3Int8Horizontal.S index e970c6bc361..a798efde262 100644 --- a/mindspore/lite/nnacl/assembly/arm64/ConvDw3x3Int8Horizontal.S +++ b/mindspore/lite/nnacl/assembly/arm64/ConvDw3x3Int8Horizontal.S @@ -24,8 +24,8 @@ ConvDw3x3Int8Horizontal: dup v26.4s, w9 // out_zp ldr x9, [sp, #8] dup v27.4s, w9 // out_multiplier - ldr x9, [sp, #16] - dup v28.4s, w9 // left_shift + ldr x8, [sp, #16] + dup v28.4s, w8 // left_shift ldr x9, [sp, #24] dup v29.4s, w9 // right_shift ldr x9, [sp, #32] @@ -109,26 +109,24 @@ ConvDw3x3Int8Horizontal: smlal v23.4s, v17.4h, v19.4h ld1 {v18.8h}, [x16], x13 smlal2 v24.4s, v17.8h, v19.8h - - sqshl v23.4s, v23.4s, v28.4s - sqshl v24.4s, v24.4s, v28.4s - sqrdmulh v23.4s, v23.4s, v27.4s - sqrdmulh v24.4s, v24.4s, v27.4s - - and v21.16b, v29.16b, v23.16b - sshr v21.4s, v21.4s, #31 - sqadd v23.4s, v23.4s, v21.4s - srshl v23.4s, v23.4s, v29.4s - - and v22.16b, v29.16b, v24.16b - sshr v22.4s, v22.4s, #31 - sqadd v24.4s, v24.4s, v22.4s - srshl v24.4s, v24.4s, v29.4s - ld1 {v17.8b}, [x15], x5 ssubl v17.8h, v17.8b, v25.8b ld1 {v19.8h}, [x16], x13 + cbz w8, RightShiftLoop + sqshl v23.4s, v23.4s, v28.4s + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b AddZpLoop + + RightShiftLoop: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + + AddZpLoop: add v23.4s, v23.4s, v26.4s add v24.4s, v24.4s, v26.4s smax v23.4s, v23.4s, v30.4s @@ -163,21 +161,20 @@ ConvDw3x3Int8Horizontal: smlal v23.4s, v17.4h, v19.4h smlal2 v24.4s, v17.8h, v19.8h + cbz w8, RightShift sqshl v23.4s, v23.4s, v28.4s sqshl v24.4s, v24.4s, v28.4s sqrdmulh v23.4s, v23.4s, v27.4s sqrdmulh v24.4s, v24.4s, v27.4s + b AddZp - and v21.16b, v29.16b, v23.16b - sshr v21.4s, v21.4s, #31 - sqadd v23.4s, v23.4s, v21.4s - srshl v23.4s, v23.4s, v29.4s - - and v22.16b, v29.16b, v24.16b - sshr v22.4s, v22.4s, #31 - sqadd v24.4s, v24.4s, v22.4s - srshl v24.4s, v24.4s, v29.4s + RightShift: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + AddZp: add v23.4s, v23.4s, v26.4s add v24.4s, v24.4s, v26.4s smax v23.4s, v23.4s, v30.4s diff --git a/mindspore/lite/nnacl/assembly/arm64/ConvDw3x3Int8Stride2.S b/mindspore/lite/nnacl/assembly/arm64/ConvDw3x3Int8Stride2.S new file mode 100644 index 00000000000..cbfbbeecf32 --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm64/ConvDw3x3Int8Stride2.S @@ -0,0 +1,395 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global ConvDw3x3Int8Stride2 +#ifndef __APPLE__ +.type ConvDw3x3Int8Stride2, %function +#endif + + +// void ConvDw3x3Int8Stride2(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, int input_col_size, +// int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp, +// int out_multiplier, int left_shift, int right_shift, int32_t acc_min, int32_t acc_max) +// +// x0: output +// x1: input +// x2: weight +// x3: bias +// w4: col_size +// w5: row_size +// w6: channel +// w7: output_h +// w8: output_w +// w9: in_zp +// w10: out_zp +// w11: out_multiplier +// w12: left_shift +// w13: right_shift +// w14: acc_min +// w15: acc_max + +ConvDw3x3Int8Stride2: + sub sp, sp, #160 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + + ldr w8, [sp] + ldr w9, [sp, #8] + ldr w10, [sp, #16] // out_zp + ldr w11, [sp, #24] + ldr w12, [sp, #32] + ldr w13, [sp, #40] + ldr w14, [sp, #48] // acc_min + ldr w15, [sp, #56] // acc_max + + add x19, x3, #16 + add w20, w6, w6 // channel * 2 + dup v28.8b, w9 // in_zp + dup v30.4s, w11 // out_multiplier + dup v29.4s, w12 // left_shift + dup v31.4s, w13 // right_shift + + // Load weights + ld1 {v0.8h}, [x2], x20 + ld1 {v1.8h}, [x2], x20 + ld1 {v2.8h}, [x2], x20 + ld1 {v3.8h}, [x2], x20 + ld1 {v4.8h}, [x2], x20 + ld1 {v5.8h}, [x2], x20 + ld1 {v6.8h}, [x2], x20 + ld1 {v7.8h}, [x2], x20 + ld1 {v8.8h}, [x2], x20 + + mov x16, x1 + add x17, x16, x5 + add x18, x17, x5 + ld1 {v9.8b}, [x16], x4 + ld1 {v10.8b}, [x16], x4 + ssubl v9.8h, v9.8b, v28.8b + ld1 {v11.8b}, [x16], x4 + ssubl v10.8h, v10.8b, v28.8b + ld1 {v14.8b}, [x17], x4 + ssubl v11.8h, v11.8b, v28.8b + ld1 {v15.8b}, [x17], x4 + ssubl v14.8h, v14.8b, v28.8b + ld1 {v16.8b}, [x17], x4 + ssubl v15.8h, v15.8b, v28.8b + ld1 {v19.8b}, [x18], x4 + ssubl v16.8h, v16.8b, v28.8b + ld1 {v20.8b}, [x18], x4 + ssubl v19.8h, v19.8b, v28.8b + ld1 {v21.8b}, [x18], x4 + ssubl v20.8h, v20.8b, v28.8b + ssubl v21.8h, v21.8b, v28.8b + + ld1 {v24.4s}, [x3] + ld1 {v25.4s}, [x19] + ld1 {v26.4s}, [x3] + ld1 {v27.4s}, [x19] + + cmp w8, #2 + beq WIDTH2_LEFT + cmp w8, #1 + beq WIDTH1_LEFT + +HEIGHT1_LOOP: + smlal v24.4s, v0.4h, v9.4h + ld1 {v12.8b}, [x16], x4 + smlal2 v25.4s, v0.8h, v9.8h + ld1 {v17.8b}, [x17], x4 + ssubl v12.8h, v12.8b, v28.8b + smlal v26.4s, v0.4h, v11.4h + ld1 {v22.8b}, [x18], x4 + ssubl v17.8h, v17.8b, v28.8b + smlal2 v27.4s, v0.8h, v11.8h + ld1 {v13.8b}, [x16], x4 + ssubl v22.8h, v22.8b, v28.8b + smlal v24.4s, v1.4h, v10.4h + ld1 {v18.8b}, [x17], x4 + ssubl v13.8h, v13.8b, v28.8b + smlal2 v25.4s, v1.8h, v10.8h + ld1 {v23.8b}, [x18], x4 + ssubl v18.8h, v18.8b, v28.8b + smlal v26.4s, v1.4h, v12.4h + mov v9.8h, v13.8h + ssubl v23.8h, v23.8b, v28.8b + smlal2 v27.4s, v1.8h, v12.8h + ld1 {v10.8b}, [x16], x4 + smlal v24.4s, v2.4h, v11.4h + smlal2 v25.4s, v2.8h, v11.8h + ld1 {v11.8b}, [x16], x4 + ssubl v10.8h, v10.8b, v28.8b + smlal v26.4s, v2.4h, v13.4h + ssubl v11.8h, v11.8b, v28.8b + smlal2 v27.4s, v2.8h, v13.8h + + smlal v24.4s, v3.4h, v14.4h + smlal2 v25.4s, v3.8h, v14.8h + mov v14.8h, v18.8h + smlal v26.4s, v3.4h, v16.4h + smlal2 v27.4s, v3.8h, v16.8h + smlal v24.4s, v4.4h, v15.4h + smlal2 v25.4s, v4.8h, v15.8h + ld1 {v15.8b}, [x17], x4 + smlal v26.4s, v4.4h, v17.4h + smlal2 v27.4s, v4.8h, v17.8h + smlal v24.4s, v5.4h, v16.4h + smlal2 v25.4s, v5.8h, v16.8h + ld1 {v16.8b}, [x17], x4 + ssubl v15.8h, v15.8b, v28.8b + smlal v26.4s, v5.4h, v18.4h + ssubl v16.8h, v16.8b, v28.8b + smlal2 v27.4s, v5.8h, v18.8h + + smlal v24.4s, v6.4h, v19.4h + smlal2 v25.4s, v6.8h, v19.8h + mov v19.8h, v23.8h + smlal v26.4s, v6.4h, v21.4h + smlal2 v27.4s, v6.8h, v21.8h + smlal v24.4s, v7.4h, v20.4h + smlal2 v25.4s, v7.8h, v20.8h + ld1 {v20.8b}, [x18], x4 + smlal v26.4s, v7.4h, v22.4h + smlal2 v27.4s, v7.8h, v22.8h + smlal v24.4s, v8.4h, v21.4h + smlal2 v25.4s, v8.8h, v21.8h + ld1 {v21.8b}, [x18], x4 + ssubl v20.8h, v20.8b, v28.8b + smlal v26.4s, v8.4h, v23.4h + ssubl v21.8h, v21.8b, v28.8b + smlal2 v27.4s, v8.8h, v23.8h + + dup v12.4s, w10 // out_zp + cbz w12, SKIP_LEFTSHIFT1 + sqshl v24.4s, v24.4s, v29.4s + sqshl v25.4s, v25.4s, v29.4s + sqshl v26.4s, v26.4s, v29.4s + sqshl v27.4s, v27.4s, v29.4s + sqrdmulh v24.4s, v24.4s, v30.4s + sqrdmulh v25.4s, v25.4s, v30.4s + sqrdmulh v26.4s, v26.4s, v30.4s + sqrdmulh v27.4s, v27.4s, v30.4s + b OUTZP1 + +SKIP_LEFTSHIFT1: + sqrdmulh v24.4s, v24.4s, v30.4s + sqrdmulh v25.4s, v25.4s, v30.4s + sqrdmulh v26.4s, v26.4s, v30.4s + sqrdmulh v27.4s, v27.4s, v30.4s + sqrshl v24.4s, v24.4s, v31.4s + sqrshl v25.4s, v25.4s, v31.4s + sqrshl v26.4s, v26.4s, v31.4s + sqrshl v27.4s, v27.4s, v31.4s + +OUTZP1: + // Add output zero point + dup v17.4s, w14 // acc_min + sqadd v24.4s, v24.4s, v12.4s + sqadd v25.4s, v25.4s, v12.4s + sqadd v26.4s, v26.4s, v12.4s + sqadd v27.4s, v27.4s, v12.4s + + // Apply min bound + dup v22.4s, w15 // acc_max + smax v24.4s, v24.4s, v17.4s + smax v25.4s, v25.4s, v17.4s + smax v26.4s, v26.4s, v17.4s + smax v27.4s, v27.4s, v17.4s + + // Apply max bound + smin v24.4s, v24.4s, v22.4s + smin v25.4s, v25.4s, v22.4s + smin v26.4s, v26.4s, v22.4s + smin v27.4s, v27.4s, v22.4s + + sqxtn v24.4h, v24.4s + sqxtn2 v24.8h, v25.4s + ld1 {v25.4s}, [x19] + sqxtn v26.4h, v26.4s + sqxtn2 v26.8h, v27.4s + ld1 {v27.4s}, [x19] + sqxtn v24.8b, v24.8h + sqxtn2 v24.16b, v26.8h + st1 {v24.8b}, [x0], x6 + mov v26.d[0], v24.d[1] + ld1 {v24.4s}, [x3] + st1 {v26.8b}, [x0], x6 + ld1 {v26.4s}, [x3] + sub w8, w8, #2 + cmp w8, #2 + bgt HEIGHT1_LOOP + + cmp w8, #2 + blt WIDTH1_LEFT + +WIDTH2_LEFT: + smlal v24.4s, v0.4h, v9.4h + ld1 {v12.8b}, [x16], x4 + smlal2 v25.4s, v0.8h, v9.8h + ld1 {v17.8b}, [x17], x4 + ssubl v12.8h, v12.8b, v28.8b + smlal v26.4s, v0.4h, v11.4h + ld1 {v22.8b}, [x18], x4 + ssubl v17.8h, v17.8b, v28.8b + smlal2 v27.4s, v0.8h, v11.8h + ld1 {v13.8b}, [x16], x4 + ssubl v22.8h, v22.8b, v28.8b + smlal v24.4s, v1.4h, v10.4h + ld1 {v18.8b}, [x17], x4 + ssubl v13.8h, v13.8b, v28.8b + smlal2 v25.4s, v1.8h, v10.8h + ld1 {v23.8b}, [x18], x4 + ssubl v18.8h, v18.8b, v28.8b + smlal v26.4s, v1.4h, v12.4h + ssubl v23.8h, v23.8b, v28.8b + smlal2 v27.4s, v1.8h, v12.8h + smlal v24.4s, v2.4h, v11.4h + smlal2 v25.4s, v2.8h, v11.8h + smlal v26.4s, v2.4h, v13.4h + smlal2 v27.4s, v2.8h, v13.8h + + smlal v24.4s, v3.4h, v14.4h + smlal2 v25.4s, v3.8h, v14.8h + smlal v26.4s, v3.4h, v16.4h + smlal2 v27.4s, v3.8h, v16.8h + smlal v24.4s, v4.4h, v15.4h + smlal2 v25.4s, v4.8h, v15.8h + smlal v26.4s, v4.4h, v17.4h + smlal2 v27.4s, v4.8h, v17.8h + smlal v24.4s, v5.4h, v16.4h + smlal2 v25.4s, v5.8h, v16.8h + smlal v26.4s, v5.4h, v18.4h + smlal2 v27.4s, v5.8h, v18.8h + + smlal v24.4s, v6.4h, v19.4h + smlal2 v25.4s, v6.8h, v19.8h + smlal v26.4s, v6.4h, v21.4h + smlal2 v27.4s, v6.8h, v21.8h + smlal v24.4s, v7.4h, v20.4h + smlal2 v25.4s, v7.8h, v20.8h + smlal v26.4s, v7.4h, v22.4h + smlal2 v27.4s, v7.8h, v22.8h + smlal v24.4s, v8.4h, v21.4h + smlal2 v25.4s, v8.8h, v21.8h + smlal v26.4s, v8.4h, v23.4h + smlal2 v27.4s, v8.8h, v23.8h + + dup v12.4s, w10 + cbz w12, SKIP_LEFTSHIFT2 + sqshl v24.4s, v24.4s, v29.4s + sqshl v25.4s, v25.4s, v29.4s + sqshl v26.4s, v26.4s, v29.4s + sqshl v27.4s, v27.4s, v29.4s + sqrdmulh v24.4s, v24.4s, v30.4s + sqrdmulh v25.4s, v25.4s, v30.4s + sqrdmulh v26.4s, v26.4s, v30.4s + sqrdmulh v27.4s, v27.4s, v30.4s + b OUTZP2 + +SKIP_LEFTSHIFT2: + sqrdmulh v24.4s, v24.4s, v30.4s + sqrdmulh v25.4s, v25.4s, v30.4s + sqrdmulh v26.4s, v26.4s, v30.4s + sqrdmulh v27.4s, v27.4s, v30.4s + sqrshl v24.4s, v24.4s, v31.4s + sqrshl v25.4s, v25.4s, v31.4s + sqrshl v26.4s, v26.4s, v31.4s + sqrshl v27.4s, v27.4s, v31.4s + +OUTZP2: + // Add output zero point + dup v17.4s, w14 + sqadd v24.4s, v24.4s, v12.4s + sqadd v25.4s, v25.4s, v12.4s + sqadd v26.4s, v26.4s, v12.4s + sqadd v27.4s, v27.4s, v12.4s + + // Apply min bound + dup v22.4s, w15 + smax v24.4s, v24.4s, v17.4s + smax v25.4s, v25.4s, v17.4s + smax v26.4s, v26.4s, v17.4s + smax v27.4s, v27.4s, v17.4s + + // Apply max bound + smin v24.4s, v24.4s, v22.4s + smin v25.4s, v25.4s, v22.4s + smin v26.4s, v26.4s, v22.4s + smin v27.4s, v27.4s, v22.4s + + sqxtn v24.4h, v24.4s + sqxtn2 v24.8h, v25.4s + sqxtn v26.4h, v26.4s + sqxtn2 v26.8h, v27.4s + sqxtn v24.8b, v24.8h + sqxtn2 v24.16b, v26.8h + st1 {v24.8b}, [x0], x6 + mov v26.d[0], v24.d[1] + st1 {v26.8b}, [x0], x6 + b End + +WIDTH1_LEFT: + smlal v24.4s, v0.4h, v9.4h + smlal2 v25.4s, v0.8h, v9.8h + smlal v24.4s, v1.4h, v10.4h + smlal2 v25.4s, v1.8h, v10.8h + smlal v24.4s, v2.4h, v11.4h + smlal2 v25.4s, v2.8h, v11.8h + smlal v24.4s, v3.4h, v14.4h + smlal2 v25.4s, v3.8h, v14.8h + smlal v24.4s, v4.4h, v15.4h + smlal2 v25.4s, v4.8h, v15.8h + smlal v24.4s, v5.4h, v16.4h + smlal2 v25.4s, v5.8h, v16.8h + smlal v24.4s, v6.4h, v19.4h + smlal2 v25.4s, v6.8h, v19.8h + smlal v24.4s, v7.4h, v20.4h + smlal2 v25.4s, v7.8h, v20.8h + smlal v24.4s, v8.4h, v21.4h + smlal2 v25.4s, v8.8h, v21.8h + + dup v12.4s, w10 + cbz w12, SKIP_LEFTSHIFT3 + sqshl v24.4s, v24.4s, v29.4s + sqshl v25.4s, v25.4s, v29.4s + sqrdmulh v24.4s, v24.4s, v30.4s + sqrdmulh v25.4s, v25.4s, v30.4s + b OUTZP3 + +SKIP_LEFTSHIFT3: + sqrdmulh v24.4s, v24.4s, v30.4s + sqrdmulh v25.4s, v25.4s, v30.4s + sqrshl v24.4s, v24.4s, v31.4s + sqrshl v25.4s, v25.4s, v31.4s + +OUTZP3: + // Add output zero point + dup v17.4s, w14 + sqadd v24.4s, v24.4s, v12.4s + sqadd v25.4s, v25.4s, v12.4s + + // Apply min bound + dup v22.4s, w15 + smax v24.4s, v24.4s, v17.4s + smax v25.4s, v25.4s, v17.4s + + // Apply max bound + smin v24.4s, v24.4s, v22.4s + smin v25.4s, v25.4s, v22.4s + + sqxtn v24.4h, v24.4s + sqxtn2 v24.8h, v25.4s + sqxtn v24.8b, v24.8h + st1 {v24.8b}, [x0], x6 + +End: + sub sp, sp, #160 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore/lite/nnacl/assembly/arm64/ConvDw3x3Int8Vertical.S b/mindspore/lite/nnacl/assembly/arm64/ConvDw3x3Int8Vertical.S index d1e2dce2bbd..040b12ed503 100644 --- a/mindspore/lite/nnacl/assembly/arm64/ConvDw3x3Int8Vertical.S +++ b/mindspore/lite/nnacl/assembly/arm64/ConvDw3x3Int8Vertical.S @@ -24,8 +24,8 @@ ConvDw3x3Int8Vertical: dup v26.4s, w9 // out_zp ldr x9, [sp, #8] dup v27.4s, w9 // out_multiplier - ldr x9, [sp, #16] - dup v28.4s, w9 // left_shift + ldr x8, [sp, #16] + dup v28.4s, w8 // left_shift ldr x9, [sp, #24] dup v29.4s, w9 // right_shift ldr x9, [sp, #32] @@ -105,26 +105,24 @@ ConvDw3x3Int8Vertical: smlal v23.4s, v17.4h, v19.4h ld1 {v18.8h}, [x10], x13 smlal2 v24.4s, v17.8h, v19.8h - - sqshl v23.4s, v23.4s, v28.4s - sqshl v24.4s, v24.4s, v28.4s - sqrdmulh v23.4s, v23.4s, v27.4s - sqrdmulh v24.4s, v24.4s, v27.4s - - and v21.16b, v29.16b, v23.16b - sshr v21.4s, v21.4s, #31 - sqadd v23.4s, v23.4s, v21.4s - srshl v23.4s, v23.4s, v29.4s - - and v22.16b, v29.16b, v24.16b - sshr v22.4s, v22.4s, #31 - sqadd v24.4s, v24.4s, v22.4s - srshl v24.4s, v24.4s, v29.4s - ld1 {v17.8b}, [x11], x5 ssubl v17.8h, v17.8b, v25.8b ld1 {v19.8h}, [x12], x13 + cbz w8, RightShiftLoop + sqshl v23.4s, v23.4s, v28.4s + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b AddZpLoop + + RightShiftLoop: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + + AddZpLoop: add v23.4s, v23.4s, v26.4s add v24.4s, v24.4s, v26.4s smax v23.4s, v23.4s, v30.4s @@ -159,21 +157,20 @@ ConvDw3x3Int8Vertical: smlal v23.4s, v17.4h, v19.4h smlal2 v24.4s, v17.8h, v19.8h + cbz w8, RightShift sqshl v23.4s, v23.4s, v28.4s sqshl v24.4s, v24.4s, v28.4s sqrdmulh v23.4s, v23.4s, v27.4s sqrdmulh v24.4s, v24.4s, v27.4s + b AddZp - and v21.16b, v29.16b, v23.16b - sshr v21.4s, v21.4s, #31 - sqadd v23.4s, v23.4s, v21.4s - srshl v23.4s, v23.4s, v29.4s - - and v22.16b, v29.16b, v24.16b - sshr v22.4s, v22.4s, #31 - sqadd v24.4s, v24.4s, v22.4s - srshl v24.4s, v24.4s, v29.4s + RightShift: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + AddZp: add v23.4s, v23.4s, v26.4s add v24.4s, v24.4s, v26.4s smax v23.4s, v23.4s, v30.4s diff --git a/mindspore/lite/nnacl/int8/common_func_int8.h b/mindspore/lite/nnacl/int8/common_func_int8.h index c9a555cf75d..912200619f2 100644 --- a/mindspore/lite/nnacl/int8/common_func_int8.h +++ b/mindspore/lite/nnacl/int8/common_func_int8.h @@ -77,6 +77,10 @@ void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *wei int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min, int32_t acc_max); +void ConvDw3x3Int8Stride2(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, + int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, + int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min, + int32_t acc_max); void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, size_t out_multiplier, size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max); diff --git a/mindspore/lite/nnacl/int8/conv_depthwise_int8.c b/mindspore/lite/nnacl/int8/conv_depthwise_int8.c index 4e46ee4d69b..41d838c0ebe 100644 --- a/mindspore/lite/nnacl/int8/conv_depthwise_int8.c +++ b/mindspore/lite/nnacl/int8/conv_depthwise_int8.c @@ -139,11 +139,22 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da /*conv depthwise int8 end*/ /*conv depthwise 3x3 int8 begin*/ -bool CheckIfUse3X3(const ConvParameter *conv_param, int channel) { - bool use_3x3 = conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && conv_param->stride_h_ == 1 && - conv_param->stride_w_ == 1 && (conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) && +bool CheckIfUse3X3(const ConvParameter *conv_param) { + bool use_3x3 = conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && + (conv_param->stride_h_ == 1 || conv_param->stride_h_ == 2) && + (conv_param->stride_w_ == 1 || conv_param->stride_w_ == 2) && + conv_param->stride_h_ == conv_param->stride_w_ && + (conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) && (conv_param->pad_l_ == 0 || conv_param->pad_l_ == 1) && conv_param->pad_u_ == conv_param->pad_l_ && - conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && (channel % 8 == 0); + conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && (conv_param->input_channel_ % 8 == 0); + if (!use_3x3) { + return false; + } + const int out_w = conv_param->output_w_ - 1; + const int out_h = conv_param->output_h_ - 1; + const int in_w = out_w * conv_param->stride_w_ - conv_param->pad_l_ + conv_param->kernel_w_; + const int in_h = out_h * conv_param->stride_h_ - conv_param->pad_u_ + conv_param->kernel_h_; + use_3x3 = in_w <= (conv_param->input_w_ + conv_param->pad_l_) && in_h <= (conv_param->input_h_ + conv_param->pad_u_); return use_3x3; } @@ -206,8 +217,14 @@ void ConvDw3x3Int8Block(int8_t *output, const int8_t *buffer, const int16_t *wei int32_t acc_max, int stride) { for (; start_c <= end_c - 8; start_c += 8) { #ifdef ENABLE_ARM64 - ConvDw3x3Int8Neon64(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp, - out_multiplier, left_shift, right_shift, acc_min, acc_max); + if (stride == 1) { + ConvDw3x3Int8Neon64(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp, + out_multiplier, left_shift, right_shift, acc_min, acc_max); + } else { + ConvDw3x3Int8Stride2(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp, + out_multiplier, left_shift, right_shift, acc_min, acc_max); + } + #else ConvDw3x3Int8Window(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, stride); diff --git a/mindspore/lite/nnacl/int8/conv_depthwise_int8.h b/mindspore/lite/nnacl/int8/conv_depthwise_int8.h index 06592604fc8..ae132861c63 100644 --- a/mindspore/lite/nnacl/int8/conv_depthwise_int8.h +++ b/mindspore/lite/nnacl/int8/conv_depthwise_int8.h @@ -24,7 +24,7 @@ extern "C" { #endif -bool CheckIfUse3X3(const ConvParameter *conv_param, int channel); +bool CheckIfUse3X3(const ConvParameter *conv_param); void ConvDwInt8(int8_t *output_data, int32_t *output_row, const int8_t *input_data, const int16_t *weight_data, const int32_t *bias_data, const ConvParameter *conv_param, int task_id); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc index 706f286524f..6b7f469c037 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc @@ -168,22 +168,26 @@ kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector const mindspore::lite::PrimitiveC *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); - kernel::LiteKernel *kernel; + kernel::LiteKernel *kernel = nullptr; auto act_quant_size = MSMAX(inputs[kInputIndex]->GetQuantParams().size(), outputs[kOutputIndex]->GetQuantParams().size()); if (act_quant_size == 1) { // per tensor - auto conv_parm = reinterpret_cast(opParameter); - auto channel = inputs[kWeightIndex]->shape()[0]; + auto conv_param = reinterpret_cast(opParameter); + if (primitive != nullptr && primitive->GetInferFlag()) { + conv_param->input_h_ = inputs[kInputIndex]->Height(); + conv_param->input_w_ = inputs[kInputIndex]->Width(); + conv_param->input_channel_ = inputs[kInputIndex]->Channel(); + conv_param->output_h_ = outputs[kOutputIndex]->Height(); + conv_param->output_w_ = outputs[kOutputIndex]->Width(); + } auto weight_quant_size = inputs[kWeightIndex]->GetQuantParams().size(); - if (CheckIfUse3X3(conv_parm, channel) && weight_quant_size == 1) { + if (CheckIfUse3X3(conv_param) && weight_quant_size == 1) { #ifdef ENABLE_ARM64 kernel = new (std::nothrow) kernel::ConvolutionDepthwise3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); -#else - kernel = - new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); #endif - } else { + } + if (kernel == nullptr) { kernel = new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); } diff --git a/mindspore/lite/test/models_only_for_process.cfg b/mindspore/lite/test/models_only_for_process.cfg index c7f62a22999..e69de29bb2d 100644 --- a/mindspore/lite/test/models_only_for_process.cfg +++ b/mindspore/lite/test/models_only_for_process.cfg @@ -1,2 +0,0 @@ -ssd-10.onnx -efficientnet-lite4-11.onnx