Guillaume Lagrange
dc45cf1700
Add `topk` tensor operation ( #1497 )
...
* Add topk and topk_with_indices
* Change topk_with_indices test to guarantee order (previously equal elements)
2024-03-22 10:57:20 -04:00
Louis Fortier-Dubois
dd699a90a2
Migrate/jit/matmul tiling 2d ( #1472 )
...
* refactor matmul files
* wip refactor matmul
* everything is memco
* support local arrays
* advancing tiling2d
* advancing tiling2d
* advancing tiling2d
* tiling2d finished but buggy
* configurable unrolling
* not bugged
* fails on unroll
* stupid break
* tiling2d no assumption works
* clippy
* bounds check as bool
* lhs rhs as enum
* tiling 2d major refactor
* remove assign vec4
* variable declarations above loops
* fmt
* clippy
* Fix autotune + unroll
* move val
* clippy
* fmt
---------
Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
2024-03-22 08:26:32 -04:00
Sylvain Benner
0a8a3cc9e9
[xtask] Add support for cargo metadata new workspace member format ( #1500 )
2024-03-21 16:04:52 -04:00
Guillaume Lagrange
3e4af41694
Fix sort descending for 1d case ( #1494 )
2024-03-21 07:45:37 -04:00
Guillaume Lagrange
47a84cc980
Add tensor sorting operations ( #1488 )
...
* Add sort, sort_with_indices and argsort
* Fix wasm build
* Add sort ops autodiff
* Fix TODO parallel comment
* Fix sort_with_indices 1d and add descending options
* Fix clippy
* Fix ascending comment (configurable)
2024-03-20 14:51:04 -04:00
Guillaume Lagrange
430f642394
Change assert_approx_eq precision from 3 to 2 ( #1491 )
2024-03-19 12:26:21 -04:00
Rubén J.R
69f1877754
New learning rate schedulers ( #1481 )
2024-03-19 08:28:42 -05:00
carrotflakes
8911093b88
Add `flip` tensor operator ( #1468 )
2024-03-18 20:33:39 -05:00
Dilshod Tadjibaev
8a8300c1fb
Add tril_mask, triu_mask and diag_mask ops ( #1479 )
2024-03-18 10:15:40 -05:00
Louis Fortier-Dubois
c729401fb2
Remove unroll in shared reduce ( #1480 )
2024-03-17 13:56:54 -04:00
Arjun31415
d3af29c5b4
Missing `Debug` derive for Group Norm Config ( #1482 )
2024-03-17 13:12:50 -04:00
Louis Fortier-Dubois
cf3c1ca80a
Migrate/jit/cat ( #1457 )
2024-03-17 11:37:36 -04:00
Louis Fortier-Dubois
41d01b8e19
Migrate/jit/prod ( #1474 )
2024-03-15 18:29:30 -04:00
Arjun31415
4de1272344
Feat: Add Leaky Relu Model ( #1467 )
2024-03-14 10:53:40 -05:00
WorldSEnder
53eb3ecfa9
Implement Huber loss ( #1444 )
...
* Implement Huber loss
Instead of using a sign or abs function, uses clamping to compute
it outside the bounds. This is better for the autodiff backend.
* mention Huber loss in the book
* unify naming of residuals in comments
2024-03-13 12:55:46 -05:00
Dilshod Tadjibaev
7a98b2f663
Add prod and prod_dim tensor ops ( #1460 )
2024-03-12 14:00:02 -05:00
carrotflakes
80aac1dde4
Add Rank0 variant to AdaptorRecordV1 and AdaptorRecordItemV1 ( #1442 )
2024-03-12 13:08:20 -04:00
Kyle Chen
c52c49785d
Add linear learning rate scheduler ( #1443 )
2024-03-12 13:04:12 -04:00
Louis Fortier-Dubois
278fcb3dad
Migrate/jit/mask ( #1456 )
2024-03-12 12:43:05 -04:00
Louis Fortier-Dubois
02d37011ab
Fix/main/print ( #1459 )
2024-03-11 18:52:36 -04:00
Dilshod Tadjibaev
0138e16af6
Add Enum module support in PyTorchFileRecorder ( #1436 )
...
* Add Enum module support in PyTorchFileRecorder
Fixes #1431
* Fix wording/typos per PR feedback
2024-03-11 11:21:01 -05:00
Dilshod Tadjibaev
9d4fbc5a35
Rename `diagonal` to `eye` tensor op and add missing entry for diagonal to Book tensor section ( #1449 )
...
* Update tensor.md
* Rename diagonal to eye
* Remove extra space per PR feedback
2024-03-11 11:00:36 -05:00
Louis Fortier-Dubois
093cbd397d
JIT Migration: PRNG ( #1433 )
...
* wip bernoulli
* wip
* bernoulli works
* uniform works
* done
* remove old
* refactor prng traits
* forgot to save file
* allow
* clippy
* clippy
* scalar commutativity
* array instead of vec
2024-03-11 11:40:27 -04:00
Dilshod Tadjibaev
3f7e6bd5bc
Add `sign` tensor operator ( #1446 )
2024-03-11 10:39:30 -05:00
github-actions[bot]
56f460295a
Combined PRs ( #1439 )
...
* Bump web-time from 1.0.0 to 1.1.0
Bumps [web-time](https://github.com/daxpedda/web-time ) from 1.0.0 to 1.1.0.
- [Release notes](https://github.com/daxpedda/web-time/releases )
- [Changelog](https://github.com/daxpedda/web-time/blob/main/CHANGELOG.md )
- [Commits](https://github.com/daxpedda/web-time/compare/v1.0.0...v1.1.0 )
---
updated-dependencies:
- dependency-name: web-time
dependency-type: direct:production
update-type: version-update:semver-minor
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump rayon from 1.8.1 to 1.9.0
Bumps [rayon](https://github.com/rayon-rs/rayon ) from 1.8.1 to 1.9.0.
- [Changelog](https://github.com/rayon-rs/rayon/blob/main/RELEASES.md )
- [Commits](https://github.com/rayon-rs/rayon/compare/rayon-core-v1.8.1...rayon-core-v1.9.0 )
---
updated-dependencies:
- dependency-name: rayon
dependency-type: direct:production
update-type: version-update:semver-minor
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump tempfile from 3.10.0 to 3.10.1
Bumps [tempfile](https://github.com/Stebalien/tempfile ) from 3.10.0 to 3.10.1.
- [Changelog](https://github.com/Stebalien/tempfile/blob/master/CHANGELOG.md )
- [Commits](https://github.com/Stebalien/tempfile/compare/v3.10.0...v3.10.1 )
---
updated-dependencies:
- dependency-name: tempfile
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump the cargo group group with 1 update
Bumps the cargo group group with 1 update: [mio](https://github.com/tokio-rs/mio ).
Updates `mio` from 0.8.10 to 0.8.11
- [Release notes](https://github.com/tokio-rs/mio/releases )
- [Changelog](https://github.com/tokio-rs/mio/blob/master/CHANGELOG.md )
- [Commits](https://github.com/tokio-rs/mio/compare/v0.8.10...v0.8.11 )
---
updated-dependencies:
- dependency-name: mio
dependency-type: indirect
dependency-group: cargo-security-group
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump log from 0.4.20 to 0.4.21
Bumps [log](https://github.com/rust-lang/log ) from 0.4.20 to 0.4.21.
- [Release notes](https://github.com/rust-lang/log/releases )
- [Changelog](https://github.com/rust-lang/log/blob/master/CHANGELOG.md )
- [Commits](https://github.com/rust-lang/log/compare/0.4.20...0.4.21 )
---
updated-dependencies:
- dependency-name: log
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
---------
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-03-08 19:27:48 -05:00
Nathaniel Simard
2de270fe0e
Fix tch view data corruption ( #1434 )
2024-03-08 09:55:47 -05:00
Louis Fortier-Dubois
61c0474172
JIT migration: contiguous kernel ( #1424 )
...
* JIT migration: contiguous kernel
* delete wgsl
2024-03-08 08:01:29 -05:00
Louis Fortier-Dubois
040cd55f85
JIT migration: cast kernel ( #1423 )
2024-03-07 17:49:09 -05:00
Louis Fortier-Dubois
9eecc713a4
JIT: Fix min & max values ( #1429 )
...
* real min and max values
* fix
* fmt
2024-03-07 15:10:30 -05:00
Dilshod Tadjibaev
c7d4c23f97
Support for non-contiguous indexes in PyTorchFileRecorder keys ( #1432 )
...
* Fix non-contiguous indexes
* Update pytorch-model.md
* Simplify multiple forwards
2024-03-07 13:40:57 -06:00
Dilshod Tadjibaev
b12646de0a
Truncate debug display for NestedValue ( #1428 )
...
* Truncate debug display for NestedValue
* Fix failing tests
2024-03-07 08:06:31 -05:00
Dilshod Tadjibaev
545444c02a
PyTorchFileRecord print debug option ( #1425 )
...
* Add debug print option to PyTorchFileRecorder
* Updated documentation and improved print output
* Improve print wording
* Updated per PR feedback
2024-03-06 16:11:37 -06:00
Nathaniel Simard
b429cc39c1
Splitted the JIT stuff from the Wgpu stuff ( #1417 )
2024-03-06 11:23:53 -05:00
jackdarlison
3ff6e7170e
Switched arguments in `reshape_args_usize` check ( #1409 )
2024-03-06 08:45:12 -05:00
Aasheesh Singh
0c92c8c8eb
Autodiff/training support for Nearest Interpolation ( #1414 )
...
Add training support for nearest interpolation
---------
Co-authored-by: yurzhang <yurzhang.oi@gmail.com>
Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
2024-03-06 00:12:05 -05:00
Dilshod Tadjibaev
0601dc778b
Update operation.rs ( #1418 )
2024-03-05 12:43:02 -06:00
Dilshod Tadjibaev
c7834e4658
Tensor `permute` operator ( #1410 )
2024-03-05 12:29:13 -05:00
Dilshod Tadjibaev
4ed90a988e
Add `bool()` op for numerical tensor ( #1402 )
...
Fixes #1395
2024-03-04 12:39:17 -06:00
Nathaniel Simard
efbe818465
Refactor wgpu max pooling ( #1398 )
2024-03-04 13:23:11 -05:00
Louis Fortier-Dubois
046d975b76
Migrate reduce dim + More JIT instructions + Major wgpu reduce dim refactor ( #1396 )
2024-03-04 10:48:52 -05:00
Guillaume Lagrange
16d7666611
Add `argwhere` and `nonzero` boolean tensor ops ( #1394 )
...
* Add argwhere and nonzero bool tensor ops
* Fix wasm build
* Add missing vec
* Fix wasm cfg placement
* Fix comment
2024-03-04 08:33:59 -05:00
yurzhang
7d44f0b2d7
Interpolate tensor operation (Inference Only) ( #1246 )
...
* squash
feat: bilinear interpolation for tch, ndarray and wgpu backend
fix: reduce test case size to avoid exceeding floating-point precision limits
feat: support nearest-neighbor interpolation for ndarray backend
feat: support nearest-neighbor interpolation for wgpu backend
feat: support fusion backend
fix: no-std support
build: upgrade dependencies
* feat: bicubic interpolation for ndarray backend
* fix: test case precision
* feat: bicubic interpolation for wgpu backend
* Update Cargo.lock
---------
Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
Co-authored-by: Aasheesh Singh <20820983+ashdtu@users.noreply.github.com>
2024-03-02 12:01:35 -06:00
Dilshod Tadjibaev
d43a0b3f90
Add is_close and all_close tensor operators ( #1389 )
...
* Add is_close and all_close tensor operators
* Fix broken build issues
* Fix the table
* Add tests to candle
2024-03-01 15:37:14 -06:00
Matt Thompson
201e7f87c9
Element wise min/max between a pair of tensors ( #1385 )
...
* Added min_pair() and max_pair() methods to numeric Tensors
* Update book with added max_pair() and min_pair() methods.
* Fix spelling typo in comments
* Update comments per change requests
2024-03-01 15:56:48 -05:00
Guillaume Lagrange
3d93e6dae9
Add `not_equal` and `not_equal_elem` tensor ops ( #1374 )
...
* Fix tensor.equal_elem usage in book
* Add not_equal and not_equal_elem tensor ops
* Fix "element-wise" usage for correctness and uniformity
* Add bool_not_equal test
2024-03-01 12:13:25 -06:00
Louis Fortier-Dubois
9a0b2d6e7e
bool cast kernel ( #1391 )
2024-03-01 12:52:28 -05:00
Dilshod Tadjibaev
688958ee74
Enhance PyTorchRecorder to pass top-level key to extract state_dict ( #1300 )
...
* Enhance PyTorchRecorder to pass top level key to extract state_dict
This is needed for Whisper weight pt files.
* Fix missing hyphens
* Move top-level-key test under crates
* Add sub-crates as members of workspace
* Update Cargo.lock
* Add accidentally omitted line during merge
2024-02-29 12:57:27 -06:00
Guillaume Lagrange
4efc683df4
Upgrade to candle 0.4.1 ( #1382 )
...
* Fix python main entrypoint in book example
* Remove candle windows safeguards (#1178 )
* Bump candle-core from 0.3.3 to 0.4.1
* Remove windows current known issue
2024-02-29 11:29:11 -06:00
Nathaniel Simard
40bf3927f0
Migrate wgsl index shaders to gpu representation ( #1378 )
2024-02-28 16:17:47 -05:00
Yu Sun
330552afb4
docs(book-&-examples): modify book and examples with new `prelude` module ( #1372 )
2024-02-28 13:25:25 -05:00
Nathaniel Simard
57887e7a47
Refactor/elemwise/kernel selection + dynamic fused inplace operations and broadcasting ( #1359 )
2024-02-27 08:41:31 -05:00
Nathaniel Simard
bdec8d5813
[Refactor - JIT] Gather Scatter new implementations ( #1356 )
2024-02-26 17:20:09 -05:00
Louis Fortier-Dubois
576bb44bc8
Feat/autodiff/checkpoint ops ( #1358 )
2024-02-26 17:19:09 -05:00
Mathias Insley
bb5e6faff2
Feat/autotune int ops ( #1136 )
...
* Add int_random to int tensor ops
* Int random for tch backend
* Int random for burn-fusion
* int random for autodiff
* Int random for candle backend
* Int random for ndarray backend
* Int random for wgpu backend
* Merge imports
* Typo
* Shader file for int uniform distribution
* Create AutotuneOperationSet and public int_sum_dim_autotune
* Adjust bounds to 0..10
* Create uniform_int_kernel, unit tests, use new kernel
* Reduction kernels for regular and shared memory sum_dim int operations
* Macro that accomadates wgpu IntElement
* Add autotuning to int_mean_dim
* Use correct macro for Int autotuning
* Add int_mean_dim_shared_memory
* Add int_mean_dim and unit test
* Create autotunables for mean_dim
* Run fmt
* Remove comment
* Finish resolving merge conflict, fix doc
* Make the element trait bound a parameter to reduce_tune_ops macro
* Update book
* Fix requested change
* Change range to [0, 255] and update test accordingly
* Forgot to include candle in last commit
* Fix comment
* Use correct int autotune for mean dim
* Fix typo- not sure how this passed earlier
* Resolve syntax issues from merge
* Fix cast_float
* Saving here
* Continue fixing merge conflicts, all tests pass locally
* Run fmt
* Change cast_float to cast_u32_to_float
* Make uniform_int_inner_loop safer
* Be even more explicit about u32 casts
* Skip an intermediate step and cast directly to u32
* Replace JitElement + Element with IntElement
* Run fmt
* This should fix the CI
* This time for sure
2024-02-26 14:53:21 -05:00
Joshua Ferguson
706e0ebce2
Parser rewrite ( #1296 )
...
* Running into issues with identity nodes
* Vec<RefCell<Node>> seems to work for this
* back to passing tests
* Reworked IO into separate struct
* working towards exploiting topological ordering and more informative ident errors
* the passing of an initializer to coalesce is temporary
* cleaning up dead code
* handled unsqueeze
* reworked node initialization and dim inference
* mainly cleanup
* changed how io use is tracked, moved unsqueeze remapping out of dim inference
* `cargo xtask run-checks all` now passes
* added a fixme and a few doc strings
* removing println and dead code
* spaces in doc strings
* altered top sort to work on node proto, moved prior to node gen
* Update ir.rs
* Update from_onnx.rs
removed dead code
* updated doc string
* camalcased Onnx Graph Builder
* removed self import?
2024-02-24 10:51:58 -06:00
Arjun31415
8e23057c6b
Feature Addition: PRelu Module ( #1328 )
2024-02-24 10:24:22 -05:00
Yu Sun
1da47c9bf1
feat: add prelude module for convenience ( #1335 )
2024-02-24 10:17:30 -05:00
Tushushu
27f2095bcd
Implement Instance Normalization ( #1321 )
...
* config
* rename as instances, otherwise won't work
* refactor
* InstanceNormConfig
* remove unused var
* forward
* rename
* based on gn
* unit tests
* fix tests
* update doc
* update onnx doc
* renaming method
* add comment
---------
Co-authored-by: VungleTienan <tienan.liu@vungle.com>
2024-02-23 23:31:43 -06:00
Guillaume Lagrange
f5bd2a474f
Check that pa_type is valid before checking if is_binary ( #1354 )
2024-02-23 09:44:03 -06:00
Dilshod Tadjibaev
08302e38fc
Fix broken test and run-checks script ( #1347 )
2024-02-23 10:06:51 -05:00
Aasheesh Singh
c86db83fa9
Add support for Any, All operations to Tensor ( #1342 )
...
* add any, all op implementation for all tensor types
* add op to burn-book
* fix formatting
* refactor tensor operations from numeric to BaseOps.
* fix book doc
* comments fix and add more tests
2024-02-23 10:06:31 -05:00
Dilshod Tadjibaev
d6e859330f
Pytorch message updates ( #1344 )
...
* Update pytorch-model.md
* Update error.rs
2024-02-22 12:12:50 -06:00
Nathaniel Simard
b256c0404e
Refactor/wgpu/memco ( #1340 )
2024-02-22 07:59:54 -05:00
Guillaume Lagrange
bff4961426
Add enum module support ( #1337 )
2024-02-21 17:03:34 -05:00
Sylvain Benner
4427768570
[refactor] Move burn crates to their own crates directory ( #1336 )
2024-02-20 13:57:55 -05:00