diff --git a/crates/burn-cube-macros/src/codegen_common/signature.rs b/crates/burn-cube-macros/src/codegen_common/signature.rs index c19e990ef..9f9654f47 100644 --- a/crates/burn-cube-macros/src/codegen_common/signature.rs +++ b/crates/burn-cube-macros/src/codegen_common/signature.rs @@ -2,10 +2,17 @@ use quote::ToTokens; use crate::tracker::VariableTracker; +#[derive(Copy, Clone, Debug)] +pub enum ExpandMode { + FuncImpl, + MethodImpl, +} + pub fn expand_sig( sig: &syn::Signature, visibility: &syn::Visibility, mut variable_tracker: Option<&mut VariableTracker>, + mode: ExpandMode, ) -> proc_macro2::TokenStream { let mut inputs = quote::quote!(); @@ -42,7 +49,10 @@ pub fn expand_sig( } let ident = &sig.ident; - let ident = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span()); + let ident = match mode { + ExpandMode::FuncImpl => syn::Ident::new("__expand".to_string().as_str(), ident.span()), + _ => syn::Ident::new(format!("__expand_{ident}").as_str(), ident.span()), + }; let generics = sig.generics.clone().into_token_stream(); diff --git a/crates/burn-cube-macros/src/codegen_function/function.rs b/crates/burn-cube-macros/src/codegen_function/function.rs index 88542aec0..436ab4a6c 100644 --- a/crates/burn-cube-macros/src/codegen_function/function.rs +++ b/crates/burn-cube-macros/src/codegen_function/function.rs @@ -106,22 +106,42 @@ pub(crate) fn codegen_call( // Path let mut path_tokens = TokenStream::new(); let mut is_comptime = false; + let mut is_plain_func = true; let mut comptime_func: Option = None; for (i, (ident, generics)) in path.iter().enumerate() { - if *ident == "Comptime" { + let name = ident.to_string(); + + if name == "Comptime" { is_comptime = true; continue; } + + if let Some(first_char) = name.chars().next() { + if first_char.is_uppercase() { + is_plain_func = false; + } + } + if i == path.len() - 1 { if is_comptime { comptime_func = Some(ident.to_string()); break; } - let func_name_expand = syn::Ident::new( - format!("{ident}_expand").as_str(), - proc_macro2::Span::call_site(), - ); + + let func_name_expand = if is_plain_func { + quote::quote! { + #ident::__expand + } + } else { + let ident = syn::Ident::new( + format!("__expand_{ident}").as_str(), + proc_macro2::Span::call_site(), + ); + quote::quote! { + #ident + } + }; path_tokens.extend(quote_spanned! {func_name_expand.span() => #func_name_expand }); if let Some(generics) = generics { path_tokens.extend(quote_spanned! {generics.span() => #generics }); diff --git a/crates/burn-cube-macros/src/codegen_function/launch.rs b/crates/burn-cube-macros/src/codegen_function/launch.rs index f87de3a3c..e37b52c45 100644 --- a/crates/burn-cube-macros/src/codegen_function/launch.rs +++ b/crates/burn-cube-macros/src/codegen_function/launch.rs @@ -211,7 +211,7 @@ impl Codegen { } } - fn gen_define_impl(&self, expand: &Ident) -> TokenStream { + fn gen_define_impl(&self, expand: &TokenStream) -> TokenStream { let mut expand_args = quote::quote! { &mut builder.context, }; let mut variables = quote::quote! {}; @@ -340,7 +340,7 @@ impl Codegen { tokens } - fn gen_compile_impl(&self, expand: &Ident) -> TokenStream { + fn gen_compile_impl(&self, expand: &TokenStream) -> TokenStream { let ident = Ident::new(&self.name, Span::call_site()); let generics = add_runtime(self.generics.clone()); let (impl_gen, ty_gen, where_gen) = generics.split_for_impl(); @@ -453,22 +453,27 @@ pub fn codegen_launch(sig: &syn::Signature) -> TokenStream { let codegen = Codegen::from_sig(sig); let ident = &sig.ident; - let ident_expand = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span()); - let ident = syn::Ident::new(format!("{ident}_launch").as_str(), ident.span()); + + let ident_expand = quote::quote! { + __expand + }; let generics = add_runtime(add_lifetime(sig.generics.clone())); let body = codegen.gen_launch_body(); let kernel = codegen.gen_kernel_struct(); let compile = codegen.gen_compile_impl(&ident_expand); let (inputs, output) = (codegen.fn_inputs, codegen.fn_output); + let doc = + format!("Launch the kernel [{ident}] with the provided argument on the given runtime."); quote::quote! { #kernel #compile #[allow(clippy::too_many_arguments)] - /// Launch - pub fn #ident #generics ( + #[doc = #doc] + /// Launch the kernel. + pub fn launch #generics ( client: ComputeClient, cube_count: CubeCount, cube_dim: CubeDim, diff --git a/crates/burn-cube-macros/src/codegen_trait/mod.rs b/crates/burn-cube-macros/src/codegen_trait/mod.rs index a6d18b4e5..4810bc1c6 100644 --- a/crates/burn-cube-macros/src/codegen_trait/mod.rs +++ b/crates/burn-cube-macros/src/codegen_trait/mod.rs @@ -1,4 +1,6 @@ -use crate::codegen_common::signature::expand_sig; +use proc_macro2::TokenStream; + +use crate::codegen_common::signature::{expand_sig, ExpandMode}; pub fn expand_trait_def(mut tr: syn::ItemTrait) -> proc_macro2::TokenStream { let mut expand_items = Vec::new(); @@ -6,7 +8,12 @@ pub fn expand_trait_def(mut tr: syn::ItemTrait) -> proc_macro2::TokenStream { for item in tr.items.iter() { match item { syn::TraitItem::Fn(func) => { - let expand = expand_sig(&func.sig, &syn::Visibility::Inherited, None); + let expand = expand_sig( + &func.sig, + &syn::Visibility::Inherited, + None, + ExpandMode::MethodImpl, + ); expand_items.push(syn::parse_quote!(#expand;)); } _ => continue, @@ -26,7 +33,7 @@ pub fn expand_trait_impl(mut tr: syn::ItemImpl) -> proc_macro2::TokenStream { match item { syn::ImplItem::Fn(func) => { let ident = &func.sig.ident; - let ident = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span()); + let ident = quote::quote! {#ident::__expand}; let mut inputs = quote::quote!(); for input in &func.sig.inputs { @@ -41,7 +48,12 @@ pub fn expand_trait_impl(mut tr: syn::ItemImpl) -> proc_macro2::TokenStream { } } - let expand = expand_sig(&func.sig, &syn::Visibility::Inherited, None); + let expand = expand_sig( + &func.sig, + &syn::Visibility::Inherited, + None, + ExpandMode::MethodImpl, + ); let tokens = if !tr.generics.params.is_empty() { let mut func = func.clone(); @@ -67,7 +79,7 @@ pub fn expand_trait_impl(mut tr: syn::ItemImpl) -> proc_macro2::TokenStream { fn register_expand( func: &syn::ImplItemFn, - name: &syn::Ident, + name: &TokenStream, expand: proc_macro2::TokenStream, inputs: proc_macro2::TokenStream, ) -> proc_macro2::TokenStream { @@ -91,7 +103,7 @@ fn register_expand( quote::quote! ( #expand { #[cube] - #func + pub #func #func_expand } ) diff --git a/crates/burn-cube-macros/src/lib.rs b/crates/burn-cube-macros/src/lib.rs index 6ce49e318..616fdbbcc 100644 --- a/crates/burn-cube-macros/src/lib.rs +++ b/crates/burn-cube-macros/src/lib.rs @@ -10,7 +10,7 @@ mod tracker; pub(crate) mod codegen_common; use analyzer::VariableAnalyzer; -use codegen_common::signature::expand_sig; +use codegen_common::signature::{expand_sig, ExpandMode}; use codegen_function::{codegen_launch, codegen_statement}; use codegen_trait::{expand_trait_def, expand_trait_impl}; use codegen_type::generate_cube_type; @@ -69,20 +69,8 @@ pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream { fn cube_fn(func: syn::ItemFn, attrs: &SupportedAttributes) -> TokenStream { let mut variable_tracker = VariableAnalyzer::create_tracker(&func); - match codegen_cube(&func, &mut variable_tracker) { - Ok(code) => { - if attrs.launch { - let launch = codegen_launch(&func.sig); - - quote::quote! { - #code - #launch - } - .into() - } else { - code.into() - } - } + match codegen_cube(&func, &mut variable_tracker, attrs.launch) { + Ok(code) => code.into(), Err(err) => err.into(), } } @@ -120,8 +108,15 @@ fn parse_attributes(args: &Punctuated) -> SupportedAttributes { fn codegen_cube( func: &syn::ItemFn, variable_tracker: &mut VariableTracker, + launch: bool, ) -> Result { - let signature = expand_sig(&func.sig, &func.vis, Some(variable_tracker)); + let signature = expand_sig( + &func.sig, + &syn::Visibility::Public(Default::default()), // Always public, otherwise we can't import + // it from an outside module. + Some(variable_tracker), + ExpandMode::FuncImpl, + ); let mut body = quote::quote! {}; for statement in func.block.stmts.iter() { @@ -148,15 +143,36 @@ fn codegen_cube( return Err(code); } + let launch_doc = if launch { "and launch function " } else { "" }; + + let launch = if launch { + codegen_launch(&func.sig) + } else { + quote::quote! {} + }; + + let mod_name = &func.sig.ident; + let vis = &func.vis; + let doc = format!("Module containing the expand method {launch_doc}of {mod_name}."); + Ok(quote::quote! { #[allow(dead_code)] #[allow(clippy::too_many_arguments)] #func - #[allow(unused_mut)] - #[allow(clippy::too_many_arguments)] - #signature { - #body + + #[doc = #doc] + #vis mod #mod_name { + use super::*; + + #launch + + #[allow(unused_mut)] + #[allow(clippy::too_many_arguments)] + #signature { + #body + } + } }) } diff --git a/crates/burn-cube/src/frontend/cmma.rs b/crates/burn-cube/src/frontend/cmma.rs index 81f3a5c46..5490ea130 100644 --- a/crates/burn-cube/src/frontend/cmma.rs +++ b/crates/burn-cube/src/frontend/cmma.rs @@ -86,7 +86,7 @@ impl Init for MatrixExpand { impl Matrix { /// Create a new matrix that is going to be used in the - /// [matrix-multiply and accumulate](execute) function. + /// [matrix-multiply and accumulate](execute()) function. /// /// You have to declare the shape used for the execution. /// The shape of the current matrix is determined using the [MatrixIdent]. @@ -103,7 +103,7 @@ impl Matrix { Matrix { _c: PhantomData } } - pub fn new_expand( + pub fn __expand_new( context: &mut CubeContext, ident: MatrixIdent, m: u8, @@ -129,16 +129,21 @@ pub fn fill(mat: &Matrix, value: C) { unexpanded!() } -/// Expand method of [fill]. -pub fn fill_expand( - context: &mut CubeContext, - mat: MatrixExpand, - value: ExpandElement, -) { - context.register(Operation::CoopMma(ir::CoopMma::Fill { - mat: *mat.elem, - value: *value, - })); +/// Module containing the expand function for [fill()]. +pub mod fill { + use super::*; + + /// Expand method of [fill()]. + pub fn __expand( + context: &mut CubeContext, + mat: MatrixExpand, + value: ExpandElement, + ) { + context.register(Operation::CoopMma(ir::CoopMma::Fill { + mat: *mat.elem, + value: *value, + })); + } } /// Load the matrix with the provided array using the stride. @@ -147,19 +152,24 @@ pub fn load(mat: &Matrix, value: &Slice<'_, C>, stride: UInt) { unexpanded!() } -/// Expand method of [load]. -#[allow(unused_variables)] -pub fn load_expand( - context: &mut CubeContext, - mat: MatrixExpand, - value: ExpandElementTyped>, - stride: ExpandElement, -) { - context.register(Operation::CoopMma(ir::CoopMma::Load { - mat: *mat.elem, - value: *value.expand, - stride: *stride, - })); +/// Module containing the expand function for [load()]. +pub mod load { + use super::*; + + /// Expand method of [load()]. + #[allow(unused_variables)] + pub fn __expand( + context: &mut CubeContext, + mat: MatrixExpand, + value: ExpandElementTyped>, + stride: ExpandElement, + ) { + context.register(Operation::CoopMma(ir::CoopMma::Load { + mat: *mat.elem, + value: *value.expand, + stride: *stride, + })); + } } /// Store the matrix in the given array following the given stride and layout. @@ -173,21 +183,26 @@ pub fn store( unexpanded!() } -/// Expand method of [store]. -#[allow(unused_variables)] -pub fn store_expand( - context: &mut CubeContext, - output: ExpandElementTyped>, - mat: MatrixExpand, - stride: ExpandElement, - layout: MatrixLayout, -) { - context.register(Operation::CoopMma(ir::CoopMma::Store { - output: *output.expand, - mat: *mat.elem, - stride: *stride, - layout, - })); +/// Module containing the expand function for [store()]. +pub mod store { + use super::*; + + /// Expand method of [store()]. + #[allow(unused_variables)] + pub fn __expand( + context: &mut CubeContext, + output: ExpandElementTyped>, + mat: MatrixExpand, + stride: ExpandElement, + layout: MatrixLayout, + ) { + context.register(Operation::CoopMma(ir::CoopMma::Store { + output: *output.expand, + mat: *mat.elem, + stride: *stride, + layout, + })); + } } /// Execute the matrix-multiply and accumulate operation on the given [matrices](Matrix). @@ -201,18 +216,23 @@ pub fn execute( - context: &mut CubeContext, - mat_a: MatrixExpand, - mat_b: MatrixExpand, - mat_c: MatrixExpand, - mat_d: MatrixExpand, -) { - context.register(Operation::CoopMma(ir::CoopMma::Execute { - mat_a: *mat_a.elem, - mat_b: *mat_b.elem, - mat_c: *mat_c.elem, - mat_d: *mat_d.elem, - })); +/// Module containing the expand function for [execute()]. +pub mod execute { + use super::*; + + /// Expand method of [execute()]. + pub fn __expand( + context: &mut CubeContext, + mat_a: MatrixExpand, + mat_b: MatrixExpand, + mat_c: MatrixExpand, + mat_d: MatrixExpand, + ) { + context.register(Operation::CoopMma(ir::CoopMma::Execute { + mat_a: *mat_a.elem, + mat_b: *mat_b.elem, + mat_c: *mat_c.elem, + mat_d: *mat_d.elem, + })); + } } diff --git a/crates/burn-cube/src/frontend/element/array.rs b/crates/burn-cube/src/frontend/element/array.rs index aaca60a21..5831a8d47 100644 --- a/crates/burn-cube/src/frontend/element/array.rs +++ b/crates/burn-cube/src/frontend/element/array.rs @@ -30,7 +30,11 @@ impl Array { Array { _val: PhantomData } } - pub fn new_expand( + pub fn vectorized(_size: S, _vectorization_factor: UInt) -> Self { + Array { _val: PhantomData } + } + + pub fn __expand_new( context: &mut CubeContext, size: S, ) -> ::ExpandType { @@ -44,11 +48,7 @@ impl Array { .into() } - pub fn vectorized(_size: S, _vectorization_factor: UInt) -> Self { - Array { _val: PhantomData } - } - - pub fn vectorized_expand( + pub fn __expand_vectorized( context: &mut CubeContext, size: S, vectorization_factor: UInt, diff --git a/crates/burn-cube/src/frontend/element/bool.rs b/crates/burn-cube/src/frontend/element/bool.rs index a36e40913..e5e2675b5 100644 --- a/crates/burn-cube/src/frontend/element/bool.rs +++ b/crates/burn-cube/src/frontend/element/bool.rs @@ -3,6 +3,10 @@ use crate::ir::Elem; use super::Vectorized; +// To be consistent with other primitive type. +/// Boolean type. +pub type Bool = bool; + impl CubeType for bool { type ExpandType = ExpandElement; } diff --git a/crates/burn-cube/src/frontend/element/cast.rs b/crates/burn-cube/src/frontend/element/cast.rs index fb411b040..4187b510a 100644 --- a/crates/burn-cube/src/frontend/element/cast.rs +++ b/crates/burn-cube/src/frontend/element/cast.rs @@ -6,7 +6,7 @@ use crate::{frontend::ExpandElement, unexpanded}; pub trait Cast: CubePrimitive { fn cast_from(value: From) -> Self; - fn cast_from_expand( + fn __expand_cast_from( context: &mut CubeContext, value: From, ) -> ::ExpandType diff --git a/crates/burn-cube/src/frontend/element/float.rs b/crates/burn-cube/src/frontend/element/float.rs index a73023665..c8dec62ab 100644 --- a/crates/burn-cube/src/frontend/element/float.rs +++ b/crates/burn-cube/src/frontend/element/float.rs @@ -29,15 +29,15 @@ pub trait Float: + core::ops::IndexMut { fn new(val: f32) -> Self; - fn new_expand(context: &mut CubeContext, val: f32) -> ::ExpandType; fn vectorized(val: f32, vectorization: UInt) -> Self; - fn vectorized_expand( + fn vectorized_empty(vectorization: UInt) -> Self; + fn __expand_new(context: &mut CubeContext, val: f32) -> ::ExpandType; + fn __expand_vectorized( context: &mut CubeContext, val: f32, vectorization: UInt, ) -> ::ExpandType; - fn vectorized_empty(vectorization: UInt) -> Self; - fn vectorized_empty_expand( + fn __expand_vectorized_empty( context: &mut CubeContext, vectorization: UInt, ) -> ::ExpandType; @@ -74,14 +74,6 @@ macro_rules! impl_float { } } - fn new_expand(_context: &mut CubeContext, val: f32) -> ::ExpandType { - let new_var = Variable::ConstantScalar { - value: val as f64, - elem: Self::as_elem(), - }; - ExpandElement::Plain(new_var) - } - fn vectorized(val: f32, vectorization: UInt) -> Self { if vectorization.val == 1 { Self::new(val) @@ -93,13 +85,28 @@ macro_rules! impl_float { } } - fn vectorized_expand( + fn vectorized_empty(vectorization: UInt) -> Self { + Self::vectorized(0., vectorization) + } + + fn __expand_new( + _context: &mut CubeContext, + val: f32, + ) -> ::ExpandType { + let new_var = Variable::ConstantScalar { + value: val as f64, + elem: Self::as_elem(), + }; + ExpandElement::Plain(new_var) + } + + fn __expand_vectorized( context: &mut CubeContext, val: f32, vectorization: UInt, ) -> ::ExpandType { if vectorization.val == 1 { - Self::new_expand(context, val) + Self::__expand_new(context, val) } else { let mut new_var = context .create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8)); @@ -111,16 +118,12 @@ macro_rules! impl_float { } } - fn vectorized_empty(vectorization: UInt) -> Self { - Self::vectorized(0., vectorization) - } - - fn vectorized_empty_expand( + fn __expand_vectorized_empty( context: &mut CubeContext, vectorization: UInt, ) -> ::ExpandType { if vectorization.val == 1 { - Self::new_expand(context, 0.) + Self::__expand_new(context, 0.) } else { context.create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8)) } diff --git a/crates/burn-cube/src/frontend/element/int.rs b/crates/burn-cube/src/frontend/element/int.rs index cac606f37..d5a92f73c 100644 --- a/crates/burn-cube/src/frontend/element/int.rs +++ b/crates/burn-cube/src/frontend/element/int.rs @@ -9,9 +9,9 @@ use super::{LaunchArgExpand, ScalarArgSettings, UInt, Vectorized}; /// Signed integer. Used as input in int kernels pub trait Int: Numeric + std::ops::Rem { fn new(val: i64) -> Self; - fn new_expand(context: &mut CubeContext, val: i64) -> ::ExpandType; fn vectorized(val: i64, vectorization: UInt) -> Self; - fn vectorized_expand( + fn __expand_new(context: &mut CubeContext, val: i64) -> ::ExpandType; + fn __expand_vectorized( context: &mut CubeContext, val: i64, vectorization: UInt, @@ -48,14 +48,6 @@ macro_rules! impl_int { } } - fn new_expand(_context: &mut CubeContext, val: i64) -> ::ExpandType { - let new_var = Variable::ConstantScalar { - value: val as f64, - elem: Self::as_elem(), - }; - ExpandElement::Plain(new_var) - } - fn vectorized(val: i64, vectorization: UInt) -> Self { if vectorization.val == 1 { Self::new(val) @@ -67,13 +59,24 @@ macro_rules! impl_int { } } - fn vectorized_expand( + fn __expand_new( + _context: &mut CubeContext, + val: i64, + ) -> ::ExpandType { + let new_var = Variable::ConstantScalar { + value: val as f64, + elem: Self::as_elem(), + }; + ExpandElement::Plain(new_var) + } + + fn __expand_vectorized( context: &mut CubeContext, val: i64, vectorization: UInt, ) -> ::ExpandType { if vectorization.val == 1 { - Self::new_expand(context, val) + Self::__expand_new(context, val) } else { let mut new_var = context .create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8)); diff --git a/crates/burn-cube/src/frontend/element/mod.rs b/crates/burn-cube/src/frontend/element/mod.rs index b2938035f..d67fb8647 100644 --- a/crates/burn-cube/src/frontend/element/mod.rs +++ b/crates/burn-cube/src/frontend/element/mod.rs @@ -14,6 +14,7 @@ mod vectorized; pub use array::*; pub use base::*; +pub use bool::*; pub use cast::*; pub use cube_elem::*; pub use float::*; diff --git a/crates/burn-cube/src/frontend/element/numeric.rs b/crates/burn-cube/src/frontend/element/numeric.rs index edc8daf9b..3c92a2eeb 100644 --- a/crates/burn-cube/src/frontend/element/numeric.rs +++ b/crates/burn-cube/src/frontend/element/numeric.rs @@ -45,8 +45,11 @@ pub trait Numeric: type Primitive: ScalarArgSettings; - /// Expand version of from_int - fn from_int_expand(_context: &mut CubeContext, val: i64) -> ::ExpandType { + fn from_vec(_vec: [i64; D]) -> Self { + unexpanded!() + } + + fn __expand_from_int(_context: &mut CubeContext, val: i64) -> ::ExpandType { let new_var = Variable::ConstantScalar { value: val as f64, elem: Self::as_elem(), @@ -54,11 +57,7 @@ pub trait Numeric: ExpandElement::Plain(new_var) } - fn from_vec(_vec: [i64; D]) -> Self { - unexpanded!() - } - - fn from_vec_expand( + fn __expand_from_vec( context: &mut CubeContext, vec: [i64; D], ) -> ::ExpandType { diff --git a/crates/burn-cube/src/frontend/element/shared_memory.rs b/crates/burn-cube/src/frontend/element/shared_memory.rs index fe06ca3cd..3ad49c330 100644 --- a/crates/burn-cube/src/frontend/element/shared_memory.rs +++ b/crates/burn-cube/src/frontend/element/shared_memory.rs @@ -27,24 +27,11 @@ impl SharedMemory { SharedMemory { _val: PhantomData } } - pub fn new_expand( - context: &mut CubeContext, - size: S, - ) -> ::ExpandType { - let size = size.value(); - let size = match size { - crate::ir::Variable::ConstantScalar { value, .. } => value as u32, - _ => panic!("Shared memory need constant initialization value"), - }; - let var = context.create_shared(Item::new(T::as_elem()), size); - ExpandElementTyped::new(var) - } - pub fn vectorized(_size: S, _vectorization_factor: UInt) -> Self { SharedMemory { _val: PhantomData } } - pub fn vectorized_expand( + pub fn __expand_vectorized( context: &mut CubeContext, size: S, vectorization_factor: UInt, @@ -60,4 +47,17 @@ impl SharedMemory { ); ExpandElementTyped::new(var) } + + pub fn __expand_new( + context: &mut CubeContext, + size: S, + ) -> ::ExpandType { + let size = size.value(); + let size = match size { + crate::ir::Variable::ConstantScalar { value, .. } => value as u32, + _ => panic!("Shared memory need constant initialization value"), + }; + let var = context.create_shared(Item::new(T::as_elem()), size); + ExpandElementTyped::new(var) + } } diff --git a/crates/burn-cube/src/frontend/element/uint.rs b/crates/burn-cube/src/frontend/element/uint.rs index 0e1e2ec7c..4df2516b6 100644 --- a/crates/burn-cube/src/frontend/element/uint.rs +++ b/crates/burn-cube/src/frontend/element/uint.rs @@ -48,7 +48,7 @@ impl UInt { } } - pub fn new_expand(_context: &mut CubeContext, val: u32) -> ::ExpandType { + pub fn __expand_new(_context: &mut CubeContext, val: u32) -> ::ExpandType { let new_var = Variable::ConstantScalar { value: val as f64, elem: Self::as_elem(), @@ -67,13 +67,13 @@ impl UInt { } } - pub fn vectorized_expand( + pub fn __expand_vectorized( context: &mut CubeContext, val: u32, vectorization: UInt, ) -> ::ExpandType { if vectorization.val == 1 { - Self::new_expand(context, val) + Self::__expand_new(context, val) } else { let mut new_var = context.create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8)); diff --git a/crates/burn-cube/src/frontend/operation/binary.rs b/crates/burn-cube/src/frontend/operation/binary.rs index 136ac60a2..08cf56244 100644 --- a/crates/burn-cube/src/frontend/operation/binary.rs +++ b/crates/burn-cube/src/frontend/operation/binary.rs @@ -280,11 +280,20 @@ macro_rules! impl_binary_func { } } -impl_binary_func!(Powf, powf, powf_expand, Operator::Powf, F16, BF16, F32, F64); +impl_binary_func!( + Powf, + powf, + __expand_powf, + Operator::Powf, + F16, + BF16, + F32, + F64 +); impl_binary_func!( Max, max, - max_expand, + __expand_max, Operator::Max, F16, BF16, @@ -297,7 +306,7 @@ impl_binary_func!( impl_binary_func!( Min, min, - min_expand, + __expand_min, Operator::Min, F16, BF16, @@ -310,7 +319,7 @@ impl_binary_func!( impl_binary_func!( Remainder, rem, - rem_expand, + __expand_rem, Operator::Remainder, F16, BF16, diff --git a/crates/burn-cube/src/frontend/operation/clamp.rs b/crates/burn-cube/src/frontend/operation/clamp.rs index c1d56ebfb..3765c5aa9 100644 --- a/crates/burn-cube/src/frontend/operation/clamp.rs +++ b/crates/burn-cube/src/frontend/operation/clamp.rs @@ -12,7 +12,7 @@ pub trait Clamp: CubePrimitive + Sized { fn clamp(input: Self, min_value: Self, max_value: Self) -> Self { unexpanded!() } - fn clamp_expand( + fn __expand_clamp( context: &mut CubeContext, input: Self::ExpandType, min_value: Self::ExpandType, diff --git a/crates/burn-cube/src/frontend/operation/unary.rs b/crates/burn-cube/src/frontend/operation/unary.rs index 8b36f1518..ec780652f 100644 --- a/crates/burn-cube/src/frontend/operation/unary.rs +++ b/crates/burn-cube/src/frontend/operation/unary.rs @@ -33,7 +33,7 @@ macro_rules! impl_unary_func { impl_unary_func!( Abs, abs, - abs_expand, + __expand_abs, Operator::Abs, F16, BF16, @@ -43,38 +43,65 @@ impl_unary_func!( I64, UInt ); -impl_unary_func!(Exp, exp, exp_expand, Operator::Exp, F16, BF16, F32, F64); -impl_unary_func!(Log, log, log_expand, Operator::Log, F16, BF16, F32, F64); +impl_unary_func!(Exp, exp, __expand_exp, Operator::Exp, F16, BF16, F32, F64); +impl_unary_func!(Log, log, __expand_log, Operator::Log, F16, BF16, F32, F64); impl_unary_func!( Log1p, log1p, - log1p_expand, + __expand_log1p, Operator::Log1p, F16, BF16, F32, F64 ); -impl_unary_func!(Cos, cos, cos_expand, Operator::Cos, F16, BF16, F32, F64); -impl_unary_func!(Sin, sin, sin_expand, Operator::Sin, F16, BF16, F32, F64); -impl_unary_func!(Tanh, tanh, tanh_expand, Operator::Tanh, F16, BF16, F32, F64); -impl_unary_func!(Sqrt, sqrt, sqrt_expand, Operator::Sqrt, F16, BF16, F32, F64); +impl_unary_func!(Cos, cos, __expand_cos, Operator::Cos, F16, BF16, F32, F64); +impl_unary_func!(Sin, sin, __expand_sin, Operator::Sin, F16, BF16, F32, F64); +impl_unary_func!( + Tanh, + tanh, + __expand_tanh, + Operator::Tanh, + F16, + BF16, + F32, + F64 +); +impl_unary_func!( + Sqrt, + sqrt, + __expand_sqrt, + Operator::Sqrt, + F16, + BF16, + F32, + F64 +); impl_unary_func!( Floor, floor, - floor_expand, + __expand_floor, Operator::Floor, F16, BF16, F32, F64 ); -impl_unary_func!(Ceil, ceil, ceil_expand, Operator::Ceil, F16, BF16, F32, F64); -impl_unary_func!(Erf, erf, erf_expand, Operator::Erf, F16, BF16, F32, F64); +impl_unary_func!( + Ceil, + ceil, + __expand_ceil, + Operator::Ceil, + F16, + BF16, + F32, + F64 +); +impl_unary_func!(Erf, erf, __expand_erf, Operator::Erf, F16, BF16, F32, F64); impl_unary_func!( Recip, recip, - recip_expand, + __expand_recip, Operator::Recip, F16, BF16, diff --git a/crates/burn-cube/src/frontend/subcube.rs b/crates/burn-cube/src/frontend/subcube.rs index c285ba1d9..3f50154ab 100644 --- a/crates/burn-cube/src/frontend/subcube.rs +++ b/crates/burn-cube/src/frontend/subcube.rs @@ -19,107 +19,143 @@ pub fn subcube_elect_expand(context: &mut CubeContext) -> Expa output } -pub fn subcube_sum(_elem: E) -> E { +/// Perform a reduce sum operation across all units in a subcube. +#[allow(unused_variables)] +pub fn subcube_sum(value: E) -> E { unexpanded!() } -pub fn subcube_sum_expand( - context: &mut CubeContext, - elem: ExpandElement, -) -> ExpandElement { - let output = context.create_local(elem.item()); +/// Module containing the expand function for [subcube_sum()]. +pub mod subcube_sum { + use super::*; - let out = *output; - let input = *elem; + /// Expand method of [subcube_sum()]. + pub fn __expand( + context: &mut CubeContext, + elem: ExpandElement, + ) -> ExpandElement { + let output = context.create_local(elem.item()); - context.register(Operation::Subcube(Subcube::Sum(UnaryOperator { - input, - out, - }))); + let out = *output; + let input = *elem; - output + context.register(Operation::Subcube(Subcube::Sum(UnaryOperator { + input, + out, + }))); + + output + } } +/// Perform a reduce prod operation across all units in a subcube. pub fn subcube_prod(_elem: E) -> E { unexpanded!() } -pub fn subcube_prod_expand( - context: &mut CubeContext, - elem: ExpandElement, -) -> ExpandElement { - let output = context.create_local(elem.item()); +/// Module containing the expand function for [subcube_prod()]. +pub mod subcube_prod { + use super::*; - let out = *output; - let input = *elem; + /// Expand method of [subcube_prod()]. + pub fn __expand( + context: &mut CubeContext, + elem: ExpandElement, + ) -> ExpandElement { + let output = context.create_local(elem.item()); - context.register(Operation::Subcube(Subcube::Prod(UnaryOperator { - input, - out, - }))); + let out = *output; + let input = *elem; - output + context.register(Operation::Subcube(Subcube::Prod(UnaryOperator { + input, + out, + }))); + + output + } } +/// Perform a reduce max operation across all units in a subcube. pub fn subcube_max(_elem: E) -> E { unexpanded!() } -pub fn subcube_max_expand( - context: &mut CubeContext, - elem: ExpandElement, -) -> ExpandElement { - let output = context.create_local(elem.item()); +/// Module containing the expand function for [subcube_max()]. +pub mod subcube_max { + use super::*; - let out = *output; - let input = *elem; + /// Expand method of [subcube_max()]. + pub fn __expand( + context: &mut CubeContext, + elem: ExpandElement, + ) -> ExpandElement { + let output = context.create_local(elem.item()); - context.register(Operation::Subcube(Subcube::Max(UnaryOperator { - input, - out, - }))); + let out = *output; + let input = *elem; - output + context.register(Operation::Subcube(Subcube::Max(UnaryOperator { + input, + out, + }))); + + output + } } +/// Perform a reduce min operation across all units in a subcube. pub fn subcube_min(_elem: E) -> E { unexpanded!() } -pub fn subcube_min_expand( - context: &mut CubeContext, - elem: ExpandElement, -) -> ExpandElement { - let output = context.create_local(elem.item()); +/// Module containing the expand function for [subcube_min()]. +pub mod subcube_min { + use super::*; - let out = *output; - let input = *elem; + /// Expand method of [subcube_min()]. + pub fn __expand( + context: &mut CubeContext, + elem: ExpandElement, + ) -> ExpandElement { + let output = context.create_local(elem.item()); - context.register(Operation::Subcube(Subcube::Min(UnaryOperator { - input, - out, - }))); + let out = *output; + let input = *elem; - output + context.register(Operation::Subcube(Subcube::Min(UnaryOperator { + input, + out, + }))); + + output + } } +/// Perform a reduce all operation across all units in a subcube. pub fn subcube_all(_elem: E) -> E { unexpanded!() } -pub fn subcube_all_expand( - context: &mut CubeContext, - elem: ExpandElement, -) -> ExpandElement { - let output = context.create_local(elem.item()); +/// Module containing the expand function for [subcube_all()]. +pub mod subcube_all { + use super::*; - let out = *output; - let input = *elem; + /// Expand method of [subcube_all()]. + pub fn __expand( + context: &mut CubeContext, + elem: ExpandElement, + ) -> ExpandElement { + let output = context.create_local(elem.item()); - context.register(Operation::Subcube(Subcube::All(UnaryOperator { - input, - out, - }))); + let out = *output; + let input = *elem; - output + context.register(Operation::Subcube(Subcube::All(UnaryOperator { + input, + out, + }))); + + output + } } diff --git a/crates/burn-cube/src/frontend/synchronization.rs b/crates/burn-cube/src/frontend/synchronization.rs index 53027e91b..a47967e4a 100644 --- a/crates/burn-cube/src/frontend/synchronization.rs +++ b/crates/burn-cube/src/frontend/synchronization.rs @@ -3,6 +3,10 @@ use crate::ir::Synchronization; pub fn sync_units() {} -pub fn sync_units_expand(context: &mut CubeContext) { - context.register(Synchronization::SyncUnits) +pub mod sync_units { + use super::*; + + pub fn __expand(context: &mut CubeContext) { + context.register(Synchronization::SyncUnits) + } } diff --git a/crates/burn-cube/src/prelude.rs b/crates/burn-cube/src/prelude.rs index 00be35c1e..027a7f866 100644 --- a/crates/burn-cube/src/prelude.rs +++ b/crates/burn-cube/src/prelude.rs @@ -11,7 +11,8 @@ pub use crate::runtime::Runtime; /// Elements pub use crate::frontend::{ - Array, ArrayHandle, Float, LaunchArg, Slice, Tensor, TensorArg, UInt, F16, F32, F64, I32, I64, + Array, ArrayHandle, Bool, Float, LaunchArg, Slice, SliceMut, Tensor, TensorArg, UInt, F16, F32, + F64, I32, I64, }; pub use crate::pod::CubeElement; @@ -23,10 +24,7 @@ pub use crate::frontend::{ }; /// Export subcube operations. -pub use crate::frontend::{ - subcube_all, subcube_all_expand, subcube_max, subcube_max_expand, subcube_min, - subcube_min_expand, subcube_prod, subcube_prod_expand, subcube_sum, subcube_sum_expand, -}; +pub use crate::frontend::{subcube_all, subcube_max, subcube_min, subcube_prod, subcube_sum}; pub use burn_compute::client::ComputeClient; pub use crate::frontend::*; diff --git a/crates/burn-cube/src/runtime_tests/cmma.rs b/crates/burn-cube/src/runtime_tests/cmma.rs index c6c253bc9..6c0668cbb 100644 --- a/crates/burn-cube/src/runtime_tests/cmma.rs +++ b/crates/burn-cube/src/runtime_tests/cmma.rs @@ -64,7 +64,7 @@ pub fn test_simple_1(client: ComputeClient) { let rhs = client.create(f16::as_bytes(&rhs)); let out = client.empty(core::mem::size_of::() * 256); - kernel_simple_1_launch::( + kernel_simple_1::launch::( client.clone(), CubeCount::Static(1, 1, 1), CubeDim::new(16, 16, 1), diff --git a/crates/burn-cube/src/runtime_tests/launch.rs b/crates/burn-cube/src/runtime_tests/launch.rs index 158ca490b..fe0646d97 100644 --- a/crates/burn-cube/src/runtime_tests/launch.rs +++ b/crates/burn-cube/src/runtime_tests/launch.rs @@ -18,7 +18,7 @@ pub fn kernel_without_generics(output: &mut Array) { pub fn test_kernel_with_generics(client: ComputeClient) { let handle = client.create(f32::as_bytes(&[0.0, 1.0])); - kernel_with_generics_launch::( + kernel_with_generics::launch::( client.clone(), CubeCount::Static(1, 1, 1), CubeDim::default(), @@ -34,7 +34,7 @@ pub fn test_kernel_with_generics(client: ComputeClient(client: ComputeClient) { let handle = client.create(f32::as_bytes(&[0.0, 1.0])); - kernel_without_generics_launch::( + kernel_without_generics::launch::( client.clone(), CubeCount::Static(1, 1, 1), CubeDim::default(), diff --git a/crates/burn-cube/src/runtime_tests/slice.rs b/crates/burn-cube/src/runtime_tests/slice.rs index a8f4ab96c..c46c07fe3 100644 --- a/crates/burn-cube/src/runtime_tests/slice.rs +++ b/crates/burn-cube/src/runtime_tests/slice.rs @@ -30,7 +30,7 @@ pub fn test_slice_select(client: ComputeClient()); - slice_select_launch::( + slice_select::launch::( client.clone(), CubeCount::Static(1, 1, 1), CubeDim::new(1, 1, 1), @@ -48,7 +48,7 @@ pub fn test_slice_len(client: ComputeClient) let input = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0])); let output = client.empty(core::mem::size_of::()); - slice_len_launch::( + slice_len::launch::( client.clone(), CubeCount::Static(1, 1, 1), CubeDim::new(1, 1, 1), @@ -66,7 +66,7 @@ pub fn test_slice_assign(client: ComputeClient( + slice_assign::launch::( client.clone(), CubeCount::Static(1, 1, 1), CubeDim::new(1, 1, 1), diff --git a/crates/burn-cube/src/runtime_tests/subcube.rs b/crates/burn-cube/src/runtime_tests/subcube.rs index 6cfe6f5db..c3c9cd53c 100644 --- a/crates/burn-cube/src/runtime_tests/subcube.rs +++ b/crates/burn-cube/src/runtime_tests/subcube.rs @@ -49,7 +49,7 @@ pub fn test_subcube_sum( &[17.0, 5.0, 7.0, 1.0], client.clone(), |cube_count, cube_dim, handle| { - kernel_sum_launch::(client.clone(), cube_count, cube_dim, handle) + kernel_sum::launch::(client.clone(), cube_count, cube_dim, handle) }, ); } @@ -62,7 +62,7 @@ pub fn test_subcube_prod( &[140.0, 5.0, 7.0, 1.0], client.clone(), |cube_dim, settings, handle| { - kernel_prod_launch::(client.clone(), cube_dim, settings, handle) + kernel_prod::launch::(client.clone(), cube_dim, settings, handle) }, ); } @@ -74,7 +74,7 @@ pub fn test_subcube_max( &[7.0, 5.0, 7.0, 1.0], client.clone(), |cube_dim, settings, handle| { - kernel_max_launch::(client.clone(), cube_dim, settings, handle) + kernel_max::launch::(client.clone(), cube_dim, settings, handle) }, ); } @@ -87,7 +87,7 @@ pub fn test_subcube_min( &[1.0, 5.0, 7.0, 1.0], client.clone(), |cube_dim, settings, handle| { - kernel_min_launch::(client.clone(), cube_dim, settings, handle) + kernel_min::launch::(client.clone(), cube_dim, settings, handle) }, ); } diff --git a/crates/burn-cube/tests/frontend/array.rs b/crates/burn-cube/tests/frontend/array.rs index 8f6deb731..502ee2b24 100644 --- a/crates/burn-cube/tests/frontend/array.rs +++ b/crates/burn-cube/tests/frontend/array.rs @@ -1,14 +1,14 @@ use burn_cube::prelude::*; #[cube] -fn array_read_write(array_size: Comptime) { +pub fn array_read_write(array_size: Comptime) { let mut array = Array::::new(array_size); array[0] = T::from_int(3); let _ = array[0]; } #[cube] -fn array_to_vectorized_variable() -> T { +pub fn array_to_vectorized_variable() -> T { let mut array = Array::::new(2); array[0] = T::from_int(0); array[1] = T::from_int(1); @@ -16,19 +16,19 @@ fn array_to_vectorized_variable() -> T { } #[cube] -fn array_of_one_to_vectorized_variable() -> T { +pub fn array_of_one_to_vectorized_variable() -> T { let mut array = Array::::new(1); array[0] = T::from_int(3); array.to_vectorized(Comptime::new(UInt::new(1))) } #[cube] -fn array_add_assign_simple(array: &mut Array) { +pub fn array_add_assign_simple(array: &mut Array) { array[UInt::new(1)] += UInt::new(1); } #[cube] -fn array_add_assign_expr(array: &mut Array) { +pub fn array_add_assign_expr(array: &mut Array) { array[UInt::new(1) + UInt::new(5)] += UInt::new(1); } @@ -45,7 +45,7 @@ mod tests { fn cube_support_array() { let mut context = CubeContext::root(); - array_read_write_expand::(&mut context, 512); + array_read_write::__expand::(&mut context, 512); assert_eq!( context.into_scope().operations, inline_macro_ref_read_write() @@ -57,7 +57,7 @@ mod tests { let mut context = CubeContext::root(); let array = context.input(0, Item::new(Elem::UInt)); - array_add_assign_simple_expand(&mut context, array.into()); + array_add_assign_simple::__expand(&mut context, array.into()); let scope = context.into_scope(); assert_eq!(scope.operations, inline_macro_array_add_assign_simple()); @@ -67,7 +67,7 @@ mod tests { fn cube_array_to_vectorized() { let mut context = CubeContext::root(); - array_to_vectorized_variable_expand::(&mut context); + array_to_vectorized_variable::__expand::(&mut context); assert_eq!( context.into_scope().operations, inline_macro_ref_to_vectorized() @@ -78,7 +78,7 @@ mod tests { fn cube_array_of_one_to_vectorized() { let mut context = CubeContext::root(); - array_of_one_to_vectorized_variable_expand::(&mut context); + array_of_one_to_vectorized_variable::__expand::(&mut context); assert_eq!( context.into_scope().operations, inline_macro_ref_one_to_vectorized() @@ -110,7 +110,7 @@ mod tests { let mut context = CubeContext::root(); let array = context.input(0, Item::new(Elem::UInt)); - array_add_assign_expr_expand(&mut context, array.into()); + array_add_assign_expr::__expand(&mut context, array.into()); let scope = context.into_scope(); assert_eq!(scope.operations, inline_macro_array_add_assign_expr()); diff --git a/crates/burn-cube/tests/frontend/assign.rs b/crates/burn-cube/tests/frontend/assign.rs index 988ea65cd..b720a3b87 100644 --- a/crates/burn-cube/tests/frontend/assign.rs +++ b/crates/burn-cube/tests/frontend/assign.rs @@ -1,27 +1,27 @@ use burn_cube::prelude::*; #[cube] -fn mut_assign() { +pub fn mut_assign() { let mut x = UInt::new(0); x += UInt::new(1); } #[cube] -fn mut_assign_input(y: UInt) -> UInt { +pub fn mut_assign_input(y: UInt) -> UInt { let mut x = y; x += UInt::new(1); y + UInt::new(2) } #[cube] -fn assign_mut_input(mut y: UInt) -> UInt { +pub fn assign_mut_input(mut y: UInt) -> UInt { let x = y; y += UInt::new(1); x + UInt::new(2) } #[cube] -fn assign_vectorized(y: UInt) -> UInt { +pub fn assign_vectorized(y: UInt) -> UInt { let vectorization_factor = Comptime::vectorization(&y); let x = UInt::vectorized(1, Comptime::get(vectorization_factor)); x + y @@ -38,7 +38,7 @@ mod tests { fn cube_mut_assign_test() { let mut context = CubeContext::root(); - mut_assign_expand(&mut context); + mut_assign::__expand(&mut context); let scope = context.into_scope(); assert_eq!( @@ -53,7 +53,7 @@ mod tests { let y = context.create_local(Item::new(UInt::as_elem())); - mut_assign_input_expand(&mut context, y); + mut_assign_input::__expand(&mut context, y); let scope = context.into_scope(); assert_eq!( @@ -68,7 +68,7 @@ mod tests { let y = context.create_local(Item::new(UInt::as_elem())); - assign_mut_input_expand(&mut context, y); + assign_mut_input::__expand(&mut context, y); let scope = context.into_scope(); assert_eq!( @@ -83,7 +83,7 @@ mod tests { let y = context.create_local(Item::vectorized(UInt::as_elem(), 4)); - assign_vectorized_expand(&mut context, y); + assign_vectorized::__expand(&mut context, y); let scope = context.into_scope(); assert_eq!( diff --git a/crates/burn-cube/tests/frontend/cast_elem.rs b/crates/burn-cube/tests/frontend/cast_elem.rs index 8e92630a3..8ceaf164f 100644 --- a/crates/burn-cube/tests/frontend/cast_elem.rs +++ b/crates/burn-cube/tests/frontend/cast_elem.rs @@ -1,6 +1,6 @@ use burn_cube::{ cube, - frontend::{Cast, Numeric, UInt, F32, I32}, + frontend::{Bool, Cast, Numeric, UInt, F32, I32}, }; // From float @@ -26,7 +26,7 @@ pub fn float_to_uint(x: F32) { #[allow(clippy::overly_complex_bool_expr)] pub fn float_to_bool(x: F32) { let y = x + F32::from_int(2); - let _ = bool::cast_from(y) || true; + let _ = Bool::cast_from(y) || true; } // From int @@ -53,7 +53,7 @@ pub fn int_to_uint(x: I32) { #[allow(clippy::overly_complex_bool_expr)] pub fn int_to_bool(x: I32) { let y = x + I32::from_int(2); - let _ = bool::cast_from(y) || true; + let _ = Bool::cast_from(y) || true; } // // From uint @@ -80,27 +80,27 @@ pub fn uint_to_uint(x: UInt) { #[allow(clippy::overly_complex_bool_expr)] pub fn uint_to_bool(x: UInt) { let y = x + UInt::from_int(2); - let _ = bool::cast_from(y) || true; + let _ = Bool::cast_from(y) || true; } // From bool #[cube] #[allow(clippy::overly_complex_bool_expr)] -pub fn bool_to_float(x: bool) { +pub fn bool_to_float(x: Bool) { let y = x && false; let _ = F32::cast_from(y) + F32::from_int(34); } #[cube] #[allow(clippy::overly_complex_bool_expr)] -pub fn bool_to_int(x: bool) { +pub fn bool_to_int(x: Bool) { let y = x && false; let _ = I32::cast_from(y) + I32::from_int(34); } #[cube] #[allow(clippy::overly_complex_bool_expr)] -pub fn bool_to_uint(x: bool) { +pub fn bool_to_uint(x: Bool) { let y = x && false; let _ = UInt::cast_from(y) + UInt::from_int(34); } @@ -108,9 +108,9 @@ pub fn bool_to_uint(x: bool) { #[cube] #[allow(clippy::overly_complex_bool_expr)] #[allow(clippy::useless_conversion)] -pub fn bool_to_bool(x: bool) { +pub fn bool_to_bool(x: Bool) { let y = x && false; - let _ = bool::cast_from(y) || true; + let _ = Bool::cast_from(y) || true; } mod tests { @@ -122,7 +122,7 @@ mod tests { }; macro_rules! cast_test { - ($name:ident, $module:ident, $from:expr, $to:expr) => { + ($name:ident, $module:expr, $from:expr, $to:expr) => { #[test] fn $name() { let mut context = CubeContext::root(); @@ -142,112 +142,112 @@ mod tests { cast_test!( cube_float_to_float_test, - float_to_float_expand, + float_to_float::__expand, Item::new(F32::as_elem()), Item::new(F32::as_elem()) ); cast_test!( cube_float_to_int_test, - float_to_int_expand, + float_to_int::__expand, Item::new(F32::as_elem()), Item::new(I32::as_elem()) ); cast_test!( cube_float_to_uint_test, - float_to_uint_expand, + float_to_uint::__expand, Item::new(F32::as_elem()), Item::new(Elem::UInt) ); cast_test!( cube_float_to_bool_test, - float_to_bool_expand, + float_to_bool::__expand, Item::new(F32::as_elem()), Item::new(Elem::Bool) ); cast_test!( cube_int_to_float_test, - int_to_float_expand, + int_to_float::__expand, Item::new(I32::as_elem()), Item::new(F32::as_elem()) ); cast_test!( cube_int_to_int_test, - int_to_int_expand, + int_to_int::__expand, Item::new(I32::as_elem()), Item::new(I32::as_elem()) ); cast_test!( cube_int_to_uint_test, - int_to_uint_expand, + int_to_uint::__expand, Item::new(I32::as_elem()), Item::new(Elem::UInt) ); cast_test!( cube_int_to_bool_test, - int_to_bool_expand, + int_to_bool::__expand, Item::new(I32::as_elem()), Item::new(Elem::Bool) ); cast_test!( cube_uint_to_float_test, - uint_to_float_expand, + uint_to_float::__expand, Item::new(Elem::UInt), Item::new(F32::as_elem()) ); cast_test!( cube_uint_to_int_test, - uint_to_int_expand, + uint_to_int::__expand, Item::new(Elem::UInt), Item::new(I32::as_elem()) ); cast_test!( cube_uint_to_uint_test, - uint_to_uint_expand, + uint_to_uint::__expand, Item::new(Elem::UInt), Item::new(Elem::UInt) ); cast_test!( cube_uint_to_bool_test, - uint_to_bool_expand, + uint_to_bool::__expand, Item::new(Elem::UInt), Item::new(Elem::Bool) ); cast_test!( cube_bool_to_float_test, - bool_to_float_expand, + bool_to_float::__expand, Item::new(Elem::Bool), Item::new(F32::as_elem()) ); cast_test!( cube_bool_to_int_test, - bool_to_int_expand, + bool_to_int::__expand, Item::new(Elem::Bool), Item::new(I32::as_elem()) ); cast_test!( cube_bool_to_uint_test, - bool_to_uint_expand, + bool_to_uint::__expand, Item::new(Elem::Bool), Item::new(Elem::UInt) ); cast_test!( cube_bool_to_bool_test, - bool_to_bool_expand, + bool_to_bool::__expand, Item::new(Elem::Bool), Item::new(Elem::Bool) ); diff --git a/crates/burn-cube/tests/frontend/cast_kind.rs b/crates/burn-cube/tests/frontend/cast_kind.rs index 48743b399..26e54b6f6 100644 --- a/crates/burn-cube/tests/frontend/cast_kind.rs +++ b/crates/burn-cube/tests/frontend/cast_kind.rs @@ -46,7 +46,7 @@ mod tests { let input = context.create_local(item); - cast_float_kind_expand::(&mut context, input); + cast_float_kind::__expand::(&mut context, input); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_float()); @@ -59,7 +59,7 @@ mod tests { let input = context.create_local(item); - cast_int_kind_expand::(&mut context, input); + cast_int_kind::__expand::(&mut context, input); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); @@ -72,7 +72,7 @@ mod tests { let input = context.create_local(item); - cast_numeric_to_kind_expand::(&mut context, input); + cast_numeric_to_kind::__expand::(&mut context, input); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); @@ -85,7 +85,7 @@ mod tests { let input = context.create_local(item); - cast_int_to_numeric_expand::(&mut context, input); + cast_int_to_numeric::__expand::(&mut context, input); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); diff --git a/crates/burn-cube/tests/frontend/comptime.rs b/crates/burn-cube/tests/frontend/comptime.rs index e0790a75e..3e4be79bb 100644 --- a/crates/burn-cube/tests/frontend/comptime.rs +++ b/crates/burn-cube/tests/frontend/comptime.rs @@ -122,7 +122,7 @@ mod tests { let lhs = context.create_local(Item::new(ElemType::as_elem())); - comptime_if_else_expand::(&mut context, lhs, true); + comptime_if_else::__expand::(&mut context, lhs, true); let scope = context.into_scope(); assert_eq!( @@ -137,7 +137,7 @@ mod tests { let lhs = context.create_local(Item::new(ElemType::as_elem())); - comptime_if_expr_expand::(&mut context, lhs, UInt::new(4), UInt::new(5)); + comptime_if_expr::__expand::(&mut context, lhs, UInt::new(4), UInt::new(5)); let scope = context.into_scope(); assert_eq!( @@ -152,7 +152,7 @@ mod tests { let lhs = context.create_local(Item::new(ElemType::as_elem())); - comptime_if_else_expand::(&mut context, lhs, false); + comptime_if_else::__expand::(&mut context, lhs, false); let scope = context.into_scope(); assert_eq!( @@ -167,12 +167,12 @@ mod tests { for cond2 in [false, true] { let mut context1 = CubeContext::root(); let lhs = context1.create_local(Item::new(ElemType::as_elem())); - comptime_else_then_if_expand::(&mut context1, lhs, cond1, cond2); + comptime_else_then_if::__expand::(&mut context1, lhs, cond1, cond2); let scope1 = context1.into_scope(); let mut context2 = CubeContext::root(); let lhs = context2.create_local(Item::new(ElemType::as_elem())); - comptime_elsif_expand::(&mut context2, lhs, cond1, cond2); + comptime_elsif::__expand::(&mut context2, lhs, cond1, cond2); let scope2 = context2.into_scope(); assert_eq!( @@ -188,7 +188,7 @@ mod tests { for cond in [false, true] { let mut context = CubeContext::root(); let lhs = context.create_local(Item::new(ElemType::as_elem())); - comptime_elsif_with_runtime1_expand::(&mut context, lhs, cond); + comptime_elsif_with_runtime1::__expand::(&mut context, lhs, cond); let scope = context.into_scope(); assert_eq!( @@ -203,7 +203,7 @@ mod tests { for cond in [false, true] { let mut context = CubeContext::root(); let lhs = context.create_local(Item::new(ElemType::as_elem())); - comptime_elsif_with_runtime2_expand::(&mut context, lhs, cond); + comptime_elsif_with_runtime2::__expand::(&mut context, lhs, cond); let scope = context.into_scope(); assert_eq!( @@ -227,8 +227,8 @@ mod tests { bound: 4, }; - comptime_with_map_bool_expand::(&mut context1, comptime_state_true); - comptime_with_map_bool_expand::(&mut context2, comptime_state_false); + comptime_with_map_bool::__expand::(&mut context1, comptime_state_true); + comptime_with_map_bool::__expand::(&mut context2, comptime_state_false); let scope1 = context1.into_scope(); let scope2 = context2.into_scope(); @@ -248,7 +248,7 @@ mod tests { bound: 4, }; - comptime_with_map_uint_expand::(&mut context, comptime_state); + comptime_with_map_uint::__expand::(&mut context, comptime_state); let scope = context.into_scope(); diff --git a/crates/burn-cube/tests/frontend/cube_trait.rs b/crates/burn-cube/tests/frontend/cube_trait.rs index 018ac5d5c..d74814e16 100644 --- a/crates/burn-cube/tests/frontend/cube_trait.rs +++ b/crates/burn-cube/tests/frontend/cube_trait.rs @@ -42,12 +42,12 @@ impl CombinedTraitFunctionGeneric for Test { } #[cube] -fn simple(lhs: C, rhs: C) -> C { +pub fn simple(lhs: C, rhs: C) -> C { lhs + rhs } #[cube] -fn with_cast(lhs: C, rhs: C) -> O { +pub fn with_cast(lhs: C, rhs: C) -> O { O::cast_from(lhs + rhs) } @@ -62,7 +62,7 @@ mod tests { let lhs = context.create_local(Item::new(F32::as_elem())); let rhs = context.create_local(Item::new(F32::as_elem())); - ::test_expand::(&mut context, lhs, rhs); + ::__expand_test::(&mut context, lhs, rhs); assert_eq!(simple_scope(), context.into_scope()); } @@ -73,7 +73,7 @@ mod tests { let lhs = context.create_local(Item::new(F32::as_elem())); let rhs = context.create_local(Item::new(F32::as_elem())); - >::test_expand(&mut context, lhs, rhs); + >::__expand_test(&mut context, lhs, rhs); assert_eq!(simple_scope(), context.into_scope()); } @@ -84,7 +84,7 @@ mod tests { let lhs = context.create_local(Item::new(F32::as_elem())); let rhs = context.create_local(Item::new(F32::as_elem())); - >::test_expand::(&mut context, lhs, rhs); + >::__expand_test::(&mut context, lhs, rhs); assert_eq!(with_cast_scope(), context.into_scope()); } @@ -94,7 +94,7 @@ mod tests { let lhs = context_ref.create_local(Item::new(F32::as_elem())); let rhs = context_ref.create_local(Item::new(F32::as_elem())); - simple_expand::(&mut context_ref, lhs, rhs); + simple::__expand::(&mut context_ref, lhs, rhs); context_ref.into_scope() } @@ -103,7 +103,7 @@ mod tests { let lhs = context_ref.create_local(Item::new(F32::as_elem())); let rhs = context_ref.create_local(Item::new(F32::as_elem())); - with_cast_expand::(&mut context_ref, lhs, rhs); + with_cast::__expand::(&mut context_ref, lhs, rhs); context_ref.into_scope() } } diff --git a/crates/burn-cube/tests/frontend/for_loop.rs b/crates/burn-cube/tests/frontend/for_loop.rs index 75e888109..dcc1c9e0d 100644 --- a/crates/burn-cube/tests/frontend/for_loop.rs +++ b/crates/burn-cube/tests/frontend/for_loop.rs @@ -30,7 +30,7 @@ mod tests { let rhs = context.create_local(Item::new(ElemType::as_elem())); let end = 4u32.into(); - for_loop_expand::(&mut context, lhs.into(), rhs, end, unroll); + for_loop::__expand::(&mut context, lhs.into(), rhs, end, unroll); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(unroll)); @@ -45,7 +45,7 @@ mod tests { let rhs = context.create_local(Item::new(ElemType::as_elem())); let end = 4u32.into(); - for_loop_expand::(&mut context, lhs.into(), rhs, end, unroll); + for_loop::__expand::(&mut context, lhs.into(), rhs, end, unroll); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(unroll)); diff --git a/crates/burn-cube/tests/frontend/function_call.rs b/crates/burn-cube/tests/frontend/function_call.rs index 875a45515..88d5d7f87 100644 --- a/crates/burn-cube/tests/frontend/function_call.rs +++ b/crates/burn-cube/tests/frontend/function_call.rs @@ -59,12 +59,12 @@ mod tests { fn cube_call_equivalent_to_no_call_no_arg_test() { let mut caller_context = CubeContext::root(); let x = caller_context.create_local(Item::new(Elem::UInt)); - caller_no_arg_expand(&mut caller_context, x); + caller_no_arg::__expand(&mut caller_context, x); let caller_scope = caller_context.into_scope(); let mut no_call_context = CubeContext::root(); let x = no_call_context.create_local(Item::new(Elem::UInt)); - no_call_no_arg_expand(&mut no_call_context, x); + no_call_no_arg::__expand(&mut no_call_context, x); let no_call_scope = no_call_context.into_scope(); assert_eq!( @@ -78,12 +78,12 @@ mod tests { let mut caller_context = CubeContext::root(); let x = caller_context.create_local(Item::new(Elem::UInt)); - caller_with_arg_expand(&mut caller_context, x); + caller_with_arg::__expand(&mut caller_context, x); let caller_scope = caller_context.into_scope(); let mut no_call_context = CubeContext::root(); let x = no_call_context.create_local(Item::new(Elem::UInt)); - no_call_with_arg_expand(&mut no_call_context, x); + no_call_with_arg::__expand(&mut no_call_context, x); let no_call_scope = no_call_context.into_scope(); assert_eq!( @@ -97,12 +97,12 @@ mod tests { let mut caller_context = CubeContext::root(); type ElemType = I64; let x = caller_context.create_local(Item::new(ElemType::as_elem())); - caller_with_generics_expand::(&mut caller_context, x); + caller_with_generics::__expand::(&mut caller_context, x); let caller_scope = caller_context.into_scope(); let mut no_call_context = CubeContext::root(); let x = no_call_context.create_local(Item::new(ElemType::as_elem())); - no_call_with_generics_expand::(&mut no_call_context, x); + no_call_with_generics::__expand::(&mut no_call_context, x); let no_call_scope = no_call_context.into_scope(); assert_eq!( diff --git a/crates/burn-cube/tests/frontend/generic_kernel.rs b/crates/burn-cube/tests/frontend/generic_kernel.rs index 387b3c3e4..d2879c408 100644 --- a/crates/burn-cube/tests/frontend/generic_kernel.rs +++ b/crates/burn-cube/tests/frontend/generic_kernel.rs @@ -20,7 +20,7 @@ mod tests { let lhs = context.create_local(Item::new(F32::as_elem())); - generic_kernel_expand::(&mut context, lhs); + generic_kernel::__expand::(&mut context, lhs); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_float()); @@ -32,7 +32,7 @@ mod tests { let lhs = context.create_local(Item::new(I32::as_elem())); - generic_kernel_expand::(&mut context, lhs); + generic_kernel::__expand::(&mut context, lhs); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); diff --git a/crates/burn-cube/tests/frontend/if.rs b/crates/burn-cube/tests/frontend/if.rs index 200113d56..895e11e02 100644 --- a/crates/burn-cube/tests/frontend/if.rs +++ b/crates/burn-cube/tests/frontend/if.rs @@ -52,7 +52,7 @@ mod tests { let lhs = context.create_local(Item::new(ElemType::as_elem())); - if_greater_expand::(&mut context, lhs); + if_greater::__expand::(&mut context, lhs); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_if()); @@ -64,7 +64,7 @@ mod tests { let lhs = context.create_local(Item::new(ElemType::as_elem())); - if_then_else_expand::(&mut context, lhs); + if_then_else::__expand::(&mut context, lhs); let scope = context.into_scope(); assert_eq!( @@ -79,7 +79,7 @@ mod tests { let lhs = context.create_local(Item::new(ElemType::as_elem())); - elsif_expand::(&mut context, lhs); + elsif::__expand::(&mut context, lhs); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_elsif()); diff --git a/crates/burn-cube/tests/frontend/literal.rs b/crates/burn-cube/tests/frontend/literal.rs index 9bc6ad53e..d825b0c73 100644 --- a/crates/burn-cube/tests/frontend/literal.rs +++ b/crates/burn-cube/tests/frontend/literal.rs @@ -25,7 +25,7 @@ mod tests { let lhs = context.create_local(Item::new(ElemType::as_elem())); - literal_expand::(&mut context, lhs); + literal::__expand::(&mut context, lhs); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); @@ -37,7 +37,7 @@ mod tests { let lhs = context.create_local(Item::new(ElemType::as_elem())); - literal_float_no_decimals_expand::(&mut context, lhs); + literal_float_no_decimals::__expand::(&mut context, lhs); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); diff --git a/crates/burn-cube/tests/frontend/loop.rs b/crates/burn-cube/tests/frontend/loop.rs index a9d943c9e..5c0d318c6 100644 --- a/crates/burn-cube/tests/frontend/loop.rs +++ b/crates/burn-cube/tests/frontend/loop.rs @@ -42,7 +42,7 @@ mod tests { let lhs = context.create_local(Item::new(ElemType::as_elem())); - while_not_expand::(&mut context, lhs); + while_not::__expand::(&mut context, lhs); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(false)); @@ -54,7 +54,7 @@ mod tests { let lhs = context.create_local(Item::new(ElemType::as_elem())); - manual_loop_break_expand::(&mut context, lhs); + manual_loop_break::__expand::(&mut context, lhs); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(false)); @@ -66,7 +66,7 @@ mod tests { let lhs = context.create_local(Item::new(ElemType::as_elem())); - loop_with_return_expand::(&mut context, lhs); + loop_with_return::__expand::(&mut context, lhs); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(true)); diff --git a/crates/burn-cube/tests/frontend/module_import.rs b/crates/burn-cube/tests/frontend/module_import.rs index 3da61d271..fdd3c4bca 100644 --- a/crates/burn-cube/tests/frontend/module_import.rs +++ b/crates/burn-cube/tests/frontend/module_import.rs @@ -33,12 +33,12 @@ mod tests { fn cube_call_equivalent_to_no_call_no_arg_test() { let mut caller_context = CubeContext::root(); let x = caller_context.create_local(Item::new(ElemType::as_elem())); - here::caller_expand::(&mut caller_context, x); + here::caller::__expand::(&mut caller_context, x); let caller_scope = caller_context.into_scope(); let mut no_call_context = CubeContext::root(); let x = no_call_context.create_local(Item::new(ElemType::as_elem())); - here::no_call_ref_expand::(&mut no_call_context, x); + here::no_call_ref::__expand::(&mut no_call_context, x); let no_call_scope = no_call_context.into_scope(); assert_eq!( diff --git a/crates/burn-cube/tests/frontend/ops.rs b/crates/burn-cube/tests/frontend/ops.rs index 21ef221ae..e53f664c1 100644 --- a/crates/burn-cube/tests/frontend/ops.rs +++ b/crates/burn-cube/tests/frontend/ops.rs @@ -1,192 +1,192 @@ use burn_cube::prelude::*; #[cube] -fn add_op(a: T, b: T) -> T { +pub fn add_op(a: T, b: T) -> T { a + b } #[cube] -fn sub_op(a: T, b: T) -> T { +pub fn sub_op(a: T, b: T) -> T { a - b } #[cube] -fn mul_op(a: T, b: T) -> T { +pub fn mul_op(a: T, b: T) -> T { a * b } #[cube] -fn div_op(a: T, b: T) -> T { +pub fn div_op(a: T, b: T) -> T { a / b } #[cube] -fn abs_op(a: T) -> T { +pub fn abs_op(a: T) -> T { T::abs(a) } #[cube] -fn exp_op(a: F) -> F { +pub fn exp_op(a: F) -> F { F::exp(a) } #[cube] -fn log_op(a: F) -> F { +pub fn log_op(a: F) -> F { F::log(a) } #[cube] -fn log1p_op(a: F) -> F { +pub fn log1p_op(a: F) -> F { F::log1p(a) } #[cube] -fn cos_op(a: F) -> F { +pub fn cos_op(a: F) -> F { F::cos(a) } #[cube] -fn sin_op(a: F) -> F { +pub fn sin_op(a: F) -> F { F::sin(a) } #[cube] -fn tanh_op(a: F) -> F { +pub fn tanh_op(a: F) -> F { F::tanh(a) } #[cube] -fn powf_op(a: F, b: F) -> F { +pub fn powf_op(a: F, b: F) -> F { F::powf(a, b) } #[cube] -fn sqrt_op(a: F) -> F { +pub fn sqrt_op(a: F) -> F { F::sqrt(a) } #[cube] -fn floor_op(a: F) -> F { +pub fn floor_op(a: F) -> F { F::floor(a) } #[cube] -fn ceil_op(a: F) -> F { +pub fn ceil_op(a: F) -> F { F::ceil(a) } #[cube] -fn erf_op(a: F) -> F { +pub fn erf_op(a: F) -> F { F::erf(a) } #[cube] -fn recip_op(a: F) -> F { +pub fn recip_op(a: F) -> F { F::recip(a) } #[cube] -fn equal_op(a: T, b: T) -> bool { +pub fn equal_op(a: T, b: T) -> bool { a == b } #[cube] -fn not_equal_op(a: T, b: T) -> bool { +pub fn not_equal_op(a: T, b: T) -> bool { a != b } #[cube] -fn lower_op(a: T, b: T) -> bool { +pub fn lower_op(a: T, b: T) -> bool { a < b } #[cube] -fn greater_op(a: T, b: T) -> bool { +pub fn greater_op(a: T, b: T) -> bool { a > b } #[cube] -fn lower_equal_op(a: T, b: T) -> bool { +pub fn lower_equal_op(a: T, b: T) -> bool { a <= b } #[cube] -fn greater_equal_op(a: T, b: T) -> bool { +pub fn greater_equal_op(a: T, b: T) -> bool { a >= b } #[cube] -fn modulo_op(a: UInt, b: UInt) -> UInt { +pub fn modulo_op(a: UInt, b: UInt) -> UInt { a % b } #[cube] -fn remainder_op(a: T, b: T) -> T { +pub fn remainder_op(a: T, b: T) -> T { T::rem(a, b) } #[cube] -fn max_op(a: T, b: T) -> T { +pub fn max_op(a: T, b: T) -> T { T::max(a, b) } #[cube] -fn min_op(a: T, b: T) -> T { +pub fn min_op(a: T, b: T) -> T { T::min(a, b) } #[cube] -fn and_op(a: bool, b: bool) -> bool { +pub fn and_op(a: bool, b: bool) -> bool { a && b } #[cube] -fn or_op(a: bool, b: bool) -> bool { +pub fn or_op(a: bool, b: bool) -> bool { a || b } #[cube] -fn not_op(a: bool) -> bool { +pub fn not_op(a: bool) -> bool { !a } #[cube] -fn bitand_op(a: UInt, b: UInt) -> UInt { +pub fn bitand_op(a: UInt, b: UInt) -> UInt { a & b } #[cube] -fn bitxor_op(a: UInt, b: UInt) -> UInt { +pub fn bitxor_op(a: UInt, b: UInt) -> UInt { a ^ b } #[cube] -fn shl_op(a: UInt, b: UInt) -> UInt { +pub fn shl_op(a: UInt, b: UInt) -> UInt { a << b } #[cube] -fn shr_op(a: UInt, b: UInt) -> UInt { +pub fn shr_op(a: UInt, b: UInt) -> UInt { a >> b } #[cube] -fn add_assign_op(mut a: T, b: T) { +pub fn add_assign_op(mut a: T, b: T) { a += b; } #[cube] -fn sub_assign_op(mut a: T, b: T) { +pub fn sub_assign_op(mut a: T, b: T) { a -= b; } #[cube] -fn mul_assign_op(mut a: T, b: T) { +pub fn mul_assign_op(mut a: T, b: T) { a *= b; } #[cube] -fn div_assign_op(mut a: T, b: T) { +pub fn div_assign_op(mut a: T, b: T) { a /= b; } @@ -195,14 +195,14 @@ mod tests { use burn_cube::ir::{Elem, FloatKind, Item}; macro_rules! binary_test { - ($test_name:ident, $op_expand:ident, $op_name:expr, $func:ident) => { + ($test_name:ident, $op_expand:expr, $op_name:expr, $func:ident) => { #[test] fn $test_name() { let mut context = CubeContext::root(); let x = context.create_local(Item::new(Elem::Float(FloatKind::F32))); let y = context.create_local(Item::new(Elem::Float(FloatKind::F32))); - $op_expand::(&mut context, x, y); + $op_expand(&mut context, x, y); assert_eq!( format!("{:?}", context.into_scope().operations), @@ -213,13 +213,13 @@ mod tests { } macro_rules! unary_test { - ($test_name:ident, $op_expand:ident, $op_name:expr) => { + ($test_name:ident, $op_expand:expr, $op_name:expr) => { #[test] fn $test_name() { let mut context = CubeContext::root(); let x = context.create_local(Item::new(Elem::Float(FloatKind::F32))); - $op_expand::(&mut context, x); + $op_expand(&mut context, x); assert_eq!( format!("{:?}", context.into_scope().operations), @@ -230,7 +230,7 @@ mod tests { } macro_rules! binary_boolean_test { - ($test_name:ident, $op_expand:ident, $op_name:expr) => { + ($test_name:ident, $op_expand:expr, $op_name:expr) => { #[test] fn $test_name() { let mut context = CubeContext::root(); @@ -248,7 +248,7 @@ mod tests { } macro_rules! binary_uint_test { - ($test_name:ident, $op_expand:ident, $op_name:expr) => { + ($test_name:ident, $op_expand:expr, $op_name:expr) => { #[test] fn $test_name() { let mut context = CubeContext::root(); @@ -265,75 +265,90 @@ mod tests { }; } - binary_test!(cube_can_add, add_op_expand, "Add", ref_ops_binary); - binary_test!(cube_can_sub, sub_op_expand, "Sub", ref_ops_binary); - binary_test!(cube_can_mul, mul_op_expand, "Mul", ref_ops_binary); - binary_test!(cube_can_div, div_op_expand, "Div", ref_ops_binary); - unary_test!(cube_can_abs, abs_op_expand, "Abs"); - unary_test!(cube_can_exp, exp_op_expand, "Exp"); - unary_test!(cube_can_log, log_op_expand, "Log"); - unary_test!(cube_can_log1p, log1p_op_expand, "Log1p"); - unary_test!(cube_can_cos, cos_op_expand, "Cos"); - unary_test!(cube_can_sin, sin_op_expand, "Sin"); - unary_test!(cube_can_tanh, tanh_op_expand, "Tanh"); - binary_test!(cube_can_powf, powf_op_expand, "Powf", ref_ops_binary); - unary_test!(cube_can_sqrt, sqrt_op_expand, "Sqrt"); - unary_test!(cube_can_erf, erf_op_expand, "Erf"); - unary_test!(cube_can_recip, recip_op_expand, "Recip"); - unary_test!(cube_can_floor, floor_op_expand, "Floor"); - unary_test!(cube_can_ceil, ceil_op_expand, "Ceil"); - binary_test!(cube_can_eq, equal_op_expand, "Equal", ref_ops_cmp); - binary_test!(cube_can_ne, not_equal_op_expand, "NotEqual", ref_ops_cmp); - binary_test!(cube_can_lt, lower_op_expand, "Lower", ref_ops_cmp); + binary_test!(cube_can_add, add_op::__expand::, "Add", ref_ops_binary); + binary_test!(cube_can_sub, sub_op::__expand::, "Sub", ref_ops_binary); + binary_test!(cube_can_mul, mul_op::__expand::, "Mul", ref_ops_binary); + binary_test!(cube_can_div, div_op::__expand::, "Div", ref_ops_binary); + unary_test!(cube_can_abs, abs_op::__expand::, "Abs"); + unary_test!(cube_can_exp, exp_op::__expand::, "Exp"); + unary_test!(cube_can_log, log_op::__expand::, "Log"); + unary_test!(cube_can_log1p, log1p_op::__expand::, "Log1p"); + unary_test!(cube_can_cos, cos_op::__expand::, "Cos"); + unary_test!(cube_can_sin, sin_op::__expand::, "Sin"); + unary_test!(cube_can_tanh, tanh_op::__expand::, "Tanh"); + binary_test!( + cube_can_powf, + powf_op::__expand::, + "Powf", + ref_ops_binary + ); + unary_test!(cube_can_sqrt, sqrt_op::__expand::, "Sqrt"); + unary_test!(cube_can_erf, erf_op::__expand::, "Erf"); + unary_test!(cube_can_recip, recip_op::__expand::, "Recip"); + unary_test!(cube_can_floor, floor_op::__expand::, "Floor"); + unary_test!(cube_can_ceil, ceil_op::__expand::, "Ceil"); + binary_test!(cube_can_eq, equal_op::__expand::, "Equal", ref_ops_cmp); + binary_test!( + cube_can_ne, + not_equal_op::__expand::, + "NotEqual", + ref_ops_cmp + ); + binary_test!(cube_can_lt, lower_op::__expand::, "Lower", ref_ops_cmp); binary_test!( cube_can_le, - lower_equal_op_expand, + lower_equal_op::__expand::, "LowerEqual", ref_ops_cmp ); binary_test!( cube_can_ge, - greater_equal_op_expand, + greater_equal_op::__expand::, "GreaterEqual", ref_ops_cmp ); - binary_test!(cube_can_gt, greater_op_expand, "Greater", ref_ops_cmp); - binary_test!(cube_can_max, max_op_expand, "Max", ref_ops_binary); - binary_test!(cube_can_min, min_op_expand, "Min", ref_ops_binary); + binary_test!( + cube_can_gt, + greater_op::__expand::, + "Greater", + ref_ops_cmp + ); + binary_test!(cube_can_max, max_op::__expand::, "Max", ref_ops_binary); + binary_test!(cube_can_min, min_op::__expand::, "Min", ref_ops_binary); binary_test!( cube_can_add_assign, - add_assign_op_expand, + add_assign_op::__expand::, "Add", ref_ops_binary ); binary_test!( cube_can_sub_assign, - sub_assign_op_expand, + sub_assign_op::__expand::, "Sub", ref_ops_binary ); binary_test!( cube_can_mul_assign, - mul_assign_op_expand, + mul_assign_op::__expand::, "Mul", ref_ops_binary ); binary_test!( cube_can_div_assign, - div_assign_op_expand, + div_assign_op::__expand::, "Div", ref_ops_binary ); - binary_boolean_test!(cube_can_and, and_op_expand, "And"); - binary_boolean_test!(cube_can_or, or_op_expand, "Or"); - binary_uint_test!(cube_can_bitand, bitand_op_expand, "BitwiseAnd"); - binary_uint_test!(cube_can_bitxor, bitxor_op_expand, "BitwiseXor"); - binary_uint_test!(cube_can_shl, shl_op_expand, "ShiftLeft"); - binary_uint_test!(cube_can_shr, shr_op_expand, "ShiftRight"); - binary_uint_test!(cube_can_mod, modulo_op_expand, "Modulo"); + binary_boolean_test!(cube_can_and, and_op::__expand, "And"); + binary_boolean_test!(cube_can_or, or_op::__expand, "Or"); + binary_uint_test!(cube_can_bitand, bitand_op::__expand, "BitwiseAnd"); + binary_uint_test!(cube_can_bitxor, bitxor_op::__expand, "BitwiseXor"); + binary_uint_test!(cube_can_shl, shl_op::__expand, "ShiftLeft"); + binary_uint_test!(cube_can_shr, shr_op::__expand, "ShiftRight"); + binary_uint_test!(cube_can_mod, modulo_op::__expand, "Modulo"); binary_test!( cube_can_rem, - remainder_op_expand, + remainder_op::__expand::, "Remainder", ref_ops_binary ); @@ -343,7 +358,7 @@ mod tests { let mut context = CubeContext::root(); let x = context.create_local(Item::new(Elem::Bool)); - not_op_expand(&mut context, x); + not_op::__expand(&mut context, x); assert_eq!( format!("{:?}", context.into_scope().operations), diff --git a/crates/burn-cube/tests/frontend/parenthesis.rs b/crates/burn-cube/tests/frontend/parenthesis.rs index 446433ec3..bb522cc1e 100644 --- a/crates/burn-cube/tests/frontend/parenthesis.rs +++ b/crates/burn-cube/tests/frontend/parenthesis.rs @@ -22,7 +22,7 @@ mod tests { let y = context.create_local(Item::new(ElemType::as_elem())); let z = context.create_local(Item::new(ElemType::as_elem())); - parenthesis_expand::(&mut context, x, y, z); + parenthesis::__expand::(&mut context, x, y, z); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); diff --git a/crates/burn-cube/tests/frontend/redeclare.rs b/crates/burn-cube/tests/frontend/redeclare.rs index 2fdf9b4c4..4fb786721 100644 --- a/crates/burn-cube/tests/frontend/redeclare.rs +++ b/crates/burn-cube/tests/frontend/redeclare.rs @@ -53,7 +53,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); - redeclare_same_scope_expand::(&mut context, x); + redeclare_same_scope::__expand::(&mut context, x); let scope = context.into_scope(); assert_eq!( @@ -68,7 +68,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); - redeclare_same_scope_other_type_expand::(&mut context, x); + redeclare_same_scope_other_type::__expand::(&mut context, x); let scope = context.into_scope(); assert_eq!( @@ -83,7 +83,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); - redeclare_different_scope_expand::(&mut context, x); + redeclare_different_scope::__expand::(&mut context, x); let scope = context.into_scope(); assert_eq!( @@ -98,7 +98,7 @@ mod tests { let x = context.create_local(Item::new(UInt::as_elem())); - redeclare_two_for_loops_expand(&mut context, x); + redeclare_two_for_loops::__expand(&mut context, x); let scope = context.into_scope(); assert_eq!( diff --git a/crates/burn-cube/tests/frontend/reuse.rs b/crates/burn-cube/tests/frontend/reuse.rs index ee8bf9519..53063f798 100644 --- a/crates/burn-cube/tests/frontend/reuse.rs +++ b/crates/burn-cube/tests/frontend/reuse.rs @@ -32,7 +32,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); - reuse_expand::(&mut context, x); + reuse::__expand::(&mut context, x); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_assign()); @@ -44,7 +44,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); - reuse_incr_expand::(&mut context, x); + reuse_incr::__expand::(&mut context, x); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_incr()); diff --git a/crates/burn-cube/tests/frontend/shared_memory.rs b/crates/burn-cube/tests/frontend/shared_memory.rs index 93f90f608..5018630cc 100644 --- a/crates/burn-cube/tests/frontend/shared_memory.rs +++ b/crates/burn-cube/tests/frontend/shared_memory.rs @@ -1,7 +1,7 @@ use burn_cube::prelude::*; #[cube] -fn shared_memory_read_write(sm_size: Comptime) { +pub fn shared_memory_read_write(sm_size: Comptime) { let mut shared = SharedMemory::::new(sm_size); shared[0] = T::from_int(3); let _ = shared[0]; @@ -20,7 +20,7 @@ mod tests { fn cube_support_shared_memory() { let mut context = CubeContext::root(); - shared_memory_read_write_expand::(&mut context, 512); + shared_memory_read_write::__expand::(&mut context, 512); assert_eq!( format!("{:?}", context.into_scope().operations), inline_macro_ref() diff --git a/crates/burn-cube/tests/frontend/struct.rs b/crates/burn-cube/tests/frontend/struct.rs index eccfe4045..d5539dc07 100644 --- a/crates/burn-cube/tests/frontend/struct.rs +++ b/crates/burn-cube/tests/frontend/struct.rs @@ -1,25 +1,25 @@ use burn_cube::prelude::*; #[derive(CubeType)] -struct State { +pub struct State { first: T, second: T, } #[cube] -fn state_receiver_with_reuse(state: State) -> T { +pub fn state_receiver_with_reuse(state: State) -> T { let x = state.first + state.second; state.second + x + state.first } #[cube] -fn attribute_modifier_reuse_field(mut state: State) -> T { +pub fn attribute_modifier_reuse_field(mut state: State) -> T { state.first = T::from_int(4); state.first } #[cube] -fn attribute_modifier_reuse_struct(mut state: State) -> State { +pub fn attribute_modifier_reuse_struct(mut state: State) -> State { state.first = T::from_int(4); state } @@ -48,7 +48,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); let y = context.create_local(Item::new(ElemType::as_elem())); - creator_expand::(&mut context, x, y); + creator::__expand::(&mut context, x, y); let scope = context.into_scope(); assert_eq!( @@ -68,7 +68,7 @@ mod tests { first: x, second: y, }; - state_receiver_with_reuse_expand::(&mut context, expanded_state); + state_receiver_with_reuse::__expand::(&mut context, expanded_state); let scope = context.into_scope(); assert_eq!( @@ -88,7 +88,7 @@ mod tests { first: x, second: y, }; - attribute_modifier_reuse_field_expand::(&mut context, expanded_state); + attribute_modifier_reuse_field::__expand::(&mut context, expanded_state); let scope = context.into_scope(); assert_eq!( @@ -108,7 +108,7 @@ mod tests { first: x, second: y, }; - attribute_modifier_reuse_struct_expand::(&mut context, expanded_state); + attribute_modifier_reuse_struct::__expand::(&mut context, expanded_state); let scope = context.into_scope(); assert_eq!( diff --git a/crates/burn-cube/tests/frontend/tensor.rs b/crates/burn-cube/tests/frontend/tensor.rs index 4a909b8bc..2d27c3ad7 100644 --- a/crates/burn-cube/tests/frontend/tensor.rs +++ b/crates/burn-cube/tests/frontend/tensor.rs @@ -1,7 +1,7 @@ use burn_cube::prelude::*; #[cube] -fn kernel(input: &Tensor) { +pub fn kernel(input: &Tensor) { let _shape = input.shape(1); let _stride = input.stride(1); let _length = input.len(); @@ -21,7 +21,7 @@ mod tests { let mut context = CubeContext::root(); let input = context.input(0, Item::new(ElemType::as_elem())); - kernel_expand::(&mut context, input.into()); + kernel::__expand::(&mut context, input.into()); assert_eq!( format!("{:?}", context.into_scope().operations), inline_macro_ref() diff --git a/crates/burn-cube/tests/frontend/topology.rs b/crates/burn-cube/tests/frontend/topology.rs index afc83c061..9a6548249 100644 --- a/crates/burn-cube/tests/frontend/topology.rs +++ b/crates/burn-cube/tests/frontend/topology.rs @@ -1,7 +1,7 @@ use burn_cube::prelude::*; #[cube] -fn topology_kernel(input: Tensor) { +pub fn topology_kernel(input: Tensor) { let x = ABSOLUTE_POS + UInt::new(4); let _ = input[x]; } @@ -20,7 +20,7 @@ mod tests { let mut context = CubeContext::root(); let input = context.input(0, Item::new(ElemType::as_elem())); - topology_kernel_expand::(&mut context, input.into()); + topology_kernel::__expand::(&mut context, input.into()); assert_eq!( format!("{:?}", context.into_scope().operations), inline_macro_ref() diff --git a/crates/burn-cube/tests/frontend/trait.rs b/crates/burn-cube/tests/frontend/trait.rs index 65fd55df5..b85c43c21 100644 --- a/crates/burn-cube/tests/frontend/trait.rs +++ b/crates/burn-cube/tests/frontend/trait.rs @@ -4,7 +4,7 @@ use burn_cube::prelude::*; /// for all their methods. However, one does not need to provide its /// implementation, see examples below. #[cube] -trait Strategy { +pub trait Strategy { fn operation(input_1: T, input_2: T) -> T; } @@ -13,7 +13,7 @@ struct AddStrategy; #[cube] /// The actual implementation of AddStrategy's operation /// Automatically generated an _expand variant -fn add_strategy_operation(input_1: T, input_2: T) -> T { +pub fn add_strategy_operation(input_1: T, input_2: T) -> T { input_1 + input_2 } @@ -34,19 +34,19 @@ impl Strategy for SubStrategy { } #[cube] -fn with_strategy_trait, T: Numeric>(x: T, y: T) -> T { +pub fn with_strategy_trait, T: Numeric>(x: T, y: T) -> T { S::operation(x, y) } #[cube] -fn two_strategy_traits, S2: Strategy, F: Float>(x: F, y: F) -> F { +pub fn two_strategy_traits, S2: Strategy, F: Float>(x: F, y: F) -> F { let z = S1::operation(x, y); S2::operation(z, y) } -trait MethodTypedStrategy { +pub trait MethodTypedStrategy { fn operation(input_1: T, input_2: T) -> T; - fn operation_expand( + fn __expand_operation( _context: &mut CubeContext, input_1: ::ExpandType, input_2: ::ExpandType, @@ -58,17 +58,17 @@ impl MethodTypedStrategy for AddStrategy { add_strategy_operation(input_1, input_2) } - fn operation_expand( + fn __expand_operation( context: &mut CubeContext, input_1: ::ExpandType, input_2: ::ExpandType, ) -> ::ExpandType { - add_strategy_operation_expand::(context, input_1, input_2) + add_strategy_operation::__expand::(context, input_1, input_2) } } #[cube] -fn with_trait_generic_method(x: T, y: T) -> T { +pub fn with_trait_generic_method(x: T, y: T) -> T { S::operation::(x, y) } @@ -87,7 +87,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); let y = context.create_local(Item::new(ElemType::as_elem())); - with_strategy_trait_expand::(&mut context, x, y); + with_strategy_trait::__expand::(&mut context, x, y); let scope = context.into_scope(); assert_eq!( @@ -103,7 +103,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); let y = context.create_local(Item::new(ElemType::as_elem())); - with_strategy_trait_expand::(&mut context, x, y); + with_strategy_trait::__expand::(&mut context, x, y); let scope = context.into_scope(); assert_eq!( @@ -119,7 +119,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); let y = context.create_local(Item::new(ElemType::as_elem())); - two_strategy_traits_expand::(&mut context, x, y); + two_strategy_traits::__expand::(&mut context, x, y); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_two()); @@ -132,7 +132,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); let y = context.create_local(Item::new(ElemType::as_elem())); - with_trait_generic_method_expand::(&mut context, x, y); + with_trait_generic_method::__expand::(&mut context, x, y); let scope = context.into_scope(); assert_eq!( diff --git a/crates/burn-cube/tests/frontend/vectorization.rs b/crates/burn-cube/tests/frontend/vectorization.rs index 56dd5b1a6..4dd0a8dcf 100644 --- a/crates/burn-cube/tests/frontend/vectorization.rs +++ b/crates/burn-cube/tests/frontend/vectorization.rs @@ -22,7 +22,7 @@ mod tests { let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 2)); - vectorization_binary_expand::(&mut context, lhs); + vectorization_binary::__expand::(&mut context, lhs); } #[test] @@ -32,7 +32,7 @@ mod tests { let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 4)); - vectorization_binary_expand::(&mut context, lhs); + vectorization_binary::__expand::(&mut context, lhs); } #[test] @@ -41,7 +41,7 @@ mod tests { let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 2)); - vectorization_cmp_expand::(&mut context, lhs); + vectorization_cmp::__expand::(&mut context, lhs); } #[test] @@ -51,7 +51,7 @@ mod tests { let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 4)); - vectorization_cmp_expand::(&mut context, lhs); + vectorization_cmp::__expand::(&mut context, lhs); } #[test] @@ -60,6 +60,6 @@ mod tests { let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 1)); - vectorization_cmp_expand::(&mut context, lhs); + vectorization_cmp::__expand::(&mut context, lhs); } } diff --git a/crates/burn-jit/src/kernel/clamp.rs b/crates/burn-jit/src/kernel/clamp.rs index 4cb31fbbe..3a77d3741 100644 --- a/crates/burn-jit/src/kernel/clamp.rs +++ b/crates/burn-jit/src/kernel/clamp.rs @@ -3,23 +3,23 @@ use burn_cube::prelude::*; use crate::kernel::{launch_unary, UnaryOp}; use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; +#[derive(CubeLaunch)] +struct Options { + min_value: C, + max_value: C, +} + pub(crate) fn clamp( input: JitTensor, min_value: E, max_value: E, ) -> JitTensor { - #[derive(CubeLaunch)] - struct Options { - min_value: C, - max_value: C, - } - struct ClampOp; impl UnaryOp for ClampOp { type Options = Options; - fn execute_expand( + fn __expand_execute( context: &mut CubeContext, input: C::ExpandType, options: OptionsExpand, @@ -29,7 +29,7 @@ pub(crate) fn clamp( C::clamp(input, options.min_value, options.max_value) } - execute_expand(context, input, options) + execute::__expand(context, input, options) } } diff --git a/crates/burn-jit/src/kernel/comparison.rs b/crates/burn-jit/src/kernel/comparison.rs index 51465060c..8f3bf7539 100644 --- a/crates/burn-jit/src/kernel/comparison.rs +++ b/crates/burn-jit/src/kernel/comparison.rs @@ -1,4 +1,4 @@ -use super::{index_offset_with_layout, index_offset_with_layout_expand, Kernel}; +use super::{index_offset_with_layout, Kernel}; use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; use burn_cube::{ calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor, Runtime, @@ -146,7 +146,7 @@ pub(crate) fn launch_cmp< let same_tensor_type = core::any::TypeId::of::() == core::any::TypeId::of::(); if same_tensor_type && lhs.can_mut_broadcast(&rhs) { - kernel_cmp_launch::( + kernel_cmp::launch::( client, cube_count, CubeDim::default(), @@ -170,7 +170,7 @@ pub(crate) fn launch_cmp< JitTensor::new(lhs.client, lhs.handle, lhs.shape, lhs.device, lhs.strides) } else if same_tensor_type && rhs.can_mut_broadcast(&lhs) { - kernel_cmp_launch::( + kernel_cmp::launch::( client, cube_count, CubeDim::default(), @@ -199,7 +199,7 @@ pub(crate) fn launch_cmp< let to_contiguous_rhs = !rhs.is_contiguous(); let output = JitTensor::new_contiguous(lhs.client.clone(), lhs.device, shape_out, buffer); - kernel_cmp_launch::( + kernel_cmp::launch::( client, cube_count, CubeDim::default(), @@ -251,7 +251,7 @@ pub(crate) fn launch_scalar_cmp< let same_tensor_type = core::any::TypeId::of::() == core::any::TypeId::of::(); if same_tensor_type && tensor.can_mut() { - kernel_scalar_cmp_launch::( + kernel_scalar_cmp::launch::( client, cube_count, CubeDim::default(), @@ -282,7 +282,7 @@ pub(crate) fn launch_scalar_cmp< tensor.strides, ); - kernel_scalar_cmp_launch::( + kernel_scalar_cmp::launch::( client, cube_count, CubeDim::default(), diff --git a/crates/burn-jit/src/kernel/contiguous.rs b/crates/burn-jit/src/kernel/contiguous.rs index 03333a9e7..6e4576ff8 100644 --- a/crates/burn-jit/src/kernel/contiguous.rs +++ b/crates/burn-jit/src/kernel/contiguous.rs @@ -87,7 +87,7 @@ pub fn into_contiguous( SUBCUBE_DIM_APPROX, ); - into_contiguous_kernel_launch::( + into_contiguous_kernel::launch::( client, cube_count, CubeDim::default(), diff --git a/crates/burn-jit/src/kernel/conv/conv2d.rs b/crates/burn-jit/src/kernel/conv/conv2d.rs index 271124cd7..106979ebd 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d.rs @@ -163,7 +163,7 @@ pub(crate) fn conv2d( let num_elems_output = output.shape.num_elements(); let cube_dim = calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX); - conv2d_kernel_launch::( + conv2d_kernel::launch::( input.client, cube_dim, CubeDim::default(), diff --git a/crates/burn-jit/src/kernel/conv/conv3d.rs b/crates/burn-jit/src/kernel/conv/conv3d.rs index f0fe8cc36..da1e3aea0 100644 --- a/crates/burn-jit/src/kernel/conv/conv3d.rs +++ b/crates/burn-jit/src/kernel/conv/conv3d.rs @@ -188,7 +188,7 @@ pub(crate) fn conv3d( } }; - conv3d_kernel_launch::( + conv3d_kernel::launch::( input.client, calculate_cube_count_elemwise(output.shape.num_elements(), SUBCUBE_DIM_APPROX), CubeDim::default(), diff --git a/crates/burn-jit/src/kernel/matmul/simple.rs b/crates/burn-jit/src/kernel/matmul/simple.rs index 16d92516d..63c68f58d 100644 --- a/crates/burn-jit/src/kernel/matmul/simple.rs +++ b/crates/burn-jit/src/kernel/matmul/simple.rs @@ -116,7 +116,7 @@ pub fn matmul_simple( false => 1, }; - matmul_kernel_launch::( + matmul_kernel::launch::( lhs.client, cube_count, CubeDim::new(cube_dim_x as u32, cube_dim_y as u32, 1), diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/base.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/base.rs index 6268da11a..e46f9a8c1 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/base.rs @@ -2,11 +2,11 @@ use burn_cube::prelude::*; use crate::kernel::matmul::config::CubeTiling2dConfig; -use super::block_loop::{block_loop, block_loop_expand}; +use super::block_loop::block_loop; #[cube(launch)] #[allow(unused_mut)] -fn tiling2d_cube( +pub fn tiling2d_cube_kernel( lhs: &Tensor, rhs: &Tensor, out: &mut Tensor, diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/block_loop.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/block_loop.rs index 123f991fe..110902db0 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/block_loop.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/block_loop.rs @@ -4,10 +4,10 @@ use crate::kernel::matmul::config::CubeTiling2dConfig; use super::{ base::{BatchOffsets, Coordinates, Dimensions, SharedMemories}, - compute_loop::{compute_loop, compute_loop_expand}, - load_shared_memory::{load_to_shared_memories, load_to_shared_memories_expand}, + compute_loop::compute_loop, + load_shared_memory::load_to_shared_memories, tile::{loader::TileLoader, writer::TileWriter}, - write_output::{write_to_output, write_to_output_expand}, + write_output::write_to_output, }; #[cube] diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/compute_loop.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/compute_loop.rs index 082421b25..29fe627f2 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/compute_loop.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/compute_loop.rs @@ -2,10 +2,7 @@ use burn_cube::prelude::*; use crate::kernel::matmul::config::CubeTiling2dConfig; -use super::{ - base::Coordinates, - outer_product::{tile_outer_product, tile_outer_product_expand}, -}; +use super::{base::Coordinates, outer_product::tile_outer_product}; #[cube] #[allow(unused_mut)] @@ -105,7 +102,7 @@ pub mod tests { const SOME_DIM: usize = 12; let config = make_config(SOME_DIM, SOME_DIM, SOME_DIM); - compute_loop_test_launch::( + compute_loop_test::launch::( lhs.client.clone(), cube_count, cube_dim, @@ -134,7 +131,7 @@ pub mod tests { let config = make_config(4, 8, 4); - compute_loop_test_launch::( + compute_loop_test::launch::( lhs.client.clone(), cube_count, cube_dim, diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/launch.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/launch.rs index b196ff4f4..9f2606f09 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/launch.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/launch.rs @@ -7,6 +7,7 @@ use crate::{ into_contiguous, matmul::{ config::{tiling2d_cube_count, tiling2d_cube_dim, CubeTiling2dConfig}, + tiling2d_cube::base::tiling2d_cube_kernel, Tiling2dConfig, }, }, @@ -14,8 +15,6 @@ use crate::{ FloatElement, JitRuntime, }; -use super::base::tiling2d_cube_launch; - /// Matrix multiplication using tiling 2d algorithm pub fn matmul_tiling_2d_cube( lhs: JitTensor, @@ -69,7 +68,7 @@ pub fn matmul_tiling_2d_cube( let cube_dim = tiling2d_cube_dim(&config); let cube_config = CubeTiling2dConfig::new(&config, m, k, n, lhs_transposed, rhs_transposed); - tiling2d_cube_launch::( + tiling2d_cube_kernel::launch::( client, cube_count, cube_dim, diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/load_shared_memory.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/load_shared_memory.rs index aae7e8804..dc9ce5e9a 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/load_shared_memory.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/load_shared_memory.rs @@ -385,7 +385,7 @@ pub mod tests { let config = make_config(16, 16, 8); - load_tensor_test_launch::( + load_tensor_test::launch::( lhs.client.clone(), cube_count, cube_dim, @@ -417,7 +417,7 @@ pub mod tests { let config = make_config(5, 1, 1); - load_tensor_multiple_tiles_test_launch::( + load_tensor_multiple_tiles_test::launch::( lhs.client.clone(), cube_count, cube_dim, @@ -451,7 +451,7 @@ pub mod tests { let config = make_config(8, 8, 8); - load_tensor_multiple_tiles_test_launch::( + load_tensor_multiple_tiles_test::launch::( lhs.client.clone(), cube_count, cube_dim, @@ -481,7 +481,7 @@ pub mod tests { let config = make_config(8, 8, 16); - load_tensor_multiple_tiles_test_launch::( + load_tensor_multiple_tiles_test::launch::( lhs.client.clone(), cube_count, cube_dim, @@ -511,7 +511,7 @@ pub mod tests { let config = make_config(8, 16, 16); - load_tensor_test_launch::( + load_tensor_test::launch::( rhs.client.clone(), cube_count, cube_dim, @@ -543,7 +543,7 @@ pub mod tests { let config = make_config(8, 8, 8); - load_tensor_multiple_tiles_test_launch::( + load_tensor_multiple_tiles_test::launch::( rhs.client.clone(), cube_count, cube_dim, @@ -573,7 +573,7 @@ pub mod tests { let config = make_config(16, 16, 8); - load_tensor_multiple_tiles_test_launch::( + load_tensor_multiple_tiles_test::launch::( rhs.client.clone(), cube_count, cube_dim, @@ -603,7 +603,7 @@ pub mod tests { let config = make_config(16, 16, 8); - load_tensor_permuted_test_launch::( + load_tensor_permuted_test::launch::( lhs.client.clone(), cube_count, cube_dim, @@ -636,7 +636,7 @@ pub mod tests { let config = make_config(m, k, 8); - load_tensor_permuted_test_launch::( + load_tensor_permuted_test::launch::( lhs.client.clone(), cube_count, cube_dim, @@ -667,7 +667,7 @@ pub mod tests { let config = make_config(16, 16, 8); - load_tensor_permuted_test_launch::( + load_tensor_permuted_test::launch::( rhs.client.clone(), cube_count, cube_dim, @@ -699,7 +699,7 @@ pub mod tests { let config = make_config(8, k, n); - load_tensor_permuted_test_launch::( + load_tensor_permuted_test::launch::( rhs.client.clone(), cube_count, cube_dim, diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/outer_product.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/outer_product.rs index fb7512539..2ab90e611 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/outer_product.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/outer_product.rs @@ -70,7 +70,7 @@ pub mod tests { const SOME_DIM: usize = 12; let config = make_config(SOME_DIM, SOME_DIM, SOME_DIM); - tile_outer_product_test_launch::( + tile_outer_product_test::launch::( client.clone(), cube_count, cube_dim, @@ -99,7 +99,7 @@ pub mod tests { const SOME_DIM: usize = 12; let config = make_config(SOME_DIM, SOME_DIM, SOME_DIM); - tile_outer_product_test_launch::( + tile_outer_product_test::launch::( client.clone(), cube_count, cube_dim, diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/horizontal_block_check.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/horizontal_block_check.rs index 5bd140fb4..8b09877fd 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/horizontal_block_check.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/horizontal_block_check.rs @@ -14,10 +14,7 @@ use crate::kernel::matmul::{ }, }; -use super::base::{ - all_zeros_comptime, all_zeros_comptime_expand, all_zeros_runtime, all_zeros_runtime_expand, - BlockLoader, BlockWriter, -}; +use super::base::{all_zeros_comptime, all_zeros_runtime, BlockLoader, BlockWriter}; pub(crate) struct HorizontalCheckBlockIO; diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/vertical_block_check.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/vertical_block_check.rs index 677978e7f..1e91a32ac 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/vertical_block_check.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/vertical_block_check.rs @@ -14,7 +14,7 @@ use crate::kernel::matmul::{ }, }; -use super::base::{all_zeros_runtime, all_zeros_runtime_expand, BlockLoader, BlockWriter}; +use super::base::{all_zeros_runtime, BlockLoader, BlockWriter}; pub(crate) struct VerticalCheckBlockIO; diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/whole_block_check.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/whole_block_check.rs index 274c79181..e868b1633 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/whole_block_check.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/whole_block_check.rs @@ -14,10 +14,7 @@ use crate::kernel::matmul::{ }, }; -use super::base::{ - all_zeros_comptime, all_zeros_comptime_expand, all_zeros_runtime, all_zeros_runtime_expand, - BlockLoader, BlockWriter, -}; +use super::base::{all_zeros_comptime, all_zeros_runtime, BlockLoader, BlockWriter}; pub(crate) struct WholeCheckBlockIO; diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/write_output.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/write_output.rs index 95a12697e..42d2ee8d1 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/write_output.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/write_output.rs @@ -130,7 +130,7 @@ pub mod tests { let config = make_config(6, 8, 8); - write_to_output_test_launch::( + write_to_output_test::launch::( out.client.clone(), cube_count, cube_dim, @@ -156,7 +156,7 @@ pub mod tests { let config = make_config(8, 8, 4); - write_to_output_test_launch::( + write_to_output_test::launch::( out.client.clone(), cube_count, cube_dim, @@ -182,7 +182,7 @@ pub mod tests { let config = make_config(8, 8, 8); - write_to_output_test_launch::( + write_to_output_test::launch::( out.client.clone(), cube_count, cube_dim, @@ -215,7 +215,7 @@ pub mod tests { let config = make_config(8, 8, 8); - write_to_output_test_launch::( + write_to_output_test::launch::( out.client.clone(), cube_count, cube_dim, @@ -248,7 +248,7 @@ pub mod tests { let config = make_config(5, 8, 1); - write_results_to_output_out_of_bounds_test_launch::( + write_results_to_output_out_of_bounds_test::launch::( out.client.clone(), cube_count, cube_dim, diff --git a/crates/burn-jit/src/kernel/mod.rs b/crates/burn-jit/src/kernel/mod.rs index 0e508b944..867b92a28 100644 --- a/crates/burn-jit/src/kernel/mod.rs +++ b/crates/burn-jit/src/kernel/mod.rs @@ -11,7 +11,7 @@ pub use binary::*; pub use cast::*; pub use contiguous::*; pub use mask::*; -pub use unary::*; +pub(crate) use unary::*; pub use burn_cube::{Kernel, SUBCUBE_DIM_APPROX}; diff --git a/crates/burn-jit/src/kernel/unary.rs b/crates/burn-jit/src/kernel/unary.rs index e14759f37..b2ebf220f 100644 --- a/crates/burn-jit/src/kernel/unary.rs +++ b/crates/burn-jit/src/kernel/unary.rs @@ -4,7 +4,7 @@ use burn_cube::{ SUBCUBE_DIM_APPROX, }; -use super::{index_offset_with_layout, index_offset_with_layout_expand, Kernel}; +use super::{index_offset_with_layout, Kernel}; pub(crate) trait UnaryOp: 'static + Send + Sync { type Options: LaunchArg; @@ -13,7 +13,7 @@ pub(crate) trait UnaryOp: 'static + Send + Sync { fn execute(_input: C, _options: &Self::Options) -> C { unexpanded!(); } - fn execute_expand( + fn __expand_execute( context: &mut CubeContext, input: C::ExpandType, options: ::ExpandType, @@ -78,7 +78,7 @@ where let is_contiguous = tensor.is_contiguous(); if tensor.can_mut() && is_contiguous { - unary_kernel_launch::( + unary_kernel::launch::( client, cube_count, CubeDim::default(), @@ -104,7 +104,7 @@ where buffer, ); - unary_kernel_launch::( + unary_kernel::launch::( client, cube_count, CubeDim::default(), @@ -136,7 +136,7 @@ macro_rules! unary_op { type Options = (); #[allow(clippy::redundant_closure_call)] - fn execute_expand( + fn __expand_execute( context: &mut CubeContext, input: C::ExpandType, _options: ::ExpandType, @@ -152,7 +152,7 @@ macro_rules! unary_op { type Options = C; #[allow(clippy::redundant_closure_call)] - fn execute_expand( + fn __expand_execute( context: &mut CubeContext, input: C::ExpandType, scalar: C::ExpandType, diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index b53e1b4e3..128c1195d 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -334,7 +334,7 @@ where fn execute(input: C) -> C { C::exp(input) } - execute_expand::(context, input) + execute::__expand::(context, input) }) } @@ -344,7 +344,7 @@ where fn execute(input: C) -> C { C::log(input) } - execute_expand::(context, tensor) + execute::__expand::(context, tensor) }) } @@ -354,7 +354,7 @@ where fn execute(input: C) -> C { C::log1p(input) } - execute_expand::(context, tensor) + execute::__expand::(context, tensor) }) } @@ -367,7 +367,7 @@ where fn execute(input: C, scalar: C) -> C { C::powf(input, scalar) } - execute_expand::(context, tensor, scalar) + execute::__expand::(context, tensor, scalar) }) } @@ -377,7 +377,7 @@ where fn execute(input: C) -> C { C::sqrt(input) } - execute_expand::(context, tensor) + execute::__expand::(context, tensor) }) } @@ -387,7 +387,7 @@ where fn execute(input: C) -> C { C::abs(input) } - execute_expand::(context, tensor) + execute::__expand::(context, tensor) }) } @@ -397,7 +397,7 @@ where fn execute(input: C) -> C { C::cos(input) } - execute_expand::(context, tensor) + execute::__expand::(context, tensor) }) } @@ -407,7 +407,7 @@ where fn execute(input: C) -> C { C::sin(input) } - execute_expand::(context, tensor) + execute::__expand::(context, tensor) }) } @@ -417,7 +417,7 @@ where fn execute(input: C) -> C { C::tanh(input) } - execute_expand::(context, tensor) + execute::__expand::(context, tensor) }) } @@ -427,7 +427,7 @@ where fn execute(input: C) -> C { C::erf(input) } - execute_expand::(context, tensor) + execute::__expand::(context, tensor) }) } @@ -463,7 +463,7 @@ where fn execute(input: C) -> C { C::recip(input) } - execute_expand::(context, tensor) + execute::__expand::(context, tensor) }) } diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index 0bb62474d..e5fdc50ec 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -299,7 +299,7 @@ where fn execute(input: C) -> C { C::abs(input) } - execute_expand::(context, tensor) + execute::__expand::(context, tensor) }) } diff --git a/crates/burn-jit/src/ops/numeric.rs b/crates/burn-jit/src/ops/numeric.rs index 1f101e243..a334ad9c6 100644 --- a/crates/burn-jit/src/ops/numeric.rs +++ b/crates/burn-jit/src/ops/numeric.rs @@ -26,7 +26,7 @@ pub fn full_device( let empty = empty_device(client, device, shape); #[cube(launch)] - pub(crate) fn full_kernel(tensor: &mut Tensor, value: C) { + pub fn full_kernel(tensor: &mut Tensor, value: C) { if ABSOLUTE_POS >= tensor.len() { return; } @@ -42,7 +42,7 @@ pub fn full_device( SUBCUBE_DIM_APPROX, ); - full_kernel_launch::( + full_kernel::launch::( empty.client.clone(), cube_count, CubeDim::default(), @@ -127,7 +127,7 @@ pub fn add_scalar( fn execute(lhs: C, rhs: C) -> C { lhs + rhs } - execute_expand::(context, lhs, rhs) + execute::__expand::(context, lhs, rhs) }) } @@ -156,7 +156,7 @@ pub fn sub_scalar( fn execute(lhs: C, rhs: C) -> C { lhs - rhs } - execute_expand::(context, lhs, rhs) + execute::__expand::(context, lhs, rhs) }) } @@ -185,7 +185,7 @@ pub fn mul_scalar( fn execute(lhs: C, rhs: C) -> C { lhs * rhs } - execute_expand::(context, lhs, rhs) + execute::__expand::(context, lhs, rhs) }) } @@ -214,7 +214,7 @@ pub fn div_scalar( fn execute(lhs: C, rhs: C) -> C { lhs / rhs } - execute_expand::(context, lhs, rhs) + execute::__expand::(context, lhs, rhs) }) } diff --git a/crates/burn-jit/src/tensor/base.rs b/crates/burn-jit/src/tensor/base.rs index 5e7efa2fa..8fb678c28 100644 --- a/crates/burn-jit/src/tensor/base.rs +++ b/crates/burn-jit/src/tensor/base.rs @@ -146,7 +146,7 @@ where fn execute(input: C) -> C { input } - execute_expand::(context, tensor) + execute::__expand::(context, tensor) }) } diff --git a/examples/gelu/src/lib.rs b/examples/gelu/src/lib.rs index 9854fd4e3..a4c7d8040 100644 --- a/examples/gelu/src/lib.rs +++ b/examples/gelu/src/lib.rs @@ -20,7 +20,7 @@ pub fn launch(device: &R::Device) { let input_handle = client.create(f32::as_bytes(input)); let output_handle = client.empty(input.len() * core::mem::size_of::()); - gelu_launch::( + gelu::launch::( client.clone(), CubeCount::Static(1, 1, 1), CubeDim::default(),