From 70ebd9bdbc137cdc8ab212b986bf0d687c948b0d Mon Sep 17 00:00:00 2001 From: huangbo77 Date: Mon, 11 Oct 2021 15:57:07 +0800 Subject: [PATCH] enable BatchNorm grad in eval mode --- .../cpu/mkldnn/batch_norm_grad_cpu_kernel.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/batch_norm_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/batch_norm_grad_cpu_kernel.cc index 36b2242a75e..d7cc2be9467 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/batch_norm_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/batch_norm_grad_cpu_kernel.cc @@ -48,8 +48,13 @@ void BatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { dnnl::memory::desc x_desc = GetDefaultMemDesc(x_shape); dnnl::memory::desc scale_bias_desc = GetDefaultMemDesc({SCALE_SHIFT_NUM, channel}); auto epsilon = AnfAlgo::GetNodeAttr(kernel_node, "epsilon"); - auto prop_kind = dnnl::prop_kind::forward_training; - auto normalization_flags = dnnl::normalization_flags::use_scale_shift; + auto is_train = AnfAlgo::GetNodeAttr(kernel_node, "is_training"); + auto prop_kind = dnnl::prop_kind::forward_inference; + auto normalization_flags = dnnl::normalization_flags::use_scale_shift | dnnl::normalization_flags::use_global_stats; + if (is_train) { + prop_kind = dnnl::prop_kind::forward_training; + normalization_flags = dnnl::normalization_flags::use_scale_shift; + } // fused Batch Normalization forward description dnnl::batch_normalization_forward::desc desc =