burn/crates
AlteredOxide 0292967000
Feature/codegen gather indices greater than rank 1 (#2199)
* implemented muli-dim index for GatherNode

The `NodeCodegen` impl for `GatherNode` now performs gather in complete
accordance with the ONNX Gather spec.
- a `gather` function was added to the gather.rs file
- `gather()` is now called within the codegen instead of `tensor.select()`
- a test with two test cases have been added
    - test axes 0 and 1
    - both use 2D index tensors

* add gather_onnx to numeric api

Added int and float implementations of gather to the burn-tensor numeric
api:
- named the methods `gather_onnx` to not be confused with the current
  `gather`
- these implementations follow the `Gather` ONNX spec

Updated the gather*.py variants and their onnx outputs

* modified files didn't end up in last commit

* tests passing for onnx gather

The implementation of gather for the ONNX `Gather` spec is tentatively
complete:
- py test models are updated
- onnx_tests are modified and passing: `gather`, `gather_scalar`, and
  `gather_shape`
- node/gather tests are passing

NOTE: The two additional tests in crates/burn-import/src/burn/node/gather.rs that test
the actual functionality of gather are likely to be deleted, since they
are redundant to the tests in
crates/burn-import/onnx-tests/tests/onnx_tests.rs.

* inlined onnx gather within codegen

* rm gather_onnx from public api; rm unnecessary tests

* add comments to gather py models

* some codegen changes; formatting to appease run-checks

- Some necessary changes and improvements to the codegen inlined code
  after translating from public api (removed in previous commit).
- Changed some formatting that run-checks complained about.

* simplify gather codegen; include 1d and 2d onnx tests

Modified the `Gather` codegen per requested changes:
- combined match statements on index
- remove use of `alloc::vec::Vec`
- use map -> collect instead of procedural
- include a 1d index gather onnx test
- remove superflous tests

* delete unused gather.onnx
2024-08-28 07:51:19 -04:00
..
burn Bump burn version to 0.15.0 2024-08-27 15:13:40 -04:00
burn-autodiff Bump burn version to 0.15.0 2024-08-27 15:13:40 -04:00
burn-candle Bump burn version to 0.15.0 2024-08-27 15:13:40 -04:00
burn-common Make compatible with thumbv6m-none-eabi + add raspberry pi pico example (#2096) 2024-08-23 07:39:39 -04:00
burn-core Bump burn version to 0.15.0 2024-08-27 15:13:40 -04:00
burn-cuda Bump burn version to 0.15.0 2024-08-27 15:13:40 -04:00
burn-dataset Bump burn version to 0.15.0 2024-08-27 15:13:40 -04:00
burn-derive Fix module derive with generics (#2127) 2024-08-08 16:24:51 -04:00
burn-fusion Bump burn version to 0.15.0 2024-08-27 15:13:40 -04:00
burn-import Feature/codegen gather indices greater than rank 1 (#2199) 2024-08-28 07:51:19 -04:00
burn-jit Select kernel from CPA to CubeCL (#2168) 2024-08-27 15:17:58 -04:00
burn-ndarray Bump burn version to 0.15.0 2024-08-27 15:13:40 -04:00
burn-no-std-tests Bump burn version to 0.15.0 2024-08-27 15:13:40 -04:00
burn-tch Bump burn version to 0.15.0 2024-08-27 15:13:40 -04:00
burn-tensor Feature/codegen gather indices greater than rank 1 (#2199) 2024-08-28 07:51:19 -04:00
burn-tensor-testgen Fix Cargo.toml repository links (#1749) 2024-05-09 15:40:05 -04:00
burn-train Bump burn version to 0.15.0 2024-08-27 15:13:40 -04:00
burn-wgpu Bump burn version to 0.15.0 2024-08-27 15:13:40 -04:00
onnx-ir Feature/codegen gather indices greater than rank 1 (#2199) 2024-08-28 07:51:19 -04:00