!3836 1. multithreading support for fc_int8_op. 2. change asm matmul output layout

Merge pull request !3836 from zhanyuan/master
This commit is contained in:
mindspore-ci-bot 2020-08-01 14:36:40 +08:00 committed by Gitee
commit 53c40d0ec7
9 changed files with 367 additions and 261 deletions

View File

@ -17,6 +17,7 @@
#include "src/runtime/kernel/arm/int8/fullconnection_int8.h"
#include "src/runtime/kernel/arm/opclib/int8/matmul.h"
#include "src/runtime/kernel/arm/opclib/common_func.h"
#include "src/runtime/runtime_api.h"
#include "include/errorcode.h"
using mindspore::lite::RET_MEMORY_FAILED;
@ -25,22 +26,42 @@ using mindspore::lite::RET_OK;
namespace mindspore::kernel {
int FullconnectionInt8CPUKernel::Init() {
fc_param_->row_ = (inputs_[0]->shape())[0];
fc_param_->col_ = (inputs_[1]->shape())[1];
fc_param_->deep_ = (inputs_[1]->shape())[0];
fc_param_->col_ = (inputs_[1]->shape())[0];
fc_param_->deep_ = (inputs_[1]->shape())[1];
fc_param_->row_8_ = UP_ROUND(fc_param_->row_, 8);
fc_param_->col_8_ = UP_ROUND(fc_param_->col_, 8);
thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, 8));
thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, 8), thread_count_);
a_c8_ptr_ =
reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(fc_param_->row_8_ * fc_param_->deep_ * sizeof(int8_t)));
if (!a_c8_ptr_) {
return RET_MEMORY_FAILED;
}
memset(a_c8_ptr_, 0, fc_param_->row_8_ * fc_param_->deep_ * sizeof(int8_t));
b_r8_ptr_ =
reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(fc_param_->col_8_ * fc_param_->deep_ * sizeof(int8_t)));
memset(b_r8_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(int8_t));
c_r8x8_ptr_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(fc_param_->row_8_ * fc_param_->col_8_ * sizeof(int)));
memset(c_r8x8_ptr_, 0, fc_param_->row_8_ * fc_param_->col_8_ * sizeof(int));
if (!a_c8_ptr_ || !b_r8_ptr_ || !c_r8x8_ptr_) {
if (!b_r8_ptr_) {
return RET_MEMORY_FAILED;
}
memset(b_r8_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(int8_t));
auto weight_data = reinterpret_cast<int8_t *>(inputs_[1]->Data());
RowMajor2Col8MajorInt8(weight_data, b_r8_ptr_, fc_param_->col_, fc_param_->deep_);
c_r8x8_ptr_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(fc_param_->row_8_ * fc_param_->col_8_ * sizeof(int)));
if (!c_r8x8_ptr_) {
return RET_MEMORY_FAILED;
}
memset(c_r8x8_ptr_, 0, fc_param_->row_8_ * fc_param_->col_8_ * sizeof(int));
auto bias_len = fc_param_->col_8_ * sizeof(int);
bias_ptr_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(bias_len));
if (!bias_ptr_) {
return RET_MEMORY_FAILED;
}
memset(bias_ptr_, 0, bias_len);
if (inputs_.size() == 3) {
memcpy(bias_ptr_, inputs_[2]->Data(), bias_len);
}
auto input_tensor = inputs_[0];
auto params = input_tensor->GetQuantParams();
@ -59,7 +80,8 @@ int FullconnectionInt8CPUKernel::Init() {
quant_params_.output.scale_ = params.front().scale;
double real_multiplier = quant_params_.input.scale_ * quant_params_.weight.scale_ / quant_params_.output.scale_;
QuantizeMultiplier(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.output_shift);
QuantizeRoundParameter(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift,
&quant_params_.right_shift);
CalculateActivationRangeQuantized(fc_param_->maxf_, fc_param_->minf_, quant_params_.output.scale_,
quant_params_.output.zp_, &quant_params_.out_act_max, &quant_params_.out_act_min);
@ -68,22 +90,37 @@ int FullconnectionInt8CPUKernel::Init() {
int FullconnectionInt8CPUKernel::ReSize() { return RET_OK; }
int FullconnectionInt8CPUKernel::Run() {
auto a_ptr = reinterpret_cast<int8_t *>(inputs_.at(0)->Data());
auto b_ptr = reinterpret_cast<int8_t *>(inputs_.at(1)->Data());
auto bias_ptr = reinterpret_cast<int *>(inputs_.at(2)->Data());
auto output_ptr = reinterpret_cast<int8_t *>(outputs_.at(0)->Data());
int FullconnectionInt8CPUKernel::RunImpl(int task_id) {
int cur_oc = MSMIN(thread_stride_, UP_DIV(fc_param_->col_8_, 8) - task_id * thread_stride_);
if (cur_oc <= 0) {
return RET_OK;
}
auto &p = quant_params_;
// rows*depth -> rows*depth, col_8 major
RowMajor2Col8MajorInt8(a_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_);
// cols*depth -> cols*depth, col_8 major == depth*cols, row_8 major
RowMajor2Col8MajorInt8(b_ptr, b_r8_ptr_, fc_param_->col_, fc_param_->deep_);
MatMulInt8(a_c8_ptr_, b_r8_ptr_, c_r8x8_ptr_, fc_param_->row_8_, fc_param_->col_8_, fc_param_->deep_, p.input.zp_,
p.weight.zp_);
PostFuncInt8(c_r8x8_ptr_, bias_ptr, output_ptr, fc_param_->col_, fc_param_->row_, fc_param_->col_8_,
fc_param_->row_8_, p.quant_multiplier, p.output_shift, p.output.zp_, p.out_act_min, p.out_act_max);
auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_;
auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->row_8_;
MatMulInt8(a_c8_ptr_, cur_b, cur_c, fc_param_->row_8_, cur_oc * 8, fc_param_->deep_, p.input.zp_, p.weight.zp_);
return RET_OK;
}
int FcInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto fc = reinterpret_cast<FullconnectionInt8CPUKernel *>(cdata);
auto ret = fc->RunImpl(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "FcInt8Run error task_id[" << task_id << "] error_code[" << ret << "]";
return ret;
}
return RET_OK;
}
int FullconnectionInt8CPUKernel::Run() {
auto a_ptr = reinterpret_cast<int8_t *>(inputs_[0]->Data());
auto output_ptr = reinterpret_cast<int8_t *>(outputs_[0]->Data());
auto &p = quant_params_;
RowMajor2Col8MajorInt8(a_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_);
LiteBackendParallelLaunch(FcInt8Run, this, thread_count_);
PostFuncInt8(c_r8x8_ptr_, bias_ptr_, output_ptr, fc_param_->col_, fc_param_->row_, fc_param_->row_8_,
p.quant_multiplier, p.left_shift, p.right_shift, p.output.zp_, p.out_act_min, p.out_act_max);
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -31,20 +31,22 @@ class FullconnectionInt8CPUKernel : public FullconnectionBaseCPUKernel {
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx)
: FullconnectionBaseCPUKernel(parameter, inputs, outputs, ctx) {}
~FullconnectionInt8CPUKernel() override {
free(a_c8_ptr_);
free(b_r8_ptr_);
free(c_r8x8_ptr_);
ctx_->allocator->Free(a_c8_ptr_);
ctx_->allocator->Free(b_r8_ptr_);
ctx_->allocator->Free(c_r8x8_ptr_);
}
int Init() override;
int ReSize() override;
int Run() override;
int RunImpl(int task_id);
private:
FcQuantArg quant_params_;
int8_t *a_c8_ptr_;
int8_t *b_r8_ptr_;
int *c_r8x8_ptr_;
int *bias_ptr_;
};
} // namespace mindspore::kernel

View File

@ -17,17 +17,17 @@
// \-----------------------------------------/
// LM 8x1 block
// /---------------------\ /-----------------------------------------\
// | v0.s[0] | |v16.s[0] ... v30.s[0]|
// | v0.s[0] | |v16.s[0]...v16.s[3] v17.s[0]...v17.s[3]|
// | ... | | ... ... |
// | v0.s[3] | |v16.s[3] ... v30.s[3]|
// | v1.s[0] | |v17.s[0] ... v31.s[0]|
// | v0.s[3] | |v22.s[0]...v22.s[3] v23.s[0]...v23.s[3]|
// | v1.s[0] | |v24.s[0]...v24.s[3] v25.s[0]...v25.s[3]|
// | ... | | ... ... |
// | v1.s[3] | |v17.s[3] ... v31.s[3]|
// | v1.s[3] | |v30.s[0]...v30.s[3] v31.s[0]...v31.s[3]|
// \---------------------/ \-----------------------------------------/
// accumulators 8x8 block
//
///////////////////////////////////////////////////////////////////////////////
//OptLoopMul4 RHS 1x8 block
//OptLoopMul4 RM 1x8 block
// /--------------------------------------------\
// |v8.s[0] ... v8.s[3] v9.s[0] ... v9.s[3] |
// |v10.s[0] ... v10.s[3] v11.s[0] ... v11.s[3]|
@ -36,12 +36,12 @@
// \--------------------------------------------/
// LM 8x4 block
// /---------------------------------\ /--------------------------------------------\
// | v0.s[0] v2.s[0] v4.s[0] v6.s[0] | |v16.s[0] ... v30.s[0]|
// | v0.s[0] v2.s[0] v4.s[0] v6.s[0] | |v16.s[0]...v16.s[3] v17.s[0]...v17.s[3] |
// | ... ... ... ... | | ... ... |
// | v0.s[3] v2.s[3] v4.s[3] v6.s[3] | |v16.s[3] ... v30.s[3]|
// | v1.s[0] v3.s[0] v5.s[0] v7.s[0] | |v17.s[0] ... v31.s[0]|
// | v0.s[3] v2.s[3] v4.s[3] v6.s[3] | |v22.s[0]...v22.s[3] v23.s[0]...v23.s[3] |
// | v1.s[0] v3.s[0] v5.s[0] v7.s[0] | |v24.s[0]...v24.s[3] v25.s[0]...v25.s[3] |
// | ... ... ... ... | | ... ... |
// | v1.s[3] v3.s[3] v5.s[3] v7.s[3] | |v17.s[3] ... v31.s[3]|
// | v1.s[3] v3.s[3] v5.s[3] v7.s[3] | |v30.s[0]...v30.s[3] v31.s[0]...v31.s[3] |
// \---------------------------------/ \--------------------------------------------/
// accumulators 8x8 block
/////////////////////////////////////////////////////////////////////////////////
@ -64,25 +64,22 @@ MatMulFloatNeon64:
mov w7, v0.s[0]
mov w8, v1.s[0]
mov w9, 0 // row counter
mov w10, 0 // col counter
mov w18, #32
mul w15, w4, w18 // the stride of a or b
mul w16, w6, w18 // the stride of c
mov w9, 0 // rm col offset
mov w10, 0 // lm row offset
mov w18, #32 // sizeof(float)*8
mul w15, w4, w18 // the stride of lm/rm: sizeof(float)*8*depth
L1:
cmp w9, w5
cmp w9, w6
beq End1
mov w10, 0 // reset col counter
mov x12, x1 // reload b ptr
mov x17, x2 // reload current c ptr
mov w10, 0 // reset lm row offset
mov x12, x0 // reload lm ptr
mov x14, x3 // reload bias ptr
L2:
cmp w10, w6
beq End2
mov x11, x0 // reload a ptr
mov w13, w4 // reload depth
dup v16.4s, wzr
dup v17.4s, wzr
@ -105,142 +102,127 @@ OptLoopMul4:
cmp w13, #4
blt CommLoopMul
ld1 {v0.4s}, [x11], #16
ld1 {v8.4s}, [x12], #16
fmla v16.4s, v0.4s, v8.s[0]
fmla v18.4s, v0.4s, v8.s[1]
ld1 {v1.4s}, [x11], #16
fmla v20.4s, v0.4s, v8.s[2]
fmla v22.4s, v0.4s, v8.s[3]
ld1 {v9.4s}, [x12], #16
fmla v25.4s, v1.4s, v9.s[0]
fmla v27.4s, v1.4s, v9.s[1]
fmla v29.4s, v1.4s, v9.s[2]
fmla v31.4s, v1.4s, v9.s[3]
ld1 {v2.4s}, [x11], #16
ld1 {v3.4s}, [x11], #16
fmla v24.4s, v0.4s, v9.s[0]
fmla v26.4s, v0.4s, v9.s[1]
fmla v28.4s, v0.4s, v9.s[2]
fmla v30.4s, v0.4s, v9.s[3]
fmla v17.4s, v1.4s, v8.s[0]
fmla v19.4s, v1.4s, v8.s[1]
fmla v21.4s, v1.4s, v8.s[2]
fmla v23.4s, v1.4s, v8.s[3]
ld1 {v10.4s}, [x12], #16
ld1 {v11.4s}, [x12], #16
fmla v16.4s, v2.4s, v10.s[0]
fmla v18.4s, v2.4s, v10.s[1]
fmla v20.4s, v2.4s, v10.s[2]
fmla v22.4s, v2.4s, v10.s[3]
fmla v25.4s, v3.4s, v11.s[0]
fmla v27.4s, v3.4s, v11.s[1]
fmla v29.4s, v3.4s, v11.s[2]
fmla v31.4s, v3.4s, v11.s[3]
ld1 {v4.4s}, [x11], #16
ld1 {v5.4s}, [x11], #16
fmla v24.4s, v2.4s, v11.s[0]
fmla v26.4s, v2.4s, v11.s[1]
fmla v28.4s, v2.4s, v11.s[2]
fmla v30.4s, v2.4s, v11.s[3]
fmla v17.4s, v3.4s, v10.s[0]
fmla v19.4s, v3.4s, v10.s[1]
fmla v21.4s, v3.4s, v10.s[2]
fmla v23.4s, v3.4s, v10.s[3]
ld1 {v12.4s}, [x12], #16
ld1 {v13.4s}, [x12], #16
fmla v16.4s, v4.4s, v12.s[0]
fmla v18.4s, v4.4s, v12.s[1]
fmla v20.4s, v4.4s, v12.s[2]
fmla v22.4s, v4.4s, v12.s[3]
fmla v25.4s, v5.4s, v13.s[0]
fmla v27.4s, v5.4s, v13.s[1]
fmla v29.4s, v5.4s, v13.s[2]
fmla v31.4s, v5.4s, v13.s[3]
ld1 {v6.4s}, [x11], #16
ld1 {v7.4s}, [x11], #16
fmla v24.4s, v4.4s, v13.s[0]
fmla v26.4s, v4.4s, v13.s[1]
fmla v28.4s, v4.4s, v13.s[2]
fmla v30.4s, v4.4s, v13.s[3]
fmla v17.4s, v5.4s, v12.s[0]
fmla v19.4s, v5.4s, v12.s[1]
fmla v21.4s, v5.4s, v12.s[2]
fmla v23.4s, v5.4s, v12.s[3]
ld1 {v14.4s}, [x12], #16
ld1 {v15.4s}, [x12], #16
fmla v16.4s, v6.4s, v14.s[0]
fmla v18.4s, v6.4s, v14.s[1]
fmla v20.4s, v6.4s, v14.s[2]
fmla v22.4s, v6.4s, v14.s[3]
fmla v25.4s, v7.4s, v15.s[0]
fmla v27.4s, v7.4s, v15.s[1]
fmla v29.4s, v7.4s, v15.s[2]
fmla v31.4s, v7.4s, v15.s[3]
fmla v24.4s, v6.4s, v15.s[0]
fmla v26.4s, v6.4s, v15.s[1]
fmla v28.4s, v6.4s, v15.s[2]
fmla v30.4s, v6.4s, v15.s[3]
fmla v17.4s, v7.4s, v14.s[0]
fmla v19.4s, v7.4s, v14.s[1]
fmla v21.4s, v7.4s, v14.s[2]
fmla v23.4s, v7.4s, v14.s[3]
ld1 {v0.4s, v1.4s}, [x12], #32
ld1 {v8.4s, v9.4s}, [x1], #32
fmla v16.4s, v8.4s, v0.s[0]
fmla v17.4s, v9.4s, v0.s[0]
fmla v18.4s, v8.4s, v0.s[1]
fmla v19.4s, v9.4s, v0.s[1]
fmla v20.4s, v8.4s, v0.s[2]
fmla v21.4s, v9.4s, v0.s[2]
fmla v22.4s, v8.4s, v0.s[3]
fmla v23.4s, v9.4s, v0.s[3]
ld1 {v10.4s, v11.4s}, [x1], #32
fmla v24.4s, v8.4s, v1.s[0]
fmla v25.4s, v9.4s, v1.s[0]
fmla v26.4s, v8.4s, v1.s[1]
fmla v27.4s, v9.4s, v1.s[1]
ld1 {v2.4s, v3.4s}, [x12], #32
fmla v28.4s, v8.4s, v1.s[2]
fmla v29.4s, v9.4s, v1.s[2]
fmla v30.4s, v8.4s, v1.s[3]
fmla v31.4s, v9.4s, v1.s[3]
fmla v16.4s, v10.4s, v2.s[0]
fmla v17.4s, v11.4s, v2.s[0]
fmla v18.4s, v10.4s, v2.s[1]
fmla v19.4s, v11.4s, v2.s[1]
fmla v20.4s, v10.4s, v2.s[2]
fmla v21.4s, v11.4s, v2.s[2]
fmla v22.4s, v10.4s, v2.s[3]
fmla v23.4s, v11.4s, v2.s[3]
ld1 {v12.4s, v13.4s}, [x1], #32
fmla v24.4s, v10.4s, v3.s[0]
fmla v25.4s, v11.4s, v3.s[0]
fmla v26.4s, v10.4s, v3.s[1]
fmla v27.4s, v11.4s, v3.s[1]
ld1 {v4.4s, v5.4s}, [x12], #32
fmla v28.4s, v10.4s, v3.s[2]
fmla v29.4s, v11.4s, v3.s[2]
fmla v30.4s, v10.4s, v3.s[3]
fmla v31.4s, v11.4s, v3.s[3]
fmla v16.4s, v12.4s, v4.s[0]
fmla v17.4s, v13.4s, v4.s[0]
fmla v18.4s, v12.4s, v4.s[1]
fmla v19.4s, v13.4s, v4.s[1]
fmla v20.4s, v12.4s, v4.s[2]
fmla v21.4s, v13.4s, v4.s[2]
fmla v22.4s, v12.4s, v4.s[3]
fmla v23.4s, v13.4s, v4.s[3]
ld1 {v6.4s,v7.4s}, [x12], #32
fmla v24.4s, v12.4s, v5.s[0]
fmla v25.4s, v13.4s, v5.s[0]
fmla v26.4s, v12.4s, v5.s[1]
fmla v27.4s, v13.4s, v5.s[1]
ld1 {v14.4s, v15.4s}, [x1], #32
fmla v28.4s, v12.4s, v5.s[2]
fmla v29.4s, v13.4s, v5.s[2]
fmla v30.4s, v12.4s, v5.s[3]
fmla v31.4s, v13.4s, v5.s[3]
fmla v16.4s, v14.4s, v6.s[0]
fmla v17.4s, v15.4s, v6.s[0]
fmla v18.4s, v14.4s, v6.s[1]
fmla v19.4s, v15.4s, v6.s[1]
fmla v20.4s, v14.4s, v6.s[2]
fmla v21.4s, v15.4s, v6.s[2]
fmla v22.4s, v14.4s, v6.s[3]
fmla v23.4s, v15.4s, v6.s[3]
fmla v24.4s, v14.4s, v7.s[0]
fmla v25.4s, v15.4s, v7.s[0]
fmla v26.4s, v14.4s, v7.s[1]
fmla v27.4s, v15.4s, v7.s[1]
fmla v28.4s, v14.4s, v7.s[2]
fmla v29.4s, v15.4s, v7.s[2]
fmla v30.4s, v14.4s, v7.s[3]
fmla v31.4s, v15.4s, v7.s[3]
subs w13, w13, #4
b OptLoopMul4
CommLoopMul:
cmp w13, #1
blt Bias
ld1 {v0.4s}, [x11], #16
ld1 {v2.4s}, [x12], #16
fmla v16.4s, v0.4s, v2.s[0]
fmla v18.4s, v0.4s, v2.s[1]
ld1 {v1.4s}, [x11], #16
fmla v20.4s, v0.4s, v2.s[2]
fmla v22.4s, v0.4s, v2.s[3]
ld1 {v3.4s}, [x12], #16
fmla v25.4s, v1.4s, v3.s[0]
fmla v27.4s, v1.4s, v3.s[1]
fmla v29.4s, v1.4s, v3.s[2]
fmla v31.4s, v1.4s, v3.s[3]
fmla v24.4s, v0.4s, v3.s[0]
fmla v26.4s, v0.4s, v3.s[1]
fmla v28.4s, v0.4s, v3.s[2]
fmla v30.4s, v0.4s, v3.s[3]
fmla v17.4s, v1.4s, v2.s[0]
fmla v19.4s, v1.4s, v2.s[1]
fmla v21.4s, v1.4s, v2.s[2]
fmla v23.4s, v1.4s, v2.s[3]
ld1 {v0.4s, v1.4s}, [x12], #32
ld1 {v2.4s, v3.4s}, [x1], #32
fmla v16.4s, v2.4s, v0.s[0]
fmla v17.4s, v3.4s, v0.s[0]
fmla v18.4s, v2.4s, v0.s[1]
fmla v19.4s, v3.4s, v0.s[1]
fmla v20.4s, v2.4s, v0.s[2]
fmla v21.4s, v3.4s, v0.s[2]
fmla v22.4s, v2.4s, v0.s[3]
fmla v23.4s, v3.4s, v0.s[3]
fmla v24.4s, v2.4s, v1.s[0]
fmla v25.4s, v3.4s, v1.s[0]
fmla v26.4s, v2.4s, v1.s[1]
fmla v27.4s, v3.4s, v1.s[1]
fmla v28.4s, v2.4s, v1.s[2]
fmla v29.4s, v3.4s, v1.s[2]
fmla v30.4s, v2.4s, v1.s[3]
fmla v31.4s, v3.4s, v1.s[3]
subs w13, w13, #1
b CommLoopMul
Bias:
cmp x3, #0
beq Relu
ld1 {v0.4s}, [x14], #16
ld1 {v1.4s}, [x14], #16
dup v2.4s, v0.s[0]
fadd v16.4s, v16.4s, v2.4s
fadd v17.4s, v17.4s, v2.4s
dup v3.4s, v0.s[1]
fadd v18.4s, v18.4s, v3.4s
fadd v19.4s, v19.4s, v3.4s
dup v4.4s, v0.s[2]
fadd v20.4s, v20.4s, v4.4s
fadd v21.4s, v21.4s, v4.4s
dup v5.4s, v0.s[3]
fadd v22.4s, v22.4s, v5.4s
fadd v23.4s, v23.4s, v5.4s
dup v2.4s, v1.s[0]
fadd v24.4s, v24.4s, v2.4s
fadd v25.4s, v25.4s, v2.4s
dup v3.4s, v1.s[1]
fadd v26.4s, v26.4s, v3.4s
fadd v27.4s, v27.4s, v3.4s
dup v4.4s, v1.s[2]
fadd v28.4s, v28.4s, v4.4s
fadd v29.4s, v29.4s, v4.4s
dup v5.4s, v1.s[3]
fadd v30.4s, v30.4s, v5.4s
fadd v31.4s, v31.4s, v5.4s
fadd v16.4s, v16.4s, v0.4s
fadd v17.4s, v17.4s, v1.4s
fadd v18.4s, v18.4s, v0.4s
fadd v19.4s, v19.4s, v1.4s
fadd v20.4s, v20.4s, v0.4s
fadd v21.4s, v21.4s, v1.4s
fadd v22.4s, v22.4s, v0.4s
fadd v23.4s, v23.4s, v1.4s
fadd v24.4s, v24.4s, v0.4s
fadd v25.4s, v25.4s, v1.4s
fadd v26.4s, v26.4s, v0.4s
fadd v27.4s, v27.4s, v1.4s
fadd v28.4s, v28.4s, v0.4s
fadd v29.4s, v29.4s, v1.4s
fadd v30.4s, v30.4s, v0.4s
fadd v31.4s, v31.4s, v1.4s
Relu:
dup v15.4s, w7
@ -281,30 +263,28 @@ Relu:
fmin v31.4s, v31.4s, v15.4s
TransToOut:
st1 {v16.4s}, [x17], #16
st1 {v17.4s}, [x17], #16
st1 {v18.4s}, [x17], #16
st1 {v19.4s}, [x17], #16
st1 {v20.4s}, [x17], #16
st1 {v21.4s}, [x17], #16
st1 {v22.4s}, [x17], #16
st1 {v23.4s}, [x17], #16
st1 {v24.4s}, [x17], #16
st1 {v25.4s}, [x17], #16
st1 {v26.4s}, [x17], #16
st1 {v27.4s}, [x17], #16
st1 {v28.4s}, [x17], #16
st1 {v29.4s}, [x17], #16
st1 {v30.4s}, [x17], #16
st1 {v31.4s}, [x17], #16
st1 {v16.4s}, [x2], #16
st1 {v17.4s}, [x2], #16
st1 {v18.4s}, [x2], #16
st1 {v19.4s}, [x2], #16
st1 {v20.4s}, [x2], #16
st1 {v21.4s}, [x2], #16
st1 {v22.4s}, [x2], #16
st1 {v23.4s}, [x2], #16
st1 {v24.4s}, [x2], #16
st1 {v25.4s}, [x2], #16
st1 {v26.4s}, [x2], #16
st1 {v27.4s}, [x2], #16
st1 {v28.4s}, [x2], #16
st1 {v29.4s}, [x2], #16
st1 {v30.4s}, [x2], #16
st1 {v31.4s}, [x2], #16
add w10, w10, #8 // col+=8
add w10, w10, #8 // lhs row offset + 8
b L2
End2:
add x0, x0, x15 // stride a ptr
add x2, x2, x16 // stride c ptr
add w9, w9, #8 // row+=8
add w9, w9, #8 // rhs col offset + 8
b L1
End1:

View File

@ -74,5 +74,9 @@ void MatMul8x8(const float *a, const float *b, float *c, const float *bias, floa
void MatMul(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int deep, int row_8_,
int col_8_) {
#ifdef __aarch64__
MatMulFloatNeon64(a, b, c, bias, maxf, minf, deep, row_8_, col_8_);
#else
MatMul8x8(a, b, c, bias, maxf, minf, deep, row_8_, col_8_);
#endif
}

View File

@ -21,19 +21,22 @@
#include "src/runtime/kernel/arm/opclib/op_base.h"
#include "src/runtime/kernel/arm/opclib/matmul.h"
#ifdef __cplusplus
extern "C" {
#endif
void MatMul(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int depth, int row,
int col);
void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, int row, int col);
void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, int row, int col);
void MatMul8x8(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int deep,
int row_8_, int col_8_);
#ifdef __cplusplus
extern "C" {
#endif
#ifdef __aarch64__
void MatMulFloatNeon64(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int depth,
int row, int col);
#endif
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_MATMUL_H_

View File

@ -48,54 +48,3 @@ void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, co
}
return;
}
// todo: need to delete, replace by above functions. z00445833
void GemmRowCol8x8Major2RowMajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
int col8 = UP_ROUND(col, 8);
for (int r = 0; r < row; r++) {
int rd8 = r / 8;
int rm8 = r % 8;
for (int c = 0; c < col; c++) {
dst_ptr[r * col + c] = src_ptr[rd8 * col8 * 8 + c * 8 + rm8];
}
}
}
void Gemm8x8Int8(const int8_t *lhs_data, const int8_t *rhs_data, const int8_t *bias_data, int8_t *output_data,
int depth, FcQuantArg *params) {
int lhs_offset = params->input.zp_;
int rhs_offset = params->weight.zp_;
int output_offset = params->output.zp_;
int output_multiplier = params->quant_multiplier;
int output_shift = params->output_shift;
for (int row = 0; row < 8; ++row) {
for (int col = 0; col < 8; ++col) {
int c_index = col * 8 + row;
int acc = 0;
for (int d = 0; d < depth; ++d) {
int a_index = d * 8 + row;
int b_index = d * 8 + col;
acc += (lhs_data[a_index] - lhs_offset) * (rhs_data[b_index] - rhs_offset);
}
acc += bias_data[col];
acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift, output_shift) + output_offset;
acc = MSMAX(CHAR_MIN, MSMIN(CHAR_MAX, acc));
output_data[c_index] = (int8_t)acc;
}
}
}
void GemmInt8(const int8_t *input_data, const int8_t *weights_data, const int8_t *bias_data, int8_t *output_data,
int row_8, int col_8, int depth, FcQuantArg *params) {
for (int r = 0; r < row_8; r += 8) {
int8_t *output = output_data + r * col_8;
const int8_t *input = input_data + r * depth;
for (int c = 0; c < col_8; c += 8) {
const int8_t *bias = bias_data + c;
const int8_t *weights = weights_data + c * depth;
int8_t *dst = output + c * 8;
Gemm8x8Int8(input, weights, bias, dst, depth, params);
}
}
}

View File

@ -20,23 +20,9 @@
#include "src/runtime/kernel/arm/opclib/op_base.h"
#include "src/runtime/kernel/arm/opclib/matmul.h"
#ifdef __cplusplus
extern "C" {
#endif
void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, const int col8, const int deep,
const int32_t a_zp, const int32_t b_zp);
void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void GemmRowCol8x8Major2RowMajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void Gemm8x8Int8(const int8_t *lhs_data, const int8_t *rhs_data, const int8_t *bias_data, int8_t *output_data,
int depth, FcQuantArg *params);
void GemmInt8(const int8_t *input_data, const int8_t *weights_data, const int8_t *bias_data, int8_t *output_data,
int row_8, int col_8, int depth, FcQuantArg *params);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_MATMUL_H_
#endif // MINDSPORE_LITE_SRC_BACKEND_ARM_OPCLIB_INT8_MATMUL_H_

View File

@ -54,7 +54,8 @@ struct FcQuantArg {
QuantArg output;
int32_t out_act_min;
int32_t out_act_max;
int32_t output_shift;
int32_t left_shift;
int32_t right_shift;
int32_t quant_multiplier;
};

View File

@ -0,0 +1,144 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h"
#include "mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.h"
#include "mindspore/lite/src/runtime/kernel/arm/opclib/common_func.h"
#include "mindspore/lite/src/kernel_registry.h"
#include "mindspore/lite/src/lite_kernel.h"
namespace mindspore {
using lite::tensor::Tensor;
class TestFcInt8 : public mindspore::Common {
public:
TestFcInt8(){}
};
void Quantize(float *input_data, int length, float scale, int zero_point, int8_t *output_data) {
for (int i = 0; i < length; ++i) {
int8_t q = static_cast<int8_t>(std::max<float>(
std::numeric_limits<int8_t>::min(),
std::min<float>(std::numeric_limits<int8_t>::max(), std::round(zero_point + (input_data[i] / scale)))));
output_data[i] = q;
}
}
void Dequantize(int8_t *input_data, int length, float scale, int zero_point, float *output_data) {
for (int i = 0; i < length; ++i) {
output_data[i] = scale * (input_data[i] - zero_point);
}
}
int FcInt8TestInit(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lite::tensor::Tensor *> *outputs_,
MatMulParameter *matmal_param, float **correct, double *scale, int *zeropoint) {
float input_max = 20;
float input_min = -20;
float weight_max = 1;
float weight_min = -1;
float output_max = 20;
float output_min = -20;
double input_scale =
(input_max - input_min) / (std::numeric_limits<int8_t>::max() - std::numeric_limits<int8_t>::min());
int input_zp = std::numeric_limits<int8_t>::max() - input_max / input_scale;
double weight_scale =
(weight_max - weight_min) / (std::numeric_limits<int8_t>::max() - std::numeric_limits<int8_t>::min());
int weight_zp = std::numeric_limits<int8_t>::max() - weight_max / weight_scale;
double output_scale =
(output_max - output_min) / (std::numeric_limits<int8_t>::max() - std::numeric_limits<int8_t>::min());
int output_zp = std::numeric_limits<int8_t>::max() - output_max / output_scale;
*scale = output_scale;
*zeropoint = output_zp;
Tensor *in_t = new Tensor(kNumberTypeInt8, {2, 2, 2, 2}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
in_t->MallocData();
float in[] = {-3.2366564, -4.7733846, -7.8329225, 16.146885, 5.060793, -6.1471, -1.7680453, -6.5721383,
17.87506, -5.1192183, 10.742863, 1.4536934, 19.693445, 19.45783, 5.063163, 0.5234792};
Quantize(in, in_t->ElementsNum(), input_scale, input_zp, reinterpret_cast<int8_t *>(in_t->Data()));
auto in_quant_arg = new mindspore::lite::tensor::QuantArg();
in_quant_arg->zeroPoint = input_zp;
in_quant_arg->scale = input_scale;
in_t->AddQuantParam(*in_quant_arg);
inputs_->push_back(in_t);
Tensor *weight_t = new Tensor(kNumberTypeInt8, {3, 8}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
weight_t->MallocData();
float weight[] = {-0.24438887, 0.06738146, -0.8169129, 0.21510671, -0.012470592, -0.053063435,
0.6050155, 0.8656233, 0.12911413, -0.028635843, -0.034080597, -0.10622552,
-0.012254699, -0.01312836, 0.25241964, -0.4706142, 0.2451482, -0.9558459,
0.4481974, 0.33251503, -0.011705584, -0.1720293, -0.39410214, -0.73637343};
Quantize(weight, weight_t->ElementsNum(), weight_scale, weight_zp, reinterpret_cast<int8_t *>(weight_t->Data()));
auto weight_quant_arg = new mindspore::lite::tensor::QuantArg();
weight_quant_arg->zeroPoint = weight_zp;
weight_quant_arg->scale = weight_scale;
weight_t->AddQuantParam(*weight_quant_arg);
inputs_->push_back(weight_t);
Tensor *bias_t = new Tensor(kNumberTypeInt32, {3}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
bias_t->MallocData();
memset(bias_t->Data(), 0, sizeof(int) * bias_t->ElementsNum());
inputs_->push_back(bias_t);
Tensor *out_t = new Tensor(kNumberTypeInt8, {2, 3}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
out_t->MallocData();
auto output_quant_arg = new mindspore::lite::tensor::QuantArg();
output_quant_arg->zeroPoint = output_zp;
output_quant_arg->scale = output_scale;
out_t->AddQuantParam(*output_quant_arg);
outputs_->push_back(out_t);
*correct = reinterpret_cast<float *>(malloc(out_t->ElementsNum() * sizeof(float)));
float nchw_co[] = {3.84586822, 0.93586633, 12.16212629, -10.93835061, 2.46887183, 8.61480108};
memcpy(*correct, nchw_co, out_t->ElementsNum() * sizeof(float));
matmal_param->b_transpose_ = true;
matmal_param->a_transpose_ = false;
matmal_param->has_bias_ = true;
matmal_param->minf_ = -FLT_MAX;
matmal_param->maxf_ = FLT_MAX;
return out_t->ElementsNum();
}
TEST_F(TestFcInt8, fcint8) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
auto matmul_param = new MatMulParameter();
float *correct;
double output_scale;
int output_zp;
int total_size = FcInt8TestInit(&inputs_, &outputs_, matmul_param, &correct, &output_scale, &output_zp);
lite::Context *ctx = new lite::Context;
ctx->threadNum = 2;
kernel::FullconnectionInt8CPUKernel *fc =
new kernel::FullconnectionInt8CPUKernel(reinterpret_cast<OpParameter *>(matmul_param), inputs_, outputs_, ctx);
fc->Init();
fc->Run();
float fout[6] = {0};
Dequantize(reinterpret_cast<int8_t *>(outputs_[0]->Data()), outputs_[0]->ElementsNum(), output_scale, output_zp,
fout);
CompareOutputData(fout, correct, 6, 0.2);
delete matmul_param;
delete fc;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
free(correct);
}
} // namespace mindspore