[MSLITE][Develop] optimize arm cpu int8 op conv dw 3x3, add stride 2 assembly

This commit is contained in:
yangruoqi713 2020-10-27 14:59:23 +08:00
parent 9fc0218c56
commit 6273bdaedb
9 changed files with 507 additions and 98 deletions

View File

@ -24,8 +24,8 @@ ConvDw3x3Int8Corner:
dup v26.4s, w9 // out_zp
ldr x9, [sp, #8]
dup v27.4s, w9 // out_multiplier
ldr x9, [sp, #16]
dup v28.4s, w9 // left_shift
ldr x8, [sp, #16]
dup v28.4s, w8 // left_shift
ldr x9, [sp, #24]
dup v29.4s, w9 // right_shift
ldr x9, [sp, #32]
@ -85,26 +85,24 @@ ConvDw3x3Int8Corner:
smlal v23.4s, v3.4h, v7.4h
ld1 {v6.8h}, [x12], x13
smlal2 v24.4s, v3.8h, v7.8h
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
and v21.16b, v29.16b, v23.16b
sshr v21.4s, v21.4s, #31
sqadd v23.4s, v23.4s, v21.4s
srshl v23.4s, v23.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
ld1 {v3.8b}, [x11], x5
ssubl v3.8h, v3.8b, v25.8b
ld1 {v7.8h}, [x12], x13
cbz w8, RightShiftLoop
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
b AddZpLoop
RightShiftLoop:
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
AddZpLoop:
add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s
@ -135,21 +133,20 @@ ConvDw3x3Int8Corner:
smlal v23.4s, v3.4h, v7.4h
smlal2 v24.4s, v3.8h, v7.8h
cbz w8, RightShift
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
b AddZp
and v21.16b, v29.16b, v23.16b
sshr v21.4s, v21.4s, #31
sqadd v23.4s, v23.4s, v21.4s
srshl v23.4s, v23.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
RightShift:
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
AddZp:
add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s

View File

@ -24,8 +24,8 @@ ConvDw3x3Int8Horizontal:
dup v26.4s, w9 // out_zp
ldr x9, [sp, #8]
dup v27.4s, w9 // out_multiplier
ldr x9, [sp, #16]
dup v28.4s, w9 // left_shift
ldr x8, [sp, #16]
dup v28.4s, w8 // left_shift
ldr x9, [sp, #24]
dup v29.4s, w9 // right_shift
ldr x9, [sp, #32]
@ -109,26 +109,24 @@ ConvDw3x3Int8Horizontal:
smlal v23.4s, v17.4h, v19.4h
ld1 {v18.8h}, [x16], x13
smlal2 v24.4s, v17.8h, v19.8h
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
and v21.16b, v29.16b, v23.16b
sshr v21.4s, v21.4s, #31
sqadd v23.4s, v23.4s, v21.4s
srshl v23.4s, v23.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
ld1 {v17.8b}, [x15], x5
ssubl v17.8h, v17.8b, v25.8b
ld1 {v19.8h}, [x16], x13
cbz w8, RightShiftLoop
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
b AddZpLoop
RightShiftLoop:
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
AddZpLoop:
add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s
@ -163,21 +161,20 @@ ConvDw3x3Int8Horizontal:
smlal v23.4s, v17.4h, v19.4h
smlal2 v24.4s, v17.8h, v19.8h
cbz w8, RightShift
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
b AddZp
and v21.16b, v29.16b, v23.16b
sshr v21.4s, v21.4s, #31
sqadd v23.4s, v23.4s, v21.4s
srshl v23.4s, v23.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
RightShift:
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
AddZp:
add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s

View File

@ -0,0 +1,395 @@
#ifdef __aarch64__
.text
.align 5
.global ConvDw3x3Int8Stride2
#ifndef __APPLE__
.type ConvDw3x3Int8Stride2, %function
#endif
// void ConvDw3x3Int8Stride2(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, int input_col_size,
// int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp,
// int out_multiplier, int left_shift, int right_shift, int32_t acc_min, int32_t acc_max)
//
// x0: output
// x1: input
// x2: weight
// x3: bias
// w4: col_size
// w5: row_size
// w6: channel
// w7: output_h
// w8: output_w
// w9: in_zp
// w10: out_zp
// w11: out_multiplier
// w12: left_shift
// w13: right_shift
// w14: acc_min
// w15: acc_max
ConvDw3x3Int8Stride2:
sub sp, sp, #160
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
stp x19, x20, [sp], #16
stp x21, x22, [sp], #16
ldr w8, [sp]
ldr w9, [sp, #8]
ldr w10, [sp, #16] // out_zp
ldr w11, [sp, #24]
ldr w12, [sp, #32]
ldr w13, [sp, #40]
ldr w14, [sp, #48] // acc_min
ldr w15, [sp, #56] // acc_max
add x19, x3, #16
add w20, w6, w6 // channel * 2
dup v28.8b, w9 // in_zp
dup v30.4s, w11 // out_multiplier
dup v29.4s, w12 // left_shift
dup v31.4s, w13 // right_shift
// Load weights
ld1 {v0.8h}, [x2], x20
ld1 {v1.8h}, [x2], x20
ld1 {v2.8h}, [x2], x20
ld1 {v3.8h}, [x2], x20
ld1 {v4.8h}, [x2], x20
ld1 {v5.8h}, [x2], x20
ld1 {v6.8h}, [x2], x20
ld1 {v7.8h}, [x2], x20
ld1 {v8.8h}, [x2], x20
mov x16, x1
add x17, x16, x5
add x18, x17, x5
ld1 {v9.8b}, [x16], x4
ld1 {v10.8b}, [x16], x4
ssubl v9.8h, v9.8b, v28.8b
ld1 {v11.8b}, [x16], x4
ssubl v10.8h, v10.8b, v28.8b
ld1 {v14.8b}, [x17], x4
ssubl v11.8h, v11.8b, v28.8b
ld1 {v15.8b}, [x17], x4
ssubl v14.8h, v14.8b, v28.8b
ld1 {v16.8b}, [x17], x4
ssubl v15.8h, v15.8b, v28.8b
ld1 {v19.8b}, [x18], x4
ssubl v16.8h, v16.8b, v28.8b
ld1 {v20.8b}, [x18], x4
ssubl v19.8h, v19.8b, v28.8b
ld1 {v21.8b}, [x18], x4
ssubl v20.8h, v20.8b, v28.8b
ssubl v21.8h, v21.8b, v28.8b
ld1 {v24.4s}, [x3]
ld1 {v25.4s}, [x19]
ld1 {v26.4s}, [x3]
ld1 {v27.4s}, [x19]
cmp w8, #2
beq WIDTH2_LEFT
cmp w8, #1
beq WIDTH1_LEFT
HEIGHT1_LOOP:
smlal v24.4s, v0.4h, v9.4h
ld1 {v12.8b}, [x16], x4
smlal2 v25.4s, v0.8h, v9.8h
ld1 {v17.8b}, [x17], x4
ssubl v12.8h, v12.8b, v28.8b
smlal v26.4s, v0.4h, v11.4h
ld1 {v22.8b}, [x18], x4
ssubl v17.8h, v17.8b, v28.8b
smlal2 v27.4s, v0.8h, v11.8h
ld1 {v13.8b}, [x16], x4
ssubl v22.8h, v22.8b, v28.8b
smlal v24.4s, v1.4h, v10.4h
ld1 {v18.8b}, [x17], x4
ssubl v13.8h, v13.8b, v28.8b
smlal2 v25.4s, v1.8h, v10.8h
ld1 {v23.8b}, [x18], x4
ssubl v18.8h, v18.8b, v28.8b
smlal v26.4s, v1.4h, v12.4h
mov v9.8h, v13.8h
ssubl v23.8h, v23.8b, v28.8b
smlal2 v27.4s, v1.8h, v12.8h
ld1 {v10.8b}, [x16], x4
smlal v24.4s, v2.4h, v11.4h
smlal2 v25.4s, v2.8h, v11.8h
ld1 {v11.8b}, [x16], x4
ssubl v10.8h, v10.8b, v28.8b
smlal v26.4s, v2.4h, v13.4h
ssubl v11.8h, v11.8b, v28.8b
smlal2 v27.4s, v2.8h, v13.8h
smlal v24.4s, v3.4h, v14.4h
smlal2 v25.4s, v3.8h, v14.8h
mov v14.8h, v18.8h
smlal v26.4s, v3.4h, v16.4h
smlal2 v27.4s, v3.8h, v16.8h
smlal v24.4s, v4.4h, v15.4h
smlal2 v25.4s, v4.8h, v15.8h
ld1 {v15.8b}, [x17], x4
smlal v26.4s, v4.4h, v17.4h
smlal2 v27.4s, v4.8h, v17.8h
smlal v24.4s, v5.4h, v16.4h
smlal2 v25.4s, v5.8h, v16.8h
ld1 {v16.8b}, [x17], x4
ssubl v15.8h, v15.8b, v28.8b
smlal v26.4s, v5.4h, v18.4h
ssubl v16.8h, v16.8b, v28.8b
smlal2 v27.4s, v5.8h, v18.8h
smlal v24.4s, v6.4h, v19.4h
smlal2 v25.4s, v6.8h, v19.8h
mov v19.8h, v23.8h
smlal v26.4s, v6.4h, v21.4h
smlal2 v27.4s, v6.8h, v21.8h
smlal v24.4s, v7.4h, v20.4h
smlal2 v25.4s, v7.8h, v20.8h
ld1 {v20.8b}, [x18], x4
smlal v26.4s, v7.4h, v22.4h
smlal2 v27.4s, v7.8h, v22.8h
smlal v24.4s, v8.4h, v21.4h
smlal2 v25.4s, v8.8h, v21.8h
ld1 {v21.8b}, [x18], x4
ssubl v20.8h, v20.8b, v28.8b
smlal v26.4s, v8.4h, v23.4h
ssubl v21.8h, v21.8b, v28.8b
smlal2 v27.4s, v8.8h, v23.8h
dup v12.4s, w10 // out_zp
cbz w12, SKIP_LEFTSHIFT1
sqshl v24.4s, v24.4s, v29.4s
sqshl v25.4s, v25.4s, v29.4s
sqshl v26.4s, v26.4s, v29.4s
sqshl v27.4s, v27.4s, v29.4s
sqrdmulh v24.4s, v24.4s, v30.4s
sqrdmulh v25.4s, v25.4s, v30.4s
sqrdmulh v26.4s, v26.4s, v30.4s
sqrdmulh v27.4s, v27.4s, v30.4s
b OUTZP1
SKIP_LEFTSHIFT1:
sqrdmulh v24.4s, v24.4s, v30.4s
sqrdmulh v25.4s, v25.4s, v30.4s
sqrdmulh v26.4s, v26.4s, v30.4s
sqrdmulh v27.4s, v27.4s, v30.4s
sqrshl v24.4s, v24.4s, v31.4s
sqrshl v25.4s, v25.4s, v31.4s
sqrshl v26.4s, v26.4s, v31.4s
sqrshl v27.4s, v27.4s, v31.4s
OUTZP1:
// Add output zero point
dup v17.4s, w14 // acc_min
sqadd v24.4s, v24.4s, v12.4s
sqadd v25.4s, v25.4s, v12.4s
sqadd v26.4s, v26.4s, v12.4s
sqadd v27.4s, v27.4s, v12.4s
// Apply min bound
dup v22.4s, w15 // acc_max
smax v24.4s, v24.4s, v17.4s
smax v25.4s, v25.4s, v17.4s
smax v26.4s, v26.4s, v17.4s
smax v27.4s, v27.4s, v17.4s
// Apply max bound
smin v24.4s, v24.4s, v22.4s
smin v25.4s, v25.4s, v22.4s
smin v26.4s, v26.4s, v22.4s
smin v27.4s, v27.4s, v22.4s
sqxtn v24.4h, v24.4s
sqxtn2 v24.8h, v25.4s
ld1 {v25.4s}, [x19]
sqxtn v26.4h, v26.4s
sqxtn2 v26.8h, v27.4s
ld1 {v27.4s}, [x19]
sqxtn v24.8b, v24.8h
sqxtn2 v24.16b, v26.8h
st1 {v24.8b}, [x0], x6
mov v26.d[0], v24.d[1]
ld1 {v24.4s}, [x3]
st1 {v26.8b}, [x0], x6
ld1 {v26.4s}, [x3]
sub w8, w8, #2
cmp w8, #2
bgt HEIGHT1_LOOP
cmp w8, #2
blt WIDTH1_LEFT
WIDTH2_LEFT:
smlal v24.4s, v0.4h, v9.4h
ld1 {v12.8b}, [x16], x4
smlal2 v25.4s, v0.8h, v9.8h
ld1 {v17.8b}, [x17], x4
ssubl v12.8h, v12.8b, v28.8b
smlal v26.4s, v0.4h, v11.4h
ld1 {v22.8b}, [x18], x4
ssubl v17.8h, v17.8b, v28.8b
smlal2 v27.4s, v0.8h, v11.8h
ld1 {v13.8b}, [x16], x4
ssubl v22.8h, v22.8b, v28.8b
smlal v24.4s, v1.4h, v10.4h
ld1 {v18.8b}, [x17], x4
ssubl v13.8h, v13.8b, v28.8b
smlal2 v25.4s, v1.8h, v10.8h
ld1 {v23.8b}, [x18], x4
ssubl v18.8h, v18.8b, v28.8b
smlal v26.4s, v1.4h, v12.4h
ssubl v23.8h, v23.8b, v28.8b
smlal2 v27.4s, v1.8h, v12.8h
smlal v24.4s, v2.4h, v11.4h
smlal2 v25.4s, v2.8h, v11.8h
smlal v26.4s, v2.4h, v13.4h
smlal2 v27.4s, v2.8h, v13.8h
smlal v24.4s, v3.4h, v14.4h
smlal2 v25.4s, v3.8h, v14.8h
smlal v26.4s, v3.4h, v16.4h
smlal2 v27.4s, v3.8h, v16.8h
smlal v24.4s, v4.4h, v15.4h
smlal2 v25.4s, v4.8h, v15.8h
smlal v26.4s, v4.4h, v17.4h
smlal2 v27.4s, v4.8h, v17.8h
smlal v24.4s, v5.4h, v16.4h
smlal2 v25.4s, v5.8h, v16.8h
smlal v26.4s, v5.4h, v18.4h
smlal2 v27.4s, v5.8h, v18.8h
smlal v24.4s, v6.4h, v19.4h
smlal2 v25.4s, v6.8h, v19.8h
smlal v26.4s, v6.4h, v21.4h
smlal2 v27.4s, v6.8h, v21.8h
smlal v24.4s, v7.4h, v20.4h
smlal2 v25.4s, v7.8h, v20.8h
smlal v26.4s, v7.4h, v22.4h
smlal2 v27.4s, v7.8h, v22.8h
smlal v24.4s, v8.4h, v21.4h
smlal2 v25.4s, v8.8h, v21.8h
smlal v26.4s, v8.4h, v23.4h
smlal2 v27.4s, v8.8h, v23.8h
dup v12.4s, w10
cbz w12, SKIP_LEFTSHIFT2
sqshl v24.4s, v24.4s, v29.4s
sqshl v25.4s, v25.4s, v29.4s
sqshl v26.4s, v26.4s, v29.4s
sqshl v27.4s, v27.4s, v29.4s
sqrdmulh v24.4s, v24.4s, v30.4s
sqrdmulh v25.4s, v25.4s, v30.4s
sqrdmulh v26.4s, v26.4s, v30.4s
sqrdmulh v27.4s, v27.4s, v30.4s
b OUTZP2
SKIP_LEFTSHIFT2:
sqrdmulh v24.4s, v24.4s, v30.4s
sqrdmulh v25.4s, v25.4s, v30.4s
sqrdmulh v26.4s, v26.4s, v30.4s
sqrdmulh v27.4s, v27.4s, v30.4s
sqrshl v24.4s, v24.4s, v31.4s
sqrshl v25.4s, v25.4s, v31.4s
sqrshl v26.4s, v26.4s, v31.4s
sqrshl v27.4s, v27.4s, v31.4s
OUTZP2:
// Add output zero point
dup v17.4s, w14
sqadd v24.4s, v24.4s, v12.4s
sqadd v25.4s, v25.4s, v12.4s
sqadd v26.4s, v26.4s, v12.4s
sqadd v27.4s, v27.4s, v12.4s
// Apply min bound
dup v22.4s, w15
smax v24.4s, v24.4s, v17.4s
smax v25.4s, v25.4s, v17.4s
smax v26.4s, v26.4s, v17.4s
smax v27.4s, v27.4s, v17.4s
// Apply max bound
smin v24.4s, v24.4s, v22.4s
smin v25.4s, v25.4s, v22.4s
smin v26.4s, v26.4s, v22.4s
smin v27.4s, v27.4s, v22.4s
sqxtn v24.4h, v24.4s
sqxtn2 v24.8h, v25.4s
sqxtn v26.4h, v26.4s
sqxtn2 v26.8h, v27.4s
sqxtn v24.8b, v24.8h
sqxtn2 v24.16b, v26.8h
st1 {v24.8b}, [x0], x6
mov v26.d[0], v24.d[1]
st1 {v26.8b}, [x0], x6
b End
WIDTH1_LEFT:
smlal v24.4s, v0.4h, v9.4h
smlal2 v25.4s, v0.8h, v9.8h
smlal v24.4s, v1.4h, v10.4h
smlal2 v25.4s, v1.8h, v10.8h
smlal v24.4s, v2.4h, v11.4h
smlal2 v25.4s, v2.8h, v11.8h
smlal v24.4s, v3.4h, v14.4h
smlal2 v25.4s, v3.8h, v14.8h
smlal v24.4s, v4.4h, v15.4h
smlal2 v25.4s, v4.8h, v15.8h
smlal v24.4s, v5.4h, v16.4h
smlal2 v25.4s, v5.8h, v16.8h
smlal v24.4s, v6.4h, v19.4h
smlal2 v25.4s, v6.8h, v19.8h
smlal v24.4s, v7.4h, v20.4h
smlal2 v25.4s, v7.8h, v20.8h
smlal v24.4s, v8.4h, v21.4h
smlal2 v25.4s, v8.8h, v21.8h
dup v12.4s, w10
cbz w12, SKIP_LEFTSHIFT3
sqshl v24.4s, v24.4s, v29.4s
sqshl v25.4s, v25.4s, v29.4s
sqrdmulh v24.4s, v24.4s, v30.4s
sqrdmulh v25.4s, v25.4s, v30.4s
b OUTZP3
SKIP_LEFTSHIFT3:
sqrdmulh v24.4s, v24.4s, v30.4s
sqrdmulh v25.4s, v25.4s, v30.4s
sqrshl v24.4s, v24.4s, v31.4s
sqrshl v25.4s, v25.4s, v31.4s
OUTZP3:
// Add output zero point
dup v17.4s, w14
sqadd v24.4s, v24.4s, v12.4s
sqadd v25.4s, v25.4s, v12.4s
// Apply min bound
dup v22.4s, w15
smax v24.4s, v24.4s, v17.4s
smax v25.4s, v25.4s, v17.4s
// Apply max bound
smin v24.4s, v24.4s, v22.4s
smin v25.4s, v25.4s, v22.4s
sqxtn v24.4h, v24.4s
sqxtn2 v24.8h, v25.4s
sqxtn v24.8b, v24.8h
st1 {v24.8b}, [x0], x6
End:
sub sp, sp, #160
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ldp x19, x20, [sp], #16
ldp x21, x22, [sp], #16
ret
#endif

View File

@ -24,8 +24,8 @@ ConvDw3x3Int8Vertical:
dup v26.4s, w9 // out_zp
ldr x9, [sp, #8]
dup v27.4s, w9 // out_multiplier
ldr x9, [sp, #16]
dup v28.4s, w9 // left_shift
ldr x8, [sp, #16]
dup v28.4s, w8 // left_shift
ldr x9, [sp, #24]
dup v29.4s, w9 // right_shift
ldr x9, [sp, #32]
@ -105,26 +105,24 @@ ConvDw3x3Int8Vertical:
smlal v23.4s, v17.4h, v19.4h
ld1 {v18.8h}, [x10], x13
smlal2 v24.4s, v17.8h, v19.8h
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
and v21.16b, v29.16b, v23.16b
sshr v21.4s, v21.4s, #31
sqadd v23.4s, v23.4s, v21.4s
srshl v23.4s, v23.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
ld1 {v17.8b}, [x11], x5
ssubl v17.8h, v17.8b, v25.8b
ld1 {v19.8h}, [x12], x13
cbz w8, RightShiftLoop
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
b AddZpLoop
RightShiftLoop:
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
AddZpLoop:
add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s
@ -159,21 +157,20 @@ ConvDw3x3Int8Vertical:
smlal v23.4s, v17.4h, v19.4h
smlal2 v24.4s, v17.8h, v19.8h
cbz w8, RightShift
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
b AddZp
and v21.16b, v29.16b, v23.16b
sshr v21.4s, v21.4s, #31
sqadd v23.4s, v23.4s, v21.4s
srshl v23.4s, v23.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
RightShift:
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
AddZp:
add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s

View File

@ -77,6 +77,10 @@ void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *wei
int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp,
int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min,
int32_t acc_max);
void ConvDw3x3Int8Stride2(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias,
int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp,
int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min,
int32_t acc_max);
void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step,
size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, size_t out_multiplier,
size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max);

View File

@ -139,11 +139,22 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da
/*conv depthwise int8 end*/
/*conv depthwise 3x3 int8 begin*/
bool CheckIfUse3X3(const ConvParameter *conv_param, int channel) {
bool use_3x3 = conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && conv_param->stride_h_ == 1 &&
conv_param->stride_w_ == 1 && (conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) &&
bool CheckIfUse3X3(const ConvParameter *conv_param) {
bool use_3x3 = conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 &&
(conv_param->stride_h_ == 1 || conv_param->stride_h_ == 2) &&
(conv_param->stride_w_ == 1 || conv_param->stride_w_ == 2) &&
conv_param->stride_h_ == conv_param->stride_w_ &&
(conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) &&
(conv_param->pad_l_ == 0 || conv_param->pad_l_ == 1) && conv_param->pad_u_ == conv_param->pad_l_ &&
conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && (channel % 8 == 0);
conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && (conv_param->input_channel_ % 8 == 0);
if (!use_3x3) {
return false;
}
const int out_w = conv_param->output_w_ - 1;
const int out_h = conv_param->output_h_ - 1;
const int in_w = out_w * conv_param->stride_w_ - conv_param->pad_l_ + conv_param->kernel_w_;
const int in_h = out_h * conv_param->stride_h_ - conv_param->pad_u_ + conv_param->kernel_h_;
use_3x3 = in_w <= (conv_param->input_w_ + conv_param->pad_l_) && in_h <= (conv_param->input_h_ + conv_param->pad_u_);
return use_3x3;
}
@ -206,8 +217,14 @@ void ConvDw3x3Int8Block(int8_t *output, const int8_t *buffer, const int16_t *wei
int32_t acc_max, int stride) {
for (; start_c <= end_c - 8; start_c += 8) {
#ifdef ENABLE_ARM64
ConvDw3x3Int8Neon64(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp,
out_multiplier, left_shift, right_shift, acc_min, acc_max);
if (stride == 1) {
ConvDw3x3Int8Neon64(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp,
out_multiplier, left_shift, right_shift, acc_min, acc_max);
} else {
ConvDw3x3Int8Stride2(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp,
out_multiplier, left_shift, right_shift, acc_min, acc_max);
}
#else
ConvDw3x3Int8Window(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp,
out_multiplier, left_shift, right_shift, acc_min, acc_max, stride);

View File

@ -24,7 +24,7 @@
extern "C" {
#endif
bool CheckIfUse3X3(const ConvParameter *conv_param, int channel);
bool CheckIfUse3X3(const ConvParameter *conv_param);
void ConvDwInt8(int8_t *output_data, int32_t *output_row, const int8_t *input_data, const int16_t *weight_data,
const int32_t *bias_data, const ConvParameter *conv_param, int task_id);

View File

@ -168,22 +168,26 @@ kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector<lite::Tensor *>
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
kernel::LiteKernel *kernel;
kernel::LiteKernel *kernel = nullptr;
auto act_quant_size =
MSMAX(inputs[kInputIndex]->GetQuantParams().size(), outputs[kOutputIndex]->GetQuantParams().size());
if (act_quant_size == 1) { // per tensor
auto conv_parm = reinterpret_cast<ConvParameter *>(opParameter);
auto channel = inputs[kWeightIndex]->shape()[0];
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
if (primitive != nullptr && primitive->GetInferFlag()) {
conv_param->input_h_ = inputs[kInputIndex]->Height();
conv_param->input_w_ = inputs[kInputIndex]->Width();
conv_param->input_channel_ = inputs[kInputIndex]->Channel();
conv_param->output_h_ = outputs[kOutputIndex]->Height();
conv_param->output_w_ = outputs[kOutputIndex]->Width();
}
auto weight_quant_size = inputs[kWeightIndex]->GetQuantParams().size();
if (CheckIfUse3X3(conv_parm, channel) && weight_quant_size == 1) {
if (CheckIfUse3X3(conv_param) && weight_quant_size == 1) {
#ifdef ENABLE_ARM64
kernel =
new (std::nothrow) kernel::ConvolutionDepthwise3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
#else
kernel =
new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
#endif
} else {
}
if (kernel == nullptr) {
kernel =
new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
}

View File

@ -1,2 +0,0 @@
ssd-10.onnx
efficientnet-lite4-11.onnx