avx512 self feel, support matmal op

This commit is contained in:
greatpan 2022-06-01 10:50:10 +08:00
parent d813a7da0a
commit b2aa35fa25
5 changed files with 144 additions and 90 deletions

View File

@ -19,6 +19,7 @@
#include "nnacl/fp32/matmul_avx512_fp32.h"
#include "nnacl/op_base.h"
#include "nnacl/intrinsics/ms_simd_instructions.h"
void GemmRowxColKernelFp32(float *dst, const float *src, const float *weight, const float *bias, const size_t act_flag,
const size_t row_block, const size_t col_block, const size_t depth, const size_t src_stride,
@ -202,5 +203,53 @@ void MatVecMulAvx512Fp32(const float *a, const float *b, float *c, const float *
}
}
}
// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
int64_t GemmIsNotPackOptimizeAVX512(int64_t m_index, const float *a, const float *b, float *c, const float *bias, int m,
int k, int act_type) {
// gemm dot is [m, k] * [k, 1] ==>> [m, 1]
// block 8
MS_FLOAT32X8 down_threshold256 = _mm256_setzero_ps();
MS_FLOAT32X8 up_threshold256 = _mm256_set1_ps(C6NUM);
for (; m_index <= m - C8NUM; m_index += C8NUM) {
int k_index = 0;
MS_FLOAT32X8 dst = MS_MOV256_F32(bias[0]);
MS_SET_ZERO512X8_F32(dst16_)
for (; k_index <= k - C16NUM; k_index += C16NUM) {
__m512 weight = _mm512_loadu_ps(b + k_index);
MS_LOAD512X8_F32(src, a + m_index * k + k_index, k)
MS_FMADD512X8_F32(src, weight, dst16_)
}
MS_F32X8_GETI(dst, 0) += MS_REDUCE_ADD512_F32(dst16_1);
MS_F32X8_GETI(dst, 1) += MS_REDUCE_ADD512_F32(dst16_2);
MS_F32X8_GETI(dst, C2NUM) += MS_REDUCE_ADD512_F32(dst16_3);
MS_F32X8_GETI(dst, C3NUM) += MS_REDUCE_ADD512_F32(dst16_4);
MS_F32X8_GETI(dst, C4NUM) += MS_REDUCE_ADD512_F32(dst16_5);
MS_F32X8_GETI(dst, C5NUM) += MS_REDUCE_ADD512_F32(dst16_6);
MS_F32X8_GETI(dst, C6NUM) += MS_REDUCE_ADD512_F32(dst16_7);
MS_F32X8_GETI(dst, C7NUM) += MS_REDUCE_ADD512_F32(dst16_8);
for (; k_index < k; k_index++) {
MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index];
MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k];
MS_F32X8_GETI(dst, C2NUM) += b[k_index] * a[m_index * k + k_index + C2NUM * k];
MS_F32X8_GETI(dst, C3NUM) += b[k_index] * a[m_index * k + k_index + C3NUM * k];
MS_F32X8_GETI(dst, C4NUM) += b[k_index] * a[m_index * k + k_index + C4NUM * k];
MS_F32X8_GETI(dst, C5NUM) += b[k_index] * a[m_index * k + k_index + C5NUM * k];
MS_F32X8_GETI(dst, C6NUM) += b[k_index] * a[m_index * k + k_index + C6NUM * k];
MS_F32X8_GETI(dst, C7NUM) += b[k_index] * a[m_index * k + k_index + C7NUM * k];
}
if (act_type != 0) {
dst = MS_MAX256_F32(dst, down_threshold256);
if (act_type == 3) { // 3: relu6
dst = MS_MIN256_F32(dst, up_threshold256);
}
}
MS_ST256_F32(c + m_index, dst);
}
return m_index;
}
#pragma GCC pop_options
#endif

View File

@ -30,6 +30,9 @@ void MatVecMulAvx512Fp32(const float *a, const float *b, float *c, const float *
void MatMulAvx512Fp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int cur_col,
int col_align, int row);
int64_t GemmIsNotPackOptimizeAVX512(int64_t m_index, const float *a, const float *b, float *c, const float *bias, int m,
int k, int act_type);
// 64 block
void nnacl_gemm_avx512_6x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,

View File

@ -16,8 +16,25 @@
#include "nnacl/fp32/matmul_fp32.h"
#include "nnacl/fp32/pack_fp32.h"
#include "nnacl/fp32/matmul_avx512_fp32.h"
#include "nnacl/intrinsics/ms_simd_instructions.h"
#ifdef ENABLE_AVX512
#include "nnacl/avx512/matmul_fp32_avx512.h"
#endif
#ifdef ENABLE_AVX
#include "nnacl/avx/matmul_fp32_avx.h"
#endif
#ifdef ENABLE_SSE
#include "nnacl/sse/matmul_fp32_sse.h"
#endif
#ifdef ENABLE_ARM
#include "nnacl/neon/matmul_fp32_neon.h"
#endif
#ifndef ENABLE_ARM
void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col) {
for (int ci = 0; ci < col; ci++) {
@ -1271,44 +1288,8 @@ void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, c
// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep, int act_type) {
int index = 0;
#ifdef ENABLE_AVX512
__m512 down_threshold512 = _mm512_setzero_ps();
__m512 up_threshold512 = _mm512_set1_ps(C6NUM);
__m512 b_data16 = _mm512_set1_ps(b[0]);
__m512 bias_data16 = _mm512_set1_ps(bias[0]);
for (; index < row - C16NUM; index += C16NUM) {
__m512 a_data = _mm512_loadu_ps(a + index);
__m512 dst = b_data16 * a_data + bias_data16;
ActCompute(512, down_threshold512, up_threshold512);
_mm512_storeu_ps(c + index, dst);
}
#endif
#ifdef ENABLE_AVX
__m256 down_threshold256 = _mm256_setzero_ps();
__m256 up_threshold256 = _mm256_set1_ps(C6NUM);
__m256 b_data8 = _mm256_set1_ps(b[0]);
__m256 bias_data8 = _mm256_set1_ps(bias[0]);
for (; index < row - C8NUM; index += C8NUM) {
__m256 a_data = _mm256_loadu_ps(a + index);
__m256 dst = b_data8 * a_data + bias_data8;
ActCompute(256, down_threshold256, up_threshold256);
_mm256_storeu_ps(c + index, dst);
}
#endif
#if defined(ENABLE_SSE) || defined(ENABLE_ARM)
MS_FLOAT32X4 down_threshold128 = MS_MOVQ_F32(0);
MS_FLOAT32X4 up_threshold128 = MS_MOVQ_F32(C6NUM);
MS_FLOAT32X4 b_data4 = MS_MOVQ_F32(b[0]);
MS_FLOAT32X4 bias_data4 = MS_MOVQ_F32(bias[0]);
for (; index < row - C4NUM; index += C4NUM) {
MS_FLOAT32X4 a_data = MS_LDQ_F32(a + index);
MS_FLOAT32X4 dst = MS_ADD128_F32(MS_MUL128_F32(b_data4, a_data), bias_data4);
ActCompute(128, down_threshold128, up_threshold128);
MS_STQ_F32(c + index, dst);
}
#endif
SIMD_RUN_NO_SCALAR(GemmIsNotPack, index, a, b, c, bias, row, deep, act_type);
for (; index < row; ++index) {
float dst = a[index] * b[0] + bias[0];
@ -1321,41 +1302,9 @@ void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias,
void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k, int act_type) {
// gemm dot is [m, k] * [k, 1] ==>> [m, 1]
int m_index = 0;
#ifdef ENABLE_AVX512
// block 8
MS_FLOAT32X8 down_threshold256 = _mm256_setzero_ps();
MS_FLOAT32X8 up_threshold256 = _mm256_set1_ps(C6NUM);
for (; m_index <= m - C8NUM; m_index += C8NUM) {
int k_index = 0;
MS_FLOAT32X8 dst = MS_MOV256_F32(bias[0]);
MS_SET_ZERO512X8_F32(dst16_)
for (; k_index <= k - C16NUM; k_index += C16NUM) {
__m512 weight = _mm512_loadu_ps(b + k_index);
MS_LOAD512X8_F32(src, a + m_index * k + k_index, k)
MS_FMADD512X8_F32(src, weight, dst16_)
}
MS_F32X8_GETI(dst, 0) += MS_REDUCE_ADD512_F32(dst16_1);
MS_F32X8_GETI(dst, 1) += MS_REDUCE_ADD512_F32(dst16_2);
MS_F32X8_GETI(dst, C2NUM) += MS_REDUCE_ADD512_F32(dst16_3);
MS_F32X8_GETI(dst, C3NUM) += MS_REDUCE_ADD512_F32(dst16_4);
MS_F32X8_GETI(dst, C4NUM) += MS_REDUCE_ADD512_F32(dst16_5);
MS_F32X8_GETI(dst, C5NUM) += MS_REDUCE_ADD512_F32(dst16_6);
MS_F32X8_GETI(dst, C6NUM) += MS_REDUCE_ADD512_F32(dst16_7);
MS_F32X8_GETI(dst, C7NUM) += MS_REDUCE_ADD512_F32(dst16_8);
for (; k_index < k; k_index++) {
MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index];
MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k];
MS_F32X8_GETI(dst, C2NUM) += b[k_index] * a[m_index * k + k_index + C2NUM * k];
MS_F32X8_GETI(dst, C3NUM) += b[k_index] * a[m_index * k + k_index + C3NUM * k];
MS_F32X8_GETI(dst, C4NUM) += b[k_index] * a[m_index * k + k_index + C4NUM * k];
MS_F32X8_GETI(dst, C5NUM) += b[k_index] * a[m_index * k + k_index + C5NUM * k];
MS_F32X8_GETI(dst, C6NUM) += b[k_index] * a[m_index * k + k_index + C6NUM * k];
MS_F32X8_GETI(dst, C7NUM) += b[k_index] * a[m_index * k + k_index + C7NUM * k];
}
ActCompute(256, down_threshold256, up_threshold256);
MS_ST256_F32(c + m_index, dst);
}
#endif
SIMD_RUN_AVX512(GemmIsNotPackOptimize, m_index, a, b, c, bias, m, k, act_type);
#ifdef ENABLE_AVX
// block 4
MS_FLOAT32X4 down_threshold128 = MS_MOVQ_F32(0);
@ -1388,24 +1337,10 @@ void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float
for (; m_index < m; m_index++) {
float dst = bias[0];
int k_index = 0;
#ifdef ENABLE_AVX512
__m512 dst1 = _mm512_setzero_ps();
for (; k_index <= k - C16NUM; k_index += C16NUM) {
__m512 weight = _mm512_loadu_ps(b + k_index);
__m512 a1 = _mm512_loadu_ps(a + m_index * k + k_index);
dst1 = _mm512_fmadd_ps(weight, a1, dst1);
}
dst += _mm512_reduce_add_ps(dst1);
#endif
#ifdef ENABLE_AVX
__m256 dst2 = _mm256_setzero_ps();
for (; k_index <= k - C8NUM; k_index += C8NUM) {
__m256 weight = _mm256_loadu_ps(b + k_index);
__m256 src = _mm256_loadu_ps(a + m_index * k + k_index);
dst2 = _mm256_fmadd_ps(weight, src, dst2);
}
dst += MS_REDUCE_ADD256_F32(dst2);
#endif
SIMD_RUN_AVX512(GemmIsNotPackOptimizeCore, k_index, a + m_index * k, b, k, &dst);
SIMD_RUN_AVX(GemmIsNotPackOptimizeCore, k_index, a + m_index * k, b, k, &dst);
for (; k_index < k; k_index++) {
dst += b[k_index] * a[m_index * k + k_index];
}

View File

@ -0,0 +1,66 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_FP32_MATMUL_F32_@SIMD_INSTRUCTION@_H_
#define MINDSPORE_NNACL_FP32_MATMUL_F32_@SIMD_INSTRUCTION@_H_
#include "nnacl/intrinsics/ms_simd_instructions.h"
#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h"
#ifdef __cplusplus
extern "C" {
#endif
@SIMD_INSTRUCTION_BEGIN@
// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
static inline int64_t GemmIsNotPack@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, float *c, const float *bias, int row,
int deep, int act_type) {
SIMD_F32 down_threshold = SIMD_MOV_F32(0.0f);
SIMD_F32 up_threshold = SIMD_MOV_F32(6);
SIMD_F32 b_data16 = SIMD_MOV_F32(b[0]);
SIMD_F32 bias_data16 = SIMD_MOV_F32(bias[0]);
for (int block_max_size = row - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
SIMD_F32 a_data = SIMD_LD_F32(a + index);
SIMD_F32 dst = b_data16 * a_data + bias_data16;
if (act_type != 0) {
dst = SIMD_MAX_F32(dst, down_threshold);
if (act_type == 3) {
dst = SIMD_MIN_F32(dst, up_threshold);
}
}
SIMD_ST_F32(c + index, dst);
}
return index;
}
#if defined(MS_SIMD_AVX512) || defined(MS_SIMD_AVX)
static inline int64_t GemmIsNotPackOptimizeCore@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, int k, float *dst) {
SIMD_F32 dst1 = SIMD_MOV_F32(0.0f);
for (int block_max_size = k - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
SIMD_F32 weight = SIMD_LD_F32(b + index);
SIMD_F32 a1 = SIMD_LD_F32(a + index);
dst1 = SIMD_FMADD_F32(weight, a1, dst1);
}
*dst += SIMD_REDUCE_ADD_F32(dst1);
return index;
}
#endif
@SIMD_INSTRUCTION_END@
#ifdef __cplusplus
}
#endif
#endif

View File

@ -134,6 +134,7 @@
// get max (float/int) op
#define SIMD_GET_SUM_F32 MS_SIMD_INSTRUCTION_F32(MS_GET_SUM)
#define SIMD_REDUCE_ADD_F32 MS_SIMD_INSTRUCTION(MS_REDUCE_ADD, _F32)
// clamp (float/int) op
#define SIMD_CLAMP_F32(val, min_val, max_val) SIMD_MIN_F32(SIMD_MAX_F32(val, min_val), max_val)