mirror of https://github.com/tracel-ai/burn.git
Implement Huber loss (#1444)
* Implement Huber loss Instead of using a sign or abs function, uses clamping to compute it outside the bounds. This is better for the autodiff backend. * mention Huber loss in the book * unify naming of residuals in comments
This commit is contained in:
parent
7a98b2f663
commit
53eb3ecfa9
|
@ -162,3 +162,4 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
| ------------------ | --------------------- |
|
||||
| `CrossEntropyLoss` | `nn.CrossEntropyLoss` |
|
||||
| `MseLoss` | `nn.MSELoss` |
|
||||
| `HuberLoss` | `nn.HuberLoss` |
|
||||
|
|
|
@ -0,0 +1,178 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::{config::Config, module::Module};
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::Tensor;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use super::Reduction;
|
||||
|
||||
/// Configuration to create a [Huber loss](HuberLoss).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct HuberLossConfig {
|
||||
/// The bound where the Huber loss function changes from quadratic to linear behaviour.
|
||||
pub delta: f32,
|
||||
}
|
||||
|
||||
impl HuberLossConfig {
|
||||
/// Initialize [Huber loss](HuberLoss).
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> HuberLoss<B> {
|
||||
// device is not needed as of now, but we might want to prepare some data on it
|
||||
// and its consistent with other loss functions
|
||||
let _ = device;
|
||||
self.assertions();
|
||||
HuberLoss {
|
||||
delta: self.delta,
|
||||
lin_bias: self.delta * self.delta * 0.5,
|
||||
_backend: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn assertions(&self) {
|
||||
assert!(
|
||||
self.delta >= 0., // This also tests for normality
|
||||
"Delta for Huber loss must be a non-negative number."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate the Huber loss between the inputs and the target.
|
||||
///
|
||||
/// The loss for each element of the residuals `r = targets - predictions` is given by
|
||||
///
|
||||
/// ```text
|
||||
/// L(r) = 0.5 * x^2 if |r| <= d
|
||||
/// L(r) = 0.5 * d^2 + d * (|r| - d) if |r| > d
|
||||
/// ```
|
||||
///
|
||||
/// where `d` is the configured `delta`. In particular, this is equal to the
|
||||
/// [L2 Loss](super::MseLoss) for residuals with magnitude smaller than `delta`,
|
||||
/// but behaves linearly instead of quadratically for large residuals.
|
||||
///
|
||||
/// This loss function is less sensitive to outliers than the mean squared error loss.
|
||||
///
|
||||
/// See also: <https://en.wikipedia.org/wiki/Huber_loss>
|
||||
#[derive(Module, Debug)]
|
||||
pub struct HuberLoss<B: Backend> {
|
||||
delta: f32,
|
||||
lin_bias: f32, // delta * delta * 0.5 precomputed
|
||||
_backend: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> HuberLoss<B> {
|
||||
/// Compute the loss element-wise for the predictions and targets, then reduce
|
||||
/// to a single loss value.
|
||||
///
|
||||
/// `Reduction::Auto` behaves as `Reduction::Mean`.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - predictions: \[...dims\]
|
||||
/// - targets: \[...dims\]
|
||||
/// - output: \[1\]
|
||||
pub fn forward<const D: usize>(
|
||||
&self,
|
||||
predictions: Tensor<B, D>,
|
||||
targets: Tensor<B, D>,
|
||||
reduction: Reduction,
|
||||
) -> Tensor<B, 1> {
|
||||
let loss = self.forward_no_reduction(predictions, targets);
|
||||
match reduction {
|
||||
Reduction::Mean | Reduction::Auto => loss.mean(),
|
||||
Reduction::Sum => loss.sum(),
|
||||
}
|
||||
}
|
||||
/// Compute the loss element-wise for the predictions and targets.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - predictions: [...dims]
|
||||
/// - targets: [...dims]
|
||||
/// - output: [...dims]
|
||||
pub fn forward_no_reduction<const D: usize>(
|
||||
&self,
|
||||
predictions: Tensor<B, D>,
|
||||
targets: Tensor<B, D>,
|
||||
) -> Tensor<B, D> {
|
||||
let residuals = targets - predictions;
|
||||
self.forward_residuals(residuals)
|
||||
}
|
||||
/// Compute the loss element-wise for the given residuals.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - residuals: [...dims]
|
||||
/// - output: [...dims]
|
||||
pub fn forward_residuals<const D: usize>(&self, residuals: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let is_large = residuals.clone().abs().greater_elem(self.delta);
|
||||
// We are interested in `sign(r)` when `abs(r) > self.delta`. Note that the
|
||||
// `sign()` function, in general, suffers from a jump at 0.
|
||||
// Instead the following tensor implements `delta * sign(r)` for values outside
|
||||
// the bound:
|
||||
let softsign = residuals.clone().clamp(-self.delta, self.delta);
|
||||
|
||||
// 0.5 * d^2 + d * (|r| - d) =
|
||||
// d * |r| - 0.5 * d^2
|
||||
// Moreover |r| = sign(r) * r
|
||||
let outside = softsign.mul(residuals.clone()).sub_scalar(self.lin_bias);
|
||||
|
||||
let inside = residuals.powf_scalar(2.).mul_scalar(0.5);
|
||||
inside.mask_where(is_large, outside)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn_tensor::Data;
|
||||
type TestTensor<const D: usize> = Tensor<TestBackend, D>;
|
||||
|
||||
#[test]
|
||||
fn test_huber_loss() {
|
||||
let predict = Data::from([-2., -0.5, 0., 0.3, 1.]);
|
||||
let targets = Data::from([0., 0., 0., 0., 0.]);
|
||||
|
||||
let device = Default::default();
|
||||
|
||||
let predict = TestTensor::<1>::from_data(predict, &device);
|
||||
let targets = TestTensor::<1>::from_data(targets, &device);
|
||||
|
||||
let huber = HuberLossConfig::new(0.5).init(&device);
|
||||
|
||||
let loss_sum = huber.forward(predict.clone(), targets.clone(), Reduction::Sum);
|
||||
let loss = huber.forward(predict.clone(), targets.clone(), Reduction::Auto);
|
||||
let loss_no_reduction = huber.forward_no_reduction(predict, targets);
|
||||
|
||||
loss_no_reduction
|
||||
.into_data()
|
||||
.assert_approx_eq(&Data::from([0.875, 0.125, 0., 0.045, 0.375]), 7);
|
||||
loss.into_data().assert_approx_eq(&Data::from([0.284]), 7);
|
||||
loss_sum
|
||||
.into_data()
|
||||
.assert_approx_eq(&Data::from([1.42]), 7);
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn test_huber_ad_loss() {
|
||||
type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 1>;
|
||||
|
||||
let predict = Data::from([-2., -0.5, 0., 0.3, 1.]);
|
||||
let targets = Data::from([0., 0., 0., 0., 0.]);
|
||||
|
||||
let device = Default::default();
|
||||
let predict = TestAutodiffTensor::from_data(predict, &device).require_grad();
|
||||
let targets = TestAutodiffTensor::from_data(targets, &device);
|
||||
|
||||
let loss = HuberLossConfig::new(0.5).init(&device);
|
||||
let loss = loss.forward_no_reduction(predict.clone(), targets);
|
||||
|
||||
let grads = loss.backward();
|
||||
let grads_predict = predict.grad(&grads).unwrap();
|
||||
|
||||
grads_predict
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([-0.5, -0.5, 0., 0.3, 0.5]), 3);
|
||||
}
|
||||
}
|
|
@ -1,9 +1,11 @@
|
|||
mod binary_cross_entropy;
|
||||
mod cross_entropy;
|
||||
mod huber;
|
||||
mod mse;
|
||||
mod reduction;
|
||||
|
||||
pub use binary_cross_entropy::*;
|
||||
pub use cross_entropy::*;
|
||||
pub use huber::*;
|
||||
pub use mse::*;
|
||||
pub use reduction::*;
|
||||
|
|
Loading…
Reference in New Issue