feat: implement matmul diff

This commit is contained in:
nathaniel 2022-07-18 21:55:39 -04:00
parent 5ce657ded9
commit 2084aced63
12 changed files with 102 additions and 37 deletions

View File

@ -3,7 +3,6 @@ use crate::node::{Node, NodeId, NodeRef, Ones, Zeros};
use std::ops::{Add, Mul}; use std::ops::{Add, Mul};
pub trait BinaryOps<Lhs, Rhs, Out>: std::fmt::Debug { pub trait BinaryOps<Lhs, Rhs, Out>: std::fmt::Debug {
fn forward(&self, left: Lhs, right: Rhs) -> Out;
fn partial_left(&self, state: &BinaryRecordedState<Lhs, Rhs, Out>) -> Lhs; fn partial_left(&self, state: &BinaryRecordedState<Lhs, Rhs, Out>) -> Lhs;
fn partial_right(&self, state: &BinaryRecordedState<Lhs, Rhs, Out>) -> Rhs; fn partial_right(&self, state: &BinaryRecordedState<Lhs, Rhs, Out>) -> Rhs;
} }
@ -80,13 +79,11 @@ where
} }
fn backward(&mut self) { fn backward(&mut self) {
let left = self.lhs.borrow().value(); let state = BinaryRecordedState::new(&self.lhs, &self.rhs, &self.out);
let right = self.rhs.borrow().value();
let output = self.out.borrow().value();
let state = BinaryRecordedState::new(&left, &right, &output);
let partial_left = self.ops.partial_left(&state); let partial_left = self.ops.partial_left(&state);
let partial_right: Rhs = self.ops.partial_right(&state); let partial_right: Rhs = self.ops.partial_right(&state);
let grad_mine = self.out.borrow_mut().grad(); let grad_mine = self.out.borrow_mut().grad();
self.lhs self.lhs

View File

@ -1,10 +1,10 @@
use crate::node::NodeId; use crate::node::{NodeId, NodeRef};
#[derive(new)] #[derive(new)]
pub struct BinaryRecordedState<'a, Lhs, Rhs, Out> { pub struct BinaryRecordedState<'a, Lhs, Rhs, Out> {
pub left: &'a Lhs, pub left: &'a NodeRef<Lhs>,
pub right: &'a Rhs, pub right: &'a NodeRef<Rhs>,
pub output: &'a Out, pub output: &'a NodeRef<Out>,
} }
#[derive(new)] #[derive(new)]

View File

@ -3,7 +3,6 @@ use crate::node::{Node, NodeId, NodeRef, Ones, Zeros};
use std::ops::{Add, Mul}; use std::ops::{Add, Mul};
pub trait SingleOps<In, Out>: std::fmt::Debug { pub trait SingleOps<In, Out>: std::fmt::Debug {
fn forward(&self, input: In) -> Out;
fn partial(&self, state: &SingleRecordedState<In, Out>) -> In; fn partial(&self, state: &SingleRecordedState<In, Out>) -> In;
} }

View File

@ -12,15 +12,13 @@ use num_traits::Float;
register_ops!( register_ops!(
ops BinaryOps<T, T, T>, ops BinaryOps<T, T, T>,
name ADTensorAddOps, name ADTensorAddOps,
forward |left, right| left * right, partial_left |state: &BinaryRecordedState<T, T, T>| state.left.borrow().value().ones(),
partial_left |state: &BinaryRecordedState<T, T, T>| state.left.ones(), partial_right |state: &BinaryRecordedState<T, T, T>| state.right.borrow().value().ones(),
partial_right |state: &BinaryRecordedState<T, T, T>| state.right.ones(),
); );
register_ops!( register_ops!(
ops SingleOps<T, T>, ops SingleOps<T, T>,
name ADTensorAddScalarOps state P, name ADTensorAddScalarOps state P,
forward |state, input| input * state,
partial |_state, state_recorded: &SingleRecordedState<T, T>| state_recorded.input.ones(), partial |_state, state_recorded: &SingleRecordedState<T, T>| state_recorded.input.ones(),
); );

View File

@ -42,7 +42,6 @@ macro_rules! register_ops {
( (
ops $ops:ty, ops $ops:ty,
name $name:ident, name $name:ident,
forward $forward:expr,
partial_left $partial_left:expr, partial_left $partial_left:expr,
partial_right $partial_right:expr, partial_right $partial_right:expr,
) => { ) => {
@ -55,10 +54,6 @@ macro_rules! register_ops {
P: $crate::tensor::backend::autodiff::ADFloat, P: $crate::tensor::backend::autodiff::ADFloat,
T: $crate::tensor::backend::autodiff::ADFloatTensor<P, D>, T: $crate::tensor::backend::autodiff::ADFloatTensor<P, D>,
{ {
fn forward(&self, left: T, right: T) -> T {
$forward(left, right)
}
fn partial_left(&self, state: &$crate::graph::ops::BinaryRecordedState<T, T, T>) -> T { fn partial_left(&self, state: &$crate::graph::ops::BinaryRecordedState<T, T, T>) -> T {
$partial_left(state) $partial_left(state)
} }
@ -71,7 +66,6 @@ macro_rules! register_ops {
( (
ops $ops:ty, ops $ops:ty,
name $name:ident state $ops_tensor_state:ident, name $name:ident state $ops_tensor_state:ident,
forward $forward:expr,
partial $partial:expr, partial $partial:expr,
) => { ) => {
define_ops!( define_ops!(
@ -84,10 +78,6 @@ macro_rules! register_ops {
P: $crate::tensor::backend::autodiff::ADFloat, P: $crate::tensor::backend::autodiff::ADFloat,
T: $crate::tensor::backend::autodiff::ADFloatTensor<P, D>, T: $crate::tensor::backend::autodiff::ADFloatTensor<P, D>,
{ {
fn forward(&self, input: T) -> T {
$forward(self.state, input)
}
fn partial(&self, state: &$crate::graph::ops::SingleRecordedState<T, T>) -> T { fn partial(&self, state: &$crate::graph::ops::SingleRecordedState<T, T>) -> T {
$partial(self.state, state) $partial(self.state, state)
} }
@ -96,7 +86,6 @@ macro_rules! register_ops {
( (
ops $ops:ty, ops $ops:ty,
name $name:ident, name $name:ident,
forward $forward:expr,
partial $partial:expr, partial $partial:expr,
) => { ) => {
define_ops!( define_ops!(
@ -108,10 +97,6 @@ macro_rules! register_ops {
P: $crate::tensor::backend::autodiff::ADFloat, P: $crate::tensor::backend::autodiff::ADFloat,
T: $crate::tensor::backend::autodiff::ADFloatTensor<P, D>, T: $crate::tensor::backend::autodiff::ADFloatTensor<P, D>,
{ {
fn forward(&self, input: T) -> T {
$forward(input)
}
fn partial(&self, state: &$crate::graph::ops::SingleRecordedState<T, T>) -> T { fn partial(&self, state: &$crate::graph::ops::SingleRecordedState<T, T>) -> T {
$partial(state) $partial(state)
} }

View File

@ -0,0 +1,68 @@
use crate::execute_ops;
use crate::{
backend::autodiff::{ADFloat, ADFloatTensor, ADTensor},
ops::{BinaryOps, BinaryRecordedOps, BinaryRecordedState},
register_ops, TensorOpsMatmul,
};
use num_traits::Float;
register_ops!(
ops BinaryOps<T, T, T>,
name ADTensorMatmulOps,
partial_left |state: &BinaryRecordedState<T, T, T>| {
let out_grad = state.output.borrow_mut().grad();
let rhs = state.right.borrow().value().transpose();
out_grad.matmul(&rhs)
},
partial_right |state: &BinaryRecordedState<T, T, T>| {
let out_grad = state.output.borrow_mut().grad();
let lhs = state.left.borrow().value().transpose();
lhs.matmul(&out_grad)
},
);
impl<T, P, const D: usize> TensorOpsMatmul<P, D> for ADTensor<P, D, T>
where
T: ADFloatTensor<P, D>,
P: ADFloat,
{
fn matmul(&self, other: &Self) -> Self {
let node = execute_ops!(
lhs self.node.clone(),
rhs other.node.clone(),
out TensorOpsMatmul::matmul(&self.tensor(), &other.tensor()),
tape self.tape.clone(),
ops ADTensorMatmulOps::new(),
);
self.from_existing(node)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{backend::autodiff::helper::ADTchTensor, tape::Tape, Data, TensorBase};
#[test]
fn should_diff_mul() {
let tape = Tape::new_ref();
let data_1: Data<f64, 2> = Data::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2: Data<f64, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);
let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone());
let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone());
let tensor_3 = &tensor_1.matmul(&tensor_2);
tensor_3.backprob();
let grad_1 = tensor_1.grad();
let grad_2 = tensor_2.grad();
assert_eq!(grad_1.into_data(), Data::from([[11.0, 5.0], [11.0, 5.0]]));
assert_eq!(grad_2.into_data(), Data::from([[3.0, 3.0], [10.0, 10.0]]));
assert_eq!(
tensor_3.clone().into_data(),
Data::from([[18.0, 28.0], [14.0, 23.0]])
);
}
}

View File

@ -1,4 +1,5 @@
mod add; mod add;
mod matmul;
mod mul; mod mul;
mod sub; mod sub;

View File

@ -12,15 +12,13 @@ use num_traits::Float;
register_ops!( register_ops!(
ops BinaryOps<T, T, T>, ops BinaryOps<T, T, T>,
name ADTensorMulOps, name ADTensorMulOps,
forward |left, right| left * right, partial_left |state: &BinaryRecordedState<T, T, T>| state.right.borrow().value().clone(),
partial_left |state: &BinaryRecordedState<T, T, T>| state.right.clone(), partial_right |state: &BinaryRecordedState<T, T, T>| state.left.borrow().value().clone(),
partial_right |state: &BinaryRecordedState<T, T, T>| state.left.clone(),
); );
register_ops!( register_ops!(
ops SingleOps<T, T>, ops SingleOps<T, T>,
name ADTensorMulScalarOps state P, name ADTensorMulScalarOps state P,
forward |state, input| input * state,
partial |state, state_recorded: &SingleRecordedState<T, T>| state_recorded.input.ones() * state, partial |state, state_recorded: &SingleRecordedState<T, T>| state_recorded.input.ones() * state,
); );

View File

@ -12,15 +12,13 @@ use num_traits::Float;
register_ops!( register_ops!(
ops BinaryOps<T, T, T>, ops BinaryOps<T, T, T>,
name ADTensorSubOps, name ADTensorSubOps,
forward |left, right| left * right, partial_left |state: &BinaryRecordedState<T, T, T>| state.left.borrow().value().ones(),
partial_left |state: &BinaryRecordedState<T, T, T>| state.left.ones(), partial_right |state: &BinaryRecordedState<T, T, T>| state.right.borrow().value().ones().neg(),
partial_right |state: &BinaryRecordedState<T, T, T>| state.right.ones().neg(),
); );
register_ops!( register_ops!(
ops SingleOps<T, T>, ops SingleOps<T, T>,
name ADTensorSubScalarOps state P, name ADTensorSubScalarOps state P,
forward |state, input| input * state,
partial |_state, state_recorded: &SingleRecordedState<T, T>| state_recorded.input.ones(), partial |_state, state_recorded: &SingleRecordedState<T, T>| state_recorded.input.ones(),
); );

View File

@ -5,3 +5,4 @@ mod mul;
mod neg; mod neg;
mod reshape; mod reshape;
mod sub; mod sub;
mod transpose;

View File

@ -0,0 +1,15 @@
use crate::{backend::tch::TchTensor, Shape, TensorOpsTranspose};
impl<P: tch::kind::Element, const D: usize> TensorOpsTranspose<P, D> for TchTensor<P, D> {
fn transpose(&self) -> Self {
let tensor = self.tensor.transpose(-2, -1);
let kind = self.kind.clone();
let shape = Shape::from(tensor.size());
Self {
kind,
tensor,
shape,
}
}
}

View File

@ -13,6 +13,7 @@ pub trait FloatTensor<P: num_traits::Float, const D: usize>:
+ TensorOpsAdd<P, D> + TensorOpsAdd<P, D>
+ TensorOpsSub<P, D> + TensorOpsSub<P, D>
+ TensorOpsMatmul<P, D> + TensorOpsMatmul<P, D>
+ TensorOpsTranspose<P, D>
+ std::fmt::Debug + std::fmt::Debug
{ {
} }
@ -40,6 +41,10 @@ where
fn sub_scalar(&self, other: &P) -> Self; fn sub_scalar(&self, other: &P) -> Self;
} }
pub trait TensorOpsTranspose<P, const D: usize> {
fn transpose(&self) -> Self;
}
pub trait TensorOpsMatmul<P, const D: usize> { pub trait TensorOpsMatmul<P, const D: usize> {
fn matmul(&self, other: &Self) -> Self; fn matmul(&self, other: &Self) -> Self;
} }