forked from mindspore-Ecosystem/mindspore
!30809 [MSLITE] winograd build bug
Merge pull request !30809 from ling/sr
This commit is contained in:
commit
9fe8945d2e
|
@ -637,46 +637,53 @@ void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, i
|
|||
}
|
||||
}
|
||||
|
||||
void InputTransform8x8StepFp16_uint(float16x8_t *s, float16x8_t *m) {
|
||||
m[0] =
|
||||
vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(s[0], 0.5625), vmulq_n_f16(s[2], 3.0625)), vmulq_n_f16(s[4], 3.5)), s[6]);
|
||||
float16x8_t tmp1 = vaddq_f16(vmulq_n_f16(s[1], 1.125), vmulq_n_f16(s[5], 0.5));
|
||||
float16x8_t tmp2 = vsubq_f16(vmulq_n_f16(s[2], 2.25), vmulq_n_f16(s[4], 3.25));
|
||||
m[1] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s[3], 1.625)), s[6]);
|
||||
m[2] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s[3], 1.625)), s[6]);
|
||||
tmp1 = vaddq_f16(vmulq_n_f16(s[1], 0.5625), s[5]);
|
||||
tmp2 = vsubq_f16(vmulq_n_f16(s[2], 0.5625), vmulq_n_f16(s[4], 2.5));
|
||||
m[3] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s[3], 2.5)), s[6]);
|
||||
m[4] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s[3], 2.5)), s[6]);
|
||||
tmp1 = vaddq_f16(vmulq_n_f16(s[1], 0.375), vmulq_n_f16(s[5], 1.5));
|
||||
tmp2 = vsubq_f16(vmulq_n_f16(s[2], 0.25), vmulq_n_f16(s[4], 1.25));
|
||||
m[5] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s[3], 1.875)), s[6]);
|
||||
m[6] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s[3], 1.875)), s[6]);
|
||||
m[7] = vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(s[1], -0.5625), vmulq_n_f16(s[3], 3.0625)), vmulq_n_f16(s[5], 3.5)),
|
||||
s[7]);
|
||||
}
|
||||
|
||||
void InputTransform8x8StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step,
|
||||
int dst_row_step) {
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
const float16_t *src_ptr = src_data + l * 8 * src_step;
|
||||
float16_t *dst_ptr = dst_data + l * dst_row_step;
|
||||
|
||||
float16x8_t s0 = vld1q_f16(src_ptr + 0 * src_step);
|
||||
float16x8_t s1 = vld1q_f16(src_ptr + 1 * src_step);
|
||||
float16x8_t s2 = vld1q_f16(src_ptr + 2 * src_step);
|
||||
float16x8_t s3 = vld1q_f16(src_ptr + 3 * src_step);
|
||||
float16x8_t s4 = vld1q_f16(src_ptr + 4 * src_step);
|
||||
float16x8_t s5 = vld1q_f16(src_ptr + 5 * src_step);
|
||||
float16x8_t s6 = vld1q_f16(src_ptr + 6 * src_step);
|
||||
float16x8_t s7 = vld1q_f16(src_ptr + 7 * src_step);
|
||||
float16x8_t s[8];
|
||||
float16x8_t m[8];
|
||||
|
||||
float16x8_t m0 =
|
||||
vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(s0, 0.5625), vmulq_n_f16(s2, 3.0625)), vmulq_n_f16(s4, 3.5)), s6);
|
||||
float16x8_t tmp1 = vaddq_f16(vmulq_n_f16(s1, 1.125), vmulq_n_f16(s5, 0.5));
|
||||
float16x8_t tmp2 = vsubq_f16(vmulq_n_f16(s2, 2.25), vmulq_n_f16(s4, 3.25));
|
||||
float16x8_t m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s3, 1.625)), s6);
|
||||
float16x8_t m2 = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s3, 1.625)), s6);
|
||||
tmp1 = vaddq_f16(vmulq_n_f16(s1, 0.5625), s5);
|
||||
tmp2 = vsubq_f16(vmulq_n_f16(s2, 0.5625), vmulq_n_f16(s4, 2.5));
|
||||
float16x8_t m3 = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s3, 2.5)), s6);
|
||||
float16x8_t m4 = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s3, 2.5)), s6);
|
||||
tmp1 = vaddq_f16(vmulq_n_f16(s1, 0.375), vmulq_n_f16(s5, 1.5));
|
||||
tmp2 = vsubq_f16(vmulq_n_f16(s2, 0.25), vmulq_n_f16(s4, 1.25));
|
||||
float16x8_t m5 = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s3, 1.875)), s6);
|
||||
float16x8_t m6 = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s3, 1.875)), s6);
|
||||
float16x8_t m7 =
|
||||
vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(s1, -0.5625), vmulq_n_f16(s3, 3.0625)), vmulq_n_f16(s5, 3.5)), s7);
|
||||
s[0] = vld1q_f16(src_ptr + 0 * src_step);
|
||||
s[1] = vld1q_f16(src_ptr + 1 * src_step);
|
||||
s[2] = vld1q_f16(src_ptr + 2 * src_step);
|
||||
s[3] = vld1q_f16(src_ptr + 3 * src_step);
|
||||
s[4] = vld1q_f16(src_ptr + 4 * src_step);
|
||||
s[5] = vld1q_f16(src_ptr + 5 * src_step);
|
||||
s[6] = vld1q_f16(src_ptr + 6 * src_step);
|
||||
s[7] = vld1q_f16(src_ptr + 7 * src_step);
|
||||
|
||||
vst1q_f16(dst_ptr + 0 * dst_step, m0);
|
||||
vst1q_f16(dst_ptr + 1 * dst_step, m1);
|
||||
vst1q_f16(dst_ptr + 2 * dst_step, m2);
|
||||
vst1q_f16(dst_ptr + 3 * dst_step, m3);
|
||||
vst1q_f16(dst_ptr + 4 * dst_step, m4);
|
||||
vst1q_f16(dst_ptr + 5 * dst_step, m5);
|
||||
vst1q_f16(dst_ptr + 6 * dst_step, m6);
|
||||
vst1q_f16(dst_ptr + 7 * dst_step, m7);
|
||||
InputTransform8x8StepFp16_uint(s, m);
|
||||
|
||||
vst1q_f16(dst_ptr + 0 * dst_step, m[0]);
|
||||
vst1q_f16(dst_ptr + 1 * dst_step, m[1]);
|
||||
vst1q_f16(dst_ptr + 2 * dst_step, m[2]);
|
||||
vst1q_f16(dst_ptr + 3 * dst_step, m[3]);
|
||||
vst1q_f16(dst_ptr + 4 * dst_step, m[4]);
|
||||
vst1q_f16(dst_ptr + 5 * dst_step, m[5]);
|
||||
vst1q_f16(dst_ptr + 6 * dst_step, m[6]);
|
||||
vst1q_f16(dst_ptr + 7 * dst_step, m[7]);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue