forked from mindspore-Ecosystem/mindspore
Fix input verification for input of GlobalbatchNorm.
This commit is contained in:
parent
fd9619bbf1
commit
4c98a00830
|
@ -45,7 +45,7 @@ class _BatchNorm(Cell):
|
|||
moving_var_init='ones',
|
||||
use_batch_statistics=None,
|
||||
device_num_each_group=1,
|
||||
input_dims='1d'):
|
||||
input_dims='2d'):
|
||||
super(_BatchNorm, self).__init__()
|
||||
if num_features < 1:
|
||||
raise ValueError("num_features must be at least 1")
|
||||
|
@ -151,6 +151,8 @@ class _BatchNorm(Cell):
|
|||
_shape_check(self.shape(x))
|
||||
if self.input_dims == '1d':
|
||||
_shape_check_2d(self.shape(x))
|
||||
if self.input_dims == 'both':
|
||||
_shape_check_2d_or_4d(self.shape(x))
|
||||
if self.use_batch_statistics is None:
|
||||
flag = self.training
|
||||
else:
|
||||
|
@ -211,7 +213,13 @@ def _shape_check_2d(input_shape):
|
|||
@constexpr
|
||||
def _shape_check(in_shape):
|
||||
if len(in_shape) != 4:
|
||||
raise ValueError("The input must has 4 dims")
|
||||
raise ValueError("The input must has 4 dims.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _shape_check_2d_or_4d(in_shape):
|
||||
if len(in_shape) != 2 and len(in_shape) != 4:
|
||||
raise ValueError("The input must has 2 dims or 4 dims.")
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -449,7 +457,8 @@ class GlobalBatchNorm(_BatchNorm):
|
|||
moving_mean_init,
|
||||
moving_var_init,
|
||||
use_batch_statistics,
|
||||
device_num_each_group)
|
||||
device_num_each_group,
|
||||
input_dims='both')
|
||||
self.group = check_int_positive(device_num_each_group)
|
||||
if self.group <= 1:
|
||||
raise ValueError("the number of group must be greater than 1.")
|
||||
|
|
Loading…
Reference in New Issue