diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index e969ee7d4..710206ee4 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -97,8 +97,8 @@ experimental-named-tensor = ["burn-tensor/experimental-named-tensor"] # Backwards compatibility with previous serialized data format. record-backward-compat = [] -test-tch = ["tch"] # To use tch during testing, default uses ndarray. -test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray. +test-tch = ["tch"] # To use tch during testing, default uses ndarray. +test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray. test-cuda = ["cuda-jit"] # To use cuda during testing, default uses ndarray. [dependencies] diff --git a/crates/burn-jit/src/kernel/index/gather.rs b/crates/burn-jit/src/kernel/index/gather.rs index c7f4ff685..2de20d39a 100644 --- a/crates/burn-jit/src/kernel/index/gather.rs +++ b/crates/burn-jit/src/kernel/index/gather.rs @@ -1,128 +1,46 @@ use crate::{ element::JitElement, kernel::Kernel, ops::numeric::empty_device, tensor::JitTensor, JitRuntime, }; -use cubecl::ir::{ - Elem, IndexOffsetGlobalWithLayout, IntKind, Item, KernelDefinition, Scope, Variable, Visibility, -}; -use cubecl::{ - cpa, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, - KernelSettings, OutputInfo, -}; -use std::marker::PhantomData; +use cubecl::frontend::{Numeric, Tensor, UInt, ABSOLUTE_POS}; +use cubecl::linalg::tensor::index_offset_with_layout; +use cubecl::CubeDim; +use cubecl::{calculate_cube_count_elemwise, prelude::*}; -#[derive(new)] -struct GatherEagerKernel { - dim: usize, - _runtime: PhantomData, - _elem: PhantomData, -} +#[cube(launch_unchecked)] +fn gather_kernel( + input: &Tensor, + indices: &Tensor, + output: &mut Tensor, + dim: &UInt, +) { + let index = indices[ABSOLUTE_POS]; -struct GatherComputeShader { - tensor: Variable, - indices: Variable, - out: Variable, - dim: usize, -} + let stride = input.stride(*dim); + let mut offset = UInt::cast_from(index); + offset *= stride; -impl GatherComputeShader { - pub fn expand(self, scope: &mut Scope) { - match self.tensor { - Variable::GlobalInputArray { .. } => (), - Variable::GlobalOutputArray { .. } => (), - _ => panic!("Tensor variable must be an global array."), - }; - - let tensor = self.tensor; - let output = self.out; - - let stride = scope.create_local(Elem::UInt); - let offset = scope.create_local(Elem::UInt); - - // The offset of the `dim` dimension is obtained by the indices tensor. - cpa!(scope, offset = cast(self.indices)); - cpa!(scope, stride = stride(tensor, self.dim)); - cpa!(scope, offset = offset * stride); - - // We fetch the offset before the `dim` dimension. - if self.dim > 0 { - let offset_before = scope.create_local(Elem::UInt); - scope.index_offset_with_output_layout(IndexOffsetGlobalWithLayout { - tensors: vec![tensor], - indexes: vec![offset_before], - layout: Variable::AbsolutePos, // Will be updated. - position: Variable::AbsolutePos, - dim_start: 0u32.into(), - dim_end: self.dim.into(), - }); - cpa!(scope, offset += offset_before); - } - - let offset_after = scope.create_local(Elem::UInt); - scope.index_offset_with_output_layout(IndexOffsetGlobalWithLayout { - tensors: vec![tensor], - indexes: vec![offset_after], - layout: Variable::AbsolutePos, // Will be updated. - position: Variable::AbsolutePos, - dim_start: (self.dim + 1).into(), - dim_end: Variable::Rank, - }); - cpa!(scope, offset += offset_after); - - cpa!(scope, output = tensor[offset]); - } -} - -impl Kernel for GatherEagerKernel { - fn define(&self) -> KernelDefinition { - let mut scope = Scope::root(); - let item_tensor = E::cube_elem().into(); - let item_indices: Item = Elem::Int(IntKind::I32).into(); - - let tensor = Variable::GlobalInputArray { - id: 0, - item: item_tensor, - }; - let indices = scope.read_array(1, item_indices, Variable::AbsolutePos); - - let output_array = Variable::GlobalOutputArray { - id: 0, - item: item_tensor, - }; - let output_local = scope.create_local(item_tensor); - - GatherComputeShader { - tensor, - indices, - out: output_local, - dim: self.dim, - } - .expand(&mut scope); - - scope.write_global(output_local, output_array, Variable::AbsolutePos); - - let tensor = InputInfo::Array { - item: item_tensor, - visibility: Visibility::Read, - }; - let indices = InputInfo::Array { - item: Elem::Int(IntKind::I32).into(), - visibility: Visibility::Read, - }; - let out = OutputInfo::Array { item: item_tensor }; - - let info = KernelExpansion { - inputs: vec![tensor, indices], - outputs: vec![out], - scope, - }; - - let settings = KernelSettings::default(); - KernelIntegrator::new(info).integrate(settings) + if *dim > 0 { + let offset_before = index_offset_with_layout( + input, + output, + ABSOLUTE_POS, + UInt::new(0), + *dim, + Comptime::new(false), + ); + offset += offset_before; } - fn id(&self) -> cubecl::KernelId { - cubecl::KernelId::new::().info(self.dim) - } + let offset_after = index_offset_with_layout( + input, + output, + ABSOLUTE_POS, + *dim + 1, + input.rank(), + Comptime::new(false), + ); + offset += offset_after; + output[ABSOLUTE_POS] = input[offset]; } pub(crate) fn gather( @@ -131,13 +49,21 @@ pub(crate) fn gather, ) -> JitTensor { let shape_output = indices.shape.clone(); + let total_elem = shape_output.num_elements(); let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output); - let kernel = GatherEagerKernel::::new(dim); - - Execution::start(kernel, tensor.client.clone()) - .inputs(&[tensor.as_handle_ref(), indices.as_handle_ref()]) - .outputs(&[output.as_handle_ref()]) - .execute(CubeCountSettings::Output { pos: 0 }); + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim); + unsafe { + gather_kernel::launch_unchecked::( + &tensor.client, + cube_count, + cube_dim, + tensor.as_tensor_arg(1), + indices.as_tensor_arg(1), + output.as_tensor_arg(1), + ScalarArg::new(dim as u32), + ) + } output } diff --git a/examples/text-classification/Cargo.toml b/examples/text-classification/Cargo.toml index 985641829..7d01a6b36 100644 --- a/examples/text-classification/Cargo.toml +++ b/examples/text-classification/Cargo.toml @@ -20,7 +20,15 @@ cuda-jit = ["burn/cuda-jit"] [dependencies] # Burn -burn = {path = "../../crates/burn", features=["train", "ndarray", "std", "metrics", "autotune", "fusion", "default"], default-features = false} +burn = { path = "../../crates/burn", features = [ + "train", + "ndarray", + "std", + "metrics", + "autotune", + "fusion", + "default", +], default-features = false } # Tokenizer tokenizers = { version = "0.20.0", default-features = false, features = [