burn/examples/text-classification
mepatrick73 e1fed792f7
Gather CPA to CubeCL (#2165)
* working version

* cleanup

* wip

* working version of gather

* testsetsetser

* Revert "testsetsetser"

This reverts commit f37b329697.

* Reapply "testsetsetser"

This reverts commit f8ada0044e.

* Revert "testsetsetser"

This reverts commit f37b329697.

* Revert "working version of gather"

This reverts commit f5047c27c8.

* Revert "wip"

This reverts commit abaaa2dd55.

* Revert "Merge branch 'main' into index-cpa-to-cubecl"

This reverts commit 05bed8ea74, reversing
changes made to 94954fc32c.

* Revert "cleanup"

This reverts commit 94954fc32c.

* Revert "working version"

This reverts commit a06933f029.

* gather test

* fix

* fix clippy

* cleanup
2024-08-22 13:44:26 -04:00
..
examples Update model.bin mnist inference web + add cuda-jit flag for ag-news-infer (#2170) 2024-08-19 12:53:15 -04:00
src Fix class_index (#2167) 2024-08-15 12:33:41 -04:00
Cargo.toml Gather CPA to CubeCL (#2165) 2024-08-22 13:44:26 -04:00
README.md Update model.bin mnist inference web + add cuda-jit flag for ag-news-infer (#2170) 2024-08-19 12:53:15 -04:00

README.md

Text Classification

This project provides an example implementation for training and inferencing text classification models on AG News and DbPedia datasets using the Rust-based Burn Deep Learning Library.

Note
This example makes use of the HuggingFace datasets library to download the datasets. Make sure you have Python installed on your computer.

Dataset Details

  • AG News: The AG News dataset is a collection of news articles from more than 2000 news sources. This library helps you load and process this dataset, categorizing articles into four classes: "World", "Sports", "Business", and "Technology".

  • DbPedia: The DbPedia dataset is a large multi-class text classification dataset extracted from Wikipedia. This library helps you load and process this dataset, categorizing articles into 14 classes including "Company", "Educational Institution", "Artist", among others.

Usage

Torch GPU backend

git clone https://github.com/tracel-ai/burn.git
cd burn

# Use the --release flag to really speed up training.
# Use the f16 feature if your CUDA device supports FP16 (half precision) operations. May not work well on every device.

export TORCH_CUDA_VERSION=cu121  # Set the cuda version (CUDA users)

# AG News
cargo run --example ag-news-train --release --features tch-gpu   # Train on the ag news dataset
cargo run --example ag-news-infer --release --features tch-gpu   # Run inference on the ag news dataset

# DbPedia
cargo run --example db-pedia-train --release --features tch-gpu  # Train on the db pedia dataset
cargo run --example db-pedia-infer --release --features tch-gpu  # Run inference db pedia dataset

Torch CPU backend

git clone https://github.com/tracel-ai/burn.git
cd burn

# Use the --release flag to really speed up training.

# AG News
cargo run --example ag-news-train --release --features tch-cpu   # Train on the ag news dataset
cargo run --example ag-news-infer --release --features tch-cpu   # Run inference on the ag news dataset

# DbPedia
cargo run --example db-pedia-train --release --features tch-cpu  # Train on the db pedia dataset
cargo run --example db-pedia-infer --release --features tch-cpu  # Run inference db pedia dataset

ndarray backend

git clone https://github.com/tracel-ai/burn.git
cd burn

# Use the --release flag to really speed up training.

# Replace ndarray by ndarray-blas-netlib, ndarray-blas-openblas or ndarray-blas-accelerate for different matmul techniques

# AG News
cargo run --example ag-news-train --release --features ndarray   # Train on the ag news dataset
cargo run --example ag-news-infer --release --features ndarray   # Run inference on the ag news dataset

# DbPedia
cargo run --example db-pedia-train --release --features ndarray  # Train on the db pedia dataset
cargo run --example db-pedia-infer --release --features ndarray  # Run inference db pedia dataset

WGPU backend

git clone https://github.com/tracel-ai/burn.git
cd burn

# Use the --release flag to really speed up training.

# AG News
cargo run --example ag-news-train --release --features wgpu   # Train on the ag news dataset
cargo run --example ag-news-infer --release --features wgpu   # Run inference on the ag news dataset

# DbPedia
cargo run --example db-pedia-train --release --features wgpu  # Train on the db pedia dataset
cargo run --example db-pedia-infer --release --features wgpu  # Run inference db pedia dataset

CUDA backend

git clone https://github.com/tracel-ai/burn.git
cd burn

# Use the --release flag to really speed up training.

# AG News
cargo run --example ag-news-train --release --features cuda-jit   # Train on the ag news dataset
cargo run --example ag-news-infer --release --features cuda-jit   # Run inference on the ag news dataset