Commit Graph

1337 Commits

Author SHA1 Message Date
mepatrick73 c9eb8d817d test 2024-08-14 19:57:03 -04:00
mepatrick73 9e23cc4227 trying part 2 2024-08-14 19:01:48 -04:00
mepatrick73 c111d9dd61 perhaps fix 2024-08-14 18:49:55 -04:00
mepatrick73 44053277fb scatter cleanup 2024-08-14 18:27:33 -04:00
mepatrick73 ab5d437adf Merge branch 'main' into index-cpa-to-cubecl 2024-08-14 18:26:38 -04:00
mepatrick73 7899cee125 Working version ! 2024-08-14 18:26:26 -04:00
Periwink 0435721188
Convert `reduce_dim_naive` kernel to use the `#[cube]` derive macro (#2117) 2024-08-14 10:46:37 -04:00
mepatrick73 f5047c27c8 working version of gather 2024-08-13 20:30:46 -04:00
mepatrick73 abaaa2dd55 wip 2024-08-13 19:56:04 -04:00
mepatrick73 05bed8ea74 Merge branch 'main' into index-cpa-to-cubecl 2024-08-13 18:24:36 -04:00
mepatrick73 94954fc32c cleanup 2024-08-13 18:24:29 -04:00
Nathaniel Simard ff8d0308fb
Enable cuda-jit in burn-core + in text classification example (#2160) 2024-08-12 18:22:27 -04:00
mepatrick73 a06933f029 working version 2024-08-12 17:13:53 -04:00
Guillaume Lagrange 7c17e84a0e
Update cubecl rev (#2159) 2024-08-12 16:55:07 -04:00
github-actions[bot] be705466c9
Combined PRs (#2157)
* Bump syn from 2.0.72 to 2.0.74

Bumps [syn](https://github.com/dtolnay/syn) from 2.0.72 to 2.0.74.
- [Release notes](https://github.com/dtolnay/syn/releases)
- [Commits](https://github.com/dtolnay/syn/compare/2.0.72...2.0.74)

---
updated-dependencies:
- dependency-name: syn
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump serde from 1.0.204 to 1.0.206

Bumps [serde](https://github.com/serde-rs/serde) from 1.0.204 to 1.0.206.
- [Release notes](https://github.com/serde-rs/serde/releases)
- [Commits](https://github.com/serde-rs/serde/compare/v1.0.204...v1.0.206)

---
updated-dependencies:
- dependency-name: serde
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump clap from 4.5.12 to 4.5.15

Bumps [clap](https://github.com/clap-rs/clap) from 4.5.12 to 4.5.15.
- [Release notes](https://github.com/clap-rs/clap/releases)
- [Changelog](https://github.com/clap-rs/clap/blob/master/CHANGELOG.md)
- [Commits](https://github.com/clap-rs/clap/compare/clap_complete-v4.5.12...v4.5.15)

---
updated-dependencies:
- dependency-name: clap
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump serde_json from 1.0.122 to 1.0.124

Bumps [serde_json](https://github.com/serde-rs/json) from 1.0.122 to 1.0.124.
- [Release notes](https://github.com/serde-rs/json/releases)
- [Commits](https://github.com/serde-rs/json/compare/v1.0.122...v1.0.124)

---
updated-dependencies:
- dependency-name: serde_json
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-08-12 10:24:22 -04:00
dependabot[bot] 286f8174a8
Bump serde_json from 1.0.122 to 1.0.124 (#2156)
Bumps [serde_json](https://github.com/serde-rs/json) from 1.0.122 to 1.0.124.
- [Release notes](https://github.com/serde-rs/json/releases)
- [Commits](https://github.com/serde-rs/json/compare/v1.0.122...v1.0.124)

---
updated-dependencies:
- dependency-name: serde_json
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-08-12 10:23:43 -04:00
Guillaume Lagrange 0eec293e28
Fix indices dim check in gather_update_outputs (#2149) 2024-08-12 09:20:25 -04:00
Adrian Müller 12caca7909
Allow ONNX scalar greater/less with scalar (#2146) 2024-08-12 09:11:08 -04:00
Periwink e75eebfc31
Add comments for matmul kernel (#2138) 2024-08-12 09:09:24 -04:00
Guillaume Lagrange be5eb910d4
Remove CubeCL GELU kernel example reference (moved to CubeCL repo) (#2150) 2024-08-09 15:23:47 -04:00
Adrian Müller 5a0c1dcead
Implement ONNX Gather for scalar indices (#2141)
* Implement ONNX Gather for scalars

* Fix ONNX gather_scalar codegen test
2024-08-09 11:53:01 -04:00
Guillaume Lagrange 724bfbc73b
Add scientific notation formatting for small metric values (#2136) 2024-08-08 16:25:34 -04:00
Guillaume Lagrange 723c9d1a2e
Fix module derive with generics (#2127)
* Remove unnecessary ModuleDisplayDefault generic bound + duplicate ModuleDisplay

* Remove erroneous bound for autodiff module generic
2024-08-08 16:24:51 -04:00
Nathaniel Simard bb4a605ca6
Chore/integrate updated cubecl (#2142) 2024-08-08 16:19:39 -04:00
Dilshod Tadjibaev 1c681f46ec
Precision option for tensor display (#2139) 2024-08-08 15:01:42 -05:00
mepatrick73 27ca6cee95
feat: adding shape support for gather ONNX operation (#2128) 2024-08-08 13:18:03 -04:00
Guillaume Lagrange 0802d063d8
Fix inner backend typo in book guide (#2135) 2024-08-08 11:23:00 -05:00
tiruka 64b57792e0
modified mnist image link in the Hugging face (#2134) 2024-08-08 11:15:08 -05:00
mepatrick73 d770b1f470
ONNX Tile operation (#2092)
* renaming repeat to repeat_dim

* implementing repeat function

* renaming repeat files to repeat_dim

* renaming part 2

* renaming part 3

* renaming part 4

* renaming part 5

* adding test file

* adding unit test

* adding rust book documentation

* adding function args doc

* fixing tests

* changing repeat api to match pytorch equivalent

* fixing clippy error

* implementing tile onnx file

* temp

* working implementation and test

* working e2e test

* adding new supported onnx operation to the md file
2024-08-07 17:43:59 -04:00
Dilshod Tadjibaev 6b61ad5a61
Fix #2091 bug (in-place after expand) (#2114) 2024-08-07 17:37:20 -04:00
mepatrick73 e39485322d
bug fix: adding bounds checking to pad ONNX inputs (#2120)
* bug fix: adding bounds checking

Constant value is an optional value.
Adding a bounds check to make sure we've gotten enough inputs

* fixing pr

* quick little fix

* fixing constant_value cast

* fix clippy
2024-08-07 16:34:43 -04:00
Genna Wingert a01004dd4a
Add Hard sigmoid activation function (#2112)
* Add Hard Sigmoid activation function

* Add ONNX import conversion for HardSigmoid

* Update supported operators list

* Update book

* Make test comparison approximate to eliminate precision issues

* Add burn-candle test

* Fix name in E2E test generator
2024-08-07 13:01:42 -05:00
tiruka af8c3150c9
remove lto linker option to make build successful (#2123) 2024-08-07 12:59:53 -05:00
Periwink dad85e0709
Add onnx mean (#2119)
* make contacts deterministic across Worlds

* add top k acc

* add onnx mean

* fix

* push fix

* format

---------

Co-authored-by: Charles Bournhonesque <cbournhonesque@snapchat.com>
2024-08-07 13:03:59 -04:00
Dilshod Tadjibaev cd848b1c94
Add is_nan and contains_nan tensor ops (#2088)
* Add is_nan and contains_nan tensor ops

* Enable nan test for burn-candle

* Disabling tests due to #2089
2024-08-06 12:16:12 -05:00
Guillaume Lagrange 27d42cdaad
Fix aggregation results slice (#2110)
* Fix aggregation results slice

* View aggregation results as 1d tensor instead
2024-08-06 12:02:11 -05:00
Periwink ade664d4d8
Add top-k accuracy (#2097)
* make contacts deterministic across Worlds

* add top k acc

* update book

---------

Co-authored-by: Charles Bournhonesque <cbournhonesque@snapchat.com>
2024-08-06 12:01:28 -05:00
tiruka a53f459f20
Modify contributing md scripts to solve conflicts between doc and scripts (#2107)
* modified scripts comments and contributing.md to solve conflicts between them

* modified default value for checktypes and added NoArgs enum value

* added burn tch installation link

* Removed NoArgs enum value and use All as default
2024-08-05 13:32:17 -05:00
github-actions[bot] 9405713d2b
Combined PRs (#2108)
* Bump serde_json from 1.0.121 to 1.0.122

Bumps [serde_json](https://github.com/serde-rs/json) from 1.0.121 to 1.0.122.
- [Release notes](https://github.com/serde-rs/json/releases)
- [Commits](https://github.com/serde-rs/json/compare/v1.0.121...v1.0.122)

---
updated-dependencies:
- dependency-name: serde_json
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump regex from 1.10.5 to 1.10.6

Bumps [regex](https://github.com/rust-lang/regex) from 1.10.5 to 1.10.6.
- [Release notes](https://github.com/rust-lang/regex/releases)
- [Changelog](https://github.com/rust-lang/regex/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/regex/compare/1.10.5...1.10.6)

---
updated-dependencies:
- dependency-name: regex
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump wgpu from 22.0.0 to 22.1.0

Bumps [wgpu](https://github.com/gfx-rs/wgpu) from 22.0.0 to 22.1.0.
- [Release notes](https://github.com/gfx-rs/wgpu/releases)
- [Changelog](https://github.com/gfx-rs/wgpu/blob/wgpu-v22.1.0/CHANGELOG.md)
- [Commits](https://github.com/gfx-rs/wgpu/compare/wgpu-v22.0.0...wgpu-v22.1.0)

---
updated-dependencies:
- dependency-name: wgpu
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump tempfile from 3.10.1 to 3.11.0

Bumps [tempfile](https://github.com/Stebalien/tempfile) from 3.10.1 to 3.11.0.
- [Changelog](https://github.com/Stebalien/tempfile/blob/master/CHANGELOG.md)
- [Commits](https://github.com/Stebalien/tempfile/compare/v3.10.1...v3.11.0)

---
updated-dependencies:
- dependency-name: tempfile
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump flate2 from 1.0.30 to 1.0.31

Bumps [flate2](https://github.com/rust-lang/flate2-rs) from 1.0.30 to 1.0.31.
- [Release notes](https://github.com/rust-lang/flate2-rs/releases)
- [Commits](https://github.com/rust-lang/flate2-rs/commits)

---
updated-dependencies:
- dependency-name: flate2
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump EmbarkStudios/cargo-deny-action from 1 to 2

Bumps [EmbarkStudios/cargo-deny-action](https://github.com/embarkstudios/cargo-deny-action) from 1 to 2.
- [Release notes](https://github.com/embarkstudios/cargo-deny-action/releases)
- [Commits](https://github.com/embarkstudios/cargo-deny-action/compare/v1...v2)

---
updated-dependencies:
- dependency-name: EmbarkStudios/cargo-deny-action
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-08-05 11:43:58 -04:00
Dilshod Tadjibaev 80cc6d4eb5
Fix bug: Filling tensor containing f32::NEG_INFINITY will result in NaN for burn-ndarray (#2095)
* Fix #2094 bug

* Fix typo

* Fix mask broadcasting
2024-08-05 08:32:31 -04:00
Noah Schiro 52d896cd27
Fix broken links in contributor book (#2061) 2024-08-04 14:38:18 -05:00
omahs 9721b92dae
Fix typos (#2098)
* fix typos

* fix typo

* fix typo
2024-08-04 14:32:01 -05:00
mepatrick73 f7639bd35a
Repeat operation (#2090)
* renaming repeat to repeat_dim

* implementing repeat function

* renaming repeat files to repeat_dim

* renaming part 2

* renaming part 3

* renaming part 4

* renaming part 5

* adding test file

* adding unit test

* adding rust book documentation

* adding function args doc

* fixing tests

* changing repeat api to match pytorch equivalent

* fixing clippy error
2024-08-02 20:33:47 -04:00
Dilshod Tadjibaev bb13729b20
Improve ONNX import book section (#2059)
* Improve ONNX importing section

* Update onnx-model.md
2024-08-01 17:10:18 -05:00
Nathaniel Simard 62a30e973c
Fix: fusion auto bound checks (#2087) 2024-08-01 09:13:10 -04:00
Ragy Abraham 04d7ff24f2
Add polars DataFrame support for Dataset (#2029)
* initial commit to try implement from_dataframes for a burn dataset

* added the beginnings of tests. removed ref to self in utility method

* added unit test for dataframe module. added utility methods to convert polars rows to burn dataset values

* putting polars and dataframe mod behind a fearure flag

* testing both methods

* added a if let OK so that it doesn't panic. if we can't convert serde map to json string. added comments

* using polars serializer, renaming vars

* removed prints. just unwrapping

* setting feature flags back

* return Value::Null rather than panic if we can't serialize list value. no longer convert to object before converting to string. no longer using serde_json to_string method

* Use native deserializer instead of serde_json

* added support for lazyframes. added support to deserialize a few more data. added a few more tests

* Remove lazy, add more testing and other fixes

* Update the book

* Remove lazy feature

* Put back lazy feature for polars

---------

Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
2024-07-31 17:22:49 -05:00
Nathaniel Simard f673721d27
Refactor binary op (#2085) 2024-07-31 16:18:21 -04:00
Sylvain Benner 88656d24ad
Rename revision key to rev for cubecl dependencies in Cargo.toml (#2086) 2024-07-31 15:34:41 -04:00
Dilshod Tadjibaev 297173124f
Add 1d and 2d modules for interpolate with scaling (also fix ONNX Resize op) (#2081)
* Add interpolate module

* Update module.md

* Add interpolate 1d and 2d modules

* Consolidated InterpolateMode for 1d and 2d

* Remove CoordinateTransformationMode

* Add 1d tests for interpolate

* Refactor and fixes of ONNX Resize OP

* Fix clippy

* Fix docs

* Fix no_std
2024-07-31 12:08:26 -05:00
tiruka bc24bf3c14
modify broken link src of ide image (#2079) 2024-07-31 12:38:04 -04:00