This commit is contained in:
louisfd 2024-06-12 09:45:01 -04:00
parent fa72ed5d98
commit 88dfa3afaa
6 changed files with 42 additions and 6 deletions

View File

@ -200,7 +200,6 @@ impl CodeAnalysisBuilder {
// Declaration of iterator
if let syn::Pat::Ident(pat_ident) = &*expr.pat {
let id = &pat_ident.ident;
let is_mut = pat_ident.mutability.is_some();
self.variable_ident_factory
.analyze_declare(id.to_string(), depth);
}

View File

@ -19,6 +19,13 @@ pub(crate) fn codegen_for_loop(
) -> TokenStream {
let i = &for_loop.pat;
if let syn::Pat::Ident(pat_ident) = &*for_loop.pat {
let id = &pat_ident.ident;
variable_analyses
.vif
.codegen_declare(id.to_string(), loop_level as u8 + 1);
}
match for_loop.expr.as_ref() {
syn::Expr::Call(call) => {
let func_name = match call.func.as_ref() {

View File

@ -60,6 +60,8 @@ pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream {
cube.into()
};
// panic!("{:?}", variable_analyses);
match mode {
CubeMode::Default => code,
CubeMode::Debug => panic!("{code}"),

View File

@ -159,7 +159,9 @@ impl VariableReuseAnalyzer {
.get_mut(&ident)
.ok_or_else(|| VariableNotFound::new(name, scope_declared, field))?;
// if analysis.num_used > 0 {
analysis.num_used -= 1;
// }
Ok(analysis.should_clone() || should_clone_parent || scope_declared != scope)
}
}

View File

@ -20,12 +20,23 @@ pub fn redeclare_same_scope_other_type<I: Int, F: Float>(mut x: I) -> F {
pub fn redeclare_different_scope<I: Int>(mut x: I) {
let y = I::new(1);
x += y;
for i in range(0u32, 2u32, Comptime::new(false)) {
for _ in range(0u32, 2u32, Comptime::new(false)) {
let y = I::new(2);
x += y;
}
}
#[cube]
pub fn redeclare_two_for_loops(mut x: UInt) {
for i in range(0u32, 2u32, Comptime::new(false)) {
x += i;
}
for i in range(0u32, 2u32, Comptime::new(false)) {
x += i;
x += i;
}
}
mod tests {
use burn_cube::{
cpa,
@ -75,6 +86,21 @@ mod tests {
);
}
#[test]
fn cube_redeclare_two_for_loops_test() {
let mut context = CubeContext::root();
let x = context.create_local(Item::new(UInt::as_elem()));
redeclare_two_for_loops_expand(&mut context, x);
let scope = context.into_scope();
assert_eq!(
format!("{:?}", scope.operations),
inline_macro_ref_different()
);
}
fn inline_macro_ref_different() -> String {
let mut context = CubeContext::root();
let item = Item::new(ElemType::as_elem());

View File

@ -53,9 +53,9 @@ fn matmul_kernel<F: Float>(
k /= Comptime::runtime(vectorization_factor);
for j in range(0u32, k, Comptime::new(false)) {
let lhs_index = row * k + j + offset_lhs;
let rhs_index = col * k + j + offset_rhs;
for i in range(0u32, k, Comptime::new(false)) {
let lhs_index = row * k + i + offset_lhs;
let rhs_index = col * k + i + offset_rhs;
sum += lhs[lhs_index] * rhs[rhs_index];
}