forked from mindspore-Ecosystem/mindspore
!47250 silu ops functional
Merge pull request !47250 from Henry Shi/ops_silu
This commit is contained in:
commit
9b6736d4be
|
@ -98,6 +98,7 @@ mindspore.ops
|
|||
mindspore.ops.rrelu
|
||||
mindspore.ops.selu
|
||||
mindspore.ops.sigmoid
|
||||
mindspore.ops.silu
|
||||
mindspore.ops.soft_shrink
|
||||
mindspore.ops.softmax
|
||||
mindspore.ops.softmin
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
mindspore.ops.silu
|
||||
==================
|
||||
|
||||
.. py:function:: mindspore.ops.silu(x)
|
||||
|
||||
激活函数SiLU(Sigmoid Linear Unit)。
|
||||
|
||||
该激活函数定义为:
|
||||
|
||||
.. math::
|
||||
\text{SiLU}(x) = x * \sigma(x),
|
||||
|
||||
其中 :math:`x_i` 是输入的元素, math:`\sigma(x)` Logistic Sigmoid函数。
|
||||
|
||||
.. math::
|
||||
|
||||
\text{sigma}(x_i) = \frac{1}{1 + \exp(-x_i)},
|
||||
|
||||
关于SiLU的图例见 `SiLU <https://en.wikipedia.org/wiki/Activation_function#/media/File:Swish.svg>`_ 。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 数据类型为float16, float32, float64, complex64 或 complex128的输入。任意维度的Tensor。
|
||||
|
||||
返回:
|
||||
Tensor,数据类型和shape与 `x` 的相同。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `x` 的数据类型不是float16, float32, float64, complex64 或 complex128。
|
|
@ -99,6 +99,7 @@ Activation Functions
|
|||
mindspore.ops.rrelu
|
||||
mindspore.ops.selu
|
||||
mindspore.ops.sigmoid
|
||||
mindspore.ops.silu
|
||||
mindspore.ops.softsign
|
||||
mindspore.ops.soft_shrink
|
||||
mindspore.ops.softmax
|
||||
|
|
|
@ -651,10 +651,9 @@ class SiLU(Cell):
|
|||
def __init__(self):
|
||||
"""Initialize SiLU."""
|
||||
super(SiLU, self).__init__()
|
||||
self.sigmoid = P.Sigmoid()
|
||||
|
||||
def construct(self, x):
|
||||
return self.sigmoid(x) * x
|
||||
return ops.function.silu(x)
|
||||
|
||||
|
||||
class Tanh(Cell):
|
||||
|
|
|
@ -419,6 +419,7 @@ from .nn_func import (
|
|||
hardtanh,
|
||||
huber_loss,
|
||||
softsign,
|
||||
silu,
|
||||
selu,
|
||||
softmax,
|
||||
softmin,
|
||||
|
|
|
@ -2198,6 +2198,28 @@ def soft_shrink(x, lambd=0.5):
|
|||
return soft_shrink_op(x)
|
||||
|
||||
|
||||
def silu(x):
|
||||
r"""
|
||||
Sigmoid Linear Unit.
|
||||
|
||||
Computes Sigmoid Linear Unit of input element-wise. The SiLU function is defined as:
|
||||
|
||||
.. math::
|
||||
\text{SiLU}(x) = x * \sigma(x)
|
||||
|
||||
where the Logistic Sigmoid function is defined as:
|
||||
|
||||
.. math::
|
||||
|
||||
\text{sigma}(x_i) = \frac{1}{1 + \exp(-x_i)}
|
||||
|
||||
where :math:`x_i` is an element of the x.
|
||||
|
||||
For more details, please refer to mindspore.nn.SiLU.
|
||||
"""
|
||||
return sigmoid_(x)*x
|
||||
|
||||
|
||||
def selu(input_x):
|
||||
r"""
|
||||
Activation function SeLU (Scaled exponential Linear Unit).
|
||||
|
@ -4632,7 +4654,7 @@ def conv3d(inputs, weight, pad_mode="valid", padding=0, stride=1, dilation=1, gr
|
|||
:math:`(C_{out}, C_{in} / \text{group}, \text{kernel_size[0]}, \text{kernel_size[1]}, \text{kernel_size[2]})`,
|
||||
where `group` is the number of groups to split the input `x` in the channel dimension.
|
||||
|
||||
For more details, please refers to the paper `Gradient Based Learning Applied to Document
|
||||
For more details, please refer to the paper `Gradient Based Learning Applied to Document
|
||||
Recognition <http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_ .
|
||||
|
||||
Note:
|
||||
|
@ -5539,6 +5561,7 @@ __all__ = [
|
|||
'huber_loss',
|
||||
'softsign',
|
||||
'selu',
|
||||
'silu',
|
||||
'softmax',
|
||||
'softmin',
|
||||
'pdist',
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor, nn
|
||||
import mindspore.ops.function as F
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return F.silu(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_net(mode):
|
||||
"""
|
||||
Feature: test silu op
|
||||
Description: verify the result of silu
|
||||
Expectation: assertion success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = Tensor(np.array([1, 2, 3, 4, 5]), ms.float32)
|
||||
silu = Net()
|
||||
output = silu(x)
|
||||
np_out = np.array([0.7310586, 1.7615942, 2.8577223, 3.9280553, 4.966536])
|
||||
assert np.allclose(output.asnumpy(), np_out)
|
Loading…
Reference in New Issue