diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.c index c4329a21577..6589e219e84 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.c @@ -100,7 +100,66 @@ int InstanceNormNC8HW8Fp16(const float16_t *src_data, float16_t *dst_data, const const float16_t *src_b = src_data + b * channel * hw_plane; float16_t *dst_b = dst_data + b * channel * hw_plane; int c = channel_begin; - for (; c < c8_down; c += C8NUM) { + for (; c <= channel_end - C16NUM; c += C16NUM) { + const float16_t *src = src_b + c * hw_plane; + const float16_t *src1 = src_b + (c + C8NUM) * hw_plane; + float16_t *dst = dst_b + c; + float32x4_t mean1 = vdupq_n_f32(0.0f); + float32x4_t mean2 = vdupq_n_f32(0.0f); + float32x4_t mean3 = vdupq_n_f32(0.0f); + float32x4_t mean4 = vdupq_n_f32(0.0f); + float32x4_t square_mean1 = vdupq_n_f32(0.0f); + float32x4_t square_mean2 = vdupq_n_f32(0.0f); + float32x4_t square_mean3 = vdupq_n_f32(0.0f); + float32x4_t square_mean4 = vdupq_n_f32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + float16x8_t srcv = vld1q_f16(src + index * C8NUM); + float16x8_t srcv1 = vld1q_f16(src1 + index * C8NUM); + + float32x4_t srcv01 = vcvt_f32_f16(vget_low_f16(srcv)); + float32x4_t srcv02 = vcvt_f32_f16(vget_high_f16(srcv1)); + float32x4_t srcv11 = vcvt_f32_f16(vget_low_f16(srcv)); + float32x4_t srcv12 = vcvt_f32_f16(vget_high_f16(srcv1)); + mean1 = vaddq_f32(mean1, srcv01); + mean2 = vaddq_f32(mean2, srcv02); + mean3 = vaddq_f32(mean3, srcv11); + mean4 = vaddq_f32(mean4, srcv12); + square_mean1 = vaddq_f32(square_mean1, vmulq_f32(srcv01, srcv01)); + square_mean2 = vaddq_f32(square_mean2, vmulq_f32(srcv02, srcv02)); + square_mean3 = vaddq_f32(square_mean3, vmulq_f32(srcv11, srcv11)); + square_mean4 = vaddq_f32(square_mean4, vmulq_f32(srcv12, srcv12)); + } + float16x8_t mean = + vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(mean1, hw_plane_4)), vcvt_f16_f32(MS_DIVQ_F32(mean2, hw_plane_4))); + float16x8_t mean_1 = + vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(mean3, hw_plane_4)), vcvt_f16_f32(MS_DIVQ_F32(mean4, hw_plane_4))); + float16x8_t square_mean = vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(square_mean1, hw_plane_4)), + vcvt_f16_f32(MS_DIVQ_F32(square_mean2, hw_plane_4))); + float16x8_t square_mean_1 = vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(square_mean3, hw_plane_4)), + vcvt_f16_f32(MS_DIVQ_F32(square_mean4, hw_plane_4))); + float16x8_t deno = vaddq_f16(vsubq_f16(square_mean, vmulq_f16(mean, mean)), vdupq_n_f16(param->epsilon_)); + float16x8_t deno1 = vaddq_f16(vsubq_f16(square_mean_1, vmulq_f16(mean_1, mean_1)), vdupq_n_f16(param->epsilon_)); + deno = 1 / MS_SQRTFX8_F16(deno); + deno1 = 1 / MS_SQRTFX8_F16(deno1); + + float16x8_t gammav = vmulq_f16(vld1q_f16(gamma_data + c), deno); // deno * gamma_data[c] + float16x8_t gammav1 = vmulq_f16(vld1q_f16(gamma_data + c + C8NUM), deno1); // deno * gamma_data[c] + float16x8_t betav = vld1q_f16(beta_data + c); + float16x8_t betav1 = vld1q_f16(beta_data + c + C8NUM); + for (int index = 0; index < hw_plane; ++index) { + float16x8_t srcv = vld1q_f16(src + index * C8NUM); + float16x8_t srcv1 = vld1q_f16(src1 + index * C8NUM); + float16x8_t outv = vsubq_f16(srcv, mean); + float16x8_t outv1 = vsubq_f16(srcv1, mean1); + outv = vmulq_f16(outv, gammav); + outv1 = vmulq_f16(outv1, gammav1); + outv = vaddq_f16(outv, betav); + outv1 = vaddq_f16(outv1, betav1); + vst1q_f16(dst + index * channel, outv); + vst1q_f16(dst + index * channel + C8NUM, outv1); + } + } + for (; c <= channel_end - C8NUM; c += C8NUM) { const float16_t *src = src_b + c * hw_plane; float16_t *dst = dst_b + c; float32x4_t mean1 = vdupq_n_f32(0.0f); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/instance_norm_fp32.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/instance_norm_fp32.c index 250814770e8..cf95e164a8f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/instance_norm_fp32.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/instance_norm_fp32.c @@ -23,7 +23,7 @@ int InstanceNorm(const float *src_data, float *dst_data, const float *gamma_data const InstanceNormParameter *param, size_t task_id) { NNACL_CHECK_NULL_RETURN_ERR(src_data); NNACL_CHECK_NULL_RETURN_ERR(dst_data); - NNACL_CHECK_NULL_RETURN_ERR(param->op_parameter_.thread_num_) + NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_) int channel_step = UP_DIV(param->channel_, param->op_parameter_.thread_num_); int channel_begin = (int)(task_id)*channel_step; int channel_end = MSMIN(channel_begin + channel_step, param->channel_); @@ -116,14 +116,13 @@ int InstanceNormNC4HW4(const float *src_data, float *dst_data, const float *gamm const InstanceNormParameter *param, size_t task_id) { NNACL_CHECK_NULL_RETURN_ERR(src_data); NNACL_CHECK_NULL_RETURN_ERR(dst_data); - NNACL_CHECK_NULL_RETURN_ERR(param->op_parameter_.thread_num_); + NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_); int channel = param->channel_; int hw_plane = param->inner_size_; int channel_step = UP_DIV(UP_DIV(channel, C4NUM), param->op_parameter_.thread_num_) * C4NUM; int channel_begin = (int)(task_id)*channel_step; int channel_end = MSMIN(channel_begin + channel_step, channel); #if defined(ENABLE_SSE) || defined(ENABLE_ARM) - int c4_down = channel_end / C4NUM * C4NUM; MS_FLOAT32X4 hw_planev = MS_MOVQ_F32((float)(hw_plane)); #endif for (int b = 0; b < param->batch_; b++) { @@ -131,7 +130,135 @@ int InstanceNormNC4HW4(const float *src_data, float *dst_data, const float *gamm float *dst_b = dst_data + b * channel * hw_plane; int c = channel_begin; #if defined(ENABLE_ARM) || defined(ENABLE_SSE) - for (; c < c4_down; c += C4NUM) { + for (; c <= channel_end - C16NUM; c += C16NUM) { + const float *src = src_b + c * hw_plane; + const float *src1 = src_b + (c + C4NUM) * hw_plane; + const float *src2 = src_b + (c + C8NUM) * hw_plane; + const float *src3 = src_b + (c + C12NUM) * hw_plane; + float *dst = dst_b + c; + MS_FLOAT32X4 mean = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 mean1 = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 mean2 = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 mean3 = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 square_mean = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 square_mean1 = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 square_mean2 = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 square_mean3 = MS_MOVQ_F32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM); + MS_FLOAT32X4 srcv1 = MS_LDQ_F32(src1 + index * C4NUM); + MS_FLOAT32X4 srcv2 = MS_LDQ_F32(src2 + index * C4NUM); + MS_FLOAT32X4 srcv3 = MS_LDQ_F32(src3 + index * C4NUM); + MS_FLOAT32X4 squarev = MS_MULQ_F32(srcv, srcv); + MS_FLOAT32X4 squarev1 = MS_MULQ_F32(srcv1, srcv1); + MS_FLOAT32X4 squarev2 = MS_MULQ_F32(srcv2, srcv2); + MS_FLOAT32X4 squarev3 = MS_MULQ_F32(srcv3, srcv3); + mean = MS_ADDQ_F32(mean, srcv); + mean1 = MS_ADDQ_F32(mean1, srcv1); + mean2 = MS_ADDQ_F32(mean2, srcv2); + mean3 = MS_ADDQ_F32(mean3, srcv3); + square_mean = MS_ADDQ_F32(square_mean, squarev); + square_mean1 = MS_ADDQ_F32(square_mean1, squarev1); + square_mean2 = MS_ADDQ_F32(square_mean2, squarev2); + square_mean3 = MS_ADDQ_F32(square_mean3, squarev3); + } + mean = MS_DIVQ_F32(mean, hw_planev); + mean1 = MS_DIVQ_F32(mean1, hw_planev); + mean2 = MS_DIVQ_F32(mean2, hw_planev); + mean3 = MS_DIVQ_F32(mean3, hw_planev); + square_mean = MS_DIVQ_F32(square_mean, hw_planev); + square_mean1 = MS_DIVQ_F32(square_mean1, hw_planev); + square_mean2 = MS_DIVQ_F32(square_mean2, hw_planev); + square_mean3 = MS_DIVQ_F32(square_mean3, hw_planev); + MS_FLOAT32X4 deno = MS_ADDQ_F32(MS_SUBQ_F32(square_mean, MS_MULQ_F32(mean, mean)), MS_MOVQ_F32(param->epsilon_)); + MS_FLOAT32X4 deno1 = + MS_ADDQ_F32(MS_SUBQ_F32(square_mean1, MS_MULQ_F32(mean1, mean1)), MS_MOVQ_F32(param->epsilon_)); + MS_FLOAT32X4 deno2 = + MS_ADDQ_F32(MS_SUBQ_F32(square_mean2, MS_MULQ_F32(mean2, mean2)), MS_MOVQ_F32(param->epsilon_)); + MS_FLOAT32X4 deno3 = + MS_ADDQ_F32(MS_SUBQ_F32(square_mean3, MS_MULQ_F32(mean3, mean3)), MS_MOVQ_F32(param->epsilon_)); + deno = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno)); + deno1 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno1)); + deno2 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno2)); + deno3 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno3)); + + MS_FLOAT32X4 gammav = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c), deno); // deno * gamma_data[c] + MS_FLOAT32X4 gammav1 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C4NUM), deno1); // deno * gamma_data[c] + MS_FLOAT32X4 gammav2 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C8NUM), deno2); // deno * gamma_data[c] + MS_FLOAT32X4 gammav3 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C12NUM), deno3); // deno * gamma_data[c] + MS_FLOAT32X4 betav = MS_LDQ_F32(beta_data + c); + MS_FLOAT32X4 betav1 = MS_LDQ_F32(beta_data + c + C4NUM); + MS_FLOAT32X4 betav2 = MS_LDQ_F32(beta_data + c + C8NUM); + MS_FLOAT32X4 betav3 = MS_LDQ_F32(beta_data + c + C12NUM); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM); + MS_FLOAT32X4 srcv1 = MS_LDQ_F32(src1 + index * C4NUM); + MS_FLOAT32X4 srcv2 = MS_LDQ_F32(src2 + index * C4NUM); + MS_FLOAT32X4 srcv3 = MS_LDQ_F32(src3 + index * C4NUM); + MS_FLOAT32X4 outv = MS_SUBQ_F32(srcv, mean); + MS_FLOAT32X4 outv1 = MS_SUBQ_F32(srcv1, mean1); + MS_FLOAT32X4 outv2 = MS_SUBQ_F32(srcv2, mean2); + MS_FLOAT32X4 outv3 = MS_SUBQ_F32(srcv3, mean3); + outv = MS_MULQ_F32(outv, gammav); + outv1 = MS_MULQ_F32(outv1, gammav1); + outv2 = MS_MULQ_F32(outv2, gammav2); + outv3 = MS_MULQ_F32(outv3, gammav3); + outv = MS_ADDQ_F32(outv, betav); + outv1 = MS_ADDQ_F32(outv1, betav1); + outv2 = MS_ADDQ_F32(outv2, betav2); + outv3 = MS_ADDQ_F32(outv3, betav3); + MS_STQ_F32(dst + index * channel, outv); + MS_STQ_F32(dst + index * channel + C4NUM, outv1); + MS_STQ_F32(dst + index * channel + C8NUM, outv2); + MS_STQ_F32(dst + index * channel + C12NUM, outv3); + } + } + for (; c <= channel_end - C8NUM; c += C8NUM) { + const float *src = src_b + c * hw_plane; + const float *src1 = src_b + (c + C4NUM) * hw_plane; + float *dst = dst_b + c; + MS_FLOAT32X4 mean = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 mean1 = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 square_mean = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 square_mean1 = MS_MOVQ_F32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM); + MS_FLOAT32X4 srcv1 = MS_LDQ_F32(src1 + index * C4NUM); + MS_FLOAT32X4 squarev = MS_MULQ_F32(srcv, srcv); + MS_FLOAT32X4 squarev1 = MS_MULQ_F32(srcv1, srcv1); + mean = MS_ADDQ_F32(mean, srcv); + mean1 = MS_ADDQ_F32(mean1, srcv1); + square_mean = MS_ADDQ_F32(square_mean, squarev); + square_mean1 = MS_ADDQ_F32(square_mean1, squarev1); + } + mean = MS_DIVQ_F32(mean, hw_planev); + mean1 = MS_DIVQ_F32(mean1, hw_planev); + square_mean = MS_DIVQ_F32(square_mean, hw_planev); + square_mean1 = MS_DIVQ_F32(square_mean1, hw_planev); + MS_FLOAT32X4 deno = MS_ADDQ_F32(MS_SUBQ_F32(square_mean, MS_MULQ_F32(mean, mean)), MS_MOVQ_F32(param->epsilon_)); + MS_FLOAT32X4 deno1 = + MS_ADDQ_F32(MS_SUBQ_F32(square_mean1, MS_MULQ_F32(mean1, mean1)), MS_MOVQ_F32(param->epsilon_)); + deno = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno)); + deno1 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno1)); + + MS_FLOAT32X4 gammav = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c), deno); // deno * gamma_data[c] + MS_FLOAT32X4 gammav1 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C4NUM), deno1); // deno * gamma_data[c] + MS_FLOAT32X4 betav = MS_LDQ_F32(beta_data + c); + MS_FLOAT32X4 betav1 = MS_LDQ_F32(beta_data + c + C4NUM); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM); + MS_FLOAT32X4 srcv1 = MS_LDQ_F32(src1 + index * C4NUM); + MS_FLOAT32X4 outv = MS_SUBQ_F32(srcv, mean); + MS_FLOAT32X4 outv1 = MS_SUBQ_F32(srcv1, mean1); + outv = MS_MULQ_F32(outv, gammav); + outv1 = MS_MULQ_F32(outv1, gammav1); + outv = MS_ADDQ_F32(outv, betav); + outv1 = MS_ADDQ_F32(outv1, betav1); + MS_STQ_F32(dst + index * channel, outv); + MS_STQ_F32(dst + index * channel + C4NUM, outv1); + } + } + for (; c <= channel_end - C4NUM; c += C4NUM) { const float *src = src_b + c * hw_plane; float *dst = dst_b + c; MS_FLOAT32X4 mean = MS_MOVQ_F32(0.0f);