mirror of https://github.com/tracel-ai/burn.git
gather
This commit is contained in:
parent
def45f0b58
commit
4bc1a77819
|
@ -1,128 +1,50 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
element::JitElement, kernel::Kernel, ops::numeric::empty_device, tensor::JitTensor, JitRuntime,
|
element::JitElement, kernel::Kernel, ops::numeric::empty_device, tensor::JitTensor, JitRuntime,
|
||||||
};
|
};
|
||||||
use cubecl::ir::{
|
use cubecl::ir::KernelDefinition;
|
||||||
Elem, IndexOffsetGlobalWithLayout, IntKind, Item, KernelDefinition, Scope, Variable, Visibility,
|
use cubecl::linalg::tensor::index_offset_with_layout;
|
||||||
};
|
use cubecl::prelude::*;
|
||||||
use cubecl::{
|
use cubecl::{calculate_cube_count_elemwise, CubeDim};
|
||||||
cpa, frontend::TensorHandleRef, CubeCountSettings, Execution, InputInfo, KernelExpansion,
|
|
||||||
KernelIntegrator, KernelSettings, OutputInfo,
|
|
||||||
};
|
|
||||||
use std::marker::PhantomData;
|
|
||||||
|
|
||||||
#[derive(new)]
|
#[cube(launch)]
|
||||||
struct GatherEagerKernel<R: JitRuntime, E: JitElement> {
|
fn gather_kernel<T: Numeric>(
|
||||||
dim: usize,
|
input: &Tensor<T>,
|
||||||
_runtime: PhantomData<R>,
|
indices: &Tensor<UInt>,
|
||||||
_elem: PhantomData<E>,
|
output: &mut Tensor<T>,
|
||||||
}
|
dim: Comptime<UInt>,
|
||||||
|
rank: Comptime<UInt>,
|
||||||
|
) {
|
||||||
|
let dim_runtime = Comptime::runtime(dim);
|
||||||
|
let not_zeroth_dim = Comptime::map(dim, |d| d > 0);
|
||||||
|
|
||||||
struct GatherComputeShader {
|
// The offset for the `dim` dimension is obtained by the indices tensor.
|
||||||
tensor: Variable,
|
let index = indices[ABSOLUTE_POS];
|
||||||
indices: Variable,
|
let stride = input.stride(dim_runtime);
|
||||||
out: Variable,
|
|
||||||
dim: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GatherComputeShader {
|
let mut offset = index * stride;
|
||||||
pub fn expand(self, scope: &mut Scope) {
|
|
||||||
match self.tensor {
|
|
||||||
Variable::GlobalInputArray { .. } => (),
|
|
||||||
Variable::GlobalOutputArray { .. } => (),
|
|
||||||
_ => panic!("Tensor variable must be an global array."),
|
|
||||||
};
|
|
||||||
|
|
||||||
let tensor = self.tensor;
|
// We fetch the offset before the `dim` dimension.
|
||||||
let output = self.out;
|
if Comptime::get(not_zeroth_dim) {
|
||||||
|
offset += index_offset_with_layout(
|
||||||
let stride = scope.create_local(Elem::UInt);
|
input,
|
||||||
let offset = scope.create_local(Elem::UInt);
|
output,
|
||||||
|
ABSOLUTE_POS,
|
||||||
// The offset of the `dim` dimension is obtained by the indices tensor.
|
UInt::new(0),
|
||||||
cpa!(scope, offset = cast(self.indices));
|
dim_runtime,
|
||||||
cpa!(scope, stride = stride(tensor, self.dim));
|
Comptime::new(true),
|
||||||
cpa!(scope, offset = offset * stride);
|
);
|
||||||
|
|
||||||
// We fetch the offset before the `dim` dimension.
|
|
||||||
if self.dim > 0 {
|
|
||||||
let offset_before = scope.create_local(Elem::UInt);
|
|
||||||
scope.index_offset_with_output_layout(IndexOffsetGlobalWithLayout {
|
|
||||||
tensors: vec![tensor],
|
|
||||||
indexes: vec![offset_before],
|
|
||||||
layout: Variable::AbsolutePos, // Will be updated.
|
|
||||||
position: Variable::AbsolutePos,
|
|
||||||
dim_start: 0u32.into(),
|
|
||||||
dim_end: self.dim.into(),
|
|
||||||
});
|
|
||||||
cpa!(scope, offset += offset_before);
|
|
||||||
}
|
|
||||||
|
|
||||||
let offset_after = scope.create_local(Elem::UInt);
|
|
||||||
scope.index_offset_with_output_layout(IndexOffsetGlobalWithLayout {
|
|
||||||
tensors: vec![tensor],
|
|
||||||
indexes: vec![offset_after],
|
|
||||||
layout: Variable::AbsolutePos, // Will be updated.
|
|
||||||
position: Variable::AbsolutePos,
|
|
||||||
dim_start: (self.dim + 1).into(),
|
|
||||||
dim_end: Variable::Rank,
|
|
||||||
});
|
|
||||||
cpa!(scope, offset += offset_after);
|
|
||||||
|
|
||||||
cpa!(scope, output = tensor[offset]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<R: JitRuntime, E: JitElement> Kernel for GatherEagerKernel<R, E> {
|
|
||||||
fn define(&self) -> KernelDefinition {
|
|
||||||
let mut scope = Scope::root();
|
|
||||||
let item_tensor = E::cube_elem().into();
|
|
||||||
let item_indices: Item = Elem::Int(IntKind::I32).into();
|
|
||||||
|
|
||||||
let tensor = Variable::GlobalInputArray {
|
|
||||||
id: 0,
|
|
||||||
item: item_tensor,
|
|
||||||
};
|
|
||||||
let indices = scope.read_array(1, item_indices, Variable::AbsolutePos);
|
|
||||||
|
|
||||||
let output_array = Variable::GlobalOutputArray {
|
|
||||||
id: 0,
|
|
||||||
item: item_tensor,
|
|
||||||
};
|
|
||||||
let output_local = scope.create_local(item_tensor);
|
|
||||||
|
|
||||||
GatherComputeShader {
|
|
||||||
tensor,
|
|
||||||
indices,
|
|
||||||
out: output_local,
|
|
||||||
dim: self.dim,
|
|
||||||
}
|
|
||||||
.expand(&mut scope);
|
|
||||||
|
|
||||||
scope.write_global(output_local, output_array, Variable::AbsolutePos);
|
|
||||||
|
|
||||||
let tensor = InputInfo::Array {
|
|
||||||
item: item_tensor,
|
|
||||||
visibility: Visibility::Read,
|
|
||||||
};
|
|
||||||
let indices = InputInfo::Array {
|
|
||||||
item: Elem::Int(IntKind::I32).into(),
|
|
||||||
visibility: Visibility::Read,
|
|
||||||
};
|
|
||||||
let out = OutputInfo::Array { item: item_tensor };
|
|
||||||
|
|
||||||
let info = KernelExpansion {
|
|
||||||
inputs: vec![tensor, indices],
|
|
||||||
outputs: vec![out],
|
|
||||||
scope,
|
|
||||||
};
|
|
||||||
|
|
||||||
let settings = KernelSettings::default();
|
|
||||||
KernelIntegrator::new(info).integrate(settings)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn id(&self) -> cubecl::KernelId {
|
offset += index_offset_with_layout(
|
||||||
cubecl::KernelId::new::<Self>().info(self.dim)
|
input,
|
||||||
}
|
output,
|
||||||
|
ABSOLUTE_POS,
|
||||||
|
Comptime::runtime(Comptime::map(dim, |d| d + 1)),
|
||||||
|
Comptime::runtime(rank),
|
||||||
|
Comptime::new(true),
|
||||||
|
);
|
||||||
|
|
||||||
|
output[ABSOLUTE_POS] = input[offset];
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn gather<R: JitRuntime, E: JitElement, I: JitElement, const D: usize>(
|
pub(crate) fn gather<R: JitRuntime, E: JitElement, I: JitElement, const D: usize>(
|
||||||
|
@ -131,20 +53,23 @@ pub(crate) fn gather<R: JitRuntime, E: JitElement, I: JitElement, const D: usize
|
||||||
indices: JitTensor<R, I, D>,
|
indices: JitTensor<R, I, D>,
|
||||||
) -> JitTensor<R, E, D> {
|
) -> JitTensor<R, E, D> {
|
||||||
let shape_output = indices.shape.clone();
|
let shape_output = indices.shape.clone();
|
||||||
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
|
|
||||||
let kernel = GatherEagerKernel::<R, E>::new(dim);
|
|
||||||
|
|
||||||
Execution::start(kernel, tensor.client)
|
let cube_dim = CubeDim::default();
|
||||||
.inputs(&[
|
let cube_count =
|
||||||
TensorHandleRef::<R>::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
|
calculate_cube_count_elemwise::<R::Server>(shape_output.num_elements(), cube_dim);
|
||||||
TensorHandleRef::new(&indices.handle, &indices.strides, &indices.shape.dims),
|
|
||||||
])
|
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
|
||||||
.outputs(&[TensorHandleRef::new(
|
|
||||||
&output.handle,
|
gather_kernel::launch::<E::Primitive, R>(
|
||||||
&output.strides,
|
&tensor.client,
|
||||||
&output.shape.dims,
|
cube_count,
|
||||||
)])
|
cube_dim,
|
||||||
.execute(CubeCountSettings::Output { pos: 0 });
|
TensorArg::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
|
||||||
|
TensorArg::new(&indices.handle, &indices.strides, &indices.shape.dims),
|
||||||
|
TensorArg::new(&output.handle, &output.strides, &output.shape.dims),
|
||||||
|
UInt::new(dim as u32),
|
||||||
|
UInt::new(D as u32),
|
||||||
|
);
|
||||||
|
|
||||||
output
|
output
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue