mirror of https://github.com/tracel-ai/burn.git
[backend-comparison] Refresh access token and display authenticated user name (#1483)
* [backend-comparison] Serialize both auth tokens to cache file We need to refresh token to be able to renew an expired access token. * [backend-comparison] Refresh access token * [backend-comparison] Display user name with auth command * [backend-comparison] Update README * [backend-comparison] Fix PR comments * [backend-comparison] Fix hyphen consistency in benchmark names * [backend-comparison] Fix release build error when refreshing token The reqwest must have an explicit empty body otherwise the release build returns a 411 when refreshing the tokens without even calling the benchmark server endpoint.
This commit is contained in:
parent
47a84cc980
commit
e8863dafd2
|
@ -208,6 +208,7 @@ dependencies = [
|
|||
"ratatui",
|
||||
"reqwest",
|
||||
"rstest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serial_test",
|
||||
"strum",
|
||||
|
|
|
@ -37,6 +37,7 @@ github-device-flow = { workspace = true }
|
|||
rand = { workspace = true }
|
||||
ratatui = { workspace = true, optional = true }
|
||||
reqwest = {workspace = true, features = ["blocking", "json"]}
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
strum = { workspace = true }
|
||||
strum_macros = { workspace = true }
|
||||
|
@ -54,7 +55,8 @@ name = "binary"
|
|||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "max_pool2d"
|
||||
name = "max-pool2d"
|
||||
path = "benches/max_pool2d.rs"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
|
@ -66,7 +68,8 @@ name = "data"
|
|||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "custom_gelu"
|
||||
name = "custom-gelu"
|
||||
path = "benches/custom_gelu.rs"
|
||||
harness = false
|
||||
|
||||
[[bin]]
|
||||
|
|
|
@ -10,7 +10,10 @@ within the corresponding backend crate.
|
|||
## burnbench CLI
|
||||
|
||||
This crate comes with a CLI binary called `burnbench` which can be executed via
|
||||
`cargo run --bin burnbench`.
|
||||
`cargo run --release --bin burnbench`.
|
||||
|
||||
Note that you need to run the `release` target of `burnbench` otherwise you won't
|
||||
be able to share your benchmark results.
|
||||
|
||||
The end of options argument `--` is used to pass arguments to the `burnbench`
|
||||
application. For instance `cargo run --bin burnbench -- list` passes the `list`
|
||||
|
@ -23,7 +26,7 @@ argument to `burnbench` effectively calling `burnbench list`.
|
|||
To list all the available benches and backends use the `list` command:
|
||||
|
||||
```sh
|
||||
> cargo run --bin burnbench -- list
|
||||
> cargo run --release --bin burnbench -- list
|
||||
Finished dev [unoptimized] target(s) in 0.10s
|
||||
Running `target/debug/burnbench list`
|
||||
Available Backends:
|
||||
|
@ -54,13 +57,13 @@ with the arguments `--benches` and `--backends` respectively. In the following
|
|||
example we execute the `unary` benchmark against the `wgpu-fusion` backend:
|
||||
|
||||
```sh
|
||||
> cargo run --bin burnbench -- run --benches unary --backends wgpu-fusion
|
||||
> cargo run --release --bin burnbench -- run --benches unary --backends wgpu-fusion
|
||||
```
|
||||
|
||||
Shorthands can be used, the following command line is the same:
|
||||
|
||||
```sh
|
||||
> cargo run --bin burnbench -- run -b unary -B wgpu-fusion
|
||||
> cargo run --release --bin burnbench -- run -b unary -B wgpu-fusion
|
||||
```
|
||||
|
||||
Multiple benchmarks and backends can be passed on the same command line. In this
|
||||
|
@ -68,8 +71,7 @@ case, all the combinations of benchmarks with backends will be executed.
|
|||
|
||||
```sh
|
||||
> cargo run --bin burnbench -- run --benches unary binary --backends wgpu-fusion tch-gpu
|
||||
Finished dev [unoptimized] target(s) in 0.09s
|
||||
Running `target/debug/burnbench run --benches unary binary --backends wgpu-fusion wgpu`
|
||||
Running `target/release/burnbench run --benches unary binary --backends wgpu-fusion wgpu`
|
||||
Executing the following benchmark and backend combinations (Total: 4):
|
||||
- Benchmark: unary, Backend: wgpu-fusion
|
||||
- Benchmark: binary, Backend: wgpu-fusion
|
||||
|
@ -88,7 +90,7 @@ Sharing results is opt-in and it is enabled with the `--share` arguments passed
|
|||
to the `run` command:
|
||||
|
||||
```sh
|
||||
> cargo run --bin burnbench -- run --share --benches unary --backends wgpu-fusion
|
||||
> cargo run --release --bin burnbench -- run --share --benches unary --backends wgpu-fusion
|
||||
```
|
||||
|
||||
To be able to upload results you must be authenticated. We only support GitHub
|
||||
|
@ -96,14 +98,26 @@ authentication. To authenticate run the `auth` command, then follow the URL
|
|||
to enter your device code and authorize the Burnbench application:
|
||||
|
||||
```sh
|
||||
> cargo run --bin burnbench -- run auth
|
||||
> cargo run --release --bin burnbench -- auth
|
||||
```
|
||||
|
||||
If everything is fine you should get a confirmation in the terminal that your
|
||||
token has been saved to the burn cache directory.
|
||||
|
||||
We don't store any of your personal information. An anonymized user name will
|
||||
be attributed to you and displayed in the terminal once you are authenticated.
|
||||
For instance:
|
||||
|
||||
```
|
||||
🔑 Your username is: CuteFlame
|
||||
```
|
||||
|
||||
You can now use the `--share` argument to upload and share your benchmarks!
|
||||
|
||||
Note that your access token will be refreshed automatically so you should not
|
||||
need to reauthorize the application again except if your refresh token itself
|
||||
becomes invalid.
|
||||
|
||||
### Terminal UI
|
||||
|
||||
This is a work in progress and is not usable for now.
|
||||
|
|
|
@ -1,18 +1,169 @@
|
|||
use arboard::Clipboard;
|
||||
use burn::serde::{Deserialize, Serialize};
|
||||
use github_device_flow::{self, DeviceFlow};
|
||||
use reqwest;
|
||||
use std::io::Write;
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
use std::{
|
||||
fs::{self, File},
|
||||
path::{Path, PathBuf},
|
||||
thread, time,
|
||||
};
|
||||
|
||||
pub(crate) static CLIENT_ID: &str = "Iv1.692f6a61b6086810";
|
||||
const FIVE_SECONDS: time::Duration = time::Duration::new(5, 0);
|
||||
static GITHUB_API_VERSION_HEADER: &str = "X-GitHub-Api-Version";
|
||||
static GITHUB_API_VERSION: &str = "2022-11-28";
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub(crate) struct Tokens {
|
||||
/// Token returned once the Burnbench Github app has been authorized by the user.
|
||||
/// This token is used to authenticate the user to the Burn benchmark server.
|
||||
/// This token is a short lived token (about 8 hours).
|
||||
pub access_token: String,
|
||||
/// Along with the access token, a refresh token is provided once the Burnbench
|
||||
/// GitHub app has been authorized by the user.
|
||||
/// This token can be presented to the Burn benchmark server in order to re-issue
|
||||
/// a new access token for the user.
|
||||
/// This token is longer lived (around 6 months).
|
||||
pub refresh_token: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct UserInfo {
|
||||
pub nickname: String,
|
||||
}
|
||||
|
||||
/// Retrieve cached tokens and refresh them if necessary then save the new tokens.
|
||||
/// If there is no cached token or if the access token cannot be resfresh then
|
||||
/// ask for the user to reauthorize the Burnbench github application.
|
||||
pub(crate) fn get_tokens() -> Option<Tokens> {
|
||||
get_tokens_from_cache().map_or_else(
|
||||
// no token saved yet
|
||||
auth,
|
||||
// cached tokens found
|
||||
|tokens| {
|
||||
if verify_tokens(&tokens) {
|
||||
Some(tokens)
|
||||
} else {
|
||||
refresh_tokens(&tokens).map_or_else(
|
||||
|| {
|
||||
println!("⚠ Cannot refresh the access token. You need to reauthorize the Burnbench application.");
|
||||
auth()
|
||||
},
|
||||
|new_tokens| {
|
||||
save_tokens(&new_tokens);
|
||||
Some(new_tokens)
|
||||
})
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns the authenticated user name from access token
|
||||
pub(crate) fn get_username(access_token: &str) -> Option<UserInfo> {
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let response = client
|
||||
.get(format!("{}users/me", super::USER_BENCHMARK_SERVER_URL))
|
||||
.header(reqwest::header::USER_AGENT, "burnbench")
|
||||
.header(reqwest::header::CONTENT_TYPE, "application/json")
|
||||
.header(
|
||||
reqwest::header::AUTHORIZATION,
|
||||
format!("Bearer {}", access_token),
|
||||
)
|
||||
.send()
|
||||
.ok()?;
|
||||
response.json::<UserInfo>().ok()
|
||||
}
|
||||
|
||||
fn auth() -> Option<Tokens> {
|
||||
let mut flow = match DeviceFlow::start(CLIENT_ID, None) {
|
||||
Ok(flow) => flow,
|
||||
Err(e) => {
|
||||
eprintln!("Error authenticating: {}", e);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
println!("🌐 Please visit for following URL in your browser (CTRL+click if your terminal supports it):");
|
||||
println!("\n {}\n", flow.verification_uri.clone().unwrap());
|
||||
let user_code = flow.user_code.clone().unwrap();
|
||||
println!("👉 And enter code: {}", &user_code);
|
||||
if let Ok(mut clipboard) = Clipboard::new() {
|
||||
if clipboard.set_text(user_code).is_ok() {
|
||||
println!("📋 Code has been successfully copied to clipboard.")
|
||||
};
|
||||
};
|
||||
// Wait for the minimum allowed interval to poll for authentication update
|
||||
// see: https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/authorizing-oauth-apps#step-3-app-polls-github-to-check-if-the-user-authorized-the-device
|
||||
thread::sleep(FIVE_SECONDS);
|
||||
match flow.poll(20) {
|
||||
Ok(creds) => {
|
||||
let tokens = Tokens {
|
||||
access_token: creds.token.clone(),
|
||||
refresh_token: creds.refresh_token.clone(),
|
||||
};
|
||||
save_tokens(&tokens);
|
||||
Some(tokens)
|
||||
}
|
||||
Err(e) => {
|
||||
eprint!("Authentication error: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the token saved in the cache file
|
||||
#[inline]
|
||||
fn get_tokens_from_cache() -> Option<Tokens> {
|
||||
let path = get_auth_cache_file_path();
|
||||
let file = File::open(path).ok()?;
|
||||
let tokens: Tokens = serde_json::from_reader(file).ok()?;
|
||||
Some(tokens)
|
||||
}
|
||||
|
||||
/// Returns true if the token is still valid
|
||||
fn verify_tokens(tokens: &Tokens) -> bool {
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let response = client
|
||||
.get("https://api.github.com/user")
|
||||
.header(reqwest::header::USER_AGENT, "burnbench")
|
||||
.header(reqwest::header::ACCEPT, "application/vnd.github+json")
|
||||
.header(
|
||||
reqwest::header::AUTHORIZATION,
|
||||
format!("Bearer {}", tokens.access_token),
|
||||
)
|
||||
.header(GITHUB_API_VERSION_HEADER, GITHUB_API_VERSION)
|
||||
.send();
|
||||
response.map_or(false, |resp| resp.status().is_success())
|
||||
}
|
||||
|
||||
fn refresh_tokens(tokens: &Tokens) -> Option<Tokens> {
|
||||
println!("Access token must be refreshed.");
|
||||
println!("Refreshing token...");
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let response = client
|
||||
.post(format!(
|
||||
"{}auth/refresh-token",
|
||||
super::USER_BENCHMARK_SERVER_URL
|
||||
))
|
||||
.header(reqwest::header::USER_AGENT, "burnbench")
|
||||
.header(reqwest::header::CONTENT_TYPE, "application/json")
|
||||
.header(
|
||||
reqwest::header::AUTHORIZATION,
|
||||
format!("Bearer-Refresh {}", tokens.refresh_token),
|
||||
)
|
||||
// it is important to explicitly add an empty body otherwise
|
||||
// reqwest won't send the request in release build
|
||||
.body(reqwest::blocking::Body::from(""))
|
||||
.send();
|
||||
response.ok()?.json::<Tokens>().ok().map(|new_tokens| {
|
||||
println!("✅ Token refreshed!");
|
||||
new_tokens
|
||||
})
|
||||
}
|
||||
|
||||
/// Return the file path for the auth cache on disk
|
||||
pub(crate) fn get_auth_cache_file_path() -> PathBuf {
|
||||
fn get_auth_cache_file_path() -> PathBuf {
|
||||
let home_dir = dirs::home_dir().expect("an home directory should exist");
|
||||
let path_dir = home_dir.join(".cache").join("burn").join("burnbench");
|
||||
#[cfg(test)]
|
||||
|
@ -21,26 +172,13 @@ pub(crate) fn get_auth_cache_file_path() -> PathBuf {
|
|||
path.join("token.txt")
|
||||
}
|
||||
|
||||
/// Returns true if the token is still valid
|
||||
pub(crate) fn verify_token(token: &str) -> bool {
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let response = client
|
||||
.get("https://api.github.com/user")
|
||||
.header(reqwest::header::USER_AGENT, "burnbench")
|
||||
.header(reqwest::header::ACCEPT, "application/vnd.github+json")
|
||||
.header(reqwest::header::AUTHORIZATION, format!("Bearer {}", token))
|
||||
.header(GITHUB_API_VERSION_HEADER, GITHUB_API_VERSION)
|
||||
.send();
|
||||
response.map_or(false, |resp| resp.status().is_success())
|
||||
}
|
||||
|
||||
/// Save token in Burn cache directory and adjust file permissions
|
||||
pub(crate) fn save_token(token: &str) {
|
||||
fn save_tokens(tokens: &Tokens) {
|
||||
let path = get_auth_cache_file_path();
|
||||
fs::create_dir_all(path.parent().expect("path should have a parent directory"))
|
||||
.expect("directory should be created");
|
||||
let mut file = File::create(&path).expect("file should be created");
|
||||
write!(file, "{}", token).expect("token should be written to file");
|
||||
let file = File::create(&path).expect("file should be created");
|
||||
serde_json::to_writer_pretty(file, &tokens).expect("Tokens should be saved to cache file.");
|
||||
// On unix systems we lower the permissions on the cache file to be readable
|
||||
// just by the current user
|
||||
#[cfg(unix)]
|
||||
|
@ -49,23 +187,23 @@ pub(crate) fn save_token(token: &str) {
|
|||
println!("✅ Token saved at location: {}", path.to_str().unwrap());
|
||||
}
|
||||
|
||||
/// Return the token saved in the cache file
|
||||
#[inline]
|
||||
pub(crate) fn get_token_from_cache() -> Option<String> {
|
||||
let path = get_auth_cache_file_path();
|
||||
fs::read_to_string(path)
|
||||
.ok()
|
||||
.and_then(|contents| contents.lines().next().map(str::to_string))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
use serial_test::serial;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rstest::*;
|
||||
use std::fs;
|
||||
|
||||
#[fixture]
|
||||
fn tokens() -> Tokens {
|
||||
Tokens {
|
||||
access_token: "unique_test_token".to_string(),
|
||||
refresh_token: "unique_refresh_token".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn cleanup_test_environment() {
|
||||
let path = get_auth_cache_file_path();
|
||||
if path.exists() {
|
||||
|
@ -79,67 +217,35 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[rstest]
|
||||
#[serial]
|
||||
fn test_save_token_when_file_does_not_exist() {
|
||||
fn test_save_token_when_file_does_not_exist(tokens: Tokens) {
|
||||
cleanup_test_environment();
|
||||
let token = "unique_test_token";
|
||||
// Ensure the file does not exist
|
||||
let path = get_auth_cache_file_path();
|
||||
if path.exists() {
|
||||
fs::remove_file(&path).unwrap();
|
||||
}
|
||||
save_token(token);
|
||||
assert_eq!(fs::read_to_string(path).unwrap(), token);
|
||||
save_tokens(&tokens);
|
||||
let retrieved_tokens = get_tokens_from_cache().unwrap();
|
||||
assert_eq!(retrieved_tokens.access_token, tokens.access_token);
|
||||
assert_eq!(retrieved_tokens.refresh_token, tokens.refresh_token);
|
||||
cleanup_test_environment();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[rstest]
|
||||
#[serial]
|
||||
fn test_overwrite_saved_token_when_file_already_exists() {
|
||||
fn test_overwrite_saved_token_when_file_already_exists(tokens: Tokens) {
|
||||
cleanup_test_environment();
|
||||
let initial_token = "initial_test_token";
|
||||
let new_token = "new_test_token";
|
||||
// Save initial token
|
||||
save_token(initial_token);
|
||||
// Save new token that should overwrite the initial one
|
||||
save_token(new_token);
|
||||
let path = get_auth_cache_file_path();
|
||||
assert_eq!(fs::read_to_string(path).unwrap(), new_token);
|
||||
cleanup_test_environment();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn test_get_saved_token_from_cache_when_it_exists() {
|
||||
cleanup_test_environment();
|
||||
let token = "existing_test_token";
|
||||
// Save the token first
|
||||
save_token(token);
|
||||
// Now retrieve it
|
||||
let retrieved_token = get_token_from_cache().unwrap();
|
||||
assert_eq!(retrieved_token, token);
|
||||
cleanup_test_environment();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn test_return_only_first_line_of_cache_as_token() {
|
||||
cleanup_test_environment();
|
||||
let path = get_auth_cache_file_path();
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).expect("directory tree should be created");
|
||||
}
|
||||
// Create a file with multiple lines
|
||||
let mut file = File::create(&path).expect("test file should be created");
|
||||
write!(file, "first_line_token\nsecond_line\nthird_line")
|
||||
.expect("test file should contain several lines");
|
||||
// Test that only the first line is returned as the token
|
||||
let token = get_token_from_cache().expect("token should be present");
|
||||
assert_eq!(
|
||||
token, "first_line_token",
|
||||
"The token should match only the first line of the file"
|
||||
);
|
||||
save_tokens(&tokens);
|
||||
let new_tokens = Tokens {
|
||||
access_token: "new_test_token".to_string(),
|
||||
refresh_token: "new_refresh_token".to_string(),
|
||||
};
|
||||
save_tokens(&new_tokens);
|
||||
let retrieved_tokens = get_tokens_from_cache().unwrap();
|
||||
assert_eq!(retrieved_tokens.access_token, new_tokens.access_token);
|
||||
assert_eq!(retrieved_tokens.refresh_token, new_tokens.refresh_token);
|
||||
cleanup_test_environment();
|
||||
}
|
||||
|
||||
|
@ -152,7 +258,7 @@ mod tests {
|
|||
if path.exists() {
|
||||
fs::remove_file(&path).unwrap();
|
||||
}
|
||||
assert!(get_token_from_cache().is_none());
|
||||
assert!(get_tokens_from_cache().is_none());
|
||||
cleanup_test_environment();
|
||||
}
|
||||
|
||||
|
@ -167,7 +273,7 @@ mod tests {
|
|||
}
|
||||
File::create(&path).expect("empty file should be created");
|
||||
assert!(
|
||||
get_token_from_cache().is_none(),
|
||||
get_tokens_from_cache().is_none(),
|
||||
"Expected None for empty cache file, got Some"
|
||||
);
|
||||
cleanup_test_environment();
|
||||
|
|
|
@ -1,31 +1,18 @@
|
|||
use super::{
|
||||
auth::{save_token, CLIENT_ID},
|
||||
App,
|
||||
};
|
||||
use crate::burnbenchapp::auth::{get_token_from_cache, verify_token};
|
||||
use crate::persistence::{BenchmarkCollection, BenchmarkRecord};
|
||||
use arboard::Clipboard;
|
||||
use clap::{Parser, Subcommand, ValueEnum};
|
||||
use github_device_flow::{self, DeviceFlow};
|
||||
use serde_json;
|
||||
use std::fs;
|
||||
use std::io::{BufRead, BufReader, Result as ioResult};
|
||||
use std::process::ExitStatus;
|
||||
use std::{
|
||||
process::{Command, ExitStatus, Stdio},
|
||||
thread, time,
|
||||
fs,
|
||||
io::{BufRead, BufReader, Result as ioResult},
|
||||
process::{Command, Stdio},
|
||||
};
|
||||
|
||||
use strum::IntoEnumIterator;
|
||||
use strum_macros::{Display, EnumIter};
|
||||
const FIVE_SECONDS: time::Duration = time::Duration::new(5, 0);
|
||||
const BENCHMARKS_TARGET_DIR: &str = "target/benchmarks";
|
||||
const USER_BENCHMARK_SERVER_URL: &str = if cfg!(debug_assertions) {
|
||||
// development
|
||||
"http://localhost:8000/benchmarks"
|
||||
} else {
|
||||
// production
|
||||
"https://user-benchmark-server-gvtbw64teq-nn.a.run.app/benchmarks"
|
||||
};
|
||||
|
||||
use crate::burnbenchapp::auth::Tokens;
|
||||
use crate::persistence::{BenchmarkCollection, BenchmarkRecord};
|
||||
|
||||
use super::auth::get_username;
|
||||
use super::{auth::get_tokens, App};
|
||||
|
||||
/// Base trait to define an application
|
||||
pub(crate) trait Application {
|
||||
|
@ -105,7 +92,7 @@ pub(crate) enum BackendValues {
|
|||
pub(crate) enum BenchmarkValues {
|
||||
#[strum(to_string = "binary")]
|
||||
Binary,
|
||||
#[strum(to_string = "custom_gelu")]
|
||||
#[strum(to_string = "custom-gelu")]
|
||||
CustomGelu,
|
||||
#[strum(to_string = "data")]
|
||||
Data,
|
||||
|
@ -113,7 +100,7 @@ pub(crate) enum BenchmarkValues {
|
|||
Matmul,
|
||||
#[strum(to_string = "unary")]
|
||||
Unary,
|
||||
#[strum(to_string = "max_pool2d")]
|
||||
#[strum(to_string = "max-pool2d")]
|
||||
MaxPool2d,
|
||||
}
|
||||
|
||||
|
@ -126,34 +113,17 @@ pub fn execute() {
|
|||
}
|
||||
}
|
||||
|
||||
/// Create an access token from GitHub Burnbench application and store it
|
||||
/// to be used with the user benchmark backend.
|
||||
/// Create an access token from GitHub Burnbench application, store it,
|
||||
/// and display the name of the authenticated user.
|
||||
fn command_auth() {
|
||||
let mut flow = match DeviceFlow::start(CLIENT_ID, None) {
|
||||
Ok(flow) => flow,
|
||||
Err(e) => {
|
||||
eprintln!("Error authenticating: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
println!("🌐 Please visit for following URL in your browser (CTRL+click if your terminal supports it):");
|
||||
println!("\n {}\n", flow.verification_uri.clone().unwrap());
|
||||
let user_code = flow.user_code.clone().unwrap();
|
||||
println!("👉 And enter code: {}", &user_code);
|
||||
if let Ok(mut clipboard) = Clipboard::new() {
|
||||
if clipboard.set_text(user_code).is_ok() {
|
||||
println!("📋 Code has been successfully copied to clipboard.")
|
||||
};
|
||||
};
|
||||
// Wait for the minimum allowed interval to poll for authentication update
|
||||
// see: https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/authorizing-oauth-apps#step-3-app-polls-github-to-check-if-the-user-authorized-the-device
|
||||
thread::sleep(FIVE_SECONDS);
|
||||
match flow.poll(20) {
|
||||
Ok(creds) => {
|
||||
save_token(&creds.token);
|
||||
}
|
||||
Err(e) => eprint!("Authentication error: {}", e),
|
||||
};
|
||||
get_tokens()
|
||||
.and_then(|t| get_username(&t.access_token))
|
||||
.map(|user_info| {
|
||||
println!("🔑 Your username is: {}", user_info.nickname);
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
println!("Failed to display your username.");
|
||||
});
|
||||
}
|
||||
|
||||
fn command_list() {
|
||||
|
@ -168,21 +138,9 @@ fn command_list() {
|
|||
}
|
||||
|
||||
fn command_run(run_args: RunArgs) {
|
||||
let token = get_token_from_cache();
|
||||
let mut tokens: Option<Tokens> = None;
|
||||
if run_args.share {
|
||||
// Verify if a token is saved
|
||||
if token.is_none() {
|
||||
eprintln!("You need to be authenticated to be able to share benchmark results.");
|
||||
eprintln!("Run the command 'burnbench auth' to authenticate.");
|
||||
return;
|
||||
}
|
||||
// TODO refresh the token when it is expired
|
||||
// Check for the validity of the saved token
|
||||
if !verify_token(token.as_deref().unwrap()) {
|
||||
eprintln!("Your access token is no longer valid.");
|
||||
eprintln!("Run the command 'burnbench auth' again to get a new token.");
|
||||
return;
|
||||
}
|
||||
tokens = get_tokens();
|
||||
}
|
||||
let total_combinations = run_args.backends.len() * run_args.benches.len();
|
||||
println!(
|
||||
|
@ -192,10 +150,11 @@ fn command_run(run_args: RunArgs) {
|
|||
let mut app = App::new();
|
||||
app.init();
|
||||
println!("Running benchmarks...\n");
|
||||
let access_token = tokens.map(|t| t.access_token);
|
||||
app.run(
|
||||
&run_args.benches,
|
||||
&run_args.backends,
|
||||
token.as_deref().filter(|_| run_args.share),
|
||||
access_token.as_deref(),
|
||||
);
|
||||
app.cleanup();
|
||||
}
|
||||
|
@ -240,6 +199,7 @@ pub(crate) fn run_backend_comparison_benchmarks(
|
|||
"{}Benchmarking {} on {}{}",
|
||||
filler, bench_str, backend_str, filler
|
||||
);
|
||||
let url = format!("{}benchmarks", super::USER_BENCHMARK_SERVER_URL);
|
||||
let mut args = vec![
|
||||
"-p",
|
||||
"backend-comparison",
|
||||
|
@ -248,12 +208,12 @@ pub(crate) fn run_backend_comparison_benchmarks(
|
|||
"--features",
|
||||
&backend_str,
|
||||
"--target-dir",
|
||||
BENCHMARKS_TARGET_DIR,
|
||||
super::BENCHMARKS_TARGET_DIR,
|
||||
];
|
||||
if let Some(t) = token {
|
||||
args.push("--");
|
||||
args.push("--sharing-url");
|
||||
args.push(USER_BENCHMARK_SERVER_URL);
|
||||
args.push(url.as_str());
|
||||
args.push("--sharing-token");
|
||||
args.push(t);
|
||||
}
|
||||
|
|
|
@ -12,3 +12,12 @@ use tui::TuiApplication as App;
|
|||
mod term;
|
||||
#[cfg(not(feature = "tui"))]
|
||||
use term::TermApplication as App;
|
||||
|
||||
const BENCHMARKS_TARGET_DIR: &str = "target/benchmarks";
|
||||
const USER_BENCHMARK_SERVER_URL: &str = if cfg!(debug_assertions) {
|
||||
// development
|
||||
"http://localhost:8000/"
|
||||
} else {
|
||||
// production
|
||||
"https://user-benchmark-server-gvtbw64teq-nn.a.run.app/"
|
||||
};
|
||||
|
|
|
@ -82,7 +82,8 @@ pub fn save<B: Backend>(
|
|||
serde_json::to_writer_pretty(file, &record)
|
||||
.expect("Benchmark file should be updated with benchmark results");
|
||||
|
||||
// Append the benchmark result filepath in the benchmark_results.tx file of cache folder to be later picked by benchrun
|
||||
// Append the benchmark result filepath in the benchmark_results.tx file of
|
||||
// cache folder to be later picked by benchrun
|
||||
let benchmark_results_path = cache_dir.join("benchmark_results.txt");
|
||||
let mut benchmark_results_file = fs::OpenOptions::new()
|
||||
.append(true)
|
||||
|
@ -93,39 +94,39 @@ pub fn save<B: Backend>(
|
|||
.write_all(format!("{}\n", file_path.to_string_lossy()).as_bytes())
|
||||
.unwrap();
|
||||
|
||||
if url.is_some() {
|
||||
println!("Sharing results...");
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(USER_AGENT, "burnbench".parse().unwrap());
|
||||
headers.insert(ACCEPT, "application/json".parse().unwrap());
|
||||
headers.insert(
|
||||
AUTHORIZATION,
|
||||
format!(
|
||||
"Bearer {}",
|
||||
token.expect("An auth token should be provided.")
|
||||
)
|
||||
.parse()
|
||||
.unwrap(),
|
||||
if let Some(upload_url) = url {
|
||||
upload_record(
|
||||
&record,
|
||||
token.expect("An auth token should be provided."),
|
||||
upload_url,
|
||||
);
|
||||
// post the benchmark record
|
||||
let response = client
|
||||
.post(url.expect("A benchmark server URL should be provided."))
|
||||
.headers(headers)
|
||||
.json(&record)
|
||||
.send()
|
||||
.expect("Request should be sent successfully.");
|
||||
if response.status().is_success() {
|
||||
println!("Results shared successfully.");
|
||||
} else {
|
||||
println!("Failed to share results. Status: {}", response.status());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
fn upload_record(record: &BenchmarkRecord, token: &str, url: &str) {
|
||||
println!("Sharing results...");
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(USER_AGENT, "burnbench".parse().unwrap());
|
||||
headers.insert(ACCEPT, "application/json".parse().unwrap());
|
||||
headers.insert(AUTHORIZATION, format!("Bearer {}", token).parse().unwrap());
|
||||
// post the benchmark record
|
||||
let response = client
|
||||
.post(url)
|
||||
.headers(headers)
|
||||
.json(record)
|
||||
.send()
|
||||
.expect("Request should be sent successfully.");
|
||||
if response.status().is_success() {
|
||||
println!("Results shared successfully.");
|
||||
} else {
|
||||
println!("Failed to share results. Status: {}", response.status());
|
||||
}
|
||||
}
|
||||
|
||||
/// Macro to easily serialize each field in a flatten manner.
|
||||
/// This macro automatically computes the number of fields to serialize
|
||||
/// and allows specifying a custom serialization key for each field.
|
||||
|
|
Loading…
Reference in New Issue