forked from mindspore-Ecosystem/mindspore
!29009 [MS][LITE]Optimize deconv fp16 kernel
Merge pull request !29009 from 张学同/debug_deconv
This commit is contained in:
commit
93732478d1
|
@ -16,6 +16,45 @@
|
|||
|
||||
#include "nnacl/fp16/deconv_fp16.h"
|
||||
#include <float.h>
|
||||
|
||||
void DeConvPostAddC8WithStride(const float16_t *source, float16_t *dest, size_t srcStride, size_t dststride,
|
||||
size_t count) {
|
||||
if (count == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float16_t *src_ptr = source;
|
||||
float16_t *dst_ptr = dest;
|
||||
float16x8_t src1 = vld1q_f16(src_ptr);
|
||||
float16x8_t dst1 = vld1q_f16(dst_ptr);
|
||||
float16x8_t src2;
|
||||
float16x8_t dst2;
|
||||
size_t i = 1;
|
||||
while (i < count - 1) {
|
||||
dst1 = vaddq_f16(dst1, src1);
|
||||
vst1q_f16(dst_ptr, dst1);
|
||||
|
||||
src2 = vld1q_f16(src_ptr + srcStride);
|
||||
dst2 = vld1q_f16(dst_ptr + dststride);
|
||||
dst2 = vaddq_f16(dst2, src2);
|
||||
vst1q_f16(dst_ptr + dststride, dst2);
|
||||
i = i + 2;
|
||||
src1 = vld1q_f16(src_ptr + srcStride + srcStride);
|
||||
dst1 = vld1q_f16(dst_ptr + dststride + dststride);
|
||||
|
||||
src_ptr = src_ptr + srcStride + srcStride;
|
||||
dst_ptr = dst_ptr + dststride + dststride;
|
||||
}
|
||||
dst1 = vaddq_f16(dst1, src1);
|
||||
vst1q_f16(dst_ptr, dst1);
|
||||
if (i < count) {
|
||||
src2 = vld1q_f16(src_ptr + srcStride);
|
||||
dst2 = vld1q_f16(dst_ptr + dststride);
|
||||
dst2 = vaddq_f16(dst2, src2);
|
||||
vst1q_f16(dst_ptr + dststride, dst2);
|
||||
}
|
||||
}
|
||||
|
||||
int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias, float16_t *dst, int output_channel,
|
||||
const ConvParameter *conv_param) {
|
||||
float16x8_t min_v = vdupq_n_f16(-FLT_MAX);
|
||||
|
@ -67,34 +106,11 @@ int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias,
|
|||
for (int kh = kh_start; kh < kh_end; kh++) {
|
||||
const float16_t *src_kh_ptr = src_in_ptr + kh * src_kh_stride;
|
||||
float16_t *dst_kh_ptr = dst_in_ptr + kh * dst_kh_stride;
|
||||
|
||||
for (int kw = kw_start; kw < kw_end; kw++) {
|
||||
const float16_t *src_kw_index = src_kh_ptr + kw * src_kw_stride;
|
||||
float16_t *dst_kw_index = dst_kh_ptr + kw * dst_kw_stride;
|
||||
#ifdef ENABLE_ARM64
|
||||
asm volatile(
|
||||
"mov x0, %[src_kw_index] \n"
|
||||
"mov x1, %[dst_kw_index] \n"
|
||||
|
||||
"ld1 {v0.8h}, [x0] \n"
|
||||
"ld1 {v1.8h}, [x1] \n"
|
||||
|
||||
"fadd v0.8h, v0.8h, v1.8h \n"
|
||||
|
||||
"st1 {v0.8h}, [x1] \n"
|
||||
|
||||
:
|
||||
: [ src_kw_index ] "r"(src_kw_index), [ dst_kw_index ] "r"(dst_kw_index)
|
||||
: "x0", "x1", "v0", "v1");
|
||||
#else
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
dst_kw_index[i] += src_kw_index[i];
|
||||
}
|
||||
#endif
|
||||
} // kw
|
||||
} // kh
|
||||
} // iw
|
||||
} // ih
|
||||
DeConvPostAddC8WithStride(src_kh_ptr + kw_start * src_kw_stride, dst_kh_ptr + kw_start * dst_kw_stride,
|
||||
src_kw_stride, dst_kw_stride, kw_end - kw_start);
|
||||
} // kh
|
||||
} // iw
|
||||
} // ih
|
||||
|
||||
/* add bias for current oh*ow*C8
|
||||
* write to output data ptr in nhwc format */
|
||||
|
|
Loading…
Reference in New Issue