mirror of https://github.com/tracel-ai/burn.git
Add a dataset/dataloader/batcher usage section (#2161)
* Add a dataset/dataloader/batcher usage section * Fix typos
This commit is contained in:
parent
75a2850047
commit
73d4b11aa2
|
@ -11,3 +11,6 @@ language = "en"
|
||||||
multilingual = false
|
multilingual = false
|
||||||
src = "src"
|
src = "src"
|
||||||
title = "The Burn Book 🔥"
|
title = "The Burn Book 🔥"
|
||||||
|
|
||||||
|
[output.html]
|
||||||
|
mathjax-support = true
|
||||||
|
|
|
@ -1,8 +1,13 @@
|
||||||
# Dataset
|
# Dataset
|
||||||
|
|
||||||
In most deep learning training performed on datasets (with perhaps the exception of reinforcement
|
At its core, a dataset is a collection of data typically related to a specific analysis or
|
||||||
learning), it is essential to provide a convenient and performant API. The dataset trait is quite
|
processing task. The data modality can vary depending on the task, but most datasets primarily
|
||||||
similar to the dataset abstract class in PyTorch:
|
consist of images, texts, audio or videos.
|
||||||
|
|
||||||
|
This data source represents an integral part of machine learning to successfully train a model. Thus,
|
||||||
|
it is essential to provide a convenient and performant API to handle your data. Since this process
|
||||||
|
varies wildly from one problem to another, it is defined as a trait that should be implemented on
|
||||||
|
your type. The dataset trait is quite similar to the dataset abstract class in PyTorch:
|
||||||
|
|
||||||
```rust, ignore
|
```rust, ignore
|
||||||
pub trait Dataset<I>: Send + Sync {
|
pub trait Dataset<I>: Send + Sync {
|
||||||
|
@ -133,3 +138,263 @@ multiple times over the dataset and only checkpoint when done. You can consider
|
||||||
dataset as the number of iterations before performing checkpointing and running the validation.
|
dataset as the number of iterations before performing checkpointing and running the validation.
|
||||||
There is nothing stopping you from returning different items even when called with the same `index`
|
There is nothing stopping you from returning different items even when called with the same `index`
|
||||||
multiple times.
|
multiple times.
|
||||||
|
|
||||||
|
## How Is The Dataset Used?
|
||||||
|
|
||||||
|
During training, the dataset is used to access the data samples and, for most use cases in
|
||||||
|
supervised learning, their corresponding ground-truth labels. Remember that the `Dataset` trait
|
||||||
|
implementation is responsible to retrieve the data from its source, usually some sort of data
|
||||||
|
storage. At this point, the dataset could be naively iterated over to provide the model a single
|
||||||
|
sample to process at a time, but this is not very efficient.
|
||||||
|
|
||||||
|
Instead, we collect multiple samples that the model can process as a _batch_ to fully leverage
|
||||||
|
modern hardware (e.g., GPUs - which have impressing parallel processing capabilities). Since each
|
||||||
|
data sample in the dataset can be collected independently, the data loading is typically done in
|
||||||
|
parallel to further speed things up. In this case, we parallelize the data loading using a
|
||||||
|
multi-threaded `BatchDataLoader` to obtain a sequence of items from the `Dataset` implementation.
|
||||||
|
Finally, the sequence of items is combined into a batched tensor that can be used as input to a
|
||||||
|
model with the `Batcher` trait implementation. Other tensor operations can be performed during this
|
||||||
|
step to prepare the batch data, as is done [in the basic workflow guide](../basic-workflow/data.md).
|
||||||
|
The process is illustrated in the figure below for the MNIST dataset.
|
||||||
|
|
||||||
|
<img title="Burn Data Loading Pipeline" alt="Burn Data Loading Pipeline" src="./dataset.png">
|
||||||
|
|
||||||
|
Although we have conveniently implemented the
|
||||||
|
[`MnistDataset`](https://github.com/tracel-ai/burn/blob/main/crates/burn-dataset/src/vision/mnist.rs)
|
||||||
|
used in the guide, we'll go over its implementation to demonstrate how the `Dataset` and `Batcher`
|
||||||
|
traits are used.
|
||||||
|
|
||||||
|
The [MNIST dataset](http://yann.lecun.com/exdb/mnist/) of handwritten digits has a training set of
|
||||||
|
60,000 examples and a test set of 10,000 examples. A single item in the dataset is represented by a
|
||||||
|
\\(28 \times 28\\) pixels black-and-white image (stored as raw bytes) with its corresponding label
|
||||||
|
(a digit between \\(0\\) and \\(9\\)). This is defined by the `MnistItemRaw` struct.
|
||||||
|
|
||||||
|
```rust, ignore
|
||||||
|
# #[derive(Deserialize, Debug, Clone)]
|
||||||
|
struct MnistItemRaw {
|
||||||
|
pub image_bytes: Vec<u8>,
|
||||||
|
pub label: u8,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
With single-channel images of such low resolution, the entire training and test sets can be loaded
|
||||||
|
in memory at once. Therefore, we leverage the already existing `InMemDataset` to retrieve the raw
|
||||||
|
images and labels data. At this point, the image data is still just a bunch of bytes, but we want to
|
||||||
|
retrieve the _structured_ image data in its intended form. For that, we can define a `MapperDataset`
|
||||||
|
that transforms the raw image bytes to a 2D array image (which we convert to float while we're at
|
||||||
|
it).
|
||||||
|
|
||||||
|
```rust, ignore
|
||||||
|
const WIDTH: usize = 28;
|
||||||
|
const HEIGHT: usize = 28;
|
||||||
|
|
||||||
|
# /// MNIST item.
|
||||||
|
# #[derive(Deserialize, Serialize, Debug, Clone)]
|
||||||
|
pub struct MnistItem {
|
||||||
|
/// Image as a 2D array of floats.
|
||||||
|
pub image: [[f32; WIDTH]; HEIGHT],
|
||||||
|
|
||||||
|
/// Label of the image.
|
||||||
|
pub label: u8,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct BytesToImage;
|
||||||
|
|
||||||
|
impl Mapper<MnistItemRaw, MnistItem> for BytesToImage {
|
||||||
|
/// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image).
|
||||||
|
fn map(&self, item: &MnistItemRaw) -> MnistItem {
|
||||||
|
// Ensure the image dimensions are correct.
|
||||||
|
debug_assert_eq!(item.image_bytes.len(), WIDTH * HEIGHT);
|
||||||
|
|
||||||
|
// Convert the image to a 2D array of floats.
|
||||||
|
let mut image_array = [[0f32; WIDTH]; HEIGHT];
|
||||||
|
for (i, pixel) in item.image_bytes.iter().enumerate() {
|
||||||
|
let x = i % WIDTH;
|
||||||
|
let y = i / HEIGHT;
|
||||||
|
image_array[y][x] = *pixel as f32;
|
||||||
|
}
|
||||||
|
|
||||||
|
MnistItem {
|
||||||
|
image: image_array,
|
||||||
|
label: item.label,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type MappedDataset = MapperDataset<InMemDataset<MnistItemRaw>, BytesToImage, MnistItemRaw>;
|
||||||
|
|
||||||
|
# /// The MNIST dataset consists of 70,000 28x28 black-and-white images in 10 classes (one for each digits), with 7,000
|
||||||
|
# /// images per class. There are 60,000 training images and 10,000 test images.
|
||||||
|
# ///
|
||||||
|
# /// The data is downloaded from the web from the [CVDF mirror](https://github.com/cvdfoundation/mnist).
|
||||||
|
pub struct MnistDataset {
|
||||||
|
dataset: MappedDataset,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
To construct the `MnistDataset`, the data source must be parsed into the expected `MappedDataset`
|
||||||
|
type. Since both the train and test sets use the same file format, we can separate the functionality
|
||||||
|
to load the `train()` and `test()` dataset.
|
||||||
|
|
||||||
|
```rust, ignore
|
||||||
|
|
||||||
|
impl MnistDataset {
|
||||||
|
/// Creates a new train dataset.
|
||||||
|
pub fn train() -> Self {
|
||||||
|
Self::new("train")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new test dataset.
|
||||||
|
pub fn test() -> Self {
|
||||||
|
Self::new("test")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new(split: &str) -> Self {
|
||||||
|
// Download dataset
|
||||||
|
let root = MnistDataset::download(split);
|
||||||
|
|
||||||
|
// Parse data as vector of images bytes and vector of labels
|
||||||
|
let images: Vec<Vec<u8>> = MnistDataset::read_images(&root, split);
|
||||||
|
let labels: Vec<u8> = MnistDataset::read_labels(&root, split);
|
||||||
|
|
||||||
|
// Collect as vector of MnistItemRaw
|
||||||
|
let items: Vec<_> = images
|
||||||
|
.into_iter()
|
||||||
|
.zip(labels)
|
||||||
|
.map(|(image_bytes, label)| MnistItemRaw { image_bytes, label })
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Create the MapperDataset for InMemDataset<MnistItemRaw> to transform
|
||||||
|
// items (MnistItemRaw -> MnistItem)
|
||||||
|
let dataset = InMemDataset::new(items);
|
||||||
|
let dataset = MapperDataset::new(dataset, BytesToImage);
|
||||||
|
|
||||||
|
Self { dataset }
|
||||||
|
}
|
||||||
|
|
||||||
|
# /// Download the MNIST dataset files from the web.
|
||||||
|
# /// Panics if the download cannot be completed or the content of the file cannot be written to disk.
|
||||||
|
# fn download(split: &str) -> PathBuf {
|
||||||
|
# // Dataset files are stored un the burn-dataset cache directory
|
||||||
|
# let cache_dir = dirs::home_dir()
|
||||||
|
# .expect("Could not get home directory")
|
||||||
|
# .join(".cache")
|
||||||
|
# .join("burn-dataset");
|
||||||
|
# let split_dir = cache_dir.join("mnist").join(split);
|
||||||
|
#
|
||||||
|
# if !split_dir.exists() {
|
||||||
|
# create_dir_all(&split_dir).expect("Failed to create base directory");
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# // Download split files
|
||||||
|
# match split {
|
||||||
|
# "train" => {
|
||||||
|
# MnistDataset::download_file(TRAIN_IMAGES, &split_dir);
|
||||||
|
# MnistDataset::download_file(TRAIN_LABELS, &split_dir);
|
||||||
|
# }
|
||||||
|
# "test" => {
|
||||||
|
# MnistDataset::download_file(TEST_IMAGES, &split_dir);
|
||||||
|
# MnistDataset::download_file(TEST_LABELS, &split_dir);
|
||||||
|
# }
|
||||||
|
# _ => panic!("Invalid split specified {}", split),
|
||||||
|
# };
|
||||||
|
#
|
||||||
|
# split_dir
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# /// Download a file from the MNIST dataset URL to the destination directory.
|
||||||
|
# /// File download progress is reported with the help of a [progress bar](indicatif).
|
||||||
|
# fn download_file<P: AsRef<Path>>(name: &str, dest_dir: &P) -> PathBuf {
|
||||||
|
# // Output file name
|
||||||
|
# let file_name = dest_dir.as_ref().join(name);
|
||||||
|
#
|
||||||
|
# if !file_name.exists() {
|
||||||
|
# // Download gzip file
|
||||||
|
# let bytes = download_file_as_bytes(&format!("{URL}{name}.gz"), name);
|
||||||
|
#
|
||||||
|
# // Create file to write the downloaded content to
|
||||||
|
# let mut output_file = File::create(&file_name).unwrap();
|
||||||
|
#
|
||||||
|
# // Decode gzip file content and write to disk
|
||||||
|
# let mut gz_buffer = GzDecoder::new(&bytes[..]);
|
||||||
|
# std::io::copy(&mut gz_buffer, &mut output_file).unwrap();
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# file_name
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# /// Read images at the provided path for the specified split.
|
||||||
|
# /// Each image is a vector of bytes.
|
||||||
|
# fn read_images<P: AsRef<Path>>(root: &P, split: &str) -> Vec<Vec<u8>> {
|
||||||
|
# let file_name = if split == "train" {
|
||||||
|
# TRAIN_IMAGES
|
||||||
|
# } else {
|
||||||
|
# TEST_IMAGES
|
||||||
|
# };
|
||||||
|
# let file_name = root.as_ref().join(file_name);
|
||||||
|
#
|
||||||
|
# // Read number of images from 16-byte header metadata
|
||||||
|
# let mut f = File::open(file_name).unwrap();
|
||||||
|
# let mut buf = [0u8; 4];
|
||||||
|
# let _ = f.seek(SeekFrom::Start(4)).unwrap();
|
||||||
|
# f.read_exact(&mut buf)
|
||||||
|
# .expect("Should be able to read image file header");
|
||||||
|
# let size = u32::from_be_bytes(buf);
|
||||||
|
#
|
||||||
|
# let mut buf_images: Vec<u8> = vec![0u8; WIDTH * HEIGHT * (size as usize)];
|
||||||
|
# let _ = f.seek(SeekFrom::Start(16)).unwrap();
|
||||||
|
# f.read_exact(&mut buf_images)
|
||||||
|
# .expect("Should be able to read image file header");
|
||||||
|
#
|
||||||
|
# buf_images
|
||||||
|
# .chunks(WIDTH * HEIGHT)
|
||||||
|
# .map(|chunk| chunk.to_vec())
|
||||||
|
# .collect()
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# /// Read labels at the provided path for the specified split.
|
||||||
|
# fn read_labels<P: AsRef<Path>>(root: &P, split: &str) -> Vec<u8> {
|
||||||
|
# let file_name = if split == "train" {
|
||||||
|
# TRAIN_LABELS
|
||||||
|
# } else {
|
||||||
|
# TEST_LABELS
|
||||||
|
# };
|
||||||
|
# let file_name = root.as_ref().join(file_name);
|
||||||
|
#
|
||||||
|
# // Read number of labels from 8-byte header metadata
|
||||||
|
# let mut f = File::open(file_name).unwrap();
|
||||||
|
# let mut buf = [0u8; 4];
|
||||||
|
# let _ = f.seek(SeekFrom::Start(4)).unwrap();
|
||||||
|
# f.read_exact(&mut buf)
|
||||||
|
# .expect("Should be able to read label file header");
|
||||||
|
# let size = u32::from_be_bytes(buf);
|
||||||
|
#
|
||||||
|
# let mut buf_labels: Vec<u8> = vec![0u8; size as usize];
|
||||||
|
# let _ = f.seek(SeekFrom::Start(8)).unwrap();
|
||||||
|
# f.read_exact(&mut buf_labels)
|
||||||
|
# .expect("Should be able to read labels from file");
|
||||||
|
#
|
||||||
|
# buf_labels
|
||||||
|
# }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Since the `MnistDataset` simply wraps a `MapperDataset` instance with `InMemDataset`, we can easily
|
||||||
|
implement the `Dataset` trait.
|
||||||
|
|
||||||
|
```rust, ignore
|
||||||
|
impl Dataset<MnistItem> for MnistDataset {
|
||||||
|
fn get(&self, index: usize) -> Option<MnistItem> {
|
||||||
|
self.dataset.get(index)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn len(&self) -> usize {
|
||||||
|
self.dataset.len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The only thing missing now is the `Batcher`, which we already went over
|
||||||
|
[in the basic workflow guide](../basic-workflow/data.md). The `Batcher` takes a list of `MnistItem`
|
||||||
|
retrieved by the dataloader as input and returns a batch of images as a 3D tensor along with their
|
||||||
|
targets.
|
||||||
|
|
Binary file not shown.
After Width: | Height: | Size: 99 KiB |
Loading…
Reference in New Issue