forked from mindspore-Ecosystem/mindspore
support norm
This commit is contained in:
parent
2ff013c3cc
commit
5dfffafe72
|
@ -31,6 +31,7 @@ mindspore/mindspore/python/mindspore/ops/function/array_func.py:scatter_nd
|
||||||
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:max_unpool3d
|
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:max_unpool3d
|
||||||
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:pad
|
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:pad
|
||||||
mindspore/mindspore/python/mindspore/ops/function/math_func.py:cov
|
mindspore/mindspore/python/mindspore/ops/function/math_func.py:cov
|
||||||
|
mindspore/mindspore/python/mindspore/ops/function/math_func.py:norm
|
||||||
mindspore/mindspore/python/mindspore/context.py:set_auto_parallel_context
|
mindspore/mindspore/python/mindspore/context.py:set_auto_parallel_context
|
||||||
mindspore/mindspore/python/mindspore/common/tensor.py:__init__
|
mindspore/mindspore/python/mindspore/common/tensor.py:__init__
|
||||||
mindspore/mindspore/python/mindspore/common/parameter.py:set_data
|
mindspore/mindspore/python/mindspore/common/parameter.py:set_data
|
||||||
|
|
|
@ -1,30 +1,55 @@
|
||||||
mindspore.ops.norm
|
mindspore.ops.norm
|
||||||
==================
|
==================
|
||||||
|
|
||||||
.. py:function:: mindspore.ops.norm(input_x, axis, p=2, keep_dims=False, epsilon=1e-12)
|
.. py:function:: mindspore.ops.norm(A, ord=None, dim=None, keepdim=False, *, dtype=None)
|
||||||
|
|
||||||
返回给定Tensor的矩阵范数或向量范数。
|
返回给定Tensor的矩阵范数或向量范数。
|
||||||
|
|
||||||
.. math::
|
该函数计算向量范数或者矩阵范数的规则如下:
|
||||||
output = sum(abs(input)**p)**(1/p)
|
|
||||||
|
- 如果 `dim` 是一个整型,将会计算向量范数。
|
||||||
|
|
||||||
|
- 如果 `dim` 是一个2-tuple,将会计算矩阵范数。
|
||||||
|
|
||||||
|
- 如果 `dim` 为None且 `ord` 为None,A将会被展平为1D并计算向量的2-范数。
|
||||||
|
|
||||||
|
- 如果 `dim` 为None且 `ord` 不为None,A必须为1D或者2D。
|
||||||
|
|
||||||
|
`ord` 为norm的计算模式。支持下列norm模式。
|
||||||
|
|
||||||
|
====================== ========================= ========================================================
|
||||||
|
`ord` 矩阵范数 向量范数
|
||||||
|
====================== ========================= ========================================================
|
||||||
|
`None` (默认值) Frobenius norm `2`-norm (参考最下方公式)
|
||||||
|
`'fro'` Frobenius norm 不支持
|
||||||
|
`'nuc'` nuclear norm 不支持
|
||||||
|
`inf` `max(sum(abs(x), dim=1))` `max(abs(x))`
|
||||||
|
`-inf` `min(sum(abs(x), dim=1))` `min(abs(x))`
|
||||||
|
`0` 不支持 `sum(x != 0)`
|
||||||
|
`1` `max(sum(abs(x), dim=0))` 参考下方公式
|
||||||
|
`-1` `min(sum(abs(x), dim=0))` 参考下方公式
|
||||||
|
`2` 最大奇异值 参考下方公式
|
||||||
|
`-2` 最小奇异值 参考下方公式
|
||||||
|
其余int或float值 不支持 `sum(abs(x)^{ord})^{(1 / ord)}`
|
||||||
|
====================== ========================= ========================================================
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **input_x** (Tensor) - 输入Tensor。数据类型必须为float16或者float32。
|
- **A** (Tensor) - shape为 (*, n) 或者 (*, m, n)的Tensor,其中*是零个或多个batch维度。
|
||||||
- **axis** (Union[int, list, tuple]) - 指定要计算范数的输入维度。
|
- **ord** (Union[int, float, inf, -inf, 'fro', 'nuc'], 可选) - norm的模式。行为参考上表。默认值:None。
|
||||||
- **p** (int) - 范数的值。默认值:2。 `p` 大于等于0。
|
- **dim** (Union[int, Tuple(int)], 可选) - 计算向量范数或矩阵范数的维度。有关 `dim` = `None` 时的行为,请参见上文。默认值:None。
|
||||||
- **keep_dims** (bool) - 输出Tensor是否保留原有的维度。默认值:False。
|
- **keepdim** (bool) - 输出Tensor是否保留原有的维度。默认值:False。
|
||||||
- **epsilon** (float) - 用于保持数据稳定性的常量。默认值:1e-12。
|
|
||||||
|
关键字参数:
|
||||||
|
- **dtype** (:class:`mindspore.dtype`, 可选) - 如果指定,则在执行之前将输入Tensor转换为dtype类型,返回的Tensor类型也将为dtype。默认值:None。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
Tensor,其数据类型与 `input_x` 相同,其维度信息取决于 `axis` 轴以及参数 `keep_dims` 。例如如果输入的大小为 `(2,3,4)` 轴为 `[0,1]` ,输出的维度为 `(4,)` 。
|
实值Tensor。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `input_x` 不是Tensor。
|
- **ValueError** - `dim` 超出范围。
|
||||||
- **TypeError** - `input_x` 的数据类型不是float16或者float32。
|
- **TypeError** - `dim` 既不是int也不是由int组成的tuple。
|
||||||
- **TypeError** - `axis` 不是int,tuple或者list。
|
- **TypeError** - `A` 是一个向量并且 `ord` 是str类型。
|
||||||
- **TypeError** - `p` 不是int。
|
- **ValueError** - `A` 是一个矩阵并且 `ord` 不是有效的取值。
|
||||||
- **TypeError** - `axis` 是tuple或者list但其元素不是int。
|
- **ValueError** - `A` 是一个矩阵并且 `ord` 为一个整型但是取值不为[1, -1, 2, -2]之一。
|
||||||
- **TypeError** - `keep_dims` 不是bool。
|
- **ValueError** - `dim` 的两个元素在标准化过后取值相同。
|
||||||
- **TypeError** - `epsilon` 不是float。
|
- **ValueError** - `dim` 的任意元素超出索引。
|
||||||
- **ValueError** - `axis` 的元素超出范围 `(-len(input_x.shape), len(input_x.shape))` ,其中 `input_x` 指当前Tensor 。
|
|
||||||
- **ValueError** - `axis` 的维度rank大于当前Tensor的维度rank。
|
|
||||||
|
|
|
@ -3312,9 +3312,11 @@ def lerp(start, end, weight):
|
||||||
return F.lerp(start, end, weight)
|
return F.lerp(start, end, weight)
|
||||||
|
|
||||||
|
|
||||||
def norm(input_x, axis, p=2, keep_dims=False, epsilon=1e-12):
|
# pylint: disable=redefined-builtin
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
def norm(A, ord=None, dim=None, keepdim=False, *, dtype=None):
|
||||||
"""Returns the matrix norm or vector norm of a given tensor."""
|
"""Returns the matrix norm or vector norm of a given tensor."""
|
||||||
return F.norm(input_x, axis, p, keep_dims, epsilon)
|
return F.norm(A, ord, dim, keepdim, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
def renorm(input_x, p, dim, maxnorm):
|
def renorm(input_x, p, dim, maxnorm):
|
||||||
|
|
|
@ -1315,12 +1315,14 @@ class Tensor(Tensor_):
|
||||||
self._init_check()
|
self._init_check()
|
||||||
return tensor_operator_registry.get("negative")(self)
|
return tensor_operator_registry.get("negative")(self)
|
||||||
|
|
||||||
def norm(self, axis, p=2, keep_dims=False, epsilon=1e-12):
|
# pylint: disable=redefined-builtin
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
def norm(self, A, ord=None, dim=None, keepdim=False, *, dtype=None):
|
||||||
"""
|
"""
|
||||||
For details, please refer to :func:`mindspore.ops.norm`.
|
For details, please refer to :func:`mindspore.ops.norm`.
|
||||||
"""
|
"""
|
||||||
self._init_check()
|
self._init_check()
|
||||||
return tensor_operator_registry.get('norm')(self, axis, p, keep_dims, epsilon)
|
return tensor_operator_registry.get('norm')(self, A, ord, dim, keepdim, dtype=dtype)
|
||||||
|
|
||||||
def renorm(self, p, dim, maxnorm):
|
def renorm(self, p, dim, maxnorm):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1889,7 +1889,7 @@ def cosine_similarity(x1, x2, dim=1, eps=1e-08):
|
||||||
[0.4843164 0.81647635]
|
[0.4843164 0.81647635]
|
||||||
"""
|
"""
|
||||||
molecule = ops.sum(x1 * x2, dim=dim)
|
molecule = ops.sum(x1 * x2, dim=dim)
|
||||||
denominator = (ops.norm(x1, axis=dim, p=2) * ops.norm(x2, axis=dim, p=2)).clip(min=eps)
|
denominator = (ops.norm(x1, dim=dim, ord=2) * ops.norm(x2, dim=dim, ord=2)).clip(min=eps)
|
||||||
output = molecule / denominator
|
output = molecule / denominator
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -4712,6 +4712,8 @@ def reciprocal(x):
|
||||||
>>> print(output)
|
>>> print(output)
|
||||||
[1. 0.5 0.25]
|
[1. 0.5 0.25]
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(x, Tensor):
|
||||||
|
raise TypeError(f"For reciprocal, the input must be a Tensor, but got {type(x)}.")
|
||||||
if not is_complex(x) and not ops.is_floating_point(x):
|
if not is_complex(x) and not ops.is_floating_point(x):
|
||||||
x = ops.cast(x, mstype.float32)
|
x = ops.cast(x, mstype.float32)
|
||||||
return _get_cache_prim(ops.Reciprocal)()(x)
|
return _get_cache_prim(ops.Reciprocal)()(x)
|
||||||
|
@ -7116,46 +7118,291 @@ def prod(x, axis=(), keep_dims=False):
|
||||||
return _get_cache_prim(P.ReduceProd)(keep_dims)(x, axis)
|
return _get_cache_prim(P.ReduceProd)(keep_dims)(x, axis)
|
||||||
|
|
||||||
|
|
||||||
def norm(input_x, axis, p=2, keep_dims=False, epsilon=1e-12):
|
def _multi_svd_norm(x, row_axis, col_axis, op):
|
||||||
|
y = moveaxis(x.astype(mstype.float32), (row_axis, col_axis), (-2, -1))
|
||||||
|
if op == 'amax':
|
||||||
|
return ops.svd(y, compute_uv=False).max(axis=-1)
|
||||||
|
if op == 'amin':
|
||||||
|
return ops.svd(y, compute_uv=False).min(axis=-1)
|
||||||
|
if op == 'sum':
|
||||||
|
return ops.svd(y, compute_uv=False).sum(axis=-1)
|
||||||
|
raise ValueError(f"For svd_norm, the op input must be one of ['amax', 'amin', 'sum'], but got f{op}")
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_axis_index(axis, ndim):
|
||||||
|
# pylint: disable=chained-comparison
|
||||||
|
if axis >= 0 and axis < ndim:
|
||||||
|
return axis
|
||||||
|
# pylint: disable=chained-comparison
|
||||||
|
if axis < 0 and axis >= -ndim:
|
||||||
|
return ndim + axis
|
||||||
|
raise ValueError('For norm, the dim is out of range.')
|
||||||
|
|
||||||
|
|
||||||
|
def moveaxis(x, source, destination):
|
||||||
|
perm = [i for i in range(x.ndim)]
|
||||||
|
for s, d in zip(source, destination):
|
||||||
|
tmp = perm[s]
|
||||||
|
perm[s] = perm[d]
|
||||||
|
perm[d] = tmp
|
||||||
|
perm = tuple(perm)
|
||||||
|
return ops.transpose(x, perm)
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _check_axis(axis, ord, ndim):
|
||||||
|
"""axis check"""
|
||||||
|
if axis is None:
|
||||||
|
axis = tuple(range(ndim))
|
||||||
|
if (ord is None) or (ord == 'fro' and ndim == 2) or (ord == 2 and ndim == 1):
|
||||||
|
return axis, True
|
||||||
|
return axis, False
|
||||||
|
if isinstance(axis, int):
|
||||||
|
axis = (axis,)
|
||||||
|
elif isinstance(axis, tuple):
|
||||||
|
if len(axis) > 2:
|
||||||
|
raise ValueError("For norm, the dimensions is out of range.")
|
||||||
|
else:
|
||||||
|
raise TypeError(f'For norm, the dim should be int or tuple of int, but got {type(axis)}')
|
||||||
|
return axis, False
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _check_ord(ord, axis):
|
||||||
|
if len(axis) == 1:
|
||||||
|
if isinstance(ord, str):
|
||||||
|
raise TypeError(f"For norm, ord mode can not be str for vectors, but got {ord}.")
|
||||||
|
elif len(axis) == 2:
|
||||||
|
if ord not in [2, -2, 1, -1, float('inf'), -float('inf'), 'fro', 'nuc', None]:
|
||||||
|
raise ValueError(f"For norm, the ord mode must be in "
|
||||||
|
f"[2, -2, 1, -1, float('inf'), -float('inf'), 'fro', 'nuc', None] for matrices, "
|
||||||
|
f"but got {ord}.")
|
||||||
|
|
||||||
|
|
||||||
|
def _check_dtype(d1, d2):
|
||||||
|
if mstype.float32 in (d1, d2):
|
||||||
|
return mstype.float32
|
||||||
|
if d1 == d2:
|
||||||
|
return d1
|
||||||
|
raise ValueError('the dtype is not supported.')
|
||||||
|
|
||||||
|
|
||||||
|
def dot(a, b):
|
||||||
|
"""dot function"""
|
||||||
|
res_dtype = _check_dtype(a.dtype, b.dtype)
|
||||||
|
ndim_a, ndim_b = a.ndim, b.ndim
|
||||||
|
if ndim_a == 0 or ndim_b == 0:
|
||||||
|
return ops.tensor_mul(a, b)
|
||||||
|
if ndim_a > 0 and ndim_b >= 2:
|
||||||
|
perm = ops.make_range(ndim_b)
|
||||||
|
perm = perm[:-2] + (perm[-1],) + (perm[-2],)
|
||||||
|
b = ops.transpose(b, perm)
|
||||||
|
|
||||||
|
if a.shape[-1] != b.shape[-1]:
|
||||||
|
raise ValueError('shapes are not aligned')
|
||||||
|
a_aligned = a.reshape(-1, a.shape[-1]).astype(mstype.float32)
|
||||||
|
b_aligned = b.reshape(-1, b.shape[-1]).astype(mstype.float32)
|
||||||
|
|
||||||
|
res = ops.matmul(a_aligned, b_aligned.T)
|
||||||
|
res = res.reshape(a.shape[:-1] + b.shape[:-1])
|
||||||
|
|
||||||
|
return res.astype(res_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def norm(A, ord=None, dim=None, keepdim=False, *, dtype=None):
|
||||||
r"""
|
r"""
|
||||||
Returns the matrix norm or vector norm of a given tensor.
|
Returns the matrix norm or vector norm of a given tensor.
|
||||||
|
|
||||||
.. math::
|
Whether this function computes a vector or matrix norm is determined as follows:
|
||||||
output = sum(abs(input)**p)**(1/p)
|
|
||||||
|
- If `dim` is an integer, the vector norm will be computed.
|
||||||
|
|
||||||
|
- If `dim` is a 2-tuple, the matrix norm will be computed.
|
||||||
|
|
||||||
|
- If `dim` = None and `ord` = None, A will be flattened to 1D and the 2-norm of the resulting vector
|
||||||
|
will be computed.
|
||||||
|
|
||||||
|
- If `dim` = None and `ord` != None, A must be 1D or 2D.
|
||||||
|
|
||||||
|
`ord` defines the norm that is computed. The following norms are supported:
|
||||||
|
|
||||||
|
====================== ========================= ========================================================
|
||||||
|
`ord` norm for matrices norm for vectors
|
||||||
|
====================== ========================= ========================================================
|
||||||
|
`None` (default) Frobenius norm `2`-norm (see below)
|
||||||
|
`'fro'` Frobenius norm -- not supported --
|
||||||
|
`'nuc'` nuclear norm -- not supported --
|
||||||
|
`inf` `max(sum(abs(x), dim=1))` `max(abs(x))`
|
||||||
|
`-inf` `min(sum(abs(x), dim=1))` `min(abs(x))`
|
||||||
|
`0` -- not supported -- `sum(x != 0)`
|
||||||
|
`1` `max(sum(abs(x), dim=0))` as below
|
||||||
|
`-1` `min(sum(abs(x), dim=0))` as below
|
||||||
|
`2` largest singular value as below
|
||||||
|
`-2` smallest singular value as below
|
||||||
|
other `int` or `float` -- not supported -- `sum(abs(x)^{ord})^{(1 / ord)}`
|
||||||
|
====================== ========================= ========================================================
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_x (Tensor): Input tensor. The dtype must be float32 or float16.
|
A (Tensor): Tensor of shape (*, n) or (*, m, n) where * is zero or more batch dimensions.
|
||||||
axis (Union[int, list, tuple]): Specifies which dimension or dimensions of input to calculate the norm across.
|
ord (Union[int, float, inf, -inf, 'fro', 'nuc'], optional): order of norm. refer to the table above for
|
||||||
p (int): The order of norm. Default: 2. `p` is greater than or equal to 0.
|
behavior. Default: None.
|
||||||
keep_dims (bool): Whether the output tensors have dim retained or not. Default: False.
|
dim (Union[int, Tuple(int)], optional): dimensions over which to compute the vector or matrix norm.
|
||||||
epsilon (float): A value added to the denominator for numerical stability. Default: 1e-12.
|
See above for the behavior when `dim` = None. Default: None.
|
||||||
|
keepdim (bool): Whether the output tensors have dim retained or not. Default: False.
|
||||||
|
|
||||||
|
Keyword Args:
|
||||||
|
dtype (:class:`mindspore.dtype`, optional): If specified, the input tensor is cast to dtype before performing
|
||||||
|
the operation, and the returned tensor’s type will be dtype. Default: None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, has the same dtype as `input`, which shape depends on the args axis. For example, if the size of input
|
A real-valued tensor.
|
||||||
is (2, 3, 4), axis is [0, 1], Outputs' shape will be (4,).
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If `input` is not a Tensor.
|
ValueError: If `dim` is out of range.
|
||||||
TypeError: If dtype of `input` is not one of: float16, float32.
|
TypeError: If `dim` is neither an int nor a tuple of int.
|
||||||
TypeError: If `p` is not an int.
|
TypeError: If `A` is a vector and `ord` is a str.
|
||||||
TypeError: If `axis` is not an int, a tuple or a list.
|
ValueError: If `A` is a matrices and `ord` is not in valid mode.
|
||||||
TypeError: If `axis` is a tuple or a list, but the element of `axis` is not an int.
|
ValueError: If `A` is a matrices and `ord` is an integer but not in [1, -1, 2, -2].
|
||||||
TypeError: If `keep_dims` is not a bool.
|
ValueError: If two elements of `dim` is same after normalize.
|
||||||
TypeError: If `epsilon` is not a float.
|
ValueError: If any elements of `dim` is out of range.
|
||||||
ValueError: If the element of `axis` is out of the range (-len(input.shape), len(input.shape)).
|
|
||||||
ValueError: If the length of shape of `axis` is bigger than the length of shape of `input`.
|
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU`` ``CPU``
|
``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> input_x = Tensor(np.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]).astype(np.float32))
|
>>> import mindspore as ms
|
||||||
>>> output = ops.norm(input_x, [0, 1], p=2)
|
>>> import mindspore.ops as ops
|
||||||
>>> print(output)
|
>>> a = ops.arange(9, dtype=ms.float32) - 4
|
||||||
[ 9.165152 10.954452]
|
>>> b = a.reshape((3, 3))
|
||||||
|
>>>print(ops.norm(a))
|
||||||
|
7.745967
|
||||||
|
>>>print(ops.norm(b))
|
||||||
|
7.745967
|
||||||
|
>>>print(ops.norm(b, 'fro'))
|
||||||
|
7.745967
|
||||||
|
>>>print(ops.norm(a, float('inf')))
|
||||||
|
4.0
|
||||||
|
>>>print(ops.norm(b, float('inf')))
|
||||||
|
9.0
|
||||||
|
>>>print(ops.norm(a, -float('inf')))
|
||||||
|
0.0
|
||||||
|
>>>print(ops.norm(b, -float('inf')))
|
||||||
|
2.0
|
||||||
|
>>>print(ops.norm(a, 1))
|
||||||
|
20.0
|
||||||
|
>>>print(ops.norm(b, 1))
|
||||||
|
7.0
|
||||||
|
>>>print(ops.norm(a, -1))
|
||||||
|
0.0
|
||||||
|
>>>print(ops.norm(b, -1))
|
||||||
|
6.0
|
||||||
|
>>>print(ops.norm(a, 2))
|
||||||
|
7.745967
|
||||||
|
>>>print(ops.norm(b, 2))
|
||||||
|
7.3484707
|
||||||
|
>>>print(ops.norm(a, -2))
|
||||||
|
0.0
|
||||||
|
>>>print(ops.norm(a, 3))
|
||||||
|
5.848036
|
||||||
|
>>>print(ops.norm(a, -3))
|
||||||
|
0.0
|
||||||
|
>>> c = ms.Tensor([[1., 2., 3.], [-1, 1, 4]])
|
||||||
|
>>> print(ops.norm(c, dim=0))
|
||||||
|
[1.4142135 2.236068 5. ]
|
||||||
|
>>> print(ops.norm(c, dim=1))
|
||||||
|
[3.7416575 4.2426405]
|
||||||
|
>>> print(ops.norm(c, ord=1, dim=1))
|
||||||
|
[6. 6.]
|
||||||
|
>>> d = ops.arange(8, dtype=ms.float32).reshape(2, 2, 2)
|
||||||
|
>>> print(ops.norm(d, dim=(1,2)))
|
||||||
|
[ 3.7416575 11.224972 ]
|
||||||
|
>>> print(ops.norm(d[0, :, :]), norm(d[1, :, :]))
|
||||||
|
3.7416575 11.224972
|
||||||
"""
|
"""
|
||||||
lp_norm_inner = _get_cache_prim(P.LpNorm)(axis, p, keep_dims, epsilon)
|
ndim = A.ndim
|
||||||
return lp_norm_inner(input_x)
|
dim, immediate = _check_axis(dim, ord, ndim)
|
||||||
|
_check_ord(ord, dim)
|
||||||
|
if dtype is not None:
|
||||||
|
A = ops.cast(A, dtype)
|
||||||
|
# Immediately handle some default, simple, fast, and common cases.
|
||||||
|
if immediate:
|
||||||
|
A = A.ravel()
|
||||||
|
sqnorm = dot(A, A)
|
||||||
|
ret = ops.sqrt(sqnorm)
|
||||||
|
if keepdim:
|
||||||
|
ret = ret.reshape(ndim * [1])
|
||||||
|
return ret
|
||||||
|
|
||||||
|
if isinstance(ord, int):
|
||||||
|
if len(dim) == 2:
|
||||||
|
if ord == 1:
|
||||||
|
return A.abs().sum(0).max()
|
||||||
|
if ord == -1:
|
||||||
|
return A.abs().sum(0).min()
|
||||||
|
if ord == 2:
|
||||||
|
return ops.svd(A, False, False).max(axis=-1)
|
||||||
|
if ord == -2:
|
||||||
|
return ops.svd(A, False, False).min(axis=-1)
|
||||||
|
raise ValueError(f"For norm, the ord {ord} are not support for matrices.")
|
||||||
|
if len(dim) == 1:
|
||||||
|
if ord == 0:
|
||||||
|
tmp = A != 0
|
||||||
|
return tmp.astype(mstype.float32).sum()
|
||||||
|
if ord > 0:
|
||||||
|
_lp_norm = _get_cache_prim(ops.LpNorm)(dim, ord, keepdim)
|
||||||
|
return _lp_norm(A)
|
||||||
|
return ops.sum(ops.abs(A).pow(ord)).pow(1.0 / ord)
|
||||||
|
if len(dim) == 1:
|
||||||
|
if ord == float('inf'):
|
||||||
|
return ops.abs(A).max(dim, keepdim)
|
||||||
|
if ord == -float('inf'):
|
||||||
|
return ops.abs(A).min(dim, keepdim)
|
||||||
|
if ord is None:
|
||||||
|
# special case for speedup
|
||||||
|
s = ops.conj(A) * A
|
||||||
|
reduce_sum = _get_cache_prim(ops.ReduceSum)(keepdim)
|
||||||
|
return ops.sqrt(reduce_sum(s, dim))
|
||||||
|
# None of the str-type keywords for ord ('fro', 'nuc')
|
||||||
|
# are valid for vectors
|
||||||
|
absx = ops.abs(A)
|
||||||
|
absx **= ord
|
||||||
|
reduce_sum = _get_cache_prim(ops.ReduceSum)(keepdim)
|
||||||
|
ret = reduce_sum(absx, dim)
|
||||||
|
if isinstance(ord, Tensor):
|
||||||
|
ret **= ops.reciprocal(ord)
|
||||||
|
else:
|
||||||
|
ret **= 1 / ord
|
||||||
|
if ops.isnan(ret):
|
||||||
|
return ops.zeros_like(ret)
|
||||||
|
return ret
|
||||||
|
if len(dim) == 2:
|
||||||
|
row_axis, col_axis = dim
|
||||||
|
row_axis = normalize_axis_index(row_axis, ndim)
|
||||||
|
col_axis = normalize_axis_index(col_axis, ndim)
|
||||||
|
if row_axis == col_axis:
|
||||||
|
raise ValueError('For norm, the elements of dim can not be duplicate.')
|
||||||
|
|
||||||
|
if ord == float('inf'):
|
||||||
|
if row_axis > col_axis:
|
||||||
|
row_axis -= 1
|
||||||
|
ret = ops.reduce_sum(abs(A), col_axis).max(row_axis)
|
||||||
|
elif ord == -float('inf'):
|
||||||
|
if row_axis > col_axis:
|
||||||
|
row_axis -= 1
|
||||||
|
ret = ops.reduce_sum(abs(A), col_axis).min(row_axis)
|
||||||
|
elif ord == 'fro':
|
||||||
|
ret = ops.sqrt(ops.reduce_sum((ops.conj(A) * A), dim))
|
||||||
|
elif ord == 'nuc':
|
||||||
|
ret = _multi_svd_norm(A, row_axis, col_axis, 'sum')
|
||||||
|
else:
|
||||||
|
ret = ops.sqrt(ops.reduce_sum((ops.conj(A) * A), dim))
|
||||||
|
if keepdim:
|
||||||
|
ret_shape = list(A.shape)
|
||||||
|
ret_shape[dim[0]] = 1
|
||||||
|
ret_shape[dim[1]] = 1
|
||||||
|
ret = ret.reshape(ret_shape)
|
||||||
|
return ret
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
|
def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
|
||||||
|
@ -9779,7 +10026,7 @@ def sum(x, dim=None, keepdim=False, *, dtype=None):
|
||||||
raise TypeError("For 'sum', 'keepdim' must be bool.")
|
raise TypeError("For 'sum', 'keepdim' must be bool.")
|
||||||
|
|
||||||
reduce_sum = P.ReduceSum(keep_dims=keepdim)
|
reduce_sum = P.ReduceSum(keep_dims=keepdim)
|
||||||
if dim:
|
if dim is not None:
|
||||||
out = reduce_sum(x, dim)
|
out = reduce_sum(x, dim)
|
||||||
else:
|
else:
|
||||||
out = reduce_sum(x)
|
out = reduce_sum(x)
|
||||||
|
|
|
@ -0,0 +1,127 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.ops as ops
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def construct(self, a, norm_ord=None, dim=None, keepdim=False, dtype=None):
|
||||||
|
output = ops.norm(a, norm_ord, dim, keepdim, dtype=dtype)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||||
|
def test_norm_normal(mode):
|
||||||
|
"""
|
||||||
|
Feature: norm
|
||||||
|
Description: Verify the result of norm
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
ms.set_context(mode=mode)
|
||||||
|
net = Net()
|
||||||
|
a = ops.arange(9, dtype=ms.float32) - 4
|
||||||
|
b = a.reshape((3, 3))
|
||||||
|
output1 = net(a)
|
||||||
|
expect_output1 = np.array(7.745967)
|
||||||
|
assert np.allclose(output1.asnumpy(), expect_output1)
|
||||||
|
|
||||||
|
output2 = net(b)
|
||||||
|
expect_output2 = np.array(7.745967)
|
||||||
|
assert np.allclose(output2.asnumpy(), expect_output2)
|
||||||
|
|
||||||
|
output3 = net(a, float('inf'))
|
||||||
|
expect_output3 = np.array(4.0)
|
||||||
|
assert np.allclose(output3.asnumpy(), expect_output3)
|
||||||
|
|
||||||
|
output4 = net(b, float('inf'))
|
||||||
|
expect_output4 = np.array(9.0)
|
||||||
|
assert np.allclose(output4.asnumpy(), expect_output4)
|
||||||
|
|
||||||
|
output5 = net(a, -float('inf'))
|
||||||
|
expect_output5 = np.array(0.0)
|
||||||
|
assert np.allclose(output5.asnumpy(), expect_output5)
|
||||||
|
|
||||||
|
output6 = net(b, -float('inf'))
|
||||||
|
expect_output6 = np.array(2.0)
|
||||||
|
assert np.allclose(output6.asnumpy(), expect_output6)
|
||||||
|
|
||||||
|
output7 = net(a, 1)
|
||||||
|
expect_output7 = np.array(20.0)
|
||||||
|
assert np.allclose(output7.asnumpy(), expect_output7)
|
||||||
|
|
||||||
|
output8 = net(b, 1)
|
||||||
|
expect_output8 = np.array(7.0)
|
||||||
|
assert np.allclose(output8.asnumpy(), expect_output8)
|
||||||
|
|
||||||
|
output9 = net(a, 2)
|
||||||
|
expect_output9 = np.array(7.745967)
|
||||||
|
assert np.allclose(output9.asnumpy(), expect_output9)
|
||||||
|
|
||||||
|
output10 = net(b, 2)
|
||||||
|
expect_output10 = np.array(7.3484707)
|
||||||
|
assert np.allclose(output10.asnumpy(), expect_output10)
|
||||||
|
|
||||||
|
output11 = net(a, -1)
|
||||||
|
expect_output11 = np.array(0.0)
|
||||||
|
assert np.allclose(output11.asnumpy(), expect_output11)
|
||||||
|
|
||||||
|
output12 = net(b, -1)
|
||||||
|
expect_output12 = np.array(6.0)
|
||||||
|
assert np.allclose(output12.asnumpy(), expect_output12)
|
||||||
|
|
||||||
|
output13 = net(a, -2)
|
||||||
|
expect_output13 = np.array(0.0)
|
||||||
|
assert np.allclose(output13.asnumpy(), expect_output13)
|
||||||
|
|
||||||
|
output15 = net(a, 3)
|
||||||
|
expect_output15 = np.array(5.848036)
|
||||||
|
assert np.allclose(output15.asnumpy(), expect_output15)
|
||||||
|
|
||||||
|
output16 = net(a, -3)
|
||||||
|
expect_output16 = np.array(0.0)
|
||||||
|
assert np.allclose(output16.asnumpy(), expect_output16)
|
||||||
|
|
||||||
|
c = ms.Tensor([[1., 2., 3.], [-1, 1, 4]])
|
||||||
|
output17 = net(c, dim=0)
|
||||||
|
expect_output17 = np.array([1.4142135, 2.236068, 5.])
|
||||||
|
assert np.allclose(output17.asnumpy(), expect_output17)
|
||||||
|
|
||||||
|
output18 = net(c, dim=1)
|
||||||
|
expect_output18 = np.array([3.7416575, 4.2426405])
|
||||||
|
assert np.allclose(output18.asnumpy(), expect_output18)
|
||||||
|
|
||||||
|
output19 = net(c, norm_ord=1, dim=1)
|
||||||
|
expect_output19 = np.array([6., 6.])
|
||||||
|
assert np.allclose(output19.asnumpy(), expect_output19)
|
||||||
|
|
||||||
|
d = ops.arange(8, dtype=ms.float32).reshape(2, 2, 2)
|
||||||
|
|
||||||
|
output20 = net(d, dim=(1, 2))
|
||||||
|
expect_output20 = np.array([3.7416575, 11.224972])
|
||||||
|
assert np.allclose(output20.asnumpy(), expect_output20)
|
||||||
|
|
||||||
|
output21 = net(d[0, :, :]).asnumpy(), net(d[1, :, :]).asnumpy()
|
||||||
|
expect_output21 = np.array([3.7416575, 11.224972])
|
||||||
|
assert np.allclose(output21, expect_output21)
|
Loading…
Reference in New Issue