!24671 Enable CPU BatchNorm grad in eval mode

Merge pull request !24671 from huangbo/master_codedex
This commit is contained in:
i-robot 2021-10-12 02:09:00 +00:00 committed by Gitee
commit 4caa16dee8
1 changed files with 7 additions and 2 deletions

View File

@ -56,8 +56,13 @@ void BatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::desc x_desc = GetDefaultMemDesc(x_shape);
dnnl::memory::desc scale_bias_desc = GetDefaultMemDesc({SCALE_SHIFT_NUM, channel});
auto epsilon = AnfAlgo::GetNodeAttr<float>(kernel_node, "epsilon");
auto prop_kind = dnnl::prop_kind::forward_training;
auto normalization_flags = dnnl::normalization_flags::use_scale_shift;
auto is_train = AnfAlgo::GetNodeAttr<bool>(kernel_node, "is_training");
auto prop_kind = dnnl::prop_kind::forward_inference;
auto normalization_flags = dnnl::normalization_flags::use_scale_shift | dnnl::normalization_flags::use_global_stats;
if (is_train) {
prop_kind = dnnl::prop_kind::forward_training;
normalization_flags = dnnl::normalization_flags::use_scale_shift;
}
// fused Batch Normalization forward description
dnnl::batch_normalization_forward::desc desc =