Add `burn::data::network::downloader` (#1283)

This commit is contained in:
Guillaume Lagrange 2024-02-10 11:54:33 -05:00 committed by GitHub
parent fb6cc2db62
commit 88f5a3e88c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 106 additions and 115 deletions

View File

@ -16,6 +16,8 @@ jobs:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
with:
crate: burn-dataset
needs:
- publish-burn-common
secrets: inherit
publish-burn-common:

7
Cargo.lock generated
View File

@ -340,9 +340,12 @@ dependencies = [
"dashmap",
"derive-new",
"getrandom",
"indicatif",
"rand",
"reqwest",
"serde",
"spin",
"tokio",
"uuid",
"web-time",
]
@ -399,6 +402,7 @@ dependencies = [
name = "burn-dataset"
version = "0.13.0"
dependencies = [
"burn-common",
"csv",
"derive-new",
"dirs 5.0.1",
@ -408,12 +412,10 @@ dependencies = [
"globwalk",
"hound",
"image",
"indicatif",
"r2d2",
"r2d2_sqlite",
"rand",
"rayon",
"reqwest",
"rmp-serde",
"rstest",
"rusqlite",
@ -425,7 +427,6 @@ dependencies = [
"strum_macros",
"tempfile",
"thiserror",
"tokio",
]
[[package]]

View File

@ -15,6 +15,7 @@ default = ["std"]
std = ["rand/std"]
doc = ["default"]
wasm-sync = []
network = ["dep:indicatif", "dep:reqwest", "dep:tokio"]
[target.'cfg(target_family = "wasm")'.dependencies]
async-trait = { workspace = true }
@ -31,8 +32,13 @@ uuid = { workspace = true }
derive-new = { workspace = true }
serde = { workspace = true }
# Network downloader
indicatif = { workspace = true, optional = true }
reqwest = { workspace = true, optional = true }
tokio = { workspace = true, optional = true }
[dev-dependencies]
dashmap = { workspace = true }
[package.metadata.docs.rs]
features = ["doc"]
features = ["doc"]

View File

@ -26,3 +26,7 @@ pub mod benchmark;
pub mod reader;
extern crate alloc;
/// Network utilities.
#[cfg(feature = "network")]
pub mod network;

View File

@ -0,0 +1,57 @@
/// Network download utilities.
pub mod downloader {
use indicatif::{ProgressBar, ProgressState, ProgressStyle};
use reqwest::Client;
#[cfg(feature = "std")]
use std::io::Write;
/// Download the file at the specified url.
/// File download progress is reported with the help of a [progress bar](indicatif).
///
/// # Arguments
///
/// * `url` - The file URL to download.
/// * `message` - The message to display on the progress bar during download.
///
/// # Returns
///
/// A vector of bytes containing the downloaded file data.
#[cfg(feature = "std")]
#[tokio::main(flavor = "current_thread")]
pub async fn download_file_as_bytes(url: &str, message: &str) -> Vec<u8> {
// Get file from web
let mut response = Client::new().get(url).send().await.unwrap();
let total_size = response.content_length().unwrap();
// Pretty progress bar
let pb = ProgressBar::new(total_size);
let msg = message.to_owned();
pb.set_style(
ProgressStyle::with_template(
"{msg}\n {wide_bar:.cyan/blue} {bytes}/{total_bytes} ({eta})",
)
.unwrap()
.with_key(
"eta",
|state: &ProgressState, w: &mut dyn std::fmt::Write| {
write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()
},
)
.progress_chars(""),
);
pb.set_message(msg.clone());
// Read stream into bytes
let mut downloaded: u64 = 0;
let mut bytes: Vec<u8> = Vec::with_capacity(total_size as usize);
while let Some(chunk) = response.chunk().await.unwrap() {
let num_bytes = bytes.write(&chunk).unwrap();
let new = std::cmp::min(downloaded + (num_bytes as u64), total_size);
downloaded = new;
pb.set_position(new);
}
pb.finish_with_message(msg);
bytes
}
}

View File

@ -57,9 +57,10 @@ doc = [
"burn-wgpu/doc",
]
dataset = ["burn-dataset"]
network = ["burn-common/network"]
sqlite = ["burn-dataset?/sqlite"]
sqlite-bundled = ["burn-dataset?/sqlite-bundled"]
vision = ["burn-dataset?/vision"]
vision = ["burn-dataset?/vision", "burn-common/network"]
wasm-sync = ["burn-tensor/wasm-sync", "burn-common/wasm-sync"]

View File

@ -7,3 +7,9 @@ pub mod dataloader;
pub mod dataset {
pub use burn_dataset::*;
}
/// Network module.
#[cfg(feature = "network")]
pub mod network {
pub use burn_common::network::*;
}

View File

@ -21,7 +21,7 @@ fake = ["dep:fake"]
sqlite = ["__sqlite-shared", "dep:rusqlite"]
sqlite-bundled = ["__sqlite-shared", "rusqlite/bundled"]
vision = ["dep:flate2", "dep:globwalk", "dep:image", "dep:indicatif", "dep:reqwest", "dep:tokio"]
vision = ["dep:flate2", "dep:globwalk", "dep:burn-common"]
# internal
__sqlite-shared = [
@ -33,20 +33,21 @@ __sqlite-shared = [
]
[dependencies]
burn-common = { path = "../burn-common", version = "0.13.0", optional = true, features = [
"network",
] }
csv = { workspace = true }
derive-new = { workspace = true }
dirs = { workspace = true }
fake = { workspace = true, optional = true }
flate2 = { workspace = true, optional = true }
gix-tempfile = { workspace = true, optional = true }
globwalk = { workspace = true, optional = true}
globwalk = { workspace = true, optional = true }
hound = { workspace = true, optional = true }
image = { workspace = true, optional = true }
indicatif = { workspace = true, optional = true }
r2d2 = { workspace = true, optional = true }
r2d2_sqlite = { workspace = true, optional = true }
rand = { workspace = true, features = ["std"] }
reqwest = { workspace = true, optional = true }
rmp-serde = { workspace = true }
rusqlite = { workspace = true, optional = true }
sanitize-filename = { workspace = true }
@ -57,7 +58,6 @@ strum = { workspace = true }
strum_macros = { workspace = true }
tempfile = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, optional = true }
[dev-dependencies]
rayon = { workspace = true }
@ -68,4 +68,4 @@ fake = { workspace = true }
normal = ["strum", "strum_macros"]
[package.metadata.docs.rs]
features = ["doc"]
features = ["doc"]

View File

@ -1,44 +0,0 @@
use indicatif::{ProgressBar, ProgressState, ProgressStyle};
use reqwest::Client;
use std::cmp::min;
use std::io::Write;
/// Download the file at the specified url to a bytes vector.
/// File download progress is reported with the help of a [progress bar](indicatif).
#[tokio::main(flavor = "current_thread")]
pub async fn download_file_as_bytes(url: &str, message: &str) -> Vec<u8> {
// Get file from web
let mut response = Client::new().get(url).send().await.unwrap();
let total_size = response.content_length().unwrap();
// Pretty progress bar
let pb = ProgressBar::new(total_size);
let msg = message.to_owned();
pb.set_style(
ProgressStyle::with_template(
"{msg}\n {wide_bar:.cyan/blue} {bytes}/{total_bytes} ({eta})",
)
.unwrap()
.with_key(
"eta",
|state: &ProgressState, w: &mut dyn std::fmt::Write| {
write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()
},
)
.progress_chars(""),
);
pb.set_message(msg.clone());
// Read stream into bytes
let mut downloaded: u64 = 0;
let mut bytes: Vec<u8> = Vec::with_capacity(total_size as usize);
while let Some(chunk) = response.chunk().await.unwrap() {
let num_bytes = bytes.write(&chunk).unwrap();
let new = min(downloaded + (num_bytes as u64), total_size);
downloaded = new;
pb.set_position(new);
}
pb.finish_with_message(msg);
bytes
}

View File

@ -10,6 +10,8 @@ use crate::{
Dataset, InMemDataset,
};
use burn_common::network::downloader::download_file_as_bytes;
// CVDF mirror of http://yann.lecun.com/exdb/mnist/
const URL: &str = "https://storage.googleapis.com/cvdf-datasets/mnist/";
const TRAIN_IMAGES: &str = "train-images-idx3-ubyte";
@ -151,7 +153,7 @@ impl MNISTDataset {
if !file_name.exists() {
// Download gzip file
let bytes = super::downloader::download_file_as_bytes(&format!("{URL}{name}.gz"), name);
let bytes = download_file_as_bytes(&format!("{URL}{name}.gz"), name);
// Create file to write the downloaded content to
let mut output_file = File::create(&file_name).unwrap();

View File

@ -1,4 +1,3 @@
mod downloader;
mod image_folder;
mod mnist;

View File

@ -58,6 +58,9 @@ wgpu = ["burn-core/wgpu"]
tch = ["burn-core/tch"]
candle = ["burn-core/candle"]
# Network utils
network = ["burn-core/network"]
# Experimental
experimental-named-tensor = ["burn-core/experimental-named-tensor"]

View File

@ -80,6 +80,7 @@
//! - `autodiff`: Makes available the Autodiff backend
//! - Others:
//! - `std`: Activates the standard library (deactivate for no_std)
//! - `network`: Enables network utilities (currently, only a file downloader with progress bar)
//! - `experimental-named-tensor`: Enables named tensors (experimental)
pub use burn_core::*;

View File

@ -13,7 +13,7 @@ tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]
[dependencies]
burn = {path = "../../burn", features = ["train", "vision"]}
burn = { path = "../../burn", features = ["train", "vision", "network"] }
# File download
flate2 = { workspace = true }

View File

@ -2,7 +2,7 @@ use flate2::read::GzDecoder;
use std::path::{Path, PathBuf};
use tar::Archive;
use burn::data::dataset::vision::ImageFolderDataset;
use burn::data::{dataset::vision::ImageFolderDataset, network::downloader};
/// CIFAR-10 mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L44).
/// Licensed under the [Appache License](https://github.com/fastai/fastai/blob/master/LICENSE).
@ -44,7 +44,7 @@ fn download() -> PathBuf {
let labels_file = cifar_dir.join("labels.txt");
if !labels_file.exists() {
// Download gzip file
let bytes = super::downloader::download_file_as_bytes(URL, "cifar10.tgz");
let bytes = downloader::download_file_as_bytes(URL, "cifar10.tgz");
// Decode gzip file content and unpack archive
let gz_buffer = GzDecoder::new(&bytes[..]);

View File

@ -1,46 +0,0 @@
use indicatif::{ProgressBar, ProgressState, ProgressStyle};
use reqwest::Client;
use std::cmp::min;
use std::io::Write;
/// Download the file at the specified url to a bytes vector.
/// File download progress is reported with the help of a [progress bar](indicatif).
///
/// Taken from [burn-dataset](https://github.com/tracel-ai/burn/blob/main/burn-dataset/src/vision/downloader.rs).
#[tokio::main(flavor = "current_thread")]
pub async fn download_file_as_bytes(url: &str, message: &str) -> Vec<u8> {
// Get file from web
let mut response = Client::new().get(url).send().await.unwrap();
let total_size = response.content_length().unwrap();
// Pretty progress bar
let pb = ProgressBar::new(total_size);
let msg = message.to_owned();
pb.set_style(
ProgressStyle::with_template(
"{msg}\n {wide_bar:.cyan/blue} {bytes}/{total_bytes} ({eta})",
)
.unwrap()
.with_key(
"eta",
|state: &ProgressState, w: &mut dyn std::fmt::Write| {
write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()
},
)
.progress_chars(""),
);
pb.set_message(msg.clone());
// Read stream into bytes
let mut downloaded: u64 = 0;
let mut bytes: Vec<u8> = Vec::with_capacity(total_size as usize);
while let Some(chunk) = response.chunk().await.unwrap() {
let num_bytes = bytes.write(&chunk).unwrap();
let new = min(downloaded + (num_bytes as u64), total_size);
downloaded = new;
pb.set_position(new);
}
pb.finish_with_message(msg);
bytes
}

View File

@ -1,5 +1,4 @@
pub mod data;
pub mod dataset;
mod downloader;
pub mod model;
pub mod training;

View File

@ -151,14 +151,14 @@ impl<B: Backend> Model<B> {
let output = self.model.forward(input);
// Convert the model output into probability distribution using softmax formula
let probabilies = softmax(output, 1);
let probabilities = softmax(output, 1);
#[cfg(not(target_family = "wasm"))]
let result = probabilies.into_data().convert::<f32>().value;
let result = probabilities.into_data().convert::<f32>().value;
// Forces the result to be computed
#[cfg(target_family = "wasm")]
let result = probabilies.into_data().await.convert::<f32>().value;
let result = probabilities.into_data().await.convert::<f32>().value;
result
}
@ -173,18 +173,18 @@ pub struct InferenceResult {
}
/// Returns the top 5 classes and convert them into a JsValue
fn top_5_classes(probabilies: Vec<f32>) -> Result<JsValue, JsValue> {
fn top_5_classes(probabilities: Vec<f32>) -> Result<JsValue, JsValue> {
// Convert the probabilities into a vector of (index, probability)
let mut probabilies: Vec<_> = probabilies.iter().enumerate().collect();
let mut probabilities: Vec<_> = probabilities.iter().enumerate().collect();
// Sort the probabilities in descending order
probabilies.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
probabilities.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
// Take the top 5 probabilities
probabilies.truncate(5);
probabilities.truncate(5);
// Convert the probabilities into InferenceResult
let result: Vec<InferenceResult> = probabilies
let result: Vec<InferenceResult> = probabilities
.into_iter()
.map(|(index, probability)| InferenceResult {
index,