burn/burn-tch
Dilshod Tadjibaev e62ee1269b
Fix burn-tch's random implementation for standard dist (#469)
2023-07-06 08:50:50 -04:00
..
src Fix burn-tch's random implementation for standard dist (#469) 2023-07-06 08:50:50 -04:00
Cargo.toml Clean up cargo descriptions and formatting (#403) 2023-06-15 09:20:53 -04:00
LICENSE-APACHE Refactor/extract tch backend (#103) 2022-11-15 21:06:40 -05:00
LICENSE-MIT Refactor/extract tch backend (#103) 2022-11-15 21:06:40 -05:00
README.md Readme updates (#325) 2023-05-04 14:58:44 -04:00

README.md

Burn Torch Backend

Burn Torch backend

Current Crates.io Version license

This crate provides a Torch backend for Burn utilizing the tch-rs crate, which offers a Rust interface to the PyTorch C++ API.

The backend supports CPU (multithreaded), CUDA (multiple GPUs), and MPS devices (MacOS).

Usage Example

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
    use burn_autodiff::ADBackendDecorator;
    use burn_tch::{TchBackend, TchDevice};
    use mnist::training;

    pub fn run() {
        #[cfg(not(target_os = "macos"))]
        let device = TchDevice::Cuda(0);
        #[cfg(target_os = "macos")]
        let device = TchDevice::Mps;

        training::run::<ADBackendDecorator<TchBackend<f32>>>(device);
    }
}

#[cfg(feature = "tch-cpu")]
mod tch_cpu {
    use burn_autodiff::ADBackendDecorator;
    use burn_tch::{TchBackend, TchDevice};
    use mnist::training;

    pub fn run() {
        let device = TchDevice::Cpu;
        training::run::<ADBackendDecorator<TchBackend<f32>>>(device);
    }
}