cdd1fa1672
* Move distribution to module * Add new TensorData with serialization support * Implement display and from for TensorData * Add missing Cargo.lock * Add missing bytemuck feature * Add zeros, ones, full and random TensorData methods * Refactor Data -> TensorData usage * Fix tests Since TensorData is not generic over the element type anymore no type inference can be done by the compiler. We must explicitly cast the expected results to the expected backend type. * Remove commented line * Fix import * Add record-backward-compat * Remove dim const generic from TensorData * Support NestedValue de/serialization with TensorData * Fix burn-jit tests * Remove eprinln * Refactor onnx import to use TensorData * Fix tch from_data * Fix nested value serialization for u8 * Fix missing import * Fix reduce min onnx test * Fix deprecated attribute * Remove shape getter * Remove strict assert in tests * Add tensor data as_bytes * Add tensor check for rank mismatch * Fix typo (dimensions plural) * Fix error message * Update book examples with from_data and fix Display impl for TensorData * Add deprecation note |
||
---|---|---|
.. | ||
examples | ||
src | ||
.gitignore | ||
Cargo.toml | ||
README.md |
README.md
Training on a Custom Image Dataset
In this example, a simple CNN model is trained from scratch on the
CIFAR-10 dataset by leveraging the
ImageFolderDataset
struct to retrieve images from a folder structure on disk.
Since the original source is in binary format, the data is downloaded from a
fastai mirror in a
folder structure with .png
images.
cifar10
├── labels.txt
├── test
│ ├── airplane
│ ├── automobile
│ ├── bird
│ ├── cat
│ ├── deer
│ ├── dog
│ ├── frog
│ ├── horse
│ ├── ship
│ └── truck
└── train
├── airplane
├── automobile
├── bird
├── cat
├── deer
├── dog
├── frog
├── horse
├── ship
└── truck
To load the training and test dataset splits, it is as simple as providing the root path to both folders
let train_ds = ImageFolderDataset::new_classification("/path/to/cifar10/train").unwrap();
let test_ds = ImageFolderDataset::new_classification("/path/to/cifar10/test").unwrap();
as is done in CIFAR10Loader
for this example.
Example Usage
The CNN model and training recipe used in this example are fairly simple since the objective is to demonstrate how to load a custom image classification dataset from disk. Nonetheless, it still achieves 70-80% accuracy on the test set after just 30 epochs.
Run it with the Torch GPU backend:
export TORCH_CUDA_VERSION=cu121
cargo run --example custom-image-dataset --release --features tch-gpu
Run it with our WGPU backend:
cargo run --example custom-image-dataset --release --features wgpu