forked from mindspore-Ecosystem/mindspore
Add tensor and functional interface for NonZero operator.
This commit is contained in:
parent
fc011ae8a3
commit
1b8a7b0a8f
|
@ -252,6 +252,7 @@ Array操作
|
|||
mindspore.ops.gather
|
||||
mindspore.ops.gather_d
|
||||
mindspore.ops.gather_nd
|
||||
mindspore.ops.nonzero
|
||||
mindspore.ops.rank
|
||||
mindspore.ops.reshape
|
||||
mindspore.ops.scatter_nd
|
||||
|
|
|
@ -530,6 +530,14 @@ mindspore.Tensor
|
|||
|
||||
返回Tensor维度的数量。
|
||||
|
||||
.. py:method:: nonzero()
|
||||
|
||||
计算x中非零元素的下标。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,维度为2,类型为int64。
|
||||
|
||||
.. py:method:: ptp(axis=None, keepdims=False)
|
||||
|
||||
该函数名称是"peak to peak"的缩写。计算沿着axis的最大值与最小值的差值。
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
mindspore.ops.nonzero
|
||||
=====================
|
||||
|
||||
.. py:function:: mindspore.ops.nonzero(x)
|
||||
|
||||
计算x中非零元素的下标。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **x** (Tensor) - nonzero的输入,任意维度的Tensor,秩应大于1。其数据类型为数值型和布尔型。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,维度为2,类型为int64。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `x` 不是Tensor。
|
||||
- **ValueError** - `x` 的维度为0。
|
|
@ -251,6 +251,7 @@ Array Operation
|
|||
mindspore.ops.gather
|
||||
mindspore.ops.gather_d
|
||||
mindspore.ops.gather_nd
|
||||
mindspore.ops.nonzero
|
||||
mindspore.ops.rank
|
||||
mindspore.ops.reshape
|
||||
mindspore.ops.scatter_nd
|
||||
|
|
|
@ -190,6 +190,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"swapaxes", std::string("swapaxes")}, // P.transpose()
|
||||
{"narrow", std::string("narrow")}, // narrow()
|
||||
{"masked_fill", std::string("masked_fill")}, // masked_fill()
|
||||
{"nonzero", std::string("nonzero")}, // nonzero()
|
||||
{"expand_dims", std::string("expand_dims")}, // P.expand_dims()
|
||||
{"squeeze", std::string("squeeze")}, // P.squeeze()
|
||||
{"astype", std::string("astype")}, // P.cast()
|
||||
|
|
|
@ -48,8 +48,8 @@ abstract::ShapePtr NonZeroInferShape(const PrimitivePtr &primitive, const std::v
|
|||
}
|
||||
|
||||
TypePtr NonZeroInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::set valid_types = {kBool, kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
|
||||
kUInt32, kUInt64, kFloat16, kFloat, kFloat64, kComplex64, kComplex128};
|
||||
const std::set valid_types = {kBool, kInt8, kInt16, kInt32, kInt64, kUInt8,
|
||||
kUInt16, kUInt32, kUInt64, kFloat16, kFloat, kFloat64};
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name());
|
||||
return std::make_shared<TensorType>(kInt64);
|
||||
|
|
|
@ -1578,6 +1578,13 @@ def tensor_sactter_div(input_x, indices, updates):
|
|||
return F.tensor_scatter_div(input_x, indices, updates)
|
||||
|
||||
|
||||
def nonzero(x):
|
||||
"""
|
||||
Return a tensor of the positions of all non-zero values.
|
||||
"""
|
||||
return F.nonzero(x)
|
||||
|
||||
|
||||
def coo_to_csr(x):
|
||||
"""convert coo to csr."""
|
||||
row_indices = x.indices[:, 0]
|
||||
|
|
|
@ -2565,6 +2565,28 @@ class Tensor(Tensor_):
|
|||
repeated_subs.append(tensor_operator_registry.get('repeat_elements')(sub, rep, axis))
|
||||
return tensor_operator_registry.get('concatenate')(axis)(repeated_subs)
|
||||
|
||||
def nonzero(self):
|
||||
"""
|
||||
Return a tensor of the positions of all non-zero values.
|
||||
|
||||
Returns:
|
||||
Tensor, a 2-D tensor, containing the positions of all non-zero values of the input.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> x = Tensor(np.array([[[1, 0], [-5, 0]]]), mindspore.int32)
|
||||
>>> output = x.nonzero()
|
||||
>>> print(output)
|
||||
[[0 0 0]
|
||||
[0 1 0]]
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('nonzero')(self)
|
||||
|
||||
|
||||
class RowTensor(RowTensor_):
|
||||
"""
|
||||
|
|
|
@ -23,7 +23,7 @@ from . import array_func, parameter_func, math_func
|
|||
from .array_func import (unique, eye, matrix_band_part, fill, fill_, tile, size, ones, ones_like, shape, shape_,
|
||||
dyn_shape, rank, reshape, reshape_, tensor_slice, slice, scalar_to_array, scalar_to_tensor,
|
||||
tuple_to_array, expand_dims, transpose, scatter_nd, gather, gather_d, gather_nd, scalar_cast,
|
||||
masked_fill, tensor_scatter_add, tensor_scatter_div, scatter_max, scatter_min,
|
||||
masked_fill, tensor_scatter_add, tensor_scatter_div, scatter_max, scatter_min, nonzero,
|
||||
space_to_batch_nd)
|
||||
from .parameter_func import assign, assign_add, assign_sub, index_add
|
||||
from .math_func import (addn, absolute, abs, tensor_add, add, neg_tensor, neg, tensor_lt, less, tensor_le, le, lerp,
|
||||
|
|
|
@ -18,6 +18,37 @@
|
|||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from ..operations.array_ops import NonZero
|
||||
|
||||
|
||||
eye_ = P.Eye()
|
||||
fill_ = P.Fill()
|
||||
ones_ = P.Ones()
|
||||
ones_like_ = P.OnesLike()
|
||||
tile_ = P.Tile()
|
||||
size_ = P.Size()
|
||||
shape_ = P.Shape()
|
||||
rank_ = P.Rank()
|
||||
tensor_shape = P.TensorShape()
|
||||
reshape_ = P.Reshape()
|
||||
tensor_slice = P.Slice()
|
||||
expand_dims_ = P.ExpandDims()
|
||||
transpose_ = P.Transpose()
|
||||
scatter_max_ = P.ScatterMax()
|
||||
scatter_min_ = P.ScatterMin()
|
||||
scatter_nd_ = P.ScatterNd()
|
||||
gather_ = P.Gather()
|
||||
gather_d_ = P.GatherD()
|
||||
gather_nd_ = P.GatherNd()
|
||||
tensor_scatter_add_ = P.TensorScatterAdd()
|
||||
nonzero_ = NonZero()
|
||||
scalar_cast_ = P.ScalarCast()
|
||||
tensor_scatter_div_ = P.TensorScatterDiv()
|
||||
scalar_to_array_ = P.ScalarToArray()
|
||||
scalar_to_tensor_ = P.ScalarToTensor()
|
||||
tuple_to_array_ = P.TupleToArray()
|
||||
masked_fill_ = P.MaskedFill()
|
||||
matrix_band_part_ = P.array_ops.MatrixBandPart()
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -33,7 +64,6 @@ def get_x_shape(x_shape):
|
|||
##############################
|
||||
|
||||
|
||||
eye_ = P.Eye()
|
||||
def eye(n, m, t):
|
||||
"""
|
||||
Creates a tensor with ones on the diagonal and zeros in the rest.
|
||||
|
@ -75,7 +105,6 @@ def eye(n, m, t):
|
|||
return eye_(n, m, t)
|
||||
|
||||
|
||||
matrix_band_part_ = P.array_ops.MatrixBandPart()
|
||||
def matrix_band_part(x, lower, upper):
|
||||
r"""
|
||||
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
|
||||
|
@ -118,7 +147,6 @@ def matrix_band_part(x, lower, upper):
|
|||
return matrix_band_part_(x, lower, upper)
|
||||
|
||||
|
||||
fill_ = P.Fill()
|
||||
def fill(type, shape, value):
|
||||
"""
|
||||
Create a Tensor of the specified shape and fill it with the specified value.
|
||||
|
@ -153,7 +181,6 @@ def fill(type, shape, value):
|
|||
return fill_(type, shape, value)
|
||||
|
||||
|
||||
ones_ = P.Ones()
|
||||
def ones(shape, type):
|
||||
r"""
|
||||
Creates a tensor filled with value ones.
|
||||
|
@ -188,7 +215,6 @@ def ones(shape, type):
|
|||
return ones_(shape, type)
|
||||
|
||||
|
||||
ones_like_ = P.OnesLike()
|
||||
def ones_like(input_x):
|
||||
"""
|
||||
Returns a Tensor with a value of 1 and its shape and data type is the same as the input.
|
||||
|
@ -215,7 +241,6 @@ def ones_like(input_x):
|
|||
return ones_like_(input_x)
|
||||
|
||||
|
||||
tile_ = P.Tile()
|
||||
def tile(input_x, multiples):
|
||||
r"""
|
||||
Replicates an input tensor with given multiples times.
|
||||
|
@ -346,7 +371,6 @@ def unique(x):
|
|||
return y, idx
|
||||
|
||||
|
||||
size_ = P.Size()
|
||||
def size(input_x):
|
||||
r"""
|
||||
Returns a Scalar of type int that represents the size of the input Tensor and the total number of elements in the
|
||||
|
@ -375,7 +399,6 @@ def size(input_x):
|
|||
return size_(input_x)
|
||||
|
||||
|
||||
shape_ = P.Shape()
|
||||
def shape(input_x):
|
||||
"""
|
||||
Returns the shape of the input tensor. And it used to be static shape.
|
||||
|
@ -406,7 +429,6 @@ def shape(input_x):
|
|||
return shape_(input_x)
|
||||
|
||||
|
||||
tensor_shape = P.TensorShape()
|
||||
def dyn_shape(input_x):
|
||||
"""
|
||||
Returns the shape of the input tensor.
|
||||
|
@ -432,7 +454,6 @@ def dyn_shape(input_x):
|
|||
return tensor_shape(input_x)
|
||||
|
||||
|
||||
rank_ = P.Rank()
|
||||
def rank(input_x):
|
||||
"""
|
||||
Returns the rank of a tensor.
|
||||
|
@ -463,7 +484,6 @@ def rank(input_x):
|
|||
return rank_(input_x)
|
||||
|
||||
|
||||
reshape_ = P.Reshape()
|
||||
def reshape(input_x, input_shape):
|
||||
"""
|
||||
Rearranges the input Tensor based on the given shape.
|
||||
|
@ -498,7 +518,6 @@ def reshape(input_x, input_shape):
|
|||
return reshape_(input_x, input_shape)
|
||||
|
||||
|
||||
tensor_slice = P.Slice()
|
||||
def slice(input_x, begin, size):
|
||||
"""
|
||||
Slices a tensor in the specified shape.
|
||||
|
@ -550,7 +569,6 @@ def slice(input_x, begin, size):
|
|||
return tensor_slice(input_x, begin, size)
|
||||
|
||||
|
||||
expand_dims_ = P.ExpandDims()
|
||||
def expand_dims(input_x, axis):
|
||||
"""
|
||||
Adds an additional dimension to `input_x` at the given axis.
|
||||
|
@ -585,7 +603,6 @@ def expand_dims(input_x, axis):
|
|||
return expand_dims_(input_x, axis)
|
||||
|
||||
|
||||
transpose_ = P.Transpose()
|
||||
def transpose(input_x, input_perm):
|
||||
"""
|
||||
Permutes the dimensions of the input tensor according to input permutation.
|
||||
|
@ -630,7 +647,6 @@ def transpose(input_x, input_perm):
|
|||
return transpose_(input_x, input_perm)
|
||||
|
||||
|
||||
scatter_max_ = P.ScatterMax()
|
||||
def scatter_max(input_x, indices, updates):
|
||||
r"""
|
||||
Using given values to update tensor value through the max operation, along with the input indices.
|
||||
|
@ -669,7 +685,6 @@ def scatter_max(input_x, indices, updates):
|
|||
return scatter_max_(input_x, indices, updates)
|
||||
|
||||
|
||||
scatter_min_ = P.ScatterMin()
|
||||
def scatter_min(input_x, indices, updates):
|
||||
"""
|
||||
Updates the value of the input tensor through the minimum operation.
|
||||
|
@ -725,7 +740,6 @@ def scatter_min(input_x, indices, updates):
|
|||
return scatter_min_(input_x, indices, updates)
|
||||
|
||||
|
||||
scatter_nd_ = P.ScatterNd()
|
||||
def scatter_nd(indices, updates, shape):
|
||||
r"""
|
||||
Scatters a tensor into a new tensor depending on the specified indices.
|
||||
|
@ -827,7 +841,6 @@ def scatter_nd(indices, updates, shape):
|
|||
return scatter_nd_(indices, updates, shape)
|
||||
|
||||
|
||||
gather_ = P.Gather()
|
||||
def gather(input_params, input_indices, axis):
|
||||
r"""
|
||||
Returns the slice of the input tensor corresponding to the elements of `input_indices` on the specified `axis`.
|
||||
|
@ -903,7 +916,6 @@ def gather(input_params, input_indices, axis):
|
|||
return gather_(input_params, input_indices, axis)
|
||||
|
||||
|
||||
gather_d_ = P.GatherD()
|
||||
def gather_d(x, dim, index):
|
||||
"""
|
||||
Gathers values along an axis specified by dim.
|
||||
|
@ -951,7 +963,6 @@ def gather_d(x, dim, index):
|
|||
return gather_d_(x, dim, index)
|
||||
|
||||
|
||||
gather_nd_ = P.GatherNd()
|
||||
def gather_nd(input_x, indices):
|
||||
r"""
|
||||
Gathers slices from a tensor by indices.
|
||||
|
@ -991,7 +1002,6 @@ def gather_nd(input_x, indices):
|
|||
return gather_nd_(input_x, indices)
|
||||
|
||||
|
||||
tensor_scatter_add_ = P.TensorScatterAdd()
|
||||
def tensor_scatter_add(input_x, indices, updates):
|
||||
"""
|
||||
Creates a new tensor by adding the values from the positions in `input_x` indicated by
|
||||
|
@ -1099,12 +1109,44 @@ def space_to_batch_nd(input_x, block_size, paddings):
|
|||
"""
|
||||
return P.SpaceToBatchND(block_size, paddings)(input_x)
|
||||
|
||||
|
||||
def nonzero(x):
|
||||
"""
|
||||
Return a tensor of the positions of all non-zero values.
|
||||
|
||||
Args:
|
||||
x (int): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The data type is Number or Bool.
|
||||
|
||||
Returns:
|
||||
y (Tensor): The shape of tensor is 2-D. The data type is int64.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not Tensor.
|
||||
ValueError: If 'x' dim equal to 0.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> import mindspore.ops as ops
|
||||
>>> x = Tensor(np.array([[[1, 0], [-5, 0]]]), mindspore.int32)
|
||||
>>> nonzero = ops.nonzero()
|
||||
>>> output = nonzero(x)
|
||||
>>> print(output)
|
||||
[[0 0 0]
|
||||
[0 1 0]]
|
||||
"""
|
||||
return nonzero_(x)
|
||||
|
||||
|
||||
##############################
|
||||
# Type Conversion Functions.
|
||||
##############################
|
||||
|
||||
|
||||
scalar_cast_ = P.ScalarCast()
|
||||
def scalar_cast(input_x, input_y):
|
||||
"""
|
||||
Casts the input scalar to another type.
|
||||
|
@ -1129,7 +1171,6 @@ def scalar_cast(input_x, input_y):
|
|||
"""
|
||||
return scalar_cast_(input_x, input_y)
|
||||
|
||||
tensor_scatter_div_ = P.TensorScatterDiv()
|
||||
def tensor_scatter_div(input_x, indices, updates):
|
||||
"""
|
||||
Creates a new tensor by dividing the values from the positions in `input_x` indicated by
|
||||
|
@ -1179,7 +1220,6 @@ def tensor_scatter_div(input_x, indices, updates):
|
|||
return tensor_scatter_div_(input_x, indices, updates)
|
||||
|
||||
|
||||
scalar_to_array_ = P.ScalarToArray()
|
||||
def scalar_to_array(input_x):
|
||||
"""
|
||||
Converts a scalar to a `Tensor`.
|
||||
|
@ -1209,7 +1249,6 @@ def scalar_to_array(input_x):
|
|||
return scalar_to_array_(input_x)
|
||||
|
||||
|
||||
scalar_to_tensor_ = P.ScalarToTensor()
|
||||
def scalar_to_tensor(input_x, dtype=mstype.float32):
|
||||
"""
|
||||
Converts a scalar to a `Tensor`, and converts the data type to the specified type.
|
||||
|
@ -1237,7 +1276,6 @@ def scalar_to_tensor(input_x, dtype=mstype.float32):
|
|||
return scalar_to_tensor_(input_x, dtype)
|
||||
|
||||
|
||||
tuple_to_array_ = P.TupleToArray()
|
||||
def tuple_to_array(input_x):
|
||||
"""
|
||||
Converts a tuple to a tensor.
|
||||
|
@ -1272,7 +1310,6 @@ def tuple_to_array(input_x):
|
|||
return tuple_to_array_(input_x)
|
||||
|
||||
|
||||
masked_fill_ = P.MaskedFill()
|
||||
def masked_fill(x, mask, value):
|
||||
"""
|
||||
Fills elements of self tensor with value where mask is True.
|
||||
|
@ -1308,6 +1345,7 @@ def masked_fill(x, mask, value):
|
|||
"""
|
||||
return masked_fill_(x, mask, value)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'unique',
|
||||
'eye',
|
||||
|
@ -1341,6 +1379,7 @@ __all__ = [
|
|||
'masked_fill',
|
||||
'tensor_scatter_div',
|
||||
'scatter_max',
|
||||
'scatter_min'
|
||||
'scatter_min',
|
||||
'nonzero',
|
||||
]
|
||||
__all__.sort()
|
||||
|
|
|
@ -984,6 +984,7 @@ tensor_operator_registry.register('index_add', P.IndexAdd)
|
|||
tensor_operator_registry.register('scatter_max', P.ScatterMax)
|
||||
tensor_operator_registry.register('scatter_min', P.ScatterMin)
|
||||
tensor_operator_registry.register('space_to_batch_nd', P.SpaceToBatchND)
|
||||
tensor_operator_registry.register('nonzero', nonzero)
|
||||
# ms cannot support Tensor(True) compare
|
||||
tensor_operator_registry.register('__eq__', equal)
|
||||
tensor_operator_registry.register('__ne__', not_equal)
|
||||
|
|
|
@ -7650,15 +7650,7 @@ class NonZero(Primitive):
|
|||
"""
|
||||
Return a tensor of the positions of all non-zero values.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The data type is Number or Bool.
|
||||
|
||||
Outputs:
|
||||
- **y** (Tensor), The shape of tensor is 2-D. The data type is int64.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not Tensor.
|
||||
ValueError: If 'x' dim equal to 0.
|
||||
Refer to :func:`mindspore.ops.expand_dims` for more detail.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
@ -7667,9 +7659,9 @@ class NonZero(Primitive):
|
|||
>>> import mindspore
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> import mindspore.ops as ops
|
||||
>>> from mindspore.ops.operations.array_ops import NonZero
|
||||
>>> x = Tensor(np.array([[[1, 0], [-5, 0]]]), mindspore.int32)
|
||||
>>> nonzero = ops.NonZero()
|
||||
>>> nonzero = NonZero()
|
||||
>>> output = nonzero(x)
|
||||
>>> print(output)
|
||||
[[0 0 0]
|
||||
|
|
Loading…
Reference in New Issue