combine comptime and runtime in elsif

This commit is contained in:
louisfd 2024-07-03 16:18:39 -04:00
parent 2292b38778
commit 8e829da1b6
3 changed files with 156 additions and 123 deletions

View File

@ -128,15 +128,17 @@ impl VariableAnalyzer {
self.find_occurrences_in_stmts(&expr.body.stmts, depth); self.find_occurrences_in_stmts(&expr.body.stmts, depth);
} }
syn::Expr::If(expr) => { syn::Expr::If(expr) => {
self.find_occurrences_in_expr(&expr.cond, depth + 1); let depth = depth + 1;
self.find_occurrences_in_stmts(&expr.then_branch.stmts, depth + 1);
self.find_occurrences_in_expr(&expr.cond, depth);
self.find_occurrences_in_stmts(&expr.then_branch.stmts, depth);
if let Some((_, expr)) = &expr.else_branch { if let Some((_, expr)) = &expr.else_branch {
match &**expr { match &**expr {
syn::Expr::Block(expr_block) => { syn::Expr::Block(expr_block) => {
self.find_occurrences_in_stmts(&expr_block.block.stmts, depth + 1); self.find_occurrences_in_stmts(&expr_block.block.stmts, depth);
} }
syn::Expr::If(expr) => { syn::Expr::If(expr) => {
self.find_occurrences_in_expr(&syn::Expr::If(expr.clone()), depth + 1); self.find_occurrences_in_expr(&syn::Expr::If(expr.clone()), depth);
} }
_ => unreachable!(), _ => unreachable!(),
} }

View File

@ -45,125 +45,29 @@ pub fn comptime_elsif<T: Numeric>(lhs: T, cond1: Comptime<bool>, cond2: Comptime
} }
} }
// #[allow(dead_code)] #[cube]
// #[allow(clippy::too_many_arguments)] pub fn comptime_elsif_with_runtime1<T: Numeric>(lhs: T, comptime_cond: Comptime<bool>) {
// pub fn comptime_if_else_if<T: Numeric>(lhs: T, cond1: Comptime<bool>, cond2: Comptime<bool>) { let runtime_cond = lhs >= T::from_int(2);
// if Comptime::get(cond1) { if Comptime::get(comptime_cond) {
// let _ = lhs + T::from_int(4); let _ = lhs + T::from_int(4);
// } else { } else if runtime_cond {
// if Comptime::get(cond2) { let _ = lhs + T::from_int(5);
// let _ = lhs + T::from_int(6); } else {
// } else { let _ = lhs - T::from_int(6);
// let _ = lhs - T::from_int(5); }
// } }
// }
// }
// #[allow(unused_mut)]
// #[allow(clippy::too_many_arguments)]
// #[doc = r" Expanded Cube function"]
// pub fn comptime_if_else_if_expand<T: Numeric>(
// context: &mut burn_cube::frontend::CubeContext,
// lhs: <T as burn_cube::frontend::CubeType>::ExpandType,
// cond1: <Comptime<bool> as burn_cube::frontend::CubeType>::ExpandType,
// cond2: <Comptime<bool> as burn_cube::frontend::CubeType>::ExpandType,
// ) -> () {
// let _cond = cond1;
// burn_cube::frontend::branch::if_else_expand(
// context,
// Some(cond1),
// _cond.into(),
// |context| {
// let _ = {
// let _inner = {
// let _lhs = lhs.clone();
// let _rhs = {
// let _var_0 = 4;
// T::from_int_expand(context, _var_0)
// };
// burn_cube::frontend::add::expand(context, _lhs, _rhs)
// };
// burn_cube::frontend::Init::init(_inner, context)
// };
// },
// |context| {
// let _cond = cond2;
// burn_cube::frontend::branch::if_else_expand(
// context,
// Some(cond2),
// _cond.into(),
// |context| {
// let _ = {
// let _inner = {
// let _lhs = lhs.clone();
// let _rhs = {
// let _var_0 = 6;
// T::from_int_expand(context, _var_0)
// };
// burn_cube::frontend::add::expand(context, _lhs, _rhs)
// };
// burn_cube::frontend::Init::init(_inner, context)
// };
// },
// |context| {
// let _ = {
// let _inner = {
// let _lhs = lhs.clone();
// let _rhs = {
// let _var_0 = 5;
// T::from_int_expand(context, _var_0)
// };
// burn_cube::frontend::sub::expand(context, _lhs, _rhs)
// };
// burn_cube::frontend::Init::init(_inner, context)
// };
// },
// );
// },
// );
// }
// #[allow(unused_mut)] #[cube]
// #[allow(clippy::too_many_arguments)] pub fn comptime_elsif_with_runtime2<T: Numeric>(lhs: T, comptime_cond: Comptime<bool>) {
// #[doc = r" Expanded Cube function"] let runtime_cond = lhs >= T::from_int(2);
// pub fn comptime_if_else_if_expand<T: Numeric>( if runtime_cond {
// context: &mut burn_cube::frontend::CubeContext, let _ = lhs + T::from_int(4);
// lhs: <T as burn_cube::frontend::CubeType>::ExpandType, } else if Comptime::get(comptime_cond) {
// cond1: <Comptime<bool> as burn_cube::frontend::CubeType>::ExpandType, let _ = lhs + T::from_int(5);
// cond2: <Comptime<bool> as burn_cube::frontend::CubeType>::ExpandType, } else {
// ) -> () { let _ = lhs - T::from_int(6);
// let _cond = cond2; }
// burn_cube::frontend::branch::if_else_expand( }
// context,
// Some(cond2),
// _cond.into(),
// |context| {
// let _ = {
// let _inner = {
// let _lhs = lhs.clone();
// let _rhs = {
// let _var_0 = 6;
// T::from_int_expand(context, _var_0)
// };
// burn_cube::frontend::add::expand(context, _lhs, _rhs)
// };
// burn_cube::frontend::Init::init(_inner, context)
// };
// },
// |context| {
// let _ = {
// let _inner = {
// let _lhs = lhs.clone();
// let _rhs = {
// let _var_0 = 5;
// T::from_int_expand(context, _var_0)
// };
// burn_cube::frontend::sub::expand(context, _lhs, _rhs)
// };
// burn_cube::frontend::Init::init(_inner, context)
// };
// },
// );
// }
#[cube] #[cube]
pub fn comptime_if_expr<T: Numeric>(lhs: T, x: Comptime<UInt>, y: Comptime<UInt>) { pub fn comptime_if_expr<T: Numeric>(lhs: T, x: Comptime<UInt>, y: Comptime<UInt>) {
@ -206,7 +110,7 @@ mod tests {
use burn_cube::{ use burn_cube::{
cpa, cpa,
frontend::{CubeContext, CubePrimitive, F32}, frontend::{CubeContext, CubePrimitive, F32},
ir::{Item, Variable}, ir::{Elem, Item, Variable},
}; };
type ElemType = F32; type ElemType = F32;
@ -278,6 +182,36 @@ mod tests {
} }
} }
#[test]
fn cube_comptime_elsif_runtime1_test() {
for cond in [false, true] {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::new(ElemType::as_elem()));
comptime_elsif_with_runtime1_expand::<ElemType>(&mut context, lhs, cond);
let scope = context.into_scope();
assert_eq!(
format!("{:?}", scope.operations),
inline_macro_ref_elsif_runtime1(cond)
);
}
}
#[test]
fn cube_comptime_elsif_runtime2_test() {
for cond in [false, true] {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::new(ElemType::as_elem()));
comptime_elsif_with_runtime2_expand::<ElemType>(&mut context, lhs, cond);
let scope = context.into_scope();
assert_eq!(
format!("{:?}", scope.operations),
inline_macro_ref_elsif_runtime2(cond)
);
}
}
#[test] #[test]
fn cube_comptime_map_bool_test() { fn cube_comptime_map_bool_test() {
let mut context1 = CubeContext::root(); let mut context1 = CubeContext::root();
@ -337,4 +271,52 @@ mod tests {
format!("{:?}", scope.operations) format!("{:?}", scope.operations)
} }
fn inline_macro_ref_elsif_runtime1(comptime_cond: bool) -> String {
let mut context = CubeContext::root();
let item = Item::new(ElemType::as_elem());
let x = context.create_local(item);
let mut scope = context.into_scope();
let x: Variable = x.into();
let runtime_cond = scope.create_local(Item::new(Elem::Bool));
let y = scope.create_local(item);
cpa!(scope, runtime_cond = x >= 2.0f32);
if comptime_cond {
cpa!(scope, y = x + 4.0f32);
} else {
cpa!(&mut scope, if(runtime_cond).then(|scope| {
cpa!(scope, y = x + 5.0f32);
}).else(|scope| {
cpa!(scope, y = x - 6.0f32);
}));
};
format!("{:?}", scope.operations)
}
fn inline_macro_ref_elsif_runtime2(comptime_cond: bool) -> String {
let mut context = CubeContext::root();
let item = Item::new(ElemType::as_elem());
let x = context.create_local(item);
let mut scope = context.into_scope();
let x: Variable = x.into();
let runtime_cond = scope.create_local(Item::new(Elem::Bool));
let y = scope.create_local(item);
cpa!(scope, runtime_cond = x >= 2.0f32);
cpa!(&mut scope, if(runtime_cond).then(|scope| {
cpa!(scope, y = x + 4.0f32);
}).else(|scope| {
if comptime_cond {
cpa!(scope, y = x + 5.0f32);
} else {
cpa!(scope, y = x - 6.0f32);
}
}));
format!("{:?}", scope.operations)
}
} }

View File

@ -24,6 +24,17 @@ pub fn if_then_else<F: Float>(lhs: F) {
} }
} }
#[cube]
pub fn elsif<F: Float>(lhs: F) {
if lhs < F::new(0.) {
let _ = lhs + F::new(2.);
} else if lhs > F::new(0.) {
let _ = lhs + F::new(1.);
} else {
let _ = lhs + F::new(0.);
}
}
mod tests { mod tests {
use burn_cube::{ use burn_cube::{
cpa, cpa,
@ -62,6 +73,18 @@ mod tests {
); );
} }
#[test]
fn cube_elsif_test() {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::new(ElemType::as_elem()));
elsif_expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_elsif());
}
fn inline_macro_ref_if() -> String { fn inline_macro_ref_if() -> String {
let mut context = CubeContext::root(); let mut context = CubeContext::root();
let item = Item::new(ElemType::as_elem()); let item = Item::new(ElemType::as_elem());
@ -99,4 +122,30 @@ mod tests {
format!("{:?}", scope.operations) format!("{:?}", scope.operations)
} }
fn inline_macro_ref_elsif() -> String {
let mut context = CubeContext::root();
let item = Item::new(ElemType::as_elem());
let lhs = context.create_local(item);
let mut scope = context.into_scope();
let cond1 = scope.create_local(Item::new(Elem::Bool));
let lhs: Variable = lhs.into();
let y = scope.create_local(item);
let cond2 = scope.create_local(Item::new(Elem::Bool));
cpa!(scope, cond1 = lhs < 0f32);
cpa!(&mut scope, if(cond1).then(|scope| {
cpa!(scope, y = lhs + 2.0f32);
}).else(|mut scope|{
cpa!(scope, cond2 = lhs > 0f32);
cpa!(&mut scope, if(cond2).then(|scope| {
cpa!(scope, y = lhs + 1.0f32);
}).else(|scope|{
cpa!(scope, y = lhs + 0.0f32);
}));
}));
format!("{:?}", scope.operations)
}
} }