This commit is contained in:
louisfd 2024-07-29 14:14:55 -04:00
parent def45f0b58
commit 4bc1a77819
1 changed files with 54 additions and 129 deletions

View File

@ -1,128 +1,50 @@
use crate::{ use crate::{
element::JitElement, kernel::Kernel, ops::numeric::empty_device, tensor::JitTensor, JitRuntime, element::JitElement, kernel::Kernel, ops::numeric::empty_device, tensor::JitTensor, JitRuntime,
}; };
use cubecl::ir::{ use cubecl::ir::KernelDefinition;
Elem, IndexOffsetGlobalWithLayout, IntKind, Item, KernelDefinition, Scope, Variable, Visibility, use cubecl::linalg::tensor::index_offset_with_layout;
}; use cubecl::prelude::*;
use cubecl::{ use cubecl::{calculate_cube_count_elemwise, CubeDim};
cpa, frontend::TensorHandleRef, CubeCountSettings, Execution, InputInfo, KernelExpansion,
KernelIntegrator, KernelSettings, OutputInfo,
};
use std::marker::PhantomData;
#[derive(new)] #[cube(launch)]
struct GatherEagerKernel<R: JitRuntime, E: JitElement> { fn gather_kernel<T: Numeric>(
dim: usize, input: &Tensor<T>,
_runtime: PhantomData<R>, indices: &Tensor<UInt>,
_elem: PhantomData<E>, output: &mut Tensor<T>,
} dim: Comptime<UInt>,
rank: Comptime<UInt>,
) {
let dim_runtime = Comptime::runtime(dim);
let not_zeroth_dim = Comptime::map(dim, |d| d > 0);
struct GatherComputeShader { // The offset for the `dim` dimension is obtained by the indices tensor.
tensor: Variable, let index = indices[ABSOLUTE_POS];
indices: Variable, let stride = input.stride(dim_runtime);
out: Variable,
dim: usize,
}
impl GatherComputeShader { let mut offset = index * stride;
pub fn expand(self, scope: &mut Scope) {
match self.tensor {
Variable::GlobalInputArray { .. } => (),
Variable::GlobalOutputArray { .. } => (),
_ => panic!("Tensor variable must be an global array."),
};
let tensor = self.tensor; // We fetch the offset before the `dim` dimension.
let output = self.out; if Comptime::get(not_zeroth_dim) {
offset += index_offset_with_layout(
let stride = scope.create_local(Elem::UInt); input,
let offset = scope.create_local(Elem::UInt); output,
ABSOLUTE_POS,
// The offset of the `dim` dimension is obtained by the indices tensor. UInt::new(0),
cpa!(scope, offset = cast(self.indices)); dim_runtime,
cpa!(scope, stride = stride(tensor, self.dim)); Comptime::new(true),
cpa!(scope, offset = offset * stride); );
// We fetch the offset before the `dim` dimension.
if self.dim > 0 {
let offset_before = scope.create_local(Elem::UInt);
scope.index_offset_with_output_layout(IndexOffsetGlobalWithLayout {
tensors: vec![tensor],
indexes: vec![offset_before],
layout: Variable::AbsolutePos, // Will be updated.
position: Variable::AbsolutePos,
dim_start: 0u32.into(),
dim_end: self.dim.into(),
});
cpa!(scope, offset += offset_before);
}
let offset_after = scope.create_local(Elem::UInt);
scope.index_offset_with_output_layout(IndexOffsetGlobalWithLayout {
tensors: vec![tensor],
indexes: vec![offset_after],
layout: Variable::AbsolutePos, // Will be updated.
position: Variable::AbsolutePos,
dim_start: (self.dim + 1).into(),
dim_end: Variable::Rank,
});
cpa!(scope, offset += offset_after);
cpa!(scope, output = tensor[offset]);
}
}
impl<R: JitRuntime, E: JitElement> Kernel for GatherEagerKernel<R, E> {
fn define(&self) -> KernelDefinition {
let mut scope = Scope::root();
let item_tensor = E::cube_elem().into();
let item_indices: Item = Elem::Int(IntKind::I32).into();
let tensor = Variable::GlobalInputArray {
id: 0,
item: item_tensor,
};
let indices = scope.read_array(1, item_indices, Variable::AbsolutePos);
let output_array = Variable::GlobalOutputArray {
id: 0,
item: item_tensor,
};
let output_local = scope.create_local(item_tensor);
GatherComputeShader {
tensor,
indices,
out: output_local,
dim: self.dim,
}
.expand(&mut scope);
scope.write_global(output_local, output_array, Variable::AbsolutePos);
let tensor = InputInfo::Array {
item: item_tensor,
visibility: Visibility::Read,
};
let indices = InputInfo::Array {
item: Elem::Int(IntKind::I32).into(),
visibility: Visibility::Read,
};
let out = OutputInfo::Array { item: item_tensor };
let info = KernelExpansion {
inputs: vec![tensor, indices],
outputs: vec![out],
scope,
};
let settings = KernelSettings::default();
KernelIntegrator::new(info).integrate(settings)
} }
fn id(&self) -> cubecl::KernelId { offset += index_offset_with_layout(
cubecl::KernelId::new::<Self>().info(self.dim) input,
} output,
ABSOLUTE_POS,
Comptime::runtime(Comptime::map(dim, |d| d + 1)),
Comptime::runtime(rank),
Comptime::new(true),
);
output[ABSOLUTE_POS] = input[offset];
} }
pub(crate) fn gather<R: JitRuntime, E: JitElement, I: JitElement, const D: usize>( pub(crate) fn gather<R: JitRuntime, E: JitElement, I: JitElement, const D: usize>(
@ -131,20 +53,23 @@ pub(crate) fn gather<R: JitRuntime, E: JitElement, I: JitElement, const D: usize
indices: JitTensor<R, I, D>, indices: JitTensor<R, I, D>,
) -> JitTensor<R, E, D> { ) -> JitTensor<R, E, D> {
let shape_output = indices.shape.clone(); let shape_output = indices.shape.clone();
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
let kernel = GatherEagerKernel::<R, E>::new(dim);
Execution::start(kernel, tensor.client) let cube_dim = CubeDim::default();
.inputs(&[ let cube_count =
TensorHandleRef::<R>::new(&tensor.handle, &tensor.strides, &tensor.shape.dims), calculate_cube_count_elemwise::<R::Server>(shape_output.num_elements(), cube_dim);
TensorHandleRef::new(&indices.handle, &indices.strides, &indices.shape.dims),
]) let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
.outputs(&[TensorHandleRef::new(
&output.handle, gather_kernel::launch::<E::Primitive, R>(
&output.strides, &tensor.client,
&output.shape.dims, cube_count,
)]) cube_dim,
.execute(CubeCountSettings::Output { pos: 0 }); TensorArg::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
TensorArg::new(&indices.handle, &indices.strides, &indices.shape.dims),
TensorArg::new(&output.handle, &output.strides, &output.shape.dims),
UInt::new(dim as u32),
UInt::new(D as u32),
);
output output
} }