!28051 [MSLITE][CPU] AVX512/256/SSE/NENO Advanced packaging, and Reduce Op Refactoring and optimization

Merge pull request !28051 from Greatpan/master_reduce
This commit is contained in:
i-robot 2021-12-25 02:39:14 +00:00 committed by Gitee
commit 9901f0fdb2
6 changed files with 655 additions and 583 deletions

View File

@ -23,261 +23,133 @@
#include "nnacl/reduce_parameter.h"
#endif
int ReduceMean(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid,
int thread_num) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
// 32 bits, block_size : (512/256/128/32), block_num : (16/8/4/1)
#define ReduceCoreCalc(block_size, block_num, op_name, op_type, outer_src, outer_dst, k) \
for (int block_max_size = inner_size - block_num + 1; k < block_max_size; k += block_num) { \
const op_type *inner_src = outer_src + k; \
op_type *inner_dst = outer_dst + k; \
op_name##PreDeal(block_size, block_num); \
for (int i = 0; i < axis_size; i++) { \
op_name##MidCalc(block_size, block_num); \
} \
op_name##PostDeal(block_size, block_num); \
}
if (thread_num == 0) {
return NNACL_PARAM_INVALID;
}
int i, j, k;
for (j = tid; j < outer_size; j += thread_num) {
const float *outer_src = src_data + j * axis_size * inner_size;
float *outer_dst = dst_data + j * inner_size;
for (k = 0; k < inner_size; k++) {
const float *inner_src = outer_src + k;
float *inner_dst = outer_dst + k;
float tmp = 0.0f;
for (i = 0; i < axis_size; i++) {
tmp += inner_src[i * inner_size];
}
*inner_dst = tmp / (float)axis_size;
}
}
return NNACL_OK;
}
int IntReduceMean(int outer_size, int inner_size, int axis_size, const int *src_data, int *dst_data, int tid,
int thread_num) {
if (axis_size == 0) {
return NNACL_ERR;
#define RegReduceOp(op_name, op_type) \
int op_name(int outer_size, int inner_size, int axis_size, const op_type *src_data, op_type *dst_data, int tid, \
int thread_num) { \
MS_CHECK_TRUE_RET(src_data != NULL && dst_data != NULL, NNACL_NULL_PTR); \
MS_CHECK_TRUE_RET(thread_num > 0, NNACL_PARAM_INVALID); \
for (int j = tid; j < outer_size; j += thread_num) { \
const op_type *outer_src = src_data + j * axis_size * inner_size; \
op_type *outer_dst = dst_data + j * inner_size; \
int k = 0; \
MS_SIMD_RUN(ReduceCoreCalc, op_name, op_type, outer_src, outer_dst, k); \
} \
return NNACL_OK; \
}
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
if (thread_num == 0) {
return NNACL_PARAM_INVALID;
}
NNACL_CHECK_ZERO_RETURN_ERR(axis_size);
int i, j;
#ifdef ENABLE_NEON
int block_mod = inner_size % C4NUM;
int block_c4 = inner_size - block_mod;
#endif
for (j = tid; j < outer_size; j += thread_num) {
const int *outer_src = src_data + j * axis_size * inner_size;
int *outer_dst = dst_data + j * inner_size;
int k = 0;
#ifdef ENABLE_NEON
for (; k < block_c4; k += C4NUM) {
const int *inner_src = outer_src + k;
int *inner_dst = outer_dst + k;
int32x4_t tmp = {0, 0, 0, 0};
for (i = 0; i < axis_size; i++) {
tmp = vaddq_s32(tmp, vld1q_s32(inner_src + i * inner_size));
}
tmp[0] /= axis_size;
tmp[1] /= axis_size;
tmp[2] /= axis_size;
tmp[3] /= axis_size;
vst1q_s32(inner_dst, tmp);
}
#endif
for (; k < inner_size; k++) {
const int *inner_src = outer_src + k;
int *inner_dst = outer_dst + k;
int tmp = 0;
for (i = 0; i < axis_size; i++) {
tmp += inner_src[i * inner_size];
}
*inner_dst = tmp / axis_size;
}
}
return NNACL_OK;
}
int ReduceSum(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid,
int thread_num) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
if (thread_num == 0) {
return NNACL_PARAM_INVALID;
}
int i, j;
#ifdef ENABLE_NEON
int block_mod = inner_size % C4NUM;
int block_c4 = inner_size - block_mod;
#endif
for (j = tid; j < outer_size; j += thread_num) {
const float *outer_src = src_data + j * axis_size * inner_size;
float *outer_dst = dst_data + j * inner_size;
int k = 0;
#ifdef ENABLE_NEON
for (; k < block_c4; k += C4NUM) {
const float *inner_src = outer_src + k;
float *inner_dst = outer_dst + k;
float32x4_t tmp = {0, 0, 0, 0};
for (i = 0; i < axis_size; i++) {
tmp = vaddq_f32(tmp, vld1q_f32(inner_src + i * inner_size));
}
vst1q_f32(inner_dst, tmp);
}
#endif
for (; k < inner_size; k++) {
const float *inner_src = outer_src + k;
float *inner_dst = outer_dst + k;
float tmp = 0.0f;
for (i = 0; i < axis_size; i++) {
tmp += inner_src[i * inner_size];
}
*inner_dst = tmp;
}
}
return NNACL_OK;
}
// ReduceSum
// (c style) ReduceSumPreDeal : float tmp = 0;
#define ReduceSumPreDeal(block_size, block_num) MS_FLOAT_32xN(block_num) tmp = MS_MOVN_F32(block_size, 0);
// (c style) ReduceSumMidCalc : tmp = tmp + *(inner_src + i * inner_size);
#define ReduceSumMidCalc(block_size, block_num) \
tmp = MS_ADD_F32(block_size, tmp, MS_LD_F32(block_size, inner_src + i * inner_size));
// (c style) ReduceSumPostDeal : *inner_dst = tmp;
#define ReduceSumPostDeal(block_size, block_num) MS_ST_F32(block_size, inner_dst, tmp);
RegReduceOp(ReduceSum, float);
int IntReduceSum(int outer_size, int inner_size, int axis_size, const int *src_data, int *dst_data, int tid,
int thread_num) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
if (thread_num == 0) {
return NNACL_PARAM_INVALID;
}
int i, j;
#ifdef ENABLE_NEON
int block_mod = inner_size % C4NUM;
int block_c4 = inner_size - block_mod;
#endif
for (j = tid; j < outer_size; j += thread_num) {
const int *outer_src = src_data + j * axis_size * inner_size;
int *outer_dst = dst_data + j * inner_size;
int k = 0;
#ifdef ENABLE_NEON
for (; k < block_c4; k += C4NUM) {
const int *inner_src = outer_src + k;
int *inner_dst = outer_dst + k;
int32x4_t tmp = {0, 0, 0, 0};
for (i = 0; i < axis_size; i++) {
tmp = vaddq_s32(tmp, vld1q_s32(inner_src + i * inner_size));
}
vst1q_s32(inner_dst, tmp);
}
#endif
for (; k < inner_size; k++) {
const int *inner_src = outer_src + k;
int *inner_dst = outer_dst + k;
int tmp = 0;
for (i = 0; i < axis_size; i++) {
tmp += inner_src[i * inner_size];
}
*inner_dst = tmp;
}
}
return NNACL_OK;
}
// ReduceMean
// (c style) ReduceMeanPreDeal : int tmp = 0;
#define ReduceMeanPreDeal(block_size, block_num) MS_FLOAT_32xN(block_num) tmp = MS_MOVN_F32(block_size, 0);
// (c style) ReduceMeanMidCalc : tmp = tmp + *(inner_src + i * inner_size);
#define ReduceMeanMidCalc(block_size, block_num) \
tmp = MS_ADD_F32(block_size, tmp, MS_LD_F32(block_size, inner_src + i * inner_size));
// (c style) ReduceMeanPostDeal : *inner_dst = tmp / axis_size;
#define ReduceMeanPostDeal(block_size, block_num) \
MS_ST_F32(block_size, inner_dst, MS_DIV_N_F32(block_size, tmp, axis_size));
RegReduceOp(ReduceMean, float);
int ReduceMax(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid,
int thread_num) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
if (thread_num == 0) {
return NNACL_PARAM_INVALID;
}
int i, j, k;
for (j = tid; j < outer_size; j += thread_num) {
const float *outer_src = src_data + j * axis_size * inner_size;
float *outer_dst = dst_data + j * inner_size;
for (k = 0; k < inner_size; k++) {
const float *inner_src = outer_src + k;
float *inner_dst = outer_dst + k;
float tmp = -FLT_MAX;
for (i = 0; i < axis_size; i++) {
tmp = tmp > inner_src[i * inner_size] ? tmp : inner_src[i * inner_size];
}
*inner_dst = tmp;
}
}
return NNACL_OK;
}
// ReduceMin
// (c style) ReduceMinPreDeal : float tmp = FLT_MAX;
#define ReduceMinPreDeal(block_size, block_num) MS_FLOAT_32xN(block_num) tmp = MS_MOVN_F32(block_size, FLT_MAX);
// (c style) ReduceMinMidCalc : tmp = fminf(tmp, *(inner_src + i * inner_size));
#define ReduceMinMidCalc(block_size, block_num) \
tmp = MS_MIN_F32(block_size, tmp, MS_LD_F32(block_size, inner_src + i * inner_size));
// (c style) ReduceMinPostDeal : *inner_dst = tmp;
#define ReduceMinPostDeal(block_size, block_num) MS_ST_F32(block_size, inner_dst, tmp);
RegReduceOp(ReduceMin, float);
int IntReduceMax(int outer_size, int inner_size, int axis_size, const int *src_data, int *dst_data, int tid,
int thread_num) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
if (thread_num == 0) {
return NNACL_PARAM_INVALID;
}
int i, j, k;
for (j = tid; j < outer_size; j += thread_num) {
const int *outer_src = src_data + j * axis_size * inner_size;
int *outer_dst = dst_data + j * inner_size;
for (k = 0; k < inner_size; k++) {
const int *inner_src = outer_src + k;
int *inner_dst = outer_dst + k;
int tmp = -INT_MAX;
for (i = 0; i < axis_size; i++) {
tmp = tmp > inner_src[i * inner_size] ? tmp : inner_src[i * inner_size];
}
*inner_dst = tmp;
}
}
return NNACL_OK;
}
// ReduceMax
// (c style) ReduceMaxPreDeal : float tmp = FLT_MIN;
#define ReduceMaxPreDeal(block_size, block_num) MS_FLOAT_32xN(block_num) tmp = MS_MOVN_F32(block_size, FLT_MIN);
// (c style) ReduceMaxMidCalc : tmp = fmaxf(tmp, *(inner_src + i * inner_size));
#define ReduceMaxMidCalc(block_size, block_num) \
tmp = MS_MAX_F32(block_size, tmp, MS_LD_F32(block_size, inner_src + i * inner_size));
// (c style) ReduceMaxPostDeal : *inner_dst = tmp;
#define ReduceMaxPostDeal(block_size, block_num) MS_ST_F32(block_size, inner_dst, tmp);
RegReduceOp(ReduceMax, float);
int ReduceMin(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid,
int thread_num) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
if (thread_num == 0) {
return NNACL_PARAM_INVALID;
}
int i, j, k;
for (j = tid; j < outer_size; j += thread_num) {
const float *outer_src = src_data + j * axis_size * inner_size;
float *outer_dst = dst_data + j * inner_size;
for (k = 0; k < inner_size; k++) {
const float *inner_src = outer_src + k;
float *inner_dst = outer_dst + k;
float tmp = FLT_MAX;
for (i = 0; i < axis_size; i++) {
tmp = tmp < inner_src[i * inner_size] ? tmp : inner_src[i * inner_size];
}
*inner_dst = tmp;
}
}
return NNACL_OK;
}
// ReduceProd
// (c style) ReduceProdPreDeal : float tmp = 1.0f;
#define ReduceProdPreDeal(block_size, block_num) MS_FLOAT_32xN(block_num) tmp = MS_MOVN_F32(block_size, 1.0f);
// (c style) ReduceProdMidCalc : tmp = tmp * (*(inner_src + i * inner_size));
#define ReduceProdMidCalc(block_size, block_num) \
tmp = MS_MUL_F32(block_size, tmp, MS_LD_F32(block_size, inner_src + i * inner_size));
// (c style) ReduceProdPostDeal : *inner_dst = tmp;
#define ReduceProdPostDeal(block_size, block_num) MS_ST_F32(block_size, inner_dst, tmp);
RegReduceOp(ReduceProd, float);
int IntReduceMin(int outer_size, int inner_size, int axis_size, const int *src_data, int *dst_data, int tid,
int thread_num) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
if (thread_num == 0) {
return NNACL_PARAM_INVALID;
}
int i, j, k;
for (j = tid; j < outer_size; j += thread_num) {
const int *outer_src = src_data + j * axis_size * inner_size;
int *outer_dst = dst_data + j * inner_size;
for (k = 0; k < inner_size; k++) {
const int *inner_src = outer_src + k;
int *inner_dst = outer_dst + k;
int tmp = INT32_MAX;
for (i = 0; i < axis_size; i++) {
tmp = tmp < inner_src[i * inner_size] ? tmp : inner_src[i * inner_size];
}
*inner_dst = tmp;
}
}
return NNACL_OK;
}
// ReduceSumSquare
// (c style) ReduceSumSquarePreDeal : float tmp = 0;
#define ReduceSumSquarePreDeal(block_size, block_num) MS_FLOAT_32xN(block_num) tmp = MS_MOVN_F32(block_size, 0);
// (c style) ReduceSumSquareMidCalc : float val = *(inner_src + i * inner_size); tmp = tmp + val * val;
#define ReduceSumSquareMidCalc(block_size, block_num) \
tmp = MS_ADD_F32(block_size, tmp, MS_MUL_SQUARE_F32(block_size, MS_LD_F32(block_size, inner_src + i * inner_size)));
// (c style) ReduceSumSquarePostDeal : *inner_dst = tmp;
#define ReduceSumSquarePostDeal(block_size, block_num) MS_ST_F32(block_size, inner_dst, tmp);
RegReduceOp(ReduceSumSquare, float);
// IntReduceSum
// (c style) IntReduceSumPreDeal : int tmp = 0;
#define IntReduceSumPreDeal(block_size, block_num) MS_INT_32xN(block_num) tmp = MS_MOVN_EPI32(block_size, 0);
// (c style) IntReduceSumMidCalc : tmp = tmp + *(inner_src + i * inner_size);
#define IntReduceSumMidCalc(block_size, block_num) \
tmp = MS_ADD_EPI32(block_size, tmp, MS_LD_EPI32(block_size, inner_src + i * inner_size));
// (c style) IntReduceSumPostDeal : *inner_dst = tmp;
#define IntReduceSumPostDeal(block_size, block_num) MS_ST_EPI32(block_size, inner_dst, tmp);
RegReduceOp(IntReduceSum, int);
// IntReduceMean
// (c style) IntReduceSumPreDeal : int tmp = 0;
#define IntReduceMeanPreDeal(block_size, block_num) MS_INT_32xN(block_num) tmp = MS_MOVN_EPI32(block_size, 0);
// (c style) IntReduceSumMidCalc : tmp = tmp + *(inner_src + i * inner_size);
#define IntReduceMeanMidCalc(block_size, block_num) \
tmp = MS_ADD_EPI32(block_size, tmp, MS_LD_EPI32(block_size, inner_src + i * inner_size));
// (c style) IntReduceSumPostDeal : *inner_dst = tmp / axis_size;
#define IntReduceMeanPostDeal(block_size, block_num) \
MS_ST_EPI32(block_size, inner_dst, MS_DIV_N_EPI32(block_size, tmp, axis_size));
RegReduceOp(IntReduceMean, int);
// IntReduceMin
// (c style) IntReduceMinPreDeal : int tmp = INT32_MAX;
#define IntReduceMinPreDeal(block_size, block_num) MS_INT_32xN(block_num) tmp = MS_MOVN_EPI32(block_size, INT32_MAX);
// (c style) IntReduceMinMidCalc : tmp = fminf(tmp, *(inner_src + i * inner_size));
#define IntReduceMinMidCalc(block_size, block_num) \
tmp = MS_MIN_EPI32(block_size, tmp, MS_LD_EPI32(block_size, inner_src + i * inner_size));
// (c style) IntReduceMinPostDeal : *inner_dst = tmp;
#define IntReduceMinPostDeal(block_size, block_num) MS_ST_EPI32(block_size, inner_dst, tmp);
RegReduceOp(IntReduceMin, int);
// IntReduceMax
// (c style) IntReduceMinPreDeal : int tmp = INT32_MIN;
#define IntReduceMaxPreDeal(block_size, block_num) MS_INT_32xN(block_num) tmp = MS_MOVN_EPI32(block_size, INT32_MIN);
// (c style) IntReduceMinMidCalc : tmp = fmax+f(tmp, *(inner_src + i * inner_size));
#define IntReduceMaxMidCalc(block_size, block_num) \
tmp = MS_MAX_EPI32(block_size, tmp, MS_LD_EPI32(block_size, inner_src + i * inner_size));
// (c style) IntReduceMinPostDeal : *inner_dst = tmp;
#define IntReduceMaxPostDeal(block_size, block_num) MS_ST_EPI32(block_size, inner_dst, tmp);
RegReduceOp(IntReduceMax, int);
int ReduceAll(int outer_size, int inner_size, int axis_size, const bool *src_data, bool *dst_data, int tid,
int thread_num) {
@ -304,31 +176,6 @@ int ReduceAll(int outer_size, int inner_size, int axis_size, const bool *src_dat
return NNACL_OK;
}
int ReduceProd(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid,
int thread_num) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
if (thread_num == 0) {
return NNACL_PARAM_INVALID;
}
int i, j, k;
for (j = tid; j < outer_size; j += thread_num) {
const float *outer_src = src_data + j * axis_size * inner_size;
float *outer_dst = dst_data + j * inner_size;
for (k = 0; k < inner_size; k++) {
const float *inner_src = outer_src + k;
float *inner_dst = outer_dst + k;
float tmp = 1.0f;
for (i = 0; i < axis_size; i++) {
tmp *= inner_src[i * inner_size];
}
*inner_dst = tmp;
}
}
return NNACL_OK;
}
int IntReduceProd(int outer_size, int inner_size, int axis_size, const int *src_data, int *dst_data, int tid,
int thread_num) {
if (src_data == NULL || dst_data == NULL) {
@ -357,31 +204,6 @@ int IntReduceProd(int outer_size, int inner_size, int axis_size, const int *src_
return NNACL_OK;
}
int ReduceSumSquare(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid,
int thread_num) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
if (thread_num == 0) {
return NNACL_PARAM_INVALID;
}
int i, j, k;
for (j = tid; j < outer_size; j += thread_num) {
const float *outer_src = src_data + j * axis_size * inner_size;
float *outer_dst = dst_data + j * inner_size;
for (k = 0; k < inner_size; k++) {
const float *inner_src = outer_src + k;
float *inner_dst = outer_dst + k;
float tmp = 0.0f;
for (i = 0; i < axis_size; i++) {
tmp += inner_src[i * inner_size] * inner_src[i * inner_size];
}
*inner_dst = tmp;
}
}
return NNACL_OK;
}
#ifdef ENABLE_NNACL_INFER_SHAPE
int ReduceInferShape(int **in_shape, size_t *dim_size, int *out_shape, int *in_format, int *out_format,
int *in_datatype, int *out_datatype, OpParameter *param) {
@ -392,15 +214,11 @@ int ReduceInferShape(int **in_shape, size_t *dim_size, int *out_shape, int *in_f
int num_axes = reduce_parameter->num_axes_;
int *in_shape0 = in_shape[0];
int rank = dim_size[0];
if (rank <= 0 || rank > REDUCE_MAX_AXES_NUM) {
return NNACL_PARAM_INVALID;
}
MS_CHECK_TRUE_RET(rank > 0 && rank <= REDUCE_MAX_AXES_NUM, NNACL_PARAM_INVALID);
int axes[REDUCE_MAX_AXES_NUM];
int actual_axes_num = num_axes;
for (int i = 0; i < num_axes; ++i) {
if (reduce_parameter->axes_[i] < -rank || reduce_parameter->axes_[i] >= rank) {
return NNACL_PARAM_INVALID;
}
MS_CHECK_TRUE_RET(reduce_parameter->axes_[i] >= -rank && reduce_parameter->axes_[i] < rank, NNACL_PARAM_INVALID);
if (reduce_parameter->axes_[i] < 0) {
axes[i] = reduce_parameter->axes_[i] + rank;
} else {
@ -408,9 +226,7 @@ int ReduceInferShape(int **in_shape, size_t *dim_size, int *out_shape, int *in_f
}
}
if (reduce_parameter->reduce_to_end_) {
if (num_axes != 1) {
return NNACL_PARAM_INVALID;
}
MS_CHECK_TRUE_RET(num_axes == 1, NNACL_PARAM_INVALID);
int begin_axis = axes[0];
num_axes = rank - begin_axis;
for (int i = begin_axis + 1; i < rank; ++i) {

View File

@ -0,0 +1,154 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version C2NUM.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-C2NUM.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_AVX_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
#define MINDSPORE_NNACL_AVX_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
#include <math.h>
#ifdef _MSC_VER
#include <immintrin.h>
#define MS_F32X8_GETI(src, i) src.m256_f32[i]
#else
#include <x86intrin.h>
#define MS_F32X8_GETI(src, i) src[i]
#endif
#define MS_FLOAT32X8 __m256
#define MS_INT32X8 __m256i
#define MS_LD256_F32 _mm256_loadu_ps
#define MS_LD256_EPI32(src) _mm256_loadu_si256((__m256i const *)(src))
#define MS_ADD256_F32 _mm256_add_ps
#define MS_ADD256_EPI32 _mm256_add_epi32
#define MS_MOV256_F32 _mm256_set1_ps
#define MS_MOV256_EPI32 _mm256_set1_epi32
#define MS_MLA256_F32(src1, src2, src3) _mm256_fmadd_ps(src2, src3, src1)
#define MS_ST256_F32 _mm256_storeu_ps
#define MS_ST256_EPI32(src1, src2) _mm256_storeu_si256((__m256i *)(src1), src2)
#define MS_SUB256_F32 _mm256_sub_ps
#define MS_MAX256_F32 _mm256_max_ps
#define MS_MAX256_EPI32 _mm256_max_epi32
#define MS_MIN256_F32 _mm256_min_ps
#define MS_MIN256_EPI32 _mm256_min_epi32
#define MS_MUL256_F32(src1, src2) _mm256_mul_ps(src1, src2)
#define MS_MUL256_EPI32(src1, src2) _mm256_mul_epi32(src1, src2)
#define MS_DIV256_F32(src1, src2) _mm256_div_ps(src1, src2)
#define MS_MUL256_N_F32(src1, src2) _mm256_mul_ps(src1, _mm256_set1_ps(src2))
#define MS_MUL256_N_EPI32(src1, src2) _mm256_mul_epi32(src1, _mm256_set1_epi32(src2))
#define MS_DIV256_N_F32(src1, src2) _mm256_div_ps(src1, _mm256_set1_ps(src2))
#define MS_SLLI256_EPI32(src1, src2) _mm256_slli_epi32(src1, src2)
#define MS_CVT256PS_EPI32(src) _mm256_cvttps_epi32(src)
#define MS_CVT256EPI32_PS(src) _mm256_cvtepi32_ps(src) // truncate float to int
#define MS_CMP256_F32(src1, src2, src3) _mm256_cmp_ps(src1, src2, src3)
#define MS_CMPGT256_EPI32(src1, src2) _mm256_cmpgt_epi32(src1, src2)
#define MS_BLEND256_F32(src1, src2, src3) _mm256_blendv_ps(src1, src2, src3)
#define MS_BLEND256_EPI32(src1, src2, src3) _mm256_blendv_epi8(src1, src2, src3)
#define MS_CAST256_F32_S32(src) _mm256_castsi256_ps(src)
#define MS_DIV256_EPI32(src1, src2) \
_mm256_cvttps_epi32(MS_DIV256_F32(_mm256_cvtepi32_ps(src1), _mm256_cvtepi32_ps(src2)))
static inline MS_FLOAT32X8 MS_SQRTFX8_F32(MS_FLOAT32X8 src) {
MS_FLOAT32X8 dst;
MS_F32X8_GETI(dst, 0) = sqrtf(MS_F32X8_GETI(src, 0));
MS_F32X8_GETI(dst, 1) = sqrtf(MS_F32X8_GETI(src, 1));
MS_F32X8_GETI(dst, 2) = sqrtf(MS_F32X8_GETI(src, 2));
MS_F32X8_GETI(dst, 3) = sqrtf(MS_F32X8_GETI(src, 3));
MS_F32X8_GETI(dst, 4) = sqrtf(MS_F32X8_GETI(src, 4));
MS_F32X8_GETI(dst, 5) = sqrtf(MS_F32X8_GETI(src, 5));
MS_F32X8_GETI(dst, 6) = sqrtf(MS_F32X8_GETI(src, 6));
MS_F32X8_GETI(dst, 7) = sqrtf(MS_F32X8_GETI(src, 7));
return dst;
}
#define LOAD256X8_F32(src, input_ptr, num) \
MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \
MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \
MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \
MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num); \
MS_FLOAT32X8 src##5 = MS_LD256_F32(input_ptr + 4 * num); \
MS_FLOAT32X8 src##6 = MS_LD256_F32(input_ptr + 5 * num); \
MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \
MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num);
#define LOAD256X16_F32(src, input_ptr, num) \
MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \
MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \
MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \
MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num); \
MS_FLOAT32X8 src##5 = MS_LD256_F32(input_ptr + 4 * num); \
MS_FLOAT32X8 src##6 = MS_LD256_F32(input_ptr + 5 * num); \
MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \
MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num); \
MS_FLOAT32X8 src##9 = MS_LD256_F32(input_ptr + 8 * num); \
MS_FLOAT32X8 src##10 = MS_LD256_F32(input_ptr + 9 * num); \
MS_FLOAT32X8 src##11 = MS_LD256_F32(input_ptr + 10 * num); \
MS_FLOAT32X8 src##12 = MS_LD256_F32(input_ptr + 11 * num); \
MS_FLOAT32X8 src##13 = MS_LD256_F32(input_ptr + 12 * num); \
MS_FLOAT32X8 src##14 = MS_LD256_F32(input_ptr + 13 * num); \
MS_FLOAT32X8 src##15 = MS_LD256_F32(input_ptr + 14 * num); \
MS_FLOAT32X8 src##16 = MS_LD256_F32(input_ptr + 15 * num);
#define STORE256X8_F32(output_ptr, num, dst) \
MS_ST256_F32(output_ptr + 0 * num, dst##1); \
MS_ST256_F32(output_ptr + 1 * num, dst##2); \
MS_ST256_F32(output_ptr + 2 * num, dst##3); \
MS_ST256_F32(output_ptr + 3 * num, dst##4); \
MS_ST256_F32(output_ptr + 4 * num, dst##5); \
MS_ST256_F32(output_ptr + 5 * num, dst##6); \
MS_ST256_F32(output_ptr + 6 * num, dst##7); \
MS_ST256_F32(output_ptr + 7 * num, dst##8);
#define STORE256X16_F32(output_ptr, num, dst) \
MS_ST256_F32(output_ptr + 0 * num, dst##1); \
MS_ST256_F32(output_ptr + 1 * num, dst##2); \
MS_ST256_F32(output_ptr + 2 * num, dst##3); \
MS_ST256_F32(output_ptr + 3 * num, dst##4); \
MS_ST256_F32(output_ptr + 4 * num, dst##5); \
MS_ST256_F32(output_ptr + 5 * num, dst##6); \
MS_ST256_F32(output_ptr + 6 * num, dst##7); \
MS_ST256_F32(output_ptr + 7 * num, dst##8); \
MS_ST256_F32(output_ptr + 8 * num, dst##9); \
MS_ST256_F32(output_ptr + 9 * num, dst##10); \
MS_ST256_F32(output_ptr + 10 * num, dst##11); \
MS_ST256_F32(output_ptr + 11 * num, dst##12); \
MS_ST256_F32(output_ptr + 12 * num, dst##13); \
MS_ST256_F32(output_ptr + 13 * num, dst##14); \
MS_ST256_F32(output_ptr + 14 * num, dst##15); \
MS_ST256_F32(output_ptr + 15 * num, dst##16);
static inline MS_FLOAT32X8 MS_TANHX8_F32(MS_FLOAT32X8 src) {
static const MS_FLOAT32X8 data0 = {378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f};
static const MS_FLOAT32X8 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f};
static const MS_FLOAT32X8 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f,
135135.0f, 135135.0f, 135135.0f, 135135.0f};
static const MS_FLOAT32X8 data3 = {28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f};
static const MS_FLOAT32X8 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f};
static const MS_FLOAT32X8 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f};
static const MS_FLOAT32X8 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f};
static const MS_FLOAT32X8 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
MS_FLOAT32X8 square = MS_MUL256_F32(src, src);
MS_FLOAT32X8 a = MS_MUL256_F32(
MS_ADD256_F32(MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(MS_ADD256_F32(square, data0), square), data1), square),
data2),
src);
MS_FLOAT32X8 b = MS_ADD256_F32(
MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(data3, square), data4), square), data5),
square),
data2);
return MS_MIN256_F32(MS_MAX256_F32(MS_DIV256_F32(a, b), neg), pos);
}
#endif

View File

@ -18,292 +18,103 @@
#define MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
#include <math.h>
#ifdef ENABLE_ARM
#include <arm_neon.h>
#define MS_F32X4_GETI(src, i) src[i]
#endif
// Scaler
#define MS_FLOAT32X1 float
#define MS_INT32X1 int
#define MS_MOV32_F32(value) (value)
#define MS_MOV32_EPI32(value) (value)
#define MS_LD32_F32(address) (*(address))
#define MS_LD32_EPI32(address) (*(address))
#define MS_ST32_F32(address, value) (*(address) = (value))
#define MS_ST32_EPI32(address, value) (*(address) = (value))
#define MS_ADD32_F32(value1, value2) ((value1) + (value2))
#define MS_ADD32_EPI32(value1, value2) ((value1) + (value2))
#define MS_MUL32_F32(value1, value2) ((value1) * (value2))
#define MS_MUL32_EPI32(value1, value2) ((value1) * (value2))
#define MS_DIV32_F32(value1, value2) ((value1) / (value2))
#define MS_DIV32_EPI32(value1, value2) ((value1) / (value2))
#define MS_MIN32_F32(value1, value2) (fmin((value1), (value2)))
#define MS_MIN32_EPI32(value1, value2) ((value1) < (value2) ? (value1) : (value2))
#define MS_MAX32_F32(value1, value2) (fmax((value1), (value2)))
#define MS_MAX32_EPI32(value1, value2) ((value1) > (value2) ? (value1) : (value2))
#if defined(ENABLE_SSE)
#ifdef _MSC_VER
#include <immintrin.h>
#define MS_F32X4_GETI(src, i) src.m128_f32[i]
// define (float/int) data
#define MS_FLOAT_32xN(byte_num) MS_FLOAT32##X##byte_num
#define MS_INT_32xN(byte_num) MS_INT32##X##byte_num
// move (float/int) data
#define MS_MOVN_F32(byte_num, ...) MS_MOV##byte_num##_F32(__VA_ARGS__)
#define MS_MOVN_EPI32(byte_num, ...) MS_MOV##byte_num##_EPI32(__VA_ARGS__)
// load (float/int) data
#define MS_LD_F32(bit_num, ...) MS_LD##bit_num##_F32(__VA_ARGS__)
#define MS_LD_EPI32(bit_num, ...) MS_LD##bit_num##_EPI32(__VA_ARGS__)
// stored (float/int) data
#define MS_ST_F32(bit_num, ...) MS_ST##bit_num##_F32(__VA_ARGS__)
#define MS_ST_EPI32(bit_num, ...) MS_ST##bit_num##_EPI32(__VA_ARGS__)
// add (float/int) op
#define MS_ADD_F32(bit_num, ...) MS_ADD##bit_num##_F32(__VA_ARGS__)
#define MS_ADD_EPI32(bit_num, ...) MS_ADD##bit_num##_EPI32(__VA_ARGS__)
// div (float/int) op
#define MS_DIV_F32(bit_num, ...) MS_DIV##bit_num##_F32(__VA_ARGS__)
#define MS_DIV_EPI32(bit_num, ...) MS_DIV##bit_num##_EPI32(__VA_ARGS__)
// div (float/int) op
#define MS_DIV_N_F32(bit_num, val1, val2) MS_DIV##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2))
#define MS_DIV_N_EPI32(bit_num, val1, val2) MS_DIV##bit_num##_EPI32(val1, MS_MOV##bit_num##_EPI32(val2))
// min (float/int) op
#define MS_MIN_F32(bit_num, ...) MS_MIN##bit_num##_F32(__VA_ARGS__)
#define MS_MIN_EPI32(bit_num, ...) MS_MIN##bit_num##_EPI32(__VA_ARGS__)
// max (float/int) op
#define MS_MAX_F32(bit_num, ...) MS_MAX##bit_num##_F32(__VA_ARGS__)
#define MS_MAX_EPI32(bit_num, ...) MS_MAX##bit_num##_EPI32(__VA_ARGS__)
// mul (float/int) op
#define MS_MUL_F32(bit_num, ...) MS_MUL##bit_num##_F32(__VA_ARGS__)
#define MS_MUL_EPI32(bit_num, ...) MS_MUL##bit_num##_EPI32(__VA_ARGS__)
// square (float/int) op
#define MS_MUL_SQUARE_F32(bit_num, val) (MS_MUL##bit_num##_F32(val, val))
#define MS_MUL_SQUARE_EPI32(bit_num, val) (MS_MUL##bit_num##_EPI32(val, val))
// enable avx512
#if defined(ENABLE_AVX512)
#define MS_SIMD_RUN_AVX512(function, ...) function(512, 16, __VA_ARGS__)
#else
#include <x86intrin.h>
#define MS_F32X4_GETI(src, i) src[i]
#endif
#endif
#ifdef ENABLE_AVX
#ifdef _MSC_VER
#include <immintrin.h>
#define MS_F32X8_GETI(src, i) src.m256_f32[i]
#else
#define MS_F32X8_GETI(src, i) src[i]
#endif
#endif
#ifdef ENABLE_ARM
#define MS_FLOAT32X4 float32x4_t
#define MS_INT32X4 int32x4_t
#define MS_UINT32X4 uint32x4_t
#define MS_LDQ_F32 vld1q_f32
#define MS_LDQ_EPI32 vld1q_s32
#define MS_ADDQ_F32 vaddq_f32
#define MS_ADDQ_EPI32 vaddq_s32
#define MS_MOVQ_F32 vmovq_n_f32
#define MS_MOVQ_EPI32 vmovq_n_s32
#define MS_SUBQ_F32 vsubq_f32
#define MS_MLAQ_F32(src1, src2, src3) vmlaq_f32(src1, src2, src3)
#define MS_STQ_F32 vst1q_f32
#define MS_STQ_EPI32 vst1q_s32
#define MS_MAXQ_F32 vmaxq_f32
#define MS_MAXQ_EPI32 vmaxq_s32
#define MS_MINQ_F32 vminq_f32
#define MS_MINQ_EPI32 vminq_s32
#define MS_MULQ_F32(src1, src2) vmulq_f32(src1, src2)
#define MS_MULQ_EPI32(src1, src2) vmulq_s32(src1, src2)
#ifdef ENABLE_ARM64
#define MS_DIVQ_F32(src1, src2) vdivq_f32(src1, src2)
#else
static inline float32x4_t vrecp(float32x4_t v) {
float32x4_t r = vrecpeq_f32(v);
r = vmulq_f32(vrecpsq_f32(v, r), r);
r = vmulq_f32(vrecpsq_f32(v, r), r);
return r;
}
#define MS_DIVQ_F32(src1, src2) vmulq_f32(src1, vrecp(src2))
#endif
#define MS_MULQ_N_F32(src1, src2) vmulq_n_f32(src1, src2)
#define MS_MULQ_N_EPI32(src1, src2) vmulq_n_s32(src1, src2)
#define MS_DIVQ_N_F32(src1, src2) vdivq_n_f32(src1, src2)
#define MS_SLLIQ_EPI32(src1, src2) vshlq_s32(src1, vmovq_n_s32(src2))
#define MS_CVTQPS_EPI32(src) vcvtq_s32_f32(src)
#define MS_CVTQEPI32_PS(src) vcvtq_f32_s32(src)
#define MS_CMPLEQ_F32(src1, src2) vcleq_f32(src1, src2)
#define MS_CMPGTQ_F32(src1, src2) vcgtq_f32(src1, src2)
#define MS_CMPGTQ_EPI32(src1, src2) vcgtq_s32(src1, src2)
// Note: Compared with X86, the vbslq_f32 parameters are the opposite with _mm_blendv_f32
#define MS_BLENDQ_F32(src1, src2, src3) vbslq_f32(src3, src2, src1)
#define MS_BLENDQ_EPI32(src1, src2, src3) vbslq_s32(src3, src2, src1)
#define MS_CAST_F32_S32(src) vreinterpretq_f32_s32(src)
#define MS_SIMD_RUN_AVX512(function, ...)
#endif
// enable avx256
#if defined(ENABLE_AVX)
#define MS_FLOAT32X8 __m256
#define MS_INT32X8 __m256i
#define MS_LD256_F32 _mm256_loadu_ps
#define MS_LD256_EPI32(src) _mm256_loadu_si256((__m256i const *)(src))
#define MS_ADD256_F32 _mm256_add_ps
#define MS_ADD256_EPI32 _mm256_add_epi32
#define MS_MOV256_F32 _mm256_set1_ps
#define MS_MOV256_EPI32 _mm256_set1_epi32
#define MS_MLA256_F32(src1, src2, src3) _mm256_fmadd_ps(src2, src3, src1)
#define MS_ST256_F32 _mm256_storeu_ps
#define MS_ST256_EPI32(src1, src2) _mm256_storeu_si256((__m256i *)(src1), src2)
#define MS_SUB256_F32 _mm256_sub_ps
#define MS_MAX256_F32 _mm256_max_ps
#define MS_MAX256_EPI32 _mm256_max_epi32
#define MS_MIN256_F32 _mm256_min_ps
#define MS_MIN256_EPI32 _mm256_min_epi32
#define MS_MUL256_F32(src1, src2) _mm256_mul_ps(src1, src2)
#define MS_MUL256_EPI32(src1, src2) _mm256_mul_epi32(src1, src2)
#define MS_DIV256_F32(src1, src2) _mm256_div_ps(src1, src2)
#define MS_MUL256_N_F32(src1, src2) _mm256_mul_ps(src1, _mm256_set1_ps(src2))
#define MS_MUL256_N_EPI32(src1, src2) _mm256_mul_epi32(src1, _mm256_set1_epi32(src2))
#define MS_DIV256_N_F32(src1, src2) _mm256_div_ps(src1, _mm256_set1_ps(src2))
#define MS_SLLI256_EPI32(src1, src2) _mm256_slli_epi32(src1, src2)
#define MS_CVT256PS_EPI32(src) _mm256_cvttps_epi32(src)
#define MS_CVT256EPI32_PS(src) _mm256_cvtepi32_ps(src) // truncate float to int
#define MS_CMP256_F32(src1, src2, src3) _mm256_cmp_ps(src1, src2, src3)
#define MS_CMPGT256_EPI32(src1, src2) _mm256_cmpgt_epi32(src1, src2)
#define MS_BLEND256_F32(src1, src2, src3) _mm256_blendv_ps(src1, src2, src3)
#define MS_BLEND256_EPI32(src1, src2, src3) _mm256_blendv_epi8(src1, src2, src3)
#define MS_CAST256_F32_S32(src) _mm256_castsi256_ps(src)
#define MS_SIMD_RUN_AVX(function, ...) function(256, 8, __VA_ARGS__)
#else
#define MS_SIMD_RUN_AVX(function, ...)
#endif
#if defined(ENABLE_SSE)
#define MS_FLOAT32X4 __m128
#define MS_INT32X4 __m128i
#define MS_LDQ_F32 _mm_loadu_ps
#define MS_LDQ_EPI32(src) _mm_loadu_si128((__m128i const *)(src))
#define MS_ADDQ_F32 _mm_add_ps
#define MS_ADDQ_EPI32 _mm_add_epi32
#define MS_MOVQ_F32 _mm_set1_ps
#define MS_MOVQ_EPI32 _mm_set1_epi32
#define MS_MLAQ_F32(src1, src2, src3) _mm_add_ps(src1, _mm_mul_ps(src2, src3))
#define MS_STQ_F32 _mm_storeu_ps
#define MS_STQ_EPI32(src1, src2) _mm_storeu_si128((__m128i *)(src1), src2)
#define MS_SUBQ_F32 _mm_sub_ps
#define MS_MAXQ_F32 _mm_max_ps
#define MS_MAXQ_EPI32 _mm_max_epi32
#define MS_MINQ_F32 _mm_min_ps
#define MS_MINQ_EPI32 _mm_min_epi32
#define MS_MULQ_F32(src1, src2) _mm_mul_ps(src1, src2)
#define MS_MULQ_EPI32(src1, src2) _mm_mul_epi32(src1, src2)
#define MS_DIVQ_F32(src1, src2) _mm_div_ps(src1, src2)
#define MS_MULQ_N_F32(src1, src2) _mm_mul_ps(src1, _mm_set1_ps(src2))
#define MS_MULQ_N_EPI32(src1, src2) _mm_mul_epi32(src1, _mm_set1_epi32(src2))
#define MS_DIVQ_N_F32(src1, src2) _mm_div_ps(src1, _mm_set1_ps(src2))
#define MS_SLLIQ_EPI32(src1, src2) _mm_slli_epi32(src1, src2)
#define MS_CVTQPS_EPI32(src) _mm_cvttps_epi32(src) // truncate float to int
#define MS_CVTQEPI32_PS(src) _mm_cvtepi32_ps(src)
#define MS_CMPLEQ_F32(src1, src2) _mm_cmple_ps(src1, src2)
#define MS_CMPGTQ_F32(src1, src2) _mm_cmpgt_ps(src1, src2)
#define MS_CMPGTQ_EPI32(src1, src2) _mm_cmpgt_epi32(src1, src2)
#define MS_BLENDQ_F32(src1, src2, src3) _mm_blendv_ps(src1, src2, src3)
#define MS_BLENDQ_EPI32(src1, src2, src3) _mm_blendv_epi8(src1, src2, src3)
#define MS_CAST_F32_S32(src) _mm_castsi128_ps(src)
// enable neon/sse
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
#define MS_SIMD_RUN_SSEORNEON128(function, ...) function(128, 4, __VA_ARGS__)
#else
#define MS_SIMD_RUN_SSEORNEON128(function, ...)
#endif
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
static inline MS_FLOAT32X4 MS_SQRTFX4_F32(MS_FLOAT32X4 src) {
MS_FLOAT32X4 dst;
MS_F32X4_GETI(dst, 0) = sqrtf(MS_F32X4_GETI(src, 0));
MS_F32X4_GETI(dst, 1) = sqrtf(MS_F32X4_GETI(src, 1));
MS_F32X4_GETI(dst, 2) = sqrtf(MS_F32X4_GETI(src, 2));
MS_F32X4_GETI(dst, 3) = sqrtf(MS_F32X4_GETI(src, 3));
return dst;
}
// scalar (c style data)
#define MS_SIMD_RUN_SCALAR(function, ...) function(32, 1, __VA_ARGS__)
#define LOAD128X8_F32(src, input_ptr, num) \
MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \
MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \
MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \
MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \
MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \
MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \
MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \
MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num);
#define MS_SIMD_RUN(function, ...) \
MS_SIMD_RUN_AVX512(function, __VA_ARGS__); \
MS_SIMD_RUN_AVX(function, __VA_ARGS__); \
MS_SIMD_RUN_SSEORNEON128(function, __VA_ARGS__); \
MS_SIMD_RUN_SCALAR(function, __VA_ARGS__);
#define STORE128X8_F32(output_ptr, num, dst) \
MS_STQ_F32(output_ptr + 0 * num, dst##1); \
MS_STQ_F32(output_ptr + 1 * num, dst##2); \
MS_STQ_F32(output_ptr + 2 * num, dst##3); \
MS_STQ_F32(output_ptr + 3 * num, dst##4); \
MS_STQ_F32(output_ptr + 4 * num, dst##5); \
MS_STQ_F32(output_ptr + 5 * num, dst##6); \
MS_STQ_F32(output_ptr + 6 * num, dst##7); \
MS_STQ_F32(output_ptr + 7 * num, dst##8);
static inline MS_FLOAT32X4 MS_TANHX4_F32(MS_FLOAT32X4 src) {
static const MS_FLOAT32X4 data0 = {378.0f, 378.0f, 378.0f, 378.0f};
static const MS_FLOAT32X4 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f};
static const MS_FLOAT32X4 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f};
static const MS_FLOAT32X4 data3 = {28.0f, 28.0f, 28.0f, 28.0f};
static const MS_FLOAT32X4 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f};
static const MS_FLOAT32X4 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f};
static const MS_FLOAT32X4 neg = {-1.0f, -1.0f, -1.0f, -1.0f};
static const MS_FLOAT32X4 pos = {1.0f, 1.0f, 1.0f, 1.0f};
MS_FLOAT32X4 square = MS_MULQ_F32(src, src);
MS_FLOAT32X4 a = MS_MULQ_F32(
MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(square, data0), square), data1), square), data2), src);
MS_FLOAT32X4 b = MS_ADDQ_F32(
MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(data3, square), data4), square), data5), square),
data2);
return MS_MINQ_F32(MS_MAXQ_F32(MS_DIVQ_F32(a, b), neg), pos);
}
static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) {
MS_FLOAT32X4 dst;
MS_F32X4_GETI(dst, 0) = erff(MS_F32X4_GETI(src, 0));
MS_F32X4_GETI(dst, 1) = erff(MS_F32X4_GETI(src, 1));
MS_F32X4_GETI(dst, 2) = erff(MS_F32X4_GETI(src, 2));
MS_F32X4_GETI(dst, 3) = erff(MS_F32X4_GETI(src, 3));
return dst;
}
#endif
#ifdef ENABLE_AVX
static inline MS_FLOAT32X8 MS_SQRTFX8_F32(MS_FLOAT32X8 src) {
MS_FLOAT32X8 dst;
MS_F32X8_GETI(dst, 0) = sqrtf(MS_F32X8_GETI(src, 0));
MS_F32X8_GETI(dst, 1) = sqrtf(MS_F32X8_GETI(src, 1));
MS_F32X8_GETI(dst, 2) = sqrtf(MS_F32X8_GETI(src, 2));
MS_F32X8_GETI(dst, 3) = sqrtf(MS_F32X8_GETI(src, 3));
MS_F32X8_GETI(dst, 4) = sqrtf(MS_F32X8_GETI(src, 4));
MS_F32X8_GETI(dst, 5) = sqrtf(MS_F32X8_GETI(src, 5));
MS_F32X8_GETI(dst, 6) = sqrtf(MS_F32X8_GETI(src, 6));
MS_F32X8_GETI(dst, 7) = sqrtf(MS_F32X8_GETI(src, 7));
return dst;
}
#define LOAD256X8_F32(src, input_ptr, num) \
MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \
MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \
MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \
MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num); \
MS_FLOAT32X8 src##5 = MS_LD256_F32(input_ptr + 4 * num); \
MS_FLOAT32X8 src##6 = MS_LD256_F32(input_ptr + 5 * num); \
MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \
MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num);
#define LOAD256X16_F32(src, input_ptr, num) \
MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \
MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \
MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \
MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num); \
MS_FLOAT32X8 src##5 = MS_LD256_F32(input_ptr + 4 * num); \
MS_FLOAT32X8 src##6 = MS_LD256_F32(input_ptr + 5 * num); \
MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \
MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num); \
MS_FLOAT32X8 src##9 = MS_LD256_F32(input_ptr + 8 * num); \
MS_FLOAT32X8 src##10 = MS_LD256_F32(input_ptr + 9 * num); \
MS_FLOAT32X8 src##11 = MS_LD256_F32(input_ptr + 10 * num); \
MS_FLOAT32X8 src##12 = MS_LD256_F32(input_ptr + 11 * num); \
MS_FLOAT32X8 src##13 = MS_LD256_F32(input_ptr + 12 * num); \
MS_FLOAT32X8 src##14 = MS_LD256_F32(input_ptr + 13 * num); \
MS_FLOAT32X8 src##15 = MS_LD256_F32(input_ptr + 14 * num); \
MS_FLOAT32X8 src##16 = MS_LD256_F32(input_ptr + 15 * num);
#define STORE256X8_F32(output_ptr, num, dst) \
MS_ST256_F32(output_ptr + 0 * num, dst##1); \
MS_ST256_F32(output_ptr + 1 * num, dst##2); \
MS_ST256_F32(output_ptr + 2 * num, dst##3); \
MS_ST256_F32(output_ptr + 3 * num, dst##4); \
MS_ST256_F32(output_ptr + 4 * num, dst##5); \
MS_ST256_F32(output_ptr + 5 * num, dst##6); \
MS_ST256_F32(output_ptr + 6 * num, dst##7); \
MS_ST256_F32(output_ptr + 7 * num, dst##8);
#define STORE256X16_F32(output_ptr, num, dst) \
MS_ST256_F32(output_ptr + 0 * num, dst##1); \
MS_ST256_F32(output_ptr + 1 * num, dst##2); \
MS_ST256_F32(output_ptr + 2 * num, dst##3); \
MS_ST256_F32(output_ptr + 3 * num, dst##4); \
MS_ST256_F32(output_ptr + 4 * num, dst##5); \
MS_ST256_F32(output_ptr + 5 * num, dst##6); \
MS_ST256_F32(output_ptr + 6 * num, dst##7); \
MS_ST256_F32(output_ptr + 7 * num, dst##8); \
MS_ST256_F32(output_ptr + 8 * num, dst##9); \
MS_ST256_F32(output_ptr + 9 * num, dst##10); \
MS_ST256_F32(output_ptr + 10 * num, dst##11); \
MS_ST256_F32(output_ptr + 11 * num, dst##12); \
MS_ST256_F32(output_ptr + 12 * num, dst##13); \
MS_ST256_F32(output_ptr + 13 * num, dst##14); \
MS_ST256_F32(output_ptr + 14 * num, dst##15); \
MS_ST256_F32(output_ptr + 15 * num, dst##16);
static inline MS_FLOAT32X8 MS_TANHX8_F32(MS_FLOAT32X8 src) {
static const MS_FLOAT32X8 data0 = {378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f};
static const MS_FLOAT32X8 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f};
static const MS_FLOAT32X8 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f,
135135.0f, 135135.0f, 135135.0f, 135135.0f};
static const MS_FLOAT32X8 data3 = {28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f};
static const MS_FLOAT32X8 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f};
static const MS_FLOAT32X8 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f};
static const MS_FLOAT32X8 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f};
static const MS_FLOAT32X8 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
MS_FLOAT32X8 square = MS_MUL256_F32(src, src);
MS_FLOAT32X8 a = MS_MUL256_F32(
MS_ADD256_F32(MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(MS_ADD256_F32(square, data0), square), data1), square),
data2),
src);
MS_FLOAT32X8 b = MS_ADD256_F32(
MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(data3, square), data4), square), data5),
square),
data2);
return MS_MIN256_F32(MS_MAX256_F32(MS_DIV256_F32(a, b), neg), pos);
}
#endif
#define MS_SIMD_RUN_NO_SCALAR(function, ...) \
MS_SIMD_RUN_AVX512(function, __VA_ARGS__); \
MS_SIMD_RUN_AVX(function, __VA_ARGS__); \
MS_SIMD_RUN_SSEORNEON128(function, __VA_ARGS__);
#endif // MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_

View File

@ -0,0 +1,148 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version C2NUM.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-C2NUM.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_AVX512_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
#define MINDSPORE_NNACL_AVX512_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
#include <math.h>
#include <arm_neon.h>
#define MS_F32X4_GETI(src, i) src[i]
#define MS_FLOAT32X4 float32x4_t
#define MS_INT32X4 int32x4_t
#define MS_UINT32X4 uint32x4_t
#define MS_LDQ_F32 vld1q_f32
#define MS_LD128_F32 vld1q_f32
#define MS_LDQ_EPI32 vld1q_s32
#define MS_LD128_EPI32 vld1q_s32
#define MS_ADDQ_F32 vaddq_f32
#define MS_ADD128_F32 vaddq_f32
#define MS_ADDQ_EPI32 vaddq_s32
#define MS_ADD128_EPI32 vaddq_s32
#define MS_MOVQ_F32 vmovq_n_f32
#define MS_MOV128_F32 vmovq_n_f32
#define MS_MOVQ_EPI32 vmovq_n_s32
#define MS_MOV128_EPI32 vmovq_n_s32
#define MS_SUBQ_F32 vsubq_f32
#define MS_MLAQ_F32(src1, src2, src3) vmlaq_f32(src1, src2, src3)
#define MS_STQ_F32 vst1q_f32
#define MS_ST128_F32 vst1q_f32
#define MS_STQ_EPI32 vst1q_s32
#define MS_ST128_EPI32 vst1q_s32
#define MS_MAXQ_F32 vmaxq_f32
#define MS_MAXQ_EPI32 vmaxq_s32
#define MS_MAX128_F32 vmaxq_f32
#define MS_MAX128_EPI32 vmaxq_s32
#define MS_MINQ_F32 vminq_f32
#define MS_MINQ_EPI32 vminq_s32
#define MS_MULQ_F32(src1, src2) vmulq_f32(src1, src2)
#define MS_MULQ_EPI32(src1, src2) vmulq_s32(src1, src2)
#define MS_MIN128_F32 vminq_f32
#define MS_MIN128_EPI32 vminq_s32
#define MS_MUL128_F32(src1, src2) vmulq_f32(src1, src2)
#define MS_MUL128_EPI32(src1, src2) vmulq_s32(src1, src2)
#ifdef ENABLE_ARM64
#define MS_DIVQ_F32(src1, src2) vdivq_f32(src1, src2)
#define MS_DIV128_F32(src1, src2) vdivq_f32(src1, src2)
#else
static inline float32x4_t vrecp(float32x4_t v) {
float32x4_t r = vrecpeq_f32(v);
r = vmulq_f32(vrecpsq_f32(v, r), r);
r = vmulq_f32(vrecpsq_f32(v, r), r);
return r;
}
#define MS_DIVQ_F32(src1, src2) vmulq_f32(src1, vrecp(src2))
#define MS_DIV128_F32(src1, src2) vmulq_f32(src1, vrecp(src2))
#endif
#define MS_MULQ_N_F32(src1, src2) vmulq_n_f32(src1, src2)
#define MS_MULQ_N_EPI32(src1, src2) vmulq_n_s32(src1, src2)
#define MS_DIVQ_N_F32(src1, src2) vdivq_n_f32(src1, src2)
#define MS_SLLIQ_EPI32(src1, src2) vshlq_s32(src1, vmovq_n_s32(src2))
#define MS_CVTQPS_EPI32(src) vcvtq_s32_f32(src)
#define MS_CVTQEPI32_PS(src) vcvtq_f32_s32(src)
#define MS_CMPLEQ_F32(src1, src2) vcleq_f32(src1, src2)
#define MS_CMPGTQ_F32(src1, src2) vcgtq_f32(src1, src2)
#define MS_CMPGTQ_EPI32(src1, src2) vcgtq_s32(src1, src2)
// Note: Compared with X86, the vbslq_f32 parameters are the opposite with _mm_blendv_f32
#define MS_BLENDQ_F32(src1, src2, src3) vbslq_f32(src3, src2, src1)
#define MS_BLENDQ_EPI32(src1, src2, src3) vbslq_s32(src3, src2, src1)
#define MS_CAST_F32_S32(src) vreinterpretq_f32_s32(src)
static inline int32x4_t MS_DIV128_EPI32(int32x4_t src1, int32x4_t src2) {
int32x4_t result;
result[0] = src1[0] / src2[0]; // C0 : 0
result[1] = src1[1] / src2[1]; // C1 : 1
result[2] = src1[2] / src2[2]; // C2 : 2
result[3] = src1[3] / src2[3]; // C3 : 3
return result;
}
static inline MS_FLOAT32X4 MS_SQRTFX4_F32(MS_FLOAT32X4 src) {
MS_FLOAT32X4 dst;
MS_F32X4_GETI(dst, 0) = sqrtf(MS_F32X4_GETI(src, 0));
MS_F32X4_GETI(dst, 1) = sqrtf(MS_F32X4_GETI(src, 1));
MS_F32X4_GETI(dst, 2) = sqrtf(MS_F32X4_GETI(src, 2));
MS_F32X4_GETI(dst, 3) = sqrtf(MS_F32X4_GETI(src, 3));
return dst;
}
#define LOAD128X8_F32(src, input_ptr, num) \
MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \
MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \
MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \
MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \
MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \
MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \
MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \
MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num);
#define STORE128X8_F32(output_ptr, num, dst) \
MS_STQ_F32(output_ptr + 0 * num, dst##1); \
MS_STQ_F32(output_ptr + 1 * num, dst##2); \
MS_STQ_F32(output_ptr + 2 * num, dst##3); \
MS_STQ_F32(output_ptr + 3 * num, dst##4); \
MS_STQ_F32(output_ptr + 4 * num, dst##5); \
MS_STQ_F32(output_ptr + 5 * num, dst##6); \
MS_STQ_F32(output_ptr + 6 * num, dst##7); \
MS_STQ_F32(output_ptr + 7 * num, dst##8);
static inline MS_FLOAT32X4 MS_TANHX4_F32(MS_FLOAT32X4 src) {
static const MS_FLOAT32X4 data0 = {378.0f, 378.0f, 378.0f, 378.0f};
static const MS_FLOAT32X4 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f};
static const MS_FLOAT32X4 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f};
static const MS_FLOAT32X4 data3 = {28.0f, 28.0f, 28.0f, 28.0f};
static const MS_FLOAT32X4 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f};
static const MS_FLOAT32X4 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f};
static const MS_FLOAT32X4 neg = {-1.0f, -1.0f, -1.0f, -1.0f};
static const MS_FLOAT32X4 pos = {1.0f, 1.0f, 1.0f, 1.0f};
MS_FLOAT32X4 square = MS_MULQ_F32(src, src);
MS_FLOAT32X4 a = MS_MULQ_F32(
MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(square, data0), square), data1), square), data2), src);
MS_FLOAT32X4 b = MS_ADDQ_F32(
MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(data3, square), data4), square), data5), square),
data2);
return MS_MINQ_F32(MS_MAXQ_F32(MS_DIVQ_F32(a, b), neg), pos);
}
static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) {
MS_FLOAT32X4 dst;
MS_F32X4_GETI(dst, 0) = erff(MS_F32X4_GETI(src, 0));
MS_F32X4_GETI(dst, 1) = erff(MS_F32X4_GETI(src, 1));
MS_F32X4_GETI(dst, 2) = erff(MS_F32X4_GETI(src, 2));
MS_F32X4_GETI(dst, 3) = erff(MS_F32X4_GETI(src, 3));
return dst;
}
#endif

View File

@ -0,0 +1,133 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version C2NUM.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-C2NUM.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_SSE_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
#define MINDSPORE_NNACL_SSE_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
#include <math.h>
#ifdef _MSC_VER
#include <immintrin.h>
#define MS_F32X4_GETI(src, i) src.m128_f32[i]
#else
#include <x86intrin.h>
#define MS_F32X4_GETI(src, i) src[i]
#endif
#define MS_FLOAT32X4 __m128
#define MS_INT32X4 __m128i
#define MS_LDQ_F32 _mm_loadu_ps
#define MS_LD128_F32 _mm_loadu_ps
#define MS_LDQ_EPI32(src) _mm_loadu_si128((__m128i const *)(src))
#define MS_LD128_EPI32(src) _mm_loadu_si128((__m128i const *)(src))
#define MS_ADDQ_F32 _mm_add_ps
#define MS_ADD128_F32 _mm_add_ps
#define MS_ADDQ_EPI32 _mm_add_epi32
#define MS_ADD128_EPI32 _mm_add_epi32
#define MS_MOVQ_F32 _mm_set1_ps
#define MS_MOV128_F32 _mm_set1_ps
#define MS_MOVQ_EPI32 _mm_set1_epi32
#define MS_MOV128_EPI32 _mm_set1_epi32
#define MS_MLAQ_F32(src1, src2, src3) _mm_add_ps(src1, _mm_mul_ps(src2, src3))
#define MS_STQ_F32 _mm_storeu_ps
#define MS_ST128_F32 _mm_storeu_ps
#define MS_STQ_EPI32(src1, src2) _mm_storeu_si128((__m128i *)(src1), src2)
#define MS_ST128_EPI32(src1, src2) _mm_storeu_si128((__m128i *)(src1), src2)
#define MS_SUBQ_F32 _mm_sub_ps
#define MS_MAXQ_F32 _mm_max_ps
#define MS_MAXQ_EPI32 _mm_max_epi32
#define MS_MAX128_F32 _mm_max_ps
#define MS_MAX128_EPI32 _mm_max_epi32
#define MS_MINQ_F32 _mm_min_ps
#define MS_MINQ_EPI32 _mm_min_epi32
#define MS_MULQ_F32(src1, src2) _mm_mul_ps(src1, src2)
#define MS_MULQ_EPI32(src1, src2) _mm_mul_epi32(src1, src2)
#define MS_MIN128_F32 _mm_min_ps
#define MS_MIN128_EPI32 _mm_min_epi32
#define MS_MUL128_F32(src1, src2) _mm_mul_ps(src1, src2)
#define MS_MUL128_EPI32(src1, src2) _mm_mul_epi32(src1, src2)
#define MS_DIVQ_F32(src1, src2) _mm_div_ps(src1, src2)
#define MS_DIV128_F32(src1, src2) _mm_div_ps(src1, src2)
#define MS_MULQ_N_F32(src1, src2) _mm_mul_ps(src1, _mm_set1_ps(src2))
#define MS_MULQ_N_EPI32(src1, src2) _mm_mul_epi32(src1, _mm_set1_epi32(src2))
#define MS_DIVQ_N_F32(src1, src2) _mm_div_ps(src1, _mm_set1_ps(src2))
#define MS_SLLIQ_EPI32(src1, src2) _mm_slli_epi32(src1, src2)
#define MS_CVTQPS_EPI32(src) _mm_cvttps_epi32(src) // truncate float to int
#define MS_CVTQEPI32_PS(src) _mm_cvtepi32_ps(src)
#define MS_CMPLEQ_F32(src1, src2) _mm_cmple_ps(src1, src2)
#define MS_CMPGTQ_F32(src1, src2) _mm_cmpgt_ps(src1, src2)
#define MS_CMPGTQ_EPI32(src1, src2) _mm_cmpgt_epi32(src1, src2)
#define MS_BLENDQ_F32(src1, src2, src3) _mm_blendv_ps(src1, src2, src3)
#define MS_BLENDQ_EPI32(src1, src2, src3) _mm_blendv_epi8(src1, src2, src3)
#define MS_CAST_F32_S32(src) _mm_castsi128_ps(src)
#define MS_DIV128_EPI32(src1, src2) _mm_cvttps_epi32(MS_DIV128_F32(_mm_cvtepi32_ps(src1), _mm_cvtepi32_ps(src2)))
static inline MS_FLOAT32X4 MS_SQRTFX4_F32(MS_FLOAT32X4 src) {
MS_FLOAT32X4 dst;
MS_F32X4_GETI(dst, 0) = sqrtf(MS_F32X4_GETI(src, 0));
MS_F32X4_GETI(dst, 1) = sqrtf(MS_F32X4_GETI(src, 1));
MS_F32X4_GETI(dst, 2) = sqrtf(MS_F32X4_GETI(src, 2));
MS_F32X4_GETI(dst, 3) = sqrtf(MS_F32X4_GETI(src, 3));
return dst;
}
#define LOAD128X8_F32(src, input_ptr, num) \
MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \
MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \
MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \
MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \
MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \
MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \
MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \
MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num);
#define STORE128X8_F32(output_ptr, num, dst) \
MS_STQ_F32(output_ptr + 0 * num, dst##1); \
MS_STQ_F32(output_ptr + 1 * num, dst##2); \
MS_STQ_F32(output_ptr + 2 * num, dst##3); \
MS_STQ_F32(output_ptr + 3 * num, dst##4); \
MS_STQ_F32(output_ptr + 4 * num, dst##5); \
MS_STQ_F32(output_ptr + 5 * num, dst##6); \
MS_STQ_F32(output_ptr + 6 * num, dst##7); \
MS_STQ_F32(output_ptr + 7 * num, dst##8);
static inline MS_FLOAT32X4 MS_TANHX4_F32(MS_FLOAT32X4 src) {
static const MS_FLOAT32X4 data0 = {378.0f, 378.0f, 378.0f, 378.0f};
static const MS_FLOAT32X4 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f};
static const MS_FLOAT32X4 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f};
static const MS_FLOAT32X4 data3 = {28.0f, 28.0f, 28.0f, 28.0f};
static const MS_FLOAT32X4 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f};
static const MS_FLOAT32X4 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f};
static const MS_FLOAT32X4 neg = {-1.0f, -1.0f, -1.0f, -1.0f};
static const MS_FLOAT32X4 pos = {1.0f, 1.0f, 1.0f, 1.0f};
MS_FLOAT32X4 square = MS_MULQ_F32(src, src);
MS_FLOAT32X4 a = MS_MULQ_F32(
MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(square, data0), square), data1), square), data2), src);
MS_FLOAT32X4 b = MS_ADDQ_F32(
MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(data3, square), data4), square), data5), square),
data2);
return MS_MINQ_F32(MS_MAXQ_F32(MS_DIVQ_F32(a, b), neg), pos);
}
static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) {
MS_FLOAT32X4 dst;
MS_F32X4_GETI(dst, 0) = erff(MS_F32X4_GETI(src, 0));
MS_F32X4_GETI(dst, 1) = erff(MS_F32X4_GETI(src, 1));
MS_F32X4_GETI(dst, 2) = erff(MS_F32X4_GETI(src, 2));
MS_F32X4_GETI(dst, 3) = erff(MS_F32X4_GETI(src, 3));
return dst;
}
#endif

View File

@ -27,10 +27,20 @@
#include "nnacl/intrinsics/ms_simd_avx512_instructions.h"
#endif
#if defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM)
#include "nnacl/intrinsics/ms_simd_instructions.h"
#ifdef ENABLE_AVX
#include "nnacl/intrinsics/ms_simd_avx_instructions.h"
#endif
#ifdef ENABLE_SSE
#include "nnacl/intrinsics/ms_simd_sse_instructions.h"
#endif
#ifdef ENABLE_NEON
#include "nnacl/intrinsics/ms_simd_neon_instructions.h"
#endif
#include "nnacl/intrinsics/ms_simd_instructions.h"
#define C1NUM 1
#define C2NUM 2
#define C3NUM 3