mirror of https://github.com/tracel-ai/burn.git
[WIP] Migrate to `cubecl` IR refactor (#2418)
This commit is contained in:
parent
69856a97db
commit
5730f022fd
|
@ -524,7 +524,7 @@ dependencies = [
|
|||
name = "burn-common"
|
||||
version = "0.16.0"
|
||||
dependencies = [
|
||||
"cubecl-common",
|
||||
"cubecl-common 0.4.0",
|
||||
"dashmap",
|
||||
"getrandom",
|
||||
"indicatif",
|
||||
|
@ -1459,15 +1459,14 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "cubecl"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "75e75c7e982b943380665c5901fe0b69d5df2627644e0e50199c52b64d8d5a1c"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
|
||||
dependencies = [
|
||||
"cubecl-core",
|
||||
"cubecl-cuda",
|
||||
"cubecl-hip",
|
||||
"cubecl-linalg",
|
||||
"cubecl-runtime",
|
||||
"cubecl-runtime 0.4.0",
|
||||
"cubecl-wgpu",
|
||||
]
|
||||
|
||||
|
@ -1489,16 +1488,32 @@ dependencies = [
|
|||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cubecl-common"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
|
||||
dependencies = [
|
||||
"derive-new",
|
||||
"embassy-futures",
|
||||
"futures-lite",
|
||||
"getrandom",
|
||||
"log",
|
||||
"portable-atomic",
|
||||
"rand",
|
||||
"serde",
|
||||
"spin",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cubecl-core"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec33b64139d1dfc747df8aed5834d10c3c55c716f5219041c6eb17241c96c929"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common",
|
||||
"cubecl-common 0.4.0",
|
||||
"cubecl-macros",
|
||||
"cubecl-runtime",
|
||||
"cubecl-runtime 0.4.0",
|
||||
"derive-new",
|
||||
"half",
|
||||
"log",
|
||||
|
@ -1509,14 +1524,13 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "cubecl-cpp"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ded461feb0ff342a4f675131dc0ae8ad94e58f66bad11e57f852cb7f190a731"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common",
|
||||
"cubecl-common 0.4.0",
|
||||
"cubecl-core",
|
||||
"cubecl-runtime",
|
||||
"cubecl-runtime 0.4.0",
|
||||
"derive-new",
|
||||
"half",
|
||||
"log",
|
||||
|
@ -1524,15 +1538,14 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "cubecl-cuda"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "88dfdfe616124d2abe5e82052ff56f86843c369440e181d6936f7409e161dd82"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common",
|
||||
"cubecl-common 0.4.0",
|
||||
"cubecl-core",
|
||||
"cubecl-cpp",
|
||||
"cubecl-runtime",
|
||||
"cubecl-runtime 0.4.0",
|
||||
"cudarc 0.12.1",
|
||||
"derive-new",
|
||||
"half",
|
||||
|
@ -1541,16 +1554,15 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "cubecl-hip"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "409e0e176152ab51a60bbebb940b7a72aba210cd42b5f8cd2e87e7d7e674a13a"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common",
|
||||
"cubecl-common 0.4.0",
|
||||
"cubecl-core",
|
||||
"cubecl-cpp",
|
||||
"cubecl-hip-sys",
|
||||
"cubecl-runtime",
|
||||
"cubecl-runtime 0.4.0",
|
||||
"derive-new",
|
||||
"half",
|
||||
"log",
|
||||
|
@ -1567,23 +1579,21 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "cubecl-linalg"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3c5634782d790e9b6562fc267ffd15e9a510b4d6ec32c144cd2b166af2ba0cfb"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-core",
|
||||
"cubecl-runtime",
|
||||
"cubecl-runtime 0.4.0",
|
||||
"half",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cubecl-macros"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d2d22663257d9cdbcd67f5048d6f4e6eb965dd87104c3a173a7b0ea0d720e99b"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
|
||||
dependencies = [
|
||||
"cubecl-common",
|
||||
"cubecl-common 0.4.0",
|
||||
"darling",
|
||||
"derive-new",
|
||||
"ident_case",
|
||||
|
@ -1595,11 +1605,10 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "cubecl-opt"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59fba2561a6ceb99e9c5fe7313db0aeead02b848dc9cdacf8373e0fe98d3c247"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
|
||||
dependencies = [
|
||||
"cubecl-common",
|
||||
"cubecl-common 0.4.0",
|
||||
"cubecl-core",
|
||||
"float-ord",
|
||||
"log",
|
||||
|
@ -1618,7 +1627,28 @@ dependencies = [
|
|||
"async-channel",
|
||||
"async-lock",
|
||||
"cfg_aliases 0.2.1",
|
||||
"cubecl-common",
|
||||
"cubecl-common 0.3.0",
|
||||
"derive-new",
|
||||
"dirs 5.0.1",
|
||||
"hashbrown 0.14.5",
|
||||
"log",
|
||||
"md5",
|
||||
"sanitize-filename",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"spin",
|
||||
"wasm-bindgen-futures",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cubecl-runtime"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
|
||||
dependencies = [
|
||||
"async-channel",
|
||||
"async-lock",
|
||||
"cfg_aliases 0.2.1",
|
||||
"cubecl-common 0.4.0",
|
||||
"derive-new",
|
||||
"dirs 5.0.1",
|
||||
"hashbrown 0.14.5",
|
||||
|
@ -1633,31 +1663,29 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "cubecl-spirv"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "835bc234cdd40fbb5e3e5e41bfb4a6e2ee2d7fd899b66d44dcbcb3a825ca1a59"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
|
||||
dependencies = [
|
||||
"cubecl-common",
|
||||
"cubecl-common 0.4.0",
|
||||
"cubecl-core",
|
||||
"cubecl-opt",
|
||||
"cubecl-runtime",
|
||||
"cubecl-runtime 0.4.0",
|
||||
"hashbrown 0.14.5",
|
||||
"rspirv",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cubecl-wgpu"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6779f1072d70923758421c6214fd0cd19a6f25b91035a522f9cd9407d03b5cae"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e"
|
||||
dependencies = [
|
||||
"ash",
|
||||
"async-channel",
|
||||
"bytemuck",
|
||||
"cfg_aliases 0.2.1",
|
||||
"cubecl-common",
|
||||
"cubecl-common 0.4.0",
|
||||
"cubecl-core",
|
||||
"cubecl-runtime",
|
||||
"cubecl-runtime 0.4.0",
|
||||
"cubecl-spirv",
|
||||
"derive-new",
|
||||
"hashbrown 0.14.5",
|
||||
|
@ -3265,7 +3293,7 @@ dependencies = [
|
|||
"burn-candle",
|
||||
"burn-import",
|
||||
"console_error_panic_hook",
|
||||
"cubecl-runtime",
|
||||
"cubecl-runtime 0.3.0",
|
||||
"js-sys",
|
||||
"log",
|
||||
"serde",
|
||||
|
@ -3793,7 +3821,7 @@ version = "0.16.0"
|
|||
dependencies = [
|
||||
"burn",
|
||||
"console_error_panic_hook",
|
||||
"cubecl-runtime",
|
||||
"cubecl-runtime 0.3.0",
|
||||
"js-sys",
|
||||
"serde",
|
||||
"wasm-bindgen",
|
||||
|
|
|
@ -152,14 +152,14 @@ ahash = { version = "0.8.11", default-features = false }
|
|||
portable-atomic-util = { version = "0.2.2", features = ["alloc"] }
|
||||
|
||||
### For the main burn branch. ###
|
||||
# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "63da837b5ae78ff1a3b7363fe3af2b02c2bc864f" }
|
||||
# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "63da837b5ae78ff1a3b7363fe3af2b02c2bc864f" }
|
||||
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "99404b1e29946832a42b72a5c26d4cf42c67692e" }
|
||||
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "99404b1e29946832a42b72a5c26d4cf42c67692e" }
|
||||
### For local development. ###
|
||||
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
|
||||
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
|
||||
### For the release. ###
|
||||
cubecl = { version="0.3.0", default-features = false }
|
||||
cubecl-common = { version="0.3.0", default-features = false }
|
||||
# cubecl = { version = "0.3.0", default-features = false }
|
||||
# cubecl-common = { version = "0.3.0", default-features = false }
|
||||
|
||||
### For xtask crate ###
|
||||
tracel-xtask = { version = "~1.1" }
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Elem, IntKind, KernelDefinition, Scope, Variable, Visibility},
|
||||
ir::{
|
||||
Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility,
|
||||
},
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -39,7 +41,7 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
|
|||
let weight = self.weight;
|
||||
let bias = self.bias;
|
||||
let output = self.output;
|
||||
let id = Variable::AbsolutePos;
|
||||
let id = Variable::builtin(Builtin::AbsolutePos);
|
||||
|
||||
let input_stride_0 = scope.create_local(Elem::UInt);
|
||||
let input_stride_1 = scope.create_local(Elem::UInt);
|
||||
|
@ -92,34 +94,13 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
|
|||
cpa!(scope, kernel_size_0 = shape(weight, 2u32));
|
||||
cpa!(scope, kernel_size_1 = shape(weight, 3u32));
|
||||
|
||||
let conv_stride_0 = Variable::GlobalScalar {
|
||||
id: 0,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let conv_stride_1 = Variable::GlobalScalar {
|
||||
id: 1,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let dilation_0 = Variable::GlobalScalar {
|
||||
id: 2,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let dilation_1 = Variable::GlobalScalar {
|
||||
id: 3,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_0 = Variable::GlobalScalar {
|
||||
id: 4,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_1 = Variable::GlobalScalar {
|
||||
id: 5,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let groups = Variable::GlobalScalar {
|
||||
id: 6,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let conv_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
|
||||
let conv_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
|
||||
let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt));
|
||||
let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt));
|
||||
let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt));
|
||||
let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt));
|
||||
let groups = Variable::new(VariableKind::GlobalScalar(6), Item::new(Elem::UInt));
|
||||
|
||||
let stride_0_i = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let stride_1_i = scope.create_local(Elem::Int(IntKind::I32));
|
||||
|
@ -222,9 +203,9 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
|
|||
cpa!(scope, index_input_b = b * input_stride_0);
|
||||
cpa!(scope, index_weight_oc = oc * weight_stride_1);
|
||||
|
||||
let prod = scope.create_local(output.item());
|
||||
let prod_tmp = scope.create_local(output.item());
|
||||
let sum = scope.create_local(output.item());
|
||||
let prod = scope.create_local(output.item);
|
||||
let prod_tmp = scope.create_local(output.item);
|
||||
let sum = scope.create_local(output.item);
|
||||
cpa!(scope, sum = bias[oc_out]);
|
||||
|
||||
let kh = scope.create_local(Elem::UInt);
|
||||
|
@ -314,10 +295,10 @@ impl<R: JitRuntime, E: JitElement> Kernel for Conv2dTransposeEagerKernel<R, E> {
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray { id: 0, item };
|
||||
let weight = Variable::GlobalInputArray { id: 1, item };
|
||||
let bias = Variable::GlobalInputArray { id: 2, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
let input = Variable::new(VariableKind::GlobalInputArray(0), item);
|
||||
let weight = Variable::new(VariableKind::GlobalInputArray(1), item);
|
||||
let bias = Variable::new(VariableKind::GlobalInputArray(2), item);
|
||||
let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Elem, IntKind, KernelDefinition, Scope, Variable, Visibility},
|
||||
ir::{
|
||||
Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility,
|
||||
},
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -39,7 +41,7 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
|
|||
let weight = self.weight;
|
||||
let bias = self.bias;
|
||||
let output = self.output;
|
||||
let idx = Variable::AbsolutePos;
|
||||
let idx = Variable::builtin(Builtin::AbsolutePos);
|
||||
|
||||
let input_stride_0 = scope.create_local(Elem::UInt);
|
||||
let input_stride_1 = scope.create_local(Elem::UInt);
|
||||
|
@ -104,46 +106,16 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
|
|||
cpa!(scope, kernel_size_1 = shape(weight, 3u32));
|
||||
cpa!(scope, kernel_size_2 = shape(weight, 4u32));
|
||||
|
||||
let conv_stride_0 = Variable::GlobalScalar {
|
||||
id: 0,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let conv_stride_1 = Variable::GlobalScalar {
|
||||
id: 1,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let conv_stride_2 = Variable::GlobalScalar {
|
||||
id: 2,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let dilation_0 = Variable::GlobalScalar {
|
||||
id: 3,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let dilation_1 = Variable::GlobalScalar {
|
||||
id: 4,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let dilation_2 = Variable::GlobalScalar {
|
||||
id: 5,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_0 = Variable::GlobalScalar {
|
||||
id: 6,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_1 = Variable::GlobalScalar {
|
||||
id: 7,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_2 = Variable::GlobalScalar {
|
||||
id: 8,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let groups = Variable::GlobalScalar {
|
||||
id: 9,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let conv_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
|
||||
let conv_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
|
||||
let conv_stride_2 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt));
|
||||
let dilation_0 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt));
|
||||
let dilation_1 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt));
|
||||
let dilation_2 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt));
|
||||
let padding_0 = Variable::new(VariableKind::GlobalScalar(6), Item::new(Elem::UInt));
|
||||
let padding_1 = Variable::new(VariableKind::GlobalScalar(7), Item::new(Elem::UInt));
|
||||
let padding_2 = Variable::new(VariableKind::GlobalScalar(8), Item::new(Elem::UInt));
|
||||
let groups = Variable::new(VariableKind::GlobalScalar(9), Item::new(Elem::UInt));
|
||||
|
||||
let stride_0_i = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let stride_1_i = scope.create_local(Elem::Int(IntKind::I32));
|
||||
|
@ -273,9 +245,9 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
|
|||
cpa!(scope, index_input_b = b * input_stride_0);
|
||||
cpa!(scope, index_weight_oc = oc * weight_stride_1);
|
||||
|
||||
let prod = scope.create_local(output.item());
|
||||
let prod_tmp = scope.create_local(output.item());
|
||||
let sum = scope.create_local(output.item());
|
||||
let prod = scope.create_local(output.item);
|
||||
let prod_tmp = scope.create_local(output.item);
|
||||
let sum = scope.create_local(output.item);
|
||||
cpa!(scope, sum = bias[oc_out]);
|
||||
|
||||
let kd = scope.create_local(Elem::UInt);
|
||||
|
@ -391,10 +363,10 @@ impl<R: JitRuntime, E: JitElement> Kernel for Conv3dTransposeEagerKernel<R, E> {
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray { id: 0, item };
|
||||
let weight = Variable::GlobalInputArray { id: 1, item };
|
||||
let bias = Variable::GlobalInputArray { id: 2, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
let input = Variable::new(VariableKind::GlobalInputArray(0), item);
|
||||
let weight = Variable::new(VariableKind::GlobalInputArray(1), item);
|
||||
let bias = Variable::new(VariableKind::GlobalInputArray(2), item);
|
||||
let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ use crate::{
|
|||
use burn_tensor::ElementConversion;
|
||||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Elem, KernelDefinition, Scope, Variable, Visibility},
|
||||
ir::{Builtin, Elem, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility},
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -27,7 +27,7 @@ impl FlipComputeShader {
|
|||
pub fn expand(self, scope: &mut Scope) {
|
||||
let input = self.input;
|
||||
let output = self.output;
|
||||
let id = Variable::AbsolutePos;
|
||||
let id = Variable::builtin(Builtin::AbsolutePos);
|
||||
|
||||
let offset_input = scope.zero(Elem::UInt);
|
||||
let offset_local = scope.create_local(Elem::UInt);
|
||||
|
@ -42,10 +42,10 @@ impl FlipComputeShader {
|
|||
cpa!(scope, shape = shape(output, i));
|
||||
cpa!(
|
||||
scope,
|
||||
flip = cast(Variable::GlobalScalar {
|
||||
id: i as u16,
|
||||
elem: Elem::UInt
|
||||
})
|
||||
flip = cast(Variable::new(
|
||||
VariableKind::GlobalScalar(i as u16),
|
||||
Item::new(Elem::UInt)
|
||||
))
|
||||
);
|
||||
cpa!(scope, flip_bool = flip == 1u32);
|
||||
|
||||
|
@ -61,7 +61,7 @@ impl FlipComputeShader {
|
|||
cpa!(scope, offset_input += offset_local);
|
||||
}
|
||||
|
||||
let result = scope.create_local(input.item());
|
||||
let result = scope.create_local(input.item);
|
||||
cpa!(scope, result = input[offset_input]);
|
||||
cpa!(scope, output[id] = result);
|
||||
}
|
||||
|
@ -72,8 +72,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for FlipEagerKernel<R, E> {
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray { id: 0, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
let input = Variable::new(VariableKind::GlobalInputArray(0), item);
|
||||
let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime};
|
||||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Elem, KernelDefinition, Scope, Variable, Visibility},
|
||||
ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility},
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -26,7 +26,7 @@ impl RepeatComputeShader {
|
|||
pub fn expand(self, scope: &mut Scope) {
|
||||
let input = self.input;
|
||||
let output = self.output;
|
||||
let id = Variable::AbsolutePos;
|
||||
let id = Variable::builtin(Builtin::AbsolutePos);
|
||||
|
||||
let offset_input = scope.zero(Elem::UInt);
|
||||
let offset_local = scope.zero(Elem::UInt);
|
||||
|
@ -50,7 +50,7 @@ impl RepeatComputeShader {
|
|||
cpa!(scope, offset_input += offset_local);
|
||||
}
|
||||
|
||||
let result = scope.create_local(input.item());
|
||||
let result = scope.create_local(input.item);
|
||||
cpa!(scope, result = input[offset_input]);
|
||||
cpa!(scope, output[id] = result);
|
||||
}
|
||||
|
@ -60,8 +60,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for RepeatEagerKernel<R, E> {
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray { id: 0, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
let input = Variable::new(VariableKind::GlobalInputArray(0), item);
|
||||
let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ use crate::{
|
|||
use burn_tensor::{ElementConversion, Shape};
|
||||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Elem, KernelDefinition, Scope, Variable, Visibility},
|
||||
ir::{Builtin, Elem, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility},
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -27,7 +27,7 @@ impl SliceComputeShader {
|
|||
pub fn expand(self, scope: &mut Scope) {
|
||||
let input = self.input;
|
||||
let output = self.output;
|
||||
let id = Variable::AbsolutePos;
|
||||
let id = Variable::builtin(Builtin::AbsolutePos);
|
||||
|
||||
let offset_input = scope.zero(Elem::UInt);
|
||||
let offset_local = scope.create_local(Elem::UInt);
|
||||
|
@ -43,10 +43,10 @@ impl SliceComputeShader {
|
|||
cpa!(scope, shape_output = shape(output, i));
|
||||
cpa!(
|
||||
scope,
|
||||
range_start = cast(Variable::GlobalScalar {
|
||||
id: i as u16,
|
||||
elem: Elem::UInt
|
||||
})
|
||||
range_start = cast(Variable::new(
|
||||
VariableKind::GlobalScalar(i as u16),
|
||||
Item::new(Elem::UInt)
|
||||
))
|
||||
);
|
||||
|
||||
cpa!(scope, offset_local = id / stride_output);
|
||||
|
@ -57,7 +57,7 @@ impl SliceComputeShader {
|
|||
cpa!(scope, offset_input += offset_local);
|
||||
}
|
||||
|
||||
let result = scope.create_local(input.item());
|
||||
let result = scope.create_local(input.item);
|
||||
cpa!(scope, result = input[offset_input]);
|
||||
cpa!(scope, output[id] = result);
|
||||
}
|
||||
|
@ -68,8 +68,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for SliceEagerKernel<R, E> {
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray { id: 0, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
let input = Variable::new(VariableKind::GlobalInputArray(0), item);
|
||||
let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime};
|
|||
use burn_tensor::ElementConversion;
|
||||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Elem, KernelDefinition, Scope, Variable, Visibility},
|
||||
ir::{Builtin, Elem, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility},
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
};
|
||||
use std::{marker::PhantomData, ops::Range};
|
||||
|
@ -24,7 +24,7 @@ impl SliceAssignComputeShader {
|
|||
pub fn expand(self, scope: &mut Scope) {
|
||||
let input = self.input;
|
||||
let value = self.value;
|
||||
let id = Variable::AbsolutePos;
|
||||
let id = Variable::builtin(Builtin::AbsolutePos);
|
||||
|
||||
let offset_input = scope.zero(Elem::UInt);
|
||||
let offset_value = scope.zero(Elem::UInt);
|
||||
|
@ -46,10 +46,10 @@ impl SliceAssignComputeShader {
|
|||
cpa!(scope, shape_input = shape(input, i));
|
||||
cpa!(
|
||||
scope,
|
||||
range_start = cast(Variable::GlobalScalar {
|
||||
id: i as u16,
|
||||
elem: Elem::UInt
|
||||
})
|
||||
range_start = cast(Variable::new(
|
||||
VariableKind::GlobalScalar(i as u16),
|
||||
Item::new(Elem::UInt)
|
||||
))
|
||||
);
|
||||
|
||||
cpa!(scope, offset_local = id / stride_value);
|
||||
|
@ -66,7 +66,7 @@ impl SliceAssignComputeShader {
|
|||
cpa!(scope, offset_input += offset_local_input);
|
||||
}
|
||||
|
||||
let result = scope.create_local(input.item());
|
||||
let result = scope.create_local(input.item);
|
||||
cpa!(scope, result = value[offset_value]);
|
||||
cpa!(scope, input[offset_input] = result);
|
||||
}
|
||||
|
@ -77,8 +77,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for SliceAssignEagerKernel<R, E> {
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray { id: 0, item };
|
||||
let value = Variable::GlobalInputArray { id: 1, item };
|
||||
let input = Variable::new(VariableKind::GlobalInputArray(0), item);
|
||||
let value = Variable::new(VariableKind::GlobalInputArray(1), item);
|
||||
|
||||
scope.write_global_custom(input);
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Elem, KernelDefinition, Scope, Variable, Visibility},
|
||||
ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility},
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -24,7 +24,7 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
|
|||
pub(crate) fn expand(self, scope: &mut Scope) {
|
||||
let input = self.input;
|
||||
let output = self.output;
|
||||
let id = Variable::AbsolutePos;
|
||||
let id = Variable::builtin(Builtin::AbsolutePos);
|
||||
let elem = E::cube_elem();
|
||||
|
||||
let input_stride_0 = scope.create_local(Elem::UInt);
|
||||
|
@ -181,10 +181,10 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
|
|||
let index_1 = scope.create_local(Elem::UInt);
|
||||
let index_2 = scope.create_local(Elem::UInt);
|
||||
let index_3 = scope.create_local(Elem::UInt);
|
||||
let inp_0 = scope.create_local(input.item());
|
||||
let inp_1 = scope.create_local(input.item());
|
||||
let inp_2 = scope.create_local(input.item());
|
||||
let inp_3 = scope.create_local(input.item());
|
||||
let inp_0 = scope.create_local(input.item);
|
||||
let inp_1 = scope.create_local(input.item);
|
||||
let inp_2 = scope.create_local(input.item);
|
||||
let inp_3 = scope.create_local(input.item);
|
||||
|
||||
cpa!(scope, index_0 = index_base);
|
||||
cpa!(scope, index_0 += y0_stride);
|
||||
|
@ -276,7 +276,7 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
|
|||
|
||||
fn min(scope: &mut Scope, a: Variable, b: Variable) -> Variable {
|
||||
let cond = scope.create_local(Elem::Bool);
|
||||
let res = scope.create_local(a.item());
|
||||
let res = scope.create_local(a.item);
|
||||
|
||||
cpa!(scope, cond = a < b);
|
||||
cpa!(scope, if(cond).then(|scope|{
|
||||
|
@ -296,7 +296,7 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
|
|||
x3: Variable,
|
||||
t: Variable,
|
||||
) -> Variable {
|
||||
let item = x0.item();
|
||||
let item = x0.item;
|
||||
let x = scope.create_local(item);
|
||||
let a: Variable = scope.create_with_value(-0.75, item);
|
||||
let one: Variable = scope.create_with_value(1, item);
|
||||
|
@ -327,7 +327,7 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
|
|||
}
|
||||
|
||||
fn cubic_convolution1(scope: &mut Scope, x: Variable, a: Variable) -> Variable {
|
||||
let item = x.item();
|
||||
let item = x.item;
|
||||
let conv = scope.create_local(item);
|
||||
let tmp = scope.create_local(item);
|
||||
let one = scope.create_with_value(1, item);
|
||||
|
@ -346,7 +346,7 @@ impl<E: JitElement> InterpolateBicubicShader<E> {
|
|||
}
|
||||
|
||||
fn cubic_convolution2(scope: &mut Scope, x: Variable, a: Variable) -> Variable {
|
||||
let item = x.item();
|
||||
let item = x.item;
|
||||
let conv = scope.create_local(item);
|
||||
let tmp = scope.create_local(item);
|
||||
let four = scope.create_with_value(4, item);
|
||||
|
@ -372,8 +372,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for InterpolateBicubicEagerKernel<R, E
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray { id: 0, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
let input = Variable::new(VariableKind::GlobalInputArray(0), item);
|
||||
let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
|
||||
|
||||
InterpolateBicubicShader {
|
||||
input,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Elem, KernelDefinition, Scope, Variable, Visibility},
|
||||
ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility},
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -23,7 +23,7 @@ impl InterpolateBilinearShader {
|
|||
pub(crate) fn expand(self, scope: &mut Scope) {
|
||||
let input = self.input;
|
||||
let output = self.output;
|
||||
let id = Variable::AbsolutePos;
|
||||
let id = Variable::builtin(Builtin::AbsolutePos);
|
||||
|
||||
let input_stride_0 = scope.create_local(Elem::UInt);
|
||||
let input_stride_1 = scope.create_local(Elem::UInt);
|
||||
|
@ -78,26 +78,26 @@ impl InterpolateBilinearShader {
|
|||
cpa!(scope, w = id / output_stride_3);
|
||||
cpa!(scope, w = w % output_shape_3);
|
||||
|
||||
let factor_float = scope.create_local(input.item());
|
||||
let numerator_float = scope.create_local(input.item());
|
||||
let factor_float = scope.create_local(input.item);
|
||||
let numerator_float = scope.create_local(input.item);
|
||||
let numerator_int = scope.create_local(Elem::UInt);
|
||||
let denominator_float = scope.create_local(input.item());
|
||||
let denominator_float = scope.create_local(input.item);
|
||||
let denominator_int = scope.create_local(Elem::UInt);
|
||||
|
||||
let frac = scope.create_local(input.item());
|
||||
let v0 = scope.create_local(input.item());
|
||||
let v1 = scope.create_local(input.item());
|
||||
let one = scope.create_with_value(1f32, input.item());
|
||||
let frac = scope.create_local(input.item);
|
||||
let v0 = scope.create_local(input.item);
|
||||
let v1 = scope.create_local(input.item);
|
||||
let one = scope.create_with_value(1f32, input.item);
|
||||
|
||||
let y0 = scope.create_local(Elem::UInt);
|
||||
let y1 = scope.create_local(Elem::UInt);
|
||||
let yw = scope.create_local(input.item());
|
||||
let yw_ = scope.create_local(input.item());
|
||||
let yw = scope.create_local(input.item);
|
||||
let yw_ = scope.create_local(input.item);
|
||||
|
||||
let x0 = scope.create_local(Elem::UInt);
|
||||
let x1 = scope.create_local(Elem::UInt);
|
||||
let xw = scope.create_local(input.item());
|
||||
let xw_ = scope.create_local(input.item());
|
||||
let xw = scope.create_local(input.item);
|
||||
let xw_ = scope.create_local(input.item);
|
||||
|
||||
cpa!(scope, numerator_int = input_shape_2 - 1u32);
|
||||
cpa!(scope, denominator_int = output_shape_2 - 1u32);
|
||||
|
@ -136,10 +136,10 @@ impl InterpolateBilinearShader {
|
|||
let y1_stride = scope.create_local(Elem::UInt);
|
||||
let x0_stride = scope.create_local(Elem::UInt);
|
||||
let x1_stride = scope.create_local(Elem::UInt);
|
||||
let p_a = scope.create_local(input.item());
|
||||
let p_b = scope.create_local(input.item());
|
||||
let p_c = scope.create_local(input.item());
|
||||
let p_d = scope.create_local(input.item());
|
||||
let p_a = scope.create_local(input.item);
|
||||
let p_b = scope.create_local(input.item);
|
||||
let p_c = scope.create_local(input.item);
|
||||
let p_d = scope.create_local(input.item);
|
||||
|
||||
cpa!(scope, index_base = b * input_stride_0);
|
||||
cpa!(scope, index_tmp = c * input_stride_1);
|
||||
|
@ -177,7 +177,7 @@ impl InterpolateBilinearShader {
|
|||
cpa!(scope, p_d *= xw);
|
||||
cpa!(scope, p_d *= yw);
|
||||
|
||||
let sum = scope.create_local(input.item());
|
||||
let sum = scope.create_local(input.item);
|
||||
cpa!(scope, sum = p_a + p_b);
|
||||
cpa!(scope, sum += p_c);
|
||||
cpa!(scope, sum += p_d);
|
||||
|
@ -190,8 +190,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for InterpolateBilinearEagerKernel<R,
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray { id: 0, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
let input = Variable::new(VariableKind::GlobalInputArray(0), item);
|
||||
let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
|
||||
|
||||
InterpolateBilinearShader { input, output }.expand(&mut scope);
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Elem, KernelDefinition, Scope, Variable, Visibility},
|
||||
ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility},
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -24,7 +24,7 @@ impl<E: JitElement> InterpolateNearestShader<E> {
|
|||
pub(crate) fn expand(self, scope: &mut Scope) {
|
||||
let input = self.input;
|
||||
let output = self.output;
|
||||
let id = Variable::AbsolutePos;
|
||||
let id = Variable::builtin(Builtin::AbsolutePos);
|
||||
let elem = E::cube_elem();
|
||||
|
||||
let input_stride_0 = scope.create_local(Elem::UInt);
|
||||
|
@ -106,7 +106,7 @@ impl<E: JitElement> InterpolateNearestShader<E> {
|
|||
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
let index_tmp = scope.create_local(Elem::UInt);
|
||||
let val = scope.create_local(output.item());
|
||||
let val = scope.create_local(output.item);
|
||||
|
||||
cpa!(scope, index = b * input_stride_0);
|
||||
cpa!(scope, index_tmp = c * input_stride_1);
|
||||
|
@ -126,8 +126,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for InterpolateNearestEagerKernel<R, E
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray { id: 0, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
let input = Variable::new(VariableKind::GlobalInputArray(0), item);
|
||||
let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
|
||||
|
||||
InterpolateNearestShader {
|
||||
input,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Elem, KernelDefinition, Scope, Variable, Visibility},
|
||||
ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility},
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -24,7 +24,7 @@ impl<E: JitElement> InterpolateNearestBackwardShader<E> {
|
|||
fn expand(self, scope: &mut Scope) {
|
||||
let grad = self.out_grad;
|
||||
let output = self.output;
|
||||
let id = Variable::AbsolutePos;
|
||||
let id = Variable::builtin(Builtin::AbsolutePos);
|
||||
|
||||
let grad_stride_0 = scope.create_local(Elem::UInt);
|
||||
let grad_stride_1 = scope.create_local(Elem::UInt);
|
||||
|
@ -88,7 +88,7 @@ impl<E: JitElement> InterpolateNearestBackwardShader<E> {
|
|||
let gw_start = Self::start_index(scope, ow, grad_shape_3, output_shape_3);
|
||||
let gw_end = Self::end_index(scope, ow, grad_shape_3, output_shape_3);
|
||||
|
||||
let result = scope.create_local(grad.item());
|
||||
let result = scope.create_local(grad.item);
|
||||
|
||||
let index_grad = scope.create_local(Elem::UInt);
|
||||
let index_grad_0 = scope.create_local(Elem::UInt);
|
||||
|
@ -99,7 +99,7 @@ impl<E: JitElement> InterpolateNearestBackwardShader<E> {
|
|||
cpa!(scope, index_grad_0 = b * grad_stride_0);
|
||||
cpa!(scope, index_grad_1 = c * grad_stride_1);
|
||||
|
||||
let sum = scope.zero(output.item());
|
||||
let sum = scope.zero(output.item);
|
||||
|
||||
cpa!(
|
||||
scope,
|
||||
|
@ -184,8 +184,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for InterpolateNearestBackwardEagerKer
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let out_grad = Variable::GlobalInputArray { id: 0, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
let out_grad = Variable::new(VariableKind::GlobalInputArray(0), item);
|
||||
let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
|
||||
|
||||
InterpolateNearestBackwardShader {
|
||||
out_grad,
|
||||
|
|
|
@ -7,7 +7,9 @@ use crate::{
|
|||
};
|
||||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Elem, IntKind, KernelDefinition, Scope, Variable, Visibility},
|
||||
ir::{
|
||||
Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility,
|
||||
},
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -32,7 +34,7 @@ impl AvgPool2dBackwardComputeShader {
|
|||
fn expand(self, scope: &mut Scope) {
|
||||
let grad = self.grad;
|
||||
let output = self.output;
|
||||
let id = Variable::AbsolutePos;
|
||||
let id = Variable::builtin(Builtin::AbsolutePos);
|
||||
|
||||
let grad_stride_0 = scope.create_local(Elem::UInt);
|
||||
let grad_stride_1 = scope.create_local(Elem::UInt);
|
||||
|
@ -70,22 +72,10 @@ impl AvgPool2dBackwardComputeShader {
|
|||
cpa!(scope, output_shape_2 = shape(output, 2u32));
|
||||
cpa!(scope, output_shape_3 = shape(output, 3u32));
|
||||
|
||||
let pool_stride_0 = Variable::GlobalScalar {
|
||||
id: 0,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let pool_stride_1 = Variable::GlobalScalar {
|
||||
id: 1,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_0 = Variable::GlobalScalar {
|
||||
id: 4,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_1 = Variable::GlobalScalar {
|
||||
id: 5,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
|
||||
let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
|
||||
let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt));
|
||||
let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt));
|
||||
let [kernel_size_0, kernel_size_1] = self.kernel_size;
|
||||
|
||||
let b = scope.create_local(Elem::UInt);
|
||||
|
@ -116,9 +106,9 @@ impl AvgPool2dBackwardComputeShader {
|
|||
let index_tmp = scope.create_local(Elem::UInt);
|
||||
let index_base = scope.create_local(Elem::UInt);
|
||||
|
||||
let grad_accumulation = scope.zero(grad.item());
|
||||
let result = scope.create_local(grad.item());
|
||||
let count = scope.create_local(grad.item());
|
||||
let grad_accumulation = scope.zero(grad.item);
|
||||
let result = scope.create_local(grad.item);
|
||||
let count = scope.create_local(grad.item);
|
||||
|
||||
let count_include_pad = self.count_include_pad;
|
||||
if count_include_pad {
|
||||
|
@ -226,30 +216,12 @@ impl AvgPool2dBackwardComputeShader {
|
|||
output_stride_2: Variable,
|
||||
output_stride_3: Variable,
|
||||
) -> (Variable, Variable, Variable, Variable) {
|
||||
let pool_stride_0 = Variable::GlobalScalar {
|
||||
id: 0,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let pool_stride_1 = Variable::GlobalScalar {
|
||||
id: 1,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let dilation_0 = Variable::GlobalScalar {
|
||||
id: 2,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let dilation_1 = Variable::GlobalScalar {
|
||||
id: 3,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_0 = Variable::GlobalScalar {
|
||||
id: 4,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_1 = Variable::GlobalScalar {
|
||||
id: 5,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
|
||||
let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
|
||||
let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt));
|
||||
let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt));
|
||||
let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt));
|
||||
let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt));
|
||||
|
||||
let [kernel_size_0, kernel_size_1] = self.kernel_size;
|
||||
|
||||
|
@ -350,8 +322,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for AvgPool2dBackwardEagerKernel<R, E>
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let grad = Variable::GlobalInputArray { id: 0, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
let grad = Variable::new(VariableKind::GlobalInputArray(0), item);
|
||||
let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
|
|
|
@ -7,7 +7,9 @@ use crate::{
|
|||
};
|
||||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility},
|
||||
ir::{
|
||||
Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility,
|
||||
},
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
|
@ -32,7 +34,7 @@ impl MaxPool2dBackwardComputeShader {
|
|||
let grad = self.grad;
|
||||
let output = self.output;
|
||||
let indices = self.indices;
|
||||
let id = Variable::AbsolutePos;
|
||||
let id = Variable::builtin(Builtin::AbsolutePos);
|
||||
|
||||
let grad_stride_0 = scope.create_local(Elem::UInt);
|
||||
let grad_stride_1 = scope.create_local(Elem::UInt);
|
||||
|
@ -103,8 +105,8 @@ impl MaxPool2dBackwardComputeShader {
|
|||
let index_base = scope.create_local(Elem::UInt);
|
||||
let index_tmp = scope.create_local(Elem::UInt);
|
||||
|
||||
let grad_accumulation = scope.zero(grad.item());
|
||||
let result = scope.create_local(grad.item());
|
||||
let grad_accumulation = scope.zero(grad.item);
|
||||
let result = scope.create_local(grad.item);
|
||||
|
||||
let (oh_start, oh_end, ow_start, ow_end) = self.loop_ranges(
|
||||
scope,
|
||||
|
@ -160,30 +162,12 @@ impl MaxPool2dBackwardComputeShader {
|
|||
output_stride_2: Variable,
|
||||
output_stride_3: Variable,
|
||||
) -> (Variable, Variable, Variable, Variable) {
|
||||
let pool_stride_0 = Variable::GlobalScalar {
|
||||
id: 0,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let pool_stride_1 = Variable::GlobalScalar {
|
||||
id: 1,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let dilation_0 = Variable::GlobalScalar {
|
||||
id: 2,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let dilation_1 = Variable::GlobalScalar {
|
||||
id: 3,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_0 = Variable::GlobalScalar {
|
||||
id: 4,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_1 = Variable::GlobalScalar {
|
||||
id: 5,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
|
||||
let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
|
||||
let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt));
|
||||
let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt));
|
||||
let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt));
|
||||
let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt));
|
||||
|
||||
let [kernel_size_0, kernel_size_1] = self.kernel_size;
|
||||
|
||||
|
@ -284,12 +268,12 @@ impl<R: JitRuntime, E: JitElement> Kernel for MaxPool2dWithIndicesBackwardEagerK
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let indices = Variable::GlobalInputArray {
|
||||
id: 0,
|
||||
item: Item::new(Elem::Int(IntKind::I32)),
|
||||
};
|
||||
let grad = Variable::GlobalInputArray { id: 1, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
let indices = Variable::new(
|
||||
VariableKind::GlobalInputArray(0),
|
||||
Item::new(Elem::Int(IntKind::I32)),
|
||||
);
|
||||
let grad = Variable::new(VariableKind::GlobalInputArray(1), item);
|
||||
let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use cubecl::{
|
||||
cpa,
|
||||
ir::{Elem, Scope, Variable},
|
||||
ir::{Builtin, Elem, Item, Scope, Variable, VariableKind},
|
||||
prelude::*,
|
||||
CubeCountSettings, Execution, InputInfo, OutputInfo,
|
||||
};
|
||||
|
@ -58,32 +58,17 @@ impl<P: Prng<E>, R: JitRuntime, E: JitElement> Kernel for PrngEagerKernel<P, R,
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
let output = Variable::new(VariableKind::GlobalOutputArray(0), item);
|
||||
|
||||
let seed0 = Variable::GlobalScalar {
|
||||
id: 0,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let seed1 = Variable::GlobalScalar {
|
||||
id: 1,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let seed2 = Variable::GlobalScalar {
|
||||
id: 2,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let seed3 = Variable::GlobalScalar {
|
||||
id: 3,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let seed0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
|
||||
let seed1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
|
||||
let seed2 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt));
|
||||
let seed3 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt));
|
||||
let seeds = [seed0, seed1, seed2, seed3];
|
||||
|
||||
let mut args = Vec::<Variable>::new();
|
||||
for i in 0..P::args_length() {
|
||||
args.push(Variable::GlobalScalar {
|
||||
id: i as u16,
|
||||
elem: item.elem(),
|
||||
});
|
||||
args.push(Variable::new(VariableKind::GlobalScalar(i as u16), item));
|
||||
}
|
||||
|
||||
PrngShader::<P, E>::new(output, N_VALUES_PER_THREAD, seeds, args).expand(&mut scope);
|
||||
|
@ -174,12 +159,12 @@ impl<P: Prng<E>, E: JitElement> PrngShader<P, E> {
|
|||
let n_values_per_thread: Variable = self.n_values_per_thread.into();
|
||||
let args = self.args;
|
||||
|
||||
let cube_dim_x = Variable::CubeDimX;
|
||||
let cube_dim_y = Variable::CubeDimY;
|
||||
let cube_pos_x = Variable::CubePosX;
|
||||
let cube_pos_y = Variable::CubePosY;
|
||||
let cube_count_y = Variable::CubeCountY;
|
||||
let local_index = Variable::UnitPos;
|
||||
let cube_dim_x = Variable::builtin(Builtin::CubeDimX);
|
||||
let cube_dim_y = Variable::builtin(Builtin::CubeDimY);
|
||||
let cube_pos_x = Variable::builtin(Builtin::CubePosX);
|
||||
let cube_pos_y = Variable::builtin(Builtin::CubePosY);
|
||||
let cube_count_y = Variable::builtin(Builtin::CubeCountY);
|
||||
let local_index = Variable::builtin(Builtin::UnitPos);
|
||||
|
||||
let n_invocations = scope.create_local(Elem::UInt);
|
||||
cpa!(scope, n_invocations = cube_dim_x);
|
||||
|
|
|
@ -37,7 +37,7 @@ impl<E: JitElement> Prng<E> for Normal<E> {
|
|||
output: Variable,
|
||||
) {
|
||||
let float_elem = Elem::Float(FloatKind::F32);
|
||||
let item = output.item();
|
||||
let item = output.item;
|
||||
let mean = args[0];
|
||||
let std = args[1];
|
||||
let two_pi = scope.create_with_value(2. * PI, float_elem);
|
||||
|
|
|
@ -35,7 +35,7 @@ impl<E: JitElement> Prng<E> for Uniform<E> {
|
|||
output: Variable,
|
||||
) {
|
||||
let float_elem = Elem::Float(FloatKind::F32);
|
||||
let item = output.item();
|
||||
let item = output.item;
|
||||
let lower_bound = args[0];
|
||||
let upper_bound = args[1];
|
||||
let scale = scope.create_local(item);
|
||||
|
|
|
@ -32,7 +32,7 @@ impl<E: JitElement> ReduceDimShared<E> for Argmax {
|
|||
(value, index): Self::Accumulator,
|
||||
) {
|
||||
let (value_shared_memory, index_shared_memory) = shared_memory;
|
||||
let current_value = scope.create_local(value.item());
|
||||
let current_value = scope.create_local(value.item);
|
||||
cpa!(scope, current_value = value_shared_memory[write_position]);
|
||||
|
||||
let condition = scope.create_local(Elem::Bool);
|
||||
|
@ -49,7 +49,7 @@ impl<E: JitElement> ReduceDimShared<E> for Argmax {
|
|||
read_position: Variable,
|
||||
i: Variable,
|
||||
) -> Self::Accumulator {
|
||||
let value = scope.create_local(input.item());
|
||||
let value = scope.create_local(input.item);
|
||||
cpa!(scope, value = input[read_position]);
|
||||
(value, i)
|
||||
}
|
||||
|
@ -60,9 +60,9 @@ impl<E: JitElement> ReduceDimShared<E> for Argmax {
|
|||
read_position: Variable,
|
||||
) -> Self::Accumulator {
|
||||
let (value_shared_memory, index_shared_memory) = shared_memory;
|
||||
let value = scope.create_local(value_shared_memory.item());
|
||||
let value = scope.create_local(value_shared_memory.item);
|
||||
cpa!(scope, value = value_shared_memory[read_position]);
|
||||
let index = scope.create_local(index_shared_memory.item());
|
||||
let index = scope.create_local(index_shared_memory.item);
|
||||
cpa!(scope, index = index_shared_memory[read_position]);
|
||||
(value, index)
|
||||
}
|
||||
|
@ -75,7 +75,7 @@ impl<E: JitElement> ReduceDimShared<E> for Argmax {
|
|||
_shape_reduce_dim: Variable,
|
||||
) {
|
||||
let (_, index_shared_memory) = shared_memory;
|
||||
let final_value = scope.create_local(output.item());
|
||||
let final_value = scope.create_local(index_shared_memory.item);
|
||||
cpa!(scope, final_value = index_shared_memory[0]);
|
||||
cpa!(scope, output[write_position] = final_value);
|
||||
}
|
||||
|
|
|
@ -33,7 +33,7 @@ impl<E: JitElement> ReduceDimShared<E> for Argmin {
|
|||
(value, index): Self::Accumulator,
|
||||
) {
|
||||
let (value_shared_memory, index_shared_memory) = shared_memory;
|
||||
let current_value = scope.create_local(value.item());
|
||||
let current_value = scope.create_local(value.item);
|
||||
cpa!(scope, current_value = value_shared_memory[write_position]);
|
||||
|
||||
let condition = scope.create_local(Elem::Bool);
|
||||
|
@ -50,7 +50,7 @@ impl<E: JitElement> ReduceDimShared<E> for Argmin {
|
|||
read_position: Variable,
|
||||
i: Variable,
|
||||
) -> Self::Accumulator {
|
||||
let value = scope.create_local(input.item());
|
||||
let value = scope.create_local(input.item);
|
||||
cpa!(scope, value = input[read_position]);
|
||||
(value, i)
|
||||
}
|
||||
|
@ -61,9 +61,9 @@ impl<E: JitElement> ReduceDimShared<E> for Argmin {
|
|||
read_position: Variable,
|
||||
) -> Self::Accumulator {
|
||||
let (value_shared_memory, index_shared_memory) = shared_memory;
|
||||
let value = scope.create_local(value_shared_memory.item());
|
||||
let value = scope.create_local(value_shared_memory.item);
|
||||
cpa!(scope, value = value_shared_memory[read_position]);
|
||||
let index = scope.create_local(index_shared_memory.item());
|
||||
let index = scope.create_local(index_shared_memory.item);
|
||||
cpa!(scope, index = index_shared_memory[read_position]);
|
||||
(value, index)
|
||||
}
|
||||
|
@ -76,7 +76,7 @@ impl<E: JitElement> ReduceDimShared<E> for Argmin {
|
|||
_shape_reduce_dim: Variable,
|
||||
) {
|
||||
let (_, index_shared_memory) = shared_memory;
|
||||
let final_value = scope.create_local(output.item());
|
||||
let final_value = scope.create_local(index_shared_memory.item);
|
||||
cpa!(scope, final_value = index_shared_memory[0]);
|
||||
cpa!(scope, output[write_position] = final_value);
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ impl<E: JitElement> ReduceDimShared<E> for MeanDim {
|
|||
input_item: Item,
|
||||
) -> Self::Accumulator {
|
||||
let shared_memory = scope.create_shared(input_item, shared_memory_size);
|
||||
let neutral_element = scope.zero(shared_memory.item());
|
||||
let neutral_element = scope.zero(shared_memory.item);
|
||||
cpa!(scope, shared_memory[write_position] = neutral_element);
|
||||
shared_memory
|
||||
}
|
||||
|
@ -27,8 +27,8 @@ impl<E: JitElement> ReduceDimShared<E> for MeanDim {
|
|||
write_position: Variable,
|
||||
value: Self::Accumulator,
|
||||
) {
|
||||
let current_value = scope.create_local(value.item());
|
||||
let computed = scope.create_local(value.item());
|
||||
let current_value = scope.create_local(value.item);
|
||||
let computed = scope.create_local(value.item);
|
||||
cpa!(scope, current_value = shared_memory[write_position]);
|
||||
cpa!(scope, computed = current_value + value);
|
||||
cpa!(scope, shared_memory[write_position] = computed);
|
||||
|
@ -40,7 +40,7 @@ impl<E: JitElement> ReduceDimShared<E> for MeanDim {
|
|||
read_position: Variable,
|
||||
_i: Variable,
|
||||
) -> Self::Accumulator {
|
||||
let value = scope.create_local(input.item());
|
||||
let value = scope.create_local(input.item);
|
||||
cpa!(scope, value = input[read_position]);
|
||||
value
|
||||
}
|
||||
|
@ -50,7 +50,7 @@ impl<E: JitElement> ReduceDimShared<E> for MeanDim {
|
|||
shared_memory: Self::Accumulator,
|
||||
read_position: Variable,
|
||||
) -> Variable {
|
||||
let read_value = scope.create_local(shared_memory.item());
|
||||
let read_value = scope.create_local(shared_memory.item);
|
||||
cpa!(scope, read_value = shared_memory[read_position]);
|
||||
read_value
|
||||
}
|
||||
|
@ -62,10 +62,10 @@ impl<E: JitElement> ReduceDimShared<E> for MeanDim {
|
|||
write_position: Variable,
|
||||
shape_reduce_dim: Variable,
|
||||
) {
|
||||
let final_value = scope.create_local(output.item());
|
||||
let final_value = scope.create_local(output.item);
|
||||
cpa!(scope, final_value = shared_memory[0]);
|
||||
|
||||
let denominator = scope.create_local(output.item());
|
||||
let denominator = scope.create_local(output.item);
|
||||
cpa!(scope, denominator = cast(shape_reduce_dim));
|
||||
cpa!(scope, final_value = final_value / denominator);
|
||||
cpa!(scope, output[write_position] = final_value);
|
||||
|
|
|
@ -16,7 +16,7 @@ impl<E: JitElement> ReduceDimShared<E> for ProdDim {
|
|||
input_item: Item,
|
||||
) -> Self::Accumulator {
|
||||
let shared_memory = scope.create_shared(input_item, shared_memory_size);
|
||||
let neutral_element = scope.create_with_value(1, shared_memory.item());
|
||||
let neutral_element = scope.create_with_value(1, shared_memory.item);
|
||||
cpa!(scope, shared_memory[write_position] = neutral_element);
|
||||
shared_memory
|
||||
}
|
||||
|
@ -27,8 +27,8 @@ impl<E: JitElement> ReduceDimShared<E> for ProdDim {
|
|||
write_position: Variable,
|
||||
value: Self::Accumulator,
|
||||
) {
|
||||
let current_value = scope.create_local(value.item());
|
||||
let computed = scope.create_local(value.item());
|
||||
let current_value = scope.create_local(value.item);
|
||||
let computed = scope.create_local(value.item);
|
||||
cpa!(scope, current_value = shared_memory[write_position]);
|
||||
cpa!(scope, computed = current_value * value);
|
||||
cpa!(scope, shared_memory[write_position] = computed);
|
||||
|
@ -40,7 +40,7 @@ impl<E: JitElement> ReduceDimShared<E> for ProdDim {
|
|||
read_position: Variable,
|
||||
_i: Variable,
|
||||
) -> Self::Accumulator {
|
||||
let value = scope.create_local(input.item());
|
||||
let value = scope.create_local(input.item);
|
||||
cpa!(scope, value = input[read_position]);
|
||||
value
|
||||
}
|
||||
|
@ -50,7 +50,7 @@ impl<E: JitElement> ReduceDimShared<E> for ProdDim {
|
|||
shared_memory: Self::Accumulator,
|
||||
read_position: Variable,
|
||||
) -> Self::Accumulator {
|
||||
let read_value = scope.create_local(shared_memory.item());
|
||||
let read_value = scope.create_local(shared_memory.item);
|
||||
cpa!(scope, read_value = shared_memory[read_position]);
|
||||
read_value
|
||||
}
|
||||
|
@ -62,7 +62,7 @@ impl<E: JitElement> ReduceDimShared<E> for ProdDim {
|
|||
write_position: Variable,
|
||||
_shape_reduce_dim: Variable,
|
||||
) {
|
||||
let final_value = scope.create_local(output.item());
|
||||
let final_value = scope.create_local(output.item);
|
||||
cpa!(scope, final_value = shared_memory[0]);
|
||||
cpa!(scope, output[write_position] = final_value);
|
||||
}
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
use cubecl::{
|
||||
cpa, ir::KernelDefinition, prelude::CubeCount, CubeCountSettings, Execution, InputInfo,
|
||||
KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo,
|
||||
cpa,
|
||||
ir::{Builtin, KernelDefinition, VariableKind},
|
||||
prelude::CubeCount,
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
|
@ -51,14 +54,8 @@ impl<RD: ReduceDimShared<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Ker
|
|||
let item_input = EI::cube_elem().into();
|
||||
let item_output = EO::cube_elem().into();
|
||||
|
||||
let tensor = Variable::GlobalInputArray {
|
||||
id: 0,
|
||||
item: item_input,
|
||||
};
|
||||
let output = Variable::GlobalOutputArray {
|
||||
id: 0,
|
||||
item: item_output,
|
||||
};
|
||||
let tensor = Variable::new(VariableKind::GlobalInputArray(0), item_input);
|
||||
let output = Variable::new(VariableKind::GlobalOutputArray(0), item_output);
|
||||
|
||||
// Reduce groups are elements that are aligned along the reduce dim
|
||||
SharedReduceDimComputeShader {
|
||||
|
@ -112,16 +109,16 @@ impl<E: JitElement, RD: ReduceDimShared<E>> SharedReduceDimComputeShader<E, RD>
|
|||
let tensor = self.tensor;
|
||||
let output = self.output;
|
||||
|
||||
let rank = Variable::Rank;
|
||||
let rank = Variable::builtin(Builtin::Rank);
|
||||
let dim: Variable = self.dim.into();
|
||||
|
||||
let cube_pos_x = Variable::CubePosX;
|
||||
let cube_pos_y = Variable::CubePosY;
|
||||
let cube_count_x = Variable::CubeCountX;
|
||||
let local_invocation_id_x = Variable::UnitPosX;
|
||||
let local_invocation_id_y = Variable::UnitPosY;
|
||||
let cube_dim_x = Variable::CubeDimX;
|
||||
let cube_dim_y = Variable::CubeDimY;
|
||||
let cube_pos_x = Variable::builtin(Builtin::CubePosX);
|
||||
let cube_pos_y = Variable::builtin(Builtin::CubePosY);
|
||||
let cube_count_x = Variable::builtin(Builtin::CubeCountX);
|
||||
let local_invocation_id_x = Variable::builtin(Builtin::UnitPosX);
|
||||
let local_invocation_id_y = Variable::builtin(Builtin::UnitPosY);
|
||||
let cube_dim_x = Variable::builtin(Builtin::CubeDimX);
|
||||
let cube_dim_y = Variable::builtin(Builtin::CubeDimY);
|
||||
|
||||
let stride_reduce_dim_input = scope.create_local(Elem::UInt);
|
||||
cpa!(scope, stride_reduce_dim_input = stride(tensor, dim));
|
||||
|
@ -162,12 +159,8 @@ impl<E: JitElement, RD: ReduceDimShared<E>> SharedReduceDimComputeShader<E, RD>
|
|||
})
|
||||
);
|
||||
|
||||
let shared_memory = RD::initialize_shared(
|
||||
scope,
|
||||
self.shared_memory_size as u32,
|
||||
local_id,
|
||||
tensor.item(),
|
||||
);
|
||||
let shared_memory =
|
||||
RD::initialize_shared(scope, self.shared_memory_size as u32, local_id, tensor.item);
|
||||
|
||||
// Load to shared memory, unrolled
|
||||
cpa!(
|
||||
|
|
|
@ -16,7 +16,7 @@ impl<E: JitElement> ReduceDimShared<E> for SumDim {
|
|||
input_item: Item,
|
||||
) -> Self::Accumulator {
|
||||
let shared_memory = scope.create_shared(input_item, shared_memory_size);
|
||||
let neutral_element = scope.zero(shared_memory.item());
|
||||
let neutral_element = scope.zero(shared_memory.item);
|
||||
cpa!(scope, shared_memory[write_position] = neutral_element);
|
||||
shared_memory
|
||||
}
|
||||
|
@ -27,8 +27,8 @@ impl<E: JitElement> ReduceDimShared<E> for SumDim {
|
|||
write_position: Variable,
|
||||
value: Self::Accumulator,
|
||||
) {
|
||||
let current_value = scope.create_local(value.item());
|
||||
let computed = scope.create_local(value.item());
|
||||
let current_value = scope.create_local(value.item);
|
||||
let computed = scope.create_local(value.item);
|
||||
cpa!(scope, current_value = shared_memory[write_position]);
|
||||
cpa!(scope, computed = current_value + value);
|
||||
cpa!(scope, shared_memory[write_position] = computed);
|
||||
|
@ -40,7 +40,7 @@ impl<E: JitElement> ReduceDimShared<E> for SumDim {
|
|||
read_position: Variable,
|
||||
_i: Variable,
|
||||
) -> Self::Accumulator {
|
||||
let value = scope.create_local(input.item());
|
||||
let value = scope.create_local(input.item);
|
||||
cpa!(scope, value = input[read_position]);
|
||||
value
|
||||
}
|
||||
|
@ -50,7 +50,7 @@ impl<E: JitElement> ReduceDimShared<E> for SumDim {
|
|||
shared_memory: Self::Accumulator,
|
||||
read_position: Variable,
|
||||
) -> Self::Accumulator {
|
||||
let read_value = scope.create_local(shared_memory.item());
|
||||
let read_value = scope.create_local(shared_memory.item);
|
||||
cpa!(scope, read_value = shared_memory[read_position]);
|
||||
read_value
|
||||
}
|
||||
|
@ -62,7 +62,7 @@ impl<E: JitElement> ReduceDimShared<E> for SumDim {
|
|||
write_position: Variable,
|
||||
_shape_reduce_dim: Variable,
|
||||
) {
|
||||
let final_value = scope.create_local(output.item());
|
||||
let final_value = scope.create_local(output.item);
|
||||
cpa!(scope, final_value = shared_memory[0]);
|
||||
cpa!(scope, output[write_position] = final_value);
|
||||
}
|
||||
|
|
|
@ -73,6 +73,7 @@ mod cube_wgpu {
|
|||
WgpuDevice::Existing(id) => {
|
||||
DeviceId::new(5, (id.inner() % (u32::MAX as u64)) as u32)
|
||||
}
|
||||
WgpuDevice::DefaultDevice => DeviceId::new(6, 0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,9 +8,13 @@ use core::convert::Into;
|
|||
|
||||
use crate::model::{label::LABELS, normalizer::Normalizer, squeezenet::Model as SqueezenetModel};
|
||||
|
||||
use burn::{backend::NdArray, prelude::*, tensor::activation::softmax};
|
||||
use burn::{
|
||||
backend::{wgpu::init_device, NdArray},
|
||||
prelude::*,
|
||||
tensor::activation::softmax,
|
||||
};
|
||||
|
||||
use burn::backend::wgpu::{init_async, AutoGraphicsApi, Wgpu, WgpuDevice};
|
||||
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
|
||||
use burn_candle::Candle;
|
||||
|
||||
use serde::Serialize;
|
||||
|
@ -106,7 +110,7 @@ impl ImageClassifier {
|
|||
log::info!("Loading the model to the Wgpu backend");
|
||||
let start = Instant::now();
|
||||
let device = WgpuDevice::default();
|
||||
init_async::<AutoGraphicsApi>(&device, Default::default()).await;
|
||||
init_device::<AutoGraphicsApi>(&device, Default::default()).await;
|
||||
self.model = ModelType::WithWgpuBackend(Model::new(&device));
|
||||
let duration = start.elapsed();
|
||||
log::debug!("Model is loaded to the Wgpu backend in {:?}", duration);
|
||||
|
|
Loading…
Reference in New Issue