burn/crates/burn-tch
Mathias Insley bb5e6faff2
Feat/autotune int ops (#1136)
* Add int_random to int tensor ops

* Int random for tch backend

* Int random for burn-fusion

* int random for autodiff

* Int random for candle backend

* Int random for ndarray backend

* Int random for wgpu backend

* Merge imports

* Typo

* Shader file for int uniform distribution

* Create AutotuneOperationSet and public int_sum_dim_autotune

* Adjust bounds to 0..10

* Create uniform_int_kernel, unit tests, use new kernel

* Reduction kernels for regular and shared memory sum_dim int operations

* Macro that accomadates wgpu IntElement

* Add autotuning to int_mean_dim

* Use correct macro for Int autotuning

* Add int_mean_dim_shared_memory

* Add int_mean_dim and unit test

* Create autotunables for mean_dim

* Run fmt

* Remove comment

* Finish resolving merge conflict, fix doc

* Make the element trait bound a parameter to reduce_tune_ops macro

* Update book

* Fix requested change

* Change range to [0, 255] and update test accordingly

* Forgot to include candle in last commit

* Fix comment

* Use correct int autotune for mean dim

* Fix typo- not sure how this passed earlier

* Resolve syntax issues from merge

* Fix cast_float

* Saving here

* Continue fixing merge conflicts, all tests pass locally

* Run fmt

* Change cast_float to cast_u32_to_float

* Make uniform_int_inner_loop safer

* Be even more explicit about u32 casts

* Skip an intermediate step and cast directly to u32

* Replace JitElement + Element with IntElement

* Run fmt

* This should fix the CI

* This time for sure
2024-02-26 14:53:21 -05:00
..
src Feat/autotune int ops (#1136) 2024-02-26 14:53:21 -05:00
Cargo.toml [refactor] Move burn crates to their own crates directory (#1336) 2024-02-20 13:57:55 -05:00
LICENSE-APACHE [refactor] Move burn crates to their own crates directory (#1336) 2024-02-20 13:57:55 -05:00
LICENSE-MIT [refactor] Move burn crates to their own crates directory (#1336) 2024-02-20 13:57:55 -05:00
README.md [refactor] Move burn crates to their own crates directory (#1336) 2024-02-20 13:57:55 -05:00

README.md

Burn Torch Backend

Burn Torch backend

Current Crates.io Version license

This crate provides a Torch backend for Burn utilizing the tch-rs crate, which offers a Rust interface to the PyTorch C++ API.

The backend supports CPU (multithreaded), CUDA (multiple GPUs), and MPS devices (MacOS).

Installation

tch-rs requires the C++ PyTorch library (LibTorch) to be available on your system.

By default, the CPU distribution is installed for LibTorch v2.2.0 as required by tch-rs.

CUDA

To install the latest compatible CUDA distribution, set the TORCH_CUDA_VERSION environment variable before the tch-rs dependency is retrieved with cargo.

export TORCH_CUDA_VERSION=cu121

On Windows:

$Env:TORCH_CUDA_VERSION = "cu121"

For example, running the validation sample for the first time could be done with the following commands:

export TORCH_CUDA_VERSION=cu121
cargo run --bin cuda --release

Important: make sure your driver version is compatible with the selected CUDA version. A CUDA Toolkit installation is not required since LibTorch ships with the appropriate CUDA runtimes. Having the latest driver version is recommended, but you can always take a look at the toolkit driver version table or minimum required driver version (limited feature-set, might not work with all operations).


Once your installation is complete, you should be able to build/run your project. You can also validate your installation by running the appropriate cpu, cuda or mps sample as below.

cargo run --bin cpu --release
cargo run --bin cuda --release
cargo run --bin mps --release

Note: no MPS distribution is available for automatic download at this time, please check out the manual instructions.

Manual Download

To install tch-rs with a different LibTorch distribution, you will have to manually download the desired LibTorch distribution. The instructions are detailed in the sections below for each platform.

Compute Platform CPU GPU Linux MacOS Windows Android iOS WASM
CPU Yes No Yes Yes Yes Yes Yes No
CUDA Yes [1] Yes Yes No Yes No No No
Metal (MPS) No Yes No Yes No No No No
Vulkan Yes Yes Yes Yes Yes Yes No No

[1] The LibTorch CUDA distribution also comes with CPU support.

CPU

🐧 Linux

First, download the LibTorch CPU distribution.

wget -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcpu.zip
unzip libtorch.zip

Then, point to that installation using the LIBTORCH and LD_LIBRARY_PATH environment variables before building burn-tch or a crate which depends on it.

export LIBTORCH=/absolute/path/to/libtorch/
export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH

🍎 Mac

First, download the LibTorch CPU distribution.

wget -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-macos-x86_64-2.2.0.zip
unzip libtorch.zip

Then, point to that installation using the LIBTORCH and DYLD_LIBRARY_PATH environment variables before building burn-tch or a crate which depends on it.

export LIBTORCH=/absolute/path/to/libtorch/
export DYLD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$DYLD_LIBRARY_PATH

🪟 Windows

First, download the LibTorch CPU distribution.

wget https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-2.2.0%2Bcpu.zip -OutFile libtorch.zip
Expand-Archive libtorch.zip

Then, set the LIBTORCH environment variable and append the library to your path as with the PowerShell commands below before building burn-tch or a crate which depends on it.

$Env:LIBTORCH = "/absolute/path/to/libtorch/"
$Env:Path += ";/absolute/path/to/libtorch/"

CUDA

LibTorch 2.2.0 currently includes binary distributions with CUDA 11.8 or 12.1 runtimes. The manual installation instructions are detailed below.

CUDA 11.8

🐧 Linux

First, download the LibTorch CUDA 11.8 distribution.

wget -O libtorch.zip https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu118.zip
unzip libtorch.zip

Then, point to that installation using the LIBTORCH and LD_LIBRARY_PATH environment variables before building burn-tch or a crate which depends on it.

export LIBTORCH=/absolute/path/to/libtorch/
export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH

Note: make sure your CUDA installation is in your PATH and LD_LIBRARY_PATH.


🪟 Windows

First, download the LibTorch CUDA 11.8 distribution.

wget https://download.pytorch.org/libtorch/cu118/libtorch-win-shared-with-deps-2.2.0%2Bcu118.zip -OutFile libtorch.zip
Expand-Archive libtorch.zip

Then, set the LIBTORCH environment variable and append the library to your path as with the PowerShell commands below before building burn-tch or a crate which depends on it.

$Env:LIBTORCH = "/absolute/path/to/libtorch/"
$Env:Path += ";/absolute/path/to/libtorch/"

CUDA 12.1

🐧 Linux

First, download the LibTorch CUDA 12.1 distribution.

wget -O libtorch.zip https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu121.zip
unzip libtorch.zip

Then, point to that installation using the LIBTORCH and LD_LIBRARY_PATH environment variables before building burn-tch or a crate which depends on it.

export LIBTORCH=/absolute/path/to/libtorch/
export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH

Note: make sure your CUDA installation is in your PATH and LD_LIBRARY_PATH.


🪟 Windows

First, download the LibTorch CUDA 12.1 distribution.

wget https://download.pytorch.org/libtorch/cu121/libtorch-win-shared-with-deps-2.2.0%2Bcu121.zip -OutFile libtorch.zip
Expand-Archive libtorch.zip

Then, set the LIBTORCH environment variable and append the library to your path as with the PowerShell commands below before building burn-tch or a crate which depends on it.

$Env:LIBTORCH = "/absolute/path/to/libtorch/"
$Env:Path += ";/absolute/path/to/libtorch/"

Metal (MPS)

There is no official LibTorch distribution with MPS support at this time, so the easiest alternative is to use a PyTorch installation. This requires a Python installation.

Note: MPS acceleration is available on MacOS 12.3+.

pip install torch==2.2.0
export LIBTORCH_USE_PYTORCH=1
export DYLD_LIBRARY_PATH=/path/to/pytorch/lib:$DYLD_LIBRARY_PATH

Example Usage

For a simple example, check out any of the test programs in src/bin/. Each program sets the device to use and and performs a simple elementwise addition.

For a more complete example using the tch backend, take a loot at the Burn mnist example.