forked from mindspore-Ecosystem/mindspore
enable BatchNorm grad in eval mode
This commit is contained in:
parent
2065ec8fe1
commit
70ebd9bdbc
|
@ -48,8 +48,13 @@ void BatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
dnnl::memory::desc x_desc = GetDefaultMemDesc(x_shape);
|
||||
dnnl::memory::desc scale_bias_desc = GetDefaultMemDesc({SCALE_SHIFT_NUM, channel});
|
||||
auto epsilon = AnfAlgo::GetNodeAttr<float>(kernel_node, "epsilon");
|
||||
auto prop_kind = dnnl::prop_kind::forward_training;
|
||||
auto normalization_flags = dnnl::normalization_flags::use_scale_shift;
|
||||
auto is_train = AnfAlgo::GetNodeAttr<bool>(kernel_node, "is_training");
|
||||
auto prop_kind = dnnl::prop_kind::forward_inference;
|
||||
auto normalization_flags = dnnl::normalization_flags::use_scale_shift | dnnl::normalization_flags::use_global_stats;
|
||||
if (is_train) {
|
||||
prop_kind = dnnl::prop_kind::forward_training;
|
||||
normalization_flags = dnnl::normalization_flags::use_scale_shift;
|
||||
}
|
||||
|
||||
// fused Batch Normalization forward description
|
||||
dnnl::batch_normalization_forward::desc desc =
|
||||
|
|
Loading…
Reference in New Issue