mirror of https://github.com/tracel-ai/burn.git
combine comptime and runtime in elsif
This commit is contained in:
parent
2292b38778
commit
8e829da1b6
|
@ -128,15 +128,17 @@ impl VariableAnalyzer {
|
|||
self.find_occurrences_in_stmts(&expr.body.stmts, depth);
|
||||
}
|
||||
syn::Expr::If(expr) => {
|
||||
self.find_occurrences_in_expr(&expr.cond, depth + 1);
|
||||
self.find_occurrences_in_stmts(&expr.then_branch.stmts, depth + 1);
|
||||
let depth = depth + 1;
|
||||
|
||||
self.find_occurrences_in_expr(&expr.cond, depth);
|
||||
self.find_occurrences_in_stmts(&expr.then_branch.stmts, depth);
|
||||
if let Some((_, expr)) = &expr.else_branch {
|
||||
match &**expr {
|
||||
syn::Expr::Block(expr_block) => {
|
||||
self.find_occurrences_in_stmts(&expr_block.block.stmts, depth + 1);
|
||||
self.find_occurrences_in_stmts(&expr_block.block.stmts, depth);
|
||||
}
|
||||
syn::Expr::If(expr) => {
|
||||
self.find_occurrences_in_expr(&syn::Expr::If(expr.clone()), depth + 1);
|
||||
self.find_occurrences_in_expr(&syn::Expr::If(expr.clone()), depth);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
|
|
@ -45,125 +45,29 @@ pub fn comptime_elsif<T: Numeric>(lhs: T, cond1: Comptime<bool>, cond2: Comptime
|
|||
}
|
||||
}
|
||||
|
||||
// #[allow(dead_code)]
|
||||
// #[allow(clippy::too_many_arguments)]
|
||||
// pub fn comptime_if_else_if<T: Numeric>(lhs: T, cond1: Comptime<bool>, cond2: Comptime<bool>) {
|
||||
// if Comptime::get(cond1) {
|
||||
// let _ = lhs + T::from_int(4);
|
||||
// } else {
|
||||
// if Comptime::get(cond2) {
|
||||
// let _ = lhs + T::from_int(6);
|
||||
// } else {
|
||||
// let _ = lhs - T::from_int(5);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// #[allow(unused_mut)]
|
||||
// #[allow(clippy::too_many_arguments)]
|
||||
// #[doc = r" Expanded Cube function"]
|
||||
// pub fn comptime_if_else_if_expand<T: Numeric>(
|
||||
// context: &mut burn_cube::frontend::CubeContext,
|
||||
// lhs: <T as burn_cube::frontend::CubeType>::ExpandType,
|
||||
// cond1: <Comptime<bool> as burn_cube::frontend::CubeType>::ExpandType,
|
||||
// cond2: <Comptime<bool> as burn_cube::frontend::CubeType>::ExpandType,
|
||||
// ) -> () {
|
||||
// let _cond = cond1;
|
||||
// burn_cube::frontend::branch::if_else_expand(
|
||||
// context,
|
||||
// Some(cond1),
|
||||
// _cond.into(),
|
||||
// |context| {
|
||||
// let _ = {
|
||||
// let _inner = {
|
||||
// let _lhs = lhs.clone();
|
||||
// let _rhs = {
|
||||
// let _var_0 = 4;
|
||||
// T::from_int_expand(context, _var_0)
|
||||
// };
|
||||
// burn_cube::frontend::add::expand(context, _lhs, _rhs)
|
||||
// };
|
||||
// burn_cube::frontend::Init::init(_inner, context)
|
||||
// };
|
||||
// },
|
||||
// |context| {
|
||||
// let _cond = cond2;
|
||||
// burn_cube::frontend::branch::if_else_expand(
|
||||
// context,
|
||||
// Some(cond2),
|
||||
// _cond.into(),
|
||||
// |context| {
|
||||
// let _ = {
|
||||
// let _inner = {
|
||||
// let _lhs = lhs.clone();
|
||||
// let _rhs = {
|
||||
// let _var_0 = 6;
|
||||
// T::from_int_expand(context, _var_0)
|
||||
// };
|
||||
// burn_cube::frontend::add::expand(context, _lhs, _rhs)
|
||||
// };
|
||||
// burn_cube::frontend::Init::init(_inner, context)
|
||||
// };
|
||||
// },
|
||||
// |context| {
|
||||
// let _ = {
|
||||
// let _inner = {
|
||||
// let _lhs = lhs.clone();
|
||||
// let _rhs = {
|
||||
// let _var_0 = 5;
|
||||
// T::from_int_expand(context, _var_0)
|
||||
// };
|
||||
// burn_cube::frontend::sub::expand(context, _lhs, _rhs)
|
||||
// };
|
||||
// burn_cube::frontend::Init::init(_inner, context)
|
||||
// };
|
||||
// },
|
||||
// );
|
||||
// },
|
||||
// );
|
||||
// }
|
||||
#[cube]
|
||||
pub fn comptime_elsif_with_runtime1<T: Numeric>(lhs: T, comptime_cond: Comptime<bool>) {
|
||||
let runtime_cond = lhs >= T::from_int(2);
|
||||
if Comptime::get(comptime_cond) {
|
||||
let _ = lhs + T::from_int(4);
|
||||
} else if runtime_cond {
|
||||
let _ = lhs + T::from_int(5);
|
||||
} else {
|
||||
let _ = lhs - T::from_int(6);
|
||||
}
|
||||
}
|
||||
|
||||
// #[allow(unused_mut)]
|
||||
// #[allow(clippy::too_many_arguments)]
|
||||
// #[doc = r" Expanded Cube function"]
|
||||
// pub fn comptime_if_else_if_expand<T: Numeric>(
|
||||
// context: &mut burn_cube::frontend::CubeContext,
|
||||
// lhs: <T as burn_cube::frontend::CubeType>::ExpandType,
|
||||
// cond1: <Comptime<bool> as burn_cube::frontend::CubeType>::ExpandType,
|
||||
// cond2: <Comptime<bool> as burn_cube::frontend::CubeType>::ExpandType,
|
||||
// ) -> () {
|
||||
// let _cond = cond2;
|
||||
// burn_cube::frontend::branch::if_else_expand(
|
||||
// context,
|
||||
// Some(cond2),
|
||||
// _cond.into(),
|
||||
// |context| {
|
||||
// let _ = {
|
||||
// let _inner = {
|
||||
// let _lhs = lhs.clone();
|
||||
// let _rhs = {
|
||||
// let _var_0 = 6;
|
||||
// T::from_int_expand(context, _var_0)
|
||||
// };
|
||||
// burn_cube::frontend::add::expand(context, _lhs, _rhs)
|
||||
// };
|
||||
// burn_cube::frontend::Init::init(_inner, context)
|
||||
// };
|
||||
// },
|
||||
// |context| {
|
||||
// let _ = {
|
||||
// let _inner = {
|
||||
// let _lhs = lhs.clone();
|
||||
// let _rhs = {
|
||||
// let _var_0 = 5;
|
||||
// T::from_int_expand(context, _var_0)
|
||||
// };
|
||||
// burn_cube::frontend::sub::expand(context, _lhs, _rhs)
|
||||
// };
|
||||
// burn_cube::frontend::Init::init(_inner, context)
|
||||
// };
|
||||
// },
|
||||
// );
|
||||
// }
|
||||
#[cube]
|
||||
pub fn comptime_elsif_with_runtime2<T: Numeric>(lhs: T, comptime_cond: Comptime<bool>) {
|
||||
let runtime_cond = lhs >= T::from_int(2);
|
||||
if runtime_cond {
|
||||
let _ = lhs + T::from_int(4);
|
||||
} else if Comptime::get(comptime_cond) {
|
||||
let _ = lhs + T::from_int(5);
|
||||
} else {
|
||||
let _ = lhs - T::from_int(6);
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub fn comptime_if_expr<T: Numeric>(lhs: T, x: Comptime<UInt>, y: Comptime<UInt>) {
|
||||
|
@ -206,7 +110,7 @@ mod tests {
|
|||
use burn_cube::{
|
||||
cpa,
|
||||
frontend::{CubeContext, CubePrimitive, F32},
|
||||
ir::{Item, Variable},
|
||||
ir::{Elem, Item, Variable},
|
||||
};
|
||||
|
||||
type ElemType = F32;
|
||||
|
@ -278,6 +182,36 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cube_comptime_elsif_runtime1_test() {
|
||||
for cond in [false, true] {
|
||||
let mut context = CubeContext::root();
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
comptime_elsif_with_runtime1_expand::<ElemType>(&mut context, lhs, cond);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
format!("{:?}", scope.operations),
|
||||
inline_macro_ref_elsif_runtime1(cond)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cube_comptime_elsif_runtime2_test() {
|
||||
for cond in [false, true] {
|
||||
let mut context = CubeContext::root();
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
comptime_elsif_with_runtime2_expand::<ElemType>(&mut context, lhs, cond);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
format!("{:?}", scope.operations),
|
||||
inline_macro_ref_elsif_runtime2(cond)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cube_comptime_map_bool_test() {
|
||||
let mut context1 = CubeContext::root();
|
||||
|
@ -337,4 +271,52 @@ mod tests {
|
|||
|
||||
format!("{:?}", scope.operations)
|
||||
}
|
||||
|
||||
fn inline_macro_ref_elsif_runtime1(comptime_cond: bool) -> String {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::new(ElemType::as_elem());
|
||||
let x = context.create_local(item);
|
||||
|
||||
let mut scope = context.into_scope();
|
||||
let x: Variable = x.into();
|
||||
let runtime_cond = scope.create_local(Item::new(Elem::Bool));
|
||||
let y = scope.create_local(item);
|
||||
cpa!(scope, runtime_cond = x >= 2.0f32);
|
||||
|
||||
if comptime_cond {
|
||||
cpa!(scope, y = x + 4.0f32);
|
||||
} else {
|
||||
cpa!(&mut scope, if(runtime_cond).then(|scope| {
|
||||
cpa!(scope, y = x + 5.0f32);
|
||||
}).else(|scope| {
|
||||
cpa!(scope, y = x - 6.0f32);
|
||||
}));
|
||||
};
|
||||
|
||||
format!("{:?}", scope.operations)
|
||||
}
|
||||
|
||||
fn inline_macro_ref_elsif_runtime2(comptime_cond: bool) -> String {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::new(ElemType::as_elem());
|
||||
let x = context.create_local(item);
|
||||
|
||||
let mut scope = context.into_scope();
|
||||
let x: Variable = x.into();
|
||||
let runtime_cond = scope.create_local(Item::new(Elem::Bool));
|
||||
let y = scope.create_local(item);
|
||||
cpa!(scope, runtime_cond = x >= 2.0f32);
|
||||
|
||||
cpa!(&mut scope, if(runtime_cond).then(|scope| {
|
||||
cpa!(scope, y = x + 4.0f32);
|
||||
}).else(|scope| {
|
||||
if comptime_cond {
|
||||
cpa!(scope, y = x + 5.0f32);
|
||||
} else {
|
||||
cpa!(scope, y = x - 6.0f32);
|
||||
}
|
||||
}));
|
||||
|
||||
format!("{:?}", scope.operations)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,6 +24,17 @@ pub fn if_then_else<F: Float>(lhs: F) {
|
|||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub fn elsif<F: Float>(lhs: F) {
|
||||
if lhs < F::new(0.) {
|
||||
let _ = lhs + F::new(2.);
|
||||
} else if lhs > F::new(0.) {
|
||||
let _ = lhs + F::new(1.);
|
||||
} else {
|
||||
let _ = lhs + F::new(0.);
|
||||
}
|
||||
}
|
||||
|
||||
mod tests {
|
||||
use burn_cube::{
|
||||
cpa,
|
||||
|
@ -62,6 +73,18 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cube_elsif_test() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
elsif_expand::<ElemType>(&mut context, lhs);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_elsif());
|
||||
}
|
||||
|
||||
fn inline_macro_ref_if() -> String {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::new(ElemType::as_elem());
|
||||
|
@ -99,4 +122,30 @@ mod tests {
|
|||
|
||||
format!("{:?}", scope.operations)
|
||||
}
|
||||
|
||||
fn inline_macro_ref_elsif() -> String {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::new(ElemType::as_elem());
|
||||
let lhs = context.create_local(item);
|
||||
|
||||
let mut scope = context.into_scope();
|
||||
let cond1 = scope.create_local(Item::new(Elem::Bool));
|
||||
let lhs: Variable = lhs.into();
|
||||
let y = scope.create_local(item);
|
||||
let cond2 = scope.create_local(Item::new(Elem::Bool));
|
||||
|
||||
cpa!(scope, cond1 = lhs < 0f32);
|
||||
cpa!(&mut scope, if(cond1).then(|scope| {
|
||||
cpa!(scope, y = lhs + 2.0f32);
|
||||
}).else(|mut scope|{
|
||||
cpa!(scope, cond2 = lhs > 0f32);
|
||||
cpa!(&mut scope, if(cond2).then(|scope| {
|
||||
cpa!(scope, y = lhs + 1.0f32);
|
||||
}).else(|scope|{
|
||||
cpa!(scope, y = lhs + 0.0f32);
|
||||
}));
|
||||
}));
|
||||
|
||||
format!("{:?}", scope.operations)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue