!7648 [MSLITE] deconv winograd fp16 neon

Merge pull request !7648 from ling/sr
This commit is contained in:
mindspore-ci-bot 2020-10-23 17:44:44 +08:00 committed by Gitee
commit 17764803ef
8 changed files with 616 additions and 101 deletions

View File

@ -0,0 +1,259 @@
.text
.align 5
.global TiledC4MatmulFp16
#ifndef __APPLE__
.type TiledC4MatmulFp16, %function
#endif
TiledC4MatmulFp16:
sub sp, sp, #128
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
mov x7, #2 //sizeof(float)
mul x3, x3, x7
mov x7, #32
mul x10, x4, x7
cmp x5, #2
blt LoopOcHalf
LoopOc:
mov x8, x1
subs x9, x4, #1
add x6, x2, x10
ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x8], #32
ld1 {v8.4h, v9.4h, v10.4h, v11.4h}, [x2], #32
fmul v16.4h, v8.4h, v0.h[0]
fmul v17.4h, v8.4h, v1.h[0]
ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [x8], #32
fmul v18.4h, v8.4h, v2.h[0]
fmul v19.4h, v8.4h, v3.h[0]
ld1 {v12.4h, v13.4h, v14.4h, v15.4h}, [x6], #32
fmul v20.4h, v8.4h, v4.h[0]
fmul v21.4h, v8.4h, v5.h[0]
fmul v22.4h, v8.4h, v6.h[0]
fmul v23.4h, v8.4h, v7.h[0]
fmul v24.4h, v12.4h, v0.h[0]
fmul v25.4h, v12.4h, v1.h[0]
fmul v26.4h, v12.4h, v2.h[0]
fmul v27.4h, v12.4h, v3.h[0]
fmul v28.4h, v12.4h, v4.h[0]
fmul v29.4h, v12.4h, v5.h[0]
fmul v30.4h, v12.4h, v6.h[0]
fmul v31.4h, v12.4h, v7.h[0]
beq LoopIcEnd
LoopIc:
add x2, x2, #64
prfm pldl1keep, [x2]
prfm pldl1keep, [x2, x10]
sub x2, x2, #64
prfm pldl1keep, [x8, #64]
prfm pldl1keep, [x8, #96]
fmla v16.4h, v9.4h, v0.h[1]
fmla v17.4h, v9.4h, v1.h[1]
fmla v18.4h, v9.4h, v2.h[1]
fmla v19.4h, v9.4h, v3.h[1]
fmla v20.4h, v9.4h, v4.h[1]
fmla v21.4h, v9.4h, v5.h[1]
fmla v22.4h, v9.4h, v6.h[1]
fmla v23.4h, v9.4h, v7.h[1]
fmla v24.4h, v13.4h, v0.h[1]
fmla v25.4h, v13.4h, v1.h[1]
fmla v26.4h, v13.4h, v2.h[1]
fmla v27.4h, v13.4h, v3.h[1]
fmla v28.4h, v13.4h, v4.h[1]
fmla v29.4h, v13.4h, v5.h[1]
fmla v30.4h, v13.4h, v6.h[1]
fmla v31.4h, v13.4h, v7.h[1]
fmla v16.4h, v10.4h, v0.h[2]
fmla v17.4h, v10.4h, v1.h[2]
fmla v18.4h, v10.4h, v2.h[2]
fmla v19.4h, v10.4h, v3.h[2]
fmla v20.4h, v10.4h, v4.h[2]
fmla v21.4h, v10.4h, v5.h[2]
fmla v22.4h, v10.4h, v6.h[2]
fmla v23.4h, v10.4h, v7.h[2]
fmla v24.4h, v14.4h, v0.h[2]
fmla v25.4h, v14.4h, v1.h[2]
fmla v26.4h, v14.4h, v2.h[2]
fmla v27.4h, v14.4h, v3.h[2]
fmla v28.4h, v14.4h, v4.h[2]
fmla v29.4h, v14.4h, v5.h[2]
fmla v30.4h, v14.4h, v6.h[2]
fmla v31.4h, v14.4h, v7.h[2]
fmla v16.4h, v11.4h, v0.h[3]
fmla v17.4h, v11.4h, v1.h[3]
fmla v18.4h, v11.4h, v2.h[3]
fmla v19.4h, v11.4h, v3.h[3]
fmla v20.4h, v11.4h, v4.h[3]
fmla v21.4h, v11.4h, v5.h[3]
fmla v22.4h, v11.4h, v6.h[3]
fmla v23.4h, v11.4h, v7.h[3]
fmla v24.4h, v15.4h, v0.h[3]
fmla v25.4h, v15.4h, v1.h[3]
fmla v26.4h, v15.4h, v2.h[3]
fmla v27.4h, v15.4h, v3.h[3]
fmla v28.4h, v15.4h, v4.h[3]
fmla v29.4h, v15.4h, v5.h[3]
fmla v30.4h, v15.4h, v6.h[3]
fmla v31.4h, v15.4h, v7.h[3]
ld1 {v8.4h, v9.4h, v10.4h, v11.4h}, [x2], #32
ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x8], #32
fmla v16.4h, v8.4h, v0.h[0]
fmla v17.4h, v8.4h, v1.h[0]
ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [x8], #32
fmla v18.4h, v8.4h, v2.h[0]
fmla v19.4h, v8.4h, v3.h[0]
ld1 {v12.4h, v13.4h, v14.4h, v15.4h}, [x6], #32
fmla v20.4h, v8.4h, v4.h[0]
fmla v21.4h, v8.4h, v5.h[0]
fmla v22.4h, v8.4h, v6.h[0]
fmla v23.4h, v8.4h, v7.h[0]
fmla v24.4h, v12.4h, v0.h[0]
fmla v25.4h, v12.4h, v1.h[0]
fmla v26.4h, v12.4h, v2.h[0]
fmla v27.4h, v12.4h, v3.h[0]
fmla v28.4h, v12.4h, v4.h[0]
fmla v29.4h, v12.4h, v5.h[0]
fmla v30.4h, v12.4h, v6.h[0]
fmla v31.4h, v12.4h, v7.h[0]
subs x9, x9, #1
bne LoopIc
LoopIcEnd:
fmla v16.4h, v9.4h, v0.h[1]
fmla v17.4h, v9.4h, v1.h[1]
fmla v18.4h, v9.4h, v2.h[1]
fmla v19.4h, v9.4h, v3.h[1]
fmla v20.4h, v9.4h, v4.h[1]
fmla v21.4h, v9.4h, v5.h[1]
fmla v22.4h, v9.4h, v6.h[1]
fmla v23.4h, v9.4h, v7.h[1]
fmla v24.4h, v13.4h, v0.h[1]
fmla v25.4h, v13.4h, v1.h[1]
fmla v26.4h, v13.4h, v2.h[1]
fmla v27.4h, v13.4h, v3.h[1]
fmla v28.4h, v13.4h, v4.h[1]
fmla v29.4h, v13.4h, v5.h[1]
fmla v30.4h, v13.4h, v6.h[1]
fmla v31.4h, v13.4h, v7.h[1]
fmla v16.4h, v10.4h, v0.h[2]
fmla v17.4h, v10.4h, v1.h[2]
fmla v18.4h, v10.4h, v2.h[2]
fmla v19.4h, v10.4h, v3.h[2]
fmla v20.4h, v10.4h, v4.h[2]
fmla v21.4h, v10.4h, v5.h[2]
fmla v22.4h, v10.4h, v6.h[2]
fmla v23.4h, v10.4h, v7.h[2]
fmla v24.4h, v14.4h, v0.h[2]
fmla v25.4h, v14.4h, v1.h[2]
fmla v26.4h, v14.4h, v2.h[2]
fmla v27.4h, v14.4h, v3.h[2]
fmla v28.4h, v14.4h, v4.h[2]
fmla v29.4h, v14.4h, v5.h[2]
fmla v30.4h, v14.4h, v6.h[2]
fmla v31.4h, v14.4h, v7.h[2]
add x7, x0, #32
fmla v16.4h, v11.4h, v0.h[3]
fmla v17.4h, v11.4h, v1.h[3]
fmla v18.4h, v11.4h, v2.h[3]
fmla v19.4h, v11.4h, v3.h[3]
fmla v20.4h, v11.4h, v4.h[3]
fmla v21.4h, v11.4h, v5.h[3]
fmla v22.4h, v11.4h, v6.h[3]
fmla v23.4h, v11.4h, v7.h[3]
fmla v24.4h, v15.4h, v0.h[3]
fmla v25.4h, v15.4h, v1.h[3]
fmla v26.4h, v15.4h, v2.h[3]
fmla v27.4h, v15.4h, v3.h[3]
fmla v28.4h, v15.4h, v4.h[3]
st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x0], x3
fmla v29.4h, v15.4h, v5.h[3]
st1 {v20.4h, v21.4h, v22.4h, v23.4h}, [x7], x3
fmla v30.4h, v15.4h, v6.h[3]
st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [x0], x3
mov x2, x6
fmla v31.4h, v15.4h, v7.h[3]
st1 {v28.4h, v29.4h, v30.4h, v31.4h}, [x7]
subs x5, x5, #2
beq LoopOcEnd
cmp x5, #2
bge LoopOc
LoopOcHalf:
mov x8, x1
mov x9, x4
dup v16.4s, wzr
dup v17.4s, wzr
dup v18.4s, wzr
dup v19.4s, wzr
dup v20.4s, wzr
dup v21.4s, wzr
dup v22.4s, wzr
dup v23.4s, wzr
LoopIcHalf:
ld1 {v8.4h, v9.4h, v10.4h, v11.4h}, [x2], #32
ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x8], #32
fmla v16.4h, v8.4h, v0.h[0]
fmla v17.4h, v8.4h, v1.h[0]
ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [x8], #32
fmla v18.4h, v8.4h, v2.h[0]
fmla v19.4h, v8.4h, v3.h[0]
fmla v20.4h, v8.4h, v4.h[0]
fmla v21.4h, v8.4h, v5.h[0]
fmla v22.4h, v8.4h, v6.h[0]
fmla v23.4h, v8.4h, v7.h[0]
fmla v16.4h, v9.4h, v0.h[1]
fmla v17.4h, v9.4h, v1.h[1]
fmla v18.4h, v9.4h, v2.h[1]
fmla v19.4h, v9.4h, v3.h[1]
fmla v20.4h, v9.4h, v4.h[1]
fmla v21.4h, v9.4h, v5.h[1]
fmla v22.4h, v9.4h, v6.h[1]
fmla v23.4h, v9.4h, v7.h[1]
fmla v16.4h, v10.4h, v0.h[2]
fmla v17.4h, v10.4h, v1.h[2]
fmla v18.4h, v10.4h, v2.h[2]
fmla v19.4h, v10.4h, v3.h[2]
fmla v20.4h, v10.4h, v4.h[2]
fmla v21.4h, v10.4h, v5.h[2]
fmla v22.4h, v10.4h, v6.h[2]
fmla v23.4h, v10.4h, v7.h[2]
fmla v16.4h, v11.4h, v0.h[3]
fmla v17.4h, v11.4h, v1.h[3]
fmla v18.4h, v11.4h, v2.h[3]
fmla v19.4h, v11.4h, v3.h[3]
fmla v20.4h, v11.4h, v4.h[3]
fmla v21.4h, v11.4h, v5.h[3]
fmla v22.4h, v11.4h, v6.h[3]
fmla v23.4h, v11.4h, v7.h[3]
subs x9, x9, #1
bne LoopIcHalf
st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x0], #32
st1 {v20.4h, v21.4h, v22.4h, v23.4h}, [x0], #32
LoopOcEnd:
sub sp, sp, #128
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ret

View File

@ -0,0 +1,136 @@
.text
.align 5
.global WinogradTransLeftFp16
#ifndef __APPLE__
.type WinogradTransLeftFp16, %function
#endif
WinogradTransLeftFp16:
sub sp, sp, #32
stp x19, x20, [sp], #32
mov x8, #8 // 4 * sizeof(float16)
mul x8, x6, x8
mul x9, x3, x8
sub x9, x9, x8
add x7, x9, x8 // step for S
mov x10, #2
mul x10, x4, x10 // step for B
LoopH:
mov x13, x0
mov x15, x3
LoopW:
mov x14, x13
mov x17, x1
dup v30.4h, wzr
mov x11, x6
InitZero:
st1 {v30.4h}, [x2], #8
subs x11, x11, #1
bne InitZero
sub x2, x2, x8
mov x12, x5
LoopKStart4:
cmp x12, #4
blt LoopKStart3
mov x16, x15
mov x19, x4
LoopK4:
ld1 {v0.h}[0], [x17], x10
ld1 {v0.h}[1], [x17], x10
ld1 {v0.h}[2], [x17], x10
ld1 {v0.h}[3], [x17], x10
mov x11, x6
mov x18, x17
add x18, x14, x7
add x16, x18, x7
add x19, x16, x7
LoopLength4:
ld1 {v16.4h}, [x2]
ld1 {v20.4h}, [x14], #8
fmla v16.4h, v20.4h, v0.h[0]
ld1 {v21.4h}, [x18], #8
fmul v17.4h, v21.4h, v0.h[1]
ld1 {v20.4h}, [x16], #8
fmla v16.4h, v20.4h, v0.h[2]
ld1 {v21.4h}, [x19], #8
fmla v17.4h, v21.4h, v0.h[3]
fadd v17.4h, v16.4h, v17.4h
st1 {v17.4h}, [x2], #8
subs x11, x11, #1
bne LoopLength4
sub x2, x2, x8
sub x12, x12, #4
add x14, x19, x9
cmp x12, #4
bge LoopK4
LoopKStart3:
cmp x12, #3
blt LoopKStart
mov x16, x15
LoopK3:
ld1 {v0.h}[0], [x17], x10
ld1 {v0.h}[1], [x17], x10
ld1 {v0.h}[2], [x17], x10
mov x11, x6
mov x18, x17
add x18, x14, x7
add x16, x18, x7
LoopLength3:
ld1 {v16.4h}, [x2]
ld1 {v20.4h}, [x14], #8
fmla v16.4h, v20.4h, v0.h[0]
ld1 {v21.4h}, [x18], #8
fmul v17.4h, v21.4h, v0.h[1]
ld1 {v20.4h}, [x16], #8
fmla v16.4h, v20.4h, v0.h[2]
fadd v17.4h, v16.4h, v17.4h
st1 {v17.4h}, [x2], #8
subs x11, x11, #1
bne LoopLength3
sub x2, x2, x8
sub x12, x12, #3
add x14, x16, x9
cmp x12, #3
bge LoopK3
LoopKStart:
cmp x12, #0
beq LKEnd
LoopK:
ld1r {v31.4h}, [x17], x10
mov x11, x6
LoopLength:
ld1 {v0.4h}, [x2]
ld1 {v1.4h}, [x14], #8
fmla v0.4h, v1.4h, v31.4h
st1 {v0.4h}, [x2], #8
subs x11, x11, #1
bne LoopLength
subs x12, x12, #1
sub x2, x2, x8
add x14, x14, x9
bne LoopK
LKEnd:
subs x15, x15, #1
add x13, x13, x8
add x2, x2, x8
bne LoopW
add x1, x1, #2 //sizeof(float)
subs x4, x4, #1
bne LoopH
sub sp, sp, #32
ldp x19, x20, [sp], #32
ret

View File

@ -0,0 +1,134 @@
.text
.align 5
.global WinogradTransRightFp16
#ifndef __APPLE__
.type WinogradTransRightFp16, %function
#endif
WinogradTransRightFp16:
mov x8, #8 // 4 * sizeof(float16)
mul x8, x6, x8
mul x9, x5, x8 // step for S
mov x10, #2
mul x10, x4, x10 // step for B
LoopH:
mov x7, x1
mov x15, x3
LoopW:
mov x17, x0
mov x13, x7
dup v30.4h, wzr
mov x11, x6
InitZero:
st1 {v30.4h}, [x2], #8
subs x11, x11, #1
bne InitZero
sub x2, x2, x8
mov x12, x5
LoopKStart4:
cmp x12, #4
blt LoopKStart3
mov x16, x15
mov x18, x4
LoopK4:
ld1 {v0.h}[0], [x13], x10
ld1 {v0.h}[1], [x13], x10
ld1 {v0.h}[2], [x13], x10
ld1 {v0.h}[3], [x13], x10
mov x11, x6
mov x14, x13
add x14, x17, x8
add x16, x14, x8
add x18, x16, x8
LoopLength4:
ld1 {v16.4h}, [x2]
ld1 {v20.4h}, [x17], #8
fmla v16.4h, v20.4h, v0.h[0]
ld1 {v21.4h}, [x14], #8
fmul v17.4h, v21.4h, v0.h[1]
ld1 {v20.4h}, [x16], #8
fmla v16.4h, v20.4h, v0.h[2]
ld1 {v21.4h}, [x18], #8
fmla v17.4h, v21.4h, v0.h[3]
fadd v17.4h, v16.4h, v17.4h
st1 {v17.4h}, [x2], #8
subs x11, x11, #1
bne LoopLength4
sub x2, x2, x8
sub x12, x12, #4
mov x17, x18
cmp x12, #4
bge LoopK4
LoopKStart3:
cmp x12, #3
blt LoopKStart
mov x16, x15
LoopK3:
ld1 {v0.h}[0], [x13], x10
ld1 {v0.h}[1], [x13], x10
ld1 {v0.h}[2], [x13], x10
mov x11, x6
mov x14, x13
add x14, x17, x8
add x16, x14, x8
LoopLength3:
ld1 {v16.4h}, [x2]
ld1 {v20.4h}, [x17], #8
fmla v16.4h, v20.4h, v0.h[0]
ld1 {v21.4h}, [x14], #8
fmul v17.4h, v21.4h, v0.h[1]
ld1 {v20.4h}, [x16], #8
fmla v16.4h, v20.4h, v0.h[2]
fadd v17.4h, v16.4h, v17.4h
st1 {v17.4h}, [x2], #8
subs x11, x11, #1
bne LoopLength3
sub x2, x2, x8
sub x12, x12, #3
mov x17, x18
cmp x12, #3
bge LoopK3
LoopKStart:
cmp x12, #0
beq LoopKEnd
LoopK:
ld1r {v31.4h}, [x13], x10
mov x11, x6
LoopLength:
ld1 {v0.4h}, [x2]
ld1 {v1.4h}, [x17], #8
fmla v0.4h, v1.4h, v31.4h
st1 {v0.4h}, [x2], #8
subs x11, x11, #1
bne LoopLength
subs x12, x12, #1
sub x2, x2, x8
bne LoopK
LoopKEnd:
subs x15, x15, #1
add x2, x2, x8
add x7, x7, #2
bne LoopW
add x0, x0, x9
subs x4, x4, #1
bne LoopH
ret

View File

@ -41,41 +41,73 @@ void DeConvWgInputPackFp16(float16_t *src_ptr, float16_t *dst_ptr, int channel,
return;
}
void C4GemmFp16(float16_t *dst, const float16_t *src, const float16_t *weight, size_t src_depth_quad, size_t dst_step,
size_t dst_depth_quad, size_t width, size_t weight_depth_offset) {
int dx, sz, dz;
int src_depth_step = 4 * width;
for (dz = 0; dz < dst_depth_quad; ++dz) {
float16_t *dst_z = dst + dz * dst_step;
const float16_t *weight_dz = weight + dz * (src_depth_quad * 16 + weight_depth_offset);
for (dx = 0; dx < width; ++dx) {
float16_t *dst_x = dst_z + dx * 4;
dst_x[0] = 0.0f;
dst_x[1] = 0.0f;
dst_x[2] = 0.0f;
dst_x[3] = 0.0f;
const float16_t *src_dx = src + 4 * dx;
for (sz = 0; sz < src_depth_quad; ++sz) {
const float16_t *src_z = src_dx + sz * src_depth_step;
const float16_t *weight_z = weight_dz + sz * 16;
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) {
dst_x[j] += src_z[i] * weight_z[4 * i + j];
}
}
}
}
}
}
void DeConvWgMergeFp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride, size_t count) {
for (int i = 0; i < count; ++i) {
const float16_t *s = src + i * src_stride;
float16_t *d = dst + i * dst_stride;
for (int j = 0; j < 4; ++j) {
d[j] += s[j];
}
const float16_t *src_ptr = src;
float16_t *dst_ptr = dst;
size_t cuont8 = count / C8NUM * C8NUM;
int i = 0;
for (; i < cuont8; i += C8NUM) {
size_t src_step = src_stride * sizeof(float16_t);
size_t dst_step = dst_stride * sizeof(float16_t);
asm volatile(
"mov x7, %[src_ptr]\n"
"mov x8, %[dst_ptr]\n"
"mov x10, x8\n"
"ld1 {v0.4h}, [x7], %[src_step]\n"
"ld1 {v1.4h}, [x8], %[dst_step]\n"
"ld1 {v2.4h}, [x7], %[src_step]\n"
"ld1 {v3.4h}, [x8], %[dst_step]\n"
"fadd v0.4h, v0.4h, v1.4h\n"
"ld1 {v4.4h}, [x7], %[src_step]\n"
"fadd v2.4h, v2.4h, v3.4h\n"
"st1 {v0.4h}, [x10], %[dst_step]\n"
"st1 {v2.4h}, [x10], %[dst_step]\n"
"ld1 {v5.4h}, [x8], %[dst_step]\n"
"ld1 {v6.4h}, [x7], %[src_step]\n"
"fadd v4.4h, v4.4h, v5.4h\n"
"ld1 {v7.4h}, [x8], %[dst_step]\n"
"fadd v6.4h, v6.4h, v7.4h\n"
"ld1 {v0.4h}, [x7], %[src_step]\n"
"st1 {v4.4h}, [x10], %[dst_step]\n"
"st1 {v6.4h}, [x10], %[dst_step]\n"
"ld1 {v1.4h}, [x8], %[dst_step]\n"
"ld1 {v2.4h}, [x7], %[src_step]\n"
"ld1 {v3.4h}, [x8], %[dst_step]\n"
"fadd v0.4h, v0.4h, v1.4h\n"
"fadd v2.4h, v2.4h, v3.4h\n"
"st1 {v0.4h}, [x10], %[dst_step]\n"
"st1 {v2.4h}, [x10], %[dst_step]\n"
"ld1 {v4.4h}, [x7], %[src_step]\n"
"ld1 {v5.4h}, [x8], %[dst_step]\n"
"ld1 {v6.4h}, [x7], %[src_step]\n"
"ld1 {v7.4h}, [x8], %[dst_step]\n"
"fadd v4.4h, v4.4h, v5.4h\n"
"fadd v6.4h, v6.4h, v7.4h\n"
"st1 {v4.4h}, [x10], %[dst_step]\n"
"st1 {v6.4h}, [x10], %[dst_step]\n"
:
: [ src_ptr ] "r"(src_ptr), [ dst_ptr ] "r"(dst_ptr), [ src_step ] "r"(src_step), [ dst_step ] "r"(dst_step)
: "x7", "x8", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
src_ptr += C8NUM * src_stride;
dst_ptr += C8NUM * dst_stride;
}
for (; i < count; i++) {
float16x4_t src_data = vld1_f16(src_ptr);
float16x4_t dst_data = vld1_f16(dst_ptr);
dst_data = vadd_f16(src_data, dst_data);
vst1_f16(dst_ptr, dst_data);
src_ptr += src_stride;
dst_ptr += dst_stride;
}
return;
}
void _deConvWinogradFp16(float16_t *tile_in, float16_t *tile_out, float16_t *weight_buf, float16_t *tmp_buf,
@ -84,25 +116,25 @@ void _deConvWinogradFp16(float16_t *tile_in, float16_t *tile_out, float16_t *wei
ConvParameter *conv_param, DeConvParam *deconv_param) {
int winograd_plane = unit_size * unit_size;
if (!transfered[unit_size]) {
WinogradMatrixProductLeftFp16(tile_in, at_buf, a_mid_buf, DECONV_WINOGRAD_DEFAULT_UNIT, unit_size,
DECONV_WINOGRAD_DEFAULT_UNIT, deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE);
WinogradMatrixProductRightFp16(a_mid_buf, at_buf, trans_a_buf, unit_size, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT,
deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE);
transfered[unit_size] = false;
WinogradTransLeftFp16(tile_in, at_buf, a_mid_buf, DECONV_WINOGRAD_DEFAULT_UNIT, unit_size,
DECONV_WINOGRAD_DEFAULT_UNIT, deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE);
WinogradTransRightFp16(a_mid_buf, at_buf, trans_a_buf, unit_size, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT,
deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE);
transfered[unit_size] = true;
}
for (int index = 0; index < winograd_plane; index++) {
float16_t *src = trans_a_buf + index * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up4_;
float16_t *dst = tmp_buf + index * deconv_param->oc_up4_ * DECONV_WINOGRAD_DEFAULT_TILE;
float16_t *weight = weight_buf + index * deconv_param->ic_up4_ * deconv_param->oc_up4_;
C4GemmFp16(dst, src, weight, deconv_param->ic_div4_, DECONV_WINOGRAD_DEFAULT_TILE * C4NUM, deconv_param->oc_div4_,
DECONV_WINOGRAD_DEFAULT_TILE, 0);
TiledC4MatmulFp16(dst, src, weight, DECONV_WINOGRAD_DEFAULT_TILE * C4NUM, deconv_param->ic_div4_,
deconv_param->oc_div4_);
}
WinogradMatrixProductLeftFp16(tmp_buf, bt_buf, b_tmp_buf, unit_size, unit_size, unit_size,
deconv_param->oc_div4_ * DECONV_WINOGRAD_DEFAULT_TILE);
WinogradMatrixProductRightFp16(b_tmp_buf, bt_buf, tmp_buf, unit_size, unit_size, unit_size,
deconv_param->oc_div4_ * DECONV_WINOGRAD_DEFAULT_TILE);
WinogradTransLeftFp16(tmp_buf, bt_buf, b_tmp_buf, unit_size, unit_size, unit_size,
deconv_param->oc_div4_ * DECONV_WINOGRAD_DEFAULT_TILE);
WinogradTransRightFp16(b_tmp_buf, bt_buf, tmp_buf, unit_size, unit_size, unit_size,
deconv_param->oc_div4_ * DECONV_WINOGRAD_DEFAULT_TILE);
// Add to dest
for (int uhi = 0; uhi < unit_size; uhi++) {
@ -128,8 +160,7 @@ void _deConvCommonFp16(float16_t *tile_in, float16_t *tile_out, float16_t *weigh
for (int hi = 0; hi < DECONV_WINOGRAD_DEFAULT_UNIT; hi++) {
for (int wi = 0; wi < DECONV_WINOGRAD_DEFAULT_UNIT; wi++) {
float16_t *src_in = tile_in + (wi + hi * DECONV_WINOGRAD_DEFAULT_UNIT) * in_stride;
C4GemmFp16(tmp_buf, src_in, weight, deconv_param->ic_div4_, DECONV_WINOGRAD_DEFAULT_TILE * 4, count,
DECONV_WINOGRAD_DEFAULT_TILE, 0);
TiledC4MatmulFp16(tmp_buf, src_in, weight, DECONV_WINOGRAD_DEFAULT_TILE * 4, deconv_param->ic_div4_, count);
for (int uhi = 0; uhi < h_size; uhi++) {
for (int uwi = 0; uwi < w_size; uwi++) {

View File

@ -32,6 +32,15 @@ void DeconvWgFp16(float16_t *nhwc_input_, float16_t *tile_in, float16_t *tile_ou
void DeconvWgPostFp16(float16_t *tile_out, float16_t *nc4hw4_output, ConvParameter *conv_param,
DeConvParam *deconv_param, int calculate_count, int tile_index);
void TiledC4MatmulFp16(float16_t *dst, const float16_t *src, const float16_t *weight, size_t ic4, size_t cal_num,
size_t oc4);
void WinogradTransLeftFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k,
size_t length);
void WinogradTransRightFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k,
size_t length);
#ifdef __cplusplus
}
#endif

View File

@ -81,51 +81,3 @@ void MatrixMultiplyVecFp16(const float16x8_t *matrix_a, const float16x8_t *matri
}
}
}
void WinogradMatrixProductLeftFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k,
size_t length) {
int unitStep = 4 * length;
for (int y = 0; y < h; ++y) {
float16_t *dstY = M + y * w * unitStep;
for (int x = 0; x < w; ++x) {
float16_t *dstX = dstY + x * unitStep;
const float16_t *srcX = S + x * unitStep;
memset(dstX, 0, unitStep * sizeof(float16_t));
for (int i = 0; i < k; ++i) {
float16_t b = B[i * h + y];
const float16_t *srcY = srcX + i * w * unitStep;
if (0.0f == b) {
continue;
}
for (int j = 0; j < unitStep; ++j) {
dstX[j] += srcY[j] * b;
}
}
}
}
}
// M = S * B , M = w*h * l, S = k*h * l, B = w*k
void WinogradMatrixProductRightFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k,
size_t length) {
int unitStep = 4 * length;
for (int y = 0; y < h; ++y) {
float16_t *dstY = M + y * w * unitStep;
const float16_t *srcY = S + y * k * unitStep;
for (int x = 0; x < w; ++x) {
float16_t *dstX = dstY + x * unitStep;
memset(dstX, 0, unitStep * sizeof(float16_t));
for (int i = 0; i < k; ++i) {
const float16_t *srcX = srcY + i * unitStep;
float16_t b = B[i * h + x];
if (0.0f == b) {
continue;
}
for (int j = 0; j < unitStep; ++j) {
dstX[j] += srcX[j] * b;
}
}
}
}
}

View File

@ -29,13 +29,6 @@ void MatrixMultiplyVecFp16(const float16x8_t *matrix_a, const float16x8_t *matri
const float16_t *bias, int m, int k, int n);
void MatrixMultiplyWinogradFp16(const float16_t *matix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k,
int n, int in_channel);
void WinogradMatrixProductLeftFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k,
size_t length);
void WinogradMatrixProductRightFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k,
size_t length);
#ifdef __cplusplus
}
#endif

View File

@ -265,4 +265,5 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_DeConv2D, CpuDeConvFp16KernelCreator)
} // namespace mindspore::kernel