mirror of https://github.com/tracel-ai/burn.git
Add learner training report summary (#1591)
* Add training report summary * Fix LossMetric batch size state * Add NumericEntry de/serialize * Fix clippy suggestion * Compact recorder does not use compression (anymore) * Add learner summary expected results tests * Add summary to learner builder and automatically display in fit - Add LearnerSummaryConfig - Keep track of summary metrics names - Add model field when displaying from learner.fit()
This commit is contained in:
parent
bdb62fbcd0
commit
0cbe9a927d
|
@ -1,11 +1,11 @@
|
|||
# Training
|
||||
|
||||
We are now ready to write the necessary code to train our model on the MNIST dataset.
|
||||
We shall define the code for this training section in the file: `src/training.rs`.
|
||||
We are now ready to write the necessary code to train our model on the MNIST dataset. We shall
|
||||
define the code for this training section in the file: `src/training.rs`.
|
||||
|
||||
Instead of a simple tensor, the model should output an item that can be understood by the learner, a struct whose
|
||||
responsibility is to apply an optimizer to the model. The output struct is used for all metrics
|
||||
calculated during the training. Therefore it should include all the necessary information to
|
||||
Instead of a simple tensor, the model should output an item that can be understood by the learner, a
|
||||
struct whose responsibility is to apply an optimizer to the model. The output struct is used for all
|
||||
metrics calculated during the training. Therefore it should include all the necessary information to
|
||||
calculate any metric that you want for a task.
|
||||
|
||||
Burn provides two basic output types: `ClassificationOutput` and `RegressionOutput`. They implement
|
||||
|
@ -110,8 +110,14 @@ pub struct TrainingConfig {
|
|||
pub learning_rate: f64,
|
||||
}
|
||||
|
||||
pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {
|
||||
fn create_artifact_dir(artifact_dir: &str) {
|
||||
// Remove existing artifacts before to get an accurate learner summary
|
||||
std::fs::remove_dir_all(artifact_dir).ok();
|
||||
std::fs::create_dir_all(artifact_dir).ok();
|
||||
}
|
||||
|
||||
pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {
|
||||
create_artifact_dir(artifact_dir);
|
||||
config
|
||||
.save(format!("{artifact_dir}/config.json"))
|
||||
.expect("Config should be saved successfully");
|
||||
|
@ -141,6 +147,7 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
|
|||
.with_file_checkpointer(CompactRecorder::new())
|
||||
.devices(vec![device.clone()])
|
||||
.num_epochs(config.num_epochs)
|
||||
.summary()
|
||||
.build(
|
||||
config.model.init::<B>(&device),
|
||||
config.optimizer.init(),
|
||||
|
@ -181,8 +188,8 @@ Once the learner is created, we can simply call `fit` and provide the training a
|
|||
dataloaders. For the sake of simplicity in this example, we employ the test set as the validation
|
||||
set; however, we do not recommend this practice for actual usage.
|
||||
|
||||
Finally, the trained model is returned by the `fit` method, and the only remaining task is saving
|
||||
the trained weights using the `CompactRecorder`. This recorder employs the `MessagePack` format with
|
||||
`gzip` compression, `f16` for floats and `i16` for integers. Other recorders are available, offering
|
||||
support for various formats, such as `BinCode` and `JSON`, with or without compression. Any backend,
|
||||
regardless of precision, can load recorded data of any kind.
|
||||
Finally, the trained model is returned by the `fit` method. The trained weights are then saved using
|
||||
the `CompactRecorder`. This recorder employs the `MessagePack` format with half precision, `f16` for
|
||||
floats and `i16` for integers. Other recorders are available, offering support for various formats,
|
||||
such as `BinCode` and `JSON`, with or without compression. Any backend, regardless of precision, can
|
||||
load recorded data of any kind.
|
||||
|
|
|
@ -2,6 +2,7 @@ use crate::checkpoint::{Checkpointer, CheckpointingAction, CheckpointingStrategy
|
|||
use crate::components::LearnerComponents;
|
||||
use crate::learner::EarlyStoppingStrategy;
|
||||
use crate::metric::store::EventStoreClient;
|
||||
use crate::LearnerSummaryConfig;
|
||||
use burn_core::lr_scheduler::LrScheduler;
|
||||
use burn_core::module::Module;
|
||||
use burn_core::optim::Optimizer;
|
||||
|
@ -26,6 +27,7 @@ pub struct Learner<LC: LearnerComponents> {
|
|||
pub(crate) early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
|
||||
pub(crate) event_processor: LC::EventProcessor,
|
||||
pub(crate) event_store: Arc<EventStoreClient>,
|
||||
pub(crate) summary: Option<LearnerSummaryConfig>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::log::install_file_logger;
|
||||
|
@ -14,7 +15,7 @@ use crate::metric::processor::{FullEventProcessor, Metrics};
|
|||
use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
|
||||
use crate::metric::{Adaptor, LossMetric, Metric};
|
||||
use crate::renderer::{default_renderer, MetricsRenderer};
|
||||
use crate::LearnerCheckpointer;
|
||||
use crate::{LearnerCheckpointer, LearnerSummaryConfig};
|
||||
use burn_core::lr_scheduler::LrScheduler;
|
||||
use burn_core::module::AutodiffModule;
|
||||
use burn_core::optim::Optimizer;
|
||||
|
@ -53,6 +54,8 @@ where
|
|||
num_loggers: usize,
|
||||
checkpointer_strategy: Box<dyn CheckpointingStrategy>,
|
||||
early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
|
||||
summary_metrics: HashSet<String>,
|
||||
summary: bool,
|
||||
}
|
||||
|
||||
impl<B, T, V, M, O, S> LearnerBuilder<B, T, V, M, O, S>
|
||||
|
@ -94,6 +97,8 @@ where
|
|||
.build(),
|
||||
),
|
||||
early_stopping: None,
|
||||
summary_metrics: HashSet::new(),
|
||||
summary: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -140,7 +145,7 @@ where
|
|||
where
|
||||
T: Adaptor<Me::Input>,
|
||||
{
|
||||
self.metrics.register_metric_train(metric);
|
||||
self.metrics.register_train_metric(metric);
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -174,6 +179,7 @@ where
|
|||
Me: Metric + crate::metric::Numeric + 'static,
|
||||
T: Adaptor<Me::Input>,
|
||||
{
|
||||
self.summary_metrics.insert(Me::NAME.to_string());
|
||||
self.metrics.register_train_metric_numeric(metric);
|
||||
self
|
||||
}
|
||||
|
@ -186,6 +192,7 @@ where
|
|||
where
|
||||
V: Adaptor<Me::Input>,
|
||||
{
|
||||
self.summary_metrics.insert(Me::NAME.to_string());
|
||||
self.metrics.register_valid_metric_numeric(metric);
|
||||
self
|
||||
}
|
||||
|
@ -266,6 +273,14 @@ where
|
|||
self
|
||||
}
|
||||
|
||||
/// Enable the training summary report.
|
||||
///
|
||||
/// The summary will be displayed at the end of `.fit()`.
|
||||
pub fn summary(mut self) -> Self {
|
||||
self.summary = true;
|
||||
self
|
||||
}
|
||||
|
||||
/// Create the [learner](Learner) from a [model](AutodiffModule) and an [optimizer](Optimizer).
|
||||
/// The [learning rate scheduler](LrScheduler) can also be a simple
|
||||
/// [learning rate](burn_core::LearningRate).
|
||||
|
@ -320,6 +335,15 @@ where
|
|||
LearnerCheckpointer::new(model, optim, scheduler, self.checkpointer_strategy)
|
||||
});
|
||||
|
||||
let summary = if self.summary {
|
||||
Some(LearnerSummaryConfig {
|
||||
directory: self.directory,
|
||||
metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Learner {
|
||||
model,
|
||||
optim,
|
||||
|
@ -333,6 +357,7 @@ where
|
|||
devices: self.devices,
|
||||
interrupter: self.interrupter,
|
||||
early_stopping: self.early_stopping,
|
||||
summary,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ mod early_stopping;
|
|||
mod epoch;
|
||||
mod regression;
|
||||
mod step;
|
||||
mod summary;
|
||||
mod train_val;
|
||||
|
||||
pub(crate) mod log;
|
||||
|
@ -16,5 +17,6 @@ pub use early_stopping::*;
|
|||
pub use epoch::*;
|
||||
pub use regression::*;
|
||||
pub use step::*;
|
||||
pub use summary::*;
|
||||
pub use train::*;
|
||||
pub use train_val::*;
|
||||
|
|
|
@ -0,0 +1,307 @@
|
|||
use core::cmp::Ordering;
|
||||
use std::{fmt::Display, path::Path};
|
||||
|
||||
use crate::{
|
||||
logger::FileMetricLogger,
|
||||
metric::store::{Aggregate, EventStore, LogEventStore, Split},
|
||||
};
|
||||
|
||||
/// Contains the metric value at a given time.
|
||||
pub struct MetricEntry {
|
||||
/// The step at which the metric was recorded (i.e., epoch).
|
||||
pub step: usize,
|
||||
/// The metric value.
|
||||
pub value: f64,
|
||||
}
|
||||
|
||||
/// Contains the summary of recorded values for a given metric.
|
||||
pub struct MetricSummary {
|
||||
/// The metric name.
|
||||
pub name: String,
|
||||
/// The metric entries.
|
||||
pub entries: Vec<MetricEntry>,
|
||||
}
|
||||
|
||||
impl MetricSummary {
|
||||
fn new<E: EventStore>(
|
||||
event_store: &mut E,
|
||||
metric: &str,
|
||||
split: Split,
|
||||
num_epochs: usize,
|
||||
) -> Option<Self> {
|
||||
let entries = (1..=num_epochs)
|
||||
.filter_map(|epoch| {
|
||||
event_store
|
||||
.find_metric(metric, epoch, Aggregate::Mean, split)
|
||||
.map(|value| MetricEntry { step: epoch, value })
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if entries.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(Self {
|
||||
name: metric.to_string(),
|
||||
entries,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Contains the summary of recorded metrics for the training and validation steps.
|
||||
pub struct SummaryMetrics {
|
||||
/// Training metrics summary.
|
||||
pub train: Vec<MetricSummary>,
|
||||
/// Validation metrics summary.
|
||||
pub valid: Vec<MetricSummary>,
|
||||
}
|
||||
|
||||
/// Detailed training summary.
|
||||
pub struct LearnerSummary {
|
||||
/// The number of epochs completed.
|
||||
pub epochs: usize,
|
||||
/// The summary of recorded metrics during training.
|
||||
pub metrics: SummaryMetrics,
|
||||
/// The model name (only recorded within the learner).
|
||||
pub(crate) model: Option<String>,
|
||||
}
|
||||
|
||||
impl LearnerSummary {
|
||||
/// Creates a new learner summary for the specified metrics.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `directory` - The directory containing the training artifacts (checkpoints and logs).
|
||||
/// * `metrics` - The list of metrics to collect for the summary.
|
||||
pub fn new<S: AsRef<str>>(directory: &str, metrics: &[S]) -> Result<Self, String> {
|
||||
let directory_path = Path::new(directory);
|
||||
if !directory_path.exists() {
|
||||
return Err(format!("Artifact directory does not exist at: {directory}"));
|
||||
}
|
||||
let train_dir = directory_path.join("train");
|
||||
let valid_dir = directory_path.join("valid");
|
||||
if !train_dir.exists() & !valid_dir.exists() {
|
||||
return Err(format!(
|
||||
"No training or validation artifacts found at: {directory}"
|
||||
));
|
||||
}
|
||||
|
||||
let mut event_store = LogEventStore::default();
|
||||
|
||||
let train_logger = FileMetricLogger::new(train_dir.to_str().unwrap());
|
||||
let valid_logger = FileMetricLogger::new(valid_dir.to_str().unwrap());
|
||||
|
||||
// Number of recorded epochs
|
||||
let epochs = train_logger.epochs();
|
||||
|
||||
event_store.register_logger_train(train_logger);
|
||||
event_store.register_logger_valid(valid_logger);
|
||||
|
||||
let train_summary = metrics
|
||||
.iter()
|
||||
.filter_map(|metric| {
|
||||
MetricSummary::new(&mut event_store, metric.as_ref(), Split::Train, epochs)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let valid_summary = metrics
|
||||
.iter()
|
||||
.filter_map(|metric| {
|
||||
MetricSummary::new(&mut event_store, metric.as_ref(), Split::Valid, epochs)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Ok(Self {
|
||||
epochs,
|
||||
metrics: SummaryMetrics {
|
||||
train: train_summary,
|
||||
valid: valid_summary,
|
||||
},
|
||||
model: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn with_model(mut self, name: String) -> Self {
|
||||
self.model = Some(name);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for LearnerSummary {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// Compute the max length for each column
|
||||
let split_train = "Train";
|
||||
let split_valid = "Valid";
|
||||
let max_split_len = "Split".len().max(split_train.len()).max(split_valid.len());
|
||||
let mut max_metric_len = "Metric".len();
|
||||
for metric in self.metrics.train.iter() {
|
||||
max_metric_len = max_metric_len.max(metric.name.len());
|
||||
}
|
||||
for metric in self.metrics.valid.iter() {
|
||||
max_metric_len = max_metric_len.max(metric.name.len());
|
||||
}
|
||||
|
||||
// Summary header
|
||||
writeln!(
|
||||
f,
|
||||
"{:=>width_symbol$} Learner Summary {:=>width_symbol$}",
|
||||
"",
|
||||
"",
|
||||
width_symbol = 24,
|
||||
)?;
|
||||
|
||||
if let Some(model) = &self.model {
|
||||
writeln!(f, "Model: {model}")?;
|
||||
}
|
||||
writeln!(f, "Total Epochs: {epochs}\n\n", epochs = self.epochs)?;
|
||||
|
||||
// Metrics table header
|
||||
writeln!(
|
||||
f,
|
||||
"| {:<width_split$} | {:<width_metric$} | Min. | Epoch | Max. | Epoch |\n|{:->width_split$}--|{:->width_metric$}--|----------|----------|----------|----------|",
|
||||
"Split", "Metric", "", "",
|
||||
width_split = max_split_len,
|
||||
width_metric = max_metric_len,
|
||||
)?;
|
||||
|
||||
// Table entries
|
||||
fn cmp_f64(a: &f64, b: &f64) -> Ordering {
|
||||
match (a.is_nan(), b.is_nan()) {
|
||||
(true, true) => Ordering::Equal,
|
||||
(true, false) => Ordering::Greater,
|
||||
(false, true) => Ordering::Less,
|
||||
_ => a.partial_cmp(b).unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
let mut write_metrics_summary = |metrics: &[MetricSummary],
|
||||
split: &str|
|
||||
-> std::fmt::Result {
|
||||
for metric in metrics.iter() {
|
||||
if metric.entries.is_empty() {
|
||||
continue; // skip metrics with no recorded values
|
||||
}
|
||||
|
||||
// Compute the min & max for each metric
|
||||
let metric_min = metric
|
||||
.entries
|
||||
.iter()
|
||||
.min_by(|a, b| cmp_f64(&a.value, &b.value))
|
||||
.unwrap();
|
||||
let metric_max = metric
|
||||
.entries
|
||||
.iter()
|
||||
.max_by(|a, b| cmp_f64(&a.value, &b.value))
|
||||
.unwrap();
|
||||
|
||||
writeln!(
|
||||
f,
|
||||
"| {:<width_split$} | {:<width_metric$} | {:<9.3?}| {:<9?}| {:<9.3?}| {:<9.3?}|",
|
||||
split,
|
||||
metric.name,
|
||||
metric_min.value,
|
||||
metric_min.step,
|
||||
metric_max.value,
|
||||
metric_max.step,
|
||||
width_split = max_split_len,
|
||||
width_metric = max_metric_len,
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
};
|
||||
|
||||
write_metrics_summary(&self.metrics.train, split_train)?;
|
||||
write_metrics_summary(&self.metrics.valid, split_valid)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct LearnerSummaryConfig {
|
||||
pub(crate) directory: String,
|
||||
pub(crate) metrics: Vec<String>,
|
||||
}
|
||||
|
||||
impl LearnerSummaryConfig {
|
||||
pub fn init(&self) -> Result<LearnerSummary, String> {
|
||||
LearnerSummary::new(&self.directory, &self.metrics[..])
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Summary artifacts should exist"]
|
||||
fn test_artifact_dir_should_exist() {
|
||||
let dir = "/tmp/learner-summary-not-found";
|
||||
let _summary = LearnerSummary::new(dir, &["Loss"]).expect("Summary artifacts should exist");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Summary artifacts should exist"]
|
||||
fn test_train_valid_artifacts_should_exist() {
|
||||
let dir = "/tmp/test-learner-summary-empty";
|
||||
std::fs::create_dir_all(dir).ok();
|
||||
let _summary = LearnerSummary::new(dir, &["Loss"]).expect("Summary artifacts should exist");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_summary_should_be_empty() {
|
||||
let dir = Path::new("/tmp/test-learner-summary-empty-metrics");
|
||||
std::fs::create_dir_all(dir).unwrap();
|
||||
std::fs::create_dir_all(dir.join("train/epoch-1")).unwrap();
|
||||
std::fs::create_dir_all(dir.join("valid/epoch-1")).unwrap();
|
||||
let summary = LearnerSummary::new(dir.to_str().unwrap(), &["Loss"])
|
||||
.expect("Summary artifacts should exist");
|
||||
|
||||
assert_eq!(summary.epochs, 1);
|
||||
|
||||
assert_eq!(summary.metrics.train.len(), 0);
|
||||
assert_eq!(summary.metrics.valid.len(), 0);
|
||||
|
||||
std::fs::remove_dir_all(dir).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_summary_should_be_collected() {
|
||||
let dir = Path::new("/tmp/test-learner-summary");
|
||||
let train_dir = dir.join("train/epoch-1");
|
||||
let valid_dir = dir.join("valid/epoch-1");
|
||||
std::fs::create_dir_all(dir).unwrap();
|
||||
std::fs::create_dir_all(&train_dir).unwrap();
|
||||
std::fs::create_dir_all(&valid_dir).unwrap();
|
||||
|
||||
std::fs::write(train_dir.join("Loss.log"), "1.0\n2.0").expect("Unable to write file");
|
||||
std::fs::write(valid_dir.join("Loss.log"), "1.0").expect("Unable to write file");
|
||||
|
||||
let summary = LearnerSummary::new(dir.to_str().unwrap(), &["Loss"])
|
||||
.expect("Summary artifacts should exist");
|
||||
|
||||
assert_eq!(summary.epochs, 1);
|
||||
|
||||
// Only Loss metric
|
||||
assert_eq!(summary.metrics.train.len(), 1);
|
||||
assert_eq!(summary.metrics.valid.len(), 1);
|
||||
|
||||
// Aggregated train metric entries for 1 epoch
|
||||
let train_metric = &summary.metrics.train[0];
|
||||
assert_eq!(train_metric.name, "Loss");
|
||||
assert_eq!(train_metric.entries.len(), 1);
|
||||
let entry = &train_metric.entries[0];
|
||||
assert_eq!(entry.step, 1); // epoch = 1
|
||||
assert_eq!(entry.value, 1.5); // (1 + 2) / 2
|
||||
|
||||
// Aggregated valid metric entries for 1 epoch
|
||||
let valid_metric = &summary.metrics.valid[0];
|
||||
assert_eq!(valid_metric.name, "Loss");
|
||||
assert_eq!(valid_metric.entries.len(), 1);
|
||||
let entry = &valid_metric.entries[0];
|
||||
assert_eq!(entry.step, 1); // epoch = 1
|
||||
assert_eq!(entry.value, 1.0);
|
||||
|
||||
std::fs::remove_dir_all(dir).unwrap();
|
||||
}
|
||||
}
|
|
@ -199,6 +199,19 @@ impl<LC: LearnerComponents> Learner<LC> {
|
|||
}
|
||||
}
|
||||
|
||||
// Display learner summary
|
||||
if let Some(summary) = self.summary {
|
||||
match summary.init() {
|
||||
Ok(summary) => {
|
||||
// Drop event processor (includes renderer) so the summary is displayed
|
||||
// when switching back to "main" screen
|
||||
core::mem::drop(self.event_processor);
|
||||
println!("{}", summary.with_model(self.model.to_string()))
|
||||
}
|
||||
Err(err) => log::error!("Could not retrieve learner summary:\n{err}"),
|
||||
}
|
||||
}
|
||||
|
||||
self.model
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger};
|
||||
use crate::metric::MetricEntry;
|
||||
use std::collections::HashMap;
|
||||
use crate::metric::{MetricEntry, NumericEntry};
|
||||
use std::{collections::HashMap, fs};
|
||||
|
||||
const EPOCH_PREFIX: &str = "epoch-";
|
||||
|
||||
/// Metric logger.
|
||||
pub trait MetricLogger: Send {
|
||||
|
@ -19,7 +21,7 @@ pub trait MetricLogger: Send {
|
|||
fn end_epoch(&mut self, epoch: usize);
|
||||
|
||||
/// Read the logs for an epoch.
|
||||
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<f64>, String>;
|
||||
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<NumericEntry>, String>;
|
||||
}
|
||||
|
||||
/// The file metric logger.
|
||||
|
@ -47,14 +49,44 @@ impl FileMetricLogger {
|
|||
}
|
||||
}
|
||||
|
||||
/// Number of epochs recorded.
|
||||
pub(crate) fn epochs(&self) -> usize {
|
||||
let mut max_epoch = 0;
|
||||
|
||||
for path in fs::read_dir(&self.directory).unwrap() {
|
||||
let path = path.unwrap();
|
||||
|
||||
if fs::metadata(path.path()).unwrap().is_dir() {
|
||||
let dir_name = path.file_name().into_string().unwrap();
|
||||
|
||||
if !dir_name.starts_with(EPOCH_PREFIX) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let epoch = dir_name.replace(EPOCH_PREFIX, "").parse::<usize>().ok();
|
||||
|
||||
if let Some(epoch) = epoch {
|
||||
if epoch > max_epoch {
|
||||
max_epoch = epoch;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
max_epoch
|
||||
}
|
||||
|
||||
fn epoch_directory(&self, epoch: usize) -> String {
|
||||
format!("{}/{}{}", self.directory, EPOCH_PREFIX, epoch)
|
||||
}
|
||||
fn file_path(&self, name: &str, epoch: usize) -> String {
|
||||
let directory = format!("{}/epoch-{}", self.directory, epoch);
|
||||
let directory = self.epoch_directory(epoch);
|
||||
let name = name.replace(' ', "_");
|
||||
|
||||
format!("{directory}/{name}.log")
|
||||
}
|
||||
fn create_directory(&self, epoch: usize) {
|
||||
let directory = format!("{}/epoch-{}", self.directory, epoch);
|
||||
let directory = self.epoch_directory(epoch);
|
||||
std::fs::create_dir_all(directory).ok();
|
||||
}
|
||||
}
|
||||
|
@ -88,7 +120,7 @@ impl MetricLogger for FileMetricLogger {
|
|||
self.epoch = epoch + 1;
|
||||
}
|
||||
|
||||
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<f64>, String> {
|
||||
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<NumericEntry>, String> {
|
||||
if let Some(value) = self.loggers.get(name) {
|
||||
value.sync()
|
||||
}
|
||||
|
@ -104,7 +136,7 @@ impl MetricLogger for FileMetricLogger {
|
|||
if value.is_empty() {
|
||||
None
|
||||
} else {
|
||||
match value.parse::<f64>() {
|
||||
match NumericEntry::deserialize(value) {
|
||||
Ok(value) => Some(value),
|
||||
Err(err) => {
|
||||
log::error!("{err}");
|
||||
|
@ -117,7 +149,7 @@ impl MetricLogger for FileMetricLogger {
|
|||
.collect();
|
||||
|
||||
if errors {
|
||||
Err("Parsing float errors".to_string())
|
||||
Err("Parsing numeric entry errors".to_string())
|
||||
} else {
|
||||
Ok(data)
|
||||
}
|
||||
|
@ -154,7 +186,7 @@ impl MetricLogger for InMemoryMetricLogger {
|
|||
}
|
||||
}
|
||||
|
||||
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<f64>, String> {
|
||||
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<NumericEntry>, String> {
|
||||
let values = match self.values.get(name) {
|
||||
Some(values) => values,
|
||||
None => return Ok(Vec::new()),
|
||||
|
@ -164,7 +196,7 @@ impl MetricLogger for InMemoryMetricLogger {
|
|||
Some(logger) => Ok(logger
|
||||
.values
|
||||
.iter()
|
||||
.filter_map(|value| value.parse::<f64>().ok())
|
||||
.filter_map(|value| NumericEntry::deserialize(value).ok())
|
||||
.collect()),
|
||||
None => Ok(Vec::new()),
|
||||
}
|
||||
|
|
|
@ -84,6 +84,49 @@ pub struct MetricEntry {
|
|||
pub serialize: String,
|
||||
}
|
||||
|
||||
/// Numeric metric entry.
|
||||
pub enum NumericEntry {
|
||||
/// Single numeric value.
|
||||
Value(f64),
|
||||
/// Aggregated numeric (value, number of elements).
|
||||
Aggregated(f64, usize),
|
||||
}
|
||||
|
||||
impl NumericEntry {
|
||||
pub(crate) fn serialize(&self) -> String {
|
||||
match self {
|
||||
Self::Value(v) => v.to_string(),
|
||||
Self::Aggregated(v, n) => format!("{v},{n}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn deserialize(entry: &str) -> Result<Self, String> {
|
||||
// Check for comma separated values
|
||||
let values = entry.split(',').collect::<Vec<_>>();
|
||||
let num_values = values.len();
|
||||
|
||||
if num_values == 1 {
|
||||
// Numeric value
|
||||
match values[0].parse::<f64>() {
|
||||
Ok(value) => Ok(NumericEntry::Value(value)),
|
||||
Err(err) => Err(err.to_string()),
|
||||
}
|
||||
} else if num_values == 2 {
|
||||
// Aggregated numeric (value, number of elements)
|
||||
let (value, numel) = (values[0], values[1]);
|
||||
match value.parse::<f64>() {
|
||||
Ok(value) => match numel.parse::<usize>() {
|
||||
Ok(numel) => Ok(NumericEntry::Aggregated(value, numel)),
|
||||
Err(err) => Err(err.to_string()),
|
||||
},
|
||||
Err(err) => Err(err.to_string()),
|
||||
}
|
||||
} else {
|
||||
Err("Invalid number of values for numeric entry".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Format a float with the given precision. Will use scientific notation if necessary.
|
||||
pub fn format_float(float: f64, precision: usize) -> String {
|
||||
let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);
|
||||
|
|
|
@ -33,10 +33,14 @@ impl<B: Backend> Metric for LossMetric<B> {
|
|||
type Input = LossInput<B>;
|
||||
|
||||
fn update(&mut self, loss: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry {
|
||||
let [batch_size] = loss.tensor.dims();
|
||||
let loss = f64::from_elem(loss.tensor.clone().mean().into_data().value[0]);
|
||||
|
||||
self.state
|
||||
.update(loss, 1, FormatOptions::new(Self::NAME).precision(2))
|
||||
self.state.update(
|
||||
loss,
|
||||
batch_size,
|
||||
FormatOptions::new(Self::NAME).precision(2),
|
||||
)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
|
|
|
@ -24,7 +24,7 @@ impl<T, V> Default for Metrics<T, V> {
|
|||
|
||||
impl<T, V> Metrics<T, V> {
|
||||
/// Register a training metric.
|
||||
pub(crate) fn register_metric_train<Me: Metric + 'static>(&mut self, metric: Me)
|
||||
pub(crate) fn register_train_metric<Me: Metric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
T: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::metric::{format_float, MetricEntry, Numeric};
|
||||
use crate::metric::{format_float, MetricEntry, Numeric, NumericEntry};
|
||||
|
||||
/// Useful utility to implement numeric metrics.
|
||||
///
|
||||
|
@ -67,7 +67,8 @@ impl NumericMetricState {
|
|||
|
||||
let value_current = value;
|
||||
let value_running = self.sum / self.count as f64;
|
||||
let serialized = value_current.to_string();
|
||||
// Numeric metric state is an aggregated value
|
||||
let serialized = NumericEntry::Aggregated(value_current, batch_size).serialize();
|
||||
|
||||
let (formatted_current, formatted_running) = match format.precision {
|
||||
Some(precision) => (
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::logger::MetricLogger;
|
||||
use crate::{logger::MetricLogger, metric::NumericEntry};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::{Aggregate, Direction};
|
||||
|
@ -48,8 +48,18 @@ impl NumericMetricsAggregate {
|
|||
return None;
|
||||
}
|
||||
|
||||
let num_points = points.len();
|
||||
let sum = points.into_iter().sum::<f64>();
|
||||
// Accurately compute the aggregated value based on the *actual* number of points
|
||||
// since not all mini-batches are guaranteed to have the specified batch size
|
||||
let (sum, num_points) = points
|
||||
.into_iter()
|
||||
.map(|entry| match entry {
|
||||
NumericEntry::Value(v) => (v, 1),
|
||||
// Right now the mean is the only aggregate available, so we can assume that the sum
|
||||
// of an entry corresponds to (value * number of elements)
|
||||
NumericEntry::Aggregated(v, n) => (v * n as f64, n),
|
||||
})
|
||||
.reduce(|(acc_v, acc_n), (v, n)| (acc_v + v, acc_n + n))
|
||||
.unwrap();
|
||||
let value = match aggregate {
|
||||
Aggregate::Mean => sum / num_points as f64,
|
||||
};
|
||||
|
@ -105,7 +115,10 @@ impl NumericMetricsAggregate {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{logger::FileMetricLogger, metric::MetricEntry};
|
||||
use crate::{
|
||||
logger::{FileMetricLogger, InMemoryMetricLogger},
|
||||
metric::MetricEntry,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
|
@ -159,4 +172,34 @@ mod tests {
|
|||
|
||||
assert_eq!(value, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_aggregate_numeric_entry() {
|
||||
let mut logger = InMemoryMetricLogger::default();
|
||||
let mut aggregate = NumericMetricsAggregate::default();
|
||||
let metric_name = "Loss";
|
||||
|
||||
// Epoch 1
|
||||
let loss_1 = 0.5;
|
||||
let loss_2 = 1.25; // (1.5 + 1.0) / 2 = 2.5 / 2
|
||||
let entry = MetricEntry::new(
|
||||
metric_name.to_string(),
|
||||
loss_1.to_string(),
|
||||
NumericEntry::Value(loss_1).serialize(),
|
||||
);
|
||||
logger.log(&entry);
|
||||
let entry = MetricEntry::new(
|
||||
metric_name.to_string(),
|
||||
loss_2.to_string(),
|
||||
NumericEntry::Aggregated(loss_2, 2).serialize(),
|
||||
);
|
||||
logger.log(&entry);
|
||||
|
||||
let value = aggregate
|
||||
.aggregate(metric_name, 1, Aggregate::Mean, &mut [Box::new(logger)])
|
||||
.unwrap();
|
||||
|
||||
// Average should be (0.5 + 1.25 * 2) / 3 = 1.0, not (0.5 + 1.25) / 2 = 0.875
|
||||
assert_eq!(value, 1.0);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -65,8 +65,15 @@ pub struct TrainingConfig {
|
|||
pub learning_rate: f64,
|
||||
}
|
||||
|
||||
fn create_artifact_dir(artifact_dir: &str) {
|
||||
// Remove existing artifacts before to get an accurate learner summary
|
||||
std::fs::remove_dir_all(artifact_dir).ok();
|
||||
std::fs::create_dir_all(artifact_dir).ok();
|
||||
}
|
||||
|
||||
pub fn train<B: AutodiffBackend>(config: TrainingConfig, device: B::Device) {
|
||||
std::fs::create_dir_all(ARTIFACT_DIR).ok();
|
||||
create_artifact_dir(ARTIFACT_DIR);
|
||||
|
||||
config
|
||||
.save(format!("{ARTIFACT_DIR}/config.json"))
|
||||
.expect("Config should be saved successfully");
|
||||
|
@ -98,6 +105,7 @@ pub fn train<B: AutodiffBackend>(config: TrainingConfig, device: B::Device) {
|
|||
.with_file_checkpointer(CompactRecorder::new())
|
||||
.devices(vec![device.clone()])
|
||||
.num_epochs(config.num_epochs)
|
||||
.summary()
|
||||
.build(
|
||||
Cnn::new(NUM_CLASSES.into(), &device),
|
||||
config.optimizer.init(),
|
||||
|
|
|
@ -60,8 +60,14 @@ pub struct TrainingConfig {
|
|||
pub learning_rate: f64,
|
||||
}
|
||||
|
||||
pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {
|
||||
fn create_artifact_dir(artifact_dir: &str) {
|
||||
// Remove existing artifacts before to get an accurate learner summary
|
||||
std::fs::remove_dir_all(artifact_dir).ok();
|
||||
std::fs::create_dir_all(artifact_dir).ok();
|
||||
}
|
||||
|
||||
pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {
|
||||
create_artifact_dir(artifact_dir);
|
||||
config
|
||||
.save(format!("{artifact_dir}/config.json"))
|
||||
.expect("Config should be saved successfully");
|
||||
|
@ -91,6 +97,7 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
|
|||
.with_file_checkpointer(CompactRecorder::new())
|
||||
.devices(vec![device.clone()])
|
||||
.num_epochs(config.num_epochs)
|
||||
.summary()
|
||||
.build(
|
||||
config.model.init::<B>(&device),
|
||||
config.optimizer.init(),
|
||||
|
|
|
@ -34,7 +34,14 @@ pub struct MnistTrainingConfig {
|
|||
pub optimizer: AdamConfig,
|
||||
}
|
||||
|
||||
fn create_artifact_dir(artifact_dir: &str) {
|
||||
// Remove existing artifacts before to get an accurate learner summary
|
||||
std::fs::remove_dir_all(artifact_dir).ok();
|
||||
std::fs::create_dir_all(artifact_dir).ok();
|
||||
}
|
||||
|
||||
pub fn run<B: AutodiffBackend>(device: B::Device) {
|
||||
create_artifact_dir(ARTIFACT_DIR);
|
||||
// Config
|
||||
let config_optimizer = AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5)));
|
||||
let config = MnistTrainingConfig::new(config_optimizer);
|
||||
|
@ -76,6 +83,7 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
|
|||
))
|
||||
.devices(vec![device.clone()])
|
||||
.num_epochs(config.num_epochs)
|
||||
.summary()
|
||||
.build(Model::new(&device), config.optimizer.init(), 1e-4);
|
||||
|
||||
let model_trained = learner.fit(dataloader_train, dataloader_test);
|
||||
|
|
|
@ -81,6 +81,7 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
|
|||
))
|
||||
.devices(vec![device.clone()])
|
||||
.num_epochs(config.num_epochs)
|
||||
.summary()
|
||||
.build(model, config.optimizer.init(), 5e-3);
|
||||
|
||||
let model_trained = learner.fit(dataloader_train, dataloader_test);
|
||||
|
|
|
@ -102,6 +102,7 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
|
|||
.with_file_checkpointer(CompactRecorder::new())
|
||||
.devices(devices)
|
||||
.num_epochs(config.num_epochs)
|
||||
.summary()
|
||||
.build(model, optim, lr_scheduler);
|
||||
|
||||
// Train the model
|
||||
|
|
|
@ -80,6 +80,7 @@ pub fn train<B: AutodiffBackend, D: Dataset<TextGenerationItem> + 'static>(
|
|||
.devices(vec![device])
|
||||
.grads_accumulation(accum)
|
||||
.num_epochs(config.num_epochs)
|
||||
.summary()
|
||||
.build(model, optim, lr_scheduler);
|
||||
|
||||
let model_trained = learner.fit(dataloader_train, dataloader_test);
|
||||
|
|
Loading…
Reference in New Issue