2022-12-03 06:42:49 +08:00
|
|
|
# Text Classification
|
|
|
|
|
2023-06-14 21:06:45 +08:00
|
|
|
This project provides an example implementation for training and inferencing text classification
|
|
|
|
models on AG News and DbPedia datasets using the Rust-based Burn Deep Learning Library.
|
|
|
|
|
2024-01-26 05:16:39 +08:00
|
|
|
> **Note**
|
|
|
|
> This example makes use of the HuggingFace [`datasets`](https://huggingface.co/docs/datasets/index)
|
|
|
|
> library to download the datasets. Make sure you have [Python](https://www.python.org/downloads/)
|
|
|
|
> installed on your computer.
|
|
|
|
|
2023-06-14 21:06:45 +08:00
|
|
|
## Dataset Details
|
|
|
|
|
|
|
|
- AG News: The AG News dataset is a collection of news articles from more than 2000 news sources.
|
|
|
|
This library helps you load and process this dataset, categorizing articles into four classes:
|
|
|
|
"World", "Sports", "Business", and "Technology".
|
|
|
|
|
|
|
|
- DbPedia: The DbPedia dataset is a large multi-class text classification dataset extracted from
|
|
|
|
Wikipedia. This library helps you load and process this dataset, categorizing articles into 14
|
|
|
|
classes including "Company", "Educational Institution", "Artist", among others.
|
|
|
|
|
2023-07-25 21:50:00 +08:00
|
|
|
# Usage
|
2022-12-03 06:42:49 +08:00
|
|
|
|
2023-07-25 21:50:00 +08:00
|
|
|
## Torch GPU backend
|
2023-04-12 00:47:30 +08:00
|
|
|
|
2022-12-03 06:42:49 +08:00
|
|
|
```bash
|
2023-12-02 03:33:28 +08:00
|
|
|
git clone https://github.com/tracel-ai/burn.git
|
2022-12-03 06:42:49 +08:00
|
|
|
cd burn
|
2023-07-25 21:50:00 +08:00
|
|
|
|
2022-12-03 06:42:49 +08:00
|
|
|
# Use the --release flag to really speed up training.
|
2023-12-02 03:33:28 +08:00
|
|
|
# Use the f16 feature if your CUDA device supports FP16 (half precision) operations. May not work well on every device.
|
2023-07-25 21:50:00 +08:00
|
|
|
|
2024-02-11 01:01:45 +08:00
|
|
|
export TORCH_CUDA_VERSION=cu121 # Set the cuda version (CUDA users)
|
2023-07-25 21:50:00 +08:00
|
|
|
|
|
|
|
# AG News
|
|
|
|
cargo run --example ag-news-train --release --features tch-gpu # Train on the ag news dataset
|
|
|
|
cargo run --example ag-news-infer --release --features tch-gpu # Run inference on the ag news dataset
|
2023-03-27 04:51:37 +08:00
|
|
|
|
2023-07-25 21:50:00 +08:00
|
|
|
# DbPedia
|
|
|
|
cargo run --example db-pedia-train --release --features tch-gpu # Train on the db pedia dataset
|
|
|
|
cargo run --example db-pedia-infer --release --features tch-gpu # Run inference db pedia dataset
|
2022-12-03 06:42:49 +08:00
|
|
|
```
|
2023-04-12 00:47:30 +08:00
|
|
|
|
2023-07-25 21:50:00 +08:00
|
|
|
## Torch CPU backend
|
|
|
|
|
|
|
|
```bash
|
2023-12-02 03:33:28 +08:00
|
|
|
git clone https://github.com/tracel-ai/burn.git
|
2023-07-25 21:50:00 +08:00
|
|
|
cd burn
|
|
|
|
|
|
|
|
# Use the --release flag to really speed up training.
|
|
|
|
|
|
|
|
# AG News
|
|
|
|
cargo run --example ag-news-train --release --features tch-cpu # Train on the ag news dataset
|
|
|
|
cargo run --example ag-news-infer --release --features tch-cpu # Run inference on the ag news dataset
|
|
|
|
|
|
|
|
# DbPedia
|
|
|
|
cargo run --example db-pedia-train --release --features tch-cpu # Train on the db pedia dataset
|
|
|
|
cargo run --example db-pedia-infer --release --features tch-cpu # Run inference db pedia dataset
|
|
|
|
```
|
2023-04-12 00:47:30 +08:00
|
|
|
|
2023-07-25 21:50:00 +08:00
|
|
|
## ndarray backend
|
2023-04-12 00:47:30 +08:00
|
|
|
|
|
|
|
```bash
|
2023-12-02 03:33:28 +08:00
|
|
|
git clone https://github.com/tracel-ai/burn.git
|
2023-04-12 00:47:30 +08:00
|
|
|
cd burn
|
2023-07-25 21:50:00 +08:00
|
|
|
|
2023-04-12 00:47:30 +08:00
|
|
|
# Use the --release flag to really speed up training.
|
|
|
|
|
2023-07-25 21:50:00 +08:00
|
|
|
# Replace ndarray by ndarray-blas-netlib, ndarray-blas-openblas or ndarray-blas-accelerate for different matmul techniques
|
|
|
|
|
|
|
|
# AG News
|
|
|
|
cargo run --example ag-news-train --release --features ndarray # Train on the ag news dataset
|
|
|
|
cargo run --example ag-news-infer --release --features ndarray # Run inference on the ag news dataset
|
|
|
|
|
|
|
|
# DbPedia
|
|
|
|
cargo run --example db-pedia-train --release --features ndarray # Train on the db pedia dataset
|
|
|
|
cargo run --example db-pedia-infer --release --features ndarray # Run inference db pedia dataset
|
2023-04-12 00:47:30 +08:00
|
|
|
```
|
2023-07-25 21:50:00 +08:00
|
|
|
|
|
|
|
## WGPU backend
|
|
|
|
|
|
|
|
```bash
|
2023-12-02 03:33:28 +08:00
|
|
|
git clone https://github.com/tracel-ai/burn.git
|
2023-07-25 21:50:00 +08:00
|
|
|
cd burn
|
|
|
|
|
|
|
|
# Use the --release flag to really speed up training.
|
|
|
|
|
|
|
|
# AG News
|
|
|
|
cargo run --example ag-news-train --release --features wgpu # Train on the ag news dataset
|
|
|
|
cargo run --example ag-news-infer --release --features wgpu # Run inference on the ag news dataset
|
|
|
|
|
|
|
|
# DbPedia
|
|
|
|
cargo run --example db-pedia-train --release --features wgpu # Train on the db pedia dataset
|
|
|
|
cargo run --example db-pedia-infer --release --features wgpu # Run inference db pedia dataset
|
2023-12-02 03:33:28 +08:00
|
|
|
```
|
2024-08-20 00:53:15 +08:00
|
|
|
|
|
|
|
## CUDA backend
|
|
|
|
|
|
|
|
```bash
|
|
|
|
git clone https://github.com/tracel-ai/burn.git
|
|
|
|
cd burn
|
|
|
|
|
|
|
|
# Use the --release flag to really speed up training.
|
|
|
|
|
|
|
|
# AG News
|
|
|
|
cargo run --example ag-news-train --release --features cuda-jit # Train on the ag news dataset
|
|
|
|
cargo run --example ag-news-infer --release --features cuda-jit # Run inference on the ag news dataset
|
|
|
|
```
|