forked from mindspore-Ecosystem/mindspore
!22757 [MS][LITE][CPU] sw nc4hw4 support
Merge pull request !22757 from liuzhongkai/code_re5
This commit is contained in:
commit
50847c9659
|
@ -33,59 +33,3 @@ int Offset6d(const int *shape, const int *dims) {
|
|||
int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); }
|
||||
|
||||
int8_t MaxInt8(int8_t a, int8_t b) { return a ^ ((a ^ b) & -(a < b)); }
|
||||
|
||||
void ReluFp32(float *data, float *dst, int ele_num) {
|
||||
int index = 0;
|
||||
#ifdef ENABLE_AVX
|
||||
int c8_block = DOWN_DIV(ele_num, C8NUM) * C8NUM;
|
||||
for (; index < c8_block; index += C8NUM) {
|
||||
MS_FLOAT32X8 relu_data = MS_LD256_F32(data + index);
|
||||
MS_FLOAT32X8 zero_data = MS_MOV256_F32(0.0f);
|
||||
relu_data = MS_MAX256_F32(relu_data, zero_data);
|
||||
MS_ST256_F32(dst + index, relu_data);
|
||||
}
|
||||
#endif
|
||||
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
|
||||
int c4_block = DOWN_DIV(ele_num, C4NUM) * C4NUM;
|
||||
for (; index < c4_block; index += C4NUM) {
|
||||
MS_FLOAT32X4 relu_data = MS_LDQ_F32(data + index);
|
||||
MS_FLOAT32X4 zero_data = MS_MOVQ_F32(0.0f);
|
||||
relu_data = MS_MAXQ_F32(relu_data, zero_data);
|
||||
MS_STQ_F32(dst + index, relu_data);
|
||||
}
|
||||
#endif
|
||||
for (; index < ele_num; ++index) {
|
||||
data[index] = data[index] < 0.0f ? 0.0f : data[index];
|
||||
}
|
||||
}
|
||||
|
||||
void Relu6Fp32(float *data, float *dst, int ele_num) {
|
||||
int index = 0;
|
||||
#ifdef ENABLE_AVX
|
||||
int c8_block = DOWN_DIV(ele_num, C8NUM) * C8NUM;
|
||||
for (; index < c8_block; index += C8NUM) {
|
||||
MS_FLOAT32X8 relu6_data = MS_LD256_F32(data + index);
|
||||
MS_FLOAT32X8 zero_data = MS_MOV256_F32(0.0f);
|
||||
MS_FLOAT32X8 six_data = MS_MOV256_F32(6.0f);
|
||||
relu6_data = MS_MAX256_F32(relu6_data, zero_data);
|
||||
relu6_data = MS_MIN256_F32(relu6_data, six_data);
|
||||
MS_ST256_F32(dst + index, relu6_data);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
|
||||
int c4_block = DOWN_DIV(ele_num, C4NUM) * C4NUM;
|
||||
for (; index < c4_block; index += C4NUM) {
|
||||
MS_FLOAT32X4 relu6_data = MS_LDQ_F32(data + index);
|
||||
MS_FLOAT32X4 zero_data = MS_MOVQ_F32(0.0f);
|
||||
MS_FLOAT32X4 six_data = MS_MOVQ_F32(6.0f);
|
||||
relu6_data = MS_MAXQ_F32(relu6_data, zero_data);
|
||||
relu6_data = MS_MINQ_F32(relu6_data, six_data);
|
||||
MS_STQ_F32(dst + index, relu6_data);
|
||||
}
|
||||
#endif
|
||||
for (; index < ele_num; ++index) {
|
||||
data[index] = data[index] < 0.0f ? 0.0f : data[index];
|
||||
data[index] = data[index] > 6.0f ? 6.0f : data[index];
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,14 +28,6 @@ extern "C" {
|
|||
|
||||
int8_t MinInt8(int8_t a, int8_t b);
|
||||
int8_t MaxInt8(int8_t a, int8_t b);
|
||||
void ReluFp32(float *data, float *dst, int ele_num);
|
||||
void Relu6Fp32(float *data, float *dst, int ele_num);
|
||||
#ifdef ENABLE_AVX
|
||||
#ifdef WIN32
|
||||
void ReluFp32C8(float *data, float *dst, int ele_num);
|
||||
void Relu6Fp32C8(float *data, float *dst, int ele_num);
|
||||
#endif
|
||||
#endif
|
||||
int Offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3);
|
||||
int OffsetComm(const int *shape, const int dim0, const int dim1, const int dim2);
|
||||
int Offset4d(const int *shape, const int *dims);
|
||||
|
@ -62,15 +54,6 @@ static inline int GetStride(int *strides, const int *shape, int length) {
|
|||
}
|
||||
return stride;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void BiasAdd(const float *bias, float *data, size_t oc4, size_t plan_size);
|
||||
void BiasAddRelu6(const float *bias, float *data, size_t oc4, size_t plan_size);
|
||||
void BiasAddRelu(const float *bias, float *data, size_t oc4, size_t plan_size);
|
||||
void Relu6(float *data, size_t element4);
|
||||
void Relu(float *data, size_t element4);
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -69,12 +69,6 @@ void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t
|
|||
#endif
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void BiasAdd(const float *bias, float *data, size_t oc4, size_t plan_size);
|
||||
void BiasAddRelu6(const float *bias, float *data, size_t oc4, size_t plan_size);
|
||||
void BiasAddRelu(const float *bias, float *data, size_t oc4, size_t plan_size);
|
||||
void Relu6(float *data, size_t element4);
|
||||
void Relu(float *data, size_t element4);
|
||||
|
||||
void DeconvDwFp32Border(float *dst, const float *src, const float *weight, size_t height, size_t width,
|
||||
size_t in_kh_step, size_t in_kw_step, size_t kernel_w);
|
||||
|
||||
|
|
|
@ -161,7 +161,7 @@ void ConvFp32OutNC4HW4(const float *input_data, float *packed_input, const float
|
|||
#ifdef ENABLE_AVX
|
||||
void SWBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left,
|
||||
int right, const ConvParameter *conv_param, const SlidingWindowParam *sw_param, const SWConvKernel kernel,
|
||||
int act_type, int ow_bock, int oc_block) {
|
||||
int act_type, int ow_bock, int oc_block, size_t write_mode) {
|
||||
for (int oh = top; oh < bottom; oh++) { // now h is only loop one time
|
||||
int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
|
||||
|
@ -179,7 +179,7 @@ void SWBorder(float *dst, const float *src, const float *weight, const float *bi
|
|||
kernel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, act_type, ow_bock,
|
||||
oc_block, sw_param->block_channel_, sw_param->ic_align_, sw_param->in_kw_step_, sw_param->in_kh_step_,
|
||||
sw_param->in_sw_step_,
|
||||
(conv_param->kernel_w_ - end_kw + start_kw) * C8NUM * oc_block * sw_param->ic_align_);
|
||||
(conv_param->kernel_w_ - end_kw + start_kw) * C8NUM * oc_block * sw_param->ic_align_, write_mode);
|
||||
dst_kernel += ow_bock * sw_param->block_channel_;
|
||||
} // width loop
|
||||
dst += sw_param->out_h_step_;
|
||||
|
@ -232,6 +232,11 @@ void ConvSWFp32(const float *input_data, const float *packed_weight, const float
|
|||
int in_h_start = top * stride_h - pad_u;
|
||||
int in_w_start = left * stride_w - pad_l;
|
||||
int center_step = in_h_start * in_h_step + in_w_start * ic_algin;
|
||||
int write_mode = conv_param->out_format_;
|
||||
int kernel_out_step = oc_algin;
|
||||
if (write_mode == 13) {
|
||||
kernel_out_step = out_h * out_w;
|
||||
}
|
||||
const int ow_block_num[4] = {12, 6, 4, 3};
|
||||
const SWConvKernel kernel[4][2] = {{SWConv1x8Kernel, SWConv12x8Kernel},
|
||||
{SWConv1x16Kernel, SWConv6x16Kernel},
|
||||
|
@ -254,11 +259,11 @@ void ConvSWFp32(const float *input_data, const float *packed_weight, const float
|
|||
const SWConvKernel kernel_border = kernel[oc_block - 1][0];
|
||||
if (oh < top || oh >= bottom) { // oh in up or down border
|
||||
SWBorder(dst_w, input_data, weight, bias, oh, oh + 1, 0, out_w, conv_param, sw_param, kernel_border, act_type,
|
||||
1, oc_block);
|
||||
1, oc_block, write_mode);
|
||||
} else { // oh in center
|
||||
// ow in right
|
||||
SWBorder(dst_w, input_data, weight, bias, oh, oh + 1, 0, left, conv_param, sw_param, kernel_border, act_type,
|
||||
1, oc_block);
|
||||
1, oc_block, write_mode);
|
||||
// ow in center
|
||||
const float *src_w = src_h + (oh - top) * in_sh_step;
|
||||
int ow_block = ow_block_num[oc_block - 1]; // 12 6 4 3
|
||||
|
@ -269,12 +274,12 @@ void ConvSWFp32(const float *input_data, const float *packed_weight, const float
|
|||
}
|
||||
kernel[oc_block - 1][ow_block / ow_block_num[oc_block - 1]](
|
||||
dst_w + ow * block_channel, src_w, weight, bias, kernel_h, kernel_w, act_type, ow_block, oc_block,
|
||||
oc_algin, ic_algin, in_kw_step, in_kh_step, in_sw_step, 0);
|
||||
kernel_out_step, ic_algin, in_kw_step, in_kh_step, in_sw_step, 0, write_mode);
|
||||
src_w += ow_block * in_sw_step;
|
||||
}
|
||||
// ow in left
|
||||
SWBorder(dst_w, input_data, weight, bias, oh, oh + 1, right, out_w, conv_param, sw_param, kernel_border,
|
||||
act_type, 1, oc_block);
|
||||
act_type, 1, oc_block, write_mode);
|
||||
}
|
||||
}
|
||||
} // output h loop
|
||||
|
@ -284,12 +289,14 @@ void ConvSWFp32(const float *input_data, const float *packed_weight, const float
|
|||
}
|
||||
|
||||
void SWConv3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
|
||||
size_t write_mode) {
|
||||
in_kh_step *= sizeof(float);
|
||||
in_sw_step *= sizeof(float);
|
||||
in_kw_step *= sizeof(float);
|
||||
oc_algin *= sizeof(float);
|
||||
float *dst_4 = dst + out_step * C3NUM;
|
||||
out_step *= sizeof(float);
|
||||
kw_remainder *= sizeof(float);
|
||||
asm volatile(
|
||||
"cmpq $0, %2\n"
|
||||
|
@ -403,6 +410,9 @@ void SWConv3x32Kernel(float *dst, const float *src, const float *weight, const f
|
|||
"vminps %%ymm14, %%ymm11, %%ymm11\n"
|
||||
|
||||
"0:\n"
|
||||
"cmpq $13, %3\n"
|
||||
"je 1f\n"
|
||||
// write to nhwc
|
||||
"vmovups %%ymm0, (%2)\n" // dst_0
|
||||
"vmovups %%ymm1, 0x20(%2)\n"
|
||||
"vmovups %%ymm2, 0x40(%2)\n"
|
||||
|
@ -415,19 +425,37 @@ void SWConv3x32Kernel(float *dst, const float *src, const float *weight, const f
|
|||
"vmovups %%ymm9, 0x20(%2, %1, 2)\n"
|
||||
"vmovups %%ymm10, 0x40(%2, %1, 2)\n"
|
||||
"vmovups %%ymm11, 0x60(%2, %1, 2)\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
// write to nc8hw8
|
||||
"vmovups %%ymm0, (%2)\n" // dst_0
|
||||
"vmovups %%ymm4, 0x20(%2)\n"
|
||||
"vmovups %%ymm8, 0x40(%2)\n"
|
||||
"vmovups %%ymm1, (%2, %1, 1)\n"
|
||||
"vmovups %%ymm5, 0x20(%2, %1, 1)\n"
|
||||
"vmovups %%ymm9, 0x40(%2, %1, 1)\n"
|
||||
"vmovups %%ymm2, (%2, %1, 2)\n"
|
||||
"vmovups %%ymm6, 0x20(%2, %1, 2)\n"
|
||||
"vmovups %%ymm10, 0x40(%2, %1, 2)\n"
|
||||
"vmovups %%ymm3, (%4)\n"
|
||||
"vmovups %%ymm7, 0x20(%4)\n"
|
||||
"vmovups %%ymm11, 0x40(%4)\n"
|
||||
"2:\n"
|
||||
:
|
||||
: "a"(act_flag), "r"(oc_algin), "r"(dst)
|
||||
: "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode), "r"(dst_4)
|
||||
: "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm14");
|
||||
}
|
||||
|
||||
void SWConv1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
|
||||
size_t write_mode) {
|
||||
in_kh_step *= sizeof(float);
|
||||
in_kw_step *= sizeof(float);
|
||||
oc_algin *= sizeof(float);
|
||||
out_step *= sizeof(float);
|
||||
kw_remainder *= sizeof(float);
|
||||
float *dst_4 = dst + out_step * C3NUM;
|
||||
asm volatile(
|
||||
"cmpq $0, %2\n"
|
||||
"je 0f\n"
|
||||
|
@ -494,25 +522,37 @@ void SWConv1x32Kernel(float *dst, const float *src, const float *weight, const f
|
|||
"vminps %%ymm14, %%ymm3, %%ymm3\n"
|
||||
|
||||
"0:\n"
|
||||
"cmpq $13, %3\n"
|
||||
"je 1f\n"
|
||||
// write to nhwc
|
||||
"vmovups %%ymm0, (%2)\n" // dst_0
|
||||
"vmovups %%ymm1, 0x20(%2)\n"
|
||||
"vmovups %%ymm2, 0x40(%2)\n"
|
||||
"vmovups %%ymm3, 0x60(%2)\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
// write to nc8hw8
|
||||
"vmovups %%ymm0, (%2)\n" // dst_0
|
||||
"vmovups %%ymm1, (%2, %1, 1)\n"
|
||||
"vmovups %%ymm2, (%2, %1, 2)\n"
|
||||
"vmovups %%ymm3, (%4)\n"
|
||||
"2:\n"
|
||||
:
|
||||
: "a"(act_flag), "r"(oc_algin), "r"(dst)
|
||||
: "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode), "r"(dst_4)
|
||||
: "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm14");
|
||||
}
|
||||
|
||||
void SWConv4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
|
||||
size_t write_mode) {
|
||||
in_kh_step *= sizeof(float);
|
||||
in_kw_step *= sizeof(float);
|
||||
in_sw_step *= sizeof(float);
|
||||
kw_remainder *= sizeof(float);
|
||||
size_t src_3_step = 3 * in_sw_step;
|
||||
float *dst_3 = dst + 3 * oc_algin;
|
||||
oc_algin *= sizeof(float);
|
||||
float *dst_3 = dst + 3 * out_step;
|
||||
out_step *= sizeof(float);
|
||||
asm volatile(
|
||||
"cmpq $0, %0\n"
|
||||
"je 0f\n"
|
||||
|
@ -640,6 +680,9 @@ void SWConv4x24Kernel(float *dst, const float *src, const float *weight, const f
|
|||
"vminps %%ymm14, %%ymm11, %%ymm11\n"
|
||||
|
||||
"0:\n"
|
||||
"cmpq $13, %4\n"
|
||||
"je 1f\n"
|
||||
// write to nhwc
|
||||
"vmovups %%ymm0, (%2)\n" // dst_0
|
||||
"vmovups %%ymm1, 0x20(%2)\n"
|
||||
"vmovups %%ymm2, 0x40(%2)\n"
|
||||
|
@ -652,19 +695,36 @@ void SWConv4x24Kernel(float *dst, const float *src, const float *weight, const f
|
|||
"vmovups %%ymm9, (%3)\n" // dst+3
|
||||
"vmovups %%ymm10, 0x20(%3)\n"
|
||||
"vmovups %%ymm11, 0x40(%3)\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
// write to nc8hw8
|
||||
"vmovups %%ymm0, (%2)\n" // dst_0
|
||||
"vmovups %%ymm3, 0x20(%2)\n"
|
||||
"vmovups %%ymm6, 0x40(%2)\n"
|
||||
"vmovups %%ymm9, 0x60(%2)\n"
|
||||
"vmovups %%ymm1, (%2, %1, 1)\n"
|
||||
"vmovups %%ymm4, 0x20(%2, %1, 1)\n"
|
||||
"vmovups %%ymm7, 0x40(%2, %1, 1)\n"
|
||||
"vmovups %%ymm10, 0x60(%2, %1, 1)\n"
|
||||
"vmovups %%ymm2, (%2, %1, 2)\n"
|
||||
"vmovups %%ymm5, 0x20(%2, %1, 2)\n"
|
||||
"vmovups %%ymm8, 0x60(%2, %1, 2)\n"
|
||||
"vmovups %%ymm11, 0x60(%2, %1, 2)\n"
|
||||
"2:\n"
|
||||
:
|
||||
: "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3)
|
||||
: "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3), "r"(write_mode)
|
||||
: "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm14");
|
||||
}
|
||||
|
||||
void SWConv1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
|
||||
size_t write_mode) {
|
||||
in_kh_step *= sizeof(float);
|
||||
in_kw_step *= sizeof(float);
|
||||
kw_remainder *= sizeof(float);
|
||||
oc_algin *= sizeof(float);
|
||||
out_step *= sizeof(float);
|
||||
asm volatile(
|
||||
"cmpq $0, %2\n"
|
||||
"je 0f\n"
|
||||
|
@ -726,24 +786,35 @@ void SWConv1x24Kernel(float *dst, const float *src, const float *weight, const f
|
|||
"vminps %%ymm14, %%ymm2, %%ymm2\n"
|
||||
|
||||
"0:\n"
|
||||
"cmpq $13, %3\n"
|
||||
"je 1f\n"
|
||||
// write to nhwc
|
||||
"vmovups %%ymm0, (%2)\n" // dst_0
|
||||
"vmovups %%ymm1, 0x20(%2)\n"
|
||||
"vmovups %%ymm2, 0x40(%2)\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
// write to nc4hw4
|
||||
"vmovups %%ymm0, (%2)\n"
|
||||
"vmovups %%ymm1, (%2, %1, 1)\n"
|
||||
"vmovups %%ymm2, (%2, %1, 2)\n"
|
||||
"2:\n"
|
||||
:
|
||||
: "a"(act_flag), "r"(oc_algin), "r"(dst)
|
||||
: "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode)
|
||||
: "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm12", "%ymm14");
|
||||
}
|
||||
|
||||
void SWConv6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
|
||||
size_t write_mode) {
|
||||
in_kh_step *= sizeof(float);
|
||||
in_kw_step *= sizeof(float);
|
||||
in_sw_step *= sizeof(float);
|
||||
kw_remainder *= sizeof(float);
|
||||
size_t src_3_step = 3 * in_sw_step;
|
||||
float *dst_3 = dst + 3 * oc_algin;
|
||||
oc_algin *= sizeof(float);
|
||||
float *dst_3 = dst + 3 * out_step;
|
||||
out_step *= sizeof(float);
|
||||
asm volatile(
|
||||
"cmpq $0, %0\n"
|
||||
"je 0f\n"
|
||||
|
@ -874,6 +945,9 @@ void SWConv6x16Kernel(float *dst, const float *src, const float *weight, const f
|
|||
"vminps %%ymm14, %%ymm11, %%ymm11\n"
|
||||
|
||||
"0:\n"
|
||||
"cmpq $13, %4\n"
|
||||
"je 1f\n"
|
||||
// write to nhwc
|
||||
"vmovups %%ymm0, (%2)\n" // dst_0
|
||||
"vmovups %%ymm1, 0x20(%2)\n"
|
||||
"vmovups %%ymm2, (%2, %1, 1)\n"
|
||||
|
@ -886,19 +960,36 @@ void SWConv6x16Kernel(float *dst, const float *src, const float *weight, const f
|
|||
"vmovups %%ymm9, 0x20(%3, %1, 1)\n"
|
||||
"vmovups %%ymm10, (%3, %1, 2)\n"
|
||||
"vmovups %%ymm11, 0x20(%3, %1, 2)\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
// write to nc8hw8
|
||||
"vmovups %%ymm0, (%2)\n" // dst_0
|
||||
"vmovups %%ymm2, 0x20(%2)\n"
|
||||
"vmovups %%ymm4, 0x40(%2)\n"
|
||||
"vmovups %%ymm6, 0x60(%2)\n" // dst+3
|
||||
"vmovups %%ymm8, 0x80(%2)\n"
|
||||
"vmovups %%ymm10, 0xA0(%2)\n"
|
||||
"vmovups %%ymm1, (%2, %1, 1)\n"
|
||||
"vmovups %%ymm3, 0x20(%2, %1, 1)\n"
|
||||
"vmovups %%ymm5, 0x40(%2, %1, 1)\n"
|
||||
"vmovups %%ymm7, 0x60(%2, %1, 1)\n"
|
||||
"vmovups %%ymm9, 0x80(%2, %1, 1)\n"
|
||||
"vmovups %%ymm11, 0xA0(%2, %1, 1)\n"
|
||||
"2:\n"
|
||||
:
|
||||
: "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3)
|
||||
: "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3), "r"(write_mode)
|
||||
: "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm14");
|
||||
}
|
||||
|
||||
void SWConv1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
|
||||
size_t write_mode) {
|
||||
in_kh_step *= sizeof(float);
|
||||
in_kw_step *= sizeof(float);
|
||||
kw_remainder *= sizeof(float);
|
||||
oc_algin *= sizeof(float);
|
||||
out_step *= sizeof(float);
|
||||
asm volatile(
|
||||
"cmpq $0, %2\n"
|
||||
"je 0f\n"
|
||||
|
@ -955,25 +1046,35 @@ void SWConv1x16Kernel(float *dst, const float *src, const float *weight, const f
|
|||
"vminps %%ymm14, %%ymm1, %%ymm1\n"
|
||||
|
||||
"0:\n"
|
||||
"cmpq $13, %3\n"
|
||||
"je 1f\n"
|
||||
// write to nhwc
|
||||
"vmovups %%ymm0, (%2)\n" // dst_0
|
||||
"vmovups %%ymm1, 0x20(%2)\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
// write nc8hw8
|
||||
"vmovups %%ymm0, (%2)\n"
|
||||
"vmovups %%ymm1, (%2, %1, 1)\n"
|
||||
"2:\n"
|
||||
:
|
||||
: "a"(act_flag), "r"(oc_algin), "r"(dst)
|
||||
: "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode)
|
||||
: "%ecx", "%ymm0", "%ymm1", "%ymm12", "%ymm14");
|
||||
}
|
||||
|
||||
void SWConv12x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
|
||||
size_t write_mode) {
|
||||
in_kh_step *= sizeof(float);
|
||||
in_sw_step *= sizeof(float);
|
||||
in_kw_step *= sizeof(float);
|
||||
kw_remainder *= sizeof(float);
|
||||
size_t src_3_step = 3 * in_sw_step;
|
||||
float *dst_3 = dst + 3 * oc_algin;
|
||||
float *dst_5 = dst + 5 * oc_algin;
|
||||
float *dst_9 = dst + 9 * oc_algin;
|
||||
oc_algin *= sizeof(float);
|
||||
float *dst_3 = dst + 3 * out_step;
|
||||
float *dst_5 = dst + 5 * out_step;
|
||||
float *dst_9 = dst + 9 * out_step;
|
||||
out_step *= sizeof(float);
|
||||
asm volatile(
|
||||
"cmpq $0, %0\n"
|
||||
"je 0f\n"
|
||||
|
@ -1105,6 +1206,9 @@ void SWConv12x8Kernel(float *dst, const float *src, const float *weight, const f
|
|||
"vminps %%ymm14, %%ymm11, %%ymm11\n"
|
||||
|
||||
"Write:\n"
|
||||
"cmpq $13, %6\n"
|
||||
"je WriteNC8HW8\n"
|
||||
// write nhwc
|
||||
"vmovups %%ymm0, (%2)\n" // dst_0
|
||||
"vmovups %%ymm1, (%2, %1)\n"
|
||||
"vmovups %%ymm2, (%2, %1, 2)\n"
|
||||
|
@ -1117,21 +1221,37 @@ void SWConv12x8Kernel(float *dst, const float *src, const float *weight, const f
|
|||
"vmovups %%ymm9, (%5)\n" // dst_9
|
||||
"vmovups %%ymm10, (%5, %1, 1)\n"
|
||||
"vmovups %%ymm11, (%5, %1, 2)\n"
|
||||
"jmp End\n"
|
||||
"WriteNC8HW8:\n"
|
||||
"vmovups %%ymm0, (%2)\n" // dst_0
|
||||
"vmovups %%ymm1, 0x20(%2)\n"
|
||||
"vmovups %%ymm2, 0x40(%2)\n"
|
||||
"vmovups %%ymm3, 0x60(%2)\n" // dst_3
|
||||
"vmovups %%ymm4, 0x80(%2)\n"
|
||||
"vmovups %%ymm5, 0xA0(%2)\n" // dst_5
|
||||
"vmovups %%ymm6, 0xC0(%2)\n"
|
||||
"vmovups %%ymm7, 0xE0(%2)\n"
|
||||
"vmovups %%ymm8, 0x100(%2)\n"
|
||||
"vmovups %%ymm9, 0x120(%2)\n" // dst_9
|
||||
"vmovups %%ymm10, 0x140(%2)\n"
|
||||
"vmovups %%ymm11, 0x160(%2)\n"
|
||||
"End:\n"
|
||||
:
|
||||
: "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3), "r"(dst_5), "r"(dst_9)
|
||||
: "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3), "r"(dst_5), "r"(dst_9), "r"(write_mode)
|
||||
: "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm14");
|
||||
}
|
||||
|
||||
void SWConv4x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
|
||||
size_t write_mode) {
|
||||
in_kh_step *= sizeof(float);
|
||||
in_sw_step *= sizeof(float);
|
||||
in_kw_step *= sizeof(float);
|
||||
size_t src_step = 3 * in_sw_step;
|
||||
float *dst_3 = dst + 3 * oc_algin;
|
||||
oc_algin *= sizeof(float);
|
||||
float *dst_3 = dst + 3 * out_step;
|
||||
out_step *= sizeof(float);
|
||||
asm volatile(
|
||||
"cmpq $0, %0\n"
|
||||
"je 0f\n"
|
||||
|
@ -1215,17 +1335,18 @@ void SWConv4x8Kernel(float *dst, const float *src, const float *weight, const fl
|
|||
"vmovups %%ymm2, (%2, %1, 2)\n"
|
||||
"vmovups %%ymm3, (%3)\n" // dst_3
|
||||
:
|
||||
: "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3)
|
||||
: "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3)
|
||||
: "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm12", "%ymm14");
|
||||
}
|
||||
|
||||
void SWConv1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
|
||||
size_t write_mode) {
|
||||
in_kh_step *= sizeof(float);
|
||||
in_kw_step *= sizeof(float);
|
||||
kw_remainder *= sizeof(float);
|
||||
oc_algin *= sizeof(float);
|
||||
out_step *= sizeof(float);
|
||||
asm volatile(
|
||||
"cmpq $0, %2\n"
|
||||
"je 0f\n"
|
||||
|
@ -1277,16 +1398,18 @@ void SWConv1x8Kernel(float *dst, const float *src, const float *weight, const fl
|
|||
"vminps %%ymm14, %%ymm0, %%ymm0\n"
|
||||
|
||||
"0:\n"
|
||||
// write to nhec and nc8hw8 is identical!
|
||||
"vmovups %%ymm0, (%2)\n" // dst_0
|
||||
:
|
||||
: "a"(act_flag), "r"(oc_algin), "r"(dst)
|
||||
: "a"(act_flag), "r"(out_step), "r"(dst)
|
||||
: "%ecx", "%ymm0", "%ymm12", "%ymm14");
|
||||
}
|
||||
|
||||
#ifdef ENABLE_DEBUG
|
||||
void SWConvWxKKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
|
||||
size_t write_mode) {
|
||||
__m256 dst_data[12];
|
||||
const float *src_kh[12];
|
||||
const float *src_kw[12];
|
||||
|
@ -1339,7 +1462,13 @@ void SWConvWxKKernel(float *dst, const float *src, const float *weight, const fl
|
|||
if (0x2 & act_flag) { // relu
|
||||
dst_data[i * oc_block + j] = _mm256_max_ps(dst_data[i * oc_block + j], _mm256_set1_ps(0.0f));
|
||||
}
|
||||
_mm256_storeu_ps(dst + i * oc_algin + j * C8NUM, dst_data[i * oc_block + j]);
|
||||
if (write_mode == 13) {
|
||||
// write nc8hw8
|
||||
_mm256_storeu_ps(dst + j * out_step + i * C8NUM, dst_data[i * oc_block + j]);
|
||||
} else {
|
||||
// write nhwc
|
||||
_mm256_storeu_ps(dst + i * out_step + j * C8NUM, dst_data[i * oc_block + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,11 +43,11 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_
|
|||
typedef void (*SWConvKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step,
|
||||
size_t kw_remainder);
|
||||
size_t kw_remainder, size_t write_mode);
|
||||
|
||||
void SWBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left,
|
||||
int right, const ConvParameter *conv_param, const SlidingWindowParam *sw_param, SWConvKernel kernel,
|
||||
int act_type, int ow_bock, int oc_block);
|
||||
int act_type, int ow_bock, int oc_block, size_t write_mode);
|
||||
|
||||
void ConvSWFp32(const float *input_data, const float *packed_weight, const float *bias_data, float *output_data,
|
||||
int task_id, ConvParameter *conv_param, SlidingWindowParam *sw_param);
|
||||
|
@ -57,48 +57,59 @@ void SWCenter(float *dst, const float *src, const float *weight, const float *bi
|
|||
#ifdef ENABLE_DEBUG
|
||||
void SWConvWxKKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder);
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
|
||||
size_t write_mode);
|
||||
#endif
|
||||
|
||||
void SWConv3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder);
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
|
||||
size_t write_mode);
|
||||
|
||||
void SWConv1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder);
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
|
||||
size_t write_mode);
|
||||
|
||||
void SWConv4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder);
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
|
||||
size_t write_mode);
|
||||
|
||||
void SWConv1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder);
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
|
||||
size_t write_mode);
|
||||
|
||||
void SWConv6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder);
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
|
||||
size_t write_mode);
|
||||
|
||||
void SWConv1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder);
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
|
||||
size_t write_mode);
|
||||
|
||||
void SWConv12x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder);
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
|
||||
size_t write_mode);
|
||||
|
||||
void SWConv8x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder);
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
|
||||
size_t write_mode);
|
||||
|
||||
void SWConv4x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder);
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
|
||||
size_t write_mode);
|
||||
|
||||
void SWConv1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
|
||||
size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder);
|
||||
size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder,
|
||||
size_t write_mode);
|
||||
#endif
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "nnacl/fp32/common_func_fp32.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/fp32/activation_fp32.h"
|
||||
|
||||
#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE)
|
||||
void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, int num_pixels,
|
||||
|
@ -80,10 +81,10 @@ int ConvDw(float *output_data, const float *input_data, const float *weight_data
|
|||
}
|
||||
}
|
||||
if (relu) {
|
||||
ReluFp32(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_);
|
||||
Fp32Relu(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data);
|
||||
}
|
||||
if (relu6) {
|
||||
Relu6Fp32(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_);
|
||||
Fp32Relu6(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -779,10 +780,10 @@ void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, c
|
|||
}
|
||||
}
|
||||
if (relu) {
|
||||
ReluFp32(output, output, channels);
|
||||
Fp32Relu(output, channels, output);
|
||||
}
|
||||
if (relu6) {
|
||||
Relu6Fp32(output, output, channels);
|
||||
Fp32Relu6(output, channels, output);
|
||||
}
|
||||
output += channels;
|
||||
input = input + input_stride;
|
||||
|
|
|
@ -164,16 +164,24 @@ int InterpRow(const float *src_line, float *linear_output, int new_width, const
|
|||
int w;
|
||||
for (w = 0; w < new_width; w++) {
|
||||
int c = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t left_w = vdupq_n_f32(x_left_weights[w]);
|
||||
float32x4_t right_w = vdupq_n_f32(1.0f - x_left_weights[w]);
|
||||
|
||||
for (; c <= in_c - 4; c += 4) {
|
||||
float32x4_t left = vld1q_f32(src_line + x_lefts[w] * in_c + c);
|
||||
float32x4_t right = vld1q_f32(src_line + x_rights[w] * in_c + c);
|
||||
|
||||
float32x4_t interp_value = left * left_w + right * right_w;
|
||||
vst1q_f32(linear_output + w * in_c + c, interp_value);
|
||||
#if defined(ENABLE_AVX)
|
||||
MS_FLOAT32X8 left_w_8 = MS_MOV256_F32(x_left_weights[w]);
|
||||
MS_FLOAT32X8 right_w_8 = MS_MOV256_F32(1.0f - x_left_weights[w]);
|
||||
for (; c <= in_c - C8NUM; c += C8NUM) {
|
||||
MS_FLOAT32X8 left = MS_LD256_F32(src_line + x_lefts[w] * in_c + c);
|
||||
MS_FLOAT32X8 right = MS_LD256_F32(src_line + x_rights[w] * in_c + c);
|
||||
MS_FLOAT32X8 interp_value = left * left_w_8 + right * right_w_8;
|
||||
MS_ST256_F32(linear_output + w * in_c + c, interp_value);
|
||||
}
|
||||
#endif
|
||||
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
|
||||
MS_FLOAT32X4 left_w = MS_MOVQ_F32(x_left_weights[w]);
|
||||
MS_FLOAT32X4 right_w = MS_MOVQ_F32(1.0f - x_left_weights[w]);
|
||||
for (; c <= in_c - C4NUM; c += C4NUM) {
|
||||
MS_FLOAT32X4 left = MS_LDQ_F32(src_line + x_lefts[w] * in_c + c);
|
||||
MS_FLOAT32X4 right = MS_LDQ_F32(src_line + x_rights[w] * in_c + c);
|
||||
MS_FLOAT32X4 interp_value = left * left_w + right * right_w;
|
||||
MS_STQ_F32(linear_output + w * in_c + c, interp_value);
|
||||
}
|
||||
#endif
|
||||
int left_w_offset = x_lefts[w] * in_c;
|
||||
|
@ -192,15 +200,24 @@ int InterpCol(const float *bottom_line, const float *top_line, float *output, in
|
|||
int w;
|
||||
for (w = 0; w < new_width; w++) {
|
||||
int c = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t bottom_w = vdupq_n_f32(y_bottom_weight);
|
||||
float32x4_t top_w = vdupq_n_f32(1.0f - y_bottom_weight);
|
||||
|
||||
for (; c <= in_c - 4; c += 4) {
|
||||
float32x4_t bottom = vld1q_f32(bottom_line + w * in_c + c);
|
||||
float32x4_t top = vld1q_f32(top_line + w * in_c + c);
|
||||
float32x4_t interp_value = bottom * bottom_w + top * top_w;
|
||||
vst1q_f32(output + w * in_c + c, interp_value);
|
||||
#if defined(ENABLE_AVX)
|
||||
MS_FLOAT32X8 bottom_w_8 = MS_MOV256_F32(y_bottom_weight);
|
||||
MS_FLOAT32X8 top_w_8 = MS_MOV256_F32(1.0f - y_bottom_weight);
|
||||
for (; c <= in_c - C8NUM; c += C8NUM) {
|
||||
MS_FLOAT32X8 bottom = MS_LD256_F32(bottom_line + w * in_c + c);
|
||||
MS_FLOAT32X8 top = MS_LD256_F32(top_line + w * in_c + c);
|
||||
MS_FLOAT32X8 interp_value = bottom * bottom_w_8 + top * top_w_8;
|
||||
MS_ST256_F32(output + w * in_c + c, interp_value);
|
||||
}
|
||||
#endif
|
||||
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
|
||||
MS_FLOAT32X4 bottom_w = MS_MOVQ_F32(y_bottom_weight);
|
||||
MS_FLOAT32X4 top_w = MS_MOVQ_F32(1.0f - y_bottom_weight);
|
||||
for (; c <= in_c - C4NUM; c += C4NUM) {
|
||||
MS_FLOAT32X4 bottom = MS_LDQ_F32(bottom_line + w * in_c + c);
|
||||
MS_FLOAT32X4 top = MS_LDQ_F32(top_line + w * in_c + c);
|
||||
MS_FLOAT32X4 interp_value = bottom * bottom_w + top * top_w;
|
||||
MS_STQ_F32(output + w * in_c + c, interp_value);
|
||||
}
|
||||
#endif
|
||||
for (; c < in_c; c++) {
|
||||
|
@ -299,21 +316,40 @@ void BicubicInterpRow(const float *src, float *dst, const float *weights, const
|
|||
const float *src2_w = src + lefts[4 * w + 2] * channel;
|
||||
const float *src3_w = src + lefts[4 * w + 3] * channel;
|
||||
int c = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t weight0_vec = vdupq_n_f32(weight[0]);
|
||||
float32x4_t weight1_vec = vdupq_n_f32(weight[1]);
|
||||
float32x4_t weight2_vec = vdupq_n_f32(weight[2]);
|
||||
float32x4_t weight3_vec = vdupq_n_f32(weight[3]);
|
||||
|
||||
for (; c <= channel - 4; c += 4) {
|
||||
float32x4_t src0_vec = vld1q_f32(src0_w + c);
|
||||
float32x4_t src1_vec = vld1q_f32(src1_w + c);
|
||||
float32x4_t src2_vec = vld1q_f32(src2_w + c);
|
||||
float32x4_t src3_vec = vld1q_f32(src3_w + c);
|
||||
|
||||
float32x4_t interp_value =
|
||||
src0_vec * weight0_vec + src1_vec * weight1_vec + src2_vec * weight2_vec + src3_vec * weight3_vec;
|
||||
vst1q_f32(dst_w + c, interp_value);
|
||||
#if defined(ENABLE_AVX)
|
||||
MS_FLOAT32X8 weight0_vec_8 = MS_MOV256_F32(weight[0]);
|
||||
MS_FLOAT32X8 weight1_vec_8 = MS_MOV256_F32(weight[1]);
|
||||
MS_FLOAT32X8 weight2_vec_8 = MS_MOV256_F32(weight[2]);
|
||||
MS_FLOAT32X8 weight3_vec_8 = MS_MOV256_F32(weight[3]);
|
||||
for (; c <= channel - C8NUM; c += C8NUM) {
|
||||
MS_FLOAT32X8 src0_vec = MS_LD256_F32(src0_w + c);
|
||||
MS_FLOAT32X8 src1_vec = MS_LD256_F32(src1_w + c);
|
||||
MS_FLOAT32X8 src2_vec = MS_LD256_F32(src2_w + c);
|
||||
MS_FLOAT32X8 src3_vec = MS_LD256_F32(src3_w + c);
|
||||
MS_FLOAT32X8 dst0 = MS_MUL256_F32(src0_vec, weight0_vec_8);
|
||||
MS_FLOAT32X8 dst1 = MS_MUL256_F32(src1_vec, weight1_vec_8);
|
||||
MS_FLOAT32X8 dst2 = MS_MUL256_F32(src2_vec, weight2_vec_8);
|
||||
MS_FLOAT32X8 dst3 = MS_MUL256_F32(src3_vec, weight3_vec_8);
|
||||
MS_FLOAT32X8 interp_value = MS_ADD256_F32(dst3, MS_ADD256_F32(dst2, MS_ADD256_F32(dst1, dst0)));
|
||||
MS_ST256_F32(dst_w + c, interp_value);
|
||||
}
|
||||
#endif
|
||||
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
|
||||
MS_FLOAT32X4 weight0_vec = MS_MOVQ_F32(weight[0]);
|
||||
MS_FLOAT32X4 weight1_vec = MS_MOVQ_F32(weight[1]);
|
||||
MS_FLOAT32X4 weight2_vec = MS_MOVQ_F32(weight[2]);
|
||||
MS_FLOAT32X4 weight3_vec = MS_MOVQ_F32(weight[3]);
|
||||
for (; c <= channel - C4NUM; c += C4NUM) {
|
||||
MS_FLOAT32X4 src0_vec = MS_LDQ_F32(src0_w + c);
|
||||
MS_FLOAT32X4 src1_vec = MS_LDQ_F32(src1_w + c);
|
||||
MS_FLOAT32X4 src2_vec = MS_LDQ_F32(src2_w + c);
|
||||
MS_FLOAT32X4 src3_vec = MS_LDQ_F32(src3_w + c);
|
||||
MS_FLOAT32X4 dst0 = MS_MULQ_F32(src0_vec, weight0_vec);
|
||||
MS_FLOAT32X4 dst1 = MS_MULQ_F32(src1_vec, weight1_vec);
|
||||
MS_FLOAT32X4 dst2 = MS_MULQ_F32(src2_vec, weight2_vec);
|
||||
MS_FLOAT32X4 dst3 = MS_MULQ_F32(src3_vec, weight3_vec);
|
||||
MS_FLOAT32X4 interp_value = MS_ADDQ_F32(dst3, MS_ADDQ_F32(dst2, MS_ADDQ_F32(dst1, dst0)));
|
||||
MS_STQ_F32(dst_w + c, interp_value);
|
||||
}
|
||||
#endif
|
||||
for (; c < channel; c++) {
|
||||
|
@ -334,20 +370,40 @@ void BicubicInterpCol(const float *src, float *dst, const float *weights, int wi
|
|||
const float *src2_w = src2 + w * channel;
|
||||
const float *src3_w = src3 + w * channel;
|
||||
int c = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t weight0_vec = vdupq_n_f32(weights[0]);
|
||||
float32x4_t weight1_vec = vdupq_n_f32(weights[1]);
|
||||
float32x4_t weight2_vec = vdupq_n_f32(weights[2]);
|
||||
float32x4_t weight3_vec = vdupq_n_f32(weights[3]);
|
||||
|
||||
for (; c <= channel - 4; c += 4) {
|
||||
float32x4_t src0_vec = vld1q_f32(src0_w + c);
|
||||
float32x4_t src1_vec = vld1q_f32(src1_w + c);
|
||||
float32x4_t src2_vec = vld1q_f32(src2_w + c);
|
||||
float32x4_t src3_vec = vld1q_f32(src3_w + c);
|
||||
float32x4_t interp_value =
|
||||
src0_vec * weight0_vec + src1_vec * weight1_vec + src2_vec * weight2_vec + src3_vec * weight3_vec;
|
||||
vst1q_f32(dst_w + c, interp_value);
|
||||
#ifdef ENABLE_AVX
|
||||
MS_FLOAT32X8 weight0_vec_8 = MS_MOV256_F32(weights[0]);
|
||||
MS_FLOAT32X8 weight1_vec_8 = MS_MOV256_F32(weights[1]);
|
||||
MS_FLOAT32X8 weight2_vec_8 = MS_MOV256_F32(weights[2]);
|
||||
MS_FLOAT32X8 weight3_vec_8 = MS_MOV256_F32(weights[3]);
|
||||
for (; c <= channel - C8NUM; c += C8NUM) {
|
||||
MS_FLOAT32X8 src0_vec = MS_LD256_F32(src0_w + c);
|
||||
MS_FLOAT32X8 src1_vec = MS_LD256_F32(src1_w + c);
|
||||
MS_FLOAT32X8 src2_vec = MS_LD256_F32(src2_w + c);
|
||||
MS_FLOAT32X8 src3_vec = MS_LD256_F32(src3_w + c);
|
||||
MS_FLOAT32X8 dst1 = MS_MUL256_F32(src0_vec, weight0_vec_8);
|
||||
MS_FLOAT32X8 dst2 = MS_MUL256_F32(src1_vec, weight1_vec_8);
|
||||
MS_FLOAT32X8 dst3 = MS_MUL256_F32(src2_vec, weight2_vec_8);
|
||||
MS_FLOAT32X8 dst4 = MS_MUL256_F32(src3_vec, weight3_vec_8);
|
||||
MS_FLOAT32X8 interp_value = MS_ADD256_F32(dst4, MS_ADD256_F32(dst3, MS_ADD256_F32(dst1, dst2)));
|
||||
MS_ST256_F32(dst_w + c, interp_value);
|
||||
}
|
||||
#endif
|
||||
#if defined(ENABLE_NEON) || defined(ENABLE_SSE)
|
||||
MS_FLOAT32X4 weight0_vec = MS_MOVQ_F32(weights[0]);
|
||||
MS_FLOAT32X4 weight1_vec = MS_MOVQ_F32(weights[1]);
|
||||
MS_FLOAT32X4 weight2_vec = MS_MOVQ_F32(weights[2]);
|
||||
MS_FLOAT32X4 weight3_vec = MS_MOVQ_F32(weights[3]);
|
||||
for (; c <= channel - C4NUM; c += C4NUM) {
|
||||
MS_FLOAT32X4 src0_vec = MS_LDQ_F32(src0_w + c);
|
||||
MS_FLOAT32X4 src1_vec = MS_LDQ_F32(src1_w + c);
|
||||
MS_FLOAT32X4 src2_vec = MS_LDQ_F32(src2_w + c);
|
||||
MS_FLOAT32X4 src3_vec = MS_LDQ_F32(src3_w + c);
|
||||
MS_FLOAT32X4 dst1 = MS_MULQ_F32(src0_vec, weight0_vec);
|
||||
MS_FLOAT32X4 dst2 = MS_MULQ_F32(src1_vec, weight1_vec);
|
||||
MS_FLOAT32X4 dst3 = MS_MULQ_F32(src2_vec, weight2_vec);
|
||||
MS_FLOAT32X4 dst4 = MS_MULQ_F32(src3_vec, weight3_vec);
|
||||
MS_FLOAT32X4 interp_value = MS_ADDQ_F32(dst4, MS_ADDQ_F32(dst3, MS_ADDQ_F32(dst1, dst2)));
|
||||
MS_STQ_F32(dst_w + c, interp_value);
|
||||
}
|
||||
#endif
|
||||
for (; c < channel; c++) {
|
||||
|
|
|
@ -75,6 +75,7 @@ int ConvolutionDepthwiseFP32Coder::DoCode(CoderContext *const context) {
|
|||
},
|
||||
{
|
||||
"conv_depthwise_fp32.c",
|
||||
"activation_fp32.c",
|
||||
},
|
||||
{});
|
||||
nnacl::NNaclFp32Serializer code;
|
||||
|
|
Loading…
Reference in New Issue