forked from mindspore-Ecosystem/mindspore
!832 fix globalbatchnorm bug
Merge pull request !832 from JichenZhao/syncbn
This commit is contained in:
commit
64c87170fd
|
@ -74,8 +74,12 @@ class _BatchNorm(Cell):
|
|||
management.create_group('group' + str(i), self.rank_list[i])
|
||||
self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1)
|
||||
self.shape = P.Shape()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.reduce_mean = P.ReduceMean(keep_dims=True)
|
||||
self.square = P.Square()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.cast = P.Cast()
|
||||
self.dtype = P.DType()
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
if context.get_context("enable_ge"):
|
||||
self.is_ge_backend = True
|
||||
|
@ -112,29 +116,34 @@ class _BatchNorm(Cell):
|
|||
group_list = [list(i) for i in world_rank_list]
|
||||
return group_list
|
||||
|
||||
|
||||
|
||||
def _global_sync(self, x, axes, re_shape):
|
||||
"""calculate global batch normalization output"""
|
||||
x_mean = self.reduce_mean(x, axes)
|
||||
x_mean_square = self.reduce_mean(self.square(x), axes)
|
||||
global_batch_mean = self.all_reduce(x_mean) / self.group
|
||||
global_batch_mean_square = self.all_reduce(x_mean_square) / self.group
|
||||
global_mean = global_batch_mean
|
||||
global_var = global_batch_mean_square - self.square(global_mean)
|
||||
var_sqrt = self.sqrt(global_var + self.eps)
|
||||
mean_first = (x - global_mean) / var_sqrt
|
||||
y = mean_first * self.reshape(self.gamma, re_shape) + self.reshape(self.beta, re_shape)
|
||||
|
||||
mean_sub = self.sub_mean(self.reshape(self.moving_mean, re_shape), global_mean)
|
||||
tmp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub)))
|
||||
mean_sub2 = self.sub_var(self.reshape(self.moving_mean, re_shape), global_var)
|
||||
tmp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2)))
|
||||
y = F.depend(y, self.assign_sub_mean(self.reshape(self.moving_mean, re_shape), tmp_mean))
|
||||
y = F.depend(y, self.assign_sub_var(self.reshape(self.moving_variance, re_shape), tmp_variance))
|
||||
return y
|
||||
|
||||
def construct(self, x):
|
||||
if self.training and self.use_batch_statistics:
|
||||
if self.is_ge_backend:
|
||||
if self.is_global:
|
||||
x_mean = self.reduce_mean(x)
|
||||
x_mean_square = self.reduce_mean(self.square(x))
|
||||
global_batch_mean = self.all_reduce(x_mean) / self.group
|
||||
global_batch_mean_square = self.all_reduce(x_mean_square) / self.group
|
||||
global_mean = global_batch_mean
|
||||
global_var = global_batch_mean_square - self.square(global_batch_mean)
|
||||
y, batch_mean, batch_var, _, _ = \
|
||||
self.bn_train(x,
|
||||
self.gamma,
|
||||
self.beta,
|
||||
None,
|
||||
None)
|
||||
|
||||
mean_sub = self.sub_mean(self.moving_mean, global_mean)
|
||||
temp_mean = self.mul_mean(mean_sub, self.momentum)
|
||||
mean_sub2 = self.sub_var(self.moving_variance, global_var)
|
||||
temp_variance = self.mul_var(mean_sub2, self.momentum)
|
||||
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
|
||||
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
|
||||
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
|
||||
y = self._global_sync(x, axes, re_shape)
|
||||
else:
|
||||
y, batch_mean, batch_var, _, _ = \
|
||||
self.bn_train(x,
|
||||
|
@ -172,6 +181,17 @@ def _channel_check(channel, num_channel):
|
|||
if channel != num_channel:
|
||||
raise ValueError("the input channel is not equal with num_channel")
|
||||
|
||||
@constexpr
|
||||
def _shape_infer(x_shape, num_feature):
|
||||
"""global batch normalization shape and axes infer"""
|
||||
if len(x_shape) == 4:
|
||||
axes = (0, 2, 3)
|
||||
re_shape = (1, num_feature, 1, 1)
|
||||
else:
|
||||
axes = (0,)
|
||||
re_shape = (1, num_feature)
|
||||
return axes, re_shape
|
||||
|
||||
class BatchNorm1d(_BatchNorm):
|
||||
r"""
|
||||
Batch normalization layer over a 2D input.
|
||||
|
@ -474,6 +494,12 @@ class GroupNorm(Cell):
|
|||
num_channels (int): The number of channels per group.
|
||||
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
|
||||
affine (bool): A bool value, this layer will has learnable affine parameters when set to true. Default: True.
|
||||
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||
'he_uniform', etc. Default: 'ones'.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||
'he_uniform', etc. Default: 'zeros'.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input feature with shape [N, C, H, W].
|
||||
|
|
Loading…
Reference in New Issue