diff --git a/burn-tensor/src/graph/node.rs b/burn-tensor/src/graph/node.rs index 70d22dec1..c3c4aae17 100644 --- a/burn-tensor/src/graph/node.rs +++ b/burn-tensor/src/graph/node.rs @@ -15,6 +15,9 @@ impl NodeId { value: nanoid::nanoid!(), } } + pub fn to_string(&self) -> String { + self.value.to_string() + } } pub trait Node: std::fmt::Debug { diff --git a/burn-tensor/src/graph/ops/single.rs b/burn-tensor/src/graph/ops/single.rs index 1855b6e74..90ce04a50 100644 --- a/burn-tensor/src/graph/ops/single.rs +++ b/burn-tensor/src/graph/ops/single.rs @@ -69,7 +69,7 @@ where Ops: SingleOps, { fn id(&self) -> NodeId { - self.input.borrow().id() + self.out.borrow().id() } fn backward(&mut self) { diff --git a/burn-tensor/src/graph/tape.rs b/burn-tensor/src/graph/tape.rs index c544f8045..58ff02494 100644 --- a/burn-tensor/src/graph/tape.rs +++ b/burn-tensor/src/graph/tape.rs @@ -14,6 +14,10 @@ impl Tape { } } + pub fn new_ref() -> TapeRef { + Rc::new(RefCell::new(Self::new())) + } + pub fn backward(&mut self, from: NodeId) { let mut init = false; @@ -21,17 +25,14 @@ impl Tape { if init { ops.backward(); } else if ops.id() == from { - ops.set_last_ops(); init = true; + ops.set_last_ops(); ops.backward(); } } } pub fn add(&mut self, ops: RecordedOpsRef) { - println!("---"); - println!("Adding ops {:?}", ops); - println!("---"); self.operations.push(ops) } } diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/add.rs b/burn-tensor/src/tensor/backend/autodiff/ops/add.rs index 2bd97f296..99a049015 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/add.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/add.rs @@ -13,15 +13,15 @@ register_ops!( ops BinaryOps, name ADTensorAddOps, forward |left, right| left * right, - partial_left |state: &BinaryRecordedState| state.right.clone(), - partial_right |state: &BinaryRecordedState| state.left.clone(), + partial_left |state: &BinaryRecordedState| state.left.ones(), + partial_right |state: &BinaryRecordedState| state.right.ones(), ); register_ops!( ops SingleOps, name ADTensorAddScalarOps state P, forward |state, input| input * state, - partial |state, state_recorded: &SingleRecordedState| state_recorded.input.ones() * state, + partial |_state, state_recorded: &SingleRecordedState| state_recorded.input.ones(), ); impl TensorOpsAdd for ADTensor @@ -74,3 +74,28 @@ where TensorOpsAdd::add(&self, &rhs) } } + +#[cfg(test)] +mod tests { + use crate::{backend::autodiff::helper::ADTchTensor, tape::Tape, Data, TensorBase}; + + #[test] + fn should_diff_add() { + let tape = Tape::new_ref(); + let data_1 = Data::from([2.0]); + let data_2 = Data::from([4.0]); + + let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone()); + let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone()); + + let tensor_3 = tensor_1.clone() + tensor_2.clone(); + tensor_3.backprob(); + + let grad_1 = tensor_1.grad(); + let grad_2 = tensor_2.grad(); + + assert_eq!(grad_1.into_data(), Data::from([1.0])); + assert_eq!(grad_2.into_data(), Data::from([1.0])); + assert_eq!(tensor_3.into_data(), Data::from([6.0])); + } +} diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs b/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs index c3a44d5ab..ffeec8895 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs @@ -77,35 +77,39 @@ where #[cfg(test)] mod tests { - use super::*; - use crate::{ - backend::tch::TchTensor, - tape::{Tape, TapeRef}, - Data, TensorBase, - }; - use std::cell::RefCell; + use crate::{backend::autodiff::helper::ADTchTensor, tape::Tape, Data, TensorBase}; #[test] fn should_diff_mul() { - let tape = TapeRef::new(RefCell::new(Tape::new())); + let tape = Tape::new_ref(); let data_1 = Data::from([1.0]); let data_2 = Data::from([4.0]); - let tensor_1 = TchTensor::from_data(data_1.clone(), tch::Device::Cpu); - let tensor_2 = TchTensor::from_data(data_2.clone(), tch::Device::Cpu); + let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone()); + let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone()); - let tensor_ad_1 = ADTensor::from_tensor(tensor_1, tape.clone()); - let tensor_ad_2 = ADTensor::from_tensor(tensor_2, tape.clone()); + let tensor_3 = tensor_1.clone() * tensor_2.clone(); + tensor_3.backprob(); - let tensor_ad_3 = tensor_ad_1.mul(&tensor_ad_2); - let data_ad_3 = tensor_ad_3.tensor().into_data(); - assert_eq!(data_ad_3, Data::from([4.0])); - - tensor_ad_3.backprob(); - let grad_1 = tensor_ad_1.grad(); - let grad_2 = tensor_ad_2.grad(); + let grad_1 = tensor_1.grad(); + let grad_2 = tensor_2.grad(); assert_eq!(grad_1.into_data(), data_2); assert_eq!(grad_2.into_data(), data_1); + assert_eq!(tensor_3.into_data(), Data::from([4.0])); + } + + #[test] + fn should_diff_mul_scalar() { + let tape = Tape::new_ref(); + let data = Data::from([2.0]); + + let tensor = ADTchTensor::from_data(data.clone(), tape.clone()); + let tensor_out = tensor.clone() * 4.0; + tensor_out.backprob(); + + let grad = tensor.grad(); + assert_eq!(tensor_out.into_data(), Data::from([8.0])); + assert_eq!(grad.into_data(), Data::from([4.0])); } } diff --git a/burn-tensor/src/tensor/backend/autodiff/tensor.rs b/burn-tensor/src/tensor/backend/autodiff/tensor.rs index 1d7750f75..d9bc05f55 100644 --- a/burn-tensor/src/tensor/backend/autodiff/tensor.rs +++ b/burn-tensor/src/tensor/backend/autodiff/tensor.rs @@ -7,7 +7,7 @@ use crate::{ }; use num_traits::Float; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ADTensor { pub node: NodeRef, pub shape: Shape, @@ -94,3 +94,21 @@ impl ADKind

{ Self { _p: P::default() } } } + +#[cfg(test)] +pub mod helper { + use super::*; + use crate::{ + backend::{autodiff::ADFloat, tch::TchTensor}, + Data, + }; + + pub type ADTchTensor = ADTensor>; + + impl, const D: usize> ADTchTensor { + pub fn from_data(data: Data, tape: TapeRef) -> Self { + let tensor = TchTensor::from_data(data, tch::Device::Cpu); + ADTensor::from_tensor(tensor, tape) + } + } +}