mirror of https://github.com/tracel-ai/burn.git
Refactor burn-train (#847)
This commit is contained in:
parent
9afc76303f
commit
904ff1a974
|
@ -63,6 +63,6 @@ pub type TestADBackend = burn_autodiff::ADBackendDecorator<TestBackend>;
|
|||
|
||||
/// Type alias for the learning rate.
|
||||
///
|
||||
/// LearningRate also implements [learning rate scheduler](crate::lr_scheduler::LRScheduler) so it
|
||||
/// LearningRate also implements [learning rate scheduler](crate::lr_scheduler::LrScheduler) so it
|
||||
/// can be used for constant learning rate.
|
||||
pub type LearningRate = f64; // We could potentially change the type.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{record::Record, LearningRate};
|
||||
|
||||
/// Learning rate scheduler defines how the learning rate will evolve during training.
|
||||
pub trait LRScheduler: Send + Sync {
|
||||
pub trait LrScheduler: Send + Sync {
|
||||
/// Scheduler associative type to be used when saving and loading the state.
|
||||
type Record: Record;
|
||||
|
||||
|
|
|
@ -1,23 +1,23 @@
|
|||
use super::LRScheduler;
|
||||
use super::LrScheduler;
|
||||
use crate::LearningRate;
|
||||
|
||||
/// Constant learning rate implementing [learning rate scheduler](LRScheduler).
|
||||
/// Constant learning rate implementing [learning rate scheduler](LrScheduler).
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// You can also use [learning rate](LearningRate) which the same effect.
|
||||
#[derive(new, Clone, Debug)]
|
||||
pub struct ConstantLR {
|
||||
pub struct ConstantLr {
|
||||
lr: LearningRate,
|
||||
}
|
||||
|
||||
impl From<LearningRate> for ConstantLR {
|
||||
impl From<LearningRate> for ConstantLr {
|
||||
fn from(lr: LearningRate) -> Self {
|
||||
Self { lr }
|
||||
}
|
||||
}
|
||||
|
||||
impl LRScheduler for ConstantLR {
|
||||
impl LrScheduler for ConstantLr {
|
||||
type Record = ();
|
||||
|
||||
fn step(&mut self) -> LearningRate {
|
||||
|
@ -31,7 +31,7 @@ impl LRScheduler for ConstantLR {
|
|||
}
|
||||
}
|
||||
|
||||
impl LRScheduler for LearningRate {
|
||||
impl LrScheduler for LearningRate {
|
||||
type Record = ();
|
||||
|
||||
fn step(&mut self) -> LearningRate {
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
use crate as burn;
|
||||
|
||||
use super::LRScheduler;
|
||||
use super::LrScheduler;
|
||||
use crate::{config::Config, LearningRate};
|
||||
|
||||
/// Configuration to create a [noam](NoamLRScheduler) learning rate scheduler.
|
||||
/// Configuration to create a [noam](NoamLrScheduler) learning rate scheduler.
|
||||
#[derive(Config)]
|
||||
pub struct NoamLRSchedulerConfig {
|
||||
pub struct NoamLrSchedulerConfig {
|
||||
/// The initial learning rate.
|
||||
init_lr: LearningRate,
|
||||
/// The number of steps before the exponential decay stats.
|
||||
|
@ -18,17 +18,17 @@ pub struct NoamLRSchedulerConfig {
|
|||
|
||||
/// Noam learning rate scheduler as described in [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NoamLRScheduler {
|
||||
pub struct NoamLrScheduler {
|
||||
warmup_steps: f64,
|
||||
embedding_size: f64,
|
||||
init_lr: LearningRate,
|
||||
step: f64,
|
||||
}
|
||||
|
||||
impl NoamLRSchedulerConfig {
|
||||
/// Initialize a new [noam](NoamLRScheduler) learning rate scheduler.
|
||||
pub fn init(&self) -> NoamLRScheduler {
|
||||
NoamLRScheduler {
|
||||
impl NoamLrSchedulerConfig {
|
||||
/// Initialize a new [noam](NoamLrScheduler) learning rate scheduler.
|
||||
pub fn init(&self) -> NoamLrScheduler {
|
||||
NoamLrScheduler {
|
||||
warmup_steps: self.warmup_steps as f64,
|
||||
embedding_size: self.model_size as f64,
|
||||
init_lr: self.init_lr,
|
||||
|
@ -37,7 +37,7 @@ impl NoamLRSchedulerConfig {
|
|||
}
|
||||
}
|
||||
|
||||
impl LRScheduler for NoamLRScheduler {
|
||||
impl LrScheduler for NoamLrScheduler {
|
||||
type Record = usize;
|
||||
|
||||
fn step(&mut self) -> LearningRate {
|
||||
|
@ -66,7 +66,7 @@ mod tests {
|
|||
#[test]
|
||||
fn test_function_increase_and_decrease() {
|
||||
let warmup_steps = 100;
|
||||
let mut scheduler = NoamLRSchedulerConfig::new(10.0)
|
||||
let mut scheduler = NoamLrSchedulerConfig::new(10.0)
|
||||
.with_warmup_steps(warmup_steps)
|
||||
.init();
|
||||
let mut lr_current = 0.0;
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
use super::{LearnerCallback, LearnerItem};
|
||||
use std::{
|
||||
sync::{mpsc, Mutex},
|
||||
thread::JoinHandle,
|
||||
};
|
||||
use std::{sync::mpsc, thread::JoinHandle};
|
||||
|
||||
enum Message<T, V> {
|
||||
LogTrain(LearnerItem<T>),
|
||||
|
@ -19,30 +16,29 @@ pub struct AsyncTrainerCallback<T, V> {
|
|||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct CallbackThread<T, V> {
|
||||
callback: Mutex<Box<dyn LearnerCallback<T, V>>>,
|
||||
struct CallbackThread<C, T, V> {
|
||||
callback: C,
|
||||
receiver: mpsc::Receiver<Message<T, V>>,
|
||||
}
|
||||
|
||||
impl<T, V> CallbackThread<T, V> {
|
||||
fn run(self) {
|
||||
impl<C, T, V> CallbackThread<C, T, V>
|
||||
where
|
||||
C: LearnerCallback<ItemTrain = T, ItemValid = V>,
|
||||
{
|
||||
fn run(mut self) {
|
||||
for item in self.receiver.iter() {
|
||||
match item {
|
||||
Message::LogTrain(item) => {
|
||||
let mut callback = self.callback.lock().unwrap();
|
||||
callback.on_train_item(item);
|
||||
self.callback.on_train_item(item);
|
||||
}
|
||||
Message::ClearTrain(epoch) => {
|
||||
let mut callback = self.callback.lock().unwrap();
|
||||
callback.on_train_end_epoch(epoch);
|
||||
self.callback.on_train_end_epoch(epoch);
|
||||
}
|
||||
Message::LogValid(item) => {
|
||||
let mut callback = self.callback.lock().unwrap();
|
||||
callback.on_valid_item(item);
|
||||
self.callback.on_valid_item(item);
|
||||
}
|
||||
Message::ClearValid(epoch) => {
|
||||
let mut callback = self.callback.lock().unwrap();
|
||||
callback.on_valid_end_epoch(epoch);
|
||||
self.callback.on_valid_end_epoch(epoch);
|
||||
}
|
||||
Message::End => {
|
||||
return;
|
||||
|
@ -54,9 +50,12 @@ impl<T, V> CallbackThread<T, V> {
|
|||
|
||||
impl<T: Send + Sync + 'static, V: Send + Sync + 'static> AsyncTrainerCallback<T, V> {
|
||||
/// Create a new async trainer callback.
|
||||
pub fn new(callback: Box<dyn LearnerCallback<T, V>>) -> Self {
|
||||
pub fn new<C>(callback: C) -> Self
|
||||
where
|
||||
C: LearnerCallback<ItemTrain = T, ItemValid = V> + 'static,
|
||||
{
|
||||
let (sender, receiver) = mpsc::channel();
|
||||
let thread = CallbackThread::new(Mutex::new(callback), receiver);
|
||||
let thread = CallbackThread::new(callback, receiver);
|
||||
|
||||
let handler = std::thread::spawn(move || thread.run());
|
||||
let handler = Some(handler);
|
||||
|
@ -65,7 +64,10 @@ impl<T: Send + Sync + 'static, V: Send + Sync + 'static> AsyncTrainerCallback<T,
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: Send, V: Send> LearnerCallback<T, V> for AsyncTrainerCallback<T, V> {
|
||||
impl<T: Send, V: Send> LearnerCallback for AsyncTrainerCallback<T, V> {
|
||||
type ItemTrain = T;
|
||||
type ItemValid = V;
|
||||
|
||||
fn on_train_item(&mut self, item: LearnerItem<T>) {
|
||||
self.sender.send(Message::LogTrain(item)).unwrap();
|
||||
}
|
||||
|
|
|
@ -1,12 +1,17 @@
|
|||
use burn_core::{data::dataloader::Progress, LearningRate};
|
||||
|
||||
/// The base trait for trainer callbacks.
|
||||
pub trait LearnerCallback<T, V>: Send {
|
||||
pub trait LearnerCallback: Send {
|
||||
/// Training item.
|
||||
type ItemTrain;
|
||||
/// Validation item.
|
||||
type ItemValid;
|
||||
|
||||
/// Called when a training item is logged.
|
||||
fn on_train_item(&mut self, _item: LearnerItem<T>) {}
|
||||
fn on_train_item(&mut self, _item: LearnerItem<Self::ItemTrain>) {}
|
||||
|
||||
/// Called when a validation item is logged.
|
||||
fn on_valid_item(&mut self, _item: LearnerItem<V>) {}
|
||||
fn on_valid_item(&mut self, _item: LearnerItem<Self::ItemValid>) {}
|
||||
|
||||
/// Called when a training epoch is finished.
|
||||
fn on_train_end_epoch(&mut self, _epoch: usize) {}
|
||||
|
|
|
@ -1,22 +1,27 @@
|
|||
use super::{Checkpointer, CheckpointerError};
|
||||
use burn_core::record::Record;
|
||||
use std::sync::{mpsc, Arc};
|
||||
use std::sync::mpsc;
|
||||
|
||||
enum Message<R> {
|
||||
Restore(usize, mpsc::SyncSender<Result<R, CheckpointerError>>),
|
||||
Save(usize, R),
|
||||
End,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct CheckpointerThread<R> {
|
||||
checkpointer: Arc<dyn Checkpointer<R> + Send + Sync>,
|
||||
struct CheckpointerThread<C, R> {
|
||||
checkpointer: C,
|
||||
receiver: mpsc::Receiver<Message<R>>,
|
||||
}
|
||||
|
||||
impl<R: Record> CheckpointerThread<R> {
|
||||
impl<C: Checkpointer<R>, R: Record> CheckpointerThread<C, R> {
|
||||
fn run(self) {
|
||||
for item in self.receiver.iter() {
|
||||
match item {
|
||||
Message::Restore(epoch, sender) => {
|
||||
let record = self.checkpointer.restore(epoch);
|
||||
sender.send(record).unwrap();
|
||||
}
|
||||
Message::Save(epoch, state) => self.checkpointer.save(epoch, state).unwrap(),
|
||||
Message::End => {
|
||||
return;
|
||||
|
@ -27,9 +32,8 @@ impl<R: Record> CheckpointerThread<R> {
|
|||
}
|
||||
|
||||
/// Async checkpointer.
|
||||
pub struct AsyncCheckpointer<E> {
|
||||
checkpointer: Arc<dyn Checkpointer<E> + Send + Sync>,
|
||||
sender: mpsc::SyncSender<Message<E>>,
|
||||
pub struct AsyncCheckpointer<Record> {
|
||||
sender: mpsc::SyncSender<Message<Record>>,
|
||||
handler: Option<std::thread::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
|
@ -43,17 +47,16 @@ impl<R: Record + 'static> AsyncCheckpointer<R> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The async checkpointer.
|
||||
pub fn new(checkpointer: Arc<dyn Checkpointer<R> + Send + Sync>) -> Self {
|
||||
pub fn new<C>(checkpointer: C) -> Self
|
||||
where
|
||||
C: Checkpointer<R> + Send + 'static,
|
||||
{
|
||||
// Only on checkpoint can be done in advance.
|
||||
let (sender, receiver) = mpsc::sync_channel(0);
|
||||
let thread = CheckpointerThread::new(checkpointer.clone(), receiver);
|
||||
let thread = CheckpointerThread::new(checkpointer, receiver);
|
||||
let handler = Some(std::thread::spawn(move || thread.run()));
|
||||
|
||||
Self {
|
||||
checkpointer,
|
||||
sender,
|
||||
handler,
|
||||
}
|
||||
Self { sender, handler }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -68,7 +71,16 @@ where
|
|||
}
|
||||
|
||||
fn restore(&self, epoch: usize) -> Result<R, CheckpointerError> {
|
||||
self.checkpointer.restore(epoch)
|
||||
let (sender, receiver) = mpsc::sync_channel(1);
|
||||
self.sender
|
||||
.send(Message::Restore(epoch, sender))
|
||||
.map_err(|e| CheckpointerError::Unknown(e.to_string()))?;
|
||||
|
||||
if let Ok(record) = receiver.recv() {
|
||||
return record;
|
||||
};
|
||||
|
||||
Err(CheckpointerError::Unknown("Channel error.".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
use crate::{checkpoint::Checkpointer, LearnerCallback};
|
||||
use burn_core::{
|
||||
lr_scheduler::LrScheduler,
|
||||
module::{ADModule, Module},
|
||||
optim::Optimizer,
|
||||
tensor::backend::ADBackend,
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// All components necessary to train a model grouped in one trait.
|
||||
pub trait LearnerComponents {
|
||||
/// The backend in used for the training.
|
||||
type Backend: ADBackend;
|
||||
/// The learning rate scheduler used for the training.
|
||||
type LrScheduler: LrScheduler;
|
||||
/// The model to train.
|
||||
type Model: ADModule<Self::Backend> + core::fmt::Display + 'static;
|
||||
/// The optimizer used for the training.
|
||||
type Optimizer: Optimizer<Self::Model, Self::Backend>;
|
||||
/// The checkpointer used for the model.
|
||||
type CheckpointerModel: Checkpointer<<Self::Model as Module<Self::Backend>>::Record>;
|
||||
/// The checkpointer used for the optimizer.
|
||||
type CheckpointerOptimizer: Checkpointer<
|
||||
<Self::Optimizer as Optimizer<Self::Model, Self::Backend>>::Record,
|
||||
>;
|
||||
/// The checkpointer used for the scheduler.
|
||||
type CheckpointerLrScheduler: Checkpointer<<Self::LrScheduler as LrScheduler>::Record>;
|
||||
/// Callback used for training tracking.
|
||||
type Callback: LearnerCallback + 'static;
|
||||
}
|
||||
|
||||
/// Concrete type that implements [training components trait](TrainingComponents).
|
||||
pub struct LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, C> {
|
||||
_backend: PhantomData<B>,
|
||||
_lr_scheduler: PhantomData<LR>,
|
||||
_model: PhantomData<M>,
|
||||
_optimizer: PhantomData<O>,
|
||||
_checkpointer_model: PhantomData<CM>,
|
||||
_checkpointer_optim: PhantomData<CO>,
|
||||
_checkpointer_scheduler: PhantomData<CS>,
|
||||
_callback: PhantomData<C>,
|
||||
}
|
||||
|
||||
impl<B, LR, M, O, CM, CO, CS, C> LearnerComponents
|
||||
for LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, C>
|
||||
where
|
||||
B: ADBackend,
|
||||
LR: LrScheduler,
|
||||
M: ADModule<B> + core::fmt::Display + 'static,
|
||||
O: Optimizer<M, B>,
|
||||
CM: Checkpointer<M::Record>,
|
||||
CO: Checkpointer<O::Record>,
|
||||
CS: Checkpointer<LR::Record>,
|
||||
C: LearnerCallback + 'static,
|
||||
{
|
||||
type Backend = B;
|
||||
type LrScheduler = LR;
|
||||
type Model = M;
|
||||
type Optimizer = O;
|
||||
type CheckpointerModel = CM;
|
||||
type CheckpointerOptimizer = CO;
|
||||
type CheckpointerLrScheduler = CS;
|
||||
type Callback = C;
|
||||
}
|
|
@ -1,91 +1,67 @@
|
|||
use crate::checkpoint::Checkpointer;
|
||||
use crate::LearnerCallback;
|
||||
use burn_core::lr_scheduler::LRScheduler;
|
||||
use burn_core::module::{ADModule, Module};
|
||||
use crate::components::LearnerComponents;
|
||||
use burn_core::lr_scheduler::LrScheduler;
|
||||
use burn_core::module::Module;
|
||||
use burn_core::optim::Optimizer;
|
||||
use burn_core::tensor::backend::ADBackend;
|
||||
use burn_core::tensor::backend::Backend;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Learner struct encapsulating all components necessary to train a Neural Network model.
|
||||
///
|
||||
/// To create a learner, use the [builder](crate::learner::LearnerBuilder) struct.
|
||||
pub struct Learner<B, Model, Optim, LR, TrainOutput, ValidOutput>
|
||||
where
|
||||
B: ADBackend,
|
||||
LR: LRScheduler,
|
||||
Model: ADModule<B>,
|
||||
Optim: Optimizer<Model, B>,
|
||||
{
|
||||
pub(super) model: Model,
|
||||
pub(super) optim: Optim,
|
||||
pub(super) lr_scheduler: LR,
|
||||
pub(super) num_epochs: usize,
|
||||
pub(super) callback: Box<dyn LearnerCallback<TrainOutput, ValidOutput>>,
|
||||
pub(super) checkpoint: Option<usize>,
|
||||
pub(super) checkpointer_model: CheckpointModel<Model, B>,
|
||||
pub(super) checkpointer_optimizer: CheckpointOptim<Optim, Model, B>,
|
||||
pub(super) checkpointer_scheduler: CheckpointScheduler<LR>,
|
||||
pub(super) grad_accumulation: Option<usize>,
|
||||
pub(super) devices: Vec<B::Device>,
|
||||
pub(super) interrupter: TrainingInterrupter,
|
||||
pub struct Learner<LC: LearnerComponents> {
|
||||
pub(crate) model: LC::Model,
|
||||
pub(crate) optim: LC::Optimizer,
|
||||
pub(crate) lr_scheduler: LC::LrScheduler,
|
||||
pub(crate) num_epochs: usize,
|
||||
pub(crate) checkpoint: Option<usize>,
|
||||
pub(crate) grad_accumulation: Option<usize>,
|
||||
pub(crate) checkpointer: Option<LearnerCheckpointer<LC>>,
|
||||
pub(crate) devices: Vec<<LC::Backend as Backend>::Device>,
|
||||
pub(crate) callback: LC::Callback,
|
||||
pub(crate) interrupter: TrainingInterrupter,
|
||||
}
|
||||
|
||||
type CheckpointModel<M, B> = Option<Box<dyn Checkpointer<<M as Module<B>>::Record>>>;
|
||||
type CheckpointOptim<O, M, B> = Option<Box<dyn Checkpointer<<O as Optimizer<M, B>>::Record>>>;
|
||||
type CheckpointScheduler<LR> = Option<Box<dyn Checkpointer<<LR as LRScheduler>::Record>>>;
|
||||
#[derive(new)]
|
||||
pub(crate) struct LearnerCheckpointer<LC: LearnerComponents> {
|
||||
model: LC::CheckpointerModel,
|
||||
optim: LC::CheckpointerOptimizer,
|
||||
lr_scheduler: LC::CheckpointerLrScheduler,
|
||||
}
|
||||
|
||||
impl<B, Model, Optim, LR, TrainOutput, ValidOutput>
|
||||
Learner<B, Model, Optim, LR, TrainOutput, ValidOutput>
|
||||
where
|
||||
ValidOutput: Send + Sync + 'static,
|
||||
TrainOutput: Send + Sync + 'static,
|
||||
B: ADBackend,
|
||||
Model: ADModule<B>,
|
||||
Optim: Optimizer<Model, B>,
|
||||
LR: LRScheduler,
|
||||
{
|
||||
pub(super) fn checkpoint(
|
||||
model: &Model,
|
||||
optim: &Optim,
|
||||
scheduler: &LR,
|
||||
checkpointer_model: &CheckpointModel<Model, B>,
|
||||
checkpointer_optimizer: &CheckpointOptim<Optim, Model, B>,
|
||||
checkpointer_scheduler: &CheckpointScheduler<LR>,
|
||||
impl<LC: LearnerComponents> LearnerCheckpointer<LC> {
|
||||
pub(crate) fn checkpoint(
|
||||
&self,
|
||||
model: &LC::Model,
|
||||
optim: &LC::Optimizer,
|
||||
scheduler: &LC::LrScheduler,
|
||||
epoch: usize,
|
||||
) {
|
||||
if let Some(checkpointer) = &checkpointer_model {
|
||||
checkpointer
|
||||
.save(epoch, model.clone().into_record())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
if let Some(checkpointer) = &checkpointer_optimizer {
|
||||
checkpointer.save(epoch, optim.to_record()).unwrap();
|
||||
}
|
||||
|
||||
if let Some(checkpointer) = &checkpointer_scheduler {
|
||||
checkpointer.save(epoch, scheduler.to_record()).unwrap();
|
||||
}
|
||||
self.model.save(epoch, model.clone().into_record()).unwrap();
|
||||
self.optim.save(epoch, optim.to_record()).unwrap();
|
||||
self.lr_scheduler
|
||||
.save(epoch, scheduler.to_record())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub(super) fn load_checkpoint(mut self, epoch: usize) -> Self {
|
||||
if let Some(checkpointer) = &self.checkpointer_model {
|
||||
let record = checkpointer.restore(epoch).unwrap();
|
||||
self.model = self.model.load_record(record);
|
||||
}
|
||||
pub(crate) fn load_checkpoint(
|
||||
&self,
|
||||
model: LC::Model,
|
||||
optim: LC::Optimizer,
|
||||
scheduler: LC::LrScheduler,
|
||||
epoch: usize,
|
||||
) -> (LC::Model, LC::Optimizer, LC::LrScheduler) {
|
||||
let record = self.model.restore(epoch).unwrap();
|
||||
let model = model.load_record(record);
|
||||
|
||||
if let Some(checkpointer) = &self.checkpointer_optimizer {
|
||||
let record = checkpointer.restore(epoch).unwrap();
|
||||
self.optim = self.optim.load_record(record);
|
||||
}
|
||||
let record = self.optim.restore(epoch).unwrap();
|
||||
let optim = optim.load_record(record);
|
||||
|
||||
if let Some(checkpointer) = &self.checkpointer_scheduler {
|
||||
let record = checkpointer.restore(epoch).unwrap();
|
||||
self.lr_scheduler = self.lr_scheduler.load_record(record);
|
||||
}
|
||||
let record = self.lr_scheduler.restore(epoch).unwrap();
|
||||
let scheduler = scheduler.load_record(record);
|
||||
|
||||
self
|
||||
(model, optim, scheduler)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,21 +1,20 @@
|
|||
use super::log::install_file_logger;
|
||||
use super::Learner;
|
||||
use crate::checkpoint::{AsyncCheckpointer, Checkpointer, FileCheckpointer};
|
||||
use crate::checkpoint::{AsyncCheckpointer, FileCheckpointer};
|
||||
use crate::components::LearnerComponentsMarker;
|
||||
use crate::learner::base::TrainingInterrupter;
|
||||
use crate::logger::{FileMetricLogger, MetricLogger};
|
||||
use crate::metric::dashboard::{
|
||||
default_renderer, Dashboard, DashboardRenderer, MetricWrapper, Metrics,
|
||||
use crate::metric::callback::{
|
||||
default_renderer, MetricWrapper, Metrics, MetricsCallback, MetricsRenderer,
|
||||
};
|
||||
use crate::metric::{Adaptor, Metric};
|
||||
use crate::AsyncTrainerCallback;
|
||||
use burn_core::lr_scheduler::LRScheduler;
|
||||
use crate::{AsyncTrainerCallback, LearnerCheckpointer};
|
||||
use burn_core::lr_scheduler::LrScheduler;
|
||||
use burn_core::module::ADModule;
|
||||
use burn_core::optim::Optimizer;
|
||||
use burn_core::record::FileRecorder;
|
||||
use burn_core::tensor::backend::ADBackend;
|
||||
|
||||
use crate::learner::base::TrainingInterrupter;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Struct to configure and create a [learner](Learner).
|
||||
pub struct LearnerBuilder<B, T, V, M, O, S>
|
||||
where
|
||||
|
@ -24,11 +23,17 @@ where
|
|||
B: ADBackend,
|
||||
M: ADModule<B>,
|
||||
O: Optimizer<M, B>,
|
||||
S: LRScheduler,
|
||||
S: LrScheduler,
|
||||
{
|
||||
checkpointer_model: Option<Arc<dyn Checkpointer<M::Record> + Send + Sync>>,
|
||||
checkpointer_optimizer: Option<Arc<dyn Checkpointer<O::Record> + Send + Sync>>,
|
||||
checkpointer_scheduler: Option<Arc<dyn Checkpointer<S::Record> + Send + Sync>>,
|
||||
// Not that complex and very convenient when the traits are
|
||||
// already constrained correctly. Extracting in another type
|
||||
// would be more complex.
|
||||
#[allow(clippy::type_complexity)]
|
||||
checkpointers: Option<(
|
||||
AsyncCheckpointer<M::Record>,
|
||||
AsyncCheckpointer<O::Record>,
|
||||
AsyncCheckpointer<S::Record>,
|
||||
)>,
|
||||
num_epochs: usize,
|
||||
checkpoint: Option<usize>,
|
||||
directory: String,
|
||||
|
@ -36,20 +41,20 @@ where
|
|||
devices: Vec<B::Device>,
|
||||
metric_logger_train: Option<Box<dyn MetricLogger + 'static>>,
|
||||
metric_logger_valid: Option<Box<dyn MetricLogger + 'static>>,
|
||||
renderer: Option<Box<dyn DashboardRenderer + 'static>>,
|
||||
renderer: Option<Box<dyn MetricsRenderer + 'static>>,
|
||||
metrics: Metrics<T, V>,
|
||||
interrupter: TrainingInterrupter,
|
||||
log_to_file: bool,
|
||||
}
|
||||
|
||||
impl<B, T, V, Model, Optim, LR> LearnerBuilder<B, T, V, Model, Optim, LR>
|
||||
impl<B, T, V, M, O, S> LearnerBuilder<B, T, V, M, O, S>
|
||||
where
|
||||
B: ADBackend,
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
B: ADBackend,
|
||||
Model: ADModule<B>,
|
||||
Optim: Optimizer<Model, B>,
|
||||
LR: LRScheduler,
|
||||
M: ADModule<B> + core::fmt::Display + 'static,
|
||||
O: Optimizer<M, B>,
|
||||
S: LrScheduler,
|
||||
{
|
||||
/// Creates a new learner builder.
|
||||
///
|
||||
|
@ -60,9 +65,7 @@ where
|
|||
Self {
|
||||
num_epochs: 1,
|
||||
checkpoint: None,
|
||||
checkpointer_model: None,
|
||||
checkpointer_optimizer: None,
|
||||
checkpointer_scheduler: None,
|
||||
checkpointers: None,
|
||||
directory: directory.to_string(),
|
||||
grad_accumulation: None,
|
||||
devices: vec![B::Device::default()],
|
||||
|
@ -96,18 +99,18 @@ where
|
|||
/// # Arguments
|
||||
///
|
||||
/// * `renderer` - The custom renderer.
|
||||
pub fn renderer<DR>(mut self, renderer: DR) -> Self
|
||||
pub fn renderer<MR>(mut self, renderer: MR) -> Self
|
||||
where
|
||||
DR: DashboardRenderer + 'static,
|
||||
MR: MetricsRenderer + 'static,
|
||||
{
|
||||
self.renderer = Some(Box::new(renderer));
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a training metric.
|
||||
pub fn metric_train<M: Metric + 'static>(mut self, metric: M) -> Self
|
||||
pub fn metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self
|
||||
where
|
||||
T: Adaptor<M::Input>,
|
||||
T: Adaptor<Me::Input>,
|
||||
{
|
||||
self.metrics
|
||||
.train
|
||||
|
@ -116,9 +119,9 @@ where
|
|||
}
|
||||
|
||||
/// Register a validation metric.
|
||||
pub fn metric_valid<M: Metric + 'static>(mut self, metric: M) -> Self
|
||||
pub fn metric_valid<Me: Metric + 'static>(mut self, metric: Me) -> Self
|
||||
where
|
||||
V: Adaptor<M::Input>,
|
||||
V: Adaptor<Me::Input>,
|
||||
{
|
||||
self.metrics
|
||||
.valid
|
||||
|
@ -148,10 +151,10 @@ where
|
|||
/// Only [numeric](crate::metric::Numeric) metric can be displayed on a plot.
|
||||
/// If the same metric is also registered for the [validation split](Self::metric_valid_plot),
|
||||
/// the same graph will be used for both.
|
||||
pub fn metric_train_plot<M>(mut self, metric: M) -> Self
|
||||
pub fn metric_train_plot<Me>(mut self, metric: Me) -> Self
|
||||
where
|
||||
M: Metric + crate::metric::Numeric + 'static,
|
||||
T: Adaptor<M::Input>,
|
||||
Me: Metric + crate::metric::Numeric + 'static,
|
||||
T: Adaptor<Me::Input>,
|
||||
{
|
||||
self.metrics
|
||||
.train_numeric
|
||||
|
@ -166,12 +169,12 @@ where
|
|||
/// Only [numeric](crate::metric::Numeric) metric can be displayed on a plot.
|
||||
/// If the same metric is also registered for the [training split](Self::metric_train_plot),
|
||||
/// the same graph will be used for both.
|
||||
pub fn metric_valid_plot<M: Metric + crate::metric::Numeric + 'static>(
|
||||
pub fn metric_valid_plot<Me: Metric + crate::metric::Numeric + 'static>(
|
||||
mut self,
|
||||
metric: M,
|
||||
metric: Me,
|
||||
) -> Self
|
||||
where
|
||||
V: Adaptor<M::Input>,
|
||||
V: Adaptor<Me::Input>,
|
||||
{
|
||||
self.metrics
|
||||
.valid_numeric
|
||||
|
@ -219,41 +222,64 @@ where
|
|||
pub fn with_file_checkpointer<FR>(mut self, num_keep: usize, recorder: FR) -> Self
|
||||
where
|
||||
FR: FileRecorder + 'static,
|
||||
O::Record: 'static,
|
||||
M::Record: 'static,
|
||||
S::Record: 'static,
|
||||
{
|
||||
self.checkpointer_model = Some(Arc::new(FileCheckpointer::new(
|
||||
let checkpointer_model = FileCheckpointer::new(
|
||||
recorder.clone(),
|
||||
format!("{}/checkpoint", self.directory).as_str(),
|
||||
"model",
|
||||
num_keep,
|
||||
)));
|
||||
self.checkpointer_optimizer = Some(Arc::new(FileCheckpointer::new(
|
||||
);
|
||||
let checkpointer_optimizer = FileCheckpointer::new(
|
||||
recorder.clone(),
|
||||
format!("{}/checkpoint", self.directory).as_str(),
|
||||
"optim",
|
||||
num_keep,
|
||||
)));
|
||||
self.checkpointer_scheduler = Some(Arc::new(FileCheckpointer::new(
|
||||
);
|
||||
let checkpointer_scheduler = FileCheckpointer::new(
|
||||
recorder,
|
||||
format!("{}/checkpoint", self.directory).as_str(),
|
||||
"scheduler",
|
||||
num_keep,
|
||||
)));
|
||||
);
|
||||
|
||||
self.checkpointers = Some((
|
||||
AsyncCheckpointer::new(checkpointer_model),
|
||||
AsyncCheckpointer::new(checkpointer_optimizer),
|
||||
AsyncCheckpointer::new(checkpointer_scheduler),
|
||||
));
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Create the [learner](Learner) from a [model](ADModule) and an [optimizer](Optimizer).
|
||||
/// The [learning rate scheduler](LRScheduler) can also be a simple
|
||||
/// The [learning rate scheduler](LrScheduler) can also be a simple
|
||||
/// [learning rate](burn_core::LearningRate).
|
||||
#[allow(clippy::type_complexity)] // The goal for the builder is to handle all types and
|
||||
// creates a clean learner.
|
||||
pub fn build(
|
||||
self,
|
||||
model: Model,
|
||||
optim: Optim,
|
||||
lr_scheduler: LR,
|
||||
) -> Learner<B, Model, Optim, LR, T, V>
|
||||
model: M,
|
||||
optim: O,
|
||||
lr_scheduler: S,
|
||||
) -> Learner<
|
||||
LearnerComponentsMarker<
|
||||
B,
|
||||
S,
|
||||
M,
|
||||
O,
|
||||
AsyncCheckpointer<M::Record>,
|
||||
AsyncCheckpointer<O::Record>,
|
||||
AsyncCheckpointer<S::Record>,
|
||||
AsyncTrainerCallback<T, V>,
|
||||
>,
|
||||
>
|
||||
where
|
||||
Model::Record: 'static,
|
||||
Optim::Record: 'static,
|
||||
LR::Record: 'static,
|
||||
M::Record: 'static,
|
||||
O::Record: 'static,
|
||||
S::Record: 'static,
|
||||
{
|
||||
if self.log_to_file {
|
||||
self.init_logger();
|
||||
|
@ -268,45 +294,25 @@ where
|
|||
let logger_valid = self.metric_logger_valid.unwrap_or_else(|| {
|
||||
Box::new(FileMetricLogger::new(format!("{directory}/valid").as_str()))
|
||||
});
|
||||
let dashboard = Dashboard::new(renderer, self.metrics, logger_train, logger_valid);
|
||||
let callback = Box::new(dashboard);
|
||||
let callback = Box::new(AsyncTrainerCallback::new(callback));
|
||||
let callback = AsyncTrainerCallback::new(MetricsCallback::new(
|
||||
renderer,
|
||||
self.metrics,
|
||||
logger_train,
|
||||
logger_valid,
|
||||
));
|
||||
|
||||
let checkpointer_optimizer = match self.checkpointer_optimizer {
|
||||
Some(checkpointer) => {
|
||||
let checkpointer: Box<dyn Checkpointer<Optim::Record>> =
|
||||
Box::new(AsyncCheckpointer::new(checkpointer));
|
||||
Some(checkpointer)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
let checkpointer_model = match self.checkpointer_model {
|
||||
Some(checkpointer) => {
|
||||
let checkpointer: Box<dyn Checkpointer<Model::Record>> =
|
||||
Box::new(AsyncCheckpointer::new(checkpointer));
|
||||
Some(checkpointer)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
let checkpointer_scheduler = match self.checkpointer_scheduler {
|
||||
Some(checkpointer) => {
|
||||
let checkpointer: Box<dyn Checkpointer<LR::Record>> =
|
||||
Box::new(AsyncCheckpointer::new(checkpointer));
|
||||
Some(checkpointer)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
let checkpointer = self
|
||||
.checkpointers
|
||||
.map(|(model, optim, scheduler)| LearnerCheckpointer::new(model, optim, scheduler));
|
||||
|
||||
Learner {
|
||||
model,
|
||||
optim,
|
||||
lr_scheduler,
|
||||
checkpointer,
|
||||
num_epochs: self.num_epochs,
|
||||
callback,
|
||||
checkpoint: self.checkpoint,
|
||||
checkpointer_model,
|
||||
checkpointer_optimizer,
|
||||
checkpointer_scheduler,
|
||||
grad_accumulation: self.grad_accumulation,
|
||||
devices: self.devices,
|
||||
interrupter: self.interrupter,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use burn_core::{
|
||||
data::dataloader::DataLoader,
|
||||
lr_scheduler::LRScheduler,
|
||||
lr_scheduler::LrScheduler,
|
||||
module::ADModule,
|
||||
optim::{GradientsAccumulator, Optimizer},
|
||||
tensor::backend::ADBackend,
|
||||
|
@ -27,7 +27,7 @@ pub struct TrainEpoch<TI> {
|
|||
grad_accumulation: Option<usize>,
|
||||
}
|
||||
|
||||
impl<I> ValidEpoch<I> {
|
||||
impl<VI> ValidEpoch<VI> {
|
||||
/// Runs the validation epoch.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -37,12 +37,12 @@ impl<I> ValidEpoch<I> {
|
|||
pub fn run<B, M, TO, VO>(
|
||||
&self,
|
||||
model: &M,
|
||||
callback: &mut Box<dyn LearnerCallback<TO, VO>>,
|
||||
callback: &mut Box<dyn LearnerCallback<ItemTrain = TO, ItemValid = VO>>,
|
||||
interrupter: &TrainingInterrupter,
|
||||
) where
|
||||
B: ADBackend,
|
||||
M: ADModule<B>,
|
||||
M::InnerModule: ValidStep<I, VO>,
|
||||
M::InnerModule: ValidStep<VI, VO>,
|
||||
{
|
||||
log::info!("Executing validation step for epoch {}", self.epoch);
|
||||
let model = model.valid();
|
||||
|
@ -92,14 +92,14 @@ impl<TI> TrainEpoch<TI> {
|
|||
mut model: M,
|
||||
mut optim: O,
|
||||
scheduler: &mut LR,
|
||||
callback: &mut Box<dyn LearnerCallback<TO, VO>>,
|
||||
callback: &mut Box<dyn LearnerCallback<ItemTrain = TO, ItemValid = VO>>,
|
||||
interrupter: &TrainingInterrupter,
|
||||
) -> (M, O)
|
||||
where
|
||||
B: ADBackend,
|
||||
M: TrainStep<TI, TO> + ADModule<B>,
|
||||
O: Optimizer<M, B>,
|
||||
LR: LRScheduler,
|
||||
LR: LrScheduler,
|
||||
{
|
||||
log::info!("Executing training step for epoch {}", self.epoch,);
|
||||
|
||||
|
@ -170,7 +170,7 @@ impl<TI> TrainEpoch<TI> {
|
|||
mut model: M,
|
||||
mut optim: O,
|
||||
lr_scheduler: &mut S,
|
||||
callback: &mut Box<dyn LearnerCallback<TO, VO>>,
|
||||
callback: &mut Box<dyn LearnerCallback<ItemTrain = TO, ItemValid = VO>>,
|
||||
devices: Vec<B::Device>,
|
||||
interrupter: &TrainingInterrupter,
|
||||
) -> (M, O)
|
||||
|
@ -179,7 +179,7 @@ impl<TI> TrainEpoch<TI> {
|
|||
M: ADModule<B> + 'static,
|
||||
O: Optimizer<M, B>,
|
||||
M: TrainStep<TI, TO>,
|
||||
S: LRScheduler,
|
||||
S: LrScheduler,
|
||||
TI: Send + 'static,
|
||||
TO: Send + 'static,
|
||||
{
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
use super::Learner;
|
||||
|
||||
use crate::{TrainEpoch, ValidEpoch};
|
||||
use crate::components::LearnerComponents;
|
||||
use crate::{Learner, LearnerCallback, TrainEpoch, ValidEpoch};
|
||||
use burn_core::data::dataloader::DataLoader;
|
||||
use burn_core::lr_scheduler::LRScheduler;
|
||||
use burn_core::module::ADModule;
|
||||
use burn_core::module::{ADModule, Module};
|
||||
use burn_core::optim::{GradientsParams, Optimizer};
|
||||
use burn_core::tensor::backend::ADBackend;
|
||||
use std::sync::Arc;
|
||||
|
@ -94,15 +92,7 @@ pub trait ValidStep<VI, VO> {
|
|||
fn step(&self, item: VI) -> VO;
|
||||
}
|
||||
|
||||
impl<B, M, O, LR, TO, VO> Learner<B, M, O, LR, TO, VO>
|
||||
where
|
||||
VO: Send + Sync + 'static,
|
||||
TO: Send + Sync + 'static,
|
||||
B: ADBackend,
|
||||
M: ADModule<B> + core::fmt::Display,
|
||||
O: Optimizer<M, B>,
|
||||
LR: LRScheduler,
|
||||
{
|
||||
impl<LC: LearnerComponents> Learner<LC> {
|
||||
/// Fits the model.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -113,16 +103,19 @@ where
|
|||
/// # Returns
|
||||
///
|
||||
/// The fitted model.
|
||||
pub fn fit<TI, VI>(
|
||||
pub fn fit<InputTrain, InputValid, OutputTrain, OutputValid>(
|
||||
mut self,
|
||||
dataloader_train: Arc<dyn DataLoader<TI>>,
|
||||
dataloader_valid: Arc<dyn DataLoader<VI>>,
|
||||
) -> M
|
||||
dataloader_train: Arc<dyn DataLoader<InputTrain>>,
|
||||
dataloader_valid: Arc<dyn DataLoader<InputValid>>,
|
||||
) -> LC::Model
|
||||
where
|
||||
TI: Send + 'static,
|
||||
TO: Send + 'static,
|
||||
M: TrainStep<TI, TO> + Send + Clone + 'static,
|
||||
M::InnerModule: ValidStep<VI, VO>,
|
||||
InputTrain: Send + 'static,
|
||||
InputValid: Send,
|
||||
OutputTrain: Send + 'static,
|
||||
OutputValid: Send,
|
||||
LC::Model: TrainStep<InputTrain, OutputTrain>,
|
||||
<LC::Model as ADModule<LC::Backend>>::InnerModule: ValidStep<InputValid, OutputValid>,
|
||||
LC::Callback: LearnerCallback<ItemTrain = OutputTrain, ItemValid = OutputValid>,
|
||||
{
|
||||
log::info!("Fitting {}", self.model.to_string());
|
||||
// The reference model is always on the first device provided.
|
||||
|
@ -132,14 +125,22 @@ where
|
|||
|
||||
let starting_epoch = match self.checkpoint {
|
||||
Some(checkpoint) => {
|
||||
self = self.load_checkpoint(checkpoint);
|
||||
if let Some(checkpointer) = &self.checkpointer {
|
||||
(self.model, self.optim, self.lr_scheduler) = checkpointer.load_checkpoint(
|
||||
self.model,
|
||||
self.optim,
|
||||
self.lr_scheduler,
|
||||
checkpoint,
|
||||
);
|
||||
}
|
||||
checkpoint + 1
|
||||
}
|
||||
None => 1,
|
||||
};
|
||||
|
||||
let mut model = self.model;
|
||||
let mut optim = self.optim;
|
||||
let mut callback: Box<
|
||||
dyn LearnerCallback<ItemTrain = OutputTrain, ItemValid = OutputValid>,
|
||||
> = Box::new(self.callback);
|
||||
|
||||
for epoch in starting_epoch..self.num_epochs + 1 {
|
||||
let epoch_train = TrainEpoch::new(
|
||||
|
@ -150,20 +151,20 @@ where
|
|||
);
|
||||
|
||||
if self.devices.len() > 1 {
|
||||
(model, optim) = epoch_train.run_multi_device(
|
||||
model,
|
||||
optim,
|
||||
(self.model, self.optim) = epoch_train.run_multi_device(
|
||||
self.model,
|
||||
self.optim,
|
||||
&mut self.lr_scheduler,
|
||||
&mut self.callback,
|
||||
&mut callback,
|
||||
self.devices.clone(),
|
||||
&self.interrupter,
|
||||
)
|
||||
} else {
|
||||
(model, optim) = epoch_train.run(
|
||||
model,
|
||||
optim,
|
||||
(self.model, self.optim) = epoch_train.run(
|
||||
self.model,
|
||||
self.optim,
|
||||
&mut self.lr_scheduler,
|
||||
&mut self.callback,
|
||||
&mut callback,
|
||||
&self.interrupter,
|
||||
);
|
||||
}
|
||||
|
@ -173,19 +174,13 @@ where
|
|||
}
|
||||
|
||||
let epoch_valid = ValidEpoch::new(dataloader_valid.clone(), epoch, self.num_epochs);
|
||||
epoch_valid.run(&model, &mut self.callback, &self.interrupter);
|
||||
epoch_valid.run(&self.model, &mut callback, &self.interrupter);
|
||||
|
||||
Self::checkpoint(
|
||||
&model,
|
||||
&optim,
|
||||
&self.lr_scheduler,
|
||||
&self.checkpointer_model,
|
||||
&self.checkpointer_optimizer,
|
||||
&self.checkpointer_scheduler,
|
||||
epoch,
|
||||
);
|
||||
if let Some(checkpointer) = &self.checkpointer {
|
||||
checkpointer.checkpoint(&self.model, &self.optim, &self.lr_scheduler, epoch);
|
||||
}
|
||||
}
|
||||
|
||||
model
|
||||
self.model
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,8 @@ extern crate derive_new;
|
|||
/// The checkpoint module.
|
||||
pub mod checkpoint;
|
||||
|
||||
pub(crate) mod components;
|
||||
|
||||
/// The logger module.
|
||||
pub mod logger;
|
||||
|
||||
|
|
|
@ -5,6 +5,115 @@ use crate::{
|
|||
};
|
||||
use burn_core::data::dataloader::Progress;
|
||||
|
||||
/// Holds all metrics, metric loggers, and a metrics renderer.
|
||||
pub struct MetricsCallback<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
metrics: Metrics<T, V>,
|
||||
logger_train: Box<dyn MetricLogger>,
|
||||
logger_valid: Box<dyn MetricLogger>,
|
||||
renderer: Box<dyn MetricsRenderer>,
|
||||
}
|
||||
|
||||
impl<T, V> MetricsCallback<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
/// Creates a new metrics callback.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `renderer` - The metrics renderer.
|
||||
/// * `metrics` - The metrics holder.
|
||||
/// * `logger_train` - The training logger.
|
||||
/// * `logger_valid` - The validation logger.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new metrics callback.
|
||||
pub(crate) fn new(
|
||||
renderer: Box<dyn MetricsRenderer>,
|
||||
metrics: Metrics<T, V>,
|
||||
logger_train: Box<dyn MetricLogger>,
|
||||
logger_valid: Box<dyn MetricLogger>,
|
||||
) -> Self {
|
||||
Self {
|
||||
metrics,
|
||||
logger_train,
|
||||
logger_valid,
|
||||
renderer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, V> LearnerCallback for MetricsCallback<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
type ItemTrain = T;
|
||||
type ItemValid = V;
|
||||
|
||||
fn on_train_item(&mut self, item: LearnerItem<T>) {
|
||||
let metadata = (&item).into();
|
||||
for metric in self.metrics.train.iter_mut() {
|
||||
let state = metric.update(&item, &metadata);
|
||||
self.logger_train.log(&state);
|
||||
|
||||
self.renderer.update_train(MetricState::Generic(state));
|
||||
}
|
||||
for metric in self.metrics.train_numeric.iter_mut() {
|
||||
let (state, value) = metric.update(&item, &metadata);
|
||||
self.logger_train.log(&state);
|
||||
|
||||
self.renderer
|
||||
.update_train(MetricState::Numeric(state, value));
|
||||
}
|
||||
self.renderer.render_train(item.into());
|
||||
}
|
||||
|
||||
fn on_valid_item(&mut self, item: LearnerItem<V>) {
|
||||
let metadata = (&item).into();
|
||||
for metric in self.metrics.valid.iter_mut() {
|
||||
let state = metric.update(&item, &metadata);
|
||||
self.logger_valid.log(&state);
|
||||
|
||||
self.renderer.update_valid(MetricState::Generic(state));
|
||||
}
|
||||
for metric in self.metrics.valid_numeric.iter_mut() {
|
||||
let (state, value) = metric.update(&item, &metadata);
|
||||
self.logger_valid.log(&state);
|
||||
|
||||
self.renderer
|
||||
.update_valid(MetricState::Numeric(state, value));
|
||||
}
|
||||
self.renderer.render_valid(item.into());
|
||||
}
|
||||
|
||||
fn on_train_end_epoch(&mut self, epoch: usize) {
|
||||
for metric in self.metrics.train.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
for metric in self.metrics.train_numeric.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
self.logger_train.epoch(epoch + 1);
|
||||
}
|
||||
|
||||
fn on_valid_end_epoch(&mut self, epoch: usize) {
|
||||
for metric in self.metrics.valid.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
for metric in self.metrics.valid_numeric.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
self.logger_valid.epoch(epoch + 1);
|
||||
}
|
||||
}
|
||||
|
||||
/// Training progress.
|
||||
#[derive(Debug)]
|
||||
pub struct TrainingProgress {
|
||||
|
@ -36,9 +145,9 @@ impl TrainingProgress {
|
|||
}
|
||||
}
|
||||
|
||||
/// A dashboard metric.
|
||||
/// The state of a metric.
|
||||
#[derive(Debug)]
|
||||
pub enum DashboardMetricState {
|
||||
pub enum MetricState {
|
||||
/// A generic metric.
|
||||
Generic(MetricEntry),
|
||||
|
||||
|
@ -46,21 +155,21 @@ pub enum DashboardMetricState {
|
|||
Numeric(MetricEntry, f64),
|
||||
}
|
||||
|
||||
/// Trait for rendering dashboard metrics.
|
||||
pub trait DashboardRenderer: Send + Sync {
|
||||
/// Trait for rendering metrics.
|
||||
pub trait MetricsRenderer: Send + Sync {
|
||||
/// Updates the training metric state.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `state` - The metric state.
|
||||
fn update_train(&mut self, state: DashboardMetricState);
|
||||
fn update_train(&mut self, state: MetricState);
|
||||
|
||||
/// Updates the validation metric state.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `state` - The metric state.
|
||||
fn update_valid(&mut self, state: DashboardMetricState);
|
||||
fn update_valid(&mut self, state: MetricState);
|
||||
|
||||
/// Renders the training progress.
|
||||
///
|
||||
|
@ -77,16 +186,16 @@ pub trait DashboardRenderer: Send + Sync {
|
|||
fn render_valid(&mut self, item: TrainingProgress);
|
||||
}
|
||||
|
||||
/// A container for the metrics held by a dashboard.
|
||||
/// A container for the metrics held by a metrics callback.
|
||||
pub(crate) struct Metrics<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
pub(crate) train: Vec<Box<dyn DashboardMetric<T>>>,
|
||||
pub(crate) valid: Vec<Box<dyn DashboardMetric<V>>>,
|
||||
pub(crate) train_numeric: Vec<Box<dyn DashboardNumericMetric<T>>>,
|
||||
pub(crate) valid_numeric: Vec<Box<dyn DashboardNumericMetric<V>>>,
|
||||
pub(crate) train: Vec<Box<dyn MetricUpdater<T>>>,
|
||||
pub(crate) valid: Vec<Box<dyn MetricUpdater<V>>>,
|
||||
pub(crate) train_numeric: Vec<Box<dyn NumericMetricUpdater<T>>>,
|
||||
pub(crate) valid_numeric: Vec<Box<dyn NumericMetricUpdater<V>>>,
|
||||
}
|
||||
|
||||
impl<T, V> Metrics<T, V>
|
||||
|
@ -104,50 +213,6 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Holds all metrics, metric loggers, and a dashboard renderer.
|
||||
pub struct Dashboard<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
metrics: Metrics<T, V>,
|
||||
logger_train: Box<dyn MetricLogger>,
|
||||
logger_valid: Box<dyn MetricLogger>,
|
||||
renderer: Box<dyn DashboardRenderer>,
|
||||
}
|
||||
|
||||
impl<T, V> Dashboard<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
/// Creates a new dashboard.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `renderer` - The dashboard renderer.
|
||||
/// * `metrics` - The dashboard's metrics
|
||||
/// * `logger_train` - The training logger.
|
||||
/// * `logger_valid` - The validation logger.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new dashboard.
|
||||
pub(crate) fn new(
|
||||
renderer: Box<dyn DashboardRenderer>,
|
||||
metrics: Metrics<T, V>,
|
||||
logger_train: Box<dyn MetricLogger>,
|
||||
logger_valid: Box<dyn MetricLogger>,
|
||||
) -> Self {
|
||||
Self {
|
||||
metrics,
|
||||
logger_train,
|
||||
logger_valid,
|
||||
renderer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<LearnerItem<T>> for TrainingProgress {
|
||||
fn from(item: LearnerItem<T>) -> Self {
|
||||
Self {
|
||||
|
@ -171,76 +236,12 @@ impl<T> From<&LearnerItem<T>> for MetricMetadata {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T, V> LearnerCallback<T, V> for Dashboard<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
fn on_train_item(&mut self, item: LearnerItem<T>) {
|
||||
let metadata = (&item).into();
|
||||
for metric in self.metrics.train.iter_mut() {
|
||||
let state = metric.update(&item, &metadata);
|
||||
self.logger_train.log(&state);
|
||||
|
||||
self.renderer
|
||||
.update_train(DashboardMetricState::Generic(state));
|
||||
}
|
||||
for metric in self.metrics.train_numeric.iter_mut() {
|
||||
let (state, value) = metric.update(&item, &metadata);
|
||||
self.logger_train.log(&state);
|
||||
|
||||
self.renderer
|
||||
.update_train(DashboardMetricState::Numeric(state, value));
|
||||
}
|
||||
self.renderer.render_train(item.into());
|
||||
}
|
||||
|
||||
fn on_valid_item(&mut self, item: LearnerItem<V>) {
|
||||
let metadata = (&item).into();
|
||||
for metric in self.metrics.valid.iter_mut() {
|
||||
let state = metric.update(&item, &metadata);
|
||||
self.logger_valid.log(&state);
|
||||
|
||||
self.renderer
|
||||
.update_valid(DashboardMetricState::Generic(state));
|
||||
}
|
||||
for metric in self.metrics.valid_numeric.iter_mut() {
|
||||
let (state, value) = metric.update(&item, &metadata);
|
||||
self.logger_valid.log(&state);
|
||||
|
||||
self.renderer
|
||||
.update_valid(DashboardMetricState::Numeric(state, value));
|
||||
}
|
||||
self.renderer.render_valid(item.into());
|
||||
}
|
||||
|
||||
fn on_train_end_epoch(&mut self, epoch: usize) {
|
||||
for metric in self.metrics.train.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
for metric in self.metrics.train_numeric.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
self.logger_train.epoch(epoch + 1);
|
||||
}
|
||||
|
||||
fn on_valid_end_epoch(&mut self, epoch: usize) {
|
||||
for metric in self.metrics.valid.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
for metric in self.metrics.valid_numeric.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
self.logger_valid.epoch(epoch + 1);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait DashboardNumericMetric<T>: Send + Sync {
|
||||
pub(crate) trait NumericMetricUpdater<T>: Send + Sync {
|
||||
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> (MetricEntry, f64);
|
||||
fn clear(&mut self);
|
||||
}
|
||||
|
||||
pub(crate) trait DashboardMetric<T>: Send + Sync {
|
||||
pub(crate) trait MetricUpdater<T>: Send + Sync {
|
||||
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry;
|
||||
fn clear(&mut self);
|
||||
}
|
||||
|
@ -250,7 +251,7 @@ pub(crate) struct MetricWrapper<M> {
|
|||
metric: M,
|
||||
}
|
||||
|
||||
impl<T, M> DashboardNumericMetric<T> for MetricWrapper<M>
|
||||
impl<T, M> NumericMetricUpdater<T> for MetricWrapper<M>
|
||||
where
|
||||
T: 'static,
|
||||
M: Metric + Numeric + 'static,
|
||||
|
@ -268,7 +269,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<T, M> DashboardMetric<T> for MetricWrapper<M>
|
||||
impl<T, M> MetricUpdater<T> for MetricWrapper<M>
|
||||
where
|
||||
T: 'static,
|
||||
M: Metric + 'static,
|
|
@ -0,0 +1,25 @@
|
|||
use crate::metric::callback::{MetricState, MetricsRenderer, TrainingProgress};
|
||||
|
||||
/// A simple renderer for when the cli feature is not enabled.
|
||||
pub struct CliMetricsRenderer;
|
||||
|
||||
impl CliMetricsRenderer {
|
||||
/// Create a new instance.
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
}
|
||||
|
||||
impl MetricsRenderer for CliMetricsRenderer {
|
||||
fn update_train(&mut self, _state: MetricState) {}
|
||||
|
||||
fn update_valid(&mut self, _state: MetricState) {}
|
||||
|
||||
fn render_train(&mut self, item: TrainingProgress) {
|
||||
dbg!(item);
|
||||
}
|
||||
|
||||
fn render_valid(&mut self, item: TrainingProgress) {
|
||||
dbg!(item);
|
||||
}
|
||||
}
|
|
@ -3,25 +3,25 @@ mod base;
|
|||
pub use base::*;
|
||||
|
||||
#[cfg(not(feature = "tui"))]
|
||||
mod cli_stub;
|
||||
mod cli;
|
||||
#[cfg(not(feature = "tui"))]
|
||||
pub use cli_stub::CLIDashboardRenderer as SelectedDashboardRenderer;
|
||||
pub use cli::CliMetricsRenderer as SelectedMetricsRenderer;
|
||||
|
||||
#[cfg(feature = "tui")]
|
||||
mod tui;
|
||||
use crate::TrainingInterrupter;
|
||||
#[cfg(feature = "tui")]
|
||||
pub use tui::TuiDashboardRenderer as SelectedDashboardRenderer;
|
||||
pub use tui::TuiMetricsRenderer as SelectedMetricsRenderer;
|
||||
|
||||
/// The TUI renderer, or a simple stub if the tui feature is not enabled.
|
||||
#[allow(unused_variables)]
|
||||
pub(crate) fn default_renderer(
|
||||
interuptor: TrainingInterrupter,
|
||||
checkpoint: Option<usize>,
|
||||
) -> SelectedDashboardRenderer {
|
||||
) -> SelectedMetricsRenderer {
|
||||
#[cfg(feature = "tui")]
|
||||
return SelectedDashboardRenderer::new(interuptor, checkpoint);
|
||||
return SelectedMetricsRenderer::new(interuptor, checkpoint);
|
||||
|
||||
#[cfg(not(feature = "tui"))]
|
||||
return SelectedDashboardRenderer::new();
|
||||
return SelectedMetricsRenderer::new();
|
||||
}
|
|
@ -4,7 +4,7 @@ use super::{
|
|||
use ratatui::prelude::{Constraint, Direction, Layout, Rect};
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct DashboardView<'a> {
|
||||
pub(crate) struct MetricsView<'a> {
|
||||
metric_numeric: NumericMetricView<'a>,
|
||||
metric_text: TextMetricView,
|
||||
progress: ProgressBarView,
|
||||
|
@ -12,7 +12,7 @@ pub(crate) struct DashboardView<'a> {
|
|||
status: StatusView,
|
||||
}
|
||||
|
||||
impl<'a> DashboardView<'a> {
|
||||
impl<'a> MetricsView<'a> {
|
||||
pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {
|
||||
let chunks = Layout::default()
|
||||
.direction(Direction::Vertical)
|
|
@ -1,4 +1,4 @@
|
|||
use crate::metric::dashboard::TrainingProgress;
|
||||
use crate::metric::callback::TrainingProgress;
|
||||
|
||||
use super::{FullHistoryPlot, RecentHistoryPlot, TerminalFrame};
|
||||
use crossterm::event::{Event, KeyCode};
|
|
@ -1,5 +1,5 @@
|
|||
use super::TerminalFrame;
|
||||
use crate::metric::dashboard::TrainingProgress;
|
||||
use crate::metric::callback::TrainingProgress;
|
||||
use ratatui::{
|
||||
prelude::{Alignment, Constraint, Direction, Layout, Rect},
|
||||
style::{Color, Style, Stylize},
|
|
@ -1,5 +1,5 @@
|
|||
use crate::metric::dashboard::tui::NumericMetricsState;
|
||||
use crate::metric::dashboard::{DashboardMetricState, DashboardRenderer, TrainingProgress};
|
||||
use crate::metric::callback::tui::NumericMetricsState;
|
||||
use crate::metric::callback::{MetricState, MetricsRenderer, TrainingProgress};
|
||||
use crate::TrainingInterrupter;
|
||||
use crossterm::{
|
||||
event::{self, Event, KeyCode},
|
||||
|
@ -14,7 +14,7 @@ use std::{
|
|||
};
|
||||
|
||||
use super::{
|
||||
Callback, CallbackFn, ControlsView, DashboardView, PopupState, ProgressBarState, StatusState,
|
||||
Callback, CallbackFn, ControlsView, MetricsView, PopupState, ProgressBarState, StatusState,
|
||||
TextMetricsState,
|
||||
};
|
||||
|
||||
|
@ -25,8 +25,8 @@ pub(crate) type TerminalFrame<'a> = ratatui::Frame<'a, TerminalBackend>;
|
|||
|
||||
const MAX_REFRESH_RATE_MILLIS: u64 = 100;
|
||||
|
||||
/// The CLI dashboard renderer.
|
||||
pub struct TuiDashboardRenderer {
|
||||
/// The terminal UI metrics renderer.
|
||||
pub struct TuiMetricsRenderer {
|
||||
terminal: Terminal<TerminalBackend>,
|
||||
last_update: std::time::Instant,
|
||||
progress: ProgressBarState,
|
||||
|
@ -37,25 +37,25 @@ pub struct TuiDashboardRenderer {
|
|||
popup: PopupState,
|
||||
}
|
||||
|
||||
impl DashboardRenderer for TuiDashboardRenderer {
|
||||
fn update_train(&mut self, state: DashboardMetricState) {
|
||||
impl MetricsRenderer for TuiMetricsRenderer {
|
||||
fn update_train(&mut self, state: MetricState) {
|
||||
match state {
|
||||
DashboardMetricState::Generic(entry) => {
|
||||
MetricState::Generic(entry) => {
|
||||
self.metrics_text.update_train(entry);
|
||||
}
|
||||
DashboardMetricState::Numeric(entry, value) => {
|
||||
MetricState::Numeric(entry, value) => {
|
||||
self.metrics_numeric.push_train(entry.name.clone(), value);
|
||||
self.metrics_text.update_train(entry);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn update_valid(&mut self, state: DashboardMetricState) {
|
||||
fn update_valid(&mut self, state: MetricState) {
|
||||
match state {
|
||||
DashboardMetricState::Generic(entry) => {
|
||||
MetricState::Generic(entry) => {
|
||||
self.metrics_text.update_valid(entry);
|
||||
}
|
||||
DashboardMetricState::Numeric(entry, value) => {
|
||||
MetricState::Numeric(entry, value) => {
|
||||
self.metrics_numeric.push_valid(entry.name.clone(), value);
|
||||
self.metrics_text.update_valid(entry);
|
||||
}
|
||||
|
@ -77,8 +77,8 @@ impl DashboardRenderer for TuiDashboardRenderer {
|
|||
}
|
||||
}
|
||||
|
||||
impl TuiDashboardRenderer {
|
||||
/// Create a new CLI dashboard renderer.
|
||||
impl TuiMetricsRenderer {
|
||||
/// Create a new terminal UI renderer.
|
||||
pub fn new(interuptor: TrainingInterrupter, checkpoint: Option<usize>) -> Self {
|
||||
let mut stdout = io::stdout();
|
||||
execute!(stdout, EnterAlternateScreen).unwrap();
|
||||
|
@ -118,7 +118,7 @@ impl TuiDashboardRenderer {
|
|||
match self.popup.view() {
|
||||
Some(view) => view.render(frame, size),
|
||||
None => {
|
||||
let view = DashboardView::new(
|
||||
let view = MetricsView::new(
|
||||
self.metrics_numeric.view(),
|
||||
self.metrics_text.view(),
|
||||
self.progress.view(),
|
||||
|
@ -194,7 +194,7 @@ impl CallbackFn for PopupCancel {
|
|||
}
|
||||
}
|
||||
|
||||
impl Drop for TuiDashboardRenderer {
|
||||
impl Drop for TuiMetricsRenderer {
|
||||
fn drop(&mut self) {
|
||||
disable_raw_mode().ok();
|
||||
execute!(self.terminal.backend_mut(), LeaveAlternateScreen).unwrap();
|
|
@ -1,5 +1,5 @@
|
|||
use super::TerminalFrame;
|
||||
use crate::metric::dashboard::TrainingProgress;
|
||||
use crate::metric::callback::TrainingProgress;
|
||||
use ratatui::{
|
||||
prelude::{Alignment, Rect},
|
||||
style::{Color, Style, Stylize},
|
|
@ -1,25 +0,0 @@
|
|||
use crate::metric::dashboard::{DashboardMetricState, DashboardRenderer, TrainingProgress};
|
||||
|
||||
/// A simple renderer for when the cli feature is not enabled.
|
||||
pub struct CLIDashboardRenderer;
|
||||
|
||||
impl CLIDashboardRenderer {
|
||||
/// Create a new instance.
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
}
|
||||
|
||||
impl DashboardRenderer for CLIDashboardRenderer {
|
||||
fn update_train(&mut self, _state: DashboardMetricState) {}
|
||||
|
||||
fn update_valid(&mut self, _state: DashboardMetricState) {}
|
||||
|
||||
fn render_train(&mut self, item: TrainingProgress) {
|
||||
dbg!(item);
|
||||
}
|
||||
|
||||
fn render_valid(&mut self, item: TrainingProgress) {
|
||||
dbg!(item);
|
||||
}
|
||||
}
|
|
@ -1,7 +1,7 @@
|
|||
/// Dashboard module for training progress.
|
||||
pub mod dashboard;
|
||||
/// Callback module for training progress.
|
||||
pub mod callback;
|
||||
|
||||
/// State module for dashboard metrics.
|
||||
/// State module for callback metrics.
|
||||
pub mod state;
|
||||
|
||||
mod acc;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use burn::data::dataset::source::huggingface::MNISTDataset;
|
||||
use burn::train::metric::dashboard::{DashboardMetricState, DashboardRenderer, TrainingProgress};
|
||||
use burn::train::metric::callback::{MetricState, MetricsRenderer, TrainingProgress};
|
||||
use burn::train::LearnerBuilder;
|
||||
use burn::{
|
||||
config::Config, data::dataloader::DataLoaderBuilder, optim::AdamConfig,
|
||||
|
@ -25,10 +25,10 @@ pub struct MnistTrainingConfig {
|
|||
|
||||
struct CustomRenderer {}
|
||||
|
||||
impl DashboardRenderer for CustomRenderer {
|
||||
fn update_train(&mut self, _state: DashboardMetricState) {}
|
||||
impl MetricsRenderer for CustomRenderer {
|
||||
fn update_train(&mut self, _state: MetricState) {}
|
||||
|
||||
fn update_valid(&mut self, _state: DashboardMetricState) {}
|
||||
fn update_valid(&mut self, _state: MetricState) {}
|
||||
|
||||
fn render_train(&mut self, item: TrainingProgress) {
|
||||
dbg!(item);
|
||||
|
|
|
@ -12,7 +12,7 @@ use crate::{
|
|||
use burn::{
|
||||
config::Config,
|
||||
data::{dataloader::DataLoaderBuilder, dataset::transform::SamplerDataset},
|
||||
lr_scheduler::noam::NoamLRSchedulerConfig,
|
||||
lr_scheduler::noam::NoamLrSchedulerConfig,
|
||||
module::Module,
|
||||
nn::transformer::TransformerEncoderConfig,
|
||||
optim::AdamConfig,
|
||||
|
@ -84,7 +84,7 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
|
|||
let optim = config.optimizer.init();
|
||||
|
||||
// Initialize learning rate scheduler
|
||||
let lr_scheduler = NoamLRSchedulerConfig::new(0.25)
|
||||
let lr_scheduler = NoamLrSchedulerConfig::new(0.25)
|
||||
.with_warmup_steps(1000)
|
||||
.with_model_size(config.transformer.d_model)
|
||||
.init();
|
||||
|
|
|
@ -6,7 +6,7 @@ use burn::data::dataset::transform::SamplerDataset;
|
|||
use burn::{
|
||||
config::Config,
|
||||
data::{dataloader::DataLoaderBuilder, dataset::Dataset},
|
||||
lr_scheduler::noam::NoamLRSchedulerConfig,
|
||||
lr_scheduler::noam::NoamLrSchedulerConfig,
|
||||
module::Module,
|
||||
nn::transformer::TransformerEncoderConfig,
|
||||
optim::AdamConfig,
|
||||
|
@ -62,7 +62,7 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(
|
|||
|
||||
let accum = 6; // Effective batch size = 6 * 6 = 32.
|
||||
let optim = config.optimizer.init();
|
||||
let lr_scheduler = NoamLRSchedulerConfig::new(0.01 / accum as f64)
|
||||
let lr_scheduler = NoamLrSchedulerConfig::new(0.01 / accum as f64)
|
||||
.with_warmup_steps(6000)
|
||||
.with_model_size(config.transformer.d_model)
|
||||
.init();
|
||||
|
|
Loading…
Reference in New Issue