Fix: launch without generics (#1932)

This commit is contained in:
Nathaniel Simard 2024-06-26 12:57:32 -04:00 committed by GitHub
parent 4c9097030f
commit d772a1cfd5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 79 additions and 3 deletions

View File

@ -153,8 +153,6 @@ impl Codegen {
})
}
let generics = self.generics.split_for_impl().1;
let mut format_str = "{:?}-{}".to_string();
for _ in 0..self.state_comptimes.len() {
format_str.push_str("-{:?}");
@ -166,6 +164,14 @@ impl Codegen {
format_args.extend(quote::quote! { self.#ident, });
}
let expand_func = match self.generics.params.is_empty() {
true => quote::quote! { #expand },
false => {
let generics = self.generics.split_for_impl().1;
quote::quote! { #expand::#generics }
}
};
quote::quote! {
impl #impl_gen Kernel for #ident #ty_gen #where_gen {
fn define(&self) -> KernelDefinition {
@ -173,7 +179,7 @@ impl Codegen {
#variables
#expand::#generics(#expand_args);
#expand_func(#expand_args);
builder.build(self.settings.clone())
}

View File

@ -0,0 +1,68 @@
use crate as burn_cube;
use burn_cube::prelude::*;
#[cube(launch)]
pub fn kernel_with_generics<F: Float>(mut output: Array<F>) {
if UNIT_POS == UInt::new(0) {
output[0] = F::new(5.0);
}
}
#[cube(launch)]
pub fn kernel_without_generics(mut output: Array<F32>) {
if UNIT_POS == UInt::new(0) {
output[0] = F32::new(5.0);
}
}
pub fn test_kernel_with_generics<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
let handle = client.create(f32::as_bytes(&[0.0, 1.0]));
kernel_with_generics_launch::<F32, R>(
client.clone(),
CubeCount::new(1, 1, 1),
KernelSettings::default(),
ArrayHandle::new(&handle, 2),
);
let actual = client.read(handle.binding()).read_sync().unwrap();
let actual = f32::from_bytes(&actual);
assert_eq!(actual[0], 5.0);
}
pub fn test_kernel_without_generics<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
let handle = client.create(f32::as_bytes(&[0.0, 1.0]));
kernel_without_generics_launch::<R>(
client.clone(),
CubeCount::new(1, 1, 1),
KernelSettings::default(),
ArrayHandle::new(&handle, 2),
);
let actual = client.read(handle.binding()).read_sync().unwrap();
let actual = f32::from_bytes(&actual);
assert_eq!(actual[0], 5.0);
}
#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_launch {
() => {
use super::*;
#[test]
fn test_launch_with_generics() {
let client = TestRuntime::client(&Default::default());
burn_cube::runtime_tests::launch::test_kernel_with_generics::<TestRuntime>(client);
}
#[test]
fn test_launch_without_generics() {
let client = TestRuntime::client(&Default::default());
burn_cube::runtime_tests::launch::test_kernel_without_generics::<TestRuntime>(client);
}
};
}

View File

@ -1,3 +1,4 @@
pub mod launch;
pub mod subcube;
#[allow(missing_docs)]
@ -7,5 +8,6 @@ macro_rules! testgen_all {
use burn_cube::prelude::*;
burn_cube::testgen_subcube!();
burn_cube::testgen_launch!();
};
}