!7305 [MS][LITE][CPU] fp16 conv1x1 asm optimize

Merge pull request !7305 from liuzhongkai/conv1x1_asmop
This commit is contained in:
mindspore-ci-bot 2020-10-14 20:12:12 +08:00 committed by Gitee
commit 55bcc0d7cf
1 changed files with 184 additions and 8 deletions

View File

@ -17,22 +17,198 @@
#include "nnacl/fp16/matmul_fp16.h"
void ColMajor2Row8MajorFp16(void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16) {
int row_c8 = row / C8NUM * C8NUM;
int col_c8 = col / C8NUM * C8NUM;
int ci = 0;
if (src_float16) {
float16_t *src = (float16_t *)src_ptr;
for (; ci < col_c8; ci += C8NUM) {
int ri = 0;
for (; ri < row_c8; ri += C8NUM) {
float16_t *src_ptr1 = src + ci * row + ri;
float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM;
#ifdef ENABLE_ARM64
size_t strid_row = row * 2;
asm volatile(
"mov x10, %[src_ptr1]\n"
"mov x11, %[dst_ptr1]\n"
"mov x12, %[strid_row]\n"
"ld1 {v0.8h}, [x10], x12\n"
"ld1 {v1.8h}, [x10], x12\n"
"ld1 {v2.8h}, [x10], x12\n"
"ld1 {v3.8h}, [x10], x12\n"
"ld1 {v4.8h}, [x10], x12\n"
"ld1 {v5.8h}, [x10], x12\n"
"ld1 {v6.8h}, [x10], x12\n"
"ld1 {v7.8h}, [x10], x12\n"
"zip1 v8.8h, v0.8h, v1.8h\n"
"zip1 v9.8h, v2.8h, v3.8h\n"
"zip1 v10.8h, v4.8h, v5.8h\n"
"zip1 v11.8h, v6.8h, v7.8h\n"
"trn1 v12.4s, v8.4s, v9.4s\n"
"trn1 v14.4s, v10.4s, v11.4s\n"
"trn2 v13.4s, v8.4s, v9.4s\n"
"trn2 v15.4s, v10.4s, v11.4s\n"
"trn1 v16.2d, v12.2d, v14.2d\n"
"trn2 v18.2d, v12.2d, v14.2d\n"
"trn1 v17.2d, v13.2d, v15.2d\n"
"trn2 v19.2d, v13.2d, v15.2d\n"
"zip2 v8.8h, v0.8h, v1.8h\n"
"zip2 v9.8h, v2.8h, v3.8h\n"
"zip2 v10.8h, v4.8h, v5.8h\n"
"zip2 v11.8h, v6.8h, v7.8h\n"
"trn1 v12.4s, v8.4s, v9.4s\n"
"trn1 v14.4s, v10.4s, v11.4s\n"
"trn2 v13.4s, v8.4s, v9.4s\n"
"trn2 v15.4s, v10.4s, v11.4s\n"
"trn1 v20.2d, v12.2d, v14.2d\n"
"trn2 v22.2d, v12.2d, v14.2d\n"
"trn1 v21.2d, v13.2d, v15.2d\n"
"trn2 v23.2d, v13.2d, v15.2d\n"
"st1 {v16.8h}, [x11], #16\n"
"st1 {v17.8h}, [x11], #16\n"
"st1 {v18.8h}, [x11], #16\n"
"st1 {v19.8h}, [x11], #16\n"
"st1 {v20.8h}, [x11], #16\n"
"st1 {v21.8h}, [x11], #16\n"
"st1 {v22.8h}, [x11], #16\n"
"st1 {v23.8h}, [x11], #16\n"
:
: [ dst_ptr1 ] "r"(dst_ptr1), [ src_ptr1 ] "r"(src_ptr1), [ strid_row ] "r"(strid_row)
: "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13",
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
#else
for (int tr = 0; tr < C8NUM; ++tr) {
for (int tc = 0; tc < C8NUM; ++tc) {
dst_ptr1[tr * C8NUM + tc] = src_ptr1[tc * row + tr];
}
}
#endif
}
for (; ri < row; ++ri) {
float16_t *src_ptr1 = src + ci * row;
float16_t *dst_ptr1 = dst_ptr + ci * row;
for (int tc = 0; tc < C8NUM; ++tc) {
dst_ptr1[ri * C8NUM + tc] = src_ptr1[tc * row + ri];
}
}
}
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int cd8 = c / 8;
int cm8 = c % 8;
dst_ptr[cd8 * 8 * row + r * 8 + cm8] = (float16_t)(src[c * row + r]);
for (int tc = ci; tc < col; tc++) {
int cd8 = tc / C8NUM;
int cm8 = tc % C8NUM;
dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[tc * row + r];
}
}
} else {
float *src = (float *)src_ptr;
for (; ci < col_c8; ci += C8NUM) {
int ri = 0;
for (; ri < row_c8; ri += C8NUM) {
float *src_ptr1 = src + ci * row + ri;
float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM;
#ifdef ENABLE_ARM64
size_t strid_row = row * 4;
asm volatile(
"mov x10, %[src_ptr1]\n"
"mov x11, %[dst_ptr1]\n"
"mov x12, %[strid_row]\n"
"ld1 {v8.4s, v9.4s}, [x10], x12\n"
"ld1 {v10.4s, v11.4s}, [x10], x12\n"
"ld1 {v12.4s, v13.4s}, [x10], x12\n"
"ld1 {v14.4s, v15.4s}, [x10], x12\n"
"ld1 {v16.4s, v17.4s}, [x10], x12\n"
"ld1 {v18.4s, v19.4s}, [x10], x12\n"
"ld1 {v20.4s, v21.4s}, [x10], x12\n"
"ld1 {v22.4s, v23.4s}, [x10], x12\n"
"fcvtn v0.4h, v8.4s\n"
"fcvtn2 v0.8h, v9.4s\n"
"fcvtn v1.4h, v10.4s\n"
"fcvtn2 v1.8h, v11.4s\n"
"fcvtn v2.4h, v12.4s\n"
"fcvtn2 v2.8h, v13.4s\n"
"fcvtn v3.4h, v14.4s\n"
"fcvtn2 v3.8h, v15.4s\n"
"fcvtn v4.4h, v16.4s\n"
"fcvtn2 v4.8h, v17.4s\n"
"fcvtn v5.4h, v18.4s\n"
"fcvtn2 v5.8h, v19.4s\n"
"fcvtn v6.4h, v20.4s\n"
"fcvtn2 v6.8h, v21.4s\n"
"fcvtn v7.4h, v22.4s\n"
"fcvtn2 v7.8h, v23.4s\n"
"zip1 v8.8h, v0.8h, v1.8h\n"
"zip1 v9.8h, v2.8h, v3.8h\n"
"zip1 v10.8h, v4.8h, v5.8h\n"
"zip1 v11.8h, v6.8h, v7.8h\n"
"trn1 v12.4s, v8.4s, v9.4s\n"
"trn1 v14.4s, v10.4s, v11.4s\n"
"trn2 v13.4s, v8.4s, v9.4s\n"
"trn2 v15.4s, v10.4s, v11.4s\n"
"trn1 v16.2d, v12.2d, v14.2d\n"
"trn2 v18.2d, v12.2d, v14.2d\n"
"trn1 v17.2d, v13.2d, v15.2d\n"
"trn2 v19.2d, v13.2d, v15.2d\n"
"zip2 v8.8h, v0.8h, v1.8h\n"
"zip2 v9.8h, v2.8h, v3.8h\n"
"zip2 v10.8h, v4.8h, v5.8h\n"
"zip2 v11.8h, v6.8h, v7.8h\n"
"trn1 v12.4s, v8.4s, v9.4s\n"
"trn1 v14.4s, v10.4s, v11.4s\n"
"trn2 v13.4s, v8.4s, v9.4s\n"
"trn2 v15.4s, v10.4s, v11.4s\n"
"trn1 v20.2d, v12.2d, v14.2d\n"
"trn2 v22.2d, v12.2d, v14.2d\n"
"trn1 v21.2d, v13.2d, v15.2d\n"
"trn2 v23.2d, v13.2d, v15.2d\n"
"st1 {v16.8h}, [x11], #16\n"
"st1 {v17.8h}, [x11], #16\n"
"st1 {v18.8h}, [x11], #16\n"
"st1 {v19.8h}, [x11], #16\n"
"st1 {v20.8h}, [x11], #16\n"
"st1 {v21.8h}, [x11], #16\n"
"st1 {v22.8h}, [x11], #16\n"
"st1 {v23.8h}, [x11], #16\n"
:
: [ dst_ptr1 ] "r"(dst_ptr1), [ src_ptr1 ] "r"(src_ptr1), [ strid_row ] "r"(strid_row)
: "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13",
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
#else
for (int tr = 0; tr < C8NUM; ++tr) {
for (int tc = 0; tc < C8NUM; ++tc) {
dst_ptr1[tr * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + tr]);
}
}
#endif
}
for (; ri < row; ++ri) {
float *src_ptr1 = src + ci * row;
float16_t *dst_ptr1 = dst_ptr + ci * row;
for (int tc = 0; tc < C8NUM; ++tc) {
dst_ptr1[ri * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + ri]);
}
}
}
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int cd8 = c / 8;
int cm8 = c % 8;
dst_ptr[cd8 * 8 * row + r * 8 + cm8] = (float16_t)(src[c * row + r]);
for (int tc = ci; tc < col; tc++) {
int cd8 = tc / C8NUM;
int cm8 = tc % C8NUM;
dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = (float16_t)(src[tc * row + r]);
}
}
}