forked from mindspore-Ecosystem/mindspore
Add Group Normalization
This commit is contained in:
parent
2ca9e4481e
commit
ebe6efff71
|
@ -18,7 +18,7 @@ Layer.
|
|||
The high-level components(Cells) used to construct the neural network.
|
||||
"""
|
||||
from .activation import Softmax, LogSoftmax, ReLU, ReLU6, Tanh, GELU, ELU, Sigmoid, PReLU, get_activation, LeakyReLU, HSigmoid, HSwish
|
||||
from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm
|
||||
from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm, GroupNorm
|
||||
from .container import SequentialCell, CellList
|
||||
from .conv import Conv2d, Conv2dTranspose
|
||||
from .lstm import LSTM
|
||||
|
@ -29,7 +29,7 @@ from .image import ImageGradients, SSIM
|
|||
|
||||
__all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
|
||||
'PReLU', 'get_activation', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU',
|
||||
'BatchNorm1d', 'BatchNorm2d', 'LayerNorm',
|
||||
'BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm',
|
||||
'SequentialCell', 'CellList',
|
||||
'Conv2d', 'Conv2dTranspose',
|
||||
'LSTM',
|
||||
|
|
|
@ -18,8 +18,9 @@ from mindspore.ops import functional as F
|
|||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as DT
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.context as context
|
||||
from mindspore._checkparam import check_int_positive, check_bool,check_typename
|
||||
from mindspore._extends import cell_attr_register
|
||||
from ..cell import Cell
|
||||
|
||||
|
@ -58,7 +59,7 @@ class _BatchNorm(Cell):
|
|||
|
||||
if context.get_context("enable_ge"):
|
||||
self.is_ge_backend = True
|
||||
self.momentum = Tensor(1.0 - momentum, DT.float32)
|
||||
self.momentum = Tensor(1.0 - momentum, mstype.float32)
|
||||
self.bn_train = P.BatchNorm(is_training=True,
|
||||
epsilon=self.eps)
|
||||
else:
|
||||
|
@ -289,3 +290,71 @@ class LayerNorm(Cell):
|
|||
s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format(
|
||||
self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
|
||||
return s
|
||||
|
||||
class GroupNorm(Cell):
|
||||
r"""
|
||||
Group Normalization over a mini-batch of inputs.
|
||||
|
||||
Group normalization is widely used in recurrent neural networks. It applies
|
||||
normalization over a mini-batch of inputs for each single training case as described
|
||||
in the paper `Group Normalization <https://arxiv.org/pdf/1803.08494.pdf>`_. Group normalization
|
||||
divides the channels into groups and computes within each group the mean and variance for normalization,
|
||||
and it performs very stable over a wide range of batch size. It can be described using the following formula.
|
||||
|
||||
.. math::
|
||||
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
||||
|
||||
Args:
|
||||
num_groups (int): The number of groups to be divided along the channel dimension.
|
||||
num_channels (int): The number of channels per group.
|
||||
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
|
||||
affine (bool): A bool value, this layer will has learnable affine parameters when set to true. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input feature with shape [N, C, H, W].
|
||||
|
||||
Outputs:
|
||||
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
|
||||
|
||||
Examples:
|
||||
>>> goup_norm_op = nn.GroupNorm(16, 64)
|
||||
>>> x = Tensor(np.ones([1, 64, 256, 256], np.float32))
|
||||
>>> goup_norm_op(x)
|
||||
"""
|
||||
def __init__(self, num_groups, num_channels, eps=1e-05, affine=True):
|
||||
super(GroupNorm, self).__init__()
|
||||
self.num_groups = check_int_positive(num_groups)
|
||||
self.num_channels = check_int_positive(num_channels)
|
||||
if num_channels % num_groups != 0:
|
||||
raise ValueError("num_channels should be divided by num_groups")
|
||||
self.eps = Tensor(check_typename('eps', eps, (float,)),mstype.float32)
|
||||
self.affine = check_bool(affine)
|
||||
|
||||
gamma = initializer('ones', [num_channels, 1, 1], mstype.float32)
|
||||
beta = initializer('zeros', [num_channels, 1, 1], mstype.float32)
|
||||
if self.affine:
|
||||
self.gamma = Parameter(gamma, name='gamma')
|
||||
self.beta = Parameter(beta, name='beta')
|
||||
else:
|
||||
self.gamma = gamma
|
||||
self.beta = beta
|
||||
self.shape = F.shape
|
||||
self.reshape = F.reshape
|
||||
self.reduce_mean = P.ReduceMean(keep_dims=True)
|
||||
self.square = F.square
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
||||
self.sqrt = P.Sqrt()
|
||||
|
||||
def construct(self, x):
|
||||
batch,channel,height,width = self.shape(x)
|
||||
x = self.reshape(x,(batch, self.num_groups,channel*height*width/self.num_groups))
|
||||
mean = self.reduce_mean(x, 2)
|
||||
var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups - 1)
|
||||
std = self.sqrt(var + self.eps)
|
||||
x = (x - mean) / std
|
||||
x = self.reshape(x, (batch, channel, height, width))
|
||||
output = x * self.gamma + self.beta
|
||||
return output
|
||||
|
||||
def extend_repr(self):
|
||||
return 'num_groups={}, num_channels={}'.format(self.num_groups,self.num_channels)
|
|
@ -56,3 +56,15 @@ def test_compile():
|
|||
net = Net()
|
||||
input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
|
||||
_executor.compile(net, input_data)
|
||||
|
||||
class GroupNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(GroupNet, self).__init__()
|
||||
self.group_bn = nn.GroupNorm()
|
||||
def construct(self, x):
|
||||
return self.group_bn(x)
|
||||
|
||||
def test_compile_groupnorm():
|
||||
net = nn.GroupNorm(16, 64)
|
||||
input_data = Tensor(np.random.rand(1,64,256,256).astype(np.float32))
|
||||
_executor.compile(net, input_data)
|
Loading…
Reference in New Issue