burn/burn-tch
Nathaniel Simard 3d6c738776
Refactor/fusion/graph (#988)
2023-11-22 09:55:42 -05:00
..
src Refactor/fusion/graph (#988) 2023-11-22 09:55:42 -05:00
Cargo.toml Fix image classifier build failure by upgrading versions (#967) 2023-11-21 09:28:56 -05:00
LICENSE-APACHE Refactor/extract tch backend (#103) 2022-11-15 21:06:40 -05:00
LICENSE-MIT Refactor/extract tch backend (#103) 2022-11-15 21:06:40 -05:00
README.md Update readme (#962) 2023-11-17 13:04:41 -05:00

README.md

Burn Torch Backend

Burn Torch backend

Current Crates.io Version license

This crate provides a Torch backend for Burn utilizing the tch-rs crate, which offers a Rust interface to the PyTorch C++ API.

The backend supports CPU (multithreaded), CUDA (multiple GPUs), and MPS devices (MacOS).

Usage Example

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
    use burn_autodiff::Autodiff;
    use burn_tch::{LibTorch, LibTorchDevice};
    use mnist::training;

    pub fn run() {
        #[cfg(not(target_os = "macos"))]
        let device = LibTorchDevice::Cuda(0);
        #[cfg(target_os = "macos")]
        let device = LibTorchDevice::Mps;

        training::run::<Autodiff<LibTorch<f32>>>(device);
    }
}

#[cfg(feature = "tch-cpu")]
mod tch_cpu {
    use burn_autodiff::Autodiff;
    use burn_tch::{LibTorch, LibTorchDevice};
    use mnist::training;

    pub fn run() {
        let device = LibTorchDevice::Cpu;
        training::run::<Autodiff<LibTorch<f32>>>(device);
    }
}

Platform Support

Option CPU GPU Linux MacOS Windows Android iOS WASM
CPU Yes No Yes Yes Yes Yes Yes No
CUDA No Yes Yes No Yes No No No
MPS No Yes No Yes No No No No
Vulkan Yes Yes Yes Yes Yes Yes No No