mirror of https://github.com/tracel-ai/burn.git
made it to clone errors
This commit is contained in:
parent
1c0798952e
commit
9f63f9d308
|
@ -21,3 +21,6 @@ burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.12.0", opt
|
|||
|
||||
derive-new = { workspace = true }
|
||||
spin = { workspace = true }
|
||||
|
||||
# HOPEFULLY CAN BE REMOVED
|
||||
downcast-rs = "1.0.4"
|
||||
|
|
|
@ -5,7 +5,7 @@ use std::{collections::HashMap, sync::Arc};
|
|||
use crate::grads::Gradients;
|
||||
|
||||
use super::{
|
||||
checkpoint::{NodeStates, StateStruct},
|
||||
checkpoint::{NodeStates, OperationBoxed, StateStruct},
|
||||
NodeID, NodeRef,
|
||||
};
|
||||
|
||||
|
@ -103,10 +103,13 @@ impl Graph {
|
|||
&self,
|
||||
output: B::TensorPrimitive<D>,
|
||||
node: NodeID,
|
||||
operation: OperationBoxed,
|
||||
) {
|
||||
self.states
|
||||
.get_mut()
|
||||
.register(node, Box::new(StateStruct::<B, D, 1>::new([output])))
|
||||
self.states.get_mut().register(
|
||||
node,
|
||||
Box::new(StateStruct::<B, D, 1>::new([output].into())),
|
||||
operation,
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn node_states(&self) -> &NodeStates {
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
use std::{collections::HashMap, fmt::Debug};
|
||||
|
||||
use burn_tensor::backend::Backend;
|
||||
use downcast_rs::Downcast;
|
||||
|
||||
use crate::ops::{Ops, OpsSpec};
|
||||
use crate::ops::{Operation, Ops, OpsSpec};
|
||||
|
||||
use super::{NodeID, NodeRef};
|
||||
|
||||
|
@ -27,11 +28,12 @@ pub enum Bottleneck {
|
|||
// // Manual,
|
||||
// // }
|
||||
|
||||
pub trait State: Send + Sync + Debug + 'static {}
|
||||
pub trait State: Send + Sync + Debug + 'static + Downcast {}
|
||||
|
||||
#[derive(new, Debug, Clone)]
|
||||
pub struct StateStruct<B: Backend, const D: usize, const N: usize> {
|
||||
pub tensors: [B::TensorPrimitive<D>; N],
|
||||
pub tensors: Vec<B::TensorPrimitive<D>>, // size N
|
||||
// pub tensors: [B::TensorPrimitive<D>; N],
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize, const N: usize> State for StateStruct<B, D, N> {}
|
||||
|
@ -41,16 +43,20 @@ impl<B: Backend, const D: usize, const N: usize> State for StateStruct<B, D, N>
|
|||
pub struct StateNull {}
|
||||
impl State for StateNull {}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct NodeStates {
|
||||
hashmap: HashMap<NodeID, StateBoxed>,
|
||||
}
|
||||
|
||||
pub type StateBoxed = Box<dyn State>;
|
||||
|
||||
pub type OperationBoxed = Box<dyn Operation>;
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct NodeStates {
|
||||
state_hashmap: HashMap<NodeID, StateBoxed>,
|
||||
operation_hashmap: HashMap<NodeID, OperationBoxed>,
|
||||
}
|
||||
|
||||
impl NodeStates {
|
||||
pub fn register(mut self, node_id: NodeID, state: StateBoxed) {
|
||||
self.hashmap.insert(node_id, state);
|
||||
pub fn register(mut self, node_id: NodeID, state: StateBoxed, operation: OperationBoxed) {
|
||||
self.state_hashmap.insert(node_id, state);
|
||||
self.operation_hashmap.insert(node_id, operation);
|
||||
}
|
||||
|
||||
pub fn get_input<B, OS, I, O, const D: usize, const N: usize>(&self, node: NodeRef) -> I
|
||||
|
@ -60,10 +66,16 @@ impl NodeStates {
|
|||
I: State,
|
||||
O: State,
|
||||
{
|
||||
node.parents
|
||||
let x: Vec<Box<dyn State>> = node
|
||||
.parents
|
||||
.iter()
|
||||
.map(|parent| self.get_output::<B, OS, I, O, D, N>(parent))
|
||||
.collect()
|
||||
.collect();
|
||||
|
||||
*outputs_to_input::<B, D, N>(x)
|
||||
.as_any()
|
||||
.downcast_ref::<I>()
|
||||
.expect("Downcast failed")
|
||||
}
|
||||
|
||||
pub fn get_output<B, OS, I, O, const D: usize, const N: usize>(
|
||||
|
@ -76,35 +88,48 @@ impl NodeStates {
|
|||
I: State,
|
||||
O: State,
|
||||
{
|
||||
match self.hashmap.remove(node_id) {
|
||||
match self.state_hashmap.remove(node_id) {
|
||||
Some(state) => state,
|
||||
None => {
|
||||
let ops: Ops<B, OS, I, O, D, N> = self.get_ops_from_node_id(node_id);
|
||||
let inputs = self.get_input::<B, OS, I, O, D, N>(ops.node);
|
||||
Box::new(ops.forward(inputs))
|
||||
// let ops: Ops<B, OS, I, O, D, N> = self.get_ops_from_node_id(node_id);
|
||||
let operation: &OperationBoxed =
|
||||
self.get_ops_from_node_id::<B, OS, I, O, D, N>(node_id);
|
||||
let inputs = self.get_input::<B, OS, I, O, D, N>(operation.node());
|
||||
operation.forward(Box::new(inputs))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NodeStates must have access to a mapping from NodeRef/NodeID to Ops
|
||||
// Otherwise how to get parents just with ids?
|
||||
// And how to do the forward pass ?
|
||||
// maybe inline
|
||||
fn get_ops_from_node_id<B, OS, I, O, const D: usize, const N: usize>(
|
||||
&self,
|
||||
node_id: &NodeID,
|
||||
) -> Ops<B, OS, I, O, D, N>
|
||||
) -> &OperationBoxed
|
||||
// ) -> Ops<B, OS, I, O, D, N>
|
||||
where
|
||||
OS: OpsSpec<B, D, N, Input = I, Output = O>,
|
||||
B: Backend,
|
||||
I: State,
|
||||
O: State,
|
||||
{
|
||||
todo!()
|
||||
self.operation_hashmap.get(node_id).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
// STILL TO DO
|
||||
|
||||
// - Collect several Os into an I
|
||||
// - node_id -> map of node_id -> Ops
|
||||
// when registering, pass a pointer to the ops too
|
||||
fn outputs_to_input<B: Backend, const D: usize, const N: usize>(
|
||||
outputs: Vec<StateBoxed>,
|
||||
) -> StateBoxed {
|
||||
let x: Vec<StateStruct<B, D, N>> = outputs
|
||||
.iter()
|
||||
.map(|out| {
|
||||
*out.as_any()
|
||||
.downcast_ref::<StateStruct<B, D, N>>()
|
||||
.expect("Downcast failed")
|
||||
})
|
||||
.collect();
|
||||
let y: Vec<B::TensorPrimitive<D>> = x
|
||||
.iter()
|
||||
.map(|state_struct| state_struct.tensors[0])
|
||||
.collect();
|
||||
Box::new(StateStruct::<B, D, N>::new(y))
|
||||
}
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
use crate::{
|
||||
grads::Gradients,
|
||||
graph::{
|
||||
checkpoint::{NodeStates, State, StateNull},
|
||||
checkpoint::{NodeStates, State, StateBoxed, StateNull},
|
||||
NodeRef, Requirement, {Graph, Step},
|
||||
},
|
||||
tensor::AutodiffTensor,
|
||||
};
|
||||
use burn_tensor::{backend::Backend, ops, Shape};
|
||||
use std::any::Any;
|
||||
use std::{marker::PhantomData, process::Output};
|
||||
|
||||
use super::OpsSpec;
|
||||
|
@ -121,7 +122,7 @@ where
|
|||
let ops = Ops::new(parents, autodiff_tensor.node.clone(), ops_spec);
|
||||
match ops_spec.bottleneck() {
|
||||
ComputeBound => {
|
||||
autodiff_tensor.register_output(output);
|
||||
autodiff_tensor.register_output(output, Box::new(ops));
|
||||
}
|
||||
MemoryBound => {}
|
||||
}
|
||||
|
@ -171,8 +172,34 @@ where
|
|||
states.get_input::<B, OS, I, O, D, N>(self.node)
|
||||
}
|
||||
|
||||
pub(crate) fn forward(&self, inputs: I) -> O {
|
||||
self.ops_spec.forward(inputs)
|
||||
// pub(crate) fn forward(&self, inputs: I) -> O {
|
||||
//
|
||||
// }
|
||||
}
|
||||
|
||||
pub trait Operation: Send + Sync + std::fmt::Debug + 'static {
|
||||
fn node(&self) -> NodeRef;
|
||||
fn forward(&self, input: StateBoxed) -> StateBoxed;
|
||||
}
|
||||
|
||||
impl<B, OS, I, O, const D: usize, const N: usize> Operation for Ops<B, OS, I, O, D, N>
|
||||
where
|
||||
B: Backend,
|
||||
OS: OpsSpec<B, D, N, Input = I, Output = O>,
|
||||
I: State,
|
||||
O: State,
|
||||
{
|
||||
fn node(&self) -> NodeRef {
|
||||
self.node
|
||||
}
|
||||
|
||||
fn forward(&self, input: StateBoxed) -> StateBoxed {
|
||||
// ouch
|
||||
let x = *input
|
||||
.as_any()
|
||||
.downcast_ref::<I>()
|
||||
.expect("Downcast did not work");
|
||||
Box::new(self.ops_spec.forward(x))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ use crate::{
|
|||
checkpoint::{Bottleneck, NodeStates, StateStruct},
|
||||
Graph, NodeRef, Requirement, Step,
|
||||
},
|
||||
ops::{binary, broadcast_shape, unary, unary_different_backend, Ops, OpsKind, OpsSpec},
|
||||
ops::{binary, broadcast_shape, tensor, unary, unary_different_backend, Ops, OpsKind, OpsSpec},
|
||||
tensor::AutodiffTensor,
|
||||
utils::duplicate,
|
||||
Autodiff,
|
||||
|
@ -305,7 +305,9 @@ impl<B: Backend> TensorOps<Self> for Autodiff<B> {
|
|||
states: &NodeStates,
|
||||
) {
|
||||
// let (lhs, rhs, broadcast) = ops.state;
|
||||
let [lhs, rhs] = ops.fetch_inputs(states).tensors;
|
||||
let tensors = ops.fetch_inputs(states).tensors;
|
||||
let lhs = tensors[0];
|
||||
let rhs = tensors[1];
|
||||
let broadcast = BinaryOpsBroadcast::new::<B>(&lhs, &rhs);
|
||||
let [rhs_4lhs, rhs_4rhs] = duplicate(&ops.parents, Some(rhs));
|
||||
|
||||
|
@ -331,9 +333,9 @@ impl<B: Backend> TensorOps<Self> for Autodiff<B> {
|
|||
}
|
||||
|
||||
fn forward(&self, input: Self::Input) -> Self::Output {
|
||||
let [lhs, rhs] = input.tensors;
|
||||
let (lhs, rhs) = (input.tensors[0], input.tensors[1]);
|
||||
let result = B::div(lhs, rhs);
|
||||
StateStruct::new([result])
|
||||
StateStruct::new([result].into())
|
||||
}
|
||||
|
||||
fn bottleneck(&self) -> Bottleneck {
|
||||
|
|
|
@ -3,7 +3,7 @@ use burn_tensor::backend::Backend;
|
|||
use crate::{
|
||||
grads::Gradients,
|
||||
graph::{
|
||||
checkpoint::NodeStates,
|
||||
checkpoint::{NodeStates, OperationBoxed},
|
||||
Node, NodeID, NodeRef, Requirement, {Graph, Step},
|
||||
},
|
||||
};
|
||||
|
@ -105,7 +105,8 @@ impl<B: Backend, const D: usize> AutodiffTensor<B, D> {
|
|||
self
|
||||
}
|
||||
|
||||
pub fn register_output(&self, output: B::TensorPrimitive<D>) {
|
||||
self.graph.register_output::<B, D>(output, self.node.id);
|
||||
pub fn register_output(&self, output: B::TensorPrimitive<D>, operation: OperationBoxed) {
|
||||
self.graph
|
||||
.register_output::<B, D>(output, self.node.id, operation);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue