mirror of https://github.com/tracel-ai/burn.git
Add a dataset/dataloader/batcher usage section (#2161)
* Add a dataset/dataloader/batcher usage section * Fix typos
This commit is contained in:
parent
75a2850047
commit
73d4b11aa2
|
@ -11,3 +11,6 @@ language = "en"
|
|||
multilingual = false
|
||||
src = "src"
|
||||
title = "The Burn Book 🔥"
|
||||
|
||||
[output.html]
|
||||
mathjax-support = true
|
||||
|
|
|
@ -1,8 +1,13 @@
|
|||
# Dataset
|
||||
|
||||
In most deep learning training performed on datasets (with perhaps the exception of reinforcement
|
||||
learning), it is essential to provide a convenient and performant API. The dataset trait is quite
|
||||
similar to the dataset abstract class in PyTorch:
|
||||
At its core, a dataset is a collection of data typically related to a specific analysis or
|
||||
processing task. The data modality can vary depending on the task, but most datasets primarily
|
||||
consist of images, texts, audio or videos.
|
||||
|
||||
This data source represents an integral part of machine learning to successfully train a model. Thus,
|
||||
it is essential to provide a convenient and performant API to handle your data. Since this process
|
||||
varies wildly from one problem to another, it is defined as a trait that should be implemented on
|
||||
your type. The dataset trait is quite similar to the dataset abstract class in PyTorch:
|
||||
|
||||
```rust, ignore
|
||||
pub trait Dataset<I>: Send + Sync {
|
||||
|
@ -133,3 +138,263 @@ multiple times over the dataset and only checkpoint when done. You can consider
|
|||
dataset as the number of iterations before performing checkpointing and running the validation.
|
||||
There is nothing stopping you from returning different items even when called with the same `index`
|
||||
multiple times.
|
||||
|
||||
## How Is The Dataset Used?
|
||||
|
||||
During training, the dataset is used to access the data samples and, for most use cases in
|
||||
supervised learning, their corresponding ground-truth labels. Remember that the `Dataset` trait
|
||||
implementation is responsible to retrieve the data from its source, usually some sort of data
|
||||
storage. At this point, the dataset could be naively iterated over to provide the model a single
|
||||
sample to process at a time, but this is not very efficient.
|
||||
|
||||
Instead, we collect multiple samples that the model can process as a _batch_ to fully leverage
|
||||
modern hardware (e.g., GPUs - which have impressing parallel processing capabilities). Since each
|
||||
data sample in the dataset can be collected independently, the data loading is typically done in
|
||||
parallel to further speed things up. In this case, we parallelize the data loading using a
|
||||
multi-threaded `BatchDataLoader` to obtain a sequence of items from the `Dataset` implementation.
|
||||
Finally, the sequence of items is combined into a batched tensor that can be used as input to a
|
||||
model with the `Batcher` trait implementation. Other tensor operations can be performed during this
|
||||
step to prepare the batch data, as is done [in the basic workflow guide](../basic-workflow/data.md).
|
||||
The process is illustrated in the figure below for the MNIST dataset.
|
||||
|
||||
<img title="Burn Data Loading Pipeline" alt="Burn Data Loading Pipeline" src="./dataset.png">
|
||||
|
||||
Although we have conveniently implemented the
|
||||
[`MnistDataset`](https://github.com/tracel-ai/burn/blob/main/crates/burn-dataset/src/vision/mnist.rs)
|
||||
used in the guide, we'll go over its implementation to demonstrate how the `Dataset` and `Batcher`
|
||||
traits are used.
|
||||
|
||||
The [MNIST dataset](http://yann.lecun.com/exdb/mnist/) of handwritten digits has a training set of
|
||||
60,000 examples and a test set of 10,000 examples. A single item in the dataset is represented by a
|
||||
\\(28 \times 28\\) pixels black-and-white image (stored as raw bytes) with its corresponding label
|
||||
(a digit between \\(0\\) and \\(9\\)). This is defined by the `MnistItemRaw` struct.
|
||||
|
||||
```rust, ignore
|
||||
# #[derive(Deserialize, Debug, Clone)]
|
||||
struct MnistItemRaw {
|
||||
pub image_bytes: Vec<u8>,
|
||||
pub label: u8,
|
||||
}
|
||||
```
|
||||
|
||||
With single-channel images of such low resolution, the entire training and test sets can be loaded
|
||||
in memory at once. Therefore, we leverage the already existing `InMemDataset` to retrieve the raw
|
||||
images and labels data. At this point, the image data is still just a bunch of bytes, but we want to
|
||||
retrieve the _structured_ image data in its intended form. For that, we can define a `MapperDataset`
|
||||
that transforms the raw image bytes to a 2D array image (which we convert to float while we're at
|
||||
it).
|
||||
|
||||
```rust, ignore
|
||||
const WIDTH: usize = 28;
|
||||
const HEIGHT: usize = 28;
|
||||
|
||||
# /// MNIST item.
|
||||
# #[derive(Deserialize, Serialize, Debug, Clone)]
|
||||
pub struct MnistItem {
|
||||
/// Image as a 2D array of floats.
|
||||
pub image: [[f32; WIDTH]; HEIGHT],
|
||||
|
||||
/// Label of the image.
|
||||
pub label: u8,
|
||||
}
|
||||
|
||||
struct BytesToImage;
|
||||
|
||||
impl Mapper<MnistItemRaw, MnistItem> for BytesToImage {
|
||||
/// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image).
|
||||
fn map(&self, item: &MnistItemRaw) -> MnistItem {
|
||||
// Ensure the image dimensions are correct.
|
||||
debug_assert_eq!(item.image_bytes.len(), WIDTH * HEIGHT);
|
||||
|
||||
// Convert the image to a 2D array of floats.
|
||||
let mut image_array = [[0f32; WIDTH]; HEIGHT];
|
||||
for (i, pixel) in item.image_bytes.iter().enumerate() {
|
||||
let x = i % WIDTH;
|
||||
let y = i / HEIGHT;
|
||||
image_array[y][x] = *pixel as f32;
|
||||
}
|
||||
|
||||
MnistItem {
|
||||
image: image_array,
|
||||
label: item.label,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type MappedDataset = MapperDataset<InMemDataset<MnistItemRaw>, BytesToImage, MnistItemRaw>;
|
||||
|
||||
# /// The MNIST dataset consists of 70,000 28x28 black-and-white images in 10 classes (one for each digits), with 7,000
|
||||
# /// images per class. There are 60,000 training images and 10,000 test images.
|
||||
# ///
|
||||
# /// The data is downloaded from the web from the [CVDF mirror](https://github.com/cvdfoundation/mnist).
|
||||
pub struct MnistDataset {
|
||||
dataset: MappedDataset,
|
||||
}
|
||||
```
|
||||
|
||||
To construct the `MnistDataset`, the data source must be parsed into the expected `MappedDataset`
|
||||
type. Since both the train and test sets use the same file format, we can separate the functionality
|
||||
to load the `train()` and `test()` dataset.
|
||||
|
||||
```rust, ignore
|
||||
|
||||
impl MnistDataset {
|
||||
/// Creates a new train dataset.
|
||||
pub fn train() -> Self {
|
||||
Self::new("train")
|
||||
}
|
||||
|
||||
/// Creates a new test dataset.
|
||||
pub fn test() -> Self {
|
||||
Self::new("test")
|
||||
}
|
||||
|
||||
fn new(split: &str) -> Self {
|
||||
// Download dataset
|
||||
let root = MnistDataset::download(split);
|
||||
|
||||
// Parse data as vector of images bytes and vector of labels
|
||||
let images: Vec<Vec<u8>> = MnistDataset::read_images(&root, split);
|
||||
let labels: Vec<u8> = MnistDataset::read_labels(&root, split);
|
||||
|
||||
// Collect as vector of MnistItemRaw
|
||||
let items: Vec<_> = images
|
||||
.into_iter()
|
||||
.zip(labels)
|
||||
.map(|(image_bytes, label)| MnistItemRaw { image_bytes, label })
|
||||
.collect();
|
||||
|
||||
// Create the MapperDataset for InMemDataset<MnistItemRaw> to transform
|
||||
// items (MnistItemRaw -> MnistItem)
|
||||
let dataset = InMemDataset::new(items);
|
||||
let dataset = MapperDataset::new(dataset, BytesToImage);
|
||||
|
||||
Self { dataset }
|
||||
}
|
||||
|
||||
# /// Download the MNIST dataset files from the web.
|
||||
# /// Panics if the download cannot be completed or the content of the file cannot be written to disk.
|
||||
# fn download(split: &str) -> PathBuf {
|
||||
# // Dataset files are stored un the burn-dataset cache directory
|
||||
# let cache_dir = dirs::home_dir()
|
||||
# .expect("Could not get home directory")
|
||||
# .join(".cache")
|
||||
# .join("burn-dataset");
|
||||
# let split_dir = cache_dir.join("mnist").join(split);
|
||||
#
|
||||
# if !split_dir.exists() {
|
||||
# create_dir_all(&split_dir).expect("Failed to create base directory");
|
||||
# }
|
||||
#
|
||||
# // Download split files
|
||||
# match split {
|
||||
# "train" => {
|
||||
# MnistDataset::download_file(TRAIN_IMAGES, &split_dir);
|
||||
# MnistDataset::download_file(TRAIN_LABELS, &split_dir);
|
||||
# }
|
||||
# "test" => {
|
||||
# MnistDataset::download_file(TEST_IMAGES, &split_dir);
|
||||
# MnistDataset::download_file(TEST_LABELS, &split_dir);
|
||||
# }
|
||||
# _ => panic!("Invalid split specified {}", split),
|
||||
# };
|
||||
#
|
||||
# split_dir
|
||||
# }
|
||||
#
|
||||
# /// Download a file from the MNIST dataset URL to the destination directory.
|
||||
# /// File download progress is reported with the help of a [progress bar](indicatif).
|
||||
# fn download_file<P: AsRef<Path>>(name: &str, dest_dir: &P) -> PathBuf {
|
||||
# // Output file name
|
||||
# let file_name = dest_dir.as_ref().join(name);
|
||||
#
|
||||
# if !file_name.exists() {
|
||||
# // Download gzip file
|
||||
# let bytes = download_file_as_bytes(&format!("{URL}{name}.gz"), name);
|
||||
#
|
||||
# // Create file to write the downloaded content to
|
||||
# let mut output_file = File::create(&file_name).unwrap();
|
||||
#
|
||||
# // Decode gzip file content and write to disk
|
||||
# let mut gz_buffer = GzDecoder::new(&bytes[..]);
|
||||
# std::io::copy(&mut gz_buffer, &mut output_file).unwrap();
|
||||
# }
|
||||
#
|
||||
# file_name
|
||||
# }
|
||||
#
|
||||
# /// Read images at the provided path for the specified split.
|
||||
# /// Each image is a vector of bytes.
|
||||
# fn read_images<P: AsRef<Path>>(root: &P, split: &str) -> Vec<Vec<u8>> {
|
||||
# let file_name = if split == "train" {
|
||||
# TRAIN_IMAGES
|
||||
# } else {
|
||||
# TEST_IMAGES
|
||||
# };
|
||||
# let file_name = root.as_ref().join(file_name);
|
||||
#
|
||||
# // Read number of images from 16-byte header metadata
|
||||
# let mut f = File::open(file_name).unwrap();
|
||||
# let mut buf = [0u8; 4];
|
||||
# let _ = f.seek(SeekFrom::Start(4)).unwrap();
|
||||
# f.read_exact(&mut buf)
|
||||
# .expect("Should be able to read image file header");
|
||||
# let size = u32::from_be_bytes(buf);
|
||||
#
|
||||
# let mut buf_images: Vec<u8> = vec![0u8; WIDTH * HEIGHT * (size as usize)];
|
||||
# let _ = f.seek(SeekFrom::Start(16)).unwrap();
|
||||
# f.read_exact(&mut buf_images)
|
||||
# .expect("Should be able to read image file header");
|
||||
#
|
||||
# buf_images
|
||||
# .chunks(WIDTH * HEIGHT)
|
||||
# .map(|chunk| chunk.to_vec())
|
||||
# .collect()
|
||||
# }
|
||||
#
|
||||
# /// Read labels at the provided path for the specified split.
|
||||
# fn read_labels<P: AsRef<Path>>(root: &P, split: &str) -> Vec<u8> {
|
||||
# let file_name = if split == "train" {
|
||||
# TRAIN_LABELS
|
||||
# } else {
|
||||
# TEST_LABELS
|
||||
# };
|
||||
# let file_name = root.as_ref().join(file_name);
|
||||
#
|
||||
# // Read number of labels from 8-byte header metadata
|
||||
# let mut f = File::open(file_name).unwrap();
|
||||
# let mut buf = [0u8; 4];
|
||||
# let _ = f.seek(SeekFrom::Start(4)).unwrap();
|
||||
# f.read_exact(&mut buf)
|
||||
# .expect("Should be able to read label file header");
|
||||
# let size = u32::from_be_bytes(buf);
|
||||
#
|
||||
# let mut buf_labels: Vec<u8> = vec![0u8; size as usize];
|
||||
# let _ = f.seek(SeekFrom::Start(8)).unwrap();
|
||||
# f.read_exact(&mut buf_labels)
|
||||
# .expect("Should be able to read labels from file");
|
||||
#
|
||||
# buf_labels
|
||||
# }
|
||||
}
|
||||
```
|
||||
|
||||
Since the `MnistDataset` simply wraps a `MapperDataset` instance with `InMemDataset`, we can easily
|
||||
implement the `Dataset` trait.
|
||||
|
||||
```rust, ignore
|
||||
impl Dataset<MnistItem> for MnistDataset {
|
||||
fn get(&self, index: usize) -> Option<MnistItem> {
|
||||
self.dataset.get(index)
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.dataset.len()
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The only thing missing now is the `Batcher`, which we already went over
|
||||
[in the basic workflow guide](../basic-workflow/data.md). The `Batcher` takes a list of `MnistItem`
|
||||
retrieved by the dataloader as input and returns a batch of images as a 3D tensor along with their
|
||||
targets.
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 99 KiB |
Loading…
Reference in New Issue