Feat: mse loss (#378)

This commit is contained in:
Yu Sun 2023-06-03 22:31:12 +08:00 committed by GitHub
parent 974fdfaba1
commit 498d163a7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 86 additions and 0 deletions

View File

@ -1,3 +1,6 @@
mod cross_entropy;
mod mse;
mod reduction;
pub use cross_entropy::*;
pub use mse::*;

View File

@ -0,0 +1,78 @@
use crate::nn::loss::reduction::Reduction;
use core::marker::PhantomData;
use burn_tensor::{backend::Backend, Tensor};
/// Calculate the mean squared error loss from the input logits and the targets.
#[derive(Clone, Debug)]
pub struct MSELoss<B: Backend> {
backend: PhantomData<B>,
}
impl<B: Backend> Default for MSELoss<B> {
fn default() -> Self {
Self::new()
}
}
impl<B: Backend> MSELoss<B> {
/// Create the criterion.
pub fn new() -> Self {
Self {
backend: PhantomData::default(),
}
}
/// Compute the criterion on the input tensor.
///
/// # Shapes
///
/// - logits: [batch_size, num_targets]
/// - targets: [batch_size, num_targets]
pub fn forward<const D: usize>(
&self,
logits: Tensor<B, D>,
targets: Tensor<B, D>,
reduction: Reduction,
) -> Tensor<B, 1> {
let tensor = self.forward_no_reduction(logits, targets);
match reduction {
Reduction::Mean | Reduction::Auto => tensor.mean(),
Reduction::Sum => tensor.sum(),
}
}
pub fn forward_no_reduction<const D: usize>(
&self,
logits: Tensor<B, D>,
targets: Tensor<B, D>,
) -> Tensor<B, D> {
logits.sub(targets).powf(2.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn_tensor::Data;
#[test]
fn test_mse_loss() {
let logits = Tensor::<TestBackend, 2>::from_data(Data::from([[1.0, 2.0], [3.0, 4.0]]));
let targets = Tensor::<TestBackend, 2>::from_data(Data::from([[2.0, 1.0], [3.0, 2.0]]));
let mse = MSELoss::new();
let loss_no_reduction = mse.forward_no_reduction(logits.clone(), targets.clone());
let loss = mse.forward(logits.clone(), targets.clone(), Reduction::Auto);
let loss_sum = mse.forward(logits, targets, Reduction::Sum);
assert_eq!(
loss_no_reduction.into_data(),
Data::from([[1.0, 1.0], [0.0, 4.0]])
);
assert_eq!(loss.into_data(), Data::from([1.5]));
assert_eq!(loss_sum.into_data(), Data::from([6.0]));
}
}

View File

@ -0,0 +1,5 @@
pub enum Reduction {
Mean,
Sum,
Auto,
}