From 76c8e10c4a9e486de3bdb9dfc470fd345029eda2 Mon Sep 17 00:00:00 2001 From: huangmengxi Date: Fri, 26 Feb 2021 17:40:17 +0800 Subject: [PATCH] bug fix --- mindspore/numpy/__init__.py | 6 +-- mindspore/numpy/array_creations.py | 51 +------------------ mindspore/numpy/array_ops.py | 3 ++ mindspore/numpy/logic_ops.py | 8 +-- mindspore/numpy/math_ops.py | 50 +++++++++++++++++- tests/st/numpy_native/test_array_creations.py | 20 -------- tests/st/numpy_native/test_math_ops.py | 22 +++++++- 7 files changed, 81 insertions(+), 79 deletions(-) diff --git a/mindspore/numpy/__init__.py b/mindspore/numpy/__init__.py index 8ac9969a054..f14dca72d87 100644 --- a/mindspore/numpy/__init__.py +++ b/mindspore/numpy/__init__.py @@ -35,7 +35,7 @@ from .array_creations import copy_ as copy from .array_creations import (array, asarray, asfarray, ones, zeros, full, arange, linspace, logspace, eye, identity, empty, empty_like, ones_like, zeros_like, full_like, diagonal, tril, triu, - tri, trace, cumsum, meshgrid, mgrid, ogrid, diagflat, + tri, trace, meshgrid, mgrid, ogrid, diagflat, diag, diag_indices, ix_) from .dtypes import (int_, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float_, float16, float32, float64, bool_, inf, nan, @@ -45,7 +45,7 @@ from .math_ops import (mean, inner, add, subtract, multiply, divide, true_divide matmul, square, sqrt, reciprocal, log, maximum, heaviside, amax, amin, hypot, float_power, floor, ptp, deg2rad, rad2deg, count_nonzero, positive, negative, clip, floor_divide, remainder, fix, fmod, trunc, - exp, expm1) + exp, expm1, cumsum) from .logic_ops import (not_equal, less_equal, less, greater_equal, greater, equal, isfinite, isnan, isinf, isposinf, isneginf, isscalar) @@ -70,7 +70,7 @@ math_module = ['mean', 'inner', 'add', 'subtract', 'multiply', 'divide', 'true_d 'minimum', 'matmul', 'square', 'sqrt', 'reciprocal', 'log', 'maximum', 'heaviside', 'amax', 'amin', 'hypot', 'float_power', 'floor', 'ptp', 'deg2rad', 'rad2deg', 'count_nonzero', 'positive', 'negative', 'clip', 'floor_divide', - 'remainder', 'mod', 'fix', 'fmod', 'trunc', 'exp', 'expm1', 'fabs'] + 'remainder', 'mod', 'fix', 'fmod', 'trunc', 'exp', 'expm1', 'fabs', 'cumsum'] logic_module = ['not_equal', 'less_equal', 'less', 'greater_equal', 'greater', 'equal', 'isfinite', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isscalar'] diff --git a/mindspore/numpy/array_creations.py b/mindspore/numpy/array_creations.py index 2d2b655b0fa..2ad65a80ad0 100644 --- a/mindspore/numpy/array_creations.py +++ b/mindspore/numpy/array_creations.py @@ -20,7 +20,6 @@ import numpy as onp from ..common import Tensor from ..common import dtype as mstype from ..ops import functional as F -from ..ops import operations as P from ..ops.primitive import constexpr from ..nn.layer.basic import tril as nn_tril from ..nn.layer.basic import triu as nn_triu @@ -31,7 +30,7 @@ from .utils import _check_input_for_asarray, _deep_list, _deep_tensor_to_nparray _expand, _broadcast_to_shape, _check_input_tensor, _convert_64_to_32, _get_dtype_from_scalar from .utils_const import _raise_value_error, _empty, _check_axis_valid, _max, _min, \ _check_same_type, _is_shape_empty, _check_shape, _check_dtype, _tile_size, _abs, \ - _raise_type_error, _expanded_shape, _check_axis_in_range, _check_is_float, _iota, \ + _raise_type_error, _expanded_shape, _check_is_float, _iota, \ _type_convert, _canonicalize_axis, _list_comprehensions, _ceil from .array_ops import transpose, ravel, concatenate, broadcast_arrays, reshape from .dtypes import nan @@ -41,7 +40,6 @@ from .dtypes import nan MAX_NUMPY_DIMS = 32 # All types that can be accepted as "array_like" parameters in graph mode. ARRAY_TYPES = (int, float, bool, list, tuple, Tensor) -_cumsum_default = P.CumSum() def array(obj, dtype=None, copy=True, ndmin=0): @@ -1172,53 +1170,6 @@ def trace(a, offset=0, axis1=0, axis2=1, dtype=None): return res -def cumsum(a, axis=None, dtype=None): - """ - Returns the cumulative sum of the elements along a given axis. - - Args: - a (Tensor): Input tensor. - axis (int, optional): Axis along which the cumulative sum is computed. The - default (None) is to compute the cumsum over the flattened array. - dtype (:class:`mindspore.dtype`, optional): If not specified, stay the same as `a`, - unless `a` has an integer dtype with a precision less than that of the - default platform integer. In that case, the default platform integer - is used. - - Returns: - Tensor. - - Raises: - TypeError: If input arguments have types not specified above. - ValueError: If axis is out of range. - - Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` - - Examples: - >>> output = np.cumsum(np.ones((3,3)), axis=0) - >>> print(output) - [[1. 1. 1.] - [2. 2. 2.] - [3. 3. 3.]] - """ - _check_input_tensor(a) - original_dtype = F.dtype(a) - # If original array is int, and has precision less then int32, convert to int32 - if _check_same_type(original_dtype, mstype.bool_) or \ - _check_same_type(original_dtype, mstype.int8) or \ - _check_same_type(original_dtype, mstype.int16): - original_dtype = mstype.int32 - a = a.astype(mstype.float32) - if axis is None: - a = a.ravel() - axis = 0 - _check_axis_in_range(axis, a.ndim) - if dtype is not None and not _check_same_type(original_dtype, dtype): - return _cumsum_default(a, axis).astype(dtype, copy=False) - return _cumsum_default(a, axis).astype(original_dtype, copy=False) - - def _index(i, size, Cartesian=True): """If Cartesian=True, index 0 is swapped with index 1.""" if Cartesian: diff --git a/mindspore/numpy/array_ops.py b/mindspore/numpy/array_ops.py index 9c8fe7efc77..ec4d14bb54c 100644 --- a/mindspore/numpy/array_ops.py +++ b/mindspore/numpy/array_ops.py @@ -1905,8 +1905,11 @@ def repeat(a, repeats, axis=None): if repeats == 0: return _empty(F.dtype(a), (0,)) return C.repeat_elements(a, repeats, axis) + shape = F.shape(a) size = shape[axis] + if len(repeats) != size: + _raise_value_error('operands could not be broadcast together') subs = split(a, size, axis) repeated_subs = [] for sub, rep in zip(subs, repeats): diff --git a/mindspore/numpy/logic_ops.py b/mindspore/numpy/logic_ops.py index c294431220e..754dca4432f 100644 --- a/mindspore/numpy/logic_ops.py +++ b/mindspore/numpy/logic_ops.py @@ -361,7 +361,7 @@ def isnan(x, out=None, where=True, dtype=None): When `where` is provided, `out` must have a tensor value. `out` is not supported for storing the result, however it can be used in combination with `where` to set the value at indices for which `where` is set to False. - On GPU, the supported dtypes are np.float16, and np.float32. + Only np.float32 is currently supported. Args: x (Tensor): Input values. @@ -422,7 +422,7 @@ def isinf(x, out=None, where=True, dtype=None): When `where` is provided, `out` must have a tensor value. `out` is not supported for storing the result, however it can be used in combination with `where` to set the value at indices for which `where` is set to False. - On GPU, the supported dtypes are np.float16, and np.float32. + Only np.float32 is currently supported. Args: x (Tensor): Input values. @@ -477,7 +477,7 @@ def isposinf(x): Note: Numpy argument `out` is not supported. - On GPU, the supported dtypes are np.float16, and np.float32. + Only np.float32 is currently supported. Args: x (Tensor): Input values. @@ -507,7 +507,7 @@ def isneginf(x): Note: Numpy argument `out` is not supported. - On GPU, the supported dtypes are np.float16, and np.float32. + Only np.float32 is currently supported. Args: x (Tensor): Input values. diff --git a/mindspore/numpy/math_ops.py b/mindspore/numpy/math_ops.py index 5b0208f3ae5..76540083757 100644 --- a/mindspore/numpy/math_ops.py +++ b/mindspore/numpy/math_ops.py @@ -32,7 +32,7 @@ from .array_ops import ravel, expand_dims from .utils_const import _infer_out_shape, _check_axis_valid, _get_device, \ _check_shape_aligned, _raise_type_error, _check_same_type, _check_is_float, \ _raise_value_error, _check_matmul_shapes, _promote, _check_axis_type, _canonicalize_axis, \ - _max, _is_shape_empty, _check_is_int, _expanded_shape + _max, _is_shape_empty, _check_is_int, _expanded_shape, _check_axis_in_range from .utils import _is_scalar, _expand, _broadcast_to, _broadcast_to_shape, _get_size, \ _check_input_tensor @@ -50,6 +50,7 @@ _reduce_min_default = P.ReduceMin() _reduce_min_keepdims = P.ReduceMin(True) _reduce_max_default = P.ReduceMax() _reduce_max_keepdims = P.ReduceMax(True) +_cumsum_default = P.CumSum() def absolute(x, out=None, where=True, dtype=None): """ @@ -2385,6 +2386,53 @@ def negative(a, out=None, where=True, dtype=None): return _apply_tensor_op(F.neg_tensor, a, out=out, where=where, dtype=dtype) +def cumsum(a, axis=None, dtype=None): + """ + Returns the cumulative sum of the elements along a given axis. + + Args: + a (Tensor): Input tensor. + axis (int, optional): Axis along which the cumulative sum is computed. The + default (None) is to compute the cumsum over the flattened array. + dtype (:class:`mindspore.dtype`, optional): If not specified, stay the same as `a`, + unless `a` has an integer dtype with a precision less than that of the + default platform integer. In that case, the default platform integer + is used. + + Returns: + Tensor. + + Raises: + TypeError: If input arguments have types not specified above. + ValueError: If axis is out of range. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.cumsum(np.ones((3,3)), axis=0) + >>> print(output) + [[1. 1. 1.] + [2. 2. 2.] + [3. 3. 3.]] + """ + _check_input_tensor(a) + original_dtype = F.dtype(a) + # If original array is int, and has precision less then int32, convert to int32 + if _check_same_type(original_dtype, mstype.bool_) or \ + _check_same_type(original_dtype, mstype.int8) or \ + _check_same_type(original_dtype, mstype.int16): + original_dtype = mstype.int32 + a = a.astype(mstype.float32) + if axis is None: + a = a.ravel() + axis = 0 + _check_axis_in_range(axis, a.ndim) + if dtype is not None and not _check_same_type(original_dtype, dtype): + return _cumsum_default(a, axis).astype(dtype, copy=False) + return _cumsum_default(a, axis).astype(original_dtype, copy=False) + + def _apply_tensor_op(fn, *args, out=None, where=True, dtype=None): """Applies tensor operations based on fn""" _check_input_tensor(*args) diff --git a/tests/st/numpy_native/test_array_creations.py b/tests/st/numpy_native/test_array_creations.py index 431866e46f0..75c4754acf7 100644 --- a/tests/st/numpy_native/test_array_creations.py +++ b/tests/st/numpy_native/test_array_creations.py @@ -549,26 +549,6 @@ def test_tri_triu_tril(): match_array(mnp.tri(64, 64, -10).asnumpy(), onp.tri(64, 64, -10)) -@pytest.mark.level1 -@pytest.mark.platform_arm_ascend_training -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_x86_gpu_training -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_cumsum(): - x = mnp.ones((16, 16), dtype="bool") - match_array(mnp.cumsum(x).asnumpy(), onp.cumsum(x.asnumpy())) - match_array(mnp.cumsum(x, axis=0).asnumpy(), - onp.cumsum(x.asnumpy(), axis=0)) - match_meta(mnp.cumsum(x).asnumpy(), onp.cumsum(x.asnumpy())) - - x = rand_int(3, 4, 5) - match_array(mnp.cumsum(mnp.asarray(x), dtype="bool").asnumpy(), - onp.cumsum(x, dtype="bool")) - match_array(mnp.cumsum(mnp.asarray(x), axis=-1).asnumpy(), - onp.cumsum(x, axis=-1)) - - def mnp_diagonal(arr): return mnp.diagonal(arr, offset=2, axis1=-1, axis2=0) diff --git a/tests/st/numpy_native/test_math_ops.py b/tests/st/numpy_native/test_math_ops.py index a6d87d845da..1fc2b86e06c 100644 --- a/tests/st/numpy_native/test_math_ops.py +++ b/tests/st/numpy_native/test_math_ops.py @@ -20,7 +20,7 @@ import numpy as onp import mindspore.numpy as mnp from .utils import rand_int, rand_bool, run_binop_test, run_unary_test, run_multi_test, \ - run_single_test, match_res, match_array + run_single_test, match_res, match_array, match_meta class Cases(): def __init__(self): @@ -1138,6 +1138,26 @@ def test_negative(): match_array(mnp_neg.asnumpy(), onp_neg, 1e-5) +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cumsum(): + x = mnp.ones((16, 16), dtype="bool") + match_array(mnp.cumsum(x).asnumpy(), onp.cumsum(x.asnumpy())) + match_array(mnp.cumsum(x, axis=0).asnumpy(), + onp.cumsum(x.asnumpy(), axis=0)) + match_meta(mnp.cumsum(x).asnumpy(), onp.cumsum(x.asnumpy())) + + x = rand_int(3, 4, 5) + match_array(mnp.cumsum(mnp.asarray(x), dtype="bool").asnumpy(), + onp.cumsum(x, dtype="bool")) + match_array(mnp.cumsum(mnp.asarray(x), axis=-1).asnumpy(), + onp.cumsum(x, axis=-1)) + + @pytest.mark.level1 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training