forked from mindspore-Ecosystem/mindspore
[MSLITE][Develop] optimize arm cpu int8 op conv dw 3x3, add stride 2 assembly
This commit is contained in:
parent
9fc0218c56
commit
6273bdaedb
|
@ -24,8 +24,8 @@ ConvDw3x3Int8Corner:
|
|||
dup v26.4s, w9 // out_zp
|
||||
ldr x9, [sp, #8]
|
||||
dup v27.4s, w9 // out_multiplier
|
||||
ldr x9, [sp, #16]
|
||||
dup v28.4s, w9 // left_shift
|
||||
ldr x8, [sp, #16]
|
||||
dup v28.4s, w8 // left_shift
|
||||
ldr x9, [sp, #24]
|
||||
dup v29.4s, w9 // right_shift
|
||||
ldr x9, [sp, #32]
|
||||
|
@ -85,26 +85,24 @@ ConvDw3x3Int8Corner:
|
|||
smlal v23.4s, v3.4h, v7.4h
|
||||
ld1 {v6.8h}, [x12], x13
|
||||
smlal2 v24.4s, v3.8h, v7.8h
|
||||
|
||||
sqshl v23.4s, v23.4s, v28.4s
|
||||
sqshl v24.4s, v24.4s, v28.4s
|
||||
sqrdmulh v23.4s, v23.4s, v27.4s
|
||||
sqrdmulh v24.4s, v24.4s, v27.4s
|
||||
|
||||
and v21.16b, v29.16b, v23.16b
|
||||
sshr v21.4s, v21.4s, #31
|
||||
sqadd v23.4s, v23.4s, v21.4s
|
||||
srshl v23.4s, v23.4s, v29.4s
|
||||
|
||||
and v22.16b, v29.16b, v24.16b
|
||||
sshr v22.4s, v22.4s, #31
|
||||
sqadd v24.4s, v24.4s, v22.4s
|
||||
srshl v24.4s, v24.4s, v29.4s
|
||||
|
||||
ld1 {v3.8b}, [x11], x5
|
||||
ssubl v3.8h, v3.8b, v25.8b
|
||||
ld1 {v7.8h}, [x12], x13
|
||||
|
||||
cbz w8, RightShiftLoop
|
||||
sqshl v23.4s, v23.4s, v28.4s
|
||||
sqshl v24.4s, v24.4s, v28.4s
|
||||
sqrdmulh v23.4s, v23.4s, v27.4s
|
||||
sqrdmulh v24.4s, v24.4s, v27.4s
|
||||
b AddZpLoop
|
||||
|
||||
RightShiftLoop:
|
||||
sqrdmulh v23.4s, v23.4s, v27.4s
|
||||
sqrdmulh v24.4s, v24.4s, v27.4s
|
||||
sqrshl v23.4s, v23.4s, v29.4s
|
||||
sqrshl v24.4s, v24.4s, v29.4s
|
||||
|
||||
AddZpLoop:
|
||||
add v23.4s, v23.4s, v26.4s
|
||||
add v24.4s, v24.4s, v26.4s
|
||||
smax v23.4s, v23.4s, v30.4s
|
||||
|
@ -135,21 +133,20 @@ ConvDw3x3Int8Corner:
|
|||
smlal v23.4s, v3.4h, v7.4h
|
||||
smlal2 v24.4s, v3.8h, v7.8h
|
||||
|
||||
cbz w8, RightShift
|
||||
sqshl v23.4s, v23.4s, v28.4s
|
||||
sqshl v24.4s, v24.4s, v28.4s
|
||||
sqrdmulh v23.4s, v23.4s, v27.4s
|
||||
sqrdmulh v24.4s, v24.4s, v27.4s
|
||||
b AddZp
|
||||
|
||||
and v21.16b, v29.16b, v23.16b
|
||||
sshr v21.4s, v21.4s, #31
|
||||
sqadd v23.4s, v23.4s, v21.4s
|
||||
srshl v23.4s, v23.4s, v29.4s
|
||||
|
||||
and v22.16b, v29.16b, v24.16b
|
||||
sshr v22.4s, v22.4s, #31
|
||||
sqadd v24.4s, v24.4s, v22.4s
|
||||
srshl v24.4s, v24.4s, v29.4s
|
||||
RightShift:
|
||||
sqrdmulh v23.4s, v23.4s, v27.4s
|
||||
sqrdmulh v24.4s, v24.4s, v27.4s
|
||||
sqrshl v23.4s, v23.4s, v29.4s
|
||||
sqrshl v24.4s, v24.4s, v29.4s
|
||||
|
||||
AddZp:
|
||||
add v23.4s, v23.4s, v26.4s
|
||||
add v24.4s, v24.4s, v26.4s
|
||||
smax v23.4s, v23.4s, v30.4s
|
||||
|
|
|
@ -24,8 +24,8 @@ ConvDw3x3Int8Horizontal:
|
|||
dup v26.4s, w9 // out_zp
|
||||
ldr x9, [sp, #8]
|
||||
dup v27.4s, w9 // out_multiplier
|
||||
ldr x9, [sp, #16]
|
||||
dup v28.4s, w9 // left_shift
|
||||
ldr x8, [sp, #16]
|
||||
dup v28.4s, w8 // left_shift
|
||||
ldr x9, [sp, #24]
|
||||
dup v29.4s, w9 // right_shift
|
||||
ldr x9, [sp, #32]
|
||||
|
@ -109,26 +109,24 @@ ConvDw3x3Int8Horizontal:
|
|||
smlal v23.4s, v17.4h, v19.4h
|
||||
ld1 {v18.8h}, [x16], x13
|
||||
smlal2 v24.4s, v17.8h, v19.8h
|
||||
|
||||
sqshl v23.4s, v23.4s, v28.4s
|
||||
sqshl v24.4s, v24.4s, v28.4s
|
||||
sqrdmulh v23.4s, v23.4s, v27.4s
|
||||
sqrdmulh v24.4s, v24.4s, v27.4s
|
||||
|
||||
and v21.16b, v29.16b, v23.16b
|
||||
sshr v21.4s, v21.4s, #31
|
||||
sqadd v23.4s, v23.4s, v21.4s
|
||||
srshl v23.4s, v23.4s, v29.4s
|
||||
|
||||
and v22.16b, v29.16b, v24.16b
|
||||
sshr v22.4s, v22.4s, #31
|
||||
sqadd v24.4s, v24.4s, v22.4s
|
||||
srshl v24.4s, v24.4s, v29.4s
|
||||
|
||||
ld1 {v17.8b}, [x15], x5
|
||||
ssubl v17.8h, v17.8b, v25.8b
|
||||
ld1 {v19.8h}, [x16], x13
|
||||
|
||||
cbz w8, RightShiftLoop
|
||||
sqshl v23.4s, v23.4s, v28.4s
|
||||
sqshl v24.4s, v24.4s, v28.4s
|
||||
sqrdmulh v23.4s, v23.4s, v27.4s
|
||||
sqrdmulh v24.4s, v24.4s, v27.4s
|
||||
b AddZpLoop
|
||||
|
||||
RightShiftLoop:
|
||||
sqrdmulh v23.4s, v23.4s, v27.4s
|
||||
sqrdmulh v24.4s, v24.4s, v27.4s
|
||||
sqrshl v23.4s, v23.4s, v29.4s
|
||||
sqrshl v24.4s, v24.4s, v29.4s
|
||||
|
||||
AddZpLoop:
|
||||
add v23.4s, v23.4s, v26.4s
|
||||
add v24.4s, v24.4s, v26.4s
|
||||
smax v23.4s, v23.4s, v30.4s
|
||||
|
@ -163,21 +161,20 @@ ConvDw3x3Int8Horizontal:
|
|||
smlal v23.4s, v17.4h, v19.4h
|
||||
smlal2 v24.4s, v17.8h, v19.8h
|
||||
|
||||
cbz w8, RightShift
|
||||
sqshl v23.4s, v23.4s, v28.4s
|
||||
sqshl v24.4s, v24.4s, v28.4s
|
||||
sqrdmulh v23.4s, v23.4s, v27.4s
|
||||
sqrdmulh v24.4s, v24.4s, v27.4s
|
||||
b AddZp
|
||||
|
||||
and v21.16b, v29.16b, v23.16b
|
||||
sshr v21.4s, v21.4s, #31
|
||||
sqadd v23.4s, v23.4s, v21.4s
|
||||
srshl v23.4s, v23.4s, v29.4s
|
||||
|
||||
and v22.16b, v29.16b, v24.16b
|
||||
sshr v22.4s, v22.4s, #31
|
||||
sqadd v24.4s, v24.4s, v22.4s
|
||||
srshl v24.4s, v24.4s, v29.4s
|
||||
RightShift:
|
||||
sqrdmulh v23.4s, v23.4s, v27.4s
|
||||
sqrdmulh v24.4s, v24.4s, v27.4s
|
||||
sqrshl v23.4s, v23.4s, v29.4s
|
||||
sqrshl v24.4s, v24.4s, v29.4s
|
||||
|
||||
AddZp:
|
||||
add v23.4s, v23.4s, v26.4s
|
||||
add v24.4s, v24.4s, v26.4s
|
||||
smax v23.4s, v23.4s, v30.4s
|
||||
|
|
|
@ -0,0 +1,395 @@
|
|||
#ifdef __aarch64__
|
||||
|
||||
.text
|
||||
.align 5
|
||||
.global ConvDw3x3Int8Stride2
|
||||
#ifndef __APPLE__
|
||||
.type ConvDw3x3Int8Stride2, %function
|
||||
#endif
|
||||
|
||||
|
||||
// void ConvDw3x3Int8Stride2(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, int input_col_size,
|
||||
// int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp,
|
||||
// int out_multiplier, int left_shift, int right_shift, int32_t acc_min, int32_t acc_max)
|
||||
//
|
||||
// x0: output
|
||||
// x1: input
|
||||
// x2: weight
|
||||
// x3: bias
|
||||
// w4: col_size
|
||||
// w5: row_size
|
||||
// w6: channel
|
||||
// w7: output_h
|
||||
// w8: output_w
|
||||
// w9: in_zp
|
||||
// w10: out_zp
|
||||
// w11: out_multiplier
|
||||
// w12: left_shift
|
||||
// w13: right_shift
|
||||
// w14: acc_min
|
||||
// w15: acc_max
|
||||
|
||||
ConvDw3x3Int8Stride2:
|
||||
sub sp, sp, #160
|
||||
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
stp x19, x20, [sp], #16
|
||||
stp x21, x22, [sp], #16
|
||||
|
||||
ldr w8, [sp]
|
||||
ldr w9, [sp, #8]
|
||||
ldr w10, [sp, #16] // out_zp
|
||||
ldr w11, [sp, #24]
|
||||
ldr w12, [sp, #32]
|
||||
ldr w13, [sp, #40]
|
||||
ldr w14, [sp, #48] // acc_min
|
||||
ldr w15, [sp, #56] // acc_max
|
||||
|
||||
add x19, x3, #16
|
||||
add w20, w6, w6 // channel * 2
|
||||
dup v28.8b, w9 // in_zp
|
||||
dup v30.4s, w11 // out_multiplier
|
||||
dup v29.4s, w12 // left_shift
|
||||
dup v31.4s, w13 // right_shift
|
||||
|
||||
// Load weights
|
||||
ld1 {v0.8h}, [x2], x20
|
||||
ld1 {v1.8h}, [x2], x20
|
||||
ld1 {v2.8h}, [x2], x20
|
||||
ld1 {v3.8h}, [x2], x20
|
||||
ld1 {v4.8h}, [x2], x20
|
||||
ld1 {v5.8h}, [x2], x20
|
||||
ld1 {v6.8h}, [x2], x20
|
||||
ld1 {v7.8h}, [x2], x20
|
||||
ld1 {v8.8h}, [x2], x20
|
||||
|
||||
mov x16, x1
|
||||
add x17, x16, x5
|
||||
add x18, x17, x5
|
||||
ld1 {v9.8b}, [x16], x4
|
||||
ld1 {v10.8b}, [x16], x4
|
||||
ssubl v9.8h, v9.8b, v28.8b
|
||||
ld1 {v11.8b}, [x16], x4
|
||||
ssubl v10.8h, v10.8b, v28.8b
|
||||
ld1 {v14.8b}, [x17], x4
|
||||
ssubl v11.8h, v11.8b, v28.8b
|
||||
ld1 {v15.8b}, [x17], x4
|
||||
ssubl v14.8h, v14.8b, v28.8b
|
||||
ld1 {v16.8b}, [x17], x4
|
||||
ssubl v15.8h, v15.8b, v28.8b
|
||||
ld1 {v19.8b}, [x18], x4
|
||||
ssubl v16.8h, v16.8b, v28.8b
|
||||
ld1 {v20.8b}, [x18], x4
|
||||
ssubl v19.8h, v19.8b, v28.8b
|
||||
ld1 {v21.8b}, [x18], x4
|
||||
ssubl v20.8h, v20.8b, v28.8b
|
||||
ssubl v21.8h, v21.8b, v28.8b
|
||||
|
||||
ld1 {v24.4s}, [x3]
|
||||
ld1 {v25.4s}, [x19]
|
||||
ld1 {v26.4s}, [x3]
|
||||
ld1 {v27.4s}, [x19]
|
||||
|
||||
cmp w8, #2
|
||||
beq WIDTH2_LEFT
|
||||
cmp w8, #1
|
||||
beq WIDTH1_LEFT
|
||||
|
||||
HEIGHT1_LOOP:
|
||||
smlal v24.4s, v0.4h, v9.4h
|
||||
ld1 {v12.8b}, [x16], x4
|
||||
smlal2 v25.4s, v0.8h, v9.8h
|
||||
ld1 {v17.8b}, [x17], x4
|
||||
ssubl v12.8h, v12.8b, v28.8b
|
||||
smlal v26.4s, v0.4h, v11.4h
|
||||
ld1 {v22.8b}, [x18], x4
|
||||
ssubl v17.8h, v17.8b, v28.8b
|
||||
smlal2 v27.4s, v0.8h, v11.8h
|
||||
ld1 {v13.8b}, [x16], x4
|
||||
ssubl v22.8h, v22.8b, v28.8b
|
||||
smlal v24.4s, v1.4h, v10.4h
|
||||
ld1 {v18.8b}, [x17], x4
|
||||
ssubl v13.8h, v13.8b, v28.8b
|
||||
smlal2 v25.4s, v1.8h, v10.8h
|
||||
ld1 {v23.8b}, [x18], x4
|
||||
ssubl v18.8h, v18.8b, v28.8b
|
||||
smlal v26.4s, v1.4h, v12.4h
|
||||
mov v9.8h, v13.8h
|
||||
ssubl v23.8h, v23.8b, v28.8b
|
||||
smlal2 v27.4s, v1.8h, v12.8h
|
||||
ld1 {v10.8b}, [x16], x4
|
||||
smlal v24.4s, v2.4h, v11.4h
|
||||
smlal2 v25.4s, v2.8h, v11.8h
|
||||
ld1 {v11.8b}, [x16], x4
|
||||
ssubl v10.8h, v10.8b, v28.8b
|
||||
smlal v26.4s, v2.4h, v13.4h
|
||||
ssubl v11.8h, v11.8b, v28.8b
|
||||
smlal2 v27.4s, v2.8h, v13.8h
|
||||
|
||||
smlal v24.4s, v3.4h, v14.4h
|
||||
smlal2 v25.4s, v3.8h, v14.8h
|
||||
mov v14.8h, v18.8h
|
||||
smlal v26.4s, v3.4h, v16.4h
|
||||
smlal2 v27.4s, v3.8h, v16.8h
|
||||
smlal v24.4s, v4.4h, v15.4h
|
||||
smlal2 v25.4s, v4.8h, v15.8h
|
||||
ld1 {v15.8b}, [x17], x4
|
||||
smlal v26.4s, v4.4h, v17.4h
|
||||
smlal2 v27.4s, v4.8h, v17.8h
|
||||
smlal v24.4s, v5.4h, v16.4h
|
||||
smlal2 v25.4s, v5.8h, v16.8h
|
||||
ld1 {v16.8b}, [x17], x4
|
||||
ssubl v15.8h, v15.8b, v28.8b
|
||||
smlal v26.4s, v5.4h, v18.4h
|
||||
ssubl v16.8h, v16.8b, v28.8b
|
||||
smlal2 v27.4s, v5.8h, v18.8h
|
||||
|
||||
smlal v24.4s, v6.4h, v19.4h
|
||||
smlal2 v25.4s, v6.8h, v19.8h
|
||||
mov v19.8h, v23.8h
|
||||
smlal v26.4s, v6.4h, v21.4h
|
||||
smlal2 v27.4s, v6.8h, v21.8h
|
||||
smlal v24.4s, v7.4h, v20.4h
|
||||
smlal2 v25.4s, v7.8h, v20.8h
|
||||
ld1 {v20.8b}, [x18], x4
|
||||
smlal v26.4s, v7.4h, v22.4h
|
||||
smlal2 v27.4s, v7.8h, v22.8h
|
||||
smlal v24.4s, v8.4h, v21.4h
|
||||
smlal2 v25.4s, v8.8h, v21.8h
|
||||
ld1 {v21.8b}, [x18], x4
|
||||
ssubl v20.8h, v20.8b, v28.8b
|
||||
smlal v26.4s, v8.4h, v23.4h
|
||||
ssubl v21.8h, v21.8b, v28.8b
|
||||
smlal2 v27.4s, v8.8h, v23.8h
|
||||
|
||||
dup v12.4s, w10 // out_zp
|
||||
cbz w12, SKIP_LEFTSHIFT1
|
||||
sqshl v24.4s, v24.4s, v29.4s
|
||||
sqshl v25.4s, v25.4s, v29.4s
|
||||
sqshl v26.4s, v26.4s, v29.4s
|
||||
sqshl v27.4s, v27.4s, v29.4s
|
||||
sqrdmulh v24.4s, v24.4s, v30.4s
|
||||
sqrdmulh v25.4s, v25.4s, v30.4s
|
||||
sqrdmulh v26.4s, v26.4s, v30.4s
|
||||
sqrdmulh v27.4s, v27.4s, v30.4s
|
||||
b OUTZP1
|
||||
|
||||
SKIP_LEFTSHIFT1:
|
||||
sqrdmulh v24.4s, v24.4s, v30.4s
|
||||
sqrdmulh v25.4s, v25.4s, v30.4s
|
||||
sqrdmulh v26.4s, v26.4s, v30.4s
|
||||
sqrdmulh v27.4s, v27.4s, v30.4s
|
||||
sqrshl v24.4s, v24.4s, v31.4s
|
||||
sqrshl v25.4s, v25.4s, v31.4s
|
||||
sqrshl v26.4s, v26.4s, v31.4s
|
||||
sqrshl v27.4s, v27.4s, v31.4s
|
||||
|
||||
OUTZP1:
|
||||
// Add output zero point
|
||||
dup v17.4s, w14 // acc_min
|
||||
sqadd v24.4s, v24.4s, v12.4s
|
||||
sqadd v25.4s, v25.4s, v12.4s
|
||||
sqadd v26.4s, v26.4s, v12.4s
|
||||
sqadd v27.4s, v27.4s, v12.4s
|
||||
|
||||
// Apply min bound
|
||||
dup v22.4s, w15 // acc_max
|
||||
smax v24.4s, v24.4s, v17.4s
|
||||
smax v25.4s, v25.4s, v17.4s
|
||||
smax v26.4s, v26.4s, v17.4s
|
||||
smax v27.4s, v27.4s, v17.4s
|
||||
|
||||
// Apply max bound
|
||||
smin v24.4s, v24.4s, v22.4s
|
||||
smin v25.4s, v25.4s, v22.4s
|
||||
smin v26.4s, v26.4s, v22.4s
|
||||
smin v27.4s, v27.4s, v22.4s
|
||||
|
||||
sqxtn v24.4h, v24.4s
|
||||
sqxtn2 v24.8h, v25.4s
|
||||
ld1 {v25.4s}, [x19]
|
||||
sqxtn v26.4h, v26.4s
|
||||
sqxtn2 v26.8h, v27.4s
|
||||
ld1 {v27.4s}, [x19]
|
||||
sqxtn v24.8b, v24.8h
|
||||
sqxtn2 v24.16b, v26.8h
|
||||
st1 {v24.8b}, [x0], x6
|
||||
mov v26.d[0], v24.d[1]
|
||||
ld1 {v24.4s}, [x3]
|
||||
st1 {v26.8b}, [x0], x6
|
||||
ld1 {v26.4s}, [x3]
|
||||
sub w8, w8, #2
|
||||
cmp w8, #2
|
||||
bgt HEIGHT1_LOOP
|
||||
|
||||
cmp w8, #2
|
||||
blt WIDTH1_LEFT
|
||||
|
||||
WIDTH2_LEFT:
|
||||
smlal v24.4s, v0.4h, v9.4h
|
||||
ld1 {v12.8b}, [x16], x4
|
||||
smlal2 v25.4s, v0.8h, v9.8h
|
||||
ld1 {v17.8b}, [x17], x4
|
||||
ssubl v12.8h, v12.8b, v28.8b
|
||||
smlal v26.4s, v0.4h, v11.4h
|
||||
ld1 {v22.8b}, [x18], x4
|
||||
ssubl v17.8h, v17.8b, v28.8b
|
||||
smlal2 v27.4s, v0.8h, v11.8h
|
||||
ld1 {v13.8b}, [x16], x4
|
||||
ssubl v22.8h, v22.8b, v28.8b
|
||||
smlal v24.4s, v1.4h, v10.4h
|
||||
ld1 {v18.8b}, [x17], x4
|
||||
ssubl v13.8h, v13.8b, v28.8b
|
||||
smlal2 v25.4s, v1.8h, v10.8h
|
||||
ld1 {v23.8b}, [x18], x4
|
||||
ssubl v18.8h, v18.8b, v28.8b
|
||||
smlal v26.4s, v1.4h, v12.4h
|
||||
ssubl v23.8h, v23.8b, v28.8b
|
||||
smlal2 v27.4s, v1.8h, v12.8h
|
||||
smlal v24.4s, v2.4h, v11.4h
|
||||
smlal2 v25.4s, v2.8h, v11.8h
|
||||
smlal v26.4s, v2.4h, v13.4h
|
||||
smlal2 v27.4s, v2.8h, v13.8h
|
||||
|
||||
smlal v24.4s, v3.4h, v14.4h
|
||||
smlal2 v25.4s, v3.8h, v14.8h
|
||||
smlal v26.4s, v3.4h, v16.4h
|
||||
smlal2 v27.4s, v3.8h, v16.8h
|
||||
smlal v24.4s, v4.4h, v15.4h
|
||||
smlal2 v25.4s, v4.8h, v15.8h
|
||||
smlal v26.4s, v4.4h, v17.4h
|
||||
smlal2 v27.4s, v4.8h, v17.8h
|
||||
smlal v24.4s, v5.4h, v16.4h
|
||||
smlal2 v25.4s, v5.8h, v16.8h
|
||||
smlal v26.4s, v5.4h, v18.4h
|
||||
smlal2 v27.4s, v5.8h, v18.8h
|
||||
|
||||
smlal v24.4s, v6.4h, v19.4h
|
||||
smlal2 v25.4s, v6.8h, v19.8h
|
||||
smlal v26.4s, v6.4h, v21.4h
|
||||
smlal2 v27.4s, v6.8h, v21.8h
|
||||
smlal v24.4s, v7.4h, v20.4h
|
||||
smlal2 v25.4s, v7.8h, v20.8h
|
||||
smlal v26.4s, v7.4h, v22.4h
|
||||
smlal2 v27.4s, v7.8h, v22.8h
|
||||
smlal v24.4s, v8.4h, v21.4h
|
||||
smlal2 v25.4s, v8.8h, v21.8h
|
||||
smlal v26.4s, v8.4h, v23.4h
|
||||
smlal2 v27.4s, v8.8h, v23.8h
|
||||
|
||||
dup v12.4s, w10
|
||||
cbz w12, SKIP_LEFTSHIFT2
|
||||
sqshl v24.4s, v24.4s, v29.4s
|
||||
sqshl v25.4s, v25.4s, v29.4s
|
||||
sqshl v26.4s, v26.4s, v29.4s
|
||||
sqshl v27.4s, v27.4s, v29.4s
|
||||
sqrdmulh v24.4s, v24.4s, v30.4s
|
||||
sqrdmulh v25.4s, v25.4s, v30.4s
|
||||
sqrdmulh v26.4s, v26.4s, v30.4s
|
||||
sqrdmulh v27.4s, v27.4s, v30.4s
|
||||
b OUTZP2
|
||||
|
||||
SKIP_LEFTSHIFT2:
|
||||
sqrdmulh v24.4s, v24.4s, v30.4s
|
||||
sqrdmulh v25.4s, v25.4s, v30.4s
|
||||
sqrdmulh v26.4s, v26.4s, v30.4s
|
||||
sqrdmulh v27.4s, v27.4s, v30.4s
|
||||
sqrshl v24.4s, v24.4s, v31.4s
|
||||
sqrshl v25.4s, v25.4s, v31.4s
|
||||
sqrshl v26.4s, v26.4s, v31.4s
|
||||
sqrshl v27.4s, v27.4s, v31.4s
|
||||
|
||||
OUTZP2:
|
||||
// Add output zero point
|
||||
dup v17.4s, w14
|
||||
sqadd v24.4s, v24.4s, v12.4s
|
||||
sqadd v25.4s, v25.4s, v12.4s
|
||||
sqadd v26.4s, v26.4s, v12.4s
|
||||
sqadd v27.4s, v27.4s, v12.4s
|
||||
|
||||
// Apply min bound
|
||||
dup v22.4s, w15
|
||||
smax v24.4s, v24.4s, v17.4s
|
||||
smax v25.4s, v25.4s, v17.4s
|
||||
smax v26.4s, v26.4s, v17.4s
|
||||
smax v27.4s, v27.4s, v17.4s
|
||||
|
||||
// Apply max bound
|
||||
smin v24.4s, v24.4s, v22.4s
|
||||
smin v25.4s, v25.4s, v22.4s
|
||||
smin v26.4s, v26.4s, v22.4s
|
||||
smin v27.4s, v27.4s, v22.4s
|
||||
|
||||
sqxtn v24.4h, v24.4s
|
||||
sqxtn2 v24.8h, v25.4s
|
||||
sqxtn v26.4h, v26.4s
|
||||
sqxtn2 v26.8h, v27.4s
|
||||
sqxtn v24.8b, v24.8h
|
||||
sqxtn2 v24.16b, v26.8h
|
||||
st1 {v24.8b}, [x0], x6
|
||||
mov v26.d[0], v24.d[1]
|
||||
st1 {v26.8b}, [x0], x6
|
||||
b End
|
||||
|
||||
WIDTH1_LEFT:
|
||||
smlal v24.4s, v0.4h, v9.4h
|
||||
smlal2 v25.4s, v0.8h, v9.8h
|
||||
smlal v24.4s, v1.4h, v10.4h
|
||||
smlal2 v25.4s, v1.8h, v10.8h
|
||||
smlal v24.4s, v2.4h, v11.4h
|
||||
smlal2 v25.4s, v2.8h, v11.8h
|
||||
smlal v24.4s, v3.4h, v14.4h
|
||||
smlal2 v25.4s, v3.8h, v14.8h
|
||||
smlal v24.4s, v4.4h, v15.4h
|
||||
smlal2 v25.4s, v4.8h, v15.8h
|
||||
smlal v24.4s, v5.4h, v16.4h
|
||||
smlal2 v25.4s, v5.8h, v16.8h
|
||||
smlal v24.4s, v6.4h, v19.4h
|
||||
smlal2 v25.4s, v6.8h, v19.8h
|
||||
smlal v24.4s, v7.4h, v20.4h
|
||||
smlal2 v25.4s, v7.8h, v20.8h
|
||||
smlal v24.4s, v8.4h, v21.4h
|
||||
smlal2 v25.4s, v8.8h, v21.8h
|
||||
|
||||
dup v12.4s, w10
|
||||
cbz w12, SKIP_LEFTSHIFT3
|
||||
sqshl v24.4s, v24.4s, v29.4s
|
||||
sqshl v25.4s, v25.4s, v29.4s
|
||||
sqrdmulh v24.4s, v24.4s, v30.4s
|
||||
sqrdmulh v25.4s, v25.4s, v30.4s
|
||||
b OUTZP3
|
||||
|
||||
SKIP_LEFTSHIFT3:
|
||||
sqrdmulh v24.4s, v24.4s, v30.4s
|
||||
sqrdmulh v25.4s, v25.4s, v30.4s
|
||||
sqrshl v24.4s, v24.4s, v31.4s
|
||||
sqrshl v25.4s, v25.4s, v31.4s
|
||||
|
||||
OUTZP3:
|
||||
// Add output zero point
|
||||
dup v17.4s, w14
|
||||
sqadd v24.4s, v24.4s, v12.4s
|
||||
sqadd v25.4s, v25.4s, v12.4s
|
||||
|
||||
// Apply min bound
|
||||
dup v22.4s, w15
|
||||
smax v24.4s, v24.4s, v17.4s
|
||||
smax v25.4s, v25.4s, v17.4s
|
||||
|
||||
// Apply max bound
|
||||
smin v24.4s, v24.4s, v22.4s
|
||||
smin v25.4s, v25.4s, v22.4s
|
||||
|
||||
sqxtn v24.4h, v24.4s
|
||||
sqxtn2 v24.8h, v25.4s
|
||||
sqxtn v24.8b, v24.8h
|
||||
st1 {v24.8b}, [x0], x6
|
||||
|
||||
End:
|
||||
sub sp, sp, #160
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
ldp x19, x20, [sp], #16
|
||||
ldp x21, x22, [sp], #16
|
||||
ret
|
||||
#endif
|
|
@ -24,8 +24,8 @@ ConvDw3x3Int8Vertical:
|
|||
dup v26.4s, w9 // out_zp
|
||||
ldr x9, [sp, #8]
|
||||
dup v27.4s, w9 // out_multiplier
|
||||
ldr x9, [sp, #16]
|
||||
dup v28.4s, w9 // left_shift
|
||||
ldr x8, [sp, #16]
|
||||
dup v28.4s, w8 // left_shift
|
||||
ldr x9, [sp, #24]
|
||||
dup v29.4s, w9 // right_shift
|
||||
ldr x9, [sp, #32]
|
||||
|
@ -105,26 +105,24 @@ ConvDw3x3Int8Vertical:
|
|||
smlal v23.4s, v17.4h, v19.4h
|
||||
ld1 {v18.8h}, [x10], x13
|
||||
smlal2 v24.4s, v17.8h, v19.8h
|
||||
|
||||
sqshl v23.4s, v23.4s, v28.4s
|
||||
sqshl v24.4s, v24.4s, v28.4s
|
||||
sqrdmulh v23.4s, v23.4s, v27.4s
|
||||
sqrdmulh v24.4s, v24.4s, v27.4s
|
||||
|
||||
and v21.16b, v29.16b, v23.16b
|
||||
sshr v21.4s, v21.4s, #31
|
||||
sqadd v23.4s, v23.4s, v21.4s
|
||||
srshl v23.4s, v23.4s, v29.4s
|
||||
|
||||
and v22.16b, v29.16b, v24.16b
|
||||
sshr v22.4s, v22.4s, #31
|
||||
sqadd v24.4s, v24.4s, v22.4s
|
||||
srshl v24.4s, v24.4s, v29.4s
|
||||
|
||||
ld1 {v17.8b}, [x11], x5
|
||||
ssubl v17.8h, v17.8b, v25.8b
|
||||
ld1 {v19.8h}, [x12], x13
|
||||
|
||||
cbz w8, RightShiftLoop
|
||||
sqshl v23.4s, v23.4s, v28.4s
|
||||
sqshl v24.4s, v24.4s, v28.4s
|
||||
sqrdmulh v23.4s, v23.4s, v27.4s
|
||||
sqrdmulh v24.4s, v24.4s, v27.4s
|
||||
b AddZpLoop
|
||||
|
||||
RightShiftLoop:
|
||||
sqrdmulh v23.4s, v23.4s, v27.4s
|
||||
sqrdmulh v24.4s, v24.4s, v27.4s
|
||||
sqrshl v23.4s, v23.4s, v29.4s
|
||||
sqrshl v24.4s, v24.4s, v29.4s
|
||||
|
||||
AddZpLoop:
|
||||
add v23.4s, v23.4s, v26.4s
|
||||
add v24.4s, v24.4s, v26.4s
|
||||
smax v23.4s, v23.4s, v30.4s
|
||||
|
@ -159,21 +157,20 @@ ConvDw3x3Int8Vertical:
|
|||
smlal v23.4s, v17.4h, v19.4h
|
||||
smlal2 v24.4s, v17.8h, v19.8h
|
||||
|
||||
cbz w8, RightShift
|
||||
sqshl v23.4s, v23.4s, v28.4s
|
||||
sqshl v24.4s, v24.4s, v28.4s
|
||||
sqrdmulh v23.4s, v23.4s, v27.4s
|
||||
sqrdmulh v24.4s, v24.4s, v27.4s
|
||||
b AddZp
|
||||
|
||||
and v21.16b, v29.16b, v23.16b
|
||||
sshr v21.4s, v21.4s, #31
|
||||
sqadd v23.4s, v23.4s, v21.4s
|
||||
srshl v23.4s, v23.4s, v29.4s
|
||||
|
||||
and v22.16b, v29.16b, v24.16b
|
||||
sshr v22.4s, v22.4s, #31
|
||||
sqadd v24.4s, v24.4s, v22.4s
|
||||
srshl v24.4s, v24.4s, v29.4s
|
||||
RightShift:
|
||||
sqrdmulh v23.4s, v23.4s, v27.4s
|
||||
sqrdmulh v24.4s, v24.4s, v27.4s
|
||||
sqrshl v23.4s, v23.4s, v29.4s
|
||||
sqrshl v24.4s, v24.4s, v29.4s
|
||||
|
||||
AddZp:
|
||||
add v23.4s, v23.4s, v26.4s
|
||||
add v24.4s, v24.4s, v26.4s
|
||||
smax v23.4s, v23.4s, v30.4s
|
||||
|
|
|
@ -77,6 +77,10 @@ void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *wei
|
|||
int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp,
|
||||
int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min,
|
||||
int32_t acc_max);
|
||||
void ConvDw3x3Int8Stride2(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias,
|
||||
int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp,
|
||||
int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min,
|
||||
int32_t acc_max);
|
||||
void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step,
|
||||
size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, size_t out_multiplier,
|
||||
size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max);
|
||||
|
|
|
@ -139,11 +139,22 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da
|
|||
/*conv depthwise int8 end*/
|
||||
|
||||
/*conv depthwise 3x3 int8 begin*/
|
||||
bool CheckIfUse3X3(const ConvParameter *conv_param, int channel) {
|
||||
bool use_3x3 = conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && conv_param->stride_h_ == 1 &&
|
||||
conv_param->stride_w_ == 1 && (conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) &&
|
||||
bool CheckIfUse3X3(const ConvParameter *conv_param) {
|
||||
bool use_3x3 = conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 &&
|
||||
(conv_param->stride_h_ == 1 || conv_param->stride_h_ == 2) &&
|
||||
(conv_param->stride_w_ == 1 || conv_param->stride_w_ == 2) &&
|
||||
conv_param->stride_h_ == conv_param->stride_w_ &&
|
||||
(conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) &&
|
||||
(conv_param->pad_l_ == 0 || conv_param->pad_l_ == 1) && conv_param->pad_u_ == conv_param->pad_l_ &&
|
||||
conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && (channel % 8 == 0);
|
||||
conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && (conv_param->input_channel_ % 8 == 0);
|
||||
if (!use_3x3) {
|
||||
return false;
|
||||
}
|
||||
const int out_w = conv_param->output_w_ - 1;
|
||||
const int out_h = conv_param->output_h_ - 1;
|
||||
const int in_w = out_w * conv_param->stride_w_ - conv_param->pad_l_ + conv_param->kernel_w_;
|
||||
const int in_h = out_h * conv_param->stride_h_ - conv_param->pad_u_ + conv_param->kernel_h_;
|
||||
use_3x3 = in_w <= (conv_param->input_w_ + conv_param->pad_l_) && in_h <= (conv_param->input_h_ + conv_param->pad_u_);
|
||||
return use_3x3;
|
||||
}
|
||||
|
||||
|
@ -206,8 +217,14 @@ void ConvDw3x3Int8Block(int8_t *output, const int8_t *buffer, const int16_t *wei
|
|||
int32_t acc_max, int stride) {
|
||||
for (; start_c <= end_c - 8; start_c += 8) {
|
||||
#ifdef ENABLE_ARM64
|
||||
ConvDw3x3Int8Neon64(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp,
|
||||
out_multiplier, left_shift, right_shift, acc_min, acc_max);
|
||||
if (stride == 1) {
|
||||
ConvDw3x3Int8Neon64(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp,
|
||||
out_multiplier, left_shift, right_shift, acc_min, acc_max);
|
||||
} else {
|
||||
ConvDw3x3Int8Stride2(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp,
|
||||
out_multiplier, left_shift, right_shift, acc_min, acc_max);
|
||||
}
|
||||
|
||||
#else
|
||||
ConvDw3x3Int8Window(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp,
|
||||
out_multiplier, left_shift, right_shift, acc_min, acc_max, stride);
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
extern "C" {
|
||||
#endif
|
||||
|
||||
bool CheckIfUse3X3(const ConvParameter *conv_param, int channel);
|
||||
bool CheckIfUse3X3(const ConvParameter *conv_param);
|
||||
|
||||
void ConvDwInt8(int8_t *output_data, int32_t *output_row, const int8_t *input_data, const int16_t *weight_data,
|
||||
const int32_t *bias_data, const ConvParameter *conv_param, int task_id);
|
||||
|
|
|
@ -168,22 +168,26 @@ kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector<lite::Tensor *>
|
|||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
MS_ASSERT(opParameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
|
||||
kernel::LiteKernel *kernel;
|
||||
kernel::LiteKernel *kernel = nullptr;
|
||||
auto act_quant_size =
|
||||
MSMAX(inputs[kInputIndex]->GetQuantParams().size(), outputs[kOutputIndex]->GetQuantParams().size());
|
||||
if (act_quant_size == 1) { // per tensor
|
||||
auto conv_parm = reinterpret_cast<ConvParameter *>(opParameter);
|
||||
auto channel = inputs[kWeightIndex]->shape()[0];
|
||||
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
|
||||
if (primitive != nullptr && primitive->GetInferFlag()) {
|
||||
conv_param->input_h_ = inputs[kInputIndex]->Height();
|
||||
conv_param->input_w_ = inputs[kInputIndex]->Width();
|
||||
conv_param->input_channel_ = inputs[kInputIndex]->Channel();
|
||||
conv_param->output_h_ = outputs[kOutputIndex]->Height();
|
||||
conv_param->output_w_ = outputs[kOutputIndex]->Width();
|
||||
}
|
||||
auto weight_quant_size = inputs[kWeightIndex]->GetQuantParams().size();
|
||||
if (CheckIfUse3X3(conv_parm, channel) && weight_quant_size == 1) {
|
||||
if (CheckIfUse3X3(conv_param) && weight_quant_size == 1) {
|
||||
#ifdef ENABLE_ARM64
|
||||
kernel =
|
||||
new (std::nothrow) kernel::ConvolutionDepthwise3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
#else
|
||||
kernel =
|
||||
new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
#endif
|
||||
} else {
|
||||
}
|
||||
if (kernel == nullptr) {
|
||||
kernel =
|
||||
new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
}
|
||||
|
|
|
@ -1,2 +0,0 @@
|
|||
ssd-10.onnx
|
||||
efficientnet-lite4-11.onnx
|
Loading…
Reference in New Issue