mirror of https://github.com/tracel-ai/burn.git
Autodiff Memory Management: BFS (#1710)
This commit is contained in:
parent
5d959e2884
commit
a8661a2f53
|
@ -12,19 +12,13 @@ pub struct GraphMemoryManagement {
|
|||
statuses: HashMap<NodeID, NodeMemoryStatus>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum NodeMemoryStatus {
|
||||
Useful,
|
||||
Unavailable,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum Mode {
|
||||
TagAsUseful,
|
||||
Explore,
|
||||
}
|
||||
|
||||
impl GraphMemoryManagement {
|
||||
/// Register a new node with its parent.
|
||||
pub fn register(&mut self, node: NodeRefCount, parents: Vec<NodeID>) {
|
||||
|
@ -66,9 +60,7 @@ impl GraphMemoryManagement {
|
|||
// available node with a tensor reference exist in their descendance.
|
||||
// But some may seem useless from some leaf but be useful from another one,
|
||||
// hence the need to iterate on all leaves.
|
||||
for leaf in leaves.clone() {
|
||||
self.useful_propagation(leaf, Mode::Explore);
|
||||
}
|
||||
self.useful_propagation(leaves.clone());
|
||||
|
||||
// New leaves are the roots of a useful backward sub-tree.
|
||||
// Deletables are everything not marked as useful.
|
||||
|
@ -115,49 +107,112 @@ impl GraphMemoryManagement {
|
|||
}
|
||||
}
|
||||
|
||||
fn useful_propagation(&mut self, node_id: NodeID, mode: Mode) {
|
||||
let parents = self.nodes.get(&node_id).cloned().unwrap_or(vec![]);
|
||||
fn useful_propagation(&mut self, leaves: HashSet<NodeID>) {
|
||||
// Accumulate visited nodes
|
||||
let mut explored = HashSet::new();
|
||||
let mut tagged_useful = HashSet::new();
|
||||
|
||||
match mode {
|
||||
Mode::TagAsUseful => {
|
||||
self.statuses.insert(node_id, NodeMemoryStatus::Useful);
|
||||
for parent in parents {
|
||||
self.useful_propagation(parent, Mode::TagAsUseful)
|
||||
// Queue of nodes to visit
|
||||
let mut to_tag_useful = PopNodeSet::default();
|
||||
let mut to_explore = PopNodeSet::new(leaves);
|
||||
|
||||
// Utilitary function to iterate over a node's parents
|
||||
let parents = |node_id| {
|
||||
self.nodes
|
||||
.get(&node_id)
|
||||
.cloned()
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
};
|
||||
|
||||
loop {
|
||||
// Pop a node id, greedily looking at tag_useful ones first
|
||||
let (node_id, status) = match to_tag_useful.pop() {
|
||||
Some(node_id) => (node_id, NodeMemoryStatus::Useful),
|
||||
None => match to_explore.pop() {
|
||||
Some(node_id) => {
|
||||
let node_status = self
|
||||
.statuses
|
||||
.get(&node_id)
|
||||
.expect("All nodes should have received a status during unavailable_propagation")
|
||||
.to_owned();
|
||||
|
||||
if let NodeMemoryStatus::Unknown = node_status {
|
||||
match self.is_referenced(node_id) {
|
||||
true => (node_id, NodeMemoryStatus::Useful),
|
||||
false => (node_id, NodeMemoryStatus::Unknown),
|
||||
}
|
||||
} else {
|
||||
(node_id, node_status)
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// There are no nodes in the queues anymore
|
||||
break;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
match status {
|
||||
NodeMemoryStatus::Useful => {
|
||||
tagged_useful.insert(node_id);
|
||||
for parent in parents(node_id) {
|
||||
// The node can be explored, as long as it's not already tagged useful
|
||||
if !(tagged_useful.contains(&parent) || to_tag_useful.contains(&parent)) {
|
||||
to_tag_useful.insert(parent);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Mode::Explore => {
|
||||
let node_status = self
|
||||
.statuses
|
||||
.get(&node_id)
|
||||
.expect("All nodes should have received a status at this point")
|
||||
.clone();
|
||||
|
||||
match node_status {
|
||||
NodeMemoryStatus::Useful => {
|
||||
// Nothing to do, was already tagged through some other path
|
||||
}
|
||||
NodeMemoryStatus::Unavailable => {
|
||||
// Even if this node is unavailable, it is still possible that an ancestor is useful if referenced
|
||||
for parent in parents {
|
||||
self.useful_propagation(parent, Mode::Explore);
|
||||
}
|
||||
}
|
||||
NodeMemoryStatus::Unknown => {
|
||||
// If this node is referenced and not unavailable,
|
||||
// then it is useful and we must retain all ancestors
|
||||
|
||||
let mut mode = Mode::Explore;
|
||||
if self.is_referenced(node_id) {
|
||||
self.statuses.insert(node_id, NodeMemoryStatus::Useful);
|
||||
mode = Mode::TagAsUseful;
|
||||
}
|
||||
|
||||
for parent in parents {
|
||||
self.useful_propagation(parent, mode.clone());
|
||||
_ => {
|
||||
explored.insert(node_id);
|
||||
for parent in parents(node_id) {
|
||||
if !(explored.contains(&parent) || to_explore.contains(&parent)) {
|
||||
to_explore.insert(parent);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.statuses.insert(node_id, status);
|
||||
}
|
||||
}
|
||||
|
||||
fn identify_leaves_and_deletables(
|
||||
&self,
|
||||
leaf_id: NodeID,
|
||||
new_leaves: &mut HashSet<NodeID>,
|
||||
to_delete: &mut Vec<NodeID>,
|
||||
) {
|
||||
let mut visited = HashSet::new();
|
||||
let mut to_visit = vec![leaf_id];
|
||||
|
||||
while let Some(node_id) = to_visit.pop() {
|
||||
visited.insert(node_id);
|
||||
|
||||
match self
|
||||
.statuses
|
||||
.get(&node_id)
|
||||
.expect("Node should have status")
|
||||
{
|
||||
NodeMemoryStatus::Useful => {
|
||||
new_leaves.insert(node_id);
|
||||
}
|
||||
_ => {
|
||||
to_delete.push(node_id);
|
||||
|
||||
for parent in self
|
||||
.nodes
|
||||
.get(&node_id)
|
||||
.cloned()
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
{
|
||||
if !visited.contains(&parent) {
|
||||
to_visit.push(parent);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -167,29 +222,31 @@ impl GraphMemoryManagement {
|
|||
None => panic!("Node should be in the nodes map"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn identify_leaves_and_deletables(
|
||||
&self,
|
||||
node_id: NodeID,
|
||||
new_leaves: &mut HashSet<NodeID>,
|
||||
to_delete: &mut Vec<NodeID>,
|
||||
) {
|
||||
let current_status = self
|
||||
.statuses
|
||||
.get(&node_id)
|
||||
.expect("Node should have status");
|
||||
/// Wrapper over hash set for fast popping of any node
|
||||
#[derive(new, Default)]
|
||||
struct PopNodeSet {
|
||||
hash_set: HashSet<NodeID>,
|
||||
}
|
||||
|
||||
match current_status {
|
||||
NodeMemoryStatus::Useful => {
|
||||
new_leaves.insert(node_id);
|
||||
}
|
||||
_ => {
|
||||
let parents = self.nodes.get(&node_id).cloned().unwrap_or(vec![]);
|
||||
for parent in parents {
|
||||
self.identify_leaves_and_deletables(parent, new_leaves, to_delete)
|
||||
}
|
||||
to_delete.push(node_id);
|
||||
}
|
||||
}
|
||||
impl PopNodeSet {
|
||||
#[inline(always)]
|
||||
fn pop(&mut self) -> Option<NodeID> {
|
||||
self.hash_set
|
||||
.iter()
|
||||
.next()
|
||||
.copied()
|
||||
.and_then(|node_id| self.hash_set.take(&node_id))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn contains(&self, node_id: &NodeID) -> bool {
|
||||
self.hash_set.contains(node_id)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn insert(&mut self, node_id: NodeID) {
|
||||
self.hash_set.insert(node_id);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -238,4 +238,56 @@ mod tests {
|
|||
assert!(tensor_2.grad(&grads).is_some());
|
||||
assert!(tensor_3.grad(&grads).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_mm_deletables_propagate_well() {
|
||||
let data = Data::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
|
||||
let tensor_0 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_1 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_2 = tensor_0 * tensor_1;
|
||||
let tensor_3 = tensor_2.clone().exp();
|
||||
let tensor_4 = tensor_3.clone().log();
|
||||
|
||||
let grads = tensor_2.backward();
|
||||
|
||||
// We are testing that after backward on tensor_2, not only the leaf tensor_4 is deleted, but
|
||||
// the intermediate tensor_3 as well
|
||||
let grads = tensor_3.backward();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mm_node_explored_once_can_still_be_tagged_as_useful_when_found_again_deeper() {
|
||||
let data = Data::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
let device = Default::default();
|
||||
|
||||
// The test has 50% chance of starting with leaf tensor_8 instead of tensor_4, which is not informative
|
||||
// By repeating it many times it becomes almost impossible that it passes if it shouldn't
|
||||
for _ in 0..12 {
|
||||
let tensor_0 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
let tensor_1 =
|
||||
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
|
||||
|
||||
let tensor_2 = tensor_1.clone().exp();
|
||||
let tensor_3 = tensor_0.exp();
|
||||
let tensor_4 = tensor_3.clone() * tensor_2.clone();
|
||||
let tensor_5 = tensor_2.exp();
|
||||
let tensor_6 = tensor_5.exp();
|
||||
let tensor_7 = tensor_6.exp();
|
||||
let tensor_8 = tensor_7.exp();
|
||||
|
||||
// tensor_2 should be tagged unknown through the leaf tensor_4, then useful through the leaf tensor_8
|
||||
// which should happen after because tensor_2 is deeper from tensor_8 point of view and we're in breadth first search
|
||||
tensor_3.backward();
|
||||
let grads = tensor_8.backward();
|
||||
|
||||
assert!(tensor_1.grad(&grads).is_some());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue