feat: support sub autograd

This commit is contained in:
nathaniel 2022-07-18 20:23:21 -04:00
parent f241a6c114
commit 5ce657ded9
7 changed files with 224 additions and 12 deletions

View File

@ -82,8 +82,8 @@ mod tests {
#[test] #[test]
fn should_diff_add() { fn should_diff_add() {
let tape = Tape::new_ref(); let tape = Tape::new_ref();
let data_1 = Data::from([2.0]); let data_1 = Data::from([2.0, 5.0]);
let data_2 = Data::from([4.0]); let data_2 = Data::from([4.0, 1.0]);
let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone()); let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone());
let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone()); let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone());
@ -94,8 +94,22 @@ mod tests {
let grad_1 = tensor_1.grad(); let grad_1 = tensor_1.grad();
let grad_2 = tensor_2.grad(); let grad_2 = tensor_2.grad();
assert_eq!(grad_1.into_data(), Data::from([1.0])); assert_eq!(grad_1.into_data(), Data::from([1.0, 1.0]));
assert_eq!(grad_2.into_data(), Data::from([1.0])); assert_eq!(grad_2.into_data(), Data::from([1.0, 1.0]));
assert_eq!(tensor_3.into_data(), Data::from([6.0])); assert_eq!(tensor_3.into_data(), Data::from([6.0, 6.0]));
}
#[test]
fn should_diff_add_scalar() {
let tape = Tape::new_ref();
let data = Data::from([2.0, 10.0]);
let tensor = ADTchTensor::from_data(data.clone(), tape.clone());
let tensor_out = tensor.clone() + 5.0;
tensor_out.backprob();
let grad = tensor.grad();
assert_eq!(grad.into_data(), Data::from([1.0, 1.0]));
assert_eq!(tensor_out.into_data(), Data::from([7.0, 15.0]));
} }
} }

View File

@ -1,5 +1,6 @@
mod add; mod add;
mod mul; mod mul;
mod sub;
mod macros; mod macros;
pub use macros::*; pub use macros::*;

View File

@ -82,8 +82,8 @@ mod tests {
#[test] #[test]
fn should_diff_mul() { fn should_diff_mul() {
let tape = Tape::new_ref(); let tape = Tape::new_ref();
let data_1 = Data::from([1.0]); let data_1 = Data::from([1.0, 7.0]);
let data_2 = Data::from([4.0]); let data_2 = Data::from([4.0, 7.0]);
let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone()); let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone());
let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone()); let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone());
@ -96,20 +96,20 @@ mod tests {
assert_eq!(grad_1.into_data(), data_2); assert_eq!(grad_1.into_data(), data_2);
assert_eq!(grad_2.into_data(), data_1); assert_eq!(grad_2.into_data(), data_1);
assert_eq!(tensor_3.into_data(), Data::from([4.0])); assert_eq!(tensor_3.into_data(), Data::from([4.0, 49.0]));
} }
#[test] #[test]
fn should_diff_mul_scalar() { fn should_diff_mul_scalar() {
let tape = Tape::new_ref(); let tape = Tape::new_ref();
let data = Data::from([2.0]); let data = Data::from([2.0, 5.0]);
let tensor = ADTchTensor::from_data(data.clone(), tape.clone()); let tensor = ADTchTensor::from_data(data.clone(), tape.clone());
let tensor_out = tensor.clone() * 4.0; let tensor_out = tensor.clone() * 4.0;
tensor_out.backprob(); tensor_out.backprob();
let grad = tensor.grad(); let grad = tensor.grad();
assert_eq!(tensor_out.into_data(), Data::from([8.0])); assert_eq!(tensor_out.into_data(), Data::from([8.0, 20.0]));
assert_eq!(grad.into_data(), Data::from([4.0])); assert_eq!(grad.into_data(), Data::from([4.0, 4.0]));
} }
} }

View File

@ -0,0 +1,115 @@
use crate::{
backend::autodiff::{ADFloat, ADFloatTensor, ADTensor},
define_ops, execute_ops,
ops::{
BinaryOps, BinaryRecordedOps, BinaryRecordedState, SingleOps, SingleRecordedOps,
SingleRecordedState,
},
register_ops, TensorOpsSub,
};
use num_traits::Float;
register_ops!(
ops BinaryOps<T, T, T>,
name ADTensorSubOps,
forward |left, right| left * right,
partial_left |state: &BinaryRecordedState<T, T, T>| state.left.ones(),
partial_right |state: &BinaryRecordedState<T, T, T>| state.right.ones().neg(),
);
register_ops!(
ops SingleOps<T, T>,
name ADTensorSubScalarOps state P,
forward |state, input| input * state,
partial |_state, state_recorded: &SingleRecordedState<T, T>| state_recorded.input.ones(),
);
impl<T, P, const D: usize> TensorOpsSub<P, D> for ADTensor<P, D, T>
where
T: ADFloatTensor<P, D>,
P: ADFloat,
{
fn sub(&self, other: &Self) -> Self {
let node = execute_ops!(
lhs self.node.clone(),
rhs other.node.clone(),
out TensorOpsSub::sub(&self.tensor(), &other.tensor()),
tape self.tape.clone(),
ops ADTensorSubOps::new(),
);
self.from_existing(node)
}
fn sub_scalar(&self, other: &P) -> Self {
let node = execute_ops!(
input self.node.clone(),
out TensorOpsSub::sub_scalar(&self.tensor(), &other),
tape self.tape.clone(),
ops ADTensorSubScalarOps::new(other.clone()),
);
self.from_existing(node)
}
}
impl<T, P, const D: usize> std::ops::Sub<P> for ADTensor<P, D, T>
where
T: ADFloatTensor<P, D> + 'static,
P: ADFloat + 'static,
{
type Output = ADTensor<P, D, T>;
fn sub(self, rhs: P) -> Self::Output {
TensorOpsSub::sub_scalar(&self, &rhs)
}
}
impl<T, P, const D: usize> std::ops::Sub<ADTensor<P, D, T>> for ADTensor<P, D, T>
where
T: ADFloatTensor<P, D> + 'static,
P: ADFloat + 'static,
{
type Output = ADTensor<P, D, T>;
fn sub(self, rhs: Self) -> Self::Output {
TensorOpsSub::sub(&self, &rhs)
}
}
#[cfg(test)]
mod tests {
use crate::{backend::autodiff::helper::ADTchTensor, tape::Tape, Data, TensorBase};
#[test]
fn should_diff_sub() {
let tape = Tape::new_ref();
let data_1 = Data::from([2.0, 5.0]);
let data_2 = Data::from([4.0, 1.0]);
let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone());
let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone());
let tensor_3 = tensor_1.clone() - tensor_2.clone();
tensor_3.backprob();
let grad_1 = tensor_1.grad();
let grad_2 = tensor_2.grad();
assert_eq!(grad_1.into_data(), Data::from([1.0, 1.0]));
assert_eq!(grad_2.into_data(), Data::from([-1.0, -1.0]));
assert_eq!(tensor_3.into_data(), Data::from([-2.0, 4.0]));
}
#[test]
fn should_diff_sub_scalar() {
let tape = Tape::new_ref();
let data = Data::from([2.0, 10.0]);
let tensor = ADTchTensor::from_data(data.clone(), tape.clone());
let tensor_out = tensor.clone() - 5.0;
tensor_out.backprob();
let grad = tensor.grad();
assert_eq!(grad.into_data(), Data::from([1.0, 1.0]));
assert_eq!(tensor_out.into_data(), Data::from([-3.0, 5.0]));
}
}

View File

@ -4,3 +4,4 @@ mod matmul;
mod mul; mod mul;
mod neg; mod neg;
mod reshape; mod reshape;
mod sub;

View File

@ -0,0 +1,71 @@
use crate::{backend::tch::TchTensor, Data, TensorOpsSub};
use std::ops::Sub;
impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TensorOpsSub<P, D>
for TchTensor<P, D>
{
fn sub(&self, other: &Self) -> Self {
let tensor = (&self.tensor).sub(&other.tensor);
let kind = self.kind.clone();
let shape = self.shape.clone();
Self {
tensor,
shape,
kind,
}
}
fn sub_scalar(&self, other: &P) -> Self {
let elems: [P; D] = [*other; D];
let data = Data::from(elems);
let other = TchTensor::from_data(data, self.tensor.device());
let tensor = (&self.tensor).sub(&other.tensor);
let kind = self.kind.clone();
let shape = self.shape.clone();
Self {
tensor,
shape,
kind,
}
}
}
impl<P: tch::kind::Element + Default + std::fmt::Debug + Copy, const D: usize> std::ops::Sub<Self>
for TchTensor<P, D>
{
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
TensorOpsSub::sub(&self, &rhs)
}
}
impl<P: tch::kind::Element + Default + std::fmt::Debug + Copy, const D: usize> std::ops::Sub<P>
for TchTensor<P, D>
{
type Output = Self;
fn sub(self, rhs: P) -> Self::Output {
TensorOpsSub::sub_scalar(&self, &rhs)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TensorBase;
#[test]
fn should_support_sub_ops() {
let data_1 = Data::<f64, 2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let data_2 = Data::<f64, 2>::from([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]);
let data_expected = Data::from([[-6.0, -6.0, -6.0], [-6.0, -6.0, -6.0]]);
let tensor_1 = TchTensor::from_data(data_1, tch::Device::Cpu);
let tensor_2 = TchTensor::from_data(data_2, tch::Device::Cpu);
let data_actual = (tensor_1 - tensor_2).into_data();
assert_eq!(data_expected, data_actual);
}
}

View File

@ -11,6 +11,7 @@ pub trait FloatTensor<P: num_traits::Float, const D: usize>:
+ TensorOpsMul<P, D> + TensorOpsMul<P, D>
+ TensorOpsNeg<P, D> + TensorOpsNeg<P, D>
+ TensorOpsAdd<P, D> + TensorOpsAdd<P, D>
+ TensorOpsSub<P, D>
+ TensorOpsMatmul<P, D> + TensorOpsMatmul<P, D>
+ std::fmt::Debug + std::fmt::Debug
{ {
@ -30,11 +31,20 @@ where
fn add_scalar(&self, other: &P) -> Self; fn add_scalar(&self, other: &P) -> Self;
} }
pub trait TensorOpsSub<P, const D: usize>:
std::ops::Sub<Self, Output = Self> + std::ops::Sub<P, Output = Self>
where
Self: Sized,
{
fn sub(&self, other: &Self) -> Self;
fn sub_scalar(&self, other: &P) -> Self;
}
pub trait TensorOpsMatmul<P, const D: usize> { pub trait TensorOpsMatmul<P, const D: usize> {
fn matmul(&self, other: &Self) -> Self; fn matmul(&self, other: &Self) -> Self;
} }
pub trait TensorOpsNeg<P, const D: usize>: std::ops::Neg { pub trait TensorOpsNeg<P, const D: usize>: std::ops::Neg<Output = Self> {
fn neg(&self) -> Self; fn neg(&self) -> Self;
} }