mirror of https://github.com/tracel-ai/burn.git
Add `burn::data::network::downloader` (#1283)
This commit is contained in:
parent
fb6cc2db62
commit
88f5a3e88c
|
@ -16,6 +16,8 @@ jobs:
|
|||
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main
|
||||
with:
|
||||
crate: burn-dataset
|
||||
needs:
|
||||
- publish-burn-common
|
||||
secrets: inherit
|
||||
|
||||
publish-burn-common:
|
||||
|
|
|
@ -340,9 +340,12 @@ dependencies = [
|
|||
"dashmap",
|
||||
"derive-new",
|
||||
"getrandom",
|
||||
"indicatif",
|
||||
"rand",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"spin",
|
||||
"tokio",
|
||||
"uuid",
|
||||
"web-time",
|
||||
]
|
||||
|
@ -399,6 +402,7 @@ dependencies = [
|
|||
name = "burn-dataset"
|
||||
version = "0.13.0"
|
||||
dependencies = [
|
||||
"burn-common",
|
||||
"csv",
|
||||
"derive-new",
|
||||
"dirs 5.0.1",
|
||||
|
@ -408,12 +412,10 @@ dependencies = [
|
|||
"globwalk",
|
||||
"hound",
|
||||
"image",
|
||||
"indicatif",
|
||||
"r2d2",
|
||||
"r2d2_sqlite",
|
||||
"rand",
|
||||
"rayon",
|
||||
"reqwest",
|
||||
"rmp-serde",
|
||||
"rstest",
|
||||
"rusqlite",
|
||||
|
@ -425,7 +427,6 @@ dependencies = [
|
|||
"strum_macros",
|
||||
"tempfile",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
@ -15,6 +15,7 @@ default = ["std"]
|
|||
std = ["rand/std"]
|
||||
doc = ["default"]
|
||||
wasm-sync = []
|
||||
network = ["dep:indicatif", "dep:reqwest", "dep:tokio"]
|
||||
|
||||
[target.'cfg(target_family = "wasm")'.dependencies]
|
||||
async-trait = { workspace = true }
|
||||
|
@ -31,8 +32,13 @@ uuid = { workspace = true }
|
|||
derive-new = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
|
||||
# Network downloader
|
||||
indicatif = { workspace = true, optional = true }
|
||||
reqwest = { workspace = true, optional = true }
|
||||
tokio = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
dashmap = { workspace = true }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
features = ["doc"]
|
||||
|
|
|
@ -26,3 +26,7 @@ pub mod benchmark;
|
|||
pub mod reader;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
/// Network utilities.
|
||||
#[cfg(feature = "network")]
|
||||
pub mod network;
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
/// Network download utilities.
|
||||
pub mod downloader {
|
||||
use indicatif::{ProgressBar, ProgressState, ProgressStyle};
|
||||
use reqwest::Client;
|
||||
#[cfg(feature = "std")]
|
||||
use std::io::Write;
|
||||
|
||||
/// Download the file at the specified url.
|
||||
/// File download progress is reported with the help of a [progress bar](indicatif).
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `url` - The file URL to download.
|
||||
/// * `message` - The message to display on the progress bar during download.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of bytes containing the downloaded file data.
|
||||
#[cfg(feature = "std")]
|
||||
#[tokio::main(flavor = "current_thread")]
|
||||
pub async fn download_file_as_bytes(url: &str, message: &str) -> Vec<u8> {
|
||||
// Get file from web
|
||||
let mut response = Client::new().get(url).send().await.unwrap();
|
||||
let total_size = response.content_length().unwrap();
|
||||
|
||||
// Pretty progress bar
|
||||
let pb = ProgressBar::new(total_size);
|
||||
let msg = message.to_owned();
|
||||
pb.set_style(
|
||||
ProgressStyle::with_template(
|
||||
"{msg}\n {wide_bar:.cyan/blue} {bytes}/{total_bytes} ({eta})",
|
||||
)
|
||||
.unwrap()
|
||||
.with_key(
|
||||
"eta",
|
||||
|state: &ProgressState, w: &mut dyn std::fmt::Write| {
|
||||
write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()
|
||||
},
|
||||
)
|
||||
.progress_chars("▬ "),
|
||||
);
|
||||
pb.set_message(msg.clone());
|
||||
|
||||
// Read stream into bytes
|
||||
let mut downloaded: u64 = 0;
|
||||
let mut bytes: Vec<u8> = Vec::with_capacity(total_size as usize);
|
||||
while let Some(chunk) = response.chunk().await.unwrap() {
|
||||
let num_bytes = bytes.write(&chunk).unwrap();
|
||||
let new = std::cmp::min(downloaded + (num_bytes as u64), total_size);
|
||||
downloaded = new;
|
||||
pb.set_position(new);
|
||||
}
|
||||
pb.finish_with_message(msg);
|
||||
|
||||
bytes
|
||||
}
|
||||
}
|
|
@ -57,9 +57,10 @@ doc = [
|
|||
"burn-wgpu/doc",
|
||||
]
|
||||
dataset = ["burn-dataset"]
|
||||
network = ["burn-common/network"]
|
||||
sqlite = ["burn-dataset?/sqlite"]
|
||||
sqlite-bundled = ["burn-dataset?/sqlite-bundled"]
|
||||
vision = ["burn-dataset?/vision"]
|
||||
vision = ["burn-dataset?/vision", "burn-common/network"]
|
||||
|
||||
wasm-sync = ["burn-tensor/wasm-sync", "burn-common/wasm-sync"]
|
||||
|
||||
|
|
|
@ -7,3 +7,9 @@ pub mod dataloader;
|
|||
pub mod dataset {
|
||||
pub use burn_dataset::*;
|
||||
}
|
||||
|
||||
/// Network module.
|
||||
#[cfg(feature = "network")]
|
||||
pub mod network {
|
||||
pub use burn_common::network::*;
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ fake = ["dep:fake"]
|
|||
sqlite = ["__sqlite-shared", "dep:rusqlite"]
|
||||
sqlite-bundled = ["__sqlite-shared", "rusqlite/bundled"]
|
||||
|
||||
vision = ["dep:flate2", "dep:globwalk", "dep:image", "dep:indicatif", "dep:reqwest", "dep:tokio"]
|
||||
vision = ["dep:flate2", "dep:globwalk", "dep:burn-common"]
|
||||
|
||||
# internal
|
||||
__sqlite-shared = [
|
||||
|
@ -33,20 +33,21 @@ __sqlite-shared = [
|
|||
]
|
||||
|
||||
[dependencies]
|
||||
burn-common = { path = "../burn-common", version = "0.13.0", optional = true, features = [
|
||||
"network",
|
||||
] }
|
||||
csv = { workspace = true }
|
||||
derive-new = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
fake = { workspace = true, optional = true }
|
||||
flate2 = { workspace = true, optional = true }
|
||||
gix-tempfile = { workspace = true, optional = true }
|
||||
globwalk = { workspace = true, optional = true}
|
||||
globwalk = { workspace = true, optional = true }
|
||||
hound = { workspace = true, optional = true }
|
||||
image = { workspace = true, optional = true }
|
||||
indicatif = { workspace = true, optional = true }
|
||||
r2d2 = { workspace = true, optional = true }
|
||||
r2d2_sqlite = { workspace = true, optional = true }
|
||||
rand = { workspace = true, features = ["std"] }
|
||||
reqwest = { workspace = true, optional = true }
|
||||
rmp-serde = { workspace = true }
|
||||
rusqlite = { workspace = true, optional = true }
|
||||
sanitize-filename = { workspace = true }
|
||||
|
@ -57,7 +58,6 @@ strum = { workspace = true }
|
|||
strum_macros = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
rayon = { workspace = true }
|
||||
|
@ -68,4 +68,4 @@ fake = { workspace = true }
|
|||
normal = ["strum", "strum_macros"]
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
features = ["doc"]
|
||||
|
|
|
@ -1,44 +0,0 @@
|
|||
use indicatif::{ProgressBar, ProgressState, ProgressStyle};
|
||||
use reqwest::Client;
|
||||
use std::cmp::min;
|
||||
use std::io::Write;
|
||||
|
||||
/// Download the file at the specified url to a bytes vector.
|
||||
/// File download progress is reported with the help of a [progress bar](indicatif).
|
||||
#[tokio::main(flavor = "current_thread")]
|
||||
pub async fn download_file_as_bytes(url: &str, message: &str) -> Vec<u8> {
|
||||
// Get file from web
|
||||
let mut response = Client::new().get(url).send().await.unwrap();
|
||||
let total_size = response.content_length().unwrap();
|
||||
|
||||
// Pretty progress bar
|
||||
let pb = ProgressBar::new(total_size);
|
||||
let msg = message.to_owned();
|
||||
pb.set_style(
|
||||
ProgressStyle::with_template(
|
||||
"{msg}\n {wide_bar:.cyan/blue} {bytes}/{total_bytes} ({eta})",
|
||||
)
|
||||
.unwrap()
|
||||
.with_key(
|
||||
"eta",
|
||||
|state: &ProgressState, w: &mut dyn std::fmt::Write| {
|
||||
write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()
|
||||
},
|
||||
)
|
||||
.progress_chars("▬ "),
|
||||
);
|
||||
pb.set_message(msg.clone());
|
||||
|
||||
// Read stream into bytes
|
||||
let mut downloaded: u64 = 0;
|
||||
let mut bytes: Vec<u8> = Vec::with_capacity(total_size as usize);
|
||||
while let Some(chunk) = response.chunk().await.unwrap() {
|
||||
let num_bytes = bytes.write(&chunk).unwrap();
|
||||
let new = min(downloaded + (num_bytes as u64), total_size);
|
||||
downloaded = new;
|
||||
pb.set_position(new);
|
||||
}
|
||||
pb.finish_with_message(msg);
|
||||
|
||||
bytes
|
||||
}
|
|
@ -10,6 +10,8 @@ use crate::{
|
|||
Dataset, InMemDataset,
|
||||
};
|
||||
|
||||
use burn_common::network::downloader::download_file_as_bytes;
|
||||
|
||||
// CVDF mirror of http://yann.lecun.com/exdb/mnist/
|
||||
const URL: &str = "https://storage.googleapis.com/cvdf-datasets/mnist/";
|
||||
const TRAIN_IMAGES: &str = "train-images-idx3-ubyte";
|
||||
|
@ -151,7 +153,7 @@ impl MNISTDataset {
|
|||
|
||||
if !file_name.exists() {
|
||||
// Download gzip file
|
||||
let bytes = super::downloader::download_file_as_bytes(&format!("{URL}{name}.gz"), name);
|
||||
let bytes = download_file_as_bytes(&format!("{URL}{name}.gz"), name);
|
||||
|
||||
// Create file to write the downloaded content to
|
||||
let mut output_file = File::create(&file_name).unwrap();
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
mod downloader;
|
||||
mod image_folder;
|
||||
mod mnist;
|
||||
|
||||
|
|
|
@ -58,6 +58,9 @@ wgpu = ["burn-core/wgpu"]
|
|||
tch = ["burn-core/tch"]
|
||||
candle = ["burn-core/candle"]
|
||||
|
||||
# Network utils
|
||||
network = ["burn-core/network"]
|
||||
|
||||
# Experimental
|
||||
experimental-named-tensor = ["burn-core/experimental-named-tensor"]
|
||||
|
||||
|
|
|
@ -80,6 +80,7 @@
|
|||
//! - `autodiff`: Makes available the Autodiff backend
|
||||
//! - Others:
|
||||
//! - `std`: Activates the standard library (deactivate for no_std)
|
||||
//! - `network`: Enables network utilities (currently, only a file downloader with progress bar)
|
||||
//! - `experimental-named-tensor`: Enables named tensors (experimental)
|
||||
|
||||
pub use burn_core::*;
|
||||
|
|
|
@ -13,7 +13,7 @@ tch-gpu = ["burn/tch"]
|
|||
wgpu = ["burn/wgpu"]
|
||||
|
||||
[dependencies]
|
||||
burn = {path = "../../burn", features = ["train", "vision"]}
|
||||
burn = { path = "../../burn", features = ["train", "vision", "network"] }
|
||||
|
||||
# File download
|
||||
flate2 = { workspace = true }
|
||||
|
|
|
@ -2,7 +2,7 @@ use flate2::read::GzDecoder;
|
|||
use std::path::{Path, PathBuf};
|
||||
use tar::Archive;
|
||||
|
||||
use burn::data::dataset::vision::ImageFolderDataset;
|
||||
use burn::data::{dataset::vision::ImageFolderDataset, network::downloader};
|
||||
|
||||
/// CIFAR-10 mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L44).
|
||||
/// Licensed under the [Appache License](https://github.com/fastai/fastai/blob/master/LICENSE).
|
||||
|
@ -44,7 +44,7 @@ fn download() -> PathBuf {
|
|||
let labels_file = cifar_dir.join("labels.txt");
|
||||
if !labels_file.exists() {
|
||||
// Download gzip file
|
||||
let bytes = super::downloader::download_file_as_bytes(URL, "cifar10.tgz");
|
||||
let bytes = downloader::download_file_as_bytes(URL, "cifar10.tgz");
|
||||
|
||||
// Decode gzip file content and unpack archive
|
||||
let gz_buffer = GzDecoder::new(&bytes[..]);
|
||||
|
|
|
@ -1,46 +0,0 @@
|
|||
use indicatif::{ProgressBar, ProgressState, ProgressStyle};
|
||||
use reqwest::Client;
|
||||
use std::cmp::min;
|
||||
use std::io::Write;
|
||||
|
||||
/// Download the file at the specified url to a bytes vector.
|
||||
/// File download progress is reported with the help of a [progress bar](indicatif).
|
||||
///
|
||||
/// Taken from [burn-dataset](https://github.com/tracel-ai/burn/blob/main/burn-dataset/src/vision/downloader.rs).
|
||||
#[tokio::main(flavor = "current_thread")]
|
||||
pub async fn download_file_as_bytes(url: &str, message: &str) -> Vec<u8> {
|
||||
// Get file from web
|
||||
let mut response = Client::new().get(url).send().await.unwrap();
|
||||
let total_size = response.content_length().unwrap();
|
||||
|
||||
// Pretty progress bar
|
||||
let pb = ProgressBar::new(total_size);
|
||||
let msg = message.to_owned();
|
||||
pb.set_style(
|
||||
ProgressStyle::with_template(
|
||||
"{msg}\n {wide_bar:.cyan/blue} {bytes}/{total_bytes} ({eta})",
|
||||
)
|
||||
.unwrap()
|
||||
.with_key(
|
||||
"eta",
|
||||
|state: &ProgressState, w: &mut dyn std::fmt::Write| {
|
||||
write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()
|
||||
},
|
||||
)
|
||||
.progress_chars("▬ "),
|
||||
);
|
||||
pb.set_message(msg.clone());
|
||||
|
||||
// Read stream into bytes
|
||||
let mut downloaded: u64 = 0;
|
||||
let mut bytes: Vec<u8> = Vec::with_capacity(total_size as usize);
|
||||
while let Some(chunk) = response.chunk().await.unwrap() {
|
||||
let num_bytes = bytes.write(&chunk).unwrap();
|
||||
let new = min(downloaded + (num_bytes as u64), total_size);
|
||||
downloaded = new;
|
||||
pb.set_position(new);
|
||||
}
|
||||
pb.finish_with_message(msg);
|
||||
|
||||
bytes
|
||||
}
|
|
@ -1,5 +1,4 @@
|
|||
pub mod data;
|
||||
pub mod dataset;
|
||||
mod downloader;
|
||||
pub mod model;
|
||||
pub mod training;
|
||||
|
|
|
@ -151,14 +151,14 @@ impl<B: Backend> Model<B> {
|
|||
let output = self.model.forward(input);
|
||||
|
||||
// Convert the model output into probability distribution using softmax formula
|
||||
let probabilies = softmax(output, 1);
|
||||
let probabilities = softmax(output, 1);
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
let result = probabilies.into_data().convert::<f32>().value;
|
||||
let result = probabilities.into_data().convert::<f32>().value;
|
||||
|
||||
// Forces the result to be computed
|
||||
#[cfg(target_family = "wasm")]
|
||||
let result = probabilies.into_data().await.convert::<f32>().value;
|
||||
let result = probabilities.into_data().await.convert::<f32>().value;
|
||||
|
||||
result
|
||||
}
|
||||
|
@ -173,18 +173,18 @@ pub struct InferenceResult {
|
|||
}
|
||||
|
||||
/// Returns the top 5 classes and convert them into a JsValue
|
||||
fn top_5_classes(probabilies: Vec<f32>) -> Result<JsValue, JsValue> {
|
||||
fn top_5_classes(probabilities: Vec<f32>) -> Result<JsValue, JsValue> {
|
||||
// Convert the probabilities into a vector of (index, probability)
|
||||
let mut probabilies: Vec<_> = probabilies.iter().enumerate().collect();
|
||||
let mut probabilities: Vec<_> = probabilities.iter().enumerate().collect();
|
||||
|
||||
// Sort the probabilities in descending order
|
||||
probabilies.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
|
||||
probabilities.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
|
||||
|
||||
// Take the top 5 probabilities
|
||||
probabilies.truncate(5);
|
||||
probabilities.truncate(5);
|
||||
|
||||
// Convert the probabilities into InferenceResult
|
||||
let result: Vec<InferenceResult> = probabilies
|
||||
let result: Vec<InferenceResult> = probabilities
|
||||
.into_iter()
|
||||
.map(|(index, probability)| InferenceResult {
|
||||
index,
|
||||
|
|
Loading…
Reference in New Issue