forked from mindspore-Ecosystem/mindspore
add tensor norm
This commit is contained in:
parent
057da5a11d
commit
4e71de40be
|
@ -1,6 +1,6 @@
|
|||
mindspore.Tensor.norm
|
||||
=====================
|
||||
|
||||
.. py:method:: mindspore.Tensor.norm(axis, p=2, keep_dims=False, epsilon=1e-12)
|
||||
.. py:method:: mindspore.Tensor.norm(ord=None, dim=None, keepdim=False, *, dtype=None)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.norm`。
|
||||
|
|
|
@ -5,4 +5,33 @@
|
|||
|
||||
沿最后一个维度查找 `k` 个最大元素和对应的索引。
|
||||
|
||||
更多参考详见 :func:`mindspore.ops.top_k`。
|
||||
.. warning::
|
||||
- 如果 `sorted` 设置为False,它将使用aicpu运算符,性能可能会降低。
|
||||
|
||||
如果 `input_x` 是一维Tensor,则查找Tensor中 `k` 个最大元素,并将其值和索引输出为Tensor。`values[k]` 是 `input_x` 中 `k` 个最大元素,其索引是 `indices[k]` 。
|
||||
|
||||
对于多维矩阵,计算每行中最大的 `k` 个元素(沿最后一个维度的相应向量),因此:
|
||||
|
||||
.. math::
|
||||
values.shape = indices.shape = input.shape[:-1] + [k].
|
||||
|
||||
如果两个比较的元素相同,则优先返回索引值较小的元素。
|
||||
|
||||
参数:
|
||||
- **sorted** (bool, 可选) - 如果为True,则获取的元素将按值降序排序。如果为False,则获取的元素将按值升序排序。默认值:True。
|
||||
|
||||
输入:
|
||||
- **input_x** (Tensor) - 需计算的输入,数据类型必须为float16、float32或int32。
|
||||
- **k** (int) - 指定计算最大元素的数量,必须为常量。
|
||||
|
||||
输出:
|
||||
由 `values` 和 `indices` 组成的tuple。
|
||||
|
||||
- **values** (Tensor) - 最后一个维度的每个切片中的 `k` 最大元素。
|
||||
- **indices** (Tensor) - `k` 最大元素的对应索引。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果 `sorted` 不是bool。
|
||||
- **TypeError** - 如果 `input_x` 不是Tensor。
|
||||
- **TypeError** - 如果 `k` 不是int。
|
||||
- **TypeError** - 如果 `input_x` 的数据类型不是以下之一:float16、float32或int32。
|
|
@ -5,45 +5,44 @@ mindspore.ops.norm
|
|||
|
||||
返回给定Tensor的矩阵范数或向量范数。
|
||||
|
||||
该函数计算向量范数或者矩阵范数的规则如下:
|
||||
|
||||
- 如果 `dim` 是一个整型,将会计算向量范数。
|
||||
|
||||
- 如果 `dim` 是一个2-tuple,将会计算矩阵范数。
|
||||
|
||||
- 如果 `dim` 为None且 `ord` 为None,A将会被展平为1D并计算向量的2-范数。
|
||||
|
||||
- 如果 `dim` 为None且 `ord` 不为None,A必须为1D或者2D。
|
||||
|
||||
`ord` 为norm的计算模式。支持下列norm模式。
|
||||
|
||||
====================== ========================= ========================================================
|
||||
`ord` 矩阵范数 向量范数
|
||||
====================== ========================= ========================================================
|
||||
`None` (默认值) Frobenius norm `2`-norm (参考最下方公式)
|
||||
`'fro'` Frobenius norm 不支持
|
||||
`'nuc'` nuclear norm 不支持
|
||||
`inf` `max(sum(abs(x), dim=1))` `max(abs(x))`
|
||||
`-inf` `min(sum(abs(x), dim=1))` `min(abs(x))`
|
||||
`0` 不支持 `sum(x != 0)`
|
||||
`1` `max(sum(abs(x), dim=0))` 参考下方公式
|
||||
`-1` `min(sum(abs(x), dim=0))` 参考下方公式
|
||||
`2` 最大奇异值 参考下方公式
|
||||
`-2` 最小奇异值 参考下方公式
|
||||
其余int或float值 不支持 `sum(abs(x)^{ord})^{(1 / ord)}`
|
||||
====================== ========================= ========================================================
|
||||
================= ================================== ==============================================
|
||||
`ord` 矩阵范数 向量范数
|
||||
================= ================================== ==============================================
|
||||
`None` (默认值) Frobenius norm `2`-norm (参考最下方公式)
|
||||
`'fro'` Frobenius norm 不支持
|
||||
`'nuc'` nuclear norm 不支持
|
||||
`inf` :math:`max(sum(abs(x), dim=1))` :math:`max(abs(x))`
|
||||
`-inf` :math:`min(sum(abs(x), dim=1))` :math:`min(abs(x))`
|
||||
`0` 不支持 :math:`sum(x != 0)`
|
||||
`1` :math:`max(sum(abs(x), dim=0))` 参考最下方公式
|
||||
`-1` :math:`min(sum(abs(x), dim=0))` 参考最下方公式
|
||||
`2` 最大奇异值 参考最下方公式
|
||||
`-2` 最小奇异值 参考最下方公式
|
||||
其余int或float值 不支持 :math:`sum(abs(x)^{ord})^{(1 / ord)}`
|
||||
================= ================================== ==============================================
|
||||
|
||||
参数:
|
||||
- **A** (Tensor) - shape为 (*, n) 或者 (*, m, n)的Tensor,其中*是零个或多个batch维度。
|
||||
- **A** (Tensor) - shape为 :math:`(*, n)` 或者 :math:`(*, m, n)` 的Tensor,其中*是零个或多个batch维度。
|
||||
- **ord** (Union[int, float, inf, -inf, 'fro', 'nuc'], 可选) - norm的模式。行为参考上表。默认值:None。
|
||||
- **dim** (Union[int, Tuple(int)], 可选) - 计算向量范数或矩阵范数的维度。有关 `dim` = `None` 时的行为,请参见上文。默认值:None。
|
||||
- **dim** (Union[int, Tuple(int)], 可选) - 计算向量范数或矩阵范数的维度。默认值:None。
|
||||
|
||||
- 当 `dim` 为int时,会按向量范数计算。
|
||||
|
||||
- 当 `dim` 为一个二元组时,会按矩阵范数计算。
|
||||
|
||||
- 当 `dim` 为None且 `ord` 为None,`A` 将会被展平为1D并计算向量的2-范数。
|
||||
|
||||
- 当 `dim` 为None且 `ord` 不为None,`A` 必须为1D或者2D。
|
||||
|
||||
- **keepdim** (bool) - 输出Tensor是否保留原有的维度。默认值:False。
|
||||
|
||||
关键字参数:
|
||||
- **dtype** (:class:`mindspore.dtype`, 可选) - 如果指定,则在执行之前将输入Tensor转换为dtype类型,返回的Tensor类型也将为dtype。默认值:None。
|
||||
- **dtype** (:class:`mindspore.dtype`, 可选) - 如果设置此参数,则会在执行之前将A转换为指定的类型,返回的Tensor类型也将为指定类型。默认值:None。
|
||||
|
||||
返回:
|
||||
实值Tensor。
|
||||
Tensor,在指定维度 `dim` 上进行范数计算的结果,与输入 `A` 的数据类型相同。
|
||||
|
||||
异常:
|
||||
- **ValueError** - `dim` 超出范围。
|
||||
|
|
|
@ -6,9 +6,9 @@ mindspore.ops.topk
|
|||
沿给定维度查找 `k` 个最大或最小元素和对应的索引。
|
||||
|
||||
.. warning::
|
||||
- 如果 `sorted` 设置为'False',它将使用aicpu运算符,性能可能会降低。
|
||||
- 如果 `sorted` 设置为False,它将使用aicpu运算符,性能可能会降低。
|
||||
|
||||
如果 `input_x` 是一维Tensor,则查找Tensor中 `k` 个最大或最小元素,并将其值和索引输出为Tensor。因此, `values[k]` 是 `input_x` 中 `k` 个最大元素,其索引是 `indices[k]` 。
|
||||
如果 `input_x` 是一维Tensor,则查找Tensor中 `k` 个最大或最小元素,并将其值和索引输出为Tensor。`values[k]` 是 `input_x` 中 `k` 个最大元素,其索引是 `indices[k]` 。
|
||||
|
||||
对于多维矩阵,计算给定维度中最大或最小的 `k` 个元素,因此:
|
||||
|
||||
|
@ -19,13 +19,13 @@ mindspore.ops.topk
|
|||
|
||||
参数:
|
||||
- **input_x** (Tensor) - 需计算的输入,数据类型必须为float16、float32或int32。
|
||||
- **k** (int) - 指定计算最大或最小元素的数量,需要是常量。
|
||||
- **k** (int) - 指定计算最大或最小元素的数量,必须为常量。
|
||||
- **dim** (int, 可选) - 需要排序的维度。默认值:None。
|
||||
- **largest** (bool, 可选) - 如果为False,则会返回前k个最小值。默认值:True。
|
||||
- **sorted** (bool, 可选) - 如果为True,则获取的元素将按值降序排序。默认值:True。
|
||||
- **sorted** (bool, 可选) - 如果为True,则获取的元素将按值降序排序。如果为False,则获取的元素将按值升序排序。默认值:True。
|
||||
|
||||
返回:
|
||||
2个Tensor组成的tuple, `values` 和 `indices` 。
|
||||
由 `values` 和 `indices` 组成的tuple。
|
||||
|
||||
- **values** (Tensor) - 给定维度的每个切片中的 `k` 最大元素或最小元素。
|
||||
- **indices** (Tensor) - `k` 最大元素的对应索引。
|
||||
|
|
|
@ -1317,12 +1317,12 @@ class Tensor(Tensor_):
|
|||
|
||||
# pylint: disable=redefined-builtin
|
||||
# pylint: disable=invalid-name
|
||||
def norm(self, A, ord=None, dim=None, keepdim=False, *, dtype=None):
|
||||
def norm(self, ord=None, dim=None, keepdim=False, *, dtype=None):
|
||||
"""
|
||||
For details, please refer to :func:`mindspore.ops.norm`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('norm')(self, A, ord, dim, keepdim, dtype=dtype)
|
||||
return tensor_operator_registry.get('norm')(self, ord, dim, keepdim, dtype=dtype)
|
||||
|
||||
def renorm(self, p, dim, maxnorm):
|
||||
"""
|
||||
|
|
|
@ -5619,10 +5619,10 @@ def topk(input_x, k, dim=None, largest=True, sorted=True):
|
|||
Finds values and indices of the `k` largest or smallest entries along a given dimension.
|
||||
|
||||
.. warning::
|
||||
- If sorted is set to 'False', it will use the aicpu operator, the performance may be reduced.
|
||||
- If sorted is set to False, it will use the aicpu operator, the performance may be reduced.
|
||||
|
||||
If the `input_x` is a one-dimensional Tensor, finds the `k` largest or smallest entries in the Tensor,
|
||||
and outputs its value and index as a Tensor. Therefore, values[`k`] is the `k` largest item in `input_x`,
|
||||
and outputs its value and index as a Tensor. values[`k`] is the `k` largest item in `input_x`,
|
||||
and its index is indices [`k`].
|
||||
|
||||
For a multi-dimensional matrix,
|
||||
|
@ -5639,11 +5639,11 @@ def topk(input_x, k, dim=None, largest=True, sorted=True):
|
|||
k (int): The number of top or bottom elements to be computed along the last dimension, constant input is needed.
|
||||
dim (int, optional): The dimension to sort along. Default: None.
|
||||
largest (bool, optional): If largest is False then the k smallest elements are returned. Default: True.
|
||||
sorted (bool, optional): If true, the obtained elements will be sorted by the values in descending order.
|
||||
Default: True.
|
||||
sorted (bool): If True, the obtained elements will be sorted by the values in descending order.
|
||||
If False, the obtained elements will be sorted by the values in ascending order. Default: True.
|
||||
|
||||
Returns:
|
||||
Tuple of 2 tensors, the values and the indices.
|
||||
A tuple consisting of `values` and `indexes`.
|
||||
|
||||
- values (Tensor): The `k` largest or smallest elements in each slice of the given dimension.
|
||||
- indices (Tensor): The indices of values within the last dimension of input.
|
||||
|
|
|
@ -7217,49 +7217,47 @@ def norm(A, ord=None, dim=None, keepdim=False, *, dtype=None):
|
|||
r"""
|
||||
Returns the matrix norm or vector norm of a given tensor.
|
||||
|
||||
Whether this function computes a vector or matrix norm is determined as follows:
|
||||
`ord` is the calculation mode of norm. The following norm modes are supported.
|
||||
|
||||
- If `dim` is an integer, the vector norm will be computed.
|
||||
|
||||
- If `dim` is a 2-tuple, the matrix norm will be computed.
|
||||
|
||||
- If `dim` = None and `ord` = None, A will be flattened to 1D and the 2-norm of the resulting vector
|
||||
will be computed.
|
||||
|
||||
- If `dim` = None and `ord` != None, A must be 1D or 2D.
|
||||
|
||||
`ord` defines the norm that is computed. The following norms are supported:
|
||||
|
||||
====================== ========================= ========================================================
|
||||
`ord` norm for matrices norm for vectors
|
||||
====================== ========================= ========================================================
|
||||
`None` (default) Frobenius norm `2`-norm (see below)
|
||||
`'fro'` Frobenius norm -- not supported --
|
||||
`'nuc'` nuclear norm -- not supported --
|
||||
`inf` `max(sum(abs(x), dim=1))` `max(abs(x))`
|
||||
`-inf` `min(sum(abs(x), dim=1))` `min(abs(x))`
|
||||
`0` -- not supported -- `sum(x != 0)`
|
||||
`1` `max(sum(abs(x), dim=0))` as below
|
||||
`-1` `min(sum(abs(x), dim=0))` as below
|
||||
`2` largest singular value as below
|
||||
`-2` smallest singular value as below
|
||||
other `int` or `float` -- not supported -- `sum(abs(x)^{ord})^{(1 / ord)}`
|
||||
====================== ========================= ========================================================
|
||||
====================== ========================= ========================================================
|
||||
`ord` norm for matrices norm for vectors
|
||||
====================== ========================= ========================================================
|
||||
`None` (default) Frobenius norm `2`-norm (see below)
|
||||
`'fro'` Frobenius norm -- not supported --
|
||||
`'nuc'` nuclear norm -- not supported --
|
||||
`inf` :math:`max(sum(abs(x), dim=1))` :math:`max(abs(x))`
|
||||
`-inf` :math:`min(sum(abs(x), dim=1))` :math:`min(abs(x))`
|
||||
`0` -- not supported -- :math:`sum(x != 0)`
|
||||
`1` :math:`max(sum(abs(x), dim=0))` as below
|
||||
`-1` :math:`min(sum(abs(x), dim=0))` as below
|
||||
`2` largest singular value as below
|
||||
`-2` smallest singular value as below
|
||||
other `int` or `float` -- not supported -- :math:`sum(abs(x)^{ord})^{(1 / ord)}`
|
||||
====================== ========================= ========================================================
|
||||
|
||||
Args:
|
||||
A (Tensor): Tensor of shape (*, n) or (*, m, n) where * is zero or more batch dimensions.
|
||||
ord (Union[int, float, inf, -inf, 'fro', 'nuc'], optional): order of norm. refer to the table above for
|
||||
A (Tensor): Tensor of shape :math:`(*, n)` or :math:`(*, m, n)` where * is zero or more batch dimensions.
|
||||
ord (Union[int, float, inf, -inf, 'fro', 'nuc'], optional): norm's mode. refer to the table above for
|
||||
behavior. Default: None.
|
||||
dim (Union[int, Tuple(int)], optional): dimensions over which to compute the vector or matrix norm.
|
||||
See above for the behavior when `dim` = None. Default: None.
|
||||
keepdim (bool): Whether the output tensors have dim retained or not. Default: False.
|
||||
dim (Union[int, Tuple(int)], optional): calculate the dimension of vector norm or matrix norm. Default: None.
|
||||
|
||||
- When `dim` is int, it will be calculated by vector norm.
|
||||
|
||||
- When `dim` is a 2-tuple, it will be calculated by matrix norm.
|
||||
|
||||
- If `dim` is None and `ord` is None, `A` will be flattened to 1D and the 2-norm
|
||||
of the vector will be calculated.
|
||||
|
||||
- If `dim` is None and `ord` is not None, `A` must be 1D or 2D.
|
||||
|
||||
keepdim (bool): whether the output Tensor retains the original dimension. Default: False.
|
||||
|
||||
Keyword Args:
|
||||
dtype (:class:`mindspore.dtype`, optional): If specified, the input tensor is cast to dtype before performing
|
||||
the operation, and the returned tensor’s type will be dtype. Default: None.
|
||||
dtype (:class:`mindspore.dtype`, optional):If this parameter is set, A will be converted to the specified type
|
||||
before execution, and the returned Tensor type will also be the specified type. Default: None.
|
||||
|
||||
Returns:
|
||||
A real-valued tensor.
|
||||
Tensor, the result of norm calculation on the specified dimension, is the same as the input data type.
|
||||
|
||||
Raises:
|
||||
ValueError: If `dim` is out of range.
|
||||
|
@ -7276,52 +7274,56 @@ def norm(A, ord=None, dim=None, keepdim=False, *, dtype=None):
|
|||
Examples:
|
||||
>>> import mindspore as ms
|
||||
>>> import mindspore.ops as ops
|
||||
>>> a = ops.arange(9, dtype=ms.float32) - 4
|
||||
>>> b = a.reshape((3, 3))
|
||||
>>>print(ops.norm(a))
|
||||
7.745967
|
||||
>>>print(ops.norm(b))
|
||||
7.745967
|
||||
>>>print(ops.norm(b, 'fro'))
|
||||
7.745967
|
||||
>>>print(ops.norm(a, float('inf')))
|
||||
4.0
|
||||
>>>print(ops.norm(b, float('inf')))
|
||||
9.0
|
||||
>>>print(ops.norm(a, -float('inf')))
|
||||
>>> x = ops.arange(-12, 13, dtype=ms.float32)
|
||||
>>> y = x.reshape(5, 5)
|
||||
>>> print(ops.norm(x))
|
||||
36.05551
|
||||
>>> print(ops.norm(x, float('inf')))
|
||||
12.0
|
||||
>>> print(ops.norm(x, float('-inf')))
|
||||
0.0
|
||||
>>>print(ops.norm(b, -float('inf')))
|
||||
2.0
|
||||
>>>print(ops.norm(a, 1))
|
||||
20.0
|
||||
>>>print(ops.norm(b, 1))
|
||||
7.0
|
||||
>>>print(ops.norm(a, -1))
|
||||
>>> print(ops.norm(x, 0))
|
||||
24.0
|
||||
>>> print(ops.norm(x, 1))
|
||||
156.0
|
||||
>>> print(ops.norm(x, -1))
|
||||
0.0
|
||||
>>>print(ops.norm(b, -1))
|
||||
>>> print(ops.norm(x, 2))
|
||||
36.05551
|
||||
>>> print(ops.norm(x, -2))
|
||||
0.0
|
||||
>>> print(ops.norm(x, 3))
|
||||
23.000631
|
||||
>>> print(ops.norm(x, -3))
|
||||
0.0
|
||||
>>> print(ops.norm(y))
|
||||
36.05551
|
||||
>>> print(ops.norm(y, 'fro'))
|
||||
36.05551
|
||||
>>> print(ops.norm(y, 'nuc'))
|
||||
42.42641
|
||||
>>> print(ops.norm(y, float('inf')))
|
||||
50.0
|
||||
>>> print(ops.norm(y, float('-inf')))
|
||||
6.0
|
||||
>>>print(ops.norm(a, 2))
|
||||
7.745967
|
||||
>>>print(ops.norm(b, 2))
|
||||
7.3484707
|
||||
>>>print(ops.norm(a, -2))
|
||||
0.0
|
||||
>>>print(ops.norm(a, 3))
|
||||
5.848036
|
||||
>>>print(ops.norm(a, -3))
|
||||
0.0
|
||||
>>> c = ms.Tensor([[1., 2., 3.], [-1, 1, 4]])
|
||||
>>> print(ops.norm(c, dim=0))
|
||||
[1.4142135 2.236068 5. ]
|
||||
>>> print(ops.norm(c, dim=1))
|
||||
[3.7416575 4.2426405]
|
||||
>>> print(ops.norm(c, ord=1, dim=1))
|
||||
[6. 6.]
|
||||
>>> d = ops.arange(8, dtype=ms.float32).reshape(2, 2, 2)
|
||||
>>> print(ops.norm(d, dim=(1,2)))
|
||||
[ 3.7416575 11.224972 ]
|
||||
>>> print(ops.norm(d[0, :, :]), norm(d[1, :, :]))
|
||||
3.7416575 11.224972
|
||||
>>> print(ops.norm(y, 1))
|
||||
32.0
|
||||
>>> print(ops.norm(y, -1))
|
||||
30.0
|
||||
>>> print(ops.norm(y, 2))
|
||||
35.355343
|
||||
>>> m = ms.Tensor([[1., -1., 2.], [-2., 3., -4.]])
|
||||
>>> print(ops.norm(m, dim=0))
|
||||
[2.236068 3.1622777 4.472136 ]
|
||||
>>> print(ops.norm(m, dim=1))
|
||||
[2.4494898 5.3851647]
|
||||
>>> print(ops.norm(m, ord=1, dim=1))
|
||||
[4. 9.]
|
||||
>>> n = ops.arange(27, dtype=ms.float32).reshape(3, 3, 3)
|
||||
>>> print(ops.norm(n, dim=(1, 2)))
|
||||
[14.282857 39.76179 66.45299 ]
|
||||
>>> print(ops.norm(n[0, :, :]), ops.norm(n[1, :, :]), ops.norm(n[2, :, :]))
|
||||
14.282857 39.76179 66.45299
|
||||
"""
|
||||
ndim = A.ndim
|
||||
dim, immediate = _check_axis(dim, ord, ndim)
|
||||
|
|
|
@ -7802,7 +7802,41 @@ class TopK(Primitive):
|
|||
"""
|
||||
Finds values and indices of the `k` largest entries along the last dimension.
|
||||
|
||||
Refer to :func:`mindspore.ops.top_k` for more details.
|
||||
.. warning::
|
||||
- If sorted is set to False, it will use the aicpu operator, the performance may be reduced.
|
||||
|
||||
If the `input_x` is a one-dimensional Tensor, finds the `k` largest entries in the Tensor,
|
||||
and outputs its value and index as a Tensor. values[`k`] is the `k` largest item in `input_x`,
|
||||
and its index is indices [`k`].
|
||||
|
||||
For a multi-dimensional matrix,
|
||||
calculates the first `k` entries in each row (corresponding vector along the last dimension), therefore:
|
||||
|
||||
.. math::
|
||||
|
||||
values.shape = indices.shape = input.shape[:-1] + [k].
|
||||
|
||||
If the two compared elements are the same, the one with the smaller index value is returned first.
|
||||
|
||||
Args:
|
||||
sorted (bool): If True, the obtained elements will be sorted by the values in descending order.
|
||||
If False, the obtained elements will be sorted by the values in ascending order. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - Input to be computed, data type must be float16, float32 or int32.
|
||||
- **k** (int) - The number of top elements to be computed along the last dimension, constant input is needed.
|
||||
|
||||
Outputs:
|
||||
A tuple consisting of `values` and `indexes`.
|
||||
|
||||
- **values** (Tensor) - The `k` largest elements in each slice of the last dimension.
|
||||
- **indices** (Tensor) - The indices of values within the last dimension of input.
|
||||
|
||||
Raises:
|
||||
TypeError: If `sorted` is not a bool.
|
||||
TypeError: If `input_x` is not a Tensor.
|
||||
TypeError: If `k` is not an int.
|
||||
TypeError: If dtype of `input_x` is not one of the following: float16, float32 or int32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, a, norm_ord=None, dim=None, keepdim=False, dtype=None):
|
||||
output = a.norm(norm_ord, dim, keepdim, dtype=dtype)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_norm_normal(mode):
|
||||
"""
|
||||
Feature: norm
|
||||
Description: Verify the result of norm
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
a = ops.arange(9, dtype=ms.float32) - 4
|
||||
b = a.reshape((3, 3))
|
||||
output1 = net(a)
|
||||
expect_output1 = np.array(7.745967)
|
||||
assert np.allclose(output1.asnumpy(), expect_output1)
|
||||
|
||||
output2 = net(b)
|
||||
expect_output2 = np.array(7.745967)
|
||||
assert np.allclose(output2.asnumpy(), expect_output2)
|
||||
|
||||
output3 = net(a, float('inf'))
|
||||
expect_output3 = np.array(4.0)
|
||||
assert np.allclose(output3.asnumpy(), expect_output3)
|
||||
|
||||
output4 = net(b, float('inf'))
|
||||
expect_output4 = np.array(9.0)
|
||||
assert np.allclose(output4.asnumpy(), expect_output4)
|
||||
|
||||
output5 = net(a, -float('inf'))
|
||||
expect_output5 = np.array(0.0)
|
||||
assert np.allclose(output5.asnumpy(), expect_output5)
|
||||
|
||||
output6 = net(b, -float('inf'))
|
||||
expect_output6 = np.array(2.0)
|
||||
assert np.allclose(output6.asnumpy(), expect_output6)
|
||||
|
||||
output7 = net(a, 1)
|
||||
expect_output7 = np.array(20.0)
|
||||
assert np.allclose(output7.asnumpy(), expect_output7)
|
||||
|
||||
output8 = net(b, 1)
|
||||
expect_output8 = np.array(7.0)
|
||||
assert np.allclose(output8.asnumpy(), expect_output8)
|
||||
|
||||
output9 = net(a, 2)
|
||||
expect_output9 = np.array(7.745967)
|
||||
assert np.allclose(output9.asnumpy(), expect_output9)
|
||||
|
||||
output10 = net(b, 2)
|
||||
expect_output10 = np.array(7.3484707)
|
||||
assert np.allclose(output10.asnumpy(), expect_output10)
|
||||
|
||||
output11 = net(a, -1)
|
||||
expect_output11 = np.array(0.0)
|
||||
assert np.allclose(output11.asnumpy(), expect_output11)
|
||||
|
||||
output12 = net(b, -1)
|
||||
expect_output12 = np.array(6.0)
|
||||
assert np.allclose(output12.asnumpy(), expect_output12)
|
||||
|
||||
output13 = net(a, -2)
|
||||
expect_output13 = np.array(0.0)
|
||||
assert np.allclose(output13.asnumpy(), expect_output13)
|
||||
|
||||
output15 = net(a, 3)
|
||||
expect_output15 = np.array(5.848036)
|
||||
assert np.allclose(output15.asnumpy(), expect_output15)
|
||||
|
||||
output16 = net(a, -3)
|
||||
expect_output16 = np.array(0.0)
|
||||
assert np.allclose(output16.asnumpy(), expect_output16)
|
||||
|
||||
c = ms.Tensor([[1., 2., 3.], [-1, 1, 4]])
|
||||
output17 = net(c, dim=0)
|
||||
expect_output17 = np.array([1.4142135, 2.236068, 5.])
|
||||
assert np.allclose(output17.asnumpy(), expect_output17)
|
||||
|
||||
output18 = net(c, dim=1)
|
||||
expect_output18 = np.array([3.7416575, 4.2426405])
|
||||
assert np.allclose(output18.asnumpy(), expect_output18)
|
||||
|
||||
output19 = net(c, norm_ord=1, dim=1)
|
||||
expect_output19 = np.array([6., 6.])
|
||||
assert np.allclose(output19.asnumpy(), expect_output19)
|
||||
|
||||
d = ops.arange(8, dtype=ms.float32).reshape(2, 2, 2)
|
||||
|
||||
output20 = net(d, dim=(1, 2))
|
||||
expect_output20 = np.array([3.7416575, 11.224972])
|
||||
assert np.allclose(output20.asnumpy(), expect_output20)
|
||||
|
||||
output21 = net(d[0, :, :]).asnumpy(), net(d[1, :, :]).asnumpy()
|
||||
expect_output21 = np.array([3.7416575, 11.224972])
|
||||
assert np.allclose(output21, expect_output21)
|
Loading…
Reference in New Issue