mirror of https://github.com/tracel-ai/burn.git
Add cuda gpu example + doc (#91)
This commit is contained in:
parent
947ed00301
commit
3da122db09
19
README.md
19
README.md
|
@ -44,14 +44,6 @@ Also, this may be a good idea to checkout the main [components](#components) to
|
|||
|
||||
For now there is only one example, but more to come 💪..
|
||||
|
||||
The `mnist` example can be run like so:
|
||||
|
||||
```console
|
||||
$ git clone https://github.com/burn-rs/burn.git
|
||||
$ cd burn
|
||||
$ cargo run --example mnist
|
||||
```
|
||||
|
||||
#### MNIST
|
||||
|
||||
The [MNIST](https://github.com/burn-rs/burn/blob/main/examples/mnist) example is not just of small script that shows you how to train a basic model, but it's a quick one showing you how to:
|
||||
|
@ -60,6 +52,17 @@ The [MNIST](https://github.com/burn-rs/burn/blob/main/examples/mnist) example is
|
|||
* Create the data pipeline from a raw dataset to a batched multi-threaded fast DataLoader.
|
||||
* Configure a [learner](#learner) to display and log metrics as well as to keep training checkpoints.
|
||||
|
||||
The example can be run like so:
|
||||
|
||||
```console
|
||||
$ git clone https://github.com/burn-rs/burn.git
|
||||
$ cd burn
|
||||
$ export TORCH_CUDA_VERSION=cu113 # Set the cuda version
|
||||
$ # Use the --release flag to really speed up training.
|
||||
$ cargo run --example mnist --release # CPU NdArray Backend
|
||||
$ cargo run --example mnist_cuda_gpu --release # GPU Tch Backend
|
||||
```
|
||||
|
||||
### Components
|
||||
|
||||
Knowing the main components will be of great help when starting playing with `burn`.
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
use mnist::training;
|
||||
|
||||
fn main() {
|
||||
use burn::tensor::backend::{TchADBackend, TchDevice};
|
||||
|
||||
let device = TchDevice::Cuda(0);
|
||||
training::run::<TchADBackend<burn::tensor::f16>>(device);
|
||||
println!("Done.");
|
||||
}
|
Loading…
Reference in New Issue