!6102 [MSLITE][Develop]conv1x1 per oc arm64

Merge pull request !6102 from ling/sr
This commit is contained in:
mindspore-ci-bot 2020-09-14 16:29:32 +08:00 committed by Gitee
commit 19874b83e7
13 changed files with 566 additions and 175 deletions

View File

@ -6,9 +6,9 @@
.type MatmulInt8Neon64, %function
#endif
//void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16,
// const int *a_sums, const int *bias, int act_min, int act_max, int out_zp,
// int multiplier, int left_shift, int right_shift, int row, int col, int stride);
//void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums,
// const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift,
// int32_t *right_shift, int row, int col, int stride, int filter_peroc)
// x0: a(left matrix ptr)
// x1: b(right matrix ptr)
@ -21,31 +21,34 @@
// w8: act_min
// w9: act_max
// w10: out_zp
// w11: multiplier
// w12: left_shift
// w13: right_shift
// x11: multiplier
// x12: left_shift
// x13: right_shift
// w14: row
// w15: col
// w24: stride
// w27: filter_peroc
MatmulInt8Neon64:
sub sp, sp, #192
sub sp, sp, #208
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
stp x19, x20, [sp], #16
stp x21, x22, [sp], #16
stp x23, x24, [sp], #16
stp x25, x26, [sp], #16
stp x27, x28, [sp], #16
ldr w8, [sp]
ldr w9, [sp, #8]
ldr w10, [sp, #16]
ldr w11, [sp, #24]
ldr w12, [sp, #32]
ldr w13, [sp, #40]
ldr x11, [sp, #24]
ldr x12, [sp, #32]
ldr x13, [sp, #40]
ldr w14, [sp, #48]
ldr w15, [sp, #56]
ldr w24, [sp, #64]
ldr w27, [sp, #72]
mov w17, #4 // sizeof(int8)*4
mul w21, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16
@ -58,7 +61,7 @@ L1:
mov w16, w3 // reset a row4 counter
mov w23, w14 // reset a row counter
mov x17, x0 // reload a ptr
mov x22, x6 // reload a_sums ptr
mov x22, x6 // reload a_sums ptr
L2:
cmp w16, #0
beq End2
@ -167,39 +170,60 @@ End3:
addp v19.4s, v28.4s, v30.4s
// Add (Bias+Depth*Za*Zb-Za*Bsums)
ld1 {v15.4s}, [x19], #16
ld1 {v15.4s}, [x19], #16
add v16.4s, v16.4s, v15.4s
add v17.4s, v17.4s, v15.4s
add v18.4s, v18.4s, v15.4s
add v19.4s, v19.4s, v15.4s
// Subtract (Asums*Zb)
cmp w27, #0
beq PerTLoad
PerCLoad:
ld1 {v20.4s}, [x6], #16
ld1 {v21.4s}, [x6], #16
ld1 {v22.4s}, [x6], #16
ld1 {v23.4s}, [x6], #16
ld1 {v13.4s}, [x12]
ld1 {v12.4s}, [x11]
ld1 {v11.4s}, [x13]
b Apply
PerTLoad:
ld1 {v14.4s}, [x22], #16
dup v20.4s, v14.s[0]
dup v21.4s, v14.s[1]
dup v22.4s, v14.s[2]
dup v23.4s, v14.s[3]
ld1 {v14.s}[0], [x12]
dup v13.4s, v14.s[0]
ld1 {v14.s}[0], [x11]
dup v12.4s, v14.s[0]
ld1 {v14.s}[0], [x13]
dup v11.4s, v14.s[0]
b Apply
Apply:
// Subtract (Asums*Zb)
sub v16.4s, v16.4s, v20.4s
sub v17.4s, v17.4s, v21.4s
sub v18.4s, v18.4s, v22.4s
sub v19.4s, v19.4s, v23.4s
// Apply left shift
dup v13.4s, w12
sqshl v16.4s, v16.4s, v13.4s
sqshl v17.4s, v17.4s, v13.4s
sqshl v18.4s, v18.4s, v13.4s
sqshl v19.4s, v19.4s, v13.4s
// Apply the fixed-point part of the multiplier.
dup v12.4s, w11
sqrdmulh v16.4s, v16.4s, v12.4s
sqrdmulh v17.4s, v17.4s, v12.4s
sqrdmulh v18.4s, v18.4s, v12.4s
sqrdmulh v19.4s, v19.4s, v12.4s
// Apply right shift
dup v11.4s, w13
and v20.16b, v11.16b, v16.16b
sshr v20.4s, v20.4s, #31
sqadd v16.4s, v16.4s, v20.4s
@ -268,7 +292,7 @@ Write:
beq WriteCol2
cmp w15, #1
beq WriteCol1
WriteCol4:
st1 {v15.s}[0], [x2], x24
cmp w23, #1
@ -349,7 +373,7 @@ WriteCol1:
st1 {v15.b}[12], [x2], x24
b Endwrite
Endwrite:
Endwrite:
sub w16, w16, #4 // a row4 counter - 4
sub w23, w23, #4 // a row counter - 4
b L2
@ -361,15 +385,23 @@ End2:
add x7, x7, #16 // bias ptr + stride
add x25, x25, #4 // output + stride(4 * sizeof(int8))
mov x2, x25
cmp w27, #0
beq PerTEnd2
add x12, x12, #16
add x11, x11, #16
add x13, x13, #16
PerTEnd2:
b L1
End1:
sub sp, sp, #192
sub sp, sp, #208
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ldp x19, x20, [sp], #16
ldp x21, x22, [sp], #16
ldp x23, x24, [sp], #16
ldp x25, x26, [sp], #16
ldp x27, x28, [sp], #16
ret
#endif

View File

@ -8,7 +8,7 @@
//void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, int row8, int col8, int deep4,
// const int *a_sums, const int *bias, int act_min, int act_max, int out_zp,
// int multiplier, int left_shift, int right_shift, int row, int col, int stride);
// int *multiplier, int *left_shift, int *right_shift, int row, int col, int stride, int peroc);
// x0: a(left matrix ptr)
// x1: b(right matrix ptr)
@ -21,31 +21,34 @@
// w8: act_min
// w9: act_max
// w10: out_zp
// w11: multiplier
// w12: left_shift
// w13: right_shift
// x11: multiplier
// x12: left_shift
// x13: right_shift
// w14: row
// w15: col
// w24: stride
// w27: filter_peroc
MatmulInt8DpNeon64:
sub sp, sp, #192
sub sp, sp, #208
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
stp x19, x20, [sp], #16
stp x21, x22, [sp], #16
stp x23, x24, [sp], #16
stp x25, x26, [sp], #16
stp x27, x28, [sp], #16
ldr w8, [sp]
ldr w9, [sp, #8]
ldr w10, [sp, #16]
ldr w11, [sp, #24]
ldr w12, [sp, #32]
ldr w13, [sp, #40]
ldr x11, [sp, #24]
ldr x12, [sp, #32]
ldr x13, [sp, #40]
ldr w14, [sp, #48]
ldr w15, [sp, #56]
ldr w24, [sp, #64]
ldr w27, [sp, #72]
mov w17, #8 // sizeof(int8)*8
mul w21, w5, w17 // the stride of a/b: sizeof(int8)*8*deep4
@ -226,138 +229,171 @@ End3:
add v29.4s, v29.4s, v14.4s
add v31.4s, v31.4s, v14.4s
cmp w27, #0
beq PerTSumLoad
PerCSumLoad:
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x6], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x6], #64
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x6], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], #64
b ApplySum
PerTSumLoad:
ld1 {v14.4s}, [x22], #16
ld1 {v15.4s}, [x22], #16
dup v0.4s, v14.s[0]
dup v1.4s, v14.s[0]
dup v2.4s, v14.s[1]
dup v3.4s, v14.s[1]
dup v4.4s, v14.s[2]
dup v5.4s, v14.s[2]
dup v6.4s, v14.s[3]
dup v7.4s, v14.s[3]
dup v8.4s, v15.s[0]
dup v9.4s, v15.s[0]
dup v10.4s, v15.s[1]
dup v11.4s, v15.s[1]
dup v12.4s, v15.s[2]
dup v13.4s, v15.s[2]
dup v14.4s, v15.s[3]
dup v15.4s, v14.s[0]
ApplySum:
// Subtract (Asums*Zb)
ld1 {v13.4s}, [x22], #16
ld1 {v12.4s}, [x22], #16
dup v0.4s, v13.s[0]
dup v1.4s, v13.s[1]
dup v2.4s, v13.s[2]
dup v3.4s, v13.s[3]
dup v4.4s, v12.s[0]
dup v5.4s, v12.s[1]
dup v6.4s, v12.s[2]
dup v7.4s, v12.s[3]
sub v16.4s, v16.4s, v0.4s
sub v17.4s, v17.4s, v0.4s
sub v18.4s, v18.4s, v1.4s
sub v19.4s, v19.4s, v1.4s
sub v20.4s, v20.4s, v2.4s
sub v21.4s, v21.4s, v2.4s
sub v22.4s, v22.4s, v3.4s
sub v23.4s, v23.4s, v3.4s
sub v24.4s, v24.4s, v4.4s
sub v25.4s, v25.4s, v4.4s
sub v26.4s, v26.4s, v5.4s
sub v27.4s, v27.4s, v5.4s
sub v28.4s, v28.4s, v6.4s
sub v29.4s, v29.4s, v6.4s
sub v30.4s, v30.4s, v7.4s
sub v31.4s, v31.4s, v7.4s
sub v17.4s, v17.4s, v1.4s
sub v18.4s, v18.4s, v2.4s
sub v19.4s, v19.4s, v3.4s
sub v20.4s, v20.4s, v4.4s
sub v21.4s, v21.4s, v5.4s
sub v22.4s, v22.4s, v6.4s
sub v23.4s, v23.4s, v7.4s
sub v24.4s, v24.4s, v8.4s
sub v25.4s, v25.4s, v9.4s
sub v26.4s, v26.4s, v10.4s
sub v27.4s, v27.4s, v11.4s
sub v28.4s, v28.4s, v12.4s
sub v29.4s, v29.4s, v13.4s
sub v30.4s, v30.4s, v14.4s
sub v31.4s, v31.4s, v15.4s
cmp w27, #0
beq PerTRoundLoad
PerCRoundLoad:
ld1 {v8.4s, v9.4s}, [x12]
ld1 {v10.4s, v11.4s}, [x11]
ld1 {v12.4s, v13.4s}, [x13]
b ApplyRound
PerTRoundLoad:
ld1 {v14.s}[0], [x12]
dup v8.4s, v14.s[0]
dup v9.4s, v14.s[0]
ld1 {v14.s}[0], [x11]
dup v10.4s, v14.s[0]
dup v11.4s, v14.s[0]
ld1 {v14.s}[0], [x13]
dup v12.4s, v14.s[0]
dup v13.4s, v14.s[0]
ApplyRound:
// Apply left shift
dup v11.4s, w12
sqshl v16.4s, v16.4s, v11.4s
sqshl v17.4s, v17.4s, v11.4s
sqshl v18.4s, v18.4s, v11.4s
sqshl v19.4s, v19.4s, v11.4s
sqshl v20.4s, v20.4s, v11.4s
sqshl v21.4s, v21.4s, v11.4s
sqshl v22.4s, v22.4s, v11.4s
sqshl v23.4s, v23.4s, v11.4s
sqshl v24.4s, v24.4s, v11.4s
sqshl v25.4s, v25.4s, v11.4s
sqshl v26.4s, v26.4s, v11.4s
sqshl v27.4s, v27.4s, v11.4s
sqshl v28.4s, v28.4s, v11.4s
sqshl v29.4s, v29.4s, v11.4s
sqshl v30.4s, v30.4s, v11.4s
sqshl v31.4s, v31.4s, v11.4s
sqshl v16.4s, v16.4s, v8.4s
sqshl v17.4s, v17.4s, v9.4s
sqshl v18.4s, v18.4s, v8.4s
sqshl v19.4s, v19.4s, v9.4s
sqshl v20.4s, v20.4s, v8.4s
sqshl v21.4s, v21.4s, v9.4s
sqshl v22.4s, v22.4s, v8.4s
sqshl v23.4s, v23.4s, v9.4s
sqshl v24.4s, v24.4s, v8.4s
sqshl v25.4s, v25.4s, v9.4s
sqshl v26.4s, v26.4s, v8.4s
sqshl v27.4s, v27.4s, v9.4s
sqshl v28.4s, v28.4s, v8.4s
sqshl v29.4s, v29.4s, v9.4s
sqshl v30.4s, v30.4s, v8.4s
sqshl v31.4s, v31.4s, v9.4s
// Apply the fixed-point part of the multiplier.
dup v10.4s, w11
sqrdmulh v16.4s, v16.4s, v10.4s
sqrdmulh v17.4s, v17.4s, v10.4s
sqrdmulh v17.4s, v17.4s, v11.4s
sqrdmulh v18.4s, v18.4s, v10.4s
sqrdmulh v19.4s, v19.4s, v10.4s
sqrdmulh v19.4s, v19.4s, v11.4s
sqrdmulh v20.4s, v20.4s, v10.4s
sqrdmulh v21.4s, v21.4s, v10.4s
sqrdmulh v21.4s, v21.4s, v11.4s
sqrdmulh v22.4s, v22.4s, v10.4s
sqrdmulh v23.4s, v23.4s, v10.4s
sqrdmulh v23.4s, v23.4s, v11.4s
sqrdmulh v24.4s, v24.4s, v10.4s
sqrdmulh v25.4s, v25.4s, v10.4s
sqrdmulh v25.4s, v25.4s, v11.4s
sqrdmulh v26.4s, v26.4s, v10.4s
sqrdmulh v27.4s, v27.4s, v10.4s
sqrdmulh v27.4s, v27.4s, v11.4s
sqrdmulh v28.4s, v28.4s, v10.4s
sqrdmulh v29.4s, v29.4s, v10.4s
sqrdmulh v29.4s, v29.4s, v11.4s
sqrdmulh v30.4s, v30.4s, v10.4s
sqrdmulh v31.4s, v31.4s, v10.4s
sqrdmulh v31.4s, v31.4s, v11.4s
// Apply right shift
dup v9.4s, w13
and v0.16b, v9.16b, v16.16b
and v0.16b, v12.16b, v16.16b
sshr v0.4s, v0.4s, #31
sqadd v16.4s, v16.4s, v0.4s
srshl v16.4s, v16.4s, v9.4s
and v1.16b, v9.16b, v17.16b
srshl v16.4s, v16.4s, v12.4s
and v1.16b, v13.16b, v17.16b
sshr v1.4s, v1.4s, #31
sqadd v17.4s, v17.4s, v1.4s
srshl v17.4s, v17.4s, v9.4s
and v2.16b, v9.16b, v18.16b
srshl v17.4s, v17.4s, v13.4s
and v2.16b, v12.16b, v18.16b
sshr v2.4s, v2.4s, #31
sqadd v18.4s, v18.4s, v2.4s
srshl v18.4s, v18.4s, v9.4s
and v3.16b, v9.16b, v19.16b
srshl v18.4s, v18.4s, v12.4s
and v3.16b, v13.16b, v19.16b
sshr v3.4s, v3.4s, #31
sqadd v19.4s, v19.4s, v3.4s
srshl v19.4s, v19.4s, v9.4s
and v0.16b, v9.16b, v20.16b
srshl v19.4s, v19.4s, v13.4s
and v0.16b, v12.16b, v20.16b
sshr v0.4s, v0.4s, #31
sqadd v20.4s, v20.4s, v0.4s
srshl v20.4s, v20.4s, v9.4s
and v1.16b, v9.16b, v21.16b
srshl v20.4s, v20.4s, v12.4s
and v1.16b, v13.16b, v21.16b
sshr v1.4s, v1.4s, #31
sqadd v21.4s, v21.4s, v1.4s
srshl v21.4s, v21.4s, v9.4s
and v2.16b, v9.16b, v22.16b
srshl v21.4s, v21.4s, v13.4s
and v2.16b, v12.16b, v22.16b
sshr v2.4s, v2.4s, #31
sqadd v22.4s, v22.4s, v2.4s
srshl v22.4s, v22.4s, v9.4s
and v3.16b, v9.16b, v23.16b
srshl v22.4s, v22.4s, v12.4s
and v3.16b, v13.16b, v23.16b
sshr v3.4s, v3.4s, #31
sqadd v23.4s, v23.4s, v3.4s
srshl v23.4s, v23.4s, v9.4s
and v0.16b, v9.16b, v24.16b
srshl v23.4s, v23.4s, v13.4s
and v0.16b, v12.16b, v24.16b
sshr v0.4s, v0.4s, #31
sqadd v24.4s, v24.4s, v0.4s
srshl v24.4s, v24.4s, v9.4s
and v1.16b, v9.16b, v25.16b
srshl v24.4s, v24.4s, v12.4s
and v1.16b, v13.16b, v25.16b
sshr v1.4s, v1.4s, #31
sqadd v25.4s, v25.4s, v1.4s
srshl v25.4s, v25.4s, v9.4s
and v2.16b, v9.16b, v26.16b
srshl v25.4s, v25.4s, v13.4s
and v2.16b, v12.16b, v26.16b
sshr v2.4s, v2.4s, #31
sqadd v26.4s, v26.4s, v2.4s
srshl v26.4s, v26.4s, v9.4s
and v3.16b, v9.16b, v27.16b
srshl v26.4s, v26.4s, v12.4s
and v3.16b, v13.16b, v27.16b
sshr v3.4s, v3.4s, #31
sqadd v27.4s, v27.4s, v3.4s
srshl v27.4s, v27.4s, v9.4s
and v0.16b, v9.16b, v28.16b
srshl v27.4s, v27.4s, v13.4s
and v0.16b, v12.16b, v28.16b
sshr v0.4s, v0.4s, #31
sqadd v28.4s, v28.4s, v0.4s
srshl v28.4s, v28.4s, v9.4s
and v1.16b, v9.16b, v29.16b
srshl v28.4s, v28.4s, v12.4s
and v1.16b, v13.16b, v29.16b
sshr v1.4s, v1.4s, #31
sqadd v29.4s, v29.4s, v1.4s
srshl v29.4s, v29.4s, v9.4s
and v2.16b, v9.16b, v30.16b
srshl v29.4s, v29.4s, v13.4s
and v2.16b, v12.16b, v30.16b
sshr v2.4s, v2.4s, #31
sqadd v30.4s, v30.4s, v2.4s
srshl v30.4s, v30.4s, v9.4s
and v3.16b, v9.16b, v31.16b
srshl v30.4s, v30.4s, v12.4s
and v3.16b, v13.16b, v31.16b
sshr v3.4s, v3.4s, #31
sqadd v31.4s, v31.4s, v3.4s
srshl v31.4s, v31.4s, v9.4s
srshl v31.4s, v31.4s, v13.4s
// Add the destination zero point
dup v8.4s, w10
@ -793,15 +829,23 @@ End2:
add x7, x7, #32 // bias ptr + stride
add x25, x25, #8 // output + stride(8 * sizeof(int8))
mov x2, x25
cmp w27, #0
beq PerTEnd2
add x12, x12, #32
add x11, x11, #32
add x13, x13, #32
PerTEnd2:
b L1
End1:
sub sp, sp, #192
sub sp, sp, #208
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ldp x19, x20, [sp], #16
ldp x21, x22, [sp], #16
ldp x23, x24, [sp], #16
ldp x25, x26, [sp], #16
ldp x27, x28, [sp], #16
ret
#endif

View File

@ -385,12 +385,302 @@ void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *
int8_t *pack_r = packed_input;
int32_t *input_sum_r = input_sum;
/* per layer */
for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) {
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
int32_t *input_sum_oc = input_sum_r;
#ifdef ENABLE_ARM64
size_t src_stride = input_channel;
size_t ic_4res = input_channel - ic_4div;
size_t input_sum_stride = inputsum_stride * 4 - C8NUM * C8NUM * 4;
asm volatile(
"dup v16.4s, wzr \n"
"dup v17.4s, wzr \n"
"mov x10, %[src_ic] \n"
"mov x11, %[pack_ic] \n"
"mov x0, #0 \n"
"1: \n"
"cmp x0, %[ic_4div] \n"
"add x0, x0, #4\n"
"mov x12, x10 \n"
"add x10, x10, #4\n"
"blt 2f \n"
"cmp %[ic_4res], #0\n"
"beq 6f \n"
"cmp %[ic_4res], #1\n"
"beq 3f \n"
"cmp %[ic_4res], #2\n"
"beq 4f \n"
"cmp %[ic_4res], #3\n"
"beq 5f \n"
"2: \n"
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
"ld1 {v1.s}[0], [x12], %[src_stride]\n"
"ld1 {v1.s}[1], [x12], %[src_stride]\n"
"ld1 {v1.s}[2], [x12], %[src_stride]\n"
"ld1 {v1.s}[3], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 1b \n"
"3: \n" /* col res 1 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"ld1 {v0.b}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
"ld1 {v1.b}[0], [x12], %[src_stride]\n"
"ld1 {v1.b}[4], [x12], %[src_stride]\n"
"ld1 {v1.b}[8], [x12], %[src_stride]\n"
"ld1 {v1.b}[12], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"4: \n" /* col res 2 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"5: \n" /* col res 3 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"add x13, x12, #2 \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
"ld1 {v1.b}[2], [x13], %[src_stride]\n"
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
"ld1 {v1.b}[6], [x13], %[src_stride]\n"
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
"ld1 {v1.b}[10], [x13], %[src_stride]\n"
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
"ld1 {v1.b}[14], [x13], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"6: \n"
"dup v0.4s, v16.s[0] \n"
"dup v1.4s, v16.s[1] \n"
"dup v2.4s, v16.s[2] \n"
"dup v3.4s, v16.s[3] \n"
"dup v4.4s, v17.s[0] \n"
"dup v5.4s, v17.s[1] \n"
"dup v6.4s, v17.s[2] \n"
"dup v7.4s, v17.s[3] \n"
"mov x4, #0 \n"
"mov x10, %[filter_zp] \n"
"mov x11, %[input_sum_oc] \n"
"7: \n"
"cmp x4, %[oc_8div] \n"
"beq 8f \n"
"add x4, x4, #8\n"
"ld1 {v16.4s}, [x10], #16\n"
"ld1 {v17.4s}, [x10], #16\n"
"mul v18.4s, v16.4s, v0.4s \n"
"mul v19.4s, v17.4s, v0.4s \n"
"st1 {v18.4s}, [x11], #16 \n"
"st1 {v19.4s}, [x11], #16 \n"
"mul v20.4s, v16.4s, v1.4s \n"
"mul v21.4s, v17.4s, v1.4s \n"
"st1 {v20.4s}, [x11], #16 \n"
"st1 {v21.4s}, [x11], #16 \n"
"mul v22.4s, v16.4s, v2.4s \n"
"mul v23.4s, v17.4s, v2.4s \n"
"st1 {v22.4s}, [x11], #16 \n"
"st1 {v23.4s}, [x11], #16 \n"
"mul v24.4s, v16.4s, v3.4s \n"
"mul v25.4s, v17.4s, v3.4s \n"
"st1 {v24.4s}, [x11], #16 \n"
"st1 {v25.4s}, [x11], #16 \n"
"mul v18.4s, v16.4s, v4.4s \n"
"mul v19.4s, v17.4s, v4.4s \n"
"st1 {v18.4s}, [x11], #16 \n"
"st1 {v19.4s}, [x11], #16 \n"
"mul v20.4s, v16.4s, v5.4s \n"
"mul v21.4s, v17.4s, v5.4s \n"
"st1 {v20.4s}, [x11], #16 \n"
"st1 {v21.4s}, [x11], #16 \n"
"mul v22.4s, v16.4s, v6.4s \n"
"mul v23.4s, v17.4s, v6.4s \n"
"st1 {v22.4s}, [x11], #16 \n"
"st1 {v23.4s}, [x11], #16 \n"
"mul v24.4s, v16.4s, v7.4s \n"
"mul v25.4s, v17.4s, v7.4s \n"
"st1 {v24.4s}, [x11], #16 \n"
"st1 {v25.4s}, [x11], #16 \n"
"add x11, x11, %[input_sum_stride] \n"
"b 7b \n"
"8: \n"
"cmp %[oc_8res], #0\n"
"beq 17f \n"
"dup v16.4s, wzr \n"
"dup v17.4s, wzr \n"
"cmp %[oc_8res], #1\n"
"beq 9f \n"
"cmp %[oc_8res], #2\n"
"beq 10f \n"
"cmp %[oc_8res], #3\n"
"beq 11f \n"
"cmp %[oc_8res], #4\n"
"beq 12f \n"
"cmp %[oc_8res], #5\n"
"beq 13f \n"
"cmp %[oc_8res], #6\n"
"beq 14f \n"
"cmp %[oc_8res], #7\n"
"beq 15f \n"
"9: \n"
"ld1 {v16.s}[0], [x10] \n"
"b 16f \n"
"10: \n"
"ld1 {v16.h}[0], [x10] \n"
"b 16f \n"
"11: \n"
"ld1 {v16.h}[0], [x10] \n"
"add x10, x10, #8 \n"
"ld1 {v16.s}[2], [x10] \n"
"b 16f \n"
"12: \n"
"ld1 {v16.4s}, [x10] \n"
"b 16f \n"
"13: \n"
"ld1 {v16.4s}, [x10], #16\n"
"ld1 {v17.s}[0], [x10] \n"
"b 16f \n"
"14: \n"
"ld1 {v16.4s}, [x10], #16\n"
"ld1 {v17.h}[0], [x10] \n"
"b 16f \n"
"15: \n"
"ld1 {v16.4s}, [x10], #16\n"
"ld1 {v17.h}[0], [x10] \n"
"add x10, x10, #8 \n"
"ld1 {v17.s}[2], [x10] \n"
"b 16f \n"
"16: \n"
"mul v18.4s, v16.4s, v0.4s \n"
"mul v19.4s, v17.4s, v0.4s \n"
"mul v20.4s, v16.4s, v1.4s \n"
"mul v21.4s, v17.4s, v1.4s \n"
"mul v22.4s, v16.4s, v2.4s \n"
"mul v23.4s, v17.4s, v2.4s \n"
"mul v24.4s, v16.4s, v3.4s \n"
"mul v25.4s, v17.4s, v3.4s \n"
"st1 {v18.4s}, [x11], #16 \n"
"st1 {v19.4s}, [x11], #16 \n"
"st1 {v20.4s}, [x11], #16 \n"
"st1 {v21.4s}, [x11], #16 \n"
"st1 {v22.4s}, [x11], #16 \n"
"st1 {v23.4s}, [x11], #16 \n"
"st1 {v24.4s}, [x11], #16 \n"
"st1 {v25.4s}, [x11], #16 \n"
"mul v18.4s, v16.4s, v4.4s \n"
"mul v19.4s, v17.4s, v4.4s \n"
"mul v20.4s, v16.4s, v5.4s \n"
"mul v21.4s, v17.4s, v5.4s \n"
"mul v22.4s, v16.4s, v6.4s \n"
"mul v23.4s, v17.4s, v6.4s \n"
"mul v24.4s, v16.4s, v7.4s \n"
"mul v25.4s, v17.4s, v7.4s \n"
"st1 {v18.4s}, [x11], #16 \n"
"st1 {v19.4s}, [x11], #16 \n"
"st1 {v20.4s}, [x11], #16 \n"
"st1 {v21.4s}, [x11], #16 \n"
"st1 {v22.4s}, [x11], #16 \n"
"st1 {v23.4s}, [x11], #16 \n"
"st1 {v24.4s}, [x11], #16 \n"
"st1 {v25.4s}, [x11], #16 \n"
"17: \n"
:
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ filter_zp ] "r"(filter_zp),
[ input_sum_oc ] "r"(input_sum_oc), [ input_sum_stride ] "r"(input_sum_stride), [ src_stride ] "r"(src_stride),
[ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ oc_8div ] "r"(oc_8div), [ oc_8res ] "r"(oc_8res)
: "x0", "x1", "x4", "x9", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16",
"v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25");
#else
int32_t tmp_sum_value[8] = {0};
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
for (int i = 0; i < C8NUM; i++) {
@ -440,7 +730,7 @@ void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *
}
}
} /* oc8 res done */
#endif
src_r += input_channel * C8NUM;
pack_r += ic4 * C8NUM;
input_sum_r += C8NUM * C8NUM;
@ -520,9 +810,9 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i
size_t src_stride = input_channel;
size_t ic_4res = input_channel - ic_4div;
asm volatile(
"dup v10.4s, wzr \n"
"dup v11.4s, wzr \n"
"mov x20, %[input_sum_r] \n"
"dup v16.4s, wzr \n"
"dup v17.4s, wzr \n"
"mov x14, %[input_sum_r] \n"
"dup v20.4s, %w[filter_zp] \n"
"mov x10, %[src_ic] \n"
@ -563,8 +853,8 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v10.4s, v10.4s, v0.4s \n"
"add v11.4s, v11.4s, v1.4s \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 1b \n"
"3: \n" /* col res 1 */
@ -586,8 +876,8 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v10.4s, v10.4s, v0.4s \n"
"add v11.4s, v11.4s, v1.4s \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"4: \n" /* col res 2 */
@ -609,8 +899,8 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v10.4s, v10.4s, v0.4s \n"
"add v11.4s, v11.4s, v1.4s \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"5: \n" /* col res 3 */
@ -641,21 +931,21 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v10.4s, v10.4s, v0.4s \n"
"add v11.4s, v11.4s, v1.4s \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"6: \n"
"mul v10.4s, v10.4s, v20.4s \n"
"mul v11.4s, v11.4s, v20.4s \n"
"mul v16.4s, v16.4s, v20.4s \n"
"mul v17.4s, v17.4s, v20.4s \n"
"st1 {v10.4s}, [x20], #16 \n"
"st1 {v11.4s}, [x20], #16 \n"
"st1 {v16.4s}, [x14], #16 \n"
"st1 {v17.4s}, [x14], #16 \n"
:
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r),
[ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp)
: "x0", "x1", "x10", "x11", "x12", "x13", "x20", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v11",
: "x0", "x1", "x10", "x11", "x12", "x13", "x14", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17",
"v20");
#else
int32_t tmp_sum_value[8] = {0};
@ -728,10 +1018,10 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i
void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func) {
int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1;
matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias,
left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
conv_param->conv_quant_arg_.filter_arg_num_ != 1);
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc);
return;
}
@ -756,24 +1046,17 @@ void Conv1x1Int8Arm32(const int8_t *packed_input, const int8_t *packed_weight, i
void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param) {
if (conv_param->conv_quant_arg_.filter_arg_num_ > 1) {
return MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum,
bias, left_shift, right_shift, multiplier,
conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
conv_param->conv_quant_arg_.filter_arg_num_ != 1);
}
int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1;
#ifdef ENABLE_ARM64
MatmulInt8Neon64(packed_input, packed_weight, dst, UP_ROUND(row, C4NUM), UP_ROUND(col, C4NUM), deep16, input_sum,
bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0],
conv_param->conv_quant_arg_.right_shift_[0], row, col, conv_param->output_channel_);
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, row, col,
conv_param->output_channel_, is_per_oc);
#else
MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias,
left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
conv_param->conv_quant_arg_.filter_arg_num_ != 1);
is_per_oc);
#endif
return;
}

View File

@ -269,7 +269,7 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row,
void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
bool per_channel) {
size_t per_channel) {
/* row8x4-major * row4x8-major => (int8)row-major */
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {

View File

@ -39,7 +39,7 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col);
void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
bool per_channel);
size_t per_channel);
void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void RowMajor2Row4x8MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16);
@ -59,8 +59,8 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row,
#ifdef ENABLE_ARM64
void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums,
const int *bias, int act_min, int act_max, int out_zp, int multiplier, int left_shift,
int right_shift, int row, int col, int stride);
const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift,
int32_t *right_shift, int row, int col, int stride, int filter_peroc);
void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16,
const int *input_sum, const int *bias);

View File

@ -25,7 +25,7 @@ typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int *dst, i
typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini,
int32_t maxi, bool per_channel);
int32_t maxi, size_t per_channel);
typedef void (*MAT_TRANS_FUNC)(void *dst, void *a, int row, int col);

View File

@ -30,8 +30,9 @@ extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_
extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16,
const int *input_sum, const int *bias);
extern void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, int row8, int col8, int deep4,
const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, int multiplier,
int left_shift, int right_shift, int row, int col, int stride);
const int *a_sums, const int *bias, int act_min, int act_max, int out_zp,
int *multiplier, int *left_shift, int *right_shift, int row, int col, int stride,
size_t peroc);
#ifdef __cplusplus
}
@ -55,8 +56,8 @@ void MatMulR4Int8_optimize_handler(const int8_t *a, const int8_t *b, int *dst, i
void MatMulRInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini,
int32_t maxi, bool per_channel) {
int32_t maxi, size_t per_channel) {
return MatmulInt8DpNeon64(a, b, dst, UP_ROUND(row, 8), UP_ROUND(col, 8), deep_4, input_sum, bias, mini, maxi,
output_zp, multiplier[0], left_shift[0], right_shift[0], row, col, stride);
output_zp, multiplier, left_shift, right_shift, row, col, stride, per_channel);
}
#endif

View File

@ -273,13 +273,13 @@ void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, i
"mov x11, %[input_sum] \n"
"mov x15, %[filter_zp_ptr] \n"
"mov x0, #0 \n" // row 4 count
"mov x0, #0 \n"
"1: \n"
"cmp x0, %[hw4] \n"
"beq 11f \n"
"add x0, x0, #4\n"
"dup v10.4s, wzr \n"
"mov x2, #0 \n" // input deep count
"mov x2, #0 \n"
"mov x16, x15 \n"
"2: \n"
@ -313,9 +313,9 @@ void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, i
"b 2b \n"
"3: \n"
"mov x12, x11 \n" // tmp inputsm inputsum hw
"mov x12, x11 \n"
"add x11, x11, #64 \n"
"mov x4, #0 \n" // oc count
"mov x4, #0 \n"
"dup v1.4s, v10.s[0] \n"
"dup v2.4s, v10.s[1] \n"

View File

@ -46,6 +46,18 @@ Convolution1x1Int8CPUKernel::~Convolution1x1Int8CPUKernel() {
free(filter_zp_ptr_);
filter_zp_ptr_ = nullptr;
}
if (filter_peroc_ && left_shift_ != nullptr) {
free(left_shift_);
left_shift_ = nullptr;
}
if (filter_peroc_ && right_shift_ != nullptr) {
free(right_shift_);
right_shift_ = nullptr;
}
if (filter_peroc_ && multiplier_ != nullptr) {
free(multiplier_);
multiplier_ = nullptr;
}
FreeResizeBuf();
FreeQuantParam();
}
@ -59,7 +71,7 @@ void Convolution1x1Int8CPUKernel::FreeResizeBuf() {
}
void Convolution1x1Int8CPUKernel::CheckSupportOptimize() {
support_optimize_ = false;
support_optimize_ = true;
matmul_func_ = MatMulInt8_8x8_r;
#ifdef ENABLE_ARM64
void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_;
@ -78,10 +90,6 @@ void Convolution1x1Int8CPUKernel::CheckSupportOptimize() {
support_optimize_ = false;
matmul_func_ = nullptr;
}
if (filter_peroc_) {
support_optimize_ = false;
}
#endif
return;
}
@ -109,6 +117,26 @@ int Convolution1x1Int8CPUKernel::InitBiasByzp(void *src_weight, int input_channe
for (int fi = 0; fi < output_channel; fi++) {
filter_zp_ptr_[fi] = conv_param_->conv_quant_arg_.filter_quant_args_[fi].zp_;
}
int up_round_oc_size = support_optimize_ ? UP_ROUND(output_channel, C8NUM) : UP_ROUND(output_channel, C4NUM);
left_shift_ = reinterpret_cast<int32_t *>(malloc(up_round_oc_size * sizeof(int32_t)));
if (left_shift_ == nullptr) {
return RET_ERROR;
}
memset(left_shift_, 0, up_round_oc_size * sizeof(int32_t));
memcpy(left_shift_, conv_param_->conv_quant_arg_.left_shift_, output_channel * sizeof(int32_t));
right_shift_ = reinterpret_cast<int32_t *>(malloc(up_round_oc_size * sizeof(int32_t)));
if (right_shift_ == nullptr) {
return RET_ERROR;
}
memset(right_shift_, 0, up_round_oc_size * sizeof(int32_t));
memcpy(right_shift_, conv_param_->conv_quant_arg_.right_shift_, output_channel * sizeof(int32_t));
multiplier_ = reinterpret_cast<int32_t *>(malloc(up_round_oc_size * sizeof(int32_t)));
if (multiplier_ == nullptr) {
return RET_ERROR;
}
memset(multiplier_, 0, up_round_oc_size * sizeof(int32_t));
memcpy(multiplier_, conv_param_->conv_quant_arg_.quant_multiplier_, output_channel * sizeof(int32_t));
}
return RET_OK;
}
@ -328,9 +356,9 @@ int Convolution1x1Int8CPUKernel::RunImpl(int task_id) {
}
if (filter_peroc_) {
cur_input_sum = input_sum_ + task_id * matmul_param_->row_8_ * thread_stride_ * C8NUM;
cur_left_shift = conv_param_->conv_quant_arg_.left_shift_ + task_id * thread_stride_ * C8NUM;
cur_right_shift = conv_param_->conv_quant_arg_.right_shift_ + task_id * thread_stride_ * C8NUM;
cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_ + task_id * thread_stride_ * C8NUM;
cur_left_shift = left_shift_ + task_id * thread_stride_ * C8NUM;
cur_right_shift = right_shift_ + task_id * thread_stride_ * C8NUM;
cur_multiplier = multiplier_ + task_id * thread_stride_ * C8NUM;
}
Conv1x1Int8Opt(packed_input_, packed_weight_ + task_id * thread_stride_ * C8NUM * matmul_param_->deep_4_,
output_ptr_ + task_id * thread_stride_ * C8NUM, cur_input_sum,
@ -346,9 +374,9 @@ int Convolution1x1Int8CPUKernel::RunImpl(int task_id) {
}
if (filter_peroc_) {
cur_input_sum = input_sum_ + task_id * matmul_param_->row_4_ * thread_stride_ * C4NUM;
cur_left_shift = conv_param_->conv_quant_arg_.left_shift_ + task_id * thread_stride_ * C4NUM;
cur_right_shift = conv_param_->conv_quant_arg_.right_shift_ + task_id * thread_stride_ * C4NUM;
cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_ + task_id * thread_stride_ * C4NUM;
cur_left_shift = left_shift_ + task_id * thread_stride_ * C4NUM;
cur_right_shift = right_shift_ + task_id * thread_stride_ * C4NUM;
cur_multiplier = multiplier_ + task_id * thread_stride_ * C4NUM;
}
Conv1x1Int8(packed_input_, packed_weight_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_16_,
output_ptr_ + task_id * thread_stride_ * C4NUM, cur_input_sum,

View File

@ -58,8 +58,11 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel {
int InitBiasByzp(void *src_weight, int input_channel, int output_channel);
private:
int32_t *input_sum_ = nullptr; /* per-channel: oc4 format */
int32_t *filter_zp_ptr_ = nullptr; /* oc - per - channel */
int32_t *input_sum_ = nullptr; /* per-oc: oc4 format */
int32_t *filter_zp_ptr_ = nullptr; /* per-oc */
int32_t *left_shift_ = nullptr; /* per-oc up round */
int32_t *right_shift_ = nullptr; /* per-oc up round */
int32_t *multiplier_ = nullptr; /* per-oc up round */
int8_t *packed_weight_ = nullptr;
int8_t *packed_input_ = nullptr;
int8_t *input_ptr_ = nullptr;

View File

@ -108,8 +108,8 @@ int FullconnectionInt8CPUKernel::RunImpl(int task_id) {
auto cur_c = output_ptr + task_id * thread_stride_ * C4NUM;
#ifdef ENABLE_ARM64
MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, q.out_act_min,
q.out_act_max, q.output.zp_, q.quant_multiplier, q.left_shift, q.right_shift, p->row_, cur_oc_res,
p->col_ * sizeof(int8_t));
q.out_act_max, q.output.zp_, &q.quant_multiplier, &q.left_shift, &q.right_shift, p->row_, cur_oc_res,
p->col_ * sizeof(int8_t), 0);
#else
MatmulInt8(a_r4x16_ptr_, cur_b, cur_c, input_sums_, cur_bias, q.out_act_min, q.out_act_max, q.output.zp_,
q.quant_multiplier, q.left_shift, q.right_shift, p->row_, cur_oc_res, d16_, p->col_);

View File

@ -101,8 +101,8 @@ int MatmulInt8CPUKernel::RunImpl(int task_id) {
auto &p = quant_params_;
#ifdef ENABLE_ARM64
MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, INT8_MIN, INT8_MAX,
p.output.zp_, p.quant_multiplier, p.left_shift, p.right_shift, params_->row_, cur_oc_res,
params_->col_ * sizeof(int8_t));
p.output.zp_, &p.quant_multiplier, &p.left_shift, &p.right_shift, params_->row_, cur_oc_res,
params_->col_ * sizeof(int8_t), false);
#else
MatmulInt8(a_r4x16_ptr_, cur_b, cur_c, input_sums_, cur_bias, INT8_MIN, INT8_MAX, p.output.zp_, p.quant_multiplier,
p.left_shift, p.right_shift, params_->row_, cur_oc_res, d16_, params_->col_);

View File

@ -120,8 +120,8 @@ TEST_F(TestMatmulInt8, simple) {
int multiplier, ls, rs;
QuantizeRoundParameter(1.0f, &multiplier, &ls, &rs);
#ifdef ENABLE_ARM64
MatmulInt8Neon64(a_r4x16, b_c16x4, output, ROW4, COL4, DEPTH16, a_sums, bias, INT8_MIN, INT8_MAX, 0, multiplier, ls,
rs, ROW, COL, COL);
MatmulInt8Neon64(a_r4x16, b_c16x4, output, ROW4, COL4, DEPTH16, a_sums, bias, INT8_MIN, INT8_MAX, 0, &multiplier, &ls,
&rs, ROW, COL, COL, false);
#else
MatmulInt8(a_r4x16, b_c16x4, output, a_sums, bias, INT8_MIN, INT8_MAX, 0, multiplier, ls, rs, ROW, COL, DEPTH16, COL);
#endif