From 9e9451bb60f2dda3c4904b1b28ea7e437390e7ef Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Mon, 9 Sep 2024 11:19:26 -0500 Subject: [PATCH] simplify scope tracking in burn-import (#2207) * simplify scope tracking in burn-import * removed unecessary return statement --- crates/burn-import/src/burn/graph.rs | 14 ++++------ crates/burn-import/src/burn/scope.rs | 39 ++++++++++------------------ 2 files changed, 18 insertions(+), 35 deletions(-) diff --git a/crates/burn-import/src/burn/graph.rs b/crates/burn-import/src/burn/graph.rs index cb1d1572c..3ad633faf 100644 --- a/crates/burn-import/src/burn/graph.rs +++ b/crates/burn-import/src/burn/graph.rs @@ -314,21 +314,17 @@ impl BurnGraph { .flat_map(to_tensor) .for_each(|tensor| { self.scope - .tensor_register_variable(&tensor, node_position + 1) - }) - }); - - self.nodes - .iter() - .enumerate() - .for_each(|(node_position, node)| { + .tensor_register_variable(&tensor, node_position + 1); + }); + // Since the graph is guaranteed to be a DAG, we can safely register future uses + // of the inputs (which are the previous nodes' outputs) node.input_types() .into_iter() .flat_map(to_tensor) .for_each(|tensor| { self.scope .tensor_register_future_use(&tensor, node_position) - }) + }); }); } diff --git a/crates/burn-import/src/burn/scope.rs b/crates/burn-import/src/burn/scope.rs index 9e497c09a..caeaa19a4 100644 --- a/crates/burn-import/src/burn/scope.rs +++ b/crates/burn-import/src/burn/scope.rs @@ -7,7 +7,7 @@ use std::collections::HashMap; /// The scope struct ensures that ownership rules are respected during the forward pass. #[derive(Clone, Debug, Default)] pub struct Scope { - variables: HashMap>, + variables: HashMap, } #[derive(Clone, Debug, new)] @@ -19,20 +19,13 @@ struct TensorVariable { impl Scope { /// Declare a new tensor variable. pub fn tensor_register_variable(&mut self, tensor: &TensorType, node_position: usize) { - if let Some(variables) = self.variables.get_mut(&tensor.name) { - for variable in variables.iter_mut() { - if variable.node_position == node_position { - variable.references += 1; - return; - } + if let Some(variable) = self.variables.get_mut(&tensor.name) { + if variable.node_position == node_position { + variable.references += 1; } - - variables.push(TensorVariable::new(0, node_position)); } else { - self.variables.insert( - tensor.name.clone(), - vec![TensorVariable::new(0, node_position)], - ); + self.variables + .insert(tensor.name.clone(), TensorVariable::new(0, node_position)); } } @@ -42,12 +35,9 @@ impl Scope { /// /// We need to know all futures use of a variable in advance. pub fn tensor_register_future_use(&mut self, tensor: &TensorType, node_position: usize) { - if let Some(variables) = self.variables.get_mut(&tensor.name) { - for variable in variables.iter_mut().rev() { - if node_position >= variable.node_position { - variable.references += 1; - break; - } + if let Some(variable) = self.variables.get_mut(&tensor.name) { + if node_position >= variable.node_position { + variable.references += 1; } } else { panic!("No variable with name {}", tensor.name); @@ -56,16 +46,13 @@ impl Scope { /// Use a tensor variable, cloning it if it was registered multiple times and the tensor will still be used afterward. pub fn tensor_use_owned(&mut self, tensor: &TensorType, node_position: usize) -> TokenStream { - if let Some(variables) = self.variables.get_mut(&tensor.name) { + if let Some(variable) = self.variables.get_mut(&tensor.name) { let mut count = 0; let name = &tensor.name; - for variable in variables.iter_mut().rev() { - if node_position >= variable.node_position { - variable.references -= 1; - count = variable.references; - break; - } + if node_position >= variable.node_position { + variable.references -= 1; + count = variable.references; } if count > 0 {