diff --git a/docs/api/api_python/mindspore.ops.functional.rst b/docs/api/api_python/mindspore.ops.functional.rst index 2d3065235de..f980f5446cc 100644 --- a/docs/api/api_python/mindspore.ops.functional.rst +++ b/docs/api/api_python/mindspore.ops.functional.rst @@ -252,6 +252,7 @@ Array操作 mindspore.ops.gather mindspore.ops.gather_d mindspore.ops.gather_nd + mindspore.ops.nonzero mindspore.ops.rank mindspore.ops.reshape mindspore.ops.scatter_nd diff --git a/docs/api/api_python/mindspore/mindspore.Tensor.rst b/docs/api/api_python/mindspore/mindspore.Tensor.rst index e0b7226f4b3..9f78746fda3 100644 --- a/docs/api/api_python/mindspore/mindspore.Tensor.rst +++ b/docs/api/api_python/mindspore/mindspore.Tensor.rst @@ -530,6 +530,14 @@ mindspore.Tensor 返回Tensor维度的数量。 + .. py:method:: nonzero() + + 计算x中非零元素的下标。 + + **返回:** + + Tensor,维度为2,类型为int64。 + .. py:method:: ptp(axis=None, keepdims=False) 该函数名称是"peak to peak"的缩写。计算沿着axis的最大值与最小值的差值。 diff --git a/docs/api/api_python/ops/mindspore.ops.func_nonzero.rst b/docs/api/api_python/ops/mindspore.ops.func_nonzero.rst new file mode 100644 index 00000000000..f995ab62ef9 --- /dev/null +++ b/docs/api/api_python/ops/mindspore.ops.func_nonzero.rst @@ -0,0 +1,19 @@ +mindspore.ops.nonzero +===================== + +.. py:function:: mindspore.ops.nonzero(x) + + 计算x中非零元素的下标。 + + **参数:** + + - **x** (Tensor) - nonzero的输入,任意维度的Tensor,秩应大于1。其数据类型为数值型和布尔型。 + + **返回:** + + Tensor,维度为2,类型为int64。 + + **异常:** + + - **TypeError** - `x` 不是Tensor。 + - **ValueError** - `x` 的维度为0。 \ No newline at end of file diff --git a/docs/api/api_python_en/mindspore.ops.functional.rst b/docs/api/api_python_en/mindspore.ops.functional.rst index 50b26810798..27b7a93356c 100644 --- a/docs/api/api_python_en/mindspore.ops.functional.rst +++ b/docs/api/api_python_en/mindspore.ops.functional.rst @@ -251,6 +251,7 @@ Array Operation mindspore.ops.gather mindspore.ops.gather_d mindspore.ops.gather_nd + mindspore.ops.nonzero mindspore.ops.rank mindspore.ops.reshape mindspore.ops.scatter_nd diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc index e38e8fe8ec7..2ecb336884b 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.cc +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -190,6 +190,7 @@ BuiltInTypeMap &GetMethodMap() { {"swapaxes", std::string("swapaxes")}, // P.transpose() {"narrow", std::string("narrow")}, // narrow() {"masked_fill", std::string("masked_fill")}, // masked_fill() + {"nonzero", std::string("nonzero")}, // nonzero() {"expand_dims", std::string("expand_dims")}, // P.expand_dims() {"squeeze", std::string("squeeze")}, // P.squeeze() {"astype", std::string("astype")}, // P.cast() diff --git a/mindspore/core/ops/non_zero.cc b/mindspore/core/ops/non_zero.cc index 105a351e1b4..20fc3dd2d33 100644 --- a/mindspore/core/ops/non_zero.cc +++ b/mindspore/core/ops/non_zero.cc @@ -48,8 +48,8 @@ abstract::ShapePtr NonZeroInferShape(const PrimitivePtr &primitive, const std::v } TypePtr NonZeroInferType(const PrimitivePtr &prim, const std::vector &input_args) { - const std::set valid_types = {kBool, kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, - kUInt32, kUInt64, kFloat16, kFloat, kFloat64, kComplex64, kComplex128}; + const std::set valid_types = {kBool, kInt8, kInt16, kInt32, kInt64, kUInt8, + kUInt16, kUInt32, kUInt64, kFloat16, kFloat, kFloat64}; auto x_type = input_args[0]->BuildType(); (void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name()); return std::make_shared(kInt64); diff --git a/mindspore/python/mindspore/_extends/parse/standard_method.py b/mindspore/python/mindspore/_extends/parse/standard_method.py index b67911fd806..9249acbc3ea 100644 --- a/mindspore/python/mindspore/_extends/parse/standard_method.py +++ b/mindspore/python/mindspore/_extends/parse/standard_method.py @@ -1578,6 +1578,13 @@ def tensor_sactter_div(input_x, indices, updates): return F.tensor_scatter_div(input_x, indices, updates) +def nonzero(x): + """ + Return a tensor of the positions of all non-zero values. + """ + return F.nonzero(x) + + def coo_to_csr(x): """convert coo to csr.""" row_indices = x.indices[:, 0] diff --git a/mindspore/python/mindspore/common/tensor.py b/mindspore/python/mindspore/common/tensor.py index 11f9af5a25f..f2d0714b18c 100644 --- a/mindspore/python/mindspore/common/tensor.py +++ b/mindspore/python/mindspore/common/tensor.py @@ -2565,6 +2565,28 @@ class Tensor(Tensor_): repeated_subs.append(tensor_operator_registry.get('repeat_elements')(sub, rep, axis)) return tensor_operator_registry.get('concatenate')(axis)(repeated_subs) + def nonzero(self): + """ + Return a tensor of the positions of all non-zero values. + + Returns: + Tensor, a 2-D tensor, containing the positions of all non-zero values of the input. + + Supported Platforms: + ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> x = Tensor(np.array([[[1, 0], [-5, 0]]]), mindspore.int32) + >>> output = x.nonzero() + >>> print(output) + [[0 0 0] + [0 1 0]] + """ + self._init_check() + return tensor_operator_registry.get('nonzero')(self) + class RowTensor(RowTensor_): """ diff --git a/mindspore/python/mindspore/ops/function/__init__.py b/mindspore/python/mindspore/ops/function/__init__.py index 513159ccda4..5793e5554a1 100644 --- a/mindspore/python/mindspore/ops/function/__init__.py +++ b/mindspore/python/mindspore/ops/function/__init__.py @@ -23,7 +23,7 @@ from . import array_func, parameter_func, math_func from .array_func import (unique, eye, matrix_band_part, fill, fill_, tile, size, ones, ones_like, shape, shape_, dyn_shape, rank, reshape, reshape_, tensor_slice, slice, scalar_to_array, scalar_to_tensor, tuple_to_array, expand_dims, transpose, scatter_nd, gather, gather_d, gather_nd, scalar_cast, - masked_fill, tensor_scatter_add, tensor_scatter_div, scatter_max, scatter_min, + masked_fill, tensor_scatter_add, tensor_scatter_div, scatter_max, scatter_min, nonzero, space_to_batch_nd) from .parameter_func import assign, assign_add, assign_sub, index_add from .math_func import (addn, absolute, abs, tensor_add, add, neg_tensor, neg, tensor_lt, less, tensor_le, le, lerp, diff --git a/mindspore/python/mindspore/ops/function/array_func.py b/mindspore/python/mindspore/ops/function/array_func.py index c7f9e37c7af..8c7f1df45bb 100644 --- a/mindspore/python/mindspore/ops/function/array_func.py +++ b/mindspore/python/mindspore/ops/function/array_func.py @@ -18,6 +18,37 @@ from mindspore.ops.primitive import constexpr from mindspore.ops import operations as P import mindspore.common.dtype as mstype +from ..operations.array_ops import NonZero + + +eye_ = P.Eye() +fill_ = P.Fill() +ones_ = P.Ones() +ones_like_ = P.OnesLike() +tile_ = P.Tile() +size_ = P.Size() +shape_ = P.Shape() +rank_ = P.Rank() +tensor_shape = P.TensorShape() +reshape_ = P.Reshape() +tensor_slice = P.Slice() +expand_dims_ = P.ExpandDims() +transpose_ = P.Transpose() +scatter_max_ = P.ScatterMax() +scatter_min_ = P.ScatterMin() +scatter_nd_ = P.ScatterNd() +gather_ = P.Gather() +gather_d_ = P.GatherD() +gather_nd_ = P.GatherNd() +tensor_scatter_add_ = P.TensorScatterAdd() +nonzero_ = NonZero() +scalar_cast_ = P.ScalarCast() +tensor_scatter_div_ = P.TensorScatterDiv() +scalar_to_array_ = P.ScalarToArray() +scalar_to_tensor_ = P.ScalarToTensor() +tuple_to_array_ = P.TupleToArray() +masked_fill_ = P.MaskedFill() +matrix_band_part_ = P.array_ops.MatrixBandPart() @constexpr @@ -33,7 +64,6 @@ def get_x_shape(x_shape): ############################## -eye_ = P.Eye() def eye(n, m, t): """ Creates a tensor with ones on the diagonal and zeros in the rest. @@ -75,7 +105,6 @@ def eye(n, m, t): return eye_(n, m, t) -matrix_band_part_ = P.array_ops.MatrixBandPart() def matrix_band_part(x, lower, upper): r""" Copy a tensor setting everything outside a central band in each innermost matrix to zero. @@ -118,7 +147,6 @@ def matrix_band_part(x, lower, upper): return matrix_band_part_(x, lower, upper) -fill_ = P.Fill() def fill(type, shape, value): """ Create a Tensor of the specified shape and fill it with the specified value. @@ -153,7 +181,6 @@ def fill(type, shape, value): return fill_(type, shape, value) -ones_ = P.Ones() def ones(shape, type): r""" Creates a tensor filled with value ones. @@ -188,7 +215,6 @@ def ones(shape, type): return ones_(shape, type) -ones_like_ = P.OnesLike() def ones_like(input_x): """ Returns a Tensor with a value of 1 and its shape and data type is the same as the input. @@ -215,7 +241,6 @@ def ones_like(input_x): return ones_like_(input_x) -tile_ = P.Tile() def tile(input_x, multiples): r""" Replicates an input tensor with given multiples times. @@ -346,7 +371,6 @@ def unique(x): return y, idx -size_ = P.Size() def size(input_x): r""" Returns a Scalar of type int that represents the size of the input Tensor and the total number of elements in the @@ -375,7 +399,6 @@ def size(input_x): return size_(input_x) -shape_ = P.Shape() def shape(input_x): """ Returns the shape of the input tensor. And it used to be static shape. @@ -406,7 +429,6 @@ def shape(input_x): return shape_(input_x) -tensor_shape = P.TensorShape() def dyn_shape(input_x): """ Returns the shape of the input tensor. @@ -432,7 +454,6 @@ def dyn_shape(input_x): return tensor_shape(input_x) -rank_ = P.Rank() def rank(input_x): """ Returns the rank of a tensor. @@ -463,7 +484,6 @@ def rank(input_x): return rank_(input_x) -reshape_ = P.Reshape() def reshape(input_x, input_shape): """ Rearranges the input Tensor based on the given shape. @@ -498,7 +518,6 @@ def reshape(input_x, input_shape): return reshape_(input_x, input_shape) -tensor_slice = P.Slice() def slice(input_x, begin, size): """ Slices a tensor in the specified shape. @@ -550,7 +569,6 @@ def slice(input_x, begin, size): return tensor_slice(input_x, begin, size) -expand_dims_ = P.ExpandDims() def expand_dims(input_x, axis): """ Adds an additional dimension to `input_x` at the given axis. @@ -585,7 +603,6 @@ def expand_dims(input_x, axis): return expand_dims_(input_x, axis) -transpose_ = P.Transpose() def transpose(input_x, input_perm): """ Permutes the dimensions of the input tensor according to input permutation. @@ -630,7 +647,6 @@ def transpose(input_x, input_perm): return transpose_(input_x, input_perm) -scatter_max_ = P.ScatterMax() def scatter_max(input_x, indices, updates): r""" Using given values to update tensor value through the max operation, along with the input indices. @@ -669,7 +685,6 @@ def scatter_max(input_x, indices, updates): return scatter_max_(input_x, indices, updates) -scatter_min_ = P.ScatterMin() def scatter_min(input_x, indices, updates): """ Updates the value of the input tensor through the minimum operation. @@ -725,7 +740,6 @@ def scatter_min(input_x, indices, updates): return scatter_min_(input_x, indices, updates) -scatter_nd_ = P.ScatterNd() def scatter_nd(indices, updates, shape): r""" Scatters a tensor into a new tensor depending on the specified indices. @@ -827,7 +841,6 @@ def scatter_nd(indices, updates, shape): return scatter_nd_(indices, updates, shape) -gather_ = P.Gather() def gather(input_params, input_indices, axis): r""" Returns the slice of the input tensor corresponding to the elements of `input_indices` on the specified `axis`. @@ -903,7 +916,6 @@ def gather(input_params, input_indices, axis): return gather_(input_params, input_indices, axis) -gather_d_ = P.GatherD() def gather_d(x, dim, index): """ Gathers values along an axis specified by dim. @@ -951,7 +963,6 @@ def gather_d(x, dim, index): return gather_d_(x, dim, index) -gather_nd_ = P.GatherNd() def gather_nd(input_x, indices): r""" Gathers slices from a tensor by indices. @@ -991,7 +1002,6 @@ def gather_nd(input_x, indices): return gather_nd_(input_x, indices) -tensor_scatter_add_ = P.TensorScatterAdd() def tensor_scatter_add(input_x, indices, updates): """ Creates a new tensor by adding the values from the positions in `input_x` indicated by @@ -1099,12 +1109,44 @@ def space_to_batch_nd(input_x, block_size, paddings): """ return P.SpaceToBatchND(block_size, paddings)(input_x) + +def nonzero(x): + """ + Return a tensor of the positions of all non-zero values. + + Args: + x (int): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The data type is Number or Bool. + + Returns: + y (Tensor): The shape of tensor is 2-D. The data type is int64. + + Raises: + TypeError: If `x` is not Tensor. + ValueError: If 'x' dim equal to 0. + + Supported Platforms: + ``GPU`` + + Examples: + >>> import mindspore + >>> import numpy as np + >>> from mindspore import Tensor + >>> import mindspore.ops as ops + >>> x = Tensor(np.array([[[1, 0], [-5, 0]]]), mindspore.int32) + >>> nonzero = ops.nonzero() + >>> output = nonzero(x) + >>> print(output) + [[0 0 0] + [0 1 0]] + """ + return nonzero_(x) + + ############################## # Type Conversion Functions. ############################## -scalar_cast_ = P.ScalarCast() def scalar_cast(input_x, input_y): """ Casts the input scalar to another type. @@ -1129,7 +1171,6 @@ def scalar_cast(input_x, input_y): """ return scalar_cast_(input_x, input_y) -tensor_scatter_div_ = P.TensorScatterDiv() def tensor_scatter_div(input_x, indices, updates): """ Creates a new tensor by dividing the values from the positions in `input_x` indicated by @@ -1179,7 +1220,6 @@ def tensor_scatter_div(input_x, indices, updates): return tensor_scatter_div_(input_x, indices, updates) -scalar_to_array_ = P.ScalarToArray() def scalar_to_array(input_x): """ Converts a scalar to a `Tensor`. @@ -1209,7 +1249,6 @@ def scalar_to_array(input_x): return scalar_to_array_(input_x) -scalar_to_tensor_ = P.ScalarToTensor() def scalar_to_tensor(input_x, dtype=mstype.float32): """ Converts a scalar to a `Tensor`, and converts the data type to the specified type. @@ -1237,7 +1276,6 @@ def scalar_to_tensor(input_x, dtype=mstype.float32): return scalar_to_tensor_(input_x, dtype) -tuple_to_array_ = P.TupleToArray() def tuple_to_array(input_x): """ Converts a tuple to a tensor. @@ -1272,7 +1310,6 @@ def tuple_to_array(input_x): return tuple_to_array_(input_x) -masked_fill_ = P.MaskedFill() def masked_fill(x, mask, value): """ Fills elements of self tensor with value where mask is True. @@ -1308,6 +1345,7 @@ def masked_fill(x, mask, value): """ return masked_fill_(x, mask, value) + __all__ = [ 'unique', 'eye', @@ -1341,6 +1379,7 @@ __all__ = [ 'masked_fill', 'tensor_scatter_div', 'scatter_max', - 'scatter_min' + 'scatter_min', + 'nonzero', ] __all__.sort() diff --git a/mindspore/python/mindspore/ops/functional.py b/mindspore/python/mindspore/ops/functional.py index 4241e6d7221..12aee45e1d5 100644 --- a/mindspore/python/mindspore/ops/functional.py +++ b/mindspore/python/mindspore/ops/functional.py @@ -984,6 +984,7 @@ tensor_operator_registry.register('index_add', P.IndexAdd) tensor_operator_registry.register('scatter_max', P.ScatterMax) tensor_operator_registry.register('scatter_min', P.ScatterMin) tensor_operator_registry.register('space_to_batch_nd', P.SpaceToBatchND) +tensor_operator_registry.register('nonzero', nonzero) # ms cannot support Tensor(True) compare tensor_operator_registry.register('__eq__', equal) tensor_operator_registry.register('__ne__', not_equal) diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index f6fa430707b..f6a2a453215 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -7650,15 +7650,7 @@ class NonZero(Primitive): """ Return a tensor of the positions of all non-zero values. - Inputs: - - **x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The data type is Number or Bool. - - Outputs: - - **y** (Tensor), The shape of tensor is 2-D. The data type is int64. - - Raises: - TypeError: If `x` is not Tensor. - ValueError: If 'x' dim equal to 0. + Refer to :func:`mindspore.ops.expand_dims` for more detail. Supported Platforms: ``GPU`` @@ -7667,9 +7659,9 @@ class NonZero(Primitive): >>> import mindspore >>> import numpy as np >>> from mindspore import Tensor - >>> import mindspore.ops as ops + >>> from mindspore.ops.operations.array_ops import NonZero >>> x = Tensor(np.array([[[1, 0], [-5, 0]]]), mindspore.int32) - >>> nonzero = ops.NonZero() + >>> nonzero = NonZero() >>> output = nonzero(x) >>> print(output) [[0 0 0]