From 9cab9f7a91f40f4531e9a8a5b70ef9b8cd796b68 Mon Sep 17 00:00:00 2001 From: zhanyuan Date: Sat, 1 Aug 2020 11:14:47 +0800 Subject: [PATCH] 1. multithreading support for fc_int8_op. 2. change asm matmul output layout from col8x8 to row8x8 --- .../kernel/arm/int8/fullconnection_int8.cc | 81 +++-- .../kernel/arm/int8/fullconnection_int8.h | 8 +- .../kernel/arm/opclib/assembly/arm64/matmul.s | 306 ++++++++---------- .../runtime/kernel/arm/opclib/fp32/matmul.cc | 4 + .../runtime/kernel/arm/opclib/fp32/matmul.h | 15 +- .../runtime/kernel/arm/opclib/int8/matmul.cc | 51 --- .../runtime/kernel/arm/opclib/int8/matmul.h | 16 +- .../kernel/arm/opclib/quantization/quantize.h | 3 +- .../arm/int8/fullconnection_int8_tests.cc | 144 +++++++++ 9 files changed, 367 insertions(+), 261 deletions(-) create mode 100644 mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc index 8f05cf2b27a..b7c23f5ece1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc @@ -17,6 +17,7 @@ #include "src/runtime/kernel/arm/int8/fullconnection_int8.h" #include "src/runtime/kernel/arm/opclib/int8/matmul.h" #include "src/runtime/kernel/arm/opclib/common_func.h" +#include "src/runtime/runtime_api.h" #include "include/errorcode.h" using mindspore::lite::RET_MEMORY_FAILED; @@ -25,22 +26,42 @@ using mindspore::lite::RET_OK; namespace mindspore::kernel { int FullconnectionInt8CPUKernel::Init() { fc_param_->row_ = (inputs_[0]->shape())[0]; - fc_param_->col_ = (inputs_[1]->shape())[1]; - fc_param_->deep_ = (inputs_[1]->shape())[0]; + fc_param_->col_ = (inputs_[1]->shape())[0]; + fc_param_->deep_ = (inputs_[1]->shape())[1]; fc_param_->row_8_ = UP_ROUND(fc_param_->row_, 8); fc_param_->col_8_ = UP_ROUND(fc_param_->col_, 8); + thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, 8)); + thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, 8), thread_count_); + a_c8_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(fc_param_->row_8_ * fc_param_->deep_ * sizeof(int8_t))); + if (!a_c8_ptr_) { + return RET_MEMORY_FAILED; + } memset(a_c8_ptr_, 0, fc_param_->row_8_ * fc_param_->deep_ * sizeof(int8_t)); b_r8_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(fc_param_->col_8_ * fc_param_->deep_ * sizeof(int8_t))); - memset(b_r8_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(int8_t)); - c_r8x8_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(fc_param_->row_8_ * fc_param_->col_8_ * sizeof(int))); - memset(c_r8x8_ptr_, 0, fc_param_->row_8_ * fc_param_->col_8_ * sizeof(int)); - if (!a_c8_ptr_ || !b_r8_ptr_ || !c_r8x8_ptr_) { + if (!b_r8_ptr_) { return RET_MEMORY_FAILED; } + memset(b_r8_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(int8_t)); + auto weight_data = reinterpret_cast(inputs_[1]->Data()); + RowMajor2Col8MajorInt8(weight_data, b_r8_ptr_, fc_param_->col_, fc_param_->deep_); + c_r8x8_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(fc_param_->row_8_ * fc_param_->col_8_ * sizeof(int))); + if (!c_r8x8_ptr_) { + return RET_MEMORY_FAILED; + } + memset(c_r8x8_ptr_, 0, fc_param_->row_8_ * fc_param_->col_8_ * sizeof(int)); + auto bias_len = fc_param_->col_8_ * sizeof(int); + bias_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(bias_len)); + if (!bias_ptr_) { + return RET_MEMORY_FAILED; + } + memset(bias_ptr_, 0, bias_len); + if (inputs_.size() == 3) { + memcpy(bias_ptr_, inputs_[2]->Data(), bias_len); + } auto input_tensor = inputs_[0]; auto params = input_tensor->GetQuantParams(); @@ -59,7 +80,8 @@ int FullconnectionInt8CPUKernel::Init() { quant_params_.output.scale_ = params.front().scale; double real_multiplier = quant_params_.input.scale_ * quant_params_.weight.scale_ / quant_params_.output.scale_; - QuantizeMultiplier(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.output_shift); + QuantizeRoundParameter(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift, + &quant_params_.right_shift); CalculateActivationRangeQuantized(fc_param_->maxf_, fc_param_->minf_, quant_params_.output.scale_, quant_params_.output.zp_, &quant_params_.out_act_max, &quant_params_.out_act_min); @@ -68,22 +90,37 @@ int FullconnectionInt8CPUKernel::Init() { int FullconnectionInt8CPUKernel::ReSize() { return RET_OK; } -int FullconnectionInt8CPUKernel::Run() { - auto a_ptr = reinterpret_cast(inputs_.at(0)->Data()); - auto b_ptr = reinterpret_cast(inputs_.at(1)->Data()); - auto bias_ptr = reinterpret_cast(inputs_.at(2)->Data()); - auto output_ptr = reinterpret_cast(outputs_.at(0)->Data()); +int FullconnectionInt8CPUKernel::RunImpl(int task_id) { + int cur_oc = MSMIN(thread_stride_, UP_DIV(fc_param_->col_8_, 8) - task_id * thread_stride_); + if (cur_oc <= 0) { + return RET_OK; + } auto &p = quant_params_; - - // rows*depth -> rows*depth, col_8 major - RowMajor2Col8MajorInt8(a_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_); - // cols*depth -> cols*depth, col_8 major == depth*cols, row_8 major - RowMajor2Col8MajorInt8(b_ptr, b_r8_ptr_, fc_param_->col_, fc_param_->deep_); - MatMulInt8(a_c8_ptr_, b_r8_ptr_, c_r8x8_ptr_, fc_param_->row_8_, fc_param_->col_8_, fc_param_->deep_, p.input.zp_, - p.weight.zp_); - PostFuncInt8(c_r8x8_ptr_, bias_ptr, output_ptr, fc_param_->col_, fc_param_->row_, fc_param_->col_8_, - fc_param_->row_8_, p.quant_multiplier, p.output_shift, p.output.zp_, p.out_act_min, p.out_act_max); - + auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_; + auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->row_8_; + MatMulInt8(a_c8_ptr_, cur_b, cur_c, fc_param_->row_8_, cur_oc * 8, fc_param_->deep_, p.input.zp_, p.weight.zp_); return RET_OK; } + +int FcInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto fc = reinterpret_cast(cdata); + auto ret = fc->RunImpl(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "FcInt8Run error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int FullconnectionInt8CPUKernel::Run() { + auto a_ptr = reinterpret_cast(inputs_[0]->Data()); + auto output_ptr = reinterpret_cast(outputs_[0]->Data()); + auto &p = quant_params_; + RowMajor2Col8MajorInt8(a_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_); + LiteBackendParallelLaunch(FcInt8Run, this, thread_count_); + PostFuncInt8(c_r8x8_ptr_, bias_ptr_, output_ptr, fc_param_->col_, fc_param_->row_, fc_param_->row_8_, + p.quant_multiplier, p.left_shift, p.right_shift, p.output.zp_, p.out_act_min, p.out_act_max); + return RET_OK; +} + } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h index e953d4e6383..4bc62f1b082 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h @@ -31,20 +31,22 @@ class FullconnectionInt8CPUKernel : public FullconnectionBaseCPUKernel { const std::vector &outputs, const Context *ctx) : FullconnectionBaseCPUKernel(parameter, inputs, outputs, ctx) {} ~FullconnectionInt8CPUKernel() override { - free(a_c8_ptr_); - free(b_r8_ptr_); - free(c_r8x8_ptr_); + ctx_->allocator->Free(a_c8_ptr_); + ctx_->allocator->Free(b_r8_ptr_); + ctx_->allocator->Free(c_r8x8_ptr_); } int Init() override; int ReSize() override; int Run() override; + int RunImpl(int task_id); private: FcQuantArg quant_params_; int8_t *a_c8_ptr_; int8_t *b_r8_ptr_; int *c_r8x8_ptr_; + int *bias_ptr_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/matmul.s b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/matmul.s index 17dddeb355d..b33c71d34ec 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/matmul.s +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/matmul.s @@ -17,17 +17,17 @@ // \-----------------------------------------/ // LM 8x1 block // /---------------------\ /-----------------------------------------\ -// | v0.s[0] | |v16.s[0] ... v30.s[0]| +// | v0.s[0] | |v16.s[0]...v16.s[3] v17.s[0]...v17.s[3]| // | ... | | ... ... | -// | v0.s[3] | |v16.s[3] ... v30.s[3]| -// | v1.s[0] | |v17.s[0] ... v31.s[0]| +// | v0.s[3] | |v22.s[0]...v22.s[3] v23.s[0]...v23.s[3]| +// | v1.s[0] | |v24.s[0]...v24.s[3] v25.s[0]...v25.s[3]| // | ... | | ... ... | -// | v1.s[3] | |v17.s[3] ... v31.s[3]| +// | v1.s[3] | |v30.s[0]...v30.s[3] v31.s[0]...v31.s[3]| // \---------------------/ \-----------------------------------------/ // accumulators 8x8 block // /////////////////////////////////////////////////////////////////////////////// -//OptLoopMul4 RHS 1x8 block +//OptLoopMul4 RM 1x8 block // /--------------------------------------------\ // |v8.s[0] ... v8.s[3] v9.s[0] ... v9.s[3] | // |v10.s[0] ... v10.s[3] v11.s[0] ... v11.s[3]| @@ -36,12 +36,12 @@ // \--------------------------------------------/ // LM 8x4 block // /---------------------------------\ /--------------------------------------------\ -// | v0.s[0] v2.s[0] v4.s[0] v6.s[0] | |v16.s[0] ... v30.s[0]| +// | v0.s[0] v2.s[0] v4.s[0] v6.s[0] | |v16.s[0]...v16.s[3] v17.s[0]...v17.s[3] | // | ... ... ... ... | | ... ... | -// | v0.s[3] v2.s[3] v4.s[3] v6.s[3] | |v16.s[3] ... v30.s[3]| -// | v1.s[0] v3.s[0] v5.s[0] v7.s[0] | |v17.s[0] ... v31.s[0]| +// | v0.s[3] v2.s[3] v4.s[3] v6.s[3] | |v22.s[0]...v22.s[3] v23.s[0]...v23.s[3] | +// | v1.s[0] v3.s[0] v5.s[0] v7.s[0] | |v24.s[0]...v24.s[3] v25.s[0]...v25.s[3] | // | ... ... ... ... | | ... ... | -// | v1.s[3] v3.s[3] v5.s[3] v7.s[3] | |v17.s[3] ... v31.s[3]| +// | v1.s[3] v3.s[3] v5.s[3] v7.s[3] | |v30.s[0]...v30.s[3] v31.s[0]...v31.s[3] | // \---------------------------------/ \--------------------------------------------/ // accumulators 8x8 block ///////////////////////////////////////////////////////////////////////////////// @@ -64,25 +64,22 @@ MatMulFloatNeon64: mov w7, v0.s[0] mov w8, v1.s[0] - mov w9, 0 // row counter - mov w10, 0 // col counter - mov w18, #32 - mul w15, w4, w18 // the stride of a or b - mul w16, w6, w18 // the stride of c + mov w9, 0 // rm col offset + mov w10, 0 // lm row offset + mov w18, #32 // sizeof(float)*8 + mul w15, w4, w18 // the stride of lm/rm: sizeof(float)*8*depth L1: - cmp w9, w5 + cmp w9, w6 beq End1 - mov w10, 0 // reset col counter - mov x12, x1 // reload b ptr - mov x17, x2 // reload current c ptr + mov w10, 0 // reset lm row offset + mov x12, x0 // reload lm ptr mov x14, x3 // reload bias ptr L2: cmp w10, w6 beq End2 - mov x11, x0 // reload a ptr mov w13, w4 // reload depth dup v16.4s, wzr dup v17.4s, wzr @@ -105,142 +102,127 @@ OptLoopMul4: cmp w13, #4 blt CommLoopMul - ld1 {v0.4s}, [x11], #16 - ld1 {v8.4s}, [x12], #16 - fmla v16.4s, v0.4s, v8.s[0] - fmla v18.4s, v0.4s, v8.s[1] - ld1 {v1.4s}, [x11], #16 - fmla v20.4s, v0.4s, v8.s[2] - fmla v22.4s, v0.4s, v8.s[3] - ld1 {v9.4s}, [x12], #16 - fmla v25.4s, v1.4s, v9.s[0] - fmla v27.4s, v1.4s, v9.s[1] - fmla v29.4s, v1.4s, v9.s[2] - fmla v31.4s, v1.4s, v9.s[3] - ld1 {v2.4s}, [x11], #16 - ld1 {v3.4s}, [x11], #16 - fmla v24.4s, v0.4s, v9.s[0] - fmla v26.4s, v0.4s, v9.s[1] - fmla v28.4s, v0.4s, v9.s[2] - fmla v30.4s, v0.4s, v9.s[3] - fmla v17.4s, v1.4s, v8.s[0] - fmla v19.4s, v1.4s, v8.s[1] - fmla v21.4s, v1.4s, v8.s[2] - fmla v23.4s, v1.4s, v8.s[3] - ld1 {v10.4s}, [x12], #16 - ld1 {v11.4s}, [x12], #16 - fmla v16.4s, v2.4s, v10.s[0] - fmla v18.4s, v2.4s, v10.s[1] - fmla v20.4s, v2.4s, v10.s[2] - fmla v22.4s, v2.4s, v10.s[3] - fmla v25.4s, v3.4s, v11.s[0] - fmla v27.4s, v3.4s, v11.s[1] - fmla v29.4s, v3.4s, v11.s[2] - fmla v31.4s, v3.4s, v11.s[3] - ld1 {v4.4s}, [x11], #16 - ld1 {v5.4s}, [x11], #16 - fmla v24.4s, v2.4s, v11.s[0] - fmla v26.4s, v2.4s, v11.s[1] - fmla v28.4s, v2.4s, v11.s[2] - fmla v30.4s, v2.4s, v11.s[3] - fmla v17.4s, v3.4s, v10.s[0] - fmla v19.4s, v3.4s, v10.s[1] - fmla v21.4s, v3.4s, v10.s[2] - fmla v23.4s, v3.4s, v10.s[3] - ld1 {v12.4s}, [x12], #16 - ld1 {v13.4s}, [x12], #16 - fmla v16.4s, v4.4s, v12.s[0] - fmla v18.4s, v4.4s, v12.s[1] - fmla v20.4s, v4.4s, v12.s[2] - fmla v22.4s, v4.4s, v12.s[3] - fmla v25.4s, v5.4s, v13.s[0] - fmla v27.4s, v5.4s, v13.s[1] - fmla v29.4s, v5.4s, v13.s[2] - fmla v31.4s, v5.4s, v13.s[3] - ld1 {v6.4s}, [x11], #16 - ld1 {v7.4s}, [x11], #16 - fmla v24.4s, v4.4s, v13.s[0] - fmla v26.4s, v4.4s, v13.s[1] - fmla v28.4s, v4.4s, v13.s[2] - fmla v30.4s, v4.4s, v13.s[3] - fmla v17.4s, v5.4s, v12.s[0] - fmla v19.4s, v5.4s, v12.s[1] - fmla v21.4s, v5.4s, v12.s[2] - fmla v23.4s, v5.4s, v12.s[3] - ld1 {v14.4s}, [x12], #16 - ld1 {v15.4s}, [x12], #16 - fmla v16.4s, v6.4s, v14.s[0] - fmla v18.4s, v6.4s, v14.s[1] - fmla v20.4s, v6.4s, v14.s[2] - fmla v22.4s, v6.4s, v14.s[3] - fmla v25.4s, v7.4s, v15.s[0] - fmla v27.4s, v7.4s, v15.s[1] - fmla v29.4s, v7.4s, v15.s[2] - fmla v31.4s, v7.4s, v15.s[3] - fmla v24.4s, v6.4s, v15.s[0] - fmla v26.4s, v6.4s, v15.s[1] - fmla v28.4s, v6.4s, v15.s[2] - fmla v30.4s, v6.4s, v15.s[3] - fmla v17.4s, v7.4s, v14.s[0] - fmla v19.4s, v7.4s, v14.s[1] - fmla v21.4s, v7.4s, v14.s[2] - fmla v23.4s, v7.4s, v14.s[3] + ld1 {v0.4s, v1.4s}, [x12], #32 + ld1 {v8.4s, v9.4s}, [x1], #32 + fmla v16.4s, v8.4s, v0.s[0] + fmla v17.4s, v9.4s, v0.s[0] + fmla v18.4s, v8.4s, v0.s[1] + fmla v19.4s, v9.4s, v0.s[1] + fmla v20.4s, v8.4s, v0.s[2] + fmla v21.4s, v9.4s, v0.s[2] + fmla v22.4s, v8.4s, v0.s[3] + fmla v23.4s, v9.4s, v0.s[3] + ld1 {v10.4s, v11.4s}, [x1], #32 + fmla v24.4s, v8.4s, v1.s[0] + fmla v25.4s, v9.4s, v1.s[0] + fmla v26.4s, v8.4s, v1.s[1] + fmla v27.4s, v9.4s, v1.s[1] + ld1 {v2.4s, v3.4s}, [x12], #32 + fmla v28.4s, v8.4s, v1.s[2] + fmla v29.4s, v9.4s, v1.s[2] + fmla v30.4s, v8.4s, v1.s[3] + fmla v31.4s, v9.4s, v1.s[3] + fmla v16.4s, v10.4s, v2.s[0] + fmla v17.4s, v11.4s, v2.s[0] + fmla v18.4s, v10.4s, v2.s[1] + fmla v19.4s, v11.4s, v2.s[1] + fmla v20.4s, v10.4s, v2.s[2] + fmla v21.4s, v11.4s, v2.s[2] + fmla v22.4s, v10.4s, v2.s[3] + fmla v23.4s, v11.4s, v2.s[3] + ld1 {v12.4s, v13.4s}, [x1], #32 + fmla v24.4s, v10.4s, v3.s[0] + fmla v25.4s, v11.4s, v3.s[0] + fmla v26.4s, v10.4s, v3.s[1] + fmla v27.4s, v11.4s, v3.s[1] + ld1 {v4.4s, v5.4s}, [x12], #32 + fmla v28.4s, v10.4s, v3.s[2] + fmla v29.4s, v11.4s, v3.s[2] + fmla v30.4s, v10.4s, v3.s[3] + fmla v31.4s, v11.4s, v3.s[3] + fmla v16.4s, v12.4s, v4.s[0] + fmla v17.4s, v13.4s, v4.s[0] + fmla v18.4s, v12.4s, v4.s[1] + fmla v19.4s, v13.4s, v4.s[1] + fmla v20.4s, v12.4s, v4.s[2] + fmla v21.4s, v13.4s, v4.s[2] + fmla v22.4s, v12.4s, v4.s[3] + fmla v23.4s, v13.4s, v4.s[3] + ld1 {v6.4s,v7.4s}, [x12], #32 + fmla v24.4s, v12.4s, v5.s[0] + fmla v25.4s, v13.4s, v5.s[0] + fmla v26.4s, v12.4s, v5.s[1] + fmla v27.4s, v13.4s, v5.s[1] + ld1 {v14.4s, v15.4s}, [x1], #32 + fmla v28.4s, v12.4s, v5.s[2] + fmla v29.4s, v13.4s, v5.s[2] + fmla v30.4s, v12.4s, v5.s[3] + fmla v31.4s, v13.4s, v5.s[3] + fmla v16.4s, v14.4s, v6.s[0] + fmla v17.4s, v15.4s, v6.s[0] + fmla v18.4s, v14.4s, v6.s[1] + fmla v19.4s, v15.4s, v6.s[1] + fmla v20.4s, v14.4s, v6.s[2] + fmla v21.4s, v15.4s, v6.s[2] + fmla v22.4s, v14.4s, v6.s[3] + fmla v23.4s, v15.4s, v6.s[3] + fmla v24.4s, v14.4s, v7.s[0] + fmla v25.4s, v15.4s, v7.s[0] + fmla v26.4s, v14.4s, v7.s[1] + fmla v27.4s, v15.4s, v7.s[1] + fmla v28.4s, v14.4s, v7.s[2] + fmla v29.4s, v15.4s, v7.s[2] + fmla v30.4s, v14.4s, v7.s[3] + fmla v31.4s, v15.4s, v7.s[3] subs w13, w13, #4 b OptLoopMul4 CommLoopMul: cmp w13, #1 blt Bias - ld1 {v0.4s}, [x11], #16 - ld1 {v2.4s}, [x12], #16 - fmla v16.4s, v0.4s, v2.s[0] - fmla v18.4s, v0.4s, v2.s[1] - ld1 {v1.4s}, [x11], #16 - fmla v20.4s, v0.4s, v2.s[2] - fmla v22.4s, v0.4s, v2.s[3] - ld1 {v3.4s}, [x12], #16 - fmla v25.4s, v1.4s, v3.s[0] - fmla v27.4s, v1.4s, v3.s[1] - fmla v29.4s, v1.4s, v3.s[2] - fmla v31.4s, v1.4s, v3.s[3] - fmla v24.4s, v0.4s, v3.s[0] - fmla v26.4s, v0.4s, v3.s[1] - fmla v28.4s, v0.4s, v3.s[2] - fmla v30.4s, v0.4s, v3.s[3] - fmla v17.4s, v1.4s, v2.s[0] - fmla v19.4s, v1.4s, v2.s[1] - fmla v21.4s, v1.4s, v2.s[2] - fmla v23.4s, v1.4s, v2.s[3] + + ld1 {v0.4s, v1.4s}, [x12], #32 + ld1 {v2.4s, v3.4s}, [x1], #32 + fmla v16.4s, v2.4s, v0.s[0] + fmla v17.4s, v3.4s, v0.s[0] + fmla v18.4s, v2.4s, v0.s[1] + fmla v19.4s, v3.4s, v0.s[1] + fmla v20.4s, v2.4s, v0.s[2] + fmla v21.4s, v3.4s, v0.s[2] + fmla v22.4s, v2.4s, v0.s[3] + fmla v23.4s, v3.4s, v0.s[3] + fmla v24.4s, v2.4s, v1.s[0] + fmla v25.4s, v3.4s, v1.s[0] + fmla v26.4s, v2.4s, v1.s[1] + fmla v27.4s, v3.4s, v1.s[1] + fmla v28.4s, v2.4s, v1.s[2] + fmla v29.4s, v3.4s, v1.s[2] + fmla v30.4s, v2.4s, v1.s[3] + fmla v31.4s, v3.4s, v1.s[3] subs w13, w13, #1 b CommLoopMul Bias: + cmp x3, #0 + beq Relu ld1 {v0.4s}, [x14], #16 ld1 {v1.4s}, [x14], #16 - dup v2.4s, v0.s[0] - fadd v16.4s, v16.4s, v2.4s - fadd v17.4s, v17.4s, v2.4s - dup v3.4s, v0.s[1] - fadd v18.4s, v18.4s, v3.4s - fadd v19.4s, v19.4s, v3.4s - dup v4.4s, v0.s[2] - fadd v20.4s, v20.4s, v4.4s - fadd v21.4s, v21.4s, v4.4s - dup v5.4s, v0.s[3] - fadd v22.4s, v22.4s, v5.4s - fadd v23.4s, v23.4s, v5.4s - dup v2.4s, v1.s[0] - fadd v24.4s, v24.4s, v2.4s - fadd v25.4s, v25.4s, v2.4s - dup v3.4s, v1.s[1] - fadd v26.4s, v26.4s, v3.4s - fadd v27.4s, v27.4s, v3.4s - dup v4.4s, v1.s[2] - fadd v28.4s, v28.4s, v4.4s - fadd v29.4s, v29.4s, v4.4s - dup v5.4s, v1.s[3] - fadd v30.4s, v30.4s, v5.4s - fadd v31.4s, v31.4s, v5.4s + fadd v16.4s, v16.4s, v0.4s + fadd v17.4s, v17.4s, v1.4s + fadd v18.4s, v18.4s, v0.4s + fadd v19.4s, v19.4s, v1.4s + fadd v20.4s, v20.4s, v0.4s + fadd v21.4s, v21.4s, v1.4s + fadd v22.4s, v22.4s, v0.4s + fadd v23.4s, v23.4s, v1.4s + fadd v24.4s, v24.4s, v0.4s + fadd v25.4s, v25.4s, v1.4s + fadd v26.4s, v26.4s, v0.4s + fadd v27.4s, v27.4s, v1.4s + fadd v28.4s, v28.4s, v0.4s + fadd v29.4s, v29.4s, v1.4s + fadd v30.4s, v30.4s, v0.4s + fadd v31.4s, v31.4s, v1.4s Relu: dup v15.4s, w7 @@ -281,30 +263,28 @@ Relu: fmin v31.4s, v31.4s, v15.4s TransToOut: - st1 {v16.4s}, [x17], #16 - st1 {v17.4s}, [x17], #16 - st1 {v18.4s}, [x17], #16 - st1 {v19.4s}, [x17], #16 - st1 {v20.4s}, [x17], #16 - st1 {v21.4s}, [x17], #16 - st1 {v22.4s}, [x17], #16 - st1 {v23.4s}, [x17], #16 - st1 {v24.4s}, [x17], #16 - st1 {v25.4s}, [x17], #16 - st1 {v26.4s}, [x17], #16 - st1 {v27.4s}, [x17], #16 - st1 {v28.4s}, [x17], #16 - st1 {v29.4s}, [x17], #16 - st1 {v30.4s}, [x17], #16 - st1 {v31.4s}, [x17], #16 + st1 {v16.4s}, [x2], #16 + st1 {v17.4s}, [x2], #16 + st1 {v18.4s}, [x2], #16 + st1 {v19.4s}, [x2], #16 + st1 {v20.4s}, [x2], #16 + st1 {v21.4s}, [x2], #16 + st1 {v22.4s}, [x2], #16 + st1 {v23.4s}, [x2], #16 + st1 {v24.4s}, [x2], #16 + st1 {v25.4s}, [x2], #16 + st1 {v26.4s}, [x2], #16 + st1 {v27.4s}, [x2], #16 + st1 {v28.4s}, [x2], #16 + st1 {v29.4s}, [x2], #16 + st1 {v30.4s}, [x2], #16 + st1 {v31.4s}, [x2], #16 - add w10, w10, #8 // col+=8 + add w10, w10, #8 // lhs row offset + 8 b L2 End2: - add x0, x0, x15 // stride a ptr - add x2, x2, x16 // stride c ptr - add w9, w9, #8 // row+=8 + add w9, w9, #8 // rhs col offset + 8 b L1 End1: diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.cc index 5e2ebe7bc2f..87d8843a566 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.cc +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.cc @@ -74,5 +74,9 @@ void MatMul8x8(const float *a, const float *b, float *c, const float *bias, floa void MatMul(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int deep, int row_8_, int col_8_) { +#ifdef __aarch64__ + MatMulFloatNeon64(a, b, c, bias, maxf, minf, deep, row_8_, col_8_); +#else MatMul8x8(a, b, c, bias, maxf, minf, deep, row_8_, col_8_); +#endif } diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.h index 97c6db417ee..fb6f213db49 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.h +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.h @@ -21,19 +21,22 @@ #include "src/runtime/kernel/arm/opclib/op_base.h" #include "src/runtime/kernel/arm/opclib/matmul.h" -#ifdef __cplusplus -extern "C" { -#endif - void MatMul(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int depth, int row, int col); void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col); void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, int row, int col); void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, int row, int col); - +void MatMul8x8(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int deep, + int row_8_, int col_8_); +#ifdef __cplusplus +extern "C" { +#endif +#ifdef __aarch64__ +void MatMulFloatNeon64(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int depth, + int row, int col); +#endif #ifdef __cplusplus } #endif #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_MATMUL_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.cc index aa5de959de0..0517f8b5fd0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.cc +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.cc @@ -48,54 +48,3 @@ void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, co } return; } - -// todo: need to delete, replace by above functions. z00445833 -void GemmRowCol8x8Major2RowMajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { - int col8 = UP_ROUND(col, 8); - for (int r = 0; r < row; r++) { - int rd8 = r / 8; - int rm8 = r % 8; - for (int c = 0; c < col; c++) { - dst_ptr[r * col + c] = src_ptr[rd8 * col8 * 8 + c * 8 + rm8]; - } - } -} - -void Gemm8x8Int8(const int8_t *lhs_data, const int8_t *rhs_data, const int8_t *bias_data, int8_t *output_data, - int depth, FcQuantArg *params) { - int lhs_offset = params->input.zp_; - int rhs_offset = params->weight.zp_; - int output_offset = params->output.zp_; - int output_multiplier = params->quant_multiplier; - int output_shift = params->output_shift; - - for (int row = 0; row < 8; ++row) { - for (int col = 0; col < 8; ++col) { - int c_index = col * 8 + row; - int acc = 0; - for (int d = 0; d < depth; ++d) { - int a_index = d * 8 + row; - int b_index = d * 8 + col; - acc += (lhs_data[a_index] - lhs_offset) * (rhs_data[b_index] - rhs_offset); - } - acc += bias_data[col]; - acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift, output_shift) + output_offset; - acc = MSMAX(CHAR_MIN, MSMIN(CHAR_MAX, acc)); - output_data[c_index] = (int8_t)acc; - } - } -} - -void GemmInt8(const int8_t *input_data, const int8_t *weights_data, const int8_t *bias_data, int8_t *output_data, - int row_8, int col_8, int depth, FcQuantArg *params) { - for (int r = 0; r < row_8; r += 8) { - int8_t *output = output_data + r * col_8; - const int8_t *input = input_data + r * depth; - for (int c = 0; c < col_8; c += 8) { - const int8_t *bias = bias_data + c; - const int8_t *weights = weights_data + c * depth; - int8_t *dst = output + c * 8; - Gemm8x8Int8(input, weights, bias, dst, depth, params); - } - } -} diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.h b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.h index d0c85ba5a19..6fc2166461b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.h +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.h @@ -20,23 +20,9 @@ #include "src/runtime/kernel/arm/opclib/op_base.h" #include "src/runtime/kernel/arm/opclib/matmul.h" -#ifdef __cplusplus -extern "C" { -#endif - void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, const int col8, const int deep, const int32_t a_zp, const int32_t b_zp); void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); -void GemmRowCol8x8Major2RowMajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); -void Gemm8x8Int8(const int8_t *lhs_data, const int8_t *rhs_data, const int8_t *bias_data, int8_t *output_data, - int depth, FcQuantArg *params); -void GemmInt8(const int8_t *input_data, const int8_t *weights_data, const int8_t *bias_data, int8_t *output_data, - int row_8, int col_8, int depth, FcQuantArg *params); - -#ifdef __cplusplus -} -#endif - -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_MATMUL_H_ +#endif // MINDSPORE_LITE_SRC_BACKEND_ARM_OPCLIB_INT8_MATMUL_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/quantize.h b/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/quantize.h index dfe0b8ab45f..79de3dc5d0b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/quantize.h +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/quantize.h @@ -54,7 +54,8 @@ struct FcQuantArg { QuantArg output; int32_t out_act_min; int32_t out_act_max; - int32_t output_shift; + int32_t left_shift; + int32_t right_shift; int32_t quant_multiplier; }; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc new file mode 100644 index 00000000000..fff45f30cb2 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc @@ -0,0 +1,144 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h" +#include "mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.h" +#include "mindspore/lite/src/runtime/kernel/arm/opclib/common_func.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/src/lite_kernel.h" + +namespace mindspore { +using lite::tensor::Tensor; +class TestFcInt8 : public mindspore::Common { + public: + TestFcInt8(){} +}; + +void Quantize(float *input_data, int length, float scale, int zero_point, int8_t *output_data) { + for (int i = 0; i < length; ++i) { + int8_t q = static_cast(std::max( + std::numeric_limits::min(), + std::min(std::numeric_limits::max(), std::round(zero_point + (input_data[i] / scale))))); + output_data[i] = q; + } +} + +void Dequantize(int8_t *input_data, int length, float scale, int zero_point, float *output_data) { + for (int i = 0; i < length; ++i) { + output_data[i] = scale * (input_data[i] - zero_point); + } +} + +int FcInt8TestInit(std::vector *inputs_, std::vector *outputs_, + MatMulParameter *matmal_param, float **correct, double *scale, int *zeropoint) { + float input_max = 20; + float input_min = -20; + float weight_max = 1; + float weight_min = -1; + float output_max = 20; + float output_min = -20; + + double input_scale = + (input_max - input_min) / (std::numeric_limits::max() - std::numeric_limits::min()); + int input_zp = std::numeric_limits::max() - input_max / input_scale; + double weight_scale = + (weight_max - weight_min) / (std::numeric_limits::max() - std::numeric_limits::min()); + int weight_zp = std::numeric_limits::max() - weight_max / weight_scale; + double output_scale = + (output_max - output_min) / (std::numeric_limits::max() - std::numeric_limits::min()); + int output_zp = std::numeric_limits::max() - output_max / output_scale; + *scale = output_scale; + *zeropoint = output_zp; + + Tensor *in_t = new Tensor(kNumberTypeInt8, {2, 2, 2, 2}, schema::Format_NHWC, static_cast(1)); + in_t->MallocData(); + float in[] = {-3.2366564, -4.7733846, -7.8329225, 16.146885, 5.060793, -6.1471, -1.7680453, -6.5721383, + 17.87506, -5.1192183, 10.742863, 1.4536934, 19.693445, 19.45783, 5.063163, 0.5234792}; + Quantize(in, in_t->ElementsNum(), input_scale, input_zp, reinterpret_cast(in_t->Data())); + auto in_quant_arg = new mindspore::lite::tensor::QuantArg(); + in_quant_arg->zeroPoint = input_zp; + in_quant_arg->scale = input_scale; + in_t->AddQuantParam(*in_quant_arg); + inputs_->push_back(in_t); + + Tensor *weight_t = new Tensor(kNumberTypeInt8, {3, 8}, schema::Format_NHWC, static_cast(1)); + weight_t->MallocData(); + float weight[] = {-0.24438887, 0.06738146, -0.8169129, 0.21510671, -0.012470592, -0.053063435, + 0.6050155, 0.8656233, 0.12911413, -0.028635843, -0.034080597, -0.10622552, + -0.012254699, -0.01312836, 0.25241964, -0.4706142, 0.2451482, -0.9558459, + 0.4481974, 0.33251503, -0.011705584, -0.1720293, -0.39410214, -0.73637343}; + Quantize(weight, weight_t->ElementsNum(), weight_scale, weight_zp, reinterpret_cast(weight_t->Data())); + auto weight_quant_arg = new mindspore::lite::tensor::QuantArg(); + weight_quant_arg->zeroPoint = weight_zp; + weight_quant_arg->scale = weight_scale; + weight_t->AddQuantParam(*weight_quant_arg); + inputs_->push_back(weight_t); + + Tensor *bias_t = new Tensor(kNumberTypeInt32, {3}, schema::Format_NHWC, static_cast(1)); + bias_t->MallocData(); + memset(bias_t->Data(), 0, sizeof(int) * bias_t->ElementsNum()); + inputs_->push_back(bias_t); + + Tensor *out_t = new Tensor(kNumberTypeInt8, {2, 3}, schema::Format_NHWC, static_cast(1)); + out_t->MallocData(); + auto output_quant_arg = new mindspore::lite::tensor::QuantArg(); + output_quant_arg->zeroPoint = output_zp; + output_quant_arg->scale = output_scale; + out_t->AddQuantParam(*output_quant_arg); + outputs_->push_back(out_t); + + *correct = reinterpret_cast(malloc(out_t->ElementsNum() * sizeof(float))); + float nchw_co[] = {3.84586822, 0.93586633, 12.16212629, -10.93835061, 2.46887183, 8.61480108}; + memcpy(*correct, nchw_co, out_t->ElementsNum() * sizeof(float)); + + matmal_param->b_transpose_ = true; + matmal_param->a_transpose_ = false; + matmal_param->has_bias_ = true; + matmal_param->minf_ = -FLT_MAX; + matmal_param->maxf_ = FLT_MAX; + return out_t->ElementsNum(); +} + +TEST_F(TestFcInt8, fcint8) { + std::vector inputs_; + std::vector outputs_; + auto matmul_param = new MatMulParameter(); + float *correct; + double output_scale; + int output_zp; + int total_size = FcInt8TestInit(&inputs_, &outputs_, matmul_param, &correct, &output_scale, &output_zp); + lite::Context *ctx = new lite::Context; + ctx->threadNum = 2; + kernel::FullconnectionInt8CPUKernel *fc = + new kernel::FullconnectionInt8CPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx); + + fc->Init(); + fc->Run(); + float fout[6] = {0}; + Dequantize(reinterpret_cast(outputs_[0]->Data()), outputs_[0]->ElementsNum(), output_scale, output_zp, + fout); + CompareOutputData(fout, correct, 6, 0.2); + delete matmul_param; + delete fc; + for (auto t : inputs_) delete t; + for (auto t : outputs_) delete t; + free(correct); +} + +} // namespace mindspore