Feat: Some tweaks to make it more practical to integrate in a GUI app (#706)

* feat: Add support for using a custom renderer

When integrating in an app, the CLI display is undesirable. This will
allow us to collect the progress of iterations, so they can be displayed
in a GUI.

Because CLIDashboardRenderer() writes to the console when ::new() is
called, the code has had to be refactored to defer creation until .build()
is called. This meant that instead of delegating the metric assignments
to the already-created dashboard, we instead need to store them and add
them later.

* feat: Allow opt-out of experiment.log
This commit is contained in:
Damien Elmes 2023-08-29 06:23:31 +10:00 committed by GitHub
parent 968cd6e390
commit a4a9844da3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 108 additions and 104 deletions

View File

@ -3,7 +3,7 @@ use super::Learner;
use crate::checkpoint::{AsyncCheckpointer, Checkpointer, FileCheckpointer};
use crate::logger::{FileMetricLogger, MetricLogger};
use crate::metric::dashboard::cli::CLIDashboardRenderer;
use crate::metric::dashboard::Dashboard;
use crate::metric::dashboard::{Dashboard, DashboardRenderer, MetricWrapper, Metrics};
use crate::metric::{Adaptor, Metric, Numeric};
use crate::AsyncTrainerCallback;
use burn_core::lr_scheduler::LRScheduler;
@ -24,7 +24,6 @@ where
O: Optimizer<M, B>,
S: LRScheduler,
{
dashboard: Dashboard<T, V>,
checkpointer_model: Option<Arc<dyn Checkpointer<M::Record> + Send + Sync>>,
checkpointer_optimizer: Option<Arc<dyn Checkpointer<O::Record> + Send + Sync>>,
checkpointer_scheduler: Option<Arc<dyn Checkpointer<S::Record> + Send + Sync>>,
@ -33,6 +32,11 @@ where
directory: String,
grad_accumulation: Option<usize>,
devices: Vec<B::Device>,
metric_logger_train: Option<Box<dyn MetricLogger + 'static>>,
metric_logger_valid: Option<Box<dyn MetricLogger + 'static>>,
renderer: Option<Box<dyn DashboardRenderer + 'static>>,
metrics: Metrics<T, V>,
log_to_file: bool,
}
impl<B, T, V, Model, Optim, LR> LearnerBuilder<B, T, V, Model, Optim, LR>
@ -50,12 +54,7 @@ where
///
/// * `directory` - The directory to save the checkpoints.
pub fn new(directory: &str) -> Self {
let renderer = Box::new(CLIDashboardRenderer::new());
let logger_train = Box::new(FileMetricLogger::new(format!("{directory}/train").as_str()));
let logger_valid = Box::new(FileMetricLogger::new(format!("{directory}/valid").as_str()));
Self {
dashboard: Dashboard::new(renderer, logger_train, logger_valid),
num_epochs: 1,
checkpoint: None,
checkpointer_model: None,
@ -64,6 +63,11 @@ where
directory: directory.to_string(),
grad_accumulation: None,
devices: vec![B::Device::default()],
metric_logger_train: None,
metric_logger_valid: None,
metrics: Metrics::new(),
renderer: None,
log_to_file: true,
}
}
@ -78,8 +82,21 @@ where
MT: MetricLogger + 'static,
MV: MetricLogger + 'static,
{
self.dashboard
.replace_loggers(Box::new(logger_train), Box::new(logger_valid));
self.metric_logger_train = Some(Box::new(logger_train));
self.metric_logger_valid = Some(Box::new(logger_valid));
self
}
/// Replace the default CLI renderer with a custom one.
///
/// # Arguments
///
/// * `renderer` - The custom renderer.
pub fn renderer<DR>(mut self, renderer: DR) -> Self
where
DR: DashboardRenderer + 'static,
{
self.renderer = Some(Box::new(renderer));
self
}
@ -88,7 +105,9 @@ where
where
T: Adaptor<M::Input>,
{
self.dashboard.register_train(metric);
self.metrics
.train
.push(Box::new(MetricWrapper::new(metric)));
self
}
@ -97,7 +116,9 @@ where
where
V: Adaptor<M::Input>,
{
self.dashboard.register_valid(metric);
self.metrics
.valid
.push(Box::new(MetricWrapper::new(metric)));
self
}
@ -128,7 +149,9 @@ where
M: Metric + Numeric + 'static,
T: Adaptor<M::Input>,
{
self.dashboard.register_train_plot(metric);
self.metrics
.train_numeric
.push(Box::new(MetricWrapper::new(metric)));
self
}
@ -143,7 +166,9 @@ where
where
V: Adaptor<M::Input>,
{
self.dashboard.register_valid_plot(metric);
self.metrics
.valid_numeric
.push(Box::new(MetricWrapper::new(metric)));
self
}
@ -165,6 +190,14 @@ where
self
}
/// By default, Rust logs are captured and written into
/// `experiment.log`. If disabled, standard Rust log handling
/// will apply.
pub fn log_to_file(mut self, enabled: bool) -> Self {
self.log_to_file = enabled;
self
}
/// Register a checkpointer that will save the [optimizer](Optimizer) and the
/// [model](ADModule).
///
@ -210,8 +243,21 @@ where
Optim::Record: 'static,
LR::Record: 'static,
{
self.init_logger();
let callback = Box::new(self.dashboard);
if self.log_to_file {
self.init_logger();
}
let renderer = self
.renderer
.unwrap_or_else(|| Box::new(CLIDashboardRenderer::new()));
let directory = &self.directory;
let logger_train = self.metric_logger_train.unwrap_or_else(|| {
Box::new(FileMetricLogger::new(format!("{directory}/train").as_str()))
});
let logger_valid = self.metric_logger_valid.unwrap_or_else(|| {
Box::new(FileMetricLogger::new(format!("{directory}/valid").as_str()))
});
let dashboard = Dashboard::new(renderer, self.metrics, logger_train, logger_valid);
let callback = Box::new(dashboard);
let callback = Box::new(AsyncTrainerCallback::new(callback));
let checkpointer_optimizer = match self.checkpointer_optimizer {

View File

@ -69,7 +69,7 @@ pub trait Numeric {
}
/// Data type that contains the current state of a metric at a given time.
#[derive(new)]
#[derive(new, Debug)]
pub struct MetricEntry {
/// The name of the metric.
pub name: String,

View File

@ -6,6 +6,7 @@ use crate::{
use burn_core::data::dataloader::Progress;
/// Training progress.
#[derive(Debug)]
pub struct TrainingProgress {
/// The progress.
pub progress: Progress,
@ -36,6 +37,7 @@ impl TrainingProgress {
}
/// A dashboard metric.
#[derive(Debug)]
pub enum DashboardMetricState {
/// A generic metric.
Generic(MetricEntry),
@ -75,16 +77,40 @@ pub trait DashboardRenderer: Send + Sync {
fn render_valid(&mut self, item: TrainingProgress);
}
/// A dashboard container for all metrics.
/// A container for the metrics held by a dashboard.
pub(crate) struct Metrics<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
pub(crate) train: Vec<Box<dyn DashboardMetric<T>>>,
pub(crate) valid: Vec<Box<dyn DashboardMetric<V>>>,
pub(crate) train_numeric: Vec<Box<dyn DashboardNumericMetric<T>>>,
pub(crate) valid_numeric: Vec<Box<dyn DashboardNumericMetric<V>>>,
}
impl<T, V> Metrics<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
train: vec![],
valid: vec![],
train_numeric: vec![],
valid_numeric: vec![],
}
}
}
/// Holds all metrics, metric loggers, and a dashboard renderer.
pub struct Dashboard<T, V>
where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
metrics_train: Vec<Box<dyn DashboardMetric<T>>>,
metrics_valid: Vec<Box<dyn DashboardMetric<V>>>,
metrics_train_numeric: Vec<Box<dyn DashboardNumericMetric<T>>>,
metrics_valid_numeric: Vec<Box<dyn DashboardNumericMetric<V>>>,
metrics: Metrics<T, V>,
logger_train: Box<dyn MetricLogger>,
logger_valid: Box<dyn MetricLogger>,
renderer: Box<dyn DashboardRenderer>,
@ -100,94 +126,26 @@ where
/// # Arguments
///
/// * `renderer` - The dashboard renderer.
/// * `metrics` - The dashboard's metrics
/// * `logger_train` - The training logger.
/// * `logger_valid` - The validation logger.
///
/// # Returns
///
/// A new dashboard.
pub fn new(
pub(crate) fn new(
renderer: Box<dyn DashboardRenderer>,
metrics: Metrics<T, V>,
logger_train: Box<dyn MetricLogger>,
logger_valid: Box<dyn MetricLogger>,
) -> Self {
Self {
metrics_train: Vec::new(),
metrics_valid: Vec::new(),
metrics_train_numeric: Vec::new(),
metrics_valid_numeric: Vec::new(),
metrics,
logger_train,
logger_valid,
renderer,
}
}
/// Replace the current loggers with the provided ones.
///
/// # Arguments
///
/// * `logger_train` - The training logger.
/// * `logger_valid` - The validation logger.
pub fn replace_loggers(
&mut self,
logger_train: Box<dyn MetricLogger>,
logger_valid: Box<dyn MetricLogger>,
) {
self.logger_train = logger_train;
self.logger_valid = logger_valid;
}
/// Registers a training metric.
///
/// # Arguments
///
/// * `metric` - The metric.
pub fn register_train<M: Metric + 'static>(&mut self, metric: M)
where
T: Adaptor<M::Input>,
{
self.metrics_train
.push(Box::new(MetricWrapper::new(metric)));
}
/// Registers a training numeric metric.
///
/// # Arguments
///
/// * `metric` - The metric.
pub fn register_train_plot<M: Numeric + Metric + 'static>(&mut self, metric: M)
where
T: Adaptor<M::Input>,
{
self.metrics_train_numeric
.push(Box::new(MetricWrapper::new(metric)));
}
/// Registers a validation metric.
///
/// # Arguments
///
/// * `metric` - The metric.
pub fn register_valid<M: Metric + 'static>(&mut self, metric: M)
where
V: Adaptor<M::Input>,
{
self.metrics_valid
.push(Box::new(MetricWrapper::new(metric)));
}
/// Registers a validation numeric metric.
///
/// # Arguments
///
/// * `metric` - The metric.
pub fn register_valid_plot<M: Numeric + Metric + 'static>(&mut self, metric: M)
where
V: Adaptor<M::Input>,
{
self.metrics_valid_numeric
.push(Box::new(MetricWrapper::new(metric)));
}
}
impl<T> From<LearnerItem<T>> for TrainingProgress {
@ -220,14 +178,14 @@ where
{
fn on_train_item(&mut self, item: LearnerItem<T>) {
let metadata = (&item).into();
for metric in self.metrics_train.iter_mut() {
for metric in self.metrics.train.iter_mut() {
let state = metric.update(&item, &metadata);
self.logger_train.log(&state);
self.renderer
.update_train(DashboardMetricState::Generic(state));
}
for metric in self.metrics_train_numeric.iter_mut() {
for metric in self.metrics.train_numeric.iter_mut() {
let (state, value) = metric.update(&item, &metadata);
self.logger_train.log(&state);
@ -239,14 +197,14 @@ where
fn on_valid_item(&mut self, item: LearnerItem<V>) {
let metadata = (&item).into();
for metric in self.metrics_valid.iter_mut() {
for metric in self.metrics.valid.iter_mut() {
let state = metric.update(&item, &metadata);
self.logger_valid.log(&state);
self.renderer
.update_valid(DashboardMetricState::Generic(state));
}
for metric in self.metrics_valid_numeric.iter_mut() {
for metric in self.metrics.valid_numeric.iter_mut() {
let (state, value) = metric.update(&item, &metadata);
self.logger_valid.log(&state);
@ -257,38 +215,38 @@ where
}
fn on_train_end_epoch(&mut self, epoch: usize) {
for metric in self.metrics_train.iter_mut() {
for metric in self.metrics.train.iter_mut() {
metric.clear();
}
for metric in self.metrics_train_numeric.iter_mut() {
for metric in self.metrics.train_numeric.iter_mut() {
metric.clear();
}
self.logger_train.epoch(epoch + 1);
}
fn on_valid_end_epoch(&mut self, epoch: usize) {
for metric in self.metrics_valid.iter_mut() {
for metric in self.metrics.valid.iter_mut() {
metric.clear();
}
for metric in self.metrics_valid_numeric.iter_mut() {
for metric in self.metrics.valid_numeric.iter_mut() {
metric.clear();
}
self.logger_valid.epoch(epoch + 1);
}
}
trait DashboardNumericMetric<T>: Send + Sync {
pub(crate) trait DashboardNumericMetric<T>: Send + Sync {
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> (MetricEntry, f64);
fn clear(&mut self);
}
trait DashboardMetric<T>: Send + Sync {
pub(crate) trait DashboardMetric<T>: Send + Sync {
fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry;
fn clear(&mut self);
}
#[derive(new)]
struct MetricWrapper<M> {
pub(crate) struct MetricWrapper<M> {
metric: M,
}