simplify scope tracking in burn-import (#2207)

* simplify scope tracking in burn-import

* removed unecessary return statement
This commit is contained in:
Joshua Ferguson 2024-09-09 11:19:26 -05:00 committed by GitHub
parent ccb5b2214e
commit 9e9451bb60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 35 deletions

View File

@ -314,21 +314,17 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
.flat_map(to_tensor)
.for_each(|tensor| {
self.scope
.tensor_register_variable(&tensor, node_position + 1)
})
});
self.nodes
.iter()
.enumerate()
.for_each(|(node_position, node)| {
.tensor_register_variable(&tensor, node_position + 1);
});
// Since the graph is guaranteed to be a DAG, we can safely register future uses
// of the inputs (which are the previous nodes' outputs)
node.input_types()
.into_iter()
.flat_map(to_tensor)
.for_each(|tensor| {
self.scope
.tensor_register_future_use(&tensor, node_position)
})
});
});
}

View File

@ -7,7 +7,7 @@ use std::collections::HashMap;
/// The scope struct ensures that ownership rules are respected during the forward pass.
#[derive(Clone, Debug, Default)]
pub struct Scope {
variables: HashMap<Ident, Vec<TensorVariable>>,
variables: HashMap<Ident, TensorVariable>,
}
#[derive(Clone, Debug, new)]
@ -19,20 +19,13 @@ struct TensorVariable {
impl Scope {
/// Declare a new tensor variable.
pub fn tensor_register_variable(&mut self, tensor: &TensorType, node_position: usize) {
if let Some(variables) = self.variables.get_mut(&tensor.name) {
for variable in variables.iter_mut() {
if variable.node_position == node_position {
variable.references += 1;
return;
}
if let Some(variable) = self.variables.get_mut(&tensor.name) {
if variable.node_position == node_position {
variable.references += 1;
}
variables.push(TensorVariable::new(0, node_position));
} else {
self.variables.insert(
tensor.name.clone(),
vec![TensorVariable::new(0, node_position)],
);
self.variables
.insert(tensor.name.clone(), TensorVariable::new(0, node_position));
}
}
@ -42,12 +35,9 @@ impl Scope {
///
/// We need to know all futures use of a variable in advance.
pub fn tensor_register_future_use(&mut self, tensor: &TensorType, node_position: usize) {
if let Some(variables) = self.variables.get_mut(&tensor.name) {
for variable in variables.iter_mut().rev() {
if node_position >= variable.node_position {
variable.references += 1;
break;
}
if let Some(variable) = self.variables.get_mut(&tensor.name) {
if node_position >= variable.node_position {
variable.references += 1;
}
} else {
panic!("No variable with name {}", tensor.name);
@ -56,16 +46,13 @@ impl Scope {
/// Use a tensor variable, cloning it if it was registered multiple times and the tensor will still be used afterward.
pub fn tensor_use_owned(&mut self, tensor: &TensorType, node_position: usize) -> TokenStream {
if let Some(variables) = self.variables.get_mut(&tensor.name) {
if let Some(variable) = self.variables.get_mut(&tensor.name) {
let mut count = 0;
let name = &tensor.name;
for variable in variables.iter_mut().rev() {
if node_position >= variable.node_position {
variable.references -= 1;
count = variable.references;
break;
}
if node_position >= variable.node_position {
variable.references -= 1;
count = variable.references;
}
if count > 0 {