add CN api docs & add scalar rate ST & add 1-D-tensor check for argument shape

This commit is contained in:
hangangqiang 2022-08-09 10:02:52 +08:00
parent f0a99ea3bd
commit b1623fc2e9
6 changed files with 119 additions and 73 deletions

View File

@ -352,7 +352,7 @@ Tensor创建
mindspore.ops.gamma
mindspore.ops.laplace
mindspore.ops.multinomial
mindspore.ops.poisson
mindspore.ops.random_poisson
mindspore.ops.random_categorical
mindspore.ops.random_gamma
mindspore.ops.standard_laplace

View File

@ -1,23 +0,0 @@
mindspore.ops.poisson
=====================
.. py:function:: mindspore.ops.poisson(shape, mean, seed=None)
根据泊松随机数分布生成随机数。
.. math::
\text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!}
参数:
- **shape** (tuple) - Tuple: :math:`(N,*)` ,其中 :math:`*` 表示任何数量的附加维度。
- **mean** (Tensor) - 均值μ分布参数。支持float32数据类型应大于0。
- **seed** (int) - 随机种子。取值须为非负数。默认值None等同于0。
返回:
Tensorshape应与输入 `shape``mean` 进行广播之后的shape相同。数据类型支持float32。
异常:
- **TypeError** - `shape` 不是Tuple。
- **TypeError** - `mean` 不是Tensor或数据类型非float32。
- **TypeError** - `seed` 不是int类型。

View File

@ -0,0 +1,29 @@
mindspore.ops.random_poisson
==========================
.. py:function:: mindspore.ops.random_poisson(shape, rate, seed=None, dtype=mstype.float32)
从各指定均值的泊松分布中,随机采样`shape`形状的随机数。
.. math::
\text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!}
参数:
- **shape** (Tensor) - 表示要从每个分布中采样的随机数张量的形状。必须是一个一维的张量且数据类型必须是`mindspore.dtype.int32`或者`mindspore.dtype.int64`
- **rate** (Tensor) - 泊松分布的 :math:`μ` 参数表示泊松分布的均值同时也是分布的方差。必须是一个张量且其数据类型必须是以下类型中的一种mindspore.dtype.int64mindspore.dtype.int32mindspore.dtype.float64mindspore.dtype.float32或者mindspore.dtype.float16。
- **seed** (int) - 随机数种子用于在随机数引擎中产生随机数。必须是一个非负的整数。默认值是None表示使用0作为随机数种子。
- **dtype** (mindspore.dtype) - 表示要生成的随机数张量的数据类型。必须是mindspore.dtype类型可以是以下值中的一种mindspore.dtype.int64mindspore.dtype.int32mindspore.dtype.float64mindspore.dtype.float32或者mindspore.dtype.float16。
返回:
返回一个张量,它的形状由入参`shape``rate`共同决定:`mindspore.concat([`shape`, mindspore.shape(`rate`)], axis=0)`,它的数据类型由入参`dtype`决定。
异常:
- **TypeError** - 如果`shape`不是一个张量。
- **TypeError** - 如果`shape`张量的数据类型不是mindspore.dtype.int64或mindspore.dtype.int32。
- **ValueError** - 如果`shape`张量的形状不是一维的。
- **TypeError** - 如果`rate`不是一个张量。
- **TypeError** - 如果`rate`张量的数据类型不是mindspore.dtype.int64mindspore.dtype.int32mindspore.dtype.float64mindspore.dtype.float32或者mindspore.dtype.float16。
- **TypeError** - 如果`seed`不是一个非负整型。
- **TypeError** - 如果`dtype`不是mindspore.dtype.int64mindspore.dtype.int32mindspore.dtype.float64mindspore.dtype.float32或者mindspore.dtype.float16。
- **ValueError** - 如果`shape`张量中有非正数。

View File

@ -338,6 +338,7 @@ from .random_func import (
standard_normal,
random_gamma,
uniform_candidate_sampler,
random_poisson,
)
from .grad import (
grad_func,

View File

@ -336,28 +336,31 @@ def random_poisson(shape, rate, seed=None, dtype=mstype.float32):
\text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!}
Args:
shape (Tensor): The shape of random tensor to be generated, 1-D `Tensor` whose dtype is mindspore.dtype.int32 or
mindspore.dtype.int64.
rate (Tensor): The μ parameter the distribution was constructed with. The parameter defines mean number of
occurrences of the event. It should be a `Tensor` whose dtype is mindspore.dtype.int64, mindspore.dtype.int32,
mindspore.dtype.float64, mindspore.dtype.float32 or mindspore.dtype.float16.
shape (Tensor): The shape of random tensor to be sampled from each poisson distribution, 1-D `Tensor` whose
dtype is mindspore.dtype.int32 or mindspore.dtype.int64.
rate (Tensor): The μ parameter the distribution was constructed with. It represents the mean of the distribution
and also the variance of the distribution. It should be a `Tensor` whose dtype is mindspore.dtype.int64,
mindspore.dtype.int32, mindspore.dtype.float64, mindspore.dtype.float32 or mindspore.dtype.float16.
seed (int): Seed is used as entropy source for the random number engines to generate pseudo-random numbers
and must be non-negative. Default: None, which will be treated as 0.
dtype (mindspore.dtype): The data type of output: mindspore.dtype.int64, mindspore.dtype.int32,
mindspore.dtype.float64, mindspore.dtype.float32 or mindspore.dtype.float16. Default: mindspore.dtype.float32.
Returns:
Tensor. The shape should be `mindspore.concat([shape, mindspore.shape(mean)], axis=0)`. The data type should be
equal to argument `dtype`.
A Tensor whose shape is `mindspore.concat([`shape`, mindspore.shape(`rate`)], axis=0)` and data type is equal to
argument `dtype`.
Raises:
TypeError: If `shape` is not a Tensor[mindspore.dtype.int64] nor a Tensor[mindspore.dtype.int32].
TypeError: If `rate` is not a Tensor or `rate` is a Tensor whose dtype is not in [mindspore.dtype.int64,
mindspore.dtype.int32, mindspore.dtype.float64, mindspore.dtype.float32 or mindspore.dtype.float16].
TypeError: If `seed` is not an int.
TypeError: If `dtype` is not mindspore.dtype.int64, mindspore.dtype.int32, mindspore.dtype.float64,
mindspore.dtype.float32 nor mindspore.dtype.float16.
ValueError: If elements of input `shape` tensor is not positive.
TypeError: If `shape` is not a Tensor.
TypeError: If datatype of `shape` is not mindspore.dtype.int64 nor mindspore.dtype.int32.
ValueError: If shape of `shape` is not 1-D.
TypeError: If `rate` is not a Tensor nor a scalar.
TypeError: If datatype of `rate` is not in [mindspore.dtype.int64, mindspore.dtype.int32,
mindspore.dtype.float64, mindspore.dtype.float32 or mindspore.dtype.float16].
TypeError: If `seed` is not a non-negtive int.
TypeError: If `dtype` is not in [mindspore.dtype.int64, mindspore.dtype.int32, mindspore.dtype.float64,
mindspore.dtype.float32 nor mindspore.dtype.float16].
ValueError: If any element of input `shape` tensor is not positive.
Supported Platforms:
``CPU``
@ -365,20 +368,20 @@ def random_poisson(shape, rate, seed=None, dtype=mstype.float32):
Examples:
>>> from mindspore import Tensor, ops
>>> import mindspore
>>> # case 1: It can be broadcast.
>>> shape = Tensor(np.array([4, 1]), mindspore.int32)
>>> rate = Tensor(np.array([5.0, 10.0]), mindspore.float32)
>>> # case 1: 1-D shape, 2-D rate, float64 output
>>> shape = Tensor(np.array([2, 2]), mindspore.int64)
>>> rate = Tensor(np.array([[5.0, 10.0], [5.0, 1.0]]), mindspore.float32)
>>> output = ops.random_poisson(shape, rate, seed=5, dtype=mindspore.float64)
>>> print(output.shape, output.dtype)
(4, 1, 2) Float64
>>> # case 2: It can not be broadcast. It is recommended to use the same shape.
>>> shape = Tensor(np.array([2, 2]), mindspore.int32)
>>> rate = Tensor(np.array([[5.0, 10.0], [5.0, 1.0]]), mindspore.float32)
(2, 2, 2, 2) float64
>>> # case 2: 1-D shape, scalar rate, int64 output
>>> shape = Tensor(np.array([2, 2]), mindspore.int64)
>>> rate = Tensor(5.0, mindspore.float64)
>>> output = ops.random_poisson(shape, rate, seed=5, dtype=mindspore.int64)
>>> print(output.shape, output.dtype)
(2, 2, 2, 2) Int64
(2, 2) Int64
"""
seed1, seed2 = _get_seed(seed, "poisson")
seed1, seed2 = _get_seed(seed, "random_poisson")
prim_random_poisson = P.random_ops.RandomPoisson(seed1, seed2, dtype)
value = prim_random_poisson(shape, rate)
return value
@ -391,5 +394,6 @@ __all__ = [
'standard_normal',
'random_gamma',
'uniform_candidate_sampler',
'random_poisson',
]
__all__.sort()

View File

@ -35,12 +35,14 @@ def test_poisson_function(dtype, shape_dtype, rate_dtype):
Expectation: Output shape is correct.
"""
# rate is a scalar Tensor
shape = Tensor(np.array([3, 5]), shape_dtype)
rate = Tensor(np.array([0.5]), rate_dtype)
rate = Tensor(0.5, rate_dtype)
output = R.random_poisson(shape, rate, seed=1, dtype=dtype)
assert output.shape == (3, 5, 1)
assert output.shape == (3, 5)
assert output.dtype == dtype
# rate is a 2-D Tensor
shape = Tensor(np.array([3, 2]), shape_dtype)
rate = Tensor(np.array([[5.0, 10.0], [5.0, 1.0]]), rate_dtype)
output = R.random_poisson(shape, rate, seed=5, dtype=dtype)
@ -67,6 +69,25 @@ def test_poisson_function_shape_type_error():
assert False
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
def test_poisson_function_shape_dim_error():
"""
Feature: Poisson functional interface
Description: Feed 2-D Tensor type `shape` into poisson functional interface.
Expectation: Except TypeError.
"""
shape = Tensor(np.array([[1, 2], [3, 5]]), ms.dtype.int32)
rate = Tensor(np.array([0.5]), ms.dtype.float32)
try:
R.random_poisson(shape, rate, seed=1)
except ValueError:
return
assert False
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
@ -183,36 +204,44 @@ def test_poisson_function_out_dtype_error():
class PoissonNet(nn.Cell):
""" Network for test dynamic shape feature of poisson functional op. """
def __init__(self, out_dtype, axis=0):
def __init__(self, out_dtype, is_rate_scalar=False):
super().__init__()
self.odtype = out_dtype
self.unique = P.Unique()
self.gather = P.Gather()
self.axis = axis
self.is_rate_scalar = is_rate_scalar
def construct(self, x, y, indices):
shape, _ = self.unique(x)
rate = y
idx, _ = self.unique(indices)
rate = self.gather(y, idx, self.axis)
if not self.is_rate_scalar:
rate = self.gather(rate, idx, 0)
return R.random_poisson(shape, rate, seed=1, dtype=self.odtype)
class PoissonDSFactory:
""" Factory class for test dynamic shape feature of poisson functional op. """
def __init__(self, rate_dims, out_dtype, shape_dtype, rate_dtype):
self.rate_dims = rate_dims
def __init__(self, max_dims, rate_dims):
self.rate_random_range = 8
self.odtype = out_dtype
self.shape_dtype = shape_dtype
self.rate_dtype = rate_dtype
# shape tensor is a 1-D tensor, unique from shape_map.
self.shape_map = np.random.randint(1, 6, 30, dtype=np.int32)
# rate_map will be gathered as rate tensor.
rate_map_shape = np.random.randint(1, 6, self.rate_dims - 1, dtype=np.int32)
rate_map_shape = np.append(np.array([self.rate_random_range]), rate_map_shape, axis=0)
self.rate_map = np.random.randn(*rate_map_shape)
self.odtype = ms.float32
self.shape_dtype = ms.int32
self.rate_dtype = ms.float32
self.is_rate_scalar = rate_dims == 0
# shape tensor is a 1-D tensor, uniqueed from shape_map.
self.shape_map = np.random.randint(1, max_dims, 30, dtype=np.int32)
if self.is_rate_scalar:
self.rate_map = np.abs(np.random.randn(1))[0]
else:
# rate_shape: [rate_random_range, xx, ..., xx], rank of rate_shape = rate_dims
rate_map_shape = np.random.randint(1, max_dims, rate_dims - 1, dtype=np.int32)
rate_map_shape = np.append(np.array([self.rate_random_range]), rate_map_shape, axis=0)
# rate tensor will be gathered from rate_map.
self.rate_map = np.abs(np.random.randn(*rate_map_shape))
# indices array is used to gather rate_map to rate tensor.
self.indices = np.random.randint(1, self.rate_random_range, 4, dtype=np.int32)
indices_shape = np.random.randint(1, self.rate_random_range, 1, dtype=np.int32)[0]
self.indices = np.random.randint(1, self.rate_random_range, indices_shape, dtype=np.int32)
@staticmethod
def _np_unranked_unique(nparr):
@ -238,6 +267,8 @@ class PoissonDSFactory:
def _forward_numpy(self):
""" Get result of numpy """
shape = PoissonDSFactory._np_unranked_unique(self.shape_map)
if self.is_rate_scalar:
return shape
indices = PoissonDSFactory._np_unranked_unique(self.indices)
rate = self.rate_map[indices]
rate_shape = rate.shape
@ -247,23 +278,27 @@ class PoissonDSFactory:
def _forward_mindspore(self):
""" Get result of mindspore """
shape_map_tensor = Tensor(self.shape_map, dtype=self.shape_dtype)
rate_map_tensor = Tensor(self.rate_map, dtype=self.rate_dtype)
rate_tensor = Tensor(self.rate_map, dtype=self.rate_dtype)
indices_tensor = Tensor(self.indices, dtype=ms.dtype.int32)
net = PoissonNet(self.odtype)
output = net(shape_map_tensor, rate_map_tensor, indices_tensor)
net = PoissonNet(self.odtype, self.is_rate_scalar)
output = net(shape_map_tensor, rate_tensor, indices_tensor)
return output.shape, output.dtype
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
@pytest.mark.parametrize("rate_dims", [1, 2, 3, 4, 5, 6])
def test_poisson_function_dynamic_shape(rate_dims):
@pytest.mark.parametrize("max_dims", [2, 3, 4, 5, 6])
@pytest.mark.parametrize("rate_dims", [0, 1, 2, 3, 4, 5, 6])
def test_poisson_function_dynamic_shape(max_dims, rate_dims):
"""
Feature: Poisson functional interface
Description: Test dynamic shape feature of the poisson functional interface with 1D-8D rate.
Expectation: Output of mindspore poisson equal to numpy.
Feature: Dynamic shape of functional interface RandomPoisson.
Description:
1. Initialize a 1-D Tensor as input `shape` whose data type fixed to int32, whose data and shape are random.
2. Initialize a Tensor as input `rate` whose data type fixed to float32, whose data and shape are random.
3. Compare shape of output from MindSpore and Numpy.
Expectation: Output of MindSpore RandomPoisson equal to numpy.
"""
factory = PoissonDSFactory(rate_dims, ms.float32, ms.int32, ms.float32)
factory = PoissonDSFactory(max_dims, rate_dims)
factory.forward_compare()