forked from mindspore-Ecosystem/mindspore
!40173 fix gumbel_softmax Rel error in graphmode
Merge pull request !40173 from TuDouNi/master
This commit is contained in:
commit
b3487206fd
|
@ -4702,8 +4702,8 @@ def _check_positive_float(arg_value, arg_name, cls_name):
|
|||
|
||||
|
||||
@constexpr
|
||||
def _check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
|
||||
validator.check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name, prim_name)
|
||||
def _check_int_range(arg_value, lower_limit, upper_limit, arg_name=None, prim_name=None):
|
||||
validator.check_int_range(arg_value, lower_limit, upper_limit, Rel.INC_LEFT, arg_name, prim_name)
|
||||
|
||||
|
||||
def gumbel_softmax(logits, tau=1, hard=False, dim=-1):
|
||||
|
@ -4750,9 +4750,9 @@ def gumbel_softmax(logits, tau=1, hard=False, dim=-1):
|
|||
_check_attr_dtype("dim", dim, [int], "gumbel_softmax")
|
||||
_check_positive_float(tau, "tau", "gumbel_softmax")
|
||||
if hard:
|
||||
_check_int_range(dim, -1, len(logits), Rel.INC_LEFT, 'dim', "gumbel_softmax")
|
||||
_check_int_range(dim, -1, len(logits), 'dim', "gumbel_softmax")
|
||||
else:
|
||||
_check_int_range(dim, -len(logits), len(logits), Rel.INC_LEFT, 'dim', "gumbel_softmax")
|
||||
_check_int_range(dim, -len(logits), len(logits), 'dim', "gumbel_softmax")
|
||||
|
||||
shape_op = _get_cache_prim(P.Shape)()
|
||||
cast_op = _get_cache_prim(P.Cast)()
|
||||
|
|
Loading…
Reference in New Issue