Fix input verification for input of GlobalbatchNorm.

This commit is contained in:
liuxiao93 2020-07-20 11:28:27 +08:00
parent fd9619bbf1
commit 4c98a00830
1 changed files with 12 additions and 3 deletions

View File

@ -45,7 +45,7 @@ class _BatchNorm(Cell):
moving_var_init='ones',
use_batch_statistics=None,
device_num_each_group=1,
input_dims='1d'):
input_dims='2d'):
super(_BatchNorm, self).__init__()
if num_features < 1:
raise ValueError("num_features must be at least 1")
@ -151,6 +151,8 @@ class _BatchNorm(Cell):
_shape_check(self.shape(x))
if self.input_dims == '1d':
_shape_check_2d(self.shape(x))
if self.input_dims == 'both':
_shape_check_2d_or_4d(self.shape(x))
if self.use_batch_statistics is None:
flag = self.training
else:
@ -211,7 +213,13 @@ def _shape_check_2d(input_shape):
@constexpr
def _shape_check(in_shape):
if len(in_shape) != 4:
raise ValueError("The input must has 4 dims")
raise ValueError("The input must has 4 dims.")
@constexpr
def _shape_check_2d_or_4d(in_shape):
if len(in_shape) != 2 and len(in_shape) != 4:
raise ValueError("The input must has 2 dims or 4 dims.")
@constexpr
@ -449,7 +457,8 @@ class GlobalBatchNorm(_BatchNorm):
moving_mean_init,
moving_var_init,
use_batch_statistics,
device_num_each_group)
device_num_each_group,
input_dims='both')
self.group = check_int_positive(device_num_each_group)
if self.group <= 1:
raise ValueError("the number of group must be greater than 1.")