!8813 [MS][LITE][Develop]optimization for quantized convolution per oc

From: @lx0095
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-24 11:28:39 +08:00 committed by Gitee
commit a86c0da849
12 changed files with 797 additions and 1297 deletions

View File

@ -0,0 +1,299 @@
#ifdef __arm__
#ifndef __aarch64__
.text
.align 5
.global MatmulInt8Neon32Opt
#ifndef __APPLE__
.type MatmulInt8Neon32Opt, %function
#endif
//void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16,
// const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp,
// int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel,
// int *filter_zp);
// #-52: a, #-48: b, #-44: dst, #-40: row
// #0: col, #4: deep16, #8: input_sums, #12: weight_bias, #16: act_min, #20: act_max, #24: out_zp
// #28: multiplier, #32: left_shift, #36: right_shift, #40: stride, #44: per_channel, #48: filter_zp
MatmulInt8Neon32Opt:
push {r0-r11, lr}
vpush {q4-q7}
add sp, sp, #116
ldr r4, [sp] // col
ldr r7, [sp, #40] // output stride
mov r8, #0 // output channels offset
ldr r10, [sp, #44]
cmp r10, #0
beq L1
ldr r6, [sp, #8] // load intpu_sums ptr if per_channel
L1:
cmp r4, #0 // if at the end of col
ble End1
ldr r0, [sp, #-52] // reload a ptr
ldr r3, [sp, #-40] // reset row counter
ldr r6, [sp, #8] // reload intpu_sums ptr if per_tensor
L2:
cmp r3, #0 // if at the end of row
ble End2
ldr r1, [sp, #-48] // reload b ptr
ldr r5, [sp, #4] // reset deep16
vmov.i32 q6, #0
vmov.i32 q7, #0
vmov.i32 q8, #0
vmov.i32 q9, #0
vmov.i32 q10, #0
vmov.i32 q11, #0
vmov.i32 q12, #0
vmov.i32 q13, #0
L3:
cmp r5, #0
beq End3
vld1.8 {d0, d1, d2, d3}, [r0]!
vld1.8 {d8, d9, d10, d11}, [r1]!
vmull.s8 q14, d0, d8
vmull.s8 q2, d0, d10
vmull.s8 q15, d2, d8
vmull.s8 q3, d2, d10
vmlal.s8 q14, d1, d9
vmlal.s8 q2, d1, d11
vmlal.s8 q15, d3, d9
vmlal.s8 q3, d3, d11
vpadal.s16 q6, q14
vpadal.s16 q7, q2
vpadal.s16 q8, q15
vpadal.s16 q9, q3
vld1.8 {d0, d1, d2, d3}, [r0]!
vmull.s8 q14, d0, d8
vmull.s8 q2, d0, d10
vmull.s8 q15, d2, d8
vmull.s8 q3, d2, d10
vmlal.s8 q14, d1, d9
vmlal.s8 q2, d1, d11
vmlal.s8 q15, d3, d9
vmlal.s8 q3, d3, d11
vpadal.s16 q10, q14
vpadal.s16 q11, q2
vpadal.s16 q12, q15
vpadal.s16 q13, q3
sub r5, r5, #16 // deep16 -= 16
b L3
End3:
vpadd.i32 d0, d12, d13
vpadd.i32 d1, d14, d15
vpadd.i32 d2, d16, d17
vpadd.i32 d3, d18, d19
vpadd.i32 d4, d20, d21
vpadd.i32 d5, d22, d23
vpadd.i32 d6, d24, d25
vpadd.i32 d7, d26, d27
vpadd.i32 d28, d0, d1
vpadd.i32 d29, d2, d3
vpadd.i32 d30, d4, d5
vpadd.i32 d31, d6, d7
// Add weight_bias
ldr r9, [sp, #12] // reload weight_bias ptr
add r9, r9, r8
vld1.32 {d26}, [r9]!
vadd.i32 d28, d28, d26
vadd.i32 d29, d29, d26
vadd.i32 d30, d30, d26
vadd.i32 d31, d31, d26
ldr r10, [sp, #44]
cmp r10, #0
bgt PerChannel
PerTensor:
// Substract input_sums
vld1.32 {d24, d25}, [r6]!
vdup.32 d20, d24[0]
vdup.32 d21, d24[1]
vdup.32 d22, d25[0]
vdup.32 d23, d25[1]
vsub.s32 d28, d28, d20
vsub.s32 d29, d29, d21
vsub.s32 d30, d30, d22
vsub.s32 d31, d31, d23
// Apply left shift
ldr r10, [sp, #32]
ldr r11, [r10]!
vdup.32 q9, r11
vshl.s32 q14, q14, q9
vshl.s32 q15, q15, q9
// Apply the fixed-point part of the multiplier
ldr r10, [sp, #28]
ldr r11, [r10]
vdup.32 q8, r11
vqrdmulh.s32 q14, q14, q8
vqrdmulh.s32 q15, q15, q8
// Apply right shift
ldr r10, [sp, #36]
ldr r11, [r10]
vdup.32 q7, r11
vand q6, q7, q14
vshr.s32 q6, q6, #31
vqadd.s32 q14, q14, q6
vrshl.s32 q14, q14, q7
vand q5, q7, q15
vshr.s32 q5, q5, #31
vqadd.s32 q15, q15, q5
vrshl.s32 q15, q15, q7
b AddDstZP
PerChannel:
// Substract input_sums
vld1.32 {d24, d25}, [r6]!
vdup.32 d20, d24[0]
vdup.32 d21, d24[1]
vdup.32 d22, d25[0]
vdup.32 d23, d25[1]
ldr r10, [sp, #48]
vld1.32 {d19}, [r10]
vmul.s32 d24, d20, d19
vmul.s32 d25, d21, d19
vmul.s32 d26, d22, d19
vmul.s32 d27, d23, d19
vsub.s32 d28, d28, d24
vsub.s32 d29, d29, d25
vsub.s32 d30, d30, d26
vsub.s32 d31, d31, d27
// Apply left shift
ldr r10, [sp, #32]
add r10, r10, r8
vld1.32 {d23}, [r10]
vshl.s32 d28, d28, d23
vshl.s32 d29, d29, d23
vshl.s32 d30, d30, d23
vshl.s32 d31, d31, d23
// Apply the fixed-point part of the multiplier
ldr r10, [sp, #28]
add r10, r10, r8
vld1.32 {d22}, [r10]
vqrdmulh.s32 d28, d28, d22
vqrdmulh.s32 d29, d29, d22
vqrdmulh.s32 d30, d30, d22
vqrdmulh.s32 d31, d31, d22
// Apply right shift
ldr r10, [sp, #36]
add r10, r10, r8
vld1.32 {d21}, [r10]
vand d20, d21, d28
vshr.s32 d20, d20, #31
vqadd.s32 d28, d28, d20
vrshl.s32 d28, d28, d21
vand d19, d21, d29
vshr.s32 d19, d19, #31
vqadd.s32 d29, d29, d19
vrshl.s32 d29, d29, d21
vand d18, d21, d30
vshr.s32 d18, d18, #31
vqadd.s32 d30, d30, d18
vrshl.s32 d30, d30, d21
vand d17, d21, d31
vshr.s32 d17, d17, #31
vqadd.s32 d31, d31, d17
vrshl.s32 d31, d31, d21
AddDstZP:
// Add the destination zero point
ldr r10, [sp, #24]
vdup.32 q4, r10
vadd.i32 q14, q14, q4
vadd.i32 q15, q15, q4
// Apply the act_min bound
ldr r10, [sp, #16]
vdup.32 q3, r10
vmax.s32 q14, q14, q3
vmax.s32 q15, q15, q3
// Apply the act_max bound
ldr r10, [sp, #20]
vdup.32 q2, r10
vmin.s32 q14, q14, q2
vmin.s32 q15, q15, q2
// Cast-and-saturate from int32 to int16
vqmovn.s32 d28, q14
vqmovn.s32 d29, q15
// Cast-and-saturate from int16 to int8
vqmovn.s16 d30, q14
// start to write
cmp r4, #2
bge WriteCol2
cmp r4, #1
beq WriteCol1
b EndWrite
WriteCol2:
vst1.16 {d30[0]}, [r2], r7
cmp r3, #1
beq EndWrite
vst1.16 {d30[1]}, [r2], r7
cmp r3, #2
beq EndWrite
vst1.16 {d30[2]}, [r2], r7
cmp r3, #3
beq EndWrite
vst1.16 {d30[3]}, [r2], r7
b EndWrite
WriteCol1:
vst1.8 {d30[0]}, [r2], r7
cmp r3, #1
beq EndWrite
vst1.8 {d30[2]}, [r2], r7
cmp r3, #2
beq EndWrite
vst1.8 {d30[4]}, [r2], r7
cmp r3, #3
beq EndWrite
vst1.8 {d30[6]}, [r2], r7
b EndWrite
EndWrite:
sub r3, r3, #4 // a row counter -= 4
b L2
End2:
sub r4, r4, #2 // b col counter -= 2
ldr r1, [sp, #-48] // load b ptr
ldr r9, [sp, #4]
mov r10, #2
mul r9, r9, r10 // the stride of b
add r1, r1, r9 // b ptr + stride
str r1, [sp, #-48]
ldr r2, [sp, #-44] // load dst ptr
add r2, r2, #2 // dst ptr + offset
str r2, [sp, #-44]
ldr r10, [sp, #48]
add r10, r10, #8
str r10, [sp, #48]
add r8, r8, #8 // output channels offset + 2*sizeof(int)
b L1
End1:
sub sp, sp, #116
vpop {q4-q7}
pop {r0-r11, pc}
#endif
#endif

View File

@ -1,371 +0,0 @@
#ifdef __aarch64__
.text
.align 5
.global IndirectGemmInt8_4x4
#ifndef __APPLE__
.type IndirectGemmInt8_4x4, %function
#endif
// void IndirectGemmInt8_4x4(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4,
// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, int32_t *out_multiplier,
// int32_t *shift_before, int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset);
// x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset
IndirectGemmInt8_4x4:
.macro INIT_BIAS
dup v16.4s, wzr
dup v17.4s, wzr
dup v18.4s, wzr
dup v19.4s, wzr
dup v20.4s, wzr
dup v21.4s, wzr
dup v22.4s, wzr
dup v23.4s, wzr
dup v24.4s, wzr
dup v25.4s, wzr
dup v26.4s, wzr
dup v27.4s, wzr
dup v28.4s, wzr
dup v29.4s, wzr
dup v30.4s, wzr
dup v31.4s, wzr
.endm
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
// r19 ~ r29 should be also preserved
// whereas our coding style do not permit such amount of parameters
sub sp, sp, #176
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
stp x19, x20, [sp], #16
stp x21, x22, [sp], #16
stp x23, x24, [sp], #16
ldr x15, [sp]
ldr w8, [sp, #8]
ldr w9, [sp, #16]
ldr w16, [sp, #24]
ldr x17, [sp, #32]
ldr x18, [sp, #40]
ldr x19, [sp, #48]
ldr x20, [sp, #56]
ldr x21, [sp, #64]
ldr x23, [sp, #72]
mul x5, x4, x5
mov x4, #1
LoopOc:
mov x10, x4
mov x12, x1
LoopKsize:
INIT_BIAS
mov x11, x0
// as some processors do not support sdot intrinsic, we use instruction word
// dp support is stilled judged dymaticly, instruction word is just used to ensure compilation
// according to https://static.docs.arm.com/ddi0596/g/ISA_A64_xml_v86A-2020-03_OPT.pdf
// the instruction word of sdot vd.4s, vn.16b, vm.4b[index] is
// 0100 1111 10Lm mmmm 1110 H0nn nnnd dddd
// mmmmm/nnnnn/ddddd is the number of neon register, HL is the high/low bit of index
// load input for output 1-8
ld1 {v0.16b, v1.16b}, [x12], #32
// load weight
ld1 {v4.16b, v5.16b}, [x2], #32
// step for output 1-4
smull v8.8h, v0.8b, v4.8b
smull v9.8h, v0.8b, v5.8b
smlal2 v8.8h, v0.16b, v4.16b
smlal2 v9.8h, v0.16b, v5.16b
// load input for output 9-16
ld1 {v6.16b, v7.16b}, [x2], #32
// another step for output 5-8
smull v12.8h, v1.8b, v4.8b
smull v13.8h, v1.8b, v5.8b
smlal2 v12.8h, v1.16b, v4.16b
smlal2 v13.8h, v1.16b, v5.16b
ld1 {v2.16b, v3.16b}, [x12], #32
smull v10.8h, v0.8b, v6.8b
smull v11.8h, v0.8b, v7.8b
saddlp v16.4s, v8.8h
smlal2 v10.8h, v0.16b, v6.16b
smlal2 v11.8h, v0.16b, v7.16b
saddlp v17.4s, v9.8h
smull v14.8h, v1.8b, v6.8b
smull v15.8h, v1.8b, v7.8b
saddlp v18.4s, v10.8h
smlal2 v14.8h, v1.16b, v6.16b
smlal2 v15.8h, v1.16b, v7.16b
subs x13, x5, #1
beq LoopIcEnd
LoopIc:
// load input for output 1-8
ld1 {v0.16b, v1.16b}, [x12], #32
sadalp v19.4s, v11.8h
smull v8.8h, v2.8b, v4.8b
smull v9.8h, v2.8b, v5.8b
sadalp v20.4s, v12.8h
smlal2 v8.8h, v2.16b, v4.16b
smlal2 v9.8h, v2.16b, v5.16b
sadalp v21.4s, v13.8h
smull v10.8h, v2.8b, v6.8b
smull v11.8h, v2.8b, v7.8b
sadalp v22.4s, v14.8h
smlal2 v10.8h, v2.16b, v6.16b
smlal2 v11.8h, v2.16b, v7.16b
sadalp v23.4s, v15.8h
smull v12.8h, v3.8b, v4.8b
smull v13.8h, v3.8b, v5.8b
sadalp v24.4s, v8.8h
smlal2 v12.8h, v3.16b, v4.16b
smlal2 v13.8h, v3.16b, v5.16b
ld1 {v4.16b, v5.16b}, [x2], #32
sadalp v25.4s, v9.8h
smull v14.8h, v3.8b, v6.8b
smull v15.8h, v3.8b, v7.8b
sadalp v26.4s, v10.8h
smlal2 v14.8h, v3.16b, v6.16b
smlal2 v15.8h, v3.16b, v7.16b
ld1 {v6.16b, v7.16b}, [x2], #32
sadalp v27.4s, v11.8h
smull v8.8h, v0.8b, v4.8b
smull v9.8h, v0.8b, v5.8b
sadalp v28.4s, v12.8h
smlal2 v8.8h, v0.16b, v4.16b
smlal2 v9.8h, v0.16b, v5.16b
ld1 {v2.16b, v3.16b}, [x12], #32
sadalp v29.4s, v13.8h
smull v12.8h, v1.8b, v4.8b
smull v13.8h, v1.8b, v5.8b
sadalp v30.4s, v14.8h
smlal2 v12.8h, v1.16b, v4.16b
smlal2 v13.8h, v1.16b, v5.16b
sadalp v31.4s, v15.8h
smull v10.8h, v0.8b, v6.8b
smull v11.8h, v0.8b, v7.8b
sadalp v16.4s, v8.8h
smlal2 v10.8h, v0.16b, v6.16b
smlal2 v11.8h, v0.16b, v7.16b
sadalp v17.4s, v9.8h
smull v14.8h, v1.8b, v6.8b
smull v15.8h, v1.8b, v7.8b
sadalp v18.4s, v10.8h
smlal2 v14.8h, v1.16b, v6.16b
smlal2 v15.8h, v1.16b, v7.16b
subs x13, x13, #1
bne LoopIc
LoopIcEnd:
sadalp v19.4s, v11.8h
smull v8.8h, v2.8b, v4.8b
smull v9.8h, v2.8b, v5.8b
sadalp v20.4s, v12.8h
smlal2 v8.8h, v2.16b, v4.16b
smlal2 v9.8h, v2.16b, v5.16b
sadalp v21.4s, v13.8h
smull v10.8h, v2.8b, v6.8b
smull v11.8h, v2.8b, v7.8b
sadalp v22.4s, v14.8h
smlal2 v10.8h, v2.16b, v6.16b
smlal2 v11.8h, v2.16b, v7.16b
sadalp v23.4s, v15.8h
smull v12.8h, v3.8b, v4.8b
smull v13.8h, v3.8b, v5.8b
sadalp v24.4s, v8.8h
smlal2 v12.8h, v3.16b, v4.16b
smlal2 v13.8h, v3.16b, v5.16b
sadalp v25.4s, v9.8h
smull v14.8h, v3.8b, v6.8b
smull v15.8h, v3.8b, v7.8b
sadalp v26.4s, v10.8h
smlal2 v14.8h, v3.16b, v6.16b
smlal2 v15.8h, v3.16b, v7.16b
sadalp v27.4s, v11.8h
sadalp v28.4s, v12.8h
sadalp v29.4s, v13.8h
sadalp v30.4s, v14.8h
sadalp v31.4s, v15.8h
// pairwise add
addp v16.4s, v16.4s, v17.4s
addp v18.4s, v18.4s, v19.4s
addp v20.4s, v20.4s, v21.4s
addp v22.4s, v22.4s, v23.4s
addp v24.4s, v24.4s, v25.4s
addp v26.4s, v26.4s, v27.4s
addp v28.4s, v28.4s, v29.4s
addp v30.4s, v30.4s, v31.4s
dup v12.4s, wzr
cbz x3, NoReadBias
ld1 {v12.4s}, [x3]
NoReadBias:
addp v16.4s, v16.4s, v18.4s
addp v20.4s, v20.4s, v22.4s
addp v24.4s, v24.4s, v26.4s
addp v28.4s, v28.4s, v30.4s
cbz x20, NoSum
// load sum
mov x22, x15
cbz x21, SymSum
ld1 {v8.4s}, [x22], x23
ld1 {v9.4s}, [x22], x23
ld1 {v10.4s}, [x22], x23
ld1 {v11.4s}, [x22]
b AddSum
SymSum:
ld1r {v8.4s}, [x22], #4
ld1r {v9.4s}, [x22], #4
ld1r {v10.4s}, [x22], #4
ld1r {v11.4s}, [x22]
AddSum:
sub v16.4s, v16.4s, v8.4s
sub v20.4s, v20.4s, v9.4s
sub v24.4s, v24.4s, v10.4s
sub v28.4s, v28.4s, v11.4s
NoSum:
add v16.4s, v16.4s, v12.4s
add v20.4s, v20.4s, v12.4s
add v24.4s, v24.4s, v12.4s
add v28.4s, v28.4s, v12.4s
cbnz x21, PerChannel
ld1r {v2.4s}, [x18]
ld1r {v3.4s}, [x17]
ld1r {v4.4s}, [x19]
b QuantizeStart
PerChannel:
ld1 {v2.4s}, [x18]
ld1 {v3.4s}, [x17]
ld1 {v4.4s}, [x19]
QuantizeStart:
sqshl v16.4s, v16.4s, v2.4s
sqshl v20.4s, v20.4s, v2.4s
sqshl v24.4s, v24.4s, v2.4s
sqshl v28.4s, v28.4s, v2.4s
sqrdmulh v16.4s, v16.4s, v3.4s
sqrdmulh v20.4s, v20.4s, v3.4s
sqrdmulh v24.4s, v24.4s, v3.4s
sqrdmulh v28.4s, v28.4s, v3.4s
and v0.16b, v4.16b, v16.16b
sshr v0.4s, v0.4s, #31
sqadd v16.4s, v16.4s, v0.4s
srshl v16.4s, v16.4s, v4.4s
and v1.16b, v4.16b, v20.16b
sshr v1.4s, v1.4s, #31
sqadd v20.4s, v20.4s, v1.4s
srshl v20.4s, v20.4s, v4.4s
and v2.16b, v4.16b, v24.16b
sshr v2.4s, v2.4s, #31
sqadd v24.4s, v24.4s, v2.4s
srshl v24.4s, v24.4s, v4.4s
and v3.16b, v4.16b, v28.16b
sshr v3.4s, v3.4s, #31
sqadd v28.4s, v28.4s, v3.4s
srshl v28.4s, v28.4s, v4.4s
dup v5.4s, w16
add v16.4s, v16.4s, v5.4s
add v20.4s, v20.4s, v5.4s
add v24.4s, v24.4s, v5.4s
add v28.4s, v28.4s, v5.4s
dup v0.4s, w8
smax v16.4s, v16.4s, v0.4s
smax v20.4s, v20.4s, v0.4s
smax v24.4s, v24.4s, v0.4s
smax v28.4s, v28.4s, v0.4s
dup v1.4s, w9
smin v16.4s, v16.4s, v1.4s
smin v20.4s, v20.4s, v1.4s
smin v24.4s, v24.4s, v1.4s
smin v28.4s, v28.4s, v1.4s
sqxtn v13.4h, v16.4s
sqxtn2 v13.8h, v20.4s
sqxtn v15.8b, v13.8h
sqxtn v14.4h, v24.4s
sqxtn2 v14.8h, v28.4s
sqxtn2 v15.16b, v14.8h
// prefetching is not prefered while writing results in spite of cache missings
// you could try prfm pstl2strm
WriteStart:
cmp x6, #1
beq Write1
cmp x6, #2
beq Write2
cmp x6, #3
beq Write3
b Write4
Write1:
st1 {v15.b}[0], [x11], x7
st1 {v15.b}[4], [x11], x7
st1 {v15.b}[8], [x11], x7
st1 {v15.b}[12], [x11]
add x0, x0, #1
b WriteEnd
Write2:
st1 {v15.h}[0], [x11], x7
st1 {v15.h}[2], [x11], x7
st1 {v15.h}[4], [x11], x7
st1 {v15.h}[6], [x11]
add x0, x0, #2
b WriteEnd
Write3:
add x14, x11, #2
st1 {v15.h}[0], [x11], x7
st1 {v15.b}[2], [x14], x7
st1 {v15.h}[2], [x11], x7
st1 {v15.b}[6], [x14], x7
st1 {v15.h}[4], [x11], x7
st1 {v15.b}[10], [x14], x7
st1 {v15.h}[6], [x11]
st1 {v15.b}[14], [x14]
add x0, x0, #3
b WriteEnd
Write4:
st1 {v15.s}[0], [x11], x7
st1 {v15.s}[1], [x11], x7
st1 {v15.s}[2], [x11], x7
st1 {v15.s}[3], [x11]
add x0, x0, #4
WriteEnd:
subs x10, x10, #1
bne LoopKsize
subs x6, x6, #4
cbz x21, NoChannelForward
cbz x20, NoSumForward
add x15, x15, #16
NoSumForward:
add x17, x17, #16
add x18, x18, #16
add x19, x19, #16
NoChannelForward:
cbz x3, NoStepFowrard
add x3, x3, #16
NoStepFowrard:
bgt LoopOc
sub sp, sp, #176
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ldp x19, x20, [sp], #16
ldp x21, x22, [sp], #16
ldp x23, x24, [sp], #16
ret
#endif

View File

@ -0,0 +1,408 @@
#ifdef __aarch64__
.text
.align 5
.global MatmulInt8Neon64Opt
#ifndef __APPLE__
.type MatmulInt8Neon64Opt, %function
#endif
//void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums,
// const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift,
// int32_t *right_shift, int row, int col, int stride, int filter_peroc, int32_t *filter_zp)
// x0: a(left matrix ptr)
// x1: b(right matrix ptr)
// x2: out ptr
// w3: row4
// w4: col4
// w5: deep16
// x6: a_sums
// x7: bias
// w8: act_min
// w9: act_max
// w10: out_zp
// x11: multiplier
// x12: left_shift
// x13: right_shift
// w14: row
// w15: col
// w24: stride
// w27: filter_peroc
// x28: filter_zp
MatmulInt8Neon64Opt:
sub sp, sp, #208
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
stp x19, x20, [sp], #16
stp x21, x22, [sp], #16
stp x23, x24, [sp], #16
stp x25, x26, [sp], #16
stp x27, x28, [sp], #16
ldr w8, [sp]
ldr w9, [sp, #8]
ldr w10, [sp, #16]
ldr x11, [sp, #24]
ldr x12, [sp, #32]
ldr x13, [sp, #40]
ldr w14, [sp, #48]
ldr w15, [sp, #56]
ldr w24, [sp, #64]
ldr w27, [sp, #72]
ldr x28, [sp, #80]
mov w17, #4 // sizeof(int8)*4
mul w21, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16
mov w17, #1
mov x25, x2
L1:
cmp w4, #0 // if at the end of col4
beq End1
mov w16, w3 // reset a row4 counter
mov w23, w14 // reset a row counter
mov x17, x0 // reload a ptr
mov x22, x6 // reload a_sums ptr
L2:
cmp w16, #0
beq End2
mov x18, x1 // reload b ptr
mov x19, x7 // reload bias ptr
mov w20, w5 // reload depth
dup v16.4s, wzr
dup v17.4s, wzr
dup v18.4s, wzr
dup v19.4s, wzr
dup v20.4s, wzr
dup v21.4s, wzr
dup v22.4s, wzr
dup v23.4s, wzr
dup v24.4s, wzr
dup v25.4s, wzr
dup v26.4s, wzr
dup v27.4s, wzr
dup v28.4s, wzr
dup v29.4s, wzr
dup v30.4s, wzr
dup v31.4s, wzr
L3:
cmp w20, #0
beq End3
ld1 {v0.16b}, [x17], #16
ld1 {v1.16b}, [x17], #16
ld1 {v2.16b}, [x17], #16
ld1 {v3.16b}, [x17], #16
ld1 {v4.16b}, [x18], #16
ld1 {v5.16b}, [x18], #16
ld1 {v6.16b}, [x18], #16
ld1 {v7.16b}, [x18], #16
smull v8.8h, v4.8b, v0.8b
smull v9.8h, v5.8b, v0.8b
smull v10.8h, v6.8b, v0.8b
smull v11.8h, v7.8b, v0.8b
smull v12.8h, v4.8b, v1.8b
smull v13.8h, v5.8b, v1.8b
smull v14.8h, v6.8b, v1.8b
smull v15.8h, v7.8b, v1.8b
smlal2 v8.8h, v4.16b, v0.16b
smlal2 v9.8h, v5.16b, v0.16b
smlal2 v10.8h, v6.16b, v0.16b
smlal2 v11.8h, v7.16b, v0.16b
smlal2 v12.8h, v4.16b, v1.16b
smlal2 v13.8h, v5.16b, v1.16b
smlal2 v14.8h, v6.16b, v1.16b
smlal2 v15.8h, v7.16b, v1.16b
sadalp v16.4s, v8.8h
sadalp v17.4s, v9.8h
sadalp v18.4s, v10.8h
sadalp v19.4s, v11.8h
sadalp v20.4s, v12.8h
sadalp v21.4s, v13.8h
sadalp v22.4s, v14.8h
sadalp v23.4s, v15.8h
smull v8.8h, v4.8b, v2.8b
smull v9.8h, v5.8b, v2.8b
smull v10.8h, v6.8b, v2.8b
smull v11.8h, v7.8b, v2.8b
smull v12.8h, v4.8b, v3.8b
smull v13.8h, v5.8b, v3.8b
smull v14.8h, v6.8b, v3.8b
smull v15.8h, v7.8b, v3.8b
smlal2 v8.8h, v4.16b, v2.16b
smlal2 v9.8h, v5.16b, v2.16b
smlal2 v10.8h, v6.16b, v2.16b
smlal2 v11.8h, v7.16b, v2.16b
smlal2 v12.8h, v4.16b, v3.16b
smlal2 v13.8h, v5.16b, v3.16b
smlal2 v14.8h, v6.16b, v3.16b
smlal2 v15.8h, v7.16b, v3.16b
sadalp v24.4s, v8.8h
sadalp v25.4s, v9.8h
sadalp v26.4s, v10.8h
sadalp v27.4s, v11.8h
sadalp v28.4s, v12.8h
sadalp v29.4s, v13.8h
sadalp v30.4s, v14.8h
sadalp v31.4s, v15.8h
subs w20, w20, #16 // depth + 16
b L3
End3:
addp v16.4s, v16.4s, v17.4s
addp v18.4s, v18.4s, v19.4s
addp v20.4s, v20.4s, v21.4s
addp v22.4s, v22.4s, v23.4s
addp v24.4s, v24.4s, v25.4s
addp v26.4s, v26.4s, v27.4s
addp v28.4s, v28.4s, v29.4s
addp v30.4s, v30.4s, v31.4s
addp v16.4s, v16.4s, v18.4s
addp v17.4s, v20.4s, v22.4s
addp v18.4s, v24.4s, v26.4s
addp v19.4s, v28.4s, v30.4s
// Add (Bias+Depth*Za*Zb-Za*Bsums)
ld1 {v15.4s}, [x19], #16
add v16.4s, v16.4s, v15.4s
add v17.4s, v17.4s, v15.4s
add v18.4s, v18.4s, v15.4s
add v19.4s, v19.4s, v15.4s
ld1r {v20.4s}, [x22], #4
ld1r {v21.4s}, [x22], #4
ld1r {v22.4s}, [x22], #4
ld1r {v23.4s}, [x22], #4
cmp w27, #0
beq Apply
ld1 {v14.4s}, [x28]
mul v20.4s, v20.4s, v14.4s
mul v21.4s, v21.4s, v14.4s
mul v22.4s, v22.4s, v14.4s
mul v23.4s, v23.4s, v14.4s
Apply:
// Subtract (Asums*Zb)
sub v16.4s, v16.4s, v20.4s
sub v17.4s, v17.4s, v21.4s
sub v18.4s, v18.4s, v22.4s
sub v19.4s, v19.4s, v23.4s
cmp w27, #1
beq PerCLoad
ld1r {v13.4s}, [x12]
ld1r {v12.4s}, [x11]
ld1r {v11.4s}, [x13]
b Quantize
PerCLoad:
ld1 {v13.4s}, [x12]
ld1 {v12.4s}, [x11]
ld1 {v11.4s}, [x13]
Quantize:
// Apply left shift
sqshl v16.4s, v16.4s, v13.4s
sqshl v17.4s, v17.4s, v13.4s
sqshl v18.4s, v18.4s, v13.4s
sqshl v19.4s, v19.4s, v13.4s
// Apply the fixed-point part of the multiplier.
sqrdmulh v16.4s, v16.4s, v12.4s
sqrdmulh v17.4s, v17.4s, v12.4s
sqrdmulh v18.4s, v18.4s, v12.4s
sqrdmulh v19.4s, v19.4s, v12.4s
// Apply right shift
and v20.16b, v11.16b, v16.16b
sshr v20.4s, v20.4s, #31
sqadd v16.4s, v16.4s, v20.4s
srshl v16.4s, v16.4s, v11.4s
and v21.16b, v11.16b, v17.16b
sshr v21.4s, v21.4s, #31
sqadd v17.4s, v17.4s, v21.4s
srshl v17.4s, v17.4s, v11.4s
and v22.16b, v11.16b, v18.16b
sshr v22.4s, v22.4s, #31
sqadd v18.4s, v18.4s, v22.4s
srshl v18.4s, v18.4s, v11.4s
and v23.16b, v11.16b, v19.16b
sshr v23.4s, v23.4s, #31
sqadd v19.4s, v19.4s, v23.4s
srshl v19.4s, v19.4s, v11.4s
// Add the destination zero point
dup v10.4s, w10
add v16.4s, v16.4s, v10.4s
add v17.4s, v17.4s, v10.4s
add v18.4s, v18.4s, v10.4s
add v19.4s, v19.4s, v10.4s
// Apply the act_min bound
dup v9.4s, w8
smax v16.4s, v16.4s, v9.4s
smax v17.4s, v17.4s, v9.4s
smax v18.4s, v18.4s, v9.4s
smax v19.4s, v19.4s, v9.4s
// Apply the act_min bound
dup v8.4s, w9
smin v16.4s, v16.4s, v8.4s
smin v17.4s, v17.4s, v8.4s
smin v18.4s, v18.4s, v8.4s
smin v19.4s, v19.4s, v8.4s
// int32 -> int16
sqxtn v13.4h, v16.4s
sqxtn2 v13.8h, v17.4s
sqxtn v14.4h, v18.4s
sqxtn2 v14.8h, v19.4s
// int16 -> int8
sqxtn v15.8b, v13.8h
sqxtn2 v15.16b, v14.8h
cmp w23, #4
blt Write // if rows < 4
cmp w15, #4
blt Write // if cols < 4
st1 {v15.s}[0], [x2], x24
st1 {v15.s}[1], [x2], x24
st1 {v15.s}[2], [x2], x24
st1 {v15.s}[3], [x2], x24
b Endwrite
Write:
cmp w15, #4
beq WriteCol4
cmp w15, #3
beq WriteCol3
cmp w15, #2
beq WriteCol2
cmp w15, #1
beq WriteCol1
WriteCol4:
st1 {v15.s}[0], [x2], x24
cmp w23, #1
beq Endwrite
st1 {v15.s}[1], [x2], x24
cmp w23, #2
beq Endwrite
st1 {v15.s}[2], [x2], x24
cmp w23, #3
beq Endwrite
st1 {v15.s}[3], [x2], x24
b Endwrite
WriteCol3:
mov x26, x2
st1 {v15.b}[0], [x26], #1
st1 {v15.b}[1], [x26], #1
st1 {v15.b}[2], [x26], #1
add x2, x2, x24
cmp w23, #1
beq Endwrite
mov x26, x2
st1 {v15.b}[4], [x26], #1
st1 {v15.b}[5], [x26], #1
st1 {v15.b}[6], [x26], #1
add x2, x2, x24
cmp w23, #2
beq Endwrite
mov x26, x2
st1 {v15.b}[8], [x26], #1
st1 {v15.b}[9], [x26], #1
st1 {v15.b}[10], [x26], #1
add x2, x2, x24
cmp w23, #3
beq Endwrite
mov x26, x2
st1 {v15.b}[12], [x26], #1
st1 {v15.b}[13], [x26], #1
st1 {v15.b}[14], [x26], #1
add x2, x2, x24
b Endwrite
WriteCol2:
mov x26, x2
st1 {v15.b}[0], [x26], #1
st1 {v15.b}[1], [x26], #1
add x2, x2, x24
cmp w23, #1
beq Endwrite
mov x26, x2
st1 {v15.b}[4], [x26], #1
st1 {v15.b}[5], [x26], #1
add x2, x2, x24
cmp w23, #2
beq Endwrite
mov x26, x2
st1 {v15.b}[8], [x26], #1
st1 {v15.b}[9], [x26], #1
add x2, x2, x24
cmp w23, #3
beq Endwrite
mov x26, x2
st1 {v15.b}[12], [x26], #1
st1 {v15.b}[13], [x26], #1
add x2, x2, x24
b Endwrite
WriteCol1:
st1 {v15.b}[0], [x2], x24
cmp w23, #1
beq Endwrite
st1 {v15.b}[4], [x2], x24
cmp w23, #2
beq Endwrite
st1 {v15.b}[8], [x2], x24
cmp w23, #3
beq Endwrite
st1 {v15.b}[12], [x2], x24
b Endwrite
Endwrite:
sub w16, w16, #4 // a row4 counter - 4
sub w23, w23, #4 // a row counter - 4
b L2
End2:
sub w4, w4, #4 // b col4 counter - 4
sub w15, w15, #4 // b col counter - 4
add x1, x1, x21 // b ptr + stride
add x7, x7, #16 // bias ptr + stride
add x25, x25, #4 // output + stride(4 * sizeof(int8))
mov x2, x25
cmp w27, #0
beq PerTEnd2
add x12, x12, #16
add x11, x11, #16
add x13, x13, #16
add x28, x28, #16
PerTEnd2:
b L1
End1:
sub sp, sp, #208
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ldp x19, x20, [sp], #16
ldp x21, x22, [sp], #16
ldp x23, x24, [sp], #16
ldp x25, x26, [sp], #16
ldp x27, x28, [sp], #16
ret
#endif

View File

@ -1,785 +0,0 @@
#ifdef __aarch64__
.text
.align 5
.global IndirectGemmInt8_24x4_dp
#ifndef __APPLE__
.type IndirectGemmInt8_24x4_dp, %function
#endif
// void IndirectGemmInt8_24x4_dp(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4,
// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, int32_t *out_multiplier,
// int32_t *shift_before, int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset);
// x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset
// we use sdot intrinsic on cores that supports dotprod(Armv8.2-A w/dp or later)
// mrs intrinsic could read system register ID_AA64ISAR0_EL1(or s3_0_c0_c6_0 on Armv8.2-A)
// the 44-48 bits indicates whether dotprod is supported
IndirectGemmInt8_24x4_dp:
.macro INIT_BIAS
dup v7.4s, wzr
cbz x3, InitBias
ld1 {v7.4s}, [x3]
InitBias:
cbz x20, NoSum
mov x22, x15
cbz x21, SymSum
ld1 {v8.4s}, [x22], x23
ld1 {v9.4s}, [x22], x23
ld1 {v10.4s}, [x22], x23
ld1 {v11.4s}, [x22], x23
ld1 {v12.4s}, [x22], x23
ld1 {v13.4s}, [x22], x23
ld1 {v14.4s}, [x22], x23
ld1 {v15.4s}, [x22], x23
ld1 {v16.4s}, [x22], x23
ld1 {v17.4s}, [x22], x23
ld1 {v18.4s}, [x22], x23
ld1 {v19.4s}, [x22], x23
ld1 {v20.4s}, [x22], x23
ld1 {v21.4s}, [x22], x23
ld1 {v22.4s}, [x22], x23
ld1 {v23.4s}, [x22], x23
ld1 {v24.4s}, [x22], x23
ld1 {v25.4s}, [x22], x23
ld1 {v26.4s}, [x22], x23
ld1 {v27.4s}, [x22], x23
ld1 {v28.4s}, [x22], x23
ld1 {v29.4s}, [x22], x23
ld1 {v30.4s}, [x22], x23
ld1 {v31.4s}, [x22], x23
b AddSum
SymSum:
ld1r {v8.4s}, [x22], #4
ld1r {v9.4s}, [x22], #4
ld1r {v10.4s}, [x22], #4
ld1r {v11.4s}, [x22], #4
ld1r {v12.4s}, [x22], #4
ld1r {v13.4s}, [x22], #4
ld1r {v14.4s}, [x22], #4
ld1r {v15.4s}, [x22], #4
ld1r {v16.4s}, [x22], #4
ld1r {v17.4s}, [x22], #4
ld1r {v18.4s}, [x22], #4
ld1r {v19.4s}, [x22], #4
ld1r {v20.4s}, [x22], #4
ld1r {v21.4s}, [x22], #4
ld1r {v22.4s}, [x22], #4
ld1r {v23.4s}, [x22], #4
ld1r {v24.4s}, [x22], #4
ld1r {v25.4s}, [x22], #4
ld1r {v26.4s}, [x22], #4
ld1r {v27.4s}, [x22], #4
ld1r {v28.4s}, [x22], #4
ld1r {v29.4s}, [x22], #4
ld1r {v30.4s}, [x22], #4
ld1r {v31.4s}, [x22], #4
AddSum:
sub v8.4s, v7.4s, v8.4s
sub v9.4s, v7.4s, v9.4s
sub v10.4s, v7.4s, v10.4s
sub v11.4s, v7.4s, v11.4s
sub v12.4s, v7.4s, v12.4s
sub v13.4s, v7.4s, v13.4s
sub v14.4s, v7.4s, v14.4s
sub v15.4s, v7.4s, v15.4s
sub v16.4s, v7.4s, v16.4s
sub v17.4s, v7.4s, v17.4s
sub v18.4s, v7.4s, v18.4s
sub v19.4s, v7.4s, v19.4s
sub v20.4s, v7.4s, v20.4s
sub v21.4s, v7.4s, v21.4s
sub v22.4s, v7.4s, v22.4s
sub v23.4s, v7.4s, v23.4s
sub v24.4s, v7.4s, v24.4s
sub v25.4s, v7.4s, v25.4s
sub v26.4s, v7.4s, v26.4s
sub v27.4s, v7.4s, v27.4s
sub v28.4s, v7.4s, v28.4s
sub v29.4s, v7.4s, v29.4s
sub v30.4s, v7.4s, v30.4s
sub v31.4s, v7.4s, v31.4s
b InitBiasEnd
NoSum:
mov v8.16b, v7.16b
mov v9.16b, v7.16b
mov v10.16b, v7.16b
mov v11.16b, v7.16b
mov v12.16b, v7.16b
mov v13.16b, v7.16b
mov v14.16b, v7.16b
mov v15.16b, v7.16b
mov v16.16b, v7.16b
mov v17.16b, v7.16b
mov v18.16b, v7.16b
mov v19.16b, v7.16b
mov v20.16b, v7.16b
mov v21.16b, v7.16b
mov v22.16b, v7.16b
mov v23.16b, v7.16b
mov v24.16b, v7.16b
mov v25.16b, v7.16b
mov v26.16b, v7.16b
mov v27.16b, v7.16b
mov v28.16b, v7.16b
mov v29.16b, v7.16b
mov v30.16b, v7.16b
mov v31.16b, v7.16b
InitBiasEnd:
.endm
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
// r19 ~ r29 should be also preserved
// whereas our coding style do not permit such amount of parameters
sub sp, sp, #176
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
stp x19, x20, [sp], #16
stp x21, x22, [sp], #16
stp x23, x24, [sp], #16
ldr x15, [sp]
ldr w8, [sp, #8]
ldr w9, [sp, #16]
ldr w16, [sp, #24]
ldr x17, [sp, #32]
ldr x18, [sp, #40]
ldr x19, [sp, #48]
ldr x20, [sp, #56]
ldr x21, [sp, #64]
ldr x23, [sp, #72]
mul x5, x4, x5
mov x4, #1
LoopOc:
mov x10, x4
mov x12, x1
LoopKsize:
INIT_BIAS
mov x11, x0
// as some processors do not support sdot intrinsic, we use instruction word
// dp support is stilled judged dymaticly, instruction word is just used to ensure compilation
// according to https://static.docs.arm.com/ddi0596/g/ISA_A64_xml_v86A-2020-03_OPT.pdf
// the instruction word of sdot vd.4s, vn.16b, vm.4b[index] is
// 0100 1111 10Lm mmmm 1110 H0nn nnnd dddd
// mmmmm/nnnnn/ddddd is the number of neon register, HL is the high/low bit of index
// load input for output 1-8
ld1 {v0.16b, v1.16b}, [x12], #32
// load weight
ld1 {v6.16b}, [x2], #16
// step for output 1-4
.inst 0x4f80e0c8 // sdot v8.4s, v6.16b, v0.4b[0]
.inst 0x4fa0e0c9 // sdot v9.4s, v6.16b, v0.4b[1]
.inst 0x4f80e8ca // sdot v10.4s, v6.16b, v0.4b[2]
.inst 0x4fa0e8cb // sdot v11.4s, v6.16b, v0.4b[3]
// load input for output 9-16
ld1 {v2.16b, v3.16b, v4.16b, v5.16b}, [x12], #64
// another step for output 5-8
.inst 0x4f81e0cc // sdot v12.4s, v6.16b, v1.4b[0]
.inst 0x4fa1e0cd // sdot v13.4s, v6.16b, v1.4b[1]
.inst 0x4f81e8ce // sdot v14.4s, v6.16b, v1.4b[2]
.inst 0x4fa1e8cf // sdot v15.4s, v6.16b, v1.4b[3]
subs x13, x5, #1
beq LoopIcEndOne
// load weight
ld1 {v7.16b}, [x2], #16
cmp x13, #1
beq LoopIcEnd
LoopIc:
// load input for output 1-8
ld1 {v0.16b, v1.16b}, [x12], #32
.inst 0x4f82e0d0 // sdot v16.4s, v6.16b, v2.4b[0]
.inst 0x4fa2e0d1 // sdot v17.4s, v6.16b, v2.4b[1]
.inst 0x4f82e8d2 // sdot v18.4s, v6.16b, v2.4b[2]
.inst 0x4fa2e8d3 // sdot v19.4s, v6.16b, v2.4b[3]
.inst 0x4f83e0d4 // sdot v20.4s, v6.16b, v3.4b[0]
.inst 0x4fa3e0d5 // sdot v21.4s, v6.16b, v3.4b[1]
.inst 0x4f83e8d6 // sdot v22.4s, v6.16b, v3.4b[2]
.inst 0x4fa3e8d7 // sdot v23.4s, v6.16b, v3.4b[3]
ld1 {v2.16b, v3.16b}, [x12], #32
.inst 0x4f84e0d8 // sdot v24.4s, v6.16b, v4.4b[0]
.inst 0x4fa4e0d9 // sdot v25.4s, v6.16b, v4.4b[1]
.inst 0x4f84e8da // sdot v26.4s, v6.16b, v4.4b[2]
.inst 0x4fa4e8db // sdot v27.4s, v6.16b, v4.4b[3]
.inst 0x4f85e0dc // sdot v28.4s, v6.16b, v5.4b[0]
.inst 0x4fa5e0dd // sdot v29.4s, v6.16b, v5.4b[1]
.inst 0x4f85e8de // sdot v30.4s, v6.16b, v5.4b[2]
.inst 0x4fa5e8df // sdot v31.4s, v6.16b, v5.4b[3]
// load input for output 9-16
ld1 {v4.4s, v5.4s}, [x12], #32
.inst 0x4f80e0e8 // sdot v8.4s, v7.16b, v0.4b[0]
.inst 0x4fa0e0e9 // sdot v9.4s, v7.16b, v0.4b[1]
.inst 0x4f80e8ea // sdot v10.4s, v7.16b, v0.4b[2]
.inst 0x4fa0e8eb // sdot v11.4s, v7.16b, v0.4b[3]
// another step for output 5-8
.inst 0x4f81e0ec // sdot v12.4s, v7.16b, v1.4b[0]
.inst 0x4fa1e0ed // sdot v13.4s, v7.16b, v1.4b[1]
.inst 0x4f81e8ee // sdot v14.4s, v7.16b, v1.4b[2]
.inst 0x4fa1e8ef // sdot v15.4s, v7.16b, v1.4b[3]
// load input for output 1-8
ld1 {v0.16b, v1.16b}, [x12], #32
.inst 0x4f82e0f0 // sdot v16.4s, v7.16b, v2.4b[0]
.inst 0x4fa2e0f1 // sdot v17.4s, v7.16b, v2.4b[1]
.inst 0x4f82e8f2 // sdot v18.4s, v7.16b, v2.4b[2]
.inst 0x4fa2e8f3 // sdot v19.4s, v7.16b, v2.4b[3]
.inst 0x4f83e0f4 // sdot v20.4s, v7.16b, v3.4b[0]
.inst 0x4fa3e0f5 // sdot v21.4s, v7.16b, v3.4b[1]
.inst 0x4f83e8f6 // sdot v22.4s, v7.16b, v3.4b[2]
.inst 0x4fa3e8f7 // sdot v23.4s, v7.16b, v3.4b[3]
// load weight
ld1 {v6.16b}, [x2], #16
.inst 0x4f84e0f8 // sdot v24.4s, v7.16b, v4.4b[0]
.inst 0x4fa4e0f9 // sdot v25.4s, v7.16b, v4.4b[1]
.inst 0x4f84e8fa // sdot v26.4s, v7.16b, v4.4b[2]
.inst 0x4fa4e8fb // sdot v27.4s, v7.16b, v4.4b[3]
.inst 0x4f85e0fc // sdot v28.4s, v7.16b, v5.4b[0]
.inst 0x4fa5e0fd // sdot v29.4s, v7.16b, v5.4b[1]
.inst 0x4f85e8fe // sdot v30.4s, v7.16b, v5.4b[2]
.inst 0x4fa5e8ff // sdot v31.4s, v7.16b, v5.4b[3]
// load input for output 9-16
ld1 {v2.4s, v3.4s}, [x12], #32
.inst 0x4f80e0c8 // sdot v8.4s, v6.16b, v0.4b[0]
.inst 0x4fa0e0c9 // sdot v9.4s, v6.16b, v0.4b[1]
.inst 0x4f80e8ca // sdot v10.4s, v6.16b, v0.4b[2]
.inst 0x4fa0e8cb // sdot v11.4s, v6.16b, v0.4b[3]
// another step for output 5-8
.inst 0x4f81e0cc // sdot v12.4s, v6.16b, v1.4b[0]
.inst 0x4fa1e0cd // sdot v13.4s, v6.16b, v1.4b[1]
.inst 0x4f81e8ce // sdot v14.4s, v6.16b, v1.4b[2]
.inst 0x4fa1e8cf // sdot v15.4s, v6.16b, v1.4b[3]
// load input for output 9-16
ld1 {v4.4s, v5.4s}, [x12], #32
subs x13, x13, #2
beq LoopIcEndOne
// load weight
ld1 {v7.16b}, [x2], #16
cmp x13, #1
beq LoopIcEnd
b LoopIc
LoopIcEnd:
mov x22, x15
// load input for output 1-8
ld1 {v0.16b, v1.16b}, [x12], #32
.inst 0x4f82e0d0 // sdot v16.4s, v6.16b, v2.4b[0]
.inst 0x4fa2e0d1 // sdot v17.4s, v6.16b, v2.4b[1]
.inst 0x4f82e8d2 // sdot v18.4s, v6.16b, v2.4b[2]
.inst 0x4fa2e8d3 // sdot v19.4s, v6.16b, v2.4b[3]
.inst 0x4f83e0d4 // sdot v20.4s, v6.16b, v3.4b[0]
.inst 0x4fa3e0d5 // sdot v21.4s, v6.16b, v3.4b[1]
.inst 0x4f83e8d6 // sdot v22.4s, v6.16b, v3.4b[2]
.inst 0x4fa3e8d7 // sdot v23.4s, v6.16b, v3.4b[3]
ld1 {v2.16b, v3.16b}, [x12], #32
.inst 0x4f84e0d8 // sdot v24.4s, v6.16b, v4.4b[0]
.inst 0x4fa4e0d9 // sdot v25.4s, v6.16b, v4.4b[1]
.inst 0x4f84e8da // sdot v26.4s, v6.16b, v4.4b[2]
.inst 0x4fa4e8db // sdot v27.4s, v6.16b, v4.4b[3]
.inst 0x4f85e0dc // sdot v28.4s, v6.16b, v5.4b[0]
.inst 0x4fa5e0dd // sdot v29.4s, v6.16b, v5.4b[1]
.inst 0x4f85e8de // sdot v30.4s, v6.16b, v5.4b[2]
.inst 0x4fa5e8df // sdot v31.4s, v6.16b, v5.4b[3]
// load input for output 9-16
ld1 {v4.4s, v5.4s}, [x12], #32
.inst 0x4f80e0e8 // sdot v8.4s, v7.16b, v0.4b[0]
.inst 0x4fa0e0e9 // sdot v9.4s, v7.16b, v0.4b[1]
.inst 0x4f80e8ea // sdot v10.4s, v7.16b, v0.4b[2]
.inst 0x4fa0e8eb // sdot v11.4s, v7.16b, v0.4b[3]
.inst 0x4f81e0ec // sdot v12.4s, v7.16b, v1.4b[0]
.inst 0x4fa1e0ed // sdot v13.4s, v7.16b, v1.4b[1]
.inst 0x4f81e8ee // sdot v14.4s, v7.16b, v1.4b[2]
.inst 0x4fa1e8ef // sdot v15.4s, v7.16b, v1.4b[3]
.inst 0x4f82e0f0 // sdot v16.4s, v7.16b, v2.4b[0]
.inst 0x4fa2e0f1 // sdot v17.4s, v7.16b, v2.4b[1]
.inst 0x4f82e8f2 // sdot v18.4s, v7.16b, v2.4b[2]
.inst 0x4fa2e8f3 // sdot v19.4s, v7.16b, v2.4b[3]
.inst 0x4f83e0f4 // sdot v20.4s, v7.16b, v3.4b[0]
.inst 0x4fa3e0f5 // sdot v21.4s, v7.16b, v3.4b[1]
.inst 0x4f83e8f6 // sdot v22.4s, v7.16b, v3.4b[2]
.inst 0x4fa3e8f7 // sdot v23.4s, v7.16b, v3.4b[3]
.inst 0x4f84e0f8 // sdot v24.4s, v7.16b, v4.4b[0]
.inst 0x4fa4e0f9 // sdot v25.4s, v7.16b, v4.4b[1]
.inst 0x4f84e8fa // sdot v26.4s, v7.16b, v4.4b[2]
.inst 0x4fa4e8fb // sdot v27.4s, v7.16b, v4.4b[3]
.inst 0x4f85e0fc // sdot v28.4s, v7.16b, v5.4b[0]
.inst 0x4fa5e0fd // sdot v29.4s, v7.16b, v5.4b[1]
.inst 0x4f85e8fe // sdot v30.4s, v7.16b, v5.4b[2]
.inst 0x4fa5e8ff // sdot v31.4s, v7.16b, v5.4b[3]
b Quantization
LoopIcEndOne:
.inst 0x4f82e0d0 // sdot v16.4s, v6.16b, v2.4b[0]
.inst 0x4fa2e0d1 // sdot v17.4s, v6.16b, v2.4b[1]
.inst 0x4f82e8d2 // sdot v18.4s, v6.16b, v2.4b[2]
.inst 0x4fa2e8d3 // sdot v19.4s, v6.16b, v2.4b[3]
.inst 0x4f83e0d4 // sdot v20.4s, v6.16b, v3.4b[0]
.inst 0x4fa3e0d5 // sdot v21.4s, v6.16b, v3.4b[1]
.inst 0x4f83e8d6 // sdot v22.4s, v6.16b, v3.4b[2]
.inst 0x4fa3e8d7 // sdot v23.4s, v6.16b, v3.4b[3]
.inst 0x4f84e0d8 // sdot v24.4s, v6.16b, v4.4b[0]
.inst 0x4fa4e0d9 // sdot v25.4s, v6.16b, v4.4b[1]
.inst 0x4f84e8da // sdot v26.4s, v6.16b, v4.4b[2]
.inst 0x4fa4e8db // sdot v27.4s, v6.16b, v4.4b[3]
.inst 0x4f85e0dc // sdot v28.4s, v6.16b, v5.4b[0]
.inst 0x4fa5e0dd // sdot v29.4s, v6.16b, v5.4b[1]
.inst 0x4f85e8de // sdot v30.4s, v6.16b, v5.4b[2]
.inst 0x4fa5e8df // sdot v31.4s, v6.16b, v5.4b[3]
Quantization:
cbnz x21, PerChannel
ld1r {v2.4s}, [x18]
ld1r {v3.4s}, [x17]
ld1r {v4.4s}, [x19]
b QuantizeStart
PerChannel:
ld1 {v2.4s}, [x18]
ld1 {v3.4s}, [x17]
ld1 {v4.4s}, [x19]
QuantizeStart:
sqshl v8.4s, v8.4s, v2.4s
sqshl v9.4s, v9.4s, v2.4s
sqshl v10.4s, v10.4s, v2.4s
sqshl v11.4s, v11.4s, v2.4s
sqshl v12.4s, v12.4s, v2.4s
sqshl v13.4s, v13.4s, v2.4s
sqshl v14.4s, v14.4s, v2.4s
sqshl v15.4s, v15.4s, v2.4s
sqshl v16.4s, v16.4s, v2.4s
sqshl v17.4s, v17.4s, v2.4s
sqshl v18.4s, v18.4s, v2.4s
sqshl v19.4s, v19.4s, v2.4s
sqshl v20.4s, v20.4s, v2.4s
sqshl v21.4s, v21.4s, v2.4s
sqshl v22.4s, v22.4s, v2.4s
sqshl v23.4s, v23.4s, v2.4s
sqshl v24.4s, v24.4s, v2.4s
sqshl v25.4s, v25.4s, v2.4s
sqshl v26.4s, v26.4s, v2.4s
sqshl v27.4s, v27.4s, v2.4s
sqshl v28.4s, v28.4s, v2.4s
sqshl v29.4s, v29.4s, v2.4s
sqshl v30.4s, v30.4s, v2.4s
sqshl v31.4s, v31.4s, v2.4s
sqrdmulh v8.4s, v8.4s, v3.4s
sqrdmulh v9.4s, v9.4s, v3.4s
sqrdmulh v10.4s, v10.4s, v3.4s
sqrdmulh v11.4s, v11.4s, v3.4s
sqrdmulh v12.4s, v12.4s, v3.4s
sqrdmulh v13.4s, v13.4s, v3.4s
sqrdmulh v14.4s, v14.4s, v3.4s
sqrdmulh v15.4s, v15.4s, v3.4s
sqrdmulh v16.4s, v16.4s, v3.4s
sqrdmulh v17.4s, v17.4s, v3.4s
sqrdmulh v18.4s, v18.4s, v3.4s
sqrdmulh v19.4s, v19.4s, v3.4s
sqrdmulh v20.4s, v20.4s, v3.4s
sqrdmulh v21.4s, v21.4s, v3.4s
sqrdmulh v22.4s, v22.4s, v3.4s
sqrdmulh v23.4s, v23.4s, v3.4s
sqrdmulh v24.4s, v24.4s, v3.4s
sqrdmulh v25.4s, v25.4s, v3.4s
sqrdmulh v26.4s, v26.4s, v3.4s
sqrdmulh v27.4s, v27.4s, v3.4s
sqrdmulh v28.4s, v28.4s, v3.4s
sqrdmulh v29.4s, v29.4s, v3.4s
sqrdmulh v30.4s, v30.4s, v3.4s
sqrdmulh v31.4s, v31.4s, v3.4s
and v0.16b, v4.16b, v8.16b
sshr v0.4s, v0.4s, #31
sqadd v8.4s, v8.4s, v0.4s
srshl v8.4s, v8.4s, v4.4s
and v1.16b, v4.16b, v9.16b
sshr v1.4s, v1.4s, #31
sqadd v9.4s, v9.4s, v1.4s
srshl v9.4s, v9.4s, v4.4s
and v2.16b, v4.16b, v10.16b
sshr v2.4s, v2.4s, #31
sqadd v10.4s, v10.4s, v2.4s
srshl v10.4s, v10.4s, v4.4s
and v3.16b, v4.16b, v11.16b
sshr v3.4s, v3.4s, #31
sqadd v11.4s, v11.4s, v3.4s
srshl v11.4s, v11.4s, v4.4s
and v0.16b, v4.16b, v12.16b
sshr v0.4s, v0.4s, #31
sqadd v12.4s, v12.4s, v0.4s
srshl v12.4s, v12.4s, v4.4s
and v1.16b, v4.16b, v13.16b
sshr v1.4s, v1.4s, #31
sqadd v13.4s, v13.4s, v1.4s
srshl v13.4s, v13.4s, v4.4s
and v2.16b, v4.16b, v14.16b
sshr v2.4s, v2.4s, #31
sqadd v14.4s, v14.4s, v2.4s
srshl v14.4s, v14.4s, v4.4s
and v3.16b, v4.16b, v15.16b
sshr v3.4s, v3.4s, #31
sqadd v15.4s, v15.4s, v3.4s
srshl v15.4s, v15.4s, v4.4s
and v0.16b, v4.16b, v16.16b
sshr v0.4s, v0.4s, #31
sqadd v16.4s, v16.4s, v0.4s
srshl v16.4s, v16.4s, v4.4s
and v1.16b, v4.16b, v17.16b
sshr v1.4s, v1.4s, #31
sqadd v17.4s, v17.4s, v1.4s
srshl v17.4s, v17.4s, v4.4s
and v2.16b, v4.16b, v18.16b
sshr v2.4s, v2.4s, #31
sqadd v18.4s, v18.4s, v2.4s
srshl v18.4s, v18.4s, v4.4s
and v3.16b, v4.16b, v19.16b
sshr v3.4s, v3.4s, #31
sqadd v19.4s, v19.4s, v3.4s
srshl v19.4s, v19.4s, v4.4s
and v0.16b, v4.16b, v20.16b
sshr v0.4s, v0.4s, #31
sqadd v20.4s, v20.4s, v0.4s
srshl v20.4s, v20.4s, v4.4s
and v1.16b, v4.16b, v21.16b
sshr v1.4s, v1.4s, #31
sqadd v21.4s, v21.4s, v1.4s
srshl v21.4s, v21.4s, v4.4s
and v2.16b, v4.16b, v22.16b
sshr v2.4s, v2.4s, #31
sqadd v22.4s, v22.4s, v2.4s
srshl v22.4s, v22.4s, v4.4s
and v3.16b, v4.16b, v23.16b
sshr v3.4s, v3.4s, #31
sqadd v23.4s, v23.4s, v3.4s
srshl v23.4s, v23.4s, v4.4s
and v0.16b, v4.16b, v24.16b
sshr v0.4s, v0.4s, #31
sqadd v24.4s, v24.4s, v0.4s
srshl v24.4s, v24.4s, v4.4s
and v1.16b, v4.16b, v25.16b
sshr v1.4s, v1.4s, #31
sqadd v25.4s, v25.4s, v1.4s
srshl v25.4s, v25.4s, v4.4s
and v2.16b, v4.16b, v26.16b
sshr v2.4s, v2.4s, #31
sqadd v26.4s, v26.4s, v2.4s
srshl v26.4s, v26.4s, v4.4s
and v3.16b, v4.16b, v27.16b
sshr v3.4s, v3.4s, #31
sqadd v27.4s, v27.4s, v3.4s
srshl v27.4s, v27.4s, v4.4s
and v0.16b, v4.16b, v28.16b
sshr v0.4s, v0.4s, #31
sqadd v28.4s, v28.4s, v0.4s
srshl v28.4s, v28.4s, v4.4s
and v1.16b, v4.16b, v29.16b
sshr v1.4s, v1.4s, #31
sqadd v29.4s, v29.4s, v1.4s
srshl v29.4s, v29.4s, v4.4s
and v2.16b, v4.16b, v30.16b
sshr v2.4s, v2.4s, #31
sqadd v30.4s, v30.4s, v2.4s
srshl v30.4s, v30.4s, v4.4s
and v3.16b, v4.16b, v31.16b
sshr v3.4s, v3.4s, #31
sqadd v31.4s, v31.4s, v3.4s
srshl v31.4s, v31.4s, v4.4s
dup v5.4s, w16
add v8.4s, v8.4s, v5.4s
add v9.4s, v9.4s, v5.4s
add v10.4s, v10.4s, v5.4s
add v11.4s, v11.4s, v5.4s
add v12.4s, v12.4s, v5.4s
add v13.4s, v13.4s, v5.4s
add v14.4s, v14.4s, v5.4s
add v15.4s, v15.4s, v5.4s
add v16.4s, v16.4s, v5.4s
add v17.4s, v17.4s, v5.4s
add v18.4s, v18.4s, v5.4s
add v19.4s, v19.4s, v5.4s
add v20.4s, v20.4s, v5.4s
add v21.4s, v21.4s, v5.4s
add v22.4s, v22.4s, v5.4s
add v23.4s, v23.4s, v5.4s
add v24.4s, v24.4s, v5.4s
add v25.4s, v25.4s, v5.4s
add v26.4s, v26.4s, v5.4s
add v27.4s, v27.4s, v5.4s
add v28.4s, v28.4s, v5.4s
add v29.4s, v29.4s, v5.4s
add v30.4s, v30.4s, v5.4s
add v31.4s, v31.4s, v5.4s
dup v0.4s, w8
smax v8.4s, v8.4s, v0.4s
smax v9.4s, v9.4s, v0.4s
smax v10.4s, v10.4s, v0.4s
smax v11.4s, v11.4s, v0.4s
smax v12.4s, v12.4s, v0.4s
smax v13.4s, v13.4s, v0.4s
smax v14.4s, v14.4s, v0.4s
smax v15.4s, v15.4s, v0.4s
smax v16.4s, v16.4s, v0.4s
smax v17.4s, v17.4s, v0.4s
smax v18.4s, v18.4s, v0.4s
smax v19.4s, v19.4s, v0.4s
smax v20.4s, v20.4s, v0.4s
smax v21.4s, v21.4s, v0.4s
smax v22.4s, v22.4s, v0.4s
smax v23.4s, v23.4s, v0.4s
smax v24.4s, v24.4s, v0.4s
smax v25.4s, v25.4s, v0.4s
smax v26.4s, v26.4s, v0.4s
smax v27.4s, v27.4s, v0.4s
smax v28.4s, v28.4s, v0.4s
smax v29.4s, v29.4s, v0.4s
smax v30.4s, v30.4s, v0.4s
smax v31.4s, v31.4s, v0.4s
dup v1.4s, w9
smin v8.4s, v8.4s, v1.4s
smin v9.4s, v9.4s, v1.4s
smin v10.4s, v10.4s, v1.4s
smin v11.4s, v11.4s, v1.4s
smin v12.4s, v12.4s, v1.4s
smin v13.4s, v13.4s, v1.4s
smin v14.4s, v14.4s, v1.4s
smin v15.4s, v15.4s, v1.4s
smin v16.4s, v16.4s, v1.4s
smin v17.4s, v17.4s, v1.4s
smin v18.4s, v18.4s, v1.4s
smin v19.4s, v19.4s, v1.4s
smin v20.4s, v20.4s, v1.4s
smin v21.4s, v21.4s, v1.4s
smin v22.4s, v22.4s, v1.4s
smin v23.4s, v23.4s, v1.4s
smin v24.4s, v24.4s, v1.4s
smin v25.4s, v25.4s, v1.4s
smin v26.4s, v26.4s, v1.4s
smin v27.4s, v27.4s, v1.4s
smin v28.4s, v28.4s, v1.4s
smin v29.4s, v29.4s, v1.4s
smin v30.4s, v30.4s, v1.4s
smin v31.4s, v31.4s, v1.4s
sqxtn v6.4h, v8.4s
sqxtn2 v6.8h, v9.4s
sqxtn v0.8b, v6.8h
sqxtn v7.4h, v10.4s
sqxtn2 v7.8h, v11.4s
sqxtn2 v0.16b, v7.8h
sqxtn v6.4h, v12.4s
sqxtn2 v6.8h, v13.4s
sqxtn v1.8b, v6.8h
sqxtn v7.4h, v14.4s
sqxtn2 v7.8h, v15.4s
sqxtn2 v1.16b, v7.8h
sqxtn v6.4h, v16.4s
sqxtn2 v6.8h, v17.4s
sqxtn v2.8b, v6.8h
sqxtn v7.4h, v18.4s
sqxtn2 v7.8h, v19.4s
sqxtn2 v2.16b, v7.8h
sqxtn v6.4h, v20.4s
sqxtn2 v6.8h, v21.4s
sqxtn v3.8b, v6.8h
sqxtn v7.4h, v22.4s
sqxtn2 v7.8h, v23.4s
sqxtn2 v3.16b, v7.8h
sqxtn v6.4h, v24.4s
sqxtn2 v6.8h, v25.4s
sqxtn v4.8b, v6.8h
sqxtn v7.4h, v26.4s
sqxtn2 v7.8h, v27.4s
sqxtn2 v4.16b, v7.8h
sqxtn v6.4h, v28.4s
sqxtn2 v6.8h, v29.4s
sqxtn v5.8b, v6.8h
sqxtn v7.4h, v30.4s
sqxtn2 v7.8h, v31.4s
sqxtn2 v5.16b, v7.8h
// prefetching is not prefered while writing results in spite of cache missings
// you could try prfm pstl2strm
WriteStart:
cmp x6, #1
beq Write1
cmp x6, #2
beq Write2
cmp x6, #3
beq Write3
b Write4
Write1:
st1 {v0.b}[0], [x11], x7
st1 {v0.b}[4], [x11], x7
st1 {v0.b}[8], [x11], x7
st1 {v0.b}[12], [x11], x7
st1 {v1.b}[0], [x11], x7
st1 {v1.b}[4], [x11], x7
st1 {v1.b}[8], [x11], x7
st1 {v1.b}[12], [x11], x7
st1 {v2.b}[0], [x11], x7
st1 {v2.b}[4], [x11], x7
st1 {v2.b}[8], [x11], x7
st1 {v2.b}[12], [x11], x7
st1 {v3.b}[0], [x11], x7
st1 {v3.b}[4], [x11], x7
st1 {v3.b}[8], [x11], x7
st1 {v3.b}[12], [x11], x7
st1 {v4.b}[0], [x11], x7
st1 {v4.b}[4], [x11], x7
st1 {v4.b}[8], [x11], x7
st1 {v4.b}[12], [x11], x7
st1 {v5.b}[0], [x11], x7
st1 {v5.b}[4], [x11], x7
st1 {v5.b}[8], [x11], x7
st1 {v5.b}[12], [x11]
add x0, x0, #1
b WriteEnd
Write2:
st1 {v0.h}[0], [x11], x7
st1 {v0.h}[2], [x11], x7
st1 {v0.h}[4], [x11], x7
st1 {v0.h}[6], [x11], x7
st1 {v1.h}[0], [x11], x7
st1 {v1.h}[2], [x11], x7
st1 {v1.h}[4], [x11], x7
st1 {v1.h}[6], [x11], x7
st1 {v2.h}[0], [x11], x7
st1 {v2.h}[2], [x11], x7
st1 {v2.h}[4], [x11], x7
st1 {v2.h}[6], [x11], x7
st1 {v3.h}[0], [x11], x7
st1 {v3.h}[2], [x11], x7
st1 {v3.h}[4], [x11], x7
st1 {v3.h}[6], [x11], x7
st1 {v4.h}[0], [x11], x7
st1 {v4.h}[2], [x11], x7
st1 {v4.h}[4], [x11], x7
st1 {v4.h}[6], [x11], x7
st1 {v5.h}[0], [x11], x7
st1 {v5.h}[2], [x11], x7
st1 {v5.h}[4], [x11], x7
st1 {v5.h}[6], [x11]
add x0, x0, #2
b WriteEnd
Write3:
add x14, x11, #2
st1 {v0.h}[0], [x11], x7
st1 {v0.b}[2], [x14], x7
st1 {v0.h}[2], [x11], x7
st1 {v0.b}[6], [x14], x7
st1 {v0.h}[4], [x11], x7
st1 {v0.b}[10], [x14], x7
st1 {v0.h}[6], [x11], x7
st1 {v0.b}[14], [x14], x7
st1 {v1.h}[0], [x11], x7
st1 {v1.b}[2], [x14], x7
st1 {v1.h}[2], [x11], x7
st1 {v1.b}[6], [x14], x7
st1 {v1.h}[4], [x11], x7
st1 {v1.b}[10], [x14], x7
st1 {v1.h}[6], [x11], x7
st1 {v1.b}[14], [x14], x7
st1 {v2.h}[0], [x11], x7
st1 {v2.b}[2], [x14], x7
st1 {v2.h}[2], [x11], x7
st1 {v2.b}[6], [x14], x7
st1 {v2.h}[4], [x11], x7
st1 {v2.b}[10], [x14], x7
st1 {v2.h}[6], [x11], x7
st1 {v2.b}[14], [x14], x7
st1 {v3.h}[0], [x11], x7
st1 {v3.b}[2], [x14], x7
st1 {v3.h}[2], [x11], x7
st1 {v3.b}[6], [x14], x7
st1 {v3.h}[4], [x11], x7
st1 {v3.b}[10], [x14], x7
st1 {v3.h}[6], [x11], x7
st1 {v3.b}[14], [x14], x7
st1 {v4.h}[0], [x11], x7
st1 {v4.b}[2], [x14], x7
st1 {v4.h}[2], [x11], x7
st1 {v4.b}[6], [x14], x7
st1 {v4.h}[4], [x11], x7
st1 {v4.b}[10], [x14], x7
st1 {v4.h}[6], [x11], x7
st1 {v4.b}[14], [x14], x7
st1 {v5.h}[0], [x11], x7
st1 {v5.b}[2], [x14], x7
st1 {v5.h}[2], [x11], x7
st1 {v5.b}[6], [x14], x7
st1 {v5.h}[4], [x11], x7
st1 {v5.b}[10], [x14], x7
st1 {v5.h}[6], [x11], x7
st1 {v5.b}[14], [x14], x7
add x0, x0, #3
b WriteEnd
Write4:
st1 {v0.s}[0], [x11], x7
st1 {v0.s}[1], [x11], x7
st1 {v0.s}[2], [x11], x7
st1 {v0.s}[3], [x11], x7
st1 {v1.s}[0], [x11], x7
st1 {v1.s}[1], [x11], x7
st1 {v1.s}[2], [x11], x7
st1 {v1.s}[3], [x11], x7
st1 {v2.s}[0], [x11], x7
st1 {v2.s}[1], [x11], x7
st1 {v2.s}[2], [x11], x7
st1 {v2.s}[3], [x11], x7
st1 {v3.s}[0], [x11], x7
st1 {v3.s}[1], [x11], x7
st1 {v3.s}[2], [x11], x7
st1 {v3.s}[3], [x11], x7
st1 {v4.s}[0], [x11], x7
st1 {v4.s}[1], [x11], x7
st1 {v4.s}[2], [x11], x7
st1 {v4.s}[3], [x11], x7
st1 {v5.s}[0], [x11], x7
st1 {v5.s}[1], [x11], x7
st1 {v5.s}[2], [x11], x7
st1 {v5.s}[3], [x11]
add x0, x0, #4
WriteEnd:
subs x10, x10, #1
bne LoopKsize
subs x6, x6, #4
cbz x21, NoChannelForward
cbz x20, NoSumForward
add x15, x15, #16
NoSumForward:
add x17, x17, #16
add x18, x18, #16
add x19, x19, #16
NoChannelForward:
cbz x3, NoStepFowrard
add x3, x3, #16
NoStepFowrard:
bgt LoopOc
sub sp, sp, #176
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ldp x19, x20, [sp], #16
ldp x21, x22, [sp], #16
ldp x23, x24, [sp], #16
ret
#endif

View File

@ -62,10 +62,6 @@ int32x4_t ClacScaledInput(int32x4_t input, int32x4_t left_shift_result_vec, int3
#endif
#ifdef ENABLE_ARM32
void IndirectGemmInt8_2x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize,
size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min,
size_t act_max, size_t out_zp, int32_t *out_multiplier, int32_t *shift_before,
int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset);
void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height,
int width, int in_kh_step, int in_kw_step, int channel, int8_t in_zp, int32_t out_zp,
int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t acc_min,
@ -76,10 +72,6 @@ void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *wei
void PostFuncInt8C4Neon64(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc4div, size_t oc4res,
size_t plane, size_t stride, int32_t multiplier, int32_t left_shift, int32_t right_shift,
int32_t zp, int32_t mini, int32_t maxi);
void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize,
size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min,
size_t act_max, size_t out_zp, int32_t *out_multiplier, int32_t *shift_before,
int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset);
void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias,
int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp,
int32_t out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift,

View File

@ -811,38 +811,25 @@ void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int
return;
}
void Conv1x1Int8Arm32(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param) {
int is_per_channel = conv_param->conv_quant_arg_.filter_arg_num_ != 1 ? true : false;
#ifdef ENABLE_ARM32
MatmulInt8Neon32(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift,
conv_param->output_channel_, is_per_channel);
#else
MatMulInt8_4x2_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias,
left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
is_per_channel);
#endif
return;
}
void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param) {
int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp) {
int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1;
#ifdef ENABLE_ARM64
MatmulInt8Neon64(packed_input, packed_weight, dst, UP_ROUND(row, C4NUM), UP_ROUND(col, C4NUM), deep16, input_sum,
bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, row, col,
conv_param->output_channel_, is_per_oc);
#ifdef ENABLE_ARM32
MatmulInt8Neon32Opt(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift,
conv_param->output_channel_, is_per_oc, filter_zp);
#elif ENABLE_ARM64
MatmulInt8Neon64Opt(packed_input, packed_weight, dst, UP_ROUND(row, C4NUM), UP_ROUND(col, C4NUM), deep16, input_sum,
bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, row,
col, conv_param->output_channel_, is_per_oc, filter_zp);
#else
MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias,
left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
is_per_oc);
MatmulInt8Opt(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift,
conv_param->output_channel_, is_per_oc, filter_zp);
#endif
return;
}

View File

@ -43,13 +43,10 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i
size_t plane_size, ConvParameter *conv_param);
void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param);
int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp);
void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int32_t *filter_zp);
void Conv1x1Int8Arm32(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param);
// int8 convolution 3x3
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,

View File

@ -250,6 +250,41 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row,
return;
}
#ifndef ENABLE_ARM
void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums,
const int *bias, int mini, int maxi, int out_zp, int32_t *multiplier, int32_t *left_shift,
int32_t *right_shift, int stride, int filter_peroc, int32_t *filter_zp) {
int col_tile = C4NUM;
/* support per-layer && weight per-channel */
/* row4x16-major * row16x2-major => (int8)row-major*/
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r4div = r / C4NUM, r4mod = r % C4NUM;
int c4div = c / col_tile, c4mod = c % col_tile;
size_t ci = r * stride + c;
int32_t value = 0;
for (int d = 0; d < deep16; d++) {
int d16div = d / C16NUM, d16mod = d % C16NUM;
size_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod;
size_t bi = c4div * deep16 * col_tile + d16div * col_tile * C16NUM + c4mod * C16NUM + d16mod;
value = value + a[ai] * b[bi];
}
int32_t cur_input_sum = filter_peroc ? a_sums[r] * filter_zp[c] : a_sums[r];
value -= cur_input_sum;
value += bias[c];
int32_t cur_left_shift = filter_peroc ? left_shift[c] : left_shift[0];
int32_t cur_right_shift = filter_peroc ? right_shift[c] : right_shift[0];
int32_t cur_multiplier = filter_peroc ? multiplier[c] : multiplier[0];
value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + out_zp;
value = MSMIN(maxi, value);
value = MSMAX(mini, value);
dst[ci] = (int8_t)value;
}
}
return;
}
#endif
void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,

View File

@ -60,6 +60,9 @@ void MatMulInt8_4x16_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
size_t per_channel, int32_t *filter_zp);
void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums,
const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift,
int32_t *right_shift, int stride, int filter_peroc, int32_t *filter_zp);
#ifdef ENABLE_ARM64
void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums,
@ -68,11 +71,18 @@ void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, i
void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16,
const int *input_sum, const int *bias);
void MatmulInt8Neon64Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16,
const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier,
int32_t *left_shift, int32_t *right_shift, int row, int col, int stride, int filter_peroc,
int32_t *filter_zp);
#endif
#ifdef ENABLE_ARM32
void MatmulInt8Neon32(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16,
const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp,
int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel);
void MatmulInt8Neon32Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums,
const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier,
int32_t *left_shift, int32_t *right_shift, int stride, int filter_peroc, int32_t *filter_zp);
#endif
#ifdef __cplusplus
}

View File

@ -92,27 +92,19 @@ int Convolution1x1Int8OcOptPre(void *cdata, int task_id) {
}
int Convolution1x1Int8CPUKernel::OcRun(int task_id) {
#ifdef ENABLE_ARM32
return RunArm32Oc(task_id);
#else
if (support_optimize_) {
return RunArm64OptOc(task_id);
} else {
return RunArm64Oc(task_id);
return RunArmOc(task_id);
}
#endif
}
int Convolution1x1Int8CPUKernel::HwRun(int task_id) {
#ifdef ENABLE_ARM32
return RunArm32Hw(task_id);
#else
if (support_optimize_) {
return RunArm64OptHw(task_id);
} else {
return RunArm64Hw(task_id);
return RunArmHw(task_id);
}
#endif
}
int Convolution1x1Int8CPUKernel::InitRunBuf() {
@ -124,6 +116,7 @@ int Convolution1x1Int8CPUKernel::InitRunBuf() {
size_t size = support_optimize_ ? UP_ROUND(matmul_param_->row_, C8NUM) * UP_ROUND(matmul_param_->deep_, C4NUM)
: UP_ROUND(matmul_param_->row_, C4NUM) * UP_ROUND(matmul_param_->deep_, C16NUM);
packed_input_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(size * sizeof(int8_t)));
if (packed_input_ == nullptr) {
MS_LOG(ERROR) << "conv1x1 int8 Malloc packed_input_ error!";
@ -333,8 +326,8 @@ int Convolution1x1Int8CPUKernel::InitParam() {
matmul_param_->deep_4_ = UP_ROUND(matmul_param_->deep_, C4NUM);
matmul_param_->deep_16_ = UP_ROUND(matmul_param_->deep_, C16NUM);
int row_pack_count = 0;
int col_pack_count = 0;
int row_pack_count;
int col_pack_count;
#ifdef ENABLE_ARM32
row_pack_count = C4NUM;
@ -350,15 +343,7 @@ int Convolution1x1Int8CPUKernel::InitParam() {
#endif
/* init input sum size */
if (support_optimize_) {
input_sum_size_ = UP_ROUND(matmul_param_->row_, row_pack_count);
} else {
if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) {
input_sum_size_ = UP_ROUND(matmul_param_->col_, col_pack_count) * UP_ROUND(matmul_param_->row_, row_pack_count);
} else {
input_sum_size_ = UP_ROUND(matmul_param_->row_, row_pack_count);
}
}
input_sum_size_ = UP_ROUND(matmul_param_->row_, row_pack_count);
if (pre_trans_input_) {
input_ptr_ = reinterpret_cast<int8_t *>(malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(int8_t)));
@ -404,7 +389,7 @@ void Convolution1x1Int8CPUKernel::Pre1x1Trans(int8_t *src_input, int8_t *src_out
return;
}
int Convolution1x1Int8CPUKernel::RunArm64Hw(int task_id) {
int Convolution1x1Int8CPUKernel::RunArmHw(int task_id) {
int cur_stride = thread_stride_hw_ * C4NUM;
int res_stride = matmul_param_->row_ - task_id * thread_stride_hw_ * C4NUM;
int cur_hw = MSMIN(cur_stride, res_stride);
@ -415,51 +400,20 @@ int Convolution1x1Int8CPUKernel::RunArm64Hw(int task_id) {
int8_t *hw_in = input_ptr_ + task_id * thread_stride_hw_ * C4NUM * conv_param_->input_channel_;
int8_t *hw_out = output_ptr_ + task_id * thread_stride_hw_ * C4NUM * conv_param_->output_channel_;
int8_t *hw_packed_in = packed_input_ + task_id * thread_stride_hw_ * C4NUM * matmul_param_->deep_16_;
int32_t *hw_input_sum = filter_peroc_ ? input_sum_ + task_id * thread_stride_hw_ * C4NUM * matmul_param_->col_4_
: input_sum_ + task_id * thread_stride_hw_ * C4NUM;
int32_t *hw_input_sum = input_sum_ + task_id * thread_stride_hw_ * C4NUM;
RowMajor2Row16x4MajorInt8(hw_in, hw_packed_in, cur_hw, matmul_param_->deep_);
if (filter_peroc_) {
PackInputSum16x4PerChannel(hw_packed_in, hw_input_sum, filter_zp_ptr_, cur_hw, matmul_param_->deep_,
matmul_param_->col_);
PackInputSum16x4PerLayer(hw_packed_in, hw_input_sum, 1, UP_ROUND(cur_hw, C4NUM), matmul_param_->deep_16_);
} else {
PackInputSum16x4PerLayer(hw_packed_in, hw_input_sum, conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_,
UP_ROUND(cur_hw, C4NUM), matmul_param_->deep_16_);
}
Conv1x1Int8(hw_packed_in, packed_weight_, hw_out, hw_input_sum, reinterpret_cast<int32_t *>(bias_data_), cur_hw,
matmul_param_->col_, matmul_param_->deep_16_, left_shift_, right_shift_, multiplier_, conv_param_);
return RET_OK;
}
int Convolution1x1Int8CPUKernel::RunArm32Hw(int task_id) {
int cur_stride = thread_stride_hw_ * C4NUM;
int res_stride = matmul_param_->row_ - task_id * thread_stride_hw_ * C4NUM;
int cur_hw = MSMIN(cur_stride, res_stride);
if (cur_hw <= 0) {
return RET_OK;
}
int8_t *hw_in = input_ptr_ + task_id * thread_stride_hw_ * C4NUM * conv_param_->input_channel_;
int8_t *hw_out = output_ptr_ + task_id * thread_stride_hw_ * C4NUM * conv_param_->output_channel_;
int8_t *hw_packed_in = packed_input_ + task_id * thread_stride_hw_ * C4NUM * matmul_param_->deep_16_;
int32_t *hw_input_sum = filter_peroc_ ? input_sum_ + task_id * thread_stride_hw_ * C4NUM * matmul_param_->col_2_
: input_sum_ + task_id * thread_stride_hw_ * C4NUM;
RowMajor2Row16x4MajorInt8(hw_in, hw_packed_in, cur_hw, matmul_param_->deep_);
if (filter_peroc_) {
PackInputSum16x4PerChannelArm32(hw_packed_in, hw_input_sum, filter_zp_ptr_, cur_hw, conv_param_->input_channel_,
conv_param_->output_channel_);
} else {
PackInputSum16x4PerLayer(hw_packed_in, hw_input_sum, conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_,
UP_ROUND(cur_hw, C4NUM), matmul_param_->deep_16_);
}
Conv1x1Int8Arm32(hw_packed_in, packed_weight_, hw_out, hw_input_sum, reinterpret_cast<int32_t *>(bias_data_), cur_hw,
matmul_param_->col_, matmul_param_->deep_16_, left_shift_, right_shift_, multiplier_, conv_param_);
matmul_param_->col_, matmul_param_->deep_16_, left_shift_, right_shift_, multiplier_, conv_param_,
filter_zp_ptr_);
return RET_OK;
}
@ -489,26 +443,6 @@ int Convolution1x1Int8CPUKernel::RunArm64OptHw(int task_id) {
return RET_OK;
}
int Convolution1x1Int8CPUKernel::RunArm32Oc(int task_id) {
int stride = thread_stride_oc_ * C2NUM;
int cur_stride = task_id * stride;
int res_stride = matmul_param_->col_ - cur_stride;
int cur_oc = MSMIN(stride, res_stride);
if (cur_oc <= 0) {
return RET_OK;
}
int32_t *cur_input_sum = filter_peroc_ ? input_sum_ + cur_stride * matmul_param_->row_4_ : input_sum_;
int32_t *cur_left_shift = filter_peroc_ ? left_shift_ + cur_stride : conv_param_->conv_quant_arg_.left_shift_;
int32_t *cur_right_shift = filter_peroc_ ? right_shift_ + cur_stride : conv_param_->conv_quant_arg_.right_shift_;
int32_t *cur_multiplier = filter_peroc_ ? multiplier_ + cur_stride : conv_param_->conv_quant_arg_.quant_multiplier_;
Conv1x1Int8Arm32(packed_input_, packed_weight_ + cur_stride * matmul_param_->deep_16_, output_ptr_ + cur_stride,
cur_input_sum, reinterpret_cast<int32_t *>(bias_data_) + cur_stride, matmul_param_->row_, cur_oc,
matmul_param_->deep_16_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_);
return RET_OK;
}
int Convolution1x1Int8CPUKernel::RunArm64OptOc(int task_id) {
int stride = thread_stride_oc_ * C16NUM;
int cur_stride = task_id * stride;
@ -531,8 +465,13 @@ int Convolution1x1Int8CPUKernel::RunArm64OptOc(int task_id) {
return RET_OK;
}
int Convolution1x1Int8CPUKernel::RunArm64Oc(int task_id) {
int stride = thread_stride_oc_ * C4NUM;
int Convolution1x1Int8CPUKernel::RunArmOc(int task_id) {
#ifdef ENABLE_ARM32
int col_tile = C2NUM;
#else
int col_tile = C4NUM;
#endif
int stride = thread_stride_oc_ * col_tile;
int cur_stride = task_id * stride;
int res_stride = matmul_param_->col_ - cur_stride;
int cur_oc = MSMIN(stride, res_stride);
@ -540,14 +479,14 @@ int Convolution1x1Int8CPUKernel::RunArm64Oc(int task_id) {
return RET_OK;
}
int32_t *cur_input_sum = filter_peroc_ ? input_sum_ + cur_stride * matmul_param_->row_4_ : input_sum_;
int32_t *cur_left_shift = filter_peroc_ ? left_shift_ + cur_stride : conv_param_->conv_quant_arg_.left_shift_;
int32_t *cur_right_shift = filter_peroc_ ? right_shift_ + cur_stride : conv_param_->conv_quant_arg_.right_shift_;
int32_t *cur_multiplier = filter_peroc_ ? multiplier_ + cur_stride : conv_param_->conv_quant_arg_.quant_multiplier_;
int32_t *cur_zp = filter_peroc_ ? filter_zp_ptr_ + cur_stride : filter_zp_ptr_;
Conv1x1Int8(packed_input_, packed_weight_ + cur_stride * matmul_param_->deep_16_, output_ptr_ + cur_stride,
cur_input_sum, reinterpret_cast<int32_t *>(bias_data_) + cur_stride, matmul_param_->row_, cur_oc,
matmul_param_->deep_16_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_);
input_sum_, reinterpret_cast<int32_t *>(bias_data_) + cur_stride, matmul_param_->row_, cur_oc,
matmul_param_->deep_16_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_, cur_zp);
return RET_OK;
}
@ -592,7 +531,12 @@ int Convolution1x1Int8CPUKernel::Run() {
ParallelLaunch(this->context_->thread_pool_, Convolution1x1Int8OcOptPre, this, thread_count_hw_);
} else {
RowMajor2Row16x4MajorInt8(input_ptr_, packed_input_, matmul_param_->row_, matmul_param_->deep_);
PackInputSum16x4Int8(packed_input_, input_sum_, filter_zp_ptr_, conv_param_);
if (filter_peroc_) {
PackInputSum16x4PerLayer(packed_input_, input_sum_, 1, matmul_param_->row_4_, matmul_param_->deep_16_);
} else {
PackInputSum16x4PerLayer(packed_input_, input_sum_, conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_,
matmul_param_->row_4_, matmul_param_->deep_16_);
}
}
/* matmul parallel by oc */
error_code = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Int8OcRun, this, thread_count_oc_);

View File

@ -50,11 +50,9 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel {
int OcOptPre(int task_id);
private:
int RunArm32Oc(int task_id);
int RunArm64Oc(int task_id);
int RunArmOc(int task_id);
int RunArm64OptOc(int task_id);
int RunArm32Hw(int task_id);
int RunArm64Hw(int task_id);
int RunArmHw(int task_id);
int RunArm64OptHw(int task_id);
private:

View File

@ -21,11 +21,6 @@
#ifdef __cplusplus
extern "C" {
#endif
extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias,
size_t ksize, size_t ic4, size_t output_channel, size_t offset,
const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp,
int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after,
size_t asymmetric, size_t per_channel, size_t per_channel_offset);
extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16,
const int *input_sum, const int *bias);
@ -38,15 +33,6 @@ extern void MatmulInt8DpOpt(const int8_t *a, const int8_t *b, int8_t *dst, size_
int *left_shift, int *right_shift, size_t stride, size_t peroc, int *filter_zp);
#ifdef ENABLE_ARM64
void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias,
size_t ksize, size_t ic4, size_t output_channel, size_t offset,
const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp,
int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after,
size_t asymmetric, size_t per_channel, size_t per_channel_offset) {
return IndirectGemmInt8_24x4_dp(dst, src, weight, bias, ksize, ic4, output_channel, offset, input_sum, act_min,
act_max, out_zp, out_multiplier, shift_before, shift_after, asymmetric, per_channel,
per_channel_offset);
}
void MatMulR4Int8_optimize_handler(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16,
const int *input_sum, const int *bias) {