made it to clone errors

This commit is contained in:
louisfd 2024-01-23 12:42:27 -05:00
parent 1c0798952e
commit 9f63f9d308
6 changed files with 102 additions and 41 deletions

View File

@ -21,3 +21,6 @@ burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.12.0", opt
derive-new = { workspace = true }
spin = { workspace = true }
# HOPEFULLY CAN BE REMOVED
downcast-rs = "1.0.4"

View File

@ -5,7 +5,7 @@ use std::{collections::HashMap, sync::Arc};
use crate::grads::Gradients;
use super::{
checkpoint::{NodeStates, StateStruct},
checkpoint::{NodeStates, OperationBoxed, StateStruct},
NodeID, NodeRef,
};
@ -103,10 +103,13 @@ impl Graph {
&self,
output: B::TensorPrimitive<D>,
node: NodeID,
operation: OperationBoxed,
) {
self.states
.get_mut()
.register(node, Box::new(StateStruct::<B, D, 1>::new([output])))
self.states.get_mut().register(
node,
Box::new(StateStruct::<B, D, 1>::new([output].into())),
operation,
)
}
pub(crate) fn node_states(&self) -> &NodeStates {

View File

@ -1,8 +1,9 @@
use std::{collections::HashMap, fmt::Debug};
use burn_tensor::backend::Backend;
use downcast_rs::Downcast;
use crate::ops::{Ops, OpsSpec};
use crate::ops::{Operation, Ops, OpsSpec};
use super::{NodeID, NodeRef};
@ -27,11 +28,12 @@ pub enum Bottleneck {
// // Manual,
// // }
pub trait State: Send + Sync + Debug + 'static {}
pub trait State: Send + Sync + Debug + 'static + Downcast {}
#[derive(new, Debug, Clone)]
pub struct StateStruct<B: Backend, const D: usize, const N: usize> {
pub tensors: [B::TensorPrimitive<D>; N],
pub tensors: Vec<B::TensorPrimitive<D>>, // size N
// pub tensors: [B::TensorPrimitive<D>; N],
}
impl<B: Backend, const D: usize, const N: usize> State for StateStruct<B, D, N> {}
@ -41,16 +43,20 @@ impl<B: Backend, const D: usize, const N: usize> State for StateStruct<B, D, N>
pub struct StateNull {}
impl State for StateNull {}
#[derive(Default, Debug)]
pub struct NodeStates {
hashmap: HashMap<NodeID, StateBoxed>,
}
pub type StateBoxed = Box<dyn State>;
pub type OperationBoxed = Box<dyn Operation>;
#[derive(Default, Debug)]
pub struct NodeStates {
state_hashmap: HashMap<NodeID, StateBoxed>,
operation_hashmap: HashMap<NodeID, OperationBoxed>,
}
impl NodeStates {
pub fn register(mut self, node_id: NodeID, state: StateBoxed) {
self.hashmap.insert(node_id, state);
pub fn register(mut self, node_id: NodeID, state: StateBoxed, operation: OperationBoxed) {
self.state_hashmap.insert(node_id, state);
self.operation_hashmap.insert(node_id, operation);
}
pub fn get_input<B, OS, I, O, const D: usize, const N: usize>(&self, node: NodeRef) -> I
@ -60,10 +66,16 @@ impl NodeStates {
I: State,
O: State,
{
node.parents
let x: Vec<Box<dyn State>> = node
.parents
.iter()
.map(|parent| self.get_output::<B, OS, I, O, D, N>(parent))
.collect()
.collect();
*outputs_to_input::<B, D, N>(x)
.as_any()
.downcast_ref::<I>()
.expect("Downcast failed")
}
pub fn get_output<B, OS, I, O, const D: usize, const N: usize>(
@ -76,35 +88,48 @@ impl NodeStates {
I: State,
O: State,
{
match self.hashmap.remove(node_id) {
match self.state_hashmap.remove(node_id) {
Some(state) => state,
None => {
let ops: Ops<B, OS, I, O, D, N> = self.get_ops_from_node_id(node_id);
let inputs = self.get_input::<B, OS, I, O, D, N>(ops.node);
Box::new(ops.forward(inputs))
// let ops: Ops<B, OS, I, O, D, N> = self.get_ops_from_node_id(node_id);
let operation: &OperationBoxed =
self.get_ops_from_node_id::<B, OS, I, O, D, N>(node_id);
let inputs = self.get_input::<B, OS, I, O, D, N>(operation.node());
operation.forward(Box::new(inputs))
}
}
}
// NodeStates must have access to a mapping from NodeRef/NodeID to Ops
// Otherwise how to get parents just with ids?
// And how to do the forward pass ?
// maybe inline
fn get_ops_from_node_id<B, OS, I, O, const D: usize, const N: usize>(
&self,
node_id: &NodeID,
) -> Ops<B, OS, I, O, D, N>
) -> &OperationBoxed
// ) -> Ops<B, OS, I, O, D, N>
where
OS: OpsSpec<B, D, N, Input = I, Output = O>,
B: Backend,
I: State,
O: State,
{
todo!()
self.operation_hashmap.get(node_id).unwrap()
}
}
// STILL TO DO
// - Collect several Os into an I
// - node_id -> map of node_id -> Ops
// when registering, pass a pointer to the ops too
fn outputs_to_input<B: Backend, const D: usize, const N: usize>(
outputs: Vec<StateBoxed>,
) -> StateBoxed {
let x: Vec<StateStruct<B, D, N>> = outputs
.iter()
.map(|out| {
*out.as_any()
.downcast_ref::<StateStruct<B, D, N>>()
.expect("Downcast failed")
})
.collect();
let y: Vec<B::TensorPrimitive<D>> = x
.iter()
.map(|state_struct| state_struct.tensors[0])
.collect();
Box::new(StateStruct::<B, D, N>::new(y))
}

View File

@ -1,12 +1,13 @@
use crate::{
grads::Gradients,
graph::{
checkpoint::{NodeStates, State, StateNull},
checkpoint::{NodeStates, State, StateBoxed, StateNull},
NodeRef, Requirement, {Graph, Step},
},
tensor::AutodiffTensor,
};
use burn_tensor::{backend::Backend, ops, Shape};
use std::any::Any;
use std::{marker::PhantomData, process::Output};
use super::OpsSpec;
@ -121,7 +122,7 @@ where
let ops = Ops::new(parents, autodiff_tensor.node.clone(), ops_spec);
match ops_spec.bottleneck() {
ComputeBound => {
autodiff_tensor.register_output(output);
autodiff_tensor.register_output(output, Box::new(ops));
}
MemoryBound => {}
}
@ -171,8 +172,34 @@ where
states.get_input::<B, OS, I, O, D, N>(self.node)
}
pub(crate) fn forward(&self, inputs: I) -> O {
self.ops_spec.forward(inputs)
// pub(crate) fn forward(&self, inputs: I) -> O {
//
// }
}
pub trait Operation: Send + Sync + std::fmt::Debug + 'static {
fn node(&self) -> NodeRef;
fn forward(&self, input: StateBoxed) -> StateBoxed;
}
impl<B, OS, I, O, const D: usize, const N: usize> Operation for Ops<B, OS, I, O, D, N>
where
B: Backend,
OS: OpsSpec<B, D, N, Input = I, Output = O>,
I: State,
O: State,
{
fn node(&self) -> NodeRef {
self.node
}
fn forward(&self, input: StateBoxed) -> StateBoxed {
// ouch
let x = *input
.as_any()
.downcast_ref::<I>()
.expect("Downcast did not work");
Box::new(self.ops_spec.forward(x))
}
}

View File

@ -6,7 +6,7 @@ use crate::{
checkpoint::{Bottleneck, NodeStates, StateStruct},
Graph, NodeRef, Requirement, Step,
},
ops::{binary, broadcast_shape, unary, unary_different_backend, Ops, OpsKind, OpsSpec},
ops::{binary, broadcast_shape, tensor, unary, unary_different_backend, Ops, OpsKind, OpsSpec},
tensor::AutodiffTensor,
utils::duplicate,
Autodiff,
@ -305,7 +305,9 @@ impl<B: Backend> TensorOps<Self> for Autodiff<B> {
states: &NodeStates,
) {
// let (lhs, rhs, broadcast) = ops.state;
let [lhs, rhs] = ops.fetch_inputs(states).tensors;
let tensors = ops.fetch_inputs(states).tensors;
let lhs = tensors[0];
let rhs = tensors[1];
let broadcast = BinaryOpsBroadcast::new::<B>(&lhs, &rhs);
let [rhs_4lhs, rhs_4rhs] = duplicate(&ops.parents, Some(rhs));
@ -331,9 +333,9 @@ impl<B: Backend> TensorOps<Self> for Autodiff<B> {
}
fn forward(&self, input: Self::Input) -> Self::Output {
let [lhs, rhs] = input.tensors;
let (lhs, rhs) = (input.tensors[0], input.tensors[1]);
let result = B::div(lhs, rhs);
StateStruct::new([result])
StateStruct::new([result].into())
}
fn bottleneck(&self) -> Bottleneck {

View File

@ -3,7 +3,7 @@ use burn_tensor::backend::Backend;
use crate::{
grads::Gradients,
graph::{
checkpoint::NodeStates,
checkpoint::{NodeStates, OperationBoxed},
Node, NodeID, NodeRef, Requirement, {Graph, Step},
},
};
@ -105,7 +105,8 @@ impl<B: Backend, const D: usize> AutodiffTensor<B, D> {
self
}
pub fn register_output(&self, output: B::TensorPrimitive<D>) {
self.graph.register_output::<B, D>(output, self.node.id);
pub fn register_output(&self, output: B::TensorPrimitive<D>, operation: OperationBoxed) {
self.graph
.register_output::<B, D>(output, self.node.id, operation);
}
}