burn/burn-tch
Nathaniel Simard ab1b5890f5
Chore/release (#1031)
2023-12-01 14:33:28 -05:00
..
src Fix double broadcast with tch (#1026) 2023-12-01 10:02:57 -05:00
Cargo.toml Chore/release (#1031) 2023-12-01 14:33:28 -05:00
LICENSE-APACHE Refactor/extract tch backend (#103) 2022-11-15 21:06:40 -05:00
LICENSE-MIT Refactor/extract tch backend (#103) 2022-11-15 21:06:40 -05:00
README.md Chore/release (#1031) 2023-12-01 14:33:28 -05:00

README.md

Burn Torch Backend

Burn Torch backend

Current Crates.io Version license

This crate provides a Torch backend for Burn utilizing the tch-rs crate, which offers a Rust interface to the PyTorch C++ API.

The backend supports CPU (multithreaded), CUDA (multiple GPUs), and MPS devices (MacOS).

Usage Example

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
    use burn_autodiff::Autodiff;
    use burn_tch::{LibTorch, LibTorchDevice};
    use mnist::training;

    pub fn run() {
        #[cfg(not(target_os = "macos"))]
        let device = LibTorchDevice::Cuda(0);
        #[cfg(target_os = "macos")]
        let device = LibTorchDevice::Mps;

        training::run::<Autodiff<LibTorch<f32>>>(device);
    }
}

#[cfg(feature = "tch-cpu")]
mod tch_cpu {
    use burn_autodiff::Autodiff;
    use burn_tch::{LibTorch, LibTorchDevice};
    use mnist::training;

    pub fn run() {
        let device = LibTorchDevice::Cpu;
        training::run::<Autodiff<LibTorch<f32>>>(device);
    }
}

Platform Support

Option CPU GPU Linux MacOS Windows Android iOS WASM
CPU Yes No Yes Yes Yes Yes Yes No
CUDA No Yes Yes No Yes No No No
MPS No Yes No Yes No No No No
Vulkan Yes Yes Yes Yes Yes Yes No No