diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 56a3537923c..5ce2830113f 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -3423,7 +3423,7 @@ class ComputeAccidentalHits(PrimitiveWithCheck): - **true_classes** (Tensor) - The target classes. With data type of int32 or int64 and shape [batch_size, num_true]. - **sampled_candidates** (Tensor) - The sampled_candidates output of CandidateSampler, - with shape [num_sampled] and the same type as true_classes. + with data type of int32 or int64 and shape [num_sampled]. Outputs: Tuple of 3 Tensors. diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index 3a477da37c9..8fbcd355bd8 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -602,9 +602,10 @@ class LogUniformCandidateSampler(PrimitiveWithInfer): Args: num_true (int): The number of target classes per training example. Default: 1. num_sampled (int): The number of classes to randomly sample. Default: 5. - unique (bool): Determines whether sample with rejection. If unique is True, - all sampled classes in a batch are unique. Default: True. - range_max (int): The number of possible classes. Default: 5. + unique (bool): Determines whether sample with rejection. If `unique` is True, + all sampled classes in a batch are unique. Default: True. + range_max (int): The number of possible classes. When `unique` is True, + `range_max` must be greater than or equal to `num_sampled`. Default: 5. seed (int): Random seed, must be non-negative. Inputs: @@ -644,6 +645,7 @@ class LogUniformCandidateSampler(PrimitiveWithInfer): Validator.check_value_type("seed", seed, [int], self.name) self.num_true = Validator.check_number("num_true", num_true, 1, Rel.GE, self.name) self.num_sampled = Validator.check_number("num_sampled", num_sampled, 1, Rel.GE, self.name) + Validator.check_number("range_max", range_max, 1, Rel.GE, self.name) if unique: Validator.check("range_max", range_max, "num_sampled", num_sampled, Rel.GE, self.name) self.range_max = range_max