Commit Graph

111 Commits

Author SHA1 Message Date
Nathaniel Simard 20ab5e31d7
Chore: Update CubeCL (#2292) 2024-09-21 13:28:07 -04:00
Guillaume Lagrange aa79e36a8d
Add more quantization support for burn-jit (#2275)
* Add cubecl quantization kernels and QTensorOps for burn-jit

* Fix typo

* Fix output vec factor

* Fix output dtype size_of

* Remove unused code in dequantize test

* Fix dequantize vectorization

* Handle tensors when number of elems is not a multiple of 4

* Support quantize for tensors with less than 4 elems (no vectorization)

* Fix equal 0 test

* Add quantize/dequantize tests

* Add q_to_device

* Refactor kernels for latest cubecl

* intermediate i32 cast

* Fix size_of output type

* Use strict=false to ignore floating point precision issues with qparams equality

* Only check that lhs & rhs strategies match (but not strict on qparams values)

* Use assert_approx_eq on dequant values

* Reduce precision for flaky test

* Remove todo comment

* Add comment for cast to unsigned

* More comment

---------

Co-authored-by: louisfd <louisfd94@gmail.com>
2024-09-17 10:08:20 -04:00
Nathaniel Simard 58ce502498
Fix (#2269) 2024-09-10 13:36:00 -04:00
Nathaniel Simard d3fbdeaa48
Fix CI (#2268) 2024-09-10 12:13:48 -04:00
Genna Wingert 17050db57e
Migrate cubecl macro (#2266) 2024-09-10 11:31:02 -04:00
Guillaume Lagrange eb899db16c
Add ops w/ default implementation for `QTensorOps` (#2125)
* Add q_* ops to match float ops

* Refactor q_* ops w/ dequant_op_quant macro

* Comparison ops are already implemented by default to compare dequantized values

* Add default arg min/max implementation and fix tch implementation

* Avoid division by zero scale

* Add default q_gather implementation (tch does not support on quantized tensor)

* Add warning instead for tch quantize_dynamic

* Call chunk backend implementation

* Add QFloat check for q_ ops

* Add tch q_min/max_dim_with_indices

* Add q_ ops tests

* Clippy fix

* Remove dead code/comments

* Fix quantization tests precision

* Set higher tolerance for ndarray backend

* Remove comment
2024-09-09 12:21:47 -04:00
Asher Jingkong Chen ccb5b2214e
Fix burn-jit conv2d excessive loop unrolling (#2263)
* Related to issue #2260
2024-09-09 11:16:13 -04:00
Nathaniel Simard 94cd8a2556
Perf/slice (#2252) 2024-09-09 11:08:39 -04:00
Nathaniel Simard a567c6e888
Fusion mix precision (#2247) 2024-09-05 10:53:26 -04:00
Guillaume Lagrange 09a15e7e15
Avoid 0 denominator in interpolate frac (#2224) 2024-09-01 16:37:32 -04:00
Nathaniel Simard 0dbb7f7e91
Chore: Update cubecl (#2219) 2024-08-30 15:28:00 -04:00
mepatrick73 795201dcfc
Select kernel from CPA to CubeCL (#2168)
---------

Co-authored-by: louisfd <louisfd94@gmail.com>
2024-08-27 15:17:58 -04:00
syl20bnr 8e78106680 Bump burn version to 0.15.0 2024-08-27 15:13:40 -04:00
Nathaniel Simard 79cd3d5d21
Fix gather unchecked kernel (#2206) 2024-08-26 12:23:02 -04:00
Nathaniel Simard 978ac6c4ec
Chore: Update to newer cubecl version (#2181) 2024-08-25 15:33:16 -04:00
mepatrick73 0beec0e39e
Scatter kernel from cpa to cubecl (#2169) 2024-08-25 13:47:16 -04:00
mepatrick73 e1fed792f7
Gather CPA to CubeCL (#2165)
* working version

* cleanup

* wip

* working version of gather

* testsetsetser

* Revert "testsetsetser"

This reverts commit f37b329697.

* Reapply "testsetsetser"

This reverts commit f8ada0044e.

* Revert "testsetsetser"

This reverts commit f37b329697.

* Revert "working version of gather"

This reverts commit f5047c27c8.

* Revert "wip"

This reverts commit abaaa2dd55.

* Revert "Merge branch 'main' into index-cpa-to-cubecl"

This reverts commit 05bed8ea74, reversing
changes made to 94954fc32c.

* Revert "cleanup"

This reverts commit 94954fc32c.

* Revert "working version"

This reverts commit a06933f029.

* gather test

* fix

* fix clippy

* cleanup
2024-08-22 13:44:26 -04:00
Periwink 0435721188
Convert `reduce_dim_naive` kernel to use the `#[cube]` derive macro (#2117) 2024-08-14 10:46:37 -04:00
Nathaniel Simard ff8d0308fb
Enable cuda-jit in burn-core + in text classification example (#2160) 2024-08-12 18:22:27 -04:00
Periwink e75eebfc31
Add comments for matmul kernel (#2138) 2024-08-12 09:09:24 -04:00
Nathaniel Simard bb4a605ca6
Chore/integrate updated cubecl (#2142) 2024-08-08 16:19:39 -04:00
mepatrick73 f7639bd35a
Repeat operation (#2090)
* renaming repeat to repeat_dim

* implementing repeat function

* renaming repeat files to repeat_dim

* renaming part 2

* renaming part 3

* renaming part 4

* renaming part 5

* adding test file

* adding unit test

* adding rust book documentation

* adding function args doc

* fixing tests

* changing repeat api to match pytorch equivalent

* fixing clippy error
2024-08-02 20:33:47 -04:00
Nathaniel Simard 62a30e973c
Fix: fusion auto bound checks (#2087) 2024-08-01 09:13:10 -04:00
Nathaniel Simard f673721d27
Refactor binary op (#2085) 2024-07-31 16:18:21 -04:00
Louis Fortier-Dubois e68b9ab0cc
Refactor/jit cube/mask (#2075)
Co-authored-by: louisfd <louisfd@gmail.com>
2024-07-30 09:27:39 -04:00
Nathaniel Simard 096ec13c48
Chore/update/cubecl (#2067) 2024-07-28 12:15:02 -04:00
Nathaniel Simard 19cd67a9e2
Migration/cubecl (#2041) 2024-07-22 11:08:40 -04:00
Guillaume Lagrange 0d5025edbb
Refactor tensor quantization for q_* ops (#2025)
* Move QuantizationScheme to burn-tensor

* Refactor QuantizedTensorPrimitive to include the quantization strategy

* Fix QFloat tensor data display

* Refactor quantization methods to use scheme and qparams (on backend device)

* Fix clippy

* Fix fmt

* Add qtensor primitive tests
2024-07-19 10:39:50 -04:00
Guillaume Lagrange 7661deb258
Fix image-classsification-web + autotune flag usage (#2011) 2024-07-15 09:31:54 -04:00
Guillaume Lagrange 3afff434bd
Module weight quantization (#2000)
* Add q_into_data and q_reshape

* Fix tch quantize f16 and q_into_data

* Convert to actual dtype/kind in dequantize

* Add module quantization and q_from_data

* Fix clippy

* Add documentation

* Handle deserialize data conversion

* Fix typo

* Add calibration tests

* Fix clippy precision

* Add QTensorOps require_grad methods to avoid dequantizing

* Add Dequantize mapper docs

* Remove dead code
2024-07-15 08:20:37 -04:00
Nathaniel Simard 19f5ad7be5
Refactor/cube/expand & fix double imports (#2009)
* Refactored function

* WIP

* Basic stuff done

* Fix traits

* Cleanup

* Cleanup

* Cleanup
2024-07-12 09:18:38 -04:00
Nathaniel Simard 35345de62a
Feat/cube/slice (#2004)
* Refactor Variable types

* Sice

* Implement slice wgsl

* handle lifetime correctly

* Add cuda impl

* Update cmma

* Cleanup

* Fix tests

* Fix slice signature
2024-07-11 11:28:53 -04:00
Louis Fortier-Dubois 69be99b802
Cube: Matmul tiling (#1994) 2024-07-09 12:43:13 -04:00
Nathaniel Simard 924e3578ee
Fix CI (#1993) 2024-07-08 15:55:05 -04:00
Guillaume Lagrange c0211e2f94
Add static tensor quantization (#1963)
* Add QuantizationBackend, QTensorOps and QTensor

* Refactor QTensorOps as part of Backend trait

* Add tensor dequantize, QFloat dtype and default affine/symmetric quant

* Add ndarray default quantization implementation

* Fix clippy

* Add rayon parallel iter

* Add quantization operations to book

* Add q_shape and q_device ops to avoid converting the tensor just to get attributes

* Implement autodiff grad ops

* Mark autodiff todo for QAT

* Remove note

* Add q_inner and q_from_inner
2024-07-08 10:16:58 -04:00
Nathaniel Simard 8af2b719a1
Feat: Support trait with CubeCL (#1980) 2024-07-07 10:07:51 -04:00
Arthur Brussee 3f9e97946f
Feat: Dynamic cube count dispatch (#1975) 2024-07-06 19:17:01 -04:00
Nathaniel Simard b331290f8a
Refactor/jit/unary (#1965) 2024-07-05 19:47:24 -04:00
nathaniel 882a27c52c Revert "Revert "Implement 3D and transposed 3D convolutions. (#1945)""
This reverts commit b8b47ea6e6.
2024-07-05 18:57:01 -04:00
nathaniel b8b47ea6e6 Revert "Implement 3D and transposed 3D convolutions. (#1945)"
This reverts commit d696d74e3d.
2024-07-05 09:40:32 -04:00
Nathaniel Simard 679cfd6dfb
Refactor cube launch + support inplace operation (#1961) 2024-07-03 11:58:35 -04:00
Guillaume Charifi d696d74e3d
Implement 3D and transposed 3D convolutions. (#1945)
* Implement 3D and transposed 3D convolutions.

* Merge changes from onnx-ir #1921 pr

---------

Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
2024-07-02 17:54:35 -05:00
Nathaniel Simard 82a883a57d
Feat/cube/fma (#1947) 2024-07-02 08:32:39 -04:00
Nathaniel Simard cb6b5e7183
Feat/cube/cooperative matrix-multiply and accumulate. (#1943) 2024-07-02 08:31:00 -04:00
Arthur Brussee 849c8f453b
Consistent sync/async handling, allow more functions to be async for wasm. (#1936) 2024-07-02 08:25:28 -04:00
Nathaniel Simard 1ae1c03b2d
Refactor/cube/mutability (#1934) 2024-06-27 16:03:23 -04:00
Guillaume Lagrange cdd1fa1672
Refactor tensor data (#1916)
* Move distribution to module

* Add new TensorData with serialization support

* Implement display and from for TensorData

* Add missing Cargo.lock

* Add missing bytemuck feature

* Add zeros, ones, full and random TensorData methods

* Refactor Data -> TensorData usage

* Fix tests

Since TensorData is not generic over the element type anymore no type inference can be done by the compiler. We must explicitly cast the expected results to the expected backend type.

* Remove commented line

* Fix import

* Add record-backward-compat

* Remove dim const generic from TensorData

* Support NestedValue de/serialization with TensorData

* Fix burn-jit tests

* Remove eprinln

* Refactor onnx import to use TensorData

* Fix tch from_data

* Fix nested value serialization for u8

* Fix missing import

* Fix reduce min onnx test

* Fix deprecated attribute

* Remove shape getter

* Remove strict assert in tests

* Add tensor data as_bytes

* Add tensor check for rank mismatch

* Fix typo (dimensions plural)

* Fix error message

* Update book examples with from_data and fix Display impl for TensorData

* Add deprecation note
2024-06-26 20:22:19 -04:00
Louis Fortier-Dubois 8bf1cd60dc
Cube: variable reusability + refactor in cube macros (#1885) 2024-06-14 11:20:25 -04:00
Guillaume Lagrange 525244062f
Implement `Element` for `bool` (#1878)
* Element already implements One

* Add element module

* Add our own traits for Zero, One and ToPrimitive to support bool Element

* Fix typo

* Add basic tests for ToPrimitive with expected values

* The most important change of all

* Remove One + Zero identities

* Move zero/one outside mapv + refactor ToPrimitive -> ToElement trait

* Add num-traits to NOTICES.md
2024-06-14 09:02:38 -04:00
Arthur Brussee c873d87ac8
Add option to flush queue instead of waiting for completion. (#1864)
* Make sync_type an option on sync instead of adding submit
2024-06-13 09:56:08 -04:00