forked from mindspore-Ecosystem/mindspore
!8813 [MS][LITE][Develop]optimization for quantized convolution per oc
From: @lx0095 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
a86c0da849
|
@ -0,0 +1,299 @@
|
|||
#ifdef __arm__
|
||||
#ifndef __aarch64__
|
||||
|
||||
.text
|
||||
.align 5
|
||||
.global MatmulInt8Neon32Opt
|
||||
#ifndef __APPLE__
|
||||
.type MatmulInt8Neon32Opt, %function
|
||||
#endif
|
||||
|
||||
//void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16,
|
||||
// const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp,
|
||||
// int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel,
|
||||
// int *filter_zp);
|
||||
// #-52: a, #-48: b, #-44: dst, #-40: row
|
||||
// #0: col, #4: deep16, #8: input_sums, #12: weight_bias, #16: act_min, #20: act_max, #24: out_zp
|
||||
// #28: multiplier, #32: left_shift, #36: right_shift, #40: stride, #44: per_channel, #48: filter_zp
|
||||
|
||||
MatmulInt8Neon32Opt:
|
||||
push {r0-r11, lr}
|
||||
vpush {q4-q7}
|
||||
add sp, sp, #116
|
||||
|
||||
ldr r4, [sp] // col
|
||||
ldr r7, [sp, #40] // output stride
|
||||
mov r8, #0 // output channels offset
|
||||
ldr r10, [sp, #44]
|
||||
cmp r10, #0
|
||||
beq L1
|
||||
ldr r6, [sp, #8] // load intpu_sums ptr if per_channel
|
||||
L1:
|
||||
cmp r4, #0 // if at the end of col
|
||||
ble End1
|
||||
|
||||
ldr r0, [sp, #-52] // reload a ptr
|
||||
ldr r3, [sp, #-40] // reset row counter
|
||||
ldr r6, [sp, #8] // reload intpu_sums ptr if per_tensor
|
||||
L2:
|
||||
cmp r3, #0 // if at the end of row
|
||||
ble End2
|
||||
|
||||
ldr r1, [sp, #-48] // reload b ptr
|
||||
ldr r5, [sp, #4] // reset deep16
|
||||
vmov.i32 q6, #0
|
||||
vmov.i32 q7, #0
|
||||
vmov.i32 q8, #0
|
||||
vmov.i32 q9, #0
|
||||
vmov.i32 q10, #0
|
||||
vmov.i32 q11, #0
|
||||
vmov.i32 q12, #0
|
||||
vmov.i32 q13, #0
|
||||
L3:
|
||||
cmp r5, #0
|
||||
beq End3
|
||||
|
||||
vld1.8 {d0, d1, d2, d3}, [r0]!
|
||||
vld1.8 {d8, d9, d10, d11}, [r1]!
|
||||
vmull.s8 q14, d0, d8
|
||||
vmull.s8 q2, d0, d10
|
||||
vmull.s8 q15, d2, d8
|
||||
vmull.s8 q3, d2, d10
|
||||
vmlal.s8 q14, d1, d9
|
||||
vmlal.s8 q2, d1, d11
|
||||
vmlal.s8 q15, d3, d9
|
||||
vmlal.s8 q3, d3, d11
|
||||
|
||||
vpadal.s16 q6, q14
|
||||
vpadal.s16 q7, q2
|
||||
vpadal.s16 q8, q15
|
||||
vpadal.s16 q9, q3
|
||||
|
||||
vld1.8 {d0, d1, d2, d3}, [r0]!
|
||||
vmull.s8 q14, d0, d8
|
||||
vmull.s8 q2, d0, d10
|
||||
vmull.s8 q15, d2, d8
|
||||
vmull.s8 q3, d2, d10
|
||||
vmlal.s8 q14, d1, d9
|
||||
vmlal.s8 q2, d1, d11
|
||||
vmlal.s8 q15, d3, d9
|
||||
vmlal.s8 q3, d3, d11
|
||||
|
||||
vpadal.s16 q10, q14
|
||||
vpadal.s16 q11, q2
|
||||
vpadal.s16 q12, q15
|
||||
vpadal.s16 q13, q3
|
||||
sub r5, r5, #16 // deep16 -= 16
|
||||
b L3
|
||||
|
||||
End3:
|
||||
vpadd.i32 d0, d12, d13
|
||||
vpadd.i32 d1, d14, d15
|
||||
vpadd.i32 d2, d16, d17
|
||||
vpadd.i32 d3, d18, d19
|
||||
vpadd.i32 d4, d20, d21
|
||||
vpadd.i32 d5, d22, d23
|
||||
vpadd.i32 d6, d24, d25
|
||||
vpadd.i32 d7, d26, d27
|
||||
|
||||
vpadd.i32 d28, d0, d1
|
||||
vpadd.i32 d29, d2, d3
|
||||
vpadd.i32 d30, d4, d5
|
||||
vpadd.i32 d31, d6, d7
|
||||
|
||||
// Add weight_bias
|
||||
ldr r9, [sp, #12] // reload weight_bias ptr
|
||||
add r9, r9, r8
|
||||
vld1.32 {d26}, [r9]!
|
||||
vadd.i32 d28, d28, d26
|
||||
vadd.i32 d29, d29, d26
|
||||
vadd.i32 d30, d30, d26
|
||||
vadd.i32 d31, d31, d26
|
||||
|
||||
ldr r10, [sp, #44]
|
||||
cmp r10, #0
|
||||
bgt PerChannel
|
||||
|
||||
PerTensor:
|
||||
// Substract input_sums
|
||||
vld1.32 {d24, d25}, [r6]!
|
||||
vdup.32 d20, d24[0]
|
||||
vdup.32 d21, d24[1]
|
||||
vdup.32 d22, d25[0]
|
||||
vdup.32 d23, d25[1]
|
||||
vsub.s32 d28, d28, d20
|
||||
vsub.s32 d29, d29, d21
|
||||
vsub.s32 d30, d30, d22
|
||||
vsub.s32 d31, d31, d23
|
||||
|
||||
// Apply left shift
|
||||
ldr r10, [sp, #32]
|
||||
ldr r11, [r10]!
|
||||
vdup.32 q9, r11
|
||||
vshl.s32 q14, q14, q9
|
||||
vshl.s32 q15, q15, q9
|
||||
|
||||
// Apply the fixed-point part of the multiplier
|
||||
ldr r10, [sp, #28]
|
||||
ldr r11, [r10]
|
||||
vdup.32 q8, r11
|
||||
vqrdmulh.s32 q14, q14, q8
|
||||
vqrdmulh.s32 q15, q15, q8
|
||||
|
||||
// Apply right shift
|
||||
ldr r10, [sp, #36]
|
||||
ldr r11, [r10]
|
||||
vdup.32 q7, r11
|
||||
vand q6, q7, q14
|
||||
vshr.s32 q6, q6, #31
|
||||
vqadd.s32 q14, q14, q6
|
||||
vrshl.s32 q14, q14, q7
|
||||
vand q5, q7, q15
|
||||
vshr.s32 q5, q5, #31
|
||||
vqadd.s32 q15, q15, q5
|
||||
vrshl.s32 q15, q15, q7
|
||||
b AddDstZP
|
||||
|
||||
PerChannel:
|
||||
// Substract input_sums
|
||||
vld1.32 {d24, d25}, [r6]!
|
||||
vdup.32 d20, d24[0]
|
||||
vdup.32 d21, d24[1]
|
||||
vdup.32 d22, d25[0]
|
||||
vdup.32 d23, d25[1]
|
||||
ldr r10, [sp, #48]
|
||||
vld1.32 {d19}, [r10]
|
||||
vmul.s32 d24, d20, d19
|
||||
vmul.s32 d25, d21, d19
|
||||
vmul.s32 d26, d22, d19
|
||||
vmul.s32 d27, d23, d19
|
||||
vsub.s32 d28, d28, d24
|
||||
vsub.s32 d29, d29, d25
|
||||
vsub.s32 d30, d30, d26
|
||||
vsub.s32 d31, d31, d27
|
||||
|
||||
// Apply left shift
|
||||
ldr r10, [sp, #32]
|
||||
add r10, r10, r8
|
||||
vld1.32 {d23}, [r10]
|
||||
vshl.s32 d28, d28, d23
|
||||
vshl.s32 d29, d29, d23
|
||||
vshl.s32 d30, d30, d23
|
||||
vshl.s32 d31, d31, d23
|
||||
|
||||
// Apply the fixed-point part of the multiplier
|
||||
ldr r10, [sp, #28]
|
||||
add r10, r10, r8
|
||||
vld1.32 {d22}, [r10]
|
||||
vqrdmulh.s32 d28, d28, d22
|
||||
vqrdmulh.s32 d29, d29, d22
|
||||
vqrdmulh.s32 d30, d30, d22
|
||||
vqrdmulh.s32 d31, d31, d22
|
||||
|
||||
// Apply right shift
|
||||
ldr r10, [sp, #36]
|
||||
add r10, r10, r8
|
||||
vld1.32 {d21}, [r10]
|
||||
vand d20, d21, d28
|
||||
vshr.s32 d20, d20, #31
|
||||
vqadd.s32 d28, d28, d20
|
||||
vrshl.s32 d28, d28, d21
|
||||
vand d19, d21, d29
|
||||
vshr.s32 d19, d19, #31
|
||||
vqadd.s32 d29, d29, d19
|
||||
vrshl.s32 d29, d29, d21
|
||||
vand d18, d21, d30
|
||||
vshr.s32 d18, d18, #31
|
||||
vqadd.s32 d30, d30, d18
|
||||
vrshl.s32 d30, d30, d21
|
||||
vand d17, d21, d31
|
||||
vshr.s32 d17, d17, #31
|
||||
vqadd.s32 d31, d31, d17
|
||||
vrshl.s32 d31, d31, d21
|
||||
|
||||
AddDstZP:
|
||||
// Add the destination zero point
|
||||
ldr r10, [sp, #24]
|
||||
vdup.32 q4, r10
|
||||
vadd.i32 q14, q14, q4
|
||||
vadd.i32 q15, q15, q4
|
||||
|
||||
// Apply the act_min bound
|
||||
ldr r10, [sp, #16]
|
||||
vdup.32 q3, r10
|
||||
vmax.s32 q14, q14, q3
|
||||
vmax.s32 q15, q15, q3
|
||||
|
||||
// Apply the act_max bound
|
||||
ldr r10, [sp, #20]
|
||||
vdup.32 q2, r10
|
||||
vmin.s32 q14, q14, q2
|
||||
vmin.s32 q15, q15, q2
|
||||
|
||||
// Cast-and-saturate from int32 to int16
|
||||
vqmovn.s32 d28, q14
|
||||
vqmovn.s32 d29, q15
|
||||
|
||||
// Cast-and-saturate from int16 to int8
|
||||
vqmovn.s16 d30, q14
|
||||
|
||||
// start to write
|
||||
cmp r4, #2
|
||||
bge WriteCol2
|
||||
cmp r4, #1
|
||||
beq WriteCol1
|
||||
b EndWrite
|
||||
|
||||
WriteCol2:
|
||||
vst1.16 {d30[0]}, [r2], r7
|
||||
cmp r3, #1
|
||||
beq EndWrite
|
||||
vst1.16 {d30[1]}, [r2], r7
|
||||
cmp r3, #2
|
||||
beq EndWrite
|
||||
vst1.16 {d30[2]}, [r2], r7
|
||||
cmp r3, #3
|
||||
beq EndWrite
|
||||
vst1.16 {d30[3]}, [r2], r7
|
||||
b EndWrite
|
||||
|
||||
WriteCol1:
|
||||
vst1.8 {d30[0]}, [r2], r7
|
||||
cmp r3, #1
|
||||
beq EndWrite
|
||||
vst1.8 {d30[2]}, [r2], r7
|
||||
cmp r3, #2
|
||||
beq EndWrite
|
||||
vst1.8 {d30[4]}, [r2], r7
|
||||
cmp r3, #3
|
||||
beq EndWrite
|
||||
vst1.8 {d30[6]}, [r2], r7
|
||||
b EndWrite
|
||||
|
||||
EndWrite:
|
||||
sub r3, r3, #4 // a row counter -= 4
|
||||
b L2
|
||||
|
||||
End2:
|
||||
sub r4, r4, #2 // b col counter -= 2
|
||||
ldr r1, [sp, #-48] // load b ptr
|
||||
ldr r9, [sp, #4]
|
||||
mov r10, #2
|
||||
mul r9, r9, r10 // the stride of b
|
||||
add r1, r1, r9 // b ptr + stride
|
||||
str r1, [sp, #-48]
|
||||
ldr r2, [sp, #-44] // load dst ptr
|
||||
add r2, r2, #2 // dst ptr + offset
|
||||
str r2, [sp, #-44]
|
||||
ldr r10, [sp, #48]
|
||||
add r10, r10, #8
|
||||
str r10, [sp, #48]
|
||||
add r8, r8, #8 // output channels offset + 2*sizeof(int)
|
||||
b L1
|
||||
|
||||
End1:
|
||||
sub sp, sp, #116
|
||||
vpop {q4-q7}
|
||||
pop {r0-r11, pc}
|
||||
#endif
|
||||
#endif
|
|
@ -1,371 +0,0 @@
|
|||
#ifdef __aarch64__
|
||||
|
||||
.text
|
||||
.align 5
|
||||
.global IndirectGemmInt8_4x4
|
||||
#ifndef __APPLE__
|
||||
.type IndirectGemmInt8_4x4, %function
|
||||
#endif
|
||||
|
||||
// void IndirectGemmInt8_4x4(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4,
|
||||
// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, int32_t *out_multiplier,
|
||||
// int32_t *shift_before, int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset);
|
||||
// x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset
|
||||
IndirectGemmInt8_4x4:
|
||||
|
||||
.macro INIT_BIAS
|
||||
dup v16.4s, wzr
|
||||
dup v17.4s, wzr
|
||||
dup v18.4s, wzr
|
||||
dup v19.4s, wzr
|
||||
dup v20.4s, wzr
|
||||
dup v21.4s, wzr
|
||||
dup v22.4s, wzr
|
||||
dup v23.4s, wzr
|
||||
dup v24.4s, wzr
|
||||
dup v25.4s, wzr
|
||||
dup v26.4s, wzr
|
||||
dup v27.4s, wzr
|
||||
dup v28.4s, wzr
|
||||
dup v29.4s, wzr
|
||||
dup v30.4s, wzr
|
||||
dup v31.4s, wzr
|
||||
.endm
|
||||
|
||||
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
|
||||
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
|
||||
// r19 ~ r29 should be also preserved
|
||||
// whereas our coding style do not permit such amount of parameters
|
||||
sub sp, sp, #176
|
||||
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
stp x19, x20, [sp], #16
|
||||
stp x21, x22, [sp], #16
|
||||
stp x23, x24, [sp], #16
|
||||
|
||||
ldr x15, [sp]
|
||||
ldr w8, [sp, #8]
|
||||
ldr w9, [sp, #16]
|
||||
ldr w16, [sp, #24]
|
||||
ldr x17, [sp, #32]
|
||||
ldr x18, [sp, #40]
|
||||
ldr x19, [sp, #48]
|
||||
ldr x20, [sp, #56]
|
||||
ldr x21, [sp, #64]
|
||||
ldr x23, [sp, #72]
|
||||
|
||||
mul x5, x4, x5
|
||||
mov x4, #1
|
||||
|
||||
LoopOc:
|
||||
|
||||
mov x10, x4
|
||||
mov x12, x1
|
||||
|
||||
LoopKsize:
|
||||
INIT_BIAS
|
||||
mov x11, x0
|
||||
|
||||
// as some processors do not support sdot intrinsic, we use instruction word
|
||||
// dp support is stilled judged dymaticly, instruction word is just used to ensure compilation
|
||||
// according to https://static.docs.arm.com/ddi0596/g/ISA_A64_xml_v86A-2020-03_OPT.pdf
|
||||
// the instruction word of sdot vd.4s, vn.16b, vm.4b[index] is
|
||||
// 0100 1111 10Lm mmmm 1110 H0nn nnnd dddd
|
||||
// mmmmm/nnnnn/ddddd is the number of neon register, HL is the high/low bit of index
|
||||
|
||||
// load input for output 1-8
|
||||
ld1 {v0.16b, v1.16b}, [x12], #32
|
||||
// load weight
|
||||
ld1 {v4.16b, v5.16b}, [x2], #32
|
||||
// step for output 1-4
|
||||
smull v8.8h, v0.8b, v4.8b
|
||||
smull v9.8h, v0.8b, v5.8b
|
||||
smlal2 v8.8h, v0.16b, v4.16b
|
||||
smlal2 v9.8h, v0.16b, v5.16b
|
||||
// load input for output 9-16
|
||||
ld1 {v6.16b, v7.16b}, [x2], #32
|
||||
// another step for output 5-8
|
||||
smull v12.8h, v1.8b, v4.8b
|
||||
smull v13.8h, v1.8b, v5.8b
|
||||
smlal2 v12.8h, v1.16b, v4.16b
|
||||
smlal2 v13.8h, v1.16b, v5.16b
|
||||
ld1 {v2.16b, v3.16b}, [x12], #32
|
||||
smull v10.8h, v0.8b, v6.8b
|
||||
smull v11.8h, v0.8b, v7.8b
|
||||
saddlp v16.4s, v8.8h
|
||||
smlal2 v10.8h, v0.16b, v6.16b
|
||||
smlal2 v11.8h, v0.16b, v7.16b
|
||||
saddlp v17.4s, v9.8h
|
||||
smull v14.8h, v1.8b, v6.8b
|
||||
smull v15.8h, v1.8b, v7.8b
|
||||
saddlp v18.4s, v10.8h
|
||||
smlal2 v14.8h, v1.16b, v6.16b
|
||||
smlal2 v15.8h, v1.16b, v7.16b
|
||||
|
||||
subs x13, x5, #1
|
||||
beq LoopIcEnd
|
||||
|
||||
LoopIc:
|
||||
// load input for output 1-8
|
||||
ld1 {v0.16b, v1.16b}, [x12], #32
|
||||
sadalp v19.4s, v11.8h
|
||||
smull v8.8h, v2.8b, v4.8b
|
||||
smull v9.8h, v2.8b, v5.8b
|
||||
sadalp v20.4s, v12.8h
|
||||
smlal2 v8.8h, v2.16b, v4.16b
|
||||
smlal2 v9.8h, v2.16b, v5.16b
|
||||
sadalp v21.4s, v13.8h
|
||||
smull v10.8h, v2.8b, v6.8b
|
||||
smull v11.8h, v2.8b, v7.8b
|
||||
sadalp v22.4s, v14.8h
|
||||
smlal2 v10.8h, v2.16b, v6.16b
|
||||
smlal2 v11.8h, v2.16b, v7.16b
|
||||
sadalp v23.4s, v15.8h
|
||||
smull v12.8h, v3.8b, v4.8b
|
||||
smull v13.8h, v3.8b, v5.8b
|
||||
sadalp v24.4s, v8.8h
|
||||
smlal2 v12.8h, v3.16b, v4.16b
|
||||
smlal2 v13.8h, v3.16b, v5.16b
|
||||
ld1 {v4.16b, v5.16b}, [x2], #32
|
||||
sadalp v25.4s, v9.8h
|
||||
smull v14.8h, v3.8b, v6.8b
|
||||
smull v15.8h, v3.8b, v7.8b
|
||||
sadalp v26.4s, v10.8h
|
||||
smlal2 v14.8h, v3.16b, v6.16b
|
||||
smlal2 v15.8h, v3.16b, v7.16b
|
||||
ld1 {v6.16b, v7.16b}, [x2], #32
|
||||
sadalp v27.4s, v11.8h
|
||||
smull v8.8h, v0.8b, v4.8b
|
||||
smull v9.8h, v0.8b, v5.8b
|
||||
sadalp v28.4s, v12.8h
|
||||
smlal2 v8.8h, v0.16b, v4.16b
|
||||
smlal2 v9.8h, v0.16b, v5.16b
|
||||
ld1 {v2.16b, v3.16b}, [x12], #32
|
||||
sadalp v29.4s, v13.8h
|
||||
smull v12.8h, v1.8b, v4.8b
|
||||
smull v13.8h, v1.8b, v5.8b
|
||||
sadalp v30.4s, v14.8h
|
||||
smlal2 v12.8h, v1.16b, v4.16b
|
||||
smlal2 v13.8h, v1.16b, v5.16b
|
||||
sadalp v31.4s, v15.8h
|
||||
smull v10.8h, v0.8b, v6.8b
|
||||
smull v11.8h, v0.8b, v7.8b
|
||||
sadalp v16.4s, v8.8h
|
||||
smlal2 v10.8h, v0.16b, v6.16b
|
||||
smlal2 v11.8h, v0.16b, v7.16b
|
||||
sadalp v17.4s, v9.8h
|
||||
smull v14.8h, v1.8b, v6.8b
|
||||
smull v15.8h, v1.8b, v7.8b
|
||||
sadalp v18.4s, v10.8h
|
||||
smlal2 v14.8h, v1.16b, v6.16b
|
||||
smlal2 v15.8h, v1.16b, v7.16b
|
||||
|
||||
subs x13, x13, #1
|
||||
bne LoopIc
|
||||
|
||||
LoopIcEnd:
|
||||
sadalp v19.4s, v11.8h
|
||||
smull v8.8h, v2.8b, v4.8b
|
||||
smull v9.8h, v2.8b, v5.8b
|
||||
sadalp v20.4s, v12.8h
|
||||
smlal2 v8.8h, v2.16b, v4.16b
|
||||
smlal2 v9.8h, v2.16b, v5.16b
|
||||
sadalp v21.4s, v13.8h
|
||||
smull v10.8h, v2.8b, v6.8b
|
||||
smull v11.8h, v2.8b, v7.8b
|
||||
sadalp v22.4s, v14.8h
|
||||
smlal2 v10.8h, v2.16b, v6.16b
|
||||
smlal2 v11.8h, v2.16b, v7.16b
|
||||
sadalp v23.4s, v15.8h
|
||||
smull v12.8h, v3.8b, v4.8b
|
||||
smull v13.8h, v3.8b, v5.8b
|
||||
sadalp v24.4s, v8.8h
|
||||
smlal2 v12.8h, v3.16b, v4.16b
|
||||
smlal2 v13.8h, v3.16b, v5.16b
|
||||
sadalp v25.4s, v9.8h
|
||||
smull v14.8h, v3.8b, v6.8b
|
||||
smull v15.8h, v3.8b, v7.8b
|
||||
sadalp v26.4s, v10.8h
|
||||
smlal2 v14.8h, v3.16b, v6.16b
|
||||
smlal2 v15.8h, v3.16b, v7.16b
|
||||
sadalp v27.4s, v11.8h
|
||||
sadalp v28.4s, v12.8h
|
||||
sadalp v29.4s, v13.8h
|
||||
sadalp v30.4s, v14.8h
|
||||
sadalp v31.4s, v15.8h
|
||||
|
||||
// pairwise add
|
||||
addp v16.4s, v16.4s, v17.4s
|
||||
addp v18.4s, v18.4s, v19.4s
|
||||
addp v20.4s, v20.4s, v21.4s
|
||||
addp v22.4s, v22.4s, v23.4s
|
||||
addp v24.4s, v24.4s, v25.4s
|
||||
addp v26.4s, v26.4s, v27.4s
|
||||
addp v28.4s, v28.4s, v29.4s
|
||||
addp v30.4s, v30.4s, v31.4s
|
||||
dup v12.4s, wzr
|
||||
cbz x3, NoReadBias
|
||||
ld1 {v12.4s}, [x3]
|
||||
NoReadBias:
|
||||
addp v16.4s, v16.4s, v18.4s
|
||||
addp v20.4s, v20.4s, v22.4s
|
||||
addp v24.4s, v24.4s, v26.4s
|
||||
addp v28.4s, v28.4s, v30.4s
|
||||
cbz x20, NoSum
|
||||
// load sum
|
||||
mov x22, x15
|
||||
cbz x21, SymSum
|
||||
ld1 {v8.4s}, [x22], x23
|
||||
ld1 {v9.4s}, [x22], x23
|
||||
ld1 {v10.4s}, [x22], x23
|
||||
ld1 {v11.4s}, [x22]
|
||||
b AddSum
|
||||
SymSum:
|
||||
ld1r {v8.4s}, [x22], #4
|
||||
ld1r {v9.4s}, [x22], #4
|
||||
ld1r {v10.4s}, [x22], #4
|
||||
ld1r {v11.4s}, [x22]
|
||||
AddSum:
|
||||
sub v16.4s, v16.4s, v8.4s
|
||||
sub v20.4s, v20.4s, v9.4s
|
||||
sub v24.4s, v24.4s, v10.4s
|
||||
sub v28.4s, v28.4s, v11.4s
|
||||
NoSum:
|
||||
add v16.4s, v16.4s, v12.4s
|
||||
add v20.4s, v20.4s, v12.4s
|
||||
add v24.4s, v24.4s, v12.4s
|
||||
add v28.4s, v28.4s, v12.4s
|
||||
|
||||
cbnz x21, PerChannel
|
||||
ld1r {v2.4s}, [x18]
|
||||
ld1r {v3.4s}, [x17]
|
||||
ld1r {v4.4s}, [x19]
|
||||
b QuantizeStart
|
||||
PerChannel:
|
||||
ld1 {v2.4s}, [x18]
|
||||
ld1 {v3.4s}, [x17]
|
||||
ld1 {v4.4s}, [x19]
|
||||
QuantizeStart:
|
||||
sqshl v16.4s, v16.4s, v2.4s
|
||||
sqshl v20.4s, v20.4s, v2.4s
|
||||
sqshl v24.4s, v24.4s, v2.4s
|
||||
sqshl v28.4s, v28.4s, v2.4s
|
||||
|
||||
sqrdmulh v16.4s, v16.4s, v3.4s
|
||||
sqrdmulh v20.4s, v20.4s, v3.4s
|
||||
sqrdmulh v24.4s, v24.4s, v3.4s
|
||||
sqrdmulh v28.4s, v28.4s, v3.4s
|
||||
|
||||
and v0.16b, v4.16b, v16.16b
|
||||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v16.4s, v16.4s, v0.4s
|
||||
srshl v16.4s, v16.4s, v4.4s
|
||||
and v1.16b, v4.16b, v20.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v20.4s, v20.4s, v1.4s
|
||||
srshl v20.4s, v20.4s, v4.4s
|
||||
and v2.16b, v4.16b, v24.16b
|
||||
sshr v2.4s, v2.4s, #31
|
||||
sqadd v24.4s, v24.4s, v2.4s
|
||||
srshl v24.4s, v24.4s, v4.4s
|
||||
and v3.16b, v4.16b, v28.16b
|
||||
sshr v3.4s, v3.4s, #31
|
||||
sqadd v28.4s, v28.4s, v3.4s
|
||||
srshl v28.4s, v28.4s, v4.4s
|
||||
|
||||
dup v5.4s, w16
|
||||
add v16.4s, v16.4s, v5.4s
|
||||
add v20.4s, v20.4s, v5.4s
|
||||
add v24.4s, v24.4s, v5.4s
|
||||
add v28.4s, v28.4s, v5.4s
|
||||
|
||||
dup v0.4s, w8
|
||||
smax v16.4s, v16.4s, v0.4s
|
||||
smax v20.4s, v20.4s, v0.4s
|
||||
smax v24.4s, v24.4s, v0.4s
|
||||
smax v28.4s, v28.4s, v0.4s
|
||||
|
||||
dup v1.4s, w9
|
||||
smin v16.4s, v16.4s, v1.4s
|
||||
smin v20.4s, v20.4s, v1.4s
|
||||
smin v24.4s, v24.4s, v1.4s
|
||||
smin v28.4s, v28.4s, v1.4s
|
||||
|
||||
sqxtn v13.4h, v16.4s
|
||||
sqxtn2 v13.8h, v20.4s
|
||||
sqxtn v15.8b, v13.8h
|
||||
sqxtn v14.4h, v24.4s
|
||||
sqxtn2 v14.8h, v28.4s
|
||||
sqxtn2 v15.16b, v14.8h
|
||||
|
||||
// prefetching is not prefered while writing results in spite of cache missings
|
||||
// you could try prfm pstl2strm
|
||||
WriteStart:
|
||||
cmp x6, #1
|
||||
beq Write1
|
||||
cmp x6, #2
|
||||
beq Write2
|
||||
cmp x6, #3
|
||||
beq Write3
|
||||
b Write4
|
||||
Write1:
|
||||
st1 {v15.b}[0], [x11], x7
|
||||
st1 {v15.b}[4], [x11], x7
|
||||
st1 {v15.b}[8], [x11], x7
|
||||
st1 {v15.b}[12], [x11]
|
||||
add x0, x0, #1
|
||||
b WriteEnd
|
||||
Write2:
|
||||
st1 {v15.h}[0], [x11], x7
|
||||
st1 {v15.h}[2], [x11], x7
|
||||
st1 {v15.h}[4], [x11], x7
|
||||
st1 {v15.h}[6], [x11]
|
||||
add x0, x0, #2
|
||||
b WriteEnd
|
||||
Write3:
|
||||
add x14, x11, #2
|
||||
st1 {v15.h}[0], [x11], x7
|
||||
st1 {v15.b}[2], [x14], x7
|
||||
st1 {v15.h}[2], [x11], x7
|
||||
st1 {v15.b}[6], [x14], x7
|
||||
st1 {v15.h}[4], [x11], x7
|
||||
st1 {v15.b}[10], [x14], x7
|
||||
st1 {v15.h}[6], [x11]
|
||||
st1 {v15.b}[14], [x14]
|
||||
add x0, x0, #3
|
||||
b WriteEnd
|
||||
Write4:
|
||||
st1 {v15.s}[0], [x11], x7
|
||||
st1 {v15.s}[1], [x11], x7
|
||||
st1 {v15.s}[2], [x11], x7
|
||||
st1 {v15.s}[3], [x11]
|
||||
add x0, x0, #4
|
||||
|
||||
WriteEnd:
|
||||
|
||||
subs x10, x10, #1
|
||||
bne LoopKsize
|
||||
|
||||
subs x6, x6, #4
|
||||
cbz x21, NoChannelForward
|
||||
cbz x20, NoSumForward
|
||||
add x15, x15, #16
|
||||
NoSumForward:
|
||||
add x17, x17, #16
|
||||
add x18, x18, #16
|
||||
add x19, x19, #16
|
||||
NoChannelForward:
|
||||
cbz x3, NoStepFowrard
|
||||
add x3, x3, #16
|
||||
NoStepFowrard:
|
||||
bgt LoopOc
|
||||
|
||||
sub sp, sp, #176
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
ldp x19, x20, [sp], #16
|
||||
ldp x21, x22, [sp], #16
|
||||
ldp x23, x24, [sp], #16
|
||||
ret
|
||||
#endif
|
||||
|
|
@ -0,0 +1,408 @@
|
|||
#ifdef __aarch64__
|
||||
.text
|
||||
.align 5
|
||||
.global MatmulInt8Neon64Opt
|
||||
#ifndef __APPLE__
|
||||
.type MatmulInt8Neon64Opt, %function
|
||||
#endif
|
||||
|
||||
//void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums,
|
||||
// const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift,
|
||||
// int32_t *right_shift, int row, int col, int stride, int filter_peroc, int32_t *filter_zp)
|
||||
|
||||
// x0: a(left matrix ptr)
|
||||
// x1: b(right matrix ptr)
|
||||
// x2: out ptr
|
||||
// w3: row4
|
||||
// w4: col4
|
||||
// w5: deep16
|
||||
// x6: a_sums
|
||||
// x7: bias
|
||||
// w8: act_min
|
||||
// w9: act_max
|
||||
// w10: out_zp
|
||||
// x11: multiplier
|
||||
// x12: left_shift
|
||||
// x13: right_shift
|
||||
// w14: row
|
||||
// w15: col
|
||||
// w24: stride
|
||||
// w27: filter_peroc
|
||||
// x28: filter_zp
|
||||
|
||||
MatmulInt8Neon64Opt:
|
||||
sub sp, sp, #208
|
||||
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
stp x19, x20, [sp], #16
|
||||
stp x21, x22, [sp], #16
|
||||
stp x23, x24, [sp], #16
|
||||
stp x25, x26, [sp], #16
|
||||
stp x27, x28, [sp], #16
|
||||
|
||||
ldr w8, [sp]
|
||||
ldr w9, [sp, #8]
|
||||
ldr w10, [sp, #16]
|
||||
ldr x11, [sp, #24]
|
||||
ldr x12, [sp, #32]
|
||||
ldr x13, [sp, #40]
|
||||
ldr w14, [sp, #48]
|
||||
ldr w15, [sp, #56]
|
||||
ldr w24, [sp, #64]
|
||||
ldr w27, [sp, #72]
|
||||
ldr x28, [sp, #80]
|
||||
|
||||
mov w17, #4 // sizeof(int8)*4
|
||||
mul w21, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16
|
||||
mov w17, #1
|
||||
mov x25, x2
|
||||
L1:
|
||||
cmp w4, #0 // if at the end of col4
|
||||
beq End1
|
||||
|
||||
mov w16, w3 // reset a row4 counter
|
||||
mov w23, w14 // reset a row counter
|
||||
mov x17, x0 // reload a ptr
|
||||
mov x22, x6 // reload a_sums ptr
|
||||
L2:
|
||||
cmp w16, #0
|
||||
beq End2
|
||||
|
||||
mov x18, x1 // reload b ptr
|
||||
mov x19, x7 // reload bias ptr
|
||||
mov w20, w5 // reload depth
|
||||
dup v16.4s, wzr
|
||||
dup v17.4s, wzr
|
||||
dup v18.4s, wzr
|
||||
dup v19.4s, wzr
|
||||
dup v20.4s, wzr
|
||||
dup v21.4s, wzr
|
||||
dup v22.4s, wzr
|
||||
dup v23.4s, wzr
|
||||
dup v24.4s, wzr
|
||||
dup v25.4s, wzr
|
||||
dup v26.4s, wzr
|
||||
dup v27.4s, wzr
|
||||
dup v28.4s, wzr
|
||||
dup v29.4s, wzr
|
||||
dup v30.4s, wzr
|
||||
dup v31.4s, wzr
|
||||
L3:
|
||||
cmp w20, #0
|
||||
beq End3
|
||||
|
||||
ld1 {v0.16b}, [x17], #16
|
||||
ld1 {v1.16b}, [x17], #16
|
||||
ld1 {v2.16b}, [x17], #16
|
||||
ld1 {v3.16b}, [x17], #16
|
||||
ld1 {v4.16b}, [x18], #16
|
||||
ld1 {v5.16b}, [x18], #16
|
||||
ld1 {v6.16b}, [x18], #16
|
||||
ld1 {v7.16b}, [x18], #16
|
||||
|
||||
smull v8.8h, v4.8b, v0.8b
|
||||
smull v9.8h, v5.8b, v0.8b
|
||||
smull v10.8h, v6.8b, v0.8b
|
||||
smull v11.8h, v7.8b, v0.8b
|
||||
smull v12.8h, v4.8b, v1.8b
|
||||
smull v13.8h, v5.8b, v1.8b
|
||||
smull v14.8h, v6.8b, v1.8b
|
||||
smull v15.8h, v7.8b, v1.8b
|
||||
|
||||
smlal2 v8.8h, v4.16b, v0.16b
|
||||
smlal2 v9.8h, v5.16b, v0.16b
|
||||
smlal2 v10.8h, v6.16b, v0.16b
|
||||
smlal2 v11.8h, v7.16b, v0.16b
|
||||
smlal2 v12.8h, v4.16b, v1.16b
|
||||
smlal2 v13.8h, v5.16b, v1.16b
|
||||
smlal2 v14.8h, v6.16b, v1.16b
|
||||
smlal2 v15.8h, v7.16b, v1.16b
|
||||
|
||||
sadalp v16.4s, v8.8h
|
||||
sadalp v17.4s, v9.8h
|
||||
sadalp v18.4s, v10.8h
|
||||
sadalp v19.4s, v11.8h
|
||||
sadalp v20.4s, v12.8h
|
||||
sadalp v21.4s, v13.8h
|
||||
sadalp v22.4s, v14.8h
|
||||
sadalp v23.4s, v15.8h
|
||||
|
||||
smull v8.8h, v4.8b, v2.8b
|
||||
smull v9.8h, v5.8b, v2.8b
|
||||
smull v10.8h, v6.8b, v2.8b
|
||||
smull v11.8h, v7.8b, v2.8b
|
||||
smull v12.8h, v4.8b, v3.8b
|
||||
smull v13.8h, v5.8b, v3.8b
|
||||
smull v14.8h, v6.8b, v3.8b
|
||||
smull v15.8h, v7.8b, v3.8b
|
||||
|
||||
smlal2 v8.8h, v4.16b, v2.16b
|
||||
smlal2 v9.8h, v5.16b, v2.16b
|
||||
smlal2 v10.8h, v6.16b, v2.16b
|
||||
smlal2 v11.8h, v7.16b, v2.16b
|
||||
smlal2 v12.8h, v4.16b, v3.16b
|
||||
smlal2 v13.8h, v5.16b, v3.16b
|
||||
smlal2 v14.8h, v6.16b, v3.16b
|
||||
smlal2 v15.8h, v7.16b, v3.16b
|
||||
|
||||
sadalp v24.4s, v8.8h
|
||||
sadalp v25.4s, v9.8h
|
||||
sadalp v26.4s, v10.8h
|
||||
sadalp v27.4s, v11.8h
|
||||
sadalp v28.4s, v12.8h
|
||||
sadalp v29.4s, v13.8h
|
||||
sadalp v30.4s, v14.8h
|
||||
sadalp v31.4s, v15.8h
|
||||
subs w20, w20, #16 // depth + 16
|
||||
b L3
|
||||
|
||||
End3:
|
||||
addp v16.4s, v16.4s, v17.4s
|
||||
addp v18.4s, v18.4s, v19.4s
|
||||
addp v20.4s, v20.4s, v21.4s
|
||||
addp v22.4s, v22.4s, v23.4s
|
||||
addp v24.4s, v24.4s, v25.4s
|
||||
addp v26.4s, v26.4s, v27.4s
|
||||
addp v28.4s, v28.4s, v29.4s
|
||||
addp v30.4s, v30.4s, v31.4s
|
||||
|
||||
addp v16.4s, v16.4s, v18.4s
|
||||
addp v17.4s, v20.4s, v22.4s
|
||||
addp v18.4s, v24.4s, v26.4s
|
||||
addp v19.4s, v28.4s, v30.4s
|
||||
|
||||
// Add (Bias+Depth*Za*Zb-Za*Bsums)
|
||||
ld1 {v15.4s}, [x19], #16
|
||||
add v16.4s, v16.4s, v15.4s
|
||||
add v17.4s, v17.4s, v15.4s
|
||||
add v18.4s, v18.4s, v15.4s
|
||||
add v19.4s, v19.4s, v15.4s
|
||||
|
||||
ld1r {v20.4s}, [x22], #4
|
||||
ld1r {v21.4s}, [x22], #4
|
||||
ld1r {v22.4s}, [x22], #4
|
||||
ld1r {v23.4s}, [x22], #4
|
||||
cmp w27, #0
|
||||
beq Apply
|
||||
ld1 {v14.4s}, [x28]
|
||||
mul v20.4s, v20.4s, v14.4s
|
||||
mul v21.4s, v21.4s, v14.4s
|
||||
mul v22.4s, v22.4s, v14.4s
|
||||
mul v23.4s, v23.4s, v14.4s
|
||||
|
||||
Apply:
|
||||
// Subtract (Asums*Zb)
|
||||
sub v16.4s, v16.4s, v20.4s
|
||||
sub v17.4s, v17.4s, v21.4s
|
||||
sub v18.4s, v18.4s, v22.4s
|
||||
sub v19.4s, v19.4s, v23.4s
|
||||
|
||||
cmp w27, #1
|
||||
beq PerCLoad
|
||||
|
||||
ld1r {v13.4s}, [x12]
|
||||
ld1r {v12.4s}, [x11]
|
||||
ld1r {v11.4s}, [x13]
|
||||
b Quantize
|
||||
|
||||
PerCLoad:
|
||||
ld1 {v13.4s}, [x12]
|
||||
ld1 {v12.4s}, [x11]
|
||||
ld1 {v11.4s}, [x13]
|
||||
Quantize:
|
||||
|
||||
// Apply left shift
|
||||
sqshl v16.4s, v16.4s, v13.4s
|
||||
sqshl v17.4s, v17.4s, v13.4s
|
||||
sqshl v18.4s, v18.4s, v13.4s
|
||||
sqshl v19.4s, v19.4s, v13.4s
|
||||
|
||||
// Apply the fixed-point part of the multiplier.
|
||||
sqrdmulh v16.4s, v16.4s, v12.4s
|
||||
sqrdmulh v17.4s, v17.4s, v12.4s
|
||||
sqrdmulh v18.4s, v18.4s, v12.4s
|
||||
sqrdmulh v19.4s, v19.4s, v12.4s
|
||||
|
||||
// Apply right shift
|
||||
and v20.16b, v11.16b, v16.16b
|
||||
sshr v20.4s, v20.4s, #31
|
||||
sqadd v16.4s, v16.4s, v20.4s
|
||||
srshl v16.4s, v16.4s, v11.4s
|
||||
and v21.16b, v11.16b, v17.16b
|
||||
sshr v21.4s, v21.4s, #31
|
||||
sqadd v17.4s, v17.4s, v21.4s
|
||||
srshl v17.4s, v17.4s, v11.4s
|
||||
and v22.16b, v11.16b, v18.16b
|
||||
sshr v22.4s, v22.4s, #31
|
||||
sqadd v18.4s, v18.4s, v22.4s
|
||||
srshl v18.4s, v18.4s, v11.4s
|
||||
and v23.16b, v11.16b, v19.16b
|
||||
sshr v23.4s, v23.4s, #31
|
||||
sqadd v19.4s, v19.4s, v23.4s
|
||||
srshl v19.4s, v19.4s, v11.4s
|
||||
|
||||
// Add the destination zero point
|
||||
dup v10.4s, w10
|
||||
add v16.4s, v16.4s, v10.4s
|
||||
add v17.4s, v17.4s, v10.4s
|
||||
add v18.4s, v18.4s, v10.4s
|
||||
add v19.4s, v19.4s, v10.4s
|
||||
|
||||
// Apply the act_min bound
|
||||
dup v9.4s, w8
|
||||
smax v16.4s, v16.4s, v9.4s
|
||||
smax v17.4s, v17.4s, v9.4s
|
||||
smax v18.4s, v18.4s, v9.4s
|
||||
smax v19.4s, v19.4s, v9.4s
|
||||
|
||||
// Apply the act_min bound
|
||||
dup v8.4s, w9
|
||||
smin v16.4s, v16.4s, v8.4s
|
||||
smin v17.4s, v17.4s, v8.4s
|
||||
smin v18.4s, v18.4s, v8.4s
|
||||
smin v19.4s, v19.4s, v8.4s
|
||||
|
||||
// int32 -> int16
|
||||
sqxtn v13.4h, v16.4s
|
||||
sqxtn2 v13.8h, v17.4s
|
||||
sqxtn v14.4h, v18.4s
|
||||
sqxtn2 v14.8h, v19.4s
|
||||
|
||||
// int16 -> int8
|
||||
sqxtn v15.8b, v13.8h
|
||||
sqxtn2 v15.16b, v14.8h
|
||||
|
||||
cmp w23, #4
|
||||
blt Write // if rows < 4
|
||||
cmp w15, #4
|
||||
blt Write // if cols < 4
|
||||
|
||||
st1 {v15.s}[0], [x2], x24
|
||||
st1 {v15.s}[1], [x2], x24
|
||||
st1 {v15.s}[2], [x2], x24
|
||||
st1 {v15.s}[3], [x2], x24
|
||||
b Endwrite
|
||||
|
||||
Write:
|
||||
cmp w15, #4
|
||||
beq WriteCol4
|
||||
cmp w15, #3
|
||||
beq WriteCol3
|
||||
cmp w15, #2
|
||||
beq WriteCol2
|
||||
cmp w15, #1
|
||||
beq WriteCol1
|
||||
|
||||
WriteCol4:
|
||||
st1 {v15.s}[0], [x2], x24
|
||||
cmp w23, #1
|
||||
beq Endwrite
|
||||
st1 {v15.s}[1], [x2], x24
|
||||
cmp w23, #2
|
||||
beq Endwrite
|
||||
st1 {v15.s}[2], [x2], x24
|
||||
cmp w23, #3
|
||||
beq Endwrite
|
||||
st1 {v15.s}[3], [x2], x24
|
||||
b Endwrite
|
||||
|
||||
WriteCol3:
|
||||
mov x26, x2
|
||||
st1 {v15.b}[0], [x26], #1
|
||||
st1 {v15.b}[1], [x26], #1
|
||||
st1 {v15.b}[2], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #1
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v15.b}[4], [x26], #1
|
||||
st1 {v15.b}[5], [x26], #1
|
||||
st1 {v15.b}[6], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #2
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v15.b}[8], [x26], #1
|
||||
st1 {v15.b}[9], [x26], #1
|
||||
st1 {v15.b}[10], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #3
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v15.b}[12], [x26], #1
|
||||
st1 {v15.b}[13], [x26], #1
|
||||
st1 {v15.b}[14], [x26], #1
|
||||
add x2, x2, x24
|
||||
b Endwrite
|
||||
|
||||
WriteCol2:
|
||||
mov x26, x2
|
||||
st1 {v15.b}[0], [x26], #1
|
||||
st1 {v15.b}[1], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #1
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v15.b}[4], [x26], #1
|
||||
st1 {v15.b}[5], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #2
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v15.b}[8], [x26], #1
|
||||
st1 {v15.b}[9], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #3
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v15.b}[12], [x26], #1
|
||||
st1 {v15.b}[13], [x26], #1
|
||||
add x2, x2, x24
|
||||
b Endwrite
|
||||
|
||||
WriteCol1:
|
||||
st1 {v15.b}[0], [x2], x24
|
||||
cmp w23, #1
|
||||
beq Endwrite
|
||||
st1 {v15.b}[4], [x2], x24
|
||||
cmp w23, #2
|
||||
beq Endwrite
|
||||
st1 {v15.b}[8], [x2], x24
|
||||
cmp w23, #3
|
||||
beq Endwrite
|
||||
st1 {v15.b}[12], [x2], x24
|
||||
b Endwrite
|
||||
|
||||
Endwrite:
|
||||
sub w16, w16, #4 // a row4 counter - 4
|
||||
sub w23, w23, #4 // a row counter - 4
|
||||
b L2
|
||||
|
||||
End2:
|
||||
sub w4, w4, #4 // b col4 counter - 4
|
||||
sub w15, w15, #4 // b col counter - 4
|
||||
add x1, x1, x21 // b ptr + stride
|
||||
add x7, x7, #16 // bias ptr + stride
|
||||
add x25, x25, #4 // output + stride(4 * sizeof(int8))
|
||||
mov x2, x25
|
||||
|
||||
cmp w27, #0
|
||||
beq PerTEnd2
|
||||
add x12, x12, #16
|
||||
add x11, x11, #16
|
||||
add x13, x13, #16
|
||||
add x28, x28, #16
|
||||
PerTEnd2:
|
||||
b L1
|
||||
|
||||
End1:
|
||||
sub sp, sp, #208
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
ldp x19, x20, [sp], #16
|
||||
ldp x21, x22, [sp], #16
|
||||
ldp x23, x24, [sp], #16
|
||||
ldp x25, x26, [sp], #16
|
||||
ldp x27, x28, [sp], #16
|
||||
ret
|
||||
#endif
|
|
@ -1,785 +0,0 @@
|
|||
#ifdef __aarch64__
|
||||
|
||||
.text
|
||||
.align 5
|
||||
.global IndirectGemmInt8_24x4_dp
|
||||
#ifndef __APPLE__
|
||||
.type IndirectGemmInt8_24x4_dp, %function
|
||||
#endif
|
||||
|
||||
// void IndirectGemmInt8_24x4_dp(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4,
|
||||
// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, int32_t *out_multiplier,
|
||||
// int32_t *shift_before, int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset);
|
||||
// x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset
|
||||
// we use sdot intrinsic on cores that supports dotprod(Armv8.2-A w/dp or later)
|
||||
// mrs intrinsic could read system register ID_AA64ISAR0_EL1(or s3_0_c0_c6_0 on Armv8.2-A)
|
||||
// the 44-48 bits indicates whether dotprod is supported
|
||||
IndirectGemmInt8_24x4_dp:
|
||||
|
||||
.macro INIT_BIAS
|
||||
dup v7.4s, wzr
|
||||
cbz x3, InitBias
|
||||
ld1 {v7.4s}, [x3]
|
||||
InitBias:
|
||||
cbz x20, NoSum
|
||||
mov x22, x15
|
||||
cbz x21, SymSum
|
||||
ld1 {v8.4s}, [x22], x23
|
||||
ld1 {v9.4s}, [x22], x23
|
||||
ld1 {v10.4s}, [x22], x23
|
||||
ld1 {v11.4s}, [x22], x23
|
||||
ld1 {v12.4s}, [x22], x23
|
||||
ld1 {v13.4s}, [x22], x23
|
||||
ld1 {v14.4s}, [x22], x23
|
||||
ld1 {v15.4s}, [x22], x23
|
||||
ld1 {v16.4s}, [x22], x23
|
||||
ld1 {v17.4s}, [x22], x23
|
||||
ld1 {v18.4s}, [x22], x23
|
||||
ld1 {v19.4s}, [x22], x23
|
||||
ld1 {v20.4s}, [x22], x23
|
||||
ld1 {v21.4s}, [x22], x23
|
||||
ld1 {v22.4s}, [x22], x23
|
||||
ld1 {v23.4s}, [x22], x23
|
||||
ld1 {v24.4s}, [x22], x23
|
||||
ld1 {v25.4s}, [x22], x23
|
||||
ld1 {v26.4s}, [x22], x23
|
||||
ld1 {v27.4s}, [x22], x23
|
||||
ld1 {v28.4s}, [x22], x23
|
||||
ld1 {v29.4s}, [x22], x23
|
||||
ld1 {v30.4s}, [x22], x23
|
||||
ld1 {v31.4s}, [x22], x23
|
||||
b AddSum
|
||||
SymSum:
|
||||
ld1r {v8.4s}, [x22], #4
|
||||
ld1r {v9.4s}, [x22], #4
|
||||
ld1r {v10.4s}, [x22], #4
|
||||
ld1r {v11.4s}, [x22], #4
|
||||
ld1r {v12.4s}, [x22], #4
|
||||
ld1r {v13.4s}, [x22], #4
|
||||
ld1r {v14.4s}, [x22], #4
|
||||
ld1r {v15.4s}, [x22], #4
|
||||
ld1r {v16.4s}, [x22], #4
|
||||
ld1r {v17.4s}, [x22], #4
|
||||
ld1r {v18.4s}, [x22], #4
|
||||
ld1r {v19.4s}, [x22], #4
|
||||
ld1r {v20.4s}, [x22], #4
|
||||
ld1r {v21.4s}, [x22], #4
|
||||
ld1r {v22.4s}, [x22], #4
|
||||
ld1r {v23.4s}, [x22], #4
|
||||
ld1r {v24.4s}, [x22], #4
|
||||
ld1r {v25.4s}, [x22], #4
|
||||
ld1r {v26.4s}, [x22], #4
|
||||
ld1r {v27.4s}, [x22], #4
|
||||
ld1r {v28.4s}, [x22], #4
|
||||
ld1r {v29.4s}, [x22], #4
|
||||
ld1r {v30.4s}, [x22], #4
|
||||
ld1r {v31.4s}, [x22], #4
|
||||
AddSum:
|
||||
sub v8.4s, v7.4s, v8.4s
|
||||
sub v9.4s, v7.4s, v9.4s
|
||||
sub v10.4s, v7.4s, v10.4s
|
||||
sub v11.4s, v7.4s, v11.4s
|
||||
sub v12.4s, v7.4s, v12.4s
|
||||
sub v13.4s, v7.4s, v13.4s
|
||||
sub v14.4s, v7.4s, v14.4s
|
||||
sub v15.4s, v7.4s, v15.4s
|
||||
sub v16.4s, v7.4s, v16.4s
|
||||
sub v17.4s, v7.4s, v17.4s
|
||||
sub v18.4s, v7.4s, v18.4s
|
||||
sub v19.4s, v7.4s, v19.4s
|
||||
sub v20.4s, v7.4s, v20.4s
|
||||
sub v21.4s, v7.4s, v21.4s
|
||||
sub v22.4s, v7.4s, v22.4s
|
||||
sub v23.4s, v7.4s, v23.4s
|
||||
sub v24.4s, v7.4s, v24.4s
|
||||
sub v25.4s, v7.4s, v25.4s
|
||||
sub v26.4s, v7.4s, v26.4s
|
||||
sub v27.4s, v7.4s, v27.4s
|
||||
sub v28.4s, v7.4s, v28.4s
|
||||
sub v29.4s, v7.4s, v29.4s
|
||||
sub v30.4s, v7.4s, v30.4s
|
||||
sub v31.4s, v7.4s, v31.4s
|
||||
b InitBiasEnd
|
||||
NoSum:
|
||||
mov v8.16b, v7.16b
|
||||
mov v9.16b, v7.16b
|
||||
mov v10.16b, v7.16b
|
||||
mov v11.16b, v7.16b
|
||||
mov v12.16b, v7.16b
|
||||
mov v13.16b, v7.16b
|
||||
mov v14.16b, v7.16b
|
||||
mov v15.16b, v7.16b
|
||||
mov v16.16b, v7.16b
|
||||
mov v17.16b, v7.16b
|
||||
mov v18.16b, v7.16b
|
||||
mov v19.16b, v7.16b
|
||||
mov v20.16b, v7.16b
|
||||
mov v21.16b, v7.16b
|
||||
mov v22.16b, v7.16b
|
||||
mov v23.16b, v7.16b
|
||||
mov v24.16b, v7.16b
|
||||
mov v25.16b, v7.16b
|
||||
mov v26.16b, v7.16b
|
||||
mov v27.16b, v7.16b
|
||||
mov v28.16b, v7.16b
|
||||
mov v29.16b, v7.16b
|
||||
mov v30.16b, v7.16b
|
||||
mov v31.16b, v7.16b
|
||||
InitBiasEnd:
|
||||
.endm
|
||||
|
||||
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
|
||||
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
|
||||
// r19 ~ r29 should be also preserved
|
||||
// whereas our coding style do not permit such amount of parameters
|
||||
sub sp, sp, #176
|
||||
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
stp x19, x20, [sp], #16
|
||||
stp x21, x22, [sp], #16
|
||||
stp x23, x24, [sp], #16
|
||||
|
||||
ldr x15, [sp]
|
||||
ldr w8, [sp, #8]
|
||||
ldr w9, [sp, #16]
|
||||
ldr w16, [sp, #24]
|
||||
ldr x17, [sp, #32]
|
||||
ldr x18, [sp, #40]
|
||||
ldr x19, [sp, #48]
|
||||
ldr x20, [sp, #56]
|
||||
ldr x21, [sp, #64]
|
||||
ldr x23, [sp, #72]
|
||||
|
||||
mul x5, x4, x5
|
||||
mov x4, #1
|
||||
|
||||
LoopOc:
|
||||
|
||||
mov x10, x4
|
||||
mov x12, x1
|
||||
|
||||
LoopKsize:
|
||||
INIT_BIAS
|
||||
mov x11, x0
|
||||
|
||||
// as some processors do not support sdot intrinsic, we use instruction word
|
||||
// dp support is stilled judged dymaticly, instruction word is just used to ensure compilation
|
||||
// according to https://static.docs.arm.com/ddi0596/g/ISA_A64_xml_v86A-2020-03_OPT.pdf
|
||||
// the instruction word of sdot vd.4s, vn.16b, vm.4b[index] is
|
||||
// 0100 1111 10Lm mmmm 1110 H0nn nnnd dddd
|
||||
// mmmmm/nnnnn/ddddd is the number of neon register, HL is the high/low bit of index
|
||||
|
||||
// load input for output 1-8
|
||||
ld1 {v0.16b, v1.16b}, [x12], #32
|
||||
// load weight
|
||||
ld1 {v6.16b}, [x2], #16
|
||||
// step for output 1-4
|
||||
.inst 0x4f80e0c8 // sdot v8.4s, v6.16b, v0.4b[0]
|
||||
.inst 0x4fa0e0c9 // sdot v9.4s, v6.16b, v0.4b[1]
|
||||
.inst 0x4f80e8ca // sdot v10.4s, v6.16b, v0.4b[2]
|
||||
.inst 0x4fa0e8cb // sdot v11.4s, v6.16b, v0.4b[3]
|
||||
// load input for output 9-16
|
||||
ld1 {v2.16b, v3.16b, v4.16b, v5.16b}, [x12], #64
|
||||
// another step for output 5-8
|
||||
.inst 0x4f81e0cc // sdot v12.4s, v6.16b, v1.4b[0]
|
||||
.inst 0x4fa1e0cd // sdot v13.4s, v6.16b, v1.4b[1]
|
||||
.inst 0x4f81e8ce // sdot v14.4s, v6.16b, v1.4b[2]
|
||||
.inst 0x4fa1e8cf // sdot v15.4s, v6.16b, v1.4b[3]
|
||||
|
||||
subs x13, x5, #1
|
||||
beq LoopIcEndOne
|
||||
// load weight
|
||||
ld1 {v7.16b}, [x2], #16
|
||||
cmp x13, #1
|
||||
beq LoopIcEnd
|
||||
|
||||
LoopIc:
|
||||
// load input for output 1-8
|
||||
ld1 {v0.16b, v1.16b}, [x12], #32
|
||||
.inst 0x4f82e0d0 // sdot v16.4s, v6.16b, v2.4b[0]
|
||||
.inst 0x4fa2e0d1 // sdot v17.4s, v6.16b, v2.4b[1]
|
||||
.inst 0x4f82e8d2 // sdot v18.4s, v6.16b, v2.4b[2]
|
||||
.inst 0x4fa2e8d3 // sdot v19.4s, v6.16b, v2.4b[3]
|
||||
.inst 0x4f83e0d4 // sdot v20.4s, v6.16b, v3.4b[0]
|
||||
.inst 0x4fa3e0d5 // sdot v21.4s, v6.16b, v3.4b[1]
|
||||
.inst 0x4f83e8d6 // sdot v22.4s, v6.16b, v3.4b[2]
|
||||
.inst 0x4fa3e8d7 // sdot v23.4s, v6.16b, v3.4b[3]
|
||||
ld1 {v2.16b, v3.16b}, [x12], #32
|
||||
.inst 0x4f84e0d8 // sdot v24.4s, v6.16b, v4.4b[0]
|
||||
.inst 0x4fa4e0d9 // sdot v25.4s, v6.16b, v4.4b[1]
|
||||
.inst 0x4f84e8da // sdot v26.4s, v6.16b, v4.4b[2]
|
||||
.inst 0x4fa4e8db // sdot v27.4s, v6.16b, v4.4b[3]
|
||||
.inst 0x4f85e0dc // sdot v28.4s, v6.16b, v5.4b[0]
|
||||
.inst 0x4fa5e0dd // sdot v29.4s, v6.16b, v5.4b[1]
|
||||
.inst 0x4f85e8de // sdot v30.4s, v6.16b, v5.4b[2]
|
||||
.inst 0x4fa5e8df // sdot v31.4s, v6.16b, v5.4b[3]
|
||||
// load input for output 9-16
|
||||
ld1 {v4.4s, v5.4s}, [x12], #32
|
||||
.inst 0x4f80e0e8 // sdot v8.4s, v7.16b, v0.4b[0]
|
||||
.inst 0x4fa0e0e9 // sdot v9.4s, v7.16b, v0.4b[1]
|
||||
.inst 0x4f80e8ea // sdot v10.4s, v7.16b, v0.4b[2]
|
||||
.inst 0x4fa0e8eb // sdot v11.4s, v7.16b, v0.4b[3]
|
||||
// another step for output 5-8
|
||||
.inst 0x4f81e0ec // sdot v12.4s, v7.16b, v1.4b[0]
|
||||
.inst 0x4fa1e0ed // sdot v13.4s, v7.16b, v1.4b[1]
|
||||
.inst 0x4f81e8ee // sdot v14.4s, v7.16b, v1.4b[2]
|
||||
.inst 0x4fa1e8ef // sdot v15.4s, v7.16b, v1.4b[3]
|
||||
// load input for output 1-8
|
||||
ld1 {v0.16b, v1.16b}, [x12], #32
|
||||
.inst 0x4f82e0f0 // sdot v16.4s, v7.16b, v2.4b[0]
|
||||
.inst 0x4fa2e0f1 // sdot v17.4s, v7.16b, v2.4b[1]
|
||||
.inst 0x4f82e8f2 // sdot v18.4s, v7.16b, v2.4b[2]
|
||||
.inst 0x4fa2e8f3 // sdot v19.4s, v7.16b, v2.4b[3]
|
||||
.inst 0x4f83e0f4 // sdot v20.4s, v7.16b, v3.4b[0]
|
||||
.inst 0x4fa3e0f5 // sdot v21.4s, v7.16b, v3.4b[1]
|
||||
.inst 0x4f83e8f6 // sdot v22.4s, v7.16b, v3.4b[2]
|
||||
.inst 0x4fa3e8f7 // sdot v23.4s, v7.16b, v3.4b[3]
|
||||
// load weight
|
||||
ld1 {v6.16b}, [x2], #16
|
||||
.inst 0x4f84e0f8 // sdot v24.4s, v7.16b, v4.4b[0]
|
||||
.inst 0x4fa4e0f9 // sdot v25.4s, v7.16b, v4.4b[1]
|
||||
.inst 0x4f84e8fa // sdot v26.4s, v7.16b, v4.4b[2]
|
||||
.inst 0x4fa4e8fb // sdot v27.4s, v7.16b, v4.4b[3]
|
||||
.inst 0x4f85e0fc // sdot v28.4s, v7.16b, v5.4b[0]
|
||||
.inst 0x4fa5e0fd // sdot v29.4s, v7.16b, v5.4b[1]
|
||||
.inst 0x4f85e8fe // sdot v30.4s, v7.16b, v5.4b[2]
|
||||
.inst 0x4fa5e8ff // sdot v31.4s, v7.16b, v5.4b[3]
|
||||
// load input for output 9-16
|
||||
ld1 {v2.4s, v3.4s}, [x12], #32
|
||||
.inst 0x4f80e0c8 // sdot v8.4s, v6.16b, v0.4b[0]
|
||||
.inst 0x4fa0e0c9 // sdot v9.4s, v6.16b, v0.4b[1]
|
||||
.inst 0x4f80e8ca // sdot v10.4s, v6.16b, v0.4b[2]
|
||||
.inst 0x4fa0e8cb // sdot v11.4s, v6.16b, v0.4b[3]
|
||||
// another step for output 5-8
|
||||
.inst 0x4f81e0cc // sdot v12.4s, v6.16b, v1.4b[0]
|
||||
.inst 0x4fa1e0cd // sdot v13.4s, v6.16b, v1.4b[1]
|
||||
.inst 0x4f81e8ce // sdot v14.4s, v6.16b, v1.4b[2]
|
||||
.inst 0x4fa1e8cf // sdot v15.4s, v6.16b, v1.4b[3]
|
||||
// load input for output 9-16
|
||||
ld1 {v4.4s, v5.4s}, [x12], #32
|
||||
|
||||
subs x13, x13, #2
|
||||
beq LoopIcEndOne
|
||||
// load weight
|
||||
ld1 {v7.16b}, [x2], #16
|
||||
cmp x13, #1
|
||||
beq LoopIcEnd
|
||||
b LoopIc
|
||||
|
||||
LoopIcEnd:
|
||||
mov x22, x15
|
||||
// load input for output 1-8
|
||||
ld1 {v0.16b, v1.16b}, [x12], #32
|
||||
.inst 0x4f82e0d0 // sdot v16.4s, v6.16b, v2.4b[0]
|
||||
.inst 0x4fa2e0d1 // sdot v17.4s, v6.16b, v2.4b[1]
|
||||
.inst 0x4f82e8d2 // sdot v18.4s, v6.16b, v2.4b[2]
|
||||
.inst 0x4fa2e8d3 // sdot v19.4s, v6.16b, v2.4b[3]
|
||||
.inst 0x4f83e0d4 // sdot v20.4s, v6.16b, v3.4b[0]
|
||||
.inst 0x4fa3e0d5 // sdot v21.4s, v6.16b, v3.4b[1]
|
||||
.inst 0x4f83e8d6 // sdot v22.4s, v6.16b, v3.4b[2]
|
||||
.inst 0x4fa3e8d7 // sdot v23.4s, v6.16b, v3.4b[3]
|
||||
ld1 {v2.16b, v3.16b}, [x12], #32
|
||||
.inst 0x4f84e0d8 // sdot v24.4s, v6.16b, v4.4b[0]
|
||||
.inst 0x4fa4e0d9 // sdot v25.4s, v6.16b, v4.4b[1]
|
||||
.inst 0x4f84e8da // sdot v26.4s, v6.16b, v4.4b[2]
|
||||
.inst 0x4fa4e8db // sdot v27.4s, v6.16b, v4.4b[3]
|
||||
.inst 0x4f85e0dc // sdot v28.4s, v6.16b, v5.4b[0]
|
||||
.inst 0x4fa5e0dd // sdot v29.4s, v6.16b, v5.4b[1]
|
||||
.inst 0x4f85e8de // sdot v30.4s, v6.16b, v5.4b[2]
|
||||
.inst 0x4fa5e8df // sdot v31.4s, v6.16b, v5.4b[3]
|
||||
// load input for output 9-16
|
||||
ld1 {v4.4s, v5.4s}, [x12], #32
|
||||
.inst 0x4f80e0e8 // sdot v8.4s, v7.16b, v0.4b[0]
|
||||
.inst 0x4fa0e0e9 // sdot v9.4s, v7.16b, v0.4b[1]
|
||||
.inst 0x4f80e8ea // sdot v10.4s, v7.16b, v0.4b[2]
|
||||
.inst 0x4fa0e8eb // sdot v11.4s, v7.16b, v0.4b[3]
|
||||
.inst 0x4f81e0ec // sdot v12.4s, v7.16b, v1.4b[0]
|
||||
.inst 0x4fa1e0ed // sdot v13.4s, v7.16b, v1.4b[1]
|
||||
.inst 0x4f81e8ee // sdot v14.4s, v7.16b, v1.4b[2]
|
||||
.inst 0x4fa1e8ef // sdot v15.4s, v7.16b, v1.4b[3]
|
||||
|
||||
.inst 0x4f82e0f0 // sdot v16.4s, v7.16b, v2.4b[0]
|
||||
.inst 0x4fa2e0f1 // sdot v17.4s, v7.16b, v2.4b[1]
|
||||
.inst 0x4f82e8f2 // sdot v18.4s, v7.16b, v2.4b[2]
|
||||
.inst 0x4fa2e8f3 // sdot v19.4s, v7.16b, v2.4b[3]
|
||||
.inst 0x4f83e0f4 // sdot v20.4s, v7.16b, v3.4b[0]
|
||||
.inst 0x4fa3e0f5 // sdot v21.4s, v7.16b, v3.4b[1]
|
||||
.inst 0x4f83e8f6 // sdot v22.4s, v7.16b, v3.4b[2]
|
||||
.inst 0x4fa3e8f7 // sdot v23.4s, v7.16b, v3.4b[3]
|
||||
|
||||
.inst 0x4f84e0f8 // sdot v24.4s, v7.16b, v4.4b[0]
|
||||
.inst 0x4fa4e0f9 // sdot v25.4s, v7.16b, v4.4b[1]
|
||||
.inst 0x4f84e8fa // sdot v26.4s, v7.16b, v4.4b[2]
|
||||
.inst 0x4fa4e8fb // sdot v27.4s, v7.16b, v4.4b[3]
|
||||
.inst 0x4f85e0fc // sdot v28.4s, v7.16b, v5.4b[0]
|
||||
.inst 0x4fa5e0fd // sdot v29.4s, v7.16b, v5.4b[1]
|
||||
.inst 0x4f85e8fe // sdot v30.4s, v7.16b, v5.4b[2]
|
||||
.inst 0x4fa5e8ff // sdot v31.4s, v7.16b, v5.4b[3]
|
||||
b Quantization
|
||||
|
||||
LoopIcEndOne:
|
||||
.inst 0x4f82e0d0 // sdot v16.4s, v6.16b, v2.4b[0]
|
||||
.inst 0x4fa2e0d1 // sdot v17.4s, v6.16b, v2.4b[1]
|
||||
.inst 0x4f82e8d2 // sdot v18.4s, v6.16b, v2.4b[2]
|
||||
.inst 0x4fa2e8d3 // sdot v19.4s, v6.16b, v2.4b[3]
|
||||
.inst 0x4f83e0d4 // sdot v20.4s, v6.16b, v3.4b[0]
|
||||
.inst 0x4fa3e0d5 // sdot v21.4s, v6.16b, v3.4b[1]
|
||||
.inst 0x4f83e8d6 // sdot v22.4s, v6.16b, v3.4b[2]
|
||||
.inst 0x4fa3e8d7 // sdot v23.4s, v6.16b, v3.4b[3]
|
||||
|
||||
.inst 0x4f84e0d8 // sdot v24.4s, v6.16b, v4.4b[0]
|
||||
.inst 0x4fa4e0d9 // sdot v25.4s, v6.16b, v4.4b[1]
|
||||
.inst 0x4f84e8da // sdot v26.4s, v6.16b, v4.4b[2]
|
||||
.inst 0x4fa4e8db // sdot v27.4s, v6.16b, v4.4b[3]
|
||||
.inst 0x4f85e0dc // sdot v28.4s, v6.16b, v5.4b[0]
|
||||
.inst 0x4fa5e0dd // sdot v29.4s, v6.16b, v5.4b[1]
|
||||
.inst 0x4f85e8de // sdot v30.4s, v6.16b, v5.4b[2]
|
||||
.inst 0x4fa5e8df // sdot v31.4s, v6.16b, v5.4b[3]
|
||||
|
||||
Quantization:
|
||||
cbnz x21, PerChannel
|
||||
ld1r {v2.4s}, [x18]
|
||||
ld1r {v3.4s}, [x17]
|
||||
ld1r {v4.4s}, [x19]
|
||||
b QuantizeStart
|
||||
PerChannel:
|
||||
ld1 {v2.4s}, [x18]
|
||||
ld1 {v3.4s}, [x17]
|
||||
ld1 {v4.4s}, [x19]
|
||||
QuantizeStart:
|
||||
sqshl v8.4s, v8.4s, v2.4s
|
||||
sqshl v9.4s, v9.4s, v2.4s
|
||||
sqshl v10.4s, v10.4s, v2.4s
|
||||
sqshl v11.4s, v11.4s, v2.4s
|
||||
sqshl v12.4s, v12.4s, v2.4s
|
||||
sqshl v13.4s, v13.4s, v2.4s
|
||||
sqshl v14.4s, v14.4s, v2.4s
|
||||
sqshl v15.4s, v15.4s, v2.4s
|
||||
sqshl v16.4s, v16.4s, v2.4s
|
||||
sqshl v17.4s, v17.4s, v2.4s
|
||||
sqshl v18.4s, v18.4s, v2.4s
|
||||
sqshl v19.4s, v19.4s, v2.4s
|
||||
sqshl v20.4s, v20.4s, v2.4s
|
||||
sqshl v21.4s, v21.4s, v2.4s
|
||||
sqshl v22.4s, v22.4s, v2.4s
|
||||
sqshl v23.4s, v23.4s, v2.4s
|
||||
sqshl v24.4s, v24.4s, v2.4s
|
||||
sqshl v25.4s, v25.4s, v2.4s
|
||||
sqshl v26.4s, v26.4s, v2.4s
|
||||
sqshl v27.4s, v27.4s, v2.4s
|
||||
sqshl v28.4s, v28.4s, v2.4s
|
||||
sqshl v29.4s, v29.4s, v2.4s
|
||||
sqshl v30.4s, v30.4s, v2.4s
|
||||
sqshl v31.4s, v31.4s, v2.4s
|
||||
|
||||
sqrdmulh v8.4s, v8.4s, v3.4s
|
||||
sqrdmulh v9.4s, v9.4s, v3.4s
|
||||
sqrdmulh v10.4s, v10.4s, v3.4s
|
||||
sqrdmulh v11.4s, v11.4s, v3.4s
|
||||
sqrdmulh v12.4s, v12.4s, v3.4s
|
||||
sqrdmulh v13.4s, v13.4s, v3.4s
|
||||
sqrdmulh v14.4s, v14.4s, v3.4s
|
||||
sqrdmulh v15.4s, v15.4s, v3.4s
|
||||
sqrdmulh v16.4s, v16.4s, v3.4s
|
||||
sqrdmulh v17.4s, v17.4s, v3.4s
|
||||
sqrdmulh v18.4s, v18.4s, v3.4s
|
||||
sqrdmulh v19.4s, v19.4s, v3.4s
|
||||
sqrdmulh v20.4s, v20.4s, v3.4s
|
||||
sqrdmulh v21.4s, v21.4s, v3.4s
|
||||
sqrdmulh v22.4s, v22.4s, v3.4s
|
||||
sqrdmulh v23.4s, v23.4s, v3.4s
|
||||
sqrdmulh v24.4s, v24.4s, v3.4s
|
||||
sqrdmulh v25.4s, v25.4s, v3.4s
|
||||
sqrdmulh v26.4s, v26.4s, v3.4s
|
||||
sqrdmulh v27.4s, v27.4s, v3.4s
|
||||
sqrdmulh v28.4s, v28.4s, v3.4s
|
||||
sqrdmulh v29.4s, v29.4s, v3.4s
|
||||
sqrdmulh v30.4s, v30.4s, v3.4s
|
||||
sqrdmulh v31.4s, v31.4s, v3.4s
|
||||
|
||||
and v0.16b, v4.16b, v8.16b
|
||||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v8.4s, v8.4s, v0.4s
|
||||
srshl v8.4s, v8.4s, v4.4s
|
||||
and v1.16b, v4.16b, v9.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v9.4s, v9.4s, v1.4s
|
||||
srshl v9.4s, v9.4s, v4.4s
|
||||
and v2.16b, v4.16b, v10.16b
|
||||
sshr v2.4s, v2.4s, #31
|
||||
sqadd v10.4s, v10.4s, v2.4s
|
||||
srshl v10.4s, v10.4s, v4.4s
|
||||
and v3.16b, v4.16b, v11.16b
|
||||
sshr v3.4s, v3.4s, #31
|
||||
sqadd v11.4s, v11.4s, v3.4s
|
||||
srshl v11.4s, v11.4s, v4.4s
|
||||
and v0.16b, v4.16b, v12.16b
|
||||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v12.4s, v12.4s, v0.4s
|
||||
srshl v12.4s, v12.4s, v4.4s
|
||||
and v1.16b, v4.16b, v13.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v13.4s, v13.4s, v1.4s
|
||||
srshl v13.4s, v13.4s, v4.4s
|
||||
and v2.16b, v4.16b, v14.16b
|
||||
sshr v2.4s, v2.4s, #31
|
||||
sqadd v14.4s, v14.4s, v2.4s
|
||||
srshl v14.4s, v14.4s, v4.4s
|
||||
and v3.16b, v4.16b, v15.16b
|
||||
sshr v3.4s, v3.4s, #31
|
||||
sqadd v15.4s, v15.4s, v3.4s
|
||||
srshl v15.4s, v15.4s, v4.4s
|
||||
and v0.16b, v4.16b, v16.16b
|
||||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v16.4s, v16.4s, v0.4s
|
||||
srshl v16.4s, v16.4s, v4.4s
|
||||
and v1.16b, v4.16b, v17.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v17.4s, v17.4s, v1.4s
|
||||
srshl v17.4s, v17.4s, v4.4s
|
||||
and v2.16b, v4.16b, v18.16b
|
||||
sshr v2.4s, v2.4s, #31
|
||||
sqadd v18.4s, v18.4s, v2.4s
|
||||
srshl v18.4s, v18.4s, v4.4s
|
||||
and v3.16b, v4.16b, v19.16b
|
||||
sshr v3.4s, v3.4s, #31
|
||||
sqadd v19.4s, v19.4s, v3.4s
|
||||
srshl v19.4s, v19.4s, v4.4s
|
||||
and v0.16b, v4.16b, v20.16b
|
||||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v20.4s, v20.4s, v0.4s
|
||||
srshl v20.4s, v20.4s, v4.4s
|
||||
and v1.16b, v4.16b, v21.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v21.4s, v21.4s, v1.4s
|
||||
srshl v21.4s, v21.4s, v4.4s
|
||||
and v2.16b, v4.16b, v22.16b
|
||||
sshr v2.4s, v2.4s, #31
|
||||
sqadd v22.4s, v22.4s, v2.4s
|
||||
srshl v22.4s, v22.4s, v4.4s
|
||||
and v3.16b, v4.16b, v23.16b
|
||||
sshr v3.4s, v3.4s, #31
|
||||
sqadd v23.4s, v23.4s, v3.4s
|
||||
srshl v23.4s, v23.4s, v4.4s
|
||||
and v0.16b, v4.16b, v24.16b
|
||||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v24.4s, v24.4s, v0.4s
|
||||
srshl v24.4s, v24.4s, v4.4s
|
||||
and v1.16b, v4.16b, v25.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v25.4s, v25.4s, v1.4s
|
||||
srshl v25.4s, v25.4s, v4.4s
|
||||
and v2.16b, v4.16b, v26.16b
|
||||
sshr v2.4s, v2.4s, #31
|
||||
sqadd v26.4s, v26.4s, v2.4s
|
||||
srshl v26.4s, v26.4s, v4.4s
|
||||
and v3.16b, v4.16b, v27.16b
|
||||
sshr v3.4s, v3.4s, #31
|
||||
sqadd v27.4s, v27.4s, v3.4s
|
||||
srshl v27.4s, v27.4s, v4.4s
|
||||
and v0.16b, v4.16b, v28.16b
|
||||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v28.4s, v28.4s, v0.4s
|
||||
srshl v28.4s, v28.4s, v4.4s
|
||||
and v1.16b, v4.16b, v29.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v29.4s, v29.4s, v1.4s
|
||||
srshl v29.4s, v29.4s, v4.4s
|
||||
and v2.16b, v4.16b, v30.16b
|
||||
sshr v2.4s, v2.4s, #31
|
||||
sqadd v30.4s, v30.4s, v2.4s
|
||||
srshl v30.4s, v30.4s, v4.4s
|
||||
and v3.16b, v4.16b, v31.16b
|
||||
sshr v3.4s, v3.4s, #31
|
||||
sqadd v31.4s, v31.4s, v3.4s
|
||||
srshl v31.4s, v31.4s, v4.4s
|
||||
|
||||
dup v5.4s, w16
|
||||
add v8.4s, v8.4s, v5.4s
|
||||
add v9.4s, v9.4s, v5.4s
|
||||
add v10.4s, v10.4s, v5.4s
|
||||
add v11.4s, v11.4s, v5.4s
|
||||
add v12.4s, v12.4s, v5.4s
|
||||
add v13.4s, v13.4s, v5.4s
|
||||
add v14.4s, v14.4s, v5.4s
|
||||
add v15.4s, v15.4s, v5.4s
|
||||
add v16.4s, v16.4s, v5.4s
|
||||
add v17.4s, v17.4s, v5.4s
|
||||
add v18.4s, v18.4s, v5.4s
|
||||
add v19.4s, v19.4s, v5.4s
|
||||
add v20.4s, v20.4s, v5.4s
|
||||
add v21.4s, v21.4s, v5.4s
|
||||
add v22.4s, v22.4s, v5.4s
|
||||
add v23.4s, v23.4s, v5.4s
|
||||
add v24.4s, v24.4s, v5.4s
|
||||
add v25.4s, v25.4s, v5.4s
|
||||
add v26.4s, v26.4s, v5.4s
|
||||
add v27.4s, v27.4s, v5.4s
|
||||
add v28.4s, v28.4s, v5.4s
|
||||
add v29.4s, v29.4s, v5.4s
|
||||
add v30.4s, v30.4s, v5.4s
|
||||
add v31.4s, v31.4s, v5.4s
|
||||
|
||||
dup v0.4s, w8
|
||||
smax v8.4s, v8.4s, v0.4s
|
||||
smax v9.4s, v9.4s, v0.4s
|
||||
smax v10.4s, v10.4s, v0.4s
|
||||
smax v11.4s, v11.4s, v0.4s
|
||||
smax v12.4s, v12.4s, v0.4s
|
||||
smax v13.4s, v13.4s, v0.4s
|
||||
smax v14.4s, v14.4s, v0.4s
|
||||
smax v15.4s, v15.4s, v0.4s
|
||||
smax v16.4s, v16.4s, v0.4s
|
||||
smax v17.4s, v17.4s, v0.4s
|
||||
smax v18.4s, v18.4s, v0.4s
|
||||
smax v19.4s, v19.4s, v0.4s
|
||||
smax v20.4s, v20.4s, v0.4s
|
||||
smax v21.4s, v21.4s, v0.4s
|
||||
smax v22.4s, v22.4s, v0.4s
|
||||
smax v23.4s, v23.4s, v0.4s
|
||||
smax v24.4s, v24.4s, v0.4s
|
||||
smax v25.4s, v25.4s, v0.4s
|
||||
smax v26.4s, v26.4s, v0.4s
|
||||
smax v27.4s, v27.4s, v0.4s
|
||||
smax v28.4s, v28.4s, v0.4s
|
||||
smax v29.4s, v29.4s, v0.4s
|
||||
smax v30.4s, v30.4s, v0.4s
|
||||
smax v31.4s, v31.4s, v0.4s
|
||||
|
||||
dup v1.4s, w9
|
||||
smin v8.4s, v8.4s, v1.4s
|
||||
smin v9.4s, v9.4s, v1.4s
|
||||
smin v10.4s, v10.4s, v1.4s
|
||||
smin v11.4s, v11.4s, v1.4s
|
||||
smin v12.4s, v12.4s, v1.4s
|
||||
smin v13.4s, v13.4s, v1.4s
|
||||
smin v14.4s, v14.4s, v1.4s
|
||||
smin v15.4s, v15.4s, v1.4s
|
||||
smin v16.4s, v16.4s, v1.4s
|
||||
smin v17.4s, v17.4s, v1.4s
|
||||
smin v18.4s, v18.4s, v1.4s
|
||||
smin v19.4s, v19.4s, v1.4s
|
||||
smin v20.4s, v20.4s, v1.4s
|
||||
smin v21.4s, v21.4s, v1.4s
|
||||
smin v22.4s, v22.4s, v1.4s
|
||||
smin v23.4s, v23.4s, v1.4s
|
||||
smin v24.4s, v24.4s, v1.4s
|
||||
smin v25.4s, v25.4s, v1.4s
|
||||
smin v26.4s, v26.4s, v1.4s
|
||||
smin v27.4s, v27.4s, v1.4s
|
||||
smin v28.4s, v28.4s, v1.4s
|
||||
smin v29.4s, v29.4s, v1.4s
|
||||
smin v30.4s, v30.4s, v1.4s
|
||||
smin v31.4s, v31.4s, v1.4s
|
||||
|
||||
sqxtn v6.4h, v8.4s
|
||||
sqxtn2 v6.8h, v9.4s
|
||||
sqxtn v0.8b, v6.8h
|
||||
sqxtn v7.4h, v10.4s
|
||||
sqxtn2 v7.8h, v11.4s
|
||||
sqxtn2 v0.16b, v7.8h
|
||||
|
||||
sqxtn v6.4h, v12.4s
|
||||
sqxtn2 v6.8h, v13.4s
|
||||
sqxtn v1.8b, v6.8h
|
||||
sqxtn v7.4h, v14.4s
|
||||
sqxtn2 v7.8h, v15.4s
|
||||
sqxtn2 v1.16b, v7.8h
|
||||
|
||||
sqxtn v6.4h, v16.4s
|
||||
sqxtn2 v6.8h, v17.4s
|
||||
sqxtn v2.8b, v6.8h
|
||||
sqxtn v7.4h, v18.4s
|
||||
sqxtn2 v7.8h, v19.4s
|
||||
sqxtn2 v2.16b, v7.8h
|
||||
|
||||
sqxtn v6.4h, v20.4s
|
||||
sqxtn2 v6.8h, v21.4s
|
||||
sqxtn v3.8b, v6.8h
|
||||
sqxtn v7.4h, v22.4s
|
||||
sqxtn2 v7.8h, v23.4s
|
||||
sqxtn2 v3.16b, v7.8h
|
||||
|
||||
sqxtn v6.4h, v24.4s
|
||||
sqxtn2 v6.8h, v25.4s
|
||||
sqxtn v4.8b, v6.8h
|
||||
sqxtn v7.4h, v26.4s
|
||||
sqxtn2 v7.8h, v27.4s
|
||||
sqxtn2 v4.16b, v7.8h
|
||||
|
||||
sqxtn v6.4h, v28.4s
|
||||
sqxtn2 v6.8h, v29.4s
|
||||
sqxtn v5.8b, v6.8h
|
||||
sqxtn v7.4h, v30.4s
|
||||
sqxtn2 v7.8h, v31.4s
|
||||
sqxtn2 v5.16b, v7.8h
|
||||
// prefetching is not prefered while writing results in spite of cache missings
|
||||
// you could try prfm pstl2strm
|
||||
WriteStart:
|
||||
cmp x6, #1
|
||||
beq Write1
|
||||
cmp x6, #2
|
||||
beq Write2
|
||||
cmp x6, #3
|
||||
beq Write3
|
||||
b Write4
|
||||
Write1:
|
||||
st1 {v0.b}[0], [x11], x7
|
||||
st1 {v0.b}[4], [x11], x7
|
||||
st1 {v0.b}[8], [x11], x7
|
||||
st1 {v0.b}[12], [x11], x7
|
||||
st1 {v1.b}[0], [x11], x7
|
||||
st1 {v1.b}[4], [x11], x7
|
||||
st1 {v1.b}[8], [x11], x7
|
||||
st1 {v1.b}[12], [x11], x7
|
||||
st1 {v2.b}[0], [x11], x7
|
||||
st1 {v2.b}[4], [x11], x7
|
||||
st1 {v2.b}[8], [x11], x7
|
||||
st1 {v2.b}[12], [x11], x7
|
||||
st1 {v3.b}[0], [x11], x7
|
||||
st1 {v3.b}[4], [x11], x7
|
||||
st1 {v3.b}[8], [x11], x7
|
||||
st1 {v3.b}[12], [x11], x7
|
||||
st1 {v4.b}[0], [x11], x7
|
||||
st1 {v4.b}[4], [x11], x7
|
||||
st1 {v4.b}[8], [x11], x7
|
||||
st1 {v4.b}[12], [x11], x7
|
||||
st1 {v5.b}[0], [x11], x7
|
||||
st1 {v5.b}[4], [x11], x7
|
||||
st1 {v5.b}[8], [x11], x7
|
||||
st1 {v5.b}[12], [x11]
|
||||
add x0, x0, #1
|
||||
b WriteEnd
|
||||
Write2:
|
||||
st1 {v0.h}[0], [x11], x7
|
||||
st1 {v0.h}[2], [x11], x7
|
||||
st1 {v0.h}[4], [x11], x7
|
||||
st1 {v0.h}[6], [x11], x7
|
||||
st1 {v1.h}[0], [x11], x7
|
||||
st1 {v1.h}[2], [x11], x7
|
||||
st1 {v1.h}[4], [x11], x7
|
||||
st1 {v1.h}[6], [x11], x7
|
||||
st1 {v2.h}[0], [x11], x7
|
||||
st1 {v2.h}[2], [x11], x7
|
||||
st1 {v2.h}[4], [x11], x7
|
||||
st1 {v2.h}[6], [x11], x7
|
||||
st1 {v3.h}[0], [x11], x7
|
||||
st1 {v3.h}[2], [x11], x7
|
||||
st1 {v3.h}[4], [x11], x7
|
||||
st1 {v3.h}[6], [x11], x7
|
||||
st1 {v4.h}[0], [x11], x7
|
||||
st1 {v4.h}[2], [x11], x7
|
||||
st1 {v4.h}[4], [x11], x7
|
||||
st1 {v4.h}[6], [x11], x7
|
||||
st1 {v5.h}[0], [x11], x7
|
||||
st1 {v5.h}[2], [x11], x7
|
||||
st1 {v5.h}[4], [x11], x7
|
||||
st1 {v5.h}[6], [x11]
|
||||
add x0, x0, #2
|
||||
b WriteEnd
|
||||
Write3:
|
||||
add x14, x11, #2
|
||||
st1 {v0.h}[0], [x11], x7
|
||||
st1 {v0.b}[2], [x14], x7
|
||||
st1 {v0.h}[2], [x11], x7
|
||||
st1 {v0.b}[6], [x14], x7
|
||||
st1 {v0.h}[4], [x11], x7
|
||||
st1 {v0.b}[10], [x14], x7
|
||||
st1 {v0.h}[6], [x11], x7
|
||||
st1 {v0.b}[14], [x14], x7
|
||||
st1 {v1.h}[0], [x11], x7
|
||||
st1 {v1.b}[2], [x14], x7
|
||||
st1 {v1.h}[2], [x11], x7
|
||||
st1 {v1.b}[6], [x14], x7
|
||||
st1 {v1.h}[4], [x11], x7
|
||||
st1 {v1.b}[10], [x14], x7
|
||||
st1 {v1.h}[6], [x11], x7
|
||||
st1 {v1.b}[14], [x14], x7
|
||||
st1 {v2.h}[0], [x11], x7
|
||||
st1 {v2.b}[2], [x14], x7
|
||||
st1 {v2.h}[2], [x11], x7
|
||||
st1 {v2.b}[6], [x14], x7
|
||||
st1 {v2.h}[4], [x11], x7
|
||||
st1 {v2.b}[10], [x14], x7
|
||||
st1 {v2.h}[6], [x11], x7
|
||||
st1 {v2.b}[14], [x14], x7
|
||||
st1 {v3.h}[0], [x11], x7
|
||||
st1 {v3.b}[2], [x14], x7
|
||||
st1 {v3.h}[2], [x11], x7
|
||||
st1 {v3.b}[6], [x14], x7
|
||||
st1 {v3.h}[4], [x11], x7
|
||||
st1 {v3.b}[10], [x14], x7
|
||||
st1 {v3.h}[6], [x11], x7
|
||||
st1 {v3.b}[14], [x14], x7
|
||||
st1 {v4.h}[0], [x11], x7
|
||||
st1 {v4.b}[2], [x14], x7
|
||||
st1 {v4.h}[2], [x11], x7
|
||||
st1 {v4.b}[6], [x14], x7
|
||||
st1 {v4.h}[4], [x11], x7
|
||||
st1 {v4.b}[10], [x14], x7
|
||||
st1 {v4.h}[6], [x11], x7
|
||||
st1 {v4.b}[14], [x14], x7
|
||||
st1 {v5.h}[0], [x11], x7
|
||||
st1 {v5.b}[2], [x14], x7
|
||||
st1 {v5.h}[2], [x11], x7
|
||||
st1 {v5.b}[6], [x14], x7
|
||||
st1 {v5.h}[4], [x11], x7
|
||||
st1 {v5.b}[10], [x14], x7
|
||||
st1 {v5.h}[6], [x11], x7
|
||||
st1 {v5.b}[14], [x14], x7
|
||||
add x0, x0, #3
|
||||
b WriteEnd
|
||||
Write4:
|
||||
st1 {v0.s}[0], [x11], x7
|
||||
st1 {v0.s}[1], [x11], x7
|
||||
st1 {v0.s}[2], [x11], x7
|
||||
st1 {v0.s}[3], [x11], x7
|
||||
st1 {v1.s}[0], [x11], x7
|
||||
st1 {v1.s}[1], [x11], x7
|
||||
st1 {v1.s}[2], [x11], x7
|
||||
st1 {v1.s}[3], [x11], x7
|
||||
st1 {v2.s}[0], [x11], x7
|
||||
st1 {v2.s}[1], [x11], x7
|
||||
st1 {v2.s}[2], [x11], x7
|
||||
st1 {v2.s}[3], [x11], x7
|
||||
st1 {v3.s}[0], [x11], x7
|
||||
st1 {v3.s}[1], [x11], x7
|
||||
st1 {v3.s}[2], [x11], x7
|
||||
st1 {v3.s}[3], [x11], x7
|
||||
st1 {v4.s}[0], [x11], x7
|
||||
st1 {v4.s}[1], [x11], x7
|
||||
st1 {v4.s}[2], [x11], x7
|
||||
st1 {v4.s}[3], [x11], x7
|
||||
st1 {v5.s}[0], [x11], x7
|
||||
st1 {v5.s}[1], [x11], x7
|
||||
st1 {v5.s}[2], [x11], x7
|
||||
st1 {v5.s}[3], [x11]
|
||||
add x0, x0, #4
|
||||
|
||||
WriteEnd:
|
||||
|
||||
subs x10, x10, #1
|
||||
bne LoopKsize
|
||||
|
||||
subs x6, x6, #4
|
||||
cbz x21, NoChannelForward
|
||||
cbz x20, NoSumForward
|
||||
add x15, x15, #16
|
||||
NoSumForward:
|
||||
add x17, x17, #16
|
||||
add x18, x18, #16
|
||||
add x19, x19, #16
|
||||
NoChannelForward:
|
||||
cbz x3, NoStepFowrard
|
||||
add x3, x3, #16
|
||||
NoStepFowrard:
|
||||
bgt LoopOc
|
||||
|
||||
sub sp, sp, #176
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
ldp x19, x20, [sp], #16
|
||||
ldp x21, x22, [sp], #16
|
||||
ldp x23, x24, [sp], #16
|
||||
ret
|
||||
#endif
|
|
@ -62,10 +62,6 @@ int32x4_t ClacScaledInput(int32x4_t input, int32x4_t left_shift_result_vec, int3
|
|||
#endif
|
||||
|
||||
#ifdef ENABLE_ARM32
|
||||
void IndirectGemmInt8_2x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize,
|
||||
size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min,
|
||||
size_t act_max, size_t out_zp, int32_t *out_multiplier, int32_t *shift_before,
|
||||
int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset);
|
||||
void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height,
|
||||
int width, int in_kh_step, int in_kw_step, int channel, int8_t in_zp, int32_t out_zp,
|
||||
int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t acc_min,
|
||||
|
@ -76,10 +72,6 @@ void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *wei
|
|||
void PostFuncInt8C4Neon64(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc4div, size_t oc4res,
|
||||
size_t plane, size_t stride, int32_t multiplier, int32_t left_shift, int32_t right_shift,
|
||||
int32_t zp, int32_t mini, int32_t maxi);
|
||||
void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize,
|
||||
size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min,
|
||||
size_t act_max, size_t out_zp, int32_t *out_multiplier, int32_t *shift_before,
|
||||
int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset);
|
||||
void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias,
|
||||
int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp,
|
||||
int32_t out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift,
|
||||
|
|
|
@ -811,38 +811,25 @@ void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int
|
|||
return;
|
||||
}
|
||||
|
||||
void Conv1x1Int8Arm32(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
|
||||
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
|
||||
int32_t *multiplier, ConvParameter *conv_param) {
|
||||
int is_per_channel = conv_param->conv_quant_arg_.filter_arg_num_ != 1 ? true : false;
|
||||
#ifdef ENABLE_ARM32
|
||||
MatmulInt8Neon32(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias,
|
||||
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
|
||||
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift,
|
||||
conv_param->output_channel_, is_per_channel);
|
||||
#else
|
||||
MatMulInt8_4x2_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias,
|
||||
left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
|
||||
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
|
||||
is_per_channel);
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
||||
void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
|
||||
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
|
||||
int32_t *multiplier, ConvParameter *conv_param) {
|
||||
int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp) {
|
||||
int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1;
|
||||
#ifdef ENABLE_ARM64
|
||||
MatmulInt8Neon64(packed_input, packed_weight, dst, UP_ROUND(row, C4NUM), UP_ROUND(col, C4NUM), deep16, input_sum,
|
||||
bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
|
||||
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, row, col,
|
||||
conv_param->output_channel_, is_per_oc);
|
||||
#ifdef ENABLE_ARM32
|
||||
MatmulInt8Neon32Opt(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias,
|
||||
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
|
||||
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift,
|
||||
conv_param->output_channel_, is_per_oc, filter_zp);
|
||||
#elif ENABLE_ARM64
|
||||
MatmulInt8Neon64Opt(packed_input, packed_weight, dst, UP_ROUND(row, C4NUM), UP_ROUND(col, C4NUM), deep16, input_sum,
|
||||
bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
|
||||
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, row,
|
||||
col, conv_param->output_channel_, is_per_oc, filter_zp);
|
||||
#else
|
||||
MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias,
|
||||
left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
|
||||
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
|
||||
is_per_oc);
|
||||
MatmulInt8Opt(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias,
|
||||
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
|
||||
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift,
|
||||
conv_param->output_channel_, is_per_oc, filter_zp);
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -43,13 +43,10 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i
|
|||
size_t plane_size, ConvParameter *conv_param);
|
||||
void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
|
||||
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
|
||||
int32_t *multiplier, ConvParameter *conv_param);
|
||||
int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp);
|
||||
void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
|
||||
const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift,
|
||||
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int32_t *filter_zp);
|
||||
void Conv1x1Int8Arm32(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
|
||||
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
|
||||
int32_t *multiplier, ConvParameter *conv_param);
|
||||
|
||||
// int8 convolution 3x3
|
||||
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
|
||||
|
|
|
@ -250,6 +250,41 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row,
|
|||
return;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ARM
|
||||
void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums,
|
||||
const int *bias, int mini, int maxi, int out_zp, int32_t *multiplier, int32_t *left_shift,
|
||||
int32_t *right_shift, int stride, int filter_peroc, int32_t *filter_zp) {
|
||||
int col_tile = C4NUM;
|
||||
/* support per-layer && weight per-channel */
|
||||
/* row4x16-major * row16x2-major => (int8)row-major*/
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
int r4div = r / C4NUM, r4mod = r % C4NUM;
|
||||
int c4div = c / col_tile, c4mod = c % col_tile;
|
||||
size_t ci = r * stride + c;
|
||||
int32_t value = 0;
|
||||
for (int d = 0; d < deep16; d++) {
|
||||
int d16div = d / C16NUM, d16mod = d % C16NUM;
|
||||
size_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod;
|
||||
size_t bi = c4div * deep16 * col_tile + d16div * col_tile * C16NUM + c4mod * C16NUM + d16mod;
|
||||
value = value + a[ai] * b[bi];
|
||||
}
|
||||
int32_t cur_input_sum = filter_peroc ? a_sums[r] * filter_zp[c] : a_sums[r];
|
||||
value -= cur_input_sum;
|
||||
value += bias[c];
|
||||
int32_t cur_left_shift = filter_peroc ? left_shift[c] : left_shift[0];
|
||||
int32_t cur_right_shift = filter_peroc ? right_shift[c] : right_shift[0];
|
||||
int32_t cur_multiplier = filter_peroc ? multiplier[c] : multiplier[0];
|
||||
value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + out_zp;
|
||||
value = MSMIN(maxi, value);
|
||||
value = MSMAX(mini, value);
|
||||
dst[ci] = (int8_t)value;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
|
||||
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
|
||||
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
|
||||
|
|
|
@ -60,6 +60,9 @@ void MatMulInt8_4x16_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row
|
|||
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
|
||||
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
|
||||
size_t per_channel, int32_t *filter_zp);
|
||||
void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums,
|
||||
const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift,
|
||||
int32_t *right_shift, int stride, int filter_peroc, int32_t *filter_zp);
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums,
|
||||
|
@ -68,11 +71,18 @@ void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, i
|
|||
|
||||
void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16,
|
||||
const int *input_sum, const int *bias);
|
||||
void MatmulInt8Neon64Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16,
|
||||
const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier,
|
||||
int32_t *left_shift, int32_t *right_shift, int row, int col, int stride, int filter_peroc,
|
||||
int32_t *filter_zp);
|
||||
#endif
|
||||
#ifdef ENABLE_ARM32
|
||||
void MatmulInt8Neon32(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16,
|
||||
const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp,
|
||||
int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel);
|
||||
void MatmulInt8Neon32Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums,
|
||||
const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier,
|
||||
int32_t *left_shift, int32_t *right_shift, int stride, int filter_peroc, int32_t *filter_zp);
|
||||
#endif
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -92,27 +92,19 @@ int Convolution1x1Int8OcOptPre(void *cdata, int task_id) {
|
|||
}
|
||||
|
||||
int Convolution1x1Int8CPUKernel::OcRun(int task_id) {
|
||||
#ifdef ENABLE_ARM32
|
||||
return RunArm32Oc(task_id);
|
||||
#else
|
||||
if (support_optimize_) {
|
||||
return RunArm64OptOc(task_id);
|
||||
} else {
|
||||
return RunArm64Oc(task_id);
|
||||
return RunArmOc(task_id);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
int Convolution1x1Int8CPUKernel::HwRun(int task_id) {
|
||||
#ifdef ENABLE_ARM32
|
||||
return RunArm32Hw(task_id);
|
||||
#else
|
||||
if (support_optimize_) {
|
||||
return RunArm64OptHw(task_id);
|
||||
} else {
|
||||
return RunArm64Hw(task_id);
|
||||
return RunArmHw(task_id);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
int Convolution1x1Int8CPUKernel::InitRunBuf() {
|
||||
|
@ -124,6 +116,7 @@ int Convolution1x1Int8CPUKernel::InitRunBuf() {
|
|||
|
||||
size_t size = support_optimize_ ? UP_ROUND(matmul_param_->row_, C8NUM) * UP_ROUND(matmul_param_->deep_, C4NUM)
|
||||
: UP_ROUND(matmul_param_->row_, C4NUM) * UP_ROUND(matmul_param_->deep_, C16NUM);
|
||||
|
||||
packed_input_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(size * sizeof(int8_t)));
|
||||
if (packed_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "conv1x1 int8 Malloc packed_input_ error!";
|
||||
|
@ -333,8 +326,8 @@ int Convolution1x1Int8CPUKernel::InitParam() {
|
|||
matmul_param_->deep_4_ = UP_ROUND(matmul_param_->deep_, C4NUM);
|
||||
matmul_param_->deep_16_ = UP_ROUND(matmul_param_->deep_, C16NUM);
|
||||
|
||||
int row_pack_count = 0;
|
||||
int col_pack_count = 0;
|
||||
int row_pack_count;
|
||||
int col_pack_count;
|
||||
|
||||
#ifdef ENABLE_ARM32
|
||||
row_pack_count = C4NUM;
|
||||
|
@ -350,15 +343,7 @@ int Convolution1x1Int8CPUKernel::InitParam() {
|
|||
#endif
|
||||
|
||||
/* init input sum size */
|
||||
if (support_optimize_) {
|
||||
input_sum_size_ = UP_ROUND(matmul_param_->row_, row_pack_count);
|
||||
} else {
|
||||
if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) {
|
||||
input_sum_size_ = UP_ROUND(matmul_param_->col_, col_pack_count) * UP_ROUND(matmul_param_->row_, row_pack_count);
|
||||
} else {
|
||||
input_sum_size_ = UP_ROUND(matmul_param_->row_, row_pack_count);
|
||||
}
|
||||
}
|
||||
input_sum_size_ = UP_ROUND(matmul_param_->row_, row_pack_count);
|
||||
|
||||
if (pre_trans_input_) {
|
||||
input_ptr_ = reinterpret_cast<int8_t *>(malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(int8_t)));
|
||||
|
@ -404,7 +389,7 @@ void Convolution1x1Int8CPUKernel::Pre1x1Trans(int8_t *src_input, int8_t *src_out
|
|||
return;
|
||||
}
|
||||
|
||||
int Convolution1x1Int8CPUKernel::RunArm64Hw(int task_id) {
|
||||
int Convolution1x1Int8CPUKernel::RunArmHw(int task_id) {
|
||||
int cur_stride = thread_stride_hw_ * C4NUM;
|
||||
int res_stride = matmul_param_->row_ - task_id * thread_stride_hw_ * C4NUM;
|
||||
int cur_hw = MSMIN(cur_stride, res_stride);
|
||||
|
@ -415,51 +400,20 @@ int Convolution1x1Int8CPUKernel::RunArm64Hw(int task_id) {
|
|||
int8_t *hw_in = input_ptr_ + task_id * thread_stride_hw_ * C4NUM * conv_param_->input_channel_;
|
||||
int8_t *hw_out = output_ptr_ + task_id * thread_stride_hw_ * C4NUM * conv_param_->output_channel_;
|
||||
int8_t *hw_packed_in = packed_input_ + task_id * thread_stride_hw_ * C4NUM * matmul_param_->deep_16_;
|
||||
int32_t *hw_input_sum = filter_peroc_ ? input_sum_ + task_id * thread_stride_hw_ * C4NUM * matmul_param_->col_4_
|
||||
: input_sum_ + task_id * thread_stride_hw_ * C4NUM;
|
||||
int32_t *hw_input_sum = input_sum_ + task_id * thread_stride_hw_ * C4NUM;
|
||||
|
||||
RowMajor2Row16x4MajorInt8(hw_in, hw_packed_in, cur_hw, matmul_param_->deep_);
|
||||
|
||||
if (filter_peroc_) {
|
||||
PackInputSum16x4PerChannel(hw_packed_in, hw_input_sum, filter_zp_ptr_, cur_hw, matmul_param_->deep_,
|
||||
matmul_param_->col_);
|
||||
PackInputSum16x4PerLayer(hw_packed_in, hw_input_sum, 1, UP_ROUND(cur_hw, C4NUM), matmul_param_->deep_16_);
|
||||
} else {
|
||||
PackInputSum16x4PerLayer(hw_packed_in, hw_input_sum, conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_,
|
||||
UP_ROUND(cur_hw, C4NUM), matmul_param_->deep_16_);
|
||||
}
|
||||
|
||||
Conv1x1Int8(hw_packed_in, packed_weight_, hw_out, hw_input_sum, reinterpret_cast<int32_t *>(bias_data_), cur_hw,
|
||||
matmul_param_->col_, matmul_param_->deep_16_, left_shift_, right_shift_, multiplier_, conv_param_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Convolution1x1Int8CPUKernel::RunArm32Hw(int task_id) {
|
||||
int cur_stride = thread_stride_hw_ * C4NUM;
|
||||
int res_stride = matmul_param_->row_ - task_id * thread_stride_hw_ * C4NUM;
|
||||
int cur_hw = MSMIN(cur_stride, res_stride);
|
||||
if (cur_hw <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int8_t *hw_in = input_ptr_ + task_id * thread_stride_hw_ * C4NUM * conv_param_->input_channel_;
|
||||
int8_t *hw_out = output_ptr_ + task_id * thread_stride_hw_ * C4NUM * conv_param_->output_channel_;
|
||||
int8_t *hw_packed_in = packed_input_ + task_id * thread_stride_hw_ * C4NUM * matmul_param_->deep_16_;
|
||||
int32_t *hw_input_sum = filter_peroc_ ? input_sum_ + task_id * thread_stride_hw_ * C4NUM * matmul_param_->col_2_
|
||||
: input_sum_ + task_id * thread_stride_hw_ * C4NUM;
|
||||
|
||||
RowMajor2Row16x4MajorInt8(hw_in, hw_packed_in, cur_hw, matmul_param_->deep_);
|
||||
|
||||
if (filter_peroc_) {
|
||||
PackInputSum16x4PerChannelArm32(hw_packed_in, hw_input_sum, filter_zp_ptr_, cur_hw, conv_param_->input_channel_,
|
||||
conv_param_->output_channel_);
|
||||
} else {
|
||||
PackInputSum16x4PerLayer(hw_packed_in, hw_input_sum, conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_,
|
||||
UP_ROUND(cur_hw, C4NUM), matmul_param_->deep_16_);
|
||||
}
|
||||
|
||||
Conv1x1Int8Arm32(hw_packed_in, packed_weight_, hw_out, hw_input_sum, reinterpret_cast<int32_t *>(bias_data_), cur_hw,
|
||||
matmul_param_->col_, matmul_param_->deep_16_, left_shift_, right_shift_, multiplier_, conv_param_);
|
||||
|
||||
matmul_param_->col_, matmul_param_->deep_16_, left_shift_, right_shift_, multiplier_, conv_param_,
|
||||
filter_zp_ptr_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -489,26 +443,6 @@ int Convolution1x1Int8CPUKernel::RunArm64OptHw(int task_id) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int Convolution1x1Int8CPUKernel::RunArm32Oc(int task_id) {
|
||||
int stride = thread_stride_oc_ * C2NUM;
|
||||
int cur_stride = task_id * stride;
|
||||
int res_stride = matmul_param_->col_ - cur_stride;
|
||||
int cur_oc = MSMIN(stride, res_stride);
|
||||
if (cur_oc <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int32_t *cur_input_sum = filter_peroc_ ? input_sum_ + cur_stride * matmul_param_->row_4_ : input_sum_;
|
||||
int32_t *cur_left_shift = filter_peroc_ ? left_shift_ + cur_stride : conv_param_->conv_quant_arg_.left_shift_;
|
||||
int32_t *cur_right_shift = filter_peroc_ ? right_shift_ + cur_stride : conv_param_->conv_quant_arg_.right_shift_;
|
||||
int32_t *cur_multiplier = filter_peroc_ ? multiplier_ + cur_stride : conv_param_->conv_quant_arg_.quant_multiplier_;
|
||||
|
||||
Conv1x1Int8Arm32(packed_input_, packed_weight_ + cur_stride * matmul_param_->deep_16_, output_ptr_ + cur_stride,
|
||||
cur_input_sum, reinterpret_cast<int32_t *>(bias_data_) + cur_stride, matmul_param_->row_, cur_oc,
|
||||
matmul_param_->deep_16_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Convolution1x1Int8CPUKernel::RunArm64OptOc(int task_id) {
|
||||
int stride = thread_stride_oc_ * C16NUM;
|
||||
int cur_stride = task_id * stride;
|
||||
|
@ -531,8 +465,13 @@ int Convolution1x1Int8CPUKernel::RunArm64OptOc(int task_id) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int Convolution1x1Int8CPUKernel::RunArm64Oc(int task_id) {
|
||||
int stride = thread_stride_oc_ * C4NUM;
|
||||
int Convolution1x1Int8CPUKernel::RunArmOc(int task_id) {
|
||||
#ifdef ENABLE_ARM32
|
||||
int col_tile = C2NUM;
|
||||
#else
|
||||
int col_tile = C4NUM;
|
||||
#endif
|
||||
int stride = thread_stride_oc_ * col_tile;
|
||||
int cur_stride = task_id * stride;
|
||||
int res_stride = matmul_param_->col_ - cur_stride;
|
||||
int cur_oc = MSMIN(stride, res_stride);
|
||||
|
@ -540,14 +479,14 @@ int Convolution1x1Int8CPUKernel::RunArm64Oc(int task_id) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int32_t *cur_input_sum = filter_peroc_ ? input_sum_ + cur_stride * matmul_param_->row_4_ : input_sum_;
|
||||
int32_t *cur_left_shift = filter_peroc_ ? left_shift_ + cur_stride : conv_param_->conv_quant_arg_.left_shift_;
|
||||
int32_t *cur_right_shift = filter_peroc_ ? right_shift_ + cur_stride : conv_param_->conv_quant_arg_.right_shift_;
|
||||
int32_t *cur_multiplier = filter_peroc_ ? multiplier_ + cur_stride : conv_param_->conv_quant_arg_.quant_multiplier_;
|
||||
int32_t *cur_zp = filter_peroc_ ? filter_zp_ptr_ + cur_stride : filter_zp_ptr_;
|
||||
|
||||
Conv1x1Int8(packed_input_, packed_weight_ + cur_stride * matmul_param_->deep_16_, output_ptr_ + cur_stride,
|
||||
cur_input_sum, reinterpret_cast<int32_t *>(bias_data_) + cur_stride, matmul_param_->row_, cur_oc,
|
||||
matmul_param_->deep_16_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_);
|
||||
input_sum_, reinterpret_cast<int32_t *>(bias_data_) + cur_stride, matmul_param_->row_, cur_oc,
|
||||
matmul_param_->deep_16_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_, cur_zp);
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -592,7 +531,12 @@ int Convolution1x1Int8CPUKernel::Run() {
|
|||
ParallelLaunch(this->context_->thread_pool_, Convolution1x1Int8OcOptPre, this, thread_count_hw_);
|
||||
} else {
|
||||
RowMajor2Row16x4MajorInt8(input_ptr_, packed_input_, matmul_param_->row_, matmul_param_->deep_);
|
||||
PackInputSum16x4Int8(packed_input_, input_sum_, filter_zp_ptr_, conv_param_);
|
||||
if (filter_peroc_) {
|
||||
PackInputSum16x4PerLayer(packed_input_, input_sum_, 1, matmul_param_->row_4_, matmul_param_->deep_16_);
|
||||
} else {
|
||||
PackInputSum16x4PerLayer(packed_input_, input_sum_, conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_,
|
||||
matmul_param_->row_4_, matmul_param_->deep_16_);
|
||||
}
|
||||
}
|
||||
/* matmul parallel by oc */
|
||||
error_code = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Int8OcRun, this, thread_count_oc_);
|
||||
|
|
|
@ -50,11 +50,9 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel {
|
|||
int OcOptPre(int task_id);
|
||||
|
||||
private:
|
||||
int RunArm32Oc(int task_id);
|
||||
int RunArm64Oc(int task_id);
|
||||
int RunArmOc(int task_id);
|
||||
int RunArm64OptOc(int task_id);
|
||||
int RunArm32Hw(int task_id);
|
||||
int RunArm64Hw(int task_id);
|
||||
int RunArmHw(int task_id);
|
||||
int RunArm64OptHw(int task_id);
|
||||
|
||||
private:
|
||||
|
|
|
@ -21,11 +21,6 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias,
|
||||
size_t ksize, size_t ic4, size_t output_channel, size_t offset,
|
||||
const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp,
|
||||
int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after,
|
||||
size_t asymmetric, size_t per_channel, size_t per_channel_offset);
|
||||
|
||||
extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16,
|
||||
const int *input_sum, const int *bias);
|
||||
|
@ -38,15 +33,6 @@ extern void MatmulInt8DpOpt(const int8_t *a, const int8_t *b, int8_t *dst, size_
|
|||
int *left_shift, int *right_shift, size_t stride, size_t peroc, int *filter_zp);
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias,
|
||||
size_t ksize, size_t ic4, size_t output_channel, size_t offset,
|
||||
const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp,
|
||||
int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after,
|
||||
size_t asymmetric, size_t per_channel, size_t per_channel_offset) {
|
||||
return IndirectGemmInt8_24x4_dp(dst, src, weight, bias, ksize, ic4, output_channel, offset, input_sum, act_min,
|
||||
act_max, out_zp, out_multiplier, shift_before, shift_after, asymmetric, per_channel,
|
||||
per_channel_offset);
|
||||
}
|
||||
|
||||
void MatMulR4Int8_optimize_handler(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16,
|
||||
const int *input_sum, const int *bias) {
|
||||
|
|
Loading…
Reference in New Issue