forked from mindspore-Ecosystem/mindspore
matmul-avx support split-by-row when B is a vector
This commit is contained in:
parent
2d7f4fee0a
commit
7b4ec7b351
|
@ -71,7 +71,7 @@ mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/conv_int8.c:Conv1x
|
|||
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/conv_int8.c:Conv1x1PreOptPert
|
||||
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/pack_int8.c:PackNHWCToNCHWInt8
|
||||
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/pooling_fp32.c:AvgPooling
|
||||
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.c:MatMul4x1Kernel
|
||||
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.c:MatMul4x1Kernel, MatMul2x1Kernel
|
||||
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/conv_common_fp32.c:SWConv3x32Kernel, SWConv4x24Kernel, SWConv12x8Kernel, SWConv8x8Kernel, SWConv4x8Kernel, SWConv6x16Kernel, SWConv4x16Kernel
|
||||
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/conv_depthwise_fp32.c:DepthwiseSW3x32Kernel, DepthwiseSW4x24Kernel, DepthwiseSW12x8Kernel, DepthwiseSW8x8Kernel, DepthwiseSW4x8Kernel, DepthwiseSW6x16Kernel, DepthwiseSW4x16Kernel
|
||||
mindspore/mindspore/core/ir/dtype/type.cc:mindspore::ObjectIdLabel
|
||||
|
|
|
@ -1269,52 +1269,71 @@ void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, c
|
|||
#endif
|
||||
#endif
|
||||
|
||||
void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep) {
|
||||
#define ActCompute(bit_num, down_threshold, up_threshold) \
|
||||
if (act_type != 0) { \
|
||||
dst = MS_MAX##bit_num##_F32(dst, down_threshold); \
|
||||
if (act_type == 3) { \
|
||||
dst = MS_MIN##bit_num##_F32(dst, up_threshold); \
|
||||
} \
|
||||
}
|
||||
|
||||
// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
|
||||
void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep, int act_type) {
|
||||
int index = 0;
|
||||
#ifdef ENABLE_AVX512
|
||||
__m512 down_threshold512 = _mm512_setzero_ps();
|
||||
__m512 up_threshold512 = _mm512_set1_ps(C6NUM);
|
||||
__m512 b_data16 = _mm512_set1_ps(b[0]);
|
||||
__m512 bias_data16 = _mm512_set1_ps(bias[0]);
|
||||
#endif
|
||||
#ifdef ENABLE_AVX
|
||||
__m256 b_data8 = _mm256_set1_ps(b[0]);
|
||||
__m256 bias_data8 = _mm256_set1_ps(bias[0]);
|
||||
#endif
|
||||
#if defined(ENABLE_SSE) || defined(ENABLE_ARM)
|
||||
MS_FLOAT32X4 b_data4 = MS_MOVQ_F32(b[0]);
|
||||
MS_FLOAT32X4 bias_data4 = MS_MOVQ_F32(bias[0]);
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_AVX512
|
||||
for (; index < row - C16NUM; index += C16NUM) {
|
||||
__m512 a_data = _mm512_loadu_ps(a + index);
|
||||
_mm512_storeu_ps(c + index, b_data16 * a_data + bias_data16);
|
||||
__m512 dst = b_data16 * a_data + bias_data16;
|
||||
ActCompute(512, down_threshold512, up_threshold512);
|
||||
_mm512_storeu_ps(c + index, dst);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
__m256 down_threshold256 = _mm256_setzero_ps();
|
||||
__m256 up_threshold256 = _mm256_set1_ps(C6NUM);
|
||||
__m256 b_data8 = _mm256_set1_ps(b[0]);
|
||||
__m256 bias_data8 = _mm256_set1_ps(bias[0]);
|
||||
for (; index < row - C8NUM; index += C8NUM) {
|
||||
__m256 a_data = _mm256_loadu_ps(a + index);
|
||||
_mm256_storeu_ps(c + index, b_data8 * a_data + bias_data8);
|
||||
__m256 dst = b_data8 * a_data + bias_data8;
|
||||
ActCompute(256, down_threshold256, up_threshold256);
|
||||
_mm256_storeu_ps(c + index, dst);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_SSE) || defined(ENABLE_ARM)
|
||||
MS_FLOAT32X4 down_threshold128 = MS_MOVQ_F32(0);
|
||||
MS_FLOAT32X4 up_threshold128 = MS_MOVQ_F32(C6NUM);
|
||||
MS_FLOAT32X4 b_data4 = MS_MOVQ_F32(b[0]);
|
||||
MS_FLOAT32X4 bias_data4 = MS_MOVQ_F32(bias[0]);
|
||||
for (; index < row - C4NUM; index += C4NUM) {
|
||||
MS_FLOAT32X4 a_data = MS_LDQ_F32(a + index);
|
||||
MS_STQ_F32(c + index, MS_ADD128_F32(MS_MUL128_F32(b_data4, a_data), bias_data4));
|
||||
MS_FLOAT32X4 dst = MS_ADD128_F32(MS_MUL128_F32(b_data4, a_data), bias_data4);
|
||||
ActCompute(128, down_threshold128, up_threshold128);
|
||||
MS_STQ_F32(c + index, dst);
|
||||
}
|
||||
#endif
|
||||
|
||||
for (; index < row; ++index) {
|
||||
c[index] = a[index] * b[0] + bias[0];
|
||||
float dst = a[index] * b[0] + bias[0];
|
||||
ActCompute(32, 0, C6NUM);
|
||||
c[index] = dst;
|
||||
}
|
||||
}
|
||||
|
||||
void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k) {
|
||||
// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
|
||||
void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k, int act_type) {
|
||||
// gemm dot is [m, k] * [k, 1] ==>> [m, 1]
|
||||
int m_index = 0;
|
||||
#ifdef ENABLE_AVX512
|
||||
// block 8
|
||||
MS_FLOAT32X8 down_threshold256 = _mm256_setzero_ps();
|
||||
MS_FLOAT32X8 up_threshold256 = _mm256_set1_ps(C6NUM);
|
||||
for (; m_index <= m - C8NUM; m_index += C8NUM) {
|
||||
int k_index = 0;
|
||||
MS_FLOAT32X8 dst = MS_MOV256_F32(bias[0]);
|
||||
|
@ -1324,31 +1343,59 @@ void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float
|
|||
MS_LOAD512X8_F32(src, a + m_index * k + k_index, k)
|
||||
MS_FMADD512X8_F32(src, weight, dst16_)
|
||||
}
|
||||
MS_F32X8_GETI(dst, 0) += _mm512_reduce_add_ps(dst16_1);
|
||||
MS_F32X8_GETI(dst, 1) += _mm512_reduce_add_ps(dst16_2);
|
||||
MS_F32X8_GETI(dst, 2) += _mm512_reduce_add_ps(dst16_3);
|
||||
MS_F32X8_GETI(dst, 3) += _mm512_reduce_add_ps(dst16_4);
|
||||
MS_F32X8_GETI(dst, 4) += _mm512_reduce_add_ps(dst16_5);
|
||||
MS_F32X8_GETI(dst, 5) += _mm512_reduce_add_ps(dst16_6);
|
||||
MS_F32X8_GETI(dst, 6) += _mm512_reduce_add_ps(dst16_7);
|
||||
MS_F32X8_GETI(dst, 7) += _mm512_reduce_add_ps(dst16_8);
|
||||
MS_F32X8_GETI(dst, 0) += MS_REDUCE_ADD512_F32(dst16_1);
|
||||
MS_F32X8_GETI(dst, 1) += MS_REDUCE_ADD512_F32(dst16_2);
|
||||
MS_F32X8_GETI(dst, C2NUM) += MS_REDUCE_ADD512_F32(dst16_3);
|
||||
MS_F32X8_GETI(dst, C3NUM) += MS_REDUCE_ADD512_F32(dst16_4);
|
||||
MS_F32X8_GETI(dst, C4NUM) += MS_REDUCE_ADD512_F32(dst16_5);
|
||||
MS_F32X8_GETI(dst, C5NUM) += MS_REDUCE_ADD512_F32(dst16_6);
|
||||
MS_F32X8_GETI(dst, C6NUM) += MS_REDUCE_ADD512_F32(dst16_7);
|
||||
MS_F32X8_GETI(dst, C7NUM) += MS_REDUCE_ADD512_F32(dst16_8);
|
||||
for (; k_index < k; k_index++) {
|
||||
MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index];
|
||||
MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k];
|
||||
MS_F32X8_GETI(dst, 2) += b[k_index] * a[m_index * k + k_index + 2 * k];
|
||||
MS_F32X8_GETI(dst, 3) += b[k_index] * a[m_index * k + k_index + 3 * k];
|
||||
MS_F32X8_GETI(dst, 4) += b[k_index] * a[m_index * k + k_index + 4 * k];
|
||||
MS_F32X8_GETI(dst, 5) += b[k_index] * a[m_index * k + k_index + 5 * k];
|
||||
MS_F32X8_GETI(dst, 6) += b[k_index] * a[m_index * k + k_index + 6 * k];
|
||||
MS_F32X8_GETI(dst, 7) += b[k_index] * a[m_index * k + k_index + 7 * k];
|
||||
MS_F32X8_GETI(dst, C2NUM) += b[k_index] * a[m_index * k + k_index + C2NUM * k];
|
||||
MS_F32X8_GETI(dst, C3NUM) += b[k_index] * a[m_index * k + k_index + C3NUM * k];
|
||||
MS_F32X8_GETI(dst, C4NUM) += b[k_index] * a[m_index * k + k_index + C4NUM * k];
|
||||
MS_F32X8_GETI(dst, C5NUM) += b[k_index] * a[m_index * k + k_index + C5NUM * k];
|
||||
MS_F32X8_GETI(dst, C6NUM) += b[k_index] * a[m_index * k + k_index + C6NUM * k];
|
||||
MS_F32X8_GETI(dst, C7NUM) += b[k_index] * a[m_index * k + k_index + C7NUM * k];
|
||||
}
|
||||
ActCompute(256, down_threshold256, up_threshold256);
|
||||
MS_ST256_F32(c + m_index, dst);
|
||||
}
|
||||
#endif
|
||||
#ifdef ENABLE_AVX
|
||||
// block 4
|
||||
MS_FLOAT32X4 down_threshold128 = MS_MOVQ_F32(0);
|
||||
MS_FLOAT32X4 up_threshold128 = MS_MOVQ_F32(C6NUM);
|
||||
for (; m_index <= m - C4NUM; m_index += C4NUM) {
|
||||
int k_index = 0;
|
||||
MS_FLOAT32X4 dst = MS_MOV128_F32(bias[0]);
|
||||
MS_SET_ZERO256X4_F32(dst_)
|
||||
for (; k_index <= k - C8NUM; k_index += C8NUM) {
|
||||
MS_FLOAT32X8 weight = MS_LD256_F32(b + k_index);
|
||||
MS_LOAD256X4_F32(src, a + m_index * k + k_index, k);
|
||||
MS_FMADD256X4_F32(src, weight, dst_);
|
||||
}
|
||||
MS_F32X8_GETI(dst, 0) += MS_REDUCE_ADD256_F32(dst_1);
|
||||
MS_F32X8_GETI(dst, 1) += MS_REDUCE_ADD256_F32(dst_2);
|
||||
MS_F32X8_GETI(dst, C2NUM) += MS_REDUCE_ADD256_F32(dst_3);
|
||||
MS_F32X8_GETI(dst, C3NUM) += MS_REDUCE_ADD256_F32(dst_4);
|
||||
for (; k_index < k; ++k_index) {
|
||||
MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index];
|
||||
MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k];
|
||||
MS_F32X8_GETI(dst, C2NUM) += b[k_index] * a[m_index * k + k_index + C2NUM * k];
|
||||
MS_F32X8_GETI(dst, C3NUM) += b[k_index] * a[m_index * k + k_index + C3NUM * k];
|
||||
}
|
||||
ActCompute(128, down_threshold128, up_threshold128);
|
||||
MS_ST128_F32(c + m_index, dst);
|
||||
}
|
||||
#endif
|
||||
|
||||
// block 1
|
||||
for (; m_index < m; m_index++) {
|
||||
c[m_index] = bias[0];
|
||||
float dst = bias[0];
|
||||
int k_index = 0;
|
||||
#ifdef ENABLE_AVX512
|
||||
__m512 dst1 = _mm512_setzero_ps();
|
||||
|
@ -1357,16 +1404,29 @@ void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float
|
|||
__m512 a1 = _mm512_loadu_ps(a + m_index * k + k_index);
|
||||
dst1 = _mm512_fmadd_ps(weight, a1, dst1);
|
||||
}
|
||||
c[m_index] += _mm512_reduce_add_ps(dst1);
|
||||
dst += _mm512_reduce_add_ps(dst1);
|
||||
#endif
|
||||
#ifdef ENABLE_AVX
|
||||
__m256 dst2 = _mm256_setzero_ps();
|
||||
for (; k_index <= k - C8NUM; k_index += C8NUM) {
|
||||
__m256 weight = _mm256_loadu_ps(b + k_index);
|
||||
__m256 src = _mm256_loadu_ps(a + m_index * k + k_index);
|
||||
dst2 = _mm256_fmadd_ps(weight, src, dst2);
|
||||
}
|
||||
dst += MS_REDUCE_ADD256_F32(dst2);
|
||||
#endif
|
||||
for (; k_index < k; k_index++) {
|
||||
c[m_index] += b[k_index] * a[m_index * k + k_index];
|
||||
dst += b[k_index] * a[m_index * k + k_index];
|
||||
}
|
||||
ActCompute(32, 0, C6NUM);
|
||||
c[m_index] = dst;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void MatMul4x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) {
|
||||
// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
|
||||
void MatMul4x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep,
|
||||
size_t act_type) {
|
||||
// 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute
|
||||
// 9: WriteBack
|
||||
asm volatile(
|
||||
|
@ -1495,15 +1555,26 @@ void MatMul4x1Kernel(const float *input, const float *weight, float *output, con
|
|||
"ld1r {v1.4s}, [%[bias]]\n"
|
||||
"fadd v0.4s, v0.4s, v1.4s\n"
|
||||
"9:\n"
|
||||
"cbz %[act], 10f\n"
|
||||
"dup v1.2d, xzr\n"
|
||||
"fmax v0.4s, v0.4s, v1.4s\n"
|
||||
"cmp %[act], #3\n"
|
||||
"bne 10f\n"
|
||||
"movi v1.4s, #6\n"
|
||||
"scvtf v1.4s, v1.4s\n"
|
||||
"fmin v0.4s, v0.4s, v1.4s\n"
|
||||
"10:\n"
|
||||
"st1 {v0.4s}, [%[output]]\n"
|
||||
|
||||
:
|
||||
: [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep)
|
||||
: [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep),
|
||||
[ act ] "r"(act_type)
|
||||
: "cc", "x5", "x6", "x7", "x8", "x9", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18",
|
||||
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
|
||||
}
|
||||
|
||||
void MatMul2x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) {
|
||||
void MatMul2x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep,
|
||||
size_t act_type) {
|
||||
// 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute
|
||||
// 9: WriteBack
|
||||
asm volatile(
|
||||
|
@ -1590,15 +1661,26 @@ void MatMul2x1Kernel(const float *input, const float *weight, float *output, con
|
|||
"ld1r {v1.4s}, [%[bias]]\n"
|
||||
"fadd v0.2s, v0.2s, v1.2s\n"
|
||||
"9:\n"
|
||||
"cbz %[act], 10f\n"
|
||||
"fmov d1, xzr\n"
|
||||
"fmax v0.2s, v0.2s, v1.2s\n"
|
||||
"cmp %[act], #3\n"
|
||||
"bne 10f\n"
|
||||
"movi v1.2s, #6\n"
|
||||
"scvtf v1.2s, v1.2s\n"
|
||||
"fmin v0.2s, v0.2s, v1.2s\n"
|
||||
"10:\n"
|
||||
"st1 {v0.2s}, [%[output]]\n"
|
||||
|
||||
:
|
||||
: [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep)
|
||||
: [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep),
|
||||
[ act ] "r"(act_type)
|
||||
: "cc", "x5", "x8", "x9", "x10", "v0", "v1", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29",
|
||||
"v30", "v31", "memory");
|
||||
}
|
||||
|
||||
void MatMul1x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) {
|
||||
void MatMul1x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep,
|
||||
size_t act_type) {
|
||||
// 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute
|
||||
// 9: WriteBack
|
||||
asm volatile(
|
||||
|
@ -1665,30 +1747,40 @@ void MatMul1x1Kernel(const float *input, const float *weight, float *output, con
|
|||
"ld1 {v1.s}[0], [%[bias]]\n"
|
||||
"fadd s0, s0, s1\n"
|
||||
"9:\n"
|
||||
"st1 {v0.s}[0], [%[output]]\n"
|
||||
"cbz %[act], 10f\n"
|
||||
"fmov s1, wzr\n"
|
||||
"fmax s0, s0, s1\n"
|
||||
"cmp %[act], #3\n"
|
||||
"bne 10f\n"
|
||||
"mov x10, #6\n"
|
||||
"scvtf s1, x10\n"
|
||||
"fmin s0, s0, s1\n"
|
||||
"10:\n"
|
||||
"str s0, [%[output]]\n"
|
||||
|
||||
:
|
||||
: [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep)
|
||||
: [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep),
|
||||
[ act ] "r"(act_type)
|
||||
: "cc", "x8", "x9", "x10", "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29", "v30", "v31");
|
||||
}
|
||||
|
||||
void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row,
|
||||
int deep) {
|
||||
int deep, int act_type) {
|
||||
const float *input = a + start_row * deep;
|
||||
float *output = c + start_row;
|
||||
const int step = C4NUM * deep;
|
||||
for (; start_row <= end_row - C4NUM; start_row += C4NUM) {
|
||||
MatMul4x1Kernel(input, b, output, bias, deep);
|
||||
MatMul4x1Kernel(input, b, output, bias, deep, act_type);
|
||||
input += step;
|
||||
output += C4NUM;
|
||||
}
|
||||
for (; start_row <= end_row - C2NUM; start_row += C2NUM) {
|
||||
MatMul2x1Kernel(input, b, output, bias, deep);
|
||||
MatMul2x1Kernel(input, b, output, bias, deep, act_type);
|
||||
input += C2NUM * deep;
|
||||
output += C2NUM;
|
||||
}
|
||||
if (start_row == end_row - 1) {
|
||||
MatMul1x1Kernel(input, b, output, bias, deep);
|
||||
MatMul1x1Kernel(input, b, output, bias, deep, act_type);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -108,13 +108,13 @@ void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *
|
|||
void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
|
||||
int col, int stride, int out_type);
|
||||
|
||||
void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep);
|
||||
void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep, int act_type);
|
||||
|
||||
void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k);
|
||||
void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k, int act_type);
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row,
|
||||
int deep);
|
||||
int deep, int act_type);
|
||||
#endif
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -74,6 +74,7 @@
|
|||
#define MS_BLEND512_F32(src1, src2, mask) _mm512_mask_blend_ps(mask, src1, src2)
|
||||
#define MS_BLEND512_EPI32(src1, src2, mask) _mm512_mask_blend_epi32(mask, src1, src2)
|
||||
#define MS_CAST512_F32_S32(src) _mm512_castsi512_ps(src)
|
||||
#define MS_REDUCE_ADD512_F32(src) _mm512_reduce_add_ps(src)
|
||||
|
||||
static inline float MS_GET_MAX512_F32(__m512 src) {
|
||||
float result = MS_F32X16_GETI(src, 0);
|
||||
|
|
|
@ -247,4 +247,6 @@ static inline MS_FLOAT32X8 MS_TANHX8_F32(MS_FLOAT32X8 src) {
|
|||
MS_FLOAT32X8 dst##2 = _mm256_setzero_ps(); \
|
||||
MS_FLOAT32X8 dst##3 = _mm256_setzero_ps(); \
|
||||
MS_FLOAT32X8 dst##4 = _mm256_setzero_ps();
|
||||
|
||||
#define MS_REDUCE_ADD256_F32(src) (src = _mm256_hadd_ps(src, src), src = _mm256_hadd_ps(src, src), src[0] + src[4]);
|
||||
#endif // MINDSPORE_NNACL_AVX_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
|
||||
|
|
|
@ -110,7 +110,7 @@ int MatmulFp32BaseCPUKernel::ParallelRunByRow(int task_id) const {
|
|||
return RET_OK;
|
||||
}
|
||||
GemmIsNotPackByRow(matrix_a_.pack_ptr, matrix_b_.pack_ptr, output_data_, matrix_c_.pack_ptr, start_row, end_row,
|
||||
params_->deep_);
|
||||
params_->deep_, params_->act_type_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -69,8 +69,16 @@ int MatmulFp32BaseCPUKernel::ParallelRunByRow(int task_id) const {
|
|||
}
|
||||
const float *input = matrix_a_.pack_ptr + start_row * params_->deep_;
|
||||
float *output = output_data_ + start_row * params_->col_align_;
|
||||
MatMulAvxFp32(input, matrix_b_.pack_ptr, output, matrix_c_.pack_ptr, params_->act_type_, params_->deep_,
|
||||
params_->col_align_, params_->col_align_, row_num);
|
||||
if (params_->col_ == 1) {
|
||||
float bias = 0;
|
||||
if (matrix_c_.pack_ptr != nullptr) {
|
||||
bias = matrix_c_.pack_ptr[0];
|
||||
}
|
||||
gemmIsNotPackFun(input, matrix_b_.pack_ptr, output, &bias, row_num, params_->deep_, params_->act_type_);
|
||||
} else {
|
||||
MatMulAvxFp32(input, matrix_b_.pack_ptr, output, matrix_c_.pack_ptr, params_->act_type_, params_->deep_,
|
||||
params_->col_align_, params_->col_align_, row_num);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -102,12 +110,13 @@ bool MatmulFp32BaseCPUKernel::CheckThreadCuttingByRow() {
|
|||
if (b_batch_ != C1NUM) {
|
||||
return false;
|
||||
}
|
||||
if (params_->col_ == 1 && !params_->a_const_) {
|
||||
return false;
|
||||
}
|
||||
if (row_num_ < op_parameter_->thread_num_) {
|
||||
return false;
|
||||
}
|
||||
if (params_->col_ == 1) {
|
||||
row_min_unit_ = C4NUM;
|
||||
return true;
|
||||
}
|
||||
row_min_unit_ = C3NUM;
|
||||
if (col_step_ < C16NUM) {
|
||||
row_min_unit_ = C8NUM;
|
||||
|
|
|
@ -70,8 +70,16 @@ int MatmulFp32BaseCPUKernel::ParallelRunByRow(int task_id) const {
|
|||
}
|
||||
const float *input = matrix_a_.pack_ptr + start_row * params_->deep_;
|
||||
float *output = output_data_ + start_row * params_->col_align_;
|
||||
MatMulAvx512Fp32(input, matrix_b_.pack_ptr, output, matrix_c_.pack_ptr, params_->act_type_, params_->deep_,
|
||||
params_->col_align_, params_->col_align_, row_num);
|
||||
if (params_->col_ == 1) {
|
||||
float bias = 0;
|
||||
if (matrix_c_.pack_ptr != nullptr) {
|
||||
bias = matrix_c_.pack_ptr[0];
|
||||
}
|
||||
gemmIsNotPackFun(input, matrix_b_.pack_ptr, output, &bias, row_num, params_->deep_, params_->act_type_);
|
||||
} else {
|
||||
MatMulAvx512Fp32(input, matrix_b_.pack_ptr, output, matrix_c_.pack_ptr, params_->act_type_, params_->deep_,
|
||||
params_->col_align_, params_->col_align_, row_num);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -104,12 +112,13 @@ bool MatmulFp32BaseCPUKernel::CheckThreadCuttingByRow() {
|
|||
if (b_batch_ != C1NUM) {
|
||||
return false;
|
||||
}
|
||||
if (params_->batch >= op_parameter_->thread_num_ || params_->col_ == 1) {
|
||||
return false;
|
||||
}
|
||||
if (row_num_ < op_parameter_->thread_num_) {
|
||||
return false;
|
||||
}
|
||||
if (params_->col_ == 1) {
|
||||
row_min_unit_ = C8NUM;
|
||||
return true;
|
||||
}
|
||||
row_min_unit_ = C6NUM;
|
||||
if (col_step_ < C48NUM) {
|
||||
row_min_unit_ = C12NUM;
|
||||
|
|
|
@ -230,7 +230,7 @@ int MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch(int task_id) const {
|
|||
const float *a = matrix_a_.pack_ptr + a_offset_[index] * params_->row_ * params_->deep_;
|
||||
const float *b = matrix_b_.pack_ptr + b_offset_[index] * params_->deep_ * params_->col_;
|
||||
float *c = output_data_ + index * params_->row_ * params_->col_;
|
||||
gemmIsNotPackFun(a, b, c, &bias, params_->row_, params_->deep_);
|
||||
gemmIsNotPackFun(a, b, c, &bias, params_->row_, params_->deep_, params_->act_type_);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -247,6 +247,11 @@ int MatmulFp32BaseCPUKernel::Prepare() {
|
|||
MS_CHECK_TRUE_MSG(in_tensors_[THIRD_INPUT]->data_type() == kNumberTypeFloat32, RET_ERROR,
|
||||
"matrix-c's data type is invalid.");
|
||||
}
|
||||
auto act_type = params_->act_type_;
|
||||
if (act_type != ActType_No && act_type != ActType_Relu && act_type != ActType_Relu6) {
|
||||
MS_LOG(ERROR) << "matmul don't support the act-type: " << act_type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret = InitParameter();
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "Init parameters failed.");
|
||||
if (params_->a_const_) {
|
||||
|
|
|
@ -30,7 +30,8 @@ using mindspore::lite::RET_OK;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
using MatrixPackFun = void (*)(const float *src_ptr, float *dst_ptr, int row, int col);
|
||||
using GemmIsNotPackFun = void (*)(const float *a, const float *b, float *c, const float *bias, int m, int k);
|
||||
using GemmIsNotPackFun = void (*)(const float *a, const float *b, float *c, const float *bias, int m, int k,
|
||||
int act_type);
|
||||
|
||||
class MatmulFp32BaseCPUKernel : public LiteKernel {
|
||||
public:
|
||||
|
@ -79,6 +80,7 @@ class MatmulFp32BaseCPUKernel : public LiteKernel {
|
|||
|
||||
protected:
|
||||
MatMulParameter *params_ = nullptr;
|
||||
GemmIsNotPackFun gemmIsNotPackFun = nullptr;
|
||||
int a_batch_ = 1;
|
||||
int b_batch_ = 1;
|
||||
std::vector<int> a_offset_;
|
||||
|
@ -95,7 +97,6 @@ class MatmulFp32BaseCPUKernel : public LiteKernel {
|
|||
float *output_data_ = nullptr;
|
||||
bool out_need_aligned_ = false;
|
||||
int col_step_ = 0;
|
||||
GemmIsNotPackFun gemmIsNotPackFun = nullptr;
|
||||
std::vector<int> split_points_;
|
||||
MatrixInfo matrix_a_;
|
||||
MatrixInfo matrix_b_;
|
||||
|
|
Loading…
Reference in New Issue