mirror of https://github.com/tracel-ai/burn.git
combine comptime and runtime in elsif
This commit is contained in:
parent
2292b38778
commit
8e829da1b6
|
@ -128,15 +128,17 @@ impl VariableAnalyzer {
|
||||||
self.find_occurrences_in_stmts(&expr.body.stmts, depth);
|
self.find_occurrences_in_stmts(&expr.body.stmts, depth);
|
||||||
}
|
}
|
||||||
syn::Expr::If(expr) => {
|
syn::Expr::If(expr) => {
|
||||||
self.find_occurrences_in_expr(&expr.cond, depth + 1);
|
let depth = depth + 1;
|
||||||
self.find_occurrences_in_stmts(&expr.then_branch.stmts, depth + 1);
|
|
||||||
|
self.find_occurrences_in_expr(&expr.cond, depth);
|
||||||
|
self.find_occurrences_in_stmts(&expr.then_branch.stmts, depth);
|
||||||
if let Some((_, expr)) = &expr.else_branch {
|
if let Some((_, expr)) = &expr.else_branch {
|
||||||
match &**expr {
|
match &**expr {
|
||||||
syn::Expr::Block(expr_block) => {
|
syn::Expr::Block(expr_block) => {
|
||||||
self.find_occurrences_in_stmts(&expr_block.block.stmts, depth + 1);
|
self.find_occurrences_in_stmts(&expr_block.block.stmts, depth);
|
||||||
}
|
}
|
||||||
syn::Expr::If(expr) => {
|
syn::Expr::If(expr) => {
|
||||||
self.find_occurrences_in_expr(&syn::Expr::If(expr.clone()), depth + 1);
|
self.find_occurrences_in_expr(&syn::Expr::If(expr.clone()), depth);
|
||||||
}
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,125 +45,29 @@ pub fn comptime_elsif<T: Numeric>(lhs: T, cond1: Comptime<bool>, cond2: Comptime
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// #[allow(dead_code)]
|
#[cube]
|
||||||
// #[allow(clippy::too_many_arguments)]
|
pub fn comptime_elsif_with_runtime1<T: Numeric>(lhs: T, comptime_cond: Comptime<bool>) {
|
||||||
// pub fn comptime_if_else_if<T: Numeric>(lhs: T, cond1: Comptime<bool>, cond2: Comptime<bool>) {
|
let runtime_cond = lhs >= T::from_int(2);
|
||||||
// if Comptime::get(cond1) {
|
if Comptime::get(comptime_cond) {
|
||||||
// let _ = lhs + T::from_int(4);
|
let _ = lhs + T::from_int(4);
|
||||||
// } else {
|
} else if runtime_cond {
|
||||||
// if Comptime::get(cond2) {
|
let _ = lhs + T::from_int(5);
|
||||||
// let _ = lhs + T::from_int(6);
|
} else {
|
||||||
// } else {
|
let _ = lhs - T::from_int(6);
|
||||||
// let _ = lhs - T::from_int(5);
|
}
|
||||||
// }
|
}
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// #[allow(unused_mut)]
|
|
||||||
// #[allow(clippy::too_many_arguments)]
|
|
||||||
// #[doc = r" Expanded Cube function"]
|
|
||||||
// pub fn comptime_if_else_if_expand<T: Numeric>(
|
|
||||||
// context: &mut burn_cube::frontend::CubeContext,
|
|
||||||
// lhs: <T as burn_cube::frontend::CubeType>::ExpandType,
|
|
||||||
// cond1: <Comptime<bool> as burn_cube::frontend::CubeType>::ExpandType,
|
|
||||||
// cond2: <Comptime<bool> as burn_cube::frontend::CubeType>::ExpandType,
|
|
||||||
// ) -> () {
|
|
||||||
// let _cond = cond1;
|
|
||||||
// burn_cube::frontend::branch::if_else_expand(
|
|
||||||
// context,
|
|
||||||
// Some(cond1),
|
|
||||||
// _cond.into(),
|
|
||||||
// |context| {
|
|
||||||
// let _ = {
|
|
||||||
// let _inner = {
|
|
||||||
// let _lhs = lhs.clone();
|
|
||||||
// let _rhs = {
|
|
||||||
// let _var_0 = 4;
|
|
||||||
// T::from_int_expand(context, _var_0)
|
|
||||||
// };
|
|
||||||
// burn_cube::frontend::add::expand(context, _lhs, _rhs)
|
|
||||||
// };
|
|
||||||
// burn_cube::frontend::Init::init(_inner, context)
|
|
||||||
// };
|
|
||||||
// },
|
|
||||||
// |context| {
|
|
||||||
// let _cond = cond2;
|
|
||||||
// burn_cube::frontend::branch::if_else_expand(
|
|
||||||
// context,
|
|
||||||
// Some(cond2),
|
|
||||||
// _cond.into(),
|
|
||||||
// |context| {
|
|
||||||
// let _ = {
|
|
||||||
// let _inner = {
|
|
||||||
// let _lhs = lhs.clone();
|
|
||||||
// let _rhs = {
|
|
||||||
// let _var_0 = 6;
|
|
||||||
// T::from_int_expand(context, _var_0)
|
|
||||||
// };
|
|
||||||
// burn_cube::frontend::add::expand(context, _lhs, _rhs)
|
|
||||||
// };
|
|
||||||
// burn_cube::frontend::Init::init(_inner, context)
|
|
||||||
// };
|
|
||||||
// },
|
|
||||||
// |context| {
|
|
||||||
// let _ = {
|
|
||||||
// let _inner = {
|
|
||||||
// let _lhs = lhs.clone();
|
|
||||||
// let _rhs = {
|
|
||||||
// let _var_0 = 5;
|
|
||||||
// T::from_int_expand(context, _var_0)
|
|
||||||
// };
|
|
||||||
// burn_cube::frontend::sub::expand(context, _lhs, _rhs)
|
|
||||||
// };
|
|
||||||
// burn_cube::frontend::Init::init(_inner, context)
|
|
||||||
// };
|
|
||||||
// },
|
|
||||||
// );
|
|
||||||
// },
|
|
||||||
// );
|
|
||||||
// }
|
|
||||||
|
|
||||||
// #[allow(unused_mut)]
|
#[cube]
|
||||||
// #[allow(clippy::too_many_arguments)]
|
pub fn comptime_elsif_with_runtime2<T: Numeric>(lhs: T, comptime_cond: Comptime<bool>) {
|
||||||
// #[doc = r" Expanded Cube function"]
|
let runtime_cond = lhs >= T::from_int(2);
|
||||||
// pub fn comptime_if_else_if_expand<T: Numeric>(
|
if runtime_cond {
|
||||||
// context: &mut burn_cube::frontend::CubeContext,
|
let _ = lhs + T::from_int(4);
|
||||||
// lhs: <T as burn_cube::frontend::CubeType>::ExpandType,
|
} else if Comptime::get(comptime_cond) {
|
||||||
// cond1: <Comptime<bool> as burn_cube::frontend::CubeType>::ExpandType,
|
let _ = lhs + T::from_int(5);
|
||||||
// cond2: <Comptime<bool> as burn_cube::frontend::CubeType>::ExpandType,
|
} else {
|
||||||
// ) -> () {
|
let _ = lhs - T::from_int(6);
|
||||||
// let _cond = cond2;
|
}
|
||||||
// burn_cube::frontend::branch::if_else_expand(
|
}
|
||||||
// context,
|
|
||||||
// Some(cond2),
|
|
||||||
// _cond.into(),
|
|
||||||
// |context| {
|
|
||||||
// let _ = {
|
|
||||||
// let _inner = {
|
|
||||||
// let _lhs = lhs.clone();
|
|
||||||
// let _rhs = {
|
|
||||||
// let _var_0 = 6;
|
|
||||||
// T::from_int_expand(context, _var_0)
|
|
||||||
// };
|
|
||||||
// burn_cube::frontend::add::expand(context, _lhs, _rhs)
|
|
||||||
// };
|
|
||||||
// burn_cube::frontend::Init::init(_inner, context)
|
|
||||||
// };
|
|
||||||
// },
|
|
||||||
// |context| {
|
|
||||||
// let _ = {
|
|
||||||
// let _inner = {
|
|
||||||
// let _lhs = lhs.clone();
|
|
||||||
// let _rhs = {
|
|
||||||
// let _var_0 = 5;
|
|
||||||
// T::from_int_expand(context, _var_0)
|
|
||||||
// };
|
|
||||||
// burn_cube::frontend::sub::expand(context, _lhs, _rhs)
|
|
||||||
// };
|
|
||||||
// burn_cube::frontend::Init::init(_inner, context)
|
|
||||||
// };
|
|
||||||
// },
|
|
||||||
// );
|
|
||||||
// }
|
|
||||||
|
|
||||||
#[cube]
|
#[cube]
|
||||||
pub fn comptime_if_expr<T: Numeric>(lhs: T, x: Comptime<UInt>, y: Comptime<UInt>) {
|
pub fn comptime_if_expr<T: Numeric>(lhs: T, x: Comptime<UInt>, y: Comptime<UInt>) {
|
||||||
|
@ -206,7 +110,7 @@ mod tests {
|
||||||
use burn_cube::{
|
use burn_cube::{
|
||||||
cpa,
|
cpa,
|
||||||
frontend::{CubeContext, CubePrimitive, F32},
|
frontend::{CubeContext, CubePrimitive, F32},
|
||||||
ir::{Item, Variable},
|
ir::{Elem, Item, Variable},
|
||||||
};
|
};
|
||||||
|
|
||||||
type ElemType = F32;
|
type ElemType = F32;
|
||||||
|
@ -278,6 +182,36 @@ mod tests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cube_comptime_elsif_runtime1_test() {
|
||||||
|
for cond in [false, true] {
|
||||||
|
let mut context = CubeContext::root();
|
||||||
|
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||||
|
comptime_elsif_with_runtime1_expand::<ElemType>(&mut context, lhs, cond);
|
||||||
|
let scope = context.into_scope();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
format!("{:?}", scope.operations),
|
||||||
|
inline_macro_ref_elsif_runtime1(cond)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cube_comptime_elsif_runtime2_test() {
|
||||||
|
for cond in [false, true] {
|
||||||
|
let mut context = CubeContext::root();
|
||||||
|
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||||
|
comptime_elsif_with_runtime2_expand::<ElemType>(&mut context, lhs, cond);
|
||||||
|
let scope = context.into_scope();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
format!("{:?}", scope.operations),
|
||||||
|
inline_macro_ref_elsif_runtime2(cond)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn cube_comptime_map_bool_test() {
|
fn cube_comptime_map_bool_test() {
|
||||||
let mut context1 = CubeContext::root();
|
let mut context1 = CubeContext::root();
|
||||||
|
@ -337,4 +271,52 @@ mod tests {
|
||||||
|
|
||||||
format!("{:?}", scope.operations)
|
format!("{:?}", scope.operations)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn inline_macro_ref_elsif_runtime1(comptime_cond: bool) -> String {
|
||||||
|
let mut context = CubeContext::root();
|
||||||
|
let item = Item::new(ElemType::as_elem());
|
||||||
|
let x = context.create_local(item);
|
||||||
|
|
||||||
|
let mut scope = context.into_scope();
|
||||||
|
let x: Variable = x.into();
|
||||||
|
let runtime_cond = scope.create_local(Item::new(Elem::Bool));
|
||||||
|
let y = scope.create_local(item);
|
||||||
|
cpa!(scope, runtime_cond = x >= 2.0f32);
|
||||||
|
|
||||||
|
if comptime_cond {
|
||||||
|
cpa!(scope, y = x + 4.0f32);
|
||||||
|
} else {
|
||||||
|
cpa!(&mut scope, if(runtime_cond).then(|scope| {
|
||||||
|
cpa!(scope, y = x + 5.0f32);
|
||||||
|
}).else(|scope| {
|
||||||
|
cpa!(scope, y = x - 6.0f32);
|
||||||
|
}));
|
||||||
|
};
|
||||||
|
|
||||||
|
format!("{:?}", scope.operations)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn inline_macro_ref_elsif_runtime2(comptime_cond: bool) -> String {
|
||||||
|
let mut context = CubeContext::root();
|
||||||
|
let item = Item::new(ElemType::as_elem());
|
||||||
|
let x = context.create_local(item);
|
||||||
|
|
||||||
|
let mut scope = context.into_scope();
|
||||||
|
let x: Variable = x.into();
|
||||||
|
let runtime_cond = scope.create_local(Item::new(Elem::Bool));
|
||||||
|
let y = scope.create_local(item);
|
||||||
|
cpa!(scope, runtime_cond = x >= 2.0f32);
|
||||||
|
|
||||||
|
cpa!(&mut scope, if(runtime_cond).then(|scope| {
|
||||||
|
cpa!(scope, y = x + 4.0f32);
|
||||||
|
}).else(|scope| {
|
||||||
|
if comptime_cond {
|
||||||
|
cpa!(scope, y = x + 5.0f32);
|
||||||
|
} else {
|
||||||
|
cpa!(scope, y = x - 6.0f32);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
|
||||||
|
format!("{:?}", scope.operations)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,17 @@ pub fn if_then_else<F: Float>(lhs: F) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cube]
|
||||||
|
pub fn elsif<F: Float>(lhs: F) {
|
||||||
|
if lhs < F::new(0.) {
|
||||||
|
let _ = lhs + F::new(2.);
|
||||||
|
} else if lhs > F::new(0.) {
|
||||||
|
let _ = lhs + F::new(1.);
|
||||||
|
} else {
|
||||||
|
let _ = lhs + F::new(0.);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
mod tests {
|
mod tests {
|
||||||
use burn_cube::{
|
use burn_cube::{
|
||||||
cpa,
|
cpa,
|
||||||
|
@ -62,6 +73,18 @@ mod tests {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cube_elsif_test() {
|
||||||
|
let mut context = CubeContext::root();
|
||||||
|
|
||||||
|
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||||
|
|
||||||
|
elsif_expand::<ElemType>(&mut context, lhs);
|
||||||
|
let scope = context.into_scope();
|
||||||
|
|
||||||
|
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_elsif());
|
||||||
|
}
|
||||||
|
|
||||||
fn inline_macro_ref_if() -> String {
|
fn inline_macro_ref_if() -> String {
|
||||||
let mut context = CubeContext::root();
|
let mut context = CubeContext::root();
|
||||||
let item = Item::new(ElemType::as_elem());
|
let item = Item::new(ElemType::as_elem());
|
||||||
|
@ -99,4 +122,30 @@ mod tests {
|
||||||
|
|
||||||
format!("{:?}", scope.operations)
|
format!("{:?}", scope.operations)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn inline_macro_ref_elsif() -> String {
|
||||||
|
let mut context = CubeContext::root();
|
||||||
|
let item = Item::new(ElemType::as_elem());
|
||||||
|
let lhs = context.create_local(item);
|
||||||
|
|
||||||
|
let mut scope = context.into_scope();
|
||||||
|
let cond1 = scope.create_local(Item::new(Elem::Bool));
|
||||||
|
let lhs: Variable = lhs.into();
|
||||||
|
let y = scope.create_local(item);
|
||||||
|
let cond2 = scope.create_local(Item::new(Elem::Bool));
|
||||||
|
|
||||||
|
cpa!(scope, cond1 = lhs < 0f32);
|
||||||
|
cpa!(&mut scope, if(cond1).then(|scope| {
|
||||||
|
cpa!(scope, y = lhs + 2.0f32);
|
||||||
|
}).else(|mut scope|{
|
||||||
|
cpa!(scope, cond2 = lhs > 0f32);
|
||||||
|
cpa!(&mut scope, if(cond2).then(|scope| {
|
||||||
|
cpa!(scope, y = lhs + 1.0f32);
|
||||||
|
}).else(|scope|{
|
||||||
|
cpa!(scope, y = lhs + 0.0f32);
|
||||||
|
}));
|
||||||
|
}));
|
||||||
|
|
||||||
|
format!("{:?}", scope.operations)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue