mirror of https://github.com/tracel-ai/burn.git
working version
This commit is contained in:
parent
be705466c9
commit
a06933f029
|
@ -1,119 +1,79 @@
|
|||
use crate::{
|
||||
element::JitElement, kernel::Kernel, ops::numeric::empty_device, tensor::JitTensor, JitRuntime,
|
||||
};
|
||||
use cubecl::prelude::*;
|
||||
use cubecl::{
|
||||
cpa,
|
||||
frontend::TensorHandleRef,
|
||||
ir::{Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility},
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
calculate_cube_count_elemwise, frontend::TensorHandleRef, CubeCountSettings, CubeDim, Execution,
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[derive(new)]
|
||||
struct SelectEagerKernel<R: JitRuntime, E: JitElement> {
|
||||
dim: usize,
|
||||
_runtime: PhantomData<R>,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
#[cube(launch_unchecked)]
|
||||
fn select_kernel<T: Numeric>(
|
||||
input: &Tensor<T>,
|
||||
indices: &Tensor<I32>,
|
||||
output: &mut Tensor<T>,
|
||||
dim: &UInt,
|
||||
) {
|
||||
let id = ABSOLUTE_POS;
|
||||
let mut offset_input = UInt::new(0);
|
||||
let rank = output.rank();
|
||||
for i in range(UInt::new(0), rank, Comptime::new(false)) {
|
||||
let stride_input = input.stride(i);
|
||||
let stride_output = output.stride(i);
|
||||
let shape_output = output.shape(i);
|
||||
let mut offset_local = id / stride_output;
|
||||
offset_local = offset_local % shape_output;
|
||||
|
||||
pub struct SelectComputeShader {
|
||||
input: Variable,
|
||||
indices: Variable,
|
||||
output: Variable,
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl SelectComputeShader {
|
||||
pub fn expand(self, scope: &mut Scope) {
|
||||
let input = self.input;
|
||||
let indices = self.indices;
|
||||
let output = self.output;
|
||||
let id = Variable::AbsolutePos;
|
||||
let offset_input = scope.zero(Elem::UInt);
|
||||
|
||||
cpa!(
|
||||
scope,
|
||||
range(0u32, Variable::Rank).for_each(|i, scope| {
|
||||
let stride_input = scope.create_local(Elem::UInt);
|
||||
let stride_output = scope.create_local(Elem::UInt);
|
||||
let shape_output = scope.create_local(Elem::UInt);
|
||||
|
||||
cpa!(scope, stride_input = stride(input, i));
|
||||
cpa!(scope, stride_output = stride(output, i));
|
||||
cpa!(scope, shape_output = shape(output, i));
|
||||
|
||||
let offset_local = scope.create_local(Elem::UInt);
|
||||
cpa!(scope, offset_local = id / stride_output);
|
||||
cpa!(scope, offset_local = offset_local % shape_output);
|
||||
|
||||
let dim_index = scope.create_local(Elem::Bool);
|
||||
cpa!(scope, dim_index = i == self.dim);
|
||||
|
||||
cpa!(scope, if(dim_index).then(|scope| {
|
||||
cpa!(scope, offset_local = indices[offset_local]);
|
||||
cpa!(scope, offset_local = offset_local * stride_input);
|
||||
}).else(|scope| {
|
||||
cpa!(scope, offset_local = offset_local * stride_input);
|
||||
}));
|
||||
|
||||
cpa!(scope, offset_input += offset_local);
|
||||
})
|
||||
);
|
||||
|
||||
let value = scope.create_local(input.item());
|
||||
cpa!(scope, value = input[offset_input]);
|
||||
cpa!(scope, output[id] = value);
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: JitRuntime, E: JitElement> Kernel for SelectEagerKernel<R, E> {
|
||||
fn define(&self) -> KernelDefinition {
|
||||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
let item_indices: Item = Elem::Int(IntKind::I32).into();
|
||||
|
||||
let input = Variable::GlobalInputArray { id: 0, item };
|
||||
let indices = Variable::GlobalInputArray {
|
||||
id: 1,
|
||||
item: item_indices,
|
||||
};
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
SelectComputeShader {
|
||||
input,
|
||||
indices,
|
||||
output,
|
||||
dim: self.dim,
|
||||
if i == *dim {
|
||||
offset_local = UInt::cast_from(indices[offset_local]);
|
||||
offset_local *= stride_input;
|
||||
} else {
|
||||
offset_local *= stride_input;
|
||||
}
|
||||
.expand(&mut scope);
|
||||
|
||||
let input = InputInfo::Array {
|
||||
item,
|
||||
visibility: Visibility::Read,
|
||||
};
|
||||
let indices = InputInfo::Array {
|
||||
item: item_indices,
|
||||
visibility: Visibility::Read,
|
||||
};
|
||||
let output = OutputInfo::Array { item };
|
||||
|
||||
let info = KernelExpansion {
|
||||
inputs: vec![input, indices],
|
||||
outputs: vec![output],
|
||||
scope,
|
||||
};
|
||||
|
||||
let settings = KernelSettings::default();
|
||||
KernelIntegrator::new(info).integrate(settings)
|
||||
}
|
||||
|
||||
fn id(&self) -> cubecl::KernelId {
|
||||
cubecl::KernelId::new::<Self>().info(self.dim)
|
||||
offset_input += offset_local;
|
||||
}
|
||||
let value = input[offset_input];
|
||||
output[id] = value;
|
||||
}
|
||||
//pub fn expand(self, scope: &mut Scope) {
|
||||
// let input = self.input;
|
||||
// let indices = self.indices;
|
||||
// let output = self.output;
|
||||
// let id = Variable::AbsolutePos;
|
||||
// let offset_input = scope.zero(Elem::UInt);
|
||||
//
|
||||
// cpa!(
|
||||
// scope,
|
||||
// range(0u32, Variable::Rank).for_each(|i, scope| {
|
||||
// let stride_input = scope.create_local(Elem::UInt);
|
||||
// let stride_output = scope.create_local(Elem::UInt);
|
||||
// let shape_output = scope.create_local(Elem::UInt);
|
||||
//
|
||||
// cpa!(scope, stride_input = stride(input, i));
|
||||
// cpa!(scope, stride_output = stride(output, i));
|
||||
// cpa!(scope, shape_output = shape(output, i));
|
||||
//
|
||||
// let offset_local = scope.create_local(Elem::UInt);
|
||||
// cpa!(scope, offset_local = id / stride_output);
|
||||
// cpa!(scope, offset_local = offset_local % shape_output);
|
||||
//
|
||||
// let dim_index = scope.create_local(Elem::Bool);
|
||||
// cpa!(scope, dim_index = i == self.dim);
|
||||
//
|
||||
// cpa!(scope, if(dim_index).then(|scope| {
|
||||
// cpa!(scope, offset_local = indices[offset_local]);
|
||||
// cpa!(scope, offset_local = offset_local * stride_input);
|
||||
// }).else(|scope| {
|
||||
// cpa!(scope, offset_local = offset_local * stride_input);
|
||||
// }));
|
||||
//
|
||||
// cpa!(scope, offset_input += offset_local);
|
||||
// })
|
||||
// );
|
||||
//
|
||||
// let value = scope.create_local(input.item());
|
||||
// cpa!(scope, value = input[offset_input]);
|
||||
// cpa!(scope, output[id] = value);
|
||||
//}
|
||||
|
||||
pub(crate) fn select<R: JitRuntime, E: JitElement, I: JitElement, const D: usize>(
|
||||
tensor: JitTensor<R, E, D>,
|
||||
|
@ -122,26 +82,33 @@ pub(crate) fn select<R: JitRuntime, E: JitElement, I: JitElement, const D: usize
|
|||
) -> JitTensor<R, E, D> {
|
||||
let mut shape_output = tensor.shape.clone();
|
||||
shape_output.dims[dim] = indices.shape.dims[0];
|
||||
|
||||
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
|
||||
let kernel = SelectEagerKernel::<R, E>::new(dim);
|
||||
|
||||
let num_elems = indices.shape.dims[0];
|
||||
let mut total_elem = 1;
|
||||
for dim_size in shape_output.dims.iter() {
|
||||
total_elem *= dim_size
|
||||
}
|
||||
let mut shapes = [1; D];
|
||||
let mut strides = [num_elems; D];
|
||||
shapes[D - 1] = num_elems;
|
||||
strides[D - 1] = 1;
|
||||
Execution::start(kernel, tensor.client.clone())
|
||||
.inputs(&[
|
||||
tensor.as_handle_ref(),
|
||||
// This is a current hacks because the info buffer that contains the strides and shapes is
|
||||
// hardcoded to only contains information about tensors of the same rank. However, since
|
||||
// we don't rely on the shape and stride of the indices tensors, it doesn't matter
|
||||
// which value we put, it just needs to be of the same rank.
|
||||
unsafe { TensorHandleRef::from_raw_parts(&indices.handle, &strides, &shapes) },
|
||||
])
|
||||
.outputs(&[output.as_handle_ref()])
|
||||
.execute(CubeCountSettings::Output { pos: 0 });
|
||||
|
||||
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
|
||||
|
||||
let cube_dim = CubeDim::default();
|
||||
|
||||
let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim);
|
||||
|
||||
unsafe {
|
||||
select_kernel::launch_unchecked::<E::Primitive, R>(
|
||||
&tensor.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
tensor.as_tensor_arg(1),
|
||||
TensorArg::from_raw_parts(&indices.handle, &strides, &shapes, 1),
|
||||
output.as_tensor_arg(1),
|
||||
ScalarArg::new(dim as u32),
|
||||
)
|
||||
};
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -16,6 +16,9 @@ mod tests {
|
|||
let actual = tensor.select(1, indices);
|
||||
let expected = tensor_ref.select(1, indices_ref);
|
||||
|
||||
println!("{:?}", actual);
|
||||
println!("{:?}", expected);
|
||||
|
||||
expected
|
||||
.into_data()
|
||||
.assert_approx_eq(&actual.into_data(), 3);
|
||||
|
|
Loading…
Reference in New Issue