Feat: Support trait with CubeCL (#1980)

This commit is contained in:
Nathaniel Simard 2024-07-07 10:07:51 -04:00 committed by GitHub
parent c9e9054167
commit 8af2b719a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 341 additions and 154 deletions

View File

@ -0,0 +1 @@
pub(crate) mod signature;

View File

@ -0,0 +1,60 @@
use quote::ToTokens;
use crate::tracker::VariableTracker;
pub fn expand_sig(
sig: &syn::Signature,
visibility: &syn::Visibility,
mut variable_tracker: Option<&mut VariableTracker>,
) -> proc_macro2::TokenStream {
let mut inputs = quote::quote!();
for input in &sig.inputs {
match input {
syn::FnArg::Typed(pat) => {
let ident = pat.pat.clone();
if let syn::Pat::Ident(ident) = ident.as_ref() {
if let Some(vars) = &mut variable_tracker {
vars.codegen_declare(ident.ident.to_string(), 0);
}
}
let ty = no_ref(pat.ty.as_ref());
inputs.extend(quote::quote! {
#ident: <#ty as burn_cube::frontend::CubeType>::ExpandType,
});
}
_ => todo!("Only Typed inputs are supported"),
}
}
let mut output = quote::quote!();
match &sig.output {
syn::ReturnType::Default => output.extend(quote::quote! { ()}),
syn::ReturnType::Type(_, ty) => {
let ty = no_ref(ty.as_ref());
output.extend(quote::quote! {
<#ty as burn_cube::frontend::CubeType>::ExpandType
});
}
}
let ident = &sig.ident;
let ident = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span());
let generics = sig.generics.clone().into_token_stream();
quote::quote! {
/// Expanded Cube function
#visibility fn #ident #generics (context: &mut burn_cube::frontend::CubeContext, #inputs) -> #output
}
}
pub fn no_ref(ty: &syn::Type) -> &syn::Type {
match ty {
syn::Type::Reference(val) => &val.elem,
_ => ty,
}
}

View File

@ -0,0 +1,98 @@
use crate::codegen_common::signature::expand_sig;
pub fn expand_trait_def(mut tr: syn::ItemTrait) -> proc_macro2::TokenStream {
let mut expand_items = Vec::new();
for item in tr.items.iter() {
match item {
syn::TraitItem::Fn(func) => {
let expand = expand_sig(&func.sig, &syn::Visibility::Inherited, None);
expand_items.push(syn::parse_quote!(#expand;));
}
_ => continue,
}
}
tr.items.append(&mut expand_items);
quote::quote! {
#tr
}
}
pub fn expand_trait_impl(mut tr: syn::ItemImpl) -> proc_macro2::TokenStream {
let mut expand_items = Vec::new();
for item in tr.items.iter() {
match item {
syn::ImplItem::Fn(func) => {
let ident = &func.sig.ident;
let ident = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span());
let mut inputs = quote::quote!();
for input in &func.sig.inputs {
match input {
syn::FnArg::Typed(pat) => {
let ident = pat.pat.clone();
inputs.extend(quote::quote! {
#ident,
});
}
_ => todo!("Only Typed inputs are supported"),
}
}
let expand = expand_sig(&func.sig, &syn::Visibility::Inherited, None);
let tokens = if !tr.generics.params.is_empty() {
let mut func = func.clone();
for param in tr.generics.params.iter() {
func.sig.generics.params.push(param.clone());
}
register_expand(&func, &ident, expand, inputs)
} else {
register_expand(func, &ident, expand, inputs)
};
expand_items.push(syn::parse2(tokens).unwrap());
}
_ => continue,
}
}
tr.items.append(&mut expand_items);
quote::quote! {
#tr
}
}
fn register_expand(
func: &syn::ImplItemFn,
name: &syn::Ident,
expand: proc_macro2::TokenStream,
inputs: proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
let (func, func_expand) = if func.sig.generics.params.is_empty() {
(
quote::quote! { #func },
quote::quote! {
#name(context, #inputs)
},
)
} else {
let (_, gen, _) = &func.sig.generics.split_for_impl();
(
quote::quote! { #func },
quote::quote! {
#name::#gen(context, #inputs)
},
)
};
quote::quote! (
#expand {
#[cube]
#func
#func_expand
}
)
}

View File

@ -3,14 +3,18 @@ extern crate derive_new;
mod analyzer;
mod codegen_function;
mod codegen_trait;
mod codegen_type;
mod tracker;
pub(crate) mod codegen_common;
use analyzer::VariableAnalyzer;
use codegen_common::signature::expand_sig;
use codegen_function::{codegen_launch, codegen_statement};
use codegen_trait::{expand_trait_def, expand_trait_impl};
use codegen_type::generate_cube_type;
use proc_macro::TokenStream;
use quote::ToTokens;
use syn::{parse_macro_input, punctuated::Punctuated, token::Comma, Meta};
use tracker::VariableTracker;
@ -38,20 +42,36 @@ pub fn module_derive_cube_type(input: TokenStream) -> TokenStream {
generate_cube_type(&input, false)
}
struct SupportedAttributes {
mode: CubeMode,
launch: bool,
}
/// Derive macro for the module.
#[proc_macro_attribute]
pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr with Punctuated::<Meta, syn::Token![,]>::parse_terminated);
let (mode, launch) = parse_attributes(&args);
let attrs = parse_attributes(&args);
let func: syn::ItemFn =
syn::parse(tokens).expect("Cube annotations only supported for functions");
let code: TokenStream = match syn::parse::<syn::Item>(tokens).unwrap() {
syn::Item::Fn(func) => cube_fn(func, &attrs),
syn::Item::Impl(item) => expand_trait_impl(item).into(),
syn::Item::Trait(item) => expand_trait_def(item).into(),
_ => panic!("Cube annotations only supported for functions"),
};
match attrs.mode {
CubeMode::Default => code,
CubeMode::Debug => panic!("{code}"),
}
}
fn cube_fn(func: syn::ItemFn, attrs: &SupportedAttributes) -> TokenStream {
let mut variable_tracker = VariableAnalyzer::create_tracker(&func);
let code: TokenStream = match codegen_cube(&func, &mut variable_tracker) {
match codegen_cube(&func, &mut variable_tracker) {
Ok(code) => {
if launch {
if attrs.launch {
let launch = codegen_launch(&func.sig);
quote::quote! {
@ -64,15 +84,10 @@ pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream {
}
}
Err(err) => err.into(),
};
match mode {
CubeMode::Default => code,
CubeMode::Debug => panic!("{code}"),
}
}
fn parse_attributes(args: &Punctuated<Meta, Comma>) -> (CubeMode, bool) {
fn parse_attributes(args: &Punctuated<Meta, Comma>) -> SupportedAttributes {
let mut mode = CubeMode::Default;
let mut launch = false;
@ -98,7 +113,7 @@ fn parse_attributes(args: &Punctuated<Meta, Comma>) -> (CubeMode, bool) {
}
}
(mode, launch)
SupportedAttributes { mode, launch }
}
/// Generate the expanded version of a function marked with the cube macro
@ -106,7 +121,7 @@ fn codegen_cube(
func: &syn::ItemFn,
variable_tracker: &mut VariableTracker,
) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
let signature = expand_sig(&func.sig, &func.vis, variable_tracker);
let signature = expand_sig(&func.sig, &func.vis, Some(variable_tracker));
let mut body = quote::quote! {};
for statement in func.block.stmts.iter() {
@ -145,58 +160,3 @@ fn codegen_cube(
}
})
}
fn expand_sig(
sig: &syn::Signature,
visibility: &syn::Visibility,
variable_tracker: &mut VariableTracker,
) -> proc_macro2::TokenStream {
let mut inputs = quote::quote!();
for input in &sig.inputs {
match input {
syn::FnArg::Typed(pat) => {
let ident = pat.pat.clone();
if let syn::Pat::Ident(ident) = ident.as_ref() {
variable_tracker.codegen_declare(ident.ident.to_string(), 0);
}
let ty = no_ref(pat.ty.as_ref());
inputs.extend(quote::quote! {
#ident: <#ty as burn_cube::frontend::CubeType>::ExpandType,
});
}
_ => todo!("Only Typed inputs are supported"),
}
}
let mut output = quote::quote!();
match &sig.output {
syn::ReturnType::Default => output.extend(quote::quote! { ()}),
syn::ReturnType::Type(_, ty) => {
let ty = no_ref(ty.as_ref());
output.extend(quote::quote! {
<#ty as burn_cube::frontend::CubeType>::ExpandType
});
}
}
let ident = &sig.ident;
let ident = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span());
let generics = sig.generics.clone().into_token_stream();
quote::quote! {
/// Expanded Cube function
#visibility fn #ident #generics (context: &mut burn_cube::frontend::CubeContext, #inputs) -> #output
}
}
fn no_ref(ty: &syn::Type) -> &syn::Type {
match ty {
syn::Type::Reference(val) => &val.elem,
_ => ty,
}
}

View File

@ -0,0 +1,109 @@
use burn_cube::prelude::*;
#[cube]
trait FunctionGeneric {
#[allow(unused)]
fn test<C: Float>(lhs: C, rhs: C) -> C;
}
#[cube]
trait TraitGeneric<C: Float> {
#[allow(unused)]
fn test(lhs: C, rhs: C) -> C;
}
#[cube]
trait CombinedTraitFunctionGeneric<C: Float> {
#[allow(unused)]
fn test<O: Numeric>(lhs: C, rhs: C) -> O;
}
struct Test;
#[cube]
impl FunctionGeneric for Test {
fn test<C: Float>(lhs: C, rhs: C) -> C {
lhs + rhs
}
}
#[cube]
impl<C: Float> TraitGeneric<C> for Test {
fn test(lhs: C, rhs: C) -> C {
lhs + rhs
}
}
#[cube]
impl<C: Float> CombinedTraitFunctionGeneric<C> for Test {
fn test<O: Numeric>(lhs: C, rhs: C) -> O {
O::cast_from(lhs + rhs)
}
}
#[cube]
fn simple<C: Float>(lhs: C, rhs: C) -> C {
lhs + rhs
}
#[cube]
fn with_cast<C: Float, O: Numeric>(lhs: C, rhs: C) -> O {
O::cast_from(lhs + rhs)
}
mod tests {
use burn_cube::ir::{Item, Scope};
use super::*;
#[test]
fn test_function_generic() {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::new(F32::as_elem()));
let rhs = context.create_local(Item::new(F32::as_elem()));
<Test as FunctionGeneric>::test_expand::<F32>(&mut context, lhs, rhs);
assert_eq!(simple_scope(), context.into_scope());
}
#[test]
fn test_trait_generic() {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::new(F32::as_elem()));
let rhs = context.create_local(Item::new(F32::as_elem()));
<Test as TraitGeneric<F32>>::test_expand(&mut context, lhs, rhs);
assert_eq!(simple_scope(), context.into_scope());
}
#[test]
fn test_combined_function_generic() {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::new(F32::as_elem()));
let rhs = context.create_local(Item::new(F32::as_elem()));
<Test as CombinedTraitFunctionGeneric<F32>>::test_expand::<UInt>(&mut context, lhs, rhs);
assert_eq!(with_cast_scope(), context.into_scope());
}
fn simple_scope() -> Scope {
let mut context_ref = CubeContext::root();
let lhs = context_ref.create_local(Item::new(F32::as_elem()));
let rhs = context_ref.create_local(Item::new(F32::as_elem()));
simple_expand::<F32>(&mut context_ref, lhs, rhs);
context_ref.into_scope()
}
fn with_cast_scope() -> Scope {
let mut context_ref = CubeContext::root();
let lhs = context_ref.create_local(Item::new(F32::as_elem()));
let rhs = context_ref.create_local(Item::new(F32::as_elem()));
with_cast_expand::<F32, UInt>(&mut context_ref, lhs, rhs);
context_ref.into_scope()
}
}

View File

@ -3,6 +3,7 @@ mod assign;
mod cast_elem;
mod cast_kind;
mod comptime;
mod cube_trait;
mod for_loop;
mod function_call;
mod generic_kernel;

View File

@ -1,22 +1,56 @@
use super::{index_offset_with_layout, index_offset_with_layout_expand, Kernel};
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
use burn_cube::{
calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor, unexpanded, Runtime,
calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor, Runtime,
SUBCUBE_DIM_APPROX,
};
use burn_tensor::Shape;
#[cube]
pub(crate) trait ComparisonOp<C: Numeric>: 'static + Send + Sync {
/// Execute a comparison operation.
#[allow(unused_variables)]
fn execute(lhs: C, rhs: C) -> bool {
unexpanded!();
fn execute(lhs: C, rhs: C) -> bool;
}
struct EqualOp;
struct GreaterEqualOp;
struct LowerEqualOp;
struct GreaterOp;
struct LowerOp;
#[cube]
impl<N: Numeric> ComparisonOp<N> for EqualOp {
fn execute(lhs: N, rhs: N) -> bool {
lhs == rhs
}
}
#[cube]
impl<N: Numeric> ComparisonOp<N> for GreaterEqualOp {
fn execute(lhs: N, rhs: N) -> bool {
lhs >= rhs
}
}
#[cube]
impl<N: Numeric> ComparisonOp<N> for LowerEqualOp {
fn execute(lhs: N, rhs: N) -> bool {
lhs <= rhs
}
}
#[cube]
impl<N: Numeric> ComparisonOp<N> for GreaterOp {
fn execute(lhs: N, rhs: N) -> bool {
lhs > rhs
}
}
#[cube]
impl<N: Numeric> ComparisonOp<N> for LowerOp {
fn execute(lhs: N, rhs: N) -> bool {
lhs < rhs
}
fn execute_expand(
context: &mut CubeContext,
lhs: C::ExpandType,
rhs: C::ExpandType,
) -> <bool as CubeType>::ExpandType;
}
#[cube(launch)]
@ -271,82 +305,6 @@ pub(crate) fn launch_scalar_cmp<
}
}
struct EqualOp;
struct GreaterEqualOp;
struct LowerEqualOp;
struct GreaterOp;
struct LowerOp;
impl<N: Numeric> ComparisonOp<N> for EqualOp {
fn execute_expand(
context: &mut CubeContext,
lhs: <N>::ExpandType,
rhs: <N>::ExpandType,
) -> <bool as CubeType>::ExpandType {
#[cube]
fn cmp<N: Numeric>(lhs: N, rhs: N) -> bool {
lhs == rhs
}
cmp_expand::<N>(context, lhs, rhs)
}
}
impl<N: Numeric> ComparisonOp<N> for GreaterEqualOp {
fn execute_expand(
context: &mut CubeContext,
lhs: <N>::ExpandType,
rhs: <N>::ExpandType,
) -> <bool as CubeType>::ExpandType {
#[cube]
fn cmp<N: Numeric>(lhs: N, rhs: N) -> bool {
lhs >= rhs
}
cmp_expand::<N>(context, lhs, rhs)
}
}
impl<N: Numeric> ComparisonOp<N> for LowerEqualOp {
fn execute_expand(
context: &mut CubeContext,
lhs: <N>::ExpandType,
rhs: <N>::ExpandType,
) -> <bool as CubeType>::ExpandType {
#[cube]
fn cmp<N: Numeric>(lhs: N, rhs: N) -> bool {
lhs <= rhs
}
cmp_expand::<N>(context, lhs, rhs)
}
}
impl<N: Numeric> ComparisonOp<N> for GreaterOp {
fn execute_expand(
context: &mut CubeContext,
lhs: <N>::ExpandType,
rhs: <N>::ExpandType,
) -> <bool as CubeType>::ExpandType {
#[cube]
fn cmp<N: Numeric>(lhs: N, rhs: N) -> bool {
lhs > rhs
}
cmp_expand::<N>(context, lhs, rhs)
}
}
impl<N: Numeric> ComparisonOp<N> for LowerOp {
fn execute_expand(
context: &mut CubeContext,
lhs: <N>::ExpandType,
rhs: <N>::ExpandType,
) -> <bool as CubeType>::ExpandType {
#[cube]
fn cmp<N: Numeric>(lhs: N, rhs: N) -> bool {
lhs < rhs
}
cmp_expand::<N>(context, lhs, rhs)
}
}
pub fn equal<R: JitRuntime, E: JitElement, const D: usize>(
lhs: JitTensor<R, E, D>,
rhs: JitTensor<R, E, D>,