Update model.bin mnist inference web + add cuda-jit flag for ag-news-infer (#2170)

* Update model.bin mnist inference web

* Add cuda-jit flag for ag-news-infer
This commit is contained in:
Guillaume Lagrange 2024-08-19 12:53:15 -04:00 committed by GitHub
parent 2755c36ed7
commit 784f57bee4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 26 additions and 1 deletions

View File

@ -92,3 +92,16 @@ cargo run --example ag-news-infer --release --features wgpu # Run inference on
cargo run --example db-pedia-train --release --features wgpu # Train on the db pedia dataset
cargo run --example db-pedia-infer --release --features wgpu # Run inference db pedia dataset
```
## CUDA backend
```bash
git clone https://github.com/tracel-ai/burn.git
cd burn
# Use the --release flag to really speed up training.
# AG News
cargo run --example ag-news-train --release --features cuda-jit # Train on the ag news dataset
cargo run --example ag-news-infer --release --features cuda-jit # Run inference on the ag news dataset
```

View File

@ -81,6 +81,16 @@ mod wgpu {
}
}
#[cfg(feature = "cuda-jit")]
mod cuda_jit {
use crate::{launch, ElemType};
use burn::backend::{cuda_jit::CudaDevice, CudaJit};
pub fn run() {
launch::<CudaJit<ElemType, i32>>(CudaDevice::default());
}
}
fn main() {
#[cfg(any(
feature = "ndarray",
@ -95,4 +105,6 @@ fn main() {
tch_cpu::run();
#[cfg(feature = "wgpu")]
wgpu::run();
#[cfg(feature = "cuda-jit")]
cuda_jit::run();
}

View File

@ -97,7 +97,7 @@ mod cuda_jit {
use burn::backend::{Autodiff, CudaJit};
pub fn run() {
launch::<Autodiff<CudaJit>>(vec![Default::default()]);
launch::<Autodiff<CudaJit<ElemType, i32>>>(vec![Default::default()]);
}
}