!26050 [MS][LITE]Add int8 clip
Merge pull request !26050 from gongdaguo/add_int8_clip
This commit is contained in:
commit
a14a777464
|
@ -46,7 +46,7 @@ mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_utils_
|
|||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/pooling_int8.c:AvgPoolingOptInt8
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/pooling_int8.c:MaxPoolingWithQuantInt8
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/conv3x3_int8.c:Conv3x3Int8OutputUnit
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/pack_int8.c:Conv1x1PreOptPeroc
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/conv_int8.c:Conv1x1PreOptPeroc
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.c:RegisterInfer
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/gemm.c:RowMajor2Col12MajorStride
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/gemm.c:RowMajor2Col8MajorStride
|
||||
|
@ -65,8 +65,8 @@ mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/pooling_int8.c:
|
|||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/conv3x3_int8.c:Conv3x3Int8InputUnit
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/conv3x3_int8.c:Conv3x3Int8FilterTransform
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/conv3x3_int8.c:Conv3x3Int8OutputUnit
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/pack_int8.c:Conv1x1PreOptPeroc
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/pack_int8.c:Conv1x1PreOptPert
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/conv_int8.c:Conv1x1PreOptPeroc
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/conv_int8.c:Conv1x1PreOptPert
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/pack_int8.c:PackNHWCToNCHWInt8
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/pooling_fp32.c:AvgPooling
|
||||
mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.c:SWConv3x32Kernel, SWConv4x24Kernel, SWConv12x8Kernel, SWConv8x8Kernel, SWConv4x8Kernel, SWConv6x16Kernel, SWConv4x16Kernel
|
||||
|
|
|
@ -30,12 +30,27 @@ endif()
|
|||
file(GLOB KERNEL_SRC
|
||||
${NNACL_DIR}/*.c
|
||||
${NNACL_DIR}/fp32/*.c
|
||||
${NNACL_DIR}/int8/*.c
|
||||
${NNACL_DIR}/infer/*.c
|
||||
${NNACL_DIR}/base/*.c
|
||||
${NNACL_DIR}/fp32_grad/*.c
|
||||
)
|
||||
|
||||
if(OP_INT8_CLIP)
|
||||
set(KERNEL_SRC
|
||||
${KERNEL_SRC}
|
||||
${NNACL_DIR}/int8/pack_int8.c
|
||||
${NNACL_DIR}/int8/quantize.c
|
||||
)
|
||||
else()
|
||||
file(GLOB KERNEL_SRC_INT8
|
||||
${NNACL_DIR}/int8/*.c
|
||||
)
|
||||
set(KERNEL_SRC
|
||||
${KERNEL_SRC}
|
||||
${KERNEL_SRC_INT8}
|
||||
)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_SPARSE_COMPUTE)
|
||||
file(GLOB KERNEL_SRC_SPARSE
|
||||
${NNACL_DIR}/fp32_sparse/*.c
|
||||
|
|
|
@ -16,6 +16,818 @@
|
|||
|
||||
#include "nnacl/int8/conv_int8.h"
|
||||
|
||||
#ifdef ENABLE_ARM32
|
||||
void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_sum, const int32_t *filter_zp_ptr,
|
||||
size_t plane_size, size_t input_channel, size_t output_channel) {
|
||||
size_t hw4 = UP_ROUND(plane_size, C4NUM);
|
||||
size_t ic16 = UP_ROUND(input_channel, C16NUM);
|
||||
|
||||
#ifdef ENABLE_ARM32
|
||||
size_t oc_div2 = output_channel / C2NUM * C2NUM;
|
||||
size_t oc_res2 = output_channel - oc_div2;
|
||||
size_t inputsun_stride = hw4 * C2NUM * 4 - C4NUM * C2NUM * 4;
|
||||
PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div2, oc_res2, inputsun_stride);
|
||||
#else
|
||||
for (int ri = 0; ri < plane_size; ri++) {
|
||||
int ri4div = ri / C4NUM, ri4mod = ri % C4NUM;
|
||||
for (int ci = 0; ci < output_channel; ci++) {
|
||||
int32_t tmp_sum_value = 0;
|
||||
int ci2div = ci / C2NUM, ci2mod = ci % C2NUM;
|
||||
int32_t filter_zp = filter_zp_ptr[ci];
|
||||
for (int di = 0; di < input_channel; di++) {
|
||||
size_t di16div = di / C16NUM, di16mod = di % C16NUM;
|
||||
int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod;
|
||||
tmp_sum_value += input_value[src_index];
|
||||
}
|
||||
int dst_index = ci2div * C2NUM * hw4 + ri * C2NUM + ci2mod;
|
||||
input_sum[dst_index] = tmp_sum_value * filter_zp;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, const int32_t *filter_zp_ptr,
|
||||
size_t plane_size, size_t input_channel, size_t output_channel) {
|
||||
size_t hw4 = UP_ROUND(plane_size, C4NUM);
|
||||
size_t ic16 = UP_ROUND(input_channel, C16NUM);
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t oc_div4 = output_channel / C4NUM * C4NUM;
|
||||
size_t oc_res4 = output_channel - oc_div4;
|
||||
size_t inputsun_stride = hw4 * C4NUM * 4 - C4NUM * C4NUM * 4;
|
||||
PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div4, oc_res4, inputsun_stride);
|
||||
#else
|
||||
|
||||
for (int ri = 0; ri < plane_size; ri++) {
|
||||
int ri4div = ri / C4NUM, ri4mod = ri % C4NUM;
|
||||
for (int ci = 0; ci < output_channel; ci++) {
|
||||
int32_t tmp_sum_value = 0;
|
||||
int ci4div = ci / C4NUM, ci4mod = ci % C4NUM;
|
||||
int32_t filter_zp = filter_zp_ptr[ci];
|
||||
for (int di = 0; di < input_channel; di++) {
|
||||
size_t di16div = di / C16NUM, di16mod = di % C16NUM;
|
||||
int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod;
|
||||
tmp_sum_value += input_value[src_index];
|
||||
}
|
||||
int dst_index = ci4div * C4NUM * hw4 + ri * C4NUM + ci4mod;
|
||||
input_sum[dst_index] = tmp_sum_value * filter_zp;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
||||
void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
|
||||
size_t output_channel, size_t plane_size, const int32_t *filter_zp, size_t inputsum_stride) {
|
||||
int ic4 = UP_ROUND(input_channel, C4NUM);
|
||||
int oc8 = UP_ROUND(output_channel, C8NUM);
|
||||
int hw8 = UP_ROUND(plane_size, C8NUM);
|
||||
size_t hw_8div = plane_size / C8NUM * C8NUM;
|
||||
size_t oc_8div = output_channel / C8NUM * C8NUM;
|
||||
size_t oc_8res = output_channel - oc_8div;
|
||||
size_t ic_4div = input_channel / C4NUM * C4NUM;
|
||||
|
||||
const int8_t *src_r = src_input;
|
||||
int8_t *pack_r = packed_input;
|
||||
int32_t *input_sum_r = input_sum;
|
||||
|
||||
for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) {
|
||||
const int8_t *src_ic = src_r;
|
||||
int8_t *pack_ic = pack_r;
|
||||
int32_t *input_sum_oc = input_sum_r;
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t src_stride = input_channel;
|
||||
size_t ic_4res = input_channel - ic_4div;
|
||||
size_t input_sum_stride = inputsum_stride * 4 - C8NUM * C8NUM * 4;
|
||||
asm volatile(
|
||||
"dup v16.4s, wzr \n"
|
||||
"dup v17.4s, wzr \n"
|
||||
|
||||
"mov x10, %[src_ic] \n"
|
||||
"mov x11, %[pack_ic] \n"
|
||||
|
||||
"mov x0, #0 \n"
|
||||
"1: \n"
|
||||
"cmp x0, %[ic_4div] \n"
|
||||
"add x0, x0, #4\n"
|
||||
"mov x12, x10 \n"
|
||||
"add x10, x10, #4\n"
|
||||
"blt 2f \n"
|
||||
"cmp %[ic_4res], #0\n"
|
||||
"beq 6f \n"
|
||||
"cmp %[ic_4res], #1\n"
|
||||
"beq 3f \n"
|
||||
"cmp %[ic_4res], #2\n"
|
||||
"beq 4f \n"
|
||||
"cmp %[ic_4res], #3\n"
|
||||
"beq 5f \n"
|
||||
|
||||
"2: \n"
|
||||
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[1], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[3], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 1b \n"
|
||||
|
||||
"3: \n" /* col res 1 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
|
||||
"ld1 {v0.b}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[8], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[12], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"4: \n" /* col res 2 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
|
||||
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"5: \n" /* col res 3 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
"add x13, x12, #2 \n"
|
||||
|
||||
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[2], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[6], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[10], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[14], [x13], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"6: \n"
|
||||
"dup v0.4s, v16.s[0] \n"
|
||||
"dup v1.4s, v16.s[1] \n"
|
||||
"dup v2.4s, v16.s[2] \n"
|
||||
"dup v3.4s, v16.s[3] \n"
|
||||
"dup v4.4s, v17.s[0] \n"
|
||||
"dup v5.4s, v17.s[1] \n"
|
||||
"dup v6.4s, v17.s[2] \n"
|
||||
"dup v7.4s, v17.s[3] \n"
|
||||
"mov x4, #0 \n"
|
||||
"mov x10, %[filter_zp] \n"
|
||||
"mov x11, %[input_sum_oc] \n"
|
||||
|
||||
"7: \n"
|
||||
"cmp x4, %[oc_8div] \n"
|
||||
"beq 8f \n"
|
||||
"add x4, x4, #8\n"
|
||||
"ld1 {v16.4s}, [x10], #16\n"
|
||||
"ld1 {v17.4s}, [x10], #16\n"
|
||||
|
||||
"mul v18.4s, v16.4s, v0.4s \n"
|
||||
"mul v19.4s, v17.4s, v0.4s \n"
|
||||
"st1 {v18.4s}, [x11], #16 \n"
|
||||
"st1 {v19.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v20.4s, v16.4s, v1.4s \n"
|
||||
"mul v21.4s, v17.4s, v1.4s \n"
|
||||
"st1 {v20.4s}, [x11], #16 \n"
|
||||
"st1 {v21.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v22.4s, v16.4s, v2.4s \n"
|
||||
"mul v23.4s, v17.4s, v2.4s \n"
|
||||
"st1 {v22.4s}, [x11], #16 \n"
|
||||
"st1 {v23.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v24.4s, v16.4s, v3.4s \n"
|
||||
"mul v25.4s, v17.4s, v3.4s \n"
|
||||
"st1 {v24.4s}, [x11], #16 \n"
|
||||
"st1 {v25.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v18.4s, v16.4s, v4.4s \n"
|
||||
"mul v19.4s, v17.4s, v4.4s \n"
|
||||
"st1 {v18.4s}, [x11], #16 \n"
|
||||
"st1 {v19.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v20.4s, v16.4s, v5.4s \n"
|
||||
"mul v21.4s, v17.4s, v5.4s \n"
|
||||
"st1 {v20.4s}, [x11], #16 \n"
|
||||
"st1 {v21.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v22.4s, v16.4s, v6.4s \n"
|
||||
"mul v23.4s, v17.4s, v6.4s \n"
|
||||
"st1 {v22.4s}, [x11], #16 \n"
|
||||
"st1 {v23.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v24.4s, v16.4s, v7.4s \n"
|
||||
"mul v25.4s, v17.4s, v7.4s \n"
|
||||
"st1 {v24.4s}, [x11], #16 \n"
|
||||
"st1 {v25.4s}, [x11], #16 \n"
|
||||
|
||||
"add x11, x11, %[input_sum_stride] \n"
|
||||
"b 7b \n"
|
||||
|
||||
"8: \n"
|
||||
"cmp %[oc_8res], #0\n"
|
||||
"beq 17f \n"
|
||||
|
||||
"dup v16.4s, wzr \n"
|
||||
"dup v17.4s, wzr \n"
|
||||
"cmp %[oc_8res], #1\n"
|
||||
"beq 9f \n"
|
||||
"cmp %[oc_8res], #2\n"
|
||||
"beq 10f \n"
|
||||
"cmp %[oc_8res], #3\n"
|
||||
"beq 11f \n"
|
||||
"cmp %[oc_8res], #4\n"
|
||||
"beq 12f \n"
|
||||
"cmp %[oc_8res], #5\n"
|
||||
"beq 13f \n"
|
||||
"cmp %[oc_8res], #6\n"
|
||||
"beq 14f \n"
|
||||
"cmp %[oc_8res], #7\n"
|
||||
"beq 15f \n"
|
||||
|
||||
"9: \n"
|
||||
"ld1 {v16.s}[0], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"10: \n"
|
||||
"ld1 {v16.d}[0], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"11: \n"
|
||||
"ld1 {v16.d}[0], [x10] \n"
|
||||
"add x10, x10, #8 \n"
|
||||
"ld1 {v16.s}[2], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"12: \n"
|
||||
"ld1 {v16.4s}, [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"13: \n"
|
||||
"ld1 {v16.4s}, [x10], #16\n"
|
||||
"ld1 {v17.s}[0], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"14: \n"
|
||||
"ld1 {v16.4s}, [x10], #16\n"
|
||||
"ld1 {v17.d}[0], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"15: \n"
|
||||
"ld1 {v16.4s}, [x10], #16\n"
|
||||
"ld1 {v17.d}[0], [x10] \n"
|
||||
"add x10, x10, #8 \n"
|
||||
"ld1 {v17.s}[2], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"16: \n"
|
||||
"mul v18.4s, v16.4s, v0.4s \n"
|
||||
"mul v19.4s, v17.4s, v0.4s \n"
|
||||
"mul v20.4s, v16.4s, v1.4s \n"
|
||||
"mul v21.4s, v17.4s, v1.4s \n"
|
||||
"mul v22.4s, v16.4s, v2.4s \n"
|
||||
"mul v23.4s, v17.4s, v2.4s \n"
|
||||
"mul v24.4s, v16.4s, v3.4s \n"
|
||||
"mul v25.4s, v17.4s, v3.4s \n"
|
||||
"st1 {v18.4s}, [x11], #16 \n"
|
||||
"st1 {v19.4s}, [x11], #16 \n"
|
||||
"st1 {v20.4s}, [x11], #16 \n"
|
||||
"st1 {v21.4s}, [x11], #16 \n"
|
||||
"st1 {v22.4s}, [x11], #16 \n"
|
||||
"st1 {v23.4s}, [x11], #16 \n"
|
||||
"st1 {v24.4s}, [x11], #16 \n"
|
||||
"st1 {v25.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v18.4s, v16.4s, v4.4s \n"
|
||||
"mul v19.4s, v17.4s, v4.4s \n"
|
||||
"mul v20.4s, v16.4s, v5.4s \n"
|
||||
"mul v21.4s, v17.4s, v5.4s \n"
|
||||
"mul v22.4s, v16.4s, v6.4s \n"
|
||||
"mul v23.4s, v17.4s, v6.4s \n"
|
||||
"mul v24.4s, v16.4s, v7.4s \n"
|
||||
"mul v25.4s, v17.4s, v7.4s \n"
|
||||
"st1 {v18.4s}, [x11], #16 \n"
|
||||
"st1 {v19.4s}, [x11], #16 \n"
|
||||
"st1 {v20.4s}, [x11], #16 \n"
|
||||
"st1 {v21.4s}, [x11], #16 \n"
|
||||
"st1 {v22.4s}, [x11], #16 \n"
|
||||
"st1 {v23.4s}, [x11], #16 \n"
|
||||
"st1 {v24.4s}, [x11], #16 \n"
|
||||
"st1 {v25.4s}, [x11], #16 \n"
|
||||
|
||||
"17: \n"
|
||||
|
||||
:
|
||||
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ filter_zp ] "r"(filter_zp),
|
||||
[ input_sum_oc ] "r"(input_sum_oc), [ input_sum_stride ] "r"(input_sum_stride), [ src_stride ] "r"(src_stride),
|
||||
[ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ oc_8div ] "r"(oc_8div), [ oc_8res ] "r"(oc_8res)
|
||||
: "x0", "x1", "x4", "x9", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16",
|
||||
"v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25");
|
||||
#else
|
||||
int32_t tmp_sum_value[8] = {0};
|
||||
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
tmp_sum_value[i] += src_ic[0 + i * input_channel];
|
||||
tmp_sum_value[i] += src_ic[1 + i * input_channel];
|
||||
tmp_sum_value[i] += src_ic[2 + i * input_channel];
|
||||
tmp_sum_value[i] += src_ic[3 + i * input_channel];
|
||||
pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel];
|
||||
pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel];
|
||||
pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel];
|
||||
pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel];
|
||||
}
|
||||
src_ic += C4NUM;
|
||||
pack_ic += C4NUM * C8NUM;
|
||||
}
|
||||
for (int ici = ic_4div; ici < input_channel; ici += 1) {
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
tmp_sum_value[i] += src_ic[i * input_channel];
|
||||
pack_ic[i * C4NUM] = src_ic[i * input_channel];
|
||||
}
|
||||
src_ic += 1;
|
||||
pack_ic += 1;
|
||||
}
|
||||
|
||||
for (int ici = input_channel; ici < ic4; ici += 1) {
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
pack_ic[i * C4NUM] = 0;
|
||||
}
|
||||
pack_ic += 1;
|
||||
}
|
||||
|
||||
for (int oci = 0; oci < oc_8div; oci += C8NUM) {
|
||||
for (int ri = 0; ri < C8NUM; ri++) {
|
||||
input_sum_oc[ri * C8NUM + 0] = tmp_sum_value[ri] * filter_zp[oci + 0];
|
||||
input_sum_oc[ri * C8NUM + 1] = tmp_sum_value[ri] * filter_zp[oci + 1];
|
||||
input_sum_oc[ri * C8NUM + 2] = tmp_sum_value[ri] * filter_zp[oci + 2];
|
||||
input_sum_oc[ri * C8NUM + 3] = tmp_sum_value[ri] * filter_zp[oci + 3];
|
||||
input_sum_oc[ri * C8NUM + 4] = tmp_sum_value[ri] * filter_zp[oci + 4];
|
||||
input_sum_oc[ri * C8NUM + 5] = tmp_sum_value[ri] * filter_zp[oci + 5];
|
||||
input_sum_oc[ri * C8NUM + 6] = tmp_sum_value[ri] * filter_zp[oci + 6];
|
||||
input_sum_oc[ri * C8NUM + 7] = tmp_sum_value[ri] * filter_zp[oci + 7];
|
||||
}
|
||||
input_sum_oc += inputsum_stride;
|
||||
}
|
||||
if (oc_8div != output_channel) {
|
||||
for (int oci = 0; oci < oc_8res; oci += 1) {
|
||||
for (int ri = 0; ri < C8NUM; ri++) {
|
||||
input_sum_oc[ri * C8NUM + oci] = tmp_sum_value[ri] * filter_zp[oc_8div + oci];
|
||||
}
|
||||
}
|
||||
for (int oci = oc_8res; oci < C8NUM; oci += 1) {
|
||||
for (int ri = 0; ri < C8NUM; ri++) {
|
||||
input_sum_oc[ri * C8NUM + oci] = 0;
|
||||
}
|
||||
}
|
||||
} /* oc8 res done */
|
||||
#endif
|
||||
src_r += input_channel * C8NUM;
|
||||
pack_r += ic4 * C8NUM;
|
||||
input_sum_r += C8NUM * C8NUM;
|
||||
}
|
||||
|
||||
if (hw_8div != plane_size) {
|
||||
memset(pack_r, 0, C8NUM * ic4);
|
||||
for (int hwi = hw_8div; hwi < plane_size; hwi += 1) {
|
||||
int32_t *input_sum_oc = input_sum_r;
|
||||
int32_t tmp_sum_value = 0;
|
||||
const int8_t *src_ic = src_r;
|
||||
int8_t *pack_ic = pack_r;
|
||||
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
|
||||
tmp_sum_value += src_ic[0];
|
||||
tmp_sum_value += src_ic[1];
|
||||
tmp_sum_value += src_ic[2];
|
||||
tmp_sum_value += src_ic[3];
|
||||
pack_ic[0] = src_ic[0];
|
||||
pack_ic[1] = src_ic[1];
|
||||
pack_ic[2] = src_ic[2];
|
||||
pack_ic[3] = src_ic[3];
|
||||
src_ic += C4NUM;
|
||||
pack_ic += C4NUM * C8NUM;
|
||||
}
|
||||
for (int ici = ic_4div; ici < input_channel; ici += 1) {
|
||||
tmp_sum_value += src_ic[0];
|
||||
pack_ic[0] = src_ic[0];
|
||||
src_ic += 1;
|
||||
pack_ic += 1;
|
||||
}
|
||||
|
||||
for (int oci = 0; oci < oc_8div; oci += C8NUM) {
|
||||
for (int curoi = 0; curoi < C8NUM; curoi++) {
|
||||
input_sum_oc[curoi] = tmp_sum_value * filter_zp[oci + curoi];
|
||||
}
|
||||
input_sum_oc += inputsum_stride;
|
||||
}
|
||||
if (oc_8div != output_channel) {
|
||||
for (int oci = 0; oci < oc_8res; oci += 1) {
|
||||
input_sum_oc[oci] = tmp_sum_value * filter_zp[oc_8div + oci];
|
||||
}
|
||||
for (int oci = oc_8res; oci < C8NUM; oci += 1) {
|
||||
input_sum_oc[oci] = 0;
|
||||
}
|
||||
} /* oc8 res done */
|
||||
|
||||
src_r += input_channel;
|
||||
pack_r += C4NUM;
|
||||
input_sum_r += C8NUM;
|
||||
}
|
||||
|
||||
for (int hwi = plane_size; hwi < hw8; hwi++) {
|
||||
for (int oc = 0; oc < oc8; oc++) {
|
||||
int oc8div = oc / C8NUM, oc8res = oc % C8NUM;
|
||||
input_sum[oc8div * inputsum_stride + hwi * C8NUM + oc8res] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
|
||||
size_t plane_size, const ConvParameter *conv_param) {
|
||||
int ic4 = UP_ROUND(input_channel, C4NUM);
|
||||
size_t hw_8div = plane_size / C8NUM * C8NUM;
|
||||
size_t ic_4div = input_channel / C4NUM * C4NUM;
|
||||
int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_;
|
||||
|
||||
const int8_t *src_r = src_input;
|
||||
int8_t *pack_r = packed_input;
|
||||
/* per layer */
|
||||
for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) {
|
||||
const int8_t *src_ic = src_r;
|
||||
int8_t *pack_ic = pack_r;
|
||||
int32_t *input_sum_r = input_sum + hwi;
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t src_stride = input_channel;
|
||||
size_t ic_4res = input_channel - ic_4div;
|
||||
asm volatile(
|
||||
"dup v16.4s, wzr \n"
|
||||
"dup v17.4s, wzr \n"
|
||||
"mov x14, %[input_sum_r] \n"
|
||||
"dup v20.4s, %w[filter_zp] \n"
|
||||
|
||||
"mov x10, %[src_ic] \n"
|
||||
"mov x11, %[pack_ic] \n"
|
||||
|
||||
"mov x0, #0 \n"
|
||||
"1: \n"
|
||||
"cmp x0, %[ic_4div] \n"
|
||||
"add x0, x0, #4\n"
|
||||
"mov x12, x10 \n"
|
||||
"add x10, x10, #4\n"
|
||||
"blt 2f \n"
|
||||
"cmp %[ic_4res], #0\n"
|
||||
"beq 6f \n"
|
||||
"cmp %[ic_4res], #1\n"
|
||||
"beq 3f \n"
|
||||
"cmp %[ic_4res], #2\n"
|
||||
"beq 4f \n"
|
||||
"cmp %[ic_4res], #3\n"
|
||||
"beq 5f \n"
|
||||
|
||||
"2: \n"
|
||||
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[1], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[3], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 1b \n"
|
||||
|
||||
"3: \n" /* col res 1 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
|
||||
"ld1 {v0.b}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[8], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[12], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"4: \n" /* col res 2 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
|
||||
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"5: \n" /* col res 3 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
"add x13, x12, #2 \n"
|
||||
|
||||
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[2], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[6], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[10], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[14], [x13], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"6: \n"
|
||||
"mul v16.4s, v16.4s, v20.4s \n"
|
||||
"mul v17.4s, v17.4s, v20.4s \n"
|
||||
|
||||
"st1 {v16.4s}, [x14], #16 \n"
|
||||
"st1 {v17.4s}, [x14], #16 \n"
|
||||
|
||||
:
|
||||
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r),
|
||||
[ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp)
|
||||
: "x0", "x1", "x10", "x11", "x12", "x13", "x14", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17",
|
||||
"v20");
|
||||
#else
|
||||
int32_t tmp_sum_value[8] = {0};
|
||||
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
tmp_sum_value[i] += src_ic[0 + i * input_channel];
|
||||
tmp_sum_value[i] += src_ic[1 + i * input_channel];
|
||||
tmp_sum_value[i] += src_ic[2 + i * input_channel];
|
||||
tmp_sum_value[i] += src_ic[3 + i * input_channel];
|
||||
pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel];
|
||||
pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel];
|
||||
pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel];
|
||||
pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel];
|
||||
}
|
||||
src_ic += C4NUM;
|
||||
pack_ic += C4NUM * C8NUM;
|
||||
}
|
||||
for (int ici = ic_4div; ici < input_channel; ici += 1) {
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
tmp_sum_value[i] += src_ic[i * input_channel];
|
||||
pack_ic[i * C4NUM] = src_ic[i * input_channel];
|
||||
}
|
||||
src_ic += 1;
|
||||
pack_ic += 1;
|
||||
}
|
||||
|
||||
for (int ici = input_channel; ici < ic4; ici += 1) {
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
pack_ic[i * C4NUM] = 0;
|
||||
}
|
||||
pack_ic += 1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
input_sum_r[i] = tmp_sum_value[i] * filter_zp;
|
||||
}
|
||||
#endif
|
||||
src_r += input_channel * C8NUM;
|
||||
pack_r += ic4 * C8NUM;
|
||||
}
|
||||
|
||||
if (hw_8div != plane_size) {
|
||||
memset(pack_r, 0, C8NUM * ic4);
|
||||
for (int hwi = hw_8div; hwi < plane_size; hwi += 1) {
|
||||
int32_t tmp_sum_value = 0;
|
||||
const int8_t *src_ic = src_r;
|
||||
int8_t *pack_ic = pack_r;
|
||||
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
|
||||
tmp_sum_value += src_ic[0];
|
||||
tmp_sum_value += src_ic[1];
|
||||
tmp_sum_value += src_ic[2];
|
||||
tmp_sum_value += src_ic[3];
|
||||
pack_ic[0] = src_ic[0];
|
||||
pack_ic[1] = src_ic[1];
|
||||
pack_ic[2] = src_ic[2];
|
||||
pack_ic[3] = src_ic[3];
|
||||
src_ic += C4NUM;
|
||||
pack_ic += C4NUM * C8NUM;
|
||||
}
|
||||
for (int ici = ic_4div; ici < input_channel; ici += 1) {
|
||||
tmp_sum_value += src_ic[0];
|
||||
pack_ic[0] = src_ic[0];
|
||||
src_ic += 1;
|
||||
pack_ic += 1;
|
||||
}
|
||||
input_sum[hwi] = tmp_sum_value * filter_zp;
|
||||
src_r += input_channel;
|
||||
pack_r += C4NUM;
|
||||
}
|
||||
for (int hwi = plane_size; hwi < UP_ROUND(plane_size, C8NUM); hwi++) {
|
||||
input_sum[hwi] = 0;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, const int32_t *filter_zp,
|
||||
const ConvParameter *conv_param) {
|
||||
size_t hw = conv_param->output_h_ * conv_param->output_w_;
|
||||
size_t hw4 = UP_ROUND(hw, C4NUM);
|
||||
size_t ic16 = UP_ROUND(conv_param->input_channel_, C16NUM);
|
||||
if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) {
|
||||
PackInputSum16x4PerLayer(input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16);
|
||||
} else {
|
||||
#ifdef ENABLE_ARM32
|
||||
PackInputSum16x4PerChannelArm32(input, input_sum, filter_zp, hw, conv_param->input_channel_,
|
||||
conv_param->output_channel_);
|
||||
#else
|
||||
PackInputSum16x4PerChannel(input, input_sum, filter_zp, hw, conv_param->input_channel_,
|
||||
conv_param->output_channel_);
|
||||
#endif
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num,
|
||||
int block_index, const int32_t *filter_zp, int32_t *input_sum,
|
||||
const ConvParameter *conv_param, bool per_channel, bool is_optimize) {
|
||||
// input format : nhwc
|
||||
int kernel_h = conv_param->kernel_h_;
|
||||
int kernel_w = conv_param->kernel_w_;
|
||||
int stride_h = conv_param->stride_h_;
|
||||
int stride_w = conv_param->stride_w_;
|
||||
int pad_h = conv_param->pad_u_;
|
||||
int pad_w = conv_param->pad_l_;
|
||||
int dilation_h = conv_param->dilation_h_;
|
||||
int dilation_w = conv_param->dilation_w_;
|
||||
int in_channel = conv_param->input_channel_;
|
||||
int in_h = conv_param->input_h_;
|
||||
int in_w = conv_param->input_w_;
|
||||
int out_w = conv_param->output_w_;
|
||||
int kernel_plane = kernel_h * kernel_w;
|
||||
NNACL_CHECK_ZERO_RETURN(out_w);
|
||||
NNACL_CHECK_ZERO_RETURN(dilation_h);
|
||||
NNACL_CHECK_ZERO_RETURN(dilation_w);
|
||||
for (int i = 0; i < real_cal_num; i++) {
|
||||
int block_start = block_index + i;
|
||||
int input_h = block_start / out_w * stride_h - pad_h;
|
||||
int input_w = block_start % out_w * stride_w - pad_w;
|
||||
int input_stride = input_h * in_w * in_channel + input_w * in_channel;
|
||||
int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h));
|
||||
int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h));
|
||||
int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
|
||||
int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
|
||||
if (dilation_w == 1 && dilation_h == 1) {
|
||||
for (int j = kh_s; j < kh_e; j++) {
|
||||
int input_y_stride = j * in_w * in_channel + input_stride;
|
||||
int input_x_stride = input_y_stride + kw_s * in_channel;
|
||||
int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane;
|
||||
memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, (kw_e - kw_s) * in_channel);
|
||||
} // kernel_h loop
|
||||
} else {
|
||||
for (int j = kh_s; j < kh_e; j++) {
|
||||
int input_y_stride = j * dilation_h * in_w * in_channel + input_stride;
|
||||
for (int k = kw_s; k < kw_e; ++k) {
|
||||
int input_x_stride = input_y_stride + k * dilation_w * in_channel;
|
||||
int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane;
|
||||
memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, in_channel);
|
||||
}
|
||||
} // kernel_h loop
|
||||
}
|
||||
} // tile num loop
|
||||
int deep = kernel_plane * in_channel;
|
||||
if (is_optimize) {
|
||||
if (per_channel) {
|
||||
Conv1x1PreOptPeroc(matmul_input, packed_input, input_sum, deep, conv_param->output_channel_, real_cal_num,
|
||||
filter_zp, C8NUM * C8NUM);
|
||||
} else {
|
||||
Conv1x1PreOptPert(matmul_input, packed_input, input_sum, deep, real_cal_num, conv_param);
|
||||
}
|
||||
} else {
|
||||
RowMajor2Row16x4MajorInt8(matmul_input, packed_input, real_cal_num, deep);
|
||||
if (per_channel) {
|
||||
#ifdef ENABLE_ARM32
|
||||
PackInputSum16x4PerChannelArm32(packed_input, input_sum, filter_zp, real_cal_num, deep,
|
||||
conv_param->output_channel_);
|
||||
#else
|
||||
PackInputSum16x4PerChannel(packed_input, input_sum, filter_zp, real_cal_num, deep, conv_param->output_channel_);
|
||||
#endif
|
||||
} else {
|
||||
size_t hw4 = UP_ROUND(real_cal_num, C4NUM);
|
||||
size_t ic16 = UP_ROUND(deep, C16NUM);
|
||||
PackInputSum16x4PerLayer(packed_input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4,
|
||||
ic16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int8_t *packed_weight,
|
||||
const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id,
|
||||
ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func, bool is_optimize) {
|
||||
|
|
|
@ -16,799 +16,6 @@
|
|||
|
||||
#include "nnacl/int8/pack_int8.h"
|
||||
|
||||
#ifdef ENABLE_ARM32
|
||||
void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_sum, const int32_t *filter_zp_ptr,
|
||||
size_t plane_size, size_t input_channel, size_t output_channel) {
|
||||
size_t hw4 = UP_ROUND(plane_size, C4NUM);
|
||||
size_t ic16 = UP_ROUND(input_channel, C16NUM);
|
||||
|
||||
#ifdef ENABLE_ARM32
|
||||
size_t oc_div2 = output_channel / C2NUM * C2NUM;
|
||||
size_t oc_res2 = output_channel - oc_div2;
|
||||
size_t inputsun_stride = hw4 * C2NUM * 4 - C4NUM * C2NUM * 4;
|
||||
PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div2, oc_res2, inputsun_stride);
|
||||
#else
|
||||
for (int ri = 0; ri < plane_size; ri++) {
|
||||
int ri4div = ri / C4NUM, ri4mod = ri % C4NUM;
|
||||
for (int ci = 0; ci < output_channel; ci++) {
|
||||
int32_t tmp_sum_value = 0;
|
||||
int ci2div = ci / C2NUM, ci2mod = ci % C2NUM;
|
||||
int32_t filter_zp = filter_zp_ptr[ci];
|
||||
for (int di = 0; di < input_channel; di++) {
|
||||
size_t di16div = di / C16NUM, di16mod = di % C16NUM;
|
||||
int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod;
|
||||
tmp_sum_value += input_value[src_index];
|
||||
}
|
||||
int dst_index = ci2div * C2NUM * hw4 + ri * C2NUM + ci2mod;
|
||||
input_sum[dst_index] = tmp_sum_value * filter_zp;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, const int32_t *filter_zp_ptr,
|
||||
size_t plane_size, size_t input_channel, size_t output_channel) {
|
||||
size_t hw4 = UP_ROUND(plane_size, C4NUM);
|
||||
size_t ic16 = UP_ROUND(input_channel, C16NUM);
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t oc_div4 = output_channel / C4NUM * C4NUM;
|
||||
size_t oc_res4 = output_channel - oc_div4;
|
||||
size_t inputsun_stride = hw4 * C4NUM * 4 - C4NUM * C4NUM * 4;
|
||||
PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div4, oc_res4, inputsun_stride);
|
||||
#else
|
||||
|
||||
for (int ri = 0; ri < plane_size; ri++) {
|
||||
int ri4div = ri / C4NUM, ri4mod = ri % C4NUM;
|
||||
for (int ci = 0; ci < output_channel; ci++) {
|
||||
int32_t tmp_sum_value = 0;
|
||||
int ci4div = ci / C4NUM, ci4mod = ci % C4NUM;
|
||||
int32_t filter_zp = filter_zp_ptr[ci];
|
||||
for (int di = 0; di < input_channel; di++) {
|
||||
size_t di16div = di / C16NUM, di16mod = di % C16NUM;
|
||||
int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod;
|
||||
tmp_sum_value += input_value[src_index];
|
||||
}
|
||||
int dst_index = ci4div * C4NUM * hw4 + ri * C4NUM + ci4mod;
|
||||
input_sum[dst_index] = tmp_sum_value * filter_zp;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
||||
void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
|
||||
size_t output_channel, size_t plane_size, const int32_t *filter_zp, size_t inputsum_stride) {
|
||||
int ic4 = UP_ROUND(input_channel, C4NUM);
|
||||
int oc8 = UP_ROUND(output_channel, C8NUM);
|
||||
int hw8 = UP_ROUND(plane_size, C8NUM);
|
||||
size_t hw_8div = plane_size / C8NUM * C8NUM;
|
||||
size_t oc_8div = output_channel / C8NUM * C8NUM;
|
||||
size_t oc_8res = output_channel - oc_8div;
|
||||
size_t ic_4div = input_channel / C4NUM * C4NUM;
|
||||
|
||||
const int8_t *src_r = src_input;
|
||||
int8_t *pack_r = packed_input;
|
||||
int32_t *input_sum_r = input_sum;
|
||||
|
||||
for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) {
|
||||
const int8_t *src_ic = src_r;
|
||||
int8_t *pack_ic = pack_r;
|
||||
int32_t *input_sum_oc = input_sum_r;
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t src_stride = input_channel;
|
||||
size_t ic_4res = input_channel - ic_4div;
|
||||
size_t input_sum_stride = inputsum_stride * 4 - C8NUM * C8NUM * 4;
|
||||
asm volatile(
|
||||
"dup v16.4s, wzr \n"
|
||||
"dup v17.4s, wzr \n"
|
||||
|
||||
"mov x10, %[src_ic] \n"
|
||||
"mov x11, %[pack_ic] \n"
|
||||
|
||||
"mov x0, #0 \n"
|
||||
"1: \n"
|
||||
"cmp x0, %[ic_4div] \n"
|
||||
"add x0, x0, #4\n"
|
||||
"mov x12, x10 \n"
|
||||
"add x10, x10, #4\n"
|
||||
"blt 2f \n"
|
||||
"cmp %[ic_4res], #0\n"
|
||||
"beq 6f \n"
|
||||
"cmp %[ic_4res], #1\n"
|
||||
"beq 3f \n"
|
||||
"cmp %[ic_4res], #2\n"
|
||||
"beq 4f \n"
|
||||
"cmp %[ic_4res], #3\n"
|
||||
"beq 5f \n"
|
||||
|
||||
"2: \n"
|
||||
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[1], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[3], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 1b \n"
|
||||
|
||||
"3: \n" /* col res 1 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
|
||||
"ld1 {v0.b}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[8], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[12], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"4: \n" /* col res 2 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
|
||||
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"5: \n" /* col res 3 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
"add x13, x12, #2 \n"
|
||||
|
||||
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[2], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[6], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[10], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[14], [x13], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"6: \n"
|
||||
"dup v0.4s, v16.s[0] \n"
|
||||
"dup v1.4s, v16.s[1] \n"
|
||||
"dup v2.4s, v16.s[2] \n"
|
||||
"dup v3.4s, v16.s[3] \n"
|
||||
"dup v4.4s, v17.s[0] \n"
|
||||
"dup v5.4s, v17.s[1] \n"
|
||||
"dup v6.4s, v17.s[2] \n"
|
||||
"dup v7.4s, v17.s[3] \n"
|
||||
"mov x4, #0 \n"
|
||||
"mov x10, %[filter_zp] \n"
|
||||
"mov x11, %[input_sum_oc] \n"
|
||||
|
||||
"7: \n"
|
||||
"cmp x4, %[oc_8div] \n"
|
||||
"beq 8f \n"
|
||||
"add x4, x4, #8\n"
|
||||
"ld1 {v16.4s}, [x10], #16\n"
|
||||
"ld1 {v17.4s}, [x10], #16\n"
|
||||
|
||||
"mul v18.4s, v16.4s, v0.4s \n"
|
||||
"mul v19.4s, v17.4s, v0.4s \n"
|
||||
"st1 {v18.4s}, [x11], #16 \n"
|
||||
"st1 {v19.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v20.4s, v16.4s, v1.4s \n"
|
||||
"mul v21.4s, v17.4s, v1.4s \n"
|
||||
"st1 {v20.4s}, [x11], #16 \n"
|
||||
"st1 {v21.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v22.4s, v16.4s, v2.4s \n"
|
||||
"mul v23.4s, v17.4s, v2.4s \n"
|
||||
"st1 {v22.4s}, [x11], #16 \n"
|
||||
"st1 {v23.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v24.4s, v16.4s, v3.4s \n"
|
||||
"mul v25.4s, v17.4s, v3.4s \n"
|
||||
"st1 {v24.4s}, [x11], #16 \n"
|
||||
"st1 {v25.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v18.4s, v16.4s, v4.4s \n"
|
||||
"mul v19.4s, v17.4s, v4.4s \n"
|
||||
"st1 {v18.4s}, [x11], #16 \n"
|
||||
"st1 {v19.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v20.4s, v16.4s, v5.4s \n"
|
||||
"mul v21.4s, v17.4s, v5.4s \n"
|
||||
"st1 {v20.4s}, [x11], #16 \n"
|
||||
"st1 {v21.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v22.4s, v16.4s, v6.4s \n"
|
||||
"mul v23.4s, v17.4s, v6.4s \n"
|
||||
"st1 {v22.4s}, [x11], #16 \n"
|
||||
"st1 {v23.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v24.4s, v16.4s, v7.4s \n"
|
||||
"mul v25.4s, v17.4s, v7.4s \n"
|
||||
"st1 {v24.4s}, [x11], #16 \n"
|
||||
"st1 {v25.4s}, [x11], #16 \n"
|
||||
|
||||
"add x11, x11, %[input_sum_stride] \n"
|
||||
"b 7b \n"
|
||||
|
||||
"8: \n"
|
||||
"cmp %[oc_8res], #0\n"
|
||||
"beq 17f \n"
|
||||
|
||||
"dup v16.4s, wzr \n"
|
||||
"dup v17.4s, wzr \n"
|
||||
"cmp %[oc_8res], #1\n"
|
||||
"beq 9f \n"
|
||||
"cmp %[oc_8res], #2\n"
|
||||
"beq 10f \n"
|
||||
"cmp %[oc_8res], #3\n"
|
||||
"beq 11f \n"
|
||||
"cmp %[oc_8res], #4\n"
|
||||
"beq 12f \n"
|
||||
"cmp %[oc_8res], #5\n"
|
||||
"beq 13f \n"
|
||||
"cmp %[oc_8res], #6\n"
|
||||
"beq 14f \n"
|
||||
"cmp %[oc_8res], #7\n"
|
||||
"beq 15f \n"
|
||||
|
||||
"9: \n"
|
||||
"ld1 {v16.s}[0], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"10: \n"
|
||||
"ld1 {v16.d}[0], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"11: \n"
|
||||
"ld1 {v16.d}[0], [x10] \n"
|
||||
"add x10, x10, #8 \n"
|
||||
"ld1 {v16.s}[2], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"12: \n"
|
||||
"ld1 {v16.4s}, [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"13: \n"
|
||||
"ld1 {v16.4s}, [x10], #16\n"
|
||||
"ld1 {v17.s}[0], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"14: \n"
|
||||
"ld1 {v16.4s}, [x10], #16\n"
|
||||
"ld1 {v17.d}[0], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"15: \n"
|
||||
"ld1 {v16.4s}, [x10], #16\n"
|
||||
"ld1 {v17.d}[0], [x10] \n"
|
||||
"add x10, x10, #8 \n"
|
||||
"ld1 {v17.s}[2], [x10] \n"
|
||||
"b 16f \n"
|
||||
|
||||
"16: \n"
|
||||
"mul v18.4s, v16.4s, v0.4s \n"
|
||||
"mul v19.4s, v17.4s, v0.4s \n"
|
||||
"mul v20.4s, v16.4s, v1.4s \n"
|
||||
"mul v21.4s, v17.4s, v1.4s \n"
|
||||
"mul v22.4s, v16.4s, v2.4s \n"
|
||||
"mul v23.4s, v17.4s, v2.4s \n"
|
||||
"mul v24.4s, v16.4s, v3.4s \n"
|
||||
"mul v25.4s, v17.4s, v3.4s \n"
|
||||
"st1 {v18.4s}, [x11], #16 \n"
|
||||
"st1 {v19.4s}, [x11], #16 \n"
|
||||
"st1 {v20.4s}, [x11], #16 \n"
|
||||
"st1 {v21.4s}, [x11], #16 \n"
|
||||
"st1 {v22.4s}, [x11], #16 \n"
|
||||
"st1 {v23.4s}, [x11], #16 \n"
|
||||
"st1 {v24.4s}, [x11], #16 \n"
|
||||
"st1 {v25.4s}, [x11], #16 \n"
|
||||
|
||||
"mul v18.4s, v16.4s, v4.4s \n"
|
||||
"mul v19.4s, v17.4s, v4.4s \n"
|
||||
"mul v20.4s, v16.4s, v5.4s \n"
|
||||
"mul v21.4s, v17.4s, v5.4s \n"
|
||||
"mul v22.4s, v16.4s, v6.4s \n"
|
||||
"mul v23.4s, v17.4s, v6.4s \n"
|
||||
"mul v24.4s, v16.4s, v7.4s \n"
|
||||
"mul v25.4s, v17.4s, v7.4s \n"
|
||||
"st1 {v18.4s}, [x11], #16 \n"
|
||||
"st1 {v19.4s}, [x11], #16 \n"
|
||||
"st1 {v20.4s}, [x11], #16 \n"
|
||||
"st1 {v21.4s}, [x11], #16 \n"
|
||||
"st1 {v22.4s}, [x11], #16 \n"
|
||||
"st1 {v23.4s}, [x11], #16 \n"
|
||||
"st1 {v24.4s}, [x11], #16 \n"
|
||||
"st1 {v25.4s}, [x11], #16 \n"
|
||||
|
||||
"17: \n"
|
||||
|
||||
:
|
||||
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ filter_zp ] "r"(filter_zp),
|
||||
[ input_sum_oc ] "r"(input_sum_oc), [ input_sum_stride ] "r"(input_sum_stride), [ src_stride ] "r"(src_stride),
|
||||
[ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ oc_8div ] "r"(oc_8div), [ oc_8res ] "r"(oc_8res)
|
||||
: "x0", "x1", "x4", "x9", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16",
|
||||
"v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25");
|
||||
#else
|
||||
int32_t tmp_sum_value[8] = {0};
|
||||
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
tmp_sum_value[i] += src_ic[0 + i * input_channel];
|
||||
tmp_sum_value[i] += src_ic[1 + i * input_channel];
|
||||
tmp_sum_value[i] += src_ic[2 + i * input_channel];
|
||||
tmp_sum_value[i] += src_ic[3 + i * input_channel];
|
||||
pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel];
|
||||
pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel];
|
||||
pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel];
|
||||
pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel];
|
||||
}
|
||||
src_ic += C4NUM;
|
||||
pack_ic += C4NUM * C8NUM;
|
||||
}
|
||||
for (int ici = ic_4div; ici < input_channel; ici += 1) {
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
tmp_sum_value[i] += src_ic[i * input_channel];
|
||||
pack_ic[i * C4NUM] = src_ic[i * input_channel];
|
||||
}
|
||||
src_ic += 1;
|
||||
pack_ic += 1;
|
||||
}
|
||||
|
||||
for (int ici = input_channel; ici < ic4; ici += 1) {
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
pack_ic[i * C4NUM] = 0;
|
||||
}
|
||||
pack_ic += 1;
|
||||
}
|
||||
|
||||
for (int oci = 0; oci < oc_8div; oci += C8NUM) {
|
||||
for (int ri = 0; ri < C8NUM; ri++) {
|
||||
input_sum_oc[ri * C8NUM + 0] = tmp_sum_value[ri] * filter_zp[oci + 0];
|
||||
input_sum_oc[ri * C8NUM + 1] = tmp_sum_value[ri] * filter_zp[oci + 1];
|
||||
input_sum_oc[ri * C8NUM + 2] = tmp_sum_value[ri] * filter_zp[oci + 2];
|
||||
input_sum_oc[ri * C8NUM + 3] = tmp_sum_value[ri] * filter_zp[oci + 3];
|
||||
input_sum_oc[ri * C8NUM + 4] = tmp_sum_value[ri] * filter_zp[oci + 4];
|
||||
input_sum_oc[ri * C8NUM + 5] = tmp_sum_value[ri] * filter_zp[oci + 5];
|
||||
input_sum_oc[ri * C8NUM + 6] = tmp_sum_value[ri] * filter_zp[oci + 6];
|
||||
input_sum_oc[ri * C8NUM + 7] = tmp_sum_value[ri] * filter_zp[oci + 7];
|
||||
}
|
||||
input_sum_oc += inputsum_stride;
|
||||
}
|
||||
if (oc_8div != output_channel) {
|
||||
for (int oci = 0; oci < oc_8res; oci += 1) {
|
||||
for (int ri = 0; ri < C8NUM; ri++) {
|
||||
input_sum_oc[ri * C8NUM + oci] = tmp_sum_value[ri] * filter_zp[oc_8div + oci];
|
||||
}
|
||||
}
|
||||
for (int oci = oc_8res; oci < C8NUM; oci += 1) {
|
||||
for (int ri = 0; ri < C8NUM; ri++) {
|
||||
input_sum_oc[ri * C8NUM + oci] = 0;
|
||||
}
|
||||
}
|
||||
} /* oc8 res done */
|
||||
#endif
|
||||
src_r += input_channel * C8NUM;
|
||||
pack_r += ic4 * C8NUM;
|
||||
input_sum_r += C8NUM * C8NUM;
|
||||
}
|
||||
|
||||
if (hw_8div != plane_size) {
|
||||
memset(pack_r, 0, C8NUM * ic4);
|
||||
for (int hwi = hw_8div; hwi < plane_size; hwi += 1) {
|
||||
int32_t *input_sum_oc = input_sum_r;
|
||||
int32_t tmp_sum_value = 0;
|
||||
const int8_t *src_ic = src_r;
|
||||
int8_t *pack_ic = pack_r;
|
||||
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
|
||||
tmp_sum_value += src_ic[0];
|
||||
tmp_sum_value += src_ic[1];
|
||||
tmp_sum_value += src_ic[2];
|
||||
tmp_sum_value += src_ic[3];
|
||||
pack_ic[0] = src_ic[0];
|
||||
pack_ic[1] = src_ic[1];
|
||||
pack_ic[2] = src_ic[2];
|
||||
pack_ic[3] = src_ic[3];
|
||||
src_ic += C4NUM;
|
||||
pack_ic += C4NUM * C8NUM;
|
||||
}
|
||||
for (int ici = ic_4div; ici < input_channel; ici += 1) {
|
||||
tmp_sum_value += src_ic[0];
|
||||
pack_ic[0] = src_ic[0];
|
||||
src_ic += 1;
|
||||
pack_ic += 1;
|
||||
}
|
||||
|
||||
for (int oci = 0; oci < oc_8div; oci += C8NUM) {
|
||||
for (int curoi = 0; curoi < C8NUM; curoi++) {
|
||||
input_sum_oc[curoi] = tmp_sum_value * filter_zp[oci + curoi];
|
||||
}
|
||||
input_sum_oc += inputsum_stride;
|
||||
}
|
||||
if (oc_8div != output_channel) {
|
||||
for (int oci = 0; oci < oc_8res; oci += 1) {
|
||||
input_sum_oc[oci] = tmp_sum_value * filter_zp[oc_8div + oci];
|
||||
}
|
||||
for (int oci = oc_8res; oci < C8NUM; oci += 1) {
|
||||
input_sum_oc[oci] = 0;
|
||||
}
|
||||
} /* oc8 res done */
|
||||
|
||||
src_r += input_channel;
|
||||
pack_r += C4NUM;
|
||||
input_sum_r += C8NUM;
|
||||
}
|
||||
|
||||
for (int hwi = plane_size; hwi < hw8; hwi++) {
|
||||
for (int oc = 0; oc < oc8; oc++) {
|
||||
int oc8div = oc / C8NUM, oc8res = oc % C8NUM;
|
||||
input_sum[oc8div * inputsum_stride + hwi * C8NUM + oc8res] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
|
||||
size_t plane_size, const ConvParameter *conv_param) {
|
||||
int ic4 = UP_ROUND(input_channel, C4NUM);
|
||||
size_t hw_8div = plane_size / C8NUM * C8NUM;
|
||||
size_t ic_4div = input_channel / C4NUM * C4NUM;
|
||||
int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_;
|
||||
|
||||
const int8_t *src_r = src_input;
|
||||
int8_t *pack_r = packed_input;
|
||||
/* per layer */
|
||||
for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) {
|
||||
const int8_t *src_ic = src_r;
|
||||
int8_t *pack_ic = pack_r;
|
||||
int32_t *input_sum_r = input_sum + hwi;
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t src_stride = input_channel;
|
||||
size_t ic_4res = input_channel - ic_4div;
|
||||
asm volatile(
|
||||
"dup v16.4s, wzr \n"
|
||||
"dup v17.4s, wzr \n"
|
||||
"mov x14, %[input_sum_r] \n"
|
||||
"dup v20.4s, %w[filter_zp] \n"
|
||||
|
||||
"mov x10, %[src_ic] \n"
|
||||
"mov x11, %[pack_ic] \n"
|
||||
|
||||
"mov x0, #0 \n"
|
||||
"1: \n"
|
||||
"cmp x0, %[ic_4div] \n"
|
||||
"add x0, x0, #4\n"
|
||||
"mov x12, x10 \n"
|
||||
"add x10, x10, #4\n"
|
||||
"blt 2f \n"
|
||||
"cmp %[ic_4res], #0\n"
|
||||
"beq 6f \n"
|
||||
"cmp %[ic_4res], #1\n"
|
||||
"beq 3f \n"
|
||||
"cmp %[ic_4res], #2\n"
|
||||
"beq 4f \n"
|
||||
"cmp %[ic_4res], #3\n"
|
||||
"beq 5f \n"
|
||||
|
||||
"2: \n"
|
||||
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[1], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[3], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 1b \n"
|
||||
|
||||
"3: \n" /* col res 1 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
|
||||
"ld1 {v0.b}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[8], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[12], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"4: \n" /* col res 2 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
|
||||
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"5: \n" /* col res 3 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
"add x13, x12, #2 \n"
|
||||
|
||||
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[2], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[6], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[10], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[14], [x13], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v16.4s, v16.4s, v0.4s \n"
|
||||
"add v17.4s, v17.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"6: \n"
|
||||
"mul v16.4s, v16.4s, v20.4s \n"
|
||||
"mul v17.4s, v17.4s, v20.4s \n"
|
||||
|
||||
"st1 {v16.4s}, [x14], #16 \n"
|
||||
"st1 {v17.4s}, [x14], #16 \n"
|
||||
|
||||
:
|
||||
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r),
|
||||
[ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp)
|
||||
: "x0", "x1", "x10", "x11", "x12", "x13", "x14", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17",
|
||||
"v20");
|
||||
#else
|
||||
int32_t tmp_sum_value[8] = {0};
|
||||
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
tmp_sum_value[i] += src_ic[0 + i * input_channel];
|
||||
tmp_sum_value[i] += src_ic[1 + i * input_channel];
|
||||
tmp_sum_value[i] += src_ic[2 + i * input_channel];
|
||||
tmp_sum_value[i] += src_ic[3 + i * input_channel];
|
||||
pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel];
|
||||
pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel];
|
||||
pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel];
|
||||
pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel];
|
||||
}
|
||||
src_ic += C4NUM;
|
||||
pack_ic += C4NUM * C8NUM;
|
||||
}
|
||||
for (int ici = ic_4div; ici < input_channel; ici += 1) {
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
tmp_sum_value[i] += src_ic[i * input_channel];
|
||||
pack_ic[i * C4NUM] = src_ic[i * input_channel];
|
||||
}
|
||||
src_ic += 1;
|
||||
pack_ic += 1;
|
||||
}
|
||||
|
||||
for (int ici = input_channel; ici < ic4; ici += 1) {
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
pack_ic[i * C4NUM] = 0;
|
||||
}
|
||||
pack_ic += 1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
input_sum_r[i] = tmp_sum_value[i] * filter_zp;
|
||||
}
|
||||
#endif
|
||||
src_r += input_channel * C8NUM;
|
||||
pack_r += ic4 * C8NUM;
|
||||
}
|
||||
|
||||
if (hw_8div != plane_size) {
|
||||
memset(pack_r, 0, C8NUM * ic4);
|
||||
for (int hwi = hw_8div; hwi < plane_size; hwi += 1) {
|
||||
int32_t tmp_sum_value = 0;
|
||||
const int8_t *src_ic = src_r;
|
||||
int8_t *pack_ic = pack_r;
|
||||
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
|
||||
tmp_sum_value += src_ic[0];
|
||||
tmp_sum_value += src_ic[1];
|
||||
tmp_sum_value += src_ic[2];
|
||||
tmp_sum_value += src_ic[3];
|
||||
pack_ic[0] = src_ic[0];
|
||||
pack_ic[1] = src_ic[1];
|
||||
pack_ic[2] = src_ic[2];
|
||||
pack_ic[3] = src_ic[3];
|
||||
src_ic += C4NUM;
|
||||
pack_ic += C4NUM * C8NUM;
|
||||
}
|
||||
for (int ici = ic_4div; ici < input_channel; ici += 1) {
|
||||
tmp_sum_value += src_ic[0];
|
||||
pack_ic[0] = src_ic[0];
|
||||
src_ic += 1;
|
||||
pack_ic += 1;
|
||||
}
|
||||
input_sum[hwi] = tmp_sum_value * filter_zp;
|
||||
src_r += input_channel;
|
||||
pack_r += C4NUM;
|
||||
}
|
||||
for (int hwi = plane_size; hwi < UP_ROUND(plane_size, C8NUM); hwi++) {
|
||||
input_sum[hwi] = 0;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num,
|
||||
int block_index, const int32_t *filter_zp, int32_t *input_sum,
|
||||
const ConvParameter *conv_param, bool per_channel, bool is_optimize) {
|
||||
// input format : nhwc
|
||||
int kernel_h = conv_param->kernel_h_;
|
||||
int kernel_w = conv_param->kernel_w_;
|
||||
int stride_h = conv_param->stride_h_;
|
||||
int stride_w = conv_param->stride_w_;
|
||||
int pad_h = conv_param->pad_u_;
|
||||
int pad_w = conv_param->pad_l_;
|
||||
int dilation_h = conv_param->dilation_h_;
|
||||
int dilation_w = conv_param->dilation_w_;
|
||||
int in_channel = conv_param->input_channel_;
|
||||
int in_h = conv_param->input_h_;
|
||||
int in_w = conv_param->input_w_;
|
||||
int out_w = conv_param->output_w_;
|
||||
int kernel_plane = kernel_h * kernel_w;
|
||||
NNACL_CHECK_ZERO_RETURN(out_w);
|
||||
NNACL_CHECK_ZERO_RETURN(dilation_h);
|
||||
NNACL_CHECK_ZERO_RETURN(dilation_w);
|
||||
for (int i = 0; i < real_cal_num; i++) {
|
||||
int block_start = block_index + i;
|
||||
int input_h = block_start / out_w * stride_h - pad_h;
|
||||
int input_w = block_start % out_w * stride_w - pad_w;
|
||||
int input_stride = input_h * in_w * in_channel + input_w * in_channel;
|
||||
int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h));
|
||||
int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h));
|
||||
int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
|
||||
int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
|
||||
if (dilation_w == 1 && dilation_h == 1) {
|
||||
for (int j = kh_s; j < kh_e; j++) {
|
||||
int input_y_stride = j * in_w * in_channel + input_stride;
|
||||
int input_x_stride = input_y_stride + kw_s * in_channel;
|
||||
int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane;
|
||||
memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, (kw_e - kw_s) * in_channel);
|
||||
} // kernel_h loop
|
||||
} else {
|
||||
for (int j = kh_s; j < kh_e; j++) {
|
||||
int input_y_stride = j * dilation_h * in_w * in_channel + input_stride;
|
||||
for (int k = kw_s; k < kw_e; ++k) {
|
||||
int input_x_stride = input_y_stride + k * dilation_w * in_channel;
|
||||
int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane;
|
||||
memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, in_channel);
|
||||
}
|
||||
} // kernel_h loop
|
||||
}
|
||||
} // tile num loop
|
||||
int deep = kernel_plane * in_channel;
|
||||
if (is_optimize) {
|
||||
if (per_channel) {
|
||||
Conv1x1PreOptPeroc(matmul_input, packed_input, input_sum, deep, conv_param->output_channel_, real_cal_num,
|
||||
filter_zp, C8NUM * C8NUM);
|
||||
} else {
|
||||
Conv1x1PreOptPert(matmul_input, packed_input, input_sum, deep, real_cal_num, conv_param);
|
||||
}
|
||||
} else {
|
||||
RowMajor2Row16x4MajorInt8(matmul_input, packed_input, real_cal_num, deep);
|
||||
if (per_channel) {
|
||||
#ifdef ENABLE_ARM32
|
||||
PackInputSum16x4PerChannelArm32(packed_input, input_sum, filter_zp, real_cal_num, deep,
|
||||
conv_param->output_channel_);
|
||||
#else
|
||||
PackInputSum16x4PerChannel(packed_input, input_sum, filter_zp, real_cal_num, deep, conv_param->output_channel_);
|
||||
#endif
|
||||
} else {
|
||||
size_t hw4 = UP_ROUND(real_cal_num, C4NUM);
|
||||
size_t ic16 = UP_ROUND(deep, C16NUM);
|
||||
PackInputSum16x4PerLayer(packed_input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4,
|
||||
ic16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, const ConvParameter *conv_param) {
|
||||
int in_batch = conv_param->input_batch_;
|
||||
int in_channel = conv_param->input_channel_;
|
||||
|
@ -908,25 +115,6 @@ void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight
|
|||
}
|
||||
}
|
||||
|
||||
void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, const int32_t *filter_zp,
|
||||
const ConvParameter *conv_param) {
|
||||
size_t hw = conv_param->output_h_ * conv_param->output_w_;
|
||||
size_t hw4 = UP_ROUND(hw, C4NUM);
|
||||
size_t ic16 = UP_ROUND(conv_param->input_channel_, C16NUM);
|
||||
if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) {
|
||||
PackInputSum16x4PerLayer(input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16);
|
||||
} else {
|
||||
#ifdef ENABLE_ARM32
|
||||
PackInputSum16x4PerChannelArm32(input, input_sum, filter_zp, hw, conv_param->input_channel_,
|
||||
conv_param->output_channel_);
|
||||
#else
|
||||
PackInputSum16x4PerChannel(input, input_sum, filter_zp, hw, conv_param->input_channel_,
|
||||
conv_param->output_channel_);
|
||||
#endif
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16) {
|
||||
/* normal matmul : 4x16 * 16x4 -> 4x4 */
|
||||
#ifdef ENABLE_ARM
|
||||
|
@ -1081,25 +269,6 @@ void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int c
|
|||
}
|
||||
}
|
||||
|
||||
void PackNCHWToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int c8 = UP_DIV(channel, C8NUM);
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int src_offset = b * plane * channel;
|
||||
int dst_offset = b * plane * c8 * C8NUM;
|
||||
for (int c = 0; c < channel; c++) {
|
||||
int c8_block_num = c / C8NUM;
|
||||
int c8_block_rem = c % C8NUM;
|
||||
int src_c_offset = src_offset + c * plane;
|
||||
int dst_c_offset = dst_offset + c8_block_num * plane * C8NUM;
|
||||
for (int k = 0; k < plane; k++) {
|
||||
int src_kernel_offset = src_c_offset + k;
|
||||
int dst_kernel_offset = dst_c_offset + C8NUM * k + c8_block_rem;
|
||||
((int8_t *)dst + dst_kernel_offset)[0] = ((int8_t *)src + src_kernel_offset)[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
for (int b = 0; b < batch; b++) {
|
||||
|
@ -1127,21 +296,6 @@ void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int
|
|||
}
|
||||
}
|
||||
|
||||
void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
for (int n = 0; n < batch; n++) {
|
||||
for (int hw = 0; hw < plane; hw++) {
|
||||
for (int c = 0; c < channel; c++) {
|
||||
int c8div = c / C8NUM;
|
||||
int c8mod = c % C8NUM;
|
||||
int src_index = n * plane * channel + hw * channel + c;
|
||||
int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod;
|
||||
((int8_t *)dst)[dst_index] = ((int8_t *)src)[src_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
for (int n = 0; n < batch; n++) {
|
||||
for (int c = 0; c < channel; c++) {
|
||||
|
|
|
@ -30,20 +30,13 @@ void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int c
|
|||
void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNCHWToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
void PackNHWCToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, const int32_t *filter_zp,
|
||||
const ConvParameter *conv_param);
|
||||
void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16);
|
||||
void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, const ConvParameter *conv_param);
|
||||
void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, const ConvParameter *conv_param);
|
||||
void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num,
|
||||
int block_index, const int32_t *filter_zp, int32_t *input_sum,
|
||||
const ConvParameter *conv_param, bool per_channel, bool is_optimize);
|
||||
#ifdef ENABLE_ARM
|
||||
void PreSum4x16Int8Pert(const int8_t *src, int32_t *sum, size_t row4, size_t col16, int32_t filter_zp);
|
||||
void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, const int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div,
|
||||
|
|
|
@ -27,6 +27,7 @@ option(MSLITE_ENABLE_MINDRT "enable mindrt use" on)
|
|||
option(MSLITE_ENABLE_DELEGATE "enable delegate use" on)
|
||||
option(MSLITE_ENABLE_V0 "support v0 schema" on)
|
||||
option(MSLITE_ENABLE_FP16 "Whether to compile Fp16 operator" off)
|
||||
option(MSLITE_ENABLE_INT8 "Whether to compile Int8 operator" on)
|
||||
option(MSLITE_ENABLE_ACL "enable ACL" off)
|
||||
option(MSLITE_ENABLE_MODEL_ENCRYPTION "enable model encryption, only converter support" on)
|
||||
option(MSLITE_ENABLE_SPARSE_COMPUTE "enable sparse kernel" off)
|
||||
|
@ -96,6 +97,14 @@ endif()
|
|||
if(DEFINED ENV{MSLITE_ENABLE_FP16})
|
||||
set(MSLITE_ENABLE_FP16 $ENV{MSLITE_ENABLE_FP16})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_INT8})
|
||||
set(MSLITE_ENABLE_INT8 $ENV{MSLITE_ENABLE_INT8})
|
||||
endif()
|
||||
if(MSLITE_ENABLE_INT8)
|
||||
set(OP_INT8_CLIP off)
|
||||
else()
|
||||
set(OP_INT8_CLIP on)
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_SPARSE_COMPUTE})
|
||||
set(MSLITE_ENABLE_SPARSE_COMPUTE $ENV{MSLITE_ENABLE_SPARSE_COMPUTE})
|
||||
endif()
|
||||
|
@ -236,6 +245,7 @@ message(STATUS "\tBUILD_MINDDATA = \t${BUILD_MINDDATA}")
|
|||
message(STATUS "\tMSLITE_ENABLE_DELEGATE = \t${MSLITE_ENABLE_DELEGATE}")
|
||||
message(STATUS "\tMSLITE_ENABLE_ACL = \t${MSLITE_ENABLE_ACL}")
|
||||
message(STATUS "\tMSLITE_ENABLE_FP16 = \t${MSLITE_ENABLE_FP16}")
|
||||
message(STATUS "\tMSLITE_ENABLE_INT8 = \t${MSLITE_ENABLE_INT8}")
|
||||
message(STATUS "\tMSLITE_ENABLE_MODEL_ENCRYPTION = \t${MSLITE_ENABLE_MODEL_ENCRYPTION}")
|
||||
message(STATUS "\tMSLITE_ENABLE_SPARSE_COMPUTE = \t${MSLITE_ENABLE_SPARSE_COMPUTE}")
|
||||
message(STATUS "\tMSLITE_ENABLE_RUNTIME_CONVERT = \t${MSLITE_ENABLE_RUNTIME_CONVERT}")
|
||||
|
@ -261,8 +271,9 @@ endif()
|
|||
|
||||
if(MSLITE_ENABLE_FP16 AND PLATFORM_ARM32 AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang"
|
||||
AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
|
||||
message(FATAL_ERROR "If you want to build fp16 in arm82_a32, \
|
||||
message(STATUS "If you want to build fp16 in arm82_a32, \
|
||||
your Clang version:[${CMAKE_CXX_COMPILER_VERSION}] must not be less than 9.0 and please use android nkd r21e!")
|
||||
set(MSLITE_ENABLE_FP16 off)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_HIGH_PERFORMANCE)
|
||||
|
|
|
@ -27,6 +27,9 @@ endif()
|
|||
if(NOT MSLITE_ENABLE_DELEGATE)
|
||||
add_compile_definitions(DELEGATE_CLIP)
|
||||
endif()
|
||||
if(NOT MSLITE_ENABLE_INT8)
|
||||
add_compile_definitions(OP_INT8_CLIP)
|
||||
endif()
|
||||
|
||||
if(APPLE OR PLATFORM_ARM32 OR PLATFORM_ARM64)
|
||||
#for performance
|
||||
|
|
|
@ -54,6 +54,10 @@ const char *const unsupport_fp16_log =
|
|||
"The mindspore-lite library does not support fp16. Set environment variable "
|
||||
"MSLITE_ENABLE_FP16 to on to "
|
||||
"recompile it.";
|
||||
const char *const unsupport_int8_log =
|
||||
"The mindspore-lite library does not support int8. Set environment variable "
|
||||
"MSLITE_ENABLE_INT8 to on to "
|
||||
"recompile it.";
|
||||
|
||||
static inline bool IsPrintDebug() {
|
||||
auto env = std::getenv("GLOG_v");
|
||||
|
|
|
@ -3,9 +3,21 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/)
|
|||
file(GLOB KERNEL_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/base/*.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fp32/*.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/int8/*.cc
|
||||
)
|
||||
|
||||
if(MSLITE_ENABLE_INT8)
|
||||
file(GLOB INT8_KERNEL_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/int8/*.cc
|
||||
)
|
||||
set(KERNEL_SRC
|
||||
${KERNEL_SRC}
|
||||
${INT8_KERNEL_SRC}
|
||||
)
|
||||
list(REMOVE_ITEM KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/int8/opt_op_handler.cc)
|
||||
else()
|
||||
list(REMOVE_ITEM KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/base/quant_dtype_cast.cc)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_SPARSE_COMPUTE)
|
||||
file(GLOB SPARSE_KERNEL_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fp32_sparse/*.cc
|
||||
|
@ -34,7 +46,6 @@ if(MSLITE_ENABLE_CONTROLFLOW)
|
|||
${KERNEL_CONTROL_TENSORLIST}
|
||||
)
|
||||
endif()
|
||||
list(REMOVE_ITEM KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/int8/opt_op_handler.cc)
|
||||
|
||||
if(SUPPORT_TRAIN)
|
||||
file(GLOB TRAIN_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16_grad/*.cc)
|
||||
|
|
|
@ -51,7 +51,16 @@ int DetectionPostProcessBaseCPUKernel::Prepare() {
|
|||
auto anchor_tensor = in_tensors_.at(2);
|
||||
MS_CHECK_GT(anchor_tensor->ElementsNum(), 0, RET_ERROR);
|
||||
CHECK_NULL_RETURN(anchor_tensor->data());
|
||||
if (anchor_tensor->data_type() == kNumberTypeInt8) {
|
||||
if (anchor_tensor->data_type() == kNumberTypeFloat32 || anchor_tensor->data_type() == kNumberTypeFloat) {
|
||||
params_->anchors_ = new (std::nothrow) float[anchor_tensor->ElementsNum()];
|
||||
if (params_->anchors_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc anchor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_CHECK_FALSE(anchor_tensor->Size() == 0, RET_ERROR);
|
||||
memcpy(params_->anchors_, anchor_tensor->data(), anchor_tensor->Size());
|
||||
#ifndef OP_INT8_CLIP
|
||||
} else if (anchor_tensor->data_type() == kNumberTypeInt8) {
|
||||
auto quant_param = anchor_tensor->quant_params().front();
|
||||
auto anchor_int8 = reinterpret_cast<int8_t *>(anchor_tensor->data());
|
||||
auto anchor_fp32 = new (std::nothrow) float[anchor_tensor->ElementsNum()];
|
||||
|
@ -73,14 +82,7 @@ int DetectionPostProcessBaseCPUKernel::Prepare() {
|
|||
DoDequantizeUInt8ToFp32(anchor_uint8, anchor_fp32, quant_param.scale, quant_param.zeroPoint,
|
||||
anchor_tensor->ElementsNum());
|
||||
params_->anchors_ = anchor_fp32;
|
||||
} else if (anchor_tensor->data_type() == kNumberTypeFloat32 || anchor_tensor->data_type() == kNumberTypeFloat) {
|
||||
params_->anchors_ = new (std::nothrow) float[anchor_tensor->ElementsNum()];
|
||||
if (params_->anchors_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc anchor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_CHECK_FALSE(anchor_tensor->Size() == 0, RET_ERROR);
|
||||
memcpy(params_->anchors_, anchor_tensor->data(), anchor_tensor->Size());
|
||||
#endif
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unsupported anchor data type " << anchor_tensor->data_type();
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -1111,6 +1111,11 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
|
|||
MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_;
|
||||
}
|
||||
}
|
||||
#ifdef OP_INT8_CLIP
|
||||
if (desc.data_type == kNumberTypeInt8) {
|
||||
MS_LOG(ERROR) << unsupport_int8_log;
|
||||
}
|
||||
#endif
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -28,10 +28,8 @@ file(GLOB_RECURSE TEST_UT_SRC
|
|||
${TEST_DIR}/st/mindrt_parallel_runtime_test.cc
|
||||
${TEST_DIR}/st/mix_data_type_test.cc
|
||||
${TEST_DIR}/ut/nnacl/infer/*.cc
|
||||
${TEST_DIR}/ut/nnacl/int8/*.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/common/*.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/fp32/*.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/string/*.cc
|
||||
)
|
||||
|
||||
|
@ -68,6 +66,14 @@ if(MSLITE_GPU_BACKEND STREQUAL opencl)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_INT8)
|
||||
file(GLOB_RECURSE TEST_INT8_UT_SRC
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc
|
||||
${TEST_DIR}/ut/nnacl/int8/*.cc
|
||||
)
|
||||
list(APPEND TEST_UT_SRC ${TEST_INT8_UT_SRC})
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_FP16)
|
||||
file(GLOB_RECURSE TEST_FP16_UT_SRC
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/fp16/*.cc
|
||||
|
|
|
@ -35,77 +35,6 @@ class TestDeconvInt8 : public mindspore::CommonTest {
|
|||
TestDeconvInt8() {}
|
||||
};
|
||||
|
||||
TEST_F(TestDeconvInt8, PackWeight1) {
|
||||
int8_t in[] = {-8, 11, 99, -80, 8, -12, 37, -45, 31, -69, -66, 26, 112, 124, -109, 85, -24, 28, -46, 100,
|
||||
72, -36, -82, 64, -110, 37, -72, 65, -124, 91, -43, 99, 3, 100, 19, 51, -14, -81, 67, 90,
|
||||
4, -106, 105, 28, -61, -79, 55, -54, 47, -38, 114, 125, -65, 100, 6, -72, -33, 60, 109, -68};
|
||||
int8_t co[] = {-8, 11, 99, -80, 8, -12, 0, 0, 112, 124, -109, 85, -24, 28, 0, 0, -110, 37, -72, 65,
|
||||
-124, 91, 0, 0, -14, -81, 67, 90, 4, -106, 0, 0, 47, -38, 114, 125, -65, 100, 0, 0,
|
||||
37, -45, 31, -69, -66, 26, 0, 0, -46, 100, 72, -36, -82, 64, 0, 0, -43, 99, 3, 100,
|
||||
19, 51, 0, 0, 105, 28, -61, -79, 55, -54, 0, 0, 6, -72, -33, 60, 109, -68, 0, 0};
|
||||
int8_t dst[80] = {0};
|
||||
/*5*1*2*6 nhwc*/
|
||||
PackNHWCToC8HWN8Int8(in, dst, 5, 2, 6);
|
||||
ASSERT_EQ(0, CompareOutputData(dst, co, 80, 1));
|
||||
}
|
||||
|
||||
TEST_F(TestDeconvInt8, PackWeight2) {
|
||||
int8_t in[] = {
|
||||
40, 24, 94, 122, 67, 34, -89, 31, -43, 121, 48, -54, 44, -91, 35, 89, -37, 114, -8, 103,
|
||||
-22, 32, 26, 112, -92, -23, 43, 9, 81, 118, -73, -54, 65, -99, 51, -90, 121, -62, 119, -93,
|
||||
21, -92, -1, -82, -71, -54, 63, -93, 92, -93, 99, 122, -104, -16, -8, -32, 90, -126, 51, 91,
|
||||
4, 70, -7, 116, 99, 81, -79, 124, -14, 28, 97, 9, -97, 99, 88, -15, 54, 26, 77, -25,
|
||||
113, 119, 119, -75, -17, 7, 7, 1, 69, 66, 40, -13, 80, -115, -98, -8, -17, 31, 88, 65,
|
||||
-1, -15, -98, 77, 56, 119, -20, -32, -54, -58, -16, 52, 121, 126, -33, 43, 92, -34, -17, -52,
|
||||
104, -52, -91, 76, 79, 105, 102, -65, 43, 32, 13, 15, -38, 95, -18, -82, -7, 118, -79, -85,
|
||||
120, -15, 2, 32, -94, 111, 115, 102, -18, 121, -106, 54, 63, 111, -16, 92, 82, -23, 111, 53,
|
||||
1, -48, 45, 19, -4, -15, -72, 41, 80, -51, 116, 31, 94, 101, -10, 18, 0, -49, 108, 28,
|
||||
-36, 47, -14, -2, -10, 31, -92, -84, 74, -114, -107, 66, 99, -121, -107, 31, -38, 56, -30, 109,
|
||||
-7, 28, -22, -17, -3, -2, 27, -3, 108, -84, -23, -71, -54, 20, -45, 109, -42, 78, -79, 98,
|
||||
-10, 57, 52, 1, 25, 73, 21, -78, 46, 121, 66, 92, 24, 55, 4, -110, -37, 112, -18, 10,
|
||||
-42, 16, -9, 31, 39, -70, 108, -3, -90, -60, -121, 11, 50, -88, -104, -29, -89, 94, 64, -91,
|
||||
-101, -7, 23, -57, 93, 16, 17, 35, -48, -25, 13, -121, 73, -68, -54, -122, -20, 12, 64, 20,
|
||||
-11, -6, -71, -52, -97, 109, 116, -107, 117, -124, 56, 80, -108, 30, 123, 56, -80, 39, -18, -97,
|
||||
-103, 122, 114, -10, -31, 97, -92, 105, -61, -25, 10, -119, -106, 41, 77, -117, 55, -83, -29, 14,
|
||||
27, -106, -86, 41, 43, 23, 11, -76, -34, 121, 94, 18, 69, 73, 100, 54, 43, 32, 13, 15,
|
||||
-38, 95, -18, -82, -7, 118, -79, -85, 120, -15, 2, 32, -94, 111, 115, 102, -18, 121, -106, 54,
|
||||
63, 111, -16, 92, 82, -23, 111, 53, 1, -48, 45, 19, -4, -15, -72, 41, 80, -51, 116, 31,
|
||||
94, 101, -10, 18, 0, -49, 108, 28, -36, 47, -14, -2, -10, 31, -92, -84, 74, -114, -107, 66,
|
||||
99, -121, -107, 31, -38, 56, -30, 109, -7, 28, -22, -17, -3, -2, 27, -3, 108, -84, -23, -71,
|
||||
-54, 20, -45, 109, -42, 78, -79, 98, -10, 57, 52, 1, 25, 73, 21, -78, 46, 121, 66, 92};
|
||||
int8_t co[] = {
|
||||
40, 24, 94, 122, 67, 34, -89, 31, -22, 32, 26, 112, -92, -23, 43, 9, 21, -92, -1, -82,
|
||||
-71, -54, 63, -93, 4, 70, -7, 116, 99, 81, -79, 124, 113, 119, 119, -75, -17, 7, 7, 1,
|
||||
-1, -15, -98, 77, 56, 119, -20, -32, 104, -52, -91, 76, 79, 105, 102, -65, 120, -15, 2, 32,
|
||||
-94, 111, 115, 102, 1, -48, 45, 19, -4, -15, -72, 41, -36, 47, -14, -2, -10, 31, -92, -84,
|
||||
-7, 28, -22, -17, -3, -2, 27, -3, -10, 57, 52, 1, 25, 73, 21, -78, -42, 16, -9, 31,
|
||||
39, -70, 108, -3, -101, -7, 23, -57, 93, 16, 17, 35, -11, -6, -71, -52, -97, 109, 116, -107,
|
||||
-103, 122, 114, -10, -31, 97, -92, 105, 27, -106, -86, 41, 43, 23, 11, -76, -38, 95, -18, -82,
|
||||
-7, 118, -79, -85, 63, 111, -16, 92, 82, -23, 111, 53, 94, 101, -10, 18, 0, -49, 108, 28,
|
||||
99, -121, -107, 31, -38, 56, -30, 109, -54, 20, -45, 109, -42, 78, -79, 98, -43, 121, 48, -54,
|
||||
44, -91, 35, 89, 81, 118, -73, -54, 65, -99, 51, -90, 92, -93, 99, 122, -104, -16, -8, -32,
|
||||
-14, 28, 97, 9, -97, 99, 88, -15, 69, 66, 40, -13, 80, -115, -98, -8, -54, -58, -16, 52,
|
||||
121, 126, -33, 43, 43, 32, 13, 15, -38, 95, -18, -82, -18, 121, -106, 54, 63, 111, -16, 92,
|
||||
80, -51, 116, 31, 94, 101, -10, 18, 74, -114, -107, 66, 99, -121, -107, 31, 108, -84, -23, -71,
|
||||
-54, 20, -45, 109, 46, 121, 66, 92, 24, 55, 4, -110, -90, -60, -121, 11, 50, -88, -104, -29,
|
||||
-48, -25, 13, -121, 73, -68, -54, -122, 117, -124, 56, 80, -108, 30, 123, 56, -61, -25, 10, -119,
|
||||
-106, 41, 77, -117, -34, 121, 94, 18, 69, 73, 100, 54, 120, -15, 2, 32, -94, 111, 115, 102,
|
||||
1, -48, 45, 19, -4, -15, -72, 41, -36, 47, -14, -2, -10, 31, -92, -84, -7, 28, -22, -17,
|
||||
-3, -2, 27, -3, -10, 57, 52, 1, 25, 73, 21, -78, -37, 114, -8, 103, 0, 0, 0, 0,
|
||||
121, -62, 119, -93, 0, 0, 0, 0, 90, -126, 51, 91, 0, 0, 0, 0, 54, 26, 77, -25,
|
||||
0, 0, 0, 0, -17, 31, 88, 65, 0, 0, 0, 0, 92, -34, -17, -52, 0, 0, 0, 0,
|
||||
-7, 118, -79, -85, 0, 0, 0, 0, 82, -23, 111, 53, 0, 0, 0, 0, 0, -49, 108, 28,
|
||||
0, 0, 0, 0, -38, 56, -30, 109, 0, 0, 0, 0, -42, 78, -79, 98, 0, 0, 0, 0,
|
||||
-37, 112, -18, 10, 0, 0, 0, 0, -89, 94, 64, -91, 0, 0, 0, 0, -20, 12, 64, 20,
|
||||
0, 0, 0, 0, -80, 39, -18, -97, 0, 0, 0, 0, 55, -83, -29, 14, 0, 0, 0, 0,
|
||||
43, 32, 13, 15, 0, 0, 0, 0, -18, 121, -106, 54, 0, 0, 0, 0, 80, -51, 116, 31,
|
||||
0, 0, 0, 0, 74, -114, -107, 66, 0, 0, 0, 0, 108, -84, -23, -71, 0, 0, 0, 0,
|
||||
46, 121, 66, 92, 0, 0, 0, 0};
|
||||
int8_t dst[528] = {0};
|
||||
PackNHWCToC8HWN8Int8(in, dst, 22, 1, 20);
|
||||
ASSERT_EQ(0, CompareOutputData(dst, co, 528, 1));
|
||||
}
|
||||
|
||||
TEST_F(TestDeconvInt8, PackInputTest1) {
|
||||
/* 6 x 20 */
|
||||
int8_t in[] = {40, 24, 94, 122, 67, 34, -89, 31, -43, 121, 48, -54, 44, -91, 35, 89, -37, 114, -8, 103,
|
||||
|
|
Loading…
Reference in New Issue