deconv optimize

This commit is contained in:
lzk 2021-10-17 23:47:22 -07:00
parent 5caf88badd
commit 64728dcef1
2 changed files with 162 additions and 4 deletions

View File

@ -78,3 +78,5 @@ mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_1x1_x86_fp
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c:MatMul3x32Kernel, MatMul4x24Kernel, MatMul12x8Kernel, MatMul8x8Kernel, MatMul4x8Kernel, MatMul6x16Kernel, MatMul4x16Kernel, MatVecMul1x32Kernel, MatVecMul1x24Kernel, MatVecMul1x16Kernel, MatVecMul1x8Kernel
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/sse/TiledC4MatMulFp32.c:TiledC4MatmulFp32
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/sse/PostFuncBiasReluC4.c:PostFuncBiasReluC4
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/sse/WinogradTrans.c:WinogradTransRight
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/sse/WinogradTrans.c:WinogradTransLeft

View File

@ -22,10 +22,10 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_
size_t S_step = length * w * 4;
for (int h1 = 0; h1 < h; ++h1) {
const float *SW = S;
memset(M, 0, len_c4 * w * sizeof(float));
for (int w_tmp = w; w_tmp > 0; --w_tmp) {
const float *SK = SW;
const float *BK = B;
memset(M, 0, len_c4 * sizeof(float));
int k_tmp = k;
for (; k_tmp >= 7; k_tmp -= 7) {
__m128 k1 = _mm_load_ps1(BK);
@ -37,6 +37,26 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_
__m128 k7 = _mm_load_ps1(BK + 6 * h);
BK += 7 * h;
for (int len_tmp = length; len_tmp > 0; --len_tmp, M += 4, SK += 4) {
#ifdef ENABLE_AVX
__m128 M1 = _mm_loadu_ps(M);
__m128 M2 = _mm_set1_ps(0.0f);
__m128 s1 = _mm_loadu_ps(SK);
M1 = _mm_fmadd_ps(s1, k1, M1);
__m128 s2 = _mm_loadu_ps(SK + S_step);
M2 = _mm_fmadd_ps(s2, k2, M2);
__m128 s3 = _mm_loadu_ps(SK + 2 * S_step);
M1 = _mm_fmadd_ps(s3, k3, M1);
__m128 s4 = _mm_loadu_ps(SK + 3 * S_step);
M2 = _mm_fmadd_ps(s4, k4, M2);
__m128 s5 = _mm_loadu_ps(SK + 4 * S_step);
M1 = _mm_fmadd_ps(s5, k5, M1);
__m128 s6 = _mm_loadu_ps(SK + 5 * S_step);
M2 = _mm_fmadd_ps(s6, k6, M2);
__m128 s7 = _mm_loadu_ps(SK + 6 * S_step);
M1 = _mm_fmadd_ps(s7, k7, M1);
M1 = _mm_add_ps(M1, M2);
_mm_storeu_ps(M, M1);
#else
__m128 M1 = _mm_loadu_ps(M);
__m128 s0 = _mm_loadu_ps(SK);
M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1));
@ -54,6 +74,7 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_
M1 = _mm_add_ps(M1, _mm_mul_ps(s7, k7));
M1 = _mm_add_ps(M1, s1);
_mm_storeu_ps(M, M1);
#endif
}
M -= len_c4;
SK += 7 * S_step - len_c4;
@ -64,7 +85,46 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_
__m128 k3 = _mm_load_ps1(BK + 2 * h);
__m128 k4 = _mm_load_ps1(BK + 3 * h);
BK += 4 * h;
for (int len_tmp = length; len_tmp > 0; --len_tmp, SK += 4, M += 4) {
int len_tmp = length;
#ifdef ENABLE_AVX
for (; len_tmp >= C2NUM; len_tmp -= C2NUM, SK += C8NUM, M += C8NUM) {
__m128 M1 = _mm_loadu_ps(M);
__m128 M2 = _mm_loadu_ps(M + C4NUM);
__m128 s1 = _mm_loadu_ps(SK);
__m128 s11 = _mm_loadu_ps(SK + C4NUM);
M1 = _mm_fmadd_ps(s1, k1, M1);
M2 = _mm_fmadd_ps(s11, k1, M2);
__m128 s2 = _mm_loadu_ps(SK + S_step);
__m128 s22 = _mm_loadu_ps(SK + S_step + C4NUM);
M1 = _mm_fmadd_ps(s2, k2, M1);
M2 = _mm_fmadd_ps(s22, k2, M2);
__m128 s3 = _mm_loadu_ps(SK + 2 * S_step);
__m128 s33 = _mm_loadu_ps(SK + 2 * S_step + C4NUM);
M1 = _mm_fmadd_ps(s3, k3, M1);
M2 = _mm_fmadd_ps(s33, k3, M2);
__m128 s4 = _mm_loadu_ps(SK + 3 * S_step);
__m128 s44 = _mm_loadu_ps(SK + 3 * S_step + C4NUM);
M1 = _mm_fmadd_ps(s4, k4, M1);
M2 = _mm_fmadd_ps(s44, k4, M2);
_mm_storeu_ps(M, M1);
_mm_storeu_ps(M + C4NUM, M2);
}
#endif
for (; len_tmp > 0; --len_tmp, SK += 4, M += 4) {
#ifdef ENABLE_AVX
__m128 M1 = _mm_loadu_ps(M);
__m128 M2 = _mm_set1_ps(0.0f);
__m128 s1 = _mm_loadu_ps(SK);
M1 = _mm_fmadd_ps(s1, k1, M1);
__m128 s2 = _mm_loadu_ps(SK + S_step);
M2 = _mm_fmadd_ps(s2, k2, M2);
__m128 s3 = _mm_loadu_ps(SK + 2 * S_step);
M1 = _mm_fmadd_ps(s3, k3, M1);
__m128 s4 = _mm_loadu_ps(SK + 3 * S_step);
M2 = _mm_fmadd_ps(s4, k4, M2);
M1 = _mm_add_ps(M1, M2);
_mm_storeu_ps(M, M1);
#else
__m128 M1 = _mm_loadu_ps(M);
__m128 s0 = _mm_loadu_ps(SK);
M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1));
@ -76,6 +136,7 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_
s1 = _mm_add_ps(s1, _mm_mul_ps(s4, k4));
M1 = _mm_add_ps(M1, s1);
_mm_storeu_ps(M, M1);
#endif
}
M -= len_c4;
SK += 4 * S_step - len_c4;
@ -86,6 +147,18 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_
__m128 k3 = _mm_load_ps1(BK + 2 * h);
BK += 3 * h;
for (int len_tmp = length; len_tmp > 0; --len_tmp, SK += 4, M += 4) {
#ifdef ENABLE_AVX
__m128 M1 = _mm_loadu_ps(M);
__m128 M2 = _mm_set1_ps(0.0f);
__m128 s1 = _mm_loadu_ps(SK);
M1 = _mm_fmadd_ps(s1, k1, M1);
__m128 s2 = _mm_loadu_ps(SK + S_step);
M2 = _mm_fmadd_ps(s2, k2, M2);
__m128 s3 = _mm_loadu_ps(SK + 2 * S_step);
M1 = _mm_fmadd_ps(s3, k3, M1);
M1 = _mm_add_ps(M1, M2);
_mm_storeu_ps(M, M1);
#else
__m128 M1 = _mm_loadu_ps(M);
__m128 s0 = _mm_loadu_ps(SK);
M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1));
@ -95,6 +168,7 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_
M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3));
M1 = _mm_add_ps(M1, s1);
_mm_storeu_ps(M, M1);
#endif
}
M -= len_c4;
SK += 3 * S_step - len_c4;
@ -105,7 +179,11 @@ void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_
for (int len_tmp = length; len_tmp > 0; --len_tmp, SK += 4, M += 4) {
__m128 M1 = _mm_loadu_ps(M);
__m128 s0 = _mm_loadu_ps(SK);
#ifdef ENABLE_AVX
M1 = _mm_fmadd_ps(s0, k1, M1);
#else
M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1));
#endif
_mm_storeu_ps(M, M1);
}
M -= len_c4;
@ -122,9 +200,9 @@ void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size
size_t len_c4 = length * 4, k_step = len_c4 * k;
for (int h1 = 0; h1 < h; ++h1, S += k_step) {
const float *BW = B;
memset(M, 0, len_c4 * w * sizeof(float));
for (int ww = 0; ww < w; ++ww, BW += 1, M += len_c4) {
const float *SK = S, *BK = BW;
memset(M, 0, len_c4 * sizeof(float));
int k_tmp = k;
for (; k_tmp >= 7; k_tmp -= 7, M -= len_c4) {
__m128 k1 = _mm_load_ps1(BK);
@ -140,6 +218,26 @@ void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size
const float *S6 = S5 + len_c4, *S7 = S6 + len_c4;
for (int len_tmp = length; len_tmp > 0;
--len_tmp, M += 4, SK += 4, S2 += 4, S3 += 4, S4 += 4, S5 += 4, S6 += 4, S7 += 4) {
#ifdef ENABLE_AVX
__m128 M1 = _mm_loadu_ps(M);
__m128 M2 = _mm_set1_ps(0.0f);
__m128 s1 = _mm_loadu_ps(SK);
M1 = _mm_fmadd_ps(s1, k1, M1);
__m128 s2 = _mm_loadu_ps(S2);
M2 = _mm_fmadd_ps(s2, k2, M2);
__m128 s3 = _mm_loadu_ps(S3);
M1 = _mm_fmadd_ps(s3, k3, M1);
__m128 s4 = _mm_loadu_ps(S4);
M2 = _mm_fmadd_ps(s4, k4, M2);
__m128 s5 = _mm_loadu_ps(S5);
M1 = _mm_fmadd_ps(s5, k5, M1);
__m128 s6 = _mm_loadu_ps(S6);
M2 = _mm_fmadd_ps(s6, k6, M2);
__m128 s7 = _mm_loadu_ps(S7);
M1 = _mm_fmadd_ps(s7, k7, M1);
M1 = _mm_add_ps(M1, M2);
_mm_storeu_ps(M, M1);
#else
__m128 M1 = _mm_loadu_ps(M);
__m128 s0 = _mm_loadu_ps(SK);
M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1));
@ -157,6 +255,7 @@ void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size
M1 = _mm_add_ps(M1, _mm_mul_ps(s7, k7));
M1 = _mm_add_ps(M1, s1);
_mm_storeu_ps(M, M1);
#endif
}
SK = S7;
}
@ -169,7 +268,46 @@ void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size
const float *S2 = SK + len_c4;
const float *S3 = S2 + len_c4;
const float *S4 = S3 + len_c4;
for (int len_tmp = length; len_tmp > 0; --len_tmp, M += 4, SK += 4, S2 += 4, S3 += 4, S4 += 4) {
int len_tmp = length;
#ifdef ENABLE_AVX
for (; len_tmp >= C2NUM; len_tmp -= C2NUM, M += C8NUM, SK += C8NUM, S2 += C8NUM, S3 += C8NUM, S4 += C8NUM) {
__m128 M1 = _mm_loadu_ps(M);
__m128 M2 = _mm_loadu_ps(M + C4NUM);
__m128 s1 = _mm_loadu_ps(SK);
__m128 s11 = _mm_loadu_ps(SK + C4NUM);
M1 = _mm_fmadd_ps(s1, k1, M1);
M2 = _mm_fmadd_ps(s11, k1, M2);
__m128 s2 = _mm_loadu_ps(S2);
__m128 s22 = _mm_loadu_ps(S2 + C4NUM);
M1 = _mm_fmadd_ps(s2, k2, M1);
M2 = _mm_fmadd_ps(s22, k2, M2);
__m128 s3 = _mm_loadu_ps(S3);
__m128 s33 = _mm_loadu_ps(S3 + C4NUM);
M1 = _mm_fmadd_ps(s3, k3, M1);
M2 = _mm_fmadd_ps(s33, k3, M2);
__m128 s4 = _mm_loadu_ps(S4);
__m128 s44 = _mm_loadu_ps(S4 + C4NUM);
M1 = _mm_fmadd_ps(s4, k4, M1);
M2 = _mm_fmadd_ps(s44, k4, M2);
_mm_storeu_ps(M, M1);
_mm_storeu_ps(M + C4NUM, M2);
}
#endif
for (; len_tmp > 0; --len_tmp, M += 4, SK += 4, S2 += 4, S3 += 4, S4 += 4) {
#ifdef ENABLE_AVX
__m128 M1 = _mm_loadu_ps(M);
__m128 M2 = _mm_set1_ps(0.0f);
__m128 s1 = _mm_loadu_ps(SK);
M1 = _mm_fmadd_ps(s1, k1, M1);
__m128 s2 = _mm_loadu_ps(S2);
M2 = _mm_fmadd_ps(s2, k2, M2);
__m128 s3 = _mm_loadu_ps(S3);
M1 = _mm_fmadd_ps(s3, k3, M1);
__m128 s4 = _mm_loadu_ps(S4);
M2 = _mm_fmadd_ps(s4, k4, M2);
M1 = _mm_add_ps(M1, M2);
_mm_storeu_ps(M, M1);
#else
__m128 M1 = _mm_loadu_ps(M);
__m128 s0 = _mm_loadu_ps(SK);
M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1));
@ -181,6 +319,7 @@ void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size
s1 = _mm_add_ps(s1, _mm_mul_ps(s4, k4));
M1 = _mm_add_ps(M1, s1);
_mm_storeu_ps(M, M1);
#endif
}
SK = S4;
}
@ -192,6 +331,18 @@ void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size
const float *S2 = SK + len_c4;
const float *S3 = S2 + len_c4;
for (int len_tmp = length; len_tmp > 0; --len_tmp, M += 4, SK += 4, S2 += 4, S3 += 4) {
#ifdef ENABLE_AVX
__m128 M1 = _mm_loadu_ps(M);
__m128 M2 = _mm_set1_ps(0.0f);
__m128 s0 = _mm_loadu_ps(SK);
M1 = _mm_fmadd_ps(s0, k1, M1);
__m128 s1 = _mm_loadu_ps(S2);
M2 = _mm_fmadd_ps(s1, k2, M2);
__m128 s3 = _mm_loadu_ps(S3);
M1 = _mm_fmadd_ps(s3, k3, M1);
M1 = _mm_add_ps(M1, M2);
_mm_storeu_ps(M, M1);
#else
__m128 M1 = _mm_loadu_ps(M);
__m128 s0 = _mm_loadu_ps(SK);
M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1));
@ -201,6 +352,7 @@ void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size
M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3));
M1 = _mm_add_ps(M1, s1);
_mm_storeu_ps(M, M1);
#endif
}
SK = S3;
}
@ -210,7 +362,11 @@ void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size
for (int len_tmp = length; len_tmp > 0; --len_tmp, M += 4, SK += 4) {
__m128 M1 = _mm_loadu_ps(M);
__m128 s0 = _mm_loadu_ps(SK);
#ifdef ENABLE_AVX
M1 = _mm_fmadd_ps(s0, k1, M1);
#else
M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1));
#endif
_mm_storeu_ps(M, M1);
}
}