burn/examples/custom-image-dataset
github-actions[bot] dd60446946
Combined PRs (#1874)
* Bump cudarc from 0.11.0 to 0.11.4

Bumps [cudarc](https://github.com/coreylowman/cudarc) from 0.11.0 to 0.11.4.
- [Release notes](https://github.com/coreylowman/cudarc/releases)
- [Commits](https://github.com/coreylowman/cudarc/compare/v0.11.0...v0.11.4)

---
updated-dependencies:
- dependency-name: cudarc
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump clap from 4.5.4 to 4.5.6

Bumps [clap](https://github.com/clap-rs/clap) from 4.5.4 to 4.5.6.
- [Release notes](https://github.com/clap-rs/clap/releases)
- [Changelog](https://github.com/clap-rs/clap/blob/master/CHANGELOG.md)
- [Commits](https://github.com/clap-rs/clap/compare/clap_complete-v4.5.4...v4.5.6)

---
updated-dependencies:
- dependency-name: clap
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump tar from 0.4.40 to 0.4.41

Bumps [tar](https://github.com/alexcrichton/tar-rs) from 0.4.40 to 0.4.41.
- [Commits](https://github.com/alexcrichton/tar-rs/compare/0.4.40...0.4.41)

---
updated-dependencies:
- dependency-name: tar
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump strum_macros from 0.26.2 to 0.26.4

Bumps [strum_macros](https://github.com/Peternator7/strum) from 0.26.2 to 0.26.4.
- [Release notes](https://github.com/Peternator7/strum/releases)
- [Changelog](https://github.com/Peternator7/strum/blob/master/CHANGELOG.md)
- [Commits](https://github.com/Peternator7/strum/commits)

---
updated-dependencies:
- dependency-name: strum_macros
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump zip from 2.1.2 to 2.1.3

Bumps [zip](https://github.com/zip-rs/zip2) from 2.1.2 to 2.1.3.
- [Release notes](https://github.com/zip-rs/zip2/releases)
- [Changelog](https://github.com/zip-rs/zip2/blob/master/CHANGELOG.md)
- [Commits](https://github.com/zip-rs/zip2/compare/v2.1.2...v2.1.3)

---
updated-dependencies:
- dependency-name: zip
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-06-10 16:22:08 -04:00
..
examples docs(book-&-examples): modify book and examples with new `prelude` module (#1372) 2024-02-28 13:25:25 -05:00
src Add learner training report summary (#1591) 2024-04-11 12:32:25 -04:00
.gitignore Add `ImageFolderDataset` (#1232) 2024-02-02 16:32:38 -05:00
Cargo.toml Combined PRs (#1874) 2024-06-10 16:22:08 -04:00
README.md Add multi-label classification dataset and metric (#1572) 2024-04-05 13:16:46 -04:00

README.md

Training on a Custom Image Dataset

In this example, a simple CNN model is trained from scratch on the CIFAR-10 dataset by leveraging the ImageFolderDataset struct to retrieve images from a folder structure on disk.

Since the original source is in binary format, the data is downloaded from a fastai mirror in a folder structure with .png images.

cifar10
├── labels.txt
├── test
│   ├── airplane
│   ├── automobile
│   ├── bird
│   ├── cat
│   ├── deer
│   ├── dog
│   ├── frog
│   ├── horse
│   ├── ship
│   └── truck
└── train
    ├── airplane
    ├── automobile
    ├── bird
    ├── cat
    ├── deer
    ├── dog
    ├── frog
    ├── horse
    ├── ship
    └── truck

To load the training and test dataset splits, it is as simple as providing the root path to both folders

let train_ds = ImageFolderDataset::new_classification("/path/to/cifar10/train").unwrap();
let test_ds = ImageFolderDataset::new_classification("/path/to/cifar10/test").unwrap();

as is done in CIFAR10Loader for this example.

Example Usage

The CNN model and training recipe used in this example are fairly simple since the objective is to demonstrate how to load a custom image classification dataset from disk. Nonetheless, it still achieves 70-80% accuracy on the test set after just 30 epochs.

Run it with the Torch GPU backend:

export TORCH_CUDA_VERSION=cu121
cargo run --example custom-image-dataset --release --features tch-gpu

Run it with our WGPU backend:

cargo run --example custom-image-dataset --release --features wgpu