Autodiff Memory Management: BFS (#1710)

This commit is contained in:
Louis Fortier-Dubois 2024-05-03 09:45:21 -04:00 committed by GitHub
parent 5d959e2884
commit a8661a2f53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 178 additions and 69 deletions

View File

@ -12,19 +12,13 @@ pub struct GraphMemoryManagement {
statuses: HashMap<NodeID, NodeMemoryStatus>,
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
enum NodeMemoryStatus {
Useful,
Unavailable,
Unknown,
}
#[derive(Clone)]
enum Mode {
TagAsUseful,
Explore,
}
impl GraphMemoryManagement {
/// Register a new node with its parent.
pub fn register(&mut self, node: NodeRefCount, parents: Vec<NodeID>) {
@ -66,9 +60,7 @@ impl GraphMemoryManagement {
// available node with a tensor reference exist in their descendance.
// But some may seem useless from some leaf but be useful from another one,
// hence the need to iterate on all leaves.
for leaf in leaves.clone() {
self.useful_propagation(leaf, Mode::Explore);
}
self.useful_propagation(leaves.clone());
// New leaves are the roots of a useful backward sub-tree.
// Deletables are everything not marked as useful.
@ -115,49 +107,112 @@ impl GraphMemoryManagement {
}
}
fn useful_propagation(&mut self, node_id: NodeID, mode: Mode) {
let parents = self.nodes.get(&node_id).cloned().unwrap_or(vec![]);
fn useful_propagation(&mut self, leaves: HashSet<NodeID>) {
// Accumulate visited nodes
let mut explored = HashSet::new();
let mut tagged_useful = HashSet::new();
match mode {
Mode::TagAsUseful => {
self.statuses.insert(node_id, NodeMemoryStatus::Useful);
for parent in parents {
self.useful_propagation(parent, Mode::TagAsUseful)
}
}
Mode::Explore => {
// Queue of nodes to visit
let mut to_tag_useful = PopNodeSet::default();
let mut to_explore = PopNodeSet::new(leaves);
// Utilitary function to iterate over a node's parents
let parents = |node_id| {
self.nodes
.get(&node_id)
.cloned()
.unwrap_or_default()
.into_iter()
};
loop {
// Pop a node id, greedily looking at tag_useful ones first
let (node_id, status) = match to_tag_useful.pop() {
Some(node_id) => (node_id, NodeMemoryStatus::Useful),
None => match to_explore.pop() {
Some(node_id) => {
let node_status = self
.statuses
.get(&node_id)
.expect("All nodes should have received a status at this point")
.clone();
.expect("All nodes should have received a status during unavailable_propagation")
.to_owned();
match node_status {
if let NodeMemoryStatus::Unknown = node_status {
match self.is_referenced(node_id) {
true => (node_id, NodeMemoryStatus::Useful),
false => (node_id, NodeMemoryStatus::Unknown),
}
} else {
(node_id, node_status)
}
}
None => {
// There are no nodes in the queues anymore
break;
}
},
};
match status {
NodeMemoryStatus::Useful => {
// Nothing to do, was already tagged through some other path
}
NodeMemoryStatus::Unavailable => {
// Even if this node is unavailable, it is still possible that an ancestor is useful if referenced
for parent in parents {
self.useful_propagation(parent, Mode::Explore);
tagged_useful.insert(node_id);
for parent in parents(node_id) {
// The node can be explored, as long as it's not already tagged useful
if !(tagged_useful.contains(&parent) || to_tag_useful.contains(&parent)) {
to_tag_useful.insert(parent);
}
}
}
_ => {
explored.insert(node_id);
for parent in parents(node_id) {
if !(explored.contains(&parent) || to_explore.contains(&parent)) {
to_explore.insert(parent);
}
}
}
NodeMemoryStatus::Unknown => {
// If this node is referenced and not unavailable,
// then it is useful and we must retain all ancestors
let mut mode = Mode::Explore;
if self.is_referenced(node_id) {
self.statuses.insert(node_id, NodeMemoryStatus::Useful);
mode = Mode::TagAsUseful;
}
for parent in parents {
self.useful_propagation(parent, mode.clone());
}
self.statuses.insert(node_id, status);
}
}
fn identify_leaves_and_deletables(
&self,
leaf_id: NodeID,
new_leaves: &mut HashSet<NodeID>,
to_delete: &mut Vec<NodeID>,
) {
let mut visited = HashSet::new();
let mut to_visit = vec![leaf_id];
while let Some(node_id) = to_visit.pop() {
visited.insert(node_id);
match self
.statuses
.get(&node_id)
.expect("Node should have status")
{
NodeMemoryStatus::Useful => {
new_leaves.insert(node_id);
}
_ => {
to_delete.push(node_id);
for parent in self
.nodes
.get(&node_id)
.cloned()
.unwrap_or_default()
.into_iter()
{
if !visited.contains(&parent) {
to_visit.push(parent);
}
}
}
};
}
}
@ -167,29 +222,31 @@ impl GraphMemoryManagement {
None => panic!("Node should be in the nodes map"),
}
}
}
fn identify_leaves_and_deletables(
&self,
node_id: NodeID,
new_leaves: &mut HashSet<NodeID>,
to_delete: &mut Vec<NodeID>,
) {
let current_status = self
.statuses
.get(&node_id)
.expect("Node should have status");
/// Wrapper over hash set for fast popping of any node
#[derive(new, Default)]
struct PopNodeSet {
hash_set: HashSet<NodeID>,
}
match current_status {
NodeMemoryStatus::Useful => {
new_leaves.insert(node_id);
}
_ => {
let parents = self.nodes.get(&node_id).cloned().unwrap_or(vec![]);
for parent in parents {
self.identify_leaves_and_deletables(parent, new_leaves, to_delete)
}
to_delete.push(node_id);
impl PopNodeSet {
#[inline(always)]
fn pop(&mut self) -> Option<NodeID> {
self.hash_set
.iter()
.next()
.copied()
.and_then(|node_id| self.hash_set.take(&node_id))
}
#[inline(always)]
fn contains(&self, node_id: &NodeID) -> bool {
self.hash_set.contains(node_id)
}
#[inline(always)]
fn insert(&mut self, node_id: NodeID) {
self.hash_set.insert(node_id);
}
}

View File

@ -238,4 +238,56 @@ mod tests {
assert!(tensor_2.grad(&grads).is_some());
assert!(tensor_3.grad(&grads).is_none());
}
#[test]
#[should_panic]
fn test_mm_deletables_propagate_well() {
let data = Data::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
let tensor_0 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_1 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_2 = tensor_0 * tensor_1;
let tensor_3 = tensor_2.clone().exp();
let tensor_4 = tensor_3.clone().log();
let grads = tensor_2.backward();
// We are testing that after backward on tensor_2, not only the leaf tensor_4 is deleted, but
// the intermediate tensor_3 as well
let grads = tensor_3.backward();
}
#[test]
fn test_mm_node_explored_once_can_still_be_tagged_as_useful_when_found_again_deeper() {
let data = Data::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();
// The test has 50% chance of starting with leaf tensor_8 instead of tensor_4, which is not informative
// By repeating it many times it becomes almost impossible that it passes if it shouldn't
for _ in 0..12 {
let tensor_0 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_1 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_2 = tensor_1.clone().exp();
let tensor_3 = tensor_0.exp();
let tensor_4 = tensor_3.clone() * tensor_2.clone();
let tensor_5 = tensor_2.exp();
let tensor_6 = tensor_5.exp();
let tensor_7 = tensor_6.exp();
let tensor_8 = tensor_7.exp();
// tensor_2 should be tagged unknown through the leaf tensor_4, then useful through the leaf tensor_8
// which should happen after because tensor_2 is deeper from tensor_8 point of view and we're in breadth first search
tensor_3.backward();
let grads = tensor_8.backward();
assert!(tensor_1.grad(&grads).is_some());
}
}
}