mirror of https://github.com/tracel-ai/burn.git
Feat: Support trait with CubeCL (#1980)
This commit is contained in:
parent
c9e9054167
commit
8af2b719a1
|
@ -0,0 +1 @@
|
|||
pub(crate) mod signature;
|
|
@ -0,0 +1,60 @@
|
|||
use quote::ToTokens;
|
||||
|
||||
use crate::tracker::VariableTracker;
|
||||
|
||||
pub fn expand_sig(
|
||||
sig: &syn::Signature,
|
||||
visibility: &syn::Visibility,
|
||||
mut variable_tracker: Option<&mut VariableTracker>,
|
||||
) -> proc_macro2::TokenStream {
|
||||
let mut inputs = quote::quote!();
|
||||
|
||||
for input in &sig.inputs {
|
||||
match input {
|
||||
syn::FnArg::Typed(pat) => {
|
||||
let ident = pat.pat.clone();
|
||||
|
||||
if let syn::Pat::Ident(ident) = ident.as_ref() {
|
||||
if let Some(vars) = &mut variable_tracker {
|
||||
vars.codegen_declare(ident.ident.to_string(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
let ty = no_ref(pat.ty.as_ref());
|
||||
inputs.extend(quote::quote! {
|
||||
#ident: <#ty as burn_cube::frontend::CubeType>::ExpandType,
|
||||
});
|
||||
}
|
||||
_ => todo!("Only Typed inputs are supported"),
|
||||
}
|
||||
}
|
||||
|
||||
let mut output = quote::quote!();
|
||||
|
||||
match &sig.output {
|
||||
syn::ReturnType::Default => output.extend(quote::quote! { ()}),
|
||||
syn::ReturnType::Type(_, ty) => {
|
||||
let ty = no_ref(ty.as_ref());
|
||||
output.extend(quote::quote! {
|
||||
<#ty as burn_cube::frontend::CubeType>::ExpandType
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let ident = &sig.ident;
|
||||
let ident = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span());
|
||||
|
||||
let generics = sig.generics.clone().into_token_stream();
|
||||
|
||||
quote::quote! {
|
||||
/// Expanded Cube function
|
||||
#visibility fn #ident #generics (context: &mut burn_cube::frontend::CubeContext, #inputs) -> #output
|
||||
}
|
||||
}
|
||||
|
||||
pub fn no_ref(ty: &syn::Type) -> &syn::Type {
|
||||
match ty {
|
||||
syn::Type::Reference(val) => &val.elem,
|
||||
_ => ty,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
use crate::codegen_common::signature::expand_sig;
|
||||
|
||||
pub fn expand_trait_def(mut tr: syn::ItemTrait) -> proc_macro2::TokenStream {
|
||||
let mut expand_items = Vec::new();
|
||||
|
||||
for item in tr.items.iter() {
|
||||
match item {
|
||||
syn::TraitItem::Fn(func) => {
|
||||
let expand = expand_sig(&func.sig, &syn::Visibility::Inherited, None);
|
||||
expand_items.push(syn::parse_quote!(#expand;));
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
tr.items.append(&mut expand_items);
|
||||
|
||||
quote::quote! {
|
||||
#tr
|
||||
}
|
||||
}
|
||||
|
||||
pub fn expand_trait_impl(mut tr: syn::ItemImpl) -> proc_macro2::TokenStream {
|
||||
let mut expand_items = Vec::new();
|
||||
|
||||
for item in tr.items.iter() {
|
||||
match item {
|
||||
syn::ImplItem::Fn(func) => {
|
||||
let ident = &func.sig.ident;
|
||||
let ident = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span());
|
||||
let mut inputs = quote::quote!();
|
||||
|
||||
for input in &func.sig.inputs {
|
||||
match input {
|
||||
syn::FnArg::Typed(pat) => {
|
||||
let ident = pat.pat.clone();
|
||||
inputs.extend(quote::quote! {
|
||||
#ident,
|
||||
});
|
||||
}
|
||||
_ => todo!("Only Typed inputs are supported"),
|
||||
}
|
||||
}
|
||||
|
||||
let expand = expand_sig(&func.sig, &syn::Visibility::Inherited, None);
|
||||
|
||||
let tokens = if !tr.generics.params.is_empty() {
|
||||
let mut func = func.clone();
|
||||
for param in tr.generics.params.iter() {
|
||||
func.sig.generics.params.push(param.clone());
|
||||
}
|
||||
register_expand(&func, &ident, expand, inputs)
|
||||
} else {
|
||||
register_expand(func, &ident, expand, inputs)
|
||||
};
|
||||
|
||||
expand_items.push(syn::parse2(tokens).unwrap());
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
tr.items.append(&mut expand_items);
|
||||
|
||||
quote::quote! {
|
||||
#tr
|
||||
}
|
||||
}
|
||||
|
||||
fn register_expand(
|
||||
func: &syn::ImplItemFn,
|
||||
name: &syn::Ident,
|
||||
expand: proc_macro2::TokenStream,
|
||||
inputs: proc_macro2::TokenStream,
|
||||
) -> proc_macro2::TokenStream {
|
||||
let (func, func_expand) = if func.sig.generics.params.is_empty() {
|
||||
(
|
||||
quote::quote! { #func },
|
||||
quote::quote! {
|
||||
#name(context, #inputs)
|
||||
},
|
||||
)
|
||||
} else {
|
||||
let (_, gen, _) = &func.sig.generics.split_for_impl();
|
||||
(
|
||||
quote::quote! { #func },
|
||||
quote::quote! {
|
||||
#name::#gen(context, #inputs)
|
||||
},
|
||||
)
|
||||
};
|
||||
|
||||
quote::quote! (
|
||||
#expand {
|
||||
#[cube]
|
||||
#func
|
||||
#func_expand
|
||||
}
|
||||
)
|
||||
}
|
|
@ -3,14 +3,18 @@ extern crate derive_new;
|
|||
|
||||
mod analyzer;
|
||||
mod codegen_function;
|
||||
mod codegen_trait;
|
||||
mod codegen_type;
|
||||
mod tracker;
|
||||
|
||||
pub(crate) mod codegen_common;
|
||||
|
||||
use analyzer::VariableAnalyzer;
|
||||
use codegen_common::signature::expand_sig;
|
||||
use codegen_function::{codegen_launch, codegen_statement};
|
||||
use codegen_trait::{expand_trait_def, expand_trait_impl};
|
||||
use codegen_type::generate_cube_type;
|
||||
use proc_macro::TokenStream;
|
||||
use quote::ToTokens;
|
||||
use syn::{parse_macro_input, punctuated::Punctuated, token::Comma, Meta};
|
||||
use tracker::VariableTracker;
|
||||
|
||||
|
@ -38,20 +42,36 @@ pub fn module_derive_cube_type(input: TokenStream) -> TokenStream {
|
|||
generate_cube_type(&input, false)
|
||||
}
|
||||
|
||||
struct SupportedAttributes {
|
||||
mode: CubeMode,
|
||||
launch: bool,
|
||||
}
|
||||
|
||||
/// Derive macro for the module.
|
||||
#[proc_macro_attribute]
|
||||
pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream {
|
||||
let args = parse_macro_input!(attr with Punctuated::<Meta, syn::Token![,]>::parse_terminated);
|
||||
let (mode, launch) = parse_attributes(&args);
|
||||
let attrs = parse_attributes(&args);
|
||||
|
||||
let func: syn::ItemFn =
|
||||
syn::parse(tokens).expect("Cube annotations only supported for functions");
|
||||
let code: TokenStream = match syn::parse::<syn::Item>(tokens).unwrap() {
|
||||
syn::Item::Fn(func) => cube_fn(func, &attrs),
|
||||
syn::Item::Impl(item) => expand_trait_impl(item).into(),
|
||||
syn::Item::Trait(item) => expand_trait_def(item).into(),
|
||||
_ => panic!("Cube annotations only supported for functions"),
|
||||
};
|
||||
|
||||
match attrs.mode {
|
||||
CubeMode::Default => code,
|
||||
CubeMode::Debug => panic!("{code}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn cube_fn(func: syn::ItemFn, attrs: &SupportedAttributes) -> TokenStream {
|
||||
let mut variable_tracker = VariableAnalyzer::create_tracker(&func);
|
||||
|
||||
let code: TokenStream = match codegen_cube(&func, &mut variable_tracker) {
|
||||
match codegen_cube(&func, &mut variable_tracker) {
|
||||
Ok(code) => {
|
||||
if launch {
|
||||
if attrs.launch {
|
||||
let launch = codegen_launch(&func.sig);
|
||||
|
||||
quote::quote! {
|
||||
|
@ -64,15 +84,10 @@ pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream {
|
|||
}
|
||||
}
|
||||
Err(err) => err.into(),
|
||||
};
|
||||
|
||||
match mode {
|
||||
CubeMode::Default => code,
|
||||
CubeMode::Debug => panic!("{code}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_attributes(args: &Punctuated<Meta, Comma>) -> (CubeMode, bool) {
|
||||
fn parse_attributes(args: &Punctuated<Meta, Comma>) -> SupportedAttributes {
|
||||
let mut mode = CubeMode::Default;
|
||||
let mut launch = false;
|
||||
|
||||
|
@ -98,7 +113,7 @@ fn parse_attributes(args: &Punctuated<Meta, Comma>) -> (CubeMode, bool) {
|
|||
}
|
||||
}
|
||||
|
||||
(mode, launch)
|
||||
SupportedAttributes { mode, launch }
|
||||
}
|
||||
|
||||
/// Generate the expanded version of a function marked with the cube macro
|
||||
|
@ -106,7 +121,7 @@ fn codegen_cube(
|
|||
func: &syn::ItemFn,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
|
||||
let signature = expand_sig(&func.sig, &func.vis, variable_tracker);
|
||||
let signature = expand_sig(&func.sig, &func.vis, Some(variable_tracker));
|
||||
let mut body = quote::quote! {};
|
||||
|
||||
for statement in func.block.stmts.iter() {
|
||||
|
@ -145,58 +160,3 @@ fn codegen_cube(
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn expand_sig(
|
||||
sig: &syn::Signature,
|
||||
visibility: &syn::Visibility,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> proc_macro2::TokenStream {
|
||||
let mut inputs = quote::quote!();
|
||||
|
||||
for input in &sig.inputs {
|
||||
match input {
|
||||
syn::FnArg::Typed(pat) => {
|
||||
let ident = pat.pat.clone();
|
||||
|
||||
if let syn::Pat::Ident(ident) = ident.as_ref() {
|
||||
variable_tracker.codegen_declare(ident.ident.to_string(), 0);
|
||||
}
|
||||
|
||||
let ty = no_ref(pat.ty.as_ref());
|
||||
inputs.extend(quote::quote! {
|
||||
#ident: <#ty as burn_cube::frontend::CubeType>::ExpandType,
|
||||
});
|
||||
}
|
||||
_ => todo!("Only Typed inputs are supported"),
|
||||
}
|
||||
}
|
||||
|
||||
let mut output = quote::quote!();
|
||||
|
||||
match &sig.output {
|
||||
syn::ReturnType::Default => output.extend(quote::quote! { ()}),
|
||||
syn::ReturnType::Type(_, ty) => {
|
||||
let ty = no_ref(ty.as_ref());
|
||||
output.extend(quote::quote! {
|
||||
<#ty as burn_cube::frontend::CubeType>::ExpandType
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let ident = &sig.ident;
|
||||
let ident = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span());
|
||||
|
||||
let generics = sig.generics.clone().into_token_stream();
|
||||
|
||||
quote::quote! {
|
||||
/// Expanded Cube function
|
||||
#visibility fn #ident #generics (context: &mut burn_cube::frontend::CubeContext, #inputs) -> #output
|
||||
}
|
||||
}
|
||||
|
||||
fn no_ref(ty: &syn::Type) -> &syn::Type {
|
||||
match ty {
|
||||
syn::Type::Reference(val) => &val.elem,
|
||||
_ => ty,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
use burn_cube::prelude::*;
|
||||
|
||||
#[cube]
|
||||
trait FunctionGeneric {
|
||||
#[allow(unused)]
|
||||
fn test<C: Float>(lhs: C, rhs: C) -> C;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
trait TraitGeneric<C: Float> {
|
||||
#[allow(unused)]
|
||||
fn test(lhs: C, rhs: C) -> C;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
trait CombinedTraitFunctionGeneric<C: Float> {
|
||||
#[allow(unused)]
|
||||
fn test<O: Numeric>(lhs: C, rhs: C) -> O;
|
||||
}
|
||||
|
||||
struct Test;
|
||||
|
||||
#[cube]
|
||||
impl FunctionGeneric for Test {
|
||||
fn test<C: Float>(lhs: C, rhs: C) -> C {
|
||||
lhs + rhs
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<C: Float> TraitGeneric<C> for Test {
|
||||
fn test(lhs: C, rhs: C) -> C {
|
||||
lhs + rhs
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<C: Float> CombinedTraitFunctionGeneric<C> for Test {
|
||||
fn test<O: Numeric>(lhs: C, rhs: C) -> O {
|
||||
O::cast_from(lhs + rhs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn simple<C: Float>(lhs: C, rhs: C) -> C {
|
||||
lhs + rhs
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn with_cast<C: Float, O: Numeric>(lhs: C, rhs: C) -> O {
|
||||
O::cast_from(lhs + rhs)
|
||||
}
|
||||
|
||||
mod tests {
|
||||
use burn_cube::ir::{Item, Scope};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_function_generic() {
|
||||
let mut context = CubeContext::root();
|
||||
let lhs = context.create_local(Item::new(F32::as_elem()));
|
||||
let rhs = context.create_local(Item::new(F32::as_elem()));
|
||||
|
||||
<Test as FunctionGeneric>::test_expand::<F32>(&mut context, lhs, rhs);
|
||||
|
||||
assert_eq!(simple_scope(), context.into_scope());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trait_generic() {
|
||||
let mut context = CubeContext::root();
|
||||
let lhs = context.create_local(Item::new(F32::as_elem()));
|
||||
let rhs = context.create_local(Item::new(F32::as_elem()));
|
||||
|
||||
<Test as TraitGeneric<F32>>::test_expand(&mut context, lhs, rhs);
|
||||
|
||||
assert_eq!(simple_scope(), context.into_scope());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_combined_function_generic() {
|
||||
let mut context = CubeContext::root();
|
||||
let lhs = context.create_local(Item::new(F32::as_elem()));
|
||||
let rhs = context.create_local(Item::new(F32::as_elem()));
|
||||
|
||||
<Test as CombinedTraitFunctionGeneric<F32>>::test_expand::<UInt>(&mut context, lhs, rhs);
|
||||
|
||||
assert_eq!(with_cast_scope(), context.into_scope());
|
||||
}
|
||||
|
||||
fn simple_scope() -> Scope {
|
||||
let mut context_ref = CubeContext::root();
|
||||
let lhs = context_ref.create_local(Item::new(F32::as_elem()));
|
||||
let rhs = context_ref.create_local(Item::new(F32::as_elem()));
|
||||
|
||||
simple_expand::<F32>(&mut context_ref, lhs, rhs);
|
||||
context_ref.into_scope()
|
||||
}
|
||||
|
||||
fn with_cast_scope() -> Scope {
|
||||
let mut context_ref = CubeContext::root();
|
||||
let lhs = context_ref.create_local(Item::new(F32::as_elem()));
|
||||
let rhs = context_ref.create_local(Item::new(F32::as_elem()));
|
||||
|
||||
with_cast_expand::<F32, UInt>(&mut context_ref, lhs, rhs);
|
||||
context_ref.into_scope()
|
||||
}
|
||||
}
|
|
@ -3,6 +3,7 @@ mod assign;
|
|||
mod cast_elem;
|
||||
mod cast_kind;
|
||||
mod comptime;
|
||||
mod cube_trait;
|
||||
mod for_loop;
|
||||
mod function_call;
|
||||
mod generic_kernel;
|
||||
|
|
|
@ -1,22 +1,56 @@
|
|||
use super::{index_offset_with_layout, index_offset_with_layout_expand, Kernel};
|
||||
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
|
||||
use burn_cube::{
|
||||
calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor, unexpanded, Runtime,
|
||||
calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor, Runtime,
|
||||
SUBCUBE_DIM_APPROX,
|
||||
};
|
||||
use burn_tensor::Shape;
|
||||
|
||||
#[cube]
|
||||
pub(crate) trait ComparisonOp<C: Numeric>: 'static + Send + Sync {
|
||||
/// Execute a comparison operation.
|
||||
#[allow(unused_variables)]
|
||||
fn execute(lhs: C, rhs: C) -> bool {
|
||||
unexpanded!();
|
||||
fn execute(lhs: C, rhs: C) -> bool;
|
||||
}
|
||||
|
||||
struct EqualOp;
|
||||
struct GreaterEqualOp;
|
||||
struct LowerEqualOp;
|
||||
struct GreaterOp;
|
||||
struct LowerOp;
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> ComparisonOp<N> for EqualOp {
|
||||
fn execute(lhs: N, rhs: N) -> bool {
|
||||
lhs == rhs
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> ComparisonOp<N> for GreaterEqualOp {
|
||||
fn execute(lhs: N, rhs: N) -> bool {
|
||||
lhs >= rhs
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> ComparisonOp<N> for LowerEqualOp {
|
||||
fn execute(lhs: N, rhs: N) -> bool {
|
||||
lhs <= rhs
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> ComparisonOp<N> for GreaterOp {
|
||||
fn execute(lhs: N, rhs: N) -> bool {
|
||||
lhs > rhs
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> ComparisonOp<N> for LowerOp {
|
||||
fn execute(lhs: N, rhs: N) -> bool {
|
||||
lhs < rhs
|
||||
}
|
||||
fn execute_expand(
|
||||
context: &mut CubeContext,
|
||||
lhs: C::ExpandType,
|
||||
rhs: C::ExpandType,
|
||||
) -> <bool as CubeType>::ExpandType;
|
||||
}
|
||||
|
||||
#[cube(launch)]
|
||||
|
@ -271,82 +305,6 @@ pub(crate) fn launch_scalar_cmp<
|
|||
}
|
||||
}
|
||||
|
||||
struct EqualOp;
|
||||
struct GreaterEqualOp;
|
||||
struct LowerEqualOp;
|
||||
struct GreaterOp;
|
||||
struct LowerOp;
|
||||
|
||||
impl<N: Numeric> ComparisonOp<N> for EqualOp {
|
||||
fn execute_expand(
|
||||
context: &mut CubeContext,
|
||||
lhs: <N>::ExpandType,
|
||||
rhs: <N>::ExpandType,
|
||||
) -> <bool as CubeType>::ExpandType {
|
||||
#[cube]
|
||||
fn cmp<N: Numeric>(lhs: N, rhs: N) -> bool {
|
||||
lhs == rhs
|
||||
}
|
||||
cmp_expand::<N>(context, lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: Numeric> ComparisonOp<N> for GreaterEqualOp {
|
||||
fn execute_expand(
|
||||
context: &mut CubeContext,
|
||||
lhs: <N>::ExpandType,
|
||||
rhs: <N>::ExpandType,
|
||||
) -> <bool as CubeType>::ExpandType {
|
||||
#[cube]
|
||||
fn cmp<N: Numeric>(lhs: N, rhs: N) -> bool {
|
||||
lhs >= rhs
|
||||
}
|
||||
cmp_expand::<N>(context, lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: Numeric> ComparisonOp<N> for LowerEqualOp {
|
||||
fn execute_expand(
|
||||
context: &mut CubeContext,
|
||||
lhs: <N>::ExpandType,
|
||||
rhs: <N>::ExpandType,
|
||||
) -> <bool as CubeType>::ExpandType {
|
||||
#[cube]
|
||||
fn cmp<N: Numeric>(lhs: N, rhs: N) -> bool {
|
||||
lhs <= rhs
|
||||
}
|
||||
cmp_expand::<N>(context, lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: Numeric> ComparisonOp<N> for GreaterOp {
|
||||
fn execute_expand(
|
||||
context: &mut CubeContext,
|
||||
lhs: <N>::ExpandType,
|
||||
rhs: <N>::ExpandType,
|
||||
) -> <bool as CubeType>::ExpandType {
|
||||
#[cube]
|
||||
fn cmp<N: Numeric>(lhs: N, rhs: N) -> bool {
|
||||
lhs > rhs
|
||||
}
|
||||
cmp_expand::<N>(context, lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: Numeric> ComparisonOp<N> for LowerOp {
|
||||
fn execute_expand(
|
||||
context: &mut CubeContext,
|
||||
lhs: <N>::ExpandType,
|
||||
rhs: <N>::ExpandType,
|
||||
) -> <bool as CubeType>::ExpandType {
|
||||
#[cube]
|
||||
fn cmp<N: Numeric>(lhs: N, rhs: N) -> bool {
|
||||
lhs < rhs
|
||||
}
|
||||
cmp_expand::<N>(context, lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn equal<R: JitRuntime, E: JitElement, const D: usize>(
|
||||
lhs: JitTensor<R, E, D>,
|
||||
rhs: JitTensor<R, E, D>,
|
||||
|
|
Loading…
Reference in New Issue