mirror of https://github.com/tracel-ai/burn.git
works
This commit is contained in:
parent
fa72ed5d98
commit
88dfa3afaa
|
@ -200,7 +200,6 @@ impl CodeAnalysisBuilder {
|
|||
// Declaration of iterator
|
||||
if let syn::Pat::Ident(pat_ident) = &*expr.pat {
|
||||
let id = &pat_ident.ident;
|
||||
let is_mut = pat_ident.mutability.is_some();
|
||||
self.variable_ident_factory
|
||||
.analyze_declare(id.to_string(), depth);
|
||||
}
|
||||
|
|
|
@ -19,6 +19,13 @@ pub(crate) fn codegen_for_loop(
|
|||
) -> TokenStream {
|
||||
let i = &for_loop.pat;
|
||||
|
||||
if let syn::Pat::Ident(pat_ident) = &*for_loop.pat {
|
||||
let id = &pat_ident.ident;
|
||||
variable_analyses
|
||||
.vif
|
||||
.codegen_declare(id.to_string(), loop_level as u8 + 1);
|
||||
}
|
||||
|
||||
match for_loop.expr.as_ref() {
|
||||
syn::Expr::Call(call) => {
|
||||
let func_name = match call.func.as_ref() {
|
||||
|
|
|
@ -60,6 +60,8 @@ pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream {
|
|||
cube.into()
|
||||
};
|
||||
|
||||
// panic!("{:?}", variable_analyses);
|
||||
|
||||
match mode {
|
||||
CubeMode::Default => code,
|
||||
CubeMode::Debug => panic!("{code}"),
|
||||
|
|
|
@ -159,7 +159,9 @@ impl VariableReuseAnalyzer {
|
|||
.get_mut(&ident)
|
||||
.ok_or_else(|| VariableNotFound::new(name, scope_declared, field))?;
|
||||
|
||||
analysis.num_used -= 1;
|
||||
// if analysis.num_used > 0 {
|
||||
analysis.num_used -= 1;
|
||||
// }
|
||||
Ok(analysis.should_clone() || should_clone_parent || scope_declared != scope)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,12 +20,23 @@ pub fn redeclare_same_scope_other_type<I: Int, F: Float>(mut x: I) -> F {
|
|||
pub fn redeclare_different_scope<I: Int>(mut x: I) {
|
||||
let y = I::new(1);
|
||||
x += y;
|
||||
for i in range(0u32, 2u32, Comptime::new(false)) {
|
||||
for _ in range(0u32, 2u32, Comptime::new(false)) {
|
||||
let y = I::new(2);
|
||||
x += y;
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub fn redeclare_two_for_loops(mut x: UInt) {
|
||||
for i in range(0u32, 2u32, Comptime::new(false)) {
|
||||
x += i;
|
||||
}
|
||||
for i in range(0u32, 2u32, Comptime::new(false)) {
|
||||
x += i;
|
||||
x += i;
|
||||
}
|
||||
}
|
||||
|
||||
mod tests {
|
||||
use burn_cube::{
|
||||
cpa,
|
||||
|
@ -75,6 +86,21 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cube_redeclare_two_for_loops_test() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let x = context.create_local(Item::new(UInt::as_elem()));
|
||||
|
||||
redeclare_two_for_loops_expand(&mut context, x);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
format!("{:?}", scope.operations),
|
||||
inline_macro_ref_different()
|
||||
);
|
||||
}
|
||||
|
||||
fn inline_macro_ref_different() -> String {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::new(ElemType::as_elem());
|
||||
|
|
|
@ -53,9 +53,9 @@ fn matmul_kernel<F: Float>(
|
|||
|
||||
k /= Comptime::runtime(vectorization_factor);
|
||||
|
||||
for j in range(0u32, k, Comptime::new(false)) {
|
||||
let lhs_index = row * k + j + offset_lhs;
|
||||
let rhs_index = col * k + j + offset_rhs;
|
||||
for i in range(0u32, k, Comptime::new(false)) {
|
||||
let lhs_index = row * k + i + offset_lhs;
|
||||
let rhs_index = col * k + i + offset_rhs;
|
||||
|
||||
sum += lhs[lhs_index] * rhs[rhs_index];
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue