mirror of https://github.com/tracel-ai/burn.git
Gather CPA to CubeCL (#2165)
* working version * cleanup * wip * working version of gather * testsetsetser * Revert "testsetsetser" This reverts commitf37b329697
. * Reapply "testsetsetser" This reverts commitf8ada0044e
. * Revert "testsetsetser" This reverts commitf37b329697
. * Revert "working version of gather" This reverts commitf5047c27c8
. * Revert "wip" This reverts commitabaaa2dd55
. * Revert "Merge branch 'main' into index-cpa-to-cubecl" This reverts commit05bed8ea74
, reversing changes made to94954fc32c
. * Revert "cleanup" This reverts commit94954fc32c
. * Revert "working version" This reverts commita06933f029
. * gather test * fix * fix clippy * cleanup
This commit is contained in:
parent
73d4b11aa2
commit
e1fed792f7
|
@ -97,8 +97,8 @@ experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
|
||||||
# Backwards compatibility with previous serialized data format.
|
# Backwards compatibility with previous serialized data format.
|
||||||
record-backward-compat = []
|
record-backward-compat = []
|
||||||
|
|
||||||
test-tch = ["tch"] # To use tch during testing, default uses ndarray.
|
test-tch = ["tch"] # To use tch during testing, default uses ndarray.
|
||||||
test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray.
|
test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray.
|
||||||
test-cuda = ["cuda-jit"] # To use cuda during testing, default uses ndarray.
|
test-cuda = ["cuda-jit"] # To use cuda during testing, default uses ndarray.
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
|
|
@ -1,128 +1,46 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
element::JitElement, kernel::Kernel, ops::numeric::empty_device, tensor::JitTensor, JitRuntime,
|
element::JitElement, kernel::Kernel, ops::numeric::empty_device, tensor::JitTensor, JitRuntime,
|
||||||
};
|
};
|
||||||
use cubecl::ir::{
|
use cubecl::frontend::{Numeric, Tensor, UInt, ABSOLUTE_POS};
|
||||||
Elem, IndexOffsetGlobalWithLayout, IntKind, Item, KernelDefinition, Scope, Variable, Visibility,
|
use cubecl::linalg::tensor::index_offset_with_layout;
|
||||||
};
|
use cubecl::CubeDim;
|
||||||
use cubecl::{
|
use cubecl::{calculate_cube_count_elemwise, prelude::*};
|
||||||
cpa, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator,
|
|
||||||
KernelSettings, OutputInfo,
|
|
||||||
};
|
|
||||||
use std::marker::PhantomData;
|
|
||||||
|
|
||||||
#[derive(new)]
|
#[cube(launch_unchecked)]
|
||||||
struct GatherEagerKernel<R: JitRuntime, E: JitElement> {
|
fn gather_kernel<T: Numeric, I: Numeric>(
|
||||||
dim: usize,
|
input: &Tensor<T>,
|
||||||
_runtime: PhantomData<R>,
|
indices: &Tensor<I>,
|
||||||
_elem: PhantomData<E>,
|
output: &mut Tensor<T>,
|
||||||
}
|
dim: &UInt,
|
||||||
|
) {
|
||||||
|
let index = indices[ABSOLUTE_POS];
|
||||||
|
|
||||||
struct GatherComputeShader {
|
let stride = input.stride(*dim);
|
||||||
tensor: Variable,
|
let mut offset = UInt::cast_from(index);
|
||||||
indices: Variable,
|
offset *= stride;
|
||||||
out: Variable,
|
|
||||||
dim: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GatherComputeShader {
|
if *dim > 0 {
|
||||||
pub fn expand(self, scope: &mut Scope) {
|
let offset_before = index_offset_with_layout(
|
||||||
match self.tensor {
|
input,
|
||||||
Variable::GlobalInputArray { .. } => (),
|
output,
|
||||||
Variable::GlobalOutputArray { .. } => (),
|
ABSOLUTE_POS,
|
||||||
_ => panic!("Tensor variable must be an global array."),
|
UInt::new(0),
|
||||||
};
|
*dim,
|
||||||
|
Comptime::new(false),
|
||||||
let tensor = self.tensor;
|
);
|
||||||
let output = self.out;
|
offset += offset_before;
|
||||||
|
|
||||||
let stride = scope.create_local(Elem::UInt);
|
|
||||||
let offset = scope.create_local(Elem::UInt);
|
|
||||||
|
|
||||||
// The offset of the `dim` dimension is obtained by the indices tensor.
|
|
||||||
cpa!(scope, offset = cast(self.indices));
|
|
||||||
cpa!(scope, stride = stride(tensor, self.dim));
|
|
||||||
cpa!(scope, offset = offset * stride);
|
|
||||||
|
|
||||||
// We fetch the offset before the `dim` dimension.
|
|
||||||
if self.dim > 0 {
|
|
||||||
let offset_before = scope.create_local(Elem::UInt);
|
|
||||||
scope.index_offset_with_output_layout(IndexOffsetGlobalWithLayout {
|
|
||||||
tensors: vec![tensor],
|
|
||||||
indexes: vec![offset_before],
|
|
||||||
layout: Variable::AbsolutePos, // Will be updated.
|
|
||||||
position: Variable::AbsolutePos,
|
|
||||||
dim_start: 0u32.into(),
|
|
||||||
dim_end: self.dim.into(),
|
|
||||||
});
|
|
||||||
cpa!(scope, offset += offset_before);
|
|
||||||
}
|
|
||||||
|
|
||||||
let offset_after = scope.create_local(Elem::UInt);
|
|
||||||
scope.index_offset_with_output_layout(IndexOffsetGlobalWithLayout {
|
|
||||||
tensors: vec![tensor],
|
|
||||||
indexes: vec![offset_after],
|
|
||||||
layout: Variable::AbsolutePos, // Will be updated.
|
|
||||||
position: Variable::AbsolutePos,
|
|
||||||
dim_start: (self.dim + 1).into(),
|
|
||||||
dim_end: Variable::Rank,
|
|
||||||
});
|
|
||||||
cpa!(scope, offset += offset_after);
|
|
||||||
|
|
||||||
cpa!(scope, output = tensor[offset]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<R: JitRuntime, E: JitElement> Kernel for GatherEagerKernel<R, E> {
|
|
||||||
fn define(&self) -> KernelDefinition {
|
|
||||||
let mut scope = Scope::root();
|
|
||||||
let item_tensor = E::cube_elem().into();
|
|
||||||
let item_indices: Item = Elem::Int(IntKind::I32).into();
|
|
||||||
|
|
||||||
let tensor = Variable::GlobalInputArray {
|
|
||||||
id: 0,
|
|
||||||
item: item_tensor,
|
|
||||||
};
|
|
||||||
let indices = scope.read_array(1, item_indices, Variable::AbsolutePos);
|
|
||||||
|
|
||||||
let output_array = Variable::GlobalOutputArray {
|
|
||||||
id: 0,
|
|
||||||
item: item_tensor,
|
|
||||||
};
|
|
||||||
let output_local = scope.create_local(item_tensor);
|
|
||||||
|
|
||||||
GatherComputeShader {
|
|
||||||
tensor,
|
|
||||||
indices,
|
|
||||||
out: output_local,
|
|
||||||
dim: self.dim,
|
|
||||||
}
|
|
||||||
.expand(&mut scope);
|
|
||||||
|
|
||||||
scope.write_global(output_local, output_array, Variable::AbsolutePos);
|
|
||||||
|
|
||||||
let tensor = InputInfo::Array {
|
|
||||||
item: item_tensor,
|
|
||||||
visibility: Visibility::Read,
|
|
||||||
};
|
|
||||||
let indices = InputInfo::Array {
|
|
||||||
item: Elem::Int(IntKind::I32).into(),
|
|
||||||
visibility: Visibility::Read,
|
|
||||||
};
|
|
||||||
let out = OutputInfo::Array { item: item_tensor };
|
|
||||||
|
|
||||||
let info = KernelExpansion {
|
|
||||||
inputs: vec![tensor, indices],
|
|
||||||
outputs: vec![out],
|
|
||||||
scope,
|
|
||||||
};
|
|
||||||
|
|
||||||
let settings = KernelSettings::default();
|
|
||||||
KernelIntegrator::new(info).integrate(settings)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn id(&self) -> cubecl::KernelId {
|
let offset_after = index_offset_with_layout(
|
||||||
cubecl::KernelId::new::<Self>().info(self.dim)
|
input,
|
||||||
}
|
output,
|
||||||
|
ABSOLUTE_POS,
|
||||||
|
*dim + 1,
|
||||||
|
input.rank(),
|
||||||
|
Comptime::new(false),
|
||||||
|
);
|
||||||
|
offset += offset_after;
|
||||||
|
output[ABSOLUTE_POS] = input[offset];
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn gather<R: JitRuntime, E: JitElement, I: JitElement, const D: usize>(
|
pub(crate) fn gather<R: JitRuntime, E: JitElement, I: JitElement, const D: usize>(
|
||||||
|
@ -131,13 +49,21 @@ pub(crate) fn gather<R: JitRuntime, E: JitElement, I: JitElement, const D: usize
|
||||||
indices: JitTensor<R, I, D>,
|
indices: JitTensor<R, I, D>,
|
||||||
) -> JitTensor<R, E, D> {
|
) -> JitTensor<R, E, D> {
|
||||||
let shape_output = indices.shape.clone();
|
let shape_output = indices.shape.clone();
|
||||||
|
let total_elem = shape_output.num_elements();
|
||||||
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
|
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
|
||||||
let kernel = GatherEagerKernel::<R, E>::new(dim);
|
|
||||||
|
|
||||||
Execution::start(kernel, tensor.client.clone())
|
|
||||||
.inputs(&[tensor.as_handle_ref(), indices.as_handle_ref()])
|
|
||||||
.outputs(&[output.as_handle_ref()])
|
|
||||||
.execute(CubeCountSettings::Output { pos: 0 });
|
|
||||||
|
|
||||||
|
let cube_dim = CubeDim::default();
|
||||||
|
let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim);
|
||||||
|
unsafe {
|
||||||
|
gather_kernel::launch_unchecked::<E::Primitive, I::Primitive, R>(
|
||||||
|
&tensor.client,
|
||||||
|
cube_count,
|
||||||
|
cube_dim,
|
||||||
|
tensor.as_tensor_arg(1),
|
||||||
|
indices.as_tensor_arg(1),
|
||||||
|
output.as_tensor_arg(1),
|
||||||
|
ScalarArg::new(dim as u32),
|
||||||
|
)
|
||||||
|
}
|
||||||
output
|
output
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,15 @@ cuda-jit = ["burn/cuda-jit"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
# Burn
|
# Burn
|
||||||
burn = {path = "../../crates/burn", features=["train", "ndarray", "std", "metrics", "autotune", "fusion", "default"], default-features = false}
|
burn = { path = "../../crates/burn", features = [
|
||||||
|
"train",
|
||||||
|
"ndarray",
|
||||||
|
"std",
|
||||||
|
"metrics",
|
||||||
|
"autotune",
|
||||||
|
"fusion",
|
||||||
|
"default",
|
||||||
|
], default-features = false }
|
||||||
|
|
||||||
# Tokenizer
|
# Tokenizer
|
||||||
tokenizers = { version = "0.20.0", default-features = false, features = [
|
tokenizers = { version = "0.20.0", default-features = false, features = [
|
||||||
|
|
Loading…
Reference in New Issue