ops batch_norm, bias_add, binary_cross_entropy, bmm functional api

This commit is contained in:
panfengfeng 2022-09-19 11:50:04 +08:00
parent d9b371e6a1
commit 4391f3bd32
8 changed files with 135 additions and 61 deletions

View File

@ -0,0 +1,41 @@
mindspore.ops.batch_norm
========================
.. py:function:: mindspore.ops.batch_norm(input_x, running_mean, running_var, weight, bias, training=False, momentum=0.1, eps=1e-5)
对输入数据进行归一化和更新参数。
批量归一化广泛应用于卷积神经网络中。此运算对输入应用归一化,避免内部协变量偏移,详见论文 `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_ 。使用mini-batch数据和学习参数进行训练学习的参数见如下公式中
.. math::
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
其中, :math:`\gamma``weight` :math:`\beta``bias` :math:`\epsilon``eps` :math:`mean`:math:`x` 的均值, :math:`variance`:math:`x` 的方差。
.. warning::
- 如果该运算用于推理,并且输出"reserve_space_1"和"reserve_space_2"可用,则"reserve_space_1"的值与"mean"相同,"reserve_space_2"的值与"variance"相同。
- 对于Ascend 310由于平方根指令结果精度未能达到1‰。
.. note::
- 如果 `training` 为False`running_mean``running_var``weight``bias` 是Tensor。
- 如果 `training` 为True`running_mean``running_var``weight``bias` 是Parameter。
参数:
- **input_x** (Tensor) - 数据输入shape为 :math:`(N, C)` 的Tensor数据类型为float16或float32。
- **running_mean** (Union[Tensor, Parameter]) - shape为 :math:`(C,)` ,具有与 `weight` 相同的数据类型。
- **running_var** (Union[Tensor, Parameter]) - shape为 :math:`(C,)` ,具有与 `weight` 相同的数据类型。
- **weight** (Union[Tensor, Parameter]) - shape为 :math:`(C,)` 数据类型为float16或float32。
- **bias** (Union[Tensor, Parameter]) - shape为 :math:`(C,)` ,具有与 `weight` 相同的数据类型。
- **training** (bool) - 如果 `training``True``running_mean``running_var` 会在训练过程中进行计算。
如果 `training``False` 它们会在推理阶段从checkpoint中加载。默认值False。
- **momentum** (float) - 动态均值和动态方差所使用的动量。(例如 :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`)。动量值必须为[0, 1]。默认值0.1。
- **eps** (float) - 添加到分母上的值以确保数值稳定性。默认值1e-5。
返回:
Tensor 数据类型与shape大小与 `input_x` 相同其中shape大小为 :math:`(N, C)`
异常:
- **TypeError** - `training` 不是bool。
- **TypeError** - `eps``momentum` 的数据类型不是float。
- **TypeError** - `input_x``weight``bias``running_mean``running_var` 不是Tensor。
- **TypeError** - `input_x``weight` 的数据类型既不是float16也不是float32。

View File

@ -0,0 +1,19 @@
mindspore.ops.bias_add
======================
.. py:function:: mindspore.ops.bias_add(input_x, bias)
返回输入Tensor与偏置Tensor之和。相加前会把偏置Tensor广播成与输入Tensor的shape一致。
参数:
- **input_x** (Tensor) - 输入Tensor。shape可以有2~5个维度。数据类型应为float16或float32。
- **bias** (Tensor) - 偏置Tensorshape为 :math:`(C)`。C必须与 `input_x` 的通道维度C相同数据类型应为float16或float32。
返回:
Tensorshape和数据类型与 `input_x` 相同。
异常:
- **TypeError** - `input_x``bias` 不是Tensor。
- **TypeError** - `input_x``bias` 的数据类型既不是float16也不是float32。
- **TypeError** - `input_x``bias` 的数据类型不一致。
- **TypeError** - `input_x` 的维度不在[2, 5]范围内。

View File

@ -0,0 +1,39 @@
mindspore.ops.binary_cross_entropy
==================================
.. py:function:: mindspore.ops.binary_cross_entropy(logits, labels, weight=None, reduction='mean')
计算预测值和目标值之间的二值交叉熵损失。
将输入 `logits` 设置为 :math:`x` ,输入 `labels` 设置为 :math:`y` ,输出设置为 :math:`\ell(x, y)` 。则,
.. math::
L = \{l_1,\dots,l_N\}^\top, \quad
l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right]
其中,:math:`L` 表示所有batch_size的loss值:math:`l` 表示一个batch_size的loss值n表示在1-N范围内的一个batch_size。
.. math::
\ell(x, y) = \begin{cases}
L, & \text{if reduction} = \text{'none';}\\
\operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
\operatorname{sum}(L), & \text{if reduction} = \text{'sum'.}
\end{cases}
.. warning::
`x` 的值必须要在0-1范围之内`y` 的值必须是 `0` 或者 `1`
参数:
- **logits** (Tensor) - 输入预测值任意维度的Tensor。其数据类型为float16或float32。
- **label** (Tensor) - 输入目标值shape与 `logits` 相同。数据类型为float16或float32。
- **weight** (Tensor) - 指定每个批次二值交叉熵的权重。支持广播使其shape与 `logits` 的shape保持一致。数据类型必须为float16或float32。
- **reduction** (str) - 指定用于输出结果的计算方式。取值为 'mean' 、 'sum' 或 'none' ,不区分大小写。如果 'none' ,则不执行 `reduction` 。默认值:'mean' 。
返回:
Tensor或Scalar如果 `reduction` 为 'none' 则为shape和数据类型与输入 `logits` 相同的Tensor。否则输出为Scalar。
异常:
- **TypeError** - 输入 `logits` `labels` `weight` 不为Tensor。
- **TypeError** - 输入 `logits` `labels` `weight` 的数据类型既不是float16也不是float32。
- **ValueError** - `reduction` 不为 'none' 、 'mean' 或 'sum' 。
- **ValueError** - `labels` 的shape大小与 `logits` 或者 `weight` 不相同。

View File

@ -0,0 +1,19 @@
mindspore.ops.bmm
=================
.. py:function:: mindspore.ops.bmm(input_x, mat2)
基于batch维度的两个Tensor的矩阵乘法。
.. math::
\text{output}[..., :, :] = \text{matrix}(x[..., :, :]) * \text{matrix}(y[..., :, :])
参数:
- **input_x** (Tensor) - 输入相乘的第一个Tensor。其shape为 :math:`(*B, N, C)` ,其中 :math:`*B` 表示批处理大小,可以是多维度, :math:`N`:math:`C` 是最后两个维度的大小。
- **mat2** (Tensor) - 输入相乘的第二个Tensor。Tensor的shape为 :math:`(*B, C, M)`
返回:
Tensor输出Tensor的shape为 :math:`(*B, N, M)`
异常:
- **ValueError** - `input_x` 的shape长度不等于 `mat2` 的shape长度或 `input_x` 的shape长度小于 `3`

View File

@ -5197,7 +5197,7 @@ def bmm(input_x, mat2):
Tensor, the shape of the output tensor is :math:`(*B, N, M)`.
Raises:
ValueError: If length of shape of `input_x` is not equal to length of shape of `y` or
ValueError: If length of shape of `input_x` is not equal to length of shape of `mat2` or
length of shape of `input_x` is less than `3`.
Supported Platforms:

View File

@ -2891,29 +2891,22 @@ def batch_norm(input_x, running_mean, running_var, weight, bias, training=False,
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon, :math:`mean` is the
where :math:`\gamma` is `weight`, :math:`\beta` is `bias`, :math:`\epsilon` is `eps`, :math:`mean` is the
mean of `input_x`, :math:`variance` is the variance of `input_x`.
.. warning::
- For Ascend 310, the result accuracy fails to reach 1 due to the square root instruction.
.. note::
- If `training` is `False`, `weight`, `bias`, `running_mean` and `running_var` are Tensors.
- If `training` is `True`, `weight`, `bias`, `running_mean` and `running_var` are Parameters.
Args:
If running_mean is `False`, `scale`, `bias`, `mean` and `variance` are Tensors.
input_x (Tensor): Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
running_mean (Tensor): Tensor of shape :math:`(C,)`, has the same data type with `scale`.
running_var (Tensor): Tensor of shape :math:`(C,)`, has the same data type with `scale`.
weight (Tensor): Tensor of shape :math:`(C,)`, with float16 or float32 data type.
bias (Tensor): Tensor of shape :math:`(C,)`, has the same data type with `scale`.
If `training` is `True`, `scale`, `bias`, `mean` and `variance` are Parameters.
input_x (Tensor): Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
running_mean (Parameter): Parameter of shape :math:`(C,)`, has the same data type with `scale`.
running_var (Parameter): Parameter of shape :math:`(C,)`, has the same data type with `scale`.
weight (Parameter): Parameter of shape :math:`(C,)`, with float16 or float32 data type.
bias (Parameter): Parameter of shape :math:`(C,)`, has the same data type with `scale`.
input_x (Tensor): Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
running_mean (Union[Tensor, Parameter]): The shape :math:`(C,)`, has the same data type with `weight`.
running_var (Union[Tensor, Parameter]): The shape :math:`(C,)`, has the same data type with `weight`.
weight (Union[Tensor, Parameter]): The shape :math:`(C,)`, with float16 or float32 data type.
bias (Union[Tensor, Parameter]): The shape :math:`(C,)`, has the same data type with `weight`.
training (bool): If `training` is `True`, `mean` and `variance` are computed during training.
If `training` is `False`, they're loaded from checkpoint during inference. Default: False.
momentum (float): The hyper parameter to compute moving average for `running_mean` and `running_var`
@ -2926,9 +2919,9 @@ def batch_norm(input_x, running_mean, running_var, weight, bias, training=False,
Raises:
TypeError: If `training` is not a bool.
TypeError: If dtype of `epsilon` or `momentum` is not float.
TypeError: If `input_x`, `scale`, `bias`, `mean` or `variance` is not a Tensor.
TypeError: If dtype of `input_x`, `scale` is neither float16 nor float32.
TypeError: If dtype of `eps` or `momentum` is not float.
TypeError: If `input_x`, `weight`, `bias`, `running_mean` or `running_var` is not a Tensor.
TypeError: If dtype of `input_x`, `weight` is neither float16 nor float32.
Supported Platforms:
``Ascend`` ``CPU`` ``GPU``
@ -3022,10 +3015,10 @@ def binary_cross_entropy(logits, labels, weight=None, reduction='mean'):
Otherwise, it is a scalar Tensor.
Raises:
TypeError: If `logits`, `labels` or `weight` is not a Tensor.
TypeError: If dtype of `logits`, `labels` or `weight` (if given) is neither float16 nor float32.
ValueError: If `reduction` is not one of 'none', 'mean' or 'sum'.
ValueError: If shape of `labels` is not the same as `logits` or `weight` (if given).
TypeError: If `logits`, `labels` or `weight` is not a Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

View File

@ -2002,28 +2002,7 @@ class Argmax(PrimitiveWithInfer):
"""
Returns the indices of the maximum value of a tensor across the axis.
If the shape of input tensor is :math:`(x_1, ..., x_N)`, the shape of the output tensor will be
:math:`(x_1, ..., x_{axis-1}, x_{axis+1}, ..., x_N)`.
Args:
axis (int): Axis where the Argmax operation applies to. Default: -1.
output_type (:class:`mindspore.dtype`): An optional data type of `mindspore.dtype.int32`.
Default: `mindspore.dtype.int32`.
Inputs:
- **input_x** (Tensor) - Input tensor. :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
Support data type list as follows:
- Ascend: Float16, Float32.
- GPU: Float16, Float32.
- CPU: Float16, Float32, Float64.
Outputs:
Tensor, indices of the max value of input tensor across the axis.
Raises:
TypeError: If `axis` is not an int.
TypeError: If `output_type` is neither int32 nor int64.
Refer to :func:`mindspore.ops.argmax` for more detail.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

View File

@ -1593,23 +1593,7 @@ class AccumulateNV2(Primitive):
"""
Computes accumulation of all input tensors element-wise.
AccumulateNV2 is similar to AddN, but there is a significant difference
among them: AccumulateNV2 will not wait for all of its inputs to be ready
before summing. That is to say, AccumulateNV2 is able to save
memory when inputs are ready at different time since the minimum temporary
storage is proportional to the output size rather than the input size.
Inputs:
- **x** (Union(tuple[Tensor], list[Tensor])) - The input tuple or list
is made up of multiple tensors whose dtype is number to be added together.
Each element of tuple or list should have the same shape.
Outputs:
Tensor, has the same shape and dtype as each entry of the `x`.
Raises:
TypeError: If `x` is neither tuple nor list.
ValueError: If there is an input element with a different shape.
Refer to :func:`mindspore.ops.accumulate_n` for more detail.
Supported Platforms:
``Ascend``