Feat/early stopping + burn train refactor (#878)

This commit is contained in:
Nathaniel Simard 2023-10-20 11:47:31 -04:00 committed by GitHub
parent 3eb7f380f3
commit af813d09ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 1125 additions and 733 deletions

View File

@ -1,6 +1,7 @@
use crate::EventCollector;
use std::ops::DerefMut;
use crate::metric::store::EventStoreClient;
/// Action to be taken by a [checkpointer](crate::checkpoint::Checkpointer).
#[derive(Clone, PartialEq, Debug)]
pub enum CheckpointingAction {
@ -11,15 +12,23 @@ pub enum CheckpointingAction {
}
/// Define when checkpoint should be saved and deleted.
pub trait CheckpointingStrategy<E: EventCollector> {
pub trait CheckpointingStrategy {
/// Based on the epoch, determine if the checkpoint should be saved.
fn checkpointing(&mut self, epoch: usize, collector: &mut E) -> Vec<CheckpointingAction>;
fn checkpointing(
&mut self,
epoch: usize,
collector: &EventStoreClient,
) -> Vec<CheckpointingAction>;
}
// We make dyn box implement the checkpointing strategy so that it can be used with generic, but
// still be dynamic.
impl<E: EventCollector> CheckpointingStrategy<E> for Box<dyn CheckpointingStrategy<E>> {
fn checkpointing(&mut self, epoch: usize, collector: &mut E) -> Vec<CheckpointingAction> {
impl CheckpointingStrategy for Box<dyn CheckpointingStrategy> {
fn checkpointing(
&mut self,
epoch: usize,
collector: &EventStoreClient,
) -> Vec<CheckpointingAction> {
self.deref_mut().checkpointing(epoch, collector)
}
}

View File

@ -1,45 +1,40 @@
use crate::metric::store::EventStoreClient;
use super::{CheckpointingAction, CheckpointingStrategy};
use crate::EventCollector;
use std::collections::HashSet;
/// Compose multiple checkpointing strategy and only delete checkpoints when both strategy flag an
/// epoch to be deleted.
pub struct ComposedCheckpointingStrategy<E: EventCollector> {
strategies: Vec<Box<dyn CheckpointingStrategy<E>>>,
pub struct ComposedCheckpointingStrategy {
strategies: Vec<Box<dyn CheckpointingStrategy>>,
deleted: Vec<HashSet<usize>>,
}
/// Help building a [checkpointing strategy](CheckpointingStrategy) by combining multiple ones.
pub struct ComposedCheckpointingStrategyBuilder<E: EventCollector> {
strategies: Vec<Box<dyn CheckpointingStrategy<E>>>,
#[derive(Default)]
pub struct ComposedCheckpointingStrategyBuilder {
strategies: Vec<Box<dyn CheckpointingStrategy>>,
}
impl<E: EventCollector> Default for ComposedCheckpointingStrategyBuilder<E> {
fn default() -> Self {
Self {
strategies: Vec::new(),
}
}
}
impl<E: EventCollector> ComposedCheckpointingStrategyBuilder<E> {
impl ComposedCheckpointingStrategyBuilder {
/// Add a new [checkpointing strategy](CheckpointingStrategy).
#[allow(clippy::should_implement_trait)]
pub fn add<S>(mut self, strategy: S) -> Self
where
S: CheckpointingStrategy<E> + 'static,
S: CheckpointingStrategy + 'static,
{
self.strategies.push(Box::new(strategy));
self
}
/// Create a new [composed checkpointing strategy](ComposedCheckpointingStrategy).
pub fn build(self) -> ComposedCheckpointingStrategy<E> {
pub fn build(self) -> ComposedCheckpointingStrategy {
ComposedCheckpointingStrategy::new(self.strategies)
}
}
impl<E: EventCollector> ComposedCheckpointingStrategy<E> {
fn new(strategies: Vec<Box<dyn CheckpointingStrategy<E>>>) -> Self {
impl ComposedCheckpointingStrategy {
fn new(strategies: Vec<Box<dyn CheckpointingStrategy>>) -> Self {
Self {
deleted: strategies.iter().map(|_| HashSet::new()).collect(),
strategies,
@ -47,13 +42,17 @@ impl<E: EventCollector> ComposedCheckpointingStrategy<E> {
}
/// Create a new builder which help compose multiple
/// [checkpointing strategies](CheckpointingStrategy).
pub fn builder() -> ComposedCheckpointingStrategyBuilder<E> {
pub fn builder() -> ComposedCheckpointingStrategyBuilder {
ComposedCheckpointingStrategyBuilder::default()
}
}
impl<E: EventCollector> CheckpointingStrategy<E> for ComposedCheckpointingStrategy<E> {
fn checkpointing(&mut self, epoch: usize, collector: &mut E) -> Vec<CheckpointingAction> {
impl CheckpointingStrategy for ComposedCheckpointingStrategy {
fn checkpointing(
&mut self,
epoch: usize,
collector: &EventStoreClient,
) -> Vec<CheckpointingAction> {
let mut saved = false;
let mut actions = Vec::new();
let mut epochs_to_check = Vec::new();
@ -104,15 +103,12 @@ impl<E: EventCollector> CheckpointingStrategy<E> for ComposedCheckpointingStrate
#[cfg(test)]
mod tests {
use crate::{
checkpoint::KeepLastNCheckpoints, info::MetricsInfo, test_utils::TestEventCollector,
};
use super::*;
use crate::{checkpoint::KeepLastNCheckpoints, metric::store::LogEventStore};
#[test]
fn should_delete_when_both_deletes() {
let mut collector = TestEventCollector::<f64, f64>::new(MetricsInfo::new());
let store = EventStoreClient::new(LogEventStore::default());
let mut strategy = ComposedCheckpointingStrategy::builder()
.add(KeepLastNCheckpoints::new(1))
.add(KeepLastNCheckpoints::new(2))
@ -120,17 +116,17 @@ mod tests {
assert_eq!(
vec![CheckpointingAction::Save],
strategy.checkpointing(1, &mut collector)
strategy.checkpointing(1, &store)
);
assert_eq!(
vec![CheckpointingAction::Save],
strategy.checkpointing(2, &mut collector)
strategy.checkpointing(2, &store)
);
assert_eq!(
vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)],
strategy.checkpointing(3, &mut collector)
strategy.checkpointing(3, &store)
);
}
}

View File

@ -1,5 +1,5 @@
use super::CheckpointingStrategy;
use crate::{checkpoint::CheckpointingAction, EventCollector};
use crate::{checkpoint::CheckpointingAction, metric::store::EventStoreClient};
/// Keep the last N checkpoints.
///
@ -10,8 +10,12 @@ pub struct KeepLastNCheckpoints {
num_keep: usize,
}
impl<E: EventCollector> CheckpointingStrategy<E> for KeepLastNCheckpoints {
fn checkpointing(&mut self, epoch: usize, _collector: &mut E) -> Vec<CheckpointingAction> {
impl CheckpointingStrategy for KeepLastNCheckpoints {
fn checkpointing(
&mut self,
epoch: usize,
_store: &EventStoreClient,
) -> Vec<CheckpointingAction> {
let mut actions = vec![CheckpointingAction::Save];
if let Some(epoch) = usize::checked_sub(epoch, self.num_keep) {
@ -26,28 +30,27 @@ impl<E: EventCollector> CheckpointingStrategy<E> for KeepLastNCheckpoints {
#[cfg(test)]
mod tests {
use crate::{info::MetricsInfo, test_utils::TestEventCollector};
use super::*;
use crate::metric::store::LogEventStore;
#[test]
fn should_always_delete_lastn_epoch_if_higher_than_one() {
let mut strategy = KeepLastNCheckpoints::new(2);
let mut collector = TestEventCollector::<f64, f64>::new(MetricsInfo::new());
let store = EventStoreClient::new(LogEventStore::default());
assert_eq!(
vec![CheckpointingAction::Save],
strategy.checkpointing(1, &mut collector)
strategy.checkpointing(1, &store)
);
assert_eq!(
vec![CheckpointingAction::Save],
strategy.checkpointing(2, &mut collector)
strategy.checkpointing(2, &store)
);
assert_eq!(
vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)],
strategy.checkpointing(3, &mut collector)
strategy.checkpointing(3, &store)
);
}
}

View File

@ -1,6 +1,10 @@
use super::CheckpointingStrategy;
use crate::{
checkpoint::CheckpointingAction, metric::Metric, Aggregate, Direction, EventCollector, Split,
checkpoint::CheckpointingAction,
metric::{
store::{Aggregate, Direction, EventStoreClient, Split},
Metric,
},
};
/// Keep the best checkpoint based on a metric.
@ -28,10 +32,14 @@ impl MetricCheckpointingStrategy {
}
}
impl<E: EventCollector> CheckpointingStrategy<E> for MetricCheckpointingStrategy {
fn checkpointing(&mut self, epoch: usize, collector: &mut E) -> Vec<CheckpointingAction> {
impl CheckpointingStrategy for MetricCheckpointingStrategy {
fn checkpointing(
&mut self,
epoch: usize,
store: &EventStoreClient,
) -> Vec<CheckpointingAction> {
let best_epoch =
match collector.find_epoch(&self.name, self.aggregate, self.direction, self.split) {
match store.find_epoch(&self.name, self.aggregate, self.direction, self.split) {
Some(epoch_best) => epoch_best,
None => epoch,
};
@ -56,93 +64,70 @@ impl<E: EventCollector> CheckpointingStrategy<E> for MetricCheckpointingStrategy
#[cfg(test)]
mod tests {
use burn_core::tensor::{backend::Backend, ElementConversion, Tensor};
use crate::{
logger::InMemoryMetricLogger,
metric::{
processor::{
test_utils::{end_epoch, process_train},
Metrics, MinimalEventProcessor,
},
store::LogEventStore,
LossMetric,
},
TestBackend,
};
use std::sync::Arc;
use super::*;
use crate::{
info::MetricsInfo,
logger::InMemoryMetricLogger,
metric::{Adaptor, LossInput, LossMetric},
test_utils::TestEventCollector,
Event, LearnerItem, TestBackend,
};
#[test]
fn always_keep_the_best_epoch() {
let mut store = LogEventStore::default();
let mut strategy = MetricCheckpointingStrategy::new::<LossMetric<TestBackend>>(
Aggregate::Mean,
Direction::Lowest,
Split::Train,
);
let mut info = MetricsInfo::new();
let mut metrics = Metrics::<f64, f64>::default();
// Register an in memory logger.
info.register_logger_train(InMemoryMetricLogger::default());
store.register_logger_train(InMemoryMetricLogger::default());
// Register the loss metric.
info.register_train_metric_numeric(LossMetric::<TestBackend>::new());
let mut collector = TestEventCollector::<f64, f64>::new(info);
metrics.register_train_metric_numeric(LossMetric::<TestBackend>::new());
let store = Arc::new(EventStoreClient::new(store));
let mut processor = MinimalEventProcessor::new(metrics, store.clone());
// Two points for the first epoch. Mean 0.75
let mut epoch = 1;
item(&mut collector, 1.0, epoch);
item(&mut collector, 0.5, epoch);
end_epoch(&mut collector, epoch);
process_train(&mut processor, 1.0, epoch);
process_train(&mut processor, 0.5, epoch);
end_epoch(&mut processor, epoch);
// Should save the current record.
assert_eq!(
vec![CheckpointingAction::Save],
strategy.checkpointing(epoch, &mut collector)
strategy.checkpointing(epoch, &store)
);
// Two points for the second epoch. Mean 0.4
epoch += 1;
item(&mut collector, 0.5, epoch);
item(&mut collector, 0.3, epoch);
end_epoch(&mut collector, epoch);
process_train(&mut processor, 0.5, epoch);
process_train(&mut processor, 0.3, epoch);
end_epoch(&mut processor, epoch);
// Should save the current record and delete the pervious one.
assert_eq!(
vec![CheckpointingAction::Delete(1), CheckpointingAction::Save],
strategy.checkpointing(epoch, &mut collector)
strategy.checkpointing(epoch, &store)
);
// Two points for the last epoch. Mean 2.0
epoch += 1;
item(&mut collector, 1.0, epoch);
item(&mut collector, 3.0, epoch);
end_epoch(&mut collector, epoch);
process_train(&mut processor, 1.0, epoch);
process_train(&mut processor, 3.0, epoch);
end_epoch(&mut processor, epoch);
// Should not delete the previous record, since it's the best one, and should not save a
// new one.
assert!(strategy.checkpointing(epoch, &mut collector).is_empty());
}
fn item(collector: &mut TestEventCollector<f64, f64>, value: f64, epoch: usize) {
let dummy_progress = burn_core::data::dataloader::Progress {
items_processed: 1,
items_total: 10,
};
let num_epochs = 3;
let dummy_iteration = 1;
collector.on_event_train(Event::ProcessedItem(LearnerItem::new(
value,
dummy_progress,
epoch,
num_epochs,
dummy_iteration,
None,
)));
}
fn end_epoch(collector: &mut TestEventCollector<f64, f64>, epoch: usize) {
collector.on_event_train(Event::EndEpoch(epoch));
collector.on_event_valid(Event::EndEpoch(epoch));
}
impl<B: Backend> Adaptor<LossInput<B>> for f64 {
fn adapt(&self) -> LossInput<B> {
LossInput::new(Tensor::from_data([self.elem()]))
}
assert!(strategy.checkpointing(epoch, &store).is_empty());
}
}

View File

@ -1,118 +0,0 @@
use super::EventCollector;
use crate::{Aggregate, Direction, Event, Split};
use std::{sync::mpsc, thread::JoinHandle};
enum Message<T, V> {
OnEventTrain(Event<T>),
OnEventValid(Event<V>),
End,
FindEpoch(
String,
Aggregate,
Direction,
Split,
mpsc::SyncSender<Option<usize>>,
),
}
/// Async [event collector](EventCollector).
///
/// This will create a worker thread where all the computation is done ensuring that the training loop is
/// never blocked by metric calculation.
pub struct AsyncEventCollector<T, V> {
sender: mpsc::Sender<Message<T, V>>,
handler: Option<JoinHandle<()>>,
}
#[derive(new)]
struct WorkerThread<C, T, V> {
collector: C,
receiver: mpsc::Receiver<Message<T, V>>,
}
impl<C, T, V> WorkerThread<C, T, V>
where
C: EventCollector<ItemTrain = T, ItemValid = V>,
{
fn run(mut self) {
for item in self.receiver.iter() {
match item {
Message::End => {
return;
}
Message::FindEpoch(name, aggregate, direction, split, sender) => {
let response = self
.collector
.find_epoch(&name, aggregate, direction, split);
sender.send(response).unwrap();
}
Message::OnEventTrain(event) => self.collector.on_event_train(event),
Message::OnEventValid(event) => self.collector.on_event_valid(event),
}
}
}
}
impl<T: Send + Sync + 'static, V: Send + Sync + 'static> AsyncEventCollector<T, V> {
/// Create a new async [event collector](EventCollector).
pub fn new<C>(collector: C) -> Self
where
C: EventCollector<ItemTrain = T, ItemValid = V> + 'static,
{
let (sender, receiver) = mpsc::channel();
let thread = WorkerThread::new(collector, receiver);
let handler = std::thread::spawn(move || thread.run());
let handler = Some(handler);
Self { sender, handler }
}
}
impl<T: Send, V: Send> EventCollector for AsyncEventCollector<T, V> {
type ItemTrain = T;
type ItemValid = V;
fn on_event_train(&mut self, event: Event<Self::ItemTrain>) {
self.sender.send(Message::OnEventTrain(event)).unwrap();
}
fn on_event_valid(&mut self, event: Event<Self::ItemValid>) {
self.sender.send(Message::OnEventValid(event)).unwrap();
}
fn find_epoch(
&mut self,
name: &str,
aggregate: Aggregate,
direction: Direction,
split: Split,
) -> Option<usize> {
let (sender, receiver) = mpsc::sync_channel(1);
self.sender
.send(Message::FindEpoch(
name.to_string(),
aggregate,
direction,
split,
sender,
))
.unwrap();
match receiver.recv() {
Ok(value) => value,
Err(err) => panic!("Async server crashed: {:?}", err),
}
}
}
impl<T, V> Drop for AsyncEventCollector<T, V> {
fn drop(&mut self) {
self.sender.send(Message::End).unwrap();
let handler = self.handler.take();
if let Some(handler) = handler {
handler.join().unwrap();
}
}
}

View File

@ -1,134 +0,0 @@
use burn_core::{data::dataloader::Progress, LearningRate};
/// Event happening during the training/validation process.
pub enum Event<T> {
/// Signal that an item have been processed.
ProcessedItem(LearnerItem<T>),
/// Signal the end of an epoch.
EndEpoch(usize),
}
/// Defines how training and validation events are collected.
///
/// This trait also exposes methods that uses the collected data to compute useful information.
pub trait EventCollector: Send {
/// Training item.
type ItemTrain;
/// Validation item.
type ItemValid;
/// Collect the training event.
fn on_event_train(&mut self, event: Event<Self::ItemTrain>);
/// Collect the validaion event.
fn on_event_valid(&mut self, event: Event<Self::ItemValid>);
/// Find the epoch following the given criteria from the collected data.
fn find_epoch(
&mut self,
name: &str,
aggregate: Aggregate,
direction: Direction,
split: Split,
) -> Option<usize>;
}
#[derive(Copy, Clone)]
/// How to aggregate the metric.
pub enum Aggregate {
/// Compute the average.
Mean,
}
#[derive(Copy, Clone)]
/// The split to use.
pub enum Split {
/// The training split.
Train,
/// The validation split.
Valid,
}
#[derive(Copy, Clone)]
/// The direction of the query.
pub enum Direction {
/// Lower is better.
Lowest,
/// Higher is better.
Highest,
}
/// A learner item.
#[derive(new)]
pub struct LearnerItem<T> {
/// The item.
pub item: T,
/// The progress.
pub progress: Progress,
/// The epoch.
pub epoch: usize,
/// The total number of epochs.
pub epoch_total: usize,
/// The iteration.
pub iteration: usize,
/// The learning rate.
pub lr: Option<LearningRate>,
}
#[cfg(test)]
pub mod test_utils {
use crate::{info::MetricsInfo, Aggregate, Direction, Event, EventCollector, Split};
#[derive(new)]
pub struct TestEventCollector<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
info: MetricsInfo<T, V>,
}
impl<T, V> EventCollector for TestEventCollector<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
type ItemTrain = T;
type ItemValid = V;
fn on_event_train(&mut self, event: Event<Self::ItemTrain>) {
match event {
Event::ProcessedItem(item) => {
let metadata = (&item).into();
self.info.update_train(&item, &metadata);
}
Event::EndEpoch(epoch) => self.info.end_epoch_train(epoch),
}
}
fn on_event_valid(&mut self, event: Event<Self::ItemValid>) {
match event {
Event::ProcessedItem(item) => {
let metadata = (&item).into();
self.info.update_valid(&item, &metadata);
}
Event::EndEpoch(epoch) => self.info.end_epoch_valid(epoch),
}
}
fn find_epoch(
&mut self,
name: &str,
aggregate: Aggregate,
direction: Direction,
split: Split,
) -> Option<usize> {
self.info.find_epoch(name, aggregate, direction, split)
}
}
}

View File

@ -1,131 +0,0 @@
use crate::{
info::MetricsInfo,
metric::MetricMetadata,
renderer::{MetricState, MetricsRenderer, TrainingProgress},
Aggregate, Direction, Event, EventCollector, LearnerItem, Split,
};
/// Collect training events in order to display metrics with a metrics renderer.
#[derive(new)]
pub(crate) struct RenderedMetricsEventCollector<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
renderer: Box<dyn MetricsRenderer>,
info: MetricsInfo<T, V>,
}
impl<T, V> EventCollector for RenderedMetricsEventCollector<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
type ItemTrain = T;
type ItemValid = V;
fn on_event_train(&mut self, event: Event<Self::ItemTrain>) {
match event {
Event::ProcessedItem(item) => self.on_train_item(item),
Event::EndEpoch(epoch) => self.on_train_end_epoch(epoch),
}
}
fn on_event_valid(&mut self, event: Event<Self::ItemValid>) {
match event {
Event::ProcessedItem(item) => self.on_valid_item(item),
Event::EndEpoch(epoch) => self.on_valid_end_epoch(epoch),
}
}
fn find_epoch(
&mut self,
name: &str,
aggregate: Aggregate,
direction: Direction,
split: Split,
) -> Option<usize> {
self.info.find_epoch(name, aggregate, direction, split)
}
}
impl<T, V> RenderedMetricsEventCollector<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
fn on_train_item(&mut self, item: LearnerItem<T>) {
let progress = (&item).into();
let metadata = (&item).into();
let update = self.info.update_train(&item, &metadata);
update
.entries
.into_iter()
.for_each(|entry| self.renderer.update_train(MetricState::Generic(entry)));
update
.entries_numeric
.into_iter()
.for_each(|(entry, value)| {
self.renderer
.update_train(MetricState::Numeric(entry, value))
});
self.renderer.render_train(progress);
}
fn on_valid_item(&mut self, item: LearnerItem<V>) {
let progress = (&item).into();
let metadata = (&item).into();
let update = self.info.update_valid(&item, &metadata);
update
.entries
.into_iter()
.for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry)));
update
.entries_numeric
.into_iter()
.for_each(|(entry, value)| {
self.renderer
.update_valid(MetricState::Numeric(entry, value))
});
self.renderer.render_train(progress);
}
fn on_train_end_epoch(&mut self, epoch: usize) {
self.info.end_epoch_train(epoch);
}
fn on_valid_end_epoch(&mut self, epoch: usize) {
self.info.end_epoch_valid(epoch);
}
}
impl<T> From<&LearnerItem<T>> for TrainingProgress {
fn from(item: &LearnerItem<T>) -> Self {
Self {
progress: item.progress.clone(),
epoch: item.epoch,
epoch_total: item.epoch_total,
iteration: item.iteration,
}
}
}
impl<T> From<&LearnerItem<T>> for MetricMetadata {
fn from(item: &LearnerItem<T>) -> Self {
Self {
progress: item.progress.clone(),
epoch: item.epoch,
epoch_total: item.epoch_total,
iteration: item.iteration,
lr: item.lr,
}
}
}

View File

@ -1,3 +0,0 @@
mod base;
pub(crate) use base::*;

View File

@ -1,8 +0,0 @@
mod async_collector;
mod base;
pub use async_collector::*;
pub use base::*;
/// Metrics collector module.
pub mod metrics;

View File

@ -1,6 +1,6 @@
use crate::{
checkpoint::{Checkpointer, CheckpointingStrategy},
EventCollector,
metric::processor::EventProcessor,
};
use burn_core::{
lr_scheduler::LrScheduler,
@ -28,14 +28,13 @@ pub trait LearnerComponents {
>;
/// The checkpointer used for the scheduler.
type CheckpointerLrScheduler: Checkpointer<<Self::LrScheduler as LrScheduler>::Record>;
/// Training event collector used for training tracking.
type EventCollector: EventCollector + 'static;
type EventProcessor: EventProcessor + 'static;
/// The strategy to save and delete checkpoints.
type CheckpointerStrategy: CheckpointingStrategy<Self::EventCollector>;
type CheckpointerStrategy: CheckpointingStrategy;
}
/// Concrete type that implements [training components trait](TrainingComponents).
pub struct LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EC, S> {
pub struct LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EP, S> {
_backend: PhantomData<B>,
_lr_scheduler: PhantomData<LR>,
_model: PhantomData<M>,
@ -43,12 +42,12 @@ pub struct LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EC, S> {
_checkpointer_model: PhantomData<CM>,
_checkpointer_optim: PhantomData<CO>,
_checkpointer_scheduler: PhantomData<CS>,
_collector: PhantomData<EC>,
_event_processor: PhantomData<EP>,
_strategy: S,
}
impl<B, LR, M, O, CM, CO, CS, EC, S> LearnerComponents
for LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EC, S>
impl<B, LR, M, O, CM, CO, CS, EP, S> LearnerComponents
for LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EP, S>
where
B: ADBackend,
LR: LrScheduler,
@ -57,8 +56,8 @@ where
CM: Checkpointer<M::Record>,
CO: Checkpointer<O::Record>,
CS: Checkpointer<LR::Record>,
EC: EventCollector + 'static,
S: CheckpointingStrategy<EC>,
EP: EventProcessor + 'static,
S: CheckpointingStrategy,
{
type Backend = B;
type LrScheduler = LR;
@ -67,6 +66,6 @@ where
type CheckpointerModel = CM;
type CheckpointerOptimizer = CO;
type CheckpointerLrScheduler = CS;
type EventCollector = EC;
type EventProcessor = EP;
type CheckpointerStrategy = S;
}

View File

@ -1,5 +0,0 @@
mod aggregates;
mod metrics;
pub(crate) use aggregates::*;
pub use metrics::*;

View File

@ -1,5 +1,7 @@
use crate::checkpoint::{Checkpointer, CheckpointingAction, CheckpointingStrategy};
use crate::components::LearnerComponents;
use crate::learner::EarlyStoppingStrategy;
use crate::metric::store::EventStoreClient;
use burn_core::lr_scheduler::LrScheduler;
use burn_core::module::Module;
use burn_core::optim::Optimizer;
@ -19,8 +21,10 @@ pub struct Learner<LC: LearnerComponents> {
pub(crate) grad_accumulation: Option<usize>,
pub(crate) checkpointer: Option<LearnerCheckpointer<LC>>,
pub(crate) devices: Vec<<LC::Backend as Backend>::Device>,
pub(crate) collector: LC::EventCollector,
pub(crate) interrupter: TrainingInterrupter,
pub(crate) early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
pub(crate) event_processor: LC::EventProcessor,
pub(crate) event_store: Arc<EventStoreClient>,
}
#[derive(new)]
@ -38,9 +42,9 @@ impl<LC: LearnerComponents> LearnerCheckpointer<LC> {
optim: &LC::Optimizer,
scheduler: &LC::LrScheduler,
epoch: usize,
collector: &mut LC::EventCollector,
store: &EventStoreClient,
) {
let actions = self.strategy.checkpointing(epoch, collector);
let actions = self.strategy.checkpointing(epoch, store);
for action in actions {
match action {

View File

@ -1,3 +1,5 @@
use std::sync::Arc;
use super::log::install_file_logger;
use super::Learner;
use crate::checkpoint::{
@ -5,13 +7,14 @@ use crate::checkpoint::{
KeepLastNCheckpoints, MetricCheckpointingStrategy,
};
use crate::components::LearnerComponentsMarker;
use crate::info::MetricsInfo;
use crate::learner::base::TrainingInterrupter;
use crate::learner::EarlyStoppingStrategy;
use crate::logger::{FileMetricLogger, MetricLogger};
use crate::metric::processor::{FullEventProcessor, Metrics};
use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
use crate::metric::{Adaptor, LossMetric, Metric};
use crate::renderer::{default_renderer, MetricsRenderer};
use crate::{collector::metrics::RenderedMetricsEventCollector, Aggregate, Direction, Split};
use crate::{AsyncEventCollector, LearnerCheckpointer};
use crate::LearnerCheckpointer;
use burn_core::lr_scheduler::LrScheduler;
use burn_core::module::ADModule;
use burn_core::optim::Optimizer;
@ -43,11 +46,13 @@ where
grad_accumulation: Option<usize>,
devices: Vec<B::Device>,
renderer: Option<Box<dyn MetricsRenderer + 'static>>,
info: MetricsInfo<T, V>,
metrics: Metrics<T, V>,
event_store: LogEventStore,
interrupter: TrainingInterrupter,
log_to_file: bool,
num_loggers: usize,
checkpointer_strategy: Box<dyn CheckpointingStrategy<AsyncEventCollector<T, V>>>,
checkpointer_strategy: Box<dyn CheckpointingStrategy>,
early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
}
impl<B, T, V, M, O, S> LearnerBuilder<B, T, V, M, O, S>
@ -72,7 +77,8 @@ where
directory: directory.to_string(),
grad_accumulation: None,
devices: vec![B::Device::default()],
info: MetricsInfo::new(),
metrics: Metrics::default(),
event_store: LogEventStore::default(),
renderer: None,
interrupter: TrainingInterrupter::new(),
log_to_file: true,
@ -87,6 +93,7 @@ where
))
.build(),
),
early_stopping: None,
}
}
@ -101,8 +108,8 @@ where
MT: MetricLogger + 'static,
MV: MetricLogger + 'static,
{
self.info.register_logger_train(logger_train);
self.info.register_logger_valid(logger_valid);
self.event_store.register_logger_train(logger_train);
self.event_store.register_logger_valid(logger_valid);
self.num_loggers += 1;
self
}
@ -110,7 +117,7 @@ where
/// Update the checkpointing_strategy.
pub fn with_checkpointing_strategy<CS>(&mut self, strategy: CS)
where
CS: CheckpointingStrategy<AsyncEventCollector<T, V>> + 'static,
CS: CheckpointingStrategy + 'static,
{
self.checkpointer_strategy = Box::new(strategy);
}
@ -133,7 +140,7 @@ where
where
T: Adaptor<Me::Input>,
{
self.info.register_metric_train(metric);
self.metrics.register_metric_train(metric);
self
}
@ -142,7 +149,7 @@ where
where
V: Adaptor<Me::Input>,
{
self.info.register_valid_metric(metric);
self.metrics.register_valid_metric(metric);
self
}
@ -167,7 +174,7 @@ where
Me: Metric + crate::metric::Numeric + 'static,
T: Adaptor<Me::Input>,
{
self.info.register_train_metric_numeric(metric);
self.metrics.register_train_metric_numeric(metric);
self
}
@ -179,7 +186,7 @@ where
where
V: Adaptor<Me::Input>,
{
self.info.register_valid_metric_numeric(metric);
self.metrics.register_valid_metric_numeric(metric);
self
}
@ -206,6 +213,16 @@ where
self.interrupter.clone()
}
/// Register an [early stopping strategy](EarlyStoppingStrategy) to stop the training when the
/// conditions are meet.
pub fn early_stopping<Strategy>(mut self, strategy: Strategy) -> Self
where
Strategy: EarlyStoppingStrategy + 'static,
{
self.early_stopping = Some(Box::new(strategy));
self
}
/// By default, Rust logs are captured and written into
/// `experiment.log`. If disabled, standard Rust log handling
/// will apply.
@ -267,8 +284,8 @@ where
AsyncCheckpointer<M::Record>,
AsyncCheckpointer<O::Record>,
AsyncCheckpointer<S::Record>,
AsyncEventCollector<T, V>,
Box<dyn CheckpointingStrategy<AsyncEventCollector<T, V>>>,
FullEventProcessor<T, V>,
Box<dyn CheckpointingStrategy>,
>,
>
where
@ -285,16 +302,18 @@ where
let directory = &self.directory;
if self.num_loggers == 0 {
self.info.register_logger_train(FileMetricLogger::new(
format!("{directory}/train").as_str(),
));
self.info.register_logger_valid(FileMetricLogger::new(
format!("{directory}/valid").as_str(),
));
self.event_store
.register_logger_train(FileMetricLogger::new(
format!("{directory}/train").as_str(),
));
self.event_store
.register_logger_valid(FileMetricLogger::new(
format!("{directory}/valid").as_str(),
));
}
let collector =
AsyncEventCollector::new(RenderedMetricsEventCollector::new(renderer, self.info));
let event_store = Arc::new(EventStoreClient::new(self.event_store));
let event_processor = FullEventProcessor::new(self.metrics, renderer, event_store.clone());
let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| {
LearnerCheckpointer::new(model, optim, scheduler, self.checkpointer_strategy)
@ -306,11 +325,13 @@ where
lr_scheduler,
checkpointer,
num_epochs: self.num_epochs,
collector,
event_processor,
event_store,
checkpoint: self.checkpoint,
grad_accumulation: self.grad_accumulation,
devices: self.devices,
interrupter: self.interrupter,
early_stopping: self.early_stopping,
}
}

View File

@ -0,0 +1,209 @@
use crate::metric::{
store::{Aggregate, Direction, EventStoreClient, Split},
Metric,
};
/// The condition that [early stopping strategies](EarlyStoppingStrategy) should follow.
pub enum StoppingCondition {
/// When no improvement has happened since the given number of epochs.
NoImprovementSince {
/// The number of epochs allowed to worsen before it gets better.
n_epochs: usize,
},
}
/// A strategy that checks if the training should be stopped.
pub trait EarlyStoppingStrategy {
/// Update its current state and returns if the training should be stopped.
fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool;
}
/// An [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected
/// during training or validation.
pub struct MetricEarlyStoppingStrategy {
condition: StoppingCondition,
metric_name: String,
aggregate: Aggregate,
direction: Direction,
split: Split,
best_epoch: usize,
best_value: f64,
}
impl EarlyStoppingStrategy for MetricEarlyStoppingStrategy {
fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool {
let current_value =
match store.find_metric(&self.metric_name, epoch, self.aggregate, self.split) {
Some(value) => value,
None => {
log::warn!("Can't find metric for early stopping.");
return false;
}
};
let is_best = match self.direction {
Direction::Lowest => current_value < self.best_value,
Direction::Highest => current_value > self.best_value,
};
if is_best {
log::info!(
"New best epoch found {} {}: {}",
epoch,
self.metric_name,
current_value
);
self.best_value = current_value;
self.best_epoch = epoch;
return false;
}
match self.condition {
StoppingCondition::NoImprovementSince { n_epochs } => {
let should_stop = epoch - self.best_epoch >= n_epochs;
if should_stop {
log::info!("Stopping training loop, no improvement since epoch {}, {}: {}, current epoch {}, {}: {}", self.best_epoch, self.metric_name, self.best_value, epoch, self.metric_name, current_value);
}
should_stop
}
}
}
}
impl MetricEarlyStoppingStrategy {
/// Create a new [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected
/// during training or validation.
///
/// # Notes
///
/// The metric should be registered for early stopping to work, otherwise no data is collected.
pub fn new<Me: Metric>(
aggregate: Aggregate,
direction: Direction,
split: Split,
condition: StoppingCondition,
) -> Self {
let init_value = match direction {
Direction::Lowest => f64::MAX,
Direction::Highest => f64::MIN,
};
Self {
metric_name: Me::NAME.to_string(),
condition,
aggregate,
direction,
split,
best_epoch: 1,
best_value: init_value,
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::{
logger::InMemoryMetricLogger,
metric::{
processor::{
test_utils::{end_epoch, process_train},
Metrics, MinimalEventProcessor,
},
store::LogEventStore,
LossMetric,
},
TestBackend,
};
use super::*;
#[test]
fn never_early_stop_while_it_is_improving() {
test_early_stopping(
1,
&[
(&[0.5, 0.3], false, "Should not stop first epoch"),
(&[0.4, 0.3], false, "Should not stop when improving"),
(&[0.3, 0.3], false, "Should not stop when improving"),
(&[0.2, 0.3], false, "Should not stop when improving"),
],
);
}
#[test]
fn early_stop_when_no_improvement_since_two_epochs() {
test_early_stopping(
2,
&[
(&[1.0, 0.5], false, "Should not stop first epoch"),
(&[0.5, 0.3], false, "Should not stop when improving"),
(
&[1.0, 3.0],
false,
"Should not stop first time it gets worse",
),
(
&[1.0, 2.0],
true,
"Should stop since two following epochs didn't improve",
),
],
);
}
#[test]
fn early_stop_when_stays_equal() {
test_early_stopping(
2,
&[
(&[0.5, 0.3], false, "Should not stop first epoch"),
(
&[0.5, 0.3],
false,
"Should not stop first time it stars the same",
),
(
&[0.5, 0.3],
true,
"Should stop since two following epochs didn't improve",
),
],
);
}
fn test_early_stopping(n_epochs: usize, data: &[(&[f64], bool, &str)]) {
let mut early_stopping = MetricEarlyStoppingStrategy::new::<LossMetric<TestBackend>>(
Aggregate::Mean,
Direction::Lowest,
Split::Train,
StoppingCondition::NoImprovementSince { n_epochs },
);
let mut store = LogEventStore::default();
let mut metrics = Metrics::<f64, f64>::default();
store.register_logger_train(InMemoryMetricLogger::default());
metrics.register_train_metric_numeric(LossMetric::<TestBackend>::new());
let store = Arc::new(EventStoreClient::new(store));
let mut processor = MinimalEventProcessor::new(metrics, store.clone());
let mut epoch = 1;
for (points, should_start, comment) in data {
for point in points.iter() {
process_train(&mut processor, *point, epoch);
}
end_epoch(&mut processor, epoch);
assert_eq!(
*should_start,
early_stopping.should_stop(epoch, &store),
"{comment}"
);
epoch += 1;
}
}
}

View File

@ -4,8 +4,9 @@ use burn_core::{
};
use std::sync::Arc;
use crate::{components::LearnerComponents, learner::base::TrainingInterrupter, Event};
use crate::{EventCollector, LearnerItem, MultiDevicesTrainStep, TrainStep, ValidStep};
use crate::metric::processor::{Event, EventProcessor, LearnerItem};
use crate::{components::LearnerComponents, learner::base::TrainingInterrupter};
use crate::{MultiDevicesTrainStep, TrainStep, ValidStep};
/// A validation epoch.
#[derive(new)]
@ -30,14 +31,14 @@ impl<VI> ValidEpoch<VI> {
/// # Arguments
///
/// * `model` - The model to validate.
/// * `callback` - The callback to use.
/// * `processor` - The event processor to use.
pub fn run<LC: LearnerComponents, VO>(
&self,
model: &LC::Model,
callback: &mut LC::EventCollector,
processor: &mut LC::EventProcessor,
interrupter: &TrainingInterrupter,
) where
LC::EventCollector: EventCollector<ItemValid = VO>,
LC::EventProcessor: EventProcessor<ItemValid = VO>,
<LC::Model as ADModule<LC::Backend>>::InnerModule: ValidStep<VI, VO>,
{
log::info!("Executing validation step for epoch {}", self.epoch);
@ -60,14 +61,14 @@ impl<VI> ValidEpoch<VI> {
None,
);
callback.on_event_valid(Event::ProcessedItem(item));
processor.process_valid(Event::ProcessedItem(item));
if interrupter.should_stop() {
log::info!("Training interrupted.");
break;
}
}
callback.on_event_valid(Event::EndEpoch(self.epoch));
processor.process_valid(Event::EndEpoch(self.epoch));
}
}
@ -79,7 +80,7 @@ impl<TI> TrainEpoch<TI> {
/// * `model` - The model to train.
/// * `optim` - The optimizer to use.
/// * `scheduler` - The learning rate scheduler to use.
/// * `callback` - The callback to use.
/// * `processor` - The event processor to use.
///
/// # Returns
///
@ -89,11 +90,11 @@ impl<TI> TrainEpoch<TI> {
mut model: LC::Model,
mut optim: LC::Optimizer,
scheduler: &mut LC::LrScheduler,
callback: &mut LC::EventCollector,
processor: &mut LC::EventProcessor,
interrupter: &TrainingInterrupter,
) -> (LC::Model, LC::Optimizer)
where
LC::EventCollector: EventCollector<ItemTrain = TO>,
LC::EventProcessor: EventProcessor<ItemTrain = TO>,
LC::Model: TrainStep<TI, TO>,
{
log::info!("Executing training step for epoch {}", self.epoch,);
@ -134,13 +135,14 @@ impl<TI> TrainEpoch<TI> {
Some(lr),
);
callback.on_event_train(Event::ProcessedItem(item));
processor.process_train(Event::ProcessedItem(item));
if interrupter.should_stop() {
log::info!("Training interrupted.");
break;
}
}
callback.on_event_train(Event::EndEpoch(self.epoch));
processor.process_train(Event::EndEpoch(self.epoch));
(model, optim)
}
@ -154,7 +156,7 @@ impl<TI> TrainEpoch<TI> {
/// * `model` - The model to train.
/// * `optim` - The optimizer to use.
/// * `lr_scheduler` - The learning rate scheduler to use.
/// * `callback` - The callback to use.
/// * `processor` - The event processor to use.
/// * `devices` - The devices to use.
///
/// # Returns
@ -165,12 +167,12 @@ impl<TI> TrainEpoch<TI> {
mut model: LC::Model,
mut optim: LC::Optimizer,
lr_scheduler: &mut LC::LrScheduler,
callback: &mut LC::EventCollector,
processor: &mut LC::EventProcessor,
devices: Vec<<LC::Backend as Backend>::Device>,
interrupter: &TrainingInterrupter,
) -> (LC::Model, LC::Optimizer)
where
LC::EventCollector: EventCollector<ItemTrain = TO>,
LC::EventProcessor: EventProcessor<ItemTrain = TO>,
LC::Model: TrainStep<TI, TO>,
TO: Send + 'static,
TI: Send + 'static,
@ -224,7 +226,7 @@ impl<TI> TrainEpoch<TI> {
Some(lr),
);
callback.on_event_train(Event::ProcessedItem(item));
processor.process_train(Event::ProcessedItem(item));
if interrupter.should_stop() {
log::info!("Training interrupted.");
@ -238,7 +240,7 @@ impl<TI> TrainEpoch<TI> {
}
}
callback.on_event_train(Event::EndEpoch(self.epoch));
processor.process_train(Event::EndEpoch(self.epoch));
(model, optim)
}

View File

@ -1,6 +1,7 @@
mod base;
mod builder;
mod classification;
mod early_stopping;
mod epoch;
mod regression;
mod step;
@ -11,6 +12,7 @@ pub(crate) mod log;
pub use base::*;
pub use builder::*;
pub use classification::*;
pub use early_stopping::*;
pub use epoch::*;
pub use regression::*;
pub use step::*;

View File

@ -1,5 +1,6 @@
use crate::components::LearnerComponents;
use crate::{EventCollector, Learner, TrainEpoch, ValidEpoch};
use crate::metric::processor::EventProcessor;
use crate::{Learner, TrainEpoch, ValidEpoch};
use burn_core::data::dataloader::DataLoader;
use burn_core::module::{ADModule, Module};
use burn_core::optim::{GradientsParams, Optimizer};
@ -115,7 +116,7 @@ impl<LC: LearnerComponents> Learner<LC> {
OutputValid: Send,
LC::Model: TrainStep<InputTrain, OutputTrain>,
<LC::Model as ADModule<LC::Backend>>::InnerModule: ValidStep<InputValid, OutputValid>,
LC::EventCollector: EventCollector<ItemTrain = OutputTrain, ItemValid = OutputValid>,
LC::EventProcessor: EventProcessor<ItemTrain = OutputTrain, ItemValid = OutputValid>,
{
log::info!("Fitting {}", self.model.to_string());
// The reference model is always on the first device provided.
@ -151,7 +152,7 @@ impl<LC: LearnerComponents> Learner<LC> {
self.model,
self.optim,
&mut self.lr_scheduler,
&mut self.collector,
&mut self.event_processor,
self.devices.clone(),
&self.interrupter,
)
@ -160,7 +161,7 @@ impl<LC: LearnerComponents> Learner<LC> {
self.model,
self.optim,
&mut self.lr_scheduler,
&mut self.collector,
&mut self.event_processor,
&self.interrupter,
);
}
@ -170,7 +171,11 @@ impl<LC: LearnerComponents> Learner<LC> {
}
let epoch_valid = ValidEpoch::new(dataloader_valid.clone(), epoch, self.num_epochs);
epoch_valid.run::<LC, OutputValid>(&self.model, &mut self.collector, &self.interrupter);
epoch_valid.run::<LC, OutputValid>(
&self.model,
&mut self.event_processor,
&self.interrupter,
);
if let Some(checkpointer) = &mut self.checkpointer {
checkpointer.checkpoint(
@ -178,9 +183,15 @@ impl<LC: LearnerComponents> Learner<LC> {
&self.optim,
&self.lr_scheduler,
epoch,
&mut self.collector,
&self.event_store,
);
}
if let Some(early_stopping) = &mut self.early_stopping {
if early_stopping.should_stop(epoch, &self.event_store) {
break;
}
}
}
self.model

View File

@ -19,13 +19,8 @@ pub mod logger;
/// The metric module.
pub mod metric;
/// All information collected during training.
pub mod info;
mod collector;
mod learner;
pub use collector::*;
pub use learner::*;
#[cfg(test)]

View File

@ -0,0 +1,16 @@
use super::Logger;
/// In memory logger.
#[derive(Default)]
pub struct InMemoryLogger {
pub(crate) values: Vec<String>,
}
impl<T> Logger<T> for InMemoryLogger
where
T: std::fmt::Display,
{
fn log(&mut self, item: T) {
self.values.push(item.to_string());
}
}

View File

@ -1,4 +1,4 @@
use super::{AsyncLogger, FileLogger, Logger};
use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger};
use crate::metric::MetricEntry;
use std::collections::HashMap;
@ -16,7 +16,7 @@ pub trait MetricLogger: Send {
/// # Arguments
///
/// * `epoch` - The epoch.
fn epoch(&mut self, epoch: usize);
fn end_epoch(&mut self, epoch: usize);
/// Read the logs for an epoch.
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<f64>, String>;
@ -81,9 +81,9 @@ impl MetricLogger for FileMetricLogger {
logger.log(value.clone());
}
fn epoch(&mut self, epoch: usize) {
fn end_epoch(&mut self, epoch: usize) {
self.loggers.clear();
self.epoch = epoch;
self.epoch = epoch + 1;
}
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<f64>, String> {
@ -125,23 +125,24 @@ impl MetricLogger for FileMetricLogger {
/// In memory metric logger, useful when testing and debugging.
#[derive(Default)]
pub struct InMemoryMetricLogger {
values: HashMap<String, Vec<Vec<String>>>,
values: HashMap<String, Vec<InMemoryLogger>>,
}
impl MetricLogger for InMemoryMetricLogger {
fn log(&mut self, item: &MetricEntry) {
if !self.values.contains_key(&item.name) {
self.values.insert(item.name.clone(), vec![vec![]]);
self.values
.insert(item.name.clone(), vec![InMemoryLogger::default()]);
}
let values = self.values.get_mut(&item.name).unwrap();
values.last_mut().unwrap().push(item.serialize.clone());
values.last_mut().unwrap().log(item.serialize.clone());
}
fn epoch(&mut self, _epoch: usize) {
fn end_epoch(&mut self, _epoch: usize) {
for (_, values) in self.values.iter_mut() {
values.push(Vec::new());
values.push(InMemoryLogger::default());
}
}
@ -152,7 +153,8 @@ impl MetricLogger for InMemoryMetricLogger {
};
match values.get(epoch - 1) {
Some(values) => Ok(values
Some(logger) => Ok(logger
.values
.iter()
.filter_map(|value| value.parse::<f64>().ok())
.collect()),

View File

@ -1,9 +1,11 @@
mod async_logger;
mod base;
mod file;
mod in_memory;
mod metric;
pub use async_logger::*;
pub use base::*;
pub use file::*;
pub use in_memory::*;
pub use metric::*;

View File

@ -74,7 +74,7 @@ pub trait Numeric {
}
/// Data type that contains the current state of a metric at a given time.
#[derive(new, Debug)]
#[derive(new, Debug, Clone)]
pub struct MetricEntry {
/// The name of the metric.
pub name: String,

View File

@ -26,3 +26,7 @@ pub use learning_rate::*;
pub use loss::*;
#[cfg(feature = "metrics")]
pub use memory_use::*;
pub(crate) mod processor;
/// Module responsible to save and exposes data collected during training.
pub mod store;

View File

@ -0,0 +1,45 @@
use burn_core::data::dataloader::Progress;
use burn_core::LearningRate;
/// Event happening during the training/validation process.
pub enum Event<T> {
/// Signal that an item have been processed.
ProcessedItem(LearnerItem<T>),
/// Signal the end of an epoch.
EndEpoch(usize),
}
/// Process events happening during training and validation.
pub trait EventProcessor {
/// The training item.
type ItemTrain;
/// The validation item.
type ItemValid;
/// Collect a training event.
fn process_train(&mut self, event: Event<Self::ItemTrain>);
/// Collect a validation event.
fn process_valid(&mut self, event: Event<Self::ItemValid>);
}
/// A learner item.
#[derive(new)]
pub struct LearnerItem<T> {
/// The item.
pub item: T,
/// The progress.
pub progress: Progress,
/// The epoch.
pub epoch: usize,
/// The total number of epochs.
pub epoch_total: usize,
/// The iteration.
pub iteration: usize,
/// The learning rate.
pub lr: Option<LearningRate>,
}

View File

@ -0,0 +1,100 @@
use super::{Event, EventProcessor, Metrics};
use crate::metric::store::EventStoreClient;
use crate::renderer::{MetricState, MetricsRenderer};
use std::sync::Arc;
/// An [event processor](EventProcessor) that handles:
/// - Computing and storing metrics in an [event store](crate::metric::store::EventStore).
/// - Render metrics using a [metrics renderer](MetricsRenderer).
pub struct FullEventProcessor<T, V> {
metrics: Metrics<T, V>,
renderer: Box<dyn MetricsRenderer>,
store: Arc<EventStoreClient>,
}
impl<T, V> FullEventProcessor<T, V> {
pub(crate) fn new(
metrics: Metrics<T, V>,
renderer: Box<dyn MetricsRenderer>,
store: Arc<EventStoreClient>,
) -> Self {
Self {
metrics,
renderer,
store,
}
}
}
impl<T, V> EventProcessor for FullEventProcessor<T, V> {
type ItemTrain = T;
type ItemValid = V;
fn process_train(&mut self, event: Event<Self::ItemTrain>) {
match event {
Event::ProcessedItem(item) => {
let progress = (&item).into();
let metadata = (&item).into();
let update = self.metrics.update_train(&item, &metadata);
self.store
.add_event_train(crate::metric::store::Event::MetricsUpdate(update.clone()));
update
.entries
.into_iter()
.for_each(|entry| self.renderer.update_train(MetricState::Generic(entry)));
update
.entries_numeric
.into_iter()
.for_each(|(entry, value)| {
self.renderer
.update_train(MetricState::Numeric(entry, value))
});
self.renderer.render_train(progress);
}
Event::EndEpoch(epoch) => {
self.metrics.end_epoch_train();
self.store
.add_event_train(crate::metric::store::Event::EndEpoch(epoch));
}
}
}
fn process_valid(&mut self, event: Event<Self::ItemValid>) {
match event {
Event::ProcessedItem(item) => {
let progress = (&item).into();
let metadata = (&item).into();
let update = self.metrics.update_valid(&item, &metadata);
self.store
.add_event_valid(crate::metric::store::Event::MetricsUpdate(update.clone()));
update
.entries
.into_iter()
.for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry)));
update
.entries_numeric
.into_iter()
.for_each(|(entry, value)| {
self.renderer
.update_valid(MetricState::Numeric(entry, value))
});
self.renderer.render_valid(progress);
}
Event::EndEpoch(epoch) => {
self.metrics.end_epoch_valid();
self.store
.add_event_valid(crate::metric::store::Event::EndEpoch(epoch));
}
}
}
}

View File

@ -1,74 +1,66 @@
use super::NumericMetricsAggregate;
use super::LearnerItem;
use crate::{
logger::MetricLogger,
metric::{Adaptor, Metric, MetricEntry, MetricMetadata, Numeric},
Aggregate, Direction, LearnerItem, Split,
metric::{store::MetricsUpdate, Adaptor, Metric, MetricEntry, MetricMetadata, Numeric},
renderer::TrainingProgress,
};
/// Metrics information collected during training.
pub struct MetricsInfo<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
pub(crate) struct Metrics<T, V> {
train: Vec<Box<dyn MetricUpdater<T>>>,
valid: Vec<Box<dyn MetricUpdater<V>>>,
train_numeric: Vec<Box<dyn NumericMetricUpdater<T>>>,
valid_numeric: Vec<Box<dyn NumericMetricUpdater<V>>>,
loggers_train: Vec<Box<dyn MetricLogger>>,
loggers_valid: Vec<Box<dyn MetricLogger>>,
aggregate_train: NumericMetricsAggregate,
aggregate_valid: NumericMetricsAggregate,
}
#[derive(new)]
pub(crate) struct MetricsUpdate {
pub(crate) entries: Vec<MetricEntry>,
pub(crate) entries_numeric: Vec<(MetricEntry, f64)>,
}
impl<T, V> MetricsInfo<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
pub(crate) fn new() -> Self {
impl<T, V> Default for Metrics<T, V> {
fn default() -> Self {
Self {
train: vec![],
valid: vec![],
train_numeric: vec![],
valid_numeric: vec![],
loggers_train: vec![],
loggers_valid: vec![],
aggregate_train: NumericMetricsAggregate::default(),
aggregate_valid: NumericMetricsAggregate::default(),
train: Vec::default(),
valid: Vec::default(),
train_numeric: Vec::default(),
valid_numeric: Vec::default(),
}
}
}
/// Signal the end of a training epoch.
pub(crate) fn end_epoch_train(&mut self, epoch: usize) {
for metric in self.train.iter_mut() {
metric.clear();
}
for metric in self.train_numeric.iter_mut() {
metric.clear();
}
for logger in self.loggers_train.iter_mut() {
logger.epoch(epoch + 1);
}
impl<T, V> Metrics<T, V> {
/// Register a training metric.
pub(crate) fn register_metric_train<Me: Metric + 'static>(&mut self, metric: Me)
where
T: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.train.push(Box::new(metric))
}
/// Signal the end of a validation epoch.
pub(crate) fn end_epoch_valid(&mut self, epoch: usize) {
for metric in self.valid.iter_mut() {
metric.clear();
}
for metric in self.valid_numeric.iter_mut() {
metric.clear();
}
for logger in self.loggers_valid.iter_mut() {
logger.epoch(epoch + 1);
}
/// Register a validation metric.
pub(crate) fn register_valid_metric<Me: Metric + 'static>(&mut self, metric: Me)
where
V: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.valid.push(Box::new(metric))
}
/// Register a numeric training metric.
pub(crate) fn register_train_metric_numeric<Me: Metric + Numeric + 'static>(
&mut self,
metric: Me,
) where
T: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.train_numeric.push(Box::new(metric))
}
/// Register a numeric validation metric.
pub(crate) fn register_valid_metric_numeric<Me: Metric + Numeric + 'static>(
&mut self,
metric: Me,
) where
V: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.valid_numeric.push(Box::new(metric))
}
/// Update the training information from the training item.
@ -82,20 +74,11 @@ where
for metric in self.train.iter_mut() {
let state = metric.update(item, metadata);
for logger in self.loggers_train.iter_mut() {
logger.log(&state);
}
entries.push(state);
}
for metric in self.train_numeric.iter_mut() {
let (state, value) = metric.update(item, metadata);
for logger in self.loggers_train.iter_mut() {
logger.log(&state);
}
entries_numeric.push((state, value));
}
@ -113,94 +96,58 @@ where
for metric in self.valid.iter_mut() {
let state = metric.update(item, metadata);
for logger in self.loggers_valid.iter_mut() {
logger.log(&state);
}
entries.push(state);
}
for metric in self.valid_numeric.iter_mut() {
let (state, value) = metric.update(item, metadata);
for logger in self.loggers_valid.iter_mut() {
logger.log(&state);
}
entries_numeric.push((state, value));
}
MetricsUpdate::new(entries, entries_numeric)
}
/// Find the epoch corresponding to the given criteria.
pub(crate) fn find_epoch(
&mut self,
name: &str,
aggregate: Aggregate,
direction: Direction,
split: Split,
) -> Option<usize> {
match split {
Split::Train => {
self.aggregate_train
.find_epoch(name, aggregate, direction, &mut self.loggers_train)
}
Split::Valid => {
self.aggregate_valid
.find_epoch(name, aggregate, direction, &mut self.loggers_valid)
}
/// Signal the end of a training epoch.
pub(crate) fn end_epoch_train(&mut self) {
for metric in self.train.iter_mut() {
metric.clear();
}
for metric in self.train_numeric.iter_mut() {
metric.clear();
}
}
/// Register a logger for training metrics.
pub(crate) fn register_logger_train<ML: MetricLogger + 'static>(&mut self, logger: ML) {
self.loggers_train.push(Box::new(logger));
/// Signal the end of a validation epoch.
pub(crate) fn end_epoch_valid(&mut self) {
for metric in self.valid.iter_mut() {
metric.clear();
}
for metric in self.valid_numeric.iter_mut() {
metric.clear();
}
}
}
/// Register a logger for validation metrics.
pub(crate) fn register_logger_valid<ML: MetricLogger + 'static>(&mut self, logger: ML) {
self.loggers_valid.push(Box::new(logger));
impl<T> From<&LearnerItem<T>> for TrainingProgress {
fn from(item: &LearnerItem<T>) -> Self {
Self {
progress: item.progress.clone(),
epoch: item.epoch,
epoch_total: item.epoch_total,
iteration: item.iteration,
}
}
}
/// Register a training metric.
pub(crate) fn register_metric_train<Me: Metric + 'static>(&mut self, metric: Me)
where
T: Adaptor<Me::Input>,
{
let metric = MetricWrapper::new(metric);
self.train.push(Box::new(metric))
}
/// Register a validation metric.
pub(crate) fn register_valid_metric<Me: Metric + 'static>(&mut self, metric: Me)
where
V: Adaptor<Me::Input>,
{
let metric = MetricWrapper::new(metric);
self.valid.push(Box::new(metric))
}
/// Register a numeric training metric.
pub(crate) fn register_train_metric_numeric<Me: Metric + Numeric + 'static>(
&mut self,
metric: Me,
) where
T: Adaptor<Me::Input>,
{
let metric = MetricWrapper::new(metric);
self.train_numeric.push(Box::new(metric))
}
/// Register a numeric validation metric.
pub(crate) fn register_valid_metric_numeric<Me: Metric + Numeric + 'static>(
&mut self,
metric: Me,
) where
V: Adaptor<Me::Input>,
{
let metric = MetricWrapper::new(metric);
self.valid_numeric.push(Box::new(metric))
impl<T> From<&LearnerItem<T>> for MetricMetadata {
fn from(item: &LearnerItem<T>) -> Self {
Self {
progress: item.progress.clone(),
epoch: item.epoch,
epoch_total: item.epoch_total,
iteration: item.iteration,
lr: item.lr,
}
}
}

View File

@ -0,0 +1,52 @@
use super::{Event, EventProcessor, Metrics};
use crate::metric::store::EventStoreClient;
use std::sync::Arc;
/// An [event processor](EventProcessor) that handles:
/// - Computing and storing metrics in an [event store](crate::metric::store::EventStore).
#[derive(new)]
pub(crate) struct MinimalEventProcessor<T, V> {
metrics: Metrics<T, V>,
store: Arc<EventStoreClient>,
}
impl<T, V> EventProcessor for MinimalEventProcessor<T, V> {
type ItemTrain = T;
type ItemValid = V;
fn process_train(&mut self, event: Event<Self::ItemTrain>) {
match event {
Event::ProcessedItem(item) => {
let metadata = (&item).into();
let update = self.metrics.update_train(&item, &metadata);
self.store
.add_event_train(crate::metric::store::Event::MetricsUpdate(update));
}
Event::EndEpoch(epoch) => {
self.metrics.end_epoch_train();
self.store
.add_event_train(crate::metric::store::Event::EndEpoch(epoch));
}
}
}
fn process_valid(&mut self, event: Event<Self::ItemValid>) {
match event {
Event::ProcessedItem(item) => {
let metadata = (&item).into();
let update = self.metrics.update_valid(&item, &metadata);
self.store
.add_event_valid(crate::metric::store::Event::MetricsUpdate(update));
}
Event::EndEpoch(epoch) => {
self.metrics.end_epoch_valid();
self.store
.add_event_valid(crate::metric::store::Event::EndEpoch(epoch));
}
}
}
}

View File

@ -0,0 +1,53 @@
mod base;
mod full;
mod metrics;
mod minimal;
pub use base::*;
pub(crate) use full::*;
pub(crate) use metrics::*;
#[cfg(test)]
pub(crate) use minimal::*;
#[cfg(test)]
pub(crate) mod test_utils {
use crate::metric::{
processor::{Event, EventProcessor, LearnerItem, MinimalEventProcessor},
Adaptor, LossInput,
};
use burn_core::tensor::{backend::Backend, ElementConversion, Tensor};
impl<B: Backend> Adaptor<LossInput<B>> for f64 {
fn adapt(&self) -> LossInput<B> {
LossInput::new(Tensor::from_data([self.elem()]))
}
}
pub(crate) fn process_train(
processor: &mut MinimalEventProcessor<f64, f64>,
value: f64,
epoch: usize,
) {
let dummy_progress = burn_core::data::dataloader::Progress {
items_processed: 1,
items_total: 10,
};
let num_epochs = 3;
let dummy_iteration = 1;
processor.process_train(Event::ProcessedItem(LearnerItem::new(
value,
dummy_progress,
epoch,
num_epochs,
dummy_iteration,
None,
)));
}
pub(crate) fn end_epoch(processor: &mut MinimalEventProcessor<f64, f64>, epoch: usize) {
processor.process_train(Event::EndEpoch(epoch));
processor.process_valid(Event::EndEpoch(epoch));
}
}

View File

@ -1,28 +1,32 @@
use crate::{logger::MetricLogger, Aggregate, Direction};
use crate::logger::MetricLogger;
use std::collections::HashMap;
use super::{Aggregate, Direction};
/// Type that can be used to fetch and use numeric metric aggregates.
#[derive(Default, Debug)]
pub(crate) struct NumericMetricsAggregate {
mean_for_each_epoch: HashMap<Key, f64>,
value_for_each_epoch: HashMap<Key, f64>,
}
#[derive(new, Hash, PartialEq, Eq, Debug)]
struct Key {
name: String,
epoch: usize,
aggregate: Aggregate,
}
impl NumericMetricsAggregate {
pub(crate) fn mean(
pub(crate) fn aggregate(
&mut self,
name: &str,
epoch: usize,
aggregate: Aggregate,
loggers: &mut [Box<dyn MetricLogger>],
) -> Option<f64> {
let key = Key::new(name.to_string(), epoch);
let key = Key::new(name.to_string(), epoch, aggregate);
if let Some(value) = self.mean_for_each_epoch.get(&key) {
if let Some(value) = self.value_for_each_epoch.get(&key) {
return Some(*value);
}
@ -45,10 +49,13 @@ impl NumericMetricsAggregate {
}
let num_points = points.len();
let mean = points.into_iter().sum::<f64>() / num_points as f64;
let sum = points.into_iter().sum::<f64>();
let value = match aggregate {
Aggregate::Mean => sum / num_points as f64,
};
self.mean_for_each_epoch.insert(key, mean);
Some(mean)
self.value_for_each_epoch.insert(key, value);
Some(value)
}
pub(crate) fn find_epoch(
@ -61,16 +68,8 @@ impl NumericMetricsAggregate {
let mut data = Vec::new();
let mut current_epoch = 1;
loop {
match aggregate {
Aggregate::Mean => match self.mean(name, current_epoch, loggers) {
Some(value) => {
data.push(value);
}
None => break,
},
};
while let Some(value) = self.aggregate(name, current_epoch, aggregate, loggers) {
data.push(value);
current_epoch += 1;
}
@ -131,8 +130,8 @@ mod tests {
));
}
fn new_epoch(&mut self) {
self.logger.end_epoch(self.epoch);
self.epoch += 1;
self.logger.epoch(self.epoch);
}
}

View File

@ -0,0 +1,69 @@
use crate::metric::MetricEntry;
/// Event happening during the training/validation process.
pub enum Event {
/// Signal that metrics have been updated.
MetricsUpdate(MetricsUpdate),
/// Signal the end of an epoch.
EndEpoch(usize),
}
/// Contains all metric information.
#[derive(new, Clone)]
pub struct MetricsUpdate {
/// Metrics information related to non-numeric metrics.
pub entries: Vec<MetricEntry>,
/// Metrics information related to numeric metrics.
pub entries_numeric: Vec<(MetricEntry, f64)>,
}
/// Defines how training and validation events are collected and searched.
///
/// This trait also exposes methods that uses the collected data to compute useful information.
pub trait EventStore: Send {
/// Collect a training/validation event.
fn add_event(&mut self, event: Event, split: Split);
/// Find the epoch following the given criteria from the collected data.
fn find_epoch(
&mut self,
name: &str,
aggregate: Aggregate,
direction: Direction,
split: Split,
) -> Option<usize>;
/// Find the metric value for the current epoch following the given criteria.
fn find_metric(
&mut self,
name: &str,
epoch: usize,
aggregate: Aggregate,
split: Split,
) -> Option<f64>;
}
#[derive(Copy, Clone, Hash, PartialEq, Eq, Debug)]
/// How to aggregate the metric.
pub enum Aggregate {
/// Compute the average.
Mean,
}
#[derive(Copy, Clone)]
/// The split to use.
pub enum Split {
/// The training split.
Train,
/// The validation split.
Valid,
}
#[derive(Copy, Clone)]
/// The direction of the query.
pub enum Direction {
/// Lower is better.
Lowest,
/// Higher is better.
Highest,
}

View File

@ -0,0 +1,149 @@
use super::EventStore;
use super::{Aggregate, Direction, Event, Split};
use std::{sync::mpsc, thread::JoinHandle};
/// Type that allows to communicate with an [event store](EventStore).
pub struct EventStoreClient {
sender: mpsc::Sender<Message>,
handler: Option<JoinHandle<()>>,
}
impl EventStoreClient {
/// Create a new [event store](EventStore) client.
pub(crate) fn new<C>(store: C) -> Self
where
C: EventStore + 'static,
{
let (sender, receiver) = mpsc::channel();
let thread = WorkerThread::new(store, receiver);
let handler = std::thread::spawn(move || thread.run());
let handler = Some(handler);
Self { sender, handler }
}
}
impl EventStoreClient {
/// Add a training event to the [event store](EventStore).
pub(crate) fn add_event_train(&self, event: Event) {
self.sender.send(Message::OnEventTrain(event)).unwrap();
}
/// Add a validation event to the [event store](EventStore).
pub(crate) fn add_event_valid(&self, event: Event) {
self.sender.send(Message::OnEventValid(event)).unwrap();
}
/// Find the epoch following the given criteria from the collected data.
pub fn find_epoch(
&self,
name: &str,
aggregate: Aggregate,
direction: Direction,
split: Split,
) -> Option<usize> {
let (sender, receiver) = mpsc::sync_channel(1);
self.sender
.send(Message::FindEpoch(
name.to_string(),
aggregate,
direction,
split,
sender,
))
.unwrap();
match receiver.recv() {
Ok(value) => value,
Err(err) => panic!("Event store thread crashed: {:?}", err),
}
}
/// Find the metric value for the current epoch following the given criteria.
pub fn find_metric(
&self,
name: &str,
epoch: usize,
aggregate: Aggregate,
split: Split,
) -> Option<f64> {
let (sender, receiver) = mpsc::sync_channel(1);
self.sender
.send(Message::FindMetric(
name.to_string(),
epoch,
aggregate,
split,
sender,
))
.unwrap();
match receiver.recv() {
Ok(value) => value,
Err(err) => panic!("Event store thread crashed: {:?}", err),
}
}
}
#[derive(new)]
struct WorkerThread<S> {
store: S,
receiver: mpsc::Receiver<Message>,
}
impl<C> WorkerThread<C>
where
C: EventStore,
{
fn run(mut self) {
for item in self.receiver.iter() {
match item {
Message::End => {
return;
}
Message::FindEpoch(name, aggregate, direction, split, sender) => {
let response = self.store.find_epoch(&name, aggregate, direction, split);
sender.send(response).unwrap();
}
Message::FindMetric(name, epoch, aggregate, split, sender) => {
let response = self.store.find_metric(&name, epoch, aggregate, split);
sender.send(response).unwrap();
}
Message::OnEventTrain(event) => self.store.add_event(event, Split::Train),
Message::OnEventValid(event) => self.store.add_event(event, Split::Valid),
}
}
}
}
enum Message {
OnEventTrain(Event),
OnEventValid(Event),
End,
FindEpoch(
String,
Aggregate,
Direction,
Split,
mpsc::SyncSender<Option<usize>>,
),
FindMetric(
String,
usize,
Aggregate,
Split,
mpsc::SyncSender<Option<f64>>,
),
}
impl Drop for EventStoreClient {
fn drop(&mut self) {
self.sender.send(Message::End).unwrap();
let handler = self.handler.take();
if let Some(handler) = handler {
handler.join().unwrap();
}
}
}

View File

@ -0,0 +1,101 @@
use super::{aggregate::NumericMetricsAggregate, Aggregate, Direction, Event, EventStore, Split};
use crate::logger::MetricLogger;
#[derive(Default)]
pub(crate) struct LogEventStore {
loggers_train: Vec<Box<dyn MetricLogger>>,
loggers_valid: Vec<Box<dyn MetricLogger>>,
aggregate_train: NumericMetricsAggregate,
aggregate_valid: NumericMetricsAggregate,
}
impl EventStore for LogEventStore {
fn add_event(&mut self, event: Event, split: Split) {
match event {
Event::MetricsUpdate(update) => match split {
Split::Train => {
update
.entries
.iter()
.chain(update.entries_numeric.iter().map(|(entry, _value)| entry))
.for_each(|entry| {
self.loggers_train
.iter_mut()
.for_each(|logger| logger.log(entry));
});
}
Split::Valid => {
update
.entries
.iter()
.chain(update.entries_numeric.iter().map(|(entry, _value)| entry))
.for_each(|entry| {
self.loggers_valid
.iter_mut()
.for_each(|logger| logger.log(entry));
});
}
},
Event::EndEpoch(epoch) => match split {
Split::Train => self
.loggers_train
.iter_mut()
.for_each(|logger| logger.end_epoch(epoch)),
Split::Valid => self
.loggers_valid
.iter_mut()
.for_each(|logger| logger.end_epoch(epoch + 1)),
},
}
}
fn find_epoch(
&mut self,
name: &str,
aggregate: Aggregate,
direction: Direction,
split: Split,
) -> Option<usize> {
match split {
Split::Train => {
self.aggregate_train
.find_epoch(name, aggregate, direction, &mut self.loggers_train)
}
Split::Valid => {
self.aggregate_valid
.find_epoch(name, aggregate, direction, &mut self.loggers_valid)
}
}
}
fn find_metric(
&mut self,
name: &str,
epoch: usize,
aggregate: Aggregate,
split: Split,
) -> Option<f64> {
match split {
Split::Train => {
self.aggregate_train
.aggregate(name, epoch, aggregate, &mut self.loggers_train)
}
Split::Valid => {
self.aggregate_valid
.aggregate(name, epoch, aggregate, &mut self.loggers_valid)
}
}
}
}
impl LogEventStore {
/// Register a logger for training metrics.
pub(crate) fn register_logger_train<ML: MetricLogger + 'static>(&mut self, logger: ML) {
self.loggers_train.push(Box::new(logger));
}
/// Register a logger for validation metrics.
pub(crate) fn register_logger_valid<ML: MetricLogger + 'static>(&mut self, logger: ML) {
self.loggers_valid.push(Box::new(logger));
}
}

View File

@ -0,0 +1,9 @@
pub(crate) mod aggregate;
mod base;
mod client;
mod log;
pub(crate) use self::log::*;
pub use base::*;
pub use client::*;

View File

@ -1,4 +1,4 @@
use crate::metric::callback::{MetricState, MetricsRenderer, TrainingProgress};
use crate::metric::renderer::{MetricState, MetricsRenderer, TrainingProgress};
/// A simple renderer for when the cli feature is not enabled.
pub struct CliMetricsRenderer;

View File

@ -1,5 +1,4 @@
use crate::data::MNISTBatch;
use burn::{
module::Module,
nn::{self, loss::CrossEntropyLoss, BatchNorm, PaddingConfig2d},

View File

@ -5,7 +5,9 @@ use burn::module::Module;
use burn::optim::decay::WeightDecayConfig;
use burn::optim::AdamConfig;
use burn::record::{CompactRecorder, NoStdTrainingRecorder};
use burn::train::metric::store::{Aggregate, Direction, Split};
use burn::train::metric::{CpuMemory, CpuTemperature, CpuUse};
use burn::train::{MetricEarlyStoppingStrategy, StoppingCondition};
use burn::{
config::Config,
data::{dataloader::DataLoaderBuilder, dataset::source::huggingface::MNISTDataset},
@ -69,6 +71,12 @@ pub fn run<B: ADBackend>(device: B::Device) {
.metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(CompactRecorder::new())
.early_stopping(MetricEarlyStoppingStrategy::new::<LossMetric<B>>(
Aggregate::Mean,
Direction::Lowest,
Split::Valid,
StoppingCondition::NoImprovementSince { n_epochs: 1 },
))
.devices(vec![device])
.num_epochs(config.num_epochs)
.build(Model::new(), config.optimizer.init(), 1e-4);