fix bernoulli docs

This commit is contained in:
jiangshuqiang 2022-05-30 14:42:31 +08:00
parent b3d8d3e89b
commit 0a096070af
4 changed files with 20 additions and 15 deletions

View File

@ -1,7 +1,7 @@
mindspore.ops.bernoulli
mindspore.ops.bernoulli
=======================
.. py:class:: mindspore.ops.bernoulli(x, p=0.5, seed=-1)
.. py:function:: mindspore.ops.bernoulli(x, p=0.5, seed=-1)
以p的概率随机将输出的元素设置为0或1服从伯努利分布。

View File

@ -3544,16 +3544,17 @@ class Tensor(Tensor_):
``GPU``
Examples:
>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor
>>> input_x = Tensor(np.array([1, 2, 3], mindspore.int8))
>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int8)
>>> output = input_x.bernoulli(p=1.0)
>>> print(output)
[1, 1, 1]
>>> input_p = Tensor(np.array([0.0, 1.0, 1.0], mindspore.float32))
[1 1 1]
>>> input_p = Tensor(np.array([0.0, 1.0, 1.0]), mindspore.float32)
>>> output = input_x.bernoulli(input_p)
>>> print(output)
[0, 1, 1]
[0 1 1]
"""
self._init_check()
validator.check_is_int(seed, 'seed')

View File

@ -3105,14 +3105,18 @@ def bernoulli(x, p=0.5, seed=-1):
``GPU``
Examples:
>>> input_x = Tensor(np.array([1, 2, 3], mindspore.int8))
>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor
>>> import mindspore.ops as ops
>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int8)
>>> output = ops.bernoulli(input_x, p=1.0)
>>> print(output)
[1, 1, 1]
>>> input_p = Tensor(np.array([0.0, 1.0, 1.0], mindspore.float32))
[1 1 1]
>>> input_p = Tensor(np.array([0.0, 1.0, 1.0]), mindspore.float32)
>>> output = ops.bernoulli(input_x, input_p)
>>> print(output)
[0, 1, 1]
[0 1 1]
"""
bernoulli_ = Bernoulli(seed)
return bernoulli_(x, p)

View File

@ -5952,15 +5952,15 @@ class Bernoulli(Primitive):
``GPU``
Examples:
>>> input_x = Tensor(np.array([1, 2, 3], mindspore.int8))
>>> bernoulli = P.Bernoulli()
>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int8)
>>> bernoulli = Bernoulli()
>>> output = bernoulli(input_x, p=1.0)
>>> print(output)
[1, 1, 1]
>>> input_p = Tensor(np.array([0.0, 1.0, 1.0], mindspore.float32))
[1 1 1]
>>> input_p = Tensor(np.array([0.0, 1.0, 1.0]), mindspore.float32)
>>> output = bernoulli(input_x, input_p)
>>> print(output)
[0, 1, 1]
[0 1 1]
"""
@prim_attr_register