mirror of https://github.com/tracel-ai/burn.git
Feat/early stopping + burn train refactor (#878)
This commit is contained in:
parent
3eb7f380f3
commit
af813d09ed
|
@ -1,6 +1,7 @@
|
|||
use crate::EventCollector;
|
||||
use std::ops::DerefMut;
|
||||
|
||||
use crate::metric::store::EventStoreClient;
|
||||
|
||||
/// Action to be taken by a [checkpointer](crate::checkpoint::Checkpointer).
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub enum CheckpointingAction {
|
||||
|
@ -11,15 +12,23 @@ pub enum CheckpointingAction {
|
|||
}
|
||||
|
||||
/// Define when checkpoint should be saved and deleted.
|
||||
pub trait CheckpointingStrategy<E: EventCollector> {
|
||||
pub trait CheckpointingStrategy {
|
||||
/// Based on the epoch, determine if the checkpoint should be saved.
|
||||
fn checkpointing(&mut self, epoch: usize, collector: &mut E) -> Vec<CheckpointingAction>;
|
||||
fn checkpointing(
|
||||
&mut self,
|
||||
epoch: usize,
|
||||
collector: &EventStoreClient,
|
||||
) -> Vec<CheckpointingAction>;
|
||||
}
|
||||
|
||||
// We make dyn box implement the checkpointing strategy so that it can be used with generic, but
|
||||
// still be dynamic.
|
||||
impl<E: EventCollector> CheckpointingStrategy<E> for Box<dyn CheckpointingStrategy<E>> {
|
||||
fn checkpointing(&mut self, epoch: usize, collector: &mut E) -> Vec<CheckpointingAction> {
|
||||
impl CheckpointingStrategy for Box<dyn CheckpointingStrategy> {
|
||||
fn checkpointing(
|
||||
&mut self,
|
||||
epoch: usize,
|
||||
collector: &EventStoreClient,
|
||||
) -> Vec<CheckpointingAction> {
|
||||
self.deref_mut().checkpointing(epoch, collector)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,45 +1,40 @@
|
|||
use crate::metric::store::EventStoreClient;
|
||||
|
||||
use super::{CheckpointingAction, CheckpointingStrategy};
|
||||
use crate::EventCollector;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// Compose multiple checkpointing strategy and only delete checkpoints when both strategy flag an
|
||||
/// epoch to be deleted.
|
||||
pub struct ComposedCheckpointingStrategy<E: EventCollector> {
|
||||
strategies: Vec<Box<dyn CheckpointingStrategy<E>>>,
|
||||
pub struct ComposedCheckpointingStrategy {
|
||||
strategies: Vec<Box<dyn CheckpointingStrategy>>,
|
||||
deleted: Vec<HashSet<usize>>,
|
||||
}
|
||||
|
||||
/// Help building a [checkpointing strategy](CheckpointingStrategy) by combining multiple ones.
|
||||
pub struct ComposedCheckpointingStrategyBuilder<E: EventCollector> {
|
||||
strategies: Vec<Box<dyn CheckpointingStrategy<E>>>,
|
||||
#[derive(Default)]
|
||||
pub struct ComposedCheckpointingStrategyBuilder {
|
||||
strategies: Vec<Box<dyn CheckpointingStrategy>>,
|
||||
}
|
||||
|
||||
impl<E: EventCollector> Default for ComposedCheckpointingStrategyBuilder<E> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
strategies: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: EventCollector> ComposedCheckpointingStrategyBuilder<E> {
|
||||
impl ComposedCheckpointingStrategyBuilder {
|
||||
/// Add a new [checkpointing strategy](CheckpointingStrategy).
|
||||
#[allow(clippy::should_implement_trait)]
|
||||
pub fn add<S>(mut self, strategy: S) -> Self
|
||||
where
|
||||
S: CheckpointingStrategy<E> + 'static,
|
||||
S: CheckpointingStrategy + 'static,
|
||||
{
|
||||
self.strategies.push(Box::new(strategy));
|
||||
self
|
||||
}
|
||||
|
||||
/// Create a new [composed checkpointing strategy](ComposedCheckpointingStrategy).
|
||||
pub fn build(self) -> ComposedCheckpointingStrategy<E> {
|
||||
pub fn build(self) -> ComposedCheckpointingStrategy {
|
||||
ComposedCheckpointingStrategy::new(self.strategies)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: EventCollector> ComposedCheckpointingStrategy<E> {
|
||||
fn new(strategies: Vec<Box<dyn CheckpointingStrategy<E>>>) -> Self {
|
||||
impl ComposedCheckpointingStrategy {
|
||||
fn new(strategies: Vec<Box<dyn CheckpointingStrategy>>) -> Self {
|
||||
Self {
|
||||
deleted: strategies.iter().map(|_| HashSet::new()).collect(),
|
||||
strategies,
|
||||
|
@ -47,13 +42,17 @@ impl<E: EventCollector> ComposedCheckpointingStrategy<E> {
|
|||
}
|
||||
/// Create a new builder which help compose multiple
|
||||
/// [checkpointing strategies](CheckpointingStrategy).
|
||||
pub fn builder() -> ComposedCheckpointingStrategyBuilder<E> {
|
||||
pub fn builder() -> ComposedCheckpointingStrategyBuilder {
|
||||
ComposedCheckpointingStrategyBuilder::default()
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: EventCollector> CheckpointingStrategy<E> for ComposedCheckpointingStrategy<E> {
|
||||
fn checkpointing(&mut self, epoch: usize, collector: &mut E) -> Vec<CheckpointingAction> {
|
||||
impl CheckpointingStrategy for ComposedCheckpointingStrategy {
|
||||
fn checkpointing(
|
||||
&mut self,
|
||||
epoch: usize,
|
||||
collector: &EventStoreClient,
|
||||
) -> Vec<CheckpointingAction> {
|
||||
let mut saved = false;
|
||||
let mut actions = Vec::new();
|
||||
let mut epochs_to_check = Vec::new();
|
||||
|
@ -104,15 +103,12 @@ impl<E: EventCollector> CheckpointingStrategy<E> for ComposedCheckpointingStrate
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
checkpoint::KeepLastNCheckpoints, info::MetricsInfo, test_utils::TestEventCollector,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use crate::{checkpoint::KeepLastNCheckpoints, metric::store::LogEventStore};
|
||||
|
||||
#[test]
|
||||
fn should_delete_when_both_deletes() {
|
||||
let mut collector = TestEventCollector::<f64, f64>::new(MetricsInfo::new());
|
||||
let store = EventStoreClient::new(LogEventStore::default());
|
||||
let mut strategy = ComposedCheckpointingStrategy::builder()
|
||||
.add(KeepLastNCheckpoints::new(1))
|
||||
.add(KeepLastNCheckpoints::new(2))
|
||||
|
@ -120,17 +116,17 @@ mod tests {
|
|||
|
||||
assert_eq!(
|
||||
vec![CheckpointingAction::Save],
|
||||
strategy.checkpointing(1, &mut collector)
|
||||
strategy.checkpointing(1, &store)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
vec![CheckpointingAction::Save],
|
||||
strategy.checkpointing(2, &mut collector)
|
||||
strategy.checkpointing(2, &store)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)],
|
||||
strategy.checkpointing(3, &mut collector)
|
||||
strategy.checkpointing(3, &store)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use super::CheckpointingStrategy;
|
||||
use crate::{checkpoint::CheckpointingAction, EventCollector};
|
||||
use crate::{checkpoint::CheckpointingAction, metric::store::EventStoreClient};
|
||||
|
||||
/// Keep the last N checkpoints.
|
||||
///
|
||||
|
@ -10,8 +10,12 @@ pub struct KeepLastNCheckpoints {
|
|||
num_keep: usize,
|
||||
}
|
||||
|
||||
impl<E: EventCollector> CheckpointingStrategy<E> for KeepLastNCheckpoints {
|
||||
fn checkpointing(&mut self, epoch: usize, _collector: &mut E) -> Vec<CheckpointingAction> {
|
||||
impl CheckpointingStrategy for KeepLastNCheckpoints {
|
||||
fn checkpointing(
|
||||
&mut self,
|
||||
epoch: usize,
|
||||
_store: &EventStoreClient,
|
||||
) -> Vec<CheckpointingAction> {
|
||||
let mut actions = vec![CheckpointingAction::Save];
|
||||
|
||||
if let Some(epoch) = usize::checked_sub(epoch, self.num_keep) {
|
||||
|
@ -26,28 +30,27 @@ impl<E: EventCollector> CheckpointingStrategy<E> for KeepLastNCheckpoints {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{info::MetricsInfo, test_utils::TestEventCollector};
|
||||
|
||||
use super::*;
|
||||
use crate::metric::store::LogEventStore;
|
||||
|
||||
#[test]
|
||||
fn should_always_delete_lastn_epoch_if_higher_than_one() {
|
||||
let mut strategy = KeepLastNCheckpoints::new(2);
|
||||
let mut collector = TestEventCollector::<f64, f64>::new(MetricsInfo::new());
|
||||
let store = EventStoreClient::new(LogEventStore::default());
|
||||
|
||||
assert_eq!(
|
||||
vec![CheckpointingAction::Save],
|
||||
strategy.checkpointing(1, &mut collector)
|
||||
strategy.checkpointing(1, &store)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
vec![CheckpointingAction::Save],
|
||||
strategy.checkpointing(2, &mut collector)
|
||||
strategy.checkpointing(2, &store)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)],
|
||||
strategy.checkpointing(3, &mut collector)
|
||||
strategy.checkpointing(3, &store)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
use super::CheckpointingStrategy;
|
||||
use crate::{
|
||||
checkpoint::CheckpointingAction, metric::Metric, Aggregate, Direction, EventCollector, Split,
|
||||
checkpoint::CheckpointingAction,
|
||||
metric::{
|
||||
store::{Aggregate, Direction, EventStoreClient, Split},
|
||||
Metric,
|
||||
},
|
||||
};
|
||||
|
||||
/// Keep the best checkpoint based on a metric.
|
||||
|
@ -28,10 +32,14 @@ impl MetricCheckpointingStrategy {
|
|||
}
|
||||
}
|
||||
|
||||
impl<E: EventCollector> CheckpointingStrategy<E> for MetricCheckpointingStrategy {
|
||||
fn checkpointing(&mut self, epoch: usize, collector: &mut E) -> Vec<CheckpointingAction> {
|
||||
impl CheckpointingStrategy for MetricCheckpointingStrategy {
|
||||
fn checkpointing(
|
||||
&mut self,
|
||||
epoch: usize,
|
||||
store: &EventStoreClient,
|
||||
) -> Vec<CheckpointingAction> {
|
||||
let best_epoch =
|
||||
match collector.find_epoch(&self.name, self.aggregate, self.direction, self.split) {
|
||||
match store.find_epoch(&self.name, self.aggregate, self.direction, self.split) {
|
||||
Some(epoch_best) => epoch_best,
|
||||
None => epoch,
|
||||
};
|
||||
|
@ -56,93 +64,70 @@ impl<E: EventCollector> CheckpointingStrategy<E> for MetricCheckpointingStrategy
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn_core::tensor::{backend::Backend, ElementConversion, Tensor};
|
||||
use crate::{
|
||||
logger::InMemoryMetricLogger,
|
||||
metric::{
|
||||
processor::{
|
||||
test_utils::{end_epoch, process_train},
|
||||
Metrics, MinimalEventProcessor,
|
||||
},
|
||||
store::LogEventStore,
|
||||
LossMetric,
|
||||
},
|
||||
TestBackend,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use crate::{
|
||||
info::MetricsInfo,
|
||||
logger::InMemoryMetricLogger,
|
||||
metric::{Adaptor, LossInput, LossMetric},
|
||||
test_utils::TestEventCollector,
|
||||
Event, LearnerItem, TestBackend,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn always_keep_the_best_epoch() {
|
||||
let mut store = LogEventStore::default();
|
||||
let mut strategy = MetricCheckpointingStrategy::new::<LossMetric<TestBackend>>(
|
||||
Aggregate::Mean,
|
||||
Direction::Lowest,
|
||||
Split::Train,
|
||||
);
|
||||
let mut info = MetricsInfo::new();
|
||||
let mut metrics = Metrics::<f64, f64>::default();
|
||||
// Register an in memory logger.
|
||||
info.register_logger_train(InMemoryMetricLogger::default());
|
||||
store.register_logger_train(InMemoryMetricLogger::default());
|
||||
// Register the loss metric.
|
||||
info.register_train_metric_numeric(LossMetric::<TestBackend>::new());
|
||||
|
||||
let mut collector = TestEventCollector::<f64, f64>::new(info);
|
||||
metrics.register_train_metric_numeric(LossMetric::<TestBackend>::new());
|
||||
let store = Arc::new(EventStoreClient::new(store));
|
||||
let mut processor = MinimalEventProcessor::new(metrics, store.clone());
|
||||
|
||||
// Two points for the first epoch. Mean 0.75
|
||||
let mut epoch = 1;
|
||||
item(&mut collector, 1.0, epoch);
|
||||
item(&mut collector, 0.5, epoch);
|
||||
end_epoch(&mut collector, epoch);
|
||||
process_train(&mut processor, 1.0, epoch);
|
||||
process_train(&mut processor, 0.5, epoch);
|
||||
end_epoch(&mut processor, epoch);
|
||||
|
||||
// Should save the current record.
|
||||
assert_eq!(
|
||||
vec![CheckpointingAction::Save],
|
||||
strategy.checkpointing(epoch, &mut collector)
|
||||
strategy.checkpointing(epoch, &store)
|
||||
);
|
||||
|
||||
// Two points for the second epoch. Mean 0.4
|
||||
epoch += 1;
|
||||
item(&mut collector, 0.5, epoch);
|
||||
item(&mut collector, 0.3, epoch);
|
||||
end_epoch(&mut collector, epoch);
|
||||
process_train(&mut processor, 0.5, epoch);
|
||||
process_train(&mut processor, 0.3, epoch);
|
||||
end_epoch(&mut processor, epoch);
|
||||
|
||||
// Should save the current record and delete the pervious one.
|
||||
assert_eq!(
|
||||
vec![CheckpointingAction::Delete(1), CheckpointingAction::Save],
|
||||
strategy.checkpointing(epoch, &mut collector)
|
||||
strategy.checkpointing(epoch, &store)
|
||||
);
|
||||
|
||||
// Two points for the last epoch. Mean 2.0
|
||||
epoch += 1;
|
||||
item(&mut collector, 1.0, epoch);
|
||||
item(&mut collector, 3.0, epoch);
|
||||
end_epoch(&mut collector, epoch);
|
||||
process_train(&mut processor, 1.0, epoch);
|
||||
process_train(&mut processor, 3.0, epoch);
|
||||
end_epoch(&mut processor, epoch);
|
||||
|
||||
// Should not delete the previous record, since it's the best one, and should not save a
|
||||
// new one.
|
||||
assert!(strategy.checkpointing(epoch, &mut collector).is_empty());
|
||||
}
|
||||
|
||||
fn item(collector: &mut TestEventCollector<f64, f64>, value: f64, epoch: usize) {
|
||||
let dummy_progress = burn_core::data::dataloader::Progress {
|
||||
items_processed: 1,
|
||||
items_total: 10,
|
||||
};
|
||||
let num_epochs = 3;
|
||||
let dummy_iteration = 1;
|
||||
|
||||
collector.on_event_train(Event::ProcessedItem(LearnerItem::new(
|
||||
value,
|
||||
dummy_progress,
|
||||
epoch,
|
||||
num_epochs,
|
||||
dummy_iteration,
|
||||
None,
|
||||
)));
|
||||
}
|
||||
|
||||
fn end_epoch(collector: &mut TestEventCollector<f64, f64>, epoch: usize) {
|
||||
collector.on_event_train(Event::EndEpoch(epoch));
|
||||
collector.on_event_valid(Event::EndEpoch(epoch));
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<LossInput<B>> for f64 {
|
||||
fn adapt(&self) -> LossInput<B> {
|
||||
LossInput::new(Tensor::from_data([self.elem()]))
|
||||
}
|
||||
assert!(strategy.checkpointing(epoch, &store).is_empty());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,118 +0,0 @@
|
|||
use super::EventCollector;
|
||||
use crate::{Aggregate, Direction, Event, Split};
|
||||
use std::{sync::mpsc, thread::JoinHandle};
|
||||
|
||||
enum Message<T, V> {
|
||||
OnEventTrain(Event<T>),
|
||||
OnEventValid(Event<V>),
|
||||
End,
|
||||
FindEpoch(
|
||||
String,
|
||||
Aggregate,
|
||||
Direction,
|
||||
Split,
|
||||
mpsc::SyncSender<Option<usize>>,
|
||||
),
|
||||
}
|
||||
|
||||
/// Async [event collector](EventCollector).
|
||||
///
|
||||
/// This will create a worker thread where all the computation is done ensuring that the training loop is
|
||||
/// never blocked by metric calculation.
|
||||
pub struct AsyncEventCollector<T, V> {
|
||||
sender: mpsc::Sender<Message<T, V>>,
|
||||
handler: Option<JoinHandle<()>>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct WorkerThread<C, T, V> {
|
||||
collector: C,
|
||||
receiver: mpsc::Receiver<Message<T, V>>,
|
||||
}
|
||||
|
||||
impl<C, T, V> WorkerThread<C, T, V>
|
||||
where
|
||||
C: EventCollector<ItemTrain = T, ItemValid = V>,
|
||||
{
|
||||
fn run(mut self) {
|
||||
for item in self.receiver.iter() {
|
||||
match item {
|
||||
Message::End => {
|
||||
return;
|
||||
}
|
||||
Message::FindEpoch(name, aggregate, direction, split, sender) => {
|
||||
let response = self
|
||||
.collector
|
||||
.find_epoch(&name, aggregate, direction, split);
|
||||
sender.send(response).unwrap();
|
||||
}
|
||||
Message::OnEventTrain(event) => self.collector.on_event_train(event),
|
||||
Message::OnEventValid(event) => self.collector.on_event_valid(event),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Send + Sync + 'static, V: Send + Sync + 'static> AsyncEventCollector<T, V> {
|
||||
/// Create a new async [event collector](EventCollector).
|
||||
pub fn new<C>(collector: C) -> Self
|
||||
where
|
||||
C: EventCollector<ItemTrain = T, ItemValid = V> + 'static,
|
||||
{
|
||||
let (sender, receiver) = mpsc::channel();
|
||||
let thread = WorkerThread::new(collector, receiver);
|
||||
|
||||
let handler = std::thread::spawn(move || thread.run());
|
||||
let handler = Some(handler);
|
||||
|
||||
Self { sender, handler }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Send, V: Send> EventCollector for AsyncEventCollector<T, V> {
|
||||
type ItemTrain = T;
|
||||
type ItemValid = V;
|
||||
|
||||
fn on_event_train(&mut self, event: Event<Self::ItemTrain>) {
|
||||
self.sender.send(Message::OnEventTrain(event)).unwrap();
|
||||
}
|
||||
|
||||
fn on_event_valid(&mut self, event: Event<Self::ItemValid>) {
|
||||
self.sender.send(Message::OnEventValid(event)).unwrap();
|
||||
}
|
||||
|
||||
fn find_epoch(
|
||||
&mut self,
|
||||
name: &str,
|
||||
aggregate: Aggregate,
|
||||
direction: Direction,
|
||||
split: Split,
|
||||
) -> Option<usize> {
|
||||
let (sender, receiver) = mpsc::sync_channel(1);
|
||||
self.sender
|
||||
.send(Message::FindEpoch(
|
||||
name.to_string(),
|
||||
aggregate,
|
||||
direction,
|
||||
split,
|
||||
sender,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
match receiver.recv() {
|
||||
Ok(value) => value,
|
||||
Err(err) => panic!("Async server crashed: {:?}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, V> Drop for AsyncEventCollector<T, V> {
|
||||
fn drop(&mut self) {
|
||||
self.sender.send(Message::End).unwrap();
|
||||
let handler = self.handler.take();
|
||||
|
||||
if let Some(handler) = handler {
|
||||
handler.join().unwrap();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,134 +0,0 @@
|
|||
use burn_core::{data::dataloader::Progress, LearningRate};
|
||||
|
||||
/// Event happening during the training/validation process.
|
||||
pub enum Event<T> {
|
||||
/// Signal that an item have been processed.
|
||||
ProcessedItem(LearnerItem<T>),
|
||||
/// Signal the end of an epoch.
|
||||
EndEpoch(usize),
|
||||
}
|
||||
|
||||
/// Defines how training and validation events are collected.
|
||||
///
|
||||
/// This trait also exposes methods that uses the collected data to compute useful information.
|
||||
pub trait EventCollector: Send {
|
||||
/// Training item.
|
||||
type ItemTrain;
|
||||
/// Validation item.
|
||||
type ItemValid;
|
||||
|
||||
/// Collect the training event.
|
||||
fn on_event_train(&mut self, event: Event<Self::ItemTrain>);
|
||||
|
||||
/// Collect the validaion event.
|
||||
fn on_event_valid(&mut self, event: Event<Self::ItemValid>);
|
||||
|
||||
/// Find the epoch following the given criteria from the collected data.
|
||||
fn find_epoch(
|
||||
&mut self,
|
||||
name: &str,
|
||||
aggregate: Aggregate,
|
||||
direction: Direction,
|
||||
split: Split,
|
||||
) -> Option<usize>;
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
/// How to aggregate the metric.
|
||||
pub enum Aggregate {
|
||||
/// Compute the average.
|
||||
Mean,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
/// The split to use.
|
||||
pub enum Split {
|
||||
/// The training split.
|
||||
Train,
|
||||
/// The validation split.
|
||||
Valid,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
/// The direction of the query.
|
||||
pub enum Direction {
|
||||
/// Lower is better.
|
||||
Lowest,
|
||||
/// Higher is better.
|
||||
Highest,
|
||||
}
|
||||
|
||||
/// A learner item.
|
||||
#[derive(new)]
|
||||
pub struct LearnerItem<T> {
|
||||
/// The item.
|
||||
pub item: T,
|
||||
|
||||
/// The progress.
|
||||
pub progress: Progress,
|
||||
|
||||
/// The epoch.
|
||||
pub epoch: usize,
|
||||
|
||||
/// The total number of epochs.
|
||||
pub epoch_total: usize,
|
||||
|
||||
/// The iteration.
|
||||
pub iteration: usize,
|
||||
|
||||
/// The learning rate.
|
||||
pub lr: Option<LearningRate>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test_utils {
|
||||
use crate::{info::MetricsInfo, Aggregate, Direction, Event, EventCollector, Split};
|
||||
|
||||
#[derive(new)]
|
||||
pub struct TestEventCollector<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
info: MetricsInfo<T, V>,
|
||||
}
|
||||
|
||||
impl<T, V> EventCollector for TestEventCollector<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
type ItemTrain = T;
|
||||
type ItemValid = V;
|
||||
|
||||
fn on_event_train(&mut self, event: Event<Self::ItemTrain>) {
|
||||
match event {
|
||||
Event::ProcessedItem(item) => {
|
||||
let metadata = (&item).into();
|
||||
self.info.update_train(&item, &metadata);
|
||||
}
|
||||
Event::EndEpoch(epoch) => self.info.end_epoch_train(epoch),
|
||||
}
|
||||
}
|
||||
|
||||
fn on_event_valid(&mut self, event: Event<Self::ItemValid>) {
|
||||
match event {
|
||||
Event::ProcessedItem(item) => {
|
||||
let metadata = (&item).into();
|
||||
self.info.update_valid(&item, &metadata);
|
||||
}
|
||||
Event::EndEpoch(epoch) => self.info.end_epoch_valid(epoch),
|
||||
}
|
||||
}
|
||||
|
||||
fn find_epoch(
|
||||
&mut self,
|
||||
name: &str,
|
||||
aggregate: Aggregate,
|
||||
direction: Direction,
|
||||
split: Split,
|
||||
) -> Option<usize> {
|
||||
self.info.find_epoch(name, aggregate, direction, split)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,131 +0,0 @@
|
|||
use crate::{
|
||||
info::MetricsInfo,
|
||||
metric::MetricMetadata,
|
||||
renderer::{MetricState, MetricsRenderer, TrainingProgress},
|
||||
Aggregate, Direction, Event, EventCollector, LearnerItem, Split,
|
||||
};
|
||||
|
||||
/// Collect training events in order to display metrics with a metrics renderer.
|
||||
#[derive(new)]
|
||||
pub(crate) struct RenderedMetricsEventCollector<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
renderer: Box<dyn MetricsRenderer>,
|
||||
info: MetricsInfo<T, V>,
|
||||
}
|
||||
|
||||
impl<T, V> EventCollector for RenderedMetricsEventCollector<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
type ItemTrain = T;
|
||||
type ItemValid = V;
|
||||
|
||||
fn on_event_train(&mut self, event: Event<Self::ItemTrain>) {
|
||||
match event {
|
||||
Event::ProcessedItem(item) => self.on_train_item(item),
|
||||
Event::EndEpoch(epoch) => self.on_train_end_epoch(epoch),
|
||||
}
|
||||
}
|
||||
|
||||
fn on_event_valid(&mut self, event: Event<Self::ItemValid>) {
|
||||
match event {
|
||||
Event::ProcessedItem(item) => self.on_valid_item(item),
|
||||
Event::EndEpoch(epoch) => self.on_valid_end_epoch(epoch),
|
||||
}
|
||||
}
|
||||
|
||||
fn find_epoch(
|
||||
&mut self,
|
||||
name: &str,
|
||||
aggregate: Aggregate,
|
||||
direction: Direction,
|
||||
split: Split,
|
||||
) -> Option<usize> {
|
||||
self.info.find_epoch(name, aggregate, direction, split)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, V> RenderedMetricsEventCollector<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
fn on_train_item(&mut self, item: LearnerItem<T>) {
|
||||
let progress = (&item).into();
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.info.update_train(&item, &metadata);
|
||||
|
||||
update
|
||||
.entries
|
||||
.into_iter()
|
||||
.for_each(|entry| self.renderer.update_train(MetricState::Generic(entry)));
|
||||
|
||||
update
|
||||
.entries_numeric
|
||||
.into_iter()
|
||||
.for_each(|(entry, value)| {
|
||||
self.renderer
|
||||
.update_train(MetricState::Numeric(entry, value))
|
||||
});
|
||||
|
||||
self.renderer.render_train(progress);
|
||||
}
|
||||
|
||||
fn on_valid_item(&mut self, item: LearnerItem<V>) {
|
||||
let progress = (&item).into();
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.info.update_valid(&item, &metadata);
|
||||
|
||||
update
|
||||
.entries
|
||||
.into_iter()
|
||||
.for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry)));
|
||||
|
||||
update
|
||||
.entries_numeric
|
||||
.into_iter()
|
||||
.for_each(|(entry, value)| {
|
||||
self.renderer
|
||||
.update_valid(MetricState::Numeric(entry, value))
|
||||
});
|
||||
|
||||
self.renderer.render_train(progress);
|
||||
}
|
||||
|
||||
fn on_train_end_epoch(&mut self, epoch: usize) {
|
||||
self.info.end_epoch_train(epoch);
|
||||
}
|
||||
|
||||
fn on_valid_end_epoch(&mut self, epoch: usize) {
|
||||
self.info.end_epoch_valid(epoch);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<&LearnerItem<T>> for TrainingProgress {
|
||||
fn from(item: &LearnerItem<T>) -> Self {
|
||||
Self {
|
||||
progress: item.progress.clone(),
|
||||
epoch: item.epoch,
|
||||
epoch_total: item.epoch_total,
|
||||
iteration: item.iteration,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<&LearnerItem<T>> for MetricMetadata {
|
||||
fn from(item: &LearnerItem<T>) -> Self {
|
||||
Self {
|
||||
progress: item.progress.clone(),
|
||||
epoch: item.epoch,
|
||||
epoch_total: item.epoch_total,
|
||||
iteration: item.iteration,
|
||||
lr: item.lr,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,3 +0,0 @@
|
|||
mod base;
|
||||
|
||||
pub(crate) use base::*;
|
|
@ -1,8 +0,0 @@
|
|||
mod async_collector;
|
||||
mod base;
|
||||
|
||||
pub use async_collector::*;
|
||||
pub use base::*;
|
||||
|
||||
/// Metrics collector module.
|
||||
pub mod metrics;
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
checkpoint::{Checkpointer, CheckpointingStrategy},
|
||||
EventCollector,
|
||||
metric::processor::EventProcessor,
|
||||
};
|
||||
use burn_core::{
|
||||
lr_scheduler::LrScheduler,
|
||||
|
@ -28,14 +28,13 @@ pub trait LearnerComponents {
|
|||
>;
|
||||
/// The checkpointer used for the scheduler.
|
||||
type CheckpointerLrScheduler: Checkpointer<<Self::LrScheduler as LrScheduler>::Record>;
|
||||
/// Training event collector used for training tracking.
|
||||
type EventCollector: EventCollector + 'static;
|
||||
type EventProcessor: EventProcessor + 'static;
|
||||
/// The strategy to save and delete checkpoints.
|
||||
type CheckpointerStrategy: CheckpointingStrategy<Self::EventCollector>;
|
||||
type CheckpointerStrategy: CheckpointingStrategy;
|
||||
}
|
||||
|
||||
/// Concrete type that implements [training components trait](TrainingComponents).
|
||||
pub struct LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EC, S> {
|
||||
pub struct LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EP, S> {
|
||||
_backend: PhantomData<B>,
|
||||
_lr_scheduler: PhantomData<LR>,
|
||||
_model: PhantomData<M>,
|
||||
|
@ -43,12 +42,12 @@ pub struct LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EC, S> {
|
|||
_checkpointer_model: PhantomData<CM>,
|
||||
_checkpointer_optim: PhantomData<CO>,
|
||||
_checkpointer_scheduler: PhantomData<CS>,
|
||||
_collector: PhantomData<EC>,
|
||||
_event_processor: PhantomData<EP>,
|
||||
_strategy: S,
|
||||
}
|
||||
|
||||
impl<B, LR, M, O, CM, CO, CS, EC, S> LearnerComponents
|
||||
for LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EC, S>
|
||||
impl<B, LR, M, O, CM, CO, CS, EP, S> LearnerComponents
|
||||
for LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EP, S>
|
||||
where
|
||||
B: ADBackend,
|
||||
LR: LrScheduler,
|
||||
|
@ -57,8 +56,8 @@ where
|
|||
CM: Checkpointer<M::Record>,
|
||||
CO: Checkpointer<O::Record>,
|
||||
CS: Checkpointer<LR::Record>,
|
||||
EC: EventCollector + 'static,
|
||||
S: CheckpointingStrategy<EC>,
|
||||
EP: EventProcessor + 'static,
|
||||
S: CheckpointingStrategy,
|
||||
{
|
||||
type Backend = B;
|
||||
type LrScheduler = LR;
|
||||
|
@ -67,6 +66,6 @@ where
|
|||
type CheckpointerModel = CM;
|
||||
type CheckpointerOptimizer = CO;
|
||||
type CheckpointerLrScheduler = CS;
|
||||
type EventCollector = EC;
|
||||
type EventProcessor = EP;
|
||||
type CheckpointerStrategy = S;
|
||||
}
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
mod aggregates;
|
||||
mod metrics;
|
||||
|
||||
pub(crate) use aggregates::*;
|
||||
pub use metrics::*;
|
|
@ -1,5 +1,7 @@
|
|||
use crate::checkpoint::{Checkpointer, CheckpointingAction, CheckpointingStrategy};
|
||||
use crate::components::LearnerComponents;
|
||||
use crate::learner::EarlyStoppingStrategy;
|
||||
use crate::metric::store::EventStoreClient;
|
||||
use burn_core::lr_scheduler::LrScheduler;
|
||||
use burn_core::module::Module;
|
||||
use burn_core::optim::Optimizer;
|
||||
|
@ -19,8 +21,10 @@ pub struct Learner<LC: LearnerComponents> {
|
|||
pub(crate) grad_accumulation: Option<usize>,
|
||||
pub(crate) checkpointer: Option<LearnerCheckpointer<LC>>,
|
||||
pub(crate) devices: Vec<<LC::Backend as Backend>::Device>,
|
||||
pub(crate) collector: LC::EventCollector,
|
||||
pub(crate) interrupter: TrainingInterrupter,
|
||||
pub(crate) early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
|
||||
pub(crate) event_processor: LC::EventProcessor,
|
||||
pub(crate) event_store: Arc<EventStoreClient>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
|
@ -38,9 +42,9 @@ impl<LC: LearnerComponents> LearnerCheckpointer<LC> {
|
|||
optim: &LC::Optimizer,
|
||||
scheduler: &LC::LrScheduler,
|
||||
epoch: usize,
|
||||
collector: &mut LC::EventCollector,
|
||||
store: &EventStoreClient,
|
||||
) {
|
||||
let actions = self.strategy.checkpointing(epoch, collector);
|
||||
let actions = self.strategy.checkpointing(epoch, store);
|
||||
|
||||
for action in actions {
|
||||
match action {
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use super::log::install_file_logger;
|
||||
use super::Learner;
|
||||
use crate::checkpoint::{
|
||||
|
@ -5,13 +7,14 @@ use crate::checkpoint::{
|
|||
KeepLastNCheckpoints, MetricCheckpointingStrategy,
|
||||
};
|
||||
use crate::components::LearnerComponentsMarker;
|
||||
use crate::info::MetricsInfo;
|
||||
use crate::learner::base::TrainingInterrupter;
|
||||
use crate::learner::EarlyStoppingStrategy;
|
||||
use crate::logger::{FileMetricLogger, MetricLogger};
|
||||
use crate::metric::processor::{FullEventProcessor, Metrics};
|
||||
use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
|
||||
use crate::metric::{Adaptor, LossMetric, Metric};
|
||||
use crate::renderer::{default_renderer, MetricsRenderer};
|
||||
use crate::{collector::metrics::RenderedMetricsEventCollector, Aggregate, Direction, Split};
|
||||
use crate::{AsyncEventCollector, LearnerCheckpointer};
|
||||
use crate::LearnerCheckpointer;
|
||||
use burn_core::lr_scheduler::LrScheduler;
|
||||
use burn_core::module::ADModule;
|
||||
use burn_core::optim::Optimizer;
|
||||
|
@ -43,11 +46,13 @@ where
|
|||
grad_accumulation: Option<usize>,
|
||||
devices: Vec<B::Device>,
|
||||
renderer: Option<Box<dyn MetricsRenderer + 'static>>,
|
||||
info: MetricsInfo<T, V>,
|
||||
metrics: Metrics<T, V>,
|
||||
event_store: LogEventStore,
|
||||
interrupter: TrainingInterrupter,
|
||||
log_to_file: bool,
|
||||
num_loggers: usize,
|
||||
checkpointer_strategy: Box<dyn CheckpointingStrategy<AsyncEventCollector<T, V>>>,
|
||||
checkpointer_strategy: Box<dyn CheckpointingStrategy>,
|
||||
early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
|
||||
}
|
||||
|
||||
impl<B, T, V, M, O, S> LearnerBuilder<B, T, V, M, O, S>
|
||||
|
@ -72,7 +77,8 @@ where
|
|||
directory: directory.to_string(),
|
||||
grad_accumulation: None,
|
||||
devices: vec![B::Device::default()],
|
||||
info: MetricsInfo::new(),
|
||||
metrics: Metrics::default(),
|
||||
event_store: LogEventStore::default(),
|
||||
renderer: None,
|
||||
interrupter: TrainingInterrupter::new(),
|
||||
log_to_file: true,
|
||||
|
@ -87,6 +93,7 @@ where
|
|||
))
|
||||
.build(),
|
||||
),
|
||||
early_stopping: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -101,8 +108,8 @@ where
|
|||
MT: MetricLogger + 'static,
|
||||
MV: MetricLogger + 'static,
|
||||
{
|
||||
self.info.register_logger_train(logger_train);
|
||||
self.info.register_logger_valid(logger_valid);
|
||||
self.event_store.register_logger_train(logger_train);
|
||||
self.event_store.register_logger_valid(logger_valid);
|
||||
self.num_loggers += 1;
|
||||
self
|
||||
}
|
||||
|
@ -110,7 +117,7 @@ where
|
|||
/// Update the checkpointing_strategy.
|
||||
pub fn with_checkpointing_strategy<CS>(&mut self, strategy: CS)
|
||||
where
|
||||
CS: CheckpointingStrategy<AsyncEventCollector<T, V>> + 'static,
|
||||
CS: CheckpointingStrategy + 'static,
|
||||
{
|
||||
self.checkpointer_strategy = Box::new(strategy);
|
||||
}
|
||||
|
@ -133,7 +140,7 @@ where
|
|||
where
|
||||
T: Adaptor<Me::Input>,
|
||||
{
|
||||
self.info.register_metric_train(metric);
|
||||
self.metrics.register_metric_train(metric);
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -142,7 +149,7 @@ where
|
|||
where
|
||||
V: Adaptor<Me::Input>,
|
||||
{
|
||||
self.info.register_valid_metric(metric);
|
||||
self.metrics.register_valid_metric(metric);
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -167,7 +174,7 @@ where
|
|||
Me: Metric + crate::metric::Numeric + 'static,
|
||||
T: Adaptor<Me::Input>,
|
||||
{
|
||||
self.info.register_train_metric_numeric(metric);
|
||||
self.metrics.register_train_metric_numeric(metric);
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -179,7 +186,7 @@ where
|
|||
where
|
||||
V: Adaptor<Me::Input>,
|
||||
{
|
||||
self.info.register_valid_metric_numeric(metric);
|
||||
self.metrics.register_valid_metric_numeric(metric);
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -206,6 +213,16 @@ where
|
|||
self.interrupter.clone()
|
||||
}
|
||||
|
||||
/// Register an [early stopping strategy](EarlyStoppingStrategy) to stop the training when the
|
||||
/// conditions are meet.
|
||||
pub fn early_stopping<Strategy>(mut self, strategy: Strategy) -> Self
|
||||
where
|
||||
Strategy: EarlyStoppingStrategy + 'static,
|
||||
{
|
||||
self.early_stopping = Some(Box::new(strategy));
|
||||
self
|
||||
}
|
||||
|
||||
/// By default, Rust logs are captured and written into
|
||||
/// `experiment.log`. If disabled, standard Rust log handling
|
||||
/// will apply.
|
||||
|
@ -267,8 +284,8 @@ where
|
|||
AsyncCheckpointer<M::Record>,
|
||||
AsyncCheckpointer<O::Record>,
|
||||
AsyncCheckpointer<S::Record>,
|
||||
AsyncEventCollector<T, V>,
|
||||
Box<dyn CheckpointingStrategy<AsyncEventCollector<T, V>>>,
|
||||
FullEventProcessor<T, V>,
|
||||
Box<dyn CheckpointingStrategy>,
|
||||
>,
|
||||
>
|
||||
where
|
||||
|
@ -285,16 +302,18 @@ where
|
|||
let directory = &self.directory;
|
||||
|
||||
if self.num_loggers == 0 {
|
||||
self.info.register_logger_train(FileMetricLogger::new(
|
||||
format!("{directory}/train").as_str(),
|
||||
));
|
||||
self.info.register_logger_valid(FileMetricLogger::new(
|
||||
format!("{directory}/valid").as_str(),
|
||||
));
|
||||
self.event_store
|
||||
.register_logger_train(FileMetricLogger::new(
|
||||
format!("{directory}/train").as_str(),
|
||||
));
|
||||
self.event_store
|
||||
.register_logger_valid(FileMetricLogger::new(
|
||||
format!("{directory}/valid").as_str(),
|
||||
));
|
||||
}
|
||||
|
||||
let collector =
|
||||
AsyncEventCollector::new(RenderedMetricsEventCollector::new(renderer, self.info));
|
||||
let event_store = Arc::new(EventStoreClient::new(self.event_store));
|
||||
let event_processor = FullEventProcessor::new(self.metrics, renderer, event_store.clone());
|
||||
|
||||
let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| {
|
||||
LearnerCheckpointer::new(model, optim, scheduler, self.checkpointer_strategy)
|
||||
|
@ -306,11 +325,13 @@ where
|
|||
lr_scheduler,
|
||||
checkpointer,
|
||||
num_epochs: self.num_epochs,
|
||||
collector,
|
||||
event_processor,
|
||||
event_store,
|
||||
checkpoint: self.checkpoint,
|
||||
grad_accumulation: self.grad_accumulation,
|
||||
devices: self.devices,
|
||||
interrupter: self.interrupter,
|
||||
early_stopping: self.early_stopping,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,209 @@
|
|||
use crate::metric::{
|
||||
store::{Aggregate, Direction, EventStoreClient, Split},
|
||||
Metric,
|
||||
};
|
||||
|
||||
/// The condition that [early stopping strategies](EarlyStoppingStrategy) should follow.
|
||||
pub enum StoppingCondition {
|
||||
/// When no improvement has happened since the given number of epochs.
|
||||
NoImprovementSince {
|
||||
/// The number of epochs allowed to worsen before it gets better.
|
||||
n_epochs: usize,
|
||||
},
|
||||
}
|
||||
|
||||
/// A strategy that checks if the training should be stopped.
|
||||
pub trait EarlyStoppingStrategy {
|
||||
/// Update its current state and returns if the training should be stopped.
|
||||
fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool;
|
||||
}
|
||||
|
||||
/// An [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected
|
||||
/// during training or validation.
|
||||
pub struct MetricEarlyStoppingStrategy {
|
||||
condition: StoppingCondition,
|
||||
metric_name: String,
|
||||
aggregate: Aggregate,
|
||||
direction: Direction,
|
||||
split: Split,
|
||||
best_epoch: usize,
|
||||
best_value: f64,
|
||||
}
|
||||
|
||||
impl EarlyStoppingStrategy for MetricEarlyStoppingStrategy {
|
||||
fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool {
|
||||
let current_value =
|
||||
match store.find_metric(&self.metric_name, epoch, self.aggregate, self.split) {
|
||||
Some(value) => value,
|
||||
None => {
|
||||
log::warn!("Can't find metric for early stopping.");
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
let is_best = match self.direction {
|
||||
Direction::Lowest => current_value < self.best_value,
|
||||
Direction::Highest => current_value > self.best_value,
|
||||
};
|
||||
|
||||
if is_best {
|
||||
log::info!(
|
||||
"New best epoch found {} {}: {}",
|
||||
epoch,
|
||||
self.metric_name,
|
||||
current_value
|
||||
);
|
||||
self.best_value = current_value;
|
||||
self.best_epoch = epoch;
|
||||
return false;
|
||||
}
|
||||
|
||||
match self.condition {
|
||||
StoppingCondition::NoImprovementSince { n_epochs } => {
|
||||
let should_stop = epoch - self.best_epoch >= n_epochs;
|
||||
|
||||
if should_stop {
|
||||
log::info!("Stopping training loop, no improvement since epoch {}, {}: {}, current epoch {}, {}: {}", self.best_epoch, self.metric_name, self.best_value, epoch, self.metric_name, current_value);
|
||||
}
|
||||
|
||||
should_stop
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MetricEarlyStoppingStrategy {
|
||||
/// Create a new [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected
|
||||
/// during training or validation.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The metric should be registered for early stopping to work, otherwise no data is collected.
|
||||
pub fn new<Me: Metric>(
|
||||
aggregate: Aggregate,
|
||||
direction: Direction,
|
||||
split: Split,
|
||||
condition: StoppingCondition,
|
||||
) -> Self {
|
||||
let init_value = match direction {
|
||||
Direction::Lowest => f64::MAX,
|
||||
Direction::Highest => f64::MIN,
|
||||
};
|
||||
|
||||
Self {
|
||||
metric_name: Me::NAME.to_string(),
|
||||
condition,
|
||||
aggregate,
|
||||
direction,
|
||||
split,
|
||||
best_epoch: 1,
|
||||
best_value: init_value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
logger::InMemoryMetricLogger,
|
||||
metric::{
|
||||
processor::{
|
||||
test_utils::{end_epoch, process_train},
|
||||
Metrics, MinimalEventProcessor,
|
||||
},
|
||||
store::LogEventStore,
|
||||
LossMetric,
|
||||
},
|
||||
TestBackend,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn never_early_stop_while_it_is_improving() {
|
||||
test_early_stopping(
|
||||
1,
|
||||
&[
|
||||
(&[0.5, 0.3], false, "Should not stop first epoch"),
|
||||
(&[0.4, 0.3], false, "Should not stop when improving"),
|
||||
(&[0.3, 0.3], false, "Should not stop when improving"),
|
||||
(&[0.2, 0.3], false, "Should not stop when improving"),
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn early_stop_when_no_improvement_since_two_epochs() {
|
||||
test_early_stopping(
|
||||
2,
|
||||
&[
|
||||
(&[1.0, 0.5], false, "Should not stop first epoch"),
|
||||
(&[0.5, 0.3], false, "Should not stop when improving"),
|
||||
(
|
||||
&[1.0, 3.0],
|
||||
false,
|
||||
"Should not stop first time it gets worse",
|
||||
),
|
||||
(
|
||||
&[1.0, 2.0],
|
||||
true,
|
||||
"Should stop since two following epochs didn't improve",
|
||||
),
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn early_stop_when_stays_equal() {
|
||||
test_early_stopping(
|
||||
2,
|
||||
&[
|
||||
(&[0.5, 0.3], false, "Should not stop first epoch"),
|
||||
(
|
||||
&[0.5, 0.3],
|
||||
false,
|
||||
"Should not stop first time it stars the same",
|
||||
),
|
||||
(
|
||||
&[0.5, 0.3],
|
||||
true,
|
||||
"Should stop since two following epochs didn't improve",
|
||||
),
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
fn test_early_stopping(n_epochs: usize, data: &[(&[f64], bool, &str)]) {
|
||||
let mut early_stopping = MetricEarlyStoppingStrategy::new::<LossMetric<TestBackend>>(
|
||||
Aggregate::Mean,
|
||||
Direction::Lowest,
|
||||
Split::Train,
|
||||
StoppingCondition::NoImprovementSince { n_epochs },
|
||||
);
|
||||
let mut store = LogEventStore::default();
|
||||
let mut metrics = Metrics::<f64, f64>::default();
|
||||
|
||||
store.register_logger_train(InMemoryMetricLogger::default());
|
||||
metrics.register_train_metric_numeric(LossMetric::<TestBackend>::new());
|
||||
|
||||
let store = Arc::new(EventStoreClient::new(store));
|
||||
let mut processor = MinimalEventProcessor::new(metrics, store.clone());
|
||||
|
||||
let mut epoch = 1;
|
||||
for (points, should_start, comment) in data {
|
||||
for point in points.iter() {
|
||||
process_train(&mut processor, *point, epoch);
|
||||
}
|
||||
end_epoch(&mut processor, epoch);
|
||||
|
||||
assert_eq!(
|
||||
*should_start,
|
||||
early_stopping.should_stop(epoch, &store),
|
||||
"{comment}"
|
||||
);
|
||||
epoch += 1;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -4,8 +4,9 @@ use burn_core::{
|
|||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{components::LearnerComponents, learner::base::TrainingInterrupter, Event};
|
||||
use crate::{EventCollector, LearnerItem, MultiDevicesTrainStep, TrainStep, ValidStep};
|
||||
use crate::metric::processor::{Event, EventProcessor, LearnerItem};
|
||||
use crate::{components::LearnerComponents, learner::base::TrainingInterrupter};
|
||||
use crate::{MultiDevicesTrainStep, TrainStep, ValidStep};
|
||||
|
||||
/// A validation epoch.
|
||||
#[derive(new)]
|
||||
|
@ -30,14 +31,14 @@ impl<VI> ValidEpoch<VI> {
|
|||
/// # Arguments
|
||||
///
|
||||
/// * `model` - The model to validate.
|
||||
/// * `callback` - The callback to use.
|
||||
/// * `processor` - The event processor to use.
|
||||
pub fn run<LC: LearnerComponents, VO>(
|
||||
&self,
|
||||
model: &LC::Model,
|
||||
callback: &mut LC::EventCollector,
|
||||
processor: &mut LC::EventProcessor,
|
||||
interrupter: &TrainingInterrupter,
|
||||
) where
|
||||
LC::EventCollector: EventCollector<ItemValid = VO>,
|
||||
LC::EventProcessor: EventProcessor<ItemValid = VO>,
|
||||
<LC::Model as ADModule<LC::Backend>>::InnerModule: ValidStep<VI, VO>,
|
||||
{
|
||||
log::info!("Executing validation step for epoch {}", self.epoch);
|
||||
|
@ -60,14 +61,14 @@ impl<VI> ValidEpoch<VI> {
|
|||
None,
|
||||
);
|
||||
|
||||
callback.on_event_valid(Event::ProcessedItem(item));
|
||||
processor.process_valid(Event::ProcessedItem(item));
|
||||
|
||||
if interrupter.should_stop() {
|
||||
log::info!("Training interrupted.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
callback.on_event_valid(Event::EndEpoch(self.epoch));
|
||||
processor.process_valid(Event::EndEpoch(self.epoch));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -79,7 +80,7 @@ impl<TI> TrainEpoch<TI> {
|
|||
/// * `model` - The model to train.
|
||||
/// * `optim` - The optimizer to use.
|
||||
/// * `scheduler` - The learning rate scheduler to use.
|
||||
/// * `callback` - The callback to use.
|
||||
/// * `processor` - The event processor to use.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
|
@ -89,11 +90,11 @@ impl<TI> TrainEpoch<TI> {
|
|||
mut model: LC::Model,
|
||||
mut optim: LC::Optimizer,
|
||||
scheduler: &mut LC::LrScheduler,
|
||||
callback: &mut LC::EventCollector,
|
||||
processor: &mut LC::EventProcessor,
|
||||
interrupter: &TrainingInterrupter,
|
||||
) -> (LC::Model, LC::Optimizer)
|
||||
where
|
||||
LC::EventCollector: EventCollector<ItemTrain = TO>,
|
||||
LC::EventProcessor: EventProcessor<ItemTrain = TO>,
|
||||
LC::Model: TrainStep<TI, TO>,
|
||||
{
|
||||
log::info!("Executing training step for epoch {}", self.epoch,);
|
||||
|
@ -134,13 +135,14 @@ impl<TI> TrainEpoch<TI> {
|
|||
Some(lr),
|
||||
);
|
||||
|
||||
callback.on_event_train(Event::ProcessedItem(item));
|
||||
processor.process_train(Event::ProcessedItem(item));
|
||||
|
||||
if interrupter.should_stop() {
|
||||
log::info!("Training interrupted.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
callback.on_event_train(Event::EndEpoch(self.epoch));
|
||||
processor.process_train(Event::EndEpoch(self.epoch));
|
||||
|
||||
(model, optim)
|
||||
}
|
||||
|
@ -154,7 +156,7 @@ impl<TI> TrainEpoch<TI> {
|
|||
/// * `model` - The model to train.
|
||||
/// * `optim` - The optimizer to use.
|
||||
/// * `lr_scheduler` - The learning rate scheduler to use.
|
||||
/// * `callback` - The callback to use.
|
||||
/// * `processor` - The event processor to use.
|
||||
/// * `devices` - The devices to use.
|
||||
///
|
||||
/// # Returns
|
||||
|
@ -165,12 +167,12 @@ impl<TI> TrainEpoch<TI> {
|
|||
mut model: LC::Model,
|
||||
mut optim: LC::Optimizer,
|
||||
lr_scheduler: &mut LC::LrScheduler,
|
||||
callback: &mut LC::EventCollector,
|
||||
processor: &mut LC::EventProcessor,
|
||||
devices: Vec<<LC::Backend as Backend>::Device>,
|
||||
interrupter: &TrainingInterrupter,
|
||||
) -> (LC::Model, LC::Optimizer)
|
||||
where
|
||||
LC::EventCollector: EventCollector<ItemTrain = TO>,
|
||||
LC::EventProcessor: EventProcessor<ItemTrain = TO>,
|
||||
LC::Model: TrainStep<TI, TO>,
|
||||
TO: Send + 'static,
|
||||
TI: Send + 'static,
|
||||
|
@ -224,7 +226,7 @@ impl<TI> TrainEpoch<TI> {
|
|||
Some(lr),
|
||||
);
|
||||
|
||||
callback.on_event_train(Event::ProcessedItem(item));
|
||||
processor.process_train(Event::ProcessedItem(item));
|
||||
|
||||
if interrupter.should_stop() {
|
||||
log::info!("Training interrupted.");
|
||||
|
@ -238,7 +240,7 @@ impl<TI> TrainEpoch<TI> {
|
|||
}
|
||||
}
|
||||
|
||||
callback.on_event_train(Event::EndEpoch(self.epoch));
|
||||
processor.process_train(Event::EndEpoch(self.epoch));
|
||||
|
||||
(model, optim)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
mod base;
|
||||
mod builder;
|
||||
mod classification;
|
||||
mod early_stopping;
|
||||
mod epoch;
|
||||
mod regression;
|
||||
mod step;
|
||||
|
@ -11,6 +12,7 @@ pub(crate) mod log;
|
|||
pub use base::*;
|
||||
pub use builder::*;
|
||||
pub use classification::*;
|
||||
pub use early_stopping::*;
|
||||
pub use epoch::*;
|
||||
pub use regression::*;
|
||||
pub use step::*;
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use crate::components::LearnerComponents;
|
||||
use crate::{EventCollector, Learner, TrainEpoch, ValidEpoch};
|
||||
use crate::metric::processor::EventProcessor;
|
||||
use crate::{Learner, TrainEpoch, ValidEpoch};
|
||||
use burn_core::data::dataloader::DataLoader;
|
||||
use burn_core::module::{ADModule, Module};
|
||||
use burn_core::optim::{GradientsParams, Optimizer};
|
||||
|
@ -115,7 +116,7 @@ impl<LC: LearnerComponents> Learner<LC> {
|
|||
OutputValid: Send,
|
||||
LC::Model: TrainStep<InputTrain, OutputTrain>,
|
||||
<LC::Model as ADModule<LC::Backend>>::InnerModule: ValidStep<InputValid, OutputValid>,
|
||||
LC::EventCollector: EventCollector<ItemTrain = OutputTrain, ItemValid = OutputValid>,
|
||||
LC::EventProcessor: EventProcessor<ItemTrain = OutputTrain, ItemValid = OutputValid>,
|
||||
{
|
||||
log::info!("Fitting {}", self.model.to_string());
|
||||
// The reference model is always on the first device provided.
|
||||
|
@ -151,7 +152,7 @@ impl<LC: LearnerComponents> Learner<LC> {
|
|||
self.model,
|
||||
self.optim,
|
||||
&mut self.lr_scheduler,
|
||||
&mut self.collector,
|
||||
&mut self.event_processor,
|
||||
self.devices.clone(),
|
||||
&self.interrupter,
|
||||
)
|
||||
|
@ -160,7 +161,7 @@ impl<LC: LearnerComponents> Learner<LC> {
|
|||
self.model,
|
||||
self.optim,
|
||||
&mut self.lr_scheduler,
|
||||
&mut self.collector,
|
||||
&mut self.event_processor,
|
||||
&self.interrupter,
|
||||
);
|
||||
}
|
||||
|
@ -170,7 +171,11 @@ impl<LC: LearnerComponents> Learner<LC> {
|
|||
}
|
||||
|
||||
let epoch_valid = ValidEpoch::new(dataloader_valid.clone(), epoch, self.num_epochs);
|
||||
epoch_valid.run::<LC, OutputValid>(&self.model, &mut self.collector, &self.interrupter);
|
||||
epoch_valid.run::<LC, OutputValid>(
|
||||
&self.model,
|
||||
&mut self.event_processor,
|
||||
&self.interrupter,
|
||||
);
|
||||
|
||||
if let Some(checkpointer) = &mut self.checkpointer {
|
||||
checkpointer.checkpoint(
|
||||
|
@ -178,9 +183,15 @@ impl<LC: LearnerComponents> Learner<LC> {
|
|||
&self.optim,
|
||||
&self.lr_scheduler,
|
||||
epoch,
|
||||
&mut self.collector,
|
||||
&self.event_store,
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(early_stopping) = &mut self.early_stopping {
|
||||
if early_stopping.should_stop(epoch, &self.event_store) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.model
|
||||
|
|
|
@ -19,13 +19,8 @@ pub mod logger;
|
|||
/// The metric module.
|
||||
pub mod metric;
|
||||
|
||||
/// All information collected during training.
|
||||
pub mod info;
|
||||
|
||||
mod collector;
|
||||
mod learner;
|
||||
|
||||
pub use collector::*;
|
||||
pub use learner::*;
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
use super::Logger;
|
||||
|
||||
/// In memory logger.
|
||||
#[derive(Default)]
|
||||
pub struct InMemoryLogger {
|
||||
pub(crate) values: Vec<String>,
|
||||
}
|
||||
|
||||
impl<T> Logger<T> for InMemoryLogger
|
||||
where
|
||||
T: std::fmt::Display,
|
||||
{
|
||||
fn log(&mut self, item: T) {
|
||||
self.values.push(item.to_string());
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
use super::{AsyncLogger, FileLogger, Logger};
|
||||
use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger};
|
||||
use crate::metric::MetricEntry;
|
||||
use std::collections::HashMap;
|
||||
|
||||
|
@ -16,7 +16,7 @@ pub trait MetricLogger: Send {
|
|||
/// # Arguments
|
||||
///
|
||||
/// * `epoch` - The epoch.
|
||||
fn epoch(&mut self, epoch: usize);
|
||||
fn end_epoch(&mut self, epoch: usize);
|
||||
|
||||
/// Read the logs for an epoch.
|
||||
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<f64>, String>;
|
||||
|
@ -81,9 +81,9 @@ impl MetricLogger for FileMetricLogger {
|
|||
logger.log(value.clone());
|
||||
}
|
||||
|
||||
fn epoch(&mut self, epoch: usize) {
|
||||
fn end_epoch(&mut self, epoch: usize) {
|
||||
self.loggers.clear();
|
||||
self.epoch = epoch;
|
||||
self.epoch = epoch + 1;
|
||||
}
|
||||
|
||||
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<f64>, String> {
|
||||
|
@ -125,23 +125,24 @@ impl MetricLogger for FileMetricLogger {
|
|||
/// In memory metric logger, useful when testing and debugging.
|
||||
#[derive(Default)]
|
||||
pub struct InMemoryMetricLogger {
|
||||
values: HashMap<String, Vec<Vec<String>>>,
|
||||
values: HashMap<String, Vec<InMemoryLogger>>,
|
||||
}
|
||||
|
||||
impl MetricLogger for InMemoryMetricLogger {
|
||||
fn log(&mut self, item: &MetricEntry) {
|
||||
if !self.values.contains_key(&item.name) {
|
||||
self.values.insert(item.name.clone(), vec![vec![]]);
|
||||
self.values
|
||||
.insert(item.name.clone(), vec![InMemoryLogger::default()]);
|
||||
}
|
||||
|
||||
let values = self.values.get_mut(&item.name).unwrap();
|
||||
|
||||
values.last_mut().unwrap().push(item.serialize.clone());
|
||||
values.last_mut().unwrap().log(item.serialize.clone());
|
||||
}
|
||||
|
||||
fn epoch(&mut self, _epoch: usize) {
|
||||
fn end_epoch(&mut self, _epoch: usize) {
|
||||
for (_, values) in self.values.iter_mut() {
|
||||
values.push(Vec::new());
|
||||
values.push(InMemoryLogger::default());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -152,7 +153,8 @@ impl MetricLogger for InMemoryMetricLogger {
|
|||
};
|
||||
|
||||
match values.get(epoch - 1) {
|
||||
Some(values) => Ok(values
|
||||
Some(logger) => Ok(logger
|
||||
.values
|
||||
.iter()
|
||||
.filter_map(|value| value.parse::<f64>().ok())
|
||||
.collect()),
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
mod async_logger;
|
||||
mod base;
|
||||
mod file;
|
||||
mod in_memory;
|
||||
mod metric;
|
||||
|
||||
pub use async_logger::*;
|
||||
pub use base::*;
|
||||
pub use file::*;
|
||||
pub use in_memory::*;
|
||||
pub use metric::*;
|
||||
|
|
|
@ -74,7 +74,7 @@ pub trait Numeric {
|
|||
}
|
||||
|
||||
/// Data type that contains the current state of a metric at a given time.
|
||||
#[derive(new, Debug)]
|
||||
#[derive(new, Debug, Clone)]
|
||||
pub struct MetricEntry {
|
||||
/// The name of the metric.
|
||||
pub name: String,
|
||||
|
|
|
@ -26,3 +26,7 @@ pub use learning_rate::*;
|
|||
pub use loss::*;
|
||||
#[cfg(feature = "metrics")]
|
||||
pub use memory_use::*;
|
||||
|
||||
pub(crate) mod processor;
|
||||
/// Module responsible to save and exposes data collected during training.
|
||||
pub mod store;
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
use burn_core::data::dataloader::Progress;
|
||||
use burn_core::LearningRate;
|
||||
|
||||
/// Event happening during the training/validation process.
|
||||
pub enum Event<T> {
|
||||
/// Signal that an item have been processed.
|
||||
ProcessedItem(LearnerItem<T>),
|
||||
/// Signal the end of an epoch.
|
||||
EndEpoch(usize),
|
||||
}
|
||||
|
||||
/// Process events happening during training and validation.
|
||||
pub trait EventProcessor {
|
||||
/// The training item.
|
||||
type ItemTrain;
|
||||
/// The validation item.
|
||||
type ItemValid;
|
||||
|
||||
/// Collect a training event.
|
||||
fn process_train(&mut self, event: Event<Self::ItemTrain>);
|
||||
/// Collect a validation event.
|
||||
fn process_valid(&mut self, event: Event<Self::ItemValid>);
|
||||
}
|
||||
|
||||
/// A learner item.
|
||||
#[derive(new)]
|
||||
pub struct LearnerItem<T> {
|
||||
/// The item.
|
||||
pub item: T,
|
||||
|
||||
/// The progress.
|
||||
pub progress: Progress,
|
||||
|
||||
/// The epoch.
|
||||
pub epoch: usize,
|
||||
|
||||
/// The total number of epochs.
|
||||
pub epoch_total: usize,
|
||||
|
||||
/// The iteration.
|
||||
pub iteration: usize,
|
||||
|
||||
/// The learning rate.
|
||||
pub lr: Option<LearningRate>,
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
use super::{Event, EventProcessor, Metrics};
|
||||
use crate::metric::store::EventStoreClient;
|
||||
use crate::renderer::{MetricState, MetricsRenderer};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// An [event processor](EventProcessor) that handles:
|
||||
/// - Computing and storing metrics in an [event store](crate::metric::store::EventStore).
|
||||
/// - Render metrics using a [metrics renderer](MetricsRenderer).
|
||||
pub struct FullEventProcessor<T, V> {
|
||||
metrics: Metrics<T, V>,
|
||||
renderer: Box<dyn MetricsRenderer>,
|
||||
store: Arc<EventStoreClient>,
|
||||
}
|
||||
|
||||
impl<T, V> FullEventProcessor<T, V> {
|
||||
pub(crate) fn new(
|
||||
metrics: Metrics<T, V>,
|
||||
renderer: Box<dyn MetricsRenderer>,
|
||||
store: Arc<EventStoreClient>,
|
||||
) -> Self {
|
||||
Self {
|
||||
metrics,
|
||||
renderer,
|
||||
store,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, V> EventProcessor for FullEventProcessor<T, V> {
|
||||
type ItemTrain = T;
|
||||
type ItemValid = V;
|
||||
|
||||
fn process_train(&mut self, event: Event<Self::ItemTrain>) {
|
||||
match event {
|
||||
Event::ProcessedItem(item) => {
|
||||
let progress = (&item).into();
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.metrics.update_train(&item, &metadata);
|
||||
|
||||
self.store
|
||||
.add_event_train(crate::metric::store::Event::MetricsUpdate(update.clone()));
|
||||
|
||||
update
|
||||
.entries
|
||||
.into_iter()
|
||||
.for_each(|entry| self.renderer.update_train(MetricState::Generic(entry)));
|
||||
|
||||
update
|
||||
.entries_numeric
|
||||
.into_iter()
|
||||
.for_each(|(entry, value)| {
|
||||
self.renderer
|
||||
.update_train(MetricState::Numeric(entry, value))
|
||||
});
|
||||
|
||||
self.renderer.render_train(progress);
|
||||
}
|
||||
Event::EndEpoch(epoch) => {
|
||||
self.metrics.end_epoch_train();
|
||||
self.store
|
||||
.add_event_train(crate::metric::store::Event::EndEpoch(epoch));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn process_valid(&mut self, event: Event<Self::ItemValid>) {
|
||||
match event {
|
||||
Event::ProcessedItem(item) => {
|
||||
let progress = (&item).into();
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.metrics.update_valid(&item, &metadata);
|
||||
|
||||
self.store
|
||||
.add_event_valid(crate::metric::store::Event::MetricsUpdate(update.clone()));
|
||||
|
||||
update
|
||||
.entries
|
||||
.into_iter()
|
||||
.for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry)));
|
||||
|
||||
update
|
||||
.entries_numeric
|
||||
.into_iter()
|
||||
.for_each(|(entry, value)| {
|
||||
self.renderer
|
||||
.update_valid(MetricState::Numeric(entry, value))
|
||||
});
|
||||
|
||||
self.renderer.render_valid(progress);
|
||||
}
|
||||
Event::EndEpoch(epoch) => {
|
||||
self.metrics.end_epoch_valid();
|
||||
self.store
|
||||
.add_event_valid(crate::metric::store::Event::EndEpoch(epoch));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,74 +1,66 @@
|
|||
use super::NumericMetricsAggregate;
|
||||
use super::LearnerItem;
|
||||
use crate::{
|
||||
logger::MetricLogger,
|
||||
metric::{Adaptor, Metric, MetricEntry, MetricMetadata, Numeric},
|
||||
Aggregate, Direction, LearnerItem, Split,
|
||||
metric::{store::MetricsUpdate, Adaptor, Metric, MetricEntry, MetricMetadata, Numeric},
|
||||
renderer::TrainingProgress,
|
||||
};
|
||||
|
||||
/// Metrics information collected during training.
|
||||
pub struct MetricsInfo<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
pub(crate) struct Metrics<T, V> {
|
||||
train: Vec<Box<dyn MetricUpdater<T>>>,
|
||||
valid: Vec<Box<dyn MetricUpdater<V>>>,
|
||||
train_numeric: Vec<Box<dyn NumericMetricUpdater<T>>>,
|
||||
valid_numeric: Vec<Box<dyn NumericMetricUpdater<V>>>,
|
||||
loggers_train: Vec<Box<dyn MetricLogger>>,
|
||||
loggers_valid: Vec<Box<dyn MetricLogger>>,
|
||||
aggregate_train: NumericMetricsAggregate,
|
||||
aggregate_valid: NumericMetricsAggregate,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct MetricsUpdate {
|
||||
pub(crate) entries: Vec<MetricEntry>,
|
||||
pub(crate) entries_numeric: Vec<(MetricEntry, f64)>,
|
||||
}
|
||||
|
||||
impl<T, V> MetricsInfo<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
pub(crate) fn new() -> Self {
|
||||
impl<T, V> Default for Metrics<T, V> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
train: vec![],
|
||||
valid: vec![],
|
||||
train_numeric: vec![],
|
||||
valid_numeric: vec![],
|
||||
loggers_train: vec![],
|
||||
loggers_valid: vec![],
|
||||
aggregate_train: NumericMetricsAggregate::default(),
|
||||
aggregate_valid: NumericMetricsAggregate::default(),
|
||||
train: Vec::default(),
|
||||
valid: Vec::default(),
|
||||
train_numeric: Vec::default(),
|
||||
valid_numeric: Vec::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Signal the end of a training epoch.
|
||||
pub(crate) fn end_epoch_train(&mut self, epoch: usize) {
|
||||
for metric in self.train.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
for metric in self.train_numeric.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
for logger in self.loggers_train.iter_mut() {
|
||||
logger.epoch(epoch + 1);
|
||||
}
|
||||
impl<T, V> Metrics<T, V> {
|
||||
/// Register a training metric.
|
||||
pub(crate) fn register_metric_train<Me: Metric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
T: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.train.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Signal the end of a validation epoch.
|
||||
pub(crate) fn end_epoch_valid(&mut self, epoch: usize) {
|
||||
for metric in self.valid.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
for metric in self.valid_numeric.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
for logger in self.loggers_valid.iter_mut() {
|
||||
logger.epoch(epoch + 1);
|
||||
}
|
||||
/// Register a validation metric.
|
||||
pub(crate) fn register_valid_metric<Me: Metric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
V: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.valid.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register a numeric training metric.
|
||||
pub(crate) fn register_train_metric_numeric<Me: Metric + Numeric + 'static>(
|
||||
&mut self,
|
||||
metric: Me,
|
||||
) where
|
||||
T: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.train_numeric.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register a numeric validation metric.
|
||||
pub(crate) fn register_valid_metric_numeric<Me: Metric + Numeric + 'static>(
|
||||
&mut self,
|
||||
metric: Me,
|
||||
) where
|
||||
V: Adaptor<Me::Input> + 'static,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.valid_numeric.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Update the training information from the training item.
|
||||
|
@ -82,20 +74,11 @@ where
|
|||
|
||||
for metric in self.train.iter_mut() {
|
||||
let state = metric.update(item, metadata);
|
||||
|
||||
for logger in self.loggers_train.iter_mut() {
|
||||
logger.log(&state);
|
||||
}
|
||||
|
||||
entries.push(state);
|
||||
}
|
||||
|
||||
for metric in self.train_numeric.iter_mut() {
|
||||
let (state, value) = metric.update(item, metadata);
|
||||
for logger in self.loggers_train.iter_mut() {
|
||||
logger.log(&state);
|
||||
}
|
||||
|
||||
entries_numeric.push((state, value));
|
||||
}
|
||||
|
||||
|
@ -113,94 +96,58 @@ where
|
|||
|
||||
for metric in self.valid.iter_mut() {
|
||||
let state = metric.update(item, metadata);
|
||||
|
||||
for logger in self.loggers_valid.iter_mut() {
|
||||
logger.log(&state);
|
||||
}
|
||||
|
||||
entries.push(state);
|
||||
}
|
||||
|
||||
for metric in self.valid_numeric.iter_mut() {
|
||||
let (state, value) = metric.update(item, metadata);
|
||||
for logger in self.loggers_valid.iter_mut() {
|
||||
logger.log(&state);
|
||||
}
|
||||
|
||||
entries_numeric.push((state, value));
|
||||
}
|
||||
|
||||
MetricsUpdate::new(entries, entries_numeric)
|
||||
}
|
||||
|
||||
/// Find the epoch corresponding to the given criteria.
|
||||
pub(crate) fn find_epoch(
|
||||
&mut self,
|
||||
name: &str,
|
||||
aggregate: Aggregate,
|
||||
direction: Direction,
|
||||
split: Split,
|
||||
) -> Option<usize> {
|
||||
match split {
|
||||
Split::Train => {
|
||||
self.aggregate_train
|
||||
.find_epoch(name, aggregate, direction, &mut self.loggers_train)
|
||||
}
|
||||
Split::Valid => {
|
||||
self.aggregate_valid
|
||||
.find_epoch(name, aggregate, direction, &mut self.loggers_valid)
|
||||
}
|
||||
/// Signal the end of a training epoch.
|
||||
pub(crate) fn end_epoch_train(&mut self) {
|
||||
for metric in self.train.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
for metric in self.train_numeric.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a logger for training metrics.
|
||||
pub(crate) fn register_logger_train<ML: MetricLogger + 'static>(&mut self, logger: ML) {
|
||||
self.loggers_train.push(Box::new(logger));
|
||||
/// Signal the end of a validation epoch.
|
||||
pub(crate) fn end_epoch_valid(&mut self) {
|
||||
for metric in self.valid.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
for metric in self.valid_numeric.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a logger for validation metrics.
|
||||
pub(crate) fn register_logger_valid<ML: MetricLogger + 'static>(&mut self, logger: ML) {
|
||||
self.loggers_valid.push(Box::new(logger));
|
||||
impl<T> From<&LearnerItem<T>> for TrainingProgress {
|
||||
fn from(item: &LearnerItem<T>) -> Self {
|
||||
Self {
|
||||
progress: item.progress.clone(),
|
||||
epoch: item.epoch,
|
||||
epoch_total: item.epoch_total,
|
||||
iteration: item.iteration,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a training metric.
|
||||
pub(crate) fn register_metric_train<Me: Metric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
T: Adaptor<Me::Input>,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.train.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register a validation metric.
|
||||
pub(crate) fn register_valid_metric<Me: Metric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
V: Adaptor<Me::Input>,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.valid.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register a numeric training metric.
|
||||
pub(crate) fn register_train_metric_numeric<Me: Metric + Numeric + 'static>(
|
||||
&mut self,
|
||||
metric: Me,
|
||||
) where
|
||||
T: Adaptor<Me::Input>,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.train_numeric.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register a numeric validation metric.
|
||||
pub(crate) fn register_valid_metric_numeric<Me: Metric + Numeric + 'static>(
|
||||
&mut self,
|
||||
metric: Me,
|
||||
) where
|
||||
V: Adaptor<Me::Input>,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.valid_numeric.push(Box::new(metric))
|
||||
impl<T> From<&LearnerItem<T>> for MetricMetadata {
|
||||
fn from(item: &LearnerItem<T>) -> Self {
|
||||
Self {
|
||||
progress: item.progress.clone(),
|
||||
epoch: item.epoch,
|
||||
epoch_total: item.epoch_total,
|
||||
iteration: item.iteration,
|
||||
lr: item.lr,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
use super::{Event, EventProcessor, Metrics};
|
||||
use crate::metric::store::EventStoreClient;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// An [event processor](EventProcessor) that handles:
|
||||
/// - Computing and storing metrics in an [event store](crate::metric::store::EventStore).
|
||||
#[derive(new)]
|
||||
pub(crate) struct MinimalEventProcessor<T, V> {
|
||||
metrics: Metrics<T, V>,
|
||||
store: Arc<EventStoreClient>,
|
||||
}
|
||||
|
||||
impl<T, V> EventProcessor for MinimalEventProcessor<T, V> {
|
||||
type ItemTrain = T;
|
||||
type ItemValid = V;
|
||||
|
||||
fn process_train(&mut self, event: Event<Self::ItemTrain>) {
|
||||
match event {
|
||||
Event::ProcessedItem(item) => {
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.metrics.update_train(&item, &metadata);
|
||||
|
||||
self.store
|
||||
.add_event_train(crate::metric::store::Event::MetricsUpdate(update));
|
||||
}
|
||||
Event::EndEpoch(epoch) => {
|
||||
self.metrics.end_epoch_train();
|
||||
self.store
|
||||
.add_event_train(crate::metric::store::Event::EndEpoch(epoch));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn process_valid(&mut self, event: Event<Self::ItemValid>) {
|
||||
match event {
|
||||
Event::ProcessedItem(item) => {
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.metrics.update_valid(&item, &metadata);
|
||||
|
||||
self.store
|
||||
.add_event_valid(crate::metric::store::Event::MetricsUpdate(update));
|
||||
}
|
||||
Event::EndEpoch(epoch) => {
|
||||
self.metrics.end_epoch_valid();
|
||||
self.store
|
||||
.add_event_valid(crate::metric::store::Event::EndEpoch(epoch));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
mod base;
|
||||
mod full;
|
||||
mod metrics;
|
||||
mod minimal;
|
||||
|
||||
pub use base::*;
|
||||
pub(crate) use full::*;
|
||||
pub(crate) use metrics::*;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) use minimal::*;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod test_utils {
|
||||
use crate::metric::{
|
||||
processor::{Event, EventProcessor, LearnerItem, MinimalEventProcessor},
|
||||
Adaptor, LossInput,
|
||||
};
|
||||
use burn_core::tensor::{backend::Backend, ElementConversion, Tensor};
|
||||
|
||||
impl<B: Backend> Adaptor<LossInput<B>> for f64 {
|
||||
fn adapt(&self) -> LossInput<B> {
|
||||
LossInput::new(Tensor::from_data([self.elem()]))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn process_train(
|
||||
processor: &mut MinimalEventProcessor<f64, f64>,
|
||||
value: f64,
|
||||
epoch: usize,
|
||||
) {
|
||||
let dummy_progress = burn_core::data::dataloader::Progress {
|
||||
items_processed: 1,
|
||||
items_total: 10,
|
||||
};
|
||||
let num_epochs = 3;
|
||||
let dummy_iteration = 1;
|
||||
|
||||
processor.process_train(Event::ProcessedItem(LearnerItem::new(
|
||||
value,
|
||||
dummy_progress,
|
||||
epoch,
|
||||
num_epochs,
|
||||
dummy_iteration,
|
||||
None,
|
||||
)));
|
||||
}
|
||||
|
||||
pub(crate) fn end_epoch(processor: &mut MinimalEventProcessor<f64, f64>, epoch: usize) {
|
||||
processor.process_train(Event::EndEpoch(epoch));
|
||||
processor.process_valid(Event::EndEpoch(epoch));
|
||||
}
|
||||
}
|
|
@ -1,28 +1,32 @@
|
|||
use crate::{logger::MetricLogger, Aggregate, Direction};
|
||||
use crate::logger::MetricLogger;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::{Aggregate, Direction};
|
||||
|
||||
/// Type that can be used to fetch and use numeric metric aggregates.
|
||||
#[derive(Default, Debug)]
|
||||
pub(crate) struct NumericMetricsAggregate {
|
||||
mean_for_each_epoch: HashMap<Key, f64>,
|
||||
value_for_each_epoch: HashMap<Key, f64>,
|
||||
}
|
||||
|
||||
#[derive(new, Hash, PartialEq, Eq, Debug)]
|
||||
struct Key {
|
||||
name: String,
|
||||
epoch: usize,
|
||||
aggregate: Aggregate,
|
||||
}
|
||||
|
||||
impl NumericMetricsAggregate {
|
||||
pub(crate) fn mean(
|
||||
pub(crate) fn aggregate(
|
||||
&mut self,
|
||||
name: &str,
|
||||
epoch: usize,
|
||||
aggregate: Aggregate,
|
||||
loggers: &mut [Box<dyn MetricLogger>],
|
||||
) -> Option<f64> {
|
||||
let key = Key::new(name.to_string(), epoch);
|
||||
let key = Key::new(name.to_string(), epoch, aggregate);
|
||||
|
||||
if let Some(value) = self.mean_for_each_epoch.get(&key) {
|
||||
if let Some(value) = self.value_for_each_epoch.get(&key) {
|
||||
return Some(*value);
|
||||
}
|
||||
|
||||
|
@ -45,10 +49,13 @@ impl NumericMetricsAggregate {
|
|||
}
|
||||
|
||||
let num_points = points.len();
|
||||
let mean = points.into_iter().sum::<f64>() / num_points as f64;
|
||||
let sum = points.into_iter().sum::<f64>();
|
||||
let value = match aggregate {
|
||||
Aggregate::Mean => sum / num_points as f64,
|
||||
};
|
||||
|
||||
self.mean_for_each_epoch.insert(key, mean);
|
||||
Some(mean)
|
||||
self.value_for_each_epoch.insert(key, value);
|
||||
Some(value)
|
||||
}
|
||||
|
||||
pub(crate) fn find_epoch(
|
||||
|
@ -61,16 +68,8 @@ impl NumericMetricsAggregate {
|
|||
let mut data = Vec::new();
|
||||
let mut current_epoch = 1;
|
||||
|
||||
loop {
|
||||
match aggregate {
|
||||
Aggregate::Mean => match self.mean(name, current_epoch, loggers) {
|
||||
Some(value) => {
|
||||
data.push(value);
|
||||
}
|
||||
None => break,
|
||||
},
|
||||
};
|
||||
|
||||
while let Some(value) = self.aggregate(name, current_epoch, aggregate, loggers) {
|
||||
data.push(value);
|
||||
current_epoch += 1;
|
||||
}
|
||||
|
||||
|
@ -131,8 +130,8 @@ mod tests {
|
|||
));
|
||||
}
|
||||
fn new_epoch(&mut self) {
|
||||
self.logger.end_epoch(self.epoch);
|
||||
self.epoch += 1;
|
||||
self.logger.epoch(self.epoch);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
use crate::metric::MetricEntry;
|
||||
|
||||
/// Event happening during the training/validation process.
|
||||
pub enum Event {
|
||||
/// Signal that metrics have been updated.
|
||||
MetricsUpdate(MetricsUpdate),
|
||||
/// Signal the end of an epoch.
|
||||
EndEpoch(usize),
|
||||
}
|
||||
|
||||
/// Contains all metric information.
|
||||
#[derive(new, Clone)]
|
||||
pub struct MetricsUpdate {
|
||||
/// Metrics information related to non-numeric metrics.
|
||||
pub entries: Vec<MetricEntry>,
|
||||
/// Metrics information related to numeric metrics.
|
||||
pub entries_numeric: Vec<(MetricEntry, f64)>,
|
||||
}
|
||||
|
||||
/// Defines how training and validation events are collected and searched.
|
||||
///
|
||||
/// This trait also exposes methods that uses the collected data to compute useful information.
|
||||
pub trait EventStore: Send {
|
||||
/// Collect a training/validation event.
|
||||
fn add_event(&mut self, event: Event, split: Split);
|
||||
|
||||
/// Find the epoch following the given criteria from the collected data.
|
||||
fn find_epoch(
|
||||
&mut self,
|
||||
name: &str,
|
||||
aggregate: Aggregate,
|
||||
direction: Direction,
|
||||
split: Split,
|
||||
) -> Option<usize>;
|
||||
|
||||
/// Find the metric value for the current epoch following the given criteria.
|
||||
fn find_metric(
|
||||
&mut self,
|
||||
name: &str,
|
||||
epoch: usize,
|
||||
aggregate: Aggregate,
|
||||
split: Split,
|
||||
) -> Option<f64>;
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Hash, PartialEq, Eq, Debug)]
|
||||
/// How to aggregate the metric.
|
||||
pub enum Aggregate {
|
||||
/// Compute the average.
|
||||
Mean,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
/// The split to use.
|
||||
pub enum Split {
|
||||
/// The training split.
|
||||
Train,
|
||||
/// The validation split.
|
||||
Valid,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
/// The direction of the query.
|
||||
pub enum Direction {
|
||||
/// Lower is better.
|
||||
Lowest,
|
||||
/// Higher is better.
|
||||
Highest,
|
||||
}
|
|
@ -0,0 +1,149 @@
|
|||
use super::EventStore;
|
||||
use super::{Aggregate, Direction, Event, Split};
|
||||
use std::{sync::mpsc, thread::JoinHandle};
|
||||
|
||||
/// Type that allows to communicate with an [event store](EventStore).
|
||||
pub struct EventStoreClient {
|
||||
sender: mpsc::Sender<Message>,
|
||||
handler: Option<JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl EventStoreClient {
|
||||
/// Create a new [event store](EventStore) client.
|
||||
pub(crate) fn new<C>(store: C) -> Self
|
||||
where
|
||||
C: EventStore + 'static,
|
||||
{
|
||||
let (sender, receiver) = mpsc::channel();
|
||||
let thread = WorkerThread::new(store, receiver);
|
||||
|
||||
let handler = std::thread::spawn(move || thread.run());
|
||||
let handler = Some(handler);
|
||||
|
||||
Self { sender, handler }
|
||||
}
|
||||
}
|
||||
|
||||
impl EventStoreClient {
|
||||
/// Add a training event to the [event store](EventStore).
|
||||
pub(crate) fn add_event_train(&self, event: Event) {
|
||||
self.sender.send(Message::OnEventTrain(event)).unwrap();
|
||||
}
|
||||
|
||||
/// Add a validation event to the [event store](EventStore).
|
||||
pub(crate) fn add_event_valid(&self, event: Event) {
|
||||
self.sender.send(Message::OnEventValid(event)).unwrap();
|
||||
}
|
||||
|
||||
/// Find the epoch following the given criteria from the collected data.
|
||||
pub fn find_epoch(
|
||||
&self,
|
||||
name: &str,
|
||||
aggregate: Aggregate,
|
||||
direction: Direction,
|
||||
split: Split,
|
||||
) -> Option<usize> {
|
||||
let (sender, receiver) = mpsc::sync_channel(1);
|
||||
self.sender
|
||||
.send(Message::FindEpoch(
|
||||
name.to_string(),
|
||||
aggregate,
|
||||
direction,
|
||||
split,
|
||||
sender,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
match receiver.recv() {
|
||||
Ok(value) => value,
|
||||
Err(err) => panic!("Event store thread crashed: {:?}", err),
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the metric value for the current epoch following the given criteria.
|
||||
pub fn find_metric(
|
||||
&self,
|
||||
name: &str,
|
||||
epoch: usize,
|
||||
aggregate: Aggregate,
|
||||
split: Split,
|
||||
) -> Option<f64> {
|
||||
let (sender, receiver) = mpsc::sync_channel(1);
|
||||
self.sender
|
||||
.send(Message::FindMetric(
|
||||
name.to_string(),
|
||||
epoch,
|
||||
aggregate,
|
||||
split,
|
||||
sender,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
match receiver.recv() {
|
||||
Ok(value) => value,
|
||||
Err(err) => panic!("Event store thread crashed: {:?}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct WorkerThread<S> {
|
||||
store: S,
|
||||
receiver: mpsc::Receiver<Message>,
|
||||
}
|
||||
|
||||
impl<C> WorkerThread<C>
|
||||
where
|
||||
C: EventStore,
|
||||
{
|
||||
fn run(mut self) {
|
||||
for item in self.receiver.iter() {
|
||||
match item {
|
||||
Message::End => {
|
||||
return;
|
||||
}
|
||||
Message::FindEpoch(name, aggregate, direction, split, sender) => {
|
||||
let response = self.store.find_epoch(&name, aggregate, direction, split);
|
||||
sender.send(response).unwrap();
|
||||
}
|
||||
Message::FindMetric(name, epoch, aggregate, split, sender) => {
|
||||
let response = self.store.find_metric(&name, epoch, aggregate, split);
|
||||
sender.send(response).unwrap();
|
||||
}
|
||||
Message::OnEventTrain(event) => self.store.add_event(event, Split::Train),
|
||||
Message::OnEventValid(event) => self.store.add_event(event, Split::Valid),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum Message {
|
||||
OnEventTrain(Event),
|
||||
OnEventValid(Event),
|
||||
End,
|
||||
FindEpoch(
|
||||
String,
|
||||
Aggregate,
|
||||
Direction,
|
||||
Split,
|
||||
mpsc::SyncSender<Option<usize>>,
|
||||
),
|
||||
FindMetric(
|
||||
String,
|
||||
usize,
|
||||
Aggregate,
|
||||
Split,
|
||||
mpsc::SyncSender<Option<f64>>,
|
||||
),
|
||||
}
|
||||
|
||||
impl Drop for EventStoreClient {
|
||||
fn drop(&mut self) {
|
||||
self.sender.send(Message::End).unwrap();
|
||||
let handler = self.handler.take();
|
||||
|
||||
if let Some(handler) = handler {
|
||||
handler.join().unwrap();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
use super::{aggregate::NumericMetricsAggregate, Aggregate, Direction, Event, EventStore, Split};
|
||||
use crate::logger::MetricLogger;
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct LogEventStore {
|
||||
loggers_train: Vec<Box<dyn MetricLogger>>,
|
||||
loggers_valid: Vec<Box<dyn MetricLogger>>,
|
||||
aggregate_train: NumericMetricsAggregate,
|
||||
aggregate_valid: NumericMetricsAggregate,
|
||||
}
|
||||
|
||||
impl EventStore for LogEventStore {
|
||||
fn add_event(&mut self, event: Event, split: Split) {
|
||||
match event {
|
||||
Event::MetricsUpdate(update) => match split {
|
||||
Split::Train => {
|
||||
update
|
||||
.entries
|
||||
.iter()
|
||||
.chain(update.entries_numeric.iter().map(|(entry, _value)| entry))
|
||||
.for_each(|entry| {
|
||||
self.loggers_train
|
||||
.iter_mut()
|
||||
.for_each(|logger| logger.log(entry));
|
||||
});
|
||||
}
|
||||
Split::Valid => {
|
||||
update
|
||||
.entries
|
||||
.iter()
|
||||
.chain(update.entries_numeric.iter().map(|(entry, _value)| entry))
|
||||
.for_each(|entry| {
|
||||
self.loggers_valid
|
||||
.iter_mut()
|
||||
.for_each(|logger| logger.log(entry));
|
||||
});
|
||||
}
|
||||
},
|
||||
Event::EndEpoch(epoch) => match split {
|
||||
Split::Train => self
|
||||
.loggers_train
|
||||
.iter_mut()
|
||||
.for_each(|logger| logger.end_epoch(epoch)),
|
||||
Split::Valid => self
|
||||
.loggers_valid
|
||||
.iter_mut()
|
||||
.for_each(|logger| logger.end_epoch(epoch + 1)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn find_epoch(
|
||||
&mut self,
|
||||
name: &str,
|
||||
aggregate: Aggregate,
|
||||
direction: Direction,
|
||||
split: Split,
|
||||
) -> Option<usize> {
|
||||
match split {
|
||||
Split::Train => {
|
||||
self.aggregate_train
|
||||
.find_epoch(name, aggregate, direction, &mut self.loggers_train)
|
||||
}
|
||||
Split::Valid => {
|
||||
self.aggregate_valid
|
||||
.find_epoch(name, aggregate, direction, &mut self.loggers_valid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn find_metric(
|
||||
&mut self,
|
||||
name: &str,
|
||||
epoch: usize,
|
||||
aggregate: Aggregate,
|
||||
split: Split,
|
||||
) -> Option<f64> {
|
||||
match split {
|
||||
Split::Train => {
|
||||
self.aggregate_train
|
||||
.aggregate(name, epoch, aggregate, &mut self.loggers_train)
|
||||
}
|
||||
Split::Valid => {
|
||||
self.aggregate_valid
|
||||
.aggregate(name, epoch, aggregate, &mut self.loggers_valid)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LogEventStore {
|
||||
/// Register a logger for training metrics.
|
||||
pub(crate) fn register_logger_train<ML: MetricLogger + 'static>(&mut self, logger: ML) {
|
||||
self.loggers_train.push(Box::new(logger));
|
||||
}
|
||||
|
||||
/// Register a logger for validation metrics.
|
||||
pub(crate) fn register_logger_valid<ML: MetricLogger + 'static>(&mut self, logger: ML) {
|
||||
self.loggers_valid.push(Box::new(logger));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
pub(crate) mod aggregate;
|
||||
|
||||
mod base;
|
||||
mod client;
|
||||
mod log;
|
||||
|
||||
pub(crate) use self::log::*;
|
||||
pub use base::*;
|
||||
pub use client::*;
|
|
@ -1,4 +1,4 @@
|
|||
use crate::metric::callback::{MetricState, MetricsRenderer, TrainingProgress};
|
||||
use crate::metric::renderer::{MetricState, MetricsRenderer, TrainingProgress};
|
||||
|
||||
/// A simple renderer for when the cli feature is not enabled.
|
||||
pub struct CliMetricsRenderer;
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
use crate::data::MNISTBatch;
|
||||
|
||||
use burn::{
|
||||
module::Module,
|
||||
nn::{self, loss::CrossEntropyLoss, BatchNorm, PaddingConfig2d},
|
||||
|
|
|
@ -5,7 +5,9 @@ use burn::module::Module;
|
|||
use burn::optim::decay::WeightDecayConfig;
|
||||
use burn::optim::AdamConfig;
|
||||
use burn::record::{CompactRecorder, NoStdTrainingRecorder};
|
||||
use burn::train::metric::store::{Aggregate, Direction, Split};
|
||||
use burn::train::metric::{CpuMemory, CpuTemperature, CpuUse};
|
||||
use burn::train::{MetricEarlyStoppingStrategy, StoppingCondition};
|
||||
use burn::{
|
||||
config::Config,
|
||||
data::{dataloader::DataLoaderBuilder, dataset::source::huggingface::MNISTDataset},
|
||||
|
@ -69,6 +71,12 @@ pub fn run<B: ADBackend>(device: B::Device) {
|
|||
.metric_train_numeric(LossMetric::new())
|
||||
.metric_valid_numeric(LossMetric::new())
|
||||
.with_file_checkpointer(CompactRecorder::new())
|
||||
.early_stopping(MetricEarlyStoppingStrategy::new::<LossMetric<B>>(
|
||||
Aggregate::Mean,
|
||||
Direction::Lowest,
|
||||
Split::Valid,
|
||||
StoppingCondition::NoImprovementSince { n_epochs: 1 },
|
||||
))
|
||||
.devices(vec![device])
|
||||
.num_epochs(config.num_epochs)
|
||||
.build(Model::new(), config.optimizer.init(), 1e-4);
|
||||
|
|
Loading…
Reference in New Issue