!26050 [MS][LITE]Add int8 clip

Merge pull request !26050 from gongdaguo/add_int8_clip
This commit is contained in:
i-robot 2021-11-15 11:53:28 +00:00 committed by Gitee
commit a14a777464
13 changed files with 887 additions and 942 deletions

View File

@ -46,7 +46,7 @@ mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_utils_
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/pooling_int8.c:AvgPoolingOptInt8
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/pooling_int8.c:MaxPoolingWithQuantInt8
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/conv3x3_int8.c:Conv3x3Int8OutputUnit
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/pack_int8.c:Conv1x1PreOptPeroc
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/conv_int8.c:Conv1x1PreOptPeroc
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.c:RegisterInfer
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/gemm.c:RowMajor2Col12MajorStride
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/gemm.c:RowMajor2Col8MajorStride
@ -65,8 +65,8 @@ mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/pooling_int8.c:
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/conv3x3_int8.c:Conv3x3Int8InputUnit
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/conv3x3_int8.c:Conv3x3Int8FilterTransform
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/conv3x3_int8.c:Conv3x3Int8OutputUnit
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/pack_int8.c:Conv1x1PreOptPeroc
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/pack_int8.c:Conv1x1PreOptPert
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/conv_int8.c:Conv1x1PreOptPeroc
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/conv_int8.c:Conv1x1PreOptPert
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/pack_int8.c:PackNHWCToNCHWInt8
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/pooling_fp32.c:AvgPooling
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.c:SWConv3x32Kernel, SWConv4x24Kernel, SWConv12x8Kernel, SWConv8x8Kernel, SWConv4x8Kernel, SWConv6x16Kernel, SWConv4x16Kernel

View File

@ -30,12 +30,27 @@ endif()
file(GLOB KERNEL_SRC
${NNACL_DIR}/*.c
${NNACL_DIR}/fp32/*.c
${NNACL_DIR}/int8/*.c
${NNACL_DIR}/infer/*.c
${NNACL_DIR}/base/*.c
${NNACL_DIR}/fp32_grad/*.c
)
if(OP_INT8_CLIP)
set(KERNEL_SRC
${KERNEL_SRC}
${NNACL_DIR}/int8/pack_int8.c
${NNACL_DIR}/int8/quantize.c
)
else()
file(GLOB KERNEL_SRC_INT8
${NNACL_DIR}/int8/*.c
)
set(KERNEL_SRC
${KERNEL_SRC}
${KERNEL_SRC_INT8}
)
endif()
if(MSLITE_ENABLE_SPARSE_COMPUTE)
file(GLOB KERNEL_SRC_SPARSE
${NNACL_DIR}/fp32_sparse/*.c

View File

@ -16,6 +16,818 @@
#include "nnacl/int8/conv_int8.h"
#ifdef ENABLE_ARM32
void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_sum, const int32_t *filter_zp_ptr,
size_t plane_size, size_t input_channel, size_t output_channel) {
size_t hw4 = UP_ROUND(plane_size, C4NUM);
size_t ic16 = UP_ROUND(input_channel, C16NUM);
#ifdef ENABLE_ARM32
size_t oc_div2 = output_channel / C2NUM * C2NUM;
size_t oc_res2 = output_channel - oc_div2;
size_t inputsun_stride = hw4 * C2NUM * 4 - C4NUM * C2NUM * 4;
PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div2, oc_res2, inputsun_stride);
#else
for (int ri = 0; ri < plane_size; ri++) {
int ri4div = ri / C4NUM, ri4mod = ri % C4NUM;
for (int ci = 0; ci < output_channel; ci++) {
int32_t tmp_sum_value = 0;
int ci2div = ci / C2NUM, ci2mod = ci % C2NUM;
int32_t filter_zp = filter_zp_ptr[ci];
for (int di = 0; di < input_channel; di++) {
size_t di16div = di / C16NUM, di16mod = di % C16NUM;
int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod;
tmp_sum_value += input_value[src_index];
}
int dst_index = ci2div * C2NUM * hw4 + ri * C2NUM + ci2mod;
input_sum[dst_index] = tmp_sum_value * filter_zp;
}
}
#endif
return;
}
#endif
void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, const int32_t *filter_zp_ptr,
size_t plane_size, size_t input_channel, size_t output_channel) {
size_t hw4 = UP_ROUND(plane_size, C4NUM);
size_t ic16 = UP_ROUND(input_channel, C16NUM);
#ifdef ENABLE_ARM64
size_t oc_div4 = output_channel / C4NUM * C4NUM;
size_t oc_res4 = output_channel - oc_div4;
size_t inputsun_stride = hw4 * C4NUM * 4 - C4NUM * C4NUM * 4;
PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div4, oc_res4, inputsun_stride);
#else
for (int ri = 0; ri < plane_size; ri++) {
int ri4div = ri / C4NUM, ri4mod = ri % C4NUM;
for (int ci = 0; ci < output_channel; ci++) {
int32_t tmp_sum_value = 0;
int ci4div = ci / C4NUM, ci4mod = ci % C4NUM;
int32_t filter_zp = filter_zp_ptr[ci];
for (int di = 0; di < input_channel; di++) {
size_t di16div = di / C16NUM, di16mod = di % C16NUM;
int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod;
tmp_sum_value += input_value[src_index];
}
int dst_index = ci4div * C4NUM * hw4 + ri * C4NUM + ci4mod;
input_sum[dst_index] = tmp_sum_value * filter_zp;
}
}
#endif
return;
}
void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
size_t output_channel, size_t plane_size, const int32_t *filter_zp, size_t inputsum_stride) {
int ic4 = UP_ROUND(input_channel, C4NUM);
int oc8 = UP_ROUND(output_channel, C8NUM);
int hw8 = UP_ROUND(plane_size, C8NUM);
size_t hw_8div = plane_size / C8NUM * C8NUM;
size_t oc_8div = output_channel / C8NUM * C8NUM;
size_t oc_8res = output_channel - oc_8div;
size_t ic_4div = input_channel / C4NUM * C4NUM;
const int8_t *src_r = src_input;
int8_t *pack_r = packed_input;
int32_t *input_sum_r = input_sum;
for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) {
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
int32_t *input_sum_oc = input_sum_r;
#ifdef ENABLE_ARM64
size_t src_stride = input_channel;
size_t ic_4res = input_channel - ic_4div;
size_t input_sum_stride = inputsum_stride * 4 - C8NUM * C8NUM * 4;
asm volatile(
"dup v16.4s, wzr \n"
"dup v17.4s, wzr \n"
"mov x10, %[src_ic] \n"
"mov x11, %[pack_ic] \n"
"mov x0, #0 \n"
"1: \n"
"cmp x0, %[ic_4div] \n"
"add x0, x0, #4\n"
"mov x12, x10 \n"
"add x10, x10, #4\n"
"blt 2f \n"
"cmp %[ic_4res], #0\n"
"beq 6f \n"
"cmp %[ic_4res], #1\n"
"beq 3f \n"
"cmp %[ic_4res], #2\n"
"beq 4f \n"
"cmp %[ic_4res], #3\n"
"beq 5f \n"
"2: \n"
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
"ld1 {v1.s}[0], [x12], %[src_stride]\n"
"ld1 {v1.s}[1], [x12], %[src_stride]\n"
"ld1 {v1.s}[2], [x12], %[src_stride]\n"
"ld1 {v1.s}[3], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 1b \n"
"3: \n" /* col res 1 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"ld1 {v0.b}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
"ld1 {v1.b}[0], [x12], %[src_stride]\n"
"ld1 {v1.b}[4], [x12], %[src_stride]\n"
"ld1 {v1.b}[8], [x12], %[src_stride]\n"
"ld1 {v1.b}[12], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"4: \n" /* col res 2 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"5: \n" /* col res 3 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"add x13, x12, #2 \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
"ld1 {v1.b}[2], [x13], %[src_stride]\n"
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
"ld1 {v1.b}[6], [x13], %[src_stride]\n"
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
"ld1 {v1.b}[10], [x13], %[src_stride]\n"
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
"ld1 {v1.b}[14], [x13], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"6: \n"
"dup v0.4s, v16.s[0] \n"
"dup v1.4s, v16.s[1] \n"
"dup v2.4s, v16.s[2] \n"
"dup v3.4s, v16.s[3] \n"
"dup v4.4s, v17.s[0] \n"
"dup v5.4s, v17.s[1] \n"
"dup v6.4s, v17.s[2] \n"
"dup v7.4s, v17.s[3] \n"
"mov x4, #0 \n"
"mov x10, %[filter_zp] \n"
"mov x11, %[input_sum_oc] \n"
"7: \n"
"cmp x4, %[oc_8div] \n"
"beq 8f \n"
"add x4, x4, #8\n"
"ld1 {v16.4s}, [x10], #16\n"
"ld1 {v17.4s}, [x10], #16\n"
"mul v18.4s, v16.4s, v0.4s \n"
"mul v19.4s, v17.4s, v0.4s \n"
"st1 {v18.4s}, [x11], #16 \n"
"st1 {v19.4s}, [x11], #16 \n"
"mul v20.4s, v16.4s, v1.4s \n"
"mul v21.4s, v17.4s, v1.4s \n"
"st1 {v20.4s}, [x11], #16 \n"
"st1 {v21.4s}, [x11], #16 \n"
"mul v22.4s, v16.4s, v2.4s \n"
"mul v23.4s, v17.4s, v2.4s \n"
"st1 {v22.4s}, [x11], #16 \n"
"st1 {v23.4s}, [x11], #16 \n"
"mul v24.4s, v16.4s, v3.4s \n"
"mul v25.4s, v17.4s, v3.4s \n"
"st1 {v24.4s}, [x11], #16 \n"
"st1 {v25.4s}, [x11], #16 \n"
"mul v18.4s, v16.4s, v4.4s \n"
"mul v19.4s, v17.4s, v4.4s \n"
"st1 {v18.4s}, [x11], #16 \n"
"st1 {v19.4s}, [x11], #16 \n"
"mul v20.4s, v16.4s, v5.4s \n"
"mul v21.4s, v17.4s, v5.4s \n"
"st1 {v20.4s}, [x11], #16 \n"
"st1 {v21.4s}, [x11], #16 \n"
"mul v22.4s, v16.4s, v6.4s \n"
"mul v23.4s, v17.4s, v6.4s \n"
"st1 {v22.4s}, [x11], #16 \n"
"st1 {v23.4s}, [x11], #16 \n"
"mul v24.4s, v16.4s, v7.4s \n"
"mul v25.4s, v17.4s, v7.4s \n"
"st1 {v24.4s}, [x11], #16 \n"
"st1 {v25.4s}, [x11], #16 \n"
"add x11, x11, %[input_sum_stride] \n"
"b 7b \n"
"8: \n"
"cmp %[oc_8res], #0\n"
"beq 17f \n"
"dup v16.4s, wzr \n"
"dup v17.4s, wzr \n"
"cmp %[oc_8res], #1\n"
"beq 9f \n"
"cmp %[oc_8res], #2\n"
"beq 10f \n"
"cmp %[oc_8res], #3\n"
"beq 11f \n"
"cmp %[oc_8res], #4\n"
"beq 12f \n"
"cmp %[oc_8res], #5\n"
"beq 13f \n"
"cmp %[oc_8res], #6\n"
"beq 14f \n"
"cmp %[oc_8res], #7\n"
"beq 15f \n"
"9: \n"
"ld1 {v16.s}[0], [x10] \n"
"b 16f \n"
"10: \n"
"ld1 {v16.d}[0], [x10] \n"
"b 16f \n"
"11: \n"
"ld1 {v16.d}[0], [x10] \n"
"add x10, x10, #8 \n"
"ld1 {v16.s}[2], [x10] \n"
"b 16f \n"
"12: \n"
"ld1 {v16.4s}, [x10] \n"
"b 16f \n"
"13: \n"
"ld1 {v16.4s}, [x10], #16\n"
"ld1 {v17.s}[0], [x10] \n"
"b 16f \n"
"14: \n"
"ld1 {v16.4s}, [x10], #16\n"
"ld1 {v17.d}[0], [x10] \n"
"b 16f \n"
"15: \n"
"ld1 {v16.4s}, [x10], #16\n"
"ld1 {v17.d}[0], [x10] \n"
"add x10, x10, #8 \n"
"ld1 {v17.s}[2], [x10] \n"
"b 16f \n"
"16: \n"
"mul v18.4s, v16.4s, v0.4s \n"
"mul v19.4s, v17.4s, v0.4s \n"
"mul v20.4s, v16.4s, v1.4s \n"
"mul v21.4s, v17.4s, v1.4s \n"
"mul v22.4s, v16.4s, v2.4s \n"
"mul v23.4s, v17.4s, v2.4s \n"
"mul v24.4s, v16.4s, v3.4s \n"
"mul v25.4s, v17.4s, v3.4s \n"
"st1 {v18.4s}, [x11], #16 \n"
"st1 {v19.4s}, [x11], #16 \n"
"st1 {v20.4s}, [x11], #16 \n"
"st1 {v21.4s}, [x11], #16 \n"
"st1 {v22.4s}, [x11], #16 \n"
"st1 {v23.4s}, [x11], #16 \n"
"st1 {v24.4s}, [x11], #16 \n"
"st1 {v25.4s}, [x11], #16 \n"
"mul v18.4s, v16.4s, v4.4s \n"
"mul v19.4s, v17.4s, v4.4s \n"
"mul v20.4s, v16.4s, v5.4s \n"
"mul v21.4s, v17.4s, v5.4s \n"
"mul v22.4s, v16.4s, v6.4s \n"
"mul v23.4s, v17.4s, v6.4s \n"
"mul v24.4s, v16.4s, v7.4s \n"
"mul v25.4s, v17.4s, v7.4s \n"
"st1 {v18.4s}, [x11], #16 \n"
"st1 {v19.4s}, [x11], #16 \n"
"st1 {v20.4s}, [x11], #16 \n"
"st1 {v21.4s}, [x11], #16 \n"
"st1 {v22.4s}, [x11], #16 \n"
"st1 {v23.4s}, [x11], #16 \n"
"st1 {v24.4s}, [x11], #16 \n"
"st1 {v25.4s}, [x11], #16 \n"
"17: \n"
:
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ filter_zp ] "r"(filter_zp),
[ input_sum_oc ] "r"(input_sum_oc), [ input_sum_stride ] "r"(input_sum_stride), [ src_stride ] "r"(src_stride),
[ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ oc_8div ] "r"(oc_8div), [ oc_8res ] "r"(oc_8res)
: "x0", "x1", "x4", "x9", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16",
"v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25");
#else
int32_t tmp_sum_value[8] = {0};
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
for (int i = 0; i < C8NUM; i++) {
tmp_sum_value[i] += src_ic[0 + i * input_channel];
tmp_sum_value[i] += src_ic[1 + i * input_channel];
tmp_sum_value[i] += src_ic[2 + i * input_channel];
tmp_sum_value[i] += src_ic[3 + i * input_channel];
pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel];
pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel];
pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel];
pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel];
}
src_ic += C4NUM;
pack_ic += C4NUM * C8NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
for (int i = 0; i < C8NUM; i++) {
tmp_sum_value[i] += src_ic[i * input_channel];
pack_ic[i * C4NUM] = src_ic[i * input_channel];
}
src_ic += 1;
pack_ic += 1;
}
for (int ici = input_channel; ici < ic4; ici += 1) {
for (int i = 0; i < C8NUM; i++) {
pack_ic[i * C4NUM] = 0;
}
pack_ic += 1;
}
for (int oci = 0; oci < oc_8div; oci += C8NUM) {
for (int ri = 0; ri < C8NUM; ri++) {
input_sum_oc[ri * C8NUM + 0] = tmp_sum_value[ri] * filter_zp[oci + 0];
input_sum_oc[ri * C8NUM + 1] = tmp_sum_value[ri] * filter_zp[oci + 1];
input_sum_oc[ri * C8NUM + 2] = tmp_sum_value[ri] * filter_zp[oci + 2];
input_sum_oc[ri * C8NUM + 3] = tmp_sum_value[ri] * filter_zp[oci + 3];
input_sum_oc[ri * C8NUM + 4] = tmp_sum_value[ri] * filter_zp[oci + 4];
input_sum_oc[ri * C8NUM + 5] = tmp_sum_value[ri] * filter_zp[oci + 5];
input_sum_oc[ri * C8NUM + 6] = tmp_sum_value[ri] * filter_zp[oci + 6];
input_sum_oc[ri * C8NUM + 7] = tmp_sum_value[ri] * filter_zp[oci + 7];
}
input_sum_oc += inputsum_stride;
}
if (oc_8div != output_channel) {
for (int oci = 0; oci < oc_8res; oci += 1) {
for (int ri = 0; ri < C8NUM; ri++) {
input_sum_oc[ri * C8NUM + oci] = tmp_sum_value[ri] * filter_zp[oc_8div + oci];
}
}
for (int oci = oc_8res; oci < C8NUM; oci += 1) {
for (int ri = 0; ri < C8NUM; ri++) {
input_sum_oc[ri * C8NUM + oci] = 0;
}
}
} /* oc8 res done */
#endif
src_r += input_channel * C8NUM;
pack_r += ic4 * C8NUM;
input_sum_r += C8NUM * C8NUM;
}
if (hw_8div != plane_size) {
memset(pack_r, 0, C8NUM * ic4);
for (int hwi = hw_8div; hwi < plane_size; hwi += 1) {
int32_t *input_sum_oc = input_sum_r;
int32_t tmp_sum_value = 0;
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
tmp_sum_value += src_ic[0];
tmp_sum_value += src_ic[1];
tmp_sum_value += src_ic[2];
tmp_sum_value += src_ic[3];
pack_ic[0] = src_ic[0];
pack_ic[1] = src_ic[1];
pack_ic[2] = src_ic[2];
pack_ic[3] = src_ic[3];
src_ic += C4NUM;
pack_ic += C4NUM * C8NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
tmp_sum_value += src_ic[0];
pack_ic[0] = src_ic[0];
src_ic += 1;
pack_ic += 1;
}
for (int oci = 0; oci < oc_8div; oci += C8NUM) {
for (int curoi = 0; curoi < C8NUM; curoi++) {
input_sum_oc[curoi] = tmp_sum_value * filter_zp[oci + curoi];
}
input_sum_oc += inputsum_stride;
}
if (oc_8div != output_channel) {
for (int oci = 0; oci < oc_8res; oci += 1) {
input_sum_oc[oci] = tmp_sum_value * filter_zp[oc_8div + oci];
}
for (int oci = oc_8res; oci < C8NUM; oci += 1) {
input_sum_oc[oci] = 0;
}
} /* oc8 res done */
src_r += input_channel;
pack_r += C4NUM;
input_sum_r += C8NUM;
}
for (int hwi = plane_size; hwi < hw8; hwi++) {
for (int oc = 0; oc < oc8; oc++) {
int oc8div = oc / C8NUM, oc8res = oc % C8NUM;
input_sum[oc8div * inputsum_stride + hwi * C8NUM + oc8res] = 0;
}
}
}
return;
}
void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
size_t plane_size, const ConvParameter *conv_param) {
int ic4 = UP_ROUND(input_channel, C4NUM);
size_t hw_8div = plane_size / C8NUM * C8NUM;
size_t ic_4div = input_channel / C4NUM * C4NUM;
int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_;
const int8_t *src_r = src_input;
int8_t *pack_r = packed_input;
/* per layer */
for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) {
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
int32_t *input_sum_r = input_sum + hwi;
#ifdef ENABLE_ARM64
size_t src_stride = input_channel;
size_t ic_4res = input_channel - ic_4div;
asm volatile(
"dup v16.4s, wzr \n"
"dup v17.4s, wzr \n"
"mov x14, %[input_sum_r] \n"
"dup v20.4s, %w[filter_zp] \n"
"mov x10, %[src_ic] \n"
"mov x11, %[pack_ic] \n"
"mov x0, #0 \n"
"1: \n"
"cmp x0, %[ic_4div] \n"
"add x0, x0, #4\n"
"mov x12, x10 \n"
"add x10, x10, #4\n"
"blt 2f \n"
"cmp %[ic_4res], #0\n"
"beq 6f \n"
"cmp %[ic_4res], #1\n"
"beq 3f \n"
"cmp %[ic_4res], #2\n"
"beq 4f \n"
"cmp %[ic_4res], #3\n"
"beq 5f \n"
"2: \n"
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
"ld1 {v1.s}[0], [x12], %[src_stride]\n"
"ld1 {v1.s}[1], [x12], %[src_stride]\n"
"ld1 {v1.s}[2], [x12], %[src_stride]\n"
"ld1 {v1.s}[3], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 1b \n"
"3: \n" /* col res 1 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"ld1 {v0.b}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
"ld1 {v1.b}[0], [x12], %[src_stride]\n"
"ld1 {v1.b}[4], [x12], %[src_stride]\n"
"ld1 {v1.b}[8], [x12], %[src_stride]\n"
"ld1 {v1.b}[12], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"4: \n" /* col res 2 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"5: \n" /* col res 3 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"add x13, x12, #2 \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
"ld1 {v1.b}[2], [x13], %[src_stride]\n"
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
"ld1 {v1.b}[6], [x13], %[src_stride]\n"
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
"ld1 {v1.b}[10], [x13], %[src_stride]\n"
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
"ld1 {v1.b}[14], [x13], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"6: \n"
"mul v16.4s, v16.4s, v20.4s \n"
"mul v17.4s, v17.4s, v20.4s \n"
"st1 {v16.4s}, [x14], #16 \n"
"st1 {v17.4s}, [x14], #16 \n"
:
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r),
[ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp)
: "x0", "x1", "x10", "x11", "x12", "x13", "x14", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17",
"v20");
#else
int32_t tmp_sum_value[8] = {0};
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
for (int i = 0; i < C8NUM; i++) {
tmp_sum_value[i] += src_ic[0 + i * input_channel];
tmp_sum_value[i] += src_ic[1 + i * input_channel];
tmp_sum_value[i] += src_ic[2 + i * input_channel];
tmp_sum_value[i] += src_ic[3 + i * input_channel];
pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel];
pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel];
pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel];
pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel];
}
src_ic += C4NUM;
pack_ic += C4NUM * C8NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
for (int i = 0; i < C8NUM; i++) {
tmp_sum_value[i] += src_ic[i * input_channel];
pack_ic[i * C4NUM] = src_ic[i * input_channel];
}
src_ic += 1;
pack_ic += 1;
}
for (int ici = input_channel; ici < ic4; ici += 1) {
for (int i = 0; i < C8NUM; i++) {
pack_ic[i * C4NUM] = 0;
}
pack_ic += 1;
}
for (int i = 0; i < C8NUM; i++) {
input_sum_r[i] = tmp_sum_value[i] * filter_zp;
}
#endif
src_r += input_channel * C8NUM;
pack_r += ic4 * C8NUM;
}
if (hw_8div != plane_size) {
memset(pack_r, 0, C8NUM * ic4);
for (int hwi = hw_8div; hwi < plane_size; hwi += 1) {
int32_t tmp_sum_value = 0;
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
tmp_sum_value += src_ic[0];
tmp_sum_value += src_ic[1];
tmp_sum_value += src_ic[2];
tmp_sum_value += src_ic[3];
pack_ic[0] = src_ic[0];
pack_ic[1] = src_ic[1];
pack_ic[2] = src_ic[2];
pack_ic[3] = src_ic[3];
src_ic += C4NUM;
pack_ic += C4NUM * C8NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
tmp_sum_value += src_ic[0];
pack_ic[0] = src_ic[0];
src_ic += 1;
pack_ic += 1;
}
input_sum[hwi] = tmp_sum_value * filter_zp;
src_r += input_channel;
pack_r += C4NUM;
}
for (int hwi = plane_size; hwi < UP_ROUND(plane_size, C8NUM); hwi++) {
input_sum[hwi] = 0;
}
}
return;
}
void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, const int32_t *filter_zp,
const ConvParameter *conv_param) {
size_t hw = conv_param->output_h_ * conv_param->output_w_;
size_t hw4 = UP_ROUND(hw, C4NUM);
size_t ic16 = UP_ROUND(conv_param->input_channel_, C16NUM);
if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) {
PackInputSum16x4PerLayer(input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16);
} else {
#ifdef ENABLE_ARM32
PackInputSum16x4PerChannelArm32(input, input_sum, filter_zp, hw, conv_param->input_channel_,
conv_param->output_channel_);
#else
PackInputSum16x4PerChannel(input, input_sum, filter_zp, hw, conv_param->input_channel_,
conv_param->output_channel_);
#endif
}
return;
}
void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num,
int block_index, const int32_t *filter_zp, int32_t *input_sum,
const ConvParameter *conv_param, bool per_channel, bool is_optimize) {
// input format : nhwc
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int pad_h = conv_param->pad_u_;
int pad_w = conv_param->pad_l_;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
int in_channel = conv_param->input_channel_;
int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_;
int out_w = conv_param->output_w_;
int kernel_plane = kernel_h * kernel_w;
NNACL_CHECK_ZERO_RETURN(out_w);
NNACL_CHECK_ZERO_RETURN(dilation_h);
NNACL_CHECK_ZERO_RETURN(dilation_w);
for (int i = 0; i < real_cal_num; i++) {
int block_start = block_index + i;
int input_h = block_start / out_w * stride_h - pad_h;
int input_w = block_start % out_w * stride_w - pad_w;
int input_stride = input_h * in_w * in_channel + input_w * in_channel;
int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h));
int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h));
int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
if (dilation_w == 1 && dilation_h == 1) {
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * in_w * in_channel + input_stride;
int input_x_stride = input_y_stride + kw_s * in_channel;
int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane;
memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, (kw_e - kw_s) * in_channel);
} // kernel_h loop
} else {
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * dilation_h * in_w * in_channel + input_stride;
for (int k = kw_s; k < kw_e; ++k) {
int input_x_stride = input_y_stride + k * dilation_w * in_channel;
int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane;
memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, in_channel);
}
} // kernel_h loop
}
} // tile num loop
int deep = kernel_plane * in_channel;
if (is_optimize) {
if (per_channel) {
Conv1x1PreOptPeroc(matmul_input, packed_input, input_sum, deep, conv_param->output_channel_, real_cal_num,
filter_zp, C8NUM * C8NUM);
} else {
Conv1x1PreOptPert(matmul_input, packed_input, input_sum, deep, real_cal_num, conv_param);
}
} else {
RowMajor2Row16x4MajorInt8(matmul_input, packed_input, real_cal_num, deep);
if (per_channel) {
#ifdef ENABLE_ARM32
PackInputSum16x4PerChannelArm32(packed_input, input_sum, filter_zp, real_cal_num, deep,
conv_param->output_channel_);
#else
PackInputSum16x4PerChannel(packed_input, input_sum, filter_zp, real_cal_num, deep, conv_param->output_channel_);
#endif
} else {
size_t hw4 = UP_ROUND(real_cal_num, C4NUM);
size_t ic16 = UP_ROUND(deep, C16NUM);
PackInputSum16x4PerLayer(packed_input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4,
ic16);
}
}
}
void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int8_t *packed_weight,
const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id,
ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func, bool is_optimize) {

View File

@ -16,799 +16,6 @@
#include "nnacl/int8/pack_int8.h"
#ifdef ENABLE_ARM32
void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_sum, const int32_t *filter_zp_ptr,
size_t plane_size, size_t input_channel, size_t output_channel) {
size_t hw4 = UP_ROUND(plane_size, C4NUM);
size_t ic16 = UP_ROUND(input_channel, C16NUM);
#ifdef ENABLE_ARM32
size_t oc_div2 = output_channel / C2NUM * C2NUM;
size_t oc_res2 = output_channel - oc_div2;
size_t inputsun_stride = hw4 * C2NUM * 4 - C4NUM * C2NUM * 4;
PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div2, oc_res2, inputsun_stride);
#else
for (int ri = 0; ri < plane_size; ri++) {
int ri4div = ri / C4NUM, ri4mod = ri % C4NUM;
for (int ci = 0; ci < output_channel; ci++) {
int32_t tmp_sum_value = 0;
int ci2div = ci / C2NUM, ci2mod = ci % C2NUM;
int32_t filter_zp = filter_zp_ptr[ci];
for (int di = 0; di < input_channel; di++) {
size_t di16div = di / C16NUM, di16mod = di % C16NUM;
int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod;
tmp_sum_value += input_value[src_index];
}
int dst_index = ci2div * C2NUM * hw4 + ri * C2NUM + ci2mod;
input_sum[dst_index] = tmp_sum_value * filter_zp;
}
}
#endif
return;
}
#endif
void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, const int32_t *filter_zp_ptr,
size_t plane_size, size_t input_channel, size_t output_channel) {
size_t hw4 = UP_ROUND(plane_size, C4NUM);
size_t ic16 = UP_ROUND(input_channel, C16NUM);
#ifdef ENABLE_ARM64
size_t oc_div4 = output_channel / C4NUM * C4NUM;
size_t oc_res4 = output_channel - oc_div4;
size_t inputsun_stride = hw4 * C4NUM * 4 - C4NUM * C4NUM * 4;
PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div4, oc_res4, inputsun_stride);
#else
for (int ri = 0; ri < plane_size; ri++) {
int ri4div = ri / C4NUM, ri4mod = ri % C4NUM;
for (int ci = 0; ci < output_channel; ci++) {
int32_t tmp_sum_value = 0;
int ci4div = ci / C4NUM, ci4mod = ci % C4NUM;
int32_t filter_zp = filter_zp_ptr[ci];
for (int di = 0; di < input_channel; di++) {
size_t di16div = di / C16NUM, di16mod = di % C16NUM;
int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod;
tmp_sum_value += input_value[src_index];
}
int dst_index = ci4div * C4NUM * hw4 + ri * C4NUM + ci4mod;
input_sum[dst_index] = tmp_sum_value * filter_zp;
}
}
#endif
return;
}
void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
size_t output_channel, size_t plane_size, const int32_t *filter_zp, size_t inputsum_stride) {
int ic4 = UP_ROUND(input_channel, C4NUM);
int oc8 = UP_ROUND(output_channel, C8NUM);
int hw8 = UP_ROUND(plane_size, C8NUM);
size_t hw_8div = plane_size / C8NUM * C8NUM;
size_t oc_8div = output_channel / C8NUM * C8NUM;
size_t oc_8res = output_channel - oc_8div;
size_t ic_4div = input_channel / C4NUM * C4NUM;
const int8_t *src_r = src_input;
int8_t *pack_r = packed_input;
int32_t *input_sum_r = input_sum;
for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) {
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
int32_t *input_sum_oc = input_sum_r;
#ifdef ENABLE_ARM64
size_t src_stride = input_channel;
size_t ic_4res = input_channel - ic_4div;
size_t input_sum_stride = inputsum_stride * 4 - C8NUM * C8NUM * 4;
asm volatile(
"dup v16.4s, wzr \n"
"dup v17.4s, wzr \n"
"mov x10, %[src_ic] \n"
"mov x11, %[pack_ic] \n"
"mov x0, #0 \n"
"1: \n"
"cmp x0, %[ic_4div] \n"
"add x0, x0, #4\n"
"mov x12, x10 \n"
"add x10, x10, #4\n"
"blt 2f \n"
"cmp %[ic_4res], #0\n"
"beq 6f \n"
"cmp %[ic_4res], #1\n"
"beq 3f \n"
"cmp %[ic_4res], #2\n"
"beq 4f \n"
"cmp %[ic_4res], #3\n"
"beq 5f \n"
"2: \n"
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
"ld1 {v1.s}[0], [x12], %[src_stride]\n"
"ld1 {v1.s}[1], [x12], %[src_stride]\n"
"ld1 {v1.s}[2], [x12], %[src_stride]\n"
"ld1 {v1.s}[3], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 1b \n"
"3: \n" /* col res 1 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"ld1 {v0.b}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
"ld1 {v1.b}[0], [x12], %[src_stride]\n"
"ld1 {v1.b}[4], [x12], %[src_stride]\n"
"ld1 {v1.b}[8], [x12], %[src_stride]\n"
"ld1 {v1.b}[12], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"4: \n" /* col res 2 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"5: \n" /* col res 3 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"add x13, x12, #2 \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
"ld1 {v1.b}[2], [x13], %[src_stride]\n"
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
"ld1 {v1.b}[6], [x13], %[src_stride]\n"
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
"ld1 {v1.b}[10], [x13], %[src_stride]\n"
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
"ld1 {v1.b}[14], [x13], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"6: \n"
"dup v0.4s, v16.s[0] \n"
"dup v1.4s, v16.s[1] \n"
"dup v2.4s, v16.s[2] \n"
"dup v3.4s, v16.s[3] \n"
"dup v4.4s, v17.s[0] \n"
"dup v5.4s, v17.s[1] \n"
"dup v6.4s, v17.s[2] \n"
"dup v7.4s, v17.s[3] \n"
"mov x4, #0 \n"
"mov x10, %[filter_zp] \n"
"mov x11, %[input_sum_oc] \n"
"7: \n"
"cmp x4, %[oc_8div] \n"
"beq 8f \n"
"add x4, x4, #8\n"
"ld1 {v16.4s}, [x10], #16\n"
"ld1 {v17.4s}, [x10], #16\n"
"mul v18.4s, v16.4s, v0.4s \n"
"mul v19.4s, v17.4s, v0.4s \n"
"st1 {v18.4s}, [x11], #16 \n"
"st1 {v19.4s}, [x11], #16 \n"
"mul v20.4s, v16.4s, v1.4s \n"
"mul v21.4s, v17.4s, v1.4s \n"
"st1 {v20.4s}, [x11], #16 \n"
"st1 {v21.4s}, [x11], #16 \n"
"mul v22.4s, v16.4s, v2.4s \n"
"mul v23.4s, v17.4s, v2.4s \n"
"st1 {v22.4s}, [x11], #16 \n"
"st1 {v23.4s}, [x11], #16 \n"
"mul v24.4s, v16.4s, v3.4s \n"
"mul v25.4s, v17.4s, v3.4s \n"
"st1 {v24.4s}, [x11], #16 \n"
"st1 {v25.4s}, [x11], #16 \n"
"mul v18.4s, v16.4s, v4.4s \n"
"mul v19.4s, v17.4s, v4.4s \n"
"st1 {v18.4s}, [x11], #16 \n"
"st1 {v19.4s}, [x11], #16 \n"
"mul v20.4s, v16.4s, v5.4s \n"
"mul v21.4s, v17.4s, v5.4s \n"
"st1 {v20.4s}, [x11], #16 \n"
"st1 {v21.4s}, [x11], #16 \n"
"mul v22.4s, v16.4s, v6.4s \n"
"mul v23.4s, v17.4s, v6.4s \n"
"st1 {v22.4s}, [x11], #16 \n"
"st1 {v23.4s}, [x11], #16 \n"
"mul v24.4s, v16.4s, v7.4s \n"
"mul v25.4s, v17.4s, v7.4s \n"
"st1 {v24.4s}, [x11], #16 \n"
"st1 {v25.4s}, [x11], #16 \n"
"add x11, x11, %[input_sum_stride] \n"
"b 7b \n"
"8: \n"
"cmp %[oc_8res], #0\n"
"beq 17f \n"
"dup v16.4s, wzr \n"
"dup v17.4s, wzr \n"
"cmp %[oc_8res], #1\n"
"beq 9f \n"
"cmp %[oc_8res], #2\n"
"beq 10f \n"
"cmp %[oc_8res], #3\n"
"beq 11f \n"
"cmp %[oc_8res], #4\n"
"beq 12f \n"
"cmp %[oc_8res], #5\n"
"beq 13f \n"
"cmp %[oc_8res], #6\n"
"beq 14f \n"
"cmp %[oc_8res], #7\n"
"beq 15f \n"
"9: \n"
"ld1 {v16.s}[0], [x10] \n"
"b 16f \n"
"10: \n"
"ld1 {v16.d}[0], [x10] \n"
"b 16f \n"
"11: \n"
"ld1 {v16.d}[0], [x10] \n"
"add x10, x10, #8 \n"
"ld1 {v16.s}[2], [x10] \n"
"b 16f \n"
"12: \n"
"ld1 {v16.4s}, [x10] \n"
"b 16f \n"
"13: \n"
"ld1 {v16.4s}, [x10], #16\n"
"ld1 {v17.s}[0], [x10] \n"
"b 16f \n"
"14: \n"
"ld1 {v16.4s}, [x10], #16\n"
"ld1 {v17.d}[0], [x10] \n"
"b 16f \n"
"15: \n"
"ld1 {v16.4s}, [x10], #16\n"
"ld1 {v17.d}[0], [x10] \n"
"add x10, x10, #8 \n"
"ld1 {v17.s}[2], [x10] \n"
"b 16f \n"
"16: \n"
"mul v18.4s, v16.4s, v0.4s \n"
"mul v19.4s, v17.4s, v0.4s \n"
"mul v20.4s, v16.4s, v1.4s \n"
"mul v21.4s, v17.4s, v1.4s \n"
"mul v22.4s, v16.4s, v2.4s \n"
"mul v23.4s, v17.4s, v2.4s \n"
"mul v24.4s, v16.4s, v3.4s \n"
"mul v25.4s, v17.4s, v3.4s \n"
"st1 {v18.4s}, [x11], #16 \n"
"st1 {v19.4s}, [x11], #16 \n"
"st1 {v20.4s}, [x11], #16 \n"
"st1 {v21.4s}, [x11], #16 \n"
"st1 {v22.4s}, [x11], #16 \n"
"st1 {v23.4s}, [x11], #16 \n"
"st1 {v24.4s}, [x11], #16 \n"
"st1 {v25.4s}, [x11], #16 \n"
"mul v18.4s, v16.4s, v4.4s \n"
"mul v19.4s, v17.4s, v4.4s \n"
"mul v20.4s, v16.4s, v5.4s \n"
"mul v21.4s, v17.4s, v5.4s \n"
"mul v22.4s, v16.4s, v6.4s \n"
"mul v23.4s, v17.4s, v6.4s \n"
"mul v24.4s, v16.4s, v7.4s \n"
"mul v25.4s, v17.4s, v7.4s \n"
"st1 {v18.4s}, [x11], #16 \n"
"st1 {v19.4s}, [x11], #16 \n"
"st1 {v20.4s}, [x11], #16 \n"
"st1 {v21.4s}, [x11], #16 \n"
"st1 {v22.4s}, [x11], #16 \n"
"st1 {v23.4s}, [x11], #16 \n"
"st1 {v24.4s}, [x11], #16 \n"
"st1 {v25.4s}, [x11], #16 \n"
"17: \n"
:
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ filter_zp ] "r"(filter_zp),
[ input_sum_oc ] "r"(input_sum_oc), [ input_sum_stride ] "r"(input_sum_stride), [ src_stride ] "r"(src_stride),
[ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ oc_8div ] "r"(oc_8div), [ oc_8res ] "r"(oc_8res)
: "x0", "x1", "x4", "x9", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16",
"v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25");
#else
int32_t tmp_sum_value[8] = {0};
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
for (int i = 0; i < C8NUM; i++) {
tmp_sum_value[i] += src_ic[0 + i * input_channel];
tmp_sum_value[i] += src_ic[1 + i * input_channel];
tmp_sum_value[i] += src_ic[2 + i * input_channel];
tmp_sum_value[i] += src_ic[3 + i * input_channel];
pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel];
pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel];
pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel];
pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel];
}
src_ic += C4NUM;
pack_ic += C4NUM * C8NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
for (int i = 0; i < C8NUM; i++) {
tmp_sum_value[i] += src_ic[i * input_channel];
pack_ic[i * C4NUM] = src_ic[i * input_channel];
}
src_ic += 1;
pack_ic += 1;
}
for (int ici = input_channel; ici < ic4; ici += 1) {
for (int i = 0; i < C8NUM; i++) {
pack_ic[i * C4NUM] = 0;
}
pack_ic += 1;
}
for (int oci = 0; oci < oc_8div; oci += C8NUM) {
for (int ri = 0; ri < C8NUM; ri++) {
input_sum_oc[ri * C8NUM + 0] = tmp_sum_value[ri] * filter_zp[oci + 0];
input_sum_oc[ri * C8NUM + 1] = tmp_sum_value[ri] * filter_zp[oci + 1];
input_sum_oc[ri * C8NUM + 2] = tmp_sum_value[ri] * filter_zp[oci + 2];
input_sum_oc[ri * C8NUM + 3] = tmp_sum_value[ri] * filter_zp[oci + 3];
input_sum_oc[ri * C8NUM + 4] = tmp_sum_value[ri] * filter_zp[oci + 4];
input_sum_oc[ri * C8NUM + 5] = tmp_sum_value[ri] * filter_zp[oci + 5];
input_sum_oc[ri * C8NUM + 6] = tmp_sum_value[ri] * filter_zp[oci + 6];
input_sum_oc[ri * C8NUM + 7] = tmp_sum_value[ri] * filter_zp[oci + 7];
}
input_sum_oc += inputsum_stride;
}
if (oc_8div != output_channel) {
for (int oci = 0; oci < oc_8res; oci += 1) {
for (int ri = 0; ri < C8NUM; ri++) {
input_sum_oc[ri * C8NUM + oci] = tmp_sum_value[ri] * filter_zp[oc_8div + oci];
}
}
for (int oci = oc_8res; oci < C8NUM; oci += 1) {
for (int ri = 0; ri < C8NUM; ri++) {
input_sum_oc[ri * C8NUM + oci] = 0;
}
}
} /* oc8 res done */
#endif
src_r += input_channel * C8NUM;
pack_r += ic4 * C8NUM;
input_sum_r += C8NUM * C8NUM;
}
if (hw_8div != plane_size) {
memset(pack_r, 0, C8NUM * ic4);
for (int hwi = hw_8div; hwi < plane_size; hwi += 1) {
int32_t *input_sum_oc = input_sum_r;
int32_t tmp_sum_value = 0;
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
tmp_sum_value += src_ic[0];
tmp_sum_value += src_ic[1];
tmp_sum_value += src_ic[2];
tmp_sum_value += src_ic[3];
pack_ic[0] = src_ic[0];
pack_ic[1] = src_ic[1];
pack_ic[2] = src_ic[2];
pack_ic[3] = src_ic[3];
src_ic += C4NUM;
pack_ic += C4NUM * C8NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
tmp_sum_value += src_ic[0];
pack_ic[0] = src_ic[0];
src_ic += 1;
pack_ic += 1;
}
for (int oci = 0; oci < oc_8div; oci += C8NUM) {
for (int curoi = 0; curoi < C8NUM; curoi++) {
input_sum_oc[curoi] = tmp_sum_value * filter_zp[oci + curoi];
}
input_sum_oc += inputsum_stride;
}
if (oc_8div != output_channel) {
for (int oci = 0; oci < oc_8res; oci += 1) {
input_sum_oc[oci] = tmp_sum_value * filter_zp[oc_8div + oci];
}
for (int oci = oc_8res; oci < C8NUM; oci += 1) {
input_sum_oc[oci] = 0;
}
} /* oc8 res done */
src_r += input_channel;
pack_r += C4NUM;
input_sum_r += C8NUM;
}
for (int hwi = plane_size; hwi < hw8; hwi++) {
for (int oc = 0; oc < oc8; oc++) {
int oc8div = oc / C8NUM, oc8res = oc % C8NUM;
input_sum[oc8div * inputsum_stride + hwi * C8NUM + oc8res] = 0;
}
}
}
return;
}
void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
size_t plane_size, const ConvParameter *conv_param) {
int ic4 = UP_ROUND(input_channel, C4NUM);
size_t hw_8div = plane_size / C8NUM * C8NUM;
size_t ic_4div = input_channel / C4NUM * C4NUM;
int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_;
const int8_t *src_r = src_input;
int8_t *pack_r = packed_input;
/* per layer */
for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) {
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
int32_t *input_sum_r = input_sum + hwi;
#ifdef ENABLE_ARM64
size_t src_stride = input_channel;
size_t ic_4res = input_channel - ic_4div;
asm volatile(
"dup v16.4s, wzr \n"
"dup v17.4s, wzr \n"
"mov x14, %[input_sum_r] \n"
"dup v20.4s, %w[filter_zp] \n"
"mov x10, %[src_ic] \n"
"mov x11, %[pack_ic] \n"
"mov x0, #0 \n"
"1: \n"
"cmp x0, %[ic_4div] \n"
"add x0, x0, #4\n"
"mov x12, x10 \n"
"add x10, x10, #4\n"
"blt 2f \n"
"cmp %[ic_4res], #0\n"
"beq 6f \n"
"cmp %[ic_4res], #1\n"
"beq 3f \n"
"cmp %[ic_4res], #2\n"
"beq 4f \n"
"cmp %[ic_4res], #3\n"
"beq 5f \n"
"2: \n"
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
"ld1 {v1.s}[0], [x12], %[src_stride]\n"
"ld1 {v1.s}[1], [x12], %[src_stride]\n"
"ld1 {v1.s}[2], [x12], %[src_stride]\n"
"ld1 {v1.s}[3], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 1b \n"
"3: \n" /* col res 1 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"ld1 {v0.b}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
"ld1 {v1.b}[0], [x12], %[src_stride]\n"
"ld1 {v1.b}[4], [x12], %[src_stride]\n"
"ld1 {v1.b}[8], [x12], %[src_stride]\n"
"ld1 {v1.b}[12], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"4: \n" /* col res 2 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"5: \n" /* col res 3 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"add x13, x12, #2 \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
"ld1 {v1.b}[2], [x13], %[src_stride]\n"
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
"ld1 {v1.b}[6], [x13], %[src_stride]\n"
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
"ld1 {v1.b}[10], [x13], %[src_stride]\n"
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
"ld1 {v1.b}[14], [x13], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v16.4s, v16.4s, v0.4s \n"
"add v17.4s, v17.4s, v1.4s \n"
"b 6f \n"
"6: \n"
"mul v16.4s, v16.4s, v20.4s \n"
"mul v17.4s, v17.4s, v20.4s \n"
"st1 {v16.4s}, [x14], #16 \n"
"st1 {v17.4s}, [x14], #16 \n"
:
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r),
[ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp)
: "x0", "x1", "x10", "x11", "x12", "x13", "x14", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17",
"v20");
#else
int32_t tmp_sum_value[8] = {0};
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
for (int i = 0; i < C8NUM; i++) {
tmp_sum_value[i] += src_ic[0 + i * input_channel];
tmp_sum_value[i] += src_ic[1 + i * input_channel];
tmp_sum_value[i] += src_ic[2 + i * input_channel];
tmp_sum_value[i] += src_ic[3 + i * input_channel];
pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel];
pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel];
pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel];
pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel];
}
src_ic += C4NUM;
pack_ic += C4NUM * C8NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
for (int i = 0; i < C8NUM; i++) {
tmp_sum_value[i] += src_ic[i * input_channel];
pack_ic[i * C4NUM] = src_ic[i * input_channel];
}
src_ic += 1;
pack_ic += 1;
}
for (int ici = input_channel; ici < ic4; ici += 1) {
for (int i = 0; i < C8NUM; i++) {
pack_ic[i * C4NUM] = 0;
}
pack_ic += 1;
}
for (int i = 0; i < C8NUM; i++) {
input_sum_r[i] = tmp_sum_value[i] * filter_zp;
}
#endif
src_r += input_channel * C8NUM;
pack_r += ic4 * C8NUM;
}
if (hw_8div != plane_size) {
memset(pack_r, 0, C8NUM * ic4);
for (int hwi = hw_8div; hwi < plane_size; hwi += 1) {
int32_t tmp_sum_value = 0;
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
tmp_sum_value += src_ic[0];
tmp_sum_value += src_ic[1];
tmp_sum_value += src_ic[2];
tmp_sum_value += src_ic[3];
pack_ic[0] = src_ic[0];
pack_ic[1] = src_ic[1];
pack_ic[2] = src_ic[2];
pack_ic[3] = src_ic[3];
src_ic += C4NUM;
pack_ic += C4NUM * C8NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
tmp_sum_value += src_ic[0];
pack_ic[0] = src_ic[0];
src_ic += 1;
pack_ic += 1;
}
input_sum[hwi] = tmp_sum_value * filter_zp;
src_r += input_channel;
pack_r += C4NUM;
}
for (int hwi = plane_size; hwi < UP_ROUND(plane_size, C8NUM); hwi++) {
input_sum[hwi] = 0;
}
}
return;
}
void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num,
int block_index, const int32_t *filter_zp, int32_t *input_sum,
const ConvParameter *conv_param, bool per_channel, bool is_optimize) {
// input format : nhwc
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int pad_h = conv_param->pad_u_;
int pad_w = conv_param->pad_l_;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
int in_channel = conv_param->input_channel_;
int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_;
int out_w = conv_param->output_w_;
int kernel_plane = kernel_h * kernel_w;
NNACL_CHECK_ZERO_RETURN(out_w);
NNACL_CHECK_ZERO_RETURN(dilation_h);
NNACL_CHECK_ZERO_RETURN(dilation_w);
for (int i = 0; i < real_cal_num; i++) {
int block_start = block_index + i;
int input_h = block_start / out_w * stride_h - pad_h;
int input_w = block_start % out_w * stride_w - pad_w;
int input_stride = input_h * in_w * in_channel + input_w * in_channel;
int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h));
int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h));
int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
if (dilation_w == 1 && dilation_h == 1) {
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * in_w * in_channel + input_stride;
int input_x_stride = input_y_stride + kw_s * in_channel;
int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane;
memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, (kw_e - kw_s) * in_channel);
} // kernel_h loop
} else {
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * dilation_h * in_w * in_channel + input_stride;
for (int k = kw_s; k < kw_e; ++k) {
int input_x_stride = input_y_stride + k * dilation_w * in_channel;
int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane;
memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, in_channel);
}
} // kernel_h loop
}
} // tile num loop
int deep = kernel_plane * in_channel;
if (is_optimize) {
if (per_channel) {
Conv1x1PreOptPeroc(matmul_input, packed_input, input_sum, deep, conv_param->output_channel_, real_cal_num,
filter_zp, C8NUM * C8NUM);
} else {
Conv1x1PreOptPert(matmul_input, packed_input, input_sum, deep, real_cal_num, conv_param);
}
} else {
RowMajor2Row16x4MajorInt8(matmul_input, packed_input, real_cal_num, deep);
if (per_channel) {
#ifdef ENABLE_ARM32
PackInputSum16x4PerChannelArm32(packed_input, input_sum, filter_zp, real_cal_num, deep,
conv_param->output_channel_);
#else
PackInputSum16x4PerChannel(packed_input, input_sum, filter_zp, real_cal_num, deep, conv_param->output_channel_);
#endif
} else {
size_t hw4 = UP_ROUND(real_cal_num, C4NUM);
size_t ic16 = UP_ROUND(deep, C16NUM);
PackInputSum16x4PerLayer(packed_input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4,
ic16);
}
}
}
void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, const ConvParameter *conv_param) {
int in_batch = conv_param->input_batch_;
int in_channel = conv_param->input_channel_;
@ -908,25 +115,6 @@ void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight
}
}
void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, const int32_t *filter_zp,
const ConvParameter *conv_param) {
size_t hw = conv_param->output_h_ * conv_param->output_w_;
size_t hw4 = UP_ROUND(hw, C4NUM);
size_t ic16 = UP_ROUND(conv_param->input_channel_, C16NUM);
if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) {
PackInputSum16x4PerLayer(input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16);
} else {
#ifdef ENABLE_ARM32
PackInputSum16x4PerChannelArm32(input, input_sum, filter_zp, hw, conv_param->input_channel_,
conv_param->output_channel_);
#else
PackInputSum16x4PerChannel(input, input_sum, filter_zp, hw, conv_param->input_channel_,
conv_param->output_channel_);
#endif
}
return;
}
void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16) {
/* normal matmul : 4x16 * 16x4 -> 4x4 */
#ifdef ENABLE_ARM
@ -1081,25 +269,6 @@ void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int c
}
}
void PackNCHWToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel) {
int c8 = UP_DIV(channel, C8NUM);
for (int b = 0; b < batch; b++) {
int src_offset = b * plane * channel;
int dst_offset = b * plane * c8 * C8NUM;
for (int c = 0; c < channel; c++) {
int c8_block_num = c / C8NUM;
int c8_block_rem = c % C8NUM;
int src_c_offset = src_offset + c * plane;
int dst_c_offset = dst_offset + c8_block_num * plane * C8NUM;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_c_offset + k;
int dst_kernel_offset = dst_c_offset + C8NUM * k + c8_block_rem;
((int8_t *)dst + dst_kernel_offset)[0] = ((int8_t *)src + src_kernel_offset)[0];
}
}
}
}
void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
@ -1127,21 +296,6 @@ void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int
}
}
void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel) {
for (int n = 0; n < batch; n++) {
for (int hw = 0; hw < plane; hw++) {
for (int c = 0; c < channel; c++) {
int c8div = c / C8NUM;
int c8mod = c % C8NUM;
int src_index = n * plane * channel + hw * channel + c;
int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod;
((int8_t *)dst)[dst_index] = ((int8_t *)src)[src_index];
}
}
}
return;
}
void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) {
for (int n = 0; n < batch; n++) {
for (int c = 0; c < channel; c++) {

View File

@ -30,20 +30,13 @@ void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int c
void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, const int32_t *filter_zp,
const ConvParameter *conv_param);
void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16);
void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, const ConvParameter *conv_param);
void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, const ConvParameter *conv_param);
void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num,
int block_index, const int32_t *filter_zp, int32_t *input_sum,
const ConvParameter *conv_param, bool per_channel, bool is_optimize);
#ifdef ENABLE_ARM
void PreSum4x16Int8Pert(const int8_t *src, int32_t *sum, size_t row4, size_t col16, int32_t filter_zp);
void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, const int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div,

View File

@ -27,6 +27,7 @@ option(MSLITE_ENABLE_MINDRT "enable mindrt use" on)
option(MSLITE_ENABLE_DELEGATE "enable delegate use" on)
option(MSLITE_ENABLE_V0 "support v0 schema" on)
option(MSLITE_ENABLE_FP16 "Whether to compile Fp16 operator" off)
option(MSLITE_ENABLE_INT8 "Whether to compile Int8 operator" on)
option(MSLITE_ENABLE_ACL "enable ACL" off)
option(MSLITE_ENABLE_MODEL_ENCRYPTION "enable model encryption, only converter support" on)
option(MSLITE_ENABLE_SPARSE_COMPUTE "enable sparse kernel" off)
@ -96,6 +97,14 @@ endif()
if(DEFINED ENV{MSLITE_ENABLE_FP16})
set(MSLITE_ENABLE_FP16 $ENV{MSLITE_ENABLE_FP16})
endif()
if(DEFINED ENV{MSLITE_ENABLE_INT8})
set(MSLITE_ENABLE_INT8 $ENV{MSLITE_ENABLE_INT8})
endif()
if(MSLITE_ENABLE_INT8)
set(OP_INT8_CLIP off)
else()
set(OP_INT8_CLIP on)
endif()
if(DEFINED ENV{MSLITE_ENABLE_SPARSE_COMPUTE})
set(MSLITE_ENABLE_SPARSE_COMPUTE $ENV{MSLITE_ENABLE_SPARSE_COMPUTE})
endif()
@ -236,6 +245,7 @@ message(STATUS "\tBUILD_MINDDATA = \t${BUILD_MINDDATA}")
message(STATUS "\tMSLITE_ENABLE_DELEGATE = \t${MSLITE_ENABLE_DELEGATE}")
message(STATUS "\tMSLITE_ENABLE_ACL = \t${MSLITE_ENABLE_ACL}")
message(STATUS "\tMSLITE_ENABLE_FP16 = \t${MSLITE_ENABLE_FP16}")
message(STATUS "\tMSLITE_ENABLE_INT8 = \t${MSLITE_ENABLE_INT8}")
message(STATUS "\tMSLITE_ENABLE_MODEL_ENCRYPTION = \t${MSLITE_ENABLE_MODEL_ENCRYPTION}")
message(STATUS "\tMSLITE_ENABLE_SPARSE_COMPUTE = \t${MSLITE_ENABLE_SPARSE_COMPUTE}")
message(STATUS "\tMSLITE_ENABLE_RUNTIME_CONVERT = \t${MSLITE_ENABLE_RUNTIME_CONVERT}")
@ -261,8 +271,9 @@ endif()
if(MSLITE_ENABLE_FP16 AND PLATFORM_ARM32 AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang"
AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
message(FATAL_ERROR "If you want to build fp16 in arm82_a32, \
message(STATUS "If you want to build fp16 in arm82_a32, \
your Clang version:[${CMAKE_CXX_COMPILER_VERSION}] must not be less than 9.0 and please use android nkd r21e!")
set(MSLITE_ENABLE_FP16 off)
endif()
if(MSLITE_ENABLE_HIGH_PERFORMANCE)

View File

@ -27,6 +27,9 @@ endif()
if(NOT MSLITE_ENABLE_DELEGATE)
add_compile_definitions(DELEGATE_CLIP)
endif()
if(NOT MSLITE_ENABLE_INT8)
add_compile_definitions(OP_INT8_CLIP)
endif()
if(APPLE OR PLATFORM_ARM32 OR PLATFORM_ARM64)
#for performance

View File

@ -54,6 +54,10 @@ const char *const unsupport_fp16_log =
"The mindspore-lite library does not support fp16. Set environment variable "
"MSLITE_ENABLE_FP16 to on to "
"recompile it.";
const char *const unsupport_int8_log =
"The mindspore-lite library does not support int8. Set environment variable "
"MSLITE_ENABLE_INT8 to on to "
"recompile it.";
static inline bool IsPrintDebug() {
auto env = std::getenv("GLOG_v");

View File

@ -3,9 +3,21 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/)
file(GLOB KERNEL_SRC
${CMAKE_CURRENT_SOURCE_DIR}/base/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/fp32/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/int8/*.cc
)
if(MSLITE_ENABLE_INT8)
file(GLOB INT8_KERNEL_SRC
${CMAKE_CURRENT_SOURCE_DIR}/int8/*.cc
)
set(KERNEL_SRC
${KERNEL_SRC}
${INT8_KERNEL_SRC}
)
list(REMOVE_ITEM KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/int8/opt_op_handler.cc)
else()
list(REMOVE_ITEM KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/base/quant_dtype_cast.cc)
endif()
if(MSLITE_ENABLE_SPARSE_COMPUTE)
file(GLOB SPARSE_KERNEL_SRC
${CMAKE_CURRENT_SOURCE_DIR}/fp32_sparse/*.cc
@ -34,7 +46,6 @@ if(MSLITE_ENABLE_CONTROLFLOW)
${KERNEL_CONTROL_TENSORLIST}
)
endif()
list(REMOVE_ITEM KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/int8/opt_op_handler.cc)
if(SUPPORT_TRAIN)
file(GLOB TRAIN_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16_grad/*.cc)

View File

@ -51,7 +51,16 @@ int DetectionPostProcessBaseCPUKernel::Prepare() {
auto anchor_tensor = in_tensors_.at(2);
MS_CHECK_GT(anchor_tensor->ElementsNum(), 0, RET_ERROR);
CHECK_NULL_RETURN(anchor_tensor->data());
if (anchor_tensor->data_type() == kNumberTypeInt8) {
if (anchor_tensor->data_type() == kNumberTypeFloat32 || anchor_tensor->data_type() == kNumberTypeFloat) {
params_->anchors_ = new (std::nothrow) float[anchor_tensor->ElementsNum()];
if (params_->anchors_ == nullptr) {
MS_LOG(ERROR) << "Malloc anchor failed";
return RET_ERROR;
}
MS_CHECK_FALSE(anchor_tensor->Size() == 0, RET_ERROR);
memcpy(params_->anchors_, anchor_tensor->data(), anchor_tensor->Size());
#ifndef OP_INT8_CLIP
} else if (anchor_tensor->data_type() == kNumberTypeInt8) {
auto quant_param = anchor_tensor->quant_params().front();
auto anchor_int8 = reinterpret_cast<int8_t *>(anchor_tensor->data());
auto anchor_fp32 = new (std::nothrow) float[anchor_tensor->ElementsNum()];
@ -73,14 +82,7 @@ int DetectionPostProcessBaseCPUKernel::Prepare() {
DoDequantizeUInt8ToFp32(anchor_uint8, anchor_fp32, quant_param.scale, quant_param.zeroPoint,
anchor_tensor->ElementsNum());
params_->anchors_ = anchor_fp32;
} else if (anchor_tensor->data_type() == kNumberTypeFloat32 || anchor_tensor->data_type() == kNumberTypeFloat) {
params_->anchors_ = new (std::nothrow) float[anchor_tensor->ElementsNum()];
if (params_->anchors_ == nullptr) {
MS_LOG(ERROR) << "Malloc anchor failed";
return RET_ERROR;
}
MS_CHECK_FALSE(anchor_tensor->Size() == 0, RET_ERROR);
memcpy(params_->anchors_, anchor_tensor->data(), anchor_tensor->Size());
#endif
} else {
MS_LOG(ERROR) << "unsupported anchor data type " << anchor_tensor->data_type();
return RET_ERROR;

View File

@ -1111,6 +1111,11 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
}
}
#ifdef OP_INT8_CLIP
if (desc.data_type == kNumberTypeInt8) {
MS_LOG(ERROR) << unsupport_int8_log;
}
#endif
return nullptr;
}

View File

@ -28,10 +28,8 @@ file(GLOB_RECURSE TEST_UT_SRC
${TEST_DIR}/st/mindrt_parallel_runtime_test.cc
${TEST_DIR}/st/mix_data_type_test.cc
${TEST_DIR}/ut/nnacl/infer/*.cc
${TEST_DIR}/ut/nnacl/int8/*.cc
${TEST_DIR}/ut/src/runtime/kernel/arm/common/*.cc
${TEST_DIR}/ut/src/runtime/kernel/arm/fp32/*.cc
${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc
${TEST_DIR}/ut/src/runtime/kernel/arm/string/*.cc
)
@ -68,6 +66,14 @@ if(MSLITE_GPU_BACKEND STREQUAL opencl)
endif()
endif()
if(MSLITE_ENABLE_INT8)
file(GLOB_RECURSE TEST_INT8_UT_SRC
${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc
${TEST_DIR}/ut/nnacl/int8/*.cc
)
list(APPEND TEST_UT_SRC ${TEST_INT8_UT_SRC})
endif()
if(MSLITE_ENABLE_FP16)
file(GLOB_RECURSE TEST_FP16_UT_SRC
${TEST_DIR}/ut/src/runtime/kernel/arm/fp16/*.cc

View File

@ -35,77 +35,6 @@ class TestDeconvInt8 : public mindspore::CommonTest {
TestDeconvInt8() {}
};
TEST_F(TestDeconvInt8, PackWeight1) {
int8_t in[] = {-8, 11, 99, -80, 8, -12, 37, -45, 31, -69, -66, 26, 112, 124, -109, 85, -24, 28, -46, 100,
72, -36, -82, 64, -110, 37, -72, 65, -124, 91, -43, 99, 3, 100, 19, 51, -14, -81, 67, 90,
4, -106, 105, 28, -61, -79, 55, -54, 47, -38, 114, 125, -65, 100, 6, -72, -33, 60, 109, -68};
int8_t co[] = {-8, 11, 99, -80, 8, -12, 0, 0, 112, 124, -109, 85, -24, 28, 0, 0, -110, 37, -72, 65,
-124, 91, 0, 0, -14, -81, 67, 90, 4, -106, 0, 0, 47, -38, 114, 125, -65, 100, 0, 0,
37, -45, 31, -69, -66, 26, 0, 0, -46, 100, 72, -36, -82, 64, 0, 0, -43, 99, 3, 100,
19, 51, 0, 0, 105, 28, -61, -79, 55, -54, 0, 0, 6, -72, -33, 60, 109, -68, 0, 0};
int8_t dst[80] = {0};
/*5*1*2*6 nhwc*/
PackNHWCToC8HWN8Int8(in, dst, 5, 2, 6);
ASSERT_EQ(0, CompareOutputData(dst, co, 80, 1));
}
TEST_F(TestDeconvInt8, PackWeight2) {
int8_t in[] = {
40, 24, 94, 122, 67, 34, -89, 31, -43, 121, 48, -54, 44, -91, 35, 89, -37, 114, -8, 103,
-22, 32, 26, 112, -92, -23, 43, 9, 81, 118, -73, -54, 65, -99, 51, -90, 121, -62, 119, -93,
21, -92, -1, -82, -71, -54, 63, -93, 92, -93, 99, 122, -104, -16, -8, -32, 90, -126, 51, 91,
4, 70, -7, 116, 99, 81, -79, 124, -14, 28, 97, 9, -97, 99, 88, -15, 54, 26, 77, -25,
113, 119, 119, -75, -17, 7, 7, 1, 69, 66, 40, -13, 80, -115, -98, -8, -17, 31, 88, 65,
-1, -15, -98, 77, 56, 119, -20, -32, -54, -58, -16, 52, 121, 126, -33, 43, 92, -34, -17, -52,
104, -52, -91, 76, 79, 105, 102, -65, 43, 32, 13, 15, -38, 95, -18, -82, -7, 118, -79, -85,
120, -15, 2, 32, -94, 111, 115, 102, -18, 121, -106, 54, 63, 111, -16, 92, 82, -23, 111, 53,
1, -48, 45, 19, -4, -15, -72, 41, 80, -51, 116, 31, 94, 101, -10, 18, 0, -49, 108, 28,
-36, 47, -14, -2, -10, 31, -92, -84, 74, -114, -107, 66, 99, -121, -107, 31, -38, 56, -30, 109,
-7, 28, -22, -17, -3, -2, 27, -3, 108, -84, -23, -71, -54, 20, -45, 109, -42, 78, -79, 98,
-10, 57, 52, 1, 25, 73, 21, -78, 46, 121, 66, 92, 24, 55, 4, -110, -37, 112, -18, 10,
-42, 16, -9, 31, 39, -70, 108, -3, -90, -60, -121, 11, 50, -88, -104, -29, -89, 94, 64, -91,
-101, -7, 23, -57, 93, 16, 17, 35, -48, -25, 13, -121, 73, -68, -54, -122, -20, 12, 64, 20,
-11, -6, -71, -52, -97, 109, 116, -107, 117, -124, 56, 80, -108, 30, 123, 56, -80, 39, -18, -97,
-103, 122, 114, -10, -31, 97, -92, 105, -61, -25, 10, -119, -106, 41, 77, -117, 55, -83, -29, 14,
27, -106, -86, 41, 43, 23, 11, -76, -34, 121, 94, 18, 69, 73, 100, 54, 43, 32, 13, 15,
-38, 95, -18, -82, -7, 118, -79, -85, 120, -15, 2, 32, -94, 111, 115, 102, -18, 121, -106, 54,
63, 111, -16, 92, 82, -23, 111, 53, 1, -48, 45, 19, -4, -15, -72, 41, 80, -51, 116, 31,
94, 101, -10, 18, 0, -49, 108, 28, -36, 47, -14, -2, -10, 31, -92, -84, 74, -114, -107, 66,
99, -121, -107, 31, -38, 56, -30, 109, -7, 28, -22, -17, -3, -2, 27, -3, 108, -84, -23, -71,
-54, 20, -45, 109, -42, 78, -79, 98, -10, 57, 52, 1, 25, 73, 21, -78, 46, 121, 66, 92};
int8_t co[] = {
40, 24, 94, 122, 67, 34, -89, 31, -22, 32, 26, 112, -92, -23, 43, 9, 21, -92, -1, -82,
-71, -54, 63, -93, 4, 70, -7, 116, 99, 81, -79, 124, 113, 119, 119, -75, -17, 7, 7, 1,
-1, -15, -98, 77, 56, 119, -20, -32, 104, -52, -91, 76, 79, 105, 102, -65, 120, -15, 2, 32,
-94, 111, 115, 102, 1, -48, 45, 19, -4, -15, -72, 41, -36, 47, -14, -2, -10, 31, -92, -84,
-7, 28, -22, -17, -3, -2, 27, -3, -10, 57, 52, 1, 25, 73, 21, -78, -42, 16, -9, 31,
39, -70, 108, -3, -101, -7, 23, -57, 93, 16, 17, 35, -11, -6, -71, -52, -97, 109, 116, -107,
-103, 122, 114, -10, -31, 97, -92, 105, 27, -106, -86, 41, 43, 23, 11, -76, -38, 95, -18, -82,
-7, 118, -79, -85, 63, 111, -16, 92, 82, -23, 111, 53, 94, 101, -10, 18, 0, -49, 108, 28,
99, -121, -107, 31, -38, 56, -30, 109, -54, 20, -45, 109, -42, 78, -79, 98, -43, 121, 48, -54,
44, -91, 35, 89, 81, 118, -73, -54, 65, -99, 51, -90, 92, -93, 99, 122, -104, -16, -8, -32,
-14, 28, 97, 9, -97, 99, 88, -15, 69, 66, 40, -13, 80, -115, -98, -8, -54, -58, -16, 52,
121, 126, -33, 43, 43, 32, 13, 15, -38, 95, -18, -82, -18, 121, -106, 54, 63, 111, -16, 92,
80, -51, 116, 31, 94, 101, -10, 18, 74, -114, -107, 66, 99, -121, -107, 31, 108, -84, -23, -71,
-54, 20, -45, 109, 46, 121, 66, 92, 24, 55, 4, -110, -90, -60, -121, 11, 50, -88, -104, -29,
-48, -25, 13, -121, 73, -68, -54, -122, 117, -124, 56, 80, -108, 30, 123, 56, -61, -25, 10, -119,
-106, 41, 77, -117, -34, 121, 94, 18, 69, 73, 100, 54, 120, -15, 2, 32, -94, 111, 115, 102,
1, -48, 45, 19, -4, -15, -72, 41, -36, 47, -14, -2, -10, 31, -92, -84, -7, 28, -22, -17,
-3, -2, 27, -3, -10, 57, 52, 1, 25, 73, 21, -78, -37, 114, -8, 103, 0, 0, 0, 0,
121, -62, 119, -93, 0, 0, 0, 0, 90, -126, 51, 91, 0, 0, 0, 0, 54, 26, 77, -25,
0, 0, 0, 0, -17, 31, 88, 65, 0, 0, 0, 0, 92, -34, -17, -52, 0, 0, 0, 0,
-7, 118, -79, -85, 0, 0, 0, 0, 82, -23, 111, 53, 0, 0, 0, 0, 0, -49, 108, 28,
0, 0, 0, 0, -38, 56, -30, 109, 0, 0, 0, 0, -42, 78, -79, 98, 0, 0, 0, 0,
-37, 112, -18, 10, 0, 0, 0, 0, -89, 94, 64, -91, 0, 0, 0, 0, -20, 12, 64, 20,
0, 0, 0, 0, -80, 39, -18, -97, 0, 0, 0, 0, 55, -83, -29, 14, 0, 0, 0, 0,
43, 32, 13, 15, 0, 0, 0, 0, -18, 121, -106, 54, 0, 0, 0, 0, 80, -51, 116, 31,
0, 0, 0, 0, 74, -114, -107, 66, 0, 0, 0, 0, 108, -84, -23, -71, 0, 0, 0, 0,
46, 121, 66, 92, 0, 0, 0, 0};
int8_t dst[528] = {0};
PackNHWCToC8HWN8Int8(in, dst, 22, 1, 20);
ASSERT_EQ(0, CompareOutputData(dst, co, 528, 1));
}
TEST_F(TestDeconvInt8, PackInputTest1) {
/* 6 x 20 */
int8_t in[] = {40, 24, 94, 122, 67, 34, -89, 31, -43, 121, 48, -54, 44, -91, 35, 89, -37, 114, -8, 103,