!34842 add nn.Threshold
Merge pull request !34842 from 吕昱峰(Nate.River)/threshold
This commit is contained in:
commit
f3148dbe6e
|
@ -124,7 +124,8 @@ MindSpore中 `mindspore.nn` 接口与上一版本相比,新增、删除和支
|
|||
mindspore.nn.Softmax
|
||||
mindspore.nn.SoftShrink
|
||||
mindspore.nn.Tanh
|
||||
|
||||
mindspore.nn.Threshold
|
||||
|
||||
线性层
|
||||
-----------------
|
||||
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
mindspore.nn.Threshold
|
||||
=============================
|
||||
|
||||
.. py:class:: mindspore.nn.Threshold
|
||||
|
||||
Threshold激活函数,按元素计算输出。
|
||||
|
||||
Threshold定义为:
|
||||
|
||||
.. math::
|
||||
y =
|
||||
\begin{cases}
|
||||
x, &\text{ if } x > \text{threshold} \\
|
||||
\text{value}, &\text{ otherwise }
|
||||
\end{cases}
|
||||
|
||||
**参数:**
|
||||
|
||||
**threshold** (`Union[int, float]`) – Threshold的阈值。
|
||||
**value** (`Union[int, float]`) – 输入Tensor中element小于阈值时的填充值。
|
||||
|
||||
**输入:**
|
||||
|
||||
- **input_x** (Tensor) - Threshold的输入,数据类型为float16或float32。
|
||||
|
||||
**输出:**
|
||||
|
||||
Tensor,数据类型和shape与 `input_x` 的相同。
|
||||
|
||||
**异常:**
|
||||
|
||||
**TypeError** - `threshold` 不是浮点数或整数。
|
||||
**TypeError** - `value` 不是浮点数或整数。
|
|
@ -124,7 +124,8 @@ Nonlinear Activation Function Layer
|
|||
mindspore.nn.Softmax
|
||||
mindspore.nn.SoftShrink
|
||||
mindspore.nn.Tanh
|
||||
|
||||
mindspore.nn.Threshold
|
||||
|
||||
Linear Layer
|
||||
------------
|
||||
|
||||
|
|
|
@ -45,6 +45,7 @@ __all__ = ['Softmax',
|
|||
'SoftShrink',
|
||||
'HShrink',
|
||||
'CELU',
|
||||
'Threshold'
|
||||
]
|
||||
|
||||
|
||||
|
@ -1070,6 +1071,61 @@ class HShrink(Cell):
|
|||
return self.hshrink(input_x)
|
||||
|
||||
|
||||
class Threshold(Cell):
|
||||
r"""Thresholds each element of the input Tensor.
|
||||
|
||||
The formula is defined as follows:
|
||||
|
||||
.. math::
|
||||
y =
|
||||
\begin{cases}
|
||||
x, &\text{ if } x > \text{threshold} \\
|
||||
\text{value}, &\text{ otherwise }
|
||||
\end{cases}
|
||||
|
||||
Args:
|
||||
threshold: The value to threshold at.
|
||||
value: The value to replace with when element is less than threshold.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input of Threshold with data type of float16 or float32.
|
||||
|
||||
Outputs:
|
||||
Tensor, the same shape and data type as the input.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU`` ``GPU``
|
||||
|
||||
Raises:
|
||||
TypeError: If `threshold` is not a float or an int.
|
||||
TypeError: If `value` is not a float or an int.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
>>> import mindspore.nn as nn
|
||||
>>> m = nn.Threshold(0.1, 20)
|
||||
>>> inputs = mindspore.Tensor([0.1, 0.2, 0.3], mindspore.float32)
|
||||
>>> outputs = m(inputs)
|
||||
[ 20.0 0.2 0.3]
|
||||
"""
|
||||
|
||||
def __init__(self, threshold, value):
|
||||
"""Initialize Threshold."""
|
||||
super().__init__()
|
||||
validator.check_value_type('threshold', threshold, [float, int], self.cls_name)
|
||||
validator.check_value_type('value', value, [float, int], self.cls_name)
|
||||
self.threshold = threshold
|
||||
self.value = value
|
||||
self.greater = P.Greater()
|
||||
self.fill = P.Fill()
|
||||
self.select = P.Select()
|
||||
|
||||
def construct(self, input_x):
|
||||
cond = self.greater(input_x, self.threshold)
|
||||
value = self.fill(input_x.dtype, input_x.shape, self.value)
|
||||
return self.select(cond, input_x, value)
|
||||
|
||||
|
||||
_activation = {
|
||||
'softmax': Softmax,
|
||||
'logsoftmax': LogSoftmax,
|
||||
|
@ -1089,6 +1145,7 @@ _activation = {
|
|||
'logsigmoid': LogSigmoid,
|
||||
'softshrink': SoftShrink,
|
||||
'hshrink': HShrink,
|
||||
'threshold': Threshold
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -147,3 +147,24 @@ def test_hard_tanh():
|
|||
net = Hardtanh(-1.0, 1.0)
|
||||
input_data = Tensor(np.array([[1.6, 0, 0.6], [6, 0, -6]], dtype=np.float32))
|
||||
_cell_graph_executor.compile(net, input_data)
|
||||
|
||||
|
||||
class NetThreshold(nn.Cell):
|
||||
"""Threshold."""
|
||||
def __init__(self, threshold, value):
|
||||
super(NetThreshold, self).__init__()
|
||||
self.threshold = nn.Threshold(threshold, value)
|
||||
|
||||
def construct(self, x):
|
||||
return self.threshold(x)
|
||||
|
||||
|
||||
def test_compile_threshold():
|
||||
"""
|
||||
Feature: Test Threshold.
|
||||
Description: Test Threshold functional.
|
||||
Expectation: Success.
|
||||
"""
|
||||
net = NetThreshold(threshold=0.1, value=1.0)
|
||||
input_data = Tensor(np.array([[0.1, 0.2, 0.3], [0.0, 0.1, 0.2]], dtype=np.float32))
|
||||
_cell_graph_executor.compile(net, input_data)
|
||||
|
|
Loading…
Reference in New Issue