mirror of https://github.com/tracel-ai/burn.git
gather
This commit is contained in:
parent
def45f0b58
commit
4bc1a77819
|
@ -1,128 +1,50 @@
|
|||
use crate::{
|
||||
element::JitElement, kernel::Kernel, ops::numeric::empty_device, tensor::JitTensor, JitRuntime,
|
||||
};
|
||||
use cubecl::ir::{
|
||||
Elem, IndexOffsetGlobalWithLayout, IntKind, Item, KernelDefinition, Scope, Variable, Visibility,
|
||||
};
|
||||
use cubecl::{
|
||||
cpa, frontend::TensorHandleRef, CubeCountSettings, Execution, InputInfo, KernelExpansion,
|
||||
KernelIntegrator, KernelSettings, OutputInfo,
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
use cubecl::ir::KernelDefinition;
|
||||
use cubecl::linalg::tensor::index_offset_with_layout;
|
||||
use cubecl::prelude::*;
|
||||
use cubecl::{calculate_cube_count_elemwise, CubeDim};
|
||||
|
||||
#[derive(new)]
|
||||
struct GatherEagerKernel<R: JitRuntime, E: JitElement> {
|
||||
dim: usize,
|
||||
_runtime: PhantomData<R>,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
#[cube(launch)]
|
||||
fn gather_kernel<T: Numeric>(
|
||||
input: &Tensor<T>,
|
||||
indices: &Tensor<UInt>,
|
||||
output: &mut Tensor<T>,
|
||||
dim: Comptime<UInt>,
|
||||
rank: Comptime<UInt>,
|
||||
) {
|
||||
let dim_runtime = Comptime::runtime(dim);
|
||||
let not_zeroth_dim = Comptime::map(dim, |d| d > 0);
|
||||
|
||||
struct GatherComputeShader {
|
||||
tensor: Variable,
|
||||
indices: Variable,
|
||||
out: Variable,
|
||||
dim: usize,
|
||||
}
|
||||
// The offset for the `dim` dimension is obtained by the indices tensor.
|
||||
let index = indices[ABSOLUTE_POS];
|
||||
let stride = input.stride(dim_runtime);
|
||||
|
||||
impl GatherComputeShader {
|
||||
pub fn expand(self, scope: &mut Scope) {
|
||||
match self.tensor {
|
||||
Variable::GlobalInputArray { .. } => (),
|
||||
Variable::GlobalOutputArray { .. } => (),
|
||||
_ => panic!("Tensor variable must be an global array."),
|
||||
};
|
||||
|
||||
let tensor = self.tensor;
|
||||
let output = self.out;
|
||||
|
||||
let stride = scope.create_local(Elem::UInt);
|
||||
let offset = scope.create_local(Elem::UInt);
|
||||
|
||||
// The offset of the `dim` dimension is obtained by the indices tensor.
|
||||
cpa!(scope, offset = cast(self.indices));
|
||||
cpa!(scope, stride = stride(tensor, self.dim));
|
||||
cpa!(scope, offset = offset * stride);
|
||||
let mut offset = index * stride;
|
||||
|
||||
// We fetch the offset before the `dim` dimension.
|
||||
if self.dim > 0 {
|
||||
let offset_before = scope.create_local(Elem::UInt);
|
||||
scope.index_offset_with_output_layout(IndexOffsetGlobalWithLayout {
|
||||
tensors: vec![tensor],
|
||||
indexes: vec![offset_before],
|
||||
layout: Variable::AbsolutePos, // Will be updated.
|
||||
position: Variable::AbsolutePos,
|
||||
dim_start: 0u32.into(),
|
||||
dim_end: self.dim.into(),
|
||||
});
|
||||
cpa!(scope, offset += offset_before);
|
||||
if Comptime::get(not_zeroth_dim) {
|
||||
offset += index_offset_with_layout(
|
||||
input,
|
||||
output,
|
||||
ABSOLUTE_POS,
|
||||
UInt::new(0),
|
||||
dim_runtime,
|
||||
Comptime::new(true),
|
||||
);
|
||||
}
|
||||
|
||||
let offset_after = scope.create_local(Elem::UInt);
|
||||
scope.index_offset_with_output_layout(IndexOffsetGlobalWithLayout {
|
||||
tensors: vec![tensor],
|
||||
indexes: vec![offset_after],
|
||||
layout: Variable::AbsolutePos, // Will be updated.
|
||||
position: Variable::AbsolutePos,
|
||||
dim_start: (self.dim + 1).into(),
|
||||
dim_end: Variable::Rank,
|
||||
});
|
||||
cpa!(scope, offset += offset_after);
|
||||
offset += index_offset_with_layout(
|
||||
input,
|
||||
output,
|
||||
ABSOLUTE_POS,
|
||||
Comptime::runtime(Comptime::map(dim, |d| d + 1)),
|
||||
Comptime::runtime(rank),
|
||||
Comptime::new(true),
|
||||
);
|
||||
|
||||
cpa!(scope, output = tensor[offset]);
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: JitRuntime, E: JitElement> Kernel for GatherEagerKernel<R, E> {
|
||||
fn define(&self) -> KernelDefinition {
|
||||
let mut scope = Scope::root();
|
||||
let item_tensor = E::cube_elem().into();
|
||||
let item_indices: Item = Elem::Int(IntKind::I32).into();
|
||||
|
||||
let tensor = Variable::GlobalInputArray {
|
||||
id: 0,
|
||||
item: item_tensor,
|
||||
};
|
||||
let indices = scope.read_array(1, item_indices, Variable::AbsolutePos);
|
||||
|
||||
let output_array = Variable::GlobalOutputArray {
|
||||
id: 0,
|
||||
item: item_tensor,
|
||||
};
|
||||
let output_local = scope.create_local(item_tensor);
|
||||
|
||||
GatherComputeShader {
|
||||
tensor,
|
||||
indices,
|
||||
out: output_local,
|
||||
dim: self.dim,
|
||||
}
|
||||
.expand(&mut scope);
|
||||
|
||||
scope.write_global(output_local, output_array, Variable::AbsolutePos);
|
||||
|
||||
let tensor = InputInfo::Array {
|
||||
item: item_tensor,
|
||||
visibility: Visibility::Read,
|
||||
};
|
||||
let indices = InputInfo::Array {
|
||||
item: Elem::Int(IntKind::I32).into(),
|
||||
visibility: Visibility::Read,
|
||||
};
|
||||
let out = OutputInfo::Array { item: item_tensor };
|
||||
|
||||
let info = KernelExpansion {
|
||||
inputs: vec![tensor, indices],
|
||||
outputs: vec![out],
|
||||
scope,
|
||||
};
|
||||
|
||||
let settings = KernelSettings::default();
|
||||
KernelIntegrator::new(info).integrate(settings)
|
||||
}
|
||||
|
||||
fn id(&self) -> cubecl::KernelId {
|
||||
cubecl::KernelId::new::<Self>().info(self.dim)
|
||||
}
|
||||
output[ABSOLUTE_POS] = input[offset];
|
||||
}
|
||||
|
||||
pub(crate) fn gather<R: JitRuntime, E: JitElement, I: JitElement, const D: usize>(
|
||||
|
@ -131,20 +53,23 @@ pub(crate) fn gather<R: JitRuntime, E: JitElement, I: JitElement, const D: usize
|
|||
indices: JitTensor<R, I, D>,
|
||||
) -> JitTensor<R, E, D> {
|
||||
let shape_output = indices.shape.clone();
|
||||
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
|
||||
let kernel = GatherEagerKernel::<R, E>::new(dim);
|
||||
|
||||
Execution::start(kernel, tensor.client)
|
||||
.inputs(&[
|
||||
TensorHandleRef::<R>::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
|
||||
TensorHandleRef::new(&indices.handle, &indices.strides, &indices.shape.dims),
|
||||
])
|
||||
.outputs(&[TensorHandleRef::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)])
|
||||
.execute(CubeCountSettings::Output { pos: 0 });
|
||||
let cube_dim = CubeDim::default();
|
||||
let cube_count =
|
||||
calculate_cube_count_elemwise::<R::Server>(shape_output.num_elements(), cube_dim);
|
||||
|
||||
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
|
||||
|
||||
gather_kernel::launch::<E::Primitive, R>(
|
||||
&tensor.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
|
||||
TensorArg::new(&indices.handle, &indices.strides, &indices.shape.dims),
|
||||
TensorArg::new(&output.handle, &output.strides, &output.shape.dims),
|
||||
UInt::new(dim as u32),
|
||||
UInt::new(D as u32),
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue