mirror of https://github.com/tracel-ai/burn.git
Feat training events (#857)
This commit is contained in:
parent
097fd956d0
commit
620b86de98
|
@ -111,10 +111,10 @@ pub fn train<B: ADBackend>(artifact_dir: &str, config: TrainingConfig, device: B
|
|||
.build(MNISTDataset::test());
|
||||
|
||||
let learner = LearnerBuilder::new(artifact_dir)
|
||||
.metric_train_plot(AccuracyMetric::new())
|
||||
.metric_valid_plot(AccuracyMetric::new())
|
||||
.metric_train_plot(LossMetric::new())
|
||||
.metric_valid_plot(LossMetric::new())
|
||||
.metric_train_numeric(AccuracyMetric::new())
|
||||
.metric_valid_numeric(AccuracyMetric::new())
|
||||
.metric_train_numeric(LossMetric::new())
|
||||
.metric_valid_numeric(LossMetric::new())
|
||||
.with_file_checkpointer(1, CompactRecorder::new())
|
||||
.devices(vec![device])
|
||||
.num_epochs(config.num_epochs)
|
||||
|
|
|
@ -1,97 +0,0 @@
|
|||
use super::{LearnerCallback, LearnerItem};
|
||||
use std::{sync::mpsc, thread::JoinHandle};
|
||||
|
||||
enum Message<T, V> {
|
||||
LogTrain(LearnerItem<T>),
|
||||
LogValid(LearnerItem<V>),
|
||||
ClearTrain(usize),
|
||||
ClearValid(usize),
|
||||
End,
|
||||
}
|
||||
|
||||
/// Async trainer callback tracker.
|
||||
pub struct AsyncTrainerCallback<T, V> {
|
||||
sender: mpsc::Sender<Message<T, V>>,
|
||||
handler: Option<JoinHandle<()>>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct CallbackThread<C, T, V> {
|
||||
callback: C,
|
||||
receiver: mpsc::Receiver<Message<T, V>>,
|
||||
}
|
||||
|
||||
impl<C, T, V> CallbackThread<C, T, V>
|
||||
where
|
||||
C: LearnerCallback<ItemTrain = T, ItemValid = V>,
|
||||
{
|
||||
fn run(mut self) {
|
||||
for item in self.receiver.iter() {
|
||||
match item {
|
||||
Message::LogTrain(item) => {
|
||||
self.callback.on_train_item(item);
|
||||
}
|
||||
Message::ClearTrain(epoch) => {
|
||||
self.callback.on_train_end_epoch(epoch);
|
||||
}
|
||||
Message::LogValid(item) => {
|
||||
self.callback.on_valid_item(item);
|
||||
}
|
||||
Message::ClearValid(epoch) => {
|
||||
self.callback.on_valid_end_epoch(epoch);
|
||||
}
|
||||
Message::End => {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Send + Sync + 'static, V: Send + Sync + 'static> AsyncTrainerCallback<T, V> {
|
||||
/// Create a new async trainer callback.
|
||||
pub fn new<C>(callback: C) -> Self
|
||||
where
|
||||
C: LearnerCallback<ItemTrain = T, ItemValid = V> + 'static,
|
||||
{
|
||||
let (sender, receiver) = mpsc::channel();
|
||||
let thread = CallbackThread::new(callback, receiver);
|
||||
|
||||
let handler = std::thread::spawn(move || thread.run());
|
||||
let handler = Some(handler);
|
||||
|
||||
Self { sender, handler }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Send, V: Send> LearnerCallback for AsyncTrainerCallback<T, V> {
|
||||
type ItemTrain = T;
|
||||
type ItemValid = V;
|
||||
|
||||
fn on_train_item(&mut self, item: LearnerItem<T>) {
|
||||
self.sender.send(Message::LogTrain(item)).unwrap();
|
||||
}
|
||||
|
||||
fn on_valid_item(&mut self, item: LearnerItem<V>) {
|
||||
self.sender.send(Message::LogValid(item)).unwrap();
|
||||
}
|
||||
|
||||
fn on_train_end_epoch(&mut self, epoch: usize) {
|
||||
self.sender.send(Message::ClearTrain(epoch)).unwrap();
|
||||
}
|
||||
|
||||
fn on_valid_end_epoch(&mut self, epoch: usize) {
|
||||
self.sender.send(Message::ClearValid(epoch)).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, V> Drop for AsyncTrainerCallback<T, V> {
|
||||
fn drop(&mut self) {
|
||||
self.sender.send(Message::End).unwrap();
|
||||
let handler = self.handler.take();
|
||||
|
||||
if let Some(handler) = handler {
|
||||
handler.join().unwrap();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,43 +0,0 @@
|
|||
use burn_core::{data::dataloader::Progress, LearningRate};
|
||||
|
||||
/// The base trait for trainer callbacks.
|
||||
pub trait LearnerCallback: Send {
|
||||
/// Training item.
|
||||
type ItemTrain;
|
||||
/// Validation item.
|
||||
type ItemValid;
|
||||
|
||||
/// Called when a training item is logged.
|
||||
fn on_train_item(&mut self, _item: LearnerItem<Self::ItemTrain>) {}
|
||||
|
||||
/// Called when a validation item is logged.
|
||||
fn on_valid_item(&mut self, _item: LearnerItem<Self::ItemValid>) {}
|
||||
|
||||
/// Called when a training epoch is finished.
|
||||
fn on_train_end_epoch(&mut self, _epoch: usize) {}
|
||||
|
||||
/// Called when a validation epoch is finished.
|
||||
fn on_valid_end_epoch(&mut self, _epoch: usize) {}
|
||||
}
|
||||
|
||||
/// A learner item.
|
||||
#[derive(new)]
|
||||
pub struct LearnerItem<T> {
|
||||
/// The item.
|
||||
pub item: T,
|
||||
|
||||
/// The progress.
|
||||
pub progress: Progress,
|
||||
|
||||
/// The epoch.
|
||||
pub epoch: usize,
|
||||
|
||||
/// The total number of epochs.
|
||||
pub epoch_total: usize,
|
||||
|
||||
/// The iteration.
|
||||
pub iteration: usize,
|
||||
|
||||
/// The learning rate.
|
||||
pub lr: Option<LearningRate>,
|
||||
}
|
|
@ -1,5 +0,0 @@
|
|||
mod async_callback;
|
||||
mod base;
|
||||
|
||||
pub use async_callback::*;
|
||||
pub use base::*;
|
|
@ -0,0 +1,118 @@
|
|||
use super::EventCollector;
|
||||
use crate::{Aggregate, Direction, Event, Split};
|
||||
use std::{sync::mpsc, thread::JoinHandle};
|
||||
|
||||
enum Message<T, V> {
|
||||
OnEventTrain(Event<T>),
|
||||
OnEventValid(Event<V>),
|
||||
End,
|
||||
FindEpoch(
|
||||
String,
|
||||
Aggregate,
|
||||
Direction,
|
||||
Split,
|
||||
mpsc::SyncSender<Option<usize>>,
|
||||
),
|
||||
}
|
||||
|
||||
/// Async [event collector](EventCollector).
|
||||
///
|
||||
/// This will create a worker thread where all the computation is done ensuring that the training loop is
|
||||
/// never blocked by metric calculation.
|
||||
pub struct AsyncEventCollector<T, V> {
|
||||
sender: mpsc::Sender<Message<T, V>>,
|
||||
handler: Option<JoinHandle<()>>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct WorkerThread<C, T, V> {
|
||||
collector: C,
|
||||
receiver: mpsc::Receiver<Message<T, V>>,
|
||||
}
|
||||
|
||||
impl<C, T, V> WorkerThread<C, T, V>
|
||||
where
|
||||
C: EventCollector<ItemTrain = T, ItemValid = V>,
|
||||
{
|
||||
fn run(mut self) {
|
||||
for item in self.receiver.iter() {
|
||||
match item {
|
||||
Message::End => {
|
||||
return;
|
||||
}
|
||||
Message::FindEpoch(name, aggregate, direction, split, sender) => {
|
||||
let response = self
|
||||
.collector
|
||||
.find_epoch(&name, aggregate, direction, split);
|
||||
sender.send(response).unwrap();
|
||||
}
|
||||
Message::OnEventTrain(event) => self.collector.on_event_train(event),
|
||||
Message::OnEventValid(event) => self.collector.on_event_valid(event),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Send + Sync + 'static, V: Send + Sync + 'static> AsyncEventCollector<T, V> {
|
||||
/// Create a new async [event collector](EventCollector).
|
||||
pub fn new<C>(collector: C) -> Self
|
||||
where
|
||||
C: EventCollector<ItemTrain = T, ItemValid = V> + 'static,
|
||||
{
|
||||
let (sender, receiver) = mpsc::channel();
|
||||
let thread = WorkerThread::new(collector, receiver);
|
||||
|
||||
let handler = std::thread::spawn(move || thread.run());
|
||||
let handler = Some(handler);
|
||||
|
||||
Self { sender, handler }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Send, V: Send> EventCollector for AsyncEventCollector<T, V> {
|
||||
type ItemTrain = T;
|
||||
type ItemValid = V;
|
||||
|
||||
fn on_event_train(&mut self, event: Event<Self::ItemTrain>) {
|
||||
self.sender.send(Message::OnEventTrain(event)).unwrap();
|
||||
}
|
||||
|
||||
fn on_event_valid(&mut self, event: Event<Self::ItemValid>) {
|
||||
self.sender.send(Message::OnEventValid(event)).unwrap();
|
||||
}
|
||||
|
||||
fn find_epoch(
|
||||
&mut self,
|
||||
name: &str,
|
||||
aggregate: Aggregate,
|
||||
direction: Direction,
|
||||
split: Split,
|
||||
) -> Option<usize> {
|
||||
let (sender, receiver) = mpsc::sync_channel(1);
|
||||
self.sender
|
||||
.send(Message::FindEpoch(
|
||||
name.to_string(),
|
||||
aggregate,
|
||||
direction,
|
||||
split,
|
||||
sender,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
match receiver.recv() {
|
||||
Ok(value) => value,
|
||||
Err(err) => panic!("Async server crashed: {:?}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, V> Drop for AsyncEventCollector<T, V> {
|
||||
fn drop(&mut self) {
|
||||
self.sender.send(Message::End).unwrap();
|
||||
let handler = self.handler.take();
|
||||
|
||||
if let Some(handler) = handler {
|
||||
handler.join().unwrap();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
use burn_core::{data::dataloader::Progress, LearningRate};
|
||||
|
||||
/// Event happening during the training/validation process.
|
||||
pub enum Event<T> {
|
||||
/// Signal that an item have been processed.
|
||||
ProcessedItem(LearnerItem<T>),
|
||||
/// Signal the end of an epoch.
|
||||
EndEpoch(usize),
|
||||
}
|
||||
|
||||
/// Defines how training and validation events are collected.
|
||||
///
|
||||
/// This trait also exposes methods that uses the collected data to compute useful information.
|
||||
pub trait EventCollector: Send {
|
||||
/// Training item.
|
||||
type ItemTrain;
|
||||
/// Validation item.
|
||||
type ItemValid;
|
||||
|
||||
/// Collect the training event.
|
||||
fn on_event_train(&mut self, event: Event<Self::ItemTrain>);
|
||||
|
||||
/// Collect the validaion event.
|
||||
fn on_event_valid(&mut self, event: Event<Self::ItemValid>);
|
||||
|
||||
/// Find the epoch following the given criteria from the collected data.
|
||||
fn find_epoch(
|
||||
&mut self,
|
||||
name: &str,
|
||||
aggregate: Aggregate,
|
||||
direction: Direction,
|
||||
split: Split,
|
||||
) -> Option<usize>;
|
||||
}
|
||||
|
||||
/// How to aggregate the metric.
|
||||
pub enum Aggregate {
|
||||
/// Compute the average.
|
||||
Mean,
|
||||
}
|
||||
|
||||
/// The split to use.
|
||||
pub enum Split {
|
||||
/// The training split.
|
||||
Train,
|
||||
/// The validation split.
|
||||
Valid,
|
||||
}
|
||||
|
||||
/// The direction of the query.
|
||||
pub enum Direction {
|
||||
/// Lower is better.
|
||||
Lowest,
|
||||
/// Higher is better.
|
||||
Highest,
|
||||
}
|
||||
|
||||
/// A learner item.
|
||||
#[derive(new)]
|
||||
pub struct LearnerItem<T> {
|
||||
/// The item.
|
||||
pub item: T,
|
||||
|
||||
/// The progress.
|
||||
pub progress: Progress,
|
||||
|
||||
/// The epoch.
|
||||
pub epoch: usize,
|
||||
|
||||
/// The total number of epochs.
|
||||
pub epoch_total: usize,
|
||||
|
||||
/// The iteration.
|
||||
pub iteration: usize,
|
||||
|
||||
/// The learning rate.
|
||||
pub lr: Option<LearningRate>,
|
||||
}
|
|
@ -0,0 +1,131 @@
|
|||
use crate::{
|
||||
info::MetricsInfo,
|
||||
metric::MetricMetadata,
|
||||
renderer::{MetricState, MetricsRenderer, TrainingProgress},
|
||||
Aggregate, Direction, Event, EventCollector, LearnerItem, Split,
|
||||
};
|
||||
|
||||
/// Collect training events in order to display metrics with a metrics renderer.
|
||||
#[derive(new)]
|
||||
pub(crate) struct RenderedMetricsEventCollector<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
renderer: Box<dyn MetricsRenderer>,
|
||||
info: MetricsInfo<T, V>,
|
||||
}
|
||||
|
||||
impl<T, V> EventCollector for RenderedMetricsEventCollector<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
type ItemTrain = T;
|
||||
type ItemValid = V;
|
||||
|
||||
fn on_event_train(&mut self, event: Event<Self::ItemTrain>) {
|
||||
match event {
|
||||
Event::ProcessedItem(item) => self.on_train_item(item),
|
||||
Event::EndEpoch(epoch) => self.on_train_end_epoch(epoch),
|
||||
}
|
||||
}
|
||||
|
||||
fn on_event_valid(&mut self, event: Event<Self::ItemValid>) {
|
||||
match event {
|
||||
Event::ProcessedItem(item) => self.on_valid_item(item),
|
||||
Event::EndEpoch(epoch) => self.on_valid_end_epoch(epoch),
|
||||
}
|
||||
}
|
||||
|
||||
fn find_epoch(
|
||||
&mut self,
|
||||
name: &str,
|
||||
aggregate: Aggregate,
|
||||
direction: Direction,
|
||||
split: Split,
|
||||
) -> Option<usize> {
|
||||
self.info.find_epoch(name, aggregate, direction, split)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, V> RenderedMetricsEventCollector<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
fn on_train_item(&mut self, item: LearnerItem<T>) {
|
||||
let progress = (&item).into();
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.info.update_train(&item, &metadata);
|
||||
|
||||
update
|
||||
.entries
|
||||
.into_iter()
|
||||
.for_each(|entry| self.renderer.update_train(MetricState::Generic(entry)));
|
||||
|
||||
update
|
||||
.entries_numeric
|
||||
.into_iter()
|
||||
.for_each(|(entry, value)| {
|
||||
self.renderer
|
||||
.update_train(MetricState::Numeric(entry, value))
|
||||
});
|
||||
|
||||
self.renderer.render_train(progress);
|
||||
}
|
||||
|
||||
fn on_valid_item(&mut self, item: LearnerItem<V>) {
|
||||
let progress = (&item).into();
|
||||
let metadata = (&item).into();
|
||||
|
||||
let update = self.info.update_valid(&item, &metadata);
|
||||
|
||||
update
|
||||
.entries
|
||||
.into_iter()
|
||||
.for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry)));
|
||||
|
||||
update
|
||||
.entries_numeric
|
||||
.into_iter()
|
||||
.for_each(|(entry, value)| {
|
||||
self.renderer
|
||||
.update_valid(MetricState::Numeric(entry, value))
|
||||
});
|
||||
|
||||
self.renderer.render_train(progress);
|
||||
}
|
||||
|
||||
fn on_train_end_epoch(&mut self, epoch: usize) {
|
||||
self.info.end_epoch_train(epoch);
|
||||
}
|
||||
|
||||
fn on_valid_end_epoch(&mut self, epoch: usize) {
|
||||
self.info.end_epoch_valid(epoch);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<&LearnerItem<T>> for TrainingProgress {
|
||||
fn from(item: &LearnerItem<T>) -> Self {
|
||||
Self {
|
||||
progress: item.progress.clone(),
|
||||
epoch: item.epoch,
|
||||
epoch_total: item.epoch_total,
|
||||
iteration: item.iteration,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<&LearnerItem<T>> for MetricMetadata {
|
||||
fn from(item: &LearnerItem<T>) -> Self {
|
||||
Self {
|
||||
progress: item.progress.clone(),
|
||||
epoch: item.epoch,
|
||||
epoch_total: item.epoch_total,
|
||||
iteration: item.iteration,
|
||||
lr: item.lr,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
mod base;
|
||||
|
||||
pub(crate) use base::*;
|
|
@ -0,0 +1,8 @@
|
|||
mod async_collector;
|
||||
mod base;
|
||||
|
||||
pub use async_collector::*;
|
||||
pub use base::*;
|
||||
|
||||
/// Metrics collector module.
|
||||
pub mod metrics;
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{checkpoint::Checkpointer, LearnerCallback};
|
||||
use crate::{checkpoint::Checkpointer, EventCollector};
|
||||
use burn_core::{
|
||||
lr_scheduler::LrScheduler,
|
||||
module::{ADModule, Module},
|
||||
|
@ -25,8 +25,8 @@ pub trait LearnerComponents {
|
|||
>;
|
||||
/// The checkpointer used for the scheduler.
|
||||
type CheckpointerLrScheduler: Checkpointer<<Self::LrScheduler as LrScheduler>::Record>;
|
||||
/// Callback used for training tracking.
|
||||
type Callback: LearnerCallback + 'static;
|
||||
/// Training event collector used for training tracking.
|
||||
type EventCollector: EventCollector + 'static;
|
||||
}
|
||||
|
||||
/// Concrete type that implements [training components trait](TrainingComponents).
|
||||
|
@ -41,8 +41,8 @@ pub struct LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, C> {
|
|||
_callback: PhantomData<C>,
|
||||
}
|
||||
|
||||
impl<B, LR, M, O, CM, CO, CS, C> LearnerComponents
|
||||
for LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, C>
|
||||
impl<B, LR, M, O, CM, CO, CS, EC> LearnerComponents
|
||||
for LearnerComponentsMarker<B, LR, M, O, CM, CO, CS, EC>
|
||||
where
|
||||
B: ADBackend,
|
||||
LR: LrScheduler,
|
||||
|
@ -51,7 +51,7 @@ where
|
|||
CM: Checkpointer<M::Record>,
|
||||
CO: Checkpointer<O::Record>,
|
||||
CS: Checkpointer<LR::Record>,
|
||||
C: LearnerCallback + 'static,
|
||||
EC: EventCollector + 'static,
|
||||
{
|
||||
type Backend = B;
|
||||
type LrScheduler = LR;
|
||||
|
@ -60,5 +60,5 @@ where
|
|||
type CheckpointerModel = CM;
|
||||
type CheckpointerOptimizer = CO;
|
||||
type CheckpointerLrScheduler = CS;
|
||||
type Callback = C;
|
||||
type EventCollector = EC;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,163 @@
|
|||
use crate::{logger::MetricLogger, Aggregate, Direction};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Type that can be used to fetch and use numeric metric aggregates.
|
||||
#[derive(Default, Debug)]
|
||||
pub(crate) struct NumericMetricsAggregate {
|
||||
mean_for_each_epoch: HashMap<Key, f64>,
|
||||
}
|
||||
|
||||
#[derive(new, Hash, PartialEq, Eq, Debug)]
|
||||
struct Key {
|
||||
name: String,
|
||||
epoch: usize,
|
||||
}
|
||||
|
||||
impl NumericMetricsAggregate {
|
||||
pub(crate) fn mean(
|
||||
&mut self,
|
||||
name: &str,
|
||||
epoch: usize,
|
||||
loggers: &mut [Box<dyn MetricLogger>],
|
||||
) -> Option<f64> {
|
||||
let key = Key::new(name.to_string(), epoch);
|
||||
|
||||
if let Some(value) = self.mean_for_each_epoch.get(&key) {
|
||||
return Some(*value);
|
||||
}
|
||||
|
||||
let points = || {
|
||||
let mut errors = Vec::new();
|
||||
for logger in loggers {
|
||||
match logger.read_numeric(name, epoch) {
|
||||
Ok(points) => return Ok(points),
|
||||
Err(err) => errors.push(err),
|
||||
};
|
||||
}
|
||||
|
||||
Err(errors.join(" "))
|
||||
};
|
||||
|
||||
let points = points().expect("Can read values");
|
||||
|
||||
if points.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let num_points = points.len();
|
||||
let mean = points.into_iter().sum::<f64>() / num_points as f64;
|
||||
|
||||
self.mean_for_each_epoch.insert(key, mean);
|
||||
Some(mean)
|
||||
}
|
||||
|
||||
pub(crate) fn find_epoch(
|
||||
&mut self,
|
||||
name: &str,
|
||||
aggregate: Aggregate,
|
||||
direction: Direction,
|
||||
loggers: &mut [Box<dyn MetricLogger>],
|
||||
) -> Option<usize> {
|
||||
let mut data = Vec::new();
|
||||
let mut current_epoch = 1;
|
||||
|
||||
loop {
|
||||
match aggregate {
|
||||
Aggregate::Mean => match self.mean(name, current_epoch, loggers) {
|
||||
Some(value) => {
|
||||
data.push(value);
|
||||
}
|
||||
None => break,
|
||||
},
|
||||
};
|
||||
|
||||
current_epoch += 1;
|
||||
}
|
||||
|
||||
if data.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut current_value = match &direction {
|
||||
Direction::Lowest => f64::MAX,
|
||||
Direction::Highest => f64::MIN,
|
||||
};
|
||||
|
||||
for (i, value) in data.into_iter().enumerate() {
|
||||
match &direction {
|
||||
Direction::Lowest => {
|
||||
if value < current_value {
|
||||
current_value = value;
|
||||
current_epoch = i + 1;
|
||||
}
|
||||
}
|
||||
Direction::Highest => {
|
||||
if value > current_value {
|
||||
current_value = value;
|
||||
current_epoch = i + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(current_epoch)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{logger::FileMetricLogger, metric::MetricEntry};
|
||||
|
||||
use super::*;
|
||||
|
||||
struct TestLogger {
|
||||
logger: FileMetricLogger,
|
||||
epoch: usize,
|
||||
}
|
||||
const NAME: &str = "test-logger";
|
||||
|
||||
impl TestLogger {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
logger: FileMetricLogger::new("/tmp"),
|
||||
epoch: 1,
|
||||
}
|
||||
}
|
||||
fn log(&mut self, num: f64) {
|
||||
self.logger.log(&MetricEntry::new(
|
||||
NAME.into(),
|
||||
num.to_string(),
|
||||
num.to_string(),
|
||||
));
|
||||
}
|
||||
fn new_epoch(&mut self) {
|
||||
self.epoch += 1;
|
||||
self.logger.epoch(self.epoch);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_find_epoch() {
|
||||
let mut logger = TestLogger::new();
|
||||
let mut aggregate = NumericMetricsAggregate::default();
|
||||
|
||||
logger.log(500.); // Epoch 1
|
||||
logger.log(1000.); // Epoch 1
|
||||
logger.new_epoch();
|
||||
logger.log(200.); // Epoch 2
|
||||
logger.log(1000.); // Epoch 2
|
||||
logger.new_epoch();
|
||||
logger.log(10000.); // Epoch 3
|
||||
|
||||
let value = aggregate
|
||||
.find_epoch(
|
||||
NAME,
|
||||
Aggregate::Mean,
|
||||
Direction::Lowest,
|
||||
&mut [Box::new(logger.logger)],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(value, 2);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,253 @@
|
|||
use super::NumericMetricsAggregate;
|
||||
use crate::{
|
||||
logger::MetricLogger,
|
||||
metric::{Adaptor, Metric, MetricEntry, MetricMetadata, Numeric},
|
||||
Aggregate, Direction, LearnerItem, Split,
|
||||
};
|
||||
|
||||
/// Metrics information collected during training.
|
||||
pub struct MetricsInfo<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
train: Vec<Box<dyn MetricUpdater<T>>>,
|
||||
valid: Vec<Box<dyn MetricUpdater<V>>>,
|
||||
train_numeric: Vec<Box<dyn NumericMetricUpdater<T>>>,
|
||||
valid_numeric: Vec<Box<dyn NumericMetricUpdater<V>>>,
|
||||
loggers_train: Vec<Box<dyn MetricLogger>>,
|
||||
loggers_valid: Vec<Box<dyn MetricLogger>>,
|
||||
aggregate_train: NumericMetricsAggregate,
|
||||
aggregate_valid: NumericMetricsAggregate,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct MetricsUpdate {
|
||||
pub(crate) entries: Vec<MetricEntry>,
|
||||
pub(crate) entries_numeric: Vec<(MetricEntry, f64)>,
|
||||
}
|
||||
|
||||
impl<T, V> MetricsInfo<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
train: vec![],
|
||||
valid: vec![],
|
||||
train_numeric: vec![],
|
||||
valid_numeric: vec![],
|
||||
loggers_train: vec![],
|
||||
loggers_valid: vec![],
|
||||
aggregate_train: NumericMetricsAggregate::default(),
|
||||
aggregate_valid: NumericMetricsAggregate::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Signal the end of a training epoch.
|
||||
pub(crate) fn end_epoch_train(&mut self, epoch: usize) {
|
||||
for metric in self.train.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
for metric in self.train_numeric.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
for logger in self.loggers_train.iter_mut() {
|
||||
logger.epoch(epoch + 1);
|
||||
}
|
||||
}
|
||||
|
||||
/// Signal the end of a validation epoch.
|
||||
pub(crate) fn end_epoch_valid(&mut self, epoch: usize) {
|
||||
for metric in self.valid.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
for metric in self.valid_numeric.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
for logger in self.loggers_valid.iter_mut() {
|
||||
logger.epoch(epoch + 1);
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the training information from the training item.
|
||||
pub(crate) fn update_train(
|
||||
&mut self,
|
||||
item: &LearnerItem<T>,
|
||||
metadata: &MetricMetadata,
|
||||
) -> MetricsUpdate {
|
||||
let mut entries = Vec::with_capacity(self.train.len());
|
||||
let mut entries_numeric = Vec::with_capacity(self.train_numeric.len());
|
||||
|
||||
for metric in self.train.iter_mut() {
|
||||
let state = metric.update(item, metadata);
|
||||
|
||||
for logger in self.loggers_train.iter_mut() {
|
||||
logger.log(&state);
|
||||
}
|
||||
|
||||
entries.push(state);
|
||||
}
|
||||
|
||||
for metric in self.train_numeric.iter_mut() {
|
||||
let (state, value) = metric.update(item, metadata);
|
||||
for logger in self.loggers_train.iter_mut() {
|
||||
logger.log(&state);
|
||||
}
|
||||
|
||||
entries_numeric.push((state, value));
|
||||
}
|
||||
|
||||
MetricsUpdate::new(entries, entries_numeric)
|
||||
}
|
||||
|
||||
/// Update the training information from the validation item.
|
||||
pub(crate) fn update_valid(
|
||||
&mut self,
|
||||
item: &LearnerItem<V>,
|
||||
metadata: &MetricMetadata,
|
||||
) -> MetricsUpdate {
|
||||
let mut entries = Vec::with_capacity(self.valid.len());
|
||||
let mut entries_numeric = Vec::with_capacity(self.valid_numeric.len());
|
||||
|
||||
for metric in self.valid.iter_mut() {
|
||||
let state = metric.update(item, metadata);
|
||||
|
||||
for logger in self.loggers_valid.iter_mut() {
|
||||
logger.log(&state);
|
||||
}
|
||||
|
||||
entries.push(state);
|
||||
}
|
||||
|
||||
for metric in self.valid_numeric.iter_mut() {
|
||||
let (state, value) = metric.update(item, metadata);
|
||||
for logger in self.loggers_valid.iter_mut() {
|
||||
logger.log(&state);
|
||||
}
|
||||
|
||||
entries_numeric.push((state, value));
|
||||
}
|
||||
|
||||
MetricsUpdate::new(entries, entries_numeric)
|
||||
}
|
||||
|
||||
/// Find the epoch corresponding to the given criteria.
|
||||
pub(crate) fn find_epoch(
|
||||
&mut self,
|
||||
name: &str,
|
||||
aggregate: Aggregate,
|
||||
direction: Direction,
|
||||
split: Split,
|
||||
) -> Option<usize> {
|
||||
match split {
|
||||
Split::Train => {
|
||||
self.aggregate_train
|
||||
.find_epoch(name, aggregate, direction, &mut self.loggers_train)
|
||||
}
|
||||
Split::Valid => {
|
||||
self.aggregate_valid
|
||||
.find_epoch(name, aggregate, direction, &mut self.loggers_valid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a logger for training metrics.
|
||||
pub(crate) fn register_logger_train<ML: MetricLogger + 'static>(&mut self, logger: ML) {
|
||||
self.loggers_train.push(Box::new(logger));
|
||||
}
|
||||
|
||||
/// Register a logger for validation metrics.
|
||||
pub(crate) fn register_logger_valid<ML: MetricLogger + 'static>(&mut self, logger: ML) {
|
||||
self.loggers_valid.push(Box::new(logger));
|
||||
}
|
||||
|
||||
/// Register a training metric.
|
||||
pub(crate) fn register_metric_train<Me: Metric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
T: Adaptor<Me::Input>,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.train.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register a validation metric.
|
||||
pub(crate) fn register_valid_metric<Me: Metric + 'static>(&mut self, metric: Me)
|
||||
where
|
||||
V: Adaptor<Me::Input>,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.valid.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register a numeric training metric.
|
||||
pub(crate) fn register_train_metric_numeric<Me: Metric + Numeric + 'static>(
|
||||
&mut self,
|
||||
metric: Me,
|
||||
) where
|
||||
T: Adaptor<Me::Input>,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.train_numeric.push(Box::new(metric))
|
||||
}
|
||||
|
||||
/// Register a numeric validation metric.
|
||||
pub(crate) fn register_valid_metric_numeric<Me: Metric + Numeric + 'static>(
|
||||
&mut self,
|
||||
metric: Me,
|
||||
) where
|
||||
V: Adaptor<Me::Input>,
|
||||
{
|
||||
let metric = MetricWrapper::new(metric);
|
||||
self.valid_numeric.push(Box::new(metric))
|
||||
}
|
||||
}
|
||||
|
||||
trait NumericMetricUpdater<T>: Send + Sync {
|
||||
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> (MetricEntry, f64);
|
||||
fn clear(&mut self);
|
||||
}
|
||||
|
||||
trait MetricUpdater<T>: Send + Sync {
|
||||
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry;
|
||||
fn clear(&mut self);
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct MetricWrapper<M> {
|
||||
metric: M,
|
||||
}
|
||||
|
||||
impl<T, M> NumericMetricUpdater<T> for MetricWrapper<M>
|
||||
where
|
||||
T: 'static,
|
||||
M: Metric + Numeric + 'static,
|
||||
T: Adaptor<M::Input>,
|
||||
{
|
||||
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> (MetricEntry, f64) {
|
||||
let update = self.metric.update(&item.item.adapt(), metadata);
|
||||
let numeric = self.metric.value();
|
||||
|
||||
(update, numeric)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.metric.clear()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, M> MetricUpdater<T> for MetricWrapper<M>
|
||||
where
|
||||
T: 'static,
|
||||
M: Metric + 'static,
|
||||
T: Adaptor<M::Input>,
|
||||
{
|
||||
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry {
|
||||
self.metric.update(&item.item.adapt(), metadata)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.metric.clear()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
mod aggregates;
|
||||
mod metrics;
|
||||
|
||||
pub(crate) use aggregates::*;
|
||||
pub use metrics::*;
|
|
@ -19,7 +19,7 @@ pub struct Learner<LC: LearnerComponents> {
|
|||
pub(crate) grad_accumulation: Option<usize>,
|
||||
pub(crate) checkpointer: Option<LearnerCheckpointer<LC>>,
|
||||
pub(crate) devices: Vec<<LC::Backend as Backend>::Device>,
|
||||
pub(crate) callback: LC::Callback,
|
||||
pub(crate) collector: LC::EventCollector,
|
||||
pub(crate) interrupter: TrainingInterrupter,
|
||||
}
|
||||
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
use super::log::install_file_logger;
|
||||
use super::Learner;
|
||||
use crate::checkpoint::{AsyncCheckpointer, FileCheckpointer};
|
||||
use crate::collector::metrics::RenderedMetricsEventCollector;
|
||||
use crate::components::LearnerComponentsMarker;
|
||||
use crate::info::MetricsInfo;
|
||||
use crate::learner::base::TrainingInterrupter;
|
||||
use crate::logger::{FileMetricLogger, MetricLogger};
|
||||
use crate::metric::callback::{
|
||||
default_renderer, MetricWrapper, Metrics, MetricsCallback, MetricsRenderer,
|
||||
};
|
||||
use crate::metric::{Adaptor, Metric};
|
||||
use crate::{AsyncTrainerCallback, LearnerCheckpointer};
|
||||
use crate::renderer::{default_renderer, MetricsRenderer};
|
||||
use crate::{AsyncEventCollector, LearnerCheckpointer};
|
||||
use burn_core::lr_scheduler::LrScheduler;
|
||||
use burn_core::module::ADModule;
|
||||
use burn_core::optim::Optimizer;
|
||||
|
@ -39,12 +39,11 @@ where
|
|||
directory: String,
|
||||
grad_accumulation: Option<usize>,
|
||||
devices: Vec<B::Device>,
|
||||
metric_logger_train: Option<Box<dyn MetricLogger + 'static>>,
|
||||
metric_logger_valid: Option<Box<dyn MetricLogger + 'static>>,
|
||||
renderer: Option<Box<dyn MetricsRenderer + 'static>>,
|
||||
metrics: Metrics<T, V>,
|
||||
info: MetricsInfo<T, V>,
|
||||
interrupter: TrainingInterrupter,
|
||||
log_to_file: bool,
|
||||
num_loggers: usize,
|
||||
}
|
||||
|
||||
impl<B, T, V, M, O, S> LearnerBuilder<B, T, V, M, O, S>
|
||||
|
@ -69,12 +68,11 @@ where
|
|||
directory: directory.to_string(),
|
||||
grad_accumulation: None,
|
||||
devices: vec![B::Device::default()],
|
||||
metric_logger_train: None,
|
||||
metric_logger_valid: None,
|
||||
metrics: Metrics::new(),
|
||||
info: MetricsInfo::new(),
|
||||
renderer: None,
|
||||
interrupter: TrainingInterrupter::new(),
|
||||
log_to_file: true,
|
||||
num_loggers: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -89,8 +87,9 @@ where
|
|||
MT: MetricLogger + 'static,
|
||||
MV: MetricLogger + 'static,
|
||||
{
|
||||
self.metric_logger_train = Some(Box::new(logger_train));
|
||||
self.metric_logger_valid = Some(Box::new(logger_valid));
|
||||
self.info.register_logger_train(logger_train);
|
||||
self.info.register_logger_valid(logger_valid);
|
||||
self.num_loggers += 1;
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -112,9 +111,7 @@ where
|
|||
where
|
||||
T: Adaptor<Me::Input>,
|
||||
{
|
||||
self.metrics
|
||||
.train
|
||||
.push(Box::new(MetricWrapper::new(metric)));
|
||||
self.info.register_metric_train(metric);
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -123,9 +120,7 @@ where
|
|||
where
|
||||
V: Adaptor<Me::Input>,
|
||||
{
|
||||
self.metrics
|
||||
.valid
|
||||
.push(Box::new(MetricWrapper::new(metric)));
|
||||
self.info.register_valid_metric(metric);
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -144,41 +139,25 @@ where
|
|||
self
|
||||
}
|
||||
|
||||
/// Register a training metric and displays it on a plot.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Only [numeric](crate::metric::Numeric) metric can be displayed on a plot.
|
||||
/// If the same metric is also registered for the [validation split](Self::metric_valid_plot),
|
||||
/// the same graph will be used for both.
|
||||
pub fn metric_train_plot<Me>(mut self, metric: Me) -> Self
|
||||
/// Register a [numeric](crate::metric::Numeric) training [metric](Metric).
|
||||
pub fn metric_train_numeric<Me>(mut self, metric: Me) -> Self
|
||||
where
|
||||
Me: Metric + crate::metric::Numeric + 'static,
|
||||
T: Adaptor<Me::Input>,
|
||||
{
|
||||
self.metrics
|
||||
.train_numeric
|
||||
.push(Box::new(MetricWrapper::new(metric)));
|
||||
self.info.register_train_metric_numeric(metric);
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a validation metric and displays it on a plot.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Only [numeric](crate::metric::Numeric) metric can be displayed on a plot.
|
||||
/// If the same metric is also registered for the [training split](Self::metric_train_plot),
|
||||
/// the same graph will be used for both.
|
||||
pub fn metric_valid_plot<Me: Metric + crate::metric::Numeric + 'static>(
|
||||
/// Register a [numeric](crate::metric::Numeric) validation [metric](Metric).
|
||||
pub fn metric_valid_numeric<Me: Metric + crate::metric::Numeric + 'static>(
|
||||
mut self,
|
||||
metric: Me,
|
||||
) -> Self
|
||||
where
|
||||
V: Adaptor<Me::Input>,
|
||||
{
|
||||
self.metrics
|
||||
.valid_numeric
|
||||
.push(Box::new(MetricWrapper::new(metric)));
|
||||
self.info.register_valid_metric_numeric(metric);
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -260,7 +239,7 @@ where
|
|||
#[allow(clippy::type_complexity)] // The goal for the builder is to handle all types and
|
||||
// creates a clean learner.
|
||||
pub fn build(
|
||||
self,
|
||||
mut self,
|
||||
model: M,
|
||||
optim: O,
|
||||
lr_scheduler: S,
|
||||
|
@ -273,7 +252,7 @@ where
|
|||
AsyncCheckpointer<M::Record>,
|
||||
AsyncCheckpointer<O::Record>,
|
||||
AsyncCheckpointer<S::Record>,
|
||||
AsyncTrainerCallback<T, V>,
|
||||
AsyncEventCollector<T, V>,
|
||||
>,
|
||||
>
|
||||
where
|
||||
|
@ -288,18 +267,18 @@ where
|
|||
Box::new(default_renderer(self.interrupter.clone(), self.checkpoint))
|
||||
});
|
||||
let directory = &self.directory;
|
||||
let logger_train = self.metric_logger_train.unwrap_or_else(|| {
|
||||
Box::new(FileMetricLogger::new(format!("{directory}/train").as_str()))
|
||||
});
|
||||
let logger_valid = self.metric_logger_valid.unwrap_or_else(|| {
|
||||
Box::new(FileMetricLogger::new(format!("{directory}/valid").as_str()))
|
||||
});
|
||||
let callback = AsyncTrainerCallback::new(MetricsCallback::new(
|
||||
renderer,
|
||||
self.metrics,
|
||||
logger_train,
|
||||
logger_valid,
|
||||
));
|
||||
|
||||
if self.num_loggers == 0 {
|
||||
self.info.register_logger_train(FileMetricLogger::new(
|
||||
format!("{directory}/train").as_str(),
|
||||
));
|
||||
self.info.register_logger_valid(FileMetricLogger::new(
|
||||
format!("{directory}/valid").as_str(),
|
||||
));
|
||||
}
|
||||
|
||||
let collector =
|
||||
AsyncEventCollector::new(RenderedMetricsEventCollector::new(renderer, self.info));
|
||||
|
||||
let checkpointer = self
|
||||
.checkpointers
|
||||
|
@ -311,7 +290,7 @@ where
|
|||
lr_scheduler,
|
||||
checkpointer,
|
||||
num_epochs: self.num_epochs,
|
||||
callback,
|
||||
collector,
|
||||
checkpoint: self.checkpoint,
|
||||
grad_accumulation: self.grad_accumulation,
|
||||
devices: self.devices,
|
||||
|
|
|
@ -7,8 +7,8 @@ use burn_core::{
|
|||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::learner::base::TrainingInterrupter;
|
||||
use crate::{LearnerCallback, LearnerItem, MultiDevicesTrainStep, TrainStep, ValidStep};
|
||||
use crate::{learner::base::TrainingInterrupter, Event};
|
||||
use crate::{EventCollector, LearnerItem, MultiDevicesTrainStep, TrainStep, ValidStep};
|
||||
|
||||
/// A validation epoch.
|
||||
#[derive(new)]
|
||||
|
@ -37,7 +37,7 @@ impl<VI> ValidEpoch<VI> {
|
|||
pub fn run<B, M, TO, VO>(
|
||||
&self,
|
||||
model: &M,
|
||||
callback: &mut Box<dyn LearnerCallback<ItemTrain = TO, ItemValid = VO>>,
|
||||
callback: &mut Box<dyn EventCollector<ItemTrain = TO, ItemValid = VO>>,
|
||||
interrupter: &TrainingInterrupter,
|
||||
) where
|
||||
B: ADBackend,
|
||||
|
@ -64,13 +64,14 @@ impl<VI> ValidEpoch<VI> {
|
|||
None,
|
||||
);
|
||||
|
||||
callback.on_valid_item(item);
|
||||
callback.on_event_valid(Event::ProcessedItem(item));
|
||||
|
||||
if interrupter.should_stop() {
|
||||
log::info!("Training interrupted.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
callback.on_valid_end_epoch(self.epoch);
|
||||
callback.on_event_valid(Event::EndEpoch(self.epoch));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -92,7 +93,7 @@ impl<TI> TrainEpoch<TI> {
|
|||
mut model: M,
|
||||
mut optim: O,
|
||||
scheduler: &mut LR,
|
||||
callback: &mut Box<dyn LearnerCallback<ItemTrain = TO, ItemValid = VO>>,
|
||||
callback: &mut Box<dyn EventCollector<ItemTrain = TO, ItemValid = VO>>,
|
||||
interrupter: &TrainingInterrupter,
|
||||
) -> (M, O)
|
||||
where
|
||||
|
@ -139,13 +140,13 @@ impl<TI> TrainEpoch<TI> {
|
|||
Some(lr),
|
||||
);
|
||||
|
||||
callback.on_train_item(item);
|
||||
callback.on_event_train(Event::ProcessedItem(item));
|
||||
if interrupter.should_stop() {
|
||||
log::info!("Training interrupted.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
callback.on_train_end_epoch(self.epoch);
|
||||
callback.on_event_train(Event::EndEpoch(self.epoch));
|
||||
|
||||
(model, optim)
|
||||
}
|
||||
|
@ -170,7 +171,7 @@ impl<TI> TrainEpoch<TI> {
|
|||
mut model: M,
|
||||
mut optim: O,
|
||||
lr_scheduler: &mut S,
|
||||
callback: &mut Box<dyn LearnerCallback<ItemTrain = TO, ItemValid = VO>>,
|
||||
callback: &mut Box<dyn EventCollector<ItemTrain = TO, ItemValid = VO>>,
|
||||
devices: Vec<B::Device>,
|
||||
interrupter: &TrainingInterrupter,
|
||||
) -> (M, O)
|
||||
|
@ -232,7 +233,8 @@ impl<TI> TrainEpoch<TI> {
|
|||
Some(lr),
|
||||
);
|
||||
|
||||
callback.on_train_item(item);
|
||||
callback.on_event_train(Event::ProcessedItem(item));
|
||||
|
||||
if interrupter.should_stop() {
|
||||
log::info!("Training interrupted.");
|
||||
interrupted = true;
|
||||
|
@ -245,7 +247,7 @@ impl<TI> TrainEpoch<TI> {
|
|||
}
|
||||
}
|
||||
|
||||
callback.on_train_end_epoch(self.epoch);
|
||||
callback.on_event_train(Event::EndEpoch(self.epoch));
|
||||
|
||||
(model, optim)
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::components::LearnerComponents;
|
||||
use crate::{Learner, LearnerCallback, TrainEpoch, ValidEpoch};
|
||||
use crate::{EventCollector, Learner, TrainEpoch, ValidEpoch};
|
||||
use burn_core::data::dataloader::DataLoader;
|
||||
use burn_core::module::{ADModule, Module};
|
||||
use burn_core::optim::{GradientsParams, Optimizer};
|
||||
|
@ -115,7 +115,7 @@ impl<LC: LearnerComponents> Learner<LC> {
|
|||
OutputValid: Send,
|
||||
LC::Model: TrainStep<InputTrain, OutputTrain>,
|
||||
<LC::Model as ADModule<LC::Backend>>::InnerModule: ValidStep<InputValid, OutputValid>,
|
||||
LC::Callback: LearnerCallback<ItemTrain = OutputTrain, ItemValid = OutputValid>,
|
||||
LC::EventCollector: EventCollector<ItemTrain = OutputTrain, ItemValid = OutputValid>,
|
||||
{
|
||||
log::info!("Fitting {}", self.model.to_string());
|
||||
// The reference model is always on the first device provided.
|
||||
|
@ -139,8 +139,8 @@ impl<LC: LearnerComponents> Learner<LC> {
|
|||
};
|
||||
|
||||
let mut callback: Box<
|
||||
dyn LearnerCallback<ItemTrain = OutputTrain, ItemValid = OutputValid>,
|
||||
> = Box::new(self.callback);
|
||||
dyn EventCollector<ItemTrain = OutputTrain, ItemValid = OutputValid>,
|
||||
> = Box::new(self.collector);
|
||||
|
||||
for epoch in starting_epoch..self.num_epochs + 1 {
|
||||
let epoch_train = TrainEpoch::new(
|
||||
|
|
|
@ -10,16 +10,22 @@ pub mod checkpoint;
|
|||
|
||||
pub(crate) mod components;
|
||||
|
||||
/// Renderer modules to display metrics and training information.
|
||||
pub mod renderer;
|
||||
|
||||
/// The logger module.
|
||||
pub mod logger;
|
||||
|
||||
/// The metric module.
|
||||
pub mod metric;
|
||||
|
||||
mod callback;
|
||||
/// All information collected during training.
|
||||
pub mod info;
|
||||
|
||||
mod collector;
|
||||
mod learner;
|
||||
|
||||
pub use callback::*;
|
||||
pub use collector::*;
|
||||
pub use learner::*;
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -4,6 +4,7 @@ use std::sync::mpsc;
|
|||
enum Message<T> {
|
||||
Log(T),
|
||||
End,
|
||||
Sync(mpsc::Sender<()>),
|
||||
}
|
||||
/// Async logger.
|
||||
pub struct AsyncLogger<T> {
|
||||
|
@ -30,6 +31,9 @@ where
|
|||
Message::End => {
|
||||
return;
|
||||
}
|
||||
Message::Sync(callback) => {
|
||||
callback.send(()).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -48,6 +52,17 @@ impl<T: Send + Sync + 'static> AsyncLogger<T> {
|
|||
|
||||
Self { sender, handler }
|
||||
}
|
||||
|
||||
/// Sync the async logger.
|
||||
pub(crate) fn sync(&self) {
|
||||
let (sender, receiver) = mpsc::channel();
|
||||
|
||||
self.sender.send(Message::Sync(sender)).unwrap();
|
||||
|
||||
receiver
|
||||
.recv()
|
||||
.expect("Should sync, otherwise the thread is dead.");
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Send> Logger<T> for AsyncLogger<T> {
|
||||
|
|
|
@ -49,11 +49,14 @@ impl FileMetricLogger {
|
|||
|
||||
fn file_path(&self, name: &str, epoch: usize) -> String {
|
||||
let directory = format!("{}/epoch-{}", self.directory, epoch);
|
||||
std::fs::create_dir_all(&directory).ok();
|
||||
let name = name.replace(' ', "_");
|
||||
|
||||
format!("{directory}/{name}.log")
|
||||
}
|
||||
fn create_directory(&self, epoch: usize) {
|
||||
let directory = format!("{}/epoch-{}", self.directory, epoch);
|
||||
std::fs::create_dir_all(directory).ok();
|
||||
}
|
||||
}
|
||||
|
||||
impl MetricLogger for FileMetricLogger {
|
||||
|
@ -64,6 +67,8 @@ impl MetricLogger for FileMetricLogger {
|
|||
let logger = match self.loggers.get_mut(key) {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
self.create_directory(self.epoch);
|
||||
|
||||
let file_path = self.file_path(key, self.epoch);
|
||||
let logger = FileLogger::new(&file_path);
|
||||
let logger = AsyncLogger::new(logger);
|
||||
|
@ -82,6 +87,10 @@ impl MetricLogger for FileMetricLogger {
|
|||
}
|
||||
|
||||
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<f64>, String> {
|
||||
if let Some(value) = self.loggers.get(name) {
|
||||
value.sync()
|
||||
}
|
||||
|
||||
let file_path = self.file_path(name, epoch);
|
||||
|
||||
let mut errors = false;
|
||||
|
|
|
@ -1,285 +0,0 @@
|
|||
use crate::{
|
||||
logger::MetricLogger,
|
||||
metric::{Adaptor, Metric, MetricEntry, MetricMetadata, Numeric},
|
||||
LearnerCallback, LearnerItem,
|
||||
};
|
||||
use burn_core::data::dataloader::Progress;
|
||||
|
||||
/// Holds all metrics, metric loggers, and a metrics renderer.
|
||||
pub struct MetricsCallback<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
metrics: Metrics<T, V>,
|
||||
logger_train: Box<dyn MetricLogger>,
|
||||
logger_valid: Box<dyn MetricLogger>,
|
||||
renderer: Box<dyn MetricsRenderer>,
|
||||
}
|
||||
|
||||
impl<T, V> MetricsCallback<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
/// Creates a new metrics callback.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `renderer` - The metrics renderer.
|
||||
/// * `metrics` - The metrics holder.
|
||||
/// * `logger_train` - The training logger.
|
||||
/// * `logger_valid` - The validation logger.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new metrics callback.
|
||||
pub(crate) fn new(
|
||||
renderer: Box<dyn MetricsRenderer>,
|
||||
metrics: Metrics<T, V>,
|
||||
logger_train: Box<dyn MetricLogger>,
|
||||
logger_valid: Box<dyn MetricLogger>,
|
||||
) -> Self {
|
||||
Self {
|
||||
metrics,
|
||||
logger_train,
|
||||
logger_valid,
|
||||
renderer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, V> LearnerCallback for MetricsCallback<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
type ItemTrain = T;
|
||||
type ItemValid = V;
|
||||
|
||||
fn on_train_item(&mut self, item: LearnerItem<T>) {
|
||||
let metadata = (&item).into();
|
||||
for metric in self.metrics.train.iter_mut() {
|
||||
let state = metric.update(&item, &metadata);
|
||||
self.logger_train.log(&state);
|
||||
|
||||
self.renderer.update_train(MetricState::Generic(state));
|
||||
}
|
||||
for metric in self.metrics.train_numeric.iter_mut() {
|
||||
let (state, value) = metric.update(&item, &metadata);
|
||||
self.logger_train.log(&state);
|
||||
|
||||
self.renderer
|
||||
.update_train(MetricState::Numeric(state, value));
|
||||
}
|
||||
self.renderer.render_train(item.into());
|
||||
}
|
||||
|
||||
fn on_valid_item(&mut self, item: LearnerItem<V>) {
|
||||
let metadata = (&item).into();
|
||||
for metric in self.metrics.valid.iter_mut() {
|
||||
let state = metric.update(&item, &metadata);
|
||||
self.logger_valid.log(&state);
|
||||
|
||||
self.renderer.update_valid(MetricState::Generic(state));
|
||||
}
|
||||
for metric in self.metrics.valid_numeric.iter_mut() {
|
||||
let (state, value) = metric.update(&item, &metadata);
|
||||
self.logger_valid.log(&state);
|
||||
|
||||
self.renderer
|
||||
.update_valid(MetricState::Numeric(state, value));
|
||||
}
|
||||
self.renderer.render_valid(item.into());
|
||||
}
|
||||
|
||||
fn on_train_end_epoch(&mut self, epoch: usize) {
|
||||
for metric in self.metrics.train.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
for metric in self.metrics.train_numeric.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
self.logger_train.epoch(epoch + 1);
|
||||
}
|
||||
|
||||
fn on_valid_end_epoch(&mut self, epoch: usize) {
|
||||
for metric in self.metrics.valid.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
for metric in self.metrics.valid_numeric.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
self.logger_valid.epoch(epoch + 1);
|
||||
}
|
||||
}
|
||||
|
||||
/// Training progress.
|
||||
#[derive(Debug)]
|
||||
pub struct TrainingProgress {
|
||||
/// The progress.
|
||||
pub progress: Progress,
|
||||
|
||||
/// The epoch.
|
||||
pub epoch: usize,
|
||||
|
||||
/// The total number of epochs.
|
||||
pub epoch_total: usize,
|
||||
|
||||
/// The iteration.
|
||||
pub iteration: usize,
|
||||
}
|
||||
|
||||
impl TrainingProgress {
|
||||
/// Creates a new empty training progress.
|
||||
pub fn none() -> Self {
|
||||
Self {
|
||||
progress: Progress {
|
||||
items_processed: 0,
|
||||
items_total: 0,
|
||||
},
|
||||
epoch: 0,
|
||||
epoch_total: 0,
|
||||
iteration: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The state of a metric.
|
||||
#[derive(Debug)]
|
||||
pub enum MetricState {
|
||||
/// A generic metric.
|
||||
Generic(MetricEntry),
|
||||
|
||||
/// A numeric metric.
|
||||
Numeric(MetricEntry, f64),
|
||||
}
|
||||
|
||||
/// Trait for rendering metrics.
|
||||
pub trait MetricsRenderer: Send + Sync {
|
||||
/// Updates the training metric state.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `state` - The metric state.
|
||||
fn update_train(&mut self, state: MetricState);
|
||||
|
||||
/// Updates the validation metric state.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `state` - The metric state.
|
||||
fn update_valid(&mut self, state: MetricState);
|
||||
|
||||
/// Renders the training progress.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The training progress.
|
||||
fn render_train(&mut self, item: TrainingProgress);
|
||||
|
||||
/// Renders the validation progress.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The validation progress.
|
||||
fn render_valid(&mut self, item: TrainingProgress);
|
||||
}
|
||||
|
||||
/// A container for the metrics held by a metrics callback.
|
||||
pub(crate) struct Metrics<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
pub(crate) train: Vec<Box<dyn MetricUpdater<T>>>,
|
||||
pub(crate) valid: Vec<Box<dyn MetricUpdater<V>>>,
|
||||
pub(crate) train_numeric: Vec<Box<dyn NumericMetricUpdater<T>>>,
|
||||
pub(crate) valid_numeric: Vec<Box<dyn NumericMetricUpdater<V>>>,
|
||||
}
|
||||
|
||||
impl<T, V> Metrics<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
train: vec![],
|
||||
valid: vec![],
|
||||
train_numeric: vec![],
|
||||
valid_numeric: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<LearnerItem<T>> for TrainingProgress {
|
||||
fn from(item: LearnerItem<T>) -> Self {
|
||||
Self {
|
||||
progress: item.progress,
|
||||
epoch: item.epoch,
|
||||
epoch_total: item.epoch_total,
|
||||
iteration: item.iteration,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<&LearnerItem<T>> for MetricMetadata {
|
||||
fn from(item: &LearnerItem<T>) -> Self {
|
||||
Self {
|
||||
progress: item.progress.clone(),
|
||||
epoch: item.epoch,
|
||||
epoch_total: item.epoch_total,
|
||||
iteration: item.iteration,
|
||||
lr: item.lr,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait NumericMetricUpdater<T>: Send + Sync {
|
||||
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> (MetricEntry, f64);
|
||||
fn clear(&mut self);
|
||||
}
|
||||
|
||||
pub(crate) trait MetricUpdater<T>: Send + Sync {
|
||||
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry;
|
||||
fn clear(&mut self);
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct MetricWrapper<M> {
|
||||
metric: M,
|
||||
}
|
||||
|
||||
impl<T, M> NumericMetricUpdater<T> for MetricWrapper<M>
|
||||
where
|
||||
T: 'static,
|
||||
M: Metric + Numeric + 'static,
|
||||
T: Adaptor<M::Input>,
|
||||
{
|
||||
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> (MetricEntry, f64) {
|
||||
let update = self.metric.update(&item.item.adapt(), metadata);
|
||||
let numeric = self.metric.value();
|
||||
|
||||
(update, numeric)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.metric.clear()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, M> MetricUpdater<T> for MetricWrapper<M>
|
||||
where
|
||||
T: 'static,
|
||||
M: Metric + 'static,
|
||||
T: Adaptor<M::Input>,
|
||||
{
|
||||
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry {
|
||||
self.metric.update(&item.item.adapt(), metadata)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.metric.clear()
|
||||
}
|
||||
}
|
|
@ -1,7 +1,4 @@
|
|||
/// Callback module for training progress.
|
||||
pub mod callback;
|
||||
|
||||
/// State module for callback metrics.
|
||||
/// State module.
|
||||
pub mod state;
|
||||
|
||||
mod acc;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{format_float, MetricEntry, Numeric};
|
||||
use crate::metric::{format_float, MetricEntry, Numeric};
|
||||
|
||||
/// Usefull utility to implement numeric metrics.
|
||||
///
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
use burn_core::data::dataloader::Progress;
|
||||
|
||||
use crate::metric::MetricEntry;
|
||||
|
||||
/// Trait for rendering metrics.
|
||||
pub trait MetricsRenderer: Send + Sync {
|
||||
/// Updates the training metric state.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `state` - The metric state.
|
||||
fn update_train(&mut self, state: MetricState);
|
||||
|
||||
/// Updates the validation metric state.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `state` - The metric state.
|
||||
fn update_valid(&mut self, state: MetricState);
|
||||
|
||||
/// Renders the training progress.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The training progress.
|
||||
fn render_train(&mut self, item: TrainingProgress);
|
||||
|
||||
/// Renders the validation progress.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The validation progress.
|
||||
fn render_valid(&mut self, item: TrainingProgress);
|
||||
}
|
||||
|
||||
/// The state of a metric.
|
||||
#[derive(Debug)]
|
||||
pub enum MetricState {
|
||||
/// A generic metric.
|
||||
Generic(MetricEntry),
|
||||
|
||||
/// A numeric metric.
|
||||
Numeric(MetricEntry, f64),
|
||||
}
|
||||
|
||||
/// Training progress.
|
||||
#[derive(Debug)]
|
||||
pub struct TrainingProgress {
|
||||
/// The progress.
|
||||
pub progress: Progress,
|
||||
|
||||
/// The epoch.
|
||||
pub epoch: usize,
|
||||
|
||||
/// The total number of epochs.
|
||||
pub epoch_total: usize,
|
||||
|
||||
/// The iteration.
|
||||
pub iteration: usize,
|
||||
}
|
||||
|
||||
impl TrainingProgress {
|
||||
/// Creates a new empty training progress.
|
||||
pub fn none() -> Self {
|
||||
Self {
|
||||
progress: Progress {
|
||||
items_processed: 0,
|
||||
items_total: 0,
|
||||
},
|
||||
epoch: 0,
|
||||
epoch_total: 0,
|
||||
iteration: 0,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,5 +1,4 @@
|
|||
mod base;
|
||||
|
||||
pub use base::*;
|
||||
|
||||
#[cfg(not(feature = "tui"))]
|
|
@ -1,4 +1,4 @@
|
|||
use crate::metric::callback::TrainingProgress;
|
||||
use crate::renderer::TrainingProgress;
|
||||
|
||||
use super::{FullHistoryPlot, RecentHistoryPlot, TerminalFrame};
|
||||
use crossterm::event::{Event, KeyCode};
|
|
@ -1,5 +1,6 @@
|
|||
use crate::renderer::TrainingProgress;
|
||||
|
||||
use super::TerminalFrame;
|
||||
use crate::metric::callback::TrainingProgress;
|
||||
use ratatui::{
|
||||
prelude::{Alignment, Constraint, Direction, Layout, Rect},
|
||||
style::{Color, Style, Stylize},
|
|
@ -1,5 +1,5 @@
|
|||
use crate::metric::callback::tui::NumericMetricsState;
|
||||
use crate::metric::callback::{MetricState, MetricsRenderer, TrainingProgress};
|
||||
use crate::renderer::{tui::NumericMetricsState, MetricsRenderer};
|
||||
use crate::renderer::{MetricState, TrainingProgress};
|
||||
use crate::TrainingInterrupter;
|
||||
use crossterm::{
|
||||
event::{self, Event, KeyCode},
|
|
@ -1,5 +1,5 @@
|
|||
use super::TerminalFrame;
|
||||
use crate::metric::callback::TrainingProgress;
|
||||
use crate::renderer::TrainingProgress;
|
||||
use ratatui::{
|
||||
prelude::{Alignment, Rect},
|
||||
style::{Color, Style, Stylize},
|
|
@ -1,5 +1,5 @@
|
|||
use burn::data::dataset::source::huggingface::MNISTDataset;
|
||||
use burn::train::metric::callback::{MetricState, MetricsRenderer, TrainingProgress};
|
||||
use burn::train::renderer::{MetricState, MetricsRenderer, TrainingProgress};
|
||||
use burn::train::LearnerBuilder;
|
||||
use burn::{
|
||||
config::Config, data::dataloader::DataLoaderBuilder, optim::AdamConfig,
|
||||
|
|
|
@ -88,10 +88,10 @@ pub fn train<B: ADBackend>(artifact_dir: &str, config: TrainingConfig, device: B
|
|||
.build(MNISTDataset::test());
|
||||
|
||||
let learner = LearnerBuilder::new(artifact_dir)
|
||||
.metric_train_plot(AccuracyMetric::new())
|
||||
.metric_valid_plot(AccuracyMetric::new())
|
||||
.metric_train_plot(LossMetric::new())
|
||||
.metric_valid_plot(LossMetric::new())
|
||||
.metric_train_numeric(AccuracyMetric::new())
|
||||
.metric_valid_numeric(AccuracyMetric::new())
|
||||
.metric_train_numeric(LossMetric::new())
|
||||
.metric_valid_numeric(LossMetric::new())
|
||||
.with_file_checkpointer(1, CompactRecorder::new())
|
||||
.devices(vec![device])
|
||||
.num_epochs(config.num_epochs)
|
||||
|
|
|
@ -58,16 +58,16 @@ pub fn run<B: ADBackend>(device: B::Device) {
|
|||
|
||||
// Model
|
||||
let learner = LearnerBuilder::new(ARTIFACT_DIR)
|
||||
.metric_train_plot(AccuracyMetric::new())
|
||||
.metric_valid_plot(AccuracyMetric::new())
|
||||
.metric_train_plot(CpuUse::new())
|
||||
.metric_valid_plot(CpuUse::new())
|
||||
.metric_train_plot(CpuMemory::new())
|
||||
.metric_valid_plot(CpuMemory::new())
|
||||
.metric_train_plot(CpuTemperature::new())
|
||||
.metric_valid_plot(CpuTemperature::new())
|
||||
.metric_train_plot(LossMetric::new())
|
||||
.metric_valid_plot(LossMetric::new())
|
||||
.metric_train_numeric(AccuracyMetric::new())
|
||||
.metric_valid_numeric(AccuracyMetric::new())
|
||||
.metric_train_numeric(CpuUse::new())
|
||||
.metric_valid_numeric(CpuUse::new())
|
||||
.metric_train_numeric(CpuMemory::new())
|
||||
.metric_valid_numeric(CpuMemory::new())
|
||||
.metric_train_numeric(CpuTemperature::new())
|
||||
.metric_valid_numeric(CpuTemperature::new())
|
||||
.metric_train_numeric(LossMetric::new())
|
||||
.metric_valid_numeric(LossMetric::new())
|
||||
.with_file_checkpointer(1, CompactRecorder::new())
|
||||
.devices(vec![device])
|
||||
.num_epochs(config.num_epochs)
|
||||
|
|
|
@ -95,9 +95,9 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
|
|||
.metric_valid(CUDAMetric::new())
|
||||
.metric_train(AccuracyMetric::new())
|
||||
.metric_valid(AccuracyMetric::new())
|
||||
.metric_train_plot(LossMetric::new())
|
||||
.metric_valid_plot(LossMetric::new())
|
||||
.metric_train_plot(LearningRateMetric::new())
|
||||
.metric_train_numeric(LossMetric::new())
|
||||
.metric_valid_numeric(LossMetric::new())
|
||||
.metric_train_numeric(LearningRateMetric::new())
|
||||
.with_file_checkpointer(2, CompactRecorder::new())
|
||||
.devices(vec![device])
|
||||
.num_epochs(config.num_epochs)
|
||||
|
|
|
@ -70,11 +70,11 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(
|
|||
let learner = LearnerBuilder::new(artifact_dir)
|
||||
.metric_train(CUDAMetric::new())
|
||||
.metric_valid(CUDAMetric::new())
|
||||
.metric_train_plot(AccuracyMetric::new().with_pad_token(tokenizer.pad_token()))
|
||||
.metric_valid_plot(AccuracyMetric::new().with_pad_token(tokenizer.pad_token()))
|
||||
.metric_train_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token()))
|
||||
.metric_valid_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token()))
|
||||
.metric_train(LossMetric::new())
|
||||
.metric_valid(LossMetric::new())
|
||||
.metric_train_plot(LearningRateMetric::new())
|
||||
.metric_train_numeric(LearningRateMetric::new())
|
||||
.with_file_checkpointer(2, CompactRecorder::new())
|
||||
.devices(vec![device])
|
||||
.grads_accumulation(accum)
|
||||
|
|
Loading…
Reference in New Issue