From 399276790e918f4aeea2db7605144ea665473cda Mon Sep 17 00:00:00 2001 From: xuanyue Date: Wed, 9 Feb 2022 14:24:17 +0800 Subject: [PATCH] move in matmul, transpose and bias_add's opt --- .jenkins/check/config/whitelizard.txt | 1 + .../kernel_compiler/cpu/nnacl/fp32/bias_add.c | 142 ++++++++ .../kernel_compiler/cpu/nnacl/fp32/bias_add.h | 34 ++ .../cpu/nnacl/fp32/matmul_fp32.c | 330 +++++++++++++++++- .../cpu/nnacl/fp32/matmul_fp32.h | 4 + .../runtime/kernel/arm/fp16/transpose_fp16.cc | 6 +- .../runtime/kernel/arm/fp16/transpose_fp16.h | 5 +- .../src/runtime/kernel/arm/fp32/bias_fp32.cc | 115 ++++-- .../src/runtime/kernel/arm/fp32/bias_fp32.h | 10 +- .../kernel/arm/fp32/matmul_fp32_base.cc | 144 ++++++-- .../kernel/arm/fp32/matmul_fp32_base.h | 12 +- .../runtime/kernel/arm/fp32/transpose_fp32.cc | 216 ++++++++---- .../runtime/kernel/arm/fp32/transpose_fp32.h | 20 +- 13 files changed, 893 insertions(+), 146 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/bias_add.c create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/bias_add.h diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index cef71aeb6b2..2041de14270 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -95,6 +95,7 @@ mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/deconv_winograd mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/deconv_winograd_fp32.c:DeConvWgMerge mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/avx/TiledC8MatMulFp32.c:TiledC8MatmulFp32 mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/quant_dtype_cast_fp16.c:Fp16ToInt8_arm64 +mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c:MatMul4x1Kernel mindspore/mindspore/ccsrc/backend/session/gpu_session.cc:mindspore::session::gpu::GPUSession::LoadInputData mindspore/mindspore/ccsrc/debug/dump_proto.cc:mindspore::ProtoExporter::SetNodeOutputType mindspore/mindspore/ccsrc/debug/dump_proto.cc:mindspore::ProtoExporter::SetValueToProto diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/bias_add.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/bias_add.c new file mode 100644 index 00000000000..375609b2079 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/bias_add.c @@ -0,0 +1,142 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/bias_add.h" +#include "nnacl/op_base.h" + +void BiasAddByInnerCore(const float *input, const float *bias, float *output, int64_t num) { + int64_t index = 0; +#if defined(ENABLE_SSE) || defined(ENABLE_ARM) + for (; index <= num - C4NUM; index += C4NUM) { + MS_FLOAT32X4 input_data = MS_LDQ_F32(input + index); + MS_FLOAT32X4 bias_data = MS_LDQ_F32(bias + index); + MS_STQ_F32(output + index, MS_ADD128_F32(input_data, bias_data)); + } +#endif + + for (; index < num; ++index) { + output[index] = input[index] + bias[index]; + } +} + +void BiasAddByBatchCore(const float *input, const float *bias, float *output, int64_t num) { + float *output1 = output; + float *output2 = output + num; + float *output3 = output + num * 2; + float *output4 = output + num * 3; + int64_t index = 0; +#if defined(ENABLE_SSE) || defined(ENABLE_ARM) + for (; index <= num - C4NUM; index += C4NUM) { + MS_LOAD128X4_F32(input_data, input + index, num); + MS_FLOAT32X4 bias_data = MS_LDQ_F32(bias + index); + MS_STQ_F32(output1 + index, MS_ADD128_F32(input_data1, bias_data)); + MS_STQ_F32(output2 + index, MS_ADD128_F32(input_data2, bias_data)); + MS_STQ_F32(output3 + index, MS_ADD128_F32(input_data3, bias_data)); + MS_STQ_F32(output4 + index, MS_ADD128_F32(input_data4, bias_data)); + } +#endif + const float *input_data1 = input; + const float *input_data2 = input + num; + const float *input_data3 = input + num * 2; + const float *input_data4 = input + num * 3; + for (; index < num; ++index) { + output1[index] = input_data1[index] + bias[index]; + output2[index] = input_data2[index] + bias[index]; + output3[index] = input_data3[index] + bias[index]; + output4[index] = input_data4[index] + bias[index]; + } +} + +void DoBiasAddByBatch(const float *input, const float *bias, float *output, int64_t start, int64_t end, + int64_t inner_num) { + if (inner_num == 0) { + return; + } + int64_t start_outer = start / inner_num; + int64_t start_inner = start % inner_num; + int64_t end_outer = end / inner_num; + int64_t end_inner = end % inner_num; + const float *cur_input = input + start; + const float *cur_bias = bias + start_inner; + float *cur_output = output + start; + if (start_outer == end_outer) { + BiasAddByInnerCore(cur_input, cur_bias, cur_output, end_inner - start_inner); + return; + } + if (start_inner != 0) { + BiasAddByInnerCore(cur_input, cur_bias, cur_output, inner_num - start_inner); + start_outer += 1; + cur_input += inner_num - start_inner; + cur_bias = bias; + cur_output += inner_num - start_inner; + } + int64_t step = C4NUM * inner_num; + for (; start_outer <= end_outer - C4NUM; start_outer += C4NUM) { + BiasAddByBatchCore(cur_input, cur_bias, cur_output, inner_num); + cur_input += step; + cur_output += step; + } + for (; start_outer < end_outer; ++start_outer) { + BiasAddByInnerCore(cur_input, cur_bias, cur_output, inner_num); + cur_input += inner_num; + cur_output += inner_num; + } + BiasAddByInnerCore(cur_input, cur_bias, cur_output, end_inner); +} + +void DoBiasAddByInner(const float *input, const float *bias, float *output, int64_t start, int64_t end, + int64_t inner_num) { + if (inner_num == 0) { + return; + } + int64_t start_outer = start / inner_num; + int64_t start_inner = start % inner_num; + int64_t end_outer = end / inner_num; + int64_t end_inner = end % inner_num; + const float *cur_input = input + start; + const float *cur_bias = bias + start_inner; + float *cur_output = output + start; + if (start_outer == end_outer) { + BiasAddByInnerCore(cur_input, cur_bias, cur_output, end_inner - start_inner); + return; + } else { + BiasAddByInnerCore(cur_input, cur_bias, cur_output, inner_num - start_inner); + start_outer += 1; + cur_input += inner_num - start_inner; + cur_bias = bias; + cur_output += inner_num - start_inner; + } + if (start_outer == end_outer) { + BiasAddByInnerCore(cur_input, cur_bias, cur_output, end_inner); + return; + } else { + for (; start_outer < end_outer; ++start_outer) { + BiasAddByInnerCore(cur_input, cur_bias, cur_output, inner_num); + cur_input += inner_num; + cur_output += inner_num; + } + } + BiasAddByInnerCore(cur_input, cur_bias, cur_output, end_inner); +} + +void BiasAddOpt(const float *input, const float *bias, float *output, int64_t start, int64_t end, int64_t inner_num, + bool batch_priority) { + if (batch_priority) { + DoBiasAddByBatch(input, bias, output, start, end, inner_num); + } else { + DoBiasAddByInner(input, bias, output, start, end, inner_num); + } +} diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/bias_add.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/bias_add.h new file mode 100644 index 00000000000..210b176858d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/bias_add.h @@ -0,0 +1,34 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_BIAS_ADD_H_ +#define MINDSPORE_NNACL_FP32_BIAS_ADD_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +void BiasAddOpt(const float *input, const float *bias, float *output, int64_t start, int64_t end, int64_t inner_num, + bool batch_priority); + +#ifdef __cplusplus +}; +#endif + +#endif // MINDSPORE_NNACL_FP32_BIAS_ADD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c index 57da9bb6e14..13875921cb5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.c @@ -726,7 +726,7 @@ void RowMajor2Col64Major(const float *src_ptr, float *dst_ptr, int row, int col) for (int i = 0; i < all_block_num; i += cur_block) { cur_block = MSMIN(C4NUM, all_block_num - i); // max_tile = 4 int dst_stride = cur_block * C16NUM; - int row_num = MSMIN(dst_stride, row - i * C8NUM); + int row_num = MSMIN(dst_stride, row - i * C16NUM); const float *src = src_ptr + i * C16NUM * col; float *dst = dst_ptr + i * C16NUM * col; int r = 0; @@ -2268,3 +2268,331 @@ void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float } } } + +#ifdef ENABLE_ARM64 +void MatMul4x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) { + // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute + // 9: WriteBack + asm volatile( + "mov x8, %[input]\n" + "mov x9, %[weight]\n" + "mov x10, %[deep]\n" + "add x5, %[input], %[deep], LSL #2\n" + "add x6, %[input], %[deep], LSL #3\n" + "add x7, x5, %[deep], LSL #3\n" + "dup v0.2d, xzr\n" + "dup v1.2d, xzr\n" + "dup v2.2d, xzr\n" + "dup v3.2d, xzr\n" + "subs x10, x10, #16\n" + "blt 2f\n" + "1:\n" // LoopD16 + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64\n" + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x6], #64\n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x7], #64\n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v2.4s, v21.4s, v29.4s\n" + "fmla v3.4s, v25.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v1.4s, v18.4s, v30.4s\n" + "fmla v2.4s, v22.4s, v30.4s\n" + "fmla v3.4s, v26.4s, v30.4s\n" + "fmla v0.4s, v7.4s, v31.4s\n" + "fmla v1.4s, v19.4s, v31.4s\n" + "fmla v2.4s, v23.4s, v31.4s\n" + "fmla v3.4s, v27.4s, v31.4s\n" + "subs x10, x10, #16\n" + "bge 1b\n" + "2:\n" // LoopD12 + "adds x10, x10, #16\n" + "cbz x10, 6f\n" + "cmp x10, #12\n" + "blt 3f\n" + "ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n" + "ld1 {v16.4s, v17.4s, v18.4s}, [x5], #48\n" + "ld1 {v20.4s, v21.4s, v22.4s}, [x6], #48\n" + "ld1 {v24.4s, v25.4s, v26.4s}, [x7], #48\n" + "ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v2.4s, v21.4s, v29.4s\n" + "fmla v3.4s, v25.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v1.4s, v18.4s, v30.4s\n" + "fmla v2.4s, v22.4s, v30.4s\n" + "fmla v3.4s, v26.4s, v30.4s\n" + "sub x10, x10, #12\n" + "b 7f\n" + "3:\n" // LoopD8 + "cmp x10, #8\n" + "blt 4f\n" + "ld1 {v4.4s, v5.4s}, [x8], #32\n" + "ld1 {v16.4s, v17.4s}, [x5], #32\n" + "ld1 {v20.4s, v21.4s}, [x6], #32\n" + "ld1 {v24.4s, v25.4s}, [x7], #32\n" + "ld1 {v28.4s, v29.4s}, [x9], #32\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v2.4s, v21.4s, v29.4s\n" + "fmla v3.4s, v25.4s, v29.4s\n" + "sub x10, x10, #8\n" + "b 7f\n" + "4:\n" // LoopD4 + "cmp x10, #4\n" + "blt 7f\n" + "ld1 {v4.4s}, [x8], #16\n" + "ld1 {v16.4s}, [x5], #16\n" + "ld1 {v20.4s}, [x6], #16\n" + "ld1 {v24.4s}, [x7], #16\n" + "ld1 {v28.4s}, [x9], #16\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "sub x10, x10, #4\n" + "7:\n" + "cbz x10, 6f\n" + "dup v4.2d, xzr\n" + "dup v16.2d, xzr\n" + "dup v20.2d, xzr\n" + "dup v24.2d, xzr\n" + "dup v28.2d, xzr\n" + "subs x10, x10, #2\n" + "blt 5f\n" + "ld1 {v4.d}[0], [x8], #8\n" // LoopD2 + "ld1 {v16.d}[0], [x5], #8\n" + "ld1 {v20.d}[0], [x6], #8\n" + "ld1 {v24.d}[0], [x7], #8\n" + "ld1 {v28.d}[0], [x9], #8\n" + "cbz x10, 8f\n" + "5:\n" // LoopD1 + "ld1 {v4.s}[2], [x8]\n" + "ld1 {v16.s}[2], [x5]\n" + "ld1 {v20.s}[2], [x6]\n" + "ld1 {v24.s}[2], [x7]\n" + "ld1 {v28.s}[2], [x9]\n" + "8:\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "6:\n" + "faddp v4.4s, v0.4s, v1.4s\n" + "faddp v5.4s, v2.4s, v3.4s\n" + "faddp v0.4s, v4.4s, v5.4s\n" + "cbz %[bias], 9f\n" + "ld1r {v1.4s}, [%[bias]]\n" + "fadd v0.4s, v0.4s, v1.4s\n" + "9:\n" + "st1 {v0.4s}, [%[output]]\n" + + : + : [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep) + : "cc", "x5", "x6", "x7", "x8", "x9", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} + +void MatMul2x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) { + // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute + // 9: WriteBack + asm volatile( + "mov x8, %[input]\n" + "mov x9, %[weight]\n" + "mov x10, %[deep]\n" + "add x5, %[input], %[deep], LSL #2\n" + "dup v0.2d, xzr\n" + "dup v1.2d, xzr\n" + "subs x10, x10, #16\n" + "blt 2f\n" + "1:\n" // LoopD16 + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64\n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v1.4s, v18.4s, v30.4s\n" + "fmla v0.4s, v7.4s, v31.4s\n" + "fmla v1.4s, v19.4s, v31.4s\n" + "subs x10, x10, #16\n" + "bge 1b\n" + "2:\n" // LoopD12 + "adds x10, x10, #16\n" + "cbz x10, 6f\n" + "cmp x10, #12\n" + "blt 3f\n" + "ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n" + "ld1 {v16.4s, v17.4s, v18.4s}, [x5], #48\n" + "ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v1.4s, v18.4s, v30.4s\n" + "sub x10, x10, #12\n" + "b 7f\n" + "3:\n" // LoopD8 + "cmp x10, #8\n" + "blt 4f\n" + "ld1 {v4.4s, v5.4s}, [x8], #32\n" + "ld1 {v16.4s, v17.4s}, [x5], #32\n" + "ld1 {v28.4s, v29.4s}, [x9], #32\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "sub x10, x10, #8\n" + "b 7f\n" + "4:\n" // LoopD4 + "cmp x10, #4\n" + "blt 7f\n" + "ld1 {v4.4s}, [x8], #16\n" + "ld1 {v16.4s}, [x5], #16\n" + "ld1 {v28.4s}, [x9], #16\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "sub x10, x10, #4\n" + "7:\n" + "cbz x10, 6f\n" + "dup v4.2d, xzr\n" + "dup v16.2d, xzr\n" + "subs x10, x10, #2\n" + "blt 5f\n" + "ld1 {v4.d}[0], [x8], #8\n" // LoopD2 + "ld1 {v16.d}[0], [x5], #8\n" + "ld1 {v28.d}[0], [x9], #8\n" + "cbz x10, 8f\n" + "5:\n" // LoopD1 + "ld1 {v4.s}[2], [x8]\n" + "ld1 {v16.s}[2], [x5]\n" + "ld1 {v28.s}[2], [x9]\n" + "8:\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "6:\n" + "faddp v4.4s, v0.4s, v1.4s\n" + "faddp v0.4s, v4.4s, v4.4s\n" + "cbz %[bias], 9f\n" + "ld1r {v1.4s}, [%[bias]]\n" + "fadd v0.2s, v0.2s, v1.2s\n" + "9:\n" + "st1 {v0.2s}, [%[output]]\n" + + : + : [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep) + : "cc", "x5", "x8", "x9", "x10", "v0", "v1", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29", + "v30", "v31", "memory"); +} + +void MatMul1x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep) { + // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute + // 9: WriteBack + asm volatile( + "mov x8, %[input]\n" + "mov x9, %[weight]\n" + "mov x10, %[deep]\n" + "dup v0.2d, xzr\n" + "subs x10, x10, #16\n" + "blt 2f\n" + "1:\n" // LoopD16 + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v0.4s, v7.4s, v31.4s\n" + "subs x10, x10, #16\n" + "bge 1b\n" + "2:\n" // LoopD12 + "adds x10, x10, #16\n" + "cbz x10, 6f\n" + "cmp x10, #12\n" + "blt 3f\n" + "ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n" + "ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "sub x10, x10, #12\n" + "b 7f\n" + "3:\n" // LoopD8 + "cmp x10, #8\n" + "blt 4f\n" + "ld1 {v4.4s, v5.4s}, [x8], #32\n" + "ld1 {v28.4s, v29.4s}, [x9], #32\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "sub x10, x10, #8\n" + "b 7f\n" + "4:\n" // LoopD4 + "cmp x10, #4\n" + "blt 7f\n" + "ld1 {v4.4s}, [x8], #16\n" + "ld1 {v28.4s}, [x9], #16\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "sub x10, x10, #4\n" + "7:\n" + "cbz x10, 6f\n" + "dup v4.2d, xzr\n" + "subs x10, x10, #2\n" + "blt 5f\n" + "ld1 {v4.d}[0], [x8], #8\n" // LoopD2 + "ld1 {v28.d}[0], [x9], #8\n" + "cbz x10, 8f\n" + "5:\n" // LoopD1 + "ld1 {v4.s}[3], [x8]\n" + "ld1 {v28.s}[3], [x9]\n" + "8:\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "6:\n" + "faddp v4.4s, v0.4s, v0.4s\n" + "faddp v0.4s, v4.4s, v4.4s\n" + "cbz %[bias], 9f\n" + "ld1 {v1.s}[0], [%[bias]]\n" + "fadd s0, s0, s1\n" + "9:\n" + "st1 {v0.s}[0], [%[output]]\n" + + : + : [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep) + : "cc", "x8", "x9", "x10", "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29", "v30", "v31"); +} + +void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row, + int deep) { + const float *input = a + start_row * deep; + float *output = c + start_row; + const int step = C4NUM * deep; + for (; start_row <= end_row - C4NUM; start_row += C4NUM) { + MatMul4x1Kernel(input, b, output, bias, deep); + input += step; + output += C4NUM; + } + for (; start_row <= end_row - C2NUM; start_row += C2NUM) { + MatMul2x1Kernel(input, b, output, bias, deep); + input += C2NUM * deep; + output += C2NUM; + } + if (start_row == end_row - 1) { + MatMul1x1Kernel(input, b, output, bias, deep); + } +} +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.h index 2d7543a8e6b..e03def587a6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.h @@ -129,6 +129,10 @@ void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k); +#ifdef ENABLE_ARM64 +void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row, + int deep); +#endif #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.cc index 89b3472d6c6..63db3b77c45 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,9 +29,7 @@ using mindspore::lite::RET_OP_EXECUTE_FAILURE; using mindspore::schema::PrimitiveType_Transpose; namespace mindspore::kernel { -void TransposeFp16CPUKernel::GetNchwToNhwcFunc(TypeId dtype) { NHNCTransposeFunc_ = PackNCHWToNHWCFp16; } - -void TransposeFp16CPUKernel::GetNhwcToNchwFunc(TypeId dtype) { NHNCTransposeFunc_ = PackNHWCToNCHWFp16; } +void TransposeFp16CPUKernel::SetOptTransposeFunc() { optTransposeFunc_ = PackNHWCToNCHWFp16; } int TransposeFp16CPUKernel::TransposeDim2to6() { return DoTransposeFp16(static_cast(in_data_), static_cast(out_data_), out_shape_, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.h index 8b751fc1d44..678fd265809 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.h @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,8 +32,7 @@ class TransposeFp16CPUKernel : public TransposeCPUKernel { ~TransposeFp16CPUKernel() = default; private: - void GetNchwToNhwcFunc(TypeId dtype) override; - void GetNhwcToNchwFunc(TypeId dtype) override; + void SetOptTransposeFunc() override; int TransposeDim2to6() override; int TransposeDimGreaterThan6(int task_id) override; }; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/bias_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/bias_fp32.cc index 71bb72ce0b1..16e2e1b0f67 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/bias_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/bias_fp32.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #include "src/runtime/kernel/arm/fp32/bias_fp32.h" #include +#include "nnacl/fp32/bias_add.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "include/errorcode.h" @@ -27,39 +28,13 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_BiasAdd; namespace mindspore::kernel { -int BiasCPUKernel::ReSize() { - auto dims = in_tensors_.at(0)->shape(); - bias_param_->ndim_ = dims.size(); - if (bias_param_->ndim_ < 1 || bias_param_->ndim_ > 5) { - MS_LOG(ERROR) << "input shape is invalid"; - return RET_ERROR; +int BiasAddRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) { + CHECK_NULL_RETURN(cdata); + auto kernel = reinterpret_cast(cdata); + auto ret = kernel->DoExecute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "BatchnormRun error task_id[" << task_id << "] error_code[" << ret << "]"; } - for (size_t i = 0; i < bias_param_->ndim_; i++) { - bias_param_->in_shape0_[i] = dims[i]; - bias_param_->in_shape1_[i] = 1; - bias_param_->out_shape_[i] = dims[i]; - } - bias_param_->in_shape1_[bias_param_->ndim_ - 1] = dims[bias_param_->ndim_ - 1]; - return RET_OK; -} - -int BiasCPUKernel::Run() { - auto in = reinterpret_cast(in_tensors_.at(0)->MutableData()); - auto bias = reinterpret_cast(in_tensors_.at(1)->MutableData()); - auto out = reinterpret_cast(out_tensors_.at(0)->MutableData()); - size_t data_size = static_cast(in_tensors_.at(0)->ElementsNum()); - CHECK_NULL_RETURN(ms_context_->allocator); - float *tile_in = reinterpret_cast(ms_context_->allocator->Malloc(data_size * sizeof(float))); - float *tile_bias = reinterpret_cast(ms_context_->allocator->Malloc(data_size * sizeof(float))); - if (tile_in == nullptr || tile_bias == nullptr) { - MS_LOG(ERROR) << "Memory allocation failed"; - ms_context_->allocator->Free(tile_in); - ms_context_->allocator->Free(tile_bias); - return RET_ERROR; - } - auto ret = BroadcastAdd(in, bias, tile_in, tile_bias, out, static_cast(data_size), bias_param_); - ms_context_->allocator->Free(tile_in); - ms_context_->allocator->Free(tile_bias); return ret; } @@ -73,5 +48,79 @@ int BiasCPUKernel::Prepare() { return ReSize(); } +int BiasCPUKernel::ReSize() { + auto in_dims = in_tensors_.at(0)->shape(); + auto bias_dims = in_tensors_.at(1)->shape(); + if (bias_dims.empty() || in_dims.empty() || in_dims.size() < bias_dims.size()) { + MS_LOG(ERROR) << "inTensors' shape are invalid."; + return RET_ERROR; + } + size_t dim_offset = in_dims.size() - bias_dims.size(); + inner_num_ = 1; + for (size_t i = 0; i < bias_dims.size(); ++i) { + if (in_dims[i + dim_offset] != bias_dims[i]) { + MS_LOG(ERROR) << "inTensors' shape cannot match."; + return RET_ERROR; + } + MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW(bias_dims[i], inner_num_), RET_ERROR, "mul overflow."); + inner_num_ *= bias_dims[i]; + } + outer_num_ = 1; + for (size_t i = 0; i < dim_offset; ++i) { + MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW(in_dims[i], outer_num_), RET_ERROR, "mul overflow."); + outer_num_ *= in_dims[i]; + } + MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW(inner_num_, outer_num_), RET_ERROR, "mul overflow."); + total_num_ = inner_num_ * outer_num_; + GetThreadSegmentInfos(); + return RET_OK; +} + +void BiasCPUKernel::GetThreadSegmentInfos() { + split_start_points_ = std::vector(op_parameter_->thread_num_, 0); + split_end_points_ = std::vector(op_parameter_->thread_num_, 0); + int64_t step = MSMAX(total_num_ / op_parameter_->thread_num_, C128NUM); + int64_t remain_data = MSMAX(total_num_ - step * op_parameter_->thread_num_, 0); + for (int i = 0; i < op_parameter_->thread_num_; ++i) { + if (i == 0) { + split_end_points_[i] = MSMIN(step, total_num_) + (i < remain_data ? 1 : 0); + continue; + } + split_start_points_[i] = split_end_points_[i - 1]; + if (split_start_points_[i] >= total_num_) { + split_start_points_[i] = 0; + break; + } + split_end_points_[i] = + split_start_points_[i] + MSMIN(step, total_num_ - split_start_points_[i]) + (i < remain_data ? 1 : 0); + } + MS_ASSERT(inner_num_ != 0); + if (inner_num_ >= C64NUM && step / inner_num_ >= C6NUM) { + batch_priority_ = true; + } else { + batch_priority_ = false; + } +} + +int BiasCPUKernel::Run() { + auto ret = ParallelLaunch(this->ms_context_, BiasAddRun, this, op_parameter_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "BiasAddRun error error_code[" << ret << "]"; + } + return ret; +} + +int BiasCPUKernel::DoExecute(int task_id) { + auto input = reinterpret_cast(in_tensors_.at(0)->MutableData()); + auto bias = reinterpret_cast(in_tensors_.at(1)->MutableData()); + auto output = reinterpret_cast(out_tensors_.at(0)->MutableData()); + if (split_start_points_[task_id] == split_end_points_[task_id]) { + return lite::RET_OK; + } + BiasAddOpt(input, bias, output, split_start_points_[task_id], split_end_points_[task_id], inner_num_, + batch_priority_); + return lite::RET_OK; +} + REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BiasAdd, LiteKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/bias_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/bias_fp32.h index 4e4e9aa5f84..5b08ac16d13 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/bias_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/bias_fp32.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,9 +33,17 @@ class BiasCPUKernel : public InnerKernel { int Prepare() override; int ReSize() override; int Run() override; + int DoExecute(int task_id); private: + void GetThreadSegmentInfos(); ArithmeticParameter *bias_param_; + bool batch_priority_{false}; + int64_t inner_num_{0}; + int64_t outer_num_{0}; + int64_t total_num_{0}; + std::vector split_start_points_; + std::vector split_end_points_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc index b37a693f11d..b075a01af8e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc @@ -39,7 +39,7 @@ int MatmulRun(const void *cdata, int task_id, float, float) { MatmulFp32BaseCPUKernel::~MatmulFp32BaseCPUKernel() { FreeResizeBufA(); FreeResizeBufB(); - if (is_pack_ && out_need_aligned_ && oc_res_ != 0 && output_data_ != nullptr) { + if (out_need_aligned_ && output_data_ != nullptr) { free(output_data_); output_data_ = nullptr; } @@ -287,6 +287,34 @@ int MatmulFp32BaseCPUKernel::ParallelRunByBatch(int task_id) const { return RET_OK; } +#if defined(ENABLE_AVX) || defined(ENABLE_AVX512) || defined(ENABLE_ARM64) +int MatmulFp32BaseCPUKernel::ParallelRunByRow(int task_id) const { + int start_row = row_split_points_[task_id]; + int end_row = row_num_; + if (task_id < (thread_count_ - 1)) { + end_row = row_split_points_[task_id + 1]; + } + int row_num = end_row - start_row; + if (row_num <= 0) { + return RET_OK; + } +#if defined(ENABLE_AVX512) + const float *input = a_pack_ptr_ + start_row * params_->deep_; + float *output = output_data_ + start_row * params_->col_align_; + MatMulAvx512Fp32(input, b_pack_ptr_, output, bias_ptr_, params_->act_type_, params_->deep_, params_->col_align_, + params_->col_align_, row_num); +#elif defined(ENABLE_AVX) + const float *input = a_pack_ptr_ + start_row * params_->deep_; + float *output = output_data_ + start_row * params_->col_align_; + MatMulAvxFp32(input, b_pack_ptr_, output, bias_ptr_, params_->act_type_, params_->deep_, params_->col_align_, + params_->col_align_, row_num); +#elif defined(ENABLE_ARM64) + GemmIsNotPackByRow(a_pack_ptr_, b_pack_ptr_, output_data_, bias_ptr_, start_row, end_row, params_->deep_); +#endif + return RET_OK; +} +#endif + int MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch(int task_id) const { int start_batch = task_id * batch_stride_; int end_batch = MSMIN(params_->batch, start_batch + batch_stride_); @@ -295,8 +323,8 @@ int MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch(int task_id) const { bias = bias_ptr_[0]; } for (int index = start_batch; index < end_batch; ++index) { - const float *a = a_pack_ptr_ + index * params_->row_ * params_->deep_; - const float *b = b_pack_ptr_ + index * params_->deep_ * params_->col_; + const float *a = a_pack_ptr_ + a_offset_[index] * params_->row_ * params_->deep_; + const float *b = b_pack_ptr_ + b_offset_[index] * params_->deep_ * params_->col_; float *c = output_data_ + index * params_->row_ * params_->col_; gemmIsNotPackFun(a, b, c, &bias, params_->row_, params_->deep_); } @@ -375,28 +403,28 @@ int MatmulFp32BaseCPUKernel::init_global_variable() { row_tile_ = C12NUM; col_tile_ = C8NUM; #endif + if (params_->col_ == 1 && !params_->a_const_) { + is_pack_ = false; + out_need_aligned_ = false; + row_tile_ = 1; + col_tile_ = 1; + matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor; + matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor; + } params_->row_align_ = UP_ROUND(params_->row_, row_tile_); params_->col_align_ = UP_ROUND(params_->col_, col_tile_); + MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->row_align_, RET_ERROR); + MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->row_align_, params_->deep_, RET_ERROR); + MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->col_align_, RET_ERROR); + MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->col_align_, params_->deep_, RET_ERROR); + matrix_a_pack_size_ = a_batch_ * params_->row_align_ * params_->deep_; + matrix_b_pack_size_ = b_batch_ * params_->col_align_ * params_->deep_; #if defined(ENABLE_AVX) || defined(ENABLE_AVX512) col_step_ = params_->col_align_; #else // need not aligned col_step_ = params_->col_; #endif - MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->row_align_, RET_ERROR); - MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->row_align_, params_->deep_, RET_ERROR); - MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_, params_->col_align_, RET_ERROR); - MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch_ * params_->col_align_, params_->deep_, RET_ERROR); - if (params_->col_ == 1 && params_->b_const_) { - is_pack_ = false; - matrix_a_pack_size_ = a_batch_ * params_->row_ * params_->deep_; - matrix_b_pack_size_ = b_batch_ * params_->col_ * params_->deep_; - matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor; - matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2ColMajor : RowMajor2RowMajor; - } else { - matrix_a_pack_size_ = a_batch_ * params_->row_align_ * params_->deep_; - matrix_b_pack_size_ = b_batch_ * params_->col_align_ * params_->deep_; - } return RET_OK; } @@ -455,6 +483,8 @@ int MatmulFp32BaseCPUKernel::Prepare() { int MatmulFp32BaseCPUKernel::ReSize() { ResizeParameter(); + MS_CHECK_FALSE(INT_MUL_OVERFLOW(a_batch_, params_->row_), RET_ERROR); + row_num_ = a_batch_ * params_->row_; matrix_a_pack_size_ = a_batch_ * params_->row_align_ * params_->deep_; matrix_b_pack_size_ = b_batch_ * params_->col_align_ * params_->deep_; if (matrix_a_pack_size_ < 0 || matrix_b_pack_size_ < 0) { @@ -465,8 +495,12 @@ int MatmulFp32BaseCPUKernel::ReSize() { if (op_parameter_->is_train_session_) { set_workspace_size((matrix_a_pack_size_ + matrix_b_pack_size_) * static_cast(sizeof(float))); } - GetThreadCuttingPolicy(); - auto ret = InitTmpOutBuffer(); + auto ret = GetThreadCuttingPolicy(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ThreadCuttingPolicy error!"; + return ret; + } + ret = InitTmpOutBuffer(); if (ret != RET_OK) { MS_LOG(ERROR) << "InitTmpOutBuffer error!"; return ret; @@ -483,11 +517,11 @@ void MatmulFp32BaseCPUKernel::ResizeParameter() { vec_matmul_ = false; } params_->row_align_ = UP_ROUND(params_->row_, row_tile_); - oc_res_ = params_->col_ % col_tile_; + out_need_aligned_ = (out_need_aligned_ && ((params_->col_ % col_tile_) != 0)); } int MatmulFp32BaseCPUKernel::InitTmpOutBuffer() { - if (is_pack_ && out_need_aligned_ && oc_res_ != 0) { + if (out_need_aligned_) { if (output_data_ != nullptr) { free(output_data_); } @@ -505,12 +539,22 @@ int MatmulFp32BaseCPUKernel::InitTmpOutBuffer() { return RET_OK; } -void MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() { - if (params_->batch >= op_parameter_->thread_num_ || (params_->col_ == 1 && params_->b_const_)) { +int MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() { + if (params_->batch >= op_parameter_->thread_num_ || (params_->col_ == 1 && !params_->a_const_)) { thread_count_ = op_parameter_->thread_num_; batch_stride_ = UP_DIV(params_->batch, thread_count_); batch_split_ = true; parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunByBatch; + } else if (CheckThreadCuttingByRow()) { +#if defined(ENABLE_AVX) || defined(ENABLE_AVX512) + is_pack_ = !params_->b_const_; + batch_split_ = true; + parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunByRow; + GetThreadCuttingInfoByRow(); +#else + MS_LOG(ERROR) << "current branch only support avx."; + return RET_ERROR; +#endif } else { thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(params_->col_align_, col_tile_)); #if defined(ENABLE_AVX) || defined(ENABLE_AVX512) // thread tile by col_tile * C4NUM @@ -521,21 +565,57 @@ void MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() { batch_split_ = false; parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunByOC; } - if (params_->col_ == 1 && params_->b_const_) { + if (params_->col_ == 1 && !params_->a_const_) { is_pack_ = false; + batch_split_ = true; parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunIsNotPackByBatch; if (params_->deep_ == 1) { gemmIsNotPackFun = GemmIsNotPack; } else { gemmIsNotPackFun = GemmIsNotPackOptimize; +#ifdef ENABLE_ARM64 + if (b_batch_ == 1) { + parallel_fun_ = &MatmulFp32BaseCPUKernel::ParallelRunByRow; + GetThreadCuttingInfoByRow(); + } +#endif } } + return RET_OK; +} + +bool MatmulFp32BaseCPUKernel::CheckThreadCuttingByRow() { + if (b_batch_ != C1NUM) { + return false; + } +#if defined(ENABLE_AVX) || defined(ENABLE_AVX512) + if (row_num_ >= op_parameter_->thread_num_) { + return true; + } +#endif + return false; +} + +void MatmulFp32BaseCPUKernel::GetThreadCuttingInfoByRow() { + int row_step = MSMAX(row_num_ / op_parameter_->thread_num_, C64NUM); + int row_remaining = MSMAX(row_num_ - row_step * op_parameter_->thread_num_, 0); + row_split_points_.resize(op_parameter_->thread_num_); + for (size_t i = 0; i < row_split_points_.size(); ++i) { + if (i == 0) { + row_split_points_[i] = 0; + continue; + } + row_split_points_[i] = + MSMIN(row_split_points_[i - 1] + row_step + (static_cast(i) < row_remaining ? 1 : 0), row_num_); + } + int unused_thread_num = std::count(row_split_points_.begin(), row_split_points_.end(), row_num_); + thread_count_ = op_parameter_->thread_num_ - unused_thread_num; } int MatmulFp32BaseCPUKernel::Run() { auto out_data = reinterpret_cast(out_tensors_.front()->data()); CHECK_NULL_RETURN(out_data); - if (!is_pack_ || !out_need_aligned_ || oc_res_ == 0) { + if (!out_need_aligned_) { output_data_ = out_data; } if (!params_->b_const_) { @@ -557,9 +637,7 @@ int MatmulFp32BaseCPUKernel::Run() { if (!params_->a_const_) { auto a_ptr = reinterpret_cast(in_tensors_[0]->data()); CHECK_NULL_RETURN(a_ptr); - if (!is_pack_) { - a_pack_ptr_ = a_ptr; - } else { + if (is_pack_ || (params_->a_transpose_ && params_->deep_ != 1)) { if (InitBufferA() != RET_OK) { return RET_ERROR; } @@ -568,10 +646,12 @@ int MatmulFp32BaseCPUKernel::Run() { MS_LOG(ERROR) << "InitMatrixA failed!"; return ret; } + } else { + a_pack_ptr_ = a_ptr; } } - if (batch_split_ || !is_pack_) { + if (batch_split_) { auto ret = ParallelLaunch(this->ms_context_, MatmulRun, this, thread_count_); if (ret != RET_OK) { MS_LOG(ERROR) << "MatmulRun failed in split by batch"; @@ -590,11 +670,13 @@ int MatmulFp32BaseCPUKernel::Run() { } } - if (oc_res_ != 0 && out_need_aligned_ && is_pack_) { + if (out_need_aligned_) { PackNHWCXToNHWCFp32(output_data_, out_data, params_->batch, params_->row_, params_->col_, col_tile_); + } else { + output_data_ = nullptr; } if (!params_->a_const_) { - if (is_pack_) { + if (is_pack_ || (params_->a_transpose_ && params_->deep_ != 1)) { FreeResizeBufA(); } else { a_pack_ptr_ = nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.h b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.h index 03557032d5b..b74bae35f7a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.h @@ -51,12 +51,14 @@ class MatmulFp32BaseCPUKernel : public InnerKernel { int ReSize() override; int Run() override; +#if defined(ENABLE_AVX) || defined(ENABLE_AVX512) || defined(ENABLE_ARM64) + int ParallelRunByRow(int task_id) const; +#endif int ParallelRunByOC(int task_id) const; int ParallelRunByBatch(int task_id) const; int ParallelRunIsNotPackByBatch(int task_id) const; using ParallelRun = int (MatmulFp32BaseCPUKernel::*)(int task_id) const; ParallelRun parallel_fun_ = nullptr; - bool is_pack_ = true; protected: int InitBufferA(); @@ -74,7 +76,9 @@ class MatmulFp32BaseCPUKernel : public InnerKernel { void FreeResizeBufB(); int CalBroadCastBiasDataElements(); int InitTmpOutBuffer(); - void GetThreadCuttingPolicy(); + int GetThreadCuttingPolicy(); + bool CheckThreadCuttingByRow(); + void GetThreadCuttingInfoByRow(); protected: MatMulParameter *params_ = nullptr; @@ -92,7 +96,6 @@ class MatmulFp32BaseCPUKernel : public InnerKernel { private: int col_tile_ = 0; int row_tile_ = 0; - int oc_res_ = 0; int batch_stride_ = 0; int oc_stride_ = 0; int thread_count_ = 0; @@ -107,6 +110,7 @@ class MatmulFp32BaseCPUKernel : public InnerKernel { MatrixPackFun matrix_a_pack_fun_ = nullptr; MatrixPackFun matrix_b_pack_fun_ = nullptr; bool batch_split_ = false; + bool is_pack_ = true; bool out_need_aligned_ = false; int col_step_ = 0; #if defined(ENABLE_AVX) || defined(ENABLE_AVX512) @@ -114,6 +118,8 @@ class MatmulFp32BaseCPUKernel : public InnerKernel { GemvFun gemvCalFun = nullptr; #endif GemmIsNotPackFun gemmIsNotPackFun = nullptr; + int row_num_; + std::vector row_split_points_; }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_MATMUL_FP32_BASE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc index 12dc4959b90..31dbe79bf75 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,21 +36,31 @@ int TransposeCPUKernel::Prepare() { } int TransposeCPUKernel::ReSize() { + auto &inTensor = in_tensors_.front(); + auto in_shape = inTensor->shape(); if (in_tensors_.size() == 2) { param_->num_axes_ = in_tensors_.at(1)->ElementsNum(); } - int trans3d[3] = {0, 2, 1}; + if (in_shape.size() > MAX_TRANSPOSE_DIM_SIZE) { + MS_LOG(ERROR) << "input shape out of range."; + return RET_ERROR; + } + int transNd[MAX_TRANSPOSE_DIM_SIZE] = {0, 2, 1}; int *perm_data = nullptr; auto input_tensor = in_tensors_.at(kInputIndex); if (input_tensor->shape().size() != static_cast(param_->num_axes_)) { - if (input_tensor->shape().size() == 3 && param_->num_axes_ == 4) { - param_->num_axes_ = 3; - perm_data = trans3d; - } else { - return RET_OK; + perm_data = transNd; + if (input_tensor->shape().size() == C3NUM && param_->num_axes_ == C4NUM) { + param_->num_axes_ = C3NUM; + } + if (param_->num_axes_ == 0) { + for (int i = 0; i < static_cast(in_shape.size()); ++i) { + transNd[i] = static_cast(in_shape.size()) - 1 - i; + } + param_->num_axes_ = static_cast(in_shape.size()); } } else { - MS_ASSERT(in_tensors_.size() == 2); + MS_ASSERT(in_tensors_.size() == C2NUM); auto perm_tensor = in_tensors_.at(1); if (perm_tensor->data_type() != kNumberTypeInt32) { MS_LOG(ERROR) << "Unsupported type id: " << perm_tensor->data_type() << " of perm tensor."; @@ -59,30 +69,31 @@ int TransposeCPUKernel::ReSize() { perm_data = reinterpret_cast(perm_tensor->data()); MSLITE_CHECK_PTR(perm_data); } - if (param_->num_axes_ > MAX_TRANSPOSE_DIM_SIZE || param_->num_axes_ < 0) { - MS_LOG(ERROR) << "num_axes_ " << param_->num_axes_ << "is invalid."; - return RET_ERROR; - } + MS_CHECK_TRUE_MSG(param_->num_axes_ <= MAX_TRANSPOSE_DIM_SIZE, RET_ERROR, "transpose's perm is invalid."); for (int i = 0; i < param_->num_axes_; ++i) { param_->perm_[i] = perm_data[i]; } - for (int i = 0; i < param_->num_axes_; i++) { - if (param_->perm_[i] < 0 || param_->perm_[i] >= param_->num_axes_) { - MS_LOG(ERROR) << "Check perm failed."; - return RET_ERROR; - } + if (GetOptParameters() != RET_OK) { + MS_LOG(ERROR) << "cannot compute optimizer parameters."; + return RET_ERROR; + } + DecideIfOnlyCopy(); + if (only_copy_) { + return RET_OK; + } + GetOptTransposeFunc(); + if (optTransposeFunc_ != nullptr) { + return RET_OK; } - auto &inTensor = in_tensors_.front(); auto &outTensor = out_tensors_.front(); - auto in_shape = inTensor->shape(); auto out_shape = outTensor->shape(); param_->strides_[param_->num_axes_ - 1] = 1; param_->out_strides_[param_->num_axes_ - 1] = 1; param_->data_num_ = inTensor->ElementsNum(); - MS_CHECK_LE(static_cast(param_->num_axes_), in_shape.size(), RET_ERROR); - MS_CHECK_LE(static_cast(param_->num_axes_), out_shape.size(), RET_ERROR); + MS_CHECK_TRUE_RET(static_cast(param_->num_axes_) == in_shape.size(), RET_ERROR); + MS_CHECK_TRUE_RET(static_cast(param_->num_axes_) == out_shape.size(), RET_ERROR); for (int i = param_->num_axes_ - 2; i >= 0; i--) { param_->strides_[i] = in_shape.at(i + 1) * param_->strides_[i + 1]; param_->out_strides_[i] = out_shape.at(i + 1) * param_->out_strides_[i + 1]; @@ -102,24 +113,104 @@ int TransposeCPUKernel::ReSize() { return RET_OK; } +int TransposeCPUKernel::GetOptParameters() { + auto in_shape = in_tensors_[0]->shape(); + if (in_shape.size() != static_cast(param_->num_axes_)) { + return RET_OK; + } + for (int i = 0; i < param_->num_axes_; i++) { + if (param_->perm_[i] < 0 || param_->perm_[i] >= param_->num_axes_) { + MS_LOG(ERROR) << "Check perm failed."; + return RET_ERROR; + } + } + std::vector> segments; + for (int i = 0; i < param_->num_axes_;) { + std::vector segment{param_->perm_[i]}; + ++i; + for (; i < param_->num_axes_; ++i) { + if (param_->perm_[i] - 1 != param_->perm_[i - 1]) { + break; + } + segment.push_back(param_->perm_[i]); + } + segments.push_back(segment); + } + in_shape_opt_ = std::vector(segments.size(), 1); + perm_opt_ = std::vector(segments.size(), 0); + for (size_t i = 0; i < segments.size(); ++i) { + for (size_t j = 0; j < segments.size(); ++j) { + perm_opt_[i] += (segments[j].front() < segments[i].front() ? 1 : 0); + } + for (auto index : segments[i]) { + MS_CHECK_FALSE(INT_MUL_OVERFLOW(in_shape_opt_[perm_opt_[i]], in_shape[index]), RET_ERROR); + in_shape_opt_[perm_opt_[i]] *= in_shape[index]; + } + } + return RET_OK; +} + +void TransposeCPUKernel::DecideIfOnlyCopy() { + auto in_shape = in_tensors_[0]->shape(); + int dim = 0; + if (in_shape.size() != static_cast(param_->num_axes_) || perm_opt_.size() == 1) { + only_copy_ = true; + return; + } + dim = 0; + std::vector need_trans_dims; + std::for_each(perm_opt_.begin(), perm_opt_.end(), [&dim, &need_trans_dims](int val) { + if (val != dim) { + need_trans_dims.push_back(dim); + } + ++dim; + }); + if (need_trans_dims.size() == C2NUM && need_trans_dims.back() - need_trans_dims.front() == C1NUM) { + if (in_shape_opt_[need_trans_dims.front()] == 1 || in_shape_opt_[need_trans_dims.back()] == 1) { + only_copy_ = true; + return; + } + } + only_copy_ = false; +} + +void TransposeCPUKernel::SetOptTransposeFunc() { optTransposeFunc_ = PackNHWCToNCHWFp32; } + +int TransposeCPUKernel::GetOptTransposeFunc() { + if (in_tensors_[0]->data_type() != kNumberTypeFloat32 || perm_opt_.size() > C3NUM || perm_opt_.size() < C2NUM) { + optTransposeFunc_ = nullptr; + return RET_OK; + } + bool trans_last_two_dim{true}; + for (size_t i = 0; i < perm_opt_.size() - C2NUM; ++i) { + if (perm_opt_[i] != static_cast(i)) { + trans_last_two_dim = false; + break; + } + } + if (!trans_last_two_dim) { + optTransposeFunc_ = nullptr; + return RET_OK; + } + SetOptTransposeFunc(); + if (perm_opt_.size() == C2NUM) { + nhnc_param_[FIRST_INPUT] = 1; + nhnc_param_[SECOND_INPUT] = in_shape_opt_.front(); + nhnc_param_[THIRD_INPUT] = in_shape_opt_.back(); + } else { + nhnc_param_[FIRST_INPUT] = in_shape_opt_.front(); + nhnc_param_[SECOND_INPUT] = in_shape_opt_[SECOND_INPUT]; + nhnc_param_[THIRD_INPUT] = in_shape_opt_.back(); + } + return RET_OK; +} + TransposeCPUKernel::~TransposeCPUKernel() { if (this->out_shape_ != nullptr) { free(this->out_shape_); } } -void TransposeCPUKernel::GetNchwToNhwcFunc(TypeId dtype) { - if (dtype == kNumberTypeFloat32) { - NHNCTransposeFunc_ = PackNCHWToNHWCFp32; - } -} - -void TransposeCPUKernel::GetNhwcToNchwFunc(TypeId dtype) { - if (dtype == kNumberTypeFloat32) { - NHNCTransposeFunc_ = PackNHWCToNCHWFp32; - } -} - int TransposeCPUKernel::TransposeDim2to6() { return DoTransposeFp32(static_cast(in_data_), static_cast(out_data_), out_shape_, param_); } @@ -130,34 +221,35 @@ int TransposeCPUKernel::TransposeDimGreaterThan6(int task_id) { return RET_OK; } -int TransposeCPUKernel::GetNHNCTransposeFunc(const lite::Tensor *in_tensor, const lite::Tensor *out_tensor) { - if (in_tensor->shape().size() != 4) { +int TransposeCPUKernel::CopyInputToOutput() { + auto in_tensor = in_tensors().front(); + CHECK_NULL_RETURN(in_tensor); + auto out_tensor = out_tensors().front(); + CHECK_NULL_RETURN(out_tensor); + if (in_tensor->allocator() == nullptr || in_tensor->allocator() != out_tensor->allocator() || + in_tensor->allocator() != ms_context_->allocator || op_parameter_->is_train_session_ || + ((in_tensor->IsGraphInput() || in_tensor->IsGraphOutput()) && out_tensor->IsGraphOutput())) { + CHECK_NULL_RETURN(out_tensor->data()); + CHECK_NULL_RETURN(in_tensor->data()); + MS_CHECK_FALSE(in_tensor->Size() == 0, RET_ERROR); + if (in_tensor->data() != out_tensor->data()) { + memcpy(out_tensor->data(), in_tensor->data(), in_tensor->Size()); + } return RET_OK; } - auto out_shape = out_tensor->shape(); - if (param_->perm_[FIRST_INPUT] == FIRST_INPUT && param_->perm_[SECOND_INPUT] == THIRD_INPUT && - param_->perm_[THIRD_INPUT] == FOURTH_INPUT && param_->perm_[FOURTH_INPUT] == SECOND_INPUT) { - nhnc_param_[FIRST_INPUT] = out_shape[FIRST_INPUT]; - MS_CHECK_FALSE(INT_MUL_OVERFLOW(out_shape[SECOND_INPUT], out_shape[THIRD_INPUT]), RET_ERROR); - nhnc_param_[SECOND_INPUT] = out_shape[SECOND_INPUT] * out_shape[THIRD_INPUT]; - nhnc_param_[THIRD_INPUT] = out_shape[FOURTH_INPUT]; - GetNchwToNhwcFunc(in_tensor->data_type()); - } - if (param_->perm_[FIRST_INPUT] == FIRST_INPUT && param_->perm_[SECOND_INPUT] == FOURTH_INPUT && - param_->perm_[THIRD_INPUT] == SECOND_INPUT && param_->perm_[FOURTH_INPUT] == THIRD_INPUT) { - nhnc_param_[FIRST_INPUT] = out_shape[FIRST_INPUT]; - MS_CHECK_FALSE(INT_MUL_OVERFLOW(out_shape[THIRD_INPUT], out_shape[FOURTH_INPUT]), RET_ERROR); - nhnc_param_[SECOND_INPUT] = out_shape[THIRD_INPUT] * out_shape[FOURTH_INPUT]; - nhnc_param_[THIRD_INPUT] = out_shape[SECOND_INPUT]; - GetNhwcToNchwFunc(in_tensor->data_type()); - } + + out_tensor->FreeData(); + out_tensor->ResetRefCount(); + in_tensor->allocator()->IncRefCount(in_tensor->data(), out_tensor->ref_count()); + out_tensor->set_data(in_tensor->data()); + out_tensor->set_own_data(in_tensor->own_data()); return RET_OK; } int TransposeCPUKernel::RunImpl(int task_id) { - if (NHNCTransposeFunc_ != nullptr) { - NHNCTransposeFunc_(in_data_, out_data_, nhnc_param_[FIRST_INPUT], nhnc_param_[SECOND_INPUT], - nhnc_param_[THIRD_INPUT], task_id, op_parameter_->thread_num_); + if (optTransposeFunc_ != nullptr) { + optTransposeFunc_(in_data_, out_data_, nhnc_param_[FIRST_INPUT], nhnc_param_[SECOND_INPUT], + nhnc_param_[THIRD_INPUT], task_id, op_parameter_->thread_num_); } else { return TransposeDimGreaterThan6(task_id); } @@ -176,6 +268,9 @@ int TransposeImpl(void *kernel, int task_id, float lhs_scale, float rhs_scale) { int TransposeCPUKernel::Run() { MS_ASSERT(in_tensors_.size() == 1 || in_tensors_.size() == 2); MS_ASSERT(out_tensors_.size() == 1); + if (only_copy_) { + return CopyInputToOutput(); + } auto &in_tensor = in_tensors_.front(); auto &out_tensor = out_tensors_.front(); if (in_tensor == nullptr || out_tensor == nullptr) { @@ -186,16 +281,7 @@ int TransposeCPUKernel::Run() { out_data_ = out_tensor->data(); CHECK_NULL_RETURN(in_data_); CHECK_NULL_RETURN(out_data_); - - if (in_tensor->shape().size() != static_cast(param_->num_axes_)) { - memcpy(out_data_, in_data_, in_tensor->Size()); - return RET_OK; - } - if (GetNHNCTransposeFunc(in_tensor, out_tensor) != RET_OK) { - MS_LOG(ERROR) << "Get NHWC tranpose func fail!"; - return RET_ERROR; - } - if (NHNCTransposeFunc_ != nullptr) { + if (optTransposeFunc_ != nullptr) { return ParallelLaunch(this->ms_context_, TransposeImpl, this, op_parameter_->thread_num_); } if (out_tensor->shape().size() <= DIMENSION_6D) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.h index d24cf8b882a..62364b75757 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.h @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,18 +43,28 @@ class TransposeCPUKernel : public InnerKernel { int RunImpl(int task_id); protected: - virtual void GetNchwToNhwcFunc(TypeId dtype); - virtual void GetNhwcToNchwFunc(TypeId dtype); + virtual void SetOptTransposeFunc(); virtual int TransposeDim2to6(); virtual int TransposeDimGreaterThan6(int task_id); - int GetNHNCTransposeFunc(const lite::Tensor *in_tensor, const lite::Tensor *out_tensor); + private: + int GetOptParameters(); + void DecideIfOnlyCopy(); + int GetOptTransposeFunc(); + int CopyInputToOutput(); + + protected: void *in_data_ = nullptr; void *out_data_ = nullptr; int *out_shape_ = nullptr; TransposeParameter *param_ = nullptr; - TransposeFunc NHNCTransposeFunc_ = nullptr; + TransposeFunc optTransposeFunc_ = nullptr; + + private: int nhnc_param_[3] = {0}; + bool only_copy_{false}; + std::vector in_shape_opt_; + std::vector perm_opt_; }; } // namespace mindspore::kernel