From 9f63f9d308fd980151b046d5dc2d7986047123df Mon Sep 17 00:00:00 2001 From: louisfd Date: Tue, 23 Jan 2024 12:42:27 -0500 Subject: [PATCH] made it to clone errors --- burn-autodiff/Cargo.toml | 3 ++ burn-autodiff/src/graph/base.rs | 11 ++-- burn-autodiff/src/graph/checkpoint.rs | 77 ++++++++++++++++++--------- burn-autodiff/src/ops/base.rs | 35 ++++++++++-- burn-autodiff/src/ops/tensor.rs | 10 ++-- burn-autodiff/src/tensor.rs | 7 +-- 6 files changed, 102 insertions(+), 41 deletions(-) diff --git a/burn-autodiff/Cargo.toml b/burn-autodiff/Cargo.toml index 8057c70b7..ed427a8b8 100644 --- a/burn-autodiff/Cargo.toml +++ b/burn-autodiff/Cargo.toml @@ -21,3 +21,6 @@ burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.12.0", opt derive-new = { workspace = true } spin = { workspace = true } + +# HOPEFULLY CAN BE REMOVED +downcast-rs = "1.0.4" diff --git a/burn-autodiff/src/graph/base.rs b/burn-autodiff/src/graph/base.rs index 43b7c4fde..2fb2da936 100644 --- a/burn-autodiff/src/graph/base.rs +++ b/burn-autodiff/src/graph/base.rs @@ -5,7 +5,7 @@ use std::{collections::HashMap, sync::Arc}; use crate::grads::Gradients; use super::{ - checkpoint::{NodeStates, StateStruct}, + checkpoint::{NodeStates, OperationBoxed, StateStruct}, NodeID, NodeRef, }; @@ -103,10 +103,13 @@ impl Graph { &self, output: B::TensorPrimitive, node: NodeID, + operation: OperationBoxed, ) { - self.states - .get_mut() - .register(node, Box::new(StateStruct::::new([output]))) + self.states.get_mut().register( + node, + Box::new(StateStruct::::new([output].into())), + operation, + ) } pub(crate) fn node_states(&self) -> &NodeStates { diff --git a/burn-autodiff/src/graph/checkpoint.rs b/burn-autodiff/src/graph/checkpoint.rs index 99a62a64a..1bf58bec6 100644 --- a/burn-autodiff/src/graph/checkpoint.rs +++ b/burn-autodiff/src/graph/checkpoint.rs @@ -1,8 +1,9 @@ use std::{collections::HashMap, fmt::Debug}; use burn_tensor::backend::Backend; +use downcast_rs::Downcast; -use crate::ops::{Ops, OpsSpec}; +use crate::ops::{Operation, Ops, OpsSpec}; use super::{NodeID, NodeRef}; @@ -27,11 +28,12 @@ pub enum Bottleneck { // // Manual, // // } -pub trait State: Send + Sync + Debug + 'static {} +pub trait State: Send + Sync + Debug + 'static + Downcast {} #[derive(new, Debug, Clone)] pub struct StateStruct { - pub tensors: [B::TensorPrimitive; N], + pub tensors: Vec>, // size N + // pub tensors: [B::TensorPrimitive; N], } impl State for StateStruct {} @@ -41,16 +43,20 @@ impl State for StateStruct pub struct StateNull {} impl State for StateNull {} -#[derive(Default, Debug)] -pub struct NodeStates { - hashmap: HashMap, -} - pub type StateBoxed = Box; +pub type OperationBoxed = Box; + +#[derive(Default, Debug)] +pub struct NodeStates { + state_hashmap: HashMap, + operation_hashmap: HashMap, +} + impl NodeStates { - pub fn register(mut self, node_id: NodeID, state: StateBoxed) { - self.hashmap.insert(node_id, state); + pub fn register(mut self, node_id: NodeID, state: StateBoxed, operation: OperationBoxed) { + self.state_hashmap.insert(node_id, state); + self.operation_hashmap.insert(node_id, operation); } pub fn get_input(&self, node: NodeRef) -> I @@ -60,10 +66,16 @@ impl NodeStates { I: State, O: State, { - node.parents + let x: Vec> = node + .parents .iter() .map(|parent| self.get_output::(parent)) - .collect() + .collect(); + + *outputs_to_input::(x) + .as_any() + .downcast_ref::() + .expect("Downcast failed") } pub fn get_output( @@ -76,35 +88,48 @@ impl NodeStates { I: State, O: State, { - match self.hashmap.remove(node_id) { + match self.state_hashmap.remove(node_id) { Some(state) => state, None => { - let ops: Ops = self.get_ops_from_node_id(node_id); - let inputs = self.get_input::(ops.node); - Box::new(ops.forward(inputs)) + // let ops: Ops = self.get_ops_from_node_id(node_id); + let operation: &OperationBoxed = + self.get_ops_from_node_id::(node_id); + let inputs = self.get_input::(operation.node()); + operation.forward(Box::new(inputs)) } } } - // NodeStates must have access to a mapping from NodeRef/NodeID to Ops - // Otherwise how to get parents just with ids? - // And how to do the forward pass ? + // maybe inline fn get_ops_from_node_id( &self, node_id: &NodeID, - ) -> Ops + ) -> &OperationBoxed + // ) -> Ops where OS: OpsSpec, B: Backend, I: State, O: State, { - todo!() + self.operation_hashmap.get(node_id).unwrap() } } -// STILL TO DO - -// - Collect several Os into an I -// - node_id -> map of node_id -> Ops -// when registering, pass a pointer to the ops too +fn outputs_to_input( + outputs: Vec, +) -> StateBoxed { + let x: Vec> = outputs + .iter() + .map(|out| { + *out.as_any() + .downcast_ref::>() + .expect("Downcast failed") + }) + .collect(); + let y: Vec> = x + .iter() + .map(|state_struct| state_struct.tensors[0]) + .collect(); + Box::new(StateStruct::::new(y)) +} diff --git a/burn-autodiff/src/ops/base.rs b/burn-autodiff/src/ops/base.rs index a49620c3a..cc270a521 100644 --- a/burn-autodiff/src/ops/base.rs +++ b/burn-autodiff/src/ops/base.rs @@ -1,12 +1,13 @@ use crate::{ grads::Gradients, graph::{ - checkpoint::{NodeStates, State, StateNull}, + checkpoint::{NodeStates, State, StateBoxed, StateNull}, NodeRef, Requirement, {Graph, Step}, }, tensor::AutodiffTensor, }; use burn_tensor::{backend::Backend, ops, Shape}; +use std::any::Any; use std::{marker::PhantomData, process::Output}; use super::OpsSpec; @@ -121,7 +122,7 @@ where let ops = Ops::new(parents, autodiff_tensor.node.clone(), ops_spec); match ops_spec.bottleneck() { ComputeBound => { - autodiff_tensor.register_output(output); + autodiff_tensor.register_output(output, Box::new(ops)); } MemoryBound => {} } @@ -171,8 +172,34 @@ where states.get_input::(self.node) } - pub(crate) fn forward(&self, inputs: I) -> O { - self.ops_spec.forward(inputs) + // pub(crate) fn forward(&self, inputs: I) -> O { + // + // } +} + +pub trait Operation: Send + Sync + std::fmt::Debug + 'static { + fn node(&self) -> NodeRef; + fn forward(&self, input: StateBoxed) -> StateBoxed; +} + +impl Operation for Ops +where + B: Backend, + OS: OpsSpec, + I: State, + O: State, +{ + fn node(&self) -> NodeRef { + self.node + } + + fn forward(&self, input: StateBoxed) -> StateBoxed { + // ouch + let x = *input + .as_any() + .downcast_ref::() + .expect("Downcast did not work"); + Box::new(self.ops_spec.forward(x)) } } diff --git a/burn-autodiff/src/ops/tensor.rs b/burn-autodiff/src/ops/tensor.rs index 5a1ba2db7..70be6aef6 100644 --- a/burn-autodiff/src/ops/tensor.rs +++ b/burn-autodiff/src/ops/tensor.rs @@ -6,7 +6,7 @@ use crate::{ checkpoint::{Bottleneck, NodeStates, StateStruct}, Graph, NodeRef, Requirement, Step, }, - ops::{binary, broadcast_shape, unary, unary_different_backend, Ops, OpsKind, OpsSpec}, + ops::{binary, broadcast_shape, tensor, unary, unary_different_backend, Ops, OpsKind, OpsSpec}, tensor::AutodiffTensor, utils::duplicate, Autodiff, @@ -305,7 +305,9 @@ impl TensorOps for Autodiff { states: &NodeStates, ) { // let (lhs, rhs, broadcast) = ops.state; - let [lhs, rhs] = ops.fetch_inputs(states).tensors; + let tensors = ops.fetch_inputs(states).tensors; + let lhs = tensors[0]; + let rhs = tensors[1]; let broadcast = BinaryOpsBroadcast::new::(&lhs, &rhs); let [rhs_4lhs, rhs_4rhs] = duplicate(&ops.parents, Some(rhs)); @@ -331,9 +333,9 @@ impl TensorOps for Autodiff { } fn forward(&self, input: Self::Input) -> Self::Output { - let [lhs, rhs] = input.tensors; + let (lhs, rhs) = (input.tensors[0], input.tensors[1]); let result = B::div(lhs, rhs); - StateStruct::new([result]) + StateStruct::new([result].into()) } fn bottleneck(&self) -> Bottleneck { diff --git a/burn-autodiff/src/tensor.rs b/burn-autodiff/src/tensor.rs index 8f72d8660..181091167 100644 --- a/burn-autodiff/src/tensor.rs +++ b/burn-autodiff/src/tensor.rs @@ -3,7 +3,7 @@ use burn_tensor::backend::Backend; use crate::{ grads::Gradients, graph::{ - checkpoint::NodeStates, + checkpoint::{NodeStates, OperationBoxed}, Node, NodeID, NodeRef, Requirement, {Graph, Step}, }, }; @@ -105,7 +105,8 @@ impl AutodiffTensor { self } - pub fn register_output(&self, output: B::TensorPrimitive) { - self.graph.register_output::(output, self.node.id); + pub fn register_output(&self, output: B::TensorPrimitive, operation: OperationBoxed) { + self.graph + .register_output::(output, self.node.id, operation); } }