refactor: matmul-ops (#69)

This commit is contained in:
Nathaniel Simard 2022-11-05 16:29:52 -04:00 committed by GitHub
parent 10d1c13c88
commit 94b0283bac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 164 additions and 164 deletions

View File

@ -1,108 +0,0 @@
use crate::graph::ops::{BinaryOps, BinaryOpsNodeState};
use crate::tensor::backend::autodiff::ADTensor;
use crate::tensor::backend::Backend;
use crate::tensor::ops::*;
use crate::{execute_ops, register_ops};
register_ops!(
ops BinaryOps,
name ADTensorMatmulOps,
partial_left |state: &BinaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>, B::TensorPrimitive<D>>| {
let out_grad = state.output.grad();
let rhs = state.right.value().transpose();
out_grad.matmul(&rhs)
},
partial_right |state: &BinaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>, B::TensorPrimitive<D>>| {
let out_grad = state.output.grad();
let lhs = state.left.value().transpose();
lhs.matmul(&out_grad)
},
);
impl<B: Backend, P, const D: usize> TensorOpsMatmul<P, D> for ADTensor<D, B> {
fn matmul(&self, other: &Self) -> Self {
execute_ops!(
lhs self.node.clone(),
rhs other.node.clone(),
out TensorOpsMatmul::matmul(&self.tensor(), &other.tensor()),
ops ADTensorMatmulOps::<B, D>::new(),
)
}
}
#[cfg(test)]
mod tests {
use crate::tensor::{backend::autodiff::helper::TestADTensor, Data};
#[test]
fn should_diff_matmul() {
let data_1: Data<f64, 2> = Data::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2: Data<f64, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = &tensor_1.matmul(&tensor_2);
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]]));
assert_eq!(grad_2.to_data(), Data::from([[3.0, 3.0], [10.0, 10.0]]));
assert_eq!(
tensor_3.clone().into_data(),
Data::from([[18.0, 28.0], [14.0, 23.0]])
);
}
#[test]
fn test_matmul_complex_1() {
let data_1: Data<f64, 2> = Data::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2: Data<f64, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3: Data<f64, 2> = Data::from([[2.0, 2.0], [2.0, 2.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = TestADTensor::from_data(data_3);
let tensor_4 = tensor_1.matmul(&tensor_2);
let tensor_5 = tensor_4.matmul(&tensor_3);
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
assert_eq!(grad_1.to_data(), Data::from([[44.0, 20.0], [44.0, 20.0]]));
assert_eq!(grad_2.to_data(), Data::from([[56.0, 56.0], [16.0, 16.0]]));
}
#[test]
fn test_matmul_complex_2() {
let data_1: Data<f64, 2> = Data::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2: Data<f64, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3: Data<f64, 2> = Data::from([[2.0, 2.0], [2.0, 2.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = TestADTensor::from_data(data_3);
let tensor_4 = tensor_1.matmul(&tensor_2);
let tensor_5 = tensor_4.matmul(&tensor_3);
let tensor_6 = tensor_1.matmul(&tensor_5);
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
assert_eq!(
grad_1.to_data(),
Data::from([[800.0, 792.0], [360.0, 592.0]])
);
assert_eq!(
grad_2.to_data(),
Data::from([[264., 264.0], [344.0, 344.0]])
);
}
}

View File

@ -10,7 +10,6 @@ mod index;
mod log;
mod map_comparison;
mod mask;
mod matmul;
mod module;
mod neg;
mod pow;

View File

@ -5,7 +5,7 @@ use crate::{
Backend,
},
graph::ops::{BinaryOps, BinaryOpsNodeState, UnaryOps, UnaryOpsNodeState},
ops::{Ones, TensorOps, TensorOpsNeg},
ops::{Ones, TensorOps, TensorOpsNeg, TensorOpsTranspose},
Data, Shape,
};
@ -376,4 +376,50 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
unary_ops_wrapper(lhs.node.clone(), output, ops)
}
fn matmul<const D: usize>(
lhs: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
rhs: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
#[derive(Default, Debug)]
struct MatmulBackward<B: Backend, const D: usize> {
_b: B,
}
impl<B: Backend, const D: usize>
BinaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>, B::TensorPrimitive<D>>
for MatmulBackward<B, D>
{
fn partial_left(
&self,
state: &BinaryOpsNodeState<
B::TensorPrimitive<D>,
B::TensorPrimitive<D>,
B::TensorPrimitive<D>,
>,
) -> B::TensorPrimitive<D> {
let out_grad = state.output.grad();
let rhs = state.right.value().transpose();
B::matmul(&out_grad, &rhs)
}
fn partial_right(
&self,
state: &BinaryOpsNodeState<
B::TensorPrimitive<D>,
B::TensorPrimitive<D>,
B::TensorPrimitive<D>,
>,
) -> B::TensorPrimitive<D> {
let out_grad = state.output.grad();
let lhs = state.left.value().transpose();
B::matmul(&lhs, &out_grad)
}
}
let output = B::matmul(lhs.tensor_ref(), rhs.tensor_ref());
let ops = MatmulBackward::<B, D>::default();
binary_ops_wrapper(lhs.node.clone(), rhs.node.clone(), output, ops)
}
}

View File

@ -21,8 +21,7 @@ pub trait Backend:
type FullPrecisionElem: Element;
type FullPrecisionBackend: Backend<Elem = Self::FullPrecisionElem, Device = Self::Device>;
type IntegerBackend: Backend<Elem = i64, Device = Self::Device>;
type TensorPrimitive<const D: usize>: TensorOpsMatmul<Self::Elem, D>
+ std::ops::Add<Self::TensorPrimitive<D>, Output = Self::TensorPrimitive<D>>
type TensorPrimitive<const D: usize>: std::ops::Add<Self::TensorPrimitive<D>, Output = Self::TensorPrimitive<D>>
+ TensorOpsTranspose<Self::Elem, D>
+ TensorOpsNeg<Self::Elem, D>
+ TensorOpsDetach<Self::Elem, D>

View File

@ -1,29 +0,0 @@
use crate::tensor::{
backend::ndarray::{BatchMatrix, NdArrayTensor},
ops::*,
};
use ndarray::LinalgScalar;
impl<P, const D: usize> TensorOpsMatmul<P, D> for NdArrayTensor<P, D>
where
P: Clone + LinalgScalar + Default + std::fmt::Debug,
{
fn matmul(&self, other: &Self) -> Self {
let batch_self = BatchMatrix::from_ndarray(self.array.clone(), self.shape);
let batch_other = BatchMatrix::from_ndarray(other.array.clone(), other.shape);
let self_iter = batch_self.arrays.iter();
let other_iter = batch_other.arrays.iter();
let arrays = self_iter
.zip(other_iter)
.map(|(lhs, rhs)| lhs.dot(rhs))
.map(|output| output.into_shared())
.collect();
let mut shape = self.shape;
shape.dims[D - 1] = other.shape.dims[D - 1];
let output = BatchMatrix::new(arrays, shape);
Self::from_bmatrix(output)
}
}

View File

@ -9,7 +9,6 @@ mod index;
mod log;
mod map_comparison;
mod mask;
mod matmul;
mod neg;
mod pow;
mod precision;

View File

@ -1,4 +1,4 @@
use super::{NdArrayBackend, NdArrayTensor};
use super::{BatchMatrix, NdArrayBackend, NdArrayTensor};
use crate::{
backend::{Backend, NdArrayDevice},
ops::TensorOps,
@ -150,4 +150,26 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
NdArrayTensor { array, shape }
}
fn matmul<const D: usize>(
lhs: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
rhs: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
let batch_self = BatchMatrix::from_ndarray(lhs.array.clone(), lhs.shape);
let batch_other = BatchMatrix::from_ndarray(rhs.array.clone(), rhs.shape);
let self_iter = batch_self.arrays.iter();
let other_iter = batch_other.arrays.iter();
let arrays = self_iter
.zip(other_iter)
.map(|(lhs, rhs)| lhs.dot(rhs))
.map(|output| output.into_shared())
.collect();
let mut shape = lhs.shape;
shape.dims[D - 1] = rhs.shape.dims[D - 1];
let output = BatchMatrix::new(arrays, shape);
NdArrayTensor::from_bmatrix(output)
}
}

View File

@ -1,15 +0,0 @@
use crate::tensor::{backend::tch::TchTensor, ops::*, Shape};
impl<P: tch::kind::Element, const D: usize> TensorOpsMatmul<P, D> for TchTensor<P, D> {
fn matmul(&self, other: &Self) -> Self {
let tensor = self.tensor.matmul(&other.tensor);
let kind = self.kind.clone();
let shape = Shape::from(tensor.size());
Self {
kind,
tensor,
shape,
}
}
}

View File

@ -9,7 +9,6 @@ mod index;
mod log;
mod map_comparison;
mod mask;
mod matmul;
mod neg;
mod pow;
mod precision;

View File

@ -174,4 +174,16 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
kind,
}
}
fn matmul<const D: usize>(lhs: &TchTensor<E, D>, rhs: &TchTensor<E, D>) -> TchTensor<E, D> {
let tensor = lhs.tensor.matmul(&rhs.tensor);
let kind = lhs.kind;
let shape = Shape::from(tensor.size());
TchTensor {
tensor,
shape,
kind,
}
}
}

View File

@ -217,7 +217,7 @@ where
///
/// If the two tensors dont' have a compatible shape.
pub fn matmul(&self, other: &Self) -> Self {
Self::new(self.value.matmul(&other.value))
Self::new(B::matmul(&self.value, &other.value))
}
/// Switch sign of each element in the tensor.

View File

@ -98,6 +98,10 @@ pub trait TensorOps<B: Backend> {
lhs: &B::TensorPrimitive<D>,
rhs: &B::Elem,
) -> B::TensorPrimitive<D>;
fn matmul<const D: usize>(
lhs: &B::TensorPrimitive<D>,
rhs: &B::TensorPrimitive<D>,
) -> B::TensorPrimitive<D>;
}
pub trait TensorOpsTranspose<E, const D: usize> {
@ -105,10 +109,6 @@ pub trait TensorOpsTranspose<E, const D: usize> {
fn swap_dims(&self, dim1: usize, dim2: usize) -> Self;
}
pub trait TensorOpsMatmul<E, const D: usize> {
fn matmul(&self, other: &Self) -> Self;
}
pub trait TensorOpsNeg<E, const D: usize> {
fn neg(&self) -> Self;
}

View File

@ -0,0 +1,75 @@
use crate::tensor::TestADTensor;
use burn_tensor::Data;
#[test]
fn should_diff_matmul() {
let data_1: Data<f32, 2> = Data::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2: Data<f32, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = &tensor_1.matmul(&tensor_2);
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]]));
assert_eq!(grad_2.to_data(), Data::from([[3.0, 3.0], [10.0, 10.0]]));
assert_eq!(
tensor_3.clone().into_data(),
Data::from([[18.0, 28.0], [14.0, 23.0]])
);
}
#[test]
fn test_matmul_complex_1() {
let data_1: Data<f32, 2> = Data::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2: Data<f32, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3: Data<f32, 2> = Data::from([[2.0, 2.0], [2.0, 2.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = TestADTensor::from_data(data_3);
let tensor_4 = tensor_1.matmul(&tensor_2);
let tensor_5 = tensor_4.matmul(&tensor_3);
let grads = tensor_5.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
assert_eq!(grad_1.to_data(), Data::from([[44.0, 20.0], [44.0, 20.0]]));
assert_eq!(grad_2.to_data(), Data::from([[56.0, 56.0], [16.0, 16.0]]));
}
#[test]
fn test_matmul_complex_2() {
let data_1: Data<f32, 2> = Data::from([[1.0, 7.0], [13.0, -3.0]]);
let data_2: Data<f32, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);
let data_3: Data<f32, 2> = Data::from([[2.0, 2.0], [2.0, 2.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = TestADTensor::from_data(data_3);
let tensor_4 = tensor_1.matmul(&tensor_2);
let tensor_5 = tensor_4.matmul(&tensor_3);
let tensor_6 = tensor_1.matmul(&tensor_5);
let grads = tensor_6.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
assert_eq!(
grad_1.to_data(),
Data::from([[800.0, 792.0], [360.0, 592.0]])
);
assert_eq!(
grad_2.to_data(),
Data::from([[264., 264.0], [344.0, 344.0]])
);
}

View File

@ -2,6 +2,7 @@ mod add;
mod aggregation;
mod cross_entropy;
mod div;
mod matmul;
mod mul;
mod softmax;
mod sub;