Commit Graph

237 Commits

Author SHA1 Message Date
mepatrick73 505609ff38 Merge branch 'main' into feat/dynamic-small-pool 2024-06-27 14:18:35 -04:00
mepatrick73 a09ceab28b working memory extension, but not fast 2024-06-27 14:17:45 -04:00
Guillaume Lagrange cdd1fa1672
Refactor tensor data (#1916)
* Move distribution to module

* Add new TensorData with serialization support

* Implement display and from for TensorData

* Add missing Cargo.lock

* Add missing bytemuck feature

* Add zeros, ones, full and random TensorData methods

* Refactor Data -> TensorData usage

* Fix tests

Since TensorData is not generic over the element type anymore no type inference can be done by the compiler. We must explicitly cast the expected results to the expected backend type.

* Remove commented line

* Fix import

* Add record-backward-compat

* Remove dim const generic from TensorData

* Support NestedValue de/serialization with TensorData

* Fix burn-jit tests

* Remove eprinln

* Refactor onnx import to use TensorData

* Fix tch from_data

* Fix nested value serialization for u8

* Fix missing import

* Fix reduce min onnx test

* Fix deprecated attribute

* Remove shape getter

* Remove strict assert in tests

* Add tensor data as_bytes

* Add tensor check for rank mismatch

* Fix typo (dimensions plural)

* Fix error message

* Update book examples with from_data and fix Display impl for TensorData

* Add deprecation note
2024-06-26 20:22:19 -04:00
mepatrick73 1c7780aaac
Feat/dynamic small pool (#1931) 2024-06-26 15:42:04 -04:00
mepatrick73 9e49cc9e58 Merge branch 'main' into feat/dynamic-small-pool 2024-06-26 14:06:39 -04:00
mepatrick73 c1aecc22ff review changes 2024-06-26 14:06:31 -04:00
Nathaniel Simard f9ec2e1006
Handle visibility in cube (#1929) 2024-06-26 12:57:47 -04:00
Nathaniel Simard d772a1cfd5
Fix: launch without generics (#1932) 2024-06-26 12:57:32 -04:00
mepatrick73 23ccb2ff15 clippy fix v3 2024-06-26 04:32:51 -04:00
mepatrick73 74004dbf2e clippy 2024-06-26 04:15:01 -04:00
mepatrick73 bc84a0296b small fix for run checks 2024-06-26 04:06:29 -04:00
mepatrick73 a0f15bef8f adding small memory pool 2024-06-26 03:11:52 -04:00
mepatrick73 4c9097030f
Perf/dynamic mm slice adressing (#1917)
* basic implementation of virtual memory addressing for fast index + merging (there is a bug with slice padding
2024-06-25 18:16:46 -04:00
Nathaniel Simard 2fbc4628f3
Feat/cube/array assign ops (#1914) 2024-06-25 09:55:55 -04:00
Dilshod Tadjibaev 2c51615471
Print model structure like with PyTorch - Part 1 (#1912) 2024-06-25 09:23:10 -04:00
Nathaniel Simard a5dfb87828
Feat/comptime expr (#1910)
* Support comptime expressions

* Add test

* Cleanup

* Fix
2024-06-20 16:00:22 -04:00
Nathaniel Simard efc13d9a38
Feat/cube/compile error (#1909) 2024-06-19 17:21:32 -04:00
Nathaniel Simard d50bac165e
feat cube support Array (#1907) 2024-06-19 17:03:02 -04:00
Arthur Brussee 14d1bbba64
Do not use default burn-compute features unless enabled. (#1908) 2024-06-19 10:12:11 -04:00
Nathaniel Simard 560d77d154
Doc: Improve module to_device/fork docs (#1901) 2024-06-18 16:45:38 -04:00
Nathaniel Simard e758fd43db
Fix: constant record loading (#1902) 2024-06-18 16:45:21 -04:00
Justin Restivo 263add23a0
Tanh nn wrapper (#1903) 2024-06-18 16:45:04 -04:00
phenylshima f8a7c54272
feat: Make RetroForward public (#1905) 2024-06-18 16:44:32 -04:00
jachym 96468fc3c9
feat: added reduce min onnx import (#1894) 2024-06-18 09:04:24 -04:00
Nathaniel Simard 4f6db974a1
Perf/dynamic mm (#1906) 2024-06-18 08:41:07 -04:00
Guillaume Lagrange 8071b637b8
Fix conv2d_weight_grad_groups (#1891) 2024-06-17 09:24:33 -04:00
github-actions[bot] a04da9a285
Combined PRs (#1900)
* Bump cudarc from 0.11.4 to 0.11.6

Bumps [cudarc](https://github.com/coreylowman/cudarc) from 0.11.4 to 0.11.6.
- [Release notes](https://github.com/coreylowman/cudarc/releases)
- [Commits](https://github.com/coreylowman/cudarc/compare/v0.11.4...v0.11.6)

---
updated-dependencies:
- dependency-name: cudarc
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump derive_more from 0.99.17 to 0.99.18

Bumps [derive_more](https://github.com/JelteF/derive_more) from 0.99.17 to 0.99.18.
- [Release notes](https://github.com/JelteF/derive_more/releases)
- [Changelog](https://github.com/JelteF/derive_more/blob/v0.99.18/CHANGELOG.md)
- [Commits](https://github.com/JelteF/derive_more/compare/v0.99.17...v0.99.18)

---
updated-dependencies:
- dependency-name: derive_more
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump tokio from 1.37.0 to 1.38.0

Bumps [tokio](https://github.com/tokio-rs/tokio) from 1.37.0 to 1.38.0.
- [Release notes](https://github.com/tokio-rs/tokio/releases)
- [Commits](https://github.com/tokio-rs/tokio/compare/tokio-1.37.0...tokio-1.38.0)

---
updated-dependencies:
- dependency-name: tokio
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump regex from 1.10.4 to 1.10.5

Bumps [regex](https://github.com/rust-lang/regex) from 1.10.4 to 1.10.5.
- [Release notes](https://github.com/rust-lang/regex/releases)
- [Changelog](https://github.com/rust-lang/regex/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/regex/compare/1.10.4...1.10.5)

---
updated-dependencies:
- dependency-name: regex
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-06-17 09:06:01 -04:00
Arthur Brussee ac9f942a46
Remove GraphicsAPI generic for WgpuRuntime (#1888) 2024-06-17 09:04:25 -04:00
Joshua Ferguson eead748e90
add dependency management for python (#1887) 2024-06-17 09:00:38 -04:00
Louis Fortier-Dubois 8bf1cd60dc
Cube: variable reusability + refactor in cube macros (#1885) 2024-06-14 11:20:25 -04:00
Guillaume Lagrange 525244062f
Implement `Element` for `bool` (#1878)
* Element already implements One

* Add element module

* Add our own traits for Zero, One and ToPrimitive to support bool Element

* Fix typo

* Add basic tests for ToPrimitive with expected values

* The most important change of all

* Remove One + Zero identities

* Move zero/one outside mapv + refactor ToPrimitive -> ToElement trait

* Add num-traits to NOTICES.md
2024-06-14 09:02:38 -04:00
George b71c300638
Feat: Add `movedim` tensor operator (#1876)
*  (burn-tensor): add movedim function to tensor API

---------

Co-authored-by: Georgy Andreev <g.andreev@insilicomedicine.com>
2024-06-14 09:01:38 -04:00
Arthur Brussee 47a81270e1
Make autodiff compile on wasm (#1889) 2024-06-14 08:12:14 -04:00
Nathaniel Simard 5e58ae1a02
Refactor the tuner to be used standalone (#1884)
* Refactor the tuner to be used standalone

* Add a name for the autotune cache

* Fix tests

* Fix typo
2024-06-13 13:23:58 -04:00
Jonathan Richard 5de1517232
Add documentation to burn core nn (#1746)
* Updated documentation for unfold4d

Added links between the struct and the config. Added a link to the related burn_tensor function in the documentation for the forward function.

* Changing nn relu module documentation to functional api

Removing the formula for relu from the module API to the functional API,
citing a paper relevant to relu
and mentionning the functional API in the module API

* Linking gelu module API documentation to functional API documentation

* Linear module : adding documentation

Adding documentation to the Linear module
mentionning that LinearConfig struct
should be used when creating a Linear Layer

Also adding links to the documentation that points people toward
the right path

* Updated documentation for dropout

Added links between the struct and the config. Added a link to the struct in the forward function for more info.

* embedding + swiglu

* RotaryEncodying : adding documentation

Adding documentation stating the RotaryEncoding should be created using a RotaryEncodingConfig

* prelu: adding documentation

Adding documentation to the prelu module:
- Linking forward function documentation to the functional API
- Citing the first paper to mention prelu
- Adding documentation saying that prelu layer should be created using PReluConfig

* pos_encoding: adding documentation

* Updated documentation for mha

Added links for more info. Added shape info at some places.

* docs: Add documentation for Gru module

Provide documentation for the Gru module, including its configuration and usage. Include a link to the paper that introduced the Gated Recurrent Unit (GRU) and specify that the module should be created using GruConfig. Also, mention that the forward function returns a state tensor with specific dimensions.

* burn-core-nn-transformers: adding documentation

Adding documentation:
- Says to use config to create the layers
- Add mathematical formula to the pwff forward pass
- Add citation in the pwff to the "Attention is all you need" paper

* Updated documentation: ConvTranspose1d and ConvTranspose2d

* docs: Add documentation for Lstm and BiLstm modules

Provide documentation for the Lstm and BiLstm modules, including their configurations and usage. Include links to the papers that introduced Long Short-Term Memory (LSTM) and Bidirectional LSTM. Specify that the modules should be created using LstmConfig and BiLstmConfig respectively.

* docs: Update documentation for ConvTranspose1d and ConvTranspose2d modules

* loss: Adding documenntation to the loss layers

Adding documentation stating to use the config to create the layer

* chore: Refactor Conv1d module imports and update documentation

* docs: Add documentation for AdaptiveAvgPool1d and AdaptiveAvgPool2d modules

Added references to the burn_tensor associated functions. Added links between the struct and the config.

* Refactor Conv1d module imports and update documentation

* chore: Refactor Conv2d module imports and update documentation

* Add documentation for AvgPool1d and AvgPool2d modules

Added references to the burn_tensor associated functions. Added links between the struct and the config.

* Add documentation for MaxPool1d and MaxPool2d modules

Added references to the burn_tensor associated functions. Added links between the struct and the config.

* Add documentation for leaky_relu and removed Config generic

Added references to the burn_tensor associated functions. Added links between the struct and the config. Removed the backend generic from the config since it's not needed (might be a breaking change).

* refactor: Update BatchNormConfig initialization and add documentation.

* Added link to config in embedding struct documentation

* refactor: Update GroupNormConfig initialization and add documentation

* refactor: Update InstanceNormConfig initialization and add documentation

* feat: Update LayerNormConfig initialization and add documentation

* refactor: Update RmsNormConfig initialization and add documentation

* fixed: removed #derive accidentally

* Added missing backticks in pools' shapes

* Format nn doc

* Make config fields public in nn modules

* Update import statements in nn modules

Changed burn_tensor imports to crate::tensor

* Update import statements in nn modules' tests

Changed burn_tensor imports to crate::tensor

* breaking change refactor: Update GroupNormConfig and InstanceNormConfig initialization

* Make SwiGlu fields public

* grammar

* slashes

* input tensors grouping

* copy-pasta mistake

* a not an >:I

* Capitalization

* better desc

* math 'n ticks

* group_norm functional implementation

* removed the ... struct

* decoder typo

* fmt

* referring to private fn in docs

---------

Co-authored-by: Thierry Cantin-Demers <piertcd@gmail.com>
Co-authored-by: mepatrick73 <pameu17@ulaval.ca>
2024-06-13 12:50:21 -04:00
Louis Fortier-Dubois 4393b336bc
clippy on rust update (#1886) 2024-06-13 12:15:15 -04:00
Arthur Brussee c873d87ac8
Add option to flush queue instead of waiting for completion. (#1864)
* Make sync_type an option on sync instead of adding submit
2024-06-13 09:56:08 -04:00
Mitchell Mosure 71bd5efbfa
feat: resize onnx import (#1863)
* feat: resize onnx import

* fix: resize import proc macro output

* fix: lint

* fix: simplify resize onnx

* fix: onnx-tests passing

* feedback: remove dead code and resolve merge conflicts
2024-06-11 13:22:33 -04:00
jachym 671ec8c679
feat: added slice onnx import (#1856)
* feat: added slice onnx import

* fix: axes, steps handling
2024-06-11 07:50:03 -04:00
github-actions[bot] dd60446946
Combined PRs (#1874)
* Bump cudarc from 0.11.0 to 0.11.4

Bumps [cudarc](https://github.com/coreylowman/cudarc) from 0.11.0 to 0.11.4.
- [Release notes](https://github.com/coreylowman/cudarc/releases)
- [Commits](https://github.com/coreylowman/cudarc/compare/v0.11.0...v0.11.4)

---
updated-dependencies:
- dependency-name: cudarc
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump clap from 4.5.4 to 4.5.6

Bumps [clap](https://github.com/clap-rs/clap) from 4.5.4 to 4.5.6.
- [Release notes](https://github.com/clap-rs/clap/releases)
- [Changelog](https://github.com/clap-rs/clap/blob/master/CHANGELOG.md)
- [Commits](https://github.com/clap-rs/clap/compare/clap_complete-v4.5.4...v4.5.6)

---
updated-dependencies:
- dependency-name: clap
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump tar from 0.4.40 to 0.4.41

Bumps [tar](https://github.com/alexcrichton/tar-rs) from 0.4.40 to 0.4.41.
- [Commits](https://github.com/alexcrichton/tar-rs/compare/0.4.40...0.4.41)

---
updated-dependencies:
- dependency-name: tar
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump strum_macros from 0.26.2 to 0.26.4

Bumps [strum_macros](https://github.com/Peternator7/strum) from 0.26.2 to 0.26.4.
- [Release notes](https://github.com/Peternator7/strum/releases)
- [Changelog](https://github.com/Peternator7/strum/blob/master/CHANGELOG.md)
- [Commits](https://github.com/Peternator7/strum/commits)

---
updated-dependencies:
- dependency-name: strum_macros
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump zip from 2.1.2 to 2.1.3

Bumps [zip](https://github.com/zip-rs/zip2) from 2.1.2 to 2.1.3.
- [Release notes](https://github.com/zip-rs/zip2/releases)
- [Changelog](https://github.com/zip-rs/zip2/blob/master/CHANGELOG.md)
- [Commits](https://github.com/zip-rs/zip2/compare/v2.1.2...v2.1.3)

---
updated-dependencies:
- dependency-name: zip
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-06-10 16:22:08 -04:00
Joshua Ferguson effce28b72
Optimize argument handling and improve ONNX graph building (#1857)
* draft for alternative burn import design

* passes onnx test, fails to build example

* pushing to test example on main

* fixed the issue with the example

* passes the test now

* spring cleaning and minor code changes

* removed pub visibility from most graph_data fields and functions

* comment fixes

* went ahead and removed the constant check for now

* removed unused function arg
2024-06-10 14:06:54 -05:00
Louis Fortier-Dubois de5b681b18
Cube: Vectorization + simple matmul implementation (#1866) 2024-06-07 14:05:51 -04:00
Arthur Brussee 4b174a88bd
Get resources from server (#1861) 2024-06-06 17:33:57 -04:00
Arthur Brussee 75e26d03c3
Speedup client.create for small allocations. (#1858)
* Speedup client.create for small allocations.
2024-06-06 17:09:01 -04:00
Arthur Brussee 675f6b3280
Make Param.id public (#1859)
* Make Param.id public

* Remove extra comment.
2024-06-06 11:03:14 -04:00
Icekey d28183c7e4
LearnerBuilder "with_checkpointing_strategy" should use builder pattern (#1841) 2024-06-05 07:55:44 -04:00
Arthur Brussee e0a1094f89
Add a feature to initialize from an existing wgpu adapter/device/queue (#1788)
* Add a feature to initialize from an existing wgpu adapter/device/queue

This is useful when interacting with other wgpu applications (eg. displaying a burn tensor as a texture in egui). The existing devices are keyed by the wgpu Device ID. Alternatively they could be keyed per adapter which would be more inline with other burn WgpuDevice's (one per adapter), but also there's no real inherent reason to.

This also involves making Queue into an Arc. Alternatively, this could give up ownership of the queue, but it's helpful to be able to synchronize burn operations and custom wgpu operations.
2024-06-05 07:19:52 -04:00
mepatrick73 36ed65a5cd
Feat/dynamic mm basic implementation + small refactor (#1844) 2024-06-04 17:01:33 -04:00
Louis Fortier-Dubois c42abadfe9
Cube: CubeType (no launch) and Comptime::map (#1853) 2024-06-04 13:43:43 -04:00
jachym a5af19b959
feat: add sum onnx import (#1846) 2024-06-03 15:30:44 -05:00