[WIP] Migrate to `cubecl` IR refactor (#2418)

This commit is contained in:
Genna Wingert 2024-10-30 17:50:41 +01:00 committed by GitHub
parent 69856a97db
commit 5730f022fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 302 additions and 382 deletions

128
Cargo.lock generated
View File

@ -524,7 +524,7 @@ dependencies = [
name = "burn-common" name = "burn-common"
version = "0.16.0" version = "0.16.0"
dependencies = [ dependencies = [
"cubecl-common", "cubecl-common 0.4.0",
"dashmap", "dashmap",
"getrandom", "getrandom",
"indicatif", "indicatif",
@ -1459,15 +1459,14 @@ dependencies = [
[[package]] [[package]]
name = "cubecl" name = "cubecl"
version = "0.3.0" version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
checksum = "75e75c7e982b943380665c5901fe0b69d5df2627644e0e50199c52b64d8d5a1c"
dependencies = [ dependencies = [
"cubecl-core", "cubecl-core",
"cubecl-cuda", "cubecl-cuda",
"cubecl-hip", "cubecl-hip",
"cubecl-linalg", "cubecl-linalg",
"cubecl-runtime", "cubecl-runtime 0.4.0",
"cubecl-wgpu", "cubecl-wgpu",
] ]
@ -1489,16 +1488,32 @@ dependencies = [
"web-time", "web-time",
] ]
[[package]]
name = "cubecl-common"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
dependencies = [
"derive-new",
"embassy-futures",
"futures-lite",
"getrandom",
"log",
"portable-atomic",
"rand",
"serde",
"spin",
"web-time",
]
[[package]] [[package]]
name = "cubecl-core" name = "cubecl-core"
version = "0.3.0" version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
checksum = "ec33b64139d1dfc747df8aed5834d10c3c55c716f5219041c6eb17241c96c929"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"cubecl-common", "cubecl-common 0.4.0",
"cubecl-macros", "cubecl-macros",
"cubecl-runtime", "cubecl-runtime 0.4.0",
"derive-new", "derive-new",
"half", "half",
"log", "log",
@ -1509,14 +1524,13 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-cpp" name = "cubecl-cpp"
version = "0.3.0" version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
checksum = "4ded461feb0ff342a4f675131dc0ae8ad94e58f66bad11e57f852cb7f190a731"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"cubecl-common", "cubecl-common 0.4.0",
"cubecl-core", "cubecl-core",
"cubecl-runtime", "cubecl-runtime 0.4.0",
"derive-new", "derive-new",
"half", "half",
"log", "log",
@ -1524,15 +1538,14 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-cuda" name = "cubecl-cuda"
version = "0.3.0" version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
checksum = "88dfdfe616124d2abe5e82052ff56f86843c369440e181d6936f7409e161dd82"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"cubecl-common", "cubecl-common 0.4.0",
"cubecl-core", "cubecl-core",
"cubecl-cpp", "cubecl-cpp",
"cubecl-runtime", "cubecl-runtime 0.4.0",
"cudarc 0.12.1", "cudarc 0.12.1",
"derive-new", "derive-new",
"half", "half",
@ -1541,16 +1554,15 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-hip" name = "cubecl-hip"
version = "0.3.0" version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
checksum = "409e0e176152ab51a60bbebb940b7a72aba210cd42b5f8cd2e87e7d7e674a13a"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"cubecl-common", "cubecl-common 0.4.0",
"cubecl-core", "cubecl-core",
"cubecl-cpp", "cubecl-cpp",
"cubecl-hip-sys", "cubecl-hip-sys",
"cubecl-runtime", "cubecl-runtime 0.4.0",
"derive-new", "derive-new",
"half", "half",
"log", "log",
@ -1567,23 +1579,21 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-linalg" name = "cubecl-linalg"
version = "0.3.0" version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
checksum = "3c5634782d790e9b6562fc267ffd15e9a510b4d6ec32c144cd2b166af2ba0cfb"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"cubecl-core", "cubecl-core",
"cubecl-runtime", "cubecl-runtime 0.4.0",
"half", "half",
] ]
[[package]] [[package]]
name = "cubecl-macros" name = "cubecl-macros"
version = "0.3.0" version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
checksum = "d2d22663257d9cdbcd67f5048d6f4e6eb965dd87104c3a173a7b0ea0d720e99b"
dependencies = [ dependencies = [
"cubecl-common", "cubecl-common 0.4.0",
"darling", "darling",
"derive-new", "derive-new",
"ident_case", "ident_case",
@ -1595,11 +1605,10 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-opt" name = "cubecl-opt"
version = "0.3.0" version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
checksum = "59fba2561a6ceb99e9c5fe7313db0aeead02b848dc9cdacf8373e0fe98d3c247"
dependencies = [ dependencies = [
"cubecl-common", "cubecl-common 0.4.0",
"cubecl-core", "cubecl-core",
"float-ord", "float-ord",
"log", "log",
@ -1618,7 +1627,28 @@ dependencies = [
"async-channel", "async-channel",
"async-lock", "async-lock",
"cfg_aliases 0.2.1", "cfg_aliases 0.2.1",
"cubecl-common", "cubecl-common 0.3.0",
"derive-new",
"dirs 5.0.1",
"hashbrown 0.14.5",
"log",
"md5",
"sanitize-filename",
"serde",
"serde_json",
"spin",
"wasm-bindgen-futures",
]
[[package]]
name = "cubecl-runtime"
version = "0.4.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
dependencies = [
"async-channel",
"async-lock",
"cfg_aliases 0.2.1",
"cubecl-common 0.4.0",
"derive-new", "derive-new",
"dirs 5.0.1", "dirs 5.0.1",
"hashbrown 0.14.5", "hashbrown 0.14.5",
@ -1633,31 +1663,29 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-spirv" name = "cubecl-spirv"
version = "0.3.0" version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
checksum = "835bc234cdd40fbb5e3e5e41bfb4a6e2ee2d7fd899b66d44dcbcb3a825ca1a59"
dependencies = [ dependencies = [
"cubecl-common", "cubecl-common 0.4.0",
"cubecl-core", "cubecl-core",
"cubecl-opt", "cubecl-opt",
"cubecl-runtime", "cubecl-runtime 0.4.0",
"hashbrown 0.14.5", "hashbrown 0.14.5",
"rspirv", "rspirv",
] ]
[[package]] [[package]]
name = "cubecl-wgpu" name = "cubecl-wgpu"
version = "0.3.0" version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
checksum = "6779f1072d70923758421c6214fd0cd19a6f25b91035a522f9cd9407d03b5cae"
dependencies = [ dependencies = [
"ash", "ash",
"async-channel", "async-channel",
"bytemuck", "bytemuck",
"cfg_aliases 0.2.1", "cfg_aliases 0.2.1",
"cubecl-common", "cubecl-common 0.4.0",
"cubecl-core", "cubecl-core",
"cubecl-runtime", "cubecl-runtime 0.4.0",
"cubecl-spirv", "cubecl-spirv",
"derive-new", "derive-new",
"hashbrown 0.14.5", "hashbrown 0.14.5",
@ -3265,7 +3293,7 @@ dependencies = [
"burn-candle", "burn-candle",
"burn-import", "burn-import",
"console_error_panic_hook", "console_error_panic_hook",
"cubecl-runtime", "cubecl-runtime 0.3.0",
"js-sys", "js-sys",
"log", "log",
"serde", "serde",
@ -3793,7 +3821,7 @@ version = "0.16.0"
dependencies = [ dependencies = [
"burn", "burn",
"console_error_panic_hook", "console_error_panic_hook",
"cubecl-runtime", "cubecl-runtime 0.3.0",
"js-sys", "js-sys",
"serde", "serde",
"wasm-bindgen", "wasm-bindgen",

View File

@ -152,14 +152,14 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.2", features = ["alloc"] } portable-atomic-util = { version = "0.2.2", features = ["alloc"] }
### For the main burn branch. ### ### For the main burn branch. ###
# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "63da837b5ae78ff1a3b7363fe3af2b02c2bc864f" } cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "99404b1e29946832a42b72a5c26d4cf42c67692e" }
# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "63da837b5ae78ff1a3b7363fe3af2b02c2bc864f" } cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "99404b1e29946832a42b72a5c26d4cf42c67692e" }
### For local development. ### ### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
### For the release. ### ### For the release. ###
cubecl = { version="0.3.0", default-features = false } # cubecl = { version = "0.3.0", default-features = false }
cubecl-common = { version="0.3.0", default-features = false } # cubecl-common = { version = "0.3.0", default-features = false }
### For xtask crate ### ### For xtask crate ###
tracel-xtask = { version = "~1.1" } tracel-xtask = { version = "~1.1" }

View File

@ -1,6 +1,8 @@
use cubecl::{ use cubecl::{
cpa, cpa,
ir::{Elem, IntKind, KernelDefinition, Scope, Variable, Visibility}, ir::{
Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility,
},
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo, OutputInfo,
}; };
@ -39,7 +41,7 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
let weight = self.weight; let weight = self.weight;
let bias = self.bias; let bias = self.bias;
let output = self.output; let output = self.output;
let id = Variable::AbsolutePos; let id = Variable::builtin(Builtin::AbsolutePos);
let input_stride_0 = scope.create_local(Elem::UInt); let input_stride_0 = scope.create_local(Elem::UInt);
let input_stride_1 = scope.create_local(Elem::UInt); let input_stride_1 = scope.create_local(Elem::UInt);
@ -92,34 +94,13 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
cpa!(scope, kernel_size_0 = shape(weight, 2u32)); cpa!(scope, kernel_size_0 = shape(weight, 2u32));
cpa!(scope, kernel_size_1 = shape(weight, 3u32)); cpa!(scope, kernel_size_1 = shape(weight, 3u32));
let conv_stride_0 = Variable::GlobalScalar { let conv_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
id: 0, let conv_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
elem: Elem::UInt, let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt));
}; let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt));
let conv_stride_1 = Variable::GlobalScalar { let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt));
id: 1, let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt));
elem: Elem::UInt, let groups = Variable::new(VariableKind::GlobalScalar(6), Item::new(Elem::UInt));
};
let dilation_0 = Variable::GlobalScalar {
id: 2,
elem: Elem::UInt,
};
let dilation_1 = Variable::GlobalScalar {
id: 3,
elem: Elem::UInt,
};
let padding_0 = Variable::GlobalScalar {
id: 4,
elem: Elem::UInt,
};
let padding_1 = Variable::GlobalScalar {
id: 5,
elem: Elem::UInt,
};
let groups = Variable::GlobalScalar {
id: 6,
elem: Elem::UInt,
};
let stride_0_i = scope.create_local(Elem::Int(IntKind::I32)); let stride_0_i = scope.create_local(Elem::Int(IntKind::I32));
let stride_1_i = scope.create_local(Elem::Int(IntKind::I32)); let stride_1_i = scope.create_local(Elem::Int(IntKind::I32));
@ -222,9 +203,9 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
cpa!(scope, index_input_b = b * input_stride_0); cpa!(scope, index_input_b = b * input_stride_0);
cpa!(scope, index_weight_oc = oc * weight_stride_1); cpa!(scope, index_weight_oc = oc * weight_stride_1);
let prod = scope.create_local(output.item()); let prod = scope.create_local(output.item);
let prod_tmp = scope.create_local(output.item()); let prod_tmp = scope.create_local(output.item);
let sum = scope.create_local(output.item()); let sum = scope.create_local(output.item);
cpa!(scope, sum = bias[oc_out]); cpa!(scope, sum = bias[oc_out]);
let kh = scope.create_local(Elem::UInt); let kh = scope.create_local(Elem::UInt);
@ -314,10 +295,10 @@ impl<R: JitRuntime, E: JitElement> Kernel for Conv2dTransposeEagerKernel<R, E> {
let mut scope = Scope::root(); let mut scope = Scope::root();
let item = E::cube_elem().into(); let item = E::cube_elem().into();
let input = Variable::GlobalInputArray { id: 0, item }; let input = Variable::new(VariableKind::GlobalInputArray(0), item);
let weight = Variable::GlobalInputArray { id: 1, item }; let weight = Variable::new(VariableKind::GlobalInputArray(1), item);
let bias = Variable::GlobalInputArray { id: 2, item }; let bias = Variable::new(VariableKind::GlobalInputArray(2), item);
let output = Variable::GlobalOutputArray { id: 0, item }; let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
scope.write_global_custom(output); scope.write_global_custom(output);

View File

@ -1,6 +1,8 @@
use cubecl::{ use cubecl::{
cpa, cpa,
ir::{Elem, IntKind, KernelDefinition, Scope, Variable, Visibility}, ir::{
Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility,
},
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo, OutputInfo,
}; };
@ -39,7 +41,7 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
let weight = self.weight; let weight = self.weight;
let bias = self.bias; let bias = self.bias;
let output = self.output; let output = self.output;
let idx = Variable::AbsolutePos; let idx = Variable::builtin(Builtin::AbsolutePos);
let input_stride_0 = scope.create_local(Elem::UInt); let input_stride_0 = scope.create_local(Elem::UInt);
let input_stride_1 = scope.create_local(Elem::UInt); let input_stride_1 = scope.create_local(Elem::UInt);
@ -104,46 +106,16 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
cpa!(scope, kernel_size_1 = shape(weight, 3u32)); cpa!(scope, kernel_size_1 = shape(weight, 3u32));
cpa!(scope, kernel_size_2 = shape(weight, 4u32)); cpa!(scope, kernel_size_2 = shape(weight, 4u32));
let conv_stride_0 = Variable::GlobalScalar { let conv_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
id: 0, let conv_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
elem: Elem::UInt, let conv_stride_2 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt));
}; let dilation_0 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt));
let conv_stride_1 = Variable::GlobalScalar { let dilation_1 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt));
id: 1, let dilation_2 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt));
elem: Elem::UInt, let padding_0 = Variable::new(VariableKind::GlobalScalar(6), Item::new(Elem::UInt));
}; let padding_1 = Variable::new(VariableKind::GlobalScalar(7), Item::new(Elem::UInt));
let conv_stride_2 = Variable::GlobalScalar { let padding_2 = Variable::new(VariableKind::GlobalScalar(8), Item::new(Elem::UInt));
id: 2, let groups = Variable::new(VariableKind::GlobalScalar(9), Item::new(Elem::UInt));
elem: Elem::UInt,
};
let dilation_0 = Variable::GlobalScalar {
id: 3,
elem: Elem::UInt,
};
let dilation_1 = Variable::GlobalScalar {
id: 4,
elem: Elem::UInt,
};
let dilation_2 = Variable::GlobalScalar {
id: 5,
elem: Elem::UInt,
};
let padding_0 = Variable::GlobalScalar {
id: 6,
elem: Elem::UInt,
};
let padding_1 = Variable::GlobalScalar {
id: 7,
elem: Elem::UInt,
};
let padding_2 = Variable::GlobalScalar {
id: 8,
elem: Elem::UInt,
};
let groups = Variable::GlobalScalar {
id: 9,
elem: Elem::UInt,
};
let stride_0_i = scope.create_local(Elem::Int(IntKind::I32)); let stride_0_i = scope.create_local(Elem::Int(IntKind::I32));
let stride_1_i = scope.create_local(Elem::Int(IntKind::I32)); let stride_1_i = scope.create_local(Elem::Int(IntKind::I32));
@ -273,9 +245,9 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
cpa!(scope, index_input_b = b * input_stride_0); cpa!(scope, index_input_b = b * input_stride_0);
cpa!(scope, index_weight_oc = oc * weight_stride_1); cpa!(scope, index_weight_oc = oc * weight_stride_1);
let prod = scope.create_local(output.item()); let prod = scope.create_local(output.item);
let prod_tmp = scope.create_local(output.item()); let prod_tmp = scope.create_local(output.item);
let sum = scope.create_local(output.item()); let sum = scope.create_local(output.item);
cpa!(scope, sum = bias[oc_out]); cpa!(scope, sum = bias[oc_out]);
let kd = scope.create_local(Elem::UInt); let kd = scope.create_local(Elem::UInt);
@ -391,10 +363,10 @@ impl<R: JitRuntime, E: JitElement> Kernel for Conv3dTransposeEagerKernel<R, E> {
let mut scope = Scope::root(); let mut scope = Scope::root();
let item = E::cube_elem().into(); let item = E::cube_elem().into();
let input = Variable::GlobalInputArray { id: 0, item }; let input = Variable::new(VariableKind::GlobalInputArray(0), item);
let weight = Variable::GlobalInputArray { id: 1, item }; let weight = Variable::new(VariableKind::GlobalInputArray(1), item);
let bias = Variable::GlobalInputArray { id: 2, item }; let bias = Variable::new(VariableKind::GlobalInputArray(2), item);
let output = Variable::GlobalOutputArray { id: 0, item }; let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
scope.write_global_custom(output); scope.write_global_custom(output);

View File

@ -4,7 +4,7 @@ use crate::{
use burn_tensor::ElementConversion; use burn_tensor::ElementConversion;
use cubecl::{ use cubecl::{
cpa, cpa,
ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, ir::{Builtin, Elem, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility},
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo, OutputInfo,
}; };
@ -27,7 +27,7 @@ impl FlipComputeShader {
pub fn expand(self, scope: &mut Scope) { pub fn expand(self, scope: &mut Scope) {
let input = self.input; let input = self.input;
let output = self.output; let output = self.output;
let id = Variable::AbsolutePos; let id = Variable::builtin(Builtin::AbsolutePos);
let offset_input = scope.zero(Elem::UInt); let offset_input = scope.zero(Elem::UInt);
let offset_local = scope.create_local(Elem::UInt); let offset_local = scope.create_local(Elem::UInt);
@ -42,10 +42,10 @@ impl FlipComputeShader {
cpa!(scope, shape = shape(output, i)); cpa!(scope, shape = shape(output, i));
cpa!( cpa!(
scope, scope,
flip = cast(Variable::GlobalScalar { flip = cast(Variable::new(
id: i as u16, VariableKind::GlobalScalar(i as u16),
elem: Elem::UInt Item::new(Elem::UInt)
}) ))
); );
cpa!(scope, flip_bool = flip == 1u32); cpa!(scope, flip_bool = flip == 1u32);
@ -61,7 +61,7 @@ impl FlipComputeShader {
cpa!(scope, offset_input += offset_local); cpa!(scope, offset_input += offset_local);
} }
let result = scope.create_local(input.item()); let result = scope.create_local(input.item);
cpa!(scope, result = input[offset_input]); cpa!(scope, result = input[offset_input]);
cpa!(scope, output[id] = result); cpa!(scope, output[id] = result);
} }
@ -72,8 +72,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for FlipEagerKernel<R, E> {
let mut scope = Scope::root(); let mut scope = Scope::root();
let item = E::cube_elem().into(); let item = E::cube_elem().into();
let input = Variable::GlobalInputArray { id: 0, item }; let input = Variable::new(VariableKind::GlobalInputArray(0), item);
let output = Variable::GlobalOutputArray { id: 0, item }; let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
scope.write_global_custom(output); scope.write_global_custom(output);

View File

@ -1,7 +1,7 @@
use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime}; use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime};
use cubecl::{ use cubecl::{
cpa, cpa,
ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility},
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo, OutputInfo,
}; };
@ -26,7 +26,7 @@ impl RepeatComputeShader {
pub fn expand(self, scope: &mut Scope) { pub fn expand(self, scope: &mut Scope) {
let input = self.input; let input = self.input;
let output = self.output; let output = self.output;
let id = Variable::AbsolutePos; let id = Variable::builtin(Builtin::AbsolutePos);
let offset_input = scope.zero(Elem::UInt); let offset_input = scope.zero(Elem::UInt);
let offset_local = scope.zero(Elem::UInt); let offset_local = scope.zero(Elem::UInt);
@ -50,7 +50,7 @@ impl RepeatComputeShader {
cpa!(scope, offset_input += offset_local); cpa!(scope, offset_input += offset_local);
} }
let result = scope.create_local(input.item()); let result = scope.create_local(input.item);
cpa!(scope, result = input[offset_input]); cpa!(scope, result = input[offset_input]);
cpa!(scope, output[id] = result); cpa!(scope, output[id] = result);
} }
@ -60,8 +60,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for RepeatEagerKernel<R, E> {
let mut scope = Scope::root(); let mut scope = Scope::root();
let item = E::cube_elem().into(); let item = E::cube_elem().into();
let input = Variable::GlobalInputArray { id: 0, item }; let input = Variable::new(VariableKind::GlobalInputArray(0), item);
let output = Variable::GlobalOutputArray { id: 0, item }; let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
scope.write_global_custom(output); scope.write_global_custom(output);

View File

@ -4,7 +4,7 @@ use crate::{
use burn_tensor::{ElementConversion, Shape}; use burn_tensor::{ElementConversion, Shape};
use cubecl::{ use cubecl::{
cpa, cpa,
ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, ir::{Builtin, Elem, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility},
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo, OutputInfo,
}; };
@ -27,7 +27,7 @@ impl SliceComputeShader {
pub fn expand(self, scope: &mut Scope) { pub fn expand(self, scope: &mut Scope) {
let input = self.input; let input = self.input;
let output = self.output; let output = self.output;
let id = Variable::AbsolutePos; let id = Variable::builtin(Builtin::AbsolutePos);
let offset_input = scope.zero(Elem::UInt); let offset_input = scope.zero(Elem::UInt);
let offset_local = scope.create_local(Elem::UInt); let offset_local = scope.create_local(Elem::UInt);
@ -43,10 +43,10 @@ impl SliceComputeShader {
cpa!(scope, shape_output = shape(output, i)); cpa!(scope, shape_output = shape(output, i));
cpa!( cpa!(
scope, scope,
range_start = cast(Variable::GlobalScalar { range_start = cast(Variable::new(
id: i as u16, VariableKind::GlobalScalar(i as u16),
elem: Elem::UInt Item::new(Elem::UInt)
}) ))
); );
cpa!(scope, offset_local = id / stride_output); cpa!(scope, offset_local = id / stride_output);
@ -57,7 +57,7 @@ impl SliceComputeShader {
cpa!(scope, offset_input += offset_local); cpa!(scope, offset_input += offset_local);
} }
let result = scope.create_local(input.item()); let result = scope.create_local(input.item);
cpa!(scope, result = input[offset_input]); cpa!(scope, result = input[offset_input]);
cpa!(scope, output[id] = result); cpa!(scope, output[id] = result);
} }
@ -68,8 +68,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for SliceEagerKernel<R, E> {
let mut scope = Scope::root(); let mut scope = Scope::root();
let item = E::cube_elem().into(); let item = E::cube_elem().into();
let input = Variable::GlobalInputArray { id: 0, item }; let input = Variable::new(VariableKind::GlobalInputArray(0), item);
let output = Variable::GlobalOutputArray { id: 0, item }; let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
scope.write_global_custom(output); scope.write_global_custom(output);

View File

@ -2,7 +2,7 @@ use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime};
use burn_tensor::ElementConversion; use burn_tensor::ElementConversion;
use cubecl::{ use cubecl::{
cpa, cpa,
ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, ir::{Builtin, Elem, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility},
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
}; };
use std::{marker::PhantomData, ops::Range}; use std::{marker::PhantomData, ops::Range};
@ -24,7 +24,7 @@ impl SliceAssignComputeShader {
pub fn expand(self, scope: &mut Scope) { pub fn expand(self, scope: &mut Scope) {
let input = self.input; let input = self.input;
let value = self.value; let value = self.value;
let id = Variable::AbsolutePos; let id = Variable::builtin(Builtin::AbsolutePos);
let offset_input = scope.zero(Elem::UInt); let offset_input = scope.zero(Elem::UInt);
let offset_value = scope.zero(Elem::UInt); let offset_value = scope.zero(Elem::UInt);
@ -46,10 +46,10 @@ impl SliceAssignComputeShader {
cpa!(scope, shape_input = shape(input, i)); cpa!(scope, shape_input = shape(input, i));
cpa!( cpa!(
scope, scope,
range_start = cast(Variable::GlobalScalar { range_start = cast(Variable::new(
id: i as u16, VariableKind::GlobalScalar(i as u16),
elem: Elem::UInt Item::new(Elem::UInt)
}) ))
); );
cpa!(scope, offset_local = id / stride_value); cpa!(scope, offset_local = id / stride_value);
@ -66,7 +66,7 @@ impl SliceAssignComputeShader {
cpa!(scope, offset_input += offset_local_input); cpa!(scope, offset_input += offset_local_input);
} }
let result = scope.create_local(input.item()); let result = scope.create_local(input.item);
cpa!(scope, result = value[offset_value]); cpa!(scope, result = value[offset_value]);
cpa!(scope, input[offset_input] = result); cpa!(scope, input[offset_input] = result);
} }
@ -77,8 +77,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for SliceAssignEagerKernel<R, E> {
let mut scope = Scope::root(); let mut scope = Scope::root();
let item = E::cube_elem().into(); let item = E::cube_elem().into();
let input = Variable::GlobalInputArray { id: 0, item }; let input = Variable::new(VariableKind::GlobalInputArray(0), item);
let value = Variable::GlobalInputArray { id: 1, item }; let value = Variable::new(VariableKind::GlobalInputArray(1), item);
scope.write_global_custom(input); scope.write_global_custom(input);

View File

@ -1,6 +1,6 @@
use cubecl::{ use cubecl::{
cpa, cpa,
ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility},
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo, OutputInfo,
}; };
@ -24,7 +24,7 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
pub(crate) fn expand(self, scope: &mut Scope) { pub(crate) fn expand(self, scope: &mut Scope) {
let input = self.input; let input = self.input;
let output = self.output; let output = self.output;
let id = Variable::AbsolutePos; let id = Variable::builtin(Builtin::AbsolutePos);
let elem = E::cube_elem(); let elem = E::cube_elem();
let input_stride_0 = scope.create_local(Elem::UInt); let input_stride_0 = scope.create_local(Elem::UInt);
@ -181,10 +181,10 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
let index_1 = scope.create_local(Elem::UInt); let index_1 = scope.create_local(Elem::UInt);
let index_2 = scope.create_local(Elem::UInt); let index_2 = scope.create_local(Elem::UInt);
let index_3 = scope.create_local(Elem::UInt); let index_3 = scope.create_local(Elem::UInt);
let inp_0 = scope.create_local(input.item()); let inp_0 = scope.create_local(input.item);
let inp_1 = scope.create_local(input.item()); let inp_1 = scope.create_local(input.item);
let inp_2 = scope.create_local(input.item()); let inp_2 = scope.create_local(input.item);
let inp_3 = scope.create_local(input.item()); let inp_3 = scope.create_local(input.item);
cpa!(scope, index_0 = index_base); cpa!(scope, index_0 = index_base);
cpa!(scope, index_0 += y0_stride); cpa!(scope, index_0 += y0_stride);
@ -276,7 +276,7 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
fn min(scope: &mut Scope, a: Variable, b: Variable) -> Variable { fn min(scope: &mut Scope, a: Variable, b: Variable) -> Variable {
let cond = scope.create_local(Elem::Bool); let cond = scope.create_local(Elem::Bool);
let res = scope.create_local(a.item()); let res = scope.create_local(a.item);
cpa!(scope, cond = a < b); cpa!(scope, cond = a < b);
cpa!(scope, if(cond).then(|scope|{ cpa!(scope, if(cond).then(|scope|{
@ -296,7 +296,7 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
x3: Variable, x3: Variable,
t: Variable, t: Variable,
) -> Variable { ) -> Variable {
let item = x0.item(); let item = x0.item;
let x = scope.create_local(item); let x = scope.create_local(item);
let a: Variable = scope.create_with_value(-0.75, item); let a: Variable = scope.create_with_value(-0.75, item);
let one: Variable = scope.create_with_value(1, item); let one: Variable = scope.create_with_value(1, item);
@ -327,7 +327,7 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
} }
fn cubic_convolution1(scope: &mut Scope, x: Variable, a: Variable) -> Variable { fn cubic_convolution1(scope: &mut Scope, x: Variable, a: Variable) -> Variable {
let item = x.item(); let item = x.item;
let conv = scope.create_local(item); let conv = scope.create_local(item);
let tmp = scope.create_local(item); let tmp = scope.create_local(item);
let one = scope.create_with_value(1, item); let one = scope.create_with_value(1, item);
@ -346,7 +346,7 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
} }
fn cubic_convolution2(scope: &mut Scope, x: Variable, a: Variable) -> Variable { fn cubic_convolution2(scope: &mut Scope, x: Variable, a: Variable) -> Variable {
let item = x.item(); let item = x.item;
let conv = scope.create_local(item); let conv = scope.create_local(item);
let tmp = scope.create_local(item); let tmp = scope.create_local(item);
let four = scope.create_with_value(4, item); let four = scope.create_with_value(4, item);
@ -372,8 +372,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for InterpolateBicubicEagerKernel<R, E
let mut scope = Scope::root(); let mut scope = Scope::root();
let item = E::cube_elem().into(); let item = E::cube_elem().into();
let input = Variable::GlobalInputArray { id: 0, item }; let input = Variable::new(VariableKind::GlobalInputArray(0), item);
let output = Variable::GlobalOutputArray { id: 0, item }; let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
InterpolateBicubicShader { InterpolateBicubicShader {
input, input,

View File

@ -1,6 +1,6 @@
use cubecl::{ use cubecl::{
cpa, cpa,
ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility},
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo, OutputInfo,
}; };
@ -23,7 +23,7 @@ impl InterpolateBilinearShader {
pub(crate) fn expand(self, scope: &mut Scope) { pub(crate) fn expand(self, scope: &mut Scope) {
let input = self.input; let input = self.input;
let output = self.output; let output = self.output;
let id = Variable::AbsolutePos; let id = Variable::builtin(Builtin::AbsolutePos);
let input_stride_0 = scope.create_local(Elem::UInt); let input_stride_0 = scope.create_local(Elem::UInt);
let input_stride_1 = scope.create_local(Elem::UInt); let input_stride_1 = scope.create_local(Elem::UInt);
@ -78,26 +78,26 @@ impl InterpolateBilinearShader {
cpa!(scope, w = id / output_stride_3); cpa!(scope, w = id / output_stride_3);
cpa!(scope, w = w % output_shape_3); cpa!(scope, w = w % output_shape_3);
let factor_float = scope.create_local(input.item()); let factor_float = scope.create_local(input.item);
let numerator_float = scope.create_local(input.item()); let numerator_float = scope.create_local(input.item);
let numerator_int = scope.create_local(Elem::UInt); let numerator_int = scope.create_local(Elem::UInt);
let denominator_float = scope.create_local(input.item()); let denominator_float = scope.create_local(input.item);
let denominator_int = scope.create_local(Elem::UInt); let denominator_int = scope.create_local(Elem::UInt);
let frac = scope.create_local(input.item()); let frac = scope.create_local(input.item);
let v0 = scope.create_local(input.item()); let v0 = scope.create_local(input.item);
let v1 = scope.create_local(input.item()); let v1 = scope.create_local(input.item);
let one = scope.create_with_value(1f32, input.item()); let one = scope.create_with_value(1f32, input.item);
let y0 = scope.create_local(Elem::UInt); let y0 = scope.create_local(Elem::UInt);
let y1 = scope.create_local(Elem::UInt); let y1 = scope.create_local(Elem::UInt);
let yw = scope.create_local(input.item()); let yw = scope.create_local(input.item);
let yw_ = scope.create_local(input.item()); let yw_ = scope.create_local(input.item);
let x0 = scope.create_local(Elem::UInt); let x0 = scope.create_local(Elem::UInt);
let x1 = scope.create_local(Elem::UInt); let x1 = scope.create_local(Elem::UInt);
let xw = scope.create_local(input.item()); let xw = scope.create_local(input.item);
let xw_ = scope.create_local(input.item()); let xw_ = scope.create_local(input.item);
cpa!(scope, numerator_int = input_shape_2 - 1u32); cpa!(scope, numerator_int = input_shape_2 - 1u32);
cpa!(scope, denominator_int = output_shape_2 - 1u32); cpa!(scope, denominator_int = output_shape_2 - 1u32);
@ -136,10 +136,10 @@ impl InterpolateBilinearShader {
let y1_stride = scope.create_local(Elem::UInt); let y1_stride = scope.create_local(Elem::UInt);
let x0_stride = scope.create_local(Elem::UInt); let x0_stride = scope.create_local(Elem::UInt);
let x1_stride = scope.create_local(Elem::UInt); let x1_stride = scope.create_local(Elem::UInt);
let p_a = scope.create_local(input.item()); let p_a = scope.create_local(input.item);
let p_b = scope.create_local(input.item()); let p_b = scope.create_local(input.item);
let p_c = scope.create_local(input.item()); let p_c = scope.create_local(input.item);
let p_d = scope.create_local(input.item()); let p_d = scope.create_local(input.item);
cpa!(scope, index_base = b * input_stride_0); cpa!(scope, index_base = b * input_stride_0);
cpa!(scope, index_tmp = c * input_stride_1); cpa!(scope, index_tmp = c * input_stride_1);
@ -177,7 +177,7 @@ impl InterpolateBilinearShader {
cpa!(scope, p_d *= xw); cpa!(scope, p_d *= xw);
cpa!(scope, p_d *= yw); cpa!(scope, p_d *= yw);
let sum = scope.create_local(input.item()); let sum = scope.create_local(input.item);
cpa!(scope, sum = p_a + p_b); cpa!(scope, sum = p_a + p_b);
cpa!(scope, sum += p_c); cpa!(scope, sum += p_c);
cpa!(scope, sum += p_d); cpa!(scope, sum += p_d);
@ -190,8 +190,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for InterpolateBilinearEagerKernel<R,
let mut scope = Scope::root(); let mut scope = Scope::root();
let item = E::cube_elem().into(); let item = E::cube_elem().into();
let input = Variable::GlobalInputArray { id: 0, item }; let input = Variable::new(VariableKind::GlobalInputArray(0), item);
let output = Variable::GlobalOutputArray { id: 0, item }; let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
InterpolateBilinearShader { input, output }.expand(&mut scope); InterpolateBilinearShader { input, output }.expand(&mut scope);

View File

@ -1,6 +1,6 @@
use cubecl::{ use cubecl::{
cpa, cpa,
ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility},
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo, OutputInfo,
}; };
@ -24,7 +24,7 @@ impl<E: JitElement> InterpolateNearestShader<E> {
pub(crate) fn expand(self, scope: &mut Scope) { pub(crate) fn expand(self, scope: &mut Scope) {
let input = self.input; let input = self.input;
let output = self.output; let output = self.output;
let id = Variable::AbsolutePos; let id = Variable::builtin(Builtin::AbsolutePos);
let elem = E::cube_elem(); let elem = E::cube_elem();
let input_stride_0 = scope.create_local(Elem::UInt); let input_stride_0 = scope.create_local(Elem::UInt);
@ -106,7 +106,7 @@ impl<E: JitElement> InterpolateNearestShader<E> {
let index = scope.create_local(Elem::UInt); let index = scope.create_local(Elem::UInt);
let index_tmp = scope.create_local(Elem::UInt); let index_tmp = scope.create_local(Elem::UInt);
let val = scope.create_local(output.item()); let val = scope.create_local(output.item);
cpa!(scope, index = b * input_stride_0); cpa!(scope, index = b * input_stride_0);
cpa!(scope, index_tmp = c * input_stride_1); cpa!(scope, index_tmp = c * input_stride_1);
@ -126,8 +126,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for InterpolateNearestEagerKernel<R, E
let mut scope = Scope::root(); let mut scope = Scope::root();
let item = E::cube_elem().into(); let item = E::cube_elem().into();
let input = Variable::GlobalInputArray { id: 0, item }; let input = Variable::new(VariableKind::GlobalInputArray(0), item);
let output = Variable::GlobalOutputArray { id: 0, item }; let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
InterpolateNearestShader { InterpolateNearestShader {
input, input,

View File

@ -1,6 +1,6 @@
use cubecl::{ use cubecl::{
cpa, cpa,
ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility},
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo, OutputInfo,
}; };
@ -24,7 +24,7 @@ impl<E: JitElement> InterpolateNearestBackwardShader<E> {
fn expand(self, scope: &mut Scope) { fn expand(self, scope: &mut Scope) {
let grad = self.out_grad; let grad = self.out_grad;
let output = self.output; let output = self.output;
let id = Variable::AbsolutePos; let id = Variable::builtin(Builtin::AbsolutePos);
let grad_stride_0 = scope.create_local(Elem::UInt); let grad_stride_0 = scope.create_local(Elem::UInt);
let grad_stride_1 = scope.create_local(Elem::UInt); let grad_stride_1 = scope.create_local(Elem::UInt);
@ -88,7 +88,7 @@ impl<E: JitElement> InterpolateNearestBackwardShader<E> {
let gw_start = Self::start_index(scope, ow, grad_shape_3, output_shape_3); let gw_start = Self::start_index(scope, ow, grad_shape_3, output_shape_3);
let gw_end = Self::end_index(scope, ow, grad_shape_3, output_shape_3); let gw_end = Self::end_index(scope, ow, grad_shape_3, output_shape_3);
let result = scope.create_local(grad.item()); let result = scope.create_local(grad.item);
let index_grad = scope.create_local(Elem::UInt); let index_grad = scope.create_local(Elem::UInt);
let index_grad_0 = scope.create_local(Elem::UInt); let index_grad_0 = scope.create_local(Elem::UInt);
@ -99,7 +99,7 @@ impl<E: JitElement> InterpolateNearestBackwardShader<E> {
cpa!(scope, index_grad_0 = b * grad_stride_0); cpa!(scope, index_grad_0 = b * grad_stride_0);
cpa!(scope, index_grad_1 = c * grad_stride_1); cpa!(scope, index_grad_1 = c * grad_stride_1);
let sum = scope.zero(output.item()); let sum = scope.zero(output.item);
cpa!( cpa!(
scope, scope,
@ -184,8 +184,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for InterpolateNearestBackwardEagerKer
let mut scope = Scope::root(); let mut scope = Scope::root();
let item = E::cube_elem().into(); let item = E::cube_elem().into();
let out_grad = Variable::GlobalInputArray { id: 0, item }; let out_grad = Variable::new(VariableKind::GlobalInputArray(0), item);
let output = Variable::GlobalOutputArray { id: 0, item }; let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
InterpolateNearestBackwardShader { InterpolateNearestBackwardShader {
out_grad, out_grad,

View File

@ -7,7 +7,9 @@ use crate::{
}; };
use cubecl::{ use cubecl::{
cpa, cpa,
ir::{Elem, IntKind, KernelDefinition, Scope, Variable, Visibility}, ir::{
Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility,
},
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo, OutputInfo,
}; };
@ -32,7 +34,7 @@ impl AvgPool2dBackwardComputeShader {
fn expand(self, scope: &mut Scope) { fn expand(self, scope: &mut Scope) {
let grad = self.grad; let grad = self.grad;
let output = self.output; let output = self.output;
let id = Variable::AbsolutePos; let id = Variable::builtin(Builtin::AbsolutePos);
let grad_stride_0 = scope.create_local(Elem::UInt); let grad_stride_0 = scope.create_local(Elem::UInt);
let grad_stride_1 = scope.create_local(Elem::UInt); let grad_stride_1 = scope.create_local(Elem::UInt);
@ -70,22 +72,10 @@ impl AvgPool2dBackwardComputeShader {
cpa!(scope, output_shape_2 = shape(output, 2u32)); cpa!(scope, output_shape_2 = shape(output, 2u32));
cpa!(scope, output_shape_3 = shape(output, 3u32)); cpa!(scope, output_shape_3 = shape(output, 3u32));
let pool_stride_0 = Variable::GlobalScalar { let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
id: 0, let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
elem: Elem::UInt, let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt));
}; let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt));
let pool_stride_1 = Variable::GlobalScalar {
id: 1,
elem: Elem::UInt,
};
let padding_0 = Variable::GlobalScalar {
id: 4,
elem: Elem::UInt,
};
let padding_1 = Variable::GlobalScalar {
id: 5,
elem: Elem::UInt,
};
let [kernel_size_0, kernel_size_1] = self.kernel_size; let [kernel_size_0, kernel_size_1] = self.kernel_size;
let b = scope.create_local(Elem::UInt); let b = scope.create_local(Elem::UInt);
@ -116,9 +106,9 @@ impl AvgPool2dBackwardComputeShader {
let index_tmp = scope.create_local(Elem::UInt); let index_tmp = scope.create_local(Elem::UInt);
let index_base = scope.create_local(Elem::UInt); let index_base = scope.create_local(Elem::UInt);
let grad_accumulation = scope.zero(grad.item()); let grad_accumulation = scope.zero(grad.item);
let result = scope.create_local(grad.item()); let result = scope.create_local(grad.item);
let count = scope.create_local(grad.item()); let count = scope.create_local(grad.item);
let count_include_pad = self.count_include_pad; let count_include_pad = self.count_include_pad;
if count_include_pad { if count_include_pad {
@ -226,30 +216,12 @@ impl AvgPool2dBackwardComputeShader {
output_stride_2: Variable, output_stride_2: Variable,
output_stride_3: Variable, output_stride_3: Variable,
) -> (Variable, Variable, Variable, Variable) { ) -> (Variable, Variable, Variable, Variable) {
let pool_stride_0 = Variable::GlobalScalar { let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
id: 0, let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
elem: Elem::UInt, let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt));
}; let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt));
let pool_stride_1 = Variable::GlobalScalar { let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt));
id: 1, let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt));
elem: Elem::UInt,
};
let dilation_0 = Variable::GlobalScalar {
id: 2,
elem: Elem::UInt,
};
let dilation_1 = Variable::GlobalScalar {
id: 3,
elem: Elem::UInt,
};
let padding_0 = Variable::GlobalScalar {
id: 4,
elem: Elem::UInt,
};
let padding_1 = Variable::GlobalScalar {
id: 5,
elem: Elem::UInt,
};
let [kernel_size_0, kernel_size_1] = self.kernel_size; let [kernel_size_0, kernel_size_1] = self.kernel_size;
@ -350,8 +322,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for AvgPool2dBackwardEagerKernel<R, E>
let mut scope = Scope::root(); let mut scope = Scope::root();
let item = E::cube_elem().into(); let item = E::cube_elem().into();
let grad = Variable::GlobalInputArray { id: 0, item }; let grad = Variable::new(VariableKind::GlobalInputArray(0), item);
let output = Variable::GlobalOutputArray { id: 0, item }; let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
scope.write_global_custom(output); scope.write_global_custom(output);

View File

@ -7,7 +7,9 @@ use crate::{
}; };
use cubecl::{ use cubecl::{
cpa, cpa,
ir::{Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility}, ir::{
Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility,
},
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo, OutputInfo,
}; };
@ -32,7 +34,7 @@ impl MaxPool2dBackwardComputeShader {
let grad = self.grad; let grad = self.grad;
let output = self.output; let output = self.output;
let indices = self.indices; let indices = self.indices;
let id = Variable::AbsolutePos; let id = Variable::builtin(Builtin::AbsolutePos);
let grad_stride_0 = scope.create_local(Elem::UInt); let grad_stride_0 = scope.create_local(Elem::UInt);
let grad_stride_1 = scope.create_local(Elem::UInt); let grad_stride_1 = scope.create_local(Elem::UInt);
@ -103,8 +105,8 @@ impl MaxPool2dBackwardComputeShader {
let index_base = scope.create_local(Elem::UInt); let index_base = scope.create_local(Elem::UInt);
let index_tmp = scope.create_local(Elem::UInt); let index_tmp = scope.create_local(Elem::UInt);
let grad_accumulation = scope.zero(grad.item()); let grad_accumulation = scope.zero(grad.item);
let result = scope.create_local(grad.item()); let result = scope.create_local(grad.item);
let (oh_start, oh_end, ow_start, ow_end) = self.loop_ranges( let (oh_start, oh_end, ow_start, ow_end) = self.loop_ranges(
scope, scope,
@ -160,30 +162,12 @@ impl MaxPool2dBackwardComputeShader {
output_stride_2: Variable, output_stride_2: Variable,
output_stride_3: Variable, output_stride_3: Variable,
) -> (Variable, Variable, Variable, Variable) { ) -> (Variable, Variable, Variable, Variable) {
let pool_stride_0 = Variable::GlobalScalar { let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
id: 0, let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
elem: Elem::UInt, let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt));
}; let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt));
let pool_stride_1 = Variable::GlobalScalar { let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt));
id: 1, let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt));
elem: Elem::UInt,
};
let dilation_0 = Variable::GlobalScalar {
id: 2,
elem: Elem::UInt,
};
let dilation_1 = Variable::GlobalScalar {
id: 3,
elem: Elem::UInt,
};
let padding_0 = Variable::GlobalScalar {
id: 4,
elem: Elem::UInt,
};
let padding_1 = Variable::GlobalScalar {
id: 5,
elem: Elem::UInt,
};
let [kernel_size_0, kernel_size_1] = self.kernel_size; let [kernel_size_0, kernel_size_1] = self.kernel_size;
@ -284,12 +268,12 @@ impl<R: JitRuntime, E: JitElement> Kernel for MaxPool2dWithIndicesBackwardEagerK
let mut scope = Scope::root(); let mut scope = Scope::root();
let item = E::cube_elem().into(); let item = E::cube_elem().into();
let indices = Variable::GlobalInputArray { let indices = Variable::new(
id: 0, VariableKind::GlobalInputArray(0),
item: Item::new(Elem::Int(IntKind::I32)), Item::new(Elem::Int(IntKind::I32)),
}; );
let grad = Variable::GlobalInputArray { id: 1, item }; let grad = Variable::new(VariableKind::GlobalInputArray(1), item);
let output = Variable::GlobalOutputArray { id: 0, item }; let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
scope.write_global_custom(output); scope.write_global_custom(output);

View File

@ -1,6 +1,6 @@
use cubecl::{ use cubecl::{
cpa, cpa,
ir::{Elem, Scope, Variable}, ir::{Builtin, Elem, Item, Scope, Variable, VariableKind},
prelude::*, prelude::*,
CubeCountSettings, Execution, InputInfo, OutputInfo, CubeCountSettings, Execution, InputInfo, OutputInfo,
}; };
@ -58,32 +58,17 @@ impl<P: Prng<E>, R: JitRuntime, E: JitElement> Kernel for PrngEagerKernel<P, R,
let mut scope = Scope::root(); let mut scope = Scope::root();
let item = E::cube_elem().into(); let item = E::cube_elem().into();
let output = Variable::GlobalOutputArray { id: 0, item }; let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
let seed0 = Variable::GlobalScalar { let seed0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
id: 0, let seed1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
elem: Elem::UInt, let seed2 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt));
}; let seed3 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt));
let seed1 = Variable::GlobalScalar {
id: 1,
elem: Elem::UInt,
};
let seed2 = Variable::GlobalScalar {
id: 2,
elem: Elem::UInt,
};
let seed3 = Variable::GlobalScalar {
id: 3,
elem: Elem::UInt,
};
let seeds = [seed0, seed1, seed2, seed3]; let seeds = [seed0, seed1, seed2, seed3];
let mut args = Vec::<Variable>::new(); let mut args = Vec::<Variable>::new();
for i in 0..P::args_length() { for i in 0..P::args_length() {
args.push(Variable::GlobalScalar { args.push(Variable::new(VariableKind::GlobalScalar(i as u16), item));
id: i as u16,
elem: item.elem(),
});
} }
PrngShader::<P, E>::new(output, N_VALUES_PER_THREAD, seeds, args).expand(&mut scope); PrngShader::<P, E>::new(output, N_VALUES_PER_THREAD, seeds, args).expand(&mut scope);
@ -174,12 +159,12 @@ impl<P: Prng<E>, E: JitElement> PrngShader<P, E> {
let n_values_per_thread: Variable = self.n_values_per_thread.into(); let n_values_per_thread: Variable = self.n_values_per_thread.into();
let args = self.args; let args = self.args;
let cube_dim_x = Variable::CubeDimX; let cube_dim_x = Variable::builtin(Builtin::CubeDimX);
let cube_dim_y = Variable::CubeDimY; let cube_dim_y = Variable::builtin(Builtin::CubeDimY);
let cube_pos_x = Variable::CubePosX; let cube_pos_x = Variable::builtin(Builtin::CubePosX);
let cube_pos_y = Variable::CubePosY; let cube_pos_y = Variable::builtin(Builtin::CubePosY);
let cube_count_y = Variable::CubeCountY; let cube_count_y = Variable::builtin(Builtin::CubeCountY);
let local_index = Variable::UnitPos; let local_index = Variable::builtin(Builtin::UnitPos);
let n_invocations = scope.create_local(Elem::UInt); let n_invocations = scope.create_local(Elem::UInt);
cpa!(scope, n_invocations = cube_dim_x); cpa!(scope, n_invocations = cube_dim_x);

View File

@ -37,7 +37,7 @@ impl<E: JitElement> Prng<E> for Normal<E> {
output: Variable, output: Variable,
) { ) {
let float_elem = Elem::Float(FloatKind::F32); let float_elem = Elem::Float(FloatKind::F32);
let item = output.item(); let item = output.item;
let mean = args[0]; let mean = args[0];
let std = args[1]; let std = args[1];
let two_pi = scope.create_with_value(2. * PI, float_elem); let two_pi = scope.create_with_value(2. * PI, float_elem);

View File

@ -35,7 +35,7 @@ impl<E: JitElement> Prng<E> for Uniform<E> {
output: Variable, output: Variable,
) { ) {
let float_elem = Elem::Float(FloatKind::F32); let float_elem = Elem::Float(FloatKind::F32);
let item = output.item(); let item = output.item;
let lower_bound = args[0]; let lower_bound = args[0];
let upper_bound = args[1]; let upper_bound = args[1];
let scale = scope.create_local(item); let scale = scope.create_local(item);

View File

@ -32,7 +32,7 @@ impl<E: JitElement> ReduceDimShared<E> for Argmax {
(value, index): Self::Accumulator, (value, index): Self::Accumulator,
) { ) {
let (value_shared_memory, index_shared_memory) = shared_memory; let (value_shared_memory, index_shared_memory) = shared_memory;
let current_value = scope.create_local(value.item()); let current_value = scope.create_local(value.item);
cpa!(scope, current_value = value_shared_memory[write_position]); cpa!(scope, current_value = value_shared_memory[write_position]);
let condition = scope.create_local(Elem::Bool); let condition = scope.create_local(Elem::Bool);
@ -49,7 +49,7 @@ impl<E: JitElement> ReduceDimShared<E> for Argmax {
read_position: Variable, read_position: Variable,
i: Variable, i: Variable,
) -> Self::Accumulator { ) -> Self::Accumulator {
let value = scope.create_local(input.item()); let value = scope.create_local(input.item);
cpa!(scope, value = input[read_position]); cpa!(scope, value = input[read_position]);
(value, i) (value, i)
} }
@ -60,9 +60,9 @@ impl<E: JitElement> ReduceDimShared<E> for Argmax {
read_position: Variable, read_position: Variable,
) -> Self::Accumulator { ) -> Self::Accumulator {
let (value_shared_memory, index_shared_memory) = shared_memory; let (value_shared_memory, index_shared_memory) = shared_memory;
let value = scope.create_local(value_shared_memory.item()); let value = scope.create_local(value_shared_memory.item);
cpa!(scope, value = value_shared_memory[read_position]); cpa!(scope, value = value_shared_memory[read_position]);
let index = scope.create_local(index_shared_memory.item()); let index = scope.create_local(index_shared_memory.item);
cpa!(scope, index = index_shared_memory[read_position]); cpa!(scope, index = index_shared_memory[read_position]);
(value, index) (value, index)
} }
@ -75,7 +75,7 @@ impl<E: JitElement> ReduceDimShared<E> for Argmax {
_shape_reduce_dim: Variable, _shape_reduce_dim: Variable,
) { ) {
let (_, index_shared_memory) = shared_memory; let (_, index_shared_memory) = shared_memory;
let final_value = scope.create_local(output.item()); let final_value = scope.create_local(index_shared_memory.item);
cpa!(scope, final_value = index_shared_memory[0]); cpa!(scope, final_value = index_shared_memory[0]);
cpa!(scope, output[write_position] = final_value); cpa!(scope, output[write_position] = final_value);
} }

View File

@ -33,7 +33,7 @@ impl<E: JitElement> ReduceDimShared<E> for Argmin {
(value, index): Self::Accumulator, (value, index): Self::Accumulator,
) { ) {
let (value_shared_memory, index_shared_memory) = shared_memory; let (value_shared_memory, index_shared_memory) = shared_memory;
let current_value = scope.create_local(value.item()); let current_value = scope.create_local(value.item);
cpa!(scope, current_value = value_shared_memory[write_position]); cpa!(scope, current_value = value_shared_memory[write_position]);
let condition = scope.create_local(Elem::Bool); let condition = scope.create_local(Elem::Bool);
@ -50,7 +50,7 @@ impl<E: JitElement> ReduceDimShared<E> for Argmin {
read_position: Variable, read_position: Variable,
i: Variable, i: Variable,
) -> Self::Accumulator { ) -> Self::Accumulator {
let value = scope.create_local(input.item()); let value = scope.create_local(input.item);
cpa!(scope, value = input[read_position]); cpa!(scope, value = input[read_position]);
(value, i) (value, i)
} }
@ -61,9 +61,9 @@ impl<E: JitElement> ReduceDimShared<E> for Argmin {
read_position: Variable, read_position: Variable,
) -> Self::Accumulator { ) -> Self::Accumulator {
let (value_shared_memory, index_shared_memory) = shared_memory; let (value_shared_memory, index_shared_memory) = shared_memory;
let value = scope.create_local(value_shared_memory.item()); let value = scope.create_local(value_shared_memory.item);
cpa!(scope, value = value_shared_memory[read_position]); cpa!(scope, value = value_shared_memory[read_position]);
let index = scope.create_local(index_shared_memory.item()); let index = scope.create_local(index_shared_memory.item);
cpa!(scope, index = index_shared_memory[read_position]); cpa!(scope, index = index_shared_memory[read_position]);
(value, index) (value, index)
} }
@ -76,7 +76,7 @@ impl<E: JitElement> ReduceDimShared<E> for Argmin {
_shape_reduce_dim: Variable, _shape_reduce_dim: Variable,
) { ) {
let (_, index_shared_memory) = shared_memory; let (_, index_shared_memory) = shared_memory;
let final_value = scope.create_local(output.item()); let final_value = scope.create_local(index_shared_memory.item);
cpa!(scope, final_value = index_shared_memory[0]); cpa!(scope, final_value = index_shared_memory[0]);
cpa!(scope, output[write_position] = final_value); cpa!(scope, output[write_position] = final_value);
} }

View File

@ -16,7 +16,7 @@ impl<E: JitElement> ReduceDimShared<E> for MeanDim {
input_item: Item, input_item: Item,
) -> Self::Accumulator { ) -> Self::Accumulator {
let shared_memory = scope.create_shared(input_item, shared_memory_size); let shared_memory = scope.create_shared(input_item, shared_memory_size);
let neutral_element = scope.zero(shared_memory.item()); let neutral_element = scope.zero(shared_memory.item);
cpa!(scope, shared_memory[write_position] = neutral_element); cpa!(scope, shared_memory[write_position] = neutral_element);
shared_memory shared_memory
} }
@ -27,8 +27,8 @@ impl<E: JitElement> ReduceDimShared<E> for MeanDim {
write_position: Variable, write_position: Variable,
value: Self::Accumulator, value: Self::Accumulator,
) { ) {
let current_value = scope.create_local(value.item()); let current_value = scope.create_local(value.item);
let computed = scope.create_local(value.item()); let computed = scope.create_local(value.item);
cpa!(scope, current_value = shared_memory[write_position]); cpa!(scope, current_value = shared_memory[write_position]);
cpa!(scope, computed = current_value + value); cpa!(scope, computed = current_value + value);
cpa!(scope, shared_memory[write_position] = computed); cpa!(scope, shared_memory[write_position] = computed);
@ -40,7 +40,7 @@ impl<E: JitElement> ReduceDimShared<E> for MeanDim {
read_position: Variable, read_position: Variable,
_i: Variable, _i: Variable,
) -> Self::Accumulator { ) -> Self::Accumulator {
let value = scope.create_local(input.item()); let value = scope.create_local(input.item);
cpa!(scope, value = input[read_position]); cpa!(scope, value = input[read_position]);
value value
} }
@ -50,7 +50,7 @@ impl<E: JitElement> ReduceDimShared<E> for MeanDim {
shared_memory: Self::Accumulator, shared_memory: Self::Accumulator,
read_position: Variable, read_position: Variable,
) -> Variable { ) -> Variable {
let read_value = scope.create_local(shared_memory.item()); let read_value = scope.create_local(shared_memory.item);
cpa!(scope, read_value = shared_memory[read_position]); cpa!(scope, read_value = shared_memory[read_position]);
read_value read_value
} }
@ -62,10 +62,10 @@ impl<E: JitElement> ReduceDimShared<E> for MeanDim {
write_position: Variable, write_position: Variable,
shape_reduce_dim: Variable, shape_reduce_dim: Variable,
) { ) {
let final_value = scope.create_local(output.item()); let final_value = scope.create_local(output.item);
cpa!(scope, final_value = shared_memory[0]); cpa!(scope, final_value = shared_memory[0]);
let denominator = scope.create_local(output.item()); let denominator = scope.create_local(output.item);
cpa!(scope, denominator = cast(shape_reduce_dim)); cpa!(scope, denominator = cast(shape_reduce_dim));
cpa!(scope, final_value = final_value / denominator); cpa!(scope, final_value = final_value / denominator);
cpa!(scope, output[write_position] = final_value); cpa!(scope, output[write_position] = final_value);

View File

@ -16,7 +16,7 @@ impl<E: JitElement> ReduceDimShared<E> for ProdDim {
input_item: Item, input_item: Item,
) -> Self::Accumulator { ) -> Self::Accumulator {
let shared_memory = scope.create_shared(input_item, shared_memory_size); let shared_memory = scope.create_shared(input_item, shared_memory_size);
let neutral_element = scope.create_with_value(1, shared_memory.item()); let neutral_element = scope.create_with_value(1, shared_memory.item);
cpa!(scope, shared_memory[write_position] = neutral_element); cpa!(scope, shared_memory[write_position] = neutral_element);
shared_memory shared_memory
} }
@ -27,8 +27,8 @@ impl<E: JitElement> ReduceDimShared<E> for ProdDim {
write_position: Variable, write_position: Variable,
value: Self::Accumulator, value: Self::Accumulator,
) { ) {
let current_value = scope.create_local(value.item()); let current_value = scope.create_local(value.item);
let computed = scope.create_local(value.item()); let computed = scope.create_local(value.item);
cpa!(scope, current_value = shared_memory[write_position]); cpa!(scope, current_value = shared_memory[write_position]);
cpa!(scope, computed = current_value * value); cpa!(scope, computed = current_value * value);
cpa!(scope, shared_memory[write_position] = computed); cpa!(scope, shared_memory[write_position] = computed);
@ -40,7 +40,7 @@ impl<E: JitElement> ReduceDimShared<E> for ProdDim {
read_position: Variable, read_position: Variable,
_i: Variable, _i: Variable,
) -> Self::Accumulator { ) -> Self::Accumulator {
let value = scope.create_local(input.item()); let value = scope.create_local(input.item);
cpa!(scope, value = input[read_position]); cpa!(scope, value = input[read_position]);
value value
} }
@ -50,7 +50,7 @@ impl<E: JitElement> ReduceDimShared<E> for ProdDim {
shared_memory: Self::Accumulator, shared_memory: Self::Accumulator,
read_position: Variable, read_position: Variable,
) -> Self::Accumulator { ) -> Self::Accumulator {
let read_value = scope.create_local(shared_memory.item()); let read_value = scope.create_local(shared_memory.item);
cpa!(scope, read_value = shared_memory[read_position]); cpa!(scope, read_value = shared_memory[read_position]);
read_value read_value
} }
@ -62,7 +62,7 @@ impl<E: JitElement> ReduceDimShared<E> for ProdDim {
write_position: Variable, write_position: Variable,
_shape_reduce_dim: Variable, _shape_reduce_dim: Variable,
) { ) {
let final_value = scope.create_local(output.item()); let final_value = scope.create_local(output.item);
cpa!(scope, final_value = shared_memory[0]); cpa!(scope, final_value = shared_memory[0]);
cpa!(scope, output[write_position] = final_value); cpa!(scope, output[write_position] = final_value);
} }

View File

@ -1,6 +1,9 @@
use cubecl::{ use cubecl::{
cpa, ir::KernelDefinition, prelude::CubeCount, CubeCountSettings, Execution, InputInfo, cpa,
KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, ir::{Builtin, KernelDefinition, VariableKind},
prelude::CubeCount,
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo,
}; };
use std::marker::PhantomData; use std::marker::PhantomData;
@ -51,14 +54,8 @@ impl<RD: ReduceDimShared<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Ker
let item_input = EI::cube_elem().into(); let item_input = EI::cube_elem().into();
let item_output = EO::cube_elem().into(); let item_output = EO::cube_elem().into();
let tensor = Variable::GlobalInputArray { let tensor = Variable::new(VariableKind::GlobalInputArray(0), item_input);
id: 0, let output = Variable::new(VariableKind::GlobalOutputArray(0), item_output);
item: item_input,
};
let output = Variable::GlobalOutputArray {
id: 0,
item: item_output,
};
// Reduce groups are elements that are aligned along the reduce dim // Reduce groups are elements that are aligned along the reduce dim
SharedReduceDimComputeShader { SharedReduceDimComputeShader {
@ -112,16 +109,16 @@ impl<E: JitElement, RD: ReduceDimShared<E>> SharedReduceDimComputeShader<E, RD>
let tensor = self.tensor; let tensor = self.tensor;
let output = self.output; let output = self.output;
let rank = Variable::Rank; let rank = Variable::builtin(Builtin::Rank);
let dim: Variable = self.dim.into(); let dim: Variable = self.dim.into();
let cube_pos_x = Variable::CubePosX; let cube_pos_x = Variable::builtin(Builtin::CubePosX);
let cube_pos_y = Variable::CubePosY; let cube_pos_y = Variable::builtin(Builtin::CubePosY);
let cube_count_x = Variable::CubeCountX; let cube_count_x = Variable::builtin(Builtin::CubeCountX);
let local_invocation_id_x = Variable::UnitPosX; let local_invocation_id_x = Variable::builtin(Builtin::UnitPosX);
let local_invocation_id_y = Variable::UnitPosY; let local_invocation_id_y = Variable::builtin(Builtin::UnitPosY);
let cube_dim_x = Variable::CubeDimX; let cube_dim_x = Variable::builtin(Builtin::CubeDimX);
let cube_dim_y = Variable::CubeDimY; let cube_dim_y = Variable::builtin(Builtin::CubeDimY);
let stride_reduce_dim_input = scope.create_local(Elem::UInt); let stride_reduce_dim_input = scope.create_local(Elem::UInt);
cpa!(scope, stride_reduce_dim_input = stride(tensor, dim)); cpa!(scope, stride_reduce_dim_input = stride(tensor, dim));
@ -162,12 +159,8 @@ impl<E: JitElement, RD: ReduceDimShared<E>> SharedReduceDimComputeShader<E, RD>
}) })
); );
let shared_memory = RD::initialize_shared( let shared_memory =
scope, RD::initialize_shared(scope, self.shared_memory_size as u32, local_id, tensor.item);
self.shared_memory_size as u32,
local_id,
tensor.item(),
);
// Load to shared memory, unrolled // Load to shared memory, unrolled
cpa!( cpa!(

View File

@ -16,7 +16,7 @@ impl<E: JitElement> ReduceDimShared<E> for SumDim {
input_item: Item, input_item: Item,
) -> Self::Accumulator { ) -> Self::Accumulator {
let shared_memory = scope.create_shared(input_item, shared_memory_size); let shared_memory = scope.create_shared(input_item, shared_memory_size);
let neutral_element = scope.zero(shared_memory.item()); let neutral_element = scope.zero(shared_memory.item);
cpa!(scope, shared_memory[write_position] = neutral_element); cpa!(scope, shared_memory[write_position] = neutral_element);
shared_memory shared_memory
} }
@ -27,8 +27,8 @@ impl<E: JitElement> ReduceDimShared<E> for SumDim {
write_position: Variable, write_position: Variable,
value: Self::Accumulator, value: Self::Accumulator,
) { ) {
let current_value = scope.create_local(value.item()); let current_value = scope.create_local(value.item);
let computed = scope.create_local(value.item()); let computed = scope.create_local(value.item);
cpa!(scope, current_value = shared_memory[write_position]); cpa!(scope, current_value = shared_memory[write_position]);
cpa!(scope, computed = current_value + value); cpa!(scope, computed = current_value + value);
cpa!(scope, shared_memory[write_position] = computed); cpa!(scope, shared_memory[write_position] = computed);
@ -40,7 +40,7 @@ impl<E: JitElement> ReduceDimShared<E> for SumDim {
read_position: Variable, read_position: Variable,
_i: Variable, _i: Variable,
) -> Self::Accumulator { ) -> Self::Accumulator {
let value = scope.create_local(input.item()); let value = scope.create_local(input.item);
cpa!(scope, value = input[read_position]); cpa!(scope, value = input[read_position]);
value value
} }
@ -50,7 +50,7 @@ impl<E: JitElement> ReduceDimShared<E> for SumDim {
shared_memory: Self::Accumulator, shared_memory: Self::Accumulator,
read_position: Variable, read_position: Variable,
) -> Self::Accumulator { ) -> Self::Accumulator {
let read_value = scope.create_local(shared_memory.item()); let read_value = scope.create_local(shared_memory.item);
cpa!(scope, read_value = shared_memory[read_position]); cpa!(scope, read_value = shared_memory[read_position]);
read_value read_value
} }
@ -62,7 +62,7 @@ impl<E: JitElement> ReduceDimShared<E> for SumDim {
write_position: Variable, write_position: Variable,
_shape_reduce_dim: Variable, _shape_reduce_dim: Variable,
) { ) {
let final_value = scope.create_local(output.item()); let final_value = scope.create_local(output.item);
cpa!(scope, final_value = shared_memory[0]); cpa!(scope, final_value = shared_memory[0]);
cpa!(scope, output[write_position] = final_value); cpa!(scope, output[write_position] = final_value);
} }

View File

@ -73,6 +73,7 @@ mod cube_wgpu {
WgpuDevice::Existing(id) => { WgpuDevice::Existing(id) => {
DeviceId::new(5, (id.inner() % (u32::MAX as u64)) as u32) DeviceId::new(5, (id.inner() % (u32::MAX as u64)) as u32)
} }
WgpuDevice::DefaultDevice => DeviceId::new(6, 0),
} }
} }
} }

View File

@ -8,9 +8,13 @@ use core::convert::Into;
use crate::model::{label::LABELS, normalizer::Normalizer, squeezenet::Model as SqueezenetModel}; use crate::model::{label::LABELS, normalizer::Normalizer, squeezenet::Model as SqueezenetModel};
use burn::{backend::NdArray, prelude::*, tensor::activation::softmax}; use burn::{
backend::{wgpu::init_device, NdArray},
prelude::*,
tensor::activation::softmax,
};
use burn::backend::wgpu::{init_async, AutoGraphicsApi, Wgpu, WgpuDevice}; use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
use burn_candle::Candle; use burn_candle::Candle;
use serde::Serialize; use serde::Serialize;
@ -106,7 +110,7 @@ impl ImageClassifier {
log::info!("Loading the model to the Wgpu backend"); log::info!("Loading the model to the Wgpu backend");
let start = Instant::now(); let start = Instant::now();
let device = WgpuDevice::default(); let device = WgpuDevice::default();
init_async::<AutoGraphicsApi>(&device, Default::default()).await; init_device::<AutoGraphicsApi>(&device, Default::default()).await;
self.model = ModelType::WithWgpuBackend(Model::new(&device)); self.model = ModelType::WithWgpuBackend(Model::new(&device));
let duration = start.elapsed(); let duration = start.elapsed();
log::debug!("Model is loaded to the Wgpu backend in {:?}", duration); log::debug!("Model is loaded to the Wgpu backend in {:?}", duration);