mepatrick73
d770b1f470
ONNX Tile operation ( #2092 )
...
* renaming repeat to repeat_dim
* implementing repeat function
* renaming repeat files to repeat_dim
* renaming part 2
* renaming part 3
* renaming part 4
* renaming part 5
* adding test file
* adding unit test
* adding rust book documentation
* adding function args doc
* fixing tests
* changing repeat api to match pytorch equivalent
* fixing clippy error
* implementing tile onnx file
* temp
* working implementation and test
* working e2e test
* adding new supported onnx operation to the md file
2024-08-07 17:43:59 -04:00
Dilshod Tadjibaev
6b61ad5a61
Fix #2091 bug (in-place after expand) ( #2114 )
2024-08-07 17:37:20 -04:00
mepatrick73
e39485322d
bug fix: adding bounds checking to pad ONNX inputs ( #2120 )
...
* bug fix: adding bounds checking
Constant value is an optional value.
Adding a bounds check to make sure we've gotten enough inputs
* fixing pr
* quick little fix
* fixing constant_value cast
* fix clippy
2024-08-07 16:34:43 -04:00
Genna Wingert
a01004dd4a
Add Hard sigmoid activation function ( #2112 )
...
* Add Hard Sigmoid activation function
* Add ONNX import conversion for HardSigmoid
* Update supported operators list
* Update book
* Make test comparison approximate to eliminate precision issues
* Add burn-candle test
* Fix name in E2E test generator
2024-08-07 13:01:42 -05:00
tiruka
af8c3150c9
remove lto linker option to make build successful ( #2123 )
2024-08-07 12:59:53 -05:00
Periwink
dad85e0709
Add onnx mean ( #2119 )
...
* make contacts deterministic across Worlds
* add top k acc
* add onnx mean
* fix
* push fix
* format
---------
Co-authored-by: Charles Bournhonesque <cbournhonesque@snapchat.com>
2024-08-07 13:03:59 -04:00
Dilshod Tadjibaev
cd848b1c94
Add is_nan and contains_nan tensor ops ( #2088 )
...
* Add is_nan and contains_nan tensor ops
* Enable nan test for burn-candle
* Disabling tests due to #2089
2024-08-06 12:16:12 -05:00
Guillaume Lagrange
27d42cdaad
Fix aggregation results slice ( #2110 )
...
* Fix aggregation results slice
* View aggregation results as 1d tensor instead
2024-08-06 12:02:11 -05:00
Periwink
ade664d4d8
Add top-k accuracy ( #2097 )
...
* make contacts deterministic across Worlds
* add top k acc
* update book
---------
Co-authored-by: Charles Bournhonesque <cbournhonesque@snapchat.com>
2024-08-06 12:01:28 -05:00
tiruka
a53f459f20
Modify contributing md scripts to solve conflicts between doc and scripts ( #2107 )
...
* modified scripts comments and contributing.md to solve conflicts between them
* modified default value for checktypes and added NoArgs enum value
* added burn tch installation link
* Removed NoArgs enum value and use All as default
2024-08-05 13:32:17 -05:00
github-actions[bot]
9405713d2b
Combined PRs ( #2108 )
...
* Bump serde_json from 1.0.121 to 1.0.122
Bumps [serde_json](https://github.com/serde-rs/json ) from 1.0.121 to 1.0.122.
- [Release notes](https://github.com/serde-rs/json/releases )
- [Commits](https://github.com/serde-rs/json/compare/v1.0.121...v1.0.122 )
---
updated-dependencies:
- dependency-name: serde_json
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump regex from 1.10.5 to 1.10.6
Bumps [regex](https://github.com/rust-lang/regex ) from 1.10.5 to 1.10.6.
- [Release notes](https://github.com/rust-lang/regex/releases )
- [Changelog](https://github.com/rust-lang/regex/blob/master/CHANGELOG.md )
- [Commits](https://github.com/rust-lang/regex/compare/1.10.5...1.10.6 )
---
updated-dependencies:
- dependency-name: regex
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump wgpu from 22.0.0 to 22.1.0
Bumps [wgpu](https://github.com/gfx-rs/wgpu ) from 22.0.0 to 22.1.0.
- [Release notes](https://github.com/gfx-rs/wgpu/releases )
- [Changelog](https://github.com/gfx-rs/wgpu/blob/wgpu-v22.1.0/CHANGELOG.md )
- [Commits](https://github.com/gfx-rs/wgpu/compare/wgpu-v22.0.0...wgpu-v22.1.0 )
---
updated-dependencies:
- dependency-name: wgpu
dependency-type: direct:production
update-type: version-update:semver-minor
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump tempfile from 3.10.1 to 3.11.0
Bumps [tempfile](https://github.com/Stebalien/tempfile ) from 3.10.1 to 3.11.0.
- [Changelog](https://github.com/Stebalien/tempfile/blob/master/CHANGELOG.md )
- [Commits](https://github.com/Stebalien/tempfile/compare/v3.10.1...v3.11.0 )
---
updated-dependencies:
- dependency-name: tempfile
dependency-type: direct:production
update-type: version-update:semver-minor
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump flate2 from 1.0.30 to 1.0.31
Bumps [flate2](https://github.com/rust-lang/flate2-rs ) from 1.0.30 to 1.0.31.
- [Release notes](https://github.com/rust-lang/flate2-rs/releases )
- [Commits](https://github.com/rust-lang/flate2-rs/commits )
---
updated-dependencies:
- dependency-name: flate2
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump EmbarkStudios/cargo-deny-action from 1 to 2
Bumps [EmbarkStudios/cargo-deny-action](https://github.com/embarkstudios/cargo-deny-action ) from 1 to 2.
- [Release notes](https://github.com/embarkstudios/cargo-deny-action/releases )
- [Commits](https://github.com/embarkstudios/cargo-deny-action/compare/v1...v2 )
---
updated-dependencies:
- dependency-name: EmbarkStudios/cargo-deny-action
dependency-type: direct:production
update-type: version-update:semver-major
...
Signed-off-by: dependabot[bot] <support@github.com>
---------
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-08-05 11:43:58 -04:00
Dilshod Tadjibaev
80cc6d4eb5
Fix bug: Filling tensor containing f32::NEG_INFINITY will result in NaN for burn-ndarray ( #2095 )
...
* Fix #2094 bug
* Fix typo
* Fix mask broadcasting
2024-08-05 08:32:31 -04:00
Noah Schiro
52d896cd27
Fix broken links in contributor book ( #2061 )
2024-08-04 14:38:18 -05:00
omahs
9721b92dae
Fix typos ( #2098 )
...
* fix typos
* fix typo
* fix typo
2024-08-04 14:32:01 -05:00
mepatrick73
f7639bd35a
Repeat operation ( #2090 )
...
* renaming repeat to repeat_dim
* implementing repeat function
* renaming repeat files to repeat_dim
* renaming part 2
* renaming part 3
* renaming part 4
* renaming part 5
* adding test file
* adding unit test
* adding rust book documentation
* adding function args doc
* fixing tests
* changing repeat api to match pytorch equivalent
* fixing clippy error
2024-08-02 20:33:47 -04:00
Dilshod Tadjibaev
bb13729b20
Improve ONNX import book section ( #2059 )
...
* Improve ONNX importing section
* Update onnx-model.md
2024-08-01 17:10:18 -05:00
Nathaniel Simard
62a30e973c
Fix: fusion auto bound checks ( #2087 )
2024-08-01 09:13:10 -04:00
Ragy Abraham
04d7ff24f2
Add polars DataFrame support for Dataset ( #2029 )
...
* initial commit to try implement from_dataframes for a burn dataset
* added the beginnings of tests. removed ref to self in utility method
* added unit test for dataframe module. added utility methods to convert polars rows to burn dataset values
* putting polars and dataframe mod behind a fearure flag
* testing both methods
* added a if let OK so that it doesn't panic. if we can't convert serde map to json string. added comments
* using polars serializer, renaming vars
* removed prints. just unwrapping
* setting feature flags back
* return Value::Null rather than panic if we can't serialize list value. no longer convert to object before converting to string. no longer using serde_json to_string method
* Use native deserializer instead of serde_json
* added support for lazyframes. added support to deserialize a few more data. added a few more tests
* Remove lazy, add more testing and other fixes
* Update the book
* Remove lazy feature
* Put back lazy feature for polars
---------
Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
2024-07-31 17:22:49 -05:00
Nathaniel Simard
f673721d27
Refactor binary op ( #2085 )
2024-07-31 16:18:21 -04:00
Sylvain Benner
88656d24ad
Rename revision key to rev for cubecl dependencies in Cargo.toml ( #2086 )
2024-07-31 15:34:41 -04:00
Dilshod Tadjibaev
297173124f
Add 1d and 2d modules for interpolate with scaling (also fix ONNX Resize op) ( #2081 )
...
* Add interpolate module
* Update module.md
* Add interpolate 1d and 2d modules
* Consolidated InterpolateMode for 1d and 2d
* Remove CoordinateTransformationMode
* Add 1d tests for interpolate
* Refactor and fixes of ONNX Resize OP
* Fix clippy
* Fix docs
* Fix no_std
2024-07-31 12:08:26 -05:00
tiruka
bc24bf3c14
modify broken link src of ide image ( #2079 )
2024-07-31 12:38:04 -04:00
nathaniel
eb9e822832
Update cubecl version
2024-07-31 09:50:14 -04:00
Louis Fortier-Dubois
e68b9ab0cc
Refactor/jit cube/mask ( #2075 )
...
Co-authored-by: louisfd <louisfd@gmail.com>
2024-07-30 09:27:39 -04:00
syl20bnr
47d4139b96
Update Cargo.lock
2024-07-29 16:19:26 -04:00
syl20bnr
a72a533855
Fix cubecl version in Cargo.toml to correctly fecth the version tag
2024-07-29 16:19:26 -04:00
github-actions[bot]
e3649770fe
Combined PRs ( #2073 )
...
* Bump gix-tempfile from 14.0.0 to 14.0.1
Bumps [gix-tempfile](https://github.com/Byron/gitoxide ) from 14.0.0 to 14.0.1.
- [Release notes](https://github.com/Byron/gitoxide/releases )
- [Changelog](https://github.com/Byron/gitoxide/blob/main/CHANGELOG.md )
- [Commits](https://github.com/Byron/gitoxide/compare/gix-tempfile-v14.0.0...gix-tempfile-v14.0.1 )
---
updated-dependencies:
- dependency-name: gix-tempfile
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump tokio from 1.38.1 to 1.39.2
Bumps [tokio](https://github.com/tokio-rs/tokio ) from 1.38.1 to 1.39.2.
- [Release notes](https://github.com/tokio-rs/tokio/releases )
- [Commits](https://github.com/tokio-rs/tokio/compare/tokio-1.38.1...tokio-1.39.2 )
---
updated-dependencies:
- dependency-name: tokio
dependency-type: direct:production
update-type: version-update:semver-minor
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump env_logger from 0.11.3 to 0.11.5
Bumps [env_logger](https://github.com/rust-cli/env_logger ) from 0.11.3 to 0.11.5.
- [Release notes](https://github.com/rust-cli/env_logger/releases )
- [Changelog](https://github.com/rust-cli/env_logger/blob/main/CHANGELOG.md )
- [Commits](https://github.com/rust-cli/env_logger/compare/v0.11.3...v0.11.5 )
---
updated-dependencies:
- dependency-name: env_logger
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump matrixmultiply from 0.3.8 to 0.3.9
Bumps [matrixmultiply](https://github.com/bluss/matrixmultiply ) from 0.3.8 to 0.3.9.
- [Commits](https://github.com/bluss/matrixmultiply/compare/0.3.8...0.3.9 )
---
updated-dependencies:
- dependency-name: matrixmultiply
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump clap from 4.5.9 to 4.5.11
Bumps [clap](https://github.com/clap-rs/clap ) from 4.5.9 to 4.5.11.
- [Release notes](https://github.com/clap-rs/clap/releases )
- [Changelog](https://github.com/clap-rs/clap/blob/master/CHANGELOG.md )
- [Commits](https://github.com/clap-rs/clap/compare/clap_complete-v4.5.9...clap_complete-v4.5.11 )
---
updated-dependencies:
- dependency-name: clap
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
---------
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-07-29 13:49:23 -04:00
Nathaniel Simard
096ec13c48
Chore/update/cubecl ( #2067 )
2024-07-28 12:15:02 -04:00
dependabot[bot]
2046831df6
Bump github/combine-prs from 5.0.0 to 5.1.0 ( #2039 )
...
Bumps [github/combine-prs](https://github.com/github/combine-prs ) from 5.0.0 to 5.1.0.
- [Release notes](https://github.com/github/combine-prs/releases )
- [Commits](https://github.com/github/combine-prs/compare/v5.0.0...v5.1.0 )
---
updated-dependencies:
- dependency-name: github/combine-prs
dependency-type: direct:production
update-type: version-update:semver-minor
...
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-07-24 11:03:45 -04:00
Guillaume Lagrange
64a2f12827
Extend [min, max] range to ensure zero-point ( #2055 )
2024-07-24 09:55:11 -04:00
dependabot[bot]
dea33e88d4
Bump zip from 2.1.3 to 2.1.5 ( #2047 )
...
Bumps [zip](https://github.com/zip-rs/zip2 ) from 2.1.3 to 2.1.5.
- [Release notes](https://github.com/zip-rs/zip2/releases )
- [Changelog](https://github.com/zip-rs/zip2/blob/master/CHANGELOG.md )
- [Commits](https://github.com/zip-rs/zip2/compare/v2.1.3...v2.1.5 )
---
updated-dependencies:
- dependency-name: zip
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-07-24 09:47:04 -04:00
dependabot[bot]
03ddf55831
Bump image from 0.25.1 to 0.25.2 ( #2045 )
...
Bumps [image](https://github.com/image-rs/image ) from 0.25.1 to 0.25.2.
- [Changelog](https://github.com/image-rs/image/blob/main/CHANGES.md )
- [Commits](https://github.com/image-rs/image/compare/v0.25.1...v0.25.2 )
---
updated-dependencies:
- dependency-name: image
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-07-24 09:45:26 -04:00
johnhuichen
4a3fc9d4a0
Implement ONNX Pad Operator ( #2007 )
...
* Implement ONNX pad
* ONNX pad arguments fix
pad now requires 2 or more arguments
if the third argument is not given, it will default to 0
* fixing bug in input len fix
* change panic comment
Change panic comment from needing two inputs. This comes from the fact that the ONNX spec requires two necessary inputs but could have more two more optional argument.
---------
Co-authored-by: JC <you@example.com>
Co-authored-by: mepatrick73 <pameu17@ulaval.ca>
2024-07-23 13:50:20 -04:00
Guillaume Lagrange
53c77ae646
Convert compatible prelu weights to rank 1 ( #2054 )
2024-07-23 10:58:20 -04:00
Guillaume Lagrange
4c7353230e
Fix checks_channels_div_groups condition and ONNX conv import with groups ( #2051 )
...
* Fix checks_channels_div_groups condition
* Fix conv channels config w/ groups
2024-07-22 12:53:48 -05:00
Mathias Insley
0bbc1ed30f
Bug/Remove Squeeze Panic for Multiple Dimensions ( #2035 )
...
* Remove panic for squeeze when more than one axis is specified
* Remove extra Model()
* Change script to squeeze all singleton dimensions
* Revert change since burn requires axes to be specified
* Fix input tensor
* Try updating ONNX files again
* Add script for testing multiple axes along with new ONNX file
* Update squeeze.py comments
* Add squeeze_multiple model to tests
* Fix dim_inference
2024-07-22 12:13:07 -05:00
Nathaniel Simard
19cd67a9e2
Migration/cubecl ( #2041 )
2024-07-22 11:08:40 -04:00
Guillaume Lagrange
0d5025edbb
Refactor tensor quantization for q_* ops ( #2025 )
...
* Move QuantizationScheme to burn-tensor
* Refactor QuantizedTensorPrimitive to include the quantization strategy
* Fix QFloat tensor data display
* Refactor quantization methods to use scheme and qparams (on backend device)
* Fix clippy
* Fix fmt
* Add qtensor primitive tests
2024-07-19 10:39:50 -04:00
Sylvain Benner
3204cbe345
Update cargo.lock
2024-07-18 09:15:49 -04:00
Sylvain Benner
b6784684a1
Bump rust minimal version to 1.79
...
That's because bitstream-io, a dependency of rav1e, started using
a feature only in Rust 1.79.
2024-07-18 09:15:49 -04:00
José Manuel
befe6c1601
Added parameter trust_remote_code to hf dataset call. ( #2013 )
...
* Added parameter trust_remote_code to hf dataset call.
* Removed test modul as it may break causing false negatives.
Set default trust_remote_code to false.
Added an example that highlights the usecase.
2024-07-17 16:40:23 -05:00
RuelYasa
9804bf81b2
Adding burn::nn::Sigmoid ( #2031 )
2024-07-17 14:34:44 -04:00
Dilshod Tadjibaev
ed8a91d48a
Update slice documentation ( #2024 )
2024-07-16 11:59:02 -05:00
Sylvain Benner
1ed62f36f8
Bump gix-tempfile to fix security audit on gix-fs ( #2022 )
2024-07-16 11:41:11 -04:00
github-actions[bot]
2a5d175e14
Combined PRs ( #2021 )
...
* Bump clap from 4.5.8 to 4.5.9
Bumps [clap](https://github.com/clap-rs/clap ) from 4.5.8 to 4.5.9.
- [Release notes](https://github.com/clap-rs/clap/releases )
- [Changelog](https://github.com/clap-rs/clap/blob/master/CHANGELOG.md )
- [Commits](https://github.com/clap-rs/clap/compare/clap_complete-v4.5.8...v4.5.9 )
---
updated-dependencies:
- dependency-name: clap
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump serde_json from 1.0.119 to 1.0.120
Bumps [serde_json](https://github.com/serde-rs/json ) from 1.0.119 to 1.0.120.
- [Release notes](https://github.com/serde-rs/json/releases )
- [Commits](https://github.com/serde-rs/json/compare/v1.0.119...v1.0.120 )
---
updated-dependencies:
- dependency-name: serde_json
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump sysinfo from 0.30.12 to 0.30.13
Bumps [sysinfo](https://github.com/GuillaumeGomez/sysinfo ) from 0.30.12 to 0.30.13.
- [Changelog](https://github.com/GuillaumeGomez/sysinfo/blob/master/CHANGELOG.md )
- [Commits](https://github.com/GuillaumeGomez/sysinfo/commits/v0.30.13 )
---
updated-dependencies:
- dependency-name: sysinfo
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump syn from 2.0.69 to 2.0.71
Bumps [syn](https://github.com/dtolnay/syn ) from 2.0.69 to 2.0.71.
- [Release notes](https://github.com/dtolnay/syn/releases )
- [Commits](https://github.com/dtolnay/syn/compare/2.0.69...2.0.71 )
---
updated-dependencies:
- dependency-name: syn
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
---------
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-07-15 09:35:36 -04:00
Sylvain Benner
0e77e19635
Remove mention of example in backend section of the book ( #2014 )
2024-07-15 09:34:40 -04:00
Guillaume Lagrange
7661deb258
Fix image-classsification-web + autotune flag usage ( #2011 )
2024-07-15 09:31:54 -04:00
Guillaume Lagrange
3afff434bd
Module weight quantization ( #2000 )
...
* Add q_into_data and q_reshape
* Fix tch quantize f16 and q_into_data
* Convert to actual dtype/kind in dequantize
* Add module quantization and q_from_data
* Fix clippy
* Add documentation
* Handle deserialize data conversion
* Fix typo
* Add calibration tests
* Fix clippy precision
* Add QTensorOps require_grad methods to avoid dequantizing
* Add Dequantize mapper docs
* Remove dead code
2024-07-15 08:20:37 -04:00
Nathaniel Simard
a4123f6c2e
Cube/doc/readme ( #1904 )
...
---------
Co-authored-by: louisfd <louisfd94@gmail.com>
2024-07-12 10:15:17 -04:00
nathaniel
0a33aa363d
Fix cube docs
2024-07-12 09:25:45 -04:00