forked from mindspore-Ecosystem/mindspore
!12656 numpy-native repeat add checking for repeats shape
From: @jachua Reviewed-by: Signed-off-by:
This commit is contained in:
commit
692d158f5c
|
@ -35,7 +35,7 @@ from .array_creations import copy_ as copy
|
|||
from .array_creations import (array, asarray, asfarray, ones, zeros, full, arange,
|
||||
linspace, logspace, eye, identity, empty, empty_like,
|
||||
ones_like, zeros_like, full_like, diagonal, tril, triu,
|
||||
tri, trace, cumsum, meshgrid, mgrid, ogrid, diagflat,
|
||||
tri, trace, meshgrid, mgrid, ogrid, diagflat,
|
||||
diag, diag_indices, ix_)
|
||||
from .dtypes import (int_, int8, int16, int32, int64, uint, uint8, uint16,
|
||||
uint32, uint64, float_, float16, float32, float64, bool_, inf, nan,
|
||||
|
@ -45,7 +45,7 @@ from .math_ops import (mean, inner, add, subtract, multiply, divide, true_divide
|
|||
matmul, square, sqrt, reciprocal, log, maximum, heaviside, amax, amin,
|
||||
hypot, float_power, floor, ptp, deg2rad, rad2deg, count_nonzero,
|
||||
positive, negative, clip, floor_divide, remainder, fix, fmod, trunc,
|
||||
exp, expm1)
|
||||
exp, expm1, cumsum)
|
||||
from .logic_ops import (not_equal, less_equal, less, greater_equal, greater, equal, isfinite,
|
||||
isnan, isinf, isposinf, isneginf, isscalar)
|
||||
|
||||
|
@ -70,7 +70,7 @@ math_module = ['mean', 'inner', 'add', 'subtract', 'multiply', 'divide', 'true_d
|
|||
'minimum', 'matmul', 'square', 'sqrt', 'reciprocal', 'log', 'maximum',
|
||||
'heaviside', 'amax', 'amin', 'hypot', 'float_power', 'floor', 'ptp', 'deg2rad',
|
||||
'rad2deg', 'count_nonzero', 'positive', 'negative', 'clip', 'floor_divide',
|
||||
'remainder', 'mod', 'fix', 'fmod', 'trunc', 'exp', 'expm1', 'fabs']
|
||||
'remainder', 'mod', 'fix', 'fmod', 'trunc', 'exp', 'expm1', 'fabs', 'cumsum']
|
||||
|
||||
logic_module = ['not_equal', 'less_equal', 'less', 'greater_equal', 'greater', 'equal', 'isfinite',
|
||||
'isnan', 'isinf', 'isposinf', 'isneginf', 'isscalar']
|
||||
|
|
|
@ -20,7 +20,6 @@ import numpy as onp
|
|||
from ..common import Tensor
|
||||
from ..common import dtype as mstype
|
||||
from ..ops import functional as F
|
||||
from ..ops import operations as P
|
||||
from ..ops.primitive import constexpr
|
||||
from ..nn.layer.basic import tril as nn_tril
|
||||
from ..nn.layer.basic import triu as nn_triu
|
||||
|
@ -31,7 +30,7 @@ from .utils import _check_input_for_asarray, _deep_list, _deep_tensor_to_nparray
|
|||
_expand, _broadcast_to_shape, _check_input_tensor, _convert_64_to_32, _get_dtype_from_scalar
|
||||
from .utils_const import _raise_value_error, _empty, _check_axis_valid, _max, _min, \
|
||||
_check_same_type, _is_shape_empty, _check_shape, _check_dtype, _tile_size, _abs, \
|
||||
_raise_type_error, _expanded_shape, _check_axis_in_range, _check_is_float, _iota, \
|
||||
_raise_type_error, _expanded_shape, _check_is_float, _iota, \
|
||||
_type_convert, _canonicalize_axis, _list_comprehensions, _ceil
|
||||
from .array_ops import transpose, ravel, concatenate, broadcast_arrays, reshape
|
||||
from .dtypes import nan
|
||||
|
@ -41,7 +40,6 @@ from .dtypes import nan
|
|||
MAX_NUMPY_DIMS = 32
|
||||
# All types that can be accepted as "array_like" parameters in graph mode.
|
||||
ARRAY_TYPES = (int, float, bool, list, tuple, Tensor)
|
||||
_cumsum_default = P.CumSum()
|
||||
|
||||
|
||||
def array(obj, dtype=None, copy=True, ndmin=0):
|
||||
|
@ -1172,53 +1170,6 @@ def trace(a, offset=0, axis1=0, axis2=1, dtype=None):
|
|||
return res
|
||||
|
||||
|
||||
def cumsum(a, axis=None, dtype=None):
|
||||
"""
|
||||
Returns the cumulative sum of the elements along a given axis.
|
||||
|
||||
Args:
|
||||
a (Tensor): Input tensor.
|
||||
axis (int, optional): Axis along which the cumulative sum is computed. The
|
||||
default (None) is to compute the cumsum over the flattened array.
|
||||
dtype (:class:`mindspore.dtype`, optional): If not specified, stay the same as `a`,
|
||||
unless `a` has an integer dtype with a precision less than that of the
|
||||
default platform integer. In that case, the default platform integer
|
||||
is used.
|
||||
|
||||
Returns:
|
||||
Tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If input arguments have types not specified above.
|
||||
ValueError: If axis is out of range.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> output = np.cumsum(np.ones((3,3)), axis=0)
|
||||
>>> print(output)
|
||||
[[1. 1. 1.]
|
||||
[2. 2. 2.]
|
||||
[3. 3. 3.]]
|
||||
"""
|
||||
_check_input_tensor(a)
|
||||
original_dtype = F.dtype(a)
|
||||
# If original array is int, and has precision less then int32, convert to int32
|
||||
if _check_same_type(original_dtype, mstype.bool_) or \
|
||||
_check_same_type(original_dtype, mstype.int8) or \
|
||||
_check_same_type(original_dtype, mstype.int16):
|
||||
original_dtype = mstype.int32
|
||||
a = a.astype(mstype.float32)
|
||||
if axis is None:
|
||||
a = a.ravel()
|
||||
axis = 0
|
||||
_check_axis_in_range(axis, a.ndim)
|
||||
if dtype is not None and not _check_same_type(original_dtype, dtype):
|
||||
return _cumsum_default(a, axis).astype(dtype, copy=False)
|
||||
return _cumsum_default(a, axis).astype(original_dtype, copy=False)
|
||||
|
||||
|
||||
def _index(i, size, Cartesian=True):
|
||||
"""If Cartesian=True, index 0 is swapped with index 1."""
|
||||
if Cartesian:
|
||||
|
|
|
@ -1905,8 +1905,11 @@ def repeat(a, repeats, axis=None):
|
|||
if repeats == 0:
|
||||
return _empty(F.dtype(a), (0,))
|
||||
return C.repeat_elements(a, repeats, axis)
|
||||
|
||||
shape = F.shape(a)
|
||||
size = shape[axis]
|
||||
if len(repeats) != size:
|
||||
_raise_value_error('operands could not be broadcast together')
|
||||
subs = split(a, size, axis)
|
||||
repeated_subs = []
|
||||
for sub, rep in zip(subs, repeats):
|
||||
|
|
|
@ -361,7 +361,7 @@ def isnan(x, out=None, where=True, dtype=None):
|
|||
When `where` is provided, `out` must have a tensor value. `out` is not supported
|
||||
for storing the result, however it can be used in combination with `where` to set
|
||||
the value at indices for which `where` is set to False.
|
||||
On GPU, the supported dtypes are np.float16, and np.float32.
|
||||
Only np.float32 is currently supported.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input values.
|
||||
|
@ -422,7 +422,7 @@ def isinf(x, out=None, where=True, dtype=None):
|
|||
When `where` is provided, `out` must have a tensor value. `out` is not supported
|
||||
for storing the result, however it can be used in combination with `where` to set
|
||||
the value at indices for which `where` is set to False.
|
||||
On GPU, the supported dtypes are np.float16, and np.float32.
|
||||
Only np.float32 is currently supported.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input values.
|
||||
|
@ -477,7 +477,7 @@ def isposinf(x):
|
|||
|
||||
Note:
|
||||
Numpy argument `out` is not supported.
|
||||
On GPU, the supported dtypes are np.float16, and np.float32.
|
||||
Only np.float32 is currently supported.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input values.
|
||||
|
@ -507,7 +507,7 @@ def isneginf(x):
|
|||
|
||||
Note:
|
||||
Numpy argument `out` is not supported.
|
||||
On GPU, the supported dtypes are np.float16, and np.float32.
|
||||
Only np.float32 is currently supported.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input values.
|
||||
|
|
|
@ -32,7 +32,7 @@ from .array_ops import ravel, expand_dims
|
|||
from .utils_const import _infer_out_shape, _check_axis_valid, _get_device, \
|
||||
_check_shape_aligned, _raise_type_error, _check_same_type, _check_is_float, \
|
||||
_raise_value_error, _check_matmul_shapes, _promote, _check_axis_type, _canonicalize_axis, \
|
||||
_max, _is_shape_empty, _check_is_int, _expanded_shape
|
||||
_max, _is_shape_empty, _check_is_int, _expanded_shape, _check_axis_in_range
|
||||
from .utils import _is_scalar, _expand, _broadcast_to, _broadcast_to_shape, _get_size, \
|
||||
_check_input_tensor
|
||||
|
||||
|
@ -50,6 +50,7 @@ _reduce_min_default = P.ReduceMin()
|
|||
_reduce_min_keepdims = P.ReduceMin(True)
|
||||
_reduce_max_default = P.ReduceMax()
|
||||
_reduce_max_keepdims = P.ReduceMax(True)
|
||||
_cumsum_default = P.CumSum()
|
||||
|
||||
def absolute(x, out=None, where=True, dtype=None):
|
||||
"""
|
||||
|
@ -2385,6 +2386,53 @@ def negative(a, out=None, where=True, dtype=None):
|
|||
return _apply_tensor_op(F.neg_tensor, a, out=out, where=where, dtype=dtype)
|
||||
|
||||
|
||||
def cumsum(a, axis=None, dtype=None):
|
||||
"""
|
||||
Returns the cumulative sum of the elements along a given axis.
|
||||
|
||||
Args:
|
||||
a (Tensor): Input tensor.
|
||||
axis (int, optional): Axis along which the cumulative sum is computed. The
|
||||
default (None) is to compute the cumsum over the flattened array.
|
||||
dtype (:class:`mindspore.dtype`, optional): If not specified, stay the same as `a`,
|
||||
unless `a` has an integer dtype with a precision less than that of the
|
||||
default platform integer. In that case, the default platform integer
|
||||
is used.
|
||||
|
||||
Returns:
|
||||
Tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If input arguments have types not specified above.
|
||||
ValueError: If axis is out of range.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> output = np.cumsum(np.ones((3,3)), axis=0)
|
||||
>>> print(output)
|
||||
[[1. 1. 1.]
|
||||
[2. 2. 2.]
|
||||
[3. 3. 3.]]
|
||||
"""
|
||||
_check_input_tensor(a)
|
||||
original_dtype = F.dtype(a)
|
||||
# If original array is int, and has precision less then int32, convert to int32
|
||||
if _check_same_type(original_dtype, mstype.bool_) or \
|
||||
_check_same_type(original_dtype, mstype.int8) or \
|
||||
_check_same_type(original_dtype, mstype.int16):
|
||||
original_dtype = mstype.int32
|
||||
a = a.astype(mstype.float32)
|
||||
if axis is None:
|
||||
a = a.ravel()
|
||||
axis = 0
|
||||
_check_axis_in_range(axis, a.ndim)
|
||||
if dtype is not None and not _check_same_type(original_dtype, dtype):
|
||||
return _cumsum_default(a, axis).astype(dtype, copy=False)
|
||||
return _cumsum_default(a, axis).astype(original_dtype, copy=False)
|
||||
|
||||
|
||||
def _apply_tensor_op(fn, *args, out=None, where=True, dtype=None):
|
||||
"""Applies tensor operations based on fn"""
|
||||
_check_input_tensor(*args)
|
||||
|
|
|
@ -549,26 +549,6 @@ def test_tri_triu_tril():
|
|||
match_array(mnp.tri(64, 64, -10).asnumpy(), onp.tri(64, 64, -10))
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_cumsum():
|
||||
x = mnp.ones((16, 16), dtype="bool")
|
||||
match_array(mnp.cumsum(x).asnumpy(), onp.cumsum(x.asnumpy()))
|
||||
match_array(mnp.cumsum(x, axis=0).asnumpy(),
|
||||
onp.cumsum(x.asnumpy(), axis=0))
|
||||
match_meta(mnp.cumsum(x).asnumpy(), onp.cumsum(x.asnumpy()))
|
||||
|
||||
x = rand_int(3, 4, 5)
|
||||
match_array(mnp.cumsum(mnp.asarray(x), dtype="bool").asnumpy(),
|
||||
onp.cumsum(x, dtype="bool"))
|
||||
match_array(mnp.cumsum(mnp.asarray(x), axis=-1).asnumpy(),
|
||||
onp.cumsum(x, axis=-1))
|
||||
|
||||
|
||||
def mnp_diagonal(arr):
|
||||
return mnp.diagonal(arr, offset=2, axis1=-1, axis2=0)
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ import numpy as onp
|
|||
import mindspore.numpy as mnp
|
||||
|
||||
from .utils import rand_int, rand_bool, run_binop_test, run_unary_test, run_multi_test, \
|
||||
run_single_test, match_res, match_array
|
||||
run_single_test, match_res, match_array, match_meta
|
||||
|
||||
class Cases():
|
||||
def __init__(self):
|
||||
|
@ -1138,6 +1138,26 @@ def test_negative():
|
|||
match_array(mnp_neg.asnumpy(), onp_neg, 1e-5)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_cumsum():
|
||||
x = mnp.ones((16, 16), dtype="bool")
|
||||
match_array(mnp.cumsum(x).asnumpy(), onp.cumsum(x.asnumpy()))
|
||||
match_array(mnp.cumsum(x, axis=0).asnumpy(),
|
||||
onp.cumsum(x.asnumpy(), axis=0))
|
||||
match_meta(mnp.cumsum(x).asnumpy(), onp.cumsum(x.asnumpy()))
|
||||
|
||||
x = rand_int(3, 4, 5)
|
||||
match_array(mnp.cumsum(mnp.asarray(x), dtype="bool").asnumpy(),
|
||||
onp.cumsum(x, dtype="bool"))
|
||||
match_array(mnp.cumsum(mnp.asarray(x), axis=-1).asnumpy(),
|
||||
onp.cumsum(x, axis=-1))
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
|
Loading…
Reference in New Issue