forked from mindspore-Ecosystem/mindspore
!42630 support choice_with_mask functional api
Merge pull request !42630 from qinzheng/code_docs_random_mask
This commit is contained in:
commit
bea1722736
|
@ -377,6 +377,7 @@ Tensor创建
|
|||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
mindspore.ops.choice_with_mask
|
||||
mindspore.ops.gamma
|
||||
mindspore.ops.laplace
|
||||
mindspore.ops.multinomial
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
mindspore.ops.choice_with_mask
|
||||
=====================================
|
||||
|
||||
.. py:function:: mindspore.ops.choice_with_mask(input_x, count=256, seed=0, seed2=0)
|
||||
|
||||
对输入进行随机取样,返回取样索引和掩码。
|
||||
|
||||
输入必须是秩不小于1的Tensor。如果其秩大于等于2,则第一个维度指定样本数。
|
||||
索引Tensor为取样的索引,掩码Tensor表示索引Tensor中的哪些元素取值为True。
|
||||
|
||||
参数:
|
||||
- **input_x** (Tensor[bool]) - 输入Tensor,bool类型。秩必须大于等于1且小于等于5。
|
||||
- **count** (int) - 取样数量,必须大于0。默认值:256。
|
||||
- **seed** (int) - 随机种子。默认值:0。
|
||||
- **seed2** (int) - 随机种子2。默认值:0。
|
||||
|
||||
返回:
|
||||
两个Tensor,第一个为索引,另一个为掩码。
|
||||
|
||||
- **index** (Tensor) - 2维Tensor。
|
||||
- **mask** (Tensor) - 1维Tensor。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `count` 不是int类型。
|
||||
- **TypeError** - `seed` 或 `seed2` 不是int类型。
|
||||
- **TypeError** - `input_x` 不是Tensor。
|
|
@ -368,6 +368,7 @@ from .random_func import (
|
|||
uniform_candidate_sampler,
|
||||
random_poisson,
|
||||
random_shuffle,
|
||||
choice_with_mask
|
||||
)
|
||||
from .grad import (
|
||||
grad_func,
|
||||
|
|
|
@ -23,7 +23,7 @@ from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_util
|
|||
from ...common import dtype as mstype
|
||||
from ...common.seed import _get_graph_seed
|
||||
from ...common.tensor import Tensor
|
||||
from ..operations.random_ops import RandomShuffle
|
||||
from ..operations.random_ops import RandomShuffle, RandomChoiceWithMask
|
||||
from .._primitive_cache import _get_cache_prim
|
||||
from .._utils import get_broadcast_shape
|
||||
|
||||
|
@ -422,6 +422,51 @@ def random_shuffle(x, seed=0, seed2=0):
|
|||
return output
|
||||
|
||||
|
||||
def choice_with_mask(input_x, count=256, seed=0, seed2=0):
|
||||
"""
|
||||
Generates a random sample as index tensor with a mask tensor from a given tensor.
|
||||
|
||||
The input_x must be a tensor of rank not less than 1. If its rank is greater than or equal to 2,
|
||||
the first dimension specifies the number of samples.
|
||||
The index tensor and the mask tensor have the fixed shapes. The index tensor denotes the index of the nonzero
|
||||
sample, while the mask tensor denotes which elements in the index tensor are valid.
|
||||
|
||||
Args:
|
||||
input_x (Tensor): The input tensor.
|
||||
The input tensor rank must be greater than or equal to 1 and less than or equal to 5.
|
||||
count (int): Number of items expected to get and the number must be greater than 0. Default: 256.
|
||||
seed (int): Random seed. Default: 0.
|
||||
seed2 (int): Random seed2. Default: 0.
|
||||
|
||||
Returns:
|
||||
Two tensors, the first one is the index tensor and the other one is the mask tensor.
|
||||
|
||||
- **index** (Tensor) - The output shape is 2-D.
|
||||
- **mask** (Tensor) - The output shape is 1-D.
|
||||
|
||||
Raises:
|
||||
TypeError: If `count` is not an int.
|
||||
TypeError: If neither `seed` nor `seed2` is an int.
|
||||
TypeError: If `input_x` is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.ones(shape=[240000, 4]).astype(np.bool))
|
||||
>>> output_y, output_mask = ops.choice_with_mask(input_x)
|
||||
>>> result = output_y.shape
|
||||
>>> print(result)
|
||||
(256, 2)
|
||||
>>> result = output_mask.shape
|
||||
>>> print(result)
|
||||
(256,)
|
||||
"""
|
||||
choice_with_mask_ = _get_cache_prim(RandomChoiceWithMask)(count=count, seed=seed, seed2=seed2)
|
||||
output = choice_with_mask_(input_x)
|
||||
return output
|
||||
|
||||
|
||||
__all__ = [
|
||||
'standard_laplace',
|
||||
'random_categorical',
|
||||
|
@ -431,5 +476,6 @@ __all__ = [
|
|||
'uniform_candidate_sampler',
|
||||
'random_poisson',
|
||||
'random_shuffle',
|
||||
'choice_with_mask'
|
||||
]
|
||||
__all__.sort()
|
||||
|
|
|
@ -638,30 +638,7 @@ class RandomChoiceWithMask(Primitive):
|
|||
"""
|
||||
Generates a random sample as index tensor with a mask tensor from a given tensor.
|
||||
|
||||
The input must be a tensor of rank not less than 1. If its rank is greater than or equal to 2,
|
||||
the first dimension specifies the number of samples.
|
||||
The index tensor and the mask tensor have the fixed shapes. The index tensor denotes the index of the nonzero
|
||||
sample, while the mask tensor denotes which elements in the index tensor are valid.
|
||||
|
||||
Args:
|
||||
count (int): Number of items expected to get and the number must be greater than 0. Default: 256.
|
||||
seed (int): Random seed. Default: 0.
|
||||
seed2 (int): Random seed2. Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor[bool]) - The input tensor.
|
||||
The input tensor rank must be greater than or equal to 1 and less than or equal to 5.
|
||||
|
||||
Outputs:
|
||||
Two tensors, the first one is the index tensor and the other one is the mask tensor.
|
||||
|
||||
- **index** (Tensor) - The output shape is 2-D.
|
||||
- **mask** (Tensor) - The output shape is 1-D.
|
||||
|
||||
Raises:
|
||||
TypeError: If `count` is not an int.
|
||||
TypeError: If neither `seed` nor `seed2` is an int.
|
||||
TypeError: If `input_x` is not a Tensor.
|
||||
Refer to :func:'mindspore.ops.choice_with_mask' for more detail.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
|
Loading…
Reference in New Issue