forked from mindspore-Ecosystem/mindspore
!42637 support glu
Merge pull request !42637 from lianliguang/glu_grad_ops
This commit is contained in:
commit
268abc5fb1
|
@ -0,0 +1,20 @@
|
|||
mindspore.nn.GLU
|
||||
=================
|
||||
|
||||
.. py:class:: mindspore.nn.GLU(axis=-1)
|
||||
|
||||
门线性单元函数(Gated Linear Unit function)。
|
||||
|
||||
.. math:: `{GLU}(a, b)= a \otimes \sigma(b)`
|
||||
|
||||
|
||||
其中,:math:`a` 表示输入Tensor的前一半元素,:math:`b` 表示输入Tensor的另一半元素。
|
||||
|
||||
参数:
|
||||
- **axis** (`int`) - 指定分割轴。数据类型为整型,默认值:0。
|
||||
|
||||
输入:
|
||||
- **x** (`Tensor`) - Tensor的shape为 (x_1, x_2, ..., x_R) 。x必须在axis 轴能够被平均分成两份。
|
||||
|
||||
输出:
|
||||
Tensor,数据类型与输入 x 相同,shape等于 x 按照 axis 拆分后的一半。
|
|
@ -0,0 +1,18 @@
|
|||
mindspore.ops.glu
|
||||
=================
|
||||
|
||||
.. py:class:: mindspore.ops.glu(input_x, axis=-1)
|
||||
|
||||
门线性单元函数(Gated Linear Unit function)。
|
||||
|
||||
.. math:: {GLU}(a, b)= a \otimes \sigma(b)`
|
||||
|
||||
|
||||
其中,:math:`a` 表示输入input_x 拆分后 Tensor的前一半元素,:math:`b` 表示输入拆分Tensor的另一半元素。
|
||||
|
||||
参数:
|
||||
- **axis** (`int`) - 指定分割轴。数据类型为整型,默认值:0。
|
||||
- **x** (`Tensor`) - Tensor的shape为 (x_1, x_2, ..., x_R) 。x 必须在axis 轴能够被平均分成两份。
|
||||
|
||||
返回:
|
||||
Tensor,数据类型与输入 x 相同,shape等于 x 按照axis 拆分后的一半。
|
|
@ -55,7 +55,8 @@ __all__ = ['Softmin',
|
|||
'HShrink',
|
||||
'CELU',
|
||||
'Threshold',
|
||||
'Mish'
|
||||
'Mish',
|
||||
'GLU'
|
||||
]
|
||||
|
||||
|
||||
|
@ -1434,6 +1435,43 @@ class Mish(Cell):
|
|||
return self.mish(input_x)
|
||||
|
||||
|
||||
class GLU(Cell):
|
||||
r"""Applies the gated linear unit function
|
||||
:math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
|
||||
of the input matrices and :math:`b` is the second half.
|
||||
|
||||
Here :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
|
||||
|
||||
Args:
|
||||
axis (int): the dimension on which to split the input. Default: -1
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional dimensions
|
||||
|
||||
Outputs:
|
||||
Tensor, math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples::
|
||||
>>> m = nn.GLU()
|
||||
>>> input = Tensor(np.randomn.randn(4, 2))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def __init__(self, axis=-1):
|
||||
"""Initialize GLU."""
|
||||
super().__init__("GLU")
|
||||
self.dim = axis
|
||||
self.spilt = P.Split(axis=axis, output_num=2)
|
||||
self.sigmoid = P.Sigmoid()
|
||||
|
||||
def construct(self, x):
|
||||
x1, x2 = self.spilt(x)
|
||||
x2 = self.sigmoid(x2)
|
||||
return x1 * x2
|
||||
|
||||
_activation = {
|
||||
'softmin': Softmin,
|
||||
'softmax': Softmax,
|
||||
|
@ -1458,7 +1496,7 @@ _activation = {
|
|||
'softshrink': SoftShrink,
|
||||
'hshrink': HShrink,
|
||||
'threshold': Threshold,
|
||||
'mish': Mish
|
||||
'mish': Mish,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -322,6 +322,7 @@ from .nn_func import (
|
|||
relu,
|
||||
relu6,
|
||||
conv3d,
|
||||
glu,
|
||||
)
|
||||
from .linalg_func import (
|
||||
svd,
|
||||
|
|
|
@ -3122,6 +3122,45 @@ def conv3d(inputs, weight, pad_mode="valid", padding=0, stride=1, dilation=1, gr
|
|||
return output
|
||||
|
||||
|
||||
def glu(x, axis=-1):
|
||||
r"""
|
||||
Computes GLU (Gated Linear Unit activation function) of input tensors .
|
||||
|
||||
|
||||
.. math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
|
||||
of the input matrices and :math:`b` is the second half.
|
||||
|
||||
Here :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
|
||||
See 'Language Modeling with Gated Convluational Networks <https://arxiv.org/abs/1612.08083>'_
|
||||
|
||||
Args:
|
||||
x(Tensor): :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional dimensions
|
||||
axis (int): the dimension on which to split the input. Default: -1
|
||||
|
||||
Returns:
|
||||
Tensor of shape :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`, with the same dtype and shape as the `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `x` is not a number.
|
||||
TypeError: If `x` is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> m = nn.GLU()
|
||||
>>> input = Tensor(np.randomn.randn(4, 2))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
if not isinstance(x, Tensor) or x.size == 0:
|
||||
raise RuntimeError("glu does not support scalars because halving size must be even")
|
||||
|
||||
spilt = _get_cache_prim(P.Split)(axis=axis, output_num=2)
|
||||
x, y = spilt(x)
|
||||
y = sigmoid_(y)
|
||||
return x * y
|
||||
|
||||
|
||||
__all__ = [
|
||||
'adaptive_avg_pool1d',
|
||||
'adaptive_avg_pool2d',
|
||||
|
@ -3170,5 +3209,6 @@ __all__ = [
|
|||
'relu',
|
||||
'relu6',
|
||||
'conv3d',
|
||||
'glu'
|
||||
]
|
||||
__all__.sort()
|
||||
|
|
Loading…
Reference in New Issue