!29009 [MS][LITE]Optimize deconv fp16 kernel

Merge pull request !29009 from 张学同/debug_deconv
This commit is contained in:
i-robot 2022-01-14 07:52:41 +00:00 committed by Gitee
commit 93732478d1
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 44 additions and 28 deletions

View File

@ -16,6 +16,45 @@
#include "nnacl/fp16/deconv_fp16.h"
#include <float.h>
void DeConvPostAddC8WithStride(const float16_t *source, float16_t *dest, size_t srcStride, size_t dststride,
size_t count) {
if (count == 0) {
return;
}
const float16_t *src_ptr = source;
float16_t *dst_ptr = dest;
float16x8_t src1 = vld1q_f16(src_ptr);
float16x8_t dst1 = vld1q_f16(dst_ptr);
float16x8_t src2;
float16x8_t dst2;
size_t i = 1;
while (i < count - 1) {
dst1 = vaddq_f16(dst1, src1);
vst1q_f16(dst_ptr, dst1);
src2 = vld1q_f16(src_ptr + srcStride);
dst2 = vld1q_f16(dst_ptr + dststride);
dst2 = vaddq_f16(dst2, src2);
vst1q_f16(dst_ptr + dststride, dst2);
i = i + 2;
src1 = vld1q_f16(src_ptr + srcStride + srcStride);
dst1 = vld1q_f16(dst_ptr + dststride + dststride);
src_ptr = src_ptr + srcStride + srcStride;
dst_ptr = dst_ptr + dststride + dststride;
}
dst1 = vaddq_f16(dst1, src1);
vst1q_f16(dst_ptr, dst1);
if (i < count) {
src2 = vld1q_f16(src_ptr + srcStride);
dst2 = vld1q_f16(dst_ptr + dststride);
dst2 = vaddq_f16(dst2, src2);
vst1q_f16(dst_ptr + dststride, dst2);
}
}
int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias, float16_t *dst, int output_channel,
const ConvParameter *conv_param) {
float16x8_t min_v = vdupq_n_f16(-FLT_MAX);
@ -67,34 +106,11 @@ int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias,
for (int kh = kh_start; kh < kh_end; kh++) {
const float16_t *src_kh_ptr = src_in_ptr + kh * src_kh_stride;
float16_t *dst_kh_ptr = dst_in_ptr + kh * dst_kh_stride;
for (int kw = kw_start; kw < kw_end; kw++) {
const float16_t *src_kw_index = src_kh_ptr + kw * src_kw_stride;
float16_t *dst_kw_index = dst_kh_ptr + kw * dst_kw_stride;
#ifdef ENABLE_ARM64
asm volatile(
"mov x0, %[src_kw_index] \n"
"mov x1, %[dst_kw_index] \n"
"ld1 {v0.8h}, [x0] \n"
"ld1 {v1.8h}, [x1] \n"
"fadd v0.8h, v0.8h, v1.8h \n"
"st1 {v0.8h}, [x1] \n"
:
: [ src_kw_index ] "r"(src_kw_index), [ dst_kw_index ] "r"(dst_kw_index)
: "x0", "x1", "v0", "v1");
#else
for (int i = 0; i < C8NUM; i++) {
dst_kw_index[i] += src_kw_index[i];
}
#endif
} // kw
} // kh
} // iw
} // ih
DeConvPostAddC8WithStride(src_kh_ptr + kw_start * src_kw_stride, dst_kh_ptr + kw_start * dst_kw_stride,
src_kw_stride, dst_kw_stride, kw_end - kw_start);
} // kh
} // iw
} // ih
/* add bias for current oh*ow*C8
* write to output data ptr in nhwc format */