mirror of https://github.com/tracel-ai/burn.git
Add configurable application logger to learner builder (#1774)
* refactor: add TracingSubscriberLogger trait and FileTracingSubscriberLogger struct * Remove unused log module and renames, fmt * Renamed tracing subscriber logger * renamed to application logger installer * book learner configuration update update * fix typo * unused import
This commit is contained in:
parent
7ab2ba1809
commit
8de05e1419
|
@ -31,6 +31,7 @@ The learner builder provides numerous options when it comes to configurations.
|
||||||
| Num Epochs | Set the number of epochs. |
|
| Num Epochs | Set the number of epochs. |
|
||||||
| Devices | Set the devices to be used |
|
| Devices | Set the devices to be used |
|
||||||
| Checkpoint | Restart training from a checkpoint |
|
| Checkpoint | Restart training from a checkpoint |
|
||||||
|
| Application logging | Configure the application logging installer (default is writing to `experiment.log`) |
|
||||||
|
|
||||||
When the builder is configured at your liking, you can then move forward to build the learner. The
|
When the builder is configured at your liking, you can then move forward to build the learner. The
|
||||||
build method requires three inputs: the model, the optimizer and the learning rate scheduler. Note
|
build method requires three inputs: the model, the optimizer and the learning rate scheduler. Note
|
||||||
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
use std::path::Path;
|
||||||
|
use tracing_core::{Level, LevelFilter};
|
||||||
|
use tracing_subscriber::filter::filter_fn;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
use tracing_subscriber::{registry, Layer};
|
||||||
|
|
||||||
|
/// This trait is used to install an application logger.
|
||||||
|
pub trait ApplicationLoggerInstaller {
|
||||||
|
/// Install the application logger.
|
||||||
|
fn install(&self) -> Result<(), String>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This struct is used to install a local file application logger to output logs to a given file path.
|
||||||
|
pub struct FileApplicationLoggerInstaller {
|
||||||
|
path: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FileApplicationLoggerInstaller {
|
||||||
|
/// Create a new file application logger.
|
||||||
|
pub fn new(path: &str) -> Self {
|
||||||
|
Self {
|
||||||
|
path: path.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ApplicationLoggerInstaller for FileApplicationLoggerInstaller {
|
||||||
|
fn install(&self) -> Result<(), String> {
|
||||||
|
let path = Path::new(&self.path);
|
||||||
|
let writer = tracing_appender::rolling::never(
|
||||||
|
path.parent().unwrap_or_else(|| Path::new(".")),
|
||||||
|
path.file_name()
|
||||||
|
.unwrap_or_else(|| panic!("The path '{}' to point to a file.", self.path)),
|
||||||
|
);
|
||||||
|
let layer = tracing_subscriber::fmt::layer()
|
||||||
|
.with_ansi(false)
|
||||||
|
.with_writer(writer)
|
||||||
|
.with_filter(LevelFilter::INFO)
|
||||||
|
.with_filter(filter_fn(|m| {
|
||||||
|
if let Some(path) = m.module_path() {
|
||||||
|
// The wgpu crate is logging too much, so we skip `info` level.
|
||||||
|
if path.starts_with("wgpu") && *m.level() >= Level::INFO {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
true
|
||||||
|
}));
|
||||||
|
|
||||||
|
if registry().with(layer).try_init().is_err() {
|
||||||
|
return Err("Failed to install the file logger.".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let hook = std::panic::take_hook();
|
||||||
|
let file_path: String = self.path.to_owned();
|
||||||
|
|
||||||
|
std::panic::set_hook(Box::new(move |info| {
|
||||||
|
log::error!("PANIC => {}", info.to_string());
|
||||||
|
eprintln!(
|
||||||
|
"=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \
|
||||||
|
'{file_path}'\n============="
|
||||||
|
);
|
||||||
|
hook(info);
|
||||||
|
}));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,7 +1,6 @@
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
|
|
||||||
use super::log::install_file_logger;
|
|
||||||
use super::Learner;
|
use super::Learner;
|
||||||
use crate::checkpoint::{
|
use crate::checkpoint::{
|
||||||
AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
|
AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
|
||||||
|
@ -15,7 +14,10 @@ use crate::metric::processor::{FullEventProcessor, Metrics};
|
||||||
use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
|
use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
|
||||||
use crate::metric::{Adaptor, LossMetric, Metric};
|
use crate::metric::{Adaptor, LossMetric, Metric};
|
||||||
use crate::renderer::{default_renderer, MetricsRenderer};
|
use crate::renderer::{default_renderer, MetricsRenderer};
|
||||||
use crate::{LearnerCheckpointer, LearnerSummaryConfig};
|
use crate::{
|
||||||
|
ApplicationLoggerInstaller, FileApplicationLoggerInstaller, LearnerCheckpointer,
|
||||||
|
LearnerSummaryConfig,
|
||||||
|
};
|
||||||
use burn_core::lr_scheduler::LrScheduler;
|
use burn_core::lr_scheduler::LrScheduler;
|
||||||
use burn_core::module::AutodiffModule;
|
use burn_core::module::AutodiffModule;
|
||||||
use burn_core::optim::Optimizer;
|
use burn_core::optim::Optimizer;
|
||||||
|
@ -50,7 +52,7 @@ where
|
||||||
metrics: Metrics<T, V>,
|
metrics: Metrics<T, V>,
|
||||||
event_store: LogEventStore,
|
event_store: LogEventStore,
|
||||||
interrupter: TrainingInterrupter,
|
interrupter: TrainingInterrupter,
|
||||||
log_to_file: bool,
|
tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
|
||||||
num_loggers: usize,
|
num_loggers: usize,
|
||||||
checkpointer_strategy: Box<dyn CheckpointingStrategy>,
|
checkpointer_strategy: Box<dyn CheckpointingStrategy>,
|
||||||
early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
|
early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
|
||||||
|
@ -84,7 +86,9 @@ where
|
||||||
event_store: LogEventStore::default(),
|
event_store: LogEventStore::default(),
|
||||||
renderer: None,
|
renderer: None,
|
||||||
interrupter: TrainingInterrupter::new(),
|
interrupter: TrainingInterrupter::new(),
|
||||||
log_to_file: true,
|
tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
|
||||||
|
format!("{}/experiment.log", directory).as_str(),
|
||||||
|
))),
|
||||||
num_loggers: 0,
|
num_loggers: 0,
|
||||||
checkpointer_strategy: Box::new(
|
checkpointer_strategy: Box::new(
|
||||||
ComposedCheckpointingStrategy::builder()
|
ComposedCheckpointingStrategy::builder()
|
||||||
|
@ -233,8 +237,11 @@ where
|
||||||
/// By default, Rust logs are captured and written into
|
/// By default, Rust logs are captured and written into
|
||||||
/// `experiment.log`. If disabled, standard Rust log handling
|
/// `experiment.log`. If disabled, standard Rust log handling
|
||||||
/// will apply.
|
/// will apply.
|
||||||
pub fn log_to_file(mut self, enabled: bool) -> Self {
|
pub fn with_application_logger(
|
||||||
self.log_to_file = enabled;
|
mut self,
|
||||||
|
logger: Option<Box<dyn ApplicationLoggerInstaller>>,
|
||||||
|
) -> Self {
|
||||||
|
self.tracing_logger = logger;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -258,7 +265,7 @@ where
|
||||||
format!("{}/checkpoint", self.directory).as_str(),
|
format!("{}/checkpoint", self.directory).as_str(),
|
||||||
"optim",
|
"optim",
|
||||||
);
|
);
|
||||||
let checkpointer_scheduler = FileCheckpointer::new(
|
let checkpointer_scheduler: FileCheckpointer<FR> = FileCheckpointer::new(
|
||||||
recorder,
|
recorder,
|
||||||
format!("{}/checkpoint", self.directory).as_str(),
|
format!("{}/checkpoint", self.directory).as_str(),
|
||||||
"scheduler",
|
"scheduler",
|
||||||
|
@ -309,8 +316,10 @@ where
|
||||||
O::Record: 'static,
|
O::Record: 'static,
|
||||||
S::Record: 'static,
|
S::Record: 'static,
|
||||||
{
|
{
|
||||||
if self.log_to_file {
|
if self.tracing_logger.is_some() {
|
||||||
self.init_logger();
|
if let Err(e) = self.tracing_logger.as_ref().unwrap().install() {
|
||||||
|
log::warn!("Failed to install the experiment logger: {}", e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
let renderer = self.renderer.unwrap_or_else(|| {
|
let renderer = self.renderer.unwrap_or_else(|| {
|
||||||
Box::new(default_renderer(self.interrupter.clone(), self.checkpoint))
|
Box::new(default_renderer(self.interrupter.clone(), self.checkpoint))
|
||||||
|
@ -360,9 +369,4 @@ where
|
||||||
summary,
|
summary,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn init_logger(&self) {
|
|
||||||
let file_path = format!("{}/experiment.log", self.directory);
|
|
||||||
install_file_logger(file_path.as_str());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,47 +0,0 @@
|
||||||
use std::path::Path;
|
|
||||||
use tracing_core::{Level, LevelFilter};
|
|
||||||
use tracing_subscriber::filter::filter_fn;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
use tracing_subscriber::{registry, Layer};
|
|
||||||
|
|
||||||
/// If a global tracing subscriber is not already configured, set up logging to a file,
|
|
||||||
/// and add our custom panic hook.
|
|
||||||
pub(crate) fn install_file_logger(file_path: &str) {
|
|
||||||
let path = Path::new(file_path);
|
|
||||||
let writer = tracing_appender::rolling::never(
|
|
||||||
path.parent().unwrap_or_else(|| Path::new(".")),
|
|
||||||
path.file_name()
|
|
||||||
.unwrap_or_else(|| panic!("The path '{file_path}' to point to a file.")),
|
|
||||||
);
|
|
||||||
let layer = tracing_subscriber::fmt::layer()
|
|
||||||
.with_ansi(false)
|
|
||||||
.with_writer(writer)
|
|
||||||
.with_filter(LevelFilter::INFO)
|
|
||||||
.with_filter(filter_fn(|m| {
|
|
||||||
if let Some(path) = m.module_path() {
|
|
||||||
// The wgpu crate is logging too much, so we skip `info` level.
|
|
||||||
if path.starts_with("wgpu") && *m.level() >= Level::INFO {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
true
|
|
||||||
}));
|
|
||||||
|
|
||||||
if registry().with(layer).try_init().is_ok() {
|
|
||||||
update_panic_hook(file_path);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn update_panic_hook(file_path: &str) {
|
|
||||||
let hook = std::panic::take_hook();
|
|
||||||
let file_path = file_path.to_owned();
|
|
||||||
|
|
||||||
std::panic::set_hook(Box::new(move |info| {
|
|
||||||
log::error!("PANIC => {}", info.to_string());
|
|
||||||
eprintln!(
|
|
||||||
"=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \
|
|
||||||
'{file_path}'\n============="
|
|
||||||
);
|
|
||||||
hook(info);
|
|
||||||
}));
|
|
||||||
}
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
mod application_logger;
|
||||||
mod base;
|
mod base;
|
||||||
mod builder;
|
mod builder;
|
||||||
mod classification;
|
mod classification;
|
||||||
|
@ -8,8 +9,7 @@ mod step;
|
||||||
mod summary;
|
mod summary;
|
||||||
mod train_val;
|
mod train_val;
|
||||||
|
|
||||||
pub(crate) mod log;
|
pub use application_logger::*;
|
||||||
|
|
||||||
pub use base::*;
|
pub use base::*;
|
||||||
pub use builder::*;
|
pub use builder::*;
|
||||||
pub use classification::*;
|
pub use classification::*;
|
||||||
|
|
|
@ -76,7 +76,7 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
|
||||||
.devices(vec![device])
|
.devices(vec![device])
|
||||||
.num_epochs(config.num_epochs)
|
.num_epochs(config.num_epochs)
|
||||||
.renderer(CustomRenderer {})
|
.renderer(CustomRenderer {})
|
||||||
.log_to_file(false);
|
.with_application_logger(None);
|
||||||
// can be used to interrupt training
|
// can be used to interrupt training
|
||||||
let _interrupter = builder.interrupter();
|
let _interrupter = builder.interrupter();
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue