mirror of https://github.com/tracel-ai/burn.git
wip
This commit is contained in:
parent
2046831df6
commit
2a2a08e0d1
|
@ -1,13 +1,13 @@
|
|||
use cubecl::{
|
||||
cpa,
|
||||
frontend::TensorHandleRef,
|
||||
ir::{KernelDefinition, Scope, Variable, Visibility},
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
use std::{any::TypeId, marker::PhantomData};
|
||||
use std::any::TypeId;
|
||||
|
||||
use crate::{kernel::Kernel, tensor::JitTensor, JitElement, JitRuntime};
|
||||
use cubecl::{calculate_cube_count_elemwise, prelude::*, CubeDim};
|
||||
|
||||
use crate::{tensor::JitTensor, JitElement, JitRuntime};
|
||||
|
||||
#[cube(launch)]
|
||||
fn cast_kernel<T1: Numeric, T2: Numeric>(input: &Tensor<T1>, output: &mut Tensor<T2>) {
|
||||
output[ABSOLUTE_POS] = T2::cast_from(input[ABSOLUTE_POS]);
|
||||
}
|
||||
|
||||
/// Cast a tensor to the given element type.
|
||||
///
|
||||
|
@ -24,7 +24,6 @@ pub fn cast<R: JitRuntime, EI: JitElement, EO: JitElement, const D: usize>(
|
|||
);
|
||||
}
|
||||
|
||||
let kernel = CastEagerKernel::<R, EI, EO>::new();
|
||||
let num_elems = tensor.shape.num_elements();
|
||||
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<EO>());
|
||||
let output = JitTensor::new_contiguous(
|
||||
|
@ -34,83 +33,36 @@ pub fn cast<R: JitRuntime, EI: JitElement, EO: JitElement, const D: usize>(
|
|||
buffer,
|
||||
);
|
||||
|
||||
Execution::start(kernel, tensor.client)
|
||||
.inputs(&[TensorHandleRef::<R>::new(
|
||||
let vectorization = |shape: usize| {
|
||||
[4, 2]
|
||||
.into_iter()
|
||||
.filter(|v| shape % v == 0)
|
||||
.map(|v| v as u8)
|
||||
.next()
|
||||
.unwrap_or(1)
|
||||
};
|
||||
|
||||
let vectorization = vectorization(num_elems);
|
||||
let cube_dim = CubeDim::default();
|
||||
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim.x as usize);
|
||||
|
||||
cast_kernel::launch::<EI::Primitive, EO::Primitive, R>(
|
||||
&tensor.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(
|
||||
vectorization,
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
)])
|
||||
.outputs(&[TensorHandleRef::new(
|
||||
),
|
||||
TensorArg::vectorized(
|
||||
vectorization,
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)])
|
||||
.execute(CubeCountSettings::Output { pos: 0 });
|
||||
),
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub(crate) struct CastShader {
|
||||
tensor: Variable,
|
||||
output: Variable,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct CastEagerKernel<R: JitRuntime, EI: JitElement, EO: JitElement> {
|
||||
_runtime: PhantomData<R>,
|
||||
_elem_in: PhantomData<EI>,
|
||||
_elem_out: PhantomData<EO>,
|
||||
}
|
||||
|
||||
impl<R: JitRuntime, EI: JitElement, EO: JitElement> Kernel for CastEagerKernel<R, EI, EO> {
|
||||
fn define(&self) -> KernelDefinition {
|
||||
let mut scope = Scope::root();
|
||||
let item_input = EI::cube_elem().into();
|
||||
let item_output = EO::cube_elem().into();
|
||||
|
||||
let tensor = Variable::GlobalInputArray {
|
||||
id: 0,
|
||||
item: item_input,
|
||||
};
|
||||
let output = Variable::GlobalOutputArray {
|
||||
id: 0,
|
||||
item: item_output,
|
||||
};
|
||||
|
||||
CastShader { tensor, output }.expand(&mut scope);
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
let tensor = InputInfo::Array {
|
||||
item: item_input,
|
||||
visibility: Visibility::Read,
|
||||
};
|
||||
|
||||
let out = OutputInfo::Array { item: item_output };
|
||||
|
||||
let info = KernelExpansion {
|
||||
inputs: vec![tensor],
|
||||
outputs: vec![out],
|
||||
scope,
|
||||
};
|
||||
|
||||
let settings = KernelSettings::default();
|
||||
KernelIntegrator::new(info).integrate(settings)
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
format!("{:?}", core::any::TypeId::of::<Self>())
|
||||
}
|
||||
}
|
||||
|
||||
impl CastShader {
|
||||
pub(crate) fn expand(self, scope: &mut Scope) {
|
||||
let tensor = self.tensor;
|
||||
let id = Variable::AbsolutePos;
|
||||
let output = self.output;
|
||||
|
||||
let value = scope.create_local(output.item());
|
||||
cpa!(scope, value = tensor[id]);
|
||||
cpa!(scope, output[id] = value);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
use crate::{kernel::Kernel, tensor::JitTensor, JitElement, JitRuntime};
|
||||
use cubecl::{
|
||||
cpa,
|
||||
frontend::TensorHandleRef,
|
||||
ir::{Elem, Item, KernelDefinition, Scope, Variable, Visibility},
|
||||
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
|
||||
OutputInfo,
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
use crate::{tensor::JitTensor, JitElement, JitRuntime};
|
||||
use cubecl::{calculate_cube_count_elemwise, cpa, prelude::*, CubeDim};
|
||||
|
||||
#[cube(launch)]
|
||||
fn bool_cast_kernel<T: Numeric>(input: &Tensor<Bool>, output: &mut Tensor<T>) {
|
||||
output[ABSOLUTE_POS] = T::cast_from(input[ABSOLUTE_POS]);
|
||||
}
|
||||
|
||||
/// Cast a bool tensor to the given element type.
|
||||
///
|
||||
|
@ -17,7 +15,6 @@ use std::marker::PhantomData;
|
|||
pub fn bool_cast<R: JitRuntime, EO: JitElement, const D: usize>(
|
||||
tensor: JitTensor<R, u32, D>,
|
||||
) -> JitTensor<R, EO, D> {
|
||||
let kernel = BoolCastEagerKernel::<R, EO>::new();
|
||||
let num_elems = tensor.shape.num_elements();
|
||||
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<EO>());
|
||||
let output = JitTensor::new_contiguous(
|
||||
|
@ -27,86 +24,36 @@ pub fn bool_cast<R: JitRuntime, EO: JitElement, const D: usize>(
|
|||
buffer,
|
||||
);
|
||||
|
||||
Execution::start(kernel, tensor.client)
|
||||
.inputs(&[TensorHandleRef::<R>::new(
|
||||
let vectorization = |shape: usize| {
|
||||
[4, 2]
|
||||
.into_iter()
|
||||
.filter(|v| shape % v == 0)
|
||||
.map(|v| v as u8)
|
||||
.next()
|
||||
.unwrap_or(1)
|
||||
};
|
||||
|
||||
let vectorization = vectorization(num_elems);
|
||||
let cube_dim = CubeDim::default();
|
||||
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim.x as usize);
|
||||
|
||||
bool_cast_kernel::launch::<EO::Primitive, R>(
|
||||
&tensor.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(
|
||||
vectorization,
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
)])
|
||||
.outputs(&[TensorHandleRef::new(
|
||||
),
|
||||
TensorArg::vectorized(
|
||||
vectorization,
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)])
|
||||
.execute(CubeCountSettings::Output { pos: 0 });
|
||||
),
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub(crate) struct BoolCastShader {
|
||||
tensor: Variable,
|
||||
output: Variable,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct BoolCastEagerKernel<R: JitRuntime, EO: JitElement> {
|
||||
_runtime: PhantomData<R>,
|
||||
_elem_out: PhantomData<EO>,
|
||||
}
|
||||
|
||||
impl<R: JitRuntime, EO: JitElement> Kernel for BoolCastEagerKernel<R, EO> {
|
||||
fn define(&self) -> KernelDefinition {
|
||||
let mut scope = Scope::root();
|
||||
let item_input = Item::new(Elem::Bool);
|
||||
let item_output = EO::cube_elem().into();
|
||||
|
||||
let tensor = Variable::GlobalInputArray {
|
||||
id: 0,
|
||||
item: item_input,
|
||||
};
|
||||
let output = Variable::GlobalOutputArray {
|
||||
id: 0,
|
||||
item: item_output,
|
||||
};
|
||||
|
||||
BoolCastShader { tensor, output }.expand(&mut scope);
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
let tensor = InputInfo::Array {
|
||||
item: item_input,
|
||||
visibility: Visibility::Read,
|
||||
};
|
||||
|
||||
let out = OutputInfo::Array { item: item_output };
|
||||
|
||||
let info = KernelExpansion {
|
||||
inputs: vec![tensor],
|
||||
outputs: vec![out],
|
||||
scope,
|
||||
};
|
||||
|
||||
let settings = KernelSettings::default();
|
||||
KernelIntegrator::new(info).integrate(settings)
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
format!("{:?}", core::any::TypeId::of::<Self>())
|
||||
}
|
||||
}
|
||||
|
||||
impl BoolCastShader {
|
||||
pub(crate) fn expand(self, scope: &mut Scope) {
|
||||
let tensor = self.tensor;
|
||||
let id = Variable::AbsolutePos;
|
||||
let output = self.output;
|
||||
|
||||
let represents_true = scope.create_local(Elem::Bool);
|
||||
cpa!(scope, represents_true = tensor[id]);
|
||||
cpa!(scope, if(represents_true).then(|scope|{
|
||||
cpa!(scope, output[id] = 1);
|
||||
}).else(|scope|{
|
||||
cpa!(scope, output[id] = 0);
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -161,11 +161,11 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
|
|||
};
|
||||
|
||||
let num_elems_output = output.shape.num_elements();
|
||||
let cube_dim = calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX);
|
||||
let cube_count = calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX);
|
||||
|
||||
conv2d_kernel::launch::<E::FloatPrimitive, R>(
|
||||
&input.client,
|
||||
cube_dim,
|
||||
cube_count,
|
||||
CubeDim::default(),
|
||||
TensorArg::new(&input.handle, &input.strides, &input.shape.dims),
|
||||
TensorArg::new(&weight.handle, &weight.strides, &weight.shape.dims),
|
||||
|
|
|
@ -46,5 +46,6 @@ mod tests {
|
|||
tensor_2
|
||||
.to_data()
|
||||
.assert_eq(&TensorData::from([[1., 0., 1.], [0., 0., 1.]]), false);
|
||||
assert!(false);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue