mepatrick73
c9eb8d817d
test
2024-08-14 19:57:03 -04:00
mepatrick73
9e23cc4227
trying part 2
2024-08-14 19:01:48 -04:00
mepatrick73
c111d9dd61
perhaps fix
2024-08-14 18:49:55 -04:00
mepatrick73
44053277fb
scatter cleanup
2024-08-14 18:27:33 -04:00
mepatrick73
ab5d437adf
Merge branch 'main' into index-cpa-to-cubecl
2024-08-14 18:26:38 -04:00
mepatrick73
7899cee125
Working version !
2024-08-14 18:26:26 -04:00
Periwink
0435721188
Convert `reduce_dim_naive` kernel to use the `#[cube]` derive macro ( #2117 )
2024-08-14 10:46:37 -04:00
mepatrick73
f5047c27c8
working version of gather
2024-08-13 20:30:46 -04:00
mepatrick73
abaaa2dd55
wip
2024-08-13 19:56:04 -04:00
mepatrick73
05bed8ea74
Merge branch 'main' into index-cpa-to-cubecl
2024-08-13 18:24:36 -04:00
mepatrick73
94954fc32c
cleanup
2024-08-13 18:24:29 -04:00
Nathaniel Simard
ff8d0308fb
Enable cuda-jit in burn-core + in text classification example ( #2160 )
2024-08-12 18:22:27 -04:00
mepatrick73
a06933f029
working version
2024-08-12 17:13:53 -04:00
Guillaume Lagrange
7c17e84a0e
Update cubecl rev ( #2159 )
2024-08-12 16:55:07 -04:00
github-actions[bot]
be705466c9
Combined PRs ( #2157 )
...
* Bump syn from 2.0.72 to 2.0.74
Bumps [syn](https://github.com/dtolnay/syn ) from 2.0.72 to 2.0.74.
- [Release notes](https://github.com/dtolnay/syn/releases )
- [Commits](https://github.com/dtolnay/syn/compare/2.0.72...2.0.74 )
---
updated-dependencies:
- dependency-name: syn
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump serde from 1.0.204 to 1.0.206
Bumps [serde](https://github.com/serde-rs/serde ) from 1.0.204 to 1.0.206.
- [Release notes](https://github.com/serde-rs/serde/releases )
- [Commits](https://github.com/serde-rs/serde/compare/v1.0.204...v1.0.206 )
---
updated-dependencies:
- dependency-name: serde
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump clap from 4.5.12 to 4.5.15
Bumps [clap](https://github.com/clap-rs/clap ) from 4.5.12 to 4.5.15.
- [Release notes](https://github.com/clap-rs/clap/releases )
- [Changelog](https://github.com/clap-rs/clap/blob/master/CHANGELOG.md )
- [Commits](https://github.com/clap-rs/clap/compare/clap_complete-v4.5.12...v4.5.15 )
---
updated-dependencies:
- dependency-name: clap
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump serde_json from 1.0.122 to 1.0.124
Bumps [serde_json](https://github.com/serde-rs/json ) from 1.0.122 to 1.0.124.
- [Release notes](https://github.com/serde-rs/json/releases )
- [Commits](https://github.com/serde-rs/json/compare/v1.0.122...v1.0.124 )
---
updated-dependencies:
- dependency-name: serde_json
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
---------
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-08-12 10:24:22 -04:00
dependabot[bot]
286f8174a8
Bump serde_json from 1.0.122 to 1.0.124 ( #2156 )
...
Bumps [serde_json](https://github.com/serde-rs/json ) from 1.0.122 to 1.0.124.
- [Release notes](https://github.com/serde-rs/json/releases )
- [Commits](https://github.com/serde-rs/json/compare/v1.0.122...v1.0.124 )
---
updated-dependencies:
- dependency-name: serde_json
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-08-12 10:23:43 -04:00
Guillaume Lagrange
0eec293e28
Fix indices dim check in gather_update_outputs ( #2149 )
2024-08-12 09:20:25 -04:00
Adrian Müller
12caca7909
Allow ONNX scalar greater/less with scalar ( #2146 )
2024-08-12 09:11:08 -04:00
Periwink
e75eebfc31
Add comments for matmul kernel ( #2138 )
2024-08-12 09:09:24 -04:00
Guillaume Lagrange
be5eb910d4
Remove CubeCL GELU kernel example reference (moved to CubeCL repo) ( #2150 )
2024-08-09 15:23:47 -04:00
Adrian Müller
5a0c1dcead
Implement ONNX Gather for scalar indices ( #2141 )
...
* Implement ONNX Gather for scalars
* Fix ONNX gather_scalar codegen test
2024-08-09 11:53:01 -04:00
Guillaume Lagrange
724bfbc73b
Add scientific notation formatting for small metric values ( #2136 )
2024-08-08 16:25:34 -04:00
Guillaume Lagrange
723c9d1a2e
Fix module derive with generics ( #2127 )
...
* Remove unnecessary ModuleDisplayDefault generic bound + duplicate ModuleDisplay
* Remove erroneous bound for autodiff module generic
2024-08-08 16:24:51 -04:00
Nathaniel Simard
bb4a605ca6
Chore/integrate updated cubecl ( #2142 )
2024-08-08 16:19:39 -04:00
Dilshod Tadjibaev
1c681f46ec
Precision option for tensor display ( #2139 )
2024-08-08 15:01:42 -05:00
mepatrick73
27ca6cee95
feat: adding shape support for gather ONNX operation ( #2128 )
2024-08-08 13:18:03 -04:00
Guillaume Lagrange
0802d063d8
Fix inner backend typo in book guide ( #2135 )
2024-08-08 11:23:00 -05:00
tiruka
64b57792e0
modified mnist image link in the Hugging face ( #2134 )
2024-08-08 11:15:08 -05:00
mepatrick73
d770b1f470
ONNX Tile operation ( #2092 )
...
* renaming repeat to repeat_dim
* implementing repeat function
* renaming repeat files to repeat_dim
* renaming part 2
* renaming part 3
* renaming part 4
* renaming part 5
* adding test file
* adding unit test
* adding rust book documentation
* adding function args doc
* fixing tests
* changing repeat api to match pytorch equivalent
* fixing clippy error
* implementing tile onnx file
* temp
* working implementation and test
* working e2e test
* adding new supported onnx operation to the md file
2024-08-07 17:43:59 -04:00
Dilshod Tadjibaev
6b61ad5a61
Fix #2091 bug (in-place after expand) ( #2114 )
2024-08-07 17:37:20 -04:00
mepatrick73
e39485322d
bug fix: adding bounds checking to pad ONNX inputs ( #2120 )
...
* bug fix: adding bounds checking
Constant value is an optional value.
Adding a bounds check to make sure we've gotten enough inputs
* fixing pr
* quick little fix
* fixing constant_value cast
* fix clippy
2024-08-07 16:34:43 -04:00
Genna Wingert
a01004dd4a
Add Hard sigmoid activation function ( #2112 )
...
* Add Hard Sigmoid activation function
* Add ONNX import conversion for HardSigmoid
* Update supported operators list
* Update book
* Make test comparison approximate to eliminate precision issues
* Add burn-candle test
* Fix name in E2E test generator
2024-08-07 13:01:42 -05:00
tiruka
af8c3150c9
remove lto linker option to make build successful ( #2123 )
2024-08-07 12:59:53 -05:00
Periwink
dad85e0709
Add onnx mean ( #2119 )
...
* make contacts deterministic across Worlds
* add top k acc
* add onnx mean
* fix
* push fix
* format
---------
Co-authored-by: Charles Bournhonesque <cbournhonesque@snapchat.com>
2024-08-07 13:03:59 -04:00
Dilshod Tadjibaev
cd848b1c94
Add is_nan and contains_nan tensor ops ( #2088 )
...
* Add is_nan and contains_nan tensor ops
* Enable nan test for burn-candle
* Disabling tests due to #2089
2024-08-06 12:16:12 -05:00
Guillaume Lagrange
27d42cdaad
Fix aggregation results slice ( #2110 )
...
* Fix aggregation results slice
* View aggregation results as 1d tensor instead
2024-08-06 12:02:11 -05:00
Periwink
ade664d4d8
Add top-k accuracy ( #2097 )
...
* make contacts deterministic across Worlds
* add top k acc
* update book
---------
Co-authored-by: Charles Bournhonesque <cbournhonesque@snapchat.com>
2024-08-06 12:01:28 -05:00
tiruka
a53f459f20
Modify contributing md scripts to solve conflicts between doc and scripts ( #2107 )
...
* modified scripts comments and contributing.md to solve conflicts between them
* modified default value for checktypes and added NoArgs enum value
* added burn tch installation link
* Removed NoArgs enum value and use All as default
2024-08-05 13:32:17 -05:00
github-actions[bot]
9405713d2b
Combined PRs ( #2108 )
...
* Bump serde_json from 1.0.121 to 1.0.122
Bumps [serde_json](https://github.com/serde-rs/json ) from 1.0.121 to 1.0.122.
- [Release notes](https://github.com/serde-rs/json/releases )
- [Commits](https://github.com/serde-rs/json/compare/v1.0.121...v1.0.122 )
---
updated-dependencies:
- dependency-name: serde_json
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump regex from 1.10.5 to 1.10.6
Bumps [regex](https://github.com/rust-lang/regex ) from 1.10.5 to 1.10.6.
- [Release notes](https://github.com/rust-lang/regex/releases )
- [Changelog](https://github.com/rust-lang/regex/blob/master/CHANGELOG.md )
- [Commits](https://github.com/rust-lang/regex/compare/1.10.5...1.10.6 )
---
updated-dependencies:
- dependency-name: regex
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump wgpu from 22.0.0 to 22.1.0
Bumps [wgpu](https://github.com/gfx-rs/wgpu ) from 22.0.0 to 22.1.0.
- [Release notes](https://github.com/gfx-rs/wgpu/releases )
- [Changelog](https://github.com/gfx-rs/wgpu/blob/wgpu-v22.1.0/CHANGELOG.md )
- [Commits](https://github.com/gfx-rs/wgpu/compare/wgpu-v22.0.0...wgpu-v22.1.0 )
---
updated-dependencies:
- dependency-name: wgpu
dependency-type: direct:production
update-type: version-update:semver-minor
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump tempfile from 3.10.1 to 3.11.0
Bumps [tempfile](https://github.com/Stebalien/tempfile ) from 3.10.1 to 3.11.0.
- [Changelog](https://github.com/Stebalien/tempfile/blob/master/CHANGELOG.md )
- [Commits](https://github.com/Stebalien/tempfile/compare/v3.10.1...v3.11.0 )
---
updated-dependencies:
- dependency-name: tempfile
dependency-type: direct:production
update-type: version-update:semver-minor
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump flate2 from 1.0.30 to 1.0.31
Bumps [flate2](https://github.com/rust-lang/flate2-rs ) from 1.0.30 to 1.0.31.
- [Release notes](https://github.com/rust-lang/flate2-rs/releases )
- [Commits](https://github.com/rust-lang/flate2-rs/commits )
---
updated-dependencies:
- dependency-name: flate2
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump EmbarkStudios/cargo-deny-action from 1 to 2
Bumps [EmbarkStudios/cargo-deny-action](https://github.com/embarkstudios/cargo-deny-action ) from 1 to 2.
- [Release notes](https://github.com/embarkstudios/cargo-deny-action/releases )
- [Commits](https://github.com/embarkstudios/cargo-deny-action/compare/v1...v2 )
---
updated-dependencies:
- dependency-name: EmbarkStudios/cargo-deny-action
dependency-type: direct:production
update-type: version-update:semver-major
...
Signed-off-by: dependabot[bot] <support@github.com>
---------
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-08-05 11:43:58 -04:00
Dilshod Tadjibaev
80cc6d4eb5
Fix bug: Filling tensor containing f32::NEG_INFINITY will result in NaN for burn-ndarray ( #2095 )
...
* Fix #2094 bug
* Fix typo
* Fix mask broadcasting
2024-08-05 08:32:31 -04:00
Noah Schiro
52d896cd27
Fix broken links in contributor book ( #2061 )
2024-08-04 14:38:18 -05:00
omahs
9721b92dae
Fix typos ( #2098 )
...
* fix typos
* fix typo
* fix typo
2024-08-04 14:32:01 -05:00
mepatrick73
f7639bd35a
Repeat operation ( #2090 )
...
* renaming repeat to repeat_dim
* implementing repeat function
* renaming repeat files to repeat_dim
* renaming part 2
* renaming part 3
* renaming part 4
* renaming part 5
* adding test file
* adding unit test
* adding rust book documentation
* adding function args doc
* fixing tests
* changing repeat api to match pytorch equivalent
* fixing clippy error
2024-08-02 20:33:47 -04:00
Dilshod Tadjibaev
bb13729b20
Improve ONNX import book section ( #2059 )
...
* Improve ONNX importing section
* Update onnx-model.md
2024-08-01 17:10:18 -05:00
Nathaniel Simard
62a30e973c
Fix: fusion auto bound checks ( #2087 )
2024-08-01 09:13:10 -04:00
Ragy Abraham
04d7ff24f2
Add polars DataFrame support for Dataset ( #2029 )
...
* initial commit to try implement from_dataframes for a burn dataset
* added the beginnings of tests. removed ref to self in utility method
* added unit test for dataframe module. added utility methods to convert polars rows to burn dataset values
* putting polars and dataframe mod behind a fearure flag
* testing both methods
* added a if let OK so that it doesn't panic. if we can't convert serde map to json string. added comments
* using polars serializer, renaming vars
* removed prints. just unwrapping
* setting feature flags back
* return Value::Null rather than panic if we can't serialize list value. no longer convert to object before converting to string. no longer using serde_json to_string method
* Use native deserializer instead of serde_json
* added support for lazyframes. added support to deserialize a few more data. added a few more tests
* Remove lazy, add more testing and other fixes
* Update the book
* Remove lazy feature
* Put back lazy feature for polars
---------
Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
2024-07-31 17:22:49 -05:00
Nathaniel Simard
f673721d27
Refactor binary op ( #2085 )
2024-07-31 16:18:21 -04:00
Sylvain Benner
88656d24ad
Rename revision key to rev for cubecl dependencies in Cargo.toml ( #2086 )
2024-07-31 15:34:41 -04:00
Dilshod Tadjibaev
297173124f
Add 1d and 2d modules for interpolate with scaling (also fix ONNX Resize op) ( #2081 )
...
* Add interpolate module
* Update module.md
* Add interpolate 1d and 2d modules
* Consolidated InterpolateMode for 1d and 2d
* Remove CoordinateTransformationMode
* Add 1d tests for interpolate
* Refactor and fixes of ONNX Resize OP
* Fix clippy
* Fix docs
* Fix no_std
2024-07-31 12:08:26 -05:00
tiruka
bc24bf3c14
modify broken link src of ide image ( #2079 )
2024-07-31 12:38:04 -04:00