fix: single recorded ops

This commit is contained in:
nathaniel 2022-07-18 20:03:30 -04:00
parent 6ee628a748
commit f241a6c114
6 changed files with 79 additions and 28 deletions

View File

@ -15,6 +15,9 @@ impl NodeId {
value: nanoid::nanoid!(), value: nanoid::nanoid!(),
} }
} }
pub fn to_string(&self) -> String {
self.value.to_string()
}
} }
pub trait Node<Out>: std::fmt::Debug { pub trait Node<Out>: std::fmt::Debug {

View File

@ -69,7 +69,7 @@ where
Ops: SingleOps<In, Out>, Ops: SingleOps<In, Out>,
{ {
fn id(&self) -> NodeId { fn id(&self) -> NodeId {
self.input.borrow().id() self.out.borrow().id()
} }
fn backward(&mut self) { fn backward(&mut self) {

View File

@ -14,6 +14,10 @@ impl Tape {
} }
} }
pub fn new_ref() -> TapeRef {
Rc::new(RefCell::new(Self::new()))
}
pub fn backward(&mut self, from: NodeId) { pub fn backward(&mut self, from: NodeId) {
let mut init = false; let mut init = false;
@ -21,17 +25,14 @@ impl Tape {
if init { if init {
ops.backward(); ops.backward();
} else if ops.id() == from { } else if ops.id() == from {
ops.set_last_ops();
init = true; init = true;
ops.set_last_ops();
ops.backward(); ops.backward();
} }
} }
} }
pub fn add(&mut self, ops: RecordedOpsRef) { pub fn add(&mut self, ops: RecordedOpsRef) {
println!("---");
println!("Adding ops {:?}", ops);
println!("---");
self.operations.push(ops) self.operations.push(ops)
} }
} }

View File

@ -13,15 +13,15 @@ register_ops!(
ops BinaryOps<T, T, T>, ops BinaryOps<T, T, T>,
name ADTensorAddOps, name ADTensorAddOps,
forward |left, right| left * right, forward |left, right| left * right,
partial_left |state: &BinaryRecordedState<T, T, T>| state.right.clone(), partial_left |state: &BinaryRecordedState<T, T, T>| state.left.ones(),
partial_right |state: &BinaryRecordedState<T, T, T>| state.left.clone(), partial_right |state: &BinaryRecordedState<T, T, T>| state.right.ones(),
); );
register_ops!( register_ops!(
ops SingleOps<T, T>, ops SingleOps<T, T>,
name ADTensorAddScalarOps state P, name ADTensorAddScalarOps state P,
forward |state, input| input * state, forward |state, input| input * state,
partial |state, state_recorded: &SingleRecordedState<T, T>| state_recorded.input.ones() * state, partial |_state, state_recorded: &SingleRecordedState<T, T>| state_recorded.input.ones(),
); );
impl<T, P, const D: usize> TensorOpsAdd<P, D> for ADTensor<P, D, T> impl<T, P, const D: usize> TensorOpsAdd<P, D> for ADTensor<P, D, T>
@ -74,3 +74,28 @@ where
TensorOpsAdd::add(&self, &rhs) TensorOpsAdd::add(&self, &rhs)
} }
} }
#[cfg(test)]
mod tests {
use crate::{backend::autodiff::helper::ADTchTensor, tape::Tape, Data, TensorBase};
#[test]
fn should_diff_add() {
let tape = Tape::new_ref();
let data_1 = Data::from([2.0]);
let data_2 = Data::from([4.0]);
let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone());
let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone());
let tensor_3 = tensor_1.clone() + tensor_2.clone();
tensor_3.backprob();
let grad_1 = tensor_1.grad();
let grad_2 = tensor_2.grad();
assert_eq!(grad_1.into_data(), Data::from([1.0]));
assert_eq!(grad_2.into_data(), Data::from([1.0]));
assert_eq!(tensor_3.into_data(), Data::from([6.0]));
}
}

View File

@ -77,35 +77,39 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use crate::{backend::autodiff::helper::ADTchTensor, tape::Tape, Data, TensorBase};
use crate::{
backend::tch::TchTensor,
tape::{Tape, TapeRef},
Data, TensorBase,
};
use std::cell::RefCell;
#[test] #[test]
fn should_diff_mul() { fn should_diff_mul() {
let tape = TapeRef::new(RefCell::new(Tape::new())); let tape = Tape::new_ref();
let data_1 = Data::from([1.0]); let data_1 = Data::from([1.0]);
let data_2 = Data::from([4.0]); let data_2 = Data::from([4.0]);
let tensor_1 = TchTensor::from_data(data_1.clone(), tch::Device::Cpu); let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone());
let tensor_2 = TchTensor::from_data(data_2.clone(), tch::Device::Cpu); let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone());
let tensor_ad_1 = ADTensor::from_tensor(tensor_1, tape.clone()); let tensor_3 = tensor_1.clone() * tensor_2.clone();
let tensor_ad_2 = ADTensor::from_tensor(tensor_2, tape.clone()); tensor_3.backprob();
let tensor_ad_3 = tensor_ad_1.mul(&tensor_ad_2); let grad_1 = tensor_1.grad();
let data_ad_3 = tensor_ad_3.tensor().into_data(); let grad_2 = tensor_2.grad();
assert_eq!(data_ad_3, Data::from([4.0]));
tensor_ad_3.backprob();
let grad_1 = tensor_ad_1.grad();
let grad_2 = tensor_ad_2.grad();
assert_eq!(grad_1.into_data(), data_2); assert_eq!(grad_1.into_data(), data_2);
assert_eq!(grad_2.into_data(), data_1); assert_eq!(grad_2.into_data(), data_1);
assert_eq!(tensor_3.into_data(), Data::from([4.0]));
}
#[test]
fn should_diff_mul_scalar() {
let tape = Tape::new_ref();
let data = Data::from([2.0]);
let tensor = ADTchTensor::from_data(data.clone(), tape.clone());
let tensor_out = tensor.clone() * 4.0;
tensor_out.backprob();
let grad = tensor.grad();
assert_eq!(tensor_out.into_data(), Data::from([8.0]));
assert_eq!(grad.into_data(), Data::from([4.0]));
} }
} }

View File

@ -7,7 +7,7 @@ use crate::{
}; };
use num_traits::Float; use num_traits::Float;
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct ADTensor<P, const D: usize, T> { pub struct ADTensor<P, const D: usize, T> {
pub node: NodeRef<T>, pub node: NodeRef<T>,
pub shape: Shape<D>, pub shape: Shape<D>,
@ -94,3 +94,21 @@ impl<P: Float + Default> ADKind<P> {
Self { _p: P::default() } Self { _p: P::default() }
} }
} }
#[cfg(test)]
pub mod helper {
use super::*;
use crate::{
backend::{autodiff::ADFloat, tch::TchTensor},
Data,
};
pub type ADTchTensor<P, const D: usize> = ADTensor<P, D, TchTensor<P, D>>;
impl<P: ADFloat + tch::kind::Element + Into<f64>, const D: usize> ADTchTensor<P, D> {
pub fn from_data(data: Data<P, D>, tape: TapeRef) -> Self {
let tensor = TchTensor::from_data(data, tch::Device::Cpu);
ADTensor::from_tensor(tensor, tape)
}
}
}