forked from mindspore-Ecosystem/mindspore
!47237 for nn.function: rrelu and logsigmoid
Merge pull request !47237 from 于振华/nn_api_rrelu_logsigmoid_1226
This commit is contained in:
commit
9fb72f3b68
|
@ -90,7 +90,9 @@ mindspore.ops
|
|||
mindspore.ops.hardshrink
|
||||
mindspore.ops.hardswish
|
||||
mindspore.ops.log_softmax
|
||||
mindspore.ops.logsigmoid
|
||||
mindspore.ops.mish
|
||||
mindspore.ops.rrelu
|
||||
mindspore.ops.selu
|
||||
mindspore.ops.sigmoid
|
||||
mindspore.ops.soft_shrink
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
mindspore.ops.logsigmoid
|
||||
=============================
|
||||
|
||||
.. py:function:: mindspore.ops.logsigmoid(x)
|
||||
|
||||
logsigmoid激活函数。
|
||||
|
||||
按元素计算logsigmoid激活函数。输入是任意格式的Tensor。
|
||||
|
||||
logsigmoid定义为:
|
||||
|
||||
.. math::
|
||||
\text{logsigmoid}(x_{i}) = log(\frac{1}{1 + \exp(-x_i)}),
|
||||
|
||||
其中,:math:`x_i` 是输入Tensor的一个元素。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - logsigmoid的输入,数据类型为float16或float32。shape为 :math:`(N,*)` ,其中 :math:`*` 表示任意的附加维度。
|
||||
|
||||
返回:
|
||||
Tensor,数据类型和shape与 `x` 的相同。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `x` 的数据类型既不是float16也不是float32。
|
|
@ -0,0 +1,31 @@
|
|||
mindspore.ops.rrelu
|
||||
===================
|
||||
|
||||
.. py:function:: mindspore.ops.rrelu(x, lower=1 / 8, upper=1 / 3)
|
||||
|
||||
Randomized Leaky ReLU激活函数。
|
||||
|
||||
该激活函数定义如下:
|
||||
|
||||
.. math::
|
||||
\text{rrelu}(x_{ji}) = \begin{cases}x_{ji}, &\text{if } x_{ji} \geq 0; \cr
|
||||
{\alpha_{ji}} * x_{ji}, &\text{otherwise.}\end{cases}
|
||||
|
||||
其中,:math:`\alpha_{ji}` ~ :math:`U(l, u)`, :math:`l \le u`.
|
||||
|
||||
更多细节详见 `Empirical Evaluation of Rectified Activations in Convolution Network <https://arxiv.org/pdf/1505.00853.pdf>`_。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 计算RReLU的任意维度的Tensor。
|
||||
- **lower** (Union[int, float]) - x<0时激活函数的斜率的下界,默认值:1/8。
|
||||
- **upper** (Union[int, float]) - x<0时激活函数的斜率的上界,默认值:1/3。
|
||||
|
||||
返回:
|
||||
Tensor,数据类型和shape与 `x` 相同。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `lower` 不是浮点数或整数。
|
||||
- **TypeError** - `upper` 不是浮点数或整数。
|
||||
- **TypeError** - `x` 不是Tensor。
|
||||
- **TypeError** - `x` 内的数据类型不是mindspore.float16或mindspore.float32。
|
||||
- **ValueError** - `lower` 大于 `upper`。
|
|
@ -91,7 +91,9 @@ Activation Functions
|
|||
mindspore.ops.hardshrink
|
||||
mindspore.ops.hardswish
|
||||
mindspore.ops.log_softmax
|
||||
mindspore.ops.logsigmoid
|
||||
mindspore.ops.mish
|
||||
mindspore.ops.rrelu
|
||||
mindspore.ops.selu
|
||||
mindspore.ops.sigmoid
|
||||
mindspore.ops.softsign
|
||||
|
|
|
@ -433,8 +433,10 @@ from .nn_func import (
|
|||
conv3d_transpose,
|
||||
conv2d,
|
||||
sigmoid,
|
||||
logsigmoid,
|
||||
relu,
|
||||
relu6,
|
||||
rrelu,
|
||||
conv3d,
|
||||
glu,
|
||||
multi_margin_loss,
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
"""Defines nn operators with functional form."""
|
||||
from __future__ import absolute_import
|
||||
from math import pi, log
|
||||
import numpy as np
|
||||
|
||||
import mindspore.ops as ops
|
||||
from mindspore.ops.primitive import constexpr
|
||||
|
@ -2252,6 +2253,46 @@ def sigmoid(input_x):
|
|||
return sigmoid_(input_x)
|
||||
|
||||
|
||||
def logsigmoid(x):
|
||||
r"""
|
||||
Logsigmoid activation function.
|
||||
|
||||
Applies logsigmoid activation element-wise. The input is a Tensor with any valid shape.
|
||||
|
||||
Logsigmoid is defined as:
|
||||
|
||||
.. math::
|
||||
\text{logsigmoid}(x_{i}) = log(\frac{1}{1 + \exp(-x_i)}),
|
||||
|
||||
where :math:`x_{i}` is the element of the input.
|
||||
|
||||
Args:
|
||||
x (Tensor): The input of LogSigmoid with data type of float16 or float32.
|
||||
The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
|
||||
|
||||
Returns:
|
||||
Tensor, with the same type and shape as the `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `x` is neither float16 nor float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
|
||||
>>> output = ops.logsigmoid(x)
|
||||
>>> print(output)
|
||||
[-0.31326166 -0.12692806 -0.04858734]
|
||||
"""
|
||||
output = _get_cache_prim(P.Mul)()(x, -1)
|
||||
output = _get_cache_prim(P.Exp)()(output)
|
||||
output = _get_cache_prim(P.Add)()(output, 1)
|
||||
output = _get_cache_prim(P.Reciprocal)()(output)
|
||||
ret = _get_cache_prim(P.Log)()(output)
|
||||
return ret
|
||||
|
||||
|
||||
def deformable_conv2d(x, weight, offsets, kernel_size, strides, padding, bias=None, dilations=(1, 1, 1, 1), groups=1,
|
||||
deformable_groups=1, modulated=True):
|
||||
r"""
|
||||
|
@ -2646,6 +2687,65 @@ def prelu(x, weight):
|
|||
return prelu_(x, weight)
|
||||
|
||||
|
||||
def rrelu(x, lower=1 / 8, upper=1 / 3):
|
||||
r"""
|
||||
|
||||
Randomized Leaky ReLU activation function.
|
||||
|
||||
The activation function is defined as:
|
||||
|
||||
.. math::
|
||||
\text{rrelu}(x_{ji}) = \begin{cases}x_{ji}, &\text{if } x_{ji} \geq 0; \cr
|
||||
{\alpha_{ji}} * x_{ji}, &\text{otherwise.}\end{cases}
|
||||
|
||||
where :math:`\alpha_{ji}` ~ :math:`U(l, u)`, :math:`l \le u`.
|
||||
|
||||
Applies the rrelu function elementally, as described in the paper:
|
||||
`Empirical Evaluation of Rectified Activations in Convolution Network <https://arxiv.org/pdf/1505.00853.pdf>`_ .
|
||||
|
||||
Args:
|
||||
x (Tensor): The input of rrelu is a Tensor of any dimension.
|
||||
lower (Union[int, float]): Slope of the activation function at x < 0. Default: 1/8.
|
||||
upper (Union[int, float]): Slope of the activation function at x < 0. Default: 1/3.
|
||||
|
||||
Returns:
|
||||
Tensor, after rrelu, has the same type and shape as the `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `lower` is not a float or an int.
|
||||
TypeError: If `upper` is not a float or an int.
|
||||
TypeError: If `x` is not a Tensor.
|
||||
TypeError: If `x` is not a Tensor of mindspore.float16 or mindpore.float32.
|
||||
ValueError: If `lower` is greater than upper.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[-1.0, 4.0], [2.0, 0]]), mindspore.float32)
|
||||
>>> output = ops.rrelu(x)
|
||||
>>> print(output)
|
||||
[[-0.31465699 4. ]
|
||||
[ 2. 0. ]]
|
||||
"""
|
||||
if not isinstance(upper, (float, int)):
|
||||
raise TypeError(f"For 'ops.rrelu', `upper` must be an int or a float, but got {type(upper)}")
|
||||
if not isinstance(lower, (float, int)):
|
||||
raise TypeError(f"For 'ops.rrelu', `lower` must be an int or a float, but got {type(lower)}")
|
||||
if lower > upper:
|
||||
raise ValueError(f"For 'ops.rrelu', the value of `upper` must be greater than `lower`, "
|
||||
f"but got upper: {upper}, lower: {lower}. ")
|
||||
size = x.shape
|
||||
sign_matrix = _get_cache_prim(P.Sign)()(x)
|
||||
negative_filter = sign_matrix.clip(None, 0)
|
||||
positive_filter = sign_matrix.clip(0, None)
|
||||
mask = _get_cache_prim(P.Cast)()(Tensor(np.random.uniform(lower, upper, size=size)), _get_cache_prim(P.DType)()(x))
|
||||
negative_mask = negative_filter * mask * -1
|
||||
total_mask = negative_mask + positive_filter
|
||||
out = total_mask * x
|
||||
return out
|
||||
|
||||
|
||||
def mirror_pad(input_x, paddings, mode):
|
||||
"""
|
||||
Pads the input tensor according to the paddings and mode.
|
||||
|
@ -5266,8 +5366,10 @@ __all__ = [
|
|||
'conv3d_transpose',
|
||||
'conv2d',
|
||||
'sigmoid',
|
||||
'logsigmoid',
|
||||
'relu',
|
||||
'relu6',
|
||||
'rrelu',
|
||||
'conv3d',
|
||||
'glu',
|
||||
'margin_ranking_loss',
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
output = ops.logsigmoid(x)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_logsigmoid(mode):
|
||||
"""
|
||||
Feature: logsigmoid
|
||||
Description: Verify the result of logsigmoid
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = ms.Tensor(np.array([1.0, 2.0, 3.0]), ms.float32)
|
||||
out = net(x)
|
||||
expect_out = np.array([-0.31326166, -0.12692806, -0.04858734])
|
||||
assert np.allclose(out.asnumpy(), expect_out)
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
output = ops.rrelu(x)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_rrelu(mode):
|
||||
"""
|
||||
Feature: rrelu
|
||||
Description: Verify the result of rrelu
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = ms.Tensor([[1.0, 4.0], [2.0, 0]], dtype=ms.float32)
|
||||
out = net(x)
|
||||
expect_out = np.array([[1., 4.],
|
||||
[2., 0.]])
|
||||
assert np.allclose(out.asnumpy(), expect_out)
|
Loading…
Reference in New Issue