add tensor norm

This commit is contained in:
fengyihang 2023-01-18 11:08:01 +08:00
parent 057da5a11d
commit 4e71de40be
9 changed files with 313 additions and 122 deletions

View File

@ -1,6 +1,6 @@
mindspore.Tensor.norm
=====================
.. py:method:: mindspore.Tensor.norm(axis, p=2, keep_dims=False, epsilon=1e-12)
.. py:method:: mindspore.Tensor.norm(ord=None, dim=None, keepdim=False, *, dtype=None)
详情请参考 :func:`mindspore.ops.norm`

View File

@ -5,4 +5,33 @@
沿最后一个维度查找 `k` 个最大元素和对应的索引。
更多参考详见 :func:`mindspore.ops.top_k`
.. warning::
- 如果 `sorted` 设置为False它将使用aicpu运算符性能可能会降低。
如果 `input_x` 是一维Tensor则查找Tensor中 `k` 个最大元素并将其值和索引输出为Tensor。`values[k]``input_x``k` 个最大元素,其索引是 `indices[k]`
对于多维矩阵,计算每行中最大的 `k` 个元素(沿最后一个维度的相应向量),因此:
.. math::
values.shape = indices.shape = input.shape[:-1] + [k].
如果两个比较的元素相同,则优先返回索引值较小的元素。
参数:
- **sorted** (bool, 可选) - 如果为True则获取的元素将按值降序排序。如果为False则获取的元素将按值升序排序。默认值True。
输入:
- **input_x** (Tensor) - 需计算的输入数据类型必须为float16、float32或int32。
- **k** (int) - 指定计算最大元素的数量,必须为常量。
输出:
`values``indices` 组成的tuple。
- **values** (Tensor) - 最后一个维度的每个切片中的 `k` 最大元素。
- **indices** (Tensor) - `k` 最大元素的对应索引。
异常:
- **TypeError** - 如果 `sorted` 不是bool。
- **TypeError** - 如果 `input_x` 不是Tensor。
- **TypeError** - 如果 `k` 不是int。
- **TypeError** - 如果 `input_x` 的数据类型不是以下之一float16、float32或int32。

View File

@ -5,45 +5,44 @@ mindspore.ops.norm
返回给定Tensor的矩阵范数或向量范数。
该函数计算向量范数或者矩阵范数的规则如下:
- 如果 `dim` 是一个整型,将会计算向量范数。
- 如果 `dim` 是一个2-tuple将会计算矩阵范数。
- 如果 `dim` 为None且 `ord` 为NoneA将会被展平为1D并计算向量的2-范数。
- 如果 `dim` 为None且 `ord` 不为NoneA必须为1D或者2D。
`ord` 为norm的计算模式。支持下列norm模式。
====================== ========================= ========================================================
`ord` 矩阵范数 向量范数
====================== ========================= ========================================================
`None` (默认值) Frobenius norm `2`-norm (参考最下方公式)
`'fro'` Frobenius norm 不支持
`'nuc'` nuclear norm 不支持
`inf` `max(sum(abs(x), dim=1))` `max(abs(x))`
`-inf` `min(sum(abs(x), dim=1))` `min(abs(x))`
`0` 不支持 `sum(x != 0)`
`1` `max(sum(abs(x), dim=0))` 参考下方公式
`-1` `min(sum(abs(x), dim=0))` 参考下方公式
`2` 最大奇异值 参考下方公式
`-2` 最小奇异值 参考下方公式
其余int或float值 不支持 `sum(abs(x)^{ord})^{(1 / ord)}`
====================== ========================= ========================================================
================= ================================== ==============================================
`ord` 矩阵范数 向量范数
================= ================================== ==============================================
`None` (默认值) Frobenius norm `2`-norm (参考最下方公式)
`'fro'` Frobenius norm 不支持
`'nuc'` nuclear norm 不支持
`inf` :math:`max(sum(abs(x), dim=1))` :math:`max(abs(x))`
`-inf` :math:`min(sum(abs(x), dim=1))` :math:`min(abs(x))`
`0` 不支持 :math:`sum(x != 0)`
`1` :math:`max(sum(abs(x), dim=0))` 参考最下方公式
`-1` :math:`min(sum(abs(x), dim=0))` 参考最下方公式
`2` 最大奇异值 参考下方公式
`-2` 最小奇异值 参考下方公式
其余int或float值 不支持 :math:`sum(abs(x)^{ord})^{(1 / ord)}`
================= ================================== ==============================================
参数:
- **A** (Tensor) - shape为 (*, n) 或者 (*, m, n)的Tensor其中*是零个或多个batch维度。
- **A** (Tensor) - shape为 :math:`(*, n)` 或者 :math:`(*, m, n)` 的Tensor其中*是零个或多个batch维度。
- **ord** (Union[int, float, inf, -inf, 'fro', 'nuc'], 可选) - norm的模式。行为参考上表。默认值None。
- **dim** (Union[int, Tuple(int)], 可选) - 计算向量范数或矩阵范数的维度。有关 `dim` = `None` 时的行为请参见上文。默认值None。
- **dim** (Union[int, Tuple(int)], 可选) - 计算向量范数或矩阵范数的维度。默认值None。
- 当 `dim` 为int时会按向量范数计算。
- 当 `dim` 为一个二元组时,会按矩阵范数计算。
- 当 `dim` 为None且 `ord` 为None`A` 将会被展平为1D并计算向量的2-范数。
- 当 `dim` 为None且 `ord` 不为None`A` 必须为1D或者2D。
- **keepdim** (bool) - 输出Tensor是否保留原有的维度。默认值False。
关键字参数:
- **dtype** (:class:`mindspore.dtype`, 可选) - 如果指定则在执行之前将输入Tensor转换为dtype类型返回的Tensor类型也将为dtype。默认值None。
- **dtype** (:class:`mindspore.dtype`, 可选) - 如果设置此参数则会在执行之前将A转换为指定的类型返回的Tensor类型也将为指定类型。默认值None。
返回:
实值Tensor。
Tensor,在指定维度 `dim` 上进行范数计算的结果,与输入 `A` 的数据类型相同
异常:
- **ValueError** - `dim` 超出范围。

View File

@ -6,9 +6,9 @@ mindspore.ops.topk
沿给定维度查找 `k` 个最大或最小元素和对应的索引。
.. warning::
- 如果 `sorted` 设置为'False'它将使用aicpu运算符性能可能会降低。
- 如果 `sorted` 设置为False它将使用aicpu运算符性能可能会降低。
如果 `input_x` 是一维Tensor则查找Tensor中 `k` 个最大或最小元素并将其值和索引输出为Tensor。因此, `values[k]``input_x``k` 个最大元素,其索引是 `indices[k]`
如果 `input_x` 是一维Tensor则查找Tensor中 `k` 个最大或最小元素并将其值和索引输出为Tensor。`values[k]``input_x``k` 个最大元素,其索引是 `indices[k]`
对于多维矩阵,计算给定维度中最大或最小的 `k` 个元素,因此:
@ -19,13 +19,13 @@ mindspore.ops.topk
参数:
- **input_x** (Tensor) - 需计算的输入数据类型必须为float16、float32或int32。
- **k** (int) - 指定计算最大或最小元素的数量,需要是常量。
- **k** (int) - 指定计算最大或最小元素的数量,必须为常量。
- **dim** (int, 可选) - 需要排序的维度。默认值None。
- **largest** (bool, 可选) - 如果为False则会返回前k个最小值。默认值True。
- **sorted** (bool, 可选) - 如果为True则获取的元素将按值降序排序。默认值True。
- **sorted** (bool, 可选) - 如果为True则获取的元素将按值降序排序。如果为False则获取的元素将按值升序排序。默认值True。
返回:
2个Tensor组成的tuple `values``indices`
`values``indices` 组成的tuple
- **values** (Tensor) - 给定维度的每个切片中的 `k` 最大元素或最小元素。
- **indices** (Tensor) - `k` 最大元素的对应索引。

View File

@ -1317,12 +1317,12 @@ class Tensor(Tensor_):
# pylint: disable=redefined-builtin
# pylint: disable=invalid-name
def norm(self, A, ord=None, dim=None, keepdim=False, *, dtype=None):
def norm(self, ord=None, dim=None, keepdim=False, *, dtype=None):
"""
For details, please refer to :func:`mindspore.ops.norm`.
"""
self._init_check()
return tensor_operator_registry.get('norm')(self, A, ord, dim, keepdim, dtype=dtype)
return tensor_operator_registry.get('norm')(self, ord, dim, keepdim, dtype=dtype)
def renorm(self, p, dim, maxnorm):
"""

View File

@ -5619,10 +5619,10 @@ def topk(input_x, k, dim=None, largest=True, sorted=True):
Finds values and indices of the `k` largest or smallest entries along a given dimension.
.. warning::
- If sorted is set to 'False', it will use the aicpu operator, the performance may be reduced.
- If sorted is set to False, it will use the aicpu operator, the performance may be reduced.
If the `input_x` is a one-dimensional Tensor, finds the `k` largest or smallest entries in the Tensor,
and outputs its value and index as a Tensor. Therefore, values[`k`] is the `k` largest item in `input_x`,
and outputs its value and index as a Tensor. values[`k`] is the `k` largest item in `input_x`,
and its index is indices [`k`].
For a multi-dimensional matrix,
@ -5639,11 +5639,11 @@ def topk(input_x, k, dim=None, largest=True, sorted=True):
k (int): The number of top or bottom elements to be computed along the last dimension, constant input is needed.
dim (int, optional): The dimension to sort along. Default: None.
largest (bool, optional): If largest is False then the k smallest elements are returned. Default: True.
sorted (bool, optional): If true, the obtained elements will be sorted by the values in descending order.
Default: True.
sorted (bool): If True, the obtained elements will be sorted by the values in descending order.
If False, the obtained elements will be sorted by the values in ascending order. Default: True.
Returns:
Tuple of 2 tensors, the values and the indices.
A tuple consisting of `values` and `indexes`.
- values (Tensor): The `k` largest or smallest elements in each slice of the given dimension.
- indices (Tensor): The indices of values within the last dimension of input.

View File

@ -7217,49 +7217,47 @@ def norm(A, ord=None, dim=None, keepdim=False, *, dtype=None):
r"""
Returns the matrix norm or vector norm of a given tensor.
Whether this function computes a vector or matrix norm is determined as follows:
`ord` is the calculation mode of norm. The following norm modes are supported.
- If `dim` is an integer, the vector norm will be computed.
- If `dim` is a 2-tuple, the matrix norm will be computed.
- If `dim` = None and `ord` = None, A will be flattened to 1D and the 2-norm of the resulting vector
will be computed.
- If `dim` = None and `ord` != None, A must be 1D or 2D.
`ord` defines the norm that is computed. The following norms are supported:
====================== ========================= ========================================================
`ord` norm for matrices norm for vectors
====================== ========================= ========================================================
`None` (default) Frobenius norm `2`-norm (see below)
`'fro'` Frobenius norm -- not supported --
`'nuc'` nuclear norm -- not supported --
`inf` `max(sum(abs(x), dim=1))` `max(abs(x))`
`-inf` `min(sum(abs(x), dim=1))` `min(abs(x))`
`0` -- not supported -- `sum(x != 0)`
`1` `max(sum(abs(x), dim=0))` as below
`-1` `min(sum(abs(x), dim=0))` as below
`2` largest singular value as below
`-2` smallest singular value as below
other `int` or `float` -- not supported -- `sum(abs(x)^{ord})^{(1 / ord)}`
====================== ========================= ========================================================
====================== ========================= ========================================================
`ord` norm for matrices norm for vectors
====================== ========================= ========================================================
`None` (default) Frobenius norm `2`-norm (see below)
`'fro'` Frobenius norm -- not supported --
`'nuc'` nuclear norm -- not supported --
`inf` :math:`max(sum(abs(x), dim=1))` :math:`max(abs(x))`
`-inf` :math:`min(sum(abs(x), dim=1))` :math:`min(abs(x))`
`0` -- not supported -- :math:`sum(x != 0)`
`1` :math:`max(sum(abs(x), dim=0))` as below
`-1` :math:`min(sum(abs(x), dim=0))` as below
`2` largest singular value as below
`-2` smallest singular value as below
other `int` or `float` -- not supported -- :math:`sum(abs(x)^{ord})^{(1 / ord)}`
====================== ========================= ========================================================
Args:
A (Tensor): Tensor of shape (*, n) or (*, m, n) where * is zero or more batch dimensions.
ord (Union[int, float, inf, -inf, 'fro', 'nuc'], optional): order of norm. refer to the table above for
A (Tensor): Tensor of shape :math:`(*, n)` or :math:`(*, m, n)` where * is zero or more batch dimensions.
ord (Union[int, float, inf, -inf, 'fro', 'nuc'], optional): norm's mode. refer to the table above for
behavior. Default: None.
dim (Union[int, Tuple(int)], optional): dimensions over which to compute the vector or matrix norm.
See above for the behavior when `dim` = None. Default: None.
keepdim (bool): Whether the output tensors have dim retained or not. Default: False.
dim (Union[int, Tuple(int)], optional): calculate the dimension of vector norm or matrix norm. Default: None.
- When `dim` is int, it will be calculated by vector norm.
- When `dim` is a 2-tuple, it will be calculated by matrix norm.
- If `dim` is None and `ord` is None, `A` will be flattened to 1D and the 2-norm
of the vector will be calculated.
- If `dim` is None and `ord` is not None, `A` must be 1D or 2D.
keepdim (bool): whether the output Tensor retains the original dimension. Default: False.
Keyword Args:
dtype (:class:`mindspore.dtype`, optional): If specified, the input tensor is cast to dtype before performing
the operation, and the returned tensors type will be dtype. Default: None.
dtype (:class:`mindspore.dtype`, optional):If this parameter is set, A will be converted to the specified type
before execution, and the returned Tensor type will also be the specified type. Default: None.
Returns:
A real-valued tensor.
Tensor, the result of norm calculation on the specified dimension, is the same as the input data type.
Raises:
ValueError: If `dim` is out of range.
@ -7276,52 +7274,56 @@ def norm(A, ord=None, dim=None, keepdim=False, *, dtype=None):
Examples:
>>> import mindspore as ms
>>> import mindspore.ops as ops
>>> a = ops.arange(9, dtype=ms.float32) - 4
>>> b = a.reshape((3, 3))
>>>print(ops.norm(a))
7.745967
>>>print(ops.norm(b))
7.745967
>>>print(ops.norm(b, 'fro'))
7.745967
>>>print(ops.norm(a, float('inf')))
4.0
>>>print(ops.norm(b, float('inf')))
9.0
>>>print(ops.norm(a, -float('inf')))
>>> x = ops.arange(-12, 13, dtype=ms.float32)
>>> y = x.reshape(5, 5)
>>> print(ops.norm(x))
36.05551
>>> print(ops.norm(x, float('inf')))
12.0
>>> print(ops.norm(x, float('-inf')))
0.0
>>>print(ops.norm(b, -float('inf')))
2.0
>>>print(ops.norm(a, 1))
20.0
>>>print(ops.norm(b, 1))
7.0
>>>print(ops.norm(a, -1))
>>> print(ops.norm(x, 0))
24.0
>>> print(ops.norm(x, 1))
156.0
>>> print(ops.norm(x, -1))
0.0
>>>print(ops.norm(b, -1))
>>> print(ops.norm(x, 2))
36.05551
>>> print(ops.norm(x, -2))
0.0
>>> print(ops.norm(x, 3))
23.000631
>>> print(ops.norm(x, -3))
0.0
>>> print(ops.norm(y))
36.05551
>>> print(ops.norm(y, 'fro'))
36.05551
>>> print(ops.norm(y, 'nuc'))
42.42641
>>> print(ops.norm(y, float('inf')))
50.0
>>> print(ops.norm(y, float('-inf')))
6.0
>>>print(ops.norm(a, 2))
7.745967
>>>print(ops.norm(b, 2))
7.3484707
>>>print(ops.norm(a, -2))
0.0
>>>print(ops.norm(a, 3))
5.848036
>>>print(ops.norm(a, -3))
0.0
>>> c = ms.Tensor([[1., 2., 3.], [-1, 1, 4]])
>>> print(ops.norm(c, dim=0))
[1.4142135 2.236068 5. ]
>>> print(ops.norm(c, dim=1))
[3.7416575 4.2426405]
>>> print(ops.norm(c, ord=1, dim=1))
[6. 6.]
>>> d = ops.arange(8, dtype=ms.float32).reshape(2, 2, 2)
>>> print(ops.norm(d, dim=(1,2)))
[ 3.7416575 11.224972 ]
>>> print(ops.norm(d[0, :, :]), norm(d[1, :, :]))
3.7416575 11.224972
>>> print(ops.norm(y, 1))
32.0
>>> print(ops.norm(y, -1))
30.0
>>> print(ops.norm(y, 2))
35.355343
>>> m = ms.Tensor([[1., -1., 2.], [-2., 3., -4.]])
>>> print(ops.norm(m, dim=0))
[2.236068 3.1622777 4.472136 ]
>>> print(ops.norm(m, dim=1))
[2.4494898 5.3851647]
>>> print(ops.norm(m, ord=1, dim=1))
[4. 9.]
>>> n = ops.arange(27, dtype=ms.float32).reshape(3, 3, 3)
>>> print(ops.norm(n, dim=(1, 2)))
[14.282857 39.76179 66.45299 ]
>>> print(ops.norm(n[0, :, :]), ops.norm(n[1, :, :]), ops.norm(n[2, :, :]))
14.282857 39.76179 66.45299
"""
ndim = A.ndim
dim, immediate = _check_axis(dim, ord, ndim)

View File

@ -7802,7 +7802,41 @@ class TopK(Primitive):
"""
Finds values and indices of the `k` largest entries along the last dimension.
Refer to :func:`mindspore.ops.top_k` for more details.
.. warning::
- If sorted is set to False, it will use the aicpu operator, the performance may be reduced.
If the `input_x` is a one-dimensional Tensor, finds the `k` largest entries in the Tensor,
and outputs its value and index as a Tensor. values[`k`] is the `k` largest item in `input_x`,
and its index is indices [`k`].
For a multi-dimensional matrix,
calculates the first `k` entries in each row (corresponding vector along the last dimension), therefore:
.. math::
values.shape = indices.shape = input.shape[:-1] + [k].
If the two compared elements are the same, the one with the smaller index value is returned first.
Args:
sorted (bool): If True, the obtained elements will be sorted by the values in descending order.
If False, the obtained elements will be sorted by the values in ascending order. Default: True.
Inputs:
- **input_x** (Tensor) - Input to be computed, data type must be float16, float32 or int32.
- **k** (int) - The number of top elements to be computed along the last dimension, constant input is needed.
Outputs:
A tuple consisting of `values` and `indexes`.
- **values** (Tensor) - The `k` largest elements in each slice of the last dimension.
- **indices** (Tensor) - The indices of values within the last dimension of input.
Raises:
TypeError: If `sorted` is not a bool.
TypeError: If `input_x` is not a Tensor.
TypeError: If `k` is not an int.
TypeError: If dtype of `input_x` is not one of the following: float16, float32 or int32.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

View File

@ -0,0 +1,127 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
def construct(self, a, norm_ord=None, dim=None, keepdim=False, dtype=None):
output = a.norm(norm_ord, dim, keepdim, dtype=dtype)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_norm_normal(mode):
"""
Feature: norm
Description: Verify the result of norm
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
a = ops.arange(9, dtype=ms.float32) - 4
b = a.reshape((3, 3))
output1 = net(a)
expect_output1 = np.array(7.745967)
assert np.allclose(output1.asnumpy(), expect_output1)
output2 = net(b)
expect_output2 = np.array(7.745967)
assert np.allclose(output2.asnumpy(), expect_output2)
output3 = net(a, float('inf'))
expect_output3 = np.array(4.0)
assert np.allclose(output3.asnumpy(), expect_output3)
output4 = net(b, float('inf'))
expect_output4 = np.array(9.0)
assert np.allclose(output4.asnumpy(), expect_output4)
output5 = net(a, -float('inf'))
expect_output5 = np.array(0.0)
assert np.allclose(output5.asnumpy(), expect_output5)
output6 = net(b, -float('inf'))
expect_output6 = np.array(2.0)
assert np.allclose(output6.asnumpy(), expect_output6)
output7 = net(a, 1)
expect_output7 = np.array(20.0)
assert np.allclose(output7.asnumpy(), expect_output7)
output8 = net(b, 1)
expect_output8 = np.array(7.0)
assert np.allclose(output8.asnumpy(), expect_output8)
output9 = net(a, 2)
expect_output9 = np.array(7.745967)
assert np.allclose(output9.asnumpy(), expect_output9)
output10 = net(b, 2)
expect_output10 = np.array(7.3484707)
assert np.allclose(output10.asnumpy(), expect_output10)
output11 = net(a, -1)
expect_output11 = np.array(0.0)
assert np.allclose(output11.asnumpy(), expect_output11)
output12 = net(b, -1)
expect_output12 = np.array(6.0)
assert np.allclose(output12.asnumpy(), expect_output12)
output13 = net(a, -2)
expect_output13 = np.array(0.0)
assert np.allclose(output13.asnumpy(), expect_output13)
output15 = net(a, 3)
expect_output15 = np.array(5.848036)
assert np.allclose(output15.asnumpy(), expect_output15)
output16 = net(a, -3)
expect_output16 = np.array(0.0)
assert np.allclose(output16.asnumpy(), expect_output16)
c = ms.Tensor([[1., 2., 3.], [-1, 1, 4]])
output17 = net(c, dim=0)
expect_output17 = np.array([1.4142135, 2.236068, 5.])
assert np.allclose(output17.asnumpy(), expect_output17)
output18 = net(c, dim=1)
expect_output18 = np.array([3.7416575, 4.2426405])
assert np.allclose(output18.asnumpy(), expect_output18)
output19 = net(c, norm_ord=1, dim=1)
expect_output19 = np.array([6., 6.])
assert np.allclose(output19.asnumpy(), expect_output19)
d = ops.arange(8, dtype=ms.float32).reshape(2, 2, 2)
output20 = net(d, dim=(1, 2))
expect_output20 = np.array([3.7416575, 11.224972])
assert np.allclose(output20.asnumpy(), expect_output20)
output21 = net(d[0, :, :]).asnumpy(), net(d[1, :, :]).asnumpy()
expect_output21 = np.array([3.7416575, 11.224972])
assert np.allclose(output21, expect_output21)