From 5ce657ded99f06c086c485206be473cbf8428a28 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 18 Jul 2022 20:23:21 -0400 Subject: [PATCH] feat: support sub autograd --- .../src/tensor/backend/autodiff/ops/add.rs | 24 +++- .../src/tensor/backend/autodiff/ops/mod.rs | 1 + .../src/tensor/backend/autodiff/ops/mul.rs | 12 +- .../src/tensor/backend/autodiff/ops/sub.rs | 115 ++++++++++++++++++ burn-tensor/src/tensor/backend/tch/ops/mod.rs | 1 + burn-tensor/src/tensor/backend/tch/ops/sub.rs | 71 +++++++++++ burn-tensor/src/tensor/tensor.rs | 12 +- 7 files changed, 224 insertions(+), 12 deletions(-) create mode 100644 burn-tensor/src/tensor/backend/autodiff/ops/sub.rs create mode 100644 burn-tensor/src/tensor/backend/tch/ops/sub.rs diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/add.rs b/burn-tensor/src/tensor/backend/autodiff/ops/add.rs index 99a049015..30e0f353b 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/add.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/add.rs @@ -82,8 +82,8 @@ mod tests { #[test] fn should_diff_add() { let tape = Tape::new_ref(); - let data_1 = Data::from([2.0]); - let data_2 = Data::from([4.0]); + let data_1 = Data::from([2.0, 5.0]); + let data_2 = Data::from([4.0, 1.0]); let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone()); let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone()); @@ -94,8 +94,22 @@ mod tests { let grad_1 = tensor_1.grad(); let grad_2 = tensor_2.grad(); - assert_eq!(grad_1.into_data(), Data::from([1.0])); - assert_eq!(grad_2.into_data(), Data::from([1.0])); - assert_eq!(tensor_3.into_data(), Data::from([6.0])); + assert_eq!(grad_1.into_data(), Data::from([1.0, 1.0])); + assert_eq!(grad_2.into_data(), Data::from([1.0, 1.0])); + assert_eq!(tensor_3.into_data(), Data::from([6.0, 6.0])); + } + + #[test] + fn should_diff_add_scalar() { + let tape = Tape::new_ref(); + let data = Data::from([2.0, 10.0]); + + let tensor = ADTchTensor::from_data(data.clone(), tape.clone()); + let tensor_out = tensor.clone() + 5.0; + tensor_out.backprob(); + + let grad = tensor.grad(); + assert_eq!(grad.into_data(), Data::from([1.0, 1.0])); + assert_eq!(tensor_out.into_data(), Data::from([7.0, 15.0])); } } diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs b/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs index 1134220eb..b8fb99f7a 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs @@ -1,5 +1,6 @@ mod add; mod mul; +mod sub; mod macros; pub use macros::*; diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs b/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs index ffeec8895..0fa6fceef 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs @@ -82,8 +82,8 @@ mod tests { #[test] fn should_diff_mul() { let tape = Tape::new_ref(); - let data_1 = Data::from([1.0]); - let data_2 = Data::from([4.0]); + let data_1 = Data::from([1.0, 7.0]); + let data_2 = Data::from([4.0, 7.0]); let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone()); let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone()); @@ -96,20 +96,20 @@ mod tests { assert_eq!(grad_1.into_data(), data_2); assert_eq!(grad_2.into_data(), data_1); - assert_eq!(tensor_3.into_data(), Data::from([4.0])); + assert_eq!(tensor_3.into_data(), Data::from([4.0, 49.0])); } #[test] fn should_diff_mul_scalar() { let tape = Tape::new_ref(); - let data = Data::from([2.0]); + let data = Data::from([2.0, 5.0]); let tensor = ADTchTensor::from_data(data.clone(), tape.clone()); let tensor_out = tensor.clone() * 4.0; tensor_out.backprob(); let grad = tensor.grad(); - assert_eq!(tensor_out.into_data(), Data::from([8.0])); - assert_eq!(grad.into_data(), Data::from([4.0])); + assert_eq!(tensor_out.into_data(), Data::from([8.0, 20.0])); + assert_eq!(grad.into_data(), Data::from([4.0, 4.0])); } } diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/sub.rs b/burn-tensor/src/tensor/backend/autodiff/ops/sub.rs new file mode 100644 index 000000000..ad226caa1 --- /dev/null +++ b/burn-tensor/src/tensor/backend/autodiff/ops/sub.rs @@ -0,0 +1,115 @@ +use crate::{ + backend::autodiff::{ADFloat, ADFloatTensor, ADTensor}, + define_ops, execute_ops, + ops::{ + BinaryOps, BinaryRecordedOps, BinaryRecordedState, SingleOps, SingleRecordedOps, + SingleRecordedState, + }, + register_ops, TensorOpsSub, +}; +use num_traits::Float; + +register_ops!( + ops BinaryOps, + name ADTensorSubOps, + forward |left, right| left * right, + partial_left |state: &BinaryRecordedState| state.left.ones(), + partial_right |state: &BinaryRecordedState| state.right.ones().neg(), +); + +register_ops!( + ops SingleOps, + name ADTensorSubScalarOps state P, + forward |state, input| input * state, + partial |_state, state_recorded: &SingleRecordedState| state_recorded.input.ones(), +); + +impl TensorOpsSub for ADTensor +where + T: ADFloatTensor, + P: ADFloat, +{ + fn sub(&self, other: &Self) -> Self { + let node = execute_ops!( + lhs self.node.clone(), + rhs other.node.clone(), + out TensorOpsSub::sub(&self.tensor(), &other.tensor()), + tape self.tape.clone(), + ops ADTensorSubOps::new(), + ); + self.from_existing(node) + } + + fn sub_scalar(&self, other: &P) -> Self { + let node = execute_ops!( + input self.node.clone(), + out TensorOpsSub::sub_scalar(&self.tensor(), &other), + tape self.tape.clone(), + ops ADTensorSubScalarOps::new(other.clone()), + ); + self.from_existing(node) + } +} + +impl std::ops::Sub

for ADTensor +where + T: ADFloatTensor + 'static, + P: ADFloat + 'static, +{ + type Output = ADTensor; + + fn sub(self, rhs: P) -> Self::Output { + TensorOpsSub::sub_scalar(&self, &rhs) + } +} + +impl std::ops::Sub> for ADTensor +where + T: ADFloatTensor + 'static, + P: ADFloat + 'static, +{ + type Output = ADTensor; + + fn sub(self, rhs: Self) -> Self::Output { + TensorOpsSub::sub(&self, &rhs) + } +} + +#[cfg(test)] +mod tests { + use crate::{backend::autodiff::helper::ADTchTensor, tape::Tape, Data, TensorBase}; + + #[test] + fn should_diff_sub() { + let tape = Tape::new_ref(); + let data_1 = Data::from([2.0, 5.0]); + let data_2 = Data::from([4.0, 1.0]); + + let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone()); + let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone()); + + let tensor_3 = tensor_1.clone() - tensor_2.clone(); + tensor_3.backprob(); + + let grad_1 = tensor_1.grad(); + let grad_2 = tensor_2.grad(); + + assert_eq!(grad_1.into_data(), Data::from([1.0, 1.0])); + assert_eq!(grad_2.into_data(), Data::from([-1.0, -1.0])); + assert_eq!(tensor_3.into_data(), Data::from([-2.0, 4.0])); + } + + #[test] + fn should_diff_sub_scalar() { + let tape = Tape::new_ref(); + let data = Data::from([2.0, 10.0]); + + let tensor = ADTchTensor::from_data(data.clone(), tape.clone()); + let tensor_out = tensor.clone() - 5.0; + tensor_out.backprob(); + + let grad = tensor.grad(); + assert_eq!(grad.into_data(), Data::from([1.0, 1.0])); + assert_eq!(tensor_out.into_data(), Data::from([-3.0, 5.0])); + } +} diff --git a/burn-tensor/src/tensor/backend/tch/ops/mod.rs b/burn-tensor/src/tensor/backend/tch/ops/mod.rs index 9d4e98467..cd516f69e 100644 --- a/burn-tensor/src/tensor/backend/tch/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/tch/ops/mod.rs @@ -4,3 +4,4 @@ mod matmul; mod mul; mod neg; mod reshape; +mod sub; diff --git a/burn-tensor/src/tensor/backend/tch/ops/sub.rs b/burn-tensor/src/tensor/backend/tch/ops/sub.rs new file mode 100644 index 000000000..bad313e85 --- /dev/null +++ b/burn-tensor/src/tensor/backend/tch/ops/sub.rs @@ -0,0 +1,71 @@ +use crate::{backend::tch::TchTensor, Data, TensorOpsSub}; +use std::ops::Sub; + +impl TensorOpsSub + for TchTensor +{ + fn sub(&self, other: &Self) -> Self { + let tensor = (&self.tensor).sub(&other.tensor); + let kind = self.kind.clone(); + let shape = self.shape.clone(); + + Self { + tensor, + shape, + kind, + } + } + fn sub_scalar(&self, other: &P) -> Self { + let elems: [P; D] = [*other; D]; + let data = Data::from(elems); + let other = TchTensor::from_data(data, self.tensor.device()); + let tensor = (&self.tensor).sub(&other.tensor); + let kind = self.kind.clone(); + let shape = self.shape.clone(); + + Self { + tensor, + shape, + kind, + } + } +} + +impl std::ops::Sub + for TchTensor +{ + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + TensorOpsSub::sub(&self, &rhs) + } +} + +impl std::ops::Sub

+ for TchTensor +{ + type Output = Self; + + fn sub(self, rhs: P) -> Self::Output { + TensorOpsSub::sub_scalar(&self, &rhs) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TensorBase; + + #[test] + fn should_support_sub_ops() { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let data_2 = Data::::from([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]); + let data_expected = Data::from([[-6.0, -6.0, -6.0], [-6.0, -6.0, -6.0]]); + let tensor_1 = TchTensor::from_data(data_1, tch::Device::Cpu); + let tensor_2 = TchTensor::from_data(data_2, tch::Device::Cpu); + + let data_actual = (tensor_1 - tensor_2).into_data(); + + assert_eq!(data_expected, data_actual); + } +} diff --git a/burn-tensor/src/tensor/tensor.rs b/burn-tensor/src/tensor/tensor.rs index 5747b61ca..764948fb6 100644 --- a/burn-tensor/src/tensor/tensor.rs +++ b/burn-tensor/src/tensor/tensor.rs @@ -11,6 +11,7 @@ pub trait FloatTensor: + TensorOpsMul + TensorOpsNeg + TensorOpsAdd + + TensorOpsSub + TensorOpsMatmul + std::fmt::Debug { @@ -30,11 +31,20 @@ where fn add_scalar(&self, other: &P) -> Self; } +pub trait TensorOpsSub: + std::ops::Sub + std::ops::Sub +where + Self: Sized, +{ + fn sub(&self, other: &Self) -> Self; + fn sub_scalar(&self, other: &P) -> Self; +} + pub trait TensorOpsMatmul { fn matmul(&self, other: &Self) -> Self; } -pub trait TensorOpsNeg: std::ops::Neg { +pub trait TensorOpsNeg: std::ops::Neg { fn neg(&self) -> Self; }