enable BatchNorm grad in eval mode

This commit is contained in:
huangbo77 2021-10-11 15:57:07 +08:00
parent 2065ec8fe1
commit 70ebd9bdbc
1 changed files with 7 additions and 2 deletions

View File

@ -48,8 +48,13 @@ void BatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::desc x_desc = GetDefaultMemDesc(x_shape);
dnnl::memory::desc scale_bias_desc = GetDefaultMemDesc({SCALE_SHIFT_NUM, channel});
auto epsilon = AnfAlgo::GetNodeAttr<float>(kernel_node, "epsilon");
auto prop_kind = dnnl::prop_kind::forward_training;
auto normalization_flags = dnnl::normalization_flags::use_scale_shift;
auto is_train = AnfAlgo::GetNodeAttr<bool>(kernel_node, "is_training");
auto prop_kind = dnnl::prop_kind::forward_inference;
auto normalization_flags = dnnl::normalization_flags::use_scale_shift | dnnl::normalization_flags::use_global_stats;
if (is_train) {
prop_kind = dnnl::prop_kind::forward_training;
normalization_flags = dnnl::normalization_flags::use_scale_shift;
}
// fused Batch Normalization forward description
dnnl::batch_normalization_forward::desc desc =