forked from mindspore-Ecosystem/mindspore
add fp32 deconv merge assembly
This commit is contained in:
parent
aa94e5a91e
commit
7271ac8a80
|
@ -159,12 +159,153 @@ void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t
|
|||
#endif
|
||||
|
||||
void DeConvWgMerge(const float *src, float *dst, size_t src_stride, size_t dst_stride, size_t count) {
|
||||
for (int i = 0; i < count; ++i) {
|
||||
const float *s = src + i * src_stride;
|
||||
float *d = dst + i * dst_stride;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
d[j] += s[j];
|
||||
const float *src_ptr = src;
|
||||
float *dst_ptr = dst;
|
||||
size_t cuont8 = count / C8NUM * C8NUM;
|
||||
int i = 0;
|
||||
for (; i < cuont8; i += 8) {
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t src_step = src_stride * sizeof(float);
|
||||
size_t dst_step = dst_stride * sizeof(float);
|
||||
asm volatile(
|
||||
"mov x7, %[src_ptr]\n"
|
||||
"mov x8, %[dst_ptr]\n"
|
||||
"mov x10, x8\n"
|
||||
|
||||
"ld1 {v0.4s}, [x7], %[src_step]\n"
|
||||
"ld1 {v1.4s}, [x8], %[dst_step]\n"
|
||||
|
||||
"ld1 {v2.4s}, [x7], %[src_step]\n"
|
||||
"ld1 {v3.4s}, [x8], %[dst_step]\n"
|
||||
|
||||
"fadd v0.4s, v0.4s, v1.4s\n"
|
||||
"ld1 {v4.4s}, [x7], %[src_step]\n"
|
||||
"fadd v2.4s, v2.4s, v3.4s\n"
|
||||
|
||||
"st1 {v0.4s}, [x10], %[dst_step]\n"
|
||||
"st1 {v2.4s}, [x10], %[dst_step]\n"
|
||||
|
||||
"ld1 {v5.4s}, [x8], %[dst_step]\n"
|
||||
|
||||
"ld1 {v6.4s}, [x7], %[src_step]\n"
|
||||
|
||||
"fadd v4.4s, v4.4s, v5.4s\n"
|
||||
"ld1 {v7.4s}, [x8], %[dst_step]\n"
|
||||
"fadd v6.4s, v6.4s, v7.4s\n"
|
||||
|
||||
"ld1 {v0.4s}, [x7], %[src_step]\n"
|
||||
"st1 {v4.4s}, [x10], %[dst_step]\n"
|
||||
"st1 {v6.4s}, [x10], %[dst_step]\n"
|
||||
|
||||
"ld1 {v1.4s}, [x8], %[dst_step]\n"
|
||||
|
||||
"ld1 {v2.4s}, [x7], %[src_step]\n"
|
||||
"ld1 {v3.4s}, [x8], %[dst_step]\n"
|
||||
|
||||
"fadd v0.4s, v0.4s, v1.4s\n"
|
||||
"fadd v2.4s, v2.4s, v3.4s\n"
|
||||
|
||||
"st1 {v0.4s}, [x10], %[dst_step]\n"
|
||||
"st1 {v2.4s}, [x10], %[dst_step]\n"
|
||||
|
||||
"ld1 {v4.4s}, [x7], %[src_step]\n"
|
||||
"ld1 {v5.4s}, [x8], %[dst_step]\n"
|
||||
|
||||
"ld1 {v6.4s}, [x7], %[src_step]\n"
|
||||
"ld1 {v7.4s}, [x8], %[dst_step]\n"
|
||||
|
||||
"fadd v4.4s, v4.4s, v5.4s\n"
|
||||
"fadd v6.4s, v6.4s, v7.4s\n"
|
||||
|
||||
"st1 {v4.4s}, [x10], %[dst_step]\n"
|
||||
"st1 {v6.4s}, [x10], %[dst_step]\n"
|
||||
|
||||
:
|
||||
: [ src_ptr ] "r"(src_ptr), [ dst_ptr ] "r"(dst_ptr), [ src_step ] "r"(src_step), [ dst_step ] "r"(dst_step)
|
||||
: "x7", "x8", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
|
||||
#elif ENABLE_ARM32
|
||||
size_t src_step = src_stride * sizeof(float);
|
||||
size_t dst_step = dst_stride * sizeof(float);
|
||||
asm volatile(
|
||||
"mov r7, %[src_ptr]\n"
|
||||
"mov r8, %[dst_ptr]\n"
|
||||
"mov r10, r8\n"
|
||||
|
||||
"vld1.32 {q0}, [r7], %[src_step]\n"
|
||||
"vld1.32 {q1}, [r8], %[dst_step]\n"
|
||||
"vld1.32 {q2}, [r7], %[src_step]\n"
|
||||
"vld1.32 {q3}, [r8], %[dst_step]\n"
|
||||
|
||||
"vadd.f32 q0, q0, q1\n"
|
||||
"vld1.32 {q8}, [r7], %[src_step]\n"
|
||||
"vadd.f32 q2, q2, q3\n"
|
||||
|
||||
"vst1.32 {q0}, [r10], %[dst_step]\n"
|
||||
"vst1.32 {q2}, [r10], %[dst_step]\n"
|
||||
|
||||
"vld1.32 {q9}, [r8], %[dst_step]\n"
|
||||
|
||||
"vld1.32 {q10}, [r7], %[src_step]\n"
|
||||
|
||||
"vadd.f32 q8, q8, q9\n"
|
||||
"vld1.32 {q11}, [r8], %[dst_step]\n"
|
||||
"vadd.f32 q10, q10, q11\n"
|
||||
|
||||
"vld1.32 {q0}, [r7], %[src_step]\n"
|
||||
"vst1.32 {q8}, [r10], %[dst_step]\n"
|
||||
"vst1.32 {q10}, [r10], %[dst_step]\n"
|
||||
|
||||
"vld1.32 {q1}, [r8], %[dst_step]\n"
|
||||
|
||||
"vld1.32 {q2}, [r7], %[src_step]\n"
|
||||
"vld1.32 {q3}, [r8], %[dst_step]\n"
|
||||
|
||||
"vadd.f32 q0, q0, q1\n"
|
||||
"vadd.f32 q2, q2, q3\n"
|
||||
|
||||
"vst1.32 {q0}, [r10], %[dst_step]\n"
|
||||
"vst1.32 {q2}, [r10], %[dst_step]\n"
|
||||
|
||||
"vld1.32 {q8}, [r7], %[src_step]\n"
|
||||
"vld1.32 {q9}, [r8], %[dst_step]\n"
|
||||
|
||||
"vld1.32 {q10}, [r7], %[src_step]\n"
|
||||
"vld1.32 {q11}, [r8], %[dst_step]\n"
|
||||
|
||||
"vadd.f32 q8, q8, q9\n"
|
||||
"vadd.f32 q10, q10, q11\n"
|
||||
|
||||
"vst1.32 {q8}, [r10], %[dst_step]\n"
|
||||
"vst1.32 {q10}, [r10], %[dst_step]\n"
|
||||
|
||||
:
|
||||
: [ src_ptr ] "r"(src_ptr), [ dst_ptr ] "r"(dst_ptr), [ src_step ] "r"(src_step), [ dst_step ] "r"(dst_step)
|
||||
: "r7", "r8", "r10", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11");
|
||||
#else
|
||||
for (int j = 0; j < 8; j++) {
|
||||
const float *s = src_ptr + j * src_stride;
|
||||
float *d = dst_ptr + j * dst_stride;
|
||||
for (int k = 0; k < 4; k++) {
|
||||
d[k] += s[k];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
src_ptr += 8 * src_stride;
|
||||
dst_ptr += 8 * dst_stride;
|
||||
}
|
||||
for (; i < count; i++) {
|
||||
#ifdef ENABLE_ARM
|
||||
float32x4_t src_data = vld1q_f32(src_ptr);
|
||||
float32x4_t dst_data = vld1q_f32(dst_ptr);
|
||||
dst_data = vaddq_f32(src_data, dst_data);
|
||||
vst1q_f32(dst_ptr, dst_data);
|
||||
#else
|
||||
for (int j = 0; j < 4; j++) {
|
||||
dst_ptr[j] += src_ptr[j];
|
||||
}
|
||||
#endif
|
||||
src_ptr += src_stride;
|
||||
dst_ptr += dst_stride;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue