diff --git a/crates/burn-jit/src/kernel/index/gather.rs b/crates/burn-jit/src/kernel/index/gather.rs index 90a8777fb..098868078 100644 --- a/crates/burn-jit/src/kernel/index/gather.rs +++ b/crates/burn-jit/src/kernel/index/gather.rs @@ -1,128 +1,50 @@ use crate::{ element::JitElement, kernel::Kernel, ops::numeric::empty_device, tensor::JitTensor, JitRuntime, }; -use cubecl::ir::{ - Elem, IndexOffsetGlobalWithLayout, IntKind, Item, KernelDefinition, Scope, Variable, Visibility, -}; -use cubecl::{ - cpa, frontend::TensorHandleRef, CubeCountSettings, Execution, InputInfo, KernelExpansion, - KernelIntegrator, KernelSettings, OutputInfo, -}; -use std::marker::PhantomData; +use cubecl::ir::KernelDefinition; +use cubecl::linalg::tensor::index_offset_with_layout; +use cubecl::prelude::*; +use cubecl::{calculate_cube_count_elemwise, CubeDim}; -#[derive(new)] -struct GatherEagerKernel { - dim: usize, - _runtime: PhantomData, - _elem: PhantomData, -} +#[cube(launch)] +fn gather_kernel( + input: &Tensor, + indices: &Tensor, + output: &mut Tensor, + dim: Comptime, + rank: Comptime, +) { + let dim_runtime = Comptime::runtime(dim); + let not_zeroth_dim = Comptime::map(dim, |d| d > 0); -struct GatherComputeShader { - tensor: Variable, - indices: Variable, - out: Variable, - dim: usize, -} + // The offset for the `dim` dimension is obtained by the indices tensor. + let index = indices[ABSOLUTE_POS]; + let stride = input.stride(dim_runtime); -impl GatherComputeShader { - pub fn expand(self, scope: &mut Scope) { - match self.tensor { - Variable::GlobalInputArray { .. } => (), - Variable::GlobalOutputArray { .. } => (), - _ => panic!("Tensor variable must be an global array."), - }; + let mut offset = index * stride; - let tensor = self.tensor; - let output = self.out; - - let stride = scope.create_local(Elem::UInt); - let offset = scope.create_local(Elem::UInt); - - // The offset of the `dim` dimension is obtained by the indices tensor. - cpa!(scope, offset = cast(self.indices)); - cpa!(scope, stride = stride(tensor, self.dim)); - cpa!(scope, offset = offset * stride); - - // We fetch the offset before the `dim` dimension. - if self.dim > 0 { - let offset_before = scope.create_local(Elem::UInt); - scope.index_offset_with_output_layout(IndexOffsetGlobalWithLayout { - tensors: vec![tensor], - indexes: vec![offset_before], - layout: Variable::AbsolutePos, // Will be updated. - position: Variable::AbsolutePos, - dim_start: 0u32.into(), - dim_end: self.dim.into(), - }); - cpa!(scope, offset += offset_before); - } - - let offset_after = scope.create_local(Elem::UInt); - scope.index_offset_with_output_layout(IndexOffsetGlobalWithLayout { - tensors: vec![tensor], - indexes: vec![offset_after], - layout: Variable::AbsolutePos, // Will be updated. - position: Variable::AbsolutePos, - dim_start: (self.dim + 1).into(), - dim_end: Variable::Rank, - }); - cpa!(scope, offset += offset_after); - - cpa!(scope, output = tensor[offset]); - } -} - -impl Kernel for GatherEagerKernel { - fn define(&self) -> KernelDefinition { - let mut scope = Scope::root(); - let item_tensor = E::cube_elem().into(); - let item_indices: Item = Elem::Int(IntKind::I32).into(); - - let tensor = Variable::GlobalInputArray { - id: 0, - item: item_tensor, - }; - let indices = scope.read_array(1, item_indices, Variable::AbsolutePos); - - let output_array = Variable::GlobalOutputArray { - id: 0, - item: item_tensor, - }; - let output_local = scope.create_local(item_tensor); - - GatherComputeShader { - tensor, - indices, - out: output_local, - dim: self.dim, - } - .expand(&mut scope); - - scope.write_global(output_local, output_array, Variable::AbsolutePos); - - let tensor = InputInfo::Array { - item: item_tensor, - visibility: Visibility::Read, - }; - let indices = InputInfo::Array { - item: Elem::Int(IntKind::I32).into(), - visibility: Visibility::Read, - }; - let out = OutputInfo::Array { item: item_tensor }; - - let info = KernelExpansion { - inputs: vec![tensor, indices], - outputs: vec![out], - scope, - }; - - let settings = KernelSettings::default(); - KernelIntegrator::new(info).integrate(settings) + // We fetch the offset before the `dim` dimension. + if Comptime::get(not_zeroth_dim) { + offset += index_offset_with_layout( + input, + output, + ABSOLUTE_POS, + UInt::new(0), + dim_runtime, + Comptime::new(true), + ); } - fn id(&self) -> cubecl::KernelId { - cubecl::KernelId::new::().info(self.dim) - } + offset += index_offset_with_layout( + input, + output, + ABSOLUTE_POS, + Comptime::runtime(Comptime::map(dim, |d| d + 1)), + Comptime::runtime(rank), + Comptime::new(true), + ); + + output[ABSOLUTE_POS] = input[offset]; } pub(crate) fn gather( @@ -131,20 +53,23 @@ pub(crate) fn gather, ) -> JitTensor { let shape_output = indices.shape.clone(); - let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output); - let kernel = GatherEagerKernel::::new(dim); - Execution::start(kernel, tensor.client) - .inputs(&[ - TensorHandleRef::::new(&tensor.handle, &tensor.strides, &tensor.shape.dims), - TensorHandleRef::new(&indices.handle, &indices.strides, &indices.shape.dims), - ]) - .outputs(&[TensorHandleRef::new( - &output.handle, - &output.strides, - &output.shape.dims, - )]) - .execute(CubeCountSettings::Output { pos: 0 }); + let cube_dim = CubeDim::default(); + let cube_count = + calculate_cube_count_elemwise::(shape_output.num_elements(), cube_dim); + + let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output); + + gather_kernel::launch::( + &tensor.client, + cube_count, + cube_dim, + TensorArg::new(&tensor.handle, &tensor.strides, &tensor.shape.dims), + TensorArg::new(&indices.handle, &indices.strides, &indices.shape.dims), + TensorArg::new(&output.handle, &output.strides, &output.shape.dims), + UInt::new(dim as u32), + UInt::new(D as u32), + ); output }