Commit Graph

285 Commits

Author SHA1 Message Date
Louis Fortier-Dubois 2e4c82fa64
Fix repeat for dims > 1 (#1713) 2024-05-01 09:11:38 -04:00
Dilshod Tadjibaev 3a02a54e55
Update SUPPORTED-ONNX-OPS.md (#1717)
gather ONNX was checked off but actually GatherElements should have been updated.
2024-05-01 08:02:59 -04:00
Dilshod Tadjibaev ff9e875321
ONNX debug improvements (#1712)
* Minor debug improvements

* Change warn to panic

* Log improvements
2024-04-30 16:36:55 -05:00
Nathaniel Simard 587b8f80b3
First draft CUDA runtime (#1685)
Initial cuda runtime crate with a WIP compiler.
2024-04-30 09:46:29 -04:00
Jonathan Merritt ab501431b1
Handle ndarray matmul broadcasting (#1679)
* Handle ndarray matmul broadcasting

- Use strides to map linear batch indices from
  the output back to the input arrays.

* Fix typos
2024-04-29 17:25:27 -05:00
Dilshod Tadjibaev 1cdceb590f
Skip updating shape for linear if not present (#1700) 2024-04-29 14:53:18 -05:00
WU Chen b387829731
Implement bidirectional LSTM (#1035)
* resolve conflict

* move `gate_product` to `GateController`

* BiLstm needs to use its own initializer when init

* resolve conflicts

* add some comments

* improve doc

* correct the description of GateController

* fix fmt

* add `LstmState`

* add test for state

* set batch 2 in bilstm test

* resolve conflict

* fix

* fix doc

* change the batch size back to 1

* change the batch size back to 1

* modify docstring; delete dead comment
2024-04-26 13:28:36 -05:00
Louis Fortier-Dubois 6ae3926006
New autodiff graph memory management strategy (#1698)
---------

Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
2024-04-26 12:25:53 -04:00
Nathaniel Simard 2f294c5092
Fix lstm batch size bug (#1695) 2024-04-26 08:54:12 -04:00
Guillaume Lagrange ce2429eb10
Refactor element type to be decoupled from runtime (#1693) 2024-04-26 08:53:55 -04:00
Dilshod Tadjibaev 67ec06d5d8
ONNX support for scalar unsqueeze (#1690)
* Revert 1c639c8393

1c639c8393?diff=unified&w=0

* Refactor by @laggui

* Refactor unsqueeze

* Add support for scalar unsqueeze

* Removed dead comment
2024-04-25 16:05:28 -05:00
Nathaniel Simard 599a20d586
Upgrade wgpu (#1692) 2024-04-25 16:32:50 -04:00
Dilshod Tadjibaev a1bd14c5ae
Reshape bug fix (#1684)
* Revert 1c639c8393

1c639c8393?diff=unified&w=0

* Refactor by @laggui

* Refactor unsqueeze
2024-04-24 19:31:53 -05:00
Nathaniel Simard 886a1de235
Refactor/burn compute (#1580) 2024-04-23 13:05:15 -04:00
Sylvain Benner c579686a8a
Move HandleContainer and Tensor Ops descriptions from burn-fusion to burn-tensor (#1654)
* Move HandlerContainer and Tensor Ops description to burn-tensor

Move HandleContainer and Tensor operations descriptions to burn-tensor crate.
Removed the FusionDevice and replaced it with a DeviceOps trait bound to Backend::Device.

For now added modules to burn-tensor are excluded from no-std as they rely on Arc.

* [burn-tensor] Flatten module hierarchy for tensor representation

+ Add new repr feature to cargo file.

* Remove prefix on dosctring

* [burn-fusion] Require default features of burn-tensor
2024-04-23 11:27:54 -04:00
Guillaume Lagrange e6b1b7a317
Add layer norm onnx op support (#1680) 2024-04-23 11:19:07 -04:00
Dilshod Tadjibaev 1718da5210
Fix reshape bug (support for opset version 1) (#1667)
* Make reshape op version 1

* Refactor per PR feedback
2024-04-22 17:52:25 -05:00
Nathaniel Simard 29fa2ee76c
Support linear 1d (#1682) 2024-04-22 18:39:09 -04:00
Alex Errant d62b344d5b
`Arc<EventStoreClient>` to `Rc<EventStoreClient>` (#1668) 2024-04-22 18:21:53 -04:00
신희제(Heejae Shin/joel.barish) 2a7b296a1b
Add sign ONNX op import support (#1663)
* Add sign ONNX op support

* Update SUPPORTED-ONNX-OPS.md
2024-04-22 09:10:50 -04:00
Louis Fortier-Dubois 2140d9b568
remove JIT subsequent RNG tests (#1652) 2024-04-21 09:48:11 -04:00
Dilshod Tadjibaev 1433284a0f
Fix bug 1645 (Unsqueeze OpSet 11) (#1661)
* Add unsqueeze opset 16 test

* Fix for unsqueeze opset 11

* Remove println statement
2024-04-19 14:17:44 -05:00
Guillaume Lagrange b65a487300
Fix transpose onnx op (permute) (#1657) 2024-04-19 09:34:03 -04:00
Nico Zweifel ee12aee2e7
fix: `window` -> `pub window` in `dataset/mod.rs` (#1658)
* update dataset/mod.rs

* Update mod.rs

* Update window.rs
2024-04-19 09:33:21 -04:00
Guillaume Lagrange 9fbcbed20f
Add where onnx op support (#1653)
* Add where onnx op support

* Add broadcasting support

* Remove broadcasting limitation comment

* Fix broadcasting in mask where

* Forgot to reflect changes in codegen test

* Fix clippy
2024-04-18 15:46:02 -04:00
Guillaume Lagrange 7705fd9c25
Add matmul ONNX op support (#1638)
* Mul onnx op already supported

* Add matmul onnx op checks and tests

* Add missing eq derives

* Change supscript symbol

* Remove dead code

* Add support for matmul broadcast

* No more broadcasting restrictions

* Add results comment for mm, mv and vm
2024-04-18 09:20:31 -04:00
Dilshod Tadjibaev 2a721a9d0c
Enable native sign operation for Candle backend (#1647)
* Enable native sign operation for Candle backend

* Use fixed revision
2024-04-17 09:07:56 -04:00
Guillaume Lagrange 424033283a
Add reduce max ONNX op support (#1636)
* Add reduce max onnx op support

* Fix comments on tensor rank 1 result
2024-04-17 08:26:46 -04:00
Nico Zweifel 5a3f345734
WindowDataset/windows function (#1553) 2024-04-17 07:51:53 -04:00
Guillaume Lagrange 35b36bbe62
Add shape ONNX op support (#1639)
* Add shape onnx op support

* Remove cast node from onnx graph

* Fix shape implementation

* Fix shape config error message

* Fix typo

* Fix clippy type complexity for generated code
2024-04-16 09:28:21 -04:00
Guillaume Lagrange 6d96e8d808
[ONNX] Add not op and extend cast support to tensors (#1634)
* Add not onnx op support

* Extend cast onnx support to tensors

* Fix clippy
2024-04-16 08:45:25 -04:00
Mathias Insley 7377bbe31c
Feat/remainder (#1597)
* Add remainder_scalar op to numeric trait and associated int/float functions

* Update burn-tch crate

* Update ndarray crate

* Update jit crate

* Update candle crate

* Update fusion crate

* Update autodiff crate

* Forgot float.rs for fusion

* Add burn-tensor tests

* Redirect to the pre-existing modulus op

* Fix sign

* Remove mut from burn-tch

* Use sign trick to make wgpu backend work

* Add more unit tests in to cover bases

* Naming fix for burn-fusion

* Update tests w/PyTorch link

* Use different WGSL instructions for remainder

* Redirect to remainder Operator instead of modulo

* Revert Modulo in instruction.rs
2024-04-16 08:35:20 -04:00
Mathias Insley 48c61ebb81
Docs/update contributor book (#1622)
* Update links to latest commit off main

* Some pedantry

* Update links and add jit

* Update instructions for burn-jit and wgpu

* Updated import section with more recent links

* Some grammar/typo/styling fixes

* Code added to burn-wgpu too
2024-04-16 08:33:59 -04:00
Guillaume Lagrange d5f20e2711
Add reduce mean ONNX op support (#1637)
* Add reduce mean onnx op support

* Fix comment
2024-04-16 07:59:35 -04:00
Dilshod Tadjibaev 340a12463a
Update SUPPORTED-ONNX-OPS.md (#1641) 2024-04-16 07:52:15 -04:00
Guillaume Lagrange 81a67b6a09
Add sin onnx op support (#1633) 2024-04-15 15:28:16 -04:00
Sylvain Benner e303e31c8b
Bump next version of Burn to 0.14.0 (#1618) 2024-04-12 17:14:45 -04:00
Guillaume Lagrange cf7b279e5e
Fix burn README symlink (#1617) 2024-04-12 16:00:47 -04:00
Guillaume Lagrange 9980db440d
Remove unused assets (#1616) 2024-04-12 15:48:16 -04:00
Guillaume Lagrange 264c167c11
Update licenses symlinks (#1613) 2024-04-12 14:43:58 -04:00
Nathaniel Simard ff844b1667
Fix candle backend sync (#1579)
* Fix candle backend sync

* tch mps sync

* clippy

---------

Co-authored-by: louisfd <louisfd94@gmail.com>
2024-04-12 12:15:50 -04:00
Aasheesh Singh fb1da53a38
support for rotary positional encoding to transformer modules. (#1604)
* add rotary positional encoding to transformer modules.

* fix f64 error

* use num_traits

* add panic condition
2024-04-12 11:45:49 -04:00
Louis Fortier-Dubois 23210f05f2
JIT: Autotune matmul tiling 2d unroll (#1601)
* autotune tiling 2d unroll

* clippy

* forgotten important stuff
2024-04-12 10:15:21 -04:00
Nathaniel Simard 07a61a1cec
Fix autodiff memory management graph cleaning (#1602) 2024-04-11 16:21:00 -04:00
Guillaume Lagrange 0cbe9a927d
Add learner training report summary (#1591)
* Add training report summary

* Fix LossMetric batch size state

* Add NumericEntry de/serialize

* Fix clippy suggestion

* Compact recorder does not use compression (anymore)

* Add learner summary expected results tests

* Add summary to learner builder and automatically display in fit

- Add LearnerSummaryConfig
- Keep track of summary metrics names
- Add model field when displaying from learner.fit()
2024-04-11 12:32:25 -04:00
Louis Fortier-Dubois bdb62fbcd0
Repeat ops autodiff & fusion + fix autodiff ones & zeros (#1600)
* added repeat to autodiff and fusion + zero one backend init in autodiff

* autodiff for repeat
2024-04-11 11:32:45 -04:00
Dilshod Tadjibaev 2f885480ed
Use num-traits for float ops (#1584) 2024-04-08 10:16:20 -05:00
Guillaume Lagrange f3e0aa6689
Add multi-label classification dataset and metric (#1572)
* Add multilabel classification dataset

- Add MultiLabel annotation support
- Refactor de/serialize annotation with AnnotationRaw
- Add ImageFolderDataset::with_items methods

* Fix custom-image-classification example deps

* Add image_folder_dataset_multilabel test

* Do not change class names order when provided

* Add hamming score and multi-label classification output

* Add new_classification_with_items test

* Fix clippy suggestions

* Implement default trait for hamming score

* Remove de/serialization and use AnnotationRaw as type

* Fix clippy

* Fix metric backend phantom data
2024-04-05 13:16:46 -04:00
Louis Fortier-Dubois f5159b6d22
Refactor: split JitKernel and SourceKernel (#1569)
* refactor execute_dynamic into Execution

* minor change

* extension cfg

* jitkernel and sourcekernel

* add todo statement

* cleanup and docs

* update book

* fix server dependancy on compiler

* refactor into shader information

* refactor to compile shader once

* clippy

* clippy

* clippy

* fix doc

* fix doc

* fmt

* rename feature flag

* refactor

* All broked

* compile at the right time

* todo done

* all dynamic

* all dynamic in template too

* fmt

* fix ci

---------

Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
2024-04-05 12:58:10 -04:00
Nathaniel Simard 1239d9bfa3
[Breaking] Make Tensor, Module, Optimizer !Sync + Refactor Autodiff (#1575) 2024-04-04 16:01:17 -04:00
Guillaume Lagrange ce898ff899
Fix pytorch recorder adapt_linear when using autodiff backend (#1576)
* Fix pytorch recorder adapt_linear when using autodiff backend

* Fix comment typo
2024-04-04 12:29:24 -04:00
Guillaume Lagrange 0978c8a586
Support multilabel binary cross entropy (#1571)
* Support multilabel binary cross entropy

* Add missing alloc Vec
2024-04-03 08:03:07 -04:00
Nathaniel Simard b0c5986d16
Feat/lazy init (#1539) 2024-04-02 10:13:35 -04:00
Guillaume Lagrange 8d210a152f
Move log_sigmoid to activation ops (#1558) 2024-04-02 09:25:40 -04:00
Louis Fortier-Dubois edcd92f13d
Refactor execute_dynamic with Execution struct (#1550) 2024-03-28 17:27:48 -04:00
Nathaniel Simard efc3b2d243
[Breaking] add runtime options in wgpu init methods (#1505) 2024-03-28 12:44:38 -04:00
Louis Fortier-Dubois 279be0496a
Conv Transpose: migration + decent speedup (#1541)
* convtranspose benchmark

* adjust bench

* conv transpose works

* Conv Transpose: migration + decent speedup

* delete template folder

* typos

* fix
2024-03-28 12:13:06 -04:00
Guillaume Lagrange b8fc3f141e
Numerically stable log_sigmoid (#1548) 2024-03-28 11:54:23 -04:00
Dilshod Tadjibaev 70b92cb2fb
Update SUPPORTED-ONNX-OPS.md (#1547) 2024-03-28 10:38:53 -05:00
Karsten Becker c21d5a3207
Add LeakyReLu implementation (#1208)
* Implement LeakyReLu

* Cargo fmt

* Apply suggestions

* cargo fmt

* Use float_mul_scalar

* Should be grad

* Add to books module

* Move test files

* Update leaky relu to use activation function

* Update tensor.md

* Fix failing test due to approx

* Add back the function comment

* Fix comment per PR feedback

---------

Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
2024-03-27 13:57:51 -05:00
jcmullwh 626457e1c6
Provide Tensor Padding Helpers #960 (#1097)
* Initial padding approach

Create padding implementation for the last two dimensions of Float and Int Tensors.

Create PadMode Enum, allowing Constant padding.

Create Padding Struct with Uniform, Asymmetric, height, and width implementations.

Create tests for the padding implementation.

* Update padding.rs

remove unneeded import

* Update from Merge

Use crate Element

Swap from old from_data() to new from_data_devauto()

* Formatting Changes

Formatting changes from cargo fmt --all

* Additional Format Change

One more format change that cargo fmt didn't get the first time.

* Changes to Example

Modify Example to ensure it works.

* modify naming

better names for impl / input variables.

* Modify API

- Change Padding to PadSize.
- integrate padding value into PadMode.
- update tests and examples.

* Comments and print

Improve comments+naming and remove println

* Pad Fixes

Moved pad to numeric

Simplified PadMode Element

updated tensor creations

fixed doc example

* Fix test location

* Simplified pad API

* Fix for failed unit tests

* Remove bool_full

* Rename `pads` to `padding`

---------

Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
2024-03-27 12:46:55 -05:00
Nathaniel Simard 40a26bd2ea
Feat/backend bridge (#1529) 2024-03-26 19:24:45 -04:00
Louis Fortier-Dubois 5bac300688
Migrate/jit/interpolate (#1528)
* separate forward backward

* refactor with pool strategy

* refactor further

* pooling refactored

* refactoring for adaptive wip

* wip adaptive

* adaptive

* delete some wgsl

* avg pool backward

* clippy

* refactor interpolate files

* nearest shader

* nearest

* some boilerplate

* wip

* bilinear

* nearest backward

* cubic

* cleanup

* minor refactor

* add some white space
2024-03-26 08:57:26 -04:00
Louis Fortier-Dubois 37b61ea646
Migrate/jit/adaptive avg pool backward (#1530)
* separate forward backward

* refactor with pool strategy

* refactor further

* pooling refactored

* refactoring for adaptive wip

* wip adaptive

* adaptive

* delete some wgsl

* avg pool backward

* clippy

* minor refactor

* works

* delete wgsl
2024-03-26 08:38:06 -04:00
Aasheesh Singh a77979e0b6
add rms norm layer (#1527) 2024-03-25 18:59:11 -04:00
Louis Fortier-Dubois da5b0438ec
Migrate/jit/pooling (#1509)
* separate forward backward

* refactor with pool strategy

* refactor further

* pooling refactored

* refactoring for adaptive wip

* wip adaptive

* adaptive

* delete some wgsl

* avg pool backward

* clippy

* minor refactor
2024-03-25 16:04:58 -04:00
Aasheesh Singh 613e698007
Feat/swiglu (#1507) 2024-03-25 15:55:27 -04:00
Louis Fortier-Dubois 4542ceddca
Migrate/jit/conv2d (#1493)
* conv2d but bug

* convolution done

* minor clean

* delete wgsl
2024-03-25 10:45:40 -04:00
Sylvain Benner 0adda72316
[backend-comparison] Add system information to benchmark results (#1495)
* Bump sysinfo crate to 0.30.7

* [backend-comparison] Add CPUs and GPUs system info to results

* [backend-comparison] Add integrated GPUs to gathered system info

* [backend-comparison] Use AutoGraphicsApi wgpu backend selection
2024-03-22 23:24:49 -04:00
Dilshod Tadjibaev 6feda90a8c
Tensor expand operator (#1508)
* Improve CI cache - remove burn-tch artifacts

* PyTorch config deserializer from .pt file

* Update pytorch-model.md

* WIP

* Rename broadcast_to to expand

* Rename broadcast_to expand file

* Implemented fusion backend and fix bugs

* Remove old files

* Remove unused state

* Rename to the correct op name

* Add missing comment

* Fix expand check function doc

* Rename the leftover names

* Rename leftover names
2024-03-22 16:33:53 -05:00
Guillaume Lagrange dc45cf1700
Add `topk` tensor operation (#1497)
* Add topk and topk_with_indices

* Change topk_with_indices test to guarantee order (previously equal elements)
2024-03-22 10:57:20 -04:00
Louis Fortier-Dubois dd699a90a2
Migrate/jit/matmul tiling 2d (#1472)
* refactor matmul files

* wip refactor matmul

* everything is memco

* support local arrays

* advancing tiling2d

* advancing tiling2d

* advancing tiling2d

* tiling2d finished but buggy

* configurable unrolling

* not bugged

* fails on unroll

* stupid break

* tiling2d no assumption works

* clippy

* bounds check as bool

* lhs rhs as enum

* tiling 2d major refactor

* remove assign vec4

* variable declarations above loops

* fmt

* clippy

* Fix autotune + unroll

* move val

* clippy

* fmt

---------

Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
2024-03-22 08:26:32 -04:00
Sylvain Benner 0a8a3cc9e9
[xtask] Add support for cargo metadata new workspace member format (#1500) 2024-03-21 16:04:52 -04:00
Guillaume Lagrange 3e4af41694
Fix sort descending for 1d case (#1494) 2024-03-21 07:45:37 -04:00
Guillaume Lagrange 47a84cc980
Add tensor sorting operations (#1488)
* Add sort, sort_with_indices and argsort

* Fix wasm build

* Add sort ops autodiff

* Fix TODO parallel comment

* Fix sort_with_indices 1d and add descending options

* Fix clippy

* Fix ascending comment (configurable)
2024-03-20 14:51:04 -04:00
Guillaume Lagrange 430f642394
Change assert_approx_eq precision from 3 to 2 (#1491) 2024-03-19 12:26:21 -04:00
Rubén J.R 69f1877754
New learning rate schedulers (#1481) 2024-03-19 08:28:42 -05:00
carrotflakes 8911093b88
Add `flip` tensor operator (#1468) 2024-03-18 20:33:39 -05:00
Dilshod Tadjibaev 8a8300c1fb
Add tril_mask, triu_mask and diag_mask ops (#1479) 2024-03-18 10:15:40 -05:00
Louis Fortier-Dubois c729401fb2
Remove unroll in shared reduce (#1480) 2024-03-17 13:56:54 -04:00
Arjun31415 d3af29c5b4
Missing `Debug` derive for Group Norm Config (#1482) 2024-03-17 13:12:50 -04:00
Louis Fortier-Dubois cf3c1ca80a
Migrate/jit/cat (#1457) 2024-03-17 11:37:36 -04:00
Louis Fortier-Dubois 41d01b8e19
Migrate/jit/prod (#1474) 2024-03-15 18:29:30 -04:00
Arjun31415 4de1272344
Feat: Add Leaky Relu Model (#1467) 2024-03-14 10:53:40 -05:00
WorldSEnder 53eb3ecfa9
Implement Huber loss (#1444)
* Implement Huber loss

Instead of using a sign or abs function, uses clamping to compute
it outside the bounds. This is better for the autodiff backend.

* mention Huber loss in the book

* unify naming of residuals in comments
2024-03-13 12:55:46 -05:00
Dilshod Tadjibaev 7a98b2f663
Add prod and prod_dim tensor ops (#1460) 2024-03-12 14:00:02 -05:00
carrotflakes 80aac1dde4
Add Rank0 variant to AdaptorRecordV1 and AdaptorRecordItemV1 (#1442) 2024-03-12 13:08:20 -04:00
Kyle Chen c52c49785d
Add linear learning rate scheduler (#1443) 2024-03-12 13:04:12 -04:00
Louis Fortier-Dubois 278fcb3dad
Migrate/jit/mask (#1456) 2024-03-12 12:43:05 -04:00
Louis Fortier-Dubois 02d37011ab
Fix/main/print (#1459) 2024-03-11 18:52:36 -04:00
Dilshod Tadjibaev 0138e16af6
Add Enum module support in PyTorchFileRecorder (#1436)
* Add Enum module support in PyTorchFileRecorder

Fixes #1431

* Fix wording/typos per PR feedback
2024-03-11 11:21:01 -05:00
Dilshod Tadjibaev 9d4fbc5a35
Rename `diagonal` to `eye` tensor op and add missing entry for diagonal to Book tensor section (#1449)
* Update tensor.md

* Rename diagonal to eye

* Remove extra space per PR feedback
2024-03-11 11:00:36 -05:00
Louis Fortier-Dubois 093cbd397d
JIT Migration: PRNG (#1433)
* wip bernoulli

* wip

* bernoulli works

* uniform works

* done

* remove old

* refactor prng traits

* forgot to save file

* allow

* clippy

* clippy

* scalar commutativity

* array instead of vec
2024-03-11 11:40:27 -04:00
Dilshod Tadjibaev 3f7e6bd5bc
Add `sign` tensor operator (#1446) 2024-03-11 10:39:30 -05:00
github-actions[bot] 56f460295a
Combined PRs (#1439)
* Bump web-time from 1.0.0 to 1.1.0

Bumps [web-time](https://github.com/daxpedda/web-time) from 1.0.0 to 1.1.0.
- [Release notes](https://github.com/daxpedda/web-time/releases)
- [Changelog](https://github.com/daxpedda/web-time/blob/main/CHANGELOG.md)
- [Commits](https://github.com/daxpedda/web-time/compare/v1.0.0...v1.1.0)

---
updated-dependencies:
- dependency-name: web-time
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump rayon from 1.8.1 to 1.9.0

Bumps [rayon](https://github.com/rayon-rs/rayon) from 1.8.1 to 1.9.0.
- [Changelog](https://github.com/rayon-rs/rayon/blob/main/RELEASES.md)
- [Commits](https://github.com/rayon-rs/rayon/compare/rayon-core-v1.8.1...rayon-core-v1.9.0)

---
updated-dependencies:
- dependency-name: rayon
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump tempfile from 3.10.0 to 3.10.1

Bumps [tempfile](https://github.com/Stebalien/tempfile) from 3.10.0 to 3.10.1.
- [Changelog](https://github.com/Stebalien/tempfile/blob/master/CHANGELOG.md)
- [Commits](https://github.com/Stebalien/tempfile/compare/v3.10.0...v3.10.1)

---
updated-dependencies:
- dependency-name: tempfile
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump the cargo group group with 1 update

Bumps the cargo group group with 1 update: [mio](https://github.com/tokio-rs/mio).


Updates `mio` from 0.8.10 to 0.8.11
- [Release notes](https://github.com/tokio-rs/mio/releases)
- [Changelog](https://github.com/tokio-rs/mio/blob/master/CHANGELOG.md)
- [Commits](https://github.com/tokio-rs/mio/compare/v0.8.10...v0.8.11)

---
updated-dependencies:
- dependency-name: mio
  dependency-type: indirect
  dependency-group: cargo-security-group
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump log from 0.4.20 to 0.4.21

Bumps [log](https://github.com/rust-lang/log) from 0.4.20 to 0.4.21.
- [Release notes](https://github.com/rust-lang/log/releases)
- [Changelog](https://github.com/rust-lang/log/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/log/compare/0.4.20...0.4.21)

---
updated-dependencies:
- dependency-name: log
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-03-08 19:27:48 -05:00
Nathaniel Simard 2de270fe0e
Fix tch view data corruption (#1434) 2024-03-08 09:55:47 -05:00
Louis Fortier-Dubois 61c0474172
JIT migration: contiguous kernel (#1424)
* JIT migration: contiguous kernel

* delete wgsl
2024-03-08 08:01:29 -05:00
Louis Fortier-Dubois 040cd55f85
JIT migration: cast kernel (#1423) 2024-03-07 17:49:09 -05:00
Louis Fortier-Dubois 9eecc713a4
JIT: Fix min & max values (#1429)
* real min and max values

* fix

* fmt
2024-03-07 15:10:30 -05:00
Dilshod Tadjibaev c7d4c23f97
Support for non-contiguous indexes in PyTorchFileRecorder keys (#1432)
* Fix non-contiguous indexes

* Update pytorch-model.md

* Simplify multiple forwards
2024-03-07 13:40:57 -06:00
Dilshod Tadjibaev b12646de0a
Truncate debug display for NestedValue (#1428)
* Truncate debug display for NestedValue

* Fix failing tests
2024-03-07 08:06:31 -05:00
Dilshod Tadjibaev 545444c02a
PyTorchFileRecord print debug option (#1425)
* Add debug print option to PyTorchFileRecorder

* Updated documentation and improved print output

* Improve print wording

* Updated per PR feedback
2024-03-06 16:11:37 -06:00
Nathaniel Simard b429cc39c1
Splitted the JIT stuff from the Wgpu stuff (#1417) 2024-03-06 11:23:53 -05:00
jackdarlison 3ff6e7170e
Switched arguments in `reshape_args_usize` check (#1409) 2024-03-06 08:45:12 -05:00
Aasheesh Singh 0c92c8c8eb
Autodiff/training support for Nearest Interpolation (#1414)
Add training support for nearest interpolation

---------

Co-authored-by: yurzhang <yurzhang.oi@gmail.com>
Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
2024-03-06 00:12:05 -05:00
Dilshod Tadjibaev 0601dc778b
Update operation.rs (#1418) 2024-03-05 12:43:02 -06:00
Dilshod Tadjibaev c7834e4658
Tensor `permute` operator (#1410) 2024-03-05 12:29:13 -05:00
Dilshod Tadjibaev 4ed90a988e
Add `bool()` op for numerical tensor (#1402)
Fixes #1395
2024-03-04 12:39:17 -06:00
Nathaniel Simard efbe818465
Refactor wgpu max pooling (#1398) 2024-03-04 13:23:11 -05:00
Louis Fortier-Dubois 046d975b76
Migrate reduce dim + More JIT instructions + Major wgpu reduce dim refactor (#1396) 2024-03-04 10:48:52 -05:00
Guillaume Lagrange 16d7666611
Add `argwhere` and `nonzero` boolean tensor ops (#1394)
* Add argwhere and nonzero bool tensor ops

* Fix wasm build

* Add missing vec

* Fix wasm cfg placement

* Fix comment
2024-03-04 08:33:59 -05:00
yurzhang 7d44f0b2d7
Interpolate tensor operation (Inference Only) (#1246)
* squash

feat: bilinear interpolation for tch, ndarray and wgpu backend

fix: reduce test case size to avoid exceeding floating-point precision limits

feat: support nearest-neighbor interpolation for ndarray backend

feat: support nearest-neighbor interpolation for wgpu backend

feat: support fusion backend

fix: no-std support

build: upgrade dependencies

* feat: bicubic interpolation for ndarray backend

* fix: test case precision

* feat: bicubic interpolation for wgpu backend

* Update Cargo.lock

---------

Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
Co-authored-by: Aasheesh Singh <20820983+ashdtu@users.noreply.github.com>
2024-03-02 12:01:35 -06:00
Dilshod Tadjibaev d43a0b3f90
Add is_close and all_close tensor operators (#1389)
* Add is_close and all_close tensor operators

* Fix broken build issues

* Fix the table

* Add tests to candle
2024-03-01 15:37:14 -06:00
Matt Thompson 201e7f87c9
Element wise min/max between a pair of tensors (#1385)
* Added min_pair() and max_pair() methods to numeric Tensors

* Update book with added max_pair() and min_pair() methods.

* Fix spelling typo in comments

* Update comments per change requests
2024-03-01 15:56:48 -05:00
Guillaume Lagrange 3d93e6dae9
Add `not_equal` and `not_equal_elem` tensor ops (#1374)
* Fix tensor.equal_elem usage in book

* Add not_equal and not_equal_elem tensor ops

* Fix "element-wise" usage for correctness and uniformity

* Add bool_not_equal test
2024-03-01 12:13:25 -06:00
Louis Fortier-Dubois 9a0b2d6e7e
bool cast kernel (#1391) 2024-03-01 12:52:28 -05:00
Dilshod Tadjibaev 688958ee74
Enhance PyTorchRecorder to pass top-level key to extract state_dict (#1300)
* Enhance PyTorchRecorder to pass top level key to extract state_dict

This is needed for Whisper weight pt files.

* Fix missing hyphens

* Move top-level-key test under crates

* Add sub-crates as members of workspace

* Update Cargo.lock

* Add accidentally omitted line during merge
2024-02-29 12:57:27 -06:00
Guillaume Lagrange 4efc683df4
Upgrade to candle 0.4.1 (#1382)
* Fix python main entrypoint in book example

* Remove candle windows safeguards (#1178)

* Bump candle-core from 0.3.3 to 0.4.1

* Remove windows current known issue
2024-02-29 11:29:11 -06:00
Nathaniel Simard 40bf3927f0
Migrate wgsl index shaders to gpu representation (#1378) 2024-02-28 16:17:47 -05:00
Yu Sun 330552afb4
docs(book-&-examples): modify book and examples with new `prelude` module (#1372) 2024-02-28 13:25:25 -05:00
Nathaniel Simard 57887e7a47
Refactor/elemwise/kernel selection + dynamic fused inplace operations and broadcasting (#1359) 2024-02-27 08:41:31 -05:00
Nathaniel Simard bdec8d5813
[Refactor - JIT] Gather Scatter new implementations (#1356) 2024-02-26 17:20:09 -05:00
Louis Fortier-Dubois 576bb44bc8
Feat/autodiff/checkpoint ops (#1358) 2024-02-26 17:19:09 -05:00
Mathias Insley bb5e6faff2
Feat/autotune int ops (#1136)
* Add int_random to int tensor ops

* Int random for tch backend

* Int random for burn-fusion

* int random for autodiff

* Int random for candle backend

* Int random for ndarray backend

* Int random for wgpu backend

* Merge imports

* Typo

* Shader file for int uniform distribution

* Create AutotuneOperationSet and public int_sum_dim_autotune

* Adjust bounds to 0..10

* Create uniform_int_kernel, unit tests, use new kernel

* Reduction kernels for regular and shared memory sum_dim int operations

* Macro that accomadates wgpu IntElement

* Add autotuning to int_mean_dim

* Use correct macro for Int autotuning

* Add int_mean_dim_shared_memory

* Add int_mean_dim and unit test

* Create autotunables for mean_dim

* Run fmt

* Remove comment

* Finish resolving merge conflict, fix doc

* Make the element trait bound a parameter to reduce_tune_ops macro

* Update book

* Fix requested change

* Change range to [0, 255] and update test accordingly

* Forgot to include candle in last commit

* Fix comment

* Use correct int autotune for mean dim

* Fix typo- not sure how this passed earlier

* Resolve syntax issues from merge

* Fix cast_float

* Saving here

* Continue fixing merge conflicts, all tests pass locally

* Run fmt

* Change cast_float to cast_u32_to_float

* Make uniform_int_inner_loop safer

* Be even more explicit about u32 casts

* Skip an intermediate step and cast directly to u32

* Replace JitElement + Element with IntElement

* Run fmt

* This should fix the CI

* This time for sure
2024-02-26 14:53:21 -05:00
Joshua Ferguson 706e0ebce2
Parser rewrite (#1296)
* Running into issues with identity nodes

* Vec<RefCell<Node>> seems to work for this

* back to passing tests

* Reworked IO into separate struct

* working towards exploiting topological ordering and more informative ident errors

* the passing of an initializer to coalesce is temporary

* cleaning up dead code

* handled unsqueeze

* reworked node initialization and dim inference

* mainly cleanup

* changed how io use is tracked, moved unsqueeze remapping out of dim inference

* `cargo xtask run-checks all` now passes

* added a fixme and a few doc strings

* removing println and dead code

* spaces in doc strings

* altered top sort to work on node proto, moved prior to node gen

* Update ir.rs

* Update from_onnx.rs

removed dead code

* updated doc string

* camalcased Onnx Graph Builder

* removed self import?
2024-02-24 10:51:58 -06:00
Arjun31415 8e23057c6b
Feature Addition: PRelu Module (#1328) 2024-02-24 10:24:22 -05:00
Yu Sun 1da47c9bf1
feat: add prelude module for convenience (#1335) 2024-02-24 10:17:30 -05:00
Tushushu 27f2095bcd
Implement Instance Normalization (#1321)
* config

* rename as instances, otherwise won't work

* refactor

* InstanceNormConfig

* remove unused var

* forward

* rename

* based on gn

* unit tests

* fix tests

* update doc

* update onnx doc

* renaming method

* add comment

---------

Co-authored-by: VungleTienan <tienan.liu@vungle.com>
2024-02-23 23:31:43 -06:00
Guillaume Lagrange f5bd2a474f
Check that pa_type is valid before checking if is_binary (#1354) 2024-02-23 09:44:03 -06:00
Dilshod Tadjibaev 08302e38fc
Fix broken test and run-checks script (#1347) 2024-02-23 10:06:51 -05:00
Aasheesh Singh c86db83fa9
Add support for Any, All operations to Tensor (#1342)
* add any, all op implementation for all tensor types

* add op to burn-book

* fix formatting

* refactor tensor operations from numeric to BaseOps.

* fix book doc

* comments fix and add more tests
2024-02-23 10:06:31 -05:00
Dilshod Tadjibaev d6e859330f
Pytorch message updates (#1344)
* Update pytorch-model.md

* Update error.rs
2024-02-22 12:12:50 -06:00
Nathaniel Simard b256c0404e
Refactor/wgpu/memco (#1340) 2024-02-22 07:59:54 -05:00
Guillaume Lagrange bff4961426
Add enum module support (#1337) 2024-02-21 17:03:34 -05:00
Sylvain Benner 4427768570
[refactor] Move burn crates to their own crates directory (#1336) 2024-02-20 13:57:55 -05:00