made it to clone errors

This commit is contained in:
louisfd 2024-01-23 12:42:27 -05:00
parent 1c0798952e
commit 9f63f9d308
6 changed files with 102 additions and 41 deletions

View File

@ -21,3 +21,6 @@ burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.12.0", opt
derive-new = { workspace = true } derive-new = { workspace = true }
spin = { workspace = true } spin = { workspace = true }
# HOPEFULLY CAN BE REMOVED
downcast-rs = "1.0.4"

View File

@ -5,7 +5,7 @@ use std::{collections::HashMap, sync::Arc};
use crate::grads::Gradients; use crate::grads::Gradients;
use super::{ use super::{
checkpoint::{NodeStates, StateStruct}, checkpoint::{NodeStates, OperationBoxed, StateStruct},
NodeID, NodeRef, NodeID, NodeRef,
}; };
@ -103,10 +103,13 @@ impl Graph {
&self, &self,
output: B::TensorPrimitive<D>, output: B::TensorPrimitive<D>,
node: NodeID, node: NodeID,
operation: OperationBoxed,
) { ) {
self.states self.states.get_mut().register(
.get_mut() node,
.register(node, Box::new(StateStruct::<B, D, 1>::new([output]))) Box::new(StateStruct::<B, D, 1>::new([output].into())),
operation,
)
} }
pub(crate) fn node_states(&self) -> &NodeStates { pub(crate) fn node_states(&self) -> &NodeStates {

View File

@ -1,8 +1,9 @@
use std::{collections::HashMap, fmt::Debug}; use std::{collections::HashMap, fmt::Debug};
use burn_tensor::backend::Backend; use burn_tensor::backend::Backend;
use downcast_rs::Downcast;
use crate::ops::{Ops, OpsSpec}; use crate::ops::{Operation, Ops, OpsSpec};
use super::{NodeID, NodeRef}; use super::{NodeID, NodeRef};
@ -27,11 +28,12 @@ pub enum Bottleneck {
// // Manual, // // Manual,
// // } // // }
pub trait State: Send + Sync + Debug + 'static {} pub trait State: Send + Sync + Debug + 'static + Downcast {}
#[derive(new, Debug, Clone)] #[derive(new, Debug, Clone)]
pub struct StateStruct<B: Backend, const D: usize, const N: usize> { pub struct StateStruct<B: Backend, const D: usize, const N: usize> {
pub tensors: [B::TensorPrimitive<D>; N], pub tensors: Vec<B::TensorPrimitive<D>>, // size N
// pub tensors: [B::TensorPrimitive<D>; N],
} }
impl<B: Backend, const D: usize, const N: usize> State for StateStruct<B, D, N> {} impl<B: Backend, const D: usize, const N: usize> State for StateStruct<B, D, N> {}
@ -41,16 +43,20 @@ impl<B: Backend, const D: usize, const N: usize> State for StateStruct<B, D, N>
pub struct StateNull {} pub struct StateNull {}
impl State for StateNull {} impl State for StateNull {}
#[derive(Default, Debug)]
pub struct NodeStates {
hashmap: HashMap<NodeID, StateBoxed>,
}
pub type StateBoxed = Box<dyn State>; pub type StateBoxed = Box<dyn State>;
pub type OperationBoxed = Box<dyn Operation>;
#[derive(Default, Debug)]
pub struct NodeStates {
state_hashmap: HashMap<NodeID, StateBoxed>,
operation_hashmap: HashMap<NodeID, OperationBoxed>,
}
impl NodeStates { impl NodeStates {
pub fn register(mut self, node_id: NodeID, state: StateBoxed) { pub fn register(mut self, node_id: NodeID, state: StateBoxed, operation: OperationBoxed) {
self.hashmap.insert(node_id, state); self.state_hashmap.insert(node_id, state);
self.operation_hashmap.insert(node_id, operation);
} }
pub fn get_input<B, OS, I, O, const D: usize, const N: usize>(&self, node: NodeRef) -> I pub fn get_input<B, OS, I, O, const D: usize, const N: usize>(&self, node: NodeRef) -> I
@ -60,10 +66,16 @@ impl NodeStates {
I: State, I: State,
O: State, O: State,
{ {
node.parents let x: Vec<Box<dyn State>> = node
.parents
.iter() .iter()
.map(|parent| self.get_output::<B, OS, I, O, D, N>(parent)) .map(|parent| self.get_output::<B, OS, I, O, D, N>(parent))
.collect() .collect();
*outputs_to_input::<B, D, N>(x)
.as_any()
.downcast_ref::<I>()
.expect("Downcast failed")
} }
pub fn get_output<B, OS, I, O, const D: usize, const N: usize>( pub fn get_output<B, OS, I, O, const D: usize, const N: usize>(
@ -76,35 +88,48 @@ impl NodeStates {
I: State, I: State,
O: State, O: State,
{ {
match self.hashmap.remove(node_id) { match self.state_hashmap.remove(node_id) {
Some(state) => state, Some(state) => state,
None => { None => {
let ops: Ops<B, OS, I, O, D, N> = self.get_ops_from_node_id(node_id); // let ops: Ops<B, OS, I, O, D, N> = self.get_ops_from_node_id(node_id);
let inputs = self.get_input::<B, OS, I, O, D, N>(ops.node); let operation: &OperationBoxed =
Box::new(ops.forward(inputs)) self.get_ops_from_node_id::<B, OS, I, O, D, N>(node_id);
let inputs = self.get_input::<B, OS, I, O, D, N>(operation.node());
operation.forward(Box::new(inputs))
} }
} }
} }
// NodeStates must have access to a mapping from NodeRef/NodeID to Ops // maybe inline
// Otherwise how to get parents just with ids?
// And how to do the forward pass ?
fn get_ops_from_node_id<B, OS, I, O, const D: usize, const N: usize>( fn get_ops_from_node_id<B, OS, I, O, const D: usize, const N: usize>(
&self, &self,
node_id: &NodeID, node_id: &NodeID,
) -> Ops<B, OS, I, O, D, N> ) -> &OperationBoxed
// ) -> Ops<B, OS, I, O, D, N>
where where
OS: OpsSpec<B, D, N, Input = I, Output = O>, OS: OpsSpec<B, D, N, Input = I, Output = O>,
B: Backend, B: Backend,
I: State, I: State,
O: State, O: State,
{ {
todo!() self.operation_hashmap.get(node_id).unwrap()
} }
} }
// STILL TO DO fn outputs_to_input<B: Backend, const D: usize, const N: usize>(
outputs: Vec<StateBoxed>,
// - Collect several Os into an I ) -> StateBoxed {
// - node_id -> map of node_id -> Ops let x: Vec<StateStruct<B, D, N>> = outputs
// when registering, pass a pointer to the ops too .iter()
.map(|out| {
*out.as_any()
.downcast_ref::<StateStruct<B, D, N>>()
.expect("Downcast failed")
})
.collect();
let y: Vec<B::TensorPrimitive<D>> = x
.iter()
.map(|state_struct| state_struct.tensors[0])
.collect();
Box::new(StateStruct::<B, D, N>::new(y))
}

View File

@ -1,12 +1,13 @@
use crate::{ use crate::{
grads::Gradients, grads::Gradients,
graph::{ graph::{
checkpoint::{NodeStates, State, StateNull}, checkpoint::{NodeStates, State, StateBoxed, StateNull},
NodeRef, Requirement, {Graph, Step}, NodeRef, Requirement, {Graph, Step},
}, },
tensor::AutodiffTensor, tensor::AutodiffTensor,
}; };
use burn_tensor::{backend::Backend, ops, Shape}; use burn_tensor::{backend::Backend, ops, Shape};
use std::any::Any;
use std::{marker::PhantomData, process::Output}; use std::{marker::PhantomData, process::Output};
use super::OpsSpec; use super::OpsSpec;
@ -121,7 +122,7 @@ where
let ops = Ops::new(parents, autodiff_tensor.node.clone(), ops_spec); let ops = Ops::new(parents, autodiff_tensor.node.clone(), ops_spec);
match ops_spec.bottleneck() { match ops_spec.bottleneck() {
ComputeBound => { ComputeBound => {
autodiff_tensor.register_output(output); autodiff_tensor.register_output(output, Box::new(ops));
} }
MemoryBound => {} MemoryBound => {}
} }
@ -171,8 +172,34 @@ where
states.get_input::<B, OS, I, O, D, N>(self.node) states.get_input::<B, OS, I, O, D, N>(self.node)
} }
pub(crate) fn forward(&self, inputs: I) -> O { // pub(crate) fn forward(&self, inputs: I) -> O {
self.ops_spec.forward(inputs) //
// }
}
pub trait Operation: Send + Sync + std::fmt::Debug + 'static {
fn node(&self) -> NodeRef;
fn forward(&self, input: StateBoxed) -> StateBoxed;
}
impl<B, OS, I, O, const D: usize, const N: usize> Operation for Ops<B, OS, I, O, D, N>
where
B: Backend,
OS: OpsSpec<B, D, N, Input = I, Output = O>,
I: State,
O: State,
{
fn node(&self) -> NodeRef {
self.node
}
fn forward(&self, input: StateBoxed) -> StateBoxed {
// ouch
let x = *input
.as_any()
.downcast_ref::<I>()
.expect("Downcast did not work");
Box::new(self.ops_spec.forward(x))
} }
} }

View File

@ -6,7 +6,7 @@ use crate::{
checkpoint::{Bottleneck, NodeStates, StateStruct}, checkpoint::{Bottleneck, NodeStates, StateStruct},
Graph, NodeRef, Requirement, Step, Graph, NodeRef, Requirement, Step,
}, },
ops::{binary, broadcast_shape, unary, unary_different_backend, Ops, OpsKind, OpsSpec}, ops::{binary, broadcast_shape, tensor, unary, unary_different_backend, Ops, OpsKind, OpsSpec},
tensor::AutodiffTensor, tensor::AutodiffTensor,
utils::duplicate, utils::duplicate,
Autodiff, Autodiff,
@ -305,7 +305,9 @@ impl<B: Backend> TensorOps<Self> for Autodiff<B> {
states: &NodeStates, states: &NodeStates,
) { ) {
// let (lhs, rhs, broadcast) = ops.state; // let (lhs, rhs, broadcast) = ops.state;
let [lhs, rhs] = ops.fetch_inputs(states).tensors; let tensors = ops.fetch_inputs(states).tensors;
let lhs = tensors[0];
let rhs = tensors[1];
let broadcast = BinaryOpsBroadcast::new::<B>(&lhs, &rhs); let broadcast = BinaryOpsBroadcast::new::<B>(&lhs, &rhs);
let [rhs_4lhs, rhs_4rhs] = duplicate(&ops.parents, Some(rhs)); let [rhs_4lhs, rhs_4rhs] = duplicate(&ops.parents, Some(rhs));
@ -331,9 +333,9 @@ impl<B: Backend> TensorOps<Self> for Autodiff<B> {
} }
fn forward(&self, input: Self::Input) -> Self::Output { fn forward(&self, input: Self::Input) -> Self::Output {
let [lhs, rhs] = input.tensors; let (lhs, rhs) = (input.tensors[0], input.tensors[1]);
let result = B::div(lhs, rhs); let result = B::div(lhs, rhs);
StateStruct::new([result]) StateStruct::new([result].into())
} }
fn bottleneck(&self) -> Bottleneck { fn bottleneck(&self) -> Bottleneck {

View File

@ -3,7 +3,7 @@ use burn_tensor::backend::Backend;
use crate::{ use crate::{
grads::Gradients, grads::Gradients,
graph::{ graph::{
checkpoint::NodeStates, checkpoint::{NodeStates, OperationBoxed},
Node, NodeID, NodeRef, Requirement, {Graph, Step}, Node, NodeID, NodeRef, Requirement, {Graph, Step},
}, },
}; };
@ -105,7 +105,8 @@ impl<B: Backend, const D: usize> AutodiffTensor<B, D> {
self self
} }
pub fn register_output(&self, output: B::TensorPrimitive<D>) { pub fn register_output(&self, output: B::TensorPrimitive<D>, operation: OperationBoxed) {
self.graph.register_output::<B, D>(output, self.node.id); self.graph
.register_output::<B, D>(output, self.node.id, operation);
} }
} }