burn/examples/text-classification
Guillaume Lagrange 0cbe9a927d
Add learner training report summary (#1591)
* Add training report summary

* Fix LossMetric batch size state

* Add NumericEntry de/serialize

* Fix clippy suggestion

* Compact recorder does not use compression (anymore)

* Add learner summary expected results tests

* Add summary to learner builder and automatically display in fit

- Add LearnerSummaryConfig
- Keep track of summary metrics names
- Add model field when displaying from learner.fit()
2024-04-11 12:32:25 -04:00
..
examples docs(book-&-examples): modify book and examples with new `prelude` module (#1372) 2024-02-28 13:25:25 -05:00
src Add learner training report summary (#1591) 2024-04-11 12:32:25 -04:00
Cargo.toml [refactor] Move burn crates to their own crates directory (#1336) 2024-02-20 13:57:55 -05:00
README.md Update TORCH_CUDA_VERSION usage (#1284) 2024-02-10 12:01:45 -05:00

README.md

Text Classification

This project provides an example implementation for training and inferencing text classification models on AG News and DbPedia datasets using the Rust-based Burn Deep Learning Library.

Note
This example makes use of the HuggingFace datasets library to download the datasets. Make sure you have Python installed on your computer.

Dataset Details

  • AG News: The AG News dataset is a collection of news articles from more than 2000 news sources. This library helps you load and process this dataset, categorizing articles into four classes: "World", "Sports", "Business", and "Technology".

  • DbPedia: The DbPedia dataset is a large multi-class text classification dataset extracted from Wikipedia. This library helps you load and process this dataset, categorizing articles into 14 classes including "Company", "Educational Institution", "Artist", among others.

Usage

Torch GPU backend

git clone https://github.com/tracel-ai/burn.git
cd burn

# Use the --release flag to really speed up training.
# Use the f16 feature if your CUDA device supports FP16 (half precision) operations. May not work well on every device.

export TORCH_CUDA_VERSION=cu121  # Set the cuda version (CUDA users)

# AG News
cargo run --example ag-news-train --release --features tch-gpu   # Train on the ag news dataset
cargo run --example ag-news-infer --release --features tch-gpu   # Run inference on the ag news dataset

# DbPedia
cargo run --example db-pedia-train --release --features tch-gpu  # Train on the db pedia dataset
cargo run --example db-pedia-infer --release --features tch-gpu  # Run inference db pedia dataset

Torch CPU backend

git clone https://github.com/tracel-ai/burn.git
cd burn

# Use the --release flag to really speed up training.

# AG News
cargo run --example ag-news-train --release --features tch-cpu   # Train on the ag news dataset
cargo run --example ag-news-infer --release --features tch-cpu   # Run inference on the ag news dataset

# DbPedia
cargo run --example db-pedia-train --release --features tch-cpu  # Train on the db pedia dataset
cargo run --example db-pedia-infer --release --features tch-cpu  # Run inference db pedia dataset

ndarray backend

git clone https://github.com/tracel-ai/burn.git
cd burn

# Use the --release flag to really speed up training.

# Replace ndarray by ndarray-blas-netlib, ndarray-blas-openblas or ndarray-blas-accelerate for different matmul techniques

# AG News
cargo run --example ag-news-train --release --features ndarray   # Train on the ag news dataset
cargo run --example ag-news-infer --release --features ndarray   # Run inference on the ag news dataset

# DbPedia
cargo run --example db-pedia-train --release --features ndarray  # Train on the db pedia dataset
cargo run --example db-pedia-infer --release --features ndarray  # Run inference db pedia dataset

WGPU backend

git clone https://github.com/tracel-ai/burn.git
cd burn

# Use the --release flag to really speed up training.

# AG News
cargo run --example ag-news-train --release --features wgpu   # Train on the ag news dataset
cargo run --example ag-news-infer --release --features wgpu   # Run inference on the ag news dataset

# DbPedia
cargo run --example db-pedia-train --release --features wgpu  # Train on the db pedia dataset
cargo run --example db-pedia-infer --release --features wgpu  # Run inference db pedia dataset