mirror of https://github.com/tracel-ai/burn.git
Support multilabel binary cross entropy (#1571)
* Support multilabel binary cross entropy * Add missing alloc Vec
This commit is contained in:
parent
93fac73e6d
commit
0978c8a586
|
@ -1,30 +1,26 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::{config::Config, module::Module};
|
||||
use burn_tensor::activation::sigmoid;
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::activation::log_sigmoid;
|
||||
use burn_tensor::{backend::Backend, Int, Tensor};
|
||||
|
||||
/// Configuration to create a [Binary Cross-entropy loss](BinaryCrossEntropyLoss).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct BinaryCrossEntropyLossConfig {
|
||||
/// Create weighted binary cross-entropy.
|
||||
/// Create weighted binary cross-entropy with a weight for each class.
|
||||
///
|
||||
/// The loss of a specific sample will simply be given by: weight * log(p(x)) * 1,
|
||||
///
|
||||
/// # Pre-conditions
|
||||
/// - The order of the weight vector should correspond to the label integer assignment.
|
||||
/// - Targets assigned negative Int's will not be allowed.
|
||||
pub weights: Option<[f32; 2]>,
|
||||
/// The loss of a specific sample will simply be multiplied by its label weight.
|
||||
pub weights: Option<Vec<f32>>,
|
||||
|
||||
/// Create binary cross-entropy with label smoothing.
|
||||
/// Create binary cross-entropy with label smoothing according to [When Does Label Smoothing Help?](https://arxiv.org/abs/1906.02629).
|
||||
///
|
||||
/// Hard labels {0, 1} will be changed to y_smoothed = y(1 - a) + a / nr_classes.
|
||||
/// Hard labels {0, 1} will be changed to `y_smoothed = y(1 - a) + a / num_classes`.
|
||||
/// Alpha = 0 would be the same as default.
|
||||
smoothing: Option<f32>,
|
||||
|
||||
/// Create binary cross-entropy with probabilities as input instead of logits.
|
||||
///
|
||||
#[config(default = true)]
|
||||
/// Treat the inputs as logits, applying a sigmoid activation when computing the loss.
|
||||
#[config(default = false)]
|
||||
logits: bool,
|
||||
}
|
||||
|
||||
|
@ -59,7 +55,7 @@ impl BinaryCrossEntropyLossConfig {
|
|||
}
|
||||
}
|
||||
|
||||
/// Calculate the cross entropy loss from the input logits and the targets.
|
||||
/// Calculate the binary cross entropy loss from the input logits and the targets.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct BinaryCrossEntropyLoss<B: Backend> {
|
||||
/// Weights for cross-entropy.
|
||||
|
@ -73,37 +69,77 @@ impl<B: Backend> BinaryCrossEntropyLoss<B> {
|
|||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// Binary:
|
||||
/// - logits: `[batch_size]`
|
||||
/// - targets: `[batch_size]`
|
||||
pub fn forward(&self, logits: Tensor<B, 1>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
|
||||
Self::assertions(logits.clone(), targets.clone());
|
||||
let mut targets_float = targets.clone().float();
|
||||
if let Some(alpha) = self.smoothing {
|
||||
targets_float = targets_float * (1. - alpha) + alpha / 2.;
|
||||
}
|
||||
let logits = if self.logits { sigmoid(logits) } else { logits };
|
||||
let loss = targets_float.clone() * logits.clone().log()
|
||||
+ (targets_float.clone().neg() + 1.) * (logits.neg() + 1.).log();
|
||||
///
|
||||
/// Multi-label:
|
||||
/// - logits: `[batch_size, num_classes]`
|
||||
/// - targets: `[batch_size, num_classes]`
|
||||
pub fn forward<const D: usize>(
|
||||
&self,
|
||||
logits: Tensor<B, D>,
|
||||
targets: Tensor<B, D, Int>,
|
||||
) -> Tensor<B, 1> {
|
||||
self.assertions(&logits, &targets);
|
||||
|
||||
match &self.weights {
|
||||
Some(weights) => {
|
||||
let weights = weights.clone().gather(0, targets);
|
||||
let loss = loss * weights.clone();
|
||||
loss.neg().sum() / weights.sum()
|
||||
}
|
||||
None => loss.mean().neg(),
|
||||
let mut targets_float = targets.clone().float();
|
||||
let shape = targets.dims();
|
||||
|
||||
if let Some(alpha) = self.smoothing {
|
||||
let num_classes = if D > 1 { shape[D - 1] } else { 2 };
|
||||
targets_float = targets_float * (1. - alpha) + alpha / num_classes as f32;
|
||||
}
|
||||
|
||||
let mut loss = if self.logits {
|
||||
// Numerically stable by combining `log(sigmoid(x))` with `log_sigmoid(x)`
|
||||
(targets_float.neg() + 1.) * logits.clone() - log_sigmoid(logits)
|
||||
} else {
|
||||
// - (target * log(input) + (1 - target) * log(1 - input))
|
||||
(targets_float.clone() * logits.clone().log()
|
||||
+ (targets_float.neg() + 1.) * (logits.neg() + 1.).log())
|
||||
.neg()
|
||||
};
|
||||
|
||||
if let Some(weights) = &self.weights {
|
||||
let weights = if D > 1 {
|
||||
weights.clone().expand(shape)
|
||||
} else {
|
||||
// Flatten targets and expand resulting weights to make it compatible with
|
||||
// Tensor<B, D> for binary 1-D case
|
||||
weights
|
||||
.clone()
|
||||
.gather(0, targets.flatten(0, 0))
|
||||
.expand(shape)
|
||||
};
|
||||
loss = loss * weights;
|
||||
}
|
||||
|
||||
loss.mean()
|
||||
}
|
||||
|
||||
fn assertions(logits: Tensor<B, 1>, targets: Tensor<B, 1, Int>) {
|
||||
let [logits_height] = logits.dims();
|
||||
let [targets_height] = targets.dims();
|
||||
fn assertions<const D: usize>(&self, logits: &Tensor<B, D>, targets: &Tensor<B, D, Int>) {
|
||||
let logits_dims = logits.dims();
|
||||
let targets_dims = targets.dims();
|
||||
assert!(
|
||||
logits_height == targets_height,
|
||||
"Shape of targets ({}) should correspond to outer shape of logits ({}).",
|
||||
targets_height,
|
||||
logits_height
|
||||
logits_dims == targets_dims,
|
||||
"Shape of targets ({:?}) should correspond to outer shape of logits ({:?}).",
|
||||
targets_dims,
|
||||
logits_dims
|
||||
);
|
||||
|
||||
if let Some(weights) = &self.weights {
|
||||
if D > 1 {
|
||||
let targets_classes = targets_dims[D - 1];
|
||||
let weights_classes = weights.dims()[0];
|
||||
assert!(
|
||||
weights_classes == targets_classes,
|
||||
"The number of classes ({}) does not match the weights provided ({}).",
|
||||
weights_classes,
|
||||
targets_classes
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -111,66 +147,211 @@ impl<B: Backend> BinaryCrossEntropyLoss<B> {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn_tensor::{activation::sigmoid, Data, Distribution};
|
||||
use burn_tensor::{activation::sigmoid, Data};
|
||||
|
||||
#[test]
|
||||
fn test_binary_cross_entropy() {
|
||||
let [batch_size] = [4];
|
||||
// import torch
|
||||
// from torch import nn
|
||||
// input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355])
|
||||
// target = torch.tensor([0., 1., 0., 1.])
|
||||
// loss = nn.BCELoss()
|
||||
// sigmoid = nn.Sigmoid()
|
||||
// out = loss(sigmoid(input), target) # tensor(0.7491)
|
||||
|
||||
let device = Default::default();
|
||||
let logits =
|
||||
Tensor::<TestBackend, 1>::random([batch_size], Distribution::Normal(0., 1.0), &device);
|
||||
Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
|
||||
let targets = Tensor::<TestBackend, 1, Int>::from_data(Data::from([0, 1, 0, 1]), &device);
|
||||
|
||||
let loss_1 = BinaryCrossEntropyLossConfig::new()
|
||||
let loss_actual = BinaryCrossEntropyLossConfig::new()
|
||||
.init(&device)
|
||||
.forward(logits.clone(), targets.clone());
|
||||
let logits = sigmoid(logits);
|
||||
let loss_2 = targets.clone().float() * logits.clone().log()
|
||||
+ (-targets.float() + 1) * (-logits + 1).log();
|
||||
let loss_2 = loss_2.mean().neg();
|
||||
loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
|
||||
.forward(sigmoid(logits), targets)
|
||||
.into_data();
|
||||
|
||||
let loss_expected = Data::from([0.7491]);
|
||||
loss_actual.assert_approx_eq(&loss_expected, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_cross_entropy_with_logits() {
|
||||
let device = Default::default();
|
||||
let logits =
|
||||
Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
|
||||
let targets = Tensor::<TestBackend, 1, Int>::from_data(Data::from([0, 1, 0, 1]), &device);
|
||||
|
||||
let loss_actual = BinaryCrossEntropyLossConfig::new()
|
||||
.with_logits(true)
|
||||
.init(&device)
|
||||
.forward(logits, targets)
|
||||
.into_data();
|
||||
|
||||
let loss_expected = Data::from([0.7491]);
|
||||
loss_actual.assert_approx_eq(&loss_expected, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_cross_entropy_with_weights() {
|
||||
let [batch_size] = [4];
|
||||
// import torch
|
||||
// from torch import nn
|
||||
// input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355])
|
||||
// target = torch.tensor([0, 1, 0, 1])
|
||||
// weights = torch.tensor([3., 7.]).gather(0, target)
|
||||
// loss = nn.BCELoss(weights)
|
||||
// sigmoid = nn.Sigmoid()
|
||||
// out = loss(sigmoid(input), target.float()) # tensor(3.1531)
|
||||
|
||||
let device = Default::default();
|
||||
let logits =
|
||||
Tensor::<TestBackend, 1>::random([batch_size], Distribution::Normal(0., 1.0), &device);
|
||||
Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
|
||||
let targets = Tensor::<TestBackend, 1, Int>::from_data(Data::from([0, 1, 0, 1]), &device);
|
||||
let weights = [3., 7.];
|
||||
|
||||
let loss_1 = BinaryCrossEntropyLossConfig::new()
|
||||
.with_weights(Some(weights))
|
||||
let loss_actual = BinaryCrossEntropyLossConfig::new()
|
||||
.with_weights(Some(weights.to_vec()))
|
||||
.init(&device)
|
||||
.forward(logits.clone(), targets.clone());
|
||||
let logits = sigmoid(logits);
|
||||
let loss_2 = targets.clone().float() * logits.clone().log()
|
||||
+ (-targets.float() + 1) * (-logits + 1).log();
|
||||
.forward(sigmoid(logits), targets)
|
||||
.into_data();
|
||||
|
||||
let loss_2 = loss_2 * Tensor::from_floats([3., 7., 3., 7.], &device);
|
||||
let loss_2 = loss_2.neg().sum() / (3. + 3. + 7. + 7.);
|
||||
loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
|
||||
let loss_expected = Data::from([3.1531]);
|
||||
loss_actual.assert_approx_eq(&loss_expected, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_cross_entropy_with_smoothing() {
|
||||
let [batch_size] = [4];
|
||||
// import torch
|
||||
// from torch import nn
|
||||
// input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355])
|
||||
// target = torch.tensor([0., 1., 0., 1.])
|
||||
// target_smooth = target * (1 - 0.1) + (0.1 / 2)
|
||||
// loss = nn.BCELoss()
|
||||
// sigmoid = nn.Sigmoid()
|
||||
// out = loss(sigmoid(input), target_smooth) # tensor(0.7490)
|
||||
|
||||
let device = Default::default();
|
||||
let logits =
|
||||
Tensor::<TestBackend, 1>::random([batch_size], Distribution::Normal(0., 1.0), &device);
|
||||
Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
|
||||
let targets = Tensor::<TestBackend, 1, Int>::from_data(Data::from([0, 1, 0, 1]), &device);
|
||||
|
||||
let loss_1 = BinaryCrossEntropyLossConfig::new()
|
||||
let loss_actual = BinaryCrossEntropyLossConfig::new()
|
||||
.with_smoothing(Some(0.1))
|
||||
.init(&device)
|
||||
.forward(logits.clone(), targets.clone());
|
||||
.forward(sigmoid(logits), targets)
|
||||
.into_data();
|
||||
|
||||
let logits = sigmoid(logits);
|
||||
let targets = targets.float() * (1. - 0.1) + 0.1 / 2.;
|
||||
let loss_2 = targets.clone() * logits.clone().log() + (-targets + 1) * (-logits + 1).log();
|
||||
let loss_2 = loss_2.mean().neg();
|
||||
let loss_expected = Data::from([0.7490]);
|
||||
loss_actual.assert_approx_eq(&loss_expected, 3);
|
||||
}
|
||||
|
||||
loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
|
||||
#[test]
|
||||
fn test_binary_cross_entropy_multilabel() {
|
||||
// import torch
|
||||
// from torch import nn
|
||||
// input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
|
||||
// target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
|
||||
// weights = torch.tensor([3., 7., 0.9])
|
||||
// loss = nn.BCEWithLogitsLoss()
|
||||
// out = loss(input, target) # tensor(0.7112)
|
||||
|
||||
let device = Default::default();
|
||||
let logits = Tensor::<TestBackend, 2>::from_floats(
|
||||
[[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
|
||||
&device,
|
||||
);
|
||||
let targets =
|
||||
Tensor::<TestBackend, 2, Int>::from_data(Data::from([[1, 0, 1], [1, 0, 0]]), &device);
|
||||
|
||||
let loss_actual = BinaryCrossEntropyLossConfig::new()
|
||||
.with_logits(true)
|
||||
.init(&device)
|
||||
.forward(logits, targets)
|
||||
.into_data();
|
||||
|
||||
let loss_expected = Data::from([0.7112]);
|
||||
loss_actual.assert_approx_eq(&loss_expected, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_cross_entropy_multilabel_with_weights() {
|
||||
// import torch
|
||||
// from torch import nn
|
||||
// input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
|
||||
// target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
|
||||
// loss = nn.BCEWithLogitsLoss()
|
||||
// out = loss(input, target) # tensor(3.1708)
|
||||
|
||||
let device = Default::default();
|
||||
let logits = Tensor::<TestBackend, 2>::from_floats(
|
||||
[[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
|
||||
&device,
|
||||
);
|
||||
let targets =
|
||||
Tensor::<TestBackend, 2, Int>::from_data(Data::from([[1, 0, 1], [1, 0, 0]]), &device);
|
||||
let weights = [3., 7., 0.9];
|
||||
|
||||
let loss_actual = BinaryCrossEntropyLossConfig::new()
|
||||
.with_logits(true)
|
||||
.with_weights(Some(weights.to_vec()))
|
||||
.init(&device)
|
||||
.forward(logits, targets)
|
||||
.into_data();
|
||||
|
||||
let loss_expected = Data::from([3.1708]);
|
||||
loss_actual.assert_approx_eq(&loss_expected, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_cross_entropy_multilabel_with_smoothing() {
|
||||
// import torch
|
||||
// from torch import nn
|
||||
// input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
|
||||
// target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
|
||||
// target_smooth = target * (1 - 0.1) + (0.1 / 3)
|
||||
// loss = nn.BCELoss()
|
||||
// sigmoid = nn.Sigmoid()
|
||||
// out = loss(sigmoid(input), target_smooth) # tensor(0.7228)
|
||||
|
||||
let device = Default::default();
|
||||
let logits = Tensor::<TestBackend, 2>::from_floats(
|
||||
[[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
|
||||
&device,
|
||||
);
|
||||
let targets =
|
||||
Tensor::<TestBackend, 2, Int>::from_data(Data::from([[1, 0, 1], [1, 0, 0]]), &device);
|
||||
|
||||
let loss_actual = BinaryCrossEntropyLossConfig::new()
|
||||
.with_smoothing(Some(0.1))
|
||||
.init(&device)
|
||||
.forward(sigmoid(logits), targets)
|
||||
.into_data();
|
||||
|
||||
let loss_expected = Data::from([0.7228]);
|
||||
loss_actual.assert_approx_eq(&loss_expected, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "The number of classes"]
|
||||
fn multilabel_weights_should_match_target() {
|
||||
// import torch
|
||||
// from torch import nn
|
||||
// input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
|
||||
// target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
|
||||
// loss = nn.BCEWithLogitsLoss()
|
||||
// out = loss(input, target) # tensor(3.1708)
|
||||
|
||||
let device = Default::default();
|
||||
let logits = Tensor::<TestBackend, 2>::from_floats(
|
||||
[[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
|
||||
&device,
|
||||
);
|
||||
let targets =
|
||||
Tensor::<TestBackend, 2, Int>::from_data(Data::from([[1, 0, 1], [1, 0, 0]]), &device);
|
||||
let weights = [3., 7.];
|
||||
|
||||
let _loss = BinaryCrossEntropyLossConfig::new()
|
||||
.with_logits(true)
|
||||
.with_weights(Some(weights.to_vec()))
|
||||
.init(&device)
|
||||
.forward(logits, targets);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue