aligned_malloc

This commit is contained in:
lzk 2021-06-15 01:03:05 -07:00
parent 8f132916f5
commit 5b2c4bceb1
12 changed files with 545 additions and 549 deletions

View File

@ -24,7 +24,7 @@
extern "C" {
#endif
typedef void (*Conv1x1SWKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step,
size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step,
size_t dst_flag);
void Conv1X1SWBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom,
@ -36,48 +36,48 @@ void Conv1x1SWFp32(const float *input_data, const float *packed_weight, const fl
#ifdef ENABLE_DEBUG
void Conv1x1SWOWxOCKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step,
size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step,
size_t dst_flag);
#endif
void Conv1x1SW3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step,
size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step,
size_t dst_flag);
void Conv1x1SW1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step,
size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step,
size_t dst_flag);
void Conv1x1SW4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step,
size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step,
size_t dst_flag);
void Conv1x1SW1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step,
size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step,
size_t dst_flag);
void Conv1x1SW6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step,
size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step,
size_t dst_flag);
void Conv1x1SW1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step,
size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step,
size_t dst_flag);
void Conv1x1SW12x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step,
size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step,
size_t dst_flag);
void Conv1x1SW8x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step,
size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step,
size_t dst_flag);
void Conv1x1SW4x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step,
size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step,
size_t dst_flag);
void Conv1x1SW1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step,
size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step,
size_t dst_flag);
#endif
#ifdef __cplusplus

View File

@ -218,18 +218,18 @@ void SWConv3x32Kernel(float *dst, const float *src, const float *weight, const f
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
"vmovups (%2), %%ymm0\n"
"vmovups 0x20(%2), %%ymm1\n"
"vmovups 0x40(%2), %%ymm2\n"
"vmovups 0x60(%2), %%ymm3\n"
"vmovups (%2), %%ymm4\n"
"vmovups 0x20(%2), %%ymm5\n"
"vmovups 0x40(%2), %%ymm6\n"
"vmovups 0x60(%2), %%ymm7\n"
"vmovups (%2), %%ymm8\n"
"vmovups 0x20(%2), %%ymm9\n"
"vmovups 0x40(%2), %%ymm10\n"
"vmovups 0x60(%2), %%ymm11\n"
"vmovaps (%2), %%ymm0\n"
"vmovaps 0x20(%2), %%ymm1\n"
"vmovaps 0x40(%2), %%ymm2\n"
"vmovaps 0x60(%2), %%ymm3\n"
"vmovaps (%2), %%ymm4\n"
"vmovaps 0x20(%2), %%ymm5\n"
"vmovaps 0x40(%2), %%ymm6\n"
"vmovaps 0x60(%2), %%ymm7\n"
"vmovaps (%2), %%ymm8\n"
"vmovaps 0x20(%2), %%ymm9\n"
"vmovaps 0x40(%2), %%ymm10\n"
"vmovaps 0x60(%2), %%ymm11\n"
"jmp 1f\n"
"0:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
@ -254,19 +254,19 @@ void SWConv3x32Kernel(float *dst, const float *src, const float *weight, const f
"vbroadcastss (%%rdx), %%ymm13\n"
"vbroadcastss (%%rdx, %8), %%ymm14\n"
"vbroadcastss (%%rdx, %8, 2), %%ymm15\n"
"vmovups (%1), %%ymm12\n"
"vmovaps (%1), %%ymm12\n"
"vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n"
"vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n"
"vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n"
"vmovups 0x20(%1), %%ymm12\n"
"vmovaps 0x20(%1), %%ymm12\n"
"vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n"
"vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n"
"vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n"
"vmovups 0x40(%1), %%ymm12\n"
"vmovaps 0x40(%1), %%ymm12\n"
"vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n"
"vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n"
"vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n"
"vmovups 0x60(%1), %%ymm12\n"
"vmovaps 0x60(%1), %%ymm12\n"
"vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n"
"vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n"
"vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n"
@ -355,10 +355,10 @@ void SWConv1x32Kernel(float *dst, const float *src, const float *weight, const f
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
"vmovups (%2), %%ymm0\n"
"vmovups 0x20(%2), %%ymm1\n"
"vmovups 0x40(%2), %%ymm2\n"
"vmovups 0x60(%2), %%ymm3\n"
"vmovaps (%2), %%ymm0\n"
"vmovaps 0x20(%2), %%ymm1\n"
"vmovaps 0x40(%2), %%ymm2\n"
"vmovaps 0x60(%2), %%ymm3\n"
"jmp 1f\n"
"0:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
@ -440,19 +440,19 @@ void SWConv4x24Kernel(float *dst, const float *src, const float *weight, const f
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
"vmovups (%2), %%ymm0\n"
"vmovups 0x20(%2), %%ymm1\n"
"vmovups 0x40(%2), %%ymm2\n"
"vmovaps (%2), %%ymm0\n"
"vmovaps 0x20(%2), %%ymm1\n"
"vmovaps 0x40(%2), %%ymm2\n"
// We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction.
"vmovups (%2), %%ymm3\n"
"vmovups 0x20(%2), %%ymm4\n"
"vmovups 0x40(%2), %%ymm5\n"
"vmovups (%2), %%ymm6\n"
"vmovups 0x20(%2), %%ymm7\n"
"vmovups 0x40(%2), %%ymm8\n"
"vmovups (%2), %%ymm9\n"
"vmovups 0x20(%2), %%ymm10\n"
"vmovups 0x40(%2), %%ymm11\n"
"vmovaps (%2), %%ymm3\n"
"vmovaps 0x20(%2), %%ymm4\n"
"vmovaps 0x40(%2), %%ymm5\n"
"vmovaps (%2), %%ymm6\n"
"vmovaps 0x20(%2), %%ymm7\n"
"vmovaps 0x40(%2), %%ymm8\n"
"vmovaps (%2), %%ymm9\n"
"vmovaps 0x20(%2), %%ymm10\n"
"vmovaps 0x40(%2), %%ymm11\n"
"jmp 1f\n"
"0:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
@ -474,9 +474,9 @@ void SWConv4x24Kernel(float *dst, const float *src, const float *weight, const f
"movq %%rcx, %%rdx\n"
"movq %5, %%r12\n" // ic_algin
"3:\n" // LoopIC
"vmovups (%1), %%ymm12\n"
"vmovups 0x20(%1), %%ymm13\n"
"vmovups 0x40(%1), %%ymm14\n"
"vmovaps (%1), %%ymm12\n"
"vmovaps 0x20(%1), %%ymm13\n"
"vmovaps 0x40(%1), %%ymm14\n"
"vbroadcastss (%%rdx), %%ymm15\n"
"vfmadd231ps %%ymm15, %%ymm12, %%ymm0\n"
@ -585,9 +585,9 @@ void SWConv1x24Kernel(float *dst, const float *src, const float *weight, const f
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
"vmovups (%2), %%ymm0\n"
"vmovups 0x20(%2), %%ymm1\n"
"vmovups 0x40(%2), %%ymm2\n"
"vmovaps (%2), %%ymm0\n"
"vmovaps 0x20(%2), %%ymm1\n"
"vmovaps 0x40(%2), %%ymm2\n"
"jmp 1f\n"
"0:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
@ -664,19 +664,19 @@ void SWConv6x16Kernel(float *dst, const float *src, const float *weight, const f
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
"vmovups (%2), %%ymm0\n"
"vmovups 0x20(%2), %%ymm1\n"
"vmovaps (%2), %%ymm0\n"
"vmovaps 0x20(%2), %%ymm1\n"
// We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction.
"vmovups (%2), %%ymm2\n"
"vmovups 0x20(%2), %%ymm3\n"
"vmovups (%2), %%ymm4\n"
"vmovups 0x20(%2), %%ymm5\n"
"vmovups (%2), %%ymm6\n"
"vmovups 0x20(%2), %%ymm7\n"
"vmovups (%2), %%ymm8\n"
"vmovups 0x20(%2), %%ymm9\n"
"vmovups (%2), %%ymm10\n"
"vmovups 0x20(%2), %%ymm11\n"
"vmovaps (%2), %%ymm2\n"
"vmovaps 0x20(%2), %%ymm3\n"
"vmovaps (%2), %%ymm4\n"
"vmovaps 0x20(%2), %%ymm5\n"
"vmovaps (%2), %%ymm6\n"
"vmovaps 0x20(%2), %%ymm7\n"
"vmovaps (%2), %%ymm8\n"
"vmovaps 0x20(%2), %%ymm9\n"
"vmovaps (%2), %%ymm10\n"
"vmovaps 0x20(%2), %%ymm11\n"
"jmp 1f\n"
"0:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
@ -698,8 +698,8 @@ void SWConv6x16Kernel(float *dst, const float *src, const float *weight, const f
"movq %%rcx, %%rdx\n"
"movq %5, %%r12\n" // ic_algin
"3:\n" // LoopIC
"vmovups (%1), %%ymm12\n"
"vmovups 0x20(%1), %%ymm13\n"
"vmovaps (%1), %%ymm12\n"
"vmovaps 0x20(%1), %%ymm13\n"
"vbroadcastss (%%rdx), %%ymm15\n"
"vfmadd231ps %%ymm15, %%ymm12, %%ymm0\n"
@ -812,8 +812,8 @@ void SWConv1x16Kernel(float *dst, const float *src, const float *weight, const f
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
"vmovups (%2), %%ymm0\n"
"vmovups 0x20(%2), %%ymm1\n"
"vmovaps (%2), %%ymm0\n"
"vmovaps 0x20(%2), %%ymm1\n"
"jmp 1f\n"
"0:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
@ -887,18 +887,18 @@ void SWConv12x8Kernel(float *dst, const float *src, const float *weight, const f
asm volatile(
"cmpq $0, %0\n"
"je 0f\n"
"vmovups (%0), %%ymm0\n"
"vmovups (%0), %%ymm1\n"
"vmovups (%0), %%ymm2\n"
"vmovups (%0), %%ymm3\n"
"vmovups (%0), %%ymm4\n"
"vmovups (%0), %%ymm5\n"
"vmovups (%0), %%ymm6\n"
"vmovups (%0), %%ymm7\n"
"vmovups (%0), %%ymm8\n"
"vmovups (%0), %%ymm9\n"
"vmovups (%0), %%ymm10\n"
"vmovups (%0), %%ymm11\n"
"vmovaps (%0), %%ymm0\n"
"vmovaps (%0), %%ymm1\n"
"vmovaps (%0), %%ymm2\n"
"vmovaps (%0), %%ymm3\n"
"vmovaps (%0), %%ymm4\n"
"vmovaps (%0), %%ymm5\n"
"vmovaps (%0), %%ymm6\n"
"vmovaps (%0), %%ymm7\n"
"vmovaps (%0), %%ymm8\n"
"vmovaps (%0), %%ymm9\n"
"vmovaps (%0), %%ymm10\n"
"vmovaps (%0), %%ymm11\n"
"jmp 1f\n"
"0:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
@ -926,7 +926,7 @@ void SWConv12x8Kernel(float *dst, const float *src, const float *weight, const f
"movq %%rcx, %%rdx\n"
"movq %4, %%r12\n" // ic_algin
"LoopIC:\n"
"vmovups (%1), %%ymm12\n"
"vmovaps (%1), %%ymm12\n"
"movq %%rdx, %%rax\n"
"addq $32, %1\n"
"vbroadcastss (%%rax), %%ymm13\n"
@ -1043,10 +1043,10 @@ void SWConv4x8Kernel(float *dst, const float *src, const float *weight, const fl
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
"vmovups (%2), %%ymm0\n"
"vmovups (%2), %%ymm1\n"
"vmovups (%2), %%ymm2\n"
"vmovups (%2), %%ymm3\n"
"vmovaps (%2), %%ymm0\n"
"vmovaps (%2), %%ymm1\n"
"vmovaps (%2), %%ymm2\n"
"vmovaps (%2), %%ymm3\n"
"jmp 1f\n"
"0:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
@ -1060,7 +1060,7 @@ void SWConv4x8Kernel(float *dst, const float *src, const float *weight, const fl
"movq %%rcx, %%rdx\n"
"movq %5, %%r12\n" // ic_algin
"3:\n" // LoopIC
"vmovups (%1), %%ymm12\n"
"vmovaps (%1), %%ymm12\n"
"movq %%rdx, %%rax\n"
"addq $32, %1\n"
"vbroadcastss (%%rax), %%ymm13\n"
@ -1131,7 +1131,7 @@ void SWConv1x8Kernel(float *dst, const float *src, const float *weight, const fl
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
"vmovups (%2), %%ymm0\n"
"vmovaps (%2), %%ymm0\n"
"jmp 1f\n"
"0:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"

View File

@ -1181,18 +1181,18 @@ void DepthwiseSW3x32Kernel(float *dst, const float *src, const float *weight, co
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
"vmovups (%2), %%ymm0\n"
"vmovups 0x20(%2), %%ymm1\n"
"vmovups 0x40(%2), %%ymm2\n"
"vmovups 0x60(%2), %%ymm3\n"
"vmovups (%2), %%ymm4\n"
"vmovups 0x20(%2), %%ymm5\n"
"vmovups 0x40(%2), %%ymm6\n"
"vmovups 0x60(%2), %%ymm7\n"
"vmovups (%2), %%ymm8\n"
"vmovups 0x20(%2), %%ymm9\n"
"vmovups 0x40(%2), %%ymm10\n"
"vmovups 0x60(%2), %%ymm11\n"
"vmovaps (%2), %%ymm0\n"
"vmovaps 0x20(%2), %%ymm1\n"
"vmovaps 0x40(%2), %%ymm2\n"
"vmovaps 0x60(%2), %%ymm3\n"
"vmovaps (%2), %%ymm4\n"
"vmovaps 0x20(%2), %%ymm5\n"
"vmovaps 0x40(%2), %%ymm6\n"
"vmovaps 0x60(%2), %%ymm7\n"
"vmovaps (%2), %%ymm8\n"
"vmovaps 0x20(%2), %%ymm9\n"
"vmovaps 0x40(%2), %%ymm10\n"
"vmovaps 0x60(%2), %%ymm11\n"
"jmp 1f\n"
"0:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
@ -1212,34 +1212,34 @@ void DepthwiseSW3x32Kernel(float *dst, const float *src, const float *weight, co
"movq %0, %%rcx\n" // src_h
"2:\n" // LoopW
"vmovups (%1), %%ymm12\n"
"vmovups (%%rcx), %%ymm13\n"
"vmovups (%%rcx, %7), %%ymm14\n"
"vmovups (%%rcx, %7, 2), %%ymm15\n"
"vmovaps (%1), %%ymm12\n"
"vmovaps (%%rcx), %%ymm13\n"
"vmovaps (%%rcx, %7), %%ymm14\n"
"vmovaps (%%rcx, %7, 2), %%ymm15\n"
"vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n"
"vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n"
"vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n"
"vmovups 0x20(%1), %%ymm12\n"
"vmovups 0x20(%%rcx), %%ymm13\n"
"vmovups 0x20(%%rcx, %7), %%ymm14\n"
"vmovups 0x20(%%rcx, %7, 2), %%ymm15\n"
"vmovaps 0x20(%1), %%ymm12\n"
"vmovaps 0x20(%%rcx), %%ymm13\n"
"vmovaps 0x20(%%rcx, %7), %%ymm14\n"
"vmovaps 0x20(%%rcx, %7, 2), %%ymm15\n"
"vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n"
"vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n"
"vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n"
"vmovups 0x40(%1), %%ymm12\n"
"vmovups 0x40(%%rcx), %%ymm13\n"
"vmovups 0x40(%%rcx, %7), %%ymm14\n"
"vmovups 0x40(%%rcx, %7, 2), %%ymm15\n"
"vmovaps 0x40(%1), %%ymm12\n"
"vmovaps 0x40(%%rcx), %%ymm13\n"
"vmovaps 0x40(%%rcx, %7), %%ymm14\n"
"vmovaps 0x40(%%rcx, %7, 2), %%ymm15\n"
"vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n"
"vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n"
"vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n"
"vmovups 0x60(%1), %%ymm12\n"
"vmovups 0x60(%%rcx), %%ymm13\n"
"vmovups 0x60(%%rcx, %7), %%ymm14\n"
"vmovups 0x60(%%rcx, %7, 2), %%ymm15\n"
"vmovaps 0x60(%1), %%ymm12\n"
"vmovaps 0x60(%%rcx), %%ymm13\n"
"vmovaps 0x60(%%rcx, %7), %%ymm14\n"
"vmovaps 0x60(%%rcx, %7, 2), %%ymm15\n"
"vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n"
"vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n"
"vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n"
@ -1297,18 +1297,18 @@ void DepthwiseSW3x32Kernel(float *dst, const float *src, const float *weight, co
"vminps %%ymm14, %%ymm11, %%ymm11\n"
"0:\n"
"vmovups %%ymm0, (%2)\n" // dst_0
"vmovups %%ymm1, 0x20(%2)\n"
"vmovups %%ymm2, 0x40(%2)\n"
"vmovups %%ymm3, 0x60(%2)\n"
"vmovups %%ymm4, (%2, %1, 1)\n"
"vmovups %%ymm5, 0x20(%2, %1, 1)\n"
"vmovups %%ymm6, 0x40(%2, %1, 1)\n"
"vmovups %%ymm7, 0x60(%2, %1, 1)\n"
"vmovups %%ymm8, (%2, %1, 2)\n"
"vmovups %%ymm9, 0x20(%2, %1, 2)\n"
"vmovups %%ymm10, 0x40(%2, %1, 2)\n"
"vmovups %%ymm11, 0x60(%2, %1, 2)\n"
"vmovaps %%ymm0, (%2)\n" // dst_0
"vmovaps %%ymm1, 0x20(%2)\n"
"vmovaps %%ymm2, 0x40(%2)\n"
"vmovaps %%ymm3, 0x60(%2)\n"
"vmovaps %%ymm4, (%2, %1, 1)\n"
"vmovaps %%ymm5, 0x20(%2, %1, 1)\n"
"vmovaps %%ymm6, 0x40(%2, %1, 1)\n"
"vmovaps %%ymm7, 0x60(%2, %1, 1)\n"
"vmovaps %%ymm8, (%2, %1, 2)\n"
"vmovaps %%ymm9, 0x20(%2, %1, 2)\n"
"vmovaps %%ymm10, 0x40(%2, %1, 2)\n"
"vmovaps %%ymm11, 0x60(%2, %1, 2)\n"
:
: "a"(act_flag), "r"(oc_algin), "r"(dst)
: "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
@ -1325,10 +1325,10 @@ void DepthwiseSW1x32Kernel(float *dst, const float *src, const float *weight, co
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
"vmovups (%2), %%ymm0\n"
"vmovups 0x20(%2), %%ymm1\n"
"vmovups 0x40(%2), %%ymm2\n"
"vmovups 0x60(%2), %%ymm3\n"
"vmovaps (%2), %%ymm0\n"
"vmovaps 0x20(%2), %%ymm1\n"
"vmovaps 0x40(%2), %%ymm2\n"
"vmovaps 0x60(%2), %%ymm3\n"
"jmp 1f\n"
"0:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
@ -1339,10 +1339,10 @@ void DepthwiseSW1x32Kernel(float *dst, const float *src, const float *weight, co
"movq %4, %%rsi\n" // width
"movq %0, %%rcx\n" // src_h
"2:\n" // Loopw
"vmovups (%%rcx), %%ymm4\n"
"vmovups 0x20(%%rcx), %%ymm5\n"
"vmovups 0x40(%%rcx), %%ymm6\n"
"vmovups 0x60(%%rcx), %%ymm7\n"
"vmovaps (%%rcx), %%ymm4\n"
"vmovaps 0x20(%%rcx), %%ymm5\n"
"vmovaps 0x40(%%rcx), %%ymm6\n"
"vmovaps 0x60(%%rcx), %%ymm7\n"
// Weight data is loaded directly from memory instead of into registers for calculation.
"vfmadd231ps (%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n"
@ -1385,10 +1385,10 @@ void DepthwiseSW1x32Kernel(float *dst, const float *src, const float *weight, co
"vminps %%ymm14, %%ymm3, %%ymm3\n"
"0:\n"
"vmovups %%ymm0, (%2)\n" // dst_0
"vmovups %%ymm1, 0x20(%2)\n"
"vmovups %%ymm2, 0x40(%2)\n"
"vmovups %%ymm3, 0x60(%2)\n"
"vmovaps %%ymm0, (%2)\n" // dst_0
"vmovaps %%ymm1, 0x20(%2)\n"
"vmovaps %%ymm2, 0x40(%2)\n"
"vmovaps %%ymm3, 0x60(%2)\n"
:
: "a"(act_flag), "r"(oc_algin), "r"(dst)
: "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm14");
@ -1407,19 +1407,19 @@ void DepthwiseSW4x24Kernel(float *dst, const float *src, const float *weight, co
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
"vmovups (%2), %%ymm0\n"
"vmovups 0x20(%2), %%ymm1\n"
"vmovups 0x40(%2), %%ymm2\n"
"vmovaps (%2), %%ymm0\n"
"vmovaps 0x20(%2), %%ymm1\n"
"vmovaps 0x40(%2), %%ymm2\n"
// We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction.
"vmovups (%2), %%ymm3\n"
"vmovups 0x20(%2), %%ymm4\n"
"vmovups 0x40(%2), %%ymm5\n"
"vmovups (%2), %%ymm6\n"
"vmovups 0x20(%2), %%ymm7\n"
"vmovups 0x40(%2), %%ymm8\n"
"vmovups (%2), %%ymm9\n"
"vmovups 0x20(%2), %%ymm10\n"
"vmovups 0x40(%2), %%ymm11\n"
"vmovaps (%2), %%ymm3\n"
"vmovaps 0x20(%2), %%ymm4\n"
"vmovaps 0x40(%2), %%ymm5\n"
"vmovaps (%2), %%ymm6\n"
"vmovaps 0x20(%2), %%ymm7\n"
"vmovaps 0x40(%2), %%ymm8\n"
"vmovaps (%2), %%ymm9\n"
"vmovaps 0x20(%2), %%ymm10\n"
"vmovaps 0x40(%2), %%ymm11\n"
"jmp 1f\n"
"0:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
@ -1438,33 +1438,33 @@ void DepthwiseSW4x24Kernel(float *dst, const float *src, const float *weight, co
"movq %4, %%rsi\n" // width
"movq %0, %%rcx\n" // src_h
"2:\n" // LoopW
"vmovups (%1), %%ymm12\n"
"vmovups (%%rcx), %%ymm13\n"
"vmovups (%%rcx, %7, 1), %%ymm14\n"
"vmovaps (%1), %%ymm12\n"
"vmovaps (%%rcx), %%ymm13\n"
"vmovaps (%%rcx, %7, 1), %%ymm14\n"
"vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n"
"vfmadd231ps %%ymm12, %%ymm14, %%ymm3\n"
"vmovups (%%rcx, %7, 2), %%ymm15\n"
"vmovups (%%rcx, %9), %%ymm13\n"
"vmovaps (%%rcx, %7, 2), %%ymm15\n"
"vmovaps (%%rcx, %9), %%ymm13\n"
"vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n"
"vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n"
"vmovups 0x20(%1), %%ymm12\n"
"vmovups 0x20(%%rcx), %%ymm13\n"
"vmovups 0x20(%%rcx, %7, 1), %%ymm14\n"
"vmovaps 0x20(%1), %%ymm12\n"
"vmovaps 0x20(%%rcx), %%ymm13\n"
"vmovaps 0x20(%%rcx, %7, 1), %%ymm14\n"
"vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n"
"vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n"
"vmovups 0x20(%%rcx, %7, 2), %%ymm15\n"
"vmovups 0x20(%%rcx, %9), %%ymm13\n"
"vmovaps 0x20(%%rcx, %7, 2), %%ymm15\n"
"vmovaps 0x20(%%rcx, %9), %%ymm13\n"
"vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n"
"vfmadd231ps %%ymm12, %%ymm13, %%ymm10\n"
"vmovups 0x40(%1), %%ymm12\n"
"vmovups 0x40(%%rcx), %%ymm13\n"
"vmovups 0x40(%%rcx, %7, 1), %%ymm14\n"
"vmovaps 0x40(%1), %%ymm12\n"
"vmovaps 0x40(%%rcx), %%ymm13\n"
"vmovaps 0x40(%%rcx, %7, 1), %%ymm14\n"
"vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n"
"vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n"
"vmovups 0x40(%%rcx, %7, 2), %%ymm15\n"
"vmovups 0x40(%%rcx, %9), %%ymm13\n"
"vmovaps 0x40(%%rcx, %7, 2), %%ymm15\n"
"vmovaps 0x40(%%rcx, %9), %%ymm13\n"
"vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n"
"vfmadd231ps %%ymm12, %%ymm13, %%ymm11\n"
@ -1521,18 +1521,18 @@ void DepthwiseSW4x24Kernel(float *dst, const float *src, const float *weight, co
"vminps %%ymm14, %%ymm11, %%ymm11\n"
"0:\n"
"vmovups %%ymm0, (%2)\n" // dst_0
"vmovups %%ymm1, 0x20(%2)\n"
"vmovups %%ymm2, 0x40(%2)\n"
"vmovups %%ymm3, (%2, %1, 1)\n"
"vmovups %%ymm4, 0x20(%2, %1, 1)\n"
"vmovups %%ymm5, 0x40(%2, %1, 1)\n"
"vmovups %%ymm6, (%2, %1, 2)\n"
"vmovups %%ymm7, 0x20(%2, %1, 2)\n"
"vmovups %%ymm8, 0x40(%2, %1, 2)\n"
"vmovups %%ymm9, (%3)\n" // dst+3
"vmovups %%ymm10, 0x20(%3)\n"
"vmovups %%ymm11, 0x40(%3)\n"
"vmovaps %%ymm0, (%2)\n" // dst_0
"vmovaps %%ymm1, 0x20(%2)\n"
"vmovaps %%ymm2, 0x40(%2)\n"
"vmovaps %%ymm3, (%2, %1, 1)\n"
"vmovaps %%ymm4, 0x20(%2, %1, 1)\n"
"vmovaps %%ymm5, 0x40(%2, %1, 1)\n"
"vmovaps %%ymm6, (%2, %1, 2)\n"
"vmovaps %%ymm7, 0x20(%2, %1, 2)\n"
"vmovaps %%ymm8, 0x40(%2, %1, 2)\n"
"vmovaps %%ymm9, (%3)\n" // dst+3
"vmovaps %%ymm10, 0x20(%3)\n"
"vmovaps %%ymm11, 0x40(%3)\n"
:
: "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3)
: "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
@ -1549,9 +1549,9 @@ void DepthwiseSW1x24Kernel(float *dst, const float *src, const float *weight, co
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
"vmovups (%2), %%ymm0\n"
"vmovups 0x20(%2), %%ymm1\n"
"vmovups 0x40(%2), %%ymm2\n"
"vmovaps (%2), %%ymm0\n"
"vmovaps 0x20(%2), %%ymm1\n"
"vmovaps 0x40(%2), %%ymm2\n"
"jmp 1f\n"
"0:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
@ -1561,9 +1561,9 @@ void DepthwiseSW1x24Kernel(float *dst, const float *src, const float *weight, co
"movq %4, %%rsi\n" // width
"movq %0, %%rcx\n" // src_h
"2:\n" // Loopw
"vmovups (%%rcx), %%ymm4\n"
"vmovups 0x20(%%rcx), %%ymm5\n"
"vmovups 0x40(%%rcx), %%ymm6\n"
"vmovaps (%%rcx), %%ymm4\n"
"vmovaps 0x20(%%rcx), %%ymm5\n"
"vmovaps 0x40(%%rcx), %%ymm6\n"
// Weight data is loaded directly from memory instead of into registers for calculation.
"vfmadd231ps (%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n"
@ -1603,9 +1603,9 @@ void DepthwiseSW1x24Kernel(float *dst, const float *src, const float *weight, co
"vminps %%ymm14, %%ymm2, %%ymm2\n"
"0:\n"
"vmovups %%ymm0, (%2)\n" // dst_0
"vmovups %%ymm1, 0x20(%2)\n"
"vmovups %%ymm2, 0x40(%2)\n"
"vmovaps %%ymm0, (%2)\n" // dst_0
"vmovaps %%ymm1, 0x20(%2)\n"
"vmovaps %%ymm2, 0x40(%2)\n"
:
: "a"(act_flag), "r"(oc_algin), "r"(dst)
: "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm12", "%ymm14");
@ -1624,15 +1624,15 @@ void DepthwiseSW4x16Kernel(float *dst, const float *src, const float *weight, co
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
"vmovups (%2), %%ymm0\n"
"vmovups 0x20(%2), %%ymm1\n"
"vmovaps (%2), %%ymm0\n"
"vmovaps 0x20(%2), %%ymm1\n"
// We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction.
"vmovups (%2), %%ymm3\n"
"vmovups 0x20(%2), %%ymm4\n"
"vmovups (%2), %%ymm6\n"
"vmovups 0x20(%2), %%ymm7\n"
"vmovups (%2), %%ymm9\n"
"vmovups 0x20(%2), %%ymm10\n"
"vmovaps (%2), %%ymm3\n"
"vmovaps 0x20(%2), %%ymm4\n"
"vmovaps (%2), %%ymm6\n"
"vmovaps 0x20(%2), %%ymm7\n"
"vmovaps (%2), %%ymm9\n"
"vmovaps 0x20(%2), %%ymm10\n"
"jmp 1f\n"
"0:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
@ -1647,21 +1647,21 @@ void DepthwiseSW4x16Kernel(float *dst, const float *src, const float *weight, co
"movq %4, %%rsi\n" // width
"movq %0, %%rcx\n" // src_h
"2:\n" // LoopW
"vmovups (%1), %%ymm12\n"
"vmovups (%%rcx), %%ymm13\n"
"vmovups (%%rcx, %7, 1), %%ymm14\n"
"vmovups (%%rcx, %7, 2), %%ymm15\n"
"vmovups (%%rcx, %9), %%ymm2\n"
"vmovaps (%1), %%ymm12\n"
"vmovaps (%%rcx), %%ymm13\n"
"vmovaps (%%rcx, %7, 1), %%ymm14\n"
"vmovaps (%%rcx, %7, 2), %%ymm15\n"
"vmovaps (%%rcx, %9), %%ymm2\n"
"vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n"
"vfmadd231ps %%ymm12, %%ymm14, %%ymm3\n"
"vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n"
"vfmadd231ps %%ymm12, %%ymm2, %%ymm9\n"
"vmovups 0x20(%1), %%ymm12\n"
"vmovups 0x20(%%rcx), %%ymm13\n"
"vmovups 0x20(%%rcx, %7, 1), %%ymm14\n"
"vmovups 0x20(%%rcx, %7, 2), %%ymm15\n"
"vmovups 0x20(%%rcx, %9), %%ymm2\n"
"vmovaps 0x20(%1), %%ymm12\n"
"vmovaps 0x20(%%rcx), %%ymm13\n"
"vmovaps 0x20(%%rcx, %7, 1), %%ymm14\n"
"vmovaps 0x20(%%rcx, %7, 2), %%ymm15\n"
"vmovaps 0x20(%%rcx, %9), %%ymm2\n"
"vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n"
"vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n"
"vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n"
@ -1712,14 +1712,14 @@ void DepthwiseSW4x16Kernel(float *dst, const float *src, const float *weight, co
"vminps %%ymm14, %%ymm10, %%ymm10\n"
"0:\n"
"vmovups %%ymm0, (%2)\n" // dst_0
"vmovups %%ymm1, 0x20(%2)\n"
"vmovups %%ymm3, (%2, %1, 1)\n"
"vmovups %%ymm4, 0x20(%2, %1, 1)\n"
"vmovups %%ymm6, (%2, %1, 2)\n"
"vmovups %%ymm7, 0x20(%2, %1, 2)\n"
"vmovups %%ymm9, (%3)\n" // dst+3
"vmovups %%ymm10, 0x20(%3)\n"
"vmovaps %%ymm0, (%2)\n" // dst_0
"vmovaps %%ymm1, 0x20(%2)\n"
"vmovaps %%ymm3, (%2, %1, 1)\n"
"vmovaps %%ymm4, 0x20(%2, %1, 1)\n"
"vmovaps %%ymm6, (%2, %1, 2)\n"
"vmovaps %%ymm7, 0x20(%2, %1, 2)\n"
"vmovaps %%ymm9, (%3)\n" // dst+3
"vmovaps %%ymm10, 0x20(%3)\n"
:
: "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3)
: "%ecx", "%ymm0", "%ymm1", "%ymm3", "%ymm4", "%ymm6", "%ymm7", "%ymm9", "%ymm10", "%ymm12", "%ymm14");
@ -1735,8 +1735,8 @@ void DepthwiseSW1x16Kernel(float *dst, const float *src, const float *weight, co
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
"vmovups (%2), %%ymm0\n"
"vmovups 0x20(%2), %%ymm1\n"
"vmovaps (%2), %%ymm0\n"
"vmovaps 0x20(%2), %%ymm1\n"
"jmp 1f\n"
"0:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
@ -1745,8 +1745,8 @@ void DepthwiseSW1x16Kernel(float *dst, const float *src, const float *weight, co
"movq %4, %%rsi\n" // width
"movq %0, %%rcx\n" // src_h
"2:\n" // Loopw
"vmovups (%%rcx), %%ymm4\n"
"vmovups 0x20(%%rcx), %%ymm5\n"
"vmovaps (%%rcx), %%ymm4\n"
"vmovaps 0x20(%%rcx), %%ymm5\n"
// Weight data is loaded directly from memory instead of into registers for calculation.
"vfmadd231ps (%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n"
@ -1783,8 +1783,8 @@ void DepthwiseSW1x16Kernel(float *dst, const float *src, const float *weight, co
"vminps %%ymm14, %%ymm1, %%ymm1\n"
"0:\n"
"vmovups %%ymm0, (%2)\n" // dst_0
"vmovups %%ymm1, 0x20(%2)\n"
"vmovaps %%ymm0, (%2)\n" // dst_0
"vmovaps %%ymm1, 0x20(%2)\n"
:
: "a"(act_flag), "r"(oc_algin), "r"(dst)
: "%ecx", "%ymm0", "%ymm1", "%ymm12", "%ymm14");
@ -1804,14 +1804,14 @@ void DepthwiseSW8x8Kernel(float *dst, const float *src, const float *weight, con
asm volatile(
"cmpq $0, %0\n"
"je 0f\n"
"vmovups (%0), %%ymm0\n"
"vmovups (%0), %%ymm1\n"
"vmovups (%0), %%ymm2\n"
"vmovups (%0), %%ymm3\n"
"vmovups (%0), %%ymm4\n"
"vmovups (%0), %%ymm5\n"
"vmovups (%0), %%ymm6\n"
"vmovups (%0), %%ymm7\n"
"vmovaps (%0), %%ymm0\n"
"vmovaps (%0), %%ymm1\n"
"vmovaps (%0), %%ymm2\n"
"vmovaps (%0), %%ymm3\n"
"vmovaps (%0), %%ymm4\n"
"vmovaps (%0), %%ymm5\n"
"vmovaps (%0), %%ymm6\n"
"vmovaps (%0), %%ymm7\n"
"jmp 1f\n"
"0:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
@ -1833,23 +1833,23 @@ void DepthwiseSW8x8Kernel(float *dst, const float *src, const float *weight, con
"movq %0, %%rcx\n" // src_h
"LoopW:\n"
"movq %%rcx, %%rax\n"
"vmovups (%1), %%ymm12\n"
"vmovups (%%rax), %%ymm13\n"
"vmovups (%%rax, %6), %%ymm14\n"
"vmovups (%%rax, %6, 2), %%ymm15\n"
"vmovaps (%1), %%ymm12\n"
"vmovaps (%%rax), %%ymm13\n"
"vmovaps (%%rax, %6), %%ymm14\n"
"vmovaps (%%rax, %6, 2), %%ymm15\n"
"vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n"
"vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n"
"vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n"
"addq %7, %%rax\n"
"vmovups (%%rax), %%ymm13\n"
"vmovups (%%rax, %6), %%ymm14\n"
"vmovups (%%rax, %6, 2), %%ymm15\n"
"vmovaps (%%rax), %%ymm13\n"
"vmovaps (%%rax, %6), %%ymm14\n"
"vmovaps (%%rax, %6, 2), %%ymm15\n"
"vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n"
"vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n"
"vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n"
"addq %7, %%rax\n"
"vmovups (%%rax), %%ymm13\n"
"vmovups (%%rax, %6), %%ymm14\n"
"vmovaps (%%rax), %%ymm13\n"
"vmovaps (%%rax, %6), %%ymm14\n"
"vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n"
"vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n"
@ -1898,14 +1898,14 @@ void DepthwiseSW8x8Kernel(float *dst, const float *src, const float *weight, con
"vminps %%ymm14, %%ymm7, %%ymm7\n"
"Write:\n"
"vmovups %%ymm0, (%2)\n" // dst_0
"vmovups %%ymm1, (%2, %1)\n"
"vmovups %%ymm2, (%2, %1, 2)\n"
"vmovups %%ymm3, (%3)\n" // dst_3
"vmovups %%ymm4, (%2, %1, 4)\n"
"vmovups %%ymm5, (%4)\n" // dst_5
"vmovups %%ymm6, (%4, %1, 1)\n"
"vmovups %%ymm7, (%4, %1, 2)\n"
"vmovaps %%ymm0, (%2)\n" // dst_0
"vmovaps %%ymm1, (%2, %1)\n"
"vmovaps %%ymm2, (%2, %1, 2)\n"
"vmovaps %%ymm3, (%3)\n" // dst_3
"vmovaps %%ymm4, (%2, %1, 4)\n"
"vmovaps %%ymm5, (%4)\n" // dst_5
"vmovaps %%ymm6, (%4, %1, 1)\n"
"vmovaps %%ymm7, (%4, %1, 2)\n"
:
: "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3), "r"(dst_5)
: "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm12", "%ymm14");
@ -1921,7 +1921,7 @@ void DepthwiseSW1x8Kernel(float *dst, const float *src, const float *weight, con
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
"vmovups (%2), %%ymm0\n"
"vmovaps (%2), %%ymm0\n"
"jmp 1f\n"
"0:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
@ -1929,7 +1929,7 @@ void DepthwiseSW1x8Kernel(float *dst, const float *src, const float *weight, con
"movq %4, %%rsi\n" // width
"movq %0, %%rcx\n" // src_h
"2:\n" // Loopw
"vmovups (%%rcx), %%ymm4\n"
"vmovaps (%%rcx), %%ymm4\n"
// Weight data is loaded directly from memory instead of into registers for calculation.
"vfmadd231ps (%1), %%ymm4, %%ymm0\n"
"addq $32, %1\n"
@ -1963,7 +1963,7 @@ void DepthwiseSW1x8Kernel(float *dst, const float *src, const float *weight, con
"vminps %%ymm14, %%ymm0, %%ymm0\n"
"0:\n"
"vmovups %%ymm0, (%2)\n" // dst_0
"vmovaps %%ymm0, (%2)\n" // dst_0
:
: "a"(act_flag), "r"(oc_algin), "r"(dst)
: "%ecx", "%ymm0", "%ymm12", "%ymm14");

View File

@ -78,10 +78,8 @@ void *DefaultAllocator::Malloc(size_t size) {
this->total_size_ += size;
membuf->ref_count_ = 0;
membuf->size = size;
auto aligned_bytes =
reinterpret_cast<size_t>((reinterpret_cast<char *>(membuf.get()) + sizeof(MemBuf))) % aligned_size_;
aligned_bytes = aligned_bytes == 0 ? 0 : aligned_size_ - aligned_bytes;
membuf->buf = reinterpret_cast<char *>(membuf.get()) + sizeof(MemBuf) + aligned_bytes;
membuf->buf = reinterpret_cast<char *>(
(reinterpret_cast<uintptr_t>(membuf.get()) + sizeof(MemBuf) + aligned_size_ - 1) & (~(aligned_size_ - 1)));
auto bufPtr = membuf->buf;
allocatedList_[bufPtr] = membuf.release();
UnLock();

View File

@ -27,6 +27,24 @@ using mindspore::lite::RET_OK;
using mindspore::schema::ActivationType;
namespace mindspore::kernel {
void *ConvolutionBaseCPUKernel::MallocAlignedData(size_t alignment, size_t size) {
auto ptr = malloc(size + alignment);
if (ptr == nullptr) {
MS_LOG(ERROR) << "MallocAlignedData failed!";
return nullptr;
}
auto aligned_ptr = (reinterpret_cast<uintptr_t>(ptr) + alignment - 1) & (~(alignment - 1));
addr_map[aligned_ptr] = ptr;
return reinterpret_cast<void *>(aligned_ptr);
}
void ConvolutionBaseCPUKernel::FreeAlignedData(void **ptr) {
if (*ptr != nullptr) {
free(addr_map[reinterpret_cast<uintptr_t>(*ptr)]);
*ptr = nullptr;
}
}
ConvolutionBaseCPUKernel::~ConvolutionBaseCPUKernel() {
if (bias_data_ != nullptr) {
free(bias_data_);

View File

@ -18,6 +18,7 @@
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONVOLUTION_BASE_H_
#include <vector>
#include <map>
#include <string>
#include <limits>
#ifdef ENABLE_ARM
@ -56,8 +57,11 @@ class ConvolutionBaseCPUKernel : public InnerKernel {
void SetRoundingAndMultipilerMode();
int CheckResizeValid();
void FreeQuantParam();
void *MallocAlignedData(size_t alignment, size_t size);
void FreeAlignedData(void **ptr);
protected:
std::map<uintptr_t, void *> addr_map;
bool is_repack() { return is_repack_; }
void *bias_data_ = nullptr;
const InnerContext *ctx_ = nullptr;

View File

@ -145,8 +145,7 @@ kernel::InnerKernel *ConvolutionDelegateCPUKernel::CpuConvFp32KernelSelect() {
if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) {
#ifdef ENABLE_AVX
if (conv_param->pad_d_ == 0 && conv_param->pad_l_ == 0 && conv_param->pad_r_ == 0 && conv_param->pad_u_ == 0 &&
conv_param->output_channel_ % 8 == 0 && conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1 &&
conv_param->input_channel_ % 8 == 0) {
conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1 && conv_param->input_channel_ % 8 == 0) {
kernel = new (std::nothrow) kernel::ConvolutionSWCPUKernel(
op_parameter_, in_tensors_, out_tensors_, static_cast<const lite::InnerContext *>(this->context_),
origin_weight_, origin_bias_);

View File

@ -28,14 +28,8 @@ ConvolutionDepthwiseSWCPUKernelX86::~ConvolutionDepthwiseSWCPUKernelX86() {
delete sliding_;
sliding_ = nullptr;
}
if (packed_weight_ != nullptr) {
free(packed_weight_);
packed_weight_ = nullptr;
}
if (packed_bias_ != nullptr) {
free(packed_bias_);
packed_bias_ = nullptr;
}
FreeAlignedData(reinterpret_cast<void **>(&packed_weight_));
FreeAlignedData(reinterpret_cast<void **>(&packed_bias_));
}
int ConvolutionDepthwiseSWCPUKernelX86::InitWeightBias() {
@ -45,8 +39,7 @@ int ConvolutionDepthwiseSWCPUKernelX86::InitWeightBias() {
MS_ASSERT(origin_weight_ != nullptr);
int oc_algin = UP_DIV(weight_tensor->Batch(), oc_tile_);
int pack_weight_size = oc_algin * oc_tile_ * weight_tensor->Height() * weight_tensor->Width();
packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
packed_weight_ = reinterpret_cast<float *>(MallocAlignedData(alignment, pack_weight_size * sizeof(float)));
if (packed_weight_ == nullptr) {
MS_LOG(ERROR) << "Malloc packed_weight_ is failed!";
return RET_NULL_PTR;
@ -57,7 +50,7 @@ int ConvolutionDepthwiseSWCPUKernelX86::InitWeightBias() {
auto bias_size = oc_algin * oc_tile_;
auto bias_tensor = in_tensors_.at(kBiasIndex);
auto ori_bias = reinterpret_cast<float *>(bias_tensor->data_c());
packed_bias_ = reinterpret_cast<float *>(malloc(bias_size * sizeof(float)));
packed_bias_ = reinterpret_cast<float *>(MallocAlignedData(alignment, bias_size * sizeof(float)));
if (packed_bias_ == nullptr) {
MS_LOG(ERROR) << "Malloc bias_data buffer failed.";
return RET_NULL_PTR;

View File

@ -51,6 +51,7 @@ class ConvolutionDepthwiseSWCPUKernelX86 : public ConvolutionBaseCPUKernel {
float *origin_weight_ = nullptr;
bool input_need_align_ = false;
bool output_need_align_ = false;
size_t alignment = C32NUM;
};
} // namespace mindspore::kernel

View File

@ -39,27 +39,22 @@ int ConvolutionSWCPUKernel::InitWeightBias() {
int kernel_plane = kernel_h * kernel_w;
int oc_block_num = UP_DIV(output_channel, oc_tile_);
int pack_weight_size = oc_block_num * oc_tile_ * input_channel * kernel_plane;
packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
packed_weight_ = reinterpret_cast<float *>(MallocAlignedData(alignment, pack_weight_size * sizeof(float)));
if (packed_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc packed weight failed.";
MS_LOG(ERROR) << "MallocAlignedData packed weight failed.";
return RET_NULL_PTR;
}
memset(packed_weight_, 0, pack_weight_size * sizeof(float));
PackNHWCTo1HWCNXFp32(kernel_h, kernel_w, output_channel, oc_block_num, input_channel, packed_weight_,
ori_weight_data_);
if (in_tensors_.size() == kInputSize2) {
bias_data_ = reinterpret_cast<float *>(malloc(oc_block_num * oc_tile_ * sizeof(float)));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "malloc bias failed.";
packed_bias_ = reinterpret_cast<float *>(MallocAlignedData(alignment, oc_block_num * oc_tile_ * sizeof(float)));
if (packed_bias_ == nullptr) {
MS_LOG(ERROR) << "MallocAlignedData bias failed.";
return RET_NULL_PTR;
}
memset(bias_data_, 0, oc_block_num * oc_tile_ * sizeof(float));
memcpy(bias_data_, ori_bias_data_, output_channel * sizeof(float));
} else {
if (bias_data_ != nullptr) {
free(bias_data_);
bias_data_ = nullptr;
}
memset(packed_bias_, 0, oc_block_num * oc_tile_ * sizeof(float));
memcpy(packed_bias_, ori_bias_data_, output_channel * sizeof(float));
}
return RET_OK;
}
@ -113,10 +108,10 @@ int ConvolutionSWCPUKernel::ReSize() {
int ConvolutionSWCPUKernel::RunImpl(int task_id) {
if (conv_param_->kernel_w_ == 1 && conv_param_->kernel_h_ == 1) {
Conv1x1SWFp32(input_data_, packed_weight_, reinterpret_cast<float *>(bias_data_), output_data_, task_id,
Conv1x1SWFp32(input_data_, packed_weight_, reinterpret_cast<float *>(packed_bias_), output_data_, task_id,
conv_param_, slidingWindow_param_);
} else {
ConvSWFp32(input_data_, packed_weight_, reinterpret_cast<float *>(bias_data_), output_data_, task_id, conv_param_,
ConvSWFp32(input_data_, packed_weight_, reinterpret_cast<float *>(packed_bias_), output_data_, task_id, conv_param_,
slidingWindow_param_);
}
return RET_OK;

View File

@ -34,8 +34,10 @@ class ConvolutionSWCPUKernel : public ConvolutionBaseCPUKernel {
~ConvolutionSWCPUKernel() override {
if (packed_weight_ != nullptr) {
free(packed_weight_);
packed_weight_ = nullptr;
FreeAlignedData(reinterpret_cast<void **>(&packed_weight_));
}
if (packed_bias_ != nullptr) {
FreeAlignedData(reinterpret_cast<void **>(&packed_bias_));
}
if (slidingWindow_param_ != nullptr) {
delete slidingWindow_param_;
@ -68,8 +70,10 @@ class ConvolutionSWCPUKernel : public ConvolutionBaseCPUKernel {
float *ori_weight_data_ = nullptr;
float *ori_bias_data_ = nullptr;
float *packed_weight_ = nullptr;
float *packed_bias_ = nullptr;
float *output_data_ = nullptr;
float *input_data_ = nullptr;
int alignment = C32NUM;
SlidingWindowParam *slidingWindow_param_ = nullptr;
};
} // namespace mindspore::kernel