From 5b2c4bceb10c3a9c3a6a639f1ccfbfd34435691b Mon Sep 17 00:00:00 2001 From: lzk Date: Tue, 15 Jun 2021 01:03:05 -0700 Subject: [PATCH] aligned_malloc --- .../cpu/nnacl/fp32/conv_1x1_x86_fp32.c | 526 +++++++++--------- .../cpu/nnacl/fp32/conv_1x1_x86_fp32.h | 24 +- .../cpu/nnacl/fp32/conv_common_fp32.c | 146 ++--- .../cpu/nnacl/fp32/conv_depthwise_fp32.c | 320 +++++------ mindspore/lite/src/runtime/inner_allocator.cc | 6 +- .../kernel/arm/base/convolution_base.cc | 18 + .../kernel/arm/base/convolution_base.h | 4 + .../arm/fp32/convolution_delegate_fp32.cc | 3 +- ...volution_depthwise_slidewindow_x86_fp32.cc | 15 +- ...nvolution_depthwise_slidewindow_x86_fp32.h | 1 + .../arm/fp32/convolution_slidewindow_fp32.cc | 23 +- .../arm/fp32/convolution_slidewindow_fp32.h | 8 +- 12 files changed, 545 insertions(+), 549 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_1x1_x86_fp32.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_1x1_x86_fp32.c index a174219475b..77881a04475 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_1x1_x86_fp32.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_1x1_x86_fp32.c @@ -34,9 +34,11 @@ void Conv1x1SWFp32(const float *input_data, const float *packed_weight, const fl int pad_l = conv_param->pad_l_; int pad_r = conv_param->pad_r_; int pad_u = conv_param->pad_u_; - int oc_algin = sw_param->block_channel_; - int ic_algin = sw_param->ic_align_; + int oc_align = sw_param->block_channel_; + int oc_align_float = oc_align * sizeof(float); + int ic_align = sw_param->ic_align_; int in_sw_step = sw_param->in_sw_step_; + int in_sw_step_float = sw_param->in_sw_step_ * sizeof(float); int kernel_step = sw_param->kernel_step_; int oc_num = sw_param->c_block_; int in_step = sw_param->in_step_; @@ -53,9 +55,9 @@ void Conv1x1SWFp32(const float *input_data, const float *packed_weight, const fl for (int b = 0; b < conv_param->output_batch_; b++) { int ic_block = 128; int dst_flag = 0; - for (int ic = 0; ic < ic_algin; ic += ic_block) { - if (ic_algin - ic <= ic_block) { - ic_block = ic_algin - ic; + for (int ic = 0; ic < ic_align; ic += ic_block) { + if (ic_align - ic <= ic_block) { + ic_block = ic_align - ic; dst_flag = 3 - (ic == 0); } else { dst_flag = 1 - (ic == 0); @@ -76,9 +78,10 @@ void Conv1x1SWFp32(const float *input_data, const float *packed_weight, const fl if (hw_block > ohw_end - hw) { // ow is not enough and process one ow hw_block = 1; } - float *dst_w = dst_oc + hw * oc_algin; - kernel[oc_block - 1][hw_block / ow_block_num[oc_block - 1]]( - dst_w, src_w, weight, bias, act_type, hw_block, oc_block, oc_algin, ic_block, in_sw_step, dst_flag); + float *dst_w = dst_oc + hw * oc_align; + kernel[oc_block - 1][hw_block / ow_block_num[oc_block - 1]](dst_w, src_w, weight, bias, act_type, hw_block, + oc_block, oc_align_float, ic_block >> 3, + in_sw_step_float, dst_flag); src_w += hw_block * in_sw_step; } } @@ -90,11 +93,8 @@ void Conv1x1SWFp32(const float *input_data, const float *packed_weight, const fl } void Conv1x1SW3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, - size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, size_t dst_flag) { - in_sw_step *= sizeof(float); - oc_algin *= sizeof(float); - ic_algin /= 8; asm volatile( "movq %8, %%rax\n" "and $0x1, %%eax\n" @@ -115,18 +115,18 @@ void Conv1x1SW3x32Kernel(float *dst, const float *src, const float *weight, cons "0:\n" "cmpq $0, %2\n" "je 1f\n" - "vmovups (%2), %%ymm0\n" - "vmovups 0x20(%2), %%ymm1\n" - "vmovups 0x40(%2), %%ymm2\n" - "vmovups 0x60(%2), %%ymm3\n" - "vmovups (%2), %%ymm4\n" - "vmovups 0x20(%2), %%ymm5\n" - "vmovups 0x40(%2), %%ymm6\n" - "vmovups 0x60(%2), %%ymm7\n" - "vmovups (%2), %%ymm8\n" - "vmovups 0x20(%2), %%ymm9\n" - "vmovups 0x40(%2), %%ymm10\n" - "vmovups 0x60(%2), %%ymm11\n" + "vmovaps (%2), %%ymm0\n" + "vmovaps 0x20(%2), %%ymm1\n" + "vmovaps 0x40(%2), %%ymm2\n" + "vmovaps 0x60(%2), %%ymm3\n" + "vmovaps (%2), %%ymm4\n" + "vmovaps 0x20(%2), %%ymm5\n" + "vmovaps 0x40(%2), %%ymm6\n" + "vmovaps 0x60(%2), %%ymm7\n" + "vmovaps (%2), %%ymm8\n" + "vmovaps 0x20(%2), %%ymm9\n" + "vmovaps 0x40(%2), %%ymm10\n" + "vmovaps 0x60(%2), %%ymm11\n" "jmp 2f\n" "1:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -146,19 +146,19 @@ void Conv1x1SW3x32Kernel(float *dst, const float *src, const float *weight, cons "vbroadcastss (%0), %%ymm13\n" "vbroadcastss (%0, %4), %%ymm14\n" "vbroadcastss (%0, %4, 2), %%ymm15\n" - "vmovups (%1), %%ymm12\n" + "vmovaps (%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" - "vmovups 0x20(%1), %%ymm12\n" + "vmovaps 0x20(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" - "vmovups 0x40(%1), %%ymm12\n" + "vmovaps 0x40(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" - "vmovups 0x60(%1), %%ymm12\n" + "vmovaps 0x60(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" @@ -166,19 +166,19 @@ void Conv1x1SW3x32Kernel(float *dst, const float *src, const float *weight, cons "vbroadcastss 4(%0), %%ymm13\n" "vbroadcastss 4(%0, %4), %%ymm14\n" "vbroadcastss 4(%0, %4, 2), %%ymm15\n" - "vmovups 128(%1), %%ymm12\n" + "vmovaps 128(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" - "vmovups 160(%1), %%ymm12\n" + "vmovaps 160(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" - "vmovups 192(%1), %%ymm12\n" + "vmovaps 192(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" - "vmovups 224(%1), %%ymm12\n" + "vmovaps 224(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" @@ -186,19 +186,19 @@ void Conv1x1SW3x32Kernel(float *dst, const float *src, const float *weight, cons "vbroadcastss 8(%0), %%ymm13\n" "vbroadcastss 8(%0, %4), %%ymm14\n" "vbroadcastss 8(%0, %4, 2), %%ymm15\n" - "vmovups 256(%1), %%ymm12\n" + "vmovaps 256(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" - "vmovups 288(%1), %%ymm12\n" + "vmovaps 288(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" - "vmovups 320(%1), %%ymm12\n" + "vmovaps 320(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" - "vmovups 352(%1), %%ymm12\n" + "vmovaps 352(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" @@ -206,19 +206,19 @@ void Conv1x1SW3x32Kernel(float *dst, const float *src, const float *weight, cons "vbroadcastss 12(%0), %%ymm13\n" "vbroadcastss 12(%0, %4), %%ymm14\n" "vbroadcastss 12(%0, %4, 2), %%ymm15\n" - "vmovups 384(%1), %%ymm12\n" + "vmovaps 384(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" - "vmovups 416(%1), %%ymm12\n" + "vmovaps 416(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" - "vmovups 448(%1), %%ymm12\n" + "vmovaps 448(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" - "vmovups 480(%1), %%ymm12\n" + "vmovaps 480(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" @@ -226,19 +226,19 @@ void Conv1x1SW3x32Kernel(float *dst, const float *src, const float *weight, cons "vbroadcastss 16(%0), %%ymm13\n" "vbroadcastss 16(%0, %4), %%ymm14\n" "vbroadcastss 16(%0, %4, 2), %%ymm15\n" - "vmovups 512(%1), %%ymm12\n" + "vmovaps 512(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" - "vmovups 544(%1), %%ymm12\n" + "vmovaps 544(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" - "vmovups 576(%1), %%ymm12\n" + "vmovaps 576(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" - "vmovups 608(%1), %%ymm12\n" + "vmovaps 608(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" @@ -246,19 +246,19 @@ void Conv1x1SW3x32Kernel(float *dst, const float *src, const float *weight, cons "vbroadcastss 20(%0), %%ymm13\n" "vbroadcastss 20(%0, %4), %%ymm14\n" "vbroadcastss 20(%0, %4, 2), %%ymm15\n" - "vmovups 640(%1), %%ymm12\n" + "vmovaps 640(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" - "vmovups 672(%1), %%ymm12\n" + "vmovaps 672(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" - "vmovups 704(%1), %%ymm12\n" + "vmovaps 704(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" - "vmovups 736(%1), %%ymm12\n" + "vmovaps 736(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" @@ -266,19 +266,19 @@ void Conv1x1SW3x32Kernel(float *dst, const float *src, const float *weight, cons "vbroadcastss 24(%0), %%ymm13\n" "vbroadcastss 24(%0, %4), %%ymm14\n" "vbroadcastss 24(%0, %4, 2), %%ymm15\n" - "vmovups 768(%1), %%ymm12\n" + "vmovaps 768(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" - "vmovups 800(%1), %%ymm12\n" + "vmovaps 800(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" - "vmovups 832(%1), %%ymm12\n" + "vmovaps 832(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" - "vmovups 864(%1), %%ymm12\n" + "vmovaps 864(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" @@ -286,19 +286,19 @@ void Conv1x1SW3x32Kernel(float *dst, const float *src, const float *weight, cons "vbroadcastss 28(%0), %%ymm13\n" "vbroadcastss 28(%0, %4), %%ymm14\n" "vbroadcastss 28(%0, %4, 2), %%ymm15\n" - "vmovups 896(%1), %%ymm12\n" + "vmovaps 896(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" - "vmovups 928(%1), %%ymm12\n" + "vmovaps 928(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" - "vmovups 960(%1), %%ymm12\n" + "vmovaps 960(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" - "vmovups 992(%1), %%ymm12\n" + "vmovaps 992(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" @@ -362,18 +362,15 @@ void Conv1x1SW3x32Kernel(float *dst, const float *src, const float *weight, cons "vmovups %%ymm10, 0x40(%7, %6, 2)\n" "vmovups %%ymm11, 0x60(%7, %6, 2)\n" : - : "r"(src), "r"(weight), "r"(bias), "r"(ic_algin), "r"(in_sw_step), "r"(act_flag), "r"(oc_algin), "r"(dst), + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(act_flag), "r"(oc_align), "r"(dst), "r"(dst_flag) // 8 : "%rax", "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); } void Conv1x1SW1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, - size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, size_t dst_flag) { - in_sw_step *= sizeof(float); - oc_algin *= sizeof(float); - ic_algin /= 8; asm volatile( "movq %8, %%rax\n" "and $0x1, %%eax\n" @@ -386,10 +383,10 @@ void Conv1x1SW1x32Kernel(float *dst, const float *src, const float *weight, cons "0:\n" "cmpq $0, %2\n" "je 1f\n" - "vmovups (%2), %%ymm0\n" - "vmovups 0x20(%2), %%ymm1\n" - "vmovups 0x40(%2), %%ymm2\n" - "vmovups 0x60(%2), %%ymm3\n" + "vmovaps (%2), %%ymm0\n" + "vmovaps 0x20(%2), %%ymm1\n" + "vmovaps 0x40(%2), %%ymm2\n" + "vmovaps 0x60(%2), %%ymm3\n" "jmp 2f\n" "1:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -399,80 +396,80 @@ void Conv1x1SW1x32Kernel(float *dst, const float *src, const float *weight, cons "2:\n" // LoopIC "vbroadcastss (%0), %%ymm13\n" - "vmovups (%1), %%ymm4\n" - "vmovups 0x20(%1), %%ymm5\n" - "vmovups 0x40(%1), %%ymm6\n" - "vmovups 0x60(%1), %%ymm7\n" + "vmovaps (%1), %%ymm4\n" + "vmovaps 0x20(%1), %%ymm5\n" + "vmovaps 0x40(%1), %%ymm6\n" + "vmovaps 0x60(%1), %%ymm7\n" "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" "vbroadcastss 4(%0), %%ymm13\n" - "vmovups 128(%1), %%ymm4\n" - "vmovups 160(%1), %%ymm5\n" - "vmovups 192(%1), %%ymm6\n" - "vmovups 224(%1), %%ymm7\n" + "vmovaps 128(%1), %%ymm4\n" + "vmovaps 160(%1), %%ymm5\n" + "vmovaps 192(%1), %%ymm6\n" + "vmovaps 224(%1), %%ymm7\n" "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" "vbroadcastss 8(%0), %%ymm13\n" - "vmovups 256(%1), %%ymm4\n" - "vmovups 288(%1), %%ymm5\n" - "vmovups 320(%1), %%ymm6\n" - "vmovups 352(%1), %%ymm7\n" + "vmovaps 256(%1), %%ymm4\n" + "vmovaps 288(%1), %%ymm5\n" + "vmovaps 320(%1), %%ymm6\n" + "vmovaps 352(%1), %%ymm7\n" "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" "vbroadcastss 12(%0), %%ymm13\n" - "vmovups 384(%1), %%ymm4\n" - "vmovups 416(%1), %%ymm5\n" - "vmovups 448(%1), %%ymm6\n" - "vmovups 480(%1), %%ymm7\n" + "vmovaps 384(%1), %%ymm4\n" + "vmovaps 416(%1), %%ymm5\n" + "vmovaps 448(%1), %%ymm6\n" + "vmovaps 480(%1), %%ymm7\n" "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" "vbroadcastss 16(%0), %%ymm13\n" - "vmovups 512(%1), %%ymm4\n" - "vmovups 544(%1), %%ymm5\n" - "vmovups 576(%1), %%ymm6\n" - "vmovups 608(%1), %%ymm7\n" + "vmovaps 512(%1), %%ymm4\n" + "vmovaps 544(%1), %%ymm5\n" + "vmovaps 576(%1), %%ymm6\n" + "vmovaps 608(%1), %%ymm7\n" "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" "vbroadcastss 20(%0), %%ymm13\n" - "vmovups 640(%1), %%ymm4\n" - "vmovups 672(%1), %%ymm5\n" - "vmovups 704(%1), %%ymm6\n" - "vmovups 736(%1), %%ymm7\n" + "vmovaps 640(%1), %%ymm4\n" + "vmovaps 672(%1), %%ymm5\n" + "vmovaps 704(%1), %%ymm6\n" + "vmovaps 736(%1), %%ymm7\n" "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" "vbroadcastss 24(%0), %%ymm13\n" - "vmovups 768(%1), %%ymm4\n" - "vmovups 800(%1), %%ymm5\n" - "vmovups 832(%1), %%ymm6\n" - "vmovups 864(%1), %%ymm7\n" + "vmovaps 768(%1), %%ymm4\n" + "vmovaps 800(%1), %%ymm5\n" + "vmovaps 832(%1), %%ymm6\n" + "vmovaps 864(%1), %%ymm7\n" "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" "vbroadcastss 28(%0), %%ymm13\n" - "vmovups 896(%1), %%ymm4\n" - "vmovups 928(%1), %%ymm5\n" - "vmovups 960(%1), %%ymm6\n" - "vmovups 992(%1), %%ymm7\n" + "vmovaps 896(%1), %%ymm4\n" + "vmovaps 928(%1), %%ymm5\n" + "vmovaps 960(%1), %%ymm6\n" + "vmovaps 992(%1), %%ymm7\n" "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" @@ -512,20 +509,17 @@ void Conv1x1SW1x32Kernel(float *dst, const float *src, const float *weight, cons "vmovups %%ymm2, 0x40(%7)\n" "vmovups %%ymm3, 0x60(%7)\n" : - : "r"(src), "r"(weight), "r"(bias), "r"(ic_algin), "r"(in_sw_step), "r"(act_flag), "r"(oc_algin), "r"(dst), + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(act_flag), "r"(oc_align), "r"(dst), "r"(dst_flag) // 8 : "%rax", "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm12", "%ymm13", "%ymm14"); } void Conv1x1SW4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, - size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, size_t dst_flag) { - in_sw_step *= sizeof(float); - ic_algin /= 8; size_t src_3_step = 3 * in_sw_step; - float *dst_3 = dst + 3 * oc_algin; - oc_algin *= sizeof(float); + float *dst_3 = dst + 3 * oc_align / sizeof(float); asm volatile( "movq %10, %%rax\n" // dst_flag "and $0x1, %%eax\n" @@ -546,18 +540,18 @@ void Conv1x1SW4x24Kernel(float *dst, const float *src, const float *weight, cons "0:\n" "cmpq $0, %2\n" "je 1f\n" - "vmovups (%2), %%ymm0\n" - "vmovups 0x20(%2), %%ymm1\n" - "vmovups 0x40(%2), %%ymm2\n" - "vmovups (%2), %%ymm3\n" - "vmovups 0x20(%2), %%ymm4\n" - "vmovups 0x40(%2), %%ymm5\n" - "vmovups (%2), %%ymm6\n" - "vmovups 0x20(%2), %%ymm7\n" - "vmovups 0x40(%2), %%ymm8\n" - "vmovups (%2), %%ymm9\n" - "vmovups 0x20(%2), %%ymm10\n" - "vmovups 0x40(%2), %%ymm11\n" + "vmovaps (%2), %%ymm0\n" + "vmovaps 0x20(%2), %%ymm1\n" + "vmovaps 0x40(%2), %%ymm2\n" + "vmovaps (%2), %%ymm3\n" + "vmovaps 0x20(%2), %%ymm4\n" + "vmovaps 0x40(%2), %%ymm5\n" + "vmovaps (%2), %%ymm6\n" + "vmovaps 0x20(%2), %%ymm7\n" + "vmovaps 0x40(%2), %%ymm8\n" + "vmovaps (%2), %%ymm9\n" + "vmovaps 0x20(%2), %%ymm10\n" + "vmovaps 0x40(%2), %%ymm11\n" "jmp 2f\n" "1:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -574,9 +568,9 @@ void Conv1x1SW4x24Kernel(float *dst, const float *src, const float *weight, cons "vxorps %%ymm11, %%ymm11, %%ymm11\n" "2:\n" // LoopIC - "vmovups (%1), %%ymm13\n" - "vmovups 0x20(%1), %%ymm14\n" - "vmovups 0x40(%1), %%ymm15\n" + "vmovaps (%1), %%ymm13\n" + "vmovaps 0x20(%1), %%ymm14\n" + "vmovaps 0x40(%1), %%ymm15\n" "vbroadcastss (%0), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" @@ -594,9 +588,9 @@ void Conv1x1SW4x24Kernel(float *dst, const float *src, const float *weight, cons "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" - "vmovups 96(%1), %%ymm13\n" - "vmovups 128(%1), %%ymm14\n" - "vmovups 160(%1), %%ymm15\n" + "vmovaps 96(%1), %%ymm13\n" + "vmovaps 128(%1), %%ymm14\n" + "vmovaps 160(%1), %%ymm15\n" "vbroadcastss 4(%0), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" @@ -614,9 +608,9 @@ void Conv1x1SW4x24Kernel(float *dst, const float *src, const float *weight, cons "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" - "vmovups 192(%1), %%ymm13\n" - "vmovups 224(%1), %%ymm14\n" - "vmovups 256(%1), %%ymm15\n" + "vmovaps 192(%1), %%ymm13\n" + "vmovaps 224(%1), %%ymm14\n" + "vmovaps 256(%1), %%ymm15\n" "vbroadcastss 8(%0), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" @@ -634,9 +628,9 @@ void Conv1x1SW4x24Kernel(float *dst, const float *src, const float *weight, cons "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" - "vmovups 288(%1), %%ymm13\n" - "vmovups 320(%1), %%ymm14\n" - "vmovups 352(%1), %%ymm15\n" + "vmovaps 288(%1), %%ymm13\n" + "vmovaps 320(%1), %%ymm14\n" + "vmovaps 352(%1), %%ymm15\n" "vbroadcastss 12(%0), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" @@ -654,9 +648,9 @@ void Conv1x1SW4x24Kernel(float *dst, const float *src, const float *weight, cons "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" - "vmovups 384(%1), %%ymm13\n" - "vmovups 416(%1), %%ymm14\n" - "vmovups 448(%1), %%ymm15\n" + "vmovaps 384(%1), %%ymm13\n" + "vmovaps 416(%1), %%ymm14\n" + "vmovaps 448(%1), %%ymm15\n" "vbroadcastss 16(%0), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" @@ -674,9 +668,9 @@ void Conv1x1SW4x24Kernel(float *dst, const float *src, const float *weight, cons "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" - "vmovups 480(%1), %%ymm13\n" - "vmovups 512(%1), %%ymm14\n" - "vmovups 544(%1), %%ymm15\n" + "vmovaps 480(%1), %%ymm13\n" + "vmovaps 512(%1), %%ymm14\n" + "vmovaps 544(%1), %%ymm15\n" "vbroadcastss 20(%0), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" @@ -694,9 +688,9 @@ void Conv1x1SW4x24Kernel(float *dst, const float *src, const float *weight, cons "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" - "vmovups 576(%1), %%ymm13\n" - "vmovups 608(%1), %%ymm14\n" - "vmovups 640(%1), %%ymm15\n" + "vmovaps 576(%1), %%ymm13\n" + "vmovaps 608(%1), %%ymm14\n" + "vmovaps 640(%1), %%ymm15\n" "vbroadcastss 24(%0), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" @@ -714,9 +708,9 @@ void Conv1x1SW4x24Kernel(float *dst, const float *src, const float *weight, cons "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" - "vmovups 672(%1), %%ymm13\n" - "vmovups 704(%1), %%ymm14\n" - "vmovups 736(%1), %%ymm15\n" + "vmovaps 672(%1), %%ymm13\n" + "vmovaps 704(%1), %%ymm14\n" + "vmovaps 736(%1), %%ymm15\n" "vbroadcastss 28(%0), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" @@ -793,18 +787,15 @@ void Conv1x1SW4x24Kernel(float *dst, const float *src, const float *weight, cons "vmovups %%ymm10, 0x20(%9)\n" "vmovups %%ymm11, 0x40(%9)\n" : - : "r"(src), "r"(weight), "r"(bias), "r"(ic_algin), "r"(in_sw_step), "r"(src_3_step), "r"(act_flag), // 6 - "r"(oc_algin), "r"(dst), "r"(dst_3), "r"(dst_flag) // 10 + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(src_3_step), "r"(act_flag), // 6 + "r"(oc_align), "r"(dst), "r"(dst_3), "r"(dst_flag) // 10 : "%rax", "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); } void Conv1x1SW1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, - size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, size_t dst_flag) { - in_sw_step *= sizeof(float); - ic_algin /= 8; - oc_algin *= sizeof(float); asm volatile( "movq %8, %%rax\n" "and $0x1, %%eax\n" @@ -816,9 +807,9 @@ void Conv1x1SW1x24Kernel(float *dst, const float *src, const float *weight, cons "0:\n" "cmpq $0, %2\n" "je 1f\n" - "vmovups (%2), %%ymm0\n" - "vmovups 0x20(%2), %%ymm1\n" - "vmovups 0x40(%2), %%ymm2\n" + "vmovaps (%2), %%ymm0\n" + "vmovaps 0x20(%2), %%ymm1\n" + "vmovaps 0x40(%2), %%ymm2\n" "jmp 2f\n" "1:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -827,65 +818,65 @@ void Conv1x1SW1x24Kernel(float *dst, const float *src, const float *weight, cons "2:\n" // LoopIC "vbroadcastss (%0), %%ymm13\n" - "vmovups (%1), %%ymm4\n" - "vmovups 0x20(%1), %%ymm5\n" - "vmovups 0x40(%1), %%ymm6\n" + "vmovaps (%1), %%ymm4\n" + "vmovaps 0x20(%1), %%ymm5\n" + "vmovaps 0x40(%1), %%ymm6\n" "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" "vbroadcastss 4(%0), %%ymm13\n" - "vmovups 96(%1), %%ymm4\n" - "vmovups 128(%1), %%ymm5\n" - "vmovups 160(%1), %%ymm6\n" + "vmovaps 96(%1), %%ymm4\n" + "vmovaps 128(%1), %%ymm5\n" + "vmovaps 160(%1), %%ymm6\n" "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" "vbroadcastss 8(%0), %%ymm13\n" - "vmovups 192(%1), %%ymm4\n" - "vmovups 224(%1), %%ymm5\n" - "vmovups 256(%1), %%ymm6\n" + "vmovaps 192(%1), %%ymm4\n" + "vmovaps 224(%1), %%ymm5\n" + "vmovaps 256(%1), %%ymm6\n" "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" "vbroadcastss 12(%0), %%ymm13\n" - "vmovups 288(%1), %%ymm4\n" - "vmovups 320(%1), %%ymm5\n" - "vmovups 352(%1), %%ymm6\n" + "vmovaps 288(%1), %%ymm4\n" + "vmovaps 320(%1), %%ymm5\n" + "vmovaps 352(%1), %%ymm6\n" "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" "vbroadcastss 16(%0), %%ymm13\n" - "vmovups 384(%1), %%ymm4\n" - "vmovups 416(%1), %%ymm5\n" - "vmovups 448(%1), %%ymm6\n" + "vmovaps 384(%1), %%ymm4\n" + "vmovaps 416(%1), %%ymm5\n" + "vmovaps 448(%1), %%ymm6\n" "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" "vbroadcastss 20(%0), %%ymm13\n" - "vmovups 480(%1), %%ymm4\n" - "vmovups 512(%1), %%ymm5\n" - "vmovups 544(%1), %%ymm6\n" + "vmovaps 480(%1), %%ymm4\n" + "vmovaps 512(%1), %%ymm5\n" + "vmovaps 544(%1), %%ymm6\n" "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" "vbroadcastss 24(%0), %%ymm13\n" - "vmovups 576(%1), %%ymm4\n" - "vmovups 608(%1), %%ymm5\n" - "vmovups 640(%1), %%ymm6\n" + "vmovaps 576(%1), %%ymm4\n" + "vmovaps 608(%1), %%ymm5\n" + "vmovaps 640(%1), %%ymm6\n" "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" "vbroadcastss 28(%0), %%ymm13\n" - "vmovups 672(%1), %%ymm4\n" - "vmovups 704(%1), %%ymm5\n" - "vmovups 736(%1), %%ymm6\n" + "vmovaps 672(%1), %%ymm4\n" + "vmovaps 704(%1), %%ymm5\n" + "vmovaps 736(%1), %%ymm6\n" "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" @@ -922,19 +913,16 @@ void Conv1x1SW1x24Kernel(float *dst, const float *src, const float *weight, cons "vmovups %%ymm1, 0x20(%7)\n" "vmovups %%ymm2, 0x40(%7)\n" : - : "r"(src), "r"(weight), "r"(bias), "r"(ic_algin), "r"(in_sw_step), "r"(act_flag), "r"(oc_algin), "r"(dst), + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(act_flag), "r"(oc_align), "r"(dst), "r"(dst_flag) // 8 : "%rax", "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm4", "%ymm5", "%ymm6", "%ymm12", "%ymm13", "%ymm14"); } void Conv1x1SW6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, - size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, size_t dst_flag) { - in_sw_step *= sizeof(float); size_t src_3_step = 3 * in_sw_step; - float *dst_3 = dst + 3 * oc_algin; - oc_algin *= sizeof(float); - ic_algin /= 8; + float *dst_3 = dst + 3 * oc_align / sizeof(float); asm volatile( "movq %10, %%rax\n" // dst_flag "and $0x1, %%eax\n" @@ -955,19 +943,19 @@ void Conv1x1SW6x16Kernel(float *dst, const float *src, const float *weight, cons "0:\n" "cmpq $0, %2\n" "je 1f\n" - "vmovups (%2), %%ymm0\n" - "vmovups 0x20(%2), %%ymm1\n" + "vmovaps (%2), %%ymm0\n" + "vmovaps 0x20(%2), %%ymm1\n" // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction. - "vmovups (%2), %%ymm2\n" - "vmovups 0x20(%2), %%ymm3\n" - "vmovups (%2), %%ymm4\n" - "vmovups 0x20(%2), %%ymm5\n" - "vmovups (%2), %%ymm6\n" - "vmovups 0x20(%2), %%ymm7\n" - "vmovups (%2), %%ymm8\n" - "vmovups 0x20(%2), %%ymm9\n" - "vmovups (%2), %%ymm10\n" - "vmovups 0x20(%2), %%ymm11\n" + "vmovaps (%2), %%ymm2\n" + "vmovaps 0x20(%2), %%ymm3\n" + "vmovaps (%2), %%ymm4\n" + "vmovaps 0x20(%2), %%ymm5\n" + "vmovaps (%2), %%ymm6\n" + "vmovaps 0x20(%2), %%ymm7\n" + "vmovaps (%2), %%ymm8\n" + "vmovaps 0x20(%2), %%ymm9\n" + "vmovaps (%2), %%ymm10\n" + "vmovaps 0x20(%2), %%ymm11\n" "jmp 2f\n" "1:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -987,8 +975,8 @@ void Conv1x1SW6x16Kernel(float *dst, const float *src, const float *weight, cons "movq %0, %%rax\n" "addq %5, %%rax\n" - "vmovups (%1), %%ymm12\n" - "vmovups 0x20(%1), %%ymm13\n" + "vmovaps (%1), %%ymm12\n" + "vmovaps 0x20(%1), %%ymm13\n" "vbroadcastss (%0), %%ymm14\n" "vbroadcastss (%0, %4), %%ymm15\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" @@ -1008,8 +996,8 @@ void Conv1x1SW6x16Kernel(float *dst, const float *src, const float *weight, cons "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" - "vmovups 64(%1), %%ymm12\n" - "vmovups 96(%1), %%ymm13\n" + "vmovaps 64(%1), %%ymm12\n" + "vmovaps 96(%1), %%ymm13\n" "vbroadcastss 4(%0), %%ymm14\n" "vbroadcastss 4(%0, %4), %%ymm15\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" @@ -1029,8 +1017,8 @@ void Conv1x1SW6x16Kernel(float *dst, const float *src, const float *weight, cons "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" - "vmovups 128(%1), %%ymm12\n" - "vmovups 160(%1), %%ymm13\n" + "vmovaps 128(%1), %%ymm12\n" + "vmovaps 160(%1), %%ymm13\n" "vbroadcastss 8(%0), %%ymm14\n" "vbroadcastss 8(%0, %4), %%ymm15\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" @@ -1050,8 +1038,8 @@ void Conv1x1SW6x16Kernel(float *dst, const float *src, const float *weight, cons "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" - "vmovups 192(%1), %%ymm12\n" - "vmovups 224(%1), %%ymm13\n" + "vmovaps 192(%1), %%ymm12\n" + "vmovaps 224(%1), %%ymm13\n" "vbroadcastss 12(%0), %%ymm14\n" "vbroadcastss 12(%0, %4), %%ymm15\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" @@ -1071,8 +1059,8 @@ void Conv1x1SW6x16Kernel(float *dst, const float *src, const float *weight, cons "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" - "vmovups 256(%1), %%ymm12\n" - "vmovups 288(%1), %%ymm13\n" + "vmovaps 256(%1), %%ymm12\n" + "vmovaps 288(%1), %%ymm13\n" "vbroadcastss 16(%0), %%ymm14\n" "vbroadcastss 16(%0, %4), %%ymm15\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" @@ -1092,8 +1080,8 @@ void Conv1x1SW6x16Kernel(float *dst, const float *src, const float *weight, cons "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" - "vmovups 320(%1), %%ymm12\n" - "vmovups 352(%1), %%ymm13\n" + "vmovaps 320(%1), %%ymm12\n" + "vmovaps 352(%1), %%ymm13\n" "vbroadcastss 20(%0), %%ymm14\n" "vbroadcastss 20(%0, %4), %%ymm15\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" @@ -1113,8 +1101,8 @@ void Conv1x1SW6x16Kernel(float *dst, const float *src, const float *weight, cons "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" - "vmovups 384(%1), %%ymm12\n" - "vmovups 416(%1), %%ymm13\n" + "vmovaps 384(%1), %%ymm12\n" + "vmovaps 416(%1), %%ymm13\n" "vbroadcastss 24(%0), %%ymm14\n" "vbroadcastss 24(%0, %4), %%ymm15\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" @@ -1134,8 +1122,8 @@ void Conv1x1SW6x16Kernel(float *dst, const float *src, const float *weight, cons "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" - "vmovups 448(%1), %%ymm12\n" - "vmovups 480(%1), %%ymm13\n" + "vmovaps 448(%1), %%ymm12\n" + "vmovaps 480(%1), %%ymm13\n" "vbroadcastss 28(%0), %%ymm14\n" "vbroadcastss 28(%0, %4), %%ymm15\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" @@ -1214,18 +1202,15 @@ void Conv1x1SW6x16Kernel(float *dst, const float *src, const float *weight, cons "vmovups %%ymm10, (%9, %7, 2)\n" "vmovups %%ymm11, 0x20(%9, %7, 2)\n" : - : "r"(src), "r"(weight), "r"(bias), "r"(ic_algin), "r"(in_sw_step), "r"(src_3_step), "r"(act_flag), // 6 - "r"(oc_algin), "r"(dst), "r"(dst_3), "r"(dst_flag) // 10 + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(src_3_step), "r"(act_flag), // 6 + "r"(oc_align), "r"(dst), "r"(dst_3), "r"(dst_flag) // 10 : "%rax", "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); } void Conv1x1SW1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, - size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, size_t dst_flag) { - in_sw_step *= sizeof(float); - oc_algin *= sizeof(float); - ic_algin /= 8; asm volatile( "movq %8, %%rax\n" "and $0x1, %%eax\n" @@ -1236,8 +1221,8 @@ void Conv1x1SW1x16Kernel(float *dst, const float *src, const float *weight, cons "0:\n" "cmpq $0, %2\n" "je 1f\n" - "vmovups (%2), %%ymm0\n" - "vmovups 0x20(%2), %%ymm1\n" + "vmovaps (%2), %%ymm0\n" + "vmovaps 0x20(%2), %%ymm1\n" "jmp 2f\n" "1:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -1245,50 +1230,50 @@ void Conv1x1SW1x16Kernel(float *dst, const float *src, const float *weight, cons "2:\n" // LoopIC "vbroadcastss (%0), %%ymm12\n" - "vmovups (%1), %%ymm13\n" - "vmovups 0x20(%1), %%ymm14\n" + "vmovaps (%1), %%ymm13\n" + "vmovaps 0x20(%1), %%ymm14\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" "vbroadcastss 4(%0), %%ymm12\n" - "vmovups 64(%1), %%ymm13\n" - "vmovups 96(%1), %%ymm14\n" + "vmovaps 64(%1), %%ymm13\n" + "vmovaps 96(%1), %%ymm14\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" "vbroadcastss 8(%0), %%ymm12\n" - "vmovups 128(%1), %%ymm13\n" - "vmovups 160(%1), %%ymm14\n" + "vmovaps 128(%1), %%ymm13\n" + "vmovaps 160(%1), %%ymm14\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" "vbroadcastss 12(%0), %%ymm12\n" - "vmovups 192(%1), %%ymm13\n" - "vmovups 224(%1), %%ymm14\n" + "vmovaps 192(%1), %%ymm13\n" + "vmovaps 224(%1), %%ymm14\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" "vbroadcastss 16(%0), %%ymm12\n" - "vmovups 256(%1), %%ymm13\n" - "vmovups 288(%1), %%ymm14\n" + "vmovaps 256(%1), %%ymm13\n" + "vmovaps 288(%1), %%ymm14\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" "vbroadcastss 20(%0), %%ymm12\n" - "vmovups 320(%1), %%ymm13\n" - "vmovups 352(%1), %%ymm14\n" + "vmovaps 320(%1), %%ymm13\n" + "vmovaps 352(%1), %%ymm14\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" "vbroadcastss 24(%0), %%ymm12\n" - "vmovups 384(%1), %%ymm13\n" - "vmovups 416(%1), %%ymm14\n" + "vmovaps 384(%1), %%ymm13\n" + "vmovaps 416(%1), %%ymm14\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" "vbroadcastss 28(%0), %%ymm12\n" - "vmovups 448(%1), %%ymm13\n" - "vmovups 480(%1), %%ymm14\n" + "vmovaps 448(%1), %%ymm13\n" + "vmovaps 480(%1), %%ymm14\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" @@ -1321,20 +1306,19 @@ void Conv1x1SW1x16Kernel(float *dst, const float *src, const float *weight, cons "vmovups %%ymm0, (%7)\n" // dst_0 "vmovups %%ymm1, 0x20(%7)\n" : - : "r"(src), "r"(weight), "r"(bias), "r"(ic_algin), "r"(in_sw_step), "r"(act_flag), "r"(oc_algin), "r"(dst), + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(act_flag), "r"(oc_align), "r"(dst), "r"(dst_flag) // 8 : "%rax", "%ecx", "%ymm0", "%ymm1", "%ymm12", "%ymm13", "%ymm14"); } void Conv1x1SW12x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, - size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, size_t dst_flag) { - in_sw_step *= sizeof(float); + ic_align <<= 3; size_t src_3_step = 3 * in_sw_step; - float *dst_3 = dst + 3 * oc_algin; - float *dst_5 = dst + 5 * oc_algin; - float *dst_9 = dst + 9 * oc_algin; - oc_algin *= sizeof(float); + float *dst_3 = dst + 3 * oc_align / sizeof(float); + float *dst_5 = dst + 5 * oc_align / sizeof(float); + float *dst_9 = dst + 9 * oc_align / sizeof(float); asm volatile( "movq %12, %%rax\n" "and $0x1, %%eax\n" @@ -1355,18 +1339,18 @@ void Conv1x1SW12x8Kernel(float *dst, const float *src, const float *weight, cons "0:\n" "cmpq $0, %2\n" "je 1f\n" - "vmovups (%2), %%ymm0\n" - "vmovups (%2), %%ymm1\n" - "vmovups (%2), %%ymm2\n" - "vmovups (%2), %%ymm3\n" - "vmovups (%2), %%ymm4\n" - "vmovups (%2), %%ymm5\n" - "vmovups (%2), %%ymm6\n" - "vmovups (%2), %%ymm7\n" - "vmovups (%2), %%ymm8\n" - "vmovups (%2), %%ymm9\n" - "vmovups (%2), %%ymm10\n" - "vmovups (%2), %%ymm11\n" + "vmovaps (%2), %%ymm0\n" + "vmovaps (%2), %%ymm1\n" + "vmovaps (%2), %%ymm2\n" + "vmovaps (%2), %%ymm3\n" + "vmovaps (%2), %%ymm4\n" + "vmovaps (%2), %%ymm5\n" + "vmovaps (%2), %%ymm6\n" + "vmovaps (%2), %%ymm7\n" + "vmovaps (%2), %%ymm8\n" + "vmovaps (%2), %%ymm9\n" + "vmovaps (%2), %%ymm10\n" + "vmovaps (%2), %%ymm11\n" "jmp 2f\n" "1:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -1383,7 +1367,7 @@ void Conv1x1SW12x8Kernel(float *dst, const float *src, const float *weight, cons "vxorps %%ymm11, %%ymm11, %%ymm11\n" "2:\n" // LoopIC - "vmovups (%1), %%ymm12\n" + "vmovaps (%1), %%ymm12\n" "movq %0, %%rax\n" "vbroadcastss (%%rax), %%ymm13\n" "vbroadcastss (%%rax, %4), %%ymm14\n" @@ -1417,8 +1401,8 @@ void Conv1x1SW12x8Kernel(float *dst, const float *src, const float *weight, cons "dec %3\n" "jg 2b\n" : - : "r"(src), "r"(weight), "r"(bias), "r"(ic_algin), "r"(in_sw_step), "r"(src_3_step), "r"(act_flag), // 6 - "r"(oc_algin), "r"(dst), "r"(dst_3), "r"(dst_5), "r"(dst_9), "r"(dst_flag) // 12 + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(src_3_step), "r"(act_flag), // 6 + "r"(oc_align), "r"(dst), "r"(dst_3), "r"(dst_5), "r"(dst_9), "r"(dst_flag) // 12 : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); @@ -1476,17 +1460,14 @@ void Conv1x1SW12x8Kernel(float *dst, const float *src, const float *weight, cons "vmovups %%ymm10, (%5, %1, 1)\n" "vmovups %%ymm11, (%5, %1, 2)\n" : - : "r"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3), "r"(dst_5), "r"(dst_9), "a"(dst_flag) // 6 + : "r"(act_flag), "r"(oc_align), "r"(dst), "r"(dst_3), "r"(dst_5), "r"(dst_9), "a"(dst_flag) // 6 : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm14"); } void Conv1x1SW1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, - size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, size_t dst_flag) { - in_sw_step *= sizeof(float); - oc_algin *= sizeof(float); - ic_algin /= 8; asm volatile( "movq %8, %%rax\n" "and $0x1, %%eax\n" @@ -1496,42 +1477,42 @@ void Conv1x1SW1x8Kernel(float *dst, const float *src, const float *weight, const "0:\n" "cmpq $0, %2\n" "je 1f\n" - "vmovups (%2), %%ymm0\n" + "vmovaps (%2), %%ymm0\n" "jmp 2f\n" "1:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" "2:\n" // LoopIC "vbroadcastss (%0), %%ymm12\n" - "vmovups (%1), %%ymm13\n" + "vmovaps (%1), %%ymm13\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vbroadcastss 4(%0), %%ymm12\n" - "vmovups 32(%1), %%ymm13\n" + "vmovaps 32(%1), %%ymm13\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vbroadcastss 8(%0), %%ymm12\n" - "vmovups 64(%1), %%ymm13\n" + "vmovaps 64(%1), %%ymm13\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vbroadcastss 12(%0), %%ymm12\n" - "vmovups 96(%1), %%ymm13\n" + "vmovaps 96(%1), %%ymm13\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vbroadcastss 16(%0), %%ymm12\n" - "vmovups 128(%1), %%ymm13\n" + "vmovaps 128(%1), %%ymm13\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vbroadcastss 20(%0), %%ymm12\n" - "vmovups 160(%1), %%ymm13\n" + "vmovaps 160(%1), %%ymm13\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vbroadcastss 24(%0), %%ymm12\n" - "vmovups 192(%1), %%ymm13\n" + "vmovaps 192(%1), %%ymm13\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vbroadcastss 28(%0), %%ymm12\n" - "vmovups 224(%1), %%ymm13\n" + "vmovaps 224(%1), %%ymm13\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "addq $256, %1\n" "addq $32, %0\n" @@ -1559,15 +1540,18 @@ void Conv1x1SW1x8Kernel(float *dst, const float *src, const float *weight, const "3:\n" "vmovups %%ymm0, (%7)\n" // dst_0 : - : "r"(src), "r"(weight), "r"(bias), "r"(ic_algin), "r"(in_sw_step), "r"(act_flag), "r"(oc_algin), "r"(dst), + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(act_flag), "r"(oc_align), "r"(dst), "r"(dst_flag) // 8 : "%rax", "%ecx", "%ymm0", "%ymm12", "%ymm13"); } #ifdef ENABLE_DEBUG void Conv1x1SWOWxOCKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, - size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, size_t dst_flag) { + oc_align /= sizeof(float); + in_sw_step /= sizeof(float); + ic_align <<= 3; __m256 dst_data[12]; const float *src_sw[12]; __m256 weight_data[4]; @@ -1577,7 +1561,7 @@ void Conv1x1SWOWxOCKernel(float *dst, const float *src, const float *weight, con for (int i = 0; i < ow_block; ++i) { if (dst_flag & 0x01) { for (int j = 0; j < oc_block; ++j) { - dst_data[i * oc_block + j] = _mm256_loadu_ps(dst + i * oc_algin + j * C8NUM); + dst_data[i * oc_block + j] = _mm256_loadu_ps(dst + i * oc_align + j * C8NUM); } } else { if (bias != NULL) { @@ -1593,7 +1577,7 @@ void Conv1x1SWOWxOCKernel(float *dst, const float *src, const float *weight, con src_sw[i] = src + i * in_sw_step; } const float *weight_kernel = weight; - for (int ic = 0; ic < ic_algin; ++ic) { + for (int ic = 0; ic < ic_align; ++ic) { for (int j = 0; j < oc_block; ++j) { weight_data[j] = _mm256_loadu_ps(weight_kernel + j * C8NUM); } @@ -1615,7 +1599,7 @@ void Conv1x1SWOWxOCKernel(float *dst, const float *src, const float *weight, con dst_data[i * oc_block + j] = _mm256_max_ps(dst_data[i * oc_block + j], _mm256_set1_ps(0.0f)); } } - _mm256_storeu_ps(dst + i * oc_algin + j * C8NUM, dst_data[i * oc_block + j]); + _mm256_storeu_ps(dst + i * oc_align + j * C8NUM, dst_data[i * oc_block + j]); } } } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_1x1_x86_fp32.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_1x1_x86_fp32.h index a4c9a6471b8..160344f616b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_1x1_x86_fp32.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_1x1_x86_fp32.h @@ -24,7 +24,7 @@ extern "C" { #endif typedef void (*Conv1x1SWKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, - size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, size_t dst_flag); void Conv1X1SWBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, @@ -36,48 +36,48 @@ void Conv1x1SWFp32(const float *input_data, const float *packed_weight, const fl #ifdef ENABLE_DEBUG void Conv1x1SWOWxOCKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, - size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, size_t dst_flag); #endif void Conv1x1SW3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, - size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, size_t dst_flag); void Conv1x1SW1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, - size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, size_t dst_flag); void Conv1x1SW4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, - size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, size_t dst_flag); void Conv1x1SW1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, - size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, size_t dst_flag); void Conv1x1SW6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, - size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, size_t dst_flag); void Conv1x1SW1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, - size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, size_t dst_flag); void Conv1x1SW12x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, - size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, size_t dst_flag); void Conv1x1SW8x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, - size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, size_t dst_flag); void Conv1x1SW4x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, - size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, size_t dst_flag); void Conv1x1SW1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, - size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_sw_step, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, size_t dst_flag); #endif #ifdef __cplusplus diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.c index 522339a99d2..89920ae573c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.c @@ -218,18 +218,18 @@ void SWConv3x32Kernel(float *dst, const float *src, const float *weight, const f asm volatile( "cmpq $0, %2\n" "je 0f\n" - "vmovups (%2), %%ymm0\n" - "vmovups 0x20(%2), %%ymm1\n" - "vmovups 0x40(%2), %%ymm2\n" - "vmovups 0x60(%2), %%ymm3\n" - "vmovups (%2), %%ymm4\n" - "vmovups 0x20(%2), %%ymm5\n" - "vmovups 0x40(%2), %%ymm6\n" - "vmovups 0x60(%2), %%ymm7\n" - "vmovups (%2), %%ymm8\n" - "vmovups 0x20(%2), %%ymm9\n" - "vmovups 0x40(%2), %%ymm10\n" - "vmovups 0x60(%2), %%ymm11\n" + "vmovaps (%2), %%ymm0\n" + "vmovaps 0x20(%2), %%ymm1\n" + "vmovaps 0x40(%2), %%ymm2\n" + "vmovaps 0x60(%2), %%ymm3\n" + "vmovaps (%2), %%ymm4\n" + "vmovaps 0x20(%2), %%ymm5\n" + "vmovaps 0x40(%2), %%ymm6\n" + "vmovaps 0x60(%2), %%ymm7\n" + "vmovaps (%2), %%ymm8\n" + "vmovaps 0x20(%2), %%ymm9\n" + "vmovaps 0x40(%2), %%ymm10\n" + "vmovaps 0x60(%2), %%ymm11\n" "jmp 1f\n" "0:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -254,19 +254,19 @@ void SWConv3x32Kernel(float *dst, const float *src, const float *weight, const f "vbroadcastss (%%rdx), %%ymm13\n" "vbroadcastss (%%rdx, %8), %%ymm14\n" "vbroadcastss (%%rdx, %8, 2), %%ymm15\n" - "vmovups (%1), %%ymm12\n" + "vmovaps (%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" - "vmovups 0x20(%1), %%ymm12\n" + "vmovaps 0x20(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" - "vmovups 0x40(%1), %%ymm12\n" + "vmovaps 0x40(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" - "vmovups 0x60(%1), %%ymm12\n" + "vmovaps 0x60(%1), %%ymm12\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" @@ -355,10 +355,10 @@ void SWConv1x32Kernel(float *dst, const float *src, const float *weight, const f asm volatile( "cmpq $0, %2\n" "je 0f\n" - "vmovups (%2), %%ymm0\n" - "vmovups 0x20(%2), %%ymm1\n" - "vmovups 0x40(%2), %%ymm2\n" - "vmovups 0x60(%2), %%ymm3\n" + "vmovaps (%2), %%ymm0\n" + "vmovaps 0x20(%2), %%ymm1\n" + "vmovaps 0x40(%2), %%ymm2\n" + "vmovaps 0x60(%2), %%ymm3\n" "jmp 1f\n" "0:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -440,19 +440,19 @@ void SWConv4x24Kernel(float *dst, const float *src, const float *weight, const f asm volatile( "cmpq $0, %2\n" "je 0f\n" - "vmovups (%2), %%ymm0\n" - "vmovups 0x20(%2), %%ymm1\n" - "vmovups 0x40(%2), %%ymm2\n" + "vmovaps (%2), %%ymm0\n" + "vmovaps 0x20(%2), %%ymm1\n" + "vmovaps 0x40(%2), %%ymm2\n" // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction. - "vmovups (%2), %%ymm3\n" - "vmovups 0x20(%2), %%ymm4\n" - "vmovups 0x40(%2), %%ymm5\n" - "vmovups (%2), %%ymm6\n" - "vmovups 0x20(%2), %%ymm7\n" - "vmovups 0x40(%2), %%ymm8\n" - "vmovups (%2), %%ymm9\n" - "vmovups 0x20(%2), %%ymm10\n" - "vmovups 0x40(%2), %%ymm11\n" + "vmovaps (%2), %%ymm3\n" + "vmovaps 0x20(%2), %%ymm4\n" + "vmovaps 0x40(%2), %%ymm5\n" + "vmovaps (%2), %%ymm6\n" + "vmovaps 0x20(%2), %%ymm7\n" + "vmovaps 0x40(%2), %%ymm8\n" + "vmovaps (%2), %%ymm9\n" + "vmovaps 0x20(%2), %%ymm10\n" + "vmovaps 0x40(%2), %%ymm11\n" "jmp 1f\n" "0:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -474,9 +474,9 @@ void SWConv4x24Kernel(float *dst, const float *src, const float *weight, const f "movq %%rcx, %%rdx\n" "movq %5, %%r12\n" // ic_algin "3:\n" // LoopIC - "vmovups (%1), %%ymm12\n" - "vmovups 0x20(%1), %%ymm13\n" - "vmovups 0x40(%1), %%ymm14\n" + "vmovaps (%1), %%ymm12\n" + "vmovaps 0x20(%1), %%ymm13\n" + "vmovaps 0x40(%1), %%ymm14\n" "vbroadcastss (%%rdx), %%ymm15\n" "vfmadd231ps %%ymm15, %%ymm12, %%ymm0\n" @@ -585,9 +585,9 @@ void SWConv1x24Kernel(float *dst, const float *src, const float *weight, const f asm volatile( "cmpq $0, %2\n" "je 0f\n" - "vmovups (%2), %%ymm0\n" - "vmovups 0x20(%2), %%ymm1\n" - "vmovups 0x40(%2), %%ymm2\n" + "vmovaps (%2), %%ymm0\n" + "vmovaps 0x20(%2), %%ymm1\n" + "vmovaps 0x40(%2), %%ymm2\n" "jmp 1f\n" "0:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -664,19 +664,19 @@ void SWConv6x16Kernel(float *dst, const float *src, const float *weight, const f asm volatile( "cmpq $0, %2\n" "je 0f\n" - "vmovups (%2), %%ymm0\n" - "vmovups 0x20(%2), %%ymm1\n" + "vmovaps (%2), %%ymm0\n" + "vmovaps 0x20(%2), %%ymm1\n" // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction. - "vmovups (%2), %%ymm2\n" - "vmovups 0x20(%2), %%ymm3\n" - "vmovups (%2), %%ymm4\n" - "vmovups 0x20(%2), %%ymm5\n" - "vmovups (%2), %%ymm6\n" - "vmovups 0x20(%2), %%ymm7\n" - "vmovups (%2), %%ymm8\n" - "vmovups 0x20(%2), %%ymm9\n" - "vmovups (%2), %%ymm10\n" - "vmovups 0x20(%2), %%ymm11\n" + "vmovaps (%2), %%ymm2\n" + "vmovaps 0x20(%2), %%ymm3\n" + "vmovaps (%2), %%ymm4\n" + "vmovaps 0x20(%2), %%ymm5\n" + "vmovaps (%2), %%ymm6\n" + "vmovaps 0x20(%2), %%ymm7\n" + "vmovaps (%2), %%ymm8\n" + "vmovaps 0x20(%2), %%ymm9\n" + "vmovaps (%2), %%ymm10\n" + "vmovaps 0x20(%2), %%ymm11\n" "jmp 1f\n" "0:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -698,8 +698,8 @@ void SWConv6x16Kernel(float *dst, const float *src, const float *weight, const f "movq %%rcx, %%rdx\n" "movq %5, %%r12\n" // ic_algin "3:\n" // LoopIC - "vmovups (%1), %%ymm12\n" - "vmovups 0x20(%1), %%ymm13\n" + "vmovaps (%1), %%ymm12\n" + "vmovaps 0x20(%1), %%ymm13\n" "vbroadcastss (%%rdx), %%ymm15\n" "vfmadd231ps %%ymm15, %%ymm12, %%ymm0\n" @@ -812,8 +812,8 @@ void SWConv1x16Kernel(float *dst, const float *src, const float *weight, const f asm volatile( "cmpq $0, %2\n" "je 0f\n" - "vmovups (%2), %%ymm0\n" - "vmovups 0x20(%2), %%ymm1\n" + "vmovaps (%2), %%ymm0\n" + "vmovaps 0x20(%2), %%ymm1\n" "jmp 1f\n" "0:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -887,18 +887,18 @@ void SWConv12x8Kernel(float *dst, const float *src, const float *weight, const f asm volatile( "cmpq $0, %0\n" "je 0f\n" - "vmovups (%0), %%ymm0\n" - "vmovups (%0), %%ymm1\n" - "vmovups (%0), %%ymm2\n" - "vmovups (%0), %%ymm3\n" - "vmovups (%0), %%ymm4\n" - "vmovups (%0), %%ymm5\n" - "vmovups (%0), %%ymm6\n" - "vmovups (%0), %%ymm7\n" - "vmovups (%0), %%ymm8\n" - "vmovups (%0), %%ymm9\n" - "vmovups (%0), %%ymm10\n" - "vmovups (%0), %%ymm11\n" + "vmovaps (%0), %%ymm0\n" + "vmovaps (%0), %%ymm1\n" + "vmovaps (%0), %%ymm2\n" + "vmovaps (%0), %%ymm3\n" + "vmovaps (%0), %%ymm4\n" + "vmovaps (%0), %%ymm5\n" + "vmovaps (%0), %%ymm6\n" + "vmovaps (%0), %%ymm7\n" + "vmovaps (%0), %%ymm8\n" + "vmovaps (%0), %%ymm9\n" + "vmovaps (%0), %%ymm10\n" + "vmovaps (%0), %%ymm11\n" "jmp 1f\n" "0:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -926,7 +926,7 @@ void SWConv12x8Kernel(float *dst, const float *src, const float *weight, const f "movq %%rcx, %%rdx\n" "movq %4, %%r12\n" // ic_algin "LoopIC:\n" - "vmovups (%1), %%ymm12\n" + "vmovaps (%1), %%ymm12\n" "movq %%rdx, %%rax\n" "addq $32, %1\n" "vbroadcastss (%%rax), %%ymm13\n" @@ -1043,10 +1043,10 @@ void SWConv4x8Kernel(float *dst, const float *src, const float *weight, const fl asm volatile( "cmpq $0, %2\n" "je 0f\n" - "vmovups (%2), %%ymm0\n" - "vmovups (%2), %%ymm1\n" - "vmovups (%2), %%ymm2\n" - "vmovups (%2), %%ymm3\n" + "vmovaps (%2), %%ymm0\n" + "vmovaps (%2), %%ymm1\n" + "vmovaps (%2), %%ymm2\n" + "vmovaps (%2), %%ymm3\n" "jmp 1f\n" "0:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -1060,7 +1060,7 @@ void SWConv4x8Kernel(float *dst, const float *src, const float *weight, const fl "movq %%rcx, %%rdx\n" "movq %5, %%r12\n" // ic_algin "3:\n" // LoopIC - "vmovups (%1), %%ymm12\n" + "vmovaps (%1), %%ymm12\n" "movq %%rdx, %%rax\n" "addq $32, %1\n" "vbroadcastss (%%rax), %%ymm13\n" @@ -1131,7 +1131,7 @@ void SWConv1x8Kernel(float *dst, const float *src, const float *weight, const fl asm volatile( "cmpq $0, %2\n" "je 0f\n" - "vmovups (%2), %%ymm0\n" + "vmovaps (%2), %%ymm0\n" "jmp 1f\n" "0:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_depthwise_fp32.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_depthwise_fp32.c index 3d087864638..bb511cb74cc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_depthwise_fp32.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_depthwise_fp32.c @@ -1181,18 +1181,18 @@ void DepthwiseSW3x32Kernel(float *dst, const float *src, const float *weight, co asm volatile( "cmpq $0, %2\n" "je 0f\n" - "vmovups (%2), %%ymm0\n" - "vmovups 0x20(%2), %%ymm1\n" - "vmovups 0x40(%2), %%ymm2\n" - "vmovups 0x60(%2), %%ymm3\n" - "vmovups (%2), %%ymm4\n" - "vmovups 0x20(%2), %%ymm5\n" - "vmovups 0x40(%2), %%ymm6\n" - "vmovups 0x60(%2), %%ymm7\n" - "vmovups (%2), %%ymm8\n" - "vmovups 0x20(%2), %%ymm9\n" - "vmovups 0x40(%2), %%ymm10\n" - "vmovups 0x60(%2), %%ymm11\n" + "vmovaps (%2), %%ymm0\n" + "vmovaps 0x20(%2), %%ymm1\n" + "vmovaps 0x40(%2), %%ymm2\n" + "vmovaps 0x60(%2), %%ymm3\n" + "vmovaps (%2), %%ymm4\n" + "vmovaps 0x20(%2), %%ymm5\n" + "vmovaps 0x40(%2), %%ymm6\n" + "vmovaps 0x60(%2), %%ymm7\n" + "vmovaps (%2), %%ymm8\n" + "vmovaps 0x20(%2), %%ymm9\n" + "vmovaps 0x40(%2), %%ymm10\n" + "vmovaps 0x60(%2), %%ymm11\n" "jmp 1f\n" "0:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -1212,34 +1212,34 @@ void DepthwiseSW3x32Kernel(float *dst, const float *src, const float *weight, co "movq %0, %%rcx\n" // src_h "2:\n" // LoopW - "vmovups (%1), %%ymm12\n" - "vmovups (%%rcx), %%ymm13\n" - "vmovups (%%rcx, %7), %%ymm14\n" - "vmovups (%%rcx, %7, 2), %%ymm15\n" + "vmovaps (%1), %%ymm12\n" + "vmovaps (%%rcx), %%ymm13\n" + "vmovaps (%%rcx, %7), %%ymm14\n" + "vmovaps (%%rcx, %7, 2), %%ymm15\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" - "vmovups 0x20(%1), %%ymm12\n" - "vmovups 0x20(%%rcx), %%ymm13\n" - "vmovups 0x20(%%rcx, %7), %%ymm14\n" - "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n" + "vmovaps 0x20(%1), %%ymm12\n" + "vmovaps 0x20(%%rcx), %%ymm13\n" + "vmovaps 0x20(%%rcx, %7), %%ymm14\n" + "vmovaps 0x20(%%rcx, %7, 2), %%ymm15\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" - "vmovups 0x40(%1), %%ymm12\n" - "vmovups 0x40(%%rcx), %%ymm13\n" - "vmovups 0x40(%%rcx, %7), %%ymm14\n" - "vmovups 0x40(%%rcx, %7, 2), %%ymm15\n" + "vmovaps 0x40(%1), %%ymm12\n" + "vmovaps 0x40(%%rcx), %%ymm13\n" + "vmovaps 0x40(%%rcx, %7), %%ymm14\n" + "vmovaps 0x40(%%rcx, %7, 2), %%ymm15\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" - "vmovups 0x60(%1), %%ymm12\n" - "vmovups 0x60(%%rcx), %%ymm13\n" - "vmovups 0x60(%%rcx, %7), %%ymm14\n" - "vmovups 0x60(%%rcx, %7, 2), %%ymm15\n" + "vmovaps 0x60(%1), %%ymm12\n" + "vmovaps 0x60(%%rcx), %%ymm13\n" + "vmovaps 0x60(%%rcx, %7), %%ymm14\n" + "vmovaps 0x60(%%rcx, %7, 2), %%ymm15\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" @@ -1297,18 +1297,18 @@ void DepthwiseSW3x32Kernel(float *dst, const float *src, const float *weight, co "vminps %%ymm14, %%ymm11, %%ymm11\n" "0:\n" - "vmovups %%ymm0, (%2)\n" // dst_0 - "vmovups %%ymm1, 0x20(%2)\n" - "vmovups %%ymm2, 0x40(%2)\n" - "vmovups %%ymm3, 0x60(%2)\n" - "vmovups %%ymm4, (%2, %1, 1)\n" - "vmovups %%ymm5, 0x20(%2, %1, 1)\n" - "vmovups %%ymm6, 0x40(%2, %1, 1)\n" - "vmovups %%ymm7, 0x60(%2, %1, 1)\n" - "vmovups %%ymm8, (%2, %1, 2)\n" - "vmovups %%ymm9, 0x20(%2, %1, 2)\n" - "vmovups %%ymm10, 0x40(%2, %1, 2)\n" - "vmovups %%ymm11, 0x60(%2, %1, 2)\n" + "vmovaps %%ymm0, (%2)\n" // dst_0 + "vmovaps %%ymm1, 0x20(%2)\n" + "vmovaps %%ymm2, 0x40(%2)\n" + "vmovaps %%ymm3, 0x60(%2)\n" + "vmovaps %%ymm4, (%2, %1, 1)\n" + "vmovaps %%ymm5, 0x20(%2, %1, 1)\n" + "vmovaps %%ymm6, 0x40(%2, %1, 1)\n" + "vmovaps %%ymm7, 0x60(%2, %1, 1)\n" + "vmovaps %%ymm8, (%2, %1, 2)\n" + "vmovaps %%ymm9, 0x20(%2, %1, 2)\n" + "vmovaps %%ymm10, 0x40(%2, %1, 2)\n" + "vmovaps %%ymm11, 0x60(%2, %1, 2)\n" : : "a"(act_flag), "r"(oc_algin), "r"(dst) : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", @@ -1325,10 +1325,10 @@ void DepthwiseSW1x32Kernel(float *dst, const float *src, const float *weight, co asm volatile( "cmpq $0, %2\n" "je 0f\n" - "vmovups (%2), %%ymm0\n" - "vmovups 0x20(%2), %%ymm1\n" - "vmovups 0x40(%2), %%ymm2\n" - "vmovups 0x60(%2), %%ymm3\n" + "vmovaps (%2), %%ymm0\n" + "vmovaps 0x20(%2), %%ymm1\n" + "vmovaps 0x40(%2), %%ymm2\n" + "vmovaps 0x60(%2), %%ymm3\n" "jmp 1f\n" "0:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -1339,10 +1339,10 @@ void DepthwiseSW1x32Kernel(float *dst, const float *src, const float *weight, co "movq %4, %%rsi\n" // width "movq %0, %%rcx\n" // src_h "2:\n" // Loopw - "vmovups (%%rcx), %%ymm4\n" - "vmovups 0x20(%%rcx), %%ymm5\n" - "vmovups 0x40(%%rcx), %%ymm6\n" - "vmovups 0x60(%%rcx), %%ymm7\n" + "vmovaps (%%rcx), %%ymm4\n" + "vmovaps 0x20(%%rcx), %%ymm5\n" + "vmovaps 0x40(%%rcx), %%ymm6\n" + "vmovaps 0x60(%%rcx), %%ymm7\n" // Weight data is loaded directly from memory instead of into registers for calculation. "vfmadd231ps (%1), %%ymm4, %%ymm0\n" "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n" @@ -1385,10 +1385,10 @@ void DepthwiseSW1x32Kernel(float *dst, const float *src, const float *weight, co "vminps %%ymm14, %%ymm3, %%ymm3\n" "0:\n" - "vmovups %%ymm0, (%2)\n" // dst_0 - "vmovups %%ymm1, 0x20(%2)\n" - "vmovups %%ymm2, 0x40(%2)\n" - "vmovups %%ymm3, 0x60(%2)\n" + "vmovaps %%ymm0, (%2)\n" // dst_0 + "vmovaps %%ymm1, 0x20(%2)\n" + "vmovaps %%ymm2, 0x40(%2)\n" + "vmovaps %%ymm3, 0x60(%2)\n" : : "a"(act_flag), "r"(oc_algin), "r"(dst) : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm14"); @@ -1407,19 +1407,19 @@ void DepthwiseSW4x24Kernel(float *dst, const float *src, const float *weight, co asm volatile( "cmpq $0, %2\n" "je 0f\n" - "vmovups (%2), %%ymm0\n" - "vmovups 0x20(%2), %%ymm1\n" - "vmovups 0x40(%2), %%ymm2\n" + "vmovaps (%2), %%ymm0\n" + "vmovaps 0x20(%2), %%ymm1\n" + "vmovaps 0x40(%2), %%ymm2\n" // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction. - "vmovups (%2), %%ymm3\n" - "vmovups 0x20(%2), %%ymm4\n" - "vmovups 0x40(%2), %%ymm5\n" - "vmovups (%2), %%ymm6\n" - "vmovups 0x20(%2), %%ymm7\n" - "vmovups 0x40(%2), %%ymm8\n" - "vmovups (%2), %%ymm9\n" - "vmovups 0x20(%2), %%ymm10\n" - "vmovups 0x40(%2), %%ymm11\n" + "vmovaps (%2), %%ymm3\n" + "vmovaps 0x20(%2), %%ymm4\n" + "vmovaps 0x40(%2), %%ymm5\n" + "vmovaps (%2), %%ymm6\n" + "vmovaps 0x20(%2), %%ymm7\n" + "vmovaps 0x40(%2), %%ymm8\n" + "vmovaps (%2), %%ymm9\n" + "vmovaps 0x20(%2), %%ymm10\n" + "vmovaps 0x40(%2), %%ymm11\n" "jmp 1f\n" "0:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -1438,33 +1438,33 @@ void DepthwiseSW4x24Kernel(float *dst, const float *src, const float *weight, co "movq %4, %%rsi\n" // width "movq %0, %%rcx\n" // src_h "2:\n" // LoopW - "vmovups (%1), %%ymm12\n" - "vmovups (%%rcx), %%ymm13\n" - "vmovups (%%rcx, %7, 1), %%ymm14\n" + "vmovaps (%1), %%ymm12\n" + "vmovaps (%%rcx), %%ymm13\n" + "vmovaps (%%rcx, %7, 1), %%ymm14\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm3\n" - "vmovups (%%rcx, %7, 2), %%ymm15\n" - "vmovups (%%rcx, %9), %%ymm13\n" + "vmovaps (%%rcx, %7, 2), %%ymm15\n" + "vmovaps (%%rcx, %9), %%ymm13\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" - "vmovups 0x20(%1), %%ymm12\n" - "vmovups 0x20(%%rcx), %%ymm13\n" - "vmovups 0x20(%%rcx, %7, 1), %%ymm14\n" + "vmovaps 0x20(%1), %%ymm12\n" + "vmovaps 0x20(%%rcx), %%ymm13\n" + "vmovaps 0x20(%%rcx, %7, 1), %%ymm14\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" - "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n" - "vmovups 0x20(%%rcx, %9), %%ymm13\n" + "vmovaps 0x20(%%rcx, %7, 2), %%ymm15\n" + "vmovaps 0x20(%%rcx, %9), %%ymm13\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm10\n" - "vmovups 0x40(%1), %%ymm12\n" - "vmovups 0x40(%%rcx), %%ymm13\n" - "vmovups 0x40(%%rcx, %7, 1), %%ymm14\n" + "vmovaps 0x40(%1), %%ymm12\n" + "vmovaps 0x40(%%rcx), %%ymm13\n" + "vmovaps 0x40(%%rcx, %7, 1), %%ymm14\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" - "vmovups 0x40(%%rcx, %7, 2), %%ymm15\n" - "vmovups 0x40(%%rcx, %9), %%ymm13\n" + "vmovaps 0x40(%%rcx, %7, 2), %%ymm15\n" + "vmovaps 0x40(%%rcx, %9), %%ymm13\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm11\n" @@ -1521,18 +1521,18 @@ void DepthwiseSW4x24Kernel(float *dst, const float *src, const float *weight, co "vminps %%ymm14, %%ymm11, %%ymm11\n" "0:\n" - "vmovups %%ymm0, (%2)\n" // dst_0 - "vmovups %%ymm1, 0x20(%2)\n" - "vmovups %%ymm2, 0x40(%2)\n" - "vmovups %%ymm3, (%2, %1, 1)\n" - "vmovups %%ymm4, 0x20(%2, %1, 1)\n" - "vmovups %%ymm5, 0x40(%2, %1, 1)\n" - "vmovups %%ymm6, (%2, %1, 2)\n" - "vmovups %%ymm7, 0x20(%2, %1, 2)\n" - "vmovups %%ymm8, 0x40(%2, %1, 2)\n" - "vmovups %%ymm9, (%3)\n" // dst+3 - "vmovups %%ymm10, 0x20(%3)\n" - "vmovups %%ymm11, 0x40(%3)\n" + "vmovaps %%ymm0, (%2)\n" // dst_0 + "vmovaps %%ymm1, 0x20(%2)\n" + "vmovaps %%ymm2, 0x40(%2)\n" + "vmovaps %%ymm3, (%2, %1, 1)\n" + "vmovaps %%ymm4, 0x20(%2, %1, 1)\n" + "vmovaps %%ymm5, 0x40(%2, %1, 1)\n" + "vmovaps %%ymm6, (%2, %1, 2)\n" + "vmovaps %%ymm7, 0x20(%2, %1, 2)\n" + "vmovaps %%ymm8, 0x40(%2, %1, 2)\n" + "vmovaps %%ymm9, (%3)\n" // dst+3 + "vmovaps %%ymm10, 0x20(%3)\n" + "vmovaps %%ymm11, 0x40(%3)\n" : : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3) : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", @@ -1549,9 +1549,9 @@ void DepthwiseSW1x24Kernel(float *dst, const float *src, const float *weight, co asm volatile( "cmpq $0, %2\n" "je 0f\n" - "vmovups (%2), %%ymm0\n" - "vmovups 0x20(%2), %%ymm1\n" - "vmovups 0x40(%2), %%ymm2\n" + "vmovaps (%2), %%ymm0\n" + "vmovaps 0x20(%2), %%ymm1\n" + "vmovaps 0x40(%2), %%ymm2\n" "jmp 1f\n" "0:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -1561,9 +1561,9 @@ void DepthwiseSW1x24Kernel(float *dst, const float *src, const float *weight, co "movq %4, %%rsi\n" // width "movq %0, %%rcx\n" // src_h "2:\n" // Loopw - "vmovups (%%rcx), %%ymm4\n" - "vmovups 0x20(%%rcx), %%ymm5\n" - "vmovups 0x40(%%rcx), %%ymm6\n" + "vmovaps (%%rcx), %%ymm4\n" + "vmovaps 0x20(%%rcx), %%ymm5\n" + "vmovaps 0x40(%%rcx), %%ymm6\n" // Weight data is loaded directly from memory instead of into registers for calculation. "vfmadd231ps (%1), %%ymm4, %%ymm0\n" "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n" @@ -1603,9 +1603,9 @@ void DepthwiseSW1x24Kernel(float *dst, const float *src, const float *weight, co "vminps %%ymm14, %%ymm2, %%ymm2\n" "0:\n" - "vmovups %%ymm0, (%2)\n" // dst_0 - "vmovups %%ymm1, 0x20(%2)\n" - "vmovups %%ymm2, 0x40(%2)\n" + "vmovaps %%ymm0, (%2)\n" // dst_0 + "vmovaps %%ymm1, 0x20(%2)\n" + "vmovaps %%ymm2, 0x40(%2)\n" : : "a"(act_flag), "r"(oc_algin), "r"(dst) : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm12", "%ymm14"); @@ -1624,15 +1624,15 @@ void DepthwiseSW4x16Kernel(float *dst, const float *src, const float *weight, co asm volatile( "cmpq $0, %2\n" "je 0f\n" - "vmovups (%2), %%ymm0\n" - "vmovups 0x20(%2), %%ymm1\n" + "vmovaps (%2), %%ymm0\n" + "vmovaps 0x20(%2), %%ymm1\n" // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction. - "vmovups (%2), %%ymm3\n" - "vmovups 0x20(%2), %%ymm4\n" - "vmovups (%2), %%ymm6\n" - "vmovups 0x20(%2), %%ymm7\n" - "vmovups (%2), %%ymm9\n" - "vmovups 0x20(%2), %%ymm10\n" + "vmovaps (%2), %%ymm3\n" + "vmovaps 0x20(%2), %%ymm4\n" + "vmovaps (%2), %%ymm6\n" + "vmovaps 0x20(%2), %%ymm7\n" + "vmovaps (%2), %%ymm9\n" + "vmovaps 0x20(%2), %%ymm10\n" "jmp 1f\n" "0:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -1647,21 +1647,21 @@ void DepthwiseSW4x16Kernel(float *dst, const float *src, const float *weight, co "movq %4, %%rsi\n" // width "movq %0, %%rcx\n" // src_h "2:\n" // LoopW - "vmovups (%1), %%ymm12\n" - "vmovups (%%rcx), %%ymm13\n" - "vmovups (%%rcx, %7, 1), %%ymm14\n" - "vmovups (%%rcx, %7, 2), %%ymm15\n" - "vmovups (%%rcx, %9), %%ymm2\n" + "vmovaps (%1), %%ymm12\n" + "vmovaps (%%rcx), %%ymm13\n" + "vmovaps (%%rcx, %7, 1), %%ymm14\n" + "vmovaps (%%rcx, %7, 2), %%ymm15\n" + "vmovaps (%%rcx, %9), %%ymm2\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm3\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" "vfmadd231ps %%ymm12, %%ymm2, %%ymm9\n" - "vmovups 0x20(%1), %%ymm12\n" - "vmovups 0x20(%%rcx), %%ymm13\n" - "vmovups 0x20(%%rcx, %7, 1), %%ymm14\n" - "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n" - "vmovups 0x20(%%rcx, %9), %%ymm2\n" + "vmovaps 0x20(%1), %%ymm12\n" + "vmovaps 0x20(%%rcx), %%ymm13\n" + "vmovaps 0x20(%%rcx, %7, 1), %%ymm14\n" + "vmovaps 0x20(%%rcx, %7, 2), %%ymm15\n" + "vmovaps 0x20(%%rcx, %9), %%ymm2\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n" @@ -1712,14 +1712,14 @@ void DepthwiseSW4x16Kernel(float *dst, const float *src, const float *weight, co "vminps %%ymm14, %%ymm10, %%ymm10\n" "0:\n" - "vmovups %%ymm0, (%2)\n" // dst_0 - "vmovups %%ymm1, 0x20(%2)\n" - "vmovups %%ymm3, (%2, %1, 1)\n" - "vmovups %%ymm4, 0x20(%2, %1, 1)\n" - "vmovups %%ymm6, (%2, %1, 2)\n" - "vmovups %%ymm7, 0x20(%2, %1, 2)\n" - "vmovups %%ymm9, (%3)\n" // dst+3 - "vmovups %%ymm10, 0x20(%3)\n" + "vmovaps %%ymm0, (%2)\n" // dst_0 + "vmovaps %%ymm1, 0x20(%2)\n" + "vmovaps %%ymm3, (%2, %1, 1)\n" + "vmovaps %%ymm4, 0x20(%2, %1, 1)\n" + "vmovaps %%ymm6, (%2, %1, 2)\n" + "vmovaps %%ymm7, 0x20(%2, %1, 2)\n" + "vmovaps %%ymm9, (%3)\n" // dst+3 + "vmovaps %%ymm10, 0x20(%3)\n" : : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3) : "%ecx", "%ymm0", "%ymm1", "%ymm3", "%ymm4", "%ymm6", "%ymm7", "%ymm9", "%ymm10", "%ymm12", "%ymm14"); @@ -1735,8 +1735,8 @@ void DepthwiseSW1x16Kernel(float *dst, const float *src, const float *weight, co asm volatile( "cmpq $0, %2\n" "je 0f\n" - "vmovups (%2), %%ymm0\n" - "vmovups 0x20(%2), %%ymm1\n" + "vmovaps (%2), %%ymm0\n" + "vmovaps 0x20(%2), %%ymm1\n" "jmp 1f\n" "0:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -1745,8 +1745,8 @@ void DepthwiseSW1x16Kernel(float *dst, const float *src, const float *weight, co "movq %4, %%rsi\n" // width "movq %0, %%rcx\n" // src_h "2:\n" // Loopw - "vmovups (%%rcx), %%ymm4\n" - "vmovups 0x20(%%rcx), %%ymm5\n" + "vmovaps (%%rcx), %%ymm4\n" + "vmovaps 0x20(%%rcx), %%ymm5\n" // Weight data is loaded directly from memory instead of into registers for calculation. "vfmadd231ps (%1), %%ymm4, %%ymm0\n" "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n" @@ -1783,8 +1783,8 @@ void DepthwiseSW1x16Kernel(float *dst, const float *src, const float *weight, co "vminps %%ymm14, %%ymm1, %%ymm1\n" "0:\n" - "vmovups %%ymm0, (%2)\n" // dst_0 - "vmovups %%ymm1, 0x20(%2)\n" + "vmovaps %%ymm0, (%2)\n" // dst_0 + "vmovaps %%ymm1, 0x20(%2)\n" : : "a"(act_flag), "r"(oc_algin), "r"(dst) : "%ecx", "%ymm0", "%ymm1", "%ymm12", "%ymm14"); @@ -1804,14 +1804,14 @@ void DepthwiseSW8x8Kernel(float *dst, const float *src, const float *weight, con asm volatile( "cmpq $0, %0\n" "je 0f\n" - "vmovups (%0), %%ymm0\n" - "vmovups (%0), %%ymm1\n" - "vmovups (%0), %%ymm2\n" - "vmovups (%0), %%ymm3\n" - "vmovups (%0), %%ymm4\n" - "vmovups (%0), %%ymm5\n" - "vmovups (%0), %%ymm6\n" - "vmovups (%0), %%ymm7\n" + "vmovaps (%0), %%ymm0\n" + "vmovaps (%0), %%ymm1\n" + "vmovaps (%0), %%ymm2\n" + "vmovaps (%0), %%ymm3\n" + "vmovaps (%0), %%ymm4\n" + "vmovaps (%0), %%ymm5\n" + "vmovaps (%0), %%ymm6\n" + "vmovaps (%0), %%ymm7\n" "jmp 1f\n" "0:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -1833,23 +1833,23 @@ void DepthwiseSW8x8Kernel(float *dst, const float *src, const float *weight, con "movq %0, %%rcx\n" // src_h "LoopW:\n" "movq %%rcx, %%rax\n" - "vmovups (%1), %%ymm12\n" - "vmovups (%%rax), %%ymm13\n" - "vmovups (%%rax, %6), %%ymm14\n" - "vmovups (%%rax, %6, 2), %%ymm15\n" + "vmovaps (%1), %%ymm12\n" + "vmovaps (%%rax), %%ymm13\n" + "vmovaps (%%rax, %6), %%ymm14\n" + "vmovaps (%%rax, %6, 2), %%ymm15\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" "addq %7, %%rax\n" - "vmovups (%%rax), %%ymm13\n" - "vmovups (%%rax, %6), %%ymm14\n" - "vmovups (%%rax, %6, 2), %%ymm15\n" + "vmovaps (%%rax), %%ymm13\n" + "vmovaps (%%rax, %6), %%ymm14\n" + "vmovaps (%%rax, %6, 2), %%ymm15\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" "addq %7, %%rax\n" - "vmovups (%%rax), %%ymm13\n" - "vmovups (%%rax, %6), %%ymm14\n" + "vmovaps (%%rax), %%ymm13\n" + "vmovaps (%%rax, %6), %%ymm14\n" "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" @@ -1898,14 +1898,14 @@ void DepthwiseSW8x8Kernel(float *dst, const float *src, const float *weight, con "vminps %%ymm14, %%ymm7, %%ymm7\n" "Write:\n" - "vmovups %%ymm0, (%2)\n" // dst_0 - "vmovups %%ymm1, (%2, %1)\n" - "vmovups %%ymm2, (%2, %1, 2)\n" - "vmovups %%ymm3, (%3)\n" // dst_3 - "vmovups %%ymm4, (%2, %1, 4)\n" - "vmovups %%ymm5, (%4)\n" // dst_5 - "vmovups %%ymm6, (%4, %1, 1)\n" - "vmovups %%ymm7, (%4, %1, 2)\n" + "vmovaps %%ymm0, (%2)\n" // dst_0 + "vmovaps %%ymm1, (%2, %1)\n" + "vmovaps %%ymm2, (%2, %1, 2)\n" + "vmovaps %%ymm3, (%3)\n" // dst_3 + "vmovaps %%ymm4, (%2, %1, 4)\n" + "vmovaps %%ymm5, (%4)\n" // dst_5 + "vmovaps %%ymm6, (%4, %1, 1)\n" + "vmovaps %%ymm7, (%4, %1, 2)\n" : : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3), "r"(dst_5) : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm12", "%ymm14"); @@ -1921,7 +1921,7 @@ void DepthwiseSW1x8Kernel(float *dst, const float *src, const float *weight, con asm volatile( "cmpq $0, %2\n" "je 0f\n" - "vmovups (%2), %%ymm0\n" + "vmovaps (%2), %%ymm0\n" "jmp 1f\n" "0:\n" "vxorps %%ymm0, %%ymm0, %%ymm0\n" @@ -1929,7 +1929,7 @@ void DepthwiseSW1x8Kernel(float *dst, const float *src, const float *weight, con "movq %4, %%rsi\n" // width "movq %0, %%rcx\n" // src_h "2:\n" // Loopw - "vmovups (%%rcx), %%ymm4\n" + "vmovaps (%%rcx), %%ymm4\n" // Weight data is loaded directly from memory instead of into registers for calculation. "vfmadd231ps (%1), %%ymm4, %%ymm0\n" "addq $32, %1\n" @@ -1963,7 +1963,7 @@ void DepthwiseSW1x8Kernel(float *dst, const float *src, const float *weight, con "vminps %%ymm14, %%ymm0, %%ymm0\n" "0:\n" - "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovaps %%ymm0, (%2)\n" // dst_0 : : "a"(act_flag), "r"(oc_algin), "r"(dst) : "%ecx", "%ymm0", "%ymm12", "%ymm14"); diff --git a/mindspore/lite/src/runtime/inner_allocator.cc b/mindspore/lite/src/runtime/inner_allocator.cc index 8f26e0bfb61..bc0d5223bf0 100644 --- a/mindspore/lite/src/runtime/inner_allocator.cc +++ b/mindspore/lite/src/runtime/inner_allocator.cc @@ -78,10 +78,8 @@ void *DefaultAllocator::Malloc(size_t size) { this->total_size_ += size; membuf->ref_count_ = 0; membuf->size = size; - auto aligned_bytes = - reinterpret_cast((reinterpret_cast(membuf.get()) + sizeof(MemBuf))) % aligned_size_; - aligned_bytes = aligned_bytes == 0 ? 0 : aligned_size_ - aligned_bytes; - membuf->buf = reinterpret_cast(membuf.get()) + sizeof(MemBuf) + aligned_bytes; + membuf->buf = reinterpret_cast( + (reinterpret_cast(membuf.get()) + sizeof(MemBuf) + aligned_size_ - 1) & (~(aligned_size_ - 1))); auto bufPtr = membuf->buf; allocatedList_[bufPtr] = membuf.release(); UnLock(); diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc index f14d31bb67e..d3af6af32c6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc @@ -27,6 +27,24 @@ using mindspore::lite::RET_OK; using mindspore::schema::ActivationType; namespace mindspore::kernel { +void *ConvolutionBaseCPUKernel::MallocAlignedData(size_t alignment, size_t size) { + auto ptr = malloc(size + alignment); + if (ptr == nullptr) { + MS_LOG(ERROR) << "MallocAlignedData failed!"; + return nullptr; + } + auto aligned_ptr = (reinterpret_cast(ptr) + alignment - 1) & (~(alignment - 1)); + addr_map[aligned_ptr] = ptr; + return reinterpret_cast(aligned_ptr); +} + +void ConvolutionBaseCPUKernel::FreeAlignedData(void **ptr) { + if (*ptr != nullptr) { + free(addr_map[reinterpret_cast(*ptr)]); + *ptr = nullptr; + } +} + ConvolutionBaseCPUKernel::~ConvolutionBaseCPUKernel() { if (bias_data_ != nullptr) { free(bias_data_); diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h index 378e7780dae..2f208f1cd8a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h @@ -18,6 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONVOLUTION_BASE_H_ #include +#include #include #include #ifdef ENABLE_ARM @@ -56,8 +57,11 @@ class ConvolutionBaseCPUKernel : public InnerKernel { void SetRoundingAndMultipilerMode(); int CheckResizeValid(); void FreeQuantParam(); + void *MallocAlignedData(size_t alignment, size_t size); + void FreeAlignedData(void **ptr); protected: + std::map addr_map; bool is_repack() { return is_repack_; } void *bias_data_ = nullptr; const InnerContext *ctx_ = nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.cc index 46d66cceb99..ee22cd1229f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.cc @@ -145,8 +145,7 @@ kernel::InnerKernel *ConvolutionDelegateCPUKernel::CpuConvFp32KernelSelect() { if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { #ifdef ENABLE_AVX if (conv_param->pad_d_ == 0 && conv_param->pad_l_ == 0 && conv_param->pad_r_ == 0 && conv_param->pad_u_ == 0 && - conv_param->output_channel_ % 8 == 0 && conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1 && - conv_param->input_channel_ % 8 == 0) { + conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1 && conv_param->input_channel_ % 8 == 0) { kernel = new (std::nothrow) kernel::ConvolutionSWCPUKernel( op_parameter_, in_tensors_, out_tensors_, static_cast(this->context_), origin_weight_, origin_bias_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_x86_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_x86_fp32.cc index 94821a6cbcf..dd2c75fff7c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_x86_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_x86_fp32.cc @@ -28,14 +28,8 @@ ConvolutionDepthwiseSWCPUKernelX86::~ConvolutionDepthwiseSWCPUKernelX86() { delete sliding_; sliding_ = nullptr; } - if (packed_weight_ != nullptr) { - free(packed_weight_); - packed_weight_ = nullptr; - } - if (packed_bias_ != nullptr) { - free(packed_bias_); - packed_bias_ = nullptr; - } + FreeAlignedData(reinterpret_cast(&packed_weight_)); + FreeAlignedData(reinterpret_cast(&packed_bias_)); } int ConvolutionDepthwiseSWCPUKernelX86::InitWeightBias() { @@ -45,8 +39,7 @@ int ConvolutionDepthwiseSWCPUKernelX86::InitWeightBias() { MS_ASSERT(origin_weight_ != nullptr); int oc_algin = UP_DIV(weight_tensor->Batch(), oc_tile_); int pack_weight_size = oc_algin * oc_tile_ * weight_tensor->Height() * weight_tensor->Width(); - - packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(float))); + packed_weight_ = reinterpret_cast(MallocAlignedData(alignment, pack_weight_size * sizeof(float))); if (packed_weight_ == nullptr) { MS_LOG(ERROR) << "Malloc packed_weight_ is failed!"; return RET_NULL_PTR; @@ -57,7 +50,7 @@ int ConvolutionDepthwiseSWCPUKernelX86::InitWeightBias() { auto bias_size = oc_algin * oc_tile_; auto bias_tensor = in_tensors_.at(kBiasIndex); auto ori_bias = reinterpret_cast(bias_tensor->data_c()); - packed_bias_ = reinterpret_cast(malloc(bias_size * sizeof(float))); + packed_bias_ = reinterpret_cast(MallocAlignedData(alignment, bias_size * sizeof(float))); if (packed_bias_ == nullptr) { MS_LOG(ERROR) << "Malloc bias_data buffer failed."; return RET_NULL_PTR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_x86_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_x86_fp32.h index fe060df82a7..a966ae6c8b4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_x86_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_x86_fp32.h @@ -51,6 +51,7 @@ class ConvolutionDepthwiseSWCPUKernelX86 : public ConvolutionBaseCPUKernel { float *origin_weight_ = nullptr; bool input_need_align_ = false; bool output_need_align_ = false; + size_t alignment = C32NUM; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_slidewindow_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_slidewindow_fp32.cc index 9543d1de170..75bb734a486 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_slidewindow_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_slidewindow_fp32.cc @@ -39,27 +39,22 @@ int ConvolutionSWCPUKernel::InitWeightBias() { int kernel_plane = kernel_h * kernel_w; int oc_block_num = UP_DIV(output_channel, oc_tile_); int pack_weight_size = oc_block_num * oc_tile_ * input_channel * kernel_plane; - packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(float))); + packed_weight_ = reinterpret_cast(MallocAlignedData(alignment, pack_weight_size * sizeof(float))); if (packed_weight_ == nullptr) { - MS_LOG(ERROR) << "malloc packed weight failed."; + MS_LOG(ERROR) << "MallocAlignedData packed weight failed."; return RET_NULL_PTR; } memset(packed_weight_, 0, pack_weight_size * sizeof(float)); PackNHWCTo1HWCNXFp32(kernel_h, kernel_w, output_channel, oc_block_num, input_channel, packed_weight_, ori_weight_data_); if (in_tensors_.size() == kInputSize2) { - bias_data_ = reinterpret_cast(malloc(oc_block_num * oc_tile_ * sizeof(float))); - if (bias_data_ == nullptr) { - MS_LOG(ERROR) << "malloc bias failed."; + packed_bias_ = reinterpret_cast(MallocAlignedData(alignment, oc_block_num * oc_tile_ * sizeof(float))); + if (packed_bias_ == nullptr) { + MS_LOG(ERROR) << "MallocAlignedData bias failed."; return RET_NULL_PTR; } - memset(bias_data_, 0, oc_block_num * oc_tile_ * sizeof(float)); - memcpy(bias_data_, ori_bias_data_, output_channel * sizeof(float)); - } else { - if (bias_data_ != nullptr) { - free(bias_data_); - bias_data_ = nullptr; - } + memset(packed_bias_, 0, oc_block_num * oc_tile_ * sizeof(float)); + memcpy(packed_bias_, ori_bias_data_, output_channel * sizeof(float)); } return RET_OK; } @@ -113,10 +108,10 @@ int ConvolutionSWCPUKernel::ReSize() { int ConvolutionSWCPUKernel::RunImpl(int task_id) { if (conv_param_->kernel_w_ == 1 && conv_param_->kernel_h_ == 1) { - Conv1x1SWFp32(input_data_, packed_weight_, reinterpret_cast(bias_data_), output_data_, task_id, + Conv1x1SWFp32(input_data_, packed_weight_, reinterpret_cast(packed_bias_), output_data_, task_id, conv_param_, slidingWindow_param_); } else { - ConvSWFp32(input_data_, packed_weight_, reinterpret_cast(bias_data_), output_data_, task_id, conv_param_, + ConvSWFp32(input_data_, packed_weight_, reinterpret_cast(packed_bias_), output_data_, task_id, conv_param_, slidingWindow_param_); } return RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_slidewindow_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_slidewindow_fp32.h index c0fcece2914..ffdbb02a2c0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_slidewindow_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_slidewindow_fp32.h @@ -34,8 +34,10 @@ class ConvolutionSWCPUKernel : public ConvolutionBaseCPUKernel { ~ConvolutionSWCPUKernel() override { if (packed_weight_ != nullptr) { - free(packed_weight_); - packed_weight_ = nullptr; + FreeAlignedData(reinterpret_cast(&packed_weight_)); + } + if (packed_bias_ != nullptr) { + FreeAlignedData(reinterpret_cast(&packed_bias_)); } if (slidingWindow_param_ != nullptr) { delete slidingWindow_param_; @@ -68,8 +70,10 @@ class ConvolutionSWCPUKernel : public ConvolutionBaseCPUKernel { float *ori_weight_data_ = nullptr; float *ori_bias_data_ = nullptr; float *packed_weight_ = nullptr; + float *packed_bias_ = nullptr; float *output_data_ = nullptr; float *input_data_ = nullptr; + int alignment = C32NUM; SlidingWindowParam *slidingWindow_param_ = nullptr; }; } // namespace mindspore::kernel