This commit is contained in:
louisfd 2024-02-07 10:04:15 -05:00
parent 7cade3caa3
commit e321dabd6d
7 changed files with 89 additions and 41 deletions

View File

@ -46,10 +46,12 @@ impl RetroForwards {
self.map.insert(node_id, retro_forward);
}
pub(crate) fn merge(self, other: Self) -> Self {
Self {
map: self.map.into_iter().chain(other.map.into_iter()).collect(),
}
pub(crate) fn extend(&mut self, other: Self) {
self.map.extend(other.map);
}
pub(crate) fn len(&self) -> usize {
self.map.len()
}
}
@ -70,10 +72,8 @@ impl NodeTree {
self.map.insert(node_id, node_ref);
}
pub(crate) fn merge(self, other: Self) -> Self {
Self {
map: self.map.into_iter().chain(other.map.into_iter()).collect(),
}
pub(crate) fn extend(&mut self, other: Self) {
self.map.extend(other.map);
}
}
@ -144,13 +144,14 @@ impl Checkpointer {
}
}
pub fn merge(self, other: Self) -> Self {
// TODO mut with extend
Self {
backward_states: self.backward_states.merge(other.backward_states),
node_tree: self.node_tree.merge(other.node_tree),
retro_forwards: self.retro_forwards.merge(other.retro_forwards),
}
pub fn extend(&mut self, other: Self) {
self.backward_states.extend(other.backward_states);
self.node_tree.extend(other.node_tree);
self.retro_forwards.extend(other.retro_forwards);
}
pub fn len(&self) -> usize {
self.backward_states.len() + self.retro_forwards.len()
}
pub fn register_retro_forward(

View File

@ -133,9 +133,11 @@ impl BackwardStates {
);
}
pub(crate) fn merge(self, other: Self) -> Self {
Self {
map: self.map.into_iter().chain(other.map.into_iter()).collect(),
}
pub(crate) fn extend(&mut self, other: Self) {
self.map.extend(other.map);
}
pub(crate) fn len(&self) -> usize {
self.map.len()
}
}

View File

@ -6,9 +6,10 @@ use super::{traversal::BreadthFirstSearch, Graph, NodeRef, StepBoxed};
pub fn backward<B: Backend, const D: usize>(root: AutodiffTensor<B, D>) -> Gradients {
let grads = Gradients::new::<B, D>(root.node.clone(), root.primitive);
let checkpointer = root.graph.take_checkpointer();
let tape = build_tape(root.node, root.graph);
execute_steps(tape, grads, root.graph.get_checkpointer())
execute_steps(tape, grads, checkpointer)
}
fn build_tape(root: NodeRef, graph: Graph) -> Vec<Vec<StepBoxed>> {
@ -32,12 +33,12 @@ fn build_tape(root: NodeRef, graph: Graph) -> Vec<Vec<StepBoxed>> {
fn execute_steps(
tape: Vec<Vec<StepBoxed>>,
mut grads: Gradients,
checkpointer: &mut Checkpointer,
mut checkpointer: Checkpointer,
) -> Gradients {
tape.into_iter().rev().for_each(|steps| {
steps
.into_iter()
.for_each(|step| step.step(&mut grads, checkpointer))
.for_each(|step| step.step(&mut grads, &mut checkpointer))
});
grads
}

View File

@ -45,7 +45,7 @@ impl Graph {
/// keeping all the tensors alive for multiple backward call is a heavy waste of resources.
pub fn steps(self) -> NodeSteps {
let mut map_drain = HashMap::new();
self.execute_mut(|map| {
self.execute_mut_steps(|map| {
std::mem::swap(&mut *map, &mut map_drain);
});
map_drain
@ -53,7 +53,7 @@ impl Graph {
/// Register a new step into the graph.
pub fn register(self, id: &NodeID, ops: StepBoxed) -> Self {
self.execute_mut(|map| {
self.execute_mut_steps(|map| {
map.insert(id.clone(), ops);
})
}
@ -67,7 +67,7 @@ impl Graph {
self.merge_different(other)
}
fn execute_mut<F: FnOnce(&mut NodeSteps)>(mut self, func: F) -> Self {
fn execute_mut_steps<F: FnOnce(&mut NodeSteps)>(mut self, func: F) -> Self {
match Arc::get_mut(&mut self.steps) {
Some(mutex) => {
let map = mutex.get_mut();
@ -84,12 +84,10 @@ impl Graph {
}
fn merge_different(self, other: Self) -> Self {
// TODO better follow the same pattern where largest extends itself with largest
let checkpointer = self.checkpointer.lock().merge(*other.checkpointer.lock());
let mut map2 = other.clone().steps();
let mut checkpointer2 = other.checkpointer_own();
let mut map2 = other.steps();
self.execute_mut(|map1| {
self.execute_mut_steps(|map1| {
if map1.len() > map2.len() {
map1.extend(map2);
} else {
@ -99,6 +97,16 @@ impl Graph {
std::mem::swap(map1, &mut map2);
}
})
.execute_mut_checkpointer(|checkpointer1| {
if checkpointer1.len() > checkpointer2.len() {
checkpointer1.extend(checkpointer2);
} else {
let mut checkpointer_drain = Checkpointer::default();
std::mem::swap(checkpointer1, &mut checkpointer_drain);
checkpointer2.extend(checkpointer_drain);
std::mem::swap(checkpointer1, &mut checkpointer2);
}
})
}
pub fn checkpoint_register<T: Clone + Send + Sync + 'static>(
@ -118,7 +126,36 @@ impl Graph {
.register_retro_forward(node_id, retro_forward)
}
pub fn get_checkpointer(&self) -> &mut Checkpointer {
&mut self.checkpointer.lock()
/// # Notes
///
/// This is a owned method, so the current checkpointer will be freed.
pub fn checkpointer_own(self) -> Checkpointer {
let mut checkpointer_drain = Checkpointer::default();
self.execute_mut_checkpointer(|checkpointer| {
std::mem::swap(&mut *checkpointer, &mut checkpointer_drain);
});
checkpointer_drain
}
fn execute_mut_checkpointer<F: FnOnce(&mut Checkpointer)>(mut self, func: F) -> Self {
match Arc::get_mut(&mut self.checkpointer) {
Some(mutex) => {
let map = mutex.get_mut();
func(map);
}
None => {
// Only lock when there are multiple references to the graph.
let mut checkpointer = self.checkpointer.lock();
func(&mut checkpointer);
}
};
self
}
pub fn take_checkpointer(&self) -> Checkpointer {
let mut guard = self.checkpointer.lock();
let owned: Checkpointer = std::mem::replace(&mut *guard, Checkpointer::default());
owned
}
}

View File

@ -8,6 +8,7 @@ use crate::{
tensor::AutodiffTensor,
};
use burn_tensor::{backend::Backend, Shape};
use spin::Mutex;
use std::{marker::PhantomData, sync::Arc};
pub enum CheckpointStrategy {

View File

@ -333,10 +333,10 @@ impl<B: Backend> FloatTensorOps<Self> for Autodiff<B> {
let (lhs, rhs, broadcast) = match ops.state {
Some(state) => state,
None => {
let lhs: B::FloatTensorPrimitive<D> =
checkpointer.retrieve_output(ops.parents[0].unwrap().id);
let rhs: B::FloatTensorPrimitive<D> =
checkpointer.retrieve_output(ops.parents[1].unwrap().id);
let lhs: B::FloatTensorPrimitive<D> = checkpointer
.retrieve_output(ops.parents[0].clone().unwrap().id.clone());
let rhs: B::FloatTensorPrimitive<D> = checkpointer
.retrieve_output(ops.parents[1].clone().unwrap().id.clone());
let broadcast = BinaryOpsBroadcast::new::<B>(&lhs, &rhs);
// TODO None if not tracked
(Some(lhs), Some(rhs), broadcast)
@ -371,9 +371,13 @@ impl<B: Backend> FloatTensorOps<Self> for Autodiff<B> {
let rhs_tracked = rhs.is_tracked();
let broadcast = BinaryOpsBroadcast::new::<B>(&lhs.primitive, &rhs.primitive);
// TODO had to add a lot of clone, for nodes then for their ids. same goes for parents in backward()
match Div
.prepare([lhs.node, rhs.node], [lhs.graph, rhs.graph])
.recomputed(Box::new(RetroDiv::<B, D>::new(lhs.node.id, rhs.node.id))) // the box is painful, try to remove from api
.prepare([lhs.node.clone(), rhs.node.clone()], [lhs.graph, rhs.graph])
.recomputed(Box::new(RetroDiv::<B, D>::new(
lhs.node.id.clone(),
rhs.node.id.clone(),
))) // the box is painful, try to remove from api
.state_lazy()
{
// TODO if state is lazy then we just ignore the state below. should not compute it for nothing

View File

@ -38,7 +38,7 @@ impl<B: Backend, const D: usize> AutodiffTensor<B, D> {
let node = Node::new(vec![], 0, id, Requirement::None);
let graph = Graph::new();
graph.checkpoint_register(node.id, primitive, 1); // n_required arbitrary
graph.checkpoint_register(node.id.clone(), primitive.clone(), 1); // n_required arbitrary
Self {
primitive,
@ -98,9 +98,11 @@ impl<B: Backend, const D: usize> AutodiffTensor<B, D> {
);
match checkpoint_strategy {
CheckpointStrategy::Computed => graph.checkpoint_register(node.id, output, 1),
CheckpointStrategy::Computed => {
graph.checkpoint_register(node.id.clone(), output.clone(), 1)
}
CheckpointStrategy::Recompute { retro_forward } => {
graph.retro_register(node.id, retro_forward)
graph.retro_register(node.id.clone(), retro_forward)
}
}