Commit Graph

1116 Commits

Author SHA1 Message Date
nathaniel 6f6b3bcc01 no wgpu validation layer when not(test) 2024-05-08 17:03:45 -04:00
Arjun31415 5bbc5ea944
Added ONNX AvgPool1d (#1744) 2024-05-07 16:10:18 -05:00
Nathaniel Simard a6e3b4e81e
Fix select assign backward (#1739) 2024-05-07 11:37:43 -04:00
Sébastien Boisvert bd06b38fac
Refactor: replace trait TemplateKernel by existing trait JitKernel (#1737)
* Refactor: replace trait TemplateKernel by existing trait JitKernel

* Refactor: implement trait JitKernel for struct Kernel
2024-05-06 20:59:00 -04:00
Jonathan Richard e233c38b0f
Add hidden code snippets to guide example in Burn book [redo] (#1742)
* added hidden code snippets in Burn book guide example

* Update backend.md

* Revert last commit
2024-05-06 20:29:28 -04:00
mepatrick73 adbe97dc4d
Fixing various syntax errors in the Burn book (#1740) 2024-05-06 17:25:22 -04:00
Thierry Cantin-Demers 1cde566317
Add indentation to project architecture in contributing book (#1738)
Now reflects the structure of the book
2024-05-06 13:43:21 -04:00
Arjun31415 7f94f4c219
Add MaxPool1d ONNX Op(#1725) 2024-05-06 10:51:00 -05:00
Anton Blomström fb13503fa9
Add reduce sum onnx ops to burn imports (#1723) 2024-05-06 10:49:17 -05:00
github-actions[bot] 0b919b6a58
Combined PRs (#1734)
* Bump serde from 1.0.199 to 1.0.200

Bumps [serde](https://github.com/serde-rs/serde) from 1.0.199 to 1.0.200.
- [Release notes](https://github.com/serde-rs/serde/releases)
- [Commits](https://github.com/serde-rs/serde/compare/v1.0.199...v1.0.200)

---
updated-dependencies:
- dependency-name: serde
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump num-traits from 0.2.18 to 0.2.19

Bumps [num-traits](https://github.com/rust-num/num-traits) from 0.2.18 to 0.2.19.
- [Changelog](https://github.com/rust-num/num-traits/blob/master/RELEASES.md)
- [Commits](https://github.com/rust-num/num-traits/compare/num-traits-0.2.18...num-traits-0.2.19)

---
updated-dependencies:
- dependency-name: num-traits
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump image from 0.24.9 to 0.25.1

Bumps [image](https://github.com/image-rs/image) from 0.24.9 to 0.25.1.
- [Changelog](https://github.com/image-rs/image/blob/main/CHANGES.md)
- [Commits](https://github.com/image-rs/image/compare/v0.24.9...v0.25.1)

---
updated-dependencies:
- dependency-name: image
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump rmp-serde from 1.2.0 to 1.3.0

Bumps [rmp-serde](https://github.com/3Hren/msgpack-rust) from 1.2.0 to 1.3.0.
- [Release notes](https://github.com/3Hren/msgpack-rust/releases)
- [Commits](https://github.com/3Hren/msgpack-rust/compare/rmp-serde/v1.2.0...rmp-serde/v1.3.0)

---
updated-dependencies:
- dependency-name: rmp-serde
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump libc from 0.2.153 to 0.2.154

Bumps [libc](https://github.com/rust-lang/libc) from 0.2.153 to 0.2.154.
- [Release notes](https://github.com/rust-lang/libc/releases)
- [Commits](https://github.com/rust-lang/libc/compare/0.2.153...0.2.154)

---
updated-dependencies:
- dependency-name: libc
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-05-06 10:20:31 -04:00
Anton Blomström f8994e044c
Fix unstable tests when run concurrently (#1724) 2024-05-05 15:27:42 -05:00
Arjun31415 152509c378
PReLu ONNX import (#1721)
* added prelu onnx operator

* bug fix

* added onnx tests and burn codegen tests

* fix tests

* added prelu to supported onnx ops and add prelu to dim_inference
2024-05-04 13:45:42 -05:00
Louis Fortier-Dubois a8661a2f53
Autodiff Memory Management: BFS (#1710) 2024-05-03 09:45:21 -04:00
Nathaniel Simard 5d959e2884
[Fusion] Support multi-precision fusion (#1718) 2024-05-02 18:22:56 -04:00
Anton Blomström 6b14bb8f01
Add info about enabling debugging for new contributors (#1719) 2024-05-02 17:42:18 -04:00
Louis Fortier-Dubois 2e4c82fa64
Fix repeat for dims > 1 (#1713) 2024-05-01 09:11:38 -04:00
Dilshod Tadjibaev 3a02a54e55
Update SUPPORTED-ONNX-OPS.md (#1717)
gather ONNX was checked off but actually GatherElements should have been updated.
2024-05-01 08:02:59 -04:00
Sylvain Benner 9f62094c07
Exclude burn-cuda from workspace to avoid build error for some users (#1716) 2024-05-01 07:33:02 -04:00
Dilshod Tadjibaev ff9e875321
ONNX debug improvements (#1712)
* Minor debug improvements

* Change warn to panic

* Log improvements
2024-04-30 16:36:55 -05:00
Nathaniel Simard 587b8f80b3
First draft CUDA runtime (#1685)
Initial cuda runtime crate with a WIP compiler.
2024-04-30 09:46:29 -04:00
Jonathan Merritt ab501431b1
Handle ndarray matmul broadcasting (#1679)
* Handle ndarray matmul broadcasting

- Use strides to map linear batch indices from
  the output back to the input arrays.

* Fix typos
2024-04-29 17:25:27 -05:00
Dilshod Tadjibaev 1cdceb590f
Skip updating shape for linear if not present (#1700) 2024-04-29 14:53:18 -05:00
github-actions[bot] bb24f1be2a
Combined PRs (#1708)
* Bump serde from 1.0.198 to 1.0.199

Bumps [serde](https://github.com/serde-rs/serde) from 1.0.198 to 1.0.199.
- [Release notes](https://github.com/serde-rs/serde/releases)
- [Commits](https://github.com/serde-rs/serde/compare/v1.0.198...v1.0.199)

---
updated-dependencies:
- dependency-name: serde
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump serial_test from 3.1.0 to 3.1.1

Bumps [serial_test](https://github.com/palfrey/serial_test) from 3.1.0 to 3.1.1.
- [Release notes](https://github.com/palfrey/serial_test/releases)
- [Commits](https://github.com/palfrey/serial_test/compare/v3.1.0...v3.1.1)

---
updated-dependencies:
- dependency-name: serial_test
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump hashbrown from 0.14.3 to 0.14.5

Bumps [hashbrown](https://github.com/rust-lang/hashbrown) from 0.14.3 to 0.14.5.
- [Changelog](https://github.com/rust-lang/hashbrown/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/hashbrown/compare/v0.14.3...v0.14.5)

---
updated-dependencies:
- dependency-name: hashbrown
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump tokenizers from 0.15.2 to 0.19.1

Bumps [tokenizers](https://github.com/huggingface/tokenizers) from 0.15.2 to 0.19.1.
- [Release notes](https://github.com/huggingface/tokenizers/releases)
- [Changelog](https://github.com/huggingface/tokenizers/blob/main/RELEASE.md)
- [Commits](https://github.com/huggingface/tokenizers/compare/v0.15.2...v0.19.1)

---
updated-dependencies:
- dependency-name: tokenizers
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump flate2 from 1.0.28 to 1.0.29

Bumps [flate2](https://github.com/rust-lang/flate2-rs) from 1.0.28 to 1.0.29.
- [Release notes](https://github.com/rust-lang/flate2-rs/releases)
- [Commits](https://github.com/rust-lang/flate2-rs/compare/1.0.28...1.0.29)

---
updated-dependencies:
- dependency-name: flate2
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-04-29 07:52:56 -04:00
Sylvain Benner 1f8b5d3efb
[guide] Remove ambiguity lib vs. executable (#1649) 2024-04-26 15:42:02 -04:00
Guillaume Lagrange b7ab19ac71
Fix inverted epoch - iteration counts in valid progress (#1699) 2024-04-26 15:26:09 -04:00
WU Chen b387829731
Implement bidirectional LSTM (#1035)
* resolve conflict

* move `gate_product` to `GateController`

* BiLstm needs to use its own initializer when init

* resolve conflicts

* add some comments

* improve doc

* correct the description of GateController

* fix fmt

* add `LstmState`

* add test for state

* set batch 2 in bilstm test

* resolve conflict

* fix

* fix doc

* change the batch size back to 1

* change the batch size back to 1

* modify docstring; delete dead comment
2024-04-26 13:28:36 -05:00
Louis Fortier-Dubois 6ae3926006
New autodiff graph memory management strategy (#1698)
---------

Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
2024-04-26 12:25:53 -04:00
Guillaume Lagrange 13cd88f2e6
Remove leaky relu ONNX file (#1697) 2024-04-26 09:57:09 -04:00
wangxiaochuTHU 03dd7e0dce
Update README.md (#1696)
Fix the link to 'limited set of ONNX operators'
2024-04-26 09:24:47 -04:00
Nathaniel Simard 2f294c5092
Fix lstm batch size bug (#1695) 2024-04-26 08:54:12 -04:00
Guillaume Lagrange ce2429eb10
Refactor element type to be decoupled from runtime (#1693) 2024-04-26 08:53:55 -04:00
Dilshod Tadjibaev 67ec06d5d8
ONNX support for scalar unsqueeze (#1690)
* Revert 1c639c8393

1c639c8393?diff=unified&w=0

* Refactor by @laggui

* Refactor unsqueeze

* Add support for scalar unsqueeze

* Removed dead comment
2024-04-25 16:05:28 -05:00
Nathaniel Simard 599a20d586
Upgrade wgpu (#1692) 2024-04-25 16:32:50 -04:00
Dilshod Tadjibaev a1bd14c5ae
Reshape bug fix (#1684)
* Revert 1c639c8393

1c639c8393?diff=unified&w=0

* Refactor by @laggui

* Refactor unsqueeze
2024-04-24 19:31:53 -05:00
Nathaniel Simard 886a1de235
Refactor/burn compute (#1580) 2024-04-23 13:05:15 -04:00
Sylvain Benner c579686a8a
Move HandleContainer and Tensor Ops descriptions from burn-fusion to burn-tensor (#1654)
* Move HandlerContainer and Tensor Ops description to burn-tensor

Move HandleContainer and Tensor operations descriptions to burn-tensor crate.
Removed the FusionDevice and replaced it with a DeviceOps trait bound to Backend::Device.

For now added modules to burn-tensor are excluded from no-std as they rely on Arc.

* [burn-tensor] Flatten module hierarchy for tensor representation

+ Add new repr feature to cargo file.

* Remove prefix on dosctring

* [burn-fusion] Require default features of burn-tensor
2024-04-23 11:27:54 -04:00
Guillaume Lagrange e6b1b7a317
Add layer norm onnx op support (#1680) 2024-04-23 11:19:07 -04:00
Dilshod Tadjibaev 1718da5210
Fix reshape bug (support for opset version 1) (#1667)
* Make reshape op version 1

* Refactor per PR feedback
2024-04-22 17:52:25 -05:00
Nathaniel Simard 29fa2ee76c
Support linear 1d (#1682) 2024-04-22 18:39:09 -04:00
Guillaume Lagrange fd26c1a241
Fix ONNX and PyTorch import section links in burn book (#1681) 2024-04-22 18:38:05 -04:00
Alex Errant d62b344d5b
`Arc<EventStoreClient>` to `Rc<EventStoreClient>` (#1668) 2024-04-22 18:21:53 -04:00
github-actions[bot] 6c708527b9
Combined PRs (#1678)
* Bump rustls from 0.22.2 to 0.22.4 in the cargo group

Bumps the cargo group with 1 update: [rustls](https://github.com/rustls/rustls).


Updates `rustls` from 0.22.2 to 0.22.4
- [Release notes](https://github.com/rustls/rustls/releases)
- [Changelog](https://github.com/rustls/rustls/blob/main/CHANGELOG.md)
- [Commits](https://github.com/rustls/rustls/compare/v/0.22.2...v/0.22.4)

---
updated-dependencies:
- dependency-name: rustls
  dependency-type: indirect
  dependency-group: cargo
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump serde from 1.0.197 to 1.0.198

Bumps [serde](https://github.com/serde-rs/serde) from 1.0.197 to 1.0.198.
- [Release notes](https://github.com/serde-rs/serde/releases)
- [Commits](https://github.com/serde-rs/serde/compare/v1.0.197...v1.0.198)

---
updated-dependencies:
- dependency-name: serde
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump bytemuck from 1.14.3 to 1.15.0

Bumps [bytemuck](https://github.com/Lokathor/bytemuck) from 1.14.3 to 1.15.0.
- [Changelog](https://github.com/Lokathor/bytemuck/blob/main/changelog.md)
- [Commits](https://github.com/Lokathor/bytemuck/compare/v1.14.3...v1.15.0)

---
updated-dependencies:
- dependency-name: bytemuck
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump syn from 2.0.55 to 2.0.58

Bumps [syn](https://github.com/dtolnay/syn) from 2.0.55 to 2.0.58.
- [Release notes](https://github.com/dtolnay/syn/releases)
- [Commits](https://github.com/dtolnay/syn/compare/2.0.55...2.0.58)

---
updated-dependencies:
- dependency-name: syn
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump serial_test from 3.0.0 to 3.1.0

Bumps [serial_test](https://github.com/palfrey/serial_test) from 3.0.0 to 3.1.0.
- [Release notes](https://github.com/palfrey/serial_test/releases)
- [Commits](https://github.com/palfrey/serial_test/compare/v3.0.0...v3.1.0)

---
updated-dependencies:
- dependency-name: serial_test
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-04-22 10:22:49 -04:00
신희제(Heejae Shin/joel.barish) 2a7b296a1b
Add sign ONNX op import support (#1663)
* Add sign ONNX op support

* Update SUPPORTED-ONNX-OPS.md
2024-04-22 09:10:50 -04:00
Louis Fortier-Dubois 2140d9b568
remove JIT subsequent RNG tests (#1652) 2024-04-21 09:48:11 -04:00
Dilshod Tadjibaev 1433284a0f
Fix bug 1645 (Unsqueeze OpSet 11) (#1661)
* Add unsqueeze opset 16 test

* Fix for unsqueeze opset 11

* Remove println statement
2024-04-19 14:17:44 -05:00
Guillaume Lagrange b65a487300
Fix transpose onnx op (permute) (#1657) 2024-04-19 09:34:03 -04:00
Nico Zweifel ee12aee2e7
fix: `window` -> `pub window` in `dataset/mod.rs` (#1658)
* update dataset/mod.rs

* Update mod.rs

* Update window.rs
2024-04-19 09:33:21 -04:00
Guillaume Lagrange 9fbcbed20f
Add where onnx op support (#1653)
* Add where onnx op support

* Add broadcasting support

* Remove broadcasting limitation comment

* Fix broadcasting in mask where

* Forgot to reflect changes in codegen test

* Fix clippy
2024-04-18 15:46:02 -04:00
Guillaume Lagrange 7705fd9c25
Add matmul ONNX op support (#1638)
* Mul onnx op already supported

* Add matmul onnx op checks and tests

* Add missing eq derives

* Change supscript symbol

* Remove dead code

* Add support for matmul broadcast

* No more broadcasting restrictions

* Add results comment for mm, mv and vm
2024-04-18 09:20:31 -04:00
Dilshod Tadjibaev 2a721a9d0c
Enable native sign operation for Candle backend (#1647)
* Enable native sign operation for Candle backend

* Use fixed revision
2024-04-17 09:07:56 -04:00