mirror of https://github.com/tracel-ai/burn.git
Add configurable application logger to learner builder (#1774)
* refactor: add TracingSubscriberLogger trait and FileTracingSubscriberLogger struct * Remove unused log module and renames, fmt * Renamed tracing subscriber logger * renamed to application logger installer * book learner configuration update update * fix typo * unused import
This commit is contained in:
parent
7ab2ba1809
commit
8de05e1419
|
@ -31,6 +31,7 @@ The learner builder provides numerous options when it comes to configurations.
|
|||
| Num Epochs | Set the number of epochs. |
|
||||
| Devices | Set the devices to be used |
|
||||
| Checkpoint | Restart training from a checkpoint |
|
||||
| Application logging | Configure the application logging installer (default is writing to `experiment.log`) |
|
||||
|
||||
When the builder is configured at your liking, you can then move forward to build the learner. The
|
||||
build method requires three inputs: the model, the optimizer and the learning rate scheduler. Note
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
use std::path::Path;
|
||||
use tracing_core::{Level, LevelFilter};
|
||||
use tracing_subscriber::filter::filter_fn;
|
||||
use tracing_subscriber::prelude::*;
|
||||
use tracing_subscriber::{registry, Layer};
|
||||
|
||||
/// This trait is used to install an application logger.
|
||||
pub trait ApplicationLoggerInstaller {
|
||||
/// Install the application logger.
|
||||
fn install(&self) -> Result<(), String>;
|
||||
}
|
||||
|
||||
/// This struct is used to install a local file application logger to output logs to a given file path.
|
||||
pub struct FileApplicationLoggerInstaller {
|
||||
path: String,
|
||||
}
|
||||
|
||||
impl FileApplicationLoggerInstaller {
|
||||
/// Create a new file application logger.
|
||||
pub fn new(path: &str) -> Self {
|
||||
Self {
|
||||
path: path.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ApplicationLoggerInstaller for FileApplicationLoggerInstaller {
|
||||
fn install(&self) -> Result<(), String> {
|
||||
let path = Path::new(&self.path);
|
||||
let writer = tracing_appender::rolling::never(
|
||||
path.parent().unwrap_or_else(|| Path::new(".")),
|
||||
path.file_name()
|
||||
.unwrap_or_else(|| panic!("The path '{}' to point to a file.", self.path)),
|
||||
);
|
||||
let layer = tracing_subscriber::fmt::layer()
|
||||
.with_ansi(false)
|
||||
.with_writer(writer)
|
||||
.with_filter(LevelFilter::INFO)
|
||||
.with_filter(filter_fn(|m| {
|
||||
if let Some(path) = m.module_path() {
|
||||
// The wgpu crate is logging too much, so we skip `info` level.
|
||||
if path.starts_with("wgpu") && *m.level() >= Level::INFO {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}));
|
||||
|
||||
if registry().with(layer).try_init().is_err() {
|
||||
return Err("Failed to install the file logger.".to_string());
|
||||
}
|
||||
|
||||
let hook = std::panic::take_hook();
|
||||
let file_path: String = self.path.to_owned();
|
||||
|
||||
std::panic::set_hook(Box::new(move |info| {
|
||||
log::error!("PANIC => {}", info.to_string());
|
||||
eprintln!(
|
||||
"=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \
|
||||
'{file_path}'\n============="
|
||||
);
|
||||
hook(info);
|
||||
}));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -1,7 +1,6 @@
|
|||
use std::collections::HashSet;
|
||||
use std::rc::Rc;
|
||||
|
||||
use super::log::install_file_logger;
|
||||
use super::Learner;
|
||||
use crate::checkpoint::{
|
||||
AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
|
||||
|
@ -15,7 +14,10 @@ use crate::metric::processor::{FullEventProcessor, Metrics};
|
|||
use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
|
||||
use crate::metric::{Adaptor, LossMetric, Metric};
|
||||
use crate::renderer::{default_renderer, MetricsRenderer};
|
||||
use crate::{LearnerCheckpointer, LearnerSummaryConfig};
|
||||
use crate::{
|
||||
ApplicationLoggerInstaller, FileApplicationLoggerInstaller, LearnerCheckpointer,
|
||||
LearnerSummaryConfig,
|
||||
};
|
||||
use burn_core::lr_scheduler::LrScheduler;
|
||||
use burn_core::module::AutodiffModule;
|
||||
use burn_core::optim::Optimizer;
|
||||
|
@ -50,7 +52,7 @@ where
|
|||
metrics: Metrics<T, V>,
|
||||
event_store: LogEventStore,
|
||||
interrupter: TrainingInterrupter,
|
||||
log_to_file: bool,
|
||||
tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
|
||||
num_loggers: usize,
|
||||
checkpointer_strategy: Box<dyn CheckpointingStrategy>,
|
||||
early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
|
||||
|
@ -84,7 +86,9 @@ where
|
|||
event_store: LogEventStore::default(),
|
||||
renderer: None,
|
||||
interrupter: TrainingInterrupter::new(),
|
||||
log_to_file: true,
|
||||
tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
|
||||
format!("{}/experiment.log", directory).as_str(),
|
||||
))),
|
||||
num_loggers: 0,
|
||||
checkpointer_strategy: Box::new(
|
||||
ComposedCheckpointingStrategy::builder()
|
||||
|
@ -233,8 +237,11 @@ where
|
|||
/// By default, Rust logs are captured and written into
|
||||
/// `experiment.log`. If disabled, standard Rust log handling
|
||||
/// will apply.
|
||||
pub fn log_to_file(mut self, enabled: bool) -> Self {
|
||||
self.log_to_file = enabled;
|
||||
pub fn with_application_logger(
|
||||
mut self,
|
||||
logger: Option<Box<dyn ApplicationLoggerInstaller>>,
|
||||
) -> Self {
|
||||
self.tracing_logger = logger;
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -258,7 +265,7 @@ where
|
|||
format!("{}/checkpoint", self.directory).as_str(),
|
||||
"optim",
|
||||
);
|
||||
let checkpointer_scheduler = FileCheckpointer::new(
|
||||
let checkpointer_scheduler: FileCheckpointer<FR> = FileCheckpointer::new(
|
||||
recorder,
|
||||
format!("{}/checkpoint", self.directory).as_str(),
|
||||
"scheduler",
|
||||
|
@ -309,8 +316,10 @@ where
|
|||
O::Record: 'static,
|
||||
S::Record: 'static,
|
||||
{
|
||||
if self.log_to_file {
|
||||
self.init_logger();
|
||||
if self.tracing_logger.is_some() {
|
||||
if let Err(e) = self.tracing_logger.as_ref().unwrap().install() {
|
||||
log::warn!("Failed to install the experiment logger: {}", e);
|
||||
}
|
||||
}
|
||||
let renderer = self.renderer.unwrap_or_else(|| {
|
||||
Box::new(default_renderer(self.interrupter.clone(), self.checkpoint))
|
||||
|
@ -360,9 +369,4 @@ where
|
|||
summary,
|
||||
}
|
||||
}
|
||||
|
||||
fn init_logger(&self) {
|
||||
let file_path = format!("{}/experiment.log", self.directory);
|
||||
install_file_logger(file_path.as_str());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,47 +0,0 @@
|
|||
use std::path::Path;
|
||||
use tracing_core::{Level, LevelFilter};
|
||||
use tracing_subscriber::filter::filter_fn;
|
||||
use tracing_subscriber::prelude::*;
|
||||
use tracing_subscriber::{registry, Layer};
|
||||
|
||||
/// If a global tracing subscriber is not already configured, set up logging to a file,
|
||||
/// and add our custom panic hook.
|
||||
pub(crate) fn install_file_logger(file_path: &str) {
|
||||
let path = Path::new(file_path);
|
||||
let writer = tracing_appender::rolling::never(
|
||||
path.parent().unwrap_or_else(|| Path::new(".")),
|
||||
path.file_name()
|
||||
.unwrap_or_else(|| panic!("The path '{file_path}' to point to a file.")),
|
||||
);
|
||||
let layer = tracing_subscriber::fmt::layer()
|
||||
.with_ansi(false)
|
||||
.with_writer(writer)
|
||||
.with_filter(LevelFilter::INFO)
|
||||
.with_filter(filter_fn(|m| {
|
||||
if let Some(path) = m.module_path() {
|
||||
// The wgpu crate is logging too much, so we skip `info` level.
|
||||
if path.starts_with("wgpu") && *m.level() >= Level::INFO {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}));
|
||||
|
||||
if registry().with(layer).try_init().is_ok() {
|
||||
update_panic_hook(file_path);
|
||||
}
|
||||
}
|
||||
|
||||
fn update_panic_hook(file_path: &str) {
|
||||
let hook = std::panic::take_hook();
|
||||
let file_path = file_path.to_owned();
|
||||
|
||||
std::panic::set_hook(Box::new(move |info| {
|
||||
log::error!("PANIC => {}", info.to_string());
|
||||
eprintln!(
|
||||
"=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \
|
||||
'{file_path}'\n============="
|
||||
);
|
||||
hook(info);
|
||||
}));
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
mod application_logger;
|
||||
mod base;
|
||||
mod builder;
|
||||
mod classification;
|
||||
|
@ -8,8 +9,7 @@ mod step;
|
|||
mod summary;
|
||||
mod train_val;
|
||||
|
||||
pub(crate) mod log;
|
||||
|
||||
pub use application_logger::*;
|
||||
pub use base::*;
|
||||
pub use builder::*;
|
||||
pub use classification::*;
|
||||
|
|
|
@ -76,7 +76,7 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
|
|||
.devices(vec![device])
|
||||
.num_epochs(config.num_epochs)
|
||||
.renderer(CustomRenderer {})
|
||||
.log_to_file(false);
|
||||
.with_application_logger(None);
|
||||
// can be used to interrupt training
|
||||
let _interrupter = builder.interrupter();
|
||||
|
||||
|
|
Loading…
Reference in New Issue