mirror of https://github.com/tracel-ai/burn.git
Feat training events (#857)
This commit is contained in:
parent
097fd956d0
commit
620b86de98
|
@ -111,10 +111,10 @@ pub fn train<B: ADBackend>(artifact_dir: &str, config: TrainingConfig, device: B
|
||||||
.build(MNISTDataset::test());
|
.build(MNISTDataset::test());
|
||||||
|
|
||||||
let learner = LearnerBuilder::new(artifact_dir)
|
let learner = LearnerBuilder::new(artifact_dir)
|
||||||
.metric_train_plot(AccuracyMetric::new())
|
.metric_train_numeric(AccuracyMetric::new())
|
||||||
.metric_valid_plot(AccuracyMetric::new())
|
.metric_valid_numeric(AccuracyMetric::new())
|
||||||
.metric_train_plot(LossMetric::new())
|
.metric_train_numeric(LossMetric::new())
|
||||||
.metric_valid_plot(LossMetric::new())
|
.metric_valid_numeric(LossMetric::new())
|
||||||
.with_file_checkpointer(1, CompactRecorder::new())
|
.with_file_checkpointer(1, CompactRecorder::new())
|
||||||
.devices(vec![device])
|
.devices(vec![device])
|
||||||
.num_epochs(config.num_epochs)
|
.num_epochs(config.num_epochs)
|
||||||
|
|
|
@ -1,97 +0,0 @@
|
||||||
use super::{LearnerCallback, LearnerItem};
|
|
||||||
use std::{sync::mpsc, thread::JoinHandle};
|
|
||||||
|
|
||||||
enum Message<T, V> {
|
|
||||||
LogTrain(LearnerItem<T>),
|
|
||||||
LogValid(LearnerItem<V>),
|
|
||||||
ClearTrain(usize),
|
|
||||||
ClearValid(usize),
|
|
||||||
End,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Async trainer callback tracker.
|
|
||||||
pub struct AsyncTrainerCallback<T, V> {
|
|
||||||
sender: mpsc::Sender<Message<T, V>>,
|
|
||||||
handler: Option<JoinHandle<()>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(new)]
|
|
||||||
struct CallbackThread<C, T, V> {
|
|
||||||
callback: C,
|
|
||||||
receiver: mpsc::Receiver<Message<T, V>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<C, T, V> CallbackThread<C, T, V>
|
|
||||||
where
|
|
||||||
C: LearnerCallback<ItemTrain = T, ItemValid = V>,
|
|
||||||
{
|
|
||||||
fn run(mut self) {
|
|
||||||
for item in self.receiver.iter() {
|
|
||||||
match item {
|
|
||||||
Message::LogTrain(item) => {
|
|
||||||
self.callback.on_train_item(item);
|
|
||||||
}
|
|
||||||
Message::ClearTrain(epoch) => {
|
|
||||||
self.callback.on_train_end_epoch(epoch);
|
|
||||||
}
|
|
||||||
Message::LogValid(item) => {
|
|
||||||
self.callback.on_valid_item(item);
|
|
||||||
}
|
|
||||||
Message::ClearValid(epoch) => {
|
|
||||||
self.callback.on_valid_end_epoch(epoch);
|
|
||||||
}
|
|
||||||
Message::End => {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Send + Sync + 'static, V: Send + Sync + 'static> AsyncTrainerCallback<T, V> {
|
|
||||||
/// Create a new async trainer callback.
|
|
||||||
pub fn new<C>(callback: C) -> Self
|
|
||||||
where
|
|
||||||
C: LearnerCallback<ItemTrain = T, ItemValid = V> + 'static,
|
|
||||||
{
|
|
||||||
let (sender, receiver) = mpsc::channel();
|
|
||||||
let thread = CallbackThread::new(callback, receiver);
|
|
||||||
|
|
||||||
let handler = std::thread::spawn(move || thread.run());
|
|
||||||
let handler = Some(handler);
|
|
||||||
|
|
||||||
Self { sender, handler }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Send, V: Send> LearnerCallback for AsyncTrainerCallback<T, V> {
|
|
||||||
type ItemTrain = T;
|
|
||||||
type ItemValid = V;
|
|
||||||
|
|
||||||
fn on_train_item(&mut self, item: LearnerItem<T>) {
|
|
||||||
self.sender.send(Message::LogTrain(item)).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn on_valid_item(&mut self, item: LearnerItem<V>) {
|
|
||||||
self.sender.send(Message::LogValid(item)).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn on_train_end_epoch(&mut self, epoch: usize) {
|
|
||||||
self.sender.send(Message::ClearTrain(epoch)).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn on_valid_end_epoch(&mut self, epoch: usize) {
|
|
||||||
self.sender.send(Message::ClearValid(epoch)).unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T, V> Drop for AsyncTrainerCallback<T, V> {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
self.sender.send(Message::End).unwrap();
|
|
||||||
let handler = self.handler.take();
|
|
||||||
|
|
||||||
if let Some(handler) = handler {
|
|
||||||
handler.join().unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,43 +0,0 @@
|
||||||
use burn_core::{data::dataloader::Progress, LearningRate};
|
|
||||||
|
|
||||||
/// The base trait for trainer callbacks.
|
|
||||||
pub trait LearnerCallback: Send {
|
|
||||||
/// Training item.
|
|
||||||
type ItemTrain;
|
|
||||||
/// Validation item.
|
|
||||||
type ItemValid;
|
|
||||||
|
|
||||||
/// Called when a training item is logged.
|
|
||||||
fn on_train_item(&mut self, _item: LearnerItem<Self::ItemTrain>) {}
|
|
||||||
|
|
||||||
/// Called when a validation item is logged.
|
|
||||||
fn on_valid_item(&mut self, _item: LearnerItem<Self::ItemValid>) {}
|
|
||||||
|
|
||||||
/// Called when a training epoch is finished.
|
|
||||||
fn on_train_end_epoch(&mut self, _epoch: usize) {}
|
|
||||||
|
|
||||||
/// Called when a validation epoch is finished.
|
|
||||||
fn on_valid_end_epoch(&mut self, _epoch: usize) {}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A learner item.
|
|
||||||
#[derive(new)]
|
|
||||||
pub struct LearnerItem<T> {
|
|
||||||
/// The item.
|
|
||||||
pub item: T,
|
|
||||||
|
|
||||||
/// The progress.
|
|
||||||
pub progress: Progress,
|
|
||||||
|
|
||||||
/// The epoch.
|
|
||||||
pub epoch: usize,
|
|
||||||
|
|
||||||
/// The total number of epochs.
|
|
||||||
pub epoch_total: usize,
|
|
||||||
|
|
||||||
/// The iteration.
|
|
||||||
pub iteration: usize,
|
|
||||||
|
|
||||||
/// The learning rate.
|
|
||||||
pub lr: Option<LearningRate>,
|
|
||||||
}
|
|
|
@ -1,5 +0,0 @@
|
||||||
mod async_callback;
|
|
||||||
mod base;
|
|
||||||
|
|
||||||
pub use async_callback::*;
|
|
||||||
pub use base::*;
|
|
|
@ -0,0 +1,118 @@
|
||||||
|
use super::EventCollector;
|
||||||
|
use crate::{Aggregate, Direction, Event, Split};
|
||||||
|
use std::{sync::mpsc, thread::JoinHandle};
|
||||||
|
|
||||||
|
enum Message<T, V> {
|
||||||
|
OnEventTrain(Event<T>),
|
||||||
|
OnEventValid(Event<V>),
|
||||||
|
End,
|
||||||
|
FindEpoch(
|
||||||
|
String,
|
||||||
|
Aggregate,
|
||||||
|
Direction,
|
||||||
|
Split,
|
||||||
|
mpsc::SyncSender<Option<usize>>,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Async [event collector](EventCollector).
|
||||||
|
///
|
||||||
|
/// This will create a worker thread where all the computation is done ensuring that the training loop is
|
||||||
|
/// never blocked by metric calculation.
|
||||||
|
pub struct AsyncEventCollector<T, V> {
|
||||||
|
sender: mpsc::Sender<Message<T, V>>,
|
||||||
|
handler: Option<JoinHandle<()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(new)]
|
||||||
|
struct WorkerThread<C, T, V> {
|
||||||
|
collector: C,
|
||||||
|
receiver: mpsc::Receiver<Message<T, V>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C, T, V> WorkerThread<C, T, V>
|
||||||
|
where
|
||||||
|
C: EventCollector<ItemTrain = T, ItemValid = V>,
|
||||||
|
{
|
||||||
|
fn run(mut self) {
|
||||||
|
for item in self.receiver.iter() {
|
||||||
|
match item {
|
||||||
|
Message::End => {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Message::FindEpoch(name, aggregate, direction, split, sender) => {
|
||||||
|
let response = self
|
||||||
|
.collector
|
||||||
|
.find_epoch(&name, aggregate, direction, split);
|
||||||
|
sender.send(response).unwrap();
|
||||||
|
}
|
||||||
|
Message::OnEventTrain(event) => self.collector.on_event_train(event),
|
||||||
|
Message::OnEventValid(event) => self.collector.on_event_valid(event),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Send + Sync + 'static, V: Send + Sync + 'static> AsyncEventCollector<T, V> {
|
||||||
|
/// Create a new async [event collector](EventCollector).
|
||||||
|
pub fn new<C>(collector: C) -> Self
|
||||||
|
where
|
||||||
|
C: EventCollector<ItemTrain = T, ItemValid = V> + 'static,
|
||||||
|
{
|
||||||
|
let (sender, receiver) = mpsc::channel();
|
||||||
|
let thread = WorkerThread::new(collector, receiver);
|
||||||
|
|
||||||
|
let handler = std::thread::spawn(move || thread.run());
|
||||||
|
let handler = Some(handler);
|
||||||
|
|
||||||
|
Self { sender, handler }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Send, V: Send> EventCollector for AsyncEventCollector<T, V> {
|
||||||
|
type ItemTrain = T;
|
||||||
|
type ItemValid = V;
|
||||||
|
|
||||||
|
fn on_event_train(&mut self, event: Event<Self::ItemTrain>) {
|
||||||
|
self.sender.send(Message::OnEventTrain(event)).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn on_event_valid(&mut self, event: Event<Self::ItemValid>) {
|
||||||
|
self.sender.send(Message::OnEventValid(event)).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_epoch(
|
||||||
|
&mut self,
|
||||||
|
name: &str,
|
||||||
|
aggregate: Aggregate,
|
||||||
|
direction: Direction,
|
||||||
|
split: Split,
|
||||||
|
) -> Option<usize> {
|
||||||
|
let (sender, receiver) = mpsc::sync_channel(1);
|
||||||
|
self.sender
|
||||||
|
.send(Message::FindEpoch(
|
||||||
|
name.to_string(),
|
||||||
|
aggregate,
|
||||||
|
direction,
|
||||||
|
split,
|
||||||
|
sender,
|
||||||
|
))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
match receiver.recv() {
|
||||||
|
Ok(value) => value,
|
||||||
|
Err(err) => panic!("Async server crashed: {:?}", err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, V> Drop for AsyncEventCollector<T, V> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.sender.send(Message::End).unwrap();
|
||||||
|
let handler = self.handler.take();
|
||||||
|
|
||||||
|
if let Some(handler) = handler {
|
||||||
|
handler.join().unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,78 @@
|
||||||
|
use burn_core::{data::dataloader::Progress, LearningRate};
|
||||||
|
|
||||||
|
/// Event happening during the training/validation process.
|
||||||
|
pub enum Event<T> {
|
||||||
|
/// Signal that an item have been processed.
|
||||||
|
ProcessedItem(LearnerItem<T>),
|
||||||
|
/// Signal the end of an epoch.
|
||||||
|
EndEpoch(usize),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Defines how training and validation events are collected.
|
||||||
|
///
|
||||||
|
/// This trait also exposes methods that uses the collected data to compute useful information.
|
||||||
|
pub trait EventCollector: Send {
|
||||||
|
/// Training item.
|
||||||
|
type ItemTrain;
|
||||||
|
/// Validation item.
|
||||||
|
type ItemValid;
|
||||||
|
|
||||||
|
/// Collect the training event.
|
||||||
|
fn on_event_train(&mut self, event: Event<Self::ItemTrain>);
|
||||||
|
|
||||||
|
/// Collect the validaion event.
|
||||||
|
fn on_event_valid(&mut self, event: Event<Self::ItemValid>);
|
||||||
|
|
||||||
|
/// Find the epoch following the given criteria from the collected data.
|
||||||
|
fn find_epoch(
|
||||||
|
&mut self,
|
||||||
|
name: &str,
|
||||||
|
aggregate: Aggregate,
|
||||||
|
direction: Direction,
|
||||||
|
split: Split,
|
||||||
|
) -> Option<usize>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// How to aggregate the metric.
|
||||||
|
pub enum Aggregate {
|
||||||
|
/// Compute the average.
|
||||||
|
Mean,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The split to use.
|
||||||
|
pub enum Split {
|
||||||
|
/// The training split.
|
||||||
|
Train,
|
||||||
|
/// The validation split.
|
||||||
|
Valid,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The direction of the query.
|
||||||
|
pub enum Direction {
|
||||||
|
/// Lower is better.
|
||||||
|
Lowest,
|
||||||
|
/// Higher is better.
|
||||||
|
Highest,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A learner item.
|
||||||
|
#[derive(new)]
|
||||||
|
pub struct LearnerItem<T> {
|
||||||
|
/// The item.
|
||||||
|
pub item: T,
|
||||||
|
|
||||||
|
/// The progress.
|
||||||
|
pub progress: Progress,
|
||||||
|
|
||||||
|
/// The epoch.
|
||||||
|
pub epoch: usize,
|
||||||
|
|
||||||
|
/// The total number of epochs.
|
||||||
|
pub epoch_total: usize,
|
||||||
|
|
||||||
|
/// The iteration.
|
||||||
|
pub iteration: usize,
|
||||||
|
|
||||||
|
/// The learning rate.
|
||||||
|
pub lr: Option<LearningRate>,
|
||||||
|
}
|
|
@ -0,0 +1,131 @@
|
||||||
|
use crate::{
|
||||||
|
info::MetricsInfo,
|
||||||
|
metric::MetricMetadata,
|
||||||
|
renderer::{MetricState, MetricsRenderer, TrainingProgress},
|
||||||
|
Aggregate, Direction, Event, EventCollector, LearnerItem, Split,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Collect training events in order to display metrics with a metrics renderer.
|
||||||
|
#[derive(new)]
|
||||||
|
pub(crate) struct RenderedMetricsEventCollector<T, V>
|
||||||
|
where
|
||||||
|
T: Send + Sync + 'static,
|
||||||
|
V: Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
renderer: Box<dyn MetricsRenderer>,
|
||||||
|
info: MetricsInfo<T, V>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, V> EventCollector for RenderedMetricsEventCollector<T, V>
|
||||||
|
where
|
||||||
|
T: Send + Sync + 'static,
|
||||||
|
V: Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
type ItemTrain = T;
|
||||||
|
type ItemValid = V;
|
||||||
|
|
||||||
|
fn on_event_train(&mut self, event: Event<Self::ItemTrain>) {
|
||||||
|
match event {
|
||||||
|
Event::ProcessedItem(item) => self.on_train_item(item),
|
||||||
|
Event::EndEpoch(epoch) => self.on_train_end_epoch(epoch),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn on_event_valid(&mut self, event: Event<Self::ItemValid>) {
|
||||||
|
match event {
|
||||||
|
Event::ProcessedItem(item) => self.on_valid_item(item),
|
||||||
|
Event::EndEpoch(epoch) => self.on_valid_end_epoch(epoch),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_epoch(
|
||||||
|
&mut self,
|
||||||
|
name: &str,
|
||||||
|
aggregate: Aggregate,
|
||||||
|
direction: Direction,
|
||||||
|
split: Split,
|
||||||
|
) -> Option<usize> {
|
||||||
|
self.info.find_epoch(name, aggregate, direction, split)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, V> RenderedMetricsEventCollector<T, V>
|
||||||
|
where
|
||||||
|
T: Send + Sync + 'static,
|
||||||
|
V: Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
fn on_train_item(&mut self, item: LearnerItem<T>) {
|
||||||
|
let progress = (&item).into();
|
||||||
|
let metadata = (&item).into();
|
||||||
|
|
||||||
|
let update = self.info.update_train(&item, &metadata);
|
||||||
|
|
||||||
|
update
|
||||||
|
.entries
|
||||||
|
.into_iter()
|
||||||
|
.for_each(|entry| self.renderer.update_train(MetricState::Generic(entry)));
|
||||||
|
|
||||||
|
update
|
||||||
|
.entries_numeric
|
||||||
|
.into_iter()
|
||||||
|
.for_each(|(entry, value)| {
|
||||||
|
self.renderer
|
||||||
|
.update_train(MetricState::Numeric(entry, value))
|
||||||
|
});
|
||||||
|
|
||||||
|
self.renderer.render_train(progress);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn on_valid_item(&mut self, item: LearnerItem<V>) {
|
||||||
|
let progress = (&item).into();
|
||||||
|
let metadata = (&item).into();
|
||||||
|
|
||||||
|
let update = self.info.update_valid(&item, &metadata);
|
||||||
|
|
||||||
|
update
|
||||||
|
.entries
|
||||||
|
.into_iter()
|
||||||
|
.for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry)));
|
||||||
|
|
||||||
|
update
|
||||||
|
.entries_numeric
|
||||||
|
.into_iter()
|
||||||
|
.for_each(|(entry, value)| {
|
||||||
|
self.renderer
|
||||||
|
.update_valid(MetricState::Numeric(entry, value))
|
||||||
|
});
|
||||||
|
|
||||||
|
self.renderer.render_train(progress);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn on_train_end_epoch(&mut self, epoch: usize) {
|
||||||
|
self.info.end_epoch_train(epoch);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn on_valid_end_epoch(&mut self, epoch: usize) {
|
||||||
|
self.info.end_epoch_valid(epoch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> From<&LearnerItem<T>> for TrainingProgress {
|
||||||
|
fn from(item: &LearnerItem<T>) -> Self {
|
||||||
|
Self {
|
||||||
|
progress: item.progress.clone(),
|
||||||
|
epoch: item.epoch,
|
||||||
|
epoch_total: item.epoch_total,
|
||||||
|
iteration: item.iteration,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> From<&LearnerItem<T>> for MetricMetadata {
|
||||||
|
fn from(item: &LearnerItem<T>) -> Self {
|
||||||
|
Self {
|
||||||
|
progress: item.progress.clone(),
|
||||||
|
epoch: item.epoch,
|
||||||
|
epoch_total: item.epoch_total,
|
||||||
|
iteration: item.iteration,
|
||||||
|
lr: item.lr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,3 @@
|
||||||
|
mod base;
|
||||||
|
|
||||||
|
pub(crate) use base::*;
|
|
@ -0,0 +1,8 @@
|
||||||
|
mod async_collector;
|
||||||
|
mod base;
|
||||||
|
|
||||||
|
pub use async_collector::*;
|
||||||
|
pub use base::*;
|
||||||
|
|
||||||
|
/// Metrics collector module.
|
||||||
|
pub mod metrics;
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::{checkpoint::Checkpointer, LearnerCallback};
|
use crate::{checkpoint::Checkpointer, EventCollector};
|
||||||
use burn_core::{
|
use burn_core::{
|
||||||
lr_scheduler::LrScheduler,
|
lr_scheduler::LrScheduler,
|
||||||
module::{ADModule, Module},
|
module::{ADModule, Module},
|
||||||
|
@ -25,8 +25,8 @@ pub trait LearnerComponents {
|
||||||
>;
|
>;
|
||||||
/// The checkpointer used for the scheduler.
|
/// The checkpointer used for the scheduler.
|
||||||
type CheckpointerLrScheduler: Checkpointer<<Self::LrScheduler as LrScheduler>::Record>;
|
type CheckpointerLrScheduler: Checkpointer<<Self::LrScheduler as LrScheduler>::Record>;
|
||||||
/// Callback used for training tracking.
|
/// Training event collector used for training tracking.
|
||||||
type Callback: LearnerCallback + 'static;
|
type EventCollector: EventCollector + 'static;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Concrete type that implements [training components trait](TrainingComponents).
|
/// Concrete type that implements [training components trait](TrainingComponents).
|
||||||
|
@ -41,8 +41,8 @@ pub struct LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, C> {
|
||||||
_callback: PhantomData<C>,
|
_callback: PhantomData<C>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B, LR, M, O, CM, CO, CS, C> LearnerComponents
|
impl<B, LR, M, O, CM, CO, CS, EC> LearnerComponents
|
||||||
for LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, C>
|
for LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EC>
|
||||||
where
|
where
|
||||||
B: ADBackend,
|
B: ADBackend,
|
||||||
LR: LrScheduler,
|
LR: LrScheduler,
|
||||||
|
@ -51,7 +51,7 @@ where
|
||||||
CM: Checkpointer<M::Record>,
|
CM: Checkpointer<M::Record>,
|
||||||
CO: Checkpointer<O::Record>,
|
CO: Checkpointer<O::Record>,
|
||||||
CS: Checkpointer<LR::Record>,
|
CS: Checkpointer<LR::Record>,
|
||||||
C: LearnerCallback + 'static,
|
EC: EventCollector + 'static,
|
||||||
{
|
{
|
||||||
type Backend = B;
|
type Backend = B;
|
||||||
type LrScheduler = LR;
|
type LrScheduler = LR;
|
||||||
|
@ -60,5 +60,5 @@ where
|
||||||
type CheckpointerModel = CM;
|
type CheckpointerModel = CM;
|
||||||
type CheckpointerOptimizer = CO;
|
type CheckpointerOptimizer = CO;
|
||||||
type CheckpointerLrScheduler = CS;
|
type CheckpointerLrScheduler = CS;
|
||||||
type Callback = C;
|
type EventCollector = EC;
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,163 @@
|
||||||
|
use crate::{logger::MetricLogger, Aggregate, Direction};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
/// Type that can be used to fetch and use numeric metric aggregates.
|
||||||
|
#[derive(Default, Debug)]
|
||||||
|
pub(crate) struct NumericMetricsAggregate {
|
||||||
|
mean_for_each_epoch: HashMap<Key, f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(new, Hash, PartialEq, Eq, Debug)]
|
||||||
|
struct Key {
|
||||||
|
name: String,
|
||||||
|
epoch: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NumericMetricsAggregate {
|
||||||
|
pub(crate) fn mean(
|
||||||
|
&mut self,
|
||||||
|
name: &str,
|
||||||
|
epoch: usize,
|
||||||
|
loggers: &mut [Box<dyn MetricLogger>],
|
||||||
|
) -> Option<f64> {
|
||||||
|
let key = Key::new(name.to_string(), epoch);
|
||||||
|
|
||||||
|
if let Some(value) = self.mean_for_each_epoch.get(&key) {
|
||||||
|
return Some(*value);
|
||||||
|
}
|
||||||
|
|
||||||
|
let points = || {
|
||||||
|
let mut errors = Vec::new();
|
||||||
|
for logger in loggers {
|
||||||
|
match logger.read_numeric(name, epoch) {
|
||||||
|
Ok(points) => return Ok(points),
|
||||||
|
Err(err) => errors.push(err),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(errors.join(" "))
|
||||||
|
};
|
||||||
|
|
||||||
|
let points = points().expect("Can read values");
|
||||||
|
|
||||||
|
if points.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let num_points = points.len();
|
||||||
|
let mean = points.into_iter().sum::<f64>() / num_points as f64;
|
||||||
|
|
||||||
|
self.mean_for_each_epoch.insert(key, mean);
|
||||||
|
Some(mean)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn find_epoch(
|
||||||
|
&mut self,
|
||||||
|
name: &str,
|
||||||
|
aggregate: Aggregate,
|
||||||
|
direction: Direction,
|
||||||
|
loggers: &mut [Box<dyn MetricLogger>],
|
||||||
|
) -> Option<usize> {
|
||||||
|
let mut data = Vec::new();
|
||||||
|
let mut current_epoch = 1;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
match aggregate {
|
||||||
|
Aggregate::Mean => match self.mean(name, current_epoch, loggers) {
|
||||||
|
Some(value) => {
|
||||||
|
data.push(value);
|
||||||
|
}
|
||||||
|
None => break,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
current_epoch += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut current_value = match &direction {
|
||||||
|
Direction::Lowest => f64::MAX,
|
||||||
|
Direction::Highest => f64::MIN,
|
||||||
|
};
|
||||||
|
|
||||||
|
for (i, value) in data.into_iter().enumerate() {
|
||||||
|
match &direction {
|
||||||
|
Direction::Lowest => {
|
||||||
|
if value < current_value {
|
||||||
|
current_value = value;
|
||||||
|
current_epoch = i + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Direction::Highest => {
|
||||||
|
if value > current_value {
|
||||||
|
current_value = value;
|
||||||
|
current_epoch = i + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(current_epoch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::{logger::FileMetricLogger, metric::MetricEntry};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
struct TestLogger {
|
||||||
|
logger: FileMetricLogger,
|
||||||
|
epoch: usize,
|
||||||
|
}
|
||||||
|
const NAME: &str = "test-logger";
|
||||||
|
|
||||||
|
impl TestLogger {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
logger: FileMetricLogger::new("/tmp"),
|
||||||
|
epoch: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn log(&mut self, num: f64) {
|
||||||
|
self.logger.log(&MetricEntry::new(
|
||||||
|
NAME.into(),
|
||||||
|
num.to_string(),
|
||||||
|
num.to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
fn new_epoch(&mut self) {
|
||||||
|
self.epoch += 1;
|
||||||
|
self.logger.epoch(self.epoch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_find_epoch() {
|
||||||
|
let mut logger = TestLogger::new();
|
||||||
|
let mut aggregate = NumericMetricsAggregate::default();
|
||||||
|
|
||||||
|
logger.log(500.); // Epoch 1
|
||||||
|
logger.log(1000.); // Epoch 1
|
||||||
|
logger.new_epoch();
|
||||||
|
logger.log(200.); // Epoch 2
|
||||||
|
logger.log(1000.); // Epoch 2
|
||||||
|
logger.new_epoch();
|
||||||
|
logger.log(10000.); // Epoch 3
|
||||||
|
|
||||||
|
let value = aggregate
|
||||||
|
.find_epoch(
|
||||||
|
NAME,
|
||||||
|
Aggregate::Mean,
|
||||||
|
Direction::Lowest,
|
||||||
|
&mut [Box::new(logger.logger)],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(value, 2);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,253 @@
|
||||||
|
use super::NumericMetricsAggregate;
|
||||||
|
use crate::{
|
||||||
|
logger::MetricLogger,
|
||||||
|
metric::{Adaptor, Metric, MetricEntry, MetricMetadata, Numeric},
|
||||||
|
Aggregate, Direction, LearnerItem, Split,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Metrics information collected during training.
|
||||||
|
pub struct MetricsInfo<T, V>
|
||||||
|
where
|
||||||
|
T: Send + Sync + 'static,
|
||||||
|
V: Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
train: Vec<Box<dyn MetricUpdater<T>>>,
|
||||||
|
valid: Vec<Box<dyn MetricUpdater<V>>>,
|
||||||
|
train_numeric: Vec<Box<dyn NumericMetricUpdater<T>>>,
|
||||||
|
valid_numeric: Vec<Box<dyn NumericMetricUpdater<V>>>,
|
||||||
|
loggers_train: Vec<Box<dyn MetricLogger>>,
|
||||||
|
loggers_valid: Vec<Box<dyn MetricLogger>>,
|
||||||
|
aggregate_train: NumericMetricsAggregate,
|
||||||
|
aggregate_valid: NumericMetricsAggregate,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(new)]
|
||||||
|
pub(crate) struct MetricsUpdate {
|
||||||
|
pub(crate) entries: Vec<MetricEntry>,
|
||||||
|
pub(crate) entries_numeric: Vec<(MetricEntry, f64)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, V> MetricsInfo<T, V>
|
||||||
|
where
|
||||||
|
T: Send + Sync + 'static,
|
||||||
|
V: Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
pub(crate) fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
train: vec![],
|
||||||
|
valid: vec![],
|
||||||
|
train_numeric: vec![],
|
||||||
|
valid_numeric: vec![],
|
||||||
|
loggers_train: vec![],
|
||||||
|
loggers_valid: vec![],
|
||||||
|
aggregate_train: NumericMetricsAggregate::default(),
|
||||||
|
aggregate_valid: NumericMetricsAggregate::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Signal the end of a training epoch.
|
||||||
|
pub(crate) fn end_epoch_train(&mut self, epoch: usize) {
|
||||||
|
for metric in self.train.iter_mut() {
|
||||||
|
metric.clear();
|
||||||
|
}
|
||||||
|
for metric in self.train_numeric.iter_mut() {
|
||||||
|
metric.clear();
|
||||||
|
}
|
||||||
|
for logger in self.loggers_train.iter_mut() {
|
||||||
|
logger.epoch(epoch + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Signal the end of a validation epoch.
|
||||||
|
pub(crate) fn end_epoch_valid(&mut self, epoch: usize) {
|
||||||
|
for metric in self.valid.iter_mut() {
|
||||||
|
metric.clear();
|
||||||
|
}
|
||||||
|
for metric in self.valid_numeric.iter_mut() {
|
||||||
|
metric.clear();
|
||||||
|
}
|
||||||
|
for logger in self.loggers_valid.iter_mut() {
|
||||||
|
logger.epoch(epoch + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update the training information from the training item.
|
||||||
|
pub(crate) fn update_train(
|
||||||
|
&mut self,
|
||||||
|
item: &LearnerItem<T>,
|
||||||
|
metadata: &MetricMetadata,
|
||||||
|
) -> MetricsUpdate {
|
||||||
|
let mut entries = Vec::with_capacity(self.train.len());
|
||||||
|
let mut entries_numeric = Vec::with_capacity(self.train_numeric.len());
|
||||||
|
|
||||||
|
for metric in self.train.iter_mut() {
|
||||||
|
let state = metric.update(item, metadata);
|
||||||
|
|
||||||
|
for logger in self.loggers_train.iter_mut() {
|
||||||
|
logger.log(&state);
|
||||||
|
}
|
||||||
|
|
||||||
|
entries.push(state);
|
||||||
|
}
|
||||||
|
|
||||||
|
for metric in self.train_numeric.iter_mut() {
|
||||||
|
let (state, value) = metric.update(item, metadata);
|
||||||
|
for logger in self.loggers_train.iter_mut() {
|
||||||
|
logger.log(&state);
|
||||||
|
}
|
||||||
|
|
||||||
|
entries_numeric.push((state, value));
|
||||||
|
}
|
||||||
|
|
||||||
|
MetricsUpdate::new(entries, entries_numeric)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update the training information from the validation item.
|
||||||
|
pub(crate) fn update_valid(
|
||||||
|
&mut self,
|
||||||
|
item: &LearnerItem<V>,
|
||||||
|
metadata: &MetricMetadata,
|
||||||
|
) -> MetricsUpdate {
|
||||||
|
let mut entries = Vec::with_capacity(self.valid.len());
|
||||||
|
let mut entries_numeric = Vec::with_capacity(self.valid_numeric.len());
|
||||||
|
|
||||||
|
for metric in self.valid.iter_mut() {
|
||||||
|
let state = metric.update(item, metadata);
|
||||||
|
|
||||||
|
for logger in self.loggers_valid.iter_mut() {
|
||||||
|
logger.log(&state);
|
||||||
|
}
|
||||||
|
|
||||||
|
entries.push(state);
|
||||||
|
}
|
||||||
|
|
||||||
|
for metric in self.valid_numeric.iter_mut() {
|
||||||
|
let (state, value) = metric.update(item, metadata);
|
||||||
|
for logger in self.loggers_valid.iter_mut() {
|
||||||
|
logger.log(&state);
|
||||||
|
}
|
||||||
|
|
||||||
|
entries_numeric.push((state, value));
|
||||||
|
}
|
||||||
|
|
||||||
|
MetricsUpdate::new(entries, entries_numeric)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find the epoch corresponding to the given criteria.
|
||||||
|
pub(crate) fn find_epoch(
|
||||||
|
&mut self,
|
||||||
|
name: &str,
|
||||||
|
aggregate: Aggregate,
|
||||||
|
direction: Direction,
|
||||||
|
split: Split,
|
||||||
|
) -> Option<usize> {
|
||||||
|
match split {
|
||||||
|
Split::Train => {
|
||||||
|
self.aggregate_train
|
||||||
|
.find_epoch(name, aggregate, direction, &mut self.loggers_train)
|
||||||
|
}
|
||||||
|
Split::Valid => {
|
||||||
|
self.aggregate_valid
|
||||||
|
.find_epoch(name, aggregate, direction, &mut self.loggers_valid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register a logger for training metrics.
|
||||||
|
pub(crate) fn register_logger_train<ML: MetricLogger + 'static>(&mut self, logger: ML) {
|
||||||
|
self.loggers_train.push(Box::new(logger));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register a logger for validation metrics.
|
||||||
|
pub(crate) fn register_logger_valid<ML: MetricLogger + 'static>(&mut self, logger: ML) {
|
||||||
|
self.loggers_valid.push(Box::new(logger));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register a training metric.
|
||||||
|
pub(crate) fn register_metric_train<Me: Metric + 'static>(&mut self, metric: Me)
|
||||||
|
where
|
||||||
|
T: Adaptor<Me::Input>,
|
||||||
|
{
|
||||||
|
let metric = MetricWrapper::new(metric);
|
||||||
|
self.train.push(Box::new(metric))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register a validation metric.
|
||||||
|
pub(crate) fn register_valid_metric<Me: Metric + 'static>(&mut self, metric: Me)
|
||||||
|
where
|
||||||
|
V: Adaptor<Me::Input>,
|
||||||
|
{
|
||||||
|
let metric = MetricWrapper::new(metric);
|
||||||
|
self.valid.push(Box::new(metric))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register a numeric training metric.
|
||||||
|
pub(crate) fn register_train_metric_numeric<Me: Metric + Numeric + 'static>(
|
||||||
|
&mut self,
|
||||||
|
metric: Me,
|
||||||
|
) where
|
||||||
|
T: Adaptor<Me::Input>,
|
||||||
|
{
|
||||||
|
let metric = MetricWrapper::new(metric);
|
||||||
|
self.train_numeric.push(Box::new(metric))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register a numeric validation metric.
|
||||||
|
pub(crate) fn register_valid_metric_numeric<Me: Metric + Numeric + 'static>(
|
||||||
|
&mut self,
|
||||||
|
metric: Me,
|
||||||
|
) where
|
||||||
|
V: Adaptor<Me::Input>,
|
||||||
|
{
|
||||||
|
let metric = MetricWrapper::new(metric);
|
||||||
|
self.valid_numeric.push(Box::new(metric))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
trait NumericMetricUpdater<T>: Send + Sync {
|
||||||
|
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> (MetricEntry, f64);
|
||||||
|
fn clear(&mut self);
|
||||||
|
}
|
||||||
|
|
||||||
|
trait MetricUpdater<T>: Send + Sync {
|
||||||
|
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry;
|
||||||
|
fn clear(&mut self);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(new)]
|
||||||
|
struct MetricWrapper<M> {
|
||||||
|
metric: M,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, M> NumericMetricUpdater<T> for MetricWrapper<M>
|
||||||
|
where
|
||||||
|
T: 'static,
|
||||||
|
M: Metric + Numeric + 'static,
|
||||||
|
T: Adaptor<M::Input>,
|
||||||
|
{
|
||||||
|
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> (MetricEntry, f64) {
|
||||||
|
let update = self.metric.update(&item.item.adapt(), metadata);
|
||||||
|
let numeric = self.metric.value();
|
||||||
|
|
||||||
|
(update, numeric)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear(&mut self) {
|
||||||
|
self.metric.clear()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, M> MetricUpdater<T> for MetricWrapper<M>
|
||||||
|
where
|
||||||
|
T: 'static,
|
||||||
|
M: Metric + 'static,
|
||||||
|
T: Adaptor<M::Input>,
|
||||||
|
{
|
||||||
|
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry {
|
||||||
|
self.metric.update(&item.item.adapt(), metadata)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear(&mut self) {
|
||||||
|
self.metric.clear()
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,5 @@
|
||||||
|
mod aggregates;
|
||||||
|
mod metrics;
|
||||||
|
|
||||||
|
pub(crate) use aggregates::*;
|
||||||
|
pub use metrics::*;
|
|
@ -19,7 +19,7 @@ pub struct Learner<LC: LearnerComponents> {
|
||||||
pub(crate) grad_accumulation: Option<usize>,
|
pub(crate) grad_accumulation: Option<usize>,
|
||||||
pub(crate) checkpointer: Option<LearnerCheckpointer<LC>>,
|
pub(crate) checkpointer: Option<LearnerCheckpointer<LC>>,
|
||||||
pub(crate) devices: Vec<<LC::Backend as Backend>::Device>,
|
pub(crate) devices: Vec<<LC::Backend as Backend>::Device>,
|
||||||
pub(crate) callback: LC::Callback,
|
pub(crate) collector: LC::EventCollector,
|
||||||
pub(crate) interrupter: TrainingInterrupter,
|
pub(crate) interrupter: TrainingInterrupter,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,14 @@
|
||||||
use super::log::install_file_logger;
|
use super::log::install_file_logger;
|
||||||
use super::Learner;
|
use super::Learner;
|
||||||
use crate::checkpoint::{AsyncCheckpointer, FileCheckpointer};
|
use crate::checkpoint::{AsyncCheckpointer, FileCheckpointer};
|
||||||
|
use crate::collector::metrics::RenderedMetricsEventCollector;
|
||||||
use crate::components::LearnerComponentsMarker;
|
use crate::components::LearnerComponentsMarker;
|
||||||
|
use crate::info::MetricsInfo;
|
||||||
use crate::learner::base::TrainingInterrupter;
|
use crate::learner::base::TrainingInterrupter;
|
||||||
use crate::logger::{FileMetricLogger, MetricLogger};
|
use crate::logger::{FileMetricLogger, MetricLogger};
|
||||||
use crate::metric::callback::{
|
|
||||||
default_renderer, MetricWrapper, Metrics, MetricsCallback, MetricsRenderer,
|
|
||||||
};
|
|
||||||
use crate::metric::{Adaptor, Metric};
|
use crate::metric::{Adaptor, Metric};
|
||||||
use crate::{AsyncTrainerCallback, LearnerCheckpointer};
|
use crate::renderer::{default_renderer, MetricsRenderer};
|
||||||
|
use crate::{AsyncEventCollector, LearnerCheckpointer};
|
||||||
use burn_core::lr_scheduler::LrScheduler;
|
use burn_core::lr_scheduler::LrScheduler;
|
||||||
use burn_core::module::ADModule;
|
use burn_core::module::ADModule;
|
||||||
use burn_core::optim::Optimizer;
|
use burn_core::optim::Optimizer;
|
||||||
|
@ -39,12 +39,11 @@ where
|
||||||
directory: String,
|
directory: String,
|
||||||
grad_accumulation: Option<usize>,
|
grad_accumulation: Option<usize>,
|
||||||
devices: Vec<B::Device>,
|
devices: Vec<B::Device>,
|
||||||
metric_logger_train: Option<Box<dyn MetricLogger + 'static>>,
|
|
||||||
metric_logger_valid: Option<Box<dyn MetricLogger + 'static>>,
|
|
||||||
renderer: Option<Box<dyn MetricsRenderer + 'static>>,
|
renderer: Option<Box<dyn MetricsRenderer + 'static>>,
|
||||||
metrics: Metrics<T, V>,
|
info: MetricsInfo<T, V>,
|
||||||
interrupter: TrainingInterrupter,
|
interrupter: TrainingInterrupter,
|
||||||
log_to_file: bool,
|
log_to_file: bool,
|
||||||
|
num_loggers: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B, T, V, M, O, S> LearnerBuilder<B, T, V, M, O, S>
|
impl<B, T, V, M, O, S> LearnerBuilder<B, T, V, M, O, S>
|
||||||
|
@ -69,12 +68,11 @@ where
|
||||||
directory: directory.to_string(),
|
directory: directory.to_string(),
|
||||||
grad_accumulation: None,
|
grad_accumulation: None,
|
||||||
devices: vec![B::Device::default()],
|
devices: vec![B::Device::default()],
|
||||||
metric_logger_train: None,
|
info: MetricsInfo::new(),
|
||||||
metric_logger_valid: None,
|
|
||||||
metrics: Metrics::new(),
|
|
||||||
renderer: None,
|
renderer: None,
|
||||||
interrupter: TrainingInterrupter::new(),
|
interrupter: TrainingInterrupter::new(),
|
||||||
log_to_file: true,
|
log_to_file: true,
|
||||||
|
num_loggers: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,8 +87,9 @@ where
|
||||||
MT: MetricLogger + 'static,
|
MT: MetricLogger + 'static,
|
||||||
MV: MetricLogger + 'static,
|
MV: MetricLogger + 'static,
|
||||||
{
|
{
|
||||||
self.metric_logger_train = Some(Box::new(logger_train));
|
self.info.register_logger_train(logger_train);
|
||||||
self.metric_logger_valid = Some(Box::new(logger_valid));
|
self.info.register_logger_valid(logger_valid);
|
||||||
|
self.num_loggers += 1;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,9 +111,7 @@ where
|
||||||
where
|
where
|
||||||
T: Adaptor<Me::Input>,
|
T: Adaptor<Me::Input>,
|
||||||
{
|
{
|
||||||
self.metrics
|
self.info.register_metric_train(metric);
|
||||||
.train
|
|
||||||
.push(Box::new(MetricWrapper::new(metric)));
|
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -123,9 +120,7 @@ where
|
||||||
where
|
where
|
||||||
V: Adaptor<Me::Input>,
|
V: Adaptor<Me::Input>,
|
||||||
{
|
{
|
||||||
self.metrics
|
self.info.register_valid_metric(metric);
|
||||||
.valid
|
|
||||||
.push(Box::new(MetricWrapper::new(metric)));
|
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -144,41 +139,25 @@ where
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Register a training metric and displays it on a plot.
|
/// Register a [numeric](crate::metric::Numeric) training [metric](Metric).
|
||||||
///
|
pub fn metric_train_numeric<Me>(mut self, metric: Me) -> Self
|
||||||
/// # Notes
|
|
||||||
///
|
|
||||||
/// Only [numeric](crate::metric::Numeric) metric can be displayed on a plot.
|
|
||||||
/// If the same metric is also registered for the [validation split](Self::metric_valid_plot),
|
|
||||||
/// the same graph will be used for both.
|
|
||||||
pub fn metric_train_plot<Me>(mut self, metric: Me) -> Self
|
|
||||||
where
|
where
|
||||||
Me: Metric + crate::metric::Numeric + 'static,
|
Me: Metric + crate::metric::Numeric + 'static,
|
||||||
T: Adaptor<Me::Input>,
|
T: Adaptor<Me::Input>,
|
||||||
{
|
{
|
||||||
self.metrics
|
self.info.register_train_metric_numeric(metric);
|
||||||
.train_numeric
|
|
||||||
.push(Box::new(MetricWrapper::new(metric)));
|
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Register a validation metric and displays it on a plot.
|
/// Register a [numeric](crate::metric::Numeric) validation [metric](Metric).
|
||||||
///
|
pub fn metric_valid_numeric<Me: Metric + crate::metric::Numeric + 'static>(
|
||||||
/// # Notes
|
|
||||||
///
|
|
||||||
/// Only [numeric](crate::metric::Numeric) metric can be displayed on a plot.
|
|
||||||
/// If the same metric is also registered for the [training split](Self::metric_train_plot),
|
|
||||||
/// the same graph will be used for both.
|
|
||||||
pub fn metric_valid_plot<Me: Metric + crate::metric::Numeric + 'static>(
|
|
||||||
mut self,
|
mut self,
|
||||||
metric: Me,
|
metric: Me,
|
||||||
) -> Self
|
) -> Self
|
||||||
where
|
where
|
||||||
V: Adaptor<Me::Input>,
|
V: Adaptor<Me::Input>,
|
||||||
{
|
{
|
||||||
self.metrics
|
self.info.register_valid_metric_numeric(metric);
|
||||||
.valid_numeric
|
|
||||||
.push(Box::new(MetricWrapper::new(metric)));
|
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -260,7 +239,7 @@ where
|
||||||
#[allow(clippy::type_complexity)] // The goal for the builder is to handle all types and
|
#[allow(clippy::type_complexity)] // The goal for the builder is to handle all types and
|
||||||
// creates a clean learner.
|
// creates a clean learner.
|
||||||
pub fn build(
|
pub fn build(
|
||||||
self,
|
mut self,
|
||||||
model: M,
|
model: M,
|
||||||
optim: O,
|
optim: O,
|
||||||
lr_scheduler: S,
|
lr_scheduler: S,
|
||||||
|
@ -273,7 +252,7 @@ where
|
||||||
AsyncCheckpointer<M::Record>,
|
AsyncCheckpointer<M::Record>,
|
||||||
AsyncCheckpointer<O::Record>,
|
AsyncCheckpointer<O::Record>,
|
||||||
AsyncCheckpointer<S::Record>,
|
AsyncCheckpointer<S::Record>,
|
||||||
AsyncTrainerCallback<T, V>,
|
AsyncEventCollector<T, V>,
|
||||||
>,
|
>,
|
||||||
>
|
>
|
||||||
where
|
where
|
||||||
|
@ -288,18 +267,18 @@ where
|
||||||
Box::new(default_renderer(self.interrupter.clone(), self.checkpoint))
|
Box::new(default_renderer(self.interrupter.clone(), self.checkpoint))
|
||||||
});
|
});
|
||||||
let directory = &self.directory;
|
let directory = &self.directory;
|
||||||
let logger_train = self.metric_logger_train.unwrap_or_else(|| {
|
|
||||||
Box::new(FileMetricLogger::new(format!("{directory}/train").as_str()))
|
if self.num_loggers == 0 {
|
||||||
});
|
self.info.register_logger_train(FileMetricLogger::new(
|
||||||
let logger_valid = self.metric_logger_valid.unwrap_or_else(|| {
|
format!("{directory}/train").as_str(),
|
||||||
Box::new(FileMetricLogger::new(format!("{directory}/valid").as_str()))
|
));
|
||||||
});
|
self.info.register_logger_valid(FileMetricLogger::new(
|
||||||
let callback = AsyncTrainerCallback::new(MetricsCallback::new(
|
format!("{directory}/valid").as_str(),
|
||||||
renderer,
|
));
|
||||||
self.metrics,
|
}
|
||||||
logger_train,
|
|
||||||
logger_valid,
|
let collector =
|
||||||
));
|
AsyncEventCollector::new(RenderedMetricsEventCollector::new(renderer, self.info));
|
||||||
|
|
||||||
let checkpointer = self
|
let checkpointer = self
|
||||||
.checkpointers
|
.checkpointers
|
||||||
|
@ -311,7 +290,7 @@ where
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
checkpointer,
|
checkpointer,
|
||||||
num_epochs: self.num_epochs,
|
num_epochs: self.num_epochs,
|
||||||
callback,
|
collector,
|
||||||
checkpoint: self.checkpoint,
|
checkpoint: self.checkpoint,
|
||||||
grad_accumulation: self.grad_accumulation,
|
grad_accumulation: self.grad_accumulation,
|
||||||
devices: self.devices,
|
devices: self.devices,
|
||||||
|
|
|
@ -7,8 +7,8 @@ use burn_core::{
|
||||||
};
|
};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::learner::base::TrainingInterrupter;
|
use crate::{learner::base::TrainingInterrupter, Event};
|
||||||
use crate::{LearnerCallback, LearnerItem, MultiDevicesTrainStep, TrainStep, ValidStep};
|
use crate::{EventCollector, LearnerItem, MultiDevicesTrainStep, TrainStep, ValidStep};
|
||||||
|
|
||||||
/// A validation epoch.
|
/// A validation epoch.
|
||||||
#[derive(new)]
|
#[derive(new)]
|
||||||
|
@ -37,7 +37,7 @@ impl<VI> ValidEpoch<VI> {
|
||||||
pub fn run<B, M, TO, VO>(
|
pub fn run<B, M, TO, VO>(
|
||||||
&self,
|
&self,
|
||||||
model: &M,
|
model: &M,
|
||||||
callback: &mut Box<dyn LearnerCallback<ItemTrain = TO, ItemValid = VO>>,
|
callback: &mut Box<dyn EventCollector<ItemTrain = TO, ItemValid = VO>>,
|
||||||
interrupter: &TrainingInterrupter,
|
interrupter: &TrainingInterrupter,
|
||||||
) where
|
) where
|
||||||
B: ADBackend,
|
B: ADBackend,
|
||||||
|
@ -64,13 +64,14 @@ impl<VI> ValidEpoch<VI> {
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
|
||||||
callback.on_valid_item(item);
|
callback.on_event_valid(Event::ProcessedItem(item));
|
||||||
|
|
||||||
if interrupter.should_stop() {
|
if interrupter.should_stop() {
|
||||||
log::info!("Training interrupted.");
|
log::info!("Training interrupted.");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
callback.on_valid_end_epoch(self.epoch);
|
callback.on_event_valid(Event::EndEpoch(self.epoch));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,7 +93,7 @@ impl<TI> TrainEpoch<TI> {
|
||||||
mut model: M,
|
mut model: M,
|
||||||
mut optim: O,
|
mut optim: O,
|
||||||
scheduler: &mut LR,
|
scheduler: &mut LR,
|
||||||
callback: &mut Box<dyn LearnerCallback<ItemTrain = TO, ItemValid = VO>>,
|
callback: &mut Box<dyn EventCollector<ItemTrain = TO, ItemValid = VO>>,
|
||||||
interrupter: &TrainingInterrupter,
|
interrupter: &TrainingInterrupter,
|
||||||
) -> (M, O)
|
) -> (M, O)
|
||||||
where
|
where
|
||||||
|
@ -139,13 +140,13 @@ impl<TI> TrainEpoch<TI> {
|
||||||
Some(lr),
|
Some(lr),
|
||||||
);
|
);
|
||||||
|
|
||||||
callback.on_train_item(item);
|
callback.on_event_train(Event::ProcessedItem(item));
|
||||||
if interrupter.should_stop() {
|
if interrupter.should_stop() {
|
||||||
log::info!("Training interrupted.");
|
log::info!("Training interrupted.");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
callback.on_train_end_epoch(self.epoch);
|
callback.on_event_train(Event::EndEpoch(self.epoch));
|
||||||
|
|
||||||
(model, optim)
|
(model, optim)
|
||||||
}
|
}
|
||||||
|
@ -170,7 +171,7 @@ impl<TI> TrainEpoch<TI> {
|
||||||
mut model: M,
|
mut model: M,
|
||||||
mut optim: O,
|
mut optim: O,
|
||||||
lr_scheduler: &mut S,
|
lr_scheduler: &mut S,
|
||||||
callback: &mut Box<dyn LearnerCallback<ItemTrain = TO, ItemValid = VO>>,
|
callback: &mut Box<dyn EventCollector<ItemTrain = TO, ItemValid = VO>>,
|
||||||
devices: Vec<B::Device>,
|
devices: Vec<B::Device>,
|
||||||
interrupter: &TrainingInterrupter,
|
interrupter: &TrainingInterrupter,
|
||||||
) -> (M, O)
|
) -> (M, O)
|
||||||
|
@ -232,7 +233,8 @@ impl<TI> TrainEpoch<TI> {
|
||||||
Some(lr),
|
Some(lr),
|
||||||
);
|
);
|
||||||
|
|
||||||
callback.on_train_item(item);
|
callback.on_event_train(Event::ProcessedItem(item));
|
||||||
|
|
||||||
if interrupter.should_stop() {
|
if interrupter.should_stop() {
|
||||||
log::info!("Training interrupted.");
|
log::info!("Training interrupted.");
|
||||||
interrupted = true;
|
interrupted = true;
|
||||||
|
@ -245,7 +247,7 @@ impl<TI> TrainEpoch<TI> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
callback.on_train_end_epoch(self.epoch);
|
callback.on_event_train(Event::EndEpoch(self.epoch));
|
||||||
|
|
||||||
(model, optim)
|
(model, optim)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::components::LearnerComponents;
|
use crate::components::LearnerComponents;
|
||||||
use crate::{Learner, LearnerCallback, TrainEpoch, ValidEpoch};
|
use crate::{EventCollector, Learner, TrainEpoch, ValidEpoch};
|
||||||
use burn_core::data::dataloader::DataLoader;
|
use burn_core::data::dataloader::DataLoader;
|
||||||
use burn_core::module::{ADModule, Module};
|
use burn_core::module::{ADModule, Module};
|
||||||
use burn_core::optim::{GradientsParams, Optimizer};
|
use burn_core::optim::{GradientsParams, Optimizer};
|
||||||
|
@ -115,7 +115,7 @@ impl<LC: LearnerComponents> Learner<LC> {
|
||||||
OutputValid: Send,
|
OutputValid: Send,
|
||||||
LC::Model: TrainStep<InputTrain, OutputTrain>,
|
LC::Model: TrainStep<InputTrain, OutputTrain>,
|
||||||
<LC::Model as ADModule<LC::Backend>>::InnerModule: ValidStep<InputValid, OutputValid>,
|
<LC::Model as ADModule<LC::Backend>>::InnerModule: ValidStep<InputValid, OutputValid>,
|
||||||
LC::Callback: LearnerCallback<ItemTrain = OutputTrain, ItemValid = OutputValid>,
|
LC::EventCollector: EventCollector<ItemTrain = OutputTrain, ItemValid = OutputValid>,
|
||||||
{
|
{
|
||||||
log::info!("Fitting {}", self.model.to_string());
|
log::info!("Fitting {}", self.model.to_string());
|
||||||
// The reference model is always on the first device provided.
|
// The reference model is always on the first device provided.
|
||||||
|
@ -139,8 +139,8 @@ impl<LC: LearnerComponents> Learner<LC> {
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut callback: Box<
|
let mut callback: Box<
|
||||||
dyn LearnerCallback<ItemTrain = OutputTrain, ItemValid = OutputValid>,
|
dyn EventCollector<ItemTrain = OutputTrain, ItemValid = OutputValid>,
|
||||||
> = Box::new(self.callback);
|
> = Box::new(self.collector);
|
||||||
|
|
||||||
for epoch in starting_epoch..self.num_epochs + 1 {
|
for epoch in starting_epoch..self.num_epochs + 1 {
|
||||||
let epoch_train = TrainEpoch::new(
|
let epoch_train = TrainEpoch::new(
|
||||||
|
|
|
@ -10,16 +10,22 @@ pub mod checkpoint;
|
||||||
|
|
||||||
pub(crate) mod components;
|
pub(crate) mod components;
|
||||||
|
|
||||||
|
/// Renderer modules to display metrics and training information.
|
||||||
|
pub mod renderer;
|
||||||
|
|
||||||
/// The logger module.
|
/// The logger module.
|
||||||
pub mod logger;
|
pub mod logger;
|
||||||
|
|
||||||
/// The metric module.
|
/// The metric module.
|
||||||
pub mod metric;
|
pub mod metric;
|
||||||
|
|
||||||
mod callback;
|
/// All information collected during training.
|
||||||
|
pub mod info;
|
||||||
|
|
||||||
|
mod collector;
|
||||||
mod learner;
|
mod learner;
|
||||||
|
|
||||||
pub use callback::*;
|
pub use collector::*;
|
||||||
pub use learner::*;
|
pub use learner::*;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
|
@ -4,6 +4,7 @@ use std::sync::mpsc;
|
||||||
enum Message<T> {
|
enum Message<T> {
|
||||||
Log(T),
|
Log(T),
|
||||||
End,
|
End,
|
||||||
|
Sync(mpsc::Sender<()>),
|
||||||
}
|
}
|
||||||
/// Async logger.
|
/// Async logger.
|
||||||
pub struct AsyncLogger<T> {
|
pub struct AsyncLogger<T> {
|
||||||
|
@ -30,6 +31,9 @@ where
|
||||||
Message::End => {
|
Message::End => {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
Message::Sync(callback) => {
|
||||||
|
callback.send(()).unwrap();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -48,6 +52,17 @@ impl<T: Send + Sync + 'static> AsyncLogger<T> {
|
||||||
|
|
||||||
Self { sender, handler }
|
Self { sender, handler }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Sync the async logger.
|
||||||
|
pub(crate) fn sync(&self) {
|
||||||
|
let (sender, receiver) = mpsc::channel();
|
||||||
|
|
||||||
|
self.sender.send(Message::Sync(sender)).unwrap();
|
||||||
|
|
||||||
|
receiver
|
||||||
|
.recv()
|
||||||
|
.expect("Should sync, otherwise the thread is dead.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Send> Logger<T> for AsyncLogger<T> {
|
impl<T: Send> Logger<T> for AsyncLogger<T> {
|
||||||
|
|
|
@ -49,11 +49,14 @@ impl FileMetricLogger {
|
||||||
|
|
||||||
fn file_path(&self, name: &str, epoch: usize) -> String {
|
fn file_path(&self, name: &str, epoch: usize) -> String {
|
||||||
let directory = format!("{}/epoch-{}", self.directory, epoch);
|
let directory = format!("{}/epoch-{}", self.directory, epoch);
|
||||||
std::fs::create_dir_all(&directory).ok();
|
|
||||||
let name = name.replace(' ', "_");
|
let name = name.replace(' ', "_");
|
||||||
|
|
||||||
format!("{directory}/{name}.log")
|
format!("{directory}/{name}.log")
|
||||||
}
|
}
|
||||||
|
fn create_directory(&self, epoch: usize) {
|
||||||
|
let directory = format!("{}/epoch-{}", self.directory, epoch);
|
||||||
|
std::fs::create_dir_all(directory).ok();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MetricLogger for FileMetricLogger {
|
impl MetricLogger for FileMetricLogger {
|
||||||
|
@ -64,6 +67,8 @@ impl MetricLogger for FileMetricLogger {
|
||||||
let logger = match self.loggers.get_mut(key) {
|
let logger = match self.loggers.get_mut(key) {
|
||||||
Some(val) => val,
|
Some(val) => val,
|
||||||
None => {
|
None => {
|
||||||
|
self.create_directory(self.epoch);
|
||||||
|
|
||||||
let file_path = self.file_path(key, self.epoch);
|
let file_path = self.file_path(key, self.epoch);
|
||||||
let logger = FileLogger::new(&file_path);
|
let logger = FileLogger::new(&file_path);
|
||||||
let logger = AsyncLogger::new(logger);
|
let logger = AsyncLogger::new(logger);
|
||||||
|
@ -82,6 +87,10 @@ impl MetricLogger for FileMetricLogger {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<f64>, String> {
|
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<f64>, String> {
|
||||||
|
if let Some(value) = self.loggers.get(name) {
|
||||||
|
value.sync()
|
||||||
|
}
|
||||||
|
|
||||||
let file_path = self.file_path(name, epoch);
|
let file_path = self.file_path(name, epoch);
|
||||||
|
|
||||||
let mut errors = false;
|
let mut errors = false;
|
||||||
|
|
|
@ -1,285 +0,0 @@
|
||||||
use crate::{
|
|
||||||
logger::MetricLogger,
|
|
||||||
metric::{Adaptor, Metric, MetricEntry, MetricMetadata, Numeric},
|
|
||||||
LearnerCallback, LearnerItem,
|
|
||||||
};
|
|
||||||
use burn_core::data::dataloader::Progress;
|
|
||||||
|
|
||||||
/// Holds all metrics, metric loggers, and a metrics renderer.
|
|
||||||
pub struct MetricsCallback<T, V>
|
|
||||||
where
|
|
||||||
T: Send + Sync + 'static,
|
|
||||||
V: Send + Sync + 'static,
|
|
||||||
{
|
|
||||||
metrics: Metrics<T, V>,
|
|
||||||
logger_train: Box<dyn MetricLogger>,
|
|
||||||
logger_valid: Box<dyn MetricLogger>,
|
|
||||||
renderer: Box<dyn MetricsRenderer>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T, V> MetricsCallback<T, V>
|
|
||||||
where
|
|
||||||
T: Send + Sync + 'static,
|
|
||||||
V: Send + Sync + 'static,
|
|
||||||
{
|
|
||||||
/// Creates a new metrics callback.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `renderer` - The metrics renderer.
|
|
||||||
/// * `metrics` - The metrics holder.
|
|
||||||
/// * `logger_train` - The training logger.
|
|
||||||
/// * `logger_valid` - The validation logger.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// A new metrics callback.
|
|
||||||
pub(crate) fn new(
|
|
||||||
renderer: Box<dyn MetricsRenderer>,
|
|
||||||
metrics: Metrics<T, V>,
|
|
||||||
logger_train: Box<dyn MetricLogger>,
|
|
||||||
logger_valid: Box<dyn MetricLogger>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
metrics,
|
|
||||||
logger_train,
|
|
||||||
logger_valid,
|
|
||||||
renderer,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T, V> LearnerCallback for MetricsCallback<T, V>
|
|
||||||
where
|
|
||||||
T: Send + Sync + 'static,
|
|
||||||
V: Send + Sync + 'static,
|
|
||||||
{
|
|
||||||
type ItemTrain = T;
|
|
||||||
type ItemValid = V;
|
|
||||||
|
|
||||||
fn on_train_item(&mut self, item: LearnerItem<T>) {
|
|
||||||
let metadata = (&item).into();
|
|
||||||
for metric in self.metrics.train.iter_mut() {
|
|
||||||
let state = metric.update(&item, &metadata);
|
|
||||||
self.logger_train.log(&state);
|
|
||||||
|
|
||||||
self.renderer.update_train(MetricState::Generic(state));
|
|
||||||
}
|
|
||||||
for metric in self.metrics.train_numeric.iter_mut() {
|
|
||||||
let (state, value) = metric.update(&item, &metadata);
|
|
||||||
self.logger_train.log(&state);
|
|
||||||
|
|
||||||
self.renderer
|
|
||||||
.update_train(MetricState::Numeric(state, value));
|
|
||||||
}
|
|
||||||
self.renderer.render_train(item.into());
|
|
||||||
}
|
|
||||||
|
|
||||||
fn on_valid_item(&mut self, item: LearnerItem<V>) {
|
|
||||||
let metadata = (&item).into();
|
|
||||||
for metric in self.metrics.valid.iter_mut() {
|
|
||||||
let state = metric.update(&item, &metadata);
|
|
||||||
self.logger_valid.log(&state);
|
|
||||||
|
|
||||||
self.renderer.update_valid(MetricState::Generic(state));
|
|
||||||
}
|
|
||||||
for metric in self.metrics.valid_numeric.iter_mut() {
|
|
||||||
let (state, value) = metric.update(&item, &metadata);
|
|
||||||
self.logger_valid.log(&state);
|
|
||||||
|
|
||||||
self.renderer
|
|
||||||
.update_valid(MetricState::Numeric(state, value));
|
|
||||||
}
|
|
||||||
self.renderer.render_valid(item.into());
|
|
||||||
}
|
|
||||||
|
|
||||||
fn on_train_end_epoch(&mut self, epoch: usize) {
|
|
||||||
for metric in self.metrics.train.iter_mut() {
|
|
||||||
metric.clear();
|
|
||||||
}
|
|
||||||
for metric in self.metrics.train_numeric.iter_mut() {
|
|
||||||
metric.clear();
|
|
||||||
}
|
|
||||||
self.logger_train.epoch(epoch + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn on_valid_end_epoch(&mut self, epoch: usize) {
|
|
||||||
for metric in self.metrics.valid.iter_mut() {
|
|
||||||
metric.clear();
|
|
||||||
}
|
|
||||||
for metric in self.metrics.valid_numeric.iter_mut() {
|
|
||||||
metric.clear();
|
|
||||||
}
|
|
||||||
self.logger_valid.epoch(epoch + 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Training progress.
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct TrainingProgress {
|
|
||||||
/// The progress.
|
|
||||||
pub progress: Progress,
|
|
||||||
|
|
||||||
/// The epoch.
|
|
||||||
pub epoch: usize,
|
|
||||||
|
|
||||||
/// The total number of epochs.
|
|
||||||
pub epoch_total: usize,
|
|
||||||
|
|
||||||
/// The iteration.
|
|
||||||
pub iteration: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TrainingProgress {
|
|
||||||
/// Creates a new empty training progress.
|
|
||||||
pub fn none() -> Self {
|
|
||||||
Self {
|
|
||||||
progress: Progress {
|
|
||||||
items_processed: 0,
|
|
||||||
items_total: 0,
|
|
||||||
},
|
|
||||||
epoch: 0,
|
|
||||||
epoch_total: 0,
|
|
||||||
iteration: 0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The state of a metric.
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum MetricState {
|
|
||||||
/// A generic metric.
|
|
||||||
Generic(MetricEntry),
|
|
||||||
|
|
||||||
/// A numeric metric.
|
|
||||||
Numeric(MetricEntry, f64),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Trait for rendering metrics.
|
|
||||||
pub trait MetricsRenderer: Send + Sync {
|
|
||||||
/// Updates the training metric state.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `state` - The metric state.
|
|
||||||
fn update_train(&mut self, state: MetricState);
|
|
||||||
|
|
||||||
/// Updates the validation metric state.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `state` - The metric state.
|
|
||||||
fn update_valid(&mut self, state: MetricState);
|
|
||||||
|
|
||||||
/// Renders the training progress.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `item` - The training progress.
|
|
||||||
fn render_train(&mut self, item: TrainingProgress);
|
|
||||||
|
|
||||||
/// Renders the validation progress.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `item` - The validation progress.
|
|
||||||
fn render_valid(&mut self, item: TrainingProgress);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A container for the metrics held by a metrics callback.
|
|
||||||
pub(crate) struct Metrics<T, V>
|
|
||||||
where
|
|
||||||
T: Send + Sync + 'static,
|
|
||||||
V: Send + Sync + 'static,
|
|
||||||
{
|
|
||||||
pub(crate) train: Vec<Box<dyn MetricUpdater<T>>>,
|
|
||||||
pub(crate) valid: Vec<Box<dyn MetricUpdater<V>>>,
|
|
||||||
pub(crate) train_numeric: Vec<Box<dyn NumericMetricUpdater<T>>>,
|
|
||||||
pub(crate) valid_numeric: Vec<Box<dyn NumericMetricUpdater<V>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T, V> Metrics<T, V>
|
|
||||||
where
|
|
||||||
T: Send + Sync + 'static,
|
|
||||||
V: Send + Sync + 'static,
|
|
||||||
{
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
train: vec![],
|
|
||||||
valid: vec![],
|
|
||||||
train_numeric: vec![],
|
|
||||||
valid_numeric: vec![],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T> From<LearnerItem<T>> for TrainingProgress {
|
|
||||||
fn from(item: LearnerItem<T>) -> Self {
|
|
||||||
Self {
|
|
||||||
progress: item.progress,
|
|
||||||
epoch: item.epoch,
|
|
||||||
epoch_total: item.epoch_total,
|
|
||||||
iteration: item.iteration,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T> From<&LearnerItem<T>> for MetricMetadata {
|
|
||||||
fn from(item: &LearnerItem<T>) -> Self {
|
|
||||||
Self {
|
|
||||||
progress: item.progress.clone(),
|
|
||||||
epoch: item.epoch,
|
|
||||||
epoch_total: item.epoch_total,
|
|
||||||
iteration: item.iteration,
|
|
||||||
lr: item.lr,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) trait NumericMetricUpdater<T>: Send + Sync {
|
|
||||||
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> (MetricEntry, f64);
|
|
||||||
fn clear(&mut self);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) trait MetricUpdater<T>: Send + Sync {
|
|
||||||
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry;
|
|
||||||
fn clear(&mut self);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(new)]
|
|
||||||
pub(crate) struct MetricWrapper<M> {
|
|
||||||
metric: M,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T, M> NumericMetricUpdater<T> for MetricWrapper<M>
|
|
||||||
where
|
|
||||||
T: 'static,
|
|
||||||
M: Metric + Numeric + 'static,
|
|
||||||
T: Adaptor<M::Input>,
|
|
||||||
{
|
|
||||||
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> (MetricEntry, f64) {
|
|
||||||
let update = self.metric.update(&item.item.adapt(), metadata);
|
|
||||||
let numeric = self.metric.value();
|
|
||||||
|
|
||||||
(update, numeric)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn clear(&mut self) {
|
|
||||||
self.metric.clear()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T, M> MetricUpdater<T> for MetricWrapper<M>
|
|
||||||
where
|
|
||||||
T: 'static,
|
|
||||||
M: Metric + 'static,
|
|
||||||
T: Adaptor<M::Input>,
|
|
||||||
{
|
|
||||||
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry {
|
|
||||||
self.metric.update(&item.item.adapt(), metadata)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn clear(&mut self) {
|
|
||||||
self.metric.clear()
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,7 +1,4 @@
|
||||||
/// Callback module for training progress.
|
/// State module.
|
||||||
pub mod callback;
|
|
||||||
|
|
||||||
/// State module for callback metrics.
|
|
||||||
pub mod state;
|
pub mod state;
|
||||||
|
|
||||||
mod acc;
|
mod acc;
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use super::{format_float, MetricEntry, Numeric};
|
use crate::metric::{format_float, MetricEntry, Numeric};
|
||||||
|
|
||||||
/// Usefull utility to implement numeric metrics.
|
/// Usefull utility to implement numeric metrics.
|
||||||
///
|
///
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
use burn_core::data::dataloader::Progress;
|
||||||
|
|
||||||
|
use crate::metric::MetricEntry;
|
||||||
|
|
||||||
|
/// Trait for rendering metrics.
|
||||||
|
pub trait MetricsRenderer: Send + Sync {
|
||||||
|
/// Updates the training metric state.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `state` - The metric state.
|
||||||
|
fn update_train(&mut self, state: MetricState);
|
||||||
|
|
||||||
|
/// Updates the validation metric state.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `state` - The metric state.
|
||||||
|
fn update_valid(&mut self, state: MetricState);
|
||||||
|
|
||||||
|
/// Renders the training progress.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `item` - The training progress.
|
||||||
|
fn render_train(&mut self, item: TrainingProgress);
|
||||||
|
|
||||||
|
/// Renders the validation progress.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `item` - The validation progress.
|
||||||
|
fn render_valid(&mut self, item: TrainingProgress);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The state of a metric.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum MetricState {
|
||||||
|
/// A generic metric.
|
||||||
|
Generic(MetricEntry),
|
||||||
|
|
||||||
|
/// A numeric metric.
|
||||||
|
Numeric(MetricEntry, f64),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Training progress.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct TrainingProgress {
|
||||||
|
/// The progress.
|
||||||
|
pub progress: Progress,
|
||||||
|
|
||||||
|
/// The epoch.
|
||||||
|
pub epoch: usize,
|
||||||
|
|
||||||
|
/// The total number of epochs.
|
||||||
|
pub epoch_total: usize,
|
||||||
|
|
||||||
|
/// The iteration.
|
||||||
|
pub iteration: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrainingProgress {
|
||||||
|
/// Creates a new empty training progress.
|
||||||
|
pub fn none() -> Self {
|
||||||
|
Self {
|
||||||
|
progress: Progress {
|
||||||
|
items_processed: 0,
|
||||||
|
items_total: 0,
|
||||||
|
},
|
||||||
|
epoch: 0,
|
||||||
|
epoch_total: 0,
|
||||||
|
iteration: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,5 +1,4 @@
|
||||||
mod base;
|
mod base;
|
||||||
|
|
||||||
pub use base::*;
|
pub use base::*;
|
||||||
|
|
||||||
#[cfg(not(feature = "tui"))]
|
#[cfg(not(feature = "tui"))]
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::metric::callback::TrainingProgress;
|
use crate::renderer::TrainingProgress;
|
||||||
|
|
||||||
use super::{FullHistoryPlot, RecentHistoryPlot, TerminalFrame};
|
use super::{FullHistoryPlot, RecentHistoryPlot, TerminalFrame};
|
||||||
use crossterm::event::{Event, KeyCode};
|
use crossterm::event::{Event, KeyCode};
|
|
@ -1,5 +1,6 @@
|
||||||
|
use crate::renderer::TrainingProgress;
|
||||||
|
|
||||||
use super::TerminalFrame;
|
use super::TerminalFrame;
|
||||||
use crate::metric::callback::TrainingProgress;
|
|
||||||
use ratatui::{
|
use ratatui::{
|
||||||
prelude::{Alignment, Constraint, Direction, Layout, Rect},
|
prelude::{Alignment, Constraint, Direction, Layout, Rect},
|
||||||
style::{Color, Style, Stylize},
|
style::{Color, Style, Stylize},
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::metric::callback::tui::NumericMetricsState;
|
use crate::renderer::{tui::NumericMetricsState, MetricsRenderer};
|
||||||
use crate::metric::callback::{MetricState, MetricsRenderer, TrainingProgress};
|
use crate::renderer::{MetricState, TrainingProgress};
|
||||||
use crate::TrainingInterrupter;
|
use crate::TrainingInterrupter;
|
||||||
use crossterm::{
|
use crossterm::{
|
||||||
event::{self, Event, KeyCode},
|
event::{self, Event, KeyCode},
|
|
@ -1,5 +1,5 @@
|
||||||
use super::TerminalFrame;
|
use super::TerminalFrame;
|
||||||
use crate::metric::callback::TrainingProgress;
|
use crate::renderer::TrainingProgress;
|
||||||
use ratatui::{
|
use ratatui::{
|
||||||
prelude::{Alignment, Rect},
|
prelude::{Alignment, Rect},
|
||||||
style::{Color, Style, Stylize},
|
style::{Color, Style, Stylize},
|
|
@ -1,5 +1,5 @@
|
||||||
use burn::data::dataset::source::huggingface::MNISTDataset;
|
use burn::data::dataset::source::huggingface::MNISTDataset;
|
||||||
use burn::train::metric::callback::{MetricState, MetricsRenderer, TrainingProgress};
|
use burn::train::renderer::{MetricState, MetricsRenderer, TrainingProgress};
|
||||||
use burn::train::LearnerBuilder;
|
use burn::train::LearnerBuilder;
|
||||||
use burn::{
|
use burn::{
|
||||||
config::Config, data::dataloader::DataLoaderBuilder, optim::AdamConfig,
|
config::Config, data::dataloader::DataLoaderBuilder, optim::AdamConfig,
|
||||||
|
|
|
@ -88,10 +88,10 @@ pub fn train<B: ADBackend>(artifact_dir: &str, config: TrainingConfig, device: B
|
||||||
.build(MNISTDataset::test());
|
.build(MNISTDataset::test());
|
||||||
|
|
||||||
let learner = LearnerBuilder::new(artifact_dir)
|
let learner = LearnerBuilder::new(artifact_dir)
|
||||||
.metric_train_plot(AccuracyMetric::new())
|
.metric_train_numeric(AccuracyMetric::new())
|
||||||
.metric_valid_plot(AccuracyMetric::new())
|
.metric_valid_numeric(AccuracyMetric::new())
|
||||||
.metric_train_plot(LossMetric::new())
|
.metric_train_numeric(LossMetric::new())
|
||||||
.metric_valid_plot(LossMetric::new())
|
.metric_valid_numeric(LossMetric::new())
|
||||||
.with_file_checkpointer(1, CompactRecorder::new())
|
.with_file_checkpointer(1, CompactRecorder::new())
|
||||||
.devices(vec![device])
|
.devices(vec![device])
|
||||||
.num_epochs(config.num_epochs)
|
.num_epochs(config.num_epochs)
|
||||||
|
|
|
@ -58,16 +58,16 @@ pub fn run<B: ADBackend>(device: B::Device) {
|
||||||
|
|
||||||
// Model
|
// Model
|
||||||
let learner = LearnerBuilder::new(ARTIFACT_DIR)
|
let learner = LearnerBuilder::new(ARTIFACT_DIR)
|
||||||
.metric_train_plot(AccuracyMetric::new())
|
.metric_train_numeric(AccuracyMetric::new())
|
||||||
.metric_valid_plot(AccuracyMetric::new())
|
.metric_valid_numeric(AccuracyMetric::new())
|
||||||
.metric_train_plot(CpuUse::new())
|
.metric_train_numeric(CpuUse::new())
|
||||||
.metric_valid_plot(CpuUse::new())
|
.metric_valid_numeric(CpuUse::new())
|
||||||
.metric_train_plot(CpuMemory::new())
|
.metric_train_numeric(CpuMemory::new())
|
||||||
.metric_valid_plot(CpuMemory::new())
|
.metric_valid_numeric(CpuMemory::new())
|
||||||
.metric_train_plot(CpuTemperature::new())
|
.metric_train_numeric(CpuTemperature::new())
|
||||||
.metric_valid_plot(CpuTemperature::new())
|
.metric_valid_numeric(CpuTemperature::new())
|
||||||
.metric_train_plot(LossMetric::new())
|
.metric_train_numeric(LossMetric::new())
|
||||||
.metric_valid_plot(LossMetric::new())
|
.metric_valid_numeric(LossMetric::new())
|
||||||
.with_file_checkpointer(1, CompactRecorder::new())
|
.with_file_checkpointer(1, CompactRecorder::new())
|
||||||
.devices(vec![device])
|
.devices(vec![device])
|
||||||
.num_epochs(config.num_epochs)
|
.num_epochs(config.num_epochs)
|
||||||
|
|
|
@ -95,9 +95,9 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
|
||||||
.metric_valid(CUDAMetric::new())
|
.metric_valid(CUDAMetric::new())
|
||||||
.metric_train(AccuracyMetric::new())
|
.metric_train(AccuracyMetric::new())
|
||||||
.metric_valid(AccuracyMetric::new())
|
.metric_valid(AccuracyMetric::new())
|
||||||
.metric_train_plot(LossMetric::new())
|
.metric_train_numeric(LossMetric::new())
|
||||||
.metric_valid_plot(LossMetric::new())
|
.metric_valid_numeric(LossMetric::new())
|
||||||
.metric_train_plot(LearningRateMetric::new())
|
.metric_train_numeric(LearningRateMetric::new())
|
||||||
.with_file_checkpointer(2, CompactRecorder::new())
|
.with_file_checkpointer(2, CompactRecorder::new())
|
||||||
.devices(vec![device])
|
.devices(vec![device])
|
||||||
.num_epochs(config.num_epochs)
|
.num_epochs(config.num_epochs)
|
||||||
|
|
|
@ -70,11 +70,11 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(
|
||||||
let learner = LearnerBuilder::new(artifact_dir)
|
let learner = LearnerBuilder::new(artifact_dir)
|
||||||
.metric_train(CUDAMetric::new())
|
.metric_train(CUDAMetric::new())
|
||||||
.metric_valid(CUDAMetric::new())
|
.metric_valid(CUDAMetric::new())
|
||||||
.metric_train_plot(AccuracyMetric::new().with_pad_token(tokenizer.pad_token()))
|
.metric_train_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token()))
|
||||||
.metric_valid_plot(AccuracyMetric::new().with_pad_token(tokenizer.pad_token()))
|
.metric_valid_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token()))
|
||||||
.metric_train(LossMetric::new())
|
.metric_train(LossMetric::new())
|
||||||
.metric_valid(LossMetric::new())
|
.metric_valid(LossMetric::new())
|
||||||
.metric_train_plot(LearningRateMetric::new())
|
.metric_train_numeric(LearningRateMetric::new())
|
||||||
.with_file_checkpointer(2, CompactRecorder::new())
|
.with_file_checkpointer(2, CompactRecorder::new())
|
||||||
.devices(vec![device])
|
.devices(vec![device])
|
||||||
.grads_accumulation(accum)
|
.grads_accumulation(accum)
|
||||||
|
|
Loading…
Reference in New Issue