forked from mindspore-Ecosystem/mindspore
!24732 [MS][LITE][CPU] deconv 性能优化
Merge pull request !24732 from liuzhongkai/winograd_op1
This commit is contained in:
commit
6c4a131c5a
|
@ -77,3 +77,4 @@ mindspore/mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py:dsdbpropimpl
|
||||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_1x1_x86_fp32.c:Conv1x1SW3x32Kernel, Conv1x1SW4x24Kernel, Conv1x1SW12x8Kernel, Conv1x1SW8x8Kernel, Conv1x1SW4x8Kernel, Conv1x1SW6x16Kernel, Conv1x1SW4x16Kernel, Conv1x1SW1x32Kernel, Conv1x1SW1x24Kernel, Conv1x1SW1x16Kernel, Conv1x1SW1x8Kernel
|
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_1x1_x86_fp32.c:Conv1x1SW3x32Kernel, Conv1x1SW4x24Kernel, Conv1x1SW12x8Kernel, Conv1x1SW8x8Kernel, Conv1x1SW4x8Kernel, Conv1x1SW6x16Kernel, Conv1x1SW4x16Kernel, Conv1x1SW1x32Kernel, Conv1x1SW1x24Kernel, Conv1x1SW1x16Kernel, Conv1x1SW1x8Kernel
|
||||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c:MatMul3x32Kernel, MatMul4x24Kernel, MatMul12x8Kernel, MatMul8x8Kernel, MatMul4x8Kernel, MatMul6x16Kernel, MatMul4x16Kernel, MatVecMul1x32Kernel, MatVecMul1x24Kernel, MatVecMul1x16Kernel, MatVecMul1x8Kernel
|
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c:MatMul3x32Kernel, MatMul4x24Kernel, MatMul12x8Kernel, MatMul8x8Kernel, MatMul4x8Kernel, MatMul6x16Kernel, MatMul4x16Kernel, MatVecMul1x32Kernel, MatVecMul1x24Kernel, MatVecMul1x16Kernel, MatVecMul1x8Kernel
|
||||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/sse/TiledC4MatMulFp32.c:TiledC4MatmulFp32
|
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/sse/TiledC4MatMulFp32.c:TiledC4MatmulFp32
|
||||||
|
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/sse/PostFuncBiasReluC4.c:PostFuncBiasReluC4
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -23,26 +23,269 @@ void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t
|
||||||
size_t plane_size, size_t plane_stride, size_t relu_type) {
|
size_t plane_size, size_t plane_stride, size_t relu_type) {
|
||||||
size_t stride = oc4div + oc4mod;
|
size_t stride = oc4div + oc4mod;
|
||||||
plane_stride /= sizeof(float);
|
plane_stride /= sizeof(float);
|
||||||
for (size_t loop_c4 = 0; loop_c4 < oc4div; loop_c4 += C4NUM) {
|
int loop_c4 = 0;
|
||||||
|
size_t src_stride = plane_size * C4NUM + plane_stride;
|
||||||
|
for (; loop_c4 <= (int)(oc4div)-C16NUM; loop_c4 += C16NUM) {
|
||||||
|
size_t plane_size_tmp = plane_size;
|
||||||
|
float *dst_c4 = dst + loop_c4;
|
||||||
|
__m128 bias1 = _mm_setzero_ps();
|
||||||
|
__m128 bias2 = _mm_setzero_ps();
|
||||||
|
__m128 bias3 = _mm_setzero_ps();
|
||||||
|
__m128 bias4 = _mm_setzero_ps();
|
||||||
|
if (bias != NULL) {
|
||||||
|
bias1 = _mm_loadu_ps(bias);
|
||||||
|
bias2 = _mm_loadu_ps(bias + C4NUM);
|
||||||
|
bias3 = _mm_loadu_ps(bias + C8NUM);
|
||||||
|
bias4 = _mm_loadu_ps(bias + C12NUM);
|
||||||
|
bias += C16NUM;
|
||||||
|
}
|
||||||
|
for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) {
|
||||||
|
__m128 src1 = _mm_loadu_ps(src);
|
||||||
|
__m128 src2 = _mm_loadu_ps(src + C4NUM);
|
||||||
|
__m128 src5 = _mm_loadu_ps(src + src_stride);
|
||||||
|
__m128 src6 = _mm_loadu_ps(src + src_stride + C4NUM);
|
||||||
|
__m128 src9 = _mm_loadu_ps(src + src_stride * C2NUM);
|
||||||
|
__m128 src10 = _mm_loadu_ps(src + src_stride * C2NUM + C4NUM);
|
||||||
|
__m128 src13 = _mm_loadu_ps(src + src_stride * C3NUM);
|
||||||
|
__m128 src14 = _mm_loadu_ps(src + src_stride * C3NUM + C4NUM);
|
||||||
|
|
||||||
|
src1 = _mm_add_ps(src1, bias1);
|
||||||
|
src2 = _mm_add_ps(src2, bias1);
|
||||||
|
src5 = _mm_add_ps(src5, bias2);
|
||||||
|
src6 = _mm_add_ps(src6, bias2);
|
||||||
|
src9 = _mm_add_ps(src9, bias3);
|
||||||
|
src10 = _mm_add_ps(src10, bias3);
|
||||||
|
src13 = _mm_add_ps(src13, bias4);
|
||||||
|
src14 = _mm_add_ps(src14, bias4);
|
||||||
|
|
||||||
|
ActBlock8(&src1, &src2, &src5, &src6, &src9, &src10, &src13, &src14, relu_type);
|
||||||
|
|
||||||
|
_mm_storeu_ps(dst_c4, src1);
|
||||||
|
_mm_storeu_ps(dst_c4 + C4NUM, src5);
|
||||||
|
_mm_storeu_ps(dst_c4 + C8NUM, src9);
|
||||||
|
_mm_storeu_ps(dst_c4 + C12NUM, src13);
|
||||||
|
dst_c4 += stride;
|
||||||
|
_mm_storeu_ps(dst_c4, src2);
|
||||||
|
_mm_storeu_ps(dst_c4 + C4NUM, src6);
|
||||||
|
_mm_storeu_ps(dst_c4 + C8NUM, src10);
|
||||||
|
_mm_storeu_ps(dst_c4 + C12NUM, src14);
|
||||||
|
dst_c4 += stride;
|
||||||
|
|
||||||
|
__m128 src3 = _mm_loadu_ps(src + C8NUM);
|
||||||
|
__m128 src4 = _mm_loadu_ps(src + C12NUM);
|
||||||
|
__m128 src7 = _mm_loadu_ps(src + src_stride + C8NUM);
|
||||||
|
__m128 src8 = _mm_loadu_ps(src + src_stride + C12NUM);
|
||||||
|
__m128 src11 = _mm_loadu_ps(src + src_stride * C2NUM + C8NUM);
|
||||||
|
__m128 src12 = _mm_loadu_ps(src + src_stride * C2NUM + C12NUM);
|
||||||
|
__m128 src15 = _mm_loadu_ps(src + src_stride * C3NUM + C8NUM);
|
||||||
|
__m128 src16 = _mm_loadu_ps(src + src_stride * C3NUM + C12NUM);
|
||||||
|
src3 = _mm_add_ps(src3, bias1);
|
||||||
|
src4 = _mm_add_ps(src4, bias1);
|
||||||
|
src7 = _mm_add_ps(src7, bias2);
|
||||||
|
src8 = _mm_add_ps(src8, bias2);
|
||||||
|
src11 = _mm_add_ps(src11, bias3);
|
||||||
|
src12 = _mm_add_ps(src12, bias3);
|
||||||
|
src15 = _mm_add_ps(src15, bias4);
|
||||||
|
src16 = _mm_add_ps(src16, bias4);
|
||||||
|
|
||||||
|
ActBlock8(&src3, &src4, &src7, &src8, &src11, &src12, &src15, &src16, relu_type);
|
||||||
|
|
||||||
|
_mm_storeu_ps(dst_c4, src3);
|
||||||
|
_mm_storeu_ps(dst_c4 + C4NUM, src7);
|
||||||
|
_mm_storeu_ps(dst_c4 + C8NUM, src11);
|
||||||
|
_mm_storeu_ps(dst_c4 + C12NUM, src15);
|
||||||
|
dst_c4 += stride;
|
||||||
|
_mm_storeu_ps(dst_c4, src4);
|
||||||
|
_mm_storeu_ps(dst_c4 + C4NUM, src8);
|
||||||
|
_mm_storeu_ps(dst_c4 + C8NUM, src12);
|
||||||
|
_mm_storeu_ps(dst_c4 + C12NUM, src16);
|
||||||
|
dst_c4 += stride;
|
||||||
|
src += C16NUM;
|
||||||
|
}
|
||||||
|
for (; plane_size_tmp > 0; plane_size_tmp -= 1) {
|
||||||
|
__m128 src1 = _mm_loadu_ps(src);
|
||||||
|
__m128 src2 = _mm_loadu_ps(src + src_stride);
|
||||||
|
__m128 src3 = _mm_loadu_ps(src + src_stride * C2NUM);
|
||||||
|
__m128 src4 = _mm_loadu_ps(src + src_stride * C3NUM);
|
||||||
|
src1 = _mm_add_ps(src1, bias1);
|
||||||
|
src2 = _mm_add_ps(src2, bias2);
|
||||||
|
src3 = _mm_add_ps(src3, bias3);
|
||||||
|
src4 = _mm_add_ps(src4, bias4);
|
||||||
|
|
||||||
|
ActBlock4(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == C3NUM);
|
||||||
|
|
||||||
|
_mm_storeu_ps(dst_c4, src1);
|
||||||
|
_mm_storeu_ps(dst_c4 + C4NUM, src2);
|
||||||
|
_mm_storeu_ps(dst_c4 + C8NUM, src3);
|
||||||
|
_mm_storeu_ps(dst_c4 + C12NUM, src4);
|
||||||
|
dst_c4 += stride;
|
||||||
|
src += C4NUM;
|
||||||
|
}
|
||||||
|
src += plane_stride;
|
||||||
|
src += C3NUM * src_stride;
|
||||||
|
}
|
||||||
|
for (; loop_c4 <= (int)(oc4div)-C12NUM; loop_c4 += C12NUM) {
|
||||||
|
size_t plane_size_tmp = plane_size;
|
||||||
|
float *dst_c4 = dst + loop_c4;
|
||||||
|
__m128 bias1 = _mm_setzero_ps();
|
||||||
|
__m128 bias2 = _mm_setzero_ps();
|
||||||
|
__m128 bias3 = _mm_setzero_ps();
|
||||||
|
if (bias != NULL) {
|
||||||
|
bias1 = _mm_loadu_ps(bias);
|
||||||
|
bias2 = _mm_loadu_ps(bias + C4NUM);
|
||||||
|
bias3 = _mm_loadu_ps(bias + C8NUM);
|
||||||
|
bias += C12NUM;
|
||||||
|
}
|
||||||
|
for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) {
|
||||||
|
__m128 src1 = _mm_loadu_ps(src);
|
||||||
|
__m128 src2 = _mm_loadu_ps(src + C4NUM);
|
||||||
|
__m128 src3 = _mm_loadu_ps(src + C8NUM);
|
||||||
|
__m128 src4 = _mm_loadu_ps(src + C12NUM);
|
||||||
|
__m128 src5 = _mm_loadu_ps(src + src_stride);
|
||||||
|
__m128 src6 = _mm_loadu_ps(src + src_stride + C4NUM);
|
||||||
|
__m128 src7 = _mm_loadu_ps(src + src_stride + C8NUM);
|
||||||
|
__m128 src8 = _mm_loadu_ps(src + src_stride + C12NUM);
|
||||||
|
__m128 src9 = _mm_loadu_ps(src + src_stride * C2NUM);
|
||||||
|
__m128 src10 = _mm_loadu_ps(src + src_stride * C2NUM + C4NUM);
|
||||||
|
__m128 src11 = _mm_loadu_ps(src + src_stride * C2NUM + C8NUM);
|
||||||
|
__m128 src12 = _mm_loadu_ps(src + src_stride * C2NUM + C12NUM);
|
||||||
|
src += C16NUM;
|
||||||
|
src1 = _mm_add_ps(src1, bias1);
|
||||||
|
src2 = _mm_add_ps(src2, bias1);
|
||||||
|
src3 = _mm_add_ps(src3, bias1);
|
||||||
|
src4 = _mm_add_ps(src4, bias1);
|
||||||
|
src5 = _mm_add_ps(src5, bias2);
|
||||||
|
src6 = _mm_add_ps(src6, bias2);
|
||||||
|
src7 = _mm_add_ps(src7, bias2);
|
||||||
|
src8 = _mm_add_ps(src8, bias2);
|
||||||
|
src9 = _mm_add_ps(src9, bias3);
|
||||||
|
src10 = _mm_add_ps(src10, bias3);
|
||||||
|
src11 = _mm_add_ps(src11, bias3);
|
||||||
|
src12 = _mm_add_ps(src12, bias3);
|
||||||
|
|
||||||
|
ActBlock12(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, &src9, &src10, &src11, &src12, relu_type == 1,
|
||||||
|
relu_type == C3NUM);
|
||||||
|
|
||||||
|
_mm_storeu_ps(dst_c4, src1);
|
||||||
|
_mm_storeu_ps(dst_c4 + C4NUM, src5);
|
||||||
|
_mm_storeu_ps(dst_c4 + C8NUM, src9);
|
||||||
|
dst_c4 += stride;
|
||||||
|
_mm_storeu_ps(dst_c4, src2);
|
||||||
|
_mm_storeu_ps(dst_c4 + C4NUM, src6);
|
||||||
|
_mm_storeu_ps(dst_c4 + C8NUM, src10);
|
||||||
|
dst_c4 += stride;
|
||||||
|
_mm_storeu_ps(dst_c4, src3);
|
||||||
|
_mm_storeu_ps(dst_c4 + C4NUM, src7);
|
||||||
|
_mm_storeu_ps(dst_c4 + C8NUM, src11);
|
||||||
|
dst_c4 += stride;
|
||||||
|
_mm_storeu_ps(dst_c4, src4);
|
||||||
|
_mm_storeu_ps(dst_c4 + C4NUM, src8);
|
||||||
|
_mm_storeu_ps(dst_c4 + C8NUM, src12);
|
||||||
|
dst_c4 += stride;
|
||||||
|
}
|
||||||
|
for (; plane_size_tmp > 0; plane_size_tmp -= 1) {
|
||||||
|
__m128 src1 = _mm_loadu_ps(src);
|
||||||
|
__m128 src2 = _mm_loadu_ps(src + src_stride);
|
||||||
|
__m128 src3 = _mm_loadu_ps(src + src_stride * C2NUM);
|
||||||
|
src1 = _mm_add_ps(src1, bias1);
|
||||||
|
src2 = _mm_add_ps(src2, bias2);
|
||||||
|
src3 = _mm_add_ps(src3, bias3);
|
||||||
|
|
||||||
|
ActBlock1(&src1, relu_type == 1, relu_type == C3NUM);
|
||||||
|
ActBlock1(&src2, relu_type == 1, relu_type == C3NUM);
|
||||||
|
ActBlock1(&src3, relu_type == 1, relu_type == C3NUM);
|
||||||
|
|
||||||
|
_mm_storeu_ps(dst_c4, src1);
|
||||||
|
_mm_storeu_ps(dst_c4 + C4NUM, src2);
|
||||||
|
_mm_storeu_ps(dst_c4 + C8NUM, src3);
|
||||||
|
dst_c4 += stride;
|
||||||
|
src += C4NUM;
|
||||||
|
}
|
||||||
|
src += plane_stride;
|
||||||
|
src += C2NUM * src_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (; loop_c4 <= (int)(oc4div)-C8NUM; loop_c4 += C8NUM) {
|
||||||
|
size_t plane_size_tmp = plane_size;
|
||||||
|
float *dst_c4 = dst + loop_c4;
|
||||||
|
__m128 bias1 = _mm_setzero_ps();
|
||||||
|
__m128 bias2 = _mm_setzero_ps();
|
||||||
|
if (bias != NULL) {
|
||||||
|
bias1 = _mm_loadu_ps(bias);
|
||||||
|
bias2 = _mm_loadu_ps(bias + C4NUM);
|
||||||
|
bias += C8NUM;
|
||||||
|
}
|
||||||
|
for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) {
|
||||||
|
__m128 src1 = _mm_loadu_ps(src);
|
||||||
|
__m128 src2 = _mm_loadu_ps(src + C4NUM);
|
||||||
|
__m128 src3 = _mm_loadu_ps(src + C8NUM);
|
||||||
|
__m128 src4 = _mm_loadu_ps(src + C12NUM);
|
||||||
|
__m128 src5 = _mm_loadu_ps(src + src_stride);
|
||||||
|
__m128 src6 = _mm_loadu_ps(src + src_stride + C4NUM);
|
||||||
|
__m128 src7 = _mm_loadu_ps(src + src_stride + C8NUM);
|
||||||
|
__m128 src8 = _mm_loadu_ps(src + src_stride + C12NUM);
|
||||||
|
src += C16NUM;
|
||||||
|
src1 = _mm_add_ps(src1, bias1);
|
||||||
|
src2 = _mm_add_ps(src2, bias1);
|
||||||
|
src3 = _mm_add_ps(src3, bias1);
|
||||||
|
src4 = _mm_add_ps(src4, bias1);
|
||||||
|
src5 = _mm_add_ps(src5, bias2);
|
||||||
|
src6 = _mm_add_ps(src6, bias2);
|
||||||
|
src7 = _mm_add_ps(src7, bias2);
|
||||||
|
src8 = _mm_add_ps(src8, bias2);
|
||||||
|
|
||||||
|
ActBlock8(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, relu_type);
|
||||||
|
|
||||||
|
_mm_storeu_ps(dst_c4, src1);
|
||||||
|
_mm_storeu_ps(dst_c4 + C4NUM, src5);
|
||||||
|
dst_c4 += stride;
|
||||||
|
_mm_storeu_ps(dst_c4, src2);
|
||||||
|
_mm_storeu_ps(dst_c4 + C4NUM, src6);
|
||||||
|
dst_c4 += stride;
|
||||||
|
_mm_storeu_ps(dst_c4, src3);
|
||||||
|
_mm_storeu_ps(dst_c4 + C4NUM, src7);
|
||||||
|
dst_c4 += stride;
|
||||||
|
_mm_storeu_ps(dst_c4, src4);
|
||||||
|
_mm_storeu_ps(dst_c4 + C4NUM, src8);
|
||||||
|
dst_c4 += stride;
|
||||||
|
}
|
||||||
|
for (; plane_size_tmp > 0; plane_size_tmp -= 1) {
|
||||||
|
__m128 src1 = _mm_loadu_ps(src);
|
||||||
|
__m128 src2 = _mm_loadu_ps(src + src_stride);
|
||||||
|
src1 = _mm_add_ps(src1, bias1);
|
||||||
|
src2 = _mm_add_ps(src2, bias2);
|
||||||
|
|
||||||
|
ActBlock1(&src1, relu_type == 1, relu_type == C3NUM);
|
||||||
|
ActBlock1(&src2, relu_type == 1, relu_type == C3NUM);
|
||||||
|
|
||||||
|
_mm_storeu_ps(dst_c4, src1);
|
||||||
|
_mm_storeu_ps(dst_c4 + C4NUM, src2);
|
||||||
|
dst_c4 += stride;
|
||||||
|
src += C4NUM;
|
||||||
|
}
|
||||||
|
src += plane_stride;
|
||||||
|
src += src_stride;
|
||||||
|
}
|
||||||
|
for (; loop_c4 < (int)(oc4div); loop_c4 += C4NUM) {
|
||||||
size_t plane_size_tmp = plane_size;
|
size_t plane_size_tmp = plane_size;
|
||||||
float *dst_c4 = dst + loop_c4;
|
float *dst_c4 = dst + loop_c4;
|
||||||
__m128 bias1 = _mm_setzero_ps();
|
__m128 bias1 = _mm_setzero_ps();
|
||||||
if (bias != NULL) {
|
if (bias != NULL) {
|
||||||
bias1 = _mm_loadu_ps(bias);
|
bias1 = _mm_loadu_ps(bias);
|
||||||
bias += 4;
|
bias += C4NUM;
|
||||||
}
|
}
|
||||||
for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) {
|
for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) {
|
||||||
__m128 src1 = _mm_loadu_ps(src);
|
__m128 src1 = _mm_loadu_ps(src);
|
||||||
__m128 src2 = _mm_loadu_ps(src + 4);
|
__m128 src2 = _mm_loadu_ps(src + C4NUM);
|
||||||
__m128 src3 = _mm_loadu_ps(src + 8);
|
__m128 src3 = _mm_loadu_ps(src + 8);
|
||||||
__m128 src4 = _mm_loadu_ps(src + 12);
|
__m128 src4 = _mm_loadu_ps(src + 12);
|
||||||
src += 16;
|
src += C16NUM;
|
||||||
src1 = _mm_add_ps(src1, bias1);
|
src1 = _mm_add_ps(src1, bias1);
|
||||||
src2 = _mm_add_ps(src2, bias1);
|
src2 = _mm_add_ps(src2, bias1);
|
||||||
src3 = _mm_add_ps(src3, bias1);
|
src3 = _mm_add_ps(src3, bias1);
|
||||||
src4 = _mm_add_ps(src4, bias1);
|
src4 = _mm_add_ps(src4, bias1);
|
||||||
|
|
||||||
ActBlock4(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == 3);
|
ActBlock4(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == C3NUM);
|
||||||
|
|
||||||
_mm_storeu_ps(dst_c4, src1);
|
_mm_storeu_ps(dst_c4, src1);
|
||||||
dst_c4 += stride;
|
dst_c4 += stride;
|
||||||
|
@ -57,11 +300,11 @@ void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t
|
||||||
__m128 src1 = _mm_loadu_ps(src);
|
__m128 src1 = _mm_loadu_ps(src);
|
||||||
src1 = _mm_add_ps(src1, bias1);
|
src1 = _mm_add_ps(src1, bias1);
|
||||||
|
|
||||||
ActBlock1(&src1, relu_type == 1, relu_type == 3);
|
ActBlock1(&src1, relu_type == 1, relu_type == C3NUM);
|
||||||
|
|
||||||
_mm_storeu_ps(dst_c4, src1);
|
_mm_storeu_ps(dst_c4, src1);
|
||||||
dst_c4 += stride;
|
dst_c4 += stride;
|
||||||
src += 4;
|
src += C4NUM;
|
||||||
}
|
}
|
||||||
src += plane_stride;
|
src += plane_stride;
|
||||||
}
|
}
|
||||||
|
@ -71,32 +314,32 @@ void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t
|
||||||
__m128 bias1 = _mm_setzero_ps();
|
__m128 bias1 = _mm_setzero_ps();
|
||||||
if (bias != NULL) {
|
if (bias != NULL) {
|
||||||
bias1 = _mm_loadu_ps(bias);
|
bias1 = _mm_loadu_ps(bias);
|
||||||
bias += 4;
|
bias += C4NUM;
|
||||||
}
|
}
|
||||||
float *dst_c1 = dst + oc4div;
|
float *dst_c1 = dst + oc4div;
|
||||||
for (size_t plane_size_tmp = plane_size; plane_size_tmp > 0; plane_size_tmp -= 1) {
|
for (size_t plane_size_tmp = plane_size; plane_size_tmp > 0; plane_size_tmp -= 1) {
|
||||||
__m128 src1 = _mm_loadu_ps(src);
|
__m128 src1 = _mm_loadu_ps(src);
|
||||||
src += 4;
|
src += C4NUM;
|
||||||
src1 = _mm_add_ps(src1, bias1);
|
src1 = _mm_add_ps(src1, bias1);
|
||||||
|
|
||||||
ActBlock1(&src1, relu_type == 1, relu_type == 3);
|
ActBlock1(&src1, relu_type == 1, relu_type == C3NUM);
|
||||||
|
|
||||||
switch (oc4mod) {
|
switch (oc4mod) {
|
||||||
case 1:
|
case 1:
|
||||||
_mm_store_ss(dst_c1, src1);
|
_mm_store_ss(dst_c1, src1);
|
||||||
dst_c1 += stride;
|
dst_c1 += stride;
|
||||||
break;
|
break;
|
||||||
case 2:
|
case C2NUM:
|
||||||
_mm_storel_pi((__m64 *)(dst_c1), src1);
|
_mm_storel_pi((__m64 *)(dst_c1), src1);
|
||||||
dst_c1 += stride;
|
dst_c1 += stride;
|
||||||
break;
|
break;
|
||||||
case 3:
|
case C3NUM:
|
||||||
_mm_storel_pi((__m64 *)(dst_c1), src1);
|
_mm_storel_pi((__m64 *)(dst_c1), src1);
|
||||||
src1 = _mm_unpackhi_ps(src1, src1);
|
src1 = _mm_unpackhi_ps(src1, src1);
|
||||||
_mm_store_ss(dst_c1 + 2, src1);
|
_mm_store_ss(dst_c1 + C2NUM, src1);
|
||||||
dst_c1 += stride;
|
dst_c1 += stride;
|
||||||
break;
|
break;
|
||||||
case 4:
|
case C4NUM:
|
||||||
_mm_storeu_ps(dst_c1, src1);
|
_mm_storeu_ps(dst_c1, src1);
|
||||||
dst_c1 += stride;
|
dst_c1 += stride;
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -58,6 +58,41 @@ static inline void ActBlock4(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, siz
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline void ActBlock12(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, __m128 *v5, __m128 *v6, __m128 *v7,
|
||||||
|
__m128 *v8, __m128 *v9, __m128 *v10, __m128 *v11, __m128 *v12, size_t relu,
|
||||||
|
size_t relu6) {
|
||||||
|
if (relu || relu6) {
|
||||||
|
__m128 zero_ma = _mm_setzero_ps();
|
||||||
|
*v1 = _mm_max_ps(zero_ma, *v1);
|
||||||
|
*v2 = _mm_max_ps(zero_ma, *v2);
|
||||||
|
*v3 = _mm_max_ps(zero_ma, *v3);
|
||||||
|
*v4 = _mm_max_ps(zero_ma, *v4);
|
||||||
|
*v5 = _mm_max_ps(zero_ma, *v5);
|
||||||
|
*v6 = _mm_max_ps(zero_ma, *v6);
|
||||||
|
*v7 = _mm_max_ps(zero_ma, *v7);
|
||||||
|
*v8 = _mm_max_ps(zero_ma, *v8);
|
||||||
|
*v9 = _mm_max_ps(zero_ma, *v9);
|
||||||
|
*v10 = _mm_max_ps(zero_ma, *v10);
|
||||||
|
*v11 = _mm_max_ps(zero_ma, *v11);
|
||||||
|
*v12 = _mm_max_ps(zero_ma, *v12);
|
||||||
|
}
|
||||||
|
if (relu6) {
|
||||||
|
__m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f);
|
||||||
|
*v1 = _mm_min_ps(relu6_ma, *v1);
|
||||||
|
*v2 = _mm_min_ps(relu6_ma, *v2);
|
||||||
|
*v3 = _mm_min_ps(relu6_ma, *v3);
|
||||||
|
*v4 = _mm_min_ps(relu6_ma, *v4);
|
||||||
|
*v5 = _mm_min_ps(relu6_ma, *v5);
|
||||||
|
*v6 = _mm_min_ps(relu6_ma, *v6);
|
||||||
|
*v7 = _mm_min_ps(relu6_ma, *v7);
|
||||||
|
*v8 = _mm_min_ps(relu6_ma, *v8);
|
||||||
|
*v9 = _mm_min_ps(relu6_ma, *v9);
|
||||||
|
*v10 = _mm_min_ps(relu6_ma, *v10);
|
||||||
|
*v11 = _mm_min_ps(relu6_ma, *v11);
|
||||||
|
*v12 = _mm_min_ps(relu6_ma, *v12);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static inline void ActBlock8(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, __m128 *v5, __m128 *v6, __m128 *v7,
|
static inline void ActBlock8(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, __m128 *v5, __m128 *v6, __m128 *v7,
|
||||||
__m128 *v8, size_t relu_type) {
|
__m128 *v8, size_t relu_type) {
|
||||||
__m128 relu6 = _mm_set_ps1(6.0);
|
__m128 relu6 = _mm_set_ps1(6.0);
|
||||||
|
|
Loading…
Reference in New Issue