!28569 [MS][LITE][CPU] matmul avx512 optimize
Merge pull request !28569 from liuzhongkai/code_generate3
This commit is contained in:
commit
eb2d1f3759
|
@ -173,3 +173,4 @@ mindspore/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc:mindspo
|
|||
mindspore/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc:mindspore::parallel::GetWeights
|
||||
mindspore/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc:mindspore::parallel::PartitionNode
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.c:InstanceNormNC8HW8Fp16
|
||||
mindspore/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc:mindspore::kernel::MatmulFp32BaseCPUKernel::init_global_variable
|
||||
|
|
|
@ -2173,7 +2173,7 @@ void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, c
|
|||
#endif
|
||||
#endif
|
||||
|
||||
void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row) {
|
||||
void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep) {
|
||||
int index = 0;
|
||||
#ifdef ENABLE_AVX512
|
||||
__m512 b_data16 = _mm512_set1_ps(b[0]);
|
||||
|
@ -2213,3 +2213,58 @@ void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias,
|
|||
c[index] = a[index] * b[0] + bias[0];
|
||||
}
|
||||
}
|
||||
|
||||
void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k) {
|
||||
// gemm dot is [m, k] * [k, 1] ==>> [m, 1]
|
||||
int m_index = 0;
|
||||
#ifdef ENABLE_AVX512
|
||||
// block 8
|
||||
for (; m_index <= m - C8NUM; m_index += C8NUM) {
|
||||
int k_index = 0;
|
||||
MS_FLOAT32X8 dst = MS_MOV256_F32(bias[0]);
|
||||
MS_SET_ZERO512X8_F32(dst16_)
|
||||
for (; k_index <= k - C16NUM; k_index += C16NUM) {
|
||||
__m512 weight = _mm512_loadu_ps(b + k_index);
|
||||
MS_LOAD512X8_F32(src, a + m_index * k + k_index, k)
|
||||
MS_FMADD512X8_F32(src, weight, dst16_)
|
||||
}
|
||||
MS_F32X8_GETI(dst, 0) += _mm512_reduce_add_ps(dst16_1);
|
||||
MS_F32X8_GETI(dst, 1) += _mm512_reduce_add_ps(dst16_2);
|
||||
MS_F32X8_GETI(dst, 2) += _mm512_reduce_add_ps(dst16_3);
|
||||
MS_F32X8_GETI(dst, 3) += _mm512_reduce_add_ps(dst16_4);
|
||||
MS_F32X8_GETI(dst, 4) += _mm512_reduce_add_ps(dst16_5);
|
||||
MS_F32X8_GETI(dst, 5) += _mm512_reduce_add_ps(dst16_6);
|
||||
MS_F32X8_GETI(dst, 6) += _mm512_reduce_add_ps(dst16_7);
|
||||
MS_F32X8_GETI(dst, 7) += _mm512_reduce_add_ps(dst16_8);
|
||||
for (; k_index < k; k_index++) {
|
||||
MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index];
|
||||
MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k];
|
||||
MS_F32X8_GETI(dst, 2) += b[k_index] * a[m_index * k + k_index + 2 * k];
|
||||
MS_F32X8_GETI(dst, 3) += b[k_index] * a[m_index * k + k_index + 3 * k];
|
||||
MS_F32X8_GETI(dst, 4) += b[k_index] * a[m_index * k + k_index + 4 * k];
|
||||
MS_F32X8_GETI(dst, 5) += b[k_index] * a[m_index * k + k_index + 5 * k];
|
||||
MS_F32X8_GETI(dst, 6) += b[k_index] * a[m_index * k + k_index + 6 * k];
|
||||
MS_F32X8_GETI(dst, 7) += b[k_index] * a[m_index * k + k_index + 7 * k];
|
||||
}
|
||||
MS_ST256_F32(c + m_index, dst);
|
||||
}
|
||||
#endif
|
||||
|
||||
// block 1
|
||||
for (; m_index < m; m_index++) {
|
||||
c[m_index] = bias[0];
|
||||
int k_index = 0;
|
||||
#ifdef ENABLE_AVX512
|
||||
__m512 dst1 = _mm512_setzero_ps();
|
||||
for (; k_index <= k - C16NUM; k_index += C16NUM) {
|
||||
__m512 weight = _mm512_loadu_ps(b + k_index);
|
||||
__m512 a1 = _mm512_loadu_ps(a + m_index * k + k_index);
|
||||
dst1 = _mm512_fmadd_ps(weight, a1, dst1);
|
||||
}
|
||||
c[m_index] += _mm512_reduce_add_ps(dst1);
|
||||
#endif
|
||||
for (; k_index < k; k_index++) {
|
||||
c[m_index] += b[k_index] * a[m_index * k + k_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -125,7 +125,9 @@ void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *
|
|||
void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
|
||||
int col, int stride, int out_type);
|
||||
|
||||
void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row);
|
||||
void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep);
|
||||
|
||||
void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -512,7 +512,7 @@ void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, in
|
|||
for (; oc < oc_remainder_c8; oc += C8NUM) {
|
||||
const float *cur_src = src + index_batch + oc;
|
||||
float *cur_dst = dst + oc;
|
||||
LOAD256X16_F32(r, cur_src, channel);
|
||||
MS_LOAD256X16_F32(r, cur_src, channel);
|
||||
STORE256X16_F32(cur_dst, stride, r);
|
||||
}
|
||||
for (; oc < oc_remainder; ++oc) {
|
||||
|
@ -821,7 +821,7 @@ inline void Transpose8X8Fp32Arm32(const float *src_ptr, float *dst_ptr, int src_
|
|||
|
||||
#ifdef ENABLE_AVX
|
||||
inline void Transpose8X8Fp32Avx(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) {
|
||||
LOAD256X8_F32(src, src_ptr, src_stride)
|
||||
MS_LOAD256X8_F32(src, src_ptr, src_stride)
|
||||
__m256 r1 = _mm256_unpacklo_ps(src1, src2);
|
||||
__m256 r2 = _mm256_unpackhi_ps(src1, src2);
|
||||
__m256 r3 = _mm256_unpacklo_ps(src3, src4);
|
||||
|
|
|
@ -35,7 +35,7 @@
|
|||
#define MS_ADD512_EPI32 _mm512_add_epi32
|
||||
#define MS_MOV512_F32 _mm512_set1_ps
|
||||
#define MS_MOV512_EPI32 _mm512_set1_epi32
|
||||
#define MS_MLA512_F32(src1, src2, src3) _mm512_add_ps(src1, _mm512_mul_ps(src2, src3))
|
||||
#define MS_MLA512_F32(src1, src2, src3) _mm512_fmadd_ps(src2, src3, src1)
|
||||
#define MS_ST512_F32 _mm512_storeu_ps
|
||||
#define MS_ST512_EPI32(src1, src2) _mm512_storeu_si512((__m512i *)(src1), src2)
|
||||
#define MS_SUB512_F32 _mm512_sub_ps
|
||||
|
@ -93,4 +93,51 @@ static inline MS_FLOAT32X16 MS_TANHX16_F32(MS_FLOAT32X16 src) {
|
|||
return MS_MIN512_F32(MS_MAX512_F32(MS_DIV512_F32(a, b), neg), pos);
|
||||
}
|
||||
|
||||
#endif
|
||||
#define MS_LOAD512X8_F32(src, input_ptr, num) \
|
||||
MS_FLOAT32X16 src##1 = MS_LD512_F32(input_ptr); \
|
||||
MS_FLOAT32X16 src##2 = MS_LD512_F32(input_ptr + 1 * num); \
|
||||
MS_FLOAT32X16 src##3 = MS_LD512_F32(input_ptr + 2 * num); \
|
||||
MS_FLOAT32X16 src##4 = MS_LD512_F32(input_ptr + 3 * num); \
|
||||
MS_FLOAT32X16 src##5 = MS_LD512_F32(input_ptr + 4 * num); \
|
||||
MS_FLOAT32X16 src##6 = MS_LD512_F32(input_ptr + 5 * num); \
|
||||
MS_FLOAT32X16 src##7 = MS_LD512_F32(input_ptr + 6 * num); \
|
||||
MS_FLOAT32X16 src##8 = MS_LD512_F32(input_ptr + 7 * num);
|
||||
|
||||
#define MS_LOAD512X4_F32(src, input_ptr, num) \
|
||||
MS_FLOAT32X16 src##1 = MS_LD512_F32(input_ptr); \
|
||||
MS_FLOAT32X16 src##2 = MS_LD512_F32(input_ptr + 1 * num); \
|
||||
MS_FLOAT32X16 src##3 = MS_LD512_F32(input_ptr + 2 * num); \
|
||||
MS_FLOAT32X16 src##4 = MS_LD512_F32(input_ptr + 3 * num);
|
||||
|
||||
#define MS_FMADD512X8_F32(src, weight, dst) \
|
||||
dst##1 = MS_MLA512_F32(dst##1, src##1, weight); \
|
||||
dst##2 = MS_MLA512_F32(dst##2, src##2, weight); \
|
||||
dst##3 = MS_MLA512_F32(dst##3, src##3, weight); \
|
||||
dst##4 = MS_MLA512_F32(dst##4, src##4, weight); \
|
||||
dst##5 = MS_MLA512_F32(dst##5, src##5, weight); \
|
||||
dst##6 = MS_MLA512_F32(dst##6, src##6, weight); \
|
||||
dst##7 = MS_MLA512_F32(dst##7, src##7, weight); \
|
||||
dst##8 = MS_MLA512_F32(dst##8, src##8, weight);
|
||||
|
||||
#define MS_FMADD512X4_F32(src, weight, dst) \
|
||||
dst##1 = MS_MLA512_F32(src##1, weight, dst##1); \
|
||||
dst##2 = MS_MLA512_F32(src##2, weight, dst##2); \
|
||||
dst##3 = MS_MLA512_F32(src##3, weight, dst##3); \
|
||||
dst##4 = MS_MLA512_F32(src##4, weight, dst##4);
|
||||
|
||||
#define MS_SET_ZERO512X8_F32(dst) \
|
||||
MS_FLOAT32X16 dst##1 = _mm512_setzero_ps(); \
|
||||
MS_FLOAT32X16 dst##2 = _mm512_setzero_ps(); \
|
||||
MS_FLOAT32X16 dst##3 = _mm512_setzero_ps(); \
|
||||
MS_FLOAT32X16 dst##4 = _mm512_setzero_ps(); \
|
||||
MS_FLOAT32X16 dst##5 = _mm512_setzero_ps(); \
|
||||
MS_FLOAT32X16 dst##6 = _mm512_setzero_ps(); \
|
||||
MS_FLOAT32X16 dst##7 = _mm512_setzero_ps(); \
|
||||
MS_FLOAT32X16 dst##8 = _mm512_setzero_ps();
|
||||
|
||||
#define MS_SET_ZERO512X4_F32(dst) \
|
||||
MS_FLOAT32X16 dst##1 = _mm512_setzero_ps(); \
|
||||
MS_FLOAT32X16 dst##2 = _mm512_setzero_ps(); \
|
||||
MS_FLOAT32X16 dst##3 = _mm512_setzero_ps(); \
|
||||
MS_FLOAT32X16 dst##4 = _mm512_setzero_ps();
|
||||
#endif // MINDSPORE_NNACL_AVX512_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
|
||||
|
|
|
@ -76,7 +76,7 @@ static inline MS_FLOAT32X8 MS_SQRTFX8_F32(MS_FLOAT32X8 src) {
|
|||
return dst;
|
||||
}
|
||||
|
||||
#define LOAD256X8_F32(src, input_ptr, num) \
|
||||
#define MS_LOAD256X8_F32(src, input_ptr, num) \
|
||||
MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \
|
||||
MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \
|
||||
MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \
|
||||
|
@ -86,7 +86,7 @@ static inline MS_FLOAT32X8 MS_SQRTFX8_F32(MS_FLOAT32X8 src) {
|
|||
MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \
|
||||
MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num);
|
||||
|
||||
#define LOAD256X16_F32(src, input_ptr, num) \
|
||||
#define MS_LOAD256X16_F32(src, input_ptr, num) \
|
||||
MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \
|
||||
MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \
|
||||
MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \
|
||||
|
@ -154,4 +154,35 @@ static inline MS_FLOAT32X8 MS_TANHX8_F32(MS_FLOAT32X8 src) {
|
|||
return MS_MIN256_F32(MS_MAX256_F32(MS_DIV256_F32(a, b), neg), pos);
|
||||
}
|
||||
|
||||
#endif
|
||||
#define MS_FMADD256X8_F32(src, weight, dst) \
|
||||
dst##1 = MS_MLA256_F32(dst##1, src##1, weight); \
|
||||
dst##2 = MS_MLA256_F32(dst##2, src##2, weight); \
|
||||
dst##3 = MS_MLA256_F32(dst##3, src##3, weight); \
|
||||
dst##4 = MS_MLA256_F32(dst##4, src##4, weight); \
|
||||
dst##5 = MS_MLA256_F32(dst##5, src##5, weight); \
|
||||
dst##6 = MS_MLA256_F32(dst##6, src##6, weight); \
|
||||
dst##7 = MS_MLA256_F32(dst##7, src##7, weight); \
|
||||
dst##8 = MS_MLA256_F32(dst##8, src##8, weight);
|
||||
|
||||
#define MS_SET_ZERO256X8_F32(dst) \
|
||||
MS_FLOAT32X8 dst##1 = _mm256_setzero_ps(); \
|
||||
MS_FLOAT32X8 dst##2 = _mm256_setzero_ps(); \
|
||||
MS_FLOAT32X8 dst##3 = _mm256_setzero_ps(); \
|
||||
MS_FLOAT32X8 dst##4 = _mm256_setzero_ps(); \
|
||||
MS_FLOAT32X8 dst##5 = _mm256_setzero_ps(); \
|
||||
MS_FLOAT32X8 dst##6 = _mm256_setzero_ps(); \
|
||||
MS_FLOAT32X8 dst##7 = _mm256_setzero_ps(); \
|
||||
MS_FLOAT32X8 dst##8 = _mm256_setzero_ps();
|
||||
|
||||
#define MS_FMADD256X4_F32(src, weight, dst) \
|
||||
dst##1 = MS_MLA256_F32(dst##1, src##1, weight); \
|
||||
dst##2 = MS_MLA256_F32(dst##2, src##2, weight); \
|
||||
dst##3 = MS_MLA256_F32(dst##3, src##3, weight); \
|
||||
dst##4 = MS_MLA256_F32(dst##4, src##4, weight);
|
||||
|
||||
#define MS_SET_ZERO256X4_F32(dst) \
|
||||
MS_FLOAT32X8 dst##1 = _mm256_setzero_ps(); \
|
||||
MS_FLOAT32X8 dst##2 = _mm256_setzero_ps(); \
|
||||
MS_FLOAT32X8 dst##3 = _mm256_setzero_ps(); \
|
||||
MS_FLOAT32X8 dst##4 = _mm256_setzero_ps();
|
||||
#endif // MINDSPORE_NNACL_AVX_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
|
||||
|
|
|
@ -152,4 +152,51 @@ static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) {
|
|||
return dst;
|
||||
}
|
||||
|
||||
#endif
|
||||
#define MS_FMADD128X8_F32(src, weight, dst) \
|
||||
dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \
|
||||
dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \
|
||||
dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \
|
||||
dst##4 = MS_MLAQ_F32(src##4, weight, dst##4); \
|
||||
dst##5 = MS_MLAQ_F32(src##5, weight, dst##5); \
|
||||
dst##6 = MS_MLAQ_F32(src##6, weight, dst##6); \
|
||||
dst##7 = MS_MLAQ_F32(src##7, weight, dst##7); \
|
||||
dst##8 = MS_MLAQ_F32(src##8, weight, dst##8);
|
||||
|
||||
#define MS_LOAD128X4_F32(src, input_ptr, num) \
|
||||
MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \
|
||||
MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \
|
||||
MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \
|
||||
MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num);
|
||||
|
||||
#define MS_FMADD128X4_F32(src, weight, dst) \
|
||||
dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \
|
||||
dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \
|
||||
dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \
|
||||
dst##4 = MS_MLAQ_F32(src##4, weight, dst##4);
|
||||
|
||||
#define MS_LOAD128X8_F32(src, input_ptr, num) \
|
||||
MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \
|
||||
MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \
|
||||
MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \
|
||||
MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \
|
||||
MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \
|
||||
MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \
|
||||
MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \
|
||||
MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num);
|
||||
|
||||
#define MS_SET_ZERO128X8_F32(dst) \
|
||||
MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \
|
||||
MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \
|
||||
MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \
|
||||
MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f); \
|
||||
MS_FLOAT32X4 dst##5 = MS_MOVQ_F32(0.0f); \
|
||||
MS_FLOAT32X4 dst##6 = MS_MOVQ_F32(0.0f); \
|
||||
MS_FLOAT32X4 dst##7 = MS_MOVQ_F32(0.0f); \
|
||||
MS_FLOAT32X4 dst##8 = MS_MOVQ_F32(0.0f);
|
||||
|
||||
#define MS_SET_ZERO128X4_F32(dst) \
|
||||
MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \
|
||||
MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \
|
||||
MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \
|
||||
MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f);
|
||||
#endif // MINDSPORE_NNACL_NEON_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
|
||||
|
|
|
@ -90,16 +90,6 @@ static inline MS_FLOAT32X4 MS_SQRTFX4_F32(MS_FLOAT32X4 src) {
|
|||
return dst;
|
||||
}
|
||||
|
||||
#define LOAD128X8_F32(src, input_ptr, num) \
|
||||
MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \
|
||||
MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \
|
||||
MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \
|
||||
MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \
|
||||
MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \
|
||||
MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \
|
||||
MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \
|
||||
MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num);
|
||||
|
||||
#define STORE128X8_F32(output_ptr, num, dst) \
|
||||
MS_STQ_F32(output_ptr + 0 * num, dst##1); \
|
||||
MS_STQ_F32(output_ptr + 1 * num, dst##2); \
|
||||
|
@ -137,4 +127,51 @@ static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) {
|
|||
return dst;
|
||||
}
|
||||
|
||||
#endif
|
||||
#define MS_FMADD128X8_F32(src, weight, dst) \
|
||||
dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \
|
||||
dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \
|
||||
dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \
|
||||
dst##4 = MS_MLAQ_F32(src##4, weight, dst##4); \
|
||||
dst##5 = MS_MLAQ_F32(src##5, weight, dst##5); \
|
||||
dst##6 = MS_MLAQ_F32(src##6, weight, dst##6); \
|
||||
dst##7 = MS_MLAQ_F32(src##7, weight, dst##7); \
|
||||
dst##8 = MS_MLAQ_F32(src##8, weight, dst##8);
|
||||
|
||||
#define MS_LOAD128X4_F32(src, input_ptr, num) \
|
||||
MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \
|
||||
MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \
|
||||
MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \
|
||||
MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num);
|
||||
|
||||
#define MS_FMADD128X4_F32(src, weight, dst) \
|
||||
dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \
|
||||
dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \
|
||||
dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \
|
||||
dst##4 = MS_MLAQ_F32(src##4, weight, dst##4);
|
||||
|
||||
#define MS_LOAD128X8_F32(src, input_ptr, num) \
|
||||
MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \
|
||||
MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \
|
||||
MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \
|
||||
MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \
|
||||
MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \
|
||||
MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \
|
||||
MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \
|
||||
MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num);
|
||||
|
||||
#define MS_SET_ZERO128X8_F32(dst) \
|
||||
MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \
|
||||
MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \
|
||||
MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \
|
||||
MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f); \
|
||||
MS_FLOAT32X4 dst##5 = MS_MOVQ_F32(0.0f); \
|
||||
MS_FLOAT32X4 dst##6 = MS_MOVQ_F32(0.0f); \
|
||||
MS_FLOAT32X4 dst##7 = MS_MOVQ_F32(0.0f); \
|
||||
MS_FLOAT32X4 dst##8 = MS_MOVQ_F32(0.0f);
|
||||
|
||||
#define MS_SET_ZERO128X4_F32(dst) \
|
||||
MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \
|
||||
MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \
|
||||
MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \
|
||||
MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f);
|
||||
#endif // MINDSPORE_NNACL_SSE_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
|
||||
|
|
|
@ -37,10 +37,8 @@ int MatmulRun(const void *cdata, int task_id, float, float) {
|
|||
}
|
||||
|
||||
MatmulFp32BaseCPUKernel::~MatmulFp32BaseCPUKernel() {
|
||||
if (is_pack_) {
|
||||
FreeResizeBufA();
|
||||
FreeResizeBufB();
|
||||
}
|
||||
FreeResizeBufA();
|
||||
FreeResizeBufB();
|
||||
if (is_pack_ && out_need_aligned_ && oc_res_ != 0 && output_data_ != nullptr) {
|
||||
free(output_data_);
|
||||
output_data_ = nullptr;
|
||||
|
@ -250,7 +248,7 @@ int MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch(int task_id) const {
|
|||
const float *a = a_pack_ptr_ + index * params_->row_ * params_->deep_;
|
||||
const float *b = b_pack_ptr_ + index * params_->deep_ * params_->col_;
|
||||
float *c = output_data_ + index * params_->row_ * params_->col_;
|
||||
GemmIsNotPack(a, b, c, &bias, params_->row_);
|
||||
gemmIsNotPackFun(a, b, c, &bias, params_->row_, params_->deep_);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -294,7 +292,7 @@ int MatmulFp32BaseCPUKernel::ParallelRunByOC(int task_id) const {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void MatmulFp32BaseCPUKernel::init_global_variable() {
|
||||
int MatmulFp32BaseCPUKernel::init_global_variable() {
|
||||
#ifdef ENABLE_AVX512
|
||||
matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor;
|
||||
matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2Col64Major : RowMajor2Row64Major;
|
||||
|
@ -335,18 +333,27 @@ void MatmulFp32BaseCPUKernel::init_global_variable() {
|
|||
// need not aligned
|
||||
col_step_ = params_->col_;
|
||||
#endif
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->row_align_, RET_ERROR);
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->row_align_, params_->deep_, RET_ERROR);
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->col_align_, RET_ERROR);
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->col_align_, params_->deep_, RET_ERROR);
|
||||
if (params_->col_ == 1 && params_->b_const_) {
|
||||
is_pack_ = false;
|
||||
matrix_a_pack_size_ = a_batch_ * params_->row_ * params_->deep_;
|
||||
matrix_b_pack_size_ = b_batch_ * params_->col_ * params_->deep_;
|
||||
matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor;
|
||||
matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor;
|
||||
} else {
|
||||
matrix_a_pack_size_ = a_batch_ * params_->row_align_ * params_->deep_;
|
||||
matrix_b_pack_size_ = b_batch_ * params_->col_align_ * params_->deep_;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulFp32BaseCPUKernel::Prepare() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), C2NUM);
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), 1);
|
||||
init_global_variable();
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->row_align_, RET_ERROR);
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->row_align_, params_->deep_, RET_ERROR);
|
||||
matrix_a_pack_size_ = a_batch_ * params_->row_align_ * params_->deep_;
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->col_align_, RET_ERROR);
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->col_align_, params_->deep_, RET_ERROR);
|
||||
matrix_b_pack_size_ = b_batch_ * params_->col_align_ * params_->deep_;
|
||||
if (matrix_a_pack_size_ < 0 || matrix_b_pack_size_ < 0) {
|
||||
MS_LOG(ERROR) << "Matrix pack size is negative "
|
||||
<< "matrix_a_pack_size=" << matrix_a_pack_size_ << "matrix_b_pack_size=" << matrix_b_pack_size_;
|
||||
|
@ -358,6 +365,8 @@ int MatmulFp32BaseCPUKernel::Prepare() {
|
|||
return ret;
|
||||
}
|
||||
if (params_->a_const_) {
|
||||
auto a_tensor = in_tensors_[0];
|
||||
CHECK_NULL_RETURN(a_tensor);
|
||||
if (InitBufferA() != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -394,10 +403,6 @@ int MatmulFp32BaseCPUKernel::ReSize() {
|
|||
set_workspace_size((matrix_a_pack_size_ + matrix_b_pack_size_) * static_cast<int>(sizeof(float)));
|
||||
}
|
||||
GetThreadCuttingPolicy();
|
||||
if (params_->col_ == 1 && params_->deep_ == 1) {
|
||||
is_pack_ = false;
|
||||
parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch;
|
||||
}
|
||||
auto ret = InitTmpOutBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InitTmpOutBuffer error!";
|
||||
|
@ -438,7 +443,7 @@ int MatmulFp32BaseCPUKernel::InitTmpOutBuffer() {
|
|||
}
|
||||
|
||||
void MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() {
|
||||
if (params_->batch >= op_parameter_->thread_num_) {
|
||||
if (params_->batch >= op_parameter_->thread_num_ || (params_->col_ == 1 && params_->b_const_)) {
|
||||
thread_count_ = op_parameter_->thread_num_;
|
||||
batch_stride_ = UP_DIV(params_->batch, thread_count_);
|
||||
batch_split_ = true;
|
||||
|
@ -453,6 +458,15 @@ void MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() {
|
|||
batch_split_ = false;
|
||||
parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunByOC;
|
||||
}
|
||||
if (params_->col_ == 1 && params_->b_const_) {
|
||||
is_pack_ = false;
|
||||
parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch;
|
||||
if (params_->deep_ == 1) {
|
||||
gemmIsNotPackFun = GemmIsNotPack;
|
||||
} else {
|
||||
gemmIsNotPackFun = GemmIsNotPackOptimize;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int MatmulFp32BaseCPUKernel::Run() {
|
||||
|
@ -517,12 +531,20 @@ int MatmulFp32BaseCPUKernel::Run() {
|
|||
PackNHWCXToNHWCFp32(output_data_, out_data, params_->batch, params_->row_, params_->col_, col_tile_);
|
||||
}
|
||||
if (!params_->a_const_) {
|
||||
FreeResizeBufA();
|
||||
if (is_pack_) {
|
||||
FreeResizeBufA();
|
||||
} else {
|
||||
a_pack_ptr_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
if (!params_->b_const_) {
|
||||
FreeResizeBufB();
|
||||
if (is_pack_) {
|
||||
FreeResizeBufB();
|
||||
} else {
|
||||
b_pack_ptr_ = nullptr;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -33,6 +33,8 @@ using GemmFun = void (*)(const float *a, const float *b, float *c, const float *
|
|||
const int depth, const int cur_col, const int col_align, const int row);
|
||||
using GemvFun = void (*)(const float *a, const float *b, float *c, const float *bias, const int act_type,
|
||||
const int depth, const int cur_col, const int col_align);
|
||||
using GemmIsNotPackFun = void (*)(const float *a, const float *b, float *c, const float *bias, int m, int k);
|
||||
|
||||
class MatmulFp32BaseCPUKernel : public InnerKernel {
|
||||
public:
|
||||
MatmulFp32BaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
|
@ -62,7 +64,7 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
|
|||
void FreeBiasBuf();
|
||||
int InitBiasData();
|
||||
void InitParameter();
|
||||
void init_global_variable();
|
||||
int init_global_variable();
|
||||
|
||||
private:
|
||||
void ResizeParameter();
|
||||
|
@ -105,6 +107,7 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
|
|||
GemmFun gemmCalFun = nullptr;
|
||||
GemvFun gemvCalFun = nullptr;
|
||||
#endif
|
||||
GemmIsNotPackFun gemmIsNotPackFun = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_MATMUL_FP32_BASE_H_
|
||||
|
|
Loading…
Reference in New Issue