mirror of https://github.com/tracel-ai/burn.git
Cube: CubeType (no launch) and Comptime::map (#1853)
This commit is contained in:
parent
a5af19b959
commit
c42abadfe9
|
@ -1,6 +1,6 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use syn::{Member, PathArguments, Stmt};
|
||||
use syn::{Member, Pat, PathArguments, Stmt};
|
||||
|
||||
use crate::variable_key::VariableKey;
|
||||
|
||||
|
@ -310,12 +310,25 @@ impl CodeAnalysisBuilder {
|
|||
}
|
||||
syn::Expr::Reference(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
|
||||
syn::Expr::Closure(expr) => {
|
||||
assert!(
|
||||
expr.inputs.is_empty(),
|
||||
"Analysis: closure with args not supported"
|
||||
);
|
||||
let depth = depth + 1;
|
||||
|
||||
self.find_occurrences_in_expr(&expr.body, depth + 1)
|
||||
for path in expr.inputs.iter() {
|
||||
let ident = match path {
|
||||
Pat::Ident(pat_ident) => &pat_ident.ident,
|
||||
Pat::Type(pat_type) => {
|
||||
if let Pat::Ident(pat_ident) = &*pat_type.pat {
|
||||
&pat_ident.ident
|
||||
} else {
|
||||
todo!("Analysis: {:?} not supported in closure inputs. ", path);
|
||||
}
|
||||
}
|
||||
_ => todo!("Analysis: {:?} not supported in closure inputs. ", path),
|
||||
};
|
||||
|
||||
self.declarations.push(((ident).into(), depth));
|
||||
}
|
||||
|
||||
self.find_occurrences_in_expr(&expr.body, depth)
|
||||
}
|
||||
syn::Expr::Unary(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
|
||||
syn::Expr::Field(expr) => {
|
||||
|
|
|
@ -32,13 +32,28 @@ pub(crate) fn codegen_closure(
|
|||
) -> TokenStream {
|
||||
let mut inputs = quote::quote! {};
|
||||
for input in closure.inputs.iter() {
|
||||
let ident = match input {
|
||||
syn::Pat::Ident(ident) => &ident.ident,
|
||||
let (ident, ty) = match input {
|
||||
syn::Pat::Ident(ident) => (&ident.ident, None),
|
||||
syn::Pat::Type(pat_type) => (
|
||||
if let syn::Pat::Ident(ident) = &*pat_type.pat {
|
||||
&ident.ident
|
||||
} else {
|
||||
panic!("Codegen: Unsupported {:?}", input);
|
||||
},
|
||||
Some(pat_type.ty.clone()),
|
||||
),
|
||||
_ => panic!("Codegen: Unsupported {:?}", input),
|
||||
};
|
||||
inputs.extend(quote::quote! {
|
||||
#ident,
|
||||
});
|
||||
|
||||
if let Some(ty) = ty {
|
||||
inputs.extend(quote::quote! {
|
||||
#ident : #ty,
|
||||
});
|
||||
} else {
|
||||
inputs.extend(quote::quote! {
|
||||
#ident,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let body = codegen_expr(closure.body.as_ref(), loop_level, variable_analyses);
|
||||
|
@ -124,6 +139,14 @@ pub(crate) fn parse_function_call(
|
|||
let code = call.args.first().unwrap();
|
||||
quote::quote! {#code}
|
||||
}
|
||||
"map" => {
|
||||
let args = codegen_args(&call.args, loop_level, variable_analyses);
|
||||
|
||||
// Codegen
|
||||
quote::quote! {
|
||||
Comptime::map_expand(#args)
|
||||
}
|
||||
}
|
||||
"unwrap_or_else" => {
|
||||
let args = codegen_args(&call.args, loop_level, variable_analyses);
|
||||
|
||||
|
|
|
@ -176,7 +176,7 @@ impl TypeCodegen {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn generate_cube_type(ast: &syn::DeriveInput) -> TokenStream {
|
||||
pub(crate) fn generate_cube_type(ast: &syn::DeriveInput, with_launch: bool) -> TokenStream {
|
||||
let name = ast.ident.clone();
|
||||
let generics = ast.generics.clone();
|
||||
let name_string = name.to_string();
|
||||
|
@ -210,14 +210,22 @@ pub(crate) fn generate_cube_type(ast: &syn::DeriveInput) -> TokenStream {
|
|||
let arg_settings_impl = codegen.arg_settings_impl();
|
||||
let launch_arg_impl = codegen.launch_arg_impl();
|
||||
|
||||
quote! {
|
||||
#expand_ty
|
||||
#launch_ty
|
||||
#launch_new
|
||||
if with_launch {
|
||||
quote! {
|
||||
#expand_ty
|
||||
#launch_ty
|
||||
#launch_new
|
||||
|
||||
#cube_type_impl
|
||||
#arg_settings_impl
|
||||
#launch_arg_impl
|
||||
#cube_type_impl
|
||||
#arg_settings_impl
|
||||
#launch_arg_impl
|
||||
}
|
||||
.into()
|
||||
} else {
|
||||
quote! {
|
||||
#expand_ty
|
||||
#cube_type_impl
|
||||
}
|
||||
.into()
|
||||
}
|
||||
.into()
|
||||
}
|
||||
|
|
|
@ -18,12 +18,20 @@ enum CubeMode {
|
|||
Debug,
|
||||
}
|
||||
|
||||
// Derive macro to define a cube type.
|
||||
#[proc_macro_derive(Cube)]
|
||||
pub fn module_derive(input: TokenStream) -> TokenStream {
|
||||
// Derive macro to define a cube type that is launched with a kernel
|
||||
#[proc_macro_derive(CubeLaunch)]
|
||||
pub fn module_derive_cube_launch(input: TokenStream) -> TokenStream {
|
||||
let input = syn::parse(input).unwrap();
|
||||
|
||||
generate_cube_type(&input)
|
||||
generate_cube_type(&input, true)
|
||||
}
|
||||
|
||||
// Derive macro to define a cube type that is not launched
|
||||
#[proc_macro_derive(CubeType)]
|
||||
pub fn module_derive_cube_type(input: TokenStream) -> TokenStream {
|
||||
let input = syn::parse(input).unwrap();
|
||||
|
||||
generate_cube_type(&input, false)
|
||||
}
|
||||
|
||||
/// Derive macro for the module.
|
||||
|
|
|
@ -25,6 +25,18 @@ impl<T> Comptime<T> {
|
|||
pub fn get(_comptime: Self) -> T {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
pub fn map<R, F: Fn(T) -> R>(_comptime: Self, _closure: F) -> Comptime<R> {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
pub fn map_expand<R, F: Fn(&mut CubeContext, T) -> R>(
|
||||
context: &mut CubeContext,
|
||||
inner: T,
|
||||
closure: F,
|
||||
) -> R {
|
||||
closure(context, inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: CubeType + Into<T::ExpandType>> Comptime<Option<T>> {
|
||||
|
|
|
@ -21,7 +21,8 @@ pub use pod::*;
|
|||
pub use runtime::*;
|
||||
|
||||
pub use burn_cube_macros::cube;
|
||||
pub use burn_cube_macros::Cube;
|
||||
pub use burn_cube_macros::CubeLaunch;
|
||||
pub use burn_cube_macros::CubeType;
|
||||
|
||||
/// An approximation of the subcube dimension.
|
||||
pub const SUBCUBE_DIM_APPROX: usize = 16;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
pub use crate::{cube, Cube, Kernel, RuntimeArg};
|
||||
pub use crate::{cube, CubeLaunch, CubeType, Kernel, RuntimeArg};
|
||||
|
||||
pub use crate::codegen::{KernelExpansion, KernelIntegrator, KernelSettings};
|
||||
pub use crate::compute::{
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
use burn_cube::prelude::*;
|
||||
|
||||
#[cube]
|
||||
pub fn if_then_else<F: Float>(lhs: F) {
|
||||
if lhs < F::from_int(0) {
|
||||
let _ = lhs + F::from_int(4);
|
||||
} else {
|
||||
let _ = lhs - F::from_int(5);
|
||||
}
|
||||
#[derive(Clone)]
|
||||
pub struct State {
|
||||
cond: bool,
|
||||
bound: u32,
|
||||
}
|
||||
|
||||
#[cube]
|
||||
|
@ -18,29 +15,41 @@ pub fn comptime_if_else<T: Numeric>(lhs: T, cond: Comptime<bool>) {
|
|||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub fn comptime_with_map_bool<T: Numeric>(state: Comptime<State>) -> T {
|
||||
let cond = Comptime::map(state, |s: State| s.cond);
|
||||
|
||||
let mut x = T::from_int(3);
|
||||
if Comptime::get(cond) {
|
||||
x += T::from_int(4);
|
||||
} else {
|
||||
x -= T::from_int(4);
|
||||
}
|
||||
x
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub fn comptime_with_map_uint<T: Numeric>(state: Comptime<State>) -> T {
|
||||
let bound = Comptime::map(state, |s: State| s.bound);
|
||||
|
||||
let mut x = T::from_int(3);
|
||||
for _ in range(0u32, Comptime::get(bound), Comptime::new(true)) {
|
||||
x += T::from_int(4);
|
||||
}
|
||||
|
||||
x
|
||||
}
|
||||
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_cube::{
|
||||
cpa,
|
||||
frontend::{CubeContext, CubeElem, F32},
|
||||
ir::{Elem, Item, Variable},
|
||||
ir::{Item, Variable},
|
||||
};
|
||||
|
||||
use super::{comptime_if_else_expand, if_then_else_expand};
|
||||
|
||||
type ElemType = F32;
|
||||
|
||||
#[test]
|
||||
fn cube_if_else_test() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
if_then_else_expand::<ElemType>(&mut context, lhs);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cube_comptime_if_test() {
|
||||
let mut context = CubeContext::root();
|
||||
|
@ -71,24 +80,46 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
fn inline_macro_ref() -> String {
|
||||
#[test]
|
||||
fn cube_comptime_map_bool_test() {
|
||||
let mut context1 = CubeContext::root();
|
||||
let mut context2 = CubeContext::root();
|
||||
|
||||
let comptime_state_true = State {
|
||||
cond: true,
|
||||
bound: 4,
|
||||
};
|
||||
let comptime_state_false = State {
|
||||
cond: false,
|
||||
bound: 4,
|
||||
};
|
||||
|
||||
comptime_with_map_bool_expand::<ElemType>(&mut context1, comptime_state_true);
|
||||
comptime_with_map_bool_expand::<ElemType>(&mut context2, comptime_state_false);
|
||||
|
||||
let scope1 = context1.into_scope();
|
||||
let scope2 = context2.into_scope();
|
||||
|
||||
assert_ne!(
|
||||
format!("{:?}", scope1.operations),
|
||||
format!("{:?}", scope2.operations)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cube_comptime_map_uint_test() {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::new(ElemType::as_elem());
|
||||
let lhs = context.create_local(item);
|
||||
|
||||
let mut scope = context.into_scope();
|
||||
let cond = scope.create_local(Item::new(Elem::Bool));
|
||||
let lhs: Variable = lhs.into();
|
||||
let y = scope.create_local(item);
|
||||
let comptime_state = State {
|
||||
cond: true,
|
||||
bound: 4,
|
||||
};
|
||||
|
||||
cpa!(scope, cond = lhs < 0f32);
|
||||
cpa!(&mut scope, if(cond).then(|scope| {
|
||||
cpa!(scope, y = lhs + 4.0f32);
|
||||
}).else(|scope|{
|
||||
cpa!(scope, y = lhs - 5.0f32);
|
||||
}));
|
||||
comptime_with_map_uint_expand::<ElemType>(&mut context, comptime_state);
|
||||
|
||||
format!("{:?}", scope.operations)
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert!(!format!("{:?}", scope.operations).contains("RangeLoop"));
|
||||
}
|
||||
|
||||
fn inline_macro_ref_comptime(cond: bool) -> String {
|
|
@ -1,4 +1,4 @@
|
|||
use burn_cube::{cube, frontend::Numeric};
|
||||
use burn_cube::prelude::*;
|
||||
|
||||
#[cube]
|
||||
pub fn if_greater<T: Numeric>(lhs: T) {
|
||||
|
@ -15,6 +15,15 @@ pub fn if_greater_var<T: Numeric>(lhs: T) {
|
|||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub fn if_then_else<F: Float>(lhs: F) {
|
||||
if lhs < F::from_int(0) {
|
||||
let _ = lhs + F::from_int(4);
|
||||
} else {
|
||||
let _ = lhs - F::from_int(5);
|
||||
}
|
||||
}
|
||||
|
||||
mod tests {
|
||||
use burn_cube::{
|
||||
cpa,
|
||||
|
@ -22,7 +31,7 @@ mod tests {
|
|||
ir::{Elem, Item, Variable},
|
||||
};
|
||||
|
||||
use super::if_greater_expand;
|
||||
use super::*;
|
||||
|
||||
type ElemType = F32;
|
||||
|
||||
|
@ -35,10 +44,25 @@ mod tests {
|
|||
if_greater_expand::<ElemType>(&mut context, lhs);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref());
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_if());
|
||||
}
|
||||
|
||||
fn inline_macro_ref() -> String {
|
||||
#[test]
|
||||
fn cube_if_else_test() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
if_then_else_expand::<ElemType>(&mut context, lhs);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
format!("{:?}", scope.operations),
|
||||
inline_macro_ref_if_else()
|
||||
);
|
||||
}
|
||||
|
||||
fn inline_macro_ref_if() -> String {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::new(ElemType::as_elem());
|
||||
let lhs = context.create_local(item);
|
||||
|
@ -55,4 +79,24 @@ mod tests {
|
|||
|
||||
format!("{:?}", scope.operations)
|
||||
}
|
||||
|
||||
fn inline_macro_ref_if_else() -> String {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::new(ElemType::as_elem());
|
||||
let lhs = context.create_local(item);
|
||||
|
||||
let mut scope = context.into_scope();
|
||||
let cond = scope.create_local(Item::new(Elem::Bool));
|
||||
let lhs: Variable = lhs.into();
|
||||
let y = scope.create_local(item);
|
||||
|
||||
cpa!(scope, cond = lhs < 0f32);
|
||||
cpa!(&mut scope, if(cond).then(|scope| {
|
||||
cpa!(scope, y = lhs + 4.0f32);
|
||||
}).else(|scope|{
|
||||
cpa!(scope, y = lhs - 5.0f32);
|
||||
}));
|
||||
|
||||
format!("{:?}", scope.operations)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
mod cast_elem;
|
||||
mod cast_kind;
|
||||
mod comptime;
|
||||
mod for_loop;
|
||||
mod function_call;
|
||||
mod generic_kernel;
|
||||
mod r#if;
|
||||
mod if_else;
|
||||
mod literal;
|
||||
mod r#loop;
|
||||
mod module_import;
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use burn_cube::prelude::*;
|
||||
|
||||
#[derive(Cube)]
|
||||
#[derive(CubeType)]
|
||||
struct State<T: Numeric> {
|
||||
first: T,
|
||||
second: T,
|
||||
|
|
|
@ -15,7 +15,7 @@ use crate::{
|
|||
FloatElement, JitRuntime,
|
||||
};
|
||||
|
||||
#[derive(Cube)]
|
||||
#[derive(CubeLaunch)]
|
||||
struct Conv2dArgs {
|
||||
conv_stride_0: UInt,
|
||||
conv_stride_1: UInt,
|
||||
|
|
Loading…
Reference in New Issue