mirror of https://github.com/tracel-ai/burn.git
refactor: fix all clippy warnings (#137)
This commit is contained in:
parent
85f98b9d54
commit
567adfb93e
|
@ -163,14 +163,14 @@ fn download(
|
|||
if !config_named.is_empty() {
|
||||
command.arg("--config-named");
|
||||
for (key, value) in config_named {
|
||||
command.arg(format!("{}={}", key, value));
|
||||
command.arg(format!("{key}={value}"));
|
||||
}
|
||||
}
|
||||
|
||||
let mut handle = command.spawn().unwrap();
|
||||
handle
|
||||
.wait()
|
||||
.map_err(|err| DownloaderError::Unknown(format!("{:?}", err)))?;
|
||||
.map_err(|err| DownloaderError::Unknown(format!("{err:?}")))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -179,14 +179,14 @@ fn cache_dir() -> String {
|
|||
let home_dir = home_dir().unwrap();
|
||||
let home_dir = home_dir.to_str().map(|s| s.to_string());
|
||||
let home_dir = home_dir.unwrap();
|
||||
let cache_dir = format!("{}/.cache/burn-dataset", home_dir);
|
||||
let cache_dir = format!("{home_dir}/.cache/burn-dataset");
|
||||
std::fs::create_dir_all(&cache_dir).ok();
|
||||
cache_dir
|
||||
}
|
||||
|
||||
fn dataset_downloader_file_path() -> String {
|
||||
let path_dir = cache_dir();
|
||||
let path_file = format!("{}/dataset.py", path_dir);
|
||||
let path_file = format!("{path_dir}/dataset.py");
|
||||
|
||||
fs::write(path_file.as_str(), PYTHON_SOURCE).expect("Write python dataset downloader");
|
||||
path_file
|
||||
|
|
|
@ -36,7 +36,7 @@ impl ConfigEnumAnalyzer {
|
|||
let mut output = Vec::new();
|
||||
|
||||
for i in 0..num {
|
||||
let arg_name = Ident::new(&format!("arg_{}", i), self.name.span());
|
||||
let arg_name = Ident::new(&format!("arg_{i}"), self.name.span());
|
||||
|
||||
input.push(quote! { #arg_name });
|
||||
output.push(quote! { #arg_name.clone() });
|
||||
|
|
|
@ -180,7 +180,7 @@ impl ConfigAnalyzer for ConfigStructAnalyzer {
|
|||
for (field, _) in self.fields_default.iter() {
|
||||
let name = field.ident();
|
||||
let ty = &field.field.ty;
|
||||
let fn_name = Ident::new(&format!("with_{}", name), name.span());
|
||||
let fn_name = Ident::new(&format!("with_{name}"), name.span());
|
||||
|
||||
body.extend(quote! {
|
||||
pub fn #fn_name(mut self, #name: #ty) -> Self {
|
||||
|
@ -193,7 +193,7 @@ impl ConfigAnalyzer for ConfigStructAnalyzer {
|
|||
for field in self.fields_option.iter() {
|
||||
let name = field.ident();
|
||||
let ty = &field.field.ty;
|
||||
let fn_name = Ident::new(&format!("with_{}", name), name.span());
|
||||
let fn_name = Ident::new(&format!("with_{name}"), name.span());
|
||||
|
||||
body.extend(quote! {
|
||||
pub fn #fn_name(mut self, #name: #ty) -> Self {
|
||||
|
|
|
@ -18,7 +18,7 @@ impl AttributeAnalyzer {
|
|||
pub fn items(&self) -> Vec<AttributeItem> {
|
||||
let config = match self.attr.parse_meta() {
|
||||
Ok(val) => val,
|
||||
Err(err) => panic!("Fail to parse items: {:?}", err),
|
||||
Err(err) => panic!("Fail to parse items: {err:?}"),
|
||||
};
|
||||
let nested = match config {
|
||||
Meta::List(val) => val.nested,
|
||||
|
|
|
@ -468,7 +468,7 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
|||
4 => keepdim!(3, dim, tensor, mean),
|
||||
5 => keepdim!(4, dim, tensor, mean),
|
||||
6 => keepdim!(5, dim, tensor, mean),
|
||||
_ => panic!("Dim not supported {}", D),
|
||||
_ => panic!("Dim not supported {D}"),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -480,7 +480,7 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
|||
4 => keepdim!(3, dim, tensor, sum),
|
||||
5 => keepdim!(4, dim, tensor, sum),
|
||||
6 => keepdim!(5, dim, tensor, sum),
|
||||
_ => panic!("Dim not supported {}", D),
|
||||
_ => panic!("Dim not supported {D}"),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -569,10 +569,7 @@ where
|
|||
/// ```
|
||||
pub fn unsqueeze<const D2: usize>(&self) -> Tensor<B, D2> {
|
||||
if D2 < D {
|
||||
panic!(
|
||||
"Can't unsqueeze smaller tensor, got dim {}, expected > {}",
|
||||
D2, D
|
||||
)
|
||||
panic!("Can't unsqueeze smaller tensor, got dim {D2}, expected > {D}")
|
||||
}
|
||||
|
||||
let mut dims = [1; D2];
|
||||
|
|
|
@ -223,7 +223,7 @@ impl<P: Into<f64> + Clone + std::fmt::Debug + PartialEq, const D: usize> Data<P,
|
|||
let b = f64::round(10.0_f64.powi(precision as i32) * b);
|
||||
|
||||
if a != b {
|
||||
println!("a {:?}, b {:?}", a, b);
|
||||
println!("a {a:?}, b {b:?}");
|
||||
eq = false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,10 +12,10 @@ impl std::fmt::Display for ConfigError {
|
|||
|
||||
match self {
|
||||
Self::InvalidFormat(err) => {
|
||||
message += format!("Invalid format: {}", err).as_str();
|
||||
message += format!("Invalid format: {err}").as_str();
|
||||
}
|
||||
Self::FileNotFound(err) => {
|
||||
message += format!("File not found: {}", err).as_str();
|
||||
message += format!("File not found: {err}").as_str();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -41,5 +41,5 @@ pub fn config_to_json<C: Config>(config: &C) -> String {
|
|||
}
|
||||
|
||||
fn config_from_str<C: Config>(content: &str) -> Result<C, ConfigError> {
|
||||
serde_json::from_str(content).map_err(|err| ConfigError::InvalidFormat(format!("{}", err)))
|
||||
serde_json::from_str(content).map_err(|err| ConfigError::InvalidFormat(format!("{err}")))
|
||||
}
|
||||
|
|
|
@ -75,7 +75,7 @@ impl<M: Module> Module for Param<Vec<M>> {
|
|||
let mut state = StateNamed::new();
|
||||
|
||||
for (i, module) in self.value.iter().enumerate() {
|
||||
state.register_state(format!("mod-{}", i).as_str(), module.state());
|
||||
state.register_state(format!("mod-{i}").as_str(), module.state());
|
||||
}
|
||||
|
||||
let state = State::StateNamed(state);
|
||||
|
@ -90,15 +90,12 @@ impl<M: Module> Module for Param<Vec<M>> {
|
|||
let num = self.value.len();
|
||||
for (i, module) in self.value.iter_mut().enumerate() {
|
||||
module
|
||||
.load(state.get(format!("mod-{}", i).as_str()).ok_or_else(|| {
|
||||
.load(state.get(format!("mod-{i}").as_str()).ok_or_else(|| {
|
||||
LoadingError::new(format!(
|
||||
"Invalid number of modules, expected {} modules missing #{}",
|
||||
num, i
|
||||
"Invalid number of modules, expected {num} modules missing #{i}"
|
||||
))
|
||||
})?)
|
||||
.map_err(|err| {
|
||||
LoadingError::new(format!("Can't load modules mod-{}: {}", i, err))
|
||||
})?;
|
||||
.map_err(|err| LoadingError::new(format!("Can't load modules mod-{i}: {err}")))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
|
@ -30,10 +30,10 @@ impl std::fmt::Display for StateError {
|
|||
|
||||
match self {
|
||||
Self::InvalidFormat(err) => {
|
||||
message += format!("Invalid format: {}", err).as_str();
|
||||
message += format!("Invalid format: {err}").as_str();
|
||||
}
|
||||
Self::FileNotFound(err) => {
|
||||
message += format!("File not found: {}", err).as_str();
|
||||
message += format!("File not found: {err}").as_str();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -122,7 +122,7 @@ where
|
|||
pub fn load(file: &str) -> Result<Self, StateError> {
|
||||
let path = Path::new(file);
|
||||
let reader =
|
||||
File::open(path).map_err(|err| StateError::FileNotFound(format!("{:?}", err)))?;
|
||||
File::open(path).map_err(|err| StateError::FileNotFound(format!("{err:?}")))?;
|
||||
let reader = GzDecoder::new(reader);
|
||||
let state = serde_json::from_reader(reader).unwrap();
|
||||
|
||||
|
|
|
@ -57,6 +57,6 @@ impl<B: ADBackend> WeightDecay<B> {
|
|||
}
|
||||
|
||||
fn state_key(id: &ParamId) -> String {
|
||||
format!("weight-decay-{}", id)
|
||||
format!("weight-decay-{id}")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -76,6 +76,6 @@ impl<B: ADBackend> Momentum<B> {
|
|||
}
|
||||
|
||||
fn state_key(id: &ParamId) -> String {
|
||||
format!("momentum-{}", id)
|
||||
format!("momentum-{id}")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,12 +37,8 @@ where
|
|||
{
|
||||
pub fn new(directory: &str) -> Self {
|
||||
let renderer = Box::new(CLIDashboardRenderer::new());
|
||||
let logger_train = Box::new(FileMetricLogger::new(
|
||||
format!("{}/train", directory).as_str(),
|
||||
));
|
||||
let logger_valid = Box::new(FileMetricLogger::new(
|
||||
format!("{}/valid", directory).as_str(),
|
||||
));
|
||||
let logger_train = Box::new(FileMetricLogger::new(format!("{directory}/train").as_str()));
|
||||
let logger_valid = Box::new(FileMetricLogger::new(format!("{directory}/valid").as_str()));
|
||||
|
||||
Self {
|
||||
dashboard: Dashboard::new(renderer, logger_train, logger_valid),
|
||||
|
|
|
@ -54,8 +54,7 @@ fn update_panic_hook(file_path: &str) {
|
|||
std::panic::set_hook(Box::new(move |info| {
|
||||
log::error!("PANIC => {}", info.to_string());
|
||||
eprintln!(
|
||||
"=== PANIC ===\nA fatal error happened, you can check the experiment logs here => '{}'\n=============",
|
||||
file_path
|
||||
"=== PANIC ===\nA fatal error happened, you can check the experiment logs here => '{file_path}'\n============="
|
||||
);
|
||||
hook(info);
|
||||
}));
|
||||
|
|
|
@ -24,6 +24,6 @@ where
|
|||
T: std::fmt::Display,
|
||||
{
|
||||
fn log(&mut self, item: T) {
|
||||
writeln!(&mut self.file, "{}", item).unwrap();
|
||||
writeln!(&mut self.file, "{item}").unwrap();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ impl MetricLogger for FileMetricLogger {
|
|||
Some(val) => val,
|
||||
None => {
|
||||
let directory = format!("{}/epoch-{}", self.directory, self.epoch);
|
||||
let file_path = format!("{}/{}.log", directory, key);
|
||||
let file_path = format!("{directory}/{key}.log");
|
||||
std::fs::create_dir_all(&directory).ok();
|
||||
|
||||
let logger = FileLogger::new(&file_path);
|
||||
|
|
|
@ -50,7 +50,7 @@ impl<B: Backend> Metric<(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)> for Accura
|
|||
|
||||
let name = String::from("Accurracy");
|
||||
let running = self.total as f64 / self.count as f64;
|
||||
let raw_running = format!("{}", running);
|
||||
let raw_running = format!("{running}");
|
||||
let raw_current = format!("{}", self.current);
|
||||
let formatted = format!(
|
||||
"running {:.2} % current {:.2} %",
|
||||
|
|
|
@ -33,18 +33,15 @@ impl<T> Metric<T> for CUDAMetric {
|
|||
let used_gb = memory_info.used as f64 * 1e-9;
|
||||
let total_gb = memory_info.total as f64 * 1e-9;
|
||||
|
||||
let memory_info_formatted = format!("{:.2}/{:.2} Gb", used_gb, total_gb);
|
||||
let memory_info_raw = format!("{}/{}", used_gb, total_gb);
|
||||
let memory_info_formatted = format!("{used_gb:.2}/{total_gb:.2} Gb");
|
||||
let memory_info_raw = format!("{used_gb}/{total_gb}");
|
||||
|
||||
formatted = format!(
|
||||
"{} GPU #{} - Memory {}",
|
||||
formatted, index, memory_info_formatted
|
||||
);
|
||||
raw_running = format!("{} ", memory_info_raw);
|
||||
formatted = format!("{formatted} GPU #{index} - Memory {memory_info_formatted}");
|
||||
raw_running = format!("{memory_info_raw} ");
|
||||
|
||||
let utilization_rates = device.utilization_rates().unwrap();
|
||||
let utilization_rate_formatted = format!("{}%", utilization_rates.gpu);
|
||||
formatted = format!("{} - Usage {}", formatted, utilization_rate_formatted);
|
||||
formatted = format!("{formatted} - Usage {utilization_rate_formatted}");
|
||||
}
|
||||
|
||||
Box::new(RunningMetricResult {
|
||||
|
|
|
@ -149,7 +149,7 @@ impl CLIDashboardRenderer {
|
|||
|
||||
if !metrics_keys.is_empty() {
|
||||
let metrics_template = metrics_keys.join("\n");
|
||||
template += format!("{}\n{}\n", PLOTS_TAG, metrics_template).as_str();
|
||||
template += format!("{PLOTS_TAG}\n{metrics_template}\n").as_str();
|
||||
}
|
||||
|
||||
template
|
||||
|
@ -159,15 +159,15 @@ impl CLIDashboardRenderer {
|
|||
let mut metrics_keys = Vec::new();
|
||||
|
||||
for (name, metric) in self.metric_train.iter() {
|
||||
metrics_keys.push(format!(" - Train {}: {}", name, metric));
|
||||
metrics_keys.push(format!(" - Train {name}: {metric}"));
|
||||
}
|
||||
for (name, metric) in self.metric_valid.iter() {
|
||||
metrics_keys.push(format!(" - Valid {}: {}", name, metric));
|
||||
metrics_keys.push(format!(" - Valid {name}: {metric}"));
|
||||
}
|
||||
|
||||
if !metrics_keys.is_empty() {
|
||||
let metrics_template = metrics_keys.join("\n");
|
||||
template += format!("{}\n{}\n", METRICS_TAG, metrics_template).as_str();
|
||||
template += format!("{METRICS_TAG}\n{metrics_template}\n").as_str();
|
||||
}
|
||||
|
||||
template
|
||||
|
@ -186,7 +186,7 @@ impl CLIDashboardRenderer {
|
|||
let mut template = template;
|
||||
|
||||
let bar = "[{wide_bar:.cyan/blue}] ({eta})";
|
||||
template += format!(" - {} {}", progress, bar).as_str();
|
||||
template += format!(" - {progress} {bar}").as_str();
|
||||
template
|
||||
}
|
||||
|
||||
|
@ -246,7 +246,7 @@ impl CLIDashboardRenderer {
|
|||
formatted: String,
|
||||
) -> ProgressStyle {
|
||||
style.with_key(key, move |_state: &ProgressState, w: &mut dyn Write| {
|
||||
write!(w, "{}: {}", name, formatted).unwrap()
|
||||
write!(w, "{name}: {formatted}").unwrap()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -42,7 +42,7 @@ impl<B: Backend> Metric<Tensor<B, 1>> for LossMetric {
|
|||
|
||||
let name = String::from("Loss");
|
||||
let running = self.total / self.count as f64;
|
||||
let raw_running = format!("{}", running);
|
||||
let raw_running = format!("{running}");
|
||||
let raw_current = format!("{}", self.current);
|
||||
let formatted = format!("running {:.3} current {:.3}", running, self.current);
|
||||
|
||||
|
|
|
@ -17,9 +17,9 @@ pub struct TestStructConfig {
|
|||
|
||||
#[derive(Config, Debug, PartialEq)]
|
||||
pub enum TestEnumConfig {
|
||||
WithoutValue,
|
||||
WithOneValue(f32),
|
||||
WithMultipleValue(f32, String),
|
||||
None,
|
||||
Single(f32),
|
||||
Multiple(f32, String),
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -47,7 +47,7 @@ fn struct_config_should_impl_display() {
|
|||
|
||||
#[test]
|
||||
fn enum_config_no_value_should_impl_serde() {
|
||||
let config = TestEnumConfig::WithoutValue;
|
||||
let config = TestEnumConfig::None;
|
||||
let file_path = "/tmp/test_enum_no_value_config.json";
|
||||
|
||||
config.save(file_path).unwrap();
|
||||
|
@ -58,7 +58,7 @@ fn enum_config_no_value_should_impl_serde() {
|
|||
|
||||
#[test]
|
||||
fn enum_config_one_value_should_impl_serde() {
|
||||
let config = TestEnumConfig::WithOneValue(42.0);
|
||||
let config = TestEnumConfig::Single(42.0);
|
||||
let file_path = "/tmp/test_enum_one_value_config.json";
|
||||
|
||||
config.save(file_path).unwrap();
|
||||
|
@ -69,7 +69,7 @@ fn enum_config_one_value_should_impl_serde() {
|
|||
|
||||
#[test]
|
||||
fn enum_config_multiple_values_should_impl_serde() {
|
||||
let config = TestEnumConfig::WithMultipleValue(42.0, "Allo".to_string());
|
||||
let config = TestEnumConfig::Multiple(42.0, "Allo".to_string());
|
||||
let file_path = "/tmp/test_enum_multiple_values_config.json";
|
||||
|
||||
config.save(file_path).unwrap();
|
||||
|
@ -80,12 +80,12 @@ fn enum_config_multiple_values_should_impl_serde() {
|
|||
|
||||
#[test]
|
||||
fn enum_config_should_impl_clone() {
|
||||
let config = TestEnumConfig::WithMultipleValue(42.0, "Allo".to_string());
|
||||
let config = TestEnumConfig::Multiple(42.0, "Allo".to_string());
|
||||
assert_eq!(config, config.clone());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enum_config_should_impl_display() {
|
||||
let config = TestEnumConfig::WithMultipleValue(42.0, "Allo".to_string());
|
||||
let config = TestEnumConfig::Multiple(42.0, "Allo".to_string());
|
||||
assert_eq!(burn::config::config_to_json(&config), config.to_string());
|
||||
}
|
||||
|
|
|
@ -55,6 +55,6 @@ pub fn run<B: ADBackend>(device: B::Device) {
|
|||
let _model_trained = learner.fit(dataloader_train, dataloader_test);
|
||||
|
||||
config
|
||||
.save(format!("{}/config.json", ARTIFACT_DIR).as_str())
|
||||
.save(format!("{ARTIFACT_DIR}/config.json").as_str())
|
||||
.unwrap();
|
||||
}
|
||||
|
|
|
@ -40,8 +40,8 @@ pub fn run<B: Backend>() {
|
|||
|
||||
let permut = output.swap_dims::<_, 1, 2>();
|
||||
|
||||
println!("Weights => {}", weights);
|
||||
println!("Input => {}", input);
|
||||
println!("Output => {}", output);
|
||||
println!("Permut => {}", permut);
|
||||
println!("Weights => {weights}");
|
||||
println!("Input => {input}");
|
||||
println!("Output => {output}");
|
||||
println!("Permut => {permut}");
|
||||
}
|
||||
|
|
|
@ -85,12 +85,10 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
|
|||
|
||||
let model_trained = learner.fit(dataloader_train, dataloader_test);
|
||||
|
||||
config
|
||||
.save(&format!("{}/config.json", artifact_dir))
|
||||
.unwrap();
|
||||
config.save(&format!("{artifact_dir}/config.json")).unwrap();
|
||||
model_trained
|
||||
.state()
|
||||
.convert::<f32>()
|
||||
.save(&format!("{}/model.json.gz", artifact_dir))
|
||||
.save(&format!("{artifact_dir}/model.json.gz"))
|
||||
.unwrap();
|
||||
}
|
||||
|
|
|
@ -83,12 +83,10 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(
|
|||
|
||||
let model_trained = learner.fit(dataloader_train, dataloader_test);
|
||||
|
||||
config
|
||||
.save(&format!("{}/config.json", artifact_dir))
|
||||
.unwrap();
|
||||
config.save(&format!("{artifact_dir}/config.json")).unwrap();
|
||||
model_trained
|
||||
.state()
|
||||
.convert::<f32>()
|
||||
.save(&format!("{}/model.json.gz", artifact_dir))
|
||||
.save(&format!("{artifact_dir}/model.json.gz"))
|
||||
.unwrap();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue