!40173 fix gumbel_softmax Rel error in graphmode

Merge pull request !40173 from TuDouNi/master
This commit is contained in:
i-robot 2022-08-11 07:58:39 +00:00 committed by Gitee
commit b3487206fd
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 4 additions and 4 deletions

View File

@ -4702,8 +4702,8 @@ def _check_positive_float(arg_value, arg_name, cls_name):
@constexpr
def _check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
validator.check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name, prim_name)
def _check_int_range(arg_value, lower_limit, upper_limit, arg_name=None, prim_name=None):
validator.check_int_range(arg_value, lower_limit, upper_limit, Rel.INC_LEFT, arg_name, prim_name)
def gumbel_softmax(logits, tau=1, hard=False, dim=-1):
@ -4750,9 +4750,9 @@ def gumbel_softmax(logits, tau=1, hard=False, dim=-1):
_check_attr_dtype("dim", dim, [int], "gumbel_softmax")
_check_positive_float(tau, "tau", "gumbel_softmax")
if hard:
_check_int_range(dim, -1, len(logits), Rel.INC_LEFT, 'dim', "gumbel_softmax")
_check_int_range(dim, -1, len(logits), 'dim', "gumbel_softmax")
else:
_check_int_range(dim, -len(logits), len(logits), Rel.INC_LEFT, 'dim', "gumbel_softmax")
_check_int_range(dim, -len(logits), len(logits), 'dim', "gumbel_softmax")
shape_op = _get_cache_prim(P.Shape)()
cast_op = _get_cache_prim(P.Cast)()