From 840280e7843ecaf4014bda8cbe4fdf76aa75b614 Mon Sep 17 00:00:00 2001 From: seatea Date: Mon, 30 Mar 2020 12:18:02 +0800 Subject: [PATCH] Correct the comments for `RandomChoiceWithMask` op. --- mindspore/ops/operations/random_ops.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index c8f59e898d2..9ef5b301f97 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -25,20 +25,23 @@ class RandomChoiceWithMask(PrimitiveWithInfer): """ Generates a random samply as index tensor with a mask tensor from a given tensor. - The input must be a tensor of rank >= 2, the first dimension specify the number of sample. - The index tensor and the mask tensor have the same and fixed shape. The index tensor denotes the index - of the nonzero sample, while the mask tensor denotes which element in the index tensor are valid. + The input must be a tensor of rank >= 1. If its rank >= 2, the first dimension specify the number of sample. + The index tensor and the mask tensor have the fixed shapes. The index tensor denotes the index of the nonzero + sample, while the mask tensor denotes which elements in the index tensor are valid. Args: - count (int): Number of items expected to get. Default: 256. - seed (int): Random seed. - seed2 (int): Random seed2. + count (int): Number of items expected to get and the number should be greater than 0. Default: 256. + seed (int): Random seed. Default: 0. + seed2 (int): Random seed2. Default: 0. Inputs: - - **input_x** (Tensor) - The input tensor. + - **input_x** (Tensor[bool]) - The input tensor. Outputs: - Tuple, two tensors, the first one is the index tensor and the other one is the mask tensor. + Two tensors, the first one is the index tensor and the other one is the mask tensor. + + - **index** (Tensor) - The output has shape between 2-D and 5-D. + - **mask** (Tensor) - The output has shape 1-D. Examples: >>> rnd_choice_mask = RandomChoiceWithMask()