forked from mindspore-Ecosystem/mindspore
[MSLITE][Develop]deconv winograd input pack and output bias
This commit is contained in:
parent
36a8013b0a
commit
7d97c1b903
|
@ -0,0 +1,248 @@
|
|||
|
||||
.text
|
||||
.align 5
|
||||
//.p2align 5,,15
|
||||
.global PostFuncBiasReluC4
|
||||
#ifndef __APPLE__
|
||||
.type PostFuncBiasReluC4, %function
|
||||
#endif
|
||||
|
||||
//void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod,
|
||||
// size_t plane_size, size_t plane_stride, size_t relu_type);
|
||||
// r0 dst r1 srx r2 bias
|
||||
// r3 oc4div r4 oc4mod r5 plane_size
|
||||
// r6 plane_stride r7 relu_type
|
||||
|
||||
// v0 ~ v15 value
|
||||
// v16 v17 bias data
|
||||
// r10 r11 weite loop tmp buf
|
||||
// r16 relu6 #6; r17 relu #0
|
||||
// lr oc8 loop control
|
||||
// r8 hw loop control
|
||||
|
||||
PostFuncBiasReluC4:
|
||||
push {r4-r8, r10, r11, lr}
|
||||
add sp, sp, #32
|
||||
|
||||
ldr r4, [sp]
|
||||
ldr r5, [sp, #4]
|
||||
ldr r6, [sp, #8]
|
||||
ldr r7, [sp, #12]
|
||||
|
||||
vmov.i32 q14, #6
|
||||
vcvt.f32.s32 q14, q14
|
||||
veor q15, q15, q15
|
||||
|
||||
mov lr, #4
|
||||
add r12, r3, r4
|
||||
mul r12, r12, lr
|
||||
|
||||
mov lr, #0
|
||||
|
||||
Loop_C4:
|
||||
cmp lr, r3
|
||||
beq Loop_C1
|
||||
mov r11, #4
|
||||
mul r10, lr, r11
|
||||
add r11, r0, r10
|
||||
add lr, lr, #4
|
||||
mov r8, r5
|
||||
vld1.32 {q12}, [r2]!
|
||||
|
||||
Loop_4x4:
|
||||
cmp r8, #4
|
||||
blt Loop_1x4
|
||||
sub r8, r8, #4
|
||||
vld1.32 {q0-q1}, [r1]!
|
||||
vld1.32 {q2-q3}, [r1]!
|
||||
|
||||
vadd.f32 q0, q0, q12
|
||||
vadd.f32 q1, q1, q12
|
||||
vadd.f32 q2, q2, q12
|
||||
vadd.f32 q3, q3, q12
|
||||
|
||||
cmp r7, #3
|
||||
beq Relu6_4x4
|
||||
cmp r7, #1
|
||||
beq Relu_4x4
|
||||
b Write_4x4
|
||||
Relu6_4x4:
|
||||
vmin.f32 q0, q0, q14
|
||||
vmin.f32 q1, q1, q14
|
||||
vmin.f32 q2, q2, q14
|
||||
vmin.f32 q3, q3, q14
|
||||
Relu_4x4:
|
||||
vmax.f32 q0, q0, q15
|
||||
vmax.f32 q1, q1, q15
|
||||
vmax.f32 q2, q2, q15
|
||||
vmax.f32 q3, q3, q15
|
||||
Write_4x4:
|
||||
vst1.32 {q0}, [r11], r12
|
||||
vst1.32 {q1}, [r11], r12
|
||||
vst1.32 {q2}, [r11], r12
|
||||
vst1.32 {q3}, [r11], r12
|
||||
b Loop_4x4
|
||||
|
||||
Loop_1x4:
|
||||
cmp r7, #3
|
||||
beq Relu6_1x4
|
||||
cmp r7, #1
|
||||
beq Relu_1x4
|
||||
b Write_1x4
|
||||
Relu6_1x4:
|
||||
cmp r8, #0
|
||||
beq HW_Add
|
||||
sub r8, r8, #1
|
||||
vld1.32 {q0}, [r1]!
|
||||
vadd.f32 q0, q0, q12
|
||||
vmin.f32 q0, q0, q14
|
||||
vmax.f32 q0, q0, q15
|
||||
vst1.32 {q0}, [r11], r12
|
||||
b Relu6_1x4
|
||||
Relu_1x4:
|
||||
cmp r8, #0
|
||||
beq HW_Add
|
||||
sub r8, r8, #1
|
||||
vld1.32 {q0}, [r1]!
|
||||
vadd.f32 q0, q0, q12
|
||||
vmax.f32 q0, q0, q15
|
||||
vst1.32 {q0}, [r11], r12
|
||||
b Relu_1x4
|
||||
Write_1x4:
|
||||
cmp r8, #0
|
||||
beq HW_Add
|
||||
sub r8, r8, #1
|
||||
vld1.32 {q0}, [r1]!
|
||||
vadd.f32 q0, q0, q12
|
||||
vst1.32 {q0}, [r11], r12
|
||||
b Write_1x4
|
||||
|
||||
HW_Add:
|
||||
add r1, r1, r6
|
||||
b Loop_C4
|
||||
|
||||
Loop_C1:
|
||||
cmp r4, #0
|
||||
beq End
|
||||
mov r8, r5
|
||||
vld1.32 {q12}, [r2]!
|
||||
mov r11, #4
|
||||
mul r10, lr, r11
|
||||
add r0, r0, r10
|
||||
|
||||
cmp r4, #1
|
||||
beq Loop_C1_1
|
||||
cmp r4, #2
|
||||
beq Loop_C1_2
|
||||
cmp r4, #3
|
||||
beq Loop_C1_3
|
||||
|
||||
Loop_C1_1:
|
||||
cmp r7, #3
|
||||
beq Loop_C1_1_Relu6
|
||||
cmp r7, #1
|
||||
beq Loop_C1_1_Relu
|
||||
b Loop_C1_1_Write
|
||||
Loop_C1_1_Relu6:
|
||||
cmp r8, #0
|
||||
beq End
|
||||
sub r8, r8, #1
|
||||
vld1.32 {q0}, [r1]!
|
||||
vadd.f32 q0, q0, q12
|
||||
vmin.f32 q0, q0, q14
|
||||
vmax.f32 q0, q0, q15
|
||||
vst1.32 {d0[0]}, [r0], r12
|
||||
b Loop_C1_1_Relu6
|
||||
Loop_C1_1_Relu:
|
||||
cmp r8, #0
|
||||
beq End
|
||||
sub r8, r8, #1
|
||||
vld1.32 {q0}, [r1]!
|
||||
vadd.f32 q0, q0, q12
|
||||
vmax.f32 q0, q0, q15
|
||||
vst1.32 {d0[0]}, [r0], r12
|
||||
b Loop_C1_1_Relu
|
||||
Loop_C1_1_Write:
|
||||
cmp r8, #0
|
||||
beq End
|
||||
sub r8, r8, #1
|
||||
vld1.32 {q0}, [r1]!
|
||||
vadd.f32 q0, q0, q12
|
||||
vst1.32 {d0[0]}, [r0], r12
|
||||
b Loop_C1_1_Write
|
||||
|
||||
Loop_C1_2:
|
||||
cmp r7, #3
|
||||
beq Loop_C1_2_Relu6
|
||||
cmp r7, #1
|
||||
beq Loop_C1_2_Relu
|
||||
b Loop_C1_2_Write
|
||||
Loop_C1_2_Relu6:
|
||||
cmp r8, #0
|
||||
beq End
|
||||
sub r8, r8, #1
|
||||
vld1.32 {q0}, [r1]!
|
||||
vadd.f32 q0, q0, q12
|
||||
vmin.f32 q0, q0, q14
|
||||
vmax.f32 q0, q0, q15
|
||||
vst1.32 {d0}, [r0], r12
|
||||
b Loop_C1_2_Relu6
|
||||
Loop_C1_2_Relu:
|
||||
cmp r8, #0
|
||||
beq End
|
||||
sub r8, r8, #1
|
||||
vld1.32 {q0}, [r1]!
|
||||
vadd.f32 q0, q0, q12
|
||||
vmax.f32 q0, q0, q15
|
||||
vst1.32 {d0}, [r0], r12
|
||||
b Loop_C1_2_Relu
|
||||
Loop_C1_2_Write:
|
||||
cmp r8, #0
|
||||
beq End
|
||||
sub r8, r8, #1
|
||||
vld1.32 {q0}, [r1]!
|
||||
vadd.f32 q0, q0, q12
|
||||
vst1.32 {d0}, [r0], r12
|
||||
b Loop_C1_2_Write
|
||||
|
||||
Loop_C1_3:
|
||||
add r11, r0, #8
|
||||
cmp r7, #3
|
||||
beq Loop_C1_3_Relu6
|
||||
cmp r7, #1
|
||||
beq Loop_C1_3_Relu
|
||||
b Loop_C1_3_Write
|
||||
Loop_C1_3_Relu6:
|
||||
cmp r8, #0
|
||||
beq End
|
||||
sub r8, r8, #1
|
||||
vld1.32 {q0}, [r1]!
|
||||
vadd.f32 q0, q0, q12
|
||||
vmin.f32 q0, q0, q14
|
||||
vmax.f32 q0, q0, q15
|
||||
vst1.32 {d0}, [r0], r6
|
||||
vst1.32 {d1[0]}, [r11], r12
|
||||
b Loop_C1_3_Relu6
|
||||
Loop_C1_3_Relu:
|
||||
cmp r8, #0
|
||||
beq End
|
||||
sub r8, r8, #1
|
||||
vld1.32 {q0}, [r1]!
|
||||
vadd.f32 q0, q0, q12
|
||||
vmax.f32 q0, q0, q15
|
||||
vst1.32 {d0}, [r0], r6
|
||||
vst1.32 {d1[0]}, [r11], r12
|
||||
b Loop_C1_3_Relu
|
||||
Loop_C1_3_Write:
|
||||
cmp r8, #0
|
||||
beq End
|
||||
sub r8, r8, #1
|
||||
vld1.32 {q0}, [r1]!
|
||||
vadd.f32 q0, q0, q12
|
||||
vst1.32 {d0}, [r0], r6
|
||||
vst1.32 {d1[0]}, [r11], r12
|
||||
b Loop_C1_3_Write
|
||||
|
||||
End:
|
||||
sub sp, sp, #32
|
||||
pop {r4-r8, r10, r11, pc}
|
|
@ -0,0 +1,305 @@
|
|||
#ifdef __aarch64__
|
||||
|
||||
.text
|
||||
.align 5
|
||||
//.p2align 5,,15
|
||||
.global PostFuncBiasReluC4
|
||||
#ifndef __APPLE__
|
||||
.type PostFuncBiasReluC4, %function
|
||||
#endif
|
||||
|
||||
//void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod,
|
||||
// size_t plane_size, size_t plane_stride, size_t relu_type);
|
||||
// x0 dst x1 srx x2 bias
|
||||
// w3 oc4div w4 oc4mod w5 plane_size
|
||||
// x6 plane_stride x7 relu_type
|
||||
|
||||
// v0 ~ v7 value
|
||||
// v16 bias data
|
||||
// x12 oc_stride
|
||||
// x14 x15 write loop tmp buf
|
||||
// v26 relu6 #6; v27 relu #0
|
||||
// w10 oc4 loop control
|
||||
// w13 hw loop control
|
||||
|
||||
|
||||
PostFuncBiasReluC4:
|
||||
|
||||
movi v26.4s, #6
|
||||
scvtf v26.4s, v26.4s
|
||||
dup v27.4s, wzr
|
||||
|
||||
mov x10, #4
|
||||
add x12, x3, x4
|
||||
mul x12, x12, x10
|
||||
|
||||
mov w10, #0
|
||||
|
||||
Loop_C4:
|
||||
cmp w10, w3
|
||||
beq Loop_C1
|
||||
mov x15, #4
|
||||
mul x14, x10, x15
|
||||
add x15, x0, x14
|
||||
add w10, w10, #4
|
||||
mov w13, w5
|
||||
ld1 {v16.4s}, [x2], #16
|
||||
|
||||
Loop_8x4:
|
||||
cmp w13, #8
|
||||
blt Loop_4x4
|
||||
sub w13, w13, #8
|
||||
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64
|
||||
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64
|
||||
|
||||
fadd v0.4s, v0.4s, v16.4s
|
||||
fadd v1.4s, v1.4s, v16.4s
|
||||
fadd v2.4s, v2.4s, v16.4s
|
||||
fadd v3.4s, v3.4s, v16.4s
|
||||
fadd v4.4s, v4.4s, v16.4s
|
||||
fadd v5.4s, v5.4s, v16.4s
|
||||
fadd v6.4s, v6.4s, v16.4s
|
||||
fadd v7.4s, v7.4s, v16.4s
|
||||
|
||||
cmp x7, #3
|
||||
beq Relu6_8x4
|
||||
cmp x7, #1
|
||||
beq Relu_8x4
|
||||
b Write_8x4
|
||||
Relu6_8x4:
|
||||
fmin v0.4s, v0.4s, v26.4s
|
||||
fmin v1.4s, v1.4s, v26.4s
|
||||
fmin v2.4s, v2.4s, v26.4s
|
||||
fmin v3.4s, v3.4s, v26.4s
|
||||
fmin v4.4s, v4.4s, v26.4s
|
||||
fmin v5.4s, v5.4s, v26.4s
|
||||
fmin v6.4s, v6.4s, v26.4s
|
||||
fmin v7.4s, v7.4s, v26.4s
|
||||
Relu_8x4:
|
||||
fmax v0.4s, v0.4s, v27.4s
|
||||
fmax v1.4s, v1.4s, v27.4s
|
||||
fmax v2.4s, v2.4s, v27.4s
|
||||
fmax v3.4s, v3.4s, v27.4s
|
||||
fmax v4.4s, v4.4s, v27.4s
|
||||
fmax v5.4s, v5.4s, v27.4s
|
||||
fmax v6.4s, v6.4s, v27.4s
|
||||
fmax v7.4s, v7.4s, v27.4s
|
||||
Write_8x4:
|
||||
st1 {v0.4s}, [x15], x12
|
||||
st1 {v1.4s}, [x15], x12
|
||||
st1 {v2.4s}, [x15], x12
|
||||
st1 {v3.4s}, [x15], x12
|
||||
st1 {v4.4s}, [x15], x12
|
||||
st1 {v5.4s}, [x15], x12
|
||||
st1 {v6.4s}, [x15], x12
|
||||
st1 {v7.4s}, [x15], x12
|
||||
b Loop_8x4
|
||||
|
||||
Loop_4x4:
|
||||
cmp w13, #4
|
||||
blt Loop_1x4
|
||||
sub w13, w13, #4
|
||||
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64
|
||||
fadd v0.4s, v0.4s, v16.4s
|
||||
fadd v1.4s, v1.4s, v16.4s
|
||||
fadd v2.4s, v2.4s, v16.4s
|
||||
fadd v3.4s, v3.4s, v16.4s
|
||||
cmp x7, #3
|
||||
beq Relu6_4x4
|
||||
cmp x7, #1
|
||||
beq Relu_4x4
|
||||
b Write_4x4
|
||||
Relu6_4x4:
|
||||
fmin v0.4s, v0.4s, v26.4s
|
||||
fmin v1.4s, v1.4s, v26.4s
|
||||
fmin v2.4s, v2.4s, v26.4s
|
||||
fmin v3.4s, v3.4s, v26.4s
|
||||
Relu_4x4:
|
||||
fmax v0.4s, v0.4s, v27.4s
|
||||
fmax v1.4s, v1.4s, v27.4s
|
||||
fmax v2.4s, v2.4s, v27.4s
|
||||
fmax v3.4s, v3.4s, v27.4s
|
||||
Write_4x4:
|
||||
st1 {v0.4s}, [x15], x12
|
||||
st1 {v1.4s}, [x15], x12
|
||||
st1 {v2.4s}, [x15], x12
|
||||
st1 {v3.4s}, [x15], x12
|
||||
|
||||
Loop_1x4:
|
||||
cmp x7, #3
|
||||
beq Relu6_1x4
|
||||
cmp x7, #1
|
||||
beq Relu_1x4
|
||||
b Write_1x4
|
||||
Relu6_1x4:
|
||||
cmp w13, #0
|
||||
beq HW_Add
|
||||
sub w13, w13, #1
|
||||
ld1 {v0.4s}, [x1], #16
|
||||
fadd v0.4s, v0.4s, v16.4s
|
||||
fmin v0.4s, v0.4s, v26.4s
|
||||
fmax v0.4s, v0.4s, v27.4s
|
||||
st1 {v0.4s}, [x15], x12
|
||||
b Relu6_1x4
|
||||
Relu_1x4:
|
||||
cmp w13, #0
|
||||
beq HW_Add
|
||||
sub w13, w13, #1
|
||||
ld1 {v0.4s}, [x1], #16
|
||||
fadd v0.4s, v0.4s, v16.4s
|
||||
fmax v0.4s, v0.4s, v27.4s
|
||||
st1 {v0.4s}, [x15], x12
|
||||
b Relu_1x4
|
||||
Write_1x4:
|
||||
cmp w13, #0
|
||||
beq HW_Add
|
||||
sub w13, w13, #1
|
||||
ld1 {v0.4s}, [x1], #16
|
||||
fadd v0.4s, v0.4s, v16.4s
|
||||
st1 {v0.4s}, [x15], x12
|
||||
b Write_1x4
|
||||
|
||||
HW_Add:
|
||||
add x1, x1, x6
|
||||
b Loop_C4
|
||||
|
||||
Loop_C1:
|
||||
cmp x4, #0
|
||||
beq End
|
||||
mov w13, w5
|
||||
ld1 {v16.4s}, [x2], #16
|
||||
mov x15, #4
|
||||
mul x14, x10, x15
|
||||
add x0, x0, x14
|
||||
|
||||
cmp x4, #1
|
||||
beq Loop_C1_1
|
||||
cmp x4, #2
|
||||
beq Loop_C1_2
|
||||
cmp x4, #3
|
||||
beq Loop_C1_3
|
||||
|
||||
Loop_C1_1:
|
||||
cmp x7, #3
|
||||
beq Loop_C1_1_Relu6
|
||||
cmp x7, #1
|
||||
beq Loop_C1_1_Relu
|
||||
b Loop_C1_1_Write
|
||||
Loop_C1_1_Relu6:
|
||||
cmp w13, #0
|
||||
beq End
|
||||
sub w13, w13, #1
|
||||
ld1 {v0.4s}, [x1], #16
|
||||
fadd v0.4s, v0.4s, v16.4s
|
||||
fmin v0.4s, v0.4s, v26.4s
|
||||
fmax v0.4s, v0.4s, v27.4s
|
||||
str s0, [x0]
|
||||
add x0, x0, x12
|
||||
b Loop_C1_1_Relu6
|
||||
Loop_C1_1_Relu:
|
||||
cmp w13, #0
|
||||
beq End
|
||||
sub w13, w13, #1
|
||||
ld1 {v0.4s}, [x1], #16
|
||||
fadd v0.4s, v0.4s, v16.4s
|
||||
fmax v0.4s, v0.4s, v27.4s
|
||||
str s0, [x0]
|
||||
add x0, x0, x12
|
||||
b Loop_C1_1_Relu
|
||||
Loop_C1_1_Write:
|
||||
cmp w13, #0
|
||||
beq End
|
||||
sub w13, w13, #1
|
||||
ld1 {v0.4s}, [x1], #16
|
||||
fadd v0.4s, v0.4s, v16.4s
|
||||
str s0, [x0]
|
||||
add x0, x0, x12
|
||||
b Loop_C1_1_Write
|
||||
|
||||
Loop_C1_2:
|
||||
cmp x7, #3
|
||||
beq Loop_C1_2_Relu6
|
||||
cmp x7, #1
|
||||
beq Loop_C1_2_Relu
|
||||
b Loop_C1_2_Write
|
||||
Loop_C1_2_Relu6:
|
||||
cmp w13, #0
|
||||
beq End
|
||||
sub w13, w13, #1
|
||||
ld1 {v0.4s}, [x1], #16
|
||||
fadd v0.4s, v0.4s, v16.4s
|
||||
fmin v0.4s, v0.4s, v26.4s
|
||||
fmax v0.4s, v0.4s, v27.4s
|
||||
dup s1, v0.s[1]
|
||||
stp s0, s1, [x0]
|
||||
add x0, x0, x12
|
||||
b Loop_C1_2_Relu6
|
||||
Loop_C1_2_Relu:
|
||||
cmp w13, #0
|
||||
beq End
|
||||
sub w13, w13, #1
|
||||
ld1 {v0.4s}, [x1], #16
|
||||
fadd v0.4s, v0.4s, v16.4s
|
||||
fmax v0.4s, v0.4s, v27.4s
|
||||
dup s1, v0.s[1]
|
||||
stp s0, s1, [x0]
|
||||
add x0, x0, x12
|
||||
b Loop_C1_2_Relu
|
||||
Loop_C1_2_Write:
|
||||
cmp w13, #0
|
||||
beq End
|
||||
sub w13, w13, #1
|
||||
ld1 {v0.4s}, [x1], #16
|
||||
fadd v0.4s, v0.4s, v16.4s
|
||||
dup s1, v0.s[1]
|
||||
stp s0, s1, [x0]
|
||||
add x0, x0, x12
|
||||
b Loop_C1_2_Write
|
||||
|
||||
Loop_C1_3:
|
||||
add x15, x0, #8
|
||||
cmp x7, #3
|
||||
beq Loop_C1_3_Relu6
|
||||
cmp x7, #1
|
||||
beq Loop_C1_3_Relu
|
||||
b Loop_C1_3_Write
|
||||
Loop_C1_3_Relu6:
|
||||
cmp w13, #0
|
||||
beq End
|
||||
sub w13, w13, #1
|
||||
ld1 {v0.4s}, [x1], #16
|
||||
fadd v0.4s, v0.4s, v16.4s
|
||||
fmin v0.4s, v0.4s, v26.4s
|
||||
fmax v0.4s, v0.4s, v27.4s
|
||||
dup s1, v0.s[1]
|
||||
stp s0, s1, [x0]
|
||||
add x0, x0, x12
|
||||
st1 {v0.s}[2], [x15], x12
|
||||
b Loop_C1_3_Relu6
|
||||
Loop_C1_3_Relu:
|
||||
cmp w13, #0
|
||||
beq End
|
||||
sub w13, w13, #1
|
||||
ld1 {v0.4s}, [x1], #16
|
||||
fadd v0.4s, v0.4s, v16.4s
|
||||
fmax v0.4s, v0.4s, v27.4s
|
||||
dup s1, v0.s[1]
|
||||
stp s0, s1, [x0]
|
||||
add x0, x0, x12
|
||||
st1 {v0.s}[2], [x15], x12
|
||||
b Loop_C1_3_Relu
|
||||
Loop_C1_3_Write:
|
||||
cmp w13, #0
|
||||
beq End
|
||||
sub w13, w13, #1
|
||||
ld1 {v0.4s}, [x1], #16
|
||||
fadd v0.4s, v0.4s, v16.4s
|
||||
dup s1, v0.s[1]
|
||||
stp s0, s1, [x0]
|
||||
add x0, x0, x12
|
||||
st1 {v0.s}[2], [x15], x12
|
||||
b Loop_C1_3_Write
|
||||
|
||||
End:
|
||||
ret
|
||||
#endif
|
|
@ -89,7 +89,6 @@ typedef struct DeConvWg {
|
|||
|
||||
typedef struct DeConvWgABuffer {
|
||||
bool buf_init_;
|
||||
bool trans_formed_;
|
||||
void *middle_buffer_;
|
||||
void *dest_buffer_;
|
||||
} DeConvWgABuffer;
|
||||
|
|
|
@ -79,15 +79,16 @@ void DeConvWgMergeFp16(const float16_t *src, float16_t *dst, size_t src_stride,
|
|||
}
|
||||
|
||||
void _deConvWinogradFp16(float16_t *tile_in, float16_t *tile_out, float16_t *weight_buf, float16_t *tmp_buf,
|
||||
float16_t *at_buf, float16_t *a_mid_buf, float16_t *trans_a_buf, bool a_trans,
|
||||
float16_t *at_buf, float16_t *a_mid_buf, float16_t *trans_a_buf, bool *transfered,
|
||||
float16_t *bt_buf, float16_t *b_tmp_buf, int unit_size, int w_start, int h_start,
|
||||
ConvParameter *conv_param, DeConvParam *deconv_param) {
|
||||
int winograd_plane = unit_size * unit_size;
|
||||
if (!a_trans) {
|
||||
if (!transfered[unit_size]) {
|
||||
WinogradMatrixProductLeftFp16(tile_in, at_buf, a_mid_buf, DECONV_WINOGRAD_DEFAULT_UNIT, unit_size,
|
||||
DECONV_WINOGRAD_DEFAULT_UNIT, deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE);
|
||||
WinogradMatrixProductRightFp16(a_mid_buf, at_buf, trans_a_buf, unit_size, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT,
|
||||
deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE);
|
||||
transfered[unit_size] = false;
|
||||
}
|
||||
|
||||
for (int index = 0; index < winograd_plane; index++) {
|
||||
|
@ -265,6 +266,7 @@ void DeconvWgFp16(float16_t *nhwc_input_, float16_t *tile_in, float16_t *tile_ou
|
|||
}
|
||||
|
||||
/* compute */
|
||||
bool transfered[DECONV_WINOGRAD_BUFFER_COUNT] = {false};
|
||||
for (int i = 0; i < deconv_param->compute_size_; i++) {
|
||||
DeConvComputeUnit *unit = &deconv_param->compute_units_[i];
|
||||
if (unit->use_winograd_) {
|
||||
|
@ -281,9 +283,8 @@ void DeconvWgFp16(float16_t *nhwc_input_, float16_t *tile_in, float16_t *tile_ou
|
|||
DECONV_WINOGRAD_DEFAULT_TILE *
|
||||
deconv_param->oc_up4_;
|
||||
_deConvWinogradFp16(tile_in, tile_out, (float16_t *)unit->weight_, tmp_buf, unit->winograd_.AT_, mid_a, dst_a,
|
||||
tmp_a->trans_formed_, unit->winograd_.BT_, tmp_b, unit->winograd_.kh_, unit->w_start_,
|
||||
unit->h_start_, conv_param, deconv_param);
|
||||
tmp_a->trans_formed_ = true;
|
||||
transfered, unit->winograd_.BT_, tmp_b, unit->winograd_.kh_, unit->w_start_, unit->h_start_,
|
||||
conv_param, deconv_param);
|
||||
} else {
|
||||
float16_t *tmp_buf = (float16_t *)unit->tmp_buffer_ + task_id * deconv_param->oc_div4_ * unit->w_size_ *
|
||||
unit->h_size_ * DECONV_WINOGRAD_DEFAULT_TILE * C4NUM;
|
||||
|
|
|
@ -56,8 +56,15 @@ void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bi
|
|||
|
||||
void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel,
|
||||
size_t plane_size, size_t plane_stride, size_t relu_type) {
|
||||
#ifdef ENABLE_ARM
|
||||
size_t oc4mod = output_channel % C4NUM;
|
||||
size_t oc4div = output_channel - oc4mod;
|
||||
size_t stride_size = (plane_stride - plane_size) * C4NUM * sizeof(float);
|
||||
PostFuncBiasReluC4(out_ptr, c4_out_ptr, bias_ptr, oc4div, oc4mod, plane_size, stride_size, relu_type);
|
||||
#else
|
||||
PostConvFuncComm(c4_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_stride, output_channel, relu_type,
|
||||
C4NUM);
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -53,6 +53,8 @@ void ConvDwFp32Border(float *dst, const float *src, const float *weight, const f
|
|||
size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6);
|
||||
void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod,
|
||||
size_t plane_size, size_t stride, size_t relu_type);
|
||||
void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod,
|
||||
size_t plane_size, size_t plane_stride, size_t relu_type);
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
|
|
|
@ -109,7 +109,11 @@ void DeConvWgInputPack(float *src_ptr, float *dst_ptr, int channel, int stride)
|
|||
float *dst = dst_ptr;
|
||||
|
||||
for (int ic = 0; ic < ic4div; ic++) {
|
||||
#ifdef ENABLE_ARM
|
||||
vst1q_f32(dst, vld1q_f32(src));
|
||||
#else
|
||||
memcpy(dst, src, C4NUM * sizeof(float));
|
||||
#endif
|
||||
dst += stride;
|
||||
src += C4NUM;
|
||||
}
|
||||
|
@ -159,25 +163,27 @@ void MSGemmFloatUnit_4(float *dstOrigin, const float *src, const float *weight,
|
|||
weight_depth_offset);
|
||||
}
|
||||
|
||||
void DeConvWgMerge(const float *source, float *dest, size_t srcStride, size_t dstStride, size_t count) {
|
||||
void DeConvWgMerge(const float *src, float *dst, size_t src_stride, size_t dst_stride, size_t count) {
|
||||
for (int i = 0; i < count; ++i) {
|
||||
const float *s = source + i * srcStride;
|
||||
float *d = dest + i * dstStride;
|
||||
const float *s = src + i * src_stride;
|
||||
float *d = dst + i * dst_stride;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
d[j] += s[j];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void _deConvWinograd(float *tile_in, float *tile_out, float *weight_buf, float *tmp_buf, float *at_buf,
|
||||
float *a_mid_buf, float *trans_a_buf, bool a_trans, float *bt_buf, float *b_tmp_buf, int unit_size,
|
||||
int w_start, int h_start, ConvParameter *conv_param, DeConvParam *deconv_param) {
|
||||
float *a_mid_buf, float *trans_a_buf, bool *transfered, float *bt_buf, float *b_tmp_buf,
|
||||
int unit_size, int w_start, int h_start, ConvParameter *conv_param, DeConvParam *deconv_param) {
|
||||
int winograd_plane = unit_size * unit_size;
|
||||
if (!a_trans) {
|
||||
if (!transfered[unit_size]) {
|
||||
WinogradMatrixProductLeft(tile_in, at_buf, a_mid_buf, DECONV_WINOGRAD_DEFAULT_UNIT, unit_size,
|
||||
DECONV_WINOGRAD_DEFAULT_UNIT, deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE);
|
||||
WinogradMatrixProductRight(a_mid_buf, at_buf, trans_a_buf, unit_size, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT,
|
||||
deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE);
|
||||
transfered[unit_size] = true;
|
||||
}
|
||||
|
||||
for (int index = 0; index < winograd_plane; index++) {
|
||||
|
@ -274,6 +280,7 @@ void DeconvWg(float *nhwc_input_, float *tile_in, float *tile_out, int start_ind
|
|||
}
|
||||
|
||||
/* compute */
|
||||
bool transfered[DECONV_WINOGRAD_BUFFER_COUNT] = {false};
|
||||
for (int i = 0; i < deconv_param->compute_size_; i++) {
|
||||
DeConvComputeUnit *unit = &deconv_param->compute_units_[i];
|
||||
if (unit->use_winograd_) {
|
||||
|
@ -289,9 +296,8 @@ void DeconvWg(float *nhwc_input_, float *tile_in, float *tile_out, int start_ind
|
|||
float *tmp_b_buf = (float *)unit->winograd_.b_buffer_ + task_id * unit->winograd_.kh_ * unit->winograd_.kw_ *
|
||||
deconv_param->oc_up4_ * DECONV_WINOGRAD_DEFAULT_TILE;
|
||||
_deConvWinograd(tile_in, tile_out, (float *)unit->weight_, tmp_buf, unit->winograd_.AT_, wg_mid_a_buf,
|
||||
wg_dst_a_buf, wg_buf->trans_formed_, unit->winograd_.BT_, tmp_b_buf, unit->winograd_.kh_,
|
||||
unit->w_start_, unit->h_start_, conv_param, deconv_param);
|
||||
wg_buf->trans_formed_ = true;
|
||||
wg_dst_a_buf, transfered, unit->winograd_.BT_, tmp_b_buf, unit->winograd_.kh_, unit->w_start_,
|
||||
unit->h_start_, conv_param, deconv_param);
|
||||
} else {
|
||||
float *tmp_buf = (float *)unit->tmp_buffer_ + task_id * deconv_param->oc_div4_ * unit->w_size_ * unit->h_size_ *
|
||||
DECONV_WINOGRAD_DEFAULT_TILE * C4NUM;
|
||||
|
|
|
@ -75,7 +75,6 @@ int DeConvWinogradFp16CPUKernel::InitParameter() {
|
|||
if (unit.use_winograd_) {
|
||||
if (deconv_param_->a_buffer_[unit.winograd_.kh_].buf_init_ == false) {
|
||||
deconv_param_->a_buffer_[unit.winograd_.kh_].buf_init_ = true;
|
||||
deconv_param_->a_buffer_[unit.winograd_.kh_].trans_formed_ = false;
|
||||
|
||||
size = unit.winograd_.kh_ * unit.winograd_.kw_ * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param_->ic_up4_;
|
||||
deconv_param_->a_buffer_[unit.winograd_.kh_].middle_buffer_ =
|
||||
|
@ -111,9 +110,6 @@ int DeConvWinogradFp16CPUKernel::DoDeconv(int task_id) {
|
|||
int calculate_count = MSMIN(DECONV_WINOGRAD_DEFAULT_TILE,
|
||||
deconv_param_->in_tile_w_count_ * deconv_param_->in_tile_h_count_ - start_index);
|
||||
|
||||
for (int i = 0; i < DECONV_WINOGRAD_BUFFER_COUNT; i++) {
|
||||
deconv_param_->a_buffer_[i].trans_formed_ = false;
|
||||
}
|
||||
DeconvWgFp16(nhwc_input_, tile_in, tile_out, start_index, calculate_count, conv_param_, deconv_param_, task_id);
|
||||
|
||||
std::unique_lock<std::mutex> merge_lock(lock_);
|
||||
|
|
|
@ -138,7 +138,6 @@ int DeConvolutionWinogradCPUKernel::InitParameter() {
|
|||
if (unit.use_winograd_) {
|
||||
if (deconv_param_->a_buffer_[unit.winograd_.kh_].buf_init_ == false) {
|
||||
deconv_param_->a_buffer_[unit.winograd_.kh_].buf_init_ = true;
|
||||
deconv_param_->a_buffer_[unit.winograd_.kh_].trans_formed_ = false;
|
||||
|
||||
size = unit.winograd_.kh_ * unit.winograd_.kw_ * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param_->ic_up4_;
|
||||
deconv_param_->a_buffer_[unit.winograd_.kh_].middle_buffer_ =
|
||||
|
@ -308,9 +307,6 @@ int DeConvolutionWinogradCPUKernel::DoDeconv(int task_id) {
|
|||
int calculate_count = MSMIN(DECONV_WINOGRAD_DEFAULT_TILE,
|
||||
deconv_param_->in_tile_w_count_ * deconv_param_->in_tile_h_count_ - start_index);
|
||||
|
||||
for (int i = 0; i < DECONV_WINOGRAD_BUFFER_COUNT; i++) {
|
||||
deconv_param_->a_buffer_[i].trans_formed_ = false;
|
||||
}
|
||||
DeconvWg(nhwc_input_, tile_in, tile_out, start_index, calculate_count, conv_param_, deconv_param_, task_id);
|
||||
|
||||
std::unique_lock<std::mutex> merge_lock(lock_);
|
||||
|
|
Loading…
Reference in New Issue