!7146 [MS][LITE][CPU] add arm32 fp32 matmul asm optimize

Merge pull request !7146 from liuzhongkai/winograd
This commit is contained in:
mindspore-ci-bot 2020-10-10 15:02:17 +08:00 committed by Gitee
commit 4500df104e
1 changed files with 41 additions and 1 deletions

View File

@ -233,7 +233,7 @@ void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col)
/* 8x4 row-major to col-major */ /* 8x4 row-major to col-major */
#ifdef ENABLE_ARM64 #ifdef ENABLE_ARM64
size_t stride = col * 4; size_t stride = col * sizeof(float);
asm volatile( asm volatile(
"mov x10, %[src_c]\n" "mov x10, %[src_c]\n"
"mov x11, %[dst_c]\n" "mov x11, %[dst_c]\n"
@ -281,6 +281,46 @@ void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col)
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15"); "v15");
#elif ENABLE_ARM32
size_t stride = col * sizeof(float);
asm volatile(
"mov r10, %[src_c]\n"
"mov r11, %[dst_c]\n"
"vld1.32 {q0}, [r10], %[stride]\n"
"vld1.32 {q2}, [r10], %[stride]\n"
"vld1.32 {q4}, [r10], %[stride]\n"
"vld1.32 {q6}, [r10], %[stride]\n"
"vtrn.32 d0, d4\n"
"vtrn.32 d1, d5\n"
"vtrn.32 d8, d12\n"
"vtrn.32 d9, d13\n"
"vld1.32 {q1}, [r10], %[stride]\n"
"vld1.32 {q3}, [r10], %[stride]\n"
"vld1.32 {q5}, [r10], %[stride]\n"
"vld1.32 {q7}, [r10], %[stride]\n"
"vswp d1, d8\n"
"vswp d5, d12\n"
"vtrn.32 d2, d6\n"
"vtrn.32 d3, d7\n"
"vtrn.32 d10, d14\n"
"vtrn.32 d11, d15\n"
"vswp d3, d10\n"
"vswp d7, d14\n"
"vst1.32 {q0, q1}, [r11]!\n"
"vst1.32 {q2, q3}, [r11]!\n"
"vst1.32 {q4, q5}, [r11]!\n"
"vst1.32 {q6, q7}, [r11]!\n"
:
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
: "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
#else #else
for (int tr = 0; tr < 8; tr++) { for (int tr = 0; tr < 8; tr++) {
for (int tc = 0; tc < 4; tc++) { for (int tc = 0; tc < 4; tc++) {