From 498d163a7b7cbdf2e83d0b7b9bbfa8d141773ade Mon Sep 17 00:00:00 2001 From: Yu Sun Date: Sat, 3 Jun 2023 22:31:12 +0800 Subject: [PATCH] Feat: mse loss (#378) --- burn-core/src/nn/loss/mod.rs | 3 ++ burn-core/src/nn/loss/mse.rs | 78 ++++++++++++++++++++++++++++++ burn-core/src/nn/loss/reduction.rs | 5 ++ 3 files changed, 86 insertions(+) create mode 100644 burn-core/src/nn/loss/mse.rs create mode 100644 burn-core/src/nn/loss/reduction.rs diff --git a/burn-core/src/nn/loss/mod.rs b/burn-core/src/nn/loss/mod.rs index 33ffad4f1..20ddbf76c 100644 --- a/burn-core/src/nn/loss/mod.rs +++ b/burn-core/src/nn/loss/mod.rs @@ -1,3 +1,6 @@ mod cross_entropy; +mod mse; +mod reduction; pub use cross_entropy::*; +pub use mse::*; diff --git a/burn-core/src/nn/loss/mse.rs b/burn-core/src/nn/loss/mse.rs new file mode 100644 index 000000000..95342b952 --- /dev/null +++ b/burn-core/src/nn/loss/mse.rs @@ -0,0 +1,78 @@ +use crate::nn::loss::reduction::Reduction; +use core::marker::PhantomData; + +use burn_tensor::{backend::Backend, Tensor}; + +/// Calculate the mean squared error loss from the input logits and the targets. +#[derive(Clone, Debug)] +pub struct MSELoss { + backend: PhantomData, +} + +impl Default for MSELoss { + fn default() -> Self { + Self::new() + } +} + +impl MSELoss { + /// Create the criterion. + pub fn new() -> Self { + Self { + backend: PhantomData::default(), + } + } + + /// Compute the criterion on the input tensor. + /// + /// # Shapes + /// + /// - logits: [batch_size, num_targets] + /// - targets: [batch_size, num_targets] + pub fn forward( + &self, + logits: Tensor, + targets: Tensor, + reduction: Reduction, + ) -> Tensor { + let tensor = self.forward_no_reduction(logits, targets); + match reduction { + Reduction::Mean | Reduction::Auto => tensor.mean(), + Reduction::Sum => tensor.sum(), + } + } + + pub fn forward_no_reduction( + &self, + logits: Tensor, + targets: Tensor, + ) -> Tensor { + logits.sub(targets).powf(2.0) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + use burn_tensor::Data; + + #[test] + fn test_mse_loss() { + let logits = Tensor::::from_data(Data::from([[1.0, 2.0], [3.0, 4.0]])); + + let targets = Tensor::::from_data(Data::from([[2.0, 1.0], [3.0, 2.0]])); + + let mse = MSELoss::new(); + let loss_no_reduction = mse.forward_no_reduction(logits.clone(), targets.clone()); + let loss = mse.forward(logits.clone(), targets.clone(), Reduction::Auto); + let loss_sum = mse.forward(logits, targets, Reduction::Sum); + + assert_eq!( + loss_no_reduction.into_data(), + Data::from([[1.0, 1.0], [0.0, 4.0]]) + ); + assert_eq!(loss.into_data(), Data::from([1.5])); + assert_eq!(loss_sum.into_data(), Data::from([6.0])); + } +} diff --git a/burn-core/src/nn/loss/reduction.rs b/burn-core/src/nn/loss/reduction.rs new file mode 100644 index 000000000..f0c42d543 --- /dev/null +++ b/burn-core/src/nn/loss/reduction.rs @@ -0,0 +1,5 @@ +pub enum Reduction { + Mean, + Sum, + Auto, +}