forked from mindspore-Ecosystem/mindspore
!7715 [MSLITE][Develop] optimize arm cpu int8 op conv dw 3x3, add border assembly
Merge pull request !7715 from yangruoqi713/conv_dw
This commit is contained in:
commit
624f6b1607
|
@ -1,168 +0,0 @@
|
|||
#ifdef __aarch64__
|
||||
|
||||
.text
|
||||
.align 5
|
||||
.global ConvDw3x3BorderPixelInt8
|
||||
#ifndef __APPLE__
|
||||
.type ConvDw3x3BorderPixelInt8, %function
|
||||
#endif
|
||||
|
||||
// void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height,
|
||||
// size_t width, size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp,
|
||||
// size_t out_multiplier, size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max) {
|
||||
|
||||
// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: width, x6: in_kh_step, x7: in_kw_step,
|
||||
// x8: channel, x9: in_zp, x10: out_zp, x11: out_multiplier, x12: left_shift, x13: right_shift
|
||||
// x14: acc_min, x15: acc_max
|
||||
ConvDw3x3BorderPixelInt8:
|
||||
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
|
||||
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
|
||||
// x19 ~ x29 should be also preserved
|
||||
// whereas our coding style do not permit such amount of parameters
|
||||
ldr x8, [sp]
|
||||
ldrb w9, [sp, #8]
|
||||
dup v25.8b, w9 // in_zp
|
||||
ldr x9, [sp, #16]
|
||||
dup v26.4s, w9 // out_zp
|
||||
ldr x9, [sp, #24]
|
||||
dup v27.4s, w9 // out_multiplier
|
||||
ldr x9, [sp, #32]
|
||||
dup v28.4s, w9 // left_shift
|
||||
ldr x9, [sp, #40]
|
||||
dup v29.4s, w9 // right_shift
|
||||
ldr x9, [sp, #48]
|
||||
dup v30.4s, w9 // acc_min
|
||||
ldr x9, [sp, #56]
|
||||
dup v31.4s, w9 // acc_max
|
||||
|
||||
mov x9, #2
|
||||
mul x13, x8, x9 // x8 * 2
|
||||
mov x9, #3
|
||||
mul x14, x13, x9 // x8 * 3 * 2
|
||||
|
||||
LoopC:
|
||||
ld1 {v23.4s}, [x3], #16
|
||||
ld1 {v24.4s}, [x3], #16
|
||||
|
||||
mov x9, x1
|
||||
mov x10, x2
|
||||
cmp x4, #2
|
||||
blt LoopHW
|
||||
LoopH2W2:
|
||||
cmp x5, #2
|
||||
blt LoopHW
|
||||
ld1 {v0.8b}, [x9], x7
|
||||
ssubl v0.8h, v0.8b, v25.8b
|
||||
add x11, x1, x6
|
||||
ld1 {v4.8h}, [x10], x13 // weight
|
||||
smlal v23.4s, v0.4h, v4.4h
|
||||
smlal2 v24.4s, v0.8h, v4.8h
|
||||
add x12, x2, x14
|
||||
ld1 {v1.8b}, [x9], x7
|
||||
ssubl v1.8h, v1.8b, v25.8b
|
||||
ld1 {v5.8h}, [x10], x13
|
||||
smlal v23.4s, v1.4h, v5.4h
|
||||
smlal2 v24.4s, v1.8h, v5.8h
|
||||
add x15, x11, x6
|
||||
ld1 {v2.8b}, [x11], x7
|
||||
ssubl v2.8h, v2.8b, v25.8b
|
||||
add x16, x12, x14
|
||||
ld1 {v6.8h}, [x12], x13
|
||||
smlal v23.4s, v2.4h, v6.4h
|
||||
smlal2 v24.4s, v2.8h, v6.8h
|
||||
ld1 {v3.8b}, [x11], x7
|
||||
ssubl v3.8h, v3.8b, v25.8b
|
||||
ld1 {v7.8h}, [x12], x13
|
||||
smlal v23.4s, v3.4h, v7.4h
|
||||
smlal2 v24.4s, v3.8h, v7.8h
|
||||
cmp x5, #3
|
||||
beq LoopH2W3
|
||||
cmp x4, #3
|
||||
beq LoopH3W2
|
||||
b Post
|
||||
|
||||
LoopH2W3:
|
||||
ld1 {v16.8b}, [x9], x7
|
||||
ssubl v16.8h, v16.8b, v25.8b
|
||||
ld1 {v17.8h}, [x10], x13
|
||||
smlal v23.4s, v16.4h, v17.4h
|
||||
smlal2 v24.4s, v16.8h, v17.8h
|
||||
ld1 {v18.8b}, [x11], x7
|
||||
ssubl v18.8h, v18.8b, v25.8b
|
||||
ld1 {v19.8h}, [x12], x13
|
||||
smlal v23.4s, v18.4h, v19.4h
|
||||
smlal2 v24.4s, v18.8h, v19.8h
|
||||
b Post
|
||||
|
||||
LoopH3W2:
|
||||
ld1 {v16.8b}, [x15], x7
|
||||
ssubl v16.8h, v16.8b, v25.8b
|
||||
ld1 {v17.8h}, [x16], x13
|
||||
smlal v23.4s, v16.4h, v17.4h
|
||||
smlal2 v24.4s, v16.8h, v17.8h
|
||||
ld1 {v18.8b}, [x15], x7
|
||||
ssubl v18.8h, v18.8b, v25.8b
|
||||
ld1 {v19.8h}, [x16], x13
|
||||
smlal v23.4s, v18.4h, v19.4h
|
||||
smlal2 v24.4s, v18.8h, v19.8h
|
||||
b Post
|
||||
|
||||
LoopHW:
|
||||
mov x9, x1
|
||||
mov x10, x2
|
||||
mov x17, x4 // height
|
||||
LoopH:
|
||||
mov x11, x9
|
||||
mov x12, x10
|
||||
mov x18, x5 // width
|
||||
LoopW:
|
||||
ld1 {v0.8b}, [x11], x7
|
||||
ssubl v1.8h, v0.8b, v25.8b
|
||||
|
||||
ld1 {v2.8h}, [x12], x13 // weight
|
||||
smlal v23.4s, v1.4h, v2.4h
|
||||
smlal2 v24.4s, v1.8h, v2.8h
|
||||
|
||||
subs x18, x18, #1
|
||||
bne LoopW
|
||||
subs x17, x17, #1
|
||||
add x9, x9, x6
|
||||
add x10, x10, x14
|
||||
bne LoopH
|
||||
Post:
|
||||
sqshl v23.4s, v23.4s, v28.4s
|
||||
sqshl v24.4s, v24.4s, v28.4s
|
||||
sqrdmulh v23.4s, v23.4s, v27.4s
|
||||
sqrdmulh v24.4s, v24.4s, v27.4s
|
||||
|
||||
and v12.16b, v29.16b, v23.16b
|
||||
sshr v12.4s, v12.4s, #31
|
||||
sqadd v23.4s, v23.4s, v12.4s
|
||||
srshl v23.4s, v23.4s, v29.4s
|
||||
|
||||
and v11.16b, v29.16b, v24.16b
|
||||
sshr v11.4s, v11.4s, #31
|
||||
sqadd v24.4s, v24.4s, v11.4s
|
||||
srshl v24.4s, v24.4s, v29.4s
|
||||
|
||||
add v23.4s, v23.4s, v26.4s
|
||||
add v24.4s, v24.4s, v26.4s
|
||||
smax v23.4s, v23.4s, v30.4s
|
||||
smax v24.4s, v24.4s, v30.4s
|
||||
smin v23.4s, v23.4s, v31.4s
|
||||
smin v24.4s, v24.4s, v31.4s
|
||||
|
||||
sqxtn v23.4h, v23.4s
|
||||
sqxtn v24.4h, v24.4s
|
||||
sqxtn v23.8b, v23.8h
|
||||
sqxtn v24.8b, v24.8h
|
||||
|
||||
st1 {v23.s}[0], [x0], #4
|
||||
st1 {v24.s}[0], [x0], #4
|
||||
add x1, x1, #8
|
||||
add x2, x2, #16
|
||||
sub x8, x8, #8
|
||||
cmp x8, #8
|
||||
bge LoopC
|
||||
ret
|
||||
#endif
|
|
@ -0,0 +1,168 @@
|
|||
#ifdef __aarch64__
|
||||
|
||||
.text
|
||||
.align 5
|
||||
.global ConvDw3x3Int8Corner
|
||||
#ifndef __APPLE__
|
||||
.type ConvDw3x3Int8Corner, %function
|
||||
#endif
|
||||
|
||||
// void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step,
|
||||
// size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, size_t out_multiplier,
|
||||
// size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max)
|
||||
|
||||
// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step,
|
||||
// x6: channel, x7: in_zp, x8: out_zp, x9: out_multiplier, x10: left_shift, x11: right_shift
|
||||
// x11: acc_min, x13: acc_max
|
||||
ConvDw3x3Int8Corner:
|
||||
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
|
||||
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
|
||||
// x19 ~ x29 should be also preserved
|
||||
// whereas our coding style do not permit such amount of parameters
|
||||
dup v25.8b, w7 // in_zp
|
||||
ldr x9, [sp]
|
||||
dup v26.4s, w9 // out_zp
|
||||
ldr x9, [sp, #8]
|
||||
dup v27.4s, w9 // out_multiplier
|
||||
ldr x9, [sp, #16]
|
||||
dup v28.4s, w9 // left_shift
|
||||
ldr x9, [sp, #24]
|
||||
dup v29.4s, w9 // right_shift
|
||||
ldr x9, [sp, #32]
|
||||
dup v30.4s, w9 // acc_min
|
||||
ldr x9, [sp, #40]
|
||||
dup v31.4s, w9 // acc_max
|
||||
|
||||
mov x9, #2
|
||||
mul x13, x6, x9 // x6 * 2
|
||||
mov x9, #3
|
||||
mul x14, x13, x9 // x6 * 3 * 2
|
||||
|
||||
ld1 {v23.4s}, [x3], #16
|
||||
ld1 {v24.4s}, [x3], #16
|
||||
mov x9, x1
|
||||
mov x10, x2
|
||||
|
||||
ld1 {v0.8b}, [x9], x5
|
||||
ssubl v0.8h, v0.8b, v25.8b
|
||||
add x11, x1, x4
|
||||
ld1 {v4.8h}, [x10], x13 // weight
|
||||
add x12, x2, x14
|
||||
ld1 {v1.8b}, [x9], x5
|
||||
ssubl v1.8h, v1.8b, v25.8b
|
||||
ld1 {v5.8h}, [x10], x13
|
||||
ld1 {v2.8b}, [x11], x5
|
||||
ssubl v2.8h, v2.8b, v25.8b
|
||||
ld1 {v6.8h}, [x12], x13
|
||||
ld1 {v3.8b}, [x11], x5
|
||||
ssubl v3.8h, v3.8b, v25.8b
|
||||
ld1 {v7.8h}, [x12], x13
|
||||
|
||||
cmp x6, #8
|
||||
ble LoopC8Post
|
||||
|
||||
LoopC8:
|
||||
add x1, x1, #8
|
||||
add x2, x2, #16
|
||||
smlal v23.4s, v0.4h, v4.4h
|
||||
smlal2 v24.4s, v0.8h, v4.8h
|
||||
mov x9, x1
|
||||
mov x10, x2
|
||||
ld1 {v0.8b}, [x9], x5
|
||||
ssubl v0.8h, v0.8b, v25.8b
|
||||
ld1 {v4.8h}, [x10], x13 // weight
|
||||
add x11, x1, x4
|
||||
smlal v23.4s, v1.4h, v5.4h
|
||||
smlal2 v24.4s, v1.8h, v5.8h
|
||||
add x12, x2, x14
|
||||
ld1 {v1.8b}, [x9], x5
|
||||
ssubl v1.8h, v1.8b, v25.8b
|
||||
smlal v23.4s, v2.4h, v6.4h
|
||||
ld1 {v5.8h}, [x10], x13
|
||||
smlal2 v24.4s, v2.8h, v6.8h
|
||||
ld1 {v2.8b}, [x11], x5
|
||||
ssubl v2.8h, v2.8b, v25.8b
|
||||
smlal v23.4s, v3.4h, v7.4h
|
||||
ld1 {v6.8h}, [x12], x13
|
||||
smlal2 v24.4s, v3.8h, v7.8h
|
||||
|
||||
sqshl v23.4s, v23.4s, v28.4s
|
||||
sqshl v24.4s, v24.4s, v28.4s
|
||||
sqrdmulh v23.4s, v23.4s, v27.4s
|
||||
sqrdmulh v24.4s, v24.4s, v27.4s
|
||||
|
||||
and v21.16b, v29.16b, v23.16b
|
||||
sshr v21.4s, v21.4s, #31
|
||||
sqadd v23.4s, v23.4s, v21.4s
|
||||
srshl v23.4s, v23.4s, v29.4s
|
||||
|
||||
and v22.16b, v29.16b, v24.16b
|
||||
sshr v22.4s, v22.4s, #31
|
||||
sqadd v24.4s, v24.4s, v22.4s
|
||||
srshl v24.4s, v24.4s, v29.4s
|
||||
|
||||
ld1 {v3.8b}, [x11], x5
|
||||
ssubl v3.8h, v3.8b, v25.8b
|
||||
ld1 {v7.8h}, [x12], x13
|
||||
|
||||
add v23.4s, v23.4s, v26.4s
|
||||
add v24.4s, v24.4s, v26.4s
|
||||
smax v23.4s, v23.4s, v30.4s
|
||||
smax v24.4s, v24.4s, v30.4s
|
||||
smin v23.4s, v23.4s, v31.4s
|
||||
smin v24.4s, v24.4s, v31.4s
|
||||
|
||||
sqxtn v23.4h, v23.4s
|
||||
sqxtn v24.4h, v24.4s
|
||||
sqxtn v23.8b, v23.8h
|
||||
sqxtn v24.8b, v24.8h
|
||||
|
||||
st1 {v23.s}[0], [x0], #4
|
||||
st1 {v24.s}[0], [x0], #4
|
||||
ld1 {v23.4s}, [x3], #16
|
||||
ld1 {v24.4s}, [x3], #16
|
||||
sub x6, x6, #8
|
||||
cmp x6, #8
|
||||
bgt LoopC8
|
||||
|
||||
LoopC8Post:
|
||||
smlal v23.4s, v0.4h, v4.4h
|
||||
smlal2 v24.4s, v0.8h, v4.8h
|
||||
smlal v23.4s, v1.4h, v5.4h
|
||||
smlal2 v24.4s, v1.8h, v5.8h
|
||||
smlal v23.4s, v2.4h, v6.4h
|
||||
smlal2 v24.4s, v2.8h, v6.8h
|
||||
smlal v23.4s, v3.4h, v7.4h
|
||||
smlal2 v24.4s, v3.8h, v7.8h
|
||||
|
||||
sqshl v23.4s, v23.4s, v28.4s
|
||||
sqshl v24.4s, v24.4s, v28.4s
|
||||
sqrdmulh v23.4s, v23.4s, v27.4s
|
||||
sqrdmulh v24.4s, v24.4s, v27.4s
|
||||
|
||||
and v21.16b, v29.16b, v23.16b
|
||||
sshr v21.4s, v21.4s, #31
|
||||
sqadd v23.4s, v23.4s, v21.4s
|
||||
srshl v23.4s, v23.4s, v29.4s
|
||||
|
||||
and v22.16b, v29.16b, v24.16b
|
||||
sshr v22.4s, v22.4s, #31
|
||||
sqadd v24.4s, v24.4s, v22.4s
|
||||
srshl v24.4s, v24.4s, v29.4s
|
||||
|
||||
add v23.4s, v23.4s, v26.4s
|
||||
add v24.4s, v24.4s, v26.4s
|
||||
smax v23.4s, v23.4s, v30.4s
|
||||
smax v24.4s, v24.4s, v30.4s
|
||||
smin v23.4s, v23.4s, v31.4s
|
||||
smin v24.4s, v24.4s, v31.4s
|
||||
|
||||
sqxtn v23.4h, v23.4s
|
||||
sqxtn v24.4h, v24.4s
|
||||
sqxtn v23.8b, v23.8h
|
||||
sqxtn v24.8b, v24.8h
|
||||
|
||||
st1 {v23.s}[0], [x0], #4
|
||||
st1 {v24.s}[0], [x0], #4
|
||||
ret
|
||||
#endif
|
|
@ -0,0 +1,196 @@
|
|||
#ifdef __aarch64__
|
||||
|
||||
.text
|
||||
.align 5
|
||||
.global ConvDw3x3Int8Horizontal
|
||||
#ifndef __APPLE__
|
||||
.type ConvDw3x3Int8Horizontal, %function
|
||||
#endif
|
||||
|
||||
// void ConvDw3x3Int8Horizontal(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step,
|
||||
// size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, size_t out_multiplier,
|
||||
// size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max)
|
||||
|
||||
// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step,
|
||||
// x6: channel, x7: in_zp, x8: out_zp, x9: out_multiplier, x10: left_shift, x11: right_shift
|
||||
// x11: acc_min, x13: acc_max
|
||||
ConvDw3x3Int8Horizontal:
|
||||
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
|
||||
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
|
||||
// x19 ~ x29 should be also preserved
|
||||
// whereas our coding style do not permit such amount of parameters
|
||||
dup v25.8b, w7 // in_zp
|
||||
ldr x9, [sp]
|
||||
dup v26.4s, w9 // out_zp
|
||||
ldr x9, [sp, #8]
|
||||
dup v27.4s, w9 // out_multiplier
|
||||
ldr x9, [sp, #16]
|
||||
dup v28.4s, w9 // left_shift
|
||||
ldr x9, [sp, #24]
|
||||
dup v29.4s, w9 // right_shift
|
||||
ldr x9, [sp, #32]
|
||||
dup v30.4s, w9 // acc_min
|
||||
ldr x9, [sp, #40]
|
||||
dup v31.4s, w9 // acc_max
|
||||
|
||||
mov x9, #2
|
||||
mul x13, x6, x9 // x6 * 2
|
||||
mov x9, #3
|
||||
mul x14, x13, x9 // x6 * 3 * 2
|
||||
|
||||
ld1 {v23.4s}, [x3], #16
|
||||
ld1 {v24.4s}, [x3], #16
|
||||
mov x9, x1
|
||||
mov x10, x2
|
||||
|
||||
ld1 {v0.8b}, [x9], x5
|
||||
ssubl v0.8h, v0.8b, v25.8b
|
||||
add x11, x1, x4
|
||||
ld1 {v4.8h}, [x10], x13 // weight
|
||||
add x12, x2, x14
|
||||
ld1 {v1.8b}, [x9], x5
|
||||
ssubl v1.8h, v1.8b, v25.8b
|
||||
ld1 {v5.8h}, [x10], x13
|
||||
add x15, x11, x4
|
||||
ld1 {v2.8b}, [x11], x5
|
||||
ssubl v2.8h, v2.8b, v25.8b
|
||||
add x16, x12, x14
|
||||
ld1 {v6.8h}, [x12], x13
|
||||
ld1 {v3.8b}, [x11], x5
|
||||
ssubl v3.8h, v3.8b, v25.8b
|
||||
ld1 {v7.8h}, [x12], x13
|
||||
ld1 {v16.8b}, [x15], x5
|
||||
ssubl v16.8h, v16.8b, v25.8b
|
||||
ld1 {v18.8h}, [x16], x13
|
||||
ld1 {v17.8b}, [x15], x5
|
||||
ssubl v17.8h, v17.8b, v25.8b
|
||||
ld1 {v19.8h}, [x16], x13
|
||||
|
||||
cmp x6, #8
|
||||
ble LoopC8Post
|
||||
|
||||
LoopC8:
|
||||
add x1, x1, #8
|
||||
add x2, x2, #16
|
||||
smlal v23.4s, v0.4h, v4.4h
|
||||
smlal2 v24.4s, v0.8h, v4.8h
|
||||
mov x9, x1
|
||||
mov x10, x2
|
||||
ld1 {v0.8b}, [x9], x5
|
||||
ssubl v0.8h, v0.8b, v25.8b
|
||||
ld1 {v4.8h}, [x10], x13 // weight
|
||||
add x11, x1, x4
|
||||
smlal v23.4s, v1.4h, v5.4h
|
||||
smlal2 v24.4s, v1.8h, v5.8h
|
||||
|
||||
add x12, x2, x14
|
||||
ld1 {v1.8b}, [x9], x5
|
||||
ssubl v1.8h, v1.8b, v25.8b
|
||||
smlal v23.4s, v2.4h, v6.4h
|
||||
ld1 {v5.8h}, [x10], x13
|
||||
smlal2 v24.4s, v2.8h, v6.8h
|
||||
|
||||
add x15, x11, x4
|
||||
add x16, x12, x14
|
||||
ld1 {v2.8b}, [x11], x5
|
||||
ssubl v2.8h, v2.8b, v25.8b
|
||||
smlal v23.4s, v3.4h, v7.4h
|
||||
ld1 {v6.8h}, [x12], x13
|
||||
smlal2 v24.4s, v3.8h, v7.8h
|
||||
|
||||
ld1 {v3.8b}, [x11], x5
|
||||
ssubl v3.8h, v3.8b, v25.8b
|
||||
smlal v23.4s, v16.4h, v18.4h
|
||||
ld1 {v7.8h}, [x12], x13
|
||||
smlal2 v24.4s, v16.8h, v18.8h
|
||||
|
||||
ld1 {v16.8b}, [x15], x5
|
||||
ssubl v16.8h, v16.8b, v25.8b
|
||||
smlal v23.4s, v17.4h, v19.4h
|
||||
ld1 {v18.8h}, [x16], x13
|
||||
smlal2 v24.4s, v17.8h, v19.8h
|
||||
|
||||
sqshl v23.4s, v23.4s, v28.4s
|
||||
sqshl v24.4s, v24.4s, v28.4s
|
||||
sqrdmulh v23.4s, v23.4s, v27.4s
|
||||
sqrdmulh v24.4s, v24.4s, v27.4s
|
||||
|
||||
and v21.16b, v29.16b, v23.16b
|
||||
sshr v21.4s, v21.4s, #31
|
||||
sqadd v23.4s, v23.4s, v21.4s
|
||||
srshl v23.4s, v23.4s, v29.4s
|
||||
|
||||
and v22.16b, v29.16b, v24.16b
|
||||
sshr v22.4s, v22.4s, #31
|
||||
sqadd v24.4s, v24.4s, v22.4s
|
||||
srshl v24.4s, v24.4s, v29.4s
|
||||
|
||||
ld1 {v17.8b}, [x15], x5
|
||||
ssubl v17.8h, v17.8b, v25.8b
|
||||
ld1 {v19.8h}, [x16], x13
|
||||
|
||||
add v23.4s, v23.4s, v26.4s
|
||||
add v24.4s, v24.4s, v26.4s
|
||||
smax v23.4s, v23.4s, v30.4s
|
||||
smax v24.4s, v24.4s, v30.4s
|
||||
smin v23.4s, v23.4s, v31.4s
|
||||
smin v24.4s, v24.4s, v31.4s
|
||||
|
||||
sqxtn v23.4h, v23.4s
|
||||
sqxtn v24.4h, v24.4s
|
||||
sqxtn v23.8b, v23.8h
|
||||
sqxtn v24.8b, v24.8h
|
||||
|
||||
st1 {v23.s}[0], [x0], #4
|
||||
st1 {v24.s}[0], [x0], #4
|
||||
ld1 {v23.4s}, [x3], #16
|
||||
ld1 {v24.4s}, [x3], #16
|
||||
sub x6, x6, #8
|
||||
cmp x6, #8
|
||||
bgt LoopC8
|
||||
|
||||
LoopC8Post:
|
||||
smlal v23.4s, v0.4h, v4.4h
|
||||
smlal2 v24.4s, v0.8h, v4.8h
|
||||
smlal v23.4s, v1.4h, v5.4h
|
||||
smlal2 v24.4s, v1.8h, v5.8h
|
||||
smlal v23.4s, v2.4h, v6.4h
|
||||
smlal2 v24.4s, v2.8h, v6.8h
|
||||
smlal v23.4s, v3.4h, v7.4h
|
||||
smlal2 v24.4s, v3.8h, v7.8h
|
||||
smlal v23.4s, v16.4h, v18.4h
|
||||
smlal2 v24.4s, v16.8h, v18.8h
|
||||
smlal v23.4s, v17.4h, v19.4h
|
||||
smlal2 v24.4s, v17.8h, v19.8h
|
||||
|
||||
sqshl v23.4s, v23.4s, v28.4s
|
||||
sqshl v24.4s, v24.4s, v28.4s
|
||||
sqrdmulh v23.4s, v23.4s, v27.4s
|
||||
sqrdmulh v24.4s, v24.4s, v27.4s
|
||||
|
||||
and v21.16b, v29.16b, v23.16b
|
||||
sshr v21.4s, v21.4s, #31
|
||||
sqadd v23.4s, v23.4s, v21.4s
|
||||
srshl v23.4s, v23.4s, v29.4s
|
||||
|
||||
and v22.16b, v29.16b, v24.16b
|
||||
sshr v22.4s, v22.4s, #31
|
||||
sqadd v24.4s, v24.4s, v22.4s
|
||||
srshl v24.4s, v24.4s, v29.4s
|
||||
|
||||
add v23.4s, v23.4s, v26.4s
|
||||
add v24.4s, v24.4s, v26.4s
|
||||
smax v23.4s, v23.4s, v30.4s
|
||||
smax v24.4s, v24.4s, v30.4s
|
||||
smin v23.4s, v23.4s, v31.4s
|
||||
smin v24.4s, v24.4s, v31.4s
|
||||
|
||||
sqxtn v23.4h, v23.4s
|
||||
sqxtn v24.4h, v24.4s
|
||||
sqxtn v23.8b, v23.8h
|
||||
sqxtn v24.8b, v24.8h
|
||||
|
||||
st1 {v23.s}[0], [x0], #4
|
||||
st1 {v24.s}[0], [x0], #4
|
||||
ret
|
||||
#endif
|
|
@ -0,0 +1,192 @@
|
|||
#ifdef __aarch64__
|
||||
|
||||
.text
|
||||
.align 5
|
||||
.global ConvDw3x3Int8Vertical
|
||||
#ifndef __APPLE__
|
||||
.type ConvDw3x3Int8Vertical, %function
|
||||
#endif
|
||||
|
||||
// void ConvDw3x3Int8Vertical(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step,
|
||||
// size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, size_t out_multiplier,
|
||||
// size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max)
|
||||
|
||||
// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step,
|
||||
// x6: channel, x7: in_zp, x8: out_zp, x9: out_multiplier, x10: left_shift, x11: right_shift
|
||||
// x11: acc_min, x13: acc_max
|
||||
ConvDw3x3Int8Vertical:
|
||||
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
|
||||
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
|
||||
// x19 ~ x29 should be also preserved
|
||||
// whereas our coding style do not permit such amount of parameters
|
||||
dup v25.8b, w7 // in_zp
|
||||
ldr x9, [sp]
|
||||
dup v26.4s, w9 // out_zp
|
||||
ldr x9, [sp, #8]
|
||||
dup v27.4s, w9 // out_multiplier
|
||||
ldr x9, [sp, #16]
|
||||
dup v28.4s, w9 // left_shift
|
||||
ldr x9, [sp, #24]
|
||||
dup v29.4s, w9 // right_shift
|
||||
ldr x9, [sp, #32]
|
||||
dup v30.4s, w9 // acc_min
|
||||
ldr x9, [sp, #40]
|
||||
dup v31.4s, w9 // acc_max
|
||||
|
||||
mov x9, #2
|
||||
mul x13, x6, x9 // x6 * 2
|
||||
mov x9, #3
|
||||
mul x14, x13, x9 // x6 * 3 * 2
|
||||
|
||||
ld1 {v23.4s}, [x3], #16
|
||||
ld1 {v24.4s}, [x3], #16
|
||||
mov x9, x1
|
||||
mov x10, x2
|
||||
|
||||
ld1 {v0.8b}, [x9], x5
|
||||
ssubl v0.8h, v0.8b, v25.8b
|
||||
add x11, x1, x4
|
||||
ld1 {v4.8h}, [x10], x13 // weight
|
||||
add x12, x2, x14
|
||||
ld1 {v1.8b}, [x9], x5
|
||||
ssubl v1.8h, v1.8b, v25.8b
|
||||
ld1 {v5.8h}, [x10], x13
|
||||
ld1 {v2.8b}, [x11], x5
|
||||
ssubl v2.8h, v2.8b, v25.8b
|
||||
ld1 {v6.8h}, [x12], x13
|
||||
ld1 {v3.8b}, [x11], x5
|
||||
ssubl v3.8h, v3.8b, v25.8b
|
||||
ld1 {v7.8h}, [x12], x13
|
||||
ld1 {v16.8b}, [x9], x5
|
||||
ssubl v16.8h, v16.8b, v25.8b
|
||||
ld1 {v18.8h}, [x10], x13
|
||||
ld1 {v17.8b}, [x11], x5
|
||||
ssubl v17.8h, v17.8b, v25.8b
|
||||
ld1 {v19.8h}, [x12], x13
|
||||
|
||||
cmp x6, #8
|
||||
ble LoopC8Post
|
||||
|
||||
LoopC8:
|
||||
add x1, x1, #8
|
||||
add x2, x2, #16
|
||||
smlal v23.4s, v0.4h, v4.4h
|
||||
smlal2 v24.4s, v0.8h, v4.8h
|
||||
mov x9, x1
|
||||
mov x10, x2
|
||||
ld1 {v0.8b}, [x9], x5
|
||||
ssubl v0.8h, v0.8b, v25.8b
|
||||
ld1 {v4.8h}, [x10], x13 // weight
|
||||
add x11, x1, x4
|
||||
smlal v23.4s, v1.4h, v5.4h
|
||||
smlal2 v24.4s, v1.8h, v5.8h
|
||||
|
||||
add x12, x2, x14
|
||||
ld1 {v1.8b}, [x9], x5
|
||||
ssubl v1.8h, v1.8b, v25.8b
|
||||
smlal v23.4s, v2.4h, v6.4h
|
||||
ld1 {v5.8h}, [x10], x13
|
||||
smlal2 v24.4s, v2.8h, v6.8h
|
||||
|
||||
ld1 {v2.8b}, [x11], x5
|
||||
ssubl v2.8h, v2.8b, v25.8b
|
||||
smlal v23.4s, v3.4h, v7.4h
|
||||
ld1 {v6.8h}, [x12], x13
|
||||
smlal2 v24.4s, v3.8h, v7.8h
|
||||
|
||||
ld1 {v3.8b}, [x11], x5
|
||||
ssubl v3.8h, v3.8b, v25.8b
|
||||
smlal v23.4s, v16.4h, v18.4h
|
||||
ld1 {v7.8h}, [x12], x13
|
||||
smlal2 v24.4s, v16.8h, v18.8h
|
||||
|
||||
ld1 {v16.8b}, [x9], x5
|
||||
ssubl v16.8h, v16.8b, v25.8b
|
||||
smlal v23.4s, v17.4h, v19.4h
|
||||
ld1 {v18.8h}, [x10], x13
|
||||
smlal2 v24.4s, v17.8h, v19.8h
|
||||
|
||||
sqshl v23.4s, v23.4s, v28.4s
|
||||
sqshl v24.4s, v24.4s, v28.4s
|
||||
sqrdmulh v23.4s, v23.4s, v27.4s
|
||||
sqrdmulh v24.4s, v24.4s, v27.4s
|
||||
|
||||
and v21.16b, v29.16b, v23.16b
|
||||
sshr v21.4s, v21.4s, #31
|
||||
sqadd v23.4s, v23.4s, v21.4s
|
||||
srshl v23.4s, v23.4s, v29.4s
|
||||
|
||||
and v22.16b, v29.16b, v24.16b
|
||||
sshr v22.4s, v22.4s, #31
|
||||
sqadd v24.4s, v24.4s, v22.4s
|
||||
srshl v24.4s, v24.4s, v29.4s
|
||||
|
||||
ld1 {v17.8b}, [x11], x5
|
||||
ssubl v17.8h, v17.8b, v25.8b
|
||||
ld1 {v19.8h}, [x12], x13
|
||||
|
||||
add v23.4s, v23.4s, v26.4s
|
||||
add v24.4s, v24.4s, v26.4s
|
||||
smax v23.4s, v23.4s, v30.4s
|
||||
smax v24.4s, v24.4s, v30.4s
|
||||
smin v23.4s, v23.4s, v31.4s
|
||||
smin v24.4s, v24.4s, v31.4s
|
||||
|
||||
sqxtn v23.4h, v23.4s
|
||||
sqxtn v24.4h, v24.4s
|
||||
sqxtn v23.8b, v23.8h
|
||||
sqxtn v24.8b, v24.8h
|
||||
|
||||
st1 {v23.s}[0], [x0], #4
|
||||
st1 {v24.s}[0], [x0], #4
|
||||
ld1 {v23.4s}, [x3], #16
|
||||
ld1 {v24.4s}, [x3], #16
|
||||
sub x6, x6, #8
|
||||
cmp x6, #8
|
||||
bgt LoopC8
|
||||
|
||||
LoopC8Post:
|
||||
smlal v23.4s, v0.4h, v4.4h
|
||||
smlal2 v24.4s, v0.8h, v4.8h
|
||||
smlal v23.4s, v1.4h, v5.4h
|
||||
smlal2 v24.4s, v1.8h, v5.8h
|
||||
smlal v23.4s, v2.4h, v6.4h
|
||||
smlal2 v24.4s, v2.8h, v6.8h
|
||||
smlal v23.4s, v3.4h, v7.4h
|
||||
smlal2 v24.4s, v3.8h, v7.8h
|
||||
smlal v23.4s, v16.4h, v18.4h
|
||||
smlal2 v24.4s, v16.8h, v18.8h
|
||||
smlal v23.4s, v17.4h, v19.4h
|
||||
smlal2 v24.4s, v17.8h, v19.8h
|
||||
|
||||
sqshl v23.4s, v23.4s, v28.4s
|
||||
sqshl v24.4s, v24.4s, v28.4s
|
||||
sqrdmulh v23.4s, v23.4s, v27.4s
|
||||
sqrdmulh v24.4s, v24.4s, v27.4s
|
||||
|
||||
and v21.16b, v29.16b, v23.16b
|
||||
sshr v21.4s, v21.4s, #31
|
||||
sqadd v23.4s, v23.4s, v21.4s
|
||||
srshl v23.4s, v23.4s, v29.4s
|
||||
|
||||
and v22.16b, v29.16b, v24.16b
|
||||
sshr v22.4s, v22.4s, #31
|
||||
sqadd v24.4s, v24.4s, v22.4s
|
||||
srshl v24.4s, v24.4s, v29.4s
|
||||
|
||||
add v23.4s, v23.4s, v26.4s
|
||||
add v24.4s, v24.4s, v26.4s
|
||||
smax v23.4s, v23.4s, v30.4s
|
||||
smax v24.4s, v24.4s, v30.4s
|
||||
smin v23.4s, v23.4s, v31.4s
|
||||
smin v24.4s, v24.4s, v31.4s
|
||||
|
||||
sqxtn v23.4h, v23.4s
|
||||
sqxtn v24.4h, v24.4s
|
||||
sqxtn v23.8b, v23.8h
|
||||
sqxtn v24.8b, v24.8h
|
||||
|
||||
st1 {v23.s}[0], [x0], #4
|
||||
st1 {v24.s}[0], [x0], #4
|
||||
ret
|
||||
#endif
|
|
@ -47,10 +47,6 @@ void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, con
|
|||
size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp,
|
||||
int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift,
|
||||
int32_t *acc_min, int32_t *acc_max);
|
||||
void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height,
|
||||
size_t width, size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp,
|
||||
size_t out_zp, size_t out_multiplier, size_t left_shift, size_t right_shift,
|
||||
size_t acc_min, size_t acc_max);
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_ARM32
|
||||
|
@ -71,6 +67,21 @@ void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *wei
|
|||
void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width,
|
||||
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
|
||||
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step);
|
||||
void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias,
|
||||
int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp,
|
||||
int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min,
|
||||
int32_t acc_max);
|
||||
void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step,
|
||||
size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, size_t out_multiplier,
|
||||
size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max);
|
||||
void ConvDw3x3Int8Vertical(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias,
|
||||
size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp,
|
||||
size_t out_multiplier, size_t left_shift, size_t right_shift, size_t acc_min,
|
||||
size_t acc_max);
|
||||
void ConvDw3x3Int8Horizontal(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias,
|
||||
size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp,
|
||||
size_t out_multiplier, size_t left_shift, size_t right_shift, size_t acc_min,
|
||||
size_t acc_max);
|
||||
#endif
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -232,29 +232,31 @@ void ConvDw3x3Int8Row(int8_t *output, int8_t *buffer, const int8_t *input, const
|
|||
|
||||
int ih_offset = 64 * block_input_w;
|
||||
int w = start_w;
|
||||
for (; w <= end_w - block_output_w; w += block_output_w) {
|
||||
int8_t *output_ptr = output;
|
||||
const int8_t *input_ptr = input;
|
||||
const int16_t *weight_ptr = weight;
|
||||
const int32_t *bias_ptr = bias;
|
||||
int c = 0;
|
||||
for (; c <= conv_param->output_channel_ - 64; c += 64) {
|
||||
InitInputBuffer(buffer, input_ptr, conv_param, block_input_h, block_input_w);
|
||||
ConvDw3x3Int8Block(output_ptr, buffer, weight_ptr, bias_ptr, 0, 64, 64, ih_offset, conv_param->input_channel_,
|
||||
block_output_h, block_output_w, in_zp, out_zp, out_multiplier, left_shift, right_shift,
|
||||
acc_min, acc_max, conv_param->stride_h_);
|
||||
output_ptr += 64;
|
||||
input_ptr += 64;
|
||||
weight_ptr += 64;
|
||||
bias_ptr += 64;
|
||||
if (conv_param->output_channel_ > 64 || (conv_param->output_channel_ < 64 && conv_param->input_w_ > 150)) {
|
||||
for (; w <= end_w - block_output_w; w += block_output_w) {
|
||||
int8_t *output_ptr = output;
|
||||
const int8_t *input_ptr = input;
|
||||
const int16_t *weight_ptr = weight;
|
||||
const int32_t *bias_ptr = bias;
|
||||
int c = 0;
|
||||
for (; c <= conv_param->output_channel_ - 64; c += 64) {
|
||||
InitInputBuffer(buffer, input_ptr, conv_param, block_input_h, block_input_w);
|
||||
ConvDw3x3Int8Block(output_ptr, buffer, weight_ptr, bias_ptr, 0, 64, 64, ih_offset, conv_param->input_channel_,
|
||||
block_output_h, block_output_w, in_zp, out_zp, out_multiplier, left_shift, right_shift,
|
||||
acc_min, acc_max, conv_param->stride_h_);
|
||||
output_ptr += 64;
|
||||
input_ptr += 64;
|
||||
weight_ptr += 64;
|
||||
bias_ptr += 64;
|
||||
}
|
||||
// left channel
|
||||
ConvDw3x3Int8Block(output_ptr, input_ptr, weight_ptr, bias_ptr, c, conv_param->input_channel_,
|
||||
conv_param->input_channel_, conv_param->input_w_ * conv_param->input_channel_,
|
||||
conv_param->input_channel_, block_output_h, block_output_w, in_zp, out_zp, out_multiplier,
|
||||
left_shift, right_shift, acc_min, acc_max, conv_param->stride_h_);
|
||||
output += block_output_w * conv_param->input_channel_;
|
||||
input += conv_param->stride_w_ * block_output_w * conv_param->input_channel_;
|
||||
}
|
||||
// left channel
|
||||
ConvDw3x3Int8Block(output_ptr, input_ptr, weight_ptr, bias_ptr, c, conv_param->input_channel_,
|
||||
conv_param->input_channel_, conv_param->input_w_ * conv_param->input_channel_,
|
||||
conv_param->input_channel_, block_output_h, block_output_w, in_zp, out_zp, out_multiplier,
|
||||
left_shift, right_shift, acc_min, acc_max, conv_param->stride_h_);
|
||||
output += block_output_w * conv_param->input_channel_;
|
||||
input += conv_param->stride_w_ * block_output_w * conv_param->input_channel_;
|
||||
}
|
||||
// left width
|
||||
int left_width = end_w - w;
|
||||
|
@ -300,8 +302,7 @@ void ConvDw3x3Int8(int8_t *output_data, int8_t *buffer, const int8_t *input_data
|
|||
}
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ARM
|
||||
void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height,
|
||||
void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height,
|
||||
int width, int in_kh_step, int in_kw_step, int channel, int8_t in_zp, int32_t out_zp,
|
||||
int out_multiplier, int left_shift, int right_shift, int32_t acc_min, int32_t acc_max) {
|
||||
for (int c = 0; c < channel; c += 8) {
|
||||
|
@ -337,9 +338,30 @@ void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *wei
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ARM64
|
||||
void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int in_kh_step,
|
||||
int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, int out_multiplier, int left_shift,
|
||||
int right_shift, int32_t acc_min, int32_t acc_max) {
|
||||
ConvDw3x3Int8BorderPixel(dst, src, weight, bias, 2, 2, in_kh_step, in_kw_step, channel, in_zp, out_zp, out_multiplier,
|
||||
left_shift, right_shift, acc_min, acc_max);
|
||||
}
|
||||
|
||||
void ConvDw3x3Int8Vertical(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int in_kh_step,
|
||||
int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, int out_multiplier,
|
||||
int left_shift, int right_shift, int32_t acc_min, int32_t acc_max) {
|
||||
ConvDw3x3Int8BorderPixel(dst, src, weight, bias, 2, 3, in_kh_step, in_kw_step, channel, in_zp, out_zp, out_multiplier,
|
||||
left_shift, right_shift, acc_min, acc_max);
|
||||
}
|
||||
void ConvDw3x3Int8Horizontal(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int in_kh_step,
|
||||
int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, int out_multiplier,
|
||||
int left_shift, int right_shift, int32_t acc_min, int32_t acc_max) {
|
||||
ConvDw3x3Int8BorderPixel(dst, src, weight, bias, 3, 2, in_kh_step, in_kw_step, channel, in_zp, out_zp, out_multiplier,
|
||||
left_shift, right_shift, acc_min, acc_max);
|
||||
}
|
||||
#endif
|
||||
|
||||
void ConvDw3x3BorderInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int top,
|
||||
void ConvDw3x3Int8Border(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int top,
|
||||
int bottom, int left, int right, const ConvParameter *conv_param,
|
||||
const SlidingWindowParam *sliding, int8_t in_zp, int32_t out_zp, int out_multiplier,
|
||||
int left_shift, int right_shift, int32_t acc_min, int32_t acc_max) {
|
||||
|
@ -361,7 +383,7 @@ void ConvDw3x3BorderInt8(int8_t *dst, const int8_t *src, const int16_t *weight,
|
|||
const int16_t *weight_kernel =
|
||||
weight + (start_kh * conv_param->kernel_w_ + start_kw) * conv_param->input_channel_;
|
||||
|
||||
ConvDw3x3BorderPixelInt8(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
|
||||
ConvDw3x3Int8BorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
|
||||
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->input_channel_, in_zp, out_zp,
|
||||
out_multiplier, left_shift, right_shift, acc_min, acc_max);
|
||||
|
||||
|
@ -371,7 +393,7 @@ void ConvDw3x3BorderInt8(int8_t *dst, const int8_t *src, const int16_t *weight,
|
|||
} // height loop
|
||||
}
|
||||
|
||||
void ConvDw3x3PadInt8(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data,
|
||||
void ConvDw3x3Int8Pad(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data,
|
||||
const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
|
||||
int out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0];
|
||||
int left_shift = conv_param->conv_quant_arg_.left_shift_[0];
|
||||
|
@ -380,17 +402,70 @@ void ConvDw3x3PadInt8(int8_t *output_data, const int8_t *input_data, const int16
|
|||
int out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_;
|
||||
int acc_min = conv_param->conv_quant_arg_.out_act_min_[0];
|
||||
int acc_max = conv_param->conv_quant_arg_.out_act_max_[0];
|
||||
ConvDw3x3BorderInt8(output_data, input_data, weight_data, bias_data, 0, sliding->top_, 0, conv_param->output_w_,
|
||||
conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
|
||||
ConvDw3x3BorderInt8(output_data, input_data, weight_data, bias_data, sliding->bottom_, conv_param->output_h_, 0,
|
||||
conv_param->output_w_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift,
|
||||
right_shift, acc_min, acc_max);
|
||||
ConvDw3x3BorderInt8(output_data, input_data, weight_data, bias_data, sliding->top_, sliding->bottom_, 0,
|
||||
sliding->left_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, right_shift,
|
||||
acc_min, acc_max);
|
||||
ConvDw3x3BorderInt8(output_data, input_data, weight_data, bias_data, sliding->top_, sliding->bottom_, sliding->right_,
|
||||
conv_param->output_w_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift,
|
||||
right_shift, acc_min, acc_max);
|
||||
int input_row_size = conv_param->input_w_ * conv_param->input_channel_;
|
||||
int weight_row_size = conv_param->kernel_w_ * conv_param->input_channel_;
|
||||
int output_row_size = conv_param->output_w_ * conv_param->output_channel_;
|
||||
int in_kh_step = sliding->in_kh_step_;
|
||||
int in_kw_step = sliding->in_kw_step_;
|
||||
|
||||
// top
|
||||
const int8_t *input = input_data;
|
||||
const int16_t *weight = weight_data + weight_row_size + conv_param->input_channel_;
|
||||
int8_t *output = output_data;
|
||||
ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp,
|
||||
out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
|
||||
input += (conv_param->stride_w_ - 1) * conv_param->input_channel_;
|
||||
weight = weight_data + weight_row_size;
|
||||
output += conv_param->output_channel_;
|
||||
for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) {
|
||||
ConvDw3x3Int8Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp,
|
||||
out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
|
||||
input += conv_param->stride_w_ * conv_param->input_channel_;
|
||||
output += conv_param->output_channel_;
|
||||
}
|
||||
ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp,
|
||||
out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
|
||||
|
||||
// left
|
||||
input = input_data + (conv_param->stride_h_ - 1) * input_row_size;
|
||||
weight = weight_data + conv_param->input_channel_;
|
||||
output = output_data + output_row_size;
|
||||
for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) {
|
||||
ConvDw3x3Int8Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp,
|
||||
out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
|
||||
input += conv_param->stride_h_ * input_row_size;
|
||||
output += output_row_size;
|
||||
}
|
||||
|
||||
// right
|
||||
input =
|
||||
input_data + (conv_param->input_w_ - 2) * conv_param->input_channel_ + (conv_param->stride_h_ - 1) * input_row_size;
|
||||
weight = weight_data;
|
||||
output = output_data + output_row_size + (conv_param->output_w_ - 1) * conv_param->output_channel_;
|
||||
for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) {
|
||||
ConvDw3x3Int8Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp,
|
||||
out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
|
||||
input += conv_param->stride_h_ * input_row_size;
|
||||
output += output_row_size;
|
||||
}
|
||||
|
||||
// bottom
|
||||
input = input_data + (conv_param->input_h_ - 2) * input_row_size;
|
||||
weight = weight_data + conv_param->input_channel_;
|
||||
output = output_data + (conv_param->output_h_ - 1) * output_row_size;
|
||||
ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp,
|
||||
out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
|
||||
input += conv_param->stride_w_ == 1 ? 0 : conv_param->input_channel_;
|
||||
weight = weight_data;
|
||||
output += conv_param->output_channel_;
|
||||
for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) {
|
||||
ConvDw3x3Int8Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp,
|
||||
out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
|
||||
input += conv_param->stride_w_ * conv_param->input_channel_;
|
||||
output += conv_param->output_channel_;
|
||||
}
|
||||
ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp,
|
||||
out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
|
||||
}
|
||||
/*conv depthwise 3x3 int8 end*/
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ bool CheckIfUse3X3(const ConvParameter *conv_param, int channel);
|
|||
void ConvDwInt8(int8_t *output_data, int32_t *output_row, const int8_t *input_data, const int16_t *weight_data,
|
||||
const int32_t *bias_data, const ConvParameter *conv_param, int task_id);
|
||||
|
||||
void ConvDw3x3PadInt8(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data,
|
||||
void ConvDw3x3Int8Pad(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data,
|
||||
const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding);
|
||||
|
||||
void ConvDw3x3Int8(int8_t *output_data, int8_t *buffer, const int8_t *input_data, const int16_t *weight_data,
|
||||
|
@ -44,13 +44,6 @@ void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *in
|
|||
const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding,
|
||||
int task_id);
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias,
|
||||
int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp,
|
||||
int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min,
|
||||
int32_t acc_max);
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -164,7 +164,7 @@ int ConvolutionDepthwise3x3Int8CPUKernel::Run() {
|
|||
|
||||
if (sliding_->top_ > 0 || sliding_->bottom_ < conv_param_->output_h_ || sliding_->left_ > 0 ||
|
||||
sliding_->right_ < conv_param_->output_w_) {
|
||||
ConvDw3x3PadInt8(output_ptr_, input_ptr_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), conv_param_,
|
||||
ConvDw3x3Int8Pad(output_ptr_, input_ptr_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), conv_param_,
|
||||
sliding_);
|
||||
}
|
||||
ret = ParallelLaunch(this->context_->thread_pool_, ConvDw3x3Int8Run, this, conv_param_->thread_num_);
|
||||
|
|
Loading…
Reference in New Issue