forked from mindspore-Ecosystem/mindspore
add batchnorm3d
This commit is contained in:
parent
7a3f34247d
commit
9f950fb16c
|
@ -420,6 +420,100 @@ class BatchNorm2d(_BatchNorm):
|
|||
pass
|
||||
|
||||
|
||||
class BatchNorm3d(Cell):
|
||||
r"""
|
||||
Batch normalization layer over a 5D input.
|
||||
|
||||
Batch Normalization is widely used in convolutional networks. This layer
|
||||
applies Batch Normalization over a 5D input (a mini-batch of 3D inputs with
|
||||
additional channel dimension) to avoid internal covariate shift.
|
||||
|
||||
.. math::
|
||||
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
||||
|
||||
Note:
|
||||
The implementation of BatchNorm is different in graph mode and pynative mode, therefore that mode can not be
|
||||
changed after net was initilized.
|
||||
Note that the formula for updating the running_mean and running_var is
|
||||
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}`,
|
||||
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
|
||||
|
||||
Args:
|
||||
num_features (int): `C` from an expected input of size (N, C, D, H, W).
|
||||
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
|
||||
momentum (float): A floating hyperparameter of the momentum for the
|
||||
running_mean and running_var computation. Default: 0.9.
|
||||
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
|
||||
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||
'he_uniform', etc. Default: 'ones'.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||
'he_uniform', etc. Default: 'zeros'.
|
||||
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||
'he_uniform', etc. Default: 'zeros'.
|
||||
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||
'he_uniform', etc. Default: 'ones'.
|
||||
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
|
||||
use the mean value and variance value of specified value. If None, the training process will use the mean
|
||||
and variance of current batch data and track the running mean and variance, the evaluation process will use
|
||||
the running mean and variance. Default: None.
|
||||
data_format (str): The optional value for data format is 'NCDHW'. Default: 'NCDHW'.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, D_{out},H_{out}, W_{out})`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> net = nn.BatchNorm3d(num_features=3)
|
||||
>>> np.random.seed(0)
|
||||
>>> input = Tensor(np.random.randint(0, 255, [16, 3, 10, 32, 32]), mindspore.float32)
|
||||
>>> output = net(input)
|
||||
>>> print(output.shape)
|
||||
(16, 3, 10, 32, 32)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_features,
|
||||
eps=1e-5,
|
||||
momentum=0.9,
|
||||
affine=True,
|
||||
gamma_init='ones',
|
||||
beta_init='zeros',
|
||||
moving_mean_init='zeros',
|
||||
moving_var_init='ones',
|
||||
use_batch_statistics=None,
|
||||
data_format='NCDHW'):
|
||||
super(BatchNorm3d, self).__init__()
|
||||
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name)
|
||||
self.bn2d = BatchNorm2d(num_features=num_features,
|
||||
eps=eps,
|
||||
momentum=momentum,
|
||||
affine=affine,
|
||||
gamma_init=gamma_init,
|
||||
beta_init=beta_init,
|
||||
moving_mean_init=moving_mean_init,
|
||||
moving_var_init=moving_var_init,
|
||||
use_batch_statistics=use_batch_statistics,
|
||||
data_format="NCHW")
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, input_x):
|
||||
x_shape = self.shape(input_x)
|
||||
input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2]*x_shape[3], x_shape[4]))
|
||||
bn2d_out = self.bn2d(input_x)
|
||||
bn3d_out = self.reshape(bn2d_out, x_shape)
|
||||
return bn3d_out
|
||||
|
||||
|
||||
class GlobalBatchNorm(_BatchNorm):
|
||||
r"""
|
||||
Global normalization layer over a N-dimension input.
|
||||
|
|
|
@ -28,6 +28,7 @@ from mindspore.ops.operations import _grad_ops as G
|
|||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.ops.operations import _quant_ops as Q
|
||||
from mindspore.ops.operations import nn_ops as nps
|
||||
from mindspore.nn.layer import normalization
|
||||
from ..ut_filter import non_graph_engine
|
||||
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
||||
from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
||||
|
@ -229,6 +230,18 @@ class Moments(nn.Cell):
|
|||
return mean, variance
|
||||
|
||||
|
||||
class BatchNorm3d(nn.Cell):
|
||||
"""BatchNorm3d net definition"""
|
||||
|
||||
def __init__(self, num_features):
|
||||
super(BatchNorm3d, self).__init__()
|
||||
self.bn3d = normalization.BatchNorm3d(num_features=num_features)
|
||||
|
||||
def construct(self, input_x):
|
||||
bn3d_out = self.bn3d(input_x)
|
||||
return bn3d_out
|
||||
|
||||
|
||||
class ClipByNorm(nn.Cell):
|
||||
"""ClipByNorm net definition"""
|
||||
|
||||
|
@ -1240,6 +1253,10 @@ test_case_math_ops = [
|
|||
'block': Moments(axis=(), keep_dims=False),
|
||||
'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
('BatchNorm3d', {
|
||||
'block': BatchNorm3d(num_features=3),
|
||||
'desc_inputs': [Tensor(np.random.rand(3, 3, 3, 5, 4).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
('Conv3D', {
|
||||
'block': Conv3D(out_channel=32, kernel_size=(4, 3, 3), mode=1, pad_mode='valid', pad=0,
|
||||
stride=1, dilation=1, group=1, data_format="NCDHW"),
|
||||
|
|
Loading…
Reference in New Issue