add fp32 deconv merge assembly

This commit is contained in:
lixian 2020-10-21 15:01:13 +08:00
parent aa94e5a91e
commit 7271ac8a80
1 changed files with 146 additions and 5 deletions

View File

@ -159,12 +159,153 @@ void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t
#endif
void DeConvWgMerge(const float *src, float *dst, size_t src_stride, size_t dst_stride, size_t count) {
for (int i = 0; i < count; ++i) {
const float *s = src + i * src_stride;
float *d = dst + i * dst_stride;
for (int j = 0; j < 4; ++j) {
d[j] += s[j];
const float *src_ptr = src;
float *dst_ptr = dst;
size_t cuont8 = count / C8NUM * C8NUM;
int i = 0;
for (; i < cuont8; i += 8) {
#ifdef ENABLE_ARM64
size_t src_step = src_stride * sizeof(float);
size_t dst_step = dst_stride * sizeof(float);
asm volatile(
"mov x7, %[src_ptr]\n"
"mov x8, %[dst_ptr]\n"
"mov x10, x8\n"
"ld1 {v0.4s}, [x7], %[src_step]\n"
"ld1 {v1.4s}, [x8], %[dst_step]\n"
"ld1 {v2.4s}, [x7], %[src_step]\n"
"ld1 {v3.4s}, [x8], %[dst_step]\n"
"fadd v0.4s, v0.4s, v1.4s\n"
"ld1 {v4.4s}, [x7], %[src_step]\n"
"fadd v2.4s, v2.4s, v3.4s\n"
"st1 {v0.4s}, [x10], %[dst_step]\n"
"st1 {v2.4s}, [x10], %[dst_step]\n"
"ld1 {v5.4s}, [x8], %[dst_step]\n"
"ld1 {v6.4s}, [x7], %[src_step]\n"
"fadd v4.4s, v4.4s, v5.4s\n"
"ld1 {v7.4s}, [x8], %[dst_step]\n"
"fadd v6.4s, v6.4s, v7.4s\n"
"ld1 {v0.4s}, [x7], %[src_step]\n"
"st1 {v4.4s}, [x10], %[dst_step]\n"
"st1 {v6.4s}, [x10], %[dst_step]\n"
"ld1 {v1.4s}, [x8], %[dst_step]\n"
"ld1 {v2.4s}, [x7], %[src_step]\n"
"ld1 {v3.4s}, [x8], %[dst_step]\n"
"fadd v0.4s, v0.4s, v1.4s\n"
"fadd v2.4s, v2.4s, v3.4s\n"
"st1 {v0.4s}, [x10], %[dst_step]\n"
"st1 {v2.4s}, [x10], %[dst_step]\n"
"ld1 {v4.4s}, [x7], %[src_step]\n"
"ld1 {v5.4s}, [x8], %[dst_step]\n"
"ld1 {v6.4s}, [x7], %[src_step]\n"
"ld1 {v7.4s}, [x8], %[dst_step]\n"
"fadd v4.4s, v4.4s, v5.4s\n"
"fadd v6.4s, v6.4s, v7.4s\n"
"st1 {v4.4s}, [x10], %[dst_step]\n"
"st1 {v6.4s}, [x10], %[dst_step]\n"
:
: [ src_ptr ] "r"(src_ptr), [ dst_ptr ] "r"(dst_ptr), [ src_step ] "r"(src_step), [ dst_step ] "r"(dst_step)
: "x7", "x8", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
#elif ENABLE_ARM32
size_t src_step = src_stride * sizeof(float);
size_t dst_step = dst_stride * sizeof(float);
asm volatile(
"mov r7, %[src_ptr]\n"
"mov r8, %[dst_ptr]\n"
"mov r10, r8\n"
"vld1.32 {q0}, [r7], %[src_step]\n"
"vld1.32 {q1}, [r8], %[dst_step]\n"
"vld1.32 {q2}, [r7], %[src_step]\n"
"vld1.32 {q3}, [r8], %[dst_step]\n"
"vadd.f32 q0, q0, q1\n"
"vld1.32 {q8}, [r7], %[src_step]\n"
"vadd.f32 q2, q2, q3\n"
"vst1.32 {q0}, [r10], %[dst_step]\n"
"vst1.32 {q2}, [r10], %[dst_step]\n"
"vld1.32 {q9}, [r8], %[dst_step]\n"
"vld1.32 {q10}, [r7], %[src_step]\n"
"vadd.f32 q8, q8, q9\n"
"vld1.32 {q11}, [r8], %[dst_step]\n"
"vadd.f32 q10, q10, q11\n"
"vld1.32 {q0}, [r7], %[src_step]\n"
"vst1.32 {q8}, [r10], %[dst_step]\n"
"vst1.32 {q10}, [r10], %[dst_step]\n"
"vld1.32 {q1}, [r8], %[dst_step]\n"
"vld1.32 {q2}, [r7], %[src_step]\n"
"vld1.32 {q3}, [r8], %[dst_step]\n"
"vadd.f32 q0, q0, q1\n"
"vadd.f32 q2, q2, q3\n"
"vst1.32 {q0}, [r10], %[dst_step]\n"
"vst1.32 {q2}, [r10], %[dst_step]\n"
"vld1.32 {q8}, [r7], %[src_step]\n"
"vld1.32 {q9}, [r8], %[dst_step]\n"
"vld1.32 {q10}, [r7], %[src_step]\n"
"vld1.32 {q11}, [r8], %[dst_step]\n"
"vadd.f32 q8, q8, q9\n"
"vadd.f32 q10, q10, q11\n"
"vst1.32 {q8}, [r10], %[dst_step]\n"
"vst1.32 {q10}, [r10], %[dst_step]\n"
:
: [ src_ptr ] "r"(src_ptr), [ dst_ptr ] "r"(dst_ptr), [ src_step ] "r"(src_step), [ dst_step ] "r"(dst_step)
: "r7", "r8", "r10", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11");
#else
for (int j = 0; j < 8; j++) {
const float *s = src_ptr + j * src_stride;
float *d = dst_ptr + j * dst_stride;
for (int k = 0; k < 4; k++) {
d[k] += s[k];
}
}
#endif
src_ptr += 8 * src_stride;
dst_ptr += 8 * dst_stride;
}
for (; i < count; i++) {
#ifdef ENABLE_ARM
float32x4_t src_data = vld1q_f32(src_ptr);
float32x4_t dst_data = vld1q_f32(dst_ptr);
dst_data = vaddq_f32(src_data, dst_data);
vst1q_f32(dst_ptr, dst_data);
#else
for (int j = 0; j < 4; j++) {
dst_ptr[j] += src_ptr[j];
}
#endif
src_ptr += src_stride;
dst_ptr += dst_stride;
}
return;
}