Feat training events (#857)

This commit is contained in:
Nathaniel Simard 2023-10-10 13:27:03 -04:00 committed by GitHub
parent 097fd956d0
commit 620b86de98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
43 changed files with 959 additions and 547 deletions

View File

@ -111,10 +111,10 @@ pub fn train<B: ADBackend>(artifact_dir: &str, config: TrainingConfig, device: B
.build(MNISTDataset::test()); .build(MNISTDataset::test());
let learner = LearnerBuilder::new(artifact_dir) let learner = LearnerBuilder::new(artifact_dir)
.metric_train_plot(AccuracyMetric::new()) .metric_train_numeric(AccuracyMetric::new())
.metric_valid_plot(AccuracyMetric::new()) .metric_valid_numeric(AccuracyMetric::new())
.metric_train_plot(LossMetric::new()) .metric_train_numeric(LossMetric::new())
.metric_valid_plot(LossMetric::new()) .metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(1, CompactRecorder::new()) .with_file_checkpointer(1, CompactRecorder::new())
.devices(vec![device]) .devices(vec![device])
.num_epochs(config.num_epochs) .num_epochs(config.num_epochs)

View File

@ -1,97 +0,0 @@
use super::{LearnerCallback, LearnerItem};
use std::{sync::mpsc, thread::JoinHandle};
enum Message<T, V> {
LogTrain(LearnerItem<T>),
LogValid(LearnerItem<V>),
ClearTrain(usize),
ClearValid(usize),
End,
}
/// Async trainer callback tracker.
pub struct AsyncTrainerCallback<T, V> {
sender: mpsc::Sender<Message<T, V>>,
handler: Option<JoinHandle<()>>,
}
#[derive(new)]
struct CallbackThread<C, T, V> {
callback: C,
receiver: mpsc::Receiver<Message<T, V>>,
}
impl<C, T, V> CallbackThread<C, T, V>
where
C: LearnerCallback<ItemTrain = T, ItemValid = V>,
{
fn run(mut self) {
for item in self.receiver.iter() {
match item {
Message::LogTrain(item) => {
self.callback.on_train_item(item);
}
Message::ClearTrain(epoch) => {
self.callback.on_train_end_epoch(epoch);
}
Message::LogValid(item) => {
self.callback.on_valid_item(item);
}
Message::ClearValid(epoch) => {
self.callback.on_valid_end_epoch(epoch);
}
Message::End => {
return;
}
}
}
}
}
impl<T: Send + Sync + 'static, V: Send + Sync + 'static> AsyncTrainerCallback<T, V> {
/// Create a new async trainer callback.
pub fn new<C>(callback: C) -> Self
where
C: LearnerCallback<ItemTrain = T, ItemValid = V> + 'static,
{
let (sender, receiver) = mpsc::channel();
let thread = CallbackThread::new(callback, receiver);
let handler = std::thread::spawn(move || thread.run());
let handler = Some(handler);
Self { sender, handler }
}
}
impl<T: Send, V: Send> LearnerCallback for AsyncTrainerCallback<T, V> {
type ItemTrain = T;
type ItemValid = V;
fn on_train_item(&mut self, item: LearnerItem<T>) {
self.sender.send(Message::LogTrain(item)).unwrap();
}
fn on_valid_item(&mut self, item: LearnerItem<V>) {
self.sender.send(Message::LogValid(item)).unwrap();
}
fn on_train_end_epoch(&mut self, epoch: usize) {
self.sender.send(Message::ClearTrain(epoch)).unwrap();
}
fn on_valid_end_epoch(&mut self, epoch: usize) {
self.sender.send(Message::ClearValid(epoch)).unwrap();
}
}
impl<T, V> Drop for AsyncTrainerCallback<T, V> {
fn drop(&mut self) {
self.sender.send(Message::End).unwrap();
let handler = self.handler.take();
if let Some(handler) = handler {
handler.join().unwrap();
}
}
}

View File

@ -1,43 +0,0 @@
use burn_core::{data::dataloader::Progress, LearningRate};
/// The base trait for trainer callbacks.
pub trait LearnerCallback: Send {
/// Training item.
type ItemTrain;
/// Validation item.
type ItemValid;
/// Called when a training item is logged.
fn on_train_item(&mut self, _item: LearnerItem<Self::ItemTrain>) {}
/// Called when a validation item is logged.
fn on_valid_item(&mut self, _item: LearnerItem<Self::ItemValid>) {}
/// Called when a training epoch is finished.
fn on_train_end_epoch(&mut self, _epoch: usize) {}
/// Called when a validation epoch is finished.
fn on_valid_end_epoch(&mut self, _epoch: usize) {}
}
/// A learner item.
#[derive(new)]
pub struct LearnerItem<T> {
/// The item.
pub item: T,
/// The progress.
pub progress: Progress,
/// The epoch.
pub epoch: usize,
/// The total number of epochs.
pub epoch_total: usize,
/// The iteration.
pub iteration: usize,
/// The learning rate.
pub lr: Option<LearningRate>,
}

View File

@ -1,5 +0,0 @@
mod async_callback;
mod base;
pub use async_callback::*;
pub use base::*;

View File

@ -0,0 +1,118 @@
use super::EventCollector;
use crate::{Aggregate, Direction, Event, Split};
use std::{sync::mpsc, thread::JoinHandle};
enum Message<T, V> {
OnEventTrain(Event<T>),
OnEventValid(Event<V>),
End,
FindEpoch(
String,
Aggregate,
Direction,
Split,
mpsc::SyncSender<Option<usize>>,
),
}
/// Async [event collector](EventCollector).
///
/// This will create a worker thread where all the computation is done ensuring that the training loop is
/// never blocked by metric calculation.
pub struct AsyncEventCollector<T, V> {
sender: mpsc::Sender<Message<T, V>>,
handler: Option<JoinHandle<()>>,
}
#[derive(new)]
struct WorkerThread<C, T, V> {
collector: C,
receiver: mpsc::Receiver<Message<T, V>>,
}
impl<C, T, V> WorkerThread<C, T, V>
where
C: EventCollector<ItemTrain = T, ItemValid = V>,
{
fn run(mut self) {
for item in self.receiver.iter() {
match item {
Message::End => {
return;
}
Message::FindEpoch(name, aggregate, direction, split, sender) => {
let response = self
.collector
.find_epoch(&name, aggregate, direction, split);
sender.send(response).unwrap();
}
Message::OnEventTrain(event) => self.collector.on_event_train(event),
Message::OnEventValid(event) => self.collector.on_event_valid(event),
}
}
}
}
impl<T: Send + Sync + 'static, V: Send + Sync + 'static> AsyncEventCollector<T, V> {
/// Create a new async [event collector](EventCollector).
pub fn new<C>(collector: C) -> Self
where
C: EventCollector<ItemTrain = T, ItemValid = V> + 'static,
{
let (sender, receiver) = mpsc::channel();
let thread = WorkerThread::new(collector, receiver);
let handler = std::thread::spawn(move || thread.run());
let handler = Some(handler);
Self { sender, handler }
}
}
impl<T: Send, V: Send> EventCollector for AsyncEventCollector<T, V> {
type ItemTrain = T;
type ItemValid = V;
fn on_event_train(&mut self, event: Event<Self::ItemTrain>) {
self.sender.send(Message::OnEventTrain(event)).unwrap();
}
fn on_event_valid(&mut self, event: Event<Self::ItemValid>) {
self.sender.send(Message::OnEventValid(event)).unwrap();
}
fn find_epoch(
&mut self,
name: &str,
aggregate: Aggregate,
direction: Direction,
split: Split,
) -> Option<usize> {
let (sender, receiver) = mpsc::sync_channel(1);
self.sender
.send(Message::FindEpoch(
name.to_string(),
aggregate,
direction,
split,
sender,
))
.unwrap();
match receiver.recv() {
Ok(value) => value,
Err(err) => panic!("Async server crashed: {:?}", err),
}
}
}
impl<T, V> Drop for AsyncEventCollector<T, V> {
fn drop(&mut self) {
self.sender.send(Message::End).unwrap();
let handler = self.handler.take();
if let Some(handler) = handler {
handler.join().unwrap();
}
}
}

View File

@ -0,0 +1,78 @@
use burn_core::{data::dataloader::Progress, LearningRate};
/// Event happening during the training/validation process.
pub enum Event<T> {
/// Signal that an item have been processed.
ProcessedItem(LearnerItem<T>),
/// Signal the end of an epoch.
EndEpoch(usize),
}
/// Defines how training and validation events are collected.
///
/// This trait also exposes methods that uses the collected data to compute useful information.
pub trait EventCollector: Send {
/// Training item.
type ItemTrain;
/// Validation item.
type ItemValid;
/// Collect the training event.
fn on_event_train(&mut self, event: Event<Self::ItemTrain>);
/// Collect the validaion event.
fn on_event_valid(&mut self, event: Event<Self::ItemValid>);
/// Find the epoch following the given criteria from the collected data.
fn find_epoch(
&mut self,
name: &str,
aggregate: Aggregate,
direction: Direction,
split: Split,
) -> Option<usize>;
}
/// How to aggregate the metric.
pub enum Aggregate {
/// Compute the average.
Mean,
}
/// The split to use.
pub enum Split {
/// The training split.
Train,
/// The validation split.
Valid,
}
/// The direction of the query.
pub enum Direction {
/// Lower is better.
Lowest,
/// Higher is better.
Highest,
}
/// A learner item.
#[derive(new)]
pub struct LearnerItem<T> {
/// The item.
pub item: T,
/// The progress.
pub progress: Progress,
/// The epoch.
pub epoch: usize,
/// The total number of epochs.
pub epoch_total: usize,
/// The iteration.
pub iteration: usize,
/// The learning rate.
pub lr: Option<LearningRate>,
}

View File

@ -0,0 +1,131 @@
use crate::{
info::MetricsInfo,
metric::MetricMetadata,
renderer::{MetricState, MetricsRenderer, TrainingProgress},
Aggregate, Direction, Event, EventCollector, LearnerItem, Split,
};
/// Collect training events in order to display metrics with a metrics renderer.
#[derive(new)]
pub(crate) struct RenderedMetricsEventCollector<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
renderer: Box<dyn MetricsRenderer>,
info: MetricsInfo<T, V>,
}
impl<T, V> EventCollector for RenderedMetricsEventCollector<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
type ItemTrain = T;
type ItemValid = V;
fn on_event_train(&mut self, event: Event<Self::ItemTrain>) {
match event {
Event::ProcessedItem(item) => self.on_train_item(item),
Event::EndEpoch(epoch) => self.on_train_end_epoch(epoch),
}
}
fn on_event_valid(&mut self, event: Event<Self::ItemValid>) {
match event {
Event::ProcessedItem(item) => self.on_valid_item(item),
Event::EndEpoch(epoch) => self.on_valid_end_epoch(epoch),
}
}
fn find_epoch(
&mut self,
name: &str,
aggregate: Aggregate,
direction: Direction,
split: Split,
) -> Option<usize> {
self.info.find_epoch(name, aggregate, direction, split)
}
}
impl<T, V> RenderedMetricsEventCollector<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
fn on_train_item(&mut self, item: LearnerItem<T>) {
let progress = (&item).into();
let metadata = (&item).into();
let update = self.info.update_train(&item, &metadata);
update
.entries
.into_iter()
.for_each(|entry| self.renderer.update_train(MetricState::Generic(entry)));
update
.entries_numeric
.into_iter()
.for_each(|(entry, value)| {
self.renderer
.update_train(MetricState::Numeric(entry, value))
});
self.renderer.render_train(progress);
}
fn on_valid_item(&mut self, item: LearnerItem<V>) {
let progress = (&item).into();
let metadata = (&item).into();
let update = self.info.update_valid(&item, &metadata);
update
.entries
.into_iter()
.for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry)));
update
.entries_numeric
.into_iter()
.for_each(|(entry, value)| {
self.renderer
.update_valid(MetricState::Numeric(entry, value))
});
self.renderer.render_train(progress);
}
fn on_train_end_epoch(&mut self, epoch: usize) {
self.info.end_epoch_train(epoch);
}
fn on_valid_end_epoch(&mut self, epoch: usize) {
self.info.end_epoch_valid(epoch);
}
}
impl<T> From<&LearnerItem<T>> for TrainingProgress {
fn from(item: &LearnerItem<T>) -> Self {
Self {
progress: item.progress.clone(),
epoch: item.epoch,
epoch_total: item.epoch_total,
iteration: item.iteration,
}
}
}
impl<T> From<&LearnerItem<T>> for MetricMetadata {
fn from(item: &LearnerItem<T>) -> Self {
Self {
progress: item.progress.clone(),
epoch: item.epoch,
epoch_total: item.epoch_total,
iteration: item.iteration,
lr: item.lr,
}
}
}

View File

@ -0,0 +1,3 @@
mod base;
pub(crate) use base::*;

View File

@ -0,0 +1,8 @@
mod async_collector;
mod base;
pub use async_collector::*;
pub use base::*;
/// Metrics collector module.
pub mod metrics;

View File

@ -1,4 +1,4 @@
use crate::{checkpoint::Checkpointer, LearnerCallback}; use crate::{checkpoint::Checkpointer, EventCollector};
use burn_core::{ use burn_core::{
lr_scheduler::LrScheduler, lr_scheduler::LrScheduler,
module::{ADModule, Module}, module::{ADModule, Module},
@ -25,8 +25,8 @@ pub trait LearnerComponents {
>; >;
/// The checkpointer used for the scheduler. /// The checkpointer used for the scheduler.
type CheckpointerLrScheduler: Checkpointer<<Self::LrScheduler as LrScheduler>::Record>; type CheckpointerLrScheduler: Checkpointer<<Self::LrScheduler as LrScheduler>::Record>;
/// Callback used for training tracking. /// Training event collector used for training tracking.
type Callback: LearnerCallback + 'static; type EventCollector: EventCollector + 'static;
} }
/// Concrete type that implements [training components trait](TrainingComponents). /// Concrete type that implements [training components trait](TrainingComponents).
@ -41,8 +41,8 @@ pub struct LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, C> {
_callback: PhantomData<C>, _callback: PhantomData<C>,
} }
impl<B, LR, M, O, CM, CO, CS, C> LearnerComponents impl<B, LR, M, O, CM, CO, CS, EC> LearnerComponents
for LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, C> for LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EC>
where where
B: ADBackend, B: ADBackend,
LR: LrScheduler, LR: LrScheduler,
@ -51,7 +51,7 @@ where
CM: Checkpointer<M::Record>, CM: Checkpointer<M::Record>,
CO: Checkpointer<O::Record>, CO: Checkpointer<O::Record>,
CS: Checkpointer<LR::Record>, CS: Checkpointer<LR::Record>,
C: LearnerCallback + 'static, EC: EventCollector + 'static,
{ {
type Backend = B; type Backend = B;
type LrScheduler = LR; type LrScheduler = LR;
@ -60,5 +60,5 @@ where
type CheckpointerModel = CM; type CheckpointerModel = CM;
type CheckpointerOptimizer = CO; type CheckpointerOptimizer = CO;
type CheckpointerLrScheduler = CS; type CheckpointerLrScheduler = CS;
type Callback = C; type EventCollector = EC;
} }

View File

@ -0,0 +1,163 @@
use crate::{logger::MetricLogger, Aggregate, Direction};
use std::collections::HashMap;
/// Type that can be used to fetch and use numeric metric aggregates.
#[derive(Default, Debug)]
pub(crate) struct NumericMetricsAggregate {
mean_for_each_epoch: HashMap<Key, f64>,
}
#[derive(new, Hash, PartialEq, Eq, Debug)]
struct Key {
name: String,
epoch: usize,
}
impl NumericMetricsAggregate {
pub(crate) fn mean(
&mut self,
name: &str,
epoch: usize,
loggers: &mut [Box<dyn MetricLogger>],
) -> Option<f64> {
let key = Key::new(name.to_string(), epoch);
if let Some(value) = self.mean_for_each_epoch.get(&key) {
return Some(*value);
}
let points = || {
let mut errors = Vec::new();
for logger in loggers {
match logger.read_numeric(name, epoch) {
Ok(points) => return Ok(points),
Err(err) => errors.push(err),
};
}
Err(errors.join(" "))
};
let points = points().expect("Can read values");
if points.is_empty() {
return None;
}
let num_points = points.len();
let mean = points.into_iter().sum::<f64>() / num_points as f64;
self.mean_for_each_epoch.insert(key, mean);
Some(mean)
}
pub(crate) fn find_epoch(
&mut self,
name: &str,
aggregate: Aggregate,
direction: Direction,
loggers: &mut [Box<dyn MetricLogger>],
) -> Option<usize> {
let mut data = Vec::new();
let mut current_epoch = 1;
loop {
match aggregate {
Aggregate::Mean => match self.mean(name, current_epoch, loggers) {
Some(value) => {
data.push(value);
}
None => break,
},
};
current_epoch += 1;
}
if data.is_empty() {
return None;
}
let mut current_value = match &direction {
Direction::Lowest => f64::MAX,
Direction::Highest => f64::MIN,
};
for (i, value) in data.into_iter().enumerate() {
match &direction {
Direction::Lowest => {
if value < current_value {
current_value = value;
current_epoch = i + 1;
}
}
Direction::Highest => {
if value > current_value {
current_value = value;
current_epoch = i + 1;
}
}
}
}
Some(current_epoch)
}
}
#[cfg(test)]
mod tests {
use crate::{logger::FileMetricLogger, metric::MetricEntry};
use super::*;
struct TestLogger {
logger: FileMetricLogger,
epoch: usize,
}
const NAME: &str = "test-logger";
impl TestLogger {
fn new() -> Self {
Self {
logger: FileMetricLogger::new("/tmp"),
epoch: 1,
}
}
fn log(&mut self, num: f64) {
self.logger.log(&MetricEntry::new(
NAME.into(),
num.to_string(),
num.to_string(),
));
}
fn new_epoch(&mut self) {
self.epoch += 1;
self.logger.epoch(self.epoch);
}
}
#[test]
fn should_find_epoch() {
let mut logger = TestLogger::new();
let mut aggregate = NumericMetricsAggregate::default();
logger.log(500.); // Epoch 1
logger.log(1000.); // Epoch 1
logger.new_epoch();
logger.log(200.); // Epoch 2
logger.log(1000.); // Epoch 2
logger.new_epoch();
logger.log(10000.); // Epoch 3
let value = aggregate
.find_epoch(
NAME,
Aggregate::Mean,
Direction::Lowest,
&mut [Box::new(logger.logger)],
)
.unwrap();
assert_eq!(value, 2);
}
}

View File

@ -0,0 +1,253 @@
use super::NumericMetricsAggregate;
use crate::{
logger::MetricLogger,
metric::{Adaptor, Metric, MetricEntry, MetricMetadata, Numeric},
Aggregate, Direction, LearnerItem, Split,
};
/// Metrics information collected during training.
pub struct MetricsInfo<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
train: Vec<Box<dyn MetricUpdater<T>>>,
valid: Vec<Box<dyn MetricUpdater<V>>>,
train_numeric: Vec<Box<dyn NumericMetricUpdater<T>>>,
valid_numeric: Vec<Box<dyn NumericMetricUpdater<V>>>,
loggers_train: Vec<Box<dyn MetricLogger>>,
loggers_valid: Vec<Box<dyn MetricLogger>>,
aggregate_train: NumericMetricsAggregate,
aggregate_valid: NumericMetricsAggregate,
}
#[derive(new)]
pub(crate) struct MetricsUpdate {
pub(crate) entries: Vec<MetricEntry>,
pub(crate) entries_numeric: Vec<(MetricEntry, f64)>,
}
impl<T, V> MetricsInfo<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
pub(crate) fn new() -> Self {
Self {
train: vec![],
valid: vec![],
train_numeric: vec![],
valid_numeric: vec![],
loggers_train: vec![],
loggers_valid: vec![],
aggregate_train: NumericMetricsAggregate::default(),
aggregate_valid: NumericMetricsAggregate::default(),
}
}
/// Signal the end of a training epoch.
pub(crate) fn end_epoch_train(&mut self, epoch: usize) {
for metric in self.train.iter_mut() {
metric.clear();
}
for metric in self.train_numeric.iter_mut() {
metric.clear();
}
for logger in self.loggers_train.iter_mut() {
logger.epoch(epoch + 1);
}
}
/// Signal the end of a validation epoch.
pub(crate) fn end_epoch_valid(&mut self, epoch: usize) {
for metric in self.valid.iter_mut() {
metric.clear();
}
for metric in self.valid_numeric.iter_mut() {
metric.clear();
}
for logger in self.loggers_valid.iter_mut() {
logger.epoch(epoch + 1);
}
}
/// Update the training information from the training item.
pub(crate) fn update_train(
&mut self,
item: &LearnerItem<T>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.train.len());
let mut entries_numeric = Vec::with_capacity(self.train_numeric.len());
for metric in self.train.iter_mut() {
let state = metric.update(item, metadata);
for logger in self.loggers_train.iter_mut() {
logger.log(&state);
}
entries.push(state);
}
for metric in self.train_numeric.iter_mut() {
let (state, value) = metric.update(item, metadata);
for logger in self.loggers_train.iter_mut() {
logger.log(&state);
}
entries_numeric.push((state, value));
}
MetricsUpdate::new(entries, entries_numeric)
}
/// Update the training information from the validation item.
pub(crate) fn update_valid(
&mut self,
item: &LearnerItem<V>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.valid.len());
let mut entries_numeric = Vec::with_capacity(self.valid_numeric.len());
for metric in self.valid.iter_mut() {
let state = metric.update(item, metadata);
for logger in self.loggers_valid.iter_mut() {
logger.log(&state);
}
entries.push(state);
}
for metric in self.valid_numeric.iter_mut() {
let (state, value) = metric.update(item, metadata);
for logger in self.loggers_valid.iter_mut() {
logger.log(&state);
}
entries_numeric.push((state, value));
}
MetricsUpdate::new(entries, entries_numeric)
}
/// Find the epoch corresponding to the given criteria.
pub(crate) fn find_epoch(
&mut self,
name: &str,
aggregate: Aggregate,
direction: Direction,
split: Split,
) -> Option<usize> {
match split {
Split::Train => {
self.aggregate_train
.find_epoch(name, aggregate, direction, &mut self.loggers_train)
}
Split::Valid => {
self.aggregate_valid
.find_epoch(name, aggregate, direction, &mut self.loggers_valid)
}
}
}
/// Register a logger for training metrics.
pub(crate) fn register_logger_train<ML: MetricLogger + 'static>(&mut self, logger: ML) {
self.loggers_train.push(Box::new(logger));
}
/// Register a logger for validation metrics.
pub(crate) fn register_logger_valid<ML: MetricLogger + 'static>(&mut self, logger: ML) {
self.loggers_valid.push(Box::new(logger));
}
/// Register a training metric.
pub(crate) fn register_metric_train<Me: Metric + 'static>(&mut self, metric: Me)
where
T: Adaptor<Me::Input>,
{
let metric = MetricWrapper::new(metric);
self.train.push(Box::new(metric))
}
/// Register a validation metric.
pub(crate) fn register_valid_metric<Me: Metric + 'static>(&mut self, metric: Me)
where
V: Adaptor<Me::Input>,
{
let metric = MetricWrapper::new(metric);
self.valid.push(Box::new(metric))
}
/// Register a numeric training metric.
pub(crate) fn register_train_metric_numeric<Me: Metric + Numeric + 'static>(
&mut self,
metric: Me,
) where
T: Adaptor<Me::Input>,
{
let metric = MetricWrapper::new(metric);
self.train_numeric.push(Box::new(metric))
}
/// Register a numeric validation metric.
pub(crate) fn register_valid_metric_numeric<Me: Metric + Numeric + 'static>(
&mut self,
metric: Me,
) where
V: Adaptor<Me::Input>,
{
let metric = MetricWrapper::new(metric);
self.valid_numeric.push(Box::new(metric))
}
}
trait NumericMetricUpdater<T>: Send + Sync {
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> (MetricEntry, f64);
fn clear(&mut self);
}
trait MetricUpdater<T>: Send + Sync {
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry;
fn clear(&mut self);
}
#[derive(new)]
struct MetricWrapper<M> {
metric: M,
}
impl<T, M> NumericMetricUpdater<T> for MetricWrapper<M>
where
T: 'static,
M: Metric + Numeric + 'static,
T: Adaptor<M::Input>,
{
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> (MetricEntry, f64) {
let update = self.metric.update(&item.item.adapt(), metadata);
let numeric = self.metric.value();
(update, numeric)
}
fn clear(&mut self) {
self.metric.clear()
}
}
impl<T, M> MetricUpdater<T> for MetricWrapper<M>
where
T: 'static,
M: Metric + 'static,
T: Adaptor<M::Input>,
{
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry {
self.metric.update(&item.item.adapt(), metadata)
}
fn clear(&mut self) {
self.metric.clear()
}
}

View File

@ -0,0 +1,5 @@
mod aggregates;
mod metrics;
pub(crate) use aggregates::*;
pub use metrics::*;

View File

@ -19,7 +19,7 @@ pub struct Learner<LC: LearnerComponents> {
pub(crate) grad_accumulation: Option<usize>, pub(crate) grad_accumulation: Option<usize>,
pub(crate) checkpointer: Option<LearnerCheckpointer<LC>>, pub(crate) checkpointer: Option<LearnerCheckpointer<LC>>,
pub(crate) devices: Vec<<LC::Backend as Backend>::Device>, pub(crate) devices: Vec<<LC::Backend as Backend>::Device>,
pub(crate) callback: LC::Callback, pub(crate) collector: LC::EventCollector,
pub(crate) interrupter: TrainingInterrupter, pub(crate) interrupter: TrainingInterrupter,
} }

View File

@ -1,14 +1,14 @@
use super::log::install_file_logger; use super::log::install_file_logger;
use super::Learner; use super::Learner;
use crate::checkpoint::{AsyncCheckpointer, FileCheckpointer}; use crate::checkpoint::{AsyncCheckpointer, FileCheckpointer};
use crate::collector::metrics::RenderedMetricsEventCollector;
use crate::components::LearnerComponentsMarker; use crate::components::LearnerComponentsMarker;
use crate::info::MetricsInfo;
use crate::learner::base::TrainingInterrupter; use crate::learner::base::TrainingInterrupter;
use crate::logger::{FileMetricLogger, MetricLogger}; use crate::logger::{FileMetricLogger, MetricLogger};
use crate::metric::callback::{
default_renderer, MetricWrapper, Metrics, MetricsCallback, MetricsRenderer,
};
use crate::metric::{Adaptor, Metric}; use crate::metric::{Adaptor, Metric};
use crate::{AsyncTrainerCallback, LearnerCheckpointer}; use crate::renderer::{default_renderer, MetricsRenderer};
use crate::{AsyncEventCollector, LearnerCheckpointer};
use burn_core::lr_scheduler::LrScheduler; use burn_core::lr_scheduler::LrScheduler;
use burn_core::module::ADModule; use burn_core::module::ADModule;
use burn_core::optim::Optimizer; use burn_core::optim::Optimizer;
@ -39,12 +39,11 @@ where
directory: String, directory: String,
grad_accumulation: Option<usize>, grad_accumulation: Option<usize>,
devices: Vec<B::Device>, devices: Vec<B::Device>,
metric_logger_train: Option<Box<dyn MetricLogger + 'static>>,
metric_logger_valid: Option<Box<dyn MetricLogger + 'static>>,
renderer: Option<Box<dyn MetricsRenderer + 'static>>, renderer: Option<Box<dyn MetricsRenderer + 'static>>,
metrics: Metrics<T, V>, info: MetricsInfo<T, V>,
interrupter: TrainingInterrupter, interrupter: TrainingInterrupter,
log_to_file: bool, log_to_file: bool,
num_loggers: usize,
} }
impl<B, T, V, M, O, S> LearnerBuilder<B, T, V, M, O, S> impl<B, T, V, M, O, S> LearnerBuilder<B, T, V, M, O, S>
@ -69,12 +68,11 @@ where
directory: directory.to_string(), directory: directory.to_string(),
grad_accumulation: None, grad_accumulation: None,
devices: vec![B::Device::default()], devices: vec![B::Device::default()],
metric_logger_train: None, info: MetricsInfo::new(),
metric_logger_valid: None,
metrics: Metrics::new(),
renderer: None, renderer: None,
interrupter: TrainingInterrupter::new(), interrupter: TrainingInterrupter::new(),
log_to_file: true, log_to_file: true,
num_loggers: 0,
} }
} }
@ -89,8 +87,9 @@ where
MT: MetricLogger + 'static, MT: MetricLogger + 'static,
MV: MetricLogger + 'static, MV: MetricLogger + 'static,
{ {
self.metric_logger_train = Some(Box::new(logger_train)); self.info.register_logger_train(logger_train);
self.metric_logger_valid = Some(Box::new(logger_valid)); self.info.register_logger_valid(logger_valid);
self.num_loggers += 1;
self self
} }
@ -112,9 +111,7 @@ where
where where
T: Adaptor<Me::Input>, T: Adaptor<Me::Input>,
{ {
self.metrics self.info.register_metric_train(metric);
.train
.push(Box::new(MetricWrapper::new(metric)));
self self
} }
@ -123,9 +120,7 @@ where
where where
V: Adaptor<Me::Input>, V: Adaptor<Me::Input>,
{ {
self.metrics self.info.register_valid_metric(metric);
.valid
.push(Box::new(MetricWrapper::new(metric)));
self self
} }
@ -144,41 +139,25 @@ where
self self
} }
/// Register a training metric and displays it on a plot. /// Register a [numeric](crate::metric::Numeric) training [metric](Metric).
/// pub fn metric_train_numeric<Me>(mut self, metric: Me) -> Self
/// # Notes
///
/// Only [numeric](crate::metric::Numeric) metric can be displayed on a plot.
/// If the same metric is also registered for the [validation split](Self::metric_valid_plot),
/// the same graph will be used for both.
pub fn metric_train_plot<Me>(mut self, metric: Me) -> Self
where where
Me: Metric + crate::metric::Numeric + 'static, Me: Metric + crate::metric::Numeric + 'static,
T: Adaptor<Me::Input>, T: Adaptor<Me::Input>,
{ {
self.metrics self.info.register_train_metric_numeric(metric);
.train_numeric
.push(Box::new(MetricWrapper::new(metric)));
self self
} }
/// Register a validation metric and displays it on a plot. /// Register a [numeric](crate::metric::Numeric) validation [metric](Metric).
/// pub fn metric_valid_numeric<Me: Metric + crate::metric::Numeric + 'static>(
/// # Notes
///
/// Only [numeric](crate::metric::Numeric) metric can be displayed on a plot.
/// If the same metric is also registered for the [training split](Self::metric_train_plot),
/// the same graph will be used for both.
pub fn metric_valid_plot<Me: Metric + crate::metric::Numeric + 'static>(
mut self, mut self,
metric: Me, metric: Me,
) -> Self ) -> Self
where where
V: Adaptor<Me::Input>, V: Adaptor<Me::Input>,
{ {
self.metrics self.info.register_valid_metric_numeric(metric);
.valid_numeric
.push(Box::new(MetricWrapper::new(metric)));
self self
} }
@ -260,7 +239,7 @@ where
#[allow(clippy::type_complexity)] // The goal for the builder is to handle all types and #[allow(clippy::type_complexity)] // The goal for the builder is to handle all types and
// creates a clean learner. // creates a clean learner.
pub fn build( pub fn build(
self, mut self,
model: M, model: M,
optim: O, optim: O,
lr_scheduler: S, lr_scheduler: S,
@ -273,7 +252,7 @@ where
AsyncCheckpointer<M::Record>, AsyncCheckpointer<M::Record>,
AsyncCheckpointer<O::Record>, AsyncCheckpointer<O::Record>,
AsyncCheckpointer<S::Record>, AsyncCheckpointer<S::Record>,
AsyncTrainerCallback<T, V>, AsyncEventCollector<T, V>,
>, >,
> >
where where
@ -288,18 +267,18 @@ where
Box::new(default_renderer(self.interrupter.clone(), self.checkpoint)) Box::new(default_renderer(self.interrupter.clone(), self.checkpoint))
}); });
let directory = &self.directory; let directory = &self.directory;
let logger_train = self.metric_logger_train.unwrap_or_else(|| {
Box::new(FileMetricLogger::new(format!("{directory}/train").as_str())) if self.num_loggers == 0 {
}); self.info.register_logger_train(FileMetricLogger::new(
let logger_valid = self.metric_logger_valid.unwrap_or_else(|| { format!("{directory}/train").as_str(),
Box::new(FileMetricLogger::new(format!("{directory}/valid").as_str())) ));
}); self.info.register_logger_valid(FileMetricLogger::new(
let callback = AsyncTrainerCallback::new(MetricsCallback::new( format!("{directory}/valid").as_str(),
renderer, ));
self.metrics, }
logger_train,
logger_valid, let collector =
)); AsyncEventCollector::new(RenderedMetricsEventCollector::new(renderer, self.info));
let checkpointer = self let checkpointer = self
.checkpointers .checkpointers
@ -311,7 +290,7 @@ where
lr_scheduler, lr_scheduler,
checkpointer, checkpointer,
num_epochs: self.num_epochs, num_epochs: self.num_epochs,
callback, collector,
checkpoint: self.checkpoint, checkpoint: self.checkpoint,
grad_accumulation: self.grad_accumulation, grad_accumulation: self.grad_accumulation,
devices: self.devices, devices: self.devices,

View File

@ -7,8 +7,8 @@ use burn_core::{
}; };
use std::sync::Arc; use std::sync::Arc;
use crate::learner::base::TrainingInterrupter; use crate::{learner::base::TrainingInterrupter, Event};
use crate::{LearnerCallback, LearnerItem, MultiDevicesTrainStep, TrainStep, ValidStep}; use crate::{EventCollector, LearnerItem, MultiDevicesTrainStep, TrainStep, ValidStep};
/// A validation epoch. /// A validation epoch.
#[derive(new)] #[derive(new)]
@ -37,7 +37,7 @@ impl<VI> ValidEpoch<VI> {
pub fn run<B, M, TO, VO>( pub fn run<B, M, TO, VO>(
&self, &self,
model: &M, model: &M,
callback: &mut Box<dyn LearnerCallback<ItemTrain = TO, ItemValid = VO>>, callback: &mut Box<dyn EventCollector<ItemTrain = TO, ItemValid = VO>>,
interrupter: &TrainingInterrupter, interrupter: &TrainingInterrupter,
) where ) where
B: ADBackend, B: ADBackend,
@ -64,13 +64,14 @@ impl<VI> ValidEpoch<VI> {
None, None,
); );
callback.on_valid_item(item); callback.on_event_valid(Event::ProcessedItem(item));
if interrupter.should_stop() { if interrupter.should_stop() {
log::info!("Training interrupted."); log::info!("Training interrupted.");
break; break;
} }
} }
callback.on_valid_end_epoch(self.epoch); callback.on_event_valid(Event::EndEpoch(self.epoch));
} }
} }
@ -92,7 +93,7 @@ impl<TI> TrainEpoch<TI> {
mut model: M, mut model: M,
mut optim: O, mut optim: O,
scheduler: &mut LR, scheduler: &mut LR,
callback: &mut Box<dyn LearnerCallback<ItemTrain = TO, ItemValid = VO>>, callback: &mut Box<dyn EventCollector<ItemTrain = TO, ItemValid = VO>>,
interrupter: &TrainingInterrupter, interrupter: &TrainingInterrupter,
) -> (M, O) ) -> (M, O)
where where
@ -139,13 +140,13 @@ impl<TI> TrainEpoch<TI> {
Some(lr), Some(lr),
); );
callback.on_train_item(item); callback.on_event_train(Event::ProcessedItem(item));
if interrupter.should_stop() { if interrupter.should_stop() {
log::info!("Training interrupted."); log::info!("Training interrupted.");
break; break;
} }
} }
callback.on_train_end_epoch(self.epoch); callback.on_event_train(Event::EndEpoch(self.epoch));
(model, optim) (model, optim)
} }
@ -170,7 +171,7 @@ impl<TI> TrainEpoch<TI> {
mut model: M, mut model: M,
mut optim: O, mut optim: O,
lr_scheduler: &mut S, lr_scheduler: &mut S,
callback: &mut Box<dyn LearnerCallback<ItemTrain = TO, ItemValid = VO>>, callback: &mut Box<dyn EventCollector<ItemTrain = TO, ItemValid = VO>>,
devices: Vec<B::Device>, devices: Vec<B::Device>,
interrupter: &TrainingInterrupter, interrupter: &TrainingInterrupter,
) -> (M, O) ) -> (M, O)
@ -232,7 +233,8 @@ impl<TI> TrainEpoch<TI> {
Some(lr), Some(lr),
); );
callback.on_train_item(item); callback.on_event_train(Event::ProcessedItem(item));
if interrupter.should_stop() { if interrupter.should_stop() {
log::info!("Training interrupted."); log::info!("Training interrupted.");
interrupted = true; interrupted = true;
@ -245,7 +247,7 @@ impl<TI> TrainEpoch<TI> {
} }
} }
callback.on_train_end_epoch(self.epoch); callback.on_event_train(Event::EndEpoch(self.epoch));
(model, optim) (model, optim)
} }

View File

@ -1,5 +1,5 @@
use crate::components::LearnerComponents; use crate::components::LearnerComponents;
use crate::{Learner, LearnerCallback, TrainEpoch, ValidEpoch}; use crate::{EventCollector, Learner, TrainEpoch, ValidEpoch};
use burn_core::data::dataloader::DataLoader; use burn_core::data::dataloader::DataLoader;
use burn_core::module::{ADModule, Module}; use burn_core::module::{ADModule, Module};
use burn_core::optim::{GradientsParams, Optimizer}; use burn_core::optim::{GradientsParams, Optimizer};
@ -115,7 +115,7 @@ impl<LC: LearnerComponents> Learner<LC> {
OutputValid: Send, OutputValid: Send,
LC::Model: TrainStep<InputTrain, OutputTrain>, LC::Model: TrainStep<InputTrain, OutputTrain>,
<LC::Model as ADModule<LC::Backend>>::InnerModule: ValidStep<InputValid, OutputValid>, <LC::Model as ADModule<LC::Backend>>::InnerModule: ValidStep<InputValid, OutputValid>,
LC::Callback: LearnerCallback<ItemTrain = OutputTrain, ItemValid = OutputValid>, LC::EventCollector: EventCollector<ItemTrain = OutputTrain, ItemValid = OutputValid>,
{ {
log::info!("Fitting {}", self.model.to_string()); log::info!("Fitting {}", self.model.to_string());
// The reference model is always on the first device provided. // The reference model is always on the first device provided.
@ -139,8 +139,8 @@ impl<LC: LearnerComponents> Learner<LC> {
}; };
let mut callback: Box< let mut callback: Box<
dyn LearnerCallback<ItemTrain = OutputTrain, ItemValid = OutputValid>, dyn EventCollector<ItemTrain = OutputTrain, ItemValid = OutputValid>,
> = Box::new(self.callback); > = Box::new(self.collector);
for epoch in starting_epoch..self.num_epochs + 1 { for epoch in starting_epoch..self.num_epochs + 1 {
let epoch_train = TrainEpoch::new( let epoch_train = TrainEpoch::new(

View File

@ -10,16 +10,22 @@ pub mod checkpoint;
pub(crate) mod components; pub(crate) mod components;
/// Renderer modules to display metrics and training information.
pub mod renderer;
/// The logger module. /// The logger module.
pub mod logger; pub mod logger;
/// The metric module. /// The metric module.
pub mod metric; pub mod metric;
mod callback; /// All information collected during training.
pub mod info;
mod collector;
mod learner; mod learner;
pub use callback::*; pub use collector::*;
pub use learner::*; pub use learner::*;
#[cfg(test)] #[cfg(test)]

View File

@ -4,6 +4,7 @@ use std::sync::mpsc;
enum Message<T> { enum Message<T> {
Log(T), Log(T),
End, End,
Sync(mpsc::Sender<()>),
} }
/// Async logger. /// Async logger.
pub struct AsyncLogger<T> { pub struct AsyncLogger<T> {
@ -30,6 +31,9 @@ where
Message::End => { Message::End => {
return; return;
} }
Message::Sync(callback) => {
callback.send(()).unwrap();
}
} }
} }
} }
@ -48,6 +52,17 @@ impl<T: Send + Sync + 'static> AsyncLogger<T> {
Self { sender, handler } Self { sender, handler }
} }
/// Sync the async logger.
pub(crate) fn sync(&self) {
let (sender, receiver) = mpsc::channel();
self.sender.send(Message::Sync(sender)).unwrap();
receiver
.recv()
.expect("Should sync, otherwise the thread is dead.");
}
} }
impl<T: Send> Logger<T> for AsyncLogger<T> { impl<T: Send> Logger<T> for AsyncLogger<T> {

View File

@ -49,11 +49,14 @@ impl FileMetricLogger {
fn file_path(&self, name: &str, epoch: usize) -> String { fn file_path(&self, name: &str, epoch: usize) -> String {
let directory = format!("{}/epoch-{}", self.directory, epoch); let directory = format!("{}/epoch-{}", self.directory, epoch);
std::fs::create_dir_all(&directory).ok();
let name = name.replace(' ', "_"); let name = name.replace(' ', "_");
format!("{directory}/{name}.log") format!("{directory}/{name}.log")
} }
fn create_directory(&self, epoch: usize) {
let directory = format!("{}/epoch-{}", self.directory, epoch);
std::fs::create_dir_all(directory).ok();
}
} }
impl MetricLogger for FileMetricLogger { impl MetricLogger for FileMetricLogger {
@ -64,6 +67,8 @@ impl MetricLogger for FileMetricLogger {
let logger = match self.loggers.get_mut(key) { let logger = match self.loggers.get_mut(key) {
Some(val) => val, Some(val) => val,
None => { None => {
self.create_directory(self.epoch);
let file_path = self.file_path(key, self.epoch); let file_path = self.file_path(key, self.epoch);
let logger = FileLogger::new(&file_path); let logger = FileLogger::new(&file_path);
let logger = AsyncLogger::new(logger); let logger = AsyncLogger::new(logger);
@ -82,6 +87,10 @@ impl MetricLogger for FileMetricLogger {
} }
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<f64>, String> { fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<f64>, String> {
if let Some(value) = self.loggers.get(name) {
value.sync()
}
let file_path = self.file_path(name, epoch); let file_path = self.file_path(name, epoch);
let mut errors = false; let mut errors = false;

View File

@ -1,285 +0,0 @@
use crate::{
logger::MetricLogger,
metric::{Adaptor, Metric, MetricEntry, MetricMetadata, Numeric},
LearnerCallback, LearnerItem,
};
use burn_core::data::dataloader::Progress;
/// Holds all metrics, metric loggers, and a metrics renderer.
pub struct MetricsCallback<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
metrics: Metrics<T, V>,
logger_train: Box<dyn MetricLogger>,
logger_valid: Box<dyn MetricLogger>,
renderer: Box<dyn MetricsRenderer>,
}
impl<T, V> MetricsCallback<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
/// Creates a new metrics callback.
///
/// # Arguments
///
/// * `renderer` - The metrics renderer.
/// * `metrics` - The metrics holder.
/// * `logger_train` - The training logger.
/// * `logger_valid` - The validation logger.
///
/// # Returns
///
/// A new metrics callback.
pub(crate) fn new(
renderer: Box<dyn MetricsRenderer>,
metrics: Metrics<T, V>,
logger_train: Box<dyn MetricLogger>,
logger_valid: Box<dyn MetricLogger>,
) -> Self {
Self {
metrics,
logger_train,
logger_valid,
renderer,
}
}
}
impl<T, V> LearnerCallback for MetricsCallback<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
type ItemTrain = T;
type ItemValid = V;
fn on_train_item(&mut self, item: LearnerItem<T>) {
let metadata = (&item).into();
for metric in self.metrics.train.iter_mut() {
let state = metric.update(&item, &metadata);
self.logger_train.log(&state);
self.renderer.update_train(MetricState::Generic(state));
}
for metric in self.metrics.train_numeric.iter_mut() {
let (state, value) = metric.update(&item, &metadata);
self.logger_train.log(&state);
self.renderer
.update_train(MetricState::Numeric(state, value));
}
self.renderer.render_train(item.into());
}
fn on_valid_item(&mut self, item: LearnerItem<V>) {
let metadata = (&item).into();
for metric in self.metrics.valid.iter_mut() {
let state = metric.update(&item, &metadata);
self.logger_valid.log(&state);
self.renderer.update_valid(MetricState::Generic(state));
}
for metric in self.metrics.valid_numeric.iter_mut() {
let (state, value) = metric.update(&item, &metadata);
self.logger_valid.log(&state);
self.renderer
.update_valid(MetricState::Numeric(state, value));
}
self.renderer.render_valid(item.into());
}
fn on_train_end_epoch(&mut self, epoch: usize) {
for metric in self.metrics.train.iter_mut() {
metric.clear();
}
for metric in self.metrics.train_numeric.iter_mut() {
metric.clear();
}
self.logger_train.epoch(epoch + 1);
}
fn on_valid_end_epoch(&mut self, epoch: usize) {
for metric in self.metrics.valid.iter_mut() {
metric.clear();
}
for metric in self.metrics.valid_numeric.iter_mut() {
metric.clear();
}
self.logger_valid.epoch(epoch + 1);
}
}
/// Training progress.
#[derive(Debug)]
pub struct TrainingProgress {
/// The progress.
pub progress: Progress,
/// The epoch.
pub epoch: usize,
/// The total number of epochs.
pub epoch_total: usize,
/// The iteration.
pub iteration: usize,
}
impl TrainingProgress {
/// Creates a new empty training progress.
pub fn none() -> Self {
Self {
progress: Progress {
items_processed: 0,
items_total: 0,
},
epoch: 0,
epoch_total: 0,
iteration: 0,
}
}
}
/// The state of a metric.
#[derive(Debug)]
pub enum MetricState {
/// A generic metric.
Generic(MetricEntry),
/// A numeric metric.
Numeric(MetricEntry, f64),
}
/// Trait for rendering metrics.
pub trait MetricsRenderer: Send + Sync {
/// Updates the training metric state.
///
/// # Arguments
///
/// * `state` - The metric state.
fn update_train(&mut self, state: MetricState);
/// Updates the validation metric state.
///
/// # Arguments
///
/// * `state` - The metric state.
fn update_valid(&mut self, state: MetricState);
/// Renders the training progress.
///
/// # Arguments
///
/// * `item` - The training progress.
fn render_train(&mut self, item: TrainingProgress);
/// Renders the validation progress.
///
/// # Arguments
///
/// * `item` - The validation progress.
fn render_valid(&mut self, item: TrainingProgress);
}
/// A container for the metrics held by a metrics callback.
pub(crate) struct Metrics<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
pub(crate) train: Vec<Box<dyn MetricUpdater<T>>>,
pub(crate) valid: Vec<Box<dyn MetricUpdater<V>>>,
pub(crate) train_numeric: Vec<Box<dyn NumericMetricUpdater<T>>>,
pub(crate) valid_numeric: Vec<Box<dyn NumericMetricUpdater<V>>>,
}
impl<T, V> Metrics<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
train: vec![],
valid: vec![],
train_numeric: vec![],
valid_numeric: vec![],
}
}
}
impl<T> From<LearnerItem<T>> for TrainingProgress {
fn from(item: LearnerItem<T>) -> Self {
Self {
progress: item.progress,
epoch: item.epoch,
epoch_total: item.epoch_total,
iteration: item.iteration,
}
}
}
impl<T> From<&LearnerItem<T>> for MetricMetadata {
fn from(item: &LearnerItem<T>) -> Self {
Self {
progress: item.progress.clone(),
epoch: item.epoch,
epoch_total: item.epoch_total,
iteration: item.iteration,
lr: item.lr,
}
}
}
pub(crate) trait NumericMetricUpdater<T>: Send + Sync {
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> (MetricEntry, f64);
fn clear(&mut self);
}
pub(crate) trait MetricUpdater<T>: Send + Sync {
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry;
fn clear(&mut self);
}
#[derive(new)]
pub(crate) struct MetricWrapper<M> {
metric: M,
}
impl<T, M> NumericMetricUpdater<T> for MetricWrapper<M>
where
T: 'static,
M: Metric + Numeric + 'static,
T: Adaptor<M::Input>,
{
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> (MetricEntry, f64) {
let update = self.metric.update(&item.item.adapt(), metadata);
let numeric = self.metric.value();
(update, numeric)
}
fn clear(&mut self) {
self.metric.clear()
}
}
impl<T, M> MetricUpdater<T> for MetricWrapper<M>
where
T: 'static,
M: Metric + 'static,
T: Adaptor<M::Input>,
{
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry {
self.metric.update(&item.item.adapt(), metadata)
}
fn clear(&mut self) {
self.metric.clear()
}
}

View File

@ -1,7 +1,4 @@
/// Callback module for training progress. /// State module.
pub mod callback;
/// State module for callback metrics.
pub mod state; pub mod state;
mod acc; mod acc;

View File

@ -1,4 +1,4 @@
use super::{format_float, MetricEntry, Numeric}; use crate::metric::{format_float, MetricEntry, Numeric};
/// Usefull utility to implement numeric metrics. /// Usefull utility to implement numeric metrics.
/// ///

View File

@ -0,0 +1,75 @@
use burn_core::data::dataloader::Progress;
use crate::metric::MetricEntry;
/// Trait for rendering metrics.
pub trait MetricsRenderer: Send + Sync {
/// Updates the training metric state.
///
/// # Arguments
///
/// * `state` - The metric state.
fn update_train(&mut self, state: MetricState);
/// Updates the validation metric state.
///
/// # Arguments
///
/// * `state` - The metric state.
fn update_valid(&mut self, state: MetricState);
/// Renders the training progress.
///
/// # Arguments
///
/// * `item` - The training progress.
fn render_train(&mut self, item: TrainingProgress);
/// Renders the validation progress.
///
/// # Arguments
///
/// * `item` - The validation progress.
fn render_valid(&mut self, item: TrainingProgress);
}
/// The state of a metric.
#[derive(Debug)]
pub enum MetricState {
/// A generic metric.
Generic(MetricEntry),
/// A numeric metric.
Numeric(MetricEntry, f64),
}
/// Training progress.
#[derive(Debug)]
pub struct TrainingProgress {
/// The progress.
pub progress: Progress,
/// The epoch.
pub epoch: usize,
/// The total number of epochs.
pub epoch_total: usize,
/// The iteration.
pub iteration: usize,
}
impl TrainingProgress {
/// Creates a new empty training progress.
pub fn none() -> Self {
Self {
progress: Progress {
items_processed: 0,
items_total: 0,
},
epoch: 0,
epoch_total: 0,
iteration: 0,
}
}
}

View File

@ -1,5 +1,4 @@
mod base; mod base;
pub use base::*; pub use base::*;
#[cfg(not(feature = "tui"))] #[cfg(not(feature = "tui"))]

View File

@ -1,4 +1,4 @@
use crate::metric::callback::TrainingProgress; use crate::renderer::TrainingProgress;
use super::{FullHistoryPlot, RecentHistoryPlot, TerminalFrame}; use super::{FullHistoryPlot, RecentHistoryPlot, TerminalFrame};
use crossterm::event::{Event, KeyCode}; use crossterm::event::{Event, KeyCode};

View File

@ -1,5 +1,6 @@
use crate::renderer::TrainingProgress;
use super::TerminalFrame; use super::TerminalFrame;
use crate::metric::callback::TrainingProgress;
use ratatui::{ use ratatui::{
prelude::{Alignment, Constraint, Direction, Layout, Rect}, prelude::{Alignment, Constraint, Direction, Layout, Rect},
style::{Color, Style, Stylize}, style::{Color, Style, Stylize},

View File

@ -1,5 +1,5 @@
use crate::metric::callback::tui::NumericMetricsState; use crate::renderer::{tui::NumericMetricsState, MetricsRenderer};
use crate::metric::callback::{MetricState, MetricsRenderer, TrainingProgress}; use crate::renderer::{MetricState, TrainingProgress};
use crate::TrainingInterrupter; use crate::TrainingInterrupter;
use crossterm::{ use crossterm::{
event::{self, Event, KeyCode}, event::{self, Event, KeyCode},

View File

@ -1,5 +1,5 @@
use super::TerminalFrame; use super::TerminalFrame;
use crate::metric::callback::TrainingProgress; use crate::renderer::TrainingProgress;
use ratatui::{ use ratatui::{
prelude::{Alignment, Rect}, prelude::{Alignment, Rect},
style::{Color, Style, Stylize}, style::{Color, Style, Stylize},

View File

@ -1,5 +1,5 @@
use burn::data::dataset::source::huggingface::MNISTDataset; use burn::data::dataset::source::huggingface::MNISTDataset;
use burn::train::metric::callback::{MetricState, MetricsRenderer, TrainingProgress}; use burn::train::renderer::{MetricState, MetricsRenderer, TrainingProgress};
use burn::train::LearnerBuilder; use burn::train::LearnerBuilder;
use burn::{ use burn::{
config::Config, data::dataloader::DataLoaderBuilder, optim::AdamConfig, config::Config, data::dataloader::DataLoaderBuilder, optim::AdamConfig,

View File

@ -88,10 +88,10 @@ pub fn train<B: ADBackend>(artifact_dir: &str, config: TrainingConfig, device: B
.build(MNISTDataset::test()); .build(MNISTDataset::test());
let learner = LearnerBuilder::new(artifact_dir) let learner = LearnerBuilder::new(artifact_dir)
.metric_train_plot(AccuracyMetric::new()) .metric_train_numeric(AccuracyMetric::new())
.metric_valid_plot(AccuracyMetric::new()) .metric_valid_numeric(AccuracyMetric::new())
.metric_train_plot(LossMetric::new()) .metric_train_numeric(LossMetric::new())
.metric_valid_plot(LossMetric::new()) .metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(1, CompactRecorder::new()) .with_file_checkpointer(1, CompactRecorder::new())
.devices(vec![device]) .devices(vec![device])
.num_epochs(config.num_epochs) .num_epochs(config.num_epochs)

View File

@ -58,16 +58,16 @@ pub fn run<B: ADBackend>(device: B::Device) {
// Model // Model
let learner = LearnerBuilder::new(ARTIFACT_DIR) let learner = LearnerBuilder::new(ARTIFACT_DIR)
.metric_train_plot(AccuracyMetric::new()) .metric_train_numeric(AccuracyMetric::new())
.metric_valid_plot(AccuracyMetric::new()) .metric_valid_numeric(AccuracyMetric::new())
.metric_train_plot(CpuUse::new()) .metric_train_numeric(CpuUse::new())
.metric_valid_plot(CpuUse::new()) .metric_valid_numeric(CpuUse::new())
.metric_train_plot(CpuMemory::new()) .metric_train_numeric(CpuMemory::new())
.metric_valid_plot(CpuMemory::new()) .metric_valid_numeric(CpuMemory::new())
.metric_train_plot(CpuTemperature::new()) .metric_train_numeric(CpuTemperature::new())
.metric_valid_plot(CpuTemperature::new()) .metric_valid_numeric(CpuTemperature::new())
.metric_train_plot(LossMetric::new()) .metric_train_numeric(LossMetric::new())
.metric_valid_plot(LossMetric::new()) .metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(1, CompactRecorder::new()) .with_file_checkpointer(1, CompactRecorder::new())
.devices(vec![device]) .devices(vec![device])
.num_epochs(config.num_epochs) .num_epochs(config.num_epochs)

View File

@ -95,9 +95,9 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
.metric_valid(CUDAMetric::new()) .metric_valid(CUDAMetric::new())
.metric_train(AccuracyMetric::new()) .metric_train(AccuracyMetric::new())
.metric_valid(AccuracyMetric::new()) .metric_valid(AccuracyMetric::new())
.metric_train_plot(LossMetric::new()) .metric_train_numeric(LossMetric::new())
.metric_valid_plot(LossMetric::new()) .metric_valid_numeric(LossMetric::new())
.metric_train_plot(LearningRateMetric::new()) .metric_train_numeric(LearningRateMetric::new())
.with_file_checkpointer(2, CompactRecorder::new()) .with_file_checkpointer(2, CompactRecorder::new())
.devices(vec![device]) .devices(vec![device])
.num_epochs(config.num_epochs) .num_epochs(config.num_epochs)

View File

@ -70,11 +70,11 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(
let learner = LearnerBuilder::new(artifact_dir) let learner = LearnerBuilder::new(artifact_dir)
.metric_train(CUDAMetric::new()) .metric_train(CUDAMetric::new())
.metric_valid(CUDAMetric::new()) .metric_valid(CUDAMetric::new())
.metric_train_plot(AccuracyMetric::new().with_pad_token(tokenizer.pad_token())) .metric_train_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token()))
.metric_valid_plot(AccuracyMetric::new().with_pad_token(tokenizer.pad_token())) .metric_valid_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token()))
.metric_train(LossMetric::new()) .metric_train(LossMetric::new())
.metric_valid(LossMetric::new()) .metric_valid(LossMetric::new())
.metric_train_plot(LearningRateMetric::new()) .metric_train_numeric(LearningRateMetric::new())
.with_file_checkpointer(2, CompactRecorder::new()) .with_file_checkpointer(2, CompactRecorder::new())
.devices(vec![device]) .devices(vec![device])
.grads_accumulation(accum) .grads_accumulation(accum)