diff --git a/crates/burn-wgpu/src/kernel/cast.rs b/crates/burn-wgpu/src/kernel/cast.rs index 51d7c4863..102882276 100644 --- a/crates/burn-wgpu/src/kernel/cast.rs +++ b/crates/burn-wgpu/src/kernel/cast.rs @@ -1,7 +1,18 @@ -use super::{KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT}; +use super::{ + DynamicKernelSource, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT, +}; use crate::{ - compute::StaticKernel, element::JitElement, kernel::elemwise_workgroup, kernel_wgsl, - tensor::JitTensor, Runtime, + codegen::{ + dialect::gpu::{gpu, Elem, Item, Scope, Variable, Visibility}, + execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, + InputInfo, OutputInfo, WorkgroupLaunch, + }, + compute::StaticKernel, + element::JitElement, + kernel::elemwise_workgroup, + kernel_wgsl, + tensor::{self, JitTensor}, + Runtime, }; use std::{any::TypeId, marker::PhantomData}; @@ -76,31 +87,103 @@ pub fn cast( +pub fn bool_cast( tensor: JitTensor, -) -> JitTensor { +) -> JitTensor { + let kernel = BoolCastEagerKernel::new(); let num_elems = tensor.shape.num_elements(); - let kernel = StaticKernel::< - KernelSettings, f32, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - - let handle = tensor - .client - .empty(num_elems * core::mem::size_of::()); + let buffer = tensor.client.empty(num_elems * core::mem::size_of::()); let output = JitTensor::new( tensor.client.clone(), tensor.device, tensor.shape.clone(), - handle, + buffer, ); - tensor - .client - .execute(Box::new(kernel), &[&tensor.handle, &output.handle]); + execute_dynamic::, u32>( + &[EagerHandle::new( + &tensor.handle, + &tensor.strides, + &tensor.shape.dims, + )], + &[EagerHandle::new( + &output.handle, + &output.strides, + &output.shape.dims, + )], + None, + kernel, + WorkgroupLaunch::Output { pos: 0 }, + tensor.client, + ); output } +pub(crate) struct BoolCastShader { + tensor: Variable, + output: Variable, +} + +#[derive(new)] +pub(crate) struct BoolCastEagerKernel { + _runtime: PhantomData, + _elem_out: PhantomData, +} + +impl DynamicKernelSource for BoolCastEagerKernel { + fn source(&self) -> crate::kernel::SourceTemplate { + let mut scope = Scope::root(); + let item_input = Item::Scalar(Elem::Bool); + let item_output = EO::gpu_elem().into(); + + let tensor = Variable::GlobalInputArray(0, item_input); + let output = Variable::GlobalOutputArray(0, item_output); + + BoolCastShader { tensor, output }.expand(&mut scope); + + scope.write_global_custom(output); + + let tensor = InputInfo::Array { + item: item_input, + visibility: Visibility::Read, + }; + + let out = OutputInfo::Array { item: item_output }; + + let info = CompilationInfo { + inputs: vec![tensor], + outputs: vec![out], + scope, + }; + + let settings = CompilationSettings::default(); + let shader = Compilation::new(info).compile(settings); + let shader = ::compile(shader); + SourceTemplate::new(shader.to_string()) + } + + fn id(&self) -> String { + format!("{:?}", core::any::TypeId::of::()) + } +} + +impl BoolCastShader { + pub(crate) fn expand(self, scope: &mut Scope) { + let tensor = self.tensor; + let id = Variable::Id; + let output = self.output; + + let represents_true = scope.create_local(Elem::Bool); + gpu!(scope, represents_true = tensor[id]); + gpu!(scope, if(represents_true).then(|scope|{ + gpu!(scope, output[id] = 1); + }).else(|scope|{ + gpu!(scope, output[id] = 0); + })); + } +} + #[cfg(test)] mod tests { use super::*;