Commit Graph

95 Commits

Author SHA1 Message Date
Laurent Mazare b3484e7a5e
Fix for the RWKV models. (#1955)
* Fix for the RWKV models.

* More general fix + revert the rwkv hack.

* Remove the old hack.
2024-03-28 10:17:38 +01:00
Laurent Mazare ab86cd37c8
Support i64 in index-select on metal. (#1951)
* Support i64 in index-select on metal.

* Add some testing of index-select for all dtypes.
2024-03-27 16:30:07 +01:00
Laurent Mazare a9abde5f93
More flexible matmul contiguity checks. (#1949)
* More flexible matmul contiguity checks.

* Also relax the checks on the metal side.
2024-03-27 10:59:05 +01:00
Thomas Santerre f5dfe883d7
Extend supported dtypes for metal (im2col & upsample_2d) (#1938)
* update im2col dtype implementations

* update dtypes for upsample
2024-03-26 06:48:56 +01:00
Laurent Mazare e7f8e72588
Contiguous variant of the rope kernel. (#1929)
* Contiguous variant of the rope kernel.

* Add the cuda kernel.

* Metal kernel.
2024-03-25 09:11:20 +01:00
Laurent Mazare 1b98f84a2b
Fast kernels for rotary embeddings. (#1928)
* Fast kernels for rotary embeddings.

* Add a test for the fast CPU kernel.

* Rope cuda bindings.

* Cuda kernel.

* Metal kernel (part 1).

* Cuda kernels.

* Finish the metal kernel.

* Use the new kernels in the quantized example.

* Fix warning.
2024-03-24 22:48:52 +01:00
Thomas Santerre fee33b45c2
Add support for strided index-select on Metal (#1909)
* initial implementation

* use correct index, but still not breaking like it should have...

* fix test
2024-03-22 07:30:02 +01:00
Thomas Santerre 9563a5fee4
Add support for conv_transpose2d on Metal backend (#1903)
* add support for conv transpose 2d and add bench mark for float types

* update bench calculation

* enable testing all conv operations on metal
2024-03-21 18:08:45 +01:00
Laurent Mazare 0fddec762e
RmsNorm kernel for metal. (#1895)
* RmsNorm kernel for metal.

* Wrapper for the metal kernel.

* Get the ops to actually work.

* Fix, get the tests to pass.
2024-03-21 09:48:56 +01:00
Thomas Santerre 2a8679509e
Add support for conv_transpose1d for metal backend (#1874)
* first attempt

* progress

* integrate into metal backend

* finish and get test passing

* add other dtype support

* update transpose1d dtypes supported
2024-03-19 08:46:58 +01:00
Thomas Santerre 04a61a9c72
Add avg_pool2d metal implementation for the metal backend (#1869)
* implement metal avg pool 2d

* fixX

* add suggested precision workaround for the accumulator
2024-03-18 18:50:14 +01:00
Thomas Santerre 754fa1e813
Add support for max_pool2d for Metal backend (#1863)
* first pass at implementation of maxpool2d

* Add definitions for other dtypes

* add tests for other dtypes

* Cosmetic tweaks + re-enable maxpool2d tests for metal.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-03-18 08:33:30 +01:00
Thomas Santerre 184105792f
add test for index add and add missing match statements (#1862) 2024-03-17 22:19:12 +01:00
Thomas Santerre e316cb6997
add support for casting between all datatypes (#1860) 2024-03-17 20:55:11 +01:00
Laurent Mazare ce9fbc3682
Optimize the cat operation on contiguous tensors (#1855)
* Add a specialized kernel for copy2d.

* Move the cat operations.

* Avoid transpositions in cat.

* Bugfix.

* Bugfix for the cuda kernel.

* Add a benchmark.

* Add more testing.

* Test fix.

* Faster kernel.

* Add the missing kernel.

* Tweak the test.

* Add a metal kernel.

* Fix for the metal kernel.

* Get the tests to pass on metal.

* Also use this opportunity to fix the metal kernel for ELU.

* Add some bf16 kernels.

* Clippy fixes.
2024-03-17 10:49:13 +01:00
Thomas Santerre db8b24ae92
Add support for index u8/i64 and input f16/bf16 scatter-add on metal (#1849)
* add support and tests for scatter add on metal

* add support for all datatypes
2024-03-17 08:09:43 +01:00
Laurent Mazare e7fc1daa21
Bump the crate versions to 0.4.2. (#1821) 2024-03-08 22:01:51 +01:00
Niklas Hallqvist be5b68cd0b
Metal random-generation bug fixes (#1811)
* use_resource API misunderstood. It is not additive. Several usages must be bit-ORed together.

* The seeding was incorrect and used the address instead of the value of the passed in seed.

* Add a check that likely exhibits failure to update the seed between generation of random tensors.

* Buffer overrun, the length given to the std::ptr::copy call was in bytes, and not 32-bit units.

* By default seed the RNG with a time-based value, so that different runs may produce different output, just like the CPU engine.
Use device.set_seed if determinism is warranted.

* Revert "By default seed the RNG with a time-based value, so that different runs may produce different output, just like the CPU engine. Use device.set_seed if determinism is warranted."

This reverts commit d7302de9

Discussion in https://github.com/huggingface/candle/pull/1811#issuecomment-1983079119

* The Metal random kernel failed to set element N/2 of tensors with N elements, N being even.  The reason was that all threads but thread 0 all created 2 random samples, but thread 0 only one, i.e. an odd number.  In order to produce an even number of samples, the early termination of thread 0 should only everr occur for odd sized tensors.

* Add a test catching any deterministic tensor element in rand and randn output.

---------

Co-authored-by: niklas <niklas@appli.se>
Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
2024-03-08 16:11:50 +01:00
Laurent Mazare 5e526abc8c
Bump the version number to 0.4.1. (#1768)
* Fix the block size for some cuda kernels.

* Bump the version number to 0.4.1.
2024-02-27 14:19:59 +01:00
OlivierDehaene b60064780d
feat: add silu activation function (#1706)
* feat: add silu activation function

* use silu/arg in grad

* update candle-nn

* use node
2024-02-14 10:27:22 +01:00
Laurent Mazare a83ca2ece0
Bump the crate version to 0.4.0. (#1658) 2024-02-04 19:08:01 +01:00
Christopher Fleetwood 6d83d42efb
Merge pull request #1606 from FL33TW00D/feature/larger-batches
fix: larger batches
2024-01-29 15:31:10 +00:00
FL33TW00D b6afb46601
chore: final 2024-01-22 15:15:19 +00:00
ivarflakstad fd7c856564
Merge pull request #1533 from huggingface/ivarflakstad/metal-prng 2024-01-22 07:30:20 +01:00
FL33TW00D 73d79e6092
chore: actual fix 2024-01-19 09:35:42 +00:00
FL33TW00D b1879f17f6
chore: switch to buffer 2024-01-19 08:57:49 +00:00
FL33TW00D 4f79f5df8a
fix: larger batches 2024-01-18 14:30:14 +00:00
ivarflakstad 1cf34368b7
Merge pull request #1602 from mimiquate/fix-metal-kernel-type
Metal: Use uint8_t as output type in int64_t binary op kernel
2024-01-18 08:40:34 +01:00
Gonzalo 17e6e2d7ee
Fixes metal kernel u8 type 2024-01-17 15:47:08 -03:00
Ivar Flakstad 80b1c689f9 Revert public EncoderParam 2024-01-17 18:09:28 +01:00
Ivar Flakstad db923517b3 Merge branch 'main' into ivarflakstad/metal-prng 2024-01-17 18:03:57 +01:00
Nicolas Patry 403680f17d
Quantized GGUF style (#1523)
* Metal quantized modifications proposal.

- Add a device param, wherever needed.
- Create new QMetal storage thing that implements QuantizedType.
- Update everywhere needed.

Fix Python.

Fixing examples.

Fix: fmt + clippy + stub.

Moving everything around.

Only missing the actual implems.

Fixing everything + adding dequantized kernels.

More work.

Fixing matmul.

Fmt + Clippy

Some clippy fixes.

Working state.

Q2K Metal -> Bugged (also present in GGML).
Q4K CPU -> Bugged (present previously, new test catch it).
Q5K CPU -> Bugged (present previously).
Q8_1 Both -> Never really implemented it seems
Q8K metal -> Never implemented in metal

Fixing Q2K bug (present in ggml).

* Cleanup.

* Fix the rebase.

* Removing the fences speeds everything up and *is* correct this time...

* Cleanup the fence.

* After rebase.

* Bad code removal.

* Rebase after phi2 merge + fix replit default to CPU.

* Making the CI happy.

* More happy tests.

---------

Co-authored-by: Nicolas Patry <nicolas@Nicolass-MacBook-Pro.local>
2024-01-17 10:27:58 +01:00
Ivar Flakstad 86a8e58897 Update metal random kernel and set_seed method
* set_seed via buffer content pointer copy + did_modify_range

* ensure random.metal kernel does not write outside of buffer range when tid==0
2024-01-17 09:12:44 +01:00
Ivar Flakstad 79478ff5a1 Seed should be updated by random kernel result. 2024-01-15 11:58:25 +01:00
Ivar Flakstad ecf88a6d38 Merge branch 'main' into ivarflakstad/metal-prng 2024-01-14 17:10:54 +01:00
ivarflakstad a3d92ab226
Metal: Activate bfloat affine and add benchmark (#1543)
* Use cfg to seperate benchmark results based on features

* Add bfloat affine and benchmarks

* Fix flops calculation

* Remove allow pragma

* Avoid some unnecessary returns.

* Improve benchmarks layout

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-01-12 11:19:49 +01:00
ivarflakstad e90bcdcc7c
Metal: f16 and bf16 where_cond + benchmark (#1545)
* Use cfg to seperate benchmark results based on features

* Add metal where_cond for f16 and bf16. Add benchmark

* Remove allow pragma

* Avoid some unnecessary returns.

* Improve benchmarks layout

* Updated feature separated benchmarks

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-01-12 11:18:11 +01:00
Ivar Flakstad e06e8d0dbe fmt 2024-01-12 07:26:42 +01:00
Ivar Flakstad e63bb8661b Merge branch 'main' into ivarflakstad/metal-prng 2024-01-12 07:19:58 +01:00
Baye Dieng 85e5680277 remove metal version check 2024-01-11 21:02:03 +00:00
Baye Dieng 1327419776 close ifdef 2024-01-11 17:14:12 +00:00
Kyle McCarthy 402349d120
feat(bf16): add cast support + tests for cast + bin ops (#1524) 2024-01-11 15:49:13 +01:00
ivarflakstad d3bdd788cf
Use __HAVE_BFLOAT__ to check for bfloat support instead of metal version check (#1540) 2024-01-10 18:50:30 +01:00
Juarez Bochi ae06cb74bb
Add relu kernel for metal (#1488)
* Add relu kernel for metal

* Copy error messages proposed in #1491

* Revert non relu changes

* Fix name changes

* Fix the last of us (:

* Fix copy and paste mistakes

* Fix typo

* Revert order changes

* Revert order change

* Add deleted functions back

* Run rustfmt
2024-01-10 18:27:17 +01:00
Ivar Flakstad 6ebe043273 Merge branch 'main' into ivarflakstad/metal-prng 2024-01-07 11:52:03 +01:00
Ivar Flakstad 6bf52b9fdf Gaussian normal distribution of PRNG via Box-Muller transform 2024-01-07 11:39:46 +01:00
Ivar Flakstad 955e63c803 Implement hybrid Tausworthe + LCG psuedo random number generator in metal 2024-01-05 13:27:59 +01:00
Nicolas Patry fa3ea98ba9
Adding bfloat16 support for the cast kernels. (#1520) 2024-01-04 12:12:56 +01:00
Gonzalo 0a245e6fa4
Metal: support unary abs (#1503)
* Metal: support unary abs

* cargo fmt
2023-12-30 00:00:12 +01:00
Gonzalo 87d7f81b43
Metal: more u8/u32 (#1502)
* Adds more metal u8

* Metal: more u32
2023-12-29 23:56:21 +01:00