From 139ee0120532c59020744cdf88470938fa247dda Mon Sep 17 00:00:00 2001 From: liuyang_655 Date: Tue, 8 Mar 2022 22:59:17 -0500 Subject: [PATCH] add some numpy method --- mindspore/python/mindspore/numpy/__init__.py | 10 +- .../python/mindspore/numpy/array_creations.py | 195 +++++++++++++++++- mindspore/python/mindspore/ops/functional.py | 66 ++++++ tests/st/numpy_native/test_array_creations.py | 123 +++++++++++ 4 files changed, 388 insertions(+), 6 deletions(-) diff --git a/mindspore/python/mindspore/numpy/__init__.py b/mindspore/python/mindspore/numpy/__init__.py index 05c7b299a3b..87fe1603522 100644 --- a/mindspore/python/mindspore/numpy/__init__.py +++ b/mindspore/python/mindspore/numpy/__init__.py @@ -34,7 +34,7 @@ from .array_ops import (transpose, expand_dims, squeeze, rollaxis, swapaxes, res rot90, select, array_split, choose, size, array_str, apply_along_axis, piecewise, unravel_index, apply_over_axes) from .array_creations import copy_ as copy -from .array_creations import (array, asarray, asfarray, ones, zeros, full, arange, +from .array_creations import (array, asarray, asfarray, ones, zeros, full, randn, rand, randint, arange, linspace, logspace, eye, identity, empty, empty_like, ones_like, zeros_like, full_like, diagonal, tril, triu, tri, trace, meshgrid, mgrid, ogrid, diagflat, @@ -83,10 +83,10 @@ array_ops_module = ['transpose', 'expand_dims', 'squeeze', 'rollaxis', 'swapaxes 'repeat', 'rot90', 'select', 'array_split', 'choose', 'size', 'array_str', 'apply_along_axis', 'piecewise', 'unravel_index', 'apply_over_axes'] -array_creations_module = ['array', 'asarray', 'asfarray', 'ones', 'zeros', 'full', 'arange', - 'linspace', 'logspace', 'eye', 'identity', 'empty', 'empty_like', - 'ones_like', 'zeros_like', 'full_like', 'diagonal', 'tril', 'triu', - 'tri', 'trace', 'meshgrid', 'mgrid', 'ogrid', 'diagflat', 'diag', +array_creations_module = ['array', 'asarray', 'asfarray', 'ones', 'zeros', 'full', 'randn', 'rand', + 'randint', 'arange', 'linspace', 'logspace', 'eye', 'identity', 'empty', + 'empty_like', 'ones_like', 'zeros_like', 'full_like', 'diagonal', 'tril', + 'triu', 'tri', 'trace', 'meshgrid', 'mgrid', 'ogrid', 'diagflat', 'diag', 'diag_indices', 'ix_', 'indices', 'geomspace', 'vander', 'hamming', 'hanning', 'bartlett', 'blackman', 'triu_indices', 'tril_indices', 'triu_indices_from', 'tril_indices_from', 'histogram_bin_edges', 'pad'] diff --git a/mindspore/python/mindspore/numpy/array_creations.py b/mindspore/python/mindspore/numpy/array_creations.py index ccc4c8fd2c9..02c31cbf5af 100644 --- a/mindspore/python/mindspore/numpy/array_creations.py +++ b/mindspore/python/mindspore/numpy/array_creations.py @@ -21,6 +21,7 @@ import numpy as onp from .. import context from ..common import Tensor from ..common import dtype as mstype +from ..common.seed import get_seed from ..ops import operations as P from ..ops import functional as F from ..ops.primitive import constexpr @@ -38,7 +39,7 @@ from .utils_const import _raise_value_error, _empty, _max, _min, \ _tuple_setitem from .array_ops import ravel, concatenate, broadcast_arrays, reshape, broadcast_to, flip, \ apply_along_axis, where, moveaxis -from .dtypes import nan, pi +from .dtypes import nan, pi, dtype_map # According to official numpy reference, the dimension of a numpy array must be less # than 32 @@ -379,6 +380,198 @@ def full(shape, fill_value, dtype=None): return _convert_64_to_32(empty_compile(dtype, shape)) +def _generate_shapes(shape): + """Generate shapes for randn and rand.""" + if not shape: + size = (1,) + elif len(shape) == 1: + if isinstance(shape[0], int): + size = shape + elif isinstance(shape[0], list): + size = tuple(shape[0]) + elif isinstance(shape[0], tuple): + size = shape[0] + else: + raise TypeError("If the length of the argument 'shape' is 1, the type of the argument 'shape' must be " + "one of ['int', 'list', 'tuple'], but got {}.".format(type(shape[0]))) + else: + for index, value in enumerate(shape): + if not isinstance(value, int): + raise TypeError("If the length of the argument 'shape' is > 1, the type of the argument 'shape' must " + "all be int, but got {} at index {}.".format(type(value), index)) + size = shape + return size + + +def _check_rand_type(dtype): + """Check type for randn and rand""" + type_list = ['float', 'float16', 'float32', 'float64'] + if isinstance(dtype, str): + if dtype not in type_list: + raise ValueError("If the argument 'dtype' is str, it must be one of {}, but got {}." + .format(type_list, dtype)) + try: + dtype = dtype_map[dtype] + except KeyError: + raise KeyError("Unsupported dtype {}.".format(dtype)) + elif dtype not in (mstype.float64, mstype.float32, mstype.float16): + raise ValueError("The argument 'dtype' must be 'mindspore.float64', 'mindspore.float32' or " + "'mindspore.float16', but got {}.".format(dtype)) + + +def randn(*shape, dtype=mstype.float32): + """ + Returns a new Tensor with given shape and dtype, filled with a sample (or samples) + from the standard normal distribution. + + Args: + *shape (Union[int, tuple(int), list(int)]): Shape of the new tensor, e.g., + :math:`(2, 3)` or :math:`2`. + dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, it must + be float type. Default is :class:`mindspore.float32`. + + Returns: + Tensor, with the designated shape and dtype, filled with a sample (or samples) + from the "standard normal" distribution. + + Raises: + TypeError: If input arguments have types not specified above. + ValueError: If `dtype` is not float type. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> from mindspore import set_seed + >>> set_seed(1) + >>> print(np.randn((2,3))) + [[ 0.30639967 -0.42438635 -0.20454668] + [-0.4287376 1.3054721 0.64747655]] + """ + _check_rand_type(dtype) + size = _generate_shapes(shape) + seed = get_seed() + if seed is not None: + stdnormal = P.StandardNormal(seed=seed) + else: + stdnormal = P.StandardNormal() + return stdnormal(size).astype(dtype) + + +def rand(*shape, dtype=mstype.float32): + """ + Returns a new Tensor with given shape and dtype, filled with random numbers from the + uniform distribution on the interval :math:`[0, 1)`. + + Args: + *shape (Union[int, tuple(int), list(int)]): Shape of the new tensor, e.g., + :math:`(2, 3)` or :math:`2`. + dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, it must + be float type. Default is :class:`mindspore.float32`. + + Returns: + Tensor, with the designated shape and dtype, filled with random numbers from the + uniform distribution on the interval :math:`[0, 1)`. + + Raises: + TypeError: If input arguments have types not specified above. + ValueError: If `dtype` is not float type. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> from mindspore import set_seed + >>> set_seed(1) + >>> print(np.rand((2,3))) + [[4.1702199e-01 9.9718481e-01 7.2032452e-01] + [9.3255734e-01 1.1438108e-04 1.2812445e-01]] + """ + _check_rand_type(dtype) + size = _generate_shapes(shape) + seed = get_seed() + if seed is not None: + uniformreal = P.UniformReal(seed=seed) + else: + uniformreal = P.UniformReal() + return uniformreal(size).astype(dtype) + + +def randint(minval, maxval=None, shape=None, dtype=mstype.int32): + """ + Return random integers from minval (inclusive) to maxval (exclusive). Return random integers from the + discrete uniform distribution of the specified dtype in the “half-open” interval :math:`[minval, maxval)`. + If maxval is None (the default), then results are from [0, maxval). + + Args: + minval(Union[int]): Start value of interval. The interval includes this value. When `maxval` + is :class:`None`, `minval` must be greater than 0. When `maxval` is not :class:`None`, + `minval` must be less than `maxval`. + maxval(Union[int], optional): End value of interval. The interval does not include this value. + shape (Union[int, tuple(int)]): Shape of the new tensor, e.g., :math:`(2, 3)` or :math:`2`. + dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, it must + be int type. Default is :class:`mindspore.int32`. + + Returns: + Tensor, with the designated shape and dtype, filled with random integers from minval (inclusive) + to maxval (exclusive). + + Raises: + TypeError: If input arguments have types not specified above. + ValueError: If input arguments have values not specified above. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> from mindspore import set_seed + >>> set_seed(1) + >>> print(np.randint(1, 10, (2,3))) + [[4 9 7] + [9 1 2]] + """ + if not isinstance(minval, int): + raise TypeError("For mindspore.numpy.randint, the type of the argument 'minval' must be int, " + "but got {}.".format(type(minval))) + if maxval is None: + if minval <= 0: + raise ValueError("For mindspore.numpy.randint, the argument 'minval' must be > 0 when the argument " + "'maxval' is None, but got {}.".format(minval)) + maxval = minval + minval = 0 + else: + if not isinstance(maxval, int): + raise TypeError("For mindspore.numpy.randint, the type of the argument 'maxval' must be int, " + "but got {}.".format(type(maxval))) + if minval >= maxval: + raise ValueError("For mindspore.numpy.randint, the value of 'minval' must be greater than the value of " + "'maxval', but got 'minval': {} and 'maxval': {}.".format(minval, maxval)) + if isinstance(dtype, str): + if dtype not in ('int', 'int8', 'int16', 'int32', 'int64'): + raise ValueError("For 'mindspore.numpy.randint', if the argument 'dtype' is str, it must be one of " + "['int', 'int8', 'int16', 'int32', 'int64'], but got {}.".format(dtype)) + try: + dtype = dtype_map[dtype] + except KeyError: + raise KeyError("Unsupported dtype {}.".format(dtype)) + elif dtype not in (mstype.int64, mstype.int32, mstype.int16, mstype.int8): + raise ValueError("For 'mindspore.numpy.randint', the argument 'dtype' must be 'mindspore.int64', " + "'mindspore.int32', 'mindspore.int16' or 'mindspore.int8', but got {}.".format(dtype)) + if shape is None: + shape = (1,) + else: + shape = _check_shape(shape) + seed = get_seed() + if seed is not None: + uniformint = P.UniformInt(seed=seed) + else: + uniformint = P.UniformInt() + return uniformint(shape, Tensor(minval, mstype.int32), Tensor(maxval, mstype.int32)).astype(dtype) + + def arange(start, stop=None, step=None, dtype=None): """ Returns evenly spaced values within a given interval. diff --git a/mindspore/python/mindspore/ops/functional.py b/mindspore/python/mindspore/ops/functional.py index 692131045ac..876157f1426 100644 --- a/mindspore/python/mindspore/ops/functional.py +++ b/mindspore/python/mindspore/ops/functional.py @@ -27,6 +27,7 @@ from mindspore.nn.grad.cell_grad import _JvpInner from mindspore.nn.grad.cell_grad import _VjpInner from mindspore.ops import _constants from mindspore.ops.primitive import constexpr +import numpy as np from .primitive import Primitive from . import operations as P from .operations import _grad_ops @@ -395,6 +396,71 @@ shard_fn = Shard() def shard(fn, in_axes, out_axes, device="Ascend", level=0): return shard_fn(fn, in_axes, out_axes, device, level) + +def arange(start=0, stop=None, step=1, rtype=None): + """ + Returns evenly spaced values within a given interval. + + Args: + start(Union[int, float]): Start value of interval. The interval includes this value. When + `stop` is None, `start` must be greater than 0, and the interval is :math:`[0, start)`. + When `stop` is not None, `start` must be less than `stop`. + stop(Union[int, float], optional): End value of interval. The interval does not + include this value. Default is None. + step(Union[int, float], optional): Spacing between values. For any output + `out`, this is the distance between two adjacent values, :math:`out[i+1] - out[i]`. + The default step size is 1. If `step` is specified as a position argument, + `start` must also be given. + rtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor type. + If rtype is None, the data type of the new tensor will be inferred from start, + stop and step. Default is None. + + Returns: + Tensor with evenly spaced values. + + Raises: + TypeError: If input arguments have types not specified above. + ValueError: If input arguments have values not specified above. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.ops as ops + >>> print(ops.arange(0, 5, 1)) + [0 1 2 3 4] + >>> print(ops.arange(3)) + [0 1 2] + >>> print(ops.arange(start=0, stop=3)) + [0 1 2] + >>> print(ops.arange(0, stop=3, step=0.5)) + [0. 0.5 1. 1.5 2. 2.5] + """ + if stop is None: + start, stop = 0, start + + arg_map = {"start": start, "stop": stop, "step": step} + for arg in arg_map: + try: + arg_value = arg_map[arg] + except KeyError: + raise KeyError("Unsupported key {}, the key must be one of ['start', 'stop', 'step'].".format(arg_value)) + if not isinstance(arg_value, int) and not isinstance(arg_value, float): + raise TypeError("For mindspore.ops.range, the argument '{}' must be int or float, but got {}." + .format(arg, type(arg_value))) + if start >= stop: + raise ValueError("For mindspore.ops.range, the argument 'start' must be < 'stop', but got 'start': {}, " + "'stop': {}.".format(start, stop)) + + if rtype is None: + data = np.arange(start, stop, step) + if data.dtype == int: + rtype = mstype.int32 + else: + rtype = mstype.float32 + return Tensor(np.arange(start, stop, step), dtype=rtype) + + def narrow(inputs, axis, start, length): """ Returns a narrowed tensor from input tensor. diff --git a/tests/st/numpy_native/test_array_creations.py b/tests/st/numpy_native/test_array_creations.py index ece4201d6a5..05503250aa2 100644 --- a/tests/st/numpy_native/test_array_creations.py +++ b/tests/st/numpy_native/test_array_creations.py @@ -17,7 +17,10 @@ import pytest import numpy as onp import mindspore.numpy as mnp +import mindspore.ops.functional as F from mindspore import context +from mindspore import set_seed +from mindspore.common import dtype as mstype from .utils import rand_int, rand_bool, match_array, match_res, match_meta, \ match_all_arrays, run_multi_test, to_tensor @@ -894,6 +897,126 @@ def test_histogram_bin_edges(): match_res(mnp.histogram_bin_edges, onp.histogram_bin_edges, x, bins=10, range=(2, 20), error=3) +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_randn(): + """ + Feature: Numpy method randn. + Description: Test numpy method randn. + Expectation: No exception. + """ + set_seed(1) + t1 = mnp.randn(1, 2, 3) + t2 = mnp.randn(1, 2, 3) + assert (t1.asnumpy() == t2.asnumpy()).all() + + with pytest.raises(ValueError): + mnp.randn(dtype="int32") + with pytest.raises(ValueError): + mnp.randn(dtype=mstype.int32) + with pytest.raises(TypeError): + mnp.randn({1}) + with pytest.raises(TypeError): + mnp.randn(1, 1.2, 2) + + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_rand(): + """ + Feature: Numpy method rand. + Description: Test numpy method rand. + Expectation: No exception. + """ + set_seed(1) + t1 = mnp.rand(1, 2, 3) + t2 = mnp.rand(1, 2, 3) + assert (t1.asnumpy() == t2.asnumpy()).all() + + with pytest.raises(ValueError): + mnp.rand(dtype="int32") + with pytest.raises(ValueError): + mnp.rand(dtype=mstype.int32) + with pytest.raises(TypeError): + mnp.rand({1}) + with pytest.raises(TypeError): + mnp.rand(1, 1.2, 2) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_randint(): + """ + Feature: Numpy method randint. + Description: Test numpy method randint. + Expectation: No exception. + """ + set_seed(1) + t1 = mnp.randint(1, 5, 3) + t2 = mnp.randint(1, 5, 3) + assert (t1.asnumpy() == t2.asnumpy()).all() + + with pytest.raises(TypeError): + mnp.randint(1.2) + with pytest.raises(ValueError): + mnp.randint(0) + with pytest.raises(TypeError): + mnp.randint(1, 1.2) + with pytest.raises(ValueError): + mnp.randint(2, 1) + with pytest.raises(ValueError): + mnp.randint(1, dtype="float") + with pytest.raises(ValueError): + mnp.randint(1, dtype=mstype.float32) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_ops_arange(): + """ + Feature: Ops function arange. + Description: Test ops function arange. + Expectation: No exception. + """ + actual = onp.arange(5) + expected = F.arange(5).asnumpy() + match_array(actual, expected) + + actual = onp.arange(0, 5) + expected = F.arange(0, 5).asnumpy() + match_array(actual, expected) + + actual = onp.arange(5, step=0.2) + expected = F.arange(5, step=0.2).asnumpy() + match_array(actual, expected) + + actual = onp.arange(0.1, 0.9) + expected = F.arange(0.1, 0.9).asnumpy() + match_array(actual, expected) + + with pytest.raises(TypeError): + F.arange([1]) + with pytest.raises(ValueError): + F.arange(10, 1) + + @pytest.mark.level1 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training