Add configurable application logger to learner builder (#1774)

* refactor: add TracingSubscriberLogger trait and FileTracingSubscriberLogger struct

* Remove unused log module and renames, fmt

* Renamed tracing subscriber logger

* renamed to application logger installer

* book learner configuration update update

* fix typo

* unused import
This commit is contained in:
Jonathan Richard 2024-05-16 16:25:33 -04:00 committed by GitHub
parent 7ab2ba1809
commit 8de05e1419
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 89 additions and 64 deletions

View File

@ -31,6 +31,7 @@ The learner builder provides numerous options when it comes to configurations.
| Num Epochs | Set the number of epochs. | | Num Epochs | Set the number of epochs. |
| Devices | Set the devices to be used | | Devices | Set the devices to be used |
| Checkpoint | Restart training from a checkpoint | | Checkpoint | Restart training from a checkpoint |
| Application logging | Configure the application logging installer (default is writing to `experiment.log`) |
When the builder is configured at your liking, you can then move forward to build the learner. The When the builder is configured at your liking, you can then move forward to build the learner. The
build method requires three inputs: the model, the optimizer and the learning rate scheduler. Note build method requires three inputs: the model, the optimizer and the learning rate scheduler. Note

View File

@ -0,0 +1,67 @@
use std::path::Path;
use tracing_core::{Level, LevelFilter};
use tracing_subscriber::filter::filter_fn;
use tracing_subscriber::prelude::*;
use tracing_subscriber::{registry, Layer};
/// This trait is used to install an application logger.
pub trait ApplicationLoggerInstaller {
/// Install the application logger.
fn install(&self) -> Result<(), String>;
}
/// This struct is used to install a local file application logger to output logs to a given file path.
pub struct FileApplicationLoggerInstaller {
path: String,
}
impl FileApplicationLoggerInstaller {
/// Create a new file application logger.
pub fn new(path: &str) -> Self {
Self {
path: path.to_string(),
}
}
}
impl ApplicationLoggerInstaller for FileApplicationLoggerInstaller {
fn install(&self) -> Result<(), String> {
let path = Path::new(&self.path);
let writer = tracing_appender::rolling::never(
path.parent().unwrap_or_else(|| Path::new(".")),
path.file_name()
.unwrap_or_else(|| panic!("The path '{}' to point to a file.", self.path)),
);
let layer = tracing_subscriber::fmt::layer()
.with_ansi(false)
.with_writer(writer)
.with_filter(LevelFilter::INFO)
.with_filter(filter_fn(|m| {
if let Some(path) = m.module_path() {
// The wgpu crate is logging too much, so we skip `info` level.
if path.starts_with("wgpu") && *m.level() >= Level::INFO {
return false;
}
}
true
}));
if registry().with(layer).try_init().is_err() {
return Err("Failed to install the file logger.".to_string());
}
let hook = std::panic::take_hook();
let file_path: String = self.path.to_owned();
std::panic::set_hook(Box::new(move |info| {
log::error!("PANIC => {}", info.to_string());
eprintln!(
"=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \
'{file_path}'\n============="
);
hook(info);
}));
Ok(())
}
}

View File

@ -1,7 +1,6 @@
use std::collections::HashSet; use std::collections::HashSet;
use std::rc::Rc; use std::rc::Rc;
use super::log::install_file_logger;
use super::Learner; use super::Learner;
use crate::checkpoint::{ use crate::checkpoint::{
AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer, AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
@ -15,7 +14,10 @@ use crate::metric::processor::{FullEventProcessor, Metrics};
use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split}; use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
use crate::metric::{Adaptor, LossMetric, Metric}; use crate::metric::{Adaptor, LossMetric, Metric};
use crate::renderer::{default_renderer, MetricsRenderer}; use crate::renderer::{default_renderer, MetricsRenderer};
use crate::{LearnerCheckpointer, LearnerSummaryConfig}; use crate::{
ApplicationLoggerInstaller, FileApplicationLoggerInstaller, LearnerCheckpointer,
LearnerSummaryConfig,
};
use burn_core::lr_scheduler::LrScheduler; use burn_core::lr_scheduler::LrScheduler;
use burn_core::module::AutodiffModule; use burn_core::module::AutodiffModule;
use burn_core::optim::Optimizer; use burn_core::optim::Optimizer;
@ -50,7 +52,7 @@ where
metrics: Metrics<T, V>, metrics: Metrics<T, V>,
event_store: LogEventStore, event_store: LogEventStore,
interrupter: TrainingInterrupter, interrupter: TrainingInterrupter,
log_to_file: bool, tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
num_loggers: usize, num_loggers: usize,
checkpointer_strategy: Box<dyn CheckpointingStrategy>, checkpointer_strategy: Box<dyn CheckpointingStrategy>,
early_stopping: Option<Box<dyn EarlyStoppingStrategy>>, early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
@ -84,7 +86,9 @@ where
event_store: LogEventStore::default(), event_store: LogEventStore::default(),
renderer: None, renderer: None,
interrupter: TrainingInterrupter::new(), interrupter: TrainingInterrupter::new(),
log_to_file: true, tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
format!("{}/experiment.log", directory).as_str(),
))),
num_loggers: 0, num_loggers: 0,
checkpointer_strategy: Box::new( checkpointer_strategy: Box::new(
ComposedCheckpointingStrategy::builder() ComposedCheckpointingStrategy::builder()
@ -233,8 +237,11 @@ where
/// By default, Rust logs are captured and written into /// By default, Rust logs are captured and written into
/// `experiment.log`. If disabled, standard Rust log handling /// `experiment.log`. If disabled, standard Rust log handling
/// will apply. /// will apply.
pub fn log_to_file(mut self, enabled: bool) -> Self { pub fn with_application_logger(
self.log_to_file = enabled; mut self,
logger: Option<Box<dyn ApplicationLoggerInstaller>>,
) -> Self {
self.tracing_logger = logger;
self self
} }
@ -258,7 +265,7 @@ where
format!("{}/checkpoint", self.directory).as_str(), format!("{}/checkpoint", self.directory).as_str(),
"optim", "optim",
); );
let checkpointer_scheduler = FileCheckpointer::new( let checkpointer_scheduler: FileCheckpointer<FR> = FileCheckpointer::new(
recorder, recorder,
format!("{}/checkpoint", self.directory).as_str(), format!("{}/checkpoint", self.directory).as_str(),
"scheduler", "scheduler",
@ -309,8 +316,10 @@ where
O::Record: 'static, O::Record: 'static,
S::Record: 'static, S::Record: 'static,
{ {
if self.log_to_file { if self.tracing_logger.is_some() {
self.init_logger(); if let Err(e) = self.tracing_logger.as_ref().unwrap().install() {
log::warn!("Failed to install the experiment logger: {}", e);
}
} }
let renderer = self.renderer.unwrap_or_else(|| { let renderer = self.renderer.unwrap_or_else(|| {
Box::new(default_renderer(self.interrupter.clone(), self.checkpoint)) Box::new(default_renderer(self.interrupter.clone(), self.checkpoint))
@ -360,9 +369,4 @@ where
summary, summary,
} }
} }
fn init_logger(&self) {
let file_path = format!("{}/experiment.log", self.directory);
install_file_logger(file_path.as_str());
}
} }

View File

@ -1,47 +0,0 @@
use std::path::Path;
use tracing_core::{Level, LevelFilter};
use tracing_subscriber::filter::filter_fn;
use tracing_subscriber::prelude::*;
use tracing_subscriber::{registry, Layer};
/// If a global tracing subscriber is not already configured, set up logging to a file,
/// and add our custom panic hook.
pub(crate) fn install_file_logger(file_path: &str) {
let path = Path::new(file_path);
let writer = tracing_appender::rolling::never(
path.parent().unwrap_or_else(|| Path::new(".")),
path.file_name()
.unwrap_or_else(|| panic!("The path '{file_path}' to point to a file.")),
);
let layer = tracing_subscriber::fmt::layer()
.with_ansi(false)
.with_writer(writer)
.with_filter(LevelFilter::INFO)
.with_filter(filter_fn(|m| {
if let Some(path) = m.module_path() {
// The wgpu crate is logging too much, so we skip `info` level.
if path.starts_with("wgpu") && *m.level() >= Level::INFO {
return false;
}
}
true
}));
if registry().with(layer).try_init().is_ok() {
update_panic_hook(file_path);
}
}
fn update_panic_hook(file_path: &str) {
let hook = std::panic::take_hook();
let file_path = file_path.to_owned();
std::panic::set_hook(Box::new(move |info| {
log::error!("PANIC => {}", info.to_string());
eprintln!(
"=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \
'{file_path}'\n============="
);
hook(info);
}));
}

View File

@ -1,3 +1,4 @@
mod application_logger;
mod base; mod base;
mod builder; mod builder;
mod classification; mod classification;
@ -8,8 +9,7 @@ mod step;
mod summary; mod summary;
mod train_val; mod train_val;
pub(crate) mod log; pub use application_logger::*;
pub use base::*; pub use base::*;
pub use builder::*; pub use builder::*;
pub use classification::*; pub use classification::*;

View File

@ -76,7 +76,7 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
.devices(vec![device]) .devices(vec![device])
.num_epochs(config.num_epochs) .num_epochs(config.num_epochs)
.renderer(CustomRenderer {}) .renderer(CustomRenderer {})
.log_to_file(false); .with_application_logger(None);
// can be used to interrupt training // can be used to interrupt training
let _interrupter = builder.interrupter(); let _interrupter = builder.interrupter();