Commit Graph

237 Commits

Author SHA1 Message Date
Louis Fortier-Dubois 5edaeabcee
Feat/cube/struct support (#1842)
* struct support (receive, use and modify fields)

* support struct with generics

* expect instead of unwrap

* fmt

* rename struc

* fmt

* Clippy

* Fix launcher

* Support creating private cube type without generics

* Cleanup

* generics support

* clippy

* minor

* fmt

---------

Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
2024-06-03 12:19:05 -04:00
Mathias Insley 92b0067693
Feat/gather import (#1843)
* Move and redirect GatherElements to new folders/nodes

* Create PyTorch script for gather

* Add onnx file for gather

* Add a gather test to onnx_tests

* Update gather.rs to use select

* Rename codegen test

* Update gather and gather_elements conversion functions

* Validate rank of input node and update output

* Add check for Gather
2024-06-03 08:28:32 -04:00
Jonas Kantic fba1e27e0c
Remainder operator (#1726)
* Adds remainder ops implementation for Tensor.

* Adds test for % operator.
2024-06-01 16:47:02 -05:00
jachym.putta 99e1ba4864
feat: expand onnx import (#1813)
* feat: added expand to import
2024-05-31 16:48:02 -05:00
jachym.putta 44f1053219
feat: added range onnx import (#1834)
* feat: added range onnx import

* fix: range input types
2024-05-31 16:40:54 -05:00
Nathaniel Simard 36d4bcd705
[Refactor - Breaking] Refactor cube operations with better names & Support subgroup operations (#1839) 2024-05-31 17:07:21 -04:00
will-maclean 13a6f84bc3
Feature/onnx argmax (#1814)
* pre-test

* implementing argmax for burn-import from onnx

* tidying

* fixing return types and tests

* addressing feedback

* only warn when select_last_index!=0
2024-05-31 14:46:09 -04:00
Louis Fortier-Dubois de0b49e4a3
Cube: Topology constants (#1838)
---------

Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
2024-05-30 12:03:30 -04:00
Louis Fortier-Dubois 61c9fdbbc8
Cube: cleaner use of topology values (#1835)
* constant keyword parsing

* works
2024-05-29 09:08:10 -04:00
McArthur a2ad424fc8
Indices Operator (#1735) 2024-05-29 09:05:31 -04:00
Louis Fortier-Dubois cacc764205
Cube: support for shared memory (#1831) 2024-05-29 08:22:04 -04:00
Guillaume Lagrange e4836241e1
Fix `DataSerialize` conversion for elements of the same type (#1832) 2024-05-28 18:12:44 -04:00
Louis Fortier-Dubois e61b026918
Cube: support method call + prettier tensor metadata (#1829) 2024-05-27 15:18:17 -04:00
Nathaniel Simard fd54a8b470
Add vectorization support into cube (#1830) 2024-05-27 14:21:29 -04:00
Louis Fortier-Dubois dc85daa1c6
Cube: support for return + conv2d early return (#1828) 2024-05-27 13:19:00 -04:00
Nathaniel Simard 15d2055de8
Feat/cube/launch (#1827) 2024-05-27 12:15:06 -04:00
Adrian Müller cccd96de48
Feat: Implement ONNX RandomUniform + RandomNormal in burn-import (#1806) 2024-05-27 10:07:04 -04:00
github-actions[bot] 85ba167582
Combined PRs (#1823)
* Bump cudarc from 0.10.0 to 0.11.0

Bumps [cudarc](https://github.com/coreylowman/cudarc) from 0.10.0 to 0.11.0.
- [Release notes](https://github.com/coreylowman/cudarc/releases)
- [Commits](https://github.com/coreylowman/cudarc/compare/v0.10.0...v0.11.0)

---
updated-dependencies:
- dependency-name: cudarc
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump libc from 0.2.154 to 0.2.155

Bumps [libc](https://github.com/rust-lang/libc) from 0.2.154 to 0.2.155.
- [Release notes](https://github.com/rust-lang/libc/releases)
- [Commits](https://github.com/rust-lang/libc/compare/0.2.154...0.2.155)

---
updated-dependencies:
- dependency-name: libc
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump ratatui from 0.26.2 to 0.26.3

Bumps [ratatui](https://github.com/ratatui-org/ratatui) from 0.26.2 to 0.26.3.
- [Release notes](https://github.com/ratatui-org/ratatui/releases)
- [Changelog](https://github.com/ratatui-org/ratatui/blob/main/CHANGELOG.md)
- [Commits](https://github.com/ratatui-org/ratatui/compare/v0.26.2...v0.26.3)

---
updated-dependencies:
- dependency-name: ratatui
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump bytemuck from 1.15.0 to 1.16.0

Bumps [bytemuck](https://github.com/Lokathor/bytemuck) from 1.15.0 to 1.16.0.
- [Changelog](https://github.com/Lokathor/bytemuck/blob/main/changelog.md)
- [Commits](https://github.com/Lokathor/bytemuck/compare/v1.15.0...v1.16.0)

---
updated-dependencies:
- dependency-name: bytemuck
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump proc-macro2 from 1.0.83 to 1.0.84

Bumps [proc-macro2](https://github.com/dtolnay/proc-macro2) from 1.0.83 to 1.0.84.
- [Release notes](https://github.com/dtolnay/proc-macro2/releases)
- [Commits](https://github.com/dtolnay/proc-macro2/compare/1.0.83...1.0.84)

---
updated-dependencies:
- dependency-name: proc-macro2
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-05-27 10:04:27 -04:00
Ikko Eltociear Ashimine 1c5e65ab26
docs: update README.md (#1810)
minor fix
2024-05-27 09:21:03 -04:00
Nathaniel Simard c7ad25ab60
Update cuda-jit (#1799) 2024-05-24 11:31:47 -04:00
Louis Fortier-Dubois 23c622a9f8
Feat/cube/remaining ops (#1807) 2024-05-24 09:48:34 -04:00
Justin Restivo 1670a71711
Fix burn-jit compile error (#1803) 2024-05-23 20:24:42 -04:00
jachym.putta ef4646c90f
feat: Greater + GreaterOrEqual onnx import (#1801) 2024-05-23 08:59:15 -04:00
jachym.putta 1f31e20ce8
feat: Less + LessOrEqual onnx import (#1800) 2024-05-23 08:04:44 -04:00
Louis Fortier-Dubois e39b4d2da0
refactor reduce into separate traits (#1798) 2024-05-22 16:01:27 -04:00
Louis Fortier-Dubois 033171920c
Cube: first ported kernel + comptime support + variable reuse + cleanup (#1797) 2024-05-22 14:08:21 -04:00
Guillaume Lagrange b466fd7606
Add seq start position when applying RoPE encoding (#1796) 2024-05-22 13:18:31 -04:00
jachym.putta 0918cf00c6
feat: added min onnx import (#1778) 2024-05-22 10:52:19 -04:00
Guillaume Lagrange 550086a5c1
Fix record nested value de/serialization (#1751) 2024-05-22 09:15:32 -04:00
Louis Fortier-Dubois 6137d42c10
fix prng bug during autotune (#1791) 2024-05-22 09:11:13 -04:00
jachym.putta 8c01444fc5
Adding max import (#1769)
* feat: add max import

* feat: implement the right max operation (hopefully)
2024-05-22 08:31:55 -04:00
Mathias Insley 81ecd14f83
Feat/squeeze dims (#1779) 2024-05-22 07:53:51 -04:00
Louis Fortier-Dubois 76fe0ed881
Refactor/cube/vectorization (#1781) 2024-05-19 13:20:55 -04:00
Louis Fortier-Dubois 499ff0dd26
Feat/enable cube cl (#1777)
* Ben WIP

* Compile burn-jit

* WGPU works

* Remove old code

* move language cube stuff

* cleaning up

* some import reworking

* remove cube reexport

* template feature flag in cube

* ci

---------

Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
2024-05-19 10:55:04 -04:00
Mathias Insley 9c5b07c833
Squeeze Onnx Import (#1753) 2024-05-17 12:00:34 -04:00
Jonathan Richard 8de05e1419
Add configurable application logger to learner builder (#1774)
* refactor: add TracingSubscriberLogger trait and FileTracingSubscriberLogger struct

* Remove unused log module and renames, fmt

* Renamed tracing subscriber logger

* renamed to application logger installer

* book learner configuration update update

* fix typo

* unused import
2024-05-16 16:25:33 -04:00
Nathaniel Simard 7ab2ba1809
Feat/cubecl ir (#1776)
---------

Co-authored-by: louisfd <louisfd94@gmail.com>
2024-05-16 15:08:53 -04:00
Louis Fortier-Dubois 542790e17e
CubeCL first iteration (#1756) 2024-05-15 10:24:37 -04:00
getumen e823338750
Add Clone trait to the `OptimizerAdaptor` and Clone implementations to the optimizers (#1770) 2024-05-15 09:18:09 -04:00
Ben Barber d3cd6c4928
Replace opaque return types in optim (#1767)
* update ARCHITECTURE.md links to project architecture section in contributor book

* replace opaque return type in optim
2024-05-13 22:21:20 -04:00
Nathaniel Simard 9dcec0b998
Refactor/jit fusion (#1750)
* Reads & Writes with index_ref

* WIP

* Fix operations

* Cleanup
2024-05-13 12:48:23 -04:00
Ahmed Yarub Hani Al Nuaimi 10737527d8
#1747 Upgrade Rust dependencies (#1748)
* #1747
Upgrade Rust dependencies

* Revert upgrade for tch

The update of tch on windows gives an error:

INTEL MKL ERROR: The specified module could not be found. mkl_vml_avx2.1.dll.
Intel MKL FATAL ERROR: cannot load mkl_vml_avx2.1.dll or mkl_vml_def.1.dll.

* Keep only .cargo/config.toml file which works with rust > 1.75

---------

Co-authored-by: Sylvain Benner <sylvain@benner.online>
2024-05-10 16:25:19 -04:00
Thierry Cantin-Demers b09d8431df
Fix Cargo.toml repository links (#1749)
* Fix wgpu github link

* Fix burn-train repo link

* Fix burn-tensor github repo

* Fix burn-tensor repo link

* Fix remaining repo links in crates Cargo.toml

---------

Co-authored-by: Jonathan Richard <47578360+jwric@users.noreply.github.com>
2024-05-09 15:40:05 -04:00
Arjun31415 5bbc5ea944
Added ONNX AvgPool1d (#1744) 2024-05-07 16:10:18 -05:00
Nathaniel Simard a6e3b4e81e
Fix select assign backward (#1739) 2024-05-07 11:37:43 -04:00
Sébastien Boisvert bd06b38fac
Refactor: replace trait TemplateKernel by existing trait JitKernel (#1737)
* Refactor: replace trait TemplateKernel by existing trait JitKernel

* Refactor: implement trait JitKernel for struct Kernel
2024-05-06 20:59:00 -04:00
Arjun31415 7f94f4c219
Add MaxPool1d ONNX Op(#1725) 2024-05-06 10:51:00 -05:00
Anton Blomström fb13503fa9
Add reduce sum onnx ops to burn imports (#1723) 2024-05-06 10:49:17 -05:00
Anton Blomström f8994e044c
Fix unstable tests when run concurrently (#1724) 2024-05-05 15:27:42 -05:00
Arjun31415 152509c378
PReLu ONNX import (#1721)
* added prelu onnx operator

* bug fix

* added onnx tests and burn codegen tests

* fix tests

* added prelu to supported onnx ops and add prelu to dim_inference
2024-05-04 13:45:42 -05:00
Louis Fortier-Dubois a8661a2f53
Autodiff Memory Management: BFS (#1710) 2024-05-03 09:45:21 -04:00
Nathaniel Simard 5d959e2884
[Fusion] Support multi-precision fusion (#1718) 2024-05-02 18:22:56 -04:00
Louis Fortier-Dubois 2e4c82fa64
Fix repeat for dims > 1 (#1713) 2024-05-01 09:11:38 -04:00
Dilshod Tadjibaev 3a02a54e55
Update SUPPORTED-ONNX-OPS.md (#1717)
gather ONNX was checked off but actually GatherElements should have been updated.
2024-05-01 08:02:59 -04:00
Dilshod Tadjibaev ff9e875321
ONNX debug improvements (#1712)
* Minor debug improvements

* Change warn to panic

* Log improvements
2024-04-30 16:36:55 -05:00
Nathaniel Simard 587b8f80b3
First draft CUDA runtime (#1685)
Initial cuda runtime crate with a WIP compiler.
2024-04-30 09:46:29 -04:00
Jonathan Merritt ab501431b1
Handle ndarray matmul broadcasting (#1679)
* Handle ndarray matmul broadcasting

- Use strides to map linear batch indices from
  the output back to the input arrays.

* Fix typos
2024-04-29 17:25:27 -05:00
Dilshod Tadjibaev 1cdceb590f
Skip updating shape for linear if not present (#1700) 2024-04-29 14:53:18 -05:00
WU Chen b387829731
Implement bidirectional LSTM (#1035)
* resolve conflict

* move `gate_product` to `GateController`

* BiLstm needs to use its own initializer when init

* resolve conflicts

* add some comments

* improve doc

* correct the description of GateController

* fix fmt

* add `LstmState`

* add test for state

* set batch 2 in bilstm test

* resolve conflict

* fix

* fix doc

* change the batch size back to 1

* change the batch size back to 1

* modify docstring; delete dead comment
2024-04-26 13:28:36 -05:00
Louis Fortier-Dubois 6ae3926006
New autodiff graph memory management strategy (#1698)
---------

Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
2024-04-26 12:25:53 -04:00
Nathaniel Simard 2f294c5092
Fix lstm batch size bug (#1695) 2024-04-26 08:54:12 -04:00
Guillaume Lagrange ce2429eb10
Refactor element type to be decoupled from runtime (#1693) 2024-04-26 08:53:55 -04:00
Dilshod Tadjibaev 67ec06d5d8
ONNX support for scalar unsqueeze (#1690)
* Revert 1c639c8393

1c639c8393?diff=unified&w=0

* Refactor by @laggui

* Refactor unsqueeze

* Add support for scalar unsqueeze

* Removed dead comment
2024-04-25 16:05:28 -05:00
Nathaniel Simard 599a20d586
Upgrade wgpu (#1692) 2024-04-25 16:32:50 -04:00
Dilshod Tadjibaev a1bd14c5ae
Reshape bug fix (#1684)
* Revert 1c639c8393

1c639c8393?diff=unified&w=0

* Refactor by @laggui

* Refactor unsqueeze
2024-04-24 19:31:53 -05:00
Nathaniel Simard 886a1de235
Refactor/burn compute (#1580) 2024-04-23 13:05:15 -04:00
Sylvain Benner c579686a8a
Move HandleContainer and Tensor Ops descriptions from burn-fusion to burn-tensor (#1654)
* Move HandlerContainer and Tensor Ops description to burn-tensor

Move HandleContainer and Tensor operations descriptions to burn-tensor crate.
Removed the FusionDevice and replaced it with a DeviceOps trait bound to Backend::Device.

For now added modules to burn-tensor are excluded from no-std as they rely on Arc.

* [burn-tensor] Flatten module hierarchy for tensor representation

+ Add new repr feature to cargo file.

* Remove prefix on dosctring

* [burn-fusion] Require default features of burn-tensor
2024-04-23 11:27:54 -04:00
Guillaume Lagrange e6b1b7a317
Add layer norm onnx op support (#1680) 2024-04-23 11:19:07 -04:00
Dilshod Tadjibaev 1718da5210
Fix reshape bug (support for opset version 1) (#1667)
* Make reshape op version 1

* Refactor per PR feedback
2024-04-22 17:52:25 -05:00
Nathaniel Simard 29fa2ee76c
Support linear 1d (#1682) 2024-04-22 18:39:09 -04:00
Alex Errant d62b344d5b
`Arc<EventStoreClient>` to `Rc<EventStoreClient>` (#1668) 2024-04-22 18:21:53 -04:00
신희제(Heejae Shin/joel.barish) 2a7b296a1b
Add sign ONNX op import support (#1663)
* Add sign ONNX op support

* Update SUPPORTED-ONNX-OPS.md
2024-04-22 09:10:50 -04:00
Louis Fortier-Dubois 2140d9b568
remove JIT subsequent RNG tests (#1652) 2024-04-21 09:48:11 -04:00
Dilshod Tadjibaev 1433284a0f
Fix bug 1645 (Unsqueeze OpSet 11) (#1661)
* Add unsqueeze opset 16 test

* Fix for unsqueeze opset 11

* Remove println statement
2024-04-19 14:17:44 -05:00
Guillaume Lagrange b65a487300
Fix transpose onnx op (permute) (#1657) 2024-04-19 09:34:03 -04:00
Nico Zweifel ee12aee2e7
fix: `window` -> `pub window` in `dataset/mod.rs` (#1658)
* update dataset/mod.rs

* Update mod.rs

* Update window.rs
2024-04-19 09:33:21 -04:00
Guillaume Lagrange 9fbcbed20f
Add where onnx op support (#1653)
* Add where onnx op support

* Add broadcasting support

* Remove broadcasting limitation comment

* Fix broadcasting in mask where

* Forgot to reflect changes in codegen test

* Fix clippy
2024-04-18 15:46:02 -04:00
Guillaume Lagrange 7705fd9c25
Add matmul ONNX op support (#1638)
* Mul onnx op already supported

* Add matmul onnx op checks and tests

* Add missing eq derives

* Change supscript symbol

* Remove dead code

* Add support for matmul broadcast

* No more broadcasting restrictions

* Add results comment for mm, mv and vm
2024-04-18 09:20:31 -04:00
Dilshod Tadjibaev 2a721a9d0c
Enable native sign operation for Candle backend (#1647)
* Enable native sign operation for Candle backend

* Use fixed revision
2024-04-17 09:07:56 -04:00
Guillaume Lagrange 424033283a
Add reduce max ONNX op support (#1636)
* Add reduce max onnx op support

* Fix comments on tensor rank 1 result
2024-04-17 08:26:46 -04:00
Nico Zweifel 5a3f345734
WindowDataset/windows function (#1553) 2024-04-17 07:51:53 -04:00
Guillaume Lagrange 35b36bbe62
Add shape ONNX op support (#1639)
* Add shape onnx op support

* Remove cast node from onnx graph

* Fix shape implementation

* Fix shape config error message

* Fix typo

* Fix clippy type complexity for generated code
2024-04-16 09:28:21 -04:00
Guillaume Lagrange 6d96e8d808
[ONNX] Add not op and extend cast support to tensors (#1634)
* Add not onnx op support

* Extend cast onnx support to tensors

* Fix clippy
2024-04-16 08:45:25 -04:00
Mathias Insley 7377bbe31c
Feat/remainder (#1597)
* Add remainder_scalar op to numeric trait and associated int/float functions

* Update burn-tch crate

* Update ndarray crate

* Update jit crate

* Update candle crate

* Update fusion crate

* Update autodiff crate

* Forgot float.rs for fusion

* Add burn-tensor tests

* Redirect to the pre-existing modulus op

* Fix sign

* Remove mut from burn-tch

* Use sign trick to make wgpu backend work

* Add more unit tests in to cover bases

* Naming fix for burn-fusion

* Update tests w/PyTorch link

* Use different WGSL instructions for remainder

* Redirect to remainder Operator instead of modulo

* Revert Modulo in instruction.rs
2024-04-16 08:35:20 -04:00
Mathias Insley 48c61ebb81
Docs/update contributor book (#1622)
* Update links to latest commit off main

* Some pedantry

* Update links and add jit

* Update instructions for burn-jit and wgpu

* Updated import section with more recent links

* Some grammar/typo/styling fixes

* Code added to burn-wgpu too
2024-04-16 08:33:59 -04:00
Guillaume Lagrange d5f20e2711
Add reduce mean ONNX op support (#1637)
* Add reduce mean onnx op support

* Fix comment
2024-04-16 07:59:35 -04:00
Dilshod Tadjibaev 340a12463a
Update SUPPORTED-ONNX-OPS.md (#1641) 2024-04-16 07:52:15 -04:00
Guillaume Lagrange 81a67b6a09
Add sin onnx op support (#1633) 2024-04-15 15:28:16 -04:00
Sylvain Benner e303e31c8b
Bump next version of Burn to 0.14.0 (#1618) 2024-04-12 17:14:45 -04:00
Guillaume Lagrange cf7b279e5e
Fix burn README symlink (#1617) 2024-04-12 16:00:47 -04:00
Guillaume Lagrange 9980db440d
Remove unused assets (#1616) 2024-04-12 15:48:16 -04:00
Guillaume Lagrange 264c167c11
Update licenses symlinks (#1613) 2024-04-12 14:43:58 -04:00
Nathaniel Simard ff844b1667
Fix candle backend sync (#1579)
* Fix candle backend sync

* tch mps sync

* clippy

---------

Co-authored-by: louisfd <louisfd94@gmail.com>
2024-04-12 12:15:50 -04:00
Aasheesh Singh fb1da53a38
support for rotary positional encoding to transformer modules. (#1604)
* add rotary positional encoding to transformer modules.

* fix f64 error

* use num_traits

* add panic condition
2024-04-12 11:45:49 -04:00
Louis Fortier-Dubois 23210f05f2
JIT: Autotune matmul tiling 2d unroll (#1601)
* autotune tiling 2d unroll

* clippy

* forgotten important stuff
2024-04-12 10:15:21 -04:00
Nathaniel Simard 07a61a1cec
Fix autodiff memory management graph cleaning (#1602) 2024-04-11 16:21:00 -04:00
Guillaume Lagrange 0cbe9a927d
Add learner training report summary (#1591)
* Add training report summary

* Fix LossMetric batch size state

* Add NumericEntry de/serialize

* Fix clippy suggestion

* Compact recorder does not use compression (anymore)

* Add learner summary expected results tests

* Add summary to learner builder and automatically display in fit

- Add LearnerSummaryConfig
- Keep track of summary metrics names
- Add model field when displaying from learner.fit()
2024-04-11 12:32:25 -04:00
Louis Fortier-Dubois bdb62fbcd0
Repeat ops autodiff & fusion + fix autodiff ones & zeros (#1600)
* added repeat to autodiff and fusion + zero one backend init in autodiff

* autodiff for repeat
2024-04-11 11:32:45 -04:00
Dilshod Tadjibaev 2f885480ed
Use num-traits for float ops (#1584) 2024-04-08 10:16:20 -05:00
Guillaume Lagrange f3e0aa6689
Add multi-label classification dataset and metric (#1572)
* Add multilabel classification dataset

- Add MultiLabel annotation support
- Refactor de/serialize annotation with AnnotationRaw
- Add ImageFolderDataset::with_items methods

* Fix custom-image-classification example deps

* Add image_folder_dataset_multilabel test

* Do not change class names order when provided

* Add hamming score and multi-label classification output

* Add new_classification_with_items test

* Fix clippy suggestions

* Implement default trait for hamming score

* Remove de/serialization and use AnnotationRaw as type

* Fix clippy

* Fix metric backend phantom data
2024-04-05 13:16:46 -04:00