!3844 little optimization for depth wise kernels

Merge pull request !3844 from lixian/master
This commit is contained in:
mindspore-ci-bot 2020-08-01 17:30:27 +08:00 committed by Gitee
commit 9ba4e4d88d
3 changed files with 24 additions and 56 deletions

View File

@ -32,16 +32,6 @@ ConvDwFp32Center:
ldr x14, [sp, #48]
ldr x15, [sp, #56]
mov x16, #4
mul x8, x8, x16
mul x9, x9, x16
mul x10, x10, x16
mul x11, x11, x16
mul x12, x12, x16
mul x13, x13, x16
mov x16, #16
mul x19, x7, x16
ld1 {v5.4s}, [x3]
LoopH:
@ -52,20 +42,17 @@ ConvDwFp32Center:
mov x16, x23
mov x17, x2
mov x20, x6
ld1 {v0.4s}, [x3]
fadd v0.4s, v0.4s, v5.4s
mov v0.16b, v5.16b
LoopKh:
mov x18, x7
mov x21, x17
mov x22, x16
LoopKw:
ld1 {v1.4s}, [x22], x13
ld1 {v2.4s}, [x21], #16
ld1 {v2.4s}, [x17], #16
fmla v0.4s, v1.4s, v2.4s
subs x18, x18, #1
bne LoopKw
add x16, x16, x12
add x17, x17, x19
subs x20, x20, #1
bne LoopKh
cbnz x15, Relu6

View File

@ -27,16 +27,6 @@ DeconvDwFp32Center:
ldr x11, [sp, #24]
ldr x12, [sp, #32]
mov x13, #4
mul x7, x7, x13
mul x8, x8, x13
mul x9, x9, x13
mul x10, x10, x13
mul x11, x11, x13
mul x12, x12, x13
mov x13, #16
mul x14, x6, x13
LoopH:
mov x15, x0
mov x16, x1
@ -45,20 +35,18 @@ DeconvDwFp32Center:
mov x18, x15
mov x19, x2
mov x20, x5
dup v0.4s, wzr
LoopKh:
mov x21, x18
mov x22, x19
mov x13, x6
LoopKw:
ld1 {v0.4s}, [x21]
ld1 {v1.4s}, [x16]
ld1 {v2.4s}, [x22], #16
ld1 {v2.4s}, [x19], #16
fmla v0.4s, v1.4s, v2.4s
st1 {v0.4s}, [x21], x12
subs x13, x13, #1
bne LoopKw
add x18, x18, x11
add x19, x19, x14
subs x20, x20, #1
bne LoopKh
add x15, x15, x10

View File

@ -120,13 +120,10 @@ void DepthwiseBorder(float *dst, const float *src, const float *weight, const fl
} // height loop
}
#ifndef ENABLE_ARM64
void DepthwiseCenter(float *dst, const float *src, const float *weight, const float *bias, int height, int width,
int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step,
int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6) {
#ifdef ENABLE_ARM64
ConvDwFp32Center(dst, src, weight, bias, height, width, kernel_h, kernel_w, out_h_step, block_channel,
in_sh_step, in_sw_step, in_kh_step, in_kw_step, is_relu, is_relu6);
#else
float *dst_h = dst;
const float *src_h = src;
for (int oh = 0; oh < height; oh++) {
@ -139,17 +136,9 @@ void DepthwiseCenter(float *dst, const float *src, const float *weight, const fl
const float *src_kw = src_kh;
const float *weight_kw = weight_kh;
for (int kw = 0; kw < kernel_w; kw++) {
#ifdef ENABLE_ARM64
float32x4_t src_4 = vld1q_f32(src_kw);
float32x4_t weight_4 = vld1q_f32(weight_kw);
float32x4_t dst_4 = vld1q_f32(dst_w);
dst_4 = vfmaq_f32(dst_4, src_4, weight_4);
vst1q_f32(dst_w, dst_4);
#else
for (int c = 0; c < C4NUM; c++) {
dst_w[c] += src_kw[c] * weight_kw[c];
}
#endif
src_kw += in_kw_step;
weight_kw += C4NUM;
} // kernel_w loop
@ -168,8 +157,8 @@ void DepthwiseCenter(float *dst, const float *src, const float *weight, const fl
dst_h += out_h_step;
src_h += in_sh_step;
} // dst_height loop
#endif
}
#endif
// conv depthwise fp32: sliding window
void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
@ -196,11 +185,18 @@ void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weig
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_;
const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_;
float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
#ifdef ENABLE_ARM64
ConvDwFp32Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float),
sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float),
sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float),
sliding->in_kw_step_ * sizeof(float), conv_param->is_relu_, conv_param->is_relu6_);
#else
DepthwiseCenter(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_,
sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_,
conv_param->is_relu_, conv_param->is_relu6_);
#endif
}
} // output C4 loop
src += sliding->in_step_;
@ -265,13 +261,10 @@ void DeconvDepthwiseBorder(float *dst, const float *src, const float *weight, in
} // height loop
}
#ifndef ENABLE_ARM64
void DeconvDepthwiseCenter(float *dst, const float *src, const float *weight, int height, int width, int kernel_h,
int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step,
int in_kh_step, int in_kw_step) {
#ifdef ENABLE_ARM64
DeconvDwFp32Center(dst, src, weight, height, width, kernel_h, kernel_w, out_h_step, block_channel,
in_sh_step, in_sw_step, in_kh_step, in_kw_step);
#else
float *dst_h = dst;
const float *src_h = src;
for (int oh = 0; oh < height; oh++) {
@ -284,17 +277,9 @@ void DeconvDepthwiseCenter(float *dst, const float *src, const float *weight, in
float *dst_kw = dst_kh;
const float *weight_kw = weight_kh;
for (int kw = 0; kw < kernel_w; kw++) {
#ifdef ENABLE_ARM64
float32x4_t src_4 = vld1q_f32(src_w);
float32x4_t weight_4 = vld1q_f32(weight_kw);
float32x4_t dst_4 = vld1q_f32(dst_kw);
dst_4 = vfmaq_f32(dst_4, src_4, weight_4);
vst1q_f32(dst_kw, dst_4);
#else
for (int c = 0; c < C4NUM; c++) {
dst_kw[c] += src_w[c] * weight_kw[c];
}
#endif
dst_kw += in_kw_step;
weight_kw += C4NUM;
} // kernel_w loop
@ -307,8 +292,8 @@ void DeconvDepthwiseCenter(float *dst, const float *src, const float *weight, in
dst_h += in_sh_step;
src_h += out_h_step;
} // dst_height loop
#endif
}
#endif
void DeconvDepthwisePostFunc(float *dst, const float *bias, int block_channel, const ConvParameter *conv_param) {
float *dst_k = dst;
@ -347,10 +332,18 @@ void DeconvDwC4Fp32(float *output_data, const float *input_data, const float *we
float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_;
const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
#ifdef ENABLE_ARM64
DeconvDwFp32Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float),
sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float),
sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float),
sliding->in_kw_step_ * sizeof(float));
#else
DeconvDepthwiseCenter(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_,
sliding->block_channel_, sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_,
sliding->in_kw_step_);
#endif
}
DeconvDepthwisePostFunc(dst_data, bias, sliding->block_channel_, conv_param);
} // output C4 loop