diff --git a/examples/mnist-inference-web/model.bin b/examples/mnist-inference-web/model.bin index f1f2b6e87..4e552ca67 100644 Binary files a/examples/mnist-inference-web/model.bin and b/examples/mnist-inference-web/model.bin differ diff --git a/examples/text-classification/README.md b/examples/text-classification/README.md index cefd5e378..8bc611361 100644 --- a/examples/text-classification/README.md +++ b/examples/text-classification/README.md @@ -92,3 +92,16 @@ cargo run --example ag-news-infer --release --features wgpu # Run inference on cargo run --example db-pedia-train --release --features wgpu # Train on the db pedia dataset cargo run --example db-pedia-infer --release --features wgpu # Run inference db pedia dataset ``` + +## CUDA backend + +```bash +git clone https://github.com/tracel-ai/burn.git +cd burn + +# Use the --release flag to really speed up training. + +# AG News +cargo run --example ag-news-train --release --features cuda-jit # Train on the ag news dataset +cargo run --example ag-news-infer --release --features cuda-jit # Run inference on the ag news dataset +``` diff --git a/examples/text-classification/examples/ag-news-infer.rs b/examples/text-classification/examples/ag-news-infer.rs index f609d842f..9af5c6c6e 100644 --- a/examples/text-classification/examples/ag-news-infer.rs +++ b/examples/text-classification/examples/ag-news-infer.rs @@ -81,6 +81,16 @@ mod wgpu { } } +#[cfg(feature = "cuda-jit")] +mod cuda_jit { + use crate::{launch, ElemType}; + use burn::backend::{cuda_jit::CudaDevice, CudaJit}; + + pub fn run() { + launch::>(CudaDevice::default()); + } +} + fn main() { #[cfg(any( feature = "ndarray", @@ -95,4 +105,6 @@ fn main() { tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); + #[cfg(feature = "cuda-jit")] + cuda_jit::run(); } diff --git a/examples/text-classification/examples/ag-news-train.rs b/examples/text-classification/examples/ag-news-train.rs index 693aa1520..e21f4f230 100644 --- a/examples/text-classification/examples/ag-news-train.rs +++ b/examples/text-classification/examples/ag-news-train.rs @@ -97,7 +97,7 @@ mod cuda_jit { use burn::backend::{Autodiff, CudaJit}; pub fn run() { - launch::>(vec![Default::default()]); + launch::>>(vec![Default::default()]); } }