!30809 [MSLITE] winograd build bug

Merge pull request !30809 from ling/sr
This commit is contained in:
i-robot 2022-03-04 01:20:44 +00:00 committed by Gitee
commit 9fe8945d2e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 39 additions and 32 deletions

View File

@ -637,46 +637,53 @@ void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, i
}
}
void InputTransform8x8StepFp16_uint(float16x8_t *s, float16x8_t *m) {
m[0] =
vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(s[0], 0.5625), vmulq_n_f16(s[2], 3.0625)), vmulq_n_f16(s[4], 3.5)), s[6]);
float16x8_t tmp1 = vaddq_f16(vmulq_n_f16(s[1], 1.125), vmulq_n_f16(s[5], 0.5));
float16x8_t tmp2 = vsubq_f16(vmulq_n_f16(s[2], 2.25), vmulq_n_f16(s[4], 3.25));
m[1] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s[3], 1.625)), s[6]);
m[2] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s[3], 1.625)), s[6]);
tmp1 = vaddq_f16(vmulq_n_f16(s[1], 0.5625), s[5]);
tmp2 = vsubq_f16(vmulq_n_f16(s[2], 0.5625), vmulq_n_f16(s[4], 2.5));
m[3] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s[3], 2.5)), s[6]);
m[4] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s[3], 2.5)), s[6]);
tmp1 = vaddq_f16(vmulq_n_f16(s[1], 0.375), vmulq_n_f16(s[5], 1.5));
tmp2 = vsubq_f16(vmulq_n_f16(s[2], 0.25), vmulq_n_f16(s[4], 1.25));
m[5] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s[3], 1.875)), s[6]);
m[6] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s[3], 1.875)), s[6]);
m[7] = vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(s[1], -0.5625), vmulq_n_f16(s[3], 3.0625)), vmulq_n_f16(s[5], 3.5)),
s[7]);
}
void InputTransform8x8StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step,
int dst_row_step) {
for (int l = 0; l < 8; ++l) {
const float16_t *src_ptr = src_data + l * 8 * src_step;
float16_t *dst_ptr = dst_data + l * dst_row_step;
float16x8_t s0 = vld1q_f16(src_ptr + 0 * src_step);
float16x8_t s1 = vld1q_f16(src_ptr + 1 * src_step);
float16x8_t s2 = vld1q_f16(src_ptr + 2 * src_step);
float16x8_t s3 = vld1q_f16(src_ptr + 3 * src_step);
float16x8_t s4 = vld1q_f16(src_ptr + 4 * src_step);
float16x8_t s5 = vld1q_f16(src_ptr + 5 * src_step);
float16x8_t s6 = vld1q_f16(src_ptr + 6 * src_step);
float16x8_t s7 = vld1q_f16(src_ptr + 7 * src_step);
float16x8_t s[8];
float16x8_t m[8];
float16x8_t m0 =
vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(s0, 0.5625), vmulq_n_f16(s2, 3.0625)), vmulq_n_f16(s4, 3.5)), s6);
float16x8_t tmp1 = vaddq_f16(vmulq_n_f16(s1, 1.125), vmulq_n_f16(s5, 0.5));
float16x8_t tmp2 = vsubq_f16(vmulq_n_f16(s2, 2.25), vmulq_n_f16(s4, 3.25));
float16x8_t m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s3, 1.625)), s6);
float16x8_t m2 = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s3, 1.625)), s6);
tmp1 = vaddq_f16(vmulq_n_f16(s1, 0.5625), s5);
tmp2 = vsubq_f16(vmulq_n_f16(s2, 0.5625), vmulq_n_f16(s4, 2.5));
float16x8_t m3 = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s3, 2.5)), s6);
float16x8_t m4 = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s3, 2.5)), s6);
tmp1 = vaddq_f16(vmulq_n_f16(s1, 0.375), vmulq_n_f16(s5, 1.5));
tmp2 = vsubq_f16(vmulq_n_f16(s2, 0.25), vmulq_n_f16(s4, 1.25));
float16x8_t m5 = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s3, 1.875)), s6);
float16x8_t m6 = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s3, 1.875)), s6);
float16x8_t m7 =
vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(s1, -0.5625), vmulq_n_f16(s3, 3.0625)), vmulq_n_f16(s5, 3.5)), s7);
s[0] = vld1q_f16(src_ptr + 0 * src_step);
s[1] = vld1q_f16(src_ptr + 1 * src_step);
s[2] = vld1q_f16(src_ptr + 2 * src_step);
s[3] = vld1q_f16(src_ptr + 3 * src_step);
s[4] = vld1q_f16(src_ptr + 4 * src_step);
s[5] = vld1q_f16(src_ptr + 5 * src_step);
s[6] = vld1q_f16(src_ptr + 6 * src_step);
s[7] = vld1q_f16(src_ptr + 7 * src_step);
vst1q_f16(dst_ptr + 0 * dst_step, m0);
vst1q_f16(dst_ptr + 1 * dst_step, m1);
vst1q_f16(dst_ptr + 2 * dst_step, m2);
vst1q_f16(dst_ptr + 3 * dst_step, m3);
vst1q_f16(dst_ptr + 4 * dst_step, m4);
vst1q_f16(dst_ptr + 5 * dst_step, m5);
vst1q_f16(dst_ptr + 6 * dst_step, m6);
vst1q_f16(dst_ptr + 7 * dst_step, m7);
InputTransform8x8StepFp16_uint(s, m);
vst1q_f16(dst_ptr + 0 * dst_step, m[0]);
vst1q_f16(dst_ptr + 1 * dst_step, m[1]);
vst1q_f16(dst_ptr + 2 * dst_step, m[2]);
vst1q_f16(dst_ptr + 3 * dst_step, m[3]);
vst1q_f16(dst_ptr + 4 * dst_step, m[4]);
vst1q_f16(dst_ptr + 5 * dst_step, m[5]);
vst1q_f16(dst_ptr + 6 * dst_step, m[6]);
vst1q_f16(dst_ptr + 7 * dst_step, m[7]);
}
}