!6038 [MSLITE][Develop] arm cpu int8 conv depthwise support activation per channel

Merge pull request !6038 from yangruoqi713/act_per_channel
This commit is contained in:
mindspore-ci-bot 2020-09-11 17:45:53 +08:00 committed by Gitee
commit 8097d6c278
10 changed files with 540 additions and 654 deletions

View File

@ -7,13 +7,15 @@
.type ConvDwInt8Center, %function
#endif
// void ConvDwInt8Center(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, size_t height, size_t width,
// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step,
// size_t in_kh_step, size_t in_kw_step, int out_multiplier, int left_shift,
// int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max);
// void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height,
// size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel,
// size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp,
// int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift,
// int32_t *acc_min, int32_t *acc_max)
// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: weight, x6: kernel_h, x7: kernel_w,
// x8: out_h_step, x9: block_channel, x10: in_sh_step, x11: in_sw_step, x12: in_kh_step, x13: in_kw_step
// x14: out_multiplier, #56: left_shift, #64: right_shift, #72:out_zp, #80: acc_min, #88: acc_max
// x14: in_zp, #56: out_zp, #64: out_multiplier, #72:left_shift, #80: right_shift, #88: acc_min, #96: acc_max
ConvDwInt8Center:
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
@ -33,489 +35,174 @@ ConvDwInt8Center:
ldr x12, [sp, #32]
ldr x13, [sp, #40]
ldr w14, [sp, #56]
dup v26.4s, w14
ldr x14, [sp, #48] // input_zp
ld1 {v19.8b}, [x14], #8
ldr x15, [sp, #48]
dup v27.4s, w15
ldr x15, [sp, #56] // output_zp
ld1 {v20.4s}, [x15], #16
ld1 {v21.4s}, [x15], #16
ldr w16, [sp, #64]
dup v28.4s, w16
ldr x16, [sp, #64] // out_multiplier
ld1 {v22.4s}, [x16], #16
ld1 {v23.4s}, [x16], #16
ldr w17, [sp, #72]
dup v29.4s, w17
ldr w18, [sp, #80]
dup v30.4s, w18
ldr x17, [sp, #72] // left_shift
ld1 {v24.4s}, [x17], #16
ld1 {v25.4s}, [x17], #16
ldr w19, [sp, #88]
dup v31.4s, w19
ldr x18, [sp, #80] // right shift
ld1 {v26.4s}, [x18], #16
ld1 {v27.4s}, [x18], #16
ld1 {v24.4s}, [x3]
ldr x19, [sp, #88] // acc_min
ld1 {v28.4s}, [x19], #16
ld1 {v29.4s}, [x19], #16
ldr x20, [sp, #96] // acc_max
ld1 {v30.4s}, [x20], #16
ld1 {v31.4s}, [x20], #16
ld1 {v17.4s}, [x3], #16
ld1 {v18.4s}, [x3], #16
LoopH:
mov x23, x1
mov x24, x5
mov x3, x0
cmp x24, #8
blt LoopW
cmp x24, #16
blt LoopW8
LoopW16:
mov x19, #16
LoopW4:
mov x19, #4
mul x19, x19, x11
mov x25, #4
mul x25, x25, x9
mov x16, x23
mov x17, x2
mov x20, x6
mov v0.16b, v24.16b
mov v1.16b, v24.16b
mov v2.16b, v24.16b
mov v3.16b, v24.16b
mov v4.16b, v24.16b
mov v5.16b, v24.16b
mov v6.16b, v24.16b
mov v7.16b, v24.16b
mov v8.16b, v24.16b
mov v9.16b, v24.16b
mov v10.16b, v24.16b
mov v11.16b, v24.16b
mov v12.16b, v24.16b
mov v13.16b, v24.16b
mov v14.16b, v24.16b
mov v15.16b, v24.16b
LoopKh16:
mov v0.16b, v17.16b
mov v1.16b, v18.16b
mov v2.16b, v17.16b
mov v3.16b, v18.16b
mov v4.16b, v17.16b
mov v5.16b, v18.16b
mov v6.16b, v17.16b
mov v7.16b, v18.16b
LoopKh4:
mov x18, x7
mov x21, x16
LoopKw16:
LoopKw4:
mov x22, x21
ld1 {v25.4h}, [x17], #8
ld1 {v16.4h}, [x22], x11
ld1 {v17.4h}, [x22], x11
smlal v0.4s, v16.4h, v25.4h
smlal v1.4s, v17.4h, v25.4h
ld1 {v18.4h}, [x22], x11
ld1 {v19.4h}, [x22], x11
smlal v2.4s, v18.4h, v25.4h
smlal v3.4s, v19.4h, v25.4h
ld1 {v20.4h}, [x22], x11
ld1 {v21.4h}, [x22], x11
smlal v4.4s, v20.4h, v25.4h
smlal v5.4s, v21.4h, v25.4h
ld1 {v22.4h}, [x22], x11
ld1 {v23.4h}, [x22], x11
smlal v6.4s, v22.4h, v25.4h
smlal v7.4s, v23.4h, v25.4h
ld1 {v16.4h}, [x22], x11
ld1 {v17.4h}, [x22], x11
smlal v8.4s, v16.4h, v25.4h
smlal v9.4s, v17.4h, v25.4h
ld1 {v18.4h}, [x22], x11
ld1 {v19.4h}, [x22], x11
smlal v10.4s, v18.4h, v25.4h
smlal v11.4s, v19.4h, v25.4h
ld1 {v20.4h}, [x22], x11
ld1 {v21.4h}, [x22], x11
smlal v12.4s, v20.4h, v25.4h
smlal v13.4s, v21.4h, v25.4h
ld1 {v22.4h}, [x22], x11
ld1 {v23.4h}, [x22], x11
smlal v14.4s, v22.4h, v25.4h
smlal v15.4s, v23.4h, v25.4h
ld1 {v16.8h}, [x17], #16
ld1 {v15.8b}, [x22], x11
ssubl v14.8h, v15.8b, v19.8b
smlal v0.4s, v14.4h, v16.4h
smlal2 v1.4s, v14.8h, v16.8h
ld1 {v13.8b}, [x22], x11
ssubl v12.8h, v13.8b, v19.8b
smlal v2.4s, v12.4h, v16.4h
smlal2 v3.4s, v12.8h, v16.8h
ld1 {v11.8b}, [x22], x11
ssubl v10.8h, v11.8b, v19.8b
smlal v4.4s, v10.4h, v16.4h
smlal2 v5.4s, v10.8h, v16.8h
ld1 {v9.8b}, [x22], x11
ssubl v8.8h, v9.8b, v19.8b
smlal v6.4s, v8.4h, v16.4h
smlal2 v7.4s, v8.8h, v16.8h
subs x18, x18, #1
add x21, x21, x13
bne LoopKw16
bne LoopKw4
add x16, x16, x12
subs x20, x20, #1
bne LoopKh16
bne LoopKh4
sqshl v0.4s, v0.4s, v26.4s
sqshl v1.4s, v1.4s, v26.4s
sqshl v2.4s, v2.4s, v26.4s
sqshl v3.4s, v3.4s, v26.4s
sqshl v4.4s, v4.4s, v26.4s
sqshl v5.4s, v5.4s, v26.4s
sqshl v6.4s, v6.4s, v26.4s
sqshl v7.4s, v7.4s, v26.4s
sqshl v8.4s, v8.4s, v26.4s
sqshl v9.4s, v9.4s, v26.4s
sqshl v10.4s, v10.4s, v26.4s
sqshl v11.4s, v11.4s, v26.4s
sqshl v12.4s, v12.4s, v26.4s
sqshl v13.4s, v13.4s, v26.4s
sqshl v14.4s, v14.4s, v26.4s
sqshl v15.4s, v15.4s, v26.4s
sqrdmulh v0.4s, v0.4s, v27.4s
sqrdmulh v1.4s, v1.4s, v27.4s
sqrdmulh v2.4s, v2.4s, v27.4s
sqrdmulh v3.4s, v3.4s, v27.4s
sqrdmulh v4.4s, v4.4s, v27.4s
sqrdmulh v5.4s, v5.4s, v27.4s
sqrdmulh v6.4s, v6.4s, v27.4s
sqrdmulh v7.4s, v7.4s, v27.4s
sqrdmulh v8.4s, v8.4s, v27.4s
sqrdmulh v9.4s, v9.4s, v27.4s
sqrdmulh v10.4s, v10.4s, v27.4s
sqrdmulh v11.4s, v11.4s, v27.4s
sqrdmulh v12.4s, v12.4s, v27.4s
sqrdmulh v13.4s, v13.4s, v27.4s
sqrdmulh v14.4s, v14.4s, v27.4s
sqrdmulh v15.4s, v15.4s, v27.4s
sqshl v0.4s, v0.4s, v24.4s
sqshl v1.4s, v1.4s, v25.4s
sqshl v2.4s, v2.4s, v24.4s
sqshl v3.4s, v3.4s, v25.4s
sqshl v4.4s, v4.4s, v24.4s
sqshl v5.4s, v5.4s, v25.4s
sqshl v6.4s, v6.4s, v24.4s
sqshl v7.4s, v7.4s, v25.4s
and v16.16b, v28.16b, v0.16b
sshr v16.4s, v16.4s, #31
sqadd v0.4s, v0.4s, v16.4s
srshl v0.4s, v0.4s, v28.4s
and v17.16b, v28.16b, v1.16b
sshr v17.4s, v17.4s, #31
sqadd v1.4s, v1.4s, v17.4s
srshl v1.4s, v1.4s, v28.4s
and v18.16b, v28.16b, v2.16b
sshr v18.4s, v18.4s, #31
sqadd v2.4s, v2.4s, v18.4s
srshl v2.4s, v2.4s, v28.4s
and v19.16b, v28.16b, v3.16b
sshr v19.4s, v19.4s, #31
sqadd v3.4s, v3.4s, v19.4s
srshl v3.4s, v3.4s, v28.4s
and v20.16b, v28.16b, v4.16b
sshr v20.4s, v20.4s, #31
sqadd v4.4s, v4.4s, v20.4s
srshl v4.4s, v4.4s, v28.4s
and v21.16b, v28.16b, v5.16b
sshr v21.4s, v21.4s, #31
sqadd v5.4s, v5.4s, v21.4s
srshl v5.4s, v5.4s, v28.4s
and v22.16b, v28.16b, v6.16b
sshr v22.4s, v22.4s, #31
sqadd v6.4s, v6.4s, v22.4s
srshl v6.4s, v6.4s, v28.4s
and v23.16b, v28.16b, v7.16b
sshr v23.4s, v23.4s, #31
sqadd v7.4s, v7.4s, v23.4s
srshl v7.4s, v7.4s, v28.4s
and v16.16b, v28.16b, v8.16b
sshr v16.4s, v16.4s, #31
sqadd v8.4s, v8.4s, v16.4s
srshl v8.4s, v8.4s, v28.4s
and v17.16b, v28.16b, v9.16b
sshr v17.4s, v17.4s, #31
sqadd v9.4s, v9.4s, v17.4s
srshl v9.4s, v9.4s, v28.4s
and v18.16b, v28.16b, v10.16b
sshr v18.4s, v18.4s, #31
sqadd v10.4s, v10.4s, v18.4s
srshl v10.4s, v10.4s, v28.4s
and v19.16b, v28.16b, v11.16b
sshr v19.4s, v19.4s, #31
sqadd v11.4s, v11.4s, v19.4s
srshl v11.4s, v11.4s, v28.4s
and v20.16b, v28.16b, v12.16b
sshr v20.4s, v20.4s, #31
sqadd v12.4s, v12.4s, v20.4s
srshl v12.4s, v12.4s, v28.4s
and v21.16b, v28.16b, v13.16b
sshr v21.4s, v21.4s, #31
sqadd v13.4s, v13.4s, v21.4s
srshl v13.4s, v13.4s, v28.4s
and v22.16b, v28.16b, v14.16b
sshr v22.4s, v22.4s, #31
sqadd v14.4s, v14.4s, v22.4s
srshl v14.4s, v14.4s, v28.4s
and v23.16b, v28.16b, v15.16b
sshr v23.4s, v23.4s, #31
sqadd v15.4s, v15.4s, v23.4s
srshl v15.4s, v15.4s, v28.4s
sqrdmulh v0.4s, v0.4s, v22.4s
sqrdmulh v1.4s, v1.4s, v23.4s
sqrdmulh v2.4s, v2.4s, v22.4s
sqrdmulh v3.4s, v3.4s, v23.4s
sqrdmulh v4.4s, v4.4s, v22.4s
sqrdmulh v5.4s, v5.4s, v23.4s
sqrdmulh v6.4s, v6.4s, v22.4s
sqrdmulh v7.4s, v7.4s, v23.4s
add v0.4s, v0.4s, v29.4s
add v1.4s, v1.4s, v29.4s
add v2.4s, v2.4s, v29.4s
add v3.4s, v3.4s, v29.4s
add v4.4s, v4.4s, v29.4s
add v5.4s, v5.4s, v29.4s
add v6.4s, v6.4s, v29.4s
add v7.4s, v7.4s, v29.4s
add v8.4s, v8.4s, v29.4s
add v9.4s, v9.4s, v29.4s
add v10.4s, v10.4s, v29.4s
add v11.4s, v11.4s, v29.4s
add v12.4s, v12.4s, v29.4s
add v13.4s, v13.4s, v29.4s
add v14.4s, v14.4s, v29.4s
add v15.4s, v15.4s, v29.4s
smax v0.4s, v0.4s, v30.4s
smax v1.4s, v1.4s, v30.4s
smax v2.4s, v2.4s, v30.4s
smax v3.4s, v3.4s, v30.4s
smax v4.4s, v4.4s, v30.4s
smax v5.4s, v5.4s, v30.4s
smax v6.4s, v6.4s, v30.4s
smax v7.4s, v7.4s, v30.4s
smax v8.4s, v8.4s, v30.4s
smax v9.4s, v9.4s, v30.4s
smax v10.4s, v10.4s, v30.4s
smax v11.4s, v11.4s, v30.4s
smax v12.4s, v12.4s, v30.4s
smax v13.4s, v13.4s, v30.4s
smax v14.4s, v14.4s, v30.4s
smax v15.4s, v15.4s, v30.4s
smin v0.4s, v0.4s, v31.4s
and v15.16b, v26.16b, v0.16b
sshr v15.4s, v15.4s, #31
sqadd v0.4s, v0.4s, v15.4s
srshl v0.4s, v0.4s, v26.4s
and v14.16b, v27.16b, v1.16b
sshr v14.4s, v14.4s, #31
sqadd v1.4s, v1.4s, v14.4s
srshl v1.4s, v1.4s, v27.4s
and v13.16b, v26.16b, v2.16b
sshr v13.4s, v13.4s, #31
sqadd v2.4s, v2.4s, v13.4s
srshl v2.4s, v2.4s, v26.4s
and v12.16b, v27.16b, v3.16b
sshr v12.4s, v12.4s, #31
sqadd v3.4s, v3.4s, v12.4s
srshl v3.4s, v3.4s, v27.4s
and v11.16b, v26.16b, v4.16b
sshr v11.4s, v11.4s, #31
sqadd v4.4s, v4.4s, v11.4s
srshl v4.4s, v4.4s, v26.4s
and v10.16b, v27.16b, v5.16b
sshr v10.4s, v10.4s, #31
sqadd v5.4s, v5.4s, v10.4s
srshl v5.4s, v5.4s, v27.4s
and v9.16b, v26.16b, v6.16b
sshr v9.4s, v9.4s, #31
sqadd v6.4s, v6.4s, v9.4s
srshl v6.4s, v6.4s, v26.4s
and v8.16b, v27.16b, v7.16b
sshr v8.4s, v8.4s, #31
sqadd v7.4s, v7.4s, v8.4s
srshl v7.4s, v7.4s, v27.4s
add v0.4s, v0.4s, v20.4s
add v1.4s, v1.4s, v21.4s
add v2.4s, v2.4s, v20.4s
add v3.4s, v3.4s, v21.4s
add v4.4s, v4.4s, v20.4s
add v5.4s, v5.4s, v21.4s
add v6.4s, v6.4s, v20.4s
add v7.4s, v7.4s, v21.4s
smax v0.4s, v0.4s, v28.4s
smax v1.4s, v1.4s, v29.4s
smax v2.4s, v2.4s, v28.4s
smax v3.4s, v3.4s, v29.4s
smax v4.4s, v4.4s, v28.4s
smax v5.4s, v5.4s, v29.4s
smax v6.4s, v6.4s, v28.4s
smax v7.4s, v7.4s, v29.4s
smin v0.4s, v0.4s, v30.4s
smin v1.4s, v1.4s, v31.4s
smin v2.4s, v2.4s, v31.4s
smin v2.4s, v2.4s, v30.4s
smin v3.4s, v3.4s, v31.4s
smin v4.4s, v4.4s, v31.4s
smin v4.4s, v4.4s, v30.4s
smin v5.4s, v5.4s, v31.4s
smin v6.4s, v6.4s, v31.4s
smin v7.4s, v7.4s, v31.4s
smin v8.4s, v8.4s, v31.4s
smin v9.4s, v9.4s, v31.4s
smin v10.4s, v10.4s, v31.4s
smin v11.4s, v11.4s, v31.4s
smin v12.4s, v12.4s, v31.4s
smin v13.4s, v13.4s, v31.4s
smin v14.4s, v14.4s, v31.4s
smin v15.4s, v15.4s, v31.4s
sqxtn v0.4h, v0.4s
sqxtn v1.4h, v1.4s
sqxtn v2.4h, v2.4s
sqxtn v3.4h, v3.4s
sqxtn v4.4h, v4.4s
sqxtn v5.4h, v5.4s
sqxtn v6.4h, v6.4s
sqxtn v7.4h, v7.4s
sqxtn v8.4h, v8.4s
sqxtn v9.4h, v9.4s
sqxtn v10.4h, v10.4s
sqxtn v11.4h, v11.4s
sqxtn v12.4h, v12.4s
sqxtn v13.4h, v13.4s
sqxtn v14.4h, v14.4s
sqxtn v15.4h, v15.4s
sqxtn v0.8b, v0.8h
sqxtn v1.8b, v1.8h
sqxtn v2.8b, v2.8h
sqxtn v3.8b, v3.8h
sqxtn v4.8b, v4.8h
sqxtn v5.8b, v5.8h
sqxtn v6.8b, v6.8h
sqxtn v7.8b, v7.8h
sqxtn v8.8b, v8.8h
sqxtn v9.8b, v9.8h
sqxtn v10.8b, v10.8h
sqxtn v11.8b, v11.8h
sqxtn v12.8b, v12.8h
sqxtn v13.8b, v13.8h
sqxtn v14.8b, v14.8h
sqxtn v15.8b, v15.8h
add x17, x3, #1
add x18, x3, #2
add x21, x3, #3
st1 {v0.b}[0], [x3], x9
st1 {v0.b}[1], [x17], x9
st1 {v0.b}[2], [x18], x9
st1 {v0.b}[3], [x21], x9
st1 {v1.b}[0], [x3], x9
st1 {v1.b}[1], [x17], x9
st1 {v1.b}[2], [x18], x9
st1 {v1.b}[3], [x21], x9
st1 {v2.b}[0], [x3], x9
st1 {v2.b}[1], [x17], x9
st1 {v2.b}[2], [x18], x9
st1 {v2.b}[3], [x21], x9
st1 {v3.b}[0], [x3], x9
st1 {v3.b}[1], [x17], x9
st1 {v3.b}[2], [x18], x9
st1 {v3.b}[3], [x21], x9
st1 {v4.b}[0], [x3], x9
st1 {v4.b}[1], [x17], x9
st1 {v4.b}[2], [x18], x9
st1 {v4.b}[3], [x21], x9
st1 {v5.b}[0], [x3], x9
st1 {v5.b}[1], [x17], x9
st1 {v5.b}[2], [x18], x9
st1 {v5.b}[3], [x21], x9
st1 {v6.b}[0], [x3], x9
st1 {v6.b}[1], [x17], x9
st1 {v6.b}[2], [x18], x9
st1 {v6.b}[3], [x21], x9
st1 {v7.b}[0], [x3], x9
st1 {v7.b}[1], [x17], x9
st1 {v7.b}[2], [x18], x9
st1 {v7.b}[3], [x21], x9
st1 {v8.b}[0], [x3], x9
st1 {v8.b}[1], [x17], x9
st1 {v8.b}[2], [x18], x9
st1 {v8.b}[3], [x21], x9
st1 {v9.b}[0], [x3], x9
st1 {v9.b}[1], [x17], x9
st1 {v9.b}[2], [x18], x9
st1 {v9.b}[3], [x21], x9
st1 {v10.b}[0], [x3], x9
st1 {v10.b}[1], [x17], x9
st1 {v10.b}[2], [x18], x9
st1 {v10.b}[3], [x21], x9
st1 {v11.b}[0], [x3], x9
st1 {v11.b}[1], [x17], x9
st1 {v11.b}[2], [x18], x9
st1 {v11.b}[3], [x21], x9
st1 {v12.b}[0], [x3], x9
st1 {v12.b}[1], [x17], x9
st1 {v12.b}[2], [x18], x9
st1 {v12.b}[3], [x21], x9
st1 {v13.b}[0], [x3], x9
st1 {v13.b}[1], [x17], x9
st1 {v13.b}[2], [x18], x9
st1 {v13.b}[3], [x21], x9
st1 {v14.b}[0], [x3], x9
st1 {v14.b}[1], [x17], x9
st1 {v14.b}[2], [x18], x9
st1 {v14.b}[3], [x21], x9
st1 {v15.b}[0], [x3], x9
st1 {v15.b}[1], [x17], x9
st1 {v15.b}[2], [x18], x9
st1 {v15.b}[3], [x21], x9
add x23, x23, x19
sub x24, x24, #16
cmp x24, #0
ble LoopWEnd
cmp x24, #8
blt LoopW
cmp x24, #16
bge LoopW16
LoopW8:
mov x19, #8
mul x19, x19, x11
mov x16, x23
mov x17, x2
mov x20, x6
mov v0.16b, v24.16b
mov v1.16b, v24.16b
mov v2.16b, v24.16b
mov v3.16b, v24.16b
mov v4.16b, v24.16b
mov v5.16b, v24.16b
mov v6.16b, v24.16b
mov v7.16b, v24.16b
LoopKh8:
mov x18, x7
mov x21, x16
LoopKw8:
mov x22, x21
ld1 {v25.4h}, [x17], #8
ld1 {v16.4h}, [x22], x11
ld1 {v17.4h}, [x22], x11
smlal v0.4s, v16.4h, v25.4h
smlal v1.4s, v17.4h, v25.4h
ld1 {v18.4h}, [x22], x11
ld1 {v19.4h}, [x22], x11
smlal v2.4s, v18.4h, v25.4h
smlal v3.4s, v19.4h, v25.4h
ld1 {v20.4h}, [x22], x11
ld1 {v21.4h}, [x22], x11
smlal v4.4s, v20.4h, v25.4h
smlal v5.4s, v21.4h, v25.4h
ld1 {v22.4h}, [x22], x11
ld1 {v23.4h}, [x22], x11
smlal v6.4s, v22.4h, v25.4h
smlal v7.4s, v23.4h, v25.4h
subs x18, x18, #1
add x21, x21, x13
bne LoopKw8
add x16, x16, x12
subs x20, x20, #1
bne LoopKh8
sqshl v0.4s, v0.4s, v26.4s
sqshl v1.4s, v1.4s, v26.4s
sqshl v2.4s, v2.4s, v26.4s
sqshl v3.4s, v3.4s, v26.4s
sqshl v4.4s, v4.4s, v26.4s
sqshl v5.4s, v5.4s, v26.4s
sqshl v6.4s, v6.4s, v26.4s
sqshl v7.4s, v7.4s, v26.4s
sqrdmulh v0.4s, v0.4s, v27.4s
sqrdmulh v1.4s, v1.4s, v27.4s
sqrdmulh v2.4s, v2.4s, v27.4s
sqrdmulh v3.4s, v3.4s, v27.4s
sqrdmulh v4.4s, v4.4s, v27.4s
sqrdmulh v5.4s, v5.4s, v27.4s
sqrdmulh v6.4s, v6.4s, v27.4s
sqrdmulh v7.4s, v7.4s, v27.4s
and v16.16b, v28.16b, v0.16b
sshr v16.4s, v16.4s, #31
sqadd v0.4s, v0.4s, v16.4s
srshl v0.4s, v0.4s, v28.4s
and v17.16b, v28.16b, v1.16b
sshr v17.4s, v17.4s, #31
sqadd v1.4s, v1.4s, v17.4s
srshl v1.4s, v1.4s, v28.4s
and v18.16b, v28.16b, v2.16b
sshr v18.4s, v18.4s, #31
sqadd v2.4s, v2.4s, v18.4s
srshl v2.4s, v2.4s, v28.4s
and v19.16b, v28.16b, v3.16b
sshr v19.4s, v19.4s, #31
sqadd v3.4s, v3.4s, v19.4s
srshl v3.4s, v3.4s, v28.4s
and v20.16b, v28.16b, v4.16b
sshr v20.4s, v20.4s, #31
sqadd v4.4s, v4.4s, v20.4s
srshl v4.4s, v4.4s, v28.4s
and v21.16b, v28.16b, v5.16b
sshr v21.4s, v21.4s, #31
sqadd v5.4s, v5.4s, v21.4s
srshl v5.4s, v5.4s, v28.4s
and v22.16b, v28.16b, v6.16b
sshr v22.4s, v22.4s, #31
sqadd v6.4s, v6.4s, v22.4s
srshl v6.4s, v6.4s, v28.4s
and v23.16b, v28.16b, v7.16b
sshr v23.4s, v23.4s, #31
sqadd v7.4s, v7.4s, v23.4s
srshl v7.4s, v7.4s, v28.4s
add v0.4s, v0.4s, v29.4s
add v1.4s, v1.4s, v29.4s
add v2.4s, v2.4s, v29.4s
add v3.4s, v3.4s, v29.4s
add v4.4s, v4.4s, v29.4s
add v5.4s, v5.4s, v29.4s
add v6.4s, v6.4s, v29.4s
add v7.4s, v7.4s, v29.4s
smax v0.4s, v0.4s, v30.4s
smax v1.4s, v1.4s, v30.4s
smax v2.4s, v2.4s, v30.4s
smax v3.4s, v3.4s, v30.4s
smax v4.4s, v4.4s, v30.4s
smax v5.4s, v5.4s, v30.4s
smax v6.4s, v6.4s, v30.4s
smax v7.4s, v7.4s, v30.4s
smin v0.4s, v0.4s, v31.4s
smin v1.4s, v1.4s, v31.4s
smin v2.4s, v2.4s, v31.4s
smin v3.4s, v3.4s, v31.4s
smin v4.4s, v4.4s, v31.4s
smin v5.4s, v5.4s, v31.4s
smin v6.4s, v6.4s, v31.4s
smin v6.4s, v6.4s, v30.4s
smin v7.4s, v7.4s, v31.4s
sqxtn v0.4h, v0.4s
@ -535,93 +222,81 @@ ConvDwInt8Center:
sqxtn v6.8b, v6.8h
sqxtn v7.8b, v7.8h
add x17, x3, #1
add x18, x3, #2
add x21, x3, #3
st1 {v0.b}[0], [x3], x9
st1 {v0.b}[1], [x17], x9
st1 {v0.b}[2], [x18], x9
st1 {v0.b}[3], [x21], x9
mov x16, x3
add x17, x16, x9
add x18, x17, x9
add x21, x18, x9
st1 {v1.b}[0], [x3], x9
st1 {v1.b}[1], [x17], x9
st1 {v1.b}[2], [x18], x9
st1 {v1.b}[3], [x21], x9
st1 {v2.b}[0], [x3], x9
st1 {v2.b}[1], [x17], x9
st1 {v2.b}[2], [x18], x9
st1 {v2.b}[3], [x21], x9
st1 {v3.b}[0], [x3], x9
st1 {v3.b}[1], [x17], x9
st1 {v3.b}[2], [x18], x9
st1 {v3.b}[3], [x21], x9
st1 {v4.b}[0], [x3], x9
st1 {v4.b}[1], [x17], x9
st1 {v4.b}[2], [x18], x9
st1 {v4.b}[3], [x21], x9
st1 {v5.b}[0], [x3], x9
st1 {v5.b}[1], [x17], x9
st1 {v5.b}[2], [x18], x9
st1 {v5.b}[3], [x21], x9
st1 {v6.b}[0], [x3], x9
st1 {v6.b}[1], [x17], x9
st1 {v6.b}[2], [x18], x9
st1 {v6.b}[3], [x21], x9
st1 {v7.b}[0], [x3], x9
st1 {v7.b}[1], [x17], x9
st1 {v7.b}[2], [x18], x9
st1 {v7.b}[3], [x21], x9
st1 {v0.s}[0], [x16], #4
st1 {v1.s}[0], [x16], #4
st1 {v2.s}[0], [x17], #4
st1 {v3.s}[0], [x17], #4
st1 {v4.s}[0], [x18], #4
st1 {v5.s}[0], [x18], #4
st1 {v6.s}[0], [x21], #4
st1 {v7.s}[0], [x21], #4
add x3, x3, x25
add x23, x23, x19
sub x24, x24, #8
sub x24, x24, #4
cmp x24, #0
ble LoopWEnd
cmp x24, #8
bge LoopW8
cmp x24, #4
bge LoopW4
LoopW:
mov x16, x23
mov x17, x2
mov x20, x6
mov v0.16b, v24.16b
mov v0.16b, v17.16b
mov v1.16b, v18.16b
LoopKh:
mov x18, x7
mov x22, x16
LoopKw:
ld1 {v16.4h}, [x22], x13
ld1 {v25.4h}, [x17], #8
smlal v0.4s, v16.4h, v25.4h
ld1 {v15.8b}, [x22], x13
ssubl v14.8h, v15.8b, v19.8b
ld1 {v16.8h}, [x17], #16
smlal v0.4s, v14.4h, v16.4h
smlal2 v1.4s, v14.8h, v16.8h
subs x18, x18, #1
bne LoopKw
add x16, x16, x12
subs x20, x20, #1
bne LoopKh
sqshl v0.4s, v0.4s, v26.4s
sqrdmulh v0.4s, v0.4s, v27.4s
sqshl v0.4s, v0.4s, v24.4s
sqrdmulh v0.4s, v0.4s, v22.4s
sqshl v1.4s, v1.4s, v25.4s
sqrdmulh v1.4s, v1.4s, v23.4s
and v16.16b, v28.16b, v0.16b
sshr v16.4s, v16.4s, #31
sqadd v0.4s, v0.4s, v16.4s
srshl v0.4s, v0.4s, v28.4s
and v15.16b, v26.16b, v0.16b
sshr v15.4s, v15.4s, #31
sqadd v0.4s, v0.4s, v15.4s
srshl v0.4s, v0.4s, v26.4s
add v0.4s, v0.4s, v29.4s
smax v0.4s, v0.4s, v30.4s
smin v0.4s, v0.4s, v31.4s
and v14.16b, v27.16b, v1.16b
sshr v14.4s, v14.4s, #31
sqadd v1.4s, v1.4s, v14.4s
srshl v1.4s, v1.4s, v27.4s
add v0.4s, v0.4s, v20.4s
smax v0.4s, v0.4s, v28.4s
smin v0.4s, v0.4s, v30.4s
sqxtn v0.4h, v0.4s
sqxtn v0.8b, v0.8h
add v1.4s, v1.4s, v21.4s
smax v1.4s, v1.4s, v29.4s
smin v1.4s, v1.4s, v31.4s
sqxtn v1.4h, v1.4s
sqxtn v1.8b, v1.8h
mov x17, x3
st1 {v0.b}[0], [x17], #1
st1 {v0.b}[1], [x17], #1
st1 {v0.b}[2], [x17], #1
st1 {v0.b}[3], [x17], #1
st1 {v0.s}[0], [x17], #4
st1 {v1.s}[0], [x17], #4
add x3, x3, x9
add x23, x23, x11

View File

@ -45,10 +45,11 @@ void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *wei
void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width,
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step);
void ConvDwInt8Center(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, size_t height,
void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height,
size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel,
size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int out_multiplier,
int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max);
size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp,
int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift,
int32_t *acc_min, int32_t *acc_max);
void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels,
int output_channel, int input_step, int8_t input_zp);
void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier,

View File

@ -138,75 +138,67 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da
}
/*conv depthwise int8 end*/
/*conv depthwise sliding window int8 begin*/
void DepthwiseBorderPixelInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int height,
int width, int in_kh_step, int in_kw_step, int kernel_w, int *out_multiplier,
int *left_shift, int *right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max,
bool per_channel) {
int tmp_buffer[C4NUM];
for (int i = 0; i < C4NUM; i++) {
/*conv depthwise sliding window perchannel int8 begin*/
void DepthwiseBorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height,
int width, int in_kh_step, int in_kw_step, int kernel_w, int8_t *input_zp,
int32_t *out_zp, int *out_multiplier, int *left_shift, int *right_shift, int32_t *acc_min,
int32_t *acc_max) {
int tmp_buffer[C8NUM];
for (int i = 0; i < C8NUM; i++) {
tmp_buffer[i] = 0;
}
const int16_t *src_kh = src;
const int8_t *src_kh = src;
const int16_t *weight_kh = weight;
for (int kh = 0; kh < height; kh++) {
const int16_t *src_kw = src_kh;
const int8_t *src_kw = src_kh;
const int16_t *weight_kw = weight_kh;
for (int kw = 0; kw < width; kw++) {
for (int c = 0; c < C4NUM; c++) {
tmp_buffer[c] += src_kw[c] * weight_kw[c];
for (int c = 0; c < C8NUM; c++) {
tmp_buffer[c] += (src_kw[c] - input_zp[c]) * weight_kw[c];
}
src_kw += in_kw_step;
weight_kw += C4NUM;
weight_kw += C8NUM;
} // kernel_w loop
src_kh += in_kh_step;
weight_kh += kernel_w * C4NUM;
weight_kh += kernel_w * C8NUM;
} // kernel_h loop
int32_t left = left_shift[0];
int32_t right = right_shift[0];
int32_t multiplier = out_multiplier[0];
for (int c = 0; c < C4NUM; c++) {
if (per_channel) {
left = left_shift[c];
right = right_shift[c];
multiplier = out_multiplier[c];
}
for (int c = 0; c < C8NUM; c++) {
tmp_buffer[c] += bias[c];
tmp_buffer[c] = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left), multiplier), -right);
tmp_buffer[c] += out_zp;
tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min);
tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max);
SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]),
-right_shift[c]);
tmp_buffer[c] += out_zp[c];
tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min[c]);
tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max[c]);
dst[c] = (tmp_buffer[c]);
}
}
void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int top,
void DepthwiseBorderInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int top,
int bottom, int left, int right, const ConvParameter *conv_param,
const SlidingWindowParam *sliding, int *out_multiplier, int *left_shift, int *right_shift,
bool per_channel) {
const SlidingWindowParam *sliding, int8_t *in_zp, int32_t *out_zp, int *out_multiplier,
int *left_shift, int *right_shift, int32_t *acc_min, int32_t *acc_max) {
int8_t *dst_h = dst + top * sliding->out_h_step_;
for (int oh = top; oh < bottom; oh++) {
int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
const int16_t *src_h = src + ih * sliding->in_h_step_;
const int8_t *src_h = src + ih * sliding->in_h_step_;
int8_t *dst_kernel = dst_h + left * sliding->block_channel_;
for (int ow = left; ow < right; ow++) {
int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
const int16_t *src_w = src_h + iw * sliding->block_channel_;
const int8_t *src_w = src_h + iw * sliding->block_channel_;
const int16_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_;
const int16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM;
const int8_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_;
const int16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM;
DepthwiseBorderPixelInt8(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_, out_multiplier,
left_shift, right_shift, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
per_channel);
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_, in_zp, out_zp,
out_multiplier, left_shift, right_shift, acc_min, acc_max);
dst_kernel += sliding->block_channel_;
} // width loop
@ -215,52 +207,46 @@ void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight,
}
#ifndef ENABLE_ARM64
void DepthwiseCenterInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int height,
void DepthwiseCenterInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height,
int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step,
int in_sw_step, int in_kh_step, int in_kw_step, int *out_multiplier, int *left_shift,
int *right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max, bool per_channel) {
int tmp_buffer[C4NUM];
int in_sw_step, int in_kh_step, int in_kw_step, int8_t *in_zp, int32_t *out_zp,
int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t *acc_min,
int32_t *acc_max) {
int tmp_buffer[C8NUM];
int8_t *dst_h = dst;
const int16_t *src_h = src;
const int8_t *src_h = src;
for (int oh = 0; oh < height; oh++) {
int8_t *dst_w = dst_h;
const int16_t *src_w = src_h;
const int8_t *src_w = src_h;
for (int ow = 0; ow < width; ow++) {
const int16_t *src_kh = src_w;
const int8_t *src_kh = src_w;
const int16_t *weight_kh = weight;
for (int i = 0; i < C4NUM; i++) {
for (int i = 0; i < C8NUM; i++) {
tmp_buffer[i] = 0;
}
for (int kh = 0; kh < kernel_h; kh++) {
const int16_t *src_kw = src_kh;
const int8_t *src_kw = src_kh;
const int16_t *weight_kw = weight_kh;
for (int kw = 0; kw < kernel_w; kw++) {
for (int c = 0; c < C4NUM; c++) {
tmp_buffer[c] += src_kw[c] * weight_kw[c];
for (int c = 0; c < C8NUM; c++) {
tmp_buffer[c] += (src_kw[c] - in_zp[c]) * weight_kw[c];
}
src_kw += in_kw_step;
weight_kw += C4NUM;
weight_kw += C8NUM;
} // kernel_w loop
src_kh += in_kh_step;
weight_kh += kernel_w * C4NUM;
weight_kh += kernel_w * C8NUM;
} // kernel_h loop
// add bias relu
int32_t left = left_shift[0];
int32_t right = right_shift[0];
int32_t multiplier = out_multiplier[0];
for (int c = 0; c < C4NUM; c++) {
if (per_channel) {
left = left_shift[c];
right = right_shift[c];
multiplier = out_multiplier[c];
}
for (int c = 0; c < C8NUM; c++) {
tmp_buffer[c] += bias[c];
tmp_buffer[c] = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left), multiplier), -right);
tmp_buffer[c] += out_zp;
tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min);
tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max);
SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]),
-right_shift[c]);
tmp_buffer[c] += out_zp[c];
tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min[c]);
tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max[c]);
dst_w[c] = (tmp_buffer[c]);
}
dst_w += block_channel;
@ -272,69 +258,65 @@ void DepthwiseCenterInt8(int8_t *dst, const int16_t *src, const int16_t *weight,
}
#endif
void ConvDwSWInt8(int8_t *output_data, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data,
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) {
const int16_t *src = input_data;
void ConvDwSWInt8(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, const int32_t *bias_data,
int8_t *input_zp, int32_t *output_zp, const ConvParameter *conv_param,
const SlidingWindowParam *sliding, int task_id) {
const int8_t *src = input_data;
int8_t *dst = output_data;
bool per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL;
int *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_;
int *left_shift = conv_param->conv_quant_arg_.left_shift_;
int *right_shift = conv_param->conv_quant_arg_.right_shift_;
for (int b = 0; b < conv_param->output_batch_; b++) {
for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) {
const int16_t *src_data = src + oc * C4NUM;
int8_t *dst_data = dst + oc * C4NUM;
const int8_t *src_data = src + oc * C8NUM;
int8_t *dst_data = dst + oc * C8NUM;
const int16_t *weight = weight_data + oc * sliding->kernel_step_;
const int32_t *bias = bias_data + oc * C4NUM;
const int32_t *bias = bias_data + oc * C8NUM;
if (per_channel) {
out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_ + oc * C4NUM;
left_shift = conv_param->conv_quant_arg_.left_shift_ + oc * C4NUM;
right_shift = conv_param->conv_quant_arg_.right_shift_ + oc * C4NUM;
}
int *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_ + oc * C8NUM;
int *left_shift = conv_param->conv_quant_arg_.left_shift_ + oc * C8NUM;
int *right_shift = conv_param->conv_quant_arg_.right_shift_ + oc * C8NUM;
int *acc_min = conv_param->conv_quant_arg_.out_act_min_ + oc * C8NUM;
int *acc_max = conv_param->conv_quant_arg_.out_act_max_ + oc * C8NUM;
int8_t *in_zp = input_zp + oc * C8NUM;
int32_t *out_zp = output_zp + oc * C8NUM;
DepthwiseBorderInt8(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param,
sliding, out_multiplier, left_shift, right_shift, per_channel);
sliding, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0,
conv_param->output_w_, conv_param, sliding, out_multiplier, left_shift, right_shift,
per_channel);
conv_param->output_w_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift,
right_shift, acc_min, acc_max);
DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_,
conv_param, sliding, out_multiplier, left_shift, right_shift, per_channel);
conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min,
acc_max);
DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_,
conv_param->output_w_, conv_param, sliding, out_multiplier, left_shift, right_shift,
per_channel);
conv_param->output_w_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift,
right_shift, acc_min, acc_max);
if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
const int16_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_;
const int8_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_;
int8_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
#ifdef ENABLE_ARM64
ConvDwInt8Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(int8_t),
sliding->block_channel_ * sizeof(int8_t), sliding->in_sh_step_ * sizeof(int16_t),
sliding->in_sw_step_ * sizeof(int16_t), sliding->in_kh_step_ * sizeof(int16_t),
sliding->in_kw_step_ * sizeof(int16_t), conv_param->conv_quant_arg_.quant_multiplier_[0],
conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0],
conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]);
sliding->block_channel_ * sizeof(int8_t), sliding->in_sh_step_ * sizeof(int8_t),
sliding->in_sw_step_ * sizeof(int8_t), sliding->in_kh_step_ * sizeof(int8_t),
sliding->in_kw_step_ * sizeof(int8_t), in_zp, out_zp, out_multiplier, left_shift, right_shift,
acc_min, acc_max);
#else
DepthwiseCenterInt8(
out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_,
sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, out_multiplier,
left_shift, right_shift, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], per_channel);
DepthwiseCenterInt8(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_,
sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_,
sliding->out_h_step_, sliding->block_channel_, sliding->in_sh_step_, sliding->in_sw_step_,
sliding->in_kh_step_, sliding->in_kw_step_, in_zp, out_zp, out_multiplier, left_shift,
right_shift, acc_min, acc_max);
#endif
}
} // output C4 loop
} // output C8 loop
src += sliding->in_step_;
dst += sliding->out_step_;
} // batch loop
// output nhwc4
// output nhwc8
}
/*conv depthwise sliding window int8 end*/
/*conv depthwise sliding window perchannel int8 end*/
/*deconv depthwise int8 begin*/
void DeconvDepthwiseBorderPixelInt8(int32_t *dst, const int16_t *src, const int16_t *weight, int height, int width,

View File

@ -27,8 +27,9 @@ extern "C" {
void ConvDwInt8(int8_t *output_data, int32_t *output_row, const int8_t *input_data, const int16_t *weight_data,
const int32_t *bias_data, const ConvParameter *conv_param, int task_id);
void ConvDwSWInt8(int8_t *output_data, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data,
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id);
void ConvDwSWInt8(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, const int32_t *bias_data,
int8_t *input_zp, int32_t *output_zp, const ConvParameter *conv_param,
const SlidingWindowParam *sliding, int task_id);
void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *input_data, const int16_t *weight_data,
const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding,

View File

@ -965,6 +965,45 @@ void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int c
}
}
void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel) {
int c8 = UP_DIV(channel, C8NUM);
int nhwc8_batch_unit_offset = c8 * C8NUM * plane;
int ic_remainder_ = channel % C8NUM;
if (ic_remainder_ != 0) {
int nhwc8_batch_offset = 0;
for (int b = 0; b < batch; b++) {
int batch_offset = b * channel * plane;
for (int i = 0; i < plane; i++) {
memcpy((int8_t *)dst + nhwc8_batch_offset + i * c8 * C8NUM, (int8_t *)src + batch_offset + i * channel,
channel);
}
nhwc8_batch_offset += nhwc8_batch_unit_offset;
}
} else {
size_t ori_input_size = batch * plane * channel;
memcpy((int8_t *)dst, (int8_t *)src, ori_input_size);
}
}
void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) {
int c8 = UP_DIV(channel, C8NUM);
int nhwc8_batch_unit_offset = c8 * C8NUM * plane;
int ic_remainder_ = channel % C8NUM;
if (ic_remainder_ != 0) {
for (int b = 0; b < batch; b++) {
int batch_offset = b * channel * plane;
int nhwc8_batch_offset = b * nhwc8_batch_unit_offset;
for (int i = 0; i < plane; i++) {
memcpy((int8_t *)dst + batch_offset + i * channel, (int8_t *)src + nhwc8_batch_offset + i * c8 * C8NUM,
channel);
}
}
} else {
size_t ori_input_size = batch * plane * channel;
memcpy((int8_t *)dst, (int8_t *)src, ori_input_size);
}
}
void PackNCHWToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) {
int nhwc4_batch_offset = 0;
int c4 = UP_DIV(channel, C4NUM);
@ -1270,6 +1309,25 @@ void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter
void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel,
ConvQuantArg *quant_qrg) {
int weight_zp = quant_qrg->filter_quant_args_[0].zp_;
for (int c = 0; c < channel; c++) {
if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) {
weight_zp = quant_qrg->filter_quant_args_[c].zp_;
}
int c8_block_num = c / C8NUM;
int c8_block_rem = c % C8NUM;
const int8_t *src_c = origin_weight + c * plane;
int16_t *dst_c = packed_weight_ + c8_block_num * plane * C8NUM;
for (int k = 0; k < plane; k++) {
const int8_t *src_kernel = src_c + k;
int16_t *dst_kernel = dst_c + C8NUM * k + c8_block_rem;
*dst_kernel = (int16_t)(src_kernel[0] - weight_zp);
}
}
}
void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel,
ConvQuantArg *quant_qrg) {
int weight_zp = quant_qrg->filter_quant_args_[0].zp_;
for (int c = 0; c < channel; c++) {
if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) {
weight_zp = quant_qrg->filter_quant_args_[c].zp_;

View File

@ -96,6 +96,10 @@ void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int c
void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel);
@ -114,6 +118,9 @@ void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter
void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel,
ConvQuantArg *quant_qrg);
void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel,
ConvQuantArg *quant_qrg);
#ifdef __cplusplus
}
#endif

View File

@ -177,8 +177,17 @@ kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector<lite::Tensor *>
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
auto kernel =
new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
kernel::LiteKernel *kernel;
auto act_quant_size =
MSMAX(inputs[kInputIndex]->GetQuantParams().size(), outputs[kOutputIndex]->GetQuantParams().size());
if (act_quant_size == 1) { // per tensor
kernel = new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
} else { // per channel
kernel =
new (std::nothrow) kernel::ConvolutionDepthwiseSWInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;

View File

@ -37,6 +37,7 @@ ConvolutionDepthwiseSWInt8CPUKernel::~ConvolutionDepthwiseSWInt8CPUKernel() {
free(packed_weight_);
packed_weight_ = nullptr;
}
FreeTmpQuant();
FreeQuantParam();
}
@ -45,8 +46,8 @@ int ConvolutionDepthwiseSWInt8CPUKernel::InitWeightBias() {
// o, h, w, i -> o/8, h, w, i, 8; o == group, i == 1
auto weight_tensor = in_tensors_[kWeightIndex];
auto origin_weight = reinterpret_cast<int8_t *>(weight_tensor->MutableData());
int OC4 = UP_DIV(weight_tensor->Batch(), C4NUM);
int pack_weight_size = C4NUM * OC4 * weight_tensor->Height() * weight_tensor->Width();
int OC8 = UP_DIV(weight_tensor->Batch(), C8NUM);
int pack_weight_size = C8NUM * OC8 * weight_tensor->Height() * weight_tensor->Width();
packed_weight_ = reinterpret_cast<int16_t *>(malloc(pack_weight_size * sizeof(int16_t)));
if (packed_weight_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
@ -55,35 +56,36 @@ int ConvolutionDepthwiseSWInt8CPUKernel::InitWeightBias() {
PackDepthwiseInt8Weight(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(),
weight_tensor->Batch(), &(conv_param_->conv_quant_arg_));
bias_data_ = reinterpret_cast<int32_t *>(malloc(C4NUM * OC4 * sizeof(int32_t)));
bias_data_ = reinterpret_cast<int32_t *>(malloc(C8NUM * OC8 * sizeof(int32_t)));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
memset(bias_data_, 0, C4NUM * OC4 * sizeof(int32_t));
memset(bias_data_, 0, C8NUM * OC8 * sizeof(int32_t));
if (in_tensors_.size() == kInputSize2) {
auto bias_tensor = in_tensors_.at(kBiasIndex);
auto ori_bias = reinterpret_cast<int32_t *>(bias_tensor->MutableData());
memcpy(bias_data_, ori_bias, bias_tensor->ElementsNum() * sizeof(int32_t));
}
conv_param_->thread_num_ = MSMIN(thread_count_, OC4);
conv_param_->thread_num_ = MSMIN(thread_count_, OC8);
return RET_OK;
}
int ConvolutionDepthwiseSWInt8CPUKernel::InitBuffer() {
int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM *
UP_DIV(conv_param_->input_channel_, 4);
packed_input_ = reinterpret_cast<int16_t *>(context_->allocator->Malloc(pack_input_size * sizeof(int16_t)));
if (packed_input_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
if (conv_param_->input_channel_ % C4NUM != 0) {
if (conv_param_->input_channel_ % C8NUM != 0) {
need_align_ = true;
int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C4NUM *
UP_DIV(conv_param_->output_channel_, C4NUM);
int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C8NUM *
UP_DIV(conv_param_->input_channel_, C8NUM);
packed_input_ = reinterpret_cast<int8_t *>(context_->allocator->Malloc(pack_input_size * sizeof(int8_t)));
if (packed_input_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C8NUM *
UP_DIV(conv_param_->output_channel_, C8NUM);
packed_output_ = reinterpret_cast<int8_t *>(context_->allocator->Malloc(pack_output_size * sizeof(int8_t)));
if (packed_input_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
@ -93,6 +95,136 @@ int ConvolutionDepthwiseSWInt8CPUKernel::InitBuffer() {
return RET_OK;
}
void ConvolutionDepthwiseSWInt8CPUKernel::FreeTmpQuant() {
if (input_scale_ != nullptr) {
free(input_scale_);
input_scale_ = nullptr;
}
if (input_zp_ != nullptr) {
free(input_zp_);
input_zp_ = nullptr;
}
if (weight_scale_ != nullptr) {
free(weight_scale_);
weight_scale_ = nullptr;
}
if (output_scale_ != nullptr) {
free(output_scale_);
output_scale_ = nullptr;
}
if (output_zp_ != nullptr) {
free(output_zp_);
output_zp_ = nullptr;
}
}
int ConvolutionDepthwiseSWInt8CPUKernel::ReinitFreeBefore() {
FreeTmpQuant();
if (conv_quant_arg_->real_multiplier_ != nullptr) {
free(conv_quant_arg_->real_multiplier_);
conv_quant_arg_->real_multiplier_ = nullptr;
}
if (conv_quant_arg_->left_shift_ != nullptr) {
free(conv_quant_arg_->left_shift_);
conv_quant_arg_->left_shift_ = nullptr;
}
if (conv_quant_arg_->right_shift_ != nullptr) {
free(conv_quant_arg_->right_shift_);
conv_quant_arg_->right_shift_ = nullptr;
}
if (conv_quant_arg_->quant_multiplier_ != nullptr) {
free(conv_quant_arg_->quant_multiplier_);
conv_quant_arg_->quant_multiplier_ = nullptr;
}
if (conv_quant_arg_->out_act_min_ != nullptr) {
free(conv_quant_arg_->out_act_min_);
conv_quant_arg_->out_act_min_ = nullptr;
}
if (conv_quant_arg_->out_act_max_ != nullptr) {
free(conv_quant_arg_->out_act_max_);
conv_quant_arg_->out_act_max_ = nullptr;
}
return RET_OK;
}
int ConvolutionDepthwiseSWInt8CPUKernel::ReinitQuantParam() {
ReinitFreeBefore(); // remalloc quant param buffer
auto input_tensor = in_tensors_.at(kInputIndex);
auto channel = conv_param_->input_channel_;
input_scale_ = reinterpret_cast<float *>(malloc(channel * sizeof(float)));
input_zp_ = reinterpret_cast<int8_t *>(malloc(channel * sizeof(int8_t)));
if (input_tensor->GetQuantParams().size() == kPerTensor) {
for (int i = 0; i < channel; i++) {
auto input_quant_arg = input_tensor->GetQuantParams().front();
input_zp_[i] = input_quant_arg.zeroPoint;
input_scale_[i] = input_quant_arg.scale;
}
} else {
for (int i = 0; i < channel; i++) {
auto input_quant_arg = input_tensor->GetQuantParams()[i];
input_zp_[i] = input_quant_arg.zeroPoint;
input_scale_[i] = input_quant_arg.scale;
}
}
auto output_tensor = out_tensors_.at(kOutputIndex);
output_scale_ = reinterpret_cast<float *>(malloc(channel * sizeof(float)));
output_zp_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t)));
if (output_tensor->GetQuantParams().size() == kPerTensor) {
for (int i = 0; i < channel; i++) {
auto output_quant_arg = output_tensor->GetQuantParams().front();
output_zp_[i] = output_quant_arg.zeroPoint;
output_scale_[i] = output_quant_arg.scale;
}
} else {
for (int i = 0; i < channel; i++) {
auto output_quant_arg = output_tensor->GetQuantParams()[i];
output_zp_[i] = output_quant_arg.zeroPoint;
output_scale_[i] = output_quant_arg.scale;
}
}
conv_quant_arg_->real_multiplier_ = reinterpret_cast<double *>(malloc(channel * sizeof(double)));
conv_quant_arg_->left_shift_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t)));
conv_quant_arg_->right_shift_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t)));
conv_quant_arg_->quant_multiplier_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t)));
conv_quant_arg_->out_act_min_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t)));
conv_quant_arg_->out_act_max_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t)));
weight_scale_ = reinterpret_cast<float *>(malloc(channel * sizeof(float)));
auto weight_tensor = in_tensors_.at(kWeightIndex);
if (weight_tensor->GetQuantParams().size() == kPerTensor) {
for (int i = 0; i < channel; i++) {
auto weight_quant_arg = weight_tensor->GetQuantParams().front();
weight_scale_[i] = weight_quant_arg.scale;
}
} else {
for (int i = 0; i < channel; i++) {
auto weight_quant_arg = weight_tensor->GetQuantParams()[i];
weight_scale_[i] = weight_quant_arg.scale;
}
}
for (int i = 0; i < channel; ++i) {
const double in_scale = static_cast<double>(input_scale_[i] * weight_scale_[i]);
double real_multiplier = in_scale / static_cast<double>(output_scale_[i]);
conv_quant_arg_->real_multiplier_[i] = real_multiplier;
QuantizeRoundParameter(real_multiplier, &conv_quant_arg_->quant_multiplier_[i], &conv_quant_arg_->left_shift_[i],
&conv_quant_arg_->right_shift_[i]);
}
// now only consider per tensor for output
bool relu = conv_param_->act_type_ == ActType_Relu;
bool relu6 = conv_param_->act_type_ == ActType_Relu6;
for (int i = 0; i < channel; ++i) {
CalculateActivationRangeQuantized(relu, relu6, output_zp_[i], output_scale_[i],
&conv_param_->conv_quant_arg_.out_act_min_[i],
&conv_param_->conv_quant_arg_.out_act_max_[i]);
}
return RET_OK;
}
int ConvolutionDepthwiseSWInt8CPUKernel::Init() {
sliding = new (std::nothrow) SlidingWindowParam;
if (sliding == nullptr) {
@ -107,13 +239,19 @@ int ConvolutionDepthwiseSWInt8CPUKernel::Init() {
int ConvolutionDepthwiseSWInt8CPUKernel::ReSize() {
ConvolutionBaseCPUKernel::Init();
InitSlidingParamConvDw(sliding, conv_param_, C4NUM);
InitSlidingParamConvDw(sliding, conv_param_, C8NUM);
auto ret = ConvolutionBaseCPUKernel::SetQuantParam();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Set quant param failed.";
return ret;
}
ret = ReinitQuantParam();
if (ret != RET_OK) {
MS_LOG(ERROR) << "reinit quant param failed.";
return ret;
}
ret = InitWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Depthwise int8 InitWeightBias error!";
@ -123,8 +261,8 @@ int ConvolutionDepthwiseSWInt8CPUKernel::ReSize() {
}
int ConvolutionDepthwiseSWInt8CPUKernel::Execute(int task_id) {
ConvDwSWInt8(packed_output_, packed_input_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), conv_param_,
sliding, task_id);
ConvDwSWInt8(packed_output_, packed_input_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), input_zp_,
output_zp_, conv_param_, sliding, task_id);
return RET_OK;
}
@ -157,7 +295,12 @@ int ConvolutionDepthwiseSWInt8CPUKernel::Run() {
auto input_tensor = in_tensors_.at(kInputIndex);
auto input_addr = reinterpret_cast<int8_t *>(input_tensor->MutableData());
PackDepthwiseInt8Input(input_addr, packed_input_, conv_param_);
if (need_align_) {
PackNHWCToNHWC8Int8(input_addr, packed_input_, conv_param_->output_batch_,
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
} else {
packed_input_ = input_addr;
}
auto output_addr = reinterpret_cast<int8_t *>(out_tensors_.at(kOutputIndex)->MutableData());
if (!need_align_) {
@ -171,11 +314,11 @@ int ConvolutionDepthwiseSWInt8CPUKernel::Run() {
}
if (need_align_) {
PackNHWC4ToNHWCInt8(packed_output_, output_addr, conv_param_->output_batch_,
PackNHWC8ToNHWCInt8(packed_output_, output_addr, conv_param_->output_batch_,
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
context_->allocator->Free(packed_input_);
context_->allocator->Free(packed_output_);
}
context_->allocator->Free(packed_input_);
return RET_OK;
}

View File

@ -40,11 +40,21 @@ class ConvolutionDepthwiseSWInt8CPUKernel : public ConvolutionBaseCPUKernel {
int Execute(int task_id);
private:
int ReinitQuantParam();
int ReinitFreeBefore();
void FreeTmpQuant();
SlidingWindowParam *sliding = nullptr;
int16_t *packed_weight_ = nullptr;
int16_t *packed_input_ = nullptr;
int8_t *packed_input_ = nullptr;
int8_t *packed_output_ = nullptr;
bool need_align_ = false;
int8_t *input_zp_ = nullptr;
float *input_scale_ = nullptr;
float *weight_scale_ = nullptr;
int32_t *output_zp_ = nullptr;
float *output_scale_ = nullptr;
};
} // namespace mindspore::kernel

View File

@ -52,8 +52,8 @@ int DeconvolutionDepthwiseInt8CPUKernel::InitWeightBias() {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
PackDepthwiseInt8Weight(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(),
weight_tensor->Batch(), &(conv_param_->conv_quant_arg_));
PackDeconvDepthwiseInt8Weight(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(),
weight_tensor->Batch(), &(conv_param_->conv_quant_arg_));
bias_data_ = reinterpret_cast<int32_t *>(malloc(C4NUM * OC4 * sizeof(int32_t)));
if (bias_data_ == nullptr) {