Add vision/mnist dataset (#1176)

This commit is contained in:
Guillaume Lagrange 2024-01-25 16:16:39 -05:00 committed by GitHub
parent 0368409eb3
commit b9bd42959b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 335 additions and 121 deletions

View File

@ -82,6 +82,7 @@ strum_macros = "0.25.3"
syn = { version = "2.0", features = ["full", "extra-traits"] }
tempfile = "3.8.1"
thiserror = "1.0.50"
tokio = { version = "1.35.1", features = ["rt", "macros"] }
tracing-appender = "0.2.3"
tracing-core = "0.1.32"
tracing-subscriber = "0.3.18"

View File

@ -11,7 +11,7 @@ used as input to our previously defined model.
```rust , ignore
use burn::{
data::{dataloader::batcher::Batcher, dataset::source::huggingface::MNISTItem},
data::{dataloader::batcher::Batcher, dataset::vision::MNISTItem},
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
};

View File

@ -165,16 +165,17 @@ Tensor {
```
While the previous example is somewhat trivial, the upcoming
[basic workflow section](./basic-workflow/README.md) will walk you through a much more relevant
example for deep learning applications.
[basic workflow section](./basic-workflow/) will walk you through a much more relevant example for
deep learning applications.
## Running examples
Burn uses a [Python library by HuggingFace](https://huggingface.co/docs/datasets/index) to download
datasets. Therefore, in order to run examples, you will need to install Python. Follow the
instructions on the [official website](https://www.python.org/downloads/) to install Python on your
computer.
Many Burn examples are available in the
Many additional Burn examples available in the
[examples](https://github.com/tracel-ai/burn/tree/main/examples) directory. To run one, please refer
to the example's README.md for the specific command to execute.
Note that some examples use the
[`datasets` library by HuggingFace](https://huggingface.co/docs/datasets/index) to download the
datasets required in the examples. This is a Python library, which means that you will need to
install Python before running these examples. This requirement will be clearly indicated in the
example's README when applicable.

View File

@ -42,6 +42,7 @@ std = [
dataset = ["burn-dataset"]
sqlite = ["burn-dataset?/sqlite"]
sqlite-bundled = ["burn-dataset?/sqlite-bundled"]
vision = ["burn-dataset?/vision"]
wasm-sync = ["burn-tensor/wasm-sync", "burn-common/wasm-sync"]

View File

@ -20,6 +20,8 @@ fake = ["dep:fake"]
sqlite = ["__sqlite-shared", "dep:rusqlite"]
sqlite-bundled = ["__sqlite-shared", "rusqlite/bundled"]
vision = ["dep:flate2", "dep:indicatif", "dep:reqwest", "dep:tokio"]
# internal
__sqlite-shared = [
"dep:r2d2",
@ -34,12 +36,15 @@ csv = { workspace = true }
derive-new = { workspace = true }
dirs = { workspace = true }
fake = { workspace = true, optional = true }
flate2 = { workspace = true, optional = true }
gix-tempfile = { workspace = true, optional = true }
hound = { version = "3.5.1", optional = true }
image = { version = "0.24.7", features = ["png"], optional = true }
indicatif = { workspace = true, optional = true }
r2d2 = { workspace = true, optional = true }
r2d2_sqlite = { workspace = true, optional = true }
rand = { workspace = true, features = ["std"] }
reqwest = { workspace = true, optional = true }
rmp-serde = { workspace = true }
rusqlite = { workspace = true, optional = true }
sanitize-filename = { workspace = true }
@ -50,6 +55,7 @@ strum = { workspace = true }
strum_macros = { workspace = true }
tempfile = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, optional = true }
[dev-dependencies]
rayon = { workspace = true }

View File

@ -19,6 +19,10 @@ pub mod transform;
#[cfg(feature = "audio")]
pub mod audio;
/// Vision datasets.
#[cfg(feature = "vision")]
pub mod vision;
mod dataset;
pub use dataset::*;
#[cfg(any(feature = "sqlite", feature = "sqlite-bundled"))]

View File

@ -1,92 +0,0 @@
use crate::source::huggingface::downloader::HuggingfaceDatasetLoader;
use crate::transform::{Mapper, MapperDataset};
use crate::{Dataset, SqliteDataset};
use image;
use serde::{Deserialize, Serialize};
const WIDTH: usize = 28;
const HEIGHT: usize = 28;
/// MNIST item.
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct MNISTItem {
/// Image as a 2D array of floats.
pub image: [[f32; WIDTH]; HEIGHT],
/// Label of the image.
pub label: usize,
}
#[derive(Deserialize, Debug, Clone)]
struct MNISTItemRaw {
pub image_bytes: Vec<u8>,
pub label: usize,
}
struct BytesToImage;
impl Mapper<MNISTItemRaw, MNISTItem> for BytesToImage {
/// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image).
fn map(&self, item: &MNISTItemRaw) -> MNISTItem {
let image = image::load_from_memory(&item.image_bytes).unwrap();
let image = image.as_luma8().unwrap();
// Ensure the image dimensions are correct.
debug_assert_eq!(image.dimensions(), (WIDTH as u32, HEIGHT as u32));
// Convert the image to a 2D array of floats.
let mut image_array = [[0f32; WIDTH]; HEIGHT];
for (i, pixel) in image.as_raw().iter().enumerate() {
let x = i % WIDTH;
let y = i / HEIGHT;
image_array[y][x] = *pixel as f32;
}
MNISTItem {
image: image_array,
label: item.label,
}
}
}
type MappedDataset = MapperDataset<SqliteDataset<MNISTItemRaw>, BytesToImage, MNISTItemRaw>;
/// MNIST dataset from Huggingface.
///
/// The data is downloaded from Huggingface and stored in a SQLite database.
pub struct MNISTDataset {
dataset: MappedDataset,
}
impl Dataset<MNISTItem> for MNISTDataset {
fn get(&self, index: usize) -> Option<MNISTItem> {
self.dataset.get(index)
}
fn len(&self) -> usize {
self.dataset.len()
}
}
impl MNISTDataset {
/// Creates a new train dataset.
pub fn train() -> Self {
Self::new("train")
}
/// Creates a new test dataset.
pub fn test() -> Self {
Self::new("test")
}
fn new(split: &str) -> Self {
let dataset = HuggingfaceDatasetLoader::new("mnist")
.dataset(split)
.unwrap();
let dataset = MapperDataset::new(dataset, BytesToImage);
Self { dataset }
}
}

View File

@ -1,5 +1,3 @@
pub(crate) mod downloader;
mod mnist;
pub use downloader::*;
pub use mnist::*;

View File

@ -0,0 +1,44 @@
use indicatif::{ProgressBar, ProgressState, ProgressStyle};
use reqwest::Client;
use std::cmp::min;
use std::io::Write;
/// Download the file at the specified url to a bytes vector.
/// File download progress is reported with the help of a [progress bar](indicatif).
#[tokio::main(flavor = "current_thread")]
pub async fn download_file_as_bytes(url: &str, message: &str) -> Vec<u8> {
// Get file from web
let mut response = Client::new().get(url).send().await.unwrap();
let total_size = response.content_length().unwrap();
// Pretty progress bar
let pb = ProgressBar::new(total_size);
let msg = message.to_owned();
pb.set_style(
ProgressStyle::with_template(
"{msg}\n {wide_bar:.cyan/blue} {bytes}/{total_bytes} ({eta})",
)
.unwrap()
.with_key(
"eta",
|state: &ProgressState, w: &mut dyn std::fmt::Write| {
write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()
},
)
.progress_chars(""),
);
pb.set_message(msg.clone());
// Read stream into bytes
let mut downloaded: u64 = 0;
let mut bytes: Vec<u8> = Vec::with_capacity(total_size as usize);
while let Some(chunk) = response.chunk().await.unwrap() {
let num_bytes = bytes.write(&chunk).unwrap();
let new = min(downloaded + (num_bytes as u64), total_size);
downloaded = new;
pb.set_position(new);
}
pb.finish_with_message(msg);
bytes
}

View File

@ -0,0 +1,220 @@
use std::fs::{create_dir_all, File};
use std::io::{Read, Seek, SeekFrom};
use std::path::{Path, PathBuf};
use flate2::read::GzDecoder;
use serde::{Deserialize, Serialize};
use crate::{
transform::{Mapper, MapperDataset},
Dataset, InMemDataset,
};
// CVDF mirror of http://yann.lecun.com/exdb/mnist/
const URL: &str = "https://storage.googleapis.com/cvdf-datasets/mnist/";
const TRAIN_IMAGES: &str = "train-images-idx3-ubyte";
const TRAIN_LABELS: &str = "train-labels-idx1-ubyte";
const TEST_IMAGES: &str = "t10k-images-idx3-ubyte";
const TEST_LABELS: &str = "t10k-labels-idx1-ubyte";
const WIDTH: usize = 28;
const HEIGHT: usize = 28;
/// MNIST item.
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct MNISTItem {
/// Image as a 2D array of floats.
pub image: [[f32; WIDTH]; HEIGHT],
/// Label of the image.
pub label: u8,
}
#[derive(Deserialize, Debug, Clone)]
struct MNISTItemRaw {
pub image_bytes: Vec<u8>,
pub label: u8,
}
struct BytesToImage;
impl Mapper<MNISTItemRaw, MNISTItem> for BytesToImage {
/// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image).
fn map(&self, item: &MNISTItemRaw) -> MNISTItem {
// Ensure the image dimensions are correct.
debug_assert_eq!(item.image_bytes.len(), WIDTH * HEIGHT);
// Convert the image to a 2D array of floats.
let mut image_array = [[0f32; WIDTH]; HEIGHT];
for (i, pixel) in item.image_bytes.iter().enumerate() {
let x = i % WIDTH;
let y = i / HEIGHT;
image_array[y][x] = *pixel as f32;
}
MNISTItem {
image: image_array,
label: item.label,
}
}
}
type MappedDataset = MapperDataset<InMemDataset<MNISTItemRaw>, BytesToImage, MNISTItemRaw>;
/// The MNIST dataset consists of 70,000 28x28 black-and-white images in 10 classes (one for each digits), with 7,000
/// images per class. There are 60,000 training images and 10,000 test images.
///
/// The data is downloaded from the web from the [CVDF mirror](https://github.com/cvdfoundation/mnist).
pub struct MNISTDataset {
dataset: MappedDataset,
}
impl Dataset<MNISTItem> for MNISTDataset {
fn get(&self, index: usize) -> Option<MNISTItem> {
self.dataset.get(index)
}
fn len(&self) -> usize {
self.dataset.len()
}
}
impl MNISTDataset {
/// Creates a new train dataset.
pub fn train() -> Self {
Self::new("train")
}
/// Creates a new test dataset.
pub fn test() -> Self {
Self::new("test")
}
fn new(split: &str) -> Self {
// Download dataset
let root = MNISTDataset::download(split);
// MNIST is tiny so we can load it in-memory
// Train images (u8): 28 * 28 * 60000 = 47.04Mb
// Test images (u8): 28 * 28 * 10000 = 7.84Mb
let images = MNISTDataset::read_images(&root, split);
let labels = MNISTDataset::read_labels(&root, split);
// Collect as vector of MNISTItemRaw
let items: Vec<_> = images
.into_iter()
.zip(labels)
.map(|(image_bytes, label)| MNISTItemRaw { image_bytes, label })
.collect();
let dataset = InMemDataset::new(items);
let dataset = MapperDataset::new(dataset, BytesToImage);
Self { dataset }
}
/// Download the MNIST dataset files from the web.
/// Panics if the download cannot be completed or the content of the file cannot be written to disk.
fn download(split: &str) -> PathBuf {
// Dataset files are stored un the burn-dataset cache directory
let cache_dir = dirs::home_dir()
.expect("Could not get home directory")
.join(".cache")
.join("burn-dataset");
let split_dir = cache_dir.join("mnist").join(split);
if !split_dir.exists() {
create_dir_all(&split_dir).expect("Failed to create base directory");
}
// Download split files
match split {
"train" => {
MNISTDataset::download_file(TRAIN_IMAGES, &split_dir);
MNISTDataset::download_file(TRAIN_LABELS, &split_dir);
}
"test" => {
MNISTDataset::download_file(TEST_IMAGES, &split_dir);
MNISTDataset::download_file(TEST_LABELS, &split_dir);
}
_ => panic!("Invalid split specified {}", split),
};
split_dir
}
/// Download a file from the MNIST dataset URL to the destination directory.
/// File download progress is reported with the help of a [progress bar](indicatif).
fn download_file<P: AsRef<Path>>(name: &str, dest_dir: &P) -> PathBuf {
// Output file name
let file_name = dest_dir.as_ref().join(name);
if !file_name.exists() {
// Download gzip file
let bytes = super::downloader::download_file_as_bytes(&format!("{URL}{name}.gz"), name);
// Create file to write the downloaded content to
let mut output_file = File::create(&file_name).unwrap();
// Decode gzip file content and write to disk
let mut gz_buffer = GzDecoder::new(&bytes[..]);
std::io::copy(&mut gz_buffer, &mut output_file).unwrap();
}
file_name
}
/// Read images at the provided path for the specified split.
/// Each image is a vector of bytes.
fn read_images<P: AsRef<Path>>(root: &P, split: &str) -> Vec<Vec<u8>> {
let file_name = if split == "train" {
TRAIN_IMAGES
} else {
TEST_IMAGES
};
let file_name = root.as_ref().join(file_name);
// Read number of images from 16-byte header metadata
let mut f = File::open(file_name).unwrap();
let mut buf = [0u8; 4];
let _ = f.seek(SeekFrom::Start(4)).unwrap();
f.read_exact(&mut buf)
.expect("Should be able to read image file header");
let size = u32::from_be_bytes(buf);
let mut buf_images: Vec<u8> = vec![0u8; WIDTH * HEIGHT * (size as usize)];
let _ = f.seek(SeekFrom::Start(16)).unwrap();
f.read_exact(&mut buf_images)
.expect("Should be able to read image file header");
buf_images
.chunks(WIDTH * HEIGHT)
.map(|chunk| chunk.to_vec())
.collect()
}
/// Read labels at the provided path for the specified split.
fn read_labels<P: AsRef<Path>>(root: &P, split: &str) -> Vec<u8> {
let file_name = if split == "train" {
TRAIN_LABELS
} else {
TEST_LABELS
};
let file_name = root.as_ref().join(file_name);
// Read number of labels from 8-byte header metadata
let mut f = File::open(file_name).unwrap();
let mut buf = [0u8; 4];
let _ = f.seek(SeekFrom::Start(4)).unwrap();
f.read_exact(&mut buf)
.expect("Should be able to read label file header");
let size = u32::from_be_bytes(buf);
let mut buf_labels: Vec<u8> = vec![0u8; size as usize];
let _ = f.seek(SeekFrom::Start(8)).unwrap();
f.read_exact(&mut buf_labels)
.expect("Should be able to read labels from file");
buf_labels
}
}

View File

@ -0,0 +1,4 @@
mod downloader;
mod mnist;
pub use mnist::*;

View File

@ -33,6 +33,8 @@ dataset = ["burn-core/dataset"]
sqlite = ["burn-core/sqlite"]
sqlite-bundled = ["burn-core/sqlite-bundled"]
vision = ["burn-core/vision"]
# Backends
autodiff = ["burn-core/autodiff"]
fusion = ["burn-core/fusion"]

View File

@ -61,6 +61,7 @@
//! - `audio`: Enables audio datasets (SpeechCommandsDataset)
//! - `sqlite`: Stores datasets in SQLite database
//! - `sqlite_bundled`: Use bundled version of SQLite
//! - `vision`: Enables vision datasets (MNISTDataset)
//! - Backends
//! - `wgpu`: Makes available the WGPU backend
//! - `candle`: Makes available the Candle backend

View File

@ -8,7 +8,7 @@ publish = false
version.workspace = true
[dependencies]
burn = {path = "../../burn", features=["autodiff", "wgpu", "train", "dataset"], default-features=false}
burn = {path = "../../burn", features=["autodiff", "wgpu", "train", "dataset", "vision"], default-features=false}
guide = {path = "../guide", default-features=false}
# Serialization

View File

@ -1,4 +1,4 @@
use burn::data::dataset::source::huggingface::MNISTDataset;
use burn::data::dataset::vision::MNISTDataset;
use burn::train::renderer::{MetricState, MetricsRenderer, TrainingProgress};
use burn::train::LearnerBuilder;
use burn::{

View File

@ -7,7 +7,7 @@ publish = false
version.workspace = true
[dependencies]
burn = {path = "../../burn", features=["autodiff", "wgpu"]}
burn = {path = "../../burn", features=["autodiff", "wgpu", "vision"]}
guide = {path = "../guide"}
# Serialization

View File

@ -1,6 +1,6 @@
use std::marker::PhantomData;
use burn::data::dataset::source::huggingface::MNISTDataset;
use burn::data::dataset::vision::MNISTDataset;
use burn::{
config::Config,
data::dataloader::DataLoaderBuilder,

View File

@ -10,7 +10,7 @@ version.workspace = true
default = ["burn/default"]
[dependencies]
burn = {path = "../../burn", features = ["wgpu", "train"]}
burn = {path = "../../burn", features = ["wgpu", "train", "vision"]}
# Serialization
log = {workspace = true}

9
examples/guide/README.md Normal file
View File

@ -0,0 +1,9 @@
# Basic Workflow: From Training to Inference
This example corresponds to the [book's guide](https://burn.dev/book/basic-workflow/).
## Example Usage
```sh
cargo run --example guide
```

View File

@ -18,7 +18,7 @@ fn main() {
guide::inference::infer::<MyBackend>(
artifact_dir,
device,
burn::data::dataset::source::huggingface::MNISTDataset::test()
burn::data::dataset::vision::MNISTDataset::test()
.get(42)
.unwrap(),
);

View File

@ -1,5 +1,5 @@
use burn::{
data::{dataloader::batcher::Batcher, dataset::source::huggingface::MNISTItem},
data::{dataloader::batcher::Batcher, dataset::vision::MNISTItem},
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
};

View File

@ -1,5 +1,5 @@
use crate::{data::MNISTBatcher, training::TrainingConfig};
use burn::data::dataset::source::huggingface::MNISTItem;
use burn::data::dataset::vision::MNISTItem;
use burn::{
config::Config,
data::dataloader::batcher::Batcher,

View File

@ -2,7 +2,7 @@ use crate::{
data::{MNISTBatch, MNISTBatcher},
model::{Model, ModelConfig},
};
use burn::data::dataset::source::huggingface::MNISTDataset;
use burn::data::dataset::vision::MNISTDataset;
use burn::train::{
metric::{AccuracyMetric, LossMetric},
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,

View File

@ -7,7 +7,7 @@ publish = false
version.workspace = true
[features]
default = ["burn/dataset", "burn/sqlite-bundled"]
default = ["burn/dataset", "burn/vision"]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]

View File

@ -1,5 +1,5 @@
use burn::{
data::{dataloader::batcher::Batcher, dataset::source::huggingface::MNISTItem},
data::{dataloader::batcher::Batcher, dataset::vision::MNISTItem},
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
};

View File

@ -10,7 +10,7 @@ use burn::train::metric::{CpuMemory, CpuTemperature, CpuUse};
use burn::train::{MetricEarlyStoppingStrategy, StoppingCondition};
use burn::{
config::Config,
data::{dataloader::DataLoaderBuilder, dataset::source::huggingface::MNISTDataset},
data::{dataloader::DataLoaderBuilder, dataset::vision::MNISTDataset},
tensor::backend::AutodiffBackend,
train::{
metric::{AccuracyMetric, LossMetric},

View File

@ -12,7 +12,7 @@ default = ["embedded-model"]
embedded-model = []
[dependencies]
burn = { path = "../../burn", features = ["ndarray", "dataset", "sqlite-bundled"] }
burn = { path = "../../burn", features = ["ndarray", "dataset", "vision"] }
serde = { workspace = true }
[build-dependencies]

View File

@ -3,7 +3,7 @@ use std::env::args;
use burn::backend::ndarray::NdArray;
use burn::tensor::Tensor;
use burn::data::dataset::source::huggingface::MNISTDataset;
use burn::data::dataset::vision::MNISTDataset;
use burn::data::dataset::Dataset;
use onnx_inference::mnist::Model;
@ -49,7 +49,7 @@ fn main() {
let output = model.forward(input);
// Get the index of the maximum value
let arg_max = output.argmax(1).into_scalar() as usize;
let arg_max = output.argmax(1).into_scalar() as u8;
// Check if the index matches the label
assert!(arg_max == item.label);

View File

@ -10,7 +10,7 @@ version = "0.12.0"
burn = { path = "../../burn", features = [
"ndarray",
"dataset",
"sqlite-bundled",
"vision",
] }
model = { path = "./model" }

View File

@ -5,7 +5,7 @@ use burn::backend::ndarray::NdArray;
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
use burn::tensor::Tensor;
use burn::data::dataset::source::huggingface::MNISTDataset;
use burn::data::dataset::vision::MNISTDataset;
use burn::data::dataset::Dataset;
use model::Model;
@ -57,7 +57,7 @@ fn main() {
let output = model.forward(input);
// Get the index of the maximum value
let arg_max = output.argmax(1).into_scalar() as usize;
let arg_max = output.argmax(1).into_scalar() as u8;
// Check if the index matches the label
assert!(arg_max == item.label);

View File

@ -7,6 +7,11 @@ from HuggingFace hub. The dataset is also available as part of toy regression da
- Create a data pipeline from a raw dataset to a batched fast DataLoader with min-max feature scaling.
- Define a Simple NN model for regression using Burn Modules.
> **Note**
> This example makes use of the HuggingFace [`datasets`](https://huggingface.co/docs/datasets/index)
> library to download the datasets. Make sure you have [Python](https://www.python.org/downloads/)
> installed on your computer.
The example can be run like so:
```bash

View File

@ -3,6 +3,11 @@
This project provides an example implementation for training and inferencing text classification
models on AG News and DbPedia datasets using the Rust-based Burn Deep Learning Library.
> **Note**
> This example makes use of the HuggingFace [`datasets`](https://huggingface.co/docs/datasets/index)
> library to download the datasets. Make sure you have [Python](https://www.python.org/downloads/)
> installed on your computer.
## Dataset Details
- AG News: The AG News dataset is a collection of news articles from more than 2000 news sources.

View File

@ -1,5 +1,10 @@
# Text Generation
> **Note**
> This example makes use of the HuggingFace [`datasets`](https://huggingface.co/docs/datasets/index)
> library to download the datasets. Make sure you have [Python](https://www.python.org/downloads/)
> installed on your computer.
The example can be run like so:
## CUDA users