diff --git a/burn-tensor/src/graph/ops/binary.rs b/burn-tensor/src/graph/ops/binary.rs index 3ec06fd3e..ef16a8921 100644 --- a/burn-tensor/src/graph/ops/binary.rs +++ b/burn-tensor/src/graph/ops/binary.rs @@ -3,7 +3,6 @@ use crate::node::{Node, NodeId, NodeRef, Ones, Zeros}; use std::ops::{Add, Mul}; pub trait BinaryOps: std::fmt::Debug { - fn forward(&self, left: Lhs, right: Rhs) -> Out; fn partial_left(&self, state: &BinaryRecordedState) -> Lhs; fn partial_right(&self, state: &BinaryRecordedState) -> Rhs; } @@ -80,13 +79,11 @@ where } fn backward(&mut self) { - let left = self.lhs.borrow().value(); - let right = self.rhs.borrow().value(); - let output = self.out.borrow().value(); - let state = BinaryRecordedState::new(&left, &right, &output); + let state = BinaryRecordedState::new(&self.lhs, &self.rhs, &self.out); let partial_left = self.ops.partial_left(&state); let partial_right: Rhs = self.ops.partial_right(&state); + let grad_mine = self.out.borrow_mut().grad(); self.lhs diff --git a/burn-tensor/src/graph/ops/ops.rs b/burn-tensor/src/graph/ops/ops.rs index 7b1a5c047..a8e150cb6 100644 --- a/burn-tensor/src/graph/ops/ops.rs +++ b/burn-tensor/src/graph/ops/ops.rs @@ -1,10 +1,10 @@ -use crate::node::NodeId; +use crate::node::{NodeId, NodeRef}; #[derive(new)] pub struct BinaryRecordedState<'a, Lhs, Rhs, Out> { - pub left: &'a Lhs, - pub right: &'a Rhs, - pub output: &'a Out, + pub left: &'a NodeRef, + pub right: &'a NodeRef, + pub output: &'a NodeRef, } #[derive(new)] diff --git a/burn-tensor/src/graph/ops/single.rs b/burn-tensor/src/graph/ops/single.rs index 90ce04a50..d82fb537e 100644 --- a/burn-tensor/src/graph/ops/single.rs +++ b/burn-tensor/src/graph/ops/single.rs @@ -3,7 +3,6 @@ use crate::node::{Node, NodeId, NodeRef, Ones, Zeros}; use std::ops::{Add, Mul}; pub trait SingleOps: std::fmt::Debug { - fn forward(&self, input: In) -> Out; fn partial(&self, state: &SingleRecordedState) -> In; } diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/add.rs b/burn-tensor/src/tensor/backend/autodiff/ops/add.rs index 30e0f353b..b26ac8f44 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/add.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/add.rs @@ -12,15 +12,13 @@ use num_traits::Float; register_ops!( ops BinaryOps, name ADTensorAddOps, - forward |left, right| left * right, - partial_left |state: &BinaryRecordedState| state.left.ones(), - partial_right |state: &BinaryRecordedState| state.right.ones(), + partial_left |state: &BinaryRecordedState| state.left.borrow().value().ones(), + partial_right |state: &BinaryRecordedState| state.right.borrow().value().ones(), ); register_ops!( ops SingleOps, name ADTensorAddScalarOps state P, - forward |state, input| input * state, partial |_state, state_recorded: &SingleRecordedState| state_recorded.input.ones(), ); diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/macros.rs b/burn-tensor/src/tensor/backend/autodiff/ops/macros.rs index 7cddd365d..f4de74a9a 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/macros.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/macros.rs @@ -42,7 +42,6 @@ macro_rules! register_ops { ( ops $ops:ty, name $name:ident, - forward $forward:expr, partial_left $partial_left:expr, partial_right $partial_right:expr, ) => { @@ -55,10 +54,6 @@ macro_rules! register_ops { P: $crate::tensor::backend::autodiff::ADFloat, T: $crate::tensor::backend::autodiff::ADFloatTensor, { - fn forward(&self, left: T, right: T) -> T { - $forward(left, right) - } - fn partial_left(&self, state: &$crate::graph::ops::BinaryRecordedState) -> T { $partial_left(state) } @@ -71,7 +66,6 @@ macro_rules! register_ops { ( ops $ops:ty, name $name:ident state $ops_tensor_state:ident, - forward $forward:expr, partial $partial:expr, ) => { define_ops!( @@ -84,10 +78,6 @@ macro_rules! register_ops { P: $crate::tensor::backend::autodiff::ADFloat, T: $crate::tensor::backend::autodiff::ADFloatTensor, { - fn forward(&self, input: T) -> T { - $forward(self.state, input) - } - fn partial(&self, state: &$crate::graph::ops::SingleRecordedState) -> T { $partial(self.state, state) } @@ -96,7 +86,6 @@ macro_rules! register_ops { ( ops $ops:ty, name $name:ident, - forward $forward:expr, partial $partial:expr, ) => { define_ops!( @@ -108,10 +97,6 @@ macro_rules! register_ops { P: $crate::tensor::backend::autodiff::ADFloat, T: $crate::tensor::backend::autodiff::ADFloatTensor, { - fn forward(&self, input: T) -> T { - $forward(input) - } - fn partial(&self, state: &$crate::graph::ops::SingleRecordedState) -> T { $partial(state) } diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/matmul.rs b/burn-tensor/src/tensor/backend/autodiff/ops/matmul.rs new file mode 100644 index 000000000..43e399e5f --- /dev/null +++ b/burn-tensor/src/tensor/backend/autodiff/ops/matmul.rs @@ -0,0 +1,68 @@ +use crate::execute_ops; +use crate::{ + backend::autodiff::{ADFloat, ADFloatTensor, ADTensor}, + ops::{BinaryOps, BinaryRecordedOps, BinaryRecordedState}, + register_ops, TensorOpsMatmul, +}; +use num_traits::Float; + +register_ops!( + ops BinaryOps, + name ADTensorMatmulOps, + partial_left |state: &BinaryRecordedState| { + let out_grad = state.output.borrow_mut().grad(); + let rhs = state.right.borrow().value().transpose(); + out_grad.matmul(&rhs) + }, + partial_right |state: &BinaryRecordedState| { + let out_grad = state.output.borrow_mut().grad(); + let lhs = state.left.borrow().value().transpose(); + lhs.matmul(&out_grad) + }, +); + +impl TensorOpsMatmul for ADTensor +where + T: ADFloatTensor, + P: ADFloat, +{ + fn matmul(&self, other: &Self) -> Self { + let node = execute_ops!( + lhs self.node.clone(), + rhs other.node.clone(), + out TensorOpsMatmul::matmul(&self.tensor(), &other.tensor()), + tape self.tape.clone(), + ops ADTensorMatmulOps::new(), + ); + self.from_existing(node) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{backend::autodiff::helper::ADTchTensor, tape::Tape, Data, TensorBase}; + + #[test] + fn should_diff_mul() { + let tape = Tape::new_ref(); + let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + + let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone()); + let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone()); + + let tensor_3 = &tensor_1.matmul(&tensor_2); + tensor_3.backprob(); + + let grad_1 = tensor_1.grad(); + let grad_2 = tensor_2.grad(); + + assert_eq!(grad_1.into_data(), Data::from([[11.0, 5.0], [11.0, 5.0]])); + assert_eq!(grad_2.into_data(), Data::from([[3.0, 3.0], [10.0, 10.0]])); + assert_eq!( + tensor_3.clone().into_data(), + Data::from([[18.0, 28.0], [14.0, 23.0]]) + ); + } +} diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs b/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs index b8fb99f7a..fa67969aa 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs @@ -1,4 +1,5 @@ mod add; +mod matmul; mod mul; mod sub; diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs b/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs index 0fa6fceef..6a8b7b620 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs @@ -12,15 +12,13 @@ use num_traits::Float; register_ops!( ops BinaryOps, name ADTensorMulOps, - forward |left, right| left * right, - partial_left |state: &BinaryRecordedState| state.right.clone(), - partial_right |state: &BinaryRecordedState| state.left.clone(), + partial_left |state: &BinaryRecordedState| state.right.borrow().value().clone(), + partial_right |state: &BinaryRecordedState| state.left.borrow().value().clone(), ); register_ops!( ops SingleOps, name ADTensorMulScalarOps state P, - forward |state, input| input * state, partial |state, state_recorded: &SingleRecordedState| state_recorded.input.ones() * state, ); diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/sub.rs b/burn-tensor/src/tensor/backend/autodiff/ops/sub.rs index ad226caa1..fb482c521 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/sub.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/sub.rs @@ -12,15 +12,13 @@ use num_traits::Float; register_ops!( ops BinaryOps, name ADTensorSubOps, - forward |left, right| left * right, - partial_left |state: &BinaryRecordedState| state.left.ones(), - partial_right |state: &BinaryRecordedState| state.right.ones().neg(), + partial_left |state: &BinaryRecordedState| state.left.borrow().value().ones(), + partial_right |state: &BinaryRecordedState| state.right.borrow().value().ones().neg(), ); register_ops!( ops SingleOps, name ADTensorSubScalarOps state P, - forward |state, input| input * state, partial |_state, state_recorded: &SingleRecordedState| state_recorded.input.ones(), ); diff --git a/burn-tensor/src/tensor/backend/tch/ops/mod.rs b/burn-tensor/src/tensor/backend/tch/ops/mod.rs index cd516f69e..07495a134 100644 --- a/burn-tensor/src/tensor/backend/tch/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/tch/ops/mod.rs @@ -5,3 +5,4 @@ mod mul; mod neg; mod reshape; mod sub; +mod transpose; diff --git a/burn-tensor/src/tensor/backend/tch/ops/transpose.rs b/burn-tensor/src/tensor/backend/tch/ops/transpose.rs new file mode 100644 index 000000000..01028a85e --- /dev/null +++ b/burn-tensor/src/tensor/backend/tch/ops/transpose.rs @@ -0,0 +1,15 @@ +use crate::{backend::tch::TchTensor, Shape, TensorOpsTranspose}; + +impl TensorOpsTranspose for TchTensor { + fn transpose(&self) -> Self { + let tensor = self.tensor.transpose(-2, -1); + let kind = self.kind.clone(); + let shape = Shape::from(tensor.size()); + + Self { + kind, + tensor, + shape, + } + } +} diff --git a/burn-tensor/src/tensor/tensor.rs b/burn-tensor/src/tensor/tensor.rs index 764948fb6..8bd55e6a9 100644 --- a/burn-tensor/src/tensor/tensor.rs +++ b/burn-tensor/src/tensor/tensor.rs @@ -13,6 +13,7 @@ pub trait FloatTensor: + TensorOpsAdd + TensorOpsSub + TensorOpsMatmul + + TensorOpsTranspose + std::fmt::Debug { } @@ -40,6 +41,10 @@ where fn sub_scalar(&self, other: &P) -> Self; } +pub trait TensorOpsTranspose { + fn transpose(&self) -> Self; +} + pub trait TensorOpsMatmul { fn matmul(&self, other: &Self) -> Self; }