From 523c3f1e8747aa9c8f93c7417c44365b8a2a4a94 Mon Sep 17 00:00:00 2001 From: ling Date: Thu, 13 Aug 2020 19:15:22 +0800 Subject: [PATCH] fp32 matmul optimize --- .../kernel/arm/fp32/convolution_1x1.cc | 36 +- .../runtime/kernel/arm/fp32/convolution_1x1.h | 2 - .../runtime/kernel/arm/fp32/deconvolution.cc | 2 +- .../runtime/kernel/arm/fp32/fullconnection.cc | 2 +- .../src/runtime/kernel/arm/fp32/matmul.cc | 2 +- .../assembly/arm64/IndirectGemmFp32_8x8.S | 4 +- .../assembly/arm64/{matmul.s => matmul.S} | 410 ++++++++++++++++-- .../runtime/kernel/arm/nnacl/fp32/matmul.c | 65 ++- .../runtime/kernel/arm/nnacl/fp32/matmul.h | 5 +- .../kernel/arm/fp32/conv1x1_fp32_tests.cc | 47 +- 10 files changed, 460 insertions(+), 115 deletions(-) rename mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/{matmul.s => matmul.S} (54%) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.cc index 94820b94809..9425b264980 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.cc @@ -31,10 +31,6 @@ Convolution1x1CPUKernel::~Convolution1x1CPUKernel() { free(pack_input_); pack_input_ = nullptr; } - if (pack_output_ != nullptr) { - free(pack_output_); - pack_output_ = nullptr; - } if (pre_trans_input_ && input_ptr_ != nullptr) { free(input_ptr_); input_ptr_ = nullptr; @@ -112,13 +108,6 @@ int Convolution1x1CPUKernel::InitConv1x1Param() { return RET_MEMORY_FAILED; } memset(pack_input_, 0, matmul_param_->row_8_ * matmul_param_->deep_ * sizeof(float)); - - pack_output_ = reinterpret_cast(malloc(matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float))); - if (pack_output_ == nullptr) { - MS_LOG(ERROR) << "Conv1x1 Malloc pack_output_ error!"; - return RET_MEMORY_FAILED; - } - memset(pack_output_, 0, matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float)); return RET_OK; } @@ -157,7 +146,7 @@ int Convolution1x1CPUKernel::Init() { } int Convolution1x1CPUKernel::DoConv1x1(int task_id) { - int cur_oc = MSMIN(thread_stride_, matmul_param_->col_8_ - task_id * thread_stride_); + int cur_oc = MSMIN(thread_stride_, matmul_param_->col_ - task_id * thread_stride_); if (cur_oc <= 0) { return RET_OK; } @@ -165,23 +154,12 @@ int Convolution1x1CPUKernel::DoConv1x1(int task_id) { auto bias = (bias_data_ == nullptr) ? nullptr : reinterpret_cast(bias_data_) + thread_stride_ * task_id; MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_, - pack_output_ + task_id * thread_stride_ * matmul_param_->row_8_, bias, matmul_param_->act_type_, - matmul_param_->deep_, matmul_param_->row_8_, cur_oc); + output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_, + matmul_param_->row_, cur_oc, matmul_param_->col_, true); return RET_OK; } -int Convolution1x1CPUKernel::DoConv1x1Post(int task_id) { - int cur_oc = MSMIN(thread_stride_, matmul_param_->col_ - task_id * thread_stride_); - if (cur_oc <= 0) { - return RET_OK; - } - float *src = pack_output_ + task_id * thread_stride_ * matmul_param_->row_8_; - float *dst = output_ptr_ + task_id * thread_stride_; - Row8x8Major2RowMajor(src, dst, matmul_param_->row_, cur_oc, matmul_param_->col_); - return RET_OK; -} - int Convolution1x1Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { auto conv1x1 = reinterpret_cast(cdata); auto error_code = conv1x1->DoConv1x1(task_id); @@ -192,12 +170,6 @@ int Convolution1x1Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { return RET_OK; } -int Convolution1x1Post(int task_id, LiteParallelGroupEnv *penv, void *cdata) { - auto conv1x1 = reinterpret_cast(cdata); - conv1x1->DoConv1x1Post(task_id); - return RET_OK; -} - int Convolution1x1CPUKernel::Run() { auto prepare_ret = Prepare(); if (prepare_ret != RET_OK) { @@ -216,8 +188,6 @@ int Convolution1x1CPUKernel::Run() { MS_LOG(ERROR) << "conv1x1 strassen error error_code[" << error_code << "]"; return RET_ERROR; } - - LiteBackendParallelLaunch(Convolution1x1Post, this, thread_count_); } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.h index 47511fee549..f3803e1dafc 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.h @@ -46,7 +46,6 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel { public: int DoConv1x1(int task_id); - int DoConv1x1Post(int task_id); private: int InitConv1x1Param(); @@ -61,7 +60,6 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel { int thread_stride_ = 0; float *weight_ptr_ = nullptr; float *pack_input_ = nullptr; - float *pack_output_ = nullptr; float *input_ptr_ = nullptr; float *output_ptr_ = nullptr; }; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc index 1e6fcf36631..6ab1b8190eb 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc @@ -152,7 +152,7 @@ int DeConvolutionCPUKernel::DoDeconv(int task_id) { MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_8_, nullptr, ActType_No, - matmul_param_->deep_, matmul_param_->row_8_, oc * C8NUM * kernel_plane_); + matmul_param_->deep_, matmul_param_->row_8_, oc * C8NUM * kernel_plane_, matmul_param_->col_, false); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc index 1299b3e932e..096127e184d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc @@ -104,7 +104,7 @@ int FullconnectionCPUKernel::DoMatmul(int task_id) { MatMul(a_c8_ptr_, b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_, c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->row_8_, bias_ptr_ + task_id * thread_stride_ * C8NUM, fc_param_->act_type_, fc_param_->deep_, fc_param_->row_8_, - cur_oc * 8); + cur_oc * 8, 0, false); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc index dc515468e05..88a0bc7f8cb 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc @@ -77,7 +77,7 @@ int MatmulCPUKernel::RunImpl(int task_id) { } auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_; auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * params_->row_8_; - MatMul(a_c8_ptr_, cur_b, cur_c, NULL, ActType_No, params_->deep_, params_->row_8_, cur_oc * 8); + MatMul(a_c8_ptr_, cur_b, cur_c, NULL, ActType_No, params_->deep_, params_->row_8_, cur_oc * 8, 0, false); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/IndirectGemmFp32_8x8.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/IndirectGemmFp32_8x8.S index 6373bb71325..483dfac09b0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/IndirectGemmFp32_8x8.S +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/IndirectGemmFp32_8x8.S @@ -640,7 +640,7 @@ IndirectGemmStart: add x15, x15, x7 str s30, [x15] add x0, x0, #4 - b WriteEnd + b WriteEndHalf Write2: dup s17, v16.s[1] stp s16, s17, [x15] @@ -666,7 +666,7 @@ IndirectGemmStart: dup s31, v30.s[1] stp s30, s31, [x15] add x0, x0, #8 - b WriteEnd + b WriteEndHalf Write3: add x17, x15, #8 dup s17, v16.s[1] diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matmul.s b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matmul.S similarity index 54% rename from mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matmul.s rename to mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matmul.S index 2231e8debf9..3c5433f62a7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matmul.s +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matmul.S @@ -27,7 +27,7 @@ // accumulators 8x8 block // /////////////////////////////////////////////////////////////////////////////// -//OptLoopMul4 RM 1x8 block +//OptLoopMul4 RM 4x8 block // /--------------------------------------------\ // |v8.s[0] ... v8.s[3] v9.s[0] ... v9.s[3] | // |v10.s[0] ... v10.s[3] v11.s[0] ... v11.s[3]| @@ -46,7 +46,8 @@ // accumulators 8x8 block ///////////////////////////////////////////////////////////////////////////////// // -// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, int col) +// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, int stride, bool write_nhwc) // x0: a // x1: b // x2: c @@ -55,30 +56,30 @@ // w5: depth // w6: row // w7: col +// w17: stride +// w13: writeC8 MatmulFloatNeon64: sub sp, sp, #128 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 - mov w9, #0 // rm col offset - mov w10, #0 // lm row offset - mov w18, #32 // sizeof(float)*8 - mul w15, w5, w18 // the stride of lm/rm: sizeof(float)*8*depth - mov x11, x3 // bias flag + mov w18, #32 // sizeof(float) * 8 + mul w15, w5, w18 // block stride of lhs/rhs: sizeof(float) * 8 * depth + mov x11, x3 // bias flag + mov x18, #4 + ldr x17, [sp] + mul x17, x17, x18 + L1: - cmp w9, w7 - beq End1 + mov w10, w6 // reload lhs row + mov x12, x0 // reload lhs ptr + mov x18, x2 // reload dst ptr - mov w10, #0 // reset lm row offset - mov x12, x0 // reload lm ptr L2: - cmp w10, w6 - beq End2 - - mov x16, x1 // reload rm ptr - mov w13, w5 // reload depth - mov x14, x3 // reload bias ptr + mov x16, x1 // reload rhs ptr + mov w13, w5 // reload depth + mov x14, x3 // reload bias ptr dup v16.4s, wzr dup v17.4s, wzr dup v18.4s, wzr @@ -96,10 +97,10 @@ L2: dup v30.4s, wzr dup v31.4s, wzr -OptLoopMul4: cmp w13, #4 blt CommLoopMul +OptLoopMul4: ld1 {v0.4s, v1.4s}, [x12], #32 ld1 {v8.4s, v9.4s}, [x16], #32 fmla v16.4s, v8.4s, v0.s[0] @@ -172,13 +173,14 @@ OptLoopMul4: fmla v29.4s, v15.4s, v7.s[2] fmla v30.4s, v14.4s, v7.s[3] fmla v31.4s, v15.4s, v7.s[3] - subs w13, w13, #4 - b OptLoopMul4 + + sub w13, w13, #4 + cmp w13, #0 + ble Bias + cmp w13, #4 + bge OptLoopMul4 CommLoopMul: - cmp w13, #1 - blt Bias - ld1 {v0.4s, v1.4s}, [x12], #32 ld1 {v2.4s, v3.4s}, [x16], #32 fmla v16.4s, v2.4s, v0.s[0] @@ -197,8 +199,9 @@ CommLoopMul: fmla v29.4s, v3.4s, v1.s[2] fmla v30.4s, v2.4s, v1.s[3] fmla v31.4s, v3.4s, v1.s[3] + subs w13, w13, #1 - b CommLoopMul + bgt CommLoopMul Bias: cbz x11, Activation @@ -226,7 +229,8 @@ Activation: beq Relu6 cmp w4, #1 beq Relu - b TransToOut + b Write + Relu6: mov w8, #6 dup v15.4s, w8 @@ -247,6 +251,7 @@ Relu6: fmin v29.4s, v29.4s, v15.4s fmin v30.4s, v30.4s, v15.4s fmin v31.4s, v31.4s, v15.4s + Relu: dup v14.4s, wzr fmax v16.4s, v16.4s, v14.4s @@ -266,7 +271,317 @@ Relu: fmax v30.4s, v30.4s, v14.4s fmax v31.4s, v31.4s, v14.4s -TransToOut: +Write: + ldrb w13, [sp, #8] + cbz w13, WriteC8 + cmp w7, #1 + beq Write1 + cmp w7, #2 + beq Write2 + cmp w7, #3 + beq Write3 + cmp w7, #4 + beq Write4 + cmp w7, #5 + beq Write5 + cmp w7, #6 + beq Write6 + cmp w7, #7 + beq Write7 + b Write8 + +Write1: + str s16, [x18] + cmp w10, #1 + beq WriteEnd + add x18, x18, x17 + str s18, [x18] + cmp w10, #2 + beq WriteEnd + add x18, x18, x17 + str s20, [x18] + cmp w10, #3 + beq WriteEnd + add x18, x18, x17 + str s22, [x18] + cmp w10, #4 + beq WriteEnd + add x18, x18, x17 + str s24, [x18] + cmp w10, #5 + beq WriteEnd + add x18, x18, x17 + str s26, [x18] + cmp w10, #6 + beq WriteEnd + add x18, x18, x17 + str s28, [x18] + cmp w10, #7 + beq WriteEnd + add x18, x18, x17 + str s30, [x18] + add x18, x18, x17 + b WriteEnd +Write2: + dup s17, v16.s[1] + stp s16, s17, [x18] + cmp w10, #1 + beq WriteEnd + add x18, x18, x17 + dup s19, v18.s[1] + stp s18, s19, [x18] + cmp w10, #2 + beq WriteEnd + add x18, x18, x17 + dup s21, v20.s[1] + stp s20, s21, [x18] + cmp w10, #3 + beq WriteEnd + add x18, x18, x17 + dup s23, v22.s[1] + stp s22, s23, [x18] + cmp w10, #4 + beq WriteEnd + add x18, x18, x17 + dup s25, v24.s[1] + stp s24, s25, [x18] + cmp w10, #5 + beq WriteEnd + add x18, x18, x17 + dup s27, v26.s[1] + stp s26, s27, [x18] + cmp w10, #6 + beq WriteEnd + add x18, x18, x17 + dup s29, v28.s[1] + stp s28, s29, [x18] + cmp w10, #7 + beq WriteEnd + add x18, x18, x17 + dup s31, v30.s[1] + stp s30, s31, [x18] + add x18, x18, x17 + b WriteEnd +Write3: + add x13, x18, #8 + dup s17, v16.s[1] + stp s16, s17, [x18] + add x18, x18, x17 + st1 {v16.s}[2], [x13], x17 + cmp w10, #1 + beq WriteEnd + dup s19, v18.s[1] + stp s18, s19, [x18] + add x18, x18, x17 + st1 {v18.s}[2], [x13], x17 + cmp w10, #2 + beq WriteEnd + dup s21, v20.s[1] + stp s20, s21, [x18] + add x18, x18, x17 + st1 {v20.s}[2], [x13], x17 + cmp w10, #3 + beq WriteEnd + dup s23, v22.s[1] + stp s22, s23, [x18] + add x18, x18, x17 + st1 {v22.s}[2], [x13], x17 + cmp w10, #4 + beq WriteEnd + dup s25, v24.s[1] + stp s24, s25, [x18] + add x18, x18, x17 + st1 {v24.s}[2], [x13], x17 + cmp w10, #5 + beq WriteEnd + dup s27, v26.s[1] + stp s26, s27, [x18] + add x18, x18, x17 + st1 {v26.s}[2], [x13], x17 + cmp w10, #6 + beq WriteEnd + dup s29, v28.s[1] + stp s28, s29, [x18] + add x18, x18, x17 + st1 {v28.s}[2], [x13], x17 + cmp w10, #7 + beq WriteEnd + dup s31, v30.s[1] + stp s30, s31, [x18] + add x18, x18, x17 + st1 {v30.s}[2], [x13] + b WriteEnd +Write4: + st1 {v16.4s}, [x18], x17 + cmp w10, #1 + beq WriteEnd + st1 {v18.4s}, [x18], x17 + cmp w10, #2 + beq WriteEnd + st1 {v20.4s}, [x18], x17 + cmp w10, #3 + beq WriteEnd + st1 {v22.4s}, [x18], x17 + cmp w10, #4 + beq WriteEnd + st1 {v24.4s}, [x18], x17 + cmp w10, #5 + beq WriteEnd + st1 {v26.4s}, [x18], x17 + cmp w10, #6 + beq WriteEnd + st1 {v28.4s}, [x18], x17 + cmp w10, #7 + beq WriteEnd + st1 {v30.4s}, [x18], x17 + b WriteEnd +Write5: + add x13, x18, #16 + st1 {v16.4s}, [x18], x17 + str s17, [x13] + cmp w10, #1 + beq WriteEnd + add x13, x13, x17 + st1 {v18.4s}, [x18], x17 + str s19, [x13] + cmp w10, #2 + beq WriteEnd + add x13, x13, x17 + st1 {v20.4s}, [x18], x17 + str s21, [x13] + cmp w10, #3 + beq WriteEnd + add x13, x13, x17 + st1 {v22.4s}, [x18], x17 + str s23, [x13] + cmp w10, #4 + beq WriteEnd + add x13, x13, x17 + st1 {v24.4s}, [x18], x17 + str s25, [x13] + cmp w10, #5 + beq WriteEnd + add x13, x13, x17 + st1 {v26.4s}, [x18], x17 + str s27, [x13] + cmp w10, #6 + beq WriteEnd + add x13, x13, x17 + st1 {v28.4s}, [x18], x17 + str s29, [x13] + cmp w10, #7 + beq WriteEnd + add x13, x13, x17 + st1 {v30.4s}, [x18], x17 + str s31, [x13] + b WriteEnd +Write6: + add x13, x18, #16 + st1 {v16.4s}, [x18], x17 + dup s16, v17.s[1] + stp s17, s16, [x13] + cmp w10, #1 + beq WriteEnd + add x13, x13, x17 + st1 {v18.4s}, [x18], x17 + dup s18, v19.s[1] + stp s19, s18, [x13] + cmp w10, #2 + beq WriteEnd + add x13, x13, x17 + st1 {v20.4s}, [x18], x17 + dup s20, v21.s[1] + stp s21, s20, [x13] + cmp w10, #3 + beq WriteEnd + add x13, x13, x17 + st1 {v22.4s}, [x18], x17 + dup s22, v23.s[1] + stp s23, s22, [x13] + cmp w10, #4 + beq WriteEnd + add x13, x13, x17 + st1 {v24.4s}, [x18], x17 + dup s24, v25.s[1] + stp s25, s24, [x13] + cmp w10, #5 + beq WriteEnd + add x13, x13, x17 + st1 {v26.4s}, [x18], x17 + dup s26, v27.s[1] + stp s27, s26, [x13] + cmp w10, #6 + beq WriteEnd + add x13, x13, x17 + st1 {v28.4s}, [x18], x17 + dup s28, v29.s[1] + stp s29, s28, [x13] + cmp w10, #7 + beq WriteEnd + add x13, x13, x17 + st1 {v30.4s}, [x18], x17 + dup s30, v31.s[1] + stp s31, s30, [x13] + b WriteEnd +Write7: + add x13, x18, #16 + add x16, x18, #24 + st1 {v16.4s}, [x18], x17 + dup s16, v17.s[1] + stp s17, s16, [x13] + add x13, x13, x17 + st1 {v17.s}[2], [x16], x17 + cmp w10, #1 + beq WriteEnd + st1 {v18.4s}, [x18], x17 + dup s18, v19.s[1] + stp s19, s18, [x13] + add x13, x13, x17 + st1 {v19.s}[2], [x16], x17 + cmp w10, #2 + beq WriteEnd + st1 {v20.4s}, [x18], x17 + dup s20, v21.s[1] + stp s21, s20, [x13] + add x13, x13, x17 + st1 {v21.s}[2], [x16], x17 + cmp w10, #3 + beq WriteEnd + st1 {v22.4s}, [x18], x17 + dup s22, v23.s[1] + stp s23, s22, [x13] + add x13, x13, x17 + st1 {v23.s}[2], [x16], x17 + cmp w10, #4 + beq WriteEnd + st1 {v24.4s}, [x18], x17 + dup s24, v25.s[1] + stp s25, s24, [x13] + add x13, x13, x17 + st1 {v25.s}[2], [x16], x17 + cmp w10, #5 + beq WriteEnd + st1 {v26.4s}, [x18], x17 + dup s26, v27.s[1] + stp s27, s26, [x13] + add x13, x13, x17 + st1 {v27.s}[2], [x16], x17 + cmp w10, #6 + beq WriteEnd + st1 {v28.4s}, [x18], x17 + dup s28, v29.s[1] + stp s29, s28, [x13] + add x13, x13, x17 + st1 {v29.s}[2], [x16], x17 + cmp w10, #7 + beq WriteEnd + st1 {v30.4s}, [x18], x17 + dup s30, v31.s[1] + stp s31, s30, [x13] + add x13, x13, x17 + st1 {v31.s}[2], [x16], x17 + b WriteEnd +WriteC8: st1 {v16.4s}, [x2], #16 st1 {v17.4s}, [x2], #16 st1 {v18.4s}, [x2], #16 @@ -283,19 +598,48 @@ TransToOut: st1 {v29.4s}, [x2], #16 st1 {v30.4s}, [x2], #16 st1 {v31.4s}, [x2], #16 + b WriteEnd +Write8: + st1 {v16.4s, v17.4s}, [x18], x17 + cmp w10, #1 + beq WriteEnd + st1 {v18.4s, v19.4s}, [x18], x17 + cmp w10, #2 + beq WriteEnd + st1 {v20.4s, v21.4s}, [x18], x17 + cmp w10, #3 + beq WriteEnd + st1 {v22.4s, v23.4s}, [x18], x17 + cmp w10, #4 + beq WriteEnd + st1 {v24.4s, v25.4s}, [x18], x17 + cmp w10, #5 + beq WriteEnd + st1 {v26.4s, v27.4s}, [x18], x17 + cmp w10, #6 + beq WriteEnd + st1 {v28.4s, v29.4s}, [x18], x17 + cmp w10, #7 + beq WriteEnd + st1 {v30.4s, v31.4s}, [x18], x17 - add w10, w10, #8 // lm row offset + 8 - b L2 +WriteEnd: + subs w10, w10, #8 // lhs row - 8 + bgt L2 End2: - add w9, w9, #8 // rm col offset + 8 - add x1, x1, x15 // rm ptr + stride - add x3, x3, x18 // bias ptr + stride - b L1 + subs w7, w7, #8 // rhs col - 8 + add x1, x1, x15 // rhs ptr + stride + add x3, x3, #32 // bias ptr + stride + ldrb w13, [sp, #8] + cbz w13, NoDstStep + add x2, x2, #32 // dst ptr + stride + NoDstStep: + bgt L1 End1: sub sp, sp, #128 ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 ret -#endif \ No newline at end of file +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.c index bae8dacdbd3..b48c18a6de1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.c @@ -221,34 +221,57 @@ void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col return; } -void MatMul8x8(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row_8_, - int col_8_) { - /* col8-major * row8-major => col8x8-major */ - for (int row = 0; row < row_8_; row++) { - for (int col = 0; col < col_8_; col++) { - int r8div = row / 8, r8mod = row % 8; - int c8div = col / 8, c8mod = col % 8; - size_t ci = c8div * row_8_ * 8 + row * 8 + c8mod; - float value = 0; - for (int d = 0; d < deep; d++) { - size_t ai = r8div * deep * 8 + d * 8 + r8mod; - size_t bi = c8div * deep * 8 + d * 8 + c8mod; - value = value + a[ai] * b[bi]; +void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, + int col, int stride, bool write_nhwc) { + if (write_nhwc) { + /* col8-major * row8-major => col-major */ + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r8div = r / 8, r8mod = r % 8; + int c8div = c / 8, c8mod = c % 8; + size_t ci = r * stride + c; + float value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r8div * deep * 8 + d * 8 + r8mod; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + if (bias != NULL) value += bias[c]; + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); + if (act_type != ActType_No) value = MSMAX(0.0f, value); + dst[ci] = value; + } + } + } else { + /* col8-major * row8-major => col8x8-major */ + int col_8 = UP_ROUND(col, C8NUM); + int row_8 = UP_ROUND(row, C8NUM); + for (int r = 0; r < row_8; r++) { + for (int c = 0; c < col_8; c++) { + int r8div = r / 8, r8mod = r % 8; + int c8div = c / 8, c8mod = c % 8; + size_t ci = c8div * row_8 * 8 + r * 8 + c8mod; + float value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r8div * deep * 8 + d * 8 + r8mod; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + if (bias != NULL) value += bias[c]; + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); + if (act_type != ActType_No) value = MSMAX(0.0f, value); + dst[ci] = value; } - if (bias != NULL) value += bias[col]; - if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); - if (act_type != ActType_No) value = MSMAX(0.0f, value); - c[ci] = value; } } return; } -void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row_8_, - int col_8_) { +void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col, + int stride, bool write_nhwc) { #ifdef __aarch64__ - MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row_8_, col_8_); + MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, write_nhwc); #else - MatMul8x8(a, b, c, bias, act_type, deep, row_8_, col_8_); + MatMul8x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc); #endif } diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.h index 5d1d7f3cffe..ce50f2ff56f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.h @@ -26,13 +26,14 @@ #ifdef __cplusplus extern "C" { #endif -void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col); +void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col, + int stride, bool write_nhwc); void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col); void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride); #ifdef __aarch64__ void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, - int col); + int col, size_t stride, bool write_nhwc); #endif #ifdef __cplusplus } diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc index 465cc28df3c..317654aa561 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc @@ -370,26 +370,35 @@ TEST_F(TestConv1x1Fp32, Conv1x1Test2) { conv1x1->Run(); CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0.0001); - /* running warm up */ - for (int i = 0; i < 0; i++) { - conv1x1->Run(); + auto ptr = reinterpret_cast(outputs_[0]->Data()); + bool first = true; + for (int i = 0; i < total_size; i++) { + if (fabs(ptr[i] - correct[i]) > 0.001 && first) { + printf("%d %f %f\n", i, ptr[i], correct[i]); + first = false; + } } - /* running time cost */ - int loop_count = 1; - auto time_start = mindspore::lite::GetTimeUs(); - for (int i = 0; i < loop_count; i++) { - conv1x1->Run(); - } - auto time_end = mindspore::lite::GetTimeUs(); - auto cost = time_end - time_start; - uint64_t time_avg = cost / loop_count; - printf("1x1 average time : %f ms\n", time_avg / 1000.0f); - - delete conv_param; - delete conv1x1; - for (auto t : inputs_) delete t; - for (auto t : outputs_) delete t; - free(correct); + // /* running warm up */ + // for (int i = 0; i < 0; i++) { + // conv1x1->Run(); + // } + // + // /* running time cost */ + // int loop_count = 1; + // auto time_start = mindspore::lite::GetTimeUs(); + // for (int i = 0; i < loop_count; i++) { + // conv1x1->Run(); + // } + // auto time_end = mindspore::lite::GetTimeUs(); + // auto cost = time_end - time_start; + // uint64_t time_avg = cost / loop_count; + // printf("1x1 average time : %f ms\n", time_avg / 1000.0f); + // + // delete conv_param; + // delete conv1x1; + // for (auto t : inputs_) delete t; + // for (auto t : outputs_) delete t; + // free(correct); } } // namespace mindspore