!40332 gumbel softmax fix bug
Merge pull request !40332 from TuDouNi/master
This commit is contained in:
commit
743ea9f307
|
@ -4749,9 +4749,9 @@ def gumbel_softmax(logits, tau=1, hard=False, dim=-1):
|
|||
_check_attr_dtype("dim", dim, [int], "gumbel_softmax")
|
||||
_check_positive_float(tau, "tau", "gumbel_softmax")
|
||||
if hard:
|
||||
_check_int_range(dim, -1, len(logits), 'dim', "gumbel_softmax")
|
||||
_check_int_range(dim, -1, len(logits.shape), 'dim', "gumbel_softmax")
|
||||
else:
|
||||
_check_int_range(dim, -len(logits), len(logits), 'dim', "gumbel_softmax")
|
||||
_check_int_range(dim, -len(logits.shape), len(logits.shape), 'dim', "gumbel_softmax")
|
||||
|
||||
shape_op = _get_cache_prim(P.Shape)()
|
||||
cast_op = _get_cache_prim(P.Cast)()
|
||||
|
|
Loading…
Reference in New Issue