forked from mindspore-Ecosystem/mindspore
add some numpy method
This commit is contained in:
parent
e2eef76c99
commit
139ee01205
|
@ -34,7 +34,7 @@ from .array_ops import (transpose, expand_dims, squeeze, rollaxis, swapaxes, res
|
|||
rot90, select, array_split, choose, size, array_str, apply_along_axis,
|
||||
piecewise, unravel_index, apply_over_axes)
|
||||
from .array_creations import copy_ as copy
|
||||
from .array_creations import (array, asarray, asfarray, ones, zeros, full, arange,
|
||||
from .array_creations import (array, asarray, asfarray, ones, zeros, full, randn, rand, randint, arange,
|
||||
linspace, logspace, eye, identity, empty, empty_like,
|
||||
ones_like, zeros_like, full_like, diagonal, tril, triu,
|
||||
tri, trace, meshgrid, mgrid, ogrid, diagflat,
|
||||
|
@ -83,10 +83,10 @@ array_ops_module = ['transpose', 'expand_dims', 'squeeze', 'rollaxis', 'swapaxes
|
|||
'repeat', 'rot90', 'select', 'array_split', 'choose', 'size', 'array_str',
|
||||
'apply_along_axis', 'piecewise', 'unravel_index', 'apply_over_axes']
|
||||
|
||||
array_creations_module = ['array', 'asarray', 'asfarray', 'ones', 'zeros', 'full', 'arange',
|
||||
'linspace', 'logspace', 'eye', 'identity', 'empty', 'empty_like',
|
||||
'ones_like', 'zeros_like', 'full_like', 'diagonal', 'tril', 'triu',
|
||||
'tri', 'trace', 'meshgrid', 'mgrid', 'ogrid', 'diagflat', 'diag',
|
||||
array_creations_module = ['array', 'asarray', 'asfarray', 'ones', 'zeros', 'full', 'randn', 'rand',
|
||||
'randint', 'arange', 'linspace', 'logspace', 'eye', 'identity', 'empty',
|
||||
'empty_like', 'ones_like', 'zeros_like', 'full_like', 'diagonal', 'tril',
|
||||
'triu', 'tri', 'trace', 'meshgrid', 'mgrid', 'ogrid', 'diagflat', 'diag',
|
||||
'diag_indices', 'ix_', 'indices', 'geomspace', 'vander', 'hamming',
|
||||
'hanning', 'bartlett', 'blackman', 'triu_indices', 'tril_indices',
|
||||
'triu_indices_from', 'tril_indices_from', 'histogram_bin_edges', 'pad']
|
||||
|
|
|
@ -21,6 +21,7 @@ import numpy as onp
|
|||
from .. import context
|
||||
from ..common import Tensor
|
||||
from ..common import dtype as mstype
|
||||
from ..common.seed import get_seed
|
||||
from ..ops import operations as P
|
||||
from ..ops import functional as F
|
||||
from ..ops.primitive import constexpr
|
||||
|
@ -38,7 +39,7 @@ from .utils_const import _raise_value_error, _empty, _max, _min, \
|
|||
_tuple_setitem
|
||||
from .array_ops import ravel, concatenate, broadcast_arrays, reshape, broadcast_to, flip, \
|
||||
apply_along_axis, where, moveaxis
|
||||
from .dtypes import nan, pi
|
||||
from .dtypes import nan, pi, dtype_map
|
||||
|
||||
# According to official numpy reference, the dimension of a numpy array must be less
|
||||
# than 32
|
||||
|
@ -379,6 +380,198 @@ def full(shape, fill_value, dtype=None):
|
|||
return _convert_64_to_32(empty_compile(dtype, shape))
|
||||
|
||||
|
||||
def _generate_shapes(shape):
|
||||
"""Generate shapes for randn and rand."""
|
||||
if not shape:
|
||||
size = (1,)
|
||||
elif len(shape) == 1:
|
||||
if isinstance(shape[0], int):
|
||||
size = shape
|
||||
elif isinstance(shape[0], list):
|
||||
size = tuple(shape[0])
|
||||
elif isinstance(shape[0], tuple):
|
||||
size = shape[0]
|
||||
else:
|
||||
raise TypeError("If the length of the argument 'shape' is 1, the type of the argument 'shape' must be "
|
||||
"one of ['int', 'list', 'tuple'], but got {}.".format(type(shape[0])))
|
||||
else:
|
||||
for index, value in enumerate(shape):
|
||||
if not isinstance(value, int):
|
||||
raise TypeError("If the length of the argument 'shape' is > 1, the type of the argument 'shape' must "
|
||||
"all be int, but got {} at index {}.".format(type(value), index))
|
||||
size = shape
|
||||
return size
|
||||
|
||||
|
||||
def _check_rand_type(dtype):
|
||||
"""Check type for randn and rand"""
|
||||
type_list = ['float', 'float16', 'float32', 'float64']
|
||||
if isinstance(dtype, str):
|
||||
if dtype not in type_list:
|
||||
raise ValueError("If the argument 'dtype' is str, it must be one of {}, but got {}."
|
||||
.format(type_list, dtype))
|
||||
try:
|
||||
dtype = dtype_map[dtype]
|
||||
except KeyError:
|
||||
raise KeyError("Unsupported dtype {}.".format(dtype))
|
||||
elif dtype not in (mstype.float64, mstype.float32, mstype.float16):
|
||||
raise ValueError("The argument 'dtype' must be 'mindspore.float64', 'mindspore.float32' or "
|
||||
"'mindspore.float16', but got {}.".format(dtype))
|
||||
|
||||
|
||||
def randn(*shape, dtype=mstype.float32):
|
||||
"""
|
||||
Returns a new Tensor with given shape and dtype, filled with a sample (or samples)
|
||||
from the standard normal distribution.
|
||||
|
||||
Args:
|
||||
*shape (Union[int, tuple(int), list(int)]): Shape of the new tensor, e.g.,
|
||||
:math:`(2, 3)` or :math:`2`.
|
||||
dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, it must
|
||||
be float type. Default is :class:`mindspore.float32`.
|
||||
|
||||
Returns:
|
||||
Tensor, with the designated shape and dtype, filled with a sample (or samples)
|
||||
from the "standard normal" distribution.
|
||||
|
||||
Raises:
|
||||
TypeError: If input arguments have types not specified above.
|
||||
ValueError: If `dtype` is not float type.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> from mindspore import set_seed
|
||||
>>> set_seed(1)
|
||||
>>> print(np.randn((2,3)))
|
||||
[[ 0.30639967 -0.42438635 -0.20454668]
|
||||
[-0.4287376 1.3054721 0.64747655]]
|
||||
"""
|
||||
_check_rand_type(dtype)
|
||||
size = _generate_shapes(shape)
|
||||
seed = get_seed()
|
||||
if seed is not None:
|
||||
stdnormal = P.StandardNormal(seed=seed)
|
||||
else:
|
||||
stdnormal = P.StandardNormal()
|
||||
return stdnormal(size).astype(dtype)
|
||||
|
||||
|
||||
def rand(*shape, dtype=mstype.float32):
|
||||
"""
|
||||
Returns a new Tensor with given shape and dtype, filled with random numbers from the
|
||||
uniform distribution on the interval :math:`[0, 1)`.
|
||||
|
||||
Args:
|
||||
*shape (Union[int, tuple(int), list(int)]): Shape of the new tensor, e.g.,
|
||||
:math:`(2, 3)` or :math:`2`.
|
||||
dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, it must
|
||||
be float type. Default is :class:`mindspore.float32`.
|
||||
|
||||
Returns:
|
||||
Tensor, with the designated shape and dtype, filled with random numbers from the
|
||||
uniform distribution on the interval :math:`[0, 1)`.
|
||||
|
||||
Raises:
|
||||
TypeError: If input arguments have types not specified above.
|
||||
ValueError: If `dtype` is not float type.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> from mindspore import set_seed
|
||||
>>> set_seed(1)
|
||||
>>> print(np.rand((2,3)))
|
||||
[[4.1702199e-01 9.9718481e-01 7.2032452e-01]
|
||||
[9.3255734e-01 1.1438108e-04 1.2812445e-01]]
|
||||
"""
|
||||
_check_rand_type(dtype)
|
||||
size = _generate_shapes(shape)
|
||||
seed = get_seed()
|
||||
if seed is not None:
|
||||
uniformreal = P.UniformReal(seed=seed)
|
||||
else:
|
||||
uniformreal = P.UniformReal()
|
||||
return uniformreal(size).astype(dtype)
|
||||
|
||||
|
||||
def randint(minval, maxval=None, shape=None, dtype=mstype.int32):
|
||||
"""
|
||||
Return random integers from minval (inclusive) to maxval (exclusive). Return random integers from the
|
||||
discrete uniform distribution of the specified dtype in the “half-open” interval :math:`[minval, maxval)`.
|
||||
If maxval is None (the default), then results are from [0, maxval).
|
||||
|
||||
Args:
|
||||
minval(Union[int]): Start value of interval. The interval includes this value. When `maxval`
|
||||
is :class:`None`, `minval` must be greater than 0. When `maxval` is not :class:`None`,
|
||||
`minval` must be less than `maxval`.
|
||||
maxval(Union[int], optional): End value of interval. The interval does not include this value.
|
||||
shape (Union[int, tuple(int)]): Shape of the new tensor, e.g., :math:`(2, 3)` or :math:`2`.
|
||||
dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, it must
|
||||
be int type. Default is :class:`mindspore.int32`.
|
||||
|
||||
Returns:
|
||||
Tensor, with the designated shape and dtype, filled with random integers from minval (inclusive)
|
||||
to maxval (exclusive).
|
||||
|
||||
Raises:
|
||||
TypeError: If input arguments have types not specified above.
|
||||
ValueError: If input arguments have values not specified above.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> from mindspore import set_seed
|
||||
>>> set_seed(1)
|
||||
>>> print(np.randint(1, 10, (2,3)))
|
||||
[[4 9 7]
|
||||
[9 1 2]]
|
||||
"""
|
||||
if not isinstance(minval, int):
|
||||
raise TypeError("For mindspore.numpy.randint, the type of the argument 'minval' must be int, "
|
||||
"but got {}.".format(type(minval)))
|
||||
if maxval is None:
|
||||
if minval <= 0:
|
||||
raise ValueError("For mindspore.numpy.randint, the argument 'minval' must be > 0 when the argument "
|
||||
"'maxval' is None, but got {}.".format(minval))
|
||||
maxval = minval
|
||||
minval = 0
|
||||
else:
|
||||
if not isinstance(maxval, int):
|
||||
raise TypeError("For mindspore.numpy.randint, the type of the argument 'maxval' must be int, "
|
||||
"but got {}.".format(type(maxval)))
|
||||
if minval >= maxval:
|
||||
raise ValueError("For mindspore.numpy.randint, the value of 'minval' must be greater than the value of "
|
||||
"'maxval', but got 'minval': {} and 'maxval': {}.".format(minval, maxval))
|
||||
if isinstance(dtype, str):
|
||||
if dtype not in ('int', 'int8', 'int16', 'int32', 'int64'):
|
||||
raise ValueError("For 'mindspore.numpy.randint', if the argument 'dtype' is str, it must be one of "
|
||||
"['int', 'int8', 'int16', 'int32', 'int64'], but got {}.".format(dtype))
|
||||
try:
|
||||
dtype = dtype_map[dtype]
|
||||
except KeyError:
|
||||
raise KeyError("Unsupported dtype {}.".format(dtype))
|
||||
elif dtype not in (mstype.int64, mstype.int32, mstype.int16, mstype.int8):
|
||||
raise ValueError("For 'mindspore.numpy.randint', the argument 'dtype' must be 'mindspore.int64', "
|
||||
"'mindspore.int32', 'mindspore.int16' or 'mindspore.int8', but got {}.".format(dtype))
|
||||
if shape is None:
|
||||
shape = (1,)
|
||||
else:
|
||||
shape = _check_shape(shape)
|
||||
seed = get_seed()
|
||||
if seed is not None:
|
||||
uniformint = P.UniformInt(seed=seed)
|
||||
else:
|
||||
uniformint = P.UniformInt()
|
||||
return uniformint(shape, Tensor(minval, mstype.int32), Tensor(maxval, mstype.int32)).astype(dtype)
|
||||
|
||||
|
||||
def arange(start, stop=None, step=None, dtype=None):
|
||||
"""
|
||||
Returns evenly spaced values within a given interval.
|
||||
|
|
|
@ -27,6 +27,7 @@ from mindspore.nn.grad.cell_grad import _JvpInner
|
|||
from mindspore.nn.grad.cell_grad import _VjpInner
|
||||
from mindspore.ops import _constants
|
||||
from mindspore.ops.primitive import constexpr
|
||||
import numpy as np
|
||||
from .primitive import Primitive
|
||||
from . import operations as P
|
||||
from .operations import _grad_ops
|
||||
|
@ -395,6 +396,71 @@ shard_fn = Shard()
|
|||
def shard(fn, in_axes, out_axes, device="Ascend", level=0):
|
||||
return shard_fn(fn, in_axes, out_axes, device, level)
|
||||
|
||||
|
||||
def arange(start=0, stop=None, step=1, rtype=None):
|
||||
"""
|
||||
Returns evenly spaced values within a given interval.
|
||||
|
||||
Args:
|
||||
start(Union[int, float]): Start value of interval. The interval includes this value. When
|
||||
`stop` is None, `start` must be greater than 0, and the interval is :math:`[0, start)`.
|
||||
When `stop` is not None, `start` must be less than `stop`.
|
||||
stop(Union[int, float], optional): End value of interval. The interval does not
|
||||
include this value. Default is None.
|
||||
step(Union[int, float], optional): Spacing between values. For any output
|
||||
`out`, this is the distance between two adjacent values, :math:`out[i+1] - out[i]`.
|
||||
The default step size is 1. If `step` is specified as a position argument,
|
||||
`start` must also be given.
|
||||
rtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor type.
|
||||
If rtype is None, the data type of the new tensor will be inferred from start,
|
||||
stop and step. Default is None.
|
||||
|
||||
Returns:
|
||||
Tensor with evenly spaced values.
|
||||
|
||||
Raises:
|
||||
TypeError: If input arguments have types not specified above.
|
||||
ValueError: If input arguments have values not specified above.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.ops as ops
|
||||
>>> print(ops.arange(0, 5, 1))
|
||||
[0 1 2 3 4]
|
||||
>>> print(ops.arange(3))
|
||||
[0 1 2]
|
||||
>>> print(ops.arange(start=0, stop=3))
|
||||
[0 1 2]
|
||||
>>> print(ops.arange(0, stop=3, step=0.5))
|
||||
[0. 0.5 1. 1.5 2. 2.5]
|
||||
"""
|
||||
if stop is None:
|
||||
start, stop = 0, start
|
||||
|
||||
arg_map = {"start": start, "stop": stop, "step": step}
|
||||
for arg in arg_map:
|
||||
try:
|
||||
arg_value = arg_map[arg]
|
||||
except KeyError:
|
||||
raise KeyError("Unsupported key {}, the key must be one of ['start', 'stop', 'step'].".format(arg_value))
|
||||
if not isinstance(arg_value, int) and not isinstance(arg_value, float):
|
||||
raise TypeError("For mindspore.ops.range, the argument '{}' must be int or float, but got {}."
|
||||
.format(arg, type(arg_value)))
|
||||
if start >= stop:
|
||||
raise ValueError("For mindspore.ops.range, the argument 'start' must be < 'stop', but got 'start': {}, "
|
||||
"'stop': {}.".format(start, stop))
|
||||
|
||||
if rtype is None:
|
||||
data = np.arange(start, stop, step)
|
||||
if data.dtype == int:
|
||||
rtype = mstype.int32
|
||||
else:
|
||||
rtype = mstype.float32
|
||||
return Tensor(np.arange(start, stop, step), dtype=rtype)
|
||||
|
||||
|
||||
def narrow(inputs, axis, start, length):
|
||||
"""
|
||||
Returns a narrowed tensor from input tensor.
|
||||
|
|
|
@ -17,7 +17,10 @@
|
|||
import pytest
|
||||
import numpy as onp
|
||||
import mindspore.numpy as mnp
|
||||
import mindspore.ops.functional as F
|
||||
from mindspore import context
|
||||
from mindspore import set_seed
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from .utils import rand_int, rand_bool, match_array, match_res, match_meta, \
|
||||
match_all_arrays, run_multi_test, to_tensor
|
||||
|
@ -894,6 +897,126 @@ def test_histogram_bin_edges():
|
|||
match_res(mnp.histogram_bin_edges, onp.histogram_bin_edges, x, bins=10, range=(2, 20), error=3)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_randn():
|
||||
"""
|
||||
Feature: Numpy method randn.
|
||||
Description: Test numpy method randn.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
set_seed(1)
|
||||
t1 = mnp.randn(1, 2, 3)
|
||||
t2 = mnp.randn(1, 2, 3)
|
||||
assert (t1.asnumpy() == t2.asnumpy()).all()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
mnp.randn(dtype="int32")
|
||||
with pytest.raises(ValueError):
|
||||
mnp.randn(dtype=mstype.int32)
|
||||
with pytest.raises(TypeError):
|
||||
mnp.randn({1})
|
||||
with pytest.raises(TypeError):
|
||||
mnp.randn(1, 1.2, 2)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_rand():
|
||||
"""
|
||||
Feature: Numpy method rand.
|
||||
Description: Test numpy method rand.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
set_seed(1)
|
||||
t1 = mnp.rand(1, 2, 3)
|
||||
t2 = mnp.rand(1, 2, 3)
|
||||
assert (t1.asnumpy() == t2.asnumpy()).all()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
mnp.rand(dtype="int32")
|
||||
with pytest.raises(ValueError):
|
||||
mnp.rand(dtype=mstype.int32)
|
||||
with pytest.raises(TypeError):
|
||||
mnp.rand({1})
|
||||
with pytest.raises(TypeError):
|
||||
mnp.rand(1, 1.2, 2)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_randint():
|
||||
"""
|
||||
Feature: Numpy method randint.
|
||||
Description: Test numpy method randint.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
set_seed(1)
|
||||
t1 = mnp.randint(1, 5, 3)
|
||||
t2 = mnp.randint(1, 5, 3)
|
||||
assert (t1.asnumpy() == t2.asnumpy()).all()
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
mnp.randint(1.2)
|
||||
with pytest.raises(ValueError):
|
||||
mnp.randint(0)
|
||||
with pytest.raises(TypeError):
|
||||
mnp.randint(1, 1.2)
|
||||
with pytest.raises(ValueError):
|
||||
mnp.randint(2, 1)
|
||||
with pytest.raises(ValueError):
|
||||
mnp.randint(1, dtype="float")
|
||||
with pytest.raises(ValueError):
|
||||
mnp.randint(1, dtype=mstype.float32)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_ops_arange():
|
||||
"""
|
||||
Feature: Ops function arange.
|
||||
Description: Test ops function arange.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
actual = onp.arange(5)
|
||||
expected = F.arange(5).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.arange(0, 5)
|
||||
expected = F.arange(0, 5).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.arange(5, step=0.2)
|
||||
expected = F.arange(5, step=0.2).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.arange(0.1, 0.9)
|
||||
expected = F.arange(0.1, 0.9).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
F.arange([1])
|
||||
with pytest.raises(ValueError):
|
||||
F.arange(10, 1)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
|
Loading…
Reference in New Issue