mirror of https://github.com/tracel-ai/burn.git
compiles
This commit is contained in:
parent
7cade3caa3
commit
e321dabd6d
|
@ -46,10 +46,12 @@ impl RetroForwards {
|
|||
self.map.insert(node_id, retro_forward);
|
||||
}
|
||||
|
||||
pub(crate) fn merge(self, other: Self) -> Self {
|
||||
Self {
|
||||
map: self.map.into_iter().chain(other.map.into_iter()).collect(),
|
||||
}
|
||||
pub(crate) fn extend(&mut self, other: Self) {
|
||||
self.map.extend(other.map);
|
||||
}
|
||||
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
self.map.len()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -70,10 +72,8 @@ impl NodeTree {
|
|||
self.map.insert(node_id, node_ref);
|
||||
}
|
||||
|
||||
pub(crate) fn merge(self, other: Self) -> Self {
|
||||
Self {
|
||||
map: self.map.into_iter().chain(other.map.into_iter()).collect(),
|
||||
}
|
||||
pub(crate) fn extend(&mut self, other: Self) {
|
||||
self.map.extend(other.map);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -144,13 +144,14 @@ impl Checkpointer {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn merge(self, other: Self) -> Self {
|
||||
// TODO mut with extend
|
||||
Self {
|
||||
backward_states: self.backward_states.merge(other.backward_states),
|
||||
node_tree: self.node_tree.merge(other.node_tree),
|
||||
retro_forwards: self.retro_forwards.merge(other.retro_forwards),
|
||||
}
|
||||
pub fn extend(&mut self, other: Self) {
|
||||
self.backward_states.extend(other.backward_states);
|
||||
self.node_tree.extend(other.node_tree);
|
||||
self.retro_forwards.extend(other.retro_forwards);
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.backward_states.len() + self.retro_forwards.len()
|
||||
}
|
||||
|
||||
pub fn register_retro_forward(
|
||||
|
|
|
@ -133,9 +133,11 @@ impl BackwardStates {
|
|||
);
|
||||
}
|
||||
|
||||
pub(crate) fn merge(self, other: Self) -> Self {
|
||||
Self {
|
||||
map: self.map.into_iter().chain(other.map.into_iter()).collect(),
|
||||
}
|
||||
pub(crate) fn extend(&mut self, other: Self) {
|
||||
self.map.extend(other.map);
|
||||
}
|
||||
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
self.map.len()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,9 +6,10 @@ use super::{traversal::BreadthFirstSearch, Graph, NodeRef, StepBoxed};
|
|||
|
||||
pub fn backward<B: Backend, const D: usize>(root: AutodiffTensor<B, D>) -> Gradients {
|
||||
let grads = Gradients::new::<B, D>(root.node.clone(), root.primitive);
|
||||
let checkpointer = root.graph.take_checkpointer();
|
||||
let tape = build_tape(root.node, root.graph);
|
||||
|
||||
execute_steps(tape, grads, root.graph.get_checkpointer())
|
||||
execute_steps(tape, grads, checkpointer)
|
||||
}
|
||||
|
||||
fn build_tape(root: NodeRef, graph: Graph) -> Vec<Vec<StepBoxed>> {
|
||||
|
@ -32,12 +33,12 @@ fn build_tape(root: NodeRef, graph: Graph) -> Vec<Vec<StepBoxed>> {
|
|||
fn execute_steps(
|
||||
tape: Vec<Vec<StepBoxed>>,
|
||||
mut grads: Gradients,
|
||||
checkpointer: &mut Checkpointer,
|
||||
mut checkpointer: Checkpointer,
|
||||
) -> Gradients {
|
||||
tape.into_iter().rev().for_each(|steps| {
|
||||
steps
|
||||
.into_iter()
|
||||
.for_each(|step| step.step(&mut grads, checkpointer))
|
||||
.for_each(|step| step.step(&mut grads, &mut checkpointer))
|
||||
});
|
||||
grads
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ impl Graph {
|
|||
/// keeping all the tensors alive for multiple backward call is a heavy waste of resources.
|
||||
pub fn steps(self) -> NodeSteps {
|
||||
let mut map_drain = HashMap::new();
|
||||
self.execute_mut(|map| {
|
||||
self.execute_mut_steps(|map| {
|
||||
std::mem::swap(&mut *map, &mut map_drain);
|
||||
});
|
||||
map_drain
|
||||
|
@ -53,7 +53,7 @@ impl Graph {
|
|||
|
||||
/// Register a new step into the graph.
|
||||
pub fn register(self, id: &NodeID, ops: StepBoxed) -> Self {
|
||||
self.execute_mut(|map| {
|
||||
self.execute_mut_steps(|map| {
|
||||
map.insert(id.clone(), ops);
|
||||
})
|
||||
}
|
||||
|
@ -67,7 +67,7 @@ impl Graph {
|
|||
self.merge_different(other)
|
||||
}
|
||||
|
||||
fn execute_mut<F: FnOnce(&mut NodeSteps)>(mut self, func: F) -> Self {
|
||||
fn execute_mut_steps<F: FnOnce(&mut NodeSteps)>(mut self, func: F) -> Self {
|
||||
match Arc::get_mut(&mut self.steps) {
|
||||
Some(mutex) => {
|
||||
let map = mutex.get_mut();
|
||||
|
@ -84,12 +84,10 @@ impl Graph {
|
|||
}
|
||||
|
||||
fn merge_different(self, other: Self) -> Self {
|
||||
// TODO better follow the same pattern where largest extends itself with largest
|
||||
let checkpointer = self.checkpointer.lock().merge(*other.checkpointer.lock());
|
||||
let mut map2 = other.clone().steps();
|
||||
let mut checkpointer2 = other.checkpointer_own();
|
||||
|
||||
let mut map2 = other.steps();
|
||||
|
||||
self.execute_mut(|map1| {
|
||||
self.execute_mut_steps(|map1| {
|
||||
if map1.len() > map2.len() {
|
||||
map1.extend(map2);
|
||||
} else {
|
||||
|
@ -99,6 +97,16 @@ impl Graph {
|
|||
std::mem::swap(map1, &mut map2);
|
||||
}
|
||||
})
|
||||
.execute_mut_checkpointer(|checkpointer1| {
|
||||
if checkpointer1.len() > checkpointer2.len() {
|
||||
checkpointer1.extend(checkpointer2);
|
||||
} else {
|
||||
let mut checkpointer_drain = Checkpointer::default();
|
||||
std::mem::swap(checkpointer1, &mut checkpointer_drain);
|
||||
checkpointer2.extend(checkpointer_drain);
|
||||
std::mem::swap(checkpointer1, &mut checkpointer2);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn checkpoint_register<T: Clone + Send + Sync + 'static>(
|
||||
|
@ -118,7 +126,36 @@ impl Graph {
|
|||
.register_retro_forward(node_id, retro_forward)
|
||||
}
|
||||
|
||||
pub fn get_checkpointer(&self) -> &mut Checkpointer {
|
||||
&mut self.checkpointer.lock()
|
||||
/// # Notes
|
||||
///
|
||||
/// This is a owned method, so the current checkpointer will be freed.
|
||||
pub fn checkpointer_own(self) -> Checkpointer {
|
||||
let mut checkpointer_drain = Checkpointer::default();
|
||||
self.execute_mut_checkpointer(|checkpointer| {
|
||||
std::mem::swap(&mut *checkpointer, &mut checkpointer_drain);
|
||||
});
|
||||
checkpointer_drain
|
||||
}
|
||||
|
||||
fn execute_mut_checkpointer<F: FnOnce(&mut Checkpointer)>(mut self, func: F) -> Self {
|
||||
match Arc::get_mut(&mut self.checkpointer) {
|
||||
Some(mutex) => {
|
||||
let map = mutex.get_mut();
|
||||
func(map);
|
||||
}
|
||||
None => {
|
||||
// Only lock when there are multiple references to the graph.
|
||||
let mut checkpointer = self.checkpointer.lock();
|
||||
func(&mut checkpointer);
|
||||
}
|
||||
};
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn take_checkpointer(&self) -> Checkpointer {
|
||||
let mut guard = self.checkpointer.lock();
|
||||
let owned: Checkpointer = std::mem::replace(&mut *guard, Checkpointer::default());
|
||||
owned
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ use crate::{
|
|||
tensor::AutodiffTensor,
|
||||
};
|
||||
use burn_tensor::{backend::Backend, Shape};
|
||||
use spin::Mutex;
|
||||
use std::{marker::PhantomData, sync::Arc};
|
||||
|
||||
pub enum CheckpointStrategy {
|
||||
|
|
|
@ -333,10 +333,10 @@ impl<B: Backend> FloatTensorOps<Self> for Autodiff<B> {
|
|||
let (lhs, rhs, broadcast) = match ops.state {
|
||||
Some(state) => state,
|
||||
None => {
|
||||
let lhs: B::FloatTensorPrimitive<D> =
|
||||
checkpointer.retrieve_output(ops.parents[0].unwrap().id);
|
||||
let rhs: B::FloatTensorPrimitive<D> =
|
||||
checkpointer.retrieve_output(ops.parents[1].unwrap().id);
|
||||
let lhs: B::FloatTensorPrimitive<D> = checkpointer
|
||||
.retrieve_output(ops.parents[0].clone().unwrap().id.clone());
|
||||
let rhs: B::FloatTensorPrimitive<D> = checkpointer
|
||||
.retrieve_output(ops.parents[1].clone().unwrap().id.clone());
|
||||
let broadcast = BinaryOpsBroadcast::new::<B>(&lhs, &rhs);
|
||||
// TODO None if not tracked
|
||||
(Some(lhs), Some(rhs), broadcast)
|
||||
|
@ -371,9 +371,13 @@ impl<B: Backend> FloatTensorOps<Self> for Autodiff<B> {
|
|||
let rhs_tracked = rhs.is_tracked();
|
||||
let broadcast = BinaryOpsBroadcast::new::<B>(&lhs.primitive, &rhs.primitive);
|
||||
|
||||
// TODO had to add a lot of clone, for nodes then for their ids. same goes for parents in backward()
|
||||
match Div
|
||||
.prepare([lhs.node, rhs.node], [lhs.graph, rhs.graph])
|
||||
.recomputed(Box::new(RetroDiv::<B, D>::new(lhs.node.id, rhs.node.id))) // the box is painful, try to remove from api
|
||||
.prepare([lhs.node.clone(), rhs.node.clone()], [lhs.graph, rhs.graph])
|
||||
.recomputed(Box::new(RetroDiv::<B, D>::new(
|
||||
lhs.node.id.clone(),
|
||||
rhs.node.id.clone(),
|
||||
))) // the box is painful, try to remove from api
|
||||
.state_lazy()
|
||||
{
|
||||
// TODO if state is lazy then we just ignore the state below. should not compute it for nothing
|
||||
|
|
|
@ -38,7 +38,7 @@ impl<B: Backend, const D: usize> AutodiffTensor<B, D> {
|
|||
let node = Node::new(vec![], 0, id, Requirement::None);
|
||||
|
||||
let graph = Graph::new();
|
||||
graph.checkpoint_register(node.id, primitive, 1); // n_required arbitrary
|
||||
graph.checkpoint_register(node.id.clone(), primitive.clone(), 1); // n_required arbitrary
|
||||
|
||||
Self {
|
||||
primitive,
|
||||
|
@ -98,9 +98,11 @@ impl<B: Backend, const D: usize> AutodiffTensor<B, D> {
|
|||
);
|
||||
|
||||
match checkpoint_strategy {
|
||||
CheckpointStrategy::Computed => graph.checkpoint_register(node.id, output, 1),
|
||||
CheckpointStrategy::Computed => {
|
||||
graph.checkpoint_register(node.id.clone(), output.clone(), 1)
|
||||
}
|
||||
CheckpointStrategy::Recompute { retro_forward } => {
|
||||
graph.retro_register(node.id, retro_forward)
|
||||
graph.retro_register(node.id.clone(), retro_forward)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue