!35227 gumbel softmax

Merge pull request !35227 from TuDouNi/gumbelsoftmax2
This commit is contained in:
i-robot 2022-06-07 10:59:53 +00:00 committed by Gitee
commit e85d0cc246
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 107 additions and 2 deletions

View File

@ -320,6 +320,7 @@ Array操作
mindspore.ops.select
mindspore.ops.shape
mindspore.ops.size
mindspore.ops.gumbel_softmax
mindspore.ops.space_to_batch_nd
mindspore.ops.tensor_scatter_add
mindspore.ops.tensor_scatter_min

View File

@ -0,0 +1,26 @@
mindspore.ops.gumbel_softmax
=================
.. py:function:: mindspore.ops.gumbel_softmax(logits, tau=1, hard=False, dim=-1)
返回Gumbel-Softmax分布的Tensor`hard = True`的时候返回one-hot形式的离散型Tensor`hard = False`时返回在dim维进行过softmax的Tensor。
**参数:**
- **logits** (Tensor) - 输入,是一个非标准化的对数概率分布。 只支持float16和float32。
- **tau** (float) - 非负的标量温度。默认值1.0。
- **hard** (bool) - 为True时返回one-hot离散型Tensor可反向求导。默认值False。
- **dim** (int) - 给softmax使用的参数在dim维上做softmax操作。默认值-1。
**返回:**
Tensorshape与dtype和输入`logits`相同。
**异常:**
- **TypeError** - `logits`不是Tensor。
- **TypeError** - `logits`不是float16或float32。
- **TypeError** - `tau`不是float。
- **TypeError** - `hard`不是bool。
- **TypeError** - `dim`不是int。
- **ValueError** - `tau`不是正数。

View File

@ -198,7 +198,8 @@ from .math_func import (
isreal,
rad2deg,
truncate_div,
truncate_mod
truncate_mod,
gumbel_softmax,
)
from .nn_func import (
deformable_conv2d,

View File

@ -17,9 +17,11 @@
import math
import numpy as np
import mindspore.ops as ops
from mindspore.common import dtype as mstype
from mindspore.ops.primitive import constexpr
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from ..operations.math_ops import (Bernoulli, BesselJ0, BesselJ1, BesselK0, BesselK0e, BesselY0, BesselY1, BesselK1,
BesselK1e, Renorm)
from ...common import dtype as mstype
@ -3232,6 +3234,80 @@ def renorm(input_x, p, dim, maxnorm):
return renorm_(input_x)
@constexpr
def _check_attr_dtype(param_name, input_dtype, allow_dtypes, cls_name):
validator.check_value_type(param_name, input_dtype, allow_dtypes, cls_name)
@constexpr
def _check_positive_float(arg_value, arg_name, cls_name):
validator.check_positive_float(arg_value, arg_name, cls_name)
def gumbel_softmax(logits, tau=1, hard=False, dim=-1):
r"""
Returns the samples from the Gumbel-Softmax distribution and optionally discretizes. If `hard = True`, the returned
samples will be one-hot, otherwise it will be probability distributions that sum to 1 across `dim`.
Args:
logits (Tensor): Unnormalized log probabilities. The data type must be float16 or float32.
tau (float): Non-negative scalar temperature. Default: 1.0.
hard (bool): if `True`, the returned samples will be discretized as one-hot vectors, but will be differentiated
as if it is the soft sample in autograd. Default: False.
dim (int): Dim for softmax to compute. Default: -1.
Returns:
Tensor, has the same dtype and shape as `logits`.
Raises:
TypeError: If `logits` is not a Tensor.
TypeError: If dtype of `logits` is not one of: float16, float32.
TypeError: If `tau` is not an float.
TypeError: If `hard` is not a bool.
TypeError: If `dim` is not a int.
ValueError: If If `tau` is not positive.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
>>> output = ops.gumbel_softmax(input_x, 1.0, True, -1)
>>> print(output.shape)
(2, 3)
"""
if not isinstance(logits, (Tensor, Tensor_)):
raise TypeError("The input logits must be tensor")
dtype_op = P.DType()
logits_dtype = dtype_op(logits)
_check_input_dtype("logits", logits_dtype, [mstype.float16, mstype.float32], "gumbel_softmax")
_check_attr_dtype("tau", tau, [float], "gumbel_softmax")
_check_attr_dtype("hard", hard, [bool], "gumbel_softmax")
_check_attr_dtype("dim", dim, [int], "gumbel_softmax")
_check_positive_float(tau, "tau", "gumbel_softmax")
shape_op = P.Shape()
cast_op = P.Cast()
log_op = P.Log()
const_op = P.ScalarToArray()
softmax_op = P.Softmax(dim)
onehot_op = P.OneHot(dim)
sample_shape = shape_op(logits)
uniform = C.uniform(sample_shape, const_op(0.0), const_op(1.0))
uniform = cast_op(uniform, logits_dtype)
gumbel = neg_tensor(log_op(neg_tensor(log_op(uniform))))
gumbel = (logits + gumbel) / tau
y_soft = softmax_op(gumbel)
if hard:
index = y_soft.argmax(axis=dim)
y_hard = onehot_op(index, sample_shape[dim], Tensor(1, logits_dtype), Tensor(0, logits_dtype))
ret = y_hard - ops.stop_gradient(y_soft) + y_soft
else:
ret = y_soft
return ret
__all__ = [
'addn',
'absolute',
@ -3335,6 +3411,7 @@ __all__ = [
'deg2rad',
'rad2deg',
'truncate_div',
'truncate_mod'
'truncate_mod',
'gumbel_softmax'
]
__all__.sort()