From 0cbe9a927d0c4e522021becaa705623aa205dec7 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 11 Apr 2024 12:32:25 -0400 Subject: [PATCH] Add learner training report summary (#1591) * Add training report summary * Fix LossMetric batch size state * Add NumericEntry de/serialize * Fix clippy suggestion * Compact recorder does not use compression (anymore) * Add learner summary expected results tests * Add summary to learner builder and automatically display in fit - Add LearnerSummaryConfig - Keep track of summary metrics names - Add model field when displaying from learner.fit() --- burn-book/src/basic-workflow/training.md | 29 +- crates/burn-train/src/learner/base.rs | 2 + crates/burn-train/src/learner/builder.rs | 29 +- crates/burn-train/src/learner/mod.rs | 2 + crates/burn-train/src/learner/summary.rs | 307 ++++++++++++++++++ crates/burn-train/src/learner/train_val.rs | 13 + crates/burn-train/src/logger/metric.rs | 52 ++- crates/burn-train/src/metric/base.rs | 43 +++ crates/burn-train/src/metric/loss.rs | 8 +- .../src/metric/processor/metrics.rs | 2 +- crates/burn-train/src/metric/state.rs | 5 +- .../burn-train/src/metric/store/aggregate.rs | 51 ++- examples/custom-image-dataset/src/training.rs | 10 +- examples/guide/src/training.rs | 9 +- examples/mnist/src/training.rs | 8 + examples/simple-regression/src/training.rs | 1 + examples/text-classification/src/training.rs | 1 + examples/text-generation/src/training.rs | 1 + 18 files changed, 539 insertions(+), 34 deletions(-) create mode 100644 crates/burn-train/src/learner/summary.rs diff --git a/burn-book/src/basic-workflow/training.md b/burn-book/src/basic-workflow/training.md index a015fb2b8..08473f06e 100644 --- a/burn-book/src/basic-workflow/training.md +++ b/burn-book/src/basic-workflow/training.md @@ -1,11 +1,11 @@ # Training -We are now ready to write the necessary code to train our model on the MNIST dataset. -We shall define the code for this training section in the file: `src/training.rs`. +We are now ready to write the necessary code to train our model on the MNIST dataset. We shall +define the code for this training section in the file: `src/training.rs`. -Instead of a simple tensor, the model should output an item that can be understood by the learner, a struct whose -responsibility is to apply an optimizer to the model. The output struct is used for all metrics -calculated during the training. Therefore it should include all the necessary information to +Instead of a simple tensor, the model should output an item that can be understood by the learner, a +struct whose responsibility is to apply an optimizer to the model. The output struct is used for all +metrics calculated during the training. Therefore it should include all the necessary information to calculate any metric that you want for a task. Burn provides two basic output types: `ClassificationOutput` and `RegressionOutput`. They implement @@ -110,8 +110,14 @@ pub struct TrainingConfig { pub learning_rate: f64, } -pub fn train(artifact_dir: &str, config: TrainingConfig, device: B::Device) { +fn create_artifact_dir(artifact_dir: &str) { + // Remove existing artifacts before to get an accurate learner summary + std::fs::remove_dir_all(artifact_dir).ok(); std::fs::create_dir_all(artifact_dir).ok(); +} + +pub fn train(artifact_dir: &str, config: TrainingConfig, device: B::Device) { + create_artifact_dir(artifact_dir); config .save(format!("{artifact_dir}/config.json")) .expect("Config should be saved successfully"); @@ -141,6 +147,7 @@ pub fn train(artifact_dir: &str, config: TrainingConfig, dev .with_file_checkpointer(CompactRecorder::new()) .devices(vec![device.clone()]) .num_epochs(config.num_epochs) + .summary() .build( config.model.init::(&device), config.optimizer.init(), @@ -181,8 +188,8 @@ Once the learner is created, we can simply call `fit` and provide the training a dataloaders. For the sake of simplicity in this example, we employ the test set as the validation set; however, we do not recommend this practice for actual usage. -Finally, the trained model is returned by the `fit` method, and the only remaining task is saving -the trained weights using the `CompactRecorder`. This recorder employs the `MessagePack` format with -`gzip` compression, `f16` for floats and `i16` for integers. Other recorders are available, offering -support for various formats, such as `BinCode` and `JSON`, with or without compression. Any backend, -regardless of precision, can load recorded data of any kind. +Finally, the trained model is returned by the `fit` method. The trained weights are then saved using +the `CompactRecorder`. This recorder employs the `MessagePack` format with half precision, `f16` for +floats and `i16` for integers. Other recorders are available, offering support for various formats, +such as `BinCode` and `JSON`, with or without compression. Any backend, regardless of precision, can +load recorded data of any kind. diff --git a/crates/burn-train/src/learner/base.rs b/crates/burn-train/src/learner/base.rs index c49e10375..0534b0f4d 100644 --- a/crates/burn-train/src/learner/base.rs +++ b/crates/burn-train/src/learner/base.rs @@ -2,6 +2,7 @@ use crate::checkpoint::{Checkpointer, CheckpointingAction, CheckpointingStrategy use crate::components::LearnerComponents; use crate::learner::EarlyStoppingStrategy; use crate::metric::store::EventStoreClient; +use crate::LearnerSummaryConfig; use burn_core::lr_scheduler::LrScheduler; use burn_core::module::Module; use burn_core::optim::Optimizer; @@ -26,6 +27,7 @@ pub struct Learner { pub(crate) early_stopping: Option>, pub(crate) event_processor: LC::EventProcessor, pub(crate) event_store: Arc, + pub(crate) summary: Option, } #[derive(new)] diff --git a/crates/burn-train/src/learner/builder.rs b/crates/burn-train/src/learner/builder.rs index 50b9cf41c..394a842ab 100644 --- a/crates/burn-train/src/learner/builder.rs +++ b/crates/burn-train/src/learner/builder.rs @@ -1,3 +1,4 @@ +use std::collections::HashSet; use std::sync::Arc; use super::log::install_file_logger; @@ -14,7 +15,7 @@ use crate::metric::processor::{FullEventProcessor, Metrics}; use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split}; use crate::metric::{Adaptor, LossMetric, Metric}; use crate::renderer::{default_renderer, MetricsRenderer}; -use crate::LearnerCheckpointer; +use crate::{LearnerCheckpointer, LearnerSummaryConfig}; use burn_core::lr_scheduler::LrScheduler; use burn_core::module::AutodiffModule; use burn_core::optim::Optimizer; @@ -53,6 +54,8 @@ where num_loggers: usize, checkpointer_strategy: Box, early_stopping: Option>, + summary_metrics: HashSet, + summary: bool, } impl LearnerBuilder @@ -94,6 +97,8 @@ where .build(), ), early_stopping: None, + summary_metrics: HashSet::new(), + summary: false, } } @@ -140,7 +145,7 @@ where where T: Adaptor, { - self.metrics.register_metric_train(metric); + self.metrics.register_train_metric(metric); self } @@ -174,6 +179,7 @@ where Me: Metric + crate::metric::Numeric + 'static, T: Adaptor, { + self.summary_metrics.insert(Me::NAME.to_string()); self.metrics.register_train_metric_numeric(metric); self } @@ -186,6 +192,7 @@ where where V: Adaptor, { + self.summary_metrics.insert(Me::NAME.to_string()); self.metrics.register_valid_metric_numeric(metric); self } @@ -266,6 +273,14 @@ where self } + /// Enable the training summary report. + /// + /// The summary will be displayed at the end of `.fit()`. + pub fn summary(mut self) -> Self { + self.summary = true; + self + } + /// Create the [learner](Learner) from a [model](AutodiffModule) and an [optimizer](Optimizer). /// The [learning rate scheduler](LrScheduler) can also be a simple /// [learning rate](burn_core::LearningRate). @@ -320,6 +335,15 @@ where LearnerCheckpointer::new(model, optim, scheduler, self.checkpointer_strategy) }); + let summary = if self.summary { + Some(LearnerSummaryConfig { + directory: self.directory, + metrics: self.summary_metrics.into_iter().collect::>(), + }) + } else { + None + }; + Learner { model, optim, @@ -333,6 +357,7 @@ where devices: self.devices, interrupter: self.interrupter, early_stopping: self.early_stopping, + summary, } } diff --git a/crates/burn-train/src/learner/mod.rs b/crates/burn-train/src/learner/mod.rs index e01080475..be4cc258b 100644 --- a/crates/burn-train/src/learner/mod.rs +++ b/crates/burn-train/src/learner/mod.rs @@ -5,6 +5,7 @@ mod early_stopping; mod epoch; mod regression; mod step; +mod summary; mod train_val; pub(crate) mod log; @@ -16,5 +17,6 @@ pub use early_stopping::*; pub use epoch::*; pub use regression::*; pub use step::*; +pub use summary::*; pub use train::*; pub use train_val::*; diff --git a/crates/burn-train/src/learner/summary.rs b/crates/burn-train/src/learner/summary.rs new file mode 100644 index 000000000..7fee928bd --- /dev/null +++ b/crates/burn-train/src/learner/summary.rs @@ -0,0 +1,307 @@ +use core::cmp::Ordering; +use std::{fmt::Display, path::Path}; + +use crate::{ + logger::FileMetricLogger, + metric::store::{Aggregate, EventStore, LogEventStore, Split}, +}; + +/// Contains the metric value at a given time. +pub struct MetricEntry { + /// The step at which the metric was recorded (i.e., epoch). + pub step: usize, + /// The metric value. + pub value: f64, +} + +/// Contains the summary of recorded values for a given metric. +pub struct MetricSummary { + /// The metric name. + pub name: String, + /// The metric entries. + pub entries: Vec, +} + +impl MetricSummary { + fn new( + event_store: &mut E, + metric: &str, + split: Split, + num_epochs: usize, + ) -> Option { + let entries = (1..=num_epochs) + .filter_map(|epoch| { + event_store + .find_metric(metric, epoch, Aggregate::Mean, split) + .map(|value| MetricEntry { step: epoch, value }) + }) + .collect::>(); + + if entries.is_empty() { + None + } else { + Some(Self { + name: metric.to_string(), + entries, + }) + } + } +} + +/// Contains the summary of recorded metrics for the training and validation steps. +pub struct SummaryMetrics { + /// Training metrics summary. + pub train: Vec, + /// Validation metrics summary. + pub valid: Vec, +} + +/// Detailed training summary. +pub struct LearnerSummary { + /// The number of epochs completed. + pub epochs: usize, + /// The summary of recorded metrics during training. + pub metrics: SummaryMetrics, + /// The model name (only recorded within the learner). + pub(crate) model: Option, +} + +impl LearnerSummary { + /// Creates a new learner summary for the specified metrics. + /// + /// # Arguments + /// + /// * `directory` - The directory containing the training artifacts (checkpoints and logs). + /// * `metrics` - The list of metrics to collect for the summary. + pub fn new>(directory: &str, metrics: &[S]) -> Result { + let directory_path = Path::new(directory); + if !directory_path.exists() { + return Err(format!("Artifact directory does not exist at: {directory}")); + } + let train_dir = directory_path.join("train"); + let valid_dir = directory_path.join("valid"); + if !train_dir.exists() & !valid_dir.exists() { + return Err(format!( + "No training or validation artifacts found at: {directory}" + )); + } + + let mut event_store = LogEventStore::default(); + + let train_logger = FileMetricLogger::new(train_dir.to_str().unwrap()); + let valid_logger = FileMetricLogger::new(valid_dir.to_str().unwrap()); + + // Number of recorded epochs + let epochs = train_logger.epochs(); + + event_store.register_logger_train(train_logger); + event_store.register_logger_valid(valid_logger); + + let train_summary = metrics + .iter() + .filter_map(|metric| { + MetricSummary::new(&mut event_store, metric.as_ref(), Split::Train, epochs) + }) + .collect::>(); + + let valid_summary = metrics + .iter() + .filter_map(|metric| { + MetricSummary::new(&mut event_store, metric.as_ref(), Split::Valid, epochs) + }) + .collect::>(); + + Ok(Self { + epochs, + metrics: SummaryMetrics { + train: train_summary, + valid: valid_summary, + }, + model: None, + }) + } + + pub(crate) fn with_model(mut self, name: String) -> Self { + self.model = Some(name); + self + } +} + +impl Display for LearnerSummary { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Compute the max length for each column + let split_train = "Train"; + let split_valid = "Valid"; + let max_split_len = "Split".len().max(split_train.len()).max(split_valid.len()); + let mut max_metric_len = "Metric".len(); + for metric in self.metrics.train.iter() { + max_metric_len = max_metric_len.max(metric.name.len()); + } + for metric in self.metrics.valid.iter() { + max_metric_len = max_metric_len.max(metric.name.len()); + } + + // Summary header + writeln!( + f, + "{:=>width_symbol$} Learner Summary {:=>width_symbol$}", + "", + "", + width_symbol = 24, + )?; + + if let Some(model) = &self.model { + writeln!(f, "Model: {model}")?; + } + writeln!(f, "Total Epochs: {epochs}\n\n", epochs = self.epochs)?; + + // Metrics table header + writeln!( + f, + "| {:width_split$}--|{:->width_metric$}--|----------|----------|----------|----------|", + "Split", "Metric", "", "", + width_split = max_split_len, + width_metric = max_metric_len, + )?; + + // Table entries + fn cmp_f64(a: &f64, b: &f64) -> Ordering { + match (a.is_nan(), b.is_nan()) { + (true, true) => Ordering::Equal, + (true, false) => Ordering::Greater, + (false, true) => Ordering::Less, + _ => a.partial_cmp(b).unwrap(), + } + } + + let mut write_metrics_summary = |metrics: &[MetricSummary], + split: &str| + -> std::fmt::Result { + for metric in metrics.iter() { + if metric.entries.is_empty() { + continue; // skip metrics with no recorded values + } + + // Compute the min & max for each metric + let metric_min = metric + .entries + .iter() + .min_by(|a, b| cmp_f64(&a.value, &b.value)) + .unwrap(); + let metric_max = metric + .entries + .iter() + .max_by(|a, b| cmp_f64(&a.value, &b.value)) + .unwrap(); + + writeln!( + f, + "| {:, +} + +impl LearnerSummaryConfig { + pub fn init(&self) -> Result { + LearnerSummary::new(&self.directory, &self.metrics[..]) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[should_panic = "Summary artifacts should exist"] + fn test_artifact_dir_should_exist() { + let dir = "/tmp/learner-summary-not-found"; + let _summary = LearnerSummary::new(dir, &["Loss"]).expect("Summary artifacts should exist"); + } + + #[test] + #[should_panic = "Summary artifacts should exist"] + fn test_train_valid_artifacts_should_exist() { + let dir = "/tmp/test-learner-summary-empty"; + std::fs::create_dir_all(dir).ok(); + let _summary = LearnerSummary::new(dir, &["Loss"]).expect("Summary artifacts should exist"); + } + + #[test] + fn test_summary_should_be_empty() { + let dir = Path::new("/tmp/test-learner-summary-empty-metrics"); + std::fs::create_dir_all(dir).unwrap(); + std::fs::create_dir_all(dir.join("train/epoch-1")).unwrap(); + std::fs::create_dir_all(dir.join("valid/epoch-1")).unwrap(); + let summary = LearnerSummary::new(dir.to_str().unwrap(), &["Loss"]) + .expect("Summary artifacts should exist"); + + assert_eq!(summary.epochs, 1); + + assert_eq!(summary.metrics.train.len(), 0); + assert_eq!(summary.metrics.valid.len(), 0); + + std::fs::remove_dir_all(dir).unwrap(); + } + + #[test] + fn test_summary_should_be_collected() { + let dir = Path::new("/tmp/test-learner-summary"); + let train_dir = dir.join("train/epoch-1"); + let valid_dir = dir.join("valid/epoch-1"); + std::fs::create_dir_all(dir).unwrap(); + std::fs::create_dir_all(&train_dir).unwrap(); + std::fs::create_dir_all(&valid_dir).unwrap(); + + std::fs::write(train_dir.join("Loss.log"), "1.0\n2.0").expect("Unable to write file"); + std::fs::write(valid_dir.join("Loss.log"), "1.0").expect("Unable to write file"); + + let summary = LearnerSummary::new(dir.to_str().unwrap(), &["Loss"]) + .expect("Summary artifacts should exist"); + + assert_eq!(summary.epochs, 1); + + // Only Loss metric + assert_eq!(summary.metrics.train.len(), 1); + assert_eq!(summary.metrics.valid.len(), 1); + + // Aggregated train metric entries for 1 epoch + let train_metric = &summary.metrics.train[0]; + assert_eq!(train_metric.name, "Loss"); + assert_eq!(train_metric.entries.len(), 1); + let entry = &train_metric.entries[0]; + assert_eq!(entry.step, 1); // epoch = 1 + assert_eq!(entry.value, 1.5); // (1 + 2) / 2 + + // Aggregated valid metric entries for 1 epoch + let valid_metric = &summary.metrics.valid[0]; + assert_eq!(valid_metric.name, "Loss"); + assert_eq!(valid_metric.entries.len(), 1); + let entry = &valid_metric.entries[0]; + assert_eq!(entry.step, 1); // epoch = 1 + assert_eq!(entry.value, 1.0); + + std::fs::remove_dir_all(dir).unwrap(); + } +} diff --git a/crates/burn-train/src/learner/train_val.rs b/crates/burn-train/src/learner/train_val.rs index ad7721d17..cbca2895a 100644 --- a/crates/burn-train/src/learner/train_val.rs +++ b/crates/burn-train/src/learner/train_val.rs @@ -199,6 +199,19 @@ impl Learner { } } + // Display learner summary + if let Some(summary) = self.summary { + match summary.init() { + Ok(summary) => { + // Drop event processor (includes renderer) so the summary is displayed + // when switching back to "main" screen + core::mem::drop(self.event_processor); + println!("{}", summary.with_model(self.model.to_string())) + } + Err(err) => log::error!("Could not retrieve learner summary:\n{err}"), + } + } + self.model } } diff --git a/crates/burn-train/src/logger/metric.rs b/crates/burn-train/src/logger/metric.rs index 5751eff92..244485cbd 100644 --- a/crates/burn-train/src/logger/metric.rs +++ b/crates/burn-train/src/logger/metric.rs @@ -1,6 +1,8 @@ use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger}; -use crate::metric::MetricEntry; -use std::collections::HashMap; +use crate::metric::{MetricEntry, NumericEntry}; +use std::{collections::HashMap, fs}; + +const EPOCH_PREFIX: &str = "epoch-"; /// Metric logger. pub trait MetricLogger: Send { @@ -19,7 +21,7 @@ pub trait MetricLogger: Send { fn end_epoch(&mut self, epoch: usize); /// Read the logs for an epoch. - fn read_numeric(&mut self, name: &str, epoch: usize) -> Result, String>; + fn read_numeric(&mut self, name: &str, epoch: usize) -> Result, String>; } /// The file metric logger. @@ -47,14 +49,44 @@ impl FileMetricLogger { } } + /// Number of epochs recorded. + pub(crate) fn epochs(&self) -> usize { + let mut max_epoch = 0; + + for path in fs::read_dir(&self.directory).unwrap() { + let path = path.unwrap(); + + if fs::metadata(path.path()).unwrap().is_dir() { + let dir_name = path.file_name().into_string().unwrap(); + + if !dir_name.starts_with(EPOCH_PREFIX) { + continue; + } + + let epoch = dir_name.replace(EPOCH_PREFIX, "").parse::().ok(); + + if let Some(epoch) = epoch { + if epoch > max_epoch { + max_epoch = epoch; + } + } + } + } + + max_epoch + } + + fn epoch_directory(&self, epoch: usize) -> String { + format!("{}/{}{}", self.directory, EPOCH_PREFIX, epoch) + } fn file_path(&self, name: &str, epoch: usize) -> String { - let directory = format!("{}/epoch-{}", self.directory, epoch); + let directory = self.epoch_directory(epoch); let name = name.replace(' ', "_"); format!("{directory}/{name}.log") } fn create_directory(&self, epoch: usize) { - let directory = format!("{}/epoch-{}", self.directory, epoch); + let directory = self.epoch_directory(epoch); std::fs::create_dir_all(directory).ok(); } } @@ -88,7 +120,7 @@ impl MetricLogger for FileMetricLogger { self.epoch = epoch + 1; } - fn read_numeric(&mut self, name: &str, epoch: usize) -> Result, String> { + fn read_numeric(&mut self, name: &str, epoch: usize) -> Result, String> { if let Some(value) = self.loggers.get(name) { value.sync() } @@ -104,7 +136,7 @@ impl MetricLogger for FileMetricLogger { if value.is_empty() { None } else { - match value.parse::() { + match NumericEntry::deserialize(value) { Ok(value) => Some(value), Err(err) => { log::error!("{err}"); @@ -117,7 +149,7 @@ impl MetricLogger for FileMetricLogger { .collect(); if errors { - Err("Parsing float errors".to_string()) + Err("Parsing numeric entry errors".to_string()) } else { Ok(data) } @@ -154,7 +186,7 @@ impl MetricLogger for InMemoryMetricLogger { } } - fn read_numeric(&mut self, name: &str, epoch: usize) -> Result, String> { + fn read_numeric(&mut self, name: &str, epoch: usize) -> Result, String> { let values = match self.values.get(name) { Some(values) => values, None => return Ok(Vec::new()), @@ -164,7 +196,7 @@ impl MetricLogger for InMemoryMetricLogger { Some(logger) => Ok(logger .values .iter() - .filter_map(|value| value.parse::().ok()) + .filter_map(|value| NumericEntry::deserialize(value).ok()) .collect()), None => Ok(Vec::new()), } diff --git a/crates/burn-train/src/metric/base.rs b/crates/burn-train/src/metric/base.rs index 1d0f2ca49..e0eafe649 100644 --- a/crates/burn-train/src/metric/base.rs +++ b/crates/burn-train/src/metric/base.rs @@ -84,6 +84,49 @@ pub struct MetricEntry { pub serialize: String, } +/// Numeric metric entry. +pub enum NumericEntry { + /// Single numeric value. + Value(f64), + /// Aggregated numeric (value, number of elements). + Aggregated(f64, usize), +} + +impl NumericEntry { + pub(crate) fn serialize(&self) -> String { + match self { + Self::Value(v) => v.to_string(), + Self::Aggregated(v, n) => format!("{v},{n}"), + } + } + + pub(crate) fn deserialize(entry: &str) -> Result { + // Check for comma separated values + let values = entry.split(',').collect::>(); + let num_values = values.len(); + + if num_values == 1 { + // Numeric value + match values[0].parse::() { + Ok(value) => Ok(NumericEntry::Value(value)), + Err(err) => Err(err.to_string()), + } + } else if num_values == 2 { + // Aggregated numeric (value, number of elements) + let (value, numel) = (values[0], values[1]); + match value.parse::() { + Ok(value) => match numel.parse::() { + Ok(numel) => Ok(NumericEntry::Aggregated(value, numel)), + Err(err) => Err(err.to_string()), + }, + Err(err) => Err(err.to_string()), + } + } else { + Err("Invalid number of values for numeric entry".to_string()) + } + } +} + /// Format a float with the given precision. Will use scientific notation if necessary. pub fn format_float(float: f64, precision: usize) -> String { let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0); diff --git a/crates/burn-train/src/metric/loss.rs b/crates/burn-train/src/metric/loss.rs index 62ed71d81..347c57a2e 100644 --- a/crates/burn-train/src/metric/loss.rs +++ b/crates/burn-train/src/metric/loss.rs @@ -33,10 +33,14 @@ impl Metric for LossMetric { type Input = LossInput; fn update(&mut self, loss: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { + let [batch_size] = loss.tensor.dims(); let loss = f64::from_elem(loss.tensor.clone().mean().into_data().value[0]); - self.state - .update(loss, 1, FormatOptions::new(Self::NAME).precision(2)) + self.state.update( + loss, + batch_size, + FormatOptions::new(Self::NAME).precision(2), + ) } fn clear(&mut self) { diff --git a/crates/burn-train/src/metric/processor/metrics.rs b/crates/burn-train/src/metric/processor/metrics.rs index e2992f12b..86d65fd2e 100644 --- a/crates/burn-train/src/metric/processor/metrics.rs +++ b/crates/burn-train/src/metric/processor/metrics.rs @@ -24,7 +24,7 @@ impl Default for Metrics { impl Metrics { /// Register a training metric. - pub(crate) fn register_metric_train(&mut self, metric: Me) + pub(crate) fn register_train_metric(&mut self, metric: Me) where T: Adaptor + 'static, { diff --git a/crates/burn-train/src/metric/state.rs b/crates/burn-train/src/metric/state.rs index 9a188198d..5b21f510c 100644 --- a/crates/burn-train/src/metric/state.rs +++ b/crates/burn-train/src/metric/state.rs @@ -1,4 +1,4 @@ -use crate::metric::{format_float, MetricEntry, Numeric}; +use crate::metric::{format_float, MetricEntry, Numeric, NumericEntry}; /// Useful utility to implement numeric metrics. /// @@ -67,7 +67,8 @@ impl NumericMetricState { let value_current = value; let value_running = self.sum / self.count as f64; - let serialized = value_current.to_string(); + // Numeric metric state is an aggregated value + let serialized = NumericEntry::Aggregated(value_current, batch_size).serialize(); let (formatted_current, formatted_running) = match format.precision { Some(precision) => ( diff --git a/crates/burn-train/src/metric/store/aggregate.rs b/crates/burn-train/src/metric/store/aggregate.rs index 679f6fa22..c96fa2c14 100644 --- a/crates/burn-train/src/metric/store/aggregate.rs +++ b/crates/burn-train/src/metric/store/aggregate.rs @@ -1,4 +1,4 @@ -use crate::logger::MetricLogger; +use crate::{logger::MetricLogger, metric::NumericEntry}; use std::collections::HashMap; use super::{Aggregate, Direction}; @@ -48,8 +48,18 @@ impl NumericMetricsAggregate { return None; } - let num_points = points.len(); - let sum = points.into_iter().sum::(); + // Accurately compute the aggregated value based on the *actual* number of points + // since not all mini-batches are guaranteed to have the specified batch size + let (sum, num_points) = points + .into_iter() + .map(|entry| match entry { + NumericEntry::Value(v) => (v, 1), + // Right now the mean is the only aggregate available, so we can assume that the sum + // of an entry corresponds to (value * number of elements) + NumericEntry::Aggregated(v, n) => (v * n as f64, n), + }) + .reduce(|(acc_v, acc_n), (v, n)| (acc_v + v, acc_n + n)) + .unwrap(); let value = match aggregate { Aggregate::Mean => sum / num_points as f64, }; @@ -105,7 +115,10 @@ impl NumericMetricsAggregate { #[cfg(test)] mod tests { - use crate::{logger::FileMetricLogger, metric::MetricEntry}; + use crate::{ + logger::{FileMetricLogger, InMemoryMetricLogger}, + metric::MetricEntry, + }; use super::*; @@ -159,4 +172,34 @@ mod tests { assert_eq!(value, 2); } + + #[test] + fn should_aggregate_numeric_entry() { + let mut logger = InMemoryMetricLogger::default(); + let mut aggregate = NumericMetricsAggregate::default(); + let metric_name = "Loss"; + + // Epoch 1 + let loss_1 = 0.5; + let loss_2 = 1.25; // (1.5 + 1.0) / 2 = 2.5 / 2 + let entry = MetricEntry::new( + metric_name.to_string(), + loss_1.to_string(), + NumericEntry::Value(loss_1).serialize(), + ); + logger.log(&entry); + let entry = MetricEntry::new( + metric_name.to_string(), + loss_2.to_string(), + NumericEntry::Aggregated(loss_2, 2).serialize(), + ); + logger.log(&entry); + + let value = aggregate + .aggregate(metric_name, 1, Aggregate::Mean, &mut [Box::new(logger)]) + .unwrap(); + + // Average should be (0.5 + 1.25 * 2) / 3 = 1.0, not (0.5 + 1.25) / 2 = 0.875 + assert_eq!(value, 1.0); + } } diff --git a/examples/custom-image-dataset/src/training.rs b/examples/custom-image-dataset/src/training.rs index 38d69572d..6bb8651d3 100644 --- a/examples/custom-image-dataset/src/training.rs +++ b/examples/custom-image-dataset/src/training.rs @@ -65,8 +65,15 @@ pub struct TrainingConfig { pub learning_rate: f64, } +fn create_artifact_dir(artifact_dir: &str) { + // Remove existing artifacts before to get an accurate learner summary + std::fs::remove_dir_all(artifact_dir).ok(); + std::fs::create_dir_all(artifact_dir).ok(); +} + pub fn train(config: TrainingConfig, device: B::Device) { - std::fs::create_dir_all(ARTIFACT_DIR).ok(); + create_artifact_dir(ARTIFACT_DIR); + config .save(format!("{ARTIFACT_DIR}/config.json")) .expect("Config should be saved successfully"); @@ -98,6 +105,7 @@ pub fn train(config: TrainingConfig, device: B::Device) { .with_file_checkpointer(CompactRecorder::new()) .devices(vec![device.clone()]) .num_epochs(config.num_epochs) + .summary() .build( Cnn::new(NUM_CLASSES.into(), &device), config.optimizer.init(), diff --git a/examples/guide/src/training.rs b/examples/guide/src/training.rs index 7f1bb75d0..682e9bb56 100644 --- a/examples/guide/src/training.rs +++ b/examples/guide/src/training.rs @@ -60,8 +60,14 @@ pub struct TrainingConfig { pub learning_rate: f64, } -pub fn train(artifact_dir: &str, config: TrainingConfig, device: B::Device) { +fn create_artifact_dir(artifact_dir: &str) { + // Remove existing artifacts before to get an accurate learner summary + std::fs::remove_dir_all(artifact_dir).ok(); std::fs::create_dir_all(artifact_dir).ok(); +} + +pub fn train(artifact_dir: &str, config: TrainingConfig, device: B::Device) { + create_artifact_dir(artifact_dir); config .save(format!("{artifact_dir}/config.json")) .expect("Config should be saved successfully"); @@ -91,6 +97,7 @@ pub fn train(artifact_dir: &str, config: TrainingConfig, dev .with_file_checkpointer(CompactRecorder::new()) .devices(vec![device.clone()]) .num_epochs(config.num_epochs) + .summary() .build( config.model.init::(&device), config.optimizer.init(), diff --git a/examples/mnist/src/training.rs b/examples/mnist/src/training.rs index cfaf15d8d..5ba284137 100644 --- a/examples/mnist/src/training.rs +++ b/examples/mnist/src/training.rs @@ -34,7 +34,14 @@ pub struct MnistTrainingConfig { pub optimizer: AdamConfig, } +fn create_artifact_dir(artifact_dir: &str) { + // Remove existing artifacts before to get an accurate learner summary + std::fs::remove_dir_all(artifact_dir).ok(); + std::fs::create_dir_all(artifact_dir).ok(); +} + pub fn run(device: B::Device) { + create_artifact_dir(ARTIFACT_DIR); // Config let config_optimizer = AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))); let config = MnistTrainingConfig::new(config_optimizer); @@ -76,6 +83,7 @@ pub fn run(device: B::Device) { )) .devices(vec![device.clone()]) .num_epochs(config.num_epochs) + .summary() .build(Model::new(&device), config.optimizer.init(), 1e-4); let model_trained = learner.fit(dataloader_train, dataloader_test); diff --git a/examples/simple-regression/src/training.rs b/examples/simple-regression/src/training.rs index 3cb1bbe26..3467082a6 100644 --- a/examples/simple-regression/src/training.rs +++ b/examples/simple-regression/src/training.rs @@ -81,6 +81,7 @@ pub fn run(device: B::Device) { )) .devices(vec![device.clone()]) .num_epochs(config.num_epochs) + .summary() .build(model, config.optimizer.init(), 5e-3); let model_trained = learner.fit(dataloader_train, dataloader_test); diff --git a/examples/text-classification/src/training.rs b/examples/text-classification/src/training.rs index 203b546dd..05b2da12c 100644 --- a/examples/text-classification/src/training.rs +++ b/examples/text-classification/src/training.rs @@ -102,6 +102,7 @@ pub fn train( .with_file_checkpointer(CompactRecorder::new()) .devices(devices) .num_epochs(config.num_epochs) + .summary() .build(model, optim, lr_scheduler); // Train the model diff --git a/examples/text-generation/src/training.rs b/examples/text-generation/src/training.rs index 55b4cbcc8..1860e4186 100644 --- a/examples/text-generation/src/training.rs +++ b/examples/text-generation/src/training.rs @@ -80,6 +80,7 @@ pub fn train + 'static>( .devices(vec![device]) .grads_accumulation(accum) .num_epochs(config.num_epochs) + .summary() .build(model, optim, lr_scheduler); let model_trained = learner.fit(dataloader_train, dataloader_test);