support choice_with_mask functional api

This commit is contained in:
qinzheng 2022-09-14 17:42:57 +08:00
parent a6e607f069
commit 554f1885a1
5 changed files with 76 additions and 25 deletions

View File

@ -376,6 +376,7 @@ Tensor创建
:nosignatures:
:template: classtemplate.rst
mindspore.ops.choice_with_mask
mindspore.ops.gamma
mindspore.ops.laplace
mindspore.ops.multinomial

View File

@ -0,0 +1,26 @@
mindspore.ops.choice_with_mask
=====================================
.. py:function:: mindspore.ops.choice_with_mask(input_x, count=256, seed=0, seed2=0)
对输入进行随机取样,返回取样索引和掩码。
输入必须是秩不小于1的Tensor。如果其秩大于等于2则第一个维度指定样本数。
索引Tensor为取样的索引掩码Tensor表示索引Tensor中的哪些元素取值为True。
参数:
- **input_x** (Tensor[bool]) - 输入Tensorbool类型。秩必须大于等于1且小于等于5。
- **count** (int) - 取样数量必须大于0。默认值256。
- **seed** (int) - 随机种子。默认值0。
- **seed2** (int) - 随机种子2。默认值0。
返回:
两个Tensor第一个为索引另一个为掩码。
- **index** (Tensor) - 2维Tensor。
- **mask** (Tensor) - 1维Tensor。
异常:
- **TypeError** - `count` 不是int类型。
- **TypeError** - `seed``seed2` 不是int类型。
- **TypeError** - `input_x` 不是Tensor。

View File

@ -367,6 +367,7 @@ from .random_func import (
uniform_candidate_sampler,
random_poisson,
random_shuffle,
choice_with_mask
)
from .grad import (
grad_func,

View File

@ -23,7 +23,7 @@ from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_util
from ...common import dtype as mstype
from ...common.seed import _get_graph_seed
from ...common.tensor import Tensor
from ..operations.random_ops import RandomShuffle
from ..operations.random_ops import RandomShuffle, RandomChoiceWithMask
from .._primitive_cache import _get_cache_prim
from .._utils import get_broadcast_shape
@ -422,6 +422,51 @@ def random_shuffle(x, seed=0, seed2=0):
return output
def choice_with_mask(input_x, count=256, seed=0, seed2=0):
"""
Generates a random sample as index tensor with a mask tensor from a given tensor.
The input_x must be a tensor of rank not less than 1. If its rank is greater than or equal to 2,
the first dimension specifies the number of samples.
The index tensor and the mask tensor have the fixed shapes. The index tensor denotes the index of the nonzero
sample, while the mask tensor denotes which elements in the index tensor are valid.
Args:
input_x (Tensor): The input tensor.
The input tensor rank must be greater than or equal to 1 and less than or equal to 5.
count (int): Number of items expected to get and the number must be greater than 0. Default: 256.
seed (int): Random seed. Default: 0.
seed2 (int): Random seed2. Default: 0.
Returns:
Two tensors, the first one is the index tensor and the other one is the mask tensor.
- **index** (Tensor) - The output shape is 2-D.
- **mask** (Tensor) - The output shape is 1-D.
Raises:
TypeError: If `count` is not an int.
TypeError: If neither `seed` nor `seed2` is an int.
TypeError: If `input_x` is not a Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.ones(shape=[240000, 4]).astype(np.bool))
>>> output_y, output_mask = ops.choice_with_mask(input_x)
>>> result = output_y.shape
>>> print(result)
(256, 2)
>>> result = output_mask.shape
>>> print(result)
(256,)
"""
choice_with_mask_ = _get_cache_prim(RandomChoiceWithMask)(count=count, seed=seed, seed2=seed2)
output = choice_with_mask_(input_x)
return output
__all__ = [
'standard_laplace',
'random_categorical',
@ -431,5 +476,6 @@ __all__ = [
'uniform_candidate_sampler',
'random_poisson',
'random_shuffle',
'choice_with_mask'
]
__all__.sort()

View File

@ -638,30 +638,7 @@ class RandomChoiceWithMask(Primitive):
"""
Generates a random sample as index tensor with a mask tensor from a given tensor.
The input must be a tensor of rank not less than 1. If its rank is greater than or equal to 2,
the first dimension specifies the number of samples.
The index tensor and the mask tensor have the fixed shapes. The index tensor denotes the index of the nonzero
sample, while the mask tensor denotes which elements in the index tensor are valid.
Args:
count (int): Number of items expected to get and the number must be greater than 0. Default: 256.
seed (int): Random seed. Default: 0.
seed2 (int): Random seed2. Default: 0.
Inputs:
- **input_x** (Tensor[bool]) - The input tensor.
The input tensor rank must be greater than or equal to 1 and less than or equal to 5.
Outputs:
Two tensors, the first one is the index tensor and the other one is the mask tensor.
- **index** (Tensor) - The output shape is 2-D.
- **mask** (Tensor) - The output shape is 1-D.
Raises:
TypeError: If `count` is not an int.
TypeError: If neither `seed` nor `seed2` is an int.
TypeError: If `input_x` is not a Tensor.
Refer to :func:'mindspore.ops.choice_with_mask' for more detail.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``