!28569 [MS][LITE][CPU] matmul avx512 optimize

Merge pull request !28569 from liuzhongkai/code_generate3
This commit is contained in:
i-robot 2022-01-07 01:40:14 +00:00 committed by Gitee
commit eb2d1f3759
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 287 additions and 42 deletions

View File

@ -173,3 +173,4 @@ mindspore/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc:mindspo
mindspore/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc:mindspore::parallel::GetWeights
mindspore/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc:mindspore::parallel::PartitionNode
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.c:InstanceNormNC8HW8Fp16
mindspore/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc:mindspore::kernel::MatmulFp32BaseCPUKernel::init_global_variable

View File

@ -2173,7 +2173,7 @@ void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, c
#endif
#endif
void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row) {
void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep) {
int index = 0;
#ifdef ENABLE_AVX512
__m512 b_data16 = _mm512_set1_ps(b[0]);
@ -2213,3 +2213,58 @@ void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias,
c[index] = a[index] * b[0] + bias[0];
}
}
void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k) {
// gemm dot is [m, k] * [k, 1] ==>> [m, 1]
int m_index = 0;
#ifdef ENABLE_AVX512
// block 8
for (; m_index <= m - C8NUM; m_index += C8NUM) {
int k_index = 0;
MS_FLOAT32X8 dst = MS_MOV256_F32(bias[0]);
MS_SET_ZERO512X8_F32(dst16_)
for (; k_index <= k - C16NUM; k_index += C16NUM) {
__m512 weight = _mm512_loadu_ps(b + k_index);
MS_LOAD512X8_F32(src, a + m_index * k + k_index, k)
MS_FMADD512X8_F32(src, weight, dst16_)
}
MS_F32X8_GETI(dst, 0) += _mm512_reduce_add_ps(dst16_1);
MS_F32X8_GETI(dst, 1) += _mm512_reduce_add_ps(dst16_2);
MS_F32X8_GETI(dst, 2) += _mm512_reduce_add_ps(dst16_3);
MS_F32X8_GETI(dst, 3) += _mm512_reduce_add_ps(dst16_4);
MS_F32X8_GETI(dst, 4) += _mm512_reduce_add_ps(dst16_5);
MS_F32X8_GETI(dst, 5) += _mm512_reduce_add_ps(dst16_6);
MS_F32X8_GETI(dst, 6) += _mm512_reduce_add_ps(dst16_7);
MS_F32X8_GETI(dst, 7) += _mm512_reduce_add_ps(dst16_8);
for (; k_index < k; k_index++) {
MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index];
MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k];
MS_F32X8_GETI(dst, 2) += b[k_index] * a[m_index * k + k_index + 2 * k];
MS_F32X8_GETI(dst, 3) += b[k_index] * a[m_index * k + k_index + 3 * k];
MS_F32X8_GETI(dst, 4) += b[k_index] * a[m_index * k + k_index + 4 * k];
MS_F32X8_GETI(dst, 5) += b[k_index] * a[m_index * k + k_index + 5 * k];
MS_F32X8_GETI(dst, 6) += b[k_index] * a[m_index * k + k_index + 6 * k];
MS_F32X8_GETI(dst, 7) += b[k_index] * a[m_index * k + k_index + 7 * k];
}
MS_ST256_F32(c + m_index, dst);
}
#endif
// block 1
for (; m_index < m; m_index++) {
c[m_index] = bias[0];
int k_index = 0;
#ifdef ENABLE_AVX512
__m512 dst1 = _mm512_setzero_ps();
for (; k_index <= k - C16NUM; k_index += C16NUM) {
__m512 weight = _mm512_loadu_ps(b + k_index);
__m512 a1 = _mm512_loadu_ps(a + m_index * k + k_index);
dst1 = _mm512_fmadd_ps(weight, a1, dst1);
}
c[m_index] += _mm512_reduce_add_ps(dst1);
#endif
for (; k_index < k; k_index++) {
c[m_index] += b[k_index] * a[m_index * k + k_index];
}
}
}

View File

@ -125,7 +125,9 @@ void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *
void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
int col, int stride, int out_type);
void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row);
void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep);
void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k);
#ifdef __cplusplus
}

View File

@ -512,7 +512,7 @@ void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, in
for (; oc < oc_remainder_c8; oc += C8NUM) {
const float *cur_src = src + index_batch + oc;
float *cur_dst = dst + oc;
LOAD256X16_F32(r, cur_src, channel);
MS_LOAD256X16_F32(r, cur_src, channel);
STORE256X16_F32(cur_dst, stride, r);
}
for (; oc < oc_remainder; ++oc) {
@ -821,7 +821,7 @@ inline void Transpose8X8Fp32Arm32(const float *src_ptr, float *dst_ptr, int src_
#ifdef ENABLE_AVX
inline void Transpose8X8Fp32Avx(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) {
LOAD256X8_F32(src, src_ptr, src_stride)
MS_LOAD256X8_F32(src, src_ptr, src_stride)
__m256 r1 = _mm256_unpacklo_ps(src1, src2);
__m256 r2 = _mm256_unpackhi_ps(src1, src2);
__m256 r3 = _mm256_unpacklo_ps(src3, src4);

View File

@ -35,7 +35,7 @@
#define MS_ADD512_EPI32 _mm512_add_epi32
#define MS_MOV512_F32 _mm512_set1_ps
#define MS_MOV512_EPI32 _mm512_set1_epi32
#define MS_MLA512_F32(src1, src2, src3) _mm512_add_ps(src1, _mm512_mul_ps(src2, src3))
#define MS_MLA512_F32(src1, src2, src3) _mm512_fmadd_ps(src2, src3, src1)
#define MS_ST512_F32 _mm512_storeu_ps
#define MS_ST512_EPI32(src1, src2) _mm512_storeu_si512((__m512i *)(src1), src2)
#define MS_SUB512_F32 _mm512_sub_ps
@ -93,4 +93,51 @@ static inline MS_FLOAT32X16 MS_TANHX16_F32(MS_FLOAT32X16 src) {
return MS_MIN512_F32(MS_MAX512_F32(MS_DIV512_F32(a, b), neg), pos);
}
#endif
#define MS_LOAD512X8_F32(src, input_ptr, num) \
MS_FLOAT32X16 src##1 = MS_LD512_F32(input_ptr); \
MS_FLOAT32X16 src##2 = MS_LD512_F32(input_ptr + 1 * num); \
MS_FLOAT32X16 src##3 = MS_LD512_F32(input_ptr + 2 * num); \
MS_FLOAT32X16 src##4 = MS_LD512_F32(input_ptr + 3 * num); \
MS_FLOAT32X16 src##5 = MS_LD512_F32(input_ptr + 4 * num); \
MS_FLOAT32X16 src##6 = MS_LD512_F32(input_ptr + 5 * num); \
MS_FLOAT32X16 src##7 = MS_LD512_F32(input_ptr + 6 * num); \
MS_FLOAT32X16 src##8 = MS_LD512_F32(input_ptr + 7 * num);
#define MS_LOAD512X4_F32(src, input_ptr, num) \
MS_FLOAT32X16 src##1 = MS_LD512_F32(input_ptr); \
MS_FLOAT32X16 src##2 = MS_LD512_F32(input_ptr + 1 * num); \
MS_FLOAT32X16 src##3 = MS_LD512_F32(input_ptr + 2 * num); \
MS_FLOAT32X16 src##4 = MS_LD512_F32(input_ptr + 3 * num);
#define MS_FMADD512X8_F32(src, weight, dst) \
dst##1 = MS_MLA512_F32(dst##1, src##1, weight); \
dst##2 = MS_MLA512_F32(dst##2, src##2, weight); \
dst##3 = MS_MLA512_F32(dst##3, src##3, weight); \
dst##4 = MS_MLA512_F32(dst##4, src##4, weight); \
dst##5 = MS_MLA512_F32(dst##5, src##5, weight); \
dst##6 = MS_MLA512_F32(dst##6, src##6, weight); \
dst##7 = MS_MLA512_F32(dst##7, src##7, weight); \
dst##8 = MS_MLA512_F32(dst##8, src##8, weight);
#define MS_FMADD512X4_F32(src, weight, dst) \
dst##1 = MS_MLA512_F32(src##1, weight, dst##1); \
dst##2 = MS_MLA512_F32(src##2, weight, dst##2); \
dst##3 = MS_MLA512_F32(src##3, weight, dst##3); \
dst##4 = MS_MLA512_F32(src##4, weight, dst##4);
#define MS_SET_ZERO512X8_F32(dst) \
MS_FLOAT32X16 dst##1 = _mm512_setzero_ps(); \
MS_FLOAT32X16 dst##2 = _mm512_setzero_ps(); \
MS_FLOAT32X16 dst##3 = _mm512_setzero_ps(); \
MS_FLOAT32X16 dst##4 = _mm512_setzero_ps(); \
MS_FLOAT32X16 dst##5 = _mm512_setzero_ps(); \
MS_FLOAT32X16 dst##6 = _mm512_setzero_ps(); \
MS_FLOAT32X16 dst##7 = _mm512_setzero_ps(); \
MS_FLOAT32X16 dst##8 = _mm512_setzero_ps();
#define MS_SET_ZERO512X4_F32(dst) \
MS_FLOAT32X16 dst##1 = _mm512_setzero_ps(); \
MS_FLOAT32X16 dst##2 = _mm512_setzero_ps(); \
MS_FLOAT32X16 dst##3 = _mm512_setzero_ps(); \
MS_FLOAT32X16 dst##4 = _mm512_setzero_ps();
#endif // MINDSPORE_NNACL_AVX512_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_

View File

@ -76,7 +76,7 @@ static inline MS_FLOAT32X8 MS_SQRTFX8_F32(MS_FLOAT32X8 src) {
return dst;
}
#define LOAD256X8_F32(src, input_ptr, num) \
#define MS_LOAD256X8_F32(src, input_ptr, num) \
MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \
MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \
MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \
@ -86,7 +86,7 @@ static inline MS_FLOAT32X8 MS_SQRTFX8_F32(MS_FLOAT32X8 src) {
MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \
MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num);
#define LOAD256X16_F32(src, input_ptr, num) \
#define MS_LOAD256X16_F32(src, input_ptr, num) \
MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \
MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \
MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \
@ -154,4 +154,35 @@ static inline MS_FLOAT32X8 MS_TANHX8_F32(MS_FLOAT32X8 src) {
return MS_MIN256_F32(MS_MAX256_F32(MS_DIV256_F32(a, b), neg), pos);
}
#endif
#define MS_FMADD256X8_F32(src, weight, dst) \
dst##1 = MS_MLA256_F32(dst##1, src##1, weight); \
dst##2 = MS_MLA256_F32(dst##2, src##2, weight); \
dst##3 = MS_MLA256_F32(dst##3, src##3, weight); \
dst##4 = MS_MLA256_F32(dst##4, src##4, weight); \
dst##5 = MS_MLA256_F32(dst##5, src##5, weight); \
dst##6 = MS_MLA256_F32(dst##6, src##6, weight); \
dst##7 = MS_MLA256_F32(dst##7, src##7, weight); \
dst##8 = MS_MLA256_F32(dst##8, src##8, weight);
#define MS_SET_ZERO256X8_F32(dst) \
MS_FLOAT32X8 dst##1 = _mm256_setzero_ps(); \
MS_FLOAT32X8 dst##2 = _mm256_setzero_ps(); \
MS_FLOAT32X8 dst##3 = _mm256_setzero_ps(); \
MS_FLOAT32X8 dst##4 = _mm256_setzero_ps(); \
MS_FLOAT32X8 dst##5 = _mm256_setzero_ps(); \
MS_FLOAT32X8 dst##6 = _mm256_setzero_ps(); \
MS_FLOAT32X8 dst##7 = _mm256_setzero_ps(); \
MS_FLOAT32X8 dst##8 = _mm256_setzero_ps();
#define MS_FMADD256X4_F32(src, weight, dst) \
dst##1 = MS_MLA256_F32(dst##1, src##1, weight); \
dst##2 = MS_MLA256_F32(dst##2, src##2, weight); \
dst##3 = MS_MLA256_F32(dst##3, src##3, weight); \
dst##4 = MS_MLA256_F32(dst##4, src##4, weight);
#define MS_SET_ZERO256X4_F32(dst) \
MS_FLOAT32X8 dst##1 = _mm256_setzero_ps(); \
MS_FLOAT32X8 dst##2 = _mm256_setzero_ps(); \
MS_FLOAT32X8 dst##3 = _mm256_setzero_ps(); \
MS_FLOAT32X8 dst##4 = _mm256_setzero_ps();
#endif // MINDSPORE_NNACL_AVX_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_

View File

@ -152,4 +152,51 @@ static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) {
return dst;
}
#endif
#define MS_FMADD128X8_F32(src, weight, dst) \
dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \
dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \
dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \
dst##4 = MS_MLAQ_F32(src##4, weight, dst##4); \
dst##5 = MS_MLAQ_F32(src##5, weight, dst##5); \
dst##6 = MS_MLAQ_F32(src##6, weight, dst##6); \
dst##7 = MS_MLAQ_F32(src##7, weight, dst##7); \
dst##8 = MS_MLAQ_F32(src##8, weight, dst##8);
#define MS_LOAD128X4_F32(src, input_ptr, num) \
MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \
MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \
MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \
MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num);
#define MS_FMADD128X4_F32(src, weight, dst) \
dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \
dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \
dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \
dst##4 = MS_MLAQ_F32(src##4, weight, dst##4);
#define MS_LOAD128X8_F32(src, input_ptr, num) \
MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \
MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \
MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \
MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \
MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \
MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \
MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \
MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num);
#define MS_SET_ZERO128X8_F32(dst) \
MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \
MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \
MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \
MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f); \
MS_FLOAT32X4 dst##5 = MS_MOVQ_F32(0.0f); \
MS_FLOAT32X4 dst##6 = MS_MOVQ_F32(0.0f); \
MS_FLOAT32X4 dst##7 = MS_MOVQ_F32(0.0f); \
MS_FLOAT32X4 dst##8 = MS_MOVQ_F32(0.0f);
#define MS_SET_ZERO128X4_F32(dst) \
MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \
MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \
MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \
MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f);
#endif // MINDSPORE_NNACL_NEON_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_

View File

@ -90,16 +90,6 @@ static inline MS_FLOAT32X4 MS_SQRTFX4_F32(MS_FLOAT32X4 src) {
return dst;
}
#define LOAD128X8_F32(src, input_ptr, num) \
MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \
MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \
MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \
MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \
MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \
MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \
MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \
MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num);
#define STORE128X8_F32(output_ptr, num, dst) \
MS_STQ_F32(output_ptr + 0 * num, dst##1); \
MS_STQ_F32(output_ptr + 1 * num, dst##2); \
@ -137,4 +127,51 @@ static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) {
return dst;
}
#endif
#define MS_FMADD128X8_F32(src, weight, dst) \
dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \
dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \
dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \
dst##4 = MS_MLAQ_F32(src##4, weight, dst##4); \
dst##5 = MS_MLAQ_F32(src##5, weight, dst##5); \
dst##6 = MS_MLAQ_F32(src##6, weight, dst##6); \
dst##7 = MS_MLAQ_F32(src##7, weight, dst##7); \
dst##8 = MS_MLAQ_F32(src##8, weight, dst##8);
#define MS_LOAD128X4_F32(src, input_ptr, num) \
MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \
MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \
MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \
MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num);
#define MS_FMADD128X4_F32(src, weight, dst) \
dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \
dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \
dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \
dst##4 = MS_MLAQ_F32(src##4, weight, dst##4);
#define MS_LOAD128X8_F32(src, input_ptr, num) \
MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \
MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \
MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \
MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \
MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \
MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \
MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \
MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num);
#define MS_SET_ZERO128X8_F32(dst) \
MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \
MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \
MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \
MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f); \
MS_FLOAT32X4 dst##5 = MS_MOVQ_F32(0.0f); \
MS_FLOAT32X4 dst##6 = MS_MOVQ_F32(0.0f); \
MS_FLOAT32X4 dst##7 = MS_MOVQ_F32(0.0f); \
MS_FLOAT32X4 dst##8 = MS_MOVQ_F32(0.0f);
#define MS_SET_ZERO128X4_F32(dst) \
MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \
MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \
MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \
MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f);
#endif // MINDSPORE_NNACL_SSE_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_

View File

@ -37,10 +37,8 @@ int MatmulRun(const void *cdata, int task_id, float, float) {
}
MatmulFp32BaseCPUKernel::~MatmulFp32BaseCPUKernel() {
if (is_pack_) {
FreeResizeBufA();
FreeResizeBufB();
}
if (is_pack_ && out_need_aligned_ && oc_res_ != 0 && output_data_ != nullptr) {
free(output_data_);
output_data_ = nullptr;
@ -250,7 +248,7 @@ int MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch(int task_id) const {
const float *a = a_pack_ptr_ + index * params_->row_ * params_->deep_;
const float *b = b_pack_ptr_ + index * params_->deep_ * params_->col_;
float *c = output_data_ + index * params_->row_ * params_->col_;
GemmIsNotPack(a, b, c, &bias, params_->row_);
gemmIsNotPackFun(a, b, c, &bias, params_->row_, params_->deep_);
}
return RET_OK;
}
@ -294,7 +292,7 @@ int MatmulFp32BaseCPUKernel::ParallelRunByOC(int task_id) const {
return RET_OK;
}
void MatmulFp32BaseCPUKernel::init_global_variable() {
int MatmulFp32BaseCPUKernel::init_global_variable() {
#ifdef ENABLE_AVX512
matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor;
matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2Col64Major : RowMajor2Row64Major;
@ -335,18 +333,27 @@ void MatmulFp32BaseCPUKernel::init_global_variable() {
// need not aligned
col_step_ = params_->col_;
#endif
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->row_align_, RET_ERROR);
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->row_align_, params_->deep_, RET_ERROR);
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->col_align_, RET_ERROR);
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->col_align_, params_->deep_, RET_ERROR);
if (params_->col_ == 1 && params_->b_const_) {
is_pack_ = false;
matrix_a_pack_size_ = a_batch_ * params_->row_ * params_->deep_;
matrix_b_pack_size_ = b_batch_ * params_->col_ * params_->deep_;
matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor;
matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor;
} else {
matrix_a_pack_size_ = a_batch_ * params_->row_align_ * params_->deep_;
matrix_b_pack_size_ = b_batch_ * params_->col_align_ * params_->deep_;
}
return RET_OK;
}
int MatmulFp32BaseCPUKernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), C2NUM);
CHECK_LESS_RETURN(out_tensors_.size(), 1);
init_global_variable();
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->row_align_, RET_ERROR);
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->row_align_, params_->deep_, RET_ERROR);
matrix_a_pack_size_ = a_batch_ * params_->row_align_ * params_->deep_;
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->col_align_, RET_ERROR);
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->col_align_, params_->deep_, RET_ERROR);
matrix_b_pack_size_ = b_batch_ * params_->col_align_ * params_->deep_;
if (matrix_a_pack_size_ < 0 || matrix_b_pack_size_ < 0) {
MS_LOG(ERROR) << "Matrix pack size is negative "
<< "matrix_a_pack_size=" << matrix_a_pack_size_ << "matrix_b_pack_size=" << matrix_b_pack_size_;
@ -358,6 +365,8 @@ int MatmulFp32BaseCPUKernel::Prepare() {
return ret;
}
if (params_->a_const_) {
auto a_tensor = in_tensors_[0];
CHECK_NULL_RETURN(a_tensor);
if (InitBufferA() != RET_OK) {
return RET_ERROR;
}
@ -394,10 +403,6 @@ int MatmulFp32BaseCPUKernel::ReSize() {
set_workspace_size((matrix_a_pack_size_ + matrix_b_pack_size_) * static_cast<int>(sizeof(float)));
}
GetThreadCuttingPolicy();
if (params_->col_ == 1 && params_->deep_ == 1) {
is_pack_ = false;
parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch;
}
auto ret = InitTmpOutBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "InitTmpOutBuffer error!";
@ -438,7 +443,7 @@ int MatmulFp32BaseCPUKernel::InitTmpOutBuffer() {
}
void MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() {
if (params_->batch >= op_parameter_->thread_num_) {
if (params_->batch >= op_parameter_->thread_num_ || (params_->col_ == 1 && params_->b_const_)) {
thread_count_ = op_parameter_->thread_num_;
batch_stride_ = UP_DIV(params_->batch, thread_count_);
batch_split_ = true;
@ -453,6 +458,15 @@ void MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() {
batch_split_ = false;
parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunByOC;
}
if (params_->col_ == 1 && params_->b_const_) {
is_pack_ = false;
parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch;
if (params_->deep_ == 1) {
gemmIsNotPackFun = GemmIsNotPack;
} else {
gemmIsNotPackFun = GemmIsNotPackOptimize;
}
}
}
int MatmulFp32BaseCPUKernel::Run() {
@ -517,12 +531,20 @@ int MatmulFp32BaseCPUKernel::Run() {
PackNHWCXToNHWCFp32(output_data_, out_data, params_->batch, params_->row_, params_->col_, col_tile_);
}
if (!params_->a_const_) {
if (is_pack_) {
FreeResizeBufA();
} else {
a_pack_ptr_ = nullptr;
}
}
if (!params_->b_const_) {
if (is_pack_) {
FreeResizeBufB();
} else {
b_pack_ptr_ = nullptr;
}
}
return RET_OK;
}
} // namespace mindspore::kernel
} // namespace mindspore::kernel

View File

@ -33,6 +33,8 @@ using GemmFun = void (*)(const float *a, const float *b, float *c, const float *
const int depth, const int cur_col, const int col_align, const int row);
using GemvFun = void (*)(const float *a, const float *b, float *c, const float *bias, const int act_type,
const int depth, const int cur_col, const int col_align);
using GemmIsNotPackFun = void (*)(const float *a, const float *b, float *c, const float *bias, int m, int k);
class MatmulFp32BaseCPUKernel : public InnerKernel {
public:
MatmulFp32BaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
@ -62,7 +64,7 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
void FreeBiasBuf();
int InitBiasData();
void InitParameter();
void init_global_variable();
int init_global_variable();
private:
void ResizeParameter();
@ -105,6 +107,7 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
GemmFun gemmCalFun = nullptr;
GemvFun gemvCalFun = nullptr;
#endif
GemmIsNotPackFun gemmIsNotPackFun = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_MATMUL_FP32_BASE_H_