!39754 gumbelsoftmax add check
Merge pull request !39754 from TuDouNi/gumbelsoftmax2
This commit is contained in:
commit
062f987e2a
|
@ -57,6 +57,7 @@ from ...common import dtype as mstype
|
|||
from ...common.tensor import Tensor
|
||||
from ..._c_expression import Tensor as Tensor_
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from .._primitive_cache import _get_cache_prim
|
||||
|
||||
|
||||
|
@ -4586,6 +4587,11 @@ def _check_positive_float(arg_value, arg_name, cls_name):
|
|||
validator.check_positive_float(arg_value, arg_name, cls_name)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
|
||||
validator.check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name, prim_name)
|
||||
|
||||
|
||||
def gumbel_softmax(logits, tau=1, hard=False, dim=-1):
|
||||
r"""
|
||||
Returns the samples from the Gumbel-Softmax distribution and optionally discretizes. If `hard = True`, the returned
|
||||
|
@ -4620,6 +4626,8 @@ def gumbel_softmax(logits, tau=1, hard=False, dim=-1):
|
|||
"""
|
||||
if not isinstance(logits, (Tensor, Tensor_)):
|
||||
raise TypeError("The input logits must be tensor")
|
||||
if logits.shape == ():
|
||||
raise ValueError("For gumbel_softmax, the 0-D input is not supported.")
|
||||
dtype_op = _get_cache_prim(P.DType)()
|
||||
logits_dtype = dtype_op(logits)
|
||||
_check_input_dtype("logits", logits_dtype, [mstype.float16, mstype.float32], "gumbel_softmax")
|
||||
|
@ -4627,6 +4635,10 @@ def gumbel_softmax(logits, tau=1, hard=False, dim=-1):
|
|||
_check_attr_dtype("hard", hard, [bool], "gumbel_softmax")
|
||||
_check_attr_dtype("dim", dim, [int], "gumbel_softmax")
|
||||
_check_positive_float(tau, "tau", "gumbel_softmax")
|
||||
if hard:
|
||||
_check_int_range(dim, -1, len(logits), Rel.INC_LEFT, 'dim', "gumbel_softmax")
|
||||
else:
|
||||
_check_int_range(dim, -len(logits), len(logits), Rel.INC_LEFT, 'dim', "gumbel_softmax")
|
||||
|
||||
shape_op = _get_cache_prim(P.Shape)()
|
||||
cast_op = _get_cache_prim(P.Cast)()
|
||||
|
|
Loading…
Reference in New Issue