!15 Correct the comments for `RandomChoiceWithMask` op.
Merge pull request !15 from seatea/randomchoicewithmask-doc
This commit is contained in:
commit
2753aa8768
|
@ -25,20 +25,23 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Generates a random samply as index tensor with a mask tensor from a given tensor.
|
Generates a random samply as index tensor with a mask tensor from a given tensor.
|
||||||
|
|
||||||
The input must be a tensor of rank >= 2, the first dimension specify the number of sample.
|
The input must be a tensor of rank >= 1. If its rank >= 2, the first dimension specify the number of sample.
|
||||||
The index tensor and the mask tensor have the same and fixed shape. The index tensor denotes the index
|
The index tensor and the mask tensor have the fixed shapes. The index tensor denotes the index of the nonzero
|
||||||
of the nonzero sample, while the mask tensor denotes which element in the index tensor are valid.
|
sample, while the mask tensor denotes which elements in the index tensor are valid.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
count (int): Number of items expected to get. Default: 256.
|
count (int): Number of items expected to get and the number should be greater than 0. Default: 256.
|
||||||
seed (int): Random seed.
|
seed (int): Random seed. Default: 0.
|
||||||
seed2 (int): Random seed2.
|
seed2 (int): Random seed2. Default: 0.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **input_x** (Tensor) - The input tensor.
|
- **input_x** (Tensor[bool]) - The input tensor.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Tuple, two tensors, the first one is the index tensor and the other one is the mask tensor.
|
Two tensors, the first one is the index tensor and the other one is the mask tensor.
|
||||||
|
|
||||||
|
- **index** (Tensor) - The output has shape between 2-D and 5-D.
|
||||||
|
- **mask** (Tensor) - The output has shape 1-D.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> rnd_choice_mask = RandomChoiceWithMask()
|
>>> rnd_choice_mask = RandomChoiceWithMask()
|
||||||
|
|
Loading…
Reference in New Issue