!8804 Prevent int64 from converting to int32 in UniformCandidateSampler
From: @TFbunny Reviewed-by: @robingrosman Signed-off-by:
This commit is contained in:
commit
0b992c077b
|
@ -541,7 +541,8 @@ class UniformCandidateSampler(PrimitiveWithInfer):
|
|||
of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.
|
||||
unique (bool): Whether all sampled classes in a batch are unique.
|
||||
range_max (int): The number of possible classes, must be non-negative.
|
||||
seed (int): Random seed, must be non-negative. Default: 0.
|
||||
seed (int): Used for random number generation, must be non-negative. If seed has a value of 0,
|
||||
seed will be replaced with a randomly generated value. Default: 0.
|
||||
remove_accidental_hits (bool): Whether accidental hit is removed. Default: False.
|
||||
|
||||
Inputs:
|
||||
|
@ -580,6 +581,7 @@ class UniformCandidateSampler(PrimitiveWithInfer):
|
|||
self.num_sampled = num_sampled
|
||||
|
||||
def infer_dtype(self, true_classes_type):
|
||||
Validator.check_tensor_dtype_valid("true_classes_type", true_classes_type, (mstype.int32), self.name)
|
||||
return (true_classes_type, mstype.float32, mstype.float32)
|
||||
|
||||
def infer_shape(self, true_classes_shape):
|
||||
|
|
Loading…
Reference in New Issue