Add learner training report summary (#1591)

* Add training report summary

* Fix LossMetric batch size state

* Add NumericEntry de/serialize

* Fix clippy suggestion

* Compact recorder does not use compression (anymore)

* Add learner summary expected results tests

* Add summary to learner builder and automatically display in fit

- Add LearnerSummaryConfig
- Keep track of summary metrics names
- Add model field when displaying from learner.fit()
This commit is contained in:
Guillaume Lagrange 2024-04-11 12:32:25 -04:00 committed by GitHub
parent bdb62fbcd0
commit 0cbe9a927d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 539 additions and 34 deletions

View File

@ -1,11 +1,11 @@
# Training
We are now ready to write the necessary code to train our model on the MNIST dataset.
We shall define the code for this training section in the file: `src/training.rs`.
We are now ready to write the necessary code to train our model on the MNIST dataset. We shall
define the code for this training section in the file: `src/training.rs`.
Instead of a simple tensor, the model should output an item that can be understood by the learner, a struct whose
responsibility is to apply an optimizer to the model. The output struct is used for all metrics
calculated during the training. Therefore it should include all the necessary information to
Instead of a simple tensor, the model should output an item that can be understood by the learner, a
struct whose responsibility is to apply an optimizer to the model. The output struct is used for all
metrics calculated during the training. Therefore it should include all the necessary information to
calculate any metric that you want for a task.
Burn provides two basic output types: `ClassificationOutput` and `RegressionOutput`. They implement
@ -110,8 +110,14 @@ pub struct TrainingConfig {
pub learning_rate: f64,
}
pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {
fn create_artifact_dir(artifact_dir: &str) {
// Remove existing artifacts before to get an accurate learner summary
std::fs::remove_dir_all(artifact_dir).ok();
std::fs::create_dir_all(artifact_dir).ok();
}
pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {
create_artifact_dir(artifact_dir);
config
.save(format!("{artifact_dir}/config.json"))
.expect("Config should be saved successfully");
@ -141,6 +147,7 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
.with_file_checkpointer(CompactRecorder::new())
.devices(vec![device.clone()])
.num_epochs(config.num_epochs)
.summary()
.build(
config.model.init::<B>(&device),
config.optimizer.init(),
@ -181,8 +188,8 @@ Once the learner is created, we can simply call `fit` and provide the training a
dataloaders. For the sake of simplicity in this example, we employ the test set as the validation
set; however, we do not recommend this practice for actual usage.
Finally, the trained model is returned by the `fit` method, and the only remaining task is saving
the trained weights using the `CompactRecorder`. This recorder employs the `MessagePack` format with
`gzip` compression, `f16` for floats and `i16` for integers. Other recorders are available, offering
support for various formats, such as `BinCode` and `JSON`, with or without compression. Any backend,
regardless of precision, can load recorded data of any kind.
Finally, the trained model is returned by the `fit` method. The trained weights are then saved using
the `CompactRecorder`. This recorder employs the `MessagePack` format with half precision, `f16` for
floats and `i16` for integers. Other recorders are available, offering support for various formats,
such as `BinCode` and `JSON`, with or without compression. Any backend, regardless of precision, can
load recorded data of any kind.

View File

@ -2,6 +2,7 @@ use crate::checkpoint::{Checkpointer, CheckpointingAction, CheckpointingStrategy
use crate::components::LearnerComponents;
use crate::learner::EarlyStoppingStrategy;
use crate::metric::store::EventStoreClient;
use crate::LearnerSummaryConfig;
use burn_core::lr_scheduler::LrScheduler;
use burn_core::module::Module;
use burn_core::optim::Optimizer;
@ -26,6 +27,7 @@ pub struct Learner<LC: LearnerComponents> {
pub(crate) early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
pub(crate) event_processor: LC::EventProcessor,
pub(crate) event_store: Arc<EventStoreClient>,
pub(crate) summary: Option<LearnerSummaryConfig>,
}
#[derive(new)]

View File

@ -1,3 +1,4 @@
use std::collections::HashSet;
use std::sync::Arc;
use super::log::install_file_logger;
@ -14,7 +15,7 @@ use crate::metric::processor::{FullEventProcessor, Metrics};
use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
use crate::metric::{Adaptor, LossMetric, Metric};
use crate::renderer::{default_renderer, MetricsRenderer};
use crate::LearnerCheckpointer;
use crate::{LearnerCheckpointer, LearnerSummaryConfig};
use burn_core::lr_scheduler::LrScheduler;
use burn_core::module::AutodiffModule;
use burn_core::optim::Optimizer;
@ -53,6 +54,8 @@ where
num_loggers: usize,
checkpointer_strategy: Box<dyn CheckpointingStrategy>,
early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
summary_metrics: HashSet<String>,
summary: bool,
}
impl<B, T, V, M, O, S> LearnerBuilder<B, T, V, M, O, S>
@ -94,6 +97,8 @@ where
.build(),
),
early_stopping: None,
summary_metrics: HashSet::new(),
summary: false,
}
}
@ -140,7 +145,7 @@ where
where
T: Adaptor<Me::Input>,
{
self.metrics.register_metric_train(metric);
self.metrics.register_train_metric(metric);
self
}
@ -174,6 +179,7 @@ where
Me: Metric + crate::metric::Numeric + 'static,
T: Adaptor<Me::Input>,
{
self.summary_metrics.insert(Me::NAME.to_string());
self.metrics.register_train_metric_numeric(metric);
self
}
@ -186,6 +192,7 @@ where
where
V: Adaptor<Me::Input>,
{
self.summary_metrics.insert(Me::NAME.to_string());
self.metrics.register_valid_metric_numeric(metric);
self
}
@ -266,6 +273,14 @@ where
self
}
/// Enable the training summary report.
///
/// The summary will be displayed at the end of `.fit()`.
pub fn summary(mut self) -> Self {
self.summary = true;
self
}
/// Create the [learner](Learner) from a [model](AutodiffModule) and an [optimizer](Optimizer).
/// The [learning rate scheduler](LrScheduler) can also be a simple
/// [learning rate](burn_core::LearningRate).
@ -320,6 +335,15 @@ where
LearnerCheckpointer::new(model, optim, scheduler, self.checkpointer_strategy)
});
let summary = if self.summary {
Some(LearnerSummaryConfig {
directory: self.directory,
metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
})
} else {
None
};
Learner {
model,
optim,
@ -333,6 +357,7 @@ where
devices: self.devices,
interrupter: self.interrupter,
early_stopping: self.early_stopping,
summary,
}
}

View File

@ -5,6 +5,7 @@ mod early_stopping;
mod epoch;
mod regression;
mod step;
mod summary;
mod train_val;
pub(crate) mod log;
@ -16,5 +17,6 @@ pub use early_stopping::*;
pub use epoch::*;
pub use regression::*;
pub use step::*;
pub use summary::*;
pub use train::*;
pub use train_val::*;

View File

@ -0,0 +1,307 @@
use core::cmp::Ordering;
use std::{fmt::Display, path::Path};
use crate::{
logger::FileMetricLogger,
metric::store::{Aggregate, EventStore, LogEventStore, Split},
};
/// Contains the metric value at a given time.
pub struct MetricEntry {
/// The step at which the metric was recorded (i.e., epoch).
pub step: usize,
/// The metric value.
pub value: f64,
}
/// Contains the summary of recorded values for a given metric.
pub struct MetricSummary {
/// The metric name.
pub name: String,
/// The metric entries.
pub entries: Vec<MetricEntry>,
}
impl MetricSummary {
fn new<E: EventStore>(
event_store: &mut E,
metric: &str,
split: Split,
num_epochs: usize,
) -> Option<Self> {
let entries = (1..=num_epochs)
.filter_map(|epoch| {
event_store
.find_metric(metric, epoch, Aggregate::Mean, split)
.map(|value| MetricEntry { step: epoch, value })
})
.collect::<Vec<_>>();
if entries.is_empty() {
None
} else {
Some(Self {
name: metric.to_string(),
entries,
})
}
}
}
/// Contains the summary of recorded metrics for the training and validation steps.
pub struct SummaryMetrics {
/// Training metrics summary.
pub train: Vec<MetricSummary>,
/// Validation metrics summary.
pub valid: Vec<MetricSummary>,
}
/// Detailed training summary.
pub struct LearnerSummary {
/// The number of epochs completed.
pub epochs: usize,
/// The summary of recorded metrics during training.
pub metrics: SummaryMetrics,
/// The model name (only recorded within the learner).
pub(crate) model: Option<String>,
}
impl LearnerSummary {
/// Creates a new learner summary for the specified metrics.
///
/// # Arguments
///
/// * `directory` - The directory containing the training artifacts (checkpoints and logs).
/// * `metrics` - The list of metrics to collect for the summary.
pub fn new<S: AsRef<str>>(directory: &str, metrics: &[S]) -> Result<Self, String> {
let directory_path = Path::new(directory);
if !directory_path.exists() {
return Err(format!("Artifact directory does not exist at: {directory}"));
}
let train_dir = directory_path.join("train");
let valid_dir = directory_path.join("valid");
if !train_dir.exists() & !valid_dir.exists() {
return Err(format!(
"No training or validation artifacts found at: {directory}"
));
}
let mut event_store = LogEventStore::default();
let train_logger = FileMetricLogger::new(train_dir.to_str().unwrap());
let valid_logger = FileMetricLogger::new(valid_dir.to_str().unwrap());
// Number of recorded epochs
let epochs = train_logger.epochs();
event_store.register_logger_train(train_logger);
event_store.register_logger_valid(valid_logger);
let train_summary = metrics
.iter()
.filter_map(|metric| {
MetricSummary::new(&mut event_store, metric.as_ref(), Split::Train, epochs)
})
.collect::<Vec<_>>();
let valid_summary = metrics
.iter()
.filter_map(|metric| {
MetricSummary::new(&mut event_store, metric.as_ref(), Split::Valid, epochs)
})
.collect::<Vec<_>>();
Ok(Self {
epochs,
metrics: SummaryMetrics {
train: train_summary,
valid: valid_summary,
},
model: None,
})
}
pub(crate) fn with_model(mut self, name: String) -> Self {
self.model = Some(name);
self
}
}
impl Display for LearnerSummary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Compute the max length for each column
let split_train = "Train";
let split_valid = "Valid";
let max_split_len = "Split".len().max(split_train.len()).max(split_valid.len());
let mut max_metric_len = "Metric".len();
for metric in self.metrics.train.iter() {
max_metric_len = max_metric_len.max(metric.name.len());
}
for metric in self.metrics.valid.iter() {
max_metric_len = max_metric_len.max(metric.name.len());
}
// Summary header
writeln!(
f,
"{:=>width_symbol$} Learner Summary {:=>width_symbol$}",
"",
"",
width_symbol = 24,
)?;
if let Some(model) = &self.model {
writeln!(f, "Model: {model}")?;
}
writeln!(f, "Total Epochs: {epochs}\n\n", epochs = self.epochs)?;
// Metrics table header
writeln!(
f,
"| {:<width_split$} | {:<width_metric$} | Min. | Epoch | Max. | Epoch |\n|{:->width_split$}--|{:->width_metric$}--|----------|----------|----------|----------|",
"Split", "Metric", "", "",
width_split = max_split_len,
width_metric = max_metric_len,
)?;
// Table entries
fn cmp_f64(a: &f64, b: &f64) -> Ordering {
match (a.is_nan(), b.is_nan()) {
(true, true) => Ordering::Equal,
(true, false) => Ordering::Greater,
(false, true) => Ordering::Less,
_ => a.partial_cmp(b).unwrap(),
}
}
let mut write_metrics_summary = |metrics: &[MetricSummary],
split: &str|
-> std::fmt::Result {
for metric in metrics.iter() {
if metric.entries.is_empty() {
continue; // skip metrics with no recorded values
}
// Compute the min & max for each metric
let metric_min = metric
.entries
.iter()
.min_by(|a, b| cmp_f64(&a.value, &b.value))
.unwrap();
let metric_max = metric
.entries
.iter()
.max_by(|a, b| cmp_f64(&a.value, &b.value))
.unwrap();
writeln!(
f,
"| {:<width_split$} | {:<width_metric$} | {:<9.3?}| {:<9?}| {:<9.3?}| {:<9.3?}|",
split,
metric.name,
metric_min.value,
metric_min.step,
metric_max.value,
metric_max.step,
width_split = max_split_len,
width_metric = max_metric_len,
)?;
}
Ok(())
};
write_metrics_summary(&self.metrics.train, split_train)?;
write_metrics_summary(&self.metrics.valid, split_valid)?;
Ok(())
}
}
pub(crate) struct LearnerSummaryConfig {
pub(crate) directory: String,
pub(crate) metrics: Vec<String>,
}
impl LearnerSummaryConfig {
pub fn init(&self) -> Result<LearnerSummary, String> {
LearnerSummary::new(&self.directory, &self.metrics[..])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[should_panic = "Summary artifacts should exist"]
fn test_artifact_dir_should_exist() {
let dir = "/tmp/learner-summary-not-found";
let _summary = LearnerSummary::new(dir, &["Loss"]).expect("Summary artifacts should exist");
}
#[test]
#[should_panic = "Summary artifacts should exist"]
fn test_train_valid_artifacts_should_exist() {
let dir = "/tmp/test-learner-summary-empty";
std::fs::create_dir_all(dir).ok();
let _summary = LearnerSummary::new(dir, &["Loss"]).expect("Summary artifacts should exist");
}
#[test]
fn test_summary_should_be_empty() {
let dir = Path::new("/tmp/test-learner-summary-empty-metrics");
std::fs::create_dir_all(dir).unwrap();
std::fs::create_dir_all(dir.join("train/epoch-1")).unwrap();
std::fs::create_dir_all(dir.join("valid/epoch-1")).unwrap();
let summary = LearnerSummary::new(dir.to_str().unwrap(), &["Loss"])
.expect("Summary artifacts should exist");
assert_eq!(summary.epochs, 1);
assert_eq!(summary.metrics.train.len(), 0);
assert_eq!(summary.metrics.valid.len(), 0);
std::fs::remove_dir_all(dir).unwrap();
}
#[test]
fn test_summary_should_be_collected() {
let dir = Path::new("/tmp/test-learner-summary");
let train_dir = dir.join("train/epoch-1");
let valid_dir = dir.join("valid/epoch-1");
std::fs::create_dir_all(dir).unwrap();
std::fs::create_dir_all(&train_dir).unwrap();
std::fs::create_dir_all(&valid_dir).unwrap();
std::fs::write(train_dir.join("Loss.log"), "1.0\n2.0").expect("Unable to write file");
std::fs::write(valid_dir.join("Loss.log"), "1.0").expect("Unable to write file");
let summary = LearnerSummary::new(dir.to_str().unwrap(), &["Loss"])
.expect("Summary artifacts should exist");
assert_eq!(summary.epochs, 1);
// Only Loss metric
assert_eq!(summary.metrics.train.len(), 1);
assert_eq!(summary.metrics.valid.len(), 1);
// Aggregated train metric entries for 1 epoch
let train_metric = &summary.metrics.train[0];
assert_eq!(train_metric.name, "Loss");
assert_eq!(train_metric.entries.len(), 1);
let entry = &train_metric.entries[0];
assert_eq!(entry.step, 1); // epoch = 1
assert_eq!(entry.value, 1.5); // (1 + 2) / 2
// Aggregated valid metric entries for 1 epoch
let valid_metric = &summary.metrics.valid[0];
assert_eq!(valid_metric.name, "Loss");
assert_eq!(valid_metric.entries.len(), 1);
let entry = &valid_metric.entries[0];
assert_eq!(entry.step, 1); // epoch = 1
assert_eq!(entry.value, 1.0);
std::fs::remove_dir_all(dir).unwrap();
}
}

View File

@ -199,6 +199,19 @@ impl<LC: LearnerComponents> Learner<LC> {
}
}
// Display learner summary
if let Some(summary) = self.summary {
match summary.init() {
Ok(summary) => {
// Drop event processor (includes renderer) so the summary is displayed
// when switching back to "main" screen
core::mem::drop(self.event_processor);
println!("{}", summary.with_model(self.model.to_string()))
}
Err(err) => log::error!("Could not retrieve learner summary:\n{err}"),
}
}
self.model
}
}

View File

@ -1,6 +1,8 @@
use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger};
use crate::metric::MetricEntry;
use std::collections::HashMap;
use crate::metric::{MetricEntry, NumericEntry};
use std::{collections::HashMap, fs};
const EPOCH_PREFIX: &str = "epoch-";
/// Metric logger.
pub trait MetricLogger: Send {
@ -19,7 +21,7 @@ pub trait MetricLogger: Send {
fn end_epoch(&mut self, epoch: usize);
/// Read the logs for an epoch.
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<f64>, String>;
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<NumericEntry>, String>;
}
/// The file metric logger.
@ -47,14 +49,44 @@ impl FileMetricLogger {
}
}
/// Number of epochs recorded.
pub(crate) fn epochs(&self) -> usize {
let mut max_epoch = 0;
for path in fs::read_dir(&self.directory).unwrap() {
let path = path.unwrap();
if fs::metadata(path.path()).unwrap().is_dir() {
let dir_name = path.file_name().into_string().unwrap();
if !dir_name.starts_with(EPOCH_PREFIX) {
continue;
}
let epoch = dir_name.replace(EPOCH_PREFIX, "").parse::<usize>().ok();
if let Some(epoch) = epoch {
if epoch > max_epoch {
max_epoch = epoch;
}
}
}
}
max_epoch
}
fn epoch_directory(&self, epoch: usize) -> String {
format!("{}/{}{}", self.directory, EPOCH_PREFIX, epoch)
}
fn file_path(&self, name: &str, epoch: usize) -> String {
let directory = format!("{}/epoch-{}", self.directory, epoch);
let directory = self.epoch_directory(epoch);
let name = name.replace(' ', "_");
format!("{directory}/{name}.log")
}
fn create_directory(&self, epoch: usize) {
let directory = format!("{}/epoch-{}", self.directory, epoch);
let directory = self.epoch_directory(epoch);
std::fs::create_dir_all(directory).ok();
}
}
@ -88,7 +120,7 @@ impl MetricLogger for FileMetricLogger {
self.epoch = epoch + 1;
}
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<f64>, String> {
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<NumericEntry>, String> {
if let Some(value) = self.loggers.get(name) {
value.sync()
}
@ -104,7 +136,7 @@ impl MetricLogger for FileMetricLogger {
if value.is_empty() {
None
} else {
match value.parse::<f64>() {
match NumericEntry::deserialize(value) {
Ok(value) => Some(value),
Err(err) => {
log::error!("{err}");
@ -117,7 +149,7 @@ impl MetricLogger for FileMetricLogger {
.collect();
if errors {
Err("Parsing float errors".to_string())
Err("Parsing numeric entry errors".to_string())
} else {
Ok(data)
}
@ -154,7 +186,7 @@ impl MetricLogger for InMemoryMetricLogger {
}
}
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<f64>, String> {
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<NumericEntry>, String> {
let values = match self.values.get(name) {
Some(values) => values,
None => return Ok(Vec::new()),
@ -164,7 +196,7 @@ impl MetricLogger for InMemoryMetricLogger {
Some(logger) => Ok(logger
.values
.iter()
.filter_map(|value| value.parse::<f64>().ok())
.filter_map(|value| NumericEntry::deserialize(value).ok())
.collect()),
None => Ok(Vec::new()),
}

View File

@ -84,6 +84,49 @@ pub struct MetricEntry {
pub serialize: String,
}
/// Numeric metric entry.
pub enum NumericEntry {
/// Single numeric value.
Value(f64),
/// Aggregated numeric (value, number of elements).
Aggregated(f64, usize),
}
impl NumericEntry {
pub(crate) fn serialize(&self) -> String {
match self {
Self::Value(v) => v.to_string(),
Self::Aggregated(v, n) => format!("{v},{n}"),
}
}
pub(crate) fn deserialize(entry: &str) -> Result<Self, String> {
// Check for comma separated values
let values = entry.split(',').collect::<Vec<_>>();
let num_values = values.len();
if num_values == 1 {
// Numeric value
match values[0].parse::<f64>() {
Ok(value) => Ok(NumericEntry::Value(value)),
Err(err) => Err(err.to_string()),
}
} else if num_values == 2 {
// Aggregated numeric (value, number of elements)
let (value, numel) = (values[0], values[1]);
match value.parse::<f64>() {
Ok(value) => match numel.parse::<usize>() {
Ok(numel) => Ok(NumericEntry::Aggregated(value, numel)),
Err(err) => Err(err.to_string()),
},
Err(err) => Err(err.to_string()),
}
} else {
Err("Invalid number of values for numeric entry".to_string())
}
}
}
/// Format a float with the given precision. Will use scientific notation if necessary.
pub fn format_float(float: f64, precision: usize) -> String {
let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);

View File

@ -33,10 +33,14 @@ impl<B: Backend> Metric for LossMetric<B> {
type Input = LossInput<B>;
fn update(&mut self, loss: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry {
let [batch_size] = loss.tensor.dims();
let loss = f64::from_elem(loss.tensor.clone().mean().into_data().value[0]);
self.state
.update(loss, 1, FormatOptions::new(Self::NAME).precision(2))
self.state.update(
loss,
batch_size,
FormatOptions::new(Self::NAME).precision(2),
)
}
fn clear(&mut self) {

View File

@ -24,7 +24,7 @@ impl<T, V> Default for Metrics<T, V> {
impl<T, V> Metrics<T, V> {
/// Register a training metric.
pub(crate) fn register_metric_train<Me: Metric + 'static>(&mut self, metric: Me)
pub(crate) fn register_train_metric<Me: Metric + 'static>(&mut self, metric: Me)
where
T: Adaptor<Me::Input> + 'static,
{

View File

@ -1,4 +1,4 @@
use crate::metric::{format_float, MetricEntry, Numeric};
use crate::metric::{format_float, MetricEntry, Numeric, NumericEntry};
/// Useful utility to implement numeric metrics.
///
@ -67,7 +67,8 @@ impl NumericMetricState {
let value_current = value;
let value_running = self.sum / self.count as f64;
let serialized = value_current.to_string();
// Numeric metric state is an aggregated value
let serialized = NumericEntry::Aggregated(value_current, batch_size).serialize();
let (formatted_current, formatted_running) = match format.precision {
Some(precision) => (

View File

@ -1,4 +1,4 @@
use crate::logger::MetricLogger;
use crate::{logger::MetricLogger, metric::NumericEntry};
use std::collections::HashMap;
use super::{Aggregate, Direction};
@ -48,8 +48,18 @@ impl NumericMetricsAggregate {
return None;
}
let num_points = points.len();
let sum = points.into_iter().sum::<f64>();
// Accurately compute the aggregated value based on the *actual* number of points
// since not all mini-batches are guaranteed to have the specified batch size
let (sum, num_points) = points
.into_iter()
.map(|entry| match entry {
NumericEntry::Value(v) => (v, 1),
// Right now the mean is the only aggregate available, so we can assume that the sum
// of an entry corresponds to (value * number of elements)
NumericEntry::Aggregated(v, n) => (v * n as f64, n),
})
.reduce(|(acc_v, acc_n), (v, n)| (acc_v + v, acc_n + n))
.unwrap();
let value = match aggregate {
Aggregate::Mean => sum / num_points as f64,
};
@ -105,7 +115,10 @@ impl NumericMetricsAggregate {
#[cfg(test)]
mod tests {
use crate::{logger::FileMetricLogger, metric::MetricEntry};
use crate::{
logger::{FileMetricLogger, InMemoryMetricLogger},
metric::MetricEntry,
};
use super::*;
@ -159,4 +172,34 @@ mod tests {
assert_eq!(value, 2);
}
#[test]
fn should_aggregate_numeric_entry() {
let mut logger = InMemoryMetricLogger::default();
let mut aggregate = NumericMetricsAggregate::default();
let metric_name = "Loss";
// Epoch 1
let loss_1 = 0.5;
let loss_2 = 1.25; // (1.5 + 1.0) / 2 = 2.5 / 2
let entry = MetricEntry::new(
metric_name.to_string(),
loss_1.to_string(),
NumericEntry::Value(loss_1).serialize(),
);
logger.log(&entry);
let entry = MetricEntry::new(
metric_name.to_string(),
loss_2.to_string(),
NumericEntry::Aggregated(loss_2, 2).serialize(),
);
logger.log(&entry);
let value = aggregate
.aggregate(metric_name, 1, Aggregate::Mean, &mut [Box::new(logger)])
.unwrap();
// Average should be (0.5 + 1.25 * 2) / 3 = 1.0, not (0.5 + 1.25) / 2 = 0.875
assert_eq!(value, 1.0);
}
}

View File

@ -65,8 +65,15 @@ pub struct TrainingConfig {
pub learning_rate: f64,
}
fn create_artifact_dir(artifact_dir: &str) {
// Remove existing artifacts before to get an accurate learner summary
std::fs::remove_dir_all(artifact_dir).ok();
std::fs::create_dir_all(artifact_dir).ok();
}
pub fn train<B: AutodiffBackend>(config: TrainingConfig, device: B::Device) {
std::fs::create_dir_all(ARTIFACT_DIR).ok();
create_artifact_dir(ARTIFACT_DIR);
config
.save(format!("{ARTIFACT_DIR}/config.json"))
.expect("Config should be saved successfully");
@ -98,6 +105,7 @@ pub fn train<B: AutodiffBackend>(config: TrainingConfig, device: B::Device) {
.with_file_checkpointer(CompactRecorder::new())
.devices(vec![device.clone()])
.num_epochs(config.num_epochs)
.summary()
.build(
Cnn::new(NUM_CLASSES.into(), &device),
config.optimizer.init(),

View File

@ -60,8 +60,14 @@ pub struct TrainingConfig {
pub learning_rate: f64,
}
pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {
fn create_artifact_dir(artifact_dir: &str) {
// Remove existing artifacts before to get an accurate learner summary
std::fs::remove_dir_all(artifact_dir).ok();
std::fs::create_dir_all(artifact_dir).ok();
}
pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {
create_artifact_dir(artifact_dir);
config
.save(format!("{artifact_dir}/config.json"))
.expect("Config should be saved successfully");
@ -91,6 +97,7 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
.with_file_checkpointer(CompactRecorder::new())
.devices(vec![device.clone()])
.num_epochs(config.num_epochs)
.summary()
.build(
config.model.init::<B>(&device),
config.optimizer.init(),

View File

@ -34,7 +34,14 @@ pub struct MnistTrainingConfig {
pub optimizer: AdamConfig,
}
fn create_artifact_dir(artifact_dir: &str) {
// Remove existing artifacts before to get an accurate learner summary
std::fs::remove_dir_all(artifact_dir).ok();
std::fs::create_dir_all(artifact_dir).ok();
}
pub fn run<B: AutodiffBackend>(device: B::Device) {
create_artifact_dir(ARTIFACT_DIR);
// Config
let config_optimizer = AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5)));
let config = MnistTrainingConfig::new(config_optimizer);
@ -76,6 +83,7 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
))
.devices(vec![device.clone()])
.num_epochs(config.num_epochs)
.summary()
.build(Model::new(&device), config.optimizer.init(), 1e-4);
let model_trained = learner.fit(dataloader_train, dataloader_test);

View File

@ -81,6 +81,7 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
))
.devices(vec![device.clone()])
.num_epochs(config.num_epochs)
.summary()
.build(model, config.optimizer.init(), 5e-3);
let model_trained = learner.fit(dataloader_train, dataloader_test);

View File

@ -102,6 +102,7 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
.with_file_checkpointer(CompactRecorder::new())
.devices(devices)
.num_epochs(config.num_epochs)
.summary()
.build(model, optim, lr_scheduler);
// Train the model

View File

@ -80,6 +80,7 @@ pub fn train<B: AutodiffBackend, D: Dataset<TextGenerationItem> + 'static>(
.devices(vec![device])
.grads_accumulation(accum)
.num_epochs(config.num_epochs)
.summary()
.build(model, optim, lr_scheduler);
let model_trained = learner.fit(dataloader_train, dataloader_test);