burn/examples/text-classification
nathaniel fa5cc760b9 Merge branch 'fix/feature-flags' into fix/cuda/stability 2024-08-03 12:49:47 -04:00
..
examples enabled in example 2024-08-02 13:55:43 -04:00
src Refactor tensor data (#1916) 2024-06-26 20:22:19 -04:00
Cargo.toml Merge branch 'fix/feature-flags' into fix/cuda/stability 2024-08-03 12:49:47 -04:00
README.md Update TORCH_CUDA_VERSION usage (#1284) 2024-02-10 12:01:45 -05:00

README.md

Text Classification

This project provides an example implementation for training and inferencing text classification models on AG News and DbPedia datasets using the Rust-based Burn Deep Learning Library.

Note
This example makes use of the HuggingFace datasets library to download the datasets. Make sure you have Python installed on your computer.

Dataset Details

  • AG News: The AG News dataset is a collection of news articles from more than 2000 news sources. This library helps you load and process this dataset, categorizing articles into four classes: "World", "Sports", "Business", and "Technology".

  • DbPedia: The DbPedia dataset is a large multi-class text classification dataset extracted from Wikipedia. This library helps you load and process this dataset, categorizing articles into 14 classes including "Company", "Educational Institution", "Artist", among others.

Usage

Torch GPU backend

git clone https://github.com/tracel-ai/burn.git
cd burn

# Use the --release flag to really speed up training.
# Use the f16 feature if your CUDA device supports FP16 (half precision) operations. May not work well on every device.

export TORCH_CUDA_VERSION=cu121  # Set the cuda version (CUDA users)

# AG News
cargo run --example ag-news-train --release --features tch-gpu   # Train on the ag news dataset
cargo run --example ag-news-infer --release --features tch-gpu   # Run inference on the ag news dataset

# DbPedia
cargo run --example db-pedia-train --release --features tch-gpu  # Train on the db pedia dataset
cargo run --example db-pedia-infer --release --features tch-gpu  # Run inference db pedia dataset

Torch CPU backend

git clone https://github.com/tracel-ai/burn.git
cd burn

# Use the --release flag to really speed up training.

# AG News
cargo run --example ag-news-train --release --features tch-cpu   # Train on the ag news dataset
cargo run --example ag-news-infer --release --features tch-cpu   # Run inference on the ag news dataset

# DbPedia
cargo run --example db-pedia-train --release --features tch-cpu  # Train on the db pedia dataset
cargo run --example db-pedia-infer --release --features tch-cpu  # Run inference db pedia dataset

ndarray backend

git clone https://github.com/tracel-ai/burn.git
cd burn

# Use the --release flag to really speed up training.

# Replace ndarray by ndarray-blas-netlib, ndarray-blas-openblas or ndarray-blas-accelerate for different matmul techniques

# AG News
cargo run --example ag-news-train --release --features ndarray   # Train on the ag news dataset
cargo run --example ag-news-infer --release --features ndarray   # Run inference on the ag news dataset

# DbPedia
cargo run --example db-pedia-train --release --features ndarray  # Train on the db pedia dataset
cargo run --example db-pedia-infer --release --features ndarray  # Run inference db pedia dataset

WGPU backend

git clone https://github.com/tracel-ai/burn.git
cd burn

# Use the --release flag to really speed up training.

# AG News
cargo run --example ag-news-train --release --features wgpu   # Train on the ag news dataset
cargo run --example ag-news-infer --release --features wgpu   # Run inference on the ag news dataset

# DbPedia
cargo run --example db-pedia-train --release --features wgpu  # Train on the db pedia dataset
cargo run --example db-pedia-infer --release --features wgpu  # Run inference db pedia dataset