This commit is contained in:
louisfd 2024-01-23 11:00:01 -05:00
parent 3552f17cb2
commit 1c0798952e
5 changed files with 109 additions and 61 deletions

View File

@ -106,7 +106,7 @@ impl Graph {
) {
self.states
.get_mut()
.register(node, Box::new(StateStruct::<B, D>::new(output)))
.register(node, Box::new(StateStruct::<B, D, 1>::new([output])))
}
pub(crate) fn node_states(&self) -> &NodeStates {

View File

@ -2,7 +2,7 @@ use std::{collections::HashMap, fmt::Debug};
use burn_tensor::backend::Backend;
use crate::ops::Ops;
use crate::ops::{Ops, OpsSpec};
use super::{NodeID, NodeRef};
@ -27,14 +27,19 @@ pub enum Bottleneck {
// // Manual,
// // }
pub trait State: Sync + Send + Debug {}
pub trait State: Send + Sync + Debug + 'static {}
#[derive(new, Debug)]
pub struct StateStruct<B: Backend, const D: usize> {
tensor: B::TensorPrimitive<D>,
#[derive(new, Debug, Clone)]
pub struct StateStruct<B: Backend, const D: usize, const N: usize> {
pub tensors: [B::TensorPrimitive<D>; N],
}
impl<B: Backend, const D: usize> State for StateStruct<B, D> {}
impl<B: Backend, const D: usize, const N: usize> State for StateStruct<B, D, N> {}
// Not sure necessary, delete if possible
#[derive(new, Debug, Clone)]
pub struct StateNull {}
impl State for StateNull {}
#[derive(Default, Debug)]
pub struct NodeStates {
@ -48,21 +53,34 @@ impl NodeStates {
self.hashmap.insert(node_id, state);
}
pub fn get_input(&self, node: NodeRef) -> Vec<StateBoxed> {
pub fn get_input<B, OS, I, O, const D: usize, const N: usize>(&self, node: NodeRef) -> I
where
B: Backend,
OS: OpsSpec<B, D, N, Input = I, Output = O>,
I: State,
O: State,
{
node.parents
.iter()
.map(|parent| self.get_output(parent))
.map(|parent| self.get_output::<B, OS, I, O, D, N>(parent))
.collect()
}
pub fn get_output(&self, node_id: &NodeID) -> StateBoxed {
pub fn get_output<B, OS, I, O, const D: usize, const N: usize>(
&self,
node_id: &NodeID,
) -> StateBoxed
where
B: Backend,
OS: OpsSpec<B, D, N, Input = I, Output = O>,
I: State,
O: State,
{
match self.hashmap.remove(node_id) {
Some(state) => state,
None => {
// TODO not <_, _, 1>
let ops: Ops<_, _, 1> = self.get_ops_from_node_id(node_id);
let inputs = self.get_input(ops.node);
//node forward does not exist
let ops: Ops<B, OS, I, O, D, N> = self.get_ops_from_node_id(node_id);
let inputs = self.get_input::<B, OS, I, O, D, N>(ops.node);
Box::new(ops.forward(inputs))
}
}
@ -71,7 +89,22 @@ impl NodeStates {
// NodeStates must have access to a mapping from NodeRef/NodeID to Ops
// Otherwise how to get parents just with ids?
// And how to do the forward pass ?
fn get_ops_from_node_id<I, O, const N: usize>(&self, node_id: &NodeID) -> Ops<I, O, N> {
fn get_ops_from_node_id<B, OS, I, O, const D: usize, const N: usize>(
&self,
node_id: &NodeID,
) -> Ops<B, OS, I, O, D, N>
where
OS: OpsSpec<B, D, N, Input = I, Output = O>,
B: Backend,
I: State,
O: State,
{
todo!()
}
}
// STILL TO DO
// - Collect several Os into an I
// - node_id -> map of node_id -> Ops
// when registering, pass a pointer to the ops too

View File

@ -1,7 +1,7 @@
use crate::{
grads::Gradients,
graph::{
checkpoint::NodeStates,
checkpoint::{NodeStates, State, StateNull},
NodeRef, Requirement, {Graph, Step},
},
tensor::AutodiffTensor,
@ -34,10 +34,12 @@ pub struct Tracked;
/// Untracked operation tag.
pub struct UnTracked;
impl<OS, B, const D: usize, const N: usize> OpsPrep<OS, B, (), (), D, N, Init>
impl<OS, B, const D: usize, const N: usize, I, O> OpsPrep<OS, B, I, O, D, N, Init>
where
B: Backend,
OS: OpsSpec<B, D, N, Input = (), Output = ()>,
OS: OpsSpec<B, D, N, Input = I, Output = O>,
I: State,
O: State,
{
/// Prepare a stateless operation.
pub fn stateless(self, output: <B as Backend>::TensorPrimitive<D>) -> AutodiffTensor<B, D> {
@ -52,8 +54,8 @@ impl<OS, B, const D: usize, const N: usize, I, O> OpsPrep<OS, B, I, O, D, N, Ini
where
B: Backend,
OS: OpsSpec<B, D, N, Input = I, Output = O>,
I: Clone + Send + Sync + std::fmt::Debug + 'static,
O: Clone + Send + Sync + std::fmt::Debug + 'static,
I: State,
O: State,
{
/// Prepare an operation that requires a state during the backward pass.
pub fn stateful(self) -> OpsKind<OS, B, I, O, D, N> {
@ -78,8 +80,8 @@ impl<OS, B, const D: usize, const N: usize, I, O> OpsPrep<OS, B, I, O, D, N, UnT
where
B: Backend,
OS: OpsSpec<B, D, N, Input = I, Output = O>,
I: Clone + Send + Sync + std::fmt::Debug + 'static,
O: Clone + Send + Sync + std::fmt::Debug + 'static,
I: State,
O: State,
{
/// Finish the preparation of an untracked operation and returns the output tensor.
pub fn finish(self, output: <B as Backend>::TensorPrimitive<D>) -> AutodiffTensor<B, D> {
@ -96,8 +98,8 @@ impl<OS, B, const D: usize, const N: usize, I, O> OpsPrep<OS, B, I, O, D, N, Tra
where
B: Backend,
OS: OpsSpec<B, D, N, Input = I, Output = O>,
I: Clone + Send + Sync + std::fmt::Debug + 'static,
O: Clone + Send + Sync + std::fmt::Debug + 'static,
I: State,
O: State,
{
/// Finish the preparation of a tracked operation and returns the output tensor.
pub fn finish(
@ -112,24 +114,21 @@ where
self.graphs.into_iter(),
self.requirement,
);
let parents = self.nodes.map(|node| node.clone_if_require_grad());
let ops = Ops::new(parents, autodiff_tensor.node.clone());
// probably should have ops_spec in order to recompute stuff
// let ops = Ops::new(parents, autodiff_tensor.node.clone(), ops_spec);
// Access autodiff_tensor.graph, ask it to register
// <nodeid, output>, nodeid inside tensor
match ops_spec {
Some(ops_spec) => match ops_spec.bottleneck() {
ComputeBound => {
autodiff_tensor.register_output(output);
Some(ops_spec) => {
let parents = self.nodes.map(|node| node.clone_if_require_grad());
let ops = Ops::new(parents, autodiff_tensor.node.clone(), ops_spec);
match ops_spec.bottleneck() {
ComputeBound => {
autodiff_tensor.register_output(output);
}
MemoryBound => {}
}
MemoryBound => {}
},
None => {}
autodiff_tensor.register_step(OpsStep::new(ops, ops_spec))
}
None => autodiff_tensor,
}
autodiff_tensor.register_step(OpsStep::new(ops, self.ops_spec))
}
}
@ -143,36 +142,51 @@ pub enum OpsKind<OS, B, I, O, const D: usize, const N: usize> {
/// Operation containing its parent nodes, its own node and the backward step state.
#[derive(new, Debug)]
pub struct Ops<I, O, const N: usize> {
// pub struct Ops<OS, I, O, const N: usize> {
pub struct Ops<B, OS, I, O, const D: usize, const N: usize>
where
B: Backend,
OS: OpsSpec<B, D, N, Input = I, Output = O>,
I: State,
O: State,
{
/// Parents nodes.
pub parents: [Option<NodeRef>; N],
/// The node.
pub node: NodeRef,
pub ops_spec: OS,
pub _backend: PhantomData<B>,
pub _input: PhantomData<I>,
pub _output: PhantomData<O>,
}
impl<I, O, const N: usize> Ops<I, O, N> {
impl<B, OS, I, O, const D: usize, const N: usize> Ops<B, OS, I, O, D, N>
where
B: Backend,
OS: OpsSpec<B, D, N, Input = I, Output = O>,
I: State,
O: State,
{
pub fn fetch_inputs(&self, states: &NodeStates) -> I {
states.get_input(self.node)
states.get_input::<B, OS, I, O, D, N>(self.node)
}
pub(crate) fn forward(&self, inputs: Vec<Box<dyn State>>) -> dyn State {
pub(crate) fn forward(&self, inputs: I) -> O {
self.ops_spec.forward(inputs)
}
}
/// Operation implementing backward [step](Step) with type erasing.
#[derive(new, Debug)]
struct OpsStep<B, T, I, O, const D: usize, const N: usize>
struct OpsStep<B, OS, I, O, const D: usize, const N: usize>
where
B: Backend,
T: OpsSpec<B, D, N, Input = I, Output = O>,
I: Clone + Send + Sync + std::fmt::Debug + 'static,
O: Clone + Send + Sync + std::fmt::Debug + 'static,
OS: OpsSpec<B, D, N, Input = I, Output = O>,
I: State,
O: State,
{
ops: Ops<I, O, N>,
backward: T,
ops: Ops<B, OS, I, O, D, N>,
backward: OS,
phantom: PhantomData<B>,
}
@ -180,8 +194,8 @@ impl<B, T, I, O, const D: usize, const N: usize> Step for OpsStep<B, T, I, O, D,
where
B: Backend,
T: OpsSpec<B, D, N, Input = I, Output = O>,
I: Clone + Send + Sync + std::fmt::Debug + 'static,
O: Clone + Send + Sync + std::fmt::Debug + 'static,
I: State,
O: State,
{
fn step(self: Box<Self>, grads: &mut Gradients, states: &NodeStates) {
self.backward.backward(self.ops, grads, states);

View File

@ -2,7 +2,7 @@ use super::{Ops, OpsPrep};
use crate::{
grads::Gradients,
graph::{
checkpoint::{Bottleneck, NodeStates},
checkpoint::{Bottleneck, NodeStates, State},
Graph, NodeRef, Requirement,
},
utils::duplicate,
@ -22,15 +22,15 @@ where
B: Backend,
{
/// Associated type to compute the backward pass.
type Input: Clone + Send + Sync + std::fmt::Debug + 'static;
type Output: Clone + Send + Sync + std::fmt::Debug + 'static;
type Input: State;
type Output: State;
fn forward(&self, input: Self::Input) -> Self::Output;
/// The backward pass.
fn backward(
self,
ops: Ops<Self::Input, Self::Output, N>,
ops: Ops<B, Self, Self::Input, Self::Output, D, N>,
grads: &mut Gradients,
states: &NodeStates,
);

View File

@ -3,7 +3,7 @@ use std::{marker::PhantomData, process::Output};
use crate::{
grads::Gradients,
graph::{
checkpoint::{Bottleneck, NodeStates},
checkpoint::{Bottleneck, NodeStates, StateStruct},
Graph, NodeRef, Requirement, Step,
},
ops::{binary, broadcast_shape, unary, unary_different_backend, Ops, OpsKind, OpsSpec},
@ -295,17 +295,17 @@ impl<B: Backend> TensorOps<Self> for Autodiff<B> {
struct Div;
impl<B: Backend, const D: usize> OpsSpec<B, D, 2> for Div {
type Input = (B::TensorPrimitive<D>, B::TensorPrimitive<D>);
type Output = B::TensorPrimitive<D>;
type Input = StateStruct<B, D, 2>;
type Output = StateStruct<B, D, 1>;
fn backward(
self,
ops: Ops<Self::Input, Self::Output, 2>,
ops: Ops<B, Self, Self::Input, Self::Output, D, 2>,
grads: &mut Gradients,
states: &NodeStates,
) {
// let (lhs, rhs, broadcast) = ops.state;
let (lhs, rhs) = ops.fetch_inputs(states);
let [lhs, rhs] = ops.fetch_inputs(states).tensors;
let broadcast = BinaryOpsBroadcast::new::<B>(&lhs, &rhs);
let [rhs_4lhs, rhs_4rhs] = duplicate(&ops.parents, Some(rhs));
@ -331,8 +331,9 @@ impl<B: Backend> TensorOps<Self> for Autodiff<B> {
}
fn forward(&self, input: Self::Input) -> Self::Output {
let (lhs, rhs) = input;
B::div(lhs, rhs)
let [lhs, rhs] = input.tensors;
let result = B::div(lhs, rhs);
StateStruct::new([result])
}
fn bottleneck(&self) -> Bottleneck {