!832 fix globalbatchnorm bug

Merge pull request !832 from JichenZhao/syncbn
This commit is contained in:
mindspore-ci-bot 2020-04-29 22:28:28 +08:00 committed by Gitee
commit 64c87170fd
1 changed files with 46 additions and 20 deletions

View File

@ -74,8 +74,12 @@ class _BatchNorm(Cell):
management.create_group('group' + str(i), self.rank_list[i])
self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1)
self.shape = P.Shape()
self.reduce_mean = P.ReduceMean()
self.reduce_mean = P.ReduceMean(keep_dims=True)
self.square = P.Square()
self.sqrt = P.Sqrt()
self.cast = P.Cast()
self.dtype = P.DType()
self.reshape = P.Reshape()
if context.get_context("enable_ge"):
self.is_ge_backend = True
@ -112,29 +116,34 @@ class _BatchNorm(Cell):
group_list = [list(i) for i in world_rank_list]
return group_list
def _global_sync(self, x, axes, re_shape):
"""calculate global batch normalization output"""
x_mean = self.reduce_mean(x, axes)
x_mean_square = self.reduce_mean(self.square(x), axes)
global_batch_mean = self.all_reduce(x_mean) / self.group
global_batch_mean_square = self.all_reduce(x_mean_square) / self.group
global_mean = global_batch_mean
global_var = global_batch_mean_square - self.square(global_mean)
var_sqrt = self.sqrt(global_var + self.eps)
mean_first = (x - global_mean) / var_sqrt
y = mean_first * self.reshape(self.gamma, re_shape) + self.reshape(self.beta, re_shape)
mean_sub = self.sub_mean(self.reshape(self.moving_mean, re_shape), global_mean)
tmp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub)))
mean_sub2 = self.sub_var(self.reshape(self.moving_mean, re_shape), global_var)
tmp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2)))
y = F.depend(y, self.assign_sub_mean(self.reshape(self.moving_mean, re_shape), tmp_mean))
y = F.depend(y, self.assign_sub_var(self.reshape(self.moving_variance, re_shape), tmp_variance))
return y
def construct(self, x):
if self.training and self.use_batch_statistics:
if self.is_ge_backend:
if self.is_global:
x_mean = self.reduce_mean(x)
x_mean_square = self.reduce_mean(self.square(x))
global_batch_mean = self.all_reduce(x_mean) / self.group
global_batch_mean_square = self.all_reduce(x_mean_square) / self.group
global_mean = global_batch_mean
global_var = global_batch_mean_square - self.square(global_batch_mean)
y, batch_mean, batch_var, _, _ = \
self.bn_train(x,
self.gamma,
self.beta,
None,
None)
mean_sub = self.sub_mean(self.moving_mean, global_mean)
temp_mean = self.mul_mean(mean_sub, self.momentum)
mean_sub2 = self.sub_var(self.moving_variance, global_var)
temp_variance = self.mul_var(mean_sub2, self.momentum)
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
y = self._global_sync(x, axes, re_shape)
else:
y, batch_mean, batch_var, _, _ = \
self.bn_train(x,
@ -172,6 +181,17 @@ def _channel_check(channel, num_channel):
if channel != num_channel:
raise ValueError("the input channel is not equal with num_channel")
@constexpr
def _shape_infer(x_shape, num_feature):
"""global batch normalization shape and axes infer"""
if len(x_shape) == 4:
axes = (0, 2, 3)
re_shape = (1, num_feature, 1, 1)
else:
axes = (0,)
re_shape = (1, num_feature)
return axes, re_shape
class BatchNorm1d(_BatchNorm):
r"""
Batch normalization layer over a 2D input.
@ -474,6 +494,12 @@ class GroupNorm(Cell):
num_channels (int): The number of channels per group.
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
affine (bool): A bool value, this layer will has learnable affine parameters when set to true. Default: True.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'zeros'.
Inputs:
- **input_x** (Tensor) - The input feature with shape [N, C, H, W].