Add a dataset/dataloader/batcher usage section (#2161)

* Add a dataset/dataloader/batcher usage section

* Fix typos
This commit is contained in:
Guillaume Lagrange 2024-08-22 11:52:56 -04:00 committed by GitHub
parent 75a2850047
commit 73d4b11aa2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 271 additions and 3 deletions

View File

@ -11,3 +11,6 @@ language = "en"
multilingual = false
src = "src"
title = "The Burn Book 🔥"
[output.html]
mathjax-support = true

View File

@ -1,8 +1,13 @@
# Dataset
In most deep learning training performed on datasets (with perhaps the exception of reinforcement
learning), it is essential to provide a convenient and performant API. The dataset trait is quite
similar to the dataset abstract class in PyTorch:
At its core, a dataset is a collection of data typically related to a specific analysis or
processing task. The data modality can vary depending on the task, but most datasets primarily
consist of images, texts, audio or videos.
This data source represents an integral part of machine learning to successfully train a model. Thus,
it is essential to provide a convenient and performant API to handle your data. Since this process
varies wildly from one problem to another, it is defined as a trait that should be implemented on
your type. The dataset trait is quite similar to the dataset abstract class in PyTorch:
```rust, ignore
pub trait Dataset<I>: Send + Sync {
@ -133,3 +138,263 @@ multiple times over the dataset and only checkpoint when done. You can consider
dataset as the number of iterations before performing checkpointing and running the validation.
There is nothing stopping you from returning different items even when called with the same `index`
multiple times.
## How Is The Dataset Used?
During training, the dataset is used to access the data samples and, for most use cases in
supervised learning, their corresponding ground-truth labels. Remember that the `Dataset` trait
implementation is responsible to retrieve the data from its source, usually some sort of data
storage. At this point, the dataset could be naively iterated over to provide the model a single
sample to process at a time, but this is not very efficient.
Instead, we collect multiple samples that the model can process as a _batch_ to fully leverage
modern hardware (e.g., GPUs - which have impressing parallel processing capabilities). Since each
data sample in the dataset can be collected independently, the data loading is typically done in
parallel to further speed things up. In this case, we parallelize the data loading using a
multi-threaded `BatchDataLoader` to obtain a sequence of items from the `Dataset` implementation.
Finally, the sequence of items is combined into a batched tensor that can be used as input to a
model with the `Batcher` trait implementation. Other tensor operations can be performed during this
step to prepare the batch data, as is done [in the basic workflow guide](../basic-workflow/data.md).
The process is illustrated in the figure below for the MNIST dataset.
<img title="Burn Data Loading Pipeline" alt="Burn Data Loading Pipeline" src="./dataset.png">
Although we have conveniently implemented the
[`MnistDataset`](https://github.com/tracel-ai/burn/blob/main/crates/burn-dataset/src/vision/mnist.rs)
used in the guide, we'll go over its implementation to demonstrate how the `Dataset` and `Batcher`
traits are used.
The [MNIST dataset](http://yann.lecun.com/exdb/mnist/) of handwritten digits has a training set of
60,000 examples and a test set of 10,000 examples. A single item in the dataset is represented by a
\\(28 \times 28\\) pixels black-and-white image (stored as raw bytes) with its corresponding label
(a digit between \\(0\\) and \\(9\\)). This is defined by the `MnistItemRaw` struct.
```rust, ignore
# #[derive(Deserialize, Debug, Clone)]
struct MnistItemRaw {
pub image_bytes: Vec<u8>,
pub label: u8,
}
```
With single-channel images of such low resolution, the entire training and test sets can be loaded
in memory at once. Therefore, we leverage the already existing `InMemDataset` to retrieve the raw
images and labels data. At this point, the image data is still just a bunch of bytes, but we want to
retrieve the _structured_ image data in its intended form. For that, we can define a `MapperDataset`
that transforms the raw image bytes to a 2D array image (which we convert to float while we're at
it).
```rust, ignore
const WIDTH: usize = 28;
const HEIGHT: usize = 28;
# /// MNIST item.
# #[derive(Deserialize, Serialize, Debug, Clone)]
pub struct MnistItem {
/// Image as a 2D array of floats.
pub image: [[f32; WIDTH]; HEIGHT],
/// Label of the image.
pub label: u8,
}
struct BytesToImage;
impl Mapper<MnistItemRaw, MnistItem> for BytesToImage {
/// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image).
fn map(&self, item: &MnistItemRaw) -> MnistItem {
// Ensure the image dimensions are correct.
debug_assert_eq!(item.image_bytes.len(), WIDTH * HEIGHT);
// Convert the image to a 2D array of floats.
let mut image_array = [[0f32; WIDTH]; HEIGHT];
for (i, pixel) in item.image_bytes.iter().enumerate() {
let x = i % WIDTH;
let y = i / HEIGHT;
image_array[y][x] = *pixel as f32;
}
MnistItem {
image: image_array,
label: item.label,
}
}
}
type MappedDataset = MapperDataset<InMemDataset<MnistItemRaw>, BytesToImage, MnistItemRaw>;
# /// The MNIST dataset consists of 70,000 28x28 black-and-white images in 10 classes (one for each digits), with 7,000
# /// images per class. There are 60,000 training images and 10,000 test images.
# ///
# /// The data is downloaded from the web from the [CVDF mirror](https://github.com/cvdfoundation/mnist).
pub struct MnistDataset {
dataset: MappedDataset,
}
```
To construct the `MnistDataset`, the data source must be parsed into the expected `MappedDataset`
type. Since both the train and test sets use the same file format, we can separate the functionality
to load the `train()` and `test()` dataset.
```rust, ignore
impl MnistDataset {
/// Creates a new train dataset.
pub fn train() -> Self {
Self::new("train")
}
/// Creates a new test dataset.
pub fn test() -> Self {
Self::new("test")
}
fn new(split: &str) -> Self {
// Download dataset
let root = MnistDataset::download(split);
// Parse data as vector of images bytes and vector of labels
let images: Vec<Vec<u8>> = MnistDataset::read_images(&root, split);
let labels: Vec<u8> = MnistDataset::read_labels(&root, split);
// Collect as vector of MnistItemRaw
let items: Vec<_> = images
.into_iter()
.zip(labels)
.map(|(image_bytes, label)| MnistItemRaw { image_bytes, label })
.collect();
// Create the MapperDataset for InMemDataset<MnistItemRaw> to transform
// items (MnistItemRaw -> MnistItem)
let dataset = InMemDataset::new(items);
let dataset = MapperDataset::new(dataset, BytesToImage);
Self { dataset }
}
# /// Download the MNIST dataset files from the web.
# /// Panics if the download cannot be completed or the content of the file cannot be written to disk.
# fn download(split: &str) -> PathBuf {
# // Dataset files are stored un the burn-dataset cache directory
# let cache_dir = dirs::home_dir()
# .expect("Could not get home directory")
# .join(".cache")
# .join("burn-dataset");
# let split_dir = cache_dir.join("mnist").join(split);
#
# if !split_dir.exists() {
# create_dir_all(&split_dir).expect("Failed to create base directory");
# }
#
# // Download split files
# match split {
# "train" => {
# MnistDataset::download_file(TRAIN_IMAGES, &split_dir);
# MnistDataset::download_file(TRAIN_LABELS, &split_dir);
# }
# "test" => {
# MnistDataset::download_file(TEST_IMAGES, &split_dir);
# MnistDataset::download_file(TEST_LABELS, &split_dir);
# }
# _ => panic!("Invalid split specified {}", split),
# };
#
# split_dir
# }
#
# /// Download a file from the MNIST dataset URL to the destination directory.
# /// File download progress is reported with the help of a [progress bar](indicatif).
# fn download_file<P: AsRef<Path>>(name: &str, dest_dir: &P) -> PathBuf {
# // Output file name
# let file_name = dest_dir.as_ref().join(name);
#
# if !file_name.exists() {
# // Download gzip file
# let bytes = download_file_as_bytes(&format!("{URL}{name}.gz"), name);
#
# // Create file to write the downloaded content to
# let mut output_file = File::create(&file_name).unwrap();
#
# // Decode gzip file content and write to disk
# let mut gz_buffer = GzDecoder::new(&bytes[..]);
# std::io::copy(&mut gz_buffer, &mut output_file).unwrap();
# }
#
# file_name
# }
#
# /// Read images at the provided path for the specified split.
# /// Each image is a vector of bytes.
# fn read_images<P: AsRef<Path>>(root: &P, split: &str) -> Vec<Vec<u8>> {
# let file_name = if split == "train" {
# TRAIN_IMAGES
# } else {
# TEST_IMAGES
# };
# let file_name = root.as_ref().join(file_name);
#
# // Read number of images from 16-byte header metadata
# let mut f = File::open(file_name).unwrap();
# let mut buf = [0u8; 4];
# let _ = f.seek(SeekFrom::Start(4)).unwrap();
# f.read_exact(&mut buf)
# .expect("Should be able to read image file header");
# let size = u32::from_be_bytes(buf);
#
# let mut buf_images: Vec<u8> = vec![0u8; WIDTH * HEIGHT * (size as usize)];
# let _ = f.seek(SeekFrom::Start(16)).unwrap();
# f.read_exact(&mut buf_images)
# .expect("Should be able to read image file header");
#
# buf_images
# .chunks(WIDTH * HEIGHT)
# .map(|chunk| chunk.to_vec())
# .collect()
# }
#
# /// Read labels at the provided path for the specified split.
# fn read_labels<P: AsRef<Path>>(root: &P, split: &str) -> Vec<u8> {
# let file_name = if split == "train" {
# TRAIN_LABELS
# } else {
# TEST_LABELS
# };
# let file_name = root.as_ref().join(file_name);
#
# // Read number of labels from 8-byte header metadata
# let mut f = File::open(file_name).unwrap();
# let mut buf = [0u8; 4];
# let _ = f.seek(SeekFrom::Start(4)).unwrap();
# f.read_exact(&mut buf)
# .expect("Should be able to read label file header");
# let size = u32::from_be_bytes(buf);
#
# let mut buf_labels: Vec<u8> = vec![0u8; size as usize];
# let _ = f.seek(SeekFrom::Start(8)).unwrap();
# f.read_exact(&mut buf_labels)
# .expect("Should be able to read labels from file");
#
# buf_labels
# }
}
```
Since the `MnistDataset` simply wraps a `MapperDataset` instance with `InMemDataset`, we can easily
implement the `Dataset` trait.
```rust, ignore
impl Dataset<MnistItem> for MnistDataset {
fn get(&self, index: usize) -> Option<MnistItem> {
self.dataset.get(index)
}
fn len(&self) -> usize {
self.dataset.len()
}
}
```
The only thing missing now is the `Batcher`, which we already went over
[in the basic workflow guide](../basic-workflow/data.md). The `Batcher` takes a list of `MnistItem`
retrieved by the dataloader as input and returns a batch of images as a 3D tensor along with their
targets.

Binary file not shown.

After

Width:  |  Height:  |  Size: 99 KiB