simplify scope tracking in burn-import (#2207)

* simplify scope tracking in burn-import

* removed unecessary return statement
This commit is contained in:
Joshua Ferguson 2024-09-09 11:19:26 -05:00 committed by GitHub
parent ccb5b2214e
commit 9e9451bb60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 35 deletions

View File

@ -314,21 +314,17 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
.flat_map(to_tensor) .flat_map(to_tensor)
.for_each(|tensor| { .for_each(|tensor| {
self.scope self.scope
.tensor_register_variable(&tensor, node_position + 1) .tensor_register_variable(&tensor, node_position + 1);
})
}); });
// Since the graph is guaranteed to be a DAG, we can safely register future uses
self.nodes // of the inputs (which are the previous nodes' outputs)
.iter()
.enumerate()
.for_each(|(node_position, node)| {
node.input_types() node.input_types()
.into_iter() .into_iter()
.flat_map(to_tensor) .flat_map(to_tensor)
.for_each(|tensor| { .for_each(|tensor| {
self.scope self.scope
.tensor_register_future_use(&tensor, node_position) .tensor_register_future_use(&tensor, node_position)
}) });
}); });
} }

View File

@ -7,7 +7,7 @@ use std::collections::HashMap;
/// The scope struct ensures that ownership rules are respected during the forward pass. /// The scope struct ensures that ownership rules are respected during the forward pass.
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
pub struct Scope { pub struct Scope {
variables: HashMap<Ident, Vec<TensorVariable>>, variables: HashMap<Ident, TensorVariable>,
} }
#[derive(Clone, Debug, new)] #[derive(Clone, Debug, new)]
@ -19,20 +19,13 @@ struct TensorVariable {
impl Scope { impl Scope {
/// Declare a new tensor variable. /// Declare a new tensor variable.
pub fn tensor_register_variable(&mut self, tensor: &TensorType, node_position: usize) { pub fn tensor_register_variable(&mut self, tensor: &TensorType, node_position: usize) {
if let Some(variables) = self.variables.get_mut(&tensor.name) { if let Some(variable) = self.variables.get_mut(&tensor.name) {
for variable in variables.iter_mut() {
if variable.node_position == node_position { if variable.node_position == node_position {
variable.references += 1; variable.references += 1;
return;
} }
}
variables.push(TensorVariable::new(0, node_position));
} else { } else {
self.variables.insert( self.variables
tensor.name.clone(), .insert(tensor.name.clone(), TensorVariable::new(0, node_position));
vec![TensorVariable::new(0, node_position)],
);
} }
} }
@ -42,12 +35,9 @@ impl Scope {
/// ///
/// We need to know all futures use of a variable in advance. /// We need to know all futures use of a variable in advance.
pub fn tensor_register_future_use(&mut self, tensor: &TensorType, node_position: usize) { pub fn tensor_register_future_use(&mut self, tensor: &TensorType, node_position: usize) {
if let Some(variables) = self.variables.get_mut(&tensor.name) { if let Some(variable) = self.variables.get_mut(&tensor.name) {
for variable in variables.iter_mut().rev() {
if node_position >= variable.node_position { if node_position >= variable.node_position {
variable.references += 1; variable.references += 1;
break;
}
} }
} else { } else {
panic!("No variable with name {}", tensor.name); panic!("No variable with name {}", tensor.name);
@ -56,16 +46,13 @@ impl Scope {
/// Use a tensor variable, cloning it if it was registered multiple times and the tensor will still be used afterward. /// Use a tensor variable, cloning it if it was registered multiple times and the tensor will still be used afterward.
pub fn tensor_use_owned(&mut self, tensor: &TensorType, node_position: usize) -> TokenStream { pub fn tensor_use_owned(&mut self, tensor: &TensorType, node_position: usize) -> TokenStream {
if let Some(variables) = self.variables.get_mut(&tensor.name) { if let Some(variable) = self.variables.get_mut(&tensor.name) {
let mut count = 0; let mut count = 0;
let name = &tensor.name; let name = &tensor.name;
for variable in variables.iter_mut().rev() {
if node_position >= variable.node_position { if node_position >= variable.node_position {
variable.references -= 1; variable.references -= 1;
count = variable.references; count = variable.references;
break;
}
} }
if count > 0 { if count > 0 {