add x86_64 sse optimize

This commit is contained in:
liuzhongkai 2020-11-03 14:06:36 +08:00
parent bbfef6233a
commit e4fa3c5b85
6 changed files with 785 additions and 4 deletions

View File

@ -26,7 +26,7 @@ usage()
echo " [-a on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|d|cpu] \\"
echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1] [-I arm64|arm32|x86_64] [-K] \\"
echo " [-B on|off] [-w on|off] [-E] [-l on|off] [-n full|lite|off] [-T on|off] \\"
echo " [-A [cpp|java|object-c] [-C on|off] [-o on|off] [-S on|off] [-k on|off] \\"
echo " [-A [cpp|java|object-c] [-C on|off] [-o on|off] [-S on|off] [-k on|off] [-W sse|neon|avx|off] \\"
echo ""
echo "Options:"
echo " -d Debug mode"
@ -65,6 +65,7 @@ usage()
echo " -o Enable mindspore lite tools compilation, enabled when -I is specified, default on"
echo " -S Enable enable download cmake compile dependency from gitee , default off"
echo " -k Enable make clean, clean up compilation generated cache "
echo " -W Enable x86_64 SSE or AVX instruction set, use [sse|avx|neon|off], default off"
}
# check value of input is 'on' or 'off'
@ -118,9 +119,10 @@ checkopts()
ENABLE_GITEE="off"
ANDROID_STL="c++_shared"
ENABLE_MAKE_CLEAN="off"
X86_64_SIMD="off"
# Process the options
while getopts 'drvj:c:t:hsb:a:g:p:ie:m:l:I:LRP:D:zM:V:K:swB:En:T:A:C:o:S:k:' opt
while getopts 'drvj:c:t:hsb:a:g:p:ie:m:l:I:LRP:D:zM:V:K:swB:En:T:A:C:o:S:k:W:' opt
do
OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]')
case "${opt}" in
@ -341,6 +343,16 @@ checkopts()
check_on_off $OPTARG o
ENABLE_TOOLS="$OPTARG"
;;
W)
if [[ "$OPTARG" != "sse" && "$OPTARG" != "off" && "$OPTARG" != "avx" && "$OPTARG" != "neon" ]]; then
echo "Invalid value ${OPTARG} for option -W, -W parameter must be sse|neon|avx|off"
usage
exit 1
fi
if [[ "$OPTARG" == "sse" || "$OPTARG" == "avx" ]]; then
X86_64_SIMD="$OPTARG"
fi
;;
*)
echo "Unknown option ${opt}!"
usage
@ -702,7 +714,7 @@ build_lite()
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_GPU=${ENABLE_GPU} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \
-DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp \
-DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \
-DENABLE_VERBOSE=${ENABLE_VERBOSE} "${BASEPATH}/mindspore/lite"
-DENABLE_VERBOSE=${ENABLE_VERBOSE} -DX86_64_SIMD=${X86_64_SIMD} "${BASEPATH}/mindspore/lite"
fi
make -j$THREAD_NUM && make install && make package
COMPILE_RET=$?

View File

@ -20,6 +20,7 @@ option(SUPPORT_GPU "if support gpu" off)
option(OFFLINE_COMPILE "if offline compile OpenCL kernel" off)
option(BUILD_MINDDATA_EXAMPLE "" on)
option(ENABLE_VERBOSE "" off)
option(ENABLE_X86_64_SSE "if x86_64 support SSE instruction set" off)
set(DIR_PREFIX mindspore-lite)
set(MS_VERSION ${MS_VERSION_MAJOR}.${MS_VERSION_MINOR}.${MS_VERSION_REVISION})
@ -174,6 +175,12 @@ if (PLATFORM_ARM32 OR PLATFORM_ARM64)
endif()
endif()
if (NOT PLATFORM_ARM32 AND NOT PLATFORM_ARM64)
if ("${X86_64_SIMD}" STREQUAL "sse")
add_compile_definitions(ENABLE_X86_64_SSE)
endif ()
endif ()
if (BUILD_MINDDATA STREQUAL "lite" OR BUILD_MINDDATA STREQUAL "full")
# add sentencepiece dependency
# include(${TOP_DIR}/cmake/external_libs/sentencepiece.cmake)

View File

@ -32,6 +32,11 @@ if (PLATFORM_ARM32)
set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)
endif()
if ("${X86_64_SIMD}" STREQUAL "sse")
file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/x86_64_sse/*.c)
set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)
endif()
########################### build nnacl static library ########################
string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
add_library(nnacl STATIC ${KERNEL_SRC} ${TRAIN_SRC} ${ASSEMBLY_SRC})

View File

@ -121,7 +121,7 @@ int B(const float *poly_array, float *matrix_b, int in_unit) {
return NNACL_OK;
}
#ifndef ENABLE_ARM
#if !defined(ENABLE_ARM) && !defined(ENABLE_X86_64_SSE)
void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n,
int in_channel, int c4_channel) {
int cnt = 0;

View File

@ -0,0 +1,747 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef ENABLE_X86_64_SSE
#include <nmmintrin.h>
#include "nnacl/minimal_filtering_generator.h"
#include "nnacl/op_base.h"
void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n,
int in_channel, int c4_channel) {
const float *src1 = matix_a;
int c16 = DOWN_DIV(in_channel, C16NUM) * C16NUM;
int c8 = DOWN_DIV(in_channel, C8NUM) * C8NUM;
for (int i = 0; i < m; ++i) {
const float *src1_n = src1;
const float *src2_n = matrix_b;
for (int j = 0; j < n; ++j) {
const float *src1_j = src1_n;
int y = 0;
// 16 channel
for (; y < c16; y += C16NUM) {
__m128 dst1 = _mm_setzero_ps();
__m128 dst2 = _mm_setzero_ps();
__m128 dst3 = _mm_setzero_ps();
__m128 dst4 = _mm_setzero_ps();
const float *src2_y = src2_n;
for (int z = 0; z < k; ++z) {
__m128 ma1 = _mm_loadu_ps(src1_j);
__m128 ma2 = _mm_loadu_ps(src1_j + 4);
__m128 ma3 = _mm_loadu_ps(src1_j + 8);
__m128 ma4 = _mm_loadu_ps(src1_j + 12);
__m128 mb = _mm_load_ps1(src2_y);
__m128 tmp1 = _mm_mul_ps(ma1, mb);
__m128 tmp2 = _mm_mul_ps(ma2, mb);
__m128 tmp3 = _mm_mul_ps(ma3, mb);
__m128 tmp4 = _mm_mul_ps(ma4, mb);
dst1 = _mm_add_ps(dst1, tmp1);
dst2 = _mm_add_ps(dst2, tmp2);
dst3 = _mm_add_ps(dst3, tmp3);
dst4 = _mm_add_ps(dst4, tmp4);
src1_j += in_channel;
src2_y += n;
}
_mm_store_ps(matrix_c, dst1);
_mm_store_ps(matrix_c + 4, dst2);
_mm_store_ps(matrix_c + 8, dst3);
_mm_store_ps(matrix_c + 12, dst4);
src1_j -= in_channel * k;
src1_j += C16NUM;
matrix_c += C16NUM;
}
// 8 channel
for (; y < c8; y += C8NUM) {
__m128 dst1 = _mm_setzero_ps();
__m128 dst2 = _mm_setzero_ps();
const float *src2_y = src2_n;
for (int z = 0; z < k; ++z) {
__m128 ma1 = _mm_loadu_ps(src1_j);
__m128 ma2 = _mm_loadu_ps(src1_j + 4);
__m128 mb = _mm_load_ps1(src2_y);
__m128 tmp1 = _mm_mul_ps(ma1, mb);
__m128 tmp2 = _mm_mul_ps(ma2, mb);
dst1 = _mm_add_ps(dst1, tmp1);
dst2 = _mm_add_ps(dst2, tmp2);
src1_j += in_channel;
src2_y += n;
}
_mm_store_ps(matrix_c, dst1);
_mm_store_ps(matrix_c + 4, dst2);
src1_j -= in_channel * k;
src1_j += C8NUM;
matrix_c += C8NUM;
}
// remain chann
for (; y < in_channel; ++y) {
float tmp = 0;
for (int z = 0; z < k; ++z) {
tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n];
}
*matrix_c++ = tmp;
}
src2_n += 1;
}
src1 += k * in_channel;
}
}
void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, int stride, int write_mode) {
int C8Steps = row * C8NUM;
int WinoSteps1 = stride * col;
int WinoSteps2 = stride * C8NUM;
for (int r = row; r > 0; r -= C4NUM) {
const float *srcb_d = b;
const float *bias_d = bias;
float *dst = NULL;
for (int cc = col; cc > 0; cc -= C8NUM) {
if (write_mode != 0) { // writec8
dst = c;
}
const float *srca_d = a;
__m128 dst1 = _mm_setzero_ps();
__m128 dst2 = _mm_setzero_ps();
__m128 dst3 = _mm_setzero_ps();
__m128 dst4 = _mm_setzero_ps();
__m128 dst5 = _mm_setzero_ps();
__m128 dst6 = _mm_setzero_ps();
__m128 dst7 = _mm_setzero_ps();
__m128 dst8 = _mm_setzero_ps();
for (int d = depth; d > 0; --d) {
__m128 b1 = _mm_loadu_ps(srcb_d);
__m128 b2 = _mm_loadu_ps(srcb_d + 4);
__m128 a1 = _mm_load_ps1(srca_d);
__m128 a2 = _mm_load_ps1(srca_d + 1);
__m128 tmp1 = _mm_mul_ps(b1, a1);
__m128 tmp2 = _mm_mul_ps(b2, a1);
__m128 tmp3 = _mm_mul_ps(b1, a2);
__m128 tmp4 = _mm_mul_ps(b2, a2);
a1 = _mm_load_ps1(srca_d + 2);
dst1 = _mm_add_ps(dst1, tmp1);
dst2 = _mm_add_ps(dst2, tmp2);
a2 = _mm_load_ps1(srca_d + 3);
dst3 = _mm_add_ps(dst3, tmp3);
dst4 = _mm_add_ps(dst4, tmp4);
tmp1 = _mm_mul_ps(b1, a1);
tmp2 = _mm_mul_ps(b2, a1);
tmp3 = _mm_mul_ps(b1, a2);
tmp4 = _mm_mul_ps(b2, a2);
dst5 = _mm_add_ps(dst5, tmp1);
dst6 = _mm_add_ps(dst6, tmp2);
dst7 = _mm_add_ps(dst7, tmp3);
dst8 = _mm_add_ps(dst8, tmp4);
srcb_d += C8NUM;
srca_d += C4NUM;
}
if (bias != NULL) {
__m128 bias1 = _mm_loadu_ps(bias_d);
__m128 bias2 = _mm_loadu_ps(bias_d + C4NUM);
dst1 = _mm_add_ps(dst1, bias1);
dst2 = _mm_add_ps(dst2, bias2);
dst3 = _mm_add_ps(dst3, bias1);
dst4 = _mm_add_ps(dst4, bias2);
dst5 = _mm_add_ps(dst5, bias1);
dst6 = _mm_add_ps(dst6, bias2);
dst7 = _mm_add_ps(dst7, bias1);
dst8 = _mm_add_ps(dst8, bias2);
bias_d += C8NUM;
}
if (act_type == 3) {
__m128 relu6 = _mm_set_ps(6.0, 6.0, 6.0, 6.0);
dst1 = _mm_min_ps(dst1, relu6);
dst2 = _mm_min_ps(dst2, relu6);
dst3 = _mm_min_ps(dst3, relu6);
dst4 = _mm_min_ps(dst4, relu6);
dst5 = _mm_min_ps(dst5, relu6);
dst6 = _mm_min_ps(dst6, relu6);
dst7 = _mm_min_ps(dst7, relu6);
dst8 = _mm_min_ps(dst8, relu6);
}
if (act_type == 1 || act_type == 3) {
__m128 zero = _mm_setzero_ps();
dst1 = _mm_max_ps(dst1, zero);
dst2 = _mm_max_ps(dst2, zero);
dst3 = _mm_max_ps(dst3, zero);
dst4 = _mm_max_ps(dst4, zero);
dst5 = _mm_max_ps(dst5, zero);
dst6 = _mm_max_ps(dst6, zero);
dst7 = _mm_max_ps(dst7, zero);
dst8 = _mm_max_ps(dst8, zero);
}
if (write_mode == 2) { // WriteWino
c = dst + WinoSteps2;
_mm_store_ps(dst, dst1);
_mm_store_ps(dst + 4, dst2);
dst += WinoSteps1;
_mm_store_ps(dst, dst3);
_mm_store_ps(dst + 4, dst4);
dst += WinoSteps1;
_mm_store_ps(dst, dst5);
_mm_store_ps(dst + 4, dst6);
dst += WinoSteps1;
_mm_store_ps(dst, dst7);
_mm_store_ps(dst + 4, dst8);
} else if (write_mode == 0) { // WriteC8
_mm_store_ps(c, dst1);
_mm_store_ps(c + 4, dst2);
_mm_store_ps(c + 8, dst3);
_mm_store_ps(c + 12, dst4);
_mm_store_ps(c + 16, dst5);
_mm_store_ps(c + 20, dst6);
_mm_store_ps(c + 24, dst7);
_mm_store_ps(c + 28, dst8);
c += C8Steps;
} else {
switch (cc) {
case 1: // write1
c = dst + 1;
_mm_store_ss(dst, dst1);
if (r > 1) {
dst += stride;
_mm_store_ss(dst, dst3);
}
if (r > 2) {
dst += stride;
_mm_store_ss(dst, dst5);
}
if (r > 3) {
dst += stride;
_mm_store_ss(dst, dst7);
dst += stride;
dst += 1;
}
break;
case 2: // write2
c = dst + 2;
_mm_store_ss(dst, dst1);
dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 1, dst1);
if (r > 1) {
dst += stride;
_mm_store_ss(dst, dst3);
dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 1, dst3);
}
if (r > 2) {
dst += stride;
_mm_store_ss(dst, dst5);
dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 1, dst5);
}
if (r > 3) {
dst += stride;
_mm_store_ss(dst, dst7);
dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 1, dst7);
dst += stride;
dst += 2;
}
break;
case 3: // write3
c = dst + 3;
_mm_store_ss(dst, dst1);
dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 1, dst1);
dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 2, dst1);
if (r > 1) {
dst += stride;
_mm_store_ss(dst, dst3);
dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 1, dst3);
dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 2, dst3);
}
if (r > 2) {
dst += stride;
_mm_store_ss(dst, dst5);
dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 1, dst5);
dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 2, dst5);
}
if (r > 3) {
dst += stride;
_mm_store_ss(dst, dst7);
dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 1, dst7);
dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 2, dst7);
dst += stride;
dst += 3;
}
break;
case 4: // write4
c = dst + 4;
_mm_store_ps(dst, dst1);
if (r > 1) {
dst += stride;
_mm_store_ps(dst, dst3);
}
if (r > 2) {
dst += stride;
_mm_store_ps(dst, dst5);
}
if (r > 3) {
dst += stride;
_mm_store_ps(dst, dst7);
dst += stride;
dst += 4;
}
break;
case 5: // write5
c = dst + 5;
_mm_store_ps(dst, dst1);
_mm_store_ss(dst + 4, dst2);
if (r > 1) {
dst += stride;
_mm_store_ps(dst, dst3);
_mm_store_ss(dst + 4, dst4);
}
if (r > 2) {
dst += stride;
_mm_store_ps(dst, dst5);
_mm_store_ss(dst + 4, dst6);
}
if (r > 3) {
dst += stride;
_mm_store_ps(dst, dst7);
_mm_store_ss(dst + 4, dst8);
dst += stride;
dst += 5;
}
break;
case 6: // write6
c = dst + 6;
_mm_store_ps(dst, dst1);
_mm_store_ss(dst + 4, dst2);
dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 5, dst2);
if (r > 1) {
dst += stride;
_mm_store_ps(dst, dst3);
_mm_store_ss(dst + 4, dst4);
dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 5, dst4);
}
if (r > 2) {
dst += stride;
_mm_store_ps(dst, dst5);
_mm_store_ss(dst + 4, dst6);
dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 5, dst6);
}
if (r > 3) {
dst += stride;
_mm_store_ps(dst, dst7);
_mm_store_ss(dst + 4, dst8);
dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 5, dst8);
dst += stride;
dst += 6;
}
break;
case 7: // write7
c = dst + 7;
_mm_store_ps(dst, dst1);
_mm_store_ss(dst + 4, dst2);
dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 5, dst2);
dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 6, dst2);
if (r > 1) {
dst += stride;
_mm_store_ps(dst, dst3);
_mm_store_ss(dst + 4, dst4);
dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 5, dst4);
dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 6, dst4);
}
if (r > 2) {
dst += stride;
_mm_store_ps(dst, dst5);
_mm_store_ss(dst + 4, dst6);
dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 5, dst6);
dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 6, dst6);
}
if (r > 3) {
dst += stride;
_mm_store_ps(dst, dst7);
_mm_store_ss(dst + 4, dst8);
dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 5, dst8);
dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 6, dst8);
dst += stride;
dst += 7;
}
break;
default: // write8
c = dst + C8NUM;
_mm_store_ps(dst, dst1);
_mm_store_ps(dst + 4, dst2);
if (r > 1) {
dst += stride;
_mm_store_ps(dst, dst3);
_mm_store_ps(dst + 4, dst4);
}
if (r > 2) {
dst += stride;
_mm_store_ps(dst, dst5);
_mm_store_ps(dst + 4, dst6);
}
if (r > 3) {
dst += stride;
_mm_store_ps(dst, dst7);
_mm_store_ps(dst + 4, dst8);
dst += stride;
dst += C8NUM;
}
break;
}
}
if (cc <= C8NUM) { // write end
break;
}
} // col end
a += C4NUM * depth;
switch (write_mode) {
case 0: // C8DstStep
c += 32;
break;
case 2:
c = dst + WinoSteps2;
break;
default:
c = dst - col;
break;
}
if (r <= C4NUM) {
break;
}
}
}
void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, int stride, size_t writeNhwc, size_t WriteWino) {
size_t DstWinoSteps = stride * C8NUM;
size_t WriteWinoSteps = stride * col;
for (int col_tmp = col; col_tmp > 0; col_tmp -= C8NUM) {
const float *srca_d = a;
float *dst = c;
for (int r = row; r > 0; r -= C4NUM) {
const float *srcb_d = b;
__m128 dst1 = _mm_setzero_ps();
__m128 dst2 = _mm_setzero_ps();
__m128 dst3 = _mm_setzero_ps();
__m128 dst4 = _mm_setzero_ps();
__m128 dst5 = _mm_setzero_ps();
__m128 dst6 = _mm_setzero_ps();
__m128 dst7 = _mm_setzero_ps();
__m128 dst8 = _mm_setzero_ps();
for (int d = 0; d < depth; d++) {
__m128 b1 = _mm_loadu_ps(srcb_d);
__m128 b2 = _mm_loadu_ps(srcb_d + 4);
__m128 a1 = _mm_load_ps1(srca_d);
__m128 a2 = _mm_load_ps1(srca_d + 1);
__m128 tmp1 = _mm_mul_ps(b1, a1);
__m128 tmp2 = _mm_mul_ps(b2, a1);
__m128 tmp3 = _mm_mul_ps(b1, a2);
__m128 tmp4 = _mm_mul_ps(b2, a2);
a1 = _mm_load_ps1(srca_d + 2);
dst1 = _mm_add_ps(dst1, tmp1);
dst2 = _mm_add_ps(dst2, tmp2);
a2 = _mm_load_ps1(srca_d + 3);
dst3 = _mm_add_ps(dst3, tmp3);
dst4 = _mm_add_ps(dst4, tmp4);
tmp1 = _mm_mul_ps(b1, a1);
tmp2 = _mm_mul_ps(b2, a1);
tmp3 = _mm_mul_ps(b1, a2);
tmp4 = _mm_mul_ps(b2, a2);
dst5 = _mm_add_ps(dst5, tmp1);
dst6 = _mm_add_ps(dst6, tmp2);
dst7 = _mm_add_ps(dst7, tmp3);
dst8 = _mm_add_ps(dst8, tmp4);
srcb_d += C8NUM;
srca_d += C4NUM;
}
if (bias != NULL) {
__m128 bias1 = _mm_loadu_ps(bias);
__m128 bias2 = _mm_loadu_ps(bias + C4NUM);
dst1 = _mm_add_ps(dst1, bias1);
dst2 = _mm_add_ps(dst2, bias2);
dst3 = _mm_add_ps(dst3, bias1);
dst4 = _mm_add_ps(dst4, bias2);
dst5 = _mm_add_ps(dst5, bias1);
dst6 = _mm_add_ps(dst6, bias2);
dst7 = _mm_add_ps(dst7, bias1);
dst8 = _mm_add_ps(dst8, bias2);
}
if (act_type == 3) {
__m128 relu6 = _mm_set_ps(6.0, 6.0, 6.0, 6.0);
dst1 = _mm_min_ps(dst1, relu6);
dst2 = _mm_min_ps(dst2, relu6);
dst3 = _mm_min_ps(dst3, relu6);
dst4 = _mm_min_ps(dst4, relu6);
dst5 = _mm_min_ps(dst5, relu6);
dst6 = _mm_min_ps(dst6, relu6);
dst7 = _mm_min_ps(dst7, relu6);
dst8 = _mm_min_ps(dst8, relu6);
}
if (act_type == 1 || act_type == 3) {
__m128 zero = _mm_setzero_ps();
dst1 = _mm_max_ps(dst1, zero);
dst2 = _mm_max_ps(dst2, zero);
dst3 = _mm_max_ps(dst3, zero);
dst4 = _mm_max_ps(dst4, zero);
dst5 = _mm_max_ps(dst5, zero);
dst6 = _mm_max_ps(dst6, zero);
dst7 = _mm_max_ps(dst7, zero);
dst8 = _mm_max_ps(dst8, zero);
}
if (WriteWino != 0) { // WriteWino
_mm_store_ps(dst, dst1);
_mm_store_ps(dst + 4, dst2);
dst += WriteWinoSteps;
_mm_store_ps(dst, dst3);
_mm_store_ps(dst + 4, dst4);
dst += WriteWinoSteps;
_mm_store_ps(dst, dst5);
_mm_store_ps(dst + 4, dst6);
dst += WriteWinoSteps;
_mm_store_ps(dst, dst7);
_mm_store_ps(dst + 4, dst8);
dst += WriteWinoSteps;
} else if (writeNhwc == 0) { // WriteC8
_mm_store_ps(dst, dst1);
_mm_store_ps(dst + 4, dst2);
_mm_store_ps(dst + 8, dst3);
_mm_store_ps(dst + 12, dst4);
_mm_store_ps(dst + 16, dst5);
_mm_store_ps(dst + 20, dst6);
_mm_store_ps(dst + 24, dst7);
_mm_store_ps(dst + 28, dst8);
dst += 32;
c = dst;
} else {
switch (col) {
case 1: // write1
_mm_store_ss(dst, dst1);
if (r > 1) {
dst += stride;
_mm_store_ss(dst, dst3);
}
if (r > 2) {
dst += stride;
_mm_store_ss(dst, dst5);
}
if (r > 3) {
dst += stride;
_mm_store_ss(dst, dst7);
dst += stride;
}
case 2: // write2
_mm_store_ss(dst, dst1);
dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst, dst1);
if (r > 1) {
dst += stride;
_mm_store_ss(dst, dst3);
dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst, dst3);
}
if (r > 2) {
dst += stride;
_mm_store_ss(dst, dst5);
dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst, dst5);
}
if (r > 3) {
dst += stride;
_mm_store_ss(dst, dst7);
dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst, dst7);
}
case 3: // write3
_mm_store_ss(dst, dst1);
dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 1, dst1);
dst1 = _mm_shuffle_ps(dst1 + 2, dst1, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst, dst1);
if (r > 1) {
dst += stride;
_mm_store_ss(dst, dst3);
dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 1, dst3);
dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 2, dst3);
}
if (r > 2) {
dst += stride;
_mm_store_ss(dst, dst5);
dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 1, dst5);
dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 2, dst5);
}
if (r > 3) {
dst += stride;
_mm_store_ss(dst, dst7);
dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 1, dst7);
dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 2, dst7);
dst += stride;
}
case 4: // write4
_mm_store_ps(dst, dst1);
if (r > 1) {
dst += stride;
_mm_store_ps(dst, dst3);
}
if (r > 2) {
dst += stride;
_mm_store_ps(dst, dst5);
}
if (r > 3) {
dst += stride;
_mm_store_ps(dst, dst7);
dst += stride;
}
case 5: // // write5
_mm_store_ps(dst, dst1);
_mm_store_ss(dst + 4, dst2);
if (r > 1) {
dst += stride;
_mm_store_ps(dst, dst3);
_mm_store_ss(dst + 4, dst4);
}
if (r > 2) {
dst += stride;
_mm_store_ps(dst, dst5);
_mm_store_ss(dst + 4, dst6);
}
if (r > 3) {
dst += stride;
_mm_store_ps(dst, dst7);
_mm_store_ss(dst + 4, dst8);
dst += stride;
}
case 6: // write6
_mm_store_ps(dst, dst1);
_mm_store_ss(dst + 4, dst2);
dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 5, dst2);
if (r > 1) {
dst += stride;
_mm_store_ps(dst, dst3);
_mm_store_ss(dst + 4, dst4);
dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 5, dst4);
}
if (r > 2) {
dst += stride;
_mm_store_ps(dst, dst5);
_mm_store_ss(dst + 4, dst6);
dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 5, dst6);
}
if (r > 3) {
dst += stride;
_mm_store_ps(dst, dst7);
_mm_store_ss(dst + 4, dst8);
dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 5, dst8);
dst += stride;
}
case 7: // write7
_mm_store_ps(dst, dst1);
_mm_store_ss(dst + 4, dst2);
dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 5, dst2);
dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 6, dst2);
if (r > 1) {
dst += stride;
_mm_store_ps(dst, dst3);
_mm_store_ss(dst + 4, dst4);
dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 5, dst4);
dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 6, dst4);
}
if (r > 2) {
dst += stride;
_mm_store_ps(dst, dst5);
_mm_store_ss(dst + 4, dst6);
dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 5, dst6);
dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 6, dst6);
}
if (r > 3) {
dst += stride;
_mm_store_ps(dst, dst7);
_mm_store_ss(dst + 4, dst8);
dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 5, dst8);
dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1));
_mm_store_ss(dst + 6, dst8);
dst += stride;
}
default: // write8
_mm_store_ps(dst, dst1);
_mm_store_ps(dst + 4, dst2);
if (r > 1) {
dst += stride;
_mm_store_ps(dst, dst3);
_mm_store_ps(dst + 4, dst4);
}
if (r > 2) {
dst += stride;
_mm_store_ps(dst, dst5);
_mm_store_ps(dst + 4, dst6);
}
if (r > 3) {
dst += stride;
_mm_store_ps(dst, dst7);
_mm_store_ps(dst + 4, dst8);
dst += stride;
}
}
}
if (r <= C4NUM) { // WriteEnd
break;
}
}
b += depth * C8NUM;
bias += (bias != NULL) ? C8NUM : 0;
if (WriteWino != 0) {
c += DstWinoSteps;
} else if (writeNhwc != 0) {
c += C8NUM;
}
if (col_tmp <= C8NUM) {
break;
}
}
}
#endif

View File

@ -76,6 +76,16 @@ if (ENABLE_FP16)
${KERNEL_OP_FP16_SRC}
)
endif ()
if ("${X86_64_SIMD}" STREQUAL "sse")
file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/x86_64_sse/*.c)
set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C)
set(KERNEL_OP_SRC
${KERNEL_OP_SRC}
${TEST_ASSEMBLY_SRC}
)
endif()
### gpu kernel
if (SUPPORT_GPU)
file(GLOB GPU_KERNEL_OP_SRC