deconv optimize

This commit is contained in:
lzk 2021-10-12 00:39:52 -07:00
parent ef44a0a981
commit 30864ad283
3 changed files with 294 additions and 15 deletions

View File

@ -76,3 +76,4 @@ mindspore/mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py:dsdbpropimpl
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_1x1_x86_fp32.c:Conv1x1SW3x32Kernel, Conv1x1SW4x24Kernel, Conv1x1SW12x8Kernel, Conv1x1SW8x8Kernel, Conv1x1SW4x8Kernel, Conv1x1SW6x16Kernel, Conv1x1SW4x16Kernel, Conv1x1SW1x32Kernel, Conv1x1SW1x24Kernel, Conv1x1SW1x16Kernel, Conv1x1SW1x8Kernel
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c:MatMul3x32Kernel, MatMul4x24Kernel, MatMul12x8Kernel, MatMul8x8Kernel, MatMul4x8Kernel, MatMul6x16Kernel, MatMul4x16Kernel, MatVecMul1x32Kernel, MatVecMul1x24Kernel, MatVecMul1x16Kernel, MatVecMul1x8Kernel
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/sse/TiledC4MatMulFp32.c:TiledC4MatmulFp32
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/sse/PostFuncBiasReluC4.c:PostFuncBiasReluC4

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -23,26 +23,269 @@ void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t
size_t plane_size, size_t plane_stride, size_t relu_type) {
size_t stride = oc4div + oc4mod;
plane_stride /= sizeof(float);
for (size_t loop_c4 = 0; loop_c4 < oc4div; loop_c4 += C4NUM) {
int loop_c4 = 0;
size_t src_stride = plane_size * C4NUM + plane_stride;
for (; loop_c4 <= (int)(oc4div)-C16NUM; loop_c4 += C16NUM) {
size_t plane_size_tmp = plane_size;
float *dst_c4 = dst + loop_c4;
__m128 bias1 = _mm_setzero_ps();
__m128 bias2 = _mm_setzero_ps();
__m128 bias3 = _mm_setzero_ps();
__m128 bias4 = _mm_setzero_ps();
if (bias != NULL) {
bias1 = _mm_loadu_ps(bias);
bias2 = _mm_loadu_ps(bias + C4NUM);
bias3 = _mm_loadu_ps(bias + C8NUM);
bias4 = _mm_loadu_ps(bias + C12NUM);
bias += C16NUM;
}
for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) {
__m128 src1 = _mm_loadu_ps(src);
__m128 src2 = _mm_loadu_ps(src + C4NUM);
__m128 src5 = _mm_loadu_ps(src + src_stride);
__m128 src6 = _mm_loadu_ps(src + src_stride + C4NUM);
__m128 src9 = _mm_loadu_ps(src + src_stride * C2NUM);
__m128 src10 = _mm_loadu_ps(src + src_stride * C2NUM + C4NUM);
__m128 src13 = _mm_loadu_ps(src + src_stride * C3NUM);
__m128 src14 = _mm_loadu_ps(src + src_stride * C3NUM + C4NUM);
src1 = _mm_add_ps(src1, bias1);
src2 = _mm_add_ps(src2, bias1);
src5 = _mm_add_ps(src5, bias2);
src6 = _mm_add_ps(src6, bias2);
src9 = _mm_add_ps(src9, bias3);
src10 = _mm_add_ps(src10, bias3);
src13 = _mm_add_ps(src13, bias4);
src14 = _mm_add_ps(src14, bias4);
ActBlock8(&src1, &src2, &src5, &src6, &src9, &src10, &src13, &src14, relu_type);
_mm_storeu_ps(dst_c4, src1);
_mm_storeu_ps(dst_c4 + C4NUM, src5);
_mm_storeu_ps(dst_c4 + C8NUM, src9);
_mm_storeu_ps(dst_c4 + C12NUM, src13);
dst_c4 += stride;
_mm_storeu_ps(dst_c4, src2);
_mm_storeu_ps(dst_c4 + C4NUM, src6);
_mm_storeu_ps(dst_c4 + C8NUM, src10);
_mm_storeu_ps(dst_c4 + C12NUM, src14);
dst_c4 += stride;
__m128 src3 = _mm_loadu_ps(src + C8NUM);
__m128 src4 = _mm_loadu_ps(src + C12NUM);
__m128 src7 = _mm_loadu_ps(src + src_stride + C8NUM);
__m128 src8 = _mm_loadu_ps(src + src_stride + C12NUM);
__m128 src11 = _mm_loadu_ps(src + src_stride * C2NUM + C8NUM);
__m128 src12 = _mm_loadu_ps(src + src_stride * C2NUM + C12NUM);
__m128 src15 = _mm_loadu_ps(src + src_stride * C3NUM + C8NUM);
__m128 src16 = _mm_loadu_ps(src + src_stride * C3NUM + C12NUM);
src3 = _mm_add_ps(src3, bias1);
src4 = _mm_add_ps(src4, bias1);
src7 = _mm_add_ps(src7, bias2);
src8 = _mm_add_ps(src8, bias2);
src11 = _mm_add_ps(src11, bias3);
src12 = _mm_add_ps(src12, bias3);
src15 = _mm_add_ps(src15, bias4);
src16 = _mm_add_ps(src16, bias4);
ActBlock8(&src3, &src4, &src7, &src8, &src11, &src12, &src15, &src16, relu_type);
_mm_storeu_ps(dst_c4, src3);
_mm_storeu_ps(dst_c4 + C4NUM, src7);
_mm_storeu_ps(dst_c4 + C8NUM, src11);
_mm_storeu_ps(dst_c4 + C12NUM, src15);
dst_c4 += stride;
_mm_storeu_ps(dst_c4, src4);
_mm_storeu_ps(dst_c4 + C4NUM, src8);
_mm_storeu_ps(dst_c4 + C8NUM, src12);
_mm_storeu_ps(dst_c4 + C12NUM, src16);
dst_c4 += stride;
src += C16NUM;
}
for (; plane_size_tmp > 0; plane_size_tmp -= 1) {
__m128 src1 = _mm_loadu_ps(src);
__m128 src2 = _mm_loadu_ps(src + src_stride);
__m128 src3 = _mm_loadu_ps(src + src_stride * C2NUM);
__m128 src4 = _mm_loadu_ps(src + src_stride * C3NUM);
src1 = _mm_add_ps(src1, bias1);
src2 = _mm_add_ps(src2, bias2);
src3 = _mm_add_ps(src3, bias3);
src4 = _mm_add_ps(src4, bias4);
ActBlock4(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == C3NUM);
_mm_storeu_ps(dst_c4, src1);
_mm_storeu_ps(dst_c4 + C4NUM, src2);
_mm_storeu_ps(dst_c4 + C8NUM, src3);
_mm_storeu_ps(dst_c4 + C12NUM, src4);
dst_c4 += stride;
src += C4NUM;
}
src += plane_stride;
src += C3NUM * src_stride;
}
for (; loop_c4 <= (int)(oc4div)-C12NUM; loop_c4 += C12NUM) {
size_t plane_size_tmp = plane_size;
float *dst_c4 = dst + loop_c4;
__m128 bias1 = _mm_setzero_ps();
__m128 bias2 = _mm_setzero_ps();
__m128 bias3 = _mm_setzero_ps();
if (bias != NULL) {
bias1 = _mm_loadu_ps(bias);
bias2 = _mm_loadu_ps(bias + C4NUM);
bias3 = _mm_loadu_ps(bias + C8NUM);
bias += C12NUM;
}
for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) {
__m128 src1 = _mm_loadu_ps(src);
__m128 src2 = _mm_loadu_ps(src + C4NUM);
__m128 src3 = _mm_loadu_ps(src + C8NUM);
__m128 src4 = _mm_loadu_ps(src + C12NUM);
__m128 src5 = _mm_loadu_ps(src + src_stride);
__m128 src6 = _mm_loadu_ps(src + src_stride + C4NUM);
__m128 src7 = _mm_loadu_ps(src + src_stride + C8NUM);
__m128 src8 = _mm_loadu_ps(src + src_stride + C12NUM);
__m128 src9 = _mm_loadu_ps(src + src_stride * C2NUM);
__m128 src10 = _mm_loadu_ps(src + src_stride * C2NUM + C4NUM);
__m128 src11 = _mm_loadu_ps(src + src_stride * C2NUM + C8NUM);
__m128 src12 = _mm_loadu_ps(src + src_stride * C2NUM + C12NUM);
src += C16NUM;
src1 = _mm_add_ps(src1, bias1);
src2 = _mm_add_ps(src2, bias1);
src3 = _mm_add_ps(src3, bias1);
src4 = _mm_add_ps(src4, bias1);
src5 = _mm_add_ps(src5, bias2);
src6 = _mm_add_ps(src6, bias2);
src7 = _mm_add_ps(src7, bias2);
src8 = _mm_add_ps(src8, bias2);
src9 = _mm_add_ps(src9, bias3);
src10 = _mm_add_ps(src10, bias3);
src11 = _mm_add_ps(src11, bias3);
src12 = _mm_add_ps(src12, bias3);
ActBlock12(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, &src9, &src10, &src11, &src12, relu_type == 1,
relu_type == C3NUM);
_mm_storeu_ps(dst_c4, src1);
_mm_storeu_ps(dst_c4 + C4NUM, src5);
_mm_storeu_ps(dst_c4 + C8NUM, src9);
dst_c4 += stride;
_mm_storeu_ps(dst_c4, src2);
_mm_storeu_ps(dst_c4 + C4NUM, src6);
_mm_storeu_ps(dst_c4 + C8NUM, src10);
dst_c4 += stride;
_mm_storeu_ps(dst_c4, src3);
_mm_storeu_ps(dst_c4 + C4NUM, src7);
_mm_storeu_ps(dst_c4 + C8NUM, src11);
dst_c4 += stride;
_mm_storeu_ps(dst_c4, src4);
_mm_storeu_ps(dst_c4 + C4NUM, src8);
_mm_storeu_ps(dst_c4 + C8NUM, src12);
dst_c4 += stride;
}
for (; plane_size_tmp > 0; plane_size_tmp -= 1) {
__m128 src1 = _mm_loadu_ps(src);
__m128 src2 = _mm_loadu_ps(src + src_stride);
__m128 src3 = _mm_loadu_ps(src + src_stride * C2NUM);
src1 = _mm_add_ps(src1, bias1);
src2 = _mm_add_ps(src2, bias2);
src3 = _mm_add_ps(src3, bias3);
ActBlock1(&src1, relu_type == 1, relu_type == C3NUM);
ActBlock1(&src2, relu_type == 1, relu_type == C3NUM);
ActBlock1(&src3, relu_type == 1, relu_type == C3NUM);
_mm_storeu_ps(dst_c4, src1);
_mm_storeu_ps(dst_c4 + C4NUM, src2);
_mm_storeu_ps(dst_c4 + C8NUM, src3);
dst_c4 += stride;
src += C4NUM;
}
src += plane_stride;
src += C2NUM * src_stride;
}
for (; loop_c4 <= (int)(oc4div)-C8NUM; loop_c4 += C8NUM) {
size_t plane_size_tmp = plane_size;
float *dst_c4 = dst + loop_c4;
__m128 bias1 = _mm_setzero_ps();
__m128 bias2 = _mm_setzero_ps();
if (bias != NULL) {
bias1 = _mm_loadu_ps(bias);
bias2 = _mm_loadu_ps(bias + C4NUM);
bias += C8NUM;
}
for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) {
__m128 src1 = _mm_loadu_ps(src);
__m128 src2 = _mm_loadu_ps(src + C4NUM);
__m128 src3 = _mm_loadu_ps(src + C8NUM);
__m128 src4 = _mm_loadu_ps(src + C12NUM);
__m128 src5 = _mm_loadu_ps(src + src_stride);
__m128 src6 = _mm_loadu_ps(src + src_stride + C4NUM);
__m128 src7 = _mm_loadu_ps(src + src_stride + C8NUM);
__m128 src8 = _mm_loadu_ps(src + src_stride + C12NUM);
src += C16NUM;
src1 = _mm_add_ps(src1, bias1);
src2 = _mm_add_ps(src2, bias1);
src3 = _mm_add_ps(src3, bias1);
src4 = _mm_add_ps(src4, bias1);
src5 = _mm_add_ps(src5, bias2);
src6 = _mm_add_ps(src6, bias2);
src7 = _mm_add_ps(src7, bias2);
src8 = _mm_add_ps(src8, bias2);
ActBlock8(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, relu_type);
_mm_storeu_ps(dst_c4, src1);
_mm_storeu_ps(dst_c4 + C4NUM, src5);
dst_c4 += stride;
_mm_storeu_ps(dst_c4, src2);
_mm_storeu_ps(dst_c4 + C4NUM, src6);
dst_c4 += stride;
_mm_storeu_ps(dst_c4, src3);
_mm_storeu_ps(dst_c4 + C4NUM, src7);
dst_c4 += stride;
_mm_storeu_ps(dst_c4, src4);
_mm_storeu_ps(dst_c4 + C4NUM, src8);
dst_c4 += stride;
}
for (; plane_size_tmp > 0; plane_size_tmp -= 1) {
__m128 src1 = _mm_loadu_ps(src);
__m128 src2 = _mm_loadu_ps(src + src_stride);
src1 = _mm_add_ps(src1, bias1);
src2 = _mm_add_ps(src2, bias2);
ActBlock1(&src1, relu_type == 1, relu_type == C3NUM);
ActBlock1(&src2, relu_type == 1, relu_type == C3NUM);
_mm_storeu_ps(dst_c4, src1);
_mm_storeu_ps(dst_c4 + C4NUM, src2);
dst_c4 += stride;
src += C4NUM;
}
src += plane_stride;
src += src_stride;
}
for (; loop_c4 < (int)(oc4div); loop_c4 += C4NUM) {
size_t plane_size_tmp = plane_size;
float *dst_c4 = dst + loop_c4;
__m128 bias1 = _mm_setzero_ps();
if (bias != NULL) {
bias1 = _mm_loadu_ps(bias);
bias += 4;
bias += C4NUM;
}
for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) {
__m128 src1 = _mm_loadu_ps(src);
__m128 src2 = _mm_loadu_ps(src + 4);
__m128 src2 = _mm_loadu_ps(src + C4NUM);
__m128 src3 = _mm_loadu_ps(src + 8);
__m128 src4 = _mm_loadu_ps(src + 12);
src += 16;
src += C16NUM;
src1 = _mm_add_ps(src1, bias1);
src2 = _mm_add_ps(src2, bias1);
src3 = _mm_add_ps(src3, bias1);
src4 = _mm_add_ps(src4, bias1);
ActBlock4(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == 3);
ActBlock4(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == C3NUM);
_mm_storeu_ps(dst_c4, src1);
dst_c4 += stride;
@ -57,11 +300,11 @@ void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t
__m128 src1 = _mm_loadu_ps(src);
src1 = _mm_add_ps(src1, bias1);
ActBlock1(&src1, relu_type == 1, relu_type == 3);
ActBlock1(&src1, relu_type == 1, relu_type == C3NUM);
_mm_storeu_ps(dst_c4, src1);
dst_c4 += stride;
src += 4;
src += C4NUM;
}
src += plane_stride;
}
@ -71,32 +314,32 @@ void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t
__m128 bias1 = _mm_setzero_ps();
if (bias != NULL) {
bias1 = _mm_loadu_ps(bias);
bias += 4;
bias += C4NUM;
}
float *dst_c1 = dst + oc4div;
for (size_t plane_size_tmp = plane_size; plane_size_tmp > 0; plane_size_tmp -= 1) {
__m128 src1 = _mm_loadu_ps(src);
src += 4;
src += C4NUM;
src1 = _mm_add_ps(src1, bias1);
ActBlock1(&src1, relu_type == 1, relu_type == 3);
ActBlock1(&src1, relu_type == 1, relu_type == C3NUM);
switch (oc4mod) {
case 1:
_mm_store_ss(dst_c1, src1);
dst_c1 += stride;
break;
case 2:
case C2NUM:
_mm_storel_pi((__m64 *)(dst_c1), src1);
dst_c1 += stride;
break;
case 3:
case C3NUM:
_mm_storel_pi((__m64 *)(dst_c1), src1);
src1 = _mm_unpackhi_ps(src1, src1);
_mm_store_ss(dst_c1 + 2, src1);
_mm_store_ss(dst_c1 + C2NUM, src1);
dst_c1 += stride;
break;
case 4:
case C4NUM:
_mm_storeu_ps(dst_c1, src1);
dst_c1 += stride;
break;

View File

@ -58,6 +58,41 @@ static inline void ActBlock4(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, siz
}
}
static inline void ActBlock12(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, __m128 *v5, __m128 *v6, __m128 *v7,
__m128 *v8, __m128 *v9, __m128 *v10, __m128 *v11, __m128 *v12, size_t relu,
size_t relu6) {
if (relu || relu6) {
__m128 zero_ma = _mm_setzero_ps();
*v1 = _mm_max_ps(zero_ma, *v1);
*v2 = _mm_max_ps(zero_ma, *v2);
*v3 = _mm_max_ps(zero_ma, *v3);
*v4 = _mm_max_ps(zero_ma, *v4);
*v5 = _mm_max_ps(zero_ma, *v5);
*v6 = _mm_max_ps(zero_ma, *v6);
*v7 = _mm_max_ps(zero_ma, *v7);
*v8 = _mm_max_ps(zero_ma, *v8);
*v9 = _mm_max_ps(zero_ma, *v9);
*v10 = _mm_max_ps(zero_ma, *v10);
*v11 = _mm_max_ps(zero_ma, *v11);
*v12 = _mm_max_ps(zero_ma, *v12);
}
if (relu6) {
__m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f);
*v1 = _mm_min_ps(relu6_ma, *v1);
*v2 = _mm_min_ps(relu6_ma, *v2);
*v3 = _mm_min_ps(relu6_ma, *v3);
*v4 = _mm_min_ps(relu6_ma, *v4);
*v5 = _mm_min_ps(relu6_ma, *v5);
*v6 = _mm_min_ps(relu6_ma, *v6);
*v7 = _mm_min_ps(relu6_ma, *v7);
*v8 = _mm_min_ps(relu6_ma, *v8);
*v9 = _mm_min_ps(relu6_ma, *v9);
*v10 = _mm_min_ps(relu6_ma, *v10);
*v11 = _mm_min_ps(relu6_ma, *v11);
*v12 = _mm_min_ps(relu6_ma, *v12);
}
}
static inline void ActBlock8(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, __m128 *v5, __m128 *v6, __m128 *v7,
__m128 *v8, size_t relu_type) {
__m128 relu6 = _mm_set_ps1(6.0);