From 4c98a00830a60d82ad4fcc9bf719b967bcb3c7e3 Mon Sep 17 00:00:00 2001 From: liuxiao93 Date: Mon, 20 Jul 2020 11:28:27 +0800 Subject: [PATCH] Fix input verification for input of GlobalbatchNorm. --- mindspore/nn/layer/normalization.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index fc08744ccdc..5d4380e4a68 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -45,7 +45,7 @@ class _BatchNorm(Cell): moving_var_init='ones', use_batch_statistics=None, device_num_each_group=1, - input_dims='1d'): + input_dims='2d'): super(_BatchNorm, self).__init__() if num_features < 1: raise ValueError("num_features must be at least 1") @@ -151,6 +151,8 @@ class _BatchNorm(Cell): _shape_check(self.shape(x)) if self.input_dims == '1d': _shape_check_2d(self.shape(x)) + if self.input_dims == 'both': + _shape_check_2d_or_4d(self.shape(x)) if self.use_batch_statistics is None: flag = self.training else: @@ -211,7 +213,13 @@ def _shape_check_2d(input_shape): @constexpr def _shape_check(in_shape): if len(in_shape) != 4: - raise ValueError("The input must has 4 dims") + raise ValueError("The input must has 4 dims.") + + +@constexpr +def _shape_check_2d_or_4d(in_shape): + if len(in_shape) != 2 and len(in_shape) != 4: + raise ValueError("The input must has 2 dims or 4 dims.") @constexpr @@ -449,7 +457,8 @@ class GlobalBatchNorm(_BatchNorm): moving_mean_init, moving_var_init, use_batch_statistics, - device_num_each_group) + device_num_each_group, + input_dims='both') self.group = check_int_positive(device_num_each_group) if self.group <= 1: raise ValueError("the number of group must be greater than 1.")