From 4fd3433d23355ac1aae2f2113db9844cd55fc66e Mon Sep 17 00:00:00 2001 From: fujianzhao Date: Wed, 11 May 2022 15:39:02 +0800 Subject: [PATCH] [feat] [assistant] [ops] [I51VSC] Add new operator nn.HuberLoss --- docs/api/api_python/mindspore.nn.rst | 1 + .../api_python/nn/mindspore.nn.HuberLoss.rst | 50 ++++++++ docs/api/api_python_en/mindspore.nn.rst | 1 + .../python/mindspore/nn/loss/__init__.py | 6 +- mindspore/python/mindspore/nn/loss/loss.py | 118 +++++++++++++++++- tests/ut/python/nn/test_loss.py | 14 ++- 6 files changed, 185 insertions(+), 5 deletions(-) create mode 100644 docs/api/api_python/nn/mindspore.nn.HuberLoss.rst diff --git a/docs/api/api_python/mindspore.nn.rst b/docs/api/api_python/mindspore.nn.rst index e3b902b6c9c..22de7c6d12c 100644 --- a/docs/api/api_python/mindspore.nn.rst +++ b/docs/api/api_python/mindspore.nn.rst @@ -196,6 +196,7 @@ Dropout层 mindspore.nn.CosineEmbeddingLoss mindspore.nn.DiceLoss mindspore.nn.FocalLoss + mindspore.nn.HuberLoss mindspore.nn.L1Loss mindspore.nn.MSELoss mindspore.nn.MultiClassDiceLoss diff --git a/docs/api/api_python/nn/mindspore.nn.HuberLoss.rst b/docs/api/api_python/nn/mindspore.nn.HuberLoss.rst new file mode 100644 index 00000000000..bbd0f713b48 --- /dev/null +++ b/docs/api/api_python/nn/mindspore.nn.HuberLoss.rst @@ -0,0 +1,50 @@ +mindspore.nn.HuberLoss +============================= + +.. py:class:: mindspore.nn.HuberLoss(reduction='mean', delta=1.0) + + HuberLoss计算预测值和目标值之间的误差。它兼有L1Loss和MSELoss的优点。 + + 假设 :math:`x` 和 :math:`y` 为一维Tensor,长度 :math:`N` ,则计算 :math:`x` 和 :math:`y` 的loss而不进行降维操作(即reduction参数设置为"none")的公式如下: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top + + 以及 + + .. math:: + l_n = \begin{cases} + 0.5 * (x_n - y_n)^2, & \text{if } |x_n - y_n| < delta; \\ + delta * (|x_n - y_n| - 0.5 * delta), & \text{otherwise. } + \end{cases} + + 其中, :math:`N` 为batch size。如果 `reduction` 不是"none",则: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.} + \end{cases} + + **参数:** + + **reduction** (str) - 应用于loss的reduction类型。取值为"mean","sum",或"none"。默认值:"mean"。如果 `reduction` 为'mean'或'sum',则输出一个标量Tensor;如果 `reduction` 为'none',则输出Tensor的shape为广播后的shape。 + **delta** (Union[int, float]) - 两种损失之间变化的阈值。 该值必须为正。 默认值:1.0。 + + **输入:** + + - **logits** (Tensor) - 输入预测值,任意维度的Tensor。其数据类型为float16或float32。 + - **label** (Tensor) - 目标值,通常情况下与 `logits` 的shape和dtype相同。但是如果 `logits` 和 `labels` 的shape不同,需要保证他们之间可以互相广播。 + + **输出:** + + Tensor或Scalar,如果 `reduction` 为'none',则为shape和数据类型与输入'logits'相的Tensor。否则,输出为Scalar。 + + **异常:** + + - **TypeError** - `logits` 或 `labels` 的数据类型既不是float16也不是float32。 + - **TypeError** - `delta` 不是float或int。 + - **ValueError** - `delta` 的值小于或等于0。 + - **ValueError** - `reduction` 不为"mean"、"sum"或"none"。 + - **ValueError** - `logits` 和 `labels` 有不同的shape,且不能互相广播。 diff --git a/docs/api/api_python_en/mindspore.nn.rst b/docs/api/api_python_en/mindspore.nn.rst index 219d98d1a2d..3e84c915591 100644 --- a/docs/api/api_python_en/mindspore.nn.rst +++ b/docs/api/api_python_en/mindspore.nn.rst @@ -196,6 +196,7 @@ Loss Function mindspore.nn.CosineEmbeddingLoss mindspore.nn.DiceLoss mindspore.nn.FocalLoss + mindspore.nn.HuberLoss mindspore.nn.L1Loss mindspore.nn.MSELoss mindspore.nn.MultiClassDiceLoss diff --git a/mindspore/python/mindspore/nn/loss/__init__.py b/mindspore/python/mindspore/nn/loss/__init__.py index 1bd4bc7714d..f7e68ada25b 100644 --- a/mindspore/python/mindspore/nn/loss/__init__.py +++ b/mindspore/python/mindspore/nn/loss/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,10 +22,10 @@ It shows how well the model works on a dataset and the optimization target which from .loss import LossBase, L1Loss, MSELoss, SmoothL1Loss, SoftMarginLoss, FocalLoss,\ SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \ SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss,\ - RMSELoss, MAELoss + RMSELoss, MAELoss, HuberLoss __all__ = ['LossBase', 'L1Loss', 'MSELoss', 'SmoothL1Loss', 'SoftMarginLoss', 'FocalLoss', 'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss', 'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss', 'MultiClassDiceLoss', - 'RMSELoss', 'MAELoss'] + 'RMSELoss', 'MAELoss', 'HuberLoss'] diff --git a/mindspore/python/mindspore/nn/loss/loss.py b/mindspore/python/mindspore/nn/loss/loss.py index ae526bde3bf..22d0feddd52 100644 --- a/mindspore/python/mindspore/nn/loss/loss.py +++ b/mindspore/python/mindspore/nn/loss/loss.py @@ -1,4 +1,4 @@ -# Copyright 2020-2021 Huawei Technologies Co., Ltd +# Copyright 2020-2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -1471,3 +1471,119 @@ class FocalLoss(LossBase): loss = (-1 * weight * labelss * log_probability).mean(axis=-1) return self.get_loss(loss) + + +@constexpr +def _dtype_check(logits_dtype, labels_dtype, prim_name): + """Check dtype.""" + if logits_dtype not in [mstype.float32, mstype.float16]: + raise TypeError("For {}, the logits_dtype must be float32 or float16, but got {}.".format(prim_name, + logits_dtype)) + if logits_dtype != labels_dtype: + raise TypeError("For {}, the labels_dtype must equal to logits_dtype {}, but got {}".format(prim_name, + logits_dtype, + labels_dtype)) + + +class HuberLoss(LossBase): + r""" + HuberLoss calculate the error between the predicted value and the target value. + It has the advantages of both L1Loss and MSELoss. + + Assuming that the :math:`x` and :math:`y` are 1-D Tensor, length :math:`N`, then calculate the loss of :math:`x` and + :math:`y` without dimensionality reduction (the reduction parameter is set to "none"). The formula is as follows: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top + + with + + .. math:: + l_n = \begin{cases} + 0.5 * (x_n - y_n)^2, & \text{if } |x_n - y_n| < delta; \\ + delta * (|x_n - y_n| - 0.5 * delta), & \text{otherwise. } + \end{cases} + + where :math:`N` is the batch size. If `reduction` is not 'none', then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.} + \end{cases} + + Args: + reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none". + Default: "mean". If `reduction` is "mean" or "sum", then output a scalar Tensor, if `reduction` is "none", + the shape of the output Tensor is the broadcasted shape. + delta (Union[int, float]): The threshold to change between two type of loss. + The value must be positive. Default: 1.0. + + Inputs: + - **logits** (Tensor) - Input logits with shape :math:`(N, *)` where :math:`*` means, any number + of additional dimensions. The data type must be float16 or float32. + - **labels** (Tensor) - Ground truth label with shape :math:`(N, *)`, same dtype as `logits`. + It supports the shape of `logits` is different from the shape of `labels` and they should be + broadcasted to each other. + + Outputs: + Tensor or Scalar, if `reduction` is "none", its shape is the same as `logits`. + Otherwise, a scalar value will be returned. + + Raises: + TypeError: If data type of `logits` or `labels` is neither float16 nor float32. + TypeError: If dtype of `delta` is neither float nor int. + ValueError: If `delta` is less than or equal to 0. + ValueError: If `reduction` is not one of 'none', 'mean', 'sum'. + ValueError: If `logits` and `labels` have different shapes and cannot be broadcasted to each other. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> # Case 1: logits.shape = labels.shape = (3,) + >>> loss = nn.HuberLoss() + >>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32) + >>> labels = Tensor(np.array([1, 2, 2]), mindspore.float32) + >>> output = loss(logits, labels) + >>> print(output) + 0.16666667 + >>> # Case 2: logits.shape = (3,), labels.shape = (2, 3) + >>> loss = nn.HuberLoss(reduction='none') + >>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32) + >>> labels = Tensor(np.array([[1, 1, 1], [1, 2, 2]]), mindspore.float32) + >>> output = loss(logits, labels) + >>> print(output) + [[0. 0.5 1.5] + [0. 0. 0.5]] + """ + + def __init__(self, reduction='mean', delta=1.0): + """Initialize HuberLoss.""" + super(HuberLoss, self).__init__(reduction=reduction) + validator.check_value_type('delta', delta, [float, int], self.cls_name) + validator.check_number("delta", delta, 0.0, Rel.GT, self.cls_name) + self.sub = P.Sub() + self.mul = P.Mul() + self.abs = P.Abs() + self.less = P.Less() + self.square = P.Square() + self.select = P.Select() + self.dtype = P.DType() + self.delta = delta + self.delta_half = 0.5 * self.delta + + def construct(self, logits, labels): + _check_is_tensor('logits', logits, self.cls_name) + _check_is_tensor('labels', labels, self.cls_name) + logits_dtype = self.dtype(logits) + labels_dtype = self.dtype(labels) + _dtype_check(logits_dtype, labels_dtype, self.cls_name) + z = self.abs(self.sub(logits, labels)) + condition = self.less(z, self.delta) + l1 = self.mul(0.5, self.square(z)) + l2 = self.mul(self.delta, self.sub(z, self.delta_half)) + loss = self.select(condition, l1, l2) + + return self.get_loss(loss) diff --git a/tests/ut/python/nn/test_loss.py b/tests/ut/python/nn/test_loss.py index 0bf1e9ef6bf..940ab824574 100644 --- a/tests/ut/python/nn/test_loss.py +++ b/tests/ut/python/nn/test_loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -218,3 +218,15 @@ def test_mae_loss(): input_data = Tensor(np.array([[1, 2, 3], [2, 3, 2]]).astype(np.float32)) target_data = Tensor(np.array([[0, 0, 5], [1, 2, 3]]).astype(np.float32)) loss(input_data, target_data) + + +def test_huber_loss(): + """ + Feature: Test HuberLoss. + Description: Test HuberLoss functional. + Expectation: Success. + """ + loss = nn.HuberLoss() + input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(np.float32)) + target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(np.float32)) + loss(input_data, target_data)