Cube: CubeType (no launch) and Comptime::map (#1853)

This commit is contained in:
Louis Fortier-Dubois 2024-06-04 13:43:43 -04:00 committed by GitHub
parent a5af19b959
commit c42abadfe9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 209 additions and 69 deletions

View File

@ -1,6 +1,6 @@
use std::collections::HashMap;
use syn::{Member, PathArguments, Stmt};
use syn::{Member, Pat, PathArguments, Stmt};
use crate::variable_key::VariableKey;
@ -310,12 +310,25 @@ impl CodeAnalysisBuilder {
}
syn::Expr::Reference(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
syn::Expr::Closure(expr) => {
assert!(
expr.inputs.is_empty(),
"Analysis: closure with args not supported"
);
let depth = depth + 1;
self.find_occurrences_in_expr(&expr.body, depth + 1)
for path in expr.inputs.iter() {
let ident = match path {
Pat::Ident(pat_ident) => &pat_ident.ident,
Pat::Type(pat_type) => {
if let Pat::Ident(pat_ident) = &*pat_type.pat {
&pat_ident.ident
} else {
todo!("Analysis: {:?} not supported in closure inputs. ", path);
}
}
_ => todo!("Analysis: {:?} not supported in closure inputs. ", path),
};
self.declarations.push(((ident).into(), depth));
}
self.find_occurrences_in_expr(&expr.body, depth)
}
syn::Expr::Unary(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
syn::Expr::Field(expr) => {

View File

@ -32,13 +32,28 @@ pub(crate) fn codegen_closure(
) -> TokenStream {
let mut inputs = quote::quote! {};
for input in closure.inputs.iter() {
let ident = match input {
syn::Pat::Ident(ident) => &ident.ident,
let (ident, ty) = match input {
syn::Pat::Ident(ident) => (&ident.ident, None),
syn::Pat::Type(pat_type) => (
if let syn::Pat::Ident(ident) = &*pat_type.pat {
&ident.ident
} else {
panic!("Codegen: Unsupported {:?}", input);
},
Some(pat_type.ty.clone()),
),
_ => panic!("Codegen: Unsupported {:?}", input),
};
inputs.extend(quote::quote! {
#ident,
});
if let Some(ty) = ty {
inputs.extend(quote::quote! {
#ident : #ty,
});
} else {
inputs.extend(quote::quote! {
#ident,
});
}
}
let body = codegen_expr(closure.body.as_ref(), loop_level, variable_analyses);
@ -124,6 +139,14 @@ pub(crate) fn parse_function_call(
let code = call.args.first().unwrap();
quote::quote! {#code}
}
"map" => {
let args = codegen_args(&call.args, loop_level, variable_analyses);
// Codegen
quote::quote! {
Comptime::map_expand(#args)
}
}
"unwrap_or_else" => {
let args = codegen_args(&call.args, loop_level, variable_analyses);

View File

@ -176,7 +176,7 @@ impl TypeCodegen {
}
}
pub(crate) fn generate_cube_type(ast: &syn::DeriveInput) -> TokenStream {
pub(crate) fn generate_cube_type(ast: &syn::DeriveInput, with_launch: bool) -> TokenStream {
let name = ast.ident.clone();
let generics = ast.generics.clone();
let name_string = name.to_string();
@ -210,14 +210,22 @@ pub(crate) fn generate_cube_type(ast: &syn::DeriveInput) -> TokenStream {
let arg_settings_impl = codegen.arg_settings_impl();
let launch_arg_impl = codegen.launch_arg_impl();
quote! {
#expand_ty
#launch_ty
#launch_new
if with_launch {
quote! {
#expand_ty
#launch_ty
#launch_new
#cube_type_impl
#arg_settings_impl
#launch_arg_impl
#cube_type_impl
#arg_settings_impl
#launch_arg_impl
}
.into()
} else {
quote! {
#expand_ty
#cube_type_impl
}
.into()
}
.into()
}

View File

@ -18,12 +18,20 @@ enum CubeMode {
Debug,
}
// Derive macro to define a cube type.
#[proc_macro_derive(Cube)]
pub fn module_derive(input: TokenStream) -> TokenStream {
// Derive macro to define a cube type that is launched with a kernel
#[proc_macro_derive(CubeLaunch)]
pub fn module_derive_cube_launch(input: TokenStream) -> TokenStream {
let input = syn::parse(input).unwrap();
generate_cube_type(&input)
generate_cube_type(&input, true)
}
// Derive macro to define a cube type that is not launched
#[proc_macro_derive(CubeType)]
pub fn module_derive_cube_type(input: TokenStream) -> TokenStream {
let input = syn::parse(input).unwrap();
generate_cube_type(&input, false)
}
/// Derive macro for the module.

View File

@ -25,6 +25,18 @@ impl<T> Comptime<T> {
pub fn get(_comptime: Self) -> T {
unexpanded!()
}
pub fn map<R, F: Fn(T) -> R>(_comptime: Self, _closure: F) -> Comptime<R> {
unexpanded!()
}
pub fn map_expand<R, F: Fn(&mut CubeContext, T) -> R>(
context: &mut CubeContext,
inner: T,
closure: F,
) -> R {
closure(context, inner)
}
}
impl<T: CubeType + Into<T::ExpandType>> Comptime<Option<T>> {

View File

@ -21,7 +21,8 @@ pub use pod::*;
pub use runtime::*;
pub use burn_cube_macros::cube;
pub use burn_cube_macros::Cube;
pub use burn_cube_macros::CubeLaunch;
pub use burn_cube_macros::CubeType;
/// An approximation of the subcube dimension.
pub const SUBCUBE_DIM_APPROX: usize = 16;

View File

@ -1,4 +1,4 @@
pub use crate::{cube, Cube, Kernel, RuntimeArg};
pub use crate::{cube, CubeLaunch, CubeType, Kernel, RuntimeArg};
pub use crate::codegen::{KernelExpansion, KernelIntegrator, KernelSettings};
pub use crate::compute::{

View File

@ -1,12 +1,9 @@
use burn_cube::prelude::*;
#[cube]
pub fn if_then_else<F: Float>(lhs: F) {
if lhs < F::from_int(0) {
let _ = lhs + F::from_int(4);
} else {
let _ = lhs - F::from_int(5);
}
#[derive(Clone)]
pub struct State {
cond: bool,
bound: u32,
}
#[cube]
@ -18,29 +15,41 @@ pub fn comptime_if_else<T: Numeric>(lhs: T, cond: Comptime<bool>) {
}
}
#[cube]
pub fn comptime_with_map_bool<T: Numeric>(state: Comptime<State>) -> T {
let cond = Comptime::map(state, |s: State| s.cond);
let mut x = T::from_int(3);
if Comptime::get(cond) {
x += T::from_int(4);
} else {
x -= T::from_int(4);
}
x
}
#[cube]
pub fn comptime_with_map_uint<T: Numeric>(state: Comptime<State>) -> T {
let bound = Comptime::map(state, |s: State| s.bound);
let mut x = T::from_int(3);
for _ in range(0u32, Comptime::get(bound), Comptime::new(true)) {
x += T::from_int(4);
}
x
}
mod tests {
use super::*;
use burn_cube::{
cpa,
frontend::{CubeContext, CubeElem, F32},
ir::{Elem, Item, Variable},
ir::{Item, Variable},
};
use super::{comptime_if_else_expand, if_then_else_expand};
type ElemType = F32;
#[test]
fn cube_if_else_test() {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::new(ElemType::as_elem()));
if_then_else_expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref());
}
#[test]
fn cube_comptime_if_test() {
let mut context = CubeContext::root();
@ -71,24 +80,46 @@ mod tests {
);
}
fn inline_macro_ref() -> String {
#[test]
fn cube_comptime_map_bool_test() {
let mut context1 = CubeContext::root();
let mut context2 = CubeContext::root();
let comptime_state_true = State {
cond: true,
bound: 4,
};
let comptime_state_false = State {
cond: false,
bound: 4,
};
comptime_with_map_bool_expand::<ElemType>(&mut context1, comptime_state_true);
comptime_with_map_bool_expand::<ElemType>(&mut context2, comptime_state_false);
let scope1 = context1.into_scope();
let scope2 = context2.into_scope();
assert_ne!(
format!("{:?}", scope1.operations),
format!("{:?}", scope2.operations)
);
}
#[test]
fn cube_comptime_map_uint_test() {
let mut context = CubeContext::root();
let item = Item::new(ElemType::as_elem());
let lhs = context.create_local(item);
let mut scope = context.into_scope();
let cond = scope.create_local(Item::new(Elem::Bool));
let lhs: Variable = lhs.into();
let y = scope.create_local(item);
let comptime_state = State {
cond: true,
bound: 4,
};
cpa!(scope, cond = lhs < 0f32);
cpa!(&mut scope, if(cond).then(|scope| {
cpa!(scope, y = lhs + 4.0f32);
}).else(|scope|{
cpa!(scope, y = lhs - 5.0f32);
}));
comptime_with_map_uint_expand::<ElemType>(&mut context, comptime_state);
format!("{:?}", scope.operations)
let scope = context.into_scope();
assert!(!format!("{:?}", scope.operations).contains("RangeLoop"));
}
fn inline_macro_ref_comptime(cond: bool) -> String {

View File

@ -1,4 +1,4 @@
use burn_cube::{cube, frontend::Numeric};
use burn_cube::prelude::*;
#[cube]
pub fn if_greater<T: Numeric>(lhs: T) {
@ -15,6 +15,15 @@ pub fn if_greater_var<T: Numeric>(lhs: T) {
}
}
#[cube]
pub fn if_then_else<F: Float>(lhs: F) {
if lhs < F::from_int(0) {
let _ = lhs + F::from_int(4);
} else {
let _ = lhs - F::from_int(5);
}
}
mod tests {
use burn_cube::{
cpa,
@ -22,7 +31,7 @@ mod tests {
ir::{Elem, Item, Variable},
};
use super::if_greater_expand;
use super::*;
type ElemType = F32;
@ -35,10 +44,25 @@ mod tests {
if_greater_expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref());
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_if());
}
fn inline_macro_ref() -> String {
#[test]
fn cube_if_else_test() {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::new(ElemType::as_elem()));
if_then_else_expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();
assert_eq!(
format!("{:?}", scope.operations),
inline_macro_ref_if_else()
);
}
fn inline_macro_ref_if() -> String {
let mut context = CubeContext::root();
let item = Item::new(ElemType::as_elem());
let lhs = context.create_local(item);
@ -55,4 +79,24 @@ mod tests {
format!("{:?}", scope.operations)
}
fn inline_macro_ref_if_else() -> String {
let mut context = CubeContext::root();
let item = Item::new(ElemType::as_elem());
let lhs = context.create_local(item);
let mut scope = context.into_scope();
let cond = scope.create_local(Item::new(Elem::Bool));
let lhs: Variable = lhs.into();
let y = scope.create_local(item);
cpa!(scope, cond = lhs < 0f32);
cpa!(&mut scope, if(cond).then(|scope| {
cpa!(scope, y = lhs + 4.0f32);
}).else(|scope|{
cpa!(scope, y = lhs - 5.0f32);
}));
format!("{:?}", scope.operations)
}
}

View File

@ -1,10 +1,10 @@
mod cast_elem;
mod cast_kind;
mod comptime;
mod for_loop;
mod function_call;
mod generic_kernel;
mod r#if;
mod if_else;
mod literal;
mod r#loop;
mod module_import;

View File

@ -1,6 +1,6 @@
use burn_cube::prelude::*;
#[derive(Cube)]
#[derive(CubeType)]
struct State<T: Numeric> {
first: T,
second: T,

View File

@ -15,7 +15,7 @@ use crate::{
FloatElement, JitRuntime,
};
#[derive(Cube)]
#[derive(CubeLaunch)]
struct Conv2dArgs {
conv_stride_0: UInt,
conv_stride_1: UInt,