Commit Graph

81 Commits

Author SHA1 Message Date
Guillaume Lagrange f5a1eca3ce
Fix root-mean-square precision issue (#2193) 2024-08-23 11:56:26 -04:00
Guillaume Lagrange 4999421f6c
Add RoPE `init_with_frequency_scaling` (#2194)
* Add RoPE init_with_frequency_scaling

* Fix clippy
2024-08-23 10:30:23 -04:00
Bjorn Beishline 17de832c6e
Make compatible with thumbv6m-none-eabi + add raspberry pi pico example (#2096)
* Made compatible with thumbv6m-none-eabi

* Added example of no_std on rp2040

* Added documentation on usage in no_std

* Rename rp2040 example and add README.md
2024-08-23 07:39:39 -04:00
mepatrick73 e1fed792f7
Gather CPA to CubeCL (#2165)
* working version

* cleanup

* wip

* working version of gather

* testsetsetser

* Revert "testsetsetser"

This reverts commit f37b329697.

* Reapply "testsetsetser"

This reverts commit f8ada0044e.

* Revert "testsetsetser"

This reverts commit f37b329697.

* Revert "working version of gather"

This reverts commit f5047c27c8.

* Revert "wip"

This reverts commit abaaa2dd55.

* Revert "Merge branch 'main' into index-cpa-to-cubecl"

This reverts commit 05bed8ea74, reversing
changes made to 94954fc32c.

* Revert "cleanup"

This reverts commit 94954fc32c.

* Revert "working version"

This reverts commit a06933f029.

* gather test

* fix

* fix clippy

* cleanup
2024-08-22 13:44:26 -04:00
Guillaume Charifi 8053001306
Fix LayerNorm normalization. (#2186)
Fixes #2185.
2024-08-20 07:47:15 -04:00
Nathaniel Simard ff8d0308fb
Enable cuda-jit in burn-core + in text classification example (#2160) 2024-08-12 18:22:27 -04:00
Nathaniel Simard bb4a605ca6
Chore/integrate updated cubecl (#2142) 2024-08-08 16:19:39 -04:00
Genna Wingert a01004dd4a
Add Hard sigmoid activation function (#2112)
* Add Hard Sigmoid activation function

* Add ONNX import conversion for HardSigmoid

* Update supported operators list

* Update book

* Make test comparison approximate to eliminate precision issues

* Add burn-candle test

* Fix name in E2E test generator
2024-08-07 13:01:42 -05:00
mepatrick73 f7639bd35a
Repeat operation (#2090)
* renaming repeat to repeat_dim

* implementing repeat function

* renaming repeat files to repeat_dim

* renaming part 2

* renaming part 3

* renaming part 4

* renaming part 5

* adding test file

* adding unit test

* adding rust book documentation

* adding function args doc

* fixing tests

* changing repeat api to match pytorch equivalent

* fixing clippy error
2024-08-02 20:33:47 -04:00
Dilshod Tadjibaev 297173124f
Add 1d and 2d modules for interpolate with scaling (also fix ONNX Resize op) (#2081)
* Add interpolate module

* Update module.md

* Add interpolate 1d and 2d modules

* Consolidated InterpolateMode for 1d and 2d

* Remove CoordinateTransformationMode

* Add 1d tests for interpolate

* Refactor and fixes of ONNX Resize OP

* Fix clippy

* Fix docs

* Fix no_std
2024-07-31 12:08:26 -05:00
Nathaniel Simard 096ec13c48
Chore/update/cubecl (#2067) 2024-07-28 12:15:02 -04:00
Guillaume Lagrange 4c7353230e
Fix checks_channels_div_groups condition and ONNX conv import with groups (#2051)
* Fix checks_channels_div_groups condition

* Fix conv channels config w/ groups
2024-07-22 12:53:48 -05:00
Guillaume Lagrange 0d5025edbb
Refactor tensor quantization for q_* ops (#2025)
* Move QuantizationScheme to burn-tensor

* Refactor QuantizedTensorPrimitive to include the quantization strategy

* Fix QFloat tensor data display

* Refactor quantization methods to use scheme and qparams (on backend device)

* Fix clippy

* Fix fmt

* Add qtensor primitive tests
2024-07-19 10:39:50 -04:00
RuelYasa 9804bf81b2
Adding burn::nn::Sigmoid (#2031) 2024-07-17 14:34:44 -04:00
Guillaume Lagrange 3afff434bd
Module weight quantization (#2000)
* Add q_into_data and q_reshape

* Fix tch quantize f16 and q_into_data

* Convert to actual dtype/kind in dequantize

* Add module quantization and q_from_data

* Fix clippy

* Add documentation

* Handle deserialize data conversion

* Fix typo

* Add calibration tests

* Fix clippy precision

* Add QTensorOps require_grad methods to avoid dequantizing

* Add Dequantize mapper docs

* Remove dead code
2024-07-15 08:20:37 -04:00
Guillaume Lagrange c30ffcf6ac
Enable optimized handling of bytes (#2003)
* Enable optimized handling of bytes

* Implement byte buffer de/serialization

* Use serde_bytes w/ alloc (no_std compatible)
2024-07-11 07:48:43 -04:00
Guillaume Lagrange 6f158af4b1
Fix warnings when using `record-backward-compat` (#1977) 2024-07-08 07:58:50 -04:00
nathaniel 882a27c52c Revert "Revert "Implement 3D and transposed 3D convolutions. (#1945)""
This reverts commit b8b47ea6e6.
2024-07-05 18:57:01 -04:00
nathaniel b8b47ea6e6 Revert "Implement 3D and transposed 3D convolutions. (#1945)"
This reverts commit d696d74e3d.
2024-07-05 09:40:32 -04:00
Guillaume Charifi d696d74e3d
Implement 3D and transposed 3D convolutions. (#1945)
* Implement 3D and transposed 3D convolutions.

* Merge changes from onnx-ir #1921 pr

---------

Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
2024-07-02 17:54:35 -05:00
Dilshod Tadjibaev 2bb76283ff
Improve pickle (CandleTensor) conversions to NestedValue (#1944)
* Manually serialize tensor - fixes #1773

* Rename `value` to `bytes`
2024-07-02 08:34:19 -04:00
Arthur Brussee 849c8f453b
Consistent sync/async handling, allow more functions to be async for wasm. (#1936) 2024-07-02 08:25:28 -04:00
Dilshod Tadjibaev 98a58c867d
Print module - implement module display for remaining modules (part2) (#1933) 2024-06-28 08:37:40 -04:00
Guillaume Lagrange cdd1fa1672
Refactor tensor data (#1916)
* Move distribution to module

* Add new TensorData with serialization support

* Implement display and from for TensorData

* Add missing Cargo.lock

* Add missing bytemuck feature

* Add zeros, ones, full and random TensorData methods

* Refactor Data -> TensorData usage

* Fix tests

Since TensorData is not generic over the element type anymore no type inference can be done by the compiler. We must explicitly cast the expected results to the expected backend type.

* Remove commented line

* Fix import

* Add record-backward-compat

* Remove dim const generic from TensorData

* Support NestedValue de/serialization with TensorData

* Fix burn-jit tests

* Remove eprinln

* Refactor onnx import to use TensorData

* Fix tch from_data

* Fix nested value serialization for u8

* Fix missing import

* Fix reduce min onnx test

* Fix deprecated attribute

* Remove shape getter

* Remove strict assert in tests

* Add tensor data as_bytes

* Add tensor check for rank mismatch

* Fix typo (dimensions plural)

* Fix error message

* Update book examples with from_data and fix Display impl for TensorData

* Add deprecation note
2024-06-26 20:22:19 -04:00
Dilshod Tadjibaev 2c51615471
Print model structure like with PyTorch - Part 1 (#1912) 2024-06-25 09:23:10 -04:00
Nathaniel Simard 560d77d154
Doc: Improve module to_device/fork docs (#1901) 2024-06-18 16:45:38 -04:00
Nathaniel Simard e758fd43db
Fix: constant record loading (#1902) 2024-06-18 16:45:21 -04:00
Justin Restivo 263add23a0
Tanh nn wrapper (#1903) 2024-06-18 16:45:04 -04:00
Jonathan Richard 5de1517232
Add documentation to burn core nn (#1746)
* Updated documentation for unfold4d

Added links between the struct and the config. Added a link to the related burn_tensor function in the documentation for the forward function.

* Changing nn relu module documentation to functional api

Removing the formula for relu from the module API to the functional API,
citing a paper relevant to relu
and mentionning the functional API in the module API

* Linking gelu module API documentation to functional API documentation

* Linear module : adding documentation

Adding documentation to the Linear module
mentionning that LinearConfig struct
should be used when creating a Linear Layer

Also adding links to the documentation that points people toward
the right path

* Updated documentation for dropout

Added links between the struct and the config. Added a link to the struct in the forward function for more info.

* embedding + swiglu

* RotaryEncodying : adding documentation

Adding documentation stating the RotaryEncoding should be created using a RotaryEncodingConfig

* prelu: adding documentation

Adding documentation to the prelu module:
- Linking forward function documentation to the functional API
- Citing the first paper to mention prelu
- Adding documentation saying that prelu layer should be created using PReluConfig

* pos_encoding: adding documentation

* Updated documentation for mha

Added links for more info. Added shape info at some places.

* docs: Add documentation for Gru module

Provide documentation for the Gru module, including its configuration and usage. Include a link to the paper that introduced the Gated Recurrent Unit (GRU) and specify that the module should be created using GruConfig. Also, mention that the forward function returns a state tensor with specific dimensions.

* burn-core-nn-transformers: adding documentation

Adding documentation:
- Says to use config to create the layers
- Add mathematical formula to the pwff forward pass
- Add citation in the pwff to the "Attention is all you need" paper

* Updated documentation: ConvTranspose1d and ConvTranspose2d

* docs: Add documentation for Lstm and BiLstm modules

Provide documentation for the Lstm and BiLstm modules, including their configurations and usage. Include links to the papers that introduced Long Short-Term Memory (LSTM) and Bidirectional LSTM. Specify that the modules should be created using LstmConfig and BiLstmConfig respectively.

* docs: Update documentation for ConvTranspose1d and ConvTranspose2d modules

* loss: Adding documenntation to the loss layers

Adding documentation stating to use the config to create the layer

* chore: Refactor Conv1d module imports and update documentation

* docs: Add documentation for AdaptiveAvgPool1d and AdaptiveAvgPool2d modules

Added references to the burn_tensor associated functions. Added links between the struct and the config.

* Refactor Conv1d module imports and update documentation

* chore: Refactor Conv2d module imports and update documentation

* Add documentation for AvgPool1d and AvgPool2d modules

Added references to the burn_tensor associated functions. Added links between the struct and the config.

* Add documentation for MaxPool1d and MaxPool2d modules

Added references to the burn_tensor associated functions. Added links between the struct and the config.

* Add documentation for leaky_relu and removed Config generic

Added references to the burn_tensor associated functions. Added links between the struct and the config. Removed the backend generic from the config since it's not needed (might be a breaking change).

* refactor: Update BatchNormConfig initialization and add documentation.

* Added link to config in embedding struct documentation

* refactor: Update GroupNormConfig initialization and add documentation

* refactor: Update InstanceNormConfig initialization and add documentation

* feat: Update LayerNormConfig initialization and add documentation

* refactor: Update RmsNormConfig initialization and add documentation

* fixed: removed #derive accidentally

* Added missing backticks in pools' shapes

* Format nn doc

* Make config fields public in nn modules

* Update import statements in nn modules

Changed burn_tensor imports to crate::tensor

* Update import statements in nn modules' tests

Changed burn_tensor imports to crate::tensor

* breaking change refactor: Update GroupNormConfig and InstanceNormConfig initialization

* Make SwiGlu fields public

* grammar

* slashes

* input tensors grouping

* copy-pasta mistake

* a not an >:I

* Capitalization

* better desc

* math 'n ticks

* group_norm functional implementation

* removed the ... struct

* decoder typo

* fmt

* referring to private fn in docs

---------

Co-authored-by: Thierry Cantin-Demers <piertcd@gmail.com>
Co-authored-by: mepatrick73 <pameu17@ulaval.ca>
2024-06-13 12:50:21 -04:00
Arthur Brussee 675f6b3280
Make Param.id public (#1859)
* Make Param.id public

* Remove extra comment.
2024-06-06 11:03:14 -04:00
Guillaume Lagrange e4836241e1
Fix `DataSerialize` conversion for elements of the same type (#1832) 2024-05-28 18:12:44 -04:00
Guillaume Lagrange b466fd7606
Add seq start position when applying RoPE encoding (#1796) 2024-05-22 13:18:31 -04:00
Guillaume Lagrange 550086a5c1
Fix record nested value de/serialization (#1751) 2024-05-22 09:15:32 -04:00
getumen e823338750
Add Clone trait to the `OptimizerAdaptor` and Clone implementations to the optimizers (#1770) 2024-05-15 09:18:09 -04:00
Ben Barber d3cd6c4928
Replace opaque return types in optim (#1767)
* update ARCHITECTURE.md links to project architecture section in contributor book

* replace opaque return type in optim
2024-05-13 22:21:20 -04:00
Ahmed Yarub Hani Al Nuaimi 10737527d8
#1747 Upgrade Rust dependencies (#1748)
* #1747
Upgrade Rust dependencies

* Revert upgrade for tch

The update of tch on windows gives an error:

INTEL MKL ERROR: The specified module could not be found. mkl_vml_avx2.1.dll.
Intel MKL FATAL ERROR: cannot load mkl_vml_avx2.1.dll or mkl_vml_def.1.dll.

* Keep only .cargo/config.toml file which works with rust > 1.75

---------

Co-authored-by: Sylvain Benner <sylvain@benner.online>
2024-05-10 16:25:19 -04:00
Thierry Cantin-Demers b09d8431df
Fix Cargo.toml repository links (#1749)
* Fix wgpu github link

* Fix burn-train repo link

* Fix burn-tensor github repo

* Fix burn-tensor repo link

* Fix remaining repo links in crates Cargo.toml

---------

Co-authored-by: Jonathan Richard <47578360+jwric@users.noreply.github.com>
2024-05-09 15:40:05 -04:00
Arjun31415 5bbc5ea944
Added ONNX AvgPool1d (#1744) 2024-05-07 16:10:18 -05:00
Arjun31415 7f94f4c219
Add MaxPool1d ONNX Op(#1725) 2024-05-06 10:51:00 -05:00
Anton Blomström f8994e044c
Fix unstable tests when run concurrently (#1724) 2024-05-05 15:27:42 -05:00
Nathaniel Simard 5d959e2884
[Fusion] Support multi-precision fusion (#1718) 2024-05-02 18:22:56 -04:00
Nathaniel Simard 587b8f80b3
First draft CUDA runtime (#1685)
Initial cuda runtime crate with a WIP compiler.
2024-04-30 09:46:29 -04:00
WU Chen b387829731
Implement bidirectional LSTM (#1035)
* resolve conflict

* move `gate_product` to `GateController`

* BiLstm needs to use its own initializer when init

* resolve conflicts

* add some comments

* improve doc

* correct the description of GateController

* fix fmt

* add `LstmState`

* add test for state

* set batch 2 in bilstm test

* resolve conflict

* fix

* fix doc

* change the batch size back to 1

* change the batch size back to 1

* modify docstring; delete dead comment
2024-04-26 13:28:36 -05:00
Nathaniel Simard 2f294c5092
Fix lstm batch size bug (#1695) 2024-04-26 08:54:12 -04:00
Dilshod Tadjibaev 67ec06d5d8
ONNX support for scalar unsqueeze (#1690)
* Revert 1c639c8393

1c639c8393?diff=unified&w=0

* Refactor by @laggui

* Refactor unsqueeze

* Add support for scalar unsqueeze

* Removed dead comment
2024-04-25 16:05:28 -05:00
Nathaniel Simard 29fa2ee76c
Support linear 1d (#1682) 2024-04-22 18:39:09 -04:00
Sylvain Benner e303e31c8b
Bump next version of Burn to 0.14.0 (#1618) 2024-04-12 17:14:45 -04:00
Guillaume Lagrange 9980db440d
Remove unused assets (#1616) 2024-04-12 15:48:16 -04:00
Guillaume Lagrange 264c167c11
Update licenses symlinks (#1613) 2024-04-12 14:43:58 -04:00
Aasheesh Singh fb1da53a38
support for rotary positional encoding to transformer modules. (#1604)
* add rotary positional encoding to transformer modules.

* fix f64 error

* use num_traits

* add panic condition
2024-04-12 11:45:49 -04:00