mirror of https://github.com/tracel-ai/burn.git
[backend-comparison] Add all choice to --benches and --backends (#1567)
+ Make some tweaks in logs
This commit is contained in:
parent
8d210a152f
commit
c4eac86ce5
|
@ -64,6 +64,8 @@ struct RunArgs {
|
|||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)]
|
||||
pub(crate) enum BackendValues {
|
||||
#[strum(to_string = "all")]
|
||||
All,
|
||||
#[strum(to_string = "candle-cpu")]
|
||||
CandleCpu,
|
||||
#[strum(to_string = "candle-cuda")]
|
||||
|
@ -90,6 +92,8 @@ pub(crate) enum BackendValues {
|
|||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)]
|
||||
pub(crate) enum BenchmarkValues {
|
||||
#[strum(to_string = "all")]
|
||||
All,
|
||||
#[strum(to_string = "binary")]
|
||||
Binary,
|
||||
#[strum(to_string = "custom-gelu")]
|
||||
|
@ -142,20 +146,26 @@ fn command_run(run_args: RunArgs) {
|
|||
if run_args.share {
|
||||
tokens = get_tokens();
|
||||
}
|
||||
let total_combinations = run_args.backends.len() * run_args.benches.len();
|
||||
println!(
|
||||
"Executing benchmark and backend combinations in total: {}",
|
||||
total_combinations
|
||||
);
|
||||
// collect benchmarks and benches to execute
|
||||
let mut backends = run_args.backends.clone();
|
||||
if backends.contains(&BackendValues::All) {
|
||||
backends = BackendValues::iter()
|
||||
.filter(|b| b != &BackendValues::All)
|
||||
.collect();
|
||||
}
|
||||
let mut benches = run_args.benches.clone();
|
||||
if benches.contains(&BenchmarkValues::All) {
|
||||
benches = BenchmarkValues::iter()
|
||||
.filter(|b| b != &BenchmarkValues::All)
|
||||
.collect();
|
||||
}
|
||||
|
||||
let total_combinations = backends.len() * benches.len();
|
||||
let mut app = App::new();
|
||||
app.init();
|
||||
println!("Running benchmarks...\n");
|
||||
println!("Running {} benchmark(s)...\n", total_combinations);
|
||||
let access_token = tokens.map(|t| t.access_token);
|
||||
app.run(
|
||||
&run_args.benches,
|
||||
&run_args.backends,
|
||||
access_token.as_deref(),
|
||||
);
|
||||
app.run(&benches, &backends, access_token.as_deref());
|
||||
app.cleanup();
|
||||
}
|
||||
|
||||
|
@ -177,6 +187,8 @@ pub(crate) fn run_backend_comparison_benchmarks(
|
|||
backends: &[BackendValues],
|
||||
token: Option<&str>,
|
||||
) {
|
||||
let total_count = backends.len() * benches.len();
|
||||
let mut current_index = 0;
|
||||
// Prefix and postfix for titles
|
||||
let filler = ["="; 10].join("");
|
||||
|
||||
|
@ -195,9 +207,10 @@ pub(crate) fn run_backend_comparison_benchmarks(
|
|||
for backend in backends.iter() {
|
||||
let bench_str = bench.to_string();
|
||||
let backend_str = backend.to_string();
|
||||
current_index += 1;
|
||||
println!(
|
||||
"{}Benchmarking {} on {}{}",
|
||||
filler, bench_str, backend_str, filler
|
||||
"{} ({}/{}) Benchmarking {} on {} {}",
|
||||
filler, current_index, total_count, bench_str, backend_str, filler
|
||||
);
|
||||
let url = format!("{}benchmarks", super::USER_BENCHMARK_SERVER_URL);
|
||||
let mut args = vec![
|
||||
|
@ -244,7 +257,7 @@ pub(crate) fn run_backend_comparison_benchmarks(
|
|||
};
|
||||
}
|
||||
println!(
|
||||
"{}Benchmark Results{}\n\n{}",
|
||||
"{} Benchmark Results {}\n\n{}",
|
||||
filler, filler, benchmark_results
|
||||
);
|
||||
fs::remove_file(benchmark_results_file).ok();
|
||||
|
|
|
@ -255,7 +255,8 @@ impl Display for BenchmarkCollection {
|
|||
let mut max_feature_len = "Feature".len();
|
||||
for record in self.records.iter() {
|
||||
max_name_len = max_name_len.max(record.results.name.len());
|
||||
max_backend_len = max_backend_len.max(record.backend.len());
|
||||
// + 2 because if the added backticks
|
||||
max_backend_len = max_backend_len.max(record.backend.len() + 2);
|
||||
max_device_len = max_device_len.max(record.device.len());
|
||||
max_feature_len = max_feature_len.max(record.feature.len());
|
||||
}
|
||||
|
@ -276,7 +277,7 @@ impl Display for BenchmarkCollection {
|
|||
"| {:<width_name$} | {:<width_feature$} | {:<width_backend$} | {:<width_device$} | {:<15.3?}|",
|
||||
record.results.name,
|
||||
record.feature,
|
||||
record.backend,
|
||||
format!("`{}`", record.backend),
|
||||
record.device,
|
||||
record.results.computed.median,
|
||||
width_name = max_name_len,
|
||||
|
|
Loading…
Reference in New Issue