Commit Graph

1398 Commits

Author SHA1 Message Date
Nathaniel Simard 978ac6c4ec
Chore: Update to newer cubecl version (#2181) 2024-08-25 15:33:16 -04:00
dependabot[bot] 9adf493305
Bump syn from 2.0.74 to 2.0.75 (#2173)
Bumps [syn](https://github.com/dtolnay/syn) from 2.0.74 to 2.0.75.
- [Release notes](https://github.com/dtolnay/syn/releases)
- [Commits](https://github.com/dtolnay/syn/compare/2.0.74...2.0.75)

---
updated-dependencies:
- dependency-name: syn
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-08-25 14:48:09 -04:00
mepatrick73 0beec0e39e
Scatter kernel from cpa to cubecl (#2169) 2024-08-25 13:47:16 -04:00
mepatrick73 c94e743829
Tensor type indent fix (#2196)
* pad-input-fix: adding support for pads as attributes

* final fix

* undo pad changes
2024-08-23 12:46:31 -04:00
mepatrick73 2c12d58cd8
pad-input-fix: adding support for pads as attributes (#2195)
* pad-input-fix: adding support for pads as attributes

* fix: making asked changes

* clippy fix
2024-08-23 12:46:14 -04:00
Guillaume Lagrange f5a1eca3ce
Fix root-mean-square precision issue (#2193) 2024-08-23 11:56:26 -04:00
Guillaume Lagrange 4999421f6c
Add RoPE `init_with_frequency_scaling` (#2194)
* Add RoPE init_with_frequency_scaling

* Fix clippy
2024-08-23 10:30:23 -04:00
Bjorn Beishline 17de832c6e
Make compatible with thumbv6m-none-eabi + add raspberry pi pico example (#2096)
* Made compatible with thumbv6m-none-eabi

* Added example of no_std on rp2040

* Added documentation on usage in no_std

* Rename rp2040 example and add README.md
2024-08-23 07:39:39 -04:00
Guillaume Lagrange 48a64d3b8a
Add images and csv dataset source to book (#2179) 2024-08-22 15:47:05 -04:00
mepatrick73 e1fed792f7
Gather CPA to CubeCL (#2165)
* working version

* cleanup

* wip

* working version of gather

* testsetsetser

* Revert "testsetsetser"

This reverts commit f37b329697.

* Reapply "testsetsetser"

This reverts commit f8ada0044e.

* Revert "testsetsetser"

This reverts commit f37b329697.

* Revert "working version of gather"

This reverts commit f5047c27c8.

* Revert "wip"

This reverts commit abaaa2dd55.

* Revert "Merge branch 'main' into index-cpa-to-cubecl"

This reverts commit 05bed8ea74, reversing
changes made to 94954fc32c.

* Revert "cleanup"

This reverts commit 94954fc32c.

* Revert "working version"

This reverts commit a06933f029.

* gather test

* fix

* fix clippy

* cleanup
2024-08-22 13:44:26 -04:00
Guillaume Lagrange 73d4b11aa2
Add a dataset/dataloader/batcher usage section (#2161)
* Add a dataset/dataloader/batcher usage section

* Fix typos
2024-08-22 10:52:56 -05:00
Dilshod Tadjibaev 75a2850047
Add closeness tensor report (#2184)
* Add closeness tensor report

* Add documentation section

* Fix for no-std

* Fix epsilon formatting

* Update report.rs

* Fix import references

* Fix doc test

* Use colored crate instead of passing codes

* Small refactor to use iter directly

* Move colored dep to std

* Add missing

* Fix missing epsilon
2024-08-22 10:19:27 -05:00
tiruka 77f8121d44
modified burn module paths in example notebooks (#2188) 2024-08-22 09:34:36 -04:00
github-actions[bot] 58129d1c11
Combined PRs (#2177)
* Bump tokenizers from 0.19.1 to 0.20.0

Bumps [tokenizers](https://github.com/huggingface/tokenizers) from 0.19.1 to 0.20.0.
- [Release notes](https://github.com/huggingface/tokenizers/releases)
- [Changelog](https://github.com/huggingface/tokenizers/blob/main/RELEASE.md)
- [Commits](https://github.com/huggingface/tokenizers/compare/v0.19.1...v0.20.0)

---
updated-dependencies:
- dependency-name: tokenizers
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump libc from 0.2.155 to 0.2.157

Bumps [libc](https://github.com/rust-lang/libc) from 0.2.155 to 0.2.157.
- [Release notes](https://github.com/rust-lang/libc/releases)
- [Changelog](https://github.com/rust-lang/libc/blob/0.2.157/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/libc/compare/0.2.155...0.2.157)

---
updated-dependencies:
- dependency-name: libc
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump tempfile from 3.11.0 to 3.12.0

Bumps [tempfile](https://github.com/Stebalien/tempfile) from 3.11.0 to 3.12.0.
- [Changelog](https://github.com/Stebalien/tempfile/blob/master/CHANGELOG.md)
- [Commits](https://github.com/Stebalien/tempfile/commits)

---
updated-dependencies:
- dependency-name: tempfile
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump serde from 1.0.207 to 1.0.208

Bumps [serde](https://github.com/serde-rs/serde) from 1.0.207 to 1.0.208.
- [Release notes](https://github.com/serde-rs/serde/releases)
- [Commits](https://github.com/serde-rs/serde/compare/v1.0.207...v1.0.208)

---
updated-dependencies:
- dependency-name: serde
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-08-20 08:36:48 -04:00
Guillaume Charifi 8053001306
Fix LayerNorm normalization. (#2186)
Fixes #2185.
2024-08-20 07:47:15 -04:00
Elias Rad c29ed43441
Docs fix spelling issues (#2183)
* fix README.md

* fix README.md

* fix mdbook.rs
2024-08-19 16:06:05 -05:00
Guillaume Lagrange 784f57bee4
Update model.bin mnist inference web + add cuda-jit flag for ag-news-infer (#2170)
* Update model.bin mnist inference web

* Add cuda-jit flag for ag-news-infer
2024-08-19 12:53:15 -04:00
Joseph Guhlin 2755c36ed7
Switches epoch and iteration to be in the proper order for the custom training loop (#2171) 2024-08-19 11:37:12 -04:00
Dilshod Tadjibaev d4a1d2026d
Fix equal/not-equal infinity numbers for burn-ndarray (#2166) 2024-08-15 12:33:54 -05:00
Guillaume Lagrange 31495f72c0
Fix class_index (#2167) 2024-08-15 12:33:41 -04:00
Guillaume Lagrange d2699022df
Add 0-dim tensor checks for creation ops and validate TensorData shape w/ num values (#2137) 2024-08-15 09:54:22 -04:00
Adrian Müller 16239db252
Fix ONNX Gather codegen for Shape input (#2148)
* Fix ONNX Gather codegen for Shape input

* Remove unneccessary cast, switch to slice for ownership
2024-08-15 07:36:13 -04:00
Periwink 0435721188
Convert `reduce_dim_naive` kernel to use the `#[cube]` derive macro (#2117) 2024-08-14 10:46:37 -04:00
Nathaniel Simard ff8d0308fb
Enable cuda-jit in burn-core + in text classification example (#2160) 2024-08-12 18:22:27 -04:00
Guillaume Lagrange 7c17e84a0e
Update cubecl rev (#2159) 2024-08-12 16:55:07 -04:00
github-actions[bot] be705466c9
Combined PRs (#2157)
* Bump syn from 2.0.72 to 2.0.74

Bumps [syn](https://github.com/dtolnay/syn) from 2.0.72 to 2.0.74.
- [Release notes](https://github.com/dtolnay/syn/releases)
- [Commits](https://github.com/dtolnay/syn/compare/2.0.72...2.0.74)

---
updated-dependencies:
- dependency-name: syn
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump serde from 1.0.204 to 1.0.206

Bumps [serde](https://github.com/serde-rs/serde) from 1.0.204 to 1.0.206.
- [Release notes](https://github.com/serde-rs/serde/releases)
- [Commits](https://github.com/serde-rs/serde/compare/v1.0.204...v1.0.206)

---
updated-dependencies:
- dependency-name: serde
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump clap from 4.5.12 to 4.5.15

Bumps [clap](https://github.com/clap-rs/clap) from 4.5.12 to 4.5.15.
- [Release notes](https://github.com/clap-rs/clap/releases)
- [Changelog](https://github.com/clap-rs/clap/blob/master/CHANGELOG.md)
- [Commits](https://github.com/clap-rs/clap/compare/clap_complete-v4.5.12...v4.5.15)

---
updated-dependencies:
- dependency-name: clap
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump serde_json from 1.0.122 to 1.0.124

Bumps [serde_json](https://github.com/serde-rs/json) from 1.0.122 to 1.0.124.
- [Release notes](https://github.com/serde-rs/json/releases)
- [Commits](https://github.com/serde-rs/json/compare/v1.0.122...v1.0.124)

---
updated-dependencies:
- dependency-name: serde_json
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-08-12 10:24:22 -04:00
dependabot[bot] 286f8174a8
Bump serde_json from 1.0.122 to 1.0.124 (#2156)
Bumps [serde_json](https://github.com/serde-rs/json) from 1.0.122 to 1.0.124.
- [Release notes](https://github.com/serde-rs/json/releases)
- [Commits](https://github.com/serde-rs/json/compare/v1.0.122...v1.0.124)

---
updated-dependencies:
- dependency-name: serde_json
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-08-12 10:23:43 -04:00
Guillaume Lagrange 0eec293e28
Fix indices dim check in gather_update_outputs (#2149) 2024-08-12 09:20:25 -04:00
Adrian Müller 12caca7909
Allow ONNX scalar greater/less with scalar (#2146) 2024-08-12 09:11:08 -04:00
Periwink e75eebfc31
Add comments for matmul kernel (#2138) 2024-08-12 09:09:24 -04:00
Guillaume Lagrange be5eb910d4
Remove CubeCL GELU kernel example reference (moved to CubeCL repo) (#2150) 2024-08-09 15:23:47 -04:00
Adrian Müller 5a0c1dcead
Implement ONNX Gather for scalar indices (#2141)
* Implement ONNX Gather for scalars

* Fix ONNX gather_scalar codegen test
2024-08-09 11:53:01 -04:00
Guillaume Lagrange 724bfbc73b
Add scientific notation formatting for small metric values (#2136) 2024-08-08 16:25:34 -04:00
Guillaume Lagrange 723c9d1a2e
Fix module derive with generics (#2127)
* Remove unnecessary ModuleDisplayDefault generic bound + duplicate ModuleDisplay

* Remove erroneous bound for autodiff module generic
2024-08-08 16:24:51 -04:00
Nathaniel Simard bb4a605ca6
Chore/integrate updated cubecl (#2142) 2024-08-08 16:19:39 -04:00
Dilshod Tadjibaev 1c681f46ec
Precision option for tensor display (#2139) 2024-08-08 15:01:42 -05:00
mepatrick73 27ca6cee95
feat: adding shape support for gather ONNX operation (#2128) 2024-08-08 13:18:03 -04:00
Guillaume Lagrange 0802d063d8
Fix inner backend typo in book guide (#2135) 2024-08-08 11:23:00 -05:00
tiruka 64b57792e0
modified mnist image link in the Hugging face (#2134) 2024-08-08 11:15:08 -05:00
mepatrick73 d770b1f470
ONNX Tile operation (#2092)
* renaming repeat to repeat_dim

* implementing repeat function

* renaming repeat files to repeat_dim

* renaming part 2

* renaming part 3

* renaming part 4

* renaming part 5

* adding test file

* adding unit test

* adding rust book documentation

* adding function args doc

* fixing tests

* changing repeat api to match pytorch equivalent

* fixing clippy error

* implementing tile onnx file

* temp

* working implementation and test

* working e2e test

* adding new supported onnx operation to the md file
2024-08-07 17:43:59 -04:00
Dilshod Tadjibaev 6b61ad5a61
Fix #2091 bug (in-place after expand) (#2114) 2024-08-07 17:37:20 -04:00
mepatrick73 e39485322d
bug fix: adding bounds checking to pad ONNX inputs (#2120)
* bug fix: adding bounds checking

Constant value is an optional value.
Adding a bounds check to make sure we've gotten enough inputs

* fixing pr

* quick little fix

* fixing constant_value cast

* fix clippy
2024-08-07 16:34:43 -04:00
Genna Wingert a01004dd4a
Add Hard sigmoid activation function (#2112)
* Add Hard Sigmoid activation function

* Add ONNX import conversion for HardSigmoid

* Update supported operators list

* Update book

* Make test comparison approximate to eliminate precision issues

* Add burn-candle test

* Fix name in E2E test generator
2024-08-07 13:01:42 -05:00
tiruka af8c3150c9
remove lto linker option to make build successful (#2123) 2024-08-07 12:59:53 -05:00
Periwink dad85e0709
Add onnx mean (#2119)
* make contacts deterministic across Worlds

* add top k acc

* add onnx mean

* fix

* push fix

* format

---------

Co-authored-by: Charles Bournhonesque <cbournhonesque@snapchat.com>
2024-08-07 13:03:59 -04:00
Dilshod Tadjibaev cd848b1c94
Add is_nan and contains_nan tensor ops (#2088)
* Add is_nan and contains_nan tensor ops

* Enable nan test for burn-candle

* Disabling tests due to #2089
2024-08-06 12:16:12 -05:00
Guillaume Lagrange 27d42cdaad
Fix aggregation results slice (#2110)
* Fix aggregation results slice

* View aggregation results as 1d tensor instead
2024-08-06 12:02:11 -05:00
Periwink ade664d4d8
Add top-k accuracy (#2097)
* make contacts deterministic across Worlds

* add top k acc

* update book

---------

Co-authored-by: Charles Bournhonesque <cbournhonesque@snapchat.com>
2024-08-06 12:01:28 -05:00
tiruka a53f459f20
Modify contributing md scripts to solve conflicts between doc and scripts (#2107)
* modified scripts comments and contributing.md to solve conflicts between them

* modified default value for checktypes and added NoArgs enum value

* added burn tch installation link

* Removed NoArgs enum value and use All as default
2024-08-05 13:32:17 -05:00
github-actions[bot] 9405713d2b
Combined PRs (#2108)
* Bump serde_json from 1.0.121 to 1.0.122

Bumps [serde_json](https://github.com/serde-rs/json) from 1.0.121 to 1.0.122.
- [Release notes](https://github.com/serde-rs/json/releases)
- [Commits](https://github.com/serde-rs/json/compare/v1.0.121...v1.0.122)

---
updated-dependencies:
- dependency-name: serde_json
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump regex from 1.10.5 to 1.10.6

Bumps [regex](https://github.com/rust-lang/regex) from 1.10.5 to 1.10.6.
- [Release notes](https://github.com/rust-lang/regex/releases)
- [Changelog](https://github.com/rust-lang/regex/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/regex/compare/1.10.5...1.10.6)

---
updated-dependencies:
- dependency-name: regex
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump wgpu from 22.0.0 to 22.1.0

Bumps [wgpu](https://github.com/gfx-rs/wgpu) from 22.0.0 to 22.1.0.
- [Release notes](https://github.com/gfx-rs/wgpu/releases)
- [Changelog](https://github.com/gfx-rs/wgpu/blob/wgpu-v22.1.0/CHANGELOG.md)
- [Commits](https://github.com/gfx-rs/wgpu/compare/wgpu-v22.0.0...wgpu-v22.1.0)

---
updated-dependencies:
- dependency-name: wgpu
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump tempfile from 3.10.1 to 3.11.0

Bumps [tempfile](https://github.com/Stebalien/tempfile) from 3.10.1 to 3.11.0.
- [Changelog](https://github.com/Stebalien/tempfile/blob/master/CHANGELOG.md)
- [Commits](https://github.com/Stebalien/tempfile/compare/v3.10.1...v3.11.0)

---
updated-dependencies:
- dependency-name: tempfile
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump flate2 from 1.0.30 to 1.0.31

Bumps [flate2](https://github.com/rust-lang/flate2-rs) from 1.0.30 to 1.0.31.
- [Release notes](https://github.com/rust-lang/flate2-rs/releases)
- [Commits](https://github.com/rust-lang/flate2-rs/commits)

---
updated-dependencies:
- dependency-name: flate2
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump EmbarkStudios/cargo-deny-action from 1 to 2

Bumps [EmbarkStudios/cargo-deny-action](https://github.com/embarkstudios/cargo-deny-action) from 1 to 2.
- [Release notes](https://github.com/embarkstudios/cargo-deny-action/releases)
- [Commits](https://github.com/embarkstudios/cargo-deny-action/compare/v1...v2)

---
updated-dependencies:
- dependency-name: EmbarkStudios/cargo-deny-action
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-08-05 11:43:58 -04:00