forked from mindspore-Ecosystem/mindspore
!35361 add ops.logsumexp and fix nan problem
Merge pull request !35361 from 吕昱峰(Nate.River)/logsumexp
This commit is contained in:
commit
ad8dab99a3
|
@ -125,6 +125,7 @@ functional算子是经过初始化后的Primitive,可以直接作为函数使
|
|||
mindspore.ops.logical_and
|
||||
mindspore.ops.logical_not
|
||||
mindspore.ops.logical_or
|
||||
mindspore.ops.logsumexp
|
||||
mindspore.ops.mul
|
||||
mindspore.ops.neg
|
||||
mindspore.ops.pow
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
mindspore.ops.logsumexp
|
||||
=======================
|
||||
|
||||
.. py:function:: mindspore.ops.logsumexp(x, aixs, keep_dims=False)
|
||||
|
||||
逐元素返回Tensor的对数指数和。
|
||||
|
||||
.. math::
|
||||
|
||||
logsumexp(x) = \log(\sum(e^(x-x_{max}))) + x_{max}
|
||||
|
||||
.. note::
|
||||
Ascend上输入Tensor的维度要小于等于8,CPU上输入Tensor的维度要小于8。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **x** (Tensor) - 任意维度的输入Tensor。数据类型为float16或float32。
|
||||
- **axis** (Union[int, tuple(int), list(int)]): 需要reduce的维度,输入为 `()` 时reduce所有维度。
|
||||
- **keep_dims** (bool): 是否保留reduce的维度。若设为True,则计算后指定的axis的长度为1,否则不保留该维度。默认值:False。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,具有与 `x` 相同的dtype。
|
||||
|
||||
- 若axis为(),且keep_dims设为False,则输出为0维Tensor。
|
||||
- 若axis为int类型,假设为2,且keep_dims设为False,则输出的shape为 :math:`(x_1, x_3, ..., x_R)`。
|
||||
- 若axis为tuple(int),假设为(2, 3),且keep_dims为False,则输出shape为 :math:`(x_1, x_4, ..., x_R)`。
|
|
@ -125,6 +125,7 @@ Element-by-Element Operations
|
|||
mindspore.ops.logical_and
|
||||
mindspore.ops.logical_not
|
||||
mindspore.ops.logical_or
|
||||
mindspore.ops.logsumexp
|
||||
mindspore.ops.mul
|
||||
mindspore.ops.neg
|
||||
mindspore.ops.pow
|
||||
|
|
|
@ -107,10 +107,11 @@ class ReduceLogSumExp(Cell):
|
|||
self.log = P.Log()
|
||||
|
||||
def construct(self, x):
|
||||
exp = self.exp(x)
|
||||
x_max = x.max()
|
||||
exp = self.exp(x - x_max)
|
||||
sumexp = self.sum(exp, self.axis)
|
||||
logsumexp = self.log(sumexp)
|
||||
return logsumexp
|
||||
return logsumexp + x_max
|
||||
|
||||
|
||||
class Range(Cell):
|
||||
|
|
|
@ -154,6 +154,7 @@ from .math_func import (
|
|||
logical_not,
|
||||
logical_or,
|
||||
logical_and,
|
||||
logsumexp,
|
||||
sin,
|
||||
cos,
|
||||
tan,
|
||||
|
|
|
@ -958,6 +958,51 @@ def logical_and(x):
|
|||
return logical_and_(x)
|
||||
|
||||
|
||||
def logsumexp(x, axis, keep_dims=False):
|
||||
r"""
|
||||
Reduces a dimension of a tensor by calculating exponential for all elements in the dimension,
|
||||
then calculate logarithm of the sum.
|
||||
|
||||
.. math::
|
||||
|
||||
logsumexp(x) = \log(\sum(e^(x-x_{max}))) + x_{max}
|
||||
|
||||
Args:
|
||||
x (Tensor): The input tensor. With float16 or float32 data type.
|
||||
axis (Union[int, tuple(int), list(int)]): The dimensions to reduce. Default: (), reduce all dimensions.
|
||||
Only constant value is allowed.
|
||||
keep_dims (bool): If True, keep these reduced dimensions and the length is 1.
|
||||
If False, don't keep these dimensions.
|
||||
Default : False.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same dtype as the `x`.
|
||||
|
||||
- If axis is (), and keep_dims is False,
|
||||
the output is a 0-D tensor representing the sum of all elements in the input tensor.
|
||||
- If axis is int, set as 2, and keep_dims is False,
|
||||
the shape of output is :math:`(x_1, x_3, ..., x_R)`.
|
||||
- If axis is tuple(int), set as (2, 3), and keep_dims is False,
|
||||
the shape of output is :math:`(x_1, x_4, ..., x_R)`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
|
||||
>>> op = ops.logsumexp(x, 1, keep_dims=True)
|
||||
>>> output = op(x)
|
||||
>>> print(output.shape)
|
||||
(3, 1, 5, 6)
|
||||
"""
|
||||
|
||||
x_max = x.max()
|
||||
x_exp = P.Exp()(x - x_max)
|
||||
x_sumexp = P.ReduceSum(keep_dims)(x_exp, axis)
|
||||
x_logsumexp = P.Log()(x_sumexp)
|
||||
return x_logsumexp + x_max
|
||||
|
||||
|
||||
def sin(x):
|
||||
r"""
|
||||
Computes sine of the input element-wise.
|
||||
|
@ -3065,6 +3110,7 @@ __all__ = [
|
|||
'logical_not',
|
||||
'logical_or',
|
||||
'logical_and',
|
||||
'logsumexp',
|
||||
'sin',
|
||||
'cos',
|
||||
'tan',
|
||||
|
|
Loading…
Reference in New Issue