mirror of https://github.com/tracel-ai/burn.git
Update model.bin mnist inference web + add cuda-jit flag for ag-news-infer (#2170)
* Update model.bin mnist inference web * Add cuda-jit flag for ag-news-infer
This commit is contained in:
parent
2755c36ed7
commit
784f57bee4
Binary file not shown.
|
@ -92,3 +92,16 @@ cargo run --example ag-news-infer --release --features wgpu # Run inference on
|
||||||
cargo run --example db-pedia-train --release --features wgpu # Train on the db pedia dataset
|
cargo run --example db-pedia-train --release --features wgpu # Train on the db pedia dataset
|
||||||
cargo run --example db-pedia-infer --release --features wgpu # Run inference db pedia dataset
|
cargo run --example db-pedia-infer --release --features wgpu # Run inference db pedia dataset
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## CUDA backend
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/tracel-ai/burn.git
|
||||||
|
cd burn
|
||||||
|
|
||||||
|
# Use the --release flag to really speed up training.
|
||||||
|
|
||||||
|
# AG News
|
||||||
|
cargo run --example ag-news-train --release --features cuda-jit # Train on the ag news dataset
|
||||||
|
cargo run --example ag-news-infer --release --features cuda-jit # Run inference on the ag news dataset
|
||||||
|
```
|
||||||
|
|
|
@ -81,6 +81,16 @@ mod wgpu {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda-jit")]
|
||||||
|
mod cuda_jit {
|
||||||
|
use crate::{launch, ElemType};
|
||||||
|
use burn::backend::{cuda_jit::CudaDevice, CudaJit};
|
||||||
|
|
||||||
|
pub fn run() {
|
||||||
|
launch::<CudaJit<ElemType, i32>>(CudaDevice::default());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
#[cfg(any(
|
#[cfg(any(
|
||||||
feature = "ndarray",
|
feature = "ndarray",
|
||||||
|
@ -95,4 +105,6 @@ fn main() {
|
||||||
tch_cpu::run();
|
tch_cpu::run();
|
||||||
#[cfg(feature = "wgpu")]
|
#[cfg(feature = "wgpu")]
|
||||||
wgpu::run();
|
wgpu::run();
|
||||||
|
#[cfg(feature = "cuda-jit")]
|
||||||
|
cuda_jit::run();
|
||||||
}
|
}
|
||||||
|
|
|
@ -97,7 +97,7 @@ mod cuda_jit {
|
||||||
use burn::backend::{Autodiff, CudaJit};
|
use burn::backend::{Autodiff, CudaJit};
|
||||||
|
|
||||||
pub fn run() {
|
pub fn run() {
|
||||||
launch::<Autodiff<CudaJit>>(vec![Default::default()]);
|
launch::<Autodiff<CudaJit<ElemType, i32>>>(vec![Default::default()]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue