matmul-avx support split-by-row when B is a vector

This commit is contained in:
xuanyue 2022-05-09 17:32:50 +08:00
parent 2d7f4fee0a
commit 7b4ec7b351
10 changed files with 182 additions and 63 deletions

View File

@ -71,7 +71,7 @@ mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/conv_int8.c:Conv1x
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/conv_int8.c:Conv1x1PreOptPert
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/pack_int8.c:PackNHWCToNCHWInt8
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/pooling_fp32.c:AvgPooling
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.c:MatMul4x1Kernel
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.c:MatMul4x1Kernel, MatMul2x1Kernel
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/conv_common_fp32.c:SWConv3x32Kernel, SWConv4x24Kernel, SWConv12x8Kernel, SWConv8x8Kernel, SWConv4x8Kernel, SWConv6x16Kernel, SWConv4x16Kernel
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/conv_depthwise_fp32.c:DepthwiseSW3x32Kernel, DepthwiseSW4x24Kernel, DepthwiseSW12x8Kernel, DepthwiseSW8x8Kernel, DepthwiseSW4x8Kernel, DepthwiseSW6x16Kernel, DepthwiseSW4x16Kernel
mindspore/mindspore/core/ir/dtype/type.cc:mindspore::ObjectIdLabel

View File

@ -1269,52 +1269,71 @@ void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, c
#endif
#endif
void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep) {
#define ActCompute(bit_num, down_threshold, up_threshold) \
if (act_type != 0) { \
dst = MS_MAX##bit_num##_F32(dst, down_threshold); \
if (act_type == 3) { \
dst = MS_MIN##bit_num##_F32(dst, up_threshold); \
} \
}
// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep, int act_type) {
int index = 0;
#ifdef ENABLE_AVX512
__m512 down_threshold512 = _mm512_setzero_ps();
__m512 up_threshold512 = _mm512_set1_ps(C6NUM);
__m512 b_data16 = _mm512_set1_ps(b[0]);
__m512 bias_data16 = _mm512_set1_ps(bias[0]);
#endif
#ifdef ENABLE_AVX
__m256 b_data8 = _mm256_set1_ps(b[0]);
__m256 bias_data8 = _mm256_set1_ps(bias[0]);
#endif
#if defined(ENABLE_SSE) || defined(ENABLE_ARM)
MS_FLOAT32X4 b_data4 = MS_MOVQ_F32(b[0]);
MS_FLOAT32X4 bias_data4 = MS_MOVQ_F32(bias[0]);
#endif
#ifdef ENABLE_AVX512
for (; index < row - C16NUM; index += C16NUM) {
__m512 a_data = _mm512_loadu_ps(a + index);
_mm512_storeu_ps(c + index, b_data16 * a_data + bias_data16);
__m512 dst = b_data16 * a_data + bias_data16;
ActCompute(512, down_threshold512, up_threshold512);
_mm512_storeu_ps(c + index, dst);
}
#endif
#ifdef ENABLE_AVX
__m256 down_threshold256 = _mm256_setzero_ps();
__m256 up_threshold256 = _mm256_set1_ps(C6NUM);
__m256 b_data8 = _mm256_set1_ps(b[0]);
__m256 bias_data8 = _mm256_set1_ps(bias[0]);
for (; index < row - C8NUM; index += C8NUM) {
__m256 a_data = _mm256_loadu_ps(a + index);
_mm256_storeu_ps(c + index, b_data8 * a_data + bias_data8);
__m256 dst = b_data8 * a_data + bias_data8;
ActCompute(256, down_threshold256, up_threshold256);
_mm256_storeu_ps(c + index, dst);
}
#endif
#if defined(ENABLE_SSE) || defined(ENABLE_ARM)
MS_FLOAT32X4 down_threshold128 = MS_MOVQ_F32(0);
MS_FLOAT32X4 up_threshold128 = MS_MOVQ_F32(C6NUM);
MS_FLOAT32X4 b_data4 = MS_MOVQ_F32(b[0]);
MS_FLOAT32X4 bias_data4 = MS_MOVQ_F32(bias[0]);
for (; index < row - C4NUM; index += C4NUM) {
MS_FLOAT32X4 a_data = MS_LDQ_F32(a + index);
MS_STQ_F32(c + index, MS_ADD128_F32(MS_MUL128_F32(b_data4, a_data), bias_data4));
MS_FLOAT32X4 dst = MS_ADD128_F32(MS_MUL128_F32(b_data4, a_data), bias_data4);
ActCompute(128, down_threshold128, up_threshold128);
MS_STQ_F32(c + index, dst);
}
#endif
for (; index < row; ++index) {
c[index] = a[index] * b[0] + bias[0];
float dst = a[index] * b[0] + bias[0];
ActCompute(32, 0, C6NUM);
c[index] = dst;
}
}
void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k) {
// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k, int act_type) {
// gemm dot is [m, k] * [k, 1] ==>> [m, 1]
int m_index = 0;
#ifdef ENABLE_AVX512
// block 8
MS_FLOAT32X8 down_threshold256 = _mm256_setzero_ps();
MS_FLOAT32X8 up_threshold256 = _mm256_set1_ps(C6NUM);
for (; m_index <= m - C8NUM; m_index += C8NUM) {
int k_index = 0;
MS_FLOAT32X8 dst = MS_MOV256_F32(bias[0]);
@ -1324,31 +1343,59 @@ void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float
MS_LOAD512X8_F32(src, a + m_index * k + k_index, k)
MS_FMADD512X8_F32(src, weight, dst16_)
}
MS_F32X8_GETI(dst, 0) += _mm512_reduce_add_ps(dst16_1);
MS_F32X8_GETI(dst, 1) += _mm512_reduce_add_ps(dst16_2);
MS_F32X8_GETI(dst, 2) += _mm512_reduce_add_ps(dst16_3);
MS_F32X8_GETI(dst, 3) += _mm512_reduce_add_ps(dst16_4);
MS_F32X8_GETI(dst, 4) += _mm512_reduce_add_ps(dst16_5);
MS_F32X8_GETI(dst, 5) += _mm512_reduce_add_ps(dst16_6);
MS_F32X8_GETI(dst, 6) += _mm512_reduce_add_ps(dst16_7);
MS_F32X8_GETI(dst, 7) += _mm512_reduce_add_ps(dst16_8);
MS_F32X8_GETI(dst, 0) += MS_REDUCE_ADD512_F32(dst16_1);
MS_F32X8_GETI(dst, 1) += MS_REDUCE_ADD512_F32(dst16_2);
MS_F32X8_GETI(dst, C2NUM) += MS_REDUCE_ADD512_F32(dst16_3);
MS_F32X8_GETI(dst, C3NUM) += MS_REDUCE_ADD512_F32(dst16_4);
MS_F32X8_GETI(dst, C4NUM) += MS_REDUCE_ADD512_F32(dst16_5);
MS_F32X8_GETI(dst, C5NUM) += MS_REDUCE_ADD512_F32(dst16_6);
MS_F32X8_GETI(dst, C6NUM) += MS_REDUCE_ADD512_F32(dst16_7);
MS_F32X8_GETI(dst, C7NUM) += MS_REDUCE_ADD512_F32(dst16_8);
for (; k_index < k; k_index++) {
MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index];
MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k];
MS_F32X8_GETI(dst, 2) += b[k_index] * a[m_index * k + k_index + 2 * k];
MS_F32X8_GETI(dst, 3) += b[k_index] * a[m_index * k + k_index + 3 * k];
MS_F32X8_GETI(dst, 4) += b[k_index] * a[m_index * k + k_index + 4 * k];
MS_F32X8_GETI(dst, 5) += b[k_index] * a[m_index * k + k_index + 5 * k];
MS_F32X8_GETI(dst, 6) += b[k_index] * a[m_index * k + k_index + 6 * k];
MS_F32X8_GETI(dst, 7) += b[k_index] * a[m_index * k + k_index + 7 * k];
MS_F32X8_GETI(dst, C2NUM) += b[k_index] * a[m_index * k + k_index + C2NUM * k];
MS_F32X8_GETI(dst, C3NUM) += b[k_index] * a[m_index * k + k_index + C3NUM * k];
MS_F32X8_GETI(dst, C4NUM) += b[k_index] * a[m_index * k + k_index + C4NUM * k];
MS_F32X8_GETI(dst, C5NUM) += b[k_index] * a[m_index * k + k_index + C5NUM * k];
MS_F32X8_GETI(dst, C6NUM) += b[k_index] * a[m_index * k + k_index + C6NUM * k];
MS_F32X8_GETI(dst, C7NUM) += b[k_index] * a[m_index * k + k_index + C7NUM * k];
}
ActCompute(256, down_threshold256, up_threshold256);
MS_ST256_F32(c + m_index, dst);
}
#endif
#ifdef ENABLE_AVX
// block 4
MS_FLOAT32X4 down_threshold128 = MS_MOVQ_F32(0);
MS_FLOAT32X4 up_threshold128 = MS_MOVQ_F32(C6NUM);
for (; m_index <= m - C4NUM; m_index += C4NUM) {
int k_index = 0;
MS_FLOAT32X4 dst = MS_MOV128_F32(bias[0]);
MS_SET_ZERO256X4_F32(dst_)
for (; k_index <= k - C8NUM; k_index += C8NUM) {
MS_FLOAT32X8 weight = MS_LD256_F32(b + k_index);
MS_LOAD256X4_F32(src, a + m_index * k + k_index, k);
MS_FMADD256X4_F32(src, weight, dst_);
}
MS_F32X8_GETI(dst, 0) += MS_REDUCE_ADD256_F32(dst_1);
MS_F32X8_GETI(dst, 1) += MS_REDUCE_ADD256_F32(dst_2);
MS_F32X8_GETI(dst, C2NUM) += MS_REDUCE_ADD256_F32(dst_3);
MS_F32X8_GETI(dst, C3NUM) += MS_REDUCE_ADD256_F32(dst_4);
for (; k_index < k; ++k_index) {
MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index];
MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k];
MS_F32X8_GETI(dst, C2NUM) += b[k_index] * a[m_index * k + k_index + C2NUM * k];
MS_F32X8_GETI(dst, C3NUM) += b[k_index] * a[m_index * k + k_index + C3NUM * k];
}
ActCompute(128, down_threshold128, up_threshold128);
MS_ST128_F32(c + m_index, dst);
}
#endif
// block 1
for (; m_index < m; m_index++) {
c[m_index] = bias[0];
float dst = bias[0];
int k_index = 0;
#ifdef ENABLE_AVX512
__m512 dst1 = _mm512_setzero_ps();
@ -1357,16 +1404,29 @@ void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float
__m512 a1 = _mm512_loadu_ps(a + m_index * k + k_index);
dst1 = _mm512_fmadd_ps(weight, a1, dst1);
}
c[m_index] += _mm512_reduce_add_ps(dst1);
dst += _mm512_reduce_add_ps(dst1);
#endif
#ifdef ENABLE_AVX
__m256 dst2 = _mm256_setzero_ps();
for (; k_index <= k - C8NUM; k_index += C8NUM) {
__m256 weight = _mm256_loadu_ps(b + k_index);
__m256 src = _mm256_loadu_ps(a + m_index * k + k_index);
dst2 = _mm256_fmadd_ps(weight, src, dst2);
}
dst += MS_REDUCE_ADD256_F32(dst2);
#endif
for (; k_index < k; k_index++) {
c[m_index] += b[k_index] * a[m_index * k + k_index];
dst += b[k_index] * a[m_index * k + k_index];
}
ActCompute(32, 0, C6NUM);
c[m_index] = dst;
}
}
#ifdef ENABLE_ARM64
void MatMul4x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) {
// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
void MatMul4x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep,
size_t act_type) {
// 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute
// 9: WriteBack
asm volatile(
@ -1495,15 +1555,26 @@ void MatMul4x1Kernel(const float *input, const float *weight, float *output, con
"ld1r {v1.4s}, [%[bias]]\n"
"fadd v0.4s, v0.4s, v1.4s\n"
"9:\n"
"cbz %[act], 10f\n"
"dup v1.2d, xzr\n"
"fmax v0.4s, v0.4s, v1.4s\n"
"cmp %[act], #3\n"
"bne 10f\n"
"movi v1.4s, #6\n"
"scvtf v1.4s, v1.4s\n"
"fmin v0.4s, v0.4s, v1.4s\n"
"10:\n"
"st1 {v0.4s}, [%[output]]\n"
:
: [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep)
: [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep),
[ act ] "r"(act_type)
: "cc", "x5", "x6", "x7", "x8", "x9", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18",
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
}
void MatMul2x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) {
void MatMul2x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep,
size_t act_type) {
// 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute
// 9: WriteBack
asm volatile(
@ -1590,15 +1661,26 @@ void MatMul2x1Kernel(const float *input, const float *weight, float *output, con
"ld1r {v1.4s}, [%[bias]]\n"
"fadd v0.2s, v0.2s, v1.2s\n"
"9:\n"
"cbz %[act], 10f\n"
"fmov d1, xzr\n"
"fmax v0.2s, v0.2s, v1.2s\n"
"cmp %[act], #3\n"
"bne 10f\n"
"movi v1.2s, #6\n"
"scvtf v1.2s, v1.2s\n"
"fmin v0.2s, v0.2s, v1.2s\n"
"10:\n"
"st1 {v0.2s}, [%[output]]\n"
:
: [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep)
: [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep),
[ act ] "r"(act_type)
: "cc", "x5", "x8", "x9", "x10", "v0", "v1", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29",
"v30", "v31", "memory");
}
void MatMul1x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) {
void MatMul1x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep,
size_t act_type) {
// 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute
// 9: WriteBack
asm volatile(
@ -1665,30 +1747,40 @@ void MatMul1x1Kernel(const float *input, const float *weight, float *output, con
"ld1 {v1.s}[0], [%[bias]]\n"
"fadd s0, s0, s1\n"
"9:\n"
"st1 {v0.s}[0], [%[output]]\n"
"cbz %[act], 10f\n"
"fmov s1, wzr\n"
"fmax s0, s0, s1\n"
"cmp %[act], #3\n"
"bne 10f\n"
"mov x10, #6\n"
"scvtf s1, x10\n"
"fmin s0, s0, s1\n"
"10:\n"
"str s0, [%[output]]\n"
:
: [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep)
: [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep),
[ act ] "r"(act_type)
: "cc", "x8", "x9", "x10", "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29", "v30", "v31");
}
void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row,
int deep) {
int deep, int act_type) {
const float *input = a + start_row * deep;
float *output = c + start_row;
const int step = C4NUM * deep;
for (; start_row <= end_row - C4NUM; start_row += C4NUM) {
MatMul4x1Kernel(input, b, output, bias, deep);
MatMul4x1Kernel(input, b, output, bias, deep, act_type);
input += step;
output += C4NUM;
}
for (; start_row <= end_row - C2NUM; start_row += C2NUM) {
MatMul2x1Kernel(input, b, output, bias, deep);
MatMul2x1Kernel(input, b, output, bias, deep, act_type);
input += C2NUM * deep;
output += C2NUM;
}
if (start_row == end_row - 1) {
MatMul1x1Kernel(input, b, output, bias, deep);
MatMul1x1Kernel(input, b, output, bias, deep, act_type);
}
}
#endif

View File

@ -108,13 +108,13 @@ void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *
void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
int col, int stride, int out_type);
void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep);
void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep, int act_type);
void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k);
void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k, int act_type);
#ifdef ENABLE_ARM64
void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row,
int deep);
int deep, int act_type);
#endif
#ifdef __cplusplus
}

View File

@ -74,6 +74,7 @@
#define MS_BLEND512_F32(src1, src2, mask) _mm512_mask_blend_ps(mask, src1, src2)
#define MS_BLEND512_EPI32(src1, src2, mask) _mm512_mask_blend_epi32(mask, src1, src2)
#define MS_CAST512_F32_S32(src) _mm512_castsi512_ps(src)
#define MS_REDUCE_ADD512_F32(src) _mm512_reduce_add_ps(src)
static inline float MS_GET_MAX512_F32(__m512 src) {
float result = MS_F32X16_GETI(src, 0);

View File

@ -247,4 +247,6 @@ static inline MS_FLOAT32X8 MS_TANHX8_F32(MS_FLOAT32X8 src) {
MS_FLOAT32X8 dst##2 = _mm256_setzero_ps(); \
MS_FLOAT32X8 dst##3 = _mm256_setzero_ps(); \
MS_FLOAT32X8 dst##4 = _mm256_setzero_ps();
#define MS_REDUCE_ADD256_F32(src) (src = _mm256_hadd_ps(src, src), src = _mm256_hadd_ps(src, src), src[0] + src[4]);
#endif // MINDSPORE_NNACL_AVX_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_

View File

@ -110,7 +110,7 @@ int MatmulFp32BaseCPUKernel::ParallelRunByRow(int task_id) const {
return RET_OK;
}
GemmIsNotPackByRow(matrix_a_.pack_ptr, matrix_b_.pack_ptr, output_data_, matrix_c_.pack_ptr, start_row, end_row,
params_->deep_);
params_->deep_, params_->act_type_);
return RET_OK;
}

View File

@ -69,8 +69,16 @@ int MatmulFp32BaseCPUKernel::ParallelRunByRow(int task_id) const {
}
const float *input = matrix_a_.pack_ptr + start_row * params_->deep_;
float *output = output_data_ + start_row * params_->col_align_;
if (params_->col_ == 1) {
float bias = 0;
if (matrix_c_.pack_ptr != nullptr) {
bias = matrix_c_.pack_ptr[0];
}
gemmIsNotPackFun(input, matrix_b_.pack_ptr, output, &bias, row_num, params_->deep_, params_->act_type_);
} else {
MatMulAvxFp32(input, matrix_b_.pack_ptr, output, matrix_c_.pack_ptr, params_->act_type_, params_->deep_,
params_->col_align_, params_->col_align_, row_num);
}
return RET_OK;
}
@ -102,12 +110,13 @@ bool MatmulFp32BaseCPUKernel::CheckThreadCuttingByRow() {
if (b_batch_ != C1NUM) {
return false;
}
if (params_->col_ == 1 && !params_->a_const_) {
return false;
}
if (row_num_ < op_parameter_->thread_num_) {
return false;
}
if (params_->col_ == 1) {
row_min_unit_ = C4NUM;
return true;
}
row_min_unit_ = C3NUM;
if (col_step_ < C16NUM) {
row_min_unit_ = C8NUM;

View File

@ -70,8 +70,16 @@ int MatmulFp32BaseCPUKernel::ParallelRunByRow(int task_id) const {
}
const float *input = matrix_a_.pack_ptr + start_row * params_->deep_;
float *output = output_data_ + start_row * params_->col_align_;
if (params_->col_ == 1) {
float bias = 0;
if (matrix_c_.pack_ptr != nullptr) {
bias = matrix_c_.pack_ptr[0];
}
gemmIsNotPackFun(input, matrix_b_.pack_ptr, output, &bias, row_num, params_->deep_, params_->act_type_);
} else {
MatMulAvx512Fp32(input, matrix_b_.pack_ptr, output, matrix_c_.pack_ptr, params_->act_type_, params_->deep_,
params_->col_align_, params_->col_align_, row_num);
}
return RET_OK;
}
@ -104,12 +112,13 @@ bool MatmulFp32BaseCPUKernel::CheckThreadCuttingByRow() {
if (b_batch_ != C1NUM) {
return false;
}
if (params_->batch >= op_parameter_->thread_num_ || params_->col_ == 1) {
return false;
}
if (row_num_ < op_parameter_->thread_num_) {
return false;
}
if (params_->col_ == 1) {
row_min_unit_ = C8NUM;
return true;
}
row_min_unit_ = C6NUM;
if (col_step_ < C48NUM) {
row_min_unit_ = C12NUM;

View File

@ -230,7 +230,7 @@ int MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch(int task_id) const {
const float *a = matrix_a_.pack_ptr + a_offset_[index] * params_->row_ * params_->deep_;
const float *b = matrix_b_.pack_ptr + b_offset_[index] * params_->deep_ * params_->col_;
float *c = output_data_ + index * params_->row_ * params_->col_;
gemmIsNotPackFun(a, b, c, &bias, params_->row_, params_->deep_);
gemmIsNotPackFun(a, b, c, &bias, params_->row_, params_->deep_, params_->act_type_);
}
return RET_OK;
}
@ -247,6 +247,11 @@ int MatmulFp32BaseCPUKernel::Prepare() {
MS_CHECK_TRUE_MSG(in_tensors_[THIRD_INPUT]->data_type() == kNumberTypeFloat32, RET_ERROR,
"matrix-c's data type is invalid.");
}
auto act_type = params_->act_type_;
if (act_type != ActType_No && act_type != ActType_Relu && act_type != ActType_Relu6) {
MS_LOG(ERROR) << "matmul don't support the act-type: " << act_type;
return RET_ERROR;
}
auto ret = InitParameter();
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "Init parameters failed.");
if (params_->a_const_) {

View File

@ -30,7 +30,8 @@ using mindspore::lite::RET_OK;
namespace mindspore::kernel {
using MatrixPackFun = void (*)(const float *src_ptr, float *dst_ptr, int row, int col);
using GemmIsNotPackFun = void (*)(const float *a, const float *b, float *c, const float *bias, int m, int k);
using GemmIsNotPackFun = void (*)(const float *a, const float *b, float *c, const float *bias, int m, int k,
int act_type);
class MatmulFp32BaseCPUKernel : public LiteKernel {
public:
@ -79,6 +80,7 @@ class MatmulFp32BaseCPUKernel : public LiteKernel {
protected:
MatMulParameter *params_ = nullptr;
GemmIsNotPackFun gemmIsNotPackFun = nullptr;
int a_batch_ = 1;
int b_batch_ = 1;
std::vector<int> a_offset_;
@ -95,7 +97,6 @@ class MatmulFp32BaseCPUKernel : public LiteKernel {
float *output_data_ = nullptr;
bool out_need_aligned_ = false;
int col_step_ = 0;
GemmIsNotPackFun gemmIsNotPackFun = nullptr;
std::vector<int> split_points_;
MatrixInfo matrix_a_;
MatrixInfo matrix_b_;