forked from mindspore-Ecosystem/mindspore
!41788 [assistant][ops]Add Diff
Merge pull request !41788 from zy26/Diff
This commit is contained in:
commit
db4e6b9a22
|
@ -313,6 +313,7 @@ from .math_func import (
|
|||
block_diag,
|
||||
atleast_1d,
|
||||
dstack,
|
||||
diff,
|
||||
atleast_2d,
|
||||
cartesian_prod,
|
||||
atleast_3d,
|
||||
|
|
|
@ -5541,23 +5541,94 @@ def dstack(inputs):
|
|||
[3. 6.]]]
|
||||
"""
|
||||
if not isinstance(inputs, (tuple, list)):
|
||||
raise TypeError(f"For 'dstack', 'inputs' must be list or tuple of tensors, but got {type(inputs)}")
|
||||
raise TypeError(f"For 'dstack', 'diff', 'inputs' must be list or tuple of tensors, but got {type(inputs)}")
|
||||
if not inputs:
|
||||
raise TypeError(f"For 'dstack', 'inputs' can not be empty.")
|
||||
raise TypeError(f"For 'dstack', 'diff', 'inputs' can not be empty.")
|
||||
trans_inputs = ()
|
||||
for tensor in inputs:
|
||||
if not isinstance(tensor, Tensor):
|
||||
raise TypeError(f"For 'dstack', each elements of 'inputs' must be Tensor, but got {type(tensor)}")
|
||||
raise TypeError(f"For 'dstack', 'diff', each elements of 'inputs' must be Tensor, but got {type(tensor)}")
|
||||
if tensor.ndim <= 1:
|
||||
tensor = _expand(tensor, 2)
|
||||
if tensor.ndim == 2:
|
||||
tensor = P.ExpandDims()(tensor, 2)
|
||||
trans_inputs += (tensor,)
|
||||
if not trans_inputs:
|
||||
raise ValueError("For 'dstack', at least one tensor is needed to concatenate.")
|
||||
raise ValueError("For 'dstack', 'diff', at least one tensor is needed to concatenate.")
|
||||
return P.Concat(2)(trans_inputs)
|
||||
|
||||
|
||||
def diff(x, n=1, axis=-1, prepend=None, append=None):
|
||||
r"""
|
||||
Calculates the n-th discrete difference along the given axis.
|
||||
|
||||
The first difference is given by :math:`out[i] = a[i+1] - a[i]` along the given axis,
|
||||
higher differences are calculated by using `diff` iteratively.
|
||||
|
||||
Note:
|
||||
Zero-shaped Tensor is not supported, a value error is raised if
|
||||
an empty Tensor is encountered.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor.
|
||||
Full support for signed integers, partial support for floats and complex numbers
|
||||
n (int, optional): The number of times values are differenced. If zero,
|
||||
the input is returned as-is. Default: 1. Currently only 1 is supported.
|
||||
axis (int, optional): The axis along which the difference is taken, default
|
||||
is the last axis. Default: -1.
|
||||
prepend (Tensor, optional): Values to prepend or append to a along
|
||||
`axis` prior to performing the difference. Scalar values are expanded to
|
||||
arrays with length 1 in the direction of `axis` and the shape of the input
|
||||
array in along all other axes. Otherwise the dimension and shape must
|
||||
match `x` except along `axis`. Default: None.
|
||||
append (Tensor, optional): Values to prepend or append to a along
|
||||
`axis` prior to performing the difference. Scalar values are expanded to
|
||||
arrays with length 1 in the direction of `axis` and the shape of the input
|
||||
array in along all other axes. Otherwise the dimension and shape must
|
||||
match `x` except along `axis`. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor, the n-th differences. The shape of the output is the same as a except along
|
||||
`axis` where the dimension is smaller by `n`. The type of the output is the same
|
||||
as the type of the difference between any two elements of `x`. This is the same
|
||||
as the type of `x` in most cases.
|
||||
|
||||
Raises:
|
||||
TypeError: If the data type of the elementes in 'x' is uint16, uint32 or uint64.
|
||||
TypeError: If `x` is not a tensor.
|
||||
TypeError: If the dimension 'x' is less than 1.
|
||||
RuntimeError: If `n` is not 1.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor([1, 3, -1, 0, 4])
|
||||
>>> out = ops.diff(x)
|
||||
>>> print(out.asnumpy())
|
||||
[ 2 -4 1 4]
|
||||
"""
|
||||
if not isinstance(x, Tensor):
|
||||
raise TypeError(f"For 'diff', 'x' must be a tensor, but got {type(x)}")
|
||||
if x.ndim < 1:
|
||||
raise TypeError(f"For 'diff', the dimension 'x' must be at least 1, but got {x.ndim}")
|
||||
if n != 1:
|
||||
raise RuntimeError(f"For 'diff', 'n' must be 1, but got {n}")
|
||||
if x.dtype in (mstype.uint16, mstype.uint32, mstype.uint64):
|
||||
msg = f"For 'diff', the data type of the elements in 'x' cannot be uint16, uint32, uint64, but got {x.dtype}"
|
||||
raise TypeError(msg)
|
||||
if prepend is not None and append is not None:
|
||||
x = ops.Concat(axis)((prepend, x, append))
|
||||
elif append is not None:
|
||||
x = ops.Concat(axis)((x, append))
|
||||
elif prepend is not None:
|
||||
x = ops.Concat(axis)((prepend, x))
|
||||
a = [i for i in range(x.shape[axis])]
|
||||
a1 = x.gather(Tensor(a[:-1]), axis)
|
||||
a2 = x.gather(Tensor(a[1:]), axis)
|
||||
return a2 - a1
|
||||
|
||||
|
||||
def tril_indices(row, col, offset=0, dtype=mstype.int64):
|
||||
r"""
|
||||
Returns the indices of the lower triangular part of a row-by- col matrix in a 2-by-N Tensor,
|
||||
|
@ -9138,6 +9209,7 @@ __all__ = [
|
|||
'block_diag',
|
||||
'atleast_1d',
|
||||
'dstack',
|
||||
'diff',
|
||||
'atleast_2d',
|
||||
'cartesian_prod',
|
||||
'atleast_3d',
|
||||
|
|
|
@ -450,6 +450,15 @@ class DstackFunc(nn.Cell):
|
|||
return self.dstack([x1, x2])
|
||||
|
||||
|
||||
class DiffFunc(nn.Cell):
|
||||
def __init__(self):
|
||||
super(DiffFunc, self).__init__()
|
||||
self.diff = ops.diff
|
||||
|
||||
def construct(self, x):
|
||||
return self.diff(x)
|
||||
|
||||
|
||||
class AtLeast2DFunc(nn.Cell):
|
||||
def __init__(self):
|
||||
super(AtLeast2DFunc, self).__init__()
|
||||
|
@ -874,6 +883,10 @@ test_case_math_ops = [
|
|||
'desc_inputs': [Tensor(np.array([1, 2, 3]), ms.float32),
|
||||
Tensor(np.array([4, 5, 6]), ms.float32)]
|
||||
}),
|
||||
('Diff', {
|
||||
'block': DiffFunc(),
|
||||
'desc_inputs': [Tensor(np.array([1, 3, -1, 0, 4]), ms.int32)]
|
||||
}),
|
||||
('AtLeast2D', {
|
||||
'block': AtLeast2DFunc(),
|
||||
'desc_inputs': [Tensor(np.array([[1, 1, 1], [1, 1, 1]]), ms.float64),
|
||||
|
|
Loading…
Reference in New Issue