forked from mindspore-Ecosystem/mindspore
!6102 [MSLITE][Develop]conv1x1 per oc arm64
Merge pull request !6102 from ling/sr
This commit is contained in:
commit
19874b83e7
|
@ -6,9 +6,9 @@
|
|||
.type MatmulInt8Neon64, %function
|
||||
#endif
|
||||
|
||||
//void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16,
|
||||
// const int *a_sums, const int *bias, int act_min, int act_max, int out_zp,
|
||||
// int multiplier, int left_shift, int right_shift, int row, int col, int stride);
|
||||
//void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums,
|
||||
// const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift,
|
||||
// int32_t *right_shift, int row, int col, int stride, int filter_peroc)
|
||||
|
||||
// x0: a(left matrix ptr)
|
||||
// x1: b(right matrix ptr)
|
||||
|
@ -21,31 +21,34 @@
|
|||
// w8: act_min
|
||||
// w9: act_max
|
||||
// w10: out_zp
|
||||
// w11: multiplier
|
||||
// w12: left_shift
|
||||
// w13: right_shift
|
||||
// x11: multiplier
|
||||
// x12: left_shift
|
||||
// x13: right_shift
|
||||
// w14: row
|
||||
// w15: col
|
||||
// w24: stride
|
||||
// w27: filter_peroc
|
||||
|
||||
MatmulInt8Neon64:
|
||||
sub sp, sp, #192
|
||||
sub sp, sp, #208
|
||||
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
stp x19, x20, [sp], #16
|
||||
stp x21, x22, [sp], #16
|
||||
stp x23, x24, [sp], #16
|
||||
stp x25, x26, [sp], #16
|
||||
stp x27, x28, [sp], #16
|
||||
|
||||
ldr w8, [sp]
|
||||
ldr w9, [sp, #8]
|
||||
ldr w10, [sp, #16]
|
||||
ldr w11, [sp, #24]
|
||||
ldr w12, [sp, #32]
|
||||
ldr w13, [sp, #40]
|
||||
ldr x11, [sp, #24]
|
||||
ldr x12, [sp, #32]
|
||||
ldr x13, [sp, #40]
|
||||
ldr w14, [sp, #48]
|
||||
ldr w15, [sp, #56]
|
||||
ldr w24, [sp, #64]
|
||||
ldr w27, [sp, #72]
|
||||
|
||||
mov w17, #4 // sizeof(int8)*4
|
||||
mul w21, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16
|
||||
|
@ -58,7 +61,7 @@ L1:
|
|||
mov w16, w3 // reset a row4 counter
|
||||
mov w23, w14 // reset a row counter
|
||||
mov x17, x0 // reload a ptr
|
||||
mov x22, x6 // reload a_sums ptr
|
||||
mov x22, x6 // reload a_sums ptr
|
||||
L2:
|
||||
cmp w16, #0
|
||||
beq End2
|
||||
|
@ -167,39 +170,60 @@ End3:
|
|||
addp v19.4s, v28.4s, v30.4s
|
||||
|
||||
// Add (Bias+Depth*Za*Zb-Za*Bsums)
|
||||
ld1 {v15.4s}, [x19], #16
|
||||
ld1 {v15.4s}, [x19], #16
|
||||
add v16.4s, v16.4s, v15.4s
|
||||
add v17.4s, v17.4s, v15.4s
|
||||
add v18.4s, v18.4s, v15.4s
|
||||
add v19.4s, v19.4s, v15.4s
|
||||
|
||||
// Subtract (Asums*Zb)
|
||||
cmp w27, #0
|
||||
beq PerTLoad
|
||||
PerCLoad:
|
||||
ld1 {v20.4s}, [x6], #16
|
||||
ld1 {v21.4s}, [x6], #16
|
||||
ld1 {v22.4s}, [x6], #16
|
||||
ld1 {v23.4s}, [x6], #16
|
||||
|
||||
ld1 {v13.4s}, [x12]
|
||||
ld1 {v12.4s}, [x11]
|
||||
ld1 {v11.4s}, [x13]
|
||||
b Apply
|
||||
|
||||
PerTLoad:
|
||||
ld1 {v14.4s}, [x22], #16
|
||||
dup v20.4s, v14.s[0]
|
||||
dup v21.4s, v14.s[1]
|
||||
dup v22.4s, v14.s[2]
|
||||
dup v23.4s, v14.s[3]
|
||||
|
||||
ld1 {v14.s}[0], [x12]
|
||||
dup v13.4s, v14.s[0]
|
||||
ld1 {v14.s}[0], [x11]
|
||||
dup v12.4s, v14.s[0]
|
||||
ld1 {v14.s}[0], [x13]
|
||||
dup v11.4s, v14.s[0]
|
||||
b Apply
|
||||
|
||||
Apply:
|
||||
// Subtract (Asums*Zb)
|
||||
sub v16.4s, v16.4s, v20.4s
|
||||
sub v17.4s, v17.4s, v21.4s
|
||||
sub v18.4s, v18.4s, v22.4s
|
||||
sub v19.4s, v19.4s, v23.4s
|
||||
|
||||
// Apply left shift
|
||||
dup v13.4s, w12
|
||||
sqshl v16.4s, v16.4s, v13.4s
|
||||
sqshl v17.4s, v17.4s, v13.4s
|
||||
sqshl v18.4s, v18.4s, v13.4s
|
||||
sqshl v19.4s, v19.4s, v13.4s
|
||||
|
||||
// Apply the fixed-point part of the multiplier.
|
||||
dup v12.4s, w11
|
||||
sqrdmulh v16.4s, v16.4s, v12.4s
|
||||
sqrdmulh v17.4s, v17.4s, v12.4s
|
||||
sqrdmulh v18.4s, v18.4s, v12.4s
|
||||
sqrdmulh v19.4s, v19.4s, v12.4s
|
||||
|
||||
// Apply right shift
|
||||
dup v11.4s, w13
|
||||
and v20.16b, v11.16b, v16.16b
|
||||
sshr v20.4s, v20.4s, #31
|
||||
sqadd v16.4s, v16.4s, v20.4s
|
||||
|
@ -268,7 +292,7 @@ Write:
|
|||
beq WriteCol2
|
||||
cmp w15, #1
|
||||
beq WriteCol1
|
||||
|
||||
|
||||
WriteCol4:
|
||||
st1 {v15.s}[0], [x2], x24
|
||||
cmp w23, #1
|
||||
|
@ -349,7 +373,7 @@ WriteCol1:
|
|||
st1 {v15.b}[12], [x2], x24
|
||||
b Endwrite
|
||||
|
||||
Endwrite:
|
||||
Endwrite:
|
||||
sub w16, w16, #4 // a row4 counter - 4
|
||||
sub w23, w23, #4 // a row counter - 4
|
||||
b L2
|
||||
|
@ -361,15 +385,23 @@ End2:
|
|||
add x7, x7, #16 // bias ptr + stride
|
||||
add x25, x25, #4 // output + stride(4 * sizeof(int8))
|
||||
mov x2, x25
|
||||
|
||||
cmp w27, #0
|
||||
beq PerTEnd2
|
||||
add x12, x12, #16
|
||||
add x11, x11, #16
|
||||
add x13, x13, #16
|
||||
PerTEnd2:
|
||||
b L1
|
||||
|
||||
End1:
|
||||
sub sp, sp, #192
|
||||
sub sp, sp, #208
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
ldp x19, x20, [sp], #16
|
||||
ldp x21, x22, [sp], #16
|
||||
ldp x23, x24, [sp], #16
|
||||
ldp x25, x26, [sp], #16
|
||||
ldp x27, x28, [sp], #16
|
||||
ret
|
||||
#endif
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
|
||||
//void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, int row8, int col8, int deep4,
|
||||
// const int *a_sums, const int *bias, int act_min, int act_max, int out_zp,
|
||||
// int multiplier, int left_shift, int right_shift, int row, int col, int stride);
|
||||
// int *multiplier, int *left_shift, int *right_shift, int row, int col, int stride, int peroc);
|
||||
|
||||
// x0: a(left matrix ptr)
|
||||
// x1: b(right matrix ptr)
|
||||
|
@ -21,31 +21,34 @@
|
|||
// w8: act_min
|
||||
// w9: act_max
|
||||
// w10: out_zp
|
||||
// w11: multiplier
|
||||
// w12: left_shift
|
||||
// w13: right_shift
|
||||
// x11: multiplier
|
||||
// x12: left_shift
|
||||
// x13: right_shift
|
||||
// w14: row
|
||||
// w15: col
|
||||
// w24: stride
|
||||
// w27: filter_peroc
|
||||
|
||||
MatmulInt8DpNeon64:
|
||||
sub sp, sp, #192
|
||||
sub sp, sp, #208
|
||||
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
stp x19, x20, [sp], #16
|
||||
stp x21, x22, [sp], #16
|
||||
stp x23, x24, [sp], #16
|
||||
stp x25, x26, [sp], #16
|
||||
stp x27, x28, [sp], #16
|
||||
|
||||
ldr w8, [sp]
|
||||
ldr w9, [sp, #8]
|
||||
ldr w10, [sp, #16]
|
||||
ldr w11, [sp, #24]
|
||||
ldr w12, [sp, #32]
|
||||
ldr w13, [sp, #40]
|
||||
ldr x11, [sp, #24]
|
||||
ldr x12, [sp, #32]
|
||||
ldr x13, [sp, #40]
|
||||
ldr w14, [sp, #48]
|
||||
ldr w15, [sp, #56]
|
||||
ldr w24, [sp, #64]
|
||||
ldr w27, [sp, #72]
|
||||
|
||||
mov w17, #8 // sizeof(int8)*8
|
||||
mul w21, w5, w17 // the stride of a/b: sizeof(int8)*8*deep4
|
||||
|
@ -226,138 +229,171 @@ End3:
|
|||
add v29.4s, v29.4s, v14.4s
|
||||
add v31.4s, v31.4s, v14.4s
|
||||
|
||||
cmp w27, #0
|
||||
beq PerTSumLoad
|
||||
PerCSumLoad:
|
||||
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x6], #64
|
||||
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x6], #64
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x6], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], #64
|
||||
b ApplySum
|
||||
PerTSumLoad:
|
||||
ld1 {v14.4s}, [x22], #16
|
||||
ld1 {v15.4s}, [x22], #16
|
||||
dup v0.4s, v14.s[0]
|
||||
dup v1.4s, v14.s[0]
|
||||
dup v2.4s, v14.s[1]
|
||||
dup v3.4s, v14.s[1]
|
||||
dup v4.4s, v14.s[2]
|
||||
dup v5.4s, v14.s[2]
|
||||
dup v6.4s, v14.s[3]
|
||||
dup v7.4s, v14.s[3]
|
||||
dup v8.4s, v15.s[0]
|
||||
dup v9.4s, v15.s[0]
|
||||
dup v10.4s, v15.s[1]
|
||||
dup v11.4s, v15.s[1]
|
||||
dup v12.4s, v15.s[2]
|
||||
dup v13.4s, v15.s[2]
|
||||
dup v14.4s, v15.s[3]
|
||||
dup v15.4s, v14.s[0]
|
||||
ApplySum:
|
||||
// Subtract (Asums*Zb)
|
||||
ld1 {v13.4s}, [x22], #16
|
||||
ld1 {v12.4s}, [x22], #16
|
||||
dup v0.4s, v13.s[0]
|
||||
dup v1.4s, v13.s[1]
|
||||
dup v2.4s, v13.s[2]
|
||||
dup v3.4s, v13.s[3]
|
||||
dup v4.4s, v12.s[0]
|
||||
dup v5.4s, v12.s[1]
|
||||
dup v6.4s, v12.s[2]
|
||||
dup v7.4s, v12.s[3]
|
||||
sub v16.4s, v16.4s, v0.4s
|
||||
sub v17.4s, v17.4s, v0.4s
|
||||
sub v18.4s, v18.4s, v1.4s
|
||||
sub v19.4s, v19.4s, v1.4s
|
||||
sub v20.4s, v20.4s, v2.4s
|
||||
sub v21.4s, v21.4s, v2.4s
|
||||
sub v22.4s, v22.4s, v3.4s
|
||||
sub v23.4s, v23.4s, v3.4s
|
||||
sub v24.4s, v24.4s, v4.4s
|
||||
sub v25.4s, v25.4s, v4.4s
|
||||
sub v26.4s, v26.4s, v5.4s
|
||||
sub v27.4s, v27.4s, v5.4s
|
||||
sub v28.4s, v28.4s, v6.4s
|
||||
sub v29.4s, v29.4s, v6.4s
|
||||
sub v30.4s, v30.4s, v7.4s
|
||||
sub v31.4s, v31.4s, v7.4s
|
||||
sub v17.4s, v17.4s, v1.4s
|
||||
sub v18.4s, v18.4s, v2.4s
|
||||
sub v19.4s, v19.4s, v3.4s
|
||||
sub v20.4s, v20.4s, v4.4s
|
||||
sub v21.4s, v21.4s, v5.4s
|
||||
sub v22.4s, v22.4s, v6.4s
|
||||
sub v23.4s, v23.4s, v7.4s
|
||||
sub v24.4s, v24.4s, v8.4s
|
||||
sub v25.4s, v25.4s, v9.4s
|
||||
sub v26.4s, v26.4s, v10.4s
|
||||
sub v27.4s, v27.4s, v11.4s
|
||||
sub v28.4s, v28.4s, v12.4s
|
||||
sub v29.4s, v29.4s, v13.4s
|
||||
sub v30.4s, v30.4s, v14.4s
|
||||
sub v31.4s, v31.4s, v15.4s
|
||||
|
||||
cmp w27, #0
|
||||
beq PerTRoundLoad
|
||||
PerCRoundLoad:
|
||||
ld1 {v8.4s, v9.4s}, [x12]
|
||||
ld1 {v10.4s, v11.4s}, [x11]
|
||||
ld1 {v12.4s, v13.4s}, [x13]
|
||||
b ApplyRound
|
||||
PerTRoundLoad:
|
||||
ld1 {v14.s}[0], [x12]
|
||||
dup v8.4s, v14.s[0]
|
||||
dup v9.4s, v14.s[0]
|
||||
ld1 {v14.s}[0], [x11]
|
||||
dup v10.4s, v14.s[0]
|
||||
dup v11.4s, v14.s[0]
|
||||
ld1 {v14.s}[0], [x13]
|
||||
dup v12.4s, v14.s[0]
|
||||
dup v13.4s, v14.s[0]
|
||||
ApplyRound:
|
||||
// Apply left shift
|
||||
dup v11.4s, w12
|
||||
sqshl v16.4s, v16.4s, v11.4s
|
||||
sqshl v17.4s, v17.4s, v11.4s
|
||||
sqshl v18.4s, v18.4s, v11.4s
|
||||
sqshl v19.4s, v19.4s, v11.4s
|
||||
sqshl v20.4s, v20.4s, v11.4s
|
||||
sqshl v21.4s, v21.4s, v11.4s
|
||||
sqshl v22.4s, v22.4s, v11.4s
|
||||
sqshl v23.4s, v23.4s, v11.4s
|
||||
sqshl v24.4s, v24.4s, v11.4s
|
||||
sqshl v25.4s, v25.4s, v11.4s
|
||||
sqshl v26.4s, v26.4s, v11.4s
|
||||
sqshl v27.4s, v27.4s, v11.4s
|
||||
sqshl v28.4s, v28.4s, v11.4s
|
||||
sqshl v29.4s, v29.4s, v11.4s
|
||||
sqshl v30.4s, v30.4s, v11.4s
|
||||
sqshl v31.4s, v31.4s, v11.4s
|
||||
sqshl v16.4s, v16.4s, v8.4s
|
||||
sqshl v17.4s, v17.4s, v9.4s
|
||||
sqshl v18.4s, v18.4s, v8.4s
|
||||
sqshl v19.4s, v19.4s, v9.4s
|
||||
sqshl v20.4s, v20.4s, v8.4s
|
||||
sqshl v21.4s, v21.4s, v9.4s
|
||||
sqshl v22.4s, v22.4s, v8.4s
|
||||
sqshl v23.4s, v23.4s, v9.4s
|
||||
sqshl v24.4s, v24.4s, v8.4s
|
||||
sqshl v25.4s, v25.4s, v9.4s
|
||||
sqshl v26.4s, v26.4s, v8.4s
|
||||
sqshl v27.4s, v27.4s, v9.4s
|
||||
sqshl v28.4s, v28.4s, v8.4s
|
||||
sqshl v29.4s, v29.4s, v9.4s
|
||||
sqshl v30.4s, v30.4s, v8.4s
|
||||
sqshl v31.4s, v31.4s, v9.4s
|
||||
|
||||
// Apply the fixed-point part of the multiplier.
|
||||
dup v10.4s, w11
|
||||
sqrdmulh v16.4s, v16.4s, v10.4s
|
||||
sqrdmulh v17.4s, v17.4s, v10.4s
|
||||
sqrdmulh v17.4s, v17.4s, v11.4s
|
||||
sqrdmulh v18.4s, v18.4s, v10.4s
|
||||
sqrdmulh v19.4s, v19.4s, v10.4s
|
||||
sqrdmulh v19.4s, v19.4s, v11.4s
|
||||
sqrdmulh v20.4s, v20.4s, v10.4s
|
||||
sqrdmulh v21.4s, v21.4s, v10.4s
|
||||
sqrdmulh v21.4s, v21.4s, v11.4s
|
||||
sqrdmulh v22.4s, v22.4s, v10.4s
|
||||
sqrdmulh v23.4s, v23.4s, v10.4s
|
||||
sqrdmulh v23.4s, v23.4s, v11.4s
|
||||
sqrdmulh v24.4s, v24.4s, v10.4s
|
||||
sqrdmulh v25.4s, v25.4s, v10.4s
|
||||
sqrdmulh v25.4s, v25.4s, v11.4s
|
||||
sqrdmulh v26.4s, v26.4s, v10.4s
|
||||
sqrdmulh v27.4s, v27.4s, v10.4s
|
||||
sqrdmulh v27.4s, v27.4s, v11.4s
|
||||
sqrdmulh v28.4s, v28.4s, v10.4s
|
||||
sqrdmulh v29.4s, v29.4s, v10.4s
|
||||
sqrdmulh v29.4s, v29.4s, v11.4s
|
||||
sqrdmulh v30.4s, v30.4s, v10.4s
|
||||
sqrdmulh v31.4s, v31.4s, v10.4s
|
||||
sqrdmulh v31.4s, v31.4s, v11.4s
|
||||
|
||||
// Apply right shift
|
||||
dup v9.4s, w13
|
||||
and v0.16b, v9.16b, v16.16b
|
||||
and v0.16b, v12.16b, v16.16b
|
||||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v16.4s, v16.4s, v0.4s
|
||||
srshl v16.4s, v16.4s, v9.4s
|
||||
and v1.16b, v9.16b, v17.16b
|
||||
srshl v16.4s, v16.4s, v12.4s
|
||||
and v1.16b, v13.16b, v17.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v17.4s, v17.4s, v1.4s
|
||||
srshl v17.4s, v17.4s, v9.4s
|
||||
and v2.16b, v9.16b, v18.16b
|
||||
srshl v17.4s, v17.4s, v13.4s
|
||||
and v2.16b, v12.16b, v18.16b
|
||||
sshr v2.4s, v2.4s, #31
|
||||
sqadd v18.4s, v18.4s, v2.4s
|
||||
srshl v18.4s, v18.4s, v9.4s
|
||||
and v3.16b, v9.16b, v19.16b
|
||||
srshl v18.4s, v18.4s, v12.4s
|
||||
and v3.16b, v13.16b, v19.16b
|
||||
sshr v3.4s, v3.4s, #31
|
||||
sqadd v19.4s, v19.4s, v3.4s
|
||||
srshl v19.4s, v19.4s, v9.4s
|
||||
and v0.16b, v9.16b, v20.16b
|
||||
srshl v19.4s, v19.4s, v13.4s
|
||||
and v0.16b, v12.16b, v20.16b
|
||||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v20.4s, v20.4s, v0.4s
|
||||
srshl v20.4s, v20.4s, v9.4s
|
||||
and v1.16b, v9.16b, v21.16b
|
||||
srshl v20.4s, v20.4s, v12.4s
|
||||
and v1.16b, v13.16b, v21.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v21.4s, v21.4s, v1.4s
|
||||
srshl v21.4s, v21.4s, v9.4s
|
||||
and v2.16b, v9.16b, v22.16b
|
||||
srshl v21.4s, v21.4s, v13.4s
|
||||
and v2.16b, v12.16b, v22.16b
|
||||
sshr v2.4s, v2.4s, #31
|
||||
sqadd v22.4s, v22.4s, v2.4s
|
||||
srshl v22.4s, v22.4s, v9.4s
|
||||
and v3.16b, v9.16b, v23.16b
|
||||
srshl v22.4s, v22.4s, v12.4s
|
||||
and v3.16b, v13.16b, v23.16b
|
||||
sshr v3.4s, v3.4s, #31
|
||||
sqadd v23.4s, v23.4s, v3.4s
|
||||
srshl v23.4s, v23.4s, v9.4s
|
||||
and v0.16b, v9.16b, v24.16b
|
||||
srshl v23.4s, v23.4s, v13.4s
|
||||
and v0.16b, v12.16b, v24.16b
|
||||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v24.4s, v24.4s, v0.4s
|
||||
srshl v24.4s, v24.4s, v9.4s
|
||||
and v1.16b, v9.16b, v25.16b
|
||||
srshl v24.4s, v24.4s, v12.4s
|
||||
and v1.16b, v13.16b, v25.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v25.4s, v25.4s, v1.4s
|
||||
srshl v25.4s, v25.4s, v9.4s
|
||||
and v2.16b, v9.16b, v26.16b
|
||||
srshl v25.4s, v25.4s, v13.4s
|
||||
and v2.16b, v12.16b, v26.16b
|
||||
sshr v2.4s, v2.4s, #31
|
||||
sqadd v26.4s, v26.4s, v2.4s
|
||||
srshl v26.4s, v26.4s, v9.4s
|
||||
and v3.16b, v9.16b, v27.16b
|
||||
srshl v26.4s, v26.4s, v12.4s
|
||||
and v3.16b, v13.16b, v27.16b
|
||||
sshr v3.4s, v3.4s, #31
|
||||
sqadd v27.4s, v27.4s, v3.4s
|
||||
srshl v27.4s, v27.4s, v9.4s
|
||||
and v0.16b, v9.16b, v28.16b
|
||||
srshl v27.4s, v27.4s, v13.4s
|
||||
and v0.16b, v12.16b, v28.16b
|
||||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v28.4s, v28.4s, v0.4s
|
||||
srshl v28.4s, v28.4s, v9.4s
|
||||
and v1.16b, v9.16b, v29.16b
|
||||
srshl v28.4s, v28.4s, v12.4s
|
||||
and v1.16b, v13.16b, v29.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v29.4s, v29.4s, v1.4s
|
||||
srshl v29.4s, v29.4s, v9.4s
|
||||
and v2.16b, v9.16b, v30.16b
|
||||
srshl v29.4s, v29.4s, v13.4s
|
||||
and v2.16b, v12.16b, v30.16b
|
||||
sshr v2.4s, v2.4s, #31
|
||||
sqadd v30.4s, v30.4s, v2.4s
|
||||
srshl v30.4s, v30.4s, v9.4s
|
||||
and v3.16b, v9.16b, v31.16b
|
||||
srshl v30.4s, v30.4s, v12.4s
|
||||
and v3.16b, v13.16b, v31.16b
|
||||
sshr v3.4s, v3.4s, #31
|
||||
sqadd v31.4s, v31.4s, v3.4s
|
||||
srshl v31.4s, v31.4s, v9.4s
|
||||
srshl v31.4s, v31.4s, v13.4s
|
||||
|
||||
// Add the destination zero point
|
||||
dup v8.4s, w10
|
||||
|
@ -793,15 +829,23 @@ End2:
|
|||
add x7, x7, #32 // bias ptr + stride
|
||||
add x25, x25, #8 // output + stride(8 * sizeof(int8))
|
||||
mov x2, x25
|
||||
|
||||
cmp w27, #0
|
||||
beq PerTEnd2
|
||||
add x12, x12, #32
|
||||
add x11, x11, #32
|
||||
add x13, x13, #32
|
||||
PerTEnd2:
|
||||
b L1
|
||||
|
||||
End1:
|
||||
sub sp, sp, #192
|
||||
sub sp, sp, #208
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
ldp x19, x20, [sp], #16
|
||||
ldp x21, x22, [sp], #16
|
||||
ldp x23, x24, [sp], #16
|
||||
ldp x25, x26, [sp], #16
|
||||
ldp x27, x28, [sp], #16
|
||||
ret
|
||||
#endif
|
||||
|
|
|
@ -385,12 +385,302 @@ void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *
|
|||
int8_t *pack_r = packed_input;
|
||||
int32_t *input_sum_r = input_sum;
|
||||
|
||||
/* per layer */
|
||||
for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) {
|
||||
const int8_t *src_ic = src_r;
|
||||
int8_t *pack_ic = pack_r;
|
||||
int32_t *input_sum_oc = input_sum_r;
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t src_stride = input_channel;
|
||||
size_t ic_4res = input_channel - ic_4div;
|
||||
size_t input_sum_stride = inputsum_stride * 4 - C8NUM * C8NUM * 4;
|
||||
asm volatile(
|
||||
"dup v16.4s, wzr \n"
|
||||
"dup v17.4s, wzr \n"
|
||||
|
||||
"mov x10, %[src_ic] \n"
|
||||
"mov x11, %[pack_ic] \n"
|
||||
|
||||
"mov x0, #0 \n"
|
||||
"1: \n"
|
||||
"cmp x0, %[ic_4div] \n"
|
||||
"add x0, x0, #4\n"
|
||||
"mov x12, x10 \n"
|
||||
"add x10, x10, #4\n"
|
||||
"blt 2f \n"
|
||||
"cmp %[ic_4res], #0\n"
|
||||
"beq 6f \n"
|
||||
"cmp %[ic_4res], #1\n"
|
||||
"beq 3f \n"
|
||||
"cmp %[ic_4res], #2\n"
|
||||
"beq 4f \n"
|
||||
"cmp %[ic_4res], #3\n"
|
||||
"beq 5f \n"
|
||||
|
||||
"2: \n"
|
||||
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[1], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[3], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 1b \n"
|
||||
|
||||
"3: \n" /* col res 1 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
|
||||
"ld1 {v0.b}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[8], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[12], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"4: \n" /* col res 2 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
|
||||
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"5: \n" /* col res 3 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
"add x13, x12, #2 \n"
|
||||
|
||||
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[2], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[6], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[10], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[14], [x13], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"6: \n"
|
||||
"dup v0.4s, v16.s[0] \n"
|
||||
"dup v1.4s, v16.s[1] \n"
|
||||
"dup v2.4s, v16.s[2] \n"
|
||||
"dup v3.4s, v16.s[3] \n"
|
||||
"dup v4.4s, v17.s[0] \n"
|
||||
"dup v5.4s, v17.s[1] \n"
|
||||
"dup v6.4s, v17.s[2] \n"
|
||||
"dup v7.4s, v17.s[3] \n"
|
||||
"mov x4, #0 \n"
|
||||
"mov x10, %[filter_zp] \n"
|
||||
"mov x11, %[input_sum_oc] \n"
|
||||
|
||||
"7: \n"
|
||||
"cmp x4, %[oc_8div] \n"
|
||||
"beq 8f \n"
|
||||
"add x4, x4, #8\n"
|
||||
"ld1 {v16.4s}, [x10], #16\n"
|
||||
"ld1 {v17.4s}, [x10], #16\n"
|
||||
|
||||
"mul v18.4s, v16.4s, v0.4s \n"
|
||||
"mul v19.4s, v17.4s, v0.4s \n"
|
||||
"st1 {v18.4s}, [x11], #16 \n"
|
||||
"st1 {v19.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v20.4s, v16.4s, v1.4s \n"
|
||||
"mul v21.4s, v17.4s, v1.4s \n"
|
||||
"st1 {v20.4s}, [x11], #16 \n"
|
||||
"st1 {v21.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v22.4s, v16.4s, v2.4s \n"
|
||||
"mul v23.4s, v17.4s, v2.4s \n"
|
||||
"st1 {v22.4s}, [x11], #16 \n"
|
||||
"st1 {v23.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v24.4s, v16.4s, v3.4s \n"
|
||||
"mul v25.4s, v17.4s, v3.4s \n"
|
||||
"st1 {v24.4s}, [x11], #16 \n"
|
||||
"st1 {v25.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v18.4s, v16.4s, v4.4s \n"
|
||||
"mul v19.4s, v17.4s, v4.4s \n"
|
||||
"st1 {v18.4s}, [x11], #16 \n"
|
||||
"st1 {v19.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v20.4s, v16.4s, v5.4s \n"
|
||||
"mul v21.4s, v17.4s, v5.4s \n"
|
||||
"st1 {v20.4s}, [x11], #16 \n"
|
||||
"st1 {v21.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v22.4s, v16.4s, v6.4s \n"
|
||||
"mul v23.4s, v17.4s, v6.4s \n"
|
||||
"st1 {v22.4s}, [x11], #16 \n"
|
||||
"st1 {v23.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v24.4s, v16.4s, v7.4s \n"
|
||||
"mul v25.4s, v17.4s, v7.4s \n"
|
||||
"st1 {v24.4s}, [x11], #16 \n"
|
||||
"st1 {v25.4s}, [x11], #16 \n"
|
||||
|
||||
"add x11, x11, %[input_sum_stride] \n"
|
||||
"b 7b \n"
|
||||
|
||||
"8: \n"
|
||||
"cmp %[oc_8res], #0\n"
|
||||
"beq 17f \n"
|
||||
|
||||
"dup v16.4s, wzr \n"
|
||||
"dup v17.4s, wzr \n"
|
||||
"cmp %[oc_8res], #1\n"
|
||||
"beq 9f \n"
|
||||
"cmp %[oc_8res], #2\n"
|
||||
"beq 10f \n"
|
||||
"cmp %[oc_8res], #3\n"
|
||||
"beq 11f \n"
|
||||
"cmp %[oc_8res], #4\n"
|
||||
"beq 12f \n"
|
||||
"cmp %[oc_8res], #5\n"
|
||||
"beq 13f \n"
|
||||
"cmp %[oc_8res], #6\n"
|
||||
"beq 14f \n"
|
||||
"cmp %[oc_8res], #7\n"
|
||||
"beq 15f \n"
|
||||
|
||||
"9: \n"
|
||||
"ld1 {v16.s}[0], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"10: \n"
|
||||
"ld1 {v16.h}[0], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"11: \n"
|
||||
"ld1 {v16.h}[0], [x10] \n"
|
||||
"add x10, x10, #8 \n"
|
||||
"ld1 {v16.s}[2], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"12: \n"
|
||||
"ld1 {v16.4s}, [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"13: \n"
|
||||
"ld1 {v16.4s}, [x10], #16\n"
|
||||
"ld1 {v17.s}[0], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"14: \n"
|
||||
"ld1 {v16.4s}, [x10], #16\n"
|
||||
"ld1 {v17.h}[0], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"15: \n"
|
||||
"ld1 {v16.4s}, [x10], #16\n"
|
||||
"ld1 {v17.h}[0], [x10] \n"
|
||||
"add x10, x10, #8 \n"
|
||||
"ld1 {v17.s}[2], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"16: \n"
|
||||
"mul v18.4s, v16.4s, v0.4s \n"
|
||||
"mul v19.4s, v17.4s, v0.4s \n"
|
||||
"mul v20.4s, v16.4s, v1.4s \n"
|
||||
"mul v21.4s, v17.4s, v1.4s \n"
|
||||
"mul v22.4s, v16.4s, v2.4s \n"
|
||||
"mul v23.4s, v17.4s, v2.4s \n"
|
||||
"mul v24.4s, v16.4s, v3.4s \n"
|
||||
"mul v25.4s, v17.4s, v3.4s \n"
|
||||
"st1 {v18.4s}, [x11], #16 \n"
|
||||
"st1 {v19.4s}, [x11], #16 \n"
|
||||
"st1 {v20.4s}, [x11], #16 \n"
|
||||
"st1 {v21.4s}, [x11], #16 \n"
|
||||
"st1 {v22.4s}, [x11], #16 \n"
|
||||
"st1 {v23.4s}, [x11], #16 \n"
|
||||
"st1 {v24.4s}, [x11], #16 \n"
|
||||
"st1 {v25.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v18.4s, v16.4s, v4.4s \n"
|
||||
"mul v19.4s, v17.4s, v4.4s \n"
|
||||
"mul v20.4s, v16.4s, v5.4s \n"
|
||||
"mul v21.4s, v17.4s, v5.4s \n"
|
||||
"mul v22.4s, v16.4s, v6.4s \n"
|
||||
"mul v23.4s, v17.4s, v6.4s \n"
|
||||
"mul v24.4s, v16.4s, v7.4s \n"
|
||||
"mul v25.4s, v17.4s, v7.4s \n"
|
||||
"st1 {v18.4s}, [x11], #16 \n"
|
||||
"st1 {v19.4s}, [x11], #16 \n"
|
||||
"st1 {v20.4s}, [x11], #16 \n"
|
||||
"st1 {v21.4s}, [x11], #16 \n"
|
||||
"st1 {v22.4s}, [x11], #16 \n"
|
||||
"st1 {v23.4s}, [x11], #16 \n"
|
||||
"st1 {v24.4s}, [x11], #16 \n"
|
||||
"st1 {v25.4s}, [x11], #16 \n"
|
||||
|
||||
"17: \n"
|
||||
|
||||
:
|
||||
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ filter_zp ] "r"(filter_zp),
|
||||
[ input_sum_oc ] "r"(input_sum_oc), [ input_sum_stride ] "r"(input_sum_stride), [ src_stride ] "r"(src_stride),
|
||||
[ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ oc_8div ] "r"(oc_8div), [ oc_8res ] "r"(oc_8res)
|
||||
: "x0", "x1", "x4", "x9", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16",
|
||||
"v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25");
|
||||
#else
|
||||
int32_t tmp_sum_value[8] = {0};
|
||||
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
|
@ -440,7 +730,7 @@ void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *
|
|||
}
|
||||
}
|
||||
} /* oc8 res done */
|
||||
|
||||
#endif
|
||||
src_r += input_channel * C8NUM;
|
||||
pack_r += ic4 * C8NUM;
|
||||
input_sum_r += C8NUM * C8NUM;
|
||||
|
@ -520,9 +810,9 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i
|
|||
size_t src_stride = input_channel;
|
||||
size_t ic_4res = input_channel - ic_4div;
|
||||
asm volatile(
|
||||
"dup v10.4s, wzr \n"
|
||||
"dup v11.4s, wzr \n"
|
||||
"mov x20, %[input_sum_r] \n"
|
||||
"dup v16.4s, wzr \n"
|
||||
"dup v17.4s, wzr \n"
|
||||
"mov x14, %[input_sum_r] \n"
|
||||
"dup v20.4s, %w[filter_zp] \n"
|
||||
|
||||
"mov x10, %[src_ic] \n"
|
||||
|
@ -563,8 +853,8 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i
|
|||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
|
||||
"add v10.4s, v10.4s, v0.4s \n"
|
||||
"add v11.4s, v11.4s, v1.4s \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 1b \n"
|
||||
|
||||
"3: \n" /* col res 1 */
|
||||
|
@ -586,8 +876,8 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i
|
|||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v10.4s, v10.4s, v0.4s \n"
|
||||
"add v11.4s, v11.4s, v1.4s \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"4: \n" /* col res 2 */
|
||||
|
@ -609,8 +899,8 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i
|
|||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v10.4s, v10.4s, v0.4s \n"
|
||||
"add v11.4s, v11.4s, v1.4s \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"5: \n" /* col res 3 */
|
||||
|
@ -641,21 +931,21 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i
|
|||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v10.4s, v10.4s, v0.4s \n"
|
||||
"add v11.4s, v11.4s, v1.4s \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"6: \n"
|
||||
"mul v10.4s, v10.4s, v20.4s \n"
|
||||
"mul v11.4s, v11.4s, v20.4s \n"
|
||||
"mul v16.4s, v16.4s, v20.4s \n"
|
||||
"mul v17.4s, v17.4s, v20.4s \n"
|
||||
|
||||
"st1 {v10.4s}, [x20], #16 \n"
|
||||
"st1 {v11.4s}, [x20], #16 \n"
|
||||
"st1 {v16.4s}, [x14], #16 \n"
|
||||
"st1 {v17.4s}, [x14], #16 \n"
|
||||
|
||||
:
|
||||
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r),
|
||||
[ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp)
|
||||
: "x0", "x1", "x10", "x11", "x12", "x13", "x20", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v11",
|
||||
: "x0", "x1", "x10", "x11", "x12", "x13", "x14", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17",
|
||||
"v20");
|
||||
#else
|
||||
int32_t tmp_sum_value[8] = {0};
|
||||
|
@ -728,10 +1018,10 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i
|
|||
void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
|
||||
const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift,
|
||||
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func) {
|
||||
int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1;
|
||||
matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias,
|
||||
left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
|
||||
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
|
||||
conv_param->conv_quant_arg_.filter_arg_num_ != 1);
|
||||
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -756,24 +1046,17 @@ void Conv1x1Int8Arm32(const int8_t *packed_input, const int8_t *packed_weight, i
|
|||
void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
|
||||
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
|
||||
int32_t *multiplier, ConvParameter *conv_param) {
|
||||
if (conv_param->conv_quant_arg_.filter_arg_num_ > 1) {
|
||||
return MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum,
|
||||
bias, left_shift, right_shift, multiplier,
|
||||
conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
|
||||
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
|
||||
conv_param->conv_quant_arg_.filter_arg_num_ != 1);
|
||||
}
|
||||
int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1;
|
||||
#ifdef ENABLE_ARM64
|
||||
MatmulInt8Neon64(packed_input, packed_weight, dst, UP_ROUND(row, C4NUM), UP_ROUND(col, C4NUM), deep16, input_sum,
|
||||
bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
|
||||
conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
|
||||
conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0],
|
||||
conv_param->conv_quant_arg_.right_shift_[0], row, col, conv_param->output_channel_);
|
||||
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, row, col,
|
||||
conv_param->output_channel_, is_per_oc);
|
||||
#else
|
||||
MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias,
|
||||
left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
|
||||
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
|
||||
conv_param->conv_quant_arg_.filter_arg_num_ != 1);
|
||||
is_per_oc);
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -269,7 +269,7 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row,
|
|||
void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
|
||||
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
|
||||
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
|
||||
bool per_channel) {
|
||||
size_t per_channel) {
|
||||
/* row8x4-major * row4x8-major => (int8)row-major */
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
|
|
|
@ -39,7 +39,7 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col);
|
|||
void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
|
||||
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
|
||||
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
|
||||
bool per_channel);
|
||||
size_t per_channel);
|
||||
void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
|
||||
void RowMajor2Row4x8MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
|
||||
void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16);
|
||||
|
@ -59,8 +59,8 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row,
|
|||
|
||||
#ifdef ENABLE_ARM64
|
||||
void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums,
|
||||
const int *bias, int act_min, int act_max, int out_zp, int multiplier, int left_shift,
|
||||
int right_shift, int row, int col, int stride);
|
||||
const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift,
|
||||
int32_t *right_shift, int row, int col, int stride, int filter_peroc);
|
||||
|
||||
void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16,
|
||||
const int *input_sum, const int *bias);
|
||||
|
|
|
@ -25,7 +25,7 @@ typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int *dst, i
|
|||
typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
|
||||
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
|
||||
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini,
|
||||
int32_t maxi, bool per_channel);
|
||||
int32_t maxi, size_t per_channel);
|
||||
|
||||
typedef void (*MAT_TRANS_FUNC)(void *dst, void *a, int row, int col);
|
||||
|
||||
|
|
|
@ -30,8 +30,9 @@ extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_
|
|||
extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16,
|
||||
const int *input_sum, const int *bias);
|
||||
extern void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, int row8, int col8, int deep4,
|
||||
const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, int multiplier,
|
||||
int left_shift, int right_shift, int row, int col, int stride);
|
||||
const int *a_sums, const int *bias, int act_min, int act_max, int out_zp,
|
||||
int *multiplier, int *left_shift, int *right_shift, int row, int col, int stride,
|
||||
size_t peroc);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
@ -55,8 +56,8 @@ void MatMulR4Int8_optimize_handler(const int8_t *a, const int8_t *b, int *dst, i
|
|||
void MatMulRInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
|
||||
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
|
||||
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini,
|
||||
int32_t maxi, bool per_channel) {
|
||||
int32_t maxi, size_t per_channel) {
|
||||
return MatmulInt8DpNeon64(a, b, dst, UP_ROUND(row, 8), UP_ROUND(col, 8), deep_4, input_sum, bias, mini, maxi,
|
||||
output_zp, multiplier[0], left_shift[0], right_shift[0], row, col, stride);
|
||||
output_zp, multiplier, left_shift, right_shift, row, col, stride, per_channel);
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -273,13 +273,13 @@ void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, i
|
|||
"mov x11, %[input_sum] \n"
|
||||
"mov x15, %[filter_zp_ptr] \n"
|
||||
|
||||
"mov x0, #0 \n" // row 4 count
|
||||
"mov x0, #0 \n"
|
||||
"1: \n"
|
||||
"cmp x0, %[hw4] \n"
|
||||
"beq 11f \n"
|
||||
"add x0, x0, #4\n"
|
||||
"dup v10.4s, wzr \n"
|
||||
"mov x2, #0 \n" // input deep count
|
||||
"mov x2, #0 \n"
|
||||
"mov x16, x15 \n"
|
||||
|
||||
"2: \n"
|
||||
|
@ -313,9 +313,9 @@ void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, i
|
|||
"b 2b \n"
|
||||
|
||||
"3: \n"
|
||||
"mov x12, x11 \n" // tmp inputsm inputsum hw
|
||||
"mov x12, x11 \n"
|
||||
"add x11, x11, #64 \n"
|
||||
"mov x4, #0 \n" // oc count
|
||||
"mov x4, #0 \n"
|
||||
|
||||
"dup v1.4s, v10.s[0] \n"
|
||||
"dup v2.4s, v10.s[1] \n"
|
||||
|
|
|
@ -46,6 +46,18 @@ Convolution1x1Int8CPUKernel::~Convolution1x1Int8CPUKernel() {
|
|||
free(filter_zp_ptr_);
|
||||
filter_zp_ptr_ = nullptr;
|
||||
}
|
||||
if (filter_peroc_ && left_shift_ != nullptr) {
|
||||
free(left_shift_);
|
||||
left_shift_ = nullptr;
|
||||
}
|
||||
if (filter_peroc_ && right_shift_ != nullptr) {
|
||||
free(right_shift_);
|
||||
right_shift_ = nullptr;
|
||||
}
|
||||
if (filter_peroc_ && multiplier_ != nullptr) {
|
||||
free(multiplier_);
|
||||
multiplier_ = nullptr;
|
||||
}
|
||||
FreeResizeBuf();
|
||||
FreeQuantParam();
|
||||
}
|
||||
|
@ -59,7 +71,7 @@ void Convolution1x1Int8CPUKernel::FreeResizeBuf() {
|
|||
}
|
||||
|
||||
void Convolution1x1Int8CPUKernel::CheckSupportOptimize() {
|
||||
support_optimize_ = false;
|
||||
support_optimize_ = true;
|
||||
matmul_func_ = MatMulInt8_8x8_r;
|
||||
#ifdef ENABLE_ARM64
|
||||
void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_;
|
||||
|
@ -78,10 +90,6 @@ void Convolution1x1Int8CPUKernel::CheckSupportOptimize() {
|
|||
support_optimize_ = false;
|
||||
matmul_func_ = nullptr;
|
||||
}
|
||||
|
||||
if (filter_peroc_) {
|
||||
support_optimize_ = false;
|
||||
}
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
@ -109,6 +117,26 @@ int Convolution1x1Int8CPUKernel::InitBiasByzp(void *src_weight, int input_channe
|
|||
for (int fi = 0; fi < output_channel; fi++) {
|
||||
filter_zp_ptr_[fi] = conv_param_->conv_quant_arg_.filter_quant_args_[fi].zp_;
|
||||
}
|
||||
|
||||
int up_round_oc_size = support_optimize_ ? UP_ROUND(output_channel, C8NUM) : UP_ROUND(output_channel, C4NUM);
|
||||
left_shift_ = reinterpret_cast<int32_t *>(malloc(up_round_oc_size * sizeof(int32_t)));
|
||||
if (left_shift_ == nullptr) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(left_shift_, 0, up_round_oc_size * sizeof(int32_t));
|
||||
memcpy(left_shift_, conv_param_->conv_quant_arg_.left_shift_, output_channel * sizeof(int32_t));
|
||||
right_shift_ = reinterpret_cast<int32_t *>(malloc(up_round_oc_size * sizeof(int32_t)));
|
||||
if (right_shift_ == nullptr) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(right_shift_, 0, up_round_oc_size * sizeof(int32_t));
|
||||
memcpy(right_shift_, conv_param_->conv_quant_arg_.right_shift_, output_channel * sizeof(int32_t));
|
||||
multiplier_ = reinterpret_cast<int32_t *>(malloc(up_round_oc_size * sizeof(int32_t)));
|
||||
if (multiplier_ == nullptr) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(multiplier_, 0, up_round_oc_size * sizeof(int32_t));
|
||||
memcpy(multiplier_, conv_param_->conv_quant_arg_.quant_multiplier_, output_channel * sizeof(int32_t));
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -328,9 +356,9 @@ int Convolution1x1Int8CPUKernel::RunImpl(int task_id) {
|
|||
}
|
||||
if (filter_peroc_) {
|
||||
cur_input_sum = input_sum_ + task_id * matmul_param_->row_8_ * thread_stride_ * C8NUM;
|
||||
cur_left_shift = conv_param_->conv_quant_arg_.left_shift_ + task_id * thread_stride_ * C8NUM;
|
||||
cur_right_shift = conv_param_->conv_quant_arg_.right_shift_ + task_id * thread_stride_ * C8NUM;
|
||||
cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_ + task_id * thread_stride_ * C8NUM;
|
||||
cur_left_shift = left_shift_ + task_id * thread_stride_ * C8NUM;
|
||||
cur_right_shift = right_shift_ + task_id * thread_stride_ * C8NUM;
|
||||
cur_multiplier = multiplier_ + task_id * thread_stride_ * C8NUM;
|
||||
}
|
||||
Conv1x1Int8Opt(packed_input_, packed_weight_ + task_id * thread_stride_ * C8NUM * matmul_param_->deep_4_,
|
||||
output_ptr_ + task_id * thread_stride_ * C8NUM, cur_input_sum,
|
||||
|
@ -346,9 +374,9 @@ int Convolution1x1Int8CPUKernel::RunImpl(int task_id) {
|
|||
}
|
||||
if (filter_peroc_) {
|
||||
cur_input_sum = input_sum_ + task_id * matmul_param_->row_4_ * thread_stride_ * C4NUM;
|
||||
cur_left_shift = conv_param_->conv_quant_arg_.left_shift_ + task_id * thread_stride_ * C4NUM;
|
||||
cur_right_shift = conv_param_->conv_quant_arg_.right_shift_ + task_id * thread_stride_ * C4NUM;
|
||||
cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_ + task_id * thread_stride_ * C4NUM;
|
||||
cur_left_shift = left_shift_ + task_id * thread_stride_ * C4NUM;
|
||||
cur_right_shift = right_shift_ + task_id * thread_stride_ * C4NUM;
|
||||
cur_multiplier = multiplier_ + task_id * thread_stride_ * C4NUM;
|
||||
}
|
||||
Conv1x1Int8(packed_input_, packed_weight_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_16_,
|
||||
output_ptr_ + task_id * thread_stride_ * C4NUM, cur_input_sum,
|
||||
|
|
|
@ -58,8 +58,11 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel {
|
|||
int InitBiasByzp(void *src_weight, int input_channel, int output_channel);
|
||||
|
||||
private:
|
||||
int32_t *input_sum_ = nullptr; /* per-channel: oc4 format */
|
||||
int32_t *filter_zp_ptr_ = nullptr; /* oc - per - channel */
|
||||
int32_t *input_sum_ = nullptr; /* per-oc: oc4 format */
|
||||
int32_t *filter_zp_ptr_ = nullptr; /* per-oc */
|
||||
int32_t *left_shift_ = nullptr; /* per-oc up round */
|
||||
int32_t *right_shift_ = nullptr; /* per-oc up round */
|
||||
int32_t *multiplier_ = nullptr; /* per-oc up round */
|
||||
int8_t *packed_weight_ = nullptr;
|
||||
int8_t *packed_input_ = nullptr;
|
||||
int8_t *input_ptr_ = nullptr;
|
||||
|
|
|
@ -108,8 +108,8 @@ int FullconnectionInt8CPUKernel::RunImpl(int task_id) {
|
|||
auto cur_c = output_ptr + task_id * thread_stride_ * C4NUM;
|
||||
#ifdef ENABLE_ARM64
|
||||
MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, q.out_act_min,
|
||||
q.out_act_max, q.output.zp_, q.quant_multiplier, q.left_shift, q.right_shift, p->row_, cur_oc_res,
|
||||
p->col_ * sizeof(int8_t));
|
||||
q.out_act_max, q.output.zp_, &q.quant_multiplier, &q.left_shift, &q.right_shift, p->row_, cur_oc_res,
|
||||
p->col_ * sizeof(int8_t), 0);
|
||||
#else
|
||||
MatmulInt8(a_r4x16_ptr_, cur_b, cur_c, input_sums_, cur_bias, q.out_act_min, q.out_act_max, q.output.zp_,
|
||||
q.quant_multiplier, q.left_shift, q.right_shift, p->row_, cur_oc_res, d16_, p->col_);
|
||||
|
|
|
@ -101,8 +101,8 @@ int MatmulInt8CPUKernel::RunImpl(int task_id) {
|
|||
auto &p = quant_params_;
|
||||
#ifdef ENABLE_ARM64
|
||||
MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, INT8_MIN, INT8_MAX,
|
||||
p.output.zp_, p.quant_multiplier, p.left_shift, p.right_shift, params_->row_, cur_oc_res,
|
||||
params_->col_ * sizeof(int8_t));
|
||||
p.output.zp_, &p.quant_multiplier, &p.left_shift, &p.right_shift, params_->row_, cur_oc_res,
|
||||
params_->col_ * sizeof(int8_t), false);
|
||||
#else
|
||||
MatmulInt8(a_r4x16_ptr_, cur_b, cur_c, input_sums_, cur_bias, INT8_MIN, INT8_MAX, p.output.zp_, p.quant_multiplier,
|
||||
p.left_shift, p.right_shift, params_->row_, cur_oc_res, d16_, params_->col_);
|
||||
|
|
|
@ -120,8 +120,8 @@ TEST_F(TestMatmulInt8, simple) {
|
|||
int multiplier, ls, rs;
|
||||
QuantizeRoundParameter(1.0f, &multiplier, &ls, &rs);
|
||||
#ifdef ENABLE_ARM64
|
||||
MatmulInt8Neon64(a_r4x16, b_c16x4, output, ROW4, COL4, DEPTH16, a_sums, bias, INT8_MIN, INT8_MAX, 0, multiplier, ls,
|
||||
rs, ROW, COL, COL);
|
||||
MatmulInt8Neon64(a_r4x16, b_c16x4, output, ROW4, COL4, DEPTH16, a_sums, bias, INT8_MIN, INT8_MAX, 0, &multiplier, &ls,
|
||||
&rs, ROW, COL, COL, false);
|
||||
#else
|
||||
MatmulInt8(a_r4x16, b_c16x4, output, a_sums, bias, INT8_MIN, INT8_MAX, 0, multiplier, ls, rs, ROW, COL, DEPTH16, COL);
|
||||
#endif
|
||||
|
|
Loading…
Reference in New Issue