forked from mindspore-Ecosystem/mindspore
avx512 self feel, support matmal op
This commit is contained in:
parent
d813a7da0a
commit
b2aa35fa25
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include "nnacl/fp32/matmul_avx512_fp32.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
||||
|
||||
void GemmRowxColKernelFp32(float *dst, const float *src, const float *weight, const float *bias, const size_t act_flag,
|
||||
const size_t row_block, const size_t col_block, const size_t depth, const size_t src_stride,
|
||||
|
@ -202,5 +203,53 @@ void MatVecMulAvx512Fp32(const float *a, const float *b, float *c, const float *
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
|
||||
int64_t GemmIsNotPackOptimizeAVX512(int64_t m_index, const float *a, const float *b, float *c, const float *bias, int m,
|
||||
int k, int act_type) {
|
||||
// gemm dot is [m, k] * [k, 1] ==>> [m, 1]
|
||||
// block 8
|
||||
MS_FLOAT32X8 down_threshold256 = _mm256_setzero_ps();
|
||||
MS_FLOAT32X8 up_threshold256 = _mm256_set1_ps(C6NUM);
|
||||
for (; m_index <= m - C8NUM; m_index += C8NUM) {
|
||||
int k_index = 0;
|
||||
MS_FLOAT32X8 dst = MS_MOV256_F32(bias[0]);
|
||||
MS_SET_ZERO512X8_F32(dst16_)
|
||||
for (; k_index <= k - C16NUM; k_index += C16NUM) {
|
||||
__m512 weight = _mm512_loadu_ps(b + k_index);
|
||||
MS_LOAD512X8_F32(src, a + m_index * k + k_index, k)
|
||||
MS_FMADD512X8_F32(src, weight, dst16_)
|
||||
}
|
||||
MS_F32X8_GETI(dst, 0) += MS_REDUCE_ADD512_F32(dst16_1);
|
||||
MS_F32X8_GETI(dst, 1) += MS_REDUCE_ADD512_F32(dst16_2);
|
||||
MS_F32X8_GETI(dst, C2NUM) += MS_REDUCE_ADD512_F32(dst16_3);
|
||||
MS_F32X8_GETI(dst, C3NUM) += MS_REDUCE_ADD512_F32(dst16_4);
|
||||
MS_F32X8_GETI(dst, C4NUM) += MS_REDUCE_ADD512_F32(dst16_5);
|
||||
MS_F32X8_GETI(dst, C5NUM) += MS_REDUCE_ADD512_F32(dst16_6);
|
||||
MS_F32X8_GETI(dst, C6NUM) += MS_REDUCE_ADD512_F32(dst16_7);
|
||||
MS_F32X8_GETI(dst, C7NUM) += MS_REDUCE_ADD512_F32(dst16_8);
|
||||
for (; k_index < k; k_index++) {
|
||||
MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index];
|
||||
MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k];
|
||||
MS_F32X8_GETI(dst, C2NUM) += b[k_index] * a[m_index * k + k_index + C2NUM * k];
|
||||
MS_F32X8_GETI(dst, C3NUM) += b[k_index] * a[m_index * k + k_index + C3NUM * k];
|
||||
MS_F32X8_GETI(dst, C4NUM) += b[k_index] * a[m_index * k + k_index + C4NUM * k];
|
||||
MS_F32X8_GETI(dst, C5NUM) += b[k_index] * a[m_index * k + k_index + C5NUM * k];
|
||||
MS_F32X8_GETI(dst, C6NUM) += b[k_index] * a[m_index * k + k_index + C6NUM * k];
|
||||
MS_F32X8_GETI(dst, C7NUM) += b[k_index] * a[m_index * k + k_index + C7NUM * k];
|
||||
}
|
||||
|
||||
if (act_type != 0) {
|
||||
dst = MS_MAX256_F32(dst, down_threshold256);
|
||||
if (act_type == 3) { // 3: relu6
|
||||
dst = MS_MIN256_F32(dst, up_threshold256);
|
||||
}
|
||||
}
|
||||
|
||||
MS_ST256_F32(c + m_index, dst);
|
||||
}
|
||||
return m_index;
|
||||
}
|
||||
|
||||
#pragma GCC pop_options
|
||||
#endif
|
||||
|
|
|
@ -30,6 +30,9 @@ void MatVecMulAvx512Fp32(const float *a, const float *b, float *c, const float *
|
|||
void MatMulAvx512Fp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int cur_col,
|
||||
int col_align, int row);
|
||||
|
||||
int64_t GemmIsNotPackOptimizeAVX512(int64_t m_index, const float *a, const float *b, float *c, const float *bias, int m,
|
||||
int k, int act_type);
|
||||
|
||||
// 64 block
|
||||
void nnacl_gemm_avx512_6x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
|
|
|
@ -16,8 +16,25 @@
|
|||
|
||||
#include "nnacl/fp32/matmul_fp32.h"
|
||||
#include "nnacl/fp32/pack_fp32.h"
|
||||
#include "nnacl/fp32/matmul_avx512_fp32.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
||||
|
||||
#ifdef ENABLE_AVX512
|
||||
#include "nnacl/avx512/matmul_fp32_avx512.h"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
#include "nnacl/avx/matmul_fp32_avx.h"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_SSE
|
||||
#include "nnacl/sse/matmul_fp32_sse.h"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_ARM
|
||||
#include "nnacl/neon/matmul_fp32_neon.h"
|
||||
#endif
|
||||
|
||||
#ifndef ENABLE_ARM
|
||||
void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col) {
|
||||
for (int ci = 0; ci < col; ci++) {
|
||||
|
@ -1271,44 +1288,8 @@ void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, c
|
|||
// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
|
||||
void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep, int act_type) {
|
||||
int index = 0;
|
||||
#ifdef ENABLE_AVX512
|
||||
__m512 down_threshold512 = _mm512_setzero_ps();
|
||||
__m512 up_threshold512 = _mm512_set1_ps(C6NUM);
|
||||
__m512 b_data16 = _mm512_set1_ps(b[0]);
|
||||
__m512 bias_data16 = _mm512_set1_ps(bias[0]);
|
||||
for (; index < row - C16NUM; index += C16NUM) {
|
||||
__m512 a_data = _mm512_loadu_ps(a + index);
|
||||
__m512 dst = b_data16 * a_data + bias_data16;
|
||||
ActCompute(512, down_threshold512, up_threshold512);
|
||||
_mm512_storeu_ps(c + index, dst);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
__m256 down_threshold256 = _mm256_setzero_ps();
|
||||
__m256 up_threshold256 = _mm256_set1_ps(C6NUM);
|
||||
__m256 b_data8 = _mm256_set1_ps(b[0]);
|
||||
__m256 bias_data8 = _mm256_set1_ps(bias[0]);
|
||||
for (; index < row - C8NUM; index += C8NUM) {
|
||||
__m256 a_data = _mm256_loadu_ps(a + index);
|
||||
__m256 dst = b_data8 * a_data + bias_data8;
|
||||
ActCompute(256, down_threshold256, up_threshold256);
|
||||
_mm256_storeu_ps(c + index, dst);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_SSE) || defined(ENABLE_ARM)
|
||||
MS_FLOAT32X4 down_threshold128 = MS_MOVQ_F32(0);
|
||||
MS_FLOAT32X4 up_threshold128 = MS_MOVQ_F32(C6NUM);
|
||||
MS_FLOAT32X4 b_data4 = MS_MOVQ_F32(b[0]);
|
||||
MS_FLOAT32X4 bias_data4 = MS_MOVQ_F32(bias[0]);
|
||||
for (; index < row - C4NUM; index += C4NUM) {
|
||||
MS_FLOAT32X4 a_data = MS_LDQ_F32(a + index);
|
||||
MS_FLOAT32X4 dst = MS_ADD128_F32(MS_MUL128_F32(b_data4, a_data), bias_data4);
|
||||
ActCompute(128, down_threshold128, up_threshold128);
|
||||
MS_STQ_F32(c + index, dst);
|
||||
}
|
||||
#endif
|
||||
SIMD_RUN_NO_SCALAR(GemmIsNotPack, index, a, b, c, bias, row, deep, act_type);
|
||||
|
||||
for (; index < row; ++index) {
|
||||
float dst = a[index] * b[0] + bias[0];
|
||||
|
@ -1321,41 +1302,9 @@ void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias,
|
|||
void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k, int act_type) {
|
||||
// gemm dot is [m, k] * [k, 1] ==>> [m, 1]
|
||||
int m_index = 0;
|
||||
#ifdef ENABLE_AVX512
|
||||
// block 8
|
||||
MS_FLOAT32X8 down_threshold256 = _mm256_setzero_ps();
|
||||
MS_FLOAT32X8 up_threshold256 = _mm256_set1_ps(C6NUM);
|
||||
for (; m_index <= m - C8NUM; m_index += C8NUM) {
|
||||
int k_index = 0;
|
||||
MS_FLOAT32X8 dst = MS_MOV256_F32(bias[0]);
|
||||
MS_SET_ZERO512X8_F32(dst16_)
|
||||
for (; k_index <= k - C16NUM; k_index += C16NUM) {
|
||||
__m512 weight = _mm512_loadu_ps(b + k_index);
|
||||
MS_LOAD512X8_F32(src, a + m_index * k + k_index, k)
|
||||
MS_FMADD512X8_F32(src, weight, dst16_)
|
||||
}
|
||||
MS_F32X8_GETI(dst, 0) += MS_REDUCE_ADD512_F32(dst16_1);
|
||||
MS_F32X8_GETI(dst, 1) += MS_REDUCE_ADD512_F32(dst16_2);
|
||||
MS_F32X8_GETI(dst, C2NUM) += MS_REDUCE_ADD512_F32(dst16_3);
|
||||
MS_F32X8_GETI(dst, C3NUM) += MS_REDUCE_ADD512_F32(dst16_4);
|
||||
MS_F32X8_GETI(dst, C4NUM) += MS_REDUCE_ADD512_F32(dst16_5);
|
||||
MS_F32X8_GETI(dst, C5NUM) += MS_REDUCE_ADD512_F32(dst16_6);
|
||||
MS_F32X8_GETI(dst, C6NUM) += MS_REDUCE_ADD512_F32(dst16_7);
|
||||
MS_F32X8_GETI(dst, C7NUM) += MS_REDUCE_ADD512_F32(dst16_8);
|
||||
for (; k_index < k; k_index++) {
|
||||
MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index];
|
||||
MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k];
|
||||
MS_F32X8_GETI(dst, C2NUM) += b[k_index] * a[m_index * k + k_index + C2NUM * k];
|
||||
MS_F32X8_GETI(dst, C3NUM) += b[k_index] * a[m_index * k + k_index + C3NUM * k];
|
||||
MS_F32X8_GETI(dst, C4NUM) += b[k_index] * a[m_index * k + k_index + C4NUM * k];
|
||||
MS_F32X8_GETI(dst, C5NUM) += b[k_index] * a[m_index * k + k_index + C5NUM * k];
|
||||
MS_F32X8_GETI(dst, C6NUM) += b[k_index] * a[m_index * k + k_index + C6NUM * k];
|
||||
MS_F32X8_GETI(dst, C7NUM) += b[k_index] * a[m_index * k + k_index + C7NUM * k];
|
||||
}
|
||||
ActCompute(256, down_threshold256, up_threshold256);
|
||||
MS_ST256_F32(c + m_index, dst);
|
||||
}
|
||||
#endif
|
||||
|
||||
SIMD_RUN_AVX512(GemmIsNotPackOptimize, m_index, a, b, c, bias, m, k, act_type);
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
// block 4
|
||||
MS_FLOAT32X4 down_threshold128 = MS_MOVQ_F32(0);
|
||||
|
@ -1388,24 +1337,10 @@ void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float
|
|||
for (; m_index < m; m_index++) {
|
||||
float dst = bias[0];
|
||||
int k_index = 0;
|
||||
#ifdef ENABLE_AVX512
|
||||
__m512 dst1 = _mm512_setzero_ps();
|
||||
for (; k_index <= k - C16NUM; k_index += C16NUM) {
|
||||
__m512 weight = _mm512_loadu_ps(b + k_index);
|
||||
__m512 a1 = _mm512_loadu_ps(a + m_index * k + k_index);
|
||||
dst1 = _mm512_fmadd_ps(weight, a1, dst1);
|
||||
}
|
||||
dst += _mm512_reduce_add_ps(dst1);
|
||||
#endif
|
||||
#ifdef ENABLE_AVX
|
||||
__m256 dst2 = _mm256_setzero_ps();
|
||||
for (; k_index <= k - C8NUM; k_index += C8NUM) {
|
||||
__m256 weight = _mm256_loadu_ps(b + k_index);
|
||||
__m256 src = _mm256_loadu_ps(a + m_index * k + k_index);
|
||||
dst2 = _mm256_fmadd_ps(weight, src, dst2);
|
||||
}
|
||||
dst += MS_REDUCE_ADD256_F32(dst2);
|
||||
#endif
|
||||
|
||||
SIMD_RUN_AVX512(GemmIsNotPackOptimizeCore, k_index, a + m_index * k, b, k, &dst);
|
||||
SIMD_RUN_AVX(GemmIsNotPackOptimizeCore, k_index, a + m_index * k, b, k, &dst);
|
||||
|
||||
for (; k_index < k; k_index++) {
|
||||
dst += b[k_index] * a[m_index * k + k_index];
|
||||
}
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_NNACL_FP32_MATMUL_F32_@SIMD_INSTRUCTION@_H_
|
||||
#define MINDSPORE_NNACL_FP32_MATMUL_F32_@SIMD_INSTRUCTION@_H_
|
||||
|
||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
||||
#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
@SIMD_INSTRUCTION_BEGIN@
|
||||
|
||||
// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
|
||||
static inline int64_t GemmIsNotPack@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, float *c, const float *bias, int row,
|
||||
int deep, int act_type) {
|
||||
SIMD_F32 down_threshold = SIMD_MOV_F32(0.0f);
|
||||
SIMD_F32 up_threshold = SIMD_MOV_F32(6);
|
||||
SIMD_F32 b_data16 = SIMD_MOV_F32(b[0]);
|
||||
SIMD_F32 bias_data16 = SIMD_MOV_F32(bias[0]);
|
||||
for (int block_max_size = row - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
|
||||
SIMD_F32 a_data = SIMD_LD_F32(a + index);
|
||||
SIMD_F32 dst = b_data16 * a_data + bias_data16;
|
||||
if (act_type != 0) {
|
||||
dst = SIMD_MAX_F32(dst, down_threshold);
|
||||
if (act_type == 3) {
|
||||
dst = SIMD_MIN_F32(dst, up_threshold);
|
||||
}
|
||||
}
|
||||
SIMD_ST_F32(c + index, dst);
|
||||
}
|
||||
|
||||
return index;
|
||||
}
|
||||
|
||||
#if defined(MS_SIMD_AVX512) || defined(MS_SIMD_AVX)
|
||||
static inline int64_t GemmIsNotPackOptimizeCore@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, int k, float *dst) {
|
||||
SIMD_F32 dst1 = SIMD_MOV_F32(0.0f);
|
||||
for (int block_max_size = k - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
|
||||
SIMD_F32 weight = SIMD_LD_F32(b + index);
|
||||
SIMD_F32 a1 = SIMD_LD_F32(a + index);
|
||||
dst1 = SIMD_FMADD_F32(weight, a1, dst1);
|
||||
}
|
||||
*dst += SIMD_REDUCE_ADD_F32(dst1);
|
||||
return index;
|
||||
}
|
||||
#endif
|
||||
|
||||
@SIMD_INSTRUCTION_END@
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
|
@ -134,6 +134,7 @@
|
|||
|
||||
// get max (float/int) op
|
||||
#define SIMD_GET_SUM_F32 MS_SIMD_INSTRUCTION_F32(MS_GET_SUM)
|
||||
#define SIMD_REDUCE_ADD_F32 MS_SIMD_INSTRUCTION(MS_REDUCE_ADD, _F32)
|
||||
|
||||
// clamp (float/int) op
|
||||
#define SIMD_CLAMP_F32(val, min_val, max_val) SIMD_MIN_F32(SIMD_MAX_F32(val, min_val), max_val)
|
||||
|
|
Loading…
Reference in New Issue