mirror of https://github.com/tracel-ai/burn.git
fix: single recorded ops
This commit is contained in:
parent
6ee628a748
commit
f241a6c114
|
@ -15,6 +15,9 @@ impl NodeId {
|
|||
value: nanoid::nanoid!(),
|
||||
}
|
||||
}
|
||||
pub fn to_string(&self) -> String {
|
||||
self.value.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Node<Out>: std::fmt::Debug {
|
||||
|
|
|
@ -69,7 +69,7 @@ where
|
|||
Ops: SingleOps<In, Out>,
|
||||
{
|
||||
fn id(&self) -> NodeId {
|
||||
self.input.borrow().id()
|
||||
self.out.borrow().id()
|
||||
}
|
||||
|
||||
fn backward(&mut self) {
|
||||
|
|
|
@ -14,6 +14,10 @@ impl Tape {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn new_ref() -> TapeRef {
|
||||
Rc::new(RefCell::new(Self::new()))
|
||||
}
|
||||
|
||||
pub fn backward(&mut self, from: NodeId) {
|
||||
let mut init = false;
|
||||
|
||||
|
@ -21,17 +25,14 @@ impl Tape {
|
|||
if init {
|
||||
ops.backward();
|
||||
} else if ops.id() == from {
|
||||
ops.set_last_ops();
|
||||
init = true;
|
||||
ops.set_last_ops();
|
||||
ops.backward();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add(&mut self, ops: RecordedOpsRef) {
|
||||
println!("---");
|
||||
println!("Adding ops {:?}", ops);
|
||||
println!("---");
|
||||
self.operations.push(ops)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,15 +13,15 @@ register_ops!(
|
|||
ops BinaryOps<T, T, T>,
|
||||
name ADTensorAddOps,
|
||||
forward |left, right| left * right,
|
||||
partial_left |state: &BinaryRecordedState<T, T, T>| state.right.clone(),
|
||||
partial_right |state: &BinaryRecordedState<T, T, T>| state.left.clone(),
|
||||
partial_left |state: &BinaryRecordedState<T, T, T>| state.left.ones(),
|
||||
partial_right |state: &BinaryRecordedState<T, T, T>| state.right.ones(),
|
||||
);
|
||||
|
||||
register_ops!(
|
||||
ops SingleOps<T, T>,
|
||||
name ADTensorAddScalarOps state P,
|
||||
forward |state, input| input * state,
|
||||
partial |state, state_recorded: &SingleRecordedState<T, T>| state_recorded.input.ones() * state,
|
||||
partial |_state, state_recorded: &SingleRecordedState<T, T>| state_recorded.input.ones(),
|
||||
);
|
||||
|
||||
impl<T, P, const D: usize> TensorOpsAdd<P, D> for ADTensor<P, D, T>
|
||||
|
@ -74,3 +74,28 @@ where
|
|||
TensorOpsAdd::add(&self, &rhs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{backend::autodiff::helper::ADTchTensor, tape::Tape, Data, TensorBase};
|
||||
|
||||
#[test]
|
||||
fn should_diff_add() {
|
||||
let tape = Tape::new_ref();
|
||||
let data_1 = Data::from([2.0]);
|
||||
let data_2 = Data::from([4.0]);
|
||||
|
||||
let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone());
|
||||
let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone());
|
||||
|
||||
let tensor_3 = tensor_1.clone() + tensor_2.clone();
|
||||
tensor_3.backprob();
|
||||
|
||||
let grad_1 = tensor_1.grad();
|
||||
let grad_2 = tensor_2.grad();
|
||||
|
||||
assert_eq!(grad_1.into_data(), Data::from([1.0]));
|
||||
assert_eq!(grad_2.into_data(), Data::from([1.0]));
|
||||
assert_eq!(tensor_3.into_data(), Data::from([6.0]));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -77,35 +77,39 @@ where
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
backend::tch::TchTensor,
|
||||
tape::{Tape, TapeRef},
|
||||
Data, TensorBase,
|
||||
};
|
||||
use std::cell::RefCell;
|
||||
use crate::{backend::autodiff::helper::ADTchTensor, tape::Tape, Data, TensorBase};
|
||||
|
||||
#[test]
|
||||
fn should_diff_mul() {
|
||||
let tape = TapeRef::new(RefCell::new(Tape::new()));
|
||||
let tape = Tape::new_ref();
|
||||
let data_1 = Data::from([1.0]);
|
||||
let data_2 = Data::from([4.0]);
|
||||
|
||||
let tensor_1 = TchTensor::from_data(data_1.clone(), tch::Device::Cpu);
|
||||
let tensor_2 = TchTensor::from_data(data_2.clone(), tch::Device::Cpu);
|
||||
let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone());
|
||||
let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone());
|
||||
|
||||
let tensor_ad_1 = ADTensor::from_tensor(tensor_1, tape.clone());
|
||||
let tensor_ad_2 = ADTensor::from_tensor(tensor_2, tape.clone());
|
||||
let tensor_3 = tensor_1.clone() * tensor_2.clone();
|
||||
tensor_3.backprob();
|
||||
|
||||
let tensor_ad_3 = tensor_ad_1.mul(&tensor_ad_2);
|
||||
let data_ad_3 = tensor_ad_3.tensor().into_data();
|
||||
assert_eq!(data_ad_3, Data::from([4.0]));
|
||||
|
||||
tensor_ad_3.backprob();
|
||||
let grad_1 = tensor_ad_1.grad();
|
||||
let grad_2 = tensor_ad_2.grad();
|
||||
let grad_1 = tensor_1.grad();
|
||||
let grad_2 = tensor_2.grad();
|
||||
|
||||
assert_eq!(grad_1.into_data(), data_2);
|
||||
assert_eq!(grad_2.into_data(), data_1);
|
||||
assert_eq!(tensor_3.into_data(), Data::from([4.0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_mul_scalar() {
|
||||
let tape = Tape::new_ref();
|
||||
let data = Data::from([2.0]);
|
||||
|
||||
let tensor = ADTchTensor::from_data(data.clone(), tape.clone());
|
||||
let tensor_out = tensor.clone() * 4.0;
|
||||
tensor_out.backprob();
|
||||
|
||||
let grad = tensor.grad();
|
||||
assert_eq!(tensor_out.into_data(), Data::from([8.0]));
|
||||
assert_eq!(grad.into_data(), Data::from([4.0]));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ use crate::{
|
|||
};
|
||||
use num_traits::Float;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ADTensor<P, const D: usize, T> {
|
||||
pub node: NodeRef<T>,
|
||||
pub shape: Shape<D>,
|
||||
|
@ -94,3 +94,21 @@ impl<P: Float + Default> ADKind<P> {
|
|||
Self { _p: P::default() }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod helper {
|
||||
use super::*;
|
||||
use crate::{
|
||||
backend::{autodiff::ADFloat, tch::TchTensor},
|
||||
Data,
|
||||
};
|
||||
|
||||
pub type ADTchTensor<P, const D: usize> = ADTensor<P, D, TchTensor<P, D>>;
|
||||
|
||||
impl<P: ADFloat + tch::kind::Element + Into<f64>, const D: usize> ADTchTensor<P, D> {
|
||||
pub fn from_data(data: Data<P, D>, tape: TapeRef) -> Self {
|
||||
let tensor = TchTensor::from_data(data, tch::Device::Cpu);
|
||||
ADTensor::from_tensor(tensor, tape)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue