!4551 add fast transpose algorithm

Merge pull request !4551 from lixian/master
This commit is contained in:
mindspore-ci-bot 2020-08-16 23:18:47 +08:00 committed by Gitee
commit eaeb3fe7ee
1 changed files with 108 additions and 17 deletions

View File

@ -865,13 +865,113 @@ void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int ch
return;
}
void PackNHWCToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel) {
for (int n = 0; n < batch; n++) {
for (int c = 0; c < channel; c++) {
for (int hw = 0; hw < plane; hw++) {
int nhwc_index = n * channel * plane + hw * channel + c;
int nchw_index = n * channel * plane + c * plane + hw;
((float *)dst)[nchw_index] = ((float *)src)[nhwc_index];
void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) {
int hw8 = plane / C8NUM * C8NUM;
int c8 = channel / C8NUM * C8NUM;
int batch = plane * channel;
for (int n = 0; n < batches; n++) {
const float *src_batch = (const float*) src + n * batch;
float *dst_batch = (float*) dst + n * batch;
int hw = 0;
for (; hw < hw8; hw += C8NUM) {
int c = 0;
for (; c < c8; c += C8NUM) {
const float *src_ptr = src_batch + hw * channel + c;
float *dst_ptr = dst_batch + c * plane + hw;
#ifdef ENABLE_ARM64
int srcStride = channel * 4;
int dstStride = plane * 4;
asm volatile(
"mov x10, %[src_ptr]\n"
"mov x11, %[dst_ptr]\n"
"ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n"
"ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n"
"zip1 v8.4s, v0.4s, v2.4s\n"
"zip2 v9.4s, v0.4s, v2.4s\n"
"zip1 v12.4s, v1.4s, v3.4s\n"
"zip2 v13.4s, v1.4s, v3.4s\n"
"ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n"
"ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n"
"zip1 v10.4s, v4.4s, v6.4s\n"
"zip2 v11.4s, v4.4s, v6.4s\n"
"zip1 v14.4s, v5.4s, v7.4s\n"
"zip2 v15.4s, v5.4s, v7.4s\n"
"ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n"
"ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n"
"trn1 v16.2d, v8.2d, v10.2d\n"
"trn2 v18.2d, v8.2d, v10.2d\n"
"trn1 v20.2d, v9.2d, v11.2d\n"
"trn2 v22.2d, v9.2d, v11.2d\n"
"ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n"
"ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n"
"trn1 v24.2d, v12.2d, v14.2d\n"
"trn2 v26.2d, v12.2d, v14.2d\n"
"trn1 v28.2d, v13.2d, v15.2d\n"
"trn2 v30.2d, v13.2d, v15.2d\n"
"zip1 v8.4s, v0.4s, v2.4s\n"
"zip2 v9.4s, v0.4s, v2.4s\n"
"zip1 v12.4s, v1.4s, v3.4s\n"
"zip2 v13.4s, v1.4s, v3.4s\n"
"zip1 v10.4s, v4.4s, v6.4s\n"
"zip2 v11.4s, v4.4s, v6.4s\n"
"zip1 v14.4s, v5.4s, v7.4s\n"
"zip2 v15.4s, v5.4s, v7.4s\n"
"trn1 v17.2d, v8.2d, v10.2d\n"
"trn2 v19.2d, v8.2d, v10.2d\n"
"trn1 v21.2d, v9.2d, v11.2d\n"
"trn2 v23.2d, v9.2d, v11.2d\n"
"trn1 v25.2d, v12.2d, v14.2d\n"
"trn2 v27.2d, v12.2d, v14.2d\n"
"trn1 v29.2d, v13.2d, v15.2d\n"
"trn2 v31.2d, v13.2d, v15.2d\n"
"st1 {v16.4s, v17.4s}, [x11], %[dstStride]\n"
"st1 {v18.4s, v19.4s}, [x11], %[dstStride]\n"
"st1 {v20.4s, v21.4s}, [x11], %[dstStride]\n"
"st1 {v22.4s, v23.4s}, [x11], %[dstStride]\n"
"st1 {v24.4s, v25.4s}, [x11], %[dstStride]\n"
"st1 {v26.4s, v27.4s}, [x11], %[dstStride]\n"
"st1 {v28.4s, v29.4s}, [x11], %[dstStride]\n"
"st1 {v30.4s, v31.4s}, [x11], %[dstStride]\n"
:
: [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride)
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
"v30", "v31");
#else
for (int tr = 0; tr < C8NUM; tr++) {
for (int tc = 0; tc < C8NUM; tc++) {
dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc];
}
}
#endif
}
for (; c < channel; c++) {
float *src_ptr = src_batch + hw * channel + c;
float *dst_ptr = dst_batch + c * plane + hw;
for (size_t i = 0; i < C8NUM; i++) {
dst_ptr[i] = src_ptr[i * channel];
}
}
}
for (; hw < plane; hw++) {
float *src_ptr = src_batch + hw * channel;
float *dst_ptr = dst_batch + hw;
for (size_t i = 0; i < channel; i++) {
dst_ptr[i * plane] = src_ptr[i];
}
}
}
@ -879,16 +979,7 @@ void PackNHWCToNCHWFp32(const void *src, void *dst, int batch, int plane, int ch
}
void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) {
for (int n = 0; n < batch; n++) {
for (int c = 0; c < channel; c++) {
for (int hw = 0; hw < plane; hw++) {
int nhwc_index = n * channel * plane + hw * channel + c;
int nchw_index = n * channel * plane + c * plane + hw;
((float *)dst)[nhwc_index] = ((float *)src)[nchw_index];
}
}
}
return;
return PackNHWCToNCHWFp32(src, dst, batch, channel, plane);
}
void MatrixPackUnit(const float *src, float *dst, size_t row, size_t col, size_t src_stride, size_t dst_stride) {