!4997 Optimize the post process of arm64 matmul int8 & fix some bugs of matmul_int8

Merge pull request !4997 from zhanyuan/dev
This commit is contained in:
mindspore-ci-bot 2020-08-25 16:03:24 +08:00 committed by Gitee
commit 5971e31309
11 changed files with 667 additions and 438 deletions

View File

@ -24,7 +24,7 @@
//void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16,
// const int *a_sums, const int *bias, int act_min, int act_max, int out_zp,
// int multiplier, int left_shift, int right_shift);
// int multiplier, int left_shift, int right_shift, int row, int col, int stride);
// x0: a(left matrix ptr)
// x1: b(right matrix ptr)
@ -40,13 +40,18 @@
// w11: multiplier
// w12: left_shift
// w13: right_shift
// w14: row
// w15: col
// w24: stride
MatmulInt8Neon64:
sub sp, sp, #160
sub sp, sp, #192
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
stp x19, x20, [sp], #16
stp x21, x22, [sp], #16
stp x23, x24, [sp], #16
stp x25, x26, [sp], #16
ldr w8, [sp]
ldr w9, [sp, #8]
@ -54,25 +59,28 @@ MatmulInt8Neon64:
ldr w11, [sp, #24]
ldr w12, [sp, #32]
ldr w13, [sp, #40]
ldr w14, [sp, #48]
ldr w15, [sp, #56]
ldr w24, [sp, #64]
mov w15, #0 // b col index
mov w16, #0 // a row index
mov w17, #4 // sizeof(int8)*4
mul w21, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16
mov w17, #1
mov x25, x2
L1:
cmp w15, w4
cmp w4, #0 // if at the end of col4
beq End1
mov w16, #0 // reset a row index
mov w16, w3 // reset a row4 counter
mov w23, w14 // reset a row counter
mov x17, x0 // reload a ptr
mov x22, x6 // reload a_sums ptr
L2:
cmp w16, w3
cmp w16, #0
beq End2
mov x18, x1 // reload b ptr
mov x19, x7 // reload bias ptr
mov x19, x7 // reload bias ptr
mov w20, w5 // reload depth
dup v16.4s, wzr
dup v17.4s, wzr
@ -256,21 +264,128 @@ End3:
sqxtn v15.8b, v13.8h
sqxtn2 v15.16b, v14.8h
st1 {v15.16b}, [x2], #16
add w16, w16, #4 // a row index + 4
cmp w23, #4
blt Write // if rows < 4
cmp w15, #4
blt Write // if cols < 4
st1 {v15.s}[0], [x2], x24
st1 {v15.s}[1], [x2], x24
st1 {v15.s}[2], [x2], x24
st1 {v15.s}[3], [x2], x24
b Endwrite
Write:
cmp w15, #4
beq WriteCol4
cmp w15, #3
beq WriteCol3
cmp w15, #2
beq WriteCol2
cmp w15, #1
beq WriteCol1
WriteCol4:
st1 {v15.s}[0], [x2], x24
cmp w23, #1
beq Endwrite
st1 {v15.s}[1], [x2], x24
cmp w23, #2
beq Endwrite
st1 {v15.s}[2], [x2], x24
cmp w23, #3
beq Endwrite
st1 {v15.s}[3], [x2], x24
b Endwrite
WriteCol3:
mov x26, x2
st1 {v15.b}[0], [x26], #1
st1 {v15.b}[1], [x26], #1
st1 {v15.b}[2], [x26], #1
add x2, x2, x24
cmp w23, #1
beq Endwrite
mov x26, x2
st1 {v15.b}[4], [x26], #1
st1 {v15.b}[5], [x26], #1
st1 {v15.b}[6], [x26], #1
add x2, x2, x24
cmp w23, #2
beq Endwrite
mov x26, x2
st1 {v15.b}[8], [x26], #1
st1 {v15.b}[9], [x26], #1
st1 {v15.b}[10], [x26], #1
add x2, x2, x24
cmp w23, #3
beq Endwrite
mov x26, x2
st1 {v15.b}[12], [x26], #1
st1 {v15.b}[13], [x26], #1
st1 {v15.b}[14], [x26], #1
add x2, x2, x24
b Endwrite
WriteCol2:
mov x26, x2
st1 {v15.b}[0], [x26], #1
st1 {v15.b}[1], [x26], #1
add x2, x2, x24
cmp w23, #1
beq Endwrite
mov x26, x2
st1 {v15.b}[4], [x26], #1
st1 {v15.b}[5], [x26], #1
add x2, x2, x24
cmp w23, #2
beq Endwrite
mov x26, x2
st1 {v15.b}[8], [x26], #1
st1 {v15.b}[9], [x26], #1
add x2, x2, x24
cmp w23, #3
beq Endwrite
mov x26, x2
st1 {v15.b}[12], [x26], #1
st1 {v15.b}[13], [x26], #1
add x2, x2, x24
b Endwrite
WriteCol1:
st1 {v15.b}[0], [x2], x24
cmp w23, #1
beq Endwrite
st1 {v15.b}[4], [x2], x24
cmp w23, #2
beq Endwrite
st1 {v15.b}[8], [x2], x24
cmp w23, #3
beq Endwrite
st1 {v15.b}[12], [x2], x24
b Endwrite
Endwrite:
sub w16, w16, #4 // a row4 counter - 4
sub w23, w23, #4 // a row counter - 4
b L2
End2:
add w15, w15, #4 // b col index + 4
sub w4, w4, #4 // b col4 counter - 4
sub w15, w15, #4 // b col counter - 4
add x1, x1, x21 // b ptr + stride
add x7, x7, #16 // bias ptr + stride
add x25, x25, #4 // output + stride(4 * sizeof(int8))
mov x2, x25
b L1
End1:
sub sp, sp, #160
sub sp, sp, #192
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ldp x19, x20, [sp], #16
ldp x21, x22, [sp], #16
ldp x23, x24, [sp], #16
ldp x25, x26, [sp], #16
ret
#endif

View File

@ -228,19 +228,3 @@ void IndirectGemmFp32_Comm(float *output, const float *input, const float *weigh
return;
}
void SimplePostFuncInt8(const int *in, int8_t *out, int oc, int plane, int plane8, int32_t multiplier,
int32_t left_shift, int32_t right_shift, int32_t zp) {
/* (int32_t)row8x8-major * multiplier => (int8_t)row-major */
for (int r = 0; r < plane; r++) {
for (int c = 0; c < oc; c++) {
int c8div = c / 8, c8mod = c % 8;
int src_index = c8div * plane8 * 8 + r * 8 + c8mod;
int dst_index = r * oc + c;
int32_t value = in[src_index];
value = MultiplyByQuantizedMultiplier(value, multiplier, left_shift, right_shift) + zp;
value = MSMIN(CHAR_MAX, value);
value = MSMAX(CHAR_MIN, value);
out[dst_index] = (int8_t)value;
}
}
}

View File

@ -117,25 +117,6 @@ void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col)
}
}
void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, const int col8, const int deep,
const int32_t a_zp, const int32_t b_zp) {
/* col8-major * row8-major => row8x8-major */
for (int row = 0; row < row8; row++) {
for (int col = 0; col < col8; col++) {
int r8div = row / 8, r8mod = row % 8;
int c8div = col / 8, c8mod = col % 8;
size_t ci = c8div * row8 * 8 + row * 8 + c8mod;
int32_t value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r8div * deep * 8 + d * 8 + r8mod;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + ((int32_t)a[ai] - a_zp) * ((int32_t)b[bi] - b_zp);
}
c[ci] = value;
}
}
}
void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
const int *input_sum, const int *bias) {
/* row4x16-major * row16x4-major => row4x4-major */
@ -191,6 +172,36 @@ void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row
return;
}
/* row4x16-major * col16x4-major => row4x4-major */
void MatmulInt8(const int8_t *a, const int8_t *b, int8_t *dst, const int *a_sums, const int *bias, int act_min,
int act_max, int out_zp, int multiplier, int left_shift, int right_shift, int row, int col, int deep16,
int stride) {
int8_t *output = dst;
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r4div = r / C4NUM;
int r4mod = r % C4NUM;
int c4div = c / C4NUM;
int c4mod = c % C4NUM;
int value = 0;
for (int d = 0; d < deep16; d++) {
int d16div = d / C16NUM;
int d16mod = d % C16NUM;
size_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod;
size_t bi = c4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod;
value += a[ai] * b[bi];
}
value -= a_sums[r];
value += bias[c];
value = MultiplyByQuantizedMultiplier(value, multiplier, left_shift, right_shift) + out_zp;
value = MSMIN(INT8_MAX, value);
value = MSMAX(INT8_MIN, value);
output[c] = (int8_t)value;
}
output += stride;
}
}
void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16) {
int stride = sizeof(int8_t) * 16 * 4;
for (int r = 0; r < row; ++r) {
@ -213,23 +224,28 @@ void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_1
}
}
void RowMajor2Asums(int8_t *a, int row, int col, int b_zp, int *dst) {
// dst: weight_zp * input_row_sums
void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst) {
for (int r = 0; r < row; ++r) {
int sum = 0;
for (int c = 0; c < col; ++c) {
int src_idx = r * col + c;
dst[r] += a[src_idx];
sum += input[src_idx];
}
dst[r] *= b_zp;
sum *= weight_zp;
dst[r] = sum;
}
}
void RowMajor2Bbias(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst) {
// dst: bias + depth*input_zp*weight_zp - input_zp*weight_col_sums
void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, int *bias, int *dst) {
for (int c = 0; c < col; ++c) {
int sum = 0;
for (int r = 0; r < row; ++r) {
int src_idx = r * col + c;
dst[c] += b[src_idx];
sum += weight[src_idx];
}
dst[c] = row * a_zp * b_zp - a_zp * dst[c];
dst[c] = row * input_zp * weight_zp - input_zp * sum;
if (bias) {
dst[c] += bias[c];
}

View File

@ -24,8 +24,6 @@
#ifdef __cplusplus
extern "C" {
#endif
void MatMulInt8(const int8_t *a, const int8_t *b, int *c, const int row8, const int col8, const int deep,
const int a_zp, const int b_zp);
void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
const int *input_sum, const int *bias);
void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16,
@ -39,15 +37,16 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col);
void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16);
void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16);
void RowMajor2Asums(int8_t *a, int row, int col, int b_zp, int *dst);
void RowMajor2Bbias(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst);
void Row4x4Major2RowMajor(int8_t *src, int row4, int8_t *dst, int row, int cow);
void CalcInputSums(int8_t *a, int row, int col, int b_zp, int *dst);
void CalcWeightBiasSums(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst);
void MatmulInt8(const int8_t *a, const int8_t *b, int8_t *dst, const int *a_sums, const int *bias, int act_min,
int act_max, int out_zp, int multiplier, int left_shift, int right_shift, int row, int col, int deep16,
int stride);
#ifdef ENABLE_ARM64
// bias = bias + depth * a_zp * b_zp - a_zp * b_sums
void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums,
const int *bias, int act_min, int act_max, int out_zp, int multiplier, int left_shift,
int right_shift);
int right_shift, int row, int col, int stride);
void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16,
const int *input_sum, const int *bias);

View File

@ -39,36 +39,32 @@ int FullconnectionInt8CPUKernel::ReSize() {
fc_param_->row_8_ = UP_ROUND(fc_param_->row_, 8);
fc_param_->col_8_ = UP_ROUND(fc_param_->col_, 8);
thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, 8));
thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, 8), thread_count_);
a_c8_ptr_ =
reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(fc_param_->row_8_ * fc_param_->deep_ * sizeof(int8_t)));
if (!a_c8_ptr_) {
return RET_MEMORY_FAILED;
}
memset(a_c8_ptr_, 0, fc_param_->row_8_ * fc_param_->deep_ * sizeof(int8_t));
b_r8_ptr_ =
reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(fc_param_->col_8_ * fc_param_->deep_ * sizeof(int8_t)));
if (!b_r8_ptr_) {
return RET_MEMORY_FAILED;
}
memset(b_r8_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(int8_t));
r4_ = UP_ROUND(fc_param_->row_, 4);
c4_ = UP_ROUND(fc_param_->col_, 4);
d16_ = UP_ROUND(fc_param_->deep_, 16);
thread_count_ = MSMIN(thread_count_, UP_DIV(c4_, 4));
thread_stride_ = UP_DIV(UP_DIV(c4_, 4), thread_count_);
a_r4x16_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(r4_ * d16_ * sizeof(int8_t)));
if (!a_r4x16_ptr_) return RET_MEMORY_FAILED;
memset(a_r4x16_ptr_, 0, r4_ * d16_ * sizeof(int8_t));
b_c16x4_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(c4_ * d16_ * sizeof(int8_t)));
if (!b_c16x4_ptr_) return RET_MEMORY_FAILED;
memset(b_c16x4_ptr_, 0, c4_ * d16_ * sizeof(int8_t));
input_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(r4_ * sizeof(int)));
if (!input_sums_) return RET_MEMORY_FAILED;
memset(input_sums_, 0, r4_ * sizeof(int));
weight_bias_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(c4_ * sizeof(int)));
if (!weight_bias_sums_) return RET_MEMORY_FAILED;
memset(weight_bias_sums_, 0, c4_ * sizeof(int));
auto weight_data = reinterpret_cast<int8_t *>(in_tensors_[1]->Data());
RowMajor2Col8MajorInt8(weight_data, b_r8_ptr_, fc_param_->col_, fc_param_->deep_);
c_r8x8_ptr_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(fc_param_->row_8_ * fc_param_->col_8_ * sizeof(int)));
if (!c_r8x8_ptr_) {
return RET_MEMORY_FAILED;
}
memset(c_r8x8_ptr_, 0, fc_param_->row_8_ * fc_param_->col_8_ * sizeof(int));
auto bias_len = fc_param_->col_8_ * sizeof(int);
bias_ptr_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(bias_len));
if (!bias_ptr_) {
return RET_MEMORY_FAILED;
}
memset(bias_ptr_, 0, bias_len);
RowMajor2Row4x16Major(weight_data, fc_param_->col_, fc_param_->deep_, b_c16x4_ptr_, d16_);
if (in_tensors_.size() == 3) {
auto bias_len = fc_param_->col_8_ * sizeof(int);
bias_ptr_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(bias_len));
if (!bias_ptr_) return RET_MEMORY_FAILED;
memcpy(bias_ptr_, in_tensors_[2]->Data(), bias_len);
} else {
bias_ptr_ = NULL;
}
auto input_tensor = in_tensors_[0];
@ -93,18 +89,32 @@ int FullconnectionInt8CPUKernel::ReSize() {
CalculateActivationRangeQuantized(fc_param_->act_type_ == ActType_Relu, fc_param_->act_type_ == ActType_Relu6,
quant_params_.output.zp_, quant_params_.output.scale_, &quant_params_.out_act_min,
&quant_params_.out_act_max);
CalcWeightBiasSums(weight_data, fc_param_->deep_, fc_param_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
bias_ptr_, weight_bias_sums_);
return RET_OK;
}
int FullconnectionInt8CPUKernel::RunImpl(int task_id) {
int cur_oc = MSMIN(thread_stride_, UP_DIV(fc_param_->col_8_, 8) - task_id * thread_stride_);
int cur_oc = MSMIN(thread_stride_, UP_DIV(c4_, 4) - task_id * thread_stride_);
if (cur_oc <= 0) {
return RET_OK;
}
auto &p = quant_params_;
auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_;
auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->row_8_;
MatMulInt8(a_c8_ptr_, cur_b, cur_c, fc_param_->row_8_, cur_oc * 8, fc_param_->deep_, p.input.zp_, p.weight.zp_);
int cur_oc_res = MSMIN(thread_stride_ * C4NUM, fc_param_->col_ - task_id * thread_stride_ * C4NUM);
auto &q = quant_params_;
auto &p = fc_param_;
auto cur_b = b_c16x4_ptr_ + task_id * thread_stride_ * C4NUM * d16_;
auto cur_bias = weight_bias_sums_ + task_id * thread_stride_ * C4NUM;
auto output_ptr = reinterpret_cast<int8_t *>(out_tensors_[0]->Data());
auto cur_c = output_ptr + task_id * thread_stride_ * C4NUM;
#ifdef ENABLE_ARM64
MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, q.out_act_min,
q.out_act_max, q.output.zp_, q.quant_multiplier, q.left_shift, q.right_shift, p->row_, cur_oc_res,
p->col_ * sizeof(int8_t));
#else
MatmulInt8(a_r4x16_ptr_, cur_b, cur_c, input_sums_, cur_bias, q.out_act_min, q.out_act_max, q.output.zp_,
q.quant_multiplier, q.left_shift, q.right_shift, p->row_, cur_oc_res, d16_, p->col_);
#endif
return RET_OK;
}
@ -124,13 +134,10 @@ int FullconnectionInt8CPUKernel::Run() {
MS_LOG(ERROR) << "Prepare failed.";
return RET_ERROR;
}
auto a_ptr = reinterpret_cast<int8_t *>(in_tensors_[0]->Data());
auto output_ptr = reinterpret_cast<int8_t *>(out_tensors_[0]->Data());
auto &p = quant_params_;
RowMajor2Col8MajorInt8(a_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_);
auto input_ptr = reinterpret_cast<int8_t *>(in_tensors_[0]->Data());
RowMajor2Row4x16Major(input_ptr, fc_param_->row_, fc_param_->deep_, a_r4x16_ptr_, d16_);
CalcInputSums(input_ptr, fc_param_->row_, fc_param_->deep_, quant_params_.weight.zp_, input_sums_);
LiteBackendParallelLaunch(FcInt8Run, this, thread_count_);
PostFuncInt8C8(c_r8x8_ptr_, bias_ptr_, output_ptr, fc_param_->col_, fc_param_->row_, p.quant_multiplier, p.left_shift,
p.right_shift, p.output.zp_, p.out_act_min, p.out_act_max);
return RET_OK;
}

View File

@ -41,28 +41,36 @@ class FullconnectionInt8CPUKernel : public FullconnectionBaseCPUKernel {
private:
void FreeTmpBuffer() {
if (a_c8_ptr_ != nullptr) {
ctx_->allocator->Free(a_c8_ptr_);
a_c8_ptr_ = nullptr;
if (a_r4x16_ptr_ != nullptr) {
ctx_->allocator->Free(a_r4x16_ptr_);
a_r4x16_ptr_ = nullptr;
}
if (b_r8_ptr_ != nullptr) {
ctx_->allocator->Free(b_r8_ptr_);
b_r8_ptr_ = nullptr;
if (b_c16x4_ptr_ != nullptr) {
ctx_->allocator->Free(b_c16x4_ptr_);
b_c16x4_ptr_ = nullptr;
}
if (c_r8x8_ptr_ != nullptr) {
ctx_->allocator->Free(c_r8x8_ptr_);
c_r8x8_ptr_ = nullptr;
if (input_sums_ != nullptr) {
ctx_->allocator->Free(input_sums_);
input_sums_ = nullptr;
}
if (weight_bias_sums_ != nullptr) {
ctx_->allocator->Free(weight_bias_sums_);
weight_bias_sums_ = nullptr;
}
if (bias_ptr_ != nullptr) {
ctx_->allocator->Free(bias_ptr_);
bias_ptr_ = nullptr;
ctx_->allocator->Free(weight_bias_sums_);
weight_bias_sums_ = nullptr;
}
}
MatmulQuantArg quant_params_;
int8_t *a_c8_ptr_ = nullptr;
int8_t *b_r8_ptr_ = nullptr;
int *c_r8x8_ptr_ = nullptr;
int8_t *a_r4x16_ptr_ = nullptr;
int8_t *b_c16x4_ptr_ = nullptr;
int *input_sums_ = nullptr;
int *weight_bias_sums_ = nullptr;
int *bias_ptr_ = nullptr;
int r4_;
int c4_;
int d16_;
};
} // namespace mindspore::kernel

View File

@ -48,46 +48,23 @@ int MatmulInt8CPUKernel::ReSize() {
params_->row_8_ = UP_ROUND(params_->row_, 8);
params_->col_8_ = UP_ROUND(params_->col_, 8);
#ifdef ENABLE_ARM64
r4_ = UP_ROUND(params_->row_, 4);
c4_ = UP_ROUND(params_->col_, 4);
d16_ = UP_ROUND(params_->deep_, 16);
a_r4d16_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(r4_ * d16_ * sizeof(int8_t)));
if (!a_r4d16_ptr_) return RET_MEMORY_FAILED;
memset(a_r4d16_ptr_, 0, r4_ * d16_ * sizeof(int8_t));
b_c4d16_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(c4_ * d16_ * sizeof(int8_t)));
if (!b_c4d16_ptr_) return RET_MEMORY_FAILED;
memset(b_c4d16_ptr_, 0, c4_ * d16_ * sizeof(int8_t));
c_r4c4_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(r4_ * c4_ * sizeof(int8_t)));
if (!c_r4c4_ptr_) return RET_MEMORY_FAILED;
memset(c_r4c4_ptr_, 0, r4_ * c4_ * sizeof(int8_t));
a_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(r4_ * sizeof(int)));
if (!a_sums_) return RET_MEMORY_FAILED;
memset(a_sums_, 0, r4_ * sizeof(int));
b_bias_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(c4_ * sizeof(int)));
if (!b_bias_) return RET_MEMORY_FAILED;
memset(b_bias_, 0, c4_ * sizeof(int));
a_r4x16_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(r4_ * d16_ * sizeof(int8_t)));
if (!a_r4x16_ptr_) return RET_MEMORY_FAILED;
memset(a_r4x16_ptr_, 0, r4_ * d16_ * sizeof(int8_t));
b_c16x4_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(c4_ * d16_ * sizeof(int8_t)));
if (!b_c16x4_ptr_) return RET_MEMORY_FAILED;
memset(b_c16x4_ptr_, 0, c4_ * d16_ * sizeof(int8_t));
input_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(r4_ * sizeof(int)));
if (!input_sums_) return RET_MEMORY_FAILED;
memset(input_sums_, 0, r4_ * sizeof(int));
weight_bias_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(c4_ * sizeof(int)));
if (!weight_bias_sums_) return RET_MEMORY_FAILED;
memset(weight_bias_sums_, 0, c4_ * sizeof(int));
thread_count_ = MSMIN(thread_count_, UP_DIV(c4_, 4));
thread_stride_ = UP_DIV(UP_DIV(c4_, 4), thread_count_);
#else
a_c8_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(params_->row_8_ * params_->deep_ * sizeof(int8_t)));
if (!a_c8_ptr_) {
return RET_MEMORY_FAILED;
}
memset(a_c8_ptr_, 0, params_->row_8_ * params_->deep_ * sizeof(int8_t));
b_r8_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(params_->col_8_ * params_->deep_ * sizeof(int8_t)));
if (!b_r8_ptr_) {
return RET_MEMORY_FAILED;
}
memset(b_r8_ptr_, 0, params_->col_8_ * params_->deep_ * sizeof(int8_t));
c_r8x8_ptr_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(params_->row_8_ * params_->col_8_ * sizeof(int)));
if (!c_r8x8_ptr_) {
return RET_MEMORY_FAILED;
}
memset(c_r8x8_ptr_, 0, params_->row_8_ * params_->col_8_ * sizeof(int));
thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8));
thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_);
#endif
auto input_tensor = in_tensors_[0];
auto params = input_tensor->GetQuantParams();
@ -112,27 +89,25 @@ int MatmulInt8CPUKernel::ReSize() {
}
int MatmulInt8CPUKernel::RunImpl(int task_id) {
#ifdef ENABLE_ARM64
int cur_oc = MSMIN(thread_stride_, UP_DIV(c4_, 4) - task_id * thread_stride_);
if (cur_oc <= 0) {
return RET_OK;
}
auto cur_b = b_c4d16_ptr_ + task_id * thread_stride_ * 4 * d16_;
auto cur_c = c_r4c4_ptr_ + task_id * thread_stride_ * 4 * r4_;
auto &p = quant_params_;
MatmulInt8Neon64(a_r4d16_ptr_, cur_b, cur_c, r4_, c4_, d16_, a_sums_, b_bias_, INT_MIN, INT_MAX, p.output.zp_,
p.quant_multiplier, p.left_shift, p.right_shift);
#else
int cur_oc = MSMIN(thread_stride_, UP_DIV(params_->col_8_, 8) - task_id * thread_stride_);
if (cur_oc <= 0) {
return RET_OK;
}
auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_;
auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * params_->row_8_;
int cur_oc_res = MSMIN(thread_stride_ * C4NUM, params_->col_ - task_id * thread_stride_ * C4NUM);
auto cur_b = b_c16x4_ptr_ + task_id * thread_stride_ * 4 * d16_;
auto cur_bias = weight_bias_sums_ + task_id * thread_stride_ * 4;
auto cur_c = c_ptr_ + task_id * thread_stride_ * 4;
MatMulInt8(a_c8_ptr_, cur_b, cur_c, params_->row_8_, cur_oc * 8, params_->deep_, quant_params_.input.zp_,
quant_params_.weight.zp_);
auto &p = quant_params_;
#ifdef ENABLE_ARM64
MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, INT8_MIN, INT8_MAX,
p.output.zp_, p.quant_multiplier, p.left_shift, p.right_shift, params_->row_, cur_oc_res,
params_->col_ * sizeof(int8_t));
#else
MatmulInt8(a_r4x16_ptr_, cur_b, cur_c, input_sums_, cur_bias, INT8_MIN, INT8_MAX, p.output.zp_, p.quant_multiplier,
p.left_shift, p.right_shift, params_->row_, cur_oc_res, d16_, params_->col_);
#endif
return RET_OK;
}
@ -162,43 +137,27 @@ int MatmulInt8CPUKernel::Run() {
for (int i = 0; i < params_->batch; ++i) {
auto cur_a_ptr = a_ptr + i * a_stride;
auto cur_b_ptr = b_ptr + i * b_stride;
auto cur_c_ptr = c_ptr + i * c_stride;
#ifdef ENABLE_ARM64
if (params_->a_transpose_) {
RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4d16_ptr_, d16_);
RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_, d16_);
} else {
RowMajor2Row4x16Major(cur_a_ptr, params_->row_, params_->deep_, a_r4d16_ptr_, d16_);
RowMajor2Row4x16Major(cur_a_ptr, params_->row_, params_->deep_, a_r4x16_ptr_, d16_);
}
if (params_->b_transpose_) {
RowMajor2Row4x16Major(cur_b_ptr, params_->col_, params_->deep_, b_c4d16_ptr_, d16_);
RowMajor2Row4x16Major(cur_b_ptr, params_->col_, params_->deep_, b_c16x4_ptr_, d16_);
} else {
RowMajor2Col16x4Major(cur_b_ptr, params_->deep_, params_->col_, b_c4d16_ptr_, d16_);
RowMajor2Col16x4Major(cur_b_ptr, params_->deep_, params_->col_, b_c16x4_ptr_, d16_);
}
c_ptr_ = c_ptr + i * c_stride;
auto &q = quant_params_;
RowMajor2Asums(cur_a_ptr, params_->row_, params_->deep_, q.weight.zp_, a_sums_);
RowMajor2Bbias(cur_b_ptr, params_->deep_, params_->col_, q.input.zp_, q.weight.zp_, NULL, b_bias_);
LiteBackendParallelLaunch(MatmulInt8Run, this, thread_count_);
Row4x4Major2RowMajor(c_r4c4_ptr_, r4_, cur_c_ptr, params_->row_, params_->col_);
#else
if (params_->a_transpose_) {
RowMajor2Row8MajorInt8(cur_a_ptr, a_c8_ptr_, params_->deep_, params_->row_);
} else {
RowMajor2Col8MajorInt8(cur_a_ptr, a_c8_ptr_, params_->row_, params_->deep_);
CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, q.weight.zp_, input_sums_);
CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, q.input.zp_, q.weight.zp_, NULL, weight_bias_sums_);
ret = LiteBackendParallelLaunch(MatmulInt8Run, this, thread_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "MatmulInt8Run error: [" << ret << "]";
return ret;
}
if (params_->b_transpose_) {
RowMajor2Col8MajorInt8(cur_b_ptr, b_r8_ptr_, params_->col_, params_->deep_);
} else {
RowMajor2Row8MajorInt8(cur_b_ptr, b_r8_ptr_, params_->deep_, params_->col_);
}
LiteBackendParallelLaunch(MatmulInt8Run, this, thread_count_);
auto &q = quant_params_;
SimplePostFuncInt8(c_r8x8_ptr_, cur_c_ptr, params_->col_, params_->row_, params_->row_8_, q.quant_multiplier,
q.left_shift, q.right_shift, q.output.zp_);
#endif
}
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -39,57 +39,32 @@ class MatmulInt8CPUKernel : public MatmulBaseCPUKernel {
private:
void FreeTmpBuffer() {
#ifdef ENABLE_ARM64
if (a_r4d16_ptr_ != nullptr) {
ctx_->allocator->Free(a_r4d16_ptr_);
a_r4d16_ptr_ = nullptr;
if (a_r4x16_ptr_ != nullptr) {
ctx_->allocator->Free(a_r4x16_ptr_);
a_r4x16_ptr_ = nullptr;
}
if (b_c4d16_ptr_ != nullptr) {
ctx_->allocator->Free(b_c4d16_ptr_);
b_c4d16_ptr_ = nullptr;
if (b_c16x4_ptr_ != nullptr) {
ctx_->allocator->Free(b_c16x4_ptr_);
b_c16x4_ptr_ = nullptr;
}
if (c_r4c4_ptr_ != nullptr) {
ctx_->allocator->Free(c_r4c4_ptr_);
c_r4c4_ptr_ = nullptr;
if (input_sums_ != nullptr) {
ctx_->allocator->Free(input_sums_);
input_sums_ = nullptr;
}
if (a_sums_ != nullptr) {
ctx_->allocator->Free(a_sums_);
a_sums_ = nullptr;
if (weight_bias_sums_ != nullptr) {
ctx_->allocator->Free(weight_bias_sums_);
weight_bias_sums_ = nullptr;
}
if (b_bias_ != nullptr) {
ctx_->allocator->Free(b_bias_);
b_bias_ = nullptr;
}
#else
if (a_c8_ptr_ != nullptr) {
ctx_->allocator->Free(a_c8_ptr_);
a_c8_ptr_ = nullptr;
}
if (b_r8_ptr_ != nullptr) {
ctx_->allocator->Free(b_r8_ptr_);
b_r8_ptr_ = nullptr;
}
if (c_r8x8_ptr_ != nullptr) {
ctx_->allocator->Free(c_r8x8_ptr_);
c_r8x8_ptr_ = nullptr;
}
#endif
}
MatmulQuantArg quant_params_;
#ifdef ENABLE_ARM64
int8_t *a_r4d16_ptr_ = nullptr;
int8_t *b_c4d16_ptr_ = nullptr;
int8_t *c_r4c4_ptr_ = nullptr;
int *a_sums_ = nullptr;
int *b_bias_ = nullptr;
int8_t *a_r4x16_ptr_ = nullptr;
int8_t *b_c16x4_ptr_ = nullptr;
int8_t *c_ptr_ = nullptr;
int *input_sums_ = nullptr;
int *weight_bias_sums_ = nullptr;
int r4_;
int c4_;
int d16_;
#else
int8_t *a_c8_ptr_ = nullptr;
int8_t *b_r8_ptr_ = nullptr;
int *c_r8x8_ptr_ = nullptr;
#endif
}; // namespace mindspore::kernel
} // namespace mindspore::kernel

View File

@ -134,54 +134,6 @@ TEST_F(TestDeconvInt8, PackInputTest1) {
CompareOutputData(dst, co, 8 * 32, 1);
}
TEST_F(TestDeconvInt8, MatMulTest1) {
int8_t a_row_major_10_12[] = {
-6, 76, 32, 80, -73, 8, -85, -3, 114, 80, 30, 42, -41, 117, 62, -76, -77, -111, 88, 105,
68, 105, -74, 13, 51, 94, 31, -52, -92, -4, -35, -71, 101, -93, 46, -65, 57, -41, -51, 77,
1, 9, 73, -19, -36, 57, 81, -24, 40, 103, 112, 109, -41, -68, 57, 61, 55, -20, 3, 2,
17, -16, -31, 58, -4, 67, -4, -95, -5, -72, 81, 15, -7, -16, -47, 112, 114, -26, -98, 53,
15, -49, 26, 19, 19, 8, -57, -35, -79, 118, 29, 21, 37, -48, 83, 7, 124, 113, -5, 15,
-8, 107, -65, -88, 50, -47, -80, -84, 3, -45, 92, 42, -20, -101, 106, -10, 89, 67, 55, 10};
int32_t zp_a = 15;
int8_t a_col8_major[16 * 12] = {0};
int8_t b_col_major_12_18[] = {
92, 27, 22, 52, -112, -20, -57, -2, 89, 32, 93, -66, -25, -54, 94, -97, -119, -98, 101, -99,
77, -83, 76, 95, 59, 97, 8, 40, -109, -20, 67, -107, 37, -6, -54, -20, -30, 36, -106, -103,
-3, -86, -82, 59, 4, -75, -50, -106, 55, 104, -117, -71, -20, -85, -77, 16, -25, -58, 4, 80,
-75, 94, 32, -68, 2, 40, 56, -103, 11, -98, -70, -69, 0, 57, -6, 82, 66, -112, -61, 33,
-77, -53, 95, -38, 87, -46, -3, 81, -47, 43, 21, 26, -45, -57, 50, -24, -82, -114, 61, 46,
-53, 78, -24, 31, -7, 37, 29, 38, 45, 106, 52, -42, 31, -6, -61, -87, 2, 79, -5, -42,
43, -106, -104, 7, 91, -63, 58, 97, -15, 74, -96, 15, -23, -3, -47, -97, 100, -54, 26, -46,
35, 26, 100, -80, 34, -25, 96, -67, -80, -27, 66, 41, 41, -43, -43, -38, -4, -64, 31, 7,
-8, 6, -2, 39, -119, 53, 75, -91, -44, 77, -62, 22, -44, 78, -67, -48, -115, -4, 43, 81,
40, -20, -5, -89, 60, -62, -4, -48, 66, -64, -69, 62, 17, -89, 1, 87, 81, 32, -29, 51,
40, 27, 66, 67, 11, -69, 85, -79, -106, 55, 22, -23, 62, 69, -74, 49};
int32_t zp_b = -20;
int8_t b_row8_major[12 * 24] = {0};
int32_t co_row_major_10_18[] = {
32005, 3597, 16595, -3458, 6627, -6663, 818, -3910, 10228, 15079, -19205, -10203, -3178, -10046,
10374, -6199, 5330, 12163, 1819, 20533, 17382, 18283, 9778, 9185, -12623, -26234, -11987, 7904,
8144, -1603, 27611, -10190, -20053, 4999, -28389, 21852, 24680, 25858, 23506, 17944, 11768, 24378,
-6102, -4675, -23460, 10434, -47579, 1986, 12018, -19418, -7248, 4938, -32613, -941, 8171, -4788,
3325, -11310, -8351, -14786, 6909, 16401, 2017, -6456, 11242, 7393, -9119, 17312, 2646, -14402,
7201, -9949, 23986, 17607, 27461, -1547, 2783, 7558, 19487, 11158, -2686, 6328, -8225, -11668,
21858, -2079, -8671, -639, -1544, 1235, 1156, 6582, 2829, -10311, -2692, 5154, 1527, 10870,
106, -8189, -24174, -1846, -15399, -3598, 14874, -5591, -619, -13667, -6053, -31103, -24499, 13008,
9143, -17982, 28437, 2176, -2114, -11631, 10779, -1032, -24690, -3112, 2125, 432, 20270, -33859,
8907, 10063, 1603, 3761, 4805, 4904, -15594, 10786, 4287, -13591, -18777, -1679, 2109, -2243,
12051, -8504, -6558, 4209, 13606, -25803, 27922, 12092, 7140, 27142, -12267, 2339, -26224, 23674,
-26579, -11398, -1823, -18976, 3641, 4415, -24878, -2045, 15937, 41465, 12601, -14513, -17619, -5728,
334, -424, 8147, -1369, 5984, 11000, 19016, 4456, -25920, 4506, 5930, 15458};
int32_t c_row8x8_major[16 * 24] = {0};
int32_t out_row_major[180] = {0};
RowMajor2Col8MajorInt8(a_row_major_10_12, a_col8_major, 10, 12);
RowMajor2Col8MajorInt8(b_col_major_12_18, b_row8_major, 18, 12);
MatMulInt8(a_col8_major, b_row8_major, c_row8x8_major, 16, 24, 12, zp_a, zp_b);
Row8x8Major2RowMajor(reinterpret_cast<float *>(c_row8x8_major), reinterpret_cast<float *>(out_row_major), 10, 18, 18);
CompareOutputData(out_row_major, co_row_major_10_18, 180, 1);
}
TEST_F(TestDeconvInt8, InputSumTest1) {
int8_t packed_a[] = {
-6, 76, 32, 80, -73, 8, -85, -3, 114, 80, 30, 42, 15, 15, 15, 15, -41, 117, 62, -76, -77, -111,

View File

@ -29,99 +29,128 @@ class TestFcInt8 : public mindspore::CommonTest {
TestFcInt8() {}
};
int FcInt8TestInit(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lite::tensor::Tensor *> *outputs_,
MatMulParameter *matmal_param, float **correct, double *scale, int *zeropoint) {
float input_max = 20;
float input_min = -20;
float weight_max = 1;
float weight_min = -1;
float output_max = 20;
float output_min = -20;
struct TensorInfo {
float *data;
int *data_int;
float min;
float max;
int len;
std::vector<int> *shape;
};
double input_scale =
(input_max - input_min) / (std::numeric_limits<int8_t>::max() - std::numeric_limits<int8_t>::min());
int input_zp = std::numeric_limits<int8_t>::max() - input_max / input_scale;
double weight_scale =
(weight_max - weight_min) / (std::numeric_limits<int8_t>::max() - std::numeric_limits<int8_t>::min());
int weight_zp = std::numeric_limits<int8_t>::max() - weight_max / weight_scale;
double output_scale =
(output_max - output_min) / (std::numeric_limits<int8_t>::max() - std::numeric_limits<int8_t>::min());
int output_zp = std::numeric_limits<int8_t>::max() - output_max / output_scale;
*scale = output_scale;
*zeropoint = output_zp;
extern void QuantProcess(float *input, int len, float min, float max, float *scale, int *zero_point, int8_t *output);
extern lite::tensor::Tensor *MakeQuantTensor(int8_t *data, int len, std::vector<int> *shape, float scale, int zp);
Tensor *in_t = new Tensor(kNumberTypeInt8, {2, 2, 2, 2}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
in_t->MallocData();
float in[] = {-3.2366564, -4.7733846, -7.8329225, 16.146885, 5.060793, -6.1471, -1.7680453, -6.5721383,
17.87506, -5.1192183, 10.742863, 1.4536934, 19.693445, 19.45783, 5.063163, 0.5234792};
Quantize(in, in_t->ElementsNum(), input_scale, input_zp, reinterpret_cast<int8_t *>(in_t->Data()));
auto in_quant_arg = new mindspore::lite::tensor::QuantArg();
in_quant_arg->zeroPoint = input_zp;
in_quant_arg->scale = input_scale;
in_t->AddQuantParam(*in_quant_arg);
inputs_->push_back(in_t);
Tensor *weight_t = new Tensor(kNumberTypeInt8, {3, 8}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
weight_t->MallocData();
float weight[] = {-0.24438887, 0.06738146, -0.8169129, 0.21510671, -0.012470592, -0.053063435,
0.6050155, 0.8656233, 0.12911413, -0.028635843, -0.034080597, -0.10622552,
-0.012254699, -0.01312836, 0.25241964, -0.4706142, 0.2451482, -0.9558459,
0.4481974, 0.33251503, -0.011705584, -0.1720293, -0.39410214, -0.73637343};
Quantize(weight, weight_t->ElementsNum(), weight_scale, weight_zp, reinterpret_cast<int8_t *>(weight_t->Data()));
auto weight_quant_arg = new mindspore::lite::tensor::QuantArg();
weight_quant_arg->zeroPoint = weight_zp;
weight_quant_arg->scale = weight_scale;
weight_t->AddQuantParam(*weight_quant_arg);
inputs_->push_back(weight_t);
Tensor *bias_t = new Tensor(kNumberTypeInt32, {3}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
bias_t->MallocData();
memset(bias_t->Data(), 0, sizeof(int) * bias_t->ElementsNum());
inputs_->push_back(bias_t);
Tensor *out_t = new Tensor(kNumberTypeInt8, {2, 3}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
out_t->MallocData();
auto output_quant_arg = new mindspore::lite::tensor::QuantArg();
output_quant_arg->zeroPoint = output_zp;
output_quant_arg->scale = output_scale;
out_t->AddQuantParam(*output_quant_arg);
outputs_->push_back(out_t);
*correct = reinterpret_cast<float *>(malloc(out_t->ElementsNum() * sizeof(float)));
float nchw_co[] = {3.84586822, 0.93586633, 12.16212629, -10.93835061, 2.46887183, 8.61480108};
memcpy(*correct, nchw_co, out_t->ElementsNum() * sizeof(float));
matmal_param->b_transpose_ = true;
matmal_param->a_transpose_ = false;
matmal_param->has_bias_ = true;
matmal_param->act_type_ = ActType_No;
return out_t->ElementsNum();
lite::tensor::Tensor *MakeIntTensor(int *data, int len, std::vector<int> *shape) {
auto tensor =
new lite::tensor::Tensor(kNumberTypeInt32, *shape, schema::Format_NHWC, static_cast<schema::NodeType>(1));
tensor->MallocData();
auto tensor_ptr = reinterpret_cast<int *>(tensor->Data());
memcpy(tensor_ptr, data, len * sizeof(int));
return tensor;
}
TEST_F(TestFcInt8, fcint8) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
auto matmul_param = new MatMulParameter();
float *correct;
double output_scale;
int output_zp;
int total_size = FcInt8TestInit(&inputs_, &outputs_, matmul_param, &correct, &output_scale, &output_zp);
lite::Context *ctx = new lite::Context;
void FcInt8TestInit(std::vector<lite::tensor::Tensor *> *inputs, std::vector<lite::tensor::Tensor *> *outputs,
TensorInfo *in, TensorInfo *weight, TensorInfo *bias, TensorInfo *out) {
float in_scale, weight_scale, out_scale;
int in_zp, weight_zp, out_zp;
int8_t *in_data = new int8_t[in->len];
int8_t *weight_data = new int8_t[weight->len];
QuantProcess(in->data, in->len, in->min, in->max, &in_scale, &in_zp, in_data);
auto in_tensor = MakeQuantTensor(in_data, in->len, in->shape, in_scale, in_zp);
inputs->push_back(in_tensor);
QuantProcess(weight->data, weight->len, weight->min, weight->max, &weight_scale, &weight_zp, weight_data);
auto weight_tensor = MakeQuantTensor(weight_data, weight->len, weight->shape, weight_scale, weight_zp);
inputs->push_back(weight_tensor);
auto bias_tensor = MakeIntTensor(bias->data_int, bias->len, bias->shape);
inputs->push_back(bias_tensor);
QuantProcess(out->data, out->len, out->min, out->max, &out_scale, &out_zp, nullptr);
auto out_tensor = MakeQuantTensor(nullptr, out->len, out->shape, out_scale, out_zp);
outputs->push_back(out_tensor);
delete[] in_data;
delete[] weight_data;
}
TEST_F(TestFcInt8, fctest1) {
float in[] = {4.259103407444801, 5.992151035772917, -9.495343223733581, 3.0509999931426215, -16.635707833991095,
-14.72005749234452, 2.8290916795754093, -15.827977973039049, -16.98208477063347, 2.8801101778935347,
-0.5905297521382735, 18.042746010536085, 3.913511213700396, 11.571264917136105, 19.084257392926148,
8.571560238377568, 17.58868010598305, 12.433311533838427, 4.548078598583526, 15.609650071521138,
6.663372887795717, 17.581323475674594, 1.453277207446778, -6.119351424589654, -16.87310296820285,
11.906066592064796, -13.290100998834653, 19.627129875430548, 16.034262583959162, 10.255738135902781,
12.134650347811792, -5.5882066903433305, 15.554050723026322, 15.288481461776783, 17.651080309797287,
-9.258779162183215, 4.218532791445092, -6.205309122668545, 1.2220458021156908, 1.6800736573947326};
TensorInfo in_params;
in_params.data = in;
in_params.len = 40;
std::vector<int> in_shape{5, 2, 2, 2};
in_params.shape = &in_shape;
in_params.min = -20;
in_params.max = 20;
float weight[] = {
-0.586269014312498, 0.10845796767603733, 0.8455159907124523, 0.20261291069007226, 0.7564258582027543,
0.4505005038790615, -0.607259232240795, -0.6962171798923924, 0.7967573009922135, -0.46069496925353715,
-0.2967638879316592, -0.7025557337565955, -0.5313515272071268, 0.07584168670764102, -0.6860034691410029,
0.9218806800279316, -0.07408538201953907, -0.7933652717840096, 0.6636691558029275, -0.30198695606477477,
0.790225747868754, -0.9478140254555916, 0.4537316306461665, 0.1776848732022871, -0.7492316745474277,
-0.5825825240770948, 0.5680842804542614, -0.9255552309192772, 0.20866577718844725, 0.9570928647172854,
0.18172570688854406, -0.26442830241827253, -0.24765169216720873, -0.19512285277145702, 0.1120696020054861,
0.7558578199370625, -0.15032457481135109, -0.08485585411928809, 0.6343014796699504, 0.026380085222785787,
-0.40516674259120444, -0.7407588590646037, -0.28521396461492454, 0.2555841827858194, 0.023640857478332444,
-0.6540694390119834, 0.7439705499824205, -0.7579774562590929};
TensorInfo weight_params;
weight_params.data = weight;
weight_params.len = 48;
std::vector<int> weight_shape{6, 8};
weight_params.shape = &weight_shape;
weight_params.min = -1;
weight_params.max = 1;
int bias[6] = {0};
TensorInfo bias_params;
bias_params.data_int = bias;
bias_params.len = 6;
std::vector<int> bias_shape{6};
bias_params.shape = &bias_shape;
float correct[] = {-19.170732, -7.5019627, -13.015462, -27.760283, 4.1447954, 20.660276, 4.0412164, -33.750015,
-4.560128, 7.1035166, 27.976341, 9.75216, 14.383608, -12.87587, -24.688887, -12.185722,
3.7933283, -19.266382, 17.193876, -49.99205, -15.480089, -3.1659412, 19.470417, 13.758459,
4.0713396, 4.614437, 11.296907, -7.244551, -11.143417, -21.233654};
TensorInfo out_params;
out_params.data = correct;
out_params.len = 30;
std::vector<int> out_shape{5, 6};
out_params.shape = &out_shape;
out_params.min = -50;
out_params.max = 50;
auto fc_param = new MatMulParameter();
fc_param->a_transpose_ = false;
fc_param->b_transpose_ = true;
fc_param->has_bias_ = true;
fc_param->act_type_ = ActType_No;
std::vector<lite::tensor::Tensor *> inputs;
std::vector<lite::tensor::Tensor *> outputs;
FcInt8TestInit(&inputs, &outputs, &in_params, &weight_params, &bias_params, &out_params);
auto ctx = new lite::Context;
ctx->thread_num_ = 2;
kernel::FullconnectionInt8CPUKernel *fc = new kernel::FullconnectionInt8CPUKernel(
reinterpret_cast<OpParameter *>(matmul_param), inputs_, outputs_, ctx, nullptr);
kernel::FullconnectionInt8CPUKernel *fc =
new kernel::FullconnectionInt8CPUKernel(reinterpret_cast<OpParameter *>(fc_param), inputs, outputs, ctx, nullptr);
fc->Init();
fc->Run();
float fout[6] = {0};
Dequantize(reinterpret_cast<int8_t *>(outputs_[0]->Data()), outputs_[0]->ElementsNum(), output_scale, output_zp,
fout);
CompareOutputData(fout, correct, 6, 0.2);
delete matmul_param;
float out_scale;
int out_zp;
QuantProcess(correct, out_params.len, out_params.min, out_params.max, &out_scale, &out_zp, nullptr);
float *out = new float[out_params.len];
Dequantize(reinterpret_cast<int8_t *>(outputs[0]->Data()), outputs[0]->ElementsNum(), out_scale, out_zp, out);
CompareOutputData(out, correct, 6, 0.3);
delete fc;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
free(correct);
for (auto t : inputs) delete t;
for (auto t : outputs) delete t;
delete[] out;
}
} // namespace mindspore

View File

@ -18,8 +18,9 @@
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h"
#include "mindspore/lite/nnacl/quantization/quantize.h"
#include "mindspore/lite/nnacl/common_func.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/common_func.h"
#include "nnacl/int8/matmul_int8.h"
#include "mindspore/lite/src/kernel_registry.h"
#include "mindspore/lite/src/lite_kernel.h"
@ -29,99 +30,283 @@ class TestMatmulInt8 : public mindspore::CommonTest {
TestMatmulInt8() {}
};
int MMInt8TestInit(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lite::tensor::Tensor *> *outputs_,
MatMulParameter *matmal_param, float **correct, double *scale, int *zeropoint) {
float input_max = 20;
float input_min = -20;
float weight_max = 1;
float weight_min = -1;
float output_max = 30;
float output_min = -30;
struct TensorInfo {
float *data;
float min;
float max;
int len;
std::vector<int> *shape;
};
double input_scale =
(input_max - input_min) / (std::numeric_limits<int8_t>::max() - std::numeric_limits<int8_t>::min());
int input_zp = std::numeric_limits<int8_t>::max() - input_max / input_scale;
double weight_scale =
(weight_max - weight_min) / (std::numeric_limits<int8_t>::max() - std::numeric_limits<int8_t>::min());
int weight_zp = std::numeric_limits<int8_t>::max() - weight_max / weight_scale;
double output_scale =
(output_max - output_min) / (std::numeric_limits<int8_t>::max() - std::numeric_limits<int8_t>::min());
int output_zp = std::numeric_limits<int8_t>::max() - output_max / output_scale;
*scale = output_scale;
*zeropoint = output_zp;
void QuantProcess(float *input, int len, float min, float max, float *scale, int *zero_point, int8_t *output) {
*scale = (max - min) / (std::numeric_limits<int8_t>::max() - std::numeric_limits<int8_t>::min());
*zero_point = std::numeric_limits<int8_t>::max() - max / (*scale);
if (output) {
Quantize(input, len, *scale, *zero_point, output);
}
}
auto in_t =
new lite::tensor::Tensor(kNumberTypeInt8, {1, 2, 8}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
in_t->MallocData();
lite::tensor::Tensor *MakeQuantTensor(int8_t *data, int len, std::vector<int> *shape, float scale, int zp) {
auto tensor =
new lite::tensor::Tensor(kNumberTypeInt8, *shape, schema::Format_NHWC, static_cast<schema::NodeType>(1));
tensor->MallocData();
if (data) {
auto tensor_ptr = reinterpret_cast<int8_t *>(tensor->Data());
memcpy(tensor_ptr, data, len * sizeof(int8_t));
}
auto quant_arg = new mindspore::lite::tensor::QuantArg();
quant_arg->zeroPoint = zp;
quant_arg->scale = scale;
tensor->AddQuantParam(*quant_arg);
return tensor;
}
void MMInt8TestInit(std::vector<lite::tensor::Tensor *> *inputs, std::vector<lite::tensor::Tensor *> *outputs,
TensorInfo *in, TensorInfo *weight, TensorInfo *out) {
float in_scale, weight_scale, out_scale;
int in_zp, weight_zp, out_zp;
int8_t *in_data = new int8_t[in->len];
int8_t *weight_data = new int8_t[weight->len];
QuantProcess(in->data, in->len, in->min, in->max, &in_scale, &in_zp, in_data);
auto in_tensor = MakeQuantTensor(in_data, in->len, in->shape, in_scale, in_zp);
inputs->push_back(in_tensor);
QuantProcess(weight->data, weight->len, weight->min, weight->max, &weight_scale, &weight_zp, weight_data);
auto weight_tensor = MakeQuantTensor(weight_data, weight->len, weight->shape, weight_scale, weight_zp);
inputs->push_back(weight_tensor);
QuantProcess(out->data, out->len, out->min, out->max, &out_scale, &out_zp, nullptr);
auto out_tensor = MakeQuantTensor(nullptr, out->len, out->shape, out_scale, out_zp);
outputs->push_back(out_tensor);
delete[] in_data;
delete[] weight_data;
}
TEST_F(TestMatmulInt8, simple) {
#define ROW 10
#define COL 15
#define DEPTH 10
#define ROW4 UP_ROUND(ROW, 4)
#define COL4 UP_ROUND(COL, 4)
#define DEPTH16 UP_ROUND(DEPTH, 16)
int8_t a[ROW * DEPTH] = {-3, -3, 0, -2, -4, -2, 1, 0, -1, 0, 5, 1, 3, 4, 4, -3, -5, 2, -2, 4,
4, 5, 1, -1, 5, 5, 2, -1, 0, 4, -4, 2, 5, -2, 5, 3, -1, 2, -4, 5,
-5, 4, 5, 3, 5, 4, -2, 5, 5, -5, -5, -5, 2, -4, -3, 3, -3, -5, 5, 0,
2, -4, 4, 2, -5, 3, -1, 3, -3, 2, -5, -4, 0, -5, 2, 4, 0, -5, -1, 4,
3, 5, 5, 2, -5, -5, -4, -5, 3, 3, 3, 0, -2, 0, -2, -3, -2, 3, 5, -5};
int8_t b[DEPTH * COL] = {1, 2, -2, -5, -4, 2, 3, 2, -5, 4, -5, 4, 1, -2, 1, 5, 5, 5, 2, 5, -3, -3,
-1, -3, -1, 0, -4, 0, 1, -2, -2, -3, -5, 1, 1, 0, 4, 5, -3, -1, 4, 3, 5, 4,
2, 4, -3, -4, 1, 4, -4, 5, -1, -2, 3, 5, 5, 2, 1, -4, 1, 2, -3, 0, -2, 4,
-3, -3, 1, 3, 4, -1, 3, 1, -5, -1, 2, 0, 0, 5, -1, -5, 5, -5, 0, 3, -3, 4,
3, 1, -3, -3, 2, -2, -3, -3, 3, 4, 2, -1, 2, 0, -2, 4, 5, 3, -1, -3, -2, -1,
4, 3, -5, 1, 0, 0, -1, -4, -3, -2, 5, 3, 2, 1, -4, 1, 4, 5, -1, 2, -2, 2,
1, -2, 5, 2, -4, -4, 1, 1, 2, -1, -5, -4, 4, 1, -3, 4, -1, -4};
int8_t correct[ROW * COL] = {
-36, -33, 11, 4, -12, -7, 11, 0, 37, -30, -13, -2, -30, -3, 29, 46, -13, -84, -8, 6, 39, 26,
-67, -48, 57, 12, 32, 44, -24, -85, 22, 32, -8, -8, 20, 10, -45, 12, -69, 36, 22, -37, 58, 27,
-24, -11, -22, -50, 26, 50, 28, -56, -42, -23, -1, 70, -58, 54, 35, -61, 54, 40, -11, 35, 43, 3,
7, 30, -7, -13, 73, -3, 26, 26, -11, -37, 0, 19, 34, -4, 0, -22, 71, 8, -25, -6, -5, 31,
8, 63, -25, -55, -62, -17, 23, 1, 36, 12, -38, 2, 11, 27, 18, 5, 4, -59, -17, 1, 25, 9,
13, -77, 13, 9, -11, 26, -52, 42, 28, 6, 44, 4, 2, 26, 19, -31, 46, 23, -57, 15, -31, 39,
40, -9, 8, 38, 40, 27, -19, -47, 14, 50, 14, 18, 0, -59, 39, -48, -47, 35};
int8_t output[ROW * COL] = {0};
int8_t *a_r4x16 = new int8_t[ROW4 * DEPTH16];
memset(a_r4x16, 0, ROW4 * DEPTH16);
int8_t *b_c16x4 = new int8_t[COL4 * DEPTH16];
memset(b_c16x4, 0, COL4 * DEPTH16);
RowMajor2Row4x16Major(a, ROW, DEPTH, a_r4x16, DEPTH16);
RowMajor2Col16x4Major(b, DEPTH, COL, b_c16x4, DEPTH16);
int a_sums[ROW4] = {0};
int bias[COL4] = {0};
int multiplier, ls, rs;
QuantizeRoundParameter(1.0f, &multiplier, &ls, &rs);
#ifdef ENABLE_ARM64
MatmulInt8Neon64(a_r4x16, b_c16x4, output, ROW4, COL4, DEPTH16, a_sums, bias, INT8_MIN, INT8_MAX, 0, multiplier, ls,
rs, ROW, COL, COL);
#else
MatmulInt8(a_r4x16, b_c16x4, output, a_sums, bias, INT8_MIN, INT8_MAX, 0, multiplier, ls, rs, ROW, COL, DEPTH16, COL);
#endif
CompareOutputData(output, correct, ROW * COL, 0.1);
delete[] a_r4x16;
delete[] b_c16x4;
}
TEST_F(TestMatmulInt8, mmtest1) {
float in[] = {6.583835634764597, 11.337275140963907, -4.125256949459629, 10.994337291530833,
19.086065139532636, 3.620842999158455, 13.167624585590346, -18.326739299407755,
14.877693740734841, -17.092677920571653, 19.24147072807235, -15.14805323833401,
-18.075654829688737, -0.9164404591894204, -3.836646280336332, -10.870298671273918};
Quantize(in, in_t->ElementsNum(), input_scale, input_zp, reinterpret_cast<int8_t *>(in_t->Data()));
auto in_quant_arg = new mindspore::lite::tensor::QuantArg();
in_quant_arg->zeroPoint = input_zp;
in_quant_arg->scale = input_scale;
in_t->AddQuantParam(*in_quant_arg);
inputs_->push_back(in_t);
TensorInfo in_params;
in_params.data = in;
in_params.len = 16;
std::vector<int> in_shape{1, 2, 8};
in_params.shape = &in_shape;
in_params.min = -20;
in_params.max = 20;
auto weight_t =
new lite::tensor::Tensor(kNumberTypeInt8, {1, 3, 8}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
weight_t->MallocData();
float weight[] = {0.3651070698591563, -0.5856943921727129, -0.7472032663840145, 0.9489992871641959,
-0.8179490270358738, -0.873058811259344, 0.39876672713807215, -0.1816769383004213,
-0.13584645926733696, -0.7614673836659709, -0.2535825872616164, -0.05265760030895916,
0.28558728305658754, 0.15404213943520118, -0.1634824450738006, -0.5068199082730189,
-0.026961256849111326, -0.1508441942453307, 0.9375335677537737, 0.3304690744194263,
-0.5091563780251127, 0.029887336278646925, -0.39540496207319276, 0.46094065001445084};
Quantize(weight, weight_t->ElementsNum(), weight_scale, weight_zp, reinterpret_cast<int8_t *>(weight_t->Data()));
auto weight_quant_arg = new mindspore::lite::tensor::QuantArg();
weight_quant_arg->zeroPoint = weight_zp;
weight_quant_arg->scale = weight_scale;
weight_t->AddQuantParam(*weight_quant_arg);
inputs_->push_back(weight_t);
TensorInfo weight_params;
weight_params.data = weight;
weight_params.len = 24;
std::vector<int> weight_shape{1, 3, 8};
weight_params.shape = &weight_shape;
weight_params.min = -1;
weight_params.max = 1;
auto out_t =
new lite::tensor::Tensor(kNumberTypeInt8, {1, 2, 3}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
out_t->MallocData();
auto output_quant_arg = new mindspore::lite::tensor::QuantArg();
output_quant_arg->zeroPoint = output_zp;
output_quant_arg->scale = output_scale;
out_t->AddQuantParam(*output_quant_arg);
outputs_->push_back(out_t);
float correct[] = {-0.912632942, 4.08398056, -25.385608673, 2.720281124, 7.745952606, 20.893184662};
TensorInfo out_params;
out_params.data = correct;
out_params.len = 6;
std::vector<int> out_shape{1, 2, 3};
out_params.shape = &out_shape;
out_params.min = -30;
out_params.max = 30;
*correct = reinterpret_cast<float *>(malloc(out_t->ElementsNum() * sizeof(float)));
float nchw_co[] = {-0.912632942, 4.08398056, -25.385608673, 2.720281124, 7.745952606, 20.893184662};
memcpy(*correct, nchw_co, out_t->ElementsNum() * sizeof(float));
matmal_param->b_transpose_ = true;
matmal_param->a_transpose_ = false;
matmal_param->has_bias_ = false;
return out_t->ElementsNum();
}
TEST_F(TestMatmulInt8, mmint8) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
auto matmul_param = new MatMulParameter();
float *correct;
double output_scale;
int output_zp;
int total_size = MMInt8TestInit(&inputs_, &outputs_, matmul_param, &correct, &output_scale, &output_zp);
matmul_param->a_transpose_ = false;
matmul_param->b_transpose_ = true;
matmul_param->has_bias_ = false;
std::vector<lite::tensor::Tensor *> inputs;
std::vector<lite::tensor::Tensor *> outputs;
MMInt8TestInit(&inputs, &outputs, &in_params, &weight_params, &out_params);
auto ctx = new lite::Context;
ctx->thread_num_ = 2;
ctx->thread_num_ = 1;
kernel::MatmulInt8CPUKernel *mm =
new kernel::MatmulInt8CPUKernel(reinterpret_cast<OpParameter *>(matmul_param), inputs_, outputs_, ctx, nullptr);
new kernel::MatmulInt8CPUKernel(reinterpret_cast<OpParameter *>(matmul_param), inputs, outputs, ctx, nullptr);
mm->Init();
mm->Run();
float fout[6] = {0};
Dequantize(reinterpret_cast<int8_t *>(outputs_[0]->Data()), outputs_[0]->ElementsNum(), output_scale, output_zp,
fout);
CompareOutputData(fout, correct, 6, 0.3);
float out_scale;
int out_zp;
QuantProcess(correct, out_params.len, out_params.min, out_params.max, &out_scale, &out_zp, nullptr);
float *out = new float[out_params.len];
Dequantize(reinterpret_cast<int8_t *>(outputs[0]->Data()), outputs[0]->ElementsNum(), out_scale, out_zp, out);
CompareOutputData(out, correct, 6, 0.3);
delete mm;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
free(correct);
for (auto t : inputs) delete t;
for (auto t : outputs) delete t;
delete[] out;
}
TEST_F(TestMatmulInt8, mmtest2) {
float in[] = {
-9.302902352910598, 16.65876088354537, -7.2801759810348265, -6.3246021711950995, 8.467234093555248,
-4.729482636552028, -3.747183865378627, -8.690477390174504, -2.7419930714530523, -3.9478573566319,
7.399137633080947, -1.604450983941291, 0.3115665358682982, -16.864318496334278, 2.5447052588244112,
-13.428639671203255, 13.417832391771974, 10.37917002467671, 14.709787234172168, -16.347969268427146,
4.652834783979106, 6.03601450738973, 2.5788179666401874, -9.236801653471375, -0.18997468903009462,
19.977363387313744, 15.163337058447325, -12.602897730843484, -6.178507797555191, 13.457928661476004,
-10.65587824516124, -18.715557779424188, -9.758039647923935, 8.102044210643097, 19.66309736072973,
-13.368041407254193, 9.928467253978024, 4.9981961360698755, -4.2838547685981645, 1.5021547181513526,
-7.043468062239523, 11.964494917194845, -4.783071964346499, -17.646518743891008, -7.77810768119101,
14.869414292570454, 8.333036603520906, 11.053769742928765, -1.768128637419725, -14.971400302494597,
-0.8653626097283293, -6.21101640878031, 14.83875267850518, 7.224097292538833, -16.747116419664213,
15.310978507353724, -0.05593751363976551, 2.066880260042687, -3.7053788331264137, 9.788933831258937,
-13.614523856950811, -5.656865231633642, 4.720270075138178, -8.366650073458409, 7.197187069893303,
-18.78518907850054, 15.691652153539678, 7.914926057095165, 10.073559408864384, 10.437631177498353,
-3.0580194164595085, 17.36998905922836, 0.09998119223460122, 19.519199178417452, -11.121833210377702,
19.655990774915622, -17.25682638091008, 11.701013896880006, -12.746728025401781, -9.370055221833699,
18.720474512055908, 7.634198897927405, -15.521885320500694, -9.119267877304358, -1.5853789671841945,
4.783147823043613, 14.6732610092525, -9.294170215010427, 9.835421489234331, 13.051159704193232,
-1.422599906517025, -1.5530696181467398, 19.51404609713284, -12.297429715833763, 6.8811248552401985,
13.052476234003755, 18.66085390709462, -8.097735292301103, -6.868239274661935, -8.067142805841826,
3.2707808734101533, 1.8239332220210827};
TensorInfo in_params;
in_params.data = in;
in_params.len = 6 * 17;
std::vector<int> in_shape{1, 6, 17};
in_params.shape = &in_shape;
in_params.min = -20;
in_params.max = 20;
float weight[] = {
-0.42740096214251677, 0.8557068789482212, 0.4560574664172552, -0.1317821769705021, 0.2845963675712846,
0.8414603241768241, 0.24513271080109011, 0.16403708196683398, -0.09111601416189297, -0.714027790956111,
0.12253431683185845, -0.4542459426686125, 0.7123202105555202, -0.3708573394849488, -0.4571735646072892,
-0.595627630450934, -0.5022671357384993, 0.2781065609468565, -0.07586181451887586, -0.2667701710291306,
0.03141663091360791, -0.013304592900917456, -0.7507975439396768, 0.5886778622432618, -0.9056075431439199,
0.9393767525356569, -0.2791312477047512, 0.7134531940450286, 0.3977932134993216, -0.027832574334469395,
0.7222024948455503, -0.2084178952731608, -0.4869535410639745, -0.8255185994321805, 0.975443145421772,
0.541914384763855, -0.8831162309708303, -0.3339354888475805, 0.3699271440691516, -0.26923635397292944,
-0.4975347179262828, 0.2440013185603882, 0.5553443771246633, 0.6111909921005778, -0.5968624036034165,
0.8367593317557596, -0.843079440282104, -0.5651924211153698, 0.7169318662247579, 0.5116755837443465,
-0.9079299375502927, 0.025240632113315176, -0.5819662075810048, -0.37278414060319176, -0.172154755034845,
-0.7372352723583462, 0.2462103743741677, 0.11785417820789856, 0.6712183976911841, -0.7042964391243491,
-0.8215958062965967, -0.7304378130182314, 0.3991295415760667, -0.07226694075875573, 0.9329628273800614,
0.7866596674858193, 0.9410341281569592, 0.39672750454198225, -0.5217505454791054, 0.9538253510722774,
-0.6286845762774464, -0.773460418882959, 0.002296000778892804, 0.9763898918063998, 0.9648708739062339,
0.9400037814137154, -0.6011085333221611, -0.5890262409238565, -0.8078857772627164, 0.233661306598278,
-0.6726381934018617, -0.08533323149874539, 0.19055766469859425, -0.7956482347958518, -0.17012651641579035,
0.7181052528631318, 0.1285045774388125, -0.6997527417326721, -0.8436484573035989, 0.342855467305474,
0.4085157503460306, -0.6199324510955382, -0.6883822276097309, 0.4186437018431113, 0.3030114883148305,
0.0948227655828271, -0.002521771948760465, -0.34878560791422397, 0.08513437045281003, 0.3116035319055901,
-0.7177514192203747, 0.050531673446029046, -0.7399803440665007, -0.9353609485885221, -0.3899340891814298,
0.40867084031625356, -0.17462484099335662, -0.6313167634279941, -0.8135597146296727, -0.9762553414099975,
-0.1040485487920626, -0.6517520252975368, 0.5877412140956126, 0.9433584450325512, 0.24701546283170672,
-0.3236849444311023, -0.12043548611719657, 0.5300129281052712, -0.1380138229226111, -0.8787455295545508,
-0.4361728423289617, 0.7331994894985936, 0.45492774136929826, -0.17836517403432972, 0.10896668585054625,
0.6176507847785211, 0.21617962964770676, -0.6821928873814629, 0.021775035324277825, 0.15089571088539566,
-0.9923383126255942, -0.6034706970202426, 0.17729888871670285, 0.1278810065499425, -0.6575545415840387,
-0.022704865415375197, -0.7366071817901978, -0.9300211224192332, -0.153494127035938, 0.4836121912045357,
-0.3318483587414114, -0.9658468087620375, 0.8388464445207262, 0.45745949405796127, -0.3671803281863002,
-0.1543498074773253, 0.18955899788963748, -0.4452120359256351, -0.5338599486040962, -0.06979561022721281,
-0.45964195574917355, -0.4343754114042866, -0.4318308749403197, 0.748107130947133, -0.4703901010752156,
0.6655596561650823, 0.9075215202451821, 0.2708741258104177, -0.6540233471632313, 0.7250124906689572,
0.6674821078610087, 0.8464696566759315, -0.6106156844283976, 0.8675828337337224, 0.8517737949695063,
-0.8126381016475459, -0.6140987457462099, -0.2984524227549874, 0.2816320572339577, -0.8131479383469931};
TensorInfo weight_params;
weight_params.data = weight;
weight_params.len = 170;
std::vector<int> weight_shape{1, 17, 10};
weight_params.shape = &weight_shape;
weight_params.min = -1;
weight_params.max = 1;
float correct[] = {35.815605, 26.532362, 14.777507, -12.651591, -2.0373726, -47.020798, -18.53121, 2.7848654,
16.19751, -30.754261, 25.830605, 47.635204, 10.247462, -33.260662, 34.145412, -6.1611304,
-18.56802, -24.669813, 20.314533, -5.887198, -14.757037, 24.78901, 20.512205, 17.985718,
17.62954, 20.365099, -26.223736, 0.99702793, 12.752281, -35.30419, -22.09603, 8.2218,
8.120908, 27.685753, -44.010464, -1.879332, -4.531702, 21.434296, 4.2146144, 22.721859,
7.485317, 20.148363, -15.49375, -4.5062046, 37.77292, -0.23385821, -45.532917, -21.055403,
46.854183, -13.595161, 2.8823144, -23.905682, 2.3569264, 26.975227, 32.806625, 9.185071,
-39.330578, -1.0041192, -6.8353715, -33.2658};
TensorInfo out_params;
out_params.data = correct;
out_params.len = 60;
std::vector<int> out_shape{1, 6, 10};
out_params.shape = &out_shape;
out_params.min = -50;
out_params.max = 50;
auto matmul_param = new MatMulParameter();
matmul_param->a_transpose_ = false;
matmul_param->b_transpose_ = false;
matmul_param->has_bias_ = false;
std::vector<lite::tensor::Tensor *> inputs;
std::vector<lite::tensor::Tensor *> outputs;
MMInt8TestInit(&inputs, &outputs, &in_params, &weight_params, &out_params);
auto ctx = new lite::Context;
ctx->thread_num_ = 2;
kernel::MatmulInt8CPUKernel *mm =
new kernel::MatmulInt8CPUKernel(reinterpret_cast<OpParameter *>(matmul_param), inputs, outputs, ctx, nullptr);
mm->Init();
mm->Run();
float out_scale;
int out_zp;
QuantProcess(correct, out_params.len, out_params.min, out_params.max, &out_scale, &out_zp, nullptr);
float *out = new float[out_params.len];
Dequantize(reinterpret_cast<int8_t *>(outputs[0]->Data()), outputs[0]->ElementsNum(), out_scale, out_zp, out);
CompareOutputData(out, correct, 6, 0.6);
delete mm;
for (auto t : inputs) delete t;
for (auto t : outputs) delete t;
delete[] out;
}
} // namespace mindspore