This commit is contained in:
louisfd 2024-07-29 14:14:55 -04:00
parent def45f0b58
commit 4bc1a77819
1 changed files with 54 additions and 129 deletions

View File

@ -1,128 +1,50 @@
use crate::{
element::JitElement, kernel::Kernel, ops::numeric::empty_device, tensor::JitTensor, JitRuntime,
};
use cubecl::ir::{
Elem, IndexOffsetGlobalWithLayout, IntKind, Item, KernelDefinition, Scope, Variable, Visibility,
};
use cubecl::{
cpa, frontend::TensorHandleRef, CubeCountSettings, Execution, InputInfo, KernelExpansion,
KernelIntegrator, KernelSettings, OutputInfo,
};
use std::marker::PhantomData;
use cubecl::ir::KernelDefinition;
use cubecl::linalg::tensor::index_offset_with_layout;
use cubecl::prelude::*;
use cubecl::{calculate_cube_count_elemwise, CubeDim};
#[derive(new)]
struct GatherEagerKernel<R: JitRuntime, E: JitElement> {
dim: usize,
_runtime: PhantomData<R>,
_elem: PhantomData<E>,
}
#[cube(launch)]
fn gather_kernel<T: Numeric>(
input: &Tensor<T>,
indices: &Tensor<UInt>,
output: &mut Tensor<T>,
dim: Comptime<UInt>,
rank: Comptime<UInt>,
) {
let dim_runtime = Comptime::runtime(dim);
let not_zeroth_dim = Comptime::map(dim, |d| d > 0);
struct GatherComputeShader {
tensor: Variable,
indices: Variable,
out: Variable,
dim: usize,
}
// The offset for the `dim` dimension is obtained by the indices tensor.
let index = indices[ABSOLUTE_POS];
let stride = input.stride(dim_runtime);
impl GatherComputeShader {
pub fn expand(self, scope: &mut Scope) {
match self.tensor {
Variable::GlobalInputArray { .. } => (),
Variable::GlobalOutputArray { .. } => (),
_ => panic!("Tensor variable must be an global array."),
};
let tensor = self.tensor;
let output = self.out;
let stride = scope.create_local(Elem::UInt);
let offset = scope.create_local(Elem::UInt);
// The offset of the `dim` dimension is obtained by the indices tensor.
cpa!(scope, offset = cast(self.indices));
cpa!(scope, stride = stride(tensor, self.dim));
cpa!(scope, offset = offset * stride);
let mut offset = index * stride;
// We fetch the offset before the `dim` dimension.
if self.dim > 0 {
let offset_before = scope.create_local(Elem::UInt);
scope.index_offset_with_output_layout(IndexOffsetGlobalWithLayout {
tensors: vec![tensor],
indexes: vec![offset_before],
layout: Variable::AbsolutePos, // Will be updated.
position: Variable::AbsolutePos,
dim_start: 0u32.into(),
dim_end: self.dim.into(),
});
cpa!(scope, offset += offset_before);
if Comptime::get(not_zeroth_dim) {
offset += index_offset_with_layout(
input,
output,
ABSOLUTE_POS,
UInt::new(0),
dim_runtime,
Comptime::new(true),
);
}
let offset_after = scope.create_local(Elem::UInt);
scope.index_offset_with_output_layout(IndexOffsetGlobalWithLayout {
tensors: vec![tensor],
indexes: vec![offset_after],
layout: Variable::AbsolutePos, // Will be updated.
position: Variable::AbsolutePos,
dim_start: (self.dim + 1).into(),
dim_end: Variable::Rank,
});
cpa!(scope, offset += offset_after);
offset += index_offset_with_layout(
input,
output,
ABSOLUTE_POS,
Comptime::runtime(Comptime::map(dim, |d| d + 1)),
Comptime::runtime(rank),
Comptime::new(true),
);
cpa!(scope, output = tensor[offset]);
}
}
impl<R: JitRuntime, E: JitElement> Kernel for GatherEagerKernel<R, E> {
fn define(&self) -> KernelDefinition {
let mut scope = Scope::root();
let item_tensor = E::cube_elem().into();
let item_indices: Item = Elem::Int(IntKind::I32).into();
let tensor = Variable::GlobalInputArray {
id: 0,
item: item_tensor,
};
let indices = scope.read_array(1, item_indices, Variable::AbsolutePos);
let output_array = Variable::GlobalOutputArray {
id: 0,
item: item_tensor,
};
let output_local = scope.create_local(item_tensor);
GatherComputeShader {
tensor,
indices,
out: output_local,
dim: self.dim,
}
.expand(&mut scope);
scope.write_global(output_local, output_array, Variable::AbsolutePos);
let tensor = InputInfo::Array {
item: item_tensor,
visibility: Visibility::Read,
};
let indices = InputInfo::Array {
item: Elem::Int(IntKind::I32).into(),
visibility: Visibility::Read,
};
let out = OutputInfo::Array { item: item_tensor };
let info = KernelExpansion {
inputs: vec![tensor, indices],
outputs: vec![out],
scope,
};
let settings = KernelSettings::default();
KernelIntegrator::new(info).integrate(settings)
}
fn id(&self) -> cubecl::KernelId {
cubecl::KernelId::new::<Self>().info(self.dim)
}
output[ABSOLUTE_POS] = input[offset];
}
pub(crate) fn gather<R: JitRuntime, E: JitElement, I: JitElement, const D: usize>(
@ -131,20 +53,23 @@ pub(crate) fn gather<R: JitRuntime, E: JitElement, I: JitElement, const D: usize
indices: JitTensor<R, I, D>,
) -> JitTensor<R, E, D> {
let shape_output = indices.shape.clone();
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
let kernel = GatherEagerKernel::<R, E>::new(dim);
Execution::start(kernel, tensor.client)
.inputs(&[
TensorHandleRef::<R>::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
TensorHandleRef::new(&indices.handle, &indices.strides, &indices.shape.dims),
])
.outputs(&[TensorHandleRef::new(
&output.handle,
&output.strides,
&output.shape.dims,
)])
.execute(CubeCountSettings::Output { pos: 0 });
let cube_dim = CubeDim::default();
let cube_count =
calculate_cube_count_elemwise::<R::Server>(shape_output.num_elements(), cube_dim);
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
gather_kernel::launch::<E::Primitive, R>(
&tensor.client,
cube_count,
cube_dim,
TensorArg::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
TensorArg::new(&indices.handle, &indices.strides, &indices.shape.dims),
TensorArg::new(&output.handle, &output.strides, &output.shape.dims),
UInt::new(dim as u32),
UInt::new(D as u32),
);
output
}