add CN api docs & add scalar rate ST & add 1-D-tensor check for argument shape
This commit is contained in:
parent
f0a99ea3bd
commit
b1623fc2e9
|
@ -352,7 +352,7 @@ Tensor创建
|
||||||
mindspore.ops.gamma
|
mindspore.ops.gamma
|
||||||
mindspore.ops.laplace
|
mindspore.ops.laplace
|
||||||
mindspore.ops.multinomial
|
mindspore.ops.multinomial
|
||||||
mindspore.ops.poisson
|
mindspore.ops.random_poisson
|
||||||
mindspore.ops.random_categorical
|
mindspore.ops.random_categorical
|
||||||
mindspore.ops.random_gamma
|
mindspore.ops.random_gamma
|
||||||
mindspore.ops.standard_laplace
|
mindspore.ops.standard_laplace
|
||||||
|
|
|
@ -1,23 +0,0 @@
|
||||||
mindspore.ops.poisson
|
|
||||||
=====================
|
|
||||||
|
|
||||||
.. py:function:: mindspore.ops.poisson(shape, mean, seed=None)
|
|
||||||
|
|
||||||
根据泊松随机数分布生成随机数。
|
|
||||||
|
|
||||||
.. math::
|
|
||||||
\text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!}
|
|
||||||
|
|
||||||
参数:
|
|
||||||
- **shape** (tuple) - Tuple: :math:`(N,*)` ,其中 :math:`*` 表示任何数量的附加维度。
|
|
||||||
- **mean** (Tensor) - 均值μ,分布参数。支持float32数据类型,应大于0。
|
|
||||||
- **seed** (int) - 随机种子。取值须为非负数。默认值:None,等同于0。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
Tensor,shape应与输入 `shape` 与 `mean` 进行广播之后的shape相同。数据类型支持float32。
|
|
||||||
|
|
||||||
异常:
|
|
||||||
- **TypeError** - `shape` 不是Tuple。
|
|
||||||
- **TypeError** - `mean` 不是Tensor或数据类型非float32。
|
|
||||||
- **TypeError** - `seed` 不是int类型。
|
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
mindspore.ops.random_poisson
|
||||||
|
==========================
|
||||||
|
|
||||||
|
.. py:function:: mindspore.ops.random_poisson(shape, rate, seed=None, dtype=mstype.float32)
|
||||||
|
|
||||||
|
从各指定均值的泊松分布中,随机采样`shape`形状的随机数。
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
\text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!}
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **shape** (Tensor) - 表示要从每个分布中采样的随机数张量的形状。必须是一个一维的张量且数据类型必须是`mindspore.dtype.int32`或者`mindspore.dtype.int64`。
|
||||||
|
- **rate** (Tensor) - 泊松分布的 :math:`μ` 参数,表示泊松分布的均值,同时也是分布的方差。必须是一个张量,且其数据类型必须是以下类型中的一种:mindspore.dtype.int64,mindspore.dtype.int32,mindspore.dtype.float64,mindspore.dtype.float32或者mindspore.dtype.float16。
|
||||||
|
- **seed** (int) - 随机数种子,用于在随机数引擎中产生随机数。必须是一个非负的整数。默认值是None,表示使用0作为随机数种子。
|
||||||
|
- **dtype** (mindspore.dtype) - 表示要生成的随机数张量的数据类型。必须是mindspore.dtype类型,可以是以下值中的一种:mindspore.dtype.int64,mindspore.dtype.int32,mindspore.dtype.float64,mindspore.dtype.float32或者mindspore.dtype.float16。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
返回一个张量,它的形状由入参`shape`和`rate`共同决定:`mindspore.concat([`shape`, mindspore.shape(`rate`)], axis=0)`,它的数据类型由入参`dtype`决定。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
- **TypeError** - 如果`shape`不是一个张量。
|
||||||
|
- **TypeError** - 如果`shape`张量的数据类型不是mindspore.dtype.int64或mindspore.dtype.int32。
|
||||||
|
- **ValueError** - 如果`shape`张量的形状不是一维的。
|
||||||
|
- **TypeError** - 如果`rate`不是一个张量。
|
||||||
|
- **TypeError** - 如果`rate`张量的数据类型不是mindspore.dtype.int64,mindspore.dtype.int32,mindspore.dtype.float64,mindspore.dtype.float32或者mindspore.dtype.float16。
|
||||||
|
- **TypeError** - 如果`seed`不是一个非负整型。
|
||||||
|
- **TypeError** - 如果`dtype`不是mindspore.dtype.int64,mindspore.dtype.int32,mindspore.dtype.float64,mindspore.dtype.float32或者mindspore.dtype.float16。
|
||||||
|
- **ValueError** - 如果`shape`张量中有非正数。
|
|
@ -338,6 +338,7 @@ from .random_func import (
|
||||||
standard_normal,
|
standard_normal,
|
||||||
random_gamma,
|
random_gamma,
|
||||||
uniform_candidate_sampler,
|
uniform_candidate_sampler,
|
||||||
|
random_poisson,
|
||||||
)
|
)
|
||||||
from .grad import (
|
from .grad import (
|
||||||
grad_func,
|
grad_func,
|
||||||
|
|
|
@ -336,28 +336,31 @@ def random_poisson(shape, rate, seed=None, dtype=mstype.float32):
|
||||||
\text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!}
|
\text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
shape (Tensor): The shape of random tensor to be generated, 1-D `Tensor` whose dtype is mindspore.dtype.int32 or
|
shape (Tensor): The shape of random tensor to be sampled from each poisson distribution, 1-D `Tensor` whose
|
||||||
mindspore.dtype.int64.
|
dtype is mindspore.dtype.int32 or mindspore.dtype.int64.
|
||||||
rate (Tensor): The μ parameter the distribution was constructed with. The parameter defines mean number of
|
rate (Tensor): The μ parameter the distribution was constructed with. It represents the mean of the distribution
|
||||||
occurrences of the event. It should be a `Tensor` whose dtype is mindspore.dtype.int64, mindspore.dtype.int32,
|
and also the variance of the distribution. It should be a `Tensor` whose dtype is mindspore.dtype.int64,
|
||||||
mindspore.dtype.float64, mindspore.dtype.float32 or mindspore.dtype.float16.
|
mindspore.dtype.int32, mindspore.dtype.float64, mindspore.dtype.float32 or mindspore.dtype.float16.
|
||||||
seed (int): Seed is used as entropy source for the random number engines to generate pseudo-random numbers
|
seed (int): Seed is used as entropy source for the random number engines to generate pseudo-random numbers
|
||||||
and must be non-negative. Default: None, which will be treated as 0.
|
and must be non-negative. Default: None, which will be treated as 0.
|
||||||
dtype (mindspore.dtype): The data type of output: mindspore.dtype.int64, mindspore.dtype.int32,
|
dtype (mindspore.dtype): The data type of output: mindspore.dtype.int64, mindspore.dtype.int32,
|
||||||
mindspore.dtype.float64, mindspore.dtype.float32 or mindspore.dtype.float16. Default: mindspore.dtype.float32.
|
mindspore.dtype.float64, mindspore.dtype.float32 or mindspore.dtype.float16. Default: mindspore.dtype.float32.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor. The shape should be `mindspore.concat([shape, mindspore.shape(mean)], axis=0)`. The data type should be
|
A Tensor whose shape is `mindspore.concat([`shape`, mindspore.shape(`rate`)], axis=0)` and data type is equal to
|
||||||
equal to argument `dtype`.
|
argument `dtype`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If `shape` is not a Tensor[mindspore.dtype.int64] nor a Tensor[mindspore.dtype.int32].
|
TypeError: If `shape` is not a Tensor.
|
||||||
TypeError: If `rate` is not a Tensor or `rate` is a Tensor whose dtype is not in [mindspore.dtype.int64,
|
TypeError: If datatype of `shape` is not mindspore.dtype.int64 nor mindspore.dtype.int32.
|
||||||
mindspore.dtype.int32, mindspore.dtype.float64, mindspore.dtype.float32 or mindspore.dtype.float16].
|
ValueError: If shape of `shape` is not 1-D.
|
||||||
TypeError: If `seed` is not an int.
|
TypeError: If `rate` is not a Tensor nor a scalar.
|
||||||
TypeError: If `dtype` is not mindspore.dtype.int64, mindspore.dtype.int32, mindspore.dtype.float64,
|
TypeError: If datatype of `rate` is not in [mindspore.dtype.int64, mindspore.dtype.int32,
|
||||||
mindspore.dtype.float32 nor mindspore.dtype.float16.
|
mindspore.dtype.float64, mindspore.dtype.float32 or mindspore.dtype.float16].
|
||||||
ValueError: If elements of input `shape` tensor is not positive.
|
TypeError: If `seed` is not a non-negtive int.
|
||||||
|
TypeError: If `dtype` is not in [mindspore.dtype.int64, mindspore.dtype.int32, mindspore.dtype.float64,
|
||||||
|
mindspore.dtype.float32 nor mindspore.dtype.float16].
|
||||||
|
ValueError: If any element of input `shape` tensor is not positive.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``CPU``
|
``CPU``
|
||||||
|
@ -365,20 +368,20 @@ def random_poisson(shape, rate, seed=None, dtype=mstype.float32):
|
||||||
Examples:
|
Examples:
|
||||||
>>> from mindspore import Tensor, ops
|
>>> from mindspore import Tensor, ops
|
||||||
>>> import mindspore
|
>>> import mindspore
|
||||||
>>> # case 1: It can be broadcast.
|
>>> # case 1: 1-D shape, 2-D rate, float64 output
|
||||||
>>> shape = Tensor(np.array([4, 1]), mindspore.int32)
|
>>> shape = Tensor(np.array([2, 2]), mindspore.int64)
|
||||||
>>> rate = Tensor(np.array([5.0, 10.0]), mindspore.float32)
|
>>> rate = Tensor(np.array([[5.0, 10.0], [5.0, 1.0]]), mindspore.float32)
|
||||||
>>> output = ops.random_poisson(shape, rate, seed=5, dtype=mindspore.float64)
|
>>> output = ops.random_poisson(shape, rate, seed=5, dtype=mindspore.float64)
|
||||||
>>> print(output.shape, output.dtype)
|
>>> print(output.shape, output.dtype)
|
||||||
(4, 1, 2) Float64
|
(2, 2, 2, 2) float64
|
||||||
>>> # case 2: It can not be broadcast. It is recommended to use the same shape.
|
>>> # case 2: 1-D shape, scalar rate, int64 output
|
||||||
>>> shape = Tensor(np.array([2, 2]), mindspore.int32)
|
>>> shape = Tensor(np.array([2, 2]), mindspore.int64)
|
||||||
>>> rate = Tensor(np.array([[5.0, 10.0], [5.0, 1.0]]), mindspore.float32)
|
>>> rate = Tensor(5.0, mindspore.float64)
|
||||||
>>> output = ops.random_poisson(shape, rate, seed=5, dtype=mindspore.int64)
|
>>> output = ops.random_poisson(shape, rate, seed=5, dtype=mindspore.int64)
|
||||||
>>> print(output.shape, output.dtype)
|
>>> print(output.shape, output.dtype)
|
||||||
(2, 2, 2, 2) Int64
|
(2, 2) Int64
|
||||||
"""
|
"""
|
||||||
seed1, seed2 = _get_seed(seed, "poisson")
|
seed1, seed2 = _get_seed(seed, "random_poisson")
|
||||||
prim_random_poisson = P.random_ops.RandomPoisson(seed1, seed2, dtype)
|
prim_random_poisson = P.random_ops.RandomPoisson(seed1, seed2, dtype)
|
||||||
value = prim_random_poisson(shape, rate)
|
value = prim_random_poisson(shape, rate)
|
||||||
return value
|
return value
|
||||||
|
@ -391,5 +394,6 @@ __all__ = [
|
||||||
'standard_normal',
|
'standard_normal',
|
||||||
'random_gamma',
|
'random_gamma',
|
||||||
'uniform_candidate_sampler',
|
'uniform_candidate_sampler',
|
||||||
|
'random_poisson',
|
||||||
]
|
]
|
||||||
__all__.sort()
|
__all__.sort()
|
||||||
|
|
|
@ -35,12 +35,14 @@ def test_poisson_function(dtype, shape_dtype, rate_dtype):
|
||||||
Expectation: Output shape is correct.
|
Expectation: Output shape is correct.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# rate is a scalar Tensor
|
||||||
shape = Tensor(np.array([3, 5]), shape_dtype)
|
shape = Tensor(np.array([3, 5]), shape_dtype)
|
||||||
rate = Tensor(np.array([0.5]), rate_dtype)
|
rate = Tensor(0.5, rate_dtype)
|
||||||
output = R.random_poisson(shape, rate, seed=1, dtype=dtype)
|
output = R.random_poisson(shape, rate, seed=1, dtype=dtype)
|
||||||
assert output.shape == (3, 5, 1)
|
assert output.shape == (3, 5)
|
||||||
assert output.dtype == dtype
|
assert output.dtype == dtype
|
||||||
|
|
||||||
|
# rate is a 2-D Tensor
|
||||||
shape = Tensor(np.array([3, 2]), shape_dtype)
|
shape = Tensor(np.array([3, 2]), shape_dtype)
|
||||||
rate = Tensor(np.array([[5.0, 10.0], [5.0, 1.0]]), rate_dtype)
|
rate = Tensor(np.array([[5.0, 10.0], [5.0, 1.0]]), rate_dtype)
|
||||||
output = R.random_poisson(shape, rate, seed=5, dtype=dtype)
|
output = R.random_poisson(shape, rate, seed=5, dtype=dtype)
|
||||||
|
@ -67,6 +69,25 @@ def test_poisson_function_shape_type_error():
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
def test_poisson_function_shape_dim_error():
|
||||||
|
"""
|
||||||
|
Feature: Poisson functional interface
|
||||||
|
Description: Feed 2-D Tensor type `shape` into poisson functional interface.
|
||||||
|
Expectation: Except TypeError.
|
||||||
|
"""
|
||||||
|
|
||||||
|
shape = Tensor(np.array([[1, 2], [3, 5]]), ms.dtype.int32)
|
||||||
|
rate = Tensor(np.array([0.5]), ms.dtype.float32)
|
||||||
|
try:
|
||||||
|
R.random_poisson(shape, rate, seed=1)
|
||||||
|
except ValueError:
|
||||||
|
return
|
||||||
|
assert False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
@pytest.mark.env_onecard
|
@pytest.mark.env_onecard
|
||||||
@pytest.mark.platform_x86_cpu
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@ -183,36 +204,44 @@ def test_poisson_function_out_dtype_error():
|
||||||
|
|
||||||
class PoissonNet(nn.Cell):
|
class PoissonNet(nn.Cell):
|
||||||
""" Network for test dynamic shape feature of poisson functional op. """
|
""" Network for test dynamic shape feature of poisson functional op. """
|
||||||
def __init__(self, out_dtype, axis=0):
|
def __init__(self, out_dtype, is_rate_scalar=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.odtype = out_dtype
|
self.odtype = out_dtype
|
||||||
self.unique = P.Unique()
|
self.unique = P.Unique()
|
||||||
self.gather = P.Gather()
|
self.gather = P.Gather()
|
||||||
self.axis = axis
|
self.is_rate_scalar = is_rate_scalar
|
||||||
|
|
||||||
def construct(self, x, y, indices):
|
def construct(self, x, y, indices):
|
||||||
shape, _ = self.unique(x)
|
shape, _ = self.unique(x)
|
||||||
|
rate = y
|
||||||
idx, _ = self.unique(indices)
|
idx, _ = self.unique(indices)
|
||||||
rate = self.gather(y, idx, self.axis)
|
if not self.is_rate_scalar:
|
||||||
|
rate = self.gather(rate, idx, 0)
|
||||||
return R.random_poisson(shape, rate, seed=1, dtype=self.odtype)
|
return R.random_poisson(shape, rate, seed=1, dtype=self.odtype)
|
||||||
|
|
||||||
|
|
||||||
class PoissonDSFactory:
|
class PoissonDSFactory:
|
||||||
""" Factory class for test dynamic shape feature of poisson functional op. """
|
""" Factory class for test dynamic shape feature of poisson functional op. """
|
||||||
def __init__(self, rate_dims, out_dtype, shape_dtype, rate_dtype):
|
def __init__(self, max_dims, rate_dims):
|
||||||
self.rate_dims = rate_dims
|
|
||||||
self.rate_random_range = 8
|
self.rate_random_range = 8
|
||||||
self.odtype = out_dtype
|
self.odtype = ms.float32
|
||||||
self.shape_dtype = shape_dtype
|
self.shape_dtype = ms.int32
|
||||||
self.rate_dtype = rate_dtype
|
self.rate_dtype = ms.float32
|
||||||
# shape tensor is a 1-D tensor, unique from shape_map.
|
self.is_rate_scalar = rate_dims == 0
|
||||||
self.shape_map = np.random.randint(1, 6, 30, dtype=np.int32)
|
# shape tensor is a 1-D tensor, uniqueed from shape_map.
|
||||||
# rate_map will be gathered as rate tensor.
|
self.shape_map = np.random.randint(1, max_dims, 30, dtype=np.int32)
|
||||||
rate_map_shape = np.random.randint(1, 6, self.rate_dims - 1, dtype=np.int32)
|
|
||||||
|
if self.is_rate_scalar:
|
||||||
|
self.rate_map = np.abs(np.random.randn(1))[0]
|
||||||
|
else:
|
||||||
|
# rate_shape: [rate_random_range, xx, ..., xx], rank of rate_shape = rate_dims
|
||||||
|
rate_map_shape = np.random.randint(1, max_dims, rate_dims - 1, dtype=np.int32)
|
||||||
rate_map_shape = np.append(np.array([self.rate_random_range]), rate_map_shape, axis=0)
|
rate_map_shape = np.append(np.array([self.rate_random_range]), rate_map_shape, axis=0)
|
||||||
self.rate_map = np.random.randn(*rate_map_shape)
|
# rate tensor will be gathered from rate_map.
|
||||||
|
self.rate_map = np.abs(np.random.randn(*rate_map_shape))
|
||||||
# indices array is used to gather rate_map to rate tensor.
|
# indices array is used to gather rate_map to rate tensor.
|
||||||
self.indices = np.random.randint(1, self.rate_random_range, 4, dtype=np.int32)
|
indices_shape = np.random.randint(1, self.rate_random_range, 1, dtype=np.int32)[0]
|
||||||
|
self.indices = np.random.randint(1, self.rate_random_range, indices_shape, dtype=np.int32)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _np_unranked_unique(nparr):
|
def _np_unranked_unique(nparr):
|
||||||
|
@ -238,6 +267,8 @@ class PoissonDSFactory:
|
||||||
def _forward_numpy(self):
|
def _forward_numpy(self):
|
||||||
""" Get result of numpy """
|
""" Get result of numpy """
|
||||||
shape = PoissonDSFactory._np_unranked_unique(self.shape_map)
|
shape = PoissonDSFactory._np_unranked_unique(self.shape_map)
|
||||||
|
if self.is_rate_scalar:
|
||||||
|
return shape
|
||||||
indices = PoissonDSFactory._np_unranked_unique(self.indices)
|
indices = PoissonDSFactory._np_unranked_unique(self.indices)
|
||||||
rate = self.rate_map[indices]
|
rate = self.rate_map[indices]
|
||||||
rate_shape = rate.shape
|
rate_shape = rate.shape
|
||||||
|
@ -247,23 +278,27 @@ class PoissonDSFactory:
|
||||||
def _forward_mindspore(self):
|
def _forward_mindspore(self):
|
||||||
""" Get result of mindspore """
|
""" Get result of mindspore """
|
||||||
shape_map_tensor = Tensor(self.shape_map, dtype=self.shape_dtype)
|
shape_map_tensor = Tensor(self.shape_map, dtype=self.shape_dtype)
|
||||||
rate_map_tensor = Tensor(self.rate_map, dtype=self.rate_dtype)
|
rate_tensor = Tensor(self.rate_map, dtype=self.rate_dtype)
|
||||||
indices_tensor = Tensor(self.indices, dtype=ms.dtype.int32)
|
indices_tensor = Tensor(self.indices, dtype=ms.dtype.int32)
|
||||||
net = PoissonNet(self.odtype)
|
net = PoissonNet(self.odtype, self.is_rate_scalar)
|
||||||
output = net(shape_map_tensor, rate_map_tensor, indices_tensor)
|
output = net(shape_map_tensor, rate_tensor, indices_tensor)
|
||||||
return output.shape, output.dtype
|
return output.shape, output.dtype
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
@pytest.mark.env_onecard
|
@pytest.mark.env_onecard
|
||||||
@pytest.mark.platform_x86_cpu
|
@pytest.mark.platform_x86_cpu
|
||||||
@pytest.mark.parametrize("rate_dims", [1, 2, 3, 4, 5, 6])
|
@pytest.mark.parametrize("max_dims", [2, 3, 4, 5, 6])
|
||||||
def test_poisson_function_dynamic_shape(rate_dims):
|
@pytest.mark.parametrize("rate_dims", [0, 1, 2, 3, 4, 5, 6])
|
||||||
|
def test_poisson_function_dynamic_shape(max_dims, rate_dims):
|
||||||
"""
|
"""
|
||||||
Feature: Poisson functional interface
|
Feature: Dynamic shape of functional interface RandomPoisson.
|
||||||
Description: Test dynamic shape feature of the poisson functional interface with 1D-8D rate.
|
Description:
|
||||||
Expectation: Output of mindspore poisson equal to numpy.
|
1. Initialize a 1-D Tensor as input `shape` whose data type fixed to int32, whose data and shape are random.
|
||||||
|
2. Initialize a Tensor as input `rate` whose data type fixed to float32, whose data and shape are random.
|
||||||
|
3. Compare shape of output from MindSpore and Numpy.
|
||||||
|
Expectation: Output of MindSpore RandomPoisson equal to numpy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
factory = PoissonDSFactory(rate_dims, ms.float32, ms.int32, ms.float32)
|
factory = PoissonDSFactory(max_dims, rate_dims)
|
||||||
factory.forward_compare()
|
factory.forward_compare()
|
||||||
|
|
Loading…
Reference in New Issue