burn/examples/text-classification
Guillaume Lagrange cdd1fa1672
Refactor tensor data (#1916)
* Move distribution to module

* Add new TensorData with serialization support

* Implement display and from for TensorData

* Add missing Cargo.lock

* Add missing bytemuck feature

* Add zeros, ones, full and random TensorData methods

* Refactor Data -> TensorData usage

* Fix tests

Since TensorData is not generic over the element type anymore no type inference can be done by the compiler. We must explicitly cast the expected results to the expected backend type.

* Remove commented line

* Fix import

* Add record-backward-compat

* Remove dim const generic from TensorData

* Support NestedValue de/serialization with TensorData

* Fix burn-jit tests

* Remove eprinln

* Refactor onnx import to use TensorData

* Fix tch from_data

* Fix nested value serialization for u8

* Fix missing import

* Fix reduce min onnx test

* Fix deprecated attribute

* Remove shape getter

* Remove strict assert in tests

* Add tensor data as_bytes

* Add tensor check for rank mismatch

* Fix typo (dimensions plural)

* Fix error message

* Update book examples with from_data and fix Display impl for TensorData

* Add deprecation note
2024-06-26 20:22:19 -04:00
..
examples Remove GraphicsAPI generic for WgpuRuntime (#1888) 2024-06-17 09:04:25 -04:00
src Refactor tensor data (#1916) 2024-06-26 20:22:19 -04:00
Cargo.toml Combined PRs (#1708) 2024-04-29 07:52:56 -04:00
README.md Update TORCH_CUDA_VERSION usage (#1284) 2024-02-10 12:01:45 -05:00

README.md

Text Classification

This project provides an example implementation for training and inferencing text classification models on AG News and DbPedia datasets using the Rust-based Burn Deep Learning Library.

Note
This example makes use of the HuggingFace datasets library to download the datasets. Make sure you have Python installed on your computer.

Dataset Details

  • AG News: The AG News dataset is a collection of news articles from more than 2000 news sources. This library helps you load and process this dataset, categorizing articles into four classes: "World", "Sports", "Business", and "Technology".

  • DbPedia: The DbPedia dataset is a large multi-class text classification dataset extracted from Wikipedia. This library helps you load and process this dataset, categorizing articles into 14 classes including "Company", "Educational Institution", "Artist", among others.

Usage

Torch GPU backend

git clone https://github.com/tracel-ai/burn.git
cd burn

# Use the --release flag to really speed up training.
# Use the f16 feature if your CUDA device supports FP16 (half precision) operations. May not work well on every device.

export TORCH_CUDA_VERSION=cu121  # Set the cuda version (CUDA users)

# AG News
cargo run --example ag-news-train --release --features tch-gpu   # Train on the ag news dataset
cargo run --example ag-news-infer --release --features tch-gpu   # Run inference on the ag news dataset

# DbPedia
cargo run --example db-pedia-train --release --features tch-gpu  # Train on the db pedia dataset
cargo run --example db-pedia-infer --release --features tch-gpu  # Run inference db pedia dataset

Torch CPU backend

git clone https://github.com/tracel-ai/burn.git
cd burn

# Use the --release flag to really speed up training.

# AG News
cargo run --example ag-news-train --release --features tch-cpu   # Train on the ag news dataset
cargo run --example ag-news-infer --release --features tch-cpu   # Run inference on the ag news dataset

# DbPedia
cargo run --example db-pedia-train --release --features tch-cpu  # Train on the db pedia dataset
cargo run --example db-pedia-infer --release --features tch-cpu  # Run inference db pedia dataset

ndarray backend

git clone https://github.com/tracel-ai/burn.git
cd burn

# Use the --release flag to really speed up training.

# Replace ndarray by ndarray-blas-netlib, ndarray-blas-openblas or ndarray-blas-accelerate for different matmul techniques

# AG News
cargo run --example ag-news-train --release --features ndarray   # Train on the ag news dataset
cargo run --example ag-news-infer --release --features ndarray   # Run inference on the ag news dataset

# DbPedia
cargo run --example db-pedia-train --release --features ndarray  # Train on the db pedia dataset
cargo run --example db-pedia-infer --release --features ndarray  # Run inference db pedia dataset

WGPU backend

git clone https://github.com/tracel-ai/burn.git
cd burn

# Use the --release flag to really speed up training.

# AG News
cargo run --example ag-news-train --release --features wgpu   # Train on the ag news dataset
cargo run --example ag-news-infer --release --features wgpu   # Run inference on the ag news dataset

# DbPedia
cargo run --example db-pedia-train --release --features wgpu  # Train on the db pedia dataset
cargo run --example db-pedia-infer --release --features wgpu  # Run inference db pedia dataset