[MSLITE][Develop] optimize arm cpu int8 op conv dw 3x3, add border assembly

This commit is contained in:
yangruoqi713 2020-10-23 18:22:14 +08:00
parent 17764803ef
commit 9e274b6468
9 changed files with 686 additions and 219 deletions

View File

@ -1,168 +0,0 @@
#ifdef __aarch64__
.text
.align 5
.global ConvDw3x3BorderPixelInt8
#ifndef __APPLE__
.type ConvDw3x3BorderPixelInt8, %function
#endif
// void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height,
// size_t width, size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp,
// size_t out_multiplier, size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max) {
// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: width, x6: in_kh_step, x7: in_kw_step,
// x8: channel, x9: in_zp, x10: out_zp, x11: out_multiplier, x12: left_shift, x13: right_shift
// x14: acc_min, x15: acc_max
ConvDw3x3BorderPixelInt8:
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
// x19 ~ x29 should be also preserved
// whereas our coding style do not permit such amount of parameters
ldr x8, [sp]
ldrb w9, [sp, #8]
dup v25.8b, w9 // in_zp
ldr x9, [sp, #16]
dup v26.4s, w9 // out_zp
ldr x9, [sp, #24]
dup v27.4s, w9 // out_multiplier
ldr x9, [sp, #32]
dup v28.4s, w9 // left_shift
ldr x9, [sp, #40]
dup v29.4s, w9 // right_shift
ldr x9, [sp, #48]
dup v30.4s, w9 // acc_min
ldr x9, [sp, #56]
dup v31.4s, w9 // acc_max
mov x9, #2
mul x13, x8, x9 // x8 * 2
mov x9, #3
mul x14, x13, x9 // x8 * 3 * 2
LoopC:
ld1 {v23.4s}, [x3], #16
ld1 {v24.4s}, [x3], #16
mov x9, x1
mov x10, x2
cmp x4, #2
blt LoopHW
LoopH2W2:
cmp x5, #2
blt LoopHW
ld1 {v0.8b}, [x9], x7
ssubl v0.8h, v0.8b, v25.8b
add x11, x1, x6
ld1 {v4.8h}, [x10], x13 // weight
smlal v23.4s, v0.4h, v4.4h
smlal2 v24.4s, v0.8h, v4.8h
add x12, x2, x14
ld1 {v1.8b}, [x9], x7
ssubl v1.8h, v1.8b, v25.8b
ld1 {v5.8h}, [x10], x13
smlal v23.4s, v1.4h, v5.4h
smlal2 v24.4s, v1.8h, v5.8h
add x15, x11, x6
ld1 {v2.8b}, [x11], x7
ssubl v2.8h, v2.8b, v25.8b
add x16, x12, x14
ld1 {v6.8h}, [x12], x13
smlal v23.4s, v2.4h, v6.4h
smlal2 v24.4s, v2.8h, v6.8h
ld1 {v3.8b}, [x11], x7
ssubl v3.8h, v3.8b, v25.8b
ld1 {v7.8h}, [x12], x13
smlal v23.4s, v3.4h, v7.4h
smlal2 v24.4s, v3.8h, v7.8h
cmp x5, #3
beq LoopH2W3
cmp x4, #3
beq LoopH3W2
b Post
LoopH2W3:
ld1 {v16.8b}, [x9], x7
ssubl v16.8h, v16.8b, v25.8b
ld1 {v17.8h}, [x10], x13
smlal v23.4s, v16.4h, v17.4h
smlal2 v24.4s, v16.8h, v17.8h
ld1 {v18.8b}, [x11], x7
ssubl v18.8h, v18.8b, v25.8b
ld1 {v19.8h}, [x12], x13
smlal v23.4s, v18.4h, v19.4h
smlal2 v24.4s, v18.8h, v19.8h
b Post
LoopH3W2:
ld1 {v16.8b}, [x15], x7
ssubl v16.8h, v16.8b, v25.8b
ld1 {v17.8h}, [x16], x13
smlal v23.4s, v16.4h, v17.4h
smlal2 v24.4s, v16.8h, v17.8h
ld1 {v18.8b}, [x15], x7
ssubl v18.8h, v18.8b, v25.8b
ld1 {v19.8h}, [x16], x13
smlal v23.4s, v18.4h, v19.4h
smlal2 v24.4s, v18.8h, v19.8h
b Post
LoopHW:
mov x9, x1
mov x10, x2
mov x17, x4 // height
LoopH:
mov x11, x9
mov x12, x10
mov x18, x5 // width
LoopW:
ld1 {v0.8b}, [x11], x7
ssubl v1.8h, v0.8b, v25.8b
ld1 {v2.8h}, [x12], x13 // weight
smlal v23.4s, v1.4h, v2.4h
smlal2 v24.4s, v1.8h, v2.8h
subs x18, x18, #1
bne LoopW
subs x17, x17, #1
add x9, x9, x6
add x10, x10, x14
bne LoopH
Post:
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
and v12.16b, v29.16b, v23.16b
sshr v12.4s, v12.4s, #31
sqadd v23.4s, v23.4s, v12.4s
srshl v23.4s, v23.4s, v29.4s
and v11.16b, v29.16b, v24.16b
sshr v11.4s, v11.4s, #31
sqadd v24.4s, v24.4s, v11.4s
srshl v24.4s, v24.4s, v29.4s
add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s
smax v24.4s, v24.4s, v30.4s
smin v23.4s, v23.4s, v31.4s
smin v24.4s, v24.4s, v31.4s
sqxtn v23.4h, v23.4s
sqxtn v24.4h, v24.4s
sqxtn v23.8b, v23.8h
sqxtn v24.8b, v24.8h
st1 {v23.s}[0], [x0], #4
st1 {v24.s}[0], [x0], #4
add x1, x1, #8
add x2, x2, #16
sub x8, x8, #8
cmp x8, #8
bge LoopC
ret
#endif

View File

@ -0,0 +1,168 @@
#ifdef __aarch64__
.text
.align 5
.global ConvDw3x3Int8Corner
#ifndef __APPLE__
.type ConvDw3x3Int8Corner, %function
#endif
// void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step,
// size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, size_t out_multiplier,
// size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max)
// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step,
// x6: channel, x7: in_zp, x8: out_zp, x9: out_multiplier, x10: left_shift, x11: right_shift
// x11: acc_min, x13: acc_max
ConvDw3x3Int8Corner:
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
// x19 ~ x29 should be also preserved
// whereas our coding style do not permit such amount of parameters
dup v25.8b, w7 // in_zp
ldr x9, [sp]
dup v26.4s, w9 // out_zp
ldr x9, [sp, #8]
dup v27.4s, w9 // out_multiplier
ldr x9, [sp, #16]
dup v28.4s, w9 // left_shift
ldr x9, [sp, #24]
dup v29.4s, w9 // right_shift
ldr x9, [sp, #32]
dup v30.4s, w9 // acc_min
ldr x9, [sp, #40]
dup v31.4s, w9 // acc_max
mov x9, #2
mul x13, x6, x9 // x6 * 2
mov x9, #3
mul x14, x13, x9 // x6 * 3 * 2
ld1 {v23.4s}, [x3], #16
ld1 {v24.4s}, [x3], #16
mov x9, x1
mov x10, x2
ld1 {v0.8b}, [x9], x5
ssubl v0.8h, v0.8b, v25.8b
add x11, x1, x4
ld1 {v4.8h}, [x10], x13 // weight
add x12, x2, x14
ld1 {v1.8b}, [x9], x5
ssubl v1.8h, v1.8b, v25.8b
ld1 {v5.8h}, [x10], x13
ld1 {v2.8b}, [x11], x5
ssubl v2.8h, v2.8b, v25.8b
ld1 {v6.8h}, [x12], x13
ld1 {v3.8b}, [x11], x5
ssubl v3.8h, v3.8b, v25.8b
ld1 {v7.8h}, [x12], x13
cmp x6, #8
ble LoopC8Post
LoopC8:
add x1, x1, #8
add x2, x2, #16
smlal v23.4s, v0.4h, v4.4h
smlal2 v24.4s, v0.8h, v4.8h
mov x9, x1
mov x10, x2
ld1 {v0.8b}, [x9], x5
ssubl v0.8h, v0.8b, v25.8b
ld1 {v4.8h}, [x10], x13 // weight
add x11, x1, x4
smlal v23.4s, v1.4h, v5.4h
smlal2 v24.4s, v1.8h, v5.8h
add x12, x2, x14
ld1 {v1.8b}, [x9], x5
ssubl v1.8h, v1.8b, v25.8b
smlal v23.4s, v2.4h, v6.4h
ld1 {v5.8h}, [x10], x13
smlal2 v24.4s, v2.8h, v6.8h
ld1 {v2.8b}, [x11], x5
ssubl v2.8h, v2.8b, v25.8b
smlal v23.4s, v3.4h, v7.4h
ld1 {v6.8h}, [x12], x13
smlal2 v24.4s, v3.8h, v7.8h
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
and v21.16b, v29.16b, v23.16b
sshr v21.4s, v21.4s, #31
sqadd v23.4s, v23.4s, v21.4s
srshl v23.4s, v23.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
ld1 {v3.8b}, [x11], x5
ssubl v3.8h, v3.8b, v25.8b
ld1 {v7.8h}, [x12], x13
add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s
smax v24.4s, v24.4s, v30.4s
smin v23.4s, v23.4s, v31.4s
smin v24.4s, v24.4s, v31.4s
sqxtn v23.4h, v23.4s
sqxtn v24.4h, v24.4s
sqxtn v23.8b, v23.8h
sqxtn v24.8b, v24.8h
st1 {v23.s}[0], [x0], #4
st1 {v24.s}[0], [x0], #4
ld1 {v23.4s}, [x3], #16
ld1 {v24.4s}, [x3], #16
sub x6, x6, #8
cmp x6, #8
bgt LoopC8
LoopC8Post:
smlal v23.4s, v0.4h, v4.4h
smlal2 v24.4s, v0.8h, v4.8h
smlal v23.4s, v1.4h, v5.4h
smlal2 v24.4s, v1.8h, v5.8h
smlal v23.4s, v2.4h, v6.4h
smlal2 v24.4s, v2.8h, v6.8h
smlal v23.4s, v3.4h, v7.4h
smlal2 v24.4s, v3.8h, v7.8h
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
and v21.16b, v29.16b, v23.16b
sshr v21.4s, v21.4s, #31
sqadd v23.4s, v23.4s, v21.4s
srshl v23.4s, v23.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s
smax v24.4s, v24.4s, v30.4s
smin v23.4s, v23.4s, v31.4s
smin v24.4s, v24.4s, v31.4s
sqxtn v23.4h, v23.4s
sqxtn v24.4h, v24.4s
sqxtn v23.8b, v23.8h
sqxtn v24.8b, v24.8h
st1 {v23.s}[0], [x0], #4
st1 {v24.s}[0], [x0], #4
ret
#endif

View File

@ -0,0 +1,196 @@
#ifdef __aarch64__
.text
.align 5
.global ConvDw3x3Int8Horizontal
#ifndef __APPLE__
.type ConvDw3x3Int8Horizontal, %function
#endif
// void ConvDw3x3Int8Horizontal(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step,
// size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, size_t out_multiplier,
// size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max)
// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step,
// x6: channel, x7: in_zp, x8: out_zp, x9: out_multiplier, x10: left_shift, x11: right_shift
// x11: acc_min, x13: acc_max
ConvDw3x3Int8Horizontal:
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
// x19 ~ x29 should be also preserved
// whereas our coding style do not permit such amount of parameters
dup v25.8b, w7 // in_zp
ldr x9, [sp]
dup v26.4s, w9 // out_zp
ldr x9, [sp, #8]
dup v27.4s, w9 // out_multiplier
ldr x9, [sp, #16]
dup v28.4s, w9 // left_shift
ldr x9, [sp, #24]
dup v29.4s, w9 // right_shift
ldr x9, [sp, #32]
dup v30.4s, w9 // acc_min
ldr x9, [sp, #40]
dup v31.4s, w9 // acc_max
mov x9, #2
mul x13, x6, x9 // x6 * 2
mov x9, #3
mul x14, x13, x9 // x6 * 3 * 2
ld1 {v23.4s}, [x3], #16
ld1 {v24.4s}, [x3], #16
mov x9, x1
mov x10, x2
ld1 {v0.8b}, [x9], x5
ssubl v0.8h, v0.8b, v25.8b
add x11, x1, x4
ld1 {v4.8h}, [x10], x13 // weight
add x12, x2, x14
ld1 {v1.8b}, [x9], x5
ssubl v1.8h, v1.8b, v25.8b
ld1 {v5.8h}, [x10], x13
add x15, x11, x4
ld1 {v2.8b}, [x11], x5
ssubl v2.8h, v2.8b, v25.8b
add x16, x12, x14
ld1 {v6.8h}, [x12], x13
ld1 {v3.8b}, [x11], x5
ssubl v3.8h, v3.8b, v25.8b
ld1 {v7.8h}, [x12], x13
ld1 {v16.8b}, [x15], x5
ssubl v16.8h, v16.8b, v25.8b
ld1 {v18.8h}, [x16], x13
ld1 {v17.8b}, [x15], x5
ssubl v17.8h, v17.8b, v25.8b
ld1 {v19.8h}, [x16], x13
cmp x6, #8
ble LoopC8Post
LoopC8:
add x1, x1, #8
add x2, x2, #16
smlal v23.4s, v0.4h, v4.4h
smlal2 v24.4s, v0.8h, v4.8h
mov x9, x1
mov x10, x2
ld1 {v0.8b}, [x9], x5
ssubl v0.8h, v0.8b, v25.8b
ld1 {v4.8h}, [x10], x13 // weight
add x11, x1, x4
smlal v23.4s, v1.4h, v5.4h
smlal2 v24.4s, v1.8h, v5.8h
add x12, x2, x14
ld1 {v1.8b}, [x9], x5
ssubl v1.8h, v1.8b, v25.8b
smlal v23.4s, v2.4h, v6.4h
ld1 {v5.8h}, [x10], x13
smlal2 v24.4s, v2.8h, v6.8h
add x15, x11, x4
add x16, x12, x14
ld1 {v2.8b}, [x11], x5
ssubl v2.8h, v2.8b, v25.8b
smlal v23.4s, v3.4h, v7.4h
ld1 {v6.8h}, [x12], x13
smlal2 v24.4s, v3.8h, v7.8h
ld1 {v3.8b}, [x11], x5
ssubl v3.8h, v3.8b, v25.8b
smlal v23.4s, v16.4h, v18.4h
ld1 {v7.8h}, [x12], x13
smlal2 v24.4s, v16.8h, v18.8h
ld1 {v16.8b}, [x15], x5
ssubl v16.8h, v16.8b, v25.8b
smlal v23.4s, v17.4h, v19.4h
ld1 {v18.8h}, [x16], x13
smlal2 v24.4s, v17.8h, v19.8h
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
and v21.16b, v29.16b, v23.16b
sshr v21.4s, v21.4s, #31
sqadd v23.4s, v23.4s, v21.4s
srshl v23.4s, v23.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
ld1 {v17.8b}, [x15], x5
ssubl v17.8h, v17.8b, v25.8b
ld1 {v19.8h}, [x16], x13
add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s
smax v24.4s, v24.4s, v30.4s
smin v23.4s, v23.4s, v31.4s
smin v24.4s, v24.4s, v31.4s
sqxtn v23.4h, v23.4s
sqxtn v24.4h, v24.4s
sqxtn v23.8b, v23.8h
sqxtn v24.8b, v24.8h
st1 {v23.s}[0], [x0], #4
st1 {v24.s}[0], [x0], #4
ld1 {v23.4s}, [x3], #16
ld1 {v24.4s}, [x3], #16
sub x6, x6, #8
cmp x6, #8
bgt LoopC8
LoopC8Post:
smlal v23.4s, v0.4h, v4.4h
smlal2 v24.4s, v0.8h, v4.8h
smlal v23.4s, v1.4h, v5.4h
smlal2 v24.4s, v1.8h, v5.8h
smlal v23.4s, v2.4h, v6.4h
smlal2 v24.4s, v2.8h, v6.8h
smlal v23.4s, v3.4h, v7.4h
smlal2 v24.4s, v3.8h, v7.8h
smlal v23.4s, v16.4h, v18.4h
smlal2 v24.4s, v16.8h, v18.8h
smlal v23.4s, v17.4h, v19.4h
smlal2 v24.4s, v17.8h, v19.8h
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
and v21.16b, v29.16b, v23.16b
sshr v21.4s, v21.4s, #31
sqadd v23.4s, v23.4s, v21.4s
srshl v23.4s, v23.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s
smax v24.4s, v24.4s, v30.4s
smin v23.4s, v23.4s, v31.4s
smin v24.4s, v24.4s, v31.4s
sqxtn v23.4h, v23.4s
sqxtn v24.4h, v24.4s
sqxtn v23.8b, v23.8h
sqxtn v24.8b, v24.8h
st1 {v23.s}[0], [x0], #4
st1 {v24.s}[0], [x0], #4
ret
#endif

View File

@ -0,0 +1,192 @@
#ifdef __aarch64__
.text
.align 5
.global ConvDw3x3Int8Vertical
#ifndef __APPLE__
.type ConvDw3x3Int8Vertical, %function
#endif
// void ConvDw3x3Int8Vertical(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step,
// size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, size_t out_multiplier,
// size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max)
// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step,
// x6: channel, x7: in_zp, x8: out_zp, x9: out_multiplier, x10: left_shift, x11: right_shift
// x11: acc_min, x13: acc_max
ConvDw3x3Int8Vertical:
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
// x19 ~ x29 should be also preserved
// whereas our coding style do not permit such amount of parameters
dup v25.8b, w7 // in_zp
ldr x9, [sp]
dup v26.4s, w9 // out_zp
ldr x9, [sp, #8]
dup v27.4s, w9 // out_multiplier
ldr x9, [sp, #16]
dup v28.4s, w9 // left_shift
ldr x9, [sp, #24]
dup v29.4s, w9 // right_shift
ldr x9, [sp, #32]
dup v30.4s, w9 // acc_min
ldr x9, [sp, #40]
dup v31.4s, w9 // acc_max
mov x9, #2
mul x13, x6, x9 // x6 * 2
mov x9, #3
mul x14, x13, x9 // x6 * 3 * 2
ld1 {v23.4s}, [x3], #16
ld1 {v24.4s}, [x3], #16
mov x9, x1
mov x10, x2
ld1 {v0.8b}, [x9], x5
ssubl v0.8h, v0.8b, v25.8b
add x11, x1, x4
ld1 {v4.8h}, [x10], x13 // weight
add x12, x2, x14
ld1 {v1.8b}, [x9], x5
ssubl v1.8h, v1.8b, v25.8b
ld1 {v5.8h}, [x10], x13
ld1 {v2.8b}, [x11], x5
ssubl v2.8h, v2.8b, v25.8b
ld1 {v6.8h}, [x12], x13
ld1 {v3.8b}, [x11], x5
ssubl v3.8h, v3.8b, v25.8b
ld1 {v7.8h}, [x12], x13
ld1 {v16.8b}, [x9], x5
ssubl v16.8h, v16.8b, v25.8b
ld1 {v18.8h}, [x10], x13
ld1 {v17.8b}, [x11], x5
ssubl v17.8h, v17.8b, v25.8b
ld1 {v19.8h}, [x12], x13
cmp x6, #8
ble LoopC8Post
LoopC8:
add x1, x1, #8
add x2, x2, #16
smlal v23.4s, v0.4h, v4.4h
smlal2 v24.4s, v0.8h, v4.8h
mov x9, x1
mov x10, x2
ld1 {v0.8b}, [x9], x5
ssubl v0.8h, v0.8b, v25.8b
ld1 {v4.8h}, [x10], x13 // weight
add x11, x1, x4
smlal v23.4s, v1.4h, v5.4h
smlal2 v24.4s, v1.8h, v5.8h
add x12, x2, x14
ld1 {v1.8b}, [x9], x5
ssubl v1.8h, v1.8b, v25.8b
smlal v23.4s, v2.4h, v6.4h
ld1 {v5.8h}, [x10], x13
smlal2 v24.4s, v2.8h, v6.8h
ld1 {v2.8b}, [x11], x5
ssubl v2.8h, v2.8b, v25.8b
smlal v23.4s, v3.4h, v7.4h
ld1 {v6.8h}, [x12], x13
smlal2 v24.4s, v3.8h, v7.8h
ld1 {v3.8b}, [x11], x5
ssubl v3.8h, v3.8b, v25.8b
smlal v23.4s, v16.4h, v18.4h
ld1 {v7.8h}, [x12], x13
smlal2 v24.4s, v16.8h, v18.8h
ld1 {v16.8b}, [x9], x5
ssubl v16.8h, v16.8b, v25.8b
smlal v23.4s, v17.4h, v19.4h
ld1 {v18.8h}, [x10], x13
smlal2 v24.4s, v17.8h, v19.8h
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
and v21.16b, v29.16b, v23.16b
sshr v21.4s, v21.4s, #31
sqadd v23.4s, v23.4s, v21.4s
srshl v23.4s, v23.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
ld1 {v17.8b}, [x11], x5
ssubl v17.8h, v17.8b, v25.8b
ld1 {v19.8h}, [x12], x13
add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s
smax v24.4s, v24.4s, v30.4s
smin v23.4s, v23.4s, v31.4s
smin v24.4s, v24.4s, v31.4s
sqxtn v23.4h, v23.4s
sqxtn v24.4h, v24.4s
sqxtn v23.8b, v23.8h
sqxtn v24.8b, v24.8h
st1 {v23.s}[0], [x0], #4
st1 {v24.s}[0], [x0], #4
ld1 {v23.4s}, [x3], #16
ld1 {v24.4s}, [x3], #16
sub x6, x6, #8
cmp x6, #8
bgt LoopC8
LoopC8Post:
smlal v23.4s, v0.4h, v4.4h
smlal2 v24.4s, v0.8h, v4.8h
smlal v23.4s, v1.4h, v5.4h
smlal2 v24.4s, v1.8h, v5.8h
smlal v23.4s, v2.4h, v6.4h
smlal2 v24.4s, v2.8h, v6.8h
smlal v23.4s, v3.4h, v7.4h
smlal2 v24.4s, v3.8h, v7.8h
smlal v23.4s, v16.4h, v18.4h
smlal2 v24.4s, v16.8h, v18.8h
smlal v23.4s, v17.4h, v19.4h
smlal2 v24.4s, v17.8h, v19.8h
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
and v21.16b, v29.16b, v23.16b
sshr v21.4s, v21.4s, #31
sqadd v23.4s, v23.4s, v21.4s
srshl v23.4s, v23.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s
smax v24.4s, v24.4s, v30.4s
smin v23.4s, v23.4s, v31.4s
smin v24.4s, v24.4s, v31.4s
sqxtn v23.4h, v23.4s
sqxtn v24.4h, v24.4s
sqxtn v23.8b, v23.8h
sqxtn v24.8b, v24.8h
st1 {v23.s}[0], [x0], #4
st1 {v24.s}[0], [x0], #4
ret
#endif

View File

@ -47,10 +47,6 @@ void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, con
size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp,
int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift,
int32_t *acc_min, int32_t *acc_max);
void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height,
size_t width, size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp,
size_t out_zp, size_t out_multiplier, size_t left_shift, size_t right_shift,
size_t acc_min, size_t acc_max);
#endif
#ifdef ENABLE_ARM32
@ -71,6 +67,21 @@ void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *wei
void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width,
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step);
void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias,
int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp,
int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min,
int32_t acc_max);
void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step,
size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, size_t out_multiplier,
size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max);
void ConvDw3x3Int8Vertical(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias,
size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp,
size_t out_multiplier, size_t left_shift, size_t right_shift, size_t acc_min,
size_t acc_max);
void ConvDw3x3Int8Horizontal(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias,
size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp,
size_t out_multiplier, size_t left_shift, size_t right_shift, size_t acc_min,
size_t acc_max);
#endif
#ifdef __cplusplus
}

View File

@ -232,29 +232,31 @@ void ConvDw3x3Int8Row(int8_t *output, int8_t *buffer, const int8_t *input, const
int ih_offset = 64 * block_input_w;
int w = start_w;
for (; w <= end_w - block_output_w; w += block_output_w) {
int8_t *output_ptr = output;
const int8_t *input_ptr = input;
const int16_t *weight_ptr = weight;
const int32_t *bias_ptr = bias;
int c = 0;
for (; c <= conv_param->output_channel_ - 64; c += 64) {
InitInputBuffer(buffer, input_ptr, conv_param, block_input_h, block_input_w);
ConvDw3x3Int8Block(output_ptr, buffer, weight_ptr, bias_ptr, 0, 64, 64, ih_offset, conv_param->input_channel_,
block_output_h, block_output_w, in_zp, out_zp, out_multiplier, left_shift, right_shift,
acc_min, acc_max, conv_param->stride_h_);
output_ptr += 64;
input_ptr += 64;
weight_ptr += 64;
bias_ptr += 64;
if (conv_param->output_channel_ > 64 || (conv_param->output_channel_ < 64 && conv_param->input_w_ > 150)) {
for (; w <= end_w - block_output_w; w += block_output_w) {
int8_t *output_ptr = output;
const int8_t *input_ptr = input;
const int16_t *weight_ptr = weight;
const int32_t *bias_ptr = bias;
int c = 0;
for (; c <= conv_param->output_channel_ - 64; c += 64) {
InitInputBuffer(buffer, input_ptr, conv_param, block_input_h, block_input_w);
ConvDw3x3Int8Block(output_ptr, buffer, weight_ptr, bias_ptr, 0, 64, 64, ih_offset, conv_param->input_channel_,
block_output_h, block_output_w, in_zp, out_zp, out_multiplier, left_shift, right_shift,
acc_min, acc_max, conv_param->stride_h_);
output_ptr += 64;
input_ptr += 64;
weight_ptr += 64;
bias_ptr += 64;
}
// left channel
ConvDw3x3Int8Block(output_ptr, input_ptr, weight_ptr, bias_ptr, c, conv_param->input_channel_,
conv_param->input_channel_, conv_param->input_w_ * conv_param->input_channel_,
conv_param->input_channel_, block_output_h, block_output_w, in_zp, out_zp, out_multiplier,
left_shift, right_shift, acc_min, acc_max, conv_param->stride_h_);
output += block_output_w * conv_param->input_channel_;
input += conv_param->stride_w_ * block_output_w * conv_param->input_channel_;
}
// left channel
ConvDw3x3Int8Block(output_ptr, input_ptr, weight_ptr, bias_ptr, c, conv_param->input_channel_,
conv_param->input_channel_, conv_param->input_w_ * conv_param->input_channel_,
conv_param->input_channel_, block_output_h, block_output_w, in_zp, out_zp, out_multiplier,
left_shift, right_shift, acc_min, acc_max, conv_param->stride_h_);
output += block_output_w * conv_param->input_channel_;
input += conv_param->stride_w_ * block_output_w * conv_param->input_channel_;
}
// left width
int left_width = end_w - w;
@ -300,8 +302,7 @@ void ConvDw3x3Int8(int8_t *output_data, int8_t *buffer, const int8_t *input_data
}
}
#ifndef ENABLE_ARM
void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height,
void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height,
int width, int in_kh_step, int in_kw_step, int channel, int8_t in_zp, int32_t out_zp,
int out_multiplier, int left_shift, int right_shift, int32_t acc_min, int32_t acc_max) {
for (int c = 0; c < channel; c += 8) {
@ -337,9 +338,30 @@ void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *wei
}
}
}
#ifndef ENABLE_ARM64
void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int in_kh_step,
int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, int out_multiplier, int left_shift,
int right_shift, int32_t acc_min, int32_t acc_max) {
ConvDw3x3Int8BorderPixel(dst, src, weight, bias, 2, 2, in_kh_step, in_kw_step, channel, in_zp, out_zp, out_multiplier,
left_shift, right_shift, acc_min, acc_max);
}
void ConvDw3x3Int8Vertical(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int in_kh_step,
int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, int out_multiplier,
int left_shift, int right_shift, int32_t acc_min, int32_t acc_max) {
ConvDw3x3Int8BorderPixel(dst, src, weight, bias, 2, 3, in_kh_step, in_kw_step, channel, in_zp, out_zp, out_multiplier,
left_shift, right_shift, acc_min, acc_max);
}
void ConvDw3x3Int8Horizontal(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int in_kh_step,
int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, int out_multiplier,
int left_shift, int right_shift, int32_t acc_min, int32_t acc_max) {
ConvDw3x3Int8BorderPixel(dst, src, weight, bias, 3, 2, in_kh_step, in_kw_step, channel, in_zp, out_zp, out_multiplier,
left_shift, right_shift, acc_min, acc_max);
}
#endif
void ConvDw3x3BorderInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int top,
void ConvDw3x3Int8Border(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int top,
int bottom, int left, int right, const ConvParameter *conv_param,
const SlidingWindowParam *sliding, int8_t in_zp, int32_t out_zp, int out_multiplier,
int left_shift, int right_shift, int32_t acc_min, int32_t acc_max) {
@ -361,7 +383,7 @@ void ConvDw3x3BorderInt8(int8_t *dst, const int8_t *src, const int16_t *weight,
const int16_t *weight_kernel =
weight + (start_kh * conv_param->kernel_w_ + start_kw) * conv_param->input_channel_;
ConvDw3x3BorderPixelInt8(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
ConvDw3x3Int8BorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->input_channel_, in_zp, out_zp,
out_multiplier, left_shift, right_shift, acc_min, acc_max);
@ -371,7 +393,7 @@ void ConvDw3x3BorderInt8(int8_t *dst, const int8_t *src, const int16_t *weight,
} // height loop
}
void ConvDw3x3PadInt8(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data,
void ConvDw3x3Int8Pad(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data,
const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
int out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0];
int left_shift = conv_param->conv_quant_arg_.left_shift_[0];
@ -380,17 +402,70 @@ void ConvDw3x3PadInt8(int8_t *output_data, const int8_t *input_data, const int16
int out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_;
int acc_min = conv_param->conv_quant_arg_.out_act_min_[0];
int acc_max = conv_param->conv_quant_arg_.out_act_max_[0];
ConvDw3x3BorderInt8(output_data, input_data, weight_data, bias_data, 0, sliding->top_, 0, conv_param->output_w_,
conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
ConvDw3x3BorderInt8(output_data, input_data, weight_data, bias_data, sliding->bottom_, conv_param->output_h_, 0,
conv_param->output_w_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift,
right_shift, acc_min, acc_max);
ConvDw3x3BorderInt8(output_data, input_data, weight_data, bias_data, sliding->top_, sliding->bottom_, 0,
sliding->left_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, right_shift,
acc_min, acc_max);
ConvDw3x3BorderInt8(output_data, input_data, weight_data, bias_data, sliding->top_, sliding->bottom_, sliding->right_,
conv_param->output_w_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift,
right_shift, acc_min, acc_max);
int input_row_size = conv_param->input_w_ * conv_param->input_channel_;
int weight_row_size = conv_param->kernel_w_ * conv_param->input_channel_;
int output_row_size = conv_param->output_w_ * conv_param->output_channel_;
int in_kh_step = sliding->in_kh_step_;
int in_kw_step = sliding->in_kw_step_;
// top
const int8_t *input = input_data;
const int16_t *weight = weight_data + weight_row_size + conv_param->input_channel_;
int8_t *output = output_data;
ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp,
out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
input += (conv_param->stride_w_ - 1) * conv_param->input_channel_;
weight = weight_data + weight_row_size;
output += conv_param->output_channel_;
for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) {
ConvDw3x3Int8Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp,
out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
input += conv_param->stride_w_ * conv_param->input_channel_;
output += conv_param->output_channel_;
}
ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp,
out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
// left
input = input_data + (conv_param->stride_h_ - 1) * input_row_size;
weight = weight_data + conv_param->input_channel_;
output = output_data + output_row_size;
for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) {
ConvDw3x3Int8Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp,
out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
input += conv_param->stride_h_ * input_row_size;
output += output_row_size;
}
// right
input =
input_data + (conv_param->input_w_ - 2) * conv_param->input_channel_ + (conv_param->stride_h_ - 1) * input_row_size;
weight = weight_data;
output = output_data + output_row_size + (conv_param->output_w_ - 1) * conv_param->output_channel_;
for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) {
ConvDw3x3Int8Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp,
out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
input += conv_param->stride_h_ * input_row_size;
output += output_row_size;
}
// bottom
input = input_data + (conv_param->input_h_ - 2) * input_row_size;
weight = weight_data + conv_param->input_channel_;
output = output_data + (conv_param->output_h_ - 1) * output_row_size;
ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp,
out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
input += conv_param->stride_w_ == 1 ? 0 : conv_param->input_channel_;
weight = weight_data;
output += conv_param->output_channel_;
for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) {
ConvDw3x3Int8Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp,
out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
input += conv_param->stride_w_ * conv_param->input_channel_;
output += conv_param->output_channel_;
}
ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp,
out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
}
/*conv depthwise 3x3 int8 end*/

View File

@ -29,7 +29,7 @@ bool CheckIfUse3X3(const ConvParameter *conv_param, int channel);
void ConvDwInt8(int8_t *output_data, int32_t *output_row, const int8_t *input_data, const int16_t *weight_data,
const int32_t *bias_data, const ConvParameter *conv_param, int task_id);
void ConvDw3x3PadInt8(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data,
void ConvDw3x3Int8Pad(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data,
const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding);
void ConvDw3x3Int8(int8_t *output_data, int8_t *buffer, const int8_t *input_data, const int16_t *weight_data,
@ -44,13 +44,6 @@ void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *in
const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding,
int task_id);
#ifdef ENABLE_ARM64
void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias,
int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp,
int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min,
int32_t acc_max);
#endif
#ifdef __cplusplus
}
#endif

View File

@ -164,7 +164,7 @@ int ConvolutionDepthwise3x3Int8CPUKernel::Run() {
if (sliding_->top_ > 0 || sliding_->bottom_ < conv_param_->output_h_ || sliding_->left_ > 0 ||
sliding_->right_ < conv_param_->output_w_) {
ConvDw3x3PadInt8(output_ptr_, input_ptr_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), conv_param_,
ConvDw3x3Int8Pad(output_ptr_, input_ptr_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), conv_param_,
sliding_);
}
ret = ParallelLaunch(this->context_->thread_pool_, ConvDw3x3Int8Run, this, conv_param_->thread_num_);