forked from mindspore-Ecosystem/mindspore
Add new np interfaces
This commit is contained in:
parent
a1c3f55aca
commit
72b365c24b
|
@ -30,13 +30,14 @@ from .array_ops import (transpose, expand_dims, squeeze, rollaxis, swapaxes, res
|
|||
ravel, concatenate, where, atleast_1d, atleast_2d, atleast_3d,
|
||||
column_stack, hstack, dstack, vstack, stack, unique, moveaxis,
|
||||
tile, broadcast_to, broadcast_arrays, roll, append, split, vsplit,
|
||||
flip, flipud, fliplr, hsplit, dsplit, take_along_axis, take, repeat)
|
||||
flip, flipud, fliplr, hsplit, dsplit, take_along_axis, take, repeat,
|
||||
rot90, select, array_split)
|
||||
from .array_creations import copy_ as copy
|
||||
from .array_creations import (array, asarray, asfarray, ones, zeros, full, arange,
|
||||
linspace, logspace, eye, identity, empty, empty_like,
|
||||
ones_like, zeros_like, full_like, diagonal, tril, triu,
|
||||
tri, trace, meshgrid, mgrid, ogrid, diagflat,
|
||||
diag, diag_indices, ix_)
|
||||
diag, diag_indices, ix_, indices, geomspace, vander)
|
||||
from .dtypes import (int_, int8, int16, int32, int64, uint, uint8, uint16,
|
||||
uint32, uint64, float_, float16, float32, float64, bool_, inf, nan,
|
||||
numeric_types, PINF, NINF)
|
||||
|
@ -45,35 +46,51 @@ from .math_ops import (mean, inner, add, subtract, multiply, divide, true_divide
|
|||
matmul, square, sqrt, reciprocal, log, maximum, heaviside, amax, amin,
|
||||
hypot, float_power, floor, ptp, deg2rad, rad2deg, count_nonzero,
|
||||
positive, negative, clip, floor_divide, remainder, fix, fmod, trunc,
|
||||
exp, expm1, cumsum)
|
||||
exp, expm1, exp2, kron, promote_types, divmod_, diff, cbrt,
|
||||
cross, ceil, trapz, gcd, lcm, convolve, log1p, logaddexp, log2,
|
||||
logaddexp2, log10, ediff1d, nansum, nanmean, nanvar, nanstd, cumsum, nancumsum,
|
||||
sin, cos, tan, arcsin, arccos, arctan, sinh, cosh, tanh, arcsinh, arccosh,
|
||||
arctanh, arctan2, cov)
|
||||
from .logic_ops import (not_equal, less_equal, less, greater_equal, greater, equal, isfinite,
|
||||
isnan, isinf, isposinf, isneginf, isscalar)
|
||||
isnan, isinf, isposinf, isneginf, isscalar, logical_and, logical_not,
|
||||
logical_or, logical_xor, in1d, isin, isclose)
|
||||
|
||||
mod = remainder
|
||||
fabs = absolute
|
||||
divmod = divmod_ # pylint: disable=redefined-builtin
|
||||
abs = absolute # pylint: disable=redefined-builtin
|
||||
max = amax # pylint: disable=redefined-builtin
|
||||
min = amin # pylint: disable=redefined-builtin
|
||||
|
||||
|
||||
array_ops_module = ['transpose', 'expand_dims', 'squeeze', 'rollaxis', 'swapaxes', 'reshape',
|
||||
'ravel', 'concatenate', 'where', 'atleast_1d', 'atleast_2d', 'atleast_3d',
|
||||
'column_stack', 'hstack', 'dstack', 'vstack', 'stack', 'unique', 'moveaxis',
|
||||
'tile', 'broadcast_to', 'broadcast_arrays', 'append', 'roll', 'split', 'vsplit',
|
||||
'flip', 'flipud', 'fliplr', 'hsplit', 'dsplit', 'take_along_axis', 'take',
|
||||
'repeat']
|
||||
'repeat', 'rot90', 'select', 'array_split']
|
||||
|
||||
array_creations_module = ['array', 'asarray', 'asfarray', 'ones', 'zeros', 'full', 'arange',
|
||||
'linspace', 'logspace', 'eye', 'identity', 'empty', 'empty_like',
|
||||
'ones_like', 'zeros_like', 'full_like', 'diagonal', 'tril', 'triu',
|
||||
'tri', 'trace', 'meshgrid', 'mgrid', 'ogrid', 'diagflat', 'diag',
|
||||
'diag_indices', 'ix_', 'cumsum']
|
||||
'diag_indices', 'ix_', 'indices', 'geomspace', 'vander']
|
||||
|
||||
math_module = ['mean', 'inner', 'add', 'subtract', 'multiply', 'divide', 'true_divide', 'power',
|
||||
'dot', 'outer', 'tensordot', 'absolute', 'std', 'var', 'average', 'not_equal',
|
||||
'minimum', 'matmul', 'square', 'sqrt', 'reciprocal', 'log', 'maximum',
|
||||
'heaviside', 'amax', 'amin', 'hypot', 'float_power', 'floor', 'ptp', 'deg2rad',
|
||||
'rad2deg', 'count_nonzero', 'positive', 'negative', 'clip', 'floor_divide',
|
||||
'remainder', 'mod', 'fix', 'fmod', 'trunc', 'exp', 'expm1', 'fabs', 'cumsum']
|
||||
'remainder', 'mod', 'fix', 'fmod', 'trunc', 'exp', 'expm1', 'fabs', 'exp2', 'kron',
|
||||
'promote_types', 'divmod', 'diff', 'cbrt', 'cross', 'ceil', 'trapz',
|
||||
'abs', 'max', 'min', 'gcd', 'lcm', 'log1p', 'logaddexp', 'log2', 'logaddexp2', 'log10',
|
||||
'convolve', 'ediff1d', 'nansum', 'nanmean', 'nanvar', 'nanstd', 'cumsum',
|
||||
'nancumsum', 'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan', 'sinh', 'cosh', 'tanh',
|
||||
'arcsinh', 'arccosh', 'arctanh', 'arctan2', 'cov']
|
||||
|
||||
logic_module = ['not_equal', 'less_equal', 'less', 'greater_equal', 'greater', 'equal', 'isfinite',
|
||||
'isnan', 'isinf', 'isposinf', 'isneginf', 'isscalar']
|
||||
'isnan', 'isinf', 'isposinf', 'isneginf', 'isscalar', 'logical_and', 'logical_not',
|
||||
'logical_or', 'logical_xor', 'in1d', 'isin', 'isclose']
|
||||
|
||||
__all__ = array_ops_module + array_creations_module + math_module + logic_module + numeric_types
|
||||
|
||||
|
|
|
@ -13,8 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""array operations, the function docs are adapted from Numpy API."""
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as onp
|
||||
|
||||
from ..common import Tensor
|
||||
|
@ -27,10 +25,11 @@ from .._c_expression import Tensor as Tensor_
|
|||
from .._c_expression.typing import Float
|
||||
|
||||
from .utils import _check_input_for_asarray, _deep_list, _deep_tensor_to_nparray, \
|
||||
_broadcast_to_shape, _check_input_tensor, _convert_64_to_32, _get_dtype_from_scalar
|
||||
_broadcast_to_shape, _check_input_tensor, _convert_64_to_32, _get_dtype_from_scalar, \
|
||||
_expand
|
||||
from .utils_const import _raise_value_error, _empty, _check_axis_valid, _max, _min, \
|
||||
_check_same_type, _is_shape_empty, _check_shape, _check_dtype, _tile_size, _abs, \
|
||||
_raise_type_error, _expanded_shape, _check_is_float, _iota, \
|
||||
_raise_type_error, _expanded_shape, _tuple_getitem, _check_is_float, _iota, \
|
||||
_type_convert, _canonicalize_axis, _list_comprehensions, _ceil
|
||||
from .array_ops import transpose, ravel, concatenate, broadcast_arrays, reshape, broadcast_to
|
||||
from .dtypes import nan
|
||||
|
@ -49,9 +48,8 @@ def array(obj, dtype=None, copy=True, ndmin=0):
|
|||
This function creates tensors from an array-like object.
|
||||
|
||||
Args:
|
||||
obj (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in
|
||||
any form that can be converted to a `Tensor`. This includes lists, lists of
|
||||
tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray.
|
||||
obj (Union[int, float, bool, list, tuple]): Input data, in any form that
|
||||
can be converted to a `Tensor`. This includes Tensor, list, tuple and numbers.
|
||||
dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can
|
||||
be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type
|
||||
of the new tensor will be inferred from obj. Default is :class:`None`.
|
||||
|
@ -76,48 +74,21 @@ def array(obj, dtype=None, copy=True, ndmin=0):
|
|||
>>> print(np.array([1,2,3]))
|
||||
[1 2 3]
|
||||
"""
|
||||
if ndmin > 0:
|
||||
# Fall back to original numpy creation.
|
||||
if isinstance(obj, Tensor):
|
||||
obj = obj.asnumpy()
|
||||
return asarray(onp.array(obj, dtype, copy=copy, ndmin=ndmin))
|
||||
res = asarray(obj, dtype)
|
||||
if ndmin > res.ndim:
|
||||
res = _expand(res, ndmin)
|
||||
|
||||
if not copy:
|
||||
return asarray(obj, dtype=dtype)
|
||||
if copy:
|
||||
res = copy_(res)
|
||||
elif dtype is not None and dtype != res.dtype:
|
||||
res = res.astype(dtype)
|
||||
|
||||
obj = deepcopy(obj)
|
||||
return asarray(obj, dtype=dtype)
|
||||
return res
|
||||
|
||||
|
||||
def asarray(a, dtype=None):
|
||||
"""
|
||||
Converts the input to tensor.
|
||||
|
||||
This function converts tensors from an array-like object.
|
||||
|
||||
Args:
|
||||
a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in
|
||||
any form that can be converted to a `Tensor`. This includes lists, lists of
|
||||
tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray.
|
||||
dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can
|
||||
be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type
|
||||
of the new tensor will be inferred from obj. Default is :class:`None`.
|
||||
|
||||
Returns:
|
||||
Tensor, generated tensor with the specified dtype.
|
||||
|
||||
Raises:
|
||||
TypeError: If input arguments have types not specified above.
|
||||
ValueError: If input `a` has different sizes at different dimensions.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> print(np.asarray([1,2,3]))
|
||||
[1 2 3]
|
||||
"""
|
||||
@constexpr
|
||||
def asarray_const(a, dtype=None):
|
||||
"""Converts the input to tensor. Note here `a` cannot be tensor itself."""
|
||||
_check_input_for_asarray(a)
|
||||
|
||||
if dtype is not None:
|
||||
|
@ -149,15 +120,59 @@ def asarray(a, dtype=None):
|
|||
dtype = mstype.pytype_to_dtype(a.dtype)
|
||||
a = Tensor.from_numpy(a)
|
||||
|
||||
# If a is already a tensor and we don't need to cast dtype, return a
|
||||
if isinstance(a, Tensor):
|
||||
if dtype is None or dtype == a.dtype:
|
||||
return a
|
||||
|
||||
return Tensor(a, dtype=dtype)
|
||||
|
||||
|
||||
asarray_const = constexpr(asarray)
|
||||
def asarray(a, dtype=None):
|
||||
"""
|
||||
Converts the input to tensor.
|
||||
|
||||
This function converts tensors from an array-like object.
|
||||
|
||||
Args:
|
||||
a (Union[int, float, bool, list, tuple, Tensor]): Input data, in any form that can
|
||||
be converted to a `Tensor`. This includes Tensor, list, tuple and numbers.
|
||||
dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can
|
||||
be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type
|
||||
of the new tensor will be inferred from obj. Default is :class:`None`.
|
||||
|
||||
Returns:
|
||||
Tensor, generated tensor with the specified dtype.
|
||||
|
||||
Raises:
|
||||
TypeError: If input arguments have types not specified above.
|
||||
ValueError: If input `a` has different sizes at different dimensions.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> print(np.asarray([1,2,3]))
|
||||
[1 2 3]
|
||||
"""
|
||||
if isinstance(a, Tensor):
|
||||
if dtype is None or dtype == a.dtype:
|
||||
return a
|
||||
return a.astype(dtype)
|
||||
return asarray_const(a, dtype)
|
||||
|
||||
|
||||
@constexpr
|
||||
def asfarray_const(a, dtype=mstype.float32):
|
||||
"""Converts the input to tensor. Note here `a` cannot be tensor itself."""
|
||||
_check_input_for_asarray(a)
|
||||
if isinstance(a, (list, tuple)):
|
||||
# Convert all tuple/nested tuples to lists
|
||||
a = _deep_list(a)
|
||||
# Convert all tensor sub-elements to numpy arrays
|
||||
a = _deep_tensor_to_nparray(a)
|
||||
a = onp.asarray(a)
|
||||
if a.dtype is onp.dtype('object'):
|
||||
raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.")
|
||||
a = Tensor.from_numpy(a)
|
||||
|
||||
return Tensor(a, dtype)
|
||||
|
||||
|
||||
def asfarray(a, dtype=mstype.float32):
|
||||
|
@ -167,9 +182,8 @@ def asfarray(a, dtype=mstype.float32):
|
|||
If non-float dtype is defined, this function will return a float32 tensor instead.
|
||||
|
||||
Args:
|
||||
a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in
|
||||
any form that can be converted to a `Tensor`. This includes lists, lists of
|
||||
tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray.
|
||||
a (Union[int, float, bool, list, tuple, Tensor]): Input data, in any form that can
|
||||
be converted to a `Tensor`. This includes Tensor, list, tuple and numbers.
|
||||
dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can
|
||||
be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type
|
||||
of the new tensor will be inferred from `a`. Default is :class:`mindspore.float32`.
|
||||
|
@ -190,27 +204,18 @@ def asfarray(a, dtype=mstype.float32):
|
|||
>>> print(np.asfarray([1,2,3]))
|
||||
[1. 2. 3.]
|
||||
"""
|
||||
_check_input_for_asarray(a)
|
||||
|
||||
if dtype is None:
|
||||
return asarray(a)
|
||||
|
||||
dtype = _check_dtype(dtype)
|
||||
if dtype not in (mstype.float16, mstype.float32, mstype.float64):
|
||||
# pylint: disable=consider-using-in
|
||||
if dtype != mstype.float16 and dtype != mstype.float32 and dtype != mstype.float64:
|
||||
dtype = mstype.float32
|
||||
|
||||
if isinstance(a, (list, tuple)):
|
||||
# Convert all tuple/nested tuples to lists
|
||||
a = _deep_list(a)
|
||||
# Convert all tensor sub-elements to numpy arrays
|
||||
a = _deep_tensor_to_nparray(a)
|
||||
a = onp.asarray(a)
|
||||
if a.dtype is onp.dtype('object'):
|
||||
raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.")
|
||||
if isinstance(a, onp.ndarray):
|
||||
a = Tensor.from_numpy(a)
|
||||
if isinstance(a, Tensor):
|
||||
return a.astype(dtype)
|
||||
|
||||
return Tensor(a, dtype)
|
||||
return asfarray_const(a)
|
||||
|
||||
|
||||
def copy_(a):
|
||||
|
@ -218,9 +223,8 @@ def copy_(a):
|
|||
Returns a tensor copy of the given object.
|
||||
|
||||
Args:
|
||||
a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in
|
||||
any form that can be converted to a tensor. This includes lists, lists of
|
||||
tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray.
|
||||
a (Union[int, float, bool, list, tuple, Tensor]): Input data, in any form that can
|
||||
be converted to a `Tensor`. This includes Tensor, list, tuple and numbers.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same data as `a`.
|
||||
|
@ -241,8 +245,16 @@ def copy_(a):
|
|||
"""
|
||||
if not isinstance(a, Tensor):
|
||||
a = asarray_const(a)
|
||||
return a.copy()
|
||||
|
||||
# The current implementation registers a new memory location for copied tensor by
|
||||
# doing some reduandent operations.
|
||||
origin_dtype = a.dtype
|
||||
if origin_dtype == mstype.bool_:
|
||||
return F.logical_not(F.logical_not(a))
|
||||
if origin_dtype != mstype.float64:
|
||||
a = a.astype("float32")
|
||||
a = a / ones_like(a)
|
||||
a = a.astype(origin_dtype)
|
||||
return a
|
||||
|
||||
def ones(shape, dtype=mstype.float32):
|
||||
"""
|
||||
|
@ -566,6 +578,65 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
|
|||
return F.tensor_pow(base, linspace_res).astype(dtype)
|
||||
|
||||
|
||||
def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
|
||||
"""
|
||||
Returns numbers spaced evenly on a log scale (a geometric progression).
|
||||
|
||||
This is similar to logspace, but with endpoints specified directly. Each output sample
|
||||
is a constant multiple of the previous.
|
||||
|
||||
Args:
|
||||
start (Union[int, list(int), tuple(int), tensor]): The starting value of the sequence.
|
||||
stop (Union[int, list(int), tuple(int), tensor]): The final value of the sequence,
|
||||
unless endpoint is False. In that case, num + 1 values are spaced over the
|
||||
interval in log-space, of which all but the last (a sequence of length num) are
|
||||
returned.
|
||||
num (int, optional): Number of samples to generate. Default is 50.
|
||||
endpoint (bool, optional): If True, `stop` is the last sample. Otherwise, it is
|
||||
not included. Default is True.
|
||||
dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can
|
||||
be in format of np.float32, or `float32`.If `dtype` is None, infer the data
|
||||
type from other input arguments. Default is None.
|
||||
axis (int, optional): The axis in the result to store the samples. Relevant
|
||||
only if start or stop is array-like. By default (0), the samples will
|
||||
be along a new axis inserted at the beginning. Use -1 to get an axis at the end.
|
||||
Default is 0.
|
||||
|
||||
Returns:
|
||||
Tensor, with samples equally spaced on a log scale.
|
||||
|
||||
Raises:
|
||||
TypeError: If input arguments have types not specified above.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> output = np.geomspace(1, 256, num=9)
|
||||
>>> print(output)
|
||||
[ 1. 2. 4. 8. 16. 32. 64. 128. 256.]
|
||||
>>> output = np.geomspace(1, 256, num=8, endpoint=False)
|
||||
>>> print(output)
|
||||
[ 1. 2. 4. 8. 16. 32. 64. 128.]
|
||||
"""
|
||||
start, stop, num, endpoint, dtype, axis = _type_checking_for_xspace(start, stop, num, endpoint, dtype, axis)
|
||||
root = num
|
||||
if endpoint:
|
||||
root -= 1
|
||||
bases = F.tensor_pow(F.tensor_div(stop, start), asarray_const(1/(root)))
|
||||
exponents = linspace(zeros(F.shape(bases)), F.fill(F.dtype(bases), F.shape(bases), root),
|
||||
num, endpoint=endpoint, dtype=dtype, axis=axis)
|
||||
shape = F.shape(bases)
|
||||
axis = axis + F.rank(bases) + 1 if axis < 0 else axis
|
||||
expanded_shape = _tuple_getitem(shape, axis, False) + (1,) + _tuple_getitem(shape, axis)
|
||||
bases = F.reshape(bases, expanded_shape)
|
||||
start = F.reshape(start, expanded_shape)
|
||||
res = F.tensor_mul(F.tensor_pow(bases, exponents), start)
|
||||
if dtype is not None:
|
||||
res = F.cast(res, dtype)
|
||||
return res
|
||||
|
||||
|
||||
def eye(N, M=None, k=0, dtype=mstype.float32):
|
||||
"""
|
||||
Returns a 2-D tensor with ones on the diagnoal and zeros elsewhere.
|
||||
|
@ -757,7 +828,7 @@ def empty_like(prototype, dtype=None, shape=None):
|
|||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> a = [[(1, 2)], np.ones((1, 2)), [[2, 3]], np.ones((1, 2))]
|
||||
>>> a = np.ones((4,1,2))
|
||||
>>> output = np.empty_like(a)
|
||||
>>> print(output)
|
||||
# result may vary
|
||||
|
@ -794,7 +865,7 @@ def ones_like(a, dtype=None, shape=None):
|
|||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> a = [[(1, 2)], np.ones((1, 2)), [[2, 3]], np.ones((1, 2))]
|
||||
>>> a = np.ones((4,1,2))
|
||||
>>> output = np.ones_like(a)
|
||||
>>> print(output)
|
||||
[[[1. 1.]]
|
||||
|
@ -832,7 +903,7 @@ def zeros_like(a, dtype=None, shape=None):
|
|||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> a = [[(1, 2)], np.ones((1, 2)), [[2, 3]], np.ones((1, 2))]
|
||||
>>> a = np.ones((4,1,2))
|
||||
>>> output = np.zeros_like(a)
|
||||
>>> print(output)
|
||||
[[[0. 0.]]
|
||||
|
@ -871,7 +942,7 @@ def full_like(a, fill_value, dtype=None, shape=None):
|
|||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> a = [[(1, 2)], np.ones((1, 2)), [[2, 3]], np.ones((1, 2))]
|
||||
>>> a = np.ones((4,1,2))
|
||||
>>> output = np.full_like(a, 0.5)
|
||||
>>> print(output)
|
||||
[[[0.5 0.5]]
|
||||
|
@ -1175,9 +1246,8 @@ def _index(i, size, Cartesian=True):
|
|||
if Cartesian:
|
||||
if i == 1:
|
||||
return 0
|
||||
if i == 0:
|
||||
if size >= 2:
|
||||
return 1
|
||||
if i == 0 and size >= 2:
|
||||
return 1
|
||||
return i
|
||||
|
||||
|
||||
|
@ -1630,3 +1700,103 @@ def ix_(*args):
|
|||
return _raise_value_error('Cross index must be 1 dimensional')
|
||||
res += (F.reshape(arr, _expanded_shape(ndim, arr.size, i)),)
|
||||
return res
|
||||
|
||||
|
||||
def vander(x, N=None, increasing=False):
|
||||
"""
|
||||
Generates a Vandermonde matrix.
|
||||
|
||||
The columns of the output matrix are powers of the input vector. The order of
|
||||
the powers is determined by the increasing boolean argument. Specifically, when
|
||||
increasing is `False`, the i-th output column is the input vector raised element-wise
|
||||
to the power of :math:`N - i - 1`. Such a matrix with a geometric progression in each row
|
||||
is named for Alexandre-Theophile Vandermonde.
|
||||
|
||||
Args:
|
||||
x (Union[list, tuple, Tensor]): 1-D input array.
|
||||
N (int, optional): Number of columns in the output. If N is not specified, a
|
||||
square array is returned (``N = len(x)``).
|
||||
increasing (bool, optional): Order of the powers of the columns. If True, the
|
||||
powers increase from left to right, if False (the default) they are reversed.
|
||||
|
||||
Returns:
|
||||
Vandermonde matrix. If `increasing` is `False`, the first column is :math:`x^{(N-1)}`,
|
||||
the second :math:`x^{(N-2)}` and so forth. If `increasing` is `True`, the columns are
|
||||
:math:`x^0, x^1, ..., x^{(N-1)}`.
|
||||
|
||||
Raises:
|
||||
TypeError: If inputs have types not specified above.
|
||||
ValueError: If `x` is not 1-D, or `N` < 0.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> print(np.vander([1,2,3,4,5]))
|
||||
[[ 1 1 1 1 1]
|
||||
[ 16 8 4 2 1]
|
||||
[ 81 27 9 3 1]
|
||||
[256 64 16 4 1]
|
||||
[625 125 25 5 1]]
|
||||
"""
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = asarray_const(x)
|
||||
elif not isinstance(x, Tensor):
|
||||
_raise_type_error("Input x must be list, tuple or Tensor, but got ", x)
|
||||
if x.ndim != 1:
|
||||
_raise_value_error("Input x must be 1-D, but got dimension=", x.ndim)
|
||||
N = N or x.size
|
||||
if not isinstance(N, int):
|
||||
_raise_type_error("Input N must be an integer.")
|
||||
if N <= 0:
|
||||
_raise_value_error("Input N must > 0.")
|
||||
if not isinstance(increasing, bool):
|
||||
_raise_type_error("increasing must be a bool.")
|
||||
exponent = _iota(x.dtype, N, increasing)
|
||||
x = F.expand_dims(x, 1)
|
||||
exponent = F.expand_dims(exponent, 0)
|
||||
return F.tensor_pow(x, exponent)
|
||||
|
||||
|
||||
def indices(dimensions, dtype=mstype.int32, sparse=False):
|
||||
"""
|
||||
Returns an array representing the indices of a grid.
|
||||
|
||||
Computes an array where the subarrays contain index values 0, 1, …
|
||||
varying only along the corresponding axis.
|
||||
|
||||
Args:
|
||||
dimensions (tuple or list of ints): The shape of the grid.
|
||||
dtype (data type, optional): Data type of the result.
|
||||
sparse (boolean, optional): Defaults to False. Return a sparse
|
||||
representation of the grid instead of a dense representation.
|
||||
|
||||
Returns:
|
||||
Tensor or tuple of Tensor, If `sparse` is False, returns one array
|
||||
of grid indices, ``grid.shape = (len(dimensions),) + tuple(dimensions)``.
|
||||
If sparse is True, returns a tuple of arrays, with
|
||||
``grid[i].shape = (1, ..., 1, dimensions[i], 1, ..., 1)`` with
|
||||
``dimensions[i]`` in the `ith` place
|
||||
|
||||
Raises:
|
||||
TypeError: if input dimensions is not a tuple or list.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> grid = np.indices((2, 3))
|
||||
>>> print(indices)
|
||||
[Tensor(shape=[2, 3], dtype=Int32, value=
|
||||
[[0, 0, 0],
|
||||
[1, 1, 1]]), Tensor(shape=[2, 3], dtype=Int32, value=
|
||||
[[0, 1, 2],
|
||||
[0, 1, 2]])]
|
||||
"""
|
||||
if not isinstance(dimensions, (tuple, list)):
|
||||
_raise_type_error('Shape of the grid must be tuple or list')
|
||||
grids = ()
|
||||
for d in dimensions:
|
||||
grids += (arange(d, dtype=dtype),)
|
||||
return meshgrid(*grids, sparse=sparse, indexing='ij')
|
||||
|
|
|
@ -24,62 +24,19 @@ from ..ops.primitive import constexpr
|
|||
from ..nn import Cell
|
||||
|
||||
from .utils import _convert_list_tensor_to_tuple_tensor, _expand, _broadcast_to_shape, \
|
||||
_check_input_tensor, _broadcast_to
|
||||
_check_input_tensor, _broadcast_to, _to_tensor
|
||||
from .utils_const import _check_axes_range, _check_start_normalize, \
|
||||
_raise_type_error, _raise_value_error, _infer_out_shape, _empty, _promote, \
|
||||
_check_same_type, _check_axis_valid, _add_unit_axes, _broadcast_tuples, \
|
||||
_check_is_float, _check_axis_in_range, _check_axis_type, _canonicalize_axis, \
|
||||
_list_comprehensions, _check_element_int, _is_shape_empty, _type_convert, \
|
||||
_tuple_getitem, _expanded_shape
|
||||
_tuple_getitem, _expanded_shape, _seq_prod, _get_device, _tuple_setitem
|
||||
|
||||
# According to official numpy reference, the dimension of a numpy array must be less
|
||||
# than 32
|
||||
MAX_NUMPY_DIMS = 32
|
||||
|
||||
|
||||
@constexpr
|
||||
def _prepare_shape_for_expand_dims(shape, axes):
|
||||
"""
|
||||
Creates the expanded new shape based on the shape and given axes
|
||||
|
||||
Args:
|
||||
shape (tuple): the shape of the tensor
|
||||
axes Union(int, tuple(int), list(int)): the axes with dimensions expanded.
|
||||
|
||||
Returns:
|
||||
new_shape(tuple): the shape with dimensions expanded.
|
||||
"""
|
||||
|
||||
new_shape = []
|
||||
shape_idx = 0
|
||||
new_shape_length = len(shape)
|
||||
|
||||
# Convert to set
|
||||
if isinstance(axes, int):
|
||||
new_shape_length += 1
|
||||
if axes >= new_shape_length or axes < -new_shape_length:
|
||||
raise ValueError(f"axis {axes} is out of bounds for tensor of dimension {new_shape_length}")
|
||||
axes = {axes}
|
||||
|
||||
elif isinstance(axes, (list, tuple)):
|
||||
new_shape_length += len(axes)
|
||||
for axis in axes:
|
||||
if axis >= new_shape_length or axis < -new_shape_length:
|
||||
raise ValueError(f"axis {axis} is out of bounds for tensor of dimension {new_shape_length}")
|
||||
axes = set(axes)
|
||||
|
||||
else:
|
||||
raise TypeError(f"only int, tuple and list are allowed for axes, but got {type(axes)}")
|
||||
|
||||
for new_shape_idx in range(new_shape_length):
|
||||
if new_shape_idx in axes or new_shape_idx - new_shape_length in axes:
|
||||
new_shape.append(1)
|
||||
else:
|
||||
new_shape.append(shape[shape_idx])
|
||||
shape_idx += 1
|
||||
return tuple(new_shape)
|
||||
|
||||
|
||||
def expand_dims(a, axis):
|
||||
"""
|
||||
Expands the shape of a tensor.
|
||||
|
@ -109,10 +66,15 @@ def expand_dims(a, axis):
|
|||
(1, 2, 2)
|
||||
"""
|
||||
_check_input_tensor(a)
|
||||
shape = F.shape(a)
|
||||
# yield expanded shape based on the axes
|
||||
new_shape = _prepare_shape_for_expand_dims(shape, axis)
|
||||
return F.reshape(a, new_shape)
|
||||
if not isinstance(axis, (int, tuple, list)):
|
||||
_raise_type_error("axis must be tuple, list or int, but got ", axis)
|
||||
if isinstance(axis, int):
|
||||
return F.expand_dims(a, axis)
|
||||
ndim = a.ndim + len(axis)
|
||||
axis = _canonicalize_axis(axis, ndim)
|
||||
for ax in axis:
|
||||
a = F.expand_dims(a, ax)
|
||||
return a
|
||||
|
||||
|
||||
def squeeze(a, axis=None):
|
||||
|
@ -1091,6 +1053,9 @@ def roll(a, shift, axis=None):
|
|||
Returns:
|
||||
Tensor, with the same shape as a.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Raises:
|
||||
TypeError: If input arguments have types not specified above.
|
||||
ValueError: If axis exceeds `a.ndim`, or `shift` and `axis` cannot broadcast.
|
||||
|
@ -1212,12 +1177,6 @@ def moveaxis(a, source, destination):
|
|||
return F.transpose(a, perm)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _seq_prod(seq1, seq2):
|
||||
"""Returns the element-wise product of seq1 and seq2."""
|
||||
return tuple(map(lambda x, y: x*y, seq1, seq2))
|
||||
|
||||
|
||||
def tile(a, reps):
|
||||
"""
|
||||
Constructs an array by repeating `a` the number of times given by `reps`.
|
||||
|
@ -1355,6 +1314,60 @@ def broadcast_arrays(*args):
|
|||
return res
|
||||
|
||||
|
||||
def array_split(x, indices_or_sections, axis=0):
|
||||
"""
|
||||
Splits a tensor into multiple sub-tensors.
|
||||
|
||||
Note:
|
||||
Currently, array_split only supports :class:`mindspore.float32` on ``CPU``.
|
||||
|
||||
The only difference between ``np.split`` and ``np.array_split`` is that
|
||||
``np.array_split`` allows indices_or_sections to be an integer that does not
|
||||
equally divide the axis. For a tensor of length l that should be split into
|
||||
n sections, it returns :math:`l % n` sub-arrays of size :math:`l//n + 1` and
|
||||
the rest of size :math:`l//n`.
|
||||
|
||||
Args:
|
||||
x (Tensor): A Tensor to be divided.
|
||||
indices_or_sections (Union[int, tuple(int), list(int)]):
|
||||
If integer, :math:`N`, the tensor will be divided into
|
||||
:math:`N` tensors along axis.
|
||||
If tuple(int), list(int) or of sorted integers,
|
||||
the entries indicate where along axis the array is split.
|
||||
For example, :math:`[2, 3]` would, for :math:`axis=0`, result in
|
||||
three sub-tensors :math:`x[:2]`, :math:`x[2:3]`and :math:`x[3:]`.
|
||||
If an index exceeds the dimension of the array along axis,
|
||||
an empty sub-array is returned correspondingly.
|
||||
axis (int): The axis along which to split. Default: 0.
|
||||
|
||||
Returns:
|
||||
A list of sub-tensors.
|
||||
|
||||
Raises:
|
||||
TypeError: If argument `indices_or_sections` is not integer,
|
||||
tuple(int) or list(int) or argument `axis` is not integer.
|
||||
ValueError: If argument `axis` is out of range of :math:`[-x.ndim, x.ndim)`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> input_x = np.arange(9).astype("float32")
|
||||
>>> output = np.array_split(input_x, 4)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[3], dtype=Float32,
|
||||
value= [ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00]),
|
||||
Tensor(shape=[2], dtype=Float32,
|
||||
value= [ 3.00000000e+00, 4.00000000e+00]),
|
||||
Tensor(shape=[2], dtype=Float32,
|
||||
value= [ 5.00000000e+00, 6.00000000e+00]),
|
||||
Tensor(shape=[2], dtype=Float32,
|
||||
value= [ 7.00000000e+00, 8.00000000e+00]))
|
||||
"""
|
||||
return _split(x, indices_or_sections, opname="array_split", axis=axis)
|
||||
|
||||
|
||||
def split(x, indices_or_sections, axis=0):
|
||||
"""
|
||||
Splits a tensor into multiple sub-tensors along the given axis.
|
||||
|
@ -1380,9 +1393,12 @@ def split(x, indices_or_sections, axis=0):
|
|||
tuple(int) or list(int) or argument `axis` is not integer.
|
||||
ValueError: If argument `axis` is out of range of :math:`[-x.ndim, x.ndim)`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> input_x = np.arange(9).astype('float32')
|
||||
>>> input_x = np.arange(9).astype("float32")
|
||||
>>> output = np.split(input_x, 3)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[3], dtype=Float32,
|
||||
|
@ -1392,13 +1408,32 @@ def split(x, indices_or_sections, axis=0):
|
|||
Tensor(shape=[3], dtype=Float32,
|
||||
value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]))
|
||||
"""
|
||||
return _split(x, indices_or_sections, opname="split", axis=axis)
|
||||
|
||||
|
||||
def _split(x, indices_or_sections, opname, axis=0):
|
||||
"""Splits a tensor based on ``np.split`` or ``np.array_split``."""
|
||||
_check_input_tensor(x)
|
||||
_ = _check_axis_type(axis, True, False, False)
|
||||
axis = _canonicalize_axis(axis, x.ndim)
|
||||
res = None
|
||||
arr_shape = x.shape
|
||||
length_along_dim = arr_shape[axis]
|
||||
if isinstance(indices_or_sections, int):
|
||||
_split = P.Split(axis, indices_or_sections)
|
||||
res = _split(x)
|
||||
if opname == "split" or length_along_dim % indices_or_sections == 0:
|
||||
res = P.Split(axis, indices_or_sections)(x)
|
||||
else:
|
||||
num_long_tensor = length_along_dim % indices_or_sections
|
||||
num_short_tensor = indices_or_sections - num_long_tensor
|
||||
length1 = num_long_tensor * (length_along_dim // indices_or_sections + 1)
|
||||
length2 = length_along_dim - length1
|
||||
start1 = _list_comprehensions(F.rank(x), 0, True)
|
||||
size1 = _tuple_setitem(arr_shape, axis, length1)
|
||||
start2 = _tuple_setitem(start1, axis, length1)
|
||||
size2 = _tuple_setitem(arr_shape, axis, length2)
|
||||
res = P.Split(axis, num_long_tensor)(F.tensor_slice(x, start1, size1)) + \
|
||||
P.Split(axis, num_short_tensor)(F.tensor_slice(x, start2, size2))
|
||||
|
||||
elif isinstance(indices_or_sections, (list, tuple)) and _check_element_int(indices_or_sections):
|
||||
res = _split_sub_tensors(x, indices_or_sections, axis)
|
||||
else:
|
||||
|
@ -1921,7 +1956,6 @@ def repeat(a, repeats, axis=None):
|
|||
if repeats == 0:
|
||||
return _empty(F.dtype(a), (0,))
|
||||
return C.repeat_elements(a, repeats, axis)
|
||||
|
||||
shape = F.shape(a)
|
||||
size = shape[axis]
|
||||
if len(repeats) != size:
|
||||
|
@ -1932,3 +1966,144 @@ def repeat(a, repeats, axis=None):
|
|||
if rep != 0:
|
||||
repeated_subs.append(C.repeat_elements(sub, rep, axis))
|
||||
return concatenate(repeated_subs, axis)
|
||||
|
||||
|
||||
def rot90(a, k=1, axes=(0, 1)):
|
||||
"""
|
||||
Rotates a tensor by 90 degrees in the plane specified by axes.
|
||||
Rotation direction is from the first towards the second axis.
|
||||
|
||||
Args:
|
||||
a (Tensor): Input tensor of two or more dimensions.
|
||||
k (int): Number of times the tensor is rotated by 90 degrees. Default: 1.
|
||||
axes (Union[tuple(int), list(int)]): The tensor is rotated in the plane
|
||||
defined by the axes. Default: `(0, 1)`.
|
||||
Axes must be different and with the shape of `(2,)`.
|
||||
|
||||
Returns:
|
||||
Tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: if input `a` is not a Tensor or
|
||||
the argument `k` is not integer or
|
||||
the argument `axes` is not tuple of ints or list of ints.
|
||||
ValueError: if any axis is out of range or
|
||||
the length of `axes` is not `2`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> a = np.arange(24).reshape((2, 3, 4))
|
||||
>>> output = np.rot90(a)
|
||||
>>> print(output)
|
||||
[[[ 8 9 10 11]
|
||||
[20 21 22 23]]
|
||||
[[ 4 5 6 7]
|
||||
[16 17 18 19]]
|
||||
[[ 0 1 2 3]
|
||||
[12 13 14 15]]]
|
||||
>>> output = np.rot90(a, 3, (1, 2))
|
||||
>>> print(output)
|
||||
[[[ 8 4 0]
|
||||
[ 9 5 1]
|
||||
[10 6 2]
|
||||
[11 7 3]]
|
||||
[[20 16 12]
|
||||
[21 17 13]
|
||||
[22 18 14]
|
||||
[23 19 15]]]
|
||||
"""
|
||||
_check_input_tensor(a)
|
||||
|
||||
if not isinstance(k, int):
|
||||
_raise_type_error("integer argument expected, but got ", k)
|
||||
k = k % 4 if k >= 0 else 4 - (-k % 4)
|
||||
|
||||
if not isinstance(axes, (tuple, list)):
|
||||
_raise_type_error("tuple(ints) or list(ints) expected, but got ", axes)
|
||||
if len(axes) != 2:
|
||||
_raise_value_error("len(axes) must be 2.")
|
||||
axis1, axis2 = axes[0], axes[1]
|
||||
axis1 = _canonicalize_axis(axis1, a.ndim)
|
||||
axis2 = _canonicalize_axis(axis2, a.ndim)
|
||||
if axis1 == axis2:
|
||||
_raise_value_error('Axes must be different.')
|
||||
|
||||
if k == 0:
|
||||
return a
|
||||
if k == 2:
|
||||
return flip(flip(a, axis1), axis2)
|
||||
perm = _list_comprehensions(a.ndim)
|
||||
perm[axis1], perm[axis2] = perm[axis2], perm[axis1]
|
||||
if k == 1:
|
||||
return flip(transpose(a, perm), axis1)
|
||||
return flip(transpose(a, perm), axis2)
|
||||
|
||||
|
||||
def select(condlist, choicelist, default=0):
|
||||
"""
|
||||
Returns an array drawn from elements in `choicelist`, depending on conditions.
|
||||
|
||||
Args:
|
||||
condlist (array_like): The list of conditions which determine from which array
|
||||
in `choicelist` the output elements are taken. When multiple conditions are
|
||||
satisfied, the first one encountered in `condlist` is used.
|
||||
choicelist (array_like): The list of arrays from which the output elements are
|
||||
taken. It has to be of the same length as `condlist`.
|
||||
default (scalar, optional): The element inserted in output when all conditions
|
||||
evaluate to `False`.
|
||||
|
||||
Returns:
|
||||
Tensor, the output at position `m` is the `m-th` element of the array in
|
||||
`choicelist` where the `m-th` element of the corresponding array in `condlist`
|
||||
is `True`.
|
||||
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Raises:
|
||||
ValueError: if ``len(condlist) != len(choicelist)``.
|
||||
|
||||
Examples:
|
||||
>>> condlist = [[True, True, True, False, False],
|
||||
[False, False, True, False, True]]
|
||||
>>> choicelist = [[0, 1, 2, 3, 4], [0, 1, 4, 9, 16]]
|
||||
>>> output = np.select(condlist, choicelist)
|
||||
>>> print(output)
|
||||
[ 0 1 2 0 16]
|
||||
"""
|
||||
condlist, choicelist = _to_tensor(condlist, choicelist)
|
||||
shape_cond = F.shape(condlist)
|
||||
shape_choice = F.shape(choicelist)
|
||||
if F.rank(condlist) == 0 or F.rank(condlist) == 0:
|
||||
_raise_value_error('input cannot be scalars')
|
||||
case_num = shape_cond[0]
|
||||
if shape_choice[0] != case_num:
|
||||
_raise_value_error('list of cases must be same length as list of conditions')
|
||||
|
||||
# performs broadcast over the cases in condlist and choicelist
|
||||
case_size = _infer_out_shape(shape_cond[1:], shape_choice[1:])
|
||||
shape_broadcasted = (case_num,) + case_size
|
||||
ndim = len(shape_broadcasted)
|
||||
shape_cond_expanded = ((case_num,) + _list_comprehensions(ndim - F.rank(condlist), 1, True) +
|
||||
shape_cond[1:])
|
||||
condlist = _broadcast_to_shape(F.reshape(condlist, shape_cond_expanded), shape_broadcasted)
|
||||
shape_choice_expanded = ((case_num,) + _list_comprehensions(ndim - F.rank(choicelist), 1, True) +
|
||||
shape_choice[1:])
|
||||
choicelist = _broadcast_to_shape(F.reshape(choicelist, shape_choice_expanded), shape_broadcasted)
|
||||
|
||||
slice_start = _list_comprehensions(ndim - 1, 0, True)
|
||||
slice_size = (1,) + case_size
|
||||
dtype = F.dtype(choicelist)
|
||||
if _get_device() == 'CPU' and not _check_is_float(dtype):
|
||||
# F.tensor_slice only supports float on CPU
|
||||
choicelist = F.cast(choicelist, mstype.float32)
|
||||
default_slice = F.fill(F.dtype(choicelist), slice_size, default)
|
||||
for i in range(case_num - 1, -1, -1):
|
||||
cond_slice = F.tensor_slice(condlist.astype(mstype.float32), (i,) + slice_start, slice_size)
|
||||
choice_slice = F.tensor_slice(choicelist, (i,) + slice_start, slice_size)
|
||||
default_slice = F.select(cond_slice.astype(mstype.bool_), choice_slice, default_slice)
|
||||
return F.reshape(default_slice, (case_size)).astype(dtype)
|
||||
|
|
|
@ -169,3 +169,16 @@ promotion_rule = {
|
|||
(bool_, float32): float32,
|
||||
(bool_, float64): float64,
|
||||
}
|
||||
|
||||
rule_for_trigonometric = {float16: float16,
|
||||
float32: float32,
|
||||
float64: float64,
|
||||
int8: float16,
|
||||
int16: float32,
|
||||
int32: float32,
|
||||
int64: float32,
|
||||
uint8: float16,
|
||||
uint16: float32,
|
||||
uint32: float32,
|
||||
uint64: float32,
|
||||
bool_: float16}
|
||||
|
|
|
@ -15,33 +15,29 @@
|
|||
"""logical operations, the function docs are adapted from Numpy API."""
|
||||
|
||||
|
||||
from .math_ops import _apply_tensor_op
|
||||
from ..ops import functional as F
|
||||
from ..ops.primitive import constexpr
|
||||
from ..common import dtype as mstype
|
||||
from ..common import Tensor
|
||||
from .._c_expression import typing
|
||||
|
||||
from .array_creations import zeros, ones
|
||||
from .utils import _check_input_tensor
|
||||
from .math_ops import _apply_tensor_op, absolute
|
||||
from .array_creations import zeros, ones, empty
|
||||
from .utils import _check_input_tensor, _to_tensor, _isnan
|
||||
from .utils_const import _raise_type_error, _is_shape_empty, _infer_out_shape
|
||||
|
||||
|
||||
def not_equal(x1, x2, out=None, where=True, dtype=None):
|
||||
def not_equal(x1, x2, dtype=None):
|
||||
"""
|
||||
Returns (x1 != x2) element-wise.
|
||||
|
||||
Note:
|
||||
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
|
||||
and `extobj` are not supported.
|
||||
|
||||
Args:
|
||||
x1 (Tensor): First input tensor to be compared.
|
||||
x2 (Tensor): Second input tensor to be compared.
|
||||
out (Tensor or None, optional), default is None.
|
||||
where (Tensor or None, optional): For any non-default value of type other
|
||||
than :class:`Tensor` or :class:`None`, the output retains its original value.
|
||||
This condition is broadcasted over the input. At locations where the
|
||||
condition is `True`, the out array will be set to the ufunc result.
|
||||
Elsewhere, the out array will retain its original value. Note that
|
||||
if an uninitialized out array is created via the default ``out=None``,
|
||||
locations within it where the condition is `False` will remain
|
||||
uninitialized.
|
||||
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
|
||||
output Tensor.
|
||||
|
||||
|
@ -65,33 +61,21 @@ def not_equal(x1, x2, out=None, where=True, dtype=None):
|
|||
[False True]]
|
||||
"""
|
||||
_check_input_tensor(x1, x2)
|
||||
return _apply_tensor_op(F.not_equal, x1, x2, out=out, where=where, dtype=dtype)
|
||||
return _apply_tensor_op(F.not_equal, x1, x2, dtype=dtype)
|
||||
|
||||
|
||||
def less_equal(x1, x2, out=None, where=True, dtype=None):
|
||||
def less_equal(x1, x2, dtype=None):
|
||||
"""
|
||||
Returns the truth value of ``(x1 <= x2)`` element-wise.
|
||||
|
||||
Note:
|
||||
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
|
||||
not supported.
|
||||
When `where` is provided, `out` must have a tensor value. `out` is not supported
|
||||
for storing the result, however it can be used in combination with `where` to set
|
||||
the value at indices for which `where` is set to False.
|
||||
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
|
||||
and `extobj` are not supported.
|
||||
|
||||
Args:
|
||||
x1 (Tensor): Input array.
|
||||
x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be
|
||||
broadcastable to a common shape (which becomes the shape of the output).
|
||||
out (Tensor or None, optional): defaults to None.
|
||||
where (Tensor or None, optional): For any non-default value of type other
|
||||
than :class:`Tensor` or :class:`None`, the output retains its original value.
|
||||
This condition is broadcasted over the input. At locations where the
|
||||
condition is `True`, the out array will be set to the ufunc result.
|
||||
Elsewhere, the out array will retain its original value. Note that
|
||||
if an uninitialized out array is created via the default ``out=None``,
|
||||
locations within it where the condition is `False` will remain
|
||||
uninitialized.
|
||||
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
|
||||
output Tensor.
|
||||
|
||||
|
@ -113,33 +97,21 @@ def less_equal(x1, x2, out=None, where=True, dtype=None):
|
|||
[False True True]
|
||||
"""
|
||||
_check_input_tensor(x1, x2)
|
||||
return _apply_tensor_op(F.tensor_le, x1, x2, out=out, where=where, dtype=dtype)
|
||||
return _apply_tensor_op(F.tensor_le, x1, x2, dtype=dtype)
|
||||
|
||||
|
||||
def less(x1, x2, out=None, where=True, dtype=None):
|
||||
def less(x1, x2, dtype=None):
|
||||
"""
|
||||
Returns the truth value of ``(x1 < x2)`` element-wise.
|
||||
|
||||
Note:
|
||||
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
|
||||
not supported.
|
||||
When `where` is provided, `out` must have a tensor value. `out` is not supported
|
||||
for storing the result, however it can be used in combination with `where` to set
|
||||
the value at indices for which `where` is set to False.
|
||||
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
|
||||
and `extobj` are not supported.
|
||||
|
||||
Args:
|
||||
x1 (Tensor): input array.
|
||||
x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be
|
||||
broadcastable to a common shape (which becomes the shape of the output).
|
||||
out (Tensor or None, optional): defaults to None.
|
||||
where (Tensor or None, optional): For any non-default value of type other
|
||||
than :class:`Tensor` or :class:`None`, the output retains its original value.
|
||||
This condition is broadcasted over the input. At locations where the
|
||||
condition is `True`, the out array will be set to the ufunc result.
|
||||
Elsewhere, the out array will retain its original value. Note that
|
||||
if an uninitialized out array is created via the default ``out=None``,
|
||||
locations within it where the condition is `False` will remain
|
||||
uninitialized.
|
||||
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
|
||||
output Tensor.
|
||||
|
||||
|
@ -160,33 +132,21 @@ def less(x1, x2, out=None, where=True, dtype=None):
|
|||
>>> print(output)
|
||||
[ True False]
|
||||
"""
|
||||
return _apply_tensor_op(F.tensor_lt, x1, x2, out=out, where=where, dtype=dtype)
|
||||
return _apply_tensor_op(F.tensor_lt, x1, x2, dtype=dtype)
|
||||
|
||||
|
||||
def greater_equal(x1, x2, out=None, where=True, dtype=None):
|
||||
def greater_equal(x1, x2, dtype=None):
|
||||
"""
|
||||
Returns the truth value of ``(x1 >= x2)`` element-wise.
|
||||
|
||||
Note:
|
||||
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
|
||||
not supported.
|
||||
When `where` is provided, `out` must have a tensor value. `out` is not supported
|
||||
for storing the result, however it can be used in combination with `where` to set
|
||||
the value at indices for which `where` is set to False.
|
||||
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
|
||||
and `extobj` are not supported.
|
||||
|
||||
Args:
|
||||
x1 (Tensor): Input array.
|
||||
x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be
|
||||
broadcastable to a common shape (which becomes the shape of the output).
|
||||
out (Tensor or None, optional): defaults to None.
|
||||
where (Tensor or None, optional): For any non-default value of type other
|
||||
than :class:`Tensor` or :class:`None`, the output retains its original value.
|
||||
This condition is broadcasted over the input. At locations where the
|
||||
condition is `True`, the out array will be set to the ufunc result.
|
||||
Elsewhere, the out array will retain its original value. Note that
|
||||
if an uninitialized out array is created via the default ``out=None``,
|
||||
locations within it where the condition is `False` will remain
|
||||
uninitialized.
|
||||
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
|
||||
output Tensor.
|
||||
|
||||
|
@ -207,33 +167,21 @@ def greater_equal(x1, x2, out=None, where=True, dtype=None):
|
|||
>>> print(output)
|
||||
[ True True False]
|
||||
"""
|
||||
return _apply_tensor_op(F.tensor_ge, x1, x2, out=out, where=where, dtype=dtype)
|
||||
return _apply_tensor_op(F.tensor_ge, x1, x2, dtype=dtype)
|
||||
|
||||
|
||||
def greater(x1, x2, out=None, where=True, dtype=None):
|
||||
def greater(x1, x2, dtype=None):
|
||||
"""
|
||||
Returns the truth value of ``(x1 > x2)`` element-wise.
|
||||
|
||||
Note:
|
||||
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
|
||||
not supported.
|
||||
When `where` is provided, `out` must have a tensor value. `out` is not supported
|
||||
for storing the result, however it can be used in combination with `where` to set
|
||||
the value at indices for which `where` is set to False.
|
||||
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
|
||||
and `extobj` are not supported.
|
||||
|
||||
Args:
|
||||
x1 (Tensor): Input array.
|
||||
x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be
|
||||
broadcastable to a common shape (which becomes the shape of the output).
|
||||
out (Tensor or None, optional): defaults to None.
|
||||
where (Tensor or None, optional): For any non-default value of type other
|
||||
than :class:`Tensor` or :class:`None`, the output retains its original value.
|
||||
This condition is broadcasted over the input. At locations where the
|
||||
condition is `True`, the out array will be set to the ufunc result.
|
||||
Elsewhere, the out array will retain its original value. Note that
|
||||
if an uninitialized out array is created via the default ``out=None``,
|
||||
locations within it where the condition is `False` will remain
|
||||
uninitialized.
|
||||
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
|
||||
output Tensor.
|
||||
|
||||
|
@ -254,33 +202,21 @@ def greater(x1, x2, out=None, where=True, dtype=None):
|
|||
>>> print(output)
|
||||
[ True False]
|
||||
"""
|
||||
return _apply_tensor_op(F.tensor_gt, x1, x2, out=out, where=where, dtype=dtype)
|
||||
return _apply_tensor_op(F.tensor_gt, x1, x2, dtype=dtype)
|
||||
|
||||
|
||||
def equal(x1, x2, out=None, where=True, dtype=None):
|
||||
def equal(x1, x2, dtype=None):
|
||||
"""
|
||||
Returns the truth value of ``(x1 == x2)`` element-wise.
|
||||
|
||||
Note:
|
||||
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
|
||||
not supported.
|
||||
When `where` is provided, `out` must have a tensor value. `out` is not supported
|
||||
for storing the result, however it can be used in combination with `where` to set
|
||||
the value at indices for which `where` is set to False.
|
||||
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
|
||||
and `extobj` are not supported.
|
||||
|
||||
Args:
|
||||
x1 (Tensor): Input array.
|
||||
x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be
|
||||
broadcastable to a common shape (which becomes the shape of the output).
|
||||
out (Tensor or None, optional): defaults to None.
|
||||
where (Tensor or None, optional): For any non-default value of type other
|
||||
than :class:`Tensor` or :class:`None`, the output retains its original value.
|
||||
This condition is broadcasted over the input. At locations where the
|
||||
condition is `True`, the out array will be set to the ufunc result.
|
||||
Elsewhere, the out array will retain its original value. Note that
|
||||
if an uninitialized out array is created via the default ``out=None``,
|
||||
locations within it where the condition is `False` will remain
|
||||
uninitialized.
|
||||
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
|
||||
output Tensor.
|
||||
|
||||
|
@ -301,34 +237,22 @@ def equal(x1, x2, out=None, where=True, dtype=None):
|
|||
>>> print(output)
|
||||
[ True True False]
|
||||
"""
|
||||
return _apply_tensor_op(F.equal, x1, x2, out=out, where=where, dtype=dtype)
|
||||
return _apply_tensor_op(F.equal, x1, x2, dtype=dtype)
|
||||
|
||||
|
||||
def isfinite(x, out=None, where=True, dtype=None):
|
||||
def isfinite(x, dtype=None):
|
||||
"""
|
||||
Tests element-wise for finiteness (not infinity or not Not a Number).
|
||||
|
||||
The result is returned as a boolean array.
|
||||
|
||||
Note:
|
||||
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
|
||||
not supported.
|
||||
When `where` is provided, `out` must have a tensor value. `out` is not supported
|
||||
for storing the result, however it can be used in combination with `where` to set
|
||||
the value at indices for which `where` is set to False.
|
||||
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
|
||||
and `extobj` are not supported.
|
||||
On GPU, the supported dtypes are np.float16, and np.float32.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input values.
|
||||
out (Tensor or None, optional): defaults to None.
|
||||
where (Tensor or None, optional): For any non-default value of type other
|
||||
than :class:`Tensor` or :class:`None`, the output retains its original value.
|
||||
This condition is broadcasted over the input. At locations where the
|
||||
condition is `True`, the out array will be set to the ufunc result.
|
||||
Elsewhere, the out array will retain its original value. Note that
|
||||
if an uninitialized out array is created via the default ``out=None``,
|
||||
locations within it where the condition is `False` will remain
|
||||
uninitialized.
|
||||
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
|
||||
output Tensor.
|
||||
|
||||
|
@ -351,37 +275,20 @@ def isfinite(x, out=None, where=True, dtype=None):
|
|||
>>> print(output)
|
||||
False
|
||||
"""
|
||||
return _apply_tensor_op(F.isfinite, x, out=out, where=where, dtype=dtype)
|
||||
return _apply_tensor_op(F.isfinite, x, dtype=dtype)
|
||||
|
||||
|
||||
def _isnan(x):
|
||||
"""Computes isnan without applying keyword arguments."""
|
||||
return F.not_equal(x, x)
|
||||
|
||||
|
||||
def isnan(x, out=None, where=True, dtype=None):
|
||||
def isnan(x, dtype=None):
|
||||
"""
|
||||
Tests element-wise for NaN and return result as a boolean array.
|
||||
|
||||
Note:
|
||||
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
|
||||
not supported.
|
||||
When `where` is provided, `out` must have a tensor value. `out` is not supported
|
||||
for storing the result, however it can be used in combination with `where` to set
|
||||
the value at indices for which `where` is set to False.
|
||||
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
|
||||
and `extobj` are not supported.
|
||||
Only np.float32 is currently supported.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input values.
|
||||
out (Tensor or None, optional): defaults to None.
|
||||
where (Tensor or None, optional): For any non-default value of type other
|
||||
than :class:`Tensor` or :class:`None`, the output retains its original value.
|
||||
This condition is broadcasted over the input. At locations where the
|
||||
condition is `True`, the out array will be set to the ufunc result.
|
||||
Elsewhere, the out array will retain its original value. Note that
|
||||
if an uninitialized out array is created via the default ``out=None``,
|
||||
locations within it where the condition is `False` will remain
|
||||
uninitialized.
|
||||
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
|
||||
output Tensor.
|
||||
|
||||
|
@ -404,7 +311,7 @@ def isnan(x, out=None, where=True, dtype=None):
|
|||
>>> print(output)
|
||||
False
|
||||
"""
|
||||
return _apply_tensor_op(_isnan, x, out=out, where=where, dtype=dtype)
|
||||
return _apply_tensor_op(_isnan, x, dtype=dtype)
|
||||
|
||||
|
||||
def _isinf(x):
|
||||
|
@ -419,31 +326,19 @@ def _isinf(x):
|
|||
return F.cast(res, mstype.bool_)
|
||||
|
||||
|
||||
def isinf(x, out=None, where=True, dtype=None):
|
||||
def isinf(x, dtype=None):
|
||||
"""
|
||||
Tests element-wise for positive or negative infinity.
|
||||
|
||||
Returns a boolean array of the same shape as `x`, True where ``x == +/-inf``, otherwise False.
|
||||
|
||||
Note:
|
||||
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
|
||||
not supported.
|
||||
When `where` is provided, `out` must have a tensor value. `out` is not supported
|
||||
for storing the result, however it can be used in combination with `where` to set
|
||||
the value at indices for which `where` is set to False.
|
||||
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
|
||||
and `extobj` are not supported.
|
||||
Only np.float32 is currently supported.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input values.
|
||||
out (Tensor or None, optional): defaults to None.
|
||||
where (Tensor or None, optional): For any non-default value of type other
|
||||
than :class:`Tensor` or :class:`None`, the output retains its original value.
|
||||
This condition is broadcasted over the input. At locations where the
|
||||
condition is `True`, the out array will be set to the ufunc result.
|
||||
Elsewhere, the out array will retain its original value. Note that
|
||||
if an uninitialized out array is created via the default ``out=None``,
|
||||
locations within it where the condition is `False` will remain
|
||||
uninitialized.
|
||||
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
|
||||
output Tensor.
|
||||
|
||||
|
@ -466,7 +361,7 @@ def isinf(x, out=None, where=True, dtype=None):
|
|||
>>> print(output)
|
||||
[ True True False False]
|
||||
"""
|
||||
return _apply_tensor_op(_isinf, x, out=out, where=where, dtype=dtype)
|
||||
return _apply_tensor_op(_isinf, x, dtype=dtype)
|
||||
|
||||
|
||||
def _is_sign_inf(x, fn):
|
||||
|
@ -562,7 +457,7 @@ def isscalar(element):
|
|||
element (any): Input argument, can be of any type and shape.
|
||||
|
||||
Returns:
|
||||
Boolean, True if `element` is a scalar type, False if it is not.
|
||||
Boolean, True if `element` is a scalar type, False if it is not.
|
||||
|
||||
Raises:
|
||||
TypeError: if the type of `element` is not supported by mindspore parser.
|
||||
|
@ -587,3 +482,302 @@ def isscalar(element):
|
|||
"""
|
||||
obj_type = F.typeof(element)
|
||||
return not isinstance(obj_type, Tensor) and _isscalar(obj_type)
|
||||
|
||||
|
||||
def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
|
||||
"""
|
||||
Returns a boolean tensor where two tensors are element-wise equal within a tolerance.
|
||||
|
||||
The tolerance values are positive, typically very small numbers. The relative
|
||||
difference (:math:`rtol * abs(b)`) and the absolute difference `atol` are added together
|
||||
to compare against the absolute difference between `a` and `b`.
|
||||
|
||||
Note:
|
||||
For finite values, isclose uses the following equation to test whether two
|
||||
floating point values are equivalent.
|
||||
:math:`absolute(a - b) <= (atol + rtol * absolute(b))`
|
||||
|
||||
Args:
|
||||
a (Union[Tensor, list, tuple]): Input first tensor to compare.
|
||||
b (Union[Tensor, list, tuple]): Input second tensor to compare.
|
||||
rtol (Number): The relative tolerance parameter (see Note).
|
||||
atol (Number): The absolute tolerance parameter (see Note).
|
||||
equal_nan (bool): Whether to compare ``NaN`` as equal. If True, ``NaN`` in
|
||||
`a` will be considered equal to ``NaN`` in `b` in the output tensor.
|
||||
|
||||
Returns:
|
||||
A ``bool`` tensor of where `a` and `b` are equal within the given tolerance.
|
||||
|
||||
Raises:
|
||||
TypeError: If inputs have types not specified above.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> a = np.array([0,1,2,float('inf'),float('inf'),float('nan')])
|
||||
>>> b = np.array([0,1,-2,float('-inf'),float('inf'),float('nan')])
|
||||
>>> print(np.isclose(a, b))
|
||||
[ True True False False True False]
|
||||
>>> print(np.isclose(a, b, equal_nan=True))
|
||||
[ True True False False True True]
|
||||
"""
|
||||
a, b = _to_tensor(a, b)
|
||||
if not isinstance(rtol, (int, float, bool)) or not isinstance(atol, (int, float, bool)):
|
||||
_raise_type_error("rtol and atol are expected to be numbers.")
|
||||
if not isinstance(equal_nan, bool):
|
||||
_raise_type_error("equal_nan is expected to be bool.")
|
||||
|
||||
if _is_shape_empty(a.shape) or _is_shape_empty(b.shape):
|
||||
return empty(_infer_out_shape(a.shape, b.shape), dtype=mstype.bool_)
|
||||
rtol = _to_tensor(rtol).astype("float32")
|
||||
atol = _to_tensor(atol).astype("float32")
|
||||
res = absolute(a - b) <= (atol + rtol * absolute(b))
|
||||
# infs are treated as equal
|
||||
a_posinf = isposinf(a)
|
||||
b_posinf = isposinf(b)
|
||||
a_neginf = isneginf(a)
|
||||
b_neginf = isneginf(b)
|
||||
same_inf = F.logical_or(F.logical_and(a_posinf, b_posinf), F.logical_and(a_neginf, b_neginf))
|
||||
diff_inf = F.logical_or(F.logical_and(a_posinf, b_neginf), F.logical_and(a_neginf, b_posinf))
|
||||
res = F.logical_and(F.logical_or(res, same_inf), F.logical_not(diff_inf))
|
||||
both_nan = F.logical_and(_isnan(a), _isnan(b))
|
||||
if equal_nan:
|
||||
res = F.logical_or(both_nan, res)
|
||||
else:
|
||||
res = F.logical_and(F.logical_not(both_nan), res)
|
||||
return res
|
||||
|
||||
|
||||
def in1d(ar1, ar2, invert=False):
|
||||
"""
|
||||
Tests whether each element of a 1-D array is also present in a second array.
|
||||
|
||||
Returns a boolean array the same length as `ar1` that is True where an element
|
||||
of `ar1` is in `ar2` and False otherwise.
|
||||
|
||||
Note:
|
||||
Numpy argument `assume_unique` is not supported since the implementation does
|
||||
not rely on the uniqueness of the input arrays.
|
||||
|
||||
Args:
|
||||
ar1 (array_like): Input array with shape `(M,)`.
|
||||
ar2 (array_like): The values against which to test each value of `ar1`.
|
||||
invert (boolean, optional): If True, the values in the returned array are
|
||||
inverted (that is, False where an element of `ar1` is in `ar2` and True
|
||||
otherwise). Default is False.
|
||||
|
||||
Returns:
|
||||
Tensor, with shape `(M,)`. The values ``ar1[in1d]`` are in `ar2`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> test = np.array([0, 1, 2, 5, 0])
|
||||
>>> states = [0, 2]
|
||||
>>> mask = np.in1d(test, states)
|
||||
>>> print(mask)
|
||||
[ True False True False True]
|
||||
>>> mask = np.in1d(test, states, invert=True)
|
||||
>>> print(mask)
|
||||
[False True False True False]
|
||||
"""
|
||||
ar1, ar2 = _to_tensor(ar1, ar2)
|
||||
ar1 = F.expand_dims(ar1.ravel(), -1)
|
||||
ar2 = ar2.ravel()
|
||||
included = F.equal(ar1, ar2)
|
||||
# F.reduce_sum only supports float
|
||||
res = F.reduce_sum(included.astype(mstype.float32), -1).astype(mstype.bool_)
|
||||
if invert:
|
||||
res = F.equal(res, _to_tensor(False))
|
||||
return res
|
||||
|
||||
|
||||
def isin(element, test_elements, invert=False):
|
||||
"""
|
||||
Calculates element in `test_elements`, broadcasting over `element` only. Returns a
|
||||
boolean array of the same shape as `element` that is True where an element of
|
||||
`element` is in `test_elements` and False otherwise.
|
||||
|
||||
Note:
|
||||
Numpy argument `assume_unique` is not supported since the implementation does
|
||||
not rely on the uniqueness of the input arrays.
|
||||
|
||||
Args:
|
||||
element (array_like): Input array.
|
||||
test_elements (array_like): The values against which to test each value of
|
||||
`element`.
|
||||
invert (boolean, optional): If True, the values in the returned array are
|
||||
inverted, as if calculating `element` not in `test_elements`. Default is False.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape as `element`. The values ``element[isin]`` are in
|
||||
`test_elements`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> element = 2*np.arange(4).reshape((2, 2))
|
||||
>>> test_elements = [1, 2, 4, 8]
|
||||
>>> mask = np.isin(element, test_elements)
|
||||
>>> print(mask)
|
||||
[[False True]
|
||||
[ True False]]
|
||||
>>> mask = np.isin(element, test_elements, invert=True)
|
||||
>>> print(mask)
|
||||
[[ True False]
|
||||
[False True]]
|
||||
"""
|
||||
res = in1d(element, test_elements, invert=invert)
|
||||
return F.reshape(res, F.shape(element))
|
||||
|
||||
|
||||
def logical_not(a, dtype=None):
|
||||
"""
|
||||
Computes the truth value of NOT `a` element-wise.
|
||||
|
||||
Note:
|
||||
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are
|
||||
not supported.
|
||||
|
||||
Args:
|
||||
a (Tensor): The input tensor whose dtype is bool.
|
||||
dtype (:class:`mindspore.dtype`, optional): Default: :class:`None`. Overrides the dtype of the
|
||||
output Tensor.
|
||||
|
||||
Returns:
|
||||
Tensor or scalar.
|
||||
Boolean result with the same shape as `a` of the NOT operation on elements of `a`.
|
||||
This is a scalar if `a` is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor or its dtype is not bool.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> a = np.array([True, False])
|
||||
>>> output = np.logical_not(a)
|
||||
>>> print(output)
|
||||
[False True]
|
||||
"""
|
||||
return _apply_tensor_op(F.logical_not, a, dtype=dtype)
|
||||
|
||||
|
||||
def logical_or(x1, x2, dtype=None):
|
||||
"""
|
||||
Computes the truth value of `x1` OR `x2` element-wise.
|
||||
|
||||
Note:
|
||||
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
|
||||
and `extobj` are not supported.
|
||||
|
||||
Args:
|
||||
x1 (Tensor): Input tensor.
|
||||
x2 (Tensor): Input tensor. If ``x1.shape != x2.shape``, they must be
|
||||
broadcastable to a common shape (which becomes the shape of the output).
|
||||
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
|
||||
output Tensor.
|
||||
|
||||
Returns:
|
||||
Tensor or scalar, element-wise comparison of `x1` and `x2`. Typically of type
|
||||
bool, unless ``dtype=object`` is passed. This is a scalar if both `x1` and `x2` are
|
||||
scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> x1 = np.array([True, False])
|
||||
>>> x2 = np.array([False, True])
|
||||
>>> output = np.logical_or(x1, x2)
|
||||
>>> print(output)
|
||||
[ True True]
|
||||
"""
|
||||
return _apply_tensor_op(F.logical_or, x1, x2, dtype=dtype)
|
||||
|
||||
|
||||
def logical_and(x1, x2, dtype=None):
|
||||
"""
|
||||
Computes the truth value of `x1` AND `x2` element-wise.
|
||||
|
||||
Note:
|
||||
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
|
||||
and `extobj` are not supported.
|
||||
|
||||
Args:
|
||||
x1 (Tensor): Input tensor.
|
||||
x2 (Tensor): Input tensor. If ``x1.shape != x2.shape``, they must be
|
||||
broadcastable to a common shape (which becomes the shape of the output).
|
||||
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
|
||||
output Tensor.
|
||||
|
||||
Returns:
|
||||
Tensor or scalar.
|
||||
Boolean result of the logical AND operation applied to the elements of `x1` and `x2`;
|
||||
the shape is determined by broadcasting. This is a scalar if both `x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> x1 = np.array([True, False])
|
||||
>>> x2 = np.array([False, False])
|
||||
>>> output = np.logical_and(x1, x2)
|
||||
>>> print(output)
|
||||
[False False]
|
||||
"""
|
||||
return _apply_tensor_op(F.logical_and, x1, x2, dtype=dtype)
|
||||
|
||||
|
||||
def logical_xor(x1, x2, dtype=None):
|
||||
"""
|
||||
Computes the truth value of `x1` XOR `x2`, element-wise.
|
||||
|
||||
Note:
|
||||
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`,
|
||||
and `extobj` are not supported.
|
||||
|
||||
Args:
|
||||
x1 (Tensor): Input tensor.
|
||||
x2 (Tensor): Input tensor. If ``x1.shape != x2.shape``, they must be
|
||||
broadcastable to a common shape (which becomes the shape of the output).
|
||||
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
|
||||
output Tensor.
|
||||
|
||||
Returns:
|
||||
Tensor or scalar.
|
||||
Boolean result of the logical AND operation applied to the elements of `x1` and `x2`;
|
||||
the shape is determined by broadcasting. This is a scalar if both `x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> x1 = np.array([True, False])
|
||||
>>> x2 = np.array([False, False])
|
||||
>>> output = np.logical_xor(x1, x2)
|
||||
>>> print(output)
|
||||
[True False]
|
||||
"""
|
||||
_check_input_tensor(x1)
|
||||
_check_input_tensor(x2)
|
||||
y1 = F.logical_or(x1, x2)
|
||||
y2 = F.logical_or(F.logical_not(x1), F.logical_not(x2))
|
||||
return _apply_tensor_op(F.logical_and, y1, y2, dtype=dtype)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -13,14 +13,11 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""internal utility functions"""
|
||||
|
||||
import numpy as onp
|
||||
|
||||
from ..common import Tensor
|
||||
from ..ops import functional as F
|
||||
from ..common import dtype as mstype
|
||||
|
||||
from .utils_const import _tile_size, _add_unit_axes, _raise_type_error
|
||||
from .utils_const import _tile_size, _add_unit_axes, _raise_type_error, _type_convert
|
||||
|
||||
|
||||
def _deep_list(array_like):
|
||||
|
@ -56,9 +53,8 @@ def _deep_tensor_to_nparray(array_like):
|
|||
|
||||
def _check_input_for_asarray(array_like):
|
||||
"""check whether array_like argument is a valid type for np.asarray conversion"""
|
||||
if not isinstance(array_like, (Tensor, list, tuple, int, float, bool, onp.ndarray)):
|
||||
_raise_type_error("input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`" + \
|
||||
"or numpy.ndarray, but got ", array_like)
|
||||
if not isinstance(array_like, (Tensor, list, tuple, int, float, bool)):
|
||||
_raise_type_error("input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`, but got ", array_like)
|
||||
|
||||
|
||||
def _is_scalar(shape):
|
||||
|
@ -121,6 +117,20 @@ def _convert_64_to_32(tensor):
|
|||
return tensor
|
||||
|
||||
|
||||
def _to_tensor(*args):
|
||||
"""Returns each input as Tensor"""
|
||||
res = ()
|
||||
for arg in args:
|
||||
if isinstance(arg, (int, float, bool, list, tuple)):
|
||||
arg = _convert_64_to_32(_type_convert(Tensor, arg))
|
||||
elif not isinstance(arg, Tensor):
|
||||
_raise_type_error("Expect input to be array like.")
|
||||
res += (arg,)
|
||||
if len(res) == 1:
|
||||
return res[0]
|
||||
return res
|
||||
|
||||
|
||||
def _get_dtype_from_scalar(*input_numbers):
|
||||
"""
|
||||
Get the final dtype from series of input numbers, compared with F.typeof, we
|
||||
|
@ -139,3 +149,8 @@ def _get_dtype_from_scalar(*input_numbers):
|
|||
if int_flag:
|
||||
return mstype.int32
|
||||
return mstype.float32
|
||||
|
||||
|
||||
def _isnan(x):
|
||||
"""Computes isnan."""
|
||||
return F.not_equal(x, x)
|
||||
|
|
|
@ -14,7 +14,8 @@
|
|||
# ============================================================================
|
||||
"""internal graph-compatible utility functions"""
|
||||
import math
|
||||
from functools import partial
|
||||
from itertools import zip_longest
|
||||
from collections import deque
|
||||
|
||||
import mindspore.context as context
|
||||
from ..ops import functional as F
|
||||
|
@ -24,7 +25,7 @@ from ..common import Tensor
|
|||
from .._c_expression import Tensor as Tensor_
|
||||
from .._c_expression import typing
|
||||
|
||||
from .dtypes import promotion_rule, dtype_tuple, all_types, dtype_map
|
||||
from .dtypes import promotion_rule, dtype_tuple, all_types, dtype_map, rule_for_trigonometric
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -110,44 +111,19 @@ def _get_device():
|
|||
return context.get_context('device_target')
|
||||
|
||||
|
||||
@constexpr
|
||||
def _reverse_index(idx, arr):
|
||||
"""
|
||||
Returns 1 if shape[idx:] is broadcastable to shape_out[idx:],
|
||||
2 situations if the function returns 1:
|
||||
- 1. Tensor's shape has 1 at the designated dimension.
|
||||
- 2. Tensor's dimension is less than the designated idx. (The Tensor shape
|
||||
has been reversed)
|
||||
For both cases, 2 tensors are broadcastable.
|
||||
otherwise returns the element at position of shape
|
||||
"""
|
||||
if len(arr) <= idx:
|
||||
return 1
|
||||
return arr[-1 - idx]
|
||||
|
||||
|
||||
@constexpr
|
||||
def _infer_out_shape(*shapes):
|
||||
"""
|
||||
Returns shape of output after broadcasting
|
||||
Raises ValueError if shape1 and shape2 cannot be broadcast
|
||||
Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
|
||||
"""
|
||||
shapes_unbroadcastable = False
|
||||
ndim_max = max(map(len, shapes))
|
||||
shape_out = [0]*ndim_max
|
||||
i = 0
|
||||
for i in range(ndim_max):
|
||||
shape_out[-1 - i] = max(map(partial(_reverse_index, i), shapes))
|
||||
for shape in shapes:
|
||||
if _reverse_index(i, shape) != shape_out[-1 - i]:
|
||||
if _reverse_index(i, shape) != 1:
|
||||
shapes_unbroadcastable = True
|
||||
break
|
||||
if shapes_unbroadcastable:
|
||||
break
|
||||
if not shapes_unbroadcastable:
|
||||
return tuple(shape_out)
|
||||
raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}')
|
||||
shape_out = deque()
|
||||
reversed_shapes = map(reversed, shapes)
|
||||
for items in zip_longest(*reversed_shapes, fillvalue=1):
|
||||
max_size = 0 if 0 in items else max(items)
|
||||
if any(item not in (1, max_size) for item in items):
|
||||
raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}')
|
||||
shape_out.appendleft(max_size)
|
||||
return tuple(shape_out)
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -228,6 +204,21 @@ def _raise_value_error(info, param=None):
|
|||
raise ValueError(info + f"{param}")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _raise_runtime_error(info, param=None):
|
||||
"""
|
||||
Raise RuntimeError in both graph/pynative mode
|
||||
|
||||
Args:
|
||||
info(str): info string to display
|
||||
param(python obj): any object that can be recognized by graph mode. If is
|
||||
not None, then param's value information will be extracted and displayed.
|
||||
Default is None.
|
||||
"""
|
||||
if param is None:
|
||||
raise RuntimeError(info)
|
||||
raise RuntimeError(info + f"{param}")
|
||||
|
||||
@constexpr
|
||||
def _empty(dtype, shape):
|
||||
"""Returns an uninitialized array with dtype and shape."""
|
||||
|
@ -242,6 +233,9 @@ def _promote(dtype1, dtype2):
|
|||
return promotion_rule[dtype1, dtype2]
|
||||
return promotion_rule[dtype2, dtype1]
|
||||
|
||||
@constexpr
|
||||
def _promote_for_trigonometric(dtype):
|
||||
return rule_for_trigonometric[dtype]
|
||||
|
||||
@constexpr
|
||||
def _max(*args):
|
||||
|
@ -315,7 +309,7 @@ def _canonicalize_axis(axis, ndim):
|
|||
|
||||
axis = tuple([canonicalizer(axis) for axis in axis])
|
||||
if all(axis.count(el) <= 1 for el in axis):
|
||||
return axis if len(axis) > 1 else axis[0]
|
||||
return tuple(sorted(axis)) if len(axis) > 1 else axis[0]
|
||||
raise ValueError(f"duplicate axes in {axis}.")
|
||||
|
||||
|
||||
|
@ -426,13 +420,37 @@ def _tuple_getitem(tup, idx, startswith=True):
|
|||
|
||||
|
||||
@constexpr
|
||||
def _iota(dtype, num):
|
||||
def _tuple_setitem(tup, idx, value):
|
||||
"""
|
||||
Returns a tuple with specified `idx` set to `value`.
|
||||
"""
|
||||
tup = list(tup)
|
||||
tup[idx] = value
|
||||
return tuple(tup)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _iota(dtype, num, increasing=True):
|
||||
"""Creates a 1-D tensor with value: [0,1,...num-1] and dtype."""
|
||||
# TODO: Change to P.Linspace when the kernel is implemented on CPU.
|
||||
return Tensor(list(range(int(num))), dtype)
|
||||
if increasing:
|
||||
return Tensor(list(range(int(num))), dtype)
|
||||
return Tensor(list(range(int(num)-1, -1, -1)), dtype)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _ceil(number):
|
||||
"""Ceils the number in graph mode."""
|
||||
return math.ceil(number)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _seq_prod(seq1, seq2):
|
||||
"""Returns the element-wise product of seq1 and seq2."""
|
||||
return tuple(map(lambda x, y: x*y, seq1, seq2))
|
||||
|
||||
|
||||
@constexpr
|
||||
def _make_tensor(val, dtype):
|
||||
""" Returns the tensor with value `val` and dtype `dtype`."""
|
||||
return Tensor(val, dtype)
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
"""Implementation for internal polymorphism `not equal` operations."""
|
||||
|
||||
from . import _constexpr_utils as const_utils
|
||||
from ...composite import base
|
||||
from ... import functional as F
|
||||
|
||||
|
@ -41,6 +42,21 @@ def _not_equal_scalar(x, y):
|
|||
return not F.scalar_eq(x, y)
|
||||
|
||||
|
||||
@not_equal.register("mstype", "mstype")
|
||||
def _not_equal_mstype(x, y):
|
||||
"""
|
||||
Determine if two mindspore types are not equal.
|
||||
|
||||
Args:
|
||||
x (mstype): first input mindspore type.
|
||||
y (mstype): second input mindspore type.
|
||||
|
||||
Returns:
|
||||
bool, if x != y return true, x == y return false.
|
||||
"""
|
||||
return not const_utils.mstype_eq(x, y)
|
||||
|
||||
|
||||
@not_equal.register("String", "String")
|
||||
def _not_equal_string(x, y):
|
||||
"""
|
||||
|
|
|
@ -77,6 +77,7 @@ floormod = tensor_mod
|
|||
tensor_exp = P.Exp()
|
||||
exp = tensor_exp
|
||||
tensor_expm1 = P.Expm1()
|
||||
tensor_slice = P.Slice()
|
||||
strided_slice = P.StridedSlice()
|
||||
same_type_shape = P.SameTypeShape()
|
||||
check_bprop = P.CheckBprop()
|
||||
|
@ -94,6 +95,22 @@ tensor_slice = P.Slice()
|
|||
maximum = P.Maximum()
|
||||
minimum = P.Minimum()
|
||||
floor = P.Floor()
|
||||
logical_not = P.LogicalNot()
|
||||
logical_or = P.LogicalOr()
|
||||
logical_and = P.LogicalAnd()
|
||||
sin = P.Sin()
|
||||
cos = P.Cos()
|
||||
tan = P.Tan()
|
||||
asin = P.Asin()
|
||||
acos = P.ACos()
|
||||
atan = P.Atan()
|
||||
sinh = P.Sinh()
|
||||
cosh = P.Cosh()
|
||||
tanh = P.Tanh()
|
||||
asinh = P.Asinh()
|
||||
acosh = P.Acosh()
|
||||
atanh = P.Atanh()
|
||||
atan2 = P.Atan2()
|
||||
|
||||
scalar_to_array = P.ScalarToArray()
|
||||
scalar_to_tensor = P.ScalarToTensor()
|
||||
|
|
|
@ -2560,7 +2560,7 @@ class Acosh(PrimitiveWithInfer):
|
|||
TypeError: If `input_x` is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> acosh = ops.Acosh()
|
||||
|
@ -2637,7 +2637,7 @@ class Asinh(PrimitiveWithInfer):
|
|||
TypeError: If `input_x` is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> asinh = ops.Asinh()
|
||||
|
|
|
@ -20,7 +20,7 @@ import numpy as onp
|
|||
import mindspore.numpy as mnp
|
||||
|
||||
from .utils import rand_int, rand_bool, match_array, match_res, match_meta, \
|
||||
match_all_arrays
|
||||
match_all_arrays, run_multi_test, to_tensor
|
||||
|
||||
|
||||
class Cases():
|
||||
|
@ -40,8 +40,8 @@ class Cases():
|
|||
|
||||
self.array_sets = [1, 1.1, True, [1, 0, True], [1, 1.0, 2], (1,),
|
||||
[(1, 2, 3), (4, 5, 6)], onp.random.random( # pylint: disable=no-member
|
||||
(100, 100)).astype(onp.float32),
|
||||
onp.random.random((100, 100)).astype(onp.bool)]
|
||||
(100, 100)).astype(onp.float32).tolist(),
|
||||
onp.random.random((100, 100)).astype(onp.bool).tolist()]
|
||||
|
||||
self.arrs = [
|
||||
rand_int(2),
|
||||
|
@ -138,8 +138,8 @@ def test_asarray():
|
|||
expected = mnp.asarray(array, test_case.mnp_dtypes[i]).asnumpy()
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
# Additional tests for nested tensor/numpy_array mixture
|
||||
mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
|
||||
# Additional tests for nested tensor mixture
|
||||
mnp_input = [(mnp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
|
||||
onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
|
||||
|
||||
actual = onp.asarray(onp_input)
|
||||
|
@ -168,11 +168,11 @@ def test_array():
|
|||
assert arr4 is arr5
|
||||
|
||||
# Additional tests for nested tensor/numpy_array mixture
|
||||
mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
|
||||
mnp_input = [(mnp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
|
||||
onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
|
||||
|
||||
actual = onp.asarray(onp_input)
|
||||
expected = mnp.asarray(mnp_input).asnumpy()
|
||||
actual = onp.array(onp_input)
|
||||
expected = mnp.array(mnp_input).asnumpy()
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
|
||||
|
@ -202,11 +202,11 @@ def test_asfarray():
|
|||
match_array(actual, expected, error=7)
|
||||
|
||||
# Additional tests for nested tensor/numpy_array mixture
|
||||
mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
|
||||
mnp_input = [(mnp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
|
||||
onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
|
||||
|
||||
actual = onp.asarray(onp_input)
|
||||
expected = mnp.asarray(mnp_input).asnumpy()
|
||||
actual = onp.asfarray(onp_input)
|
||||
expected = mnp.asfarray(mnp_input).asnumpy()
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
|
||||
|
@ -373,14 +373,14 @@ def test_linspace():
|
|||
stop = onp.random.random([1, 5, 1]).astype("float32")
|
||||
actual = onp.linspace(start, stop, num=20, retstep=True,
|
||||
endpoint=False, dtype=onp.float32)
|
||||
expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20,
|
||||
expected = mnp.linspace(to_tensor(start), to_tensor(stop), num=20,
|
||||
retstep=True, endpoint=False)
|
||||
match_array(actual[0], expected[0].asnumpy(), error=6)
|
||||
match_array(actual[1], expected[1].asnumpy(), error=6)
|
||||
|
||||
actual = onp.linspace(start, stop, num=20, retstep=True,
|
||||
endpoint=False, dtype=onp.int16)
|
||||
expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20,
|
||||
expected = mnp.linspace(to_tensor(start), to_tensor(stop), num=20,
|
||||
retstep=True, endpoint=False, dtype=mnp.int16)
|
||||
match_array(actual[0], expected[0].asnumpy(), error=6)
|
||||
match_array(actual[1], expected[1].asnumpy(), error=6)
|
||||
|
@ -388,7 +388,7 @@ def test_linspace():
|
|||
for axis in range(2):
|
||||
actual = onp.linspace(start, stop, num=20, retstep=False,
|
||||
endpoint=False, dtype=onp.float32, axis=axis)
|
||||
expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20,
|
||||
expected = mnp.linspace(to_tensor(start), to_tensor(stop), num=20,
|
||||
retstep=False, endpoint=False, dtype=mnp.float32, axis=axis)
|
||||
match_array(actual, expected.asnumpy(), error=6)
|
||||
|
||||
|
@ -510,18 +510,18 @@ def test_full_like():
|
|||
for mnp_proto, onp_proto in zip(test_case.mnp_prototypes, test_case.onp_prototypes):
|
||||
shape = onp.zeros_like(onp_proto).shape
|
||||
fill_value = rand_int()
|
||||
actual = mnp.full_like(mnp_proto, mnp.array(fill_value)).asnumpy()
|
||||
actual = mnp.full_like(mnp_proto, to_tensor(fill_value)).asnumpy()
|
||||
expected = onp.full_like(onp_proto, fill_value)
|
||||
match_array(actual, expected)
|
||||
|
||||
for i in range(len(shape) - 1, 0, -1):
|
||||
fill_value = rand_int(*shape[i:])
|
||||
actual = mnp.full_like(mnp_proto, mnp.array(fill_value)).asnumpy()
|
||||
actual = mnp.full_like(mnp_proto, to_tensor(fill_value)).asnumpy()
|
||||
expected = onp.full_like(onp_proto, fill_value)
|
||||
match_array(actual, expected)
|
||||
|
||||
fill_value = rand_int(1, *shape[i + 1:])
|
||||
actual = mnp.full_like(mnp_proto, mnp.array(fill_value)).asnumpy()
|
||||
actual = mnp.full_like(mnp_proto, to_tensor(fill_value)).asnumpy()
|
||||
expected = onp.full_like(onp_proto, fill_value)
|
||||
match_array(actual, expected)
|
||||
|
||||
|
@ -549,6 +549,21 @@ def test_tri_triu_tril():
|
|||
match_array(mnp.tri(64, 64, -10).asnumpy(), onp.tri(64, 64, -10))
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_nancumsum():
|
||||
x = rand_int(2, 3, 4, 5)
|
||||
x[0][2][1][3] = onp.nan
|
||||
x[1][0][2][4] = onp.nan
|
||||
x[1][1][1][1] = onp.nan
|
||||
match_res(mnp.nancumsum, onp.nancumsum, x)
|
||||
match_res(mnp.nancumsum, onp.nancumsum, x, axis=-2)
|
||||
match_res(mnp.nancumsum, onp.nancumsum, x, axis=0)
|
||||
match_res(mnp.nancumsum, onp.nancumsum, x, axis=3)
|
||||
|
||||
|
||||
def mnp_diagonal(arr):
|
||||
return mnp.diagonal(arr, offset=2, axis1=-1, axis2=0)
|
||||
|
||||
|
@ -653,7 +668,7 @@ def test_meshgrid():
|
|||
(2, 3), 9), onp.full((4, 5, 6), 7))
|
||||
for i in range(len(xi)):
|
||||
arrs = xi[i:]
|
||||
mnp_arrs = map(mnp.asarray, arrs)
|
||||
mnp_arrs = map(to_tensor, arrs)
|
||||
for mnp_res, onp_res in zip(mnp_meshgrid(*mnp_arrs), onp_meshgrid(*arrs)):
|
||||
match_all_arrays(mnp_res, onp_res)
|
||||
|
||||
|
@ -750,6 +765,68 @@ def test_ix_():
|
|||
match_res(mnp_ix_, onp_ix_, *test_arrs)
|
||||
|
||||
|
||||
def mnp_indices():
|
||||
a = mnp.indices((2, 3))
|
||||
b = mnp.indices((2, 3, 4), sparse=True)
|
||||
return a, b
|
||||
|
||||
|
||||
def onp_indices():
|
||||
a = onp.indices((2, 3))
|
||||
b = onp.indices((2, 3, 4), sparse=True)
|
||||
return a, b
|
||||
|
||||
|
||||
def test_indices():
|
||||
run_multi_test(mnp_indices, onp_indices, ())
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_geomspace():
|
||||
start = onp.arange(1, 7).reshape(2, 3)
|
||||
end = [1000, 2000, 3000]
|
||||
match_array(mnp.geomspace(1, 256, num=9).asnumpy(),
|
||||
onp.geomspace(1, 256, num=9), error=1)
|
||||
match_array(mnp.geomspace(1, 256, num=8, endpoint=False).asnumpy(),
|
||||
onp.geomspace(1, 256, num=8, endpoint=False), error=1)
|
||||
match_array(mnp.geomspace(to_tensor(start), end, num=4).asnumpy(),
|
||||
onp.geomspace(start, end, num=4), error=1)
|
||||
match_array(mnp.geomspace(to_tensor(start), end, num=4, endpoint=False).asnumpy(),
|
||||
onp.geomspace(start, end, num=4, endpoint=False), error=1)
|
||||
match_array(mnp.geomspace(to_tensor(start), end, num=4, axis=-1).asnumpy(),
|
||||
onp.geomspace(start, end, num=4, axis=-1), error=1)
|
||||
match_array(mnp.geomspace(to_tensor(start), end, num=4, endpoint=False, axis=-1).asnumpy(),
|
||||
onp.geomspace(start, end, num=4, endpoint=False, axis=-1), error=1)
|
||||
|
||||
start = onp.arange(1, 1 + 2*3*4*5).reshape(2, 3, 4, 5)
|
||||
end = [1000, 2000, 3000, 4000, 5000]
|
||||
for i in range(-5, 5):
|
||||
match_array(mnp.geomspace(to_tensor(start), end, num=4, axis=i).asnumpy(),
|
||||
onp.geomspace(start, end, num=4, axis=i), error=1)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vander():
|
||||
arrs = [rand_int(i + 3) for i in range(3)]
|
||||
for i in range(3):
|
||||
mnp_vander = mnp.vander(to_tensor(arrs[i]))
|
||||
onp_vander = onp.vander(arrs[i])
|
||||
match_all_arrays(mnp_vander, onp_vander)
|
||||
mnp_vander = mnp.vander(to_tensor(arrs[i]), N=2, increasing=True)
|
||||
onp_vander = onp.vander(arrs[i], N=2, increasing=True)
|
||||
match_all_arrays(mnp_vander, onp_vander)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
|
|
@ -23,7 +23,7 @@ import mindspore.numpy as mnp
|
|||
from mindspore.nn import Cell
|
||||
|
||||
from .utils import rand_int, run_non_kw_test, check_all_results, match_array, \
|
||||
rand_bool, match_res, run_multi_test
|
||||
rand_bool, match_res, run_multi_test, to_tensor
|
||||
|
||||
|
||||
class Cases():
|
||||
|
@ -139,7 +139,7 @@ def onp_transpose(input_array):
|
|||
@pytest.mark.env_onecard
|
||||
def test_transpose():
|
||||
onp_array = onp.random.random((3, 4, 5)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
mnp_array = to_tensor(onp_array)
|
||||
o_transposed = onp_transpose(onp_array)
|
||||
m_transposed = mnp_transpose(mnp_array)
|
||||
check_all_results(o_transposed, m_transposed)
|
||||
|
@ -170,7 +170,7 @@ def onp_expand_dims(input_array):
|
|||
@pytest.mark.env_onecard
|
||||
def test_expand_dims():
|
||||
onp_array = onp.random.random((3, 4, 5)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
mnp_array = to_tensor(onp_array)
|
||||
o_expanded = onp_expand_dims(onp_array)
|
||||
m_expanded = mnp_expand_dims(mnp_array)
|
||||
check_all_results(o_expanded, m_expanded)
|
||||
|
@ -205,13 +205,13 @@ def onp_squeeze(input_array):
|
|||
@pytest.mark.env_onecard
|
||||
def test_squeeze():
|
||||
onp_array = onp.random.random((1, 3, 1, 4, 2)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
mnp_array = to_tensor(onp_array)
|
||||
o_squeezed = onp_squeeze(onp_array)
|
||||
m_squeezed = mnp_squeeze(mnp_array)
|
||||
check_all_results(o_squeezed, m_squeezed)
|
||||
|
||||
onp_array = onp.random.random((1, 1, 1, 1, 1)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
mnp_array = to_tensor(onp_array)
|
||||
o_squeezed = onp_squeeze(onp_array)
|
||||
m_squeezed = mnp_squeeze(mnp_array)
|
||||
check_all_results(o_squeezed, m_squeezed)
|
||||
|
@ -246,7 +246,7 @@ def onp_rollaxis(input_array):
|
|||
@pytest.mark.env_onecard
|
||||
def test_rollaxis():
|
||||
onp_array = onp.random.random((3, 4, 5)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
mnp_array = to_tensor(onp_array)
|
||||
o_rolled = onp_rollaxis(onp_array)
|
||||
m_rolled = mnp_rollaxis(mnp_array)
|
||||
check_all_results(o_rolled, m_rolled)
|
||||
|
@ -281,7 +281,7 @@ def onp_swapaxes(input_array):
|
|||
@pytest.mark.env_onecard
|
||||
def test_swapaxes():
|
||||
onp_array = onp.random.random((3, 4, 5)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
mnp_array = to_tensor(onp_array)
|
||||
o_swaped = onp_swapaxes(onp_array)
|
||||
m_swaped = mnp_swapaxes(mnp_array)
|
||||
check_all_results(o_swaped, m_swaped)
|
||||
|
@ -324,7 +324,7 @@ def onp_reshape(input_array):
|
|||
@pytest.mark.env_onecard
|
||||
def test_reshape():
|
||||
onp_array = onp.random.random((2, 3, 4)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
mnp_array = to_tensor(onp_array)
|
||||
o_reshaped = onp_reshape(onp_array)
|
||||
m_reshaped = mnp_reshape(mnp_array)
|
||||
check_all_results(o_reshaped, m_reshaped)
|
||||
|
@ -349,7 +349,7 @@ def onp_ravel(input_array):
|
|||
@pytest.mark.env_onecard
|
||||
def test_ravel():
|
||||
onp_array = onp.random.random((2, 3, 4)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
mnp_array = to_tensor(onp_array)
|
||||
o_ravel = onp_ravel(onp_array)
|
||||
m_ravel = mnp_ravel(mnp_array).asnumpy()
|
||||
match_array(o_ravel, m_ravel)
|
||||
|
@ -380,7 +380,7 @@ def onp_concatenate(input_array):
|
|||
@pytest.mark.env_onecard
|
||||
def test_concatenate():
|
||||
onp_array = onp.random.random((5, 4, 3, 2)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
mnp_array = to_tensor(onp_array)
|
||||
o_concatenate = onp_concatenate(onp_array)
|
||||
m_concatenate = mnp_concatenate(mnp_array)
|
||||
check_all_results(o_concatenate, m_concatenate)
|
||||
|
@ -407,8 +407,8 @@ def onp_append(arr1, arr2):
|
|||
def test_append():
|
||||
onp_array = onp.random.random((4, 3, 2)).astype('float32')
|
||||
onp_value = onp.random.random((4, 3, 2)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
mnp_value = mnp.asarray(onp_value)
|
||||
mnp_array = to_tensor(onp_array)
|
||||
mnp_value = to_tensor(onp_value)
|
||||
onp_res = onp_append(onp_array, onp_value)
|
||||
mnp_res = mnp_append(mnp_array, mnp_value)
|
||||
check_all_results(onp_res, mnp_res)
|
||||
|
@ -424,13 +424,13 @@ def construct_arrays(n=1, ndim=1, axis=None, low=1, high=5):
|
|||
onp_array1 = onp.random.randint(
|
||||
low=low, high=high, size=shape).astype(onp.float32)
|
||||
onp_array_lst.append(onp_array1)
|
||||
mnp_array_lst.append(mnp.asarray(onp_array1))
|
||||
mnp_array_lst.append(to_tensor(onp_array1))
|
||||
if axis is not None and axis < ndim:
|
||||
new_shape[axis] += onp.random.randint(2)
|
||||
onp_array2 = onp.random.randint(
|
||||
low=low, high=high, size=new_shape).astype(onp.float32)
|
||||
onp_array_lst.append(onp_array2)
|
||||
mnp_array_lst.append(mnp.asarray(onp_array2))
|
||||
mnp_array_lst.append(to_tensor(onp_array2))
|
||||
return onp_array_lst, mnp_array_lst
|
||||
|
||||
# Test np.xstack
|
||||
|
@ -656,7 +656,7 @@ def onp_ndarray_flatten(input_array):
|
|||
@pytest.mark.env_onecard
|
||||
def test_ndarray_flatten():
|
||||
onp_array = onp.random.random((3, 4, 5)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
mnp_array = to_tensor(onp_array)
|
||||
o_flatten = onp_ndarray_flatten(onp_array)
|
||||
m_flatten = mnp_ndarray_flatten(mnp_array)
|
||||
check_all_results(o_flatten, m_flatten)
|
||||
|
@ -687,7 +687,7 @@ def onp_ndarray_transpose(input_array):
|
|||
@pytest.mark.env_onecard
|
||||
def test_ndarray_transpose():
|
||||
onp_array = onp.random.random((3, 4, 5)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
mnp_array = to_tensor(onp_array)
|
||||
o_transposed = onp_ndarray_transpose(onp_array)
|
||||
m_transposed = mnp_ndarray_transpose(mnp_array)
|
||||
check_all_results(o_transposed, m_transposed)
|
||||
|
@ -716,7 +716,7 @@ def onp_ndarray_astype(input_array):
|
|||
@pytest.mark.env_onecard
|
||||
def test_ndarray_astype():
|
||||
onp_array = onp.random.random((3, 4, 5)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
mnp_array = to_tensor(onp_array)
|
||||
o_astype = onp_ndarray_astype(onp_array)
|
||||
m_astype = mnp_ndarray_astype(mnp_array)
|
||||
for arr1, arr2 in zip(o_astype, m_astype):
|
||||
|
@ -747,7 +747,7 @@ def mnp_concatenate_type_promotion(mnp_array1, mnp_array2, mnp_array3, mnp_array
|
|||
@pytest.mark.env_onecard
|
||||
def test_concatenate_type_promotion():
|
||||
onp_array = onp.random.random((5, 1)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
mnp_array = to_tensor(onp_array)
|
||||
onp_array1 = onp_array.astype(onp.float16)
|
||||
onp_array2 = onp_array.astype(onp.bool_)
|
||||
onp_array3 = onp_array.astype(onp.float32)
|
||||
|
@ -1049,7 +1049,7 @@ def test_split():
|
|||
onp_arrs = [
|
||||
onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32')
|
||||
]
|
||||
mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs]
|
||||
mnp_arrs = [to_tensor(arr) for arr in onp_arrs]
|
||||
for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs):
|
||||
o_split = onp_split(onp_arr)
|
||||
m_split = mnp_split(mnp_arr)
|
||||
|
@ -1058,6 +1058,36 @@ def test_split():
|
|||
match_array(expect, actual.asnumpy())
|
||||
|
||||
|
||||
def mnp_array_split(input_tensor):
|
||||
a = mnp.array_split(input_tensor, indices_or_sections=4, axis=2)
|
||||
b = mnp.array_split(input_tensor, indices_or_sections=3, axis=1)
|
||||
c = mnp.array_split(input_tensor, indices_or_sections=6)
|
||||
return a, b, c
|
||||
|
||||
|
||||
def onp_array_split(input_array):
|
||||
a = onp.array_split(input_array, indices_or_sections=4, axis=2)
|
||||
b = onp.array_split(input_array, indices_or_sections=3, axis=1)
|
||||
c = onp.array_split(input_array, indices_or_sections=6)
|
||||
return a, b, c
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_array_split():
|
||||
onp_arr = onp.random.randint(1, 5, size=(9, 7, 13)).astype('float32')
|
||||
mnp_arr = to_tensor(onp_arr)
|
||||
o_split = onp_split(onp_arr)
|
||||
m_split = mnp_split(mnp_arr)
|
||||
for expect_lst, actual_lst in zip(o_split, m_split):
|
||||
for expect, actual in zip(expect_lst, actual_lst):
|
||||
match_array(expect, actual.asnumpy())
|
||||
|
||||
|
||||
def mnp_vsplit(input_tensor):
|
||||
a = mnp.vsplit(input_tensor, indices_or_sections=3)
|
||||
b = mnp.vsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10))
|
||||
|
@ -1082,7 +1112,7 @@ def test_vsplit():
|
|||
onp_arrs = [
|
||||
onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32')
|
||||
]
|
||||
mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs]
|
||||
mnp_arrs = [to_tensor(arr) for arr in onp_arrs]
|
||||
for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs):
|
||||
o_vsplit = onp_vsplit(onp_arr)
|
||||
m_vsplit = mnp_vsplit(mnp_arr)
|
||||
|
@ -1115,7 +1145,7 @@ def test_hsplit():
|
|||
onp_arrs = [
|
||||
onp.random.randint(1, 5, size=(4, 9, 5)).astype('float32')
|
||||
]
|
||||
mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs]
|
||||
mnp_arrs = [to_tensor(arr) for arr in onp_arrs]
|
||||
for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs):
|
||||
o_hsplit = onp_hsplit(onp_arr)
|
||||
m_hsplit = mnp_hsplit(mnp_arr)
|
||||
|
@ -1148,7 +1178,7 @@ def test_dsplit():
|
|||
onp_arrs = [
|
||||
onp.random.randint(1, 5, size=(5, 4, 9)).astype('float32')
|
||||
]
|
||||
mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs]
|
||||
mnp_arrs = [to_tensor(arr) for arr in onp_arrs]
|
||||
for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs):
|
||||
o_dsplit = onp_dsplit(onp_arr)
|
||||
m_dsplit = mnp_dsplit(mnp_arr)
|
||||
|
@ -1248,6 +1278,29 @@ def test_repeat():
|
|||
run_multi_test(mnp_repeat, onp_repeat, (x,))
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_select():
|
||||
choicelist = rand_int(2, 3, 4, 5)
|
||||
condlist = choicelist > 2
|
||||
match_res(mnp.select, onp.select, condlist, choicelist)
|
||||
match_res(mnp.select, onp.select, condlist, choicelist, default=10)
|
||||
|
||||
condlist = rand_bool(5, 4, 1, 3)
|
||||
choicelist = rand_int(5, 3)
|
||||
match_res(mnp.select, onp.select, condlist, choicelist)
|
||||
match_res(mnp.select, onp.select, condlist, choicelist, default=10)
|
||||
|
||||
condlist = rand_bool(3, 1, 7)
|
||||
choicelist = rand_int(3, 5, 2, 1)
|
||||
match_res(mnp.select, onp.select, condlist, choicelist)
|
||||
match_res(mnp.select, onp.select, condlist, choicelist, default=10)
|
||||
|
||||
|
||||
class ReshapeExpandSqueeze(Cell):
|
||||
def __init__(self):
|
||||
super(ReshapeExpandSqueeze, self).__init__()
|
||||
|
@ -1333,7 +1386,7 @@ def test_swapaxes_exception():
|
|||
@pytest.mark.env_onecard
|
||||
def test_tensor_flatten():
|
||||
lst = [[1.0, 2.0], [3.0, 4.0]]
|
||||
tensor_list = mnp.asarray(lst)
|
||||
tensor_list = to_tensor(lst)
|
||||
assert tensor_list.flatten().asnumpy().tolist() == [1.0, 2.0, 3.0, 4.0]
|
||||
assert tensor_list.flatten(order='F').asnumpy().tolist() == [
|
||||
1.0, 3.0, 2.0, 4.0]
|
||||
|
@ -1347,7 +1400,7 @@ def test_tensor_flatten():
|
|||
@pytest.mark.env_onecard
|
||||
def test_tensor_reshape():
|
||||
lst = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
|
||||
tensor_list = mnp.asarray(lst)
|
||||
tensor_list = to_tensor(lst)
|
||||
with pytest.raises(TypeError):
|
||||
tensor_list = tensor_list.reshape({0, 1, 2})
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -1364,7 +1417,7 @@ def test_tensor_reshape():
|
|||
@pytest.mark.env_onecard
|
||||
def test_tensor_squeeze():
|
||||
lst = [[[1.0], [2.0], [3.0]]]
|
||||
tensor_list = mnp.asarray(lst)
|
||||
tensor_list = to_tensor(lst)
|
||||
with pytest.raises(TypeError):
|
||||
tensor_list = tensor_list.squeeze(1.2)
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -1381,7 +1434,7 @@ def test_tensor_squeeze():
|
|||
@pytest.mark.env_onecard
|
||||
def test_tensor_ravel():
|
||||
lst = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]
|
||||
tensor_list = mnp.asarray(lst)
|
||||
tensor_list = to_tensor(lst)
|
||||
assert tensor_list.ravel().shape == (8,)
|
||||
assert tensor_list.ravel().asnumpy().tolist() == [
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
|
||||
|
@ -1395,9 +1448,47 @@ def test_tensor_ravel():
|
|||
@pytest.mark.env_onecard
|
||||
def test_tensor_swapaxes():
|
||||
lst = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
|
||||
tensor_list = mnp.asarray(lst)
|
||||
tensor_list = to_tensor(lst)
|
||||
with pytest.raises(TypeError):
|
||||
tensor_list = tensor_list.swapaxes(0, (1,))
|
||||
with pytest.raises(ValueError):
|
||||
tensor_list = tensor_list.swapaxes(0, 3)
|
||||
assert tensor_list.swapaxes(0, 1).shape == (3, 2)
|
||||
|
||||
|
||||
def mnp_rot90(input_tensor):
|
||||
a = mnp.rot90(input_tensor)
|
||||
b = mnp.rot90(input_tensor, 2)
|
||||
c = mnp.rot90(input_tensor, 3)
|
||||
d = mnp.rot90(input_tensor, 4)
|
||||
e = mnp.rot90(input_tensor, 5, (0, -1))
|
||||
f = mnp.rot90(input_tensor, 1, (2, 0))
|
||||
g = mnp.rot90(input_tensor, -3, (-1, -2))
|
||||
h = mnp.rot90(input_tensor, 3, (2, 1))
|
||||
return a, b, c, d, e, f, g, h
|
||||
|
||||
|
||||
def onp_rot90(input_array):
|
||||
a = onp.rot90(input_array)
|
||||
b = onp.rot90(input_array, 2)
|
||||
c = onp.rot90(input_array, 3)
|
||||
d = onp.rot90(input_array, 4)
|
||||
e = onp.rot90(input_array, 5, (0, -1))
|
||||
f = onp.rot90(input_array, 1, (2, 0))
|
||||
g = onp.rot90(input_array, -3, (-1, -2))
|
||||
h = onp.rot90(input_array, 3, (2, 1))
|
||||
return a, b, c, d, e, f, g, h
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_rot90():
|
||||
onp_array = rand_int(3, 4, 5).astype('float32')
|
||||
mnp_array = to_tensor(onp_array)
|
||||
o_rot = onp_rot90(onp_array)
|
||||
m_rot = mnp_rot90(mnp_array)
|
||||
check_all_results(o_rot, m_rot)
|
||||
|
|
|
@ -19,7 +19,8 @@ import numpy as onp
|
|||
|
||||
import mindspore.numpy as mnp
|
||||
|
||||
from .utils import rand_int, run_binop_test, match_res
|
||||
from .utils import rand_int, rand_bool, run_binop_test, run_logical_test, match_res, \
|
||||
match_all_arrays, to_tensor
|
||||
|
||||
|
||||
class Cases():
|
||||
|
@ -55,6 +56,15 @@ class Cases():
|
|||
rand_int(8, 1, 6, 1)
|
||||
]
|
||||
|
||||
# Boolean arrays
|
||||
self.boolean_arrs = [
|
||||
rand_bool(),
|
||||
rand_bool(5),
|
||||
rand_bool(6, 1),
|
||||
rand_bool(7, 1, 5),
|
||||
rand_bool(8, 1, 6, 1)
|
||||
]
|
||||
|
||||
# array which contains infs and nans
|
||||
self.infs = onp.array([[1.0, onp.nan], [onp.inf, onp.NINF], [2.3, -4.5], [onp.nan, 0.0]])
|
||||
|
||||
|
@ -246,10 +256,147 @@ def test_isneginf():
|
|||
match_res(mnp_isneginf, onp_isneginf, test_case.infs)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_isscalar():
|
||||
assert mnp.isscalar(1) == onp.isscalar(1)
|
||||
assert mnp.isscalar(2.3) == onp.isscalar(2.3)
|
||||
assert mnp.isscalar([4.5]) == onp.isscalar([4.5])
|
||||
assert mnp.isscalar(False) == onp.isscalar(False)
|
||||
assert mnp.isscalar(mnp.array(True)) == onp.isscalar(onp.array(True))
|
||||
assert mnp.isscalar(to_tensor(True)) == onp.isscalar(onp.array(True))
|
||||
assert mnp.isscalar('numpy') == onp.isscalar('numpy')
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_isclose():
|
||||
a = [0, 1, 2, float('inf'), float('inf'), float('nan')]
|
||||
b = [0, 1, -2, float('-inf'), float('inf'), float('nan')]
|
||||
match_all_arrays(mnp.isclose(a, b), onp.isclose(a, b))
|
||||
match_all_arrays(mnp.isclose(a, b, equal_nan=True), onp.isclose(a, b, equal_nan=True))
|
||||
|
||||
a = rand_int(2, 3, 4, 5)
|
||||
diff = (onp.random.random((2, 3, 4, 5)).astype("float32") - 0.5) / 1000
|
||||
b = a + diff
|
||||
match_all_arrays(mnp.isclose(to_tensor(a), to_tensor(b), atol=1e-3), onp.isclose(a, b, atol=1e-3))
|
||||
match_all_arrays(mnp.isclose(to_tensor(a), to_tensor(b), atol=1e-3, rtol=1e-4),
|
||||
onp.isclose(a, b, atol=1e-3, rtol=1e-4))
|
||||
match_all_arrays(mnp.isclose(to_tensor(a), to_tensor(b), atol=1e-2, rtol=1e-6),
|
||||
onp.isclose(a, b, atol=1e-2, rtol=1e-6))
|
||||
|
||||
a = rand_int(2, 3, 4, 5)
|
||||
b = rand_int(4, 5)
|
||||
match_all_arrays(mnp.isclose(to_tensor(a), to_tensor(b)), onp.isclose(a, b))
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_in1d():
|
||||
xi = [rand_int(), rand_int(1), rand_int(10)]
|
||||
yi = [rand_int(), rand_int(1), rand_int(10)]
|
||||
for x in xi:
|
||||
for y in yi:
|
||||
match_res(mnp.in1d, onp.in1d, x, y)
|
||||
match_res(mnp.in1d, onp.in1d, x, y, invert=True)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_isin():
|
||||
xi = [rand_int(), rand_int(1), rand_int(10), rand_int(2, 3)]
|
||||
yi = [rand_int(), rand_int(1), rand_int(10), rand_int(2, 3)]
|
||||
for x in xi:
|
||||
for y in yi:
|
||||
match_res(mnp.in1d, onp.in1d, x, y)
|
||||
match_res(mnp.in1d, onp.in1d, x, y, invert=True)
|
||||
|
||||
|
||||
def mnp_logical_or(x1, x2):
|
||||
return mnp.logical_or(x1, x2)
|
||||
|
||||
|
||||
def onp_logical_or(x1, x2):
|
||||
return onp.logical_or(x1, x2)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_logical_or():
|
||||
run_logical_test(mnp_logical_or, onp_logical_or, test_case)
|
||||
|
||||
|
||||
def mnp_logical_xor(x1, x2):
|
||||
return mnp.logical_xor(x1, x2)
|
||||
|
||||
|
||||
def onp_logical_xor(x1, x2):
|
||||
return onp.logical_xor(x1, x2)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_logical_xor():
|
||||
run_logical_test(mnp_logical_xor, onp_logical_xor, test_case)
|
||||
|
||||
|
||||
def mnp_logical_and(x1, x2):
|
||||
return mnp.logical_and(x1, x2)
|
||||
|
||||
|
||||
def onp_logical_and(x1, x2):
|
||||
return onp.logical_and(x1, x2)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_logical_and():
|
||||
run_logical_test(mnp_logical_and, onp_logical_and, test_case)
|
||||
|
||||
|
||||
def mnp_logical_not(x):
|
||||
return mnp.logical_not(x)
|
||||
|
||||
|
||||
def onp_logical_not(x):
|
||||
return onp.logical_not(x)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_logical_not():
|
||||
for arr in test_case.boolean_arrs:
|
||||
expected = onp_logical_not(arr)
|
||||
actual = mnp_logical_not(to_tensor(arr))
|
||||
onp.testing.assert_equal(actual.asnumpy().tolist(), expected.tolist())
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -15,6 +15,7 @@
|
|||
"""utility functions for mindspore.numpy st tests"""
|
||||
import functools
|
||||
import numpy as onp
|
||||
from mindspore import Tensor
|
||||
import mindspore.numpy as mnp
|
||||
|
||||
|
||||
|
@ -90,7 +91,9 @@ def rand_bool(*shape):
|
|||
|
||||
def match_res(mnp_fn, onp_fn, *arrs, **kwargs):
|
||||
"""Checks results from applying mnp_fn and onp_fn on arrs respectively"""
|
||||
mnp_arrs = map(functools.partial(mnp.asarray, dtype='float32'), arrs)
|
||||
dtype = kwargs.get('dtype', mnp.float32)
|
||||
kwargs.pop('dtype', None)
|
||||
mnp_arrs = map(functools.partial(Tensor, dtype=dtype), arrs)
|
||||
error = kwargs.get('error', 0)
|
||||
kwargs.pop('error', None)
|
||||
mnp_res = mnp_fn(*mnp_arrs, **kwargs)
|
||||
|
@ -151,15 +154,32 @@ def run_unary_test(mnp_fn, onp_fn, test_case, error=0):
|
|||
|
||||
|
||||
def run_multi_test(mnp_fn, onp_fn, arrs, error=0):
|
||||
mnp_arrs = map(mnp.asarray, arrs)
|
||||
mnp_arrs = map(Tensor, arrs)
|
||||
for actual, expected in zip(mnp_fn(*mnp_arrs), onp_fn(*arrs)):
|
||||
match_array(actual.asnumpy(), expected, error)
|
||||
match_all_arrays(actual, expected, error)
|
||||
|
||||
|
||||
def run_single_test(mnp_fn, onp_fn, arr, error=0):
|
||||
mnp_arr = mnp.asarray(arr)
|
||||
mnp_arr = Tensor(arr)
|
||||
for actual, expected in zip(mnp_fn(mnp_arr), onp_fn(arr)):
|
||||
if isinstance(expected, tuple):
|
||||
for actual_arr, expected_arr in zip(actual, expected):
|
||||
match_array(actual_arr.asnumpy(), expected_arr, error)
|
||||
match_array(actual.asnumpy(), expected, error)
|
||||
|
||||
|
||||
def run_logical_test(mnp_fn, onp_fn, test_case):
|
||||
for x1 in test_case.boolean_arrs:
|
||||
for x2 in test_case.boolean_arrs:
|
||||
match_res(mnp_fn, onp_fn, x1, x2, dtype=mnp.bool_)
|
||||
|
||||
def to_tensor(obj, dtype=None):
|
||||
if dtype is None:
|
||||
res = Tensor(obj)
|
||||
if res.dtype == mnp.float64:
|
||||
res = res.astype(mnp.float32)
|
||||
if res.dtype == mnp.int64:
|
||||
res = res.astype(mnp.int32)
|
||||
else:
|
||||
res = Tensor(obj, dtype)
|
||||
return res
|
||||
|
|
Loading…
Reference in New Issue