!47237 for nn.function: rrelu and logsigmoid

Merge pull request !47237 from 于振华/nn_api_rrelu_logsigmoid_1226
This commit is contained in:
i-robot 2022-12-29 07:00:32 +00:00 committed by Gitee
commit 9fb72f3b68
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 260 additions and 0 deletions

View File

@ -90,7 +90,9 @@ mindspore.ops
mindspore.ops.hardshrink
mindspore.ops.hardswish
mindspore.ops.log_softmax
mindspore.ops.logsigmoid
mindspore.ops.mish
mindspore.ops.rrelu
mindspore.ops.selu
mindspore.ops.sigmoid
mindspore.ops.soft_shrink

View File

@ -0,0 +1,24 @@
mindspore.ops.logsigmoid
=============================
.. py:function:: mindspore.ops.logsigmoid(x)
logsigmoid激活函数。
按元素计算logsigmoid激活函数。输入是任意格式的Tensor。
logsigmoid定义为
.. math::
\text{logsigmoid}(x_{i}) = log(\frac{1}{1 + \exp(-x_i)}),
其中,:math:`x_i` 是输入Tensor的一个元素。
参数:
- **x** (Tensor) - logsigmoid的输入数据类型为float16或float32。shape为 :math:`(N,*)` ,其中 :math:`*` 表示任意的附加维度。
返回:
Tensor数据类型和shape与 `x` 的相同。
异常:
- **TypeError** - `x` 的数据类型既不是float16也不是float32。

View File

@ -0,0 +1,31 @@
mindspore.ops.rrelu
===================
.. py:function:: mindspore.ops.rrelu(x, lower=1 / 8, upper=1 / 3)
Randomized Leaky ReLU激活函数。
该激活函数定义如下:
.. math::
\text{rrelu}(x_{ji}) = \begin{cases}x_{ji}, &\text{if } x_{ji} \geq 0; \cr
{\alpha_{ji}} * x_{ji}, &\text{otherwise.}\end{cases}
其中,:math:`\alpha_{ji}` ~ :math:`U(l, u)`, :math:`l \le u`.
更多细节详见 `Empirical Evaluation of Rectified Activations in Convolution Network <https://arxiv.org/pdf/1505.00853.pdf>`_
参数:
- **x** Tensor - 计算RReLU的任意维度的Tensor。
- **lower** (Union[int, float]) - x<0时激活函数的斜率的下界默认值1/8。
- **upper** (Union[int, float]) - x<0时激活函数的斜率的上界默认值1/3。
返回:
Tensor数据类型和shape与 `x` 相同。
异常:
- **TypeError** - `lower` 不是浮点数或整数。
- **TypeError** - `upper` 不是浮点数或整数。
- **TypeError** - `x` 不是Tensor。
- **TypeError** - `x` 内的数据类型不是mindspore.float16或mindspore.float32。
- **ValueError** - `lower` 大于 `upper`

View File

@ -91,7 +91,9 @@ Activation Functions
mindspore.ops.hardshrink
mindspore.ops.hardswish
mindspore.ops.log_softmax
mindspore.ops.logsigmoid
mindspore.ops.mish
mindspore.ops.rrelu
mindspore.ops.selu
mindspore.ops.sigmoid
mindspore.ops.softsign

View File

@ -433,8 +433,10 @@ from .nn_func import (
conv3d_transpose,
conv2d,
sigmoid,
logsigmoid,
relu,
relu6,
rrelu,
conv3d,
glu,
multi_margin_loss,

View File

@ -16,6 +16,7 @@
"""Defines nn operators with functional form."""
from __future__ import absolute_import
from math import pi, log
import numpy as np
import mindspore.ops as ops
from mindspore.ops.primitive import constexpr
@ -2252,6 +2253,46 @@ def sigmoid(input_x):
return sigmoid_(input_x)
def logsigmoid(x):
r"""
Logsigmoid activation function.
Applies logsigmoid activation element-wise. The input is a Tensor with any valid shape.
Logsigmoid is defined as:
.. math::
\text{logsigmoid}(x_{i}) = log(\frac{1}{1 + \exp(-x_i)}),
where :math:`x_{i}` is the element of the input.
Args:
x (Tensor): The input of LogSigmoid with data type of float16 or float32.
The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
Returns:
Tensor, with the same type and shape as the `x`.
Raises:
TypeError: If dtype of `x` is neither float16 nor float32.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
>>> output = ops.logsigmoid(x)
>>> print(output)
[-0.31326166 -0.12692806 -0.04858734]
"""
output = _get_cache_prim(P.Mul)()(x, -1)
output = _get_cache_prim(P.Exp)()(output)
output = _get_cache_prim(P.Add)()(output, 1)
output = _get_cache_prim(P.Reciprocal)()(output)
ret = _get_cache_prim(P.Log)()(output)
return ret
def deformable_conv2d(x, weight, offsets, kernel_size, strides, padding, bias=None, dilations=(1, 1, 1, 1), groups=1,
deformable_groups=1, modulated=True):
r"""
@ -2646,6 +2687,65 @@ def prelu(x, weight):
return prelu_(x, weight)
def rrelu(x, lower=1 / 8, upper=1 / 3):
r"""
Randomized Leaky ReLU activation function.
The activation function is defined as:
.. math::
\text{rrelu}(x_{ji}) = \begin{cases}x_{ji}, &\text{if } x_{ji} \geq 0; \cr
{\alpha_{ji}} * x_{ji}, &\text{otherwise.}\end{cases}
where :math:`\alpha_{ji}` ~ :math:`U(l, u)`, :math:`l \le u`.
Applies the rrelu function elementally, as described in the paper:
`Empirical Evaluation of Rectified Activations in Convolution Network <https://arxiv.org/pdf/1505.00853.pdf>`_ .
Args:
x (Tensor): The input of rrelu is a Tensor of any dimension.
lower (Union[int, float]): Slope of the activation function at x < 0. Default: 1/8.
upper (Union[int, float]): Slope of the activation function at x < 0. Default: 1/3.
Returns:
Tensor, after rrelu, has the same type and shape as the `x`.
Raises:
TypeError: If `lower` is not a float or an int.
TypeError: If `upper` is not a float or an int.
TypeError: If `x` is not a Tensor.
TypeError: If `x` is not a Tensor of mindspore.float16 or mindpore.float32.
ValueError: If `lower` is greater than upper.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([[-1.0, 4.0], [2.0, 0]]), mindspore.float32)
>>> output = ops.rrelu(x)
>>> print(output)
[[-0.31465699 4. ]
[ 2. 0. ]]
"""
if not isinstance(upper, (float, int)):
raise TypeError(f"For 'ops.rrelu', `upper` must be an int or a float, but got {type(upper)}")
if not isinstance(lower, (float, int)):
raise TypeError(f"For 'ops.rrelu', `lower` must be an int or a float, but got {type(lower)}")
if lower > upper:
raise ValueError(f"For 'ops.rrelu', the value of `upper` must be greater than `lower`, "
f"but got upper: {upper}, lower: {lower}. ")
size = x.shape
sign_matrix = _get_cache_prim(P.Sign)()(x)
negative_filter = sign_matrix.clip(None, 0)
positive_filter = sign_matrix.clip(0, None)
mask = _get_cache_prim(P.Cast)()(Tensor(np.random.uniform(lower, upper, size=size)), _get_cache_prim(P.DType)()(x))
negative_mask = negative_filter * mask * -1
total_mask = negative_mask + positive_filter
out = total_mask * x
return out
def mirror_pad(input_x, paddings, mode):
"""
Pads the input tensor according to the paddings and mode.
@ -5266,8 +5366,10 @@ __all__ = [
'conv3d_transpose',
'conv2d',
'sigmoid',
'logsigmoid',
'relu',
'relu6',
'rrelu',
'conv3d',
'glu',
'margin_ranking_loss',

View File

@ -0,0 +1,47 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
def construct(self, x):
output = ops.logsigmoid(x)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_logsigmoid(mode):
"""
Feature: logsigmoid
Description: Verify the result of logsigmoid
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor(np.array([1.0, 2.0, 3.0]), ms.float32)
out = net(x)
expect_out = np.array([-0.31326166, -0.12692806, -0.04858734])
assert np.allclose(out.asnumpy(), expect_out)

View File

@ -0,0 +1,50 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
def construct(self, x):
output = ops.rrelu(x)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_rrelu(mode):
"""
Feature: rrelu
Description: Verify the result of rrelu
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor([[1.0, 4.0], [2.0, 0]], dtype=ms.float32)
out = net(x)
expect_out = np.array([[1., 4.],
[2., 0.]])
assert np.allclose(out.asnumpy(), expect_out)