burn/burn-tch
Louis Fortier-Dubois 8fc52113bc
Chore/bump v12 (#1048)
2023-12-04 10:47:54 -05:00
..
src Fix double broadcast with tch (#1026) 2023-12-01 10:02:57 -05:00
Cargo.toml Chore/bump v12 (#1048) 2023-12-04 10:47:54 -05:00
LICENSE-APACHE Refactor/extract tch backend (#103) 2022-11-15 21:06:40 -05:00
LICENSE-MIT Refactor/extract tch backend (#103) 2022-11-15 21:06:40 -05:00
README.md Chore/release (#1031) 2023-12-01 14:33:28 -05:00

README.md

Burn Torch Backend

Burn Torch backend

Current Crates.io Version license

This crate provides a Torch backend for Burn utilizing the tch-rs crate, which offers a Rust interface to the PyTorch C++ API.

The backend supports CPU (multithreaded), CUDA (multiple GPUs), and MPS devices (MacOS).

Usage Example

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
    use burn_autodiff::Autodiff;
    use burn_tch::{LibTorch, LibTorchDevice};
    use mnist::training;

    pub fn run() {
        #[cfg(not(target_os = "macos"))]
        let device = LibTorchDevice::Cuda(0);
        #[cfg(target_os = "macos")]
        let device = LibTorchDevice::Mps;

        training::run::<Autodiff<LibTorch<f32>>>(device);
    }
}

#[cfg(feature = "tch-cpu")]
mod tch_cpu {
    use burn_autodiff::Autodiff;
    use burn_tch::{LibTorch, LibTorchDevice};
    use mnist::training;

    pub fn run() {
        let device = LibTorchDevice::Cpu;
        training::run::<Autodiff<LibTorch<f32>>>(device);
    }
}

Platform Support

Option CPU GPU Linux MacOS Windows Android iOS WASM
CPU Yes No Yes Yes Yes Yes Yes No
CUDA No Yes Yes No Yes No No No
MPS No Yes No Yes No No No No
Vulkan Yes Yes Yes Yes Yes Yes No No