[backend-comparison] Refresh access token and display authenticated user name (#1483)

* [backend-comparison] Serialize both auth tokens to cache file

We need to refresh token to be able to renew an expired access token.

* [backend-comparison] Refresh access token

* [backend-comparison] Display user name with auth command

* [backend-comparison] Update README

* [backend-comparison] Fix PR comments

* [backend-comparison] Fix hyphen consistency in benchmark names

* [backend-comparison] Fix release build error when refreshing token

The reqwest must have an explicit empty body otherwise the release
build returns a 411 when refreshing the tokens without even calling
the benchmark server endpoint.
This commit is contained in:
Sylvain Benner 2024-03-20 15:39:32 -04:00 committed by GitHub
parent 47a84cc980
commit e8863dafd2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 278 additions and 184 deletions

1
Cargo.lock generated
View File

@ -208,6 +208,7 @@ dependencies = [
"ratatui",
"reqwest",
"rstest",
"serde",
"serde_json",
"serial_test",
"strum",

View File

@ -37,6 +37,7 @@ github-device-flow = { workspace = true }
rand = { workspace = true }
ratatui = { workspace = true, optional = true }
reqwest = {workspace = true, features = ["blocking", "json"]}
serde = { workspace = true }
serde_json = { workspace = true }
strum = { workspace = true }
strum_macros = { workspace = true }
@ -54,7 +55,8 @@ name = "binary"
harness = false
[[bench]]
name = "max_pool2d"
name = "max-pool2d"
path = "benches/max_pool2d.rs"
harness = false
[[bench]]
@ -66,7 +68,8 @@ name = "data"
harness = false
[[bench]]
name = "custom_gelu"
name = "custom-gelu"
path = "benches/custom_gelu.rs"
harness = false
[[bin]]

View File

@ -10,7 +10,10 @@ within the corresponding backend crate.
## burnbench CLI
This crate comes with a CLI binary called `burnbench` which can be executed via
`cargo run --bin burnbench`.
`cargo run --release --bin burnbench`.
Note that you need to run the `release` target of `burnbench` otherwise you won't
be able to share your benchmark results.
The end of options argument `--` is used to pass arguments to the `burnbench`
application. For instance `cargo run --bin burnbench -- list` passes the `list`
@ -23,7 +26,7 @@ argument to `burnbench` effectively calling `burnbench list`.
To list all the available benches and backends use the `list` command:
```sh
> cargo run --bin burnbench -- list
> cargo run --release --bin burnbench -- list
Finished dev [unoptimized] target(s) in 0.10s
Running `target/debug/burnbench list`
Available Backends:
@ -54,13 +57,13 @@ with the arguments `--benches` and `--backends` respectively. In the following
example we execute the `unary` benchmark against the `wgpu-fusion` backend:
```sh
> cargo run --bin burnbench -- run --benches unary --backends wgpu-fusion
> cargo run --release --bin burnbench -- run --benches unary --backends wgpu-fusion
```
Shorthands can be used, the following command line is the same:
```sh
> cargo run --bin burnbench -- run -b unary -B wgpu-fusion
> cargo run --release --bin burnbench -- run -b unary -B wgpu-fusion
```
Multiple benchmarks and backends can be passed on the same command line. In this
@ -68,8 +71,7 @@ case, all the combinations of benchmarks with backends will be executed.
```sh
> cargo run --bin burnbench -- run --benches unary binary --backends wgpu-fusion tch-gpu
Finished dev [unoptimized] target(s) in 0.09s
Running `target/debug/burnbench run --benches unary binary --backends wgpu-fusion wgpu`
Running `target/release/burnbench run --benches unary binary --backends wgpu-fusion wgpu`
Executing the following benchmark and backend combinations (Total: 4):
- Benchmark: unary, Backend: wgpu-fusion
- Benchmark: binary, Backend: wgpu-fusion
@ -88,7 +90,7 @@ Sharing results is opt-in and it is enabled with the `--share` arguments passed
to the `run` command:
```sh
> cargo run --bin burnbench -- run --share --benches unary --backends wgpu-fusion
> cargo run --release --bin burnbench -- run --share --benches unary --backends wgpu-fusion
```
To be able to upload results you must be authenticated. We only support GitHub
@ -96,14 +98,26 @@ authentication. To authenticate run the `auth` command, then follow the URL
to enter your device code and authorize the Burnbench application:
```sh
> cargo run --bin burnbench -- run auth
> cargo run --release --bin burnbench -- auth
```
If everything is fine you should get a confirmation in the terminal that your
token has been saved to the burn cache directory.
We don't store any of your personal information. An anonymized user name will
be attributed to you and displayed in the terminal once you are authenticated.
For instance:
```
🔑 Your username is: CuteFlame
```
You can now use the `--share` argument to upload and share your benchmarks!
Note that your access token will be refreshed automatically so you should not
need to reauthorize the application again except if your refresh token itself
becomes invalid.
### Terminal UI
This is a work in progress and is not usable for now.

View File

@ -1,18 +1,169 @@
use arboard::Clipboard;
use burn::serde::{Deserialize, Serialize};
use github_device_flow::{self, DeviceFlow};
use reqwest;
use std::io::Write;
#[cfg(unix)]
use std::os::unix::fs::PermissionsExt;
use std::{
fs::{self, File},
path::{Path, PathBuf},
thread, time,
};
pub(crate) static CLIENT_ID: &str = "Iv1.692f6a61b6086810";
const FIVE_SECONDS: time::Duration = time::Duration::new(5, 0);
static GITHUB_API_VERSION_HEADER: &str = "X-GitHub-Api-Version";
static GITHUB_API_VERSION: &str = "2022-11-28";
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct Tokens {
/// Token returned once the Burnbench Github app has been authorized by the user.
/// This token is used to authenticate the user to the Burn benchmark server.
/// This token is a short lived token (about 8 hours).
pub access_token: String,
/// Along with the access token, a refresh token is provided once the Burnbench
/// GitHub app has been authorized by the user.
/// This token can be presented to the Burn benchmark server in order to re-issue
/// a new access token for the user.
/// This token is longer lived (around 6 months).
pub refresh_token: String,
}
#[derive(Debug, Deserialize)]
pub(crate) struct UserInfo {
pub nickname: String,
}
/// Retrieve cached tokens and refresh them if necessary then save the new tokens.
/// If there is no cached token or if the access token cannot be resfresh then
/// ask for the user to reauthorize the Burnbench github application.
pub(crate) fn get_tokens() -> Option<Tokens> {
get_tokens_from_cache().map_or_else(
// no token saved yet
auth,
// cached tokens found
|tokens| {
if verify_tokens(&tokens) {
Some(tokens)
} else {
refresh_tokens(&tokens).map_or_else(
|| {
println!("⚠ Cannot refresh the access token. You need to reauthorize the Burnbench application.");
auth()
},
|new_tokens| {
save_tokens(&new_tokens);
Some(new_tokens)
})
}
},
)
}
/// Returns the authenticated user name from access token
pub(crate) fn get_username(access_token: &str) -> Option<UserInfo> {
let client = reqwest::blocking::Client::new();
let response = client
.get(format!("{}users/me", super::USER_BENCHMARK_SERVER_URL))
.header(reqwest::header::USER_AGENT, "burnbench")
.header(reqwest::header::CONTENT_TYPE, "application/json")
.header(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", access_token),
)
.send()
.ok()?;
response.json::<UserInfo>().ok()
}
fn auth() -> Option<Tokens> {
let mut flow = match DeviceFlow::start(CLIENT_ID, None) {
Ok(flow) => flow,
Err(e) => {
eprintln!("Error authenticating: {}", e);
return None;
}
};
println!("🌐 Please visit for following URL in your browser (CTRL+click if your terminal supports it):");
println!("\n {}\n", flow.verification_uri.clone().unwrap());
let user_code = flow.user_code.clone().unwrap();
println!("👉 And enter code: {}", &user_code);
if let Ok(mut clipboard) = Clipboard::new() {
if clipboard.set_text(user_code).is_ok() {
println!("📋 Code has been successfully copied to clipboard.")
};
};
// Wait for the minimum allowed interval to poll for authentication update
// see: https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/authorizing-oauth-apps#step-3-app-polls-github-to-check-if-the-user-authorized-the-device
thread::sleep(FIVE_SECONDS);
match flow.poll(20) {
Ok(creds) => {
let tokens = Tokens {
access_token: creds.token.clone(),
refresh_token: creds.refresh_token.clone(),
};
save_tokens(&tokens);
Some(tokens)
}
Err(e) => {
eprint!("Authentication error: {}", e);
None
}
}
}
/// Return the token saved in the cache file
#[inline]
fn get_tokens_from_cache() -> Option<Tokens> {
let path = get_auth_cache_file_path();
let file = File::open(path).ok()?;
let tokens: Tokens = serde_json::from_reader(file).ok()?;
Some(tokens)
}
/// Returns true if the token is still valid
fn verify_tokens(tokens: &Tokens) -> bool {
let client = reqwest::blocking::Client::new();
let response = client
.get("https://api.github.com/user")
.header(reqwest::header::USER_AGENT, "burnbench")
.header(reqwest::header::ACCEPT, "application/vnd.github+json")
.header(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", tokens.access_token),
)
.header(GITHUB_API_VERSION_HEADER, GITHUB_API_VERSION)
.send();
response.map_or(false, |resp| resp.status().is_success())
}
fn refresh_tokens(tokens: &Tokens) -> Option<Tokens> {
println!("Access token must be refreshed.");
println!("Refreshing token...");
let client = reqwest::blocking::Client::new();
let response = client
.post(format!(
"{}auth/refresh-token",
super::USER_BENCHMARK_SERVER_URL
))
.header(reqwest::header::USER_AGENT, "burnbench")
.header(reqwest::header::CONTENT_TYPE, "application/json")
.header(
reqwest::header::AUTHORIZATION,
format!("Bearer-Refresh {}", tokens.refresh_token),
)
// it is important to explicitly add an empty body otherwise
// reqwest won't send the request in release build
.body(reqwest::blocking::Body::from(""))
.send();
response.ok()?.json::<Tokens>().ok().map(|new_tokens| {
println!("✅ Token refreshed!");
new_tokens
})
}
/// Return the file path for the auth cache on disk
pub(crate) fn get_auth_cache_file_path() -> PathBuf {
fn get_auth_cache_file_path() -> PathBuf {
let home_dir = dirs::home_dir().expect("an home directory should exist");
let path_dir = home_dir.join(".cache").join("burn").join("burnbench");
#[cfg(test)]
@ -21,26 +172,13 @@ pub(crate) fn get_auth_cache_file_path() -> PathBuf {
path.join("token.txt")
}
/// Returns true if the token is still valid
pub(crate) fn verify_token(token: &str) -> bool {
let client = reqwest::blocking::Client::new();
let response = client
.get("https://api.github.com/user")
.header(reqwest::header::USER_AGENT, "burnbench")
.header(reqwest::header::ACCEPT, "application/vnd.github+json")
.header(reqwest::header::AUTHORIZATION, format!("Bearer {}", token))
.header(GITHUB_API_VERSION_HEADER, GITHUB_API_VERSION)
.send();
response.map_or(false, |resp| resp.status().is_success())
}
/// Save token in Burn cache directory and adjust file permissions
pub(crate) fn save_token(token: &str) {
fn save_tokens(tokens: &Tokens) {
let path = get_auth_cache_file_path();
fs::create_dir_all(path.parent().expect("path should have a parent directory"))
.expect("directory should be created");
let mut file = File::create(&path).expect("file should be created");
write!(file, "{}", token).expect("token should be written to file");
let file = File::create(&path).expect("file should be created");
serde_json::to_writer_pretty(file, &tokens).expect("Tokens should be saved to cache file.");
// On unix systems we lower the permissions on the cache file to be readable
// just by the current user
#[cfg(unix)]
@ -49,23 +187,23 @@ pub(crate) fn save_token(token: &str) {
println!("✅ Token saved at location: {}", path.to_str().unwrap());
}
/// Return the token saved in the cache file
#[inline]
pub(crate) fn get_token_from_cache() -> Option<String> {
let path = get_auth_cache_file_path();
fs::read_to_string(path)
.ok()
.and_then(|contents| contents.lines().next().map(str::to_string))
}
#[cfg(test)]
use serial_test::serial;
#[cfg(test)]
mod tests {
use super::*;
use rstest::*;
use std::fs;
#[fixture]
fn tokens() -> Tokens {
Tokens {
access_token: "unique_test_token".to_string(),
refresh_token: "unique_refresh_token".to_string(),
}
}
fn cleanup_test_environment() {
let path = get_auth_cache_file_path();
if path.exists() {
@ -79,67 +217,35 @@ mod tests {
}
}
#[test]
#[rstest]
#[serial]
fn test_save_token_when_file_does_not_exist() {
fn test_save_token_when_file_does_not_exist(tokens: Tokens) {
cleanup_test_environment();
let token = "unique_test_token";
// Ensure the file does not exist
let path = get_auth_cache_file_path();
if path.exists() {
fs::remove_file(&path).unwrap();
}
save_token(token);
assert_eq!(fs::read_to_string(path).unwrap(), token);
save_tokens(&tokens);
let retrieved_tokens = get_tokens_from_cache().unwrap();
assert_eq!(retrieved_tokens.access_token, tokens.access_token);
assert_eq!(retrieved_tokens.refresh_token, tokens.refresh_token);
cleanup_test_environment();
}
#[test]
#[rstest]
#[serial]
fn test_overwrite_saved_token_when_file_already_exists() {
fn test_overwrite_saved_token_when_file_already_exists(tokens: Tokens) {
cleanup_test_environment();
let initial_token = "initial_test_token";
let new_token = "new_test_token";
// Save initial token
save_token(initial_token);
// Save new token that should overwrite the initial one
save_token(new_token);
let path = get_auth_cache_file_path();
assert_eq!(fs::read_to_string(path).unwrap(), new_token);
cleanup_test_environment();
}
#[test]
#[serial]
fn test_get_saved_token_from_cache_when_it_exists() {
cleanup_test_environment();
let token = "existing_test_token";
// Save the token first
save_token(token);
// Now retrieve it
let retrieved_token = get_token_from_cache().unwrap();
assert_eq!(retrieved_token, token);
cleanup_test_environment();
}
#[test]
#[serial]
fn test_return_only_first_line_of_cache_as_token() {
cleanup_test_environment();
let path = get_auth_cache_file_path();
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).expect("directory tree should be created");
}
// Create a file with multiple lines
let mut file = File::create(&path).expect("test file should be created");
write!(file, "first_line_token\nsecond_line\nthird_line")
.expect("test file should contain several lines");
// Test that only the first line is returned as the token
let token = get_token_from_cache().expect("token should be present");
assert_eq!(
token, "first_line_token",
"The token should match only the first line of the file"
);
save_tokens(&tokens);
let new_tokens = Tokens {
access_token: "new_test_token".to_string(),
refresh_token: "new_refresh_token".to_string(),
};
save_tokens(&new_tokens);
let retrieved_tokens = get_tokens_from_cache().unwrap();
assert_eq!(retrieved_tokens.access_token, new_tokens.access_token);
assert_eq!(retrieved_tokens.refresh_token, new_tokens.refresh_token);
cleanup_test_environment();
}
@ -152,7 +258,7 @@ mod tests {
if path.exists() {
fs::remove_file(&path).unwrap();
}
assert!(get_token_from_cache().is_none());
assert!(get_tokens_from_cache().is_none());
cleanup_test_environment();
}
@ -167,7 +273,7 @@ mod tests {
}
File::create(&path).expect("empty file should be created");
assert!(
get_token_from_cache().is_none(),
get_tokens_from_cache().is_none(),
"Expected None for empty cache file, got Some"
);
cleanup_test_environment();

View File

@ -1,31 +1,18 @@
use super::{
auth::{save_token, CLIENT_ID},
App,
};
use crate::burnbenchapp::auth::{get_token_from_cache, verify_token};
use crate::persistence::{BenchmarkCollection, BenchmarkRecord};
use arboard::Clipboard;
use clap::{Parser, Subcommand, ValueEnum};
use github_device_flow::{self, DeviceFlow};
use serde_json;
use std::fs;
use std::io::{BufRead, BufReader, Result as ioResult};
use std::process::ExitStatus;
use std::{
process::{Command, ExitStatus, Stdio},
thread, time,
fs,
io::{BufRead, BufReader, Result as ioResult},
process::{Command, Stdio},
};
use strum::IntoEnumIterator;
use strum_macros::{Display, EnumIter};
const FIVE_SECONDS: time::Duration = time::Duration::new(5, 0);
const BENCHMARKS_TARGET_DIR: &str = "target/benchmarks";
const USER_BENCHMARK_SERVER_URL: &str = if cfg!(debug_assertions) {
// development
"http://localhost:8000/benchmarks"
} else {
// production
"https://user-benchmark-server-gvtbw64teq-nn.a.run.app/benchmarks"
};
use crate::burnbenchapp::auth::Tokens;
use crate::persistence::{BenchmarkCollection, BenchmarkRecord};
use super::auth::get_username;
use super::{auth::get_tokens, App};
/// Base trait to define an application
pub(crate) trait Application {
@ -105,7 +92,7 @@ pub(crate) enum BackendValues {
pub(crate) enum BenchmarkValues {
#[strum(to_string = "binary")]
Binary,
#[strum(to_string = "custom_gelu")]
#[strum(to_string = "custom-gelu")]
CustomGelu,
#[strum(to_string = "data")]
Data,
@ -113,7 +100,7 @@ pub(crate) enum BenchmarkValues {
Matmul,
#[strum(to_string = "unary")]
Unary,
#[strum(to_string = "max_pool2d")]
#[strum(to_string = "max-pool2d")]
MaxPool2d,
}
@ -126,34 +113,17 @@ pub fn execute() {
}
}
/// Create an access token from GitHub Burnbench application and store it
/// to be used with the user benchmark backend.
/// Create an access token from GitHub Burnbench application, store it,
/// and display the name of the authenticated user.
fn command_auth() {
let mut flow = match DeviceFlow::start(CLIENT_ID, None) {
Ok(flow) => flow,
Err(e) => {
eprintln!("Error authenticating: {}", e);
return;
}
};
println!("🌐 Please visit for following URL in your browser (CTRL+click if your terminal supports it):");
println!("\n {}\n", flow.verification_uri.clone().unwrap());
let user_code = flow.user_code.clone().unwrap();
println!("👉 And enter code: {}", &user_code);
if let Ok(mut clipboard) = Clipboard::new() {
if clipboard.set_text(user_code).is_ok() {
println!("📋 Code has been successfully copied to clipboard.")
};
};
// Wait for the minimum allowed interval to poll for authentication update
// see: https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/authorizing-oauth-apps#step-3-app-polls-github-to-check-if-the-user-authorized-the-device
thread::sleep(FIVE_SECONDS);
match flow.poll(20) {
Ok(creds) => {
save_token(&creds.token);
}
Err(e) => eprint!("Authentication error: {}", e),
};
get_tokens()
.and_then(|t| get_username(&t.access_token))
.map(|user_info| {
println!("🔑 Your username is: {}", user_info.nickname);
})
.unwrap_or_else(|| {
println!("Failed to display your username.");
});
}
fn command_list() {
@ -168,21 +138,9 @@ fn command_list() {
}
fn command_run(run_args: RunArgs) {
let token = get_token_from_cache();
let mut tokens: Option<Tokens> = None;
if run_args.share {
// Verify if a token is saved
if token.is_none() {
eprintln!("You need to be authenticated to be able to share benchmark results.");
eprintln!("Run the command 'burnbench auth' to authenticate.");
return;
}
// TODO refresh the token when it is expired
// Check for the validity of the saved token
if !verify_token(token.as_deref().unwrap()) {
eprintln!("Your access token is no longer valid.");
eprintln!("Run the command 'burnbench auth' again to get a new token.");
return;
}
tokens = get_tokens();
}
let total_combinations = run_args.backends.len() * run_args.benches.len();
println!(
@ -192,10 +150,11 @@ fn command_run(run_args: RunArgs) {
let mut app = App::new();
app.init();
println!("Running benchmarks...\n");
let access_token = tokens.map(|t| t.access_token);
app.run(
&run_args.benches,
&run_args.backends,
token.as_deref().filter(|_| run_args.share),
access_token.as_deref(),
);
app.cleanup();
}
@ -240,6 +199,7 @@ pub(crate) fn run_backend_comparison_benchmarks(
"{}Benchmarking {} on {}{}",
filler, bench_str, backend_str, filler
);
let url = format!("{}benchmarks", super::USER_BENCHMARK_SERVER_URL);
let mut args = vec![
"-p",
"backend-comparison",
@ -248,12 +208,12 @@ pub(crate) fn run_backend_comparison_benchmarks(
"--features",
&backend_str,
"--target-dir",
BENCHMARKS_TARGET_DIR,
super::BENCHMARKS_TARGET_DIR,
];
if let Some(t) = token {
args.push("--");
args.push("--sharing-url");
args.push(USER_BENCHMARK_SERVER_URL);
args.push(url.as_str());
args.push("--sharing-token");
args.push(t);
}

View File

@ -12,3 +12,12 @@ use tui::TuiApplication as App;
mod term;
#[cfg(not(feature = "tui"))]
use term::TermApplication as App;
const BENCHMARKS_TARGET_DIR: &str = "target/benchmarks";
const USER_BENCHMARK_SERVER_URL: &str = if cfg!(debug_assertions) {
// development
"http://localhost:8000/"
} else {
// production
"https://user-benchmark-server-gvtbw64teq-nn.a.run.app/"
};

View File

@ -82,7 +82,8 @@ pub fn save<B: Backend>(
serde_json::to_writer_pretty(file, &record)
.expect("Benchmark file should be updated with benchmark results");
// Append the benchmark result filepath in the benchmark_results.tx file of cache folder to be later picked by benchrun
// Append the benchmark result filepath in the benchmark_results.tx file of
// cache folder to be later picked by benchrun
let benchmark_results_path = cache_dir.join("benchmark_results.txt");
let mut benchmark_results_file = fs::OpenOptions::new()
.append(true)
@ -93,39 +94,39 @@ pub fn save<B: Backend>(
.write_all(format!("{}\n", file_path.to_string_lossy()).as_bytes())
.unwrap();
if url.is_some() {
println!("Sharing results...");
let client = reqwest::blocking::Client::new();
let mut headers = HeaderMap::new();
headers.insert(USER_AGENT, "burnbench".parse().unwrap());
headers.insert(ACCEPT, "application/json".parse().unwrap());
headers.insert(
AUTHORIZATION,
format!(
"Bearer {}",
token.expect("An auth token should be provided.")
)
.parse()
.unwrap(),
if let Some(upload_url) = url {
upload_record(
&record,
token.expect("An auth token should be provided."),
upload_url,
);
// post the benchmark record
let response = client
.post(url.expect("A benchmark server URL should be provided."))
.headers(headers)
.json(&record)
.send()
.expect("Request should be sent successfully.");
if response.status().is_success() {
println!("Results shared successfully.");
} else {
println!("Failed to share results. Status: {}", response.status());
}
}
}
Ok(records)
}
fn upload_record(record: &BenchmarkRecord, token: &str, url: &str) {
println!("Sharing results...");
let client = reqwest::blocking::Client::new();
let mut headers = HeaderMap::new();
headers.insert(USER_AGENT, "burnbench".parse().unwrap());
headers.insert(ACCEPT, "application/json".parse().unwrap());
headers.insert(AUTHORIZATION, format!("Bearer {}", token).parse().unwrap());
// post the benchmark record
let response = client
.post(url)
.headers(headers)
.json(record)
.send()
.expect("Request should be sent successfully.");
if response.status().is_success() {
println!("Results shared successfully.");
} else {
println!("Failed to share results. Status: {}", response.status());
}
}
/// Macro to easily serialize each field in a flatten manner.
/// This macro automatically computes the number of fields to serialize
/// and allows specifying a custom serialization key for each field.