mirror of https://github.com/tracel-ai/burn.git
Feat: Some tweaks to make it more practical to integrate in a GUI app (#706)
* feat: Add support for using a custom renderer When integrating in an app, the CLI display is undesirable. This will allow us to collect the progress of iterations, so they can be displayed in a GUI. Because CLIDashboardRenderer() writes to the console when ::new() is called, the code has had to be refactored to defer creation until .build() is called. This meant that instead of delegating the metric assignments to the already-created dashboard, we instead need to store them and add them later. * feat: Allow opt-out of experiment.log
This commit is contained in:
parent
968cd6e390
commit
a4a9844da3
|
@ -3,7 +3,7 @@ use super::Learner;
|
|||
use crate::checkpoint::{AsyncCheckpointer, Checkpointer, FileCheckpointer};
|
||||
use crate::logger::{FileMetricLogger, MetricLogger};
|
||||
use crate::metric::dashboard::cli::CLIDashboardRenderer;
|
||||
use crate::metric::dashboard::Dashboard;
|
||||
use crate::metric::dashboard::{Dashboard, DashboardRenderer, MetricWrapper, Metrics};
|
||||
use crate::metric::{Adaptor, Metric, Numeric};
|
||||
use crate::AsyncTrainerCallback;
|
||||
use burn_core::lr_scheduler::LRScheduler;
|
||||
|
@ -24,7 +24,6 @@ where
|
|||
O: Optimizer<M, B>,
|
||||
S: LRScheduler,
|
||||
{
|
||||
dashboard: Dashboard<T, V>,
|
||||
checkpointer_model: Option<Arc<dyn Checkpointer<M::Record> + Send + Sync>>,
|
||||
checkpointer_optimizer: Option<Arc<dyn Checkpointer<O::Record> + Send + Sync>>,
|
||||
checkpointer_scheduler: Option<Arc<dyn Checkpointer<S::Record> + Send + Sync>>,
|
||||
|
@ -33,6 +32,11 @@ where
|
|||
directory: String,
|
||||
grad_accumulation: Option<usize>,
|
||||
devices: Vec<B::Device>,
|
||||
metric_logger_train: Option<Box<dyn MetricLogger + 'static>>,
|
||||
metric_logger_valid: Option<Box<dyn MetricLogger + 'static>>,
|
||||
renderer: Option<Box<dyn DashboardRenderer + 'static>>,
|
||||
metrics: Metrics<T, V>,
|
||||
log_to_file: bool,
|
||||
}
|
||||
|
||||
impl<B, T, V, Model, Optim, LR> LearnerBuilder<B, T, V, Model, Optim, LR>
|
||||
|
@ -50,12 +54,7 @@ where
|
|||
///
|
||||
/// * `directory` - The directory to save the checkpoints.
|
||||
pub fn new(directory: &str) -> Self {
|
||||
let renderer = Box::new(CLIDashboardRenderer::new());
|
||||
let logger_train = Box::new(FileMetricLogger::new(format!("{directory}/train").as_str()));
|
||||
let logger_valid = Box::new(FileMetricLogger::new(format!("{directory}/valid").as_str()));
|
||||
|
||||
Self {
|
||||
dashboard: Dashboard::new(renderer, logger_train, logger_valid),
|
||||
num_epochs: 1,
|
||||
checkpoint: None,
|
||||
checkpointer_model: None,
|
||||
|
@ -64,6 +63,11 @@ where
|
|||
directory: directory.to_string(),
|
||||
grad_accumulation: None,
|
||||
devices: vec![B::Device::default()],
|
||||
metric_logger_train: None,
|
||||
metric_logger_valid: None,
|
||||
metrics: Metrics::new(),
|
||||
renderer: None,
|
||||
log_to_file: true,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -78,8 +82,21 @@ where
|
|||
MT: MetricLogger + 'static,
|
||||
MV: MetricLogger + 'static,
|
||||
{
|
||||
self.dashboard
|
||||
.replace_loggers(Box::new(logger_train), Box::new(logger_valid));
|
||||
self.metric_logger_train = Some(Box::new(logger_train));
|
||||
self.metric_logger_valid = Some(Box::new(logger_valid));
|
||||
self
|
||||
}
|
||||
|
||||
/// Replace the default CLI renderer with a custom one.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `renderer` - The custom renderer.
|
||||
pub fn renderer<DR>(mut self, renderer: DR) -> Self
|
||||
where
|
||||
DR: DashboardRenderer + 'static,
|
||||
{
|
||||
self.renderer = Some(Box::new(renderer));
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -88,7 +105,9 @@ where
|
|||
where
|
||||
T: Adaptor<M::Input>,
|
||||
{
|
||||
self.dashboard.register_train(metric);
|
||||
self.metrics
|
||||
.train
|
||||
.push(Box::new(MetricWrapper::new(metric)));
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -97,7 +116,9 @@ where
|
|||
where
|
||||
V: Adaptor<M::Input>,
|
||||
{
|
||||
self.dashboard.register_valid(metric);
|
||||
self.metrics
|
||||
.valid
|
||||
.push(Box::new(MetricWrapper::new(metric)));
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -128,7 +149,9 @@ where
|
|||
M: Metric + Numeric + 'static,
|
||||
T: Adaptor<M::Input>,
|
||||
{
|
||||
self.dashboard.register_train_plot(metric);
|
||||
self.metrics
|
||||
.train_numeric
|
||||
.push(Box::new(MetricWrapper::new(metric)));
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -143,7 +166,9 @@ where
|
|||
where
|
||||
V: Adaptor<M::Input>,
|
||||
{
|
||||
self.dashboard.register_valid_plot(metric);
|
||||
self.metrics
|
||||
.valid_numeric
|
||||
.push(Box::new(MetricWrapper::new(metric)));
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -165,6 +190,14 @@ where
|
|||
self
|
||||
}
|
||||
|
||||
/// By default, Rust logs are captured and written into
|
||||
/// `experiment.log`. If disabled, standard Rust log handling
|
||||
/// will apply.
|
||||
pub fn log_to_file(mut self, enabled: bool) -> Self {
|
||||
self.log_to_file = enabled;
|
||||
self
|
||||
}
|
||||
|
||||
/// Register a checkpointer that will save the [optimizer](Optimizer) and the
|
||||
/// [model](ADModule).
|
||||
///
|
||||
|
@ -210,8 +243,21 @@ where
|
|||
Optim::Record: 'static,
|
||||
LR::Record: 'static,
|
||||
{
|
||||
self.init_logger();
|
||||
let callback = Box::new(self.dashboard);
|
||||
if self.log_to_file {
|
||||
self.init_logger();
|
||||
}
|
||||
let renderer = self
|
||||
.renderer
|
||||
.unwrap_or_else(|| Box::new(CLIDashboardRenderer::new()));
|
||||
let directory = &self.directory;
|
||||
let logger_train = self.metric_logger_train.unwrap_or_else(|| {
|
||||
Box::new(FileMetricLogger::new(format!("{directory}/train").as_str()))
|
||||
});
|
||||
let logger_valid = self.metric_logger_valid.unwrap_or_else(|| {
|
||||
Box::new(FileMetricLogger::new(format!("{directory}/valid").as_str()))
|
||||
});
|
||||
let dashboard = Dashboard::new(renderer, self.metrics, logger_train, logger_valid);
|
||||
let callback = Box::new(dashboard);
|
||||
let callback = Box::new(AsyncTrainerCallback::new(callback));
|
||||
|
||||
let checkpointer_optimizer = match self.checkpointer_optimizer {
|
||||
|
|
|
@ -69,7 +69,7 @@ pub trait Numeric {
|
|||
}
|
||||
|
||||
/// Data type that contains the current state of a metric at a given time.
|
||||
#[derive(new)]
|
||||
#[derive(new, Debug)]
|
||||
pub struct MetricEntry {
|
||||
/// The name of the metric.
|
||||
pub name: String,
|
||||
|
|
|
@ -6,6 +6,7 @@ use crate::{
|
|||
use burn_core::data::dataloader::Progress;
|
||||
|
||||
/// Training progress.
|
||||
#[derive(Debug)]
|
||||
pub struct TrainingProgress {
|
||||
/// The progress.
|
||||
pub progress: Progress,
|
||||
|
@ -36,6 +37,7 @@ impl TrainingProgress {
|
|||
}
|
||||
|
||||
/// A dashboard metric.
|
||||
#[derive(Debug)]
|
||||
pub enum DashboardMetricState {
|
||||
/// A generic metric.
|
||||
Generic(MetricEntry),
|
||||
|
@ -75,16 +77,40 @@ pub trait DashboardRenderer: Send + Sync {
|
|||
fn render_valid(&mut self, item: TrainingProgress);
|
||||
}
|
||||
|
||||
/// A dashboard container for all metrics.
|
||||
/// A container for the metrics held by a dashboard.
|
||||
pub(crate) struct Metrics<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
pub(crate) train: Vec<Box<dyn DashboardMetric<T>>>,
|
||||
pub(crate) valid: Vec<Box<dyn DashboardMetric<V>>>,
|
||||
pub(crate) train_numeric: Vec<Box<dyn DashboardNumericMetric<T>>>,
|
||||
pub(crate) valid_numeric: Vec<Box<dyn DashboardNumericMetric<V>>>,
|
||||
}
|
||||
|
||||
impl<T, V> Metrics<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
train: vec![],
|
||||
valid: vec![],
|
||||
train_numeric: vec![],
|
||||
valid_numeric: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Holds all metrics, metric loggers, and a dashboard renderer.
|
||||
pub struct Dashboard<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
metrics_train: Vec<Box<dyn DashboardMetric<T>>>,
|
||||
metrics_valid: Vec<Box<dyn DashboardMetric<V>>>,
|
||||
metrics_train_numeric: Vec<Box<dyn DashboardNumericMetric<T>>>,
|
||||
metrics_valid_numeric: Vec<Box<dyn DashboardNumericMetric<V>>>,
|
||||
metrics: Metrics<T, V>,
|
||||
logger_train: Box<dyn MetricLogger>,
|
||||
logger_valid: Box<dyn MetricLogger>,
|
||||
renderer: Box<dyn DashboardRenderer>,
|
||||
|
@ -100,94 +126,26 @@ where
|
|||
/// # Arguments
|
||||
///
|
||||
/// * `renderer` - The dashboard renderer.
|
||||
/// * `metrics` - The dashboard's metrics
|
||||
/// * `logger_train` - The training logger.
|
||||
/// * `logger_valid` - The validation logger.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new dashboard.
|
||||
pub fn new(
|
||||
pub(crate) fn new(
|
||||
renderer: Box<dyn DashboardRenderer>,
|
||||
metrics: Metrics<T, V>,
|
||||
logger_train: Box<dyn MetricLogger>,
|
||||
logger_valid: Box<dyn MetricLogger>,
|
||||
) -> Self {
|
||||
Self {
|
||||
metrics_train: Vec::new(),
|
||||
metrics_valid: Vec::new(),
|
||||
metrics_train_numeric: Vec::new(),
|
||||
metrics_valid_numeric: Vec::new(),
|
||||
metrics,
|
||||
logger_train,
|
||||
logger_valid,
|
||||
renderer,
|
||||
}
|
||||
}
|
||||
|
||||
/// Replace the current loggers with the provided ones.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `logger_train` - The training logger.
|
||||
/// * `logger_valid` - The validation logger.
|
||||
pub fn replace_loggers(
|
||||
&mut self,
|
||||
logger_train: Box<dyn MetricLogger>,
|
||||
logger_valid: Box<dyn MetricLogger>,
|
||||
) {
|
||||
self.logger_train = logger_train;
|
||||
self.logger_valid = logger_valid;
|
||||
}
|
||||
|
||||
/// Registers a training metric.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `metric` - The metric.
|
||||
pub fn register_train<M: Metric + 'static>(&mut self, metric: M)
|
||||
where
|
||||
T: Adaptor<M::Input>,
|
||||
{
|
||||
self.metrics_train
|
||||
.push(Box::new(MetricWrapper::new(metric)));
|
||||
}
|
||||
|
||||
/// Registers a training numeric metric.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `metric` - The metric.
|
||||
pub fn register_train_plot<M: Numeric + Metric + 'static>(&mut self, metric: M)
|
||||
where
|
||||
T: Adaptor<M::Input>,
|
||||
{
|
||||
self.metrics_train_numeric
|
||||
.push(Box::new(MetricWrapper::new(metric)));
|
||||
}
|
||||
|
||||
/// Registers a validation metric.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `metric` - The metric.
|
||||
pub fn register_valid<M: Metric + 'static>(&mut self, metric: M)
|
||||
where
|
||||
V: Adaptor<M::Input>,
|
||||
{
|
||||
self.metrics_valid
|
||||
.push(Box::new(MetricWrapper::new(metric)));
|
||||
}
|
||||
|
||||
/// Registers a validation numeric metric.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `metric` - The metric.
|
||||
pub fn register_valid_plot<M: Numeric + Metric + 'static>(&mut self, metric: M)
|
||||
where
|
||||
V: Adaptor<M::Input>,
|
||||
{
|
||||
self.metrics_valid_numeric
|
||||
.push(Box::new(MetricWrapper::new(metric)));
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<LearnerItem<T>> for TrainingProgress {
|
||||
|
@ -220,14 +178,14 @@ where
|
|||
{
|
||||
fn on_train_item(&mut self, item: LearnerItem<T>) {
|
||||
let metadata = (&item).into();
|
||||
for metric in self.metrics_train.iter_mut() {
|
||||
for metric in self.metrics.train.iter_mut() {
|
||||
let state = metric.update(&item, &metadata);
|
||||
self.logger_train.log(&state);
|
||||
|
||||
self.renderer
|
||||
.update_train(DashboardMetricState::Generic(state));
|
||||
}
|
||||
for metric in self.metrics_train_numeric.iter_mut() {
|
||||
for metric in self.metrics.train_numeric.iter_mut() {
|
||||
let (state, value) = metric.update(&item, &metadata);
|
||||
self.logger_train.log(&state);
|
||||
|
||||
|
@ -239,14 +197,14 @@ where
|
|||
|
||||
fn on_valid_item(&mut self, item: LearnerItem<V>) {
|
||||
let metadata = (&item).into();
|
||||
for metric in self.metrics_valid.iter_mut() {
|
||||
for metric in self.metrics.valid.iter_mut() {
|
||||
let state = metric.update(&item, &metadata);
|
||||
self.logger_valid.log(&state);
|
||||
|
||||
self.renderer
|
||||
.update_valid(DashboardMetricState::Generic(state));
|
||||
}
|
||||
for metric in self.metrics_valid_numeric.iter_mut() {
|
||||
for metric in self.metrics.valid_numeric.iter_mut() {
|
||||
let (state, value) = metric.update(&item, &metadata);
|
||||
self.logger_valid.log(&state);
|
||||
|
||||
|
@ -257,38 +215,38 @@ where
|
|||
}
|
||||
|
||||
fn on_train_end_epoch(&mut self, epoch: usize) {
|
||||
for metric in self.metrics_train.iter_mut() {
|
||||
for metric in self.metrics.train.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
for metric in self.metrics_train_numeric.iter_mut() {
|
||||
for metric in self.metrics.train_numeric.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
self.logger_train.epoch(epoch + 1);
|
||||
}
|
||||
|
||||
fn on_valid_end_epoch(&mut self, epoch: usize) {
|
||||
for metric in self.metrics_valid.iter_mut() {
|
||||
for metric in self.metrics.valid.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
for metric in self.metrics_valid_numeric.iter_mut() {
|
||||
for metric in self.metrics.valid_numeric.iter_mut() {
|
||||
metric.clear();
|
||||
}
|
||||
self.logger_valid.epoch(epoch + 1);
|
||||
}
|
||||
}
|
||||
|
||||
trait DashboardNumericMetric<T>: Send + Sync {
|
||||
pub(crate) trait DashboardNumericMetric<T>: Send + Sync {
|
||||
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> (MetricEntry, f64);
|
||||
fn clear(&mut self);
|
||||
}
|
||||
|
||||
trait DashboardMetric<T>: Send + Sync {
|
||||
pub(crate) trait DashboardMetric<T>: Send + Sync {
|
||||
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry;
|
||||
fn clear(&mut self);
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct MetricWrapper<M> {
|
||||
pub(crate) struct MetricWrapper<M> {
|
||||
metric: M,
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue