!40332 gumbel softmax fix bug

Merge pull request !40332 from TuDouNi/master
This commit is contained in:
i-robot 2022-08-16 01:15:10 +00:00 committed by Gitee
commit 743ea9f307
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 2 additions and 2 deletions

View File

@ -4749,9 +4749,9 @@ def gumbel_softmax(logits, tau=1, hard=False, dim=-1):
_check_attr_dtype("dim", dim, [int], "gumbel_softmax")
_check_positive_float(tau, "tau", "gumbel_softmax")
if hard:
_check_int_range(dim, -1, len(logits), 'dim', "gumbel_softmax")
_check_int_range(dim, -1, len(logits.shape), 'dim', "gumbel_softmax")
else:
_check_int_range(dim, -len(logits), len(logits), 'dim', "gumbel_softmax")
_check_int_range(dim, -len(logits.shape), len(logits.shape), 'dim', "gumbel_softmax")
shape_op = _get_cache_prim(P.Shape)()
cast_op = _get_cache_prim(P.Cast)()