mirror of https://github.com/tracel-ai/burn.git
simplify scope tracking in burn-import (#2207)
* simplify scope tracking in burn-import * removed unecessary return statement
This commit is contained in:
parent
ccb5b2214e
commit
9e9451bb60
|
@ -314,21 +314,17 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
||||||
.flat_map(to_tensor)
|
.flat_map(to_tensor)
|
||||||
.for_each(|tensor| {
|
.for_each(|tensor| {
|
||||||
self.scope
|
self.scope
|
||||||
.tensor_register_variable(&tensor, node_position + 1)
|
.tensor_register_variable(&tensor, node_position + 1);
|
||||||
})
|
});
|
||||||
});
|
// Since the graph is guaranteed to be a DAG, we can safely register future uses
|
||||||
|
// of the inputs (which are the previous nodes' outputs)
|
||||||
self.nodes
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.for_each(|(node_position, node)| {
|
|
||||||
node.input_types()
|
node.input_types()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.flat_map(to_tensor)
|
.flat_map(to_tensor)
|
||||||
.for_each(|tensor| {
|
.for_each(|tensor| {
|
||||||
self.scope
|
self.scope
|
||||||
.tensor_register_future_use(&tensor, node_position)
|
.tensor_register_future_use(&tensor, node_position)
|
||||||
})
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ use std::collections::HashMap;
|
||||||
/// The scope struct ensures that ownership rules are respected during the forward pass.
|
/// The scope struct ensures that ownership rules are respected during the forward pass.
|
||||||
#[derive(Clone, Debug, Default)]
|
#[derive(Clone, Debug, Default)]
|
||||||
pub struct Scope {
|
pub struct Scope {
|
||||||
variables: HashMap<Ident, Vec<TensorVariable>>,
|
variables: HashMap<Ident, TensorVariable>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, new)]
|
#[derive(Clone, Debug, new)]
|
||||||
|
@ -19,20 +19,13 @@ struct TensorVariable {
|
||||||
impl Scope {
|
impl Scope {
|
||||||
/// Declare a new tensor variable.
|
/// Declare a new tensor variable.
|
||||||
pub fn tensor_register_variable(&mut self, tensor: &TensorType, node_position: usize) {
|
pub fn tensor_register_variable(&mut self, tensor: &TensorType, node_position: usize) {
|
||||||
if let Some(variables) = self.variables.get_mut(&tensor.name) {
|
if let Some(variable) = self.variables.get_mut(&tensor.name) {
|
||||||
for variable in variables.iter_mut() {
|
if variable.node_position == node_position {
|
||||||
if variable.node_position == node_position {
|
variable.references += 1;
|
||||||
variable.references += 1;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
variables.push(TensorVariable::new(0, node_position));
|
|
||||||
} else {
|
} else {
|
||||||
self.variables.insert(
|
self.variables
|
||||||
tensor.name.clone(),
|
.insert(tensor.name.clone(), TensorVariable::new(0, node_position));
|
||||||
vec![TensorVariable::new(0, node_position)],
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,12 +35,9 @@ impl Scope {
|
||||||
///
|
///
|
||||||
/// We need to know all futures use of a variable in advance.
|
/// We need to know all futures use of a variable in advance.
|
||||||
pub fn tensor_register_future_use(&mut self, tensor: &TensorType, node_position: usize) {
|
pub fn tensor_register_future_use(&mut self, tensor: &TensorType, node_position: usize) {
|
||||||
if let Some(variables) = self.variables.get_mut(&tensor.name) {
|
if let Some(variable) = self.variables.get_mut(&tensor.name) {
|
||||||
for variable in variables.iter_mut().rev() {
|
if node_position >= variable.node_position {
|
||||||
if node_position >= variable.node_position {
|
variable.references += 1;
|
||||||
variable.references += 1;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
panic!("No variable with name {}", tensor.name);
|
panic!("No variable with name {}", tensor.name);
|
||||||
|
@ -56,16 +46,13 @@ impl Scope {
|
||||||
|
|
||||||
/// Use a tensor variable, cloning it if it was registered multiple times and the tensor will still be used afterward.
|
/// Use a tensor variable, cloning it if it was registered multiple times and the tensor will still be used afterward.
|
||||||
pub fn tensor_use_owned(&mut self, tensor: &TensorType, node_position: usize) -> TokenStream {
|
pub fn tensor_use_owned(&mut self, tensor: &TensorType, node_position: usize) -> TokenStream {
|
||||||
if let Some(variables) = self.variables.get_mut(&tensor.name) {
|
if let Some(variable) = self.variables.get_mut(&tensor.name) {
|
||||||
let mut count = 0;
|
let mut count = 0;
|
||||||
let name = &tensor.name;
|
let name = &tensor.name;
|
||||||
|
|
||||||
for variable in variables.iter_mut().rev() {
|
if node_position >= variable.node_position {
|
||||||
if node_position >= variable.node_position {
|
variable.references -= 1;
|
||||||
variable.references -= 1;
|
count = variable.references;
|
||||||
count = variable.references;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if count > 0 {
|
if count > 0 {
|
||||||
|
|
Loading…
Reference in New Issue