!34221 [feat] [assistant] [ops] [I51VSC] Add new operator nn.HuberLoss

Merge pull request !34221 from fujianzhao/huberloss
This commit is contained in:
i-robot 2022-05-13 02:57:19 +00:00 committed by Gitee
commit c07819796c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 185 additions and 5 deletions

View File

@ -202,6 +202,7 @@ Dropout层
mindspore.nn.CosineEmbeddingLoss
mindspore.nn.DiceLoss
mindspore.nn.FocalLoss
mindspore.nn.HuberLoss
mindspore.nn.L1Loss
mindspore.nn.MSELoss
mindspore.nn.MultiClassDiceLoss

View File

@ -0,0 +1,50 @@
mindspore.nn.HuberLoss
=============================
.. py:class:: mindspore.nn.HuberLoss(reduction='mean', delta=1.0)
HuberLoss计算预测值和目标值之间的误差。它兼有L1Loss和MSELoss的优点。
假设 :math:`x`:math:`y` 为一维Tensor长度 :math:`N` ,则计算 :math:`x`:math:`y` 的loss而不进行降维操作即reduction参数设置为"none")的公式如下:
.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top
以及
.. math::
l_n = \begin{cases}
0.5 * (x_n - y_n)^2, & \text{if } |x_n - y_n| < delta; \\
delta * (|x_n - y_n| - 0.5 * delta), & \text{otherwise. }
\end{cases}
其中, :math:`N` 为batch size。如果 `reduction` 不是"none",则:
.. math::
\ell(x, y) =
\begin{cases}
\operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
\operatorname{sum}(L), & \text{if reduction} = \text{'sum'.}
\end{cases}
**参数:**
**reduction** (str) - 应用于loss的reduction类型。取值为"mean""sum",或"none"。默认值:"mean"。如果 `reduction` 为'mean'或'sum'则输出一个标量Tensor如果 `reduction` 为'none'则输出Tensor的shape为广播后的shape。
**delta** (Union[int, float]) - 两种损失之间变化的阈值。 该值必须为正。 默认值1.0。
**输入:**
- **logits** (Tensor) - 输入预测值任意维度的Tensor。其数据类型为float16或float32。
- **label** (Tensor) - 目标值,通常情况下与 `logits` 的shape和dtype相同。但是如果 `logits``labels` 的shape不同需要保证他们之间可以互相广播。
**输出:**
Tensor或Scalar如果 `reduction` 为'none'则为shape和数据类型与输入'logits'相的Tensor。否则输出为Scalar。
**异常:**
- **TypeError** - `logits``labels` 的数据类型既不是float16也不是float32。
- **TypeError** - `delta` 不是float或int。
- **ValueError** - `delta` 的值小于或等于0。
- **ValueError** - `reduction` 不为"mean"、"sum"或"none"。
- **ValueError** - `logits``labels` 有不同的shape且不能互相广播。

View File

@ -202,6 +202,7 @@ Loss Function
mindspore.nn.CosineEmbeddingLoss
mindspore.nn.DiceLoss
mindspore.nn.FocalLoss
mindspore.nn.HuberLoss
mindspore.nn.L1Loss
mindspore.nn.MSELoss
mindspore.nn.MultiClassDiceLoss

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -22,10 +22,10 @@ It shows how well the model works on a dataset and the optimization target which
from .loss import LossBase, L1Loss, MSELoss, SmoothL1Loss, SoftMarginLoss, FocalLoss,\
SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \
SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss,\
RMSELoss, MAELoss
RMSELoss, MAELoss, HuberLoss
__all__ = ['LossBase', 'L1Loss', 'MSELoss', 'SmoothL1Loss', 'SoftMarginLoss', 'FocalLoss',
'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss',
'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss', 'MultiClassDiceLoss',
'RMSELoss', 'MAELoss']
'RMSELoss', 'MAELoss', 'HuberLoss']

View File

@ -1,4 +1,4 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -1471,3 +1471,119 @@ class FocalLoss(LossBase):
loss = (-1 * weight * labelss * log_probability).mean(axis=-1)
return self.get_loss(loss)
@constexpr
def _dtype_check(logits_dtype, labels_dtype, prim_name):
"""Check dtype."""
if logits_dtype not in [mstype.float32, mstype.float16]:
raise TypeError("For {}, the logits_dtype must be float32 or float16, but got {}.".format(prim_name,
logits_dtype))
if logits_dtype != labels_dtype:
raise TypeError("For {}, the labels_dtype must equal to logits_dtype {}, but got {}".format(prim_name,
logits_dtype,
labels_dtype))
class HuberLoss(LossBase):
r"""
HuberLoss calculate the error between the predicted value and the target value.
It has the advantages of both L1Loss and MSELoss.
Assuming that the :math:`x` and :math:`y` are 1-D Tensor, length :math:`N`, then calculate the loss of :math:`x` and
:math:`y` without dimensionality reduction (the reduction parameter is set to "none"). The formula is as follows:
.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top
with
.. math::
l_n = \begin{cases}
0.5 * (x_n - y_n)^2, & \text{if } |x_n - y_n| < delta; \\
delta * (|x_n - y_n| - 0.5 * delta), & \text{otherwise. }
\end{cases}
where :math:`N` is the batch size. If `reduction` is not 'none', then:
.. math::
\ell(x, y) =
\begin{cases}
\operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
\operatorname{sum}(L), & \text{if reduction} = \text{'sum'.}
\end{cases}
Args:
reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none".
Default: "mean". If `reduction` is "mean" or "sum", then output a scalar Tensor, if `reduction` is "none",
the shape of the output Tensor is the broadcasted shape.
delta (Union[int, float]): The threshold to change between two type of loss.
The value must be positive. Default: 1.0.
Inputs:
- **logits** (Tensor) - Input logits with shape :math:`(N, *)` where :math:`*` means, any number
of additional dimensions. The data type must be float16 or float32.
- **labels** (Tensor) - Ground truth label with shape :math:`(N, *)`, same dtype as `logits`.
It supports the shape of `logits` is different from the shape of `labels` and they should be
broadcasted to each other.
Outputs:
Tensor or Scalar, if `reduction` is "none", its shape is the same as `logits`.
Otherwise, a scalar value will be returned.
Raises:
TypeError: If data type of `logits` or `labels` is neither float16 nor float32.
TypeError: If dtype of `delta` is neither float nor int.
ValueError: If `delta` is less than or equal to 0.
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
ValueError: If `logits` and `labels` have different shapes and cannot be broadcasted to each other.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> # Case 1: logits.shape = labels.shape = (3,)
>>> loss = nn.HuberLoss()
>>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32)
>>> labels = Tensor(np.array([1, 2, 2]), mindspore.float32)
>>> output = loss(logits, labels)
>>> print(output)
0.16666667
>>> # Case 2: logits.shape = (3,), labels.shape = (2, 3)
>>> loss = nn.HuberLoss(reduction='none')
>>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32)
>>> labels = Tensor(np.array([[1, 1, 1], [1, 2, 2]]), mindspore.float32)
>>> output = loss(logits, labels)
>>> print(output)
[[0. 0.5 1.5]
[0. 0. 0.5]]
"""
def __init__(self, reduction='mean', delta=1.0):
"""Initialize HuberLoss."""
super(HuberLoss, self).__init__(reduction=reduction)
validator.check_value_type('delta', delta, [float, int], self.cls_name)
validator.check_number("delta", delta, 0.0, Rel.GT, self.cls_name)
self.sub = P.Sub()
self.mul = P.Mul()
self.abs = P.Abs()
self.less = P.Less()
self.square = P.Square()
self.select = P.Select()
self.dtype = P.DType()
self.delta = delta
self.delta_half = 0.5 * self.delta
def construct(self, logits, labels):
_check_is_tensor('logits', logits, self.cls_name)
_check_is_tensor('labels', labels, self.cls_name)
logits_dtype = self.dtype(logits)
labels_dtype = self.dtype(labels)
_dtype_check(logits_dtype, labels_dtype, self.cls_name)
z = self.abs(self.sub(logits, labels))
condition = self.less(z, self.delta)
l1 = self.mul(0.5, self.square(z))
l2 = self.mul(self.delta, self.sub(z, self.delta_half))
loss = self.select(condition, l1, l2)
return self.get_loss(loss)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -218,3 +218,15 @@ def test_mae_loss():
input_data = Tensor(np.array([[1, 2, 3], [2, 3, 2]]).astype(np.float32))
target_data = Tensor(np.array([[0, 0, 5], [1, 2, 3]]).astype(np.float32))
loss(input_data, target_data)
def test_huber_loss():
"""
Feature: Test HuberLoss.
Description: Test HuberLoss functional.
Expectation: Success.
"""
loss = nn.HuberLoss()
input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(np.float32))
target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(np.float32))
loss(input_data, target_data)