!14668 Add more numpy interfaces

From: @yanglf1121
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-04-23 11:18:34 +08:00 committed by Gitee
commit f0c4b043e8
22 changed files with 5487 additions and 151 deletions

View File

@ -728,6 +728,15 @@ class Validator:
raise ValueError(f"axis {axes} has shape entry {s} > 1, cannot be squeezed.")
return tuple(new_shape)
@staticmethod
def check_axis_in_range(axis, ndim):
"""Checks axes are with the bounds of ndim"""
if not isinstance(axis, int):
raise TypeError(f'axes should be integers, not {type(axis)}')
if not -ndim <= axis < ndim:
raise ValueError(f'axis {axis} is out of bounds for array of dimension {ndim}')
return axis % ndim
def check_input_format(input_param):
"""Judge input format."""

View File

@ -148,7 +148,34 @@ def strides_(x):
def astype(x, dtype, copy=True):
"""Implementation of `astype`."""
"""
Return a copy of the tensor, casted to a specified type.
Args:
dtype (Union[:class:`mindspore.dtype`, str]): Designated tensor dtype, can be in format
of :class:`mindspore.dtype.float32` or `float32`.
Default: :class:`mindspore.dtype.float32`.
copy (bool, optional): By default, astype always returns a newly allocated
tensor. If this is set to false, the input tensor is returned instead
of a copy if possible. Default: True.
Returns:
Tensor, with the designated dtype.
Raises:
TypeError: If `dtype` has types not specified above, or values cannot be understood.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> x = Tensor(np.ones((1,2,2,1), dtype=np.float32))
>>> x = x.astype("int32")
>>> print(x.dtype)
Int32
"""
dtype = check_astype_dtype_const(dtype)
if not copy and dtype == x.dtype:
return x
@ -156,7 +183,40 @@ def astype(x, dtype, copy=True):
def transpose(x, *axis):
"""Implementation of `transpose`."""
r"""
Return a view of the tensor with axes transposed.
For a 1-D tensor this has no effect, as a transposed vector is simply the
same vector. For a 2-D tensor, this is a standard matrix transpose. For a
n-D tensor, if axes are given, their order indicates how the axes are permuted.
If axes are not provided and tensor.shape = (i[0], i[1],...i[n-2], i[n-1]),
then tensor.transpose().shape = (i[n-1], i[n-2], ... i[1], i[0]).
Args:
axes(Union[None, tuple(int), list(int), int], optional): If axes is None or
blank, tensor.transpose() will reverse the order of the axes. If axes is tuple(int)
or list(int), tensor.transpose() will transpose the tensor to the new axes order.
If axes is int, this form is simply intended as a convenience alternative to the
tuple/list form.
Returns:
Tensor, has the same dimension as input tensor, with axes suitably permuted.
Raises:
TypeError: If input arguments have types not specified above.
ValueError: If the number of `axes` is not euqal to a.ndim.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> x = Tensor(np.ones((1,2,3), dtype=np.float32))
>>> x = x.transpose()
>>> print(x.shape)
(3, 2, 1)
"""
ndim = F.rank(x)
perm = check_transpose_axis_const(axis, ndim)
return F.transpose(x, perm)
@ -167,27 +227,86 @@ T_ = transpose
def reshape(x, *shape):
"""Implementation of `reshape`."""
"""
Give a new shape to a tensor without changing its data.
Args:
shape(Union[int, tuple(int), list(int)]): The new shape should be compatible
with the original shape. If an integer, then the result will be a 1-D
array of that length. One shape dimension can be -1. In this case, the
value is inferred from the length of the array and remaining dimensions.
Returns:
Tensor, with new specified shape.
Raises:
TypeError: If new_shape is not integer, list or tuple, or `x` is not tensor.
ValueError: If new_shape is not compatible with the original shape.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore import Tensor
>>> from mindspore import dtype as mstype
>>> x = Tensor([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]], dtype=mstype.float32)
>>> output = np.reshape(x, (3, 2))
>>> print(output)
[[-0.1 0.3]
[ 3.6 0.4]
[ 0.5 -3.2]]
"""
new_shape = check_reshape_shp_const(shape)
return F.reshape(x, new_shape)
def ravel(x):
"""Implementation of `ravel`."""
"""
Return a contiguous flattened tensor.
Returns:
Tensor, a 1-D tensor, containing the same elements of the input.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> x = Tensor(np.ones((2,3,4), dtype=np.float32))
>>> output = x.ravel()
>>> print(output.shape)
(24,)
"""
return reshape(x, (-1,))
def flatten(x, order='C'):
"""
Returns a copy of the array collapsed into one dimension.
r"""
Return a copy of the tensor collapsed into one dimension.
Args:
order (str, optional): Can choose between `C` and `F`. `C` means to
flatten in row-major (C-style) order. F means to flatten in column-major
(Fortran- style) order. Only `C` and `F` are supported.
order (str, optional): Can choose between 'C' and 'F'. 'C' means to
flatten in row-major (C-style) order. 'F' means to flatten in column-major
(Fortran-style) order. Only 'C' and 'F' are supported. Default: 'C'.
Returns:
Tensor, has the same data type as x.
Tensor, has the same data type as input.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Raises:
TypeError: If `order` is not string type.
ValueError: If `order` is string type, but not 'C' or 'F'.
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> x = Tensor(np.ones((2,3,4), dtype=np.float32))
>>> output = x.flatten()
>>> print(output.shape)
(24,)
"""
order = check_flatten_order_const(order)
if order == 'C':
@ -200,14 +319,29 @@ def flatten(x, order='C'):
def swapaxes(x, axis1, axis2):
"""
Interchanges two axes of a tensor.
Interchange two axes of a tensor.
Args:
axis1 (int): First axis.
axis2 (int): Second axis.
Returns:
Transposed tensor, has the same data type as the original tensor x.
Transposed tensor, has the same data type as the input.
Raises:
TypeError: If `axis1` or `axis2` is not integer.
ValueError: If `axis1` or `axis2` is not in the range of :math:`[-ndim, ndim-1]`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> x = Tensor(np.ones((2,3,4), dtype=np.float32))
>>> output = np.swapaxes(x, 0, 2)
>>> print(output.shape)
(4,3,2)
"""
axis1, axis2 = check_swapaxes_axis_const((axis1, axis2), x.ndim)
@ -230,13 +364,28 @@ def swapaxes(x, axis1, axis2):
def squeeze(x, axis=None):
"""
Removes single-dimensional entries from the shape of an tensor.
Remove single-dimensional entries from the shape of a tensor.
Args:
axis: Union[None, int, list(int), tuple(list)]. Default is None.
axis (Union[None, int, list(int), tuple(int)], optional): Default is None.
Returns:
Tensor, with all or a subset of the dimensions of length 1 removed.
Raises:
TypeError: If input arguments have types not specified above.
ValueError: If specified axis has shape entry :math:`> 1`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> x = Tensor(np.ones((1,2,2,1), dtype=np.float32))
>>> x = x.squeeze()
>>> print(x.shape)
(2, 2)
"""
shape = F.shape(x)
if axis is None:
@ -246,6 +395,78 @@ def squeeze(x, axis=None):
return F.reshape(x, new_shape)
def argmax(x, axis=None):
"""
Returns the indices of the maximum values along an axis.
Args:
axis (int, optional): By default, the index is into
the flattened array, otherwise along the specified axis.
Returns:
Tensor, array of indices into the array. It has the same
shape as a.shape with the dimension along axis removed.
Raises:
ValueError: if axis is out of range.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> a = Tensor(np.arange(10, 16).reshape(2, 3).astype("float32"))
>>> print(np.argmax(a))
5
"""
# P.Argmax only supports float
x = x.astype(mstype.float32)
if axis is None:
x = ravel(x)
axis = 0
else:
axis = check_axis_in_range_const(axis, F.rank(x))
return P.Argmax(axis)(x)
def argmin(x, axis=None):
"""
Returns the indices of the minimum values along an axis.
Args:
a (Union[int, float, bool, list, tuple, Tensor]): Input array.
axis (int, optional): By default, the index is into
the flattened array, otherwise along the specified axis.
Returns:
Tensor, array of indices into the array. It has the same
shape as a.shape with the dimension along axis removed.
Raises:
ValueError: if axis is out of range.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> a = Tensor(np.arange(10, 16).reshape(2, 3).astype("float32"))
>>> print(np.argmin(a))
0
"""
# P.Argmax only supports float
x = x.astype(mstype.float32)
if axis is None:
x = ravel(x)
axis = 0
else:
axis = check_axis_in_range_const(axis, F.rank(x))
# P.Argmin is currently not supported
return P.Argmax(axis)(F.neg_tensor(x))
def getitem(data, item):
"""Implementation of `getitem`."""
return data.__getitem__(item)
@ -466,6 +687,7 @@ check_reshape_shp_const = constexpr(validator.check_reshape_shp)
check_flatten_order_const = constexpr(validator.check_flatten_order)
check_swapaxes_axis_const = constexpr(validator.check_swapaxes_axis)
prepare_shape_for_squeeze_const = constexpr(validator.prepare_shape_for_squeeze)
check_axis_in_range_const = constexpr(validator.check_axis_in_range)
def tensor_bool(x):

View File

@ -184,6 +184,8 @@ BuiltInTypeMap &GetMethodMap() {
{"squeeze", std::string("squeeze")}, // P.squeeze()
{"astype", std::string("astype")}, // P.cast()
{"__bool__", std::string("tensor_bool")}, // C.tensor_bool
{"argmax", std::string("argmax")}, // P.Argmax()
{"argmin", std::string("argmin")}, // P.Argmax()
}},
{kObjectTypeRowTensorType,
{

View File

@ -466,6 +466,21 @@ class Tensor(Tensor_):
Returns:
Tensor, has the same dimension as input tensor, with axes suitably permuted.
Raises:
TypeError: If input arguments have types not specified above.
ValueError: If the number of `axes` is not euqal to a.ndim.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> x = Tensor(np.ones((1,2,3), dtype=np.float32))
>>> x = x.transpose()
>>> print(x.shape)
(3, 2, 1)
"""
self.init_check()
perm = validator.check_transpose_axis(axes, self.ndim)
@ -483,6 +498,23 @@ class Tensor(Tensor_):
Returns:
Tensor, with new specified shape.
Raises:
TypeError: If new_shape is not integer, list or tuple, or `x` is not tensor.
ValueError: If new_shape is not compatible with the original shape.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore import Tensor
>>> from mindspore import dtype as mstype
>>> x = Tensor([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]], dtype=mstype.float32)
>>> output = np.reshape(x, (3, 2))
>>> print(output)
[[-0.1 0.3]
[ 3.6 0.4]
[ 0.5 -3.2]]
"""
self.init_check()
new_shape = validator.check_reshape_shp(shape)
@ -494,6 +526,17 @@ class Tensor(Tensor_):
Returns:
Tensor, a 1-D tensor, containing the same elements of the input.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> x = Tensor(np.ones((2,3,4), dtype=np.float32))
>>> output = x.ravel()
>>> print(output.shape)
(24,)
"""
self.init_check()
reshape_op = tensor_operator_registry.get('reshape')()
@ -510,6 +553,21 @@ class Tensor(Tensor_):
Returns:
Tensor, has the same data type as input.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Raises:
TypeError: If `order` is not string type.
ValueError: If `order` is string type, but not 'C' or 'F'.
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> x = Tensor(np.ones((2,3,4), dtype=np.float32))
>>> output = x.flatten()
>>> print(output.shape)
(24,)
"""
self.init_check()
reshape_op = tensor_operator_registry.get('reshape')()
@ -532,6 +590,21 @@ class Tensor(Tensor_):
Returns:
Transposed tensor, has the same data type as the input.
Raises:
TypeError: If `axis1` or `axis2` is not integer.
ValueError: If `axis1` or `axis2` is not in the range of :math:`[-ndim, ndim-1]`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> x = Tensor(np.ones((2,3,4), dtype=np.float32))
>>> output = np.swapaxes(x, 0, 2)
>>> print(output.shape)
(4,3,2)
"""
self.init_check()
axis1, axis2 = validator.check_swapaxes_axis((axis1, axis2), self.ndim)
@ -561,6 +634,21 @@ class Tensor(Tensor_):
Returns:
Tensor, with all or a subset of the dimensions of length 1 removed.
Raises:
TypeError: If input arguments have types not specified above.
ValueError: If specified axis has shape entry :math:`> 1`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> x = Tensor(np.ones((1,2,2,1), dtype=np.float32))
>>> x = x.squeeze()
>>> print(x.shape)
(2, 2)
"""
self.init_check()
if axis is None:
@ -582,6 +670,20 @@ class Tensor(Tensor_):
Returns:
Tensor, with the designated dtype.
Raises:
TypeError: If `dtype` has types not specified above, or values cannot be understood.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> x = Tensor(np.ones((1,2,2,1), dtype=np.float32))
>>> x = x.astype("int32")
>>> print(x.dtype)
Int32
"""
self.init_check()
dtype = validator.check_astype_dtype(dtype)
@ -589,6 +691,77 @@ class Tensor(Tensor_):
return self
return tensor_operator_registry.get('cast')(self, dtype)
def argmax(self, axis=None):
"""
Returns the indices of the maximum values along an axis.
Args:
axis (int, optional): By default, the index is into
the flattened array, otherwise along the specified axis.
Returns:
Tensor, array of indices into the array. It has the same
shape as a.shape with the dimension along axis removed.
Raises:
ValueError: if axis is out of range.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> a = Tensor(np.arange(10, 16).reshape(2, 3).astype("float32"))
>>> print(np.argmax(a))
5
"""
# P.Argmax only supports float
a = self.astype(mstype.float32)
if axis is None:
a = a.ravel()
axis = 0
else:
axis = validator.check_axis_in_range(axis, a.ndim)
return tensor_operator_registry.get('argmax')(axis)(a)
def argmin(self, axis=None):
"""
Returns the indices of the minimum values along an axis.
Args:
a (Union[int, float, bool, list, tuple, Tensor]): Input array.
axis (int, optional): By default, the index is into
the flattened array, otherwise along the specified axis.
Returns:
Tensor, array of indices into the array. It has the same
shape as a.shape with the dimension along axis removed.
Raises:
ValueError: if axis is out of range.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> a = Tensor(np.arange(10, 16).reshape(2, 3).astype("float32"))
>>> print(np.argmin(a))
0
"""
# P.Argmax only supports float
a = self.astype(mstype.float32)
if axis is None:
a = a.ravel()
axis = 0
else:
axis = validator.check_axis_in_range(axis, a.ndim)
# P.Argmin is currently not supported
return tensor_operator_registry.get('argmax')(axis)(tensor_operator_registry.get('__neg__')(a))
def init_check(self):
if self.has_init:
self.init_data()

View File

@ -31,15 +31,18 @@ from .array_ops import (transpose, expand_dims, squeeze, rollaxis, swapaxes, res
column_stack, hstack, dstack, vstack, stack, unique, moveaxis,
tile, broadcast_to, broadcast_arrays, roll, append, split, vsplit,
flip, flipud, fliplr, hsplit, dsplit, take_along_axis, take, repeat,
rot90, select, array_split)
rot90, select, array_split, choose, size, array_str, apply_along_axis,
piecewise, unravel_index, apply_over_axes)
from .array_creations import copy_ as copy
from .array_creations import (array, asarray, asfarray, ones, zeros, full, arange,
linspace, logspace, eye, identity, empty, empty_like,
ones_like, zeros_like, full_like, diagonal, tril, triu,
tri, trace, meshgrid, mgrid, ogrid, diagflat,
diag, diag_indices, ix_, indices, geomspace, vander)
diag, diag_indices, ix_, indices, geomspace, vander, hamming,
hanning, bartlett, blackman, triu_indices, tril_indices,
triu_indices_from, tril_indices_from, histogram_bin_edges, pad)
from .dtypes import (int_, int8, int16, int32, int64, uint, uint8, uint16,
uint32, uint64, float_, float16, float32, float64, bool_, inf, nan,
uint32, uint64, float_, float16, float32, float64, bool_, inf, nan, pi,
numeric_types, PINF, NINF)
from .math_ops import (mean, inner, add, subtract, multiply, divide, true_divide, power,
dot, outer, tensordot, absolute, std, var, average, minimum,
@ -50,31 +53,43 @@ from .math_ops import (mean, inner, add, subtract, multiply, divide, true_divide
cross, ceil, trapz, gcd, lcm, convolve, log1p, logaddexp, log2,
logaddexp2, log10, ediff1d, nansum, nanmean, nanvar, nanstd, cumsum, nancumsum,
sin, cos, tan, arcsin, arccos, arctan, sinh, cosh, tanh, arcsinh, arccosh,
arctanh, arctan2, cov)
arctanh, arctan2, cov, multi_dot, nanmax, nanmin, argmax, argmin, searchsorted,
interp, sum_, corrcoef, gradient, sign, copysign, digitize, bincount, histogram,
histogramdd, histogram2d, matrix_power, around, polyadd, polysub, polyval,
polyder, polymul, polyint, result_type, unwrap, cumprod, ravel_multi_index,
norm, bitwise_and, bitwise_or, bitwise_xor, invert, rint, correlate, radians)
from .logic_ops import (not_equal, less_equal, less, greater_equal, greater, equal, isfinite,
isnan, isinf, isposinf, isneginf, isscalar, logical_and, logical_not,
logical_or, logical_xor, in1d, isin, isclose)
logical_or, logical_xor, in1d, isin, isclose, signbit, sometrue,
array_equal, array_equiv)
mod = remainder
fabs = absolute
round = around # pylint: disable=redefined-builtin
divmod = divmod_ # pylint: disable=redefined-builtin
del divmod_
abs = absolute # pylint: disable=redefined-builtin
max = amax # pylint: disable=redefined-builtin
min = amin # pylint: disable=redefined-builtin
sum = sum_ # pylint: disable=redefined-builtin
del sum_
bitwise_not = invert
array_ops_module = ['transpose', 'expand_dims', 'squeeze', 'rollaxis', 'swapaxes', 'reshape',
'ravel', 'concatenate', 'where', 'atleast_1d', 'atleast_2d', 'atleast_3d',
'column_stack', 'hstack', 'dstack', 'vstack', 'stack', 'unique', 'moveaxis',
'tile', 'broadcast_to', 'broadcast_arrays', 'append', 'roll', 'split', 'vsplit',
'flip', 'flipud', 'fliplr', 'hsplit', 'dsplit', 'take_along_axis', 'take',
'repeat', 'rot90', 'select', 'array_split']
'repeat', 'rot90', 'select', 'array_split', 'choose', 'size', 'array_str',
'apply_along_axis', 'piecewise', 'unravel_index', 'apply_over_axes']
array_creations_module = ['array', 'asarray', 'asfarray', 'ones', 'zeros', 'full', 'arange',
'linspace', 'logspace', 'eye', 'identity', 'empty', 'empty_like',
'ones_like', 'zeros_like', 'full_like', 'diagonal', 'tril', 'triu',
'tri', 'trace', 'meshgrid', 'mgrid', 'ogrid', 'diagflat', 'diag',
'diag_indices', 'ix_', 'indices', 'geomspace', 'vander']
'diag_indices', 'ix_', 'indices', 'geomspace', 'vander', 'hamming',
'hanning', 'bartlett', 'blackman', 'triu_indices', 'tril_indices',
'triu_indices_from', 'tril_indices_from', 'histogram_bin_edges', 'pad']
math_module = ['mean', 'inner', 'add', 'subtract', 'multiply', 'divide', 'true_divide', 'power',
'dot', 'outer', 'tensordot', 'absolute', 'std', 'var', 'average', 'not_equal',
@ -86,11 +101,17 @@ math_module = ['mean', 'inner', 'add', 'subtract', 'multiply', 'divide', 'true_d
'abs', 'max', 'min', 'gcd', 'lcm', 'log1p', 'logaddexp', 'log2', 'logaddexp2', 'log10',
'convolve', 'ediff1d', 'nansum', 'nanmean', 'nanvar', 'nanstd', 'cumsum',
'nancumsum', 'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan', 'sinh', 'cosh', 'tanh',
'arcsinh', 'arccosh', 'arctanh', 'arctan2', 'cov']
'arcsinh', 'arccosh', 'arctanh', 'arctan2', 'cov', 'multi_dot', 'nanmax', 'nanmin',
'argmax', 'argmin', 'searchsorted', 'interp', 'sum', 'corrcoef', 'gradient', 'sign',
'copysign', 'radians', 'digitize', 'bincount', 'histogram', 'histogramdd', 'histogram2d',
'polyadd', 'polysub', 'polyval', 'polyder', 'polymul', 'polyint', 'result_type',
'unwrap', 'cumprod', 'ravel_multi_index', 'norm', 'bitwise_and', 'bitwise_or',
'bitwise_xor', 'invert', 'bitwise_not', 'rint', "correlate"]
logic_module = ['not_equal', 'less_equal', 'less', 'greater_equal', 'greater', 'equal', 'isfinite',
'isnan', 'isinf', 'isposinf', 'isneginf', 'isscalar', 'logical_and', 'logical_not',
'logical_or', 'logical_xor', 'in1d', 'isin', 'isclose']
'logical_or', 'logical_xor', 'in1d', 'isin', 'isclose', 'signbit', 'sometrue',
'array_equal', 'array_equiv']
__all__ = array_ops_module + array_creations_module + math_module + logic_module + numeric_types

View File

@ -13,10 +13,14 @@
# limitations under the License.
# ============================================================================
"""array operations, the function docs are adapted from Numpy API."""
import math
import operator
import numpy as onp
from ..common import Tensor
from ..common import dtype as mstype
from ..ops import operations as P
from ..ops import functional as F
from ..ops.primitive import constexpr
from ..nn.layer.basic import tril as nn_tril
@ -25,13 +29,15 @@ from .._c_expression import Tensor as Tensor_
from .utils import _check_input_for_asarray, _deep_list, _deep_tensor_to_nparray, \
_broadcast_to_shape, _check_input_tensor, _convert_64_to_32, _get_dtype_from_scalar, \
_expand
_expand, _to_tensor, _slice_along_axis, _callable
from .utils_const import _raise_value_error, _empty, _check_axis_valid, _max, _min, \
_check_same_type, _is_shape_empty, _check_shape, _check_dtype, _tile_size, _abs, \
_raise_type_error, _expanded_shape, _check_is_float, _iota, _type_convert, \
_canonicalize_axis, _list_comprehensions, _ceil, _tuple_getitem, _tuple_slice
from .array_ops import transpose, ravel, concatenate, broadcast_arrays, reshape, broadcast_to
from .dtypes import nan
_canonicalize_axis, _list_comprehensions, _ceil, _tuple_slice, _raise_unimplemented_error, \
_tuple_setitem
from .array_ops import transpose, ravel, concatenate, broadcast_arrays, reshape, broadcast_to, flip, \
apply_along_axis, where
from .dtypes import nan, pi
# According to official numpy reference, the dimension of a numpy array must be less
# than 32
@ -39,6 +45,9 @@ MAX_NUMPY_DIMS = 32
# All types that can be accepted as "array_like" parameters in graph mode.
ARRAY_TYPES = (int, float, bool, list, tuple, Tensor)
_reduce_min_keepdims = P.ReduceMin(True)
_reduce_max_keepdims = P.ReduceMax(True)
_reduce_mean_keepdims = P.ReduceMean(True)
def array(obj, dtype=None, copy=True, ndmin=0):
"""
@ -255,6 +264,7 @@ def copy_(a):
a = a.astype(origin_dtype)
return a
def ones(shape, dtype=mstype.float32):
"""
Returns a new tensor of given shape and type, filled with ones.
@ -626,7 +636,7 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
num, endpoint=endpoint, dtype=dtype, axis=axis)
shape = F.shape(bases)
axis = axis + F.rank(bases) + 1 if axis < 0 else axis
expanded_shape = _tuple_getitem(shape, axis, False) + (1,) + _tuple_getitem(shape, axis)
expanded_shape = _tuple_slice(shape, None, axis) + (1,) + _tuple_slice(shape, axis, None)
bases = F.reshape(bases, expanded_shape)
start = F.reshape(start, expanded_shape)
res = F.tensor_mul(F.tensor_pow(bases, exponents), start)
@ -1768,7 +1778,7 @@ def indices(dimensions, dtype=mstype.int32, sparse=False):
Args:
dimensions (tuple or list of ints): The shape of the grid.
dtype (data type, optional): Data type of the result.
dtype (:class:`mindspore.dtype`, optional): Data type of the result.
sparse (boolean, optional): Defaults to False. Return a sparse
representation of the grid instead of a dense representation.
@ -1801,3 +1811,724 @@ def indices(dimensions, dtype=mstype.int32, sparse=False):
for d in dimensions:
grids += (arange(d, dtype=dtype),)
return meshgrid(*grids, sparse=sparse, indexing='ij')
def _check_window_size(x):
"""Returns True if window size is greater than 1."""
if not isinstance(x, int):
_raise_type_error('the number fo points should be an int')
return x > 1
def bartlett(M):
"""
Returns the Bartlett window.
The Bartlett window is very similar to a triangular window, except that the
end points are at zero. It is often used in signal processing for tapering a
signal, without generating too much ripple in the frequency domain.
Args:
M (int): Number of points in the output window. If zero or less, an empty
array is returned.
Returns:
Tensor, the triangular window, with the maximum value normalized to one
(the value one appears only if the number of samples is odd), with the
first and last samples equal to zero.
Raises:
TypeError: if `M` is not an int.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore.numpy as np
>>> print(np.bartlett(12))
[0. 0.18181819 0.36363637 0.5454545 0.72727275 0.9090909
0.9090909 0.72727275 0.5454545 0.36363637 0.18181819 0. ]
"""
if not _check_window_size(M):
return ones(_max(0, M))
n = _iota(mstype.float32, M)
m_minus_one = _to_tensor(M - 1)
return _to_tensor(1) - F.absolute(_to_tensor(2)*n - m_minus_one)/m_minus_one
def blackman(M):
"""
Returns the Blackman window.
The Blackman window is a taper formed by using the first three terms of a
summation of cosines. It was designed to have close to the minimal leakage
possible. It is close to optimal, only slightly worse than a Kaiser window.
Args:
M (int): Number of points in the output window. If zero or less, an empty
array is returned.
Returns:
Tensor, the window, with the maximum value normalized to one (the value
one appears only if the number of samples is odd).
Raises:
TypeError: if `M` is not an int.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore.numpy as np
>>> print(np.hamming(12))
[0.08000001 0.15302339 0.34890914 0.6054648 0.841236 0.9813669
0.9813668 0.8412359 0.6054647 0.34890908 0.15302327 0.08000001]
"""
if not _check_window_size(M):
return ones(_max(0, M))
n_doubled = arange(1 - M, M, 2, dtype=mstype.float32)
return (_to_tensor(0.42) + _to_tensor(0.5)*F.cos(_to_tensor(pi/(M - 1))*n_doubled) +
_to_tensor(0.08)*F.cos(_to_tensor(2*pi/(M - 1))*n_doubled))
def hamming(M):
"""
Returns the Hamming window.
The Hamming window is a taper formed by using a weighted cosine.
Args:
M (int): Number of points in the output window. If zero or less, an empty
array is returned.
Returns:
Tensor, the window, with the maximum value normalized to one (the value
one appears only if the number of samples is odd).
Raises:
TypeError: if `M` is not an int.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore.numpy as np
>>> print(np.hamming(12))
[0.08000001 0.15302339 0.34890914 0.6054648 0.841236 0.9813669
0.9813668 0.8412359 0.6054647 0.34890908 0.15302327 0.08000001]
"""
if not _check_window_size(M):
return ones(_max(0, M))
n = _iota(mstype.float32, M)
return _to_tensor(0.54) - _to_tensor(0.46)*F.cos(_to_tensor(2*pi/(M - 1))*n)
def hanning(M):
"""
Returns the Hanning window.
The Hanning window is a taper formed by using a weighted cosine.
Args:
M (int): Number of points in the output window. If zero or less, an empty
array is returned.
Returns:
Tensor, the window, with the maximum value normalized to one (the value
one appears only if the number of samples is odd).
Raises:
TypeError: if `M` is not an int.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore.numpy as np
>>> print(np.hanning(12))
[0. 0.07937324 0.29229254 0.5711574 0.8274304 0.9797465
0.97974646 0.82743025 0.5711573 0.29229245 0.07937312 0. ]
"""
if not _check_window_size(M):
return ones(_max(0, M))
n = _iota(mstype.float32, M)
return _to_tensor(0.5) - _to_tensor(0.5)*F.cos(_to_tensor(2*pi/(M - 1))*n)
@constexpr
def tri_indices(n, k=0, m=None, upper=True):
"""Returns triu/tril indices in o(nm) time."""
if not isinstance(n, (int, float, bool)):
raise TypeError("Input n must be a number.")
if not isinstance(k, (int, float, bool)):
raise TypeError("Input k must be a number.")
if m is None:
m = n
elif not isinstance(m, (int, float, bool)):
raise TypeError("Input m must be a number.")
if upper:
compare = operator.ge
else:
compare = operator.le
x_coordinate = []
y_coordinate = []
# math.ceil is used to match numpy's behaviour
for i in range(math.ceil(n)):
curr_limit = i + k
for j in range(math.ceil(m)):
if compare(j, curr_limit):
x_coordinate.append(i)
y_coordinate.append(j)
return asarray_const(x_coordinate), asarray_const(y_coordinate)
def triu_indices(n, k=0, m=None):
"""
Returns the indices for the upper-triangle of an (n, m) array.
Args:
n (int): The size of the arrays for which the returned indices will be valid.
k (int, optional): Diagonal offset.
m (int, optional): The column dimension of the arrays for which the returned
arrays will be valid. By default `m` is taken equal to `n`.
Returns:
The indices for the triangle. The returned tuple contains two tensors, each
with the indices along one dimension of the tensor.
Raises:
TypeError: if `n`, `k`, `m` are not numbers.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore.numpy as np
>>> print(np.triu_indices(3))
(Tensor(shape=[6], dtype=Int32, value= [0, 0, 0, 1, 1, 2]),
Tensor(shape=[6], dtype=Int32, value= [0, 1, 2, 1, 2, 2]))
"""
return tri_indices(n, k, m, True)
def tril_indices(n, k=0, m=None):
"""
Returns the indices for the lower-triangle of an (n, m) array.
Args:
n (int): The size of the arrays for which the returned indices will be valid.
k (int, optional): Diagonal offset.
m (int, optional): The column dimension of the arrays for which the returned
arrays will be valid. By default `m` is taken equal to `n`.
Returns:
The indices for the triangle. The returned tuple contains two tensors, each
with the indices along one dimension of the tensor.
Raises:
TypeError: if `n`, `k`, `m` are not numbers.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore.numpy as np
>>> print(np.tril_indices(3))
(Tensor(shape=[6], dtype=Int32, value= [0, 1, 1, 2, 2, 2]),
Tensor(shape=[6], dtype=Int32, value= [0, 0, 1, 0, 1, 2]))
"""
return tri_indices(n, k, m, False)
def triu_indices_from(arr, k=0):
"""
Returns the indices for the upper-triangle of `arr`.
Args:
arr (Union[Tensor, list, tuple]): 2-dimensional array.
k (int, optional): Diagonal offset.
Returns:
triu_indices_from, tuple of 2 tensor, shape(N)
Indices for the upper-triangle of `arr`.
Raises:
TypeError: if `arr` cannot be converted to tensor, or `k` is not a number.
ValueError: if `arr` cannot be converted to a 2-dimensional tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore.numpy as np
>>> tensor = np.ones((3,3))
>>> print(np.triu_indices_from(tensor))
(Tensor(shape=[6], dtype=Int32, value= [0, 0, 0, 1, 1, 2]),
Tensor(shape=[6], dtype=Int32, value= [0, 1, 2, 1, 2, 2]))
"""
arr = asarray(arr)
if arr.ndim != 2:
_raise_value_error("input array must be 2-d")
return triu_indices(arr.shape[-2], k=k, m=arr.shape[-1])
def tril_indices_from(arr, k=0):
"""
Returns the indices for the lower-triangle of `arr`.
Args:
arr (Union[Tensor, list, tuple]): 2-dimensional array.
k (int, optional): Diagonal offset.
Returns:
triu_indices_from, tuple of 2 tensor, shape(N)
Indices for the upper-triangle of `arr`.
Raises:
TypeError: if `arr` cannot be converted to tensor, or `k` is not a number.
ValueError: if `arr` cannot be converted to a 2-dimensional tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore.numpy as np
>>> tensor = np.ones((3,3))
>>> print(np.tril_indices_from(tensor))
(Tensor(shape=[6], dtype=Int32, value= [0, 1, 1, 2, 2, 2]),
Tensor(shape=[6], dtype=Int32, value= [0, 0, 1, 0, 1, 2]))
"""
arr = asarray(arr)
if arr.ndim != 2:
_raise_value_error("input array must be 2-d")
return tril_indices(arr.shape[-2], k=k, m=arr.shape[-1])
def histogram_bin_edges(a, bins=10, range=None, weights=None): # pylint: disable=redefined-builtin
"""
Function to calculate only the edges of the bins used by the histogram function.
Note:
String values for `bins` is not supported.
Args:
a (Union[int, float, bool, list, tuple, Tensor]): Input data. The histogram
is computed over the flattened array.
bins ((Union[int, tuple, list, Tensor])): If `bins` is an int, it defines the number
of equal-width bins in the given range (10, by default). If `bins` is a
sequence, it defines the bin edges, including the rightmost edge,
allowing for non-uniform bin widths.
range((float, float), optional): The lower and upper range of the bins. If
not provided, `range` is simply ``(a.min(), a.max())``. Values outside
the range are ignored. The first element of the range must be less than
or equal to the second.
Returns:
Tensor, the edges to pass into `histogram`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Raises:
TypeError: if `bins` is an array and not one-dimensional.
Examples:
>>> import mindspore.numpy as np
>>> arr = np.array([0, 0, 0, 1, 2, 3, 3, 4, 5])
>>> print(np.histogram_bin_edges(arr, bins=2))
[0. 2.5 5. ]
"""
if isinstance(bins, (tuple, list, Tensor)):
bins = _to_tensor(bins)
if F.rank(bins) != 1:
_raise_value_error('`bins` must be 1d, when an array')
return bins
if isinstance(bins, str):
# linspace does not support Tensor for num
_raise_unimplemented_error('string value for `bins` not implemented')
a = _to_tensor(a).ravel().astype(mstype.float32)
if range is None:
start = F.reduce_min(a)
end = F.reduce_max(a)
else:
start, end = _to_tensor(*range)
no_range = (end - start) == 0
start = where(no_range, start - 0.5, start)
end = where(no_range, end + 0.5, end)
return linspace(start, end, bins + 1)
def _pad_empty(arr, pad_width):
"""
pads the array with constant values, used in mode: "empty"
"""
dtype = arr.dtype
for i in range(arr.ndim):
shape = arr.shape
pad_before = ()
pad_after = ()
# To avoid any memory issues, we don't make tensor with 0s in their shapes
if pad_width[i][0] > 0:
pad_before += (empty(_tuple_setitem(shape, i, pad_width[i][0]), dtype=dtype),)
if pad_width[i][1] > 0:
pad_after += (empty(_tuple_setitem(shape, i, pad_width[i][1]), dtype=dtype),)
tensor_with_pad = pad_before + (arr,) + pad_after
arr = concatenate(tensor_with_pad, axis=i)
return arr
def _pad_constant(arr, pad_width, value):
"""
pads the array with constant values, used in mode: "constant"
"""
dtype = arr.dtype
for i in range(arr.ndim):
shape = arr.shape
pad_before = ()
pad_after = ()
# To avoid any memory issues, we don't make tensor with 0s in their shapes
if pad_width[i][0] > 0:
pad_before += (full(_tuple_setitem(shape, i, pad_width[i][0]), value[i][0], dtype=dtype),)
if pad_width[i][1] > 0:
pad_after += (full(_tuple_setitem(shape, i, pad_width[i][1]), value[i][1], dtype=dtype),)
tensor_with_pad = pad_before + (arr,) + pad_after
arr = concatenate(tensor_with_pad, axis=i)
return arr
def _pad_statistic(arr, pad_width, stat_length, stat_op):
"""
pads the array with values calculated along the given axis, used in mode: "maximum",
"minimum", "mean"
"""
ndim = arr.ndim
shape = arr.shape
if stat_length is None:
stat_length = _make_stat_length(shape)
else:
stat_length = _convert_pad_to_nd(stat_length, ndim)
stat_length = _limit_stat_length(stat_length, shape)
for i in range(ndim):
pad_before = stat_op(_slice_along_axis(arr, i, 0, stat_length[i][0]), i)
pad_before = (F.tile(pad_before, _tuple_setitem((1,)*ndim, i, pad_width[i][0])),)
pad_after = stat_op(_slice_along_axis(arr, i, shape[i]-stat_length[i][1], shape[i]), i)
pad_after = (F.tile(pad_after, _tuple_setitem((1,)*ndim, i, pad_width[i][1])),)
tensor_with_pad = pad_before + (arr,) + pad_after
arr = concatenate(tensor_with_pad, axis=i)
return arr
def _pad_edge(arr, pad_width):
"""pad_edge is equivalent to pad_statistic with stat_lenght=1, used in mode:"edge"."""
def identity_op(arr, axis):
return arr
return _pad_statistic(arr, pad_width, 1, identity_op)
def _pad_wrap(arr, pad_width):
"""The behaviour of wrap mode is consistent with jax.numpy, used in mode:"wrap"."""
ndim = arr.ndim
shape = arr.shape
for i in range(ndim):
padsize_before = pad_width[i][0] % shape[i]
padsize_after = pad_width[i][1] % shape[i]
total_repeats = pad_width[i][0] // shape[i] + 1 + pad_width[i][1] // shape[i]
tensor_with_pad = ()
# To avoid any memory issues, we don't make tensor with 0s in their shapes
if padsize_before > 0:
tensor_with_pad += (_slice_along_axis(arr, i, shape[i]-padsize_before, shape[i]),)
tensor_with_pad += (F.tile(arr, _tuple_setitem((1,)*ndim, i, total_repeats)),)
if padsize_after > 0:
tensor_with_pad += (_slice_along_axis(arr, i, 0, padsize_after),)
arr = concatenate(tensor_with_pad, axis=i)
return arr
def _pad_linear(arr, pad_width, end_values):
"""Pads the arr with linear range values, used in mode: "linear_ramp"."""
ndim = arr.ndim
shape = arr.shape
dtype = arr.dtype
end_values = _convert_pad_to_nd(end_values, ndim)
for i in range(ndim):
# shape [..., 1, ...]
left_value = _slice_along_axis(arr, i, 0, 1)
right_value = _slice_along_axis(arr, i, shape[i]-1, shape[i])
pad_before = ()
pad_after = ()
if pad_width[i][0] > 0:
# shape [..., pad_width[i][0], ...]
pad_before = (linspace(end_values[i][0], left_value, num=pad_width[i][0],
endpoint=False, dtype=dtype, axis=i).squeeze(i+1),)
if pad_width[i][1] > 0:
# shape [..., pad_width[i][1], ...]
pad_after = linspace(right_value, end_values[i][1], num=pad_width[i][1]+1,
endpoint=True, dtype=dtype, axis=i).squeeze(i+1)
pad_after = (_slice_along_axis(pad_after, i, 1, pad_width[i][1]+1),)
tensor_with_pad = pad_before + (arr,) + pad_after
arr = concatenate(tensor_with_pad, axis=i)
return arr
def _pad_symmetric(arr, pad_width, reflect_type):
"""pad the array with symmetric paddings"""
for i in range(arr.ndim):
array_length = arr.shape[i]
has_pad_before = (pad_width[i][0] > 0)
has_pad_after = (pad_width[i][1] > 0)
edge_before = _slice_along_axis(arr, i, 0, 1)
edge_end = _slice_along_axis(arr, i, array_length-1, array_length)
times_to_pad_before = pad_width[i][0] // array_length + 1
additional_pad_before = pad_width[i][0] % array_length
times_to_pad_after = pad_width[i][1] // array_length + 1
additional_pad_after = pad_width[i][1] % array_length
curr_pad = None
if has_pad_before:
# Deal with paddings before the original array
for times in range(times_to_pad_before):
if times < times_to_pad_before - 1:
endpoint = array_length
else:
endpoint = additional_pad_before
if endpoint != 0:
curr_pad = _slice_along_axis(arr, i, 0, endpoint)
curr_pad = flip(curr_pad, axis=i)
if reflect_type == "odd":
curr_pad = 2 * edge_before - curr_pad
arr = P.Concat(i)((curr_pad, arr))
edge_before = _slice_along_axis(arr, i, 0, 1)
if has_pad_after:
# Deal with paddings after the original array
for times in range(times_to_pad_after):
if times < times_to_pad_after - 1:
startpoint = arr.shape[i] - array_length
else:
startpoint = arr.shape[i] - additional_pad_after
if startpoint != arr.shape[i]:
curr_pad = _slice_along_axis(arr, i, startpoint, arr.shape[i])
curr_pad = flip(curr_pad, axis=i)
if reflect_type == "odd":
curr_pad = 2 * edge_end - curr_pad
arr = P.Concat(i)((arr, curr_pad))
edge_end = _slice_along_axis(arr, i, arr.shape[i]-1, arr.shape[i])
return arr
def _pad_reflect(arr, pad_width, reflect_type):
"""
pad the array with reflect paddings, this is very similar to symmetric paddings,
but differs at how edges are selected.
"""
# pylint: disable=too-many-nested-blocks
for i in range(arr.ndim):
array_length = arr.shape[i]
if array_length == 1:
total_repeats = pad_width[i][0] + pad_width[i][1] + 1
arr = F.tile(arr, _tuple_setitem((1,)*arr.ndim, i, total_repeats))
else:
has_pad_before = (pad_width[i][0] > 0)
has_pad_after = (pad_width[i][1] > 0)
edge_before = _slice_along_axis(arr, i, 0, 1)
edge_end = _slice_along_axis(arr, i, array_length-1, array_length)
pad_size = array_length - 1
times_to_pad_before = pad_width[i][0] // pad_size + 1
additional_pad_before = pad_width[i][0] % pad_size
times_to_pad_after = pad_width[i][1] // pad_size + 1
additional_pad_after = pad_width[i][1] % pad_size
curr_pad = None
if has_pad_before:
# Deal with paddings before the original array
for times in range(times_to_pad_before):
if times < times_to_pad_before - 1:
endpoint = array_length
else:
endpoint = additional_pad_before + 1
if endpoint != 1:
curr_pad = _slice_along_axis(arr, i, 1, endpoint)
curr_pad = flip(curr_pad, axis=i)
if reflect_type == "odd":
curr_pad = 2 * edge_before - curr_pad
arr = P.Concat(i)((curr_pad, arr))
edge_before = _slice_along_axis(arr, i, 0, 1)
if has_pad_after:
# Deal with paddings after the original array
for times in range(times_to_pad_after):
if times < times_to_pad_after - 1:
startpoint = arr.shape[i] - array_length
else:
startpoint = arr.shape[i] - additional_pad_after - 1
if startpoint != arr.shape[i]-1:
curr_pad = _slice_along_axis(arr, i, startpoint, arr.shape[i]-1)
curr_pad = flip(curr_pad, axis=i)
if reflect_type == "odd":
curr_pad = 2 * edge_end - curr_pad
arr = P.Concat(i)((arr, curr_pad))
edge_end = _slice_along_axis(arr, i, arr.shape[i]-1, arr.shape[i])
return arr
def _pad_func(arr, pad_width, func, **kwargs):
"""applies padding function over different axis."""
# first creates a padded array with fixed length.
arr_dim = arr.ndim
pad_width = _convert_pad_to_nd(pad_width, arr_dim)
arr = _pad_empty(arr, pad_width)
for i in range(arr_dim):
# function signature: padding_func(tensor, iaxis_pad_width, iaxis, kwargs)
arr = apply_along_axis(func, i, arr, pad_width[i], i, kwargs)
return arr
@constexpr
def _make_stat_length(shape):
"""converts the stat_length values."""
return tuple((shape[i], shape[i]) for i, _ in enumerate(shape))
@constexpr
def _limit_stat_length(stat_length, shape):
"""limits the stat_length to current array length along given dimension."""
return tuple((min(stat_pair[0], shape[i]), min(stat_pair[1], shape[i])) for i, stat_pair in enumerate(stat_length))
@constexpr
def _convert_pad_to_nd(pad_values, ndim):
"""broadcasts the pad_values to (ndim * 2)"""
if not isinstance(pad_values, (int, list, tuple, Tensor)):
raise TypeError(
"pad_width, stat_length, constant_values or end_values should only be int, list, tuple or tensor")
pad_tensor = _to_tensor(pad_values)
pad_shape = pad_tensor.shape
if not pad_shape:
pad_values = tuple((((pad_values,) * 2) for i in range(ndim)))
elif pad_shape == (1,):
pad_values = tuple((tuple(pad_values) * 2) for i in range(ndim))
elif pad_shape == (2,):
pad_values = tuple(tuple(pad_values) for i in range(ndim))
elif pad_shape == (1, 2):
pad_values = tuple(tuple(pad_values[0]) for i in range(ndim))
elif pad_shape == (ndim, 2):
pad_values = tuple(tuple(pad_pair) for pad_pair in pad_values)
else:
raise ValueError(f"input values must be able to broadcast to {(ndim, 2)}")
return pad_values
def pad(arr, pad_width, mode="constant", stat_length=None, constant_values=0,
end_values=0, reflect_type="even", **kwargs):
"""
Pads an array.
Note:
Currently, `median` mode is not supported. `reflect` and `symmetric` mode
only supports GPU backend.
Args:
arr (Union[list, tuple, Tensor]): The array to pad.
pad_width (Union[int, tuple, list]): Number of values padded to the edges of
each axis. :class:`((before_1, after_1), ... (before_N, after_N))` creates
unique pad widths for each axis. :class:`((before, after),)` yields same
before and after pad for each axis. :class:`(pad,)` or int is a shortcut
for :class:`before = after = pad width` for all axes.
mode (string, optional):
One of the following string values:
- constant (default): Pads with a constant value.
- edge: Pads with the edge values of `arr`.
- linear_ramp: Pads with the linear ramp between end_value and the `arr` edge value.
- maximum: Pads with the maximum value of all or part of the vector along each axis.
- mean: Pads with the mean value of all or part of the vector along each axis.
- median: Pads with the median value of all or part of the vector along each axis.
- minimum: Pads with the minimum value of all or part of the vector along each axis.
- reflect: Pads with the reflection of the vector mirrored on the first
and last values of the vector along each axis.
- symmetric: Pads with the reflection of the vector mirrored along the edge
of the `arr`.
- wrap: Pads with the wrap of the vector along the axis. The first values
are used to pad the end and the end values are used to pad the beginning.
- empty: Pads with undefined values.
- <function>: The padding function, if used, should modify and return a new 1-d tensor.
It has the following signature: :class:`padding_func(tensor, iaxis_pad_width, iaxis, kwargs)`
stat_length (Union[tuple, list, int], optional): Used in \'maximum\', \'mean\',
\'median\', and \'minimum\'. Number of values at edge of each axis used
to calculate the statistic value. :class:`((before_1, after_1), ... (before_N, after_N))`
creates unique statistic lengths for each axis. :class:`((before, after),)`
yields same before and after statistic lengths for each axis. :class:`(stat_length,)`
or int is a shortcut for :class:`before = after = statistic length` for all
axes. Default is :class:`None`, to use the entire axis.
constant_values (Union[tuple, list, int], optional):
Used in :class:`constant mode`. The values to set the padded values for each
axis. :class:`((before_1, after_1), ... (before_N, after_N))` creates unique pad
constants for each axis. :class:`((before, after),)` yields same before and
after constants for each axis. :class:`(constant,)` or :class:`constant` is
a shortcut for :class:`before = after = constant` for all axes. Default is 0.
end_values (Union[tuple, list, int], optional): Used in 'linear_ramp'. The values
used for the ending value of the linear_ramp and that will form the edge of
the padded `arr`. :class:`((before_1, after_1), ... (before_N, after_N))`
unique end values for each axis. :class`((before, after),)` yields same before
and after end values for each axis. :class:`(constant,)` or :class:`constant`
is a shortcut for :class:`before = after = constant` for all axes. Default is 0.
reflect_type(string, optional) can choose between \'even\' and \'odd\'. Used in
\'reflect\', and \'symmetric\'. The \'even\' style is the default with an
unaltered reflection around the edge value. For the \'odd\' style, the extended
part of the `arr` is created by subtracting the reflected values from two times
the edge value.
Returns:
Padded tensor of rank equal to `arr` with shape increased according to `pad_width`.
Raises:
TypeError: if `arr`, `pad_width`, `stat_length`, `constant_values` or `end_values`
have types not specified above.
ValueError: if `mode` cannot be recognized, or if `pad_width`, `stat_length`,
`constant_values`, `end_values` cannot broadcast to :class:`(arr.ndim, 2)`,
or if keyword arguments got unexpected inputs.
NotImplementedError: if mode is function or '/median'/.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore.numpy as np
>>> tensor = np.array([1., 2., 3., 4., 5.])
>>> print(np.pad(tensor, (3, 4)))
[0. 0. 0. 1. 2. 3. 4. 5. 0. 0. 0. 0.]
>>> print(np.pad(tensor, (3, 4), mode="wrap"))
[3. 4. 5. 1. 2. 3. 4. 5. 1. 2. 3. 4.]
>>> >>> print(np.pad(tensor, (3, 4), mode="linear_ramp", end_values=(10, 10)))
[10. 7. 4. 1. 2. 3. 4. 5. 6.25 7.5 8.75 10. ]
"""
arr = _to_tensor(arr)
if arr.ndim == 0:
return arr
pad_width = _convert_pad_to_nd(pad_width, arr.ndim)
stat_func = {"maximum": _reduce_max_keepdims,
"minimum": _reduce_min_keepdims,
"mean": _reduce_mean_keepdims,
"median": "not implemented"}
if mode not in ("constant", "maximum", "minimum", "mean", "median", "edge",
"wrap", "linear_ramp", "symmetric", "reflect", "empty") and \
not _callable(arr, mode):
_raise_value_error("Input mode not supported.")
if mode == "constant":
constant_values = _convert_pad_to_nd(constant_values, arr.ndim)
return _pad_constant(arr, pad_width, constant_values)
if mode in ("maximum", "minimum", "mean", "median"):
# TODO: support median mode once P.Sort/P.Median is supported on GPU/CPU
if mode == "median":
_raise_unimplemented_error("median mode is not supported yet")
return _pad_statistic(arr, pad_width, stat_length, stat_func[mode])
if mode == "edge":
return _pad_edge(arr, pad_width)
if mode == "wrap":
return _pad_wrap(arr, pad_width)
if mode == "linear_ramp":
return _pad_linear(arr, pad_width, end_values)
if mode == "symmetric":
return _pad_symmetric(arr, pad_width, reflect_type)
if mode == "reflect":
return _pad_reflect(arr, pad_width, reflect_type)
if mode == 'empty':
return _pad_empty(arr, pad_width)
return _pad_func(arr, pad_width, mode, **kwargs)

View File

@ -24,14 +24,14 @@ from ..ops.primitive import constexpr
from ..nn import Cell
from .utils import _convert_list_tensor_to_tuple_tensor, _expand, _broadcast_to_shape, \
_check_input_tensor, _broadcast_to, _to_tensor
_check_input_tensor, _broadcast_to, _to_tensor, _callable
from .utils_const import _check_axes_range, _check_start_normalize, \
_raise_type_error, _raise_value_error, _infer_out_shape, _empty, _promote, \
_check_same_type, _check_axis_valid, _add_unit_axes, _broadcast_tuples, \
_check_is_float, _check_axis_in_range, _check_axis_type, _canonicalize_axis, \
_list_comprehensions, _check_element_int, _is_shape_empty, _type_convert, \
_tuple_getitem, _expanded_shape, _seq_prod, _get_device, _tuple_setitem, \
_raise_unimplemented_error
_tuple_slice, _expanded_shape, _seq_prod, _tuple_setitem, _iota, \
_raise_unimplemented_error, _cumprod, _get_device
# According to official numpy reference, the dimension of a numpy array must be less
# than 32
@ -697,6 +697,7 @@ def where(condition, x=None, y=None):
[7 5]
[7 5]]]
"""
condition, x, y = _to_tensor(condition, x, y)
# type promotes input tensors
dtype1 = F.dtype(x)
dtype2 = F.dtype(y)
@ -1781,16 +1782,15 @@ def take_along_axis(arr, indices, axis):
ndim = F.rank(arr)
if ndim != F.rank(indices):
_raise_value_error('`indices` and `arr` must have the same number of dimensions')
_check_axis_in_range(axis, ndim)
axis = axis + ndim if axis < 0 else axis
axis = _check_axis_in_range(axis, ndim)
shape_arr = F.shape(arr)
shape_indices = F.shape(indices)
# broadcasts indices against the shape of arr except at axis
indices = _broadcast_to(indices, _tuple_getitem(shape_indices, axis, False),
_tuple_getitem(shape_arr, axis, False), ndim)
indices = _broadcast_to(indices, _tuple_getitem(shape_arr, axis + 1, False) +
_tuple_getitem(shape_indices, axis + 1), shape_arr, ndim)
indices = _broadcast_to(indices, _tuple_slice(shape_indices, None, axis),
_tuple_slice(shape_arr, None, axis), ndim)
indices = _broadcast_to(indices, _tuple_slice(shape_arr, None, axis + 1) +
_tuple_slice(shape_indices, axis + 1, None), shape_arr, ndim)
return F.gather_d(arr, axis, indices)
@ -1801,18 +1801,21 @@ def _mod(x, y):
return F.tensor_sub(x, prod)
def _check_indices(size, indices, mode):
def _check_indices(dims, indices, mode, allow_negative_index=True):
"""Checks whether indices are out of bounds."""
shape = F.shape(indices)
dtype = F.dtype(indices)
lowerbounds = F.fill(dtype, shape, -size)
upperbounds = F.fill(dtype, shape, size - 1)
if not allow_negative_index:
lowerbounds = F.fill(dtype, shape, 0)
else:
lowerbounds = F.fill(dtype, shape, -dims)
upperbounds = F.fill(dtype, shape, dims - 1)
out_of_lowerbounds = F.tensor_lt(indices, lowerbounds)
out_of_upperbounds = F.tensor_gt(indices, upperbounds)
if mode == 'raise':
_raise_unimplemented_error('"raise" mode is not implemented')
if mode == 'wrap':
return _mod(indices, F.fill(dtype, shape, size))
return _mod(indices, F.fill(mstype.float32, shape, dims)).astype(dtype)
zeros = F.fill(dtype, shape, 0)
clipped = F.select(out_of_lowerbounds, zeros, indices)
clipped = F.select(out_of_upperbounds, upperbounds, clipped)
@ -1878,8 +1881,7 @@ def take(a, indices, axis=None, mode='clip'):
a = ravel(a)
axis = 0
ndim = F.rank(a)
_check_axis_in_range(axis, ndim)
axis = axis + ndim if axis < 0 else axis
axis = _check_axis_in_range(axis, ndim)
shape_a = F.shape(a)
shape_indices = F.shape(indices)
@ -1887,8 +1889,8 @@ def take(a, indices, axis=None, mode='clip'):
indices = _check_indices(shape_a[axis], indices, mode)
# reshapes indices to shape (Ni..., Nj..., Nk)
shape_ni = _tuple_getitem(shape_a, axis, False)
shape_nk = _tuple_getitem(shape_a, axis + 1)
shape_ni = _tuple_slice(shape_a, None, axis)
shape_nk = _tuple_slice(shape_a, axis + 1, None)
shape_out = shape_ni + shape_indices + shape_nk
shape_indices = _expanded_shape(ndim, size_indices, axis)
indices = F.reshape(indices, shape_indices)
@ -1948,18 +1950,17 @@ def repeat(a, repeats, axis=None):
a = ravel(a)
axis = 0
ndim = F.rank(a)
_check_axis_in_range(axis, ndim)
axis = axis + ndim if axis < 0 else axis
axis = _check_axis_in_range(axis, ndim)
if len(repeats) == 1:
repeats = repeats[0]
if repeats == 0:
return _empty(F.dtype(a), (0,))
return C.repeat_elements(a, repeats, axis)
shape = F.shape(a)
size = shape[axis]
if len(repeats) != size:
dims = shape[axis]
if len(repeats) != dims:
_raise_value_error('operands could not be broadcast together')
subs = split(a, size, axis)
subs = split(a, dims, axis)
repeated_subs = []
for sub, rep in zip(subs, repeats):
if rep != 0:
@ -2046,11 +2047,13 @@ def select(condlist, choicelist, default=0):
Returns an array drawn from elements in `choicelist`, depending on conditions.
Args:
condlist (array_like): The list of conditions which determine from which array
in `choicelist` the output elements are taken. When multiple conditions are
satisfied, the first one encountered in `condlist` is used.
choicelist (array_like): The list of arrays from which the output elements are
taken. It has to be of the same length as `condlist`.
condlist (Union[int, float, bool, list, tuple, Tensor]): The list of conditions
which determine from which array in `choicelist` the output elements are
taken. When multiple conditions are satisfied, the first one encountered in
`condlist` is used.
choicelist (Union[int, float, bool, list, tuple, Tensor]): The list of arrays
from which the output elements are taken. It has to be of the same length as
`condlist`.
default (scalar, optional): The element inserted in output when all conditions
evaluate to `False`.
@ -2059,7 +2062,6 @@ def select(condlist, choicelist, default=0):
`choicelist` where the `m-th` element of the corresponding array in `condlist`
is `True`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -2067,7 +2069,9 @@ def select(condlist, choicelist, default=0):
ValueError: if ``len(condlist) != len(choicelist)``.
Examples:
>>> condlist = [[True, True, True, False, False], [False, False, True, False, True]]
>>> import mindspore.numpy as np
>>> condlist = [[True, True, True, False, False], \
[False, False, True, False, True]]
>>> choicelist = [[0, 1, 2, 3, 4], [0, 1, 4, 9, 16]]
>>> output = np.select(condlist, choicelist)
>>> print(output)
@ -2076,32 +2080,481 @@ def select(condlist, choicelist, default=0):
condlist, choicelist = _to_tensor(condlist, choicelist)
shape_cond = F.shape(condlist)
shape_choice = F.shape(choicelist)
if F.rank(condlist) == 0 or F.rank(condlist) == 0:
if F.rank(condlist) == 0 or F.rank(choicelist) == 0:
_raise_value_error('input cannot be scalars')
case_num = shape_cond[0]
if shape_choice[0] != case_num:
_raise_value_error('list of cases must be same length as list of conditions')
case_size_cond = _tuple_slice(shape_cond, 1, None)
case_size_choice = _tuple_slice(shape_choice, 1, None)
# performs broadcast over the cases in condlist and choicelist
case_size = _infer_out_shape(shape_cond[1:], shape_choice[1:])
case_size = _infer_out_shape(case_size_cond, case_size_choice)
shape_broadcasted = (case_num,) + case_size
ndim = len(shape_broadcasted)
shape_cond_expanded = ((case_num,) + _list_comprehensions(ndim - F.rank(condlist), 1, True) +
shape_cond[1:])
case_size_cond)
condlist = _broadcast_to_shape(F.reshape(condlist, shape_cond_expanded), shape_broadcasted)
shape_choice_expanded = ((case_num,) + _list_comprehensions(ndim - F.rank(choicelist), 1, True) +
shape_choice[1:])
case_size_choice)
choicelist = _broadcast_to_shape(F.reshape(choicelist, shape_choice_expanded), shape_broadcasted)
slice_start = _list_comprehensions(ndim - 1, 0, True)
slice_size = (1,) + case_size
dtype = F.dtype(choicelist)
if _get_device() == 'CPU' and not _check_is_float(dtype):
# F.tensor_slice only supports float on CPU
choicelist = F.cast(choicelist, mstype.float32)
default_slice = F.fill(F.dtype(choicelist), slice_size, default)
if isinstance(default, Tensor):
default_slice = default.astype(F.dtype(choicelist)).reshape(slice_size)
else:
default_slice = F.fill(F.dtype(choicelist), slice_size, default)
for i in range(case_num - 1, -1, -1):
cond_slice = F.tensor_slice(condlist.astype(mstype.float32), (i,) + slice_start, slice_size)
choice_slice = F.tensor_slice(choicelist, (i,) + slice_start, slice_size)
default_slice = F.select(cond_slice.astype(mstype.bool_), choice_slice, default_slice)
return F.reshape(default_slice, (case_size)).astype(dtype)
@constexpr
def _get_grid(shape):
"""Returns a grid representing all the indices for an array with the given shape."""
grids = []
ndim = len(shape)
for i in range(ndim):
dim_grid = _iota(mstype.int32, shape[i])
dim_shape = _expanded_shape(ndim, shape[i], i)
dim_grid = _broadcast_to_shape(dim_grid.reshape(dim_shape), shape)
grids.append(dim_grid)
return stack(grids, -1)
def choose(a, choices, mode='clip'):
"""
Construct an array from an index array and a list of arrays to choose from.
Given an index array `a`` of integers and a sequence of n arrays (choices),
`a` and each choice array are first broadcast, as necessary, to arrays of a
common shape; calling these `Ba` and `Bchoices[i], i = 0,,n-1` we have that,
necessarily, ``Ba.shape == Bchoices[i].shape`` for each `i`. Then, a new array
with ``shape Ba.shape`` is created as follows:
- if ``mode='raise'`` (the default), then, first of all, each element of `a`
(and thus `Ba`) must be in the range `[0, n-1]`; now, suppose that `i`
(in that range) is the value at the `(j0, j1, ..., jm)` position in
`Ba` - then the value at the same position in the new array is the
value in ``Bchoices[i]`` at that same position;
- if ``mode='wrap'``, values in `a` (and thus `Ba`) may be any (signed)
integer; modular arithmetic is used to map integers outside the
range ``[0, n-1]`` back into that range; and then the new array is
constructed as above;
- if ``mode='clip'``, values in `a` (and thus `Ba`) may be any (signed) integer;
negative integers are mapped to 0; values greater than `n-1` are mapped to
`n-1`; and then the new array is constructed as above.
Note:
Numpy argument `out` is not supported.
``mode = 'raise'`` is not supported, and the default mode is 'clip' instead.
Args:
a (int array): This array must contain integers in ``[0, n-1]``, where `n` is
the number of choices, unless ``mode=wrap`` or ``mode=clip``, in which
cases any integers are permissible.
choices (sequence of arrays): Choice arrays. `a` and all of the `choices` must
be broadcastable to the same shape. If `choices` is itself an array, then
its outermost dimension (i.e., the one corresponding to ``choices.shape[0]``)
is taken as defining the sequence.
mode (raise, wrap, clip, optional): Specifies how indices outside
``[0, n-1]`` will be treated:
raise raise an error (default);
wrap wrap around;
clip clip to the range. clip mode means that all indices that are
too large are replaced by the index that addresses the last element
along that axis. Note that this disables indexing with negative numbers.
Returns:
Tensor, the merged result.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Raises:
ValueError: if ``len(condlist) != len(choicelist)``.
Examples:
>>> import mindspore.numpy as np
>>> choices = [[0, 1, 2, 3], [10, 11, 12, 13],
[20, 21, 22, 23], [30, 31, 32, 33]]
>>> print(np.choose([2, 3, 1, 0], choices))
[20 31 12 3]
>>> print(np.choose([2, 4, 1, 0], choices, mode='clip'))
[20 31 12 3]
>>> print(np.choose([2, 4, 1, 0], choices, mode='wrap'))
[20 1 12 3]
>>> a = [[1, 0, 1], [0, 1, 0], [1, 0, 1]]
>>> choices = [-10, 10]
>>> print(np.choose(a, choices))
[[ 10 -10 10]
[-10 10 -10]
[ 10 -10 10]]
>>> a = np.array([0, 1]).reshape((2,1,1))
>>> c1 = np.array([1, 2, 3]).reshape((1,3,1))
>>> c2 = np.array([-1, -2, -3, -4, -5]).reshape((1,1,5))
>>> print(np.choose(a, (c1, c2)))
[[[ 1 1 1 1 1]
[ 2 2 2 2 2]
[ 3 3 3 3 3]]
[[-1 -2 -3 -4 -5]
[-1 -2 -3 -4 -5]
[-1 -2 -3 -4 -5]]]
"""
a = _to_tensor(a)
if isinstance(choices, (tuple, list)):
# broadcasts choices to the same shape if choices is a sequence
choices = _to_tensor(*choices)
shapes = ()
for choice in choices:
shapes += (F.shape(choice),)
shape_choice = _infer_out_shape(F.shape(a), *shapes)
tmp = []
for choice in choices:
tmp.append(broadcast_to(choice, shape_choice))
choices = stack(tmp)
else:
choices = _to_tensor(choices)
shape_choice = _infer_out_shape(F.shape(a), F.shape(choices)[1:])
choices = broadcast_to(choices, (F.shape(choices)[0],) + shape_choice)
if F.rank(a) == 0 or F.rank(choices) == 0:
_raise_value_error('input cannot be scalars')
a = broadcast_to(a, shape_choice)
dtype = F.dtype(choices)
# adjusts dtype for F.tensor_mul and F.gather_nd
a = a.astype(mstype.int32)
choices = choices.astype(mstype.int32)
a = _check_indices(F.shape(choices)[0], a, mode, allow_negative_index=False)
grid = _get_grid(F.shape(a))
indices = concatenate((a.reshape(F.shape(a) + (1,)), grid), -1)
return F.gather_nd(choices, indices).astype(dtype)
def size(a, axis=None):
"""
Returns the number of elements along a given axis.
Args:
a (Union[int, float, bool, list, tuple, Tensor]): Input data.
axis (int): Axis along which the elements are counted. Default: None.
If None, give the total number of elements.
Returns:
Number of elements along the specified axis.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Raises:
TypeError: If input is not array_like or `axis` is not int or tuple of ints.
ValueError: If any axis is out of range or duplicate axes exist.
Examples:
>>> import mindspore.numpy as np
>>> x = np.arange(10).reshape(2, 5).astype('float32')
>>> print(np.size(x))
10
>>> print(np.size(x, axis=1))
5
"""
a = _to_tensor(a)
if axis is None:
return a.size
if not isinstance(axis, int):
_raise_type_error("axis argument should be integer.")
axis = _canonicalize_axis(axis, a.ndim)
return a.shape[axis]
def array_str(a):
"""
Returns a string representation of the data in an array.
The data in the array is returned as a single string.
This function is similar to array_repr, the difference being that array_repr also
returns information on the kind of array and its data type.
Note:
Numpy argument `max_line_width`, `precision` and `suppress_small` are not supported.
Args:
a (Union[int, float, bool, list, tuple, Tensor]): Input data.
Returns:
String.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Raises:
TypeError: If input is not array_like.
Examples:
>>> import mindspore.numpy as np
>>> x = np.arange(5)
>>> np.array_str(x)
'[0 1 2 3 4]'
"""
a = _to_tensor(a)
return a.__str__()
def apply_along_axis(func1d, axis, arr, *args, **kwargs):
"""
Applies a function to 1-D slices along the given axis.
Executes ``func1d(a, *args, **kwargs)`` where `func1d` operates on 1-D arrays and `a` is a
1-D slice of arr along axis.
Args:
func1d (function): Maps `(M,) -> (Nj)`. This function should accept 1-D arrays. It is
applied to 1-D slices of arr along the specified axis.
axis (int): Axis along which arr is sliced.
arr (Tensor): Input array with shape `(Ni, M, Nk)`.
args (any): Additional arguments to `func1d`.
kwargs (any): Additional named arguments to `func1d`.
Returns:
Tensor with shape `(Ni, Nj, Nk)`, the output array. Its shape is identical to the
shape of `arr`, except along the `axis` dimension. This axis is removed, and replaced
with new dimensions equal to the shape of the return value of `func1d`. So if `func1d`
returns a scalar, the output will have one fewer dimensions than `arr`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Raises:
ValueError: if axis is out of the range.
Examples:
>>> import mindspore.numpy as np
>>> b = np.array([[1,2,3], [4,5,6], [7,8,9]])
>>> print(np.apply_along_axis(np.diag, -1, b))
[[[1 0 0]
[0 2 0]
[0 0 3]]
[[4 0 0]
[0 5 0]
[0 0 6]]
[[7 0 0]
[0 8 0]
[0 0 9]]]
"""
ndim = F.rank(arr)
shape = F.shape(arr)
axis = _check_axis_in_range(axis, ndim)
arr = moveaxis(arr, axis, -1)
arr = F.reshape(arr, (-1, F.shape(arr)[-1]))
slices = []
for i in range(F.shape(arr)[0]):
slices.append(func1d(arr[i], *args, **kwargs))
stacked_slices = stack(slices)
shape_stacked = (_tuple_slice(shape, None, axis) + _tuple_slice(shape, axis + 1, None) +
_tuple_slice(F.shape(stacked_slices), 1, None))
res = F.reshape(stacked_slices, shape_stacked)
# moves the dimensions returned by `func1d` back to `axis`
ndim_func = F.rank(res) - ndim + 1
if ndim_func >= 1:
res = moveaxis(res, F.make_range(ndim - 1, F.rank(res)),
F.make_range(axis, axis + ndim_func))
return res
def _stack_arrays(arrs):
"""Stacks a sequence of Tensor"""
if isinstance(arrs, (tuple, list)):
tensor_list = []
for arr in arrs:
tensor_list.append(_to_tensor(arr))
return stack(tensor_list)
return atleast_1d(_to_tensor(arrs))
def piecewise(x, condlist, funclist, *args, **kw):
"""
Evaluates a piecewise-defined function.
Given a set of conditions and corresponding functions, evaluate each function on the input
data wherever its condition is true.
Args:
x (Union[int, float, bool, list, tuple, Tensor]): The input domain.
condlist (Union[bool, list of bool Tensor]): Each boolean array corresponds to a
function in `funclist`. Wherever `condlist[i]` is True, `funclist[i](x)` is used as
the output value. Each boolean array in `condlist` selects a piece of `x`, and
should therefore be of the same shape as `x`. The length of `condlist` must
correspond to that of `funclist`. If one extra function is given, i.e. if
``len(funclist) == len(condlist) + 1``, then that extra function is the default
value, used wherever all conditions are false.
funclist (Union[list of callables, list of scalars]): Each function is evaluated over
`x` wherever its corresponding condition is True. It should take a 1d array as input
and give an 1d array or a scalar value as output. If, instead of a callable, a scalar
is provided then a constant function ``(lambda x: scalar)`` is assumed.
args (any): Any further arguments given to `piecewise` are passed to the functions upon
execution, i.e., if called ``piecewise(..., ..., 1, 'a')``, then each function is
called as ``f(x, 1, 'a')``.
kw (any): Keyword arguments used in calling `piecewise` are passed to the functions upon
execution, i.e., if called ``piecewise(..., ..., alpha=1)``, then each function is
called as ``f(x, alpha=1)``.
Returns:
Tensor, the output is the same shape and type as `x` and is found by calling the
functions in `funclist` on the appropriate portions of `x`, as defined by the boolean
arrays in `condlist`. Portions not covered by any condition have a default value of 0.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Raises:
ValueError: if length of `funclist` is not in ``(len(condlist), len(condlist) + 1)``
Examples:
>>> import mindspore.numpy as np
>>> x = np.linspace(-2.5, 2.5, 6)
>>> print(np.piecewise(x, [x < 0, x >= 0], [-1, 1]))
[2.5 1.5 0.5 0.5 1.5 2.5]
"""
x = _to_tensor(x)
choicelist = funclist
if isinstance(funclist, (tuple, list)):
if _callable(x, funclist[0]):
choicelist = []
for func in funclist:
choicelist.append(func(x, *args, **kw))
condlist = _stack_arrays(condlist)
choicelist = _stack_arrays(choicelist)
default = 0
n1 = len(condlist)
n2 = len(funclist)
if n1 + 1 == n2:
default = choicelist[-1]
choicelist = choicelist[:-1]
elif n1 != n2:
_raise_value_error('the number of choices should be either equal to conditions or ', n1 + 1)
return select(condlist, choicelist, default=default)
def unravel_index(indices, shape, order='C'):
"""
Converts a flat index or array of flat indices into a tuple of coordinate arrays.
Note:
Out-of-bound indices are clipped by the boundaries of `shape` instead of raising
an error.
Args:
indices (Union[int, float, bool, list, tuple, Tensor]): An integer array whose elements
are indices into the flattened version of an array of dimensions shape.
shape (tuple of ints): The shape of the array to use for unraveling indices.
order (Union['C', 'F'], optional): Determines whether the indices should be viewed as
indexing in row-major (C-style) or column-major (Fortran-style) order.
Returns:
Tensor, each array in the tuple has the same shape as the indices array.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Raises:
ValueError: if `order` is not 'C' or 'F'.
Examples:
>>> import mindspore.numpy as np
>>> print(np.unravel_index([22, 41, 37], (7,6)))
(Tensor(shape=[3], dtype=Int32, value= [3, 6, 6]),
Tensor(shape=[3], dtype=Int32, value= [4, 5, 1]))
>>> print(np.unravel_index([31, 41, 13], (7,6), order='F'))
(Tensor(shape=[3], dtype=Int32, value= [3, 6, 6]),
Tensor(shape=[3], dtype=Int32, value= [4, 5, 1]))
"""
indices = _to_tensor(indices)
if order not in ('C', 'F'):
_raise_value_error('invalid order. Expected "C" or "F"')
if isinstance(shape, int):
shape = (shape,)
ndim = F.rank(indices)
if order == 'F':
sizes = _cumprod(shape)
else:
sizes = _cumprod(shape[::-1])
sizes = _to_tensor(sizes[::-1] + (1,))
sizes = F.reshape(sizes, (-1,) + _list_comprehensions(ndim, 1, True))
total_size = sizes[0]
indices = where(indices > total_size - 1, total_size - 1, indices)
if _get_device() == 'GPU':
dtype = F.dtype(total_size)
lowerbounds = (-(total_size.astype(mstype.float32))).astype(dtype)
else:
lowerbounds = -total_size
indices = where(indices < lowerbounds, lowerbounds, indices)
res = _mod(indices, sizes[:-1])//sizes[1:]
num = len(res)
if ndim == 0 and num == 1:
return res.ravel()
if order == 'F':
r = range(num - 1, -1, -1)
else:
r = range(num)
subs = ()
for i in r:
subs += (res[i],)
return subs
def apply_over_axes(func, a, axes):
"""
Applies a function repeatedly over multiple axes.
`func` is called as `res = func(a, axis)`, where `axis` is the first element of `axes`.
The result `res` of the function call must have either the same dimensions as `a` or
one less dimension. If `res` has one less dimension than `a`, a dimension is inserted before `axis`.
The call to `func` is then repeated for each axis in `axes`, with `res` as the first argument.
Args:
func (function): This function must take two arguments, `func(a, axis)`.
a (Union[int, float, bool, list, tuple, Tensor]): Input tensor.
axes (Union[int, list, tuple]): Axes over which `func` is applied; the elements must be integers.
Returns:
Tensor. The number of dimensions is the same as `a`, but the shape can be different.
This depends on whether `func` changes the shape of its output with respect to its input.
Raises:
TypeError: If input `a` is not array_like or `axes` is not int or sequence of ints.
ValueError: If any axis is out of range or duplicate axes exist.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore.numpy as np
>>> x = np.arange(10).reshape(2, 5).astype('float32')
>>> print(x)
[[0. 1. 2. 3. 4.]
[5. 6. 7. 8. 9.]]
>>> print(np.apply_over_axes(np.sum, x, axes=0))
[[ 5. 7. 9. 11. 13.]]
"""
a = _to_tensor(a)
if isinstance(axes, int):
axes = (axes,)
res = a
for axis in axes:
res = func(res, axis=axis)
res = F.expand_dims(res, axis) if res.ndim != a.ndim else res
if res.ndim != a.ndim:
_raise_value_error("function is not returning a tensor of the correct shape")
return res

View File

@ -16,15 +16,14 @@
from ..ops import functional as F
from ..ops.primitive import constexpr
from ..common import dtype as mstype
from ..common import Tensor
from .._c_expression import typing
from .math_ops import _apply_tensor_op, absolute
from .array_creations import zeros, ones, empty
from .array_creations import zeros, ones, empty, asarray
from .utils import _check_input_tensor, _to_tensor, _isnan
from .utils_const import _raise_type_error, _is_shape_empty, _infer_out_shape
from .utils_const import _raise_type_error, _is_shape_empty, _infer_out_shape, _check_same_type, \
_check_axis_type, _canonicalize_axis, _can_broadcast, _isscalar
def not_equal(x1, x2, dtype=None):
@ -410,13 +409,6 @@ def isneginf(x):
return _is_sign_inf(x, F.tensor_lt)
@constexpr
def _isscalar(x):
"""Returns True if x is a scalar type"""
return isinstance(x, (typing.Number, typing.Int, typing.UInt, typing.Float,
typing.Bool, typing.String))
def isscalar(element):
"""
Returns True if the type of element is a scalar type.
@ -534,8 +526,9 @@ def in1d(ar1, ar2, invert=False):
not rely on the uniqueness of the input arrays.
Args:
ar1 (array_like): Input array with shape `(M,)`.
ar2 (array_like): The values against which to test each value of `ar1`.
ar1 (Union[int, float, bool, list, tuple, Tensor]): Input array with shape `(M,)`.
ar2 (Union[int, float, bool, list, tuple, Tensor]): The values against which
to test each value of `ar1`.
invert (boolean, optional): If True, the values in the returned array are
inverted (that is, False where an element of `ar1` is in `ar2` and True
otherwise). Default is False.
@ -746,3 +739,167 @@ def logical_xor(x1, x2, dtype=None):
y1 = F.logical_or(x1, x2)
y2 = F.logical_or(F.logical_not(x1), F.logical_not(x2))
return _apply_tensor_op(F.logical_and, y1, y2, dtype=dtype)
def array_equal(a1, a2, equal_nan=False):
"""
Returns `True` if input arrays have same shapes and all elements equal.
Note:
In mindpsore, a bool tensor is returned instead, since in Graph mode, the
value cannot be traced and computed at compile time.
Args:
a1/a2 (Union[int, float, bool, list, tuple, Tensor]): Input arrays.
equal_nan (bool): Whether to compare NaNs as equal.
Returns:
Scalar bool tensor, value is `True` if inputs are equal, `False` otherwise.
Raises:
TypeError: If inputs have types not specified above.
Supported Platforms:
``GPU`` ``CPU``
Examples:
>>> import mindspore.numpy as np
>>> a = [0,1,2]
>>> b = [[0,1,2], [0,1,2]]
>>> print(np.array_equal(a,b))
False
"""
a1 = asarray(a1)
a2 = asarray(a2)
if not isinstance(equal_nan, bool):
_raise_type_error("equal_nan must be bool.")
if a1.shape == a2.shape:
res = equal(a1, a2)
if equal_nan:
res = logical_or(res, logical_and(isnan(a1), isnan(a2)))
return res.all()
return _to_tensor(False)
def array_equiv(a1, a2):
"""
Returns `True` if input arrays are shape consistent and all elements equal.
Shape consistent means they are either the same shape, or one input array can
be broadcasted to create the same shape as the other one.
Note:
In mindpsore, a bool tensor is returned instead, since in Graph mode, the
value cannot be traced and computed at compile time.
Args:
a1/a2 (Union[int, float, bool, list, tuple, Tensor]): Input arrays.
Returns:
Scalar bool tensor, value is `True` if inputs are equivalent, `False` otherwise.
Raises:
TypeError: If inputs have types not specified above.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore.numpy as np
>>> a = [0,1,2]
>>> b = [[0,1,2], [0,1,2]]
>>> print(np.array_equiv(a,b))
True
"""
a1 = asarray(a1)
a2 = asarray(a2)
if _can_broadcast(a1.shape, a2.shape):
return equal(a1, a2).all()
return _to_tensor(False)
def signbit(x, dtype=None):
"""
Returns element-wise True where signbit is set (less than zero).
Note:
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and
`extobj` are not supported.
Args:
x (Union[int, float, bool, list, tuple, Tensor]): The input value(s).
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor.
Returns:
Tensor.
Raises:
TypeError: If input is not array_like or `dtype` is not `None` or `bool`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore.numpy as np
>>> x = np.array([1, -2.3, 2.1]).astype('float32')
>>> output = np.signbit(x)
>>> print(output)
[False True False]
"""
if dtype is not None and not _check_same_type(dtype, mstype.bool_):
_raise_type_error("Casting was not allowed for signbit.")
x = _to_tensor(x)
res = F.less(x, 0)
if dtype is not None and not _check_same_type(F.dtype(res), dtype):
res = F.cast(res, dtype)
return res
def sometrue(a, axis=None, keepdims=False):
"""
Tests whether any array element along a given axis evaluates to True.
Returns single boolean unless axis is not None
Args:
a (Union[int, float, bool, list, tuple, Tensor]): Input tensor or object that can be converted to an array.
axis (Union[None, int, tuple(int)]): Axis or axes along which a logical OR reduction is
performed. Default: None.
If None, perform a logical OR over all the dimensions of the input array.
If negative, it counts from the last to the first axis.
If tuple of ints, a reduction is performed on multiple axes, instead of a single axis or
all the axes as before.
keepdims (bool): Default: False.
If True, the axes which are reduced are left in the result as dimensions with size one.
With this option, the result will broadcast correctly against the input array.
If the default value is passed, then keepdims will not be passed through to the any method of
sub-classes of ndarray, however any non-default value will be. If the sub-class method does not
implement keepdims any exceptions will be raised.
Returns:
Returns single boolean unless axis is not None
Raises:
TypeError: If input is not array_like or `axis` is not int or tuple of ints or
`keepdims` is not integer or `initial` is not scalar.
ValueError: If any axis is out of range or duplicate axes exist.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore.numpy as np
>>> x = np.array([1, -2.3, 2.1]).astype('float32')
>>> output = np.signbit(x)
>>> print(output)
[False True False]
"""
if not isinstance(keepdims, int):
_raise_type_error("integer argument expected, but got ", keepdims)
if axis is not None:
_check_axis_type(axis, True, True, False)
axis = _canonicalize_axis(axis, a.ndim)
a = _to_tensor(a)
keepdims = keepdims not in (0, False)
return F.not_equal(a, 0).any(axis, keepdims)

File diff suppressed because it is too large Load Diff

View File

@ -13,11 +13,14 @@
# limitations under the License.
# ============================================================================
"""internal utility functions"""
import types
from ..common import Tensor
from ..ops import functional as F
from ..common import dtype as mstype
from .utils_const import _tile_size, _add_unit_axes, _raise_type_error, _type_convert
from .utils_const import _tile_size, _add_unit_axes, _raise_type_error, _type_convert, \
_tuple_setitem, _callable_const
def _deep_list(array_like):
@ -154,3 +157,52 @@ def _get_dtype_from_scalar(*input_numbers):
def _isnan(x):
"""Computes isnan."""
return F.not_equal(x, x)
def _convert_bool_to_int(tensor):
"""Convert tensor with bool type to int32."""
if tensor.dtype == mstype.bool_:
return tensor.astype("int32")
return tensor
def _slice_along_axis(f, axis, slice_start, slice_end):
"""
Slice a tensor along a given axis
Args:
f (Tensor): Input Tensor.
axis (int): Specified axis.
slice_start (int): The start of the slice.
slice_end (int): The end of the slice.
Returns:
Sliced tensor.
"""
index_start = (0,) * f.ndim
index_end = f.shape
slice_size = slice_end - slice_start
index_start = _tuple_setitem(index_start, axis, slice_start)
index_end = _tuple_setitem(index_end, axis, slice_size)
return F.tensor_slice(f, index_start, index_end)
def _to_tensor_origin_dtype(*args):
"""Returns each input as Tensor and remains original dtype."""
res = []
for arg in args:
if isinstance(arg, (int, float, bool, list, tuple)):
arg = _type_convert(Tensor, arg)
elif not isinstance(arg, Tensor):
_raise_type_error("Expect input to be array like.")
res.append(arg)
if len(res) == 1:
return res[0]
return res
def _callable(tensor, obj):
"""Returns True if `obj` is a function."""
if F.isconstant(tensor):
return isinstance(obj, types.FunctionType)
return _callable_const(F.typeof(obj))

View File

@ -14,8 +14,9 @@
# ============================================================================
"""internal graph-compatible utility functions"""
import math
from itertools import zip_longest
from itertools import zip_longest, accumulate
from collections import deque
import operator
import mindspore.context as context
from ..ops import functional as F
@ -126,6 +127,18 @@ def _infer_out_shape(*shapes):
return tuple(shape_out)
@constexpr
def _can_broadcast(*shapes):
"""
Returns Ture if shapes can broadcast, False if they cannot.
"""
try:
_infer_out_shape(*shapes)
except ValueError:
return False
return True
@constexpr
def _check_axis_in_range(axis, ndim):
"""Checks axes are with the bounds of ndim"""
@ -133,6 +146,7 @@ def _check_axis_in_range(axis, ndim):
raise TypeError(f'axes should be integers, not {type(axis)}')
if not -ndim <= axis < ndim:
raise ValueError(f'axis {axis} is out of bounds for array of dimension {ndim}')
return axis % ndim
@constexpr
@ -145,14 +159,11 @@ def _check_axis_valid(axes, ndim):
axes = F.make_range(ndim)
return axes
if isinstance(axes, (tuple, list)):
for axis in axes:
_check_axis_in_range(axis, ndim)
axes = tuple(map(lambda x: x % ndim, axes))
axes = tuple(map(lambda x: _check_axis_in_range(x, ndim), axes))
if any(axes.count(el) > 1 for el in axes):
raise ValueError('duplicate value in "axis"')
return axes
_check_axis_in_range(axes, ndim)
return (axes % ndim,)
return (_check_axis_in_range(axes, ndim),)
@constexpr
@ -397,7 +408,7 @@ def _type_convert(force, obj):
@constexpr
def _list_comprehensions(obj, item=None, return_tuple=False):
def _list_comprehensions(obj, item=None, return_tuple=False, make_none=False):
"""
Generates a new list/tuple by list comprehension.
@ -416,7 +427,9 @@ def _list_comprehensions(obj, item=None, return_tuple=False):
lst = obj
if isinstance(obj, int):
lst = range(obj)
if item is None:
if make_none:
res = [None for _ in lst]
elif item is None:
res = [i for i in lst]
else:
res = [item for i in lst]
@ -425,17 +438,6 @@ def _list_comprehensions(obj, item=None, return_tuple=False):
return res
@constexpr
def _tuple_getitem(tup, idx, startswith=True):
"""
Returns a slice from tup starting with idx. If startswith is False,
returns a lice from tup ending with idx instead.
"""
if startswith:
return tup[idx:]
return tup[:idx]
@constexpr
def _tuple_setitem(tup, idx, value):
"""
@ -471,7 +473,7 @@ def _seq_prod(seq1, seq2):
@constexpr
def _make_tensor(val, dtype):
""" Returns the tensor with value `val` and dtype `dtype`."""
"""Returns the tensor with value `val` and dtype `dtype`."""
return Tensor(val, dtype)
@ -479,3 +481,26 @@ def _make_tensor(val, dtype):
def _tuple_slice(tup, start, end):
"""get sliced tuple from start and end."""
return tup[start:end]
@constexpr
def _isscalar(x):
"""Returns True if x is a scalar type"""
return isinstance(x, (typing.Number, typing.Int, typing.UInt, typing.Float,
typing.Bool, typing.String))
@constexpr
def _cumprod(x):
return tuple(accumulate(x, operator.mul))
@constexpr
def _in(x, y):
return x in y
@constexpr
def _callable_const(x):
"""Returns true if x is a function in graph mode."""
return isinstance(x, typing.Function)

View File

@ -778,9 +778,8 @@ def get_stride_info_from_tuple(data_shape, tuple_index):
@constexpr
def mstype_eq(x, y):
if x == y:
return True
return False
"""Determine whether the input `x` equals `y`."""
return x == y
@constexpr
@ -841,3 +840,26 @@ def tuple_slice(tup, start, end):
@constexpr
def expanded_shape(shape, expand_size):
return (1,)*expand_size + shape
@constexpr
def sequence_mul_int(seq, number):
"""
Make a new list with native python syntax.
Args:
seq (Union[list, tuple]): Input sequence.
y (int): Input number.
Returns:
New sequence, has the same type as `seq`.
"""
if not isinstance(number, int):
raise TypeError(f"can't multiply sequence by non-int of type {type(number)}")
return seq * number
@constexpr
def check_in_sequence(x, y):
"""Determine whether the input `x` is in the sequence `y`."""
return x in y

View File

@ -130,3 +130,33 @@ def _tensor_in_tuple(x, y):
bool, if x in y return true, x not in y return false.
"""
return compile_utils.tensor_in_sequence(x, y)
@in_.register("mstype", "List")
def _mstype_in_list(x, y):
"""
Determine if a mindspore type is in a list.
Args:
x: mstype
y: List
Returns:
bool, if x in y return true, x not in y return false.
"""
return const_utils.check_in_sequence(x, y)
@in_.register("mstype", "Tuple")
def _mstype_in_tuple(x, y):
"""
Determine if a mindspore type is in a tuple.
Args:
x: mstype
y: Tuple
Returns:
bool, if x in y return true, x not in y return false.
"""
return const_utils.check_in_sequence(x, y)

View File

@ -15,6 +15,7 @@
"""Implementation for internal polymorphism `mul` operations."""
from . import _constexpr_utils as const_utils
from ...composite import base
from ... import functional as F
@ -68,3 +69,47 @@ def _tensor_mul_scalar(x, y):
Tensor, has the same dtype as x.
"""
return F.tensor_mul(x, y)
@mul.register("List", "Number")
def _list_mul_scalar(x, y):
"""
Returns x * y where x is a list and y is a number. y must be integer.
Outputs:
List.
"""
return const_utils.sequence_mul_int(x, y)
@mul.register("Tuple", "Number")
def _tuple_mul_scalar(x, y):
"""
Returns x * y where x is a tuple and y is a number. y must be integer.
Outputs:
Tuple.
"""
return const_utils.sequence_mul_int(x, y)
@mul.register("Number", "List")
def _scalar_mul_list(x, y):
"""
Returns x * y where x is a number and y is a list. x must be integer.
Outputs:
List.
"""
return const_utils.sequence_mul_int(y, x)
@mul.register("Number", "Tuple")
def _scalar_mul_tuple(x, y):
"""
Returns x * y where x is a number and y is a tuple. x must be integer.
Outputs:
Tuple.
"""
return const_utils.sequence_mul_int(y, x)

View File

@ -130,3 +130,33 @@ def _tensor_not_in_tuple(x, y):
bool, if x not in y return true, x in y return false.
"""
return not compile_utils.tensor_in_sequence(x, y)
@not_in_.register("mstype", "List")
def _mstype_not_in_list(x, y):
"""
Determine if a mindspore type is not in a list.
Args:
x: mstype
y: List
Returns:
bool, if x not in y return true, x in y return false.
"""
return not const_utils.check_in_sequence(x, y)
@not_in_.register("mstype", "Tuple")
def _mstype_not_in_tuple(x, y):
"""
Determine if a mindspore type is not in a tuple.
Args:
x: mstype
y: Tuple
Returns:
bool, if x not in y return true, x in y return false.
"""
return not const_utils.check_in_sequence(x, y)

View File

@ -86,6 +86,10 @@ square = P.Square()
sqrt = P.Sqrt()
log = P.Log()
reduce_sum = P.ReduceSum()
reduce_max = P.ReduceMax()
reduce_min = P.ReduceMin()
reduce_mean = P.ReduceMean()
reduce_prod = P.ReduceProd()
tensor_slice = P.Slice()
maximum = P.Maximum()
minimum = P.Minimum()
@ -106,6 +110,10 @@ asinh = P.Asinh()
acosh = P.Acosh()
atanh = P.Atanh()
atan2 = P.Atan2()
bitwise_and = P.BitwiseAnd()
bitwise_or = P.BitwiseOr()
bitwise_xor = P.BitwiseXor()
invert = P.Invert()
scalar_to_array = P.ScalarToArray()
scalar_to_tensor = P.ScalarToTensor()
@ -227,6 +235,8 @@ tensor_operator_registry.register('mean', P.ReduceMean)
tensor_operator_registry.register('reshape', P.Reshape)
tensor_operator_registry.register('transpose', P.Transpose)
tensor_operator_registry.register('broadcast_to', P.BroadcastTo)
tensor_operator_registry.register('matmul', P.MatMul)
tensor_operator_registry.register('argmax', P.Argmax)
# ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal)
tensor_operator_registry.register('__ne__', not_equal)

View File

@ -2658,7 +2658,7 @@ class Acosh(PrimitiveWithInfer):
TypeError: If `input_x` is not a Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
``Ascend`` ``GPU``
Examples:
>>> acosh = ops.Acosh()
@ -2735,7 +2735,7 @@ class Asinh(PrimitiveWithInfer):
TypeError: If `input_x` is not a Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
``Ascend`` ``GPU``
Examples:
>>> asinh = ops.Asinh()

View File

@ -805,6 +805,127 @@ def test_vander():
match_all_arrays(mnp_vander, onp_vander, error=1e-4)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_bartlett():
for i in [-3, -1, 0, 1, 5, 6, 10, 15]:
match_all_arrays(mnp.bartlett(i), onp.bartlett(i), error=3)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_blackman():
for i in [-3, -1, 0, 1, 5, 6, 10, 15]:
match_all_arrays(mnp.blackman(i), onp.blackman(i), error=3)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_hamming():
for i in [-3, -1, 0, 1, 5, 6, 10, 15]:
match_all_arrays(mnp.hamming(i), onp.hamming(i), error=3)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_hanning():
for i in [-3, -1, 0, 1, 5, 6, 10, 15]:
match_all_arrays(mnp.hanning(i), onp.hanning(i), error=3)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_triu_indices():
m = rand_int().tolist()
n = rand_int().tolist()
k = rand_int().tolist()
mnp_res = mnp.triu_indices(n, k, m)
onp_res = onp.triu_indices(n, k, m)
match_all_arrays(mnp_res, onp_res)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_tril_indices():
m = rand_int().tolist()
n = rand_int().tolist()
k = rand_int().tolist()
mnp_res = mnp.tril_indices(n, k, m)
onp_res = onp.tril_indices(n, k, m)
match_all_arrays(mnp_res, onp_res)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_triu_indices_from():
m = int(rand_int().tolist())
n = int(rand_int().tolist())
t = mnp.asarray(rand_int(m, n).tolist())
k = rand_int().tolist()
mnp_res = mnp.triu_indices_from(t, k)
onp_res = onp.triu_indices_from(t.asnumpy(), k)
match_all_arrays(mnp_res, onp_res)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_tril_indices_from():
m = int(rand_int().tolist())
n = int(rand_int().tolist())
t = mnp.asarray(rand_int(m, n).tolist())
k = rand_int().tolist()
mnp_res = mnp.tril_indices_from(t, k)
onp_res = onp.tril_indices_from(t.asnumpy(), k)
match_all_arrays(mnp_res, onp_res)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_histogram_bin_edges():
x = onp.random.randint(-10, 10, 10)
for bins in [(1, 2, 3), [2], 1, 5, 10]:
# pylint: disable=redefined-builtin
for range in [None, (3, 3), (2, 20)]:
match_res(mnp.histogram_bin_edges, onp.histogram_bin_edges, x, bins=bins, range=range, error=3)
match_res(mnp.histogram_bin_edges, onp.histogram_bin_edges, x, onp.arange(5))
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -836,3 +957,90 @@ def test_linspace_exception():
def test_empty_like_exception():
with pytest.raises(ValueError):
mnp.empty_like([[1, 2, 3], [4, 5]])
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_pad():
x_np = onp.random.random([2, 3, 4]).astype("float32")
x_ms = mnp.asarray(x_np.tolist())
# pad constant
mnp_res = mnp.pad(x_ms, ((1, 1), (2, 2), (3, 4)))
onp_res = onp.pad(x_np, ((1, 1), (2, 2), (3, 4)))
match_all_arrays(mnp_res, onp_res, error=1e-5)
mnp_res = mnp.pad(x_ms, ((1, 1), (2, 3), (4, 5)), constant_values=((3, 4), (5, 6), (7, 8)))
onp_res = onp.pad(x_np, ((1, 1), (2, 3), (4, 5)), constant_values=((3, 4), (5, 6), (7, 8)))
match_all_arrays(mnp_res, onp_res, error=1e-5)
# pad statistic
mnp_res = mnp.pad(x_ms, ((1, 1), (2, 2), (3, 4)), mode="mean", stat_length=((1, 2), (2, 10), (3, 4)))
onp_res = onp.pad(x_np, ((1, 1), (2, 2), (3, 4)), mode="mean", stat_length=((1, 2), (2, 10), (3, 4)))
match_all_arrays(mnp_res, onp_res, error=1e-5)
# pad edge
mnp_res = mnp.pad(x_ms, ((1, 1), (2, 2), (3, 4)), mode="edge")
onp_res = onp.pad(x_np, ((1, 1), (2, 2), (3, 4)), mode="edge")
match_all_arrays(mnp_res, onp_res, error=1e-5)
# pad wrap
mnp_res = mnp.pad(x_ms, ((1, 1), (2, 2), (3, 4)), mode="wrap")
onp_res = onp.pad(x_np, ((1, 1), (2, 2), (3, 4)), mode="wrap")
match_all_arrays(mnp_res, onp_res, error=1e-5)
# pad linear_ramp
mnp_res = mnp.pad(x_ms, ((1, 3), (5, 2), (3, 0)), mode="linear_ramp", end_values=((0, 10), (9, 1), (-10, 99)))
onp_res = onp.pad(x_np, ((1, 3), (5, 2), (3, 0)), mode="linear_ramp", end_values=((0, 10), (9, 1), (-10, 99)))
match_all_arrays(mnp_res, onp_res, error=1e-5)
def pad_with_msfunc(vector, pad_width, iaxis, kwargs):
pad_value = kwargs.get('padder', 10)
vector[:pad_width[0]] = pad_value
vector[-pad_width[1]:] = pad_value
return vector
def pad_with_npfunc(vector, pad_width, iaxis, kwargs):
pad_value = kwargs.get('padder', 10)
vector[:pad_width[0]] = pad_value
vector[-pad_width[1]:] = pad_value
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_pad_gpu():
x_np = onp.random.random([2, 1, 4, 3]).astype("float32")
x_ms = mnp.asarray(x_np.tolist())
# pad symmetric odd
mnp_res = mnp.pad(x_ms, ((10, 3), (5, 2), (3, 0), (2, 6)), mode='symmetric', reflect_type='odd')
onp_res = onp.pad(x_np, ((10, 3), (5, 2), (3, 0), (2, 6)), mode='symmetric', reflect_type='odd')
match_all_arrays(mnp_res, onp_res, error=1e-5)
# pad symmetric even
mnp_res = mnp.pad(x_ms, ((10, 13), (5, 12), (3, 0), (2, 6)), mode='symmetric', reflect_type='even')
onp_res = onp.pad(x_np, ((10, 13), (5, 12), (3, 0), (2, 6)), mode='symmetric', reflect_type='even')
match_all_arrays(mnp_res, onp_res, error=1e-5)
# pad reflect odd
mnp_res = mnp.pad(x_ms, ((10, 3), (5, 2), (3, 0), (2, 6)), mode='reflect', reflect_type='odd')
onp_res = onp.pad(x_np, ((10, 3), (5, 2), (3, 0), (2, 6)), mode='reflect', reflect_type='odd')
match_all_arrays(mnp_res, onp_res, error=1e-5)
# pad reflect even
mnp_res = mnp.pad(x_ms, ((10, 13)), mode='reflect', reflect_type='even')
onp_res = onp.pad(x_np, ((10, 13)), mode='reflect', reflect_type='even')
match_all_arrays(mnp_res, onp_res, error=1e-5)
# pad func
x_np = onp.random.random([2, 4]).astype("float32")
x_ms = mnp.asarray(x_np.tolist())
mnp_res = mnp.pad(x_ms, ((5, 5)), mode=pad_with_msfunc, padder=99)
onp_res = onp.pad(x_np, ((5, 5)), mode=pad_with_npfunc, padder=99)
match_all_arrays(mnp_res, onp_res, error=1e-5)

View File

@ -23,7 +23,7 @@ import mindspore.numpy as mnp
from mindspore.nn import Cell
from .utils import rand_int, run_non_kw_test, check_all_results, match_array, \
rand_bool, match_res, run_multi_test, to_tensor
rand_bool, match_res, run_multi_test, to_tensor, match_all_arrays
class Cases():
@ -1253,6 +1253,22 @@ def test_select():
match_res(mnp.select, onp.select, condlist, choicelist, default=10)
def test_choose():
x = rand_int(2, 1, 4).astype(onp.int32)
y = rand_int(3, 2, 5, 4).astype(onp.int32)
match_res(mnp.choose, onp.choose, x, y, mode='wrap')
match_res(mnp.choose, onp.choose, x, y, mode='clip')
x = rand_int(5, 3, 1, 7).astype(onp.int32)
y1 = rand_int(7).astype(onp.int32)
y2 = rand_int(1, 3, 1).astype(onp.int32)
y3 = rand_int(5, 1, 1, 7).astype(onp.int32)
onp_arrays = (x, (y1, y2, y3))
mnp_arrays = (to_tensor(x), tuple(map(to_tensor, (y1, y2, y3))))
match_all_arrays(mnp.choose(*mnp_arrays, mode='wrap'), onp.choose(*onp_arrays, mode='wrap'))
match_all_arrays(mnp.choose(*mnp_arrays, mode='clip'), onp.choose(*onp_arrays, mode='clip'))
class ReshapeExpandSqueeze(Cell):
def __init__(self):
super(ReshapeExpandSqueeze, self).__init__()
@ -1444,3 +1460,159 @@ def test_rot90():
o_rot = onp_rot90(onp_array)
m_rot = mnp_rot90(mnp_array)
check_all_results(o_rot, m_rot)
def mnp_size(x):
a = mnp.size(x)
b = mnp.size(x, axis=0)
return a, b
def onp_size(x):
a = onp.size(x)
b = onp.size(x, axis=0)
return a, b
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_size():
onp_arr = onp.random.rand(2, 3, 4).astype('float32')
mnp_arr = to_tensor(onp_arr)
for actual, expected in zip(mnp_size(mnp_arr), onp_size(onp_arr)):
match_array(actual, expected)
def mnp_array_str(x):
return mnp.array_str(x)
def onp_array_str(x):
return onp.array_str(x)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_array_str():
onp_arr = onp.random.rand(2, 3, 4).astype('float32')
mnp_arr = to_tensor(onp_arr)
for actual, expected in zip(mnp_size(mnp_arr), onp_size(onp_arr)):
match_array(actual, expected)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_apply_along_axis():
onp_arr = rand_int(5, 3, 7)
mnp_arr = to_tensor(onp_arr)
for i in range(-3, 3):
mnp_res = mnp.apply_along_axis(mnp.diag, i, mnp_arr)
onp_res = onp.apply_along_axis(onp.diag, i, onp_arr)
match_all_arrays(mnp_res, onp_res)
mnp_res = mnp.apply_along_axis(lambda x: x[0], 2, mnp_arr)
onp_res = onp.apply_along_axis(lambda x: x[0], 2, onp_arr)
match_all_arrays(mnp_res, onp_res)
mnp_res = mnp.apply_along_axis(lambda x, y, offset=0: (x[4] - y)*offset, 2, mnp_arr, 1, offset=3)
onp_res = onp.apply_along_axis(lambda x, y, offset=0: (x[4] - y)*offset, 2, onp_arr, 1, offset=3)
match_all_arrays(mnp_res, onp_res)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_piecewise():
x = rand_int(2, 4)
mnp_x = to_tensor(x)
condlist = [x < 2, x == 2, x > 2]
mnp_condlist = [mnp_x < 2, mnp_x == 2, mnp_x > 2]
funclist = [lambda x, offset=0: x - offset, lambda x, offset=0: x, lambda x, offset=0: x*offset]
mnp_res = mnp.piecewise(mnp_x, mnp_condlist, funclist, offset=2)
onp_res = onp.piecewise(x, condlist, funclist, offset=2)
match_all_arrays(mnp_res, onp_res)
funclist = [-1, 0, 1]
mnp_res = mnp.piecewise(mnp_x, mnp_condlist, funclist)
onp_res = onp.piecewise(x, condlist, funclist)
match_all_arrays(mnp_res, onp_res)
condlist = [x > 10, x < 0]
mnp_x = to_tensor(x)
mnp_condlist = [mnp_x > 10, mnp_x < 0]
funclist = [lambda x: x - 2, lambda x: x - 1, lambda x: x*2]
mnp_res = mnp.piecewise(mnp_x, mnp_condlist, funclist)
onp_res = onp.piecewise(x, condlist, funclist)
match_all_arrays(mnp_res, onp_res)
x = 2
condlist = True
funclist = [lambda x: x - 1]
mnp_res = mnp.piecewise(x, condlist, funclist)
onp_res = onp.piecewise(x, condlist, funclist)
match_all_arrays(mnp_res, onp_res)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_unravel_index():
shapes = [(), 1, 3, (5, 1), (2, 6, 3)]
dims = [(5, 4, 7), (5*4, 7), 5*4*7]
for shape in shapes:
x = onp.random.randint(0, 5*4*7, shape)
for dim in dims:
for order in ('C', 'F'):
mnp_res = mnp.unravel_index(to_tensor(x), dim, order=order)
onp_res = onp.unravel_index(x, dim, order=order)
match_all_arrays(mnp_res, onp_res)
def mnp_apply_over_axes(x):
a = mnp.apply_over_axes(mnp.sum, x, axes=0)
b = mnp.apply_over_axes(mnp.sum, x, axes=(0, 1))
c = mnp.apply_over_axes(mnp.std, x, axes=1)
d = mnp.apply_over_axes(mnp.mean, x, axes=(-1,))
return a, b, c, d
def onp_apply_over_axes(x):
a = onp.apply_over_axes(onp.sum, x, axes=0)
b = onp.apply_over_axes(onp.sum, x, axes=(0, 1))
c = onp.apply_over_axes(onp.std, x, axes=1)
d = onp.apply_over_axes(onp.mean, x, axes=(-1,))
return a, b, c, d
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_apply_over_axes():
arrs = [
onp.random.rand(2, 2).astype('float32'),
onp.random.rand(3, 2, 2).astype('float32'),
onp.random.rand(5, 4, 3, 3).astype('float32'),
]
for x in arrs:
for expected, actual in zip(onp_apply_over_axes(x),
mnp_apply_over_axes(to_tensor(x))):
match_array(actual.asnumpy(), expected, error=5)

View File

@ -398,3 +398,91 @@ def test_logical_not():
expected = onp_logical_not(arr)
actual = mnp_logical_not(to_tensor(arr))
onp.testing.assert_equal(actual.asnumpy().tolist(), expected.tolist())
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_array_equal():
a = [0, 1, 2, float('inf'), float('nan')]
b = [0, 1, 2, float('inf'), float('nan')]
match_all_arrays(mnp.array_equal(a, b), onp.array_equal(a, b))
a = [0, 1, 2]
b = [[0, 1, 2], [0, 1, 2]]
assert mnp.array_equal(a, b) == onp.array_equal(a, b)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_array_equiv():
a = [0, 1, 2, float('inf'), float('nan')]
b = [0, 1, 2, float('inf'), float('nan')]
match_all_arrays(mnp.array_equal(a, b), onp.array_equal(a, b))
a = [0, 1, 2]
b = [[0, 1, 2], [0, 1, 2]]
assert mnp.array_equal(a, b) == onp.array_equal(a, b)
def mnp_signbit(*arrs):
arr1 = arrs[0]
arr2 = arrs[1]
a = mnp.signbit(arr1)
b = mnp.signbit(arr2, dtype=mnp.bool_)
return a, b
def onp_signbit(*arrs):
arr1 = arrs[0]
arr2 = arrs[1]
a = onp.signbit(arr1)
b = onp.signbit(arr2, dtype='bool')
return a, b
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_signbit():
onp_arrs = [onp.arange(-10, 10).astype('float32'), onp.arange(-10, 10).astype('int32')]
mnp_arrs = [mnp.arange(-10, 10).astype('float32'), mnp.arange(-10, 10).astype('int32')]
for actual, expected in zip(mnp_signbit(*mnp_arrs), onp_signbit(*onp_arrs)):
onp.testing.assert_equal(actual.asnumpy().tolist(), expected.tolist())
def mnp_sometrue(x):
a = mnp.sometrue(x)
b = mnp.sometrue(x, axis=0)
c = mnp.sometrue(x, axis=(0, -1))
d = mnp.sometrue(x, axis=(0, 1), keepdims=True)
e = mnp.sometrue(x, axis=(0, 1), keepdims=-1)
f = mnp.sometrue(x, axis=(0, 1), keepdims=0)
return a, b, c, d, e, f
def onp_sometrue(x):
a = onp.sometrue(x)
b = onp.sometrue(x, axis=0)
c = onp.sometrue(x, axis=(0, -1))
d = onp.sometrue(x, axis=(0, 1), keepdims=True)
e = onp.sometrue(x, axis=(0, 1), keepdims=-1)
f = onp.sometrue(x, axis=(0, 1), keepdims=0)
return a, b, c, d, e, f
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_sometrue():
onp_arr = onp.full((3, 2), [True, False])
mnp_arr = to_tensor(onp_arr)
for actual, expected in zip(mnp_sometrue(mnp_arr), onp_sometrue(onp_arr)):
onp.testing.assert_equal(actual.asnumpy().tolist(), expected.tolist())

View File

@ -18,6 +18,7 @@ import pytest
import numpy as onp
import mindspore.numpy as mnp
from mindspore.common.dtype import dtype_to_nptype
from .utils import rand_int, rand_bool, run_binop_test, run_unary_test, run_multi_test, \
run_single_test, match_res, match_array, match_meta, match_all_arrays, to_tensor
@ -600,14 +601,14 @@ def test_outer():
@pytest.mark.env_onecard
def test_type_promotion():
arr = rand_int(2, 3)
onp_sum = onp_add(arr, arr)
onp_res = onp_add(arr, arr)
a = to_tensor(arr, dtype=mnp.float16)
b = to_tensor(arr, dtype=mnp.float32)
c = to_tensor(arr, dtype=mnp.int32)
match_array(mnp_add(a, b).asnumpy(), onp_sum)
match_array(mnp_add(b, c).asnumpy(), onp_sum)
match_array(mnp_add(a, b).asnumpy(), onp_res)
match_array(mnp_add(b, c).asnumpy(), onp_res)
def mnp_absolute(x):
@ -1817,6 +1818,93 @@ def test_lcm():
match_res(mnp_lcm, onp_lcm, x, y)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_exception_innner():
with pytest.raises(ValueError):
mnp.inner(to_tensor(test_case.arrs[0]),
to_tensor(test_case.arrs[1]))
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_exception_add():
with pytest.raises(ValueError):
mnp.add(to_tensor(test_case.arrs[1]), to_tensor(test_case.arrs[2]))
def mnp_nanmax(x):
a = mnp.nanmax(x)
b = mnp.nanmax(x, keepdims=True)
c = mnp.nanmax(x, axis=-2)
d = mnp.nanmax(x, axis=0, keepdims=True)
e = mnp.nanmax(x, axis=(-2, 3))
f = mnp.nanmax(x, axis=(-3, -1), keepdims=True)
return a, b, c, d, e, f
def onp_nanmax(x):
a = onp.nanmax(x)
b = onp.nanmax(x, keepdims=True)
c = onp.nanmax(x, axis=-2)
d = onp.nanmax(x, axis=0, keepdims=True)
e = onp.nanmax(x, axis=(-2, 3))
f = onp.nanmax(x, axis=(-3, -1), keepdims=True)
return a, b, c, d, e, f
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_nanmax():
x = rand_int(2, 3, 4, 5)
x[0][2][1][3] = onp.nan
x[1][0][2][4] = onp.nan
x[1][1][1][1] = onp.nan
run_multi_test(mnp_nanmax, onp_nanmax, (x,))
def mnp_nanmin(x):
a = mnp.nanmin(x)
b = mnp.nanmin(x, keepdims=True)
c = mnp.nanmin(x, axis=-2)
d = mnp.nanmin(x, axis=0, keepdims=True)
e = mnp.nanmin(x, axis=(-2, 3))
f = mnp.nanmin(x, axis=(-3, -1), keepdims=True)
return a, b, c, d, e, f
def onp_nanmin(x):
a = onp.nanmin(x)
b = onp.nanmin(x, keepdims=True)
c = onp.nanmin(x, axis=-2)
d = onp.nanmin(x, axis=0, keepdims=True)
e = onp.nanmin(x, axis=(-2, 3))
f = onp.nanmin(x, axis=(-3, -1), keepdims=True)
return a, b, c, d, e, f
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_nanmin():
x = rand_int(2, 3, 4, 5)
x[0][2][1][3] = onp.nan
x[1][0][2][4] = onp.nan
x[1][1][1][1] = onp.nan
run_multi_test(mnp_nanmin, onp_nanmin, (x,))
def mnp_nansum(x):
a = mnp.nansum(x)
b = mnp.nansum(x, keepdims=True)
@ -1927,10 +2015,17 @@ def test_mean():
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_exception_innner():
with pytest.raises(ValueError):
mnp.inner(to_tensor(test_case.arrs[0]),
to_tensor(test_case.arrs[1]))
def test_corrcoef():
x = onp.random.random((3, 4)).tolist()
mnp_res = mnp.corrcoef(x)
onp_res = onp.corrcoef(x)
match_all_arrays(mnp_res, onp_res, error=1e-5)
mnp_res = mnp.corrcoef(x[0])
onp_res = onp.corrcoef(x[0])
match_all_arrays(mnp_res, onp_res, error=1e-5)
mnp_res = mnp.corrcoef(x, rowvar=False)
onp_res = onp.corrcoef(x, rowvar=False)
match_all_arrays(mnp_res, onp_res, error=1e-5)
@pytest.mark.level1
@ -1939,9 +2034,227 @@ def test_exception_innner():
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_exception_add():
with pytest.raises(ValueError):
mnp.add(to_tensor(test_case.arrs[1]), to_tensor(test_case.arrs[2]))
def test_multi_dot():
arrays = [rand_int(3), rand_int(3, 5), rand_int(5, 2), rand_int(2, 7), rand_int(7)]
mnp_arrays = [to_tensor(arr) for arr in arrays]
match_all_arrays(mnp.multi_dot(mnp_arrays), onp.linalg.multi_dot(arrays))
match_all_arrays(mnp.multi_dot(mnp_arrays[1:]), onp.linalg.multi_dot(arrays[1:]))
match_all_arrays(mnp.multi_dot(mnp_arrays[:-1]), onp.linalg.multi_dot(arrays[:-1]))
match_all_arrays(mnp.multi_dot(mnp_arrays[1:-1]), onp.linalg.multi_dot(arrays[1:-1]))
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_gradient():
f = onp.random.random((3, 4, 5)).tolist()
mnp_res = mnp.gradient(f)
onp_res = onp.gradient(f)
match_all_arrays(mnp_res, onp_res, error=1e-5)
mnp_res = mnp.gradient(f, axis=1)
onp_res = onp.gradient(f, axis=1)
match_all_arrays(mnp_res, onp_res, error=1e-5)
mnp_res = mnp.gradient(f, -3, axis=(-1, 1))
onp_res = onp.gradient(f, -3, axis=(-1, 1))
match_all_arrays(mnp_res, onp_res, error=1e-5)
mnp_res = mnp.gradient(f, -3, 5, axis=(-1, 0))
onp_res = onp.gradient(f, -3, 5, axis=(-1, 0))
match_all_arrays(mnp_res, onp_res, error=1e-5)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_argmax():
match_res(mnp.argmax, onp.argmax, rand_int())
match_res(mnp.argmax, onp.argmax, rand_int(3))
match_res(mnp.argmax, onp.argmax, rand_int(1, 1, 1))
x = onp.random.choice(onp.arange(-100, 100), size=(2, 3, 4, 5), replace=False)
match_res(mnp.argmax, onp.argmax, x)
for i in range(-4, 4):
match_res(mnp.argmax, onp.argmax, x, axis=i)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_argmin():
match_res(mnp.argmin, onp.argmin, rand_int())
match_res(mnp.argmin, onp.argmin, rand_int(3))
match_res(mnp.argmin, onp.argmin, rand_int(1, 1, 1))
x = rand_int(2, 3, 4, 5)
match_res(mnp.argmin, onp.argmin, x)
for i in range(-4, 4):
match_res(mnp.argmin, onp.argmin, x, axis=i)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_searchsorted():
x = onp.arange(-10, 10)
y = onp.random.randint(-15, 15, size=(2, 3, 4)) + onp.random.choice([0, 0.5], (2, 3, 4))
sorter = onp.random.shuffle(onp.arange(20))
match_res(mnp.searchsorted, onp.searchsorted, x, y)
match_res(mnp.searchsorted, onp.searchsorted, x, y, side='right')
match_res(mnp.searchsorted, onp.searchsorted, x, y, sorter=sorter)
match_res(mnp.searchsorted, onp.searchsorted, x, y, side='right', sorter=sorter)
@pytest.mark.level2
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_interp():
x = onp.random.randint(-15, 15, size=(2, 3, 4)) + onp.random.choice([0, 0.5], (2, 3, 4))
xp = onp.arange(-10, 10)
fp = onp.random.uniform(-50, 50, 20)
match_res(mnp.interp, onp.interp, x, xp, fp, error=3)
match_res(mnp.interp, onp.interp, x, xp, fp, left=onp.random.rand(), error=3)
match_res(mnp.interp, onp.interp, x, xp, fp, right=onp.random.rand(), error=3)
match_res(mnp.interp, onp.interp, x, xp, fp, left=onp.random.rand(), right=onp.random.rand(), error=3)
@pytest.mark.level2
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_digitize():
bins = onp.random.randint(-10, 10, size=10)
bins.sort()
x = onp.random.randint(-15, 15, size=(2, 3, 4)) + onp.random.choice([0, 0.5], (2, 3, 4))
match_res(mnp.digitize, onp.digitize, x, [])
match_res(mnp.digitize, onp.digitize, [], [])
match_res(mnp.digitize, onp.digitize, [], bins)
match_res(mnp.digitize, onp.digitize, x, bins)
match_res(mnp.digitize, onp.digitize, x, bins, right=True)
bins = onp.flip(bins)
match_res(mnp.digitize, onp.digitize, x, bins)
match_res(mnp.digitize, onp.digitize, x, bins, right=True)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_bincount():
x = onp.random.randint(0, 10, 20)
weights = onp.random.randn(20)
match_res(mnp.bincount, onp.bincount, x)
match_res(mnp.bincount, onp.bincount, x, minlength=25)
match_res(mnp.bincount, onp.bincount, x, weights, error=3)
match_res(mnp.bincount, onp.bincount, x, weights, minlength=25, error=3)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_histogram():
x = onp.random.randint(-10, 10, 10)
weights = onp.random.randn(10)
for bins in [(1, 2, 3), [2], 1, 5, 10]:
# pylint: disable=redefined-builtin
for range in [None, (3, 3), (2, 20)]:
match_res(mnp.histogram, onp.histogram, x, bins=bins, range=range, error=3)
match_res(mnp.histogram, onp.histogram, x, bins=bins, range=range, density=True, error=3)
mnp_res = mnp.histogram(to_tensor(x), bins=bins, range=range, weights=to_tensor(weights))
onp_res = onp.histogram(x, bins=bins, range=range, weights=weights)
match_all_arrays(mnp_res, onp_res, error=3)
mnp_res = mnp.histogram(to_tensor(x), bins=bins, range=range,
weights=to_tensor(weights), density=True)
onp_res = onp.histogram(x, bins=bins, range=range, weights=weights, density=True)
match_all_arrays(mnp_res, onp_res, error=3)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_histogramdd():
x = onp.random.randint(-10, 10, (5, 3))
y = [onp.random.randint(-10, 10, 5), onp.random.randint(-10, 10, 5), onp.random.randint(-10, 10, 5)]
mnp_y = list(map(to_tensor, y))
weights = onp.random.randn(5)
for bins in [(15, 4, 9), 10, [onp.arange(5).tolist(), onp.arange(3, 6).tolist(),
onp.arange(10, 20).tolist()]]:
# pylint: disable=redefined-builtin
for range in [None, [[0, 5], [2, 7], [1, 3]]]:
mnp_res = mnp.histogramdd(to_tensor(x), bins=bins, range=range)
onp_res = onp.histogramdd(x, bins=bins, range=range)
match_all_arrays(mnp_res[0], onp_res[0], error=3)
match_all_arrays(mnp_res[1], onp_res[1], error=3)
mnp_res = mnp.histogramdd(to_tensor(x), bins=bins, range=range, density=True)
onp_res = onp.histogramdd(x, bins=bins, range=range, density=True)
match_all_arrays(mnp_res[0], onp_res[0], error=3)
match_all_arrays(mnp_res[1], onp_res[1], error=3)
mnp_res = mnp.histogramdd(to_tensor(x), bins=bins, range=range, weights=to_tensor(weights))
onp_res = onp.histogramdd(x, bins=bins, range=range, weights=weights)
match_all_arrays(mnp_res[0], onp_res[0], error=3)
match_all_arrays(mnp_res[1], onp_res[1], error=3)
mnp_res = mnp.histogramdd(to_tensor(x), bins=bins, range=range,
weights=to_tensor(weights), density=True)
mnp_res = mnp.histogramdd(mnp_y, bins=bins, range=range, weights=to_tensor(weights),
density=True)
onp_res = onp.histogramdd(y, bins, range=range, weights=weights, density=True)
match_all_arrays(mnp_res[0], onp_res[0], error=3)
match_all_arrays(mnp_res[1], onp_res[1], error=3)
bins = onp.arange(24).reshape(3, 8)
mnp_res = mnp.histogramdd(to_tensor(x), bins=to_tensor(bins))
onp_res = onp.histogramdd(x, bins=bins)
match_all_arrays(mnp_res[0], onp_res[0], error=3)
match_all_arrays(mnp_res[1], onp_res[1], error=3)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_histogram2d():
x = onp.random.randint(-10, 10, 10)
y = onp.random.randint(-10, 10, 10)
weights = onp.random.randn(10)
for bins in [(5, 7), 4, [onp.arange(5).tolist(), onp.arange(2, 10).tolist()], [8, [1, 2, 3]]]:
# pylint: disable=redefined-builtin
for range in [None, [(3, 3), (2, 20)]]:
match_res(mnp.histogram2d, onp.histogram2d, x, y, bins=bins, range=range, error=3)
match_res(mnp.histogram2d, onp.histogram2d, x, y, bins=bins, range=range, density=True,
error=3)
mnp_res = mnp.histogram2d(to_tensor(x), to_tensor(y), bins=bins, range=range,
weights=to_tensor(weights))
onp_res = onp.histogram2d(x, y, bins=bins, range=range, weights=weights)
match_all_arrays(mnp_res, onp_res, error=3)
mnp_res = mnp.histogram2d(to_tensor(x), to_tensor(y), bins=bins, range=range,
weights=to_tensor(weights), density=True)
onp_res = onp.histogram2d(x, y, bins=bins, range=range, weights=weights, density=True)
match_all_arrays(mnp_res, onp_res, error=3)
@pytest.mark.level1
@ -1955,6 +2268,277 @@ def test_exception_mean():
mnp.mean(to_tensor(test_case.arrs[0]), (-1, 0))
def mnp_sum(x):
a = mnp.sum(x)
b = mnp.sum(x, axis=0)
c = mnp.sum(x, axis=(0, 1))
d = mnp.sum(x, keepdims=True)
e = mnp.sum(x, initial=-1)
f = mnp.sum(x, initial=1)
g = mnp.sum(x, axis=(0, 2, -2), keepdims=True, initial=0.5, dtype=mnp.float64)
return a, b, c, d, e, f, g
def onp_sum(x):
a = onp.sum(x)
b = onp.sum(x, axis=0)
c = onp.sum(x, axis=(0, 1))
d = onp.sum(x, keepdims=True)
e = onp.sum(x, initial=-1)
f = onp.sum(x, initial=1)
g = onp.sum(x, axis=(0, 2, -2), keepdims=True, initial=0.5, dtype=onp.float64)
return a, b, c, d, e, f, g
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_sum():
onp_arr = onp.random.rand(2, 3, 4).astype('float32')
mnp_arr = to_tensor(onp_arr)
for actual, expected in zip(mnp_sum(mnp_arr), onp_sum(onp_arr)):
match_array(actual.asnumpy(), expected, error=5)
def mnp_sign(x):
return mnp.sign(x)
def onp_sign(x):
return onp.sign(x)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_sign():
onp_arr = [
onp.array(3.5).astype('float32'),
onp.arange(-5, 5).astype('float32'),
onp.random.rand(2, 3, 4).astype('float32')
]
mnp_arr = list(map(to_tensor, onp_arr))
for onp_x, mnp_x in zip(onp_arr, mnp_arr):
expected = onp_sign(onp_x)
actual = mnp_sign(mnp_x)
match_array(actual.asnumpy(), expected, error=5)
def mnp_copysign(x, y):
return mnp.copysign(x, y)
def onp_copysign(x, y):
return onp.copysign(x, y)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_copysign():
onp_arr = [[onp.array([1, -1, 2, -3]).astype('float32'),
onp.array([1, -1, -1, 1]).astype('float32')],
[onp.random.rand(2, 3, 4).astype('float32'),
onp.random.rand(2, 3, 4).astype('float32')]]
mnp_arr = list(map(to_tensor, onp_arr))
for onp_x, mnp_x in zip(onp_arr, mnp_arr):
expected = onp_copysign(onp_x[0], onp_x[1])
actual = mnp_copysign(mnp_x[0], mnp_x[1])
match_array(actual.asnumpy(), expected, error=5)
def mnp_matrix_power(x):
a = mnp.matrix_power(x, 0)
b = mnp.matrix_power(x, 1)
c = mnp.matrix_power(x, 2)
d = mnp.matrix_power(x, 3)
return a, b, c, d
def onp_matrix_power(x):
a = onp.linalg.matrix_power(x, 0)
b = onp.linalg.matrix_power(x, 1)
c = onp.linalg.matrix_power(x, 2)
d = onp.linalg.matrix_power(x, 3)
return a, b, c, d
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_matrix_power():
arrs = [
onp.random.rand(2, 2).astype('float32'),
onp.random.rand(3, 2, 2).astype('float32'),
onp.random.rand(5, 4, 3, 3).astype('float32'),
]
for x in arrs:
onp_res = onp_matrix_power(x)
mnp_res = mnp_matrix_power(to_tensor(x))
for expected, actual in zip(onp_res, mnp_res):
match_array(actual.asnumpy(), expected, error=5)
def mnp_around(x):
a = mnp.around(x)
b = mnp.around(x, 1)
c = mnp.around(x, 2)
d = mnp.around(x, 3)
return a, b, c, d
def onp_around(x):
a = onp.around(x)
b = onp.around(x, 1)
c = onp.around(x, 2)
d = onp.around(x, 3)
return a, b, c, d
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_around():
arrs = [
onp.random.rand(2, 2).astype('float32'),
onp.random.rand(3, 2, 2).astype('float32'),
onp.random.rand(5, 4, 3, 3).astype('float32'),
]
for x in arrs:
onp_res = onp_around(x)
mnp_res = mnp_around(to_tensor(x))
for expected, actual in zip(onp_res, mnp_res):
match_array(actual.asnumpy(), expected, error=5)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_polyadd():
arrs = [rand_int(), rand_int(1), rand_int(3), rand_int(7)]
for x in arrs:
for y in arrs:
match_res(mnp.polyadd, onp.polyadd, x, y)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_polysub():
arrs = [rand_int(), rand_int(1), rand_int(3), rand_int(7)]
for x in arrs:
for y in arrs:
match_res(mnp.polysub, onp.polysub, x, y, error=1)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_polyval():
polys = [rand_int(1), rand_int(3), rand_int(7)]
arrs = [rand_int(), rand_int(1), rand_int(3), rand_int(2, 3, 1), rand_int(1, 5, 4)]
for p in polys:
for x in arrs:
match_res(mnp.polyval, onp.polyval, p, x, error=3)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_polyder():
poly = rand_int(7)
for i in range(5):
match_res(mnp.polyder, onp.polyder, poly, m=i)
@pytest.mark.level2
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_polymul():
arrs = [rand_int(), rand_int(1), rand_int(3), rand_int(7)]
for x in arrs:
for y in arrs:
match_res(mnp.polymul, onp.polymul, x, y)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_polyint():
poly = rand_int(7)
match_res(mnp.polyint, onp.polyint, poly, m=1, k=7, error=3)
match_res(mnp.polyint, onp.polyint, poly, m=1, k=[9], error=3)
match_res(mnp.polyint, onp.polyint, poly, m=3, k=2, error=3)
for i in range(5):
match_res(mnp.polyint, onp.polyint, poly, m=i, k=rand_int(i).tolist(), error=3)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_result_type():
x = ('?', True, mnp.uint16, mnp.ones((2, 3)).astype(mnp.int32), 'float')
y = ('?', True, onp.uint16, onp.ones((2, 3)).astype(onp.int32), 'float')
for i in range(4):
mnp_args = x[:i + 1]
actual = dtype_to_nptype(mnp.result_type(*mnp_args))
onp_args = y[:i + 1]
expected = onp.result_type(*onp_args)
if expected == onp.int64:
expected = onp.int
elif expected == onp.float64:
expected = onp.float32
assert actual == expected
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_unwrap():
x = onp.linspace(onp.linspace((0, 1), (10, 15), 5), onp.linspace((0, 2), (3*onp.pi, 7*onp.pi), 5), 7)
x[5:2] += onp.pi
for i in range(-3, 3):
match_res(mnp.unwrap, onp.unwrap, x, axis=i, error=3)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -1964,3 +2548,185 @@ def test_exception_mean():
def test_exception_amax():
with pytest.raises(TypeError):
mnp.amax(mnp.array([[1, 2], [3, 4]]).astype(mnp.float32), initial=[1.0, 2.0])
def mnp_cumprod(x):
a = mnp.cumprod(x)
b = mnp.cumprod(x, axis=0)
c = mnp.cumprod(x, axis=1)
return a, b, c
def onp_cumprod(x):
a = onp.cumprod(x)
b = onp.cumprod(x, axis=0)
c = onp.cumprod(x, axis=1)
return a, b, c
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_cumprod():
mnp_x = mnp.arange(1, 7).reshape(2, 3)
tensors = [mnp_x.astype('bool'),
mnp_x.astype('uint8'),
mnp_x.astype('int16'),
mnp_x.astype('float16'),
mnp_x.astype('float32')]
for x in tensors:
onp_res = onp_cumprod(x.asnumpy())
mnp_res = mnp_cumprod(x)
for expected, actual in zip(onp_res, mnp_res):
match_array(actual.asnumpy(), expected, error=5)
def mnp_ravel_multi_index(x):
a = mnp.ravel_multi_index(x, (7, 6))
b = mnp.ravel_multi_index(x, (7, 6), order='F')
c = mnp.ravel_multi_index(x, (4, 6), mode='clip')
d = mnp.ravel_multi_index(x, (4, 4), mode='wrap')
return a, b, c, d
def onp_ravel_multi_index(x):
a = onp.ravel_multi_index(x, (7, 6))
b = onp.ravel_multi_index(x, (7, 6), order='F')
c = onp.ravel_multi_index(x, (4, 6), mode='clip')
d = onp.ravel_multi_index(x, (4, 4), mode='wrap')
return a, b, c, d
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ravel_multi_index():
x = mnp.array([[3, 6, 6], [4, 5, 1]])
onp_res = onp_ravel_multi_index(x.asnumpy())
mnp_res = mnp_ravel_multi_index(x)
for expected, actual in zip(onp_res, mnp_res):
match_array(actual.asnumpy(), expected, error=5)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_norm():
arrs = [rand_int(1), rand_int(9), rand_int(6, 4), rand_int(5, 2, 3, 7)]
for x in arrs:
for keepdims in [True, False]:
match_res(mnp.norm, onp.linalg.norm, x, keepdims=keepdims, error=3)
axes = [None, -1, 1, 2]
order = [None, float('inf'), -float('inf'), 0, 1, -1, 2, -2, 3.7, -5, 3]
for x, axis in zip(arrs, axes):
# pylint: disable=redefined-builtin
for ord in order:
for keepdims in [True, False]:
match_res(mnp.norm, onp.linalg.norm, x, ord=ord, axis=axis, keepdims=keepdims, error=3)
x = rand_int(3, 6, 4, 5)
axes = [(0, 1), (0, 3), (1, 3), (2, 3)]
order = [None, 'fro', float('inf'), -float('inf'), 1, -1]
for axis in axes:
# pylint: disable=redefined-builtin
for ord in order:
for keepdims in [True, False]:
match_res(mnp.norm, onp.linalg.norm, x, ord=ord, axis=axis, keepdims=keepdims, error=3)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_bitwise_and():
arrs = [onp.random.randint(-100, 100, ()), onp.random.randint(-100, 100, (1,)),
onp.random.randint(-100, 100, (5,)), onp.random.randint(-100, 100, (3, 1)),
onp.random.randint(-100, 100, (4, 1, 5))]
for x in arrs:
for y in arrs:
match_res(mnp.bitwise_and, onp.bitwise_and, x, y, dtype=mnp.int32)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_bitwise_or():
arrs = [onp.random.randint(-100, 100, ()), onp.random.randint(-100, 100, (1,)),
onp.random.randint(-100, 100, (5,)), onp.random.randint(-100, 100, (3, 1)),
onp.random.randint(-100, 100, (4, 1, 5))]
for x in arrs:
for y in arrs:
match_res(mnp.bitwise_or, onp.bitwise_or, x, y, dtype=mnp.int32)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_bitwise_xor():
arrs = [onp.random.randint(-100, 100, ()), onp.random.randint(-100, 100, (1,)),
onp.random.randint(-100, 100, (5,)), onp.random.randint(-100, 100, (3, 1)),
onp.random.randint(-100, 100, (4, 1, 5))]
for x in arrs:
for y in arrs:
match_res(mnp.bitwise_xor, onp.bitwise_xor, x, y, dtype=mnp.int32)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_invert():
x = onp.random.randint(-100, 100, (2, 3))
match_res(mnp.invert, onp.invert, x, dtype=mnp.int16)
match_res(mnp.invert, onp.invert, x.astype(onp.uint16), dtype=mnp.uint16)
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.env_onecard
def test_rint():
arrs = [
onp.random.rand(2, 2).astype('float32'),
onp.random.rand(3, 2, 2).astype('float32'),
onp.random.rand(5, 4, 3, 3).astype('float32'),
]
for x in arrs:
for expected, actual in zip(onp.rint(x), mnp.rint(to_tensor(x))):
match_array(actual.asnumpy(), expected, error=5)
def mnp_correlate(a, v):
a = mnp.correlate(a, v, mode="valid")
b = mnp.correlate(a, v, mode="full")
c = mnp.correlate(a, v, mode="same")
d = mnp.correlate(a, v)
return a, b, c, d
def onp_correlate(a, v):
a = onp.correlate(a, v, mode="valid")
b = onp.correlate(a, v, mode="full")
c = onp.correlate(a, v, mode="same")
d = onp.correlate(a, v)
return a, b, c, d
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_correlate():
first_sequences = [[1], [1, 2], [0, 0, 0, 1], [1, 2, 3, 4, 5]]
second_sequences = [[2], [0, 1], [1, 2, 3]]
for a in first_sequences:
for v in second_sequences:
mnp_res = mnp_correlate(a, v)
onp_res = onp_correlate(a, v)
match_all_arrays(mnp_res, onp_res)

View File

@ -24,7 +24,7 @@ def match_array(actual, expected, error=0):
if isinstance(actual, int):
actual = onp.asarray(actual)
if isinstance(expected, int):
if isinstance(expected, (int, tuple)):
expected = onp.asarray(expected)
if error > 0:
@ -91,11 +91,9 @@ def rand_bool(*shape):
def match_res(mnp_fn, onp_fn, *arrs, **kwargs):
"""Checks results from applying mnp_fn and onp_fn on arrs respectively"""
dtype = kwargs.get('dtype', mnp.float32)
kwargs.pop('dtype', None)
dtype = kwargs.pop('dtype', mnp.float32)
mnp_arrs = map(functools.partial(Tensor, dtype=dtype), arrs)
error = kwargs.get('error', 0)
kwargs.pop('error', None)
error = kwargs.pop('error', 0)
mnp_res = mnp_fn(*mnp_arrs, **kwargs)
onp_res = onp_fn(*arrs, **kwargs)
match_all_arrays(mnp_res, onp_res, error=error)
@ -173,6 +171,7 @@ def run_logical_test(mnp_fn, onp_fn, test_case):
for x2 in test_case.boolean_arrs:
match_res(mnp_fn, onp_fn, x1, x2, dtype=mnp.bool_)
def to_tensor(obj, dtype=None):
if dtype is None:
res = Tensor(obj)