!47243 [API] Add rand functions

Merge pull request !47243 from shaojunsong/feature/rand
This commit is contained in:
i-robot 2023-01-11 02:01:03 +00:00 committed by Gitee
commit 165f31036d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 580 additions and 1 deletions

View File

@ -406,6 +406,12 @@ Tensor创建
mindspore.ops.gamma
mindspore.ops.laplace
mindspore.ops.multinomial
mindspore.ops.rand
mindspore.ops.rand_like
mindspore.ops.randint
mindspore.ops.randint_like
mindspore.ops.randn
mindspore.ops.randn_like
mindspore.ops.random_poisson
mindspore.ops.random_categorical
mindspore.ops.random_gamma

View File

@ -0,0 +1,20 @@
mindspore.ops.rand
===================
.. py:function:: mindspore.ops.rand(*size, dtype=None, seed=None)
返回一个Tensorshape和dtype由输入决定其元素为服从均匀分布的 :math:`[0, 1)` 区间的数字。
参数:
- **size** (Union[int, tuple(int), list(int)]) - 输出的Tensor的shape例如:math:`(2, 3)` or :math:`2`
关键字参数:
- **dtype** (:class:`mindspore.dtype`,可选) - 指定的输出Tensor的dtype必须是float类型。如果是None`mindspore.float32` 会被使用。默认值None。
- **seed** (int可选) - 随机种子必须大于或等于0。默认值None值将取0。
返回:
Tensorshape和dtype由输入决定其元素为服从均匀分布的 :math:`[0, 1)` 区间的数字。
异常:
- **TypeError** - 如果 `seed` 不是非负整数。
- **ValueError** - 如果 `dtype` 不是一个 `mstype.float_type` 类型。

View File

@ -0,0 +1,20 @@
mindspore.ops.rand_like
========================
.. py:function:: mindspore.ops.rand_like(x, seed=None, *, dtype=None)
返回一个Tensorshape和dtype由输入决定其元素为服从均匀分布的 :math:`[0, 1)` 区间的数字。
参数:
- **x** (Tensor)输入的Tensor用来决定输出Tensor的shape和默认的dtype。
- **seed** (int可选) - 随机种子必须大于或等于0。默认值None值将取0。
关键字参数:
- **dtype** (:class:`mindspore.dtype`,可选) - 指定的输出Tensor的dtype必须是float类型。如果是None`x` 的dtype会被使用。默认值None。
返回:
Tensorshape和dtype由输入决定其元素为服从均匀分布的 :math:`[0, 1)` 区间的数字。
异常:
- **TypeError** - 如果 `seed` 不是非负整数。
- **ValueError** - 如果 `dtype` 不是一个 `mstype.float_type` 类型。

View File

@ -0,0 +1,24 @@
mindspore.ops.randint
======================
.. py:function:: mindspore.ops.randint(low, high, size, seed=None, *, dtype=None)
返回一个Tensor其元素为 [ `low` , `high` ) 区间的随机整数。
参数:
- **low** (int) - 随机区间的起始值。
- **high** (int) - 随机区间的结束值。
- **size** (tuple) - 新Tensor的shape。
- **seed** (int可选) - 随机种子必须大于或等于0。默认值None值将取0。
关键字参数:
- **dtype** (:class:`mindspore.dtype`,可选) - 指定的Tensor dtype必须是int类型的dtype。如果是None将会使用 `mindspore.int64` 。默认值None。
返回:
Tensorshape和dtype被输入指定其元素为 [ `low` , `high` ) 区间的随机整数。
异常:
- **TypeError** - 如果 `seed` 不是非负整数。
- **TypeError** - 如果 `size` 不是tuple。
- **TypeError** - 如果 `low``high` 不是整数。
- **ValueError** - 如果 `dtype` 不是一个 `mstype.int_type` 类型。

View File

@ -0,0 +1,23 @@
mindspore.ops.randint_like
===========================
.. py:function:: mindspore.ops.randint_like(x, low, high, *, dtype=None, seed=None)
返回一个Tensor其元素为 [ `low` , `high` ) 区间的随机整数。
参数:
- **x** (Tensor) - 输入的Tensor用来决定输出Tensor的shape和默认的dtype。
- **low** (int) - 随机区间的起始值。
- **high** (int) - 随机区间的结束值。
- **seed** (int可选) - 随机种子必须大于或等于0。默认值None值将取0。
关键字参数:
- **dtype** (:class:`mindspore.dtype`,可选) - 指定的Tensor dtype必须是int类型的dtype。如果是None将会使用 `x` 的dtype。默认值None。
返回:
Tensorshape和dtype被输入指定其元素为 [ `low` , `high` ) 区间的随机整数。
异常:
- **TypeError** - 如果 `seed` 不是非负整数。
- **TypeError** - 如果 `low``high` 不是整数。
- **ValueError** - 如果 `dtype` 不是 `mstype.int_type` 类型。

View File

@ -0,0 +1,21 @@
mindspore.ops.randn
====================
.. py:function:: mindspore.ops.randn(*size, dtype=None, seed=None)
返回一个Tensorshape和dtype由输入决定其元素为服从标准正态分布的 :math:`[0, 1)` 区间的数字。
参数:
- **size** (Union[int, tuple(int), list(int)]) - 输出的Tensor的shape例如:math:`(2, 3)` or :math:`2`
关键字参数:
- **dtype** (:class:`mindspore.dtype`,可选) - 需求的输出Tensor的dtype必须是float类型。如果是None`mindspore.float32` 会被使用。默认值None。
- **seed** (int可选) - 随机种子必须大于或等于0。默认值None值将取0。
返回:
Tensorshape和dtype由输入决定其元素为服从标准正态分布的 :math:`[0, 1)` 区间的数字。
异常:
- **TypeError** - 如果 `seed` 不是非负整数。
- **ValueError** - 如果 `dtype` 不是一个 `mstype.float_type` 类型。
- **ValueError** - 如果 `size` 包含不合理的数字。

View File

@ -0,0 +1,20 @@
mindspore.ops.randn_like
=========================
.. py:function:: mindspore.ops.randn_like(x, seed=None, *, dtype=None)
返回一个Tensorshape和dtype由输入决定其元素为服从标准正态分布的 :math:`[0, 1)` 区间的数字。
参数:
- **x** (Tensor) - 输入的Tensor用来决定输出Tensor的shape和默认的dtype。
- **seed** (int可选) - 随机种子必须大于或等于0。默认值None值将取0。
关键字参数:
- **dtype** (:class:`mindspore.dtype`,可选) - 需求的输出Tensor的dtype必须是float类型。如果是None`x` 的dtype会被使用。默认值None。
返回:
Tensorshape和dtype由输入决定其元素为服从标准正态分布的 :math:`[0, 1)` 区间的数字。
异常:
- **TypeError** - 如果 `seed` 不是非负整数。
- **ValueError** - 如果 `dtype` 不是一个 `mstype.float_type` 类型。

View File

@ -406,6 +406,12 @@ Randomly Generating Functions
mindspore.ops.gamma
mindspore.ops.laplace
mindspore.ops.multinomial
mindspore.ops.rand
mindspore.ops.rand_like
mindspore.ops.randint
mindspore.ops.randint_like
mindspore.ops.randn
mindspore.ops.randn_like
mindspore.ops.random_poisson
mindspore.ops.random_categorical
mindspore.ops.random_gamma

View File

@ -516,6 +516,12 @@ from .random_func import (
gamma,
poisson,
multinomial,
rand,
rand_like,
randn,
randn_like,
randint,
randint_like
)
from .grad import (
grad_func,

View File

@ -806,6 +806,311 @@ def gamma(shape, alpha, beta, seed=None):
return value
@constexpr
def _generate_shapes(shape):
"""Generate shapes for randn and rand."""
if not shape:
size = (1,)
elif len(shape) == 1:
if isinstance(shape[0], int):
size = shape
elif isinstance(shape[0], list):
size = tuple(shape[0])
elif isinstance(shape[0], tuple):
size = shape[0]
else:
raise TypeError("If the length of the argument 'shape' is 1, the type of the argument 'shape' must be "
"one of ['int', 'list', 'tuple'], but got ", shape[0])
else:
for value in shape:
if not isinstance(value, int):
raise TypeError("If the length of the argument 'shape' is > 1, the type of the argument 'shape' must "
"all be int, but got ", value)
size = shape
return size
@_function_forbid_reuse
def rand(*size, dtype=None, seed=None):
r"""
Returns a new Tensor with given shape and dtype, filled with random numbers from the uniform distribution on the
interval :math:`[0, 1)`.
Args:
*size (Union[int, tuple(int), list(int)]): Shape of the new tensor, e.g. :math:`(2, 3)` or :math:`2`.
Keyword Args:
dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be float type. If None,
`mindspore.float32` will be applied. Default: None.
seed (int, optional): Random seed, must be greater or equal to 0. Default: None, and 0 will be used.
Returns:
Tensor, with the designated shape and dtype, filled with random numbers from the uniform distribution on
the interval :math:`[0, 1)`.
Raises:
TypeError: `seed` is not a non-negative integer.
ValueError: If `dtype` is not a `mstype.float_type` type.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore.ops as ops
>>> print(ops.rand((2,3)))
[[4.1702199e-01 9.9718481e-01 7.2032452e-01]
[9.3255734e-01 1.1438108e-04 1.2812445e-01]]
"""
if dtype is None:
dtype = mstype.float32
elif dtype not in mstype.float_type:
raise ValueError(f"For 'rand', the 'dtype' must be a float type, but got {dtype}.")
shape = _generate_shapes(size)
cast_ = P.Cast()
seed1, seed2 = _get_seed(seed, 'rand')
rand_op = P.UniformReal(seed1, seed2)
output = rand_op(shape)
return cast_(output, dtype)
@_function_forbid_reuse
def rand_like(x, seed=None, *, dtype=None):
r"""
Returns a new Tensor with the shape and dtype as `x`, filled with random numbers from the uniform distribution on
the interval :math:`[0, 1)`.
Args:
x (Tensor): Input Tensor to specify the output shape and its default dtype.
seed (int, optional): Random seed, must be greater or equal to 0. Default: None, and 0 will be used.
Keyword Args:
dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be float type. If None,
the same dtype of `x` will be applied. Default: None.
Returns:
Tensor, with the designated shape and dtype, filled with random numbers from the uniform distribution on
the interval :math:`[0, 1)`.
Raises:
TypeError: If `seed` is not a non-negative integer.
ValueError: If `dtype` is not a `mstype.float_type` type.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore as ms
>>> from mindspore import Tensor, ops
>>> a = Tensor([[2, 3, 4], [1, 2, 3]])
>>> print(ops.rand_like(a, dtype=ms.float32))
[[4.1702199e-01 9.9718481e-01 7.2032452e-01]
[9.3255734e-01 1.1438108e-04 1.2812445e-01]]
"""
if dtype is None:
dtype = x.dtype
elif dtype not in mstype.float_type:
raise ValueError(f"For 'rand_like', the 'dtype' must be a float type, but got {dtype}.")
shape = x.shape
cast_ = P.Cast()
seed1, seed2 = _get_seed(seed, 'rand_like')
rand_op = P.UniformReal(seed1, seed2)
output = rand_op(shape)
return cast_(output, dtype)
@_function_forbid_reuse
def randn(*size, dtype=None, seed=None):
r"""
Returns a new Tensor with given shape and dtype, filled with a sample (or samples)
from the standard normal distribution.
Args:
*size (Union[int, tuple(int), list(int)]): Shape of the new tensor, e.g., :math:`(2, 3)` or :math:`2`.
Keyword Args:
dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be float type. If None,
:class:`mindspore.float32` will be used. Default: None.
seed (int, optional): Random seed, must be greater or equal to 0. Default: None, and 0 will be used.
Returns:
Tensor, with the designated shape and dtype, filled with a sample (or samples) from the
"standard normal" distribution.
Raises:
TypeError: `seed` is not a non-negative integer.
ValueError: If `dtype` is not a `mstype.float_type`.
ValueError: If `size` contains invalid number.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore.ops as ops
>>> print(ops.randn((2,3)))
[[ 0.30639967 -0.42438635 -0.20454668]
[-0.4287376 1.3054721 0.64747655]]
"""
if dtype is None:
dtype = mstype.float32
elif dtype not in mstype.float_type:
raise ValueError(f"For 'randn', the 'dtype' must be a float type, but got {dtype}.")
shape = _generate_shapes(size)
cast_ = P.Cast()
seed1, seed2 = _get_seed(seed, 'randn')
rand_op = P.StandardNormal(seed1, seed2)
output = rand_op(shape)
return cast_(output, dtype)
@_function_forbid_reuse
def randn_like(x, seed=None, *, dtype=None):
r"""
Returns a new Tensor with given shape and dtype, filled with a sample (or samples) from the standard normal
distribution.
Args:
x (Tensor): Input Tensor to specify the output shape and its default dtype.
seed (int, optional): Random seed, must be greater or equal to 0. Default: None, and 0 will be used.
Keyword Args:
dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be float type. If None,
:class:`mindspore.float32` will be used. Default: None.
Returns:
Tensor, with the designated shape and dtype, filled with a sample (or samples) from the
"standard normal" distribution.
Raises:
TypeError: `seed` is not a non-negative integer.
ValueError: If `dtype` is not a `mstype.float_type`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore as ms
>>> from mindspore import Tensor, ops
>>> a = Tensor([[1, 2, 3], [4, 5, 6]])
>>> print(ops.randn_like(x, dtype=ms.float32))
[[ 0.30639967 -0.42438635 -0.20454668]
[-0.4287376 1.3054721 0.64747655]]
"""
if dtype is None:
dtype = x.dtype
elif dtype not in mstype.float_type:
raise ValueError(f"For 'randn_like', the 'dtype' must be a float type, but got {dtype}.")
shape = x.shape
cast_ = P.Cast()
seed1, seed2 = _get_seed(seed, 'randn_like')
rand_op = P.StandardNormal(seed1, seed2)
output = rand_op(shape)
return cast_(output, dtype)
@_function_forbid_reuse
def randint(low, high, size, seed=None, *, dtype=None):
r"""
Return a Tensor whose elements are random integers from low (inclusive) to high (exclusive).
Args:
low (int): Start value of interval.
high (int): End value of interval.
size (tuple): Shape of the new tensor.
seed (int, optional): Random seed, must be greater or equal to 0. Default: None, and 0 will be used.
Keyword Args:
dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be int type. If None,
`mindspore.int64` will be used. Default: None.
Returns:
Tensor, with the designated shape and dtype, filled with random integers from low (inclusive)
to high (exclusive).
Raises:
TypeError: `seed` is not a non-negative integer.
TypeError: `size` is not a tuple.
TypeError: `low` or `high` is not an integer.
ValueError: If `dtype` is not a `mstype.int_type`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore.ops as ops
>>> print(ops.randint(1, 10, (2,3)))
[[4 9 7]
[9 1 2]]
"""
if dtype is None:
dtype = mstype.int64
elif dtype not in mstype.int_type:
raise ValueError(f"For 'randint', the 'dtype' must be an int type, but got {dtype}.")
if not isinstance(size, tuple):
raise ValueError(f"For 'randint', the input 'size' must be a tuple, but got {size}.")
if not isinstance(low, int) or not isinstance(high, int):
raise TypeError(f"For 'randint', 'low' and 'high' must be an int, but got {type(low)} and {type(high)}.")
seed1, seed2 = _get_seed(seed, 'randint')
cast_ = P.Cast()
rand_op = P.UniformInt(seed1, seed2)
low_ = Tensor(low, mstype.int32)
high_ = Tensor(high, mstype.int32)
output = rand_op(size, low_, high_)
return cast_(output, dtype)
@_function_forbid_reuse
def randint_like(x, low, high, seed=None, *, dtype=None):
r"""
Returns a tensor with the same shape as Tensor `x` filled with random integers generated uniformly between
low (inclusive) and high (exclusive).
Args:
x (Tensor): Input Tensor to specify the output shape and its default dtype.
low(int): Start value of interval.
high(int): End value of interval.
seed (int, optional): Random seed, must be greater or equal to 0. Default: None, and 0 will be used.
Keyword Args:
dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be int type. If None,
:class:`mindspore.int64` will be used. Default is :class:`mindspore.int64`.
Returns:
Tensor, with the designated shape and dtype, filled with random integers from low (inclusive)
to high (exclusive).
Raises:
TypeError: `seed` is not a non-negative integer.
TypeError: `low` or `high` is not an integer.
ValueError: If `dtype` is not a `mstype.int_type`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore import Tensor, ops
>>> a = Tensor([[1, 2, 3], [3, 2, 1]])
>>> print(ops.randint_like(a, 1, 10))
[[4 9 7]
[9 1 2]]
"""
if dtype is None:
dtype = x.dtype
elif dtype not in mstype.int_type:
raise ValueError(f"For 'randint_like', the 'dtype' must be an int type, but got {dtype}.")
if not isinstance(low, int) or not isinstance(high, int):
raise TypeError(f"For 'randint_like', 'low' and 'high' must be an int, but got {type(low)} and {type(high)}.")
size = x.shape
seed1, seed2 = _get_seed(seed, 'randint_like')
rand_op = P.UniformInt(seed1, seed2)
cast_ = P.Cast()
low_ = Tensor(low, mstype.int32)
high_ = Tensor(high, mstype.int32)
output = rand_op(size, low_, high_)
return cast_(output, dtype)
@_function_forbid_reuse
def poisson(shape, mean, seed=None):
r"""
@ -953,6 +1258,7 @@ def _check_shape(input_shape):
__all__ = [
'standard_laplace', 'random_categorical', 'uniform', 'standard_normal', 'random_gamma',
'uniform_candidate_sampler', 'random_poisson', 'log_uniform_candidate_sampler', 'shuffle', 'choice_with_mask',
'normal', 'laplace', 'gamma', 'poisson', 'multinomial'
'normal', 'laplace', 'gamma', 'poisson', 'multinomial', 'rand', 'rand_like', 'randn', 'randn_like', 'randint',
'randint_like'
]
__all__.sort()

View File

@ -0,0 +1,127 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Rand(nn.Cell):
def construct(self, size, dtype):
return ops.rand(size, dtype=dtype)
class RandLike(nn.Cell):
def construct(self, x, dtype):
return ops.rand_like(x, dtype=dtype)
class Randn(nn.Cell):
def construct(self, size, dtype):
return ops.randn(size, dtype=dtype)
class RandnLike(nn.Cell):
def construct(self, x, dtype):
return ops.randn_like(x, dtype=dtype)
class RandInt(nn.Cell):
def construct(self, low, high, size, dtype):
return ops.randint(low, high, size, dtype=dtype)
class RandIntLike(nn.Cell):
def construct(self, x, low, high, dtype):
return ops.randint_like(x, low, high, dtype=dtype)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
@pytest.mark.parametrize('dtype', [None, ms.float32])
def test_rand_functions(mode, dtype):
r"""
Feature: ops.rand, ops.randn, ops.rand_like, ops.randn_like
Description: Verify the result of ops.rand, ops.randn, ops.rand_like, ops.randn_like
Expectation: success
"""
ms.set_context(mode=mode)
x = ms.Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), ms.float16)
size = (2, 3)
net1 = Rand()
net2 = Randn()
net3 = RandLike()
net4 = RandnLike()
out1 = net1(size, dtype)
out2 = net2(size, dtype)
out3 = net3(x, dtype)
out4 = net4(x, dtype)
if dtype is None:
assert out1.dtype == ms.float32
assert out2.dtype == ms.float32
assert out3.dtype == ms.float16
assert out4.dtype == ms.float16
else:
assert out1.dtype == dtype
assert out2.dtype == dtype
assert out3.dtype == dtype
assert out4.dtype == dtype
assert out1.shape == size
assert out2.shape == size
assert out3.shape == x.shape
assert out4.shape == x.shape
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
@pytest.mark.parametrize('dtype', [None, ms.int32])
def test_randint_functions(mode, dtype):
r"""
Feature: ops.randint, ops.randint_like
Description: Verify the result of ops.randint, ops.randint_like
Expectation: success
"""
ms.set_context(mode=mode)
x = ms.Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), ms.int32)
net = RandInt()
net2 = RandIntLike()
out = net(0, 10, (2, 3), dtype=dtype)
out2 = net2(x, low=0, high=15, dtype=dtype)
if dtype is None:
assert out.dtype == ms.int64
assert out2.dtype == ms.int32
else:
assert out.dtype == dtype
assert out2.dtype == dtype
assert out.shape == (2, 3)
assert out2.shape == x.shape
assert out.max() < 10 and out.min() >= 0
assert out2.max() < 15 and out2.min() >= 0