forked from mindspore-Ecosystem/mindspore
!35227 gumbel softmax
Merge pull request !35227 from TuDouNi/gumbelsoftmax2
This commit is contained in:
commit
e85d0cc246
|
@ -320,6 +320,7 @@ Array操作
|
|||
mindspore.ops.select
|
||||
mindspore.ops.shape
|
||||
mindspore.ops.size
|
||||
mindspore.ops.gumbel_softmax
|
||||
mindspore.ops.space_to_batch_nd
|
||||
mindspore.ops.tensor_scatter_add
|
||||
mindspore.ops.tensor_scatter_min
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
mindspore.ops.gumbel_softmax
|
||||
=================
|
||||
|
||||
.. py:function:: mindspore.ops.gumbel_softmax(logits, tau=1, hard=False, dim=-1)
|
||||
|
||||
返回Gumbel-Softmax分布的Tensor,在`hard = True`的时候,返回one-hot形式的离散型Tensor,`hard = False`时返回在dim维进行过softmax的Tensor。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **logits** (Tensor) - 输入,是一个非标准化的对数概率分布。 只支持float16和float32。
|
||||
- **tau** (float) - 非负的标量温度。默认值:1.0。
|
||||
- **hard** (bool) - 为True时返回one-hot离散型Tensor,可反向求导。默认值:False。
|
||||
- **dim** (int) - 给softmax使用的参数,在dim维上做softmax操作。默认值:-1。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,shape与dtype和输入`logits`相同。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `logits`不是Tensor。
|
||||
- **TypeError** - `logits`不是float16或float32。
|
||||
- **TypeError** - `tau`不是float。
|
||||
- **TypeError** - `hard`不是bool。
|
||||
- **TypeError** - `dim`不是int。
|
||||
- **ValueError** - `tau`不是正数。
|
|
@ -198,7 +198,8 @@ from .math_func import (
|
|||
isreal,
|
||||
rad2deg,
|
||||
truncate_div,
|
||||
truncate_mod
|
||||
truncate_mod,
|
||||
gumbel_softmax,
|
||||
)
|
||||
from .nn_func import (
|
||||
deformable_conv2d,
|
||||
|
|
|
@ -17,9 +17,11 @@
|
|||
|
||||
import math
|
||||
import numpy as np
|
||||
import mindspore.ops as ops
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from ..operations.math_ops import (Bernoulli, BesselJ0, BesselJ1, BesselK0, BesselK0e, BesselY0, BesselY1, BesselK1,
|
||||
BesselK1e, Renorm)
|
||||
from ...common import dtype as mstype
|
||||
|
@ -3232,6 +3234,80 @@ def renorm(input_x, p, dim, maxnorm):
|
|||
return renorm_(input_x)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_attr_dtype(param_name, input_dtype, allow_dtypes, cls_name):
|
||||
validator.check_value_type(param_name, input_dtype, allow_dtypes, cls_name)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_positive_float(arg_value, arg_name, cls_name):
|
||||
validator.check_positive_float(arg_value, arg_name, cls_name)
|
||||
|
||||
|
||||
def gumbel_softmax(logits, tau=1, hard=False, dim=-1):
|
||||
r"""
|
||||
Returns the samples from the Gumbel-Softmax distribution and optionally discretizes. If `hard = True`, the returned
|
||||
samples will be one-hot, otherwise it will be probability distributions that sum to 1 across `dim`.
|
||||
|
||||
Args:
|
||||
logits (Tensor): Unnormalized log probabilities. The data type must be float16 or float32.
|
||||
tau (float): Non-negative scalar temperature. Default: 1.0.
|
||||
hard (bool): if `True`, the returned samples will be discretized as one-hot vectors, but will be differentiated
|
||||
as if it is the soft sample in autograd. Default: False.
|
||||
dim (int): Dim for softmax to compute. Default: -1.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same dtype and shape as `logits`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `logits` is not a Tensor.
|
||||
TypeError: If dtype of `logits` is not one of: float16, float32.
|
||||
TypeError: If `tau` is not an float.
|
||||
TypeError: If `hard` is not a bool.
|
||||
TypeError: If `dim` is not a int.
|
||||
ValueError: If If `tau` is not positive.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
|
||||
>>> output = ops.gumbel_softmax(input_x, 1.0, True, -1)
|
||||
>>> print(output.shape)
|
||||
(2, 3)
|
||||
"""
|
||||
if not isinstance(logits, (Tensor, Tensor_)):
|
||||
raise TypeError("The input logits must be tensor")
|
||||
dtype_op = P.DType()
|
||||
logits_dtype = dtype_op(logits)
|
||||
_check_input_dtype("logits", logits_dtype, [mstype.float16, mstype.float32], "gumbel_softmax")
|
||||
_check_attr_dtype("tau", tau, [float], "gumbel_softmax")
|
||||
_check_attr_dtype("hard", hard, [bool], "gumbel_softmax")
|
||||
_check_attr_dtype("dim", dim, [int], "gumbel_softmax")
|
||||
_check_positive_float(tau, "tau", "gumbel_softmax")
|
||||
|
||||
shape_op = P.Shape()
|
||||
cast_op = P.Cast()
|
||||
log_op = P.Log()
|
||||
const_op = P.ScalarToArray()
|
||||
softmax_op = P.Softmax(dim)
|
||||
onehot_op = P.OneHot(dim)
|
||||
|
||||
sample_shape = shape_op(logits)
|
||||
uniform = C.uniform(sample_shape, const_op(0.0), const_op(1.0))
|
||||
uniform = cast_op(uniform, logits_dtype)
|
||||
gumbel = neg_tensor(log_op(neg_tensor(log_op(uniform))))
|
||||
gumbel = (logits + gumbel) / tau
|
||||
y_soft = softmax_op(gumbel)
|
||||
if hard:
|
||||
index = y_soft.argmax(axis=dim)
|
||||
y_hard = onehot_op(index, sample_shape[dim], Tensor(1, logits_dtype), Tensor(0, logits_dtype))
|
||||
ret = y_hard - ops.stop_gradient(y_soft) + y_soft
|
||||
else:
|
||||
ret = y_soft
|
||||
return ret
|
||||
|
||||
|
||||
__all__ = [
|
||||
'addn',
|
||||
'absolute',
|
||||
|
@ -3335,6 +3411,7 @@ __all__ = [
|
|||
'deg2rad',
|
||||
'rad2deg',
|
||||
'truncate_div',
|
||||
'truncate_mod'
|
||||
'truncate_mod',
|
||||
'gumbel_softmax'
|
||||
]
|
||||
__all__.sort()
|
||||
|
|
Loading…
Reference in New Issue