mirror of https://github.com/tracel-ai/burn.git
made it to clone errors
This commit is contained in:
parent
1c0798952e
commit
9f63f9d308
|
@ -21,3 +21,6 @@ burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.12.0", opt
|
||||||
|
|
||||||
derive-new = { workspace = true }
|
derive-new = { workspace = true }
|
||||||
spin = { workspace = true }
|
spin = { workspace = true }
|
||||||
|
|
||||||
|
# HOPEFULLY CAN BE REMOVED
|
||||||
|
downcast-rs = "1.0.4"
|
||||||
|
|
|
@ -5,7 +5,7 @@ use std::{collections::HashMap, sync::Arc};
|
||||||
use crate::grads::Gradients;
|
use crate::grads::Gradients;
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
checkpoint::{NodeStates, StateStruct},
|
checkpoint::{NodeStates, OperationBoxed, StateStruct},
|
||||||
NodeID, NodeRef,
|
NodeID, NodeRef,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -103,10 +103,13 @@ impl Graph {
|
||||||
&self,
|
&self,
|
||||||
output: B::TensorPrimitive<D>,
|
output: B::TensorPrimitive<D>,
|
||||||
node: NodeID,
|
node: NodeID,
|
||||||
|
operation: OperationBoxed,
|
||||||
) {
|
) {
|
||||||
self.states
|
self.states.get_mut().register(
|
||||||
.get_mut()
|
node,
|
||||||
.register(node, Box::new(StateStruct::<B, D, 1>::new([output])))
|
Box::new(StateStruct::<B, D, 1>::new([output].into())),
|
||||||
|
operation,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn node_states(&self) -> &NodeStates {
|
pub(crate) fn node_states(&self) -> &NodeStates {
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
use std::{collections::HashMap, fmt::Debug};
|
use std::{collections::HashMap, fmt::Debug};
|
||||||
|
|
||||||
use burn_tensor::backend::Backend;
|
use burn_tensor::backend::Backend;
|
||||||
|
use downcast_rs::Downcast;
|
||||||
|
|
||||||
use crate::ops::{Ops, OpsSpec};
|
use crate::ops::{Operation, Ops, OpsSpec};
|
||||||
|
|
||||||
use super::{NodeID, NodeRef};
|
use super::{NodeID, NodeRef};
|
||||||
|
|
||||||
|
@ -27,11 +28,12 @@ pub enum Bottleneck {
|
||||||
// // Manual,
|
// // Manual,
|
||||||
// // }
|
// // }
|
||||||
|
|
||||||
pub trait State: Send + Sync + Debug + 'static {}
|
pub trait State: Send + Sync + Debug + 'static + Downcast {}
|
||||||
|
|
||||||
#[derive(new, Debug, Clone)]
|
#[derive(new, Debug, Clone)]
|
||||||
pub struct StateStruct<B: Backend, const D: usize, const N: usize> {
|
pub struct StateStruct<B: Backend, const D: usize, const N: usize> {
|
||||||
pub tensors: [B::TensorPrimitive<D>; N],
|
pub tensors: Vec<B::TensorPrimitive<D>>, // size N
|
||||||
|
// pub tensors: [B::TensorPrimitive<D>; N],
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend, const D: usize, const N: usize> State for StateStruct<B, D, N> {}
|
impl<B: Backend, const D: usize, const N: usize> State for StateStruct<B, D, N> {}
|
||||||
|
@ -41,16 +43,20 @@ impl<B: Backend, const D: usize, const N: usize> State for StateStruct<B, D, N>
|
||||||
pub struct StateNull {}
|
pub struct StateNull {}
|
||||||
impl State for StateNull {}
|
impl State for StateNull {}
|
||||||
|
|
||||||
#[derive(Default, Debug)]
|
|
||||||
pub struct NodeStates {
|
|
||||||
hashmap: HashMap<NodeID, StateBoxed>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub type StateBoxed = Box<dyn State>;
|
pub type StateBoxed = Box<dyn State>;
|
||||||
|
|
||||||
|
pub type OperationBoxed = Box<dyn Operation>;
|
||||||
|
|
||||||
|
#[derive(Default, Debug)]
|
||||||
|
pub struct NodeStates {
|
||||||
|
state_hashmap: HashMap<NodeID, StateBoxed>,
|
||||||
|
operation_hashmap: HashMap<NodeID, OperationBoxed>,
|
||||||
|
}
|
||||||
|
|
||||||
impl NodeStates {
|
impl NodeStates {
|
||||||
pub fn register(mut self, node_id: NodeID, state: StateBoxed) {
|
pub fn register(mut self, node_id: NodeID, state: StateBoxed, operation: OperationBoxed) {
|
||||||
self.hashmap.insert(node_id, state);
|
self.state_hashmap.insert(node_id, state);
|
||||||
|
self.operation_hashmap.insert(node_id, operation);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_input<B, OS, I, O, const D: usize, const N: usize>(&self, node: NodeRef) -> I
|
pub fn get_input<B, OS, I, O, const D: usize, const N: usize>(&self, node: NodeRef) -> I
|
||||||
|
@ -60,10 +66,16 @@ impl NodeStates {
|
||||||
I: State,
|
I: State,
|
||||||
O: State,
|
O: State,
|
||||||
{
|
{
|
||||||
node.parents
|
let x: Vec<Box<dyn State>> = node
|
||||||
|
.parents
|
||||||
.iter()
|
.iter()
|
||||||
.map(|parent| self.get_output::<B, OS, I, O, D, N>(parent))
|
.map(|parent| self.get_output::<B, OS, I, O, D, N>(parent))
|
||||||
.collect()
|
.collect();
|
||||||
|
|
||||||
|
*outputs_to_input::<B, D, N>(x)
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<I>()
|
||||||
|
.expect("Downcast failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_output<B, OS, I, O, const D: usize, const N: usize>(
|
pub fn get_output<B, OS, I, O, const D: usize, const N: usize>(
|
||||||
|
@ -76,35 +88,48 @@ impl NodeStates {
|
||||||
I: State,
|
I: State,
|
||||||
O: State,
|
O: State,
|
||||||
{
|
{
|
||||||
match self.hashmap.remove(node_id) {
|
match self.state_hashmap.remove(node_id) {
|
||||||
Some(state) => state,
|
Some(state) => state,
|
||||||
None => {
|
None => {
|
||||||
let ops: Ops<B, OS, I, O, D, N> = self.get_ops_from_node_id(node_id);
|
// let ops: Ops<B, OS, I, O, D, N> = self.get_ops_from_node_id(node_id);
|
||||||
let inputs = self.get_input::<B, OS, I, O, D, N>(ops.node);
|
let operation: &OperationBoxed =
|
||||||
Box::new(ops.forward(inputs))
|
self.get_ops_from_node_id::<B, OS, I, O, D, N>(node_id);
|
||||||
|
let inputs = self.get_input::<B, OS, I, O, D, N>(operation.node());
|
||||||
|
operation.forward(Box::new(inputs))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NodeStates must have access to a mapping from NodeRef/NodeID to Ops
|
// maybe inline
|
||||||
// Otherwise how to get parents just with ids?
|
|
||||||
// And how to do the forward pass ?
|
|
||||||
fn get_ops_from_node_id<B, OS, I, O, const D: usize, const N: usize>(
|
fn get_ops_from_node_id<B, OS, I, O, const D: usize, const N: usize>(
|
||||||
&self,
|
&self,
|
||||||
node_id: &NodeID,
|
node_id: &NodeID,
|
||||||
) -> Ops<B, OS, I, O, D, N>
|
) -> &OperationBoxed
|
||||||
|
// ) -> Ops<B, OS, I, O, D, N>
|
||||||
where
|
where
|
||||||
OS: OpsSpec<B, D, N, Input = I, Output = O>,
|
OS: OpsSpec<B, D, N, Input = I, Output = O>,
|
||||||
B: Backend,
|
B: Backend,
|
||||||
I: State,
|
I: State,
|
||||||
O: State,
|
O: State,
|
||||||
{
|
{
|
||||||
todo!()
|
self.operation_hashmap.get(node_id).unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// STILL TO DO
|
fn outputs_to_input<B: Backend, const D: usize, const N: usize>(
|
||||||
|
outputs: Vec<StateBoxed>,
|
||||||
// - Collect several Os into an I
|
) -> StateBoxed {
|
||||||
// - node_id -> map of node_id -> Ops
|
let x: Vec<StateStruct<B, D, N>> = outputs
|
||||||
// when registering, pass a pointer to the ops too
|
.iter()
|
||||||
|
.map(|out| {
|
||||||
|
*out.as_any()
|
||||||
|
.downcast_ref::<StateStruct<B, D, N>>()
|
||||||
|
.expect("Downcast failed")
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let y: Vec<B::TensorPrimitive<D>> = x
|
||||||
|
.iter()
|
||||||
|
.map(|state_struct| state_struct.tensors[0])
|
||||||
|
.collect();
|
||||||
|
Box::new(StateStruct::<B, D, N>::new(y))
|
||||||
|
}
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
grads::Gradients,
|
grads::Gradients,
|
||||||
graph::{
|
graph::{
|
||||||
checkpoint::{NodeStates, State, StateNull},
|
checkpoint::{NodeStates, State, StateBoxed, StateNull},
|
||||||
NodeRef, Requirement, {Graph, Step},
|
NodeRef, Requirement, {Graph, Step},
|
||||||
},
|
},
|
||||||
tensor::AutodiffTensor,
|
tensor::AutodiffTensor,
|
||||||
};
|
};
|
||||||
use burn_tensor::{backend::Backend, ops, Shape};
|
use burn_tensor::{backend::Backend, ops, Shape};
|
||||||
|
use std::any::Any;
|
||||||
use std::{marker::PhantomData, process::Output};
|
use std::{marker::PhantomData, process::Output};
|
||||||
|
|
||||||
use super::OpsSpec;
|
use super::OpsSpec;
|
||||||
|
@ -121,7 +122,7 @@ where
|
||||||
let ops = Ops::new(parents, autodiff_tensor.node.clone(), ops_spec);
|
let ops = Ops::new(parents, autodiff_tensor.node.clone(), ops_spec);
|
||||||
match ops_spec.bottleneck() {
|
match ops_spec.bottleneck() {
|
||||||
ComputeBound => {
|
ComputeBound => {
|
||||||
autodiff_tensor.register_output(output);
|
autodiff_tensor.register_output(output, Box::new(ops));
|
||||||
}
|
}
|
||||||
MemoryBound => {}
|
MemoryBound => {}
|
||||||
}
|
}
|
||||||
|
@ -171,8 +172,34 @@ where
|
||||||
states.get_input::<B, OS, I, O, D, N>(self.node)
|
states.get_input::<B, OS, I, O, D, N>(self.node)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn forward(&self, inputs: I) -> O {
|
// pub(crate) fn forward(&self, inputs: I) -> O {
|
||||||
self.ops_spec.forward(inputs)
|
//
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Operation: Send + Sync + std::fmt::Debug + 'static {
|
||||||
|
fn node(&self) -> NodeRef;
|
||||||
|
fn forward(&self, input: StateBoxed) -> StateBoxed;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B, OS, I, O, const D: usize, const N: usize> Operation for Ops<B, OS, I, O, D, N>
|
||||||
|
where
|
||||||
|
B: Backend,
|
||||||
|
OS: OpsSpec<B, D, N, Input = I, Output = O>,
|
||||||
|
I: State,
|
||||||
|
O: State,
|
||||||
|
{
|
||||||
|
fn node(&self) -> NodeRef {
|
||||||
|
self.node
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, input: StateBoxed) -> StateBoxed {
|
||||||
|
// ouch
|
||||||
|
let x = *input
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<I>()
|
||||||
|
.expect("Downcast did not work");
|
||||||
|
Box::new(self.ops_spec.forward(x))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ use crate::{
|
||||||
checkpoint::{Bottleneck, NodeStates, StateStruct},
|
checkpoint::{Bottleneck, NodeStates, StateStruct},
|
||||||
Graph, NodeRef, Requirement, Step,
|
Graph, NodeRef, Requirement, Step,
|
||||||
},
|
},
|
||||||
ops::{binary, broadcast_shape, unary, unary_different_backend, Ops, OpsKind, OpsSpec},
|
ops::{binary, broadcast_shape, tensor, unary, unary_different_backend, Ops, OpsKind, OpsSpec},
|
||||||
tensor::AutodiffTensor,
|
tensor::AutodiffTensor,
|
||||||
utils::duplicate,
|
utils::duplicate,
|
||||||
Autodiff,
|
Autodiff,
|
||||||
|
@ -305,7 +305,9 @@ impl<B: Backend> TensorOps<Self> for Autodiff<B> {
|
||||||
states: &NodeStates,
|
states: &NodeStates,
|
||||||
) {
|
) {
|
||||||
// let (lhs, rhs, broadcast) = ops.state;
|
// let (lhs, rhs, broadcast) = ops.state;
|
||||||
let [lhs, rhs] = ops.fetch_inputs(states).tensors;
|
let tensors = ops.fetch_inputs(states).tensors;
|
||||||
|
let lhs = tensors[0];
|
||||||
|
let rhs = tensors[1];
|
||||||
let broadcast = BinaryOpsBroadcast::new::<B>(&lhs, &rhs);
|
let broadcast = BinaryOpsBroadcast::new::<B>(&lhs, &rhs);
|
||||||
let [rhs_4lhs, rhs_4rhs] = duplicate(&ops.parents, Some(rhs));
|
let [rhs_4lhs, rhs_4rhs] = duplicate(&ops.parents, Some(rhs));
|
||||||
|
|
||||||
|
@ -331,9 +333,9 @@ impl<B: Backend> TensorOps<Self> for Autodiff<B> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, input: Self::Input) -> Self::Output {
|
fn forward(&self, input: Self::Input) -> Self::Output {
|
||||||
let [lhs, rhs] = input.tensors;
|
let (lhs, rhs) = (input.tensors[0], input.tensors[1]);
|
||||||
let result = B::div(lhs, rhs);
|
let result = B::div(lhs, rhs);
|
||||||
StateStruct::new([result])
|
StateStruct::new([result].into())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn bottleneck(&self) -> Bottleneck {
|
fn bottleneck(&self) -> Bottleneck {
|
||||||
|
|
|
@ -3,7 +3,7 @@ use burn_tensor::backend::Backend;
|
||||||
use crate::{
|
use crate::{
|
||||||
grads::Gradients,
|
grads::Gradients,
|
||||||
graph::{
|
graph::{
|
||||||
checkpoint::NodeStates,
|
checkpoint::{NodeStates, OperationBoxed},
|
||||||
Node, NodeID, NodeRef, Requirement, {Graph, Step},
|
Node, NodeID, NodeRef, Requirement, {Graph, Step},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
@ -105,7 +105,8 @@ impl<B: Backend, const D: usize> AutodiffTensor<B, D> {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn register_output(&self, output: B::TensorPrimitive<D>) {
|
pub fn register_output(&self, output: B::TensorPrimitive<D>, operation: OperationBoxed) {
|
||||||
self.graph.register_output::<B, D>(output, self.node.id);
|
self.graph
|
||||||
|
.register_output::<B, D>(output, self.node.id, operation);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue