diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index fc08744ccdc..5d4380e4a68 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -45,7 +45,7 @@ class _BatchNorm(Cell): moving_var_init='ones', use_batch_statistics=None, device_num_each_group=1, - input_dims='1d'): + input_dims='2d'): super(_BatchNorm, self).__init__() if num_features < 1: raise ValueError("num_features must be at least 1") @@ -151,6 +151,8 @@ class _BatchNorm(Cell): _shape_check(self.shape(x)) if self.input_dims == '1d': _shape_check_2d(self.shape(x)) + if self.input_dims == 'both': + _shape_check_2d_or_4d(self.shape(x)) if self.use_batch_statistics is None: flag = self.training else: @@ -211,7 +213,13 @@ def _shape_check_2d(input_shape): @constexpr def _shape_check(in_shape): if len(in_shape) != 4: - raise ValueError("The input must has 4 dims") + raise ValueError("The input must has 4 dims.") + + +@constexpr +def _shape_check_2d_or_4d(in_shape): + if len(in_shape) != 2 and len(in_shape) != 4: + raise ValueError("The input must has 2 dims or 4 dims.") @constexpr @@ -449,7 +457,8 @@ class GlobalBatchNorm(_BatchNorm): moving_mean_init, moving_var_init, use_batch_statistics, - device_num_each_group) + device_num_each_group, + input_dims='both') self.group = check_int_positive(device_num_each_group) if self.group <= 1: raise ValueError("the number of group must be greater than 1.")