Commit Graph

368 Commits

Author SHA1 Message Date
Guillaume Lagrange 21a0dee4af Fix clippy 2024-09-25 08:01:45 -04:00
Guillaume Lagrange f6b538fc84 Remove unused field 2024-09-24 15:55:52 -04:00
Guillaume Lagrange 15f9d2642f Add note on handles/streams order 2024-09-24 15:53:47 -04:00
Guillaume Lagrange 93d137ddd9 Add comment 2024-09-24 15:51:04 -04:00
Guillaume Lagrange e04c8ad894 Add q_to_device 2024-09-24 15:46:56 -04:00
Guillaume Lagrange d06533d382 Fix jit q_to_device 2024-09-24 14:58:10 -04:00
Guillaume Lagrange 561841b439 Add quantization tests for fusion 2024-09-24 14:56:38 -04:00
Guillaume Lagrange f94944612b Add fusion quantize/dequantize and from/into data 2024-09-24 14:48:11 -04:00
Guillaume Lagrange a6f7a5e532
Remove const D generic (#2298)
* Remove const D generic

* Missing merge conflicts
2024-09-24 08:35:52 -04:00
Genna Wingert 97af8c6d28
Introduce autotuning to `conv2d` and `conv_transpose2d` with a new `im2col`/`GEMM` algorithm (#2287) 2024-09-23 15:54:50 -04:00
Genna Wingert 2c8514ce7f
Add deform_conv2d as implemented in torchvision (#2147) 2024-09-23 15:17:23 -04:00
王翼翔 13ad4d285d
doc: improve doc for burn-tch (#2288)
* doc: improve doc for burn-tch

* improve doc about config.toml

* improve doc about config.toml
2024-09-23 07:52:10 -04:00
Nathaniel Simard 20ab5e31d7
Chore: Update CubeCL (#2292) 2024-09-21 13:28:07 -04:00
Guillaume Lagrange aa79e36a8d
Add more quantization support for burn-jit (#2275)
* Add cubecl quantization kernels and QTensorOps for burn-jit

* Fix typo

* Fix output vec factor

* Fix output dtype size_of

* Remove unused code in dequantize test

* Fix dequantize vectorization

* Handle tensors when number of elems is not a multiple of 4

* Support quantize for tensors with less than 4 elems (no vectorization)

* Fix equal 0 test

* Add quantize/dequantize tests

* Add q_to_device

* Refactor kernels for latest cubecl

* intermediate i32 cast

* Fix size_of output type

* Use strict=false to ignore floating point precision issues with qparams equality

* Only check that lhs & rhs strategies match (but not strict on qparams values)

* Use assert_approx_eq on dequant values

* Reduce precision for flaky test

* Remove todo comment

* Add comment for cast to unsigned

* More comment

---------

Co-authored-by: louisfd <louisfd94@gmail.com>
2024-09-17 10:08:20 -04:00
Asher Jingkong Chen 7ac5deebe2
Refactor burn-tensor: Split conv backward ops to allow conditional gradient computation (#2278) 2024-09-16 10:15:27 -04:00
nathaniel 395d84ce71 Fix comments 2024-09-16 09:10:34 -04:00
Periwink a1d2b13e3e
add comments to burn fusion (#2130) 2024-09-16 09:02:12 -04:00
Guillaume Lagrange 6f0e61aa4f
Change ndarray mask_where implementation to correctly deal with NaNs (#2272)
* Change ndarray mask_where implementation to correctly deal with NaNs

* Add test
2024-09-13 15:16:39 -04:00
Nathaniel Simard 58ce502498
Fix (#2269) 2024-09-10 13:36:00 -04:00
Nathaniel Simard d3fbdeaa48
Fix CI (#2268) 2024-09-10 12:13:48 -04:00
Genna Wingert 17050db57e
Migrate cubecl macro (#2266) 2024-09-10 11:31:02 -04:00
Guillaume Lagrange eb899db16c
Add ops w/ default implementation for `QTensorOps` (#2125)
* Add q_* ops to match float ops

* Refactor q_* ops w/ dequant_op_quant macro

* Comparison ops are already implemented by default to compare dequantized values

* Add default arg min/max implementation and fix tch implementation

* Avoid division by zero scale

* Add default q_gather implementation (tch does not support on quantized tensor)

* Add warning instead for tch quantize_dynamic

* Call chunk backend implementation

* Add QFloat check for q_ ops

* Add tch q_min/max_dim_with_indices

* Add q_ ops tests

* Clippy fix

* Remove dead code/comments

* Fix quantization tests precision

* Set higher tolerance for ndarray backend

* Remove comment
2024-09-09 12:21:47 -04:00
Joshua Ferguson 9e9451bb60
simplify scope tracking in burn-import (#2207)
* simplify scope tracking in burn-import

* removed unecessary return statement
2024-09-09 12:19:26 -04:00
Asher Jingkong Chen ccb5b2214e
Fix burn-jit conv2d excessive loop unrolling (#2263)
* Related to issue #2260
2024-09-09 11:16:13 -04:00
Nathaniel Simard 94cd8a2556
Perf/slice (#2252) 2024-09-09 11:08:39 -04:00
Paul Wagener 0f191e67aa
Fix panic messages being invisible in tui mode (#2226)
* Fix panic messages being invisible in tui mode

Currently when a panic happens the message gets printed to the alternate screen which gets erased after the terminal is reset to raw mode in the TuiMetricsRenderer drop code.

That leaves users unable to see the panic message (issue #2062).

This commit changes TuiMetricsRenderer to reset the terminal first during a panic and then running the panic handler.

* Use PanicInfo to support Rust version < 1.82
2024-09-06 16:22:00 -04:00
Nathaniel Simard a567c6e888
Fusion mix precision (#2247) 2024-09-05 10:53:26 -04:00
Asher Jingkong Chen fc311323d9
[burn-autodiff] Fix abs NaN when output is 0 (#2249) 2024-09-05 09:03:24 -04:00
Adrian Müller 6b51b73a5f
Fix ONNX where op for scalar inputs (#2218)
* Fix ONNX where op dim_inference for scalar inputs

* Rewrite ONNX Where codegen to support scalars

* ONNX Where: Add tests for all_scalar inputs

---------

Co-authored-by: Guillaume Lagrange <lagrange.guillaume.1@gmail.com>
2024-09-03 11:17:18 -04:00
Guillaume Lagrange 59d41bd4b2
Remove copy restriction for const generic modules (#2222) 2024-09-03 09:39:12 -04:00
Guillaume Lagrange cc214d366c
Nonzero should return an empty vec for zero tensors (#2212)
* Nonzero should return an empty vec for zero tensors

* Add nonzero empty test

* Add missing import

---------

Co-authored-by: Nathaniel Simard <nathaniel.simard.42@gmail.com>
2024-09-03 09:00:58 -04:00
Paul Wagener c1b61033f4
Fix compile for dataset crate with vision feature (#2228)
This fixes the compile error when burn is compiled with only the `dataset` and `vision` feature enabled

burn = { default-features = false, features = ["dataset", "vision"] }
2024-09-01 17:03:37 -04:00
Guillaume Lagrange 09a15e7e15
Avoid 0 denominator in interpolate frac (#2224) 2024-09-01 16:37:32 -04:00
Paul Wagener 23622d765d
Don't panic when the progress is > 1.0 (#2229)
Ratatui asserts that gauges don't have a progress greater than 1.0
This can happen if a dataset reports a lower len() than it actually provides.

This change prevents a panic when the `Progress::items_processed` is greater than the `Progress::items_total`
2024-09-01 16:33:25 -04:00
王翼翔 66ee3bb3bc
Update huber.rs (#2232) 2024-09-01 16:31:07 -04:00
Nathaniel Simard 0dbb7f7e91
Chore: Update cubecl (#2219) 2024-08-30 15:28:00 -04:00
Guillaume Lagrange a9abd8f746
Add missing output padding to conv transpose ONNX (#2216)
* Add output_padding support for ONNX ConvTranspose

* Add missing codegen

* Fix output padding codegen test
2024-08-29 14:07:00 -04:00
Dilshod Tadjibaev 28c2d4e3cd
Update SUPPORTED-ONNX-OPS.md (#2217) 2024-08-29 14:06:42 -04:00
Adrian Müller e8ea9e27c2
Improve ONNX import tensor shape tracking (#2213)
- Calculate result of broadcasting in dim_inference
- keep Shape info when converting from Argument to TensorType
- Remove a few sources of Dim = 0 Tensors, create Scalars instead
- Clean up dim_inference a bit
2024-08-29 14:06:30 -04:00
Adrian Müller 2f4c5ac0a1
Feat: Allow onnx-import expand op with non-const shapes (#2189)
* Feat: Allow onnx-import expand op with non-const shapes

* Generalize ONNX Expand across IntElem
2024-08-29 13:15:44 -04:00
Sylvain Benner a88c69af4a
Refactor xtask to use tracel-xtask and refactor CI workflow (#2063)
* Migrate to xtask-common crate

* Fix example crate name for simple-regression

* Refactor CI workflows

* Flatten linux workflows

* Install grcov and typos from binaries

Although xtask-common support auto-installation of these tools via cargo
it is a lot faster to install them via the distributed binaries

* [CI] Update Rust caches on failure

* [CI] Add shell bash to jobs steps

* [CI] Try cache all crates

* Fix no-std tests not executing

* [CI] Add CARGO_INCREMENTAL 0

* Exclude tch and cuda from tests and merge crates and examples steps

* Fix some typos found with typos cli

* Add Windows and MacOS jobs

* Only test no-std with default rust target

* Fix syntax in composite action setup-windows

* Enable incremental build

* Upate cargo alias for xtask

* Bump to github action checkout v4

* Revert to tch 0.15 and disable WGPU on windows

* Fix color in output

* Add Test command

* Test long output errorring

* Build and test workspace before additional builds and tests

* Disable wgpu tests on windows

* Remove tests- prefix in CI workflow jobs name

* Add Checks command

* Rename ci workflow jobs

* Execute windows and macos CI tests on rust stable only

* Rename integration test files with a test_ prefix

* Fix format

* Don't auto-correct "arange" with typos

* Fix typos in code

* Merge unit and integration tests steps

* Fix macos tests

* Fix coverage step

* Name publish-crate workflow

* Fix bad cache name for macos

* Reorganize commands and get rid of the ci command

* Fix dispatch to customized commands for Burn

* Update to last version of tracel-xtask

* Remove unnecessary shell bash in ci workflow

* Update cargo.lock

* Fix format

* Bump tracel-xtask

* Simplify dispatch of base commands using updated macro

* Update to last version of tracel-xtask

* Adapt legacy run_checks script with new xtask commands

* Run xtask in debug for faster compilation time

* Ditch build step in ci and enable coverage for stable linux only

* Freeze tracel-xtask to specific commit rev

* Update cargo.lock

* Update Step 6 of CONTRIBUTING guidelines about run-checks script

* Remove unneeded CI and CD paragraphgs in CONRIBUTING.md

* Change cache version

* Fix typos

* Use centralized actions and workflows

* Update to last version of tracel-xtask

* Update CONTRIBUTING file to mention integration tests

* Add custom build for thumbv6m-none-eabi

* Ignore onnx files for typos check

* Fix action and workflow paths in github workflows

* Fix custom builds on MacOS

* Bump tracel-xtask crate to last version

* Update Cargo.lock

* Update publish workflow to use reusable workflow in tracel repo

* Add --ci flag for build and test commands
2024-08-28 15:57:13 -04:00
AlteredOxide 0292967000
Feature/codegen gather indices greater than rank 1 (#2199)
* implemented muli-dim index for GatherNode

The `NodeCodegen` impl for `GatherNode` now performs gather in complete
accordance with the ONNX Gather spec.
- a `gather` function was added to the gather.rs file
- `gather()` is now called within the codegen instead of `tensor.select()`
- a test with two test cases have been added
    - test axes 0 and 1
    - both use 2D index tensors

* add gather_onnx to numeric api

Added int and float implementations of gather to the burn-tensor numeric
api:
- named the methods `gather_onnx` to not be confused with the current
  `gather`
- these implementations follow the `Gather` ONNX spec

Updated the gather*.py variants and their onnx outputs

* modified files didn't end up in last commit

* tests passing for onnx gather

The implementation of gather for the ONNX `Gather` spec is tentatively
complete:
- py test models are updated
- onnx_tests are modified and passing: `gather`, `gather_scalar`, and
  `gather_shape`
- node/gather tests are passing

NOTE: The two additional tests in crates/burn-import/src/burn/node/gather.rs that test
the actual functionality of gather are likely to be deleted, since they
are redundant to the tests in
crates/burn-import/onnx-tests/tests/onnx_tests.rs.

* inlined onnx gather within codegen

* rm gather_onnx from public api; rm unnecessary tests

* add comments to gather py models

* some codegen changes; formatting to appease run-checks

- Some necessary changes and improvements to the codegen inlined code
  after translating from public api (removed in previous commit).
- Changed some formatting that run-checks complained about.

* simplify gather codegen; include 1d and 2d onnx tests

Modified the `Gather` codegen per requested changes:
- combined match statements on index
- remove use of `alloc::vec::Vec`
- use map -> collect instead of procedural
- include a 1d index gather onnx test
- remove superflous tests

* delete unused gather.onnx
2024-08-28 07:51:19 -04:00
mepatrick73 795201dcfc
Select kernel from CPA to CubeCL (#2168)
---------

Co-authored-by: louisfd <louisfd94@gmail.com>
2024-08-27 15:17:58 -04:00
syl20bnr 8e78106680 Bump burn version to 0.15.0 2024-08-27 15:13:40 -04:00
nathaniel 4e99ddecdf Fix burn-import version to onnx-ir 2024-08-27 12:58:08 -04:00
Nathaniel Simard 79cd3d5d21
Fix gather unchecked kernel (#2206) 2024-08-26 12:23:02 -04:00
Nathaniel Simard 978ac6c4ec
Chore: Update to newer cubecl version (#2181) 2024-08-25 15:33:16 -04:00
mepatrick73 0beec0e39e
Scatter kernel from cpa to cubecl (#2169) 2024-08-25 13:47:16 -04:00
mepatrick73 c94e743829
Tensor type indent fix (#2196)
* pad-input-fix: adding support for pads as attributes

* final fix

* undo pad changes
2024-08-23 12:46:31 -04:00
mepatrick73 2c12d58cd8
pad-input-fix: adding support for pads as attributes (#2195)
* pad-input-fix: adding support for pads as attributes

* fix: making asked changes

* clippy fix
2024-08-23 12:46:14 -04:00