2022-07-28 03:44:41 +08:00
|
|
|
# Burn Tensor
|
|
|
|
|
2022-09-05 02:22:56 +08:00
|
|
|
> [Burn](https://github.com/burn-rs/burn) Tensor Library
|
|
|
|
|
|
|
|
[![Current Crates.io Version](https://img.shields.io/crates/v/burn-tensor.svg)](https://crates.io/crates/burn-tensor)
|
|
|
|
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/burn-rs/burn-tensor/blob/master/README.md)
|
2022-07-28 03:44:41 +08:00
|
|
|
|
|
|
|
This library provides multiple tensor implementations hidden behind an easy to use API that supports reverse mode automatic differentiation.
|
|
|
|
|
|
|
|
## Features
|
|
|
|
|
|
|
|
* Flexible ✨
|
|
|
|
* CPU + GPU 🙏
|
|
|
|
* Multi-Threads 🚀
|
|
|
|
* Intuitive Usage 😌
|
|
|
|
* No Global State 🚫
|
|
|
|
* Multiple Backends 🦾
|
|
|
|
* Reverse Mode Autodiff 🔥
|
|
|
|
|
|
|
|
### Backends
|
|
|
|
|
2023-08-09 05:57:51 +08:00
|
|
|
For now, only two backends are implemented, but adding new ones should not be that hard.
|
2022-07-28 03:44:41 +08:00
|
|
|
|
|
|
|
* [X] Pytorch using [tch-rs](https://github.com/LaurentMazare/tch-rs)
|
|
|
|
* [X] 100% Rust backend using [ndarray](https://github.com/rust-ndarray/ndarray)
|
2023-07-25 22:44:53 +08:00
|
|
|
* [X] [WGPU](https://github.com/gfx-rs/wgpu) backend
|
2022-07-28 03:44:41 +08:00
|
|
|
* [ ] Tensorflow using [tensorflow-rust](https://github.com/tensorflow/rust)
|
2022-09-05 02:22:56 +08:00
|
|
|
* [ ] CuDNN using RustCUDA[tensorflow-rust](https://github.com/Rust-GPU/Rust-CUDA)
|
2022-07-28 03:44:41 +08:00
|
|
|
* [ ] ...
|
|
|
|
|
|
|
|
### Autodiff
|
|
|
|
|
|
|
|
Automatic differentiation is implemented as just another tensor backend without any global state.
|
|
|
|
It's possible since we keep track of the order in which each operation as been executed and the tape is only created when calculating the gradients.
|
|
|
|
To do so, each operation creates a new node which has a reference to its parent nodes.
|
2023-08-09 05:57:51 +08:00
|
|
|
Therefore, creating the tape only requires a simple and efficient graph traversal algorithm.
|
2022-07-28 03:44:41 +08:00
|
|
|
|
|
|
|
```rust
|
|
|
|
let x = ADTensor::from_tensor(x_ndarray);
|
|
|
|
let y = ADTensor::from_tensor(y_ndarray);
|
|
|
|
|
|
|
|
let z = x.matmul(&y);
|
|
|
|
|
|
|
|
let grads = z.backward();
|
|
|
|
|
2022-09-05 02:22:56 +08:00
|
|
|
let x_grad = x.grad(&grads);
|
|
|
|
let y_grad = y.grad(&grads);
|
2022-07-28 03:44:41 +08:00
|
|
|
```
|
|
|
|
|
2022-08-01 00:06:25 +08:00
|
|
|
## Cuda
|
|
|
|
|
|
|
|
To run with CUDA set `TORCH_CUDA_VERSION=cu113`.
|
|
|
|
|
2022-12-27 05:30:25 +08:00
|
|
|
## Notes
|
2022-07-28 03:44:41 +08:00
|
|
|
|
2023-07-25 22:44:53 +08:00
|
|
|
This crate can be used alone without the entire burn stack and with only selected backends for smaller binaries.
|
2023-02-25 22:38:01 +08:00
|
|
|
|
|
|
|
|
|
|
|
## Feature Flags
|
|
|
|
|
|
|
|
This crate can be used without the standard library (`#![no_std]`) with `alloc` by disabling
|
|
|
|
the default `std` feature.
|
|
|
|
|
|
|
|
* `std` - enables the standard library.
|
2023-08-09 05:57:51 +08:00
|
|
|
* `burn-tensor-testgen` - enables test macros for generating tensor tests.
|
2023-02-25 22:38:01 +08:00
|
|
|
|