matmul opt

This commit is contained in:
xuanyue 2022-01-19 17:12:53 +08:00
parent 4cc0cb5b2b
commit 27ddd48788
5 changed files with 414 additions and 24 deletions

View File

@ -70,6 +70,7 @@ mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/conv_int8.c:Con
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/conv_int8.c:Conv1x1PreOptPert
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/pack_int8.c:PackNHWCToNCHWInt8
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/pooling_fp32.c:AvgPooling
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c:MatMul4x1Kernel
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.c:SWConv3x32Kernel, SWConv4x24Kernel, SWConv12x8Kernel, SWConv8x8Kernel, SWConv4x8Kernel, SWConv6x16Kernel, SWConv4x16Kernel
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_depthwise_fp32.c:DepthwiseSW3x32Kernel, DepthwiseSW4x24Kernel, DepthwiseSW12x8Kernel, DepthwiseSW8x8Kernel, DepthwiseSW4x8Kernel, DepthwiseSW6x16Kernel, DepthwiseSW4x16Kernel
mindspore/mindspore/core/ir/dtype/type.cc:mindspore::ObjectIdLabel

View File

@ -2268,3 +2268,331 @@ void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float
}
}
}
#ifdef ENABLE_ARM64
void MatMul4x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) {
// 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute
// 9: WriteBack
asm volatile(
"mov x8, %[input]\n"
"mov x9, %[weight]\n"
"mov x10, %[deep]\n"
"add x5, %[input], %[deep], LSL #2\n"
"add x6, %[input], %[deep], LSL #3\n"
"add x7, x5, %[deep], LSL #3\n"
"dup v0.2d, xzr\n"
"dup v1.2d, xzr\n"
"dup v2.2d, xzr\n"
"dup v3.2d, xzr\n"
"subs x10, x10, #16\n"
"blt 2f\n"
"1:\n" // LoopD16
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n"
"ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64\n"
"ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x6], #64\n"
"ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x7], #64\n"
"ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n"
"fmla v0.4s, v4.4s, v28.4s\n"
"fmla v1.4s, v16.4s, v28.4s\n"
"fmla v2.4s, v20.4s, v28.4s\n"
"fmla v3.4s, v24.4s, v28.4s\n"
"fmla v0.4s, v5.4s, v29.4s\n"
"fmla v1.4s, v17.4s, v29.4s\n"
"fmla v2.4s, v21.4s, v29.4s\n"
"fmla v3.4s, v25.4s, v29.4s\n"
"fmla v0.4s, v6.4s, v30.4s\n"
"fmla v1.4s, v18.4s, v30.4s\n"
"fmla v2.4s, v22.4s, v30.4s\n"
"fmla v3.4s, v26.4s, v30.4s\n"
"fmla v0.4s, v7.4s, v31.4s\n"
"fmla v1.4s, v19.4s, v31.4s\n"
"fmla v2.4s, v23.4s, v31.4s\n"
"fmla v3.4s, v27.4s, v31.4s\n"
"subs x10, x10, #16\n"
"bge 1b\n"
"2:\n" // LoopD12
"adds x10, x10, #16\n"
"cbz x10, 6f\n"
"cmp x10, #12\n"
"blt 3f\n"
"ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n"
"ld1 {v16.4s, v17.4s, v18.4s}, [x5], #48\n"
"ld1 {v20.4s, v21.4s, v22.4s}, [x6], #48\n"
"ld1 {v24.4s, v25.4s, v26.4s}, [x7], #48\n"
"ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n"
"fmla v0.4s, v4.4s, v28.4s\n"
"fmla v1.4s, v16.4s, v28.4s\n"
"fmla v2.4s, v20.4s, v28.4s\n"
"fmla v3.4s, v24.4s, v28.4s\n"
"fmla v0.4s, v5.4s, v29.4s\n"
"fmla v1.4s, v17.4s, v29.4s\n"
"fmla v2.4s, v21.4s, v29.4s\n"
"fmla v3.4s, v25.4s, v29.4s\n"
"fmla v0.4s, v6.4s, v30.4s\n"
"fmla v1.4s, v18.4s, v30.4s\n"
"fmla v2.4s, v22.4s, v30.4s\n"
"fmla v3.4s, v26.4s, v30.4s\n"
"sub x10, x10, #12\n"
"b 7f\n"
"3:\n" // LoopD8
"cmp x10, #8\n"
"blt 4f\n"
"ld1 {v4.4s, v5.4s}, [x8], #32\n"
"ld1 {v16.4s, v17.4s}, [x5], #32\n"
"ld1 {v20.4s, v21.4s}, [x6], #32\n"
"ld1 {v24.4s, v25.4s}, [x7], #32\n"
"ld1 {v28.4s, v29.4s}, [x9], #32\n"
"fmla v0.4s, v4.4s, v28.4s\n"
"fmla v1.4s, v16.4s, v28.4s\n"
"fmla v2.4s, v20.4s, v28.4s\n"
"fmla v3.4s, v24.4s, v28.4s\n"
"fmla v0.4s, v5.4s, v29.4s\n"
"fmla v1.4s, v17.4s, v29.4s\n"
"fmla v2.4s, v21.4s, v29.4s\n"
"fmla v3.4s, v25.4s, v29.4s\n"
"sub x10, x10, #8\n"
"b 7f\n"
"4:\n" // LoopD4
"cmp x10, #4\n"
"blt 7f\n"
"ld1 {v4.4s}, [x8], #16\n"
"ld1 {v16.4s}, [x5], #16\n"
"ld1 {v20.4s}, [x6], #16\n"
"ld1 {v24.4s}, [x7], #16\n"
"ld1 {v28.4s}, [x9], #16\n"
"fmla v0.4s, v4.4s, v28.4s\n"
"fmla v1.4s, v16.4s, v28.4s\n"
"fmla v2.4s, v20.4s, v28.4s\n"
"fmla v3.4s, v24.4s, v28.4s\n"
"sub x10, x10, #4\n"
"7:\n"
"cbz x10, 6f\n"
"dup v4.2d, xzr\n"
"dup v16.2d, xzr\n"
"dup v20.2d, xzr\n"
"dup v24.2d, xzr\n"
"dup v28.2d, xzr\n"
"subs x10, x10, #2\n"
"blt 5f\n"
"ld1 {v4.d}[0], [x8], #8\n" // LoopD2
"ld1 {v16.d}[0], [x5], #8\n"
"ld1 {v20.d}[0], [x6], #8\n"
"ld1 {v24.d}[0], [x7], #8\n"
"ld1 {v28.d}[0], [x9], #8\n"
"cbz x10, 8f\n"
"5:\n" // LoopD1
"ld1 {v4.s}[2], [x8]\n"
"ld1 {v16.s}[2], [x5]\n"
"ld1 {v20.s}[2], [x6]\n"
"ld1 {v24.s}[2], [x7]\n"
"ld1 {v28.s}[2], [x9]\n"
"8:\n"
"fmla v0.4s, v4.4s, v28.4s\n"
"fmla v1.4s, v16.4s, v28.4s\n"
"fmla v2.4s, v20.4s, v28.4s\n"
"fmla v3.4s, v24.4s, v28.4s\n"
"6:\n"
"faddp v4.4s, v0.4s, v1.4s\n"
"faddp v5.4s, v2.4s, v3.4s\n"
"faddp v0.4s, v4.4s, v5.4s\n"
"cbz %[bias], 9f\n"
"ld1r {v1.4s}, [%[bias]]\n"
"fadd v0.4s, v0.4s, v1.4s\n"
"9:\n"
"st1 {v0.4s}, [%[output]]\n"
:
: [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep)
: "cc", "x5", "x6", "x7", "x8", "x9", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18",
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
}
void MatMul2x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) {
// 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute
// 9: WriteBack
asm volatile(
"mov x8, %[input]\n"
"mov x9, %[weight]\n"
"mov x10, %[deep]\n"
"add x5, %[input], %[deep], LSL #2\n"
"dup v0.2d, xzr\n"
"dup v1.2d, xzr\n"
"subs x10, x10, #16\n"
"blt 2f\n"
"1:\n" // LoopD16
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n"
"ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64\n"
"ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n"
"fmla v0.4s, v4.4s, v28.4s\n"
"fmla v1.4s, v16.4s, v28.4s\n"
"fmla v0.4s, v5.4s, v29.4s\n"
"fmla v1.4s, v17.4s, v29.4s\n"
"fmla v0.4s, v6.4s, v30.4s\n"
"fmla v1.4s, v18.4s, v30.4s\n"
"fmla v0.4s, v7.4s, v31.4s\n"
"fmla v1.4s, v19.4s, v31.4s\n"
"subs x10, x10, #16\n"
"bge 1b\n"
"2:\n" // LoopD12
"adds x10, x10, #16\n"
"cbz x10, 6f\n"
"cmp x10, #12\n"
"blt 3f\n"
"ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n"
"ld1 {v16.4s, v17.4s, v18.4s}, [x5], #48\n"
"ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n"
"fmla v0.4s, v4.4s, v28.4s\n"
"fmla v1.4s, v16.4s, v28.4s\n"
"fmla v0.4s, v5.4s, v29.4s\n"
"fmla v1.4s, v17.4s, v29.4s\n"
"fmla v0.4s, v6.4s, v30.4s\n"
"fmla v1.4s, v18.4s, v30.4s\n"
"sub x10, x10, #12\n"
"b 7f\n"
"3:\n" // LoopD8
"cmp x10, #8\n"
"blt 4f\n"
"ld1 {v4.4s, v5.4s}, [x8], #32\n"
"ld1 {v16.4s, v17.4s}, [x5], #32\n"
"ld1 {v28.4s, v29.4s}, [x9], #32\n"
"fmla v0.4s, v4.4s, v28.4s\n"
"fmla v1.4s, v16.4s, v28.4s\n"
"fmla v0.4s, v5.4s, v29.4s\n"
"fmla v1.4s, v17.4s, v29.4s\n"
"sub x10, x10, #8\n"
"b 7f\n"
"4:\n" // LoopD4
"cmp x10, #4\n"
"blt 7f\n"
"ld1 {v4.4s}, [x8], #16\n"
"ld1 {v16.4s}, [x5], #16\n"
"ld1 {v28.4s}, [x9], #16\n"
"fmla v0.4s, v4.4s, v28.4s\n"
"fmla v1.4s, v16.4s, v28.4s\n"
"sub x10, x10, #4\n"
"7:\n"
"cbz x10, 6f\n"
"dup v4.2d, xzr\n"
"dup v16.2d, xzr\n"
"subs x10, x10, #2\n"
"blt 5f\n"
"ld1 {v4.d}[0], [x8], #8\n" // LoopD2
"ld1 {v16.d}[0], [x5], #8\n"
"ld1 {v28.d}[0], [x9], #8\n"
"cbz x10, 8f\n"
"5:\n" // LoopD1
"ld1 {v4.s}[2], [x8]\n"
"ld1 {v16.s}[2], [x5]\n"
"ld1 {v28.s}[2], [x9]\n"
"8:\n"
"fmla v0.4s, v4.4s, v28.4s\n"
"fmla v1.4s, v16.4s, v28.4s\n"
"6:\n"
"faddp v4.4s, v0.4s, v1.4s\n"
"faddp v0.4s, v4.4s, v4.4s\n"
"cbz %[bias], 9f\n"
"ld1r {v1.4s}, [%[bias]]\n"
"fadd v0.2s, v0.2s, v1.2s\n"
"9:\n"
"st1 {v0.2s}, [%[output]]\n"
:
: [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep)
: "cc", "x5", "x8", "x9", "x10", "v0", "v1", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29",
"v30", "v31", "memory");
}
void MatMul1x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) {
// 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute
// 9: WriteBack
asm volatile(
"mov x8, %[input]\n"
"mov x9, %[weight]\n"
"mov x10, %[deep]\n"
"dup v0.2d, xzr\n"
"subs x10, x10, #16\n"
"blt 2f\n"
"1:\n" // LoopD16
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n"
"ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n"
"fmla v0.4s, v4.4s, v28.4s\n"
"fmla v0.4s, v5.4s, v29.4s\n"
"fmla v0.4s, v6.4s, v30.4s\n"
"fmla v0.4s, v7.4s, v31.4s\n"
"subs x10, x10, #16\n"
"bge 1b\n"
"2:\n" // LoopD12
"adds x10, x10, #16\n"
"cbz x10, 6f\n"
"cmp x10, #12\n"
"blt 3f\n"
"ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n"
"ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n"
"fmla v0.4s, v4.4s, v28.4s\n"
"fmla v0.4s, v5.4s, v29.4s\n"
"fmla v0.4s, v6.4s, v30.4s\n"
"sub x10, x10, #12\n"
"b 7f\n"
"3:\n" // LoopD8
"cmp x10, #8\n"
"blt 4f\n"
"ld1 {v4.4s, v5.4s}, [x8], #32\n"
"ld1 {v28.4s, v29.4s}, [x9], #32\n"
"fmla v0.4s, v4.4s, v28.4s\n"
"fmla v0.4s, v5.4s, v29.4s\n"
"sub x10, x10, #8\n"
"b 7f\n"
"4:\n" // LoopD4
"cmp x10, #4\n"
"blt 7f\n"
"ld1 {v4.4s}, [x8], #16\n"
"ld1 {v28.4s}, [x9], #16\n"
"fmla v0.4s, v4.4s, v28.4s\n"
"sub x10, x10, #4\n"
"7:\n"
"cbz x10, 6f\n"
"dup v4.2d, xzr\n"
"subs x10, x10, #2\n"
"blt 5f\n"
"ld1 {v4.d}[0], [x8], #8\n" // LoopD2
"ld1 {v28.d}[0], [x9], #8\n"
"cbz x10, 8f\n"
"5:\n" // LoopD1
"ld1 {v4.s}[3], [x8]\n"
"ld1 {v28.s}[3], [x9]\n"
"8:\n"
"fmla v0.4s, v4.4s, v28.4s\n"
"6:\n"
"faddp v4.4s, v0.4s, v0.4s\n"
"faddp v0.4s, v4.4s, v4.4s\n"
"cbz %[bias], 9f\n"
"ld1 {v1.s}[0], [%[bias]]\n"
"fadd s0, s0, s1\n"
"9:\n"
"st1 {v0.s}[0], [%[output]]\n"
:
: [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep)
: "cc", "x8", "x9", "x10", "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29", "v30", "v31");
}
void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row,
int deep) {
const float *input = a + start_row * deep;
float *output = c + start_row;
const int step = C4NUM * deep;
for (; start_row <= end_row - C4NUM; start_row += C4NUM) {
MatMul4x1Kernel(input, b, output, bias, deep);
input += step;
output += C4NUM;
}
for (; start_row <= end_row - C2NUM; start_row += C2NUM) {
MatMul2x1Kernel(input, b, output, bias, deep);
input += C2NUM * deep;
output += C2NUM;
}
if (start_row == end_row - 1) {
MatMul1x1Kernel(input, b, output, bias, deep);
}
}
#endif

View File

@ -129,6 +129,10 @@ void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias,
void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k);
#ifdef ENABLE_ARM64
void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row,
int deep);
#endif
#ifdef __cplusplus
}
#endif

View File

@ -228,7 +228,8 @@ void MatmulFp32BaseCPUKernel::FreeBiasBuf() {
}
void MatmulFp32BaseCPUKernel::FreeResizeBufA() {
if (!vec_matmul_ && !op_parameter_->is_train_session_ && a_pack_ptr_ != nullptr && is_pack_) {
if (!vec_matmul_ && !op_parameter_->is_train_session_ && a_pack_ptr_ != nullptr &&
(is_pack_ || (params_->a_transpose_ && params_->deep_ != 1))) {
#ifdef USING_SERVING
if (a_is_packed_ == lite::MALLOC) {
#endif
@ -293,8 +294,8 @@ int MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch(int task_id) const {
bias = bias_ptr_[0];
}
for (int index = start_batch; index < end_batch; ++index) {
const float *a = a_pack_ptr_ + index * params_->row_ * params_->deep_;
const float *b = b_pack_ptr_ + index * params_->deep_ * params_->col_;
const float *a = a_pack_ptr_ + a_offset_[index] * params_->row_ * params_->deep_;
const float *b = b_pack_ptr_ + b_offset_[index] * params_->deep_ * params_->col_;
float *c = output_data_ + index * params_->row_ * params_->col_;
gemmIsNotPackFun(a, b, c, &bias, params_->row_, params_->deep_);
}
@ -340,6 +341,21 @@ int MatmulFp32BaseCPUKernel::ParallelRunByOC(int task_id) const {
return RET_OK;
}
#ifdef ENABLE_ARM64
int MatmulFp32BaseCPUKernel::ParallelRunByRow(int task_id) const {
int start_row = row_split_points_[task_id];
int end_row = row_num_;
if (task_id < (static_cast<int>(row_split_points_.size()) - 1)) {
end_row = row_split_points_[task_id + 1];
}
if (start_row == end_row) {
return RET_OK;
}
GemmIsNotPackByRow(a_pack_ptr_, b_pack_ptr_, output_data_, bias_ptr_, start_row, end_row, params_->deep_);
return RET_OK;
}
#endif
int MatmulFp32BaseCPUKernel::init_global_variable() {
#ifdef ENABLE_AVX512
matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor;
@ -372,14 +388,6 @@ int MatmulFp32BaseCPUKernel::init_global_variable() {
matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2Col8Major : RowMajor2Row8Major;
row_tile_ = C12NUM;
col_tile_ = C8NUM;
#endif
params_->row_align_ = UP_ROUND(params_->row_, row_tile_);
params_->col_align_ = UP_ROUND(params_->col_, col_tile_);
#if defined(ENABLE_AVX) || defined(ENABLE_AVX512)
col_step_ = params_->col_align_;
#else
// need not aligned
col_step_ = params_->col_;
#endif
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->row_align_, RET_ERROR);
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->row_align_, params_->deep_, RET_ERROR);
@ -387,14 +395,21 @@ int MatmulFp32BaseCPUKernel::init_global_variable() {
MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->col_align_, params_->deep_, RET_ERROR);
if (params_->col_ == 1 && !params_->a_const_) {
is_pack_ = false;
matrix_a_pack_size_ = a_batch_ * params_->row_ * params_->deep_;
matrix_b_pack_size_ = b_batch_ * params_->col_ * params_->deep_;
row_tile_ = 1;
col_tile_ = 1;
matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor;
matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor;
} else {
}
params_->row_align_ = UP_ROUND(params_->row_, row_tile_);
params_->col_align_ = UP_ROUND(params_->col_, col_tile_);
matrix_a_pack_size_ = a_batch_ * params_->row_align_ * params_->deep_;
matrix_b_pack_size_ = b_batch_ * params_->col_align_ * params_->deep_;
}
#if defined(ENABLE_AVX) || defined(ENABLE_AVX512)
col_step_ = params_->col_align_;
#else
// need not aligned
col_step_ = params_->col_;
#endif
return RET_OK;
}
@ -463,8 +478,12 @@ int MatmulFp32BaseCPUKernel::ReSize() {
if (op_parameter_->is_train_session_) {
set_workspace_size((matrix_a_pack_size_ + matrix_b_pack_size_) * static_cast<int>(sizeof(float)));
}
GetThreadCuttingPolicy();
auto ret = InitTmpOutBuffer();
auto ret = GetThreadCuttingPolicy();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ThreadCuttingPolicy error!";
return ret;
}
ret = InitTmpOutBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "InitTmpOutBuffer error!";
return ret;
@ -503,7 +522,7 @@ int MatmulFp32BaseCPUKernel::InitTmpOutBuffer() {
return RET_OK;
}
void MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() {
int MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() {
if (params_->batch >= op_parameter_->thread_num_ || (params_->col_ == 1 && !params_->a_const_)) {
thread_count_ = op_parameter_->thread_num_;
batch_stride_ = UP_DIV(params_->batch, thread_count_);
@ -526,10 +545,38 @@ void MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() {
gemmIsNotPackFun = GemmIsNotPack;
} else {
gemmIsNotPackFun = GemmIsNotPackOptimize;
#ifdef ENABLE_ARM64
auto ret = GetThreadCuttingPolicyForArm64WhenVb();
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "get thread policy for arm64 failed.");
#endif
}
}
return RET_OK;
}
#ifdef ENABLE_ARM64
int MatmulFp32BaseCPUKernel::GetThreadCuttingPolicyForArm64WhenVb() {
if (b_batch_ != 1) {
return RET_OK;
}
parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunByRow;
MS_CHECK_FALSE(INT_MUL_OVERFLOW(a_batch_, params_->row_), RET_ERROR);
row_num_ = a_batch_ * params_->row_;
int row_step = MSMAX(row_num_ / op_parameter_->thread_num_, C64NUM);
int row_remaining = MSMAX(row_num_ - row_step * op_parameter_->thread_num_, 0);
row_split_points_.resize(op_parameter_->thread_num_);
for (size_t i = 0; i < row_split_points_.size(); ++i) {
if (i == 0) {
row_split_points_[i] = 0;
continue;
}
row_split_points_[i] =
MSMIN(row_split_points_[i - 1] + row_step + (static_cast<int>(i) < row_remaining ? 1 : 0), row_num_);
}
return RET_OK;
}
#endif
int MatmulFp32BaseCPUKernel::Run() {
auto out_data = reinterpret_cast<float *>(out_tensors_.front()->data());
CHECK_NULL_RETURN(out_data);
@ -555,9 +602,7 @@ int MatmulFp32BaseCPUKernel::Run() {
if (!params_->a_const_) {
auto a_ptr = reinterpret_cast<float *>(in_tensors_[0]->data());
CHECK_NULL_RETURN(a_ptr);
if (!is_pack_) {
a_pack_ptr_ = a_ptr;
} else {
if (is_pack_ || (params_->a_transpose_ && params_->deep_ != 1)) {
if (InitBufferA() != RET_OK) {
return RET_ERROR;
}
@ -566,6 +611,8 @@ int MatmulFp32BaseCPUKernel::Run() {
MS_LOG(ERROR) << "InitMatrixA failed!";
return ret;
}
} else {
a_pack_ptr_ = a_ptr;
}
}
@ -592,7 +639,7 @@ int MatmulFp32BaseCPUKernel::Run() {
PackNHWCXToNHWCFp32(output_data_, out_data, params_->batch, params_->row_, params_->col_, col_tile_);
}
if (!params_->a_const_) {
if (is_pack_) {
if (is_pack_ || (params_->a_transpose_ && params_->deep_ != 1)) {
FreeResizeBufA();
} else {
a_pack_ptr_ = nullptr;

View File

@ -52,6 +52,9 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
int Run() override;
public:
#ifdef ENABLE_ARM64
int ParallelRunByRow(int task_id) const;
#endif
int ParallelRunByOC(int task_id) const;
int ParallelRunByBatch(int task_id) const;
int ParallelRunIsNotPackByBatch(int task_id) const;
@ -75,7 +78,10 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
void FreeResizeBufB();
int CalBroadCastBiasDataElements();
int InitTmpOutBuffer();
void GetThreadCuttingPolicy();
int GetThreadCuttingPolicy();
#ifdef ENABLE_ARM64
int GetThreadCuttingPolicyForArm64WhenVb();
#endif
protected:
MatMulParameter *params_ = nullptr;
@ -115,6 +121,10 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
GemvFun gemvCalFun = nullptr;
#endif
GemmIsNotPackFun gemmIsNotPackFun = nullptr;
#ifdef ENABLE_ARM64
int row_num_;
std::vector<int> row_split_points_;
#endif
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_MATMUL_FP32_BASE_H_