mirror of https://github.com/tracel-ai/burn.git
0292967000
* implemented muli-dim index for GatherNode The `NodeCodegen` impl for `GatherNode` now performs gather in complete accordance with the ONNX Gather spec. - a `gather` function was added to the gather.rs file - `gather()` is now called within the codegen instead of `tensor.select()` - a test with two test cases have been added - test axes 0 and 1 - both use 2D index tensors * add gather_onnx to numeric api Added int and float implementations of gather to the burn-tensor numeric api: - named the methods `gather_onnx` to not be confused with the current `gather` - these implementations follow the `Gather` ONNX spec Updated the gather*.py variants and their onnx outputs * modified files didn't end up in last commit * tests passing for onnx gather The implementation of gather for the ONNX `Gather` spec is tentatively complete: - py test models are updated - onnx_tests are modified and passing: `gather`, `gather_scalar`, and `gather_shape` - node/gather tests are passing NOTE: The two additional tests in crates/burn-import/src/burn/node/gather.rs that test the actual functionality of gather are likely to be deleted, since they are redundant to the tests in crates/burn-import/onnx-tests/tests/onnx_tests.rs. * inlined onnx gather within codegen * rm gather_onnx from public api; rm unnecessary tests * add comments to gather py models * some codegen changes; formatting to appease run-checks - Some necessary changes and improvements to the codegen inlined code after translating from public api (removed in previous commit). - Changed some formatting that run-checks complained about. * simplify gather codegen; include 1d and 2d onnx tests Modified the `Gather` codegen per requested changes: - combined match statements on index - remove use of `alloc::vec::Vec` - use map -> collect instead of procedural - include a 1d index gather onnx test - remove superflous tests * delete unused gather.onnx |
||
---|---|---|
.. | ||
burn | ||
burn-autodiff | ||
burn-candle | ||
burn-common | ||
burn-core | ||
burn-cuda | ||
burn-dataset | ||
burn-derive | ||
burn-fusion | ||
burn-import | ||
burn-jit | ||
burn-ndarray | ||
burn-no-std-tests | ||
burn-tch | ||
burn-tensor | ||
burn-tensor-testgen | ||
burn-train | ||
burn-wgpu | ||
onnx-ir |