diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/batch_norm.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/batch_norm.c index 122e7116f4f..ebc2133beaf 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/batch_norm.c +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32_grad/batch_norm.c @@ -53,11 +53,14 @@ void backwardAll(const float *restrict in, const float *restrict yt, const float for (int c = 0; c < ch; c++) { int ix = i * ch + c; dbias[c] += yt[ix]; - // dscale - float x_hat = (in[ix] - mean[c]) * invar[c]; + // in fact, x_hat should also mul invar[c]. now put this step to the end. + float x_hat = in[ix] - mean[c]; dscale[c] += (yt[ix] * x_hat); } } + for (int c = 0; c < ch; c++) { + dscale[c] *= invar[c]; + } backwardComputeDx(in, yt, mean, invar, scale, size, ch, dbias, dscale, dx, N, is_train); } @@ -73,11 +76,14 @@ void backwardP1(const float *restrict in, const float *restrict yt, const float for (int c = 0; c < ch; c++) { int ix = i * ch + c; dbias[c] += yt[ix]; - // dscale - float x_hat = (in[ix] - mean[c]) * invar[c]; + // in fact, x_hat should also mul invar[c]. now put this step to the end. + float x_hat = in[ix] - mean[c]; dscale[c] += (yt[ix] * x_hat); } } + for (int c = 0; c < ch; c++) { + dscale[c] *= invar[c]; + } } #ifdef _MSC_VER diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc index ab840929331..e84ef8d9075 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc @@ -70,6 +70,7 @@ int BNGradCPUKernel::DoExecute(int task_id) { input_var = in_tensors_.at(5); } auto bn_param = reinterpret_cast(op_parameter_); + CHECK_NULL_RETURN(bn_param); int stage = stage_; int thread_num = thread_num_; float *save_mean = reinterpret_cast(input_mean->MutableData()); @@ -126,7 +127,8 @@ int BNGradCPUKernel::DoExecute(int task_id) { } } if (thread_num == 1) { - backwardAll(x, yt, save_mean, save_var, scale, total, channels, dbias, dscale, dx, (IsTrain())); + backwardAll(x, yt, save_mean, save_var, scale, total, channels, dbias, dscale, dx, + (IsTrain() && bn_param->is_training_)); } break; } @@ -136,7 +138,8 @@ int BNGradCPUKernel::DoExecute(int task_id) { } case 2: { backwardP2(x + task_id * stride * channels, yt + task_id * stride * channels, save_mean, save_var, dscale, dbias, - scale, count, total, channels, dx + task_id * stride * channels, (IsTrain())); + scale, count, total, channels, dx + task_id * stride * channels, + (IsTrain() && bn_param->is_training_)); break; } default: diff --git a/mindspore/lite/tools/converter/import/primitive_adjust.cc b/mindspore/lite/tools/converter/import/primitive_adjust.cc index e6b2f5f64bc..ba7deae719a 100644 --- a/mindspore/lite/tools/converter/import/primitive_adjust.cc +++ b/mindspore/lite/tools/converter/import/primitive_adjust.cc @@ -583,6 +583,26 @@ int MoveAttrBatchNorm(const CNodePtr &cnode) { value_node->set_value(dst_prim_c); return lite::RET_OK; } + +int MoveAttrMapBatchNormGrad(const CNodePtr &cnode) { + MS_ASSERT(cnode != nullptr); + auto value_node = cnode->input(0)->cast(); + MS_CHECK_TRUE_MSG(value_node != nullptr, RET_NULL_PTR, "value_node is nullptr"); + auto src_prim = GetValueNode(value_node); + MS_CHECK_TRUE_MSG(src_prim != nullptr, RET_NULL_PTR, "value_node is nullptr"); + auto dst_prim = std::make_shared(); + MS_CHECK_TRUE_MSG(dst_prim != nullptr, RET_NULL_PTR, "dst_prim is nullptr."); + auto dst_prim_c = dst_prim->GetPrim(); + MS_CHECK_TRUE_MSG(dst_prim_c != nullptr, RET_NULL_PTR, "dst_prim_c is nullptr."); + dst_prim_c->SetAttrs(src_prim->attrs()); + auto is_training_attr = src_prim->GetAttr(ops::kIsTraining); + if (is_training_attr == nullptr) { + MS_LOG(INFO) << "no \"is_training\" attr found in BatchNormGrad, will set it to true by default."; + dst_prim->set_is_training(true); + } + value_node->set_value(dst_prim_c); + return lite::RET_OK; +} } // namespace bool PrimitiveAdjust::Run(const FuncGraphPtr &func_graphs) { @@ -665,8 +685,8 @@ REGIST_PRIMITIVE_ADJUST(kNameElu, MoveAttrMapActivation) REGIST_PRIMITIVE_ADJUST(kNameEluGrad, MoveAttrMapActivationGrad) REGIST_PRIMITIVE_ADJUST(kNameExp, MoveAttrMapCommon) REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormEx, MoveAttrMapCommon) -REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormGradEx, MoveAttrMapCommon) -REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormGradCPU, MoveAttrMapCommon) +REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormGradEx, MoveAttrMapBatchNormGrad) +REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormGradCPU, MoveAttrMapBatchNormGrad) REGIST_PRIMITIVE_ADJUST(kNameGeLU, MoveAttrMapActivation) REGIST_PRIMITIVE_ADJUST(kNameGeLUGrad, MoveAttrMapActivationGrad) REGIST_PRIMITIVE_ADJUST(kNameHSigmoid, MoveAttrMapActivation)