Commit Graph

336 Commits

Author SHA1 Message Date
Dilshod Tadjibaev c7d4c23f97
Support for non-contiguous indexes in PyTorchFileRecorder keys (#1432)
* Fix non-contiguous indexes

* Update pytorch-model.md

* Simplify multiple forwards
2024-03-07 13:40:57 -06:00
Dilshod Tadjibaev b12646de0a
Truncate debug display for NestedValue (#1428)
* Truncate debug display for NestedValue

* Fix failing tests
2024-03-07 08:06:31 -05:00
Dilshod Tadjibaev 545444c02a
PyTorchFileRecord print debug option (#1425)
* Add debug print option to PyTorchFileRecorder

* Updated documentation and improved print output

* Improve print wording

* Updated per PR feedback
2024-03-06 16:11:37 -06:00
Nathaniel Simard b429cc39c1
Splitted the JIT stuff from the Wgpu stuff (#1417) 2024-03-06 11:23:53 -05:00
jackdarlison 3ff6e7170e
Switched arguments in `reshape_args_usize` check (#1409) 2024-03-06 08:45:12 -05:00
Aasheesh Singh 0c92c8c8eb
Autodiff/training support for Nearest Interpolation (#1414)
Add training support for nearest interpolation

---------

Co-authored-by: yurzhang <yurzhang.oi@gmail.com>
Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
2024-03-06 00:12:05 -05:00
Dilshod Tadjibaev 0601dc778b
Update operation.rs (#1418) 2024-03-05 12:43:02 -06:00
Dilshod Tadjibaev c7834e4658
Tensor `permute` operator (#1410) 2024-03-05 12:29:13 -05:00
Dilshod Tadjibaev 4ed90a988e
Add `bool()` op for numerical tensor (#1402)
Fixes #1395
2024-03-04 12:39:17 -06:00
Nathaniel Simard efbe818465
Refactor wgpu max pooling (#1398) 2024-03-04 13:23:11 -05:00
Louis Fortier-Dubois 046d975b76
Migrate reduce dim + More JIT instructions + Major wgpu reduce dim refactor (#1396) 2024-03-04 10:48:52 -05:00
Guillaume Lagrange 16d7666611
Add `argwhere` and `nonzero` boolean tensor ops (#1394)
* Add argwhere and nonzero bool tensor ops

* Fix wasm build

* Add missing vec

* Fix wasm cfg placement

* Fix comment
2024-03-04 08:33:59 -05:00
yurzhang 7d44f0b2d7
Interpolate tensor operation (Inference Only) (#1246)
* squash

feat: bilinear interpolation for tch, ndarray and wgpu backend

fix: reduce test case size to avoid exceeding floating-point precision limits

feat: support nearest-neighbor interpolation for ndarray backend

feat: support nearest-neighbor interpolation for wgpu backend

feat: support fusion backend

fix: no-std support

build: upgrade dependencies

* feat: bicubic interpolation for ndarray backend

* fix: test case precision

* feat: bicubic interpolation for wgpu backend

* Update Cargo.lock

---------

Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
Co-authored-by: Aasheesh Singh <20820983+ashdtu@users.noreply.github.com>
2024-03-02 12:01:35 -06:00
Dilshod Tadjibaev d43a0b3f90
Add is_close and all_close tensor operators (#1389)
* Add is_close and all_close tensor operators

* Fix broken build issues

* Fix the table

* Add tests to candle
2024-03-01 15:37:14 -06:00
Matt Thompson 201e7f87c9
Element wise min/max between a pair of tensors (#1385)
* Added min_pair() and max_pair() methods to numeric Tensors

* Update book with added max_pair() and min_pair() methods.

* Fix spelling typo in comments

* Update comments per change requests
2024-03-01 15:56:48 -05:00
Guillaume Lagrange 3d93e6dae9
Add `not_equal` and `not_equal_elem` tensor ops (#1374)
* Fix tensor.equal_elem usage in book

* Add not_equal and not_equal_elem tensor ops

* Fix "element-wise" usage for correctness and uniformity

* Add bool_not_equal test
2024-03-01 12:13:25 -06:00
Louis Fortier-Dubois 9a0b2d6e7e
bool cast kernel (#1391) 2024-03-01 12:52:28 -05:00
Dilshod Tadjibaev 688958ee74
Enhance PyTorchRecorder to pass top-level key to extract state_dict (#1300)
* Enhance PyTorchRecorder to pass top level key to extract state_dict

This is needed for Whisper weight pt files.

* Fix missing hyphens

* Move top-level-key test under crates

* Add sub-crates as members of workspace

* Update Cargo.lock

* Add accidentally omitted line during merge
2024-02-29 12:57:27 -06:00
Guillaume Lagrange 4efc683df4
Upgrade to candle 0.4.1 (#1382)
* Fix python main entrypoint in book example

* Remove candle windows safeguards (#1178)

* Bump candle-core from 0.3.3 to 0.4.1

* Remove windows current known issue
2024-02-29 11:29:11 -06:00
Nathaniel Simard 40bf3927f0
Migrate wgsl index shaders to gpu representation (#1378) 2024-02-28 16:17:47 -05:00
Yu Sun 330552afb4
docs(book-&-examples): modify book and examples with new `prelude` module (#1372) 2024-02-28 13:25:25 -05:00
Nathaniel Simard 57887e7a47
Refactor/elemwise/kernel selection + dynamic fused inplace operations and broadcasting (#1359) 2024-02-27 08:41:31 -05:00
Nathaniel Simard bdec8d5813
[Refactor - JIT] Gather Scatter new implementations (#1356) 2024-02-26 17:20:09 -05:00
Louis Fortier-Dubois 576bb44bc8
Feat/autodiff/checkpoint ops (#1358) 2024-02-26 17:19:09 -05:00
Mathias Insley bb5e6faff2
Feat/autotune int ops (#1136)
* Add int_random to int tensor ops

* Int random for tch backend

* Int random for burn-fusion

* int random for autodiff

* Int random for candle backend

* Int random for ndarray backend

* Int random for wgpu backend

* Merge imports

* Typo

* Shader file for int uniform distribution

* Create AutotuneOperationSet and public int_sum_dim_autotune

* Adjust bounds to 0..10

* Create uniform_int_kernel, unit tests, use new kernel

* Reduction kernels for regular and shared memory sum_dim int operations

* Macro that accomadates wgpu IntElement

* Add autotuning to int_mean_dim

* Use correct macro for Int autotuning

* Add int_mean_dim_shared_memory

* Add int_mean_dim and unit test

* Create autotunables for mean_dim

* Run fmt

* Remove comment

* Finish resolving merge conflict, fix doc

* Make the element trait bound a parameter to reduce_tune_ops macro

* Update book

* Fix requested change

* Change range to [0, 255] and update test accordingly

* Forgot to include candle in last commit

* Fix comment

* Use correct int autotune for mean dim

* Fix typo- not sure how this passed earlier

* Resolve syntax issues from merge

* Fix cast_float

* Saving here

* Continue fixing merge conflicts, all tests pass locally

* Run fmt

* Change cast_float to cast_u32_to_float

* Make uniform_int_inner_loop safer

* Be even more explicit about u32 casts

* Skip an intermediate step and cast directly to u32

* Replace JitElement + Element with IntElement

* Run fmt

* This should fix the CI

* This time for sure
2024-02-26 14:53:21 -05:00
Joshua Ferguson 706e0ebce2
Parser rewrite (#1296)
* Running into issues with identity nodes

* Vec<RefCell<Node>> seems to work for this

* back to passing tests

* Reworked IO into separate struct

* working towards exploiting topological ordering and more informative ident errors

* the passing of an initializer to coalesce is temporary

* cleaning up dead code

* handled unsqueeze

* reworked node initialization and dim inference

* mainly cleanup

* changed how io use is tracked, moved unsqueeze remapping out of dim inference

* `cargo xtask run-checks all` now passes

* added a fixme and a few doc strings

* removing println and dead code

* spaces in doc strings

* altered top sort to work on node proto, moved prior to node gen

* Update ir.rs

* Update from_onnx.rs

removed dead code

* updated doc string

* camalcased Onnx Graph Builder

* removed self import?
2024-02-24 10:51:58 -06:00
Arjun31415 8e23057c6b
Feature Addition: PRelu Module (#1328) 2024-02-24 10:24:22 -05:00
Yu Sun 1da47c9bf1
feat: add prelude module for convenience (#1335) 2024-02-24 10:17:30 -05:00
Tushushu 27f2095bcd
Implement Instance Normalization (#1321)
* config

* rename as instances, otherwise won't work

* refactor

* InstanceNormConfig

* remove unused var

* forward

* rename

* based on gn

* unit tests

* fix tests

* update doc

* update onnx doc

* renaming method

* add comment

---------

Co-authored-by: VungleTienan <tienan.liu@vungle.com>
2024-02-23 23:31:43 -06:00
Guillaume Lagrange f5bd2a474f
Check that pa_type is valid before checking if is_binary (#1354) 2024-02-23 09:44:03 -06:00
Dilshod Tadjibaev 08302e38fc
Fix broken test and run-checks script (#1347) 2024-02-23 10:06:51 -05:00
Aasheesh Singh c86db83fa9
Add support for Any, All operations to Tensor (#1342)
* add any, all op implementation for all tensor types

* add op to burn-book

* fix formatting

* refactor tensor operations from numeric to BaseOps.

* fix book doc

* comments fix and add more tests
2024-02-23 10:06:31 -05:00
Dilshod Tadjibaev d6e859330f
Pytorch message updates (#1344)
* Update pytorch-model.md

* Update error.rs
2024-02-22 12:12:50 -06:00
Nathaniel Simard b256c0404e
Refactor/wgpu/memco (#1340) 2024-02-22 07:59:54 -05:00
Guillaume Lagrange bff4961426
Add enum module support (#1337) 2024-02-21 17:03:34 -05:00
Sylvain Benner 4427768570
[refactor] Move burn crates to their own crates directory (#1336) 2024-02-20 13:57:55 -05:00