rm comments

This commit is contained in:
louisfd 2024-02-29 23:57:39 -05:00
parent 3a1cd95900
commit f13dbebdf1
1 changed files with 99 additions and 16 deletions

View File

@ -1,7 +1,18 @@
use super::{KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT};
use super::{
DynamicKernelSource, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT,
};
use crate::{
compute::StaticKernel, element::JitElement, kernel::elemwise_workgroup, kernel_wgsl,
tensor::JitTensor, Runtime,
codegen::{
dialect::gpu::{gpu, Elem, Item, Scope, Variable, Visibility},
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
InputInfo, OutputInfo, WorkgroupLaunch,
},
compute::StaticKernel,
element::JitElement,
kernel::elemwise_workgroup,
kernel_wgsl,
tensor::{self, JitTensor},
Runtime,
};
use std::{any::TypeId, marker::PhantomData};
@ -76,31 +87,103 @@ pub fn cast<R: Runtime, InputElem: JitElement, OutputElem: JitElement, const D:
/// where any non-zero value means true. Depending how it was created
/// it may hold an uncanny bit combination. Naively casting it would not
/// necessarily yield 0 or 1.
pub fn bool_cast<R: Runtime, OutputElem: JitElement, const D: usize>(
pub fn bool_cast<R: Runtime, EO: JitElement, const D: usize>(
tensor: JitTensor<R, u32, D>,
) -> JitTensor<R, OutputElem, D> {
) -> JitTensor<R, EO, D> {
let kernel = BoolCastEagerKernel::new();
let num_elems = tensor.shape.num_elements();
let kernel = StaticKernel::<
KernelSettings<BoolCast<OutputElem>, f32, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
let handle = tensor
.client
.empty(num_elems * core::mem::size_of::<OutputElem>());
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<EO>());
let output = JitTensor::new(
tensor.client.clone(),
tensor.device,
tensor.shape.clone(),
handle,
buffer,
);
tensor
.client
.execute(Box::new(kernel), &[&tensor.handle, &output.handle]);
execute_dynamic::<R, BoolCastEagerKernel<R, EO>, u32>(
&[EagerHandle::new(
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
)],
&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Output { pos: 0 },
tensor.client,
);
output
}
pub(crate) struct BoolCastShader {
tensor: Variable,
output: Variable,
}
#[derive(new)]
pub(crate) struct BoolCastEagerKernel<R: Runtime, EO: JitElement> {
_runtime: PhantomData<R>,
_elem_out: PhantomData<EO>,
}
impl<R: Runtime, EO: JitElement> DynamicKernelSource for BoolCastEagerKernel<R, EO> {
fn source(&self) -> crate::kernel::SourceTemplate {
let mut scope = Scope::root();
let item_input = Item::Scalar(Elem::Bool);
let item_output = EO::gpu_elem().into();
let tensor = Variable::GlobalInputArray(0, item_input);
let output = Variable::GlobalOutputArray(0, item_output);
BoolCastShader { tensor, output }.expand(&mut scope);
scope.write_global_custom(output);
let tensor = InputInfo::Array {
item: item_input,
visibility: Visibility::Read,
};
let out = OutputInfo::Array { item: item_output };
let info = CompilationInfo {
inputs: vec![tensor],
outputs: vec![out],
scope,
};
let settings = CompilationSettings::default();
let shader = Compilation::new(info).compile(settings);
let shader = <R::Compiler as Compiler>::compile(shader);
SourceTemplate::new(shader.to_string())
}
fn id(&self) -> String {
format!("{:?}", core::any::TypeId::of::<Self>())
}
}
impl BoolCastShader {
pub(crate) fn expand(self, scope: &mut Scope) {
let tensor = self.tensor;
let id = Variable::Id;
let output = self.output;
let represents_true = scope.create_local(Elem::Bool);
gpu!(scope, represents_true = tensor[id]);
gpu!(scope, if(represents_true).then(|scope|{
gpu!(scope, output[id] = 1);
}).else(|scope|{
gpu!(scope, output[id] = 0);
}));
}
}
#[cfg(test)]
mod tests {
use super::*;