Commit Graph

11 Commits

Author SHA1 Message Date
Laurent Mazare 30cdd769f9
Update the flash attn kernels. (#2333) 2024-07-15 20:37:36 +02:00
Nicolas Patry 30313c3081
Moving to a proper build crate `bindgen_cuda`. (#1531)
* Moving to a proper build crate `bindgen_cuda`.

* Fmt.
2024-01-07 12:29:24 +01:00
OlivierDehaene 75629981bc
feat: parse Cuda compute cap from env (#1066)
* feat: add support for multiple compute caps

* Revert to one compute cap

* fmt

* fix
2023-10-16 15:37:38 +01:00
Laurent Mazare 0e250aee4f
Shape with holes (#770)
* Shape with holes.

* rustfmt.
2023-09-08 08:38:13 +01:00
Zsombor cfcbec9fc7
Add small customization to the build (#768)
* Add ability to override the compiler used by NVCC from an environment variable

* Allow relative paths in CANDLE_FLASH_ATTN_BUILD_DIR

* Add the compilation failure to the readme, with a possible solution

* Adjust the error message, and remove the special handling of the relative paths
2023-09-08 08:15:14 +01:00
Laurent Mazare d0cdea95a5
Add back the bf16 flash-attn kernels. (#730) 2023-09-04 07:50:52 +01:00
Chengxu Yang ebcfd96d94
add c++17 flags (#452) 2023-08-15 15:29:34 +01:00
Laurent Mazare 4f92420132
Add some flash attn test (#253)
* Add some flash-attn test.

* Add the cpu test.

* Fail when the head is not a multiple of 8.

* Polish the flash attention test.
2023-07-26 20:56:00 +01:00
Laurent Mazare 2ce5f12513
Again set a few extra params in flash-attn. (#245)
* Again set a few extra params.

* Use the appropriate kernel sizes.

* Add all the kernel sizes.

* Parallel compiling.

* Reduce the amount of parallelism.

* Add the missing kernel.

* Fix a typo.

* Remove bf16 support for now.
2023-07-26 14:16:37 +01:00
Laurent Mazare 471855e2ee
Specific cache dir for the flash attn build artifacts. (#242) 2023-07-26 08:04:02 +01:00
Laurent Mazare d9f9c859af
Add flash attention (#241)
* Add some flash-attn kernel, import the code for flash-attn v2 from Dao-AILab.

* More flash attn.

* Set up the flash attn parameters.

* Get things to compile locally.

* Move the flash attention files in a different directory.

* Build the static C library with nvcc.

* Add more flash attention.

* Update the build part.

* Better caching.

* Exclude flash attention from the default workspace.

* Put flash-attn behind a feature gate.

* Get the flash attn kernel to run.

* Move the flags to a more appropriate place.

* Enable flash attention in llama.

* Use flash attention in llama.
2023-07-26 07:48:10 +01:00