Add Group Normalization

This commit is contained in:
zhaojichen 2020-04-16 01:42:32 -04:00
parent 2ca9e4481e
commit ebe6efff71
3 changed files with 85 additions and 4 deletions

View File

@ -18,7 +18,7 @@ Layer.
The high-level components(Cells) used to construct the neural network. The high-level components(Cells) used to construct the neural network.
""" """
from .activation import Softmax, LogSoftmax, ReLU, ReLU6, Tanh, GELU, ELU, Sigmoid, PReLU, get_activation, LeakyReLU, HSigmoid, HSwish from .activation import Softmax, LogSoftmax, ReLU, ReLU6, Tanh, GELU, ELU, Sigmoid, PReLU, get_activation, LeakyReLU, HSigmoid, HSwish
from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm, GroupNorm
from .container import SequentialCell, CellList from .container import SequentialCell, CellList
from .conv import Conv2d, Conv2dTranspose from .conv import Conv2d, Conv2dTranspose
from .lstm import LSTM from .lstm import LSTM
@ -29,7 +29,7 @@ from .image import ImageGradients, SSIM
__all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
'PReLU', 'get_activation', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU', 'PReLU', 'get_activation', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU',
'BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm',
'SequentialCell', 'CellList', 'SequentialCell', 'CellList',
'Conv2d', 'Conv2dTranspose', 'Conv2d', 'Conv2dTranspose',
'LSTM', 'LSTM',

View File

@ -18,8 +18,9 @@ from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
import mindspore.common.dtype as DT import mindspore.common.dtype as mstype
import mindspore.context as context import mindspore.context as context
from mindspore._checkparam import check_int_positive, check_bool,check_typename
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from ..cell import Cell from ..cell import Cell
@ -58,7 +59,7 @@ class _BatchNorm(Cell):
if context.get_context("enable_ge"): if context.get_context("enable_ge"):
self.is_ge_backend = True self.is_ge_backend = True
self.momentum = Tensor(1.0 - momentum, DT.float32) self.momentum = Tensor(1.0 - momentum, mstype.float32)
self.bn_train = P.BatchNorm(is_training=True, self.bn_train = P.BatchNorm(is_training=True,
epsilon=self.eps) epsilon=self.eps)
else: else:
@ -289,3 +290,71 @@ class LayerNorm(Cell):
s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format( s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format(
self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta) self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
return s return s
class GroupNorm(Cell):
r"""
Group Normalization over a mini-batch of inputs.
Group normalization is widely used in recurrent neural networks. It applies
normalization over a mini-batch of inputs for each single training case as described
in the paper `Group Normalization <https://arxiv.org/pdf/1803.08494.pdf>`_. Group normalization
divides the channels into groups and computes within each group the mean and variance for normalization,
and it performs very stable over a wide range of batch size. It can be described using the following formula.
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Args:
num_groups (int): The number of groups to be divided along the channel dimension.
num_channels (int): The number of channels per group.
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
affine (bool): A bool value, this layer will has learnable affine parameters when set to true. Default: True.
Inputs:
- **input_x** (Tensor) - The input feature with shape [N, C, H, W].
Outputs:
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
Examples:
>>> goup_norm_op = nn.GroupNorm(16, 64)
>>> x = Tensor(np.ones([1, 64, 256, 256], np.float32))
>>> goup_norm_op(x)
"""
def __init__(self, num_groups, num_channels, eps=1e-05, affine=True):
super(GroupNorm, self).__init__()
self.num_groups = check_int_positive(num_groups)
self.num_channels = check_int_positive(num_channels)
if num_channels % num_groups != 0:
raise ValueError("num_channels should be divided by num_groups")
self.eps = Tensor(check_typename('eps', eps, (float,)),mstype.float32)
self.affine = check_bool(affine)
gamma = initializer('ones', [num_channels, 1, 1], mstype.float32)
beta = initializer('zeros', [num_channels, 1, 1], mstype.float32)
if self.affine:
self.gamma = Parameter(gamma, name='gamma')
self.beta = Parameter(beta, name='beta')
else:
self.gamma = gamma
self.beta = beta
self.shape = F.shape
self.reshape = F.reshape
self.reduce_mean = P.ReduceMean(keep_dims=True)
self.square = F.square
self.reduce_sum = P.ReduceSum(keep_dims=True)
self.sqrt = P.Sqrt()
def construct(self, x):
batch,channel,height,width = self.shape(x)
x = self.reshape(x,(batch, self.num_groups,channel*height*width/self.num_groups))
mean = self.reduce_mean(x, 2)
var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups - 1)
std = self.sqrt(var + self.eps)
x = (x - mean) / std
x = self.reshape(x, (batch, channel, height, width))
output = x * self.gamma + self.beta
return output
def extend_repr(self):
return 'num_groups={}, num_channels={}'.format(self.num_groups,self.num_channels)

View File

@ -56,3 +56,15 @@ def test_compile():
net = Net() net = Net()
input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32)) input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
_executor.compile(net, input_data) _executor.compile(net, input_data)
class GroupNet(nn.Cell):
def __init__(self):
super(GroupNet, self).__init__()
self.group_bn = nn.GroupNorm()
def construct(self, x):
return self.group_bn(x)
def test_compile_groupnorm():
net = nn.GroupNorm(16, 64)
input_data = Tensor(np.random.rand(1,64,256,256).astype(np.float32))
_executor.compile(net, input_data)