forked from mindspore-Ecosystem/mindspore
fix globalbatchnorm bug
This commit is contained in:
parent
6c9a54afa1
commit
7b81ca68dc
|
@ -116,15 +116,7 @@ class _BatchNorm(Cell):
|
|||
group_list = [list(i) for i in world_rank_list]
|
||||
return group_list
|
||||
|
||||
def _shape_infer(self, x):
|
||||
"""global batch normalization shape and axes infer"""
|
||||
if len(self.shape(x)) == 4:
|
||||
axes = (0, 2, 3)
|
||||
re_shape = (1, self.num_features, 1, 1)
|
||||
else:
|
||||
axes = (0,)
|
||||
re_shape = (1, self.num_features)
|
||||
return axes, re_shape
|
||||
|
||||
|
||||
def _global_sync(self, x, axes, re_shape):
|
||||
"""calculate global batch normalization output"""
|
||||
|
@ -150,7 +142,7 @@ class _BatchNorm(Cell):
|
|||
if self.training and self.use_batch_statistics:
|
||||
if self.is_ge_backend:
|
||||
if self.is_global:
|
||||
axes, re_shape = self._shape_infer(x)
|
||||
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
|
||||
y = self._global_sync(x, axes, re_shape)
|
||||
else:
|
||||
y, batch_mean, batch_var, _, _ = \
|
||||
|
@ -189,6 +181,17 @@ def _channel_check(channel, num_channel):
|
|||
if channel != num_channel:
|
||||
raise ValueError("the input channel is not equal with num_channel")
|
||||
|
||||
@constexpr
|
||||
def _shape_infer(x_shape, num_feature):
|
||||
"""global batch normalization shape and axes infer"""
|
||||
if len(x_shape) == 4:
|
||||
axes = (0, 2, 3)
|
||||
re_shape = (1, num_feature, 1, 1)
|
||||
else:
|
||||
axes = (0,)
|
||||
re_shape = (1, num_feature)
|
||||
return axes, re_shape
|
||||
|
||||
class BatchNorm1d(_BatchNorm):
|
||||
r"""
|
||||
Batch normalization layer over a 2D input.
|
||||
|
|
Loading…
Reference in New Issue