forked from mindspore-Ecosystem/mindspore
!399 Add Global Batch Normalization
Merge pull request !399 from JichenZhao/syncbn
This commit is contained in:
commit
b9b056f40d
|
@ -18,7 +18,7 @@ Layer.
|
|||
The high-level components(Cells) used to construct the neural network.
|
||||
"""
|
||||
from .activation import Softmax, LogSoftmax, ReLU, ReLU6, Tanh, GELU, ELU, Sigmoid, PReLU, get_activation, LeakyReLU, HSigmoid, HSwish
|
||||
from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm, GroupNorm
|
||||
from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm, GroupNorm, GlobalBatchNorm
|
||||
from .container import SequentialCell, CellList
|
||||
from .conv import Conv2d, Conv2dTranspose
|
||||
from .lstm import LSTM
|
||||
|
@ -29,7 +29,7 @@ from .image import ImageGradients, SSIM, PSNR
|
|||
|
||||
__all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
|
||||
'PReLU', 'get_activation', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU',
|
||||
'BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm',
|
||||
'BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', 'GlobalBatchNorm',
|
||||
'SequentialCell', 'CellList',
|
||||
'Conv2d', 'Conv2dTranspose',
|
||||
'LSTM',
|
||||
|
|
|
@ -20,8 +20,11 @@ from mindspore.common.initializer import initializer
|
|||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.context as context
|
||||
from mindspore._checkparam import check_int_positive, check_bool, check_typename
|
||||
from mindspore._checkparam import check_bool, check_typename
|
||||
from mindspore._extends import cell_attr_register
|
||||
from mindspore.communication.management import get_group_size, get_rank
|
||||
from mindspore.communication import management
|
||||
from mindspore._checkparam import check_int_positive
|
||||
from ..cell import Cell
|
||||
|
||||
|
||||
|
@ -30,6 +33,7 @@ class _BatchNorm(Cell):
|
|||
@cell_attr_register
|
||||
def __init__(self,
|
||||
num_features,
|
||||
group=1,
|
||||
eps=1e-5,
|
||||
momentum=0.9,
|
||||
affine=True,
|
||||
|
@ -56,6 +60,21 @@ class _BatchNorm(Cell):
|
|||
gamma_init, num_features), name="gamma", requires_grad=affine)
|
||||
self.beta = Parameter(initializer(
|
||||
beta_init, num_features), name="beta", requires_grad=affine)
|
||||
self.group = check_int_positive(group)
|
||||
if self.group != 1:
|
||||
self.rank_id = get_rank()
|
||||
self.rank_size = get_group_size()
|
||||
self.device_list = [i for i in range(0, self.rank_size)]
|
||||
self.rank_list = self.list_group(self.device_list, self.group)
|
||||
self.rank_list_idx = len(self.rank_list)
|
||||
for i in range(self.rank_list_idx):
|
||||
if self.rank_id in self.rank_list[i] and self.group != 1:
|
||||
self.is_global = True
|
||||
management.create_group('group' + str(i), self.rank_list[i])
|
||||
self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1)
|
||||
self.shape = P.Shape()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.square = P.Square()
|
||||
|
||||
if context.get_context("enable_ge"):
|
||||
self.is_ge_backend = True
|
||||
|
@ -82,9 +101,40 @@ class _BatchNorm(Cell):
|
|||
def _check_data_dim(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def list_group(self, world_rank, group_size):
|
||||
if group_size > get_group_size():
|
||||
raise ValueError("group size can not be greater than local rank size, group size is {}, "
|
||||
"local_rank_size is {}".format(group_size, get_group_size()))
|
||||
if len(world_rank) % group_size != 0:
|
||||
raise ValueError("please make your group size correct.")
|
||||
world_rank_list = zip(*(iter(world_rank),) *group_size)
|
||||
group_list = [list(i) for i in world_rank_list]
|
||||
return group_list
|
||||
|
||||
def construct(self, x):
|
||||
if self.training and self.use_batch_statistics:
|
||||
if self.is_ge_backend:
|
||||
if self.is_global:
|
||||
x_mean = self.reduce_mean(x)
|
||||
x_mean_square = self.reduce_mean(self.square(x))
|
||||
global_batch_mean = self.all_reduce(x_mean) / self.group
|
||||
global_batch_mean_square = self.all_reduce(x_mean_square) / self.group
|
||||
global_mean = global_batch_mean
|
||||
global_var = global_batch_mean_square - self.square(global_batch_mean)
|
||||
y, batch_mean, batch_var, _, _ = \
|
||||
self.bn_train(x,
|
||||
self.gamma,
|
||||
self.beta,
|
||||
None,
|
||||
None)
|
||||
|
||||
mean_sub = self.sub_mean(self.moving_mean, global_mean)
|
||||
temp_mean = self.mul_mean(mean_sub, self.momentum)
|
||||
mean_sub2 = self.sub_var(self.moving_variance, global_var)
|
||||
temp_variance = self.mul_var(mean_sub2, self.momentum)
|
||||
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
|
||||
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
|
||||
else:
|
||||
y, batch_mean, batch_var, _, _ = \
|
||||
self.bn_train(x,
|
||||
self.gamma,
|
||||
|
@ -221,6 +271,55 @@ class BatchNorm2d(_BatchNorm):
|
|||
pass
|
||||
|
||||
|
||||
class GlobalBatchNorm(_BatchNorm):
|
||||
r"""
|
||||
Global normalization layer over a N-dimension input.
|
||||
|
||||
Global Normalization is cross device synchronized batch normalization. Batch Normalization implementation
|
||||
only normalize the data within each device. Global normalization will normalize the input within the group.
|
||||
It has been described in the paper `Batch Normalization: Accelerating Deep Network Training by
|
||||
Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
|
||||
feature using a mini-batch of data and the learned parameters which can be described in the following formula.
|
||||
|
||||
.. math::
|
||||
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
||||
|
||||
Args:
|
||||
num_features (int): `C` from an expected input of size (N, C, H, W).
|
||||
group (int): The number of device in each group.
|
||||
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
|
||||
momentum (float): A floating hyperparameter of the momentum for the
|
||||
running_mean and running_var computation. Default: 0.9.
|
||||
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||
'he_uniform', etc. Default: 'ones'.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||
'he_uniform', etc. Default: 'zeros'.
|
||||
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||
'he_uniform', etc. Default: 'zeros'.
|
||||
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||
'he_uniform', etc. Default: 'ones'.
|
||||
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
|
||||
the mean value and variance value of specified value. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||
|
||||
Examples:
|
||||
>>> global_bn_op = nn.GlobalBatchNorm(num_features=3, group=4)
|
||||
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
|
||||
>>> global_bn_op(input)
|
||||
"""
|
||||
def _check_data_dim(self, x):
|
||||
if x.dim == 0:
|
||||
pass
|
||||
|
||||
class LayerNorm(Cell):
|
||||
r"""
|
||||
Applies Layer Normalization over a mini-batch of inputs.
|
||||
|
|
Loading…
Reference in New Issue