!36327 add document for InstanceNorm1d and InstanceNorm3d

Merge pull request !36327 from zhujingxuan/code_docs_instance_norm
This commit is contained in:
i-robot 2022-06-24 09:08:36 +00:00 committed by Gitee
commit 824246769a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 244 additions and 11 deletions

View File

@ -164,7 +164,9 @@ Dropout层
mindspore.nn.BatchNorm3d
mindspore.nn.GlobalBatchNorm
mindspore.nn.GroupNorm
mindspore.nn.InstanceNorm1d
mindspore.nn.InstanceNorm2d
mindspore.nn.InstanceNorm3d
mindspore.nn.LayerNorm
mindspore.nn.SyncBatchNorm

View File

@ -0,0 +1,50 @@
mindspore.nn.InstanceNorm1d
============================
.. py:class:: mindspore.nn.InstanceNorm1d(num_features, eps=1e-5, momentum=0.1, affine=True, gamma_init='ones', beta_init='zeros')
对三维输入实现实例归一化Instance Normalization Layer
该层在三维输入带有额外通道维度的mini-batch一维输入上应用实例归一化详见论文 `Instance Normalization:
The Missing Ingredient for Fast Stylization <https://arxiv.org/abs/1607.08022>`_ 。
使用mini-batch数据和学习参数进行训练参数见如下公式。
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
其中 :math:`\gamma`:math:`\beta` 是可学习的参数向量,如果 `affine` 为True则大小为 `num_features` 。通过偏置估计函数计算标准偏差。
此层使用从训练和验证模式的输入数据计算得到的实例数据。
InstanceNorm1d和BatchNorm1d类似不同之处在于InstanceNorm1d应用于RGB图像等通道数据的每个通道而BatchNorm1d通常应用于批处理。
.. note::
需要注意的是,更新滑动平均和滑动方差的公式为 :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}` ,其中 :math:`\hat{x}` 是估计的统计量, :math:`x_t` 是新的观察值。
**参数:**
- **num_features** (int) - 通道数量输入Tensor shape :math:`(N, C, L)` 中的 `C`
- **eps** (float) - 添加到分母中的值以确保数值稳定。默认值1e-5。
- **momentum** (float) - 动态均值和动态方差所使用的动量。默认值0.1。
- **affine** (bool) - bool类型。设置为True时可以学习gamma和beta参数。默认值True。
- **gamma_init** (Union[Tensor, str, Initializer, numbers.Number]) - gamma参数的初始化方法。str的值引用自函数 `initializer` ,包括'zeros'、'ones'等。默认值:'ones'。
- **beta_init** (Union[Tensor, str, Initializer, numbers.Number]) - beta参数的初始化方法。str的值引用自函数 `initializer` ,包括'zeros'、'ones'等。默认值:'zeros'。
**输入:**
- **x** (Tensor) - shape为 :math:`(N, C, L)` 的Tensor。数据类型为float16或float32。
**输出:**
Tensor归一化缩放偏移后的Tensor其shape为 :math:`(N, C, L)` 。类型和shape与 `x` 相同。
**异常:**
- **TypeError** - `num_features` 不是整数。
- **TypeError** - `eps` 的类型不是float。
- **TypeError** - `momentum` 的类型不是float。
- **TypeError** - `affine` 不是bool。
- **TypeError** - `gamma_init` / `beta_init` 的类型不相同或者初始化的元素类型不是float32。
- **ValueError** - `num_features` 小于1。
- **ValueError** - `momentum` 不在范围[0, 1]内。
- **KeyError** - `gamma_init` / `beta_init` 中的任何一个是str并且不存在继承自 `Initializer` 的同义类。

View File

@ -12,11 +12,11 @@ mindspore.nn.InstanceNorm2d
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
其中\gamma和\beta是可学习的参数向量如果 `affine` 为True则大小为 `num_features` 。通过偏置估计函数计算标准偏差。
其中 :math:`\gamma` :math:`\beta` 是可学习的参数向量,如果 `affine` 为True则大小为 `num_features` 。通过偏置估计函数计算标准偏差。
此层使用从训练和验证模式的输入数据计算得到的实例数据。
InstanceNorm2d和BatchNorm2d非常相似,但略有不同。InstanceNorm2d应用于RGB图像等通道数据的每个通道而BatchNorm2d通常应用于批处理。
InstanceNorm2d和BatchNorm2d类似,不同之处在于InstanceNorm2d应用于RGB图像等通道数据的每个通道而BatchNorm2d通常应用于批处理。
.. note::
需要注意的是,更新滑动平均和滑动方差的公式为 :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}` ,其中 :math:`\hat{x}` 是估计的统计量, :math:`x_t` 是新的观察值。
@ -41,11 +41,10 @@ mindspore.nn.InstanceNorm2d
**异常:**
- **TypeError** - `num_features` 不是整数。
- **TypeError** - `eps` 不是float。
- **TypeError** - `momentum` 不是float。
- **TypeError** - `eps` 的类型不是float。
- **TypeError** - `momentum` 的类型不是float。
- **TypeError** - `affine` 不是bool。
- **TypeError** - `gamma_init` / `beta_init` 的类型不相同或者初始化的元素类型不是float32。
- **ValueError** - `num_features` 小于1。
- **ValueError** - `momentum` 不在范围[0, 1]内。
- **KeyError** - `gamma_init` / `beta_init` 中的任何一个是str并且不存在继承自 `Initializer` 的同义类。

View File

@ -0,0 +1,50 @@
mindspore.nn.InstanceNorm3d
============================
.. py:class:: mindspore.nn.InstanceNorm3d(num_features, eps=1e-5, momentum=0.1, affine=True, gamma_init='ones', beta_init='zeros')
对五维输入实现实例归一化Instance Normalization Layer
该层在五维输入带有额外通道维度的mini-batch三维输入上应用实例归一化详见论文 `Instance Normalization:
The Missing Ingredient for Fast Stylization <https://arxiv.org/abs/1607.08022>`_ 。
使用mini-batch数据和学习参数进行训练参数见如下公式。
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
其中 :math:`\gamma`:math:`\beta` 是可学习的参数向量,如果 `affine` 为True则大小为 `num_features` 。通过偏置估计函数计算标准偏差。
此层使用从训练和验证模式的输入数据计算得到的实例数据。
InstanceNorm3d和BatchNorm3d类似不同之处在于InstanceNorm3d应用于RGB图像等通道数据的每个通道而BatchNorm3d通常应用于批处理。
.. note::
需要注意的是,更新滑动平均和滑动方差的公式为 :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}` ,其中 :math:`\hat{x}` 是估计的统计量, :math:`x_t` 是新的观察值。
**参数:**
- **num_features** (int) - 通道数量输入Tensor shape :math:`(N, C, D, H, W)` 中的 `C`
- **eps** (float) - 添加到分母中的值以确保数值稳定。默认值1e-5。
- **momentum** (float) - 动态均值和动态方差所使用的动量。默认值0.1。
- **affine** (bool) - bool类型。设置为True时可以学习gamma和beta参数。默认值True。
- **gamma_init** (Union[Tensor, str, Initializer, numbers.Number]) - gamma参数的初始化方法。str的值引用自函数 `initializer` ,包括'zeros'、'ones'等。默认值:'ones'。
- **beta_init** (Union[Tensor, str, Initializer, numbers.Number]) - beta参数的初始化方法。str的值引用自函数 `initializer` ,包括'zeros'、'ones'等。默认值:'zeros'。
**输入:**
- **x** (Tensor) - shape为 :math:`(N, C, D, H, W)` 的Tensor。数据类型为float16或float32。
**输出:**
Tensor归一化缩放偏移后的Tensor其shape为 :math:`(N, C, D, H, W)` 。类型和shape与 `x` 相同。
**异常:**
- **TypeError** - `num_features` 不是整数。
- **TypeError** - `eps` 的类型不是float。
- **TypeError** - `momentum` 的类型不是float。
- **TypeError** - `affine` 不是bool。
- **TypeError** - `gamma_init` / `beta_init` 的类型不相同或者初始化的元素类型不是float32。
- **ValueError** - `num_features` 小于1。
- **ValueError** - `momentum` 不在范围[0, 1]内。
- **KeyError** - `gamma_init` / `beta_init` 中的任何一个是str并且不存在继承自 `Initializer` 的同义类。

View File

@ -163,7 +163,9 @@ Normalization Layer
mindspore.nn.BatchNorm3d
mindspore.nn.GlobalBatchNorm
mindspore.nn.GroupNorm
mindspore.nn.InstanceNorm1d
mindspore.nn.InstanceNorm2d
mindspore.nn.InstanceNorm3d
mindspore.nn.LayerNorm
mindspore.nn.SyncBatchNorm

View File

@ -933,6 +933,71 @@ class _InstanceNorm(Cell):
class InstanceNorm1d(_InstanceNorm):
r"""
Instance Normalization layer over a 3D input.
This layer applies Instance Normalization over a 3D input (a mini-batch of 1D inputs with
additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for
Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch
of data and the learned parameters which can be described in the following formula.
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
:math:`\gamma` and :math:`\beta` are learnable parameter vectors of size num_features if affine is True.
The standard-deviation is calculated via the biased estimator.
This layer uses instance statistics computed from input data in both training and evaluation modes.
InstanceNorm1d and BatchNorm1d are very similar, but have some differences. InstanceNorm1d is applied on each
channel of channeled data like RGB images, but BatchNorm1d is usually applied on each batch of batched data.
Note:
Note that the formula for updating the running_mean and running_var is
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}`,
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
Args:
num_features (int): `C` from an expected input of size (N, C, L).
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
momentum (float): A floating hyperparameter of the momentum for the
running_mean and running_var computation. Default: 0.1.
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C, L)`. Data type: float16 or float32.
Outputs:
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C, L)`. Same type and
shape as the `x`.
Supported Platforms:
``GPU``
Raises:
TypeError: If the type of `num_features` is not int.
TypeError: If the type of `eps` is not float.
TypeError: If the type of `momentum` is not float.
TypeError: If the type of `affine` is not bool.
TypeError: If the type of `gamma_init`/`beta_init` is not same, or if the initialized element type is not
float32.
ValueError: If `num_features` is less than 1.
ValueError: If `momentum` is not in range [0, 1].
KeyError: If any of `gamma_init`/`beta_init` is str and the homonymous class inheriting from `Initializer` not
exists.
Examples:
>>> import mindspore
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor
>>> net = nn.InstanceNorm1d(3)
>>> x = Tensor(np.ones([2, 3, 5]), mindspore.float32)
>>> output = net(x)
>>> print(output.shape)
(2, 3, 5)
"""
def __init__(self,
@ -964,8 +1029,8 @@ class InstanceNorm2d(_InstanceNorm):
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
\gamma and \beta are learnable parameter vectors of size num_features if affine is True. The standard-deviation
is calculated via the biased estimator.
:math:`\gamma` and :math:`\beta` are learnable parameter vectors of size num_features if affine is True.
The standard-deviation is calculated via the biased estimator.
This layer uses instance statistics computed from input data in both training and evaluation modes.
@ -999,10 +1064,10 @@ class InstanceNorm2d(_InstanceNorm):
``GPU``
Raises:
TypeError: If `num_features` is not an int.
TypeError: If `eps` is not a float.
TypeError: If `momentum` is not a float.
TypeError: If `affine` is not a bool.
TypeError: If the type of `num_features` is not int.
TypeError: If the type of `eps` is not float.
TypeError: If the type of `momentum` is not float.
TypeError: If the type of `affine` is not bool.
TypeError: If the type of `gamma_init`/`beta_init` is not same, or if the initialized element type is not
float32.
ValueError: If `num_features` is less than 1.
@ -1042,6 +1107,71 @@ class InstanceNorm2d(_InstanceNorm):
class InstanceNorm3d(_InstanceNorm):
r"""
Instance Normalization layer over a 5D input.
This layer applies Instance Normalization over a 5D input (a mini-batch of 3D inputs with
additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for
Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch
of data and the learned parameters which can be described in the following formula.
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
:math:`\gamma` and :math:`\beta` are learnable parameter vectors of size num_features if affine is True.
The standard-deviation is calculated via the biased estimator.
This layer uses instance statistics computed from input data in both training and evaluation modes.
InstanceNorm3d and BatchNorm3d are very similar, but have some differences. InstanceNorm3d is applied on each
channel of channeled data like RGB images, but BatchNorm3d is usually applied on each batch of batched data.
Note:
Note that the formula for updating the running_mean and running_var is
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}`,
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
Args:
num_features (int): `C` from an expected input of size (N, C, D, H, W).
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
momentum (float): A floating hyperparameter of the momentum for the
running_mean and running_var computation. Default: 0.1.
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C, D, H, W)`. Data type: float16 or float32.
Outputs:
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C, D, H, W)`. Same type and
shape as the `x`.
Supported Platforms:
``GPU``
Raises:
TypeError: If the type of `num_features` is not int.
TypeError: If the type of `eps` is not float.
TypeError: If the type of `momentum` is not float.
TypeError: If the type of `affine` is not bool.
TypeError: If the type of `gamma_init`/`beta_init` is not same, or if the initialized element type is not
float32.
ValueError: If `num_features` is less than 1.
ValueError: If `momentum` is not in range [0, 1].
KeyError: If any of `gamma_init`/`beta_init` is str and the homonymous class inheriting from `Initializer` not
exists.
Examples:
>>> import mindspore
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor
>>> net = nn.InstanceNorm3d(3)
>>> x = Tensor(np.ones([2, 3, 5, 2, 2]), mindspore.float32)
>>> output = net(x)
>>> print(output.shape)
(2, 3, 5, 2, 2)
"""
def __init__(self,