1. multithreading support for fc_int8_op. 2. change asm matmul output layout from col8x8 to row8x8
This commit is contained in:
parent
6c4ee3f3d1
commit
9cab9f7a91
|
@ -17,6 +17,7 @@
|
||||||
#include "src/runtime/kernel/arm/int8/fullconnection_int8.h"
|
#include "src/runtime/kernel/arm/int8/fullconnection_int8.h"
|
||||||
#include "src/runtime/kernel/arm/opclib/int8/matmul.h"
|
#include "src/runtime/kernel/arm/opclib/int8/matmul.h"
|
||||||
#include "src/runtime/kernel/arm/opclib/common_func.h"
|
#include "src/runtime/kernel/arm/opclib/common_func.h"
|
||||||
|
#include "src/runtime/runtime_api.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
|
|
||||||
using mindspore::lite::RET_MEMORY_FAILED;
|
using mindspore::lite::RET_MEMORY_FAILED;
|
||||||
|
@ -25,22 +26,42 @@ using mindspore::lite::RET_OK;
|
||||||
namespace mindspore::kernel {
|
namespace mindspore::kernel {
|
||||||
int FullconnectionInt8CPUKernel::Init() {
|
int FullconnectionInt8CPUKernel::Init() {
|
||||||
fc_param_->row_ = (inputs_[0]->shape())[0];
|
fc_param_->row_ = (inputs_[0]->shape())[0];
|
||||||
fc_param_->col_ = (inputs_[1]->shape())[1];
|
fc_param_->col_ = (inputs_[1]->shape())[0];
|
||||||
fc_param_->deep_ = (inputs_[1]->shape())[0];
|
fc_param_->deep_ = (inputs_[1]->shape())[1];
|
||||||
fc_param_->row_8_ = UP_ROUND(fc_param_->row_, 8);
|
fc_param_->row_8_ = UP_ROUND(fc_param_->row_, 8);
|
||||||
fc_param_->col_8_ = UP_ROUND(fc_param_->col_, 8);
|
fc_param_->col_8_ = UP_ROUND(fc_param_->col_, 8);
|
||||||
|
|
||||||
|
thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, 8));
|
||||||
|
thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, 8), thread_count_);
|
||||||
|
|
||||||
a_c8_ptr_ =
|
a_c8_ptr_ =
|
||||||
reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(fc_param_->row_8_ * fc_param_->deep_ * sizeof(int8_t)));
|
reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(fc_param_->row_8_ * fc_param_->deep_ * sizeof(int8_t)));
|
||||||
|
if (!a_c8_ptr_) {
|
||||||
|
return RET_MEMORY_FAILED;
|
||||||
|
}
|
||||||
memset(a_c8_ptr_, 0, fc_param_->row_8_ * fc_param_->deep_ * sizeof(int8_t));
|
memset(a_c8_ptr_, 0, fc_param_->row_8_ * fc_param_->deep_ * sizeof(int8_t));
|
||||||
b_r8_ptr_ =
|
b_r8_ptr_ =
|
||||||
reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(fc_param_->col_8_ * fc_param_->deep_ * sizeof(int8_t)));
|
reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(fc_param_->col_8_ * fc_param_->deep_ * sizeof(int8_t)));
|
||||||
memset(b_r8_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(int8_t));
|
if (!b_r8_ptr_) {
|
||||||
c_r8x8_ptr_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(fc_param_->row_8_ * fc_param_->col_8_ * sizeof(int)));
|
|
||||||
memset(c_r8x8_ptr_, 0, fc_param_->row_8_ * fc_param_->col_8_ * sizeof(int));
|
|
||||||
if (!a_c8_ptr_ || !b_r8_ptr_ || !c_r8x8_ptr_) {
|
|
||||||
return RET_MEMORY_FAILED;
|
return RET_MEMORY_FAILED;
|
||||||
}
|
}
|
||||||
|
memset(b_r8_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(int8_t));
|
||||||
|
auto weight_data = reinterpret_cast<int8_t *>(inputs_[1]->Data());
|
||||||
|
RowMajor2Col8MajorInt8(weight_data, b_r8_ptr_, fc_param_->col_, fc_param_->deep_);
|
||||||
|
c_r8x8_ptr_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(fc_param_->row_8_ * fc_param_->col_8_ * sizeof(int)));
|
||||||
|
if (!c_r8x8_ptr_) {
|
||||||
|
return RET_MEMORY_FAILED;
|
||||||
|
}
|
||||||
|
memset(c_r8x8_ptr_, 0, fc_param_->row_8_ * fc_param_->col_8_ * sizeof(int));
|
||||||
|
auto bias_len = fc_param_->col_8_ * sizeof(int);
|
||||||
|
bias_ptr_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(bias_len));
|
||||||
|
if (!bias_ptr_) {
|
||||||
|
return RET_MEMORY_FAILED;
|
||||||
|
}
|
||||||
|
memset(bias_ptr_, 0, bias_len);
|
||||||
|
if (inputs_.size() == 3) {
|
||||||
|
memcpy(bias_ptr_, inputs_[2]->Data(), bias_len);
|
||||||
|
}
|
||||||
|
|
||||||
auto input_tensor = inputs_[0];
|
auto input_tensor = inputs_[0];
|
||||||
auto params = input_tensor->GetQuantParams();
|
auto params = input_tensor->GetQuantParams();
|
||||||
|
@ -59,7 +80,8 @@ int FullconnectionInt8CPUKernel::Init() {
|
||||||
quant_params_.output.scale_ = params.front().scale;
|
quant_params_.output.scale_ = params.front().scale;
|
||||||
|
|
||||||
double real_multiplier = quant_params_.input.scale_ * quant_params_.weight.scale_ / quant_params_.output.scale_;
|
double real_multiplier = quant_params_.input.scale_ * quant_params_.weight.scale_ / quant_params_.output.scale_;
|
||||||
QuantizeMultiplier(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.output_shift);
|
QuantizeRoundParameter(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift,
|
||||||
|
&quant_params_.right_shift);
|
||||||
CalculateActivationRangeQuantized(fc_param_->maxf_, fc_param_->minf_, quant_params_.output.scale_,
|
CalculateActivationRangeQuantized(fc_param_->maxf_, fc_param_->minf_, quant_params_.output.scale_,
|
||||||
quant_params_.output.zp_, &quant_params_.out_act_max, &quant_params_.out_act_min);
|
quant_params_.output.zp_, &quant_params_.out_act_max, &quant_params_.out_act_min);
|
||||||
|
|
||||||
|
@ -68,22 +90,37 @@ int FullconnectionInt8CPUKernel::Init() {
|
||||||
|
|
||||||
int FullconnectionInt8CPUKernel::ReSize() { return RET_OK; }
|
int FullconnectionInt8CPUKernel::ReSize() { return RET_OK; }
|
||||||
|
|
||||||
int FullconnectionInt8CPUKernel::Run() {
|
int FullconnectionInt8CPUKernel::RunImpl(int task_id) {
|
||||||
auto a_ptr = reinterpret_cast<int8_t *>(inputs_.at(0)->Data());
|
int cur_oc = MSMIN(thread_stride_, UP_DIV(fc_param_->col_8_, 8) - task_id * thread_stride_);
|
||||||
auto b_ptr = reinterpret_cast<int8_t *>(inputs_.at(1)->Data());
|
if (cur_oc <= 0) {
|
||||||
auto bias_ptr = reinterpret_cast<int *>(inputs_.at(2)->Data());
|
return RET_OK;
|
||||||
auto output_ptr = reinterpret_cast<int8_t *>(outputs_.at(0)->Data());
|
}
|
||||||
auto &p = quant_params_;
|
auto &p = quant_params_;
|
||||||
|
auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_;
|
||||||
// rows*depth -> rows*depth, col_8 major
|
auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->row_8_;
|
||||||
RowMajor2Col8MajorInt8(a_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_);
|
MatMulInt8(a_c8_ptr_, cur_b, cur_c, fc_param_->row_8_, cur_oc * 8, fc_param_->deep_, p.input.zp_, p.weight.zp_);
|
||||||
// cols*depth -> cols*depth, col_8 major == depth*cols, row_8 major
|
|
||||||
RowMajor2Col8MajorInt8(b_ptr, b_r8_ptr_, fc_param_->col_, fc_param_->deep_);
|
|
||||||
MatMulInt8(a_c8_ptr_, b_r8_ptr_, c_r8x8_ptr_, fc_param_->row_8_, fc_param_->col_8_, fc_param_->deep_, p.input.zp_,
|
|
||||||
p.weight.zp_);
|
|
||||||
PostFuncInt8(c_r8x8_ptr_, bias_ptr, output_ptr, fc_param_->col_, fc_param_->row_, fc_param_->col_8_,
|
|
||||||
fc_param_->row_8_, p.quant_multiplier, p.output_shift, p.output.zp_, p.out_act_min, p.out_act_max);
|
|
||||||
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int FcInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
|
||||||
|
auto fc = reinterpret_cast<FullconnectionInt8CPUKernel *>(cdata);
|
||||||
|
auto ret = fc->RunImpl(task_id);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "FcInt8Run error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int FullconnectionInt8CPUKernel::Run() {
|
||||||
|
auto a_ptr = reinterpret_cast<int8_t *>(inputs_[0]->Data());
|
||||||
|
auto output_ptr = reinterpret_cast<int8_t *>(outputs_[0]->Data());
|
||||||
|
auto &p = quant_params_;
|
||||||
|
RowMajor2Col8MajorInt8(a_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_);
|
||||||
|
LiteBackendParallelLaunch(FcInt8Run, this, thread_count_);
|
||||||
|
PostFuncInt8(c_r8x8_ptr_, bias_ptr_, output_ptr, fc_param_->col_, fc_param_->row_, fc_param_->row_8_,
|
||||||
|
p.quant_multiplier, p.left_shift, p.right_shift, p.output.zp_, p.out_act_min, p.out_act_max);
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
|
@ -31,20 +31,22 @@ class FullconnectionInt8CPUKernel : public FullconnectionBaseCPUKernel {
|
||||||
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx)
|
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx)
|
||||||
: FullconnectionBaseCPUKernel(parameter, inputs, outputs, ctx) {}
|
: FullconnectionBaseCPUKernel(parameter, inputs, outputs, ctx) {}
|
||||||
~FullconnectionInt8CPUKernel() override {
|
~FullconnectionInt8CPUKernel() override {
|
||||||
free(a_c8_ptr_);
|
ctx_->allocator->Free(a_c8_ptr_);
|
||||||
free(b_r8_ptr_);
|
ctx_->allocator->Free(b_r8_ptr_);
|
||||||
free(c_r8x8_ptr_);
|
ctx_->allocator->Free(c_r8x8_ptr_);
|
||||||
}
|
}
|
||||||
|
|
||||||
int Init() override;
|
int Init() override;
|
||||||
int ReSize() override;
|
int ReSize() override;
|
||||||
int Run() override;
|
int Run() override;
|
||||||
|
int RunImpl(int task_id);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FcQuantArg quant_params_;
|
FcQuantArg quant_params_;
|
||||||
int8_t *a_c8_ptr_;
|
int8_t *a_c8_ptr_;
|
||||||
int8_t *b_r8_ptr_;
|
int8_t *b_r8_ptr_;
|
||||||
int *c_r8x8_ptr_;
|
int *c_r8x8_ptr_;
|
||||||
|
int *bias_ptr_;
|
||||||
};
|
};
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
||||||
|
|
|
@ -17,17 +17,17 @@
|
||||||
// \-----------------------------------------/
|
// \-----------------------------------------/
|
||||||
// LM 8x1 block
|
// LM 8x1 block
|
||||||
// /---------------------\ /-----------------------------------------\
|
// /---------------------\ /-----------------------------------------\
|
||||||
// | v0.s[0] | |v16.s[0] ... v30.s[0]|
|
// | v0.s[0] | |v16.s[0]...v16.s[3] v17.s[0]...v17.s[3]|
|
||||||
// | ... | | ... ... |
|
// | ... | | ... ... |
|
||||||
// | v0.s[3] | |v16.s[3] ... v30.s[3]|
|
// | v0.s[3] | |v22.s[0]...v22.s[3] v23.s[0]...v23.s[3]|
|
||||||
// | v1.s[0] | |v17.s[0] ... v31.s[0]|
|
// | v1.s[0] | |v24.s[0]...v24.s[3] v25.s[0]...v25.s[3]|
|
||||||
// | ... | | ... ... |
|
// | ... | | ... ... |
|
||||||
// | v1.s[3] | |v17.s[3] ... v31.s[3]|
|
// | v1.s[3] | |v30.s[0]...v30.s[3] v31.s[0]...v31.s[3]|
|
||||||
// \---------------------/ \-----------------------------------------/
|
// \---------------------/ \-----------------------------------------/
|
||||||
// accumulators 8x8 block
|
// accumulators 8x8 block
|
||||||
//
|
//
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
//OptLoopMul4 RHS 1x8 block
|
//OptLoopMul4 RM 1x8 block
|
||||||
// /--------------------------------------------\
|
// /--------------------------------------------\
|
||||||
// |v8.s[0] ... v8.s[3] v9.s[0] ... v9.s[3] |
|
// |v8.s[0] ... v8.s[3] v9.s[0] ... v9.s[3] |
|
||||||
// |v10.s[0] ... v10.s[3] v11.s[0] ... v11.s[3]|
|
// |v10.s[0] ... v10.s[3] v11.s[0] ... v11.s[3]|
|
||||||
|
@ -36,12 +36,12 @@
|
||||||
// \--------------------------------------------/
|
// \--------------------------------------------/
|
||||||
// LM 8x4 block
|
// LM 8x4 block
|
||||||
// /---------------------------------\ /--------------------------------------------\
|
// /---------------------------------\ /--------------------------------------------\
|
||||||
// | v0.s[0] v2.s[0] v4.s[0] v6.s[0] | |v16.s[0] ... v30.s[0]|
|
// | v0.s[0] v2.s[0] v4.s[0] v6.s[0] | |v16.s[0]...v16.s[3] v17.s[0]...v17.s[3] |
|
||||||
// | ... ... ... ... | | ... ... |
|
// | ... ... ... ... | | ... ... |
|
||||||
// | v0.s[3] v2.s[3] v4.s[3] v6.s[3] | |v16.s[3] ... v30.s[3]|
|
// | v0.s[3] v2.s[3] v4.s[3] v6.s[3] | |v22.s[0]...v22.s[3] v23.s[0]...v23.s[3] |
|
||||||
// | v1.s[0] v3.s[0] v5.s[0] v7.s[0] | |v17.s[0] ... v31.s[0]|
|
// | v1.s[0] v3.s[0] v5.s[0] v7.s[0] | |v24.s[0]...v24.s[3] v25.s[0]...v25.s[3] |
|
||||||
// | ... ... ... ... | | ... ... |
|
// | ... ... ... ... | | ... ... |
|
||||||
// | v1.s[3] v3.s[3] v5.s[3] v7.s[3] | |v17.s[3] ... v31.s[3]|
|
// | v1.s[3] v3.s[3] v5.s[3] v7.s[3] | |v30.s[0]...v30.s[3] v31.s[0]...v31.s[3] |
|
||||||
// \---------------------------------/ \--------------------------------------------/
|
// \---------------------------------/ \--------------------------------------------/
|
||||||
// accumulators 8x8 block
|
// accumulators 8x8 block
|
||||||
/////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -64,25 +64,22 @@ MatMulFloatNeon64:
|
||||||
|
|
||||||
mov w7, v0.s[0]
|
mov w7, v0.s[0]
|
||||||
mov w8, v1.s[0]
|
mov w8, v1.s[0]
|
||||||
mov w9, 0 // row counter
|
mov w9, 0 // rm col offset
|
||||||
mov w10, 0 // col counter
|
mov w10, 0 // lm row offset
|
||||||
mov w18, #32
|
mov w18, #32 // sizeof(float)*8
|
||||||
mul w15, w4, w18 // the stride of a or b
|
mul w15, w4, w18 // the stride of lm/rm: sizeof(float)*8*depth
|
||||||
mul w16, w6, w18 // the stride of c
|
|
||||||
|
|
||||||
L1:
|
L1:
|
||||||
cmp w9, w5
|
cmp w9, w6
|
||||||
beq End1
|
beq End1
|
||||||
|
|
||||||
mov w10, 0 // reset col counter
|
mov w10, 0 // reset lm row offset
|
||||||
mov x12, x1 // reload b ptr
|
mov x12, x0 // reload lm ptr
|
||||||
mov x17, x2 // reload current c ptr
|
|
||||||
mov x14, x3 // reload bias ptr
|
mov x14, x3 // reload bias ptr
|
||||||
L2:
|
L2:
|
||||||
cmp w10, w6
|
cmp w10, w6
|
||||||
beq End2
|
beq End2
|
||||||
|
|
||||||
mov x11, x0 // reload a ptr
|
|
||||||
mov w13, w4 // reload depth
|
mov w13, w4 // reload depth
|
||||||
dup v16.4s, wzr
|
dup v16.4s, wzr
|
||||||
dup v17.4s, wzr
|
dup v17.4s, wzr
|
||||||
|
@ -105,142 +102,127 @@ OptLoopMul4:
|
||||||
cmp w13, #4
|
cmp w13, #4
|
||||||
blt CommLoopMul
|
blt CommLoopMul
|
||||||
|
|
||||||
ld1 {v0.4s}, [x11], #16
|
ld1 {v0.4s, v1.4s}, [x12], #32
|
||||||
ld1 {v8.4s}, [x12], #16
|
ld1 {v8.4s, v9.4s}, [x1], #32
|
||||||
fmla v16.4s, v0.4s, v8.s[0]
|
fmla v16.4s, v8.4s, v0.s[0]
|
||||||
fmla v18.4s, v0.4s, v8.s[1]
|
fmla v17.4s, v9.4s, v0.s[0]
|
||||||
ld1 {v1.4s}, [x11], #16
|
fmla v18.4s, v8.4s, v0.s[1]
|
||||||
fmla v20.4s, v0.4s, v8.s[2]
|
fmla v19.4s, v9.4s, v0.s[1]
|
||||||
fmla v22.4s, v0.4s, v8.s[3]
|
fmla v20.4s, v8.4s, v0.s[2]
|
||||||
ld1 {v9.4s}, [x12], #16
|
fmla v21.4s, v9.4s, v0.s[2]
|
||||||
fmla v25.4s, v1.4s, v9.s[0]
|
fmla v22.4s, v8.4s, v0.s[3]
|
||||||
fmla v27.4s, v1.4s, v9.s[1]
|
fmla v23.4s, v9.4s, v0.s[3]
|
||||||
fmla v29.4s, v1.4s, v9.s[2]
|
ld1 {v10.4s, v11.4s}, [x1], #32
|
||||||
fmla v31.4s, v1.4s, v9.s[3]
|
fmla v24.4s, v8.4s, v1.s[0]
|
||||||
ld1 {v2.4s}, [x11], #16
|
fmla v25.4s, v9.4s, v1.s[0]
|
||||||
ld1 {v3.4s}, [x11], #16
|
fmla v26.4s, v8.4s, v1.s[1]
|
||||||
fmla v24.4s, v0.4s, v9.s[0]
|
fmla v27.4s, v9.4s, v1.s[1]
|
||||||
fmla v26.4s, v0.4s, v9.s[1]
|
ld1 {v2.4s, v3.4s}, [x12], #32
|
||||||
fmla v28.4s, v0.4s, v9.s[2]
|
fmla v28.4s, v8.4s, v1.s[2]
|
||||||
fmla v30.4s, v0.4s, v9.s[3]
|
fmla v29.4s, v9.4s, v1.s[2]
|
||||||
fmla v17.4s, v1.4s, v8.s[0]
|
fmla v30.4s, v8.4s, v1.s[3]
|
||||||
fmla v19.4s, v1.4s, v8.s[1]
|
fmla v31.4s, v9.4s, v1.s[3]
|
||||||
fmla v21.4s, v1.4s, v8.s[2]
|
fmla v16.4s, v10.4s, v2.s[0]
|
||||||
fmla v23.4s, v1.4s, v8.s[3]
|
fmla v17.4s, v11.4s, v2.s[0]
|
||||||
ld1 {v10.4s}, [x12], #16
|
fmla v18.4s, v10.4s, v2.s[1]
|
||||||
ld1 {v11.4s}, [x12], #16
|
fmla v19.4s, v11.4s, v2.s[1]
|
||||||
fmla v16.4s, v2.4s, v10.s[0]
|
fmla v20.4s, v10.4s, v2.s[2]
|
||||||
fmla v18.4s, v2.4s, v10.s[1]
|
fmla v21.4s, v11.4s, v2.s[2]
|
||||||
fmla v20.4s, v2.4s, v10.s[2]
|
fmla v22.4s, v10.4s, v2.s[3]
|
||||||
fmla v22.4s, v2.4s, v10.s[3]
|
fmla v23.4s, v11.4s, v2.s[3]
|
||||||
fmla v25.4s, v3.4s, v11.s[0]
|
ld1 {v12.4s, v13.4s}, [x1], #32
|
||||||
fmla v27.4s, v3.4s, v11.s[1]
|
fmla v24.4s, v10.4s, v3.s[0]
|
||||||
fmla v29.4s, v3.4s, v11.s[2]
|
fmla v25.4s, v11.4s, v3.s[0]
|
||||||
fmla v31.4s, v3.4s, v11.s[3]
|
fmla v26.4s, v10.4s, v3.s[1]
|
||||||
ld1 {v4.4s}, [x11], #16
|
fmla v27.4s, v11.4s, v3.s[1]
|
||||||
ld1 {v5.4s}, [x11], #16
|
ld1 {v4.4s, v5.4s}, [x12], #32
|
||||||
fmla v24.4s, v2.4s, v11.s[0]
|
fmla v28.4s, v10.4s, v3.s[2]
|
||||||
fmla v26.4s, v2.4s, v11.s[1]
|
fmla v29.4s, v11.4s, v3.s[2]
|
||||||
fmla v28.4s, v2.4s, v11.s[2]
|
fmla v30.4s, v10.4s, v3.s[3]
|
||||||
fmla v30.4s, v2.4s, v11.s[3]
|
fmla v31.4s, v11.4s, v3.s[3]
|
||||||
fmla v17.4s, v3.4s, v10.s[0]
|
fmla v16.4s, v12.4s, v4.s[0]
|
||||||
fmla v19.4s, v3.4s, v10.s[1]
|
fmla v17.4s, v13.4s, v4.s[0]
|
||||||
fmla v21.4s, v3.4s, v10.s[2]
|
fmla v18.4s, v12.4s, v4.s[1]
|
||||||
fmla v23.4s, v3.4s, v10.s[3]
|
fmla v19.4s, v13.4s, v4.s[1]
|
||||||
ld1 {v12.4s}, [x12], #16
|
fmla v20.4s, v12.4s, v4.s[2]
|
||||||
ld1 {v13.4s}, [x12], #16
|
fmla v21.4s, v13.4s, v4.s[2]
|
||||||
fmla v16.4s, v4.4s, v12.s[0]
|
fmla v22.4s, v12.4s, v4.s[3]
|
||||||
fmla v18.4s, v4.4s, v12.s[1]
|
fmla v23.4s, v13.4s, v4.s[3]
|
||||||
fmla v20.4s, v4.4s, v12.s[2]
|
ld1 {v6.4s,v7.4s}, [x12], #32
|
||||||
fmla v22.4s, v4.4s, v12.s[3]
|
fmla v24.4s, v12.4s, v5.s[0]
|
||||||
fmla v25.4s, v5.4s, v13.s[0]
|
fmla v25.4s, v13.4s, v5.s[0]
|
||||||
fmla v27.4s, v5.4s, v13.s[1]
|
fmla v26.4s, v12.4s, v5.s[1]
|
||||||
fmla v29.4s, v5.4s, v13.s[2]
|
fmla v27.4s, v13.4s, v5.s[1]
|
||||||
fmla v31.4s, v5.4s, v13.s[3]
|
ld1 {v14.4s, v15.4s}, [x1], #32
|
||||||
ld1 {v6.4s}, [x11], #16
|
fmla v28.4s, v12.4s, v5.s[2]
|
||||||
ld1 {v7.4s}, [x11], #16
|
fmla v29.4s, v13.4s, v5.s[2]
|
||||||
fmla v24.4s, v4.4s, v13.s[0]
|
fmla v30.4s, v12.4s, v5.s[3]
|
||||||
fmla v26.4s, v4.4s, v13.s[1]
|
fmla v31.4s, v13.4s, v5.s[3]
|
||||||
fmla v28.4s, v4.4s, v13.s[2]
|
fmla v16.4s, v14.4s, v6.s[0]
|
||||||
fmla v30.4s, v4.4s, v13.s[3]
|
fmla v17.4s, v15.4s, v6.s[0]
|
||||||
fmla v17.4s, v5.4s, v12.s[0]
|
fmla v18.4s, v14.4s, v6.s[1]
|
||||||
fmla v19.4s, v5.4s, v12.s[1]
|
fmla v19.4s, v15.4s, v6.s[1]
|
||||||
fmla v21.4s, v5.4s, v12.s[2]
|
fmla v20.4s, v14.4s, v6.s[2]
|
||||||
fmla v23.4s, v5.4s, v12.s[3]
|
fmla v21.4s, v15.4s, v6.s[2]
|
||||||
ld1 {v14.4s}, [x12], #16
|
fmla v22.4s, v14.4s, v6.s[3]
|
||||||
ld1 {v15.4s}, [x12], #16
|
fmla v23.4s, v15.4s, v6.s[3]
|
||||||
fmla v16.4s, v6.4s, v14.s[0]
|
fmla v24.4s, v14.4s, v7.s[0]
|
||||||
fmla v18.4s, v6.4s, v14.s[1]
|
fmla v25.4s, v15.4s, v7.s[0]
|
||||||
fmla v20.4s, v6.4s, v14.s[2]
|
fmla v26.4s, v14.4s, v7.s[1]
|
||||||
fmla v22.4s, v6.4s, v14.s[3]
|
fmla v27.4s, v15.4s, v7.s[1]
|
||||||
fmla v25.4s, v7.4s, v15.s[0]
|
fmla v28.4s, v14.4s, v7.s[2]
|
||||||
fmla v27.4s, v7.4s, v15.s[1]
|
fmla v29.4s, v15.4s, v7.s[2]
|
||||||
fmla v29.4s, v7.4s, v15.s[2]
|
fmla v30.4s, v14.4s, v7.s[3]
|
||||||
fmla v31.4s, v7.4s, v15.s[3]
|
fmla v31.4s, v15.4s, v7.s[3]
|
||||||
fmla v24.4s, v6.4s, v15.s[0]
|
|
||||||
fmla v26.4s, v6.4s, v15.s[1]
|
|
||||||
fmla v28.4s, v6.4s, v15.s[2]
|
|
||||||
fmla v30.4s, v6.4s, v15.s[3]
|
|
||||||
fmla v17.4s, v7.4s, v14.s[0]
|
|
||||||
fmla v19.4s, v7.4s, v14.s[1]
|
|
||||||
fmla v21.4s, v7.4s, v14.s[2]
|
|
||||||
fmla v23.4s, v7.4s, v14.s[3]
|
|
||||||
subs w13, w13, #4
|
subs w13, w13, #4
|
||||||
b OptLoopMul4
|
b OptLoopMul4
|
||||||
|
|
||||||
CommLoopMul:
|
CommLoopMul:
|
||||||
cmp w13, #1
|
cmp w13, #1
|
||||||
blt Bias
|
blt Bias
|
||||||
ld1 {v0.4s}, [x11], #16
|
|
||||||
ld1 {v2.4s}, [x12], #16
|
ld1 {v0.4s, v1.4s}, [x12], #32
|
||||||
fmla v16.4s, v0.4s, v2.s[0]
|
ld1 {v2.4s, v3.4s}, [x1], #32
|
||||||
fmla v18.4s, v0.4s, v2.s[1]
|
fmla v16.4s, v2.4s, v0.s[0]
|
||||||
ld1 {v1.4s}, [x11], #16
|
fmla v17.4s, v3.4s, v0.s[0]
|
||||||
fmla v20.4s, v0.4s, v2.s[2]
|
fmla v18.4s, v2.4s, v0.s[1]
|
||||||
fmla v22.4s, v0.4s, v2.s[3]
|
fmla v19.4s, v3.4s, v0.s[1]
|
||||||
ld1 {v3.4s}, [x12], #16
|
fmla v20.4s, v2.4s, v0.s[2]
|
||||||
fmla v25.4s, v1.4s, v3.s[0]
|
fmla v21.4s, v3.4s, v0.s[2]
|
||||||
fmla v27.4s, v1.4s, v3.s[1]
|
fmla v22.4s, v2.4s, v0.s[3]
|
||||||
fmla v29.4s, v1.4s, v3.s[2]
|
fmla v23.4s, v3.4s, v0.s[3]
|
||||||
fmla v31.4s, v1.4s, v3.s[3]
|
fmla v24.4s, v2.4s, v1.s[0]
|
||||||
fmla v24.4s, v0.4s, v3.s[0]
|
fmla v25.4s, v3.4s, v1.s[0]
|
||||||
fmla v26.4s, v0.4s, v3.s[1]
|
fmla v26.4s, v2.4s, v1.s[1]
|
||||||
fmla v28.4s, v0.4s, v3.s[2]
|
fmla v27.4s, v3.4s, v1.s[1]
|
||||||
fmla v30.4s, v0.4s, v3.s[3]
|
fmla v28.4s, v2.4s, v1.s[2]
|
||||||
fmla v17.4s, v1.4s, v2.s[0]
|
fmla v29.4s, v3.4s, v1.s[2]
|
||||||
fmla v19.4s, v1.4s, v2.s[1]
|
fmla v30.4s, v2.4s, v1.s[3]
|
||||||
fmla v21.4s, v1.4s, v2.s[2]
|
fmla v31.4s, v3.4s, v1.s[3]
|
||||||
fmla v23.4s, v1.4s, v2.s[3]
|
|
||||||
subs w13, w13, #1
|
subs w13, w13, #1
|
||||||
b CommLoopMul
|
b CommLoopMul
|
||||||
|
|
||||||
Bias:
|
Bias:
|
||||||
|
cmp x3, #0
|
||||||
|
beq Relu
|
||||||
ld1 {v0.4s}, [x14], #16
|
ld1 {v0.4s}, [x14], #16
|
||||||
ld1 {v1.4s}, [x14], #16
|
ld1 {v1.4s}, [x14], #16
|
||||||
dup v2.4s, v0.s[0]
|
fadd v16.4s, v16.4s, v0.4s
|
||||||
fadd v16.4s, v16.4s, v2.4s
|
fadd v17.4s, v17.4s, v1.4s
|
||||||
fadd v17.4s, v17.4s, v2.4s
|
fadd v18.4s, v18.4s, v0.4s
|
||||||
dup v3.4s, v0.s[1]
|
fadd v19.4s, v19.4s, v1.4s
|
||||||
fadd v18.4s, v18.4s, v3.4s
|
fadd v20.4s, v20.4s, v0.4s
|
||||||
fadd v19.4s, v19.4s, v3.4s
|
fadd v21.4s, v21.4s, v1.4s
|
||||||
dup v4.4s, v0.s[2]
|
fadd v22.4s, v22.4s, v0.4s
|
||||||
fadd v20.4s, v20.4s, v4.4s
|
fadd v23.4s, v23.4s, v1.4s
|
||||||
fadd v21.4s, v21.4s, v4.4s
|
fadd v24.4s, v24.4s, v0.4s
|
||||||
dup v5.4s, v0.s[3]
|
fadd v25.4s, v25.4s, v1.4s
|
||||||
fadd v22.4s, v22.4s, v5.4s
|
fadd v26.4s, v26.4s, v0.4s
|
||||||
fadd v23.4s, v23.4s, v5.4s
|
fadd v27.4s, v27.4s, v1.4s
|
||||||
dup v2.4s, v1.s[0]
|
fadd v28.4s, v28.4s, v0.4s
|
||||||
fadd v24.4s, v24.4s, v2.4s
|
fadd v29.4s, v29.4s, v1.4s
|
||||||
fadd v25.4s, v25.4s, v2.4s
|
fadd v30.4s, v30.4s, v0.4s
|
||||||
dup v3.4s, v1.s[1]
|
fadd v31.4s, v31.4s, v1.4s
|
||||||
fadd v26.4s, v26.4s, v3.4s
|
|
||||||
fadd v27.4s, v27.4s, v3.4s
|
|
||||||
dup v4.4s, v1.s[2]
|
|
||||||
fadd v28.4s, v28.4s, v4.4s
|
|
||||||
fadd v29.4s, v29.4s, v4.4s
|
|
||||||
dup v5.4s, v1.s[3]
|
|
||||||
fadd v30.4s, v30.4s, v5.4s
|
|
||||||
fadd v31.4s, v31.4s, v5.4s
|
|
||||||
|
|
||||||
Relu:
|
Relu:
|
||||||
dup v15.4s, w7
|
dup v15.4s, w7
|
||||||
|
@ -281,30 +263,28 @@ Relu:
|
||||||
fmin v31.4s, v31.4s, v15.4s
|
fmin v31.4s, v31.4s, v15.4s
|
||||||
|
|
||||||
TransToOut:
|
TransToOut:
|
||||||
st1 {v16.4s}, [x17], #16
|
st1 {v16.4s}, [x2], #16
|
||||||
st1 {v17.4s}, [x17], #16
|
st1 {v17.4s}, [x2], #16
|
||||||
st1 {v18.4s}, [x17], #16
|
st1 {v18.4s}, [x2], #16
|
||||||
st1 {v19.4s}, [x17], #16
|
st1 {v19.4s}, [x2], #16
|
||||||
st1 {v20.4s}, [x17], #16
|
st1 {v20.4s}, [x2], #16
|
||||||
st1 {v21.4s}, [x17], #16
|
st1 {v21.4s}, [x2], #16
|
||||||
st1 {v22.4s}, [x17], #16
|
st1 {v22.4s}, [x2], #16
|
||||||
st1 {v23.4s}, [x17], #16
|
st1 {v23.4s}, [x2], #16
|
||||||
st1 {v24.4s}, [x17], #16
|
st1 {v24.4s}, [x2], #16
|
||||||
st1 {v25.4s}, [x17], #16
|
st1 {v25.4s}, [x2], #16
|
||||||
st1 {v26.4s}, [x17], #16
|
st1 {v26.4s}, [x2], #16
|
||||||
st1 {v27.4s}, [x17], #16
|
st1 {v27.4s}, [x2], #16
|
||||||
st1 {v28.4s}, [x17], #16
|
st1 {v28.4s}, [x2], #16
|
||||||
st1 {v29.4s}, [x17], #16
|
st1 {v29.4s}, [x2], #16
|
||||||
st1 {v30.4s}, [x17], #16
|
st1 {v30.4s}, [x2], #16
|
||||||
st1 {v31.4s}, [x17], #16
|
st1 {v31.4s}, [x2], #16
|
||||||
|
|
||||||
add w10, w10, #8 // col+=8
|
add w10, w10, #8 // lhs row offset + 8
|
||||||
b L2
|
b L2
|
||||||
|
|
||||||
End2:
|
End2:
|
||||||
add x0, x0, x15 // stride a ptr
|
add w9, w9, #8 // rhs col offset + 8
|
||||||
add x2, x2, x16 // stride c ptr
|
|
||||||
add w9, w9, #8 // row+=8
|
|
||||||
b L1
|
b L1
|
||||||
|
|
||||||
End1:
|
End1:
|
||||||
|
|
|
@ -74,5 +74,9 @@ void MatMul8x8(const float *a, const float *b, float *c, const float *bias, floa
|
||||||
|
|
||||||
void MatMul(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int deep, int row_8_,
|
void MatMul(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int deep, int row_8_,
|
||||||
int col_8_) {
|
int col_8_) {
|
||||||
|
#ifdef __aarch64__
|
||||||
|
MatMulFloatNeon64(a, b, c, bias, maxf, minf, deep, row_8_, col_8_);
|
||||||
|
#else
|
||||||
MatMul8x8(a, b, c, bias, maxf, minf, deep, row_8_, col_8_);
|
MatMul8x8(a, b, c, bias, maxf, minf, deep, row_8_, col_8_);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,19 +21,22 @@
|
||||||
#include "src/runtime/kernel/arm/opclib/op_base.h"
|
#include "src/runtime/kernel/arm/opclib/op_base.h"
|
||||||
#include "src/runtime/kernel/arm/opclib/matmul.h"
|
#include "src/runtime/kernel/arm/opclib/matmul.h"
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
void MatMul(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int depth, int row,
|
void MatMul(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int depth, int row,
|
||||||
int col);
|
int col);
|
||||||
void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col);
|
void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col);
|
||||||
void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, int row, int col);
|
void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, int row, int col);
|
||||||
void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, int row, int col);
|
void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, int row, int col);
|
||||||
|
void MatMul8x8(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int deep,
|
||||||
|
int row_8_, int col_8_);
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
#ifdef __aarch64__
|
||||||
|
void MatMulFloatNeon64(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int depth,
|
||||||
|
int row, int col);
|
||||||
|
#endif
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_MATMUL_H_
|
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_MATMUL_H_
|
||||||
|
|
||||||
|
|
|
@ -48,54 +48,3 @@ void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, co
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo: need to delete, replace by above functions. z00445833
|
|
||||||
void GemmRowCol8x8Major2RowMajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
|
|
||||||
int col8 = UP_ROUND(col, 8);
|
|
||||||
for (int r = 0; r < row; r++) {
|
|
||||||
int rd8 = r / 8;
|
|
||||||
int rm8 = r % 8;
|
|
||||||
for (int c = 0; c < col; c++) {
|
|
||||||
dst_ptr[r * col + c] = src_ptr[rd8 * col8 * 8 + c * 8 + rm8];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Gemm8x8Int8(const int8_t *lhs_data, const int8_t *rhs_data, const int8_t *bias_data, int8_t *output_data,
|
|
||||||
int depth, FcQuantArg *params) {
|
|
||||||
int lhs_offset = params->input.zp_;
|
|
||||||
int rhs_offset = params->weight.zp_;
|
|
||||||
int output_offset = params->output.zp_;
|
|
||||||
int output_multiplier = params->quant_multiplier;
|
|
||||||
int output_shift = params->output_shift;
|
|
||||||
|
|
||||||
for (int row = 0; row < 8; ++row) {
|
|
||||||
for (int col = 0; col < 8; ++col) {
|
|
||||||
int c_index = col * 8 + row;
|
|
||||||
int acc = 0;
|
|
||||||
for (int d = 0; d < depth; ++d) {
|
|
||||||
int a_index = d * 8 + row;
|
|
||||||
int b_index = d * 8 + col;
|
|
||||||
acc += (lhs_data[a_index] - lhs_offset) * (rhs_data[b_index] - rhs_offset);
|
|
||||||
}
|
|
||||||
acc += bias_data[col];
|
|
||||||
acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift, output_shift) + output_offset;
|
|
||||||
acc = MSMAX(CHAR_MIN, MSMIN(CHAR_MAX, acc));
|
|
||||||
output_data[c_index] = (int8_t)acc;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void GemmInt8(const int8_t *input_data, const int8_t *weights_data, const int8_t *bias_data, int8_t *output_data,
|
|
||||||
int row_8, int col_8, int depth, FcQuantArg *params) {
|
|
||||||
for (int r = 0; r < row_8; r += 8) {
|
|
||||||
int8_t *output = output_data + r * col_8;
|
|
||||||
const int8_t *input = input_data + r * depth;
|
|
||||||
for (int c = 0; c < col_8; c += 8) {
|
|
||||||
const int8_t *bias = bias_data + c;
|
|
||||||
const int8_t *weights = weights_data + c * depth;
|
|
||||||
int8_t *dst = output + c * 8;
|
|
||||||
Gemm8x8Int8(input, weights, bias, dst, depth, params);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -20,23 +20,9 @@
|
||||||
#include "src/runtime/kernel/arm/opclib/op_base.h"
|
#include "src/runtime/kernel/arm/opclib/op_base.h"
|
||||||
#include "src/runtime/kernel/arm/opclib/matmul.h"
|
#include "src/runtime/kernel/arm/opclib/matmul.h"
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, const int col8, const int deep,
|
void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, const int col8, const int deep,
|
||||||
const int32_t a_zp, const int32_t b_zp);
|
const int32_t a_zp, const int32_t b_zp);
|
||||||
void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
|
void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
|
||||||
|
|
||||||
void GemmRowCol8x8Major2RowMajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
|
#endif // MINDSPORE_LITE_SRC_BACKEND_ARM_OPCLIB_INT8_MATMUL_H_
|
||||||
void Gemm8x8Int8(const int8_t *lhs_data, const int8_t *rhs_data, const int8_t *bias_data, int8_t *output_data,
|
|
||||||
int depth, FcQuantArg *params);
|
|
||||||
void GemmInt8(const int8_t *input_data, const int8_t *weights_data, const int8_t *bias_data, int8_t *output_data,
|
|
||||||
int row_8, int col_8, int depth, FcQuantArg *params);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_MATMUL_H_
|
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,8 @@ struct FcQuantArg {
|
||||||
QuantArg output;
|
QuantArg output;
|
||||||
int32_t out_act_min;
|
int32_t out_act_min;
|
||||||
int32_t out_act_max;
|
int32_t out_act_max;
|
||||||
int32_t output_shift;
|
int32_t left_shift;
|
||||||
|
int32_t right_shift;
|
||||||
int32_t quant_multiplier;
|
int32_t quant_multiplier;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,144 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
#include "common/common_test.h"
|
||||||
|
#include "mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h"
|
||||||
|
#include "mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.h"
|
||||||
|
#include "mindspore/lite/src/runtime/kernel/arm/opclib/common_func.h"
|
||||||
|
#include "mindspore/lite/src/kernel_registry.h"
|
||||||
|
#include "mindspore/lite/src/lite_kernel.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
using lite::tensor::Tensor;
|
||||||
|
class TestFcInt8 : public mindspore::Common {
|
||||||
|
public:
|
||||||
|
TestFcInt8(){}
|
||||||
|
};
|
||||||
|
|
||||||
|
void Quantize(float *input_data, int length, float scale, int zero_point, int8_t *output_data) {
|
||||||
|
for (int i = 0; i < length; ++i) {
|
||||||
|
int8_t q = static_cast<int8_t>(std::max<float>(
|
||||||
|
std::numeric_limits<int8_t>::min(),
|
||||||
|
std::min<float>(std::numeric_limits<int8_t>::max(), std::round(zero_point + (input_data[i] / scale)))));
|
||||||
|
output_data[i] = q;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Dequantize(int8_t *input_data, int length, float scale, int zero_point, float *output_data) {
|
||||||
|
for (int i = 0; i < length; ++i) {
|
||||||
|
output_data[i] = scale * (input_data[i] - zero_point);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int FcInt8TestInit(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lite::tensor::Tensor *> *outputs_,
|
||||||
|
MatMulParameter *matmal_param, float **correct, double *scale, int *zeropoint) {
|
||||||
|
float input_max = 20;
|
||||||
|
float input_min = -20;
|
||||||
|
float weight_max = 1;
|
||||||
|
float weight_min = -1;
|
||||||
|
float output_max = 20;
|
||||||
|
float output_min = -20;
|
||||||
|
|
||||||
|
double input_scale =
|
||||||
|
(input_max - input_min) / (std::numeric_limits<int8_t>::max() - std::numeric_limits<int8_t>::min());
|
||||||
|
int input_zp = std::numeric_limits<int8_t>::max() - input_max / input_scale;
|
||||||
|
double weight_scale =
|
||||||
|
(weight_max - weight_min) / (std::numeric_limits<int8_t>::max() - std::numeric_limits<int8_t>::min());
|
||||||
|
int weight_zp = std::numeric_limits<int8_t>::max() - weight_max / weight_scale;
|
||||||
|
double output_scale =
|
||||||
|
(output_max - output_min) / (std::numeric_limits<int8_t>::max() - std::numeric_limits<int8_t>::min());
|
||||||
|
int output_zp = std::numeric_limits<int8_t>::max() - output_max / output_scale;
|
||||||
|
*scale = output_scale;
|
||||||
|
*zeropoint = output_zp;
|
||||||
|
|
||||||
|
Tensor *in_t = new Tensor(kNumberTypeInt8, {2, 2, 2, 2}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
|
||||||
|
in_t->MallocData();
|
||||||
|
float in[] = {-3.2366564, -4.7733846, -7.8329225, 16.146885, 5.060793, -6.1471, -1.7680453, -6.5721383,
|
||||||
|
17.87506, -5.1192183, 10.742863, 1.4536934, 19.693445, 19.45783, 5.063163, 0.5234792};
|
||||||
|
Quantize(in, in_t->ElementsNum(), input_scale, input_zp, reinterpret_cast<int8_t *>(in_t->Data()));
|
||||||
|
auto in_quant_arg = new mindspore::lite::tensor::QuantArg();
|
||||||
|
in_quant_arg->zeroPoint = input_zp;
|
||||||
|
in_quant_arg->scale = input_scale;
|
||||||
|
in_t->AddQuantParam(*in_quant_arg);
|
||||||
|
inputs_->push_back(in_t);
|
||||||
|
|
||||||
|
Tensor *weight_t = new Tensor(kNumberTypeInt8, {3, 8}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
|
||||||
|
weight_t->MallocData();
|
||||||
|
float weight[] = {-0.24438887, 0.06738146, -0.8169129, 0.21510671, -0.012470592, -0.053063435,
|
||||||
|
0.6050155, 0.8656233, 0.12911413, -0.028635843, -0.034080597, -0.10622552,
|
||||||
|
-0.012254699, -0.01312836, 0.25241964, -0.4706142, 0.2451482, -0.9558459,
|
||||||
|
0.4481974, 0.33251503, -0.011705584, -0.1720293, -0.39410214, -0.73637343};
|
||||||
|
Quantize(weight, weight_t->ElementsNum(), weight_scale, weight_zp, reinterpret_cast<int8_t *>(weight_t->Data()));
|
||||||
|
auto weight_quant_arg = new mindspore::lite::tensor::QuantArg();
|
||||||
|
weight_quant_arg->zeroPoint = weight_zp;
|
||||||
|
weight_quant_arg->scale = weight_scale;
|
||||||
|
weight_t->AddQuantParam(*weight_quant_arg);
|
||||||
|
inputs_->push_back(weight_t);
|
||||||
|
|
||||||
|
Tensor *bias_t = new Tensor(kNumberTypeInt32, {3}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
|
||||||
|
bias_t->MallocData();
|
||||||
|
memset(bias_t->Data(), 0, sizeof(int) * bias_t->ElementsNum());
|
||||||
|
inputs_->push_back(bias_t);
|
||||||
|
|
||||||
|
Tensor *out_t = new Tensor(kNumberTypeInt8, {2, 3}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
|
||||||
|
out_t->MallocData();
|
||||||
|
auto output_quant_arg = new mindspore::lite::tensor::QuantArg();
|
||||||
|
output_quant_arg->zeroPoint = output_zp;
|
||||||
|
output_quant_arg->scale = output_scale;
|
||||||
|
out_t->AddQuantParam(*output_quant_arg);
|
||||||
|
outputs_->push_back(out_t);
|
||||||
|
|
||||||
|
*correct = reinterpret_cast<float *>(malloc(out_t->ElementsNum() * sizeof(float)));
|
||||||
|
float nchw_co[] = {3.84586822, 0.93586633, 12.16212629, -10.93835061, 2.46887183, 8.61480108};
|
||||||
|
memcpy(*correct, nchw_co, out_t->ElementsNum() * sizeof(float));
|
||||||
|
|
||||||
|
matmal_param->b_transpose_ = true;
|
||||||
|
matmal_param->a_transpose_ = false;
|
||||||
|
matmal_param->has_bias_ = true;
|
||||||
|
matmal_param->minf_ = -FLT_MAX;
|
||||||
|
matmal_param->maxf_ = FLT_MAX;
|
||||||
|
return out_t->ElementsNum();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestFcInt8, fcint8) {
|
||||||
|
std::vector<lite::tensor::Tensor *> inputs_;
|
||||||
|
std::vector<lite::tensor::Tensor *> outputs_;
|
||||||
|
auto matmul_param = new MatMulParameter();
|
||||||
|
float *correct;
|
||||||
|
double output_scale;
|
||||||
|
int output_zp;
|
||||||
|
int total_size = FcInt8TestInit(&inputs_, &outputs_, matmul_param, &correct, &output_scale, &output_zp);
|
||||||
|
lite::Context *ctx = new lite::Context;
|
||||||
|
ctx->threadNum = 2;
|
||||||
|
kernel::FullconnectionInt8CPUKernel *fc =
|
||||||
|
new kernel::FullconnectionInt8CPUKernel(reinterpret_cast<OpParameter *>(matmul_param), inputs_, outputs_, ctx);
|
||||||
|
|
||||||
|
fc->Init();
|
||||||
|
fc->Run();
|
||||||
|
float fout[6] = {0};
|
||||||
|
Dequantize(reinterpret_cast<int8_t *>(outputs_[0]->Data()), outputs_[0]->ElementsNum(), output_scale, output_zp,
|
||||||
|
fout);
|
||||||
|
CompareOutputData(fout, correct, 6, 0.2);
|
||||||
|
delete matmul_param;
|
||||||
|
delete fc;
|
||||||
|
for (auto t : inputs_) delete t;
|
||||||
|
for (auto t : outputs_) delete t;
|
||||||
|
free(correct);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
Loading…
Reference in New Issue