gumbel softmax

This commit is contained in:
ttudu 2022-08-12 15:09:09 +08:00
parent 2f1fc1ec2c
commit 4b5d455cc8
1 changed files with 2 additions and 2 deletions

View File

@ -4749,9 +4749,9 @@ def gumbel_softmax(logits, tau=1, hard=False, dim=-1):
_check_attr_dtype("dim", dim, [int], "gumbel_softmax")
_check_positive_float(tau, "tau", "gumbel_softmax")
if hard:
_check_int_range(dim, -1, len(logits), 'dim', "gumbel_softmax")
_check_int_range(dim, -1, len(logits.shape), 'dim', "gumbel_softmax")
else:
_check_int_range(dim, -len(logits), len(logits), 'dim', "gumbel_softmax")
_check_int_range(dim, -len(logits.shape), len(logits.shape), 'dim', "gumbel_softmax")
shape_op = _get_cache_prim(P.Shape)()
cast_op = _get_cache_prim(P.Cast)()