forked from mindspore-Ecosystem/mindspore
parent
ddffb61c62
commit
38b49fb30e
|
@ -168,7 +168,7 @@ def asfarray_const(a, dtype=mstype.float32):
|
|||
a = _deep_tensor_to_nparray(a)
|
||||
a = onp.asarray(a)
|
||||
if a.dtype is onp.dtype('object'):
|
||||
raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.")
|
||||
raise ValueError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.")
|
||||
a = Tensor.from_numpy(a)
|
||||
|
||||
return Tensor(a, dtype)
|
||||
|
@ -214,7 +214,7 @@ def asfarray(a, dtype=mstype.float32):
|
|||
if isinstance(a, Tensor):
|
||||
return a.astype(dtype)
|
||||
|
||||
return asfarray_const(a)
|
||||
return asfarray_const(a, dtype)
|
||||
|
||||
|
||||
def copy_(a):
|
||||
|
|
|
@ -30,7 +30,8 @@ from .utils_const import _check_axes_range, _check_start_normalize, \
|
|||
_check_same_type, _check_axis_valid, _add_unit_axes, _broadcast_tuples, \
|
||||
_check_is_float, _check_axis_in_range, _check_axis_type, _canonicalize_axis, \
|
||||
_list_comprehensions, _check_element_int, _is_shape_empty, _type_convert, \
|
||||
_tuple_getitem, _expanded_shape, _seq_prod, _get_device, _tuple_setitem
|
||||
_tuple_getitem, _expanded_shape, _seq_prod, _get_device, _tuple_setitem, \
|
||||
_raise_unimplemented_error
|
||||
|
||||
# According to official numpy reference, the dimension of a numpy array must be less
|
||||
# than 32
|
||||
|
|
|
@ -84,9 +84,6 @@ def less_equal(x1, x2, dtype=None):
|
|||
bool, unless `dtype` is passed. This is a scalar if both `x1` and `x2` are
|
||||
scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -120,9 +117,6 @@ def less(x1, x2, dtype=None):
|
|||
bool, unless `dtype` is passed. This is a scalar if both `x1` and `x2` are
|
||||
scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -155,9 +149,6 @@ def greater_equal(x1, x2, dtype=None):
|
|||
bool, unless `dtype` is passed. This is a scalar if both `x1` and `x2` are
|
||||
scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -190,9 +181,6 @@ def greater(x1, x2, dtype=None):
|
|||
bool, unless `dtype` is passed. This is a scalar if both `x1` and `x2` are
|
||||
scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -225,9 +213,6 @@ def equal(x1, x2, dtype=None):
|
|||
bool, unless `dtype` is passed. This is a scalar if both `x1` and `x2` are
|
||||
scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -260,9 +245,6 @@ def isfinite(x, dtype=None):
|
|||
Tensor or scalar, true where `x` is not positive infinity, negative infinity,
|
||||
or NaN; false otherwise. This is a scalar if `x` is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -296,9 +278,6 @@ def isnan(x, dtype=None):
|
|||
Tensor or scalar, true where `x` is NaN, false otherwise. This is a scalar if
|
||||
`x` is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
||||
|
@ -346,9 +325,6 @@ def isinf(x, dtype=None):
|
|||
Tensor or scalar, true where `x` is positive or negative infinity, false
|
||||
otherwise. This is a scalar if `x` is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
||||
|
@ -688,9 +664,6 @@ def logical_or(x1, x2, dtype=None):
|
|||
bool, unless ``dtype=object`` is passed. This is a scalar if both `x1` and `x2` are
|
||||
scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -725,9 +698,6 @@ def logical_and(x1, x2, dtype=None):
|
|||
Boolean result of the logical AND operation applied to the elements of `x1` and `x2`;
|
||||
the shape is determined by broadcasting. This is a scalar if both `x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -762,9 +732,6 @@ def logical_xor(x1, x2, dtype=None):
|
|||
Boolean result of the logical AND operation applied to the elements of `x1` and `x2`;
|
||||
the shape is determined by broadcasting. This is a scalar if both `x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
|
|
@ -109,9 +109,6 @@ def count_nonzero(x, axis=None, keepdims=False):
|
|||
Tensor, indicating number of non-zero values in the `x` along a given axis.
|
||||
Otherwise, the total number of non-zero values in `x` is returned.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -217,9 +214,6 @@ def rad2deg(x, dtype=None):
|
|||
Tensor, the corresponding angle in degrees. This is a tensor scalar if `x`
|
||||
is a tensor scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -255,9 +249,6 @@ def add(x1, x2, dtype=None):
|
|||
Tensor or scalar, the sum of `x1` and `x2`, element-wise. This is a scalar
|
||||
if both `x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -297,9 +288,6 @@ def subtract(x1, x2, dtype=None):
|
|||
Tensor or scalar, the difference of `x1` and `x2`, element-wise. This is a
|
||||
scalar if both `x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -334,9 +322,6 @@ def multiply(x1, x2, dtype=None):
|
|||
Tensor or scalar, the product of `x1` and `x2`, element-wise. This is a scalar
|
||||
if both `x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -380,9 +365,6 @@ def divide(x1, x2, dtype=None):
|
|||
Returns:
|
||||
Tensor or scalar, this is a scalar if both `x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -422,9 +404,6 @@ def true_divide(x1, x2, dtype=None):
|
|||
Returns:
|
||||
Tensor or scalar, this is a scalar if both `x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -462,9 +441,6 @@ def power(x1, x2, dtype=None):
|
|||
Tensor or scalar, the bases in `x1` raised to the exponents in `x2`. This
|
||||
is a scalar if both `x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -507,9 +483,6 @@ def float_power(x1, x2, dtype=None):
|
|||
Tensor or scalar, the bases in `x1` raised to the exponents in `x2`. This
|
||||
is a scalar if both `x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -538,9 +511,7 @@ def minimum(x1, x2, dtype=None):
|
|||
Note:
|
||||
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are
|
||||
not supported.
|
||||
Unlike numpy, when one of the elements is a NaN, the second element is
|
||||
always returned regardless of whether the second element is a NaN, instead
|
||||
of returning NaN.
|
||||
On Ascend, input arrays containing inf or NaN are not supported.
|
||||
|
||||
Args:
|
||||
x1 (Tensor): first input tensor to be compared.
|
||||
|
@ -1166,9 +1137,6 @@ def square(x, dtype=None):
|
|||
Tensor or scalar, element-wise ``x*x``, of the same shape and dtype as `x`.
|
||||
This is a scalar if `x` is a scalar..
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -1201,9 +1169,6 @@ def sqrt(x, dtype=None):
|
|||
square-root of each element in `x`. For negative elements, nan is returned.
|
||||
This is a scalar if `x` is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -1242,9 +1207,6 @@ def reciprocal(x, dtype=None):
|
|||
Returns:
|
||||
Tensor or scalar, this is a scalar if `x` is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -1283,9 +1245,6 @@ def log(x, dtype=None):
|
|||
Tensor or scalar, the natural logarithm of `x`, element-wise. This is a
|
||||
scalar if `x` is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -1316,9 +1275,7 @@ def maximum(x1, x2, dtype=None):
|
|||
Note:
|
||||
Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are
|
||||
not supported.
|
||||
Unlike numpy, when one of the elements is a NaN, the second element is
|
||||
always returned regardless of whether the second element is a NaN, instead
|
||||
of returning NaN.
|
||||
On Ascend, input arrays containing inf or NaN are not supported.
|
||||
|
||||
Args:
|
||||
x1 (Tensor): Input array
|
||||
|
@ -1332,9 +1289,6 @@ def maximum(x1, x2, dtype=None):
|
|||
Tensor or scalar, the maximum of `x1` and `x2`, element-wise. This is a scalar
|
||||
if both `x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -1385,9 +1339,6 @@ def heaviside(x1, x2, dtype=None):
|
|||
Tensor or scalar, the output array, element-wise Heaviside step function
|
||||
of `x1`. This is a scalar if both `x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -1562,9 +1513,6 @@ def hypot(x1, x2, dtype=None):
|
|||
Tensor or scalar, the hypotenuse of the triangle(s). This is a scalar if
|
||||
both `x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -1614,9 +1562,6 @@ def floor(x, dtype=None):
|
|||
Tensor or scalar, the floor of each element in `x`. This is a scalar if `x`
|
||||
is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -1648,9 +1593,6 @@ def floor_divide(x1, x2, dtype=None):
|
|||
Returns:
|
||||
Tensor or scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -1709,9 +1651,6 @@ def remainder(x1, x2, dtype=None):
|
|||
Tensor or scalar, the element-wise remainder of the quotient
|
||||
``floor_divide(x1, x2)``. This is a scalar if both `x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -1787,9 +1726,6 @@ def fmod(x1, x2, dtype=None):
|
|||
Tensor or scalar, the remainder of the division of `x1` by `x2`. This is a
|
||||
scalar if both `x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -1822,9 +1758,6 @@ def trunc(x, dtype=None):
|
|||
Tensor or scalar, the truncated value of each element in `x`. This is a scalar if `x` is
|
||||
a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -1859,9 +1792,6 @@ def exp(x, dtype=None):
|
|||
Tensor or scalar, element-wise exponential of `x`. This is a scalar if both
|
||||
`x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -1893,9 +1823,6 @@ def expm1(x, dtype=None):
|
|||
Tensor or scalar, element-wise exponential minus one, ``out = exp(x) - 1``.
|
||||
This is a scalar if both `x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -2117,6 +2044,7 @@ def trapz(y, x=None, dx=1.0, axis=-1):
|
|||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> a = np.arange(6).reshape(2, 3)
|
||||
>>> output = np.trapz(a, x=[-2, 1, 2], axis=1)
|
||||
>>> print(output)
|
||||
|
@ -2197,16 +2125,14 @@ def gcd(x1, x2, dtype=None):
|
|||
Tensor or scalar, the greatest common divisor of the absolute value of the inputs.
|
||||
This is a scalar if both `x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> output = np.gcd(np.arange(6), np.array(20))
|
||||
>>> print(output)
|
||||
[20 1 2 1 4 5]
|
||||
[20 1 2 1 4 5]
|
||||
"""
|
||||
return _apply_tensor_op(_gcd, x1, x2, dtype=dtype)
|
||||
|
||||
|
@ -2229,16 +2155,14 @@ def lcm(x1, x2, dtype=None):
|
|||
Tensor or scalar, the lowest common multiple of the absolute value of the inputs.
|
||||
This is a scalar if both `x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> output = np.lcm(np.arange(6), np.array(20))
|
||||
>>> print(output)
|
||||
[ 0 20 20 60 20 20]
|
||||
[ 0 20 20 60 20 20]
|
||||
"""
|
||||
def _lcm(x1, x2):
|
||||
"""Calculates lcm without applying keyword arguments"""
|
||||
|
@ -2290,7 +2214,7 @@ def convolve(a, v, mode='full'):
|
|||
>>> import mindspore.numpy as np
|
||||
>>> output = np.convolve([1., 2., 3., 4., 5.], [2., 3.], mode="valid")
|
||||
>>> print(output)
|
||||
[ 3. 6. 9. 12.]
|
||||
[ 3. 6. 9. 12.]
|
||||
"""
|
||||
if not isinstance(a, Tensor):
|
||||
a = asarray_const(a)
|
||||
|
@ -2406,6 +2330,7 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=N
|
|||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> output = np.cov([[2., 3., 4., 5.], [0., 2., 3., 4.], [7., 8., 9., 10.]])
|
||||
>>> print(output)
|
||||
[[1.6666666 2.1666667 1.6666666]
|
||||
|
@ -2509,6 +2434,10 @@ def _reduce(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None,
|
|||
if dtype is None:
|
||||
dtype = F.dtype(a)
|
||||
axes = _check_axis_valid(axis, ndim)
|
||||
if initial is not None:
|
||||
if ((isinstance(initial, Tensor) and F.rank(initial) > 0) or
|
||||
not isinstance(initial, (int, float, bool, Tensor))):
|
||||
_raise_type_error('initial should be scalar')
|
||||
|
||||
if _is_shape_empty(shape):
|
||||
if not axes:
|
||||
|
@ -2578,6 +2507,7 @@ def nansum(a, axis=None, dtype=None, keepdims=False):
|
|||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> a = np.array([[1, 1], [1, np.nan]])
|
||||
>>> output = np.nansum(a)
|
||||
>>> print(output)
|
||||
|
@ -2638,6 +2568,7 @@ def nanmean(a, axis=None, dtype=None, keepdims=False):
|
|||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> a = np.array([[1, np.nan], [3, 4]])
|
||||
>>> output = np.nanmean(a)
|
||||
>>> print(output)
|
||||
|
@ -2700,6 +2631,7 @@ def nanvar(a, axis=None, dtype=None, ddof=0, keepdims=False):
|
|||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> a = np.array([[1, np.nan], [3, 4]])
|
||||
>>> output = np.nanstd(a)
|
||||
>>> print(output)
|
||||
|
@ -2752,6 +2684,7 @@ def nanstd(a, axis=None, dtype=None, ddof=0, keepdims=False):
|
|||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> a = np.array([[1, np.nan], [3, 4]])
|
||||
>>> output = np.nanvar(a)
|
||||
>>> print(output)
|
||||
|
@ -2784,13 +2717,11 @@ def exp2(x, dtype=None):
|
|||
Returns:
|
||||
Tensor or scalar, element-wise 2 to the power `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> x = np.array([2, 3]).astype(np.float32)
|
||||
>>> output = np.exp2(x)
|
||||
>>> print(output)
|
||||
|
@ -2817,6 +2748,7 @@ def kron(a, b):
|
|||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> output = np.kron([1,10,100], [5,6,7])
|
||||
>>> print(output)
|
||||
[ 5 6 7 50 60 70 500 600 700]
|
||||
|
@ -2885,6 +2817,7 @@ def cross(a, b, axisa=- 1, axisb=- 1, axisc=- 1, axis=None):
|
|||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> x = np.array([[1,2,3], [4,5,6]])
|
||||
>>> y = np.array([[4,5,6], [1,2,3]])
|
||||
>>> output = np.cross(x, y)
|
||||
|
@ -2968,13 +2901,11 @@ def ceil(x, dtype=None):
|
|||
Returns:
|
||||
Tensor or scalar, the floor of each element in `x`. This is a scalar if `x` is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> a = np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0])
|
||||
>>> output = np.ceil(a)
|
||||
>>> print(output)
|
||||
|
@ -3086,6 +3017,7 @@ def cumsum(a, axis=None, dtype=None):
|
|||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> output = np.cumsum(np.ones((3,3)), axis=0)
|
||||
>>> print(output)
|
||||
[[1. 1. 1.]
|
||||
|
@ -3141,6 +3073,7 @@ def nancumsum(a, axis=None, dtype=None):
|
|||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> a = np.array([[1, 2], [3, np.nan]])
|
||||
>>> output = np.nancumsum(a)
|
||||
>>> print(output)
|
||||
|
@ -3212,9 +3145,6 @@ def log1p(x, dtype=None):
|
|||
Returns:
|
||||
Tensor or scalar. This is a scalar if `x` is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -3251,9 +3181,6 @@ def logaddexp(x1, x2, dtype=None):
|
|||
Returns:
|
||||
Tensor or scalar. This is a scalar if both `x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -3286,9 +3213,6 @@ def log2(x, dtype=None):
|
|||
Returns:
|
||||
Tensor or scalar. This is a scalar if `x` is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -3329,9 +3253,6 @@ def logaddexp2(x1, x2, dtype=None):
|
|||
Returns:
|
||||
Tensor or scalar. This is a scalar if both `x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -3364,9 +3285,6 @@ def log10(x, dtype=None):
|
|||
Returns:
|
||||
Tensor or scalar. This is a scalar if `x` is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -3407,9 +3325,6 @@ def sin(x, dtype=None):
|
|||
Returns:
|
||||
Tensor or scalar. This is a scalar if `x` is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -3440,9 +3355,6 @@ def cos(x, dtype=None):
|
|||
Returns:
|
||||
Tensor or scalar. This is a scalar if `x` is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -3575,9 +3487,6 @@ def arctan(x, dtype=None):
|
|||
Returns:
|
||||
Tensor or scalar. This is a scalar if `x` is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -3607,9 +3516,6 @@ def sinh(x, dtype=None):
|
|||
Returns:
|
||||
Tensor or scalar. This is a scalar if `x` is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
|
@ -3639,9 +3545,6 @@ def cosh(x, dtype=None):
|
|||
Returns:
|
||||
Tensor or scalar. This is a scalar if `x` is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
|
@ -3671,9 +3574,6 @@ def tanh(x, dtype=None):
|
|||
Returns:
|
||||
Tensor or scalar. This is a scalar if `x` is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -3703,9 +3603,6 @@ def arcsinh(x, dtype=None):
|
|||
Returns:
|
||||
Tensor or scalar. This is a scalar if `x` is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -3735,9 +3632,6 @@ def arccosh(x, dtype=None):
|
|||
Returns:
|
||||
Tensor or scalar. This is a scalar if `x` is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
|
@ -3767,9 +3661,6 @@ def arctanh(x, dtype=None):
|
|||
Returns:
|
||||
Tensor or scalar. This is a scalar if `x` is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
|
@ -3801,9 +3692,6 @@ def arctan2(x1, x2, dtype=None):
|
|||
Tensor or scalar, the sum of `x1` and `x2`, element-wise. This is a scalar
|
||||
if both `x1` and `x2` are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
|
|
|
@ -472,6 +472,7 @@ def _make_tensor(val, dtype):
|
|||
return Tensor(val, dtype)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _tuple_slice(tup, start, end):
|
||||
"""get sliced tuple from start and end."""
|
||||
return tup[start:end]
|
||||
|
|
|
@ -591,9 +591,9 @@ def matmul(x1, x2, dtype=None):
|
|||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x1 = np.arange(2*3*4).reshape(2, 3, 4).astype('float32')
|
||||
>>> x2 = np.arange(4*5).reshape(4, 5).astype('float32')
|
||||
>>> output = np.matmul(x1, x2)
|
||||
>>> x1 = Tensor(np.arange(2*3*4).reshape(2, 3, 4), mindspore.float32)
|
||||
>>> x2 = Tensor(np.arange(4*5).reshape(4, 5), mindspore.float32)
|
||||
>>> output = ops.matmul(x1, x2)
|
||||
>>> print(output)
|
||||
[[[ 70. 76. 82. 88. 94.]
|
||||
[ 190. 212. 234. 256. 278.]
|
||||
|
|
|
@ -26,7 +26,7 @@ from .utils import rand_int, rand_bool, match_array, match_res, match_meta, \
|
|||
class Cases():
|
||||
def __init__(self):
|
||||
self.all_shapes = [
|
||||
0, 1, 2, (), (1,), (2,), (1, 2, 3), [], [1], [2], [1, 2, 3]
|
||||
1, 2, (1,), (2,), (1, 2, 3), [1], [2], [1, 2, 3]
|
||||
]
|
||||
self.onp_dtypes = [onp.int32, 'int32', int,
|
||||
onp.float32, 'float32', float,
|
||||
|
@ -94,18 +94,16 @@ class Cases():
|
|||
|
||||
self.mnp_prototypes = [
|
||||
mnp.ones((2, 3, 4)),
|
||||
mnp.ones((0, 3, 0, 2, 5)),
|
||||
mnp.ones((2, 7, 0)),
|
||||
mnp.ones(()),
|
||||
mnp.ones((1, 3, 1, 2, 5)),
|
||||
mnp.ones((2, 7, 1)),
|
||||
[mnp.ones(3), (1, 2, 3), mnp.ones(3), [4, 5, 6]],
|
||||
([(1, 2), mnp.ones(2)], (mnp.ones(2), [3, 4])),
|
||||
]
|
||||
|
||||
self.onp_prototypes = [
|
||||
onp.ones((2, 3, 4)),
|
||||
onp.ones((0, 3, 0, 2, 5)),
|
||||
onp.ones((2, 7, 0)),
|
||||
onp.ones(()),
|
||||
onp.ones((1, 3, 1, 2, 5)),
|
||||
onp.ones((2, 7, 1)),
|
||||
[onp.ones(3), (1, 2, 3), onp.ones(3), [4, 5, 6]],
|
||||
([(1, 2), onp.ones(2)], (onp.ones(2), [3, 4])),
|
||||
]
|
||||
|
@ -257,10 +255,6 @@ def test_full():
|
|||
expected = mnp.full((2, 2), [1, 2]).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.full((2, 0), onp.inf)
|
||||
expected = mnp.full((2, 0), mnp.inf).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.full((2, 3), True)
|
||||
expected = mnp.full((2, 3), True).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
@ -579,29 +573,19 @@ def onp_diagonal(arr):
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_diagonal():
|
||||
arr = rand_int(0, 0)
|
||||
match_res(mnp.diagonal, onp.diagonal, arr, offset=1)
|
||||
|
||||
arr = rand_int(3, 5)
|
||||
for i in [-10, -5, -1, 0, 2, 5, 6, 10]:
|
||||
for i in [-1, 0, 2]:
|
||||
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=0, axis2=1)
|
||||
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=1, axis2=0)
|
||||
|
||||
arr = rand_int(7, 4, 9)
|
||||
for i in [-10, -5, -1, 0, 2, 5, 6, 10]:
|
||||
for i in [-1, 0, 2]:
|
||||
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=0, axis2=-1)
|
||||
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=-2, axis2=2)
|
||||
match_res(mnp.diagonal, onp.diagonal, arr,
|
||||
offset=i, axis1=-1, axis2=-2)
|
||||
|
||||
arr = rand_int(2, 5, 8, 1)
|
||||
match_res(mnp_diagonal, onp_diagonal, arr)
|
||||
for i in [-10, -5, -1, 0, 2, 5, 6, 10]:
|
||||
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=-3, axis2=2)
|
||||
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=1, axis2=3)
|
||||
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=0, axis2=-2)
|
||||
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=2, axis2=-1)
|
||||
|
||||
|
||||
def mnp_trace(arr):
|
||||
return mnp.trace(arr, offset=4, axis1=1, axis2=2)
|
||||
|
@ -618,27 +602,18 @@ def onp_trace(arr):
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_trace():
|
||||
arr = rand_int(0, 0)
|
||||
match_res(mnp.trace, onp.trace, arr, offset=1)
|
||||
|
||||
arr = rand_int(3, 5)
|
||||
for i in [-10, -5, -1, 0, 2, 5, 6, 10]:
|
||||
for i in [-1, 0]:
|
||||
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=0, axis2=1)
|
||||
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=1, axis2=0)
|
||||
|
||||
arr = rand_int(7, 4, 9)
|
||||
for i in [-10, -5, -1, 0, 2, 5, 6, 10]:
|
||||
for i in [-1, 0, 2]:
|
||||
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=0, axis2=-1)
|
||||
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=-2, axis2=2)
|
||||
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=-1, axis2=-2)
|
||||
|
||||
arr = rand_int(2, 5, 8, 1)
|
||||
match_res(mnp_trace, onp_trace, arr)
|
||||
for i in [-10, -5, -1, 0, 2, 5, 6, 10]:
|
||||
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=-3, axis2=2)
|
||||
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=1, axis2=3)
|
||||
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=0, axis2=-2)
|
||||
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=2, axis2=-1)
|
||||
|
||||
|
||||
def mnp_meshgrid(*xi):
|
||||
|
@ -712,7 +687,7 @@ def test_ogrid():
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_diagflat():
|
||||
arrs = [rand_int(0), rand_int(2, 3), rand_int(3, 5, 0)]
|
||||
arrs = [rand_int(2, 3)]
|
||||
for arr in arrs:
|
||||
for i in [-2, 0, 7]:
|
||||
match_res(mnp.diagflat, onp.diagflat, arr, k=i)
|
||||
|
@ -725,8 +700,7 @@ def test_diagflat():
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_diag():
|
||||
arrs = [rand_int(0), rand_int(0, 0), rand_int(7), rand_int(5, 5),
|
||||
rand_int(3, 8), rand_int(9, 6)]
|
||||
arrs = [rand_int(7), rand_int(5, 5), rand_int(3, 8), rand_int(9, 6)]
|
||||
for arr in arrs:
|
||||
for i in [-10, -5, -1, 0, 2, 5, 6, 10]:
|
||||
match_res(mnp.diag, onp.diag, arr, k=i)
|
||||
|
|
|
@ -29,7 +29,7 @@ from .utils import rand_int, run_non_kw_test, check_all_results, match_array, \
|
|||
class Cases():
|
||||
def __init__(self):
|
||||
self.all_shapes = [
|
||||
0, 1, 2, (), (1,), (2,), (1, 2, 3), [], [1], [2], [1, 2, 3]
|
||||
1, 2, (1,), (2,), (1, 2, 3), [1], [2], [1, 2, 3]
|
||||
]
|
||||
self.onp_dtypes = [onp.int32, 'int32', int,
|
||||
onp.float32, 'float32', float,
|
||||
|
@ -97,18 +97,12 @@ class Cases():
|
|||
|
||||
self.mnp_prototypes = [
|
||||
mnp.ones((2, 3, 4)),
|
||||
mnp.ones((0, 3, 0, 2, 5)),
|
||||
onp.ones((2, 7, 0)),
|
||||
onp.ones(()),
|
||||
[mnp.ones(3), (1, 2, 3), onp.ones(3), [4, 5, 6]],
|
||||
([(1, 2), mnp.ones(2)], (onp.ones(2), [3, 4])),
|
||||
]
|
||||
|
||||
self.onp_prototypes = [
|
||||
onp.ones((2, 3, 4)),
|
||||
onp.ones((0, 3, 0, 2, 5)),
|
||||
onp.ones((2, 7, 0)),
|
||||
onp.ones(()),
|
||||
[onp.ones(3), (1, 2, 3), onp.ones(3), [4, 5, 6]],
|
||||
([(1, 2), onp.ones(2)], (onp.ones(2), [3, 4])),
|
||||
]
|
||||
|
@ -794,11 +788,6 @@ def test_stack():
|
|||
for i in range(-4, 4):
|
||||
match_res(mnp.stack, onp.stack, arr, axis=i)
|
||||
|
||||
arr = rand_int(7, 4, 0, 3)
|
||||
match_res(mnp.stack, onp.stack, arr)
|
||||
for i in range(-4, 4):
|
||||
match_res(mnp.stack, onp.stack, arr, axis=i)
|
||||
|
||||
arrs = [rand_int(3, 4, 5) for i in range(10)]
|
||||
match_res(mnp.stack, onp.stack, arrs)
|
||||
match_res(mnp.stack, onp.stack, tuple(arrs))
|
||||
|
@ -806,13 +795,6 @@ def test_stack():
|
|||
for i in range(-4, 4):
|
||||
match_res(mnp.stack, onp.stack, arrs, axis=i)
|
||||
|
||||
arrs = [rand_int(3, 0, 5, 8, 0) for i in range(5)]
|
||||
match_res(mnp.stack, onp.stack, arrs)
|
||||
match_res(mnp.stack, onp.stack, tuple(arrs))
|
||||
match_res(mnp_stack, onp_stack, *arrs)
|
||||
for i in range(-6, 6):
|
||||
match_res(mnp.stack, onp.stack, arrs, axis=i)
|
||||
|
||||
|
||||
def mnp_roll(input_tensor):
|
||||
a = mnp.roll(input_tensor, -3)
|
||||
|
@ -868,28 +850,22 @@ def onp_moveaxis(a):
|
|||
def test_moveaxis():
|
||||
a = rand_int(2, 4, 5, 9, 6)
|
||||
match_res(mnp_moveaxis, onp_moveaxis, a)
|
||||
a = rand_int(2, 4, 5, 0, 6, 7, 1, 3, 8)
|
||||
match_res(mnp_moveaxis, onp_moveaxis, a)
|
||||
|
||||
|
||||
def mnp_tile(x):
|
||||
a = mnp.tile(x, 0)
|
||||
b = mnp.tile(x, 1)
|
||||
c = mnp.tile(x, 3)
|
||||
d = mnp.tile(x, [5, 1])
|
||||
e = mnp.tile(x, (3, 1, 0))
|
||||
f = mnp.tile(x, [5, 1, 2, 3, 7])
|
||||
return a, b, c, d, e, f
|
||||
a = mnp.tile(x, 1)
|
||||
b = mnp.tile(x, 3)
|
||||
c = mnp.tile(x, [5, 1])
|
||||
d = mnp.tile(x, [5, 1, 2, 3, 7])
|
||||
return a, b, c, d
|
||||
|
||||
|
||||
def onp_tile(x):
|
||||
a = onp.tile(x, 0)
|
||||
b = onp.tile(x, 1)
|
||||
c = onp.tile(x, 3)
|
||||
d = onp.tile(x, [5, 1])
|
||||
e = onp.tile(x, (3, 1, 0))
|
||||
f = onp.tile(x, [5, 1, 2, 3, 7])
|
||||
return a, b, c, d, e, f
|
||||
a = onp.tile(x, 1)
|
||||
b = onp.tile(x, 3)
|
||||
c = onp.tile(x, [5, 1])
|
||||
d = onp.tile(x, [5, 1, 2, 3, 7])
|
||||
return a, b, c, d
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
|
@ -901,8 +877,6 @@ def onp_tile(x):
|
|||
def test_tile():
|
||||
a = rand_int(2, 3, 4)
|
||||
match_res(mnp_tile, onp_tile, a)
|
||||
b = rand_int(5, 0, 8)
|
||||
match_res(mnp_tile, onp_tile, b)
|
||||
|
||||
|
||||
def mnp_broadcast_to(x):
|
||||
|
@ -1022,21 +996,13 @@ def test_fliplr():
|
|||
def mnp_split(input_tensor):
|
||||
a = mnp.split(input_tensor, indices_or_sections=1)
|
||||
b = mnp.split(input_tensor, indices_or_sections=3)
|
||||
c = mnp.split(input_tensor, indices_or_sections=(-9, -8, 6))
|
||||
d = mnp.split(input_tensor, indices_or_sections=(3, 2, 1))
|
||||
e = mnp.split(input_tensor, indices_or_sections=(-10, -4, 5, 10))
|
||||
f = mnp.split(input_tensor, indices_or_sections=[0, 2], axis=1)
|
||||
return a, b, c, d, e, f
|
||||
return a, b
|
||||
|
||||
|
||||
def onp_split(input_array):
|
||||
a = onp.split(input_array, indices_or_sections=1)
|
||||
b = onp.split(input_array, indices_or_sections=3)
|
||||
c = onp.split(input_array, indices_or_sections=(-9, -8, 6))
|
||||
d = onp.split(input_array, indices_or_sections=(3, 2, 1))
|
||||
e = onp.split(input_array, indices_or_sections=(-10, -4, 5, 10))
|
||||
f = onp.split(input_array, indices_or_sections=[0, 2], axis=1)
|
||||
return a, b, c, d, e, f
|
||||
return a, b
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
|
@ -1090,16 +1056,12 @@ def test_array_split():
|
|||
|
||||
def mnp_vsplit(input_tensor):
|
||||
a = mnp.vsplit(input_tensor, indices_or_sections=3)
|
||||
b = mnp.vsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10))
|
||||
c = mnp.vsplit(input_tensor, indices_or_sections=[0, 2])
|
||||
return a, b, c
|
||||
return a
|
||||
|
||||
|
||||
def onp_vsplit(input_array):
|
||||
a = onp.vsplit(input_array, indices_or_sections=3)
|
||||
b = onp.vsplit(input_array, indices_or_sections=(-10, -4, 5, 10))
|
||||
c = onp.vsplit(input_array, indices_or_sections=[0, 2])
|
||||
return a, b, c
|
||||
return a
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
|
@ -1123,16 +1085,12 @@ def test_vsplit():
|
|||
|
||||
def mnp_hsplit(input_tensor):
|
||||
a = mnp.hsplit(input_tensor, indices_or_sections=3)
|
||||
b = mnp.hsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10))
|
||||
c = mnp.hsplit(input_tensor, indices_or_sections=[0, 2])
|
||||
return a, b, c
|
||||
return a
|
||||
|
||||
|
||||
def onp_hsplit(input_array):
|
||||
a = onp.hsplit(input_array, indices_or_sections=3)
|
||||
b = onp.hsplit(input_array, indices_or_sections=(-10, -4, 5, 10))
|
||||
c = onp.hsplit(input_array, indices_or_sections=[0, 2])
|
||||
return a, b, c
|
||||
return a
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
|
@ -1156,17 +1114,11 @@ def test_hsplit():
|
|||
|
||||
def mnp_dsplit(input_tensor):
|
||||
a = mnp.dsplit(input_tensor, indices_or_sections=3)
|
||||
b = mnp.dsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10))
|
||||
c = mnp.dsplit(input_tensor, indices_or_sections=[0, 2])
|
||||
return a, b, c
|
||||
|
||||
return a
|
||||
|
||||
def onp_dsplit(input_array):
|
||||
a = onp.dsplit(input_array, indices_or_sections=3)
|
||||
b = onp.dsplit(input_array, indices_or_sections=(-10, -4, 5, 10))
|
||||
c = onp.dsplit(input_array, indices_or_sections=[0, 2])
|
||||
return a, b, c
|
||||
|
||||
return a
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
|
|
@ -37,13 +37,6 @@ class Cases():
|
|||
rand_int(1, 1),
|
||||
]
|
||||
|
||||
# empty arrays
|
||||
self.empty_arrs = [
|
||||
rand_int(0),
|
||||
rand_int(4, 0),
|
||||
rand_int(2, 0, 2),
|
||||
]
|
||||
|
||||
# arrays of the same size expanded across the 0th dimension
|
||||
self.expanded_arrs = [
|
||||
rand_int(2, 3),
|
||||
|
@ -244,8 +237,6 @@ def test_float_power():
|
|||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -687,11 +678,11 @@ def test_ptp():
|
|||
|
||||
|
||||
def mnp_add_dtype(x1, x2):
|
||||
return mnp.add(x1, x2, dtype=mnp.float16)
|
||||
return mnp.add(x1, x2, dtype=mnp.float32)
|
||||
|
||||
|
||||
def onp_add_dtype(x1, x2):
|
||||
return onp.add(x1, x2, dtype=onp.float16)
|
||||
return onp.add(x1, x2, dtype=onp.float32)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
|
@ -927,8 +918,6 @@ def onp_maximum(x1, x2):
|
|||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -1410,24 +1399,22 @@ def mnp_diff(input_tensor):
|
|||
a = mnp.diff(input_tensor, 2, append=3.0)
|
||||
b = mnp.diff(input_tensor, 4, prepend=6, axis=-2)
|
||||
c = mnp.diff(input_tensor, 0, append=3.0, axis=-1)
|
||||
d = mnp.diff(input_tensor, 10, prepend=6)
|
||||
e = mnp.diff(input_tensor, 1, prepend=input_tensor)
|
||||
f = mnp.ediff1d(input_tensor, to_end=input_tensor)
|
||||
g = mnp.ediff1d(input_tensor)
|
||||
h = mnp.ediff1d(input_tensor, to_begin=3)
|
||||
return a, b, c, d, e, f, g, h
|
||||
d = mnp.diff(input_tensor, 1, prepend=input_tensor)
|
||||
e = mnp.ediff1d(input_tensor, to_end=input_tensor)
|
||||
f = mnp.ediff1d(input_tensor)
|
||||
g = mnp.ediff1d(input_tensor, to_begin=3)
|
||||
return a, b, c, d, e, f, g
|
||||
|
||||
|
||||
def onp_diff(input_array):
|
||||
a = onp.diff(input_array, 2, append=3.0)
|
||||
b = onp.diff(input_array, 4, prepend=6, axis=-2)
|
||||
c = onp.diff(input_array, 0, append=3.0, axis=-1)
|
||||
d = onp.diff(input_array, 10, prepend=6)
|
||||
e = onp.diff(input_array, 1, prepend=input_array)
|
||||
f = onp.ediff1d(input_array, to_end=input_array)
|
||||
g = onp.ediff1d(input_array)
|
||||
h = onp.ediff1d(input_array, to_begin=3)
|
||||
return a, b, c, d, e, f, g, h
|
||||
d = onp.diff(input_array, 1, prepend=input_array)
|
||||
e = onp.ediff1d(input_array, to_end=input_array)
|
||||
f = onp.ediff1d(input_array)
|
||||
g = onp.ediff1d(input_array, to_begin=3)
|
||||
return a, b, c, d, e, f, g
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
|
@ -1926,7 +1913,6 @@ def test_mean():
|
|||
run_multi_test(mnp_mean, onp_mean, test_case.arrs, error=3)
|
||||
run_multi_test(mnp_mean, onp_mean, test_case.expanded_arrs, error=3)
|
||||
run_multi_test(mnp_mean, onp_mean, test_case.scalars, error=3)
|
||||
run_multi_test(mnp_mean, onp_mean, test_case.empty_arrs, error=3)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
|
@ -1961,3 +1947,14 @@ def test_exception_add():
|
|||
def test_exception_mean():
|
||||
with pytest.raises(ValueError):
|
||||
mnp.mean(to_tensor(test_case.arrs[0]), (-1, 0))
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_exception_amax():
|
||||
with pytest.raises(TypeError):
|
||||
mnp.amax(mnp.array([[1, 2], [3, 4]]).astype(mnp.float32), initial=[1.0, 2.0])
|
||||
|
|
Loading…
Reference in New Issue