mirror of https://github.com/tracel-ai/burn.git
8acf4b223b | ||
---|---|---|
.. | ||
src | ||
Cargo.toml | ||
LICENSE-APACHE | ||
LICENSE-MIT | ||
README.md |
README.md
Burn Torch Backend
Burn Torch backend
This crate provides a Torch backend for Burn utilizing the tch-rs crate, which offers a Rust interface to the PyTorch C++ API.
The backend supports CPU (multithreaded), CUDA (multiple GPUs), and MPS devices (MacOS).
Usage Example
#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use burn_autodiff::Autodiff;
use burn_tch::{LibTorch, LibTorchDevice};
use mnist::training;
pub fn run() {
#[cfg(not(target_os = "macos"))]
let device = LibTorchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;
training::run::<Autodiff<LibTorch<f32>>>(device);
}
}
#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use burn_autodiff::Autodiff;
use burn_tch::{LibTorch, LibTorchDevice};
use mnist::training;
pub fn run() {
let device = LibTorchDevice::Cpu;
training::run::<Autodiff<LibTorch<f32>>>(device);
}
}
Platform Support
Option | CPU | GPU | Linux | MacOS | Windows | Android | iOS | WASM |
---|---|---|---|---|---|---|---|---|
CPU | Yes | No | Yes | Yes | Yes | Yes | Yes | No |
CUDA | No | Yes | Yes | No | Yes | No | No | No |
MPS | No | Yes | No | Yes | No | No | No | No |
Vulkan | Yes | Yes | Yes | Yes | Yes | Yes | No | No |