!39754 gumbelsoftmax add check

Merge pull request !39754 from TuDouNi/gumbelsoftmax2
This commit is contained in:
i-robot 2022-08-08 01:50:33 +00:00 committed by Gitee
commit 062f987e2a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 12 additions and 0 deletions

View File

@ -57,6 +57,7 @@ from ...common import dtype as mstype
from ...common.tensor import Tensor
from ..._c_expression import Tensor as Tensor_
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from .._primitive_cache import _get_cache_prim
@ -4586,6 +4587,11 @@ def _check_positive_float(arg_value, arg_name, cls_name):
validator.check_positive_float(arg_value, arg_name, cls_name)
@constexpr
def _check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
validator.check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name, prim_name)
def gumbel_softmax(logits, tau=1, hard=False, dim=-1):
r"""
Returns the samples from the Gumbel-Softmax distribution and optionally discretizes. If `hard = True`, the returned
@ -4620,6 +4626,8 @@ def gumbel_softmax(logits, tau=1, hard=False, dim=-1):
"""
if not isinstance(logits, (Tensor, Tensor_)):
raise TypeError("The input logits must be tensor")
if logits.shape == ():
raise ValueError("For gumbel_softmax, the 0-D input is not supported.")
dtype_op = _get_cache_prim(P.DType)()
logits_dtype = dtype_op(logits)
_check_input_dtype("logits", logits_dtype, [mstype.float16, mstype.float32], "gumbel_softmax")
@ -4627,6 +4635,10 @@ def gumbel_softmax(logits, tau=1, hard=False, dim=-1):
_check_attr_dtype("hard", hard, [bool], "gumbel_softmax")
_check_attr_dtype("dim", dim, [int], "gumbel_softmax")
_check_positive_float(tau, "tau", "gumbel_softmax")
if hard:
_check_int_range(dim, -1, len(logits), Rel.INC_LEFT, 'dim', "gumbel_softmax")
else:
_check_int_range(dim, -len(logits), len(logits), Rel.INC_LEFT, 'dim', "gumbel_softmax")
shape_op = _get_cache_prim(P.Shape)()
cast_op = _get_cache_prim(P.Cast)()