Feat/dashboard tui (#790)

This commit is contained in:
Nathaniel Simard 2023-09-13 10:45:14 -04:00 committed by GitHub
parent 4f72578260
commit 57d6a566be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 1599 additions and 472 deletions

View File

@ -1,5 +1,6 @@
[default] [default]
extend-ignore-identifiers-re = [ extend-ignore-identifiers-re = [
"ratatui",
"NdArray*", "NdArray*",
"ND" "ND"
] ]

View File

@ -62,11 +62,16 @@ where
let mut iterator = dataloader_cloned.iter(); let mut iterator = dataloader_cloned.iter();
while let Some(item) = iterator.next() { while let Some(item) = iterator.next() {
let progress = iterator.progress(); let progress = iterator.progress();
sender_cloned
.send(Message::Batch(index, item, progress)) match sender_cloned.send(Message::Batch(index, item, progress)) {
.unwrap(); Ok(_) => {}
// The receiver is probably gone, no need to panic, just need to stop
// iterating.
Err(_) => return,
};
} }
sender_cloned.send(Message::Done).unwrap(); // Same thing.
sender_cloned.send(Message::Done).ok();
}) })
}) })
.collect(); .collect();

View File

@ -11,17 +11,15 @@ repository = "https://github.com/burn-rs/burn/tree/main/burn-train"
version = "0.10.0" version = "0.10.0"
[features] [features]
default = ["metrics", "ui"] default = ["metrics", "tui"]
metrics = [ metrics = [
"nvml-wrapper", "nvml-wrapper",
"sysinfo", "sysinfo",
"systemstat" "systemstat"
] ]
ui = [ tui = [
"indicatif", "ratatui",
"rgb", "crossterm"
"terminal_size",
"textplots",
] ]
[dependencies] [dependencies]
@ -38,10 +36,8 @@ sysinfo = { version = "0.29.8", optional = true }
systemstat = { version = "0.2.3", optional = true } systemstat = { version = "0.2.3", optional = true }
# Text UI # Text UI
indicatif = { version = "0.17.5", optional = true } ratatui = { version = "0.23", optional = true, features = ["all-widgets"] }
rgb = { version = "0.8.36", optional = true } crossterm = { version = "0.27", optional = true }
terminal_size = { version = "0.2.6", optional = true }
textplots = { version = "0.8.0", optional = true }
# Utilities # Utilities
derive-new = {workspace = true} derive-new = {workspace = true}

View File

@ -2,8 +2,9 @@ use super::log::install_file_logger;
use super::Learner; use super::Learner;
use crate::checkpoint::{AsyncCheckpointer, Checkpointer, FileCheckpointer}; use crate::checkpoint::{AsyncCheckpointer, Checkpointer, FileCheckpointer};
use crate::logger::{FileMetricLogger, MetricLogger}; use crate::logger::{FileMetricLogger, MetricLogger};
use crate::metric::dashboard::CLIDashboardRenderer; use crate::metric::dashboard::{
use crate::metric::dashboard::{Dashboard, DashboardRenderer, MetricWrapper, Metrics}; Dashboard, DashboardRenderer, MetricWrapper, Metrics, SelectedDashboardRenderer,
};
use crate::metric::{Adaptor, Metric}; use crate::metric::{Adaptor, Metric};
use crate::AsyncTrainerCallback; use crate::AsyncTrainerCallback;
use burn_core::lr_scheduler::LRScheduler; use burn_core::lr_scheduler::LRScheduler;
@ -259,7 +260,7 @@ where
} }
let renderer = self let renderer = self
.renderer .renderer
.unwrap_or_else(|| Box::new(CLIDashboardRenderer::new())); .unwrap_or_else(|| Box::new(SelectedDashboardRenderer::new(self.interrupter.clone())));
let directory = &self.directory; let directory = &self.directory;
let logger_train = self.metric_logger_train.unwrap_or_else(|| { let logger_train = self.metric_logger_train.unwrap_or_else(|| {
Box::new(FileMetricLogger::new(format!("{directory}/train").as_str())) Box::new(FileMetricLogger::new(format!("{directory}/train").as_str()))

View File

@ -199,6 +199,7 @@ impl<TI> TrainEpoch<TI> {
// The main device is always the first in the list. // The main device is always the first in the list.
let device_main = devices.get(0).unwrap().clone(); let device_main = devices.get(0).unwrap().clone();
let mut interrupted = false;
loop { loop {
let items = step.step(&mut iterator, &model); let items = step.step(&mut iterator, &model);
@ -234,9 +235,14 @@ impl<TI> TrainEpoch<TI> {
callback.on_train_item(item); callback.on_train_item(item);
if interrupter.should_stop() { if interrupter.should_stop() {
log::info!("Training interrupted."); log::info!("Training interrupted.");
interrupted = true;
break; break;
} }
} }
if interrupted {
break;
}
} }
callback.on_train_end_epoch(self.epoch); callback.on_train_end_epoch(self.epoch);

View File

@ -168,6 +168,10 @@ where
); );
} }
if self.interrupter.should_stop() {
break;
}
let epoch_valid = ValidEpoch::new(dataloader_valid.clone(), epoch, self.num_epochs); let epoch_valid = ValidEpoch::new(dataloader_valid.clone(), epoch, self.num_epochs);
epoch_valid.run(&model, &mut self.callback, &self.interrupter); epoch_valid.run(&model, &mut self.callback, &self.interrupter);

View File

@ -78,3 +78,13 @@ pub struct MetricEntry {
/// The string to be saved. /// The string to be saved.
pub serialize: String, pub serialize: String,
} }
/// Format a float with the given precision. Will use scientific notation if necessary.
pub fn format_float(float: f64, precision: usize) -> String {
let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);
match scientific_notation_threshold >= float {
true => format!("{float:.precision$e}"),
false => format!("{float:.precision$}"),
}
}

View File

@ -1,5 +1,5 @@
/// The CPU use metric. /// The CPU use metric.
use super::MetricMetadata; use super::{MetricMetadata, Numeric};
use crate::metric::{Metric, MetricEntry}; use crate::metric::{Metric, MetricEntry};
use sysinfo::{CpuExt, System, SystemExt}; use sysinfo::{CpuExt, System, SystemExt};
@ -59,3 +59,9 @@ impl Metric for CpuUse {
fn clear(&mut self) {} fn clear(&mut self) {}
} }
impl Numeric for CpuUse {
fn value(&self) -> f64 {
self.use_percentage as f64
}
}

View File

@ -1,266 +0,0 @@
use super::{DashboardMetricState, DashboardRenderer, TextPlot, TrainingProgress};
use indicatif::{MultiProgress, ProgressBar, ProgressState, ProgressStyle};
use std::{collections::HashMap, fmt::Write};
static MAX_REFRESH_RATE_MILLIS: u128 = 250;
/// The CLI dashboard renderer.
pub struct CLIDashboardRenderer {
pb_epoch: ProgressBar,
pb_iteration: ProgressBar,
last_update: std::time::Instant,
progress: TrainingProgress,
metric_train: HashMap<String, String>,
metric_valid: HashMap<String, String>,
metric_both_plot: HashMap<String, TextPlot>,
metric_train_plot: HashMap<String, TextPlot>,
metric_valid_plot: HashMap<String, TextPlot>,
}
impl TrainingProgress {
fn finished(&self) -> bool {
self.epoch == self.epoch_total && self.progress.items_processed == self.progress.items_total
}
}
impl Default for CLIDashboardRenderer {
fn default() -> Self {
CLIDashboardRenderer::new()
}
}
impl Drop for CLIDashboardRenderer {
fn drop(&mut self) {
self.pb_iteration.finish();
self.pb_epoch.finish();
}
}
impl DashboardRenderer for CLIDashboardRenderer {
fn update_train(&mut self, state: DashboardMetricState) {
match state {
DashboardMetricState::Generic(state) => {
self.metric_train.insert(state.name, state.formatted);
}
DashboardMetricState::Numeric(state, value) => {
let name = &state.name;
self.metric_train.insert(name.clone(), state.formatted);
if let Some(mut plot) = self.text_plot_in_both(name) {
plot.update_train(value as f32);
self.metric_both_plot.insert(name.clone(), plot);
return;
}
if let Some(plot) = self.metric_train_plot.get_mut(name) {
plot.update_train(value as f32);
} else {
let mut plot = TextPlot::new();
plot.update_train(value as f32);
self.metric_train_plot.insert(state.name, plot);
}
}
};
}
fn update_valid(&mut self, state: DashboardMetricState) {
match state {
DashboardMetricState::Generic(state) => {
self.metric_valid.insert(state.name, state.formatted);
}
DashboardMetricState::Numeric(state, value) => {
let name = &state.name;
self.metric_valid.insert(name.clone(), state.formatted);
if let Some(mut plot) = self.text_plot_in_both(name) {
plot.update_valid(value as f32);
self.metric_both_plot.insert(name.clone(), plot);
return;
}
if let Some(plot) = self.metric_valid_plot.get_mut(name) {
plot.update_valid(value as f32);
} else {
let mut plot = TextPlot::new();
plot.update_valid(value as f32);
self.metric_valid_plot.insert(state.name, plot);
}
}
};
}
fn render_train(&mut self, item: TrainingProgress) {
self.progress = item;
self.render();
}
fn render_valid(&mut self, item: TrainingProgress) {
self.progress = item;
self.render();
}
}
impl CLIDashboardRenderer {
/// Create a new CLI dashboard renderer.
pub fn new() -> Self {
let pb = MultiProgress::new();
let pb_epoch = ProgressBar::new(0);
let pb_iteration = ProgressBar::new(0);
let pb_iteration = pb.add(pb_iteration);
let pb_epoch = pb.add(pb_epoch);
Self {
pb_epoch,
pb_iteration,
last_update: std::time::Instant::now(),
progress: TrainingProgress::none(),
metric_train: HashMap::new(),
metric_valid: HashMap::new(),
metric_both_plot: HashMap::new(),
metric_train_plot: HashMap::new(),
metric_valid_plot: HashMap::new(),
}
}
fn text_plot_in_both(&mut self, key: &str) -> Option<TextPlot> {
if let Some(plot) = self.metric_both_plot.remove(key) {
return Some(plot);
}
if self.metric_train_plot.contains_key(key) && self.metric_valid_plot.contains_key(key) {
let plot_train = self.metric_train_plot.remove(key).unwrap();
let plot_valid = self.metric_valid_plot.remove(key).unwrap();
return Some(plot_train.merge(plot_valid));
}
None
}
fn register_template_plots(&self, template: String) -> String {
let mut template = template;
let mut metrics_keys = Vec::new();
for (name, metric) in self.metric_both_plot.iter() {
metrics_keys.push(format!(
" - {} RED: train | BLUE: valid \n{}",
name,
metric.render()
));
}
for (name, metric) in self.metric_train_plot.iter() {
metrics_keys.push(format!(" - Train {}: \n{}", name, metric.render()));
}
for (name, metric) in self.metric_valid_plot.iter() {
metrics_keys.push(format!(" - Valid {}: \n{}", name, metric.render()));
}
if !metrics_keys.is_empty() {
let metrics_template = metrics_keys.join("\n");
template += format!("{PLOTS_TAG}\n{metrics_template}\n").as_str();
}
template
}
fn register_template_metrics(&self, template: String) -> String {
let mut template = template;
let mut metrics_keys = Vec::new();
for (name, metric) in self.metric_train.iter() {
metrics_keys.push(format!(" - Train {name}: {metric}"));
}
for (name, metric) in self.metric_valid.iter() {
metrics_keys.push(format!(" - Valid {name}: {metric}"));
}
if !metrics_keys.is_empty() {
let metrics_template = metrics_keys.join("\n");
template += format!("{METRICS_TAG}\n{metrics_template}\n").as_str();
}
template
}
fn register_style_progress(
&self,
name: &'static str,
style: ProgressStyle,
value: String,
) -> ProgressStyle {
self.register_key_item(name, style, name.to_string(), value)
}
fn register_template_progress(&self, progress: &str, template: String) -> String {
let mut template = template;
let bar = "[{wide_bar:.cyan/blue}]";
template += format!(" - {progress} {bar}").as_str();
template
}
fn render(&mut self) {
if !self.progress.finished()
&& std::time::Instant::now()
.duration_since(self.last_update)
.as_millis()
< MAX_REFRESH_RATE_MILLIS
{
return;
}
let template = self.register_template_plots(String::default());
let template = self.register_template_metrics(template);
let template = template
+ format!(
"\n{}\n - Iteration {} Epoch {}/{}\n",
PROGRESS_TAG,
self.progress.iteration,
self.progress.epoch,
self.progress.epoch_total
)
.as_str();
let template = self.register_template_progress("iteration", template);
let style_iteration = ProgressStyle::with_template(&template).unwrap();
let style_iteration = self.register_style_progress(
"iteration",
style_iteration,
format!("{}", self.progress.iteration),
);
let template = self.register_template_progress("epoch ", String::default());
let style_epoch = ProgressStyle::with_template(&template).unwrap();
let style_epoch =
self.register_style_progress("epoch", style_epoch, format!("{}", self.progress.epoch));
self.pb_iteration
.set_style(style_iteration.progress_chars("#>-"));
self.pb_iteration
.set_position(self.progress.progress.items_processed as u64);
self.pb_iteration
.set_length(self.progress.progress.items_total as u64);
self.pb_epoch.set_style(style_epoch.progress_chars("#>-"));
self.pb_epoch.set_position(self.progress.epoch as u64 - 1);
self.pb_epoch.set_length(self.progress.epoch_total as u64);
self.last_update = std::time::Instant::now();
}
/// Registers a new metric to be displayed.
pub fn register_key_item(
&self,
key: &'static str,
style: ProgressStyle,
name: String,
formatted: String,
) -> ProgressStyle {
style.with_key(key, move |_state: &ProgressState, w: &mut dyn Write| {
write!(w, "{name}: {formatted}").unwrap()
})
}
}
static METRICS_TAG: &str = "[Metrics]";
static PLOTS_TAG: &str = "[Plots]";
static PROGRESS_TAG: &str = "[Progress]";

View File

@ -1,16 +1,13 @@
/// Command line interface module for the dashboard.
#[cfg(feature = "ui")]
mod cli;
#[cfg(not(feature = "ui"))]
mod cli_stub;
mod base; mod base;
mod plot;
pub use base::*; pub use base::*;
pub use plot::*;
#[cfg(feature = "ui")] #[cfg(not(feature = "tui"))]
pub use cli::CLIDashboardRenderer; mod cli_stub;
#[cfg(not(feature = "ui"))] #[cfg(not(feature = "tui"))]
pub use cli_stub::CLIDashboardRenderer; pub use cli_stub::CLIDashboardRenderer as SelectedDashboardRenderer;
#[cfg(feature = "tui")]
mod tui;
#[cfg(feature = "tui")]
pub use tui::TuiDashboardRenderer as SelectedDashboardRenderer;

View File

@ -1,160 +0,0 @@
/// Text plot.
pub struct TextPlot {
train: Vec<(f32, f32)>,
valid: Vec<(f32, f32)>,
max_values: usize,
iteration: usize,
}
impl Default for TextPlot {
fn default() -> Self {
Self::new()
}
}
impl TextPlot {
/// Creates a new text plot.
pub fn new() -> Self {
Self {
train: Vec::new(),
valid: Vec::new(),
max_values: 10000,
iteration: 0,
}
}
/// Merges two text plots.
///
/// # Arguments
///
/// * `self` - The first text plot.
/// * `other` - The second text plot.
///
/// # Returns
///
/// The merged text plot.
pub fn merge(self, other: Self) -> Self {
let mut other = other;
let mut train = self.train;
let mut valid = self.valid;
train.append(&mut other.train);
valid.append(&mut other.valid);
Self {
train,
valid,
max_values: usize::min(self.max_values, other.max_values),
iteration: self.iteration + other.iteration,
}
}
/// Updates the text plot with a new item for training.
///
/// # Arguments
///
/// * `item` - The new item.
pub fn update_train(&mut self, item: f32) {
self.iteration += 1;
self.train.push((self.iteration as f32, item));
let x_max = self
.train
.last()
.map(|(iteration, _)| *iteration)
.unwrap_or(f32::MIN);
let x_min = self
.train
.first()
.map(|(iteration, _)| *iteration)
.unwrap_or(f32::MAX);
if x_max - x_min > self.max_values as f32 && !self.train.is_empty() {
self.train.remove(0);
}
}
/// Updates the text plot with a new item for validation.
///
/// # Arguments
///
/// * `item` - The new item.
pub fn update_valid(&mut self, item: f32) {
self.iteration += 1;
self.valid.push((self.iteration as f32, item));
let x_max = self
.valid
.last()
.map(|(iteration, _)| *iteration)
.unwrap_or(f32::MIN);
let x_min = self
.valid
.first()
.map(|(iteration, _)| *iteration)
.unwrap_or(f32::MAX);
if x_max - x_min > self.max_values as f32 && !self.valid.is_empty() {
self.valid.remove(0);
}
}
/// Renders the text plot.
///
/// # Returns
///
/// The rendered text plot.
#[cfg(feature = "ui")]
pub fn render(&self) -> String {
use rgb::RGB8;
use terminal_size::{terminal_size, Height, Width};
use textplots::{Chart, ColorPlot, Shape};
let train_color = RGB8::new(255, 140, 140);
let valid_color = RGB8::new(140, 140, 255);
let x_max_valid = self
.valid
.last()
.map(|(iteration, _)| *iteration)
.unwrap_or(f32::MIN);
let x_max_train = self
.train
.last()
.map(|(iteration, _)| *iteration)
.unwrap_or(f32::MIN);
let x_max = f32::max(x_max_train, x_max_valid);
let x_min_valid = self
.valid
.first()
.map(|(iteration, _)| *iteration)
.unwrap_or(f32::MAX);
let x_min_train = self
.train
.first()
.map(|(iteration, _)| *iteration)
.unwrap_or(f32::MAX);
let x_min = f32::min(x_min_train, x_min_valid);
let (width, height) = match terminal_size() {
Some((Width(w), Height(_))) => (u32::max(64, w.into()) * 2 - 16, 64),
None => (256, 64),
};
Chart::new(width, height, x_min, x_max)
.linecolorplot(&Shape::Lines(&self.train), train_color)
.linecolorplot(&Shape::Lines(&self.valid), valid_color)
.to_string()
}
/// Renders the text plot.
///
/// # Returns
///
/// The rendered text plot.
#[cfg(not(feature = "ui"))]
pub fn render(&self) -> String {
panic!("ui feature not enabled on burn-train")
}
}

View File

@ -0,0 +1,45 @@
use super::{
ControlsView, NumericMetricView, ProgressBarView, StatusView, TerminalFrame, TextMetricView,
};
use ratatui::prelude::{Constraint, Direction, Layout, Rect};
#[derive(new)]
pub(crate) struct DashboardView<'a> {
metric_numeric: NumericMetricView<'a>,
metric_text: TextMetricView,
progress: ProgressBarView,
controls: ControlsView,
status: StatusView,
}
impl<'a> DashboardView<'a> {
pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {
let chunks = Layout::default()
.direction(Direction::Vertical)
.constraints([Constraint::Min(16), Constraint::Max(3)].as_ref())
.split(size);
let size_other = chunks[0];
let size_progress = chunks[1];
let chunks = Layout::default()
.direction(Direction::Horizontal)
.constraints([Constraint::Percentage(38), Constraint::Percentage(62)].as_ref())
.split(size_other);
let size_other = chunks[0];
let size_metric_numeric = chunks[1];
let chunks = Layout::default()
.direction(Direction::Vertical)
.constraints([Constraint::Max(5), Constraint::Min(6), Constraint::Max(6)].as_ref())
.split(size_other);
let size_controls = chunks[0];
let size_metric_text = chunks[1];
let size_status = chunks[2];
self.metric_numeric.render(frame, size_metric_numeric);
self.metric_text.render(frame, size_metric_text);
self.controls.render(frame, size_controls);
self.progress.render(frame, size_progress);
self.status.render(frame, size_status);
}
}

View File

@ -0,0 +1,46 @@
use super::TerminalFrame;
use ratatui::{
prelude::{Alignment, Rect},
style::{Color, Style, Stylize},
text::{Line, Span},
widgets::{Block, Borders, Paragraph, Wrap},
};
/// Controls view.
pub(crate) struct ControlsView;
impl ControlsView {
/// Render the view.
pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {
let lines = vec![
vec![
Span::from(" Quit : ").yellow().bold(),
Span::from("q ").bold(),
Span::from(" Stop the training.").italic(),
],
vec![
Span::from(" Plots Metrics : ").yellow().bold(),
Span::from("⬅ ➡").bold(),
Span::from(" Switch between metrics.").italic(),
],
vec![
Span::from(" Plots Type : ").yellow().bold(),
Span::from("⬆ ⬇").bold(),
Span::from(" Switch between types.").italic(),
],
];
let paragraph = Paragraph::new(lines.into_iter().map(Line::from).collect::<Vec<_>>())
.alignment(Alignment::Left)
.wrap(Wrap { trim: false })
.style(Style::default().fg(Color::Gray))
.block(
Block::default()
.borders(Borders::ALL)
.style(Style::default().fg(Color::Gray))
.title_alignment(Alignment::Left)
.title("Controls"),
);
frame.render_widget(paragraph, size);
}
}

View File

@ -0,0 +1,216 @@
use super::PlotAxes;
use ratatui::{
style::{Color, Style, Stylize},
symbols,
widgets::{Dataset, GraphType},
};
/// A plot that shows the full history at a reduced resolution.
pub(crate) struct FullHistoryPlot {
pub(crate) axes: PlotAxes,
train: FullHistoryPoints,
valid: FullHistoryPoints,
next_x_state: usize,
}
struct FullHistoryPoints {
min_x: f64,
max_x: f64,
min_y: f64,
max_y: f64,
points: Vec<(f64, f64)>,
max_samples: usize,
step_size: usize,
}
impl FullHistoryPlot {
/// Create a new history plot.
pub(crate) fn new(max_samples: usize) -> Self {
Self {
axes: PlotAxes::default(),
train: FullHistoryPoints::new(max_samples),
valid: FullHistoryPoints::new(max_samples),
next_x_state: 0,
}
}
/// Update the maximum amount of sample to display for the validation points.
///
/// This is necessary if we want the validation line to have the same point density as the
/// training line.
pub(crate) fn update_max_sample_valid(&mut self, ratio_train: f64) {
if self.valid.step_size == 1 {
self.valid.max_samples = (ratio_train * self.train.max_samples as f64) as usize;
}
}
/// Register a training data point.
pub(crate) fn push_train(&mut self, data: f64) {
let x_current = self.next_x();
self.train.push((x_current, data));
self.update_bounds();
}
/// Register a validation data point.
pub(crate) fn push_valid(&mut self, data: f64) {
let x_current = self.next_x();
self.valid.push((x_current, data));
self.update_bounds();
}
/// Create the training and validation datasets from the data points.
pub(crate) fn datasets(&self) -> Vec<Dataset<'_>> {
let mut datasets = Vec::with_capacity(2);
if !self.train.is_empty() {
datasets.push(self.train.dataset("Train", Color::LightRed));
}
if !self.valid.is_empty() {
datasets.push(self.valid.dataset("Valid", Color::LightBlue));
}
datasets
}
fn next_x(&mut self) -> f64 {
let value = self.next_x_state;
self.next_x_state += 1;
value as f64
}
fn update_bounds(&mut self) {
self.axes.update_bounds(
(self.train.min_x, self.train.max_x),
(self.valid.min_x, self.valid.max_x),
(self.train.min_y, self.train.max_y),
(self.valid.min_y, self.valid.max_y),
);
}
}
impl FullHistoryPoints {
fn new(max_samples: usize) -> Self {
Self {
min_x: 0.,
max_x: 0.,
min_y: f64::MAX,
max_y: f64::MIN,
points: Vec::with_capacity(max_samples),
max_samples,
step_size: 1,
}
}
fn push(&mut self, (x, y): (f64, f64)) {
if x as usize % self.step_size != 0 {
return;
}
if x > self.max_x {
self.max_x = x;
}
if x < self.min_x {
self.min_x = x;
}
if y > self.max_y {
self.max_y = y;
}
if y < self.min_y {
self.min_y = y
}
self.points.push((x, y));
if self.points.len() > self.max_samples {
self.resize();
}
}
/// We keep only half the points and we double the step size.
///
/// This ensure that we have the same amount of points across the X axis.
fn resize(&mut self) {
let mut points = Vec::with_capacity(self.max_samples / 2);
let mut max_x = f64::MIN;
let mut max_y = f64::MIN;
let mut min_x = f64::MAX;
let mut min_y = f64::MAX;
for (i, (x, y)) in self.points.drain(0..self.points.len()).enumerate() {
if i % 2 == 0 {
if x > max_x {
max_x = x;
}
if x < min_x {
min_x = x;
}
if y > max_y {
max_y = y;
}
if y < min_y {
min_y = y;
}
points.push((x, y));
}
}
self.points = points;
self.step_size *= 2;
self.min_x = min_x;
self.max_x = max_x;
self.min_y = min_y;
self.max_y = max_y;
}
fn dataset<'a>(&'a self, name: &'a str, color: Color) -> Dataset<'a> {
Dataset::default()
.name(name)
.marker(symbols::Marker::Braille)
.style(Style::default().fg(color).bold())
.graph_type(GraphType::Line)
.data(&self.points)
}
fn is_empty(&self) -> bool {
self.points.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_points() {
let mut chart = FullHistoryPlot::new(10);
chart.update_max_sample_valid(0.6);
for i in 0..100 {
chart.push_train(i as f64);
}
for i in 0..60 {
chart.push_valid(i as f64);
}
let expected_train = vec![
(0.0, 0.0),
(16.0, 16.0),
(32.0, 32.0),
(48.0, 48.0),
(64.0, 64.0),
(80.0, 80.0),
(96.0, 96.0),
];
let expected_valid = vec![(100.0, 0.0), (116.0, 16.0), (128.0, 28.0), (144.0, 44.0)];
assert_eq!(chart.train.points, expected_train);
assert_eq!(chart.valid.points, expected_valid);
}
}

View File

@ -0,0 +1,232 @@
use crate::metric::dashboard::TrainingProgress;
use super::{FullHistoryPlot, RecentHistoryPlot, TerminalFrame};
use crossterm::event::{Event, KeyCode};
use ratatui::{
prelude::{Alignment, Constraint, Direction, Layout, Rect},
style::{Color, Modifier, Style, Stylize},
text::Line,
widgets::{Axis, Block, Borders, Chart, Paragraph, Tabs},
};
use std::collections::HashMap;
/// 1000 seems to be required to see some improvement.
const MAX_NUM_SAMPLES_RECENT: usize = 1000;
/// 250 seems to be the right resolution when plotting all history.
/// Otherwise, there is too much points and the lines arent't smooth enough.
const MAX_NUM_SAMPLES_FULL: usize = 250;
/// Numeric metrics state that handles creating plots.
#[derive(Default)]
pub(crate) struct NumericMetricsState {
data: HashMap<String, (RecentHistoryPlot, FullHistoryPlot)>,
names: Vec<String>,
selected: usize,
kind: PlotKind,
num_samples_train: Option<usize>,
num_samples_valid: Option<usize>,
}
/// The kind of plot to display.
#[derive(Default, Clone, Copy)]
pub(crate) enum PlotKind {
/// Display the full history of the metric with reduced resolution.
#[default]
Full,
/// Display only the recent history of the metric, but with more resolution.
Recent,
}
impl NumericMetricsState {
/// Register a new training value for the metric with the given name.
pub(crate) fn push_train(&mut self, name: String, data: f64) {
if let Some((recent, full)) = self.data.get_mut(&name) {
recent.push_train(data);
full.push_train(data);
} else {
let mut recent = RecentHistoryPlot::new(MAX_NUM_SAMPLES_RECENT);
let mut full = FullHistoryPlot::new(MAX_NUM_SAMPLES_FULL);
recent.push_train(data);
full.push_train(data);
self.names.push(name.clone());
self.data.insert(name, (recent, full));
}
}
/// Register a new validation value for the metric with the given name.
pub(crate) fn push_valid(&mut self, key: String, data: f64) {
if let Some((recent, full)) = self.data.get_mut(&key) {
recent.push_valid(data);
full.push_valid(data);
} else {
let mut recent = RecentHistoryPlot::new(MAX_NUM_SAMPLES_RECENT);
let mut full = FullHistoryPlot::new(MAX_NUM_SAMPLES_FULL);
recent.push_valid(data);
full.push_valid(data);
self.data.insert(key, (recent, full));
}
}
/// Update the state with the training progress.
pub(crate) fn update_progress_train(&mut self, progress: &TrainingProgress) {
if self.num_samples_train.is_some() {
return;
}
self.num_samples_train = Some(progress.progress.items_total);
}
/// Update the state with the validation progress.
pub(crate) fn update_progress_valid(&mut self, progress: &TrainingProgress) {
if self.num_samples_valid.is_some() {
return;
}
if let Some(num_sample_train) = self.num_samples_train {
for (_, (_recent, full)) in self.data.iter_mut() {
let ratio = progress.progress.items_total as f64 / num_sample_train as f64;
full.update_max_sample_valid(ratio);
}
}
self.num_samples_valid = Some(progress.progress.items_total);
}
/// Create a view to display the numeric metrics.
pub(crate) fn view(&self) -> NumericMetricView<'_> {
match self.names.is_empty() {
true => NumericMetricView::None,
false => NumericMetricView::Plots(&self.names, self.selected, self.chart(), self.kind),
}
}
/// Handle the current event.
pub(crate) fn on_event(&mut self, event: &Event) {
if let Event::Key(key) = event {
match key.code {
KeyCode::Right => self.next_metric(),
KeyCode::Left => self.previous_metric(),
KeyCode::Up => self.switch_kind(),
KeyCode::Down => self.switch_kind(),
_ => {}
}
}
}
fn switch_kind(&mut self) {
self.kind = match self.kind {
PlotKind::Full => PlotKind::Recent,
PlotKind::Recent => PlotKind::Full,
};
}
fn next_metric(&mut self) {
self.selected = (self.selected + 1) % {
let this = &self;
this.data.len()
};
}
fn previous_metric(&mut self) {
if self.selected > 0 {
self.selected -= 1;
} else {
self.selected = ({
let this = &self;
this.data.len()
}) - 1;
}
}
fn chart<'a>(&'a self) -> Chart<'a> {
let name = self.names.get(self.selected).unwrap();
let (recent, full) = self.data.get(name).unwrap();
let (datasets, axes) = match self.kind {
PlotKind::Full => (full.datasets(), &full.axes),
PlotKind::Recent => (recent.datasets(), &recent.axes),
};
Chart::<'a>::new(datasets)
.block(Block::default())
.x_axis(
Axis::default()
.style(Style::default().fg(Color::DarkGray))
.title("Iteration")
.labels(axes.labels_x.iter().map(|s| s.bold()).collect())
.bounds(axes.bounds_x),
)
.y_axis(
Axis::default()
.style(Style::default().fg(Color::DarkGray))
.labels(axes.labels_y.iter().map(|s| s.bold()).collect())
.bounds(axes.bounds_y),
)
}
}
#[derive(new)]
pub(crate) enum NumericMetricView<'a> {
Plots(&'a [String], usize, Chart<'a>, PlotKind),
None,
}
impl<'a> NumericMetricView<'a> {
pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {
match self {
Self::Plots(titles, selected, chart, kind) => {
let block = Block::default()
.borders(Borders::ALL)
.title("Plots")
.title_alignment(Alignment::Left);
let size_new = block.inner(size);
frame.render_widget(block, size);
let size = size_new;
let chunks = Layout::default()
.direction(Direction::Vertical)
.constraints(
[
Constraint::Length(2),
Constraint::Length(1),
Constraint::Min(0),
]
.as_ref(),
)
.split(size);
let titles = titles
.iter()
.map(|i| Line::from(vec![i.yellow()]))
.collect();
let tabs = Tabs::new(titles)
.select(selected)
.style(Style::default())
.highlight_style(
Style::default()
.add_modifier(Modifier::BOLD)
.add_modifier(Modifier::UNDERLINED)
.fg(Color::LightYellow),
);
let title = match kind {
PlotKind::Full => "Full History",
PlotKind::Recent => "Recent History",
};
let plot_type =
Paragraph::new(Line::from(title.bold())).alignment(Alignment::Center);
frame.render_widget(tabs, chunks[0]);
frame.render_widget(plot_type, chunks[1]);
frame.render_widget(chart, chunks[2]);
}
Self::None => {}
};
}
}

View File

@ -0,0 +1,101 @@
use super::TerminalFrame;
use crate::metric::MetricEntry;
use ratatui::{
prelude::{Alignment, Rect},
style::{Color, Style, Stylize},
text::{Line, Span},
widgets::{Block, Borders, Paragraph, Wrap},
};
use std::collections::HashMap;
#[derive(Default)]
pub(crate) struct TextMetricsState {
data: HashMap<String, MetricData>,
names: Vec<String>,
}
#[derive(new)]
pub(crate) struct MetricData {
train: Option<MetricEntry>,
valid: Option<MetricEntry>,
}
impl TextMetricsState {
pub(crate) fn update_train(&mut self, metric: MetricEntry) {
if let Some(existing) = self.data.get_mut(&metric.name) {
existing.train = Some(metric);
} else {
let key = metric.name.clone();
let value = MetricData::new(Some(metric), None);
self.names.push(key.clone());
self.data.insert(key, value);
}
}
pub(crate) fn update_valid(&mut self, metric: MetricEntry) {
if let Some(existing) = self.data.get_mut(&metric.name) {
existing.valid = Some(metric);
} else {
let key = metric.name.clone();
let value = MetricData::new(None, Some(metric));
self.names.push(key.clone());
self.data.insert(key, value);
}
}
pub(crate) fn view(&self) -> TextMetricView {
TextMetricView::new(&self.names, &self.data)
}
}
pub(crate) struct TextMetricView {
lines: Vec<Vec<Span<'static>>>,
}
impl TextMetricView {
fn new(names: &[String], data: &HashMap<String, MetricData>) -> Self {
let mut lines = Vec::with_capacity(names.len() * 4);
let start_line = |title: &str| vec![Span::from(format!(" {title} ")).bold().yellow()];
let train_line = |formatted: &str| {
vec![
Span::from(" Train ").bold(),
Span::from(formatted.to_string()).italic(),
]
};
let valid_line = |formatted: &str| {
vec![
Span::from(" Valid ").bold(),
Span::from(formatted.to_string()).italic(),
]
};
for name in names {
lines.push(start_line(name));
let entry = data.get(name).unwrap();
if let Some(entry) = &entry.train {
lines.push(train_line(&entry.formatted));
}
if let Some(entry) = &entry.valid {
lines.push(valid_line(&entry.formatted));
}
lines.push(vec![Span::from("")]);
}
Self { lines }
}
pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {
let paragraph = Paragraph::new(self.lines.into_iter().map(Line::from).collect::<Vec<_>>())
.alignment(Alignment::Left)
.wrap(Wrap { trim: false })
.block(Block::default().borders(Borders::ALL).title("Metrics"))
.style(Style::default().fg(Color::Gray));
frame.render_widget(paragraph, size);
}
}

View File

@ -0,0 +1,23 @@
mod base;
mod controls;
mod full_history;
mod metric_numeric;
mod metric_text;
mod plot_utils;
mod popup;
mod progress;
mod recent_history;
mod renderer;
mod status;
pub(crate) use base::*;
pub(crate) use controls::*;
pub(crate) use full_history::*;
pub(crate) use metric_numeric::*;
pub(crate) use metric_text::*;
pub(crate) use plot_utils::*;
pub(crate) use popup::*;
pub(crate) use progress::*;
pub(crate) use recent_history::*;
pub use renderer::*;
pub(crate) use status::*;

View File

@ -0,0 +1,48 @@
use crate::metric::format_float;
const AXIS_TITLE_PRECISION: usize = 2;
/// The data describing both X and Y axes.
pub(crate) struct PlotAxes {
pub(crate) labels_x: Vec<String>,
pub(crate) labels_y: Vec<String>,
pub(crate) bounds_x: [f64; 2],
pub(crate) bounds_y: [f64; 2],
}
impl Default for PlotAxes {
fn default() -> Self {
Self {
bounds_x: [f64::MAX, f64::MIN],
bounds_y: [f64::MAX, f64::MIN],
labels_x: Vec::new(),
labels_y: Vec::new(),
}
}
}
impl PlotAxes {
/// Update the bounds based on the min max of each X and Y axes with both train and valid data.
pub(crate) fn update_bounds(
&mut self,
(x_train_min, x_train_max): (f64, f64),
(x_valid_min, x_valid_max): (f64, f64),
(y_train_min, y_train_max): (f64, f64),
(y_valid_min, y_valid_max): (f64, f64),
) {
let x_min = f64::min(x_train_min, x_valid_min);
let x_max = f64::max(x_train_max, x_valid_max);
let y_min = f64::min(y_train_min, y_valid_min);
let y_max = f64::max(y_train_max, y_valid_max);
self.bounds_x = [x_min, x_max];
self.bounds_y = [y_min, y_max];
// We know x are integers.
self.labels_x = vec![format!("{x_min}"), format!("{x_max}")];
self.labels_y = vec![
format_float(y_min, AXIS_TITLE_PRECISION),
format_float(y_max, AXIS_TITLE_PRECISION),
];
}
}

View File

@ -0,0 +1,144 @@
use crossterm::event::{Event, KeyCode};
use ratatui::{
prelude::{Alignment, Constraint, Direction, Layout, Rect},
style::{Color, Modifier, Style, Stylize},
text::{Line, Span},
widgets::{Block, Borders, Paragraph, Wrap},
};
use super::TerminalFrame;
/// Popup callback function.
pub(crate) trait CallbackFn: Send + Sync {
/// Call the function and return if the popup state should be reset.
fn call(&self) -> bool;
}
/// Popup callback.
pub(crate) struct Callback {
title: String,
description: String,
trigger: char,
callback: Box<dyn CallbackFn>,
}
impl Callback {
/// Create a new popup.
pub(crate) fn new<T, D, C>(title: T, description: D, trigger: char, callback: C) -> Self
where
T: Into<String>,
D: Into<String>,
C: CallbackFn + 'static,
{
Self {
title: title.into(),
description: description.into(),
trigger,
callback: Box::new(callback),
}
}
}
/// Popup state.
pub(crate) enum PopupState {
Empty,
Full(String, Vec<Callback>),
}
impl PopupState {
/// If the popup is empty.
pub(crate) fn is_empty(&self) -> bool {
matches!(&self, PopupState::Empty)
}
/// Handle popup events.
pub(crate) fn on_event(&mut self, event: &Event) {
let mut reset = false;
match self {
PopupState::Empty => {}
PopupState::Full(_, callbacks) => {
for callback in callbacks.iter() {
if let Event::Key(key) = event {
if let KeyCode::Char(key) = &key.code {
if &callback.trigger == key && callback.callback.call() {
reset = true;
}
}
}
}
}
};
if reset {
*self = Self::Empty;
}
}
/// Create the popup view.
pub(crate) fn view(&self) -> Option<PopupView<'_>> {
match self {
PopupState::Empty => None,
PopupState::Full(title, callbacks) => Some(PopupView::new(title, callbacks)),
}
}
}
#[derive(new)]
pub(crate) struct PopupView<'a> {
title: &'a String,
callbacks: &'a [Callback],
}
impl<'a> PopupView<'a> {
/// Render the view.
pub(crate) fn render<'b>(&'a self, frame: &mut TerminalFrame<'b>, size: Rect) {
let lines = self
.callbacks
.iter()
.flat_map(|callback| {
vec![
Line::from(vec![
Span::from(format!("[{}] ", callback.trigger)).bold(),
Span::from(format!("{} ", callback.title)).yellow().bold(),
]),
Line::from(Span::from("")),
Line::from(Span::from(callback.description.to_string()).italic()),
Line::from(Span::from("")),
]
})
.collect::<Vec<_>>();
let paragraph = Paragraph::new(lines)
.alignment(Alignment::Left)
.wrap(Wrap { trim: false })
.style(Style::default().fg(Color::Gray))
.block(
Block::default()
.borders(Borders::ALL)
.title_alignment(Alignment::Center)
.style(Style::default().fg(Color::Gray))
.title(Span::styled(
self.title,
Style::default().add_modifier(Modifier::BOLD),
)),
);
let area = centered_percent(20, size, Direction::Horizontal);
let area = centered_percent(20, area, Direction::Vertical);
frame.render_widget(paragraph, area);
}
}
/// The percent represents the amount of space that will be taken by each side.
fn centered_percent(percent: u16, size: Rect, direction: Direction) -> Rect {
let center = 100 - (percent * 2);
Layout::default()
.direction(direction)
.constraints([
Constraint::Percentage(percent),
Constraint::Percentage(center),
Constraint::Percentage(percent),
])
.split(size)[1]
}

View File

@ -0,0 +1,144 @@
use super::TerminalFrame;
use crate::metric::dashboard::TrainingProgress;
use ratatui::{
prelude::{Alignment, Constraint, Direction, Layout, Rect},
style::{Color, Style, Stylize},
text::{Line, Span},
widgets::{Block, Borders, Gauge, Paragraph},
};
use std::time::Instant;
/// Simple progress bar for the training.
///
/// We currently ignore the time taken for the validation part.
pub(crate) struct ProgressBarState {
progress_train: f64,
started: Instant,
}
impl Default for ProgressBarState {
fn default() -> Self {
Self {
progress_train: 0.0,
started: Instant::now(),
}
}
}
const MINUTE: u64 = 60;
const HOUR: u64 = 60 * 60;
const DAY: u64 = 24 * 60 * 60;
impl ProgressBarState {
/// Update the training progress.
pub(crate) fn update_train(&mut self, progress: &TrainingProgress) {
let total_items = progress.progress.items_total * progress.epoch_total;
let epoch_items = (progress.epoch - 1) * progress.progress.items_total;
let iteration_items = progress.progress.items_processed as f64;
self.progress_train = (epoch_items as f64 + iteration_items) / total_items as f64
}
/// Update the validation progress.
pub(crate) fn update_valid(&mut self, _progress: &TrainingProgress) {
// We don't use the validation for the progress yet.
}
/// Create a view for the current progress.
pub(crate) fn view(&self) -> ProgressBarView {
let eta = self.started.elapsed();
let total_estimated = (eta.as_secs() as f64) / self.progress_train;
let eta = if total_estimated.is_normal() {
let remaining = 1.0 - self.progress_train;
let eta = (total_estimated * remaining) as u64;
format_eta(eta)
} else {
"---".to_string()
};
ProgressBarView::new(self.progress_train, eta)
}
}
#[derive(new)]
pub(crate) struct ProgressBarView {
progress: f64,
eta: String,
}
impl ProgressBarView {
/// Render the view.
pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {
let block = Block::default()
.borders(Borders::ALL)
.title("Progress")
.title_alignment(Alignment::Left);
let size_new = block.inner(size);
frame.render_widget(block, size);
let size = size_new;
let chunks = Layout::default()
.direction(Direction::Horizontal)
.constraints(
[
Constraint::Length(1), // Empty space
Constraint::Min(0),
Constraint::Length(self.eta.len() as u16 + 4),
]
.as_ref(),
)
.split(size);
let size_gauge = chunks[1];
let size_eta = chunks[2];
let iteration = Gauge::default()
.gauge_style(Style::default().fg(Color::Yellow))
.ratio(self.progress);
let eta = Paragraph::new(Line::from(vec![
Span::from(" ("),
Span::from(self.eta).italic(),
Span::from(") "),
]));
frame.render_widget(iteration, size_gauge);
frame.render_widget(eta, size_eta);
}
}
fn format_eta(eta_secs: u64) -> String {
let seconds = eta_secs % 60;
let minutes = eta_secs / MINUTE % 60;
let hours = eta_secs / HOUR % 24;
let days = eta_secs / DAY;
if days > 0 {
return format!("{days} days");
}
if hours > 0 {
return format!("{hours} hours");
}
if minutes > 0 {
return format!("{minutes} mins");
}
format!("{seconds} secs")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_format_eta() {
assert_eq!("55 secs", format_eta(55), "Less than 1 minutes");
assert_eq!("1 mins", format_eta(61), "More than 1 minutes");
assert_eq!("2 mins", format_eta(2 * 61), "More than 2 minutes");
assert_eq!("1 hours", format_eta(3601), "More than 1 hour");
assert_eq!("2 hours", format_eta(2 * 3601), "More than 2 hour");
assert_eq!("1 days", format_eta(24 * 3601), "More than 1 day");
assert_eq!("2 days", format_eta(48 * 3601), "More than 2 day");
}
}

View File

@ -0,0 +1,244 @@
use super::PlotAxes;
use ratatui::{
style::{Color, Style, Stylize},
symbols,
widgets::{Dataset, GraphType},
};
const FACTOR_BEFORE_RESIZE: usize = 2;
/// A plot that shows the recent history at full resolution.
pub(crate) struct RecentHistoryPlot {
pub(crate) axes: PlotAxes,
train: RecentHistoryPoints,
valid: RecentHistoryPoints,
max_samples: usize,
}
struct RecentHistoryPoints {
min_x: f64,
max_x: f64,
min_y: f64,
max_y: f64,
cursor: usize,
points: Vec<(f64, f64)>,
max_samples: usize,
factor_before_resize: usize,
}
impl RecentHistoryPlot {
pub(crate) fn new(max_samples: usize) -> Self {
Self {
axes: PlotAxes::default(),
train: RecentHistoryPoints::new(max_samples),
valid: RecentHistoryPoints::new(max_samples),
max_samples,
}
}
pub(crate) fn push_train(&mut self, data: f64) {
let (x_min, x_current) = self.x();
self.train.push((x_current, data));
self.train.update_cursor(x_min);
self.valid.update_cursor(x_min);
self.update_bounds();
}
pub(crate) fn push_valid(&mut self, data: f64) {
let (x_min, x_current) = self.x();
self.valid.push((x_current, data));
self.valid.update_cursor(x_min);
self.train.update_cursor(x_min);
self.update_bounds();
}
pub(crate) fn datasets(&self) -> Vec<Dataset<'_>> {
let mut datasets = Vec::with_capacity(2);
if self.train.num_visible_points() > 0 {
datasets.push(self.train.dataset("Train", Color::LightRed));
}
if self.valid.num_visible_points() > 0 {
datasets.push(self.valid.dataset("Valid", Color::LightBlue));
}
datasets
}
fn x(&mut self) -> (f64, f64) {
let x_current = f64::max(self.train.max_x, self.valid.max_x) + 1.0;
let mut x_min = f64::min(self.train.min_x, self.valid.min_x);
if x_current - x_min >= self.max_samples as f64 {
x_min += 1.0;
}
(x_min, x_current)
}
fn update_bounds(&mut self) {
self.axes.update_bounds(
(self.train.min_x, self.train.max_x),
(self.valid.min_x, self.valid.max_x),
(self.train.min_y, self.train.max_y),
(self.valid.min_y, self.valid.max_y),
);
}
}
impl RecentHistoryPoints {
fn new(max_samples: usize) -> Self {
let factor_before_resize = FACTOR_BEFORE_RESIZE;
Self {
min_x: 0.,
max_x: 0.,
min_y: f64::MAX,
max_y: f64::MIN,
points: Vec::with_capacity(factor_before_resize * max_samples),
cursor: 0,
max_samples,
factor_before_resize,
}
}
fn num_visible_points(&self) -> usize {
self.points.len()
}
fn push(&mut self, (x, y): (f64, f64)) {
if x > self.max_x {
self.max_x = x;
}
if x < self.min_x {
self.min_x = x;
}
if y > self.max_y {
self.max_y = y;
}
if y < self.min_y {
self.min_y = y
}
self.points.push((x, y));
}
fn update_cursor(&mut self, min_x: f64) {
if self.min_x >= min_x {
return;
}
self.min_x = min_x;
let mut update_y_max = false;
let mut update_y_min = false;
while let Some((x, y)) = self.points.get(self.cursor) {
if *x >= self.min_x {
break;
}
if *y == self.max_y {
update_y_max = true
}
if *y == self.min_y {
update_y_min = true;
}
self.cursor += 1;
}
if update_y_max {
self.max_y = self.calculate_max_y();
}
if update_y_min {
self.min_y = self.calculate_min_y();
}
if self.points.len() >= self.max_samples * self.factor_before_resize {
self.resize();
}
}
fn slice(&self) -> &[(f64, f64)] {
&self.points[self.cursor..self.points.len()]
}
fn calculate_max_y(&self) -> f64 {
let mut max_y = f64::MIN;
for (_x, y) in self.slice() {
if *y > max_y {
max_y = *y;
}
}
max_y
}
fn calculate_min_y(&self) -> f64 {
let mut min_y = f64::MAX;
for (_x, y) in self.slice() {
if *y < min_y {
min_y = *y;
}
}
min_y
}
fn resize(&mut self) {
let mut points = Vec::with_capacity(self.max_samples * self.factor_before_resize);
for i in self.cursor..self.points.len() {
points.push(self.points[i]);
}
self.points = points;
self.cursor = 0;
}
fn dataset<'a>(&'a self, name: &'a str, color: Color) -> Dataset<'a> {
let data = &self.points[self.cursor..self.points.len()];
Dataset::default()
.name(name)
.marker(symbols::Marker::Braille)
.style(Style::default().fg(color).bold())
.graph_type(GraphType::Scatter)
.data(data)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_push_update_bounds_max_y() {
let mut chart = RecentHistoryPlot::new(3);
chart.push_train(15.0);
chart.push_train(10.0);
chart.push_train(14.0);
assert_eq!(chart.axes.bounds_y[1], 15.);
chart.push_train(10.0);
assert_eq!(chart.axes.bounds_y[1], 14.);
}
#[test]
fn test_push_update_bounds_min_y() {
let mut chart = RecentHistoryPlot::new(3);
chart.push_train(5.0);
chart.push_train(10.0);
chart.push_train(14.0);
assert_eq!(chart.axes.bounds_y[0], 5.);
chart.push_train(10.0);
assert_eq!(chart.axes.bounds_y[0], 10.);
}
}

View File

@ -0,0 +1,203 @@
use crate::metric::dashboard::tui::NumericMetricsState;
use crate::metric::dashboard::{DashboardMetricState, DashboardRenderer, TrainingProgress};
use crate::TrainingInterrupter;
use crossterm::{
event::{self, Event, KeyCode},
execute,
terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen},
};
use ratatui::{prelude::*, Terminal};
use std::{
error::Error,
io::{self, Stdout},
time::{Duration, Instant},
};
use super::{
Callback, CallbackFn, ControlsView, DashboardView, PopupState, ProgressBarState, StatusState,
TextMetricsState,
};
/// The current terminal backend.
pub(crate) type TerminalBackend = CrosstermBackend<Stdout>;
/// The current terminal frame.
pub(crate) type TerminalFrame<'a> = ratatui::Frame<'a, TerminalBackend>;
const MAX_REFRESH_RATE_MILLIS: u64 = 100;
/// The CLI dashboard renderer.
pub struct TuiDashboardRenderer {
terminal: Terminal<TerminalBackend>,
last_update: std::time::Instant,
progress: ProgressBarState,
metrics_numeric: NumericMetricsState,
metrics_text: TextMetricsState,
status: StatusState,
interuptor: TrainingInterrupter,
popup: PopupState,
}
impl DashboardRenderer for TuiDashboardRenderer {
fn update_train(&mut self, state: DashboardMetricState) {
match state {
DashboardMetricState::Generic(entry) => {
self.metrics_text.update_train(entry);
}
DashboardMetricState::Numeric(entry, value) => {
self.metrics_numeric.push_train(entry.name.clone(), value);
self.metrics_text.update_train(entry);
}
};
}
fn update_valid(&mut self, state: DashboardMetricState) {
match state {
DashboardMetricState::Generic(entry) => {
self.metrics_text.update_valid(entry);
}
DashboardMetricState::Numeric(entry, value) => {
self.metrics_numeric.push_valid(entry.name.clone(), value);
self.metrics_text.update_valid(entry);
}
};
}
fn render_train(&mut self, item: TrainingProgress) {
self.progress.update_train(&item);
self.metrics_numeric.update_progress_train(&item);
self.status.update_train(item);
self.render().unwrap();
}
fn render_valid(&mut self, item: TrainingProgress) {
self.progress.update_valid(&item);
self.metrics_numeric.update_progress_valid(&item);
self.status.update_valid(item);
self.render().unwrap();
}
}
impl TuiDashboardRenderer {
/// Create a new CLI dashboard renderer.
pub fn new(interuptor: TrainingInterrupter) -> Self {
let mut stdout = io::stdout();
execute!(stdout, EnterAlternateScreen).unwrap();
enable_raw_mode().unwrap();
let terminal = Terminal::new(CrosstermBackend::new(stdout)).unwrap();
Self {
terminal,
last_update: Instant::now(),
progress: ProgressBarState::default(),
metrics_numeric: NumericMetricsState::default(),
metrics_text: TextMetricsState::default(),
status: StatusState::default(),
interuptor,
popup: PopupState::Empty,
}
}
fn render(&mut self) -> Result<(), Box<dyn Error>> {
let tick_rate = Duration::from_millis(MAX_REFRESH_RATE_MILLIS);
if self.last_update.elapsed() < tick_rate {
return Ok(());
}
self.draw()?;
self.handle_events()?;
self.last_update = Instant::now();
Ok(())
}
fn draw(&mut self) -> Result<(), Box<dyn Error>> {
self.terminal.draw(|frame| {
let size = frame.size();
match self.popup.view() {
Some(view) => view.render(frame, size),
None => {
let view = DashboardView::new(
self.metrics_numeric.view(),
self.metrics_text.view(),
self.progress.view(),
ControlsView,
self.status.view(),
);
view.render(frame, size);
}
};
})?;
Ok(())
}
fn handle_events(&mut self) -> Result<(), Box<dyn Error>> {
while event::poll(Duration::from_secs(0))? {
let event = event::read()?;
self.popup.on_event(&event);
if self.popup.is_empty() {
self.metrics_numeric.on_event(&event);
if let Event::Key(key) = event {
if let KeyCode::Char('q') = key.code {
self.popup = PopupState::Full(
"Quit".to_string(),
vec![
Callback::new(
"Stop the training.",
"Stop the training immediately. This will break from the training loop, but any remaining code after the loop will be executed.",
's',
QuitPopupAccept(self.interuptor.clone()),
),
Callback::new(
"Stop the training immediately.",
"Kill the program. This will create a panic! which will make the current training fails. Any code following the training won't be executed.",
'k',
KillPopupAccept,
),
Callback::new("Cancel", "Cancel the action, continue the training.", 'c', PopupCancel),
],
);
}
}
}
}
Ok(())
}
}
struct QuitPopupAccept(TrainingInterrupter);
struct KillPopupAccept;
struct PopupCancel;
impl CallbackFn for KillPopupAccept {
fn call(&self) -> bool {
panic!("Killing training from user input.");
}
}
impl CallbackFn for QuitPopupAccept {
fn call(&self) -> bool {
self.0.stop();
true
}
}
impl CallbackFn for PopupCancel {
fn call(&self) -> bool {
true
}
}
impl Drop for TuiDashboardRenderer {
fn drop(&mut self) {
disable_raw_mode().ok();
execute!(self.terminal.backend_mut(), LeaveAlternateScreen).unwrap();
self.terminal.show_cursor().ok();
}
}

View File

@ -0,0 +1,91 @@
use super::TerminalFrame;
use crate::metric::dashboard::TrainingProgress;
use ratatui::{
prelude::{Alignment, Rect},
style::{Color, Style, Stylize},
text::{Line, Span},
widgets::{Block, Borders, Paragraph, Wrap},
};
/// Show the training status with various information.
pub(crate) struct StatusState {
progress: TrainingProgress,
mode: Mode,
}
enum Mode {
Valid,
Train,
}
impl Default for StatusState {
fn default() -> Self {
Self {
progress: TrainingProgress::none(),
mode: Mode::Train,
}
}
}
impl StatusState {
/// Update the training information.
pub(crate) fn update_train(&mut self, progress: TrainingProgress) {
self.progress = progress;
self.mode = Mode::Train;
}
/// Update the validation information.
pub(crate) fn update_valid(&mut self, progress: TrainingProgress) {
self.progress = progress;
self.mode = Mode::Valid;
}
/// Create a view.
pub(crate) fn view(&self) -> StatusView {
StatusView::new(&self.progress, &self.mode)
}
}
pub(crate) struct StatusView {
lines: Vec<Vec<Span<'static>>>,
}
impl StatusView {
fn new(progress: &TrainingProgress, mode: &Mode) -> Self {
let title = |title: &str| Span::from(format!(" {title} ")).bold().yellow();
let value = |value: String| Span::from(value).italic();
let mode = match mode {
Mode::Valid => "Validating",
Mode::Train => "Training",
};
Self {
lines: vec![
vec![title("Mode :"), value(mode.to_string())],
vec![
title("Epoch :"),
value(format!("{}/{}", progress.epoch, progress.epoch_total)),
],
vec![
title("Iteration :"),
value(format!("{}", progress.iteration)),
],
vec![
title("Items :"),
value(format!(
"{}/{}",
progress.progress.items_processed, progress.progress.items_total
)),
],
],
}
}
pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {
let paragraph = Paragraph::new(self.lines.into_iter().map(Line::from).collect::<Vec<_>>())
.alignment(Alignment::Left)
.block(Block::default().borders(Borders::ALL).title("Status"))
.wrap(Wrap { trim: false })
.style(Style::default().fg(Color::Gray));
frame.render_widget(paragraph, size);
}
}

View File

@ -1,4 +1,4 @@
use super::{MetricEntry, Numeric}; use super::{format_float, MetricEntry, Numeric};
/// Usefull utility to implement numeric metrics. /// Usefull utility to implement numeric metrics.
/// ///
@ -70,20 +70,10 @@ impl NumericMetricState {
let serialized = value_current.to_string(); let serialized = value_current.to_string();
let (formatted_current, formatted_running) = match format.precision { let (formatted_current, formatted_running) = match format.precision {
Some(precision) => { Some(precision) => (
let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0); format_float(value_current, precision),
format_float(value_running, precision),
( ),
match scientific_notation_threshold >= value_current {
true => format!("{value_current:.precision$e}"),
false => format!("{value_current:.precision$}"),
},
match scientific_notation_threshold >= value_running {
true => format!("{value_running:.precision$e}"),
false => format!("{value_running:.precision$}"),
},
)
}
None => (format!("{value_current}"), format!("{value_running}")), None => (format!("{value_current}"), format!("{value_running}")),
}; };

View File

@ -22,7 +22,7 @@ train = ["burn-train/default", "autodiff", "dataset"]
train-minimal = ["burn-train"] train-minimal = ["burn-train"]
## Includes the Text UI (progress bars, metric plots) ## Includes the Text UI (progress bars, metric plots)
train-ui = ["burn-train/ui"] train-tui = ["burn-train/tui"]
## Includes system info metrics (CPU/GPU usage, etc) ## Includes system info metrics (CPU/GPU usage, etc)
train-metrics = ["burn-train/metrics"] train-metrics = ["burn-train/metrics"]