!20937 Rectification of operator ease of use part 7 ops/composite

Merge pull request !20937 from dinglinhe/code_docs_dlh_ms_I3R3BX_8_composite
This commit is contained in:
i-robot 2021-07-29 01:25:50 +00:00 committed by Gitee
commit 18872a6475
5 changed files with 276 additions and 87 deletions

View File

@ -74,13 +74,20 @@ def repeat_elements(x, rep, axis=0):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> # case 1 : repeat on axis 0
>>> x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), mindspore.int32)
>>> output = C.repeat_elements(x, rep = 2, axis = 0)
>>> output = ops.repeat_elements(x, rep = 2, axis = 0)
>>> print(output)
[[0 1 2]
[0 1 2]
[3 4 5]
[3 4 5]]
>>> # case 2 : repeat on axis 1
>>> x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), mindspore.int32)
>>> output = ops.repeat_elements(x, rep = 2, axis = 1)
>>> print(output)
[[0 0 1 1 2 2]
[3 3 4 4 5 5]]
"""
const_utils.check_type_valid(F.dtype(x), mstype.number_type, 'input x')
rep = _check_positive_int(rep, "rep", "repeat_elements")
@ -131,7 +138,6 @@ def sequence_mask(lengths, maxlen=None):
- **lengths** (Tensor) - Tensor to calculate the mask for. All values in this tensor should be
less than or equal to `maxlen`. Values greater than `maxlen` will be treated as `maxlen`.
Must be type int32 or int64.
- **maxlen** (int) - size of the last dimension of returned tensor. Must be positive and same
type as elements in `lengths`.
@ -147,13 +153,30 @@ def sequence_mask(lengths, maxlen=None):
``GPU``
Examples:
>>> x = Tensor(np.array([[1, 3], [2, 0]]))
>>> output = C.sequence_mask(x, 3)
>>> # case 1: When maxlen is assigned
>>> x = Tensor(np.array([1, 2, 3, 4]))
>>> output = ops.sequence_mask(x, 5)
>>> print(output)
[[[True False False]
[True True True]]
[[True True False]
[False False False]]]
[[ True False False False False]
[ True True False False False]
[ True True True False False]
[ True True True True False]]
>>> # case 2: When there is 0 in x
>>> x = Tensor(np.array([[1, 3], [2, 0]]))
>>> output = ops.sequence_mask(x, 5)
>>> print(output)
[[[ True False False False False]
[ True True True False False]]
[[ True True False False False]
[False False False False False]]]
>>> # case 3: when the maxlen is not assigned
>>> x = Tensor(np.array([[1, 3], [2, 4]]))
>>> output = ops.sequence_mask(x)
>>> print(output)
[[[ True False False False ]
[ True True True False ]]
[[ True True False False ]
[ True True True True ]]]
"""
argmax_op = P.ArgMaxWithValue()

View File

@ -480,8 +480,10 @@ class HyperMap(HyperMap_):
Args:
ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`,
the operations should be put in the first input of the instance.
reverse (bool): `reverse` is the flag to decide if apply the operation reversely. Only supported in graph mode.
the operations should be put in the first input of the instance. Default is None.
reverse (bool): The optimizer needs to be inverted in some scenarios to improve parallel performance,
general users please ignore. `reverse` is the flag to decide if apply the operation reversely.
Only supported in graph mode. Default is False.
Inputs:
- **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be sequences with the same length.
@ -503,13 +505,13 @@ class HyperMap(HyperMap_):
Examples:
>>> from mindspore import dtype as mstype
>>> nest_tensor_list = ((Tensor(1, mstype.float32), Tensor(2, mstype.float32)),
... (Tensor(3, mstype.float32), Tensor(4, mstype.float32)))
... (Tensor(3, mstype.float32), Tensor(4, mstype.float32)))
>>> # square all the tensor in the nested list
>>>
>>> square = MultitypeFuncGraph('square')
>>> @square.register("Tensor")
... def square_tensor(x):
... return F.square(x)
... return ops.square(x)
>>>
>>> common_map = HyperMap()
>>> output = common_map(square, nest_tensor_list)
@ -554,7 +556,9 @@ class Map(Map_):
Args:
ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`,
the operations should be put in the first input of the instance. Default: None
reverse (bool): `reverse` is the flag to decide if apply the operation reversely. Only supported in graph mode.
reverse (bool): The optimizer needs to be inverted in some scenarios to improve parallel performance,
general users please ignore. `reverse` is the flag to decide if apply the operation reversely.
Only supported in graph mode. Default is False.
Inputs:
- **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences,
@ -574,7 +578,7 @@ class Map(Map_):
>>> square = MultitypeFuncGraph('square')
>>> @square.register("Tensor")
... def square_tensor(x):
... return F.square(x)
... return ops.square(x)
>>>
>>> common_map = Map()
>>> output = common_map(square, tensor_list)

View File

@ -51,25 +51,21 @@ def clip_by_value(x, clip_value_min, clip_value_max):
'clip_value_min' needs to be less than or equal to 'clip_value_max'.
Args:
x (Tensor): Input data.
x (Tensor): Input data. The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
clip_value_min (Tensor): The minimum value.
clip_value_max (Tensor): The maximum value.
Returns:
Tensor, a clipped Tensor.
Tensor, a clipped Tensor. It has the same shape and data type as `x`.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindspore.ops import composite as C
>>> import mindspore.common.dtype as mstype
>>> min_value = Tensor(5, mstype.float32)
>>> max_value = Tensor(20, mstype.float32)
>>> x = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mstype.float32)
>>> output = C.clip_by_value(x, min_value, max_value)
>>> min_value = Tensor(5, mindspore.float32)
>>> max_value = Tensor(20, mindspore.float32)
>>> x = Tensor(np.array([[1., 25., 5., 7.], [4., 11., 6., 21.]]), mindspore.float32)
>>> output = ops.clip_by_value(x, min_value, max_value)
>>> print(output)
[[ 5. 20. 5. 7.]
[ 5. 11. 6. 20.]]
@ -148,12 +144,15 @@ def clip_by_global_norm(x, clip_norm=1.0, use_norm=None):
Input 'x' should be a tuple or list of tensors. Otherwise, it will raise an error.
Args:
x (Union(tuple[Tensor], list[Tensor])): Input data to clip.
clip_norm (Union(float, int)): The clipping ratio, it should be greater than 0. Default: 1.0
use_norm (None): The global norm. Default: None. Currently only none is supported.
x (Union(tuple[Tensor], list[Tensor])): Input data to clip.
The shape of each Tensor in tuple is :math:`(N,*)` where :math:`*` means,
any number of additional dimensions.
clip_norm (Union(float, int)): The clipping ratio, it should be greater than 0. Default: 1.0
use_norm (None): The global norm. Default: None. Currently only none is supported.
Returns:
tuple[Tensor], a clipped Tensor.
tuple[Tensor], a clipped Tensor. It has the same data type as `x` and each Tensor in the output tuple is the
same as the original input shape.
Supported Platforms:
``Ascend`` ``GPU``

View File

@ -47,24 +47,46 @@ def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
Args:
x (Tensor): Input data is used to count non-zero numbers.
:math:`(N,*)` where :math:`*` means, any number of additional dimensions.
axis (Union[int, tuple(int), list(int)]): The dimensions to reduce. Only constant value is allowed.
Default: (), reduce all dimensions.
keep_dims (bool): If true, keep these reduced dimensions and the length is 1.
If false, don't keep these dimensions. Default: False.
dtype (Union[Number, mstype.bool\_]): The data type of the output tensor. Only constant value is allowed.
Default: mstype.int32
dtype (Union[Number, mindspore.bool\_]): The data type of the output tensor. Only constant value is allowed.
Default: mindspore.int32
Returns:
Tensor, number of nonzero element. The data type is dtype.
Tensor, number of nonzero element. The data type is `dtype`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([[0, 1, 0], [1, 1, 0]]).astype(np.float32))
>>> nonzero_num = count_nonzero(x=input_x, axis=[0, 1], keep_dims=True, dtype=mstype.int32)
>>> # case 1: each value specified.
>>> x = Tensor(np.array([[0, 1, 0], [1, 1, 0]]).astype(np.float32))
>>> nonzero_num = ops.count_nonzero(x=x, axis=[0, 1], keep_dims=True, dtype=mindspore.int32)
>>> print(nonzero_num)
[[3]]
>>> # case 2: all value is default.
>>> nonzero_num = ops.count_nonzero(x=x)
>>> print(nonzero_num)
3
>>> # case 3: axis value was specified 0.
>>> nonzero_num = ops.count_nonzero(x=x, axis=[0,])
>>> print(nonzero_num)
[1 2 0]
>>> # case 4: axis value was specified 1.
>>> nonzero_num = ops.count_nonzero(x=x, axis=[1,])
>>> print(nonzero_num)
[1 2]
>>> # case 5: keep_dims value was specified.
>>> nonzero_num = ops.count_nonzero(x=x, keep_dims=True)
>>> print(nonzero_num)
[[3]]
>>> # case 6: keep_dims and axis value was specified.
>>> nonzero_num = ops.count_nonzero(x=x, axis=[0,], keep_dims=True)
>>> print(nonzero_num)
[[1 2 0]]
"""
const_utils.check_type_valid(F.dtype(x), mstype.number_type, 'input x')
@ -239,7 +261,7 @@ def tensor_dot(x1, x2, axes):
Examples:
>>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
>>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32)
>>> output = C.tensor_dot(input_x1, input_x2, ((0,1),(1,2)))
>>> output = ops.tensor_dot(input_x1, input_x2, ((0,1),(1,2)))
>>> print(output)
[[2. 2. 2]
[2. 2. 2]
@ -303,7 +325,9 @@ def dot(x1, x2):
Inputs:
- **x1** (Tensor) - First tensor in Dot op with datatype float16 or float32
The rank must be greater than or equal to 2.
- **x2** (Tensor) - Second tensor in Dot op with datatype float16 or float32
The rank must be greater than or equal to 2.
Outputs:
Tensor, dot product of x1 and x2.
@ -319,10 +343,48 @@ def dot(x1, x2):
Examples:
>>> input_x1 = Tensor(np.ones(shape=[2, 3]), mindspore.float32)
>>> input_x2 = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32)
>>> output = C.dot(input_x1, input_x2)
>>> output = ops.dot(input_x1, input_x2)
>>> print(output)
[[[3. 3.]]
[[3. 3.]]]
>>> print(output.shape)
(2, 1, 2)
>>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
>>> input_x2 = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32)
>>> output = ops.dot(input_x1, input_x2)
>>> print(output)
[[[[3. 3.]]
[[3. 3.]]]]
>>> print(output.shape)
(1, 2, 1, 2)
>>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
>>> input_x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
>>> output = ops.dot(input_x1, input_x2)
>>> print(output)
[[[[3. 3.]
[3. 3.]]
[[3. 3.]
[3. 3.]]]]
>>> print(output.shape)
(1, 2, 2, 2)
>>> input_x1 = Tensor(np.ones(shape=[3, 2, 3]), mindspore.float32)
>>> input_x2 = Tensor(np.ones(shape=[2, 1, 3, 2]), mindspore.float32)
>>> output = ops.dot(input_x1, input_x2)
>>> print(output)
[[[[[3. 3.]]
[[3. 3.]]]
[[[3. 3.]]
[[3. 3.]]]]
[[[[3. 3.]]
[[3. 3.]]]
[[[3. 3.]]
[[3. 3.]]]]
[[[[3. 3.]]
[[3. 3.]]]
[[[3. 3.]]
[[3. 3.]]]]]
>>> print(output.shape)
(3, 2, 2, 1, 2)
"""
shape_op = P.Shape()
reshape_op = P.Reshape()
@ -459,16 +521,18 @@ def batch_dot(x1, x2, axes=None):
output = x1[batch, :] * x2[batch, :]
Inputs:
- **x1** (Tensor) - First tensor in Batch Dot op with datatype float32
- **x2** (Tensor) - Second tensor in Batch Dot op with datatype float32. x2's datatype should
be same as x1's.
- **x1** (Tensor) - First tensor in Batch Dot op with datatype float32 and the rank of `x1` must be greater
than or equal to 2.
- **x2** (Tensor) - Second tensor in Batch Dot op with datatype float32. The datatype of `x2` should
be same as `x1` and the rank of `x2` must be greater than or equal to 2.
- **axes** (Union[int, tuple(int), list(int)]) - Single value or tuple/list of length 2 with dimensions
specified for `a` and `b` each. If single value `N` passed, automatically picks up last N dims from
`a` input shape and last N dims from `b` input shape in order as axes for each respectively.
`a` input shape and last N dimensions from `b` input shape in order as axes for each respectively.
Outputs:
Tensor, batch dot product of x1 and x2. The Shape of output for input shapes (batch, d1, axes, d2) and
(batch, d3, axes, d4) is (batch, d1, d2, d3, d4)
Tensor, batch dot product of `x1` and `x2`.For example: The Shape of output
for input `x1` shapes (batch, d1, axes, d2) and `x2` shapes (batch, d3, axes, d4) is (batch, d1, d2, d3, d4),
where d1 and d2 means any number.
Raises:
TypeError: If type of x1 and x2 are not the same.
@ -485,15 +549,35 @@ def batch_dot(x1, x2, axes=None):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x1 = Tensor(np.ones(shape=[2, 2, 3]), mindspore.float32)
>>> input_x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
>>> x1 = Tensor(np.ones(shape=[2, 2, 3]), mindspore.float32)
>>> x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
>>> axes = (-1, -2)
>>> output = C.batch_dot(input_x1, input_x2, axes)
>>> output = ops.batch_dot(x1, x2, axes)
>>> print(output)
[[[3. 3.]
[3. 3.]]
[[3. 3.]
[3. 3.]]]
>>> x1 = Tensor(np.ones(shape=[2, 2]), mindspore.float32)
>>> x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
>>> axes = (1, 2)
>>> output = ops.batch_dot(x1, x2, axes)
>>> print(output)
[[2. 2. 2.]
[2. 2. 2.]]
>>> print(output.shape)
(2, 3)
>>> x1 = Tensor(np.ones(shape=[6, 2, 3, 4]), mindspore.float32)
>>> x2 = Tensor(np.ones(shape=[6, 5, 4, 8]), mindspore.float32)
>>> output = ops.batch_dot(x1, x2)
>>> print(output.shape)
(6, 2, 3, 5, 8)
>>> x1 = Tensor(np.ones(shape=[2, 2, 4]), mindspore.float32)
>>> x2 = Tensor(np.ones(shape=[2, 5, 4, 5]), mindspore.float32)
>>> output = ops.batch_dot(x1, x2)
>>> print(output.shape)
(2, 2, 5, 5)
"""
transpose_op = P.Transpose()
@ -637,7 +721,11 @@ def matmul(x1, x2, dtype=None):
Args:
x1 (Tensor): Input tensor, scalar not allowed.
The last dimension of `x1` must be the same size as the second last dimension of `x2`.
And the shape of x1 and x2 could be broadcast.
x2 (Tensor): Input tensor, scalar not allowed.
The last dimension of `x1` must be the same size as the second last dimension of `x2`.
And the shape of x1 and x2 could be broadcast.
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor.
@ -648,11 +736,13 @@ def matmul(x1, x2, dtype=None):
Raises:
ValueError: If the last dimension of `x1` is not the same size as the
second-to-last dimension of `x2`, or if a scalar value is passed in.
ValueError: If the shape of `x1` and `x2` could not broadcast together
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> # case 1 : Reasonable application of broadcast mechanism
>>> x1 = Tensor(np.arange(2*3*4).reshape(2, 3, 4), mindspore.float32)
>>> x2 = Tensor(np.arange(4*5).reshape(4, 5), mindspore.float32)
>>> output = ops.matmul(x1, x2)
@ -663,6 +753,16 @@ def matmul(x1, x2, dtype=None):
[[ 430. 484. 538. 592. 646.]
[ 550. 620. 690. 760. 830.]
[ 670. 756. 842. 928. 1014.]]]
>>> print(output.shape)
(2, 3, 5)
>>> # case 2 : the rank of `x1` is 1
>>> x1 = Tensor(np.ones([1, 2]), mindspore.float32)
>>> x2 = Tensor(np.ones([2,]), mindspore.float32)
>>> output = ops.matmul(x1, x2)
>>> print(output)
[2.]
>>> print(output.shape)
(1,)
"""
# performs type promotion
dtype1 = F.dtype(x1)

View File

@ -34,6 +34,7 @@ def normal(shape, mean, stddev, seed=None):
Args:
shape (tuple): The shape of random tensor to be generated.
The format is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak,
with data type in [int8, int16, int32, int64, float16, float32].
stddev (Tensor): The deviation σ distribution parameter. It should be greater than 0,
@ -50,17 +51,27 @@ def normal(shape, mean, stddev, seed=None):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> import mindspore.ops.composite as C
>>> from mindspore.common import dtype as mstype
>>> shape = (3, 1, 2)
>>> mean = Tensor(np.array([[3, 4], [5, 6]]), mstype.float32)
>>> stddev = Tensor(1.0, mstype.float32)
>>> output = C.normal(shape, mean, stddev, seed=5)
>>> mean = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32)
>>> stddev = Tensor(1.0, mindspore.float32)
>>> output = ops.normal(shape, mean, stddev, seed=5)
>>> result = output.shape
>>> print(result)
(3, 2, 2)
>>> shape = (3, 1, 3)
>>> mean = Tensor(np.array([[3, 4, 3], [3, 5, 6]]), mindspore.float32)
>>> stddev = Tensor(1.0, mindspore.float32)
>>> output = ops.normal(shape, mean, stddev, seed=5)
>>> result = output.shape
>>> print(result)
(3, 2, 3)
>>> shape = (3, 1, 3)
>>> mean = Tensor(np.array([[1, 2, 3], [3, 4, 3], [3, 5, 6]]), mindspore.float32)
>>> stddev = Tensor(1.0, mindspore.float32)
>>> output = ops.normal(shape, mean, stddev, seed=5)
>>> result = output.shape
>>> print(result)
(3, 3, 3)
"""
mean_dtype = F.dtype(mean)
stddev_dtype = F.dtype(stddev)
@ -83,6 +94,7 @@ def laplace(shape, mean, lambda_param, seed=None):
Args:
shape (tuple): The shape of random tensor to be generated.
The format is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak.
With float32 data type.
lambda_param (Tensor): The parameter used for controlling the variance of this random distribution. The
@ -91,7 +103,7 @@ def laplace(shape, mean, lambda_param, seed=None):
Default: None, which will be treated as 0.
Returns:
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of mean and lambda_param.
Tensor. The shape should be the broadcasted shape of input `shape` and shapes of `mean` and `lambda_param`.
The dtype is float32.
Supported Platforms:
@ -100,11 +112,11 @@ def laplace(shape, mean, lambda_param, seed=None):
Examples:
>>> from mindspore import Tensor
>>> from mindspore.ops import composite as C
>>> import mindspore.common.dtype as mstype
>>> import mindspore.common.dtype as mindspore
>>> shape = (2, 3)
>>> mean = Tensor(1.0, mstype.float32)
>>> lambda_param = Tensor(1.0, mstype.float32)
>>> output = C.laplace(shape, mean, lambda_param, seed=5)
>>> mean = Tensor(1.0, mindspore.float32)
>>> lambda_param = Tensor(1.0, mindspore.float32)
>>> output = ops.laplace(shape, mean, lambda_param, seed=5)
>>> print(output.shape)
(2, 3)
"""
@ -128,6 +140,8 @@ def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
Args:
shape (tuple): The shape of random tensor to be generated.
The format is :math:`(N,*)` where :math:`*` means, any number of additional dimensions
and the length of :math:`(N,*)` should be less than 8 in broadcast operation.
minval (Tensor): The distribution parameter `a`.
It defines the minimum possible generated value, with int32 or float32 data type.
If dtype is int32, only one number is allowed.
@ -138,7 +152,7 @@ def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
must be non-negative. Default: None, which will be treated as 0.
dtype (mindspore.dtype): type of the Uniform distribution. If it is int32, it generates numbers from discrete
uniform distribution; if it is float32, it generates numbers from continuous uniform distribution. It only
supports these two data types. Default: mstype.float32.
supports these two data types. Default: mindspore.float32.
Returns:
Tensor. The shape should be equal to the broadcasted shape between the input `shape` and shapes
@ -156,20 +170,17 @@ def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
``Ascend`` ``GPU``
Examples:
>>> import mindspore.ops.composite as C
>>> from mindspore.common import dtype as mstype
>>> from mindspore import Tensor
>>> # For discrete uniform distribution, only one number is allowed for both minval and maxval:
>>> shape = (4, 2)
>>> minval = Tensor(1, mstype.int32)
>>> maxval = Tensor(2, mstype.int32)
>>> output = C.uniform(shape, minval, maxval, seed=5, dtype=mstype.int32)
>>> minval = Tensor(1, mindspore.int32)
>>> maxval = Tensor(2, mindspore.int32)
>>> output = ops.uniform(shape, minval, maxval, seed=5, dtype=mindspore.int32)
>>>
>>> # For continuous uniform distribution, minval and maxval can be multi-dimentional:
>>> shape = (3, 1, 2)
>>> minval = Tensor(np.array([[3, 4], [5, 6]]), mstype.float32)
>>> maxval = Tensor([8.0, 10.0], mstype.float32)
>>> output = C.uniform(shape, minval, maxval, seed=5)
>>> minval = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32)
>>> maxval = Tensor([8.0, 10.0], mindspore.float32)
>>> output = ops.uniform(shape, minval, maxval, seed=5)
>>> result = output.shape
>>> print(result)
(3, 2, 2)
@ -196,13 +207,14 @@ def gamma(shape, alpha, beta, seed=None):
Args:
shape (tuple): The shape of random tensor to be generated.
The format is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
alpha (Tensor): The alpha α distribution parameter. It should be greater than 0 with float32 data type.
beta (Tensor): The beta β distribution parameter. It should be greater than 0 with float32 data type.
seed (int): Seed is used as entropy source for the random number engines to generate
pseudo-random numbers, must be non-negative. Default: None, which will be treated as 0.
Returns:
Tensor. The shape should be equal to the broadcasted shape between the input "shape" and shapes
Tensor. The shape should be equal to the broadcasted shape between the input `shape` and shapes
of `alpha` and `beta`.
The dtype is float32.
@ -216,16 +228,48 @@ def gamma(shape, alpha, beta, seed=None):
``Ascend``
Examples:
>>> from mindspore import Tensor
>>> import mindspore.ops.composite as C
>>> from mindspore.common import dtype as mstype
>>> # case 1: alpha_shape is (2, 2)
>>> shape = (3, 1, 2)
>>> alpha = Tensor(np.array([[3, 4], [5, 6]]), mstype.float32)
>>> beta = Tensor(np.array([1.0]), mstype.float32)
>>> output = C.gamma(shape, alpha, beta, seed=5)
>>> alpha = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32)
>>> beta = Tensor(np.array([1.0]), mindspore.float32)
>>> output = ops.gamma(shape, alpha, beta, seed=5)
>>> result = output.shape
>>> print(result)
(3, 2, 2)
>>> # case 2: alpha_shape is (2, 3), so shape is (3, 1, 3)
>>> shape = (3, 1, 3)
>>> alpha = Tensor(np.array([[1, 3, 4], [2, 5, 6]]), mindspore.float32)
>>> beta = Tensor(np.array([1.0]), mindspore.float32)
>>> output = ops.gamma(shape, alpha, beta, seed=5)
>>> result = output.shape
>>> print(result)
(3, 2, 3)
>>> # case 3: beta_shape is (1, 2), the output is different.
>>> shape = (3, 1, 2)
>>> alpha = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32)
>>> beta = Tensor(np.array([1.0, 2]), mindspore.float32)
>>> output = ops.gamma(shape, alpha, beta, seed=5)
>>> result = output.shape
>>> print(output)
[[[ 2.2132034 5.8855834]]
[ 3.3981476 7.5805717]
[[ 3.3981476 7.5805717]]
[ 3.7190282 19.941492]
[[ 2.9512358 2.5969937]]
[ 3.786061 5.160872 ]]]
>>> # case 4: beta_shape is (2, 1), the output is different.
>>> shape = (3, 1, 2)
>>> alpha = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32)
>>> beta = Tensor(np.array([[1.0], [2.0]]), mindspore.float32)
>>> output = ops.gamma(shape, alpha, beta, seed=5)
>>> result = output.shape
>>> print(output)
[[[ 5.6085486 7.8280783]]
[ 15.97684 16.116285]
[[ 1.8347423 1.713663]]
[ 3.2434065 15.667398]
[[ 4.2922077 7.3365674]]
[ 5.3876944 13.159832 ]]]
"""
seed1, seed2 = _get_seed(seed, "gamma")
random_gamma = P.Gamma(seed1, seed2)
@ -243,6 +287,7 @@ def poisson(shape, mean, seed=None):
Args:
shape (tuple): The shape of random tensor to be generated.
The format is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
mean (Tensor): The mean μ distribution parameter. It should be greater than 0 with float32 data type.
seed (int): Seed is used as entropy source for the random number engines to generate pseudo-random numbers
and must be non-negative. Default: None, which will be treated as 0.
@ -260,16 +305,20 @@ def poisson(shape, mean, seed=None):
``Ascend``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> import mindspore.ops.composite as C
>>> from mindspore.common import dtype as mstype
>>> # case 1: It can be broadcast.
>>> shape = (4, 1)
>>> mean = Tensor(np.array([5.0, 10.0]), mstype.float32)
>>> output = C.poisson(shape, mean, seed=5)
>>> mean = Tensor(np.array([5.0, 10.0]), mindspore.float32)
>>> output = ops.poisson(shape, mean, seed=5)
>>> result = output.shape
>>> print(result)
(4, 2)
>>> # case 2: It can not be broadcast. It is recommended to use the same shape.
>>> shape = (2, 2)
>>> mean = Tensor(np.array([[5.0, 10.0], [5.0, 1.0]]), mindspore.float32)
>>> output = ops.poisson(shape, mean, seed=5)
>>> result = output.shape
>>> print(result)
(2, 2)
"""
seed1, seed2 = _get_seed(seed, "poisson")
random_poisson = P.Poisson(seed1, seed2)
@ -287,7 +336,7 @@ def multinomial(inputs, num_sample, replacement=True, seed=None):
but must be non-negative, finite and have a non-zero sum.
Args:
inputs (Tensor): The input tensor containing probabilities, must be 1 or 2 dimensions, with
x (Tensor): The input tensor containing probabilities, must be 1 or 2 dimensions, with
float32 data type.
num_sample (int): Number of samples to draw.
replacement (bool, optional): Whether to draw with replacement or not, default True.
@ -299,7 +348,7 @@ def multinomial(inputs, num_sample, replacement=True, seed=None):
The dtype is float32.
Raises:
TypeError: If `inputs` is not a Tensor whose dtype is not float32.
TypeError: If `x` is not a Tensor whose dtype is not float32.
TypeError: If `num_sample` is not an int.
TypeError: If `seed` is neither an int nor a optional.
@ -307,13 +356,27 @@ def multinomial(inputs, num_sample, replacement=True, seed=None):
``GPU``
Examples:
>>> from mindspore import Tensor
>>> import mindspore.ops.composite as C
>>> from mindspore.common import dtype as mstype
>>> input = Tensor([0, 9, 4, 0], mstype.float32)
>>> output = C.multinomial(input, 2, True)
>>> # case 1: The output is random, and the length of the output is the same as num_sample.
>>> x = Tensor([0, 9, 4, 0], mindspore.float32)
>>> output = ops.multinomial(x, 2)
>>> # print(output)
>>> # [1 2] or [2 1]
>>> # the case where the result is [2 1] in multiple times.
>>> # This is because the value corresponding to the index 1 is larger than the value of the index 2.
>>> print(len(output))
2
>>> # case 2: The output is random, and the length of the output is the same as num_sample.
>>> # replacement is False(Default).
>>> # If the extracted value is 0, the index value of 1 will be returned.
>>> x = Tensor([0, 9, 4, 0], mstype.float32)
>>> output = ops.multinomial(x, 4)
>>> print(output)
[1 1]
[1 1 2 1]
>>> # case 3: num_sample == x_length = 4, and replacement is True, Can extract the same elements。
>>> x = Tensor([0, 9, 4, 0], mstype.float32)
>>> output = ops.multinomial(x, 4, True)
>>> print(output)
[1 1 2 2]
"""
shape = P.Shape()
reshape = P.Reshape()