Commit Graph

354 Commits

Author SHA1 Message Date
Eric Buehler e2b6b367fa
Add some fast Metal MLX SDPA kernels (#2584)
* Add some fast Metal MLX SDPA kernels (#32)

* Sketch the sdpa kernel

* Add full sdpa kernel,

* Add test

* Add vectorized kernel for decoding

* Update tests

* Add some docs

* Fix sdpa_vector names

* Add softcapping for vectorized sdpa

* Add softcapping for full sdpa

* Add support for head dim 32, 96, 256

* Add support for head dim 32, 96, 256

* Update docs

* Add update notice

* Clippy and format

* Conditional compilation for bf16

* Use it in quantized llama

* Some review comments

* Use set_params!

* Remove unused

* Remove feature

* Fix metal sdpa for v stride

* Remove comma

* Add the dim method to layout and shape.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-11-05 09:28:00 +01:00
Laurent Mazare 3fba2b5fc4
Add the SmolLM2 models. (#2595)
* Add the SmolLM2 models.

* More SmolLM2 support.
2024-11-03 17:11:12 +01:00
Czxck001 530ab96036
Support Skip Layer Guidance (SLG) for Stable Diffusion 3.5 Medium (#2590)
* support skip layer guidance (slg) for stable diffusion 3.5 medium

* Tweak the comments formatting.

* Proper error message.

* Cosmetic tweaks.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-11-01 18:10:40 +01:00
Laurent Mazare 7ac0de15a9
Lazy upcasting for t5. (#2589) 2024-10-30 18:08:51 +01:00
Czxck001 d232e132f6
Support sd3.5 medium and MMDiT-X (#2587)
* extract attn out of joint_attn

* further adjust attn and joint_attn

* add mmdit-x support

* support sd3.5-medium in the example

* update README.md
2024-10-30 06:19:07 +01:00
Laurent Mazare 37e0ab8c64
Stable diffusion 3.5 support. (#2578)
* Stable diffusion 3.5 support.

* Clippy fixes.

* CFG fix.

* Remove some unnecessary clones.

* Avoid duplicating some of the code.
2024-10-27 10:01:04 +01:00
Zack Angelo a2e9d41b20
use softmax_last_dim (metal and cuda kernel) in llama attention layer (#2572) 2024-10-23 20:07:09 +02:00
Laurent Mazare 3d1dc06cdb
Enable stable-diffusion 3 on metal. (#2560) 2024-10-14 08:59:12 +02:00
Anubhab Bandyopadhyay f553ab5eb4
Adds support for Stella_en_v5 embedding model - 1.5B variant (#2551)
* Stella_en_1.5B_v5

* Separated  creation. This is a critical step for numerical accuracy and would be documented in the readme

* EmbedDim would require clone and copy

* WIP: example

* Examples added

* a litte more in README
2024-10-13 23:09:12 +02:00
Mikarific 41ade774e8
fix: Allow marian configs to deserialize from json. (#2556) 2024-10-13 23:05:50 +02:00
Czxck001 ca7cf5cb3b
Add Stable Diffusion 3 Example (#2558)
* Add stable diffusion 3 example

Add get_qkv_linear to handle different dimensionality in linears

Add stable diffusion 3 example

Add use_quant_conv and use_post_quant_conv for vae in stable diffusion

adapt existing AutoEncoderKLConfig to the change

add forward_until_encoder_layer to ClipTextTransformer

rename sd3 config to sd3_medium in mmdit; minor clean-up

Enable flash-attn for mmdit impl when the feature is enabled.

Add sd3 example codebase

add document

crediting references

pass the cargo fmt test

pass the clippy test

* fix typos

* expose cfg_scale and time_shift as options

* Replace the sample image with JPG version. Change image output format accordingly.

* make meaningful error messages

* remove the tail-end assignment in sd3_vae_vb_rename

* remove the CUDA requirement

* use default_value in clap args

* add use_flash_attn to turn on/off flash-attn for MMDiT at runtime

* resolve clippy errors and warnings

* use default_value_t

* Pin the web-sys dependency.

* Clippy fix.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-10-13 22:08:40 +02:00
SethWen 0d96ec31e8
feat: intergrate chinese clip and add example (#2555)
* start to impl chinese clip

* impl vision model

* copy code from bert

* refactor use

* refactor use again

* fix text model

* refactor

* try to fix text model

* tuning

* tuning chinese clip

* delete useless code

* revert code

* Clippy fixes.

* Also apply cargo fmt.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-10-10 15:18:55 +02:00
Akshay Ballal 937e8eda74
Add BertForMaskedLM to support SPLADE Models (#2550)
* add bert for masked lm

* working example

* add example readme

* Clippy fix.

* And apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-10-07 23:28:21 +02:00
Akshay Ballal 888d886dd8
Add ColPali (#2524)
* add colpali

* cleanup

* fix clippy
2024-10-01 11:48:39 +02:00
Laurent Mazare dfe9a00683
Pixtral polishing. (#2522)
* Pixtral polishing.

* Clippy fix.
2024-09-30 21:23:54 +02:00
Laurent Mazare 683ab698de
Add Pixtral. (#2521)
* Add Pixtral.

* More pixtral vision encoder.

* Sketch a pixtral example.

* Sketch a pixtral example.

* Better image loading.

* Support loading images embedded in safetensor files.

* Clippy fixes.

* Add the llava multimodal adapter.

* Add more of the llava bits.

* Add the pixtral config.

* More pixtral inference.

* Add the text generation bits.

* Get the example to work.

* Bugfix.

* Run some bits of the model in f32.

* Blessed version :)

* Better rope frequency computations.

* README update.
2024-09-30 19:31:14 +02:00
Laurent Mazare 2f49e1b534
Add PaliGemma. (#2519)
* Add PaliGemma.

* PaliGemma inference loop.

* Running PaliGemma example.

* Tweak the prompt.
2024-09-29 19:56:56 +02:00
Laurent Mazare 0ebb38813b
Paligemma siglip vision config (#2518)
* Add the paligemma siglip vision config.

* More paligemma configs.
2024-09-29 17:53:52 +02:00
Laurent Mazare 261ed65f36
Add the SigLIP model. (#2515)
* Add the SigLIP model.

* Add more to the forward pass of the vision model.

* Complete the forward pass.

* Add the siglip example.

* Fix.

* Another fix.

* Get everything in place.

* Add a readme.
2024-09-28 23:48:00 +02:00
Laurent Mazare 62525e8352
Remove some extra whitelines. (#2513) 2024-09-28 14:41:28 +02:00
Laurent Mazare ad8a4c5e5a
Add some llama-3.2 examples. (#2508)
* Add some llama-3.2 examples.

* Support tie-word-embeddings for llama.
2024-09-26 21:00:18 +02:00
Laurent Mazare 10d47183c0
Quantized version of flux. (#2500)
* Quantized version of flux.

* More generic sampling.

* Hook the quantized model.

* Use the newly minted gguf file.

* Fix for the quantized model.

* Default to avoid the faster cuda kernels.
2024-09-26 10:23:43 +02:00
Laurent Mazare d01207dbf3
Add a RotatingKVCache. (#2493)
* Add a RotatingKVCache.

* Add some KvCache tests.

* Test the reset too.

* More kv-cache testing.

* More tests for the rotating kv-cache.

* Improve the api for the rotating cache so that the whole src tensor gets returned when it's overlarge.

* Handle contiguity + bugfix + use in mimi.

* Add a way to test the mimi streaming mode.

* Mimi streaming fixes.

* More rotating kv-cache.

* Fix the attn mask generation.

* Handle the abs case.

* Add some tests for the generated mask.
2024-09-23 13:14:32 +02:00
Juan Gomez 5fc4f17727
Adding Granite 7b Instruct model example (#2487)
* Adding Granite 7b Instruct model example

* Minor refactoring to make it a little more idiomatic

* Clippy fixes.

* * Adding a README with some information about supported Granite models
* Changing the default prompt to accomodate better the Language
  modality of the Granite 7b Instruct model

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-09-21 11:52:01 +02:00
Laurent Mazare c58c5d5b01
Add the mimi audio-tokenizer. (#2488)
* Add the mimi audio-tokenizer.

* Formatting tweaks.

* Add a full example.

* Use the transformers names.

* More renamings.

* Get encoding and decoding to work.

* Clippy fixes.
2024-09-20 14:31:20 -06:00
Laurent Mazare e3261216b1
Clippy fixes for 1.81.0. (#2461)
* Clippy fixes for 1.81.0.

* Another fix.
2024-09-05 23:46:55 +02:00
Jani Monoses 86613c00e2
MobileCLIP models S1 and S2 (#2454)
* Allow loading images with given std and mean

* OpenCLIP text encoder component

* Two MobileCLIP models

* Clippy fixes.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-08-29 15:38:58 +02:00
Jani Monoses 29e25c458d
FastViT fixes. (#2452)
* correct optional SE layer dimensions.
 * head_dim instead of num_heads is 32.
 * update test example output.
2024-08-28 11:20:09 +02:00
ilookee fdc2622686
fix: qwen2 lm_head loading #2443 (#2445)
Co-authored-by: Yi Xu <xuyi@me.com>
2024-08-23 16:50:02 +02:00
Jani Monoses ccdbe87639
Add FastViT model. (#2444) 2024-08-23 16:06:54 +02:00
Laurent Mazare 2ec8729d51
Fix for parler-tts, do not add the last slice of padding tokens. (#2442)
* Fix for parler-tts, do not add the last slice of padding tokens.

* Support for the mini model.
2024-08-22 23:22:03 +02:00
Laurent Mazare 236b29ff15
Add the DAC model. (#2433)
* Add the DAC model.

* More quantization support.

* Handle DAC decoding.

* Plug the DAC decoding in parler-tts.
2024-08-19 08:59:51 +02:00
Laurent Mazare 58197e1896
parler-tts support (#2431)
* Start sketching parler-tts support.

* Implement the attention.

* Add the example code.

* Fix the example.

* Add the description + t5 encode it.

* More of the parler forward pass.

* Fix the positional embeddings.

* Support random sampling in generation.

* Handle EOS.

* Add the python decoder.

* Proper causality mask.
2024-08-18 20:42:08 +02:00
Laurent Mazare c1b9e07e35
Add support for gemma-2. (#2425)
* Add gemma-2.

* Support a couple more models.

* Sliding window support.

* Example + readme updates.

* Update the main readme.
2024-08-17 20:31:23 +02:00
Laurent Mazare 68aa9c7320
Fix the device for the bert attention mask. (#2414) 2024-08-14 10:01:12 +02:00
Jani Monoses 35e5f31397
Add Based LLM from Hazy Research. (#2411) 2024-08-12 21:21:19 +02:00
Matthew O'Malley-Nichols 14db029494
Soft Non-Maximum Suppression (#2400)
* Soft NMS with thresholds

* NMS Test

* Soft nms w/ boxes removed below threshold

* Soft nms test

* No longer removing bounding boxes to fit Soft-NMS focus

* Initialize confidence

* Added comments

* Refactored out updating based on IOU/sigma

* Score_threshold -> confidence_threshold for clarity

* Remove bboxes below confidence threshold

* Softnms basic functionality test

* Softnms confidence decay test

* Softnms confidence threshold test

* Softnms no overlapping bbox test

* Testing confidence after no overlap test

* Single bbox and no bbox tests

* Signify test completion

* Handling result of test functions

* Checking all pairs of bboxes instead of a forward pass

* Equal confidence overlap test

* Clarified tests for implementation

* No longer dropping boxes, just setting to 0.0

* Formatted w/ cargo
2024-08-10 07:57:52 +02:00
Czxck001 dfdce2b602
Add the MMDiT model of Stable Diffusion 3 (#2397)
* add mmdit of stable diffusion 3

lint

add comments

* correct a misplaced comment

* fix cargo fmt

* fix clippy error

* use bail! instead of assert!

* use get_on_dim in splitting qkv
2024-08-05 19:26:15 +02:00
唐璜 500c9f2882
add models support and example for THUDM/glm-4 (#2362)
* add models support and example for THUDM/glm-4

* fix the ci report

* fmt

* fix

* Update README.org

* Update README.org

* fmt

* Update README.org

* README.md add codegeex4

* README.md add glm4

* Typo.

* change expect into ?

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2024-08-05 17:48:09 +02:00
Laurent Mazare 2be9bd211e
Support for mistral-nemo. (#2396) 2024-08-04 19:52:40 +02:00
Laurent Mazare aa7ac1832d
Simplify handling of flux modulations. (#2394) 2024-08-04 11:09:54 +02:00
Laurent Mazare 19db6b9723
Add the flux model for image generation. (#2390)
* Add the flux autoencoder.

* Add the encoder down-blocks.

* Upsampling in the decoder.

* Sketch the flow matching model.

* More flux model.

* Add some of the positional embeddings.

* Add the rope embeddings.

* Add the sampling functions.

* Add the flux example.

* Fix the T5 bits.

* Proper T5 tokenizer.

* Clip encoder path fix.

* Get the clip embeddings.

* No configurable weights in layer norm.

* More weights related fixes.

* Yet another shape fix.

* DType fix.

* Fix a couple more shape issues.

* DType fixes.

* Fix the latent dims.

* Fix more shape issues.

* Autoencoder fixes.

* Get some generations out.

* Bugfix.

* T5 padding.

* Clippy fix.

* Add the decode only mode.

* Fix.

* More fixes.

* Finally get some generations to work.

* Add readme.
2024-08-04 08:14:33 +02:00
Laurent Mazare 9ca277a9d7
Fix cargo fmt. (#2383)
* Fix cargo fmt.

* Clippy fix.

* Cosmetic tweaks.
2024-08-01 14:19:41 +02:00
Joan Fontanals 2e9c010609
Jina Bert Example fix and more configuration (#2191)
* fix: fix jina bert example logic

* feat: enable jina embeddings de

* feat: allow more flexibility on Jina Bert
2024-08-01 13:59:20 +02:00
Jani Monoses ac51f477eb
Add Hiera vision model. (#2382) 2024-08-01 11:59:22 +02:00
Zheng Li 4a52aeb437
bert attention mask (#1934)
* bert attention mask

* Allow for using None as a mask.

* Revert part of the changes so that the proper default mask applies.

* Cosmetic change.

* Another cosmetic tweak.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-08-01 08:26:19 +02:00
Eric Buehler 0f5cbb08b3
Add support for Llama 3.1 (#2359)
* Add Llama 3.1 rope

* Clippy

* Format

* Clippy

* Add support for multiple eos tokens:

* Untagged either

* Remove either dep and fix settings.json

* Make the max positional embeddings configurable
2024-07-26 21:32:26 +02:00
donjuanplatinum 2489a606fe
feat(candle-transformers/models/codegeex4-9b): add codegeex4-9 (#2334)
* feat(candle-transformers/models/codegeex4-9b): add codegeex4-9b transoformers

* change mod.rs

* feat(candle-examples/codegeex4-9b)

* Update codegeex4_9b.rs

* Update main.rs

* Update codegeex4_9b.rs

* Update main.rs

* fmt

* fix

* fmt

* Clippy fix.

* Remove some print statements.

* Avoid using unwrap.

* 1. add README
2. change the print fmt

* Another clippy fix.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-07-21 13:00:41 +02:00
Zhuo Jinggang c63048d374
add quantized qwen2 (#2329)
* add quantized version of qwen2 and corresponding example for qwen2-instruct

* fix quantized qwen2 clippy error
2024-07-12 10:00:03 +02:00
Jani Monoses a226a9736b
Add Mobilenet v4 (#2325)
* Support different resolutions in load_image()

* Added MobilenetV4 model.

* Add MobileNetv4 to README
2024-07-09 13:52:20 +02:00