From 741883f4ea39f45eb57745c55760b669ea1aaadd Mon Sep 17 00:00:00 2001 From: lzk Date: Wed, 1 Sep 2021 00:51:03 -0700 Subject: [PATCH] x86nc4hw4 --- .../kernel_compiler/cpu/nnacl/common_func.c | 56 ----- .../kernel_compiler/cpu/nnacl/common_func.h | 17 -- .../cpu/nnacl/fp32/common_func_fp32.h | 6 - .../cpu/nnacl/fp32/conv_common_fp32.c | 231 ++++++++++++++---- .../cpu/nnacl/fp32/conv_common_fp32.h | 37 ++- .../cpu/nnacl/fp32/conv_depthwise_fp32.c | 9 +- .../cpu/nnacl/fp32/resize_fp32.c | 152 ++++++++---- .../fp32/convolution_depthwise_fp32_coder.cc | 1 + 8 files changed, 314 insertions(+), 195 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/common_func.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/common_func.c index 7f4e7817a93..aac7015b8e6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/common_func.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/common_func.c @@ -33,59 +33,3 @@ int Offset6d(const int *shape, const int *dims) { int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); } int8_t MaxInt8(int8_t a, int8_t b) { return a ^ ((a ^ b) & -(a < b)); } - -void ReluFp32(float *data, float *dst, int ele_num) { - int index = 0; -#ifdef ENABLE_AVX - int c8_block = DOWN_DIV(ele_num, C8NUM) * C8NUM; - for (; index < c8_block; index += C8NUM) { - MS_FLOAT32X8 relu_data = MS_LD256_F32(data + index); - MS_FLOAT32X8 zero_data = MS_MOV256_F32(0.0f); - relu_data = MS_MAX256_F32(relu_data, zero_data); - MS_ST256_F32(dst + index, relu_data); - } -#endif -#if defined(ENABLE_NEON) || defined(ENABLE_SSE) - int c4_block = DOWN_DIV(ele_num, C4NUM) * C4NUM; - for (; index < c4_block; index += C4NUM) { - MS_FLOAT32X4 relu_data = MS_LDQ_F32(data + index); - MS_FLOAT32X4 zero_data = MS_MOVQ_F32(0.0f); - relu_data = MS_MAXQ_F32(relu_data, zero_data); - MS_STQ_F32(dst + index, relu_data); - } -#endif - for (; index < ele_num; ++index) { - data[index] = data[index] < 0.0f ? 0.0f : data[index]; - } -} - -void Relu6Fp32(float *data, float *dst, int ele_num) { - int index = 0; -#ifdef ENABLE_AVX - int c8_block = DOWN_DIV(ele_num, C8NUM) * C8NUM; - for (; index < c8_block; index += C8NUM) { - MS_FLOAT32X8 relu6_data = MS_LD256_F32(data + index); - MS_FLOAT32X8 zero_data = MS_MOV256_F32(0.0f); - MS_FLOAT32X8 six_data = MS_MOV256_F32(6.0f); - relu6_data = MS_MAX256_F32(relu6_data, zero_data); - relu6_data = MS_MIN256_F32(relu6_data, six_data); - MS_ST256_F32(dst + index, relu6_data); - } -#endif - -#if defined(ENABLE_NEON) || defined(ENABLE_SSE) - int c4_block = DOWN_DIV(ele_num, C4NUM) * C4NUM; - for (; index < c4_block; index += C4NUM) { - MS_FLOAT32X4 relu6_data = MS_LDQ_F32(data + index); - MS_FLOAT32X4 zero_data = MS_MOVQ_F32(0.0f); - MS_FLOAT32X4 six_data = MS_MOVQ_F32(6.0f); - relu6_data = MS_MAXQ_F32(relu6_data, zero_data); - relu6_data = MS_MINQ_F32(relu6_data, six_data); - MS_STQ_F32(dst + index, relu6_data); - } -#endif - for (; index < ele_num; ++index) { - data[index] = data[index] < 0.0f ? 0.0f : data[index]; - data[index] = data[index] > 6.0f ? 6.0f : data[index]; - } -} diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/common_func.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/common_func.h index 74f418d430a..ba4d519e850 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/common_func.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/common_func.h @@ -28,14 +28,6 @@ extern "C" { int8_t MinInt8(int8_t a, int8_t b); int8_t MaxInt8(int8_t a, int8_t b); -void ReluFp32(float *data, float *dst, int ele_num); -void Relu6Fp32(float *data, float *dst, int ele_num); -#ifdef ENABLE_AVX -#ifdef WIN32 -void ReluFp32C8(float *data, float *dst, int ele_num); -void Relu6Fp32C8(float *data, float *dst, int ele_num); -#endif -#endif int Offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3); int OffsetComm(const int *shape, const int dim0, const int dim1, const int dim2); int Offset4d(const int *shape, const int *dims); @@ -62,15 +54,6 @@ static inline int GetStride(int *strides, const int *shape, int length) { } return stride; } - -#ifdef ENABLE_ARM64 -void BiasAdd(const float *bias, float *data, size_t oc4, size_t plan_size); -void BiasAddRelu6(const float *bias, float *data, size_t oc4, size_t plan_size); -void BiasAddRelu(const float *bias, float *data, size_t oc4, size_t plan_size); -void Relu6(float *data, size_t element4); -void Relu(float *data, size_t element4); -#endif - #ifdef __cplusplus } #endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/common_func_fp32.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/common_func_fp32.h index 649850dcdb2..dc23ea74718 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/common_func_fp32.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/common_func_fp32.h @@ -69,12 +69,6 @@ void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t #endif #ifdef ENABLE_ARM64 -void BiasAdd(const float *bias, float *data, size_t oc4, size_t plan_size); -void BiasAddRelu6(const float *bias, float *data, size_t oc4, size_t plan_size); -void BiasAddRelu(const float *bias, float *data, size_t oc4, size_t plan_size); -void Relu6(float *data, size_t element4); -void Relu(float *data, size_t element4); - void DeconvDwFp32Border(float *dst, const float *src, const float *weight, size_t height, size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.c index 5d3b9688d48..d5e0d916bf1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.c @@ -161,7 +161,7 @@ void ConvFp32OutNC4HW4(const float *input_data, float *packed_input, const float #ifdef ENABLE_AVX void SWBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sw_param, const SWConvKernel kernel, - int act_type, int ow_bock, int oc_block) { + int act_type, int ow_bock, int oc_block, size_t write_mode) { for (int oh = top; oh < bottom; oh++) { // now h is only loop one time int ih = oh * conv_param->stride_h_ - conv_param->pad_u_; int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); @@ -179,7 +179,7 @@ void SWBorder(float *dst, const float *src, const float *weight, const float *bi kernel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, act_type, ow_bock, oc_block, sw_param->block_channel_, sw_param->ic_align_, sw_param->in_kw_step_, sw_param->in_kh_step_, sw_param->in_sw_step_, - (conv_param->kernel_w_ - end_kw + start_kw) * C8NUM * oc_block * sw_param->ic_align_); + (conv_param->kernel_w_ - end_kw + start_kw) * C8NUM * oc_block * sw_param->ic_align_, write_mode); dst_kernel += ow_bock * sw_param->block_channel_; } // width loop dst += sw_param->out_h_step_; @@ -232,6 +232,11 @@ void ConvSWFp32(const float *input_data, const float *packed_weight, const float int in_h_start = top * stride_h - pad_u; int in_w_start = left * stride_w - pad_l; int center_step = in_h_start * in_h_step + in_w_start * ic_algin; + int write_mode = conv_param->out_format_; + int kernel_out_step = oc_algin; + if (write_mode == 13) { + kernel_out_step = out_h * out_w; + } const int ow_block_num[4] = {12, 6, 4, 3}; const SWConvKernel kernel[4][2] = {{SWConv1x8Kernel, SWConv12x8Kernel}, {SWConv1x16Kernel, SWConv6x16Kernel}, @@ -254,11 +259,11 @@ void ConvSWFp32(const float *input_data, const float *packed_weight, const float const SWConvKernel kernel_border = kernel[oc_block - 1][0]; if (oh < top || oh >= bottom) { // oh in up or down border SWBorder(dst_w, input_data, weight, bias, oh, oh + 1, 0, out_w, conv_param, sw_param, kernel_border, act_type, - 1, oc_block); + 1, oc_block, write_mode); } else { // oh in center // ow in right SWBorder(dst_w, input_data, weight, bias, oh, oh + 1, 0, left, conv_param, sw_param, kernel_border, act_type, - 1, oc_block); + 1, oc_block, write_mode); // ow in center const float *src_w = src_h + (oh - top) * in_sh_step; int ow_block = ow_block_num[oc_block - 1]; // 12 6 4 3 @@ -269,12 +274,12 @@ void ConvSWFp32(const float *input_data, const float *packed_weight, const float } kernel[oc_block - 1][ow_block / ow_block_num[oc_block - 1]]( dst_w + ow * block_channel, src_w, weight, bias, kernel_h, kernel_w, act_type, ow_block, oc_block, - oc_algin, ic_algin, in_kw_step, in_kh_step, in_sw_step, 0); + kernel_out_step, ic_algin, in_kw_step, in_kh_step, in_sw_step, 0, write_mode); src_w += ow_block * in_sw_step; } // ow in left SWBorder(dst_w, input_data, weight, bias, oh, oh + 1, right, out_w, conv_param, sw_param, kernel_border, - act_type, 1, oc_block); + act_type, 1, oc_block, write_mode); } } } // output h loop @@ -284,12 +289,14 @@ void ConvSWFp32(const float *input_data, const float *packed_weight, const float } void SWConv3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, - size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, - size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { in_kh_step *= sizeof(float); in_sw_step *= sizeof(float); in_kw_step *= sizeof(float); - oc_algin *= sizeof(float); + float *dst_4 = dst + out_step * C3NUM; + out_step *= sizeof(float); kw_remainder *= sizeof(float); asm volatile( "cmpq $0, %2\n" @@ -403,6 +410,9 @@ void SWConv3x32Kernel(float *dst, const float *src, const float *weight, const f "vminps %%ymm14, %%ymm11, %%ymm11\n" "0:\n" + "cmpq $13, %3\n" + "je 1f\n" + // write to nhwc "vmovups %%ymm0, (%2)\n" // dst_0 "vmovups %%ymm1, 0x20(%2)\n" "vmovups %%ymm2, 0x40(%2)\n" @@ -415,19 +425,37 @@ void SWConv3x32Kernel(float *dst, const float *src, const float *weight, const f "vmovups %%ymm9, 0x20(%2, %1, 2)\n" "vmovups %%ymm10, 0x40(%2, %1, 2)\n" "vmovups %%ymm11, 0x60(%2, %1, 2)\n" + "jmp 2f\n" + "1:\n" + // write to nc8hw8 + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm4, 0x20(%2)\n" + "vmovups %%ymm8, 0x40(%2)\n" + "vmovups %%ymm1, (%2, %1, 1)\n" + "vmovups %%ymm5, 0x20(%2, %1, 1)\n" + "vmovups %%ymm9, 0x40(%2, %1, 1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm6, 0x20(%2, %1, 2)\n" + "vmovups %%ymm10, 0x40(%2, %1, 2)\n" + "vmovups %%ymm3, (%4)\n" + "vmovups %%ymm7, 0x20(%4)\n" + "vmovups %%ymm11, 0x40(%4)\n" + "2:\n" : - : "a"(act_flag), "r"(oc_algin), "r"(dst) + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode), "r"(dst_4) : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm14"); } void SWConv1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, - size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, - size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { in_kh_step *= sizeof(float); in_kw_step *= sizeof(float); - oc_algin *= sizeof(float); + out_step *= sizeof(float); kw_remainder *= sizeof(float); + float *dst_4 = dst + out_step * C3NUM; asm volatile( "cmpq $0, %2\n" "je 0f\n" @@ -494,25 +522,37 @@ void SWConv1x32Kernel(float *dst, const float *src, const float *weight, const f "vminps %%ymm14, %%ymm3, %%ymm3\n" "0:\n" + "cmpq $13, %3\n" + "je 1f\n" + // write to nhwc "vmovups %%ymm0, (%2)\n" // dst_0 "vmovups %%ymm1, 0x20(%2)\n" "vmovups %%ymm2, 0x40(%2)\n" "vmovups %%ymm3, 0x60(%2)\n" + "jmp 2f\n" + "1:\n" + // write to nc8hw8 + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, (%2, %1, 1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm3, (%4)\n" + "2:\n" : - : "a"(act_flag), "r"(oc_algin), "r"(dst) + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode), "r"(dst_4) : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm14"); } void SWConv4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, - size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, - size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { in_kh_step *= sizeof(float); in_kw_step *= sizeof(float); in_sw_step *= sizeof(float); kw_remainder *= sizeof(float); size_t src_3_step = 3 * in_sw_step; - float *dst_3 = dst + 3 * oc_algin; - oc_algin *= sizeof(float); + float *dst_3 = dst + 3 * out_step; + out_step *= sizeof(float); asm volatile( "cmpq $0, %0\n" "je 0f\n" @@ -640,6 +680,9 @@ void SWConv4x24Kernel(float *dst, const float *src, const float *weight, const f "vminps %%ymm14, %%ymm11, %%ymm11\n" "0:\n" + "cmpq $13, %4\n" + "je 1f\n" + // write to nhwc "vmovups %%ymm0, (%2)\n" // dst_0 "vmovups %%ymm1, 0x20(%2)\n" "vmovups %%ymm2, 0x40(%2)\n" @@ -652,19 +695,36 @@ void SWConv4x24Kernel(float *dst, const float *src, const float *weight, const f "vmovups %%ymm9, (%3)\n" // dst+3 "vmovups %%ymm10, 0x20(%3)\n" "vmovups %%ymm11, 0x40(%3)\n" + "jmp 2f\n" + "1:\n" + // write to nc8hw8 + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm3, 0x20(%2)\n" + "vmovups %%ymm6, 0x40(%2)\n" + "vmovups %%ymm9, 0x60(%2)\n" + "vmovups %%ymm1, (%2, %1, 1)\n" + "vmovups %%ymm4, 0x20(%2, %1, 1)\n" + "vmovups %%ymm7, 0x40(%2, %1, 1)\n" + "vmovups %%ymm10, 0x60(%2, %1, 1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm5, 0x20(%2, %1, 2)\n" + "vmovups %%ymm8, 0x60(%2, %1, 2)\n" + "vmovups %%ymm11, 0x60(%2, %1, 2)\n" + "2:\n" : - : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3) + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3), "r"(write_mode) : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm14"); } void SWConv1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, - size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, - size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { in_kh_step *= sizeof(float); in_kw_step *= sizeof(float); kw_remainder *= sizeof(float); - oc_algin *= sizeof(float); + out_step *= sizeof(float); asm volatile( "cmpq $0, %2\n" "je 0f\n" @@ -726,24 +786,35 @@ void SWConv1x24Kernel(float *dst, const float *src, const float *weight, const f "vminps %%ymm14, %%ymm2, %%ymm2\n" "0:\n" + "cmpq $13, %3\n" + "je 1f\n" + // write to nhwc "vmovups %%ymm0, (%2)\n" // dst_0 "vmovups %%ymm1, 0x20(%2)\n" "vmovups %%ymm2, 0x40(%2)\n" + "jmp 2f\n" + "1:\n" + // write to nc4hw4 + "vmovups %%ymm0, (%2)\n" + "vmovups %%ymm1, (%2, %1, 1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "2:\n" : - : "a"(act_flag), "r"(oc_algin), "r"(dst) + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode) : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm12", "%ymm14"); } void SWConv6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, - size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, - size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { in_kh_step *= sizeof(float); in_kw_step *= sizeof(float); in_sw_step *= sizeof(float); kw_remainder *= sizeof(float); size_t src_3_step = 3 * in_sw_step; - float *dst_3 = dst + 3 * oc_algin; - oc_algin *= sizeof(float); + float *dst_3 = dst + 3 * out_step; + out_step *= sizeof(float); asm volatile( "cmpq $0, %0\n" "je 0f\n" @@ -874,6 +945,9 @@ void SWConv6x16Kernel(float *dst, const float *src, const float *weight, const f "vminps %%ymm14, %%ymm11, %%ymm11\n" "0:\n" + "cmpq $13, %4\n" + "je 1f\n" + // write to nhwc "vmovups %%ymm0, (%2)\n" // dst_0 "vmovups %%ymm1, 0x20(%2)\n" "vmovups %%ymm2, (%2, %1, 1)\n" @@ -886,19 +960,36 @@ void SWConv6x16Kernel(float *dst, const float *src, const float *weight, const f "vmovups %%ymm9, 0x20(%3, %1, 1)\n" "vmovups %%ymm10, (%3, %1, 2)\n" "vmovups %%ymm11, 0x20(%3, %1, 2)\n" + "jmp 2f\n" + "1:\n" + // write to nc8hw8 + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm2, 0x20(%2)\n" + "vmovups %%ymm4, 0x40(%2)\n" + "vmovups %%ymm6, 0x60(%2)\n" // dst+3 + "vmovups %%ymm8, 0x80(%2)\n" + "vmovups %%ymm10, 0xA0(%2)\n" + "vmovups %%ymm1, (%2, %1, 1)\n" + "vmovups %%ymm3, 0x20(%2, %1, 1)\n" + "vmovups %%ymm5, 0x40(%2, %1, 1)\n" + "vmovups %%ymm7, 0x60(%2, %1, 1)\n" + "vmovups %%ymm9, 0x80(%2, %1, 1)\n" + "vmovups %%ymm11, 0xA0(%2, %1, 1)\n" + "2:\n" : - : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3) + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3), "r"(write_mode) : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm14"); } void SWConv1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, - size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, - size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { in_kh_step *= sizeof(float); in_kw_step *= sizeof(float); kw_remainder *= sizeof(float); - oc_algin *= sizeof(float); + out_step *= sizeof(float); asm volatile( "cmpq $0, %2\n" "je 0f\n" @@ -955,25 +1046,35 @@ void SWConv1x16Kernel(float *dst, const float *src, const float *weight, const f "vminps %%ymm14, %%ymm1, %%ymm1\n" "0:\n" + "cmpq $13, %3\n" + "je 1f\n" + // write to nhwc "vmovups %%ymm0, (%2)\n" // dst_0 "vmovups %%ymm1, 0x20(%2)\n" + "jmp 2f\n" + "1:\n" + // write nc8hw8 + "vmovups %%ymm0, (%2)\n" + "vmovups %%ymm1, (%2, %1, 1)\n" + "2:\n" : - : "a"(act_flag), "r"(oc_algin), "r"(dst) + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode) : "%ecx", "%ymm0", "%ymm1", "%ymm12", "%ymm14"); } void SWConv12x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, - size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, - size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { in_kh_step *= sizeof(float); in_sw_step *= sizeof(float); in_kw_step *= sizeof(float); kw_remainder *= sizeof(float); size_t src_3_step = 3 * in_sw_step; - float *dst_3 = dst + 3 * oc_algin; - float *dst_5 = dst + 5 * oc_algin; - float *dst_9 = dst + 9 * oc_algin; - oc_algin *= sizeof(float); + float *dst_3 = dst + 3 * out_step; + float *dst_5 = dst + 5 * out_step; + float *dst_9 = dst + 9 * out_step; + out_step *= sizeof(float); asm volatile( "cmpq $0, %0\n" "je 0f\n" @@ -1105,6 +1206,9 @@ void SWConv12x8Kernel(float *dst, const float *src, const float *weight, const f "vminps %%ymm14, %%ymm11, %%ymm11\n" "Write:\n" + "cmpq $13, %6\n" + "je WriteNC8HW8\n" + // write nhwc "vmovups %%ymm0, (%2)\n" // dst_0 "vmovups %%ymm1, (%2, %1)\n" "vmovups %%ymm2, (%2, %1, 2)\n" @@ -1117,21 +1221,37 @@ void SWConv12x8Kernel(float *dst, const float *src, const float *weight, const f "vmovups %%ymm9, (%5)\n" // dst_9 "vmovups %%ymm10, (%5, %1, 1)\n" "vmovups %%ymm11, (%5, %1, 2)\n" + "jmp End\n" + "WriteNC8HW8:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "vmovups %%ymm3, 0x60(%2)\n" // dst_3 + "vmovups %%ymm4, 0x80(%2)\n" + "vmovups %%ymm5, 0xA0(%2)\n" // dst_5 + "vmovups %%ymm6, 0xC0(%2)\n" + "vmovups %%ymm7, 0xE0(%2)\n" + "vmovups %%ymm8, 0x100(%2)\n" + "vmovups %%ymm9, 0x120(%2)\n" // dst_9 + "vmovups %%ymm10, 0x140(%2)\n" + "vmovups %%ymm11, 0x160(%2)\n" + "End:\n" : - : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3), "r"(dst_5), "r"(dst_9) + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3), "r"(dst_5), "r"(dst_9), "r"(write_mode) : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm14"); } void SWConv4x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, - size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, - size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { in_kh_step *= sizeof(float); in_sw_step *= sizeof(float); in_kw_step *= sizeof(float); size_t src_step = 3 * in_sw_step; - float *dst_3 = dst + 3 * oc_algin; - oc_algin *= sizeof(float); + float *dst_3 = dst + 3 * out_step; + out_step *= sizeof(float); asm volatile( "cmpq $0, %0\n" "je 0f\n" @@ -1215,17 +1335,18 @@ void SWConv4x8Kernel(float *dst, const float *src, const float *weight, const fl "vmovups %%ymm2, (%2, %1, 2)\n" "vmovups %%ymm3, (%3)\n" // dst_3 : - : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3) + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3) : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm12", "%ymm14"); } void SWConv1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, - size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, - size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { in_kh_step *= sizeof(float); in_kw_step *= sizeof(float); kw_remainder *= sizeof(float); - oc_algin *= sizeof(float); + out_step *= sizeof(float); asm volatile( "cmpq $0, %2\n" "je 0f\n" @@ -1277,16 +1398,18 @@ void SWConv1x8Kernel(float *dst, const float *src, const float *weight, const fl "vminps %%ymm14, %%ymm0, %%ymm0\n" "0:\n" + // write to nhec and nc8hw8 is identical! "vmovups %%ymm0, (%2)\n" // dst_0 : - : "a"(act_flag), "r"(oc_algin), "r"(dst) + : "a"(act_flag), "r"(out_step), "r"(dst) : "%ecx", "%ymm0", "%ymm12", "%ymm14"); } #ifdef ENABLE_DEBUG void SWConvWxKKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, - size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, - size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { __m256 dst_data[12]; const float *src_kh[12]; const float *src_kw[12]; @@ -1339,7 +1462,13 @@ void SWConvWxKKernel(float *dst, const float *src, const float *weight, const fl if (0x2 & act_flag) { // relu dst_data[i * oc_block + j] = _mm256_max_ps(dst_data[i * oc_block + j], _mm256_set1_ps(0.0f)); } - _mm256_storeu_ps(dst + i * oc_algin + j * C8NUM, dst_data[i * oc_block + j]); + if (write_mode == 13) { + // write nc8hw8 + _mm256_storeu_ps(dst + j * out_step + i * C8NUM, dst_data[i * oc_block + j]); + } else { + // write nhwc + _mm256_storeu_ps(dst + i * out_step + j * C8NUM, dst_data[i * oc_block + j]); + } } } } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.h index ae8b581a9d3..873b8cd0739 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_common_fp32.h @@ -43,11 +43,11 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_ typedef void (*SWConvKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, - size_t kw_remainder); + size_t kw_remainder, size_t write_mode); void SWBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sw_param, SWConvKernel kernel, - int act_type, int ow_bock, int oc_block); + int act_type, int ow_bock, int oc_block, size_t write_mode); void ConvSWFp32(const float *input_data, const float *packed_weight, const float *bias_data, float *output_data, int task_id, ConvParameter *conv_param, SlidingWindowParam *sw_param); @@ -57,48 +57,59 @@ void SWCenter(float *dst, const float *src, const float *weight, const float *bi #ifdef ENABLE_DEBUG void SWConvWxKKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, - size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode); #endif void SWConv3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, - size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode); void SWConv1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, - size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode); void SWConv4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, - size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode); void SWConv1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, - size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode); void SWConv6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, - size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode); void SWConv1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, - size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode); void SWConv12x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, - size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode); void SWConv8x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, - size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode); void SWConv4x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, - size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode); void SWConv1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, - size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode); #endif #ifdef __cplusplus } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_depthwise_fp32.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_depthwise_fp32.c index 6a6164893c2..5e1d019d7d4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_depthwise_fp32.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/conv_depthwise_fp32.c @@ -18,6 +18,7 @@ #include "nnacl/fp32/common_func_fp32.h" #include "nnacl/intrinsics/ms_simd_instructions.h" #include "nnacl/errorcode.h" +#include "nnacl/fp32/activation_fp32.h" #if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, int num_pixels, @@ -80,10 +81,10 @@ int ConvDw(float *output_data, const float *input_data, const float *weight_data } } if (relu) { - ReluFp32(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_); + Fp32Relu(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data); } if (relu6) { - Relu6Fp32(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_); + Fp32Relu6(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data); } } } @@ -779,10 +780,10 @@ void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, c } } if (relu) { - ReluFp32(output, output, channels); + Fp32Relu(output, channels, output); } if (relu6) { - Relu6Fp32(output, output, channels); + Fp32Relu6(output, channels, output); } output += channels; input = input + input_stride; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/resize_fp32.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/resize_fp32.c index de4c238d48f..982b6f62af3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/resize_fp32.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/resize_fp32.c @@ -164,16 +164,24 @@ int InterpRow(const float *src_line, float *linear_output, int new_width, const int w; for (w = 0; w < new_width; w++) { int c = 0; -#ifdef ENABLE_NEON - float32x4_t left_w = vdupq_n_f32(x_left_weights[w]); - float32x4_t right_w = vdupq_n_f32(1.0f - x_left_weights[w]); - - for (; c <= in_c - 4; c += 4) { - float32x4_t left = vld1q_f32(src_line + x_lefts[w] * in_c + c); - float32x4_t right = vld1q_f32(src_line + x_rights[w] * in_c + c); - - float32x4_t interp_value = left * left_w + right * right_w; - vst1q_f32(linear_output + w * in_c + c, interp_value); +#if defined(ENABLE_AVX) + MS_FLOAT32X8 left_w_8 = MS_MOV256_F32(x_left_weights[w]); + MS_FLOAT32X8 right_w_8 = MS_MOV256_F32(1.0f - x_left_weights[w]); + for (; c <= in_c - C8NUM; c += C8NUM) { + MS_FLOAT32X8 left = MS_LD256_F32(src_line + x_lefts[w] * in_c + c); + MS_FLOAT32X8 right = MS_LD256_F32(src_line + x_rights[w] * in_c + c); + MS_FLOAT32X8 interp_value = left * left_w_8 + right * right_w_8; + MS_ST256_F32(linear_output + w * in_c + c, interp_value); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 left_w = MS_MOVQ_F32(x_left_weights[w]); + MS_FLOAT32X4 right_w = MS_MOVQ_F32(1.0f - x_left_weights[w]); + for (; c <= in_c - C4NUM; c += C4NUM) { + MS_FLOAT32X4 left = MS_LDQ_F32(src_line + x_lefts[w] * in_c + c); + MS_FLOAT32X4 right = MS_LDQ_F32(src_line + x_rights[w] * in_c + c); + MS_FLOAT32X4 interp_value = left * left_w + right * right_w; + MS_STQ_F32(linear_output + w * in_c + c, interp_value); } #endif int left_w_offset = x_lefts[w] * in_c; @@ -192,15 +200,24 @@ int InterpCol(const float *bottom_line, const float *top_line, float *output, in int w; for (w = 0; w < new_width; w++) { int c = 0; -#ifdef ENABLE_NEON - float32x4_t bottom_w = vdupq_n_f32(y_bottom_weight); - float32x4_t top_w = vdupq_n_f32(1.0f - y_bottom_weight); - - for (; c <= in_c - 4; c += 4) { - float32x4_t bottom = vld1q_f32(bottom_line + w * in_c + c); - float32x4_t top = vld1q_f32(top_line + w * in_c + c); - float32x4_t interp_value = bottom * bottom_w + top * top_w; - vst1q_f32(output + w * in_c + c, interp_value); +#if defined(ENABLE_AVX) + MS_FLOAT32X8 bottom_w_8 = MS_MOV256_F32(y_bottom_weight); + MS_FLOAT32X8 top_w_8 = MS_MOV256_F32(1.0f - y_bottom_weight); + for (; c <= in_c - C8NUM; c += C8NUM) { + MS_FLOAT32X8 bottom = MS_LD256_F32(bottom_line + w * in_c + c); + MS_FLOAT32X8 top = MS_LD256_F32(top_line + w * in_c + c); + MS_FLOAT32X8 interp_value = bottom * bottom_w_8 + top * top_w_8; + MS_ST256_F32(output + w * in_c + c, interp_value); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 bottom_w = MS_MOVQ_F32(y_bottom_weight); + MS_FLOAT32X4 top_w = MS_MOVQ_F32(1.0f - y_bottom_weight); + for (; c <= in_c - C4NUM; c += C4NUM) { + MS_FLOAT32X4 bottom = MS_LDQ_F32(bottom_line + w * in_c + c); + MS_FLOAT32X4 top = MS_LDQ_F32(top_line + w * in_c + c); + MS_FLOAT32X4 interp_value = bottom * bottom_w + top * top_w; + MS_STQ_F32(output + w * in_c + c, interp_value); } #endif for (; c < in_c; c++) { @@ -299,21 +316,40 @@ void BicubicInterpRow(const float *src, float *dst, const float *weights, const const float *src2_w = src + lefts[4 * w + 2] * channel; const float *src3_w = src + lefts[4 * w + 3] * channel; int c = 0; -#ifdef ENABLE_NEON - float32x4_t weight0_vec = vdupq_n_f32(weight[0]); - float32x4_t weight1_vec = vdupq_n_f32(weight[1]); - float32x4_t weight2_vec = vdupq_n_f32(weight[2]); - float32x4_t weight3_vec = vdupq_n_f32(weight[3]); - - for (; c <= channel - 4; c += 4) { - float32x4_t src0_vec = vld1q_f32(src0_w + c); - float32x4_t src1_vec = vld1q_f32(src1_w + c); - float32x4_t src2_vec = vld1q_f32(src2_w + c); - float32x4_t src3_vec = vld1q_f32(src3_w + c); - - float32x4_t interp_value = - src0_vec * weight0_vec + src1_vec * weight1_vec + src2_vec * weight2_vec + src3_vec * weight3_vec; - vst1q_f32(dst_w + c, interp_value); +#if defined(ENABLE_AVX) + MS_FLOAT32X8 weight0_vec_8 = MS_MOV256_F32(weight[0]); + MS_FLOAT32X8 weight1_vec_8 = MS_MOV256_F32(weight[1]); + MS_FLOAT32X8 weight2_vec_8 = MS_MOV256_F32(weight[2]); + MS_FLOAT32X8 weight3_vec_8 = MS_MOV256_F32(weight[3]); + for (; c <= channel - C8NUM; c += C8NUM) { + MS_FLOAT32X8 src0_vec = MS_LD256_F32(src0_w + c); + MS_FLOAT32X8 src1_vec = MS_LD256_F32(src1_w + c); + MS_FLOAT32X8 src2_vec = MS_LD256_F32(src2_w + c); + MS_FLOAT32X8 src3_vec = MS_LD256_F32(src3_w + c); + MS_FLOAT32X8 dst0 = MS_MUL256_F32(src0_vec, weight0_vec_8); + MS_FLOAT32X8 dst1 = MS_MUL256_F32(src1_vec, weight1_vec_8); + MS_FLOAT32X8 dst2 = MS_MUL256_F32(src2_vec, weight2_vec_8); + MS_FLOAT32X8 dst3 = MS_MUL256_F32(src3_vec, weight3_vec_8); + MS_FLOAT32X8 interp_value = MS_ADD256_F32(dst3, MS_ADD256_F32(dst2, MS_ADD256_F32(dst1, dst0))); + MS_ST256_F32(dst_w + c, interp_value); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 weight0_vec = MS_MOVQ_F32(weight[0]); + MS_FLOAT32X4 weight1_vec = MS_MOVQ_F32(weight[1]); + MS_FLOAT32X4 weight2_vec = MS_MOVQ_F32(weight[2]); + MS_FLOAT32X4 weight3_vec = MS_MOVQ_F32(weight[3]); + for (; c <= channel - C4NUM; c += C4NUM) { + MS_FLOAT32X4 src0_vec = MS_LDQ_F32(src0_w + c); + MS_FLOAT32X4 src1_vec = MS_LDQ_F32(src1_w + c); + MS_FLOAT32X4 src2_vec = MS_LDQ_F32(src2_w + c); + MS_FLOAT32X4 src3_vec = MS_LDQ_F32(src3_w + c); + MS_FLOAT32X4 dst0 = MS_MULQ_F32(src0_vec, weight0_vec); + MS_FLOAT32X4 dst1 = MS_MULQ_F32(src1_vec, weight1_vec); + MS_FLOAT32X4 dst2 = MS_MULQ_F32(src2_vec, weight2_vec); + MS_FLOAT32X4 dst3 = MS_MULQ_F32(src3_vec, weight3_vec); + MS_FLOAT32X4 interp_value = MS_ADDQ_F32(dst3, MS_ADDQ_F32(dst2, MS_ADDQ_F32(dst1, dst0))); + MS_STQ_F32(dst_w + c, interp_value); } #endif for (; c < channel; c++) { @@ -334,20 +370,40 @@ void BicubicInterpCol(const float *src, float *dst, const float *weights, int wi const float *src2_w = src2 + w * channel; const float *src3_w = src3 + w * channel; int c = 0; -#ifdef ENABLE_NEON - float32x4_t weight0_vec = vdupq_n_f32(weights[0]); - float32x4_t weight1_vec = vdupq_n_f32(weights[1]); - float32x4_t weight2_vec = vdupq_n_f32(weights[2]); - float32x4_t weight3_vec = vdupq_n_f32(weights[3]); - - for (; c <= channel - 4; c += 4) { - float32x4_t src0_vec = vld1q_f32(src0_w + c); - float32x4_t src1_vec = vld1q_f32(src1_w + c); - float32x4_t src2_vec = vld1q_f32(src2_w + c); - float32x4_t src3_vec = vld1q_f32(src3_w + c); - float32x4_t interp_value = - src0_vec * weight0_vec + src1_vec * weight1_vec + src2_vec * weight2_vec + src3_vec * weight3_vec; - vst1q_f32(dst_w + c, interp_value); +#ifdef ENABLE_AVX + MS_FLOAT32X8 weight0_vec_8 = MS_MOV256_F32(weights[0]); + MS_FLOAT32X8 weight1_vec_8 = MS_MOV256_F32(weights[1]); + MS_FLOAT32X8 weight2_vec_8 = MS_MOV256_F32(weights[2]); + MS_FLOAT32X8 weight3_vec_8 = MS_MOV256_F32(weights[3]); + for (; c <= channel - C8NUM; c += C8NUM) { + MS_FLOAT32X8 src0_vec = MS_LD256_F32(src0_w + c); + MS_FLOAT32X8 src1_vec = MS_LD256_F32(src1_w + c); + MS_FLOAT32X8 src2_vec = MS_LD256_F32(src2_w + c); + MS_FLOAT32X8 src3_vec = MS_LD256_F32(src3_w + c); + MS_FLOAT32X8 dst1 = MS_MUL256_F32(src0_vec, weight0_vec_8); + MS_FLOAT32X8 dst2 = MS_MUL256_F32(src1_vec, weight1_vec_8); + MS_FLOAT32X8 dst3 = MS_MUL256_F32(src2_vec, weight2_vec_8); + MS_FLOAT32X8 dst4 = MS_MUL256_F32(src3_vec, weight3_vec_8); + MS_FLOAT32X8 interp_value = MS_ADD256_F32(dst4, MS_ADD256_F32(dst3, MS_ADD256_F32(dst1, dst2))); + MS_ST256_F32(dst_w + c, interp_value); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 weight0_vec = MS_MOVQ_F32(weights[0]); + MS_FLOAT32X4 weight1_vec = MS_MOVQ_F32(weights[1]); + MS_FLOAT32X4 weight2_vec = MS_MOVQ_F32(weights[2]); + MS_FLOAT32X4 weight3_vec = MS_MOVQ_F32(weights[3]); + for (; c <= channel - C4NUM; c += C4NUM) { + MS_FLOAT32X4 src0_vec = MS_LDQ_F32(src0_w + c); + MS_FLOAT32X4 src1_vec = MS_LDQ_F32(src1_w + c); + MS_FLOAT32X4 src2_vec = MS_LDQ_F32(src2_w + c); + MS_FLOAT32X4 src3_vec = MS_LDQ_F32(src3_w + c); + MS_FLOAT32X4 dst1 = MS_MULQ_F32(src0_vec, weight0_vec); + MS_FLOAT32X4 dst2 = MS_MULQ_F32(src1_vec, weight1_vec); + MS_FLOAT32X4 dst3 = MS_MULQ_F32(src2_vec, weight2_vec); + MS_FLOAT32X4 dst4 = MS_MULQ_F32(src3_vec, weight3_vec); + MS_FLOAT32X4 interp_value = MS_ADDQ_F32(dst4, MS_ADDQ_F32(dst3, MS_ADDQ_F32(dst1, dst2))); + MS_STQ_F32(dst_w + c, interp_value); } #endif for (; c < channel; c++) { diff --git a/mindspore/lite/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc b/mindspore/lite/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc index e5adf72139a..4c4df49c8dc 100644 --- a/mindspore/lite/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc +++ b/mindspore/lite/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc @@ -75,6 +75,7 @@ int ConvolutionDepthwiseFP32Coder::DoCode(CoderContext *const context) { }, { "conv_depthwise_fp32.c", + "activation_fp32.c", }, {}); nnacl::NNaclFp32Serializer code;