forked from mindspore-Ecosystem/mindspore
NC4HW4 instance
This commit is contained in:
parent
ea3cd6d262
commit
291e5cbe5e
|
@ -100,7 +100,66 @@ int InstanceNormNC8HW8Fp16(const float16_t *src_data, float16_t *dst_data, const
|
|||
const float16_t *src_b = src_data + b * channel * hw_plane;
|
||||
float16_t *dst_b = dst_data + b * channel * hw_plane;
|
||||
int c = channel_begin;
|
||||
for (; c < c8_down; c += C8NUM) {
|
||||
for (; c <= channel_end - C16NUM; c += C16NUM) {
|
||||
const float16_t *src = src_b + c * hw_plane;
|
||||
const float16_t *src1 = src_b + (c + C8NUM) * hw_plane;
|
||||
float16_t *dst = dst_b + c;
|
||||
float32x4_t mean1 = vdupq_n_f32(0.0f);
|
||||
float32x4_t mean2 = vdupq_n_f32(0.0f);
|
||||
float32x4_t mean3 = vdupq_n_f32(0.0f);
|
||||
float32x4_t mean4 = vdupq_n_f32(0.0f);
|
||||
float32x4_t square_mean1 = vdupq_n_f32(0.0f);
|
||||
float32x4_t square_mean2 = vdupq_n_f32(0.0f);
|
||||
float32x4_t square_mean3 = vdupq_n_f32(0.0f);
|
||||
float32x4_t square_mean4 = vdupq_n_f32(0.0f);
|
||||
for (int index = 0; index < hw_plane; ++index) {
|
||||
float16x8_t srcv = vld1q_f16(src + index * C8NUM);
|
||||
float16x8_t srcv1 = vld1q_f16(src1 + index * C8NUM);
|
||||
|
||||
float32x4_t srcv01 = vcvt_f32_f16(vget_low_f16(srcv));
|
||||
float32x4_t srcv02 = vcvt_f32_f16(vget_high_f16(srcv1));
|
||||
float32x4_t srcv11 = vcvt_f32_f16(vget_low_f16(srcv));
|
||||
float32x4_t srcv12 = vcvt_f32_f16(vget_high_f16(srcv1));
|
||||
mean1 = vaddq_f32(mean1, srcv01);
|
||||
mean2 = vaddq_f32(mean2, srcv02);
|
||||
mean3 = vaddq_f32(mean3, srcv11);
|
||||
mean4 = vaddq_f32(mean4, srcv12);
|
||||
square_mean1 = vaddq_f32(square_mean1, vmulq_f32(srcv01, srcv01));
|
||||
square_mean2 = vaddq_f32(square_mean2, vmulq_f32(srcv02, srcv02));
|
||||
square_mean3 = vaddq_f32(square_mean3, vmulq_f32(srcv11, srcv11));
|
||||
square_mean4 = vaddq_f32(square_mean4, vmulq_f32(srcv12, srcv12));
|
||||
}
|
||||
float16x8_t mean =
|
||||
vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(mean1, hw_plane_4)), vcvt_f16_f32(MS_DIVQ_F32(mean2, hw_plane_4)));
|
||||
float16x8_t mean_1 =
|
||||
vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(mean3, hw_plane_4)), vcvt_f16_f32(MS_DIVQ_F32(mean4, hw_plane_4)));
|
||||
float16x8_t square_mean = vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(square_mean1, hw_plane_4)),
|
||||
vcvt_f16_f32(MS_DIVQ_F32(square_mean2, hw_plane_4)));
|
||||
float16x8_t square_mean_1 = vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(square_mean3, hw_plane_4)),
|
||||
vcvt_f16_f32(MS_DIVQ_F32(square_mean4, hw_plane_4)));
|
||||
float16x8_t deno = vaddq_f16(vsubq_f16(square_mean, vmulq_f16(mean, mean)), vdupq_n_f16(param->epsilon_));
|
||||
float16x8_t deno1 = vaddq_f16(vsubq_f16(square_mean_1, vmulq_f16(mean_1, mean_1)), vdupq_n_f16(param->epsilon_));
|
||||
deno = 1 / MS_SQRTFX8_F16(deno);
|
||||
deno1 = 1 / MS_SQRTFX8_F16(deno1);
|
||||
|
||||
float16x8_t gammav = vmulq_f16(vld1q_f16(gamma_data + c), deno); // deno * gamma_data[c]
|
||||
float16x8_t gammav1 = vmulq_f16(vld1q_f16(gamma_data + c + C8NUM), deno1); // deno * gamma_data[c]
|
||||
float16x8_t betav = vld1q_f16(beta_data + c);
|
||||
float16x8_t betav1 = vld1q_f16(beta_data + c + C8NUM);
|
||||
for (int index = 0; index < hw_plane; ++index) {
|
||||
float16x8_t srcv = vld1q_f16(src + index * C8NUM);
|
||||
float16x8_t srcv1 = vld1q_f16(src1 + index * C8NUM);
|
||||
float16x8_t outv = vsubq_f16(srcv, mean);
|
||||
float16x8_t outv1 = vsubq_f16(srcv1, mean1);
|
||||
outv = vmulq_f16(outv, gammav);
|
||||
outv1 = vmulq_f16(outv1, gammav1);
|
||||
outv = vaddq_f16(outv, betav);
|
||||
outv1 = vaddq_f16(outv1, betav1);
|
||||
vst1q_f16(dst + index * channel, outv);
|
||||
vst1q_f16(dst + index * channel + C8NUM, outv1);
|
||||
}
|
||||
}
|
||||
for (; c <= channel_end - C8NUM; c += C8NUM) {
|
||||
const float16_t *src = src_b + c * hw_plane;
|
||||
float16_t *dst = dst_b + c;
|
||||
float32x4_t mean1 = vdupq_n_f32(0.0f);
|
||||
|
|
|
@ -23,7 +23,7 @@ int InstanceNorm(const float *src_data, float *dst_data, const float *gamma_data
|
|||
const InstanceNormParameter *param, size_t task_id) {
|
||||
NNACL_CHECK_NULL_RETURN_ERR(src_data);
|
||||
NNACL_CHECK_NULL_RETURN_ERR(dst_data);
|
||||
NNACL_CHECK_NULL_RETURN_ERR(param->op_parameter_.thread_num_)
|
||||
NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_)
|
||||
int channel_step = UP_DIV(param->channel_, param->op_parameter_.thread_num_);
|
||||
int channel_begin = (int)(task_id)*channel_step;
|
||||
int channel_end = MSMIN(channel_begin + channel_step, param->channel_);
|
||||
|
@ -116,14 +116,13 @@ int InstanceNormNC4HW4(const float *src_data, float *dst_data, const float *gamm
|
|||
const InstanceNormParameter *param, size_t task_id) {
|
||||
NNACL_CHECK_NULL_RETURN_ERR(src_data);
|
||||
NNACL_CHECK_NULL_RETURN_ERR(dst_data);
|
||||
NNACL_CHECK_NULL_RETURN_ERR(param->op_parameter_.thread_num_);
|
||||
NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_);
|
||||
int channel = param->channel_;
|
||||
int hw_plane = param->inner_size_;
|
||||
int channel_step = UP_DIV(UP_DIV(channel, C4NUM), param->op_parameter_.thread_num_) * C4NUM;
|
||||
int channel_begin = (int)(task_id)*channel_step;
|
||||
int channel_end = MSMIN(channel_begin + channel_step, channel);
|
||||
#if defined(ENABLE_SSE) || defined(ENABLE_ARM)
|
||||
int c4_down = channel_end / C4NUM * C4NUM;
|
||||
MS_FLOAT32X4 hw_planev = MS_MOVQ_F32((float)(hw_plane));
|
||||
#endif
|
||||
for (int b = 0; b < param->batch_; b++) {
|
||||
|
@ -131,7 +130,135 @@ int InstanceNormNC4HW4(const float *src_data, float *dst_data, const float *gamm
|
|||
float *dst_b = dst_data + b * channel * hw_plane;
|
||||
int c = channel_begin;
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
|
||||
for (; c < c4_down; c += C4NUM) {
|
||||
for (; c <= channel_end - C16NUM; c += C16NUM) {
|
||||
const float *src = src_b + c * hw_plane;
|
||||
const float *src1 = src_b + (c + C4NUM) * hw_plane;
|
||||
const float *src2 = src_b + (c + C8NUM) * hw_plane;
|
||||
const float *src3 = src_b + (c + C12NUM) * hw_plane;
|
||||
float *dst = dst_b + c;
|
||||
MS_FLOAT32X4 mean = MS_MOVQ_F32(0.0f);
|
||||
MS_FLOAT32X4 mean1 = MS_MOVQ_F32(0.0f);
|
||||
MS_FLOAT32X4 mean2 = MS_MOVQ_F32(0.0f);
|
||||
MS_FLOAT32X4 mean3 = MS_MOVQ_F32(0.0f);
|
||||
MS_FLOAT32X4 square_mean = MS_MOVQ_F32(0.0f);
|
||||
MS_FLOAT32X4 square_mean1 = MS_MOVQ_F32(0.0f);
|
||||
MS_FLOAT32X4 square_mean2 = MS_MOVQ_F32(0.0f);
|
||||
MS_FLOAT32X4 square_mean3 = MS_MOVQ_F32(0.0f);
|
||||
for (int index = 0; index < hw_plane; ++index) {
|
||||
MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM);
|
||||
MS_FLOAT32X4 srcv1 = MS_LDQ_F32(src1 + index * C4NUM);
|
||||
MS_FLOAT32X4 srcv2 = MS_LDQ_F32(src2 + index * C4NUM);
|
||||
MS_FLOAT32X4 srcv3 = MS_LDQ_F32(src3 + index * C4NUM);
|
||||
MS_FLOAT32X4 squarev = MS_MULQ_F32(srcv, srcv);
|
||||
MS_FLOAT32X4 squarev1 = MS_MULQ_F32(srcv1, srcv1);
|
||||
MS_FLOAT32X4 squarev2 = MS_MULQ_F32(srcv2, srcv2);
|
||||
MS_FLOAT32X4 squarev3 = MS_MULQ_F32(srcv3, srcv3);
|
||||
mean = MS_ADDQ_F32(mean, srcv);
|
||||
mean1 = MS_ADDQ_F32(mean1, srcv1);
|
||||
mean2 = MS_ADDQ_F32(mean2, srcv2);
|
||||
mean3 = MS_ADDQ_F32(mean3, srcv3);
|
||||
square_mean = MS_ADDQ_F32(square_mean, squarev);
|
||||
square_mean1 = MS_ADDQ_F32(square_mean1, squarev1);
|
||||
square_mean2 = MS_ADDQ_F32(square_mean2, squarev2);
|
||||
square_mean3 = MS_ADDQ_F32(square_mean3, squarev3);
|
||||
}
|
||||
mean = MS_DIVQ_F32(mean, hw_planev);
|
||||
mean1 = MS_DIVQ_F32(mean1, hw_planev);
|
||||
mean2 = MS_DIVQ_F32(mean2, hw_planev);
|
||||
mean3 = MS_DIVQ_F32(mean3, hw_planev);
|
||||
square_mean = MS_DIVQ_F32(square_mean, hw_planev);
|
||||
square_mean1 = MS_DIVQ_F32(square_mean1, hw_planev);
|
||||
square_mean2 = MS_DIVQ_F32(square_mean2, hw_planev);
|
||||
square_mean3 = MS_DIVQ_F32(square_mean3, hw_planev);
|
||||
MS_FLOAT32X4 deno = MS_ADDQ_F32(MS_SUBQ_F32(square_mean, MS_MULQ_F32(mean, mean)), MS_MOVQ_F32(param->epsilon_));
|
||||
MS_FLOAT32X4 deno1 =
|
||||
MS_ADDQ_F32(MS_SUBQ_F32(square_mean1, MS_MULQ_F32(mean1, mean1)), MS_MOVQ_F32(param->epsilon_));
|
||||
MS_FLOAT32X4 deno2 =
|
||||
MS_ADDQ_F32(MS_SUBQ_F32(square_mean2, MS_MULQ_F32(mean2, mean2)), MS_MOVQ_F32(param->epsilon_));
|
||||
MS_FLOAT32X4 deno3 =
|
||||
MS_ADDQ_F32(MS_SUBQ_F32(square_mean3, MS_MULQ_F32(mean3, mean3)), MS_MOVQ_F32(param->epsilon_));
|
||||
deno = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno));
|
||||
deno1 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno1));
|
||||
deno2 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno2));
|
||||
deno3 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno3));
|
||||
|
||||
MS_FLOAT32X4 gammav = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c), deno); // deno * gamma_data[c]
|
||||
MS_FLOAT32X4 gammav1 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C4NUM), deno1); // deno * gamma_data[c]
|
||||
MS_FLOAT32X4 gammav2 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C8NUM), deno2); // deno * gamma_data[c]
|
||||
MS_FLOAT32X4 gammav3 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C12NUM), deno3); // deno * gamma_data[c]
|
||||
MS_FLOAT32X4 betav = MS_LDQ_F32(beta_data + c);
|
||||
MS_FLOAT32X4 betav1 = MS_LDQ_F32(beta_data + c + C4NUM);
|
||||
MS_FLOAT32X4 betav2 = MS_LDQ_F32(beta_data + c + C8NUM);
|
||||
MS_FLOAT32X4 betav3 = MS_LDQ_F32(beta_data + c + C12NUM);
|
||||
for (int index = 0; index < hw_plane; ++index) {
|
||||
MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM);
|
||||
MS_FLOAT32X4 srcv1 = MS_LDQ_F32(src1 + index * C4NUM);
|
||||
MS_FLOAT32X4 srcv2 = MS_LDQ_F32(src2 + index * C4NUM);
|
||||
MS_FLOAT32X4 srcv3 = MS_LDQ_F32(src3 + index * C4NUM);
|
||||
MS_FLOAT32X4 outv = MS_SUBQ_F32(srcv, mean);
|
||||
MS_FLOAT32X4 outv1 = MS_SUBQ_F32(srcv1, mean1);
|
||||
MS_FLOAT32X4 outv2 = MS_SUBQ_F32(srcv2, mean2);
|
||||
MS_FLOAT32X4 outv3 = MS_SUBQ_F32(srcv3, mean3);
|
||||
outv = MS_MULQ_F32(outv, gammav);
|
||||
outv1 = MS_MULQ_F32(outv1, gammav1);
|
||||
outv2 = MS_MULQ_F32(outv2, gammav2);
|
||||
outv3 = MS_MULQ_F32(outv3, gammav3);
|
||||
outv = MS_ADDQ_F32(outv, betav);
|
||||
outv1 = MS_ADDQ_F32(outv1, betav1);
|
||||
outv2 = MS_ADDQ_F32(outv2, betav2);
|
||||
outv3 = MS_ADDQ_F32(outv3, betav3);
|
||||
MS_STQ_F32(dst + index * channel, outv);
|
||||
MS_STQ_F32(dst + index * channel + C4NUM, outv1);
|
||||
MS_STQ_F32(dst + index * channel + C8NUM, outv2);
|
||||
MS_STQ_F32(dst + index * channel + C12NUM, outv3);
|
||||
}
|
||||
}
|
||||
for (; c <= channel_end - C8NUM; c += C8NUM) {
|
||||
const float *src = src_b + c * hw_plane;
|
||||
const float *src1 = src_b + (c + C4NUM) * hw_plane;
|
||||
float *dst = dst_b + c;
|
||||
MS_FLOAT32X4 mean = MS_MOVQ_F32(0.0f);
|
||||
MS_FLOAT32X4 mean1 = MS_MOVQ_F32(0.0f);
|
||||
MS_FLOAT32X4 square_mean = MS_MOVQ_F32(0.0f);
|
||||
MS_FLOAT32X4 square_mean1 = MS_MOVQ_F32(0.0f);
|
||||
for (int index = 0; index < hw_plane; ++index) {
|
||||
MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM);
|
||||
MS_FLOAT32X4 srcv1 = MS_LDQ_F32(src1 + index * C4NUM);
|
||||
MS_FLOAT32X4 squarev = MS_MULQ_F32(srcv, srcv);
|
||||
MS_FLOAT32X4 squarev1 = MS_MULQ_F32(srcv1, srcv1);
|
||||
mean = MS_ADDQ_F32(mean, srcv);
|
||||
mean1 = MS_ADDQ_F32(mean1, srcv1);
|
||||
square_mean = MS_ADDQ_F32(square_mean, squarev);
|
||||
square_mean1 = MS_ADDQ_F32(square_mean1, squarev1);
|
||||
}
|
||||
mean = MS_DIVQ_F32(mean, hw_planev);
|
||||
mean1 = MS_DIVQ_F32(mean1, hw_planev);
|
||||
square_mean = MS_DIVQ_F32(square_mean, hw_planev);
|
||||
square_mean1 = MS_DIVQ_F32(square_mean1, hw_planev);
|
||||
MS_FLOAT32X4 deno = MS_ADDQ_F32(MS_SUBQ_F32(square_mean, MS_MULQ_F32(mean, mean)), MS_MOVQ_F32(param->epsilon_));
|
||||
MS_FLOAT32X4 deno1 =
|
||||
MS_ADDQ_F32(MS_SUBQ_F32(square_mean1, MS_MULQ_F32(mean1, mean1)), MS_MOVQ_F32(param->epsilon_));
|
||||
deno = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno));
|
||||
deno1 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno1));
|
||||
|
||||
MS_FLOAT32X4 gammav = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c), deno); // deno * gamma_data[c]
|
||||
MS_FLOAT32X4 gammav1 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C4NUM), deno1); // deno * gamma_data[c]
|
||||
MS_FLOAT32X4 betav = MS_LDQ_F32(beta_data + c);
|
||||
MS_FLOAT32X4 betav1 = MS_LDQ_F32(beta_data + c + C4NUM);
|
||||
for (int index = 0; index < hw_plane; ++index) {
|
||||
MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM);
|
||||
MS_FLOAT32X4 srcv1 = MS_LDQ_F32(src1 + index * C4NUM);
|
||||
MS_FLOAT32X4 outv = MS_SUBQ_F32(srcv, mean);
|
||||
MS_FLOAT32X4 outv1 = MS_SUBQ_F32(srcv1, mean1);
|
||||
outv = MS_MULQ_F32(outv, gammav);
|
||||
outv1 = MS_MULQ_F32(outv1, gammav1);
|
||||
outv = MS_ADDQ_F32(outv, betav);
|
||||
outv1 = MS_ADDQ_F32(outv1, betav1);
|
||||
MS_STQ_F32(dst + index * channel, outv);
|
||||
MS_STQ_F32(dst + index * channel + C4NUM, outv1);
|
||||
}
|
||||
}
|
||||
for (; c <= channel_end - C4NUM; c += C4NUM) {
|
||||
const float *src = src_b + c * hw_plane;
|
||||
float *dst = dst_b + c;
|
||||
MS_FLOAT32X4 mean = MS_MOVQ_F32(0.0f);
|
||||
|
|
Loading…
Reference in New Issue