Refactor burn-train (#847)

This commit is contained in:
Nathaniel Simard 2023-10-05 13:10:54 -04:00 committed by GitHub
parent 9afc76303f
commit 904ff1a974
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 503 additions and 441 deletions

View File

@ -63,6 +63,6 @@ pub type TestADBackend = burn_autodiff::ADBackendDecorator<TestBackend>;
/// Type alias for the learning rate.
///
/// LearningRate also implements [learning rate scheduler](crate::lr_scheduler::LRScheduler) so it
/// LearningRate also implements [learning rate scheduler](crate::lr_scheduler::LrScheduler) so it
/// can be used for constant learning rate.
pub type LearningRate = f64; // We could potentially change the type.

View File

@ -1,7 +1,7 @@
use crate::{record::Record, LearningRate};
/// Learning rate scheduler defines how the learning rate will evolve during training.
pub trait LRScheduler: Send + Sync {
pub trait LrScheduler: Send + Sync {
/// Scheduler associative type to be used when saving and loading the state.
type Record: Record;

View File

@ -1,23 +1,23 @@
use super::LRScheduler;
use super::LrScheduler;
use crate::LearningRate;
/// Constant learning rate implementing [learning rate scheduler](LRScheduler).
/// Constant learning rate implementing [learning rate scheduler](LrScheduler).
///
/// # Notes
///
/// You can also use [learning rate](LearningRate) which the same effect.
#[derive(new, Clone, Debug)]
pub struct ConstantLR {
pub struct ConstantLr {
lr: LearningRate,
}
impl From<LearningRate> for ConstantLR {
impl From<LearningRate> for ConstantLr {
fn from(lr: LearningRate) -> Self {
Self { lr }
}
}
impl LRScheduler for ConstantLR {
impl LrScheduler for ConstantLr {
type Record = ();
fn step(&mut self) -> LearningRate {
@ -31,7 +31,7 @@ impl LRScheduler for ConstantLR {
}
}
impl LRScheduler for LearningRate {
impl LrScheduler for LearningRate {
type Record = ();
fn step(&mut self) -> LearningRate {

View File

@ -1,11 +1,11 @@
use crate as burn;
use super::LRScheduler;
use super::LrScheduler;
use crate::{config::Config, LearningRate};
/// Configuration to create a [noam](NoamLRScheduler) learning rate scheduler.
/// Configuration to create a [noam](NoamLrScheduler) learning rate scheduler.
#[derive(Config)]
pub struct NoamLRSchedulerConfig {
pub struct NoamLrSchedulerConfig {
/// The initial learning rate.
init_lr: LearningRate,
/// The number of steps before the exponential decay stats.
@ -18,17 +18,17 @@ pub struct NoamLRSchedulerConfig {
/// Noam learning rate scheduler as described in [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
#[derive(Clone, Debug)]
pub struct NoamLRScheduler {
pub struct NoamLrScheduler {
warmup_steps: f64,
embedding_size: f64,
init_lr: LearningRate,
step: f64,
}
impl NoamLRSchedulerConfig {
/// Initialize a new [noam](NoamLRScheduler) learning rate scheduler.
pub fn init(&self) -> NoamLRScheduler {
NoamLRScheduler {
impl NoamLrSchedulerConfig {
/// Initialize a new [noam](NoamLrScheduler) learning rate scheduler.
pub fn init(&self) -> NoamLrScheduler {
NoamLrScheduler {
warmup_steps: self.warmup_steps as f64,
embedding_size: self.model_size as f64,
init_lr: self.init_lr,
@ -37,7 +37,7 @@ impl NoamLRSchedulerConfig {
}
}
impl LRScheduler for NoamLRScheduler {
impl LrScheduler for NoamLrScheduler {
type Record = usize;
fn step(&mut self) -> LearningRate {
@ -66,7 +66,7 @@ mod tests {
#[test]
fn test_function_increase_and_decrease() {
let warmup_steps = 100;
let mut scheduler = NoamLRSchedulerConfig::new(10.0)
let mut scheduler = NoamLrSchedulerConfig::new(10.0)
.with_warmup_steps(warmup_steps)
.init();
let mut lr_current = 0.0;

View File

@ -1,8 +1,5 @@
use super::{LearnerCallback, LearnerItem};
use std::{
sync::{mpsc, Mutex},
thread::JoinHandle,
};
use std::{sync::mpsc, thread::JoinHandle};
enum Message<T, V> {
LogTrain(LearnerItem<T>),
@ -19,30 +16,29 @@ pub struct AsyncTrainerCallback<T, V> {
}
#[derive(new)]
struct CallbackThread<T, V> {
callback: Mutex<Box<dyn LearnerCallback<T, V>>>,
struct CallbackThread<C, T, V> {
callback: C,
receiver: mpsc::Receiver<Message<T, V>>,
}
impl<T, V> CallbackThread<T, V> {
fn run(self) {
impl<C, T, V> CallbackThread<C, T, V>
where
C: LearnerCallback<ItemTrain = T, ItemValid = V>,
{
fn run(mut self) {
for item in self.receiver.iter() {
match item {
Message::LogTrain(item) => {
let mut callback = self.callback.lock().unwrap();
callback.on_train_item(item);
self.callback.on_train_item(item);
}
Message::ClearTrain(epoch) => {
let mut callback = self.callback.lock().unwrap();
callback.on_train_end_epoch(epoch);
self.callback.on_train_end_epoch(epoch);
}
Message::LogValid(item) => {
let mut callback = self.callback.lock().unwrap();
callback.on_valid_item(item);
self.callback.on_valid_item(item);
}
Message::ClearValid(epoch) => {
let mut callback = self.callback.lock().unwrap();
callback.on_valid_end_epoch(epoch);
self.callback.on_valid_end_epoch(epoch);
}
Message::End => {
return;
@ -54,9 +50,12 @@ impl<T, V> CallbackThread<T, V> {
impl<T: Send + Sync + 'static, V: Send + Sync + 'static> AsyncTrainerCallback<T, V> {
/// Create a new async trainer callback.
pub fn new(callback: Box<dyn LearnerCallback<T, V>>) -> Self {
pub fn new<C>(callback: C) -> Self
where
C: LearnerCallback<ItemTrain = T, ItemValid = V> + 'static,
{
let (sender, receiver) = mpsc::channel();
let thread = CallbackThread::new(Mutex::new(callback), receiver);
let thread = CallbackThread::new(callback, receiver);
let handler = std::thread::spawn(move || thread.run());
let handler = Some(handler);
@ -65,7 +64,10 @@ impl<T: Send + Sync + 'static, V: Send + Sync + 'static> AsyncTrainerCallback<T,
}
}
impl<T: Send, V: Send> LearnerCallback<T, V> for AsyncTrainerCallback<T, V> {
impl<T: Send, V: Send> LearnerCallback for AsyncTrainerCallback<T, V> {
type ItemTrain = T;
type ItemValid = V;
fn on_train_item(&mut self, item: LearnerItem<T>) {
self.sender.send(Message::LogTrain(item)).unwrap();
}

View File

@ -1,12 +1,17 @@
use burn_core::{data::dataloader::Progress, LearningRate};
/// The base trait for trainer callbacks.
pub trait LearnerCallback<T, V>: Send {
pub trait LearnerCallback: Send {
/// Training item.
type ItemTrain;
/// Validation item.
type ItemValid;
/// Called when a training item is logged.
fn on_train_item(&mut self, _item: LearnerItem<T>) {}
fn on_train_item(&mut self, _item: LearnerItem<Self::ItemTrain>) {}
/// Called when a validation item is logged.
fn on_valid_item(&mut self, _item: LearnerItem<V>) {}
fn on_valid_item(&mut self, _item: LearnerItem<Self::ItemValid>) {}
/// Called when a training epoch is finished.
fn on_train_end_epoch(&mut self, _epoch: usize) {}

View File

@ -1,22 +1,27 @@
use super::{Checkpointer, CheckpointerError};
use burn_core::record::Record;
use std::sync::{mpsc, Arc};
use std::sync::mpsc;
enum Message<R> {
Restore(usize, mpsc::SyncSender<Result<R, CheckpointerError>>),
Save(usize, R),
End,
}
#[derive(new)]
struct CheckpointerThread<R> {
checkpointer: Arc<dyn Checkpointer<R> + Send + Sync>,
struct CheckpointerThread<C, R> {
checkpointer: C,
receiver: mpsc::Receiver<Message<R>>,
}
impl<R: Record> CheckpointerThread<R> {
impl<C: Checkpointer<R>, R: Record> CheckpointerThread<C, R> {
fn run(self) {
for item in self.receiver.iter() {
match item {
Message::Restore(epoch, sender) => {
let record = self.checkpointer.restore(epoch);
sender.send(record).unwrap();
}
Message::Save(epoch, state) => self.checkpointer.save(epoch, state).unwrap(),
Message::End => {
return;
@ -27,9 +32,8 @@ impl<R: Record> CheckpointerThread<R> {
}
/// Async checkpointer.
pub struct AsyncCheckpointer<E> {
checkpointer: Arc<dyn Checkpointer<E> + Send + Sync>,
sender: mpsc::SyncSender<Message<E>>,
pub struct AsyncCheckpointer<Record> {
sender: mpsc::SyncSender<Message<Record>>,
handler: Option<std::thread::JoinHandle<()>>,
}
@ -43,17 +47,16 @@ impl<R: Record + 'static> AsyncCheckpointer<R> {
/// # Returns
///
/// The async checkpointer.
pub fn new(checkpointer: Arc<dyn Checkpointer<R> + Send + Sync>) -> Self {
pub fn new<C>(checkpointer: C) -> Self
where
C: Checkpointer<R> + Send + 'static,
{
// Only on checkpoint can be done in advance.
let (sender, receiver) = mpsc::sync_channel(0);
let thread = CheckpointerThread::new(checkpointer.clone(), receiver);
let thread = CheckpointerThread::new(checkpointer, receiver);
let handler = Some(std::thread::spawn(move || thread.run()));
Self {
checkpointer,
sender,
handler,
}
Self { sender, handler }
}
}
@ -68,7 +71,16 @@ where
}
fn restore(&self, epoch: usize) -> Result<R, CheckpointerError> {
self.checkpointer.restore(epoch)
let (sender, receiver) = mpsc::sync_channel(1);
self.sender
.send(Message::Restore(epoch, sender))
.map_err(|e| CheckpointerError::Unknown(e.to_string()))?;
if let Ok(record) = receiver.recv() {
return record;
};
Err(CheckpointerError::Unknown("Channel error.".to_string()))
}
}

View File

@ -0,0 +1,64 @@
use crate::{checkpoint::Checkpointer, LearnerCallback};
use burn_core::{
lr_scheduler::LrScheduler,
module::{ADModule, Module},
optim::Optimizer,
tensor::backend::ADBackend,
};
use std::marker::PhantomData;
/// All components necessary to train a model grouped in one trait.
pub trait LearnerComponents {
/// The backend in used for the training.
type Backend: ADBackend;
/// The learning rate scheduler used for the training.
type LrScheduler: LrScheduler;
/// The model to train.
type Model: ADModule<Self::Backend> + core::fmt::Display + 'static;
/// The optimizer used for the training.
type Optimizer: Optimizer<Self::Model, Self::Backend>;
/// The checkpointer used for the model.
type CheckpointerModel: Checkpointer<<Self::Model as Module<Self::Backend>>::Record>;
/// The checkpointer used for the optimizer.
type CheckpointerOptimizer: Checkpointer<
<Self::Optimizer as Optimizer<Self::Model, Self::Backend>>::Record,
>;
/// The checkpointer used for the scheduler.
type CheckpointerLrScheduler: Checkpointer<<Self::LrScheduler as LrScheduler>::Record>;
/// Callback used for training tracking.
type Callback: LearnerCallback + 'static;
}
/// Concrete type that implements [training components trait](TrainingComponents).
pub struct LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, C> {
_backend: PhantomData<B>,
_lr_scheduler: PhantomData<LR>,
_model: PhantomData<M>,
_optimizer: PhantomData<O>,
_checkpointer_model: PhantomData<CM>,
_checkpointer_optim: PhantomData<CO>,
_checkpointer_scheduler: PhantomData<CS>,
_callback: PhantomData<C>,
}
impl<B, LR, M, O, CM, CO, CS, C> LearnerComponents
for LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, C>
where
B: ADBackend,
LR: LrScheduler,
M: ADModule<B> + core::fmt::Display + 'static,
O: Optimizer<M, B>,
CM: Checkpointer<M::Record>,
CO: Checkpointer<O::Record>,
CS: Checkpointer<LR::Record>,
C: LearnerCallback + 'static,
{
type Backend = B;
type LrScheduler = LR;
type Model = M;
type Optimizer = O;
type CheckpointerModel = CM;
type CheckpointerOptimizer = CO;
type CheckpointerLrScheduler = CS;
type Callback = C;
}

View File

@ -1,91 +1,67 @@
use crate::checkpoint::Checkpointer;
use crate::LearnerCallback;
use burn_core::lr_scheduler::LRScheduler;
use burn_core::module::{ADModule, Module};
use crate::components::LearnerComponents;
use burn_core::lr_scheduler::LrScheduler;
use burn_core::module::Module;
use burn_core::optim::Optimizer;
use burn_core::tensor::backend::ADBackend;
use burn_core::tensor::backend::Backend;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
/// Learner struct encapsulating all components necessary to train a Neural Network model.
///
/// To create a learner, use the [builder](crate::learner::LearnerBuilder) struct.
pub struct Learner<B, Model, Optim, LR, TrainOutput, ValidOutput>
where
B: ADBackend,
LR: LRScheduler,
Model: ADModule<B>,
Optim: Optimizer<Model, B>,
{
pub(super) model: Model,
pub(super) optim: Optim,
pub(super) lr_scheduler: LR,
pub(super) num_epochs: usize,
pub(super) callback: Box<dyn LearnerCallback<TrainOutput, ValidOutput>>,
pub(super) checkpoint: Option<usize>,
pub(super) checkpointer_model: CheckpointModel<Model, B>,
pub(super) checkpointer_optimizer: CheckpointOptim<Optim, Model, B>,
pub(super) checkpointer_scheduler: CheckpointScheduler<LR>,
pub(super) grad_accumulation: Option<usize>,
pub(super) devices: Vec<B::Device>,
pub(super) interrupter: TrainingInterrupter,
pub struct Learner<LC: LearnerComponents> {
pub(crate) model: LC::Model,
pub(crate) optim: LC::Optimizer,
pub(crate) lr_scheduler: LC::LrScheduler,
pub(crate) num_epochs: usize,
pub(crate) checkpoint: Option<usize>,
pub(crate) grad_accumulation: Option<usize>,
pub(crate) checkpointer: Option<LearnerCheckpointer<LC>>,
pub(crate) devices: Vec<<LC::Backend as Backend>::Device>,
pub(crate) callback: LC::Callback,
pub(crate) interrupter: TrainingInterrupter,
}
type CheckpointModel<M, B> = Option<Box<dyn Checkpointer<<M as Module<B>>::Record>>>;
type CheckpointOptim<O, M, B> = Option<Box<dyn Checkpointer<<O as Optimizer<M, B>>::Record>>>;
type CheckpointScheduler<LR> = Option<Box<dyn Checkpointer<<LR as LRScheduler>::Record>>>;
#[derive(new)]
pub(crate) struct LearnerCheckpointer<LC: LearnerComponents> {
model: LC::CheckpointerModel,
optim: LC::CheckpointerOptimizer,
lr_scheduler: LC::CheckpointerLrScheduler,
}
impl<B, Model, Optim, LR, TrainOutput, ValidOutput>
Learner<B, Model, Optim, LR, TrainOutput, ValidOutput>
where
ValidOutput: Send + Sync + 'static,
TrainOutput: Send + Sync + 'static,
B: ADBackend,
Model: ADModule<B>,
Optim: Optimizer<Model, B>,
LR: LRScheduler,
{
pub(super) fn checkpoint(
model: &Model,
optim: &Optim,
scheduler: &LR,
checkpointer_model: &CheckpointModel<Model, B>,
checkpointer_optimizer: &CheckpointOptim<Optim, Model, B>,
checkpointer_scheduler: &CheckpointScheduler<LR>,
impl<LC: LearnerComponents> LearnerCheckpointer<LC> {
pub(crate) fn checkpoint(
&self,
model: &LC::Model,
optim: &LC::Optimizer,
scheduler: &LC::LrScheduler,
epoch: usize,
) {
if let Some(checkpointer) = &checkpointer_model {
checkpointer
.save(epoch, model.clone().into_record())
self.model.save(epoch, model.clone().into_record()).unwrap();
self.optim.save(epoch, optim.to_record()).unwrap();
self.lr_scheduler
.save(epoch, scheduler.to_record())
.unwrap();
}
if let Some(checkpointer) = &checkpointer_optimizer {
checkpointer.save(epoch, optim.to_record()).unwrap();
}
pub(crate) fn load_checkpoint(
&self,
model: LC::Model,
optim: LC::Optimizer,
scheduler: LC::LrScheduler,
epoch: usize,
) -> (LC::Model, LC::Optimizer, LC::LrScheduler) {
let record = self.model.restore(epoch).unwrap();
let model = model.load_record(record);
if let Some(checkpointer) = &checkpointer_scheduler {
checkpointer.save(epoch, scheduler.to_record()).unwrap();
}
}
let record = self.optim.restore(epoch).unwrap();
let optim = optim.load_record(record);
pub(super) fn load_checkpoint(mut self, epoch: usize) -> Self {
if let Some(checkpointer) = &self.checkpointer_model {
let record = checkpointer.restore(epoch).unwrap();
self.model = self.model.load_record(record);
}
let record = self.lr_scheduler.restore(epoch).unwrap();
let scheduler = scheduler.load_record(record);
if let Some(checkpointer) = &self.checkpointer_optimizer {
let record = checkpointer.restore(epoch).unwrap();
self.optim = self.optim.load_record(record);
}
if let Some(checkpointer) = &self.checkpointer_scheduler {
let record = checkpointer.restore(epoch).unwrap();
self.lr_scheduler = self.lr_scheduler.load_record(record);
}
self
(model, optim, scheduler)
}
}

View File

@ -1,21 +1,20 @@
use super::log::install_file_logger;
use super::Learner;
use crate::checkpoint::{AsyncCheckpointer, Checkpointer, FileCheckpointer};
use crate::checkpoint::{AsyncCheckpointer, FileCheckpointer};
use crate::components::LearnerComponentsMarker;
use crate::learner::base::TrainingInterrupter;
use crate::logger::{FileMetricLogger, MetricLogger};
use crate::metric::dashboard::{
default_renderer, Dashboard, DashboardRenderer, MetricWrapper, Metrics,
use crate::metric::callback::{
default_renderer, MetricWrapper, Metrics, MetricsCallback, MetricsRenderer,
};
use crate::metric::{Adaptor, Metric};
use crate::AsyncTrainerCallback;
use burn_core::lr_scheduler::LRScheduler;
use crate::{AsyncTrainerCallback, LearnerCheckpointer};
use burn_core::lr_scheduler::LrScheduler;
use burn_core::module::ADModule;
use burn_core::optim::Optimizer;
use burn_core::record::FileRecorder;
use burn_core::tensor::backend::ADBackend;
use crate::learner::base::TrainingInterrupter;
use std::sync::Arc;
/// Struct to configure and create a [learner](Learner).
pub struct LearnerBuilder<B, T, V, M, O, S>
where
@ -24,11 +23,17 @@ where
B: ADBackend,
M: ADModule<B>,
O: Optimizer<M, B>,
S: LRScheduler,
S: LrScheduler,
{
checkpointer_model: Option<Arc<dyn Checkpointer<M::Record> + Send + Sync>>,
checkpointer_optimizer: Option<Arc<dyn Checkpointer<O::Record> + Send + Sync>>,
checkpointer_scheduler: Option<Arc<dyn Checkpointer<S::Record> + Send + Sync>>,
// Not that complex and very convenient when the traits are
// already constrained correctly. Extracting in another type
// would be more complex.
#[allow(clippy::type_complexity)]
checkpointers: Option<(
AsyncCheckpointer<M::Record>,
AsyncCheckpointer<O::Record>,
AsyncCheckpointer<S::Record>,
)>,
num_epochs: usize,
checkpoint: Option<usize>,
directory: String,
@ -36,20 +41,20 @@ where
devices: Vec<B::Device>,
metric_logger_train: Option<Box<dyn MetricLogger + 'static>>,
metric_logger_valid: Option<Box<dyn MetricLogger + 'static>>,
renderer: Option<Box<dyn DashboardRenderer + 'static>>,
renderer: Option<Box<dyn MetricsRenderer + 'static>>,
metrics: Metrics<T, V>,
interrupter: TrainingInterrupter,
log_to_file: bool,
}
impl<B, T, V, Model, Optim, LR> LearnerBuilder<B, T, V, Model, Optim, LR>
impl<B, T, V, M, O, S> LearnerBuilder<B, T, V, M, O, S>
where
B: ADBackend,
T: Send + Sync + 'static,
V: Send + Sync + 'static,
B: ADBackend,
Model: ADModule<B>,
Optim: Optimizer<Model, B>,
LR: LRScheduler,
M: ADModule<B> + core::fmt::Display + 'static,
O: Optimizer<M, B>,
S: LrScheduler,
{
/// Creates a new learner builder.
///
@ -60,9 +65,7 @@ where
Self {
num_epochs: 1,
checkpoint: None,
checkpointer_model: None,
checkpointer_optimizer: None,
checkpointer_scheduler: None,
checkpointers: None,
directory: directory.to_string(),
grad_accumulation: None,
devices: vec![B::Device::default()],
@ -96,18 +99,18 @@ where
/// # Arguments
///
/// * `renderer` - The custom renderer.
pub fn renderer<DR>(mut self, renderer: DR) -> Self
pub fn renderer<MR>(mut self, renderer: MR) -> Self
where
DR: DashboardRenderer + 'static,
MR: MetricsRenderer + 'static,
{
self.renderer = Some(Box::new(renderer));
self
}
/// Register a training metric.
pub fn metric_train<M: Metric + 'static>(mut self, metric: M) -> Self
pub fn metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self
where
T: Adaptor<M::Input>,
T: Adaptor<Me::Input>,
{
self.metrics
.train
@ -116,9 +119,9 @@ where
}
/// Register a validation metric.
pub fn metric_valid<M: Metric + 'static>(mut self, metric: M) -> Self
pub fn metric_valid<Me: Metric + 'static>(mut self, metric: Me) -> Self
where
V: Adaptor<M::Input>,
V: Adaptor<Me::Input>,
{
self.metrics
.valid
@ -148,10 +151,10 @@ where
/// Only [numeric](crate::metric::Numeric) metric can be displayed on a plot.
/// If the same metric is also registered for the [validation split](Self::metric_valid_plot),
/// the same graph will be used for both.
pub fn metric_train_plot<M>(mut self, metric: M) -> Self
pub fn metric_train_plot<Me>(mut self, metric: Me) -> Self
where
M: Metric + crate::metric::Numeric + 'static,
T: Adaptor<M::Input>,
Me: Metric + crate::metric::Numeric + 'static,
T: Adaptor<Me::Input>,
{
self.metrics
.train_numeric
@ -166,12 +169,12 @@ where
/// Only [numeric](crate::metric::Numeric) metric can be displayed on a plot.
/// If the same metric is also registered for the [training split](Self::metric_train_plot),
/// the same graph will be used for both.
pub fn metric_valid_plot<M: Metric + crate::metric::Numeric + 'static>(
pub fn metric_valid_plot<Me: Metric + crate::metric::Numeric + 'static>(
mut self,
metric: M,
metric: Me,
) -> Self
where
V: Adaptor<M::Input>,
V: Adaptor<Me::Input>,
{
self.metrics
.valid_numeric
@ -219,41 +222,64 @@ where
pub fn with_file_checkpointer<FR>(mut self, num_keep: usize, recorder: FR) -> Self
where
FR: FileRecorder + 'static,
O::Record: 'static,
M::Record: 'static,
S::Record: 'static,
{
self.checkpointer_model = Some(Arc::new(FileCheckpointer::new(
let checkpointer_model = FileCheckpointer::new(
recorder.clone(),
format!("{}/checkpoint", self.directory).as_str(),
"model",
num_keep,
)));
self.checkpointer_optimizer = Some(Arc::new(FileCheckpointer::new(
);
let checkpointer_optimizer = FileCheckpointer::new(
recorder.clone(),
format!("{}/checkpoint", self.directory).as_str(),
"optim",
num_keep,
)));
self.checkpointer_scheduler = Some(Arc::new(FileCheckpointer::new(
);
let checkpointer_scheduler = FileCheckpointer::new(
recorder,
format!("{}/checkpoint", self.directory).as_str(),
"scheduler",
num_keep,
)));
);
self.checkpointers = Some((
AsyncCheckpointer::new(checkpointer_model),
AsyncCheckpointer::new(checkpointer_optimizer),
AsyncCheckpointer::new(checkpointer_scheduler),
));
self
}
/// Create the [learner](Learner) from a [model](ADModule) and an [optimizer](Optimizer).
/// The [learning rate scheduler](LRScheduler) can also be a simple
/// The [learning rate scheduler](LrScheduler) can also be a simple
/// [learning rate](burn_core::LearningRate).
#[allow(clippy::type_complexity)] // The goal for the builder is to handle all types and
// creates a clean learner.
pub fn build(
self,
model: Model,
optim: Optim,
lr_scheduler: LR,
) -> Learner<B, Model, Optim, LR, T, V>
model: M,
optim: O,
lr_scheduler: S,
) -> Learner<
LearnerComponentsMarker<
B,
S,
M,
O,
AsyncCheckpointer<M::Record>,
AsyncCheckpointer<O::Record>,
AsyncCheckpointer<S::Record>,
AsyncTrainerCallback<T, V>,
>,
>
where
Model::Record: 'static,
Optim::Record: 'static,
LR::Record: 'static,
M::Record: 'static,
O::Record: 'static,
S::Record: 'static,
{
if self.log_to_file {
self.init_logger();
@ -268,45 +294,25 @@ where
let logger_valid = self.metric_logger_valid.unwrap_or_else(|| {
Box::new(FileMetricLogger::new(format!("{directory}/valid").as_str()))
});
let dashboard = Dashboard::new(renderer, self.metrics, logger_train, logger_valid);
let callback = Box::new(dashboard);
let callback = Box::new(AsyncTrainerCallback::new(callback));
let callback = AsyncTrainerCallback::new(MetricsCallback::new(
renderer,
self.metrics,
logger_train,
logger_valid,
));
let checkpointer_optimizer = match self.checkpointer_optimizer {
Some(checkpointer) => {
let checkpointer: Box<dyn Checkpointer<Optim::Record>> =
Box::new(AsyncCheckpointer::new(checkpointer));
Some(checkpointer)
}
None => None,
};
let checkpointer_model = match self.checkpointer_model {
Some(checkpointer) => {
let checkpointer: Box<dyn Checkpointer<Model::Record>> =
Box::new(AsyncCheckpointer::new(checkpointer));
Some(checkpointer)
}
None => None,
};
let checkpointer_scheduler = match self.checkpointer_scheduler {
Some(checkpointer) => {
let checkpointer: Box<dyn Checkpointer<LR::Record>> =
Box::new(AsyncCheckpointer::new(checkpointer));
Some(checkpointer)
}
None => None,
};
let checkpointer = self
.checkpointers
.map(|(model, optim, scheduler)| LearnerCheckpointer::new(model, optim, scheduler));
Learner {
model,
optim,
lr_scheduler,
checkpointer,
num_epochs: self.num_epochs,
callback,
checkpoint: self.checkpoint,
checkpointer_model,
checkpointer_optimizer,
checkpointer_scheduler,
grad_accumulation: self.grad_accumulation,
devices: self.devices,
interrupter: self.interrupter,

View File

@ -1,6 +1,6 @@
use burn_core::{
data::dataloader::DataLoader,
lr_scheduler::LRScheduler,
lr_scheduler::LrScheduler,
module::ADModule,
optim::{GradientsAccumulator, Optimizer},
tensor::backend::ADBackend,
@ -27,7 +27,7 @@ pub struct TrainEpoch<TI> {
grad_accumulation: Option<usize>,
}
impl<I> ValidEpoch<I> {
impl<VI> ValidEpoch<VI> {
/// Runs the validation epoch.
///
/// # Arguments
@ -37,12 +37,12 @@ impl<I> ValidEpoch<I> {
pub fn run<B, M, TO, VO>(
&self,
model: &M,
callback: &mut Box<dyn LearnerCallback<TO, VO>>,
callback: &mut Box<dyn LearnerCallback<ItemTrain = TO, ItemValid = VO>>,
interrupter: &TrainingInterrupter,
) where
B: ADBackend,
M: ADModule<B>,
M::InnerModule: ValidStep<I, VO>,
M::InnerModule: ValidStep<VI, VO>,
{
log::info!("Executing validation step for epoch {}", self.epoch);
let model = model.valid();
@ -92,14 +92,14 @@ impl<TI> TrainEpoch<TI> {
mut model: M,
mut optim: O,
scheduler: &mut LR,
callback: &mut Box<dyn LearnerCallback<TO, VO>>,
callback: &mut Box<dyn LearnerCallback<ItemTrain = TO, ItemValid = VO>>,
interrupter: &TrainingInterrupter,
) -> (M, O)
where
B: ADBackend,
M: TrainStep<TI, TO> + ADModule<B>,
O: Optimizer<M, B>,
LR: LRScheduler,
LR: LrScheduler,
{
log::info!("Executing training step for epoch {}", self.epoch,);
@ -170,7 +170,7 @@ impl<TI> TrainEpoch<TI> {
mut model: M,
mut optim: O,
lr_scheduler: &mut S,
callback: &mut Box<dyn LearnerCallback<TO, VO>>,
callback: &mut Box<dyn LearnerCallback<ItemTrain = TO, ItemValid = VO>>,
devices: Vec<B::Device>,
interrupter: &TrainingInterrupter,
) -> (M, O)
@ -179,7 +179,7 @@ impl<TI> TrainEpoch<TI> {
M: ADModule<B> + 'static,
O: Optimizer<M, B>,
M: TrainStep<TI, TO>,
S: LRScheduler,
S: LrScheduler,
TI: Send + 'static,
TO: Send + 'static,
{

View File

@ -1 +0,0 @@

View File

@ -1,9 +1,7 @@
use super::Learner;
use crate::{TrainEpoch, ValidEpoch};
use crate::components::LearnerComponents;
use crate::{Learner, LearnerCallback, TrainEpoch, ValidEpoch};
use burn_core::data::dataloader::DataLoader;
use burn_core::lr_scheduler::LRScheduler;
use burn_core::module::ADModule;
use burn_core::module::{ADModule, Module};
use burn_core::optim::{GradientsParams, Optimizer};
use burn_core::tensor::backend::ADBackend;
use std::sync::Arc;
@ -94,15 +92,7 @@ pub trait ValidStep<VI, VO> {
fn step(&self, item: VI) -> VO;
}
impl<B, M, O, LR, TO, VO> Learner<B, M, O, LR, TO, VO>
where
VO: Send + Sync + 'static,
TO: Send + Sync + 'static,
B: ADBackend,
M: ADModule<B> + core::fmt::Display,
O: Optimizer<M, B>,
LR: LRScheduler,
{
impl<LC: LearnerComponents> Learner<LC> {
/// Fits the model.
///
/// # Arguments
@ -113,16 +103,19 @@ where
/// # Returns
///
/// The fitted model.
pub fn fit<TI, VI>(
pub fn fit<InputTrain, InputValid, OutputTrain, OutputValid>(
mut self,
dataloader_train: Arc<dyn DataLoader<TI>>,
dataloader_valid: Arc<dyn DataLoader<VI>>,
) -> M
dataloader_train: Arc<dyn DataLoader<InputTrain>>,
dataloader_valid: Arc<dyn DataLoader<InputValid>>,
) -> LC::Model
where
TI: Send + 'static,
TO: Send + 'static,
M: TrainStep<TI, TO> + Send + Clone + 'static,
M::InnerModule: ValidStep<VI, VO>,
InputTrain: Send + 'static,
InputValid: Send,
OutputTrain: Send + 'static,
OutputValid: Send,
LC::Model: TrainStep<InputTrain, OutputTrain>,
<LC::Model as ADModule<LC::Backend>>::InnerModule: ValidStep<InputValid, OutputValid>,
LC::Callback: LearnerCallback<ItemTrain = OutputTrain, ItemValid = OutputValid>,
{
log::info!("Fitting {}", self.model.to_string());
// The reference model is always on the first device provided.
@ -132,14 +125,22 @@ where
let starting_epoch = match self.checkpoint {
Some(checkpoint) => {
self = self.load_checkpoint(checkpoint);
if let Some(checkpointer) = &self.checkpointer {
(self.model, self.optim, self.lr_scheduler) = checkpointer.load_checkpoint(
self.model,
self.optim,
self.lr_scheduler,
checkpoint,
);
}
checkpoint + 1
}
None => 1,
};
let mut model = self.model;
let mut optim = self.optim;
let mut callback: Box<
dyn LearnerCallback<ItemTrain = OutputTrain, ItemValid = OutputValid>,
> = Box::new(self.callback);
for epoch in starting_epoch..self.num_epochs + 1 {
let epoch_train = TrainEpoch::new(
@ -150,20 +151,20 @@ where
);
if self.devices.len() > 1 {
(model, optim) = epoch_train.run_multi_device(
model,
optim,
(self.model, self.optim) = epoch_train.run_multi_device(
self.model,
self.optim,
&mut self.lr_scheduler,
&mut self.callback,
&mut callback,
self.devices.clone(),
&self.interrupter,
)
} else {
(model, optim) = epoch_train.run(
model,
optim,
(self.model, self.optim) = epoch_train.run(
self.model,
self.optim,
&mut self.lr_scheduler,
&mut self.callback,
&mut callback,
&self.interrupter,
);
}
@ -173,19 +174,13 @@ where
}
let epoch_valid = ValidEpoch::new(dataloader_valid.clone(), epoch, self.num_epochs);
epoch_valid.run(&model, &mut self.callback, &self.interrupter);
epoch_valid.run(&self.model, &mut callback, &self.interrupter);
Self::checkpoint(
&model,
&optim,
&self.lr_scheduler,
&self.checkpointer_model,
&self.checkpointer_optimizer,
&self.checkpointer_scheduler,
epoch,
);
if let Some(checkpointer) = &self.checkpointer {
checkpointer.checkpoint(&self.model, &self.optim, &self.lr_scheduler, epoch);
}
}
model
self.model
}
}

View File

@ -8,6 +8,8 @@ extern crate derive_new;
/// The checkpoint module.
pub mod checkpoint;
pub(crate) mod components;
/// The logger module.
pub mod logger;

View File

@ -5,6 +5,115 @@ use crate::{
};
use burn_core::data::dataloader::Progress;
/// Holds all metrics, metric loggers, and a metrics renderer.
pub struct MetricsCallback<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
metrics: Metrics<T, V>,
logger_train: Box<dyn MetricLogger>,
logger_valid: Box<dyn MetricLogger>,
renderer: Box<dyn MetricsRenderer>,
}
impl<T, V> MetricsCallback<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
/// Creates a new metrics callback.
///
/// # Arguments
///
/// * `renderer` - The metrics renderer.
/// * `metrics` - The metrics holder.
/// * `logger_train` - The training logger.
/// * `logger_valid` - The validation logger.
///
/// # Returns
///
/// A new metrics callback.
pub(crate) fn new(
renderer: Box<dyn MetricsRenderer>,
metrics: Metrics<T, V>,
logger_train: Box<dyn MetricLogger>,
logger_valid: Box<dyn MetricLogger>,
) -> Self {
Self {
metrics,
logger_train,
logger_valid,
renderer,
}
}
}
impl<T, V> LearnerCallback for MetricsCallback<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
type ItemTrain = T;
type ItemValid = V;
fn on_train_item(&mut self, item: LearnerItem<T>) {
let metadata = (&item).into();
for metric in self.metrics.train.iter_mut() {
let state = metric.update(&item, &metadata);
self.logger_train.log(&state);
self.renderer.update_train(MetricState::Generic(state));
}
for metric in self.metrics.train_numeric.iter_mut() {
let (state, value) = metric.update(&item, &metadata);
self.logger_train.log(&state);
self.renderer
.update_train(MetricState::Numeric(state, value));
}
self.renderer.render_train(item.into());
}
fn on_valid_item(&mut self, item: LearnerItem<V>) {
let metadata = (&item).into();
for metric in self.metrics.valid.iter_mut() {
let state = metric.update(&item, &metadata);
self.logger_valid.log(&state);
self.renderer.update_valid(MetricState::Generic(state));
}
for metric in self.metrics.valid_numeric.iter_mut() {
let (state, value) = metric.update(&item, &metadata);
self.logger_valid.log(&state);
self.renderer
.update_valid(MetricState::Numeric(state, value));
}
self.renderer.render_valid(item.into());
}
fn on_train_end_epoch(&mut self, epoch: usize) {
for metric in self.metrics.train.iter_mut() {
metric.clear();
}
for metric in self.metrics.train_numeric.iter_mut() {
metric.clear();
}
self.logger_train.epoch(epoch + 1);
}
fn on_valid_end_epoch(&mut self, epoch: usize) {
for metric in self.metrics.valid.iter_mut() {
metric.clear();
}
for metric in self.metrics.valid_numeric.iter_mut() {
metric.clear();
}
self.logger_valid.epoch(epoch + 1);
}
}
/// Training progress.
#[derive(Debug)]
pub struct TrainingProgress {
@ -36,9 +145,9 @@ impl TrainingProgress {
}
}
/// A dashboard metric.
/// The state of a metric.
#[derive(Debug)]
pub enum DashboardMetricState {
pub enum MetricState {
/// A generic metric.
Generic(MetricEntry),
@ -46,21 +155,21 @@ pub enum DashboardMetricState {
Numeric(MetricEntry, f64),
}
/// Trait for rendering dashboard metrics.
pub trait DashboardRenderer: Send + Sync {
/// Trait for rendering metrics.
pub trait MetricsRenderer: Send + Sync {
/// Updates the training metric state.
///
/// # Arguments
///
/// * `state` - The metric state.
fn update_train(&mut self, state: DashboardMetricState);
fn update_train(&mut self, state: MetricState);
/// Updates the validation metric state.
///
/// # Arguments
///
/// * `state` - The metric state.
fn update_valid(&mut self, state: DashboardMetricState);
fn update_valid(&mut self, state: MetricState);
/// Renders the training progress.
///
@ -77,16 +186,16 @@ pub trait DashboardRenderer: Send + Sync {
fn render_valid(&mut self, item: TrainingProgress);
}
/// A container for the metrics held by a dashboard.
/// A container for the metrics held by a metrics callback.
pub(crate) struct Metrics<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
pub(crate) train: Vec<Box<dyn DashboardMetric<T>>>,
pub(crate) valid: Vec<Box<dyn DashboardMetric<V>>>,
pub(crate) train_numeric: Vec<Box<dyn DashboardNumericMetric<T>>>,
pub(crate) valid_numeric: Vec<Box<dyn DashboardNumericMetric<V>>>,
pub(crate) train: Vec<Box<dyn MetricUpdater<T>>>,
pub(crate) valid: Vec<Box<dyn MetricUpdater<V>>>,
pub(crate) train_numeric: Vec<Box<dyn NumericMetricUpdater<T>>>,
pub(crate) valid_numeric: Vec<Box<dyn NumericMetricUpdater<V>>>,
}
impl<T, V> Metrics<T, V>
@ -104,50 +213,6 @@ where
}
}
/// Holds all metrics, metric loggers, and a dashboard renderer.
pub struct Dashboard<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
metrics: Metrics<T, V>,
logger_train: Box<dyn MetricLogger>,
logger_valid: Box<dyn MetricLogger>,
renderer: Box<dyn DashboardRenderer>,
}
impl<T, V> Dashboard<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
/// Creates a new dashboard.
///
/// # Arguments
///
/// * `renderer` - The dashboard renderer.
/// * `metrics` - The dashboard's metrics
/// * `logger_train` - The training logger.
/// * `logger_valid` - The validation logger.
///
/// # Returns
///
/// A new dashboard.
pub(crate) fn new(
renderer: Box<dyn DashboardRenderer>,
metrics: Metrics<T, V>,
logger_train: Box<dyn MetricLogger>,
logger_valid: Box<dyn MetricLogger>,
) -> Self {
Self {
metrics,
logger_train,
logger_valid,
renderer,
}
}
}
impl<T> From<LearnerItem<T>> for TrainingProgress {
fn from(item: LearnerItem<T>) -> Self {
Self {
@ -171,76 +236,12 @@ impl<T> From<&LearnerItem<T>> for MetricMetadata {
}
}
impl<T, V> LearnerCallback<T, V> for Dashboard<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
fn on_train_item(&mut self, item: LearnerItem<T>) {
let metadata = (&item).into();
for metric in self.metrics.train.iter_mut() {
let state = metric.update(&item, &metadata);
self.logger_train.log(&state);
self.renderer
.update_train(DashboardMetricState::Generic(state));
}
for metric in self.metrics.train_numeric.iter_mut() {
let (state, value) = metric.update(&item, &metadata);
self.logger_train.log(&state);
self.renderer
.update_train(DashboardMetricState::Numeric(state, value));
}
self.renderer.render_train(item.into());
}
fn on_valid_item(&mut self, item: LearnerItem<V>) {
let metadata = (&item).into();
for metric in self.metrics.valid.iter_mut() {
let state = metric.update(&item, &metadata);
self.logger_valid.log(&state);
self.renderer
.update_valid(DashboardMetricState::Generic(state));
}
for metric in self.metrics.valid_numeric.iter_mut() {
let (state, value) = metric.update(&item, &metadata);
self.logger_valid.log(&state);
self.renderer
.update_valid(DashboardMetricState::Numeric(state, value));
}
self.renderer.render_valid(item.into());
}
fn on_train_end_epoch(&mut self, epoch: usize) {
for metric in self.metrics.train.iter_mut() {
metric.clear();
}
for metric in self.metrics.train_numeric.iter_mut() {
metric.clear();
}
self.logger_train.epoch(epoch + 1);
}
fn on_valid_end_epoch(&mut self, epoch: usize) {
for metric in self.metrics.valid.iter_mut() {
metric.clear();
}
for metric in self.metrics.valid_numeric.iter_mut() {
metric.clear();
}
self.logger_valid.epoch(epoch + 1);
}
}
pub(crate) trait DashboardNumericMetric<T>: Send + Sync {
pub(crate) trait NumericMetricUpdater<T>: Send + Sync {
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> (MetricEntry, f64);
fn clear(&mut self);
}
pub(crate) trait DashboardMetric<T>: Send + Sync {
pub(crate) trait MetricUpdater<T>: Send + Sync {
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry;
fn clear(&mut self);
}
@ -250,7 +251,7 @@ pub(crate) struct MetricWrapper<M> {
metric: M,
}
impl<T, M> DashboardNumericMetric<T> for MetricWrapper<M>
impl<T, M> NumericMetricUpdater<T> for MetricWrapper<M>
where
T: 'static,
M: Metric + Numeric + 'static,
@ -268,7 +269,7 @@ where
}
}
impl<T, M> DashboardMetric<T> for MetricWrapper<M>
impl<T, M> MetricUpdater<T> for MetricWrapper<M>
where
T: 'static,
M: Metric + 'static,

View File

@ -0,0 +1,25 @@
use crate::metric::callback::{MetricState, MetricsRenderer, TrainingProgress};
/// A simple renderer for when the cli feature is not enabled.
pub struct CliMetricsRenderer;
impl CliMetricsRenderer {
/// Create a new instance.
pub fn new() -> Self {
Self {}
}
}
impl MetricsRenderer for CliMetricsRenderer {
fn update_train(&mut self, _state: MetricState) {}
fn update_valid(&mut self, _state: MetricState) {}
fn render_train(&mut self, item: TrainingProgress) {
dbg!(item);
}
fn render_valid(&mut self, item: TrainingProgress) {
dbg!(item);
}
}

View File

@ -3,25 +3,25 @@ mod base;
pub use base::*;
#[cfg(not(feature = "tui"))]
mod cli_stub;
mod cli;
#[cfg(not(feature = "tui"))]
pub use cli_stub::CLIDashboardRenderer as SelectedDashboardRenderer;
pub use cli::CliMetricsRenderer as SelectedMetricsRenderer;
#[cfg(feature = "tui")]
mod tui;
use crate::TrainingInterrupter;
#[cfg(feature = "tui")]
pub use tui::TuiDashboardRenderer as SelectedDashboardRenderer;
pub use tui::TuiMetricsRenderer as SelectedMetricsRenderer;
/// The TUI renderer, or a simple stub if the tui feature is not enabled.
#[allow(unused_variables)]
pub(crate) fn default_renderer(
interuptor: TrainingInterrupter,
checkpoint: Option<usize>,
) -> SelectedDashboardRenderer {
) -> SelectedMetricsRenderer {
#[cfg(feature = "tui")]
return SelectedDashboardRenderer::new(interuptor, checkpoint);
return SelectedMetricsRenderer::new(interuptor, checkpoint);
#[cfg(not(feature = "tui"))]
return SelectedDashboardRenderer::new();
return SelectedMetricsRenderer::new();
}

View File

@ -4,7 +4,7 @@ use super::{
use ratatui::prelude::{Constraint, Direction, Layout, Rect};
#[derive(new)]
pub(crate) struct DashboardView<'a> {
pub(crate) struct MetricsView<'a> {
metric_numeric: NumericMetricView<'a>,
metric_text: TextMetricView,
progress: ProgressBarView,
@ -12,7 +12,7 @@ pub(crate) struct DashboardView<'a> {
status: StatusView,
}
impl<'a> DashboardView<'a> {
impl<'a> MetricsView<'a> {
pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {
let chunks = Layout::default()
.direction(Direction::Vertical)

View File

@ -1,4 +1,4 @@
use crate::metric::dashboard::TrainingProgress;
use crate::metric::callback::TrainingProgress;
use super::{FullHistoryPlot, RecentHistoryPlot, TerminalFrame};
use crossterm::event::{Event, KeyCode};

View File

@ -1,5 +1,5 @@
use super::TerminalFrame;
use crate::metric::dashboard::TrainingProgress;
use crate::metric::callback::TrainingProgress;
use ratatui::{
prelude::{Alignment, Constraint, Direction, Layout, Rect},
style::{Color, Style, Stylize},

View File

@ -1,5 +1,5 @@
use crate::metric::dashboard::tui::NumericMetricsState;
use crate::metric::dashboard::{DashboardMetricState, DashboardRenderer, TrainingProgress};
use crate::metric::callback::tui::NumericMetricsState;
use crate::metric::callback::{MetricState, MetricsRenderer, TrainingProgress};
use crate::TrainingInterrupter;
use crossterm::{
event::{self, Event, KeyCode},
@ -14,7 +14,7 @@ use std::{
};
use super::{
Callback, CallbackFn, ControlsView, DashboardView, PopupState, ProgressBarState, StatusState,
Callback, CallbackFn, ControlsView, MetricsView, PopupState, ProgressBarState, StatusState,
TextMetricsState,
};
@ -25,8 +25,8 @@ pub(crate) type TerminalFrame<'a> = ratatui::Frame<'a, TerminalBackend>;
const MAX_REFRESH_RATE_MILLIS: u64 = 100;
/// The CLI dashboard renderer.
pub struct TuiDashboardRenderer {
/// The terminal UI metrics renderer.
pub struct TuiMetricsRenderer {
terminal: Terminal<TerminalBackend>,
last_update: std::time::Instant,
progress: ProgressBarState,
@ -37,25 +37,25 @@ pub struct TuiDashboardRenderer {
popup: PopupState,
}
impl DashboardRenderer for TuiDashboardRenderer {
fn update_train(&mut self, state: DashboardMetricState) {
impl MetricsRenderer for TuiMetricsRenderer {
fn update_train(&mut self, state: MetricState) {
match state {
DashboardMetricState::Generic(entry) => {
MetricState::Generic(entry) => {
self.metrics_text.update_train(entry);
}
DashboardMetricState::Numeric(entry, value) => {
MetricState::Numeric(entry, value) => {
self.metrics_numeric.push_train(entry.name.clone(), value);
self.metrics_text.update_train(entry);
}
};
}
fn update_valid(&mut self, state: DashboardMetricState) {
fn update_valid(&mut self, state: MetricState) {
match state {
DashboardMetricState::Generic(entry) => {
MetricState::Generic(entry) => {
self.metrics_text.update_valid(entry);
}
DashboardMetricState::Numeric(entry, value) => {
MetricState::Numeric(entry, value) => {
self.metrics_numeric.push_valid(entry.name.clone(), value);
self.metrics_text.update_valid(entry);
}
@ -77,8 +77,8 @@ impl DashboardRenderer for TuiDashboardRenderer {
}
}
impl TuiDashboardRenderer {
/// Create a new CLI dashboard renderer.
impl TuiMetricsRenderer {
/// Create a new terminal UI renderer.
pub fn new(interuptor: TrainingInterrupter, checkpoint: Option<usize>) -> Self {
let mut stdout = io::stdout();
execute!(stdout, EnterAlternateScreen).unwrap();
@ -118,7 +118,7 @@ impl TuiDashboardRenderer {
match self.popup.view() {
Some(view) => view.render(frame, size),
None => {
let view = DashboardView::new(
let view = MetricsView::new(
self.metrics_numeric.view(),
self.metrics_text.view(),
self.progress.view(),
@ -194,7 +194,7 @@ impl CallbackFn for PopupCancel {
}
}
impl Drop for TuiDashboardRenderer {
impl Drop for TuiMetricsRenderer {
fn drop(&mut self) {
disable_raw_mode().ok();
execute!(self.terminal.backend_mut(), LeaveAlternateScreen).unwrap();

View File

@ -1,5 +1,5 @@
use super::TerminalFrame;
use crate::metric::dashboard::TrainingProgress;
use crate::metric::callback::TrainingProgress;
use ratatui::{
prelude::{Alignment, Rect},
style::{Color, Style, Stylize},

View File

@ -1,25 +0,0 @@
use crate::metric::dashboard::{DashboardMetricState, DashboardRenderer, TrainingProgress};
/// A simple renderer for when the cli feature is not enabled.
pub struct CLIDashboardRenderer;
impl CLIDashboardRenderer {
/// Create a new instance.
pub fn new() -> Self {
Self {}
}
}
impl DashboardRenderer for CLIDashboardRenderer {
fn update_train(&mut self, _state: DashboardMetricState) {}
fn update_valid(&mut self, _state: DashboardMetricState) {}
fn render_train(&mut self, item: TrainingProgress) {
dbg!(item);
}
fn render_valid(&mut self, item: TrainingProgress) {
dbg!(item);
}
}

View File

@ -1,7 +1,7 @@
/// Dashboard module for training progress.
pub mod dashboard;
/// Callback module for training progress.
pub mod callback;
/// State module for dashboard metrics.
/// State module for callback metrics.
pub mod state;
mod acc;

View File

@ -1,5 +1,5 @@
use burn::data::dataset::source::huggingface::MNISTDataset;
use burn::train::metric::dashboard::{DashboardMetricState, DashboardRenderer, TrainingProgress};
use burn::train::metric::callback::{MetricState, MetricsRenderer, TrainingProgress};
use burn::train::LearnerBuilder;
use burn::{
config::Config, data::dataloader::DataLoaderBuilder, optim::AdamConfig,
@ -25,10 +25,10 @@ pub struct MnistTrainingConfig {
struct CustomRenderer {}
impl DashboardRenderer for CustomRenderer {
fn update_train(&mut self, _state: DashboardMetricState) {}
impl MetricsRenderer for CustomRenderer {
fn update_train(&mut self, _state: MetricState) {}
fn update_valid(&mut self, _state: DashboardMetricState) {}
fn update_valid(&mut self, _state: MetricState) {}
fn render_train(&mut self, item: TrainingProgress) {
dbg!(item);

View File

@ -12,7 +12,7 @@ use crate::{
use burn::{
config::Config,
data::{dataloader::DataLoaderBuilder, dataset::transform::SamplerDataset},
lr_scheduler::noam::NoamLRSchedulerConfig,
lr_scheduler::noam::NoamLrSchedulerConfig,
module::Module,
nn::transformer::TransformerEncoderConfig,
optim::AdamConfig,
@ -84,7 +84,7 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
let optim = config.optimizer.init();
// Initialize learning rate scheduler
let lr_scheduler = NoamLRSchedulerConfig::new(0.25)
let lr_scheduler = NoamLrSchedulerConfig::new(0.25)
.with_warmup_steps(1000)
.with_model_size(config.transformer.d_model)
.init();

View File

@ -6,7 +6,7 @@ use burn::data::dataset::transform::SamplerDataset;
use burn::{
config::Config,
data::{dataloader::DataLoaderBuilder, dataset::Dataset},
lr_scheduler::noam::NoamLRSchedulerConfig,
lr_scheduler::noam::NoamLrSchedulerConfig,
module::Module,
nn::transformer::TransformerEncoderConfig,
optim::AdamConfig,
@ -62,7 +62,7 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(
let accum = 6; // Effective batch size = 6 * 6 = 32.
let optim = config.optimizer.init();
let lr_scheduler = NoamLRSchedulerConfig::new(0.01 / accum as f64)
let lr_scheduler = NoamLrSchedulerConfig::new(0.01 / accum as f64)
.with_warmup_steps(6000)
.with_model_size(config.transformer.d_model)
.init();