diff --git a/burn-book/src/building-blocks/learner.md b/burn-book/src/building-blocks/learner.md index d9bfcf750..bae1a4fed 100644 --- a/burn-book/src/building-blocks/learner.md +++ b/burn-book/src/building-blocks/learner.md @@ -31,6 +31,7 @@ The learner builder provides numerous options when it comes to configurations. | Num Epochs | Set the number of epochs. | | Devices | Set the devices to be used | | Checkpoint | Restart training from a checkpoint | +| Application logging | Configure the application logging installer (default is writing to `experiment.log`) | When the builder is configured at your liking, you can then move forward to build the learner. The build method requires three inputs: the model, the optimizer and the learning rate scheduler. Note diff --git a/crates/burn-train/src/learner/application_logger.rs b/crates/burn-train/src/learner/application_logger.rs new file mode 100644 index 000000000..793ac6ada --- /dev/null +++ b/crates/burn-train/src/learner/application_logger.rs @@ -0,0 +1,67 @@ +use std::path::Path; +use tracing_core::{Level, LevelFilter}; +use tracing_subscriber::filter::filter_fn; +use tracing_subscriber::prelude::*; +use tracing_subscriber::{registry, Layer}; + +/// This trait is used to install an application logger. +pub trait ApplicationLoggerInstaller { + /// Install the application logger. + fn install(&self) -> Result<(), String>; +} + +/// This struct is used to install a local file application logger to output logs to a given file path. +pub struct FileApplicationLoggerInstaller { + path: String, +} + +impl FileApplicationLoggerInstaller { + /// Create a new file application logger. + pub fn new(path: &str) -> Self { + Self { + path: path.to_string(), + } + } +} + +impl ApplicationLoggerInstaller for FileApplicationLoggerInstaller { + fn install(&self) -> Result<(), String> { + let path = Path::new(&self.path); + let writer = tracing_appender::rolling::never( + path.parent().unwrap_or_else(|| Path::new(".")), + path.file_name() + .unwrap_or_else(|| panic!("The path '{}' to point to a file.", self.path)), + ); + let layer = tracing_subscriber::fmt::layer() + .with_ansi(false) + .with_writer(writer) + .with_filter(LevelFilter::INFO) + .with_filter(filter_fn(|m| { + if let Some(path) = m.module_path() { + // The wgpu crate is logging too much, so we skip `info` level. + if path.starts_with("wgpu") && *m.level() >= Level::INFO { + return false; + } + } + true + })); + + if registry().with(layer).try_init().is_err() { + return Err("Failed to install the file logger.".to_string()); + } + + let hook = std::panic::take_hook(); + let file_path: String = self.path.to_owned(); + + std::panic::set_hook(Box::new(move |info| { + log::error!("PANIC => {}", info.to_string()); + eprintln!( + "=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \ + '{file_path}'\n=============" + ); + hook(info); + })); + + Ok(()) + } +} diff --git a/crates/burn-train/src/learner/builder.rs b/crates/burn-train/src/learner/builder.rs index cf8c5a92a..c8c4de629 100644 --- a/crates/burn-train/src/learner/builder.rs +++ b/crates/burn-train/src/learner/builder.rs @@ -1,7 +1,6 @@ use std::collections::HashSet; use std::rc::Rc; -use super::log::install_file_logger; use super::Learner; use crate::checkpoint::{ AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer, @@ -15,7 +14,10 @@ use crate::metric::processor::{FullEventProcessor, Metrics}; use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split}; use crate::metric::{Adaptor, LossMetric, Metric}; use crate::renderer::{default_renderer, MetricsRenderer}; -use crate::{LearnerCheckpointer, LearnerSummaryConfig}; +use crate::{ + ApplicationLoggerInstaller, FileApplicationLoggerInstaller, LearnerCheckpointer, + LearnerSummaryConfig, +}; use burn_core::lr_scheduler::LrScheduler; use burn_core::module::AutodiffModule; use burn_core::optim::Optimizer; @@ -50,7 +52,7 @@ where metrics: Metrics, event_store: LogEventStore, interrupter: TrainingInterrupter, - log_to_file: bool, + tracing_logger: Option>, num_loggers: usize, checkpointer_strategy: Box, early_stopping: Option>, @@ -84,7 +86,9 @@ where event_store: LogEventStore::default(), renderer: None, interrupter: TrainingInterrupter::new(), - log_to_file: true, + tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new( + format!("{}/experiment.log", directory).as_str(), + ))), num_loggers: 0, checkpointer_strategy: Box::new( ComposedCheckpointingStrategy::builder() @@ -233,8 +237,11 @@ where /// By default, Rust logs are captured and written into /// `experiment.log`. If disabled, standard Rust log handling /// will apply. - pub fn log_to_file(mut self, enabled: bool) -> Self { - self.log_to_file = enabled; + pub fn with_application_logger( + mut self, + logger: Option>, + ) -> Self { + self.tracing_logger = logger; self } @@ -258,7 +265,7 @@ where format!("{}/checkpoint", self.directory).as_str(), "optim", ); - let checkpointer_scheduler = FileCheckpointer::new( + let checkpointer_scheduler: FileCheckpointer = FileCheckpointer::new( recorder, format!("{}/checkpoint", self.directory).as_str(), "scheduler", @@ -309,8 +316,10 @@ where O::Record: 'static, S::Record: 'static, { - if self.log_to_file { - self.init_logger(); + if self.tracing_logger.is_some() { + if let Err(e) = self.tracing_logger.as_ref().unwrap().install() { + log::warn!("Failed to install the experiment logger: {}", e); + } } let renderer = self.renderer.unwrap_or_else(|| { Box::new(default_renderer(self.interrupter.clone(), self.checkpoint)) @@ -360,9 +369,4 @@ where summary, } } - - fn init_logger(&self) { - let file_path = format!("{}/experiment.log", self.directory); - install_file_logger(file_path.as_str()); - } } diff --git a/crates/burn-train/src/learner/log.rs b/crates/burn-train/src/learner/log.rs deleted file mode 100644 index 86cc4893d..000000000 --- a/crates/burn-train/src/learner/log.rs +++ /dev/null @@ -1,47 +0,0 @@ -use std::path::Path; -use tracing_core::{Level, LevelFilter}; -use tracing_subscriber::filter::filter_fn; -use tracing_subscriber::prelude::*; -use tracing_subscriber::{registry, Layer}; - -/// If a global tracing subscriber is not already configured, set up logging to a file, -/// and add our custom panic hook. -pub(crate) fn install_file_logger(file_path: &str) { - let path = Path::new(file_path); - let writer = tracing_appender::rolling::never( - path.parent().unwrap_or_else(|| Path::new(".")), - path.file_name() - .unwrap_or_else(|| panic!("The path '{file_path}' to point to a file.")), - ); - let layer = tracing_subscriber::fmt::layer() - .with_ansi(false) - .with_writer(writer) - .with_filter(LevelFilter::INFO) - .with_filter(filter_fn(|m| { - if let Some(path) = m.module_path() { - // The wgpu crate is logging too much, so we skip `info` level. - if path.starts_with("wgpu") && *m.level() >= Level::INFO { - return false; - } - } - true - })); - - if registry().with(layer).try_init().is_ok() { - update_panic_hook(file_path); - } -} - -fn update_panic_hook(file_path: &str) { - let hook = std::panic::take_hook(); - let file_path = file_path.to_owned(); - - std::panic::set_hook(Box::new(move |info| { - log::error!("PANIC => {}", info.to_string()); - eprintln!( - "=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \ - '{file_path}'\n=============" - ); - hook(info); - })); -} diff --git a/crates/burn-train/src/learner/mod.rs b/crates/burn-train/src/learner/mod.rs index be4cc258b..1df31701b 100644 --- a/crates/burn-train/src/learner/mod.rs +++ b/crates/burn-train/src/learner/mod.rs @@ -1,3 +1,4 @@ +mod application_logger; mod base; mod builder; mod classification; @@ -8,8 +9,7 @@ mod step; mod summary; mod train_val; -pub(crate) mod log; - +pub use application_logger::*; pub use base::*; pub use builder::*; pub use classification::*; diff --git a/examples/custom-renderer/src/lib.rs b/examples/custom-renderer/src/lib.rs index ab26a4531..313396854 100644 --- a/examples/custom-renderer/src/lib.rs +++ b/examples/custom-renderer/src/lib.rs @@ -76,7 +76,7 @@ pub fn run(device: B::Device) { .devices(vec![device]) .num_epochs(config.num_epochs) .renderer(CustomRenderer {}) - .log_to_file(false); + .with_application_logger(None); // can be used to interrupt training let _interrupter = builder.interrupter();