combine comptime and runtime in elsif

This commit is contained in:
louisfd 2024-07-03 16:18:39 -04:00
parent 2292b38778
commit 8e829da1b6
3 changed files with 156 additions and 123 deletions

View File

@ -128,15 +128,17 @@ impl VariableAnalyzer {
self.find_occurrences_in_stmts(&expr.body.stmts, depth);
}
syn::Expr::If(expr) => {
self.find_occurrences_in_expr(&expr.cond, depth + 1);
self.find_occurrences_in_stmts(&expr.then_branch.stmts, depth + 1);
let depth = depth + 1;
self.find_occurrences_in_expr(&expr.cond, depth);
self.find_occurrences_in_stmts(&expr.then_branch.stmts, depth);
if let Some((_, expr)) = &expr.else_branch {
match &**expr {
syn::Expr::Block(expr_block) => {
self.find_occurrences_in_stmts(&expr_block.block.stmts, depth + 1);
self.find_occurrences_in_stmts(&expr_block.block.stmts, depth);
}
syn::Expr::If(expr) => {
self.find_occurrences_in_expr(&syn::Expr::If(expr.clone()), depth + 1);
self.find_occurrences_in_expr(&syn::Expr::If(expr.clone()), depth);
}
_ => unreachable!(),
}

View File

@ -45,125 +45,29 @@ pub fn comptime_elsif<T: Numeric>(lhs: T, cond1: Comptime<bool>, cond2: Comptime
}
}
// #[allow(dead_code)]
// #[allow(clippy::too_many_arguments)]
// pub fn comptime_if_else_if<T: Numeric>(lhs: T, cond1: Comptime<bool>, cond2: Comptime<bool>) {
// if Comptime::get(cond1) {
// let _ = lhs + T::from_int(4);
// } else {
// if Comptime::get(cond2) {
// let _ = lhs + T::from_int(6);
// } else {
// let _ = lhs - T::from_int(5);
// }
// }
// }
// #[allow(unused_mut)]
// #[allow(clippy::too_many_arguments)]
// #[doc = r" Expanded Cube function"]
// pub fn comptime_if_else_if_expand<T: Numeric>(
// context: &mut burn_cube::frontend::CubeContext,
// lhs: <T as burn_cube::frontend::CubeType>::ExpandType,
// cond1: <Comptime<bool> as burn_cube::frontend::CubeType>::ExpandType,
// cond2: <Comptime<bool> as burn_cube::frontend::CubeType>::ExpandType,
// ) -> () {
// let _cond = cond1;
// burn_cube::frontend::branch::if_else_expand(
// context,
// Some(cond1),
// _cond.into(),
// |context| {
// let _ = {
// let _inner = {
// let _lhs = lhs.clone();
// let _rhs = {
// let _var_0 = 4;
// T::from_int_expand(context, _var_0)
// };
// burn_cube::frontend::add::expand(context, _lhs, _rhs)
// };
// burn_cube::frontend::Init::init(_inner, context)
// };
// },
// |context| {
// let _cond = cond2;
// burn_cube::frontend::branch::if_else_expand(
// context,
// Some(cond2),
// _cond.into(),
// |context| {
// let _ = {
// let _inner = {
// let _lhs = lhs.clone();
// let _rhs = {
// let _var_0 = 6;
// T::from_int_expand(context, _var_0)
// };
// burn_cube::frontend::add::expand(context, _lhs, _rhs)
// };
// burn_cube::frontend::Init::init(_inner, context)
// };
// },
// |context| {
// let _ = {
// let _inner = {
// let _lhs = lhs.clone();
// let _rhs = {
// let _var_0 = 5;
// T::from_int_expand(context, _var_0)
// };
// burn_cube::frontend::sub::expand(context, _lhs, _rhs)
// };
// burn_cube::frontend::Init::init(_inner, context)
// };
// },
// );
// },
// );
// }
#[cube]
pub fn comptime_elsif_with_runtime1<T: Numeric>(lhs: T, comptime_cond: Comptime<bool>) {
let runtime_cond = lhs >= T::from_int(2);
if Comptime::get(comptime_cond) {
let _ = lhs + T::from_int(4);
} else if runtime_cond {
let _ = lhs + T::from_int(5);
} else {
let _ = lhs - T::from_int(6);
}
}
// #[allow(unused_mut)]
// #[allow(clippy::too_many_arguments)]
// #[doc = r" Expanded Cube function"]
// pub fn comptime_if_else_if_expand<T: Numeric>(
// context: &mut burn_cube::frontend::CubeContext,
// lhs: <T as burn_cube::frontend::CubeType>::ExpandType,
// cond1: <Comptime<bool> as burn_cube::frontend::CubeType>::ExpandType,
// cond2: <Comptime<bool> as burn_cube::frontend::CubeType>::ExpandType,
// ) -> () {
// let _cond = cond2;
// burn_cube::frontend::branch::if_else_expand(
// context,
// Some(cond2),
// _cond.into(),
// |context| {
// let _ = {
// let _inner = {
// let _lhs = lhs.clone();
// let _rhs = {
// let _var_0 = 6;
// T::from_int_expand(context, _var_0)
// };
// burn_cube::frontend::add::expand(context, _lhs, _rhs)
// };
// burn_cube::frontend::Init::init(_inner, context)
// };
// },
// |context| {
// let _ = {
// let _inner = {
// let _lhs = lhs.clone();
// let _rhs = {
// let _var_0 = 5;
// T::from_int_expand(context, _var_0)
// };
// burn_cube::frontend::sub::expand(context, _lhs, _rhs)
// };
// burn_cube::frontend::Init::init(_inner, context)
// };
// },
// );
// }
#[cube]
pub fn comptime_elsif_with_runtime2<T: Numeric>(lhs: T, comptime_cond: Comptime<bool>) {
let runtime_cond = lhs >= T::from_int(2);
if runtime_cond {
let _ = lhs + T::from_int(4);
} else if Comptime::get(comptime_cond) {
let _ = lhs + T::from_int(5);
} else {
let _ = lhs - T::from_int(6);
}
}
#[cube]
pub fn comptime_if_expr<T: Numeric>(lhs: T, x: Comptime<UInt>, y: Comptime<UInt>) {
@ -206,7 +110,7 @@ mod tests {
use burn_cube::{
cpa,
frontend::{CubeContext, CubePrimitive, F32},
ir::{Item, Variable},
ir::{Elem, Item, Variable},
};
type ElemType = F32;
@ -278,6 +182,36 @@ mod tests {
}
}
#[test]
fn cube_comptime_elsif_runtime1_test() {
for cond in [false, true] {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::new(ElemType::as_elem()));
comptime_elsif_with_runtime1_expand::<ElemType>(&mut context, lhs, cond);
let scope = context.into_scope();
assert_eq!(
format!("{:?}", scope.operations),
inline_macro_ref_elsif_runtime1(cond)
);
}
}
#[test]
fn cube_comptime_elsif_runtime2_test() {
for cond in [false, true] {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::new(ElemType::as_elem()));
comptime_elsif_with_runtime2_expand::<ElemType>(&mut context, lhs, cond);
let scope = context.into_scope();
assert_eq!(
format!("{:?}", scope.operations),
inline_macro_ref_elsif_runtime2(cond)
);
}
}
#[test]
fn cube_comptime_map_bool_test() {
let mut context1 = CubeContext::root();
@ -337,4 +271,52 @@ mod tests {
format!("{:?}", scope.operations)
}
fn inline_macro_ref_elsif_runtime1(comptime_cond: bool) -> String {
let mut context = CubeContext::root();
let item = Item::new(ElemType::as_elem());
let x = context.create_local(item);
let mut scope = context.into_scope();
let x: Variable = x.into();
let runtime_cond = scope.create_local(Item::new(Elem::Bool));
let y = scope.create_local(item);
cpa!(scope, runtime_cond = x >= 2.0f32);
if comptime_cond {
cpa!(scope, y = x + 4.0f32);
} else {
cpa!(&mut scope, if(runtime_cond).then(|scope| {
cpa!(scope, y = x + 5.0f32);
}).else(|scope| {
cpa!(scope, y = x - 6.0f32);
}));
};
format!("{:?}", scope.operations)
}
fn inline_macro_ref_elsif_runtime2(comptime_cond: bool) -> String {
let mut context = CubeContext::root();
let item = Item::new(ElemType::as_elem());
let x = context.create_local(item);
let mut scope = context.into_scope();
let x: Variable = x.into();
let runtime_cond = scope.create_local(Item::new(Elem::Bool));
let y = scope.create_local(item);
cpa!(scope, runtime_cond = x >= 2.0f32);
cpa!(&mut scope, if(runtime_cond).then(|scope| {
cpa!(scope, y = x + 4.0f32);
}).else(|scope| {
if comptime_cond {
cpa!(scope, y = x + 5.0f32);
} else {
cpa!(scope, y = x - 6.0f32);
}
}));
format!("{:?}", scope.operations)
}
}

View File

@ -24,6 +24,17 @@ pub fn if_then_else<F: Float>(lhs: F) {
}
}
#[cube]
pub fn elsif<F: Float>(lhs: F) {
if lhs < F::new(0.) {
let _ = lhs + F::new(2.);
} else if lhs > F::new(0.) {
let _ = lhs + F::new(1.);
} else {
let _ = lhs + F::new(0.);
}
}
mod tests {
use burn_cube::{
cpa,
@ -62,6 +73,18 @@ mod tests {
);
}
#[test]
fn cube_elsif_test() {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::new(ElemType::as_elem()));
elsif_expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_elsif());
}
fn inline_macro_ref_if() -> String {
let mut context = CubeContext::root();
let item = Item::new(ElemType::as_elem());
@ -99,4 +122,30 @@ mod tests {
format!("{:?}", scope.operations)
}
fn inline_macro_ref_elsif() -> String {
let mut context = CubeContext::root();
let item = Item::new(ElemType::as_elem());
let lhs = context.create_local(item);
let mut scope = context.into_scope();
let cond1 = scope.create_local(Item::new(Elem::Bool));
let lhs: Variable = lhs.into();
let y = scope.create_local(item);
let cond2 = scope.create_local(Item::new(Elem::Bool));
cpa!(scope, cond1 = lhs < 0f32);
cpa!(&mut scope, if(cond1).then(|scope| {
cpa!(scope, y = lhs + 2.0f32);
}).else(|mut scope|{
cpa!(scope, cond2 = lhs > 0f32);
cpa!(&mut scope, if(cond2).then(|scope| {
cpa!(scope, y = lhs + 1.0f32);
}).else(|scope|{
cpa!(scope, y = lhs + 0.0f32);
}));
}));
format!("{:?}", scope.operations)
}
}