mirror of https://github.com/tracel-ai/burn.git
wip
This commit is contained in:
parent
3552f17cb2
commit
1c0798952e
|
@ -106,7 +106,7 @@ impl Graph {
|
|||
) {
|
||||
self.states
|
||||
.get_mut()
|
||||
.register(node, Box::new(StateStruct::<B, D>::new(output)))
|
||||
.register(node, Box::new(StateStruct::<B, D, 1>::new([output])))
|
||||
}
|
||||
|
||||
pub(crate) fn node_states(&self) -> &NodeStates {
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::{collections::HashMap, fmt::Debug};
|
|||
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
use crate::ops::Ops;
|
||||
use crate::ops::{Ops, OpsSpec};
|
||||
|
||||
use super::{NodeID, NodeRef};
|
||||
|
||||
|
@ -27,14 +27,19 @@ pub enum Bottleneck {
|
|||
// // Manual,
|
||||
// // }
|
||||
|
||||
pub trait State: Sync + Send + Debug {}
|
||||
pub trait State: Send + Sync + Debug + 'static {}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
pub struct StateStruct<B: Backend, const D: usize> {
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
#[derive(new, Debug, Clone)]
|
||||
pub struct StateStruct<B: Backend, const D: usize, const N: usize> {
|
||||
pub tensors: [B::TensorPrimitive<D>; N],
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> State for StateStruct<B, D> {}
|
||||
impl<B: Backend, const D: usize, const N: usize> State for StateStruct<B, D, N> {}
|
||||
|
||||
// Not sure necessary, delete if possible
|
||||
#[derive(new, Debug, Clone)]
|
||||
pub struct StateNull {}
|
||||
impl State for StateNull {}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct NodeStates {
|
||||
|
@ -48,21 +53,34 @@ impl NodeStates {
|
|||
self.hashmap.insert(node_id, state);
|
||||
}
|
||||
|
||||
pub fn get_input(&self, node: NodeRef) -> Vec<StateBoxed> {
|
||||
pub fn get_input<B, OS, I, O, const D: usize, const N: usize>(&self, node: NodeRef) -> I
|
||||
where
|
||||
B: Backend,
|
||||
OS: OpsSpec<B, D, N, Input = I, Output = O>,
|
||||
I: State,
|
||||
O: State,
|
||||
{
|
||||
node.parents
|
||||
.iter()
|
||||
.map(|parent| self.get_output(parent))
|
||||
.map(|parent| self.get_output::<B, OS, I, O, D, N>(parent))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn get_output(&self, node_id: &NodeID) -> StateBoxed {
|
||||
pub fn get_output<B, OS, I, O, const D: usize, const N: usize>(
|
||||
&self,
|
||||
node_id: &NodeID,
|
||||
) -> StateBoxed
|
||||
where
|
||||
B: Backend,
|
||||
OS: OpsSpec<B, D, N, Input = I, Output = O>,
|
||||
I: State,
|
||||
O: State,
|
||||
{
|
||||
match self.hashmap.remove(node_id) {
|
||||
Some(state) => state,
|
||||
None => {
|
||||
// TODO not <_, _, 1>
|
||||
let ops: Ops<_, _, 1> = self.get_ops_from_node_id(node_id);
|
||||
let inputs = self.get_input(ops.node);
|
||||
//node forward does not exist
|
||||
let ops: Ops<B, OS, I, O, D, N> = self.get_ops_from_node_id(node_id);
|
||||
let inputs = self.get_input::<B, OS, I, O, D, N>(ops.node);
|
||||
Box::new(ops.forward(inputs))
|
||||
}
|
||||
}
|
||||
|
@ -71,7 +89,22 @@ impl NodeStates {
|
|||
// NodeStates must have access to a mapping from NodeRef/NodeID to Ops
|
||||
// Otherwise how to get parents just with ids?
|
||||
// And how to do the forward pass ?
|
||||
fn get_ops_from_node_id<I, O, const N: usize>(&self, node_id: &NodeID) -> Ops<I, O, N> {
|
||||
fn get_ops_from_node_id<B, OS, I, O, const D: usize, const N: usize>(
|
||||
&self,
|
||||
node_id: &NodeID,
|
||||
) -> Ops<B, OS, I, O, D, N>
|
||||
where
|
||||
OS: OpsSpec<B, D, N, Input = I, Output = O>,
|
||||
B: Backend,
|
||||
I: State,
|
||||
O: State,
|
||||
{
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
// STILL TO DO
|
||||
|
||||
// - Collect several Os into an I
|
||||
// - node_id -> map of node_id -> Ops
|
||||
// when registering, pass a pointer to the ops too
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
grads::Gradients,
|
||||
graph::{
|
||||
checkpoint::NodeStates,
|
||||
checkpoint::{NodeStates, State, StateNull},
|
||||
NodeRef, Requirement, {Graph, Step},
|
||||
},
|
||||
tensor::AutodiffTensor,
|
||||
|
@ -34,10 +34,12 @@ pub struct Tracked;
|
|||
/// Untracked operation tag.
|
||||
pub struct UnTracked;
|
||||
|
||||
impl<OS, B, const D: usize, const N: usize> OpsPrep<OS, B, (), (), D, N, Init>
|
||||
impl<OS, B, const D: usize, const N: usize, I, O> OpsPrep<OS, B, I, O, D, N, Init>
|
||||
where
|
||||
B: Backend,
|
||||
OS: OpsSpec<B, D, N, Input = (), Output = ()>,
|
||||
OS: OpsSpec<B, D, N, Input = I, Output = O>,
|
||||
I: State,
|
||||
O: State,
|
||||
{
|
||||
/// Prepare a stateless operation.
|
||||
pub fn stateless(self, output: <B as Backend>::TensorPrimitive<D>) -> AutodiffTensor<B, D> {
|
||||
|
@ -52,8 +54,8 @@ impl<OS, B, const D: usize, const N: usize, I, O> OpsPrep<OS, B, I, O, D, N, Ini
|
|||
where
|
||||
B: Backend,
|
||||
OS: OpsSpec<B, D, N, Input = I, Output = O>,
|
||||
I: Clone + Send + Sync + std::fmt::Debug + 'static,
|
||||
O: Clone + Send + Sync + std::fmt::Debug + 'static,
|
||||
I: State,
|
||||
O: State,
|
||||
{
|
||||
/// Prepare an operation that requires a state during the backward pass.
|
||||
pub fn stateful(self) -> OpsKind<OS, B, I, O, D, N> {
|
||||
|
@ -78,8 +80,8 @@ impl<OS, B, const D: usize, const N: usize, I, O> OpsPrep<OS, B, I, O, D, N, UnT
|
|||
where
|
||||
B: Backend,
|
||||
OS: OpsSpec<B, D, N, Input = I, Output = O>,
|
||||
I: Clone + Send + Sync + std::fmt::Debug + 'static,
|
||||
O: Clone + Send + Sync + std::fmt::Debug + 'static,
|
||||
I: State,
|
||||
O: State,
|
||||
{
|
||||
/// Finish the preparation of an untracked operation and returns the output tensor.
|
||||
pub fn finish(self, output: <B as Backend>::TensorPrimitive<D>) -> AutodiffTensor<B, D> {
|
||||
|
@ -96,8 +98,8 @@ impl<OS, B, const D: usize, const N: usize, I, O> OpsPrep<OS, B, I, O, D, N, Tra
|
|||
where
|
||||
B: Backend,
|
||||
OS: OpsSpec<B, D, N, Input = I, Output = O>,
|
||||
I: Clone + Send + Sync + std::fmt::Debug + 'static,
|
||||
O: Clone + Send + Sync + std::fmt::Debug + 'static,
|
||||
I: State,
|
||||
O: State,
|
||||
{
|
||||
/// Finish the preparation of a tracked operation and returns the output tensor.
|
||||
pub fn finish(
|
||||
|
@ -112,24 +114,21 @@ where
|
|||
self.graphs.into_iter(),
|
||||
self.requirement,
|
||||
);
|
||||
let parents = self.nodes.map(|node| node.clone_if_require_grad());
|
||||
let ops = Ops::new(parents, autodiff_tensor.node.clone());
|
||||
// probably should have ops_spec in order to recompute stuff
|
||||
// let ops = Ops::new(parents, autodiff_tensor.node.clone(), ops_spec);
|
||||
|
||||
// Access autodiff_tensor.graph, ask it to register
|
||||
// <nodeid, output>, nodeid inside tensor
|
||||
match ops_spec {
|
||||
Some(ops_spec) => match ops_spec.bottleneck() {
|
||||
ComputeBound => {
|
||||
autodiff_tensor.register_output(output);
|
||||
Some(ops_spec) => {
|
||||
let parents = self.nodes.map(|node| node.clone_if_require_grad());
|
||||
let ops = Ops::new(parents, autodiff_tensor.node.clone(), ops_spec);
|
||||
match ops_spec.bottleneck() {
|
||||
ComputeBound => {
|
||||
autodiff_tensor.register_output(output);
|
||||
}
|
||||
MemoryBound => {}
|
||||
}
|
||||
MemoryBound => {}
|
||||
},
|
||||
None => {}
|
||||
autodiff_tensor.register_step(OpsStep::new(ops, ops_spec))
|
||||
}
|
||||
None => autodiff_tensor,
|
||||
}
|
||||
|
||||
autodiff_tensor.register_step(OpsStep::new(ops, self.ops_spec))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -143,36 +142,51 @@ pub enum OpsKind<OS, B, I, O, const D: usize, const N: usize> {
|
|||
|
||||
/// Operation containing its parent nodes, its own node and the backward step state.
|
||||
#[derive(new, Debug)]
|
||||
pub struct Ops<I, O, const N: usize> {
|
||||
// pub struct Ops<OS, I, O, const N: usize> {
|
||||
pub struct Ops<B, OS, I, O, const D: usize, const N: usize>
|
||||
where
|
||||
B: Backend,
|
||||
OS: OpsSpec<B, D, N, Input = I, Output = O>,
|
||||
I: State,
|
||||
O: State,
|
||||
{
|
||||
/// Parents nodes.
|
||||
pub parents: [Option<NodeRef>; N],
|
||||
/// The node.
|
||||
pub node: NodeRef,
|
||||
pub ops_spec: OS,
|
||||
pub _backend: PhantomData<B>,
|
||||
pub _input: PhantomData<I>,
|
||||
pub _output: PhantomData<O>,
|
||||
}
|
||||
|
||||
impl<I, O, const N: usize> Ops<I, O, N> {
|
||||
impl<B, OS, I, O, const D: usize, const N: usize> Ops<B, OS, I, O, D, N>
|
||||
where
|
||||
B: Backend,
|
||||
OS: OpsSpec<B, D, N, Input = I, Output = O>,
|
||||
I: State,
|
||||
O: State,
|
||||
{
|
||||
pub fn fetch_inputs(&self, states: &NodeStates) -> I {
|
||||
states.get_input(self.node)
|
||||
states.get_input::<B, OS, I, O, D, N>(self.node)
|
||||
}
|
||||
|
||||
pub(crate) fn forward(&self, inputs: Vec<Box<dyn State>>) -> dyn State {
|
||||
pub(crate) fn forward(&self, inputs: I) -> O {
|
||||
self.ops_spec.forward(inputs)
|
||||
}
|
||||
}
|
||||
|
||||
/// Operation implementing backward [step](Step) with type erasing.
|
||||
#[derive(new, Debug)]
|
||||
struct OpsStep<B, T, I, O, const D: usize, const N: usize>
|
||||
struct OpsStep<B, OS, I, O, const D: usize, const N: usize>
|
||||
where
|
||||
B: Backend,
|
||||
T: OpsSpec<B, D, N, Input = I, Output = O>,
|
||||
I: Clone + Send + Sync + std::fmt::Debug + 'static,
|
||||
O: Clone + Send + Sync + std::fmt::Debug + 'static,
|
||||
OS: OpsSpec<B, D, N, Input = I, Output = O>,
|
||||
I: State,
|
||||
O: State,
|
||||
{
|
||||
ops: Ops<I, O, N>,
|
||||
backward: T,
|
||||
ops: Ops<B, OS, I, O, D, N>,
|
||||
backward: OS,
|
||||
phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
|
@ -180,8 +194,8 @@ impl<B, T, I, O, const D: usize, const N: usize> Step for OpsStep<B, T, I, O, D,
|
|||
where
|
||||
B: Backend,
|
||||
T: OpsSpec<B, D, N, Input = I, Output = O>,
|
||||
I: Clone + Send + Sync + std::fmt::Debug + 'static,
|
||||
O: Clone + Send + Sync + std::fmt::Debug + 'static,
|
||||
I: State,
|
||||
O: State,
|
||||
{
|
||||
fn step(self: Box<Self>, grads: &mut Gradients, states: &NodeStates) {
|
||||
self.backward.backward(self.ops, grads, states);
|
||||
|
|
|
@ -2,7 +2,7 @@ use super::{Ops, OpsPrep};
|
|||
use crate::{
|
||||
grads::Gradients,
|
||||
graph::{
|
||||
checkpoint::{Bottleneck, NodeStates},
|
||||
checkpoint::{Bottleneck, NodeStates, State},
|
||||
Graph, NodeRef, Requirement,
|
||||
},
|
||||
utils::duplicate,
|
||||
|
@ -22,15 +22,15 @@ where
|
|||
B: Backend,
|
||||
{
|
||||
/// Associated type to compute the backward pass.
|
||||
type Input: Clone + Send + Sync + std::fmt::Debug + 'static;
|
||||
type Output: Clone + Send + Sync + std::fmt::Debug + 'static;
|
||||
type Input: State;
|
||||
type Output: State;
|
||||
|
||||
fn forward(&self, input: Self::Input) -> Self::Output;
|
||||
|
||||
/// The backward pass.
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::Input, Self::Output, N>,
|
||||
ops: Ops<B, Self, Self::Input, Self::Output, D, N>,
|
||||
grads: &mut Gradients,
|
||||
states: &NodeStates,
|
||||
);
|
||||
|
|
|
@ -3,7 +3,7 @@ use std::{marker::PhantomData, process::Output};
|
|||
use crate::{
|
||||
grads::Gradients,
|
||||
graph::{
|
||||
checkpoint::{Bottleneck, NodeStates},
|
||||
checkpoint::{Bottleneck, NodeStates, StateStruct},
|
||||
Graph, NodeRef, Requirement, Step,
|
||||
},
|
||||
ops::{binary, broadcast_shape, unary, unary_different_backend, Ops, OpsKind, OpsSpec},
|
||||
|
@ -295,17 +295,17 @@ impl<B: Backend> TensorOps<Self> for Autodiff<B> {
|
|||
struct Div;
|
||||
|
||||
impl<B: Backend, const D: usize> OpsSpec<B, D, 2> for Div {
|
||||
type Input = (B::TensorPrimitive<D>, B::TensorPrimitive<D>);
|
||||
type Output = B::TensorPrimitive<D>;
|
||||
type Input = StateStruct<B, D, 2>;
|
||||
type Output = StateStruct<B, D, 1>;
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::Input, Self::Output, 2>,
|
||||
ops: Ops<B, Self, Self::Input, Self::Output, D, 2>,
|
||||
grads: &mut Gradients,
|
||||
states: &NodeStates,
|
||||
) {
|
||||
// let (lhs, rhs, broadcast) = ops.state;
|
||||
let (lhs, rhs) = ops.fetch_inputs(states);
|
||||
let [lhs, rhs] = ops.fetch_inputs(states).tensors;
|
||||
let broadcast = BinaryOpsBroadcast::new::<B>(&lhs, &rhs);
|
||||
let [rhs_4lhs, rhs_4rhs] = duplicate(&ops.parents, Some(rhs));
|
||||
|
||||
|
@ -331,8 +331,9 @@ impl<B: Backend> TensorOps<Self> for Autodiff<B> {
|
|||
}
|
||||
|
||||
fn forward(&self, input: Self::Input) -> Self::Output {
|
||||
let (lhs, rhs) = input;
|
||||
B::div(lhs, rhs)
|
||||
let [lhs, rhs] = input.tensors;
|
||||
let result = B::div(lhs, rhs);
|
||||
StateStruct::new([result])
|
||||
}
|
||||
|
||||
fn bottleneck(&self) -> Bottleneck {
|
||||
|
|
Loading…
Reference in New Issue