mirror of https://github.com/tracel-ai/burn.git
Update model.bin mnist inference web + add cuda-jit flag for ag-news-infer (#2170)
* Update model.bin mnist inference web * Add cuda-jit flag for ag-news-infer
This commit is contained in:
parent
2755c36ed7
commit
784f57bee4
Binary file not shown.
|
@ -92,3 +92,16 @@ cargo run --example ag-news-infer --release --features wgpu # Run inference on
|
|||
cargo run --example db-pedia-train --release --features wgpu # Train on the db pedia dataset
|
||||
cargo run --example db-pedia-infer --release --features wgpu # Run inference db pedia dataset
|
||||
```
|
||||
|
||||
## CUDA backend
|
||||
|
||||
```bash
|
||||
git clone https://github.com/tracel-ai/burn.git
|
||||
cd burn
|
||||
|
||||
# Use the --release flag to really speed up training.
|
||||
|
||||
# AG News
|
||||
cargo run --example ag-news-train --release --features cuda-jit # Train on the ag news dataset
|
||||
cargo run --example ag-news-infer --release --features cuda-jit # Run inference on the ag news dataset
|
||||
```
|
||||
|
|
|
@ -81,6 +81,16 @@ mod wgpu {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda-jit")]
|
||||
mod cuda_jit {
|
||||
use crate::{launch, ElemType};
|
||||
use burn::backend::{cuda_jit::CudaDevice, CudaJit};
|
||||
|
||||
pub fn run() {
|
||||
launch::<CudaJit<ElemType, i32>>(CudaDevice::default());
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
#[cfg(any(
|
||||
feature = "ndarray",
|
||||
|
@ -95,4 +105,6 @@ fn main() {
|
|||
tch_cpu::run();
|
||||
#[cfg(feature = "wgpu")]
|
||||
wgpu::run();
|
||||
#[cfg(feature = "cuda-jit")]
|
||||
cuda_jit::run();
|
||||
}
|
||||
|
|
|
@ -97,7 +97,7 @@ mod cuda_jit {
|
|||
use burn::backend::{Autodiff, CudaJit};
|
||||
|
||||
pub fn run() {
|
||||
launch::<Autodiff<CudaJit>>(vec![Default::default()]);
|
||||
launch::<Autodiff<CudaJit<ElemType, i32>>>(vec![Default::default()]);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue