forked from mindspore-Ecosystem/mindspore
!4861 [MS][LITE][Develop]add conv per channel support for int8
Merge pull request !4861 from lixian/master
This commit is contained in:
commit
90552c4933
|
@ -8,8 +8,8 @@
|
|||
#endif
|
||||
|
||||
// void IndirectGemmInt8_4x4(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4,
|
||||
// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier,
|
||||
// size_t shift_before, size_t shift_after);
|
||||
// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp,
|
||||
// int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after, size_t asymmetric, size_t per_channel);
|
||||
// x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset
|
||||
IndirectGemmInt8_4x4:
|
||||
|
||||
|
@ -36,18 +36,26 @@ IndirectGemmInt8_4x4:
|
|||
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
|
||||
// r19 ~ r29 should be also preserved
|
||||
// whereas our coding style do not permit such amount of parameters
|
||||
sub sp, sp, #144
|
||||
sub sp, sp, #176
|
||||
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
stp x19, x20, [sp], #16
|
||||
stp x21, x22, [sp], #16
|
||||
stp x23, x24, [sp], #16
|
||||
|
||||
ldr x15, [sp]
|
||||
ldr w8, [sp, #8]
|
||||
ldr w9, [sp, #16]
|
||||
ldr w16, [sp, #24]
|
||||
ldr w17, [sp, #32]
|
||||
ldr w18, [sp, #40]
|
||||
ldr w19, [sp, #48]
|
||||
ldr x17, [sp, #32]
|
||||
ldr x18, [sp, #40]
|
||||
ldr x19, [sp, #48]
|
||||
ldr x20, [sp, #56]
|
||||
ldr x21, [sp, #64]
|
||||
|
||||
add x24, x6, #3
|
||||
mov x23, #4
|
||||
sdiv x23, x24, x23
|
||||
|
||||
mul x5, x4, x5
|
||||
mov x4, #1
|
||||
|
@ -189,12 +197,6 @@ IndirectGemmInt8_4x4:
|
|||
sadalp v30.4s, v14.8h
|
||||
sadalp v31.4s, v15.8h
|
||||
|
||||
// load sum
|
||||
mov x20, x15
|
||||
ld1r {v8.4s}, [x20], #4
|
||||
ld1r {v9.4s}, [x20], #4
|
||||
ld1r {v10.4s}, [x20], #4
|
||||
ld1r {v11.4s}, [x20]
|
||||
// pairwise add
|
||||
addp v16.4s, v16.4s, v17.4s
|
||||
addp v18.4s, v18.4s, v19.4s
|
||||
|
@ -212,28 +214,51 @@ IndirectGemmInt8_4x4:
|
|||
addp v20.4s, v20.4s, v22.4s
|
||||
addp v24.4s, v24.4s, v26.4s
|
||||
addp v28.4s, v28.4s, v30.4s
|
||||
cbz x20, NoSum
|
||||
// load sum
|
||||
mov x22, x15
|
||||
cbz x21, SymSum
|
||||
ld1r {v8.4s}, [x22], x23
|
||||
ld1r {v9.4s}, [x22], x23
|
||||
ld1r {v10.4s}, [x22], x23
|
||||
ld1r {v11.4s}, [x22]
|
||||
b AddSum
|
||||
SymSum:
|
||||
ld1r {v8.4s}, [x22], #4
|
||||
ld1r {v9.4s}, [x22], #4
|
||||
ld1r {v10.4s}, [x22], #4
|
||||
ld1r {v11.4s}, [x22]
|
||||
AddSum:
|
||||
sub v16.4s, v16.4s, v8.4s
|
||||
sub v20.4s, v20.4s, v9.4s
|
||||
sub v24.4s, v24.4s, v10.4s
|
||||
sub v28.4s, v28.4s, v11.4s
|
||||
NoSum:
|
||||
add v16.4s, v16.4s, v12.4s
|
||||
add v20.4s, v20.4s, v12.4s
|
||||
add v24.4s, v24.4s, v12.4s
|
||||
add v28.4s, v28.4s, v12.4s
|
||||
|
||||
dup v2.4s, w18
|
||||
cbnz x21, PerChannel
|
||||
ld1r {v2.4s}, [x18]
|
||||
ld1r {v3.4s}, [x17]
|
||||
ld1r {v4.4s}, [x19]
|
||||
b QuantizeStart
|
||||
PerChannel:
|
||||
ld1 {v2.4s}, [x18]
|
||||
ld1 {v3.4s}, [x17]
|
||||
ld1 {v4.4s}, [x19]
|
||||
QuantizeStart:
|
||||
sqshl v16.4s, v16.4s, v2.4s
|
||||
sqshl v20.4s, v20.4s, v2.4s
|
||||
sqshl v24.4s, v24.4s, v2.4s
|
||||
sqshl v28.4s, v28.4s, v2.4s
|
||||
|
||||
dup v3.4s, w17
|
||||
sqrdmulh v16.4s, v16.4s, v3.4s
|
||||
sqrdmulh v20.4s, v20.4s, v3.4s
|
||||
sqrdmulh v24.4s, v24.4s, v3.4s
|
||||
sqrdmulh v28.4s, v28.4s, v3.4s
|
||||
|
||||
dup v4.4s, w19
|
||||
and v0.16b, v4.16b, v16.16b
|
||||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v16.4s, v16.4s, v0.4s
|
||||
|
@ -325,15 +350,25 @@ IndirectGemmInt8_4x4:
|
|||
bne LoopKsize
|
||||
|
||||
subs x6, x6, #4
|
||||
cbz x21, NoChannelForward
|
||||
cbz x20, NoSumForward
|
||||
add x15, x15, #16
|
||||
NoSumForward:
|
||||
add x17, x17, #16
|
||||
add x18, x18, #16
|
||||
add x19, x19, #16
|
||||
NoChannelForward:
|
||||
cbz x3, NoStepFowrard
|
||||
add x3, x3, #16
|
||||
NoStepFowrard:
|
||||
bgt LoopOc
|
||||
|
||||
sub sp, sp, #144
|
||||
sub sp, sp, #176
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
ldp x19, x20, [sp], #16
|
||||
ldp x21, x22, [sp], #16
|
||||
ldp x23, x24, [sp], #16
|
||||
ret
|
||||
#endif
|
||||
|
||||
|
|
|
@ -8,8 +8,8 @@
|
|||
#endif
|
||||
|
||||
// void IndirectGemmInt8_24x4_dp(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4,
|
||||
// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier,
|
||||
// size_t shift_before, size_t shift_after);
|
||||
// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, int32_t *out_multiplier,
|
||||
// int32_t *shift_before, int32_t *shift_after);
|
||||
// x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset
|
||||
// we use sdot intrinsic on cores that supports dotprod(Armv8.2-A w/dp or later)
|
||||
// mrs intrinsic could read system register ID_AA64ISAR0_EL1(or s3_0_c0_c6_0 on Armv8.2-A)
|
||||
|
@ -17,35 +17,64 @@
|
|||
IndirectGemmInt8_24x4_dp:
|
||||
|
||||
.macro INIT_BIAS
|
||||
mov x20, x15
|
||||
ld1r {v8.4s}, [x20], #4
|
||||
ld1r {v9.4s}, [x20], #4
|
||||
ld1r {v10.4s}, [x20], #4
|
||||
ld1r {v11.4s}, [x20], #4
|
||||
ld1r {v12.4s}, [x20], #4
|
||||
ld1r {v13.4s}, [x20], #4
|
||||
ld1r {v14.4s}, [x20], #4
|
||||
ld1r {v15.4s}, [x20], #4
|
||||
ld1r {v16.4s}, [x20], #4
|
||||
ld1r {v17.4s}, [x20], #4
|
||||
ld1r {v18.4s}, [x20], #4
|
||||
ld1r {v19.4s}, [x20], #4
|
||||
ld1r {v20.4s}, [x20], #4
|
||||
ld1r {v21.4s}, [x20], #4
|
||||
ld1r {v22.4s}, [x20], #4
|
||||
ld1r {v23.4s}, [x20], #4
|
||||
ld1r {v24.4s}, [x20], #4
|
||||
ld1r {v25.4s}, [x20], #4
|
||||
ld1r {v26.4s}, [x20], #4
|
||||
ld1r {v27.4s}, [x20], #4
|
||||
ld1r {v28.4s}, [x20], #4
|
||||
ld1r {v29.4s}, [x20], #4
|
||||
ld1r {v30.4s}, [x20], #4
|
||||
ld1r {v31.4s}, [x20], #4
|
||||
dup v7.4s, wzr
|
||||
cbz x3, InitBias
|
||||
ld1 {v7.4s}, [x3]
|
||||
InitBias:
|
||||
cbz x20, NoSum
|
||||
mov x22, x15
|
||||
cbz x21, SymSum
|
||||
ld1 {v8.4s}, [x22], x23
|
||||
ld1 {v9.4s}, [x22], x23
|
||||
ld1 {v10.4s}, [x22], x23
|
||||
ld1 {v11.4s}, [x22], x23
|
||||
ld1 {v12.4s}, [x22], x23
|
||||
ld1 {v13.4s}, [x22], x23
|
||||
ld1 {v14.4s}, [x22], x23
|
||||
ld1 {v15.4s}, [x22], x23
|
||||
ld1 {v16.4s}, [x22], x23
|
||||
ld1 {v17.4s}, [x22], x23
|
||||
ld1 {v18.4s}, [x22], x23
|
||||
ld1 {v19.4s}, [x22], x23
|
||||
ld1{v20.4s}, [x22], x23
|
||||
ld1 {v21.4s}, [x22], x23
|
||||
ld1 {v22.4s}, [x22], x23
|
||||
ld1 {v23.4s}, [x22], x23
|
||||
ld1 {v24.4s}, [x22], x23
|
||||
ld1 {v25.4s}, [x22], x23
|
||||
ld1 {v26.4s}, [x22], x23
|
||||
ld1 {v27.4s}, [x22], x23
|
||||
ld1 {v28.4s}, [x22], x23
|
||||
ld1 {v29.4s}, [x22], x23
|
||||
ld1 {v30.4s}, [x22], x23
|
||||
ld1 {v31.4s}, [x22], x23
|
||||
b AddSum
|
||||
SymSum:
|
||||
ld1r {v8.4s}, [x22], #4
|
||||
ld1r {v9.4s}, [x22], #4
|
||||
ld1r {v10.4s}, [x22], #4
|
||||
ld1r {v11.4s}, [x22], #4
|
||||
ld1r {v12.4s}, [x22], #4
|
||||
ld1r {v13.4s}, [x22], #4
|
||||
ld1r {v14.4s}, [x22], #4
|
||||
ld1r {v15.4s}, [x22], #4
|
||||
ld1r {v16.4s}, [x22], #4
|
||||
ld1r {v17.4s}, [x22], #4
|
||||
ld1r {v18.4s}, [x22], #4
|
||||
ld1r {v19.4s}, [x22], #4
|
||||
ld1r {v20.4s}, [x22], #4
|
||||
ld1r {v21.4s}, [x22], #4
|
||||
ld1r {v22.4s}, [x22], #4
|
||||
ld1r {v23.4s}, [x22], #4
|
||||
ld1r {v24.4s}, [x22], #4
|
||||
ld1r {v25.4s}, [x22], #4
|
||||
ld1r {v26.4s}, [x22], #4
|
||||
ld1r {v27.4s}, [x22], #4
|
||||
ld1r {v28.4s}, [x22], #4
|
||||
ld1r {v29.4s}, [x22], #4
|
||||
ld1r {v30.4s}, [x22], #4
|
||||
ld1r {v31.4s}, [x22], #4
|
||||
AddSum:
|
||||
sub v8.4s, v7.4s, v8.4s
|
||||
sub v9.4s, v7.4s, v9.4s
|
||||
sub v10.4s, v7.4s, v10.4s
|
||||
|
@ -70,24 +99,59 @@ IndirectGemmInt8_24x4_dp:
|
|||
sub v29.4s, v7.4s, v29.4s
|
||||
sub v30.4s, v7.4s, v30.4s
|
||||
sub v31.4s, v7.4s, v31.4s
|
||||
b InitBiasEnd
|
||||
NoSum:
|
||||
mov v8.16b, v7.16b
|
||||
mov v9.16b, v7.16b
|
||||
mov v10.16b, v7.16b
|
||||
mov v11.16b, v7.16b
|
||||
mov v12.16b, v7.16b
|
||||
mov v13.16b, v7.16b
|
||||
mov v14.16b, v7.16b
|
||||
mov v15.16b, v7.16b
|
||||
mov v16.16b, v7.16b
|
||||
mov v17.16b, v7.16b
|
||||
mov v18.16b, v7.16b
|
||||
mov v19.16b, v7.16b
|
||||
mov v20.16b, v7.16b
|
||||
mov v21.16b, v7.16b
|
||||
mov v22.16b, v7.16b
|
||||
mov v23.16b, v7.16b
|
||||
mov v24.16b, v7.16b
|
||||
mov v25.16b, v7.16b
|
||||
mov v26.16b, v7.16b
|
||||
mov v27.16b, v7.16b
|
||||
mov v28.16b, v7.16b
|
||||
mov v29.16b, v7.16b
|
||||
mov v30.16b, v7.16b
|
||||
mov v31.16b, v7.16b
|
||||
InitBiasEnd:
|
||||
.endm
|
||||
|
||||
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
|
||||
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
|
||||
// r19 ~ r29 should be also preserved
|
||||
// whereas our coding style do not permit such amount of parameters
|
||||
sub sp, sp, #144
|
||||
sub sp, sp, #176
|
||||
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
stp x19, x20, [sp], #16
|
||||
stp x21, x22, [sp], #16
|
||||
stp x23, x24, [sp], #16
|
||||
|
||||
ldr x15, [sp]
|
||||
ldr w8, [sp, #8]
|
||||
ldr w9, [sp, #16]
|
||||
ldr w16, [sp, #24]
|
||||
ldr w17, [sp, #32]
|
||||
ldr w18, [sp, #40]
|
||||
ldr w19, [sp, #48]
|
||||
ldr x17, [sp, #32]
|
||||
ldr x18, [sp, #40]
|
||||
ldr x19, [sp, #48]
|
||||
ldr x20, [sp, #56]
|
||||
ldr x21, [sp, #64]
|
||||
|
||||
add x24, x6, #3
|
||||
mov x23, #4
|
||||
sdiv x23, x24, x23
|
||||
|
||||
mul x5, x4, x5
|
||||
mov x4, #1
|
||||
|
@ -206,7 +270,7 @@ IndirectGemmInt8_24x4_dp:
|
|||
b LoopIc
|
||||
|
||||
LoopIcEnd:
|
||||
mov x20, x15
|
||||
mov x22, x15
|
||||
// load input for output 1-8
|
||||
ld1 {v0.16b, v1.16b}, [x12], #32
|
||||
.inst 0x4f82e0d0 // sdot v16.4s, v6.16b, v2.4b[0]
|
||||
|
@ -276,7 +340,16 @@ IndirectGemmInt8_24x4_dp:
|
|||
.inst 0x4fa5e8df // sdot v31.4s, v6.16b, v5.4b[3]
|
||||
|
||||
Quantization:
|
||||
dup v2.4s, w18
|
||||
cbnz x21, PerChannel
|
||||
ld1r {v2.4s}, [x18]
|
||||
ld1r {v3.4s}, [x17]
|
||||
ld1r {v4.4s}, [x19]
|
||||
b QuantizeStart
|
||||
PerChannel:
|
||||
ld1 {v2.4s}, [x18]
|
||||
ld1 {v3.4s}, [x17]
|
||||
ld1 {v4.4s}, [x19]
|
||||
QuantizeStart:
|
||||
sqshl v8.4s, v8.4s, v2.4s
|
||||
sqshl v9.4s, v9.4s, v2.4s
|
||||
sqshl v10.4s, v10.4s, v2.4s
|
||||
|
@ -302,7 +375,6 @@ IndirectGemmInt8_24x4_dp:
|
|||
sqshl v30.4s, v30.4s, v2.4s
|
||||
sqshl v31.4s, v31.4s, v2.4s
|
||||
|
||||
dup v3.4s, w17
|
||||
sqrdmulh v8.4s, v8.4s, v3.4s
|
||||
sqrdmulh v9.4s, v9.4s, v3.4s
|
||||
sqrdmulh v10.4s, v10.4s, v3.4s
|
||||
|
@ -328,100 +400,99 @@ IndirectGemmInt8_24x4_dp:
|
|||
sqrdmulh v30.4s, v30.4s, v3.4s
|
||||
sqrdmulh v31.4s, v31.4s, v3.4s
|
||||
|
||||
dup v4.4s, w19
|
||||
add v0.16b, v4.16b, v8.16b
|
||||
and v0.16b, v4.16b, v8.16b
|
||||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v8.4s, v8.4s, v0.4s
|
||||
srshl v8.4s, v8.4s, v4.4s
|
||||
add v0.16b, v4.16b, v9.16b
|
||||
and v0.16b, v4.16b, v9.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v9.4s, v9.4s, v1.4s
|
||||
srshl v9.4s, v9.4s, v4.4s
|
||||
add v2.16b, v4.16b, v10.16b
|
||||
and v2.16b, v4.16b, v10.16b
|
||||
sshr v2.4s, v2.4s, #31
|
||||
sqadd v10.4s, v10.4s, v2.4s
|
||||
srshl v10.4s, v10.4s, v4.4s
|
||||
add v3.16b, v4.16b, v11.16b
|
||||
and v3.16b, v4.16b, v11.16b
|
||||
sshr v3.4s, v3.4s, #31
|
||||
sqadd v11.4s, v11.4s, v3.4s
|
||||
srshl v11.4s, v11.4s, v4.4s
|
||||
add v0.16b, v4.16b, v12.16b
|
||||
and v0.16b, v4.16b, v12.16b
|
||||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v12.4s, v12.4s, v0.4s
|
||||
srshl v12.4s, v12.4s, v4.4s
|
||||
add v0.16b, v4.16b, v13.16b
|
||||
and v0.16b, v4.16b, v13.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v13.4s, v13.4s, v1.4s
|
||||
srshl v13.4s, v13.4s, v4.4s
|
||||
add v2.16b, v4.16b, v14.16b
|
||||
and v2.16b, v4.16b, v14.16b
|
||||
sshr v2.4s, v2.4s, #31
|
||||
sqadd v14.4s, v14.4s, v2.4s
|
||||
srshl v14.4s, v14.4s, v4.4s
|
||||
add v3.16b, v4.16b, v15.16b
|
||||
and v3.16b, v4.16b, v15.16b
|
||||
sshr v3.4s, v3.4s, #31
|
||||
sqadd v15.4s, v15.4s, v3.4s
|
||||
srshl v15.4s, v15.4s, v4.4s
|
||||
add v0.16b, v4.16b, v16.16b
|
||||
and v0.16b, v4.16b, v16.16b
|
||||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v16.4s, v16.4s, v0.4s
|
||||
srshl v16.4s, v16.4s, v4.4s
|
||||
add v0.16b, v4.16b, v17.16b
|
||||
and v0.16b, v4.16b, v17.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v17.4s, v17.4s, v1.4s
|
||||
srshl v17.4s, v17.4s, v4.4s
|
||||
add v2.16b, v4.16b, v18.16b
|
||||
and v2.16b, v4.16b, v18.16b
|
||||
sshr v2.4s, v2.4s, #31
|
||||
sqadd v18.4s, v18.4s, v2.4s
|
||||
srshl v18.4s, v18.4s, v4.4s
|
||||
add v3.16b, v4.16b, v19.16b
|
||||
and v3.16b, v4.16b, v19.16b
|
||||
sshr v3.4s, v3.4s, #31
|
||||
sqadd v19.4s, v19.4s, v3.4s
|
||||
srshl v19.4s, v19.4s, v4.4s
|
||||
add v0.16b, v4.16b, v20.16b
|
||||
and v0.16b, v4.16b, v20.16b
|
||||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v20.4s, v20.4s, v0.4s
|
||||
srshl v20.4s, v20.4s, v4.4s
|
||||
add v0.16b, v4.16b, v21.16b
|
||||
and v0.16b, v4.16b, v21.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v21.4s, v21.4s, v1.4s
|
||||
srshl v21.4s, v21.4s, v4.4s
|
||||
add v2.16b, v4.16b, v22.16b
|
||||
and v2.16b, v4.16b, v22.16b
|
||||
sshr v2.4s, v2.4s, #31
|
||||
sqadd v22.4s, v22.4s, v2.4s
|
||||
srshl v10.4s, v10.4s, v4.4s
|
||||
add v3.16b, v4.16b, v23.16b
|
||||
srshl v22.4s, v22.4s, v4.4s
|
||||
and v3.16b, v4.16b, v23.16b
|
||||
sshr v3.4s, v3.4s, #31
|
||||
sqadd v23.4s, v23.4s, v3.4s
|
||||
srshl v23.4s, v23.4s, v4.4s
|
||||
add v0.16b, v4.16b, v24.16b
|
||||
and v0.16b, v4.16b, v24.16b
|
||||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v24.4s, v24.4s, v0.4s
|
||||
srshl v24.4s, v24.4s, v4.4s
|
||||
add v0.16b, v4.16b, v25.16b
|
||||
and v0.16b, v4.16b, v25.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v25.4s, v25.4s, v1.4s
|
||||
srshl v25.4s, v25.4s, v4.4s
|
||||
add v2.16b, v4.16b, v26.16b
|
||||
and v2.16b, v4.16b, v26.16b
|
||||
sshr v2.4s, v2.4s, #31
|
||||
sqadd v26.4s, v26.4s, v2.4s
|
||||
srshl v26.4s, v26.4s, v4.4s
|
||||
add v3.16b, v4.16b, v27.16b
|
||||
and v3.16b, v4.16b, v27.16b
|
||||
sshr v3.4s, v3.4s, #31
|
||||
sqadd v27.4s, v27.4s, v3.4s
|
||||
srshl v27.4s, v27.4s, v4.4s
|
||||
add v0.16b, v4.16b, v28.16b
|
||||
and v0.16b, v4.16b, v28.16b
|
||||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v28.4s, v28.4s, v0.4s
|
||||
srshl v28.4s, v28.4s, v4.4s
|
||||
add v0.16b, v4.16b, v29.16b
|
||||
and v0.16b, v4.16b, v29.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v29.4s, v29.4s, v1.4s
|
||||
srshl v29.4s, v29.4s, v4.4s
|
||||
add v2.16b, v4.16b, v30.16b
|
||||
and v2.16b, v4.16b, v30.16b
|
||||
sshr v2.4s, v2.4s, #31
|
||||
sqadd v30.4s, v30.4s, v2.4s
|
||||
srshl v30.4s, v30.4s, v4.4s
|
||||
add v3.16b, v4.16b, v31.16b
|
||||
and v3.16b, v4.16b, v31.16b
|
||||
sshr v3.4s, v3.4s, #31
|
||||
sqadd v31.4s, v31.4s, v3.4s
|
||||
srshl v31.4s, v31.4s, v4.4s
|
||||
|
@ -694,15 +765,24 @@ IndirectGemmInt8_24x4_dp:
|
|||
bne LoopKsize
|
||||
|
||||
subs x6, x6, #4
|
||||
cbz x21, NoChannelForward
|
||||
cbz x20, NoSumForward
|
||||
add x15, x15, #16
|
||||
NoSumForward:
|
||||
add x17, x17, #16
|
||||
add x18, x18, #16
|
||||
add x19, x19, #16
|
||||
NoChannelForward:
|
||||
cbz x3, NoStepFowrard
|
||||
add x3, x3, #16
|
||||
NoStepFowrard:
|
||||
bgt LoopOc
|
||||
|
||||
sub sp, sp, #144
|
||||
sub sp, sp, #176
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
ldp x19, x20, [sp], #16
|
||||
ldp x21, x22, [sp], #16
|
||||
ldp x23, x24, [sp], #16
|
||||
ret
|
||||
#endif
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
#include "nnacl/fp32/common_func.h"
|
||||
|
||||
#ifndef __aarch64__
|
||||
#ifndef ENABLE_ARM64
|
||||
void MatrixAdd(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stride, size_t b_stride, size_t c_stride,
|
||||
size_t row, size_t col) {
|
||||
for (int r = 0; r < row; r++) {
|
||||
|
|
|
@ -40,8 +40,8 @@ void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t *
|
|||
size_t oc4, size_t offset);
|
||||
void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize,
|
||||
size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min,
|
||||
size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before,
|
||||
size_t shift_after);
|
||||
size_t act_max, size_t out_zp, int32_t *out_multiplier, int32_t *shift_before,
|
||||
int32_t *shift_after, size_t asymmetric, size_t per_channel);
|
||||
void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width,
|
||||
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
|
||||
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step);
|
||||
|
|
|
@ -29,14 +29,12 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in
|
|||
int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0];
|
||||
int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0];
|
||||
int oc4 = UP_DIV(output_channel, C4NUM);
|
||||
#ifdef __aarch64__
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t asymmetric = conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC;
|
||||
size_t per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL;
|
||||
IndirectGemmInt8_4x4(dst, src, weight, bias, UP_DIV(kernel_plane, C4NUM), ic4, output_channel,
|
||||
output_channel * sizeof(int8_t), input_sum, act_min, act_max, out_zp, out_multiplier,
|
||||
shift_before, shift_after);
|
||||
// #elif defined(ENABLE_ARM32)
|
||||
// IndirectGemmInt8_2x4(dst, src, weight, bias, UP_DIV(kernel_plane, C4NUM), ic4, output_channel,
|
||||
// output_channel * sizeof(int8_t), input_sum, act_min, act_max, out_zp, out_multiplier,
|
||||
// shift_before, shift_after);
|
||||
shift_before, shift_after, asymmetric, per_channel);
|
||||
#else
|
||||
int tile_num = conv_param->tile_num_;
|
||||
int plane_c4 = UP_DIV(kernel_plane, C4NUM);
|
||||
|
@ -124,8 +122,10 @@ void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const
|
|||
int oc4 = UP_DIV(output_channel, C4NUM);
|
||||
if (gemm_func != NULL) {
|
||||
#ifdef __aarch64__
|
||||
size_t asymmetric = conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC;
|
||||
size_t per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL;
|
||||
gemm_func(dst, src, weight, bias, kernel_plane, ic4, output_channel, output_channel * sizeof(int8_t), input_sum,
|
||||
act_min, act_max, out_zp, out_multiplier, shift_before, shift_after);
|
||||
act_min, act_max, out_zp, out_multiplier, shift_before, shift_after, asymmetric, per_channel);
|
||||
#endif
|
||||
} else {
|
||||
int tile_num = conv_param->tile_num_;
|
||||
|
|
|
@ -28,8 +28,8 @@
|
|||
|
||||
typedef void (*GEMM_FUNC)(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, size_t ksize,
|
||||
size_t ic4, size_t output_channel, size_t offset, const int32_t *input_sum, size_t act_min,
|
||||
size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before,
|
||||
size_t shift_after);
|
||||
size_t act_max, size_t out_zp, int32_t *out_multiplier, int32_t *shift_before,
|
||||
int32_t *shift_after, size_t asymmetric, size_t per_channel);
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -22,11 +22,11 @@ extern "C" {
|
|||
extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias,
|
||||
size_t ksize, size_t ic4, size_t output_channel, size_t offset,
|
||||
const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp,
|
||||
size_t out_multiplier, size_t shift_before, size_t shift_after);
|
||||
int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after,
|
||||
size_t asymmetric, size_t per_channel);
|
||||
|
||||
extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16,
|
||||
const int *input_sum, const int *bias);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
@ -35,9 +35,10 @@ extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, in
|
|||
void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias,
|
||||
size_t ksize, size_t ic4, size_t output_channel, size_t offset,
|
||||
const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp,
|
||||
size_t out_multiplier, size_t shift_before, size_t shift_after) {
|
||||
int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after,
|
||||
size_t asymmetric, size_t per_channel) {
|
||||
return IndirectGemmInt8_24x4_dp(dst, src, weight, bias, ksize, ic4, output_channel, offset, input_sum, act_min,
|
||||
act_max, out_zp, out_multiplier, shift_before, shift_after);
|
||||
act_max, out_zp, out_multiplier, shift_before, shift_after, asymmetric, per_channel);
|
||||
}
|
||||
|
||||
void MatMulR4Int8_optimize_handler(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16,
|
||||
|
|
|
@ -879,8 +879,8 @@ void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int
|
|||
const float *src_ptr = src_batch + hw * channel + c;
|
||||
float *dst_ptr = dst_batch + c * plane + hw;
|
||||
#ifdef ENABLE_ARM64
|
||||
int srcStride = channel * 4;
|
||||
int dstStride = plane * 4;
|
||||
size_t srcStride = channel * sizeof(float);
|
||||
size_t dstStride = plane * sizeof(float);
|
||||
asm volatile(
|
||||
"mov x10, %[src_ptr]\n"
|
||||
"mov x11, %[dst_ptr]\n"
|
||||
|
|
Loading…
Reference in New Issue