!4861 [MS][LITE][Develop]add conv per channel support for int8

Merge pull request !4861 from lixian/master
This commit is contained in:
mindspore-ci-bot 2020-08-22 14:49:38 +08:00 committed by Gitee
commit 90552c4933
8 changed files with 212 additions and 96 deletions

View File

@ -8,8 +8,8 @@
#endif
// void IndirectGemmInt8_4x4(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4,
// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier,
// size_t shift_before, size_t shift_after);
// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp,
// int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after, size_t asymmetric, size_t per_channel);
// x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset
IndirectGemmInt8_4x4:
@ -36,18 +36,26 @@ IndirectGemmInt8_4x4:
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
// r19 ~ r29 should be also preserved
// whereas our coding style do not permit such amount of parameters
sub sp, sp, #144
sub sp, sp, #176
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
stp x19, x20, [sp], #16
stp x21, x22, [sp], #16
stp x23, x24, [sp], #16
ldr x15, [sp]
ldr w8, [sp, #8]
ldr w9, [sp, #16]
ldr w16, [sp, #24]
ldr w17, [sp, #32]
ldr w18, [sp, #40]
ldr w19, [sp, #48]
ldr x17, [sp, #32]
ldr x18, [sp, #40]
ldr x19, [sp, #48]
ldr x20, [sp, #56]
ldr x21, [sp, #64]
add x24, x6, #3
mov x23, #4
sdiv x23, x24, x23
mul x5, x4, x5
mov x4, #1
@ -189,12 +197,6 @@ IndirectGemmInt8_4x4:
sadalp v30.4s, v14.8h
sadalp v31.4s, v15.8h
// load sum
mov x20, x15
ld1r {v8.4s}, [x20], #4
ld1r {v9.4s}, [x20], #4
ld1r {v10.4s}, [x20], #4
ld1r {v11.4s}, [x20]
// pairwise add
addp v16.4s, v16.4s, v17.4s
addp v18.4s, v18.4s, v19.4s
@ -212,28 +214,51 @@ IndirectGemmInt8_4x4:
addp v20.4s, v20.4s, v22.4s
addp v24.4s, v24.4s, v26.4s
addp v28.4s, v28.4s, v30.4s
cbz x20, NoSum
// load sum
mov x22, x15
cbz x21, SymSum
ld1r {v8.4s}, [x22], x23
ld1r {v9.4s}, [x22], x23
ld1r {v10.4s}, [x22], x23
ld1r {v11.4s}, [x22]
b AddSum
SymSum:
ld1r {v8.4s}, [x22], #4
ld1r {v9.4s}, [x22], #4
ld1r {v10.4s}, [x22], #4
ld1r {v11.4s}, [x22]
AddSum:
sub v16.4s, v16.4s, v8.4s
sub v20.4s, v20.4s, v9.4s
sub v24.4s, v24.4s, v10.4s
sub v28.4s, v28.4s, v11.4s
NoSum:
add v16.4s, v16.4s, v12.4s
add v20.4s, v20.4s, v12.4s
add v24.4s, v24.4s, v12.4s
add v28.4s, v28.4s, v12.4s
dup v2.4s, w18
cbnz x21, PerChannel
ld1r {v2.4s}, [x18]
ld1r {v3.4s}, [x17]
ld1r {v4.4s}, [x19]
b QuantizeStart
PerChannel:
ld1 {v2.4s}, [x18]
ld1 {v3.4s}, [x17]
ld1 {v4.4s}, [x19]
QuantizeStart:
sqshl v16.4s, v16.4s, v2.4s
sqshl v20.4s, v20.4s, v2.4s
sqshl v24.4s, v24.4s, v2.4s
sqshl v28.4s, v28.4s, v2.4s
dup v3.4s, w17
sqrdmulh v16.4s, v16.4s, v3.4s
sqrdmulh v20.4s, v20.4s, v3.4s
sqrdmulh v24.4s, v24.4s, v3.4s
sqrdmulh v28.4s, v28.4s, v3.4s
dup v4.4s, w19
and v0.16b, v4.16b, v16.16b
sshr v0.4s, v0.4s, #31
sqadd v16.4s, v16.4s, v0.4s
@ -325,15 +350,25 @@ IndirectGemmInt8_4x4:
bne LoopKsize
subs x6, x6, #4
cbz x21, NoChannelForward
cbz x20, NoSumForward
add x15, x15, #16
NoSumForward:
add x17, x17, #16
add x18, x18, #16
add x19, x19, #16
NoChannelForward:
cbz x3, NoStepFowrard
add x3, x3, #16
NoStepFowrard:
bgt LoopOc
sub sp, sp, #144
sub sp, sp, #176
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ldp x19, x20, [sp], #16
ldp x21, x22, [sp], #16
ldp x23, x24, [sp], #16
ret
#endif

View File

@ -8,8 +8,8 @@
#endif
// void IndirectGemmInt8_24x4_dp(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4,
// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier,
// size_t shift_before, size_t shift_after);
// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, int32_t *out_multiplier,
// int32_t *shift_before, int32_t *shift_after);
// x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset
// we use sdot intrinsic on cores that supports dotprod(Armv8.2-A w/dp or later)
// mrs intrinsic could read system register ID_AA64ISAR0_EL1(or s3_0_c0_c6_0 on Armv8.2-A)
@ -17,35 +17,64 @@
IndirectGemmInt8_24x4_dp:
.macro INIT_BIAS
mov x20, x15
ld1r {v8.4s}, [x20], #4
ld1r {v9.4s}, [x20], #4
ld1r {v10.4s}, [x20], #4
ld1r {v11.4s}, [x20], #4
ld1r {v12.4s}, [x20], #4
ld1r {v13.4s}, [x20], #4
ld1r {v14.4s}, [x20], #4
ld1r {v15.4s}, [x20], #4
ld1r {v16.4s}, [x20], #4
ld1r {v17.4s}, [x20], #4
ld1r {v18.4s}, [x20], #4
ld1r {v19.4s}, [x20], #4
ld1r {v20.4s}, [x20], #4
ld1r {v21.4s}, [x20], #4
ld1r {v22.4s}, [x20], #4
ld1r {v23.4s}, [x20], #4
ld1r {v24.4s}, [x20], #4
ld1r {v25.4s}, [x20], #4
ld1r {v26.4s}, [x20], #4
ld1r {v27.4s}, [x20], #4
ld1r {v28.4s}, [x20], #4
ld1r {v29.4s}, [x20], #4
ld1r {v30.4s}, [x20], #4
ld1r {v31.4s}, [x20], #4
dup v7.4s, wzr
cbz x3, InitBias
ld1 {v7.4s}, [x3]
InitBias:
cbz x20, NoSum
mov x22, x15
cbz x21, SymSum
ld1 {v8.4s}, [x22], x23
ld1 {v9.4s}, [x22], x23
ld1 {v10.4s}, [x22], x23
ld1 {v11.4s}, [x22], x23
ld1 {v12.4s}, [x22], x23
ld1 {v13.4s}, [x22], x23
ld1 {v14.4s}, [x22], x23
ld1 {v15.4s}, [x22], x23
ld1 {v16.4s}, [x22], x23
ld1 {v17.4s}, [x22], x23
ld1 {v18.4s}, [x22], x23
ld1 {v19.4s}, [x22], x23
ld1{v20.4s}, [x22], x23
ld1 {v21.4s}, [x22], x23
ld1 {v22.4s}, [x22], x23
ld1 {v23.4s}, [x22], x23
ld1 {v24.4s}, [x22], x23
ld1 {v25.4s}, [x22], x23
ld1 {v26.4s}, [x22], x23
ld1 {v27.4s}, [x22], x23
ld1 {v28.4s}, [x22], x23
ld1 {v29.4s}, [x22], x23
ld1 {v30.4s}, [x22], x23
ld1 {v31.4s}, [x22], x23
b AddSum
SymSum:
ld1r {v8.4s}, [x22], #4
ld1r {v9.4s}, [x22], #4
ld1r {v10.4s}, [x22], #4
ld1r {v11.4s}, [x22], #4
ld1r {v12.4s}, [x22], #4
ld1r {v13.4s}, [x22], #4
ld1r {v14.4s}, [x22], #4
ld1r {v15.4s}, [x22], #4
ld1r {v16.4s}, [x22], #4
ld1r {v17.4s}, [x22], #4
ld1r {v18.4s}, [x22], #4
ld1r {v19.4s}, [x22], #4
ld1r {v20.4s}, [x22], #4
ld1r {v21.4s}, [x22], #4
ld1r {v22.4s}, [x22], #4
ld1r {v23.4s}, [x22], #4
ld1r {v24.4s}, [x22], #4
ld1r {v25.4s}, [x22], #4
ld1r {v26.4s}, [x22], #4
ld1r {v27.4s}, [x22], #4
ld1r {v28.4s}, [x22], #4
ld1r {v29.4s}, [x22], #4
ld1r {v30.4s}, [x22], #4
ld1r {v31.4s}, [x22], #4
AddSum:
sub v8.4s, v7.4s, v8.4s
sub v9.4s, v7.4s, v9.4s
sub v10.4s, v7.4s, v10.4s
@ -70,24 +99,59 @@ IndirectGemmInt8_24x4_dp:
sub v29.4s, v7.4s, v29.4s
sub v30.4s, v7.4s, v30.4s
sub v31.4s, v7.4s, v31.4s
b InitBiasEnd
NoSum:
mov v8.16b, v7.16b
mov v9.16b, v7.16b
mov v10.16b, v7.16b
mov v11.16b, v7.16b
mov v12.16b, v7.16b
mov v13.16b, v7.16b
mov v14.16b, v7.16b
mov v15.16b, v7.16b
mov v16.16b, v7.16b
mov v17.16b, v7.16b
mov v18.16b, v7.16b
mov v19.16b, v7.16b
mov v20.16b, v7.16b
mov v21.16b, v7.16b
mov v22.16b, v7.16b
mov v23.16b, v7.16b
mov v24.16b, v7.16b
mov v25.16b, v7.16b
mov v26.16b, v7.16b
mov v27.16b, v7.16b
mov v28.16b, v7.16b
mov v29.16b, v7.16b
mov v30.16b, v7.16b
mov v31.16b, v7.16b
InitBiasEnd:
.endm
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
// r19 ~ r29 should be also preserved
// whereas our coding style do not permit such amount of parameters
sub sp, sp, #144
sub sp, sp, #176
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
stp x19, x20, [sp], #16
stp x21, x22, [sp], #16
stp x23, x24, [sp], #16
ldr x15, [sp]
ldr w8, [sp, #8]
ldr w9, [sp, #16]
ldr w16, [sp, #24]
ldr w17, [sp, #32]
ldr w18, [sp, #40]
ldr w19, [sp, #48]
ldr x17, [sp, #32]
ldr x18, [sp, #40]
ldr x19, [sp, #48]
ldr x20, [sp, #56]
ldr x21, [sp, #64]
add x24, x6, #3
mov x23, #4
sdiv x23, x24, x23
mul x5, x4, x5
mov x4, #1
@ -206,7 +270,7 @@ IndirectGemmInt8_24x4_dp:
b LoopIc
LoopIcEnd:
mov x20, x15
mov x22, x15
// load input for output 1-8
ld1 {v0.16b, v1.16b}, [x12], #32
.inst 0x4f82e0d0 // sdot v16.4s, v6.16b, v2.4b[0]
@ -276,7 +340,16 @@ IndirectGemmInt8_24x4_dp:
.inst 0x4fa5e8df // sdot v31.4s, v6.16b, v5.4b[3]
Quantization:
dup v2.4s, w18
cbnz x21, PerChannel
ld1r {v2.4s}, [x18]
ld1r {v3.4s}, [x17]
ld1r {v4.4s}, [x19]
b QuantizeStart
PerChannel:
ld1 {v2.4s}, [x18]
ld1 {v3.4s}, [x17]
ld1 {v4.4s}, [x19]
QuantizeStart:
sqshl v8.4s, v8.4s, v2.4s
sqshl v9.4s, v9.4s, v2.4s
sqshl v10.4s, v10.4s, v2.4s
@ -302,7 +375,6 @@ IndirectGemmInt8_24x4_dp:
sqshl v30.4s, v30.4s, v2.4s
sqshl v31.4s, v31.4s, v2.4s
dup v3.4s, w17
sqrdmulh v8.4s, v8.4s, v3.4s
sqrdmulh v9.4s, v9.4s, v3.4s
sqrdmulh v10.4s, v10.4s, v3.4s
@ -328,100 +400,99 @@ IndirectGemmInt8_24x4_dp:
sqrdmulh v30.4s, v30.4s, v3.4s
sqrdmulh v31.4s, v31.4s, v3.4s
dup v4.4s, w19
add v0.16b, v4.16b, v8.16b
and v0.16b, v4.16b, v8.16b
sshr v0.4s, v0.4s, #31
sqadd v8.4s, v8.4s, v0.4s
srshl v8.4s, v8.4s, v4.4s
add v0.16b, v4.16b, v9.16b
and v0.16b, v4.16b, v9.16b
sshr v1.4s, v1.4s, #31
sqadd v9.4s, v9.4s, v1.4s
srshl v9.4s, v9.4s, v4.4s
add v2.16b, v4.16b, v10.16b
and v2.16b, v4.16b, v10.16b
sshr v2.4s, v2.4s, #31
sqadd v10.4s, v10.4s, v2.4s
srshl v10.4s, v10.4s, v4.4s
add v3.16b, v4.16b, v11.16b
and v3.16b, v4.16b, v11.16b
sshr v3.4s, v3.4s, #31
sqadd v11.4s, v11.4s, v3.4s
srshl v11.4s, v11.4s, v4.4s
add v0.16b, v4.16b, v12.16b
and v0.16b, v4.16b, v12.16b
sshr v0.4s, v0.4s, #31
sqadd v12.4s, v12.4s, v0.4s
srshl v12.4s, v12.4s, v4.4s
add v0.16b, v4.16b, v13.16b
and v0.16b, v4.16b, v13.16b
sshr v1.4s, v1.4s, #31
sqadd v13.4s, v13.4s, v1.4s
srshl v13.4s, v13.4s, v4.4s
add v2.16b, v4.16b, v14.16b
and v2.16b, v4.16b, v14.16b
sshr v2.4s, v2.4s, #31
sqadd v14.4s, v14.4s, v2.4s
srshl v14.4s, v14.4s, v4.4s
add v3.16b, v4.16b, v15.16b
and v3.16b, v4.16b, v15.16b
sshr v3.4s, v3.4s, #31
sqadd v15.4s, v15.4s, v3.4s
srshl v15.4s, v15.4s, v4.4s
add v0.16b, v4.16b, v16.16b
and v0.16b, v4.16b, v16.16b
sshr v0.4s, v0.4s, #31
sqadd v16.4s, v16.4s, v0.4s
srshl v16.4s, v16.4s, v4.4s
add v0.16b, v4.16b, v17.16b
and v0.16b, v4.16b, v17.16b
sshr v1.4s, v1.4s, #31
sqadd v17.4s, v17.4s, v1.4s
srshl v17.4s, v17.4s, v4.4s
add v2.16b, v4.16b, v18.16b
and v2.16b, v4.16b, v18.16b
sshr v2.4s, v2.4s, #31
sqadd v18.4s, v18.4s, v2.4s
srshl v18.4s, v18.4s, v4.4s
add v3.16b, v4.16b, v19.16b
and v3.16b, v4.16b, v19.16b
sshr v3.4s, v3.4s, #31
sqadd v19.4s, v19.4s, v3.4s
srshl v19.4s, v19.4s, v4.4s
add v0.16b, v4.16b, v20.16b
and v0.16b, v4.16b, v20.16b
sshr v0.4s, v0.4s, #31
sqadd v20.4s, v20.4s, v0.4s
srshl v20.4s, v20.4s, v4.4s
add v0.16b, v4.16b, v21.16b
and v0.16b, v4.16b, v21.16b
sshr v1.4s, v1.4s, #31
sqadd v21.4s, v21.4s, v1.4s
srshl v21.4s, v21.4s, v4.4s
add v2.16b, v4.16b, v22.16b
and v2.16b, v4.16b, v22.16b
sshr v2.4s, v2.4s, #31
sqadd v22.4s, v22.4s, v2.4s
srshl v10.4s, v10.4s, v4.4s
add v3.16b, v4.16b, v23.16b
srshl v22.4s, v22.4s, v4.4s
and v3.16b, v4.16b, v23.16b
sshr v3.4s, v3.4s, #31
sqadd v23.4s, v23.4s, v3.4s
srshl v23.4s, v23.4s, v4.4s
add v0.16b, v4.16b, v24.16b
and v0.16b, v4.16b, v24.16b
sshr v0.4s, v0.4s, #31
sqadd v24.4s, v24.4s, v0.4s
srshl v24.4s, v24.4s, v4.4s
add v0.16b, v4.16b, v25.16b
and v0.16b, v4.16b, v25.16b
sshr v1.4s, v1.4s, #31
sqadd v25.4s, v25.4s, v1.4s
srshl v25.4s, v25.4s, v4.4s
add v2.16b, v4.16b, v26.16b
and v2.16b, v4.16b, v26.16b
sshr v2.4s, v2.4s, #31
sqadd v26.4s, v26.4s, v2.4s
srshl v26.4s, v26.4s, v4.4s
add v3.16b, v4.16b, v27.16b
and v3.16b, v4.16b, v27.16b
sshr v3.4s, v3.4s, #31
sqadd v27.4s, v27.4s, v3.4s
srshl v27.4s, v27.4s, v4.4s
add v0.16b, v4.16b, v28.16b
and v0.16b, v4.16b, v28.16b
sshr v0.4s, v0.4s, #31
sqadd v28.4s, v28.4s, v0.4s
srshl v28.4s, v28.4s, v4.4s
add v0.16b, v4.16b, v29.16b
and v0.16b, v4.16b, v29.16b
sshr v1.4s, v1.4s, #31
sqadd v29.4s, v29.4s, v1.4s
srshl v29.4s, v29.4s, v4.4s
add v2.16b, v4.16b, v30.16b
and v2.16b, v4.16b, v30.16b
sshr v2.4s, v2.4s, #31
sqadd v30.4s, v30.4s, v2.4s
srshl v30.4s, v30.4s, v4.4s
add v3.16b, v4.16b, v31.16b
and v3.16b, v4.16b, v31.16b
sshr v3.4s, v3.4s, #31
sqadd v31.4s, v31.4s, v3.4s
srshl v31.4s, v31.4s, v4.4s
@ -694,15 +765,24 @@ IndirectGemmInt8_24x4_dp:
bne LoopKsize
subs x6, x6, #4
cbz x21, NoChannelForward
cbz x20, NoSumForward
add x15, x15, #16
NoSumForward:
add x17, x17, #16
add x18, x18, #16
add x19, x19, #16
NoChannelForward:
cbz x3, NoStepFowrard
add x3, x3, #16
NoStepFowrard:
bgt LoopOc
sub sp, sp, #144
sub sp, sp, #176
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ldp x19, x20, [sp], #16
ldp x21, x22, [sp], #16
ldp x23, x24, [sp], #16
ret
#endif

View File

@ -16,7 +16,7 @@
#include "nnacl/fp32/common_func.h"
#ifndef __aarch64__
#ifndef ENABLE_ARM64
void MatrixAdd(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stride, size_t b_stride, size_t c_stride,
size_t row, size_t col) {
for (int r = 0; r < row; r++) {

View File

@ -40,8 +40,8 @@ void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t *
size_t oc4, size_t offset);
void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize,
size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min,
size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before,
size_t shift_after);
size_t act_max, size_t out_zp, int32_t *out_multiplier, int32_t *shift_before,
int32_t *shift_after, size_t asymmetric, size_t per_channel);
void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width,
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step);

View File

@ -29,14 +29,12 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in
int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0];
int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0];
int oc4 = UP_DIV(output_channel, C4NUM);
#ifdef __aarch64__
#ifdef ENABLE_ARM64
size_t asymmetric = conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC;
size_t per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL;
IndirectGemmInt8_4x4(dst, src, weight, bias, UP_DIV(kernel_plane, C4NUM), ic4, output_channel,
output_channel * sizeof(int8_t), input_sum, act_min, act_max, out_zp, out_multiplier,
shift_before, shift_after);
// #elif defined(ENABLE_ARM32)
// IndirectGemmInt8_2x4(dst, src, weight, bias, UP_DIV(kernel_plane, C4NUM), ic4, output_channel,
// output_channel * sizeof(int8_t), input_sum, act_min, act_max, out_zp, out_multiplier,
// shift_before, shift_after);
shift_before, shift_after, asymmetric, per_channel);
#else
int tile_num = conv_param->tile_num_;
int plane_c4 = UP_DIV(kernel_plane, C4NUM);
@ -124,8 +122,10 @@ void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const
int oc4 = UP_DIV(output_channel, C4NUM);
if (gemm_func != NULL) {
#ifdef __aarch64__
size_t asymmetric = conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC;
size_t per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL;
gemm_func(dst, src, weight, bias, kernel_plane, ic4, output_channel, output_channel * sizeof(int8_t), input_sum,
act_min, act_max, out_zp, out_multiplier, shift_before, shift_after);
act_min, act_max, out_zp, out_multiplier, shift_before, shift_after, asymmetric, per_channel);
#endif
} else {
int tile_num = conv_param->tile_num_;

View File

@ -28,8 +28,8 @@
typedef void (*GEMM_FUNC)(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, size_t ksize,
size_t ic4, size_t output_channel, size_t offset, const int32_t *input_sum, size_t act_min,
size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before,
size_t shift_after);
size_t act_max, size_t out_zp, int32_t *out_multiplier, int32_t *shift_before,
int32_t *shift_after, size_t asymmetric, size_t per_channel);
#ifdef __cplusplus
extern "C" {

View File

@ -22,11 +22,11 @@ extern "C" {
extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias,
size_t ksize, size_t ic4, size_t output_channel, size_t offset,
const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp,
size_t out_multiplier, size_t shift_before, size_t shift_after);
int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after,
size_t asymmetric, size_t per_channel);
extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16,
const int *input_sum, const int *bias);
#ifdef __cplusplus
}
#endif
@ -35,9 +35,10 @@ extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, in
void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias,
size_t ksize, size_t ic4, size_t output_channel, size_t offset,
const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp,
size_t out_multiplier, size_t shift_before, size_t shift_after) {
int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after,
size_t asymmetric, size_t per_channel) {
return IndirectGemmInt8_24x4_dp(dst, src, weight, bias, ksize, ic4, output_channel, offset, input_sum, act_min,
act_max, out_zp, out_multiplier, shift_before, shift_after);
act_max, out_zp, out_multiplier, shift_before, shift_after, asymmetric, per_channel);
}
void MatMulR4Int8_optimize_handler(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16,

View File

@ -879,8 +879,8 @@ void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int
const float *src_ptr = src_batch + hw * channel + c;
float *dst_ptr = dst_batch + c * plane + hw;
#ifdef ENABLE_ARM64
int srcStride = channel * 4;
int dstStride = plane * 4;
size_t srcStride = channel * sizeof(float);
size_t dstStride = plane * sizeof(float);
asm volatile(
"mov x10, %[src_ptr]\n"
"mov x11, %[dst_ptr]\n"