diff --git a/docs/api/api_python/mindspore.ops.functional.rst b/docs/api/api_python/mindspore.ops.functional.rst index 72fcaf62caa..d5e2a09e0c5 100644 --- a/docs/api/api_python/mindspore.ops.functional.rst +++ b/docs/api/api_python/mindspore.ops.functional.rst @@ -352,7 +352,7 @@ Tensor创建 mindspore.ops.gamma mindspore.ops.laplace mindspore.ops.multinomial - mindspore.ops.poisson + mindspore.ops.random_poisson mindspore.ops.random_categorical mindspore.ops.random_gamma mindspore.ops.standard_laplace diff --git a/docs/api/api_python/ops/mindspore.ops.func_poisson.rst b/docs/api/api_python/ops/mindspore.ops.func_poisson.rst deleted file mode 100644 index 80e379933c9..00000000000 --- a/docs/api/api_python/ops/mindspore.ops.func_poisson.rst +++ /dev/null @@ -1,23 +0,0 @@ -mindspore.ops.poisson -===================== - -.. py:function:: mindspore.ops.poisson(shape, mean, seed=None) - - 根据泊松随机数分布生成随机数。 - - .. math:: - \text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!} - - 参数: - - **shape** (tuple) - Tuple: :math:`(N,*)` ,其中 :math:`*` 表示任何数量的附加维度。 - - **mean** (Tensor) - 均值μ,分布参数。支持float32数据类型,应大于0。 - - **seed** (int) - 随机种子。取值须为非负数。默认值:None,等同于0。 - - 返回: - Tensor,shape应与输入 `shape` 与 `mean` 进行广播之后的shape相同。数据类型支持float32。 - - 异常: - - **TypeError** - `shape` 不是Tuple。 - - **TypeError** - `mean` 不是Tensor或数据类型非float32。 - - **TypeError** - `seed` 不是int类型。 - diff --git a/docs/api/api_python/ops/mindspore.ops.func_random_poisson.rst b/docs/api/api_python/ops/mindspore.ops.func_random_poisson.rst new file mode 100644 index 00000000000..9c113314549 --- /dev/null +++ b/docs/api/api_python/ops/mindspore.ops.func_random_poisson.rst @@ -0,0 +1,29 @@ +mindspore.ops.random_poisson +========================== + +.. py:function:: mindspore.ops.random_poisson(shape, rate, seed=None, dtype=mstype.float32) + + 从各指定均值的泊松分布中,随机采样`shape`形状的随机数。 + + .. math:: + + \text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!} + + 参数: + - **shape** (Tensor) - 表示要从每个分布中采样的随机数张量的形状。必须是一个一维的张量且数据类型必须是`mindspore.dtype.int32`或者`mindspore.dtype.int64`。 + - **rate** (Tensor) - 泊松分布的 :math:`μ` 参数,表示泊松分布的均值,同时也是分布的方差。必须是一个张量,且其数据类型必须是以下类型中的一种:mindspore.dtype.int64,mindspore.dtype.int32,mindspore.dtype.float64,mindspore.dtype.float32或者mindspore.dtype.float16。 + - **seed** (int) - 随机数种子,用于在随机数引擎中产生随机数。必须是一个非负的整数。默认值是None,表示使用0作为随机数种子。 + - **dtype** (mindspore.dtype) - 表示要生成的随机数张量的数据类型。必须是mindspore.dtype类型,可以是以下值中的一种:mindspore.dtype.int64,mindspore.dtype.int32,mindspore.dtype.float64,mindspore.dtype.float32或者mindspore.dtype.float16。 + + 返回: + 返回一个张量,它的形状由入参`shape`和`rate`共同决定:`mindspore.concat([`shape`, mindspore.shape(`rate`)], axis=0)`,它的数据类型由入参`dtype`决定。 + + 异常: + - **TypeError** - 如果`shape`不是一个张量。 + - **TypeError** - 如果`shape`张量的数据类型不是mindspore.dtype.int64或mindspore.dtype.int32。 + - **ValueError** - 如果`shape`张量的形状不是一维的。 + - **TypeError** - 如果`rate`不是一个张量。 + - **TypeError** - 如果`rate`张量的数据类型不是mindspore.dtype.int64,mindspore.dtype.int32,mindspore.dtype.float64,mindspore.dtype.float32或者mindspore.dtype.float16。 + - **TypeError** - 如果`seed`不是一个非负整型。 + - **TypeError** - 如果`dtype`不是mindspore.dtype.int64,mindspore.dtype.int32,mindspore.dtype.float64,mindspore.dtype.float32或者mindspore.dtype.float16。 + - **ValueError** - 如果`shape`张量中有非正数。 diff --git a/mindspore/python/mindspore/ops/function/__init__.py b/mindspore/python/mindspore/ops/function/__init__.py index f306ee413d3..cc762d19950 100644 --- a/mindspore/python/mindspore/ops/function/__init__.py +++ b/mindspore/python/mindspore/ops/function/__init__.py @@ -338,6 +338,7 @@ from .random_func import ( standard_normal, random_gamma, uniform_candidate_sampler, + random_poisson, ) from .grad import ( grad_func, diff --git a/mindspore/python/mindspore/ops/function/random_func.py b/mindspore/python/mindspore/ops/function/random_func.py index 63c709a523b..a7345ab9669 100755 --- a/mindspore/python/mindspore/ops/function/random_func.py +++ b/mindspore/python/mindspore/ops/function/random_func.py @@ -336,28 +336,31 @@ def random_poisson(shape, rate, seed=None, dtype=mstype.float32): \text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!} Args: - shape (Tensor): The shape of random tensor to be generated, 1-D `Tensor` whose dtype is mindspore.dtype.int32 or - mindspore.dtype.int64. - rate (Tensor): The μ parameter the distribution was constructed with. The parameter defines mean number of - occurrences of the event. It should be a `Tensor` whose dtype is mindspore.dtype.int64, mindspore.dtype.int32, - mindspore.dtype.float64, mindspore.dtype.float32 or mindspore.dtype.float16. + shape (Tensor): The shape of random tensor to be sampled from each poisson distribution, 1-D `Tensor` whose + dtype is mindspore.dtype.int32 or mindspore.dtype.int64. + rate (Tensor): The μ parameter the distribution was constructed with. It represents the mean of the distribution + and also the variance of the distribution. It should be a `Tensor` whose dtype is mindspore.dtype.int64, + mindspore.dtype.int32, mindspore.dtype.float64, mindspore.dtype.float32 or mindspore.dtype.float16. seed (int): Seed is used as entropy source for the random number engines to generate pseudo-random numbers and must be non-negative. Default: None, which will be treated as 0. dtype (mindspore.dtype): The data type of output: mindspore.dtype.int64, mindspore.dtype.int32, mindspore.dtype.float64, mindspore.dtype.float32 or mindspore.dtype.float16. Default: mindspore.dtype.float32. Returns: - Tensor. The shape should be `mindspore.concat([shape, mindspore.shape(mean)], axis=0)`. The data type should be - equal to argument `dtype`. + A Tensor whose shape is `mindspore.concat([`shape`, mindspore.shape(`rate`)], axis=0)` and data type is equal to + argument `dtype`. Raises: - TypeError: If `shape` is not a Tensor[mindspore.dtype.int64] nor a Tensor[mindspore.dtype.int32]. - TypeError: If `rate` is not a Tensor or `rate` is a Tensor whose dtype is not in [mindspore.dtype.int64, - mindspore.dtype.int32, mindspore.dtype.float64, mindspore.dtype.float32 or mindspore.dtype.float16]. - TypeError: If `seed` is not an int. - TypeError: If `dtype` is not mindspore.dtype.int64, mindspore.dtype.int32, mindspore.dtype.float64, - mindspore.dtype.float32 nor mindspore.dtype.float16. - ValueError: If elements of input `shape` tensor is not positive. + TypeError: If `shape` is not a Tensor. + TypeError: If datatype of `shape` is not mindspore.dtype.int64 nor mindspore.dtype.int32. + ValueError: If shape of `shape` is not 1-D. + TypeError: If `rate` is not a Tensor nor a scalar. + TypeError: If datatype of `rate` is not in [mindspore.dtype.int64, mindspore.dtype.int32, + mindspore.dtype.float64, mindspore.dtype.float32 or mindspore.dtype.float16]. + TypeError: If `seed` is not a non-negtive int. + TypeError: If `dtype` is not in [mindspore.dtype.int64, mindspore.dtype.int32, mindspore.dtype.float64, + mindspore.dtype.float32 nor mindspore.dtype.float16]. + ValueError: If any element of input `shape` tensor is not positive. Supported Platforms: ``CPU`` @@ -365,20 +368,20 @@ def random_poisson(shape, rate, seed=None, dtype=mstype.float32): Examples: >>> from mindspore import Tensor, ops >>> import mindspore - >>> # case 1: It can be broadcast. - >>> shape = Tensor(np.array([4, 1]), mindspore.int32) - >>> rate = Tensor(np.array([5.0, 10.0]), mindspore.float32) + >>> # case 1: 1-D shape, 2-D rate, float64 output + >>> shape = Tensor(np.array([2, 2]), mindspore.int64) + >>> rate = Tensor(np.array([[5.0, 10.0], [5.0, 1.0]]), mindspore.float32) >>> output = ops.random_poisson(shape, rate, seed=5, dtype=mindspore.float64) >>> print(output.shape, output.dtype) - (4, 1, 2) Float64 - >>> # case 2: It can not be broadcast. It is recommended to use the same shape. - >>> shape = Tensor(np.array([2, 2]), mindspore.int32) - >>> rate = Tensor(np.array([[5.0, 10.0], [5.0, 1.0]]), mindspore.float32) + (2, 2, 2, 2) float64 + >>> # case 2: 1-D shape, scalar rate, int64 output + >>> shape = Tensor(np.array([2, 2]), mindspore.int64) + >>> rate = Tensor(5.0, mindspore.float64) >>> output = ops.random_poisson(shape, rate, seed=5, dtype=mindspore.int64) >>> print(output.shape, output.dtype) - (2, 2, 2, 2) Int64 + (2, 2) Int64 """ - seed1, seed2 = _get_seed(seed, "poisson") + seed1, seed2 = _get_seed(seed, "random_poisson") prim_random_poisson = P.random_ops.RandomPoisson(seed1, seed2, dtype) value = prim_random_poisson(shape, rate) return value @@ -391,5 +394,6 @@ __all__ = [ 'standard_normal', 'random_gamma', 'uniform_candidate_sampler', + 'random_poisson', ] __all__.sort() diff --git a/tests/st/ops/cpu/test_random_poisson_op.py b/tests/st/ops/cpu/test_random_poisson_op.py index 85ad57bfa18..fef744c5045 100644 --- a/tests/st/ops/cpu/test_random_poisson_op.py +++ b/tests/st/ops/cpu/test_random_poisson_op.py @@ -35,12 +35,14 @@ def test_poisson_function(dtype, shape_dtype, rate_dtype): Expectation: Output shape is correct. """ + # rate is a scalar Tensor shape = Tensor(np.array([3, 5]), shape_dtype) - rate = Tensor(np.array([0.5]), rate_dtype) + rate = Tensor(0.5, rate_dtype) output = R.random_poisson(shape, rate, seed=1, dtype=dtype) - assert output.shape == (3, 5, 1) + assert output.shape == (3, 5) assert output.dtype == dtype + # rate is a 2-D Tensor shape = Tensor(np.array([3, 2]), shape_dtype) rate = Tensor(np.array([[5.0, 10.0], [5.0, 1.0]]), rate_dtype) output = R.random_poisson(shape, rate, seed=5, dtype=dtype) @@ -67,6 +69,25 @@ def test_poisson_function_shape_type_error(): assert False +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_x86_cpu +def test_poisson_function_shape_dim_error(): + """ + Feature: Poisson functional interface + Description: Feed 2-D Tensor type `shape` into poisson functional interface. + Expectation: Except TypeError. + """ + + shape = Tensor(np.array([[1, 2], [3, 5]]), ms.dtype.int32) + rate = Tensor(np.array([0.5]), ms.dtype.float32) + try: + R.random_poisson(shape, rate, seed=1) + except ValueError: + return + assert False + + @pytest.mark.level0 @pytest.mark.env_onecard @pytest.mark.platform_x86_cpu @@ -183,36 +204,44 @@ def test_poisson_function_out_dtype_error(): class PoissonNet(nn.Cell): """ Network for test dynamic shape feature of poisson functional op. """ - def __init__(self, out_dtype, axis=0): + def __init__(self, out_dtype, is_rate_scalar=False): super().__init__() self.odtype = out_dtype self.unique = P.Unique() self.gather = P.Gather() - self.axis = axis + self.is_rate_scalar = is_rate_scalar def construct(self, x, y, indices): shape, _ = self.unique(x) + rate = y idx, _ = self.unique(indices) - rate = self.gather(y, idx, self.axis) + if not self.is_rate_scalar: + rate = self.gather(rate, idx, 0) return R.random_poisson(shape, rate, seed=1, dtype=self.odtype) class PoissonDSFactory: """ Factory class for test dynamic shape feature of poisson functional op. """ - def __init__(self, rate_dims, out_dtype, shape_dtype, rate_dtype): - self.rate_dims = rate_dims + def __init__(self, max_dims, rate_dims): self.rate_random_range = 8 - self.odtype = out_dtype - self.shape_dtype = shape_dtype - self.rate_dtype = rate_dtype - # shape tensor is a 1-D tensor, unique from shape_map. - self.shape_map = np.random.randint(1, 6, 30, dtype=np.int32) - # rate_map will be gathered as rate tensor. - rate_map_shape = np.random.randint(1, 6, self.rate_dims - 1, dtype=np.int32) - rate_map_shape = np.append(np.array([self.rate_random_range]), rate_map_shape, axis=0) - self.rate_map = np.random.randn(*rate_map_shape) + self.odtype = ms.float32 + self.shape_dtype = ms.int32 + self.rate_dtype = ms.float32 + self.is_rate_scalar = rate_dims == 0 + # shape tensor is a 1-D tensor, uniqueed from shape_map. + self.shape_map = np.random.randint(1, max_dims, 30, dtype=np.int32) + + if self.is_rate_scalar: + self.rate_map = np.abs(np.random.randn(1))[0] + else: + # rate_shape: [rate_random_range, xx, ..., xx], rank of rate_shape = rate_dims + rate_map_shape = np.random.randint(1, max_dims, rate_dims - 1, dtype=np.int32) + rate_map_shape = np.append(np.array([self.rate_random_range]), rate_map_shape, axis=0) + # rate tensor will be gathered from rate_map. + self.rate_map = np.abs(np.random.randn(*rate_map_shape)) # indices array is used to gather rate_map to rate tensor. - self.indices = np.random.randint(1, self.rate_random_range, 4, dtype=np.int32) + indices_shape = np.random.randint(1, self.rate_random_range, 1, dtype=np.int32)[0] + self.indices = np.random.randint(1, self.rate_random_range, indices_shape, dtype=np.int32) @staticmethod def _np_unranked_unique(nparr): @@ -238,6 +267,8 @@ class PoissonDSFactory: def _forward_numpy(self): """ Get result of numpy """ shape = PoissonDSFactory._np_unranked_unique(self.shape_map) + if self.is_rate_scalar: + return shape indices = PoissonDSFactory._np_unranked_unique(self.indices) rate = self.rate_map[indices] rate_shape = rate.shape @@ -247,23 +278,27 @@ class PoissonDSFactory: def _forward_mindspore(self): """ Get result of mindspore """ shape_map_tensor = Tensor(self.shape_map, dtype=self.shape_dtype) - rate_map_tensor = Tensor(self.rate_map, dtype=self.rate_dtype) + rate_tensor = Tensor(self.rate_map, dtype=self.rate_dtype) indices_tensor = Tensor(self.indices, dtype=ms.dtype.int32) - net = PoissonNet(self.odtype) - output = net(shape_map_tensor, rate_map_tensor, indices_tensor) + net = PoissonNet(self.odtype, self.is_rate_scalar) + output = net(shape_map_tensor, rate_tensor, indices_tensor) return output.shape, output.dtype @pytest.mark.level0 @pytest.mark.env_onecard @pytest.mark.platform_x86_cpu -@pytest.mark.parametrize("rate_dims", [1, 2, 3, 4, 5, 6]) -def test_poisson_function_dynamic_shape(rate_dims): +@pytest.mark.parametrize("max_dims", [2, 3, 4, 5, 6]) +@pytest.mark.parametrize("rate_dims", [0, 1, 2, 3, 4, 5, 6]) +def test_poisson_function_dynamic_shape(max_dims, rate_dims): """ - Feature: Poisson functional interface - Description: Test dynamic shape feature of the poisson functional interface with 1D-8D rate. - Expectation: Output of mindspore poisson equal to numpy. + Feature: Dynamic shape of functional interface RandomPoisson. + Description: + 1. Initialize a 1-D Tensor as input `shape` whose data type fixed to int32, whose data and shape are random. + 2. Initialize a Tensor as input `rate` whose data type fixed to float32, whose data and shape are random. + 3. Compare shape of output from MindSpore and Numpy. + Expectation: Output of MindSpore RandomPoisson equal to numpy. """ - factory = PoissonDSFactory(rate_dims, ms.float32, ms.int32, ms.float32) + factory = PoissonDSFactory(max_dims, rate_dims) factory.forward_compare()