From ce1f15666d689c72b4272b1e2c1a5ad3e8704afe Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Mon, 23 May 2022 14:51:53 +0800 Subject: [PATCH] add nn.Threshold --- docs/api/api_python/mindspore.nn.rst | 3 +- .../api_python/nn/mindspore.nn.Threshold.rst | 33 +++++++++++ docs/api/api_python_en/mindspore.nn.rst | 3 +- .../python/mindspore/nn/layer/activation.py | 57 +++++++++++++++++++ tests/ut/python/nn/test_activation.py | 21 +++++++ 5 files changed, 115 insertions(+), 2 deletions(-) create mode 100644 docs/api/api_python/nn/mindspore.nn.Threshold.rst diff --git a/docs/api/api_python/mindspore.nn.rst b/docs/api/api_python/mindspore.nn.rst index e49f436b5a1..2c7c9503a2a 100644 --- a/docs/api/api_python/mindspore.nn.rst +++ b/docs/api/api_python/mindspore.nn.rst @@ -124,7 +124,8 @@ MindSpore中 `mindspore.nn` 接口与上一版本相比,新增、删除和支 mindspore.nn.Softmax mindspore.nn.SoftShrink mindspore.nn.Tanh - + mindspore.nn.Threshold + 线性层 ----------------- diff --git a/docs/api/api_python/nn/mindspore.nn.Threshold.rst b/docs/api/api_python/nn/mindspore.nn.Threshold.rst new file mode 100644 index 00000000000..205a9b14709 --- /dev/null +++ b/docs/api/api_python/nn/mindspore.nn.Threshold.rst @@ -0,0 +1,33 @@ +mindspore.nn.Threshold +============================= + +.. py:class:: mindspore.nn.Threshold + + Threshold激活函数,按元素计算输出。 + + Threshold定义为: + + .. math:: + y = + \begin{cases} + x, &\text{ if } x > \text{threshold} \\ + \text{value}, &\text{ otherwise } + \end{cases} + + **参数:** + + **threshold** (`Union[int, float]`) – Threshold的阈值。 + **value** (`Union[int, float]`) – 输入Tensor中element小于阈值时的填充值。 + + **输入:** + + - **input_x** (Tensor) - Threshold的输入,数据类型为float16或float32。 + + **输出:** + + Tensor,数据类型和shape与 `input_x` 的相同。 + + **异常:** + + **TypeError** - `threshold` 不是浮点数或整数。 + **TypeError** - `value` 不是浮点数或整数。 diff --git a/docs/api/api_python_en/mindspore.nn.rst b/docs/api/api_python_en/mindspore.nn.rst index 6b3cda92a5d..2d46781a690 100644 --- a/docs/api/api_python_en/mindspore.nn.rst +++ b/docs/api/api_python_en/mindspore.nn.rst @@ -124,7 +124,8 @@ Nonlinear Activation Function Layer mindspore.nn.Softmax mindspore.nn.SoftShrink mindspore.nn.Tanh - + mindspore.nn.Threshold + Linear Layer ------------ diff --git a/mindspore/python/mindspore/nn/layer/activation.py b/mindspore/python/mindspore/nn/layer/activation.py index 46d911e5b6d..759e6df4da1 100644 --- a/mindspore/python/mindspore/nn/layer/activation.py +++ b/mindspore/python/mindspore/nn/layer/activation.py @@ -45,6 +45,7 @@ __all__ = ['Softmax', 'SoftShrink', 'HShrink', 'CELU', + 'Threshold' ] @@ -1070,6 +1071,61 @@ class HShrink(Cell): return self.hshrink(input_x) +class Threshold(Cell): + r"""Thresholds each element of the input Tensor. + + The formula is defined as follows: + + .. math:: + y = + \begin{cases} + x, &\text{ if } x > \text{threshold} \\ + \text{value}, &\text{ otherwise } + \end{cases} + + Args: + threshold: The value to threshold at. + value: The value to replace with when element is less than threshold. + + Inputs: + - **input_x** (Tensor) - The input of Threshold with data type of float16 or float32. + + Outputs: + Tensor, the same shape and data type as the input. + + Supported Platforms: + ``Ascend`` ``CPU`` ``GPU`` + + Raises: + TypeError: If `threshold` is not a float or an int. + TypeError: If `value` is not a float or an int. + + Examples: + >>> import mindspore + >>> import mindspore.nn as nn + >>> m = nn.Threshold(0.1, 20) + >>> inputs = mindspore.Tensor([0.1, 0.2, 0.3], mindspore.float32) + >>> outputs = m(inputs) + [ 20.0 0.2 0.3] + """ + + def __init__(self, threshold, value): + """Initialize Threshold.""" + super().__init__() + validator.check_value_type('threshold', threshold, [float, int], self.cls_name) + validator.check_value_type('value', value, [float, int], self.cls_name) + self.threshold = threshold + self.value = value + self.greater = P.Greater() + self.fill = P.Fill() + self.select = P.Select() + + def construct(self, input_x): + cond = self.greater(input_x, self.threshold) + value = self.fill(input_x.dtype, input_x.shape, self.value) + return self.select(cond, input_x, value) + + _activation = { 'softmax': Softmax, 'logsoftmax': LogSoftmax, @@ -1089,6 +1145,7 @@ _activation = { 'logsigmoid': LogSigmoid, 'softshrink': SoftShrink, 'hshrink': HShrink, + 'threshold': Threshold } diff --git a/tests/ut/python/nn/test_activation.py b/tests/ut/python/nn/test_activation.py index 76cd618a9ee..8123d43a85a 100755 --- a/tests/ut/python/nn/test_activation.py +++ b/tests/ut/python/nn/test_activation.py @@ -147,3 +147,24 @@ def test_hard_tanh(): net = Hardtanh(-1.0, 1.0) input_data = Tensor(np.array([[1.6, 0, 0.6], [6, 0, -6]], dtype=np.float32)) _cell_graph_executor.compile(net, input_data) + + +class NetThreshold(nn.Cell): + """Threshold.""" + def __init__(self, threshold, value): + super(NetThreshold, self).__init__() + self.threshold = nn.Threshold(threshold, value) + + def construct(self, x): + return self.threshold(x) + + +def test_compile_threshold(): + """ + Feature: Test Threshold. + Description: Test Threshold functional. + Expectation: Success. + """ + net = NetThreshold(threshold=0.1, value=1.0) + input_data = Tensor(np.array([[0.1, 0.2, 0.3], [0.0, 0.1, 0.2]], dtype=np.float32)) + _cell_graph_executor.compile(net, input_data)