!28332 optimize the documentation of english API of ELU, ReLU, Tril etc

Merge pull request !28332 from chenweitao_295/ops_amend_en
This commit is contained in:
i-robot 2022-01-19 10:50:38 +00:00 committed by Gitee
commit e484889f53
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 177 additions and 157 deletions

View File

@ -97,11 +97,10 @@ class CELU(Cell):
class Softmax(Cell):
r"""
Softmax activation function.
Softmax activation function. It is a two-category function :class:`mindspore.nn.Sigmoid` in the promotion of
multi-classification, the purpose is to show the results of multi-classification in the form of probability.
Applies the Softmax function to an n-dimensional input Tensor.
The input is a Tensor of logits transformed with exponential function and then
Calculate the value of the exponential function for the elements of the input Tensor on the `axis`, and then
normalized to lie in range [0, 1] and sum up to 1.
Softmax is defined as:
@ -112,7 +111,8 @@ class Softmax(Cell):
where :math:`x_{i}` is the :math:`i`-th slice in the given dimension of the input Tensor.
Args:
axis (Union[int, tuple[int]]): The axis to apply Softmax operation, -1 means the last dimension. Default: -1.
axis (Union[int, tuple[int]]): The axis to apply Softmax operation, if the dimension of input `x` is x.ndim,
the range of axis is `[-x.ndim, x.ndim)`, -1 means the last dimension. Default: -1.
Inputs:
- **x** (Tensor) - The input of Softmax with data type of float16 or float32.
@ -130,11 +130,13 @@ class Softmax(Cell):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> # axis = -1(default), and the sum of return value is 1.0.
>>> x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16)
>>> softmax = nn.Softmax()
>>> output = softmax(x)
>>> print(output)
[0.03168 0.01166 0.0861 0.636 0.2341 ]
>>> assert(1.0 == output.sum())
"""
def __init__(self, axis=-1):
@ -207,19 +209,20 @@ class ELU(Cell):
.. math::
E_{i} =
\begin{cases}
x, &\text{if } x \geq 0; \cr
\text{alpha} * (\exp(x_i) - 1), &\text{otherwise.}
x_i, &\text{if } x_i \geq 0; \cr
\alpha * (\exp(x_i) - 1), &\text{otherwise.}
\end{cases}
where :math:`x_i` represents the element of the input and :math:`\alpha` represents the `alpha` parameter.
The picture about ELU looks like this `ELU <https://en.wikipedia.org/wiki/
Activation_function#/media/File:Activation_elu.svg>`_.
Args:
alpha (float): The coefficient of negative factor whose type is float. Default: 1.0.
alpha (float): The alpha value of ELU, the data type is float. Default: 1.0.
Inputs:
- **x** (Tensor) - The input of ELU with data type of float16 or float32.
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
- **x** (Tensor) - The input of ELU is a Tensor of any dimension with data type of float16 or float32.
Outputs:
Tensor, with the same type and shape as the `x`.
@ -263,11 +266,10 @@ class ReLU(Cell):
will be suppressed and the active neurons will stay the same.
The picture about ReLU looks like this `ReLU <https://en.wikipedia.org/wiki/
Activation_function#/media/File:Activation_rectified_linear.svg>`_.
Activation_function#/media/File:Activation_rectified_linear.svg>`_ .
Inputs:
- **x** (Tensor) - The input of ReLU. The data type is Number.
The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
- **x** (Tensor) - The input of ReLU is a Tensor of any dimension. The data type is `number <https://www.mindspore.cn/docs/api/en/master/api_python/mindspore.html#mindspore.dtype>`_ .
Outputs:
Tensor, with the same type and shape as the `x`.
@ -343,12 +345,13 @@ class LeakyReLU(Cell):
r"""
Leaky ReLU activation function.
LeakyReLU is similar to ReLU, but LeakyReLU has a slope that makes it not equal to 0 at x < 0.
The activation function is defined as:
.. math::
\text{leaky_relu}(x) = \begin{cases}x, &\text{if } x \geq 0; \cr
\text{alpha} * x, &\text{otherwise.}\end{cases}
{\alpha} * x, &\text{otherwise.}\end{cases}
where :math:`\alpha` represents the `alpha` parameter.
See https://ai.stanford.edu/~amaas/papers/relu_hybrid_icml2013_final.pdf
@ -356,8 +359,7 @@ class LeakyReLU(Cell):
alpha (Union[int, float]): Slope of the activation function at x < 0. Default: 0.2.
Inputs:
- **x** (Tensor) - The input of LeakyReLU.
The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
- **x** (Tensor) - The input of LeakyReLU is a Tensor of any dimension.
Outputs:
Tensor, has the same type and shape as the `x`.

View File

@ -175,13 +175,12 @@ class Dropout(Cell):
class Flatten(Cell):
r"""
Flatten layer for the input.
Flattens a tensor without changing dimension of batch size on the 0-th axis.
Flatten the dimensions other than the 0th dimension of the input Tensor.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, \ldots)` to be flattened. The data type is Number.
The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions
- **x** (Tensor) - The input Tensor to be flattened. The data type is
`number <https://www.mindspore.cn/docs/api/en/master/api_python/mindspore.html#mindspore.dtype>`_ .
The shape is :math:`(N, *)` , where :math:`*` means any number of additional dimensions
and the shape can't be ().
Outputs:
@ -1005,19 +1004,23 @@ def tril(x_shape, x_dtype, k):
class Tril(Cell):
"""
Returns a tensor with elements above the kth diagonal zeroed.
Returns a tensor, the elements above the specified main diagonal are set to zero.
The lower triangular part of the matrix is defined as the elements on and below the diagonal.
Divide the matrix elements into upper and lower triangles along the main diagonal (including diagonals).
The parameter `k` controls the diagonal to be considered.
If diagonal = 0, all elements on and below the main diagonal are retained.
Positive values include as many diagonals above the main diagonal, and similarly,
negative values exclude as many diagonals below the main diagonal.
The parameter `k` controls the choice of diagonal.
If `k` = 0, split along the main diagonal and keep all the elements of the lower triangle.
If `k` > 0, select the diagonal `k` along the main diagonal upwards, and keep all the elements of the lower
triangle.
If `k` < 0, select the diagonal `k` along the main diagonal down, and keep all the elements of the lower
triangle.
Inputs:
- **x** (Tensor) - The input tensor. The data type is Number.
:math:`(N,*)` where :math:`*` means, any number of additional dimensions.
- **k** (Int) - The index of diagonal. Default: 0
- **x** (Tensor) - The input tensor. The data type is
`number <https://www.mindspore.cn/docs/api/en/master/api_python/mindspore.html#mindspore.dtype>`_.
- **k** (Int) - The index of diagonal. Default: 0. If the dimensions of the input matrix are d1 and d2,
the range of k should be in [-min(d1, d2)+1, min(d1, d2)-1], and the output value will be the same as the
input `x` when `k` is out of range.
Outputs:
Tensor, has the same shape and type as input `x`.
@ -1030,6 +1033,7 @@ class Tril(Cell):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> # case1: k = 0
>>> x = Tensor(np.array([[ 1, 2, 3, 4],
... [ 5, 6, 7, 8],
... [10, 11, 12, 13],
@ -1041,6 +1045,7 @@ class Tril(Cell):
[ 5 6 0 0]
[10 11 12 0]
[14 15 16 17]]
>>> # case2: k = 1
>>> x = Tensor(np.array([[ 1, 2, 3, 4],
... [ 5, 6, 7, 8],
... [10, 11, 12, 13],
@ -1052,6 +1057,7 @@ class Tril(Cell):
[ 5 6 7 0]
[10 11 12 13]
[14 15 16 17]]
>>> # case3: k = 2
>>> x = Tensor(np.array([[ 1, 2, 3, 4],
... [ 5, 6, 7, 8],
... [10, 11, 12, 13],
@ -1063,6 +1069,7 @@ class Tril(Cell):
[ 5 6 7 8]
[10 11 12 13]
[14 15 16 17]]
>>> # case4: k = -1
>>> x = Tensor(np.array([[ 1, 2, 3, 4],
... [ 5, 6, 7, 8],
... [10, 11, 12, 13],

View File

@ -879,23 +879,22 @@ class MatMul(Cell):
class Moments(Cell):
"""
Calculates the mean and variance of `x`.
The mean and variance are calculated by aggregating the contents of `input_x` across axes.
If `input_x` is 1-D and axes = [0] this is just the mean and variance of a vector.
Calculate the mean and variance of the input `x` along the specified `axis`.
Args:
axis (Union[int, tuple(int)]): Calculates the mean and variance along the specified axis. Default: None.
keep_dims (bool): If true, The dimension of mean and variance are identical with input's.
If false, don't keep these dimensions. Default: None.
axis (Union[int, tuple(int)]): Calculates the mean and variance along the specified axis.
When the value is None, it means to calculate the mean and variance of all values of `x`. Default: None.
keep_dims (bool): If True, the calculation result will retain the dimension of `axis`,
and the dimensions of the mean and variance are the same as the input. If False or None,
the dimension of `axis` will be reduced. Default: None.
Inputs:
- **x** (Tensor) - The tensor to be calculated. Only float16 and float32 are supported.
:math:`(N,*)` where :math:`*` means,any number of additional dimensions.
- **x** (Tensor) - Tensor of any dimension used to calculate the mean and variance.
Only float16 and float32 are supported.
Outputs:
- **mean** (Tensor) - The mean of `x`, with the same data type as input `x`.
- **variance** (Tensor) - The variance of `x`, with the same data type as input `x`.
- **mean** (Tensor) - The mean value of `x` on `axis`, with the same data type as input `x`.
- **variance** (Tensor) - The variance of `x` on `axis`, with the same data type as input `x`.
Raises:
TypeError: If `axis` is not one of int, tuple, None.
@ -906,40 +905,38 @@ class Moments(Cell):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([[[[1, 2, 3, 4], [3, 4, 5, 6]]]]), mindspore.float32)
>>> # case1: axis = 0, keep_dims=True
>>> x = Tensor(np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]), mindspore.float32)
>>> net = nn.Moments(axis=0, keep_dims=True)
>>> output = net(x)
>>> print(output)
(Tensor(shape=[1, 1, 2, 4], dtype=Float32, value=
[[[[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
[ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00, 6.00000000e+00]]]]),
Tensor(shape=[1, 1, 2, 4], dtype=Float32, value=
[[[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]]]))
>>> net = nn.Moments(axis=1, keep_dims=True)
(Tensor(shape=[1, 2, 2], dtype=Float32, value=
[[[ 3.00000000e+00, 4.00000000e+00],
[ 5.00000000e+00, 6.00000000e+00]]]), Tensor(shape=[1, 2, 2], dtype=Float32, value=
[[[ 4.00000000e+00, 4.00000000e+00],
[ 4.00000000e+00, 4.00000000e+00]]]))
>>> # case2: axis = 1, keep_dims=True)
>>> output = net(x)
>>> print(output)
(Tensor(shape=[1, 1, 2, 4], dtype=Float32, value=
[[[[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
[ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00, 6.00000000e+00]]]]),
Tensor(shape=[1, 1, 2, 4], dtype=Float32, value=
[[[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]]]))
>>> net = nn.Moments(axis=2, keep_dims=True)
(Tensor(shape=[2, 1, 2], dtype=Float32, value=
[[[ 2.00000000e+00, 3.00000000e+00]],
[[ 6.00000000e+00, 7.00000000e+00]]]), Tensor(shape=[2, 1, 2], dtype=Float32, value=
[[[ 1.00000000e+00, 1.00000000e+00]],
[[ 1.00000000e+00, 1.00000000e+00]]]))
>>> # case3: axis = 2, keep_dims=None(default)
>>> net = nn.Moments(axis=2)
>>> output = net(x)
>>> print(output)
(Tensor(shape=[1, 1, 1, 4], dtype=Float32, value=
[[[[ 2.00000000e+00, 3.00000000e+00, 4.00000000e+00, 5.00000000e+00]]]]),
Tensor(shape=[1, 1, 1, 4], dtype=Float32, value=
[[[[ 1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00]]]]))
>>> net = nn.Moments(axis=3, keep_dims=True)
(Tensor(shape=[2, 2], dtype=Float32, value=
[[ 1.50000000e+00, 3.50000000e+00],
[ 5.50000000e+00, 7.50000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
[[ 2.50000000e-01, 2.50000000e-01],
[ 2.50000000e-01, 2.50000000e-01]]))
>>> # case4: axis = None(default), keep_dims=None(default)
>>> net = nn.Moments()
>>> output = net(x)
>>> print(output)
(Tensor(shape=[1, 1, 2, 1], dtype=Float32, value=
[[[[ 2.50000000e+00],
[ 4.50000000e+00]]]]), Tensor(shape=[1, 1, 2, 1], dtype=Float32, value=
[[[[ 1.25000000e+00],
[ 1.25000000e+00]]]]))
(Tensor(shape=[], dtype=Float32, value= 4.5), Tensor(shape=[], dtype=Float32, value= 5.25))
"""
def __init__(self, axis=None, keep_dims=None):

View File

@ -46,6 +46,7 @@ class LossBase(Cell):
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def __init__(self, reduction='mean'):
"""Initialize Loss."""
super(LossBase, self).__init__()
@ -156,6 +157,7 @@ class _Loss(LossBase):
"""
Base class for other losses.
"""
def __init__(self, reduction='mean'):
"""Initialize _Loss."""
log.warning("'_Loss' is deprecated from version 1.3 and "
@ -176,11 +178,10 @@ def _check_is_tensor(param_name, input_data, cls_name):
class L1Loss(LossBase):
r"""
L1Loss creates a criterion to measure the mean absolute error (MAE) between :math:`x` and :math:`y` element-wise,
where :math:`x` is the input Tensor and :math:`y` is the labels Tensor.
L1Loss is used to calculate the mean absolute error between the predicted value and the target value.
For simplicity, let :math:`x` and :math:`y` be 1-dimensional Tensor with length :math:`N`,
the unreduced loss (i.e. with argument reduction set to 'none') of :math:`x` and :math:`y` is given as:
Assuming that the :math:`x` and :math:`y` are 1-D Tensor, length: math:`N`, then calculate the loss of :math:`x` and
:math:`y` without dimensionality reduction (the reduction parameter is set to "none"). The formula is as follows:
.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad \text{with } l_n = \left| x_n - y_n \right|,
@ -196,21 +197,21 @@ class L1Loss(LossBase):
Args:
reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none".
Default: "mean".
Default: "mean". If `reduction` is "mean" or "sum", then output a scalar Tensor, if `reduction` is "none",
the shape of the output Tensor is the broadcasted shape.
Inputs:
- **logits** (Tensor) - Tensor of shape :math:`(N, *)` where :math:`*` means, any number of
additional dimensions.
- **labels** (Tensor) - Tensor of shape :math:`(N, *)`, same shape as the `logits` in common cases.
- **logits** (Tensor) - Predicted value, Tensor of any dimension.
- **labels** (Tensor) - Target value, same shape as the `logits` in common cases.
However, it supports the shape of `logits` is different from the shape of `labels`
and they should be broadcasted to each other.
Outputs:
Tensor, loss float tensor, the shape is zero if `reduction` is 'mean' or 'sum',
while the shape of output is the broadcasted shape if `reduction` is 'none'.
Tensor, data type is float.
Raises:
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
ValueError: If `logits` and `labels` have different shapes and cannot be broadcasted to each other.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -232,6 +233,7 @@ class L1Loss(LossBase):
[[0. 1. 2.]
[0. 0. 1.]]
"""
def __init__(self, reduction='mean'):
"""Initialize L1Loss."""
super(L1Loss, self).__init__(reduction)
@ -302,6 +304,7 @@ class MSELoss(LossBase):
[[0. 1. 4.]
[0. 0. 1.]]
"""
def construct(self, logits, labels):
_check_is_tensor('logits', logits, self.cls_name)
_check_is_tensor('labels', labels, self.cls_name)
@ -349,6 +352,7 @@ class RMSELoss(LossBase):
>>> print(output)
1.0
"""
def __init__(self):
"""Initialize RMSELoss."""
super(RMSELoss, self).__init__()
@ -418,6 +422,7 @@ class MAELoss(LossBase):
[[0. 1. 2.]
[0. 0. 1.]]
"""
def __init__(self, reduction='mean'):
"""Initialize MAELoss."""
super(MAELoss, self).__init__(reduction)
@ -432,43 +437,42 @@ class MAELoss(LossBase):
class SmoothL1Loss(LossBase):
r"""
A loss class for learning region proposals.
SmoothL1 loss function, if the absolute error element-wise between the predicted value and the target value
is less than the set threshold `beta`, the square term is used, otherwise the absolute error term is used.
SmoothL1Loss can be regarded as modified version of L1Loss or a combination of L1Loss and L2Loss.
L1Loss computes the element-wise absolute difference between two input tensors while L2Loss computes the
squared difference between two input tensors. L2Loss often leads to faster convergence but it is less
robust to outliers.
Given two input :math:`x,\ y` of length :math:`N`, the unreduced SmoothL1Loss can be described
as follows:
Given two input :math:`x,\ y`, the SmoothL1Loss can be described as follows:
.. math::
L_{i} =
\begin{cases}
\frac{0.5 (x_i - y_i)^{2}}{\text{beta}}, & \text{if } |x_i - y_i| < \text{beta} \\
|x_i - y_i| - 0.5 \text{beta}, & \text{otherwise. }
\frac{0.5 (x_i - y_i)^{2}}{\beta}, & \text{if } |x_i - y_i| < {\beta} \\
|x_i - y_i| - 0.5 {\beta}, & \text{otherwise.}
\end{cases}
Here :math:`\text{beta}` controls the point where the loss function changes from quadratic to linear.
Its default value is 1.0. :math:`N` is the batch size. This function returns an
unreduced loss tensor.
Where :math:`{\beta}` represents the threshold `beta`.
.. note::
SmoothL1Loss can be regarded as modified version of L1Loss or a combination of L1Loss and L2Loss.
L1Loss computes the element-wise absolute difference between two input tensors while L2Loss computes the
squared difference between two input tensors. L2Loss often leads to faster convergence but it is less
robust to outliers, and the loss function has better robustness.
Args:
beta (float): A parameter used to control the point where the function will change from
quadratic to linear. Default: 1.0.
beta (float): The loss function calculates the threshold of the transformation between L1Loss and L2Loss.
Default: 1.0.
Inputs:
- **logits** (Tensor) - Tensor of shape :math:`(N, *)` where :math:`*` means, any number of
additional dimensions. Data type must be float16 or float32.
- **labels** (Tensor) - Ground truth data, tensor of shape :math:`(N, *)`,
same shape and dtype as the `logits`.
- **logits** (Tensor) - Predictive value. Tensor of any dimension. Data type must be float16 or float32.
- **labels** (Tensor) - Ground truth data, same shape and dtype as the `logits`.
Outputs:
Tensor, loss float tensor, same shape and dtype as the `logits`.
Raises:
TypeError: If `beta` is not a float.
TypeError: If `logits` or `labels` are not Tensor.
TypeError: If dtype of `logits` or `labels` is neither float16 not float32.
TypeError: If dtype of `logits` or `labels` are not the same.
ValueError: If `beta` is less than or equal to 0.
ValueError: If shape of `logits` is not the same as `labels`.
@ -483,6 +487,7 @@ class SmoothL1Loss(LossBase):
>>> print(output)
[0. 0. 0.5]
"""
def __init__(self, beta=1.0):
"""Initialize SmoothL1Loss."""
super(SmoothL1Loss, self).__init__()
@ -534,6 +539,7 @@ class SoftMarginLoss(LossBase):
>>> print(output)
0.6764238
"""
def __init__(self, reduction='mean'):
super(SoftMarginLoss, self).__init__()
self.soft_margin_loss = P.SoftMarginLoss(reduction)
@ -605,6 +611,7 @@ class SoftmaxCrossEntropyWithLogits(LossBase):
>>> print(output)
[30.]
"""
def __init__(self,
sparse=False,
reduction='none'):
@ -675,6 +682,7 @@ class DiceLoss(LossBase):
>>> print(output)
0.38596618
"""
def __init__(self, smooth=1e-5):
"""Initialize DiceLoss."""
super(DiceLoss, self).__init__()
@ -762,6 +770,7 @@ class MultiClassDiceLoss(LossBase):
>>> print(output)
0.54958105
"""
def __init__(self, weights=None, ignore_indiex=None, activation="softmax"):
"""Initialize MultiClassDiceLoss."""
super(MultiClassDiceLoss, self).__init__()
@ -802,7 +811,7 @@ class MultiClassDiceLoss(LossBase):
dice_loss *= self.weights[i]
total_loss += dice_loss
return total_loss/label.shape[1]
return total_loss / label.shape[1]
class SampledSoftmaxLoss(LossBase):
@ -1161,6 +1170,7 @@ class CosineEmbeddingLoss(LossBase):
>>> print(output)
0.0003425479
"""
def __init__(self, margin=0.0, reduction="mean"):
"""Initialize CosineEmbeddingLoss."""
super(CosineEmbeddingLoss, self).__init__(reduction)

View File

@ -677,24 +677,23 @@ class DynamicShape(Primitive):
class Squeeze(PrimitiveWithInfer):
"""
Returns a tensor with the same data type but dimensions of 1 are removed based on `axis`.
Return the Tensor after deleting the dimension of size 1 in the specified `axis`.
If :math:`axis=()`, it will remove all the dimensions of size 1.
If `axis` is specified, it will remove the dimensions of size 1 in the given `axis`.
If `axis` is None, it will remove all the dimensions of size 1.
For example, if input is of shape: (A×1×B×C×1×D), then the out tensor will be of shape: (A×B×C×D);
When dim is given, a squeeze operation is done only in the given dimension.
If input is of shape: (A×1×B), squeeze(input, 0) leaves the tensor unchanged,
but squeeze(input, 1) will squeeze the tensor to the shape (A×B).
Please note that in dynamic graph mode, the output Tensor will share data with the input Tensor,
and there is no Tensor data copy process.
For example, if the dimension is not specified :math:`axis=()`, input shape is (A, 1, B, C, 1, D),
then the shape of the output Tensor is (A, B, C, D). If the dimension is specified, the squeeze operation
is only performed in the specified dimension. If input shape is (A, 1, B), input Tensor will not be
changed when :math:`axis=0` , but when :math:`axis=1` , the shape of the input Tensor will be changed to (A, B).
Note:
The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim]`.
- Please note that in dynamic graph mode, the output Tensor will share data with the input Tensor,
and there is no Tensor data copy process.
- The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim]`.
Args:
axis (Union[int, tuple(int)]): Specifies the dimension indexes of shape to be removed, which will remove
all the dimensions that are equal to 1. If specified, it must be int32 or int64.
all the dimensions of size 1 in the given axis parameter. If specified, it must be int32 or int64.
Default: (), an empty tuple.
Inputs:
@ -846,24 +845,25 @@ class Unique(Primitive):
class Gather(Primitive):
r"""
Returns a slice of the input tensor based on the specified indices and axis.
Slices the input tensor base on the indices at specified axis. See the following example for more clear.
Returns the slice of the input Tensor corresponding to the elements of `input_indices` on the specified `axis`.
Inputs:
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
The original Tensor.
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor. Must be in the range
`[0, input_param.shape[axis])` which are only validated on CPU. The data type can be int32 or int64.
- **input_params** (Tensor) - The original Tensor. The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
- **input_indices** (Tensor) - Index tensor to be sliced, the shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor. The data type can be int32 or int64.
- **axis** (int) - Specifies the dimension index to gather indices.
.. note::
The value of input_indices must be in the range of `[0, input_param.shape[axis])`, and report an error if it
exceeds this range.
Outputs:
Tensor, the shape of tensor is
:math:`input\_params.shape[:axis] + input\_indices.shape + input\_params.shape[axis + 1:]`.
Raises:
TypeError: If `axis` is not an int.
TypeError: If `input_indices` is not an int type Tensor.
TypeError: If `input_indices` is not an int.
Supported Platforms:
@ -1266,17 +1266,17 @@ class Size(PrimitiveWithInfer):
class Fill(PrimitiveWithInfer):
"""
Creates a tensor filled with a scalar value.
Creates a tensor with shape described by the first argument and fills it with values in the second argument.
Create a Tensor of the specified shape and fill it with the specified value.
Inputs:
- **type** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed.
- **shape** (tuple) - The specified shape of output tensor. Only constant value is allowed.
- **value** (scalar) - Value to fill the returned tensor. Only constant value is allowed.
- **type** (mindspore.dtype) - The specified type of output tensor. The data type only supports
`bool_ <https://www.mindspore.cn/docs/api/en/master/api_python/mindspore.html#mindspore.dtype>`_ and
`number <https://www.mindspore.cn/docs/api/en/master/api_python/mindspore.html#mindspore.dtype>`_ .
- **shape** (tuple[int]) - The specified shape of output tensor.
- **value** (Union(number.Number, bool)) - Value to fill the returned tensor.
Outputs:
Tensor, has the same type and shape as input value.
Tensor.
Raises:
TypeError: If `shape` is not a tuple.

View File

@ -3737,14 +3737,7 @@ class NotEqual(_LogicBinaryOp):
class Greater(_LogicBinaryOp):
r"""
Computes the boolean value of :math:`x > y` element-wise.
Inputs of `x` and `y` comply with the implicit type conversion rules to make the data types consistent.
The inputs must be two tensors or one tensor and one scalar.
When the inputs are two tensors,
dtypes of them cannot be bool at the same time, and the shapes of them could be broadcast.
When the inputs are one tensor and one scalar,
the scalar could only be a constant.
Compare the value of the input parameters :math:`x,y` element-wise, and the output result is a bool value.
.. math::
@ -3754,13 +3747,23 @@ class Greater(_LogicBinaryOp):
\end{cases}
Note:
Broadcasting is supported.
- Inputs of `x` and `y` comply with the implicit type conversion rules to make the data types consistent.
- The inputs must be two tensors or one tensor and one scalar.
- When the inputs are two tensors, dtypes of them cannot be bool at the same time,
and the shapes of them can be broadcast.
- When the inputs are one tensor and one scalar, the scalar could only be a constant.
- Broadcasting is supported.
- If the input Tensor can be broadcast, the low dimension will be extended to the corresponding high dimension
in another input by copying the value of the dimension.
Inputs:
- **x** (Union[Tensor, Number, bool]) - The first input is a number or
a bool or a tensor whose data type is number or bool.
- **y** (Union[Tensor, Number, bool]) - The second input is a number or
a bool when the first input is a tensor or a tensor whose data type is number or bool.
- **x** (Union[Tensor, number.Number, bool]) - The first input is a number.Number or
a bool or a tensor whose data type is
`number <https://www.mindspore.cn/docs/api/en/master/api_python/mindspore.html#mindspore.dtype>`_ or
`bool_ <https://www.mindspore.cn/docs/api/en/master/api_python/mindspore.html#mindspore.dtype>`_ .
- **y** (Union[Tensor, number.Number, bool]) - The second input, when the first input is a Tensor,
the second input should be a number.Number or bool value, or a Tensor whose data type is number or bool_.
When the first input is Scalar, the second input must be a Tensor whose data type is number or bool_.
Outputs:
Tensor, the shape is the same as the one after broadcasting, and the data type is bool.
@ -5237,14 +5240,13 @@ class Inv(Primitive):
out_i = \frac{1}{x_{i} }
Inputs:
- **x** (Tensor) - The shape of tensor is
:math:`(N,*)` where :math:`*` means, any number of additional dimensions.
Must be one of the following types: float16, float32, int32.
- **x** (Tensor) - Tensor of any dimension. Must be one of the following types: float16, float32 or int32.
Outputs:
Tensor, has the same shape and data type as `x`.
Raises:
TypeError: If `x` is not a Tensor.
TypeError: If dtype of `x` is not one of float16, float32, int32.
Supported Platforms:

View File

@ -708,7 +708,10 @@ class ReLUV2(Primitive):
class Elu(Primitive):
r"""
Computes exponential linear:
Exponential Linear Uint activation function.
Applies the exponential linear unit function element-wise.
The activation function is defined as:
.. math::
@ -718,15 +721,14 @@ class Elu(Primitive):
x & \text{if } x \gt 0\\
\end{array}\right.
The data type of input tensor must be float.
The picture about ELU looks like this `ELU <https://en.wikipedia.org/wiki/
Activation_function#/media/File:Activation_elu.svg>`_ .
Args:
alpha (float): The coefficient of negative factor whose type is float,
only support '1.0' currently. Default: 1.0.
alpha (float): The alpha value of ELU, the data type is float. Only support '1.0' currently. Default: 1.0.
Inputs:
- **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
additional dimensions, with float16 or float32 data type.
- **input_x** (Tensor) - The input of ELU is a Tensor of any dimension with data type of float16 or float32.
Outputs:
Tensor, has the same shape and data type as `input_x`.
@ -2090,10 +2092,8 @@ class Conv2DTranspose(Conv2DBackpropInput):
class BiasAdd(Primitive):
r"""
Returns sum of input and bias tensor.
Adds the 1-D bias tensor to the input tensor, and broadcasts the shape on all axis
except for the channel axis.
Returns the sum of the input Tensor and the bias Tensor. Before adding, the bias Tensor will be broadcasted to be
consistent with the shape of the input Tensor.
Args:
data_format (str): The format of input and output data. It should be 'NHWC', 'NCHW' or 'NCDHW'.
@ -2102,16 +2102,19 @@ class BiasAdd(Primitive):
Inputs:
- **input_x** (Tensor) - The input tensor. The shape can be 2-5 dimensions.
The data type should be float16 or float32.
- **bias** (Tensor) - The bias tensor, with shape :math:`(C)`. The shape of
`bias` must be the same as `input_x`'s channel dimension. The data type should be float16 or float32.
- **bias** (Tensor) - The bias tensor, with shape :math:`(C)`. C must be the same as channel dimension C of
`input_x`. The data type should be float16 or float32.
Outputs:
Tensor, with the same shape and data type as `input_x`.
Raises:
TypeError: If `data_format` is not a str.
ValueError: If value of `data_format` is not in the range of ['NHWC','NCHW','NCDHW'].
TypeError: If `input_x` or `bias` is not a Tensor.
TypeError: If dtype of `input_x` or `bias` is neither float16 nor float32.
TypeError: If dtype of `input_x` or `bias` is inconsistent.
TypeError: If dimension of `input_x` is not in the range [2, 5].
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -3131,15 +3134,13 @@ class L2Normalize(PrimitiveWithInfer):
This operator will normalize the input using the given axis. The function is shown as follows:
.. math::
\displaylines{{\text{output} = \frac{x}{\sqrt{\text{max}(\parallel x_i \parallel^p , \epsilon)} } } \\
{\parallel x_i \parallel^p = (\sum_{i}^{}\left | x_i \right | ^p )^{1/p}} }
\displaylines{{\text{output} = \frac{x}{\sqrt{\text{max}( \sum_{i}^{}\left | x_i \right | ^2, \epsilon)}}}}
where :math:`\epsilon` is epsilon and :math:`\sum_{i}^{}\left | x_i \right | ^p` calculates
along the dimension `axis`.
where :math:`\epsilon` is epsilon and :math:`\sum_{i}^{}\left | x_i \right | ^2` calculate the sum of squares of
the input `x` along the dimension `axis`.
Args:
axis (Union[list(int), tuple(int), int]): The starting axis for the input to apply the L2 Normalization.
Default: 0.
axis (Union[list(int), tuple(int), int]): Specify the axis for calculating the L2 norm. Default: 0.
epsilon (float): A small value added for numerical stability. Default: 1e-4.
Inputs:
@ -3154,6 +3155,7 @@ class L2Normalize(PrimitiveWithInfer):
TypeError: If `epsilon` is not a float.
TypeError: If `x` is not a Tensor.
TypeError: If dtype of `x` is neither float16 nor float32.
ValueError: If dimension of `x` is not greater than 0.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``