Commit Graph

235 Commits

Author SHA1 Message Date
louisfd 1a17e974be Merge branch 'main' into feat/cube/tiling2d 2024-06-25 14:51:33 -04:00
Nathaniel Simard 2fbc4628f3
Feat/cube/array assign ops (#1914) 2024-06-25 09:55:55 -04:00
Dilshod Tadjibaev 2c51615471
Print model structure like with PyTorch - Part 1 (#1912) 2024-06-25 09:23:10 -04:00
Nathaniel Simard a5dfb87828
Feat/comptime expr (#1910)
* Support comptime expressions

* Add test

* Cleanup

* Fix
2024-06-20 16:00:22 -04:00
Nathaniel Simard efc13d9a38
Feat/cube/compile error (#1909) 2024-06-19 17:21:32 -04:00
Nathaniel Simard d50bac165e
feat cube support Array (#1907) 2024-06-19 17:03:02 -04:00
Arthur Brussee 14d1bbba64
Do not use default burn-compute features unless enabled. (#1908) 2024-06-19 10:12:11 -04:00
Nathaniel Simard 560d77d154
Doc: Improve module to_device/fork docs (#1901) 2024-06-18 16:45:38 -04:00
Nathaniel Simard e758fd43db
Fix: constant record loading (#1902) 2024-06-18 16:45:21 -04:00
Justin Restivo 263add23a0
Tanh nn wrapper (#1903) 2024-06-18 16:45:04 -04:00
phenylshima f8a7c54272
feat: Make RetroForward public (#1905) 2024-06-18 16:44:32 -04:00
jachym 96468fc3c9
feat: added reduce min onnx import (#1894) 2024-06-18 09:04:24 -04:00
Nathaniel Simard 4f6db974a1
Perf/dynamic mm (#1906) 2024-06-18 08:41:07 -04:00
Guillaume Lagrange 8071b637b8
Fix conv2d_weight_grad_groups (#1891) 2024-06-17 09:24:33 -04:00
github-actions[bot] a04da9a285
Combined PRs (#1900)
* Bump cudarc from 0.11.4 to 0.11.6

Bumps [cudarc](https://github.com/coreylowman/cudarc) from 0.11.4 to 0.11.6.
- [Release notes](https://github.com/coreylowman/cudarc/releases)
- [Commits](https://github.com/coreylowman/cudarc/compare/v0.11.4...v0.11.6)

---
updated-dependencies:
- dependency-name: cudarc
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump derive_more from 0.99.17 to 0.99.18

Bumps [derive_more](https://github.com/JelteF/derive_more) from 0.99.17 to 0.99.18.
- [Release notes](https://github.com/JelteF/derive_more/releases)
- [Changelog](https://github.com/JelteF/derive_more/blob/v0.99.18/CHANGELOG.md)
- [Commits](https://github.com/JelteF/derive_more/compare/v0.99.17...v0.99.18)

---
updated-dependencies:
- dependency-name: derive_more
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump tokio from 1.37.0 to 1.38.0

Bumps [tokio](https://github.com/tokio-rs/tokio) from 1.37.0 to 1.38.0.
- [Release notes](https://github.com/tokio-rs/tokio/releases)
- [Commits](https://github.com/tokio-rs/tokio/compare/tokio-1.37.0...tokio-1.38.0)

---
updated-dependencies:
- dependency-name: tokio
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump regex from 1.10.4 to 1.10.5

Bumps [regex](https://github.com/rust-lang/regex) from 1.10.4 to 1.10.5.
- [Release notes](https://github.com/rust-lang/regex/releases)
- [Changelog](https://github.com/rust-lang/regex/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/regex/compare/1.10.4...1.10.5)

---
updated-dependencies:
- dependency-name: regex
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-06-17 09:06:01 -04:00
Arthur Brussee ac9f942a46
Remove GraphicsAPI generic for WgpuRuntime (#1888) 2024-06-17 09:04:25 -04:00
Joshua Ferguson eead748e90
add dependency management for python (#1887) 2024-06-17 09:00:38 -04:00
Louis Fortier-Dubois 8bf1cd60dc
Cube: variable reusability + refactor in cube macros (#1885) 2024-06-14 11:20:25 -04:00
louisfd cd8802ea8b clippy 2024-06-14 10:19:00 -04:00
Guillaume Lagrange 525244062f
Implement `Element` for `bool` (#1878)
* Element already implements One

* Add element module

* Add our own traits for Zero, One and ToPrimitive to support bool Element

* Fix typo

* Add basic tests for ToPrimitive with expected values

* The most important change of all

* Remove One + Zero identities

* Move zero/one outside mapv + refactor ToPrimitive -> ToElement trait

* Add num-traits to NOTICES.md
2024-06-14 09:02:38 -04:00
George b71c300638
Feat: Add `movedim` tensor operator (#1876)
*  (burn-tensor): add movedim function to tensor API

---------

Co-authored-by: Georgy Andreev <g.andreev@insilicomedicine.com>
2024-06-14 09:01:38 -04:00
Arthur Brussee 47a81270e1
Make autodiff compile on wasm (#1889) 2024-06-14 08:12:14 -04:00
louisfd 02f39ed094 fix clippy 2024-06-13 14:05:51 -04:00
Nathaniel Simard 5e58ae1a02
Refactor the tuner to be used standalone (#1884)
* Refactor the tuner to be used standalone

* Add a name for the autotune cache

* Fix tests

* Fix typo
2024-06-13 13:23:58 -04:00
louisfd babeac6d80 Merge branch 'main' into feat/cube/tiling2d 2024-06-13 13:21:13 -04:00
Jonathan Richard 5de1517232
Add documentation to burn core nn (#1746)
* Updated documentation for unfold4d

Added links between the struct and the config. Added a link to the related burn_tensor function in the documentation for the forward function.

* Changing nn relu module documentation to functional api

Removing the formula for relu from the module API to the functional API,
citing a paper relevant to relu
and mentionning the functional API in the module API

* Linking gelu module API documentation to functional API documentation

* Linear module : adding documentation

Adding documentation to the Linear module
mentionning that LinearConfig struct
should be used when creating a Linear Layer

Also adding links to the documentation that points people toward
the right path

* Updated documentation for dropout

Added links between the struct and the config. Added a link to the struct in the forward function for more info.

* embedding + swiglu

* RotaryEncodying : adding documentation

Adding documentation stating the RotaryEncoding should be created using a RotaryEncodingConfig

* prelu: adding documentation

Adding documentation to the prelu module:
- Linking forward function documentation to the functional API
- Citing the first paper to mention prelu
- Adding documentation saying that prelu layer should be created using PReluConfig

* pos_encoding: adding documentation

* Updated documentation for mha

Added links for more info. Added shape info at some places.

* docs: Add documentation for Gru module

Provide documentation for the Gru module, including its configuration and usage. Include a link to the paper that introduced the Gated Recurrent Unit (GRU) and specify that the module should be created using GruConfig. Also, mention that the forward function returns a state tensor with specific dimensions.

* burn-core-nn-transformers: adding documentation

Adding documentation:
- Says to use config to create the layers
- Add mathematical formula to the pwff forward pass
- Add citation in the pwff to the "Attention is all you need" paper

* Updated documentation: ConvTranspose1d and ConvTranspose2d

* docs: Add documentation for Lstm and BiLstm modules

Provide documentation for the Lstm and BiLstm modules, including their configurations and usage. Include links to the papers that introduced Long Short-Term Memory (LSTM) and Bidirectional LSTM. Specify that the modules should be created using LstmConfig and BiLstmConfig respectively.

* docs: Update documentation for ConvTranspose1d and ConvTranspose2d modules

* loss: Adding documenntation to the loss layers

Adding documentation stating to use the config to create the layer

* chore: Refactor Conv1d module imports and update documentation

* docs: Add documentation for AdaptiveAvgPool1d and AdaptiveAvgPool2d modules

Added references to the burn_tensor associated functions. Added links between the struct and the config.

* Refactor Conv1d module imports and update documentation

* chore: Refactor Conv2d module imports and update documentation

* Add documentation for AvgPool1d and AvgPool2d modules

Added references to the burn_tensor associated functions. Added links between the struct and the config.

* Add documentation for MaxPool1d and MaxPool2d modules

Added references to the burn_tensor associated functions. Added links between the struct and the config.

* Add documentation for leaky_relu and removed Config generic

Added references to the burn_tensor associated functions. Added links between the struct and the config. Removed the backend generic from the config since it's not needed (might be a breaking change).

* refactor: Update BatchNormConfig initialization and add documentation.

* Added link to config in embedding struct documentation

* refactor: Update GroupNormConfig initialization and add documentation

* refactor: Update InstanceNormConfig initialization and add documentation

* feat: Update LayerNormConfig initialization and add documentation

* refactor: Update RmsNormConfig initialization and add documentation

* fixed: removed #derive accidentally

* Added missing backticks in pools' shapes

* Format nn doc

* Make config fields public in nn modules

* Update import statements in nn modules

Changed burn_tensor imports to crate::tensor

* Update import statements in nn modules' tests

Changed burn_tensor imports to crate::tensor

* breaking change refactor: Update GroupNormConfig and InstanceNormConfig initialization

* Make SwiGlu fields public

* grammar

* slashes

* input tensors grouping

* copy-pasta mistake

* a not an >:I

* Capitalization

* better desc

* math 'n ticks

* group_norm functional implementation

* removed the ... struct

* decoder typo

* fmt

* referring to private fn in docs

---------

Co-authored-by: Thierry Cantin-Demers <piertcd@gmail.com>
Co-authored-by: mepatrick73 <pameu17@ulaval.ca>
2024-06-13 12:50:21 -04:00
Louis Fortier-Dubois 4393b336bc
clippy on rust update (#1886) 2024-06-13 12:15:15 -04:00
louisfd ddb5961548 fix test 2024-06-13 11:01:34 -04:00
louisfd e7cdfc9e4b clippy 2024-06-13 10:58:34 -04:00
louisfd 9218cae576 cleanup in variable tracking 2024-06-13 10:57:38 -04:00
louisfd c2a57abc42 pass tests + cleanup 2024-06-13 10:08:21 -04:00
Arthur Brussee c873d87ac8
Add option to flush queue instead of waiting for completion. (#1864)
* Make sync_type an option on sync instead of adding submit
2024-06-13 09:56:08 -04:00
louisfd 1b84a18789 Merge branch 'main' into feat/cube/tiling2d 2024-06-13 09:39:22 -04:00
louisfd 88dfa3afaa works 2024-06-12 09:45:01 -04:00
Mitchell Mosure 71bd5efbfa
feat: resize onnx import (#1863)
* feat: resize onnx import

* fix: resize import proc macro output

* fix: lint

* fix: simplify resize onnx

* fix: onnx-tests passing

* feedback: remove dead code and resolve merge conflicts
2024-06-11 13:22:33 -04:00
jachym 671ec8c679
feat: added slice onnx import (#1856)
* feat: added slice onnx import

* fix: axes, steps handling
2024-06-11 07:50:03 -04:00
github-actions[bot] dd60446946
Combined PRs (#1874)
* Bump cudarc from 0.11.0 to 0.11.4

Bumps [cudarc](https://github.com/coreylowman/cudarc) from 0.11.0 to 0.11.4.
- [Release notes](https://github.com/coreylowman/cudarc/releases)
- [Commits](https://github.com/coreylowman/cudarc/compare/v0.11.0...v0.11.4)

---
updated-dependencies:
- dependency-name: cudarc
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump clap from 4.5.4 to 4.5.6

Bumps [clap](https://github.com/clap-rs/clap) from 4.5.4 to 4.5.6.
- [Release notes](https://github.com/clap-rs/clap/releases)
- [Changelog](https://github.com/clap-rs/clap/blob/master/CHANGELOG.md)
- [Commits](https://github.com/clap-rs/clap/compare/clap_complete-v4.5.4...v4.5.6)

---
updated-dependencies:
- dependency-name: clap
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump tar from 0.4.40 to 0.4.41

Bumps [tar](https://github.com/alexcrichton/tar-rs) from 0.4.40 to 0.4.41.
- [Commits](https://github.com/alexcrichton/tar-rs/compare/0.4.40...0.4.41)

---
updated-dependencies:
- dependency-name: tar
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump strum_macros from 0.26.2 to 0.26.4

Bumps [strum_macros](https://github.com/Peternator7/strum) from 0.26.2 to 0.26.4.
- [Release notes](https://github.com/Peternator7/strum/releases)
- [Changelog](https://github.com/Peternator7/strum/blob/master/CHANGELOG.md)
- [Commits](https://github.com/Peternator7/strum/commits)

---
updated-dependencies:
- dependency-name: strum_macros
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump zip from 2.1.2 to 2.1.3

Bumps [zip](https://github.com/zip-rs/zip2) from 2.1.2 to 2.1.3.
- [Release notes](https://github.com/zip-rs/zip2/releases)
- [Changelog](https://github.com/zip-rs/zip2/blob/master/CHANGELOG.md)
- [Commits](https://github.com/zip-rs/zip2/compare/v2.1.2...v2.1.3)

---
updated-dependencies:
- dependency-name: zip
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-06-10 16:22:08 -04:00
louisfd fa72ed5d98 redeclare 2024-06-10 15:56:57 -04:00
Joshua Ferguson effce28b72
Optimize argument handling and improve ONNX graph building (#1857)
* draft for alternative burn import design

* passes onnx test, fails to build example

* pushing to test example on main

* fixed the issue with the example

* passes the test now

* spring cleaning and minor code changes

* removed pub visibility from most graph_data fields and functions

* comment fixes

* went ahead and removed the constant check for now

* removed unused function arg
2024-06-10 14:06:54 -05:00
Louis Fortier-Dubois de5b681b18
Cube: Vectorization + simple matmul implementation (#1866) 2024-06-07 14:05:51 -04:00
Arthur Brussee 4b174a88bd
Get resources from server (#1861) 2024-06-06 17:33:57 -04:00
Arthur Brussee 75e26d03c3
Speedup client.create for small allocations. (#1858)
* Speedup client.create for small allocations.
2024-06-06 17:09:01 -04:00
Arthur Brussee 675f6b3280
Make Param.id public (#1859)
* Make Param.id public

* Remove extra comment.
2024-06-06 11:03:14 -04:00
Icekey d28183c7e4
LearnerBuilder "with_checkpointing_strategy" should use builder pattern (#1841) 2024-06-05 07:55:44 -04:00
Arthur Brussee e0a1094f89
Add a feature to initialize from an existing wgpu adapter/device/queue (#1788)
* Add a feature to initialize from an existing wgpu adapter/device/queue

This is useful when interacting with other wgpu applications (eg. displaying a burn tensor as a texture in egui). The existing devices are keyed by the wgpu Device ID. Alternatively they could be keyed per adapter which would be more inline with other burn WgpuDevice's (one per adapter), but also there's no real inherent reason to.

This also involves making Queue into an Arc. Alternatively, this could give up ownership of the queue, but it's helpful to be able to synchronize burn operations and custom wgpu operations.
2024-06-05 07:19:52 -04:00
mepatrick73 36ed65a5cd
Feat/dynamic mm basic implementation + small refactor (#1844) 2024-06-04 17:01:33 -04:00
Louis Fortier-Dubois c42abadfe9
Cube: CubeType (no launch) and Comptime::map (#1853) 2024-06-04 13:43:43 -04:00
jachym a5af19b959
feat: add sum onnx import (#1846) 2024-06-03 15:30:44 -05:00
Louis Fortier-Dubois 5edaeabcee
Feat/cube/struct support (#1842)
* struct support (receive, use and modify fields)

* support struct with generics

* expect instead of unwrap

* fmt

* rename struc

* fmt

* Clippy

* Fix launcher

* Support creating private cube type without generics

* Cleanup

* generics support

* clippy

* minor

* fmt

---------

Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
2024-06-03 12:19:05 -04:00
Mathias Insley 92b0067693
Feat/gather import (#1843)
* Move and redirect GatherElements to new folders/nodes

* Create PyTorch script for gather

* Add onnx file for gather

* Add a gather test to onnx_tests

* Update gather.rs to use select

* Rename codegen test

* Update gather and gather_elements conversion functions

* Validate rank of input node and update output

* Add check for Gather
2024-06-03 08:28:32 -04:00