mirror of https://github.com/tracel-ai/burn.git
Fix autodiff memory management graph cleaning (#1602)
This commit is contained in:
parent
0cbe9a927d
commit
07a61a1cec
|
@ -365,6 +365,7 @@ dependencies = [
|
|||
"burn-tensor",
|
||||
"burn-tensor-testgen",
|
||||
"derive-new",
|
||||
"log",
|
||||
"spin",
|
||||
]
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.13.0", opt
|
|||
|
||||
derive-new = { workspace = true }
|
||||
spin = { workspace = true }
|
||||
log = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.13.0", default-features = false, features = [
|
||||
|
|
|
@ -61,6 +61,7 @@ impl GraphMemoryManagement {
|
|||
|
||||
for node_id in graph.into_iter() {
|
||||
func(&node_id);
|
||||
self.graphs.remove(&GraphId::new(*node_id));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -258,6 +259,9 @@ mod tests {
|
|||
assert!(node_ids.contains(&node_1));
|
||||
assert!(node_ids.contains(&node_2));
|
||||
|
||||
assert_eq!(graph_mm.graphs.len(), 0);
|
||||
assert_eq!(graph_mm.owned.len(), 0);
|
||||
|
||||
// Same but with free(node_2);
|
||||
graph_mm.register(node_1.clone(), vec![]);
|
||||
graph_mm.register(node_2.clone(), vec![*node_1]);
|
||||
|
@ -267,5 +271,8 @@ mod tests {
|
|||
|
||||
assert!(node_ids.contains(&node_1));
|
||||
assert!(node_ids.contains(&node_2));
|
||||
|
||||
assert_eq!(graph_mm.graphs.len(), 0);
|
||||
assert_eq!(graph_mm.owned.len(), 0);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ use crate::{
|
|||
checkpoint::{base::Checkpointer, builder::CheckpointerBuilder},
|
||||
grads::Gradients,
|
||||
graph::{traversal::BreadthFirstSearch, StepBoxed},
|
||||
runtime::memory_management::GraphId,
|
||||
tensor::NodeRefCount,
|
||||
NodeID,
|
||||
};
|
||||
|
@ -63,6 +64,10 @@ impl AutodiffServer {
|
|||
.collect::<Vec<_>>();
|
||||
|
||||
BreadthFirstSearch.traverse(root, root_step, &mut self.steps, |id, step| {
|
||||
// We consume that node for the tape, so we should remove it from the
|
||||
// memory_management.
|
||||
self.memory_management.free_graph(GraphId::new(id), |_| {});
|
||||
|
||||
let order = step.order();
|
||||
if order == 0 {
|
||||
return;
|
||||
|
|
Loading…
Reference in New Issue