fix globalbatchnorm bug

This commit is contained in:
zhaojichen 2020-04-29 04:53:23 -04:00
parent 6c9a54afa1
commit 7b81ca68dc
1 changed files with 13 additions and 10 deletions

View File

@ -116,15 +116,7 @@ class _BatchNorm(Cell):
group_list = [list(i) for i in world_rank_list]
return group_list
def _shape_infer(self, x):
"""global batch normalization shape and axes infer"""
if len(self.shape(x)) == 4:
axes = (0, 2, 3)
re_shape = (1, self.num_features, 1, 1)
else:
axes = (0,)
re_shape = (1, self.num_features)
return axes, re_shape
def _global_sync(self, x, axes, re_shape):
"""calculate global batch normalization output"""
@ -150,7 +142,7 @@ class _BatchNorm(Cell):
if self.training and self.use_batch_statistics:
if self.is_ge_backend:
if self.is_global:
axes, re_shape = self._shape_infer(x)
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
y = self._global_sync(x, axes, re_shape)
else:
y, batch_mean, batch_var, _, _ = \
@ -189,6 +181,17 @@ def _channel_check(channel, num_channel):
if channel != num_channel:
raise ValueError("the input channel is not equal with num_channel")
@constexpr
def _shape_infer(x_shape, num_feature):
"""global batch normalization shape and axes infer"""
if len(x_shape) == 4:
axes = (0, 2, 3)
re_shape = (1, num_feature, 1, 1)
else:
axes = (0,)
re_shape = (1, num_feature)
return axes, re_shape
class BatchNorm1d(_BatchNorm):
r"""
Batch normalization layer over a 2D input.