fix: single recorded ops

This commit is contained in:
nathaniel 2022-07-18 20:03:30 -04:00
parent 6ee628a748
commit f241a6c114
6 changed files with 79 additions and 28 deletions

View File

@ -15,6 +15,9 @@ impl NodeId {
value: nanoid::nanoid!(),
}
}
pub fn to_string(&self) -> String {
self.value.to_string()
}
}
pub trait Node<Out>: std::fmt::Debug {

View File

@ -69,7 +69,7 @@ where
Ops: SingleOps<In, Out>,
{
fn id(&self) -> NodeId {
self.input.borrow().id()
self.out.borrow().id()
}
fn backward(&mut self) {

View File

@ -14,6 +14,10 @@ impl Tape {
}
}
pub fn new_ref() -> TapeRef {
Rc::new(RefCell::new(Self::new()))
}
pub fn backward(&mut self, from: NodeId) {
let mut init = false;
@ -21,17 +25,14 @@ impl Tape {
if init {
ops.backward();
} else if ops.id() == from {
ops.set_last_ops();
init = true;
ops.set_last_ops();
ops.backward();
}
}
}
pub fn add(&mut self, ops: RecordedOpsRef) {
println!("---");
println!("Adding ops {:?}", ops);
println!("---");
self.operations.push(ops)
}
}

View File

@ -13,15 +13,15 @@ register_ops!(
ops BinaryOps<T, T, T>,
name ADTensorAddOps,
forward |left, right| left * right,
partial_left |state: &BinaryRecordedState<T, T, T>| state.right.clone(),
partial_right |state: &BinaryRecordedState<T, T, T>| state.left.clone(),
partial_left |state: &BinaryRecordedState<T, T, T>| state.left.ones(),
partial_right |state: &BinaryRecordedState<T, T, T>| state.right.ones(),
);
register_ops!(
ops SingleOps<T, T>,
name ADTensorAddScalarOps state P,
forward |state, input| input * state,
partial |state, state_recorded: &SingleRecordedState<T, T>| state_recorded.input.ones() * state,
partial |_state, state_recorded: &SingleRecordedState<T, T>| state_recorded.input.ones(),
);
impl<T, P, const D: usize> TensorOpsAdd<P, D> for ADTensor<P, D, T>
@ -74,3 +74,28 @@ where
TensorOpsAdd::add(&self, &rhs)
}
}
#[cfg(test)]
mod tests {
use crate::{backend::autodiff::helper::ADTchTensor, tape::Tape, Data, TensorBase};
#[test]
fn should_diff_add() {
let tape = Tape::new_ref();
let data_1 = Data::from([2.0]);
let data_2 = Data::from([4.0]);
let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone());
let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone());
let tensor_3 = tensor_1.clone() + tensor_2.clone();
tensor_3.backprob();
let grad_1 = tensor_1.grad();
let grad_2 = tensor_2.grad();
assert_eq!(grad_1.into_data(), Data::from([1.0]));
assert_eq!(grad_2.into_data(), Data::from([1.0]));
assert_eq!(tensor_3.into_data(), Data::from([6.0]));
}
}

View File

@ -77,35 +77,39 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::{
backend::tch::TchTensor,
tape::{Tape, TapeRef},
Data, TensorBase,
};
use std::cell::RefCell;
use crate::{backend::autodiff::helper::ADTchTensor, tape::Tape, Data, TensorBase};
#[test]
fn should_diff_mul() {
let tape = TapeRef::new(RefCell::new(Tape::new()));
let tape = Tape::new_ref();
let data_1 = Data::from([1.0]);
let data_2 = Data::from([4.0]);
let tensor_1 = TchTensor::from_data(data_1.clone(), tch::Device::Cpu);
let tensor_2 = TchTensor::from_data(data_2.clone(), tch::Device::Cpu);
let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone());
let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone());
let tensor_ad_1 = ADTensor::from_tensor(tensor_1, tape.clone());
let tensor_ad_2 = ADTensor::from_tensor(tensor_2, tape.clone());
let tensor_3 = tensor_1.clone() * tensor_2.clone();
tensor_3.backprob();
let tensor_ad_3 = tensor_ad_1.mul(&tensor_ad_2);
let data_ad_3 = tensor_ad_3.tensor().into_data();
assert_eq!(data_ad_3, Data::from([4.0]));
tensor_ad_3.backprob();
let grad_1 = tensor_ad_1.grad();
let grad_2 = tensor_ad_2.grad();
let grad_1 = tensor_1.grad();
let grad_2 = tensor_2.grad();
assert_eq!(grad_1.into_data(), data_2);
assert_eq!(grad_2.into_data(), data_1);
assert_eq!(tensor_3.into_data(), Data::from([4.0]));
}
#[test]
fn should_diff_mul_scalar() {
let tape = Tape::new_ref();
let data = Data::from([2.0]);
let tensor = ADTchTensor::from_data(data.clone(), tape.clone());
let tensor_out = tensor.clone() * 4.0;
tensor_out.backprob();
let grad = tensor.grad();
assert_eq!(tensor_out.into_data(), Data::from([8.0]));
assert_eq!(grad.into_data(), Data::from([4.0]));
}
}

View File

@ -7,7 +7,7 @@ use crate::{
};
use num_traits::Float;
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct ADTensor<P, const D: usize, T> {
pub node: NodeRef<T>,
pub shape: Shape<D>,
@ -94,3 +94,21 @@ impl<P: Float + Default> ADKind<P> {
Self { _p: P::default() }
}
}
#[cfg(test)]
pub mod helper {
use super::*;
use crate::{
backend::{autodiff::ADFloat, tch::TchTensor},
Data,
};
pub type ADTchTensor<P, const D: usize> = ADTensor<P, D, TchTensor<P, D>>;
impl<P: ADFloat + tch::kind::Element + Into<f64>, const D: usize> ADTchTensor<P, D> {
pub fn from_data(data: Data<P, D>, tape: TapeRef) -> Self {
let tensor = TchTensor::from_data(data, tch::Device::Cpu);
ADTensor::from_tensor(tensor, tape)
}
}
}