diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 22e786b22f0..f4e07fae103 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -19,7 +19,8 @@ import inspect import math from enum import Enum from functools import reduce, wraps -from itertools import repeat +from itertools import repeat, zip_longest +from collections import deque from collections.abc import Iterable import numpy as np from mindspore import log as logger @@ -683,13 +684,13 @@ class Validator: def check_swapaxes_axis(axes, ndim): """Check all the axes argument for tensor.swapaxes""" if isinstance(axes, int): - check_axis_in_range(axes, ndim) + Validator.check_axis_in_range(axes, ndim) return axes % ndim if isinstance(axes, (tuple, list)): for axis in axes: if not isinstance(axis, int): raise TypeError(f"axis argument should be integer, but got {type(axis)}.") - check_axis_in_range(axis, ndim) + Validator.check_axis_in_range(axis, ndim) axes = tuple(map(lambda x: x % ndim, axes)) return axes raise TypeError(f"axes should be integer, list or tuple for check, but got {type(axes)}.") @@ -742,6 +743,95 @@ class Validator: raise ValueError(f'axis {axis} is out of bounds for array of dimension {ndim}') return axis % ndim + @staticmethod + def check_axis_valid(axes, ndim): + """ + Checks axes are valid given ndim, and returns axes that can be passed + to the built-in operator (non-negative, int or tuple) + """ + if axes is None: + axes = tuple(range(ndim)) + return axes + if isinstance(axes, (tuple, list)): + for axis in axes: + Validator.check_axis_in_range(axis, ndim) + axes = tuple(map(lambda x: x % ndim, axes)) + if any(axes.count(el) > 1 for el in axes): + raise ValueError('duplicate value in "axis"') + return axes + Validator.check_axis_in_range(axes, ndim) + return (axes % ndim,) + + @staticmethod + def max_(*args): + return max(*args) + + @staticmethod + def min_(*args): + return min(*args) + + @staticmethod + def expanded_shape(ndim, axis_size, axis): + """ + Returns a shape with size = 1 for all dimensions + except at axis. + """ + return tuple(axis_size if i == axis else 1 for i in range(ndim)) + + @staticmethod + def tuple_slice(tup, start, end): + """get sliced tuple from start and end.""" + return tup[start:end] + + @staticmethod + def infer_out_shape(*shapes): + """ + Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast. + """ + shape_out = deque() + reversed_shapes = map(reversed, shapes) + for items in zip_longest(*reversed_shapes, fillvalue=1): + max_size = 0 if 0 in items else max(items) + if any(item not in (1, max_size) for item in items): + raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}') + shape_out.appendleft(max_size) + return tuple(shape_out) + + @staticmethod + def get_log2_size(size): + return math.ceil(math.log2(size)) + + @staticmethod + def check_axis_type(axis, type_int=True, type_tuple=True, type_list=True): + """Check axis argument type.""" + if type_int and isinstance(axis, int): + return True + if (type_tuple and isinstance(axis, tuple)) or (type_list and isinstance(axis, list)): + for ax in axis: + if not isinstance(ax, int): + raise TypeError(f"Each axis should be integer, but got {type(ax)} in {axis}.") + return True + + type_str = "" + if type_int: type_str += "int, " + if type_tuple: type_str += "tuple, " + if type_list: type_str += "list, " + raise TypeError(f"Axis should be {type_str}but got {type(axis)}.") + + @staticmethod + def check_and_canonicalize_axes(axes, ndim): + """Check whether the types and values of input axes are valid.""" + axes = axes if isinstance(axes, tuple) else (axes,) + new_axes = () + for ax in axes: + if not isinstance(ax, int): + raise TypeError((f"Each axis should be integer, but got {type(ax)} in {axes}.")) + if not -ndim <= ax < ndim: + raise ValueError(f'axis {ax} is out of bounds for array of dimension {ndim}') + ax = ax if ax >= 0 else ax + ndim + new_axes += (ax,) + return new_axes + def check_input_format(input_param): """Judge input format.""" @@ -770,13 +860,6 @@ def _expand_tuple(n_dimensions): return convert -def check_axis_in_range(axis, ndim): - """Checks axes are with the bounds of ndim""" - if -ndim <= axis < ndim: - return True - raise ValueError(f'axis {axis} is out of bounds for tensor of dimension {ndim}') - - def _check_data_type_valid(data, valid_type): """Check data type valid.""" if valid_type is None: diff --git a/mindspore/_extends/parse/standard_method.py b/mindspore/_extends/parse/standard_method.py index b2c2f622a5f..036f3724cfd 100644 --- a/mindspore/_extends/parse/standard_method.py +++ b/mindspore/_extends/parse/standard_method.py @@ -25,10 +25,13 @@ from ..._checkparam import Validator as validator from ...ops import functional as F from ...ops import operations as P from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \ - zeros_like, ones_like + zeros_like, ones_like, repeat_elements from ...ops.composite.base import _append +from ...ops.composite.multitype_ops import _constexpr_utils as const_utils +from ...ops.composite.multitype_ops import _compile_utils as compile_utils from ...ops.primitive import constexpr + __all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like'] shape_ = P.Shape() @@ -36,12 +39,17 @@ dtype_ = P.DType() abs_ = P.Abs() ndim_ = P.Rank() size_ = P.Size() +cumsum_ = P.CumSum() +_reduce_sum_default = P.ReduceSum() +_reduce_sum_keepdims = P.ReduceSum(True) +_mean_keepdims = P.ReduceMean(True) itemsize_map = {mstype.bool_: 1, mstype.int8: 1, mstype.uint8: 1, mstype.float16: 2, mstype.int16: 2, mstype.uint16: 2, mstype.float32: 4, mstype.int32: 4, mstype.uint32: 4, mstype.float64: 8, mstype.int64: 8, mstype.uint64: 8} +nan_tensor = Tensor(float('nan'), dtype=mstype.float32) def mean(x, axis=(), keep_dims=False): """ @@ -147,7 +155,7 @@ def strides_(x): return strides -def astype(x, dtype, copy=True): +def astype(x, dtype, copy=True): # pylint: disable=redefined-outer-name """ Return a copy of the tensor, casted to a specified type. @@ -354,10 +362,10 @@ def swapaxes(x, axis1, axis2): new_perm = None if axis2 + 1 < x.ndim: new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \ - perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] + perm[axis2 + 1:] + perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] + perm[axis2 + 1:] else: new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \ - perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] + perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] return F.transpose(x, new_perm) @@ -467,14 +475,923 @@ def argmin(x, axis=None): return P.Argmax(axis)(F.neg_tensor(x)) -def getitem(data, item): +def cumsum(x, axis=None, dtype=None): + """ + Returns the cumulative sum of the elements along a given axis. + + Note: + If ``x.dtype`` is :class:`int8`, :class:`int16` or :class:`bool`, the result + `dtype` will be elevated to :class:`int32`. + + Args: + x (Tensor): Input tensor. + axis (int, optional): Axis along which the cumulative sum is computed. The + default (None) is to compute the cumsum over the flattened array. + dtype (:class:`mindspore.dtype`, optional): If not specified, stay the same as original, + tensor, unless it has an integer dtype with a precision less than :class:`float32`. + In that case, :class:`float32` is used. + + Returns: + Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> a = Tensor(np.ones((3,3)).astype("float32")) + >>> output = a.cumsum(0) + >>> print(output) + [[1. 1. 1.] + [2. 2. 2.] + [3. 3. 3.]] + """ + original_dtype = x.dtype + # If original tensor is int, and has precision less then int32, convert + # to int32 + if x.dtype in (mstype.int8, mstype.int16, mstype.uint8, mstype.int16): + x = x.astype(mstype.int32) + if axis is None: + x = x.ravel() + axis = 0 + check_axis_in_range_const(axis, x.ndim) + if dtype is not None and original_dtype != dtype: + return cumsum_(x, axis).astype(dtype, copy=False) + return cumsum_(x, axis) + + +def copy(x): + """ + Returns a copy of the tensor. + + Note: + The current implementation does not support `order` argument. + + Args: + x (Tensor): Input tensor. + + Returns: + Copied tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> a = Tensor(np.ones((3,3)).astype("float32")) + >>> output = a.copy() + >>> print(output) + [[1. 1. 1.] + [1. 1. 1.] + [1. 1. 1.]] + """ + if x.size == 0: + return x + origin_dtype = x.dtype + if origin_dtype == mstype.bool_: + return F.logical_not(F.logical_not(x)) + if origin_dtype != mstype.float64: + x = x.astype(mstype.float32) + x = x / 1.0 + x = x.astype(origin_dtype) + return x + + +def max(x, axis=None, keepdims=False, initial=None, where=True): # pylint: disable=redefined-builtin + """ + Returns the maximum of a tensor or maximum along an axis. + + Args: + x (Tensor): Input Tensor. + axis (None or int or tuple of ints, optional): defaults to None. Axis or + axes along which to operate. By default, flattened input is used. If + this is a tuple of ints, the maximum is selected over multiple axes, + instead of a single axis or all the axes as before. + keepdims (boolean, optional): defaults to False. + If this is set to True, the axes which are reduced are left in the + result as dimensions with size one. With this option, the result will + broadcast correctly against the input array. + initial (scalar, optional): + The minimum value of an output element. Must be present to allow + computation on empty slice. + where (boolean Tensor, optional): defaults to True. + A boolean array which is broadcasted to match the dimensions of array, + and selects elements to include in the reduction. If non-default value + is passed, initial must also be provided. + + Returns: + Tensor or scalar, maximum of input tensor. If `axis` is None, the result is a scalar + value. If `axis` is given, the result is an array of dimension ``a.ndim - 1``. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> import mindspore.numpy as np + >>> a = Tensor(np.arange(4).reshape((2,2)).astype('float32')) + >>> output = a.max() + >>> print(output) + 3.0 + """ + return compile_utils.reduce_(x, P.ReduceMax(keepdims), cmp_fn=F.maximum, + axis=axis, keepdims=keepdims, initial=initial, where=where) + + +def min(x, axis=None, keepdims=False, initial=None, where=True): # pylint: disable=redefined-builtin + """ + Returns the minimum of a tensor or minimum along an axis. + + Args: + a (Tensor): Input data. + axis (None or int or tuple of ints, optional): defaults to None. Axis or + axes along which to operate. By default, flattened input is used. If + this is a tuple of ints, the minimum is selected over multiple axes, + instead of a single axis or all the axes as before. + keepdims (boolean, optional): defaults to False. + If this is set to True, the axes which are reduced are left in the + result as dimensions with size one. With this option, the result will + broadcast correctly against the input array. + initial (scalar, optional): + The maximum value of an output element. Must be present to allow + computation on empty slice. + where (boolean Tensor, optional): defaults to True. + A boolean array which is broadcasted to match the dimensions of array, + and selects elements to include in the reduction. If non-default value + is passed, initial must also be provided. + + Returns: + Tensor or scalar, minimum of `a`. If axis is None, the result is a scalar + value. If `axis` is given, the result is an array of dimension ``a.ndim - 1``. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> import mindspore.numpy as np + >>> a = Tensor(np.arange(4).reshape((2,2)).astype('float32')) + >>> output = a.min() + >>> print(output) + 0.0 + """ + return compile_utils.reduce_(x, P.ReduceMin(keepdims), cmp_fn=F.minimum, + axis=axis, keepdims=keepdims, initial=initial, where=where) + + +def resize(x, *new_shape): + """ + Changes shape and size of array in-place. + + Note: + Instead of changing the size of the input array and returns nothing as in numpy, + this method returns a new Tensor with the input size. + Numpy argument `refcheck` is not supported. + + Args: + new_shape (Union[ints, tuple of ints]): Shape of resized array. + + Returns: + Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> from mindspore import numpy as np + >>> x = np.array([[0, 1], [2, 3]]) + >>> x = x.resize(2, 3) + >>> print(x) + [[0 1 2] + [3 0 0]] + """ + if not new_shape: + return x + if len(new_shape) == 1: + if isinstance(new_shape[0], tuple): + new_shape = new_shape[0] + flattened = x.ravel() + cur_size = F.shape_mul(x.shape) + new_size = F.shape_mul(new_shape) + diff_size = new_size - cur_size + if diff_size > 0: + pad_val = F.fill(x.dtype, (diff_size,), 0) + res = P.Concat()((flattened, pad_val)) + else: + res = flattened[:new_size] + return res.reshape(new_shape) + + +def diagonal(x, offset=0, axis1=0, axis2=1): + """ + Returns specified diagonals. + + Args: + offset (int, optional): Offset of the diagonal from the main diagonal. + Can be positive or negative. Defaults to main diagonal. + axis1 (int, optional): Axis to be used as the first axis of the 2-D + sub-arrays from which the diagonals should be taken. Defaults to + first axis (0). + axis2 (int, optional): Axis to be used as the second axis of the 2-D + sub-arrays from which the diagonals should be taken. Defaults to + second axis. + + Returns: + Tensor, if `a` is 2-D, then `a` 1-D array containing the diagonal. + + Raises: + ValueError: if the input tensor has less than two dimensions. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> a = np.arange(4).reshape(2,2) + >>> print(a) + [[0 1] + [2 3]] + >>> output = a.diagonal() + >>> print(output) + [0 3] + """ + ndim = x.ndim + if ndim < 2: + const_utils.raise_value_error('diagonal requires an array of at least two dimensions') + dtype = x.dtype + + axes = check_axis_valid((axis1, axis2), ndim) + perm = () + for i in range(ndim): + if i not in axes: + perm += (i,) + perm += axes + x = x.transpose(perm) + + shape = x.shape + n, m = shape[-2:] + + e = F.eye(n, m, dtype) + if offset >= m or offset <= -n: + e = F.fill(dtype, (n, m), 0) + elif offset != 0: + e = e.astype(mstype.float32) + if offset > 0: + e_left = F.fill(dtype, (n, offset), 0) + e_right = e[..., 0:m-offset:1] + e = P.Concat(1)((e_left, e_right)).astype(dtype) + elif offset < 0: + e_upper = F.fill(dtype, (-offset, m), 0) + e_lower = e[0:n+offset:1, ...] + e = P.Concat(0)((e_upper, e_lower)).astype(dtype) + e = P.BroadcastTo(shape)(e) + + prod = F.tensor_mul(x, e) + res = F.reduce_sum(prod.astype(mstype.float32), -1) + + begin = () + for i in range(ndim-2): + begin += (0,) + last_dim_begin = max_(0, -offset) + begin += (last_dim_begin,) + size = res.shape[:-1] + last_dim_end = min_( + shape[-2], max_(0, shape[-1] - offset)) - last_dim_begin + if last_dim_end <= 0: + return Tensor([]) + size += (last_dim_end,) + res = F.tensor_slice(res, begin, size) + return res.astype(dtype) + + +def trace(x, offset=0, axis1=0, axis2=1, dtype=None): + """ + Returns the sum along diagonals of the array. + + Args: + offset (int, optional): Offset of the diagonal from the main diagonal. + Can be positive or negative. Defaults to main diagonal. + axis1 (int, optional): Axis to be used as the first axis of the 2-D + sub-arrays from which the diagonals should be taken. Defaults to + first axis (0). + axis2 (int, optional): Axis to be used as the second axis of the 2-D + sub-arrays from which the diagonals should be taken. Defaults to + second axis. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor, sum_along_diagonals. + + Raises: + ValueError: if the input tensor has less than two dimensions. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.eye(3) + >>> print(x.trace()) + 3.0 + """ + d = x.diagonal(offset, axis1=axis1, axis2=axis2) + shape = d.shape + if dtype is None: + dtype = d.dtype + if shape[-1] == 0: + return F.fill(dtype, shape[:-1], 0) + res = F.reduce_sum(d.astype(mstype.float32), -1) + return res.astype(dtype) + + +def take(x, indices, axis=None, mode='clip'): + """ + Takes elements from an array along an axis. + + Args: + a (Tensor): Source array with shape `(Ni…, M, Nk…)`. + indices (Tensor): The indices with shape `(Nj...)` of the values to extract. + axis (int, optional): The axis over which to select values. By default, + the flattened input array is used. + mode (‘raise’, ‘wrap’, ‘clip’, optional): + - edge: Pads with the edge values of `arr`. + - raise: Raises an error; + - wrap: Wraps around; + - clip: Clips to the range. `clip` mode means that all indices that are + too large are replaced by the index that addresses the last element + along that axis. Note that this disables indexing with negative numbers. + + Returns: + Tensor, the indexed result. + + Raises: + ValueError: if axis is out of range. + TypeError: if the input is not a Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> a = np.array([4, 3, 5, 7, 6, 8]) + >>> indices = np.array([0, 1, 4]) + >>> output = a.take(indices) + >>> print(output) + [4 3 6] + """ + if mode not in ('raise', 'wrap', 'clip'): + const_utils.raise_value_error('raise should be one of "raise", "wrap", or "clip"') + if axis is None: + a = x.ravel() + axis = 0 + else: + a = x + ndim = a.ndim + axis = check_axis_in_range_const(axis, ndim) + + shape_a = a.shape + shape_indices = indices.shape + size_indices = indices.size + indices = compile_utils.check_indices(shape_a[axis], indices, mode) + + # reshapes indices to shape (Ni..., Nj..., Nk) + shape_ni = tuple_slice(shape_a, None, axis) + shape_nk = tuple_slice(shape_a, axis + 1, None) + shape_out = shape_ni + shape_indices + shape_nk + shape_indices = expanded_shape(ndim, size_indices, axis) + indices = indices.reshape(shape_indices) + shape_indices = shape_ni + (indices.size,) + shape_nk + indices = P.BroadcastTo(shape_indices)(indices) + + res = F.gather_d(a, axis, indices) + return res.reshape(shape_out) + + +def choose(x, choices, mode='clip'): + """ + Construct an array from an index array and a list of arrays to choose from. + + Args: + choices (sequence of arrays): Choice arrays. `a` and all of the `choices` must + be broadcastable to the same shape. If `choices` is itself an array, then + its outermost dimension (i.e., the one corresponding to ``choices.shape[0]``) + is taken as defining the “sequence”. + mode (‘raise’, ‘wrap’, ‘clip’, optional): Specifies how indices outside + ``[0, n-1]`` will be treated: + + ‘raise’ – raise an error (default); + + ‘wrap’ – wrap around; + + ‘clip’ – clip to the range. ‘clip’ mode means that all indices that are + too large are replaced by the index that addresses the last element + along that axis. Note that this disables indexing with negative numbers. + + Returns: + Tensor, the merged result. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Raises: + ValueError: if ``len(condlist) != len(choicelist)``. + + Examples: + >>> import mindspore.numpy as np + >>> choices = [[0, 1, 2, 3], [10, 11, 12, 13], + [20, 21, 22, 23], [30, 31, 32, 33]] + >>> x = np.array([2, 3, 1, 0]) + >>> print(x.choose(choices)) + [20 31 12 3] + """ + if check_is_tensor(F.typeof(choices)): + shape_choice = infer_out_shape(x.shape, choices.shape[1:]) + choices = P.BroadcastTo((choices.shape[0],) + shape_choice)(choices) + else: + # broadcasts choices to the same shape if choices is a sequence + choicelist = [] + shapes = () + for choice in choices: + if not check_is_tensor(F.typeof(choice)): + choice = const_utils.make_tensor(choice) + shapes += (choice.shape,) + choicelist.append(choice) + shape_choice = infer_out_shape(x.shape, *shapes) + tmp = [] + for choice in choicelist: + tmp.append(P.BroadcastTo(shape_choice)(choice)) + choices = F.stack(tmp) + + if x.ndim == 0 or choices.ndim == 0: + const_utils.raise_value_error('input cannot be scalars') + a = P.BroadcastTo(shape_choice)(x) + dtype = choices.dtype + # adjusts dtype for F.tensor_mul and F.gather_nd + a = a.astype(mstype.int32) + choices = choices.astype(mstype.int32) + a = compile_utils.check_indices(choices.shape[0], a, mode, allow_negative_index=False) + + grids = [] + ndim = len(a.shape) + for i in range(ndim): + dim_grid = const_utils.make_tensor(F.make_range(a.shape[i]), mstype.int32) + dim_shape = expanded_shape(ndim, a.shape[i], i) + dim_grid = P.BroadcastTo(a.shape)(dim_grid.reshape(dim_shape)) + grids.append(dim_grid) + grid = P.Stack(-1)(grids) + indices = P.Concat(-1)((a.reshape(a.shape + (1,)), grid)) + return F.gather_nd(choices, indices).astype(dtype) + + +def searchsorted(x, v, side='left', sorter=None): + """ + Finds indices where elements should be inserted to maintain order. + + Args: + v (Union[int, float, bool, list, tuple, Tensor]): Values to insert into `a`. + side ('left', 'right', optional): If ‘left’, the index of the first suitable + location found is given. If ‘right’, return the last such index. If there is + no suitable index, return either 0 or N (where N is the length of `a`). + sorter (Union[int, float, bool, list, tuple, Tensor]): 1-D optional array of + integer indices that sort array `a` into ascending order. They are typically + the result of argsort. + + Returns: + Tensor, array of insertion points with the same shape as `v`. + + Raises: + ValueError: if argument for `side` or `sorter` is invalid. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> from mindspore import numpy as np + >>> x = np.array([1,2,3,4,5]) + >>> print(x.searchsorted(3)) + 2 + """ + if side not in ('left', 'right'): + const_utils.raise_value_error('invalid value for keyword "side"') + a = x.astype(mstype.float32) + if not check_is_tensor(F.typeof(v)): + v = const_utils.make_tensor(v) + shape = v.shape + if sorter is not None: + if sorter.ndim != 1 or sorter.size != a.size: + const_utils.raise_value_error('sorter must be 1-D array with the same size as `a`') + sorter = const_utils.make_tensor(sorter) + sorter = sorter.reshape(sorter.shape + (1,)) + a = F.gather_nd(a, sorter) + less_op = F.tensor_le if side == 'left' else F.tensor_lt + i = F.fill(mstype.int32, shape, 0) + j = F.fill(mstype.int32, shape, a.size) + + sort_range = F.make_range(get_log2_size(F.shape_mul(shape) + 1)) + for _ in sort_range: + mid = (i - F.neg_tensor(j))//2 + mask = less_op(v, F.gather_nd(a, mid.reshape(mid.shape + (1,)))) + i = F.select(mask, i, mid) + j = F.select(mask, mid, j) + return j + + +def fill(x, value): + """ + Fills the array with a scalar value. + + Note: + Unlike Numpy, tensor.fill() will always returns a new tensor, instead of + filling the original tensor. + + Args: + value (Union[None, int, float, bool]): All elements of a will be assigned this value. + + Returns: + Tensor, with the original dtype and shape as input tensor. + + Raises: + TypeError: If input arguments have types not specified above. + ValueError: If `shape` has entries < 0. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> a = Tensor(np.arange(4).reshape((2,2)).astype('float32')) + >>> print(a.fill(1.0)) + [[1. 1.] + [1. 1.]] + """ + if value is None: + if x.dtype not in (mstype.float16, mstype.float32, mstype.float64): + const_utils.raise_type_error("If None is used as value, the original Tensor's dtype must be float.") + value = nan_tensor + return F.tile(value, x.shape).astype(x.dtype) + if not isinstance(value, (int, float, bool)): + const_utils.raise_type_error("input value must be a scalar.") + return F.fill(x.dtype, x.shape, value) + + +def ptp(x, axis=None, keepdims=False): + """ + The name of the function comes from the acronym for ‘peak to peak’. + + Note: + Numpy arguments `dtype` and `out` are not supported. + + Args: + x (Tensor): Input tensor. + axis (Union[None, int, tuple(int)]): Axis or axes along which the range is computed. + The default is to compute the variance of the flattened array. Default: None. + keepdims (bool): Default is False. + + Returns: + Tensor. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> from mindspore import Tensor + >>> x = Tensor([[4.0, 9.0, 2.0, 10.0], [6.0, 9.0, 7.0, 12.0]]).astype("float32") + >>> print(x.ptp(axis=1)) + [8. 6.] + >>> print(x.ptp(axis=0)) + [2. 0. 5. 2.] + """ + if not isinstance(keepdims, bool): + const_utils.raise_type_error('keepdims should be boolean') + if axis is None: + axis = () + else: + check_axis_type(axis, True, True, False) + axis = check_axis_valid(axis, x.ndim) + + return x.max(axis, keepdims) - x.min(axis, keepdims) + + +def clip(x, xmin, xmax, dtype=None): + """ + Clips (limits) the values in an array. + + Given an interval, values outside the interval are clipped to the interval edges. + For example, if an interval of :math:`[0, 1]` is specified, values smaller than 0 become 0, + and values larger than 1 become 1. + + Note: + Currently, clip with `nan` is not supported. + + Args: + x (Tensor): Tensor containing elements to clip. + xmin (Tensor, scalar, None): Minimum value. If None, clipping is not performed + on lower interval edge. Not more than one of `xmin` and `xmax` may be None. + xmax (Tensor, scalar, None): Maximum value. If None, clipping is not performed + on upper interval edge. Not more than one of `xmin` and `xmax` may be None. + If `xmin` or `xmax` are tensors, then the three tensors will be broadcasted + to match their shapes. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor, a tensor with the elements of `x`, but where values + < `xmin` are replaced with `xmin`, and those > `xmax` with `xmax`. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> from mindspore import Tensor + >>> x = Tensor([1, 2, 3, -4, 0, 3, 2, 0]).astype("float32") + >>> output = x.clip(x, 0, 2) + >>> print(output) + [1 2 2 0 0 2 2 0] + """ + if xmin is None and xmax is None: + const_utils.raise_value_error("One of max or min must be given.") + is_scalar = False + if xmin is not None: + xmin = const_utils.make_tensor(xmin).astype(x.dtype) + if x.ndim == 0 and xmin.ndim == 0: + x = F.maximum(x.reshape((1,)), xmin).squeeze() + else: + x = F.maximum(x, xmin) + if xmax is not None: + xmax = const_utils.make_tensor(xmax).astype(x.dtype) + if x.ndim == 0 and xmax.ndim == 0: + x = F.minimum(x.reshape((1,)), xmax).squeeze() + else: + x = F.minimum(x, xmax) + if is_scalar: + return x.squeeze() + if dtype is not None and dtype != x.dtype: + return x.astype(dtype) + return x + + +def var(x, axis=None, ddof=0, keepdims=False): + """ + Compute the variance along the specified axis. + The variance is the average of the squared deviations from the mean, i.e., + :math:`var = mean(abs(x - x.mean())**2)`. + + Return the variance, which is computed for the flattened array by default, + otherwise over the specified axis. + + Note: + Numpy arguments `dtype`, `out` and `where` are not supported. + + Args: + x (Tensor): A Tensor to be calculated. + axis (Union[None, int, tuple(int)]): Axis or axes along which the variance is computed. + The default is to compute the variance of the flattened array. Default: `None`. + ddof (int): Means Delta Degrees of Freedom. Default: 0. + The divisor used in calculations is :math:`N - ddof`, where :math:`N` represents the number of elements. + keepdims (bool): Default: `False`. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Returns: + Standard deviation tensor. + + Examples: + >>> import mindspore.numpy as np + >>> input_x = np.array([1., 2., 3., 4.]) + >>> print(input_x.var()) + 1.25 + """ + if 0 in x.shape: + return nan_tensor.astype(x.dtype) + if not isinstance(ddof, int) or not isinstance(keepdims, int): + const_utils.raise_type_error("integer argument expected") + + if axis is None: + axis = () + else: + axis = check_and_canonicalize_axes(axis, x.ndim) + x_mean = _mean_keepdims(x, axis) + x_sub = F.tensor_sub(x, x_mean) + x_pow = F.tensor_pow(x_sub, 2) + if keepdims: + x_sum = _reduce_sum_keepdims(x_pow, axis) + else: + x_sum = _reduce_sum_default(x_pow, axis) + + if axis == (): + axis = F.make_range(x.ndim) + nums = 1 + for ax in axis: + nums *= x.shape[ax] + return F.tensor_div(x_sum, nums - ddof) + + +def std(x, axis=None, ddof=0, keepdims=False): + """ + Compute the standard deviation along the specified axis. + The standard deviation is the square root of the average of the squared deviations + from the mean, i.e., :math:`std = sqrt(mean(abs(x - x.mean())**2))`. + + Return the standard deviation, which is computed for the flattened array by default, + otherwise over the specified axis. + + Note: + Numpy arguments `dtype`, `out` and `where` are not supported. + + Args: + x (Tensor): A Tensor to be calculated. + axis (Union[None, int, tuple(int)]): Axis or axes along which the standard + deviation is computed. Default: `None`. + + If `None`, compute the standard deviation of the flattened array. + ddof (int): Means Delta Degrees of Freedom. The divisor used in calculations is :math:`N - ddof`, + where :math:`N` represents the number of elements. Default: 0. + keepdims: Default: `False`. + + Returns: + Standard deviation tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> input_x = np.array([1., 2., 3., 4.]) + >>> print(input_x.std()) + 1.118034 + """ + x_var = var(x, axis, ddof, keepdims) + return F.tensor_pow(x_var, 0.5) + + +def sum(x, axis=None, dtype=None, keepdims=False, initial=None): # pylint: disable=redefined-builtin + """ + Return sum of array elements over a given axis. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and + `extobj` are not supported. + + Args: + x (Union[int, float, bool, list, tuple, Tensor]): Elements to sum. + axis (Union[None, int, tuple(int)]): Axis or axes along which a sum is performed. Default: None. + If None, sum all of the elements of the input array. + If axis is negative it counts from the last to the first axis. + If axis is a tuple of ints, a sum is performed on all of the axes specified in the tuple + instead of a single axis or all the axes as before. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + keepdims (bool): If this is set to True, the axes which are reduced are left in the result as + dimensions with size one. With this option, the result will broadcast correctly against the input array. + If the default value is passed, then keepdims will not be passed through to the sum method of + sub-classes of ndarray, however any non-default value will be. If the sub-class’ method does not + implement keepdims any exceptions will be raised. + initial (scalar): Starting value for the sum. + + Returns: + Tensor. A tensor with the same shape as input, with the specified axis removed. + If input tensor is a 0-d array, or if axis is None, a scalar is returned. + + Raises: + TypeError: If input is not array_like or `axis` is not int or tuple of ints or + `keepdims` is not integer or `initial` is not scalar. + ValueError: If any axis is out of range or duplicate axes exist. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> input_x = np.array([-1, 0, 1]).astype('int32') + >>> print(input_x.sum()) + 0 + >>> input_x = np.arange(10).reshape(2, 5).astype('float32') + >>> print(input_x.sum(axis=1)) + [10. 35.] + """ + dtype = x.dtype if dtype is None else dtype + if not isinstance(keepdims, int): + const_utils.raise_type_error("integer argument expected") + if initial is not None and not isinstance(initial, (int, float, bool)): + const_utils.raise_type_error("initial argument should be a scalar.") + if axis is None: + axis = () + else: + axis = check_and_canonicalize_axes(axis, x.ndim) + + if x.dtype == mstype.bool_: + x = x.astype("int32") + if 0 in x.shape: + x = Tensor([0], x.dtype) + if keepdims: + res = _reduce_sum_keepdims(x, axis) + else: + res = _reduce_sum_default(x, axis) + if initial is not None: + res += initial + return res.astype(dtype) + + +def repeat(x, repeats, axis=None): + """ + Repeat elements of an array. + + Args: + x (Tensor): Input tensor. + repeats (Union[int, tuple, list]): The number of repetitions for each element. + `repeats` is broadcasted to fit the shape of the given axis. + axis (int, optional): The axis along which to repeat values. By default, + use the flattened input tensor, and return a flat output tensor. + + Returns: + Tensor, has the same shape as input tensor except along the given axis. + + Raises: + ValueError: if axis is out of range. + TypeError: if input is not a Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.array(3) + >>> print(x.repeat(4)) + [3 3 3 3] + >>> x = np.array([[1,2],[3,4]]) + >>> print(x.repeat(2)) + [1 1 2 2 3 3 4 4] + >>> print(x.repeat(3, axis=1)) + [[1 1 1 2 2 2] + [3 3 3 4 4 4]] + >>> print(x.repeat([1,2], axis=0)) + [[1 2] + [3 4] + [3 4]] + """ + if not isinstance(repeats, (tuple, list)): + repeats = (repeats,) + for element in repeats: + if not isinstance(element, int): + const_utils.raise_type_error("Each element should be integer") + if axis is None: + x = ravel(x) + axis = 0 + if not isinstance(axis, int): + const_utils.raise_type_error('axes should be integers') + check_axis_in_range_const(axis, x.ndim) + axis = axis + x.ndim if axis < 0 else axis + + if len(repeats) == 1: + repeats = repeats[0] + if repeats == 0: + return empty_tensor(x.dtype) + return repeat_elements(x, repeats, axis) + size = x.shape[axis] + if len(repeats) != size: + const_utils.raise_value_error('operands could not be broadcast together') + subs = P.Split(axis, size)(x) + repeated_subs = [] + for sub, rep in zip(subs, repeats): + if rep != 0: + repeated_subs.append(repeat_elements(sub, rep, axis)) + return P.Concat(axis)(repeated_subs) + + +def getitem(data, index): """Implementation of `getitem`.""" - return data.__getitem__(item) + return data.__getitem__(index) -def setitem(data, item, value): +def setitem(data, index, value): """Implementation of `setitem`.""" - return data.__setitem__(item, value) + return data.__setitem__(index, value) + + +def item(data, *args): + """Implementation of `item`.""" + return compile_utils.tensor_item(data, *args) + + +def itemset(data, *args): + """Implementation of `itemset`.""" + return compile_utils.tensor_itemset(data, *args) def ms_iter(xs): @@ -573,6 +1490,11 @@ def while_cond(x): return x +@constexpr +def empty_tensor(dtype): + return Tensor([], dtype) + + @constexpr def check_type_same(x_type, base_type): """Check x_type is same as base_type.""" @@ -688,7 +1610,15 @@ check_flatten_order_const = constexpr(validator.check_flatten_order) check_swapaxes_axis_const = constexpr(validator.check_swapaxes_axis) prepare_shape_for_squeeze_const = constexpr(validator.prepare_shape_for_squeeze) check_axis_in_range_const = constexpr(validator.check_axis_in_range) - +check_axis_valid = constexpr(validator.check_axis_valid) +max_ = constexpr(validator.max_) +min_ = constexpr(validator.min_) +expanded_shape = constexpr(validator.expanded_shape) +tuple_slice = constexpr(validator.tuple_slice) +infer_out_shape = constexpr(validator.infer_out_shape) +get_log2_size = constexpr(validator.get_log2_size) +check_axis_type = constexpr(validator.check_axis_type) +check_and_canonicalize_axes = constexpr(validator.check_and_canonicalize_axes) def tensor_bool(x): """tensor as condition, if is constant, return immediate bool value""" @@ -817,6 +1747,7 @@ def list_hasnext(xs): return len(xs) > 0 +# pylint: disable=redefined-outer-name def list_append(self_, item): return _append(self_, item) diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc index bd2bc15a279..65a6597ded0 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.cc +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -175,7 +175,8 @@ BuiltInTypeMap &GetMethodMap() { {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem, {"__ms_iter__", std::string("array_iter")}, // C.array_iter {"__ms_to_array__", prim::kPrimIdentity}, // P.identity, - {"item", prim::kPrimArrayToScalar}, // P.array_to_scalar, + {"item", std::string("item")}, // P.item, + {"itemset", std::string("itemset")}, // P.itemset, {"transpose", std::string("transpose")}, // P.transpose {"flatten", std::string("flatten")}, // P.reshape(,-1) {"reshape", std::string("reshape")}, // P.reshape() @@ -183,9 +184,26 @@ BuiltInTypeMap &GetMethodMap() { {"swapaxes", std::string("swapaxes")}, // P.transpose() {"squeeze", std::string("squeeze")}, // P.squeeze() {"astype", std::string("astype")}, // P.cast() - {"__bool__", std::string("tensor_bool")}, // C.tensor_bool - {"argmax", std::string("argmax")}, // P.Argmax() - {"argmin", std::string("argmin")}, // P.Argmax() + {"cumsum", std::string("cumsum")}, // P.cumsum() + {"copy", std::string("copy")}, // copy() + {"max", std::string("max")}, // P.reduce_max() + {"min", std::string("min")}, // P.reduce_min() + {"fill", std::string("fill")}, // P.fill() + {"ptp", std::string("ptp")}, // P.reduce_max() - P.reduce_min() + {"clip", std::string("clip")}, // P.maximum(P.minimum) + {"__bool__", std::string("tensor_bool")}, // C.tensor_bool + {"argmax", std::string("argmax")}, // P.Argmax() + {"argmin", std::string("argmin")}, // P.Argmax() + {"resize", std::string("resize")}, // P.Reshape() + {"choose", std::string("choose")}, // P.Select() + {"diagonal", std::string("diagonal")}, // P.Eye() + {"searchsorted", std::string("searchsorted")}, // P.Select() + {"take", std::string("take")}, // P.GatherNd() + {"trace", std::string("trace")}, // P.Eye() + {"var", std::string("var")}, // P.ReduceSum + {"std", std::string("std")}, // P.ReduceSum + {"sum", std::string("sum")}, // P.ReduceSum + {"repeat", std::string("repeat")}, // C.repeat_elements }}, {kObjectTypeRowTensorType, { diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index f8b659fe5f6..ffe3a82f155 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -331,6 +331,16 @@ class Tensor(Tensor_): """Convert numpy array to Tensor without copy data.""" return Tensor(Tensor_.from_numpy(array)) + def item(self, index=None): + """Getitem from the Tensor with the index.""" + output = tensor_operator_registry.get('item')(self, index) + return output + + def itemset(self, *args): + """Setitem from the Tensor with the index.""" + output = tensor_operator_registry.get('itemset')(self, *args) + return output + def asnumpy(self): """Convert tensor to numpy array.""" self.init_check() @@ -751,6 +761,317 @@ class Tensor(Tensor_): # P.Argmin is currently not supported return tensor_operator_registry.get('argmax')(axis)(tensor_operator_registry.get('__neg__')(a)) + def cumsum(self, axis=None, dtype=None): + """ + Returns the cumulative sum of the elements along a given axis. + + Note: + If ``self.dtype`` is :class:`int8`, :class:`int16` or :class:`bool`, the result + `dtype` will be elevated to :class:`int32`. + + Args: + self (Tensor): Input tensor. + axis (int, optional): Axis along which the cumulative sum is computed. The + default (None) is to compute the cumsum over the flattened array. + dtype (:class:`mindspore.dtype`, optional): If not specified, stay the same as original, + tensor, unless it has an integer dtype with a precision less than :class:`float32`. + In that case, :class:`float32` is used. + + Returns: + Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> a = Tensor(np.ones((3,3)).astype("float32")) + >>> output = a.cumsum(0) + >>> print(output) + [[1. 1. 1.] + [2. 2. 2.] + [3. 3. 3.]] + """ + x = self + original_dtype = x.dtype + # If original tensor is int, and has precision less then int32, convert to int32 + if mstype.issubclass_(x.dtype, mstype.int_) and x.itemsize < 4: + x = x.astype(mstype.int32) + if axis is None: + x = x.ravel() + axis = 0 + validator.check_axis_in_range(axis, x.ndim) + if dtype is not None and original_dtype != dtype: + return tensor_operator_registry.get('cumsum')()(x, axis).astype(dtype, copy=False) + return tensor_operator_registry.get('cumsum')()(x, axis) + + def copy(self): + """ + Returns a copy of the tensor. + + Note: + The current implementation does not support `order` argument. + + Args: + self (Tensor): Input tensor. + + Returns: + Copied tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> a = Tensor(np.ones((3,3)).astype("float32")) + >>> output = a.copy() + >>> print(output) + [[1. 1. 1.] + [1. 1. 1.] + [1. 1. 1.]] + """ + if self.size == 0: + return self + origin_dtype = self.dtype + x = self + logical_not_op = tensor_operator_registry.get('logical_not')() + if origin_dtype == mstype.bool_: + return logical_not_op(logical_not_op(x)) + if origin_dtype != mstype.float64: + x = x.astype("float32") + x = x / 1.0 + x = x.astype(origin_dtype) + return x + + def max(self, axis=None, keepdims=False, initial=None, where=True): + """ + Returns the maximum of a tensor or maximum along an axis. + + Args: + self (Tensor): Input Tensor. + axis (None or int or tuple of ints, optional): defaults to None. Axis or + axes along which to operate. By default, flattened input is used. If + this is a tuple of ints, the maximum is selected over multiple axes, + instead of a single axis or all the axes as before. + keepdims (boolean, optional): defaults to False. + If this is set to True, the axes which are reduced are left in the + result as dimensions with size one. With this option, the result will + broadcast correctly against the input array. + initial (scalar, optional): + The minimum value of an output element. Must be present to allow + computation on empty slice. + where (boolean Tensor, optional): defaults to True. + A boolean array which is broadcasted to match the dimensions of array, + and selects elements to include in the reduction. If non-default value + is passed, initial must also be provided. + + Returns: + Tensor or scalar, maximum of input tensor. If `axis` is None, the result is a scalar + value. If `axis` is given, the result is an array of dimension ``self.ndim - 1``. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> import mindspore.numpy as np + >>> a = Tensor(np.arange(4).reshape((2,2)).astype('float32')) + >>> output = a.max() + >>> print(output) + 3.0 + """ + reduce_ = tensor_operator_registry.get("reduce") + reduce_max = tensor_operator_registry.get("reduce_max") + maximum = tensor_operator_registry.get("maximum") + return reduce_(self, reduce_max(keepdims), cmp_fn=maximum(), axis=axis, keepdims=keepdims, + initial=initial, where=where) + + def min(self, axis=None, keepdims=False, initial=None, where=True): + """ + Returns the minimum of a tensor or minimum along an axis. + + Args: + self (Tensor): Input data. + axis (None or int or tuple of ints, optional): defaults to None. Axis or + axes along which to operate. By default, flattened input is used. If + this is a tuple of ints, the minimum is selected over multiple axes, + instead of a single axis or all the axes as before. + keepdims (boolean, optional): defaults to False. + If this is set to True, the axes which are reduced are left in the + result as dimensions with size one. With this option, the result will + broadcast correctly against the input array. + initial (scalar, optional): + The maximum value of an output element. Must be present to allow + computation on empty slice. + where (boolean Tensor, optional): defaults to True. + A boolean array which is broadcasted to match the dimensions of array, + and selects elements to include in the reduction. If non-default value + is passed, initial must also be provided. + + Returns: + Tensor or scalar, minimum of input tensor. If axis is None, the result is a scalar + value. If `axis` is given, the result is an array of dimension ``self.ndim - 1``. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> import mindspore.numpy as np + >>> a = Tensor(np.arange(4).reshape((2,2)).astype('float32')) + >>> output = a.min() + >>> print(output) + 0.0 + """ + reduce_ = tensor_operator_registry.get("reduce") + reduce_min = tensor_operator_registry.get("reduce_min") + minimum = tensor_operator_registry.get("minimum") + return reduce_(self, reduce_min(keepdims), cmp_fn=minimum(), axis=axis, keepdims=keepdims, + initial=initial, where=where) + + def fill(self, value): + """ + Fills the array with a scalar value. + + Note: + Unlike Numpy, tensor.fill() will always returns a new tensor, instead of + filling the original tensor. + + Args: + value (Union[None, int, float, bool]): All elements of a will be assigned this value. + + Returns: + Tensor, with the original dtype and shape as input tensor. + + Raises: + TypeError: If input arguments have types not specified above. + ValueError: If `shape` has entries < 0. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> a = Tensor(np.arange(4).reshape((2,2)).astype('float32')) + >>> print(a.fill(1.0)) + [[1. 1.] + [1. 1.]] + """ + if value is None: + if self.dtype not in (mstype.float16, mstype.float32, mstype.float64): + raise TypeError("If None is used as value, the original Tensor's dtype must be float.") + value = Tensor(float('nan')).astype("float32") + return tensor_operator_registry.get("tile")()(value, self.shape).astype(self.dtype) + if not isinstance(value, (int, float, bool)): + raise TypeError("input value must be a scalar.") + return tensor_operator_registry.get("fill")(self.dtype, self.shape, value) + + def ptp(self, axis=None, keepdims=False): + """ + The name of the function comes from the acronym for ‘peak to peak’. + + Note: + Numpy arguments `dtype` and `out` are not supported. + + Args: + self (Tensor): Input tensor. + axis (Union[None, int, tuple(int)]): Axis or axes along which the range is computed. + The default is to compute the variance of the flattened array. Default: None. + keepdims (bool): Default is False. + + Returns: + Tensor. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> from mindspore import Tensor + >>> x = Tensor([[4.0, 9.0, 2.0, 10.0], [6.0, 9.0, 7.0, 12.0]]).astype("float32") + >>> print(x.ptp(axis=1)) + [8. 6.] + >>> print(x.ptp(axis=0)) + [2. 0. 5. 2.] + """ + if not isinstance(keepdims, bool): + raise TypeError('keepdims should be boolean') + if axis is None: + axis = () + else: + validator.check_axis_type(axis, True, True, False) + axis = validator.check_axis_valid(axis, self.ndim) + + return self.max(axis, keepdims) - self.min(axis, keepdims) + + def clip(self, xmin, xmax, dtype=None): + """ + Clips (limits) the values in a Tensor. + + Given an interval, values outside the interval are clipped to the interval edges. + For example, if an interval of :math:`[0, 1]` is specified, values smaller than 0 become 0, + and values larger than 1 become 1. + + Note: + Currently, clip with `nan` is not supported. + + Args: + self (Tensor): Tensor containing elements to clip. + xmin (Tensor, scalar, None): Minimum value. If None, clipping is not performed + on lower interval edge. Not more than one of `xmin` and `xmax` may be None. + xmax (Tensor, scalar, None): Maximum value. If None, clipping is not performed + on upper interval edge. Not more than one of `xmin` and `xmax` may be None. + If `xmin` or `xmax` are tensors, then the three tensors will be broadcasted + to match their shapes. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor, a tensor with the elements of input tensor, but where values + < `xmin` are replaced with `xmin`, and those > `xmax` with `xmax`. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> from mindspore import Tensor + >>> x = Tensor([1, 2, 3, -4, 0, 3, 2, 0]).astype("float32") + >>> output = x.clip(x, 0, 2) + >>> print(output) + [1 2 2 0 0 2 2 0] + """ + if xmin is None and xmax is None: + raise ValueError("One of max or min must be given.") + x = self + # F.maximum/minimum does not support when both operands are scalar + if xmin is not None: + xmin = Tensor(xmin).astype(x.dtype) + if x.ndim == 0 and xmin.ndim == 0: + x = tensor_operator_registry.get("maximum")()(x.reshape((1,)), xmin).squeeze() + else: + x = tensor_operator_registry.get("maximum")()(x, xmin) + if xmax is not None: + xmax = Tensor(xmax).astype(x.dtype) + if x.ndim == 0 and xmax.ndim == 0: + x = tensor_operator_registry.get("minimum")()(x.reshape((1,)), xmax).squeeze() + else: + x = tensor_operator_registry.get("minimum")()(x, xmax) + if dtype is not None and dtype != x.dtype: + return x.astype(dtype) + return x def init_check(self): if self.has_init: @@ -821,6 +1142,583 @@ class Tensor(Tensor_): " Please use init_data") return self.init_data(slice_index, shape, opt_shard_group) + def resize(self, *new_shape): + """ + Changes shape and size of array in-place. + + Note: + Instead of changing the size of the input array and returns nothing as in numpy, + this method returns a new Tensor with the input size. + Numpy argument `refcheck` is not supported. + + Args: + new_shape (Union[ints, tuple of ints]): Shape of resized array. + + Returns: + Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> from mindspore import numpy as np + >>> x = np.array([[0, 1], [2, 3]]) + >>> x = x.resize(2, 3) + >>> print(x) + [[0 1 2] + [3 0 0]] + """ + if not new_shape: + return self + if len(new_shape) == 1: + if isinstance(new_shape[0], tuple): + new_shape = new_shape[0] + flattened = self.ravel() + cur_size = flattened.size + new_size = tensor_operator_registry.get('shape_mul')(new_shape) + diff_size = new_size - cur_size + if diff_size > 0: + pad_val = tensor_operator_registry.get('fill')(self.dtype, (diff_size,), 0) + res = tensor_operator_registry.get('concatenate')(0)((flattened, pad_val)) + else: + res = flattened[:new_size] + return res.reshape(new_shape) + + def diagonal(self, offset=0, axis1=0, axis2=1): + """ + Returns specified diagonals. + + Args: + offset (int, optional): Offset of the diagonal from the main diagonal. + Can be positive or negative. Defaults to main diagonal. + axis1 (int, optional): Axis to be used as the first axis of the 2-D + sub-arrays from which the diagonals should be taken. Defaults to + first axis (0). + axis2 (int, optional): Axis to be used as the second axis of the 2-D + sub-arrays from which the diagonals should be taken. Defaults to + second axis. + + Returns: + Tensor, if `a` is 2-D, then `a` 1-D array containing the diagonal. + + Raises: + ValueError: if the input tensor has less than two dimensions. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> a = np.arange(4).reshape(2,2) + >>> print(a) + [[0 1] + [2 3]] + >>> output = a.diagonal() + >>> print(output) + [0 3] + """ + ndim = self.ndim + if ndim < 2: + raise ValueError('diagonal requires an array of at least two dimensions') + dtype = self.dtype + + axes = validator.check_axis_valid((axis1, axis2), ndim) + perm = () + for i in range(ndim): + if i not in axes: + perm += (i,) + perm += axes + a = self.transpose(perm) + + shape = a.shape + n, m = shape[-2:] + + e = tensor_operator_registry.get('eye')(n, m, dtype) + if offset >= m or offset <= -n: + e = tensor_operator_registry.get('fill')(dtype, (n, m), 0) + elif offset != 0: + e = e.astype(mstype.float32) + if offset > 0: + e_left = tensor_operator_registry.get('fill')(dtype, (n, offset), 0) + e_right = e[..., 0:m-offset:1] + e = tensor_operator_registry.get('concatenate')(1)((e_left, e_right)).astype(dtype) + elif offset < 0: + e_upper = tensor_operator_registry.get('fill')(dtype, (-offset, m), 0) + e_lower = e[0:n+offset:1, ...] + e = tensor_operator_registry.get('concatenate')(0)((e_upper, e_lower)).astype(dtype) + e = tensor_operator_registry.get('broadcast_to')(shape)(e) + + prod = tensor_operator_registry.get('__mul__')(a, e) + res = tensor_operator_registry.get('reduce_sum')(prod.astype(mstype.float32), -1) + + begin = () + for i in range(ndim-2): + begin += (0,) + last_dim_begin = max(0, -offset) + begin += (last_dim_begin,) + size = res.shape[:-1] + last_dim_end = min( + shape[-2], max(0, shape[-1] - offset)) - last_dim_begin + if last_dim_end <= 0: + return Tensor([]) + size += (last_dim_end,) + res = tensor_operator_registry.get('tensor_slice')(res, begin, size) + return res.astype(dtype) + + def trace(self, offset=0, axis1=0, axis2=1, dtype=None): + """ + Returns the sum along diagonals of the array. + + Args: + offset (int, optional): Offset of the diagonal from the main diagonal. + Can be positive or negative. Defaults to main diagonal. + axis1 (int, optional): Axis to be used as the first axis of the 2-D + sub-arrays from which the diagonals should be taken. Defaults to + first axis (0). + axis2 (int, optional): Axis to be used as the second axis of the 2-D + sub-arrays from which the diagonals should be taken. Defaults to + second axis. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor, sum_along_diagonals. + + Raises: + ValueError: if the input tensor has less than two dimensions. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.eye(3) + >>> print(x.trace()) + 3.0 + """ + d = self.diagonal(offset, axis1=axis1, axis2=axis2) + shape = d.shape + if dtype is None: + dtype = d.dtype + if shape[-1] == 0: + return tensor_operator_registry.get('fill')(dtype, shape[:-1], 0) + res = tensor_operator_registry.get('reduce_sum')(d.astype(mstype.float32), -1) + return res.astype(dtype) + + def take(self, indices, axis=None, mode='clip'): + """ + Takes elements from an array along an axis. + + Args: + a (Tensor): Source array with shape `(Ni…, M, Nk…)`. + indices (Tensor): The indices with shape `(Nj...)` of the values to extract. + axis (int, optional): The axis over which to select values. By default, + the flattened input array is used. + mode (‘raise’, ‘wrap’, ‘clip’, optional): + - edge: Pads with the edge values of `arr`. + - raise: Raises an error; + - wrap: Wraps around; + - clip: Clips to the range. `clip` mode means that all indices that are + too large are replaced by the index that addresses the last element + along that axis. Note that this disables indexing with negative numbers. + + Returns: + Tensor, the indexed result. + + Raises: + ValueError: if axis is out of range. + TypeError: if the input is not a Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> a = np.array([4, 3, 5, 7, 6, 8]) + >>> indices = np.array([0, 1, 4]) + >>> output = a.take(indices) + >>> print(output) + [4 3 6] + """ + if mode not in ('raise', 'wrap', 'clip'): + raise ValueError('raise should be one of "raise", "wrap", or "clip"') + if axis is None: + a = self.ravel() + axis = 0 + else: + a = self + ndim = a.ndim + validator.check_axis_in_range(axis, ndim) + axis = axis + ndim if axis < 0 else axis + + shape_a = a.shape + shape_indices = indices.shape + size_indices = indices.size + indices = tensor_operator_registry.get('check_indices')(shape_a[axis], indices, mode) + + # reshapes indices to shape (Ni..., Nj..., Nk) + shape_ni = shape_a[:axis] + shape_nk = shape_a[axis + 1:] + shape_out = shape_ni + shape_indices + shape_nk + shape_indices = tuple(size_indices if i == axis else 1 for i in range(ndim)) + indices = indices.reshape(shape_indices) + shape_indices = shape_ni + (indices.size,) + shape_nk + indices = tensor_operator_registry.get('broadcast_to')(shape_indices)(indices) + + res = tensor_operator_registry.get('gather_d')(a, axis, indices) + return res.reshape(shape_out) + + def choose(self, choices, mode='clip'): + """ + Construct an array from an index array and a list of arrays to choose from. + + Args: + choices (Union[tuple, list, Tensor]): Choice arrays. `a` and all of the `choices` must + be broadcastable to the same shape. If `choices` is itself an array, then + its outermost dimension (i.e., the one corresponding to ``choices.shape[0]``) + is taken as defining the “sequence”. + mode (‘raise’, ‘wrap’, ‘clip’, optional): Specifies how indices outside + ``[0, n-1]`` will be treated: + + ‘raise’ – raise an error (default); + + ‘wrap’ – wrap around; + + ‘clip’ – clip to the range. ‘clip’ mode means that all indices that are + too large are replaced by the index that addresses the last element + along that axis. Note that this disables indexing with negative numbers. + + Returns: + Tensor, the merged result. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Raises: + ValueError: if ``len(condlist) != len(choicelist)``. + + Examples: + >>> import mindspore.numpy as np + >>> choices = [[0, 1, 2, 3], [10, 11, 12, 13], + [20, 21, 22, 23], [30, 31, 32, 33]] + >>> x = np.array([2, 3, 1, 0]) + >>> print(x.choose(choices)) + [20 31 12 3] + """ + if isinstance(choices, Tensor): + shape_choice = validator.infer_out_shape(self.shape, choices.shape[1:]) + choices = tensor_operator_registry.get('broadcast_to')((choices.shape[0],) + shape_choice)(choices) + else: + # broadcasts choices to the same shape if choices is a sequence + choicelist = [] + shapes = () + for choice in choices: + if not isinstance(choice, Tensor): + choice = tensor_operator_registry.get('make_tensor')(choice) + shapes += (choice.shape,) + choicelist.append(choice) + shape_choice = validator.infer_out_shape(self.shape, *shapes) + tmp = [] + for choice in choicelist: + tmp.append(tensor_operator_registry.get('broadcast_to')(shape_choice)(choice)) + choices = tensor_operator_registry.get('stack')(0)(tmp) + + if self.ndim == 0 or choices.ndim == 0: + raise ValueError('input cannot be scalars') + a = tensor_operator_registry.get('broadcast_to')(shape_choice)(self) + dtype = choices.dtype + # adjusts dtype for F.tensor_mul and F.gather_nd + a = a.astype(mstype.int32) + choices = choices.astype(mstype.int32) + a = tensor_operator_registry.get('check_indices')(choices.shape[0], a, mode, allow_negative_index=False) + + grids = [] + ndim = len(a.shape) + for i in range(ndim): + dim_grid = Tensor(list(range(a.shape[i])), mstype.int32) + dim_shape = validator.expanded_shape(ndim, a.shape[i], i) + dim_grid = tensor_operator_registry.get('broadcast_to')(a.shape)(dim_grid.reshape(dim_shape)) + grids.append(dim_grid) + grid = tensor_operator_registry.get('stack')(-1)(grids) + indices = tensor_operator_registry.get('concatenate')(-1)((a.reshape(a.shape + (1,)), grid)) + return tensor_operator_registry.get('gather_nd')(choices, indices).astype(dtype) + + def searchsorted(self, v, side='left', sorter=None): + """ + Finds indices where elements should be inserted to maintain order. + + Args: + v (Union[int, float, bool, list, tuple, Tensor]): Values to insert into `a`. + side ('left', 'right', optional): If ‘left’, the index of the first suitable + location found is given. If ‘right’, return the last such index. If there is + no suitable index, return either 0 or N (where N is the length of `a`). + sorter (Union[int, float, bool, list, tuple, Tensor]): 1-D optional array of + integer indices that sort array `a` into ascending order. They are typically + the result of argsort. + + Returns: + Tensor, array of insertion points with the same shape as `v`. + + Raises: + ValueError: if argument for `side` or `sorter` is invalid. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> from mindspore import numpy as np + >>> x = np.array([1,2,3,4,5]) + >>> print(x.searchsorted(3)) + 2 + """ + if side not in ('left', 'right'): + raise ValueError(f'{side} is an invalid value for keyword "side"') + a = self.astype(mstype.float32) + if not isinstance(v, Tensor): + v = tensor_operator_registry.get('make_tensor')(v) + shape = v.shape + if sorter is not None: + if sorter.ndim != 1 or sorter.size != a.size: + raise ValueError('sorter must be 1-D array with the same size as `a`') + sorter = tensor_operator_registry.get('make_tensor')(sorter) + sorter = sorter.reshape(sorter.shape + (1,)) + a = tensor_operator_registry.get('gather_nd')(a, sorter) + less_op = tensor_operator_registry.get('__le__') if side == 'left' else tensor_operator_registry.get('__lt__') + i = tensor_operator_registry.get('fill')(mstype.int32, shape, 0) + j = tensor_operator_registry.get('fill')(mstype.int32, shape, a.size) + + sort_range = tuple(range(validator.get_log2_size( + tensor_operator_registry.get('shape_mul')(shape) + 1))) + for _ in sort_range: + mid = (i - -j)//2 + mask = less_op(v, tensor_operator_registry.get('gather_nd')(a, mid.reshape(mid.shape + (1,)))) + i = tensor_operator_registry.get('select')(mask, i, mid) + j = tensor_operator_registry.get('select')(mask, mid, j) + return j + + def var(self, axis=None, ddof=0, keepdims=False): + """ + Compute the variance along the specified axis. + The variance is the average of the squared deviations from the mean, i.e., + :math:`var = mean(abs(x - x.mean())**2)`. + + Return the variance, which is computed for the flattened array by default, + otherwise over the specified axis. + + Note: + Numpy arguments `dtype`, `out` and `where` are not supported. + + Args: + self (Tensor): A Tensor to be calculated. + axis (Union[None, int, tuple(int)]): Axis or axes along which the variance is computed. + The default is to compute the variance of the flattened array. Default: `None`. + ddof (int): Means Delta Degrees of Freedom. Default: 0. + The divisor used in calculations is :math:`N - ddof`, where :math:`N` represents the number of elements. + keepdims (bool): Default: `False`. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Returns: + Standard deviation tensor. + + Examples: + >>> import mindspore.numpy as np + >>> input_x = np.array([1., 2., 3., 4.]) + >>> output = input_x.var() + >>> print(output) + 1.25 + """ + if 0 in self.shape: + return Tensor(float('nan'), self.dtype) + if not isinstance(ddof, int): + raise TypeError(f"integer argument expected, but got {type(ddof)}") + if not isinstance(keepdims, int): + raise TypeError(f"integer argument expected, but got {type(keepdims)}") + + if axis is None: + axis = () + else: + axis = validator.check_and_canonicalize_axes(axis, self.ndim) + x_mean = tensor_operator_registry.get('mean')(True)(self, axis) + x_sub = tensor_operator_registry.get('__sub__')(self, x_mean) + x_pow = tensor_operator_registry.get('__pow__')(x_sub, 2) + x_sum = tensor_operator_registry.get('sum')(bool(keepdims))(x_pow, axis) + nums = 1 + if axis == (): + nums = self.size + else: + for ax in axis: + nums *= self.shape[ax] + return tensor_operator_registry.get('__truediv__')(x_sum, nums - ddof) + + def std(self, axis=None, ddof=0, keepdims=False): + """ + Compute the standard deviation along the specified axis. + The standard deviation is the square root of the average of the squared deviations + from the mean, i.e., :math:`std = sqrt(mean(abs(x - x.mean())**2))`. + + Return the standard deviation, which is computed for the flattened array by default, + otherwise over the specified axis. + + Note: + Numpy arguments `dtype`, `out` and `where` are not supported. + + Args: + self (Tensor): A Tensor to be calculated. + axis (Union[None, int, tuple(int)]): Axis or axes along which the standard + deviation is computed. Default: `None`. + + If `None`, compute the standard deviation of the flattened array. + ddof (int): Means Delta Degrees of Freedom. The divisor used in calculations is :math:`N - ddof`, + where :math:`N` represents the number of elements. Default: 0. + keepdims: Default: `False`. + + Returns: + Standard deviation tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> input_x = np.array([1., 2., 3., 4.]) + >>> output = input_x.std() + >>> print(output) + 1.118034 + """ + x_var = self.var(axis, ddof, keepdims) + return tensor_operator_registry.get('__pow__')(x_var, 0.5) + + def sum(self, axis=None, dtype=None, keepdims=False, initial=None): + """ + Return sum of array elements over a given axis. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and + `extobj` are not supported. + + Args: + self (Union[int, float, bool, list, tuple, Tensor]): Elements to sum. + axis (Union[None, int, tuple(int)]): Axis or axes along which a sum is performed. Default: None. + If None, sum all of the elements of the input array. + If axis is negative it counts from the last to the first axis. + If axis is a tuple of ints, a sum is performed on all of the axes specified in the tuple + instead of a single axis or all the axes as before. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + keepdims (bool): If this is set to True, the axes which are reduced are left in the result as + dimensions with size one. With this option, the result will broadcast correctly against the input array. + If the default value is passed, then keepdims will not be passed through to the sum method of + sub-classes of ndarray, however any non-default value will be. If the sub-class’ method does not + implement keepdims any exceptions will be raised. + initial (scalar): Starting value for the sum. + + Returns: + Tensor. A tensor with the same shape as input, with the specified axis removed. + If input tensor is a 0-d array, or if axis is None, a scalar is returned. + + Raises: + TypeError: If input is not array_like or `axis` is not int or tuple of ints or + `keepdims` is not integer or `initial` is not scalar. + ValueError: If any axis is out of range or duplicate axes exist. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> input_x = np.array([-1, 0, 1]).astype('int32') + >>> print(input_x.sum()) + 0 + >>> input_x = np.arange(10).reshape(2, 5).astype('float32') + >>> print(input_x.sum(axis=1)) + [10. 35.] + """ + dtype = self.dtype if dtype is None else dtype + if not isinstance(keepdims, int): + raise TypeError(f"integer argument expected, but got {type(keepdims)}") + if initial is not None and not isinstance(initial, (int, float, bool)): + raise TypeError("initial argument should be a scalar.") + if axis is None: + axis = () + else: + axis = validator.check_and_canonicalize_axes(axis, self.ndim) + + input_x = self.astype(mstype.int32) if self.dtype == mstype.bool_ else self + if 0 in self.shape: + input_x = Tensor([0], self.dtype) + res = tensor_operator_registry.get('sum')(bool(keepdims))(input_x, axis) + if initial is not None: + res += initial + return res.astype(dtype) + + def repeat(self, repeats, axis=None): + """ + Repeat elements of an array. + + Args: + self (Tensor): Input tensor. + repeats (Union[int, tuple, list]): The number of repetitions for each element. + `repeats` is broadcasted to fit the shape of the given axis. + axis (int, optional): The axis along which to repeat values. By default, + use the flattened input tensor, and return a flat output tensor. + + Returns: + Tensor, has the same shape as input tensor except along the given axis. + + Raises: + ValueError: if axis is out of range. + TypeError: if input is not a Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.array(3) + >>> print(x.repeat(4)) + [3 3 3 3] + >>> x = np.array([[1,2],[3,4]]) + >>> print(x.repeat(2)) + [1 1 2 2 3 3 4 4] + >>> print(x.repeat(3, axis=1)) + [[1 1 1 2 2 2] + [3 3 3 4 4 4]] + >>> print(x.repeat([1,2], axis=0)) + [[1 2] + [3 4] + [3 4]] + """ + if not isinstance(repeats, (tuple, list)): + repeats = (repeats,) + for element in repeats: + if not isinstance(element, int): + raise TypeError(f"Each element in {repeats} should be integer, but got {type(element)}.") + input_x = self + if axis is None: + input_x = self.ravel() + axis = 0 + if axis is not None and not isinstance(axis, int): + raise TypeError(f'axes should be integers, not {type(axis)}') + validator.check_axis_in_range(axis, input_x.ndim) + axis = axis + input_x.ndim if axis < 0 else axis + + if len(repeats) == 1: + repeats = repeats[0] + if repeats == 0: + return Tensor_(input_x.dtype, (0,)) + return tensor_operator_registry.get('repeat_elements')(input_x, repeats, axis) + size = input_x.shape[axis] + if len(repeats) != size: + raise ValueError('operands could not be broadcast together') + subs = tensor_operator_registry.get('split')(axis, size)(input_x) + repeated_subs = [] + for sub, rep in zip(subs, repeats): + if rep != 0: + repeated_subs.append(tensor_operator_registry.get('repeat_elements')(sub, rep, axis)) + return tensor_operator_registry.get('concatenate')(axis)(repeated_subs) + class RowTensor: """ diff --git a/mindspore/numpy/array_creations.py b/mindspore/numpy/array_creations.py index 185d5299bf9..7c8e2ca7bf3 100644 --- a/mindspore/numpy/array_creations.py +++ b/mindspore/numpy/array_creations.py @@ -28,14 +28,14 @@ from ..nn.layer.basic import triu as nn_triu from .._c_expression import Tensor as Tensor_ from .utils import _check_input_for_asarray, _deep_list, _deep_tensor_to_nparray, \ - _broadcast_to_shape, _check_input_tensor, _convert_64_to_32, _get_dtype_from_scalar, \ + _check_input_tensor, _convert_64_to_32, _get_dtype_from_scalar, \ _expand, _to_tensor, _slice_along_axis, _callable -from .utils_const import _raise_value_error, _empty, _check_axis_valid, _max, _min, \ +from .utils_const import _raise_value_error, _empty, _max, _min, \ _check_same_type, _is_shape_empty, _check_shape, _check_dtype, _tile_size, _abs, \ _raise_type_error, _expanded_shape, _check_is_float, _iota, _type_convert, \ _canonicalize_axis, _list_comprehensions, _ceil, _tuple_slice, _raise_unimplemented_error, \ _tuple_setitem -from .array_ops import transpose, ravel, concatenate, broadcast_arrays, reshape, broadcast_to, flip, \ +from .array_ops import ravel, concatenate, broadcast_arrays, reshape, broadcast_to, flip, \ apply_along_axis, where from .dtypes import nan, pi @@ -254,20 +254,8 @@ def copy_(a): [[1. 1.] [1. 1.]] """ - if not isinstance(a, Tensor): - a = asarray_const(a) - if a.size == 0: - return a - # The current implementation registers a new memory location for copied tensor by - # doing some reduandent operations. - origin_dtype = a.dtype - if origin_dtype == mstype.bool_: - return F.logical_not(F.logical_not(a)) - if origin_dtype != mstype.float64: - a = a.astype("float32") - a = a / ones_like(a) - a = a.astype(origin_dtype) - return a + a = asarray(a) + return a.copy() def ones(shape, dtype=mstype.float32): @@ -1136,51 +1124,7 @@ def diagonal(a, offset=0, axis1=0, axis2=1): [[0 6] [1 7]] """ - ndim = F.rank(a) - if ndim < 2: - return _raise_value_error('diagonal requires an array of at least two dimensions') - dtype = F.dtype(a) - - if _is_shape_empty(F.shape(a)): - return _empty(dtype, (0,)) - - cast_type = dtype - if not _check_is_float(dtype): - # reduce_sum only supports float types - cast_type = mstype.float32 - a = F.cast(a, cast_type) - - axes = _check_axis_valid((axis1, axis2), ndim) - perm = () - for i in range(ndim): - if i not in axes: - perm += (i,) - perm += axes - a = transpose(a, perm) - - shape = F.shape(a) - n, m = shape[-2:] - e = eye(n, m, offset, cast_type) - e = _broadcast_to_shape(e, F.shape(a)) - - prod = F.tensor_mul(a, e) - res = F.reduce_sum(prod, -1) - - begin = () - for i in range(ndim-2): - begin += (0,) - last_dim_begin = _max(0, -offset) - begin += (last_dim_begin,) - size = F.shape(res)[:-1] - last_dim_end = _min( - shape[-2], _max(0, shape[-1] - offset)) - last_dim_begin - if last_dim_end <= 0: - return _empty(dtype, size + (0,)) - size += (last_dim_end,) - res = F.tensor_slice(res, begin, size) - if not _check_same_type(cast_type, dtype): - res = F.cast(res, dtype) - return res + return a.diagonal(offset=offset, axis1=axis1, axis2=axis2) def trace(a, offset=0, axis1=0, axis2=1, dtype=None): @@ -1236,22 +1180,7 @@ def trace(a, offset=0, axis1=0, axis2=1, dtype=None): >>> print(output) (2, 3) """ - d = diagonal(a, offset, axis1=axis1, axis2=axis2) - shape = F.shape(d) - if dtype is None: - dtype = F.dtype(d) - if shape[-1] == 0: - return _empty(dtype, shape[:-1]) - - cast_type = dtype - if not _check_is_float(dtype): - # reduce sum only supports float types - cast_type = mstype.float32 - d = F.cast(d, cast_type) - res = F.reduce_sum(d, -1) - if not _check_same_type(cast_type, dtype): - res = F.cast(res, dtype) - return res + return a.trace(offset=offset, axis1=axis1, axis2=axis2, dtype=dtype) def _index(i, size, Cartesian=True): diff --git a/mindspore/numpy/array_ops.py b/mindspore/numpy/array_ops.py index b44760e18a1..e38ca87a12d 100644 --- a/mindspore/numpy/array_ops.py +++ b/mindspore/numpy/array_ops.py @@ -19,7 +19,6 @@ from ..common import dtype as mstype from ..common import Tensor from ..ops import operations as P from ..ops import functional as F -from ..ops import composite as C from ..ops.primitive import constexpr from ..nn import Cell @@ -1885,30 +1884,7 @@ def take(a, indices, axis=None, mode='clip'): [5 7]] """ _check_input_tensor(a, indices) - if mode not in ('raise', 'wrap', 'clip'): - _raise_value_error('raise should be one of "raise", "wrap", or "clip"') - if axis is None: - a = ravel(a) - axis = 0 - ndim = F.rank(a) - axis = _check_axis_in_range(axis, ndim) - - shape_a = F.shape(a) - shape_indices = F.shape(indices) - size_indices = indices.size - indices = _check_indices(shape_a[axis], indices, mode) - - # reshapes indices to shape (Ni..., Nj..., Nk) - shape_ni = _tuple_slice(shape_a, None, axis) - shape_nk = _tuple_slice(shape_a, axis + 1, None) - shape_out = shape_ni + shape_indices + shape_nk - shape_indices = _expanded_shape(ndim, size_indices, axis) - indices = F.reshape(indices, shape_indices) - shape_indices = shape_ni + (indices.size,) + shape_nk - indices = _broadcast_to_shape(indices, shape_indices) - - res = F.gather_d(a, axis, indices) - return F.reshape(res, shape_out) + return a.take(indices, axis=axis, mode=mode) def repeat(a, repeats, axis=None): @@ -1952,30 +1928,8 @@ def repeat(a, repeats, axis=None): [3 4] [3 4]] """ - _check_input_tensor(a) - if not isinstance(repeats, (tuple, list)): - repeats = (repeats,) - _check_element_int(repeats) - if axis is None: - a = ravel(a) - axis = 0 - ndim = F.rank(a) - axis = _check_axis_in_range(axis, ndim) - if len(repeats) == 1: - repeats = repeats[0] - if repeats == 0: - return _empty(F.dtype(a), (0,)) - return C.repeat_elements(a, repeats, axis) - shape = F.shape(a) - dims = shape[axis] - if len(repeats) != dims: - _raise_value_error('operands could not be broadcast together') - subs = split(a, dims, axis) - repeated_subs = [] - for sub, rep in zip(subs, repeats): - if rep != 0: - repeated_subs.append(C.repeat_elements(sub, rep, axis)) - return concatenate(repeated_subs, axis) + a = _to_tensor(a) + return a.repeat(repeats, axis) def rot90(a, k=1, axes=(0, 1)): @@ -2072,12 +2026,12 @@ def select(condlist, choicelist, default=0): `choicelist` where the `m-th` element of the corresponding array in `condlist` is `True`. - Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` - Raises: ValueError: if ``len(condlist) != len(choicelist)``. + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + Examples: >>> import mindspore.numpy as np >>> condlist = [[True, True, True, False, False], \ @@ -2186,12 +2140,12 @@ def choose(a, choices, mode='clip'): Returns: Tensor, the merged result. - Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` - Raises: ValueError: if ``len(condlist) != len(choicelist)``. + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + Examples: >>> import mindspore.numpy as np >>> choices = [[0, 1, 2, 3], [10, 11, 12, 13], diff --git a/mindspore/numpy/math_ops.py b/mindspore/numpy/math_ops.py index 58eeae01e7e..44448d3416b 100644 --- a/mindspore/numpy/math_ops.py +++ b/mindspore/numpy/math_ops.py @@ -41,8 +41,8 @@ from .utils_const import _infer_out_shape, _check_axis_valid, _get_device, \ _check_dtype, _list_comprehensions, _tuple_setitem, _add_unit_axes, _seq_prod, \ _make_tensor, _promote_for_trigonometric, _raise_runtime_error, _max, _type_convert, \ _raise_unimplemented_error, _abs, _in -from .utils import _expand, _broadcast_to, _broadcast_to_shape, _get_size, \ - _check_input_tensor, _to_tensor, _isnan, _convert_bool_to_int, _to_tensor_origin_dtype +from .utils import _expand, _broadcast_to, _broadcast_to_shape, _check_input_tensor, \ + _to_tensor, _isnan, _to_tensor_origin_dtype ZERO_TENSOR = asarray_const(0) @@ -869,7 +869,7 @@ def std(x, axis=None, ddof=0, keepdims=False): otherwise over the specified axis. Note: - Numpy arguments `dtype` and `out` are not supported. + Numpy arguments `dtype`, `out` and `where` are not supported. Args: x (Tensor): A Tensor to be calculated. @@ -894,34 +894,8 @@ def std(x, axis=None, ddof=0, keepdims=False): >>> print(output) 1.118034 """ - if _is_shape_empty(x.shape): - return full((), nan, F.dtype(x)) - - if not isinstance(ddof, int): - _raise_type_error("integer argument expected, but got ", ddof) - if not isinstance(keepdims, int): - _raise_type_error("integer argument expected, but got ", keepdims) - if axis is None: - axis = () - else: - _check_axis_type(axis, True, True, False) - axis = _canonicalize_axis(axis, x.ndim) - - x_mean = _mean_keepdims(x, axis) - x_sub = F.tensor_sub(x, x_mean) - x_pow = F.tensor_pow(x_sub, 2) - if keepdims: - x_sum = _reduce_sum_keepdims(x_pow, axis) - else: - x_sum = _reduce_sum_default(x_pow, axis) - - if isinstance(axis, int): - nums = x.shape[axis] - else: - nums = _get_size(x, axis) - - x_std = F.tensor_pow(F.tensor_div(x_sum, nums - ddof), 0.5) - return x_std + x = _to_tensor(x) + return x.std(axis, ddof, keepdims) def var(x, axis=None, ddof=0, keepdims=False): @@ -934,7 +908,7 @@ def var(x, axis=None, ddof=0, keepdims=False): otherwise over the specified axis. Note: - Numpy arguments `dtype` and `out` are not supported. + Numpy arguments `dtype`, `out` and `where` are not supported. Args: x (Tensor): A Tensor to be calculated. @@ -957,11 +931,8 @@ def var(x, axis=None, ddof=0, keepdims=False): >>> print(output) 1.25 """ - if _is_shape_empty(x.shape): - return full((), nan, F.dtype(x)) - - x_std = std(x, axis, ddof, keepdims) - return F.tensor_pow(x_std, 2) + x = _to_tensor(x) + return x.var(axis, ddof, keepdims) def ptp(x, axis=None, keepdims=False): @@ -996,21 +967,7 @@ def ptp(x, axis=None, keepdims=False): [2. 0. 5. 2.] """ _check_input_tensor(x) - if not isinstance(keepdims, bool): - _raise_type_error('keepdims should be boolean') - if axis is None: - axis = () - else: - _check_axis_type(axis, True, True, False) - axis = _check_axis_valid(axis, x.ndim) - - if keepdims: - x_min = _reduce_min_keepdims(x, axis) - x_max = _reduce_max_keepdims(x, axis) - else: - x_min = _reduce_min_default(x, axis) - x_max = _reduce_max_default(x, axis) - return F.tensor_sub(x_max, x_min) + return x.ptp(axis, keepdims) def average(x, axis=None, weights=None, returned=False): @@ -1445,8 +1402,7 @@ def amax(a, axis=None, keepdims=False, initial=None, where=True): >>> print(output) [-1. 3.] """ - return _reduce(a, P.ReduceMax(keepdims), cmp_fn=F.maximum, axis=axis, keepdims=keepdims, - initial=initial, where=where) + return a.max(axis, keepdims, initial, where) def amin(a, axis=None, keepdims=False, initial=None, where=True): @@ -1501,8 +1457,7 @@ def amin(a, axis=None, keepdims=False, initial=None, where=True): >>> print(output) [10. 1.] """ - return _reduce(a, P.ReduceMin(keepdims), cmp_fn=F.minimum, axis=axis, keepdims=keepdims, - initial=initial, where=where) + return a.min(axis, keepdims, initial, where) def hypot(x1, x2, dtype=None): @@ -2278,6 +2233,8 @@ def _handle_inputs(cov_input, rowvar): _raise_value_error("input array has dimension more than 2.") cov_input = cov_input.astype("float32") cov_input = _expand(cov_input, 2) + if not isinstance(rowvar, bool): + _raise_type_error("input rowvar should be boolean.") if not rowvar and cov_input.shape[0] != 1: cov_input = cov_input.T return cov_input @@ -2467,6 +2424,7 @@ def _reduce(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None, if initial is not None: initial = full(shape, initial, dtype) a = cmp_fn(a, initial) + if isinstance(where, Tensor): if initial is None: return _raise_value_error('initial value must be provided for where masks') @@ -3133,20 +3091,7 @@ def cumsum(a, axis=None, dtype=None): [3. 3. 3.]] """ _check_input_tensor(a) - original_dtype = F.dtype(a) - # If original tensor is int, and has precision less then int32, convert to int32 - if _check_same_type(original_dtype, mstype.bool_) or \ - _check_same_type(original_dtype, mstype.int8) or \ - _check_same_type(original_dtype, mstype.int16): - original_dtype = mstype.int32 - a = a.astype(mstype.float32) - if axis is None: - a = a.ravel() - axis = 0 - _check_axis_in_range(axis, a.ndim) - if dtype is not None and not _check_same_type(original_dtype, dtype): - return _cumsum_default(a, axis).astype(dtype, copy=False) - return _cumsum_default(a, axis).astype(original_dtype, copy=False) + return a.cumsum(axis, dtype) def nancumsum(a, axis=None, dtype=None): @@ -3196,7 +3141,7 @@ def nancumsum(a, axis=None, dtype=None): [3. 3.]] """ a = F.select(_isnan(a), zeros(F.shape(a), F.dtype(a)), a) - return cumsum(a, axis=axis, dtype=dtype) + return a.cumsum(axis, dtype) def cbrt(x, dtype=None): @@ -4079,28 +4024,8 @@ def sum_(a, axis=None, dtype=None, keepdims=False, initial=None): >>> print(np.sum(x, axis=1)) [10. 35.] """ - if not isinstance(keepdims, int): - _raise_type_error("integer argument expected, but got ", keepdims) - if initial is not None and not isinstance(initial, (int, float, bool)): - _raise_type_error("initial argument should be a scalar.") - if axis is None: - axis = () - else: - _check_axis_type(axis, True, True, False) - axis = _canonicalize_axis(axis, a.ndim) - a = _convert_bool_to_int(_to_tensor(a)) - if _is_shape_empty(a.shape): - a = F.fill(a.dtype, (1,), 0) - - if keepdims: - res = _reduce_sum_keepdims(a, axis) - else: - res = _reduce_sum_default(a, axis) - if initial is not None: - res += initial - if dtype is not None and not _check_same_type(F.dtype(res), dtype): - res = F.cast(res, dtype) - return res + a = _to_tensor(a) + return a.sum(axis, dtype, keepdims, initial) @constexpr @@ -4327,6 +4252,7 @@ def searchsorted(a, v, side='left', sorter=None): ``Ascend`` ``GPU`` ``CPU`` Examples: + >>> from mindspore import numpy as np >>> print(np.searchsorted([1,2,3,4,5], 3)) 2 >>> print(np.searchsorted([1,2,3,4,5], 3, side='right')) @@ -4726,7 +4652,7 @@ def histogram(a, bins=10, range=None, weights=None, density=False): # pylint: di if density: count = F.cast(count, mstype.float32) count = count/diff(bin_edges)/F.reduce_sum(count) - return count, bin_edges + return count.astype(mstype.int32), bin_edges @constexpr @@ -4865,7 +4791,7 @@ def histogramdd(sample, bins=10, range=None, weights=None, density=False): # pyl shape = _expanded_shape(ndim, dedges[i].size, i) count /= _to_tensor(dedges[i]).reshape(shape) count /= s - return count, bin_edges + return count.astype(mstype.int32), bin_edges def histogram2d(x, y, bins=10, range=None, weights=None, density=False): # pylint: disable=redefined-builtin @@ -4929,7 +4855,7 @@ def histogram2d(x, y, bins=10, range=None, weights=None, density=False): # pylin 5.33333349e+00, 6.00000000e+00])) """ count, bin_edges = histogramdd((x, y), bins=bins, range=range, weights=weights, density=density) - return count, bin_edges[0], bin_edges[1] + return count.astype(mstype.int32), bin_edges[0], bin_edges[1] def matrix_power(a, n): diff --git a/mindspore/ops/composite/array_ops.py b/mindspore/ops/composite/array_ops.py index 20187a9c8b1..352f1c50bb7 100644 --- a/mindspore/ops/composite/array_ops.py +++ b/mindspore/ops/composite/array_ops.py @@ -15,6 +15,7 @@ """array Operations.""" from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils from mindspore.common import dtype as mstype +from mindspore.common._register_for_tensor import tensor_operator_registry from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from mindspore.ops.primitive import constexpr @@ -104,6 +105,8 @@ def repeat_elements(x, rep, axis=0): return x_rep +tensor_operator_registry.register('repeat_elements', repeat_elements) + @constexpr def _check_sequence_mask_input_len(input_shape): diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index 7143b4ec9c3..8caf9f3e7fc 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -137,6 +137,68 @@ tensor_operator_registry.register('__pow__', _tensor_pow) tensor_operator_registry.register('__floordiv__', _tensor_floordiv) +def tensor_item(data, *args): + """Tensor getitem by index whose dtype is int or tuple with int.""" + # transform a.item(tuple(int)) -> a.item(int1,int2...intN) + if len(args) == 1 and isinstance(args[0], tuple): + args = args[0] + + args_types = hyper_map(F.typeof, args) + if not args or const_utils.judge_index_type(args_types[0], mstype.type_none): + if data.shape == (1,): + return data + const_utils.raise_value_error("Can only convert an array of size 1 to a Python scalar") + + if not const_utils.judge_indexes_types(args_types, mstype.int64): + const_utils.raise_type_error("The index object cannot be interpreted as an integer") + + if len(args) == data.ndim: + return _tensor_getitem_by_tuple_slice(data, args) + if len(args) > 1: + const_utils.raise_value_error("Incorrect number of indices for array") + return _tensor_index_by_integer(F.reshape(data, (-1,)), args[0]) + + +def tensor_itemset(data, *args): + """Tensor setitem by index and value.""" + if not args: + const_utils.raise_value_error("itemset must have at least one argument") + if len(args) == 2: + if const_utils.judge_index_type(F.typeof(args[0]), mstype.int64): + return tensor_itemset_by_number_with_number(data, args[0], args[1]) + if isinstance(args[0], tuple): + return tensor_itemset_by_tuple_with_number(data, args[0], args[1]) + const_utils.raise_type_error("The index object cannot be interpreted as an integer") + if len(args) > 2: + const_utils.raise_value_error("incorrect number of indices for array") + return tensor_itemset_with_number(data, args[0]) + + +tensor_operator_registry.register("item", tensor_item) +tensor_operator_registry.register("itemset", tensor_itemset) + + +def tensor_itemset_with_number(data, number_value): + if not const_utils.judge_index_type(F.typeof(number_value), mstype.number_type): + const_utils.raise_index_error("The Tensor could only use the number value for itemset") + if data.shape != (1,): + const_utils.raise_index_error("The Tensor without shape (1,) could not use the itemset api with only one args") + return const_utils.make_tensor((number_value,), F.dtype(data)) + + +def tensor_itemset_by_number_with_number(data, int_index, number_value): + flatten_data = F.reshape(data, (-1,)) + itemset_data = tensor_setitem_by_number_with_number(flatten_data, int_index, number_value) + res_data = F.reshape(itemset_data, F.shape(data)) + return res_data + + +def tensor_itemset_by_tuple_with_number(data, tuple_index, nubmer_value): + if len(tuple_index) != data.ndim: + const_utils.raise_value_error("incorrect number of indices for array") + return tensor_setitem_by_tuple_with_number(data, tuple_index, nubmer_value) + + def _broadcast(broadcast_shape, x): """Broadcast tensor to the required shape.""" if not const_utils.check_two_shapes_need_broadcast(broadcast_shape, F.shape(x)): @@ -148,10 +210,10 @@ def _broadcast(broadcast_shape, x): return x -def _transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x): +def _transform_indexing_tensor(broadcast_shape, final_shape, new_shape, item): """Transform indexing tensor to the required.""" - x = _broadcast(broadcast_shape, x) - return _broadcast(final_shape, F.reshape(x, new_shape)) + item = _broadcast(broadcast_shape, item) + return _broadcast(final_shape, F.reshape(item, new_shape)) def _transform_ellipsis_to_slice(data, tuple_index, op_name): @@ -636,12 +698,12 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): return data indexes_types = hyper_map(F.typeof, tuple_index) - contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) + contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name) if contain_type == const_utils.ALL_TENSOR: - indices = _generate_indices_from_tuple_of_tensor(tuple_index, const_utils.TENSOR_SETITEM) + indices = _generate_indices_from_tuple_of_tensor(tuple_index, op_name) else: - indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM, idx_advanced) + indices = _generate_indices_from_tuple(data, tuple_index, op_name, idx_advanced) if indices is False: return data updates = _generate_updates_from_tensor(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) @@ -824,3 +886,65 @@ def format_index(idx, data_shape, cur_dim): # does not take bool tensor into account since it's currently not supported idx = F.select(idx < 0, idx + data_shape[cur_dim], idx) return idx + + +def reduce_(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None, where=True, dtype=None): + """ + Applies comparison based on cmp_fn and reduction based on reduce_fn. + If cmp_fn is None, only reduction is performed. + """ + + shape = F.shape(a) + ndim = F.rank(a) + if dtype is None: + dtype = F.dtype(a) + axes = const_utils.check_axis_valid_const(axis, ndim) + if initial is not None: + if ((isinstance(initial, Tensor) and F.rank(initial) > 0) or + not isinstance(initial, (int, float, bool, Tensor))): + const_utils.raise_type_error('initial should be scalar') + + if F.shape_mul(shape) == 0: + const_utils.raise_value_error('zero-size tensors are not supported.') + + if initial is not None: + initial = F.fill(dtype, shape, initial) + a = cmp_fn(a, initial) + + if isinstance(where, Tensor): + if initial is None: + const_utils.raise_value_error('initial value must be provided for where masks') + ndim_orig = F.rank(a) + a = F.select(where, a, initial) + axes = const_utils.real_axes(ndim_orig, F.rank(a), axes) + + return reduce_fn(a, axes).astype(dtype) + + +tensor_operator_registry.register("reduce", reduce_) + + +def check_indices(dims, indices, mode, allow_negative_index=True): + """Checks whether indices are out of bounds.""" + shape = F.shape(indices) + dtype = F.dtype(indices) + if not allow_negative_index: + lowerbounds = F.fill(dtype, shape, 0) + else: + lowerbounds = F.fill(dtype, shape, -dims) + upperbounds = F.fill(dtype, shape, dims - 1) + out_of_lowerbounds = F.tensor_lt(indices, lowerbounds) + out_of_upperbounds = F.tensor_gt(indices, upperbounds) + if mode == 'raise': + const_utils.raise_unimplemented_error('"raise" mode is not implemented') + if mode == 'wrap': + bounds = F.fill(dtype, shape, dims) + quotient = F.tensor_floordiv(indices, bounds) + prod = F.tensor_mul(bounds, quotient) + return F.tensor_sub(indices, prod) + zeros = F.fill(dtype, shape, 0) + clipped = F.select(out_of_lowerbounds, zeros, indices) + clipped = F.select(out_of_upperbounds, upperbounds, clipped) + return clipped + +tensor_operator_registry.register('check_indices', check_indices) diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index 5e016dbd95f..f6de370381c 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -15,6 +15,8 @@ """constexpr util""" from itertools import compress +from functools import partial +import operator import numpy as np @@ -22,7 +24,9 @@ from ...primitive import constexpr from .... import log as logger from ....common import dtype as mstype from ....common.tensor import Tensor +from ....common._register_for_tensor import tensor_operator_registry from ....ops import _utils as op_utils +from ...._checkparam import Validator as validator ALL_TENSOR = 0 NO_TENSOR = 1 @@ -58,6 +62,11 @@ def raise_type_error(msg): raise TypeError(msg) +@constexpr +def raise_unimplemented_error(msg): + raise NotImplementedError(msg) + + @constexpr def check_equal(param1, param2, msg="{},{}"): """Checks whether the two parameters are equal or not.""" @@ -165,6 +174,8 @@ def make_tensor(a, dtype=mstype.int64, data_shape=None, dim_size=-1): return Tensor(a, dtype) +tensor_operator_registry.register('make_tensor', make_tensor) + @constexpr def judge_data_dim(data_dim, min_data_dim=0, max_data_dim=8): @@ -234,8 +245,7 @@ def is_same_type(inst, type_): @constexpr def check_valid_dim(dim, name): if dim not in (1, 2): - raise ValueError( - f"For {name}, inputs dim must be 1d or 2d") + raise ValueError(f"For {name}, inputs dim must be 1d or 2d") @constexpr @@ -249,8 +259,12 @@ def judge_index_type(index_type, target_type): def judge_indexes_types(dtypes, target_type): """Check a tuple of tensor data type.""" for dtype in dtypes: - if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type): - return False + if isinstance(target_type, (list, tuple)): + if dtype not in target_type: + return False + else: + if dtype != target_type: + return False return True @@ -547,7 +561,7 @@ def check_number_index_type(number): @constexpr def get_stride_info_from_slice(data_shape, slice_index): """Get stride info from a python slice""" - begin, end, step = get_slice_stride(data_shape[0], slice_index) + begin, end, step = get_slice_stride(slice_index, data_shape[0]) begin_strides = [begin] end_strides = [end] step_strides = [step] @@ -571,7 +585,7 @@ def get_stride_info_from_integer(data_shape, number): return tuple(begin_strides), tuple(end_strides), tuple(step_strides) -def get_slice_stride(dim_size, index_slice): +def get_slice_stride(index_slice, dim_size): """Get slice stride info""" step = 1 if index_slice.step is None else index_slice.step start_default = 0 @@ -591,20 +605,20 @@ def get_stride_info_from_tuple(data_shape, tuple_index): tuple_index_len = len(tuple_index) data_dim = len(data_shape) shrink_axis, index_count, ellipsis_count = 0, 0, 0 - for idx, item in enumerate(tuple_index): - if isinstance(item, slice): - start, stop, step = get_slice_stride(data_shape[idx], item) + for index, dim_size in zip(tuple_index, data_shape): + if isinstance(index, slice): + start, stop, step = get_slice_stride(index, dim_size) begin_strides.append(start) end_strides.append(stop) step_strides.append(step) index_count = index_count + 1 - elif isinstance(item, int): - begin_strides.append(item) - end_strides.append(item + 1) + elif isinstance(index, int): + begin_strides.append(index) + end_strides.append(index + 1) step_strides.append(1) shrink_axis = shrink_axis + (1 << index_count) index_count = index_count + 1 - elif item is ...: + elif index is ...: ellipsis_count = ellipsis_count + 1 if ellipsis_count > 1: raise IndexError("An index can have only one ellipsis (...)") @@ -616,10 +630,10 @@ def get_stride_info_from_tuple(data_shape, tuple_index): index_count = index_count + ellipsis_range_size else: raise IndexError("Not supported index data type, got ", - item, " type is ", type(item)) - for item in range(index_count, data_dim): + index, " type is ", type(item)) + for index in range(index_count, data_dim): begin_strides.append(0) - end_strides.append(data_shape[item]) + end_strides.append(data_shape[index]) step_strides.append(1) return tuple(begin_strides), tuple(end_strides), tuple(step_strides), shrink_axis @@ -773,3 +787,15 @@ def rem_not_expanded_dims(idx_advanced, expand_true, tensor_index_ndim, rem_ndim @constexpr def check_slice_empty(start, stop, step): return (start - stop)*step >= 0 + + +@constexpr +def real_axes(ndim_orig, ndim_out, axes_orig): + """Returns the real axes to be reduced after performing broadcast""" + _diff = ndim_out - ndim_orig + axes = tuple(range(_diff)) + axes_orig = map(partial(operator.add, _diff), axes_orig) + return axes + tuple(axes_orig) + + +check_axis_valid_const = constexpr(validator.check_axis_valid) diff --git a/mindspore/ops/composite/multitype_ops/setitem_impl.py b/mindspore/ops/composite/multitype_ops/setitem_impl.py index bee58dd3e18..73f6ac04cfd 100644 --- a/mindspore/ops/composite/multitype_ops/setitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/setitem_impl.py @@ -402,7 +402,6 @@ def _tensor_setitem_by_slice_with_tuple(data, input_slice, value): return compile_utils.tensor_setitem_by_slice_with_sequence(data, input_slice, value) - @setitem.register("Tensor", "Number", "Number") def _tensor_setitem_by_number_with_number(data, index, value): """ diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 2abf1f35676..c5f1b78aa86 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -232,6 +232,16 @@ tensor_operator_registry.register('transpose', P.Transpose) tensor_operator_registry.register('broadcast_to', P.BroadcastTo) tensor_operator_registry.register('matmul', P.MatMul) tensor_operator_registry.register('argmax', P.Argmax) +tensor_operator_registry.register('cumsum', P.CumSum) +tensor_operator_registry.register('reduce_max', P.ReduceMax) +tensor_operator_registry.register('reduce_min', P.ReduceMin) +tensor_operator_registry.register('maximum', P.Maximum) +tensor_operator_registry.register('minimum', P.Minimum) +tensor_operator_registry.register('fill', P.Fill) +tensor_operator_registry.register('tile', P.Tile) +tensor_operator_registry.register('logical_not', P.LogicalNot) +tensor_operator_registry.register('sum', P.ReduceSum) +tensor_operator_registry.register('split', P.Split) # ms cannot support Tensor(True) compare tensor_operator_registry.register('__eq__', equal) tensor_operator_registry.register('__ne__', not_equal) @@ -245,6 +255,18 @@ tensor_operator_registry.register('shape', shape) tensor_operator_registry.register('squeeze', squeeze) # support GE backend for no compare operators tensor_operator_registry.register('cast', cast) +tensor_operator_registry.register('shape_mul', shape_mul) +tensor_operator_registry.register('fill', fill) +tensor_operator_registry.register('concatenate', P.Concat) +tensor_operator_registry.register('eye', eye) +tensor_operator_registry.register('reduce_sum', reduce_sum) +tensor_operator_registry.register('tensor_slice', tensor_slice) +tensor_operator_registry.register('select', select) +tensor_operator_registry.register('gather_d', gather_d) +tensor_operator_registry.register('gather_nd', gather_nd) +tensor_operator_registry.register('stack', P.Stack) +tensor_operator_registry.register('log', log) +tensor_operator_registry.register('floor', floor) __all__ = [name for name in dir() if name[0] != "_"] __all__.remove('Primitive') diff --git a/tests/st/numpy_native/test_array_creations.py b/tests/st/numpy_native/test_array_creations.py index 6c882cfca59..817aeaf3c05 100644 --- a/tests/st/numpy_native/test_array_creations.py +++ b/tests/st/numpy_native/test_array_creations.py @@ -805,6 +805,21 @@ def test_vander(): match_all_arrays(mnp_vander, onp_vander, error=1e-4) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_tensor_fill(): + x = rand_int(2, 1, 4).astype(onp.float32) + mnp_x = to_tensor(x) + x.fill(6) + match_all_arrays(mnp_x.fill(6), x) + x.fill(None) + match_all_arrays(mnp_x.fill(None), x) + + @pytest.mark.level1 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training diff --git a/tests/st/numpy_native/test_array_ops.py b/tests/st/numpy_native/test_array_ops.py index 585fb6556d1..ce27a983a31 100644 --- a/tests/st/numpy_native/test_array_ops.py +++ b/tests/st/numpy_native/test_array_ops.py @@ -1529,6 +1529,29 @@ def test_apply_along_axis(): match_all_arrays(mnp_res, onp_res) +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_tensor_resize(): + x = rand_int(3, 5) + mnp_x = to_tensor(x) + + x.resize(2, 4, refcheck=False) + mnp_x = mnp_x.resize(2, 4) + match_array(mnp_x.asnumpy(), x) + + x.resize((3, 1), refcheck=False) + mnp_x = mnp_x.resize((3, 1)) + match_array(mnp_x.asnumpy(), x) + + x.resize(7, 4, refcheck=False) + mnp_x = mnp_x.resize(7, 4) + match_array(mnp_x.asnumpy(), x) + + @pytest.mark.level1 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -1616,3 +1639,27 @@ def test_apply_over_axes(): for expected, actual in zip(onp_apply_over_axes(x), mnp_apply_over_axes(to_tensor(x))): match_array(actual.asnumpy(), expected, error=5) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_tensor_choose(): + x = rand_int(2, 1, 4).astype(onp.int32) + mnp_x = to_tensor(x) + y = rand_int(3, 2, 5, 4).astype(onp.int32) + match_res(mnp_x.choose, x.choose, y, mode='wrap') + match_res(mnp_x.choose, x.choose, y, mode='clip') + + x = rand_int(5, 3, 1, 7).astype(onp.int32) + mnp_x = to_tensor(x) + y1 = rand_int(7).astype(onp.int32) + y2 = rand_int(1, 3, 1).astype(onp.int32) + y3 = rand_int(5, 1, 1, 7).astype(onp.int32) + onp_arrays = (y1, y2, y3) + mnp_arrays = tuple(map(to_tensor, (y1, y2, y3))) + match_all_arrays(mnp_x.choose(mnp_arrays, mode='wrap'), x.choose(onp_arrays, mode='wrap')) + match_all_arrays(mnp_x.choose(mnp_arrays, mode='clip'), x.choose(onp_arrays, mode='clip')) diff --git a/tests/st/numpy_native/test_math_ops.py b/tests/st/numpy_native/test_math_ops.py index 3395ebedb2d..46840c1a259 100644 --- a/tests/st/numpy_native/test_math_ops.py +++ b/tests/st/numpy_native/test_math_ops.py @@ -942,14 +942,20 @@ def mnp_clip(x): a = mnp.clip(x, to_tensor(10.0), to_tensor([2,])) b = mnp.clip(x, 0, 1) c = mnp.clip(x, to_tensor(0), to_tensor(10), dtype=mnp.float32) - return a, b, c + d = x.clip(to_tensor(10.0), to_tensor([2,])) + e = x.clip(0, 1) + f = x.clip(to_tensor(0), to_tensor(10), dtype=mnp.float32) + return a, b, c, d, e, f def onp_clip(x): a = onp.clip(x, onp.asarray(10.0), onp.asarray([2,])) b = onp.clip(x, 0, 1) c = onp.clip(x, onp.asarray(0), onp.asarray(10), dtype=onp.float32) - return a, b, c + d = x.clip(onp.asarray(10.0), onp.asarray([2,])) + e = x.clip(0, 1) + f = x.clip(onp.asarray(0), onp.asarray(10), dtype=onp.float32) + return a, b, c, d, e, f @pytest.mark.level1 @@ -2730,3 +2736,20 @@ def test_correlate(): mnp_res = mnp_correlate(a, v) onp_res = onp_correlate(a, v) match_all_arrays(mnp_res, onp_res) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_tensor_searchsorted(): + x = onp.arange(-10, 10) + mnp_x = to_tensor(x) + y = onp.random.randint(-15, 15, size=(2, 3, 4)) + onp.random.choice([0, 0.5], (2, 3, 4)) + sorter = onp.random.shuffle(onp.arange(20)) + match_res(mnp_x.searchsorted, x.searchsorted, y) + match_res(mnp_x.searchsorted, x.searchsorted, y, side='right') + match_res(mnp_x.searchsorted, x.searchsorted, y, sorter=sorter) + match_res(mnp_x.searchsorted, x.searchsorted, y, side='right', sorter=sorter) diff --git a/tests/st/pynative/test_tensor_index.py b/tests/st/pynative/test_tensor_getitem.py similarity index 85% rename from tests/st/pynative/test_tensor_index.py rename to tests/st/pynative/test_tensor_getitem.py index f9a4cd5b2d9..88419714ef1 100644 --- a/tests/st/pynative/test_tensor_index.py +++ b/tests/st/pynative/test_tensor_getitem.py @@ -16,7 +16,8 @@ import numpy as np import pytest -from mindspore import Tensor, Parameter +from mindspore import Tensor +from mindspore import Parameter from mindspore import context from mindspore import dtype as mstype from mindspore.nn import Cell @@ -28,7 +29,7 @@ grad_by_list_with_sens = C.GradOperation(get_by_list=True, sens_param=True) def setup_module(): - context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + context.set_context(mode=context.PYNATIVE_MODE) class NetWorkSlicePositive(Cell): @@ -50,6 +51,7 @@ class NetWorkSlicePositive(Cell): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_slice_positive(): net = NetWorkSlicePositive() @@ -77,12 +79,17 @@ class NetWorkSliceEllipsis(Cell): return ret0, ret1, ret2, ret3 -def Xtest_slice_ellipsis(): +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_slice_ellipsis(): net = NetWorkSliceEllipsis() input_np = np.arange(6*7*8*9).reshape(6, 7, 8, 9).astype(np.int32) input_0 = Tensor(input_np) output0, output1, output2, output3 = net(input_0) - assert np.all(output0.asnumpy() == input_np[0:4:2, ..., 1] + np.ones([1, 2, 3])) + assert np.all(output0.asnumpy() == input_np[0:4:2, ..., 1] + np.ones([2, 7, 8])) assert np.all(output1.asnumpy() == input_np[...] + np.ones([6, 7, 8, 9])) assert np.all(output2.asnumpy() == input_np[None] + np.ones([6, 7, 8, 9])) assert np.all(output3.asnumpy() == input_np[True] + np.ones([1, 6, 7, 8, 9])) @@ -104,7 +111,12 @@ class NetWorkReduceDimension(Cell): return ret1, ret2, ret3, ret4 -def Xtest_reduce_dimension(): +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_reduce_dimension(): net = NetWorkReduceDimension() input_np = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32) input_0 = Tensor(input_np) @@ -115,6 +127,11 @@ def Xtest_reduce_dimension(): assert np.all(output4.asnumpy() == input_np[1] + np.ones([8, 10])) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard class NetWorkSliceStep(Cell): def __init__(self): super(NetWorkSliceStep, self).__init__() @@ -127,12 +144,16 @@ class NetWorkSliceStep(Cell): return ret1, ret2 -def Xtest_step_negative(): - net = NetWorkSliceEllipsis() +@pytest.mark.level0 +# ascend op stridedslice has bug, and has not been fixed. +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_step_negative(): + net = NetWorkSliceStep() input_np = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32) input_0 = Tensor(input_np) output1, output2 = net(input_0) - assert np.all(output1.asnumpy() == input_np[::1, -5::, ::-1] + np.ones([6, 8, 10])) + assert np.all(output1.asnumpy() == input_np[::1, -5::, ::-1] + np.ones([6, 5, 10])) assert np.all(output2.asnumpy() == input_np[::2, -5::, ::2] + np.ones([3, 5, 5])) @@ -153,8 +174,9 @@ class TensorGetItemByThreeTensors(Cell): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def Xtest_getitem_by_tensors(): +def test_getitem_by_tensors(): """This testcase may encounter a sync stream error occasionally""" net = TensorGetItemByThreeTensors() input_x = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32) @@ -194,6 +216,7 @@ class TensorGetItemByMixedTensorsBasicCase(Cell): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_getitem_by_mixed_tensors(): const0 = np.ones((3, 4, 5, 3), np.float32) @@ -218,6 +241,109 @@ def test_getitem_by_mixed_tensors(): assert np.all(out5.asnumpy() == (input_np[..., index_np_0, index_np_1] + const5)) +class TensorItemByNone(Cell): + def construct(self, tensor): + ret = tensor.item() + return ret + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_item_by_none(): + net = TensorItemByNone() + input_1d_np = np.ndarray([1]).astype(np.float32) + input_1d_ms = Tensor(input_1d_np, mstype.float32) + input_3d_np = np.random.randint(3, size=(3, 4, 5)).astype(np.int32) + input_3d_ms = Tensor(input_3d_np, mstype.float32) + + output_ms = net(input_1d_ms) + assert np.all(output_ms.asnumpy() == input_1d_np.item()) + + with pytest.raises(ValueError): + net(input_3d_ms) + + +class TensorItemByItem(Cell): + def construct(self, tensor, index): + ret = tensor.item(index) + return ret + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_item_by_int(): + net = TensorItemByItem() + input_1d_np = np.ndarray([1]).astype(np.float32) + input_1d_ms = Tensor(input_1d_np, mstype.float32) + + input_3d_np = np.random.randint(3, size=(3, 4, 5)).astype(np.int32) + input_3d_ms = Tensor(input_3d_np, mstype.float32) + + index_np_1, index_np_2, index_np_3, index_np_4 = 0, 1.0, 30, 60 + + output_1d_ms = net(input_1d_ms, index_np_1) + output_3d_ms_1 = net(input_3d_ms, index_np_1) + output_3d_ms_2 = net(input_3d_ms, index_np_3) + + assert np.all(output_1d_ms.asnumpy() == input_1d_np.item(index_np_1)) + assert np.all(output_3d_ms_1.asnumpy() == input_3d_np.item(index_np_1)) + assert np.all(output_3d_ms_2.asnumpy() == input_3d_np.item(index_np_3)) + + with pytest.raises(TypeError): + net(input_1d_ms, index_np_2) + + with pytest.raises(IndexError): + net(input_1d_ms, index_np_3) + + with pytest.raises(TypeError): + net(input_3d_ms, index_np_2) + + with pytest.raises(IndexError): + net(input_3d_ms, index_np_4) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_item_by_tuple(): + net = TensorItemByItem() + input_1d_np = np.ndarray([1]).astype(np.float32) + input_1d_ms = Tensor(input_1d_np, mstype.float32) + input_3d_np = np.random.randint(3, size=(3, 4, 5)).astype(np.int32) + input_3d_ms = Tensor(input_3d_np, mstype.float32) + + index_np_1 = (0,) + index_np_2 = (1, 2) + index_np_3 = (1, 2, 3) + index_np_4 = (3, 4, 4) + index_np_5 = (1, 2, 3, 4) + + output_1d_ms = net(input_1d_ms, index_np_1) + output_3d_ms = net(input_3d_ms, index_np_3) + assert np.all(output_1d_ms.asnumpy() == input_1d_np.item(index_np_1)) + assert np.all(output_3d_ms.asnumpy() == input_3d_np.item(index_np_3)) + + with pytest.raises(ValueError): + net(input_1d_ms, index_np_2) + + with pytest.raises(ValueError): + net(input_3d_ms, index_np_2) + + with pytest.raises(IndexError): + net(input_3d_ms, index_np_4) + + with pytest.raises(ValueError): + net(input_3d_ms, index_np_5) + + class TensorSetItemByMixedTensors_0(Cell): def __init__(self, value): super(TensorSetItemByMixedTensors_0, self).__init__() @@ -236,6 +362,7 @@ class TensorSetItemByMixedTensors_0(Cell): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_setitem_by_mixed_tensors_0(): value = 88.0 @@ -253,6 +380,11 @@ def test_setitem_by_mixed_tensors_0(): assert np.all(out.asnumpy() == (input_np + const)) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard class TensorSetItemByMixedTensors_1(Cell): def __init__(self, value): super(TensorSetItemByMixedTensors_1, self).__init__() @@ -270,6 +402,7 @@ class TensorSetItemByMixedTensors_1(Cell): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_setitem_by_mixed_tensors_1(): value = 88.0 @@ -287,6 +420,11 @@ def test_setitem_by_mixed_tensors_1(): assert np.all(out.asnumpy() == (input_np + const)) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard class TensorSetItemByMixedTensors_2(Cell): def __init__(self, value): super(TensorSetItemByMixedTensors_2, self).__init__() @@ -304,6 +442,7 @@ class TensorSetItemByMixedTensors_2(Cell): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_setitem_by_mixed_tensors_2(): value = 88.0 @@ -327,7 +466,12 @@ class TensorGetItemByMixedTensorsIndexError(Cell): return ret -def test_getitem_by_mixedtensor_exception(): +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_getitem_by_mixed_tensor_exception(): input_ms = Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32) index_0 = Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32) index_1 = Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32) @@ -352,6 +496,7 @@ class TensorSetItemByOneTensorWithNumber(Cell): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_setitem_one_tensor_with_number(): value = 0.0 @@ -380,6 +525,7 @@ class TensorSetItemByOneTensorWithTensor(Cell): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_setitem_by_one_tensor_with_tensor(): net = TensorSetItemByOneTensorWithTensor() @@ -410,6 +556,7 @@ class TensorSetItemByOneTensorWithTupleOfNumber(Cell): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_setitem_by_one_tensor_with_tuple_number(): value = (0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7) @@ -438,6 +585,7 @@ class TensorSetItemByOneTensorWithTupleOfTensor(Cell): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_setitem_by_one_tensor_with_tuple_tensors(): net = TensorSetItemByOneTensorWithTupleOfTensor() @@ -472,7 +620,9 @@ class TensorSetItemByTensorsWithNumber(Cell): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard +@pytest.mark.level0 def test_setitem_by_tensors_with_number(): value = 0.0 net = TensorSetItemByTensorsWithNumber(value) @@ -504,6 +654,7 @@ class TensorSetItemByTensorsWithTensor(Cell): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_setitem_by_tensors_with_tensor(): net = TensorSetItemByTensorsWithTensor() @@ -537,6 +688,7 @@ class TensorSetItemByTensorsWithTensorNumberError(Cell): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_setitem_by_tensors_with_tensor_error(): index_0 = Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32) @@ -565,6 +717,7 @@ class TensorSetItemByTensorsWithTupleOfNumber(Cell): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +# GPU op has bug, and has not been fixed. @pytest.mark.env_onecard def test_setitem_by_tensors_with_tuple_of_number(): value = (0.0, 1.1, 2.2, 3.3, 4.4) @@ -597,6 +750,7 @@ class TensorSetItemByTensorsWithTupleOfTensor(Cell): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +# GPU op has bug, and has not been fixed. @pytest.mark.env_onecard def test_setitem_by_tensors_with_tuple_of_tensor(): value_0 = np.zeros((4, 5)) @@ -634,6 +788,7 @@ class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_setitem_by_tensor_with_tuple_of_tensor_error(): net = TensorSetItemByTensorsWithTupleOfTensorNumberError() @@ -648,6 +803,11 @@ def test_setitem_by_tensor_with_tuple_of_tensor_error(): net(index_0_ms, index_1_ms, index_2_ms, value_0_ms, value_1_ms) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard def test_setitem_grad(): class Net(Cell): def __init__(self): @@ -720,14 +880,15 @@ class TensorAssignWithSlice(Cell): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_tensor_assign_slice_value_1(): net = TensorAssignWithSlice() a = np.arange(60).reshape(3, 4, 5) - ck = np.arange(60).reshape(3, 4, 5) b = np.array([1]).astype(np.float32) # Tensor([1], dtype=mstype.float32) - tb = Tensor(b, dtype=mstype.float32) + ck = np.arange(60).reshape(3, 4, 5) ta = Tensor(a, dtype=mstype.float32) + tb = Tensor(b, dtype=mstype.float32) tck = Tensor(ck, dtype=mstype.float32) out = net(ta, tb, tck) a[1:3, ::] = b @@ -745,6 +906,7 @@ def test_tensor_assign_slice_value_1(): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_tensor_assign_slice_value_2(): net2 = TensorAssignWithSlice2() @@ -768,6 +930,7 @@ def test_tensor_assign_slice_value_2(): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_tensor_assign_exception(): net = TensorAssignWithSlice() @@ -939,6 +1102,7 @@ class TensorAssignWithBoolTensorIndex2Error(Cell): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_tensor_assign_bool_index_0(): a = np.arange(60).reshape(3, 4, 5) @@ -960,6 +1124,7 @@ def test_tensor_assign_bool_index_0(): @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_tensor_assign_bool_index_1(): a = np.arange(60).reshape(3, 4, 5) @@ -977,6 +1142,11 @@ def test_tensor_assign_bool_index_1(): assert np.all(out.asnumpy() == res) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard def test_tensor_assign_bool_index_exception(): a = np.arange(60).reshape(3, 4, 5) b = a > 5 @@ -1015,7 +1185,12 @@ def test_tensor_assign_bool_index_exception(): net4(Ta, u_scalar) -def Xtest_tensor_slice_reduce_out_of_bounds_neg(): +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_tensor_slice_reduce_out_of_bounds_neg(): class NetWork(Cell): def __init__(self): super(NetWork, self).__init__() @@ -1029,11 +1204,15 @@ def Xtest_tensor_slice_reduce_out_of_bounds_neg(): net = NetWork() with pytest.raises(IndexError) as ex: net(input_tensor) - assert "For 'StridedSlice' the `begin[0]` should be an int and must greater or equal to -6, but got `-7`" in str( - ex.value) + assert "begin should be in [-6, 6), but got stride: 1, begin: -7." in str(ex.value) -def Xtest_tensor_slice_reduce_out_of_bounds_positive(): +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_tensor_slice_reduce_out_of_bounds_positive(): class NetWork(Cell): def __init__(self): super(NetWork, self).__init__() @@ -1047,12 +1226,13 @@ def Xtest_tensor_slice_reduce_out_of_bounds_positive(): net = NetWork() with pytest.raises(IndexError) as ex: net(input_tensor) - assert "For 'StridedSlice' the `begin[0]` should be an int and must less than 6, but got `6`" in str(ex.value) + assert "begin should be in [-6, 6), but got stride: 1, begin: 6." in str(ex.value) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_tensor_range(): a = np.arange(4*5*6).reshape(4, 5, 6).astype(np.float32) diff --git a/tests/st/pynative/test_tensor_setitem.py b/tests/st/pynative/test_tensor_setitem.py index aa031be79c4..2490a20c188 100644 --- a/tests/st/pynative/test_tensor_setitem.py +++ b/tests/st/pynative/test_tensor_setitem.py @@ -18,10 +18,11 @@ import pytest from mindspore import Tensor, context from mindspore.nn import Cell +from mindspore import dtype as mstype def setup_module(): - context.set_context(mode=context.GRAPH_MODE) + context.set_context(mode=context.PYNATIVE_MODE) def setup_testcase(input_np, case_fn): @@ -47,6 +48,7 @@ class TensorSetItemByList(Cell): x[[0, 1], ..., [0, 1]] = 4 return x + class NumpySetItemByList(): def __call__(self, x): x[[0, 1], [1, 2], [1, 3]] = [3, 4] @@ -54,6 +56,7 @@ class NumpySetItemByList(): x[[0, 1], ..., [0, 1]] = 4 return x + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -61,6 +64,7 @@ class NumpySetItemByList(): @pytest.mark.env_onecard def test_setitem_by_list(): x = onp.ones((2, 3, 4), dtype=onp.float32) + def cases(x): x[[0, 1], [1, 2], [1, 3]] = [3, 4] x[([0, 1], [0, 2], [1, 1])] = [10, 5] @@ -76,6 +80,7 @@ def test_setitem_by_list(): @pytest.mark.env_onecard def test_setitem_with_sequence(): x = onp.ones((2, 3, 4), dtype=onp.float32) + def cases(x): x[...] = [3] x[..., 1] = ([1, 2, 3], [4, 5, 6]) @@ -92,6 +97,7 @@ def test_setitem_with_sequence(): @pytest.mark.env_onecard def test_setitem_dtype(): x = onp.ones((2, 3, 4), dtype=onp.float32) + def cases(x): x[...] = 3 x[..., 1] = 3.0 @@ -108,6 +114,7 @@ def test_setitem_dtype(): @pytest.mark.env_onecard def test_setitem_by_tuple_with_int(): x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32) + def cases(x): x[..., 2, False, 1] = -1 x[0, True, 0, None, True] = -2 @@ -124,6 +131,7 @@ def test_setitem_by_tuple_with_int(): @pytest.mark.env_onecard def test_setitem_by_tuple_with_list(): x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32) + def cases(x): x[..., 2, False, 1] = [-1] x[0, True, 0, None, True] = [-2, -2, -2, -2] @@ -141,6 +149,7 @@ def test_setitem_by_tuple_with_list(): @pytest.mark.env_onecard def test_setitem_by_nested_unit_list(): x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32) + def cases(x): x[[[[0]]], True] = -1 x[[1], ..., [[[[2]]]]] = -2 @@ -158,6 +167,7 @@ def test_setitem_with_broadcast(): x = onp.arange(2*3*4*5*6).reshape(2, 3, 4, 5, 6).astype(onp.float32) v1 = onp.full((1, 4, 5), -1).tolist() v2 = onp.full((4, 1, 6), -2).tolist() + def cases(x): x[..., 4] = v1 x[0, 2] = v2 @@ -174,6 +184,7 @@ def test_setitem_with_broadcast(): @pytest.mark.env_onecard def test_setitem_mul_by_scalar(): x = onp.ones((4, 5), dtype=onp.float32) + def cases(x): x[1, :] = x[1, :]*2 x[:, 2] = x[:, 3]*3.0 @@ -188,6 +199,7 @@ def test_setitem_mul_by_scalar(): @pytest.mark.env_onecard def test_setitem_by_slice(): x = onp.ones((3, 4, 5), dtype=onp.float32) + def cases(x): x[1:2] = 2 x[-3:1] = 3 @@ -207,6 +219,7 @@ def test_setitem_by_slice(): @pytest.mark.env_onecard def test_setitem_by_tuple_of_slices(): x = onp.ones((3, 4, 5), dtype=onp.float32) + def cases(x): x[1:2, 2] = 2 x[0, -4:1] = 3 @@ -217,6 +230,47 @@ def test_setitem_by_tuple_of_slices(): setup_testcase(x, cases) +class TensorItemSetWithNumber(Cell): + def construct(self, tensor, number_value): + ret = tensor.itemset(number_value) + return ret + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_itemset_with_number(): + net = TensorItemSetWithNumber() + input_1d_np = onp.ndarray([1]).astype(onp.float32) + input_1d_ms = Tensor(input_1d_np, mstype.float32) + + input_3d_np = onp.arange(60).reshape(3, 4, 5).astype(onp.int32) + input_3d_ms = Tensor(input_3d_np, mstype.float32) + + value_np_1, value_np_2 = 1, 2.0 + + output_1d_ms_1 = net(input_1d_ms, value_np_1) + output_1d_ms_2 = net(input_1d_ms, value_np_2) + + input_1d_np.itemset(value_np_1) + assert onp.all(output_1d_ms_1.asnumpy() == input_1d_np) + input_1d_np.itemset(value_np_2) + assert onp.all(output_1d_ms_2.asnumpy() == input_1d_np) + + with pytest.raises(IndexError): + net(input_3d_ms, value_np_1) + with pytest.raises(IndexError): + net(input_3d_ms, value_np_2) + + +class TensorItemSetByItemWithNumber(Cell): + def construct(self, tensor, index, number_value): + ret = tensor.itemset(index, number_value) + return ret + + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -231,3 +285,111 @@ def test_setitem_dim_expand(): x[..., (0, 1, 2), None, :, True, None] = [[[3], [3], [3], [3]]] return x setup_testcase(x, cases) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_itemset_by_number_with_number(): + net = TensorItemSetByItemWithNumber() + input_1d_np = onp.ndarray([1]).astype(onp.float32) + input_1d_ms = Tensor(input_1d_np, mstype.float32) + + input_3d_np = onp.arange(60).reshape(3, 4, 5).astype(onp.int32) + input_3d_ms = Tensor(input_3d_np, mstype.float32) + + index_np_1, index_np_2, index_np_3, index_np_4 = 0, 30, 60, 2.0 + value_np_1, value_np_2 = 1, 2.0 + + output_1d_ms_1 = net(input_1d_ms, index_np_1, value_np_1) + output_1d_ms_2 = net(input_1d_ms, index_np_1, value_np_2) + output_3d_ms_1 = net(input_3d_ms, index_np_1, value_np_1) + output_3d_ms_2 = net(output_3d_ms_1, index_np_1, value_np_2) + output_3d_ms_3 = net(output_3d_ms_2, index_np_2, value_np_1) + output_3d_ms_4 = net(output_3d_ms_3, index_np_2, value_np_2) + + input_1d_np.itemset(index_np_1, value_np_1) + assert onp.all(output_1d_ms_1.asnumpy() == input_1d_np) + input_1d_np.itemset(index_np_1, value_np_2) + assert onp.all(output_1d_ms_2.asnumpy() == input_1d_np) + input_3d_np.itemset(index_np_1, value_np_1) + assert onp.all(output_3d_ms_1.asnumpy() == input_3d_np) + input_3d_np.itemset(index_np_1, value_np_2) + assert onp.all(output_3d_ms_2.asnumpy() == input_3d_np) + input_3d_np.itemset(index_np_2, value_np_1) + assert onp.all(output_3d_ms_3.asnumpy() == input_3d_np) + input_3d_np.itemset(index_np_2, value_np_2) + assert onp.all(output_3d_ms_4.asnumpy() == input_3d_np) + + with pytest.raises(IndexError): + net(input_1d_ms, index_np_2, value_np_1) + with pytest.raises(IndexError): + net(input_1d_ms, index_np_2, value_np_2) + with pytest.raises(TypeError): + net(input_1d_ms, index_np_4, value_np_1) + with pytest.raises(TypeError): + net(input_1d_ms, index_np_4, value_np_2) + with pytest.raises(IndexError): + net(input_3d_ms, index_np_3, value_np_1) + with pytest.raises(IndexError): + net(input_3d_ms, index_np_3, value_np_2) + with pytest.raises(TypeError): + net(input_3d_ms, index_np_4, value_np_1) + with pytest.raises(TypeError): + net(input_3d_ms, index_np_4, value_np_2) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_itemset_by_tuple_with_number(): + net = TensorItemSetByItemWithNumber() + input_1d_np = onp.ndarray([1]).astype(onp.float32) + input_1d_ms = Tensor(input_1d_np, mstype.float32) + + input_3d_np = onp.arange(60).reshape(3, 4, 5).astype(onp.int32) + input_3d_ms = Tensor(input_3d_np, mstype.float32) + + index_np_1, index_np_2, index_np_3, index_np_4, index_np_5 = (0,), (1, 2), (1, 1, 0), (3, 4, 5), (1, 2, 3, 4) + value_np_1, value_np_2 = 1, 2.0 + + output_1d_ms_1 = net(input_1d_ms, index_np_1, value_np_1) + input_1d_np.itemset(index_np_1, value_np_1) + assert onp.all(output_1d_ms_1.asnumpy() == input_1d_np) + + output_1d_ms_2 = net(input_1d_ms, index_np_1, value_np_2) + input_1d_np.itemset(index_np_1, value_np_2) + assert onp.all(output_1d_ms_2.asnumpy() == input_1d_np) + + output_3d_ms_1 = net(input_3d_ms, index_np_3, value_np_1) + input_3d_np.itemset(index_np_3, value_np_1) + assert onp.all(output_3d_ms_1.asnumpy() == input_3d_np) + + output_3d_ms_2 = net(input_3d_ms, index_np_3, value_np_2) + input_3d_np.itemset(index_np_3, value_np_2) + assert onp.all(output_3d_ms_2.asnumpy() == input_3d_np) + + with pytest.raises(ValueError): + net(input_1d_ms, index_np_2, value_np_1) + with pytest.raises(ValueError): + net(input_1d_ms, index_np_2, value_np_2) + with pytest.raises(ValueError): + net(input_3d_ms, index_np_1, value_np_1) + with pytest.raises(ValueError): + net(input_3d_ms, index_np_1, value_np_2) + with pytest.raises(ValueError): + net(input_3d_ms, index_np_2, value_np_1) + with pytest.raises(ValueError): + net(input_3d_ms, index_np_2, value_np_2) + with pytest.raises(IndexError): + net(input_3d_ms, index_np_4, value_np_1) + with pytest.raises(IndexError): + net(input_3d_ms, index_np_4, value_np_2) + with pytest.raises(ValueError): + net(input_3d_ms, index_np_5, value_np_1) + with pytest.raises(ValueError): + net(input_3d_ms, index_np_5, value_np_2) diff --git a/tests/ut/python/ops/test_tensor_fancy_index.py b/tests/ut/python/ops/test_tensor_getitem.py similarity index 52% rename from tests/ut/python/ops/test_tensor_fancy_index.py rename to tests/ut/python/ops/test_tensor_getitem.py index 383a08e063b..b5dc13651cb 100644 --- a/tests/ut/python/ops/test_tensor_fancy_index.py +++ b/tests/ut/python/ops/test_tensor_getitem.py @@ -20,6 +20,10 @@ from mindspore import Tensor from mindspore import context from mindspore import dtype as mstype from mindspore.nn import Cell +from ....mindspore_test_framework.mindspore_test import mindspore_test +from ....mindspore_test_framework.pipeline.forward.compile_forward \ + import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \ + pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception class NetWorkFancyIndex(Cell): @@ -31,6 +35,18 @@ class NetWorkFancyIndex(Cell): return tensor[self.index] +class TensorItemByNone(Cell): + def construct(self, tensor): + ret = tensor.item() + return ret + + +class TensorItemByItem(Cell): + def construct(self, tensor, index): + ret = tensor.item(index) + return ret + + def test_tensor_fancy_index_integer_list(): context.set_context(mode=context.GRAPH_MODE, save_graphs=True) index = [0, 2, 1] @@ -102,3 +118,72 @@ def test_tensor_fancy_index_integer_list_tuple_bool_mixed_error(): input_me = Tensor(input_np, dtype=mstype.float32) with pytest.raises(IndexError): net(input_me) + + +input_1d_np = np.ndarray([1]).astype(np.float32) +input_1d_ms = Tensor(input_1d_np, mstype.float32) +input_3d_np = np.random.randint(3, size=(3, 4, 5)).astype(np.int32) +input_3d_ms = Tensor(input_3d_np, mstype.float32) +index_np_1, index_np_2, index_np_3, index_np_4 = 0, 1.0, 30, 60 +tuple_index_np_1, tuple_index_np_2, tuple_index_np_3, tuple_index_np_4, tuple_index_np_5 = \ + (0,), (1, 2), (1, 2, 3), (3, 4, 4), (1, 2, 3, 4) + +test_cases = [ + ('TensorItemByNone', {'block': TensorItemByNone(), 'desc_inputs': [input_1d_ms],}), + ('1dTensorItemByInt', {'block': TensorItemByItem(), 'desc_inputs': [input_1d_ms, index_np_1],}), + ('3dTensorItemByInt', {'block': TensorItemByItem(), 'desc_inputs': [input_3d_ms, index_np_1],}), + ('3dTensorItemByInt2', {'block': TensorItemByItem(), 'desc_inputs': [input_3d_ms, index_np_3],}), + ('1dTensorItemByTuple', {'block': TensorItemByItem(), 'desc_inputs': [input_1d_ms, tuple_index_np_1],}), + ('3dTensorItemByTuple', {'block': TensorItemByItem(), 'desc_inputs': [input_3d_ms, tuple_index_np_3],}), +] + + +test_error_cases = [ + ('TensorItemByNoneForMulDimsTensor', { + 'block': (TensorItemByNone(), {'exception': ValueError}), + 'desc_inputs': [input_3d_ms] + }), + ('TensorItemByFloatError', { + 'block': (TensorItemByItem(), {'exception': TypeError}), + 'desc_inputs': [input_1d_ms, index_np_2] + }), + ('TensorItemByFloatError2', { + 'block': (TensorItemByItem(), {'exception': TypeError}), + 'desc_inputs': [input_3d_ms, index_np_2] + }), + ('TensorItemByIntOverBoundary', { + 'block': (TensorItemByItem(), {'exception': IndexError}), + 'desc_inputs': [input_1d_ms, index_np_3] + }), + ('TensorItemByIntOverBoundary2', { + 'block': (TensorItemByItem(), {'exception': IndexError}), + 'desc_inputs': [input_3d_ms, index_np_4] + }), + ('1dTensorItemBy2dTuple', { + 'block': (TensorItemByItem(), {'exception': ValueError}), + 'desc_inputs': [input_1d_ms, tuple_index_np_2] + }), + ('3dTensorItemBy2dTuple', { + 'block': (TensorItemByItem(), {'exception': ValueError}), + 'desc_inputs': [input_3d_ms, tuple_index_np_2] + }), + ('3dTensorItemBy3dTupleOutOfBoundary', { + 'block': (TensorItemByItem(), {'exception': IndexError}), + 'desc_inputs': [input_3d_ms, tuple_index_np_4] + }), + ('3dTensorItemBy4dTuple', { + 'block': (TensorItemByItem(), {'exception': ValueError}), + 'desc_inputs': [input_3d_ms, tuple_index_np_5] + }) +] + + +@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config) +def test_exec(): + context.set_context(mode=context.GRAPH_MODE) + return test_cases + + +@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception) +def test_check_exception(): + return test_error_cases diff --git a/tests/ut/python/ops/test_tensor_slice.py b/tests/ut/python/ops/test_tensor_slice.py index 5d90331e358..e761c924dae 100644 --- a/tests/ut/python/ops/test_tensor_slice.py +++ b/tests/ut/python/ops/test_tensor_slice.py @@ -639,6 +639,30 @@ class TensorAssignWithBoolTensorIndex2Error(Cell): return a +class TensorItemSetWithNumber(Cell): + def construct(self, tensor, number_value): + ret = tensor.itemset(number_value) + return ret + + +class TensorItemSetByItemWithNumber(Cell): + def construct(self, tensor, index, number_value): + ret = tensor.itemset(index, number_value) + return ret + + +input_1d_np = np.ndarray([1]).astype(np.float32) +input_1d_ms = Tensor(input_1d_np, mstype.float32) + +input_3d_np = np.random.randint(3, size=(3, 4, 5)).astype(np.int32) +input_3d_ms = Tensor(input_3d_np, mstype.float32) + +index_np_1, index_np_2, index_np_3, index_np_4 = 0, 30, 60, 2.0 +tuple_index_np_1, tuple_index_np_2, tuple_index_np_3, tuple_index_np_4, tuple_index_np_5 = \ + (0,), (1, 2), (1, 2, 3), (3, 4, 4), (1, 2, 3, 4) +value_np_1, value_np_2 = 1, 2.0 + + a = np.arange(60).reshape(3, 4, 5) ck = np.arange(60).reshape(3, 4, 5) a4 = np.arange(60).reshape(3, 2, 2, 5) @@ -934,9 +958,57 @@ test_cases = [ Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], }), + ('1dTensorItemSetWithInt', { + 'block': TensorItemSetWithNumber(), + 'desc_inputs': [input_1d_ms, value_np_1] + }), + ('1dTensorItemSetWithFloat', { + 'block': TensorItemSetWithNumber(), + 'desc_inputs': [input_1d_ms, value_np_2] + }), + ('1dTensorItemSetByIntWithInt', { + 'block': TensorItemSetByItemWithNumber(), + 'desc_inputs': [input_1d_ms, index_np_1, value_np_1] + }), + ('1dTensorItemSetByIntWithFloat', { + 'block': TensorItemSetByItemWithNumber(), + 'desc_inputs': [input_1d_ms, index_np_1, value_np_2] + }), + ('3dTensorItemSetByIntWithInt', { + 'block': TensorItemSetByItemWithNumber(), + 'desc_inputs': [input_3d_ms, index_np_1, value_np_1] + }), + ('3dTensorItemSetByIntWithFloat', { + 'block': TensorItemSetByItemWithNumber(), + 'desc_inputs': [input_3d_ms, index_np_1, value_np_2] + }), + ('3dTensorItemSetByIntWithInt2', { + 'block': TensorItemSetByItemWithNumber(), + 'desc_inputs': [input_3d_ms, index_np_2, value_np_1] + }), + ('3dTensorItemSetByIntWithFloat2', { + 'block': TensorItemSetByItemWithNumber(), + 'desc_inputs': [input_3d_ms, index_np_2, value_np_2] + }), + ('1dTensorItemSetBy1dTupleWithInt', { + 'block': TensorItemSetByItemWithNumber(), + 'desc_inputs': [input_1d_ms, tuple_index_np_1, value_np_1] + }), + ('1dTensorItemSetBy1dTupleWithFloat', { + 'block': TensorItemSetByItemWithNumber(), + 'desc_inputs': [input_1d_ms, tuple_index_np_1, value_np_2] + }), + ('3dTensorItemSetBy3dTupleWithInt', { + 'block': TensorItemSetByItemWithNumber(), + 'desc_inputs': [input_3d_ms, tuple_index_np_3, value_np_1] + }), + ('3dTensorItemSetBy3dTupleWithFloat', { + 'block': TensorItemSetByItemWithNumber(), + 'desc_inputs': [input_3d_ms, tuple_index_np_3, value_np_2] + }), ] -raise_error_set = [ +test_error_cases = [ ('TensorGetItemByOneTensorDtypeError', { 'block': (TensorGetItemByOneTensor(), {'exception': IndexError}), 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), @@ -1137,6 +1209,86 @@ raise_error_set = [ 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.float32), Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], + }), + ('3dTensorItemSetWithInt', { + 'block': (TensorItemSetWithNumber(), {'exception': IndexError}), + 'desc_inputs': [input_3d_ms, value_np_1] + }), + ('3dTensorItemSetWithFloat', { + 'block': (TensorItemSetWithNumber(), {'exception': IndexError}), + 'desc_inputs': [input_3d_ms, value_np_2] + }), + ('1dTensorItemSetByOverflowIntWithInt', { + 'block': (TensorItemSetByItemWithNumber(), {'exception': IndexError}), + 'desc_inputs': [input_1d_ms, index_np_2, value_np_1] + }), + ('1dTensorItemSetByOverflowIntWithFloat', { + 'block': (TensorItemSetByItemWithNumber(), {'exception': IndexError}), + 'desc_inputs': [input_1d_ms, index_np_2, value_np_2] + }), + ('1dTensorItemSetByFloatWithInt', { + 'block': (TensorItemSetByItemWithNumber(), {'exception': TypeError}), + 'desc_inputs': [input_1d_ms, index_np_4, value_np_1] + }), + ('1dTensorItemSetByFLoatWithFloat', { + 'block': (TensorItemSetByItemWithNumber(), {'exception': TypeError}), + 'desc_inputs': [input_1d_ms, index_np_4, value_np_2] + }), + ('3dTensorItemSetByOverflowIntWithInt', { + 'block': (TensorItemSetByItemWithNumber(), {'exception': IndexError}), + 'desc_inputs': [input_3d_ms, index_np_3, value_np_1] + }), + ('3dTensorItemSetByOverflowIntWithFloat', { + 'block': (TensorItemSetByItemWithNumber(), {'exception': IndexError}), + 'desc_inputs': [input_3d_ms, index_np_3, value_np_2] + }), + ('3dTensorItemSetByFloatIntWithInt', { + 'block': (TensorItemSetByItemWithNumber(), {'exception': TypeError}), + 'desc_inputs': [input_3d_ms, index_np_4, value_np_1] + }), + ('3dTensorItemSetByFloatWithFloat', { + 'block': (TensorItemSetByItemWithNumber(), {'exception': TypeError}), + 'desc_inputs': [input_3d_ms, index_np_4, value_np_2] + }), + ('1dTensorItemSetBy2dTupleWithFloat', { + 'block': (TensorItemSetByItemWithNumber(), {'exception': ValueError}), + 'desc_inputs': [input_1d_ms, tuple_index_np_2, value_np_1] + }), + ('1dTensorItemSetBy2dTupleWithFloat', { + 'block': (TensorItemSetByItemWithNumber(), {'exception': ValueError}), + 'desc_inputs': [input_1d_ms, tuple_index_np_2, value_np_2] + }), + ('3dTensorItemSetBy1dTupleWithFloat', { + 'block': (TensorItemSetByItemWithNumber(), {'exception': ValueError}), + 'desc_inputs': [input_3d_ms, tuple_index_np_1, value_np_1] + }), + ('3dTensorItemSetBy1dTupleWithFloat', { + 'block': (TensorItemSetByItemWithNumber(), {'exception': ValueError}), + 'desc_inputs': [input_3d_ms, tuple_index_np_1, value_np_2] + }), + ('3dTensorItemSetBy2dTupleWithFloat', { + 'block': (TensorItemSetByItemWithNumber(), {'exception': ValueError}), + 'desc_inputs': [input_3d_ms, tuple_index_np_2, value_np_1] + }), + ('3dTensorItemSetBy2dTupleWithFloat', { + 'block': (TensorItemSetByItemWithNumber(), {'exception': ValueError}), + 'desc_inputs': [input_3d_ms, tuple_index_np_2, value_np_2] + }), + ('3dTensorItemSetBy3dTupleOverFloawWithFloat', { + 'block': (TensorItemSetByItemWithNumber(), {'exception': ValueError}), + 'desc_inputs': [input_3d_ms, tuple_index_np_4, value_np_1] + }), + ('3dTensorItemSetBy3dTupleOverFloawWithFloat', { + 'block': (TensorItemSetByItemWithNumber(), {'exception': ValueError}), + 'desc_inputs': [input_3d_ms, tuple_index_np_4, value_np_2] + }), + ('3dTensorItemSetBy4dTupleWithFloat', { + 'block': (TensorItemSetByItemWithNumber(), {'exception': ValueError}), + 'desc_inputs': [input_3d_ms, tuple_index_np_5, value_np_1] + }), + ('3dTensorItemSetBy4dTupleWithFloat', { + 'block': (TensorItemSetByItemWithNumber(), {'exception': ValueError}), + 'desc_inputs': [input_3d_ms, tuple_index_np_5, value_np_2] }) ] @@ -1149,7 +1301,7 @@ def test_exec(): @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception) def test_check_exception(): - return raise_error_set + return test_error_cases def test_tensor_slice_reduce_out_of_bounds_neg():