!34842 add nn.Threshold

Merge pull request !34842 from 吕昱峰(Nate.River)/threshold
This commit is contained in:
i-robot 2022-05-25 09:35:33 +00:00 committed by Gitee
commit f3148dbe6e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 115 additions and 2 deletions

View File

@ -124,7 +124,8 @@ MindSpore中 `mindspore.nn` 接口与上一版本相比,新增、删除和支
mindspore.nn.Softmax
mindspore.nn.SoftShrink
mindspore.nn.Tanh
mindspore.nn.Threshold
线性层
-----------------

View File

@ -0,0 +1,33 @@
mindspore.nn.Threshold
=============================
.. py:class:: mindspore.nn.Threshold
Threshold激活函数按元素计算输出。
Threshold定义为
.. math::
y =
\begin{cases}
x, &\text{ if } x > \text{threshold} \\
\text{value}, &\text{ otherwise }
\end{cases}
**参数:**
**threshold** (`Union[int, float]`) Threshold的阈值。
**value** (`Union[int, float]`) 输入Tensor中element小于阈值时的填充值。
**输入:**
- **input_x** (Tensor) - Threshold的输入数据类型为float16或float32。
**输出:**
Tensor数据类型和shape与 `input_x` 的相同。
**异常:**
**TypeError** - `threshold` 不是浮点数或整数。
**TypeError** - `value` 不是浮点数或整数。

View File

@ -124,7 +124,8 @@ Nonlinear Activation Function Layer
mindspore.nn.Softmax
mindspore.nn.SoftShrink
mindspore.nn.Tanh
mindspore.nn.Threshold
Linear Layer
------------

View File

@ -45,6 +45,7 @@ __all__ = ['Softmax',
'SoftShrink',
'HShrink',
'CELU',
'Threshold'
]
@ -1070,6 +1071,61 @@ class HShrink(Cell):
return self.hshrink(input_x)
class Threshold(Cell):
r"""Thresholds each element of the input Tensor.
The formula is defined as follows:
.. math::
y =
\begin{cases}
x, &\text{ if } x > \text{threshold} \\
\text{value}, &\text{ otherwise }
\end{cases}
Args:
threshold: The value to threshold at.
value: The value to replace with when element is less than threshold.
Inputs:
- **input_x** (Tensor) - The input of Threshold with data type of float16 or float32.
Outputs:
Tensor, the same shape and data type as the input.
Supported Platforms:
``Ascend`` ``CPU`` ``GPU``
Raises:
TypeError: If `threshold` is not a float or an int.
TypeError: If `value` is not a float or an int.
Examples:
>>> import mindspore
>>> import mindspore.nn as nn
>>> m = nn.Threshold(0.1, 20)
>>> inputs = mindspore.Tensor([0.1, 0.2, 0.3], mindspore.float32)
>>> outputs = m(inputs)
[ 20.0 0.2 0.3]
"""
def __init__(self, threshold, value):
"""Initialize Threshold."""
super().__init__()
validator.check_value_type('threshold', threshold, [float, int], self.cls_name)
validator.check_value_type('value', value, [float, int], self.cls_name)
self.threshold = threshold
self.value = value
self.greater = P.Greater()
self.fill = P.Fill()
self.select = P.Select()
def construct(self, input_x):
cond = self.greater(input_x, self.threshold)
value = self.fill(input_x.dtype, input_x.shape, self.value)
return self.select(cond, input_x, value)
_activation = {
'softmax': Softmax,
'logsoftmax': LogSoftmax,
@ -1089,6 +1145,7 @@ _activation = {
'logsigmoid': LogSigmoid,
'softshrink': SoftShrink,
'hshrink': HShrink,
'threshold': Threshold
}

View File

@ -147,3 +147,24 @@ def test_hard_tanh():
net = Hardtanh(-1.0, 1.0)
input_data = Tensor(np.array([[1.6, 0, 0.6], [6, 0, -6]], dtype=np.float32))
_cell_graph_executor.compile(net, input_data)
class NetThreshold(nn.Cell):
"""Threshold."""
def __init__(self, threshold, value):
super(NetThreshold, self).__init__()
self.threshold = nn.Threshold(threshold, value)
def construct(self, x):
return self.threshold(x)
def test_compile_threshold():
"""
Feature: Test Threshold.
Description: Test Threshold functional.
Expectation: Success.
"""
net = NetThreshold(threshold=0.1, value=1.0)
input_data = Tensor(np.array([[0.1, 0.2, 0.3], [0.0, 0.1, 0.2]], dtype=np.float32))
_cell_graph_executor.compile(net, input_data)