forked from mindspore-Ecosystem/mindspore
[MSLITE][Develop] arm cpu int8 conv depthwise support activation per channel
This commit is contained in:
parent
9ca16d3c6c
commit
7175e1921e
|
@ -7,13 +7,15 @@
|
|||
.type ConvDwInt8Center, %function
|
||||
#endif
|
||||
|
||||
// void ConvDwInt8Center(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, size_t height, size_t width,
|
||||
// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step,
|
||||
// size_t in_kh_step, size_t in_kw_step, int out_multiplier, int left_shift,
|
||||
// int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max);
|
||||
// void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height,
|
||||
// size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel,
|
||||
// size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp,
|
||||
// int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift,
|
||||
// int32_t *acc_min, int32_t *acc_max)
|
||||
|
||||
// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: weight, x6: kernel_h, x7: kernel_w,
|
||||
// x8: out_h_step, x9: block_channel, x10: in_sh_step, x11: in_sw_step, x12: in_kh_step, x13: in_kw_step
|
||||
// x14: out_multiplier, #56: left_shift, #64: right_shift, #72:out_zp, #80: acc_min, #88: acc_max
|
||||
// x14: in_zp, #56: out_zp, #64: out_multiplier, #72:left_shift, #80: right_shift, #88: acc_min, #96: acc_max
|
||||
ConvDwInt8Center:
|
||||
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
|
||||
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
|
||||
|
@ -33,489 +35,174 @@ ConvDwInt8Center:
|
|||
ldr x12, [sp, #32]
|
||||
ldr x13, [sp, #40]
|
||||
|
||||
ldr w14, [sp, #56]
|
||||
dup v26.4s, w14
|
||||
ldr x14, [sp, #48] // input_zp
|
||||
ld1 {v19.8b}, [x14], #8
|
||||
|
||||
ldr x15, [sp, #48]
|
||||
dup v27.4s, w15
|
||||
ldr x15, [sp, #56] // output_zp
|
||||
ld1 {v20.4s}, [x15], #16
|
||||
ld1 {v21.4s}, [x15], #16
|
||||
|
||||
ldr w16, [sp, #64]
|
||||
dup v28.4s, w16
|
||||
ldr x16, [sp, #64] // out_multiplier
|
||||
ld1 {v22.4s}, [x16], #16
|
||||
ld1 {v23.4s}, [x16], #16
|
||||
|
||||
ldr w17, [sp, #72]
|
||||
dup v29.4s, w17
|
||||
|
||||
ldr w18, [sp, #80]
|
||||
dup v30.4s, w18
|
||||
ldr x17, [sp, #72] // left_shift
|
||||
ld1 {v24.4s}, [x17], #16
|
||||
ld1 {v25.4s}, [x17], #16
|
||||
|
||||
ldr w19, [sp, #88]
|
||||
dup v31.4s, w19
|
||||
ldr x18, [sp, #80] // right shift
|
||||
ld1 {v26.4s}, [x18], #16
|
||||
ld1 {v27.4s}, [x18], #16
|
||||
|
||||
ld1 {v24.4s}, [x3]
|
||||
ldr x19, [sp, #88] // acc_min
|
||||
ld1 {v28.4s}, [x19], #16
|
||||
ld1 {v29.4s}, [x19], #16
|
||||
|
||||
ldr x20, [sp, #96] // acc_max
|
||||
ld1 {v30.4s}, [x20], #16
|
||||
ld1 {v31.4s}, [x20], #16
|
||||
|
||||
ld1 {v17.4s}, [x3], #16
|
||||
ld1 {v18.4s}, [x3], #16
|
||||
|
||||
LoopH:
|
||||
mov x23, x1
|
||||
mov x24, x5
|
||||
mov x3, x0
|
||||
cmp x24, #8
|
||||
blt LoopW
|
||||
cmp x24, #16
|
||||
blt LoopW8
|
||||
|
||||
LoopW16:
|
||||
mov x19, #16
|
||||
LoopW4:
|
||||
mov x19, #4
|
||||
mul x19, x19, x11
|
||||
mov x25, #4
|
||||
mul x25, x25, x9
|
||||
|
||||
mov x16, x23
|
||||
mov x17, x2
|
||||
mov x20, x6
|
||||
mov v0.16b, v24.16b
|
||||
mov v1.16b, v24.16b
|
||||
mov v2.16b, v24.16b
|
||||
mov v3.16b, v24.16b
|
||||
mov v4.16b, v24.16b
|
||||
mov v5.16b, v24.16b
|
||||
mov v6.16b, v24.16b
|
||||
mov v7.16b, v24.16b
|
||||
mov v8.16b, v24.16b
|
||||
mov v9.16b, v24.16b
|
||||
mov v10.16b, v24.16b
|
||||
mov v11.16b, v24.16b
|
||||
mov v12.16b, v24.16b
|
||||
mov v13.16b, v24.16b
|
||||
mov v14.16b, v24.16b
|
||||
mov v15.16b, v24.16b
|
||||
LoopKh16:
|
||||
|
||||
mov v0.16b, v17.16b
|
||||
mov v1.16b, v18.16b
|
||||
mov v2.16b, v17.16b
|
||||
mov v3.16b, v18.16b
|
||||
mov v4.16b, v17.16b
|
||||
mov v5.16b, v18.16b
|
||||
mov v6.16b, v17.16b
|
||||
mov v7.16b, v18.16b
|
||||
LoopKh4:
|
||||
mov x18, x7
|
||||
mov x21, x16
|
||||
LoopKw16:
|
||||
LoopKw4:
|
||||
mov x22, x21
|
||||
ld1 {v25.4h}, [x17], #8
|
||||
ld1 {v16.4h}, [x22], x11
|
||||
ld1 {v17.4h}, [x22], x11
|
||||
smlal v0.4s, v16.4h, v25.4h
|
||||
smlal v1.4s, v17.4h, v25.4h
|
||||
ld1 {v18.4h}, [x22], x11
|
||||
ld1 {v19.4h}, [x22], x11
|
||||
smlal v2.4s, v18.4h, v25.4h
|
||||
smlal v3.4s, v19.4h, v25.4h
|
||||
ld1 {v20.4h}, [x22], x11
|
||||
ld1 {v21.4h}, [x22], x11
|
||||
smlal v4.4s, v20.4h, v25.4h
|
||||
smlal v5.4s, v21.4h, v25.4h
|
||||
ld1 {v22.4h}, [x22], x11
|
||||
ld1 {v23.4h}, [x22], x11
|
||||
smlal v6.4s, v22.4h, v25.4h
|
||||
smlal v7.4s, v23.4h, v25.4h
|
||||
ld1 {v16.4h}, [x22], x11
|
||||
ld1 {v17.4h}, [x22], x11
|
||||
smlal v8.4s, v16.4h, v25.4h
|
||||
smlal v9.4s, v17.4h, v25.4h
|
||||
ld1 {v18.4h}, [x22], x11
|
||||
ld1 {v19.4h}, [x22], x11
|
||||
smlal v10.4s, v18.4h, v25.4h
|
||||
smlal v11.4s, v19.4h, v25.4h
|
||||
ld1 {v20.4h}, [x22], x11
|
||||
ld1 {v21.4h}, [x22], x11
|
||||
smlal v12.4s, v20.4h, v25.4h
|
||||
smlal v13.4s, v21.4h, v25.4h
|
||||
ld1 {v22.4h}, [x22], x11
|
||||
ld1 {v23.4h}, [x22], x11
|
||||
smlal v14.4s, v22.4h, v25.4h
|
||||
smlal v15.4s, v23.4h, v25.4h
|
||||
ld1 {v16.8h}, [x17], #16
|
||||
|
||||
ld1 {v15.8b}, [x22], x11
|
||||
ssubl v14.8h, v15.8b, v19.8b
|
||||
smlal v0.4s, v14.4h, v16.4h
|
||||
smlal2 v1.4s, v14.8h, v16.8h
|
||||
|
||||
ld1 {v13.8b}, [x22], x11
|
||||
ssubl v12.8h, v13.8b, v19.8b
|
||||
smlal v2.4s, v12.4h, v16.4h
|
||||
smlal2 v3.4s, v12.8h, v16.8h
|
||||
|
||||
ld1 {v11.8b}, [x22], x11
|
||||
ssubl v10.8h, v11.8b, v19.8b
|
||||
smlal v4.4s, v10.4h, v16.4h
|
||||
smlal2 v5.4s, v10.8h, v16.8h
|
||||
|
||||
ld1 {v9.8b}, [x22], x11
|
||||
ssubl v8.8h, v9.8b, v19.8b
|
||||
smlal v6.4s, v8.4h, v16.4h
|
||||
smlal2 v7.4s, v8.8h, v16.8h
|
||||
|
||||
subs x18, x18, #1
|
||||
add x21, x21, x13
|
||||
bne LoopKw16
|
||||
bne LoopKw4
|
||||
add x16, x16, x12
|
||||
subs x20, x20, #1
|
||||
bne LoopKh16
|
||||
bne LoopKh4
|
||||
|
||||
sqshl v0.4s, v0.4s, v26.4s
|
||||
sqshl v1.4s, v1.4s, v26.4s
|
||||
sqshl v2.4s, v2.4s, v26.4s
|
||||
sqshl v3.4s, v3.4s, v26.4s
|
||||
sqshl v4.4s, v4.4s, v26.4s
|
||||
sqshl v5.4s, v5.4s, v26.4s
|
||||
sqshl v6.4s, v6.4s, v26.4s
|
||||
sqshl v7.4s, v7.4s, v26.4s
|
||||
sqshl v8.4s, v8.4s, v26.4s
|
||||
sqshl v9.4s, v9.4s, v26.4s
|
||||
sqshl v10.4s, v10.4s, v26.4s
|
||||
sqshl v11.4s, v11.4s, v26.4s
|
||||
sqshl v12.4s, v12.4s, v26.4s
|
||||
sqshl v13.4s, v13.4s, v26.4s
|
||||
sqshl v14.4s, v14.4s, v26.4s
|
||||
sqshl v15.4s, v15.4s, v26.4s
|
||||
sqrdmulh v0.4s, v0.4s, v27.4s
|
||||
sqrdmulh v1.4s, v1.4s, v27.4s
|
||||
sqrdmulh v2.4s, v2.4s, v27.4s
|
||||
sqrdmulh v3.4s, v3.4s, v27.4s
|
||||
sqrdmulh v4.4s, v4.4s, v27.4s
|
||||
sqrdmulh v5.4s, v5.4s, v27.4s
|
||||
sqrdmulh v6.4s, v6.4s, v27.4s
|
||||
sqrdmulh v7.4s, v7.4s, v27.4s
|
||||
sqrdmulh v8.4s, v8.4s, v27.4s
|
||||
sqrdmulh v9.4s, v9.4s, v27.4s
|
||||
sqrdmulh v10.4s, v10.4s, v27.4s
|
||||
sqrdmulh v11.4s, v11.4s, v27.4s
|
||||
sqrdmulh v12.4s, v12.4s, v27.4s
|
||||
sqrdmulh v13.4s, v13.4s, v27.4s
|
||||
sqrdmulh v14.4s, v14.4s, v27.4s
|
||||
sqrdmulh v15.4s, v15.4s, v27.4s
|
||||
sqshl v0.4s, v0.4s, v24.4s
|
||||
sqshl v1.4s, v1.4s, v25.4s
|
||||
sqshl v2.4s, v2.4s, v24.4s
|
||||
sqshl v3.4s, v3.4s, v25.4s
|
||||
sqshl v4.4s, v4.4s, v24.4s
|
||||
sqshl v5.4s, v5.4s, v25.4s
|
||||
sqshl v6.4s, v6.4s, v24.4s
|
||||
sqshl v7.4s, v7.4s, v25.4s
|
||||
|
||||
and v16.16b, v28.16b, v0.16b
|
||||
sshr v16.4s, v16.4s, #31
|
||||
sqadd v0.4s, v0.4s, v16.4s
|
||||
srshl v0.4s, v0.4s, v28.4s
|
||||
and v17.16b, v28.16b, v1.16b
|
||||
sshr v17.4s, v17.4s, #31
|
||||
sqadd v1.4s, v1.4s, v17.4s
|
||||
srshl v1.4s, v1.4s, v28.4s
|
||||
and v18.16b, v28.16b, v2.16b
|
||||
sshr v18.4s, v18.4s, #31
|
||||
sqadd v2.4s, v2.4s, v18.4s
|
||||
srshl v2.4s, v2.4s, v28.4s
|
||||
and v19.16b, v28.16b, v3.16b
|
||||
sshr v19.4s, v19.4s, #31
|
||||
sqadd v3.4s, v3.4s, v19.4s
|
||||
srshl v3.4s, v3.4s, v28.4s
|
||||
and v20.16b, v28.16b, v4.16b
|
||||
sshr v20.4s, v20.4s, #31
|
||||
sqadd v4.4s, v4.4s, v20.4s
|
||||
srshl v4.4s, v4.4s, v28.4s
|
||||
and v21.16b, v28.16b, v5.16b
|
||||
sshr v21.4s, v21.4s, #31
|
||||
sqadd v5.4s, v5.4s, v21.4s
|
||||
srshl v5.4s, v5.4s, v28.4s
|
||||
and v22.16b, v28.16b, v6.16b
|
||||
sshr v22.4s, v22.4s, #31
|
||||
sqadd v6.4s, v6.4s, v22.4s
|
||||
srshl v6.4s, v6.4s, v28.4s
|
||||
and v23.16b, v28.16b, v7.16b
|
||||
sshr v23.4s, v23.4s, #31
|
||||
sqadd v7.4s, v7.4s, v23.4s
|
||||
srshl v7.4s, v7.4s, v28.4s
|
||||
and v16.16b, v28.16b, v8.16b
|
||||
sshr v16.4s, v16.4s, #31
|
||||
sqadd v8.4s, v8.4s, v16.4s
|
||||
srshl v8.4s, v8.4s, v28.4s
|
||||
and v17.16b, v28.16b, v9.16b
|
||||
sshr v17.4s, v17.4s, #31
|
||||
sqadd v9.4s, v9.4s, v17.4s
|
||||
srshl v9.4s, v9.4s, v28.4s
|
||||
and v18.16b, v28.16b, v10.16b
|
||||
sshr v18.4s, v18.4s, #31
|
||||
sqadd v10.4s, v10.4s, v18.4s
|
||||
srshl v10.4s, v10.4s, v28.4s
|
||||
and v19.16b, v28.16b, v11.16b
|
||||
sshr v19.4s, v19.4s, #31
|
||||
sqadd v11.4s, v11.4s, v19.4s
|
||||
srshl v11.4s, v11.4s, v28.4s
|
||||
and v20.16b, v28.16b, v12.16b
|
||||
sshr v20.4s, v20.4s, #31
|
||||
sqadd v12.4s, v12.4s, v20.4s
|
||||
srshl v12.4s, v12.4s, v28.4s
|
||||
and v21.16b, v28.16b, v13.16b
|
||||
sshr v21.4s, v21.4s, #31
|
||||
sqadd v13.4s, v13.4s, v21.4s
|
||||
srshl v13.4s, v13.4s, v28.4s
|
||||
and v22.16b, v28.16b, v14.16b
|
||||
sshr v22.4s, v22.4s, #31
|
||||
sqadd v14.4s, v14.4s, v22.4s
|
||||
srshl v14.4s, v14.4s, v28.4s
|
||||
and v23.16b, v28.16b, v15.16b
|
||||
sshr v23.4s, v23.4s, #31
|
||||
sqadd v15.4s, v15.4s, v23.4s
|
||||
srshl v15.4s, v15.4s, v28.4s
|
||||
sqrdmulh v0.4s, v0.4s, v22.4s
|
||||
sqrdmulh v1.4s, v1.4s, v23.4s
|
||||
sqrdmulh v2.4s, v2.4s, v22.4s
|
||||
sqrdmulh v3.4s, v3.4s, v23.4s
|
||||
sqrdmulh v4.4s, v4.4s, v22.4s
|
||||
sqrdmulh v5.4s, v5.4s, v23.4s
|
||||
sqrdmulh v6.4s, v6.4s, v22.4s
|
||||
sqrdmulh v7.4s, v7.4s, v23.4s
|
||||
|
||||
add v0.4s, v0.4s, v29.4s
|
||||
add v1.4s, v1.4s, v29.4s
|
||||
add v2.4s, v2.4s, v29.4s
|
||||
add v3.4s, v3.4s, v29.4s
|
||||
add v4.4s, v4.4s, v29.4s
|
||||
add v5.4s, v5.4s, v29.4s
|
||||
add v6.4s, v6.4s, v29.4s
|
||||
add v7.4s, v7.4s, v29.4s
|
||||
add v8.4s, v8.4s, v29.4s
|
||||
add v9.4s, v9.4s, v29.4s
|
||||
add v10.4s, v10.4s, v29.4s
|
||||
add v11.4s, v11.4s, v29.4s
|
||||
add v12.4s, v12.4s, v29.4s
|
||||
add v13.4s, v13.4s, v29.4s
|
||||
add v14.4s, v14.4s, v29.4s
|
||||
add v15.4s, v15.4s, v29.4s
|
||||
smax v0.4s, v0.4s, v30.4s
|
||||
smax v1.4s, v1.4s, v30.4s
|
||||
smax v2.4s, v2.4s, v30.4s
|
||||
smax v3.4s, v3.4s, v30.4s
|
||||
smax v4.4s, v4.4s, v30.4s
|
||||
smax v5.4s, v5.4s, v30.4s
|
||||
smax v6.4s, v6.4s, v30.4s
|
||||
smax v7.4s, v7.4s, v30.4s
|
||||
smax v8.4s, v8.4s, v30.4s
|
||||
smax v9.4s, v9.4s, v30.4s
|
||||
smax v10.4s, v10.4s, v30.4s
|
||||
smax v11.4s, v11.4s, v30.4s
|
||||
smax v12.4s, v12.4s, v30.4s
|
||||
smax v13.4s, v13.4s, v30.4s
|
||||
smax v14.4s, v14.4s, v30.4s
|
||||
smax v15.4s, v15.4s, v30.4s
|
||||
smin v0.4s, v0.4s, v31.4s
|
||||
and v15.16b, v26.16b, v0.16b
|
||||
sshr v15.4s, v15.4s, #31
|
||||
sqadd v0.4s, v0.4s, v15.4s
|
||||
srshl v0.4s, v0.4s, v26.4s
|
||||
|
||||
and v14.16b, v27.16b, v1.16b
|
||||
sshr v14.4s, v14.4s, #31
|
||||
sqadd v1.4s, v1.4s, v14.4s
|
||||
srshl v1.4s, v1.4s, v27.4s
|
||||
|
||||
and v13.16b, v26.16b, v2.16b
|
||||
sshr v13.4s, v13.4s, #31
|
||||
sqadd v2.4s, v2.4s, v13.4s
|
||||
srshl v2.4s, v2.4s, v26.4s
|
||||
|
||||
and v12.16b, v27.16b, v3.16b
|
||||
sshr v12.4s, v12.4s, #31
|
||||
sqadd v3.4s, v3.4s, v12.4s
|
||||
srshl v3.4s, v3.4s, v27.4s
|
||||
|
||||
and v11.16b, v26.16b, v4.16b
|
||||
sshr v11.4s, v11.4s, #31
|
||||
sqadd v4.4s, v4.4s, v11.4s
|
||||
srshl v4.4s, v4.4s, v26.4s
|
||||
|
||||
and v10.16b, v27.16b, v5.16b
|
||||
sshr v10.4s, v10.4s, #31
|
||||
sqadd v5.4s, v5.4s, v10.4s
|
||||
srshl v5.4s, v5.4s, v27.4s
|
||||
|
||||
and v9.16b, v26.16b, v6.16b
|
||||
sshr v9.4s, v9.4s, #31
|
||||
sqadd v6.4s, v6.4s, v9.4s
|
||||
srshl v6.4s, v6.4s, v26.4s
|
||||
|
||||
and v8.16b, v27.16b, v7.16b
|
||||
sshr v8.4s, v8.4s, #31
|
||||
sqadd v7.4s, v7.4s, v8.4s
|
||||
srshl v7.4s, v7.4s, v27.4s
|
||||
|
||||
add v0.4s, v0.4s, v20.4s
|
||||
add v1.4s, v1.4s, v21.4s
|
||||
add v2.4s, v2.4s, v20.4s
|
||||
add v3.4s, v3.4s, v21.4s
|
||||
add v4.4s, v4.4s, v20.4s
|
||||
add v5.4s, v5.4s, v21.4s
|
||||
add v6.4s, v6.4s, v20.4s
|
||||
add v7.4s, v7.4s, v21.4s
|
||||
smax v0.4s, v0.4s, v28.4s
|
||||
smax v1.4s, v1.4s, v29.4s
|
||||
smax v2.4s, v2.4s, v28.4s
|
||||
smax v3.4s, v3.4s, v29.4s
|
||||
smax v4.4s, v4.4s, v28.4s
|
||||
smax v5.4s, v5.4s, v29.4s
|
||||
smax v6.4s, v6.4s, v28.4s
|
||||
smax v7.4s, v7.4s, v29.4s
|
||||
smin v0.4s, v0.4s, v30.4s
|
||||
smin v1.4s, v1.4s, v31.4s
|
||||
smin v2.4s, v2.4s, v31.4s
|
||||
smin v2.4s, v2.4s, v30.4s
|
||||
smin v3.4s, v3.4s, v31.4s
|
||||
smin v4.4s, v4.4s, v31.4s
|
||||
smin v4.4s, v4.4s, v30.4s
|
||||
smin v5.4s, v5.4s, v31.4s
|
||||
smin v6.4s, v6.4s, v31.4s
|
||||
smin v7.4s, v7.4s, v31.4s
|
||||
smin v8.4s, v8.4s, v31.4s
|
||||
smin v9.4s, v9.4s, v31.4s
|
||||
smin v10.4s, v10.4s, v31.4s
|
||||
smin v11.4s, v11.4s, v31.4s
|
||||
smin v12.4s, v12.4s, v31.4s
|
||||
smin v13.4s, v13.4s, v31.4s
|
||||
smin v14.4s, v14.4s, v31.4s
|
||||
smin v15.4s, v15.4s, v31.4s
|
||||
|
||||
sqxtn v0.4h, v0.4s
|
||||
sqxtn v1.4h, v1.4s
|
||||
sqxtn v2.4h, v2.4s
|
||||
sqxtn v3.4h, v3.4s
|
||||
sqxtn v4.4h, v4.4s
|
||||
sqxtn v5.4h, v5.4s
|
||||
sqxtn v6.4h, v6.4s
|
||||
sqxtn v7.4h, v7.4s
|
||||
sqxtn v8.4h, v8.4s
|
||||
sqxtn v9.4h, v9.4s
|
||||
sqxtn v10.4h, v10.4s
|
||||
sqxtn v11.4h, v11.4s
|
||||
sqxtn v12.4h, v12.4s
|
||||
sqxtn v13.4h, v13.4s
|
||||
sqxtn v14.4h, v14.4s
|
||||
sqxtn v15.4h, v15.4s
|
||||
sqxtn v0.8b, v0.8h
|
||||
sqxtn v1.8b, v1.8h
|
||||
sqxtn v2.8b, v2.8h
|
||||
sqxtn v3.8b, v3.8h
|
||||
sqxtn v4.8b, v4.8h
|
||||
sqxtn v5.8b, v5.8h
|
||||
sqxtn v6.8b, v6.8h
|
||||
sqxtn v7.8b, v7.8h
|
||||
sqxtn v8.8b, v8.8h
|
||||
sqxtn v9.8b, v9.8h
|
||||
sqxtn v10.8b, v10.8h
|
||||
sqxtn v11.8b, v11.8h
|
||||
sqxtn v12.8b, v12.8h
|
||||
sqxtn v13.8b, v13.8h
|
||||
sqxtn v14.8b, v14.8h
|
||||
sqxtn v15.8b, v15.8h
|
||||
|
||||
add x17, x3, #1
|
||||
add x18, x3, #2
|
||||
add x21, x3, #3
|
||||
st1 {v0.b}[0], [x3], x9
|
||||
st1 {v0.b}[1], [x17], x9
|
||||
st1 {v0.b}[2], [x18], x9
|
||||
st1 {v0.b}[3], [x21], x9
|
||||
|
||||
st1 {v1.b}[0], [x3], x9
|
||||
st1 {v1.b}[1], [x17], x9
|
||||
st1 {v1.b}[2], [x18], x9
|
||||
st1 {v1.b}[3], [x21], x9
|
||||
|
||||
st1 {v2.b}[0], [x3], x9
|
||||
st1 {v2.b}[1], [x17], x9
|
||||
st1 {v2.b}[2], [x18], x9
|
||||
st1 {v2.b}[3], [x21], x9
|
||||
|
||||
st1 {v3.b}[0], [x3], x9
|
||||
st1 {v3.b}[1], [x17], x9
|
||||
st1 {v3.b}[2], [x18], x9
|
||||
st1 {v3.b}[3], [x21], x9
|
||||
|
||||
st1 {v4.b}[0], [x3], x9
|
||||
st1 {v4.b}[1], [x17], x9
|
||||
st1 {v4.b}[2], [x18], x9
|
||||
st1 {v4.b}[3], [x21], x9
|
||||
|
||||
st1 {v5.b}[0], [x3], x9
|
||||
st1 {v5.b}[1], [x17], x9
|
||||
st1 {v5.b}[2], [x18], x9
|
||||
st1 {v5.b}[3], [x21], x9
|
||||
|
||||
st1 {v6.b}[0], [x3], x9
|
||||
st1 {v6.b}[1], [x17], x9
|
||||
st1 {v6.b}[2], [x18], x9
|
||||
st1 {v6.b}[3], [x21], x9
|
||||
|
||||
st1 {v7.b}[0], [x3], x9
|
||||
st1 {v7.b}[1], [x17], x9
|
||||
st1 {v7.b}[2], [x18], x9
|
||||
st1 {v7.b}[3], [x21], x9
|
||||
|
||||
st1 {v8.b}[0], [x3], x9
|
||||
st1 {v8.b}[1], [x17], x9
|
||||
st1 {v8.b}[2], [x18], x9
|
||||
st1 {v8.b}[3], [x21], x9
|
||||
|
||||
st1 {v9.b}[0], [x3], x9
|
||||
st1 {v9.b}[1], [x17], x9
|
||||
st1 {v9.b}[2], [x18], x9
|
||||
st1 {v9.b}[3], [x21], x9
|
||||
|
||||
st1 {v10.b}[0], [x3], x9
|
||||
st1 {v10.b}[1], [x17], x9
|
||||
st1 {v10.b}[2], [x18], x9
|
||||
st1 {v10.b}[3], [x21], x9
|
||||
|
||||
st1 {v11.b}[0], [x3], x9
|
||||
st1 {v11.b}[1], [x17], x9
|
||||
st1 {v11.b}[2], [x18], x9
|
||||
st1 {v11.b}[3], [x21], x9
|
||||
|
||||
st1 {v12.b}[0], [x3], x9
|
||||
st1 {v12.b}[1], [x17], x9
|
||||
st1 {v12.b}[2], [x18], x9
|
||||
st1 {v12.b}[3], [x21], x9
|
||||
|
||||
st1 {v13.b}[0], [x3], x9
|
||||
st1 {v13.b}[1], [x17], x9
|
||||
st1 {v13.b}[2], [x18], x9
|
||||
st1 {v13.b}[3], [x21], x9
|
||||
|
||||
st1 {v14.b}[0], [x3], x9
|
||||
st1 {v14.b}[1], [x17], x9
|
||||
st1 {v14.b}[2], [x18], x9
|
||||
st1 {v14.b}[3], [x21], x9
|
||||
|
||||
st1 {v15.b}[0], [x3], x9
|
||||
st1 {v15.b}[1], [x17], x9
|
||||
st1 {v15.b}[2], [x18], x9
|
||||
st1 {v15.b}[3], [x21], x9
|
||||
|
||||
add x23, x23, x19
|
||||
sub x24, x24, #16
|
||||
cmp x24, #0
|
||||
ble LoopWEnd
|
||||
cmp x24, #8
|
||||
blt LoopW
|
||||
cmp x24, #16
|
||||
bge LoopW16
|
||||
LoopW8:
|
||||
mov x19, #8
|
||||
mul x19, x19, x11
|
||||
mov x16, x23
|
||||
mov x17, x2
|
||||
mov x20, x6
|
||||
mov v0.16b, v24.16b
|
||||
mov v1.16b, v24.16b
|
||||
mov v2.16b, v24.16b
|
||||
mov v3.16b, v24.16b
|
||||
mov v4.16b, v24.16b
|
||||
mov v5.16b, v24.16b
|
||||
mov v6.16b, v24.16b
|
||||
mov v7.16b, v24.16b
|
||||
LoopKh8:
|
||||
mov x18, x7
|
||||
mov x21, x16
|
||||
LoopKw8:
|
||||
mov x22, x21
|
||||
ld1 {v25.4h}, [x17], #8
|
||||
ld1 {v16.4h}, [x22], x11
|
||||
ld1 {v17.4h}, [x22], x11
|
||||
smlal v0.4s, v16.4h, v25.4h
|
||||
smlal v1.4s, v17.4h, v25.4h
|
||||
ld1 {v18.4h}, [x22], x11
|
||||
ld1 {v19.4h}, [x22], x11
|
||||
smlal v2.4s, v18.4h, v25.4h
|
||||
smlal v3.4s, v19.4h, v25.4h
|
||||
ld1 {v20.4h}, [x22], x11
|
||||
ld1 {v21.4h}, [x22], x11
|
||||
smlal v4.4s, v20.4h, v25.4h
|
||||
smlal v5.4s, v21.4h, v25.4h
|
||||
ld1 {v22.4h}, [x22], x11
|
||||
ld1 {v23.4h}, [x22], x11
|
||||
smlal v6.4s, v22.4h, v25.4h
|
||||
smlal v7.4s, v23.4h, v25.4h
|
||||
subs x18, x18, #1
|
||||
add x21, x21, x13
|
||||
bne LoopKw8
|
||||
add x16, x16, x12
|
||||
subs x20, x20, #1
|
||||
bne LoopKh8
|
||||
|
||||
sqshl v0.4s, v0.4s, v26.4s
|
||||
sqshl v1.4s, v1.4s, v26.4s
|
||||
sqshl v2.4s, v2.4s, v26.4s
|
||||
sqshl v3.4s, v3.4s, v26.4s
|
||||
sqshl v4.4s, v4.4s, v26.4s
|
||||
sqshl v5.4s, v5.4s, v26.4s
|
||||
sqshl v6.4s, v6.4s, v26.4s
|
||||
sqshl v7.4s, v7.4s, v26.4s
|
||||
sqrdmulh v0.4s, v0.4s, v27.4s
|
||||
sqrdmulh v1.4s, v1.4s, v27.4s
|
||||
sqrdmulh v2.4s, v2.4s, v27.4s
|
||||
sqrdmulh v3.4s, v3.4s, v27.4s
|
||||
sqrdmulh v4.4s, v4.4s, v27.4s
|
||||
sqrdmulh v5.4s, v5.4s, v27.4s
|
||||
sqrdmulh v6.4s, v6.4s, v27.4s
|
||||
sqrdmulh v7.4s, v7.4s, v27.4s
|
||||
|
||||
and v16.16b, v28.16b, v0.16b
|
||||
sshr v16.4s, v16.4s, #31
|
||||
sqadd v0.4s, v0.4s, v16.4s
|
||||
srshl v0.4s, v0.4s, v28.4s
|
||||
and v17.16b, v28.16b, v1.16b
|
||||
sshr v17.4s, v17.4s, #31
|
||||
sqadd v1.4s, v1.4s, v17.4s
|
||||
srshl v1.4s, v1.4s, v28.4s
|
||||
and v18.16b, v28.16b, v2.16b
|
||||
sshr v18.4s, v18.4s, #31
|
||||
sqadd v2.4s, v2.4s, v18.4s
|
||||
srshl v2.4s, v2.4s, v28.4s
|
||||
and v19.16b, v28.16b, v3.16b
|
||||
sshr v19.4s, v19.4s, #31
|
||||
sqadd v3.4s, v3.4s, v19.4s
|
||||
srshl v3.4s, v3.4s, v28.4s
|
||||
and v20.16b, v28.16b, v4.16b
|
||||
sshr v20.4s, v20.4s, #31
|
||||
sqadd v4.4s, v4.4s, v20.4s
|
||||
srshl v4.4s, v4.4s, v28.4s
|
||||
and v21.16b, v28.16b, v5.16b
|
||||
sshr v21.4s, v21.4s, #31
|
||||
sqadd v5.4s, v5.4s, v21.4s
|
||||
srshl v5.4s, v5.4s, v28.4s
|
||||
and v22.16b, v28.16b, v6.16b
|
||||
sshr v22.4s, v22.4s, #31
|
||||
sqadd v6.4s, v6.4s, v22.4s
|
||||
srshl v6.4s, v6.4s, v28.4s
|
||||
and v23.16b, v28.16b, v7.16b
|
||||
sshr v23.4s, v23.4s, #31
|
||||
sqadd v7.4s, v7.4s, v23.4s
|
||||
srshl v7.4s, v7.4s, v28.4s
|
||||
|
||||
add v0.4s, v0.4s, v29.4s
|
||||
add v1.4s, v1.4s, v29.4s
|
||||
add v2.4s, v2.4s, v29.4s
|
||||
add v3.4s, v3.4s, v29.4s
|
||||
add v4.4s, v4.4s, v29.4s
|
||||
add v5.4s, v5.4s, v29.4s
|
||||
add v6.4s, v6.4s, v29.4s
|
||||
add v7.4s, v7.4s, v29.4s
|
||||
smax v0.4s, v0.4s, v30.4s
|
||||
smax v1.4s, v1.4s, v30.4s
|
||||
smax v2.4s, v2.4s, v30.4s
|
||||
smax v3.4s, v3.4s, v30.4s
|
||||
smax v4.4s, v4.4s, v30.4s
|
||||
smax v5.4s, v5.4s, v30.4s
|
||||
smax v6.4s, v6.4s, v30.4s
|
||||
smax v7.4s, v7.4s, v30.4s
|
||||
smin v0.4s, v0.4s, v31.4s
|
||||
smin v1.4s, v1.4s, v31.4s
|
||||
smin v2.4s, v2.4s, v31.4s
|
||||
smin v3.4s, v3.4s, v31.4s
|
||||
smin v4.4s, v4.4s, v31.4s
|
||||
smin v5.4s, v5.4s, v31.4s
|
||||
smin v6.4s, v6.4s, v31.4s
|
||||
smin v6.4s, v6.4s, v30.4s
|
||||
smin v7.4s, v7.4s, v31.4s
|
||||
|
||||
sqxtn v0.4h, v0.4s
|
||||
|
@ -535,93 +222,81 @@ ConvDwInt8Center:
|
|||
sqxtn v6.8b, v6.8h
|
||||
sqxtn v7.8b, v7.8h
|
||||
|
||||
add x17, x3, #1
|
||||
add x18, x3, #2
|
||||
add x21, x3, #3
|
||||
st1 {v0.b}[0], [x3], x9
|
||||
st1 {v0.b}[1], [x17], x9
|
||||
st1 {v0.b}[2], [x18], x9
|
||||
st1 {v0.b}[3], [x21], x9
|
||||
mov x16, x3
|
||||
add x17, x16, x9
|
||||
add x18, x17, x9
|
||||
add x21, x18, x9
|
||||
|
||||
st1 {v1.b}[0], [x3], x9
|
||||
st1 {v1.b}[1], [x17], x9
|
||||
st1 {v1.b}[2], [x18], x9
|
||||
st1 {v1.b}[3], [x21], x9
|
||||
|
||||
st1 {v2.b}[0], [x3], x9
|
||||
st1 {v2.b}[1], [x17], x9
|
||||
st1 {v2.b}[2], [x18], x9
|
||||
st1 {v2.b}[3], [x21], x9
|
||||
|
||||
st1 {v3.b}[0], [x3], x9
|
||||
st1 {v3.b}[1], [x17], x9
|
||||
st1 {v3.b}[2], [x18], x9
|
||||
st1 {v3.b}[3], [x21], x9
|
||||
|
||||
st1 {v4.b}[0], [x3], x9
|
||||
st1 {v4.b}[1], [x17], x9
|
||||
st1 {v4.b}[2], [x18], x9
|
||||
st1 {v4.b}[3], [x21], x9
|
||||
|
||||
st1 {v5.b}[0], [x3], x9
|
||||
st1 {v5.b}[1], [x17], x9
|
||||
st1 {v5.b}[2], [x18], x9
|
||||
st1 {v5.b}[3], [x21], x9
|
||||
|
||||
st1 {v6.b}[0], [x3], x9
|
||||
st1 {v6.b}[1], [x17], x9
|
||||
st1 {v6.b}[2], [x18], x9
|
||||
st1 {v6.b}[3], [x21], x9
|
||||
|
||||
st1 {v7.b}[0], [x3], x9
|
||||
st1 {v7.b}[1], [x17], x9
|
||||
st1 {v7.b}[2], [x18], x9
|
||||
st1 {v7.b}[3], [x21], x9
|
||||
st1 {v0.s}[0], [x16], #4
|
||||
st1 {v1.s}[0], [x16], #4
|
||||
st1 {v2.s}[0], [x17], #4
|
||||
st1 {v3.s}[0], [x17], #4
|
||||
st1 {v4.s}[0], [x18], #4
|
||||
st1 {v5.s}[0], [x18], #4
|
||||
st1 {v6.s}[0], [x21], #4
|
||||
st1 {v7.s}[0], [x21], #4
|
||||
|
||||
add x3, x3, x25
|
||||
add x23, x23, x19
|
||||
sub x24, x24, #8
|
||||
sub x24, x24, #4
|
||||
cmp x24, #0
|
||||
ble LoopWEnd
|
||||
cmp x24, #8
|
||||
bge LoopW8
|
||||
cmp x24, #4
|
||||
bge LoopW4
|
||||
|
||||
LoopW:
|
||||
mov x16, x23
|
||||
mov x17, x2
|
||||
mov x20, x6
|
||||
mov v0.16b, v24.16b
|
||||
mov v0.16b, v17.16b
|
||||
mov v1.16b, v18.16b
|
||||
LoopKh:
|
||||
mov x18, x7
|
||||
mov x22, x16
|
||||
LoopKw:
|
||||
ld1 {v16.4h}, [x22], x13
|
||||
ld1 {v25.4h}, [x17], #8
|
||||
smlal v0.4s, v16.4h, v25.4h
|
||||
ld1 {v15.8b}, [x22], x13
|
||||
ssubl v14.8h, v15.8b, v19.8b
|
||||
ld1 {v16.8h}, [x17], #16
|
||||
smlal v0.4s, v14.4h, v16.4h
|
||||
smlal2 v1.4s, v14.8h, v16.8h
|
||||
subs x18, x18, #1
|
||||
bne LoopKw
|
||||
add x16, x16, x12
|
||||
subs x20, x20, #1
|
||||
bne LoopKh
|
||||
|
||||
sqshl v0.4s, v0.4s, v26.4s
|
||||
sqrdmulh v0.4s, v0.4s, v27.4s
|
||||
sqshl v0.4s, v0.4s, v24.4s
|
||||
sqrdmulh v0.4s, v0.4s, v22.4s
|
||||
sqshl v1.4s, v1.4s, v25.4s
|
||||
sqrdmulh v1.4s, v1.4s, v23.4s
|
||||
|
||||
and v16.16b, v28.16b, v0.16b
|
||||
sshr v16.4s, v16.4s, #31
|
||||
sqadd v0.4s, v0.4s, v16.4s
|
||||
srshl v0.4s, v0.4s, v28.4s
|
||||
and v15.16b, v26.16b, v0.16b
|
||||
sshr v15.4s, v15.4s, #31
|
||||
sqadd v0.4s, v0.4s, v15.4s
|
||||
srshl v0.4s, v0.4s, v26.4s
|
||||
|
||||
add v0.4s, v0.4s, v29.4s
|
||||
smax v0.4s, v0.4s, v30.4s
|
||||
smin v0.4s, v0.4s, v31.4s
|
||||
and v14.16b, v27.16b, v1.16b
|
||||
sshr v14.4s, v14.4s, #31
|
||||
sqadd v1.4s, v1.4s, v14.4s
|
||||
srshl v1.4s, v1.4s, v27.4s
|
||||
|
||||
add v0.4s, v0.4s, v20.4s
|
||||
smax v0.4s, v0.4s, v28.4s
|
||||
smin v0.4s, v0.4s, v30.4s
|
||||
|
||||
sqxtn v0.4h, v0.4s
|
||||
sqxtn v0.8b, v0.8h
|
||||
|
||||
add v1.4s, v1.4s, v21.4s
|
||||
smax v1.4s, v1.4s, v29.4s
|
||||
smin v1.4s, v1.4s, v31.4s
|
||||
|
||||
sqxtn v1.4h, v1.4s
|
||||
sqxtn v1.8b, v1.8h
|
||||
|
||||
mov x17, x3
|
||||
st1 {v0.b}[0], [x17], #1
|
||||
st1 {v0.b}[1], [x17], #1
|
||||
st1 {v0.b}[2], [x17], #1
|
||||
st1 {v0.b}[3], [x17], #1
|
||||
st1 {v0.s}[0], [x17], #4
|
||||
st1 {v1.s}[0], [x17], #4
|
||||
add x3, x3, x9
|
||||
|
||||
add x23, x23, x11
|
||||
|
|
|
@ -45,10 +45,11 @@ void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *wei
|
|||
void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width,
|
||||
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
|
||||
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step);
|
||||
void ConvDwInt8Center(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, size_t height,
|
||||
void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height,
|
||||
size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel,
|
||||
size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int out_multiplier,
|
||||
int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max);
|
||||
size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp,
|
||||
int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift,
|
||||
int32_t *acc_min, int32_t *acc_max);
|
||||
void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels,
|
||||
int output_channel, int input_step, int8_t input_zp);
|
||||
void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier,
|
||||
|
|
|
@ -138,75 +138,67 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da
|
|||
}
|
||||
/*conv depthwise int8 end*/
|
||||
|
||||
/*conv depthwise sliding window int8 begin*/
|
||||
void DepthwiseBorderPixelInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int height,
|
||||
int width, int in_kh_step, int in_kw_step, int kernel_w, int *out_multiplier,
|
||||
int *left_shift, int *right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max,
|
||||
bool per_channel) {
|
||||
int tmp_buffer[C4NUM];
|
||||
for (int i = 0; i < C4NUM; i++) {
|
||||
/*conv depthwise sliding window perchannel int8 begin*/
|
||||
void DepthwiseBorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height,
|
||||
int width, int in_kh_step, int in_kw_step, int kernel_w, int8_t *input_zp,
|
||||
int32_t *out_zp, int *out_multiplier, int *left_shift, int *right_shift, int32_t *acc_min,
|
||||
int32_t *acc_max) {
|
||||
int tmp_buffer[C8NUM];
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
tmp_buffer[i] = 0;
|
||||
}
|
||||
const int16_t *src_kh = src;
|
||||
const int8_t *src_kh = src;
|
||||
const int16_t *weight_kh = weight;
|
||||
for (int kh = 0; kh < height; kh++) {
|
||||
const int16_t *src_kw = src_kh;
|
||||
const int8_t *src_kw = src_kh;
|
||||
const int16_t *weight_kw = weight_kh;
|
||||
for (int kw = 0; kw < width; kw++) {
|
||||
for (int c = 0; c < C4NUM; c++) {
|
||||
tmp_buffer[c] += src_kw[c] * weight_kw[c];
|
||||
for (int c = 0; c < C8NUM; c++) {
|
||||
tmp_buffer[c] += (src_kw[c] - input_zp[c]) * weight_kw[c];
|
||||
}
|
||||
src_kw += in_kw_step;
|
||||
weight_kw += C4NUM;
|
||||
weight_kw += C8NUM;
|
||||
} // kernel_w loop
|
||||
src_kh += in_kh_step;
|
||||
weight_kh += kernel_w * C4NUM;
|
||||
weight_kh += kernel_w * C8NUM;
|
||||
} // kernel_h loop
|
||||
int32_t left = left_shift[0];
|
||||
int32_t right = right_shift[0];
|
||||
int32_t multiplier = out_multiplier[0];
|
||||
for (int c = 0; c < C4NUM; c++) {
|
||||
if (per_channel) {
|
||||
left = left_shift[c];
|
||||
right = right_shift[c];
|
||||
multiplier = out_multiplier[c];
|
||||
}
|
||||
|
||||
for (int c = 0; c < C8NUM; c++) {
|
||||
tmp_buffer[c] += bias[c];
|
||||
tmp_buffer[c] = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left), multiplier), -right);
|
||||
tmp_buffer[c] += out_zp;
|
||||
tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min);
|
||||
tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max);
|
||||
SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]),
|
||||
-right_shift[c]);
|
||||
tmp_buffer[c] += out_zp[c];
|
||||
tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min[c]);
|
||||
tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max[c]);
|
||||
dst[c] = (tmp_buffer[c]);
|
||||
}
|
||||
}
|
||||
|
||||
void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int top,
|
||||
void DepthwiseBorderInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int top,
|
||||
int bottom, int left, int right, const ConvParameter *conv_param,
|
||||
const SlidingWindowParam *sliding, int *out_multiplier, int *left_shift, int *right_shift,
|
||||
bool per_channel) {
|
||||
const SlidingWindowParam *sliding, int8_t *in_zp, int32_t *out_zp, int *out_multiplier,
|
||||
int *left_shift, int *right_shift, int32_t *acc_min, int32_t *acc_max) {
|
||||
int8_t *dst_h = dst + top * sliding->out_h_step_;
|
||||
for (int oh = top; oh < bottom; oh++) {
|
||||
int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
|
||||
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
|
||||
const int16_t *src_h = src + ih * sliding->in_h_step_;
|
||||
const int8_t *src_h = src + ih * sliding->in_h_step_;
|
||||
|
||||
int8_t *dst_kernel = dst_h + left * sliding->block_channel_;
|
||||
for (int ow = left; ow < right; ow++) {
|
||||
int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
|
||||
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
|
||||
const int16_t *src_w = src_h + iw * sliding->block_channel_;
|
||||
const int8_t *src_w = src_h + iw * sliding->block_channel_;
|
||||
|
||||
const int16_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_;
|
||||
const int16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM;
|
||||
const int8_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_;
|
||||
const int16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM;
|
||||
|
||||
DepthwiseBorderPixelInt8(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
|
||||
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_, out_multiplier,
|
||||
left_shift, right_shift, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
|
||||
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
|
||||
per_channel);
|
||||
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_, in_zp, out_zp,
|
||||
out_multiplier, left_shift, right_shift, acc_min, acc_max);
|
||||
|
||||
dst_kernel += sliding->block_channel_;
|
||||
} // width loop
|
||||
|
@ -215,52 +207,46 @@ void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight,
|
|||
}
|
||||
|
||||
#ifndef ENABLE_ARM64
|
||||
void DepthwiseCenterInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int height,
|
||||
void DepthwiseCenterInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height,
|
||||
int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step,
|
||||
int in_sw_step, int in_kh_step, int in_kw_step, int *out_multiplier, int *left_shift,
|
||||
int *right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max, bool per_channel) {
|
||||
int tmp_buffer[C4NUM];
|
||||
int in_sw_step, int in_kh_step, int in_kw_step, int8_t *in_zp, int32_t *out_zp,
|
||||
int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t *acc_min,
|
||||
int32_t *acc_max) {
|
||||
int tmp_buffer[C8NUM];
|
||||
int8_t *dst_h = dst;
|
||||
const int16_t *src_h = src;
|
||||
const int8_t *src_h = src;
|
||||
for (int oh = 0; oh < height; oh++) {
|
||||
int8_t *dst_w = dst_h;
|
||||
const int16_t *src_w = src_h;
|
||||
const int8_t *src_w = src_h;
|
||||
for (int ow = 0; ow < width; ow++) {
|
||||
const int16_t *src_kh = src_w;
|
||||
const int8_t *src_kh = src_w;
|
||||
const int16_t *weight_kh = weight;
|
||||
|
||||
for (int i = 0; i < C4NUM; i++) {
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
tmp_buffer[i] = 0;
|
||||
}
|
||||
for (int kh = 0; kh < kernel_h; kh++) {
|
||||
const int16_t *src_kw = src_kh;
|
||||
const int8_t *src_kw = src_kh;
|
||||
const int16_t *weight_kw = weight_kh;
|
||||
for (int kw = 0; kw < kernel_w; kw++) {
|
||||
for (int c = 0; c < C4NUM; c++) {
|
||||
tmp_buffer[c] += src_kw[c] * weight_kw[c];
|
||||
for (int c = 0; c < C8NUM; c++) {
|
||||
tmp_buffer[c] += (src_kw[c] - in_zp[c]) * weight_kw[c];
|
||||
}
|
||||
src_kw += in_kw_step;
|
||||
weight_kw += C4NUM;
|
||||
weight_kw += C8NUM;
|
||||
} // kernel_w loop
|
||||
src_kh += in_kh_step;
|
||||
weight_kh += kernel_w * C4NUM;
|
||||
weight_kh += kernel_w * C8NUM;
|
||||
} // kernel_h loop
|
||||
// add bias relu
|
||||
int32_t left = left_shift[0];
|
||||
int32_t right = right_shift[0];
|
||||
int32_t multiplier = out_multiplier[0];
|
||||
for (int c = 0; c < C4NUM; c++) {
|
||||
if (per_channel) {
|
||||
left = left_shift[c];
|
||||
right = right_shift[c];
|
||||
multiplier = out_multiplier[c];
|
||||
}
|
||||
for (int c = 0; c < C8NUM; c++) {
|
||||
tmp_buffer[c] += bias[c];
|
||||
tmp_buffer[c] = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left), multiplier), -right);
|
||||
tmp_buffer[c] += out_zp;
|
||||
tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min);
|
||||
tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max);
|
||||
SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]),
|
||||
-right_shift[c]);
|
||||
tmp_buffer[c] += out_zp[c];
|
||||
tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min[c]);
|
||||
tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max[c]);
|
||||
dst_w[c] = (tmp_buffer[c]);
|
||||
}
|
||||
dst_w += block_channel;
|
||||
|
@ -272,69 +258,65 @@ void DepthwiseCenterInt8(int8_t *dst, const int16_t *src, const int16_t *weight,
|
|||
}
|
||||
#endif
|
||||
|
||||
void ConvDwSWInt8(int8_t *output_data, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data,
|
||||
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) {
|
||||
const int16_t *src = input_data;
|
||||
void ConvDwSWInt8(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, const int32_t *bias_data,
|
||||
int8_t *input_zp, int32_t *output_zp, const ConvParameter *conv_param,
|
||||
const SlidingWindowParam *sliding, int task_id) {
|
||||
const int8_t *src = input_data;
|
||||
int8_t *dst = output_data;
|
||||
bool per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL;
|
||||
int *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_;
|
||||
int *left_shift = conv_param->conv_quant_arg_.left_shift_;
|
||||
int *right_shift = conv_param->conv_quant_arg_.right_shift_;
|
||||
for (int b = 0; b < conv_param->output_batch_; b++) {
|
||||
for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) {
|
||||
const int16_t *src_data = src + oc * C4NUM;
|
||||
int8_t *dst_data = dst + oc * C4NUM;
|
||||
const int8_t *src_data = src + oc * C8NUM;
|
||||
int8_t *dst_data = dst + oc * C8NUM;
|
||||
const int16_t *weight = weight_data + oc * sliding->kernel_step_;
|
||||
const int32_t *bias = bias_data + oc * C4NUM;
|
||||
const int32_t *bias = bias_data + oc * C8NUM;
|
||||
|
||||
if (per_channel) {
|
||||
out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_ + oc * C4NUM;
|
||||
left_shift = conv_param->conv_quant_arg_.left_shift_ + oc * C4NUM;
|
||||
right_shift = conv_param->conv_quant_arg_.right_shift_ + oc * C4NUM;
|
||||
}
|
||||
int *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_ + oc * C8NUM;
|
||||
int *left_shift = conv_param->conv_quant_arg_.left_shift_ + oc * C8NUM;
|
||||
int *right_shift = conv_param->conv_quant_arg_.right_shift_ + oc * C8NUM;
|
||||
int *acc_min = conv_param->conv_quant_arg_.out_act_min_ + oc * C8NUM;
|
||||
int *acc_max = conv_param->conv_quant_arg_.out_act_max_ + oc * C8NUM;
|
||||
int8_t *in_zp = input_zp + oc * C8NUM;
|
||||
int32_t *out_zp = output_zp + oc * C8NUM;
|
||||
|
||||
DepthwiseBorderInt8(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param,
|
||||
sliding, out_multiplier, left_shift, right_shift, per_channel);
|
||||
sliding, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
|
||||
DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0,
|
||||
conv_param->output_w_, conv_param, sliding, out_multiplier, left_shift, right_shift,
|
||||
per_channel);
|
||||
conv_param->output_w_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift,
|
||||
right_shift, acc_min, acc_max);
|
||||
DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_,
|
||||
conv_param, sliding, out_multiplier, left_shift, right_shift, per_channel);
|
||||
conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min,
|
||||
acc_max);
|
||||
DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_,
|
||||
conv_param->output_w_, conv_param, sliding, out_multiplier, left_shift, right_shift,
|
||||
per_channel);
|
||||
conv_param->output_w_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift,
|
||||
right_shift, acc_min, acc_max);
|
||||
|
||||
if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
|
||||
int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
const int16_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_;
|
||||
const int8_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_;
|
||||
int8_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
|
||||
#ifdef ENABLE_ARM64
|
||||
ConvDwInt8Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
|
||||
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(int8_t),
|
||||
sliding->block_channel_ * sizeof(int8_t), sliding->in_sh_step_ * sizeof(int16_t),
|
||||
sliding->in_sw_step_ * sizeof(int16_t), sliding->in_kh_step_ * sizeof(int16_t),
|
||||
sliding->in_kw_step_ * sizeof(int16_t), conv_param->conv_quant_arg_.quant_multiplier_[0],
|
||||
conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0],
|
||||
conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
|
||||
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]);
|
||||
sliding->block_channel_ * sizeof(int8_t), sliding->in_sh_step_ * sizeof(int8_t),
|
||||
sliding->in_sw_step_ * sizeof(int8_t), sliding->in_kh_step_ * sizeof(int8_t),
|
||||
sliding->in_kw_step_ * sizeof(int8_t), in_zp, out_zp, out_multiplier, left_shift, right_shift,
|
||||
acc_min, acc_max);
|
||||
#else
|
||||
|
||||
DepthwiseCenterInt8(
|
||||
out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
|
||||
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_,
|
||||
sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, out_multiplier,
|
||||
left_shift, right_shift, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
|
||||
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], per_channel);
|
||||
DepthwiseCenterInt8(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_,
|
||||
sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_,
|
||||
sliding->out_h_step_, sliding->block_channel_, sliding->in_sh_step_, sliding->in_sw_step_,
|
||||
sliding->in_kh_step_, sliding->in_kw_step_, in_zp, out_zp, out_multiplier, left_shift,
|
||||
right_shift, acc_min, acc_max);
|
||||
#endif
|
||||
}
|
||||
} // output C4 loop
|
||||
} // output C8 loop
|
||||
src += sliding->in_step_;
|
||||
dst += sliding->out_step_;
|
||||
} // batch loop
|
||||
// output nhwc4
|
||||
// output nhwc8
|
||||
}
|
||||
/*conv depthwise sliding window int8 end*/
|
||||
/*conv depthwise sliding window perchannel int8 end*/
|
||||
|
||||
/*deconv depthwise int8 begin*/
|
||||
void DeconvDepthwiseBorderPixelInt8(int32_t *dst, const int16_t *src, const int16_t *weight, int height, int width,
|
||||
|
|
|
@ -27,8 +27,9 @@ extern "C" {
|
|||
void ConvDwInt8(int8_t *output_data, int32_t *output_row, const int8_t *input_data, const int16_t *weight_data,
|
||||
const int32_t *bias_data, const ConvParameter *conv_param, int task_id);
|
||||
|
||||
void ConvDwSWInt8(int8_t *output_data, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data,
|
||||
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id);
|
||||
void ConvDwSWInt8(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, const int32_t *bias_data,
|
||||
int8_t *input_zp, int32_t *output_zp, const ConvParameter *conv_param,
|
||||
const SlidingWindowParam *sliding, int task_id);
|
||||
|
||||
void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *input_data, const int16_t *weight_data,
|
||||
const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding,
|
||||
|
|
|
@ -869,6 +869,45 @@ void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int c
|
|||
}
|
||||
}
|
||||
|
||||
void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int c8 = UP_DIV(channel, C8NUM);
|
||||
int nhwc8_batch_unit_offset = c8 * C8NUM * plane;
|
||||
int ic_remainder_ = channel % C8NUM;
|
||||
if (ic_remainder_ != 0) {
|
||||
int nhwc8_batch_offset = 0;
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int batch_offset = b * channel * plane;
|
||||
for (int i = 0; i < plane; i++) {
|
||||
memcpy((int8_t *)dst + nhwc8_batch_offset + i * c8 * C8NUM, (int8_t *)src + batch_offset + i * channel,
|
||||
channel);
|
||||
}
|
||||
nhwc8_batch_offset += nhwc8_batch_unit_offset;
|
||||
}
|
||||
} else {
|
||||
size_t ori_input_size = batch * plane * channel;
|
||||
memcpy((int8_t *)dst, (int8_t *)src, ori_input_size);
|
||||
}
|
||||
}
|
||||
|
||||
void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int c8 = UP_DIV(channel, C8NUM);
|
||||
int nhwc8_batch_unit_offset = c8 * C8NUM * plane;
|
||||
int ic_remainder_ = channel % C8NUM;
|
||||
if (ic_remainder_ != 0) {
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int batch_offset = b * channel * plane;
|
||||
int nhwc8_batch_offset = b * nhwc8_batch_unit_offset;
|
||||
for (int i = 0; i < plane; i++) {
|
||||
memcpy((int8_t *)dst + batch_offset + i * channel, (int8_t *)src + nhwc8_batch_offset + i * c8 * C8NUM,
|
||||
channel);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
size_t ori_input_size = batch * plane * channel;
|
||||
memcpy((int8_t *)dst, (int8_t *)src, ori_input_size);
|
||||
}
|
||||
}
|
||||
|
||||
void PackNCHWToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int nhwc4_batch_offset = 0;
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
|
@ -1174,6 +1213,25 @@ void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter
|
|||
void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel,
|
||||
ConvQuantArg *quant_qrg) {
|
||||
int weight_zp = quant_qrg->filter_quant_args_[0].zp_;
|
||||
for (int c = 0; c < channel; c++) {
|
||||
if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) {
|
||||
weight_zp = quant_qrg->filter_quant_args_[c].zp_;
|
||||
}
|
||||
int c8_block_num = c / C8NUM;
|
||||
int c8_block_rem = c % C8NUM;
|
||||
const int8_t *src_c = origin_weight + c * plane;
|
||||
int16_t *dst_c = packed_weight_ + c8_block_num * plane * C8NUM;
|
||||
for (int k = 0; k < plane; k++) {
|
||||
const int8_t *src_kernel = src_c + k;
|
||||
int16_t *dst_kernel = dst_c + C8NUM * k + c8_block_rem;
|
||||
*dst_kernel = (int16_t)(src_kernel[0] - weight_zp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel,
|
||||
ConvQuantArg *quant_qrg) {
|
||||
int weight_zp = quant_qrg->filter_quant_args_[0].zp_;
|
||||
for (int c = 0; c < channel; c++) {
|
||||
if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) {
|
||||
weight_zp = quant_qrg->filter_quant_args_[c].zp_;
|
||||
|
|
|
@ -96,6 +96,10 @@ void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int c
|
|||
|
||||
void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNCHWToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNC4HW4ToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
@ -114,6 +118,9 @@ void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter
|
|||
|
||||
void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel,
|
||||
ConvQuantArg *quant_qrg);
|
||||
|
||||
void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel,
|
||||
ConvQuantArg *quant_qrg);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -177,8 +177,17 @@ kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector<lite::Tensor *>
|
|||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
MS_ASSERT(opParameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
|
||||
auto kernel =
|
||||
new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
|
||||
kernel::LiteKernel *kernel;
|
||||
auto act_quant_size =
|
||||
MSMAX(inputs[kInputIndex]->GetQuantParams().size(), outputs[kOutputIndex]->GetQuantParams().size());
|
||||
if (act_quant_size == 1) { // per tensor
|
||||
kernel = new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
} else { // per channel
|
||||
kernel =
|
||||
new (std::nothrow) kernel::ConvolutionDepthwiseSWInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
}
|
||||
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
return nullptr;
|
||||
|
|
|
@ -37,6 +37,7 @@ ConvolutionDepthwiseSWInt8CPUKernel::~ConvolutionDepthwiseSWInt8CPUKernel() {
|
|||
free(packed_weight_);
|
||||
packed_weight_ = nullptr;
|
||||
}
|
||||
FreeTmpQuant();
|
||||
FreeQuantParam();
|
||||
}
|
||||
|
||||
|
@ -45,8 +46,8 @@ int ConvolutionDepthwiseSWInt8CPUKernel::InitWeightBias() {
|
|||
// o, h, w, i -> o/8, h, w, i, 8; o == group, i == 1
|
||||
auto weight_tensor = in_tensors_[kWeightIndex];
|
||||
auto origin_weight = reinterpret_cast<int8_t *>(weight_tensor->MutableData());
|
||||
int OC4 = UP_DIV(weight_tensor->Batch(), C4NUM);
|
||||
int pack_weight_size = C4NUM * OC4 * weight_tensor->Height() * weight_tensor->Width();
|
||||
int OC8 = UP_DIV(weight_tensor->Batch(), C8NUM);
|
||||
int pack_weight_size = C8NUM * OC8 * weight_tensor->Height() * weight_tensor->Width();
|
||||
packed_weight_ = reinterpret_cast<int16_t *>(malloc(pack_weight_size * sizeof(int16_t)));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
|
@ -55,35 +56,36 @@ int ConvolutionDepthwiseSWInt8CPUKernel::InitWeightBias() {
|
|||
PackDepthwiseInt8Weight(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(),
|
||||
weight_tensor->Batch(), &(conv_param_->conv_quant_arg_));
|
||||
|
||||
bias_data_ = reinterpret_cast<int32_t *>(malloc(C4NUM * OC4 * sizeof(int32_t)));
|
||||
bias_data_ = reinterpret_cast<int32_t *>(malloc(C8NUM * OC8 * sizeof(int32_t)));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(bias_data_, 0, C4NUM * OC4 * sizeof(int32_t));
|
||||
memset(bias_data_, 0, C8NUM * OC8 * sizeof(int32_t));
|
||||
if (in_tensors_.size() == kInputSize2) {
|
||||
auto bias_tensor = in_tensors_.at(kBiasIndex);
|
||||
auto ori_bias = reinterpret_cast<int32_t *>(bias_tensor->MutableData());
|
||||
memcpy(bias_data_, ori_bias, bias_tensor->ElementsNum() * sizeof(int32_t));
|
||||
}
|
||||
|
||||
conv_param_->thread_num_ = MSMIN(thread_count_, OC4);
|
||||
conv_param_->thread_num_ = MSMIN(thread_count_, OC8);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseSWInt8CPUKernel::InitBuffer() {
|
||||
int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM *
|
||||
UP_DIV(conv_param_->input_channel_, 4);
|
||||
packed_input_ = reinterpret_cast<int16_t *>(context_->allocator->Malloc(pack_input_size * sizeof(int16_t)));
|
||||
if (packed_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (conv_param_->input_channel_ % C4NUM != 0) {
|
||||
if (conv_param_->input_channel_ % C8NUM != 0) {
|
||||
need_align_ = true;
|
||||
int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C4NUM *
|
||||
UP_DIV(conv_param_->output_channel_, C4NUM);
|
||||
|
||||
int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C8NUM *
|
||||
UP_DIV(conv_param_->input_channel_, C8NUM);
|
||||
packed_input_ = reinterpret_cast<int8_t *>(context_->allocator->Malloc(pack_input_size * sizeof(int8_t)));
|
||||
if (packed_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C8NUM *
|
||||
UP_DIV(conv_param_->output_channel_, C8NUM);
|
||||
packed_output_ = reinterpret_cast<int8_t *>(context_->allocator->Malloc(pack_output_size * sizeof(int8_t)));
|
||||
if (packed_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
|
@ -93,6 +95,136 @@ int ConvolutionDepthwiseSWInt8CPUKernel::InitBuffer() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void ConvolutionDepthwiseSWInt8CPUKernel::FreeTmpQuant() {
|
||||
if (input_scale_ != nullptr) {
|
||||
free(input_scale_);
|
||||
input_scale_ = nullptr;
|
||||
}
|
||||
if (input_zp_ != nullptr) {
|
||||
free(input_zp_);
|
||||
input_zp_ = nullptr;
|
||||
}
|
||||
if (weight_scale_ != nullptr) {
|
||||
free(weight_scale_);
|
||||
weight_scale_ = nullptr;
|
||||
}
|
||||
if (output_scale_ != nullptr) {
|
||||
free(output_scale_);
|
||||
output_scale_ = nullptr;
|
||||
}
|
||||
if (output_zp_ != nullptr) {
|
||||
free(output_zp_);
|
||||
output_zp_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseSWInt8CPUKernel::ReinitFreeBefore() {
|
||||
FreeTmpQuant();
|
||||
if (conv_quant_arg_->real_multiplier_ != nullptr) {
|
||||
free(conv_quant_arg_->real_multiplier_);
|
||||
conv_quant_arg_->real_multiplier_ = nullptr;
|
||||
}
|
||||
if (conv_quant_arg_->left_shift_ != nullptr) {
|
||||
free(conv_quant_arg_->left_shift_);
|
||||
conv_quant_arg_->left_shift_ = nullptr;
|
||||
}
|
||||
if (conv_quant_arg_->right_shift_ != nullptr) {
|
||||
free(conv_quant_arg_->right_shift_);
|
||||
conv_quant_arg_->right_shift_ = nullptr;
|
||||
}
|
||||
if (conv_quant_arg_->quant_multiplier_ != nullptr) {
|
||||
free(conv_quant_arg_->quant_multiplier_);
|
||||
conv_quant_arg_->quant_multiplier_ = nullptr;
|
||||
}
|
||||
if (conv_quant_arg_->out_act_min_ != nullptr) {
|
||||
free(conv_quant_arg_->out_act_min_);
|
||||
conv_quant_arg_->out_act_min_ = nullptr;
|
||||
}
|
||||
if (conv_quant_arg_->out_act_max_ != nullptr) {
|
||||
free(conv_quant_arg_->out_act_max_);
|
||||
conv_quant_arg_->out_act_max_ = nullptr;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseSWInt8CPUKernel::ReinitQuantParam() {
|
||||
ReinitFreeBefore(); // remalloc quant param buffer
|
||||
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto channel = conv_param_->input_channel_;
|
||||
input_scale_ = reinterpret_cast<float *>(malloc(channel * sizeof(float)));
|
||||
input_zp_ = reinterpret_cast<int8_t *>(malloc(channel * sizeof(int8_t)));
|
||||
if (input_tensor->GetQuantParams().size() == kPerTensor) {
|
||||
for (int i = 0; i < channel; i++) {
|
||||
auto input_quant_arg = input_tensor->GetQuantParams().front();
|
||||
input_zp_[i] = input_quant_arg.zeroPoint;
|
||||
input_scale_[i] = input_quant_arg.scale;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < channel; i++) {
|
||||
auto input_quant_arg = input_tensor->GetQuantParams()[i];
|
||||
input_zp_[i] = input_quant_arg.zeroPoint;
|
||||
input_scale_[i] = input_quant_arg.scale;
|
||||
}
|
||||
}
|
||||
|
||||
auto output_tensor = out_tensors_.at(kOutputIndex);
|
||||
output_scale_ = reinterpret_cast<float *>(malloc(channel * sizeof(float)));
|
||||
output_zp_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t)));
|
||||
if (output_tensor->GetQuantParams().size() == kPerTensor) {
|
||||
for (int i = 0; i < channel; i++) {
|
||||
auto output_quant_arg = output_tensor->GetQuantParams().front();
|
||||
output_zp_[i] = output_quant_arg.zeroPoint;
|
||||
output_scale_[i] = output_quant_arg.scale;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < channel; i++) {
|
||||
auto output_quant_arg = output_tensor->GetQuantParams()[i];
|
||||
output_zp_[i] = output_quant_arg.zeroPoint;
|
||||
output_scale_[i] = output_quant_arg.scale;
|
||||
}
|
||||
}
|
||||
|
||||
conv_quant_arg_->real_multiplier_ = reinterpret_cast<double *>(malloc(channel * sizeof(double)));
|
||||
conv_quant_arg_->left_shift_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t)));
|
||||
conv_quant_arg_->right_shift_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t)));
|
||||
conv_quant_arg_->quant_multiplier_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t)));
|
||||
conv_quant_arg_->out_act_min_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t)));
|
||||
conv_quant_arg_->out_act_max_ = reinterpret_cast<int32_t *>(malloc(channel * sizeof(int32_t)));
|
||||
|
||||
weight_scale_ = reinterpret_cast<float *>(malloc(channel * sizeof(float)));
|
||||
auto weight_tensor = in_tensors_.at(kWeightIndex);
|
||||
if (weight_tensor->GetQuantParams().size() == kPerTensor) {
|
||||
for (int i = 0; i < channel; i++) {
|
||||
auto weight_quant_arg = weight_tensor->GetQuantParams().front();
|
||||
weight_scale_[i] = weight_quant_arg.scale;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < channel; i++) {
|
||||
auto weight_quant_arg = weight_tensor->GetQuantParams()[i];
|
||||
weight_scale_[i] = weight_quant_arg.scale;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < channel; ++i) {
|
||||
const double in_scale = static_cast<double>(input_scale_[i] * weight_scale_[i]);
|
||||
double real_multiplier = in_scale / static_cast<double>(output_scale_[i]);
|
||||
conv_quant_arg_->real_multiplier_[i] = real_multiplier;
|
||||
QuantizeRoundParameter(real_multiplier, &conv_quant_arg_->quant_multiplier_[i], &conv_quant_arg_->left_shift_[i],
|
||||
&conv_quant_arg_->right_shift_[i]);
|
||||
}
|
||||
|
||||
// now only consider per tensor for output
|
||||
bool relu = conv_param_->act_type_ == ActType_Relu;
|
||||
bool relu6 = conv_param_->act_type_ == ActType_Relu6;
|
||||
for (int i = 0; i < channel; ++i) {
|
||||
CalculateActivationRangeQuantized(relu, relu6, output_zp_[i], output_scale_[i],
|
||||
&conv_param_->conv_quant_arg_.out_act_min_[i],
|
||||
&conv_param_->conv_quant_arg_.out_act_max_[i]);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwiseSWInt8CPUKernel::Init() {
|
||||
sliding = new (std::nothrow) SlidingWindowParam;
|
||||
if (sliding == nullptr) {
|
||||
|
@ -107,13 +239,19 @@ int ConvolutionDepthwiseSWInt8CPUKernel::Init() {
|
|||
|
||||
int ConvolutionDepthwiseSWInt8CPUKernel::ReSize() {
|
||||
ConvolutionBaseCPUKernel::Init();
|
||||
InitSlidingParamConvDw(sliding, conv_param_, C4NUM);
|
||||
InitSlidingParamConvDw(sliding, conv_param_, C8NUM);
|
||||
|
||||
auto ret = ConvolutionBaseCPUKernel::SetQuantParam();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Set quant param failed.";
|
||||
return ret;
|
||||
}
|
||||
ret = ReinitQuantParam();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "reinit quant param failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = InitWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Depthwise int8 InitWeightBias error!";
|
||||
|
@ -123,8 +261,8 @@ int ConvolutionDepthwiseSWInt8CPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
int ConvolutionDepthwiseSWInt8CPUKernel::Execute(int task_id) {
|
||||
ConvDwSWInt8(packed_output_, packed_input_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), conv_param_,
|
||||
sliding, task_id);
|
||||
ConvDwSWInt8(packed_output_, packed_input_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), input_zp_,
|
||||
output_zp_, conv_param_, sliding, task_id);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -157,7 +295,12 @@ int ConvolutionDepthwiseSWInt8CPUKernel::Run() {
|
|||
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto input_addr = reinterpret_cast<int8_t *>(input_tensor->MutableData());
|
||||
PackDepthwiseInt8Input(input_addr, packed_input_, conv_param_);
|
||||
if (need_align_) {
|
||||
PackNHWCToNHWC8Int8(input_addr, packed_input_, conv_param_->output_batch_,
|
||||
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
|
||||
} else {
|
||||
packed_input_ = input_addr;
|
||||
}
|
||||
|
||||
auto output_addr = reinterpret_cast<int8_t *>(out_tensors_.at(kOutputIndex)->MutableData());
|
||||
if (!need_align_) {
|
||||
|
@ -171,11 +314,11 @@ int ConvolutionDepthwiseSWInt8CPUKernel::Run() {
|
|||
}
|
||||
|
||||
if (need_align_) {
|
||||
PackNHWC4ToNHWCInt8(packed_output_, output_addr, conv_param_->output_batch_,
|
||||
PackNHWC8ToNHWCInt8(packed_output_, output_addr, conv_param_->output_batch_,
|
||||
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
|
||||
context_->allocator->Free(packed_input_);
|
||||
context_->allocator->Free(packed_output_);
|
||||
}
|
||||
context_->allocator->Free(packed_input_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -40,11 +40,21 @@ class ConvolutionDepthwiseSWInt8CPUKernel : public ConvolutionBaseCPUKernel {
|
|||
int Execute(int task_id);
|
||||
|
||||
private:
|
||||
int ReinitQuantParam();
|
||||
int ReinitFreeBefore();
|
||||
void FreeTmpQuant();
|
||||
|
||||
SlidingWindowParam *sliding = nullptr;
|
||||
int16_t *packed_weight_ = nullptr;
|
||||
int16_t *packed_input_ = nullptr;
|
||||
int8_t *packed_input_ = nullptr;
|
||||
int8_t *packed_output_ = nullptr;
|
||||
bool need_align_ = false;
|
||||
|
||||
int8_t *input_zp_ = nullptr;
|
||||
float *input_scale_ = nullptr;
|
||||
float *weight_scale_ = nullptr;
|
||||
int32_t *output_zp_ = nullptr;
|
||||
float *output_scale_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -52,8 +52,8 @@ int DeconvolutionDepthwiseInt8CPUKernel::InitWeightBias() {
|
|||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
PackDepthwiseInt8Weight(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(),
|
||||
weight_tensor->Batch(), &(conv_param_->conv_quant_arg_));
|
||||
PackDeconvDepthwiseInt8Weight(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(),
|
||||
weight_tensor->Batch(), &(conv_param_->conv_quant_arg_));
|
||||
|
||||
bias_data_ = reinterpret_cast<int32_t *>(malloc(C4NUM * OC4 * sizeof(int32_t)));
|
||||
if (bias_data_ == nullptr) {
|
||||
|
|
Loading…
Reference in New Issue