mirror of https://github.com/tracel-ai/burn.git
Feat/dashboard tui (#790)
This commit is contained in:
parent
4f72578260
commit
57d6a566be
|
@ -1,5 +1,6 @@
|
|||
[default]
|
||||
extend-ignore-identifiers-re = [
|
||||
"ratatui",
|
||||
"NdArray*",
|
||||
"ND"
|
||||
]
|
||||
|
|
|
@ -62,11 +62,16 @@ where
|
|||
let mut iterator = dataloader_cloned.iter();
|
||||
while let Some(item) = iterator.next() {
|
||||
let progress = iterator.progress();
|
||||
sender_cloned
|
||||
.send(Message::Batch(index, item, progress))
|
||||
.unwrap();
|
||||
|
||||
match sender_cloned.send(Message::Batch(index, item, progress)) {
|
||||
Ok(_) => {}
|
||||
// The receiver is probably gone, no need to panic, just need to stop
|
||||
// iterating.
|
||||
Err(_) => return,
|
||||
};
|
||||
}
|
||||
sender_cloned.send(Message::Done).unwrap();
|
||||
// Same thing.
|
||||
sender_cloned.send(Message::Done).ok();
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
|
|
@ -11,17 +11,15 @@ repository = "https://github.com/burn-rs/burn/tree/main/burn-train"
|
|||
version = "0.10.0"
|
||||
|
||||
[features]
|
||||
default = ["metrics", "ui"]
|
||||
default = ["metrics", "tui"]
|
||||
metrics = [
|
||||
"nvml-wrapper",
|
||||
"sysinfo",
|
||||
"systemstat"
|
||||
]
|
||||
ui = [
|
||||
"indicatif",
|
||||
"rgb",
|
||||
"terminal_size",
|
||||
"textplots",
|
||||
tui = [
|
||||
"ratatui",
|
||||
"crossterm"
|
||||
]
|
||||
|
||||
[dependencies]
|
||||
|
@ -38,10 +36,8 @@ sysinfo = { version = "0.29.8", optional = true }
|
|||
systemstat = { version = "0.2.3", optional = true }
|
||||
|
||||
# Text UI
|
||||
indicatif = { version = "0.17.5", optional = true }
|
||||
rgb = { version = "0.8.36", optional = true }
|
||||
terminal_size = { version = "0.2.6", optional = true }
|
||||
textplots = { version = "0.8.0", optional = true }
|
||||
ratatui = { version = "0.23", optional = true, features = ["all-widgets"] }
|
||||
crossterm = { version = "0.27", optional = true }
|
||||
|
||||
# Utilities
|
||||
derive-new = {workspace = true}
|
||||
|
|
|
@ -2,8 +2,9 @@ use super::log::install_file_logger;
|
|||
use super::Learner;
|
||||
use crate::checkpoint::{AsyncCheckpointer, Checkpointer, FileCheckpointer};
|
||||
use crate::logger::{FileMetricLogger, MetricLogger};
|
||||
use crate::metric::dashboard::CLIDashboardRenderer;
|
||||
use crate::metric::dashboard::{Dashboard, DashboardRenderer, MetricWrapper, Metrics};
|
||||
use crate::metric::dashboard::{
|
||||
Dashboard, DashboardRenderer, MetricWrapper, Metrics, SelectedDashboardRenderer,
|
||||
};
|
||||
use crate::metric::{Adaptor, Metric};
|
||||
use crate::AsyncTrainerCallback;
|
||||
use burn_core::lr_scheduler::LRScheduler;
|
||||
|
@ -259,7 +260,7 @@ where
|
|||
}
|
||||
let renderer = self
|
||||
.renderer
|
||||
.unwrap_or_else(|| Box::new(CLIDashboardRenderer::new()));
|
||||
.unwrap_or_else(|| Box::new(SelectedDashboardRenderer::new(self.interrupter.clone())));
|
||||
let directory = &self.directory;
|
||||
let logger_train = self.metric_logger_train.unwrap_or_else(|| {
|
||||
Box::new(FileMetricLogger::new(format!("{directory}/train").as_str()))
|
||||
|
|
|
@ -199,6 +199,7 @@ impl<TI> TrainEpoch<TI> {
|
|||
|
||||
// The main device is always the first in the list.
|
||||
let device_main = devices.get(0).unwrap().clone();
|
||||
let mut interrupted = false;
|
||||
|
||||
loop {
|
||||
let items = step.step(&mut iterator, &model);
|
||||
|
@ -234,9 +235,14 @@ impl<TI> TrainEpoch<TI> {
|
|||
callback.on_train_item(item);
|
||||
if interrupter.should_stop() {
|
||||
log::info!("Training interrupted.");
|
||||
interrupted = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if interrupted {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
callback.on_train_end_epoch(self.epoch);
|
||||
|
|
|
@ -168,6 +168,10 @@ where
|
|||
);
|
||||
}
|
||||
|
||||
if self.interrupter.should_stop() {
|
||||
break;
|
||||
}
|
||||
|
||||
let epoch_valid = ValidEpoch::new(dataloader_valid.clone(), epoch, self.num_epochs);
|
||||
epoch_valid.run(&model, &mut self.callback, &self.interrupter);
|
||||
|
||||
|
|
|
@ -78,3 +78,13 @@ pub struct MetricEntry {
|
|||
/// The string to be saved.
|
||||
pub serialize: String,
|
||||
}
|
||||
|
||||
/// Format a float with the given precision. Will use scientific notation if necessary.
|
||||
pub fn format_float(float: f64, precision: usize) -> String {
|
||||
let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);
|
||||
|
||||
match scientific_notation_threshold >= float {
|
||||
true => format!("{float:.precision$e}"),
|
||||
false => format!("{float:.precision$}"),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/// The CPU use metric.
|
||||
use super::MetricMetadata;
|
||||
use super::{MetricMetadata, Numeric};
|
||||
use crate::metric::{Metric, MetricEntry};
|
||||
use sysinfo::{CpuExt, System, SystemExt};
|
||||
|
||||
|
@ -59,3 +59,9 @@ impl Metric for CpuUse {
|
|||
|
||||
fn clear(&mut self) {}
|
||||
}
|
||||
|
||||
impl Numeric for CpuUse {
|
||||
fn value(&self) -> f64 {
|
||||
self.use_percentage as f64
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,266 +0,0 @@
|
|||
use super::{DashboardMetricState, DashboardRenderer, TextPlot, TrainingProgress};
|
||||
use indicatif::{MultiProgress, ProgressBar, ProgressState, ProgressStyle};
|
||||
use std::{collections::HashMap, fmt::Write};
|
||||
|
||||
static MAX_REFRESH_RATE_MILLIS: u128 = 250;
|
||||
|
||||
/// The CLI dashboard renderer.
|
||||
pub struct CLIDashboardRenderer {
|
||||
pb_epoch: ProgressBar,
|
||||
pb_iteration: ProgressBar,
|
||||
last_update: std::time::Instant,
|
||||
progress: TrainingProgress,
|
||||
metric_train: HashMap<String, String>,
|
||||
metric_valid: HashMap<String, String>,
|
||||
metric_both_plot: HashMap<String, TextPlot>,
|
||||
metric_train_plot: HashMap<String, TextPlot>,
|
||||
metric_valid_plot: HashMap<String, TextPlot>,
|
||||
}
|
||||
|
||||
impl TrainingProgress {
|
||||
fn finished(&self) -> bool {
|
||||
self.epoch == self.epoch_total && self.progress.items_processed == self.progress.items_total
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CLIDashboardRenderer {
|
||||
fn default() -> Self {
|
||||
CLIDashboardRenderer::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CLIDashboardRenderer {
|
||||
fn drop(&mut self) {
|
||||
self.pb_iteration.finish();
|
||||
self.pb_epoch.finish();
|
||||
}
|
||||
}
|
||||
|
||||
impl DashboardRenderer for CLIDashboardRenderer {
|
||||
fn update_train(&mut self, state: DashboardMetricState) {
|
||||
match state {
|
||||
DashboardMetricState::Generic(state) => {
|
||||
self.metric_train.insert(state.name, state.formatted);
|
||||
}
|
||||
DashboardMetricState::Numeric(state, value) => {
|
||||
let name = &state.name;
|
||||
self.metric_train.insert(name.clone(), state.formatted);
|
||||
|
||||
if let Some(mut plot) = self.text_plot_in_both(name) {
|
||||
plot.update_train(value as f32);
|
||||
self.metric_both_plot.insert(name.clone(), plot);
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(plot) = self.metric_train_plot.get_mut(name) {
|
||||
plot.update_train(value as f32);
|
||||
} else {
|
||||
let mut plot = TextPlot::new();
|
||||
plot.update_train(value as f32);
|
||||
self.metric_train_plot.insert(state.name, plot);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn update_valid(&mut self, state: DashboardMetricState) {
|
||||
match state {
|
||||
DashboardMetricState::Generic(state) => {
|
||||
self.metric_valid.insert(state.name, state.formatted);
|
||||
}
|
||||
DashboardMetricState::Numeric(state, value) => {
|
||||
let name = &state.name;
|
||||
self.metric_valid.insert(name.clone(), state.formatted);
|
||||
|
||||
if let Some(mut plot) = self.text_plot_in_both(name) {
|
||||
plot.update_valid(value as f32);
|
||||
self.metric_both_plot.insert(name.clone(), plot);
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(plot) = self.metric_valid_plot.get_mut(name) {
|
||||
plot.update_valid(value as f32);
|
||||
} else {
|
||||
let mut plot = TextPlot::new();
|
||||
plot.update_valid(value as f32);
|
||||
self.metric_valid_plot.insert(state.name, plot);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn render_train(&mut self, item: TrainingProgress) {
|
||||
self.progress = item;
|
||||
self.render();
|
||||
}
|
||||
|
||||
fn render_valid(&mut self, item: TrainingProgress) {
|
||||
self.progress = item;
|
||||
self.render();
|
||||
}
|
||||
}
|
||||
|
||||
impl CLIDashboardRenderer {
|
||||
/// Create a new CLI dashboard renderer.
|
||||
pub fn new() -> Self {
|
||||
let pb = MultiProgress::new();
|
||||
let pb_epoch = ProgressBar::new(0);
|
||||
let pb_iteration = ProgressBar::new(0);
|
||||
|
||||
let pb_iteration = pb.add(pb_iteration);
|
||||
let pb_epoch = pb.add(pb_epoch);
|
||||
|
||||
Self {
|
||||
pb_epoch,
|
||||
pb_iteration,
|
||||
last_update: std::time::Instant::now(),
|
||||
progress: TrainingProgress::none(),
|
||||
metric_train: HashMap::new(),
|
||||
metric_valid: HashMap::new(),
|
||||
metric_both_plot: HashMap::new(),
|
||||
metric_train_plot: HashMap::new(),
|
||||
metric_valid_plot: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn text_plot_in_both(&mut self, key: &str) -> Option<TextPlot> {
|
||||
if let Some(plot) = self.metric_both_plot.remove(key) {
|
||||
return Some(plot);
|
||||
}
|
||||
if self.metric_train_plot.contains_key(key) && self.metric_valid_plot.contains_key(key) {
|
||||
let plot_train = self.metric_train_plot.remove(key).unwrap();
|
||||
let plot_valid = self.metric_valid_plot.remove(key).unwrap();
|
||||
|
||||
return Some(plot_train.merge(plot_valid));
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn register_template_plots(&self, template: String) -> String {
|
||||
let mut template = template;
|
||||
let mut metrics_keys = Vec::new();
|
||||
|
||||
for (name, metric) in self.metric_both_plot.iter() {
|
||||
metrics_keys.push(format!(
|
||||
" - {} RED: train | BLUE: valid \n{}",
|
||||
name,
|
||||
metric.render()
|
||||
));
|
||||
}
|
||||
for (name, metric) in self.metric_train_plot.iter() {
|
||||
metrics_keys.push(format!(" - Train {}: \n{}", name, metric.render()));
|
||||
}
|
||||
for (name, metric) in self.metric_valid_plot.iter() {
|
||||
metrics_keys.push(format!(" - Valid {}: \n{}", name, metric.render()));
|
||||
}
|
||||
|
||||
if !metrics_keys.is_empty() {
|
||||
let metrics_template = metrics_keys.join("\n");
|
||||
template += format!("{PLOTS_TAG}\n{metrics_template}\n").as_str();
|
||||
}
|
||||
|
||||
template
|
||||
}
|
||||
fn register_template_metrics(&self, template: String) -> String {
|
||||
let mut template = template;
|
||||
let mut metrics_keys = Vec::new();
|
||||
|
||||
for (name, metric) in self.metric_train.iter() {
|
||||
metrics_keys.push(format!(" - Train {name}: {metric}"));
|
||||
}
|
||||
for (name, metric) in self.metric_valid.iter() {
|
||||
metrics_keys.push(format!(" - Valid {name}: {metric}"));
|
||||
}
|
||||
|
||||
if !metrics_keys.is_empty() {
|
||||
let metrics_template = metrics_keys.join("\n");
|
||||
template += format!("{METRICS_TAG}\n{metrics_template}\n").as_str();
|
||||
}
|
||||
|
||||
template
|
||||
}
|
||||
|
||||
fn register_style_progress(
|
||||
&self,
|
||||
name: &'static str,
|
||||
style: ProgressStyle,
|
||||
value: String,
|
||||
) -> ProgressStyle {
|
||||
self.register_key_item(name, style, name.to_string(), value)
|
||||
}
|
||||
|
||||
fn register_template_progress(&self, progress: &str, template: String) -> String {
|
||||
let mut template = template;
|
||||
|
||||
let bar = "[{wide_bar:.cyan/blue}]";
|
||||
template += format!(" - {progress} {bar}").as_str();
|
||||
template
|
||||
}
|
||||
|
||||
fn render(&mut self) {
|
||||
if !self.progress.finished()
|
||||
&& std::time::Instant::now()
|
||||
.duration_since(self.last_update)
|
||||
.as_millis()
|
||||
< MAX_REFRESH_RATE_MILLIS
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
let template = self.register_template_plots(String::default());
|
||||
let template = self.register_template_metrics(template);
|
||||
let template = template
|
||||
+ format!(
|
||||
"\n{}\n - Iteration {} Epoch {}/{}\n",
|
||||
PROGRESS_TAG,
|
||||
self.progress.iteration,
|
||||
self.progress.epoch,
|
||||
self.progress.epoch_total
|
||||
)
|
||||
.as_str();
|
||||
|
||||
let template = self.register_template_progress("iteration", template);
|
||||
let style_iteration = ProgressStyle::with_template(&template).unwrap();
|
||||
let style_iteration = self.register_style_progress(
|
||||
"iteration",
|
||||
style_iteration,
|
||||
format!("{}", self.progress.iteration),
|
||||
);
|
||||
|
||||
let template = self.register_template_progress("epoch ", String::default());
|
||||
let style_epoch = ProgressStyle::with_template(&template).unwrap();
|
||||
let style_epoch =
|
||||
self.register_style_progress("epoch", style_epoch, format!("{}", self.progress.epoch));
|
||||
|
||||
self.pb_iteration
|
||||
.set_style(style_iteration.progress_chars("#>-"));
|
||||
self.pb_iteration
|
||||
.set_position(self.progress.progress.items_processed as u64);
|
||||
self.pb_iteration
|
||||
.set_length(self.progress.progress.items_total as u64);
|
||||
|
||||
self.pb_epoch.set_style(style_epoch.progress_chars("#>-"));
|
||||
self.pb_epoch.set_position(self.progress.epoch as u64 - 1);
|
||||
self.pb_epoch.set_length(self.progress.epoch_total as u64);
|
||||
|
||||
self.last_update = std::time::Instant::now();
|
||||
}
|
||||
|
||||
/// Registers a new metric to be displayed.
|
||||
pub fn register_key_item(
|
||||
&self,
|
||||
key: &'static str,
|
||||
style: ProgressStyle,
|
||||
name: String,
|
||||
formatted: String,
|
||||
) -> ProgressStyle {
|
||||
style.with_key(key, move |_state: &ProgressState, w: &mut dyn Write| {
|
||||
write!(w, "{name}: {formatted}").unwrap()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
static METRICS_TAG: &str = "[Metrics]";
|
||||
static PLOTS_TAG: &str = "[Plots]";
|
||||
static PROGRESS_TAG: &str = "[Progress]";
|
|
@ -1,16 +1,13 @@
|
|||
/// Command line interface module for the dashboard.
|
||||
#[cfg(feature = "ui")]
|
||||
mod cli;
|
||||
#[cfg(not(feature = "ui"))]
|
||||
mod cli_stub;
|
||||
|
||||
mod base;
|
||||
mod plot;
|
||||
|
||||
pub use base::*;
|
||||
pub use plot::*;
|
||||
|
||||
#[cfg(feature = "ui")]
|
||||
pub use cli::CLIDashboardRenderer;
|
||||
#[cfg(not(feature = "ui"))]
|
||||
pub use cli_stub::CLIDashboardRenderer;
|
||||
#[cfg(not(feature = "tui"))]
|
||||
mod cli_stub;
|
||||
#[cfg(not(feature = "tui"))]
|
||||
pub use cli_stub::CLIDashboardRenderer as SelectedDashboardRenderer;
|
||||
|
||||
#[cfg(feature = "tui")]
|
||||
mod tui;
|
||||
#[cfg(feature = "tui")]
|
||||
pub use tui::TuiDashboardRenderer as SelectedDashboardRenderer;
|
||||
|
|
|
@ -1,160 +0,0 @@
|
|||
/// Text plot.
|
||||
pub struct TextPlot {
|
||||
train: Vec<(f32, f32)>,
|
||||
valid: Vec<(f32, f32)>,
|
||||
max_values: usize,
|
||||
iteration: usize,
|
||||
}
|
||||
|
||||
impl Default for TextPlot {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl TextPlot {
|
||||
/// Creates a new text plot.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
train: Vec::new(),
|
||||
valid: Vec::new(),
|
||||
max_values: 10000,
|
||||
iteration: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Merges two text plots.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `self` - The first text plot.
|
||||
/// * `other` - The second text plot.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The merged text plot.
|
||||
pub fn merge(self, other: Self) -> Self {
|
||||
let mut other = other;
|
||||
let mut train = self.train;
|
||||
let mut valid = self.valid;
|
||||
|
||||
train.append(&mut other.train);
|
||||
valid.append(&mut other.valid);
|
||||
|
||||
Self {
|
||||
train,
|
||||
valid,
|
||||
max_values: usize::min(self.max_values, other.max_values),
|
||||
iteration: self.iteration + other.iteration,
|
||||
}
|
||||
}
|
||||
|
||||
/// Updates the text plot with a new item for training.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The new item.
|
||||
pub fn update_train(&mut self, item: f32) {
|
||||
self.iteration += 1;
|
||||
self.train.push((self.iteration as f32, item));
|
||||
|
||||
let x_max = self
|
||||
.train
|
||||
.last()
|
||||
.map(|(iteration, _)| *iteration)
|
||||
.unwrap_or(f32::MIN);
|
||||
let x_min = self
|
||||
.train
|
||||
.first()
|
||||
.map(|(iteration, _)| *iteration)
|
||||
.unwrap_or(f32::MAX);
|
||||
|
||||
if x_max - x_min > self.max_values as f32 && !self.train.is_empty() {
|
||||
self.train.remove(0);
|
||||
}
|
||||
}
|
||||
|
||||
/// Updates the text plot with a new item for validation.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The new item.
|
||||
pub fn update_valid(&mut self, item: f32) {
|
||||
self.iteration += 1;
|
||||
self.valid.push((self.iteration as f32, item));
|
||||
|
||||
let x_max = self
|
||||
.valid
|
||||
.last()
|
||||
.map(|(iteration, _)| *iteration)
|
||||
.unwrap_or(f32::MIN);
|
||||
let x_min = self
|
||||
.valid
|
||||
.first()
|
||||
.map(|(iteration, _)| *iteration)
|
||||
.unwrap_or(f32::MAX);
|
||||
|
||||
if x_max - x_min > self.max_values as f32 && !self.valid.is_empty() {
|
||||
self.valid.remove(0);
|
||||
}
|
||||
}
|
||||
|
||||
/// Renders the text plot.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The rendered text plot.
|
||||
#[cfg(feature = "ui")]
|
||||
pub fn render(&self) -> String {
|
||||
use rgb::RGB8;
|
||||
use terminal_size::{terminal_size, Height, Width};
|
||||
use textplots::{Chart, ColorPlot, Shape};
|
||||
|
||||
let train_color = RGB8::new(255, 140, 140);
|
||||
let valid_color = RGB8::new(140, 140, 255);
|
||||
|
||||
let x_max_valid = self
|
||||
.valid
|
||||
.last()
|
||||
.map(|(iteration, _)| *iteration)
|
||||
.unwrap_or(f32::MIN);
|
||||
let x_max_train = self
|
||||
.train
|
||||
.last()
|
||||
.map(|(iteration, _)| *iteration)
|
||||
.unwrap_or(f32::MIN);
|
||||
let x_max = f32::max(x_max_train, x_max_valid);
|
||||
|
||||
let x_min_valid = self
|
||||
.valid
|
||||
.first()
|
||||
.map(|(iteration, _)| *iteration)
|
||||
.unwrap_or(f32::MAX);
|
||||
let x_min_train = self
|
||||
.train
|
||||
.first()
|
||||
.map(|(iteration, _)| *iteration)
|
||||
.unwrap_or(f32::MAX);
|
||||
let x_min = f32::min(x_min_train, x_min_valid);
|
||||
|
||||
let (width, height) = match terminal_size() {
|
||||
Some((Width(w), Height(_))) => (u32::max(64, w.into()) * 2 - 16, 64),
|
||||
None => (256, 64),
|
||||
};
|
||||
|
||||
Chart::new(width, height, x_min, x_max)
|
||||
.linecolorplot(&Shape::Lines(&self.train), train_color)
|
||||
.linecolorplot(&Shape::Lines(&self.valid), valid_color)
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Renders the text plot.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The rendered text plot.
|
||||
#[cfg(not(feature = "ui"))]
|
||||
pub fn render(&self) -> String {
|
||||
panic!("ui feature not enabled on burn-train")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
use super::{
|
||||
ControlsView, NumericMetricView, ProgressBarView, StatusView, TerminalFrame, TextMetricView,
|
||||
};
|
||||
use ratatui::prelude::{Constraint, Direction, Layout, Rect};
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct DashboardView<'a> {
|
||||
metric_numeric: NumericMetricView<'a>,
|
||||
metric_text: TextMetricView,
|
||||
progress: ProgressBarView,
|
||||
controls: ControlsView,
|
||||
status: StatusView,
|
||||
}
|
||||
|
||||
impl<'a> DashboardView<'a> {
|
||||
pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {
|
||||
let chunks = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([Constraint::Min(16), Constraint::Max(3)].as_ref())
|
||||
.split(size);
|
||||
let size_other = chunks[0];
|
||||
let size_progress = chunks[1];
|
||||
|
||||
let chunks = Layout::default()
|
||||
.direction(Direction::Horizontal)
|
||||
.constraints([Constraint::Percentage(38), Constraint::Percentage(62)].as_ref())
|
||||
.split(size_other);
|
||||
let size_other = chunks[0];
|
||||
let size_metric_numeric = chunks[1];
|
||||
|
||||
let chunks = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([Constraint::Max(5), Constraint::Min(6), Constraint::Max(6)].as_ref())
|
||||
.split(size_other);
|
||||
let size_controls = chunks[0];
|
||||
let size_metric_text = chunks[1];
|
||||
let size_status = chunks[2];
|
||||
|
||||
self.metric_numeric.render(frame, size_metric_numeric);
|
||||
self.metric_text.render(frame, size_metric_text);
|
||||
self.controls.render(frame, size_controls);
|
||||
self.progress.render(frame, size_progress);
|
||||
self.status.render(frame, size_status);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
use super::TerminalFrame;
|
||||
use ratatui::{
|
||||
prelude::{Alignment, Rect},
|
||||
style::{Color, Style, Stylize},
|
||||
text::{Line, Span},
|
||||
widgets::{Block, Borders, Paragraph, Wrap},
|
||||
};
|
||||
|
||||
/// Controls view.
|
||||
pub(crate) struct ControlsView;
|
||||
|
||||
impl ControlsView {
|
||||
/// Render the view.
|
||||
pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {
|
||||
let lines = vec![
|
||||
vec![
|
||||
Span::from(" Quit : ").yellow().bold(),
|
||||
Span::from("q ").bold(),
|
||||
Span::from(" Stop the training.").italic(),
|
||||
],
|
||||
vec![
|
||||
Span::from(" Plots Metrics : ").yellow().bold(),
|
||||
Span::from("⬅ ➡").bold(),
|
||||
Span::from(" Switch between metrics.").italic(),
|
||||
],
|
||||
vec![
|
||||
Span::from(" Plots Type : ").yellow().bold(),
|
||||
Span::from("⬆ ⬇").bold(),
|
||||
Span::from(" Switch between types.").italic(),
|
||||
],
|
||||
];
|
||||
let paragraph = Paragraph::new(lines.into_iter().map(Line::from).collect::<Vec<_>>())
|
||||
.alignment(Alignment::Left)
|
||||
.wrap(Wrap { trim: false })
|
||||
.style(Style::default().fg(Color::Gray))
|
||||
.block(
|
||||
Block::default()
|
||||
.borders(Borders::ALL)
|
||||
.style(Style::default().fg(Color::Gray))
|
||||
.title_alignment(Alignment::Left)
|
||||
.title("Controls"),
|
||||
);
|
||||
|
||||
frame.render_widget(paragraph, size);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,216 @@
|
|||
use super::PlotAxes;
|
||||
use ratatui::{
|
||||
style::{Color, Style, Stylize},
|
||||
symbols,
|
||||
widgets::{Dataset, GraphType},
|
||||
};
|
||||
|
||||
/// A plot that shows the full history at a reduced resolution.
|
||||
pub(crate) struct FullHistoryPlot {
|
||||
pub(crate) axes: PlotAxes,
|
||||
train: FullHistoryPoints,
|
||||
valid: FullHistoryPoints,
|
||||
next_x_state: usize,
|
||||
}
|
||||
|
||||
struct FullHistoryPoints {
|
||||
min_x: f64,
|
||||
max_x: f64,
|
||||
min_y: f64,
|
||||
max_y: f64,
|
||||
points: Vec<(f64, f64)>,
|
||||
max_samples: usize,
|
||||
step_size: usize,
|
||||
}
|
||||
|
||||
impl FullHistoryPlot {
|
||||
/// Create a new history plot.
|
||||
pub(crate) fn new(max_samples: usize) -> Self {
|
||||
Self {
|
||||
axes: PlotAxes::default(),
|
||||
train: FullHistoryPoints::new(max_samples),
|
||||
valid: FullHistoryPoints::new(max_samples),
|
||||
next_x_state: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the maximum amount of sample to display for the validation points.
|
||||
///
|
||||
/// This is necessary if we want the validation line to have the same point density as the
|
||||
/// training line.
|
||||
pub(crate) fn update_max_sample_valid(&mut self, ratio_train: f64) {
|
||||
if self.valid.step_size == 1 {
|
||||
self.valid.max_samples = (ratio_train * self.train.max_samples as f64) as usize;
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a training data point.
|
||||
pub(crate) fn push_train(&mut self, data: f64) {
|
||||
let x_current = self.next_x();
|
||||
self.train.push((x_current, data));
|
||||
|
||||
self.update_bounds();
|
||||
}
|
||||
|
||||
/// Register a validation data point.
|
||||
pub(crate) fn push_valid(&mut self, data: f64) {
|
||||
let x_current = self.next_x();
|
||||
|
||||
self.valid.push((x_current, data));
|
||||
|
||||
self.update_bounds();
|
||||
}
|
||||
|
||||
/// Create the training and validation datasets from the data points.
|
||||
pub(crate) fn datasets(&self) -> Vec<Dataset<'_>> {
|
||||
let mut datasets = Vec::with_capacity(2);
|
||||
|
||||
if !self.train.is_empty() {
|
||||
datasets.push(self.train.dataset("Train", Color::LightRed));
|
||||
}
|
||||
|
||||
if !self.valid.is_empty() {
|
||||
datasets.push(self.valid.dataset("Valid", Color::LightBlue));
|
||||
}
|
||||
|
||||
datasets
|
||||
}
|
||||
|
||||
fn next_x(&mut self) -> f64 {
|
||||
let value = self.next_x_state;
|
||||
self.next_x_state += 1;
|
||||
value as f64
|
||||
}
|
||||
|
||||
fn update_bounds(&mut self) {
|
||||
self.axes.update_bounds(
|
||||
(self.train.min_x, self.train.max_x),
|
||||
(self.valid.min_x, self.valid.max_x),
|
||||
(self.train.min_y, self.train.max_y),
|
||||
(self.valid.min_y, self.valid.max_y),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl FullHistoryPoints {
|
||||
fn new(max_samples: usize) -> Self {
|
||||
Self {
|
||||
min_x: 0.,
|
||||
max_x: 0.,
|
||||
min_y: f64::MAX,
|
||||
max_y: f64::MIN,
|
||||
points: Vec::with_capacity(max_samples),
|
||||
max_samples,
|
||||
step_size: 1,
|
||||
}
|
||||
}
|
||||
|
||||
fn push(&mut self, (x, y): (f64, f64)) {
|
||||
if x as usize % self.step_size != 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
if x > self.max_x {
|
||||
self.max_x = x;
|
||||
}
|
||||
if x < self.min_x {
|
||||
self.min_x = x;
|
||||
}
|
||||
if y > self.max_y {
|
||||
self.max_y = y;
|
||||
}
|
||||
if y < self.min_y {
|
||||
self.min_y = y
|
||||
}
|
||||
|
||||
self.points.push((x, y));
|
||||
|
||||
if self.points.len() > self.max_samples {
|
||||
self.resize();
|
||||
}
|
||||
}
|
||||
|
||||
/// We keep only half the points and we double the step size.
|
||||
///
|
||||
/// This ensure that we have the same amount of points across the X axis.
|
||||
fn resize(&mut self) {
|
||||
let mut points = Vec::with_capacity(self.max_samples / 2);
|
||||
let mut max_x = f64::MIN;
|
||||
let mut max_y = f64::MIN;
|
||||
let mut min_x = f64::MAX;
|
||||
let mut min_y = f64::MAX;
|
||||
|
||||
for (i, (x, y)) in self.points.drain(0..self.points.len()).enumerate() {
|
||||
if i % 2 == 0 {
|
||||
if x > max_x {
|
||||
max_x = x;
|
||||
}
|
||||
if x < min_x {
|
||||
min_x = x;
|
||||
}
|
||||
if y > max_y {
|
||||
max_y = y;
|
||||
}
|
||||
if y < min_y {
|
||||
min_y = y;
|
||||
}
|
||||
|
||||
points.push((x, y));
|
||||
}
|
||||
}
|
||||
|
||||
self.points = points;
|
||||
self.step_size *= 2;
|
||||
|
||||
self.min_x = min_x;
|
||||
self.max_x = max_x;
|
||||
self.min_y = min_y;
|
||||
self.max_y = max_y;
|
||||
}
|
||||
|
||||
fn dataset<'a>(&'a self, name: &'a str, color: Color) -> Dataset<'a> {
|
||||
Dataset::default()
|
||||
.name(name)
|
||||
.marker(symbols::Marker::Braille)
|
||||
.style(Style::default().fg(color).bold())
|
||||
.graph_type(GraphType::Line)
|
||||
.data(&self.points)
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.points.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_points() {
|
||||
let mut chart = FullHistoryPlot::new(10);
|
||||
chart.update_max_sample_valid(0.6);
|
||||
|
||||
for i in 0..100 {
|
||||
chart.push_train(i as f64);
|
||||
}
|
||||
for i in 0..60 {
|
||||
chart.push_valid(i as f64);
|
||||
}
|
||||
|
||||
let expected_train = vec![
|
||||
(0.0, 0.0),
|
||||
(16.0, 16.0),
|
||||
(32.0, 32.0),
|
||||
(48.0, 48.0),
|
||||
(64.0, 64.0),
|
||||
(80.0, 80.0),
|
||||
(96.0, 96.0),
|
||||
];
|
||||
|
||||
let expected_valid = vec![(100.0, 0.0), (116.0, 16.0), (128.0, 28.0), (144.0, 44.0)];
|
||||
|
||||
assert_eq!(chart.train.points, expected_train);
|
||||
assert_eq!(chart.valid.points, expected_valid);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,232 @@
|
|||
use crate::metric::dashboard::TrainingProgress;
|
||||
|
||||
use super::{FullHistoryPlot, RecentHistoryPlot, TerminalFrame};
|
||||
use crossterm::event::{Event, KeyCode};
|
||||
use ratatui::{
|
||||
prelude::{Alignment, Constraint, Direction, Layout, Rect},
|
||||
style::{Color, Modifier, Style, Stylize},
|
||||
text::Line,
|
||||
widgets::{Axis, Block, Borders, Chart, Paragraph, Tabs},
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// 1000 seems to be required to see some improvement.
|
||||
const MAX_NUM_SAMPLES_RECENT: usize = 1000;
|
||||
/// 250 seems to be the right resolution when plotting all history.
|
||||
/// Otherwise, there is too much points and the lines arent't smooth enough.
|
||||
const MAX_NUM_SAMPLES_FULL: usize = 250;
|
||||
|
||||
/// Numeric metrics state that handles creating plots.
|
||||
#[derive(Default)]
|
||||
pub(crate) struct NumericMetricsState {
|
||||
data: HashMap<String, (RecentHistoryPlot, FullHistoryPlot)>,
|
||||
names: Vec<String>,
|
||||
selected: usize,
|
||||
kind: PlotKind,
|
||||
num_samples_train: Option<usize>,
|
||||
num_samples_valid: Option<usize>,
|
||||
}
|
||||
|
||||
/// The kind of plot to display.
|
||||
#[derive(Default, Clone, Copy)]
|
||||
pub(crate) enum PlotKind {
|
||||
/// Display the full history of the metric with reduced resolution.
|
||||
#[default]
|
||||
Full,
|
||||
/// Display only the recent history of the metric, but with more resolution.
|
||||
Recent,
|
||||
}
|
||||
|
||||
impl NumericMetricsState {
|
||||
/// Register a new training value for the metric with the given name.
|
||||
pub(crate) fn push_train(&mut self, name: String, data: f64) {
|
||||
if let Some((recent, full)) = self.data.get_mut(&name) {
|
||||
recent.push_train(data);
|
||||
full.push_train(data);
|
||||
} else {
|
||||
let mut recent = RecentHistoryPlot::new(MAX_NUM_SAMPLES_RECENT);
|
||||
let mut full = FullHistoryPlot::new(MAX_NUM_SAMPLES_FULL);
|
||||
|
||||
recent.push_train(data);
|
||||
full.push_train(data);
|
||||
|
||||
self.names.push(name.clone());
|
||||
self.data.insert(name, (recent, full));
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a new validation value for the metric with the given name.
|
||||
pub(crate) fn push_valid(&mut self, key: String, data: f64) {
|
||||
if let Some((recent, full)) = self.data.get_mut(&key) {
|
||||
recent.push_valid(data);
|
||||
full.push_valid(data);
|
||||
} else {
|
||||
let mut recent = RecentHistoryPlot::new(MAX_NUM_SAMPLES_RECENT);
|
||||
let mut full = FullHistoryPlot::new(MAX_NUM_SAMPLES_FULL);
|
||||
|
||||
recent.push_valid(data);
|
||||
full.push_valid(data);
|
||||
|
||||
self.data.insert(key, (recent, full));
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the state with the training progress.
|
||||
pub(crate) fn update_progress_train(&mut self, progress: &TrainingProgress) {
|
||||
if self.num_samples_train.is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
self.num_samples_train = Some(progress.progress.items_total);
|
||||
}
|
||||
|
||||
/// Update the state with the validation progress.
|
||||
pub(crate) fn update_progress_valid(&mut self, progress: &TrainingProgress) {
|
||||
if self.num_samples_valid.is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(num_sample_train) = self.num_samples_train {
|
||||
for (_, (_recent, full)) in self.data.iter_mut() {
|
||||
let ratio = progress.progress.items_total as f64 / num_sample_train as f64;
|
||||
full.update_max_sample_valid(ratio);
|
||||
}
|
||||
}
|
||||
|
||||
self.num_samples_valid = Some(progress.progress.items_total);
|
||||
}
|
||||
|
||||
/// Create a view to display the numeric metrics.
|
||||
pub(crate) fn view(&self) -> NumericMetricView<'_> {
|
||||
match self.names.is_empty() {
|
||||
true => NumericMetricView::None,
|
||||
false => NumericMetricView::Plots(&self.names, self.selected, self.chart(), self.kind),
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle the current event.
|
||||
pub(crate) fn on_event(&mut self, event: &Event) {
|
||||
if let Event::Key(key) = event {
|
||||
match key.code {
|
||||
KeyCode::Right => self.next_metric(),
|
||||
KeyCode::Left => self.previous_metric(),
|
||||
KeyCode::Up => self.switch_kind(),
|
||||
KeyCode::Down => self.switch_kind(),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn switch_kind(&mut self) {
|
||||
self.kind = match self.kind {
|
||||
PlotKind::Full => PlotKind::Recent,
|
||||
PlotKind::Recent => PlotKind::Full,
|
||||
};
|
||||
}
|
||||
|
||||
fn next_metric(&mut self) {
|
||||
self.selected = (self.selected + 1) % {
|
||||
let this = &self;
|
||||
this.data.len()
|
||||
};
|
||||
}
|
||||
|
||||
fn previous_metric(&mut self) {
|
||||
if self.selected > 0 {
|
||||
self.selected -= 1;
|
||||
} else {
|
||||
self.selected = ({
|
||||
let this = &self;
|
||||
this.data.len()
|
||||
}) - 1;
|
||||
}
|
||||
}
|
||||
|
||||
fn chart<'a>(&'a self) -> Chart<'a> {
|
||||
let name = self.names.get(self.selected).unwrap();
|
||||
let (recent, full) = self.data.get(name).unwrap();
|
||||
|
||||
let (datasets, axes) = match self.kind {
|
||||
PlotKind::Full => (full.datasets(), &full.axes),
|
||||
PlotKind::Recent => (recent.datasets(), &recent.axes),
|
||||
};
|
||||
|
||||
Chart::<'a>::new(datasets)
|
||||
.block(Block::default())
|
||||
.x_axis(
|
||||
Axis::default()
|
||||
.style(Style::default().fg(Color::DarkGray))
|
||||
.title("Iteration")
|
||||
.labels(axes.labels_x.iter().map(|s| s.bold()).collect())
|
||||
.bounds(axes.bounds_x),
|
||||
)
|
||||
.y_axis(
|
||||
Axis::default()
|
||||
.style(Style::default().fg(Color::DarkGray))
|
||||
.labels(axes.labels_y.iter().map(|s| s.bold()).collect())
|
||||
.bounds(axes.bounds_y),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) enum NumericMetricView<'a> {
|
||||
Plots(&'a [String], usize, Chart<'a>, PlotKind),
|
||||
None,
|
||||
}
|
||||
|
||||
impl<'a> NumericMetricView<'a> {
|
||||
pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {
|
||||
match self {
|
||||
Self::Plots(titles, selected, chart, kind) => {
|
||||
let block = Block::default()
|
||||
.borders(Borders::ALL)
|
||||
.title("Plots")
|
||||
.title_alignment(Alignment::Left);
|
||||
let size_new = block.inner(size);
|
||||
frame.render_widget(block, size);
|
||||
|
||||
let size = size_new;
|
||||
|
||||
let chunks = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints(
|
||||
[
|
||||
Constraint::Length(2),
|
||||
Constraint::Length(1),
|
||||
Constraint::Min(0),
|
||||
]
|
||||
.as_ref(),
|
||||
)
|
||||
.split(size);
|
||||
|
||||
let titles = titles
|
||||
.iter()
|
||||
.map(|i| Line::from(vec![i.yellow()]))
|
||||
.collect();
|
||||
|
||||
let tabs = Tabs::new(titles)
|
||||
.select(selected)
|
||||
.style(Style::default())
|
||||
.highlight_style(
|
||||
Style::default()
|
||||
.add_modifier(Modifier::BOLD)
|
||||
.add_modifier(Modifier::UNDERLINED)
|
||||
.fg(Color::LightYellow),
|
||||
);
|
||||
let title = match kind {
|
||||
PlotKind::Full => "Full History",
|
||||
PlotKind::Recent => "Recent History",
|
||||
};
|
||||
|
||||
let plot_type =
|
||||
Paragraph::new(Line::from(title.bold())).alignment(Alignment::Center);
|
||||
|
||||
frame.render_widget(tabs, chunks[0]);
|
||||
frame.render_widget(plot_type, chunks[1]);
|
||||
frame.render_widget(chart, chunks[2]);
|
||||
}
|
||||
Self::None => {}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
use super::TerminalFrame;
|
||||
use crate::metric::MetricEntry;
|
||||
use ratatui::{
|
||||
prelude::{Alignment, Rect},
|
||||
style::{Color, Style, Stylize},
|
||||
text::{Line, Span},
|
||||
widgets::{Block, Borders, Paragraph, Wrap},
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct TextMetricsState {
|
||||
data: HashMap<String, MetricData>,
|
||||
names: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct MetricData {
|
||||
train: Option<MetricEntry>,
|
||||
valid: Option<MetricEntry>,
|
||||
}
|
||||
|
||||
impl TextMetricsState {
|
||||
pub(crate) fn update_train(&mut self, metric: MetricEntry) {
|
||||
if let Some(existing) = self.data.get_mut(&metric.name) {
|
||||
existing.train = Some(metric);
|
||||
} else {
|
||||
let key = metric.name.clone();
|
||||
let value = MetricData::new(Some(metric), None);
|
||||
|
||||
self.names.push(key.clone());
|
||||
self.data.insert(key, value);
|
||||
}
|
||||
}
|
||||
pub(crate) fn update_valid(&mut self, metric: MetricEntry) {
|
||||
if let Some(existing) = self.data.get_mut(&metric.name) {
|
||||
existing.valid = Some(metric);
|
||||
} else {
|
||||
let key = metric.name.clone();
|
||||
let value = MetricData::new(None, Some(metric));
|
||||
|
||||
self.names.push(key.clone());
|
||||
self.data.insert(key, value);
|
||||
}
|
||||
}
|
||||
pub(crate) fn view(&self) -> TextMetricView {
|
||||
TextMetricView::new(&self.names, &self.data)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct TextMetricView {
|
||||
lines: Vec<Vec<Span<'static>>>,
|
||||
}
|
||||
|
||||
impl TextMetricView {
|
||||
fn new(names: &[String], data: &HashMap<String, MetricData>) -> Self {
|
||||
let mut lines = Vec::with_capacity(names.len() * 4);
|
||||
|
||||
let start_line = |title: &str| vec![Span::from(format!(" {title} ")).bold().yellow()];
|
||||
let train_line = |formatted: &str| {
|
||||
vec![
|
||||
Span::from(" Train ").bold(),
|
||||
Span::from(formatted.to_string()).italic(),
|
||||
]
|
||||
};
|
||||
let valid_line = |formatted: &str| {
|
||||
vec![
|
||||
Span::from(" Valid ").bold(),
|
||||
Span::from(formatted.to_string()).italic(),
|
||||
]
|
||||
};
|
||||
|
||||
for name in names {
|
||||
lines.push(start_line(name));
|
||||
|
||||
let entry = data.get(name).unwrap();
|
||||
|
||||
if let Some(entry) = &entry.train {
|
||||
lines.push(train_line(&entry.formatted));
|
||||
}
|
||||
|
||||
if let Some(entry) = &entry.valid {
|
||||
lines.push(valid_line(&entry.formatted));
|
||||
}
|
||||
|
||||
lines.push(vec![Span::from("")]);
|
||||
}
|
||||
|
||||
Self { lines }
|
||||
}
|
||||
|
||||
pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {
|
||||
let paragraph = Paragraph::new(self.lines.into_iter().map(Line::from).collect::<Vec<_>>())
|
||||
.alignment(Alignment::Left)
|
||||
.wrap(Wrap { trim: false })
|
||||
.block(Block::default().borders(Borders::ALL).title("Metrics"))
|
||||
.style(Style::default().fg(Color::Gray));
|
||||
|
||||
frame.render_widget(paragraph, size);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
mod base;
|
||||
mod controls;
|
||||
mod full_history;
|
||||
mod metric_numeric;
|
||||
mod metric_text;
|
||||
mod plot_utils;
|
||||
mod popup;
|
||||
mod progress;
|
||||
mod recent_history;
|
||||
mod renderer;
|
||||
mod status;
|
||||
|
||||
pub(crate) use base::*;
|
||||
pub(crate) use controls::*;
|
||||
pub(crate) use full_history::*;
|
||||
pub(crate) use metric_numeric::*;
|
||||
pub(crate) use metric_text::*;
|
||||
pub(crate) use plot_utils::*;
|
||||
pub(crate) use popup::*;
|
||||
pub(crate) use progress::*;
|
||||
pub(crate) use recent_history::*;
|
||||
pub use renderer::*;
|
||||
pub(crate) use status::*;
|
|
@ -0,0 +1,48 @@
|
|||
use crate::metric::format_float;
|
||||
|
||||
const AXIS_TITLE_PRECISION: usize = 2;
|
||||
|
||||
/// The data describing both X and Y axes.
|
||||
pub(crate) struct PlotAxes {
|
||||
pub(crate) labels_x: Vec<String>,
|
||||
pub(crate) labels_y: Vec<String>,
|
||||
pub(crate) bounds_x: [f64; 2],
|
||||
pub(crate) bounds_y: [f64; 2],
|
||||
}
|
||||
|
||||
impl Default for PlotAxes {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
bounds_x: [f64::MAX, f64::MIN],
|
||||
bounds_y: [f64::MAX, f64::MIN],
|
||||
labels_x: Vec::new(),
|
||||
labels_y: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PlotAxes {
|
||||
/// Update the bounds based on the min max of each X and Y axes with both train and valid data.
|
||||
pub(crate) fn update_bounds(
|
||||
&mut self,
|
||||
(x_train_min, x_train_max): (f64, f64),
|
||||
(x_valid_min, x_valid_max): (f64, f64),
|
||||
(y_train_min, y_train_max): (f64, f64),
|
||||
(y_valid_min, y_valid_max): (f64, f64),
|
||||
) {
|
||||
let x_min = f64::min(x_train_min, x_valid_min);
|
||||
let x_max = f64::max(x_train_max, x_valid_max);
|
||||
let y_min = f64::min(y_train_min, y_valid_min);
|
||||
let y_max = f64::max(y_train_max, y_valid_max);
|
||||
|
||||
self.bounds_x = [x_min, x_max];
|
||||
self.bounds_y = [y_min, y_max];
|
||||
|
||||
// We know x are integers.
|
||||
self.labels_x = vec![format!("{x_min}"), format!("{x_max}")];
|
||||
self.labels_y = vec![
|
||||
format_float(y_min, AXIS_TITLE_PRECISION),
|
||||
format_float(y_max, AXIS_TITLE_PRECISION),
|
||||
];
|
||||
}
|
||||
}
|
|
@ -0,0 +1,144 @@
|
|||
use crossterm::event::{Event, KeyCode};
|
||||
use ratatui::{
|
||||
prelude::{Alignment, Constraint, Direction, Layout, Rect},
|
||||
style::{Color, Modifier, Style, Stylize},
|
||||
text::{Line, Span},
|
||||
widgets::{Block, Borders, Paragraph, Wrap},
|
||||
};
|
||||
|
||||
use super::TerminalFrame;
|
||||
|
||||
/// Popup callback function.
|
||||
pub(crate) trait CallbackFn: Send + Sync {
|
||||
/// Call the function and return if the popup state should be reset.
|
||||
fn call(&self) -> bool;
|
||||
}
|
||||
|
||||
/// Popup callback.
|
||||
pub(crate) struct Callback {
|
||||
title: String,
|
||||
description: String,
|
||||
trigger: char,
|
||||
callback: Box<dyn CallbackFn>,
|
||||
}
|
||||
|
||||
impl Callback {
|
||||
/// Create a new popup.
|
||||
pub(crate) fn new<T, D, C>(title: T, description: D, trigger: char, callback: C) -> Self
|
||||
where
|
||||
T: Into<String>,
|
||||
D: Into<String>,
|
||||
C: CallbackFn + 'static,
|
||||
{
|
||||
Self {
|
||||
title: title.into(),
|
||||
description: description.into(),
|
||||
trigger,
|
||||
callback: Box::new(callback),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Popup state.
|
||||
pub(crate) enum PopupState {
|
||||
Empty,
|
||||
Full(String, Vec<Callback>),
|
||||
}
|
||||
|
||||
impl PopupState {
|
||||
/// If the popup is empty.
|
||||
pub(crate) fn is_empty(&self) -> bool {
|
||||
matches!(&self, PopupState::Empty)
|
||||
}
|
||||
/// Handle popup events.
|
||||
pub(crate) fn on_event(&mut self, event: &Event) {
|
||||
let mut reset = false;
|
||||
|
||||
match self {
|
||||
PopupState::Empty => {}
|
||||
PopupState::Full(_, callbacks) => {
|
||||
for callback in callbacks.iter() {
|
||||
if let Event::Key(key) = event {
|
||||
if let KeyCode::Char(key) = &key.code {
|
||||
if &callback.trigger == key && callback.callback.call() {
|
||||
reset = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if reset {
|
||||
*self = Self::Empty;
|
||||
}
|
||||
}
|
||||
/// Create the popup view.
|
||||
pub(crate) fn view(&self) -> Option<PopupView<'_>> {
|
||||
match self {
|
||||
PopupState::Empty => None,
|
||||
PopupState::Full(title, callbacks) => Some(PopupView::new(title, callbacks)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct PopupView<'a> {
|
||||
title: &'a String,
|
||||
callbacks: &'a [Callback],
|
||||
}
|
||||
|
||||
impl<'a> PopupView<'a> {
|
||||
/// Render the view.
|
||||
pub(crate) fn render<'b>(&'a self, frame: &mut TerminalFrame<'b>, size: Rect) {
|
||||
let lines = self
|
||||
.callbacks
|
||||
.iter()
|
||||
.flat_map(|callback| {
|
||||
vec![
|
||||
Line::from(vec![
|
||||
Span::from(format!("[{}] ", callback.trigger)).bold(),
|
||||
Span::from(format!("{} ", callback.title)).yellow().bold(),
|
||||
]),
|
||||
Line::from(Span::from("")),
|
||||
Line::from(Span::from(callback.description.to_string()).italic()),
|
||||
Line::from(Span::from("")),
|
||||
]
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let paragraph = Paragraph::new(lines)
|
||||
.alignment(Alignment::Left)
|
||||
.wrap(Wrap { trim: false })
|
||||
.style(Style::default().fg(Color::Gray))
|
||||
.block(
|
||||
Block::default()
|
||||
.borders(Borders::ALL)
|
||||
.title_alignment(Alignment::Center)
|
||||
.style(Style::default().fg(Color::Gray))
|
||||
.title(Span::styled(
|
||||
self.title,
|
||||
Style::default().add_modifier(Modifier::BOLD),
|
||||
)),
|
||||
);
|
||||
|
||||
let area = centered_percent(20, size, Direction::Horizontal);
|
||||
let area = centered_percent(20, area, Direction::Vertical);
|
||||
|
||||
frame.render_widget(paragraph, area);
|
||||
}
|
||||
}
|
||||
|
||||
/// The percent represents the amount of space that will be taken by each side.
|
||||
fn centered_percent(percent: u16, size: Rect, direction: Direction) -> Rect {
|
||||
let center = 100 - (percent * 2);
|
||||
|
||||
Layout::default()
|
||||
.direction(direction)
|
||||
.constraints([
|
||||
Constraint::Percentage(percent),
|
||||
Constraint::Percentage(center),
|
||||
Constraint::Percentage(percent),
|
||||
])
|
||||
.split(size)[1]
|
||||
}
|
|
@ -0,0 +1,144 @@
|
|||
use super::TerminalFrame;
|
||||
use crate::metric::dashboard::TrainingProgress;
|
||||
use ratatui::{
|
||||
prelude::{Alignment, Constraint, Direction, Layout, Rect},
|
||||
style::{Color, Style, Stylize},
|
||||
text::{Line, Span},
|
||||
widgets::{Block, Borders, Gauge, Paragraph},
|
||||
};
|
||||
use std::time::Instant;
|
||||
|
||||
/// Simple progress bar for the training.
|
||||
///
|
||||
/// We currently ignore the time taken for the validation part.
|
||||
pub(crate) struct ProgressBarState {
|
||||
progress_train: f64,
|
||||
started: Instant,
|
||||
}
|
||||
|
||||
impl Default for ProgressBarState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
progress_train: 0.0,
|
||||
started: Instant::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const MINUTE: u64 = 60;
|
||||
const HOUR: u64 = 60 * 60;
|
||||
const DAY: u64 = 24 * 60 * 60;
|
||||
|
||||
impl ProgressBarState {
|
||||
/// Update the training progress.
|
||||
pub(crate) fn update_train(&mut self, progress: &TrainingProgress) {
|
||||
let total_items = progress.progress.items_total * progress.epoch_total;
|
||||
let epoch_items = (progress.epoch - 1) * progress.progress.items_total;
|
||||
let iteration_items = progress.progress.items_processed as f64;
|
||||
|
||||
self.progress_train = (epoch_items as f64 + iteration_items) / total_items as f64
|
||||
}
|
||||
|
||||
/// Update the validation progress.
|
||||
pub(crate) fn update_valid(&mut self, _progress: &TrainingProgress) {
|
||||
// We don't use the validation for the progress yet.
|
||||
}
|
||||
|
||||
/// Create a view for the current progress.
|
||||
pub(crate) fn view(&self) -> ProgressBarView {
|
||||
let eta = self.started.elapsed();
|
||||
let total_estimated = (eta.as_secs() as f64) / self.progress_train;
|
||||
|
||||
let eta = if total_estimated.is_normal() {
|
||||
let remaining = 1.0 - self.progress_train;
|
||||
let eta = (total_estimated * remaining) as u64;
|
||||
format_eta(eta)
|
||||
} else {
|
||||
"---".to_string()
|
||||
};
|
||||
ProgressBarView::new(self.progress_train, eta)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct ProgressBarView {
|
||||
progress: f64,
|
||||
eta: String,
|
||||
}
|
||||
|
||||
impl ProgressBarView {
|
||||
/// Render the view.
|
||||
pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {
|
||||
let block = Block::default()
|
||||
.borders(Borders::ALL)
|
||||
.title("Progress")
|
||||
.title_alignment(Alignment::Left);
|
||||
let size_new = block.inner(size);
|
||||
frame.render_widget(block, size);
|
||||
let size = size_new;
|
||||
|
||||
let chunks = Layout::default()
|
||||
.direction(Direction::Horizontal)
|
||||
.constraints(
|
||||
[
|
||||
Constraint::Length(1), // Empty space
|
||||
Constraint::Min(0),
|
||||
Constraint::Length(self.eta.len() as u16 + 4),
|
||||
]
|
||||
.as_ref(),
|
||||
)
|
||||
.split(size);
|
||||
|
||||
let size_gauge = chunks[1];
|
||||
let size_eta = chunks[2];
|
||||
|
||||
let iteration = Gauge::default()
|
||||
.gauge_style(Style::default().fg(Color::Yellow))
|
||||
.ratio(self.progress);
|
||||
let eta = Paragraph::new(Line::from(vec![
|
||||
Span::from(" ("),
|
||||
Span::from(self.eta).italic(),
|
||||
Span::from(") "),
|
||||
]));
|
||||
|
||||
frame.render_widget(iteration, size_gauge);
|
||||
frame.render_widget(eta, size_eta);
|
||||
}
|
||||
}
|
||||
|
||||
fn format_eta(eta_secs: u64) -> String {
|
||||
let seconds = eta_secs % 60;
|
||||
let minutes = eta_secs / MINUTE % 60;
|
||||
let hours = eta_secs / HOUR % 24;
|
||||
let days = eta_secs / DAY;
|
||||
|
||||
if days > 0 {
|
||||
return format!("{days} days");
|
||||
}
|
||||
|
||||
if hours > 0 {
|
||||
return format!("{hours} hours");
|
||||
}
|
||||
|
||||
if minutes > 0 {
|
||||
return format!("{minutes} mins");
|
||||
}
|
||||
|
||||
format!("{seconds} secs")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_format_eta() {
|
||||
assert_eq!("55 secs", format_eta(55), "Less than 1 minutes");
|
||||
assert_eq!("1 mins", format_eta(61), "More than 1 minutes");
|
||||
assert_eq!("2 mins", format_eta(2 * 61), "More than 2 minutes");
|
||||
assert_eq!("1 hours", format_eta(3601), "More than 1 hour");
|
||||
assert_eq!("2 hours", format_eta(2 * 3601), "More than 2 hour");
|
||||
assert_eq!("1 days", format_eta(24 * 3601), "More than 1 day");
|
||||
assert_eq!("2 days", format_eta(48 * 3601), "More than 2 day");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,244 @@
|
|||
use super::PlotAxes;
|
||||
use ratatui::{
|
||||
style::{Color, Style, Stylize},
|
||||
symbols,
|
||||
widgets::{Dataset, GraphType},
|
||||
};
|
||||
|
||||
const FACTOR_BEFORE_RESIZE: usize = 2;
|
||||
|
||||
/// A plot that shows the recent history at full resolution.
|
||||
pub(crate) struct RecentHistoryPlot {
|
||||
pub(crate) axes: PlotAxes,
|
||||
train: RecentHistoryPoints,
|
||||
valid: RecentHistoryPoints,
|
||||
max_samples: usize,
|
||||
}
|
||||
|
||||
struct RecentHistoryPoints {
|
||||
min_x: f64,
|
||||
max_x: f64,
|
||||
min_y: f64,
|
||||
max_y: f64,
|
||||
cursor: usize,
|
||||
points: Vec<(f64, f64)>,
|
||||
max_samples: usize,
|
||||
factor_before_resize: usize,
|
||||
}
|
||||
|
||||
impl RecentHistoryPlot {
|
||||
pub(crate) fn new(max_samples: usize) -> Self {
|
||||
Self {
|
||||
axes: PlotAxes::default(),
|
||||
train: RecentHistoryPoints::new(max_samples),
|
||||
valid: RecentHistoryPoints::new(max_samples),
|
||||
max_samples,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn push_train(&mut self, data: f64) {
|
||||
let (x_min, x_current) = self.x();
|
||||
|
||||
self.train.push((x_current, data));
|
||||
self.train.update_cursor(x_min);
|
||||
self.valid.update_cursor(x_min);
|
||||
|
||||
self.update_bounds();
|
||||
}
|
||||
|
||||
pub(crate) fn push_valid(&mut self, data: f64) {
|
||||
let (x_min, x_current) = self.x();
|
||||
|
||||
self.valid.push((x_current, data));
|
||||
self.valid.update_cursor(x_min);
|
||||
self.train.update_cursor(x_min);
|
||||
|
||||
self.update_bounds();
|
||||
}
|
||||
|
||||
pub(crate) fn datasets(&self) -> Vec<Dataset<'_>> {
|
||||
let mut datasets = Vec::with_capacity(2);
|
||||
|
||||
if self.train.num_visible_points() > 0 {
|
||||
datasets.push(self.train.dataset("Train", Color::LightRed));
|
||||
}
|
||||
|
||||
if self.valid.num_visible_points() > 0 {
|
||||
datasets.push(self.valid.dataset("Valid", Color::LightBlue));
|
||||
}
|
||||
|
||||
datasets
|
||||
}
|
||||
|
||||
fn x(&mut self) -> (f64, f64) {
|
||||
let x_current = f64::max(self.train.max_x, self.valid.max_x) + 1.0;
|
||||
let mut x_min = f64::min(self.train.min_x, self.valid.min_x);
|
||||
if x_current - x_min >= self.max_samples as f64 {
|
||||
x_min += 1.0;
|
||||
}
|
||||
|
||||
(x_min, x_current)
|
||||
}
|
||||
|
||||
fn update_bounds(&mut self) {
|
||||
self.axes.update_bounds(
|
||||
(self.train.min_x, self.train.max_x),
|
||||
(self.valid.min_x, self.valid.max_x),
|
||||
(self.train.min_y, self.train.max_y),
|
||||
(self.valid.min_y, self.valid.max_y),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl RecentHistoryPoints {
|
||||
fn new(max_samples: usize) -> Self {
|
||||
let factor_before_resize = FACTOR_BEFORE_RESIZE;
|
||||
|
||||
Self {
|
||||
min_x: 0.,
|
||||
max_x: 0.,
|
||||
min_y: f64::MAX,
|
||||
max_y: f64::MIN,
|
||||
points: Vec::with_capacity(factor_before_resize * max_samples),
|
||||
cursor: 0,
|
||||
max_samples,
|
||||
factor_before_resize,
|
||||
}
|
||||
}
|
||||
|
||||
fn num_visible_points(&self) -> usize {
|
||||
self.points.len()
|
||||
}
|
||||
|
||||
fn push(&mut self, (x, y): (f64, f64)) {
|
||||
if x > self.max_x {
|
||||
self.max_x = x;
|
||||
}
|
||||
if x < self.min_x {
|
||||
self.min_x = x;
|
||||
}
|
||||
if y > self.max_y {
|
||||
self.max_y = y;
|
||||
}
|
||||
if y < self.min_y {
|
||||
self.min_y = y
|
||||
}
|
||||
self.points.push((x, y));
|
||||
}
|
||||
|
||||
fn update_cursor(&mut self, min_x: f64) {
|
||||
if self.min_x >= min_x {
|
||||
return;
|
||||
}
|
||||
self.min_x = min_x;
|
||||
|
||||
let mut update_y_max = false;
|
||||
let mut update_y_min = false;
|
||||
|
||||
while let Some((x, y)) = self.points.get(self.cursor) {
|
||||
if *x >= self.min_x {
|
||||
break;
|
||||
}
|
||||
|
||||
if *y == self.max_y {
|
||||
update_y_max = true
|
||||
}
|
||||
if *y == self.min_y {
|
||||
update_y_min = true;
|
||||
}
|
||||
|
||||
self.cursor += 1;
|
||||
}
|
||||
|
||||
if update_y_max {
|
||||
self.max_y = self.calculate_max_y();
|
||||
}
|
||||
|
||||
if update_y_min {
|
||||
self.min_y = self.calculate_min_y();
|
||||
}
|
||||
|
||||
if self.points.len() >= self.max_samples * self.factor_before_resize {
|
||||
self.resize();
|
||||
}
|
||||
}
|
||||
|
||||
fn slice(&self) -> &[(f64, f64)] {
|
||||
&self.points[self.cursor..self.points.len()]
|
||||
}
|
||||
|
||||
fn calculate_max_y(&self) -> f64 {
|
||||
let mut max_y = f64::MIN;
|
||||
|
||||
for (_x, y) in self.slice() {
|
||||
if *y > max_y {
|
||||
max_y = *y;
|
||||
}
|
||||
}
|
||||
|
||||
max_y
|
||||
}
|
||||
|
||||
fn calculate_min_y(&self) -> f64 {
|
||||
let mut min_y = f64::MAX;
|
||||
|
||||
for (_x, y) in self.slice() {
|
||||
if *y < min_y {
|
||||
min_y = *y;
|
||||
}
|
||||
}
|
||||
|
||||
min_y
|
||||
}
|
||||
|
||||
fn resize(&mut self) {
|
||||
let mut points = Vec::with_capacity(self.max_samples * self.factor_before_resize);
|
||||
|
||||
for i in self.cursor..self.points.len() {
|
||||
points.push(self.points[i]);
|
||||
}
|
||||
|
||||
self.points = points;
|
||||
self.cursor = 0;
|
||||
}
|
||||
|
||||
fn dataset<'a>(&'a self, name: &'a str, color: Color) -> Dataset<'a> {
|
||||
let data = &self.points[self.cursor..self.points.len()];
|
||||
|
||||
Dataset::default()
|
||||
.name(name)
|
||||
.marker(symbols::Marker::Braille)
|
||||
.style(Style::default().fg(color).bold())
|
||||
.graph_type(GraphType::Scatter)
|
||||
.data(data)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_push_update_bounds_max_y() {
|
||||
let mut chart = RecentHistoryPlot::new(3);
|
||||
chart.push_train(15.0);
|
||||
chart.push_train(10.0);
|
||||
chart.push_train(14.0);
|
||||
|
||||
assert_eq!(chart.axes.bounds_y[1], 15.);
|
||||
chart.push_train(10.0);
|
||||
assert_eq!(chart.axes.bounds_y[1], 14.);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_push_update_bounds_min_y() {
|
||||
let mut chart = RecentHistoryPlot::new(3);
|
||||
chart.push_train(5.0);
|
||||
chart.push_train(10.0);
|
||||
chart.push_train(14.0);
|
||||
|
||||
assert_eq!(chart.axes.bounds_y[0], 5.);
|
||||
chart.push_train(10.0);
|
||||
assert_eq!(chart.axes.bounds_y[0], 10.);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,203 @@
|
|||
use crate::metric::dashboard::tui::NumericMetricsState;
|
||||
use crate::metric::dashboard::{DashboardMetricState, DashboardRenderer, TrainingProgress};
|
||||
use crate::TrainingInterrupter;
|
||||
use crossterm::{
|
||||
event::{self, Event, KeyCode},
|
||||
execute,
|
||||
terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen},
|
||||
};
|
||||
use ratatui::{prelude::*, Terminal};
|
||||
use std::{
|
||||
error::Error,
|
||||
io::{self, Stdout},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use super::{
|
||||
Callback, CallbackFn, ControlsView, DashboardView, PopupState, ProgressBarState, StatusState,
|
||||
TextMetricsState,
|
||||
};
|
||||
|
||||
/// The current terminal backend.
|
||||
pub(crate) type TerminalBackend = CrosstermBackend<Stdout>;
|
||||
/// The current terminal frame.
|
||||
pub(crate) type TerminalFrame<'a> = ratatui::Frame<'a, TerminalBackend>;
|
||||
|
||||
const MAX_REFRESH_RATE_MILLIS: u64 = 100;
|
||||
|
||||
/// The CLI dashboard renderer.
|
||||
pub struct TuiDashboardRenderer {
|
||||
terminal: Terminal<TerminalBackend>,
|
||||
last_update: std::time::Instant,
|
||||
progress: ProgressBarState,
|
||||
metrics_numeric: NumericMetricsState,
|
||||
metrics_text: TextMetricsState,
|
||||
status: StatusState,
|
||||
interuptor: TrainingInterrupter,
|
||||
popup: PopupState,
|
||||
}
|
||||
|
||||
impl DashboardRenderer for TuiDashboardRenderer {
|
||||
fn update_train(&mut self, state: DashboardMetricState) {
|
||||
match state {
|
||||
DashboardMetricState::Generic(entry) => {
|
||||
self.metrics_text.update_train(entry);
|
||||
}
|
||||
DashboardMetricState::Numeric(entry, value) => {
|
||||
self.metrics_numeric.push_train(entry.name.clone(), value);
|
||||
self.metrics_text.update_train(entry);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn update_valid(&mut self, state: DashboardMetricState) {
|
||||
match state {
|
||||
DashboardMetricState::Generic(entry) => {
|
||||
self.metrics_text.update_valid(entry);
|
||||
}
|
||||
DashboardMetricState::Numeric(entry, value) => {
|
||||
self.metrics_numeric.push_valid(entry.name.clone(), value);
|
||||
self.metrics_text.update_valid(entry);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn render_train(&mut self, item: TrainingProgress) {
|
||||
self.progress.update_train(&item);
|
||||
self.metrics_numeric.update_progress_train(&item);
|
||||
self.status.update_train(item);
|
||||
self.render().unwrap();
|
||||
}
|
||||
|
||||
fn render_valid(&mut self, item: TrainingProgress) {
|
||||
self.progress.update_valid(&item);
|
||||
self.metrics_numeric.update_progress_valid(&item);
|
||||
self.status.update_valid(item);
|
||||
self.render().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
impl TuiDashboardRenderer {
|
||||
/// Create a new CLI dashboard renderer.
|
||||
pub fn new(interuptor: TrainingInterrupter) -> Self {
|
||||
let mut stdout = io::stdout();
|
||||
execute!(stdout, EnterAlternateScreen).unwrap();
|
||||
enable_raw_mode().unwrap();
|
||||
let terminal = Terminal::new(CrosstermBackend::new(stdout)).unwrap();
|
||||
|
||||
Self {
|
||||
terminal,
|
||||
last_update: Instant::now(),
|
||||
progress: ProgressBarState::default(),
|
||||
metrics_numeric: NumericMetricsState::default(),
|
||||
metrics_text: TextMetricsState::default(),
|
||||
status: StatusState::default(),
|
||||
interuptor,
|
||||
popup: PopupState::Empty,
|
||||
}
|
||||
}
|
||||
|
||||
fn render(&mut self) -> Result<(), Box<dyn Error>> {
|
||||
let tick_rate = Duration::from_millis(MAX_REFRESH_RATE_MILLIS);
|
||||
if self.last_update.elapsed() < tick_rate {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.draw()?;
|
||||
self.handle_events()?;
|
||||
|
||||
self.last_update = Instant::now();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn draw(&mut self) -> Result<(), Box<dyn Error>> {
|
||||
self.terminal.draw(|frame| {
|
||||
let size = frame.size();
|
||||
|
||||
match self.popup.view() {
|
||||
Some(view) => view.render(frame, size),
|
||||
None => {
|
||||
let view = DashboardView::new(
|
||||
self.metrics_numeric.view(),
|
||||
self.metrics_text.view(),
|
||||
self.progress.view(),
|
||||
ControlsView,
|
||||
self.status.view(),
|
||||
);
|
||||
|
||||
view.render(frame, size);
|
||||
}
|
||||
};
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_events(&mut self) -> Result<(), Box<dyn Error>> {
|
||||
while event::poll(Duration::from_secs(0))? {
|
||||
let event = event::read()?;
|
||||
self.popup.on_event(&event);
|
||||
|
||||
if self.popup.is_empty() {
|
||||
self.metrics_numeric.on_event(&event);
|
||||
|
||||
if let Event::Key(key) = event {
|
||||
if let KeyCode::Char('q') = key.code {
|
||||
self.popup = PopupState::Full(
|
||||
"Quit".to_string(),
|
||||
vec![
|
||||
Callback::new(
|
||||
"Stop the training.",
|
||||
"Stop the training immediately. This will break from the training loop, but any remaining code after the loop will be executed.",
|
||||
's',
|
||||
QuitPopupAccept(self.interuptor.clone()),
|
||||
),
|
||||
Callback::new(
|
||||
"Stop the training immediately.",
|
||||
"Kill the program. This will create a panic! which will make the current training fails. Any code following the training won't be executed.",
|
||||
'k',
|
||||
KillPopupAccept,
|
||||
),
|
||||
Callback::new("Cancel", "Cancel the action, continue the training.", 'c', PopupCancel),
|
||||
],
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct QuitPopupAccept(TrainingInterrupter);
|
||||
struct KillPopupAccept;
|
||||
struct PopupCancel;
|
||||
|
||||
impl CallbackFn for KillPopupAccept {
|
||||
fn call(&self) -> bool {
|
||||
panic!("Killing training from user input.");
|
||||
}
|
||||
}
|
||||
|
||||
impl CallbackFn for QuitPopupAccept {
|
||||
fn call(&self) -> bool {
|
||||
self.0.stop();
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl CallbackFn for PopupCancel {
|
||||
fn call(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TuiDashboardRenderer {
|
||||
fn drop(&mut self) {
|
||||
disable_raw_mode().ok();
|
||||
execute!(self.terminal.backend_mut(), LeaveAlternateScreen).unwrap();
|
||||
self.terminal.show_cursor().ok();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
use super::TerminalFrame;
|
||||
use crate::metric::dashboard::TrainingProgress;
|
||||
use ratatui::{
|
||||
prelude::{Alignment, Rect},
|
||||
style::{Color, Style, Stylize},
|
||||
text::{Line, Span},
|
||||
widgets::{Block, Borders, Paragraph, Wrap},
|
||||
};
|
||||
|
||||
/// Show the training status with various information.
|
||||
pub(crate) struct StatusState {
|
||||
progress: TrainingProgress,
|
||||
mode: Mode,
|
||||
}
|
||||
|
||||
enum Mode {
|
||||
Valid,
|
||||
Train,
|
||||
}
|
||||
|
||||
impl Default for StatusState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
progress: TrainingProgress::none(),
|
||||
mode: Mode::Train,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StatusState {
|
||||
/// Update the training information.
|
||||
pub(crate) fn update_train(&mut self, progress: TrainingProgress) {
|
||||
self.progress = progress;
|
||||
self.mode = Mode::Train;
|
||||
}
|
||||
/// Update the validation information.
|
||||
pub(crate) fn update_valid(&mut self, progress: TrainingProgress) {
|
||||
self.progress = progress;
|
||||
self.mode = Mode::Valid;
|
||||
}
|
||||
/// Create a view.
|
||||
pub(crate) fn view(&self) -> StatusView {
|
||||
StatusView::new(&self.progress, &self.mode)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct StatusView {
|
||||
lines: Vec<Vec<Span<'static>>>,
|
||||
}
|
||||
|
||||
impl StatusView {
|
||||
fn new(progress: &TrainingProgress, mode: &Mode) -> Self {
|
||||
let title = |title: &str| Span::from(format!(" {title} ")).bold().yellow();
|
||||
let value = |value: String| Span::from(value).italic();
|
||||
let mode = match mode {
|
||||
Mode::Valid => "Validating",
|
||||
Mode::Train => "Training",
|
||||
};
|
||||
|
||||
Self {
|
||||
lines: vec![
|
||||
vec![title("Mode :"), value(mode.to_string())],
|
||||
vec![
|
||||
title("Epoch :"),
|
||||
value(format!("{}/{}", progress.epoch, progress.epoch_total)),
|
||||
],
|
||||
vec![
|
||||
title("Iteration :"),
|
||||
value(format!("{}", progress.iteration)),
|
||||
],
|
||||
vec![
|
||||
title("Items :"),
|
||||
value(format!(
|
||||
"{}/{}",
|
||||
progress.progress.items_processed, progress.progress.items_total
|
||||
)),
|
||||
],
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {
|
||||
let paragraph = Paragraph::new(self.lines.into_iter().map(Line::from).collect::<Vec<_>>())
|
||||
.alignment(Alignment::Left)
|
||||
.block(Block::default().borders(Borders::ALL).title("Status"))
|
||||
.wrap(Wrap { trim: false })
|
||||
.style(Style::default().fg(Color::Gray));
|
||||
|
||||
frame.render_widget(paragraph, size);
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
use super::{MetricEntry, Numeric};
|
||||
use super::{format_float, MetricEntry, Numeric};
|
||||
|
||||
/// Usefull utility to implement numeric metrics.
|
||||
///
|
||||
|
@ -70,20 +70,10 @@ impl NumericMetricState {
|
|||
let serialized = value_current.to_string();
|
||||
|
||||
let (formatted_current, formatted_running) = match format.precision {
|
||||
Some(precision) => {
|
||||
let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);
|
||||
|
||||
(
|
||||
match scientific_notation_threshold >= value_current {
|
||||
true => format!("{value_current:.precision$e}"),
|
||||
false => format!("{value_current:.precision$}"),
|
||||
},
|
||||
match scientific_notation_threshold >= value_running {
|
||||
true => format!("{value_running:.precision$e}"),
|
||||
false => format!("{value_running:.precision$}"),
|
||||
},
|
||||
)
|
||||
}
|
||||
Some(precision) => (
|
||||
format_float(value_current, precision),
|
||||
format_float(value_running, precision),
|
||||
),
|
||||
None => (format!("{value_current}"), format!("{value_running}")),
|
||||
};
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ train = ["burn-train/default", "autodiff", "dataset"]
|
|||
train-minimal = ["burn-train"]
|
||||
|
||||
## Includes the Text UI (progress bars, metric plots)
|
||||
train-ui = ["burn-train/ui"]
|
||||
train-tui = ["burn-train/tui"]
|
||||
|
||||
## Includes system info metrics (CPU/GPU usage, etc)
|
||||
train-metrics = ["burn-train/metrics"]
|
||||
|
|
Loading…
Reference in New Issue