From 94b0283bac6b0ebe4cc0b6719fd11d754a813f2a Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Sat, 5 Nov 2022 16:29:52 -0400 Subject: [PATCH] refactor: matmul-ops (#69) --- .../src/tensor/backend/autodiff/ops/matmul.rs | 108 ------------------ .../src/tensor/backend/autodiff/ops/mod.rs | 1 - .../src/tensor/backend/autodiff/ops/tensor.rs | 48 +++++++- burn-tensor/src/tensor/backend/base.rs | 3 +- .../src/tensor/backend/ndarray/ops/matmul.rs | 29 ----- .../src/tensor/backend/ndarray/ops/mod.rs | 1 - .../src/tensor/backend/ndarray/tensor_ops.rs | 24 +++- .../src/tensor/backend/tch/ops/matmul.rs | 15 --- burn-tensor/src/tensor/backend/tch/ops/mod.rs | 1 - .../src/tensor/backend/tch/tensor_ops.rs | 12 ++ burn-tensor/src/tensor/base.rs | 2 +- burn-tensor/src/tensor/ops/base.rs | 8 +- burn-tensor/tests/tensor/grad/matmul.rs | 75 ++++++++++++ burn-tensor/tests/tensor/grad/mod.rs | 1 + 14 files changed, 164 insertions(+), 164 deletions(-) delete mode 100644 burn-tensor/src/tensor/backend/autodiff/ops/matmul.rs delete mode 100644 burn-tensor/src/tensor/backend/ndarray/ops/matmul.rs delete mode 100644 burn-tensor/src/tensor/backend/tch/ops/matmul.rs create mode 100644 burn-tensor/tests/tensor/grad/matmul.rs diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/matmul.rs b/burn-tensor/src/tensor/backend/autodiff/ops/matmul.rs deleted file mode 100644 index 5eca5a410..000000000 --- a/burn-tensor/src/tensor/backend/autodiff/ops/matmul.rs +++ /dev/null @@ -1,108 +0,0 @@ -use crate::graph::ops::{BinaryOps, BinaryOpsNodeState}; -use crate::tensor::backend::autodiff::ADTensor; -use crate::tensor::backend::Backend; -use crate::tensor::ops::*; -use crate::{execute_ops, register_ops}; - -register_ops!( - ops BinaryOps, - name ADTensorMatmulOps, - partial_left |state: &BinaryOpsNodeState, B::TensorPrimitive, B::TensorPrimitive>| { - let out_grad = state.output.grad(); - let rhs = state.right.value().transpose(); - out_grad.matmul(&rhs) - }, - partial_right |state: &BinaryOpsNodeState, B::TensorPrimitive, B::TensorPrimitive>| { - let out_grad = state.output.grad(); - let lhs = state.left.value().transpose(); - lhs.matmul(&out_grad) - }, -); - -impl TensorOpsMatmul for ADTensor { - fn matmul(&self, other: &Self) -> Self { - execute_ops!( - lhs self.node.clone(), - rhs other.node.clone(), - out TensorOpsMatmul::matmul(&self.tensor(), &other.tensor()), - ops ADTensorMatmulOps::::new(), - ) - } -} - -#[cfg(test)] -mod tests { - use crate::tensor::{backend::autodiff::helper::TestADTensor, Data}; - - #[test] - fn should_diff_matmul() { - let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - - let tensor_1 = TestADTensor::from_data(data_1); - let tensor_2 = TestADTensor::from_data(data_2); - - let tensor_3 = &tensor_1.matmul(&tensor_2); - let grads = tensor_3.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]])); - assert_eq!(grad_2.to_data(), Data::from([[3.0, 3.0], [10.0, 10.0]])); - assert_eq!( - tensor_3.clone().into_data(), - Data::from([[18.0, 28.0], [14.0, 23.0]]) - ); - } - - #[test] - fn test_matmul_complex_1() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); - - let tensor_1 = TestADTensor::from_data(data_1); - let tensor_2 = TestADTensor::from_data(data_2); - let tensor_3 = TestADTensor::from_data(data_3); - - let tensor_4 = tensor_1.matmul(&tensor_2); - let tensor_5 = tensor_4.matmul(&tensor_3); - - let grads = tensor_5.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!(grad_1.to_data(), Data::from([[44.0, 20.0], [44.0, 20.0]])); - assert_eq!(grad_2.to_data(), Data::from([[56.0, 56.0], [16.0, 16.0]])); - } - #[test] - fn test_matmul_complex_2() { - let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); - let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); - let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); - - let tensor_1 = TestADTensor::from_data(data_1); - let tensor_2 = TestADTensor::from_data(data_2); - let tensor_3 = TestADTensor::from_data(data_3); - - let tensor_4 = tensor_1.matmul(&tensor_2); - let tensor_5 = tensor_4.matmul(&tensor_3); - let tensor_6 = tensor_1.matmul(&tensor_5); - - let grads = tensor_6.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - assert_eq!( - grad_1.to_data(), - Data::from([[800.0, 792.0], [360.0, 592.0]]) - ); - assert_eq!( - grad_2.to_data(), - Data::from([[264., 264.0], [344.0, 344.0]]) - ); - } -} diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs b/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs index 15615fe70..bb5550ef0 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs @@ -10,7 +10,6 @@ mod index; mod log; mod map_comparison; mod mask; -mod matmul; mod module; mod neg; mod pow; diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs b/burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs index e0112063f..605e71da6 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs @@ -5,7 +5,7 @@ use crate::{ Backend, }, graph::ops::{BinaryOps, BinaryOpsNodeState, UnaryOps, UnaryOpsNodeState}, - ops::{Ones, TensorOps, TensorOpsNeg}, + ops::{Ones, TensorOps, TensorOpsNeg, TensorOpsTranspose}, Data, Shape, }; @@ -376,4 +376,50 @@ impl TensorOps> for ADBackendDecorator { unary_ops_wrapper(lhs.node.clone(), output, ops) } + + fn matmul( + lhs: & as Backend>::TensorPrimitive, + rhs: & as Backend>::TensorPrimitive, + ) -> as Backend>::TensorPrimitive { + #[derive(Default, Debug)] + struct MatmulBackward { + _b: B, + } + + impl + BinaryOps, B::TensorPrimitive, B::TensorPrimitive> + for MatmulBackward + { + fn partial_left( + &self, + state: &BinaryOpsNodeState< + B::TensorPrimitive, + B::TensorPrimitive, + B::TensorPrimitive, + >, + ) -> B::TensorPrimitive { + let out_grad = state.output.grad(); + let rhs = state.right.value().transpose(); + B::matmul(&out_grad, &rhs) + } + + fn partial_right( + &self, + state: &BinaryOpsNodeState< + B::TensorPrimitive, + B::TensorPrimitive, + B::TensorPrimitive, + >, + ) -> B::TensorPrimitive { + let out_grad = state.output.grad(); + let lhs = state.left.value().transpose(); + B::matmul(&lhs, &out_grad) + } + } + + let output = B::matmul(lhs.tensor_ref(), rhs.tensor_ref()); + let ops = MatmulBackward::::default(); + + binary_ops_wrapper(lhs.node.clone(), rhs.node.clone(), output, ops) + } } diff --git a/burn-tensor/src/tensor/backend/base.rs b/burn-tensor/src/tensor/backend/base.rs index 0fbf3b60d..c33ee61b2 100644 --- a/burn-tensor/src/tensor/backend/base.rs +++ b/burn-tensor/src/tensor/backend/base.rs @@ -21,8 +21,7 @@ pub trait Backend: type FullPrecisionElem: Element; type FullPrecisionBackend: Backend; type IntegerBackend: Backend; - type TensorPrimitive: TensorOpsMatmul - + std::ops::Add, Output = Self::TensorPrimitive> + type TensorPrimitive: std::ops::Add, Output = Self::TensorPrimitive> + TensorOpsTranspose + TensorOpsNeg + TensorOpsDetach diff --git a/burn-tensor/src/tensor/backend/ndarray/ops/matmul.rs b/burn-tensor/src/tensor/backend/ndarray/ops/matmul.rs deleted file mode 100644 index 7f01ad8df..000000000 --- a/burn-tensor/src/tensor/backend/ndarray/ops/matmul.rs +++ /dev/null @@ -1,29 +0,0 @@ -use crate::tensor::{ - backend::ndarray::{BatchMatrix, NdArrayTensor}, - ops::*, -}; -use ndarray::LinalgScalar; - -impl TensorOpsMatmul for NdArrayTensor -where - P: Clone + LinalgScalar + Default + std::fmt::Debug, -{ - fn matmul(&self, other: &Self) -> Self { - let batch_self = BatchMatrix::from_ndarray(self.array.clone(), self.shape); - let batch_other = BatchMatrix::from_ndarray(other.array.clone(), other.shape); - - let self_iter = batch_self.arrays.iter(); - let other_iter = batch_other.arrays.iter(); - let arrays = self_iter - .zip(other_iter) - .map(|(lhs, rhs)| lhs.dot(rhs)) - .map(|output| output.into_shared()) - .collect(); - - let mut shape = self.shape; - shape.dims[D - 1] = other.shape.dims[D - 1]; - let output = BatchMatrix::new(arrays, shape); - - Self::from_bmatrix(output) - } -} diff --git a/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs b/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs index ea0bb132b..4e91db00e 100644 --- a/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs @@ -9,7 +9,6 @@ mod index; mod log; mod map_comparison; mod mask; -mod matmul; mod neg; mod pow; mod precision; diff --git a/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs b/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs index 29c84f10d..fae129364 100644 --- a/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs +++ b/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs @@ -1,4 +1,4 @@ -use super::{NdArrayBackend, NdArrayTensor}; +use super::{BatchMatrix, NdArrayBackend, NdArrayTensor}; use crate::{ backend::{Backend, NdArrayDevice}, ops::TensorOps, @@ -150,4 +150,26 @@ impl TensorOps> for NdArrayBackend { NdArrayTensor { array, shape } } + + fn matmul( + lhs: & as Backend>::TensorPrimitive, + rhs: & as Backend>::TensorPrimitive, + ) -> as Backend>::TensorPrimitive { + let batch_self = BatchMatrix::from_ndarray(lhs.array.clone(), lhs.shape); + let batch_other = BatchMatrix::from_ndarray(rhs.array.clone(), rhs.shape); + + let self_iter = batch_self.arrays.iter(); + let other_iter = batch_other.arrays.iter(); + let arrays = self_iter + .zip(other_iter) + .map(|(lhs, rhs)| lhs.dot(rhs)) + .map(|output| output.into_shared()) + .collect(); + + let mut shape = lhs.shape; + shape.dims[D - 1] = rhs.shape.dims[D - 1]; + let output = BatchMatrix::new(arrays, shape); + + NdArrayTensor::from_bmatrix(output) + } } diff --git a/burn-tensor/src/tensor/backend/tch/ops/matmul.rs b/burn-tensor/src/tensor/backend/tch/ops/matmul.rs deleted file mode 100644 index 38d953836..000000000 --- a/burn-tensor/src/tensor/backend/tch/ops/matmul.rs +++ /dev/null @@ -1,15 +0,0 @@ -use crate::tensor::{backend::tch::TchTensor, ops::*, Shape}; - -impl TensorOpsMatmul for TchTensor { - fn matmul(&self, other: &Self) -> Self { - let tensor = self.tensor.matmul(&other.tensor); - let kind = self.kind.clone(); - let shape = Shape::from(tensor.size()); - - Self { - kind, - tensor, - shape, - } - } -} diff --git a/burn-tensor/src/tensor/backend/tch/ops/mod.rs b/burn-tensor/src/tensor/backend/tch/ops/mod.rs index ea0bb132b..4e91db00e 100644 --- a/burn-tensor/src/tensor/backend/tch/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/tch/ops/mod.rs @@ -9,7 +9,6 @@ mod index; mod log; mod map_comparison; mod mask; -mod matmul; mod neg; mod pow; mod precision; diff --git a/burn-tensor/src/tensor/backend/tch/tensor_ops.rs b/burn-tensor/src/tensor/backend/tch/tensor_ops.rs index ce2fdcf76..71cc743a9 100644 --- a/burn-tensor/src/tensor/backend/tch/tensor_ops.rs +++ b/burn-tensor/src/tensor/backend/tch/tensor_ops.rs @@ -174,4 +174,16 @@ impl TensorOps> for TchBackend { kind, } } + + fn matmul(lhs: &TchTensor, rhs: &TchTensor) -> TchTensor { + let tensor = lhs.tensor.matmul(&rhs.tensor); + let kind = lhs.kind; + let shape = Shape::from(tensor.size()); + + TchTensor { + tensor, + shape, + kind, + } + } } diff --git a/burn-tensor/src/tensor/base.rs b/burn-tensor/src/tensor/base.rs index 7e37f47a0..a57f55ac4 100644 --- a/burn-tensor/src/tensor/base.rs +++ b/burn-tensor/src/tensor/base.rs @@ -217,7 +217,7 @@ where /// /// If the two tensors dont' have a compatible shape. pub fn matmul(&self, other: &Self) -> Self { - Self::new(self.value.matmul(&other.value)) + Self::new(B::matmul(&self.value, &other.value)) } /// Switch sign of each element in the tensor. diff --git a/burn-tensor/src/tensor/ops/base.rs b/burn-tensor/src/tensor/ops/base.rs index 428e3c4d0..013e9c19c 100644 --- a/burn-tensor/src/tensor/ops/base.rs +++ b/burn-tensor/src/tensor/ops/base.rs @@ -98,6 +98,10 @@ pub trait TensorOps { lhs: &B::TensorPrimitive, rhs: &B::Elem, ) -> B::TensorPrimitive; + fn matmul( + lhs: &B::TensorPrimitive, + rhs: &B::TensorPrimitive, + ) -> B::TensorPrimitive; } pub trait TensorOpsTranspose { @@ -105,10 +109,6 @@ pub trait TensorOpsTranspose { fn swap_dims(&self, dim1: usize, dim2: usize) -> Self; } -pub trait TensorOpsMatmul { - fn matmul(&self, other: &Self) -> Self; -} - pub trait TensorOpsNeg { fn neg(&self) -> Self; } diff --git a/burn-tensor/tests/tensor/grad/matmul.rs b/burn-tensor/tests/tensor/grad/matmul.rs new file mode 100644 index 000000000..9c7a927a6 --- /dev/null +++ b/burn-tensor/tests/tensor/grad/matmul.rs @@ -0,0 +1,75 @@ +use crate::tensor::TestADTensor; +use burn_tensor::Data; + +#[test] +fn should_diff_matmul() { + let data_1: Data = Data::from([[1.0, 7.0], [2.0, 3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + + let tensor_1 = TestADTensor::from_data(data_1); + let tensor_2 = TestADTensor::from_data(data_2); + + let tensor_3 = &tensor_1.matmul(&tensor_2); + let grads = tensor_3.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]])); + assert_eq!(grad_2.to_data(), Data::from([[3.0, 3.0], [10.0, 10.0]])); + assert_eq!( + tensor_3.clone().into_data(), + Data::from([[18.0, 28.0], [14.0, 23.0]]) + ); +} + +#[test] +fn test_matmul_complex_1() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); + + let tensor_1 = TestADTensor::from_data(data_1); + let tensor_2 = TestADTensor::from_data(data_2); + let tensor_3 = TestADTensor::from_data(data_3); + + let tensor_4 = tensor_1.matmul(&tensor_2); + let tensor_5 = tensor_4.matmul(&tensor_3); + + let grads = tensor_5.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!(grad_1.to_data(), Data::from([[44.0, 20.0], [44.0, 20.0]])); + assert_eq!(grad_2.to_data(), Data::from([[56.0, 56.0], [16.0, 16.0]])); +} + +#[test] +fn test_matmul_complex_2() { + let data_1: Data = Data::from([[1.0, 7.0], [13.0, -3.0]]); + let data_2: Data = Data::from([[4.0, 7.0], [2.0, 3.0]]); + let data_3: Data = Data::from([[2.0, 2.0], [2.0, 2.0]]); + + let tensor_1 = TestADTensor::from_data(data_1); + let tensor_2 = TestADTensor::from_data(data_2); + let tensor_3 = TestADTensor::from_data(data_3); + + let tensor_4 = tensor_1.matmul(&tensor_2); + let tensor_5 = tensor_4.matmul(&tensor_3); + let tensor_6 = tensor_1.matmul(&tensor_5); + + let grads = tensor_6.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!( + grad_1.to_data(), + Data::from([[800.0, 792.0], [360.0, 592.0]]) + ); + assert_eq!( + grad_2.to_data(), + Data::from([[264., 264.0], [344.0, 344.0]]) + ); +} diff --git a/burn-tensor/tests/tensor/grad/mod.rs b/burn-tensor/tests/tensor/grad/mod.rs index 72ed2915a..aa651f368 100644 --- a/burn-tensor/tests/tensor/grad/mod.rs +++ b/burn-tensor/tests/tensor/grad/mod.rs @@ -2,6 +2,7 @@ mod add; mod aggregation; mod cross_entropy; mod div; +mod matmul; mod mul; mod softmax; mod sub;