cdd1fa1672
* Move distribution to module * Add new TensorData with serialization support * Implement display and from for TensorData * Add missing Cargo.lock * Add missing bytemuck feature * Add zeros, ones, full and random TensorData methods * Refactor Data -> TensorData usage * Fix tests Since TensorData is not generic over the element type anymore no type inference can be done by the compiler. We must explicitly cast the expected results to the expected backend type. * Remove commented line * Fix import * Add record-backward-compat * Remove dim const generic from TensorData * Support NestedValue de/serialization with TensorData * Fix burn-jit tests * Remove eprinln * Refactor onnx import to use TensorData * Fix tch from_data * Fix nested value serialization for u8 * Fix missing import * Fix reduce min onnx test * Fix deprecated attribute * Remove shape getter * Remove strict assert in tests * Add tensor data as_bytes * Add tensor check for rank mismatch * Fix typo (dimensions plural) * Fix error message * Update book examples with from_data and fix Display impl for TensorData * Add deprecation note |
||
---|---|---|
.. | ||
src | ||
Cargo.toml | ||
LICENSE-APACHE | ||
LICENSE-MIT | ||
README.md |
README.md
Burn Torch Backend
Burn Torch backend
This crate provides a Torch backend for Burn utilizing the
tch-rs
crate, which offers a Rust interface to the
PyTorch C++ API.
The backend supports CPU (multithreaded), CUDA (multiple GPUs), and MPS devices (MacOS).
Installation
tch-rs
requires the C++ PyTorch library (LibTorch) to
be available on your system.
By default, the CPU distribution is installed for LibTorch v2.2.0 as required by tch-rs
.
CUDA
To install the latest compatible CUDA distribution, set the TORCH_CUDA_VERSION
environment
variable before the tch-rs
dependency is retrieved with cargo
.
export TORCH_CUDA_VERSION=cu121
On Windows:
$Env:TORCH_CUDA_VERSION = "cu121"
For example, running the validation sample for the first time could be done with the following commands:
export TORCH_CUDA_VERSION=cu121
cargo run --bin cuda --release
Important: make sure your driver version is compatible with the selected CUDA version. A CUDA Toolkit installation is not required since LibTorch ships with the appropriate CUDA runtimes. Having the latest driver version is recommended, but you can always take a look at the toolkit driver version table or minimum required driver version (limited feature-set, might not work with all operations).
Once your installation is complete, you should be able to build/run your project. You can also
validate your installation by running the appropriate cpu
, cuda
or mps
sample as below.
cargo run --bin cpu --release
cargo run --bin cuda --release
cargo run --bin mps --release
Note: no MPS distribution is available for automatic download at this time, please check out the manual instructions.
Manual Download
To install tch-rs
with a different LibTorch distribution, you will have to manually download the
desired LibTorch distribution. The instructions are detailed in the sections below for each
platform.
Compute Platform | CPU | GPU | Linux | MacOS | Windows | Android | iOS | WASM |
---|---|---|---|---|---|---|---|---|
CPU | Yes | No | Yes | Yes | Yes | Yes | Yes | No |
CUDA | Yes [1] | Yes | Yes | No | Yes | No | No | No |
Metal (MPS) | No | Yes | No | Yes | No | No | No | No |
Vulkan | Yes | Yes | Yes | Yes | Yes | Yes | No | No |
[1] The LibTorch CUDA distribution also comes with CPU support.
CPU
🐧 Linux
First, download the LibTorch CPU distribution.
wget -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcpu.zip
unzip libtorch.zip
Then, point to that installation using the LIBTORCH
and LD_LIBRARY_PATH
environment variables
before building burn-tch
or a crate which depends on it.
export LIBTORCH=/absolute/path/to/libtorch/
export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH
🍎 Mac
First, download the LibTorch CPU distribution.
wget -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-macos-x86_64-2.2.0.zip
unzip libtorch.zip
Then, point to that installation using the LIBTORCH
and DYLD_LIBRARY_PATH
environment variables
before building burn-tch
or a crate which depends on it.
export LIBTORCH=/absolute/path/to/libtorch/
export DYLD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$DYLD_LIBRARY_PATH
🪟 Windows
First, download the LibTorch CPU distribution.
wget https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-2.2.0%2Bcpu.zip -OutFile libtorch.zip
Expand-Archive libtorch.zip
Then, set the LIBTORCH
environment variable and append the library to your path as with the
PowerShell commands below before building burn-tch
or a crate which depends on it.
$Env:LIBTORCH = "/absolute/path/to/libtorch/"
$Env:Path += ";/absolute/path/to/libtorch/"
CUDA
LibTorch 2.2.0 currently includes binary distributions with CUDA 11.8 or 12.1 runtimes. The manual installation instructions are detailed below.
CUDA 11.8
🐧 Linux
First, download the LibTorch CUDA 11.8 distribution.
wget -O libtorch.zip https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu118.zip
unzip libtorch.zip
Then, point to that installation using the LIBTORCH
and LD_LIBRARY_PATH
environment variables
before building burn-tch
or a crate which depends on it.
export LIBTORCH=/absolute/path/to/libtorch/
export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH
Note: make sure your CUDA installation is in your PATH
and LD_LIBRARY_PATH
.
🪟 Windows
First, download the LibTorch CUDA 11.8 distribution.
wget https://download.pytorch.org/libtorch/cu118/libtorch-win-shared-with-deps-2.2.0%2Bcu118.zip -OutFile libtorch.zip
Expand-Archive libtorch.zip
Then, set the LIBTORCH
environment variable and append the library to your path as with the
PowerShell commands below before building burn-tch
or a crate which depends on it.
$Env:LIBTORCH = "/absolute/path/to/libtorch/"
$Env:Path += ";/absolute/path/to/libtorch/"
CUDA 12.1
🐧 Linux
First, download the LibTorch CUDA 12.1 distribution.
wget -O libtorch.zip https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu121.zip
unzip libtorch.zip
Then, point to that installation using the LIBTORCH
and LD_LIBRARY_PATH
environment variables
before building burn-tch
or a crate which depends on it.
export LIBTORCH=/absolute/path/to/libtorch/
export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH
Note: make sure your CUDA installation is in your PATH
and LD_LIBRARY_PATH
.
🪟 Windows
First, download the LibTorch CUDA 12.1 distribution.
wget https://download.pytorch.org/libtorch/cu121/libtorch-win-shared-with-deps-2.2.0%2Bcu121.zip -OutFile libtorch.zip
Expand-Archive libtorch.zip
Then, set the LIBTORCH
environment variable and append the library to your path as with the
PowerShell commands below before building burn-tch
or a crate which depends on it.
$Env:LIBTORCH = "/absolute/path/to/libtorch/"
$Env:Path += ";/absolute/path/to/libtorch/"
Metal (MPS)
There is no official LibTorch distribution with MPS support at this time, so the easiest alternative is to use a PyTorch installation. This requires a Python installation.
Note: MPS acceleration is available on MacOS 12.3+.
pip install torch==2.2.0
export LIBTORCH_USE_PYTORCH=1
export DYLD_LIBRARY_PATH=/path/to/pytorch/lib:$DYLD_LIBRARY_PATH
Example Usage
For a simple example, check out any of the test programs in src/bin/
. Each program
sets the device to use and performs a simple element-wise addition.
For a more complete example using the tch
backend, take a loot at the
Burn mnist example.