From a06933f0297cccd5cd1f89656749c15f50867e65 Mon Sep 17 00:00:00 2001 From: mepatrick73 Date: Mon, 12 Aug 2024 17:13:53 -0400 Subject: [PATCH] working version --- crates/burn-jit/src/kernel/index/select.rs | 209 +++++++++------------ crates/burn-jit/src/tests/select.rs | 3 + 2 files changed, 91 insertions(+), 121 deletions(-) diff --git a/crates/burn-jit/src/kernel/index/select.rs b/crates/burn-jit/src/kernel/index/select.rs index 979adafe3..99fbf84b1 100644 --- a/crates/burn-jit/src/kernel/index/select.rs +++ b/crates/burn-jit/src/kernel/index/select.rs @@ -1,119 +1,79 @@ use crate::{ element::JitElement, kernel::Kernel, ops::numeric::empty_device, tensor::JitTensor, JitRuntime, }; +use cubecl::prelude::*; use cubecl::{ - cpa, - frontend::TensorHandleRef, - ir::{Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility}, - CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, - OutputInfo, + calculate_cube_count_elemwise, frontend::TensorHandleRef, CubeCountSettings, CubeDim, Execution, }; -use std::marker::PhantomData; -#[derive(new)] -struct SelectEagerKernel { - dim: usize, - _runtime: PhantomData, - _elem: PhantomData, -} +#[cube(launch_unchecked)] +fn select_kernel( + input: &Tensor, + indices: &Tensor, + output: &mut Tensor, + dim: &UInt, +) { + let id = ABSOLUTE_POS; + let mut offset_input = UInt::new(0); + let rank = output.rank(); + for i in range(UInt::new(0), rank, Comptime::new(false)) { + let stride_input = input.stride(i); + let stride_output = output.stride(i); + let shape_output = output.shape(i); + let mut offset_local = id / stride_output; + offset_local = offset_local % shape_output; -pub struct SelectComputeShader { - input: Variable, - indices: Variable, - output: Variable, - dim: usize, -} - -impl SelectComputeShader { - pub fn expand(self, scope: &mut Scope) { - let input = self.input; - let indices = self.indices; - let output = self.output; - let id = Variable::AbsolutePos; - let offset_input = scope.zero(Elem::UInt); - - cpa!( - scope, - range(0u32, Variable::Rank).for_each(|i, scope| { - let stride_input = scope.create_local(Elem::UInt); - let stride_output = scope.create_local(Elem::UInt); - let shape_output = scope.create_local(Elem::UInt); - - cpa!(scope, stride_input = stride(input, i)); - cpa!(scope, stride_output = stride(output, i)); - cpa!(scope, shape_output = shape(output, i)); - - let offset_local = scope.create_local(Elem::UInt); - cpa!(scope, offset_local = id / stride_output); - cpa!(scope, offset_local = offset_local % shape_output); - - let dim_index = scope.create_local(Elem::Bool); - cpa!(scope, dim_index = i == self.dim); - - cpa!(scope, if(dim_index).then(|scope| { - cpa!(scope, offset_local = indices[offset_local]); - cpa!(scope, offset_local = offset_local * stride_input); - }).else(|scope| { - cpa!(scope, offset_local = offset_local * stride_input); - })); - - cpa!(scope, offset_input += offset_local); - }) - ); - - let value = scope.create_local(input.item()); - cpa!(scope, value = input[offset_input]); - cpa!(scope, output[id] = value); - } -} - -impl Kernel for SelectEagerKernel { - fn define(&self) -> KernelDefinition { - let mut scope = Scope::root(); - let item = E::cube_elem().into(); - let item_indices: Item = Elem::Int(IntKind::I32).into(); - - let input = Variable::GlobalInputArray { id: 0, item }; - let indices = Variable::GlobalInputArray { - id: 1, - item: item_indices, - }; - let output = Variable::GlobalOutputArray { id: 0, item }; - - scope.write_global_custom(output); - - SelectComputeShader { - input, - indices, - output, - dim: self.dim, + if i == *dim { + offset_local = UInt::cast_from(indices[offset_local]); + offset_local *= stride_input; + } else { + offset_local *= stride_input; } - .expand(&mut scope); - - let input = InputInfo::Array { - item, - visibility: Visibility::Read, - }; - let indices = InputInfo::Array { - item: item_indices, - visibility: Visibility::Read, - }; - let output = OutputInfo::Array { item }; - - let info = KernelExpansion { - inputs: vec![input, indices], - outputs: vec![output], - scope, - }; - - let settings = KernelSettings::default(); - KernelIntegrator::new(info).integrate(settings) - } - - fn id(&self) -> cubecl::KernelId { - cubecl::KernelId::new::().info(self.dim) + offset_input += offset_local; } + let value = input[offset_input]; + output[id] = value; } +//pub fn expand(self, scope: &mut Scope) { +// let input = self.input; +// let indices = self.indices; +// let output = self.output; +// let id = Variable::AbsolutePos; +// let offset_input = scope.zero(Elem::UInt); +// +// cpa!( +// scope, +// range(0u32, Variable::Rank).for_each(|i, scope| { +// let stride_input = scope.create_local(Elem::UInt); +// let stride_output = scope.create_local(Elem::UInt); +// let shape_output = scope.create_local(Elem::UInt); +// +// cpa!(scope, stride_input = stride(input, i)); +// cpa!(scope, stride_output = stride(output, i)); +// cpa!(scope, shape_output = shape(output, i)); +// +// let offset_local = scope.create_local(Elem::UInt); +// cpa!(scope, offset_local = id / stride_output); +// cpa!(scope, offset_local = offset_local % shape_output); +// +// let dim_index = scope.create_local(Elem::Bool); +// cpa!(scope, dim_index = i == self.dim); +// +// cpa!(scope, if(dim_index).then(|scope| { +// cpa!(scope, offset_local = indices[offset_local]); +// cpa!(scope, offset_local = offset_local * stride_input); +// }).else(|scope| { +// cpa!(scope, offset_local = offset_local * stride_input); +// })); +// +// cpa!(scope, offset_input += offset_local); +// }) +// ); +// +// let value = scope.create_local(input.item()); +// cpa!(scope, value = input[offset_input]); +// cpa!(scope, output[id] = value); +//} pub(crate) fn select( tensor: JitTensor, @@ -122,26 +82,33 @@ pub(crate) fn select JitTensor { let mut shape_output = tensor.shape.clone(); shape_output.dims[dim] = indices.shape.dims[0]; - - let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output); - let kernel = SelectEagerKernel::::new(dim); - let num_elems = indices.shape.dims[0]; + let mut total_elem = 1; + for dim_size in shape_output.dims.iter() { + total_elem *= dim_size + } let mut shapes = [1; D]; let mut strides = [num_elems; D]; shapes[D - 1] = num_elems; strides[D - 1] = 1; - Execution::start(kernel, tensor.client.clone()) - .inputs(&[ - tensor.as_handle_ref(), - // This is a current hacks because the info buffer that contains the strides and shapes is - // hardcoded to only contains information about tensors of the same rank. However, since - // we don't rely on the shape and stride of the indices tensors, it doesn't matter - // which value we put, it just needs to be of the same rank. - unsafe { TensorHandleRef::from_raw_parts(&indices.handle, &strides, &shapes) }, - ]) - .outputs(&[output.as_handle_ref()]) - .execute(CubeCountSettings::Output { pos: 0 }); + + let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output); + + let cube_dim = CubeDim::default(); + + let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim); + + unsafe { + select_kernel::launch_unchecked::( + &tensor.client, + cube_count, + cube_dim, + tensor.as_tensor_arg(1), + TensorArg::from_raw_parts(&indices.handle, &strides, &shapes, 1), + output.as_tensor_arg(1), + ScalarArg::new(dim as u32), + ) + }; output } diff --git a/crates/burn-jit/src/tests/select.rs b/crates/burn-jit/src/tests/select.rs index 6ede6e89b..7b6b72e9d 100644 --- a/crates/burn-jit/src/tests/select.rs +++ b/crates/burn-jit/src/tests/select.rs @@ -16,6 +16,9 @@ mod tests { let actual = tensor.select(1, indices); let expected = tensor_ref.select(1, indices_ref); + println!("{:?}", actual); + println!("{:?}", expected); + expected .into_data() .assert_approx_eq(&actual.into_data(), 3);