Implement Huber loss (#1444)

* Implement Huber loss

Instead of using a sign or abs function, uses clamping to compute
it outside the bounds. This is better for the autodiff backend.

* mention Huber loss in the book

* unify naming of residuals in comments
This commit is contained in:
WorldSEnder 2024-03-13 17:55:46 +00:00 committed by GitHub
parent 7a98b2f663
commit 53eb3ecfa9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 181 additions and 0 deletions

View File

@ -162,3 +162,4 @@ Burn comes with built-in modules that you can use to build your own modules.
| ------------------ | --------------------- |
| `CrossEntropyLoss` | `nn.CrossEntropyLoss` |
| `MseLoss` | `nn.MSELoss` |
| `HuberLoss` | `nn.HuberLoss` |

View File

@ -0,0 +1,178 @@
use crate as burn;
use crate::{config::Config, module::Module};
use burn_tensor::backend::Backend;
use burn_tensor::Tensor;
use core::marker::PhantomData;
use super::Reduction;
/// Configuration to create a [Huber loss](HuberLoss).
#[derive(Config, Debug)]
pub struct HuberLossConfig {
/// The bound where the Huber loss function changes from quadratic to linear behaviour.
pub delta: f32,
}
impl HuberLossConfig {
/// Initialize [Huber loss](HuberLoss).
pub fn init<B: Backend>(&self, device: &B::Device) -> HuberLoss<B> {
// device is not needed as of now, but we might want to prepare some data on it
// and its consistent with other loss functions
let _ = device;
self.assertions();
HuberLoss {
delta: self.delta,
lin_bias: self.delta * self.delta * 0.5,
_backend: PhantomData,
}
}
fn assertions(&self) {
assert!(
self.delta >= 0., // This also tests for normality
"Delta for Huber loss must be a non-negative number."
);
}
}
/// Calculate the Huber loss between the inputs and the target.
///
/// The loss for each element of the residuals `r = targets - predictions` is given by
///
/// ```text
/// L(r) = 0.5 * x^2 if |r| <= d
/// L(r) = 0.5 * d^2 + d * (|r| - d) if |r| > d
/// ```
///
/// where `d` is the configured `delta`. In particular, this is equal to the
/// [L2 Loss](super::MseLoss) for residuals with magnitude smaller than `delta`,
/// but behaves linearly instead of quadratically for large residuals.
///
/// This loss function is less sensitive to outliers than the mean squared error loss.
///
/// See also: <https://en.wikipedia.org/wiki/Huber_loss>
#[derive(Module, Debug)]
pub struct HuberLoss<B: Backend> {
delta: f32,
lin_bias: f32, // delta * delta * 0.5 precomputed
_backend: PhantomData<B>,
}
impl<B: Backend> HuberLoss<B> {
/// Compute the loss element-wise for the predictions and targets, then reduce
/// to a single loss value.
///
/// `Reduction::Auto` behaves as `Reduction::Mean`.
///
/// # Shapes
///
/// - predictions: \[...dims\]
/// - targets: \[...dims\]
/// - output: \[1\]
pub fn forward<const D: usize>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
reduction: Reduction,
) -> Tensor<B, 1> {
let loss = self.forward_no_reduction(predictions, targets);
match reduction {
Reduction::Mean | Reduction::Auto => loss.mean(),
Reduction::Sum => loss.sum(),
}
}
/// Compute the loss element-wise for the predictions and targets.
///
/// # Shapes
///
/// - predictions: [...dims]
/// - targets: [...dims]
/// - output: [...dims]
pub fn forward_no_reduction<const D: usize>(
&self,
predictions: Tensor<B, D>,
targets: Tensor<B, D>,
) -> Tensor<B, D> {
let residuals = targets - predictions;
self.forward_residuals(residuals)
}
/// Compute the loss element-wise for the given residuals.
///
/// # Shapes
///
/// - residuals: [...dims]
/// - output: [...dims]
pub fn forward_residuals<const D: usize>(&self, residuals: Tensor<B, D>) -> Tensor<B, D> {
let is_large = residuals.clone().abs().greater_elem(self.delta);
// We are interested in `sign(r)` when `abs(r) > self.delta`. Note that the
// `sign()` function, in general, suffers from a jump at 0.
// Instead the following tensor implements `delta * sign(r)` for values outside
// the bound:
let softsign = residuals.clone().clamp(-self.delta, self.delta);
// 0.5 * d^2 + d * (|r| - d) =
// d * |r| - 0.5 * d^2
// Moreover |r| = sign(r) * r
let outside = softsign.mul(residuals.clone()).sub_scalar(self.lin_bias);
let inside = residuals.powf_scalar(2.).mul_scalar(0.5);
inside.mask_where(is_large, outside)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn_tensor::Data;
type TestTensor<const D: usize> = Tensor<TestBackend, D>;
#[test]
fn test_huber_loss() {
let predict = Data::from([-2., -0.5, 0., 0.3, 1.]);
let targets = Data::from([0., 0., 0., 0., 0.]);
let device = Default::default();
let predict = TestTensor::<1>::from_data(predict, &device);
let targets = TestTensor::<1>::from_data(targets, &device);
let huber = HuberLossConfig::new(0.5).init(&device);
let loss_sum = huber.forward(predict.clone(), targets.clone(), Reduction::Sum);
let loss = huber.forward(predict.clone(), targets.clone(), Reduction::Auto);
let loss_no_reduction = huber.forward_no_reduction(predict, targets);
loss_no_reduction
.into_data()
.assert_approx_eq(&Data::from([0.875, 0.125, 0., 0.045, 0.375]), 7);
loss.into_data().assert_approx_eq(&Data::from([0.284]), 7);
loss_sum
.into_data()
.assert_approx_eq(&Data::from([1.42]), 7);
}
#[cfg(feature = "std")]
#[test]
fn test_huber_ad_loss() {
type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 1>;
let predict = Data::from([-2., -0.5, 0., 0.3, 1.]);
let targets = Data::from([0., 0., 0., 0., 0.]);
let device = Default::default();
let predict = TestAutodiffTensor::from_data(predict, &device).require_grad();
let targets = TestAutodiffTensor::from_data(targets, &device);
let loss = HuberLossConfig::new(0.5).init(&device);
let loss = loss.forward_no_reduction(predict.clone(), targets);
let grads = loss.backward();
let grads_predict = predict.grad(&grads).unwrap();
grads_predict
.to_data()
.assert_approx_eq(&Data::from([-0.5, -0.5, 0., 0.3, 0.5]), 3);
}
}

View File

@ -1,9 +1,11 @@
mod binary_cross_entropy;
mod cross_entropy;
mod huber;
mod mse;
mod reduction;
pub use binary_cross_entropy::*;
pub use cross_entropy::*;
pub use huber::*;
pub use mse::*;
pub use reduction::*;