forked from mindspore-Ecosystem/mindspore
change fp32 kernel to 12x8
This commit is contained in:
parent
befc209480
commit
811940bc55
|
@ -0,0 +1,784 @@
|
|||
#ifdef __aarch64__
|
||||
.text
|
||||
.align 5
|
||||
.global MatmulFloatNeon64Opt
|
||||
#ifndef __APPLE__
|
||||
.type MatmulFloatNeon64Opt, %function
|
||||
#endif
|
||||
|
||||
// A: LM [row_8 * depth] col_8_major
|
||||
// B: RM [depth * col_8] row_8_major
|
||||
// C: A*B [row_8 * col_8] col_8x8_major
|
||||
// A * B -> [8 * depth] * [depth * 8] -> [8 * 4] * [4 * 8] or [8 * 1] * [1 * 8]
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
//CommLoopMul RM 1x8 block
|
||||
// /-----------------------------------------\
|
||||
// |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]|
|
||||
// \-----------------------------------------/
|
||||
// LM 8x1 block
|
||||
// /---------------------\ /-----------------------------------------\
|
||||
// | v0.s[0] | |v16.s[0]...v16.s[3] v17.s[0]...v17.s[3]|
|
||||
// | ... | | ... ... |
|
||||
// | v0.s[3] | |v22.s[0]...v22.s[3] v23.s[0]...v23.s[3]|
|
||||
// | v1.s[0] | |v24.s[0]...v24.s[3] v25.s[0]...v25.s[3]|
|
||||
// | ... | | ... ... |
|
||||
// | v1.s[3] | |v30.s[0]...v30.s[3] v31.s[0]...v31.s[3]|
|
||||
// \---------------------/ \-----------------------------------------/
|
||||
// accumulators 8x8 block
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
//OptLoopMul4 RM 4x8 block
|
||||
// /--------------------------------------------\
|
||||
// |v8.s[0] ... v8.s[3] v9.s[0] ... v9.s[3] |
|
||||
// |v10.s[0] ... v10.s[3] v11.s[0] ... v11.s[3]|
|
||||
// |v12.s[0] ... v12.s[3] v13.s[0] ... v13.s[3]|
|
||||
// |v14.s[0] ... v14.s[3] v15.s[0] ... v15.s[3]|
|
||||
// \--------------------------------------------/
|
||||
// LM 8x4 block
|
||||
// /---------------------------------\ /--------------------------------------------\
|
||||
// | v0.s[0] v2.s[0] v4.s[0] v6.s[0] | |v16.s[0]...v16.s[3] v17.s[0]...v17.s[3] |
|
||||
// | ... ... ... ... | | ... ... |
|
||||
// | v0.s[3] v2.s[3] v4.s[3] v6.s[3] | |v22.s[0]...v22.s[3] v23.s[0]...v23.s[3] |
|
||||
// | v1.s[0] v3.s[0] v5.s[0] v7.s[0] | |v24.s[0]...v24.s[3] v25.s[0]...v25.s[3] |
|
||||
// | ... ... ... ... | | ... ... |
|
||||
// | v1.s[3] v3.s[3] v5.s[3] v7.s[3] | |v30.s[0]...v30.s[3] v31.s[0]...v31.s[3] |
|
||||
// \---------------------------------/ \--------------------------------------------/
|
||||
// accumulators 8x8 block
|
||||
/////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth
|
||||
// int row, int col, int stride, bool write_nhwc)
|
||||
// x0: a
|
||||
// x1: b
|
||||
// x2: c
|
||||
// x3: bias
|
||||
// w4: act_type
|
||||
// w5: depth
|
||||
// w6: row
|
||||
// w7: col
|
||||
// w17: stride
|
||||
// w13: writeC8
|
||||
|
||||
MatmulFloatNeon64Opt:
|
||||
sub sp, sp, #128
|
||||
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
|
||||
mov w18, #32 // sizeof(float) * 8
|
||||
mul w15, w5, w18 // block stride of lhs/rhs: sizeof(float) * 8 * depth
|
||||
mov x11, x3 // bias flag
|
||||
mov x18, #4
|
||||
ldr x17, [sp]
|
||||
mul x17, x17, x18
|
||||
|
||||
L1:
|
||||
mov w10, w6 // reload lhs row
|
||||
mov x12, x0 // reload lhs ptr
|
||||
mov x18, x2 // reload dst ptr
|
||||
|
||||
L2:
|
||||
mov x16, x1 // reload rhs ptr
|
||||
mov w13, w5 // reload depth
|
||||
mov x14, x3 // reload bias ptr
|
||||
dup v8.4s, wzr
|
||||
dup v9.4s, wzr
|
||||
dup v10.4s, wzr
|
||||
dup v11.4s, wzr
|
||||
dup v12.4s, wzr
|
||||
dup v13.4s, wzr
|
||||
dup v14.4s, wzr
|
||||
dup v15.4s, wzr
|
||||
dup v16.4s, wzr
|
||||
dup v17.4s, wzr
|
||||
dup v18.4s, wzr
|
||||
dup v19.4s, wzr
|
||||
dup v20.4s, wzr
|
||||
dup v21.4s, wzr
|
||||
dup v22.4s, wzr
|
||||
dup v23.4s, wzr
|
||||
dup v24.4s, wzr
|
||||
dup v25.4s, wzr
|
||||
dup v26.4s, wzr
|
||||
dup v27.4s, wzr
|
||||
dup v28.4s, wzr
|
||||
dup v29.4s, wzr
|
||||
dup v30.4s, wzr
|
||||
dup v31.4s, wzr
|
||||
|
||||
LoopStart:
|
||||
ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48
|
||||
ld1 {v3.4s, v4.4s}, [x16], #32
|
||||
fmla v8.4s, v3.4s, v0.s[0]
|
||||
fmla v10.4s, v3.4s, v0.s[1]
|
||||
fmla v12.4s, v3.4s, v0.s[2]
|
||||
fmla v14.4s, v3.4s, v0.s[3]
|
||||
fmla v9.4s, v4.4s, v0.s[0]
|
||||
fmla v11.4s, v4.4s, v0.s[1]
|
||||
fmla v13.4s, v4.4s, v0.s[2]
|
||||
fmla v15.4s, v4.4s, v0.s[3]
|
||||
|
||||
subs w13, w13, #1
|
||||
beq LoopEnd
|
||||
|
||||
Loop:
|
||||
ld1 {v0.4s}, [x12], #16
|
||||
fmla v16.4s, v3.4s, v1.s[0]
|
||||
fmla v18.4s, v3.4s, v1.s[1]
|
||||
fmla v20.4s, v3.4s, v1.s[2]
|
||||
fmla v22.4s, v3.4s, v1.s[3]
|
||||
fmla v17.4s, v4.4s, v1.s[0]
|
||||
fmla v19.4s, v4.4s, v1.s[1]
|
||||
fmla v21.4s, v4.4s, v1.s[2]
|
||||
fmla v23.4s, v4.4s, v1.s[3]
|
||||
ld1 {v1.4s}, [x12], #16
|
||||
fmla v24.4s, v3.4s, v2.s[0]
|
||||
fmla v26.4s, v3.4s, v2.s[1]
|
||||
fmla v28.4s, v3.4s, v2.s[2]
|
||||
fmla v30.4s, v3.4s, v2.s[3]
|
||||
ld1 {v3.4s}, [x16], #16
|
||||
fmla v25.4s, v4.4s, v2.s[0]
|
||||
fmla v27.4s, v4.4s, v2.s[1]
|
||||
fmla v29.4s, v4.4s, v2.s[2]
|
||||
fmla v31.4s, v4.4s, v2.s[3]
|
||||
ld1 {v4.4s}, [x16], #16
|
||||
fmla v8.4s, v3.4s, v0.s[0]
|
||||
fmla v10.4s, v3.4s, v0.s[1]
|
||||
fmla v12.4s, v3.4s, v0.s[2]
|
||||
fmla v14.4s, v3.4s, v0.s[3]
|
||||
ld1 {v2.4s}, [x12], #16
|
||||
fmla v9.4s, v4.4s, v0.s[0]
|
||||
fmla v11.4s, v4.4s, v0.s[1]
|
||||
fmla v13.4s, v4.4s, v0.s[2]
|
||||
fmla v15.4s, v4.4s, v0.s[3]
|
||||
|
||||
subs w13, w13, #1
|
||||
bgt Loop
|
||||
|
||||
LoopEnd:
|
||||
fmla v16.4s, v3.4s, v1.s[0]
|
||||
fmla v18.4s, v3.4s, v1.s[1]
|
||||
fmla v20.4s, v3.4s, v1.s[2]
|
||||
fmla v22.4s, v3.4s, v1.s[3]
|
||||
fmla v17.4s, v4.4s, v1.s[0]
|
||||
fmla v19.4s, v4.4s, v1.s[1]
|
||||
fmla v21.4s, v4.4s, v1.s[2]
|
||||
fmla v23.4s, v4.4s, v1.s[3]
|
||||
fmla v24.4s, v3.4s, v2.s[0]
|
||||
fmla v26.4s, v3.4s, v2.s[1]
|
||||
fmla v28.4s, v3.4s, v2.s[2]
|
||||
fmla v30.4s, v3.4s, v2.s[3]
|
||||
fmla v25.4s, v4.4s, v2.s[0]
|
||||
fmla v27.4s, v4.4s, v2.s[1]
|
||||
fmla v29.4s, v4.4s, v2.s[2]
|
||||
fmla v31.4s, v4.4s, v2.s[3]
|
||||
|
||||
Bias:
|
||||
cbz x11, Activation
|
||||
ld1 {v0.4s}, [x14], #16
|
||||
ld1 {v1.4s}, [x14], #16
|
||||
fadd v8.4s, v8.4s, v0.4s
|
||||
fadd v9.4s, v9.4s, v1.4s
|
||||
fadd v10.4s, v10.4s, v0.4s
|
||||
fadd v11.4s, v11.4s, v1.4s
|
||||
fadd v12.4s, v12.4s, v0.4s
|
||||
fadd v13.4s, v13.4s, v1.4s
|
||||
fadd v14.4s, v14.4s, v0.4s
|
||||
fadd v15.4s, v15.4s, v1.4s
|
||||
fadd v16.4s, v16.4s, v0.4s
|
||||
fadd v17.4s, v17.4s, v1.4s
|
||||
fadd v18.4s, v18.4s, v0.4s
|
||||
fadd v19.4s, v19.4s, v1.4s
|
||||
fadd v20.4s, v20.4s, v0.4s
|
||||
fadd v21.4s, v21.4s, v1.4s
|
||||
fadd v22.4s, v22.4s, v0.4s
|
||||
fadd v23.4s, v23.4s, v1.4s
|
||||
fadd v24.4s, v24.4s, v0.4s
|
||||
fadd v25.4s, v25.4s, v1.4s
|
||||
fadd v26.4s, v26.4s, v0.4s
|
||||
fadd v27.4s, v27.4s, v1.4s
|
||||
fadd v28.4s, v28.4s, v0.4s
|
||||
fadd v29.4s, v29.4s, v1.4s
|
||||
fadd v30.4s, v30.4s, v0.4s
|
||||
fadd v31.4s, v31.4s, v1.4s
|
||||
|
||||
Activation:
|
||||
cmp w4, #2
|
||||
beq Relu6
|
||||
cmp w4, #1
|
||||
beq Relu
|
||||
b Write
|
||||
|
||||
Relu6:
|
||||
mov w8, #6
|
||||
dup v2.4s, w8
|
||||
scvtf v2.4s, v2.4s
|
||||
fmin v8.4s, v8.4s, v2.4s
|
||||
fmin v9.4s, v9.4s, v2.4s
|
||||
fmin v10.4s, v10.4s, v2.4s
|
||||
fmin v11.4s, v11.4s, v2.4s
|
||||
fmin v12.4s, v12.4s, v2.4s
|
||||
fmin v13.4s, v13.4s, v2.4s
|
||||
fmin v14.4s, v14.4s, v2.4s
|
||||
fmin v15.4s, v15.4s, v2.4s
|
||||
fmin v16.4s, v16.4s, v2.4s
|
||||
fmin v17.4s, v17.4s, v2.4s
|
||||
fmin v18.4s, v18.4s, v2.4s
|
||||
fmin v19.4s, v19.4s, v2.4s
|
||||
fmin v20.4s, v20.4s, v2.4s
|
||||
fmin v21.4s, v21.4s, v2.4s
|
||||
fmin v22.4s, v22.4s, v2.4s
|
||||
fmin v23.4s, v23.4s, v2.4s
|
||||
fmin v24.4s, v24.4s, v2.4s
|
||||
fmin v25.4s, v25.4s, v2.4s
|
||||
fmin v26.4s, v26.4s, v2.4s
|
||||
fmin v27.4s, v27.4s, v2.4s
|
||||
fmin v28.4s, v28.4s, v2.4s
|
||||
fmin v29.4s, v29.4s, v2.4s
|
||||
fmin v30.4s, v30.4s, v2.4s
|
||||
fmin v31.4s, v31.4s, v2.4s
|
||||
|
||||
Relu:
|
||||
dup v3.4s, wzr
|
||||
fmax v8.4s, v8.4s, v3.4s
|
||||
fmax v9.4s, v9.4s, v3.4s
|
||||
fmax v10.4s, v10.4s, v3.4s
|
||||
fmax v11.4s, v11.4s, v3.4s
|
||||
fmax v12.4s, v12.4s, v3.4s
|
||||
fmax v13.4s, v13.4s, v3.4s
|
||||
fmax v14.4s, v14.4s, v3.4s
|
||||
fmax v15.4s, v15.4s, v3.4s
|
||||
fmax v16.4s, v16.4s, v3.4s
|
||||
fmax v17.4s, v17.4s, v3.4s
|
||||
fmax v18.4s, v18.4s, v3.4s
|
||||
fmax v19.4s, v19.4s, v3.4s
|
||||
fmax v20.4s, v20.4s, v3.4s
|
||||
fmax v21.4s, v21.4s, v3.4s
|
||||
fmax v22.4s, v22.4s, v3.4s
|
||||
fmax v23.4s, v23.4s, v3.4s
|
||||
fmax v24.4s, v24.4s, v3.4s
|
||||
fmax v25.4s, v25.4s, v3.4s
|
||||
fmax v26.4s, v26.4s, v3.4s
|
||||
fmax v27.4s, v27.4s, v3.4s
|
||||
fmax v28.4s, v28.4s, v3.4s
|
||||
fmax v29.4s, v29.4s, v3.4s
|
||||
fmax v30.4s, v30.4s, v3.4s
|
||||
fmax v31.4s, v31.4s, v3.4s
|
||||
|
||||
Write:
|
||||
ldrb w13, [sp, #8]
|
||||
cbz w13, WriteC8
|
||||
cmp w7, #1
|
||||
beq Write1
|
||||
cmp w7, #2
|
||||
beq Write2
|
||||
cmp w7, #3
|
||||
beq Write3
|
||||
cmp w7, #4
|
||||
beq Write4
|
||||
cmp w7, #5
|
||||
beq Write5
|
||||
cmp w7, #6
|
||||
beq Write6
|
||||
cmp w7, #7
|
||||
beq Write7
|
||||
b Write8
|
||||
|
||||
Write1:
|
||||
str s8, [x18]
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s10, [x18]
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s12, [x18]
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s14, [x18]
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s16, [x18]
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s18, [x18]
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s20, [x18]
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s22, [x18]
|
||||
cmp w10, #8
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s24, [x18]
|
||||
cmp w10, #9
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s26, [x18]
|
||||
cmp w10, #10
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s28, [x18]
|
||||
cmp w10, #11
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s30, [x18]
|
||||
add x18, x18, x17
|
||||
b WriteEnd
|
||||
Write2:
|
||||
dup s9, v8.s[1]
|
||||
stp s8, s9, [x18]
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s11, v10.s[1]
|
||||
stp s10, s11, [x18]
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s13, v12.s[1]
|
||||
stp s12, s13, [x18]
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s15, v14.s[1]
|
||||
stp s14, s15, [x18]
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s17, v16.s[1]
|
||||
stp s16, s17, [x18]
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s19, v18.s[1]
|
||||
stp s18, s19, [x18]
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s21, v20.s[1]
|
||||
stp s20, s21, [x18]
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s23, v22.s[1]
|
||||
stp s22, s23, [x18]
|
||||
cmp w10, #8
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s25, v24.s[1]
|
||||
stp s24, s25, [x18]
|
||||
cmp w10, #9
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s27, v26.s[1]
|
||||
stp s26, s27, [x18]
|
||||
cmp w10, #10
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s29, v28.s[1]
|
||||
stp s28, s29, [x18]
|
||||
cmp w10, #11
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s31, v30.s[1]
|
||||
stp s30, s31, [x18]
|
||||
add x18, x18, x17
|
||||
b WriteEnd
|
||||
Write3:
|
||||
add x13, x18, #8
|
||||
dup s9, v8.s[1]
|
||||
stp s8, s9, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v8.s}[2], [x13], x17
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
dup s11, v10.s[1]
|
||||
stp s10, s11, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v10.s}[2], [x13], x17
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
dup s13, v12.s[1]
|
||||
stp s12, s13, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v12.s}[2], [x13], x17
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
dup s15, v14.s[1]
|
||||
stp s14, s15, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v14.s}[2], [x13], x17
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
dup s17, v16.s[1]
|
||||
stp s16, s17, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v16.s}[2], [x13], x17
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
dup s19, v18.s[1]
|
||||
stp s18, s19, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v18.s}[2], [x13], x17
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
dup s21, v20.s[1]
|
||||
stp s20, s21, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v20.s}[2], [x13], x17
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
dup s23, v22.s[1]
|
||||
stp s22, s23, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v22.s}[2], [x13], x17
|
||||
cmp w10, #8
|
||||
beq WriteEnd
|
||||
dup s25, v24.s[1]
|
||||
stp s24, s25, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v24.s}[2], [x13], x17
|
||||
cmp w10, #9
|
||||
beq WriteEnd
|
||||
dup s27, v26.s[1]
|
||||
stp s26, s27, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v26.s}[2], [x13], x17
|
||||
cmp w10, #10
|
||||
beq WriteEnd
|
||||
dup s29, v28.s[1]
|
||||
stp s28, s29, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v28.s}[2], [x13], x17
|
||||
cmp w10, #11
|
||||
beq WriteEnd
|
||||
dup s31, v30.s[1]
|
||||
stp s30, s31, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v30.s}[2], [x13]
|
||||
b WriteEnd
|
||||
Write4:
|
||||
st1 {v8.4s}, [x18], x17
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
st1 {v10.4s}, [x18], x17
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
st1 {v12.4s}, [x18], x17
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
st1 {v14.4s}, [x18], x17
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
st1 {v16.4s}, [x18], x17
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
st1 {v18.4s}, [x18], x17
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
st1 {v20.4s}, [x18], x17
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
st1 {v22.4s}, [x18], x17
|
||||
cmp w10, #8
|
||||
beq WriteEnd
|
||||
st1 {v24.4s}, [x18], x17
|
||||
cmp w10, #9
|
||||
beq WriteEnd
|
||||
st1 {v26.4s}, [x18], x17
|
||||
cmp w10, #10
|
||||
beq WriteEnd
|
||||
st1 {v28.4s}, [x18], x17
|
||||
cmp w10, #11
|
||||
beq WriteEnd
|
||||
st1 {v30.4s}, [x18], x17
|
||||
b WriteEnd
|
||||
Write5:
|
||||
add x13, x18, #16
|
||||
st1 {v8.4s}, [x18], x17
|
||||
str s9, [x13]
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v10.4s}, [x18], x17
|
||||
str s11, [x13]
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v12.4s}, [x18], x17
|
||||
str s13, [x13]
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v14.4s}, [x18], x17
|
||||
str s15, [x13]
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v16.4s}, [x18], x17
|
||||
str s17, [x13]
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v18.4s}, [x18], x17
|
||||
str s19, [x13]
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v20.4s}, [x18], x17
|
||||
str s21, [x13]
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v22.4s}, [x18], x17
|
||||
str s23, [x13]
|
||||
cmp w10, #8
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v24.4s}, [x18], x17
|
||||
str s25, [x13]
|
||||
cmp w10, #9
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v26.4s}, [x18], x17
|
||||
str s27, [x13]
|
||||
cmp w10, #10
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v28.4s}, [x18], x17
|
||||
str s29, [x13]
|
||||
cmp w10, #11
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v30.4s}, [x18], x17
|
||||
str s31, [x13]
|
||||
b WriteEnd
|
||||
Write6:
|
||||
add x13, x18, #16
|
||||
st1 {v8.4s}, [x18], x17
|
||||
dup s8, v9.s[1]
|
||||
stp s9, s8, [x13]
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v10.4s}, [x18], x17
|
||||
dup s10, v11.s[1]
|
||||
stp s11, s10, [x13]
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v12.4s}, [x18], x17
|
||||
dup s12, v13.s[1]
|
||||
stp s13, s12, [x13]
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v14.4s}, [x18], x17
|
||||
dup s14, v15.s[1]
|
||||
stp s15, s14, [x13]
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v16.4s}, [x18], x17
|
||||
dup s16, v17.s[1]
|
||||
stp s17, s16, [x13]
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v18.4s}, [x18], x17
|
||||
dup s18, v19.s[1]
|
||||
stp s19, s18, [x13]
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v20.4s}, [x18], x17
|
||||
dup s20, v21.s[1]
|
||||
stp s21, s20, [x13]
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v22.4s}, [x18], x17
|
||||
dup s22, v23.s[1]
|
||||
stp s23, s22, [x13]
|
||||
cmp w10, #8
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v24.4s}, [x18], x17
|
||||
dup s24, v25.s[1]
|
||||
stp s25, s24, [x13]
|
||||
cmp w10, #9
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v26.4s}, [x18], x17
|
||||
dup s26, v27.s[1]
|
||||
stp s27, s26, [x13]
|
||||
cmp w10, #10
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v28.4s}, [x18], x17
|
||||
dup s28, v29.s[1]
|
||||
stp s29, s28, [x13]
|
||||
cmp w10, #11
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v30.4s}, [x18], x17
|
||||
dup s30, v31.s[1]
|
||||
stp s31, s30, [x13]
|
||||
b WriteEnd
|
||||
Write7:
|
||||
add x13, x18, #16
|
||||
add x16, x18, #24
|
||||
st1 {v8.4s}, [x18], x17
|
||||
dup s8, v9.s[1]
|
||||
stp s9, s8, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v9.s}[2], [x16], x17
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
st1 {v10.4s}, [x18], x17
|
||||
dup s10, v11.s[1]
|
||||
stp s11, s10, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v11.s}[2], [x16], x17
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
st1 {v12.4s}, [x18], x17
|
||||
dup s12, v13.s[1]
|
||||
stp s13, s12, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v13.s}[2], [x16], x17
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
st1 {v14.4s}, [x18], x17
|
||||
dup s14, v15.s[1]
|
||||
stp s15, s14, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v15.s}[2], [x16], x17
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
st1 {v16.4s}, [x18], x17
|
||||
dup s16, v17.s[1]
|
||||
stp s17, s16, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v17.s}[2], [x16], x17
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
st1 {v18.4s}, [x18], x17
|
||||
dup s18, v19.s[1]
|
||||
stp s19, s18, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v19.s}[2], [x16], x17
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
st1 {v20.4s}, [x18], x17
|
||||
dup s20, v21.s[1]
|
||||
stp s21, s20, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v21.s}[2], [x16], x17
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
st1 {v22.4s}, [x18], x17
|
||||
dup s22, v23.s[1]
|
||||
stp s23, s22, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v23.s}[2], [x16], x17
|
||||
cmp w10, #8
|
||||
beq WriteEnd
|
||||
st1 {v24.4s}, [x18], x17
|
||||
dup s24, v25.s[1]
|
||||
stp s25, s24, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v25.s}[2], [x16], x17
|
||||
cmp w10, #9
|
||||
beq WriteEnd
|
||||
st1 {v26.4s}, [x18], x17
|
||||
dup s26, v27.s[1]
|
||||
stp s27, s26, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v27.s}[2], [x16], x17
|
||||
cmp w10, #10
|
||||
beq WriteEnd
|
||||
st1 {v28.4s}, [x18], x17
|
||||
dup s28, v29.s[1]
|
||||
stp s29, s28, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v29.s}[2], [x16], x17
|
||||
cmp w10, #11
|
||||
beq WriteEnd
|
||||
st1 {v30.4s}, [x18], x17
|
||||
dup s30, v31.s[1]
|
||||
stp s31, s30, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v31.s}[2], [x16], x17
|
||||
b WriteEnd
|
||||
WriteC8:
|
||||
st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x2], #64
|
||||
st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x2], #64
|
||||
st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x2], #64
|
||||
st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x2], #64
|
||||
st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x2], #64
|
||||
st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x2], #64
|
||||
b WriteEnd
|
||||
Write8:
|
||||
st1 {v8.4s, v9.4s}, [x18], x17
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
st1 {v10.4s, v11.4s}, [x18], x17
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
st1 {v12.4s, v13.4s}, [x18], x17
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
st1 {v14.4s, v15.4s}, [x18], x17
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
st1 {v16.4s, v17.4s}, [x18], x17
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
st1 {v18.4s, v19.4s}, [x18], x17
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
st1 {v20.4s, v21.4s}, [x18], x17
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
st1 {v22.4s, v23.4s}, [x18], x17
|
||||
cmp w10, #8
|
||||
beq WriteEnd
|
||||
st1 {v24.4s, v25.4s}, [x18], x17
|
||||
cmp w10, #9
|
||||
beq WriteEnd
|
||||
st1 {v26.4s, v27.4s}, [x18], x17
|
||||
cmp w10, #10
|
||||
beq WriteEnd
|
||||
st1 {v28.4s, v29.4s}, [x18], x17
|
||||
cmp w10, #11
|
||||
beq WriteEnd
|
||||
st1 {v30.4s, v31.4s}, [x18], x17
|
||||
|
||||
WriteEnd:
|
||||
subs w10, w10, #12 // lhs row - 12
|
||||
bgt L2
|
||||
|
||||
End2:
|
||||
subs w7, w7, #8 // rhs col - 8
|
||||
add x1, x1, x15 // rhs ptr + stride
|
||||
add x3, x3, #32 // bias ptr + stride
|
||||
ldrb w13, [sp, #8]
|
||||
cbz w13, NoDstStep
|
||||
add x2, x2, #32 // dst ptr + stride
|
||||
NoDstStep:
|
||||
bgt L1
|
||||
|
||||
End1:
|
||||
sub sp, sp, #128
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
ret
|
||||
#endif
|
|
@ -28,6 +28,108 @@ void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col) {
|
|||
return;
|
||||
}
|
||||
|
||||
void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) {
|
||||
size_t row12 = row / C12NUM * C12NUM;
|
||||
size_t col4 = col / C4NUM * C4NUM;
|
||||
float *src_r = src_ptr;
|
||||
float *dst_r = dst_ptr;
|
||||
|
||||
size_t ri = 0;
|
||||
for (; ri < row12; ri += C12NUM) {
|
||||
size_t ci = 0;
|
||||
for (; ci < col4; ci += C4NUM) {
|
||||
float *src_c = src_r + ci;
|
||||
float *dst_c = dst_r + ci * C12NUM;
|
||||
|
||||
/* 12x4 row-major to col-major */
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t stride = col * sizeof(float);
|
||||
asm volatile(
|
||||
"mov x10, %[src_c]\n"
|
||||
"mov x11, %[dst_c]\n"
|
||||
|
||||
"ld1 {v0.4s}, [x10], %[stride]\n"
|
||||
"ld1 {v1.4s}, [x10], %[stride]\n"
|
||||
"ld1 {v2.4s}, [x10], %[stride]\n"
|
||||
"ld1 {v3.4s}, [x10], %[stride]\n"
|
||||
|
||||
"ld1 {v4.4s}, [x10], %[stride]\n"
|
||||
"ld1 {v5.4s}, [x10], %[stride]\n"
|
||||
"ld1 {v6.4s}, [x10], %[stride]\n"
|
||||
"ld1 {v7.4s}, [x10], %[stride]\n"
|
||||
|
||||
"zip1 v12.4s, v0.4s, v1.4s\n"
|
||||
"zip2 v13.4s, v0.4s, v1.4s\n"
|
||||
"zip1 v14.4s, v2.4s, v3.4s\n"
|
||||
"zip2 v15.4s, v2.4s, v3.4s\n"
|
||||
|
||||
"ld1 {v8.4s}, [x10], %[stride]\n"
|
||||
"ld1 {v9.4s}, [x10], %[stride]\n"
|
||||
"ld1 {v10.4s}, [x10], %[stride]\n"
|
||||
"ld1 {v11.4s}, [x10], %[stride]\n"
|
||||
|
||||
"zip1 v16.4s, v4.4s, v5.4s\n"
|
||||
"zip2 v17.4s, v4.4s, v5.4s\n"
|
||||
"zip1 v18.4s, v6.4s, v7.4s\n"
|
||||
"zip2 v19.4s, v6.4s, v7.4s\n"
|
||||
|
||||
"trn1 v20.2d, v12.2d, v14.2d\n"
|
||||
"trn2 v23.2d, v12.2d, v14.2d\n"
|
||||
"trn1 v26.2d, v13.2d, v15.2d\n"
|
||||
"trn2 v29.2d, v13.2d, v15.2d\n"
|
||||
|
||||
"trn1 v21.2d, v16.2d, v18.2d\n"
|
||||
"trn2 v24.2d, v16.2d, v18.2d\n"
|
||||
"trn1 v27.2d, v17.2d, v19.2d\n"
|
||||
"trn2 v30.2d, v17.2d, v19.2d\n"
|
||||
|
||||
"zip1 v12.4s, v8.4s, v9.4s\n"
|
||||
"zip2 v13.4s, v8.4s, v9.4s\n"
|
||||
"zip1 v14.4s, v10.4s, v11.4s\n"
|
||||
"zip2 v15.4s, v10.4s, v11.4s\n"
|
||||
|
||||
"trn1 v22.2d, v12.2d, v14.2d\n"
|
||||
"trn2 v25.2d, v12.2d, v14.2d\n"
|
||||
"trn1 v28.2d, v13.2d, v15.2d\n"
|
||||
"trn2 v31.2d, v13.2d, v15.2d\n"
|
||||
|
||||
"st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x11], #64\n"
|
||||
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11], #64\n"
|
||||
"st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x11], #64\n"
|
||||
|
||||
:
|
||||
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
|
||||
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
|
||||
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
|
||||
"v30", "v31");
|
||||
#else
|
||||
for (int tr = 0; tr < C12NUM; tr++) {
|
||||
for (int tc = 0; tc < C4NUM; tc++) {
|
||||
dst_c[tc * C12NUM + tr] = src_c[tr * col + tc];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
for (; ci < col; ci++) {
|
||||
float *src_c = src_r + ci;
|
||||
float *dst_c = dst_r + ci * C12NUM;
|
||||
for (size_t i = 0; i < C12NUM; i++) {
|
||||
dst_c[i] = src_c[i * col];
|
||||
}
|
||||
}
|
||||
src_r += C12NUM * col;
|
||||
dst_r += C12NUM * col;
|
||||
}
|
||||
for (; ri < row; ri++) {
|
||||
for (size_t i = 0; i < col; i++) {
|
||||
dst_r[i * C12NUM] = src_r[i];
|
||||
}
|
||||
src_r += col;
|
||||
dst_r += 1;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) {
|
||||
size_t row8 = row / C8NUM * C8NUM;
|
||||
size_t col4 = col / C4NUM * C4NUM;
|
||||
|
@ -267,6 +369,31 @@ void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, Ac
|
|||
return;
|
||||
}
|
||||
|
||||
void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
|
||||
int col, int stride, bool write_nhwc) {
|
||||
if (write_nhwc) {
|
||||
/* col8-major * row8-major => col-major */
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
int r12div = r / 12, r12mod = r % 12;
|
||||
int c8div = c / 8, c8mod = c % 8;
|
||||
size_t ci = r * stride + c;
|
||||
float value = 0;
|
||||
for (int d = 0; d < deep; d++) {
|
||||
size_t ai = r12div * deep * 12 + d * 12 + r12mod;
|
||||
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
|
||||
value = value + a[ai] * b[bi];
|
||||
}
|
||||
if (bias != NULL) value += bias[c];
|
||||
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
|
||||
if (act_type != ActType_No) value = MSMAX(0.0f, value);
|
||||
dst[ci] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col,
|
||||
int stride, bool write_nhwc) {
|
||||
#ifdef ENABLE_ARM64
|
||||
|
@ -275,3 +402,12 @@ void MatMul(const float *a, const float *b, float *c, const float *bias, ActType
|
|||
MatMul8x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc);
|
||||
#endif
|
||||
}
|
||||
|
||||
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
|
||||
int col, int stride, bool write_nhwc) {
|
||||
#ifdef ENABLE_ARM64
|
||||
MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, write_nhwc);
|
||||
#else
|
||||
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc);
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -28,12 +28,17 @@ extern "C" {
|
|||
#endif
|
||||
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col,
|
||||
int stride, bool write_nhwc);
|
||||
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row,
|
||||
int col, int stride, bool write_nhwc);
|
||||
void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
|
||||
void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
|
||||
void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride);
|
||||
#ifdef ENABLE_ARM64
|
||||
void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
||||
int col, size_t stride, bool write_nhwc);
|
||||
void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
||||
int col, size_t stride, bool write_nhwc);
|
||||
#endif
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -31,6 +31,7 @@ typedef struct MatMulParameter {
|
|||
int row_;
|
||||
int col_;
|
||||
int row_8_;
|
||||
int row_12_;
|
||||
int row_16_;
|
||||
int col_8_;
|
||||
int deep_;
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
|
||||
#define C4NUM 4
|
||||
#define C8NUM 8
|
||||
#define C12NUM 12
|
||||
#define C16NUM 16
|
||||
#define BLOCK 4
|
||||
#define TILE_NUM 8
|
||||
|
|
|
@ -59,7 +59,7 @@ void Convolution1x1CPUKernel::InitConv1x1MatmulParam() {
|
|||
matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_;
|
||||
matmul_param_->col_ = conv_param_->output_channel_;
|
||||
matmul_param_->deep_ = conv_param_->input_channel_;
|
||||
matmul_param_->row_8_ = UP_ROUND(matmul_param_->row_, C8NUM);
|
||||
matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM);
|
||||
matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM);
|
||||
matmul_param_->act_type_ = (conv_param_->is_relu6_) ? ActType_Relu6 : ActType_No;
|
||||
matmul_param_->act_type_ = (conv_param_->is_relu_) ? ActType_Relu : matmul_param_->act_type_;
|
||||
|
@ -100,12 +100,12 @@ int Convolution1x1CPUKernel::InitConv1x1Param() {
|
|||
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM));
|
||||
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_) * C8NUM;
|
||||
|
||||
pack_input_ = reinterpret_cast<float *>(malloc(matmul_param_->row_8_ * matmul_param_->deep_ * sizeof(float)));
|
||||
pack_input_ = reinterpret_cast<float *>(malloc(matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float)));
|
||||
if (pack_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
memset(pack_input_, 0, matmul_param_->row_8_ * matmul_param_->deep_ * sizeof(float));
|
||||
memset(pack_input_, 0, matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -118,7 +118,7 @@ void Convolution1x1CPUKernel::Pre1x1Trans(float *src_input, float *src_output) {
|
|||
input_ptr_ = src_input;
|
||||
}
|
||||
|
||||
RowMajor2Col8Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
|
||||
RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -143,7 +143,7 @@ int Convolution1x1CPUKernel::DoConv1x1(int task_id) {
|
|||
|
||||
auto bias = (bias_data_ == nullptr) ? nullptr : reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id;
|
||||
|
||||
MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_,
|
||||
MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_,
|
||||
output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_,
|
||||
matmul_param_->row_, cur_oc, matmul_param_->col_, true);
|
||||
|
||||
|
|
Loading…
Reference in New Issue