mirror of https://github.com/tracel-ai/burn.git
Replaced `str` with `Path` (#1919)
* replaced str with Path * minor change (Path to AsRef<Path>) * fixed clippy lint
This commit is contained in:
parent
98a58c867d
commit
a7efc102b9
|
@ -1,3 +1,5 @@
|
|||
use std::path::{Path, PathBuf};
|
||||
|
||||
use super::{Checkpointer, CheckpointerError};
|
||||
use burn_core::{
|
||||
record::{FileRecorder, Record},
|
||||
|
@ -6,7 +8,7 @@ use burn_core::{
|
|||
|
||||
/// The file checkpointer.
|
||||
pub struct FileCheckpointer<FR> {
|
||||
directory: String,
|
||||
directory: PathBuf,
|
||||
name: String,
|
||||
recorder: FR,
|
||||
}
|
||||
|
@ -19,17 +21,19 @@ impl<FR> FileCheckpointer<FR> {
|
|||
/// * `recorder` - The file recorder.
|
||||
/// * `directory` - The directory to save the checkpoints.
|
||||
/// * `name` - The name of the checkpoint.
|
||||
pub fn new(recorder: FR, directory: &str, name: &str) -> Self {
|
||||
pub fn new(recorder: FR, directory: impl AsRef<Path>, name: &str) -> Self {
|
||||
let directory = directory.as_ref();
|
||||
std::fs::create_dir_all(directory).ok();
|
||||
|
||||
Self {
|
||||
directory: directory.to_string(),
|
||||
directory: directory.to_path_buf(),
|
||||
name: name.to_string(),
|
||||
recorder,
|
||||
}
|
||||
}
|
||||
fn path_for_epoch(&self, epoch: usize) -> String {
|
||||
format!("{}/{}-{}", self.directory, self.name, epoch)
|
||||
|
||||
fn path_for_epoch(&self, epoch: usize) -> PathBuf {
|
||||
self.directory.join(format!("{}-{}", self.name, epoch))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -41,10 +45,10 @@ where
|
|||
{
|
||||
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
|
||||
let file_path = self.path_for_epoch(epoch);
|
||||
log::info!("Saving checkpoint {} to {}", epoch, file_path);
|
||||
log::info!("Saving checkpoint {} to {}", epoch, file_path.display());
|
||||
|
||||
self.recorder
|
||||
.record(record, file_path.into())
|
||||
.record(record, file_path)
|
||||
.map_err(CheckpointerError::RecorderError)?;
|
||||
|
||||
Ok(())
|
||||
|
@ -52,17 +56,25 @@ where
|
|||
|
||||
fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {
|
||||
let file_path = self.path_for_epoch(epoch);
|
||||
log::info!("Restoring checkpoint {} from {}", epoch, file_path);
|
||||
log::info!(
|
||||
"Restoring checkpoint {} from {}",
|
||||
epoch,
|
||||
file_path.display()
|
||||
);
|
||||
let record = self
|
||||
.recorder
|
||||
.load(file_path.into(), device)
|
||||
.load(file_path, device)
|
||||
.map_err(CheckpointerError::RecorderError)?;
|
||||
|
||||
Ok(record)
|
||||
}
|
||||
|
||||
fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> {
|
||||
let file_to_remove = format!("{}.{}", self.path_for_epoch(epoch), FR::file_extension(),);
|
||||
let file_to_remove = format!(
|
||||
"{}.{}",
|
||||
self.path_for_epoch(epoch).display(),
|
||||
FR::file_extension(),
|
||||
);
|
||||
|
||||
if std::path::Path::new(&file_to_remove).exists() {
|
||||
log::info!("Removing checkpoint {}", file_to_remove);
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::path::Path;
|
||||
use std::path::{Path, PathBuf};
|
||||
use tracing_core::{Level, LevelFilter};
|
||||
use tracing_subscriber::filter::filter_fn;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
@ -12,14 +12,14 @@ pub trait ApplicationLoggerInstaller {
|
|||
|
||||
/// This struct is used to install a local file application logger to output logs to a given file path.
|
||||
pub struct FileApplicationLoggerInstaller {
|
||||
path: String,
|
||||
path: PathBuf,
|
||||
}
|
||||
|
||||
impl FileApplicationLoggerInstaller {
|
||||
/// Create a new file application logger.
|
||||
pub fn new(path: &str) -> Self {
|
||||
pub fn new(path: impl AsRef<Path>) -> Self {
|
||||
Self {
|
||||
path: path.to_string(),
|
||||
path: path.as_ref().to_path_buf(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -29,8 +29,9 @@ impl ApplicationLoggerInstaller for FileApplicationLoggerInstaller {
|
|||
let path = Path::new(&self.path);
|
||||
let writer = tracing_appender::rolling::never(
|
||||
path.parent().unwrap_or_else(|| Path::new(".")),
|
||||
path.file_name()
|
||||
.unwrap_or_else(|| panic!("The path '{}' to point to a file.", self.path)),
|
||||
path.file_name().unwrap_or_else(|| {
|
||||
panic!("The path '{}' to point to a file.", self.path.display())
|
||||
}),
|
||||
);
|
||||
let layer = tracing_subscriber::fmt::layer()
|
||||
.with_ansi(false)
|
||||
|
@ -51,13 +52,14 @@ impl ApplicationLoggerInstaller for FileApplicationLoggerInstaller {
|
|||
}
|
||||
|
||||
let hook = std::panic::take_hook();
|
||||
let file_path: String = self.path.to_owned();
|
||||
let file_path = self.path.to_owned();
|
||||
|
||||
std::panic::set_hook(Box::new(move |info| {
|
||||
log::error!("PANIC => {}", info.to_string());
|
||||
eprintln!(
|
||||
"=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \
|
||||
'{file_path}'\n============="
|
||||
'{}'\n=============",
|
||||
file_path.display()
|
||||
);
|
||||
hook(info);
|
||||
}));
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use std::collections::HashSet;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::rc::Rc;
|
||||
|
||||
use super::Learner;
|
||||
|
@ -45,7 +46,7 @@ where
|
|||
)>,
|
||||
num_epochs: usize,
|
||||
checkpoint: Option<usize>,
|
||||
directory: String,
|
||||
directory: PathBuf,
|
||||
grad_accumulation: Option<usize>,
|
||||
devices: Vec<B::Device>,
|
||||
renderer: Option<Box<dyn MetricsRenderer + 'static>>,
|
||||
|
@ -74,12 +75,14 @@ where
|
|||
/// # Arguments
|
||||
///
|
||||
/// * `directory` - The directory to save the checkpoints.
|
||||
pub fn new(directory: &str) -> Self {
|
||||
pub fn new(directory: impl AsRef<Path>) -> Self {
|
||||
let directory = directory.as_ref().to_path_buf();
|
||||
let experiment_log_file = directory.join("experiment.log");
|
||||
Self {
|
||||
num_epochs: 1,
|
||||
checkpoint: None,
|
||||
checkpointers: None,
|
||||
directory: directory.to_string(),
|
||||
directory,
|
||||
grad_accumulation: None,
|
||||
devices: vec![B::Device::default()],
|
||||
metrics: Metrics::default(),
|
||||
|
@ -87,7 +90,7 @@ where
|
|||
renderer: None,
|
||||
interrupter: TrainingInterrupter::new(),
|
||||
tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
|
||||
format!("{}/experiment.log", directory).as_str(),
|
||||
experiment_log_file,
|
||||
))),
|
||||
num_loggers: 0,
|
||||
checkpointer_strategy: Box::new(
|
||||
|
@ -256,21 +259,12 @@ where
|
|||
M::Record: 'static,
|
||||
S::Record: 'static,
|
||||
{
|
||||
let checkpointer_model = FileCheckpointer::new(
|
||||
recorder.clone(),
|
||||
format!("{}/checkpoint", self.directory).as_str(),
|
||||
"model",
|
||||
);
|
||||
let checkpointer_optimizer = FileCheckpointer::new(
|
||||
recorder.clone(),
|
||||
format!("{}/checkpoint", self.directory).as_str(),
|
||||
"optim",
|
||||
);
|
||||
let checkpointer_scheduler: FileCheckpointer<FR> = FileCheckpointer::new(
|
||||
recorder,
|
||||
format!("{}/checkpoint", self.directory).as_str(),
|
||||
"scheduler",
|
||||
);
|
||||
let checkpoint_dir = self.directory.join("checkpoint");
|
||||
let checkpointer_model = FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "model");
|
||||
let checkpointer_optimizer =
|
||||
FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "optim");
|
||||
let checkpointer_scheduler: FileCheckpointer<FR> =
|
||||
FileCheckpointer::new(recorder, &checkpoint_dir, "scheduler");
|
||||
|
||||
self.checkpointers = Some((
|
||||
AsyncCheckpointer::new(checkpointer_model),
|
||||
|
@ -325,17 +319,12 @@ where
|
|||
let renderer = self.renderer.unwrap_or_else(|| {
|
||||
Box::new(default_renderer(self.interrupter.clone(), self.checkpoint))
|
||||
});
|
||||
let directory = &self.directory;
|
||||
|
||||
if self.num_loggers == 0 {
|
||||
self.event_store
|
||||
.register_logger_train(FileMetricLogger::new(
|
||||
format!("{directory}/train").as_str(),
|
||||
));
|
||||
.register_logger_train(FileMetricLogger::new(self.directory.join("train")));
|
||||
self.event_store
|
||||
.register_logger_valid(FileMetricLogger::new(
|
||||
format!("{directory}/valid").as_str(),
|
||||
));
|
||||
.register_logger_valid(FileMetricLogger::new(self.directory.join("valid")));
|
||||
}
|
||||
|
||||
let event_store = Rc::new(EventStoreClient::new(self.event_store));
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
use core::cmp::Ordering;
|
||||
use std::{fmt::Display, path::Path};
|
||||
use std::{
|
||||
fmt::Display,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
logger::FileMetricLogger,
|
||||
|
@ -73,16 +76,20 @@ impl LearnerSummary {
|
|||
///
|
||||
/// * `directory` - The directory containing the training artifacts (checkpoints and logs).
|
||||
/// * `metrics` - The list of metrics to collect for the summary.
|
||||
pub fn new<S: AsRef<str>>(directory: &str, metrics: &[S]) -> Result<Self, String> {
|
||||
let directory_path = Path::new(directory);
|
||||
if !directory_path.exists() {
|
||||
return Err(format!("Artifact directory does not exist at: {directory}"));
|
||||
pub fn new<S: AsRef<str>>(directory: impl AsRef<Path>, metrics: &[S]) -> Result<Self, String> {
|
||||
let directory = directory.as_ref();
|
||||
if !directory.exists() {
|
||||
return Err(format!(
|
||||
"Artifact directory does not exist at: {}",
|
||||
directory.display()
|
||||
));
|
||||
}
|
||||
let train_dir = directory_path.join("train");
|
||||
let valid_dir = directory_path.join("valid");
|
||||
let train_dir = directory.join("train");
|
||||
let valid_dir = directory.join("valid");
|
||||
if !train_dir.exists() & !valid_dir.exists() {
|
||||
return Err(format!(
|
||||
"No training or validation artifacts found at: {directory}"
|
||||
"No training or validation artifacts found at: {}",
|
||||
directory.display()
|
||||
));
|
||||
}
|
||||
|
||||
|
@ -219,7 +226,7 @@ impl Display for LearnerSummary {
|
|||
}
|
||||
|
||||
pub(crate) struct LearnerSummaryConfig {
|
||||
pub(crate) directory: String,
|
||||
pub(crate) directory: PathBuf,
|
||||
pub(crate) metrics: Vec<String>,
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use super::Logger;
|
||||
use std::{fs::File, io::Write};
|
||||
use std::{fs::File, io::Write, path::Path};
|
||||
|
||||
/// File logger.
|
||||
pub struct FileLogger {
|
||||
|
@ -16,14 +16,21 @@ impl FileLogger {
|
|||
/// # Returns
|
||||
///
|
||||
/// The file logger.
|
||||
pub fn new(path: &str) -> Self {
|
||||
pub fn new(path: impl AsRef<Path>) -> Self {
|
||||
let path = path.as_ref();
|
||||
let mut options = std::fs::File::options();
|
||||
let file = options
|
||||
.write(true)
|
||||
.truncate(true)
|
||||
.create(true)
|
||||
.open(path)
|
||||
.unwrap_or_else(|err| panic!("Should be able to create the new file '{path}': {err}"));
|
||||
.unwrap_or_else(|err| {
|
||||
panic!(
|
||||
"Should be able to create the new file '{}': {}",
|
||||
path.display(),
|
||||
err
|
||||
)
|
||||
});
|
||||
|
||||
Self { file }
|
||||
}
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger};
|
||||
use crate::metric::{MetricEntry, NumericEntry};
|
||||
use std::{collections::HashMap, fs};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fs,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
const EPOCH_PREFIX: &str = "epoch-";
|
||||
|
||||
|
@ -27,7 +31,7 @@ pub trait MetricLogger: Send {
|
|||
/// The file metric logger.
|
||||
pub struct FileMetricLogger {
|
||||
loggers: HashMap<String, AsyncLogger<String>>,
|
||||
directory: String,
|
||||
directory: PathBuf,
|
||||
epoch: usize,
|
||||
}
|
||||
|
||||
|
@ -41,10 +45,10 @@ impl FileMetricLogger {
|
|||
/// # Returns
|
||||
///
|
||||
/// The file metric logger.
|
||||
pub fn new(directory: &str) -> Self {
|
||||
pub fn new(directory: impl AsRef<Path>) -> Self {
|
||||
Self {
|
||||
loggers: HashMap::new(),
|
||||
directory: directory.to_string(),
|
||||
directory: directory.as_ref().to_path_buf(),
|
||||
epoch: 1,
|
||||
}
|
||||
}
|
||||
|
@ -76,15 +80,18 @@ impl FileMetricLogger {
|
|||
max_epoch
|
||||
}
|
||||
|
||||
fn epoch_directory(&self, epoch: usize) -> String {
|
||||
format!("{}/{}{}", self.directory, EPOCH_PREFIX, epoch)
|
||||
fn epoch_directory(&self, epoch: usize) -> PathBuf {
|
||||
let name = format!("{}{}", EPOCH_PREFIX, epoch);
|
||||
self.directory.join(name)
|
||||
}
|
||||
fn file_path(&self, name: &str, epoch: usize) -> String {
|
||||
|
||||
fn file_path(&self, name: &str, epoch: usize) -> PathBuf {
|
||||
let directory = self.epoch_directory(epoch);
|
||||
let name = name.replace(' ', "_");
|
||||
|
||||
format!("{directory}/{name}.log")
|
||||
let name = format!("{name}.log");
|
||||
directory.join(name)
|
||||
}
|
||||
|
||||
fn create_directory(&self, epoch: usize) {
|
||||
let directory = self.epoch_directory(epoch);
|
||||
std::fs::create_dir_all(directory).ok();
|
||||
|
@ -102,7 +109,7 @@ impl MetricLogger for FileMetricLogger {
|
|||
self.create_directory(self.epoch);
|
||||
|
||||
let file_path = self.file_path(key, self.epoch);
|
||||
let logger = FileLogger::new(&file_path);
|
||||
let logger = FileLogger::new(file_path);
|
||||
let logger = AsyncLogger::new(logger);
|
||||
|
||||
self.loggers.insert(key.clone(), logger);
|
||||
|
|
Loading…
Reference in New Issue