avx support vs build

This commit is contained in:
sunsuodong 2021-09-02 17:20:26 +08:00
parent af2dd898a3
commit 6e845fe3a3
19 changed files with 78 additions and 53 deletions

View File

@ -41,7 +41,7 @@ asm_function ConvDwFp32Avx3x3
pushq %rdi // -96
addq $96, %rsp
#ifdef WIN32
#ifdef _WIN32
movq %rcx, %rdi
movq %rdx, %rsi
movq %r8, %rdx

View File

@ -24,7 +24,7 @@ asm_function ConvDwFp32Border
addq $96, %rsp
movq %rdi, %rdx
#ifdef WIN32
#ifdef _WIN32
movq %rcx, %rdx
#endif
movq 8(%rdx), %r12 // src

View File

@ -31,7 +31,7 @@ asm_function ConvDwFp32Row
pushq %rdi
addq $48, %rsp
#ifdef WIN32
#ifdef _WIN32
movq %rcx, %rdi // output_ptr
movq %rdx, %rsi // input_ptr
movq %r8, %rdx // weight_ptr

View File

@ -47,7 +47,7 @@ asm_function MatmulFloatAvxOpt
pushq %rsi // -104 rsi
pushq %rdi // -112 rdi
addq $112, %rsp
#ifdef WIN32
#ifdef _WIN32
movq %rcx, %rdi
movq %rdx, %rsi
movq %r8, %rdx

View File

@ -104,7 +104,7 @@ int Sigmoid(const float *src, int length, float *dst) {
int i = 0;
#if defined(ENABLE_AVX)
for (; i <= length - 8; i += 8) {
simd_exp_avx(-(MS_LD256_F32(src + i)), dst + i);
simd_exp_avx(MS_SUB256_F32(MS_MOV256_F32(0.0f), (MS_LD256_F32(src + i))), dst + i);
MS_ST256_F32(dst + i,
MS_DIV256_F32(MS_MOV256_F32(1.0f), MS_ADD256_F32(MS_MOV256_F32(1.0f), MS_LD256_F32(dst + i))));
}
@ -232,25 +232,32 @@ int Gelu(const float *src, int length, float *dst, bool approximate) {
if (approximate) {
// dst = 0.5 * x * (1 + tanh((2 / pi) ^ 0.5 * (x + 0.044715x^3)))
#if defined(ENABLE_AVX)
MS_FLOAT32X8 para1 = MS_MOV256_F32(0.79788456080287f);
MS_FLOAT32X8 para2 = MS_MOV256_F32(0.035677408136f);
MS_FLOAT32X8 para3 = MS_MOV256_F32(1.0f);
MS_FLOAT32X8 para4 = MS_MOV256_F32(0.5f);
int C8 = DOWN_ROUND(length, C8NUM);
for (; i < C8; i += C8NUM) {
MS_FLOAT32X8 in = MS_LD256_F32(src + i);
const MS_FLOAT32X8 res = 0.5 * in * (1.0 + MS_TANHX8_F32((0.79788456080287f + 0.035677408136f * in * in) * in));
const MS_FLOAT32X8 res = MS_MUL256_F32(
MS_MUL256_F32(para4, in),
MS_ADD256_F32(
para3, MS_TANHX8_F32(MS_MUL256_F32(MS_ADD256_F32(para1, MS_MUL256_F32(MS_MUL256_F32(para2, in), in)), in))));
MS_ST256_F32(dst + i, res);
}
#endif
#if defined(ENABLE_SSE) || defined(ENABLE_ARM)
MS_FLOAT32X4 para1 = MS_MOVQ_F32(0.79788456080287f);
MS_FLOAT32X4 para2 = MS_MOVQ_F32(0.035677408136f);
MS_FLOAT32X4 para3 = MS_MOVQ_F32(1.0f);
MS_FLOAT32X4 para4 = MS_MOVQ_F32(0.5f);
MS_FLOAT32X4 para5 = MS_MOVQ_F32(0.79788456080287f);
MS_FLOAT32X4 para6 = MS_MOVQ_F32(0.035677408136f);
MS_FLOAT32X4 para7 = MS_MOVQ_F32(1.0f);
MS_FLOAT32X4 para8 = MS_MOVQ_F32(0.5f);
int C4 = DOWN_ROUND(length, C4NUM);
for (; i < C4; i += C4NUM) {
MS_FLOAT32X4 in = MS_LDQ_F32(src + i);
MS_FLOAT32X4 res = MS_MULQ_F32(
MS_MULQ_F32(para4, in),
MS_ADDQ_F32(para3,
MS_TANHX4_F32(MS_MULQ_F32(MS_ADDQ_F32(para1, MS_MULQ_F32(MS_MULQ_F32(para2, in), in)), in))));
MS_MULQ_F32(para8, in),
MS_ADDQ_F32(para7,
MS_TANHX4_F32(MS_MULQ_F32(MS_ADDQ_F32(para5, MS_MULQ_F32(MS_MULQ_F32(para6, in), in)), in))));
MS_STQ_F32(dst + i, res);
}
#endif

View File

@ -15,7 +15,11 @@
*/
#ifdef ENABLE_AVX
#include "nnacl/fp32/conv_1x1_x86_fp32.h"
#ifdef _MSC_VER
#include <immintrin.h>
#else
#include <x86intrin.h>
#endif
// sliding window to compate 1x1 conv in x86
void Conv1x1SWFp32(const float *input_data, const float *packed_weight, const float *bias_data, float *output_data,

View File

@ -17,8 +17,12 @@
#include "nnacl/fp32/conv_common_fp32.h"
#include <string.h>
#ifdef ENABLE_AVX
#ifdef _MSC_VER
#include <immintrin.h>
#else
#include <x86intrin.h>
#endif
#endif
#include "nnacl/fp32/matmul_fp32.h"
// fp32 conv common

View File

@ -105,7 +105,7 @@ void DepthwiseSWAvxFp32(float *output_data, const float *input_data, const float
void DepthwiseBorderAvxFp32(float *dst, const float *src, const float *weight, const float *bias, int top, int left,
int right, const ConvParameter *conv_param, const SlidingWindowParam *sw_param,
DepthwiseSWKernel kernel, int act_type, int ow_bock, int oc_block);
const DepthwiseSWKernel kernel, int act_type, int ow_bock, int oc_block);
void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, size_t channels,
size_t output_width, size_t input_stride, size_t relu, size_t relu6);

View File

@ -67,13 +67,13 @@ static inline void simd_exp_avx(MS_FLOAT32X8 input, float *dst) {
{0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}};
input = MS_MAX256_F32(minv, MS_MIN256_F32(input, maxv));
MS_INT32X8 integer = MS_CVT256PS_EPI32(input / param[0]);
MS_FLOAT32X8 decimal = input - MS_CVT256EPI32_PS(integer) * param[0];
MS_INT32X8 integer = MS_CVT256PS_EPI32(MS_DIV256_F32(input, param[0]));
MS_FLOAT32X8 decimal = MS_SUB256_F32(input, MS_MUL256_F32(MS_CVT256EPI32_PS(integer), param[0]));
MS_INT32X8 int_exp = MS_SLLI256_EPI32(MS_ADD256_EPI32(integer, MS_MOV256_EPI32(127)), 23);
MS_FLOAT32X8 decimal_exp =
param[5] +
decimal * (param[5] + decimal * (param[4] + decimal * (param[3] + decimal * (param[2] + decimal * param[1]))));
MS_ST256_F32(dst, decimal_exp * MS_CAST256_F32_S32(int_exp));
MS_FLOAT32X8 tmp = MS_MUL256_F32(decimal, (MS_ADD256_F32(param[2], MS_MUL256_F32(decimal, param[1]))));
tmp = MS_MUL256_F32(decimal, MS_ADD256_F32(param[4], MS_MUL256_F32(decimal, MS_ADD256_F32(param[3], tmp))));
MS_FLOAT32X8 decimal_exp = MS_ADD256_F32(param[5], MS_MUL256_F32(decimal, MS_ADD256_F32(param[5], tmp)));
MS_ST256_F32(dst, MS_MUL256_F32(decimal_exp, MS_CAST256_F32_S32(int_exp)));
}
#endif

View File

@ -116,7 +116,7 @@ int SoftplusGrad(const float *src0, const float *src1, int length, float *dst) {
int i = 0;
#if defined(ENABLE_AVX)
for (; i <= length - C8NUM; i += C8NUM) {
simd_exp_avx(-(MS_LD256_F32(src1 + i)), dst + i);
simd_exp_avx(MS_SUB256_F32(MS_MOV256_F32(0.0f), (MS_LD256_F32(src1 + i))), dst + i);
MS_ST256_F32(dst + i,
MS_DIV256_F32(MS_LD256_F32(src0 + i), MS_ADD256_F32(MS_MOV256_F32(1.0f), MS_LD256_F32(dst + i))));
}

View File

@ -47,7 +47,6 @@
#include "nnacl/infer/string/custom_normalize_infer.h"
#include "nnacl/infer/string/custom_predict_infer.h"
#include "nnacl/infer/deconv2d_infer.h"
#include "nnacl/infer/dedepthwise_conv2d_infer.h"
#include "nnacl/infer/depth_to_space_infer.h"
#include "nnacl/infer/depthwise_conv2d_infer.h"
#include "nnacl/infer/detection_post_process_infer.h"

View File

@ -15,11 +15,8 @@
*/
#include "nnacl/int8/add_int8.h"
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include "nnacl/intrinsics/ms_simd_instructions.h"
#ifdef ENABLE_AVX
#include <x86intrin.h>
#include "nnacl/intrinsics/avx/common_utils.h"
#endif
#include "nnacl/int8/fixed_point.h"
@ -319,8 +316,8 @@ void AddInt8_AVX2(const int8_t *input0, const int8_t *input1, int8_t *output, in
const __m128i out_multiplier = _mm_set1_epi32(params->out_multiplier_);
int index = 0;
for (; index <= size - 16; index += 16) {
const __m128i in0_src = _mm_loadu_si128((__m128i_u *)(input0 + index));
const __m128i in1_src = _mm_loadu_si128((__m128i_u *)(input1 + index));
const __m128i in0_src = _mm_loadu_si128((__m128i *)(input0 + index));
const __m128i in1_src = _mm_loadu_si128((__m128i *)(input1 + index));
const __m256i in0_s16 = _mm256_cvtepi8_epi16(in0_src);
const __m128i in0_s16_low = _mm256_extractf128_si256(in0_s16, 0);
@ -398,7 +395,7 @@ void AddInt8_AVX2(const int8_t *input0, const int8_t *input1, int8_t *output, in
__m128i out = _mm_packs_epi16(out_s16_1, out_s16_2);
__m128i int8_out = _mm_max_epi8(min_vec, _mm_min_epi8(max_vec, out));
_mm_storeu_si128((__m128i_u *)(output + index), int8_out);
_mm_storeu_si128((__m128i *)(output + index), int8_out);
}
for (; index < size; index++) {
const int32_t in0_left = (input0[index] + params->in0_args_.zp_) * in0_left_shift;
@ -452,7 +449,7 @@ void AddOptInt8_AVX2(const int8_t *ptr_in, const int8_t element_in, int8_t *outp
int index = 0;
for (; index <= size - 16; index += 16) {
const __m128i in0_src = _mm_loadu_si128((__m128i_u *)(ptr_in + index));
const __m128i in0_src = _mm_loadu_si128((__m128i *)(ptr_in + index));
const __m256i in0_s16 = _mm256_cvtepi8_epi16(in0_src);
const __m128i in0_s16_low = _mm256_extractf128_si256(in0_s16, 0);
const __m128i in0_s16_high = _mm256_extractf128_si256(in0_s16, 1);
@ -516,7 +513,7 @@ void AddOptInt8_AVX2(const int8_t *ptr_in, const int8_t element_in, int8_t *outp
__m128i out = _mm_packs_epi16(out_s16_1, out_s16_2);
__m128i int8_out = _mm_max_epi8(min_vec, _mm_min_epi8(max_vec, out));
_mm_storeu_si128((__m128i_u *)(output + index), int8_out);
_mm_storeu_si128((__m128i *)(output + index), int8_out);
}
for (; index < size; index++) {
const int32_t in0_left = (ptr_in[index] + ptr_args->zp_) * in0_left_shift;

View File

@ -13,13 +13,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_ADD_INT8_H_
#define MINDSPORE_NNACL_ADD_INT8_H_
#ifdef ENABLE_AVX
#include <x86intrin.h>
#endif
#include "nnacl/op_base.h"
#include "nnacl/errorcode.h"
#include "nnacl/arithmetic.h"

View File

@ -14,7 +14,11 @@
* limitations under the License.
*/
#ifdef ENABLE_AVX
#ifdef _MSC_VER
#include <immintrin.h>
#else
#include <x86intrin.h>
#endif
#include "nnacl/fp32/common_func_fp32.h"
static inline __m256 padd(__m256 v0, __m256 v1, __m256 v2, __m256 v3) {

View File

@ -14,11 +14,7 @@
* limitations under the License.
*/
#include "nnacl/intrinsics/avx/common_utils.h"
#ifdef WIN32
#ifdef ENABLE_AVX
#include <stdint.h>
#endif
#endif
__m128i _mm_adds_epi32(__m128i a, __m128i b) {
__m128i int_min = _mm_set1_epi32(0x80000000);

View File

@ -16,7 +16,7 @@
#ifndef MINDSPORE_NNACL_X86_64_AVX_COMMON_UTILS_H_
#define MINDSPORE_NNACL_X86_64_AVX_COMMON_UTILS_H_
#ifdef SUPPORT_MSVC
#ifdef _MSC_VER
#include <immintrin.h>
#else
#include <x86intrin.h>

View File

@ -24,7 +24,7 @@
#endif
#if defined(ENABLE_SSE) || defined(ENABLE_AVX)
#ifdef SUPPORT_MSVC
#ifdef _MSC_VER
#include <immintrin.h>
#define MS_F32X4_GETI(src, i) src.m128_f32[i]
#else
@ -224,13 +224,25 @@ static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) {
MS_ST256_F32(output_ptr + 7 * num, dst##8);
static inline MS_FLOAT32X8 MS_TANHX8_F32(MS_FLOAT32X8 src) {
static const float data[] = {378.0f, 17325.0f, 135135.0f, 28.0f, 3150.0f, 62370.0f};
static const MS_FLOAT32X8 data0 = {378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f};
static const MS_FLOAT32X8 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f};
static const MS_FLOAT32X8 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f,
135135.0f, 135135.0f, 135135.0f, 135135.0f};
static const MS_FLOAT32X8 data3 = {28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f};
static const MS_FLOAT32X8 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f};
static const MS_FLOAT32X8 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f};
static const MS_FLOAT32X8 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f};
static const MS_FLOAT32X8 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
MS_FLOAT32X8 square = src * src;
MS_FLOAT32X8 a = (((square + data[0]) * square + data[1]) * square + data[2]) * src;
MS_FLOAT32X8 b = ((data[3] * square + data[4]) * square + data[5]) * square + data[2];
return MS_MIN256_F32(MS_MAX256_F32(a / b, neg), pos);
MS_FLOAT32X8 square = MS_MUL256_F32(src, src);
MS_FLOAT32X8 a = MS_MUL256_F32(
MS_ADD256_F32(MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(MS_ADD256_F32(square, data0), square), data1), square),
data2),
src);
MS_FLOAT32X8 b = MS_ADD256_F32(
MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(data3, square), data4), square), data5),
square),
data2);
return MS_MIN256_F32(MS_MAX256_F32(MS_DIV256_F32(a, b), neg), pos);
}
#endif

View File

@ -15,19 +15,23 @@
*/
#ifdef ENABLE_AVX
#ifdef _MSC_VER
#include <immintrin.h>
#else
#include <x86intrin.h>
#endif
#include "nnacl/fp32/conv_depthwise_fp32.h"
#define INPUT_SIZE 25
void ConvDwFp32Avx5x5(float *output, float **input, const float *weights, const float *bias, size_t channels,
size_t output_width, size_t input_stride, size_t relu, size_t relu6) {
input_stride /= sizeof(float *);
size_t c8 = UP_DIV(channels, C8NUM) * C8NUM;
size_t c8_mod = channels % C8NUM;
const int kernel = 25;
float *in[INPUT_SIZE];
for (int i = 0; i < output_width; ++i) {
float *in[kernel];
for (int k = 0; k < kernel; k++) {
for (int k = 0; k < INPUT_SIZE; k++) {
in[k] = input[k];
}
input += input_stride;
@ -37,7 +41,7 @@ void ConvDwFp32Avx5x5(float *output, float **input, const float *weights, const
for (; c >= C8NUM; c -= C8NUM) {
__m256 out1 = _mm256_loadu_ps(bias1);
bias1 += 8;
for (int k = 0; k < kernel; k += 5) {
for (int k = 0; k < INPUT_SIZE; k += 5) {
__m256 in1 = _mm256_loadu_ps(in[k]);
__m256 w1 = _mm256_loadu_ps(w);
__m256 in2 = _mm256_loadu_ps(in[k + 1]);

View File

@ -444,8 +444,10 @@ if(NOT PLATFORM_ARM)
set(X86_64_SIMD "avx")
add_compile_definitions(ENABLE_SSE)
add_compile_definitions(ENABLE_AVX)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mfma")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx -mfma")
if(NOT MSVC)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mfma")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx -mfma")
endif()
elseif(MSLITE_ENABLE_SSE)
set(X86_64_SIMD "sse")
add_compile_definitions(ENABLE_SSE)