Gather CPA to CubeCL (#2165)

* working version

* cleanup

* wip

* working version of gather

* testsetsetser

* Revert "testsetsetser"

This reverts commit f37b329697.

* Reapply "testsetsetser"

This reverts commit f8ada0044e.

* Revert "testsetsetser"

This reverts commit f37b329697.

* Revert "working version of gather"

This reverts commit f5047c27c8.

* Revert "wip"

This reverts commit abaaa2dd55.

* Revert "Merge branch 'main' into index-cpa-to-cubecl"

This reverts commit 05bed8ea74, reversing
changes made to 94954fc32c.

* Revert "cleanup"

This reverts commit 94954fc32c.

* Revert "working version"

This reverts commit a06933f029.

* gather test

* fix

* fix clippy

* cleanup
This commit is contained in:
mepatrick73 2024-08-22 13:44:26 -04:00 committed by GitHub
parent 73d4b11aa2
commit e1fed792f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 60 additions and 126 deletions

View File

@ -1,128 +1,46 @@
use crate::{ use crate::{
element::JitElement, kernel::Kernel, ops::numeric::empty_device, tensor::JitTensor, JitRuntime, element::JitElement, kernel::Kernel, ops::numeric::empty_device, tensor::JitTensor, JitRuntime,
}; };
use cubecl::ir::{ use cubecl::frontend::{Numeric, Tensor, UInt, ABSOLUTE_POS};
Elem, IndexOffsetGlobalWithLayout, IntKind, Item, KernelDefinition, Scope, Variable, Visibility, use cubecl::linalg::tensor::index_offset_with_layout;
}; use cubecl::CubeDim;
use cubecl::{ use cubecl::{calculate_cube_count_elemwise, prelude::*};
cpa, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator,
KernelSettings, OutputInfo,
};
use std::marker::PhantomData;
#[derive(new)] #[cube(launch_unchecked)]
struct GatherEagerKernel<R: JitRuntime, E: JitElement> { fn gather_kernel<T: Numeric, I: Numeric>(
dim: usize, input: &Tensor<T>,
_runtime: PhantomData<R>, indices: &Tensor<I>,
_elem: PhantomData<E>, output: &mut Tensor<T>,
dim: &UInt,
) {
let index = indices[ABSOLUTE_POS];
let stride = input.stride(*dim);
let mut offset = UInt::cast_from(index);
offset *= stride;
if *dim > 0 {
let offset_before = index_offset_with_layout(
input,
output,
ABSOLUTE_POS,
UInt::new(0),
*dim,
Comptime::new(false),
);
offset += offset_before;
} }
struct GatherComputeShader { let offset_after = index_offset_with_layout(
tensor: Variable, input,
indices: Variable, output,
out: Variable, ABSOLUTE_POS,
dim: usize, *dim + 1,
} input.rank(),
Comptime::new(false),
impl GatherComputeShader { );
pub fn expand(self, scope: &mut Scope) { offset += offset_after;
match self.tensor { output[ABSOLUTE_POS] = input[offset];
Variable::GlobalInputArray { .. } => (),
Variable::GlobalOutputArray { .. } => (),
_ => panic!("Tensor variable must be an global array."),
};
let tensor = self.tensor;
let output = self.out;
let stride = scope.create_local(Elem::UInt);
let offset = scope.create_local(Elem::UInt);
// The offset of the `dim` dimension is obtained by the indices tensor.
cpa!(scope, offset = cast(self.indices));
cpa!(scope, stride = stride(tensor, self.dim));
cpa!(scope, offset = offset * stride);
// We fetch the offset before the `dim` dimension.
if self.dim > 0 {
let offset_before = scope.create_local(Elem::UInt);
scope.index_offset_with_output_layout(IndexOffsetGlobalWithLayout {
tensors: vec![tensor],
indexes: vec![offset_before],
layout: Variable::AbsolutePos, // Will be updated.
position: Variable::AbsolutePos,
dim_start: 0u32.into(),
dim_end: self.dim.into(),
});
cpa!(scope, offset += offset_before);
}
let offset_after = scope.create_local(Elem::UInt);
scope.index_offset_with_output_layout(IndexOffsetGlobalWithLayout {
tensors: vec![tensor],
indexes: vec![offset_after],
layout: Variable::AbsolutePos, // Will be updated.
position: Variable::AbsolutePos,
dim_start: (self.dim + 1).into(),
dim_end: Variable::Rank,
});
cpa!(scope, offset += offset_after);
cpa!(scope, output = tensor[offset]);
}
}
impl<R: JitRuntime, E: JitElement> Kernel for GatherEagerKernel<R, E> {
fn define(&self) -> KernelDefinition {
let mut scope = Scope::root();
let item_tensor = E::cube_elem().into();
let item_indices: Item = Elem::Int(IntKind::I32).into();
let tensor = Variable::GlobalInputArray {
id: 0,
item: item_tensor,
};
let indices = scope.read_array(1, item_indices, Variable::AbsolutePos);
let output_array = Variable::GlobalOutputArray {
id: 0,
item: item_tensor,
};
let output_local = scope.create_local(item_tensor);
GatherComputeShader {
tensor,
indices,
out: output_local,
dim: self.dim,
}
.expand(&mut scope);
scope.write_global(output_local, output_array, Variable::AbsolutePos);
let tensor = InputInfo::Array {
item: item_tensor,
visibility: Visibility::Read,
};
let indices = InputInfo::Array {
item: Elem::Int(IntKind::I32).into(),
visibility: Visibility::Read,
};
let out = OutputInfo::Array { item: item_tensor };
let info = KernelExpansion {
inputs: vec![tensor, indices],
outputs: vec![out],
scope,
};
let settings = KernelSettings::default();
KernelIntegrator::new(info).integrate(settings)
}
fn id(&self) -> cubecl::KernelId {
cubecl::KernelId::new::<Self>().info(self.dim)
}
} }
pub(crate) fn gather<R: JitRuntime, E: JitElement, I: JitElement, const D: usize>( pub(crate) fn gather<R: JitRuntime, E: JitElement, I: JitElement, const D: usize>(
@ -131,13 +49,21 @@ pub(crate) fn gather<R: JitRuntime, E: JitElement, I: JitElement, const D: usize
indices: JitTensor<R, I, D>, indices: JitTensor<R, I, D>,
) -> JitTensor<R, E, D> { ) -> JitTensor<R, E, D> {
let shape_output = indices.shape.clone(); let shape_output = indices.shape.clone();
let total_elem = shape_output.num_elements();
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output); let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
let kernel = GatherEagerKernel::<R, E>::new(dim);
Execution::start(kernel, tensor.client.clone())
.inputs(&[tensor.as_handle_ref(), indices.as_handle_ref()])
.outputs(&[output.as_handle_ref()])
.execute(CubeCountSettings::Output { pos: 0 });
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim);
unsafe {
gather_kernel::launch_unchecked::<E::Primitive, I::Primitive, R>(
&tensor.client,
cube_count,
cube_dim,
tensor.as_tensor_arg(1),
indices.as_tensor_arg(1),
output.as_tensor_arg(1),
ScalarArg::new(dim as u32),
)
}
output output
} }

View File

@ -20,7 +20,15 @@ cuda-jit = ["burn/cuda-jit"]
[dependencies] [dependencies]
# Burn # Burn
burn = {path = "../../crates/burn", features=["train", "ndarray", "std", "metrics", "autotune", "fusion", "default"], default-features = false} burn = { path = "../../crates/burn", features = [
"train",
"ndarray",
"std",
"metrics",
"autotune",
"fusion",
"default",
], default-features = false }
# Tokenizer # Tokenizer
tokenizers = { version = "0.20.0", default-features = false, features = [ tokenizers = { version = "0.20.0", default-features = false, features = [