From cd8b664f13e6228ad228e9026bbe429218920c95 Mon Sep 17 00:00:00 2001 From: zhanyuan Date: Sun, 23 Aug 2020 00:09:37 +0800 Subject: [PATCH] Optimize the post process of arm64 matmul int8 --- .../lite/nnacl/assembly/arm64/MatmulInt8.S | 141 +++++++- mindspore/lite/nnacl/common_func.c | 16 - mindspore/lite/nnacl/int8/matmul_int8.c | 66 ++-- mindspore/lite/nnacl/int8/matmul_int8.h | 13 +- .../kernel/arm/int8/fullconnection_int8.cc | 83 +++-- .../kernel/arm/int8/fullconnection_int8.h | 36 +- .../runtime/kernel/arm/int8/matmul_int8.cc | 113 ++---- .../src/runtime/kernel/arm/int8/matmul_int8.h | 59 +--- .../kernel/arm/int8/deconv_int8_tests.cc | 48 --- .../arm/int8/fullconnection_int8_tests.cc | 199 ++++++----- .../kernel/arm/int8/matmul_int8_tests.cc | 331 ++++++++++++++---- 11 files changed, 667 insertions(+), 438 deletions(-) diff --git a/mindspore/lite/nnacl/assembly/arm64/MatmulInt8.S b/mindspore/lite/nnacl/assembly/arm64/MatmulInt8.S index 9f1c11a3e9b..e92adc1be62 100644 --- a/mindspore/lite/nnacl/assembly/arm64/MatmulInt8.S +++ b/mindspore/lite/nnacl/assembly/arm64/MatmulInt8.S @@ -24,7 +24,7 @@ //void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, // const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, -// int multiplier, int left_shift, int right_shift); +// int multiplier, int left_shift, int right_shift, int row, int col, int stride); // x0: a(left matrix ptr) // x1: b(right matrix ptr) @@ -40,13 +40,18 @@ // w11: multiplier // w12: left_shift // w13: right_shift +// w14: row +// w15: col +// w24: stride MatmulInt8Neon64: - sub sp, sp, #160 + sub sp, sp, #192 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 stp x19, x20, [sp], #16 stp x21, x22, [sp], #16 + stp x23, x24, [sp], #16 + stp x25, x26, [sp], #16 ldr w8, [sp] ldr w9, [sp, #8] @@ -54,25 +59,28 @@ MatmulInt8Neon64: ldr w11, [sp, #24] ldr w12, [sp, #32] ldr w13, [sp, #40] + ldr w14, [sp, #48] + ldr w15, [sp, #56] + ldr w24, [sp, #64] - mov w15, #0 // b col index - mov w16, #0 // a row index mov w17, #4 // sizeof(int8)*4 mul w21, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16 - + mov w17, #1 + mov x25, x2 L1: - cmp w15, w4 + cmp w4, #0 // if at the end of col4 beq End1 - mov w16, #0 // reset a row index + mov w16, w3 // reset a row4 counter + mov w23, w14 // reset a row counter mov x17, x0 // reload a ptr mov x22, x6 // reload a_sums ptr L2: - cmp w16, w3 + cmp w16, #0 beq End2 mov x18, x1 // reload b ptr - mov x19, x7 // reload bias ptr + mov x19, x7 // reload bias ptr mov w20, w5 // reload depth dup v16.4s, wzr dup v17.4s, wzr @@ -256,21 +264,128 @@ End3: sqxtn v15.8b, v13.8h sqxtn2 v15.16b, v14.8h - st1 {v15.16b}, [x2], #16 - add w16, w16, #4 // a row index + 4 + cmp w23, #4 + blt Write // if rows < 4 + cmp w15, #4 + blt Write // if cols < 4 + + st1 {v15.s}[0], [x2], x24 + st1 {v15.s}[1], [x2], x24 + st1 {v15.s}[2], [x2], x24 + st1 {v15.s}[3], [x2], x24 + b Endwrite + +Write: + cmp w15, #4 + beq WriteCol4 + cmp w15, #3 + beq WriteCol3 + cmp w15, #2 + beq WriteCol2 + cmp w15, #1 + beq WriteCol1 + +WriteCol4: + st1 {v15.s}[0], [x2], x24 + cmp w23, #1 + beq Endwrite + st1 {v15.s}[1], [x2], x24 + cmp w23, #2 + beq Endwrite + st1 {v15.s}[2], [x2], x24 + cmp w23, #3 + beq Endwrite + st1 {v15.s}[3], [x2], x24 + b Endwrite + +WriteCol3: + mov x26, x2 + st1 {v15.b}[0], [x26], #1 + st1 {v15.b}[1], [x26], #1 + st1 {v15.b}[2], [x26], #1 + add x2, x2, x24 + cmp w23, #1 + beq Endwrite + mov x26, x2 + st1 {v15.b}[4], [x26], #1 + st1 {v15.b}[5], [x26], #1 + st1 {v15.b}[6], [x26], #1 + add x2, x2, x24 + cmp w23, #2 + beq Endwrite + mov x26, x2 + st1 {v15.b}[8], [x26], #1 + st1 {v15.b}[9], [x26], #1 + st1 {v15.b}[10], [x26], #1 + add x2, x2, x24 + cmp w23, #3 + beq Endwrite + mov x26, x2 + st1 {v15.b}[12], [x26], #1 + st1 {v15.b}[13], [x26], #1 + st1 {v15.b}[14], [x26], #1 + add x2, x2, x24 + b Endwrite + +WriteCol2: + mov x26, x2 + st1 {v15.b}[0], [x26], #1 + st1 {v15.b}[1], [x26], #1 + add x2, x2, x24 + cmp w23, #1 + beq Endwrite + mov x26, x2 + st1 {v15.b}[4], [x26], #1 + st1 {v15.b}[5], [x26], #1 + add x2, x2, x24 + cmp w23, #2 + beq Endwrite + mov x26, x2 + st1 {v15.b}[8], [x26], #1 + st1 {v15.b}[9], [x26], #1 + add x2, x2, x24 + cmp w23, #3 + beq Endwrite + mov x26, x2 + st1 {v15.b}[12], [x26], #1 + st1 {v15.b}[13], [x26], #1 + add x2, x2, x24 + b Endwrite + +WriteCol1: + st1 {v15.b}[0], [x2], x24 + cmp w23, #1 + beq Endwrite + st1 {v15.b}[4], [x2], x24 + cmp w23, #2 + beq Endwrite + st1 {v15.b}[8], [x2], x24 + cmp w23, #3 + beq Endwrite + st1 {v15.b}[12], [x2], x24 + b Endwrite + +Endwrite: + sub w16, w16, #4 // a row4 counter - 4 + sub w23, w23, #4 // a row counter - 4 b L2 End2: - add w15, w15, #4 // b col index + 4 + sub w4, w4, #4 // b col4 counter - 4 + sub w15, w15, #4 // b col counter - 4 add x1, x1, x21 // b ptr + stride add x7, x7, #16 // bias ptr + stride + add x25, x25, #4 // output + stride(4 * sizeof(int8)) + mov x2, x25 b L1 End1: - sub sp, sp, #160 + sub sp, sp, #192 ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 ldp x19, x20, [sp], #16 ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 ret #endif diff --git a/mindspore/lite/nnacl/common_func.c b/mindspore/lite/nnacl/common_func.c index de5b59cddf1..326e5e4f360 100644 --- a/mindspore/lite/nnacl/common_func.c +++ b/mindspore/lite/nnacl/common_func.c @@ -228,19 +228,3 @@ void IndirectGemmFp32_Comm(float *output, const float *input, const float *weigh return; } -void SimplePostFuncInt8(const int *in, int8_t *out, int oc, int plane, int plane8, int32_t multiplier, - int32_t left_shift, int32_t right_shift, int32_t zp) { - /* (int32_t)row8x8-major * multiplier => (int8_t)row-major */ - for (int r = 0; r < plane; r++) { - for (int c = 0; c < oc; c++) { - int c8div = c / 8, c8mod = c % 8; - int src_index = c8div * plane8 * 8 + r * 8 + c8mod; - int dst_index = r * oc + c; - int32_t value = in[src_index]; - value = MultiplyByQuantizedMultiplier(value, multiplier, left_shift, right_shift) + zp; - value = MSMIN(CHAR_MAX, value); - value = MSMAX(CHAR_MIN, value); - out[dst_index] = (int8_t)value; - } - } -} diff --git a/mindspore/lite/nnacl/int8/matmul_int8.c b/mindspore/lite/nnacl/int8/matmul_int8.c index 18d1c8c2800..31ac84a24b0 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.c +++ b/mindspore/lite/nnacl/int8/matmul_int8.c @@ -117,25 +117,6 @@ void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) } } -void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, const int col8, const int deep, - const int32_t a_zp, const int32_t b_zp) { - /* col8-major * row8-major => row8x8-major */ - for (int row = 0; row < row8; row++) { - for (int col = 0; col < col8; col++) { - int r8div = row / 8, r8mod = row % 8; - int c8div = col / 8, c8mod = col % 8; - size_t ci = c8div * row8 * 8 + row * 8 + c8mod; - int32_t value = 0; - for (int d = 0; d < deep; d++) { - size_t ai = r8div * deep * 8 + d * 8 + r8mod; - size_t bi = c8div * deep * 8 + d * 8 + c8mod; - value = value + ((int32_t)a[ai] - a_zp) * ((int32_t)b[bi] - b_zp); - } - c[ci] = value; - } - } -} - void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, const int *input_sum, const int *bias) { /* row4x16-major * row16x4-major => row4x4-major */ @@ -191,6 +172,36 @@ void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row return; } +/* row4x16-major * col16x4-major => row4x4-major */ +void MatmulInt8(const int8_t *a, const int8_t *b, int8_t *dst, const int *a_sums, const int *bias, int act_min, + int act_max, int out_zp, int multiplier, int left_shift, int right_shift, int row, int col, int deep16, + int stride) { + int8_t *output = dst; + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r4div = r / C4NUM; + int r4mod = r % C4NUM; + int c4div = c / C4NUM; + int c4mod = c % C4NUM; + int value = 0; + for (int d = 0; d < deep16; d++) { + int d16div = d / C16NUM; + int d16mod = d % C16NUM; + size_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod; + size_t bi = c4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod; + value += a[ai] * b[bi]; + } + value -= a_sums[r]; + value += bias[c]; + value = MultiplyByQuantizedMultiplier(value, multiplier, left_shift, right_shift) + out_zp; + value = MSMIN(INT8_MAX, value); + value = MSMAX(INT8_MIN, value); + output[c] = (int8_t)value; + } + output += stride; + } +} + void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16) { int stride = sizeof(int8_t) * 16 * 4; for (int r = 0; r < row; ++r) { @@ -213,23 +224,28 @@ void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_1 } } -void RowMajor2Asums(int8_t *a, int row, int col, int b_zp, int *dst) { +// dst: weight_zp * input_row_sums +void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst) { for (int r = 0; r < row; ++r) { + int sum = 0; for (int c = 0; c < col; ++c) { int src_idx = r * col + c; - dst[r] += a[src_idx]; + sum += input[src_idx]; } - dst[r] *= b_zp; + sum *= weight_zp; + dst[r] = sum; } } -void RowMajor2Bbias(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst) { +// dst: bias + depth*input_zp*weight_zp - input_zp*weight_col_sums +void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, int *bias, int *dst) { for (int c = 0; c < col; ++c) { + int sum = 0; for (int r = 0; r < row; ++r) { int src_idx = r * col + c; - dst[c] += b[src_idx]; + sum += weight[src_idx]; } - dst[c] = row * a_zp * b_zp - a_zp * dst[c]; + dst[c] = row * input_zp * weight_zp - input_zp * sum; if (bias) { dst[c] += bias[c]; } diff --git a/mindspore/lite/nnacl/int8/matmul_int8.h b/mindspore/lite/nnacl/int8/matmul_int8.h index 7e7f2e9f944..04ff5972f33 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.h +++ b/mindspore/lite/nnacl/int8/matmul_int8.h @@ -24,8 +24,6 @@ #ifdef __cplusplus extern "C" { #endif -void MatMulInt8(const int8_t *a, const int8_t *b, int *c, const int row8, const int col8, const int deep, - const int a_zp, const int b_zp); void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, const int *input_sum, const int *bias); void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16, @@ -39,15 +37,16 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col); void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16); void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16); -void RowMajor2Asums(int8_t *a, int row, int col, int b_zp, int *dst); -void RowMajor2Bbias(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst); -void Row4x4Major2RowMajor(int8_t *src, int row4, int8_t *dst, int row, int cow); +void CalcInputSums(int8_t *a, int row, int col, int b_zp, int *dst); +void CalcWeightBiasSums(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst); +void MatmulInt8(const int8_t *a, const int8_t *b, int8_t *dst, const int *a_sums, const int *bias, int act_min, + int act_max, int out_zp, int multiplier, int left_shift, int right_shift, int row, int col, int deep16, + int stride); #ifdef ENABLE_ARM64 -// bias = bias + depth * a_zp * b_zp - a_zp * b_sums void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, int multiplier, int left_shift, - int right_shift); + int right_shift, int row, int col, int stride); void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16, const int *input_sum, const int *bias); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc index 9a092eb994f..54284d72d4b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc @@ -39,36 +39,32 @@ int FullconnectionInt8CPUKernel::ReSize() { fc_param_->row_8_ = UP_ROUND(fc_param_->row_, 8); fc_param_->col_8_ = UP_ROUND(fc_param_->col_, 8); - thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, 8)); - thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, 8), thread_count_); - - a_c8_ptr_ = - reinterpret_cast(ctx_->allocator->Malloc(fc_param_->row_8_ * fc_param_->deep_ * sizeof(int8_t))); - if (!a_c8_ptr_) { - return RET_MEMORY_FAILED; - } - memset(a_c8_ptr_, 0, fc_param_->row_8_ * fc_param_->deep_ * sizeof(int8_t)); - b_r8_ptr_ = - reinterpret_cast(ctx_->allocator->Malloc(fc_param_->col_8_ * fc_param_->deep_ * sizeof(int8_t))); - if (!b_r8_ptr_) { - return RET_MEMORY_FAILED; - } - memset(b_r8_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(int8_t)); + r4_ = UP_ROUND(fc_param_->row_, 4); + c4_ = UP_ROUND(fc_param_->col_, 4); + d16_ = UP_ROUND(fc_param_->deep_, 16); + thread_count_ = MSMIN(thread_count_, UP_DIV(c4_, 4)); + thread_stride_ = UP_DIV(UP_DIV(c4_, 4), thread_count_); + a_r4x16_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(r4_ * d16_ * sizeof(int8_t))); + if (!a_r4x16_ptr_) return RET_MEMORY_FAILED; + memset(a_r4x16_ptr_, 0, r4_ * d16_ * sizeof(int8_t)); + b_c16x4_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(c4_ * d16_ * sizeof(int8_t))); + if (!b_c16x4_ptr_) return RET_MEMORY_FAILED; + memset(b_c16x4_ptr_, 0, c4_ * d16_ * sizeof(int8_t)); + input_sums_ = reinterpret_cast(ctx_->allocator->Malloc(r4_ * sizeof(int))); + if (!input_sums_) return RET_MEMORY_FAILED; + memset(input_sums_, 0, r4_ * sizeof(int)); + weight_bias_sums_ = reinterpret_cast(ctx_->allocator->Malloc(c4_ * sizeof(int))); + if (!weight_bias_sums_) return RET_MEMORY_FAILED; + memset(weight_bias_sums_, 0, c4_ * sizeof(int)); auto weight_data = reinterpret_cast(in_tensors_[1]->Data()); - RowMajor2Col8MajorInt8(weight_data, b_r8_ptr_, fc_param_->col_, fc_param_->deep_); - c_r8x8_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(fc_param_->row_8_ * fc_param_->col_8_ * sizeof(int))); - if (!c_r8x8_ptr_) { - return RET_MEMORY_FAILED; - } - memset(c_r8x8_ptr_, 0, fc_param_->row_8_ * fc_param_->col_8_ * sizeof(int)); - auto bias_len = fc_param_->col_8_ * sizeof(int); - bias_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(bias_len)); - if (!bias_ptr_) { - return RET_MEMORY_FAILED; - } - memset(bias_ptr_, 0, bias_len); + RowMajor2Row4x16Major(weight_data, fc_param_->col_, fc_param_->deep_, b_c16x4_ptr_, d16_); if (in_tensors_.size() == 3) { + auto bias_len = fc_param_->col_8_ * sizeof(int); + bias_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(bias_len)); + if (!bias_ptr_) return RET_MEMORY_FAILED; memcpy(bias_ptr_, in_tensors_[2]->Data(), bias_len); + } else { + bias_ptr_ = NULL; } auto input_tensor = in_tensors_[0]; @@ -93,18 +89,32 @@ int FullconnectionInt8CPUKernel::ReSize() { CalculateActivationRangeQuantized(fc_param_->act_type_ == ActType_Relu, fc_param_->act_type_ == ActType_Relu6, quant_params_.output.zp_, quant_params_.output.scale_, &quant_params_.out_act_min, &quant_params_.out_act_max); + CalcWeightBiasSums(weight_data, fc_param_->deep_, fc_param_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, + bias_ptr_, weight_bias_sums_); return RET_OK; } int FullconnectionInt8CPUKernel::RunImpl(int task_id) { - int cur_oc = MSMIN(thread_stride_, UP_DIV(fc_param_->col_8_, 8) - task_id * thread_stride_); + int cur_oc = MSMIN(thread_stride_, UP_DIV(c4_, 4) - task_id * thread_stride_); if (cur_oc <= 0) { return RET_OK; } - auto &p = quant_params_; - auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_; - auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->row_8_; - MatMulInt8(a_c8_ptr_, cur_b, cur_c, fc_param_->row_8_, cur_oc * 8, fc_param_->deep_, p.input.zp_, p.weight.zp_); + int cur_oc_res = MSMIN(thread_stride_ * C4NUM, fc_param_->col_ - task_id * thread_stride_ * C4NUM); + auto &q = quant_params_; + auto &p = fc_param_; + auto cur_b = b_c16x4_ptr_ + task_id * thread_stride_ * C4NUM * d16_; + auto cur_bias = weight_bias_sums_ + task_id * thread_stride_ * C4NUM; + auto output_ptr = reinterpret_cast(out_tensors_[0]->Data()); + auto cur_c = output_ptr + task_id * thread_stride_ * C4NUM; +#ifdef ENABLE_ARM64 + MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, q.out_act_min, + q.out_act_max, q.output.zp_, q.quant_multiplier, q.left_shift, q.right_shift, p->row_, cur_oc_res, + p->col_ * sizeof(int8_t)); +#else + MatmulInt8(a_r4x16_ptr_, cur_b, cur_c, input_sums_, cur_bias, q.out_act_min, q.out_act_max, q.output.zp_, + q.quant_multiplier, q.left_shift, q.right_shift, p->row_, cur_oc_res, d16_, p->col_); +#endif + return RET_OK; } @@ -124,13 +134,10 @@ int FullconnectionInt8CPUKernel::Run() { MS_LOG(ERROR) << "Prepare failed."; return RET_ERROR; } - auto a_ptr = reinterpret_cast(in_tensors_[0]->Data()); - auto output_ptr = reinterpret_cast(out_tensors_[0]->Data()); - auto &p = quant_params_; - RowMajor2Col8MajorInt8(a_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_); + auto input_ptr = reinterpret_cast(in_tensors_[0]->Data()); + RowMajor2Row4x16Major(input_ptr, fc_param_->row_, fc_param_->deep_, a_r4x16_ptr_, d16_); + CalcInputSums(input_ptr, fc_param_->row_, fc_param_->deep_, quant_params_.weight.zp_, input_sums_); LiteBackendParallelLaunch(FcInt8Run, this, thread_count_); - PostFuncInt8C8(c_r8x8_ptr_, bias_ptr_, output_ptr, fc_param_->col_, fc_param_->row_, p.quant_multiplier, p.left_shift, - p.right_shift, p.output.zp_, p.out_act_min, p.out_act_max); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h index 9af8b850fa6..9e2aca294c3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h @@ -41,28 +41,36 @@ class FullconnectionInt8CPUKernel : public FullconnectionBaseCPUKernel { private: void FreeTmpBuffer() { - if (a_c8_ptr_ != nullptr) { - ctx_->allocator->Free(a_c8_ptr_); - a_c8_ptr_ = nullptr; + if (a_r4x16_ptr_ != nullptr) { + ctx_->allocator->Free(a_r4x16_ptr_); + a_r4x16_ptr_ = nullptr; } - if (b_r8_ptr_ != nullptr) { - ctx_->allocator->Free(b_r8_ptr_); - b_r8_ptr_ = nullptr; + if (b_c16x4_ptr_ != nullptr) { + ctx_->allocator->Free(b_c16x4_ptr_); + b_c16x4_ptr_ = nullptr; } - if (c_r8x8_ptr_ != nullptr) { - ctx_->allocator->Free(c_r8x8_ptr_); - c_r8x8_ptr_ = nullptr; + if (input_sums_ != nullptr) { + ctx_->allocator->Free(input_sums_); + input_sums_ = nullptr; + } + if (weight_bias_sums_ != nullptr) { + ctx_->allocator->Free(weight_bias_sums_); + weight_bias_sums_ = nullptr; } if (bias_ptr_ != nullptr) { - ctx_->allocator->Free(bias_ptr_); - bias_ptr_ = nullptr; + ctx_->allocator->Free(weight_bias_sums_); + weight_bias_sums_ = nullptr; } } MatmulQuantArg quant_params_; - int8_t *a_c8_ptr_ = nullptr; - int8_t *b_r8_ptr_ = nullptr; - int *c_r8x8_ptr_ = nullptr; + int8_t *a_r4x16_ptr_ = nullptr; + int8_t *b_c16x4_ptr_ = nullptr; + int *input_sums_ = nullptr; + int *weight_bias_sums_ = nullptr; int *bias_ptr_ = nullptr; + int r4_; + int c4_; + int d16_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc index 7c7c798ee65..935f74d4541 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc @@ -48,46 +48,23 @@ int MatmulInt8CPUKernel::ReSize() { params_->row_8_ = UP_ROUND(params_->row_, 8); params_->col_8_ = UP_ROUND(params_->col_, 8); -#ifdef ENABLE_ARM64 r4_ = UP_ROUND(params_->row_, 4); c4_ = UP_ROUND(params_->col_, 4); d16_ = UP_ROUND(params_->deep_, 16); - a_r4d16_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(r4_ * d16_ * sizeof(int8_t))); - if (!a_r4d16_ptr_) return RET_MEMORY_FAILED; - memset(a_r4d16_ptr_, 0, r4_ * d16_ * sizeof(int8_t)); - b_c4d16_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(c4_ * d16_ * sizeof(int8_t))); - if (!b_c4d16_ptr_) return RET_MEMORY_FAILED; - memset(b_c4d16_ptr_, 0, c4_ * d16_ * sizeof(int8_t)); - c_r4c4_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(r4_ * c4_ * sizeof(int8_t))); - if (!c_r4c4_ptr_) return RET_MEMORY_FAILED; - memset(c_r4c4_ptr_, 0, r4_ * c4_ * sizeof(int8_t)); - a_sums_ = reinterpret_cast(ctx_->allocator->Malloc(r4_ * sizeof(int))); - if (!a_sums_) return RET_MEMORY_FAILED; - memset(a_sums_, 0, r4_ * sizeof(int)); - b_bias_ = reinterpret_cast(ctx_->allocator->Malloc(c4_ * sizeof(int))); - if (!b_bias_) return RET_MEMORY_FAILED; - memset(b_bias_, 0, c4_ * sizeof(int)); + a_r4x16_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(r4_ * d16_ * sizeof(int8_t))); + if (!a_r4x16_ptr_) return RET_MEMORY_FAILED; + memset(a_r4x16_ptr_, 0, r4_ * d16_ * sizeof(int8_t)); + b_c16x4_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(c4_ * d16_ * sizeof(int8_t))); + if (!b_c16x4_ptr_) return RET_MEMORY_FAILED; + memset(b_c16x4_ptr_, 0, c4_ * d16_ * sizeof(int8_t)); + input_sums_ = reinterpret_cast(ctx_->allocator->Malloc(r4_ * sizeof(int))); + if (!input_sums_) return RET_MEMORY_FAILED; + memset(input_sums_, 0, r4_ * sizeof(int)); + weight_bias_sums_ = reinterpret_cast(ctx_->allocator->Malloc(c4_ * sizeof(int))); + if (!weight_bias_sums_) return RET_MEMORY_FAILED; + memset(weight_bias_sums_, 0, c4_ * sizeof(int)); thread_count_ = MSMIN(thread_count_, UP_DIV(c4_, 4)); thread_stride_ = UP_DIV(UP_DIV(c4_, 4), thread_count_); -#else - a_c8_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(params_->row_8_ * params_->deep_ * sizeof(int8_t))); - if (!a_c8_ptr_) { - return RET_MEMORY_FAILED; - } - memset(a_c8_ptr_, 0, params_->row_8_ * params_->deep_ * sizeof(int8_t)); - b_r8_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(params_->col_8_ * params_->deep_ * sizeof(int8_t))); - if (!b_r8_ptr_) { - return RET_MEMORY_FAILED; - } - memset(b_r8_ptr_, 0, params_->col_8_ * params_->deep_ * sizeof(int8_t)); - c_r8x8_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(params_->row_8_ * params_->col_8_ * sizeof(int))); - if (!c_r8x8_ptr_) { - return RET_MEMORY_FAILED; - } - memset(c_r8x8_ptr_, 0, params_->row_8_ * params_->col_8_ * sizeof(int)); - thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8)); - thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_); -#endif auto input_tensor = in_tensors_[0]; auto params = input_tensor->GetQuantParams(); @@ -112,27 +89,25 @@ int MatmulInt8CPUKernel::ReSize() { } int MatmulInt8CPUKernel::RunImpl(int task_id) { -#ifdef ENABLE_ARM64 int cur_oc = MSMIN(thread_stride_, UP_DIV(c4_, 4) - task_id * thread_stride_); if (cur_oc <= 0) { return RET_OK; } - auto cur_b = b_c4d16_ptr_ + task_id * thread_stride_ * 4 * d16_; - auto cur_c = c_r4c4_ptr_ + task_id * thread_stride_ * 4 * r4_; - auto &p = quant_params_; - MatmulInt8Neon64(a_r4d16_ptr_, cur_b, cur_c, r4_, c4_, d16_, a_sums_, b_bias_, INT_MIN, INT_MAX, p.output.zp_, - p.quant_multiplier, p.left_shift, p.right_shift); -#else - int cur_oc = MSMIN(thread_stride_, UP_DIV(params_->col_8_, 8) - task_id * thread_stride_); - if (cur_oc <= 0) { - return RET_OK; - } - auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_; - auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * params_->row_8_; + int cur_oc_res = MSMIN(thread_stride_ * C4NUM, params_->col_ - task_id * thread_stride_ * C4NUM); + auto cur_b = b_c16x4_ptr_ + task_id * thread_stride_ * 4 * d16_; + auto cur_bias = weight_bias_sums_ + task_id * thread_stride_ * 4; + auto cur_c = c_ptr_ + task_id * thread_stride_ * 4; - MatMulInt8(a_c8_ptr_, cur_b, cur_c, params_->row_8_, cur_oc * 8, params_->deep_, quant_params_.input.zp_, - quant_params_.weight.zp_); + auto &p = quant_params_; +#ifdef ENABLE_ARM64 + MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, INT8_MIN, INT8_MAX, + p.output.zp_, p.quant_multiplier, p.left_shift, p.right_shift, params_->row_, cur_oc_res, + params_->col_ * sizeof(int8_t)); +#else + MatmulInt8(a_r4x16_ptr_, cur_b, cur_c, input_sums_, cur_bias, INT8_MIN, INT8_MAX, p.output.zp_, p.quant_multiplier, + p.left_shift, p.right_shift, params_->row_, cur_oc_res, d16_, params_->col_); #endif + return RET_OK; } @@ -162,43 +137,27 @@ int MatmulInt8CPUKernel::Run() { for (int i = 0; i < params_->batch; ++i) { auto cur_a_ptr = a_ptr + i * a_stride; auto cur_b_ptr = b_ptr + i * b_stride; - auto cur_c_ptr = c_ptr + i * c_stride; -#ifdef ENABLE_ARM64 if (params_->a_transpose_) { - RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4d16_ptr_, d16_); + RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_, d16_); } else { - RowMajor2Row4x16Major(cur_a_ptr, params_->row_, params_->deep_, a_r4d16_ptr_, d16_); + RowMajor2Row4x16Major(cur_a_ptr, params_->row_, params_->deep_, a_r4x16_ptr_, d16_); } if (params_->b_transpose_) { - RowMajor2Row4x16Major(cur_b_ptr, params_->col_, params_->deep_, b_c4d16_ptr_, d16_); + RowMajor2Row4x16Major(cur_b_ptr, params_->col_, params_->deep_, b_c16x4_ptr_, d16_); } else { - RowMajor2Col16x4Major(cur_b_ptr, params_->deep_, params_->col_, b_c4d16_ptr_, d16_); + RowMajor2Col16x4Major(cur_b_ptr, params_->deep_, params_->col_, b_c16x4_ptr_, d16_); } + c_ptr_ = c_ptr + i * c_stride; auto &q = quant_params_; - RowMajor2Asums(cur_a_ptr, params_->row_, params_->deep_, q.weight.zp_, a_sums_); - RowMajor2Bbias(cur_b_ptr, params_->deep_, params_->col_, q.input.zp_, q.weight.zp_, NULL, b_bias_); - LiteBackendParallelLaunch(MatmulInt8Run, this, thread_count_); - Row4x4Major2RowMajor(c_r4c4_ptr_, r4_, cur_c_ptr, params_->row_, params_->col_); -#else - if (params_->a_transpose_) { - RowMajor2Row8MajorInt8(cur_a_ptr, a_c8_ptr_, params_->deep_, params_->row_); - } else { - RowMajor2Col8MajorInt8(cur_a_ptr, a_c8_ptr_, params_->row_, params_->deep_); + CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, q.weight.zp_, input_sums_); + CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, q.input.zp_, q.weight.zp_, NULL, weight_bias_sums_); + ret = LiteBackendParallelLaunch(MatmulInt8Run, this, thread_count_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "MatmulInt8Run error: [" << ret << "]"; + return ret; } - if (params_->b_transpose_) { - RowMajor2Col8MajorInt8(cur_b_ptr, b_r8_ptr_, params_->col_, params_->deep_); - } else { - RowMajor2Row8MajorInt8(cur_b_ptr, b_r8_ptr_, params_->deep_, params_->col_); - } - LiteBackendParallelLaunch(MatmulInt8Run, this, thread_count_); - auto &q = quant_params_; - SimplePostFuncInt8(c_r8x8_ptr_, cur_c_ptr, params_->col_, params_->row_, params_->row_8_, q.quant_multiplier, - q.left_shift, q.right_shift, q.output.zp_); -#endif } - return RET_OK; } - } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h index 7e6d66f8fb5..d728d2aecd5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h @@ -39,57 +39,32 @@ class MatmulInt8CPUKernel : public MatmulBaseCPUKernel { private: void FreeTmpBuffer() { -#ifdef ENABLE_ARM64 - if (a_r4d16_ptr_ != nullptr) { - ctx_->allocator->Free(a_r4d16_ptr_); - a_r4d16_ptr_ = nullptr; + if (a_r4x16_ptr_ != nullptr) { + ctx_->allocator->Free(a_r4x16_ptr_); + a_r4x16_ptr_ = nullptr; } - if (b_c4d16_ptr_ != nullptr) { - ctx_->allocator->Free(b_c4d16_ptr_); - b_c4d16_ptr_ = nullptr; + if (b_c16x4_ptr_ != nullptr) { + ctx_->allocator->Free(b_c16x4_ptr_); + b_c16x4_ptr_ = nullptr; } - if (c_r4c4_ptr_ != nullptr) { - ctx_->allocator->Free(c_r4c4_ptr_); - c_r4c4_ptr_ = nullptr; + if (input_sums_ != nullptr) { + ctx_->allocator->Free(input_sums_); + input_sums_ = nullptr; } - if (a_sums_ != nullptr) { - ctx_->allocator->Free(a_sums_); - a_sums_ = nullptr; + if (weight_bias_sums_ != nullptr) { + ctx_->allocator->Free(weight_bias_sums_); + weight_bias_sums_ = nullptr; } - if (b_bias_ != nullptr) { - ctx_->allocator->Free(b_bias_); - b_bias_ = nullptr; - } -#else - if (a_c8_ptr_ != nullptr) { - ctx_->allocator->Free(a_c8_ptr_); - a_c8_ptr_ = nullptr; - } - if (b_r8_ptr_ != nullptr) { - ctx_->allocator->Free(b_r8_ptr_); - b_r8_ptr_ = nullptr; - } - if (c_r8x8_ptr_ != nullptr) { - ctx_->allocator->Free(c_r8x8_ptr_); - c_r8x8_ptr_ = nullptr; - } -#endif } MatmulQuantArg quant_params_; -#ifdef ENABLE_ARM64 - int8_t *a_r4d16_ptr_ = nullptr; - int8_t *b_c4d16_ptr_ = nullptr; - int8_t *c_r4c4_ptr_ = nullptr; - int *a_sums_ = nullptr; - int *b_bias_ = nullptr; + int8_t *a_r4x16_ptr_ = nullptr; + int8_t *b_c16x4_ptr_ = nullptr; + int8_t *c_ptr_ = nullptr; + int *input_sums_ = nullptr; + int *weight_bias_sums_ = nullptr; int r4_; int c4_; int d16_; -#else - int8_t *a_c8_ptr_ = nullptr; - int8_t *b_r8_ptr_ = nullptr; - int *c_r8x8_ptr_ = nullptr; -#endif }; // namespace mindspore::kernel } // namespace mindspore::kernel diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc index 0791019c47a..d5adb010ad6 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc @@ -134,54 +134,6 @@ TEST_F(TestDeconvInt8, PackInputTest1) { CompareOutputData(dst, co, 8 * 32, 1); } -TEST_F(TestDeconvInt8, MatMulTest1) { - int8_t a_row_major_10_12[] = { - -6, 76, 32, 80, -73, 8, -85, -3, 114, 80, 30, 42, -41, 117, 62, -76, -77, -111, 88, 105, - 68, 105, -74, 13, 51, 94, 31, -52, -92, -4, -35, -71, 101, -93, 46, -65, 57, -41, -51, 77, - 1, 9, 73, -19, -36, 57, 81, -24, 40, 103, 112, 109, -41, -68, 57, 61, 55, -20, 3, 2, - 17, -16, -31, 58, -4, 67, -4, -95, -5, -72, 81, 15, -7, -16, -47, 112, 114, -26, -98, 53, - 15, -49, 26, 19, 19, 8, -57, -35, -79, 118, 29, 21, 37, -48, 83, 7, 124, 113, -5, 15, - -8, 107, -65, -88, 50, -47, -80, -84, 3, -45, 92, 42, -20, -101, 106, -10, 89, 67, 55, 10}; - int32_t zp_a = 15; - int8_t a_col8_major[16 * 12] = {0}; - int8_t b_col_major_12_18[] = { - 92, 27, 22, 52, -112, -20, -57, -2, 89, 32, 93, -66, -25, -54, 94, -97, -119, -98, 101, -99, - 77, -83, 76, 95, 59, 97, 8, 40, -109, -20, 67, -107, 37, -6, -54, -20, -30, 36, -106, -103, - -3, -86, -82, 59, 4, -75, -50, -106, 55, 104, -117, -71, -20, -85, -77, 16, -25, -58, 4, 80, - -75, 94, 32, -68, 2, 40, 56, -103, 11, -98, -70, -69, 0, 57, -6, 82, 66, -112, -61, 33, - -77, -53, 95, -38, 87, -46, -3, 81, -47, 43, 21, 26, -45, -57, 50, -24, -82, -114, 61, 46, - -53, 78, -24, 31, -7, 37, 29, 38, 45, 106, 52, -42, 31, -6, -61, -87, 2, 79, -5, -42, - 43, -106, -104, 7, 91, -63, 58, 97, -15, 74, -96, 15, -23, -3, -47, -97, 100, -54, 26, -46, - 35, 26, 100, -80, 34, -25, 96, -67, -80, -27, 66, 41, 41, -43, -43, -38, -4, -64, 31, 7, - -8, 6, -2, 39, -119, 53, 75, -91, -44, 77, -62, 22, -44, 78, -67, -48, -115, -4, 43, 81, - 40, -20, -5, -89, 60, -62, -4, -48, 66, -64, -69, 62, 17, -89, 1, 87, 81, 32, -29, 51, - 40, 27, 66, 67, 11, -69, 85, -79, -106, 55, 22, -23, 62, 69, -74, 49}; - int32_t zp_b = -20; - int8_t b_row8_major[12 * 24] = {0}; - int32_t co_row_major_10_18[] = { - 32005, 3597, 16595, -3458, 6627, -6663, 818, -3910, 10228, 15079, -19205, -10203, -3178, -10046, - 10374, -6199, 5330, 12163, 1819, 20533, 17382, 18283, 9778, 9185, -12623, -26234, -11987, 7904, - 8144, -1603, 27611, -10190, -20053, 4999, -28389, 21852, 24680, 25858, 23506, 17944, 11768, 24378, - -6102, -4675, -23460, 10434, -47579, 1986, 12018, -19418, -7248, 4938, -32613, -941, 8171, -4788, - 3325, -11310, -8351, -14786, 6909, 16401, 2017, -6456, 11242, 7393, -9119, 17312, 2646, -14402, - 7201, -9949, 23986, 17607, 27461, -1547, 2783, 7558, 19487, 11158, -2686, 6328, -8225, -11668, - 21858, -2079, -8671, -639, -1544, 1235, 1156, 6582, 2829, -10311, -2692, 5154, 1527, 10870, - 106, -8189, -24174, -1846, -15399, -3598, 14874, -5591, -619, -13667, -6053, -31103, -24499, 13008, - 9143, -17982, 28437, 2176, -2114, -11631, 10779, -1032, -24690, -3112, 2125, 432, 20270, -33859, - 8907, 10063, 1603, 3761, 4805, 4904, -15594, 10786, 4287, -13591, -18777, -1679, 2109, -2243, - 12051, -8504, -6558, 4209, 13606, -25803, 27922, 12092, 7140, 27142, -12267, 2339, -26224, 23674, - -26579, -11398, -1823, -18976, 3641, 4415, -24878, -2045, 15937, 41465, 12601, -14513, -17619, -5728, - 334, -424, 8147, -1369, 5984, 11000, 19016, 4456, -25920, 4506, 5930, 15458}; - int32_t c_row8x8_major[16 * 24] = {0}; - - int32_t out_row_major[180] = {0}; - RowMajor2Col8MajorInt8(a_row_major_10_12, a_col8_major, 10, 12); - RowMajor2Col8MajorInt8(b_col_major_12_18, b_row8_major, 18, 12); - MatMulInt8(a_col8_major, b_row8_major, c_row8x8_major, 16, 24, 12, zp_a, zp_b); - Row8x8Major2RowMajor(reinterpret_cast(c_row8x8_major), reinterpret_cast(out_row_major), 10, 18, 18); - CompareOutputData(out_row_major, co_row_major_10_18, 180, 1); -} - TEST_F(TestDeconvInt8, InputSumTest1) { int8_t packed_a[] = { -6, 76, 32, 80, -73, 8, -85, -3, 114, 80, 30, 42, 15, 15, 15, 15, -41, 117, 62, -76, -77, -111, diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc index 60f1c2c5c71..9584d9fa336 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc @@ -29,99 +29,128 @@ class TestFcInt8 : public mindspore::CommonTest { TestFcInt8() {} }; -int FcInt8TestInit(std::vector *inputs_, std::vector *outputs_, - MatMulParameter *matmal_param, float **correct, double *scale, int *zeropoint) { - float input_max = 20; - float input_min = -20; - float weight_max = 1; - float weight_min = -1; - float output_max = 20; - float output_min = -20; +struct TensorInfo { + float *data; + int *data_int; + float min; + float max; + int len; + std::vector *shape; +}; - double input_scale = - (input_max - input_min) / (std::numeric_limits::max() - std::numeric_limits::min()); - int input_zp = std::numeric_limits::max() - input_max / input_scale; - double weight_scale = - (weight_max - weight_min) / (std::numeric_limits::max() - std::numeric_limits::min()); - int weight_zp = std::numeric_limits::max() - weight_max / weight_scale; - double output_scale = - (output_max - output_min) / (std::numeric_limits::max() - std::numeric_limits::min()); - int output_zp = std::numeric_limits::max() - output_max / output_scale; - *scale = output_scale; - *zeropoint = output_zp; +extern void QuantProcess(float *input, int len, float min, float max, float *scale, int *zero_point, int8_t *output); +extern lite::tensor::Tensor *MakeQuantTensor(int8_t *data, int len, std::vector *shape, float scale, int zp); - Tensor *in_t = new Tensor(kNumberTypeInt8, {2, 2, 2, 2}, schema::Format_NHWC, static_cast(1)); - in_t->MallocData(); - float in[] = {-3.2366564, -4.7733846, -7.8329225, 16.146885, 5.060793, -6.1471, -1.7680453, -6.5721383, - 17.87506, -5.1192183, 10.742863, 1.4536934, 19.693445, 19.45783, 5.063163, 0.5234792}; - Quantize(in, in_t->ElementsNum(), input_scale, input_zp, reinterpret_cast(in_t->Data())); - auto in_quant_arg = new mindspore::lite::tensor::QuantArg(); - in_quant_arg->zeroPoint = input_zp; - in_quant_arg->scale = input_scale; - in_t->AddQuantParam(*in_quant_arg); - inputs_->push_back(in_t); - - Tensor *weight_t = new Tensor(kNumberTypeInt8, {3, 8}, schema::Format_NHWC, static_cast(1)); - weight_t->MallocData(); - float weight[] = {-0.24438887, 0.06738146, -0.8169129, 0.21510671, -0.012470592, -0.053063435, - 0.6050155, 0.8656233, 0.12911413, -0.028635843, -0.034080597, -0.10622552, - -0.012254699, -0.01312836, 0.25241964, -0.4706142, 0.2451482, -0.9558459, - 0.4481974, 0.33251503, -0.011705584, -0.1720293, -0.39410214, -0.73637343}; - Quantize(weight, weight_t->ElementsNum(), weight_scale, weight_zp, reinterpret_cast(weight_t->Data())); - auto weight_quant_arg = new mindspore::lite::tensor::QuantArg(); - weight_quant_arg->zeroPoint = weight_zp; - weight_quant_arg->scale = weight_scale; - weight_t->AddQuantParam(*weight_quant_arg); - inputs_->push_back(weight_t); - - Tensor *bias_t = new Tensor(kNumberTypeInt32, {3}, schema::Format_NHWC, static_cast(1)); - bias_t->MallocData(); - memset(bias_t->Data(), 0, sizeof(int) * bias_t->ElementsNum()); - inputs_->push_back(bias_t); - - Tensor *out_t = new Tensor(kNumberTypeInt8, {2, 3}, schema::Format_NHWC, static_cast(1)); - out_t->MallocData(); - auto output_quant_arg = new mindspore::lite::tensor::QuantArg(); - output_quant_arg->zeroPoint = output_zp; - output_quant_arg->scale = output_scale; - out_t->AddQuantParam(*output_quant_arg); - outputs_->push_back(out_t); - - *correct = reinterpret_cast(malloc(out_t->ElementsNum() * sizeof(float))); - float nchw_co[] = {3.84586822, 0.93586633, 12.16212629, -10.93835061, 2.46887183, 8.61480108}; - memcpy(*correct, nchw_co, out_t->ElementsNum() * sizeof(float)); - - matmal_param->b_transpose_ = true; - matmal_param->a_transpose_ = false; - matmal_param->has_bias_ = true; - matmal_param->act_type_ = ActType_No; - return out_t->ElementsNum(); +lite::tensor::Tensor *MakeIntTensor(int *data, int len, std::vector *shape) { + auto tensor = + new lite::tensor::Tensor(kNumberTypeInt32, *shape, schema::Format_NHWC, static_cast(1)); + tensor->MallocData(); + auto tensor_ptr = reinterpret_cast(tensor->Data()); + memcpy(tensor_ptr, data, len * sizeof(int)); + return tensor; } -TEST_F(TestFcInt8, fcint8) { - std::vector inputs_; - std::vector outputs_; - auto matmul_param = new MatMulParameter(); - float *correct; - double output_scale; - int output_zp; - int total_size = FcInt8TestInit(&inputs_, &outputs_, matmul_param, &correct, &output_scale, &output_zp); - lite::Context *ctx = new lite::Context; +void FcInt8TestInit(std::vector *inputs, std::vector *outputs, + TensorInfo *in, TensorInfo *weight, TensorInfo *bias, TensorInfo *out) { + float in_scale, weight_scale, out_scale; + int in_zp, weight_zp, out_zp; + int8_t *in_data = new int8_t[in->len]; + int8_t *weight_data = new int8_t[weight->len]; + QuantProcess(in->data, in->len, in->min, in->max, &in_scale, &in_zp, in_data); + auto in_tensor = MakeQuantTensor(in_data, in->len, in->shape, in_scale, in_zp); + inputs->push_back(in_tensor); + QuantProcess(weight->data, weight->len, weight->min, weight->max, &weight_scale, &weight_zp, weight_data); + auto weight_tensor = MakeQuantTensor(weight_data, weight->len, weight->shape, weight_scale, weight_zp); + inputs->push_back(weight_tensor); + auto bias_tensor = MakeIntTensor(bias->data_int, bias->len, bias->shape); + inputs->push_back(bias_tensor); + QuantProcess(out->data, out->len, out->min, out->max, &out_scale, &out_zp, nullptr); + auto out_tensor = MakeQuantTensor(nullptr, out->len, out->shape, out_scale, out_zp); + outputs->push_back(out_tensor); + delete[] in_data; + delete[] weight_data; +} + +TEST_F(TestFcInt8, fctest1) { + float in[] = {4.259103407444801, 5.992151035772917, -9.495343223733581, 3.0509999931426215, -16.635707833991095, + -14.72005749234452, 2.8290916795754093, -15.827977973039049, -16.98208477063347, 2.8801101778935347, + -0.5905297521382735, 18.042746010536085, 3.913511213700396, 11.571264917136105, 19.084257392926148, + 8.571560238377568, 17.58868010598305, 12.433311533838427, 4.548078598583526, 15.609650071521138, + 6.663372887795717, 17.581323475674594, 1.453277207446778, -6.119351424589654, -16.87310296820285, + 11.906066592064796, -13.290100998834653, 19.627129875430548, 16.034262583959162, 10.255738135902781, + 12.134650347811792, -5.5882066903433305, 15.554050723026322, 15.288481461776783, 17.651080309797287, + -9.258779162183215, 4.218532791445092, -6.205309122668545, 1.2220458021156908, 1.6800736573947326}; + TensorInfo in_params; + in_params.data = in; + in_params.len = 40; + std::vector in_shape{5, 2, 2, 2}; + in_params.shape = &in_shape; + in_params.min = -20; + in_params.max = 20; + + float weight[] = { + -0.586269014312498, 0.10845796767603733, 0.8455159907124523, 0.20261291069007226, 0.7564258582027543, + 0.4505005038790615, -0.607259232240795, -0.6962171798923924, 0.7967573009922135, -0.46069496925353715, + -0.2967638879316592, -0.7025557337565955, -0.5313515272071268, 0.07584168670764102, -0.6860034691410029, + 0.9218806800279316, -0.07408538201953907, -0.7933652717840096, 0.6636691558029275, -0.30198695606477477, + 0.790225747868754, -0.9478140254555916, 0.4537316306461665, 0.1776848732022871, -0.7492316745474277, + -0.5825825240770948, 0.5680842804542614, -0.9255552309192772, 0.20866577718844725, 0.9570928647172854, + 0.18172570688854406, -0.26442830241827253, -0.24765169216720873, -0.19512285277145702, 0.1120696020054861, + 0.7558578199370625, -0.15032457481135109, -0.08485585411928809, 0.6343014796699504, 0.026380085222785787, + -0.40516674259120444, -0.7407588590646037, -0.28521396461492454, 0.2555841827858194, 0.023640857478332444, + -0.6540694390119834, 0.7439705499824205, -0.7579774562590929}; + TensorInfo weight_params; + weight_params.data = weight; + weight_params.len = 48; + std::vector weight_shape{6, 8}; + weight_params.shape = &weight_shape; + weight_params.min = -1; + weight_params.max = 1; + + int bias[6] = {0}; + TensorInfo bias_params; + bias_params.data_int = bias; + bias_params.len = 6; + std::vector bias_shape{6}; + bias_params.shape = &bias_shape; + + float correct[] = {-19.170732, -7.5019627, -13.015462, -27.760283, 4.1447954, 20.660276, 4.0412164, -33.750015, + -4.560128, 7.1035166, 27.976341, 9.75216, 14.383608, -12.87587, -24.688887, -12.185722, + 3.7933283, -19.266382, 17.193876, -49.99205, -15.480089, -3.1659412, 19.470417, 13.758459, + 4.0713396, 4.614437, 11.296907, -7.244551, -11.143417, -21.233654}; + TensorInfo out_params; + out_params.data = correct; + out_params.len = 30; + std::vector out_shape{5, 6}; + out_params.shape = &out_shape; + out_params.min = -50; + out_params.max = 50; + + auto fc_param = new MatMulParameter(); + fc_param->a_transpose_ = false; + fc_param->b_transpose_ = true; + fc_param->has_bias_ = true; + fc_param->act_type_ = ActType_No; + std::vector inputs; + std::vector outputs; + FcInt8TestInit(&inputs, &outputs, &in_params, &weight_params, &bias_params, &out_params); + auto ctx = new lite::Context; ctx->thread_num_ = 2; - kernel::FullconnectionInt8CPUKernel *fc = new kernel::FullconnectionInt8CPUKernel( - reinterpret_cast(matmul_param), inputs_, outputs_, ctx, nullptr); + + kernel::FullconnectionInt8CPUKernel *fc = + new kernel::FullconnectionInt8CPUKernel(reinterpret_cast(fc_param), inputs, outputs, ctx, nullptr); fc->Init(); fc->Run(); - float fout[6] = {0}; - Dequantize(reinterpret_cast(outputs_[0]->Data()), outputs_[0]->ElementsNum(), output_scale, output_zp, - fout); - CompareOutputData(fout, correct, 6, 0.2); - delete matmul_param; + float out_scale; + int out_zp; + QuantProcess(correct, out_params.len, out_params.min, out_params.max, &out_scale, &out_zp, nullptr); + float *out = new float[out_params.len]; + Dequantize(reinterpret_cast(outputs[0]->Data()), outputs[0]->ElementsNum(), out_scale, out_zp, out); + CompareOutputData(out, correct, 6, 0.3); delete fc; - for (auto t : inputs_) delete t; - for (auto t : outputs_) delete t; - free(correct); + for (auto t : inputs) delete t; + for (auto t : outputs) delete t; + delete[] out; } - } // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc index 017a7657b77..49a93eafb08 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc @@ -18,8 +18,9 @@ #include "utils/log_adapter.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h" -#include "mindspore/lite/nnacl/quantization/quantize.h" -#include "mindspore/lite/nnacl/common_func.h" +#include "nnacl/quantization/quantize.h" +#include "nnacl/common_func.h" +#include "nnacl/int8/matmul_int8.h" #include "mindspore/lite/src/kernel_registry.h" #include "mindspore/lite/src/lite_kernel.h" @@ -29,99 +30,283 @@ class TestMatmulInt8 : public mindspore::CommonTest { TestMatmulInt8() {} }; -int MMInt8TestInit(std::vector *inputs_, std::vector *outputs_, - MatMulParameter *matmal_param, float **correct, double *scale, int *zeropoint) { - float input_max = 20; - float input_min = -20; - float weight_max = 1; - float weight_min = -1; - float output_max = 30; - float output_min = -30; +struct TensorInfo { + float *data; + float min; + float max; + int len; + std::vector *shape; +}; - double input_scale = - (input_max - input_min) / (std::numeric_limits::max() - std::numeric_limits::min()); - int input_zp = std::numeric_limits::max() - input_max / input_scale; - double weight_scale = - (weight_max - weight_min) / (std::numeric_limits::max() - std::numeric_limits::min()); - int weight_zp = std::numeric_limits::max() - weight_max / weight_scale; - double output_scale = - (output_max - output_min) / (std::numeric_limits::max() - std::numeric_limits::min()); - int output_zp = std::numeric_limits::max() - output_max / output_scale; - *scale = output_scale; - *zeropoint = output_zp; +void QuantProcess(float *input, int len, float min, float max, float *scale, int *zero_point, int8_t *output) { + *scale = (max - min) / (std::numeric_limits::max() - std::numeric_limits::min()); + *zero_point = std::numeric_limits::max() - max / (*scale); + if (output) { + Quantize(input, len, *scale, *zero_point, output); + } +} - auto in_t = - new lite::tensor::Tensor(kNumberTypeInt8, {1, 2, 8}, schema::Format_NHWC, static_cast(1)); - in_t->MallocData(); +lite::tensor::Tensor *MakeQuantTensor(int8_t *data, int len, std::vector *shape, float scale, int zp) { + auto tensor = + new lite::tensor::Tensor(kNumberTypeInt8, *shape, schema::Format_NHWC, static_cast(1)); + tensor->MallocData(); + if (data) { + auto tensor_ptr = reinterpret_cast(tensor->Data()); + memcpy(tensor_ptr, data, len * sizeof(int8_t)); + } + auto quant_arg = new mindspore::lite::tensor::QuantArg(); + quant_arg->zeroPoint = zp; + quant_arg->scale = scale; + tensor->AddQuantParam(*quant_arg); + return tensor; +} + +void MMInt8TestInit(std::vector *inputs, std::vector *outputs, + TensorInfo *in, TensorInfo *weight, TensorInfo *out) { + float in_scale, weight_scale, out_scale; + int in_zp, weight_zp, out_zp; + int8_t *in_data = new int8_t[in->len]; + int8_t *weight_data = new int8_t[weight->len]; + QuantProcess(in->data, in->len, in->min, in->max, &in_scale, &in_zp, in_data); + auto in_tensor = MakeQuantTensor(in_data, in->len, in->shape, in_scale, in_zp); + inputs->push_back(in_tensor); + QuantProcess(weight->data, weight->len, weight->min, weight->max, &weight_scale, &weight_zp, weight_data); + auto weight_tensor = MakeQuantTensor(weight_data, weight->len, weight->shape, weight_scale, weight_zp); + inputs->push_back(weight_tensor); + QuantProcess(out->data, out->len, out->min, out->max, &out_scale, &out_zp, nullptr); + auto out_tensor = MakeQuantTensor(nullptr, out->len, out->shape, out_scale, out_zp); + outputs->push_back(out_tensor); + delete[] in_data; + delete[] weight_data; +} + +TEST_F(TestMatmulInt8, simple) { +#define ROW 10 +#define COL 15 +#define DEPTH 10 +#define ROW4 UP_ROUND(ROW, 4) +#define COL4 UP_ROUND(COL, 4) +#define DEPTH16 UP_ROUND(DEPTH, 16) + int8_t a[ROW * DEPTH] = {-3, -3, 0, -2, -4, -2, 1, 0, -1, 0, 5, 1, 3, 4, 4, -3, -5, 2, -2, 4, + 4, 5, 1, -1, 5, 5, 2, -1, 0, 4, -4, 2, 5, -2, 5, 3, -1, 2, -4, 5, + -5, 4, 5, 3, 5, 4, -2, 5, 5, -5, -5, -5, 2, -4, -3, 3, -3, -5, 5, 0, + 2, -4, 4, 2, -5, 3, -1, 3, -3, 2, -5, -4, 0, -5, 2, 4, 0, -5, -1, 4, + 3, 5, 5, 2, -5, -5, -4, -5, 3, 3, 3, 0, -2, 0, -2, -3, -2, 3, 5, -5}; + int8_t b[DEPTH * COL] = {1, 2, -2, -5, -4, 2, 3, 2, -5, 4, -5, 4, 1, -2, 1, 5, 5, 5, 2, 5, -3, -3, + -1, -3, -1, 0, -4, 0, 1, -2, -2, -3, -5, 1, 1, 0, 4, 5, -3, -1, 4, 3, 5, 4, + 2, 4, -3, -4, 1, 4, -4, 5, -1, -2, 3, 5, 5, 2, 1, -4, 1, 2, -3, 0, -2, 4, + -3, -3, 1, 3, 4, -1, 3, 1, -5, -1, 2, 0, 0, 5, -1, -5, 5, -5, 0, 3, -3, 4, + 3, 1, -3, -3, 2, -2, -3, -3, 3, 4, 2, -1, 2, 0, -2, 4, 5, 3, -1, -3, -2, -1, + 4, 3, -5, 1, 0, 0, -1, -4, -3, -2, 5, 3, 2, 1, -4, 1, 4, 5, -1, 2, -2, 2, + 1, -2, 5, 2, -4, -4, 1, 1, 2, -1, -5, -4, 4, 1, -3, 4, -1, -4}; + + int8_t correct[ROW * COL] = { + -36, -33, 11, 4, -12, -7, 11, 0, 37, -30, -13, -2, -30, -3, 29, 46, -13, -84, -8, 6, 39, 26, + -67, -48, 57, 12, 32, 44, -24, -85, 22, 32, -8, -8, 20, 10, -45, 12, -69, 36, 22, -37, 58, 27, + -24, -11, -22, -50, 26, 50, 28, -56, -42, -23, -1, 70, -58, 54, 35, -61, 54, 40, -11, 35, 43, 3, + 7, 30, -7, -13, 73, -3, 26, 26, -11, -37, 0, 19, 34, -4, 0, -22, 71, 8, -25, -6, -5, 31, + 8, 63, -25, -55, -62, -17, 23, 1, 36, 12, -38, 2, 11, 27, 18, 5, 4, -59, -17, 1, 25, 9, + 13, -77, 13, 9, -11, 26, -52, 42, 28, 6, 44, 4, 2, 26, 19, -31, 46, 23, -57, 15, -31, 39, + 40, -9, 8, 38, 40, 27, -19, -47, 14, 50, 14, 18, 0, -59, 39, -48, -47, 35}; + + int8_t output[ROW * COL] = {0}; + int8_t *a_r4x16 = new int8_t[ROW4 * DEPTH16]; + memset(a_r4x16, 0, ROW4 * DEPTH16); + int8_t *b_c16x4 = new int8_t[COL4 * DEPTH16]; + memset(b_c16x4, 0, COL4 * DEPTH16); + RowMajor2Row4x16Major(a, ROW, DEPTH, a_r4x16, DEPTH16); + RowMajor2Col16x4Major(b, DEPTH, COL, b_c16x4, DEPTH16); + int a_sums[ROW4] = {0}; + int bias[COL4] = {0}; + int multiplier, ls, rs; + QuantizeRoundParameter(1.0f, &multiplier, &ls, &rs); +#ifdef ENABLE_ARM64 + MatmulInt8Neon64(a_r4x16, b_c16x4, output, ROW4, COL4, DEPTH16, a_sums, bias, INT8_MIN, INT8_MAX, 0, multiplier, ls, + rs, ROW, COL, COL); +#else + MatmulInt8(a_r4x16, b_c16x4, output, a_sums, bias, INT8_MIN, INT8_MAX, 0, multiplier, ls, rs, ROW, COL, DEPTH16, COL); +#endif + CompareOutputData(output, correct, ROW * COL, 0.1); + delete[] a_r4x16; + delete[] b_c16x4; +} + +TEST_F(TestMatmulInt8, mmtest1) { float in[] = {6.583835634764597, 11.337275140963907, -4.125256949459629, 10.994337291530833, 19.086065139532636, 3.620842999158455, 13.167624585590346, -18.326739299407755, 14.877693740734841, -17.092677920571653, 19.24147072807235, -15.14805323833401, -18.075654829688737, -0.9164404591894204, -3.836646280336332, -10.870298671273918}; - Quantize(in, in_t->ElementsNum(), input_scale, input_zp, reinterpret_cast(in_t->Data())); - auto in_quant_arg = new mindspore::lite::tensor::QuantArg(); - in_quant_arg->zeroPoint = input_zp; - in_quant_arg->scale = input_scale; - in_t->AddQuantParam(*in_quant_arg); - inputs_->push_back(in_t); + TensorInfo in_params; + in_params.data = in; + in_params.len = 16; + std::vector in_shape{1, 2, 8}; + in_params.shape = &in_shape; + in_params.min = -20; + in_params.max = 20; - auto weight_t = - new lite::tensor::Tensor(kNumberTypeInt8, {1, 3, 8}, schema::Format_NHWC, static_cast(1)); - weight_t->MallocData(); float weight[] = {0.3651070698591563, -0.5856943921727129, -0.7472032663840145, 0.9489992871641959, -0.8179490270358738, -0.873058811259344, 0.39876672713807215, -0.1816769383004213, -0.13584645926733696, -0.7614673836659709, -0.2535825872616164, -0.05265760030895916, 0.28558728305658754, 0.15404213943520118, -0.1634824450738006, -0.5068199082730189, -0.026961256849111326, -0.1508441942453307, 0.9375335677537737, 0.3304690744194263, -0.5091563780251127, 0.029887336278646925, -0.39540496207319276, 0.46094065001445084}; - Quantize(weight, weight_t->ElementsNum(), weight_scale, weight_zp, reinterpret_cast(weight_t->Data())); - auto weight_quant_arg = new mindspore::lite::tensor::QuantArg(); - weight_quant_arg->zeroPoint = weight_zp; - weight_quant_arg->scale = weight_scale; - weight_t->AddQuantParam(*weight_quant_arg); - inputs_->push_back(weight_t); + TensorInfo weight_params; + weight_params.data = weight; + weight_params.len = 24; + std::vector weight_shape{1, 3, 8}; + weight_params.shape = &weight_shape; + weight_params.min = -1; + weight_params.max = 1; - auto out_t = - new lite::tensor::Tensor(kNumberTypeInt8, {1, 2, 3}, schema::Format_NHWC, static_cast(1)); - out_t->MallocData(); - auto output_quant_arg = new mindspore::lite::tensor::QuantArg(); - output_quant_arg->zeroPoint = output_zp; - output_quant_arg->scale = output_scale; - out_t->AddQuantParam(*output_quant_arg); - outputs_->push_back(out_t); + float correct[] = {-0.912632942, 4.08398056, -25.385608673, 2.720281124, 7.745952606, 20.893184662}; + TensorInfo out_params; + out_params.data = correct; + out_params.len = 6; + std::vector out_shape{1, 2, 3}; + out_params.shape = &out_shape; + out_params.min = -30; + out_params.max = 30; - *correct = reinterpret_cast(malloc(out_t->ElementsNum() * sizeof(float))); - float nchw_co[] = {-0.912632942, 4.08398056, -25.385608673, 2.720281124, 7.745952606, 20.893184662}; - memcpy(*correct, nchw_co, out_t->ElementsNum() * sizeof(float)); - - matmal_param->b_transpose_ = true; - matmal_param->a_transpose_ = false; - matmal_param->has_bias_ = false; - return out_t->ElementsNum(); -} - -TEST_F(TestMatmulInt8, mmint8) { - std::vector inputs_; - std::vector outputs_; auto matmul_param = new MatMulParameter(); - float *correct; - double output_scale; - int output_zp; - int total_size = MMInt8TestInit(&inputs_, &outputs_, matmul_param, &correct, &output_scale, &output_zp); + matmul_param->a_transpose_ = false; + matmul_param->b_transpose_ = true; + matmul_param->has_bias_ = false; + std::vector inputs; + std::vector outputs; + MMInt8TestInit(&inputs, &outputs, &in_params, &weight_params, &out_params); auto ctx = new lite::Context; - ctx->thread_num_ = 2; + ctx->thread_num_ = 1; kernel::MatmulInt8CPUKernel *mm = - new kernel::MatmulInt8CPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx, nullptr); + new kernel::MatmulInt8CPUKernel(reinterpret_cast(matmul_param), inputs, outputs, ctx, nullptr); mm->Init(); mm->Run(); - float fout[6] = {0}; - Dequantize(reinterpret_cast(outputs_[0]->Data()), outputs_[0]->ElementsNum(), output_scale, output_zp, - fout); - CompareOutputData(fout, correct, 6, 0.3); + float out_scale; + int out_zp; + QuantProcess(correct, out_params.len, out_params.min, out_params.max, &out_scale, &out_zp, nullptr); + float *out = new float[out_params.len]; + Dequantize(reinterpret_cast(outputs[0]->Data()), outputs[0]->ElementsNum(), out_scale, out_zp, out); + CompareOutputData(out, correct, 6, 0.3); delete mm; - for (auto t : inputs_) delete t; - for (auto t : outputs_) delete t; - free(correct); + for (auto t : inputs) delete t; + for (auto t : outputs) delete t; + delete[] out; +} + +TEST_F(TestMatmulInt8, mmtest2) { + float in[] = { + -9.302902352910598, 16.65876088354537, -7.2801759810348265, -6.3246021711950995, 8.467234093555248, + -4.729482636552028, -3.747183865378627, -8.690477390174504, -2.7419930714530523, -3.9478573566319, + 7.399137633080947, -1.604450983941291, 0.3115665358682982, -16.864318496334278, 2.5447052588244112, + -13.428639671203255, 13.417832391771974, 10.37917002467671, 14.709787234172168, -16.347969268427146, + 4.652834783979106, 6.03601450738973, 2.5788179666401874, -9.236801653471375, -0.18997468903009462, + 19.977363387313744, 15.163337058447325, -12.602897730843484, -6.178507797555191, 13.457928661476004, + -10.65587824516124, -18.715557779424188, -9.758039647923935, 8.102044210643097, 19.66309736072973, + -13.368041407254193, 9.928467253978024, 4.9981961360698755, -4.2838547685981645, 1.5021547181513526, + -7.043468062239523, 11.964494917194845, -4.783071964346499, -17.646518743891008, -7.77810768119101, + 14.869414292570454, 8.333036603520906, 11.053769742928765, -1.768128637419725, -14.971400302494597, + -0.8653626097283293, -6.21101640878031, 14.83875267850518, 7.224097292538833, -16.747116419664213, + 15.310978507353724, -0.05593751363976551, 2.066880260042687, -3.7053788331264137, 9.788933831258937, + -13.614523856950811, -5.656865231633642, 4.720270075138178, -8.366650073458409, 7.197187069893303, + -18.78518907850054, 15.691652153539678, 7.914926057095165, 10.073559408864384, 10.437631177498353, + -3.0580194164595085, 17.36998905922836, 0.09998119223460122, 19.519199178417452, -11.121833210377702, + 19.655990774915622, -17.25682638091008, 11.701013896880006, -12.746728025401781, -9.370055221833699, + 18.720474512055908, 7.634198897927405, -15.521885320500694, -9.119267877304358, -1.5853789671841945, + 4.783147823043613, 14.6732610092525, -9.294170215010427, 9.835421489234331, 13.051159704193232, + -1.422599906517025, -1.5530696181467398, 19.51404609713284, -12.297429715833763, 6.8811248552401985, + 13.052476234003755, 18.66085390709462, -8.097735292301103, -6.868239274661935, -8.067142805841826, + 3.2707808734101533, 1.8239332220210827}; + TensorInfo in_params; + in_params.data = in; + in_params.len = 6 * 17; + std::vector in_shape{1, 6, 17}; + in_params.shape = &in_shape; + in_params.min = -20; + in_params.max = 20; + + float weight[] = { + -0.42740096214251677, 0.8557068789482212, 0.4560574664172552, -0.1317821769705021, 0.2845963675712846, + 0.8414603241768241, 0.24513271080109011, 0.16403708196683398, -0.09111601416189297, -0.714027790956111, + 0.12253431683185845, -0.4542459426686125, 0.7123202105555202, -0.3708573394849488, -0.4571735646072892, + -0.595627630450934, -0.5022671357384993, 0.2781065609468565, -0.07586181451887586, -0.2667701710291306, + 0.03141663091360791, -0.013304592900917456, -0.7507975439396768, 0.5886778622432618, -0.9056075431439199, + 0.9393767525356569, -0.2791312477047512, 0.7134531940450286, 0.3977932134993216, -0.027832574334469395, + 0.7222024948455503, -0.2084178952731608, -0.4869535410639745, -0.8255185994321805, 0.975443145421772, + 0.541914384763855, -0.8831162309708303, -0.3339354888475805, 0.3699271440691516, -0.26923635397292944, + -0.4975347179262828, 0.2440013185603882, 0.5553443771246633, 0.6111909921005778, -0.5968624036034165, + 0.8367593317557596, -0.843079440282104, -0.5651924211153698, 0.7169318662247579, 0.5116755837443465, + -0.9079299375502927, 0.025240632113315176, -0.5819662075810048, -0.37278414060319176, -0.172154755034845, + -0.7372352723583462, 0.2462103743741677, 0.11785417820789856, 0.6712183976911841, -0.7042964391243491, + -0.8215958062965967, -0.7304378130182314, 0.3991295415760667, -0.07226694075875573, 0.9329628273800614, + 0.7866596674858193, 0.9410341281569592, 0.39672750454198225, -0.5217505454791054, 0.9538253510722774, + -0.6286845762774464, -0.773460418882959, 0.002296000778892804, 0.9763898918063998, 0.9648708739062339, + 0.9400037814137154, -0.6011085333221611, -0.5890262409238565, -0.8078857772627164, 0.233661306598278, + -0.6726381934018617, -0.08533323149874539, 0.19055766469859425, -0.7956482347958518, -0.17012651641579035, + 0.7181052528631318, 0.1285045774388125, -0.6997527417326721, -0.8436484573035989, 0.342855467305474, + 0.4085157503460306, -0.6199324510955382, -0.6883822276097309, 0.4186437018431113, 0.3030114883148305, + 0.0948227655828271, -0.002521771948760465, -0.34878560791422397, 0.08513437045281003, 0.3116035319055901, + -0.7177514192203747, 0.050531673446029046, -0.7399803440665007, -0.9353609485885221, -0.3899340891814298, + 0.40867084031625356, -0.17462484099335662, -0.6313167634279941, -0.8135597146296727, -0.9762553414099975, + -0.1040485487920626, -0.6517520252975368, 0.5877412140956126, 0.9433584450325512, 0.24701546283170672, + -0.3236849444311023, -0.12043548611719657, 0.5300129281052712, -0.1380138229226111, -0.8787455295545508, + -0.4361728423289617, 0.7331994894985936, 0.45492774136929826, -0.17836517403432972, 0.10896668585054625, + 0.6176507847785211, 0.21617962964770676, -0.6821928873814629, 0.021775035324277825, 0.15089571088539566, + -0.9923383126255942, -0.6034706970202426, 0.17729888871670285, 0.1278810065499425, -0.6575545415840387, + -0.022704865415375197, -0.7366071817901978, -0.9300211224192332, -0.153494127035938, 0.4836121912045357, + -0.3318483587414114, -0.9658468087620375, 0.8388464445207262, 0.45745949405796127, -0.3671803281863002, + -0.1543498074773253, 0.18955899788963748, -0.4452120359256351, -0.5338599486040962, -0.06979561022721281, + -0.45964195574917355, -0.4343754114042866, -0.4318308749403197, 0.748107130947133, -0.4703901010752156, + 0.6655596561650823, 0.9075215202451821, 0.2708741258104177, -0.6540233471632313, 0.7250124906689572, + 0.6674821078610087, 0.8464696566759315, -0.6106156844283976, 0.8675828337337224, 0.8517737949695063, + -0.8126381016475459, -0.6140987457462099, -0.2984524227549874, 0.2816320572339577, -0.8131479383469931}; + TensorInfo weight_params; + weight_params.data = weight; + weight_params.len = 170; + std::vector weight_shape{1, 17, 10}; + weight_params.shape = &weight_shape; + weight_params.min = -1; + weight_params.max = 1; + + float correct[] = {35.815605, 26.532362, 14.777507, -12.651591, -2.0373726, -47.020798, -18.53121, 2.7848654, + 16.19751, -30.754261, 25.830605, 47.635204, 10.247462, -33.260662, 34.145412, -6.1611304, + -18.56802, -24.669813, 20.314533, -5.887198, -14.757037, 24.78901, 20.512205, 17.985718, + 17.62954, 20.365099, -26.223736, 0.99702793, 12.752281, -35.30419, -22.09603, 8.2218, + 8.120908, 27.685753, -44.010464, -1.879332, -4.531702, 21.434296, 4.2146144, 22.721859, + 7.485317, 20.148363, -15.49375, -4.5062046, 37.77292, -0.23385821, -45.532917, -21.055403, + 46.854183, -13.595161, 2.8823144, -23.905682, 2.3569264, 26.975227, 32.806625, 9.185071, + -39.330578, -1.0041192, -6.8353715, -33.2658}; + TensorInfo out_params; + out_params.data = correct; + out_params.len = 60; + std::vector out_shape{1, 6, 10}; + out_params.shape = &out_shape; + out_params.min = -50; + out_params.max = 50; + + auto matmul_param = new MatMulParameter(); + matmul_param->a_transpose_ = false; + matmul_param->b_transpose_ = false; + matmul_param->has_bias_ = false; + std::vector inputs; + std::vector outputs; + MMInt8TestInit(&inputs, &outputs, &in_params, &weight_params, &out_params); + auto ctx = new lite::Context; + ctx->thread_num_ = 2; + kernel::MatmulInt8CPUKernel *mm = + new kernel::MatmulInt8CPUKernel(reinterpret_cast(matmul_param), inputs, outputs, ctx, nullptr); + + mm->Init(); + mm->Run(); + float out_scale; + int out_zp; + QuantProcess(correct, out_params.len, out_params.min, out_params.max, &out_scale, &out_zp, nullptr); + float *out = new float[out_params.len]; + Dequantize(reinterpret_cast(outputs[0]->Data()), outputs[0]->ElementsNum(), out_scale, out_zp, out); + CompareOutputData(out, correct, 6, 0.6); + delete mm; + for (auto t : inputs) delete t; + for (auto t : outputs) delete t; + delete[] out; } } // namespace mindspore