f16 and local cubecl

This commit is contained in:
louisfd 2024-09-05 11:02:26 -04:00
parent a567c6e888
commit 0f1843a2d7
3 changed files with 12 additions and 18 deletions

18
Cargo.lock generated
View File

@ -1396,7 +1396,6 @@ dependencies = [
[[package]] [[package]]
name = "cubecl" name = "cubecl"
version = "0.2.0" version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=b009bcc2ac97ca0d46d9990d5681081f1ebc09cd#b009bcc2ac97ca0d46d9990d5681081f1ebc09cd"
dependencies = [ dependencies = [
"cubecl-core", "cubecl-core",
"cubecl-cuda", "cubecl-cuda",
@ -1407,7 +1406,6 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-common" name = "cubecl-common"
version = "0.2.0" version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=b009bcc2ac97ca0d46d9990d5681081f1ebc09cd#b009bcc2ac97ca0d46d9990d5681081f1ebc09cd"
dependencies = [ dependencies = [
"derive-new", "derive-new",
"getrandom", "getrandom",
@ -1422,7 +1420,6 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-core" name = "cubecl-core"
version = "0.2.0" version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=b009bcc2ac97ca0d46d9990d5681081f1ebc09cd#b009bcc2ac97ca0d46d9990d5681081f1ebc09cd"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"cubecl-macros", "cubecl-macros",
@ -1437,7 +1434,6 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-cuda" name = "cubecl-cuda"
version = "0.2.0" version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=b009bcc2ac97ca0d46d9990d5681081f1ebc09cd#b009bcc2ac97ca0d46d9990d5681081f1ebc09cd"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"cubecl-common", "cubecl-common",
@ -1452,7 +1448,6 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-linalg" name = "cubecl-linalg"
version = "0.2.0" version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=b009bcc2ac97ca0d46d9990d5681081f1ebc09cd#b009bcc2ac97ca0d46d9990d5681081f1ebc09cd"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"cubecl-core", "cubecl-core",
@ -1463,7 +1458,6 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-macros" name = "cubecl-macros"
version = "0.2.0" version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=b009bcc2ac97ca0d46d9990d5681081f1ebc09cd#b009bcc2ac97ca0d46d9990d5681081f1ebc09cd"
dependencies = [ dependencies = [
"derive-new", "derive-new",
"proc-macro2", "proc-macro2",
@ -1474,7 +1468,6 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-runtime" name = "cubecl-runtime"
version = "0.2.0" version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=b009bcc2ac97ca0d46d9990d5681081f1ebc09cd#b009bcc2ac97ca0d46d9990d5681081f1ebc09cd"
dependencies = [ dependencies = [
"async-channel", "async-channel",
"cfg_aliases 0.2.1", "cfg_aliases 0.2.1",
@ -1494,7 +1487,6 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-wgpu" name = "cubecl-wgpu"
version = "0.2.0" version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=b009bcc2ac97ca0d46d9990d5681081f1ebc09cd#b009bcc2ac97ca0d46d9990d5681081f1ebc09cd"
dependencies = [ dependencies = [
"async-channel", "async-channel",
"bytemuck", "bytemuck",
@ -6665,8 +6657,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
[[package]] [[package]]
name = "tracel-xtask" name = "tracel-xtask"
version = "1.0.0" version = "1.0.8"
source = "git+https://github.com/tracel-ai/xtask?rev=921408bc16e74d3ef8ae59356d928fb6706fb8f4#921408bc16e74d3ef8ae59356d928fb6706fb8f4" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63f307c8a22d3c67bb2a0678290243e4917d235c507f0c8b35e8612919978a08"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"clap 4.5.16", "clap 4.5.16",
@ -6683,8 +6676,9 @@ dependencies = [
[[package]] [[package]]
name = "tracel-xtask-macros" name = "tracel-xtask-macros"
version = "1.0.0" version = "1.0.8"
source = "git+https://github.com/tracel-ai/xtask?rev=921408bc16e74d3ef8ae59356d928fb6706fb8f4#921408bc16e74d3ef8ae59356d928fb6706fb8f4" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9b7ee23c050536c8c932ca7daaebbf45ff6c1d57f1bd65fc084833ac6a8d419"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",

View File

@ -151,11 +151,11 @@ systemstat = "0.2.3"
portable-atomic-util = { version = "0.2.2", features = ["alloc"] } portable-atomic-util = { version = "0.2.2", features = ["alloc"] }
### For the main burn branch. ### ### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b009bcc2ac97ca0d46d9990d5681081f1ebc09cd" } # cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b009bcc2ac97ca0d46d9990d5681081f1ebc09cd" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b009bcc2ac97ca0d46d9990d5681081f1ebc09cd" } # cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b009bcc2ac97ca0d46d9990d5681081f1ebc09cd" }
### For local development. ### ### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl" } cubecl = { path = "../cubecl/crates/cubecl" }
# cubecl-common = { path = "../cubecl/crates/cubecl-common" } cubecl-common = { path = "../cubecl/crates/cubecl-common" }
### For the release. ### ### For the release. ###
# cubecl = { version="0.2.0", default-features = false } # cubecl = { version="0.2.0", default-features = false }
# cubecl-common = { version="0.2.0", default-features = false } # cubecl-common = { version="0.2.0", default-features = false }

View File

@ -5,10 +5,10 @@ pub use cubecl::cuda::CudaDevice;
use cubecl::cuda::CudaRuntime; use cubecl::cuda::CudaRuntime;
#[cfg(not(feature = "fusion"))] #[cfg(not(feature = "fusion"))]
pub type Cuda<F = f32, I = i32> = JitBackend<CudaRuntime, F, I>; pub type Cuda<F = half::f16, I = i32> = JitBackend<CudaRuntime, F, I>;
#[cfg(feature = "fusion")] #[cfg(feature = "fusion")]
pub type Cuda<F = f32, I = i32> = burn_fusion::Fusion<JitBackend<CudaRuntime, F, I>>; pub type Cuda<F = half::f16, I = i32> = burn_fusion::Fusion<JitBackend<CudaRuntime, F, I>>;
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {