burn/crates/burn-tensor
AlteredOxide 0292967000
Feature/codegen gather indices greater than rank 1 (#2199)
* implemented muli-dim index for GatherNode

The `NodeCodegen` impl for `GatherNode` now performs gather in complete
accordance with the ONNX Gather spec.
- a `gather` function was added to the gather.rs file
- `gather()` is now called within the codegen instead of `tensor.select()`
- a test with two test cases have been added
    - test axes 0 and 1
    - both use 2D index tensors

* add gather_onnx to numeric api

Added int and float implementations of gather to the burn-tensor numeric
api:
- named the methods `gather_onnx` to not be confused with the current
  `gather`
- these implementations follow the `Gather` ONNX spec

Updated the gather*.py variants and their onnx outputs

* modified files didn't end up in last commit

* tests passing for onnx gather

The implementation of gather for the ONNX `Gather` spec is tentatively
complete:
- py test models are updated
- onnx_tests are modified and passing: `gather`, `gather_scalar`, and
  `gather_shape`
- node/gather tests are passing

NOTE: The two additional tests in crates/burn-import/src/burn/node/gather.rs that test
the actual functionality of gather are likely to be deleted, since they
are redundant to the tests in
crates/burn-import/onnx-tests/tests/onnx_tests.rs.

* inlined onnx gather within codegen

* rm gather_onnx from public api; rm unnecessary tests

* add comments to gather py models

* some codegen changes; formatting to appease run-checks

- Some necessary changes and improvements to the codegen inlined code
  after translating from public api (removed in previous commit).
- Changed some formatting that run-checks complained about.

* simplify gather codegen; include 1d and 2d onnx tests

Modified the `Gather` codegen per requested changes:
- combined match statements on index
- remove use of `alloc::vec::Vec`
- use map -> collect instead of procedural
- include a 1d index gather onnx test
- remove superflous tests

* delete unused gather.onnx
2024-08-28 07:51:19 -04:00
..
src Feature/codegen gather indices greater than rank 1 (#2199) 2024-08-28 07:51:19 -04:00
Cargo.toml Bump burn version to 0.15.0 2024-08-27 15:13:40 -04:00
LICENSE-APACHE Update licenses symlinks (#1613) 2024-04-12 14:43:58 -04:00
LICENSE-MIT Update licenses symlinks (#1613) 2024-04-12 14:43:58 -04:00
README.md [refactor] Move burn crates to their own crates directory (#1336) 2024-02-20 13:57:55 -05:00
env.bash [refactor] Move burn crates to their own crates directory (#1336) 2024-02-20 13:57:55 -05:00

README.md

Burn Tensor

Burn Tensor Library

Current Crates.io Version license

This library provides multiple tensor implementations hidden behind an easy to use API that supports reverse mode automatic differentiation.

Features

  • Flexible
  • CPU + GPU 🙏
  • Multi-Threads 🚀
  • Intuitive Usage 😌
  • No Global State 🚫
  • Multiple Backends 🦾
  • Reverse Mode Autodiff 🔥

Backends

For now, three backends are implemented, and some more are planned.

Autodiff

Automatic differentiation is implemented as just another tensor backend without any global state. It's possible since we keep track of the order in which each operation as been executed and the tape is only created when calculating the gradients. To do so, each operation creates a new node which has a reference to its parent nodes. Therefore, creating the tape only requires a simple and efficient graph traversal algorithm.

    let x = AutodiffTensor::from_tensor(x_ndarray);
    let y = ADtodiffTensor::from_tensor(y_ndarray);

    let z = x.matmul(&y);

    let grads = z.backward();

    let x_grad = x.grad(&grads);
    let y_grad = y.grad(&grads);

Cuda

To run with CUDA set TORCH_CUDA_VERSION=cu121.

Notes

This crate can be used alone without the entire burn stack and with only selected backends for smaller binaries.

Feature Flags

This crate can be used without the standard library (#![no_std]) with alloc by disabling the default std feature.

  • std - enables the standard library.
  • burn-tensor-testgen - enables test macros for generating tensor tests.