This commit is contained in:
louisfd 2024-07-25 11:22:41 -04:00
parent 2046831df6
commit 2a2a08e0d1
4 changed files with 67 additions and 167 deletions

View File

@ -1,13 +1,13 @@
use cubecl::{
cpa,
frontend::TensorHandleRef,
ir::{KernelDefinition, Scope, Variable, Visibility},
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo,
};
use std::{any::TypeId, marker::PhantomData};
use std::any::TypeId;
use crate::{kernel::Kernel, tensor::JitTensor, JitElement, JitRuntime};
use cubecl::{calculate_cube_count_elemwise, prelude::*, CubeDim};
use crate::{tensor::JitTensor, JitElement, JitRuntime};
#[cube(launch)]
fn cast_kernel<T1: Numeric, T2: Numeric>(input: &Tensor<T1>, output: &mut Tensor<T2>) {
output[ABSOLUTE_POS] = T2::cast_from(input[ABSOLUTE_POS]);
}
/// Cast a tensor to the given element type.
///
@ -24,7 +24,6 @@ pub fn cast<R: JitRuntime, EI: JitElement, EO: JitElement, const D: usize>(
);
}
let kernel = CastEagerKernel::<R, EI, EO>::new();
let num_elems = tensor.shape.num_elements();
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<EO>());
let output = JitTensor::new_contiguous(
@ -34,83 +33,36 @@ pub fn cast<R: JitRuntime, EI: JitElement, EO: JitElement, const D: usize>(
buffer,
);
Execution::start(kernel, tensor.client)
.inputs(&[TensorHandleRef::<R>::new(
let vectorization = |shape: usize| {
[4, 2]
.into_iter()
.filter(|v| shape % v == 0)
.map(|v| v as u8)
.next()
.unwrap_or(1)
};
let vectorization = vectorization(num_elems);
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim.x as usize);
cast_kernel::launch::<EI::Primitive, EO::Primitive, R>(
&tensor.client,
cube_count,
cube_dim,
TensorArg::vectorized(
vectorization,
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
)])
.outputs(&[TensorHandleRef::new(
),
TensorArg::vectorized(
vectorization,
&output.handle,
&output.strides,
&output.shape.dims,
)])
.execute(CubeCountSettings::Output { pos: 0 });
),
);
output
}
pub(crate) struct CastShader {
tensor: Variable,
output: Variable,
}
#[derive(new)]
pub(crate) struct CastEagerKernel<R: JitRuntime, EI: JitElement, EO: JitElement> {
_runtime: PhantomData<R>,
_elem_in: PhantomData<EI>,
_elem_out: PhantomData<EO>,
}
impl<R: JitRuntime, EI: JitElement, EO: JitElement> Kernel for CastEagerKernel<R, EI, EO> {
fn define(&self) -> KernelDefinition {
let mut scope = Scope::root();
let item_input = EI::cube_elem().into();
let item_output = EO::cube_elem().into();
let tensor = Variable::GlobalInputArray {
id: 0,
item: item_input,
};
let output = Variable::GlobalOutputArray {
id: 0,
item: item_output,
};
CastShader { tensor, output }.expand(&mut scope);
scope.write_global_custom(output);
let tensor = InputInfo::Array {
item: item_input,
visibility: Visibility::Read,
};
let out = OutputInfo::Array { item: item_output };
let info = KernelExpansion {
inputs: vec![tensor],
outputs: vec![out],
scope,
};
let settings = KernelSettings::default();
KernelIntegrator::new(info).integrate(settings)
}
fn id(&self) -> String {
format!("{:?}", core::any::TypeId::of::<Self>())
}
}
impl CastShader {
pub(crate) fn expand(self, scope: &mut Scope) {
let tensor = self.tensor;
let id = Variable::AbsolutePos;
let output = self.output;
let value = scope.create_local(output.item());
cpa!(scope, value = tensor[id]);
cpa!(scope, output[id] = value);
}
}

View File

@ -1,12 +1,10 @@
use crate::{kernel::Kernel, tensor::JitTensor, JitElement, JitRuntime};
use cubecl::{
cpa,
frontend::TensorHandleRef,
ir::{Elem, Item, KernelDefinition, Scope, Variable, Visibility},
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo,
};
use std::marker::PhantomData;
use crate::{tensor::JitTensor, JitElement, JitRuntime};
use cubecl::{calculate_cube_count_elemwise, cpa, prelude::*, CubeDim};
#[cube(launch)]
fn bool_cast_kernel<T: Numeric>(input: &Tensor<Bool>, output: &mut Tensor<T>) {
output[ABSOLUTE_POS] = T::cast_from(input[ABSOLUTE_POS]);
}
/// Cast a bool tensor to the given element type.
///
@ -17,7 +15,6 @@ use std::marker::PhantomData;
pub fn bool_cast<R: JitRuntime, EO: JitElement, const D: usize>(
tensor: JitTensor<R, u32, D>,
) -> JitTensor<R, EO, D> {
let kernel = BoolCastEagerKernel::<R, EO>::new();
let num_elems = tensor.shape.num_elements();
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<EO>());
let output = JitTensor::new_contiguous(
@ -27,86 +24,36 @@ pub fn bool_cast<R: JitRuntime, EO: JitElement, const D: usize>(
buffer,
);
Execution::start(kernel, tensor.client)
.inputs(&[TensorHandleRef::<R>::new(
let vectorization = |shape: usize| {
[4, 2]
.into_iter()
.filter(|v| shape % v == 0)
.map(|v| v as u8)
.next()
.unwrap_or(1)
};
let vectorization = vectorization(num_elems);
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim.x as usize);
bool_cast_kernel::launch::<EO::Primitive, R>(
&tensor.client,
cube_count,
cube_dim,
TensorArg::vectorized(
vectorization,
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
)])
.outputs(&[TensorHandleRef::new(
),
TensorArg::vectorized(
vectorization,
&output.handle,
&output.strides,
&output.shape.dims,
)])
.execute(CubeCountSettings::Output { pos: 0 });
),
);
output
}
pub(crate) struct BoolCastShader {
tensor: Variable,
output: Variable,
}
#[derive(new)]
pub(crate) struct BoolCastEagerKernel<R: JitRuntime, EO: JitElement> {
_runtime: PhantomData<R>,
_elem_out: PhantomData<EO>,
}
impl<R: JitRuntime, EO: JitElement> Kernel for BoolCastEagerKernel<R, EO> {
fn define(&self) -> KernelDefinition {
let mut scope = Scope::root();
let item_input = Item::new(Elem::Bool);
let item_output = EO::cube_elem().into();
let tensor = Variable::GlobalInputArray {
id: 0,
item: item_input,
};
let output = Variable::GlobalOutputArray {
id: 0,
item: item_output,
};
BoolCastShader { tensor, output }.expand(&mut scope);
scope.write_global_custom(output);
let tensor = InputInfo::Array {
item: item_input,
visibility: Visibility::Read,
};
let out = OutputInfo::Array { item: item_output };
let info = KernelExpansion {
inputs: vec![tensor],
outputs: vec![out],
scope,
};
let settings = KernelSettings::default();
KernelIntegrator::new(info).integrate(settings)
}
fn id(&self) -> String {
format!("{:?}", core::any::TypeId::of::<Self>())
}
}
impl BoolCastShader {
pub(crate) fn expand(self, scope: &mut Scope) {
let tensor = self.tensor;
let id = Variable::AbsolutePos;
let output = self.output;
let represents_true = scope.create_local(Elem::Bool);
cpa!(scope, represents_true = tensor[id]);
cpa!(scope, if(represents_true).then(|scope|{
cpa!(scope, output[id] = 1);
}).else(|scope|{
cpa!(scope, output[id] = 0);
}));
}
}

View File

@ -161,11 +161,11 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
};
let num_elems_output = output.shape.num_elements();
let cube_dim = calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX);
let cube_count = calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX);
conv2d_kernel::launch::<E::FloatPrimitive, R>(
&input.client,
cube_dim,
cube_count,
CubeDim::default(),
TensorArg::new(&input.handle, &input.strides, &input.shape.dims),
TensorArg::new(&weight.handle, &weight.strides, &weight.shape.dims),

View File

@ -46,5 +46,6 @@ mod tests {
tensor_2
.to_data()
.assert_eq(&TensorData::from([[1., 0., 1.], [0., 0., 1.]]), false);
assert!(false);
}
}