mirror of https://github.com/tracel-ai/burn.git
Fix: launch without generics (#1932)
This commit is contained in:
parent
4c9097030f
commit
d772a1cfd5
|
@ -153,8 +153,6 @@ impl Codegen {
|
|||
})
|
||||
}
|
||||
|
||||
let generics = self.generics.split_for_impl().1;
|
||||
|
||||
let mut format_str = "{:?}-{}".to_string();
|
||||
for _ in 0..self.state_comptimes.len() {
|
||||
format_str.push_str("-{:?}");
|
||||
|
@ -166,6 +164,14 @@ impl Codegen {
|
|||
format_args.extend(quote::quote! { self.#ident, });
|
||||
}
|
||||
|
||||
let expand_func = match self.generics.params.is_empty() {
|
||||
true => quote::quote! { #expand },
|
||||
false => {
|
||||
let generics = self.generics.split_for_impl().1;
|
||||
quote::quote! { #expand::#generics }
|
||||
}
|
||||
};
|
||||
|
||||
quote::quote! {
|
||||
impl #impl_gen Kernel for #ident #ty_gen #where_gen {
|
||||
fn define(&self) -> KernelDefinition {
|
||||
|
@ -173,7 +179,7 @@ impl Codegen {
|
|||
|
||||
#variables
|
||||
|
||||
#expand::#generics(#expand_args);
|
||||
#expand_func(#expand_args);
|
||||
|
||||
builder.build(self.settings.clone())
|
||||
}
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
use crate as burn_cube;
|
||||
use burn_cube::prelude::*;
|
||||
|
||||
#[cube(launch)]
|
||||
pub fn kernel_with_generics<F: Float>(mut output: Array<F>) {
|
||||
if UNIT_POS == UInt::new(0) {
|
||||
output[0] = F::new(5.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch)]
|
||||
pub fn kernel_without_generics(mut output: Array<F32>) {
|
||||
if UNIT_POS == UInt::new(0) {
|
||||
output[0] = F32::new(5.0);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_kernel_with_generics<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
|
||||
let handle = client.create(f32::as_bytes(&[0.0, 1.0]));
|
||||
|
||||
kernel_with_generics_launch::<F32, R>(
|
||||
client.clone(),
|
||||
CubeCount::new(1, 1, 1),
|
||||
KernelSettings::default(),
|
||||
ArrayHandle::new(&handle, 2),
|
||||
);
|
||||
|
||||
let actual = client.read(handle.binding()).read_sync().unwrap();
|
||||
let actual = f32::from_bytes(&actual);
|
||||
|
||||
assert_eq!(actual[0], 5.0);
|
||||
}
|
||||
|
||||
pub fn test_kernel_without_generics<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
|
||||
let handle = client.create(f32::as_bytes(&[0.0, 1.0]));
|
||||
|
||||
kernel_without_generics_launch::<R>(
|
||||
client.clone(),
|
||||
CubeCount::new(1, 1, 1),
|
||||
KernelSettings::default(),
|
||||
ArrayHandle::new(&handle, 2),
|
||||
);
|
||||
|
||||
let actual = client.read(handle.binding()).read_sync().unwrap();
|
||||
let actual = f32::from_bytes(&actual);
|
||||
|
||||
assert_eq!(actual[0], 5.0);
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export]
|
||||
macro_rules! testgen_launch {
|
||||
() => {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_launch_with_generics() {
|
||||
let client = TestRuntime::client(&Default::default());
|
||||
burn_cube::runtime_tests::launch::test_kernel_with_generics::<TestRuntime>(client);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_launch_without_generics() {
|
||||
let client = TestRuntime::client(&Default::default());
|
||||
burn_cube::runtime_tests::launch::test_kernel_without_generics::<TestRuntime>(client);
|
||||
}
|
||||
};
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
pub mod launch;
|
||||
pub mod subcube;
|
||||
|
||||
#[allow(missing_docs)]
|
||||
|
@ -7,5 +8,6 @@ macro_rules! testgen_all {
|
|||
use burn_cube::prelude::*;
|
||||
|
||||
burn_cube::testgen_subcube!();
|
||||
burn_cube::testgen_launch!();
|
||||
};
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue