!22757 [MS][LITE][CPU] sw nc4hw4 support

Merge pull request !22757 from liuzhongkai/code_re5
This commit is contained in:
i-robot 2021-09-02 07:42:55 +00:00 committed by Gitee
commit 50847c9659
8 changed files with 314 additions and 195 deletions

View File

@ -33,59 +33,3 @@ int Offset6d(const int *shape, const int *dims) {
int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); }
int8_t MaxInt8(int8_t a, int8_t b) { return a ^ ((a ^ b) & -(a < b)); }
void ReluFp32(float *data, float *dst, int ele_num) {
int index = 0;
#ifdef ENABLE_AVX
int c8_block = DOWN_DIV(ele_num, C8NUM) * C8NUM;
for (; index < c8_block; index += C8NUM) {
MS_FLOAT32X8 relu_data = MS_LD256_F32(data + index);
MS_FLOAT32X8 zero_data = MS_MOV256_F32(0.0f);
relu_data = MS_MAX256_F32(relu_data, zero_data);
MS_ST256_F32(dst + index, relu_data);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
int c4_block = DOWN_DIV(ele_num, C4NUM) * C4NUM;
for (; index < c4_block; index += C4NUM) {
MS_FLOAT32X4 relu_data = MS_LDQ_F32(data + index);
MS_FLOAT32X4 zero_data = MS_MOVQ_F32(0.0f);
relu_data = MS_MAXQ_F32(relu_data, zero_data);
MS_STQ_F32(dst + index, relu_data);
}
#endif
for (; index < ele_num; ++index) {
data[index] = data[index] < 0.0f ? 0.0f : data[index];
}
}
void Relu6Fp32(float *data, float *dst, int ele_num) {
int index = 0;
#ifdef ENABLE_AVX
int c8_block = DOWN_DIV(ele_num, C8NUM) * C8NUM;
for (; index < c8_block; index += C8NUM) {
MS_FLOAT32X8 relu6_data = MS_LD256_F32(data + index);
MS_FLOAT32X8 zero_data = MS_MOV256_F32(0.0f);
MS_FLOAT32X8 six_data = MS_MOV256_F32(6.0f);
relu6_data = MS_MAX256_F32(relu6_data, zero_data);
relu6_data = MS_MIN256_F32(relu6_data, six_data);
MS_ST256_F32(dst + index, relu6_data);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
int c4_block = DOWN_DIV(ele_num, C4NUM) * C4NUM;
for (; index < c4_block; index += C4NUM) {
MS_FLOAT32X4 relu6_data = MS_LDQ_F32(data + index);
MS_FLOAT32X4 zero_data = MS_MOVQ_F32(0.0f);
MS_FLOAT32X4 six_data = MS_MOVQ_F32(6.0f);
relu6_data = MS_MAXQ_F32(relu6_data, zero_data);
relu6_data = MS_MINQ_F32(relu6_data, six_data);
MS_STQ_F32(dst + index, relu6_data);
}
#endif
for (; index < ele_num; ++index) {
data[index] = data[index] < 0.0f ? 0.0f : data[index];
data[index] = data[index] > 6.0f ? 6.0f : data[index];
}
}

View File

@ -28,14 +28,6 @@ extern "C" {
int8_t MinInt8(int8_t a, int8_t b);
int8_t MaxInt8(int8_t a, int8_t b);
void ReluFp32(float *data, float *dst, int ele_num);
void Relu6Fp32(float *data, float *dst, int ele_num);
#ifdef ENABLE_AVX
#ifdef WIN32
void ReluFp32C8(float *data, float *dst, int ele_num);
void Relu6Fp32C8(float *data, float *dst, int ele_num);
#endif
#endif
int Offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3);
int OffsetComm(const int *shape, const int dim0, const int dim1, const int dim2);
int Offset4d(const int *shape, const int *dims);
@ -62,15 +54,6 @@ static inline int GetStride(int *strides, const int *shape, int length) {
}
return stride;
}
#ifdef ENABLE_ARM64
void BiasAdd(const float *bias, float *data, size_t oc4, size_t plan_size);
void BiasAddRelu6(const float *bias, float *data, size_t oc4, size_t plan_size);
void BiasAddRelu(const float *bias, float *data, size_t oc4, size_t plan_size);
void Relu6(float *data, size_t element4);
void Relu(float *data, size_t element4);
#endif
#ifdef __cplusplus
}
#endif

View File

@ -69,12 +69,6 @@ void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t
#endif
#ifdef ENABLE_ARM64
void BiasAdd(const float *bias, float *data, size_t oc4, size_t plan_size);
void BiasAddRelu6(const float *bias, float *data, size_t oc4, size_t plan_size);
void BiasAddRelu(const float *bias, float *data, size_t oc4, size_t plan_size);
void Relu6(float *data, size_t element4);
void Relu(float *data, size_t element4);
void DeconvDwFp32Border(float *dst, const float *src, const float *weight, size_t height, size_t width,
size_t in_kh_step, size_t in_kw_step, size_t kernel_w);

View File

@ -161,7 +161,7 @@ void ConvFp32OutNC4HW4(const float *input_data, float *packed_input, const float
#ifdef ENABLE_AVX
void SWBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left,
int right, const ConvParameter *conv_param, const SlidingWindowParam *sw_param, const SWConvKernel kernel,
int act_type, int ow_bock, int oc_block) {
int act_type, int ow_bock, int oc_block, size_t write_mode) {
for (int oh = top; oh < bottom; oh++) { // now h is only loop one time
int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
@ -179,7 +179,7 @@ void SWBorder(float *dst, const float *src, const float *weight, const float *bi
kernel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, act_type, ow_bock,
oc_block, sw_param->block_channel_, sw_param->ic_align_, sw_param->in_kw_step_, sw_param->in_kh_step_,
sw_param->in_sw_step_,
(conv_param->kernel_w_ - end_kw + start_kw) * C8NUM * oc_block * sw_param->ic_align_);
(conv_param->kernel_w_ - end_kw + start_kw) * C8NUM * oc_block * sw_param->ic_align_, write_mode);
dst_kernel += ow_bock * sw_param->block_channel_;
} // width loop
dst += sw_param->out_h_step_;
@ -232,6 +232,11 @@ void ConvSWFp32(const float *input_data, const float *packed_weight, const float
int in_h_start = top * stride_h - pad_u;
int in_w_start = left * stride_w - pad_l;
int center_step = in_h_start * in_h_step + in_w_start * ic_algin;
int write_mode = conv_param->out_format_;
int kernel_out_step = oc_algin;
if (write_mode == 13) {
kernel_out_step = out_h * out_w;
}
const int ow_block_num[4] = {12, 6, 4, 3};
const SWConvKernel kernel[4][2] = {{SWConv1x8Kernel, SWConv12x8Kernel},
{SWConv1x16Kernel, SWConv6x16Kernel},
@ -254,11 +259,11 @@ void ConvSWFp32(const float *input_data, const float *packed_weight, const float
const SWConvKernel kernel_border = kernel[oc_block - 1][0];
if (oh < top || oh >= bottom) { // oh in up or down border
SWBorder(dst_w, input_data, weight, bias, oh, oh + 1, 0, out_w, conv_param, sw_param, kernel_border, act_type,
1, oc_block);
1, oc_block, write_mode);
} else { // oh in center
// ow in right
SWBorder(dst_w, input_data, weight, bias, oh, oh + 1, 0, left, conv_param, sw_param, kernel_border, act_type,
1, oc_block);
1, oc_block, write_mode);
// ow in center
const float *src_w = src_h + (oh - top) * in_sh_step;
int ow_block = ow_block_num[oc_block - 1]; // 12 6 4 3
@ -269,12 +274,12 @@ void ConvSWFp32(const float *input_data, const float *packed_weight, const float
}
kernel[oc_block - 1][ow_block / ow_block_num[oc_block - 1]](
dst_w + ow * block_channel, src_w, weight, bias, kernel_h, kernel_w, act_type, ow_block, oc_block,
oc_algin, ic_algin, in_kw_step, in_kh_step, in_sw_step, 0);
kernel_out_step, ic_algin, in_kw_step, in_kh_step, in_sw_step, 0, write_mode);
src_w += ow_block * in_sw_step;
}
// ow in left
SWBorder(dst_w, input_data, weight, bias, oh, oh + 1, right, out_w, conv_param, sw_param, kernel_border,
act_type, 1, oc_block);
act_type, 1, oc_block, write_mode);
}
}
} // output h loop
@ -284,12 +289,14 @@ void ConvSWFp32(const float *input_data, const float *packed_weight, const float
}
void SWConv3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
size_t write_mode) {
in_kh_step *= sizeof(float);
in_sw_step *= sizeof(float);
in_kw_step *= sizeof(float);
oc_algin *= sizeof(float);
float *dst_4 = dst + out_step * C3NUM;
out_step *= sizeof(float);
kw_remainder *= sizeof(float);
asm volatile(
"cmpq $0, %2\n"
@ -403,6 +410,9 @@ void SWConv3x32Kernel(float *dst, const float *src, const float *weight, const f
"vminps %%ymm14, %%ymm11, %%ymm11\n"
"0:\n"
"cmpq $13, %3\n"
"je 1f\n"
// write to nhwc
"vmovups %%ymm0, (%2)\n" // dst_0
"vmovups %%ymm1, 0x20(%2)\n"
"vmovups %%ymm2, 0x40(%2)\n"
@ -415,19 +425,37 @@ void SWConv3x32Kernel(float *dst, const float *src, const float *weight, const f
"vmovups %%ymm9, 0x20(%2, %1, 2)\n"
"vmovups %%ymm10, 0x40(%2, %1, 2)\n"
"vmovups %%ymm11, 0x60(%2, %1, 2)\n"
"jmp 2f\n"
"1:\n"
// write to nc8hw8
"vmovups %%ymm0, (%2)\n" // dst_0
"vmovups %%ymm4, 0x20(%2)\n"
"vmovups %%ymm8, 0x40(%2)\n"
"vmovups %%ymm1, (%2, %1, 1)\n"
"vmovups %%ymm5, 0x20(%2, %1, 1)\n"
"vmovups %%ymm9, 0x40(%2, %1, 1)\n"
"vmovups %%ymm2, (%2, %1, 2)\n"
"vmovups %%ymm6, 0x20(%2, %1, 2)\n"
"vmovups %%ymm10, 0x40(%2, %1, 2)\n"
"vmovups %%ymm3, (%4)\n"
"vmovups %%ymm7, 0x20(%4)\n"
"vmovups %%ymm11, 0x40(%4)\n"
"2:\n"
:
: "a"(act_flag), "r"(oc_algin), "r"(dst)
: "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode), "r"(dst_4)
: "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm14");
}
void SWConv1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
size_t write_mode) {
in_kh_step *= sizeof(float);
in_kw_step *= sizeof(float);
oc_algin *= sizeof(float);
out_step *= sizeof(float);
kw_remainder *= sizeof(float);
float *dst_4 = dst + out_step * C3NUM;
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
@ -494,25 +522,37 @@ void SWConv1x32Kernel(float *dst, const float *src, const float *weight, const f
"vminps %%ymm14, %%ymm3, %%ymm3\n"
"0:\n"
"cmpq $13, %3\n"
"je 1f\n"
// write to nhwc
"vmovups %%ymm0, (%2)\n" // dst_0
"vmovups %%ymm1, 0x20(%2)\n"
"vmovups %%ymm2, 0x40(%2)\n"
"vmovups %%ymm3, 0x60(%2)\n"
"jmp 2f\n"
"1:\n"
// write to nc8hw8
"vmovups %%ymm0, (%2)\n" // dst_0
"vmovups %%ymm1, (%2, %1, 1)\n"
"vmovups %%ymm2, (%2, %1, 2)\n"
"vmovups %%ymm3, (%4)\n"
"2:\n"
:
: "a"(act_flag), "r"(oc_algin), "r"(dst)
: "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode), "r"(dst_4)
: "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm14");
}
void SWConv4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
size_t write_mode) {
in_kh_step *= sizeof(float);
in_kw_step *= sizeof(float);
in_sw_step *= sizeof(float);
kw_remainder *= sizeof(float);
size_t src_3_step = 3 * in_sw_step;
float *dst_3 = dst + 3 * oc_algin;
oc_algin *= sizeof(float);
float *dst_3 = dst + 3 * out_step;
out_step *= sizeof(float);
asm volatile(
"cmpq $0, %0\n"
"je 0f\n"
@ -640,6 +680,9 @@ void SWConv4x24Kernel(float *dst, const float *src, const float *weight, const f
"vminps %%ymm14, %%ymm11, %%ymm11\n"
"0:\n"
"cmpq $13, %4\n"
"je 1f\n"
// write to nhwc
"vmovups %%ymm0, (%2)\n" // dst_0
"vmovups %%ymm1, 0x20(%2)\n"
"vmovups %%ymm2, 0x40(%2)\n"
@ -652,19 +695,36 @@ void SWConv4x24Kernel(float *dst, const float *src, const float *weight, const f
"vmovups %%ymm9, (%3)\n" // dst+3
"vmovups %%ymm10, 0x20(%3)\n"
"vmovups %%ymm11, 0x40(%3)\n"
"jmp 2f\n"
"1:\n"
// write to nc8hw8
"vmovups %%ymm0, (%2)\n" // dst_0
"vmovups %%ymm3, 0x20(%2)\n"
"vmovups %%ymm6, 0x40(%2)\n"
"vmovups %%ymm9, 0x60(%2)\n"
"vmovups %%ymm1, (%2, %1, 1)\n"
"vmovups %%ymm4, 0x20(%2, %1, 1)\n"
"vmovups %%ymm7, 0x40(%2, %1, 1)\n"
"vmovups %%ymm10, 0x60(%2, %1, 1)\n"
"vmovups %%ymm2, (%2, %1, 2)\n"
"vmovups %%ymm5, 0x20(%2, %1, 2)\n"
"vmovups %%ymm8, 0x60(%2, %1, 2)\n"
"vmovups %%ymm11, 0x60(%2, %1, 2)\n"
"2:\n"
:
: "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3)
: "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3), "r"(write_mode)
: "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm14");
}
void SWConv1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
size_t write_mode) {
in_kh_step *= sizeof(float);
in_kw_step *= sizeof(float);
kw_remainder *= sizeof(float);
oc_algin *= sizeof(float);
out_step *= sizeof(float);
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
@ -726,24 +786,35 @@ void SWConv1x24Kernel(float *dst, const float *src, const float *weight, const f
"vminps %%ymm14, %%ymm2, %%ymm2\n"
"0:\n"
"cmpq $13, %3\n"
"je 1f\n"
// write to nhwc
"vmovups %%ymm0, (%2)\n" // dst_0
"vmovups %%ymm1, 0x20(%2)\n"
"vmovups %%ymm2, 0x40(%2)\n"
"jmp 2f\n"
"1:\n"
// write to nc4hw4
"vmovups %%ymm0, (%2)\n"
"vmovups %%ymm1, (%2, %1, 1)\n"
"vmovups %%ymm2, (%2, %1, 2)\n"
"2:\n"
:
: "a"(act_flag), "r"(oc_algin), "r"(dst)
: "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode)
: "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm12", "%ymm14");
}
void SWConv6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
size_t write_mode) {
in_kh_step *= sizeof(float);
in_kw_step *= sizeof(float);
in_sw_step *= sizeof(float);
kw_remainder *= sizeof(float);
size_t src_3_step = 3 * in_sw_step;
float *dst_3 = dst + 3 * oc_algin;
oc_algin *= sizeof(float);
float *dst_3 = dst + 3 * out_step;
out_step *= sizeof(float);
asm volatile(
"cmpq $0, %0\n"
"je 0f\n"
@ -874,6 +945,9 @@ void SWConv6x16Kernel(float *dst, const float *src, const float *weight, const f
"vminps %%ymm14, %%ymm11, %%ymm11\n"
"0:\n"
"cmpq $13, %4\n"
"je 1f\n"
// write to nhwc
"vmovups %%ymm0, (%2)\n" // dst_0
"vmovups %%ymm1, 0x20(%2)\n"
"vmovups %%ymm2, (%2, %1, 1)\n"
@ -886,19 +960,36 @@ void SWConv6x16Kernel(float *dst, const float *src, const float *weight, const f
"vmovups %%ymm9, 0x20(%3, %1, 1)\n"
"vmovups %%ymm10, (%3, %1, 2)\n"
"vmovups %%ymm11, 0x20(%3, %1, 2)\n"
"jmp 2f\n"
"1:\n"
// write to nc8hw8
"vmovups %%ymm0, (%2)\n" // dst_0
"vmovups %%ymm2, 0x20(%2)\n"
"vmovups %%ymm4, 0x40(%2)\n"
"vmovups %%ymm6, 0x60(%2)\n" // dst+3
"vmovups %%ymm8, 0x80(%2)\n"
"vmovups %%ymm10, 0xA0(%2)\n"
"vmovups %%ymm1, (%2, %1, 1)\n"
"vmovups %%ymm3, 0x20(%2, %1, 1)\n"
"vmovups %%ymm5, 0x40(%2, %1, 1)\n"
"vmovups %%ymm7, 0x60(%2, %1, 1)\n"
"vmovups %%ymm9, 0x80(%2, %1, 1)\n"
"vmovups %%ymm11, 0xA0(%2, %1, 1)\n"
"2:\n"
:
: "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3)
: "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3), "r"(write_mode)
: "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm14");
}
void SWConv1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
size_t write_mode) {
in_kh_step *= sizeof(float);
in_kw_step *= sizeof(float);
kw_remainder *= sizeof(float);
oc_algin *= sizeof(float);
out_step *= sizeof(float);
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
@ -955,25 +1046,35 @@ void SWConv1x16Kernel(float *dst, const float *src, const float *weight, const f
"vminps %%ymm14, %%ymm1, %%ymm1\n"
"0:\n"
"cmpq $13, %3\n"
"je 1f\n"
// write to nhwc
"vmovups %%ymm0, (%2)\n" // dst_0
"vmovups %%ymm1, 0x20(%2)\n"
"jmp 2f\n"
"1:\n"
// write nc8hw8
"vmovups %%ymm0, (%2)\n"
"vmovups %%ymm1, (%2, %1, 1)\n"
"2:\n"
:
: "a"(act_flag), "r"(oc_algin), "r"(dst)
: "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode)
: "%ecx", "%ymm0", "%ymm1", "%ymm12", "%ymm14");
}
void SWConv12x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
size_t write_mode) {
in_kh_step *= sizeof(float);
in_sw_step *= sizeof(float);
in_kw_step *= sizeof(float);
kw_remainder *= sizeof(float);
size_t src_3_step = 3 * in_sw_step;
float *dst_3 = dst + 3 * oc_algin;
float *dst_5 = dst + 5 * oc_algin;
float *dst_9 = dst + 9 * oc_algin;
oc_algin *= sizeof(float);
float *dst_3 = dst + 3 * out_step;
float *dst_5 = dst + 5 * out_step;
float *dst_9 = dst + 9 * out_step;
out_step *= sizeof(float);
asm volatile(
"cmpq $0, %0\n"
"je 0f\n"
@ -1105,6 +1206,9 @@ void SWConv12x8Kernel(float *dst, const float *src, const float *weight, const f
"vminps %%ymm14, %%ymm11, %%ymm11\n"
"Write:\n"
"cmpq $13, %6\n"
"je WriteNC8HW8\n"
// write nhwc
"vmovups %%ymm0, (%2)\n" // dst_0
"vmovups %%ymm1, (%2, %1)\n"
"vmovups %%ymm2, (%2, %1, 2)\n"
@ -1117,21 +1221,37 @@ void SWConv12x8Kernel(float *dst, const float *src, const float *weight, const f
"vmovups %%ymm9, (%5)\n" // dst_9
"vmovups %%ymm10, (%5, %1, 1)\n"
"vmovups %%ymm11, (%5, %1, 2)\n"
"jmp End\n"
"WriteNC8HW8:\n"
"vmovups %%ymm0, (%2)\n" // dst_0
"vmovups %%ymm1, 0x20(%2)\n"
"vmovups %%ymm2, 0x40(%2)\n"
"vmovups %%ymm3, 0x60(%2)\n" // dst_3
"vmovups %%ymm4, 0x80(%2)\n"
"vmovups %%ymm5, 0xA0(%2)\n" // dst_5
"vmovups %%ymm6, 0xC0(%2)\n"
"vmovups %%ymm7, 0xE0(%2)\n"
"vmovups %%ymm8, 0x100(%2)\n"
"vmovups %%ymm9, 0x120(%2)\n" // dst_9
"vmovups %%ymm10, 0x140(%2)\n"
"vmovups %%ymm11, 0x160(%2)\n"
"End:\n"
:
: "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3), "r"(dst_5), "r"(dst_9)
: "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3), "r"(dst_5), "r"(dst_9), "r"(write_mode)
: "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm14");
}
void SWConv4x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
size_t write_mode) {
in_kh_step *= sizeof(float);
in_sw_step *= sizeof(float);
in_kw_step *= sizeof(float);
size_t src_step = 3 * in_sw_step;
float *dst_3 = dst + 3 * oc_algin;
oc_algin *= sizeof(float);
float *dst_3 = dst + 3 * out_step;
out_step *= sizeof(float);
asm volatile(
"cmpq $0, %0\n"
"je 0f\n"
@ -1215,17 +1335,18 @@ void SWConv4x8Kernel(float *dst, const float *src, const float *weight, const fl
"vmovups %%ymm2, (%2, %1, 2)\n"
"vmovups %%ymm3, (%3)\n" // dst_3
:
: "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3)
: "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3)
: "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm12", "%ymm14");
}
void SWConv1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
size_t write_mode) {
in_kh_step *= sizeof(float);
in_kw_step *= sizeof(float);
kw_remainder *= sizeof(float);
oc_algin *= sizeof(float);
out_step *= sizeof(float);
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
@ -1277,16 +1398,18 @@ void SWConv1x8Kernel(float *dst, const float *src, const float *weight, const fl
"vminps %%ymm14, %%ymm0, %%ymm0\n"
"0:\n"
// write to nhec and nc8hw8 is identical!
"vmovups %%ymm0, (%2)\n" // dst_0
:
: "a"(act_flag), "r"(oc_algin), "r"(dst)
: "a"(act_flag), "r"(out_step), "r"(dst)
: "%ecx", "%ymm0", "%ymm12", "%ymm14");
}
#ifdef ENABLE_DEBUG
void SWConvWxKKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
size_t write_mode) {
__m256 dst_data[12];
const float *src_kh[12];
const float *src_kw[12];
@ -1339,7 +1462,13 @@ void SWConvWxKKernel(float *dst, const float *src, const float *weight, const fl
if (0x2 & act_flag) { // relu
dst_data[i * oc_block + j] = _mm256_max_ps(dst_data[i * oc_block + j], _mm256_set1_ps(0.0f));
}
_mm256_storeu_ps(dst + i * oc_algin + j * C8NUM, dst_data[i * oc_block + j]);
if (write_mode == 13) {
// write nc8hw8
_mm256_storeu_ps(dst + j * out_step + i * C8NUM, dst_data[i * oc_block + j]);
} else {
// write nhwc
_mm256_storeu_ps(dst + i * out_step + j * C8NUM, dst_data[i * oc_block + j]);
}
}
}
}

View File

@ -43,11 +43,11 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_
typedef void (*SWConvKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step,
size_t kw_remainder);
size_t kw_remainder, size_t write_mode);
void SWBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left,
int right, const ConvParameter *conv_param, const SlidingWindowParam *sw_param, SWConvKernel kernel,
int act_type, int ow_bock, int oc_block);
int act_type, int ow_bock, int oc_block, size_t write_mode);
void ConvSWFp32(const float *input_data, const float *packed_weight, const float *bias_data, float *output_data,
int task_id, ConvParameter *conv_param, SlidingWindowParam *sw_param);
@ -57,48 +57,59 @@ void SWCenter(float *dst, const float *src, const float *weight, const float *bi
#ifdef ENABLE_DEBUG
void SWConvWxKKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder);
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
size_t write_mode);
#endif
void SWConv3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder);
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
size_t write_mode);
void SWConv1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder);
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
size_t write_mode);
void SWConv4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder);
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
size_t write_mode);
void SWConv1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder);
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
size_t write_mode);
void SWConv6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder);
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
size_t write_mode);
void SWConv1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder);
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
size_t write_mode);
void SWConv12x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder);
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
size_t write_mode);
void SWConv8x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder);
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
size_t write_mode);
void SWConv4x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder);
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
size_t write_mode);
void SWConv1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder);
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
size_t write_mode);
#endif
#ifdef __cplusplus
}

View File

@ -18,6 +18,7 @@
#include "nnacl/fp32/common_func_fp32.h"
#include "nnacl/intrinsics/ms_simd_instructions.h"
#include "nnacl/errorcode.h"
#include "nnacl/fp32/activation_fp32.h"
#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE)
void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, int num_pixels,
@ -80,10 +81,10 @@ int ConvDw(float *output_data, const float *input_data, const float *weight_data
}
}
if (relu) {
ReluFp32(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_);
Fp32Relu(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data);
}
if (relu6) {
Relu6Fp32(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_);
Fp32Relu6(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data);
}
}
}
@ -779,10 +780,10 @@ void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, c
}
}
if (relu) {
ReluFp32(output, output, channels);
Fp32Relu(output, channels, output);
}
if (relu6) {
Relu6Fp32(output, output, channels);
Fp32Relu6(output, channels, output);
}
output += channels;
input = input + input_stride;

View File

@ -164,16 +164,24 @@ int InterpRow(const float *src_line, float *linear_output, int new_width, const
int w;
for (w = 0; w < new_width; w++) {
int c = 0;
#ifdef ENABLE_NEON
float32x4_t left_w = vdupq_n_f32(x_left_weights[w]);
float32x4_t right_w = vdupq_n_f32(1.0f - x_left_weights[w]);
for (; c <= in_c - 4; c += 4) {
float32x4_t left = vld1q_f32(src_line + x_lefts[w] * in_c + c);
float32x4_t right = vld1q_f32(src_line + x_rights[w] * in_c + c);
float32x4_t interp_value = left * left_w + right * right_w;
vst1q_f32(linear_output + w * in_c + c, interp_value);
#if defined(ENABLE_AVX)
MS_FLOAT32X8 left_w_8 = MS_MOV256_F32(x_left_weights[w]);
MS_FLOAT32X8 right_w_8 = MS_MOV256_F32(1.0f - x_left_weights[w]);
for (; c <= in_c - C8NUM; c += C8NUM) {
MS_FLOAT32X8 left = MS_LD256_F32(src_line + x_lefts[w] * in_c + c);
MS_FLOAT32X8 right = MS_LD256_F32(src_line + x_rights[w] * in_c + c);
MS_FLOAT32X8 interp_value = left * left_w_8 + right * right_w_8;
MS_ST256_F32(linear_output + w * in_c + c, interp_value);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_FLOAT32X4 left_w = MS_MOVQ_F32(x_left_weights[w]);
MS_FLOAT32X4 right_w = MS_MOVQ_F32(1.0f - x_left_weights[w]);
for (; c <= in_c - C4NUM; c += C4NUM) {
MS_FLOAT32X4 left = MS_LDQ_F32(src_line + x_lefts[w] * in_c + c);
MS_FLOAT32X4 right = MS_LDQ_F32(src_line + x_rights[w] * in_c + c);
MS_FLOAT32X4 interp_value = left * left_w + right * right_w;
MS_STQ_F32(linear_output + w * in_c + c, interp_value);
}
#endif
int left_w_offset = x_lefts[w] * in_c;
@ -192,15 +200,24 @@ int InterpCol(const float *bottom_line, const float *top_line, float *output, in
int w;
for (w = 0; w < new_width; w++) {
int c = 0;
#ifdef ENABLE_NEON
float32x4_t bottom_w = vdupq_n_f32(y_bottom_weight);
float32x4_t top_w = vdupq_n_f32(1.0f - y_bottom_weight);
for (; c <= in_c - 4; c += 4) {
float32x4_t bottom = vld1q_f32(bottom_line + w * in_c + c);
float32x4_t top = vld1q_f32(top_line + w * in_c + c);
float32x4_t interp_value = bottom * bottom_w + top * top_w;
vst1q_f32(output + w * in_c + c, interp_value);
#if defined(ENABLE_AVX)
MS_FLOAT32X8 bottom_w_8 = MS_MOV256_F32(y_bottom_weight);
MS_FLOAT32X8 top_w_8 = MS_MOV256_F32(1.0f - y_bottom_weight);
for (; c <= in_c - C8NUM; c += C8NUM) {
MS_FLOAT32X8 bottom = MS_LD256_F32(bottom_line + w * in_c + c);
MS_FLOAT32X8 top = MS_LD256_F32(top_line + w * in_c + c);
MS_FLOAT32X8 interp_value = bottom * bottom_w_8 + top * top_w_8;
MS_ST256_F32(output + w * in_c + c, interp_value);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_FLOAT32X4 bottom_w = MS_MOVQ_F32(y_bottom_weight);
MS_FLOAT32X4 top_w = MS_MOVQ_F32(1.0f - y_bottom_weight);
for (; c <= in_c - C4NUM; c += C4NUM) {
MS_FLOAT32X4 bottom = MS_LDQ_F32(bottom_line + w * in_c + c);
MS_FLOAT32X4 top = MS_LDQ_F32(top_line + w * in_c + c);
MS_FLOAT32X4 interp_value = bottom * bottom_w + top * top_w;
MS_STQ_F32(output + w * in_c + c, interp_value);
}
#endif
for (; c < in_c; c++) {
@ -299,21 +316,40 @@ void BicubicInterpRow(const float *src, float *dst, const float *weights, const
const float *src2_w = src + lefts[4 * w + 2] * channel;
const float *src3_w = src + lefts[4 * w + 3] * channel;
int c = 0;
#ifdef ENABLE_NEON
float32x4_t weight0_vec = vdupq_n_f32(weight[0]);
float32x4_t weight1_vec = vdupq_n_f32(weight[1]);
float32x4_t weight2_vec = vdupq_n_f32(weight[2]);
float32x4_t weight3_vec = vdupq_n_f32(weight[3]);
for (; c <= channel - 4; c += 4) {
float32x4_t src0_vec = vld1q_f32(src0_w + c);
float32x4_t src1_vec = vld1q_f32(src1_w + c);
float32x4_t src2_vec = vld1q_f32(src2_w + c);
float32x4_t src3_vec = vld1q_f32(src3_w + c);
float32x4_t interp_value =
src0_vec * weight0_vec + src1_vec * weight1_vec + src2_vec * weight2_vec + src3_vec * weight3_vec;
vst1q_f32(dst_w + c, interp_value);
#if defined(ENABLE_AVX)
MS_FLOAT32X8 weight0_vec_8 = MS_MOV256_F32(weight[0]);
MS_FLOAT32X8 weight1_vec_8 = MS_MOV256_F32(weight[1]);
MS_FLOAT32X8 weight2_vec_8 = MS_MOV256_F32(weight[2]);
MS_FLOAT32X8 weight3_vec_8 = MS_MOV256_F32(weight[3]);
for (; c <= channel - C8NUM; c += C8NUM) {
MS_FLOAT32X8 src0_vec = MS_LD256_F32(src0_w + c);
MS_FLOAT32X8 src1_vec = MS_LD256_F32(src1_w + c);
MS_FLOAT32X8 src2_vec = MS_LD256_F32(src2_w + c);
MS_FLOAT32X8 src3_vec = MS_LD256_F32(src3_w + c);
MS_FLOAT32X8 dst0 = MS_MUL256_F32(src0_vec, weight0_vec_8);
MS_FLOAT32X8 dst1 = MS_MUL256_F32(src1_vec, weight1_vec_8);
MS_FLOAT32X8 dst2 = MS_MUL256_F32(src2_vec, weight2_vec_8);
MS_FLOAT32X8 dst3 = MS_MUL256_F32(src3_vec, weight3_vec_8);
MS_FLOAT32X8 interp_value = MS_ADD256_F32(dst3, MS_ADD256_F32(dst2, MS_ADD256_F32(dst1, dst0)));
MS_ST256_F32(dst_w + c, interp_value);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_FLOAT32X4 weight0_vec = MS_MOVQ_F32(weight[0]);
MS_FLOAT32X4 weight1_vec = MS_MOVQ_F32(weight[1]);
MS_FLOAT32X4 weight2_vec = MS_MOVQ_F32(weight[2]);
MS_FLOAT32X4 weight3_vec = MS_MOVQ_F32(weight[3]);
for (; c <= channel - C4NUM; c += C4NUM) {
MS_FLOAT32X4 src0_vec = MS_LDQ_F32(src0_w + c);
MS_FLOAT32X4 src1_vec = MS_LDQ_F32(src1_w + c);
MS_FLOAT32X4 src2_vec = MS_LDQ_F32(src2_w + c);
MS_FLOAT32X4 src3_vec = MS_LDQ_F32(src3_w + c);
MS_FLOAT32X4 dst0 = MS_MULQ_F32(src0_vec, weight0_vec);
MS_FLOAT32X4 dst1 = MS_MULQ_F32(src1_vec, weight1_vec);
MS_FLOAT32X4 dst2 = MS_MULQ_F32(src2_vec, weight2_vec);
MS_FLOAT32X4 dst3 = MS_MULQ_F32(src3_vec, weight3_vec);
MS_FLOAT32X4 interp_value = MS_ADDQ_F32(dst3, MS_ADDQ_F32(dst2, MS_ADDQ_F32(dst1, dst0)));
MS_STQ_F32(dst_w + c, interp_value);
}
#endif
for (; c < channel; c++) {
@ -334,20 +370,40 @@ void BicubicInterpCol(const float *src, float *dst, const float *weights, int wi
const float *src2_w = src2 + w * channel;
const float *src3_w = src3 + w * channel;
int c = 0;
#ifdef ENABLE_NEON
float32x4_t weight0_vec = vdupq_n_f32(weights[0]);
float32x4_t weight1_vec = vdupq_n_f32(weights[1]);
float32x4_t weight2_vec = vdupq_n_f32(weights[2]);
float32x4_t weight3_vec = vdupq_n_f32(weights[3]);
for (; c <= channel - 4; c += 4) {
float32x4_t src0_vec = vld1q_f32(src0_w + c);
float32x4_t src1_vec = vld1q_f32(src1_w + c);
float32x4_t src2_vec = vld1q_f32(src2_w + c);
float32x4_t src3_vec = vld1q_f32(src3_w + c);
float32x4_t interp_value =
src0_vec * weight0_vec + src1_vec * weight1_vec + src2_vec * weight2_vec + src3_vec * weight3_vec;
vst1q_f32(dst_w + c, interp_value);
#ifdef ENABLE_AVX
MS_FLOAT32X8 weight0_vec_8 = MS_MOV256_F32(weights[0]);
MS_FLOAT32X8 weight1_vec_8 = MS_MOV256_F32(weights[1]);
MS_FLOAT32X8 weight2_vec_8 = MS_MOV256_F32(weights[2]);
MS_FLOAT32X8 weight3_vec_8 = MS_MOV256_F32(weights[3]);
for (; c <= channel - C8NUM; c += C8NUM) {
MS_FLOAT32X8 src0_vec = MS_LD256_F32(src0_w + c);
MS_FLOAT32X8 src1_vec = MS_LD256_F32(src1_w + c);
MS_FLOAT32X8 src2_vec = MS_LD256_F32(src2_w + c);
MS_FLOAT32X8 src3_vec = MS_LD256_F32(src3_w + c);
MS_FLOAT32X8 dst1 = MS_MUL256_F32(src0_vec, weight0_vec_8);
MS_FLOAT32X8 dst2 = MS_MUL256_F32(src1_vec, weight1_vec_8);
MS_FLOAT32X8 dst3 = MS_MUL256_F32(src2_vec, weight2_vec_8);
MS_FLOAT32X8 dst4 = MS_MUL256_F32(src3_vec, weight3_vec_8);
MS_FLOAT32X8 interp_value = MS_ADD256_F32(dst4, MS_ADD256_F32(dst3, MS_ADD256_F32(dst1, dst2)));
MS_ST256_F32(dst_w + c, interp_value);
}
#endif
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
MS_FLOAT32X4 weight0_vec = MS_MOVQ_F32(weights[0]);
MS_FLOAT32X4 weight1_vec = MS_MOVQ_F32(weights[1]);
MS_FLOAT32X4 weight2_vec = MS_MOVQ_F32(weights[2]);
MS_FLOAT32X4 weight3_vec = MS_MOVQ_F32(weights[3]);
for (; c <= channel - C4NUM; c += C4NUM) {
MS_FLOAT32X4 src0_vec = MS_LDQ_F32(src0_w + c);
MS_FLOAT32X4 src1_vec = MS_LDQ_F32(src1_w + c);
MS_FLOAT32X4 src2_vec = MS_LDQ_F32(src2_w + c);
MS_FLOAT32X4 src3_vec = MS_LDQ_F32(src3_w + c);
MS_FLOAT32X4 dst1 = MS_MULQ_F32(src0_vec, weight0_vec);
MS_FLOAT32X4 dst2 = MS_MULQ_F32(src1_vec, weight1_vec);
MS_FLOAT32X4 dst3 = MS_MULQ_F32(src2_vec, weight2_vec);
MS_FLOAT32X4 dst4 = MS_MULQ_F32(src3_vec, weight3_vec);
MS_FLOAT32X4 interp_value = MS_ADDQ_F32(dst4, MS_ADDQ_F32(dst3, MS_ADDQ_F32(dst1, dst2)));
MS_STQ_F32(dst_w + c, interp_value);
}
#endif
for (; c < channel; c++) {

View File

@ -75,6 +75,7 @@ int ConvolutionDepthwiseFP32Coder::DoCode(CoderContext *const context) {
},
{
"conv_depthwise_fp32.c",
"activation_fp32.c",
},
{});
nnacl::NNaclFp32Serializer code;