refactor: fix all clippy warnings (#137)

This commit is contained in:
Visual 2022-12-25 18:22:25 +02:00 committed by GitHub
parent 85f98b9d54
commit 567adfb93e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 59 additions and 77 deletions

View File

@ -163,14 +163,14 @@ fn download(
if !config_named.is_empty() {
command.arg("--config-named");
for (key, value) in config_named {
command.arg(format!("{}={}", key, value));
command.arg(format!("{key}={value}"));
}
}
let mut handle = command.spawn().unwrap();
handle
.wait()
.map_err(|err| DownloaderError::Unknown(format!("{:?}", err)))?;
.map_err(|err| DownloaderError::Unknown(format!("{err:?}")))?;
Ok(())
}
@ -179,14 +179,14 @@ fn cache_dir() -> String {
let home_dir = home_dir().unwrap();
let home_dir = home_dir.to_str().map(|s| s.to_string());
let home_dir = home_dir.unwrap();
let cache_dir = format!("{}/.cache/burn-dataset", home_dir);
let cache_dir = format!("{home_dir}/.cache/burn-dataset");
std::fs::create_dir_all(&cache_dir).ok();
cache_dir
}
fn dataset_downloader_file_path() -> String {
let path_dir = cache_dir();
let path_file = format!("{}/dataset.py", path_dir);
let path_file = format!("{path_dir}/dataset.py");
fs::write(path_file.as_str(), PYTHON_SOURCE).expect("Write python dataset downloader");
path_file

View File

@ -36,7 +36,7 @@ impl ConfigEnumAnalyzer {
let mut output = Vec::new();
for i in 0..num {
let arg_name = Ident::new(&format!("arg_{}", i), self.name.span());
let arg_name = Ident::new(&format!("arg_{i}"), self.name.span());
input.push(quote! { #arg_name });
output.push(quote! { #arg_name.clone() });

View File

@ -180,7 +180,7 @@ impl ConfigAnalyzer for ConfigStructAnalyzer {
for (field, _) in self.fields_default.iter() {
let name = field.ident();
let ty = &field.field.ty;
let fn_name = Ident::new(&format!("with_{}", name), name.span());
let fn_name = Ident::new(&format!("with_{name}"), name.span());
body.extend(quote! {
pub fn #fn_name(mut self, #name: #ty) -> Self {
@ -193,7 +193,7 @@ impl ConfigAnalyzer for ConfigStructAnalyzer {
for field in self.fields_option.iter() {
let name = field.ident();
let ty = &field.field.ty;
let fn_name = Ident::new(&format!("with_{}", name), name.span());
let fn_name = Ident::new(&format!("with_{name}"), name.span());
body.extend(quote! {
pub fn #fn_name(mut self, #name: #ty) -> Self {

View File

@ -18,7 +18,7 @@ impl AttributeAnalyzer {
pub fn items(&self) -> Vec<AttributeItem> {
let config = match self.attr.parse_meta() {
Ok(val) => val,
Err(err) => panic!("Fail to parse items: {:?}", err),
Err(err) => panic!("Fail to parse items: {err:?}"),
};
let nested = match config {
Meta::List(val) => val.nested,

View File

@ -468,7 +468,7 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
4 => keepdim!(3, dim, tensor, mean),
5 => keepdim!(4, dim, tensor, mean),
6 => keepdim!(5, dim, tensor, mean),
_ => panic!("Dim not supported {}", D),
_ => panic!("Dim not supported {D}"),
}
}
@ -480,7 +480,7 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
4 => keepdim!(3, dim, tensor, sum),
5 => keepdim!(4, dim, tensor, sum),
6 => keepdim!(5, dim, tensor, sum),
_ => panic!("Dim not supported {}", D),
_ => panic!("Dim not supported {D}"),
}
}

View File

@ -569,10 +569,7 @@ where
/// ```
pub fn unsqueeze<const D2: usize>(&self) -> Tensor<B, D2> {
if D2 < D {
panic!(
"Can't unsqueeze smaller tensor, got dim {}, expected > {}",
D2, D
)
panic!("Can't unsqueeze smaller tensor, got dim {D2}, expected > {D}")
}
let mut dims = [1; D2];

View File

@ -223,7 +223,7 @@ impl<P: Into<f64> + Clone + std::fmt::Debug + PartialEq, const D: usize> Data<P,
let b = f64::round(10.0_f64.powi(precision as i32) * b);
if a != b {
println!("a {:?}, b {:?}", a, b);
println!("a {a:?}, b {b:?}");
eq = false;
}
}

View File

@ -12,10 +12,10 @@ impl std::fmt::Display for ConfigError {
match self {
Self::InvalidFormat(err) => {
message += format!("Invalid format: {}", err).as_str();
message += format!("Invalid format: {err}").as_str();
}
Self::FileNotFound(err) => {
message += format!("File not found: {}", err).as_str();
message += format!("File not found: {err}").as_str();
}
};
@ -41,5 +41,5 @@ pub fn config_to_json<C: Config>(config: &C) -> String {
}
fn config_from_str<C: Config>(content: &str) -> Result<C, ConfigError> {
serde_json::from_str(content).map_err(|err| ConfigError::InvalidFormat(format!("{}", err)))
serde_json::from_str(content).map_err(|err| ConfigError::InvalidFormat(format!("{err}")))
}

View File

@ -75,7 +75,7 @@ impl<M: Module> Module for Param<Vec<M>> {
let mut state = StateNamed::new();
for (i, module) in self.value.iter().enumerate() {
state.register_state(format!("mod-{}", i).as_str(), module.state());
state.register_state(format!("mod-{i}").as_str(), module.state());
}
let state = State::StateNamed(state);
@ -90,15 +90,12 @@ impl<M: Module> Module for Param<Vec<M>> {
let num = self.value.len();
for (i, module) in self.value.iter_mut().enumerate() {
module
.load(state.get(format!("mod-{}", i).as_str()).ok_or_else(|| {
.load(state.get(format!("mod-{i}").as_str()).ok_or_else(|| {
LoadingError::new(format!(
"Invalid number of modules, expected {} modules missing #{}",
num, i
"Invalid number of modules, expected {num} modules missing #{i}"
))
})?)
.map_err(|err| {
LoadingError::new(format!("Can't load modules mod-{}: {}", i, err))
})?;
.map_err(|err| LoadingError::new(format!("Can't load modules mod-{i}: {err}")))?;
}
Ok(())

View File

@ -30,10 +30,10 @@ impl std::fmt::Display for StateError {
match self {
Self::InvalidFormat(err) => {
message += format!("Invalid format: {}", err).as_str();
message += format!("Invalid format: {err}").as_str();
}
Self::FileNotFound(err) => {
message += format!("File not found: {}", err).as_str();
message += format!("File not found: {err}").as_str();
}
};
@ -122,7 +122,7 @@ where
pub fn load(file: &str) -> Result<Self, StateError> {
let path = Path::new(file);
let reader =
File::open(path).map_err(|err| StateError::FileNotFound(format!("{:?}", err)))?;
File::open(path).map_err(|err| StateError::FileNotFound(format!("{err:?}")))?;
let reader = GzDecoder::new(reader);
let state = serde_json::from_reader(reader).unwrap();

View File

@ -57,6 +57,6 @@ impl<B: ADBackend> WeightDecay<B> {
}
fn state_key(id: &ParamId) -> String {
format!("weight-decay-{}", id)
format!("weight-decay-{id}")
}
}

View File

@ -76,6 +76,6 @@ impl<B: ADBackend> Momentum<B> {
}
fn state_key(id: &ParamId) -> String {
format!("momentum-{}", id)
format!("momentum-{id}")
}
}

View File

@ -37,12 +37,8 @@ where
{
pub fn new(directory: &str) -> Self {
let renderer = Box::new(CLIDashboardRenderer::new());
let logger_train = Box::new(FileMetricLogger::new(
format!("{}/train", directory).as_str(),
));
let logger_valid = Box::new(FileMetricLogger::new(
format!("{}/valid", directory).as_str(),
));
let logger_train = Box::new(FileMetricLogger::new(format!("{directory}/train").as_str()));
let logger_valid = Box::new(FileMetricLogger::new(format!("{directory}/valid").as_str()));
Self {
dashboard: Dashboard::new(renderer, logger_train, logger_valid),

View File

@ -54,8 +54,7 @@ fn update_panic_hook(file_path: &str) {
std::panic::set_hook(Box::new(move |info| {
log::error!("PANIC => {}", info.to_string());
eprintln!(
"=== PANIC ===\nA fatal error happened, you can check the experiment logs here => '{}'\n=============",
file_path
"=== PANIC ===\nA fatal error happened, you can check the experiment logs here => '{file_path}'\n============="
);
hook(info);
}));

View File

@ -24,6 +24,6 @@ where
T: std::fmt::Display,
{
fn log(&mut self, item: T) {
writeln!(&mut self.file, "{}", item).unwrap();
writeln!(&mut self.file, "{item}").unwrap();
}
}

View File

@ -32,7 +32,7 @@ impl MetricLogger for FileMetricLogger {
Some(val) => val,
None => {
let directory = format!("{}/epoch-{}", self.directory, self.epoch);
let file_path = format!("{}/{}.log", directory, key);
let file_path = format!("{directory}/{key}.log");
std::fs::create_dir_all(&directory).ok();
let logger = FileLogger::new(&file_path);

View File

@ -50,7 +50,7 @@ impl<B: Backend> Metric<(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)> for Accura
let name = String::from("Accurracy");
let running = self.total as f64 / self.count as f64;
let raw_running = format!("{}", running);
let raw_running = format!("{running}");
let raw_current = format!("{}", self.current);
let formatted = format!(
"running {:.2} % current {:.2} %",

View File

@ -33,18 +33,15 @@ impl<T> Metric<T> for CUDAMetric {
let used_gb = memory_info.used as f64 * 1e-9;
let total_gb = memory_info.total as f64 * 1e-9;
let memory_info_formatted = format!("{:.2}/{:.2} Gb", used_gb, total_gb);
let memory_info_raw = format!("{}/{}", used_gb, total_gb);
let memory_info_formatted = format!("{used_gb:.2}/{total_gb:.2} Gb");
let memory_info_raw = format!("{used_gb}/{total_gb}");
formatted = format!(
"{} GPU #{} - Memory {}",
formatted, index, memory_info_formatted
);
raw_running = format!("{} ", memory_info_raw);
formatted = format!("{formatted} GPU #{index} - Memory {memory_info_formatted}");
raw_running = format!("{memory_info_raw} ");
let utilization_rates = device.utilization_rates().unwrap();
let utilization_rate_formatted = format!("{}%", utilization_rates.gpu);
formatted = format!("{} - Usage {}", formatted, utilization_rate_formatted);
formatted = format!("{formatted} - Usage {utilization_rate_formatted}");
}
Box::new(RunningMetricResult {

View File

@ -149,7 +149,7 @@ impl CLIDashboardRenderer {
if !metrics_keys.is_empty() {
let metrics_template = metrics_keys.join("\n");
template += format!("{}\n{}\n", PLOTS_TAG, metrics_template).as_str();
template += format!("{PLOTS_TAG}\n{metrics_template}\n").as_str();
}
template
@ -159,15 +159,15 @@ impl CLIDashboardRenderer {
let mut metrics_keys = Vec::new();
for (name, metric) in self.metric_train.iter() {
metrics_keys.push(format!(" - Train {}: {}", name, metric));
metrics_keys.push(format!(" - Train {name}: {metric}"));
}
for (name, metric) in self.metric_valid.iter() {
metrics_keys.push(format!(" - Valid {}: {}", name, metric));
metrics_keys.push(format!(" - Valid {name}: {metric}"));
}
if !metrics_keys.is_empty() {
let metrics_template = metrics_keys.join("\n");
template += format!("{}\n{}\n", METRICS_TAG, metrics_template).as_str();
template += format!("{METRICS_TAG}\n{metrics_template}\n").as_str();
}
template
@ -186,7 +186,7 @@ impl CLIDashboardRenderer {
let mut template = template;
let bar = "[{wide_bar:.cyan/blue}] ({eta})";
template += format!(" - {} {}", progress, bar).as_str();
template += format!(" - {progress} {bar}").as_str();
template
}
@ -246,7 +246,7 @@ impl CLIDashboardRenderer {
formatted: String,
) -> ProgressStyle {
style.with_key(key, move |_state: &ProgressState, w: &mut dyn Write| {
write!(w, "{}: {}", name, formatted).unwrap()
write!(w, "{name}: {formatted}").unwrap()
})
}
}

View File

@ -42,7 +42,7 @@ impl<B: Backend> Metric<Tensor<B, 1>> for LossMetric {
let name = String::from("Loss");
let running = self.total / self.count as f64;
let raw_running = format!("{}", running);
let raw_running = format!("{running}");
let raw_current = format!("{}", self.current);
let formatted = format!("running {:.3} current {:.3}", running, self.current);

View File

@ -17,9 +17,9 @@ pub struct TestStructConfig {
#[derive(Config, Debug, PartialEq)]
pub enum TestEnumConfig {
WithoutValue,
WithOneValue(f32),
WithMultipleValue(f32, String),
None,
Single(f32),
Multiple(f32, String),
}
#[test]
@ -47,7 +47,7 @@ fn struct_config_should_impl_display() {
#[test]
fn enum_config_no_value_should_impl_serde() {
let config = TestEnumConfig::WithoutValue;
let config = TestEnumConfig::None;
let file_path = "/tmp/test_enum_no_value_config.json";
config.save(file_path).unwrap();
@ -58,7 +58,7 @@ fn enum_config_no_value_should_impl_serde() {
#[test]
fn enum_config_one_value_should_impl_serde() {
let config = TestEnumConfig::WithOneValue(42.0);
let config = TestEnumConfig::Single(42.0);
let file_path = "/tmp/test_enum_one_value_config.json";
config.save(file_path).unwrap();
@ -69,7 +69,7 @@ fn enum_config_one_value_should_impl_serde() {
#[test]
fn enum_config_multiple_values_should_impl_serde() {
let config = TestEnumConfig::WithMultipleValue(42.0, "Allo".to_string());
let config = TestEnumConfig::Multiple(42.0, "Allo".to_string());
let file_path = "/tmp/test_enum_multiple_values_config.json";
config.save(file_path).unwrap();
@ -80,12 +80,12 @@ fn enum_config_multiple_values_should_impl_serde() {
#[test]
fn enum_config_should_impl_clone() {
let config = TestEnumConfig::WithMultipleValue(42.0, "Allo".to_string());
let config = TestEnumConfig::Multiple(42.0, "Allo".to_string());
assert_eq!(config, config.clone());
}
#[test]
fn enum_config_should_impl_display() {
let config = TestEnumConfig::WithMultipleValue(42.0, "Allo".to_string());
let config = TestEnumConfig::Multiple(42.0, "Allo".to_string());
assert_eq!(burn::config::config_to_json(&config), config.to_string());
}

View File

@ -55,6 +55,6 @@ pub fn run<B: ADBackend>(device: B::Device) {
let _model_trained = learner.fit(dataloader_train, dataloader_test);
config
.save(format!("{}/config.json", ARTIFACT_DIR).as_str())
.save(format!("{ARTIFACT_DIR}/config.json").as_str())
.unwrap();
}

View File

@ -40,8 +40,8 @@ pub fn run<B: Backend>() {
let permut = output.swap_dims::<_, 1, 2>();
println!("Weights => {}", weights);
println!("Input => {}", input);
println!("Output => {}", output);
println!("Permut => {}", permut);
println!("Weights => {weights}");
println!("Input => {input}");
println!("Output => {output}");
println!("Permut => {permut}");
}

View File

@ -85,12 +85,10 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
let model_trained = learner.fit(dataloader_train, dataloader_test);
config
.save(&format!("{}/config.json", artifact_dir))
.unwrap();
config.save(&format!("{artifact_dir}/config.json")).unwrap();
model_trained
.state()
.convert::<f32>()
.save(&format!("{}/model.json.gz", artifact_dir))
.save(&format!("{artifact_dir}/model.json.gz"))
.unwrap();
}

View File

@ -83,12 +83,10 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(
let model_trained = learner.fit(dataloader_train, dataloader_test);
config
.save(&format!("{}/config.json", artifact_dir))
.unwrap();
config.save(&format!("{artifact_dir}/config.json")).unwrap();
model_trained
.state()
.convert::<f32>()
.save(&format!("{}/model.json.gz", artifact_dir))
.save(&format!("{artifact_dir}/model.json.gz"))
.unwrap();
}