Jonas Kantic
fba1e27e0c
Remainder operator ( #1726 )
...
* Adds remainder ops implementation for Tensor.
* Adds test for % operator.
2024-06-01 16:47:02 -05:00
jachym.putta
99e1ba4864
feat: expand onnx import ( #1813 )
...
* feat: added expand to import
2024-05-31 16:48:02 -05:00
jachym.putta
44f1053219
feat: added range onnx import ( #1834 )
...
* feat: added range onnx import
* fix: range input types
2024-05-31 16:40:54 -05:00
Nathaniel Simard
36d4bcd705
[Refactor - Breaking] Refactor cube operations with better names & Support subgroup operations ( #1839 )
2024-05-31 17:07:21 -04:00
will-maclean
13a6f84bc3
Feature/onnx argmax ( #1814 )
...
* pre-test
* implementing argmax for burn-import from onnx
* tidying
* fixing return types and tests
* addressing feedback
* only warn when select_last_index!=0
2024-05-31 14:46:09 -04:00
Louis Fortier-Dubois
de0b49e4a3
Cube: Topology constants ( #1838 )
...
---------
Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
2024-05-30 12:03:30 -04:00
Louis Fortier-Dubois
61c9fdbbc8
Cube: cleaner use of topology values ( #1835 )
...
* constant keyword parsing
* works
2024-05-29 09:08:10 -04:00
McArthur
a2ad424fc8
Indices Operator ( #1735 )
2024-05-29 09:05:31 -04:00
Louis Fortier-Dubois
cacc764205
Cube: support for shared memory ( #1831 )
2024-05-29 08:22:04 -04:00
Guillaume Lagrange
e4836241e1
Fix `DataSerialize` conversion for elements of the same type ( #1832 )
2024-05-28 18:12:44 -04:00
Louis Fortier-Dubois
e61b026918
Cube: support method call + prettier tensor metadata ( #1829 )
2024-05-27 15:18:17 -04:00
Nathaniel Simard
fd54a8b470
Add vectorization support into cube ( #1830 )
2024-05-27 14:21:29 -04:00
Louis Fortier-Dubois
dc85daa1c6
Cube: support for return + conv2d early return ( #1828 )
2024-05-27 13:19:00 -04:00
Nathaniel Simard
15d2055de8
Feat/cube/launch ( #1827 )
2024-05-27 12:15:06 -04:00
Adrian Müller
cccd96de48
Feat: Implement ONNX RandomUniform + RandomNormal in burn-import ( #1806 )
2024-05-27 10:07:04 -04:00
github-actions[bot]
85ba167582
Combined PRs ( #1823 )
...
* Bump cudarc from 0.10.0 to 0.11.0
Bumps [cudarc](https://github.com/coreylowman/cudarc ) from 0.10.0 to 0.11.0.
- [Release notes](https://github.com/coreylowman/cudarc/releases )
- [Commits](https://github.com/coreylowman/cudarc/compare/v0.10.0...v0.11.0 )
---
updated-dependencies:
- dependency-name: cudarc
dependency-type: direct:production
update-type: version-update:semver-minor
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump libc from 0.2.154 to 0.2.155
Bumps [libc](https://github.com/rust-lang/libc ) from 0.2.154 to 0.2.155.
- [Release notes](https://github.com/rust-lang/libc/releases )
- [Commits](https://github.com/rust-lang/libc/compare/0.2.154...0.2.155 )
---
updated-dependencies:
- dependency-name: libc
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump ratatui from 0.26.2 to 0.26.3
Bumps [ratatui](https://github.com/ratatui-org/ratatui ) from 0.26.2 to 0.26.3.
- [Release notes](https://github.com/ratatui-org/ratatui/releases )
- [Changelog](https://github.com/ratatui-org/ratatui/blob/main/CHANGELOG.md )
- [Commits](https://github.com/ratatui-org/ratatui/compare/v0.26.2...v0.26.3 )
---
updated-dependencies:
- dependency-name: ratatui
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump bytemuck from 1.15.0 to 1.16.0
Bumps [bytemuck](https://github.com/Lokathor/bytemuck ) from 1.15.0 to 1.16.0.
- [Changelog](https://github.com/Lokathor/bytemuck/blob/main/changelog.md )
- [Commits](https://github.com/Lokathor/bytemuck/compare/v1.15.0...v1.16.0 )
---
updated-dependencies:
- dependency-name: bytemuck
dependency-type: direct:production
update-type: version-update:semver-minor
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump proc-macro2 from 1.0.83 to 1.0.84
Bumps [proc-macro2](https://github.com/dtolnay/proc-macro2 ) from 1.0.83 to 1.0.84.
- [Release notes](https://github.com/dtolnay/proc-macro2/releases )
- [Commits](https://github.com/dtolnay/proc-macro2/compare/1.0.83...1.0.84 )
---
updated-dependencies:
- dependency-name: proc-macro2
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
---------
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-05-27 10:04:27 -04:00
Ikko Eltociear Ashimine
1c5e65ab26
docs: update README.md ( #1810 )
...
minor fix
2024-05-27 09:21:03 -04:00
Nathaniel Simard
c7ad25ab60
Update cuda-jit ( #1799 )
2024-05-24 11:31:47 -04:00
Louis Fortier-Dubois
23c622a9f8
Feat/cube/remaining ops ( #1807 )
2024-05-24 09:48:34 -04:00
Justin Restivo
1670a71711
Fix burn-jit compile error ( #1803 )
2024-05-23 20:24:42 -04:00
jachym.putta
ef4646c90f
feat: Greater + GreaterOrEqual onnx import ( #1801 )
2024-05-23 08:59:15 -04:00
jachym.putta
1f31e20ce8
feat: Less + LessOrEqual onnx import ( #1800 )
2024-05-23 08:04:44 -04:00
Louis Fortier-Dubois
e39b4d2da0
refactor reduce into separate traits ( #1798 )
2024-05-22 16:01:27 -04:00
Louis Fortier-Dubois
033171920c
Cube: first ported kernel + comptime support + variable reuse + cleanup ( #1797 )
2024-05-22 14:08:21 -04:00
Guillaume Lagrange
b466fd7606
Add seq start position when applying RoPE encoding ( #1796 )
2024-05-22 13:18:31 -04:00
jachym.putta
0918cf00c6
feat: added min onnx import ( #1778 )
2024-05-22 10:52:19 -04:00
Guillaume Lagrange
550086a5c1
Fix record nested value de/serialization ( #1751 )
2024-05-22 09:15:32 -04:00
Louis Fortier-Dubois
6137d42c10
fix prng bug during autotune ( #1791 )
2024-05-22 09:11:13 -04:00
jachym.putta
8c01444fc5
Adding max import ( #1769 )
...
* feat: add max import
* feat: implement the right max operation (hopefully)
2024-05-22 08:31:55 -04:00
Mathias Insley
81ecd14f83
Feat/squeeze dims ( #1779 )
2024-05-22 07:53:51 -04:00
Louis Fortier-Dubois
76fe0ed881
Refactor/cube/vectorization ( #1781 )
2024-05-19 13:20:55 -04:00
Louis Fortier-Dubois
499ff0dd26
Feat/enable cube cl ( #1777 )
...
* Ben WIP
* Compile burn-jit
* WGPU works
* Remove old code
* move language cube stuff
* cleaning up
* some import reworking
* remove cube reexport
* template feature flag in cube
* ci
---------
Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
2024-05-19 10:55:04 -04:00
Mathias Insley
9c5b07c833
Squeeze Onnx Import ( #1753 )
2024-05-17 12:00:34 -04:00
Jonathan Richard
8de05e1419
Add configurable application logger to learner builder ( #1774 )
...
* refactor: add TracingSubscriberLogger trait and FileTracingSubscriberLogger struct
* Remove unused log module and renames, fmt
* Renamed tracing subscriber logger
* renamed to application logger installer
* book learner configuration update update
* fix typo
* unused import
2024-05-16 16:25:33 -04:00
Nathaniel Simard
7ab2ba1809
Feat/cubecl ir ( #1776 )
...
---------
Co-authored-by: louisfd <louisfd94@gmail.com>
2024-05-16 15:08:53 -04:00
Louis Fortier-Dubois
542790e17e
CubeCL first iteration ( #1756 )
2024-05-15 10:24:37 -04:00
getumen
e823338750
Add Clone trait to the `OptimizerAdaptor` and Clone implementations to the optimizers ( #1770 )
2024-05-15 09:18:09 -04:00
Ben Barber
d3cd6c4928
Replace opaque return types in optim ( #1767 )
...
* update ARCHITECTURE.md links to project architecture section in contributor book
* replace opaque return type in optim
2024-05-13 22:21:20 -04:00
Nathaniel Simard
9dcec0b998
Refactor/jit fusion ( #1750 )
...
* Reads & Writes with index_ref
* WIP
* Fix operations
* Cleanup
2024-05-13 12:48:23 -04:00
Ahmed Yarub Hani Al Nuaimi
10737527d8
#1747 Upgrade Rust dependencies ( #1748 )
...
* #1747
Upgrade Rust dependencies
* Revert upgrade for tch
The update of tch on windows gives an error:
INTEL MKL ERROR: The specified module could not be found. mkl_vml_avx2.1.dll.
Intel MKL FATAL ERROR: cannot load mkl_vml_avx2.1.dll or mkl_vml_def.1.dll.
* Keep only .cargo/config.toml file which works with rust > 1.75
---------
Co-authored-by: Sylvain Benner <sylvain@benner.online>
2024-05-10 16:25:19 -04:00
Thierry Cantin-Demers
b09d8431df
Fix Cargo.toml repository links ( #1749 )
...
* Fix wgpu github link
* Fix burn-train repo link
* Fix burn-tensor github repo
* Fix burn-tensor repo link
* Fix remaining repo links in crates Cargo.toml
---------
Co-authored-by: Jonathan Richard <47578360+jwric@users.noreply.github.com>
2024-05-09 15:40:05 -04:00
Arjun31415
5bbc5ea944
Added ONNX AvgPool1d ( #1744 )
2024-05-07 16:10:18 -05:00
Nathaniel Simard
a6e3b4e81e
Fix select assign backward ( #1739 )
2024-05-07 11:37:43 -04:00
Sébastien Boisvert
bd06b38fac
Refactor: replace trait TemplateKernel by existing trait JitKernel ( #1737 )
...
* Refactor: replace trait TemplateKernel by existing trait JitKernel
* Refactor: implement trait JitKernel for struct Kernel
2024-05-06 20:59:00 -04:00
Arjun31415
7f94f4c219
Add MaxPool1d ONNX Op( #1725 )
2024-05-06 10:51:00 -05:00
Anton Blomström
fb13503fa9
Add reduce sum onnx ops to burn imports ( #1723 )
2024-05-06 10:49:17 -05:00
Anton Blomström
f8994e044c
Fix unstable tests when run concurrently ( #1724 )
2024-05-05 15:27:42 -05:00
Arjun31415
152509c378
PReLu ONNX import ( #1721 )
...
* added prelu onnx operator
* bug fix
* added onnx tests and burn codegen tests
* fix tests
* added prelu to supported onnx ops and add prelu to dim_inference
2024-05-04 13:45:42 -05:00
Louis Fortier-Dubois
a8661a2f53
Autodiff Memory Management: BFS ( #1710 )
2024-05-03 09:45:21 -04:00
Nathaniel Simard
5d959e2884
[Fusion] Support multi-precision fusion ( #1718 )
2024-05-02 18:22:56 -04:00
Louis Fortier-Dubois
2e4c82fa64
Fix repeat for dims > 1 ( #1713 )
2024-05-01 09:11:38 -04:00
Dilshod Tadjibaev
3a02a54e55
Update SUPPORTED-ONNX-OPS.md ( #1717 )
...
gather ONNX was checked off but actually GatherElements should have been updated.
2024-05-01 08:02:59 -04:00
Dilshod Tadjibaev
ff9e875321
ONNX debug improvements ( #1712 )
...
* Minor debug improvements
* Change warn to panic
* Log improvements
2024-04-30 16:36:55 -05:00
Nathaniel Simard
587b8f80b3
First draft CUDA runtime ( #1685 )
...
Initial cuda runtime crate with a WIP compiler.
2024-04-30 09:46:29 -04:00
Jonathan Merritt
ab501431b1
Handle ndarray matmul broadcasting ( #1679 )
...
* Handle ndarray matmul broadcasting
- Use strides to map linear batch indices from
the output back to the input arrays.
* Fix typos
2024-04-29 17:25:27 -05:00
Dilshod Tadjibaev
1cdceb590f
Skip updating shape for linear if not present ( #1700 )
2024-04-29 14:53:18 -05:00
WU Chen
b387829731
Implement bidirectional LSTM ( #1035 )
...
* resolve conflict
* move `gate_product` to `GateController`
* BiLstm needs to use its own initializer when init
* resolve conflicts
* add some comments
* improve doc
* correct the description of GateController
* fix fmt
* add `LstmState`
* add test for state
* set batch 2 in bilstm test
* resolve conflict
* fix
* fix doc
* change the batch size back to 1
* change the batch size back to 1
* modify docstring; delete dead comment
2024-04-26 13:28:36 -05:00
Louis Fortier-Dubois
6ae3926006
New autodiff graph memory management strategy ( #1698 )
...
---------
Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
2024-04-26 12:25:53 -04:00
Nathaniel Simard
2f294c5092
Fix lstm batch size bug ( #1695 )
2024-04-26 08:54:12 -04:00
Guillaume Lagrange
ce2429eb10
Refactor element type to be decoupled from runtime ( #1693 )
2024-04-26 08:53:55 -04:00
Dilshod Tadjibaev
67ec06d5d8
ONNX support for scalar unsqueeze ( #1690 )
...
* Revert 1c639c8393
1c639c8393
?diff=unified&w=0
* Refactor by @laggui
* Refactor unsqueeze
* Add support for scalar unsqueeze
* Removed dead comment
2024-04-25 16:05:28 -05:00
Nathaniel Simard
599a20d586
Upgrade wgpu ( #1692 )
2024-04-25 16:32:50 -04:00
Dilshod Tadjibaev
a1bd14c5ae
Reshape bug fix ( #1684 )
...
* Revert 1c639c8393
1c639c8393
?diff=unified&w=0
* Refactor by @laggui
* Refactor unsqueeze
2024-04-24 19:31:53 -05:00
Nathaniel Simard
886a1de235
Refactor/burn compute ( #1580 )
2024-04-23 13:05:15 -04:00
Sylvain Benner
c579686a8a
Move HandleContainer and Tensor Ops descriptions from burn-fusion to burn-tensor ( #1654 )
...
* Move HandlerContainer and Tensor Ops description to burn-tensor
Move HandleContainer and Tensor operations descriptions to burn-tensor crate.
Removed the FusionDevice and replaced it with a DeviceOps trait bound to Backend::Device.
For now added modules to burn-tensor are excluded from no-std as they rely on Arc.
* [burn-tensor] Flatten module hierarchy for tensor representation
+ Add new repr feature to cargo file.
* Remove prefix on dosctring
* [burn-fusion] Require default features of burn-tensor
2024-04-23 11:27:54 -04:00
Guillaume Lagrange
e6b1b7a317
Add layer norm onnx op support ( #1680 )
2024-04-23 11:19:07 -04:00
Dilshod Tadjibaev
1718da5210
Fix reshape bug (support for opset version 1) ( #1667 )
...
* Make reshape op version 1
* Refactor per PR feedback
2024-04-22 17:52:25 -05:00
Nathaniel Simard
29fa2ee76c
Support linear 1d ( #1682 )
2024-04-22 18:39:09 -04:00
Alex Errant
d62b344d5b
`Arc<EventStoreClient>` to `Rc<EventStoreClient>` ( #1668 )
2024-04-22 18:21:53 -04:00
신희제(Heejae Shin/joel.barish)
2a7b296a1b
Add sign ONNX op import support ( #1663 )
...
* Add sign ONNX op support
* Update SUPPORTED-ONNX-OPS.md
2024-04-22 09:10:50 -04:00
Louis Fortier-Dubois
2140d9b568
remove JIT subsequent RNG tests ( #1652 )
2024-04-21 09:48:11 -04:00
Dilshod Tadjibaev
1433284a0f
Fix bug 1645 (Unsqueeze OpSet 11) ( #1661 )
...
* Add unsqueeze opset 16 test
* Fix for unsqueeze opset 11
* Remove println statement
2024-04-19 14:17:44 -05:00
Guillaume Lagrange
b65a487300
Fix transpose onnx op (permute) ( #1657 )
2024-04-19 09:34:03 -04:00
Nico Zweifel
ee12aee2e7
fix: `window` -> `pub window` in `dataset/mod.rs` ( #1658 )
...
* update dataset/mod.rs
* Update mod.rs
* Update window.rs
2024-04-19 09:33:21 -04:00
Guillaume Lagrange
9fbcbed20f
Add where onnx op support ( #1653 )
...
* Add where onnx op support
* Add broadcasting support
* Remove broadcasting limitation comment
* Fix broadcasting in mask where
* Forgot to reflect changes in codegen test
* Fix clippy
2024-04-18 15:46:02 -04:00
Guillaume Lagrange
7705fd9c25
Add matmul ONNX op support ( #1638 )
...
* Mul onnx op already supported
* Add matmul onnx op checks and tests
* Add missing eq derives
* Change supscript symbol
* Remove dead code
* Add support for matmul broadcast
* No more broadcasting restrictions
* Add results comment for mm, mv and vm
2024-04-18 09:20:31 -04:00
Dilshod Tadjibaev
2a721a9d0c
Enable native sign operation for Candle backend ( #1647 )
...
* Enable native sign operation for Candle backend
* Use fixed revision
2024-04-17 09:07:56 -04:00
Guillaume Lagrange
424033283a
Add reduce max ONNX op support ( #1636 )
...
* Add reduce max onnx op support
* Fix comments on tensor rank 1 result
2024-04-17 08:26:46 -04:00
Nico Zweifel
5a3f345734
WindowDataset/windows function ( #1553 )
2024-04-17 07:51:53 -04:00
Guillaume Lagrange
35b36bbe62
Add shape ONNX op support ( #1639 )
...
* Add shape onnx op support
* Remove cast node from onnx graph
* Fix shape implementation
* Fix shape config error message
* Fix typo
* Fix clippy type complexity for generated code
2024-04-16 09:28:21 -04:00
Guillaume Lagrange
6d96e8d808
[ONNX] Add not op and extend cast support to tensors ( #1634 )
...
* Add not onnx op support
* Extend cast onnx support to tensors
* Fix clippy
2024-04-16 08:45:25 -04:00
Mathias Insley
7377bbe31c
Feat/remainder ( #1597 )
...
* Add remainder_scalar op to numeric trait and associated int/float functions
* Update burn-tch crate
* Update ndarray crate
* Update jit crate
* Update candle crate
* Update fusion crate
* Update autodiff crate
* Forgot float.rs for fusion
* Add burn-tensor tests
* Redirect to the pre-existing modulus op
* Fix sign
* Remove mut from burn-tch
* Use sign trick to make wgpu backend work
* Add more unit tests in to cover bases
* Naming fix for burn-fusion
* Update tests w/PyTorch link
* Use different WGSL instructions for remainder
* Redirect to remainder Operator instead of modulo
* Revert Modulo in instruction.rs
2024-04-16 08:35:20 -04:00
Mathias Insley
48c61ebb81
Docs/update contributor book ( #1622 )
...
* Update links to latest commit off main
* Some pedantry
* Update links and add jit
* Update instructions for burn-jit and wgpu
* Updated import section with more recent links
* Some grammar/typo/styling fixes
* Code added to burn-wgpu too
2024-04-16 08:33:59 -04:00
Guillaume Lagrange
d5f20e2711
Add reduce mean ONNX op support ( #1637 )
...
* Add reduce mean onnx op support
* Fix comment
2024-04-16 07:59:35 -04:00
Dilshod Tadjibaev
340a12463a
Update SUPPORTED-ONNX-OPS.md ( #1641 )
2024-04-16 07:52:15 -04:00
Guillaume Lagrange
81a67b6a09
Add sin onnx op support ( #1633 )
2024-04-15 15:28:16 -04:00
Sylvain Benner
e303e31c8b
Bump next version of Burn to 0.14.0 ( #1618 )
2024-04-12 17:14:45 -04:00
Guillaume Lagrange
cf7b279e5e
Fix burn README symlink ( #1617 )
2024-04-12 16:00:47 -04:00
Guillaume Lagrange
9980db440d
Remove unused assets ( #1616 )
2024-04-12 15:48:16 -04:00
Guillaume Lagrange
264c167c11
Update licenses symlinks ( #1613 )
2024-04-12 14:43:58 -04:00
Nathaniel Simard
ff844b1667
Fix candle backend sync ( #1579 )
...
* Fix candle backend sync
* tch mps sync
* clippy
---------
Co-authored-by: louisfd <louisfd94@gmail.com>
2024-04-12 12:15:50 -04:00
Aasheesh Singh
fb1da53a38
support for rotary positional encoding to transformer modules. ( #1604 )
...
* add rotary positional encoding to transformer modules.
* fix f64 error
* use num_traits
* add panic condition
2024-04-12 11:45:49 -04:00
Louis Fortier-Dubois
23210f05f2
JIT: Autotune matmul tiling 2d unroll ( #1601 )
...
* autotune tiling 2d unroll
* clippy
* forgotten important stuff
2024-04-12 10:15:21 -04:00
Nathaniel Simard
07a61a1cec
Fix autodiff memory management graph cleaning ( #1602 )
2024-04-11 16:21:00 -04:00
Guillaume Lagrange
0cbe9a927d
Add learner training report summary ( #1591 )
...
* Add training report summary
* Fix LossMetric batch size state
* Add NumericEntry de/serialize
* Fix clippy suggestion
* Compact recorder does not use compression (anymore)
* Add learner summary expected results tests
* Add summary to learner builder and automatically display in fit
- Add LearnerSummaryConfig
- Keep track of summary metrics names
- Add model field when displaying from learner.fit()
2024-04-11 12:32:25 -04:00
Louis Fortier-Dubois
bdb62fbcd0
Repeat ops autodiff & fusion + fix autodiff ones & zeros ( #1600 )
...
* added repeat to autodiff and fusion + zero one backend init in autodiff
* autodiff for repeat
2024-04-11 11:32:45 -04:00
Dilshod Tadjibaev
2f885480ed
Use num-traits for float ops ( #1584 )
2024-04-08 10:16:20 -05:00
Guillaume Lagrange
f3e0aa6689
Add multi-label classification dataset and metric ( #1572 )
...
* Add multilabel classification dataset
- Add MultiLabel annotation support
- Refactor de/serialize annotation with AnnotationRaw
- Add ImageFolderDataset::with_items methods
* Fix custom-image-classification example deps
* Add image_folder_dataset_multilabel test
* Do not change class names order when provided
* Add hamming score and multi-label classification output
* Add new_classification_with_items test
* Fix clippy suggestions
* Implement default trait for hamming score
* Remove de/serialization and use AnnotationRaw as type
* Fix clippy
* Fix metric backend phantom data
2024-04-05 13:16:46 -04:00
Louis Fortier-Dubois
f5159b6d22
Refactor: split JitKernel and SourceKernel ( #1569 )
...
* refactor execute_dynamic into Execution
* minor change
* extension cfg
* jitkernel and sourcekernel
* add todo statement
* cleanup and docs
* update book
* fix server dependancy on compiler
* refactor into shader information
* refactor to compile shader once
* clippy
* clippy
* clippy
* fix doc
* fix doc
* fmt
* rename feature flag
* refactor
* All broked
* compile at the right time
* todo done
* all dynamic
* all dynamic in template too
* fmt
* fix ci
---------
Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
2024-04-05 12:58:10 -04:00
Nathaniel Simard
1239d9bfa3
[Breaking] Make Tensor, Module, Optimizer !Sync + Refactor Autodiff ( #1575 )
2024-04-04 16:01:17 -04:00
Guillaume Lagrange
ce898ff899
Fix pytorch recorder adapt_linear when using autodiff backend ( #1576 )
...
* Fix pytorch recorder adapt_linear when using autodiff backend
* Fix comment typo
2024-04-04 12:29:24 -04:00
Guillaume Lagrange
0978c8a586
Support multilabel binary cross entropy ( #1571 )
...
* Support multilabel binary cross entropy
* Add missing alloc Vec
2024-04-03 08:03:07 -04:00
Nathaniel Simard
b0c5986d16
Feat/lazy init ( #1539 )
2024-04-02 10:13:35 -04:00
Guillaume Lagrange
8d210a152f
Move log_sigmoid to activation ops ( #1558 )
2024-04-02 09:25:40 -04:00
Louis Fortier-Dubois
edcd92f13d
Refactor execute_dynamic with Execution struct ( #1550 )
2024-03-28 17:27:48 -04:00
Nathaniel Simard
efc3b2d243
[Breaking] add runtime options in wgpu init methods ( #1505 )
2024-03-28 12:44:38 -04:00
Louis Fortier-Dubois
279be0496a
Conv Transpose: migration + decent speedup ( #1541 )
...
* convtranspose benchmark
* adjust bench
* conv transpose works
* Conv Transpose: migration + decent speedup
* delete template folder
* typos
* fix
2024-03-28 12:13:06 -04:00
Guillaume Lagrange
b8fc3f141e
Numerically stable log_sigmoid ( #1548 )
2024-03-28 11:54:23 -04:00
Dilshod Tadjibaev
70b92cb2fb
Update SUPPORTED-ONNX-OPS.md ( #1547 )
2024-03-28 10:38:53 -05:00
Karsten Becker
c21d5a3207
Add LeakyReLu implementation ( #1208 )
...
* Implement LeakyReLu
* Cargo fmt
* Apply suggestions
* cargo fmt
* Use float_mul_scalar
* Should be grad
* Add to books module
* Move test files
* Update leaky relu to use activation function
* Update tensor.md
* Fix failing test due to approx
* Add back the function comment
* Fix comment per PR feedback
---------
Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
2024-03-27 13:57:51 -05:00
jcmullwh
626457e1c6
Provide Tensor Padding Helpers #960 ( #1097 )
...
* Initial padding approach
Create padding implementation for the last two dimensions of Float and Int Tensors.
Create PadMode Enum, allowing Constant padding.
Create Padding Struct with Uniform, Asymmetric, height, and width implementations.
Create tests for the padding implementation.
* Update padding.rs
remove unneeded import
* Update from Merge
Use crate Element
Swap from old from_data() to new from_data_devauto()
* Formatting Changes
Formatting changes from cargo fmt --all
* Additional Format Change
One more format change that cargo fmt didn't get the first time.
* Changes to Example
Modify Example to ensure it works.
* modify naming
better names for impl / input variables.
* Modify API
- Change Padding to PadSize.
- integrate padding value into PadMode.
- update tests and examples.
* Comments and print
Improve comments+naming and remove println
* Pad Fixes
Moved pad to numeric
Simplified PadMode Element
updated tensor creations
fixed doc example
* Fix test location
* Simplified pad API
* Fix for failed unit tests
* Remove bool_full
* Rename `pads` to `padding`
---------
Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
2024-03-27 12:46:55 -05:00
Nathaniel Simard
40a26bd2ea
Feat/backend bridge ( #1529 )
2024-03-26 19:24:45 -04:00
Louis Fortier-Dubois
5bac300688
Migrate/jit/interpolate ( #1528 )
...
* separate forward backward
* refactor with pool strategy
* refactor further
* pooling refactored
* refactoring for adaptive wip
* wip adaptive
* adaptive
* delete some wgsl
* avg pool backward
* clippy
* refactor interpolate files
* nearest shader
* nearest
* some boilerplate
* wip
* bilinear
* nearest backward
* cubic
* cleanup
* minor refactor
* add some white space
2024-03-26 08:57:26 -04:00
Louis Fortier-Dubois
37b61ea646
Migrate/jit/adaptive avg pool backward ( #1530 )
...
* separate forward backward
* refactor with pool strategy
* refactor further
* pooling refactored
* refactoring for adaptive wip
* wip adaptive
* adaptive
* delete some wgsl
* avg pool backward
* clippy
* minor refactor
* works
* delete wgsl
2024-03-26 08:38:06 -04:00
Aasheesh Singh
a77979e0b6
add rms norm layer ( #1527 )
2024-03-25 18:59:11 -04:00
Louis Fortier-Dubois
da5b0438ec
Migrate/jit/pooling ( #1509 )
...
* separate forward backward
* refactor with pool strategy
* refactor further
* pooling refactored
* refactoring for adaptive wip
* wip adaptive
* adaptive
* delete some wgsl
* avg pool backward
* clippy
* minor refactor
2024-03-25 16:04:58 -04:00
Aasheesh Singh
613e698007
Feat/swiglu ( #1507 )
2024-03-25 15:55:27 -04:00
Louis Fortier-Dubois
4542ceddca
Migrate/jit/conv2d ( #1493 )
...
* conv2d but bug
* convolution done
* minor clean
* delete wgsl
2024-03-25 10:45:40 -04:00
Sylvain Benner
0adda72316
[backend-comparison] Add system information to benchmark results ( #1495 )
...
* Bump sysinfo crate to 0.30.7
* [backend-comparison] Add CPUs and GPUs system info to results
* [backend-comparison] Add integrated GPUs to gathered system info
* [backend-comparison] Use AutoGraphicsApi wgpu backend selection
2024-03-22 23:24:49 -04:00
Dilshod Tadjibaev
6feda90a8c
Tensor expand operator ( #1508 )
...
* Improve CI cache - remove burn-tch artifacts
* PyTorch config deserializer from .pt file
* Update pytorch-model.md
* WIP
* Rename broadcast_to to expand
* Rename broadcast_to expand file
* Implemented fusion backend and fix bugs
* Remove old files
* Remove unused state
* Rename to the correct op name
* Add missing comment
* Fix expand check function doc
* Rename the leftover names
* Rename leftover names
2024-03-22 16:33:53 -05:00
Guillaume Lagrange
dc45cf1700
Add `topk` tensor operation ( #1497 )
...
* Add topk and topk_with_indices
* Change topk_with_indices test to guarantee order (previously equal elements)
2024-03-22 10:57:20 -04:00
Louis Fortier-Dubois
dd699a90a2
Migrate/jit/matmul tiling 2d ( #1472 )
...
* refactor matmul files
* wip refactor matmul
* everything is memco
* support local arrays
* advancing tiling2d
* advancing tiling2d
* advancing tiling2d
* tiling2d finished but buggy
* configurable unrolling
* not bugged
* fails on unroll
* stupid break
* tiling2d no assumption works
* clippy
* bounds check as bool
* lhs rhs as enum
* tiling 2d major refactor
* remove assign vec4
* variable declarations above loops
* fmt
* clippy
* Fix autotune + unroll
* move val
* clippy
* fmt
---------
Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
2024-03-22 08:26:32 -04:00
Sylvain Benner
0a8a3cc9e9
[xtask] Add support for cargo metadata new workspace member format ( #1500 )
2024-03-21 16:04:52 -04:00
Guillaume Lagrange
3e4af41694
Fix sort descending for 1d case ( #1494 )
2024-03-21 07:45:37 -04:00
Guillaume Lagrange
47a84cc980
Add tensor sorting operations ( #1488 )
...
* Add sort, sort_with_indices and argsort
* Fix wasm build
* Add sort ops autodiff
* Fix TODO parallel comment
* Fix sort_with_indices 1d and add descending options
* Fix clippy
* Fix ascending comment (configurable)
2024-03-20 14:51:04 -04:00
Guillaume Lagrange
430f642394
Change assert_approx_eq precision from 3 to 2 ( #1491 )
2024-03-19 12:26:21 -04:00
Rubén J.R
69f1877754
New learning rate schedulers ( #1481 )
2024-03-19 08:28:42 -05:00
carrotflakes
8911093b88
Add `flip` tensor operator ( #1468 )
2024-03-18 20:33:39 -05:00
Dilshod Tadjibaev
8a8300c1fb
Add tril_mask, triu_mask and diag_mask ops ( #1479 )
2024-03-18 10:15:40 -05:00
Louis Fortier-Dubois
c729401fb2
Remove unroll in shared reduce ( #1480 )
2024-03-17 13:56:54 -04:00
Arjun31415
d3af29c5b4
Missing `Debug` derive for Group Norm Config ( #1482 )
2024-03-17 13:12:50 -04:00
Louis Fortier-Dubois
cf3c1ca80a
Migrate/jit/cat ( #1457 )
2024-03-17 11:37:36 -04:00
Louis Fortier-Dubois
41d01b8e19
Migrate/jit/prod ( #1474 )
2024-03-15 18:29:30 -04:00
Arjun31415
4de1272344
Feat: Add Leaky Relu Model ( #1467 )
2024-03-14 10:53:40 -05:00
WorldSEnder
53eb3ecfa9
Implement Huber loss ( #1444 )
...
* Implement Huber loss
Instead of using a sign or abs function, uses clamping to compute
it outside the bounds. This is better for the autodiff backend.
* mention Huber loss in the book
* unify naming of residuals in comments
2024-03-13 12:55:46 -05:00
Dilshod Tadjibaev
7a98b2f663
Add prod and prod_dim tensor ops ( #1460 )
2024-03-12 14:00:02 -05:00
carrotflakes
80aac1dde4
Add Rank0 variant to AdaptorRecordV1 and AdaptorRecordItemV1 ( #1442 )
2024-03-12 13:08:20 -04:00
Kyle Chen
c52c49785d
Add linear learning rate scheduler ( #1443 )
2024-03-12 13:04:12 -04:00
Louis Fortier-Dubois
278fcb3dad
Migrate/jit/mask ( #1456 )
2024-03-12 12:43:05 -04:00
Louis Fortier-Dubois
02d37011ab
Fix/main/print ( #1459 )
2024-03-11 18:52:36 -04:00
Dilshod Tadjibaev
0138e16af6
Add Enum module support in PyTorchFileRecorder ( #1436 )
...
* Add Enum module support in PyTorchFileRecorder
Fixes #1431
* Fix wording/typos per PR feedback
2024-03-11 11:21:01 -05:00
Dilshod Tadjibaev
9d4fbc5a35
Rename `diagonal` to `eye` tensor op and add missing entry for diagonal to Book tensor section ( #1449 )
...
* Update tensor.md
* Rename diagonal to eye
* Remove extra space per PR feedback
2024-03-11 11:00:36 -05:00
Louis Fortier-Dubois
093cbd397d
JIT Migration: PRNG ( #1433 )
...
* wip bernoulli
* wip
* bernoulli works
* uniform works
* done
* remove old
* refactor prng traits
* forgot to save file
* allow
* clippy
* clippy
* scalar commutativity
* array instead of vec
2024-03-11 11:40:27 -04:00
Dilshod Tadjibaev
3f7e6bd5bc
Add `sign` tensor operator ( #1446 )
2024-03-11 10:39:30 -05:00
github-actions[bot]
56f460295a
Combined PRs ( #1439 )
...
* Bump web-time from 1.0.0 to 1.1.0
Bumps [web-time](https://github.com/daxpedda/web-time ) from 1.0.0 to 1.1.0.
- [Release notes](https://github.com/daxpedda/web-time/releases )
- [Changelog](https://github.com/daxpedda/web-time/blob/main/CHANGELOG.md )
- [Commits](https://github.com/daxpedda/web-time/compare/v1.0.0...v1.1.0 )
---
updated-dependencies:
- dependency-name: web-time
dependency-type: direct:production
update-type: version-update:semver-minor
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump rayon from 1.8.1 to 1.9.0
Bumps [rayon](https://github.com/rayon-rs/rayon ) from 1.8.1 to 1.9.0.
- [Changelog](https://github.com/rayon-rs/rayon/blob/main/RELEASES.md )
- [Commits](https://github.com/rayon-rs/rayon/compare/rayon-core-v1.8.1...rayon-core-v1.9.0 )
---
updated-dependencies:
- dependency-name: rayon
dependency-type: direct:production
update-type: version-update:semver-minor
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump tempfile from 3.10.0 to 3.10.1
Bumps [tempfile](https://github.com/Stebalien/tempfile ) from 3.10.0 to 3.10.1.
- [Changelog](https://github.com/Stebalien/tempfile/blob/master/CHANGELOG.md )
- [Commits](https://github.com/Stebalien/tempfile/compare/v3.10.0...v3.10.1 )
---
updated-dependencies:
- dependency-name: tempfile
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump the cargo group group with 1 update
Bumps the cargo group group with 1 update: [mio](https://github.com/tokio-rs/mio ).
Updates `mio` from 0.8.10 to 0.8.11
- [Release notes](https://github.com/tokio-rs/mio/releases )
- [Changelog](https://github.com/tokio-rs/mio/blob/master/CHANGELOG.md )
- [Commits](https://github.com/tokio-rs/mio/compare/v0.8.10...v0.8.11 )
---
updated-dependencies:
- dependency-name: mio
dependency-type: indirect
dependency-group: cargo-security-group
...
Signed-off-by: dependabot[bot] <support@github.com>
* Bump log from 0.4.20 to 0.4.21
Bumps [log](https://github.com/rust-lang/log ) from 0.4.20 to 0.4.21.
- [Release notes](https://github.com/rust-lang/log/releases )
- [Changelog](https://github.com/rust-lang/log/blob/master/CHANGELOG.md )
- [Commits](https://github.com/rust-lang/log/compare/0.4.20...0.4.21 )
---
updated-dependencies:
- dependency-name: log
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
---------
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-03-08 19:27:48 -05:00
Nathaniel Simard
2de270fe0e
Fix tch view data corruption ( #1434 )
2024-03-08 09:55:47 -05:00
Louis Fortier-Dubois
61c0474172
JIT migration: contiguous kernel ( #1424 )
...
* JIT migration: contiguous kernel
* delete wgsl
2024-03-08 08:01:29 -05:00
Louis Fortier-Dubois
040cd55f85
JIT migration: cast kernel ( #1423 )
2024-03-07 17:49:09 -05:00
Louis Fortier-Dubois
9eecc713a4
JIT: Fix min & max values ( #1429 )
...
* real min and max values
* fix
* fmt
2024-03-07 15:10:30 -05:00
Dilshod Tadjibaev
c7d4c23f97
Support for non-contiguous indexes in PyTorchFileRecorder keys ( #1432 )
...
* Fix non-contiguous indexes
* Update pytorch-model.md
* Simplify multiple forwards
2024-03-07 13:40:57 -06:00