diff --git a/mindspore/python/mindspore/ops/function/math_func.py b/mindspore/python/mindspore/ops/function/math_func.py index 654f0142a97..a609d5b8840 100644 --- a/mindspore/python/mindspore/ops/function/math_func.py +++ b/mindspore/python/mindspore/ops/function/math_func.py @@ -4749,9 +4749,9 @@ def gumbel_softmax(logits, tau=1, hard=False, dim=-1): _check_attr_dtype("dim", dim, [int], "gumbel_softmax") _check_positive_float(tau, "tau", "gumbel_softmax") if hard: - _check_int_range(dim, -1, len(logits), 'dim', "gumbel_softmax") + _check_int_range(dim, -1, len(logits.shape), 'dim', "gumbel_softmax") else: - _check_int_range(dim, -len(logits), len(logits), 'dim', "gumbel_softmax") + _check_int_range(dim, -len(logits.shape), len(logits.shape), 'dim', "gumbel_softmax") shape_op = _get_cache_prim(P.Shape)() cast_op = _get_cache_prim(P.Cast)()