!8804 Prevent int64 from converting to int32 in UniformCandidateSampler

From: @TFbunny
Reviewed-by: @robingrosman
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-21 05:04:15 +08:00 committed by Gitee
commit 0b992c077b
1 changed files with 3 additions and 1 deletions

View File

@ -541,7 +541,8 @@ class UniformCandidateSampler(PrimitiveWithInfer):
of num_sampled. If unique=True, num_sampled must be less than or equal to range_max. of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.
unique (bool): Whether all sampled classes in a batch are unique. unique (bool): Whether all sampled classes in a batch are unique.
range_max (int): The number of possible classes, must be non-negative. range_max (int): The number of possible classes, must be non-negative.
seed (int): Random seed, must be non-negative. Default: 0. seed (int): Used for random number generation, must be non-negative. If seed has a value of 0,
seed will be replaced with a randomly generated value. Default: 0.
remove_accidental_hits (bool): Whether accidental hit is removed. Default: False. remove_accidental_hits (bool): Whether accidental hit is removed. Default: False.
Inputs: Inputs:
@ -580,6 +581,7 @@ class UniformCandidateSampler(PrimitiveWithInfer):
self.num_sampled = num_sampled self.num_sampled = num_sampled
def infer_dtype(self, true_classes_type): def infer_dtype(self, true_classes_type):
Validator.check_tensor_dtype_valid("true_classes_type", true_classes_type, (mstype.int32), self.name)
return (true_classes_type, mstype.float32, mstype.float32) return (true_classes_type, mstype.float32, mstype.float32)
def infer_shape(self, true_classes_shape): def infer_shape(self, true_classes_shape):