mirror of https://github.com/tracel-ai/burn.git
refactor: matmul-ops (#69)
This commit is contained in:
parent
10d1c13c88
commit
94b0283bac
|
@ -1,108 +0,0 @@
|
|||
use crate::graph::ops::{BinaryOps, BinaryOpsNodeState};
|
||||
use crate::tensor::backend::autodiff::ADTensor;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::ops::*;
|
||||
use crate::{execute_ops, register_ops};
|
||||
|
||||
register_ops!(
|
||||
ops BinaryOps,
|
||||
name ADTensorMatmulOps,
|
||||
partial_left |state: &BinaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>, B::TensorPrimitive<D>>| {
|
||||
let out_grad = state.output.grad();
|
||||
let rhs = state.right.value().transpose();
|
||||
out_grad.matmul(&rhs)
|
||||
},
|
||||
partial_right |state: &BinaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>, B::TensorPrimitive<D>>| {
|
||||
let out_grad = state.output.grad();
|
||||
let lhs = state.left.value().transpose();
|
||||
lhs.matmul(&out_grad)
|
||||
},
|
||||
);
|
||||
|
||||
impl<B: Backend, P, const D: usize> TensorOpsMatmul<P, D> for ADTensor<D, B> {
|
||||
fn matmul(&self, other: &Self) -> Self {
|
||||
execute_ops!(
|
||||
lhs self.node.clone(),
|
||||
rhs other.node.clone(),
|
||||
out TensorOpsMatmul::matmul(&self.tensor(), &other.tensor()),
|
||||
ops ADTensorMatmulOps::<B, D>::new(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::tensor::{backend::autodiff::helper::TestADTensor, Data};
|
||||
|
||||
#[test]
|
||||
fn should_diff_matmul() {
|
||||
let data_1: Data<f64, 2> = Data::from([[1.0, 7.0], [2.0, 3.0]]);
|
||||
let data_2: Data<f64, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
|
||||
let tensor_1 = TestADTensor::from_data(data_1);
|
||||
let tensor_2 = TestADTensor::from_data(data_2);
|
||||
|
||||
let tensor_3 = &tensor_1.matmul(&tensor_2);
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]]));
|
||||
assert_eq!(grad_2.to_data(), Data::from([[3.0, 3.0], [10.0, 10.0]]));
|
||||
assert_eq!(
|
||||
tensor_3.clone().into_data(),
|
||||
Data::from([[18.0, 28.0], [14.0, 23.0]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_complex_1() {
|
||||
let data_1: Data<f64, 2> = Data::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2: Data<f64, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
let data_3: Data<f64, 2> = Data::from([[2.0, 2.0], [2.0, 2.0]]);
|
||||
|
||||
let tensor_1 = TestADTensor::from_data(data_1);
|
||||
let tensor_2 = TestADTensor::from_data(data_2);
|
||||
let tensor_3 = TestADTensor::from_data(data_3);
|
||||
|
||||
let tensor_4 = tensor_1.matmul(&tensor_2);
|
||||
let tensor_5 = tensor_4.matmul(&tensor_3);
|
||||
|
||||
let grads = tensor_5.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
assert_eq!(grad_1.to_data(), Data::from([[44.0, 20.0], [44.0, 20.0]]));
|
||||
assert_eq!(grad_2.to_data(), Data::from([[56.0, 56.0], [16.0, 16.0]]));
|
||||
}
|
||||
#[test]
|
||||
fn test_matmul_complex_2() {
|
||||
let data_1: Data<f64, 2> = Data::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2: Data<f64, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
let data_3: Data<f64, 2> = Data::from([[2.0, 2.0], [2.0, 2.0]]);
|
||||
|
||||
let tensor_1 = TestADTensor::from_data(data_1);
|
||||
let tensor_2 = TestADTensor::from_data(data_2);
|
||||
let tensor_3 = TestADTensor::from_data(data_3);
|
||||
|
||||
let tensor_4 = tensor_1.matmul(&tensor_2);
|
||||
let tensor_5 = tensor_4.matmul(&tensor_3);
|
||||
let tensor_6 = tensor_1.matmul(&tensor_5);
|
||||
|
||||
let grads = tensor_6.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
grad_1.to_data(),
|
||||
Data::from([[800.0, 792.0], [360.0, 592.0]])
|
||||
);
|
||||
assert_eq!(
|
||||
grad_2.to_data(),
|
||||
Data::from([[264., 264.0], [344.0, 344.0]])
|
||||
);
|
||||
}
|
||||
}
|
|
@ -10,7 +10,6 @@ mod index;
|
|||
mod log;
|
||||
mod map_comparison;
|
||||
mod mask;
|
||||
mod matmul;
|
||||
mod module;
|
||||
mod neg;
|
||||
mod pow;
|
||||
|
|
|
@ -5,7 +5,7 @@ use crate::{
|
|||
Backend,
|
||||
},
|
||||
graph::ops::{BinaryOps, BinaryOpsNodeState, UnaryOps, UnaryOpsNodeState},
|
||||
ops::{Ones, TensorOps, TensorOpsNeg},
|
||||
ops::{Ones, TensorOps, TensorOpsNeg, TensorOpsTranspose},
|
||||
Data, Shape,
|
||||
};
|
||||
|
||||
|
@ -376,4 +376,50 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
|
||||
unary_ops_wrapper(lhs.node.clone(), output, ops)
|
||||
}
|
||||
|
||||
fn matmul<const D: usize>(
|
||||
lhs: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
|
||||
rhs: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
|
||||
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
|
||||
#[derive(Default, Debug)]
|
||||
struct MatmulBackward<B: Backend, const D: usize> {
|
||||
_b: B,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize>
|
||||
BinaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>, B::TensorPrimitive<D>>
|
||||
for MatmulBackward<B, D>
|
||||
{
|
||||
fn partial_left(
|
||||
&self,
|
||||
state: &BinaryOpsNodeState<
|
||||
B::TensorPrimitive<D>,
|
||||
B::TensorPrimitive<D>,
|
||||
B::TensorPrimitive<D>,
|
||||
>,
|
||||
) -> B::TensorPrimitive<D> {
|
||||
let out_grad = state.output.grad();
|
||||
let rhs = state.right.value().transpose();
|
||||
B::matmul(&out_grad, &rhs)
|
||||
}
|
||||
|
||||
fn partial_right(
|
||||
&self,
|
||||
state: &BinaryOpsNodeState<
|
||||
B::TensorPrimitive<D>,
|
||||
B::TensorPrimitive<D>,
|
||||
B::TensorPrimitive<D>,
|
||||
>,
|
||||
) -> B::TensorPrimitive<D> {
|
||||
let out_grad = state.output.grad();
|
||||
let lhs = state.left.value().transpose();
|
||||
B::matmul(&lhs, &out_grad)
|
||||
}
|
||||
}
|
||||
|
||||
let output = B::matmul(lhs.tensor_ref(), rhs.tensor_ref());
|
||||
let ops = MatmulBackward::<B, D>::default();
|
||||
|
||||
binary_ops_wrapper(lhs.node.clone(), rhs.node.clone(), output, ops)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,8 +21,7 @@ pub trait Backend:
|
|||
type FullPrecisionElem: Element;
|
||||
type FullPrecisionBackend: Backend<Elem = Self::FullPrecisionElem, Device = Self::Device>;
|
||||
type IntegerBackend: Backend<Elem = i64, Device = Self::Device>;
|
||||
type TensorPrimitive<const D: usize>: TensorOpsMatmul<Self::Elem, D>
|
||||
+ std::ops::Add<Self::TensorPrimitive<D>, Output = Self::TensorPrimitive<D>>
|
||||
type TensorPrimitive<const D: usize>: std::ops::Add<Self::TensorPrimitive<D>, Output = Self::TensorPrimitive<D>>
|
||||
+ TensorOpsTranspose<Self::Elem, D>
|
||||
+ TensorOpsNeg<Self::Elem, D>
|
||||
+ TensorOpsDetach<Self::Elem, D>
|
||||
|
|
|
@ -1,29 +0,0 @@
|
|||
use crate::tensor::{
|
||||
backend::ndarray::{BatchMatrix, NdArrayTensor},
|
||||
ops::*,
|
||||
};
|
||||
use ndarray::LinalgScalar;
|
||||
|
||||
impl<P, const D: usize> TensorOpsMatmul<P, D> for NdArrayTensor<P, D>
|
||||
where
|
||||
P: Clone + LinalgScalar + Default + std::fmt::Debug,
|
||||
{
|
||||
fn matmul(&self, other: &Self) -> Self {
|
||||
let batch_self = BatchMatrix::from_ndarray(self.array.clone(), self.shape);
|
||||
let batch_other = BatchMatrix::from_ndarray(other.array.clone(), other.shape);
|
||||
|
||||
let self_iter = batch_self.arrays.iter();
|
||||
let other_iter = batch_other.arrays.iter();
|
||||
let arrays = self_iter
|
||||
.zip(other_iter)
|
||||
.map(|(lhs, rhs)| lhs.dot(rhs))
|
||||
.map(|output| output.into_shared())
|
||||
.collect();
|
||||
|
||||
let mut shape = self.shape;
|
||||
shape.dims[D - 1] = other.shape.dims[D - 1];
|
||||
let output = BatchMatrix::new(arrays, shape);
|
||||
|
||||
Self::from_bmatrix(output)
|
||||
}
|
||||
}
|
|
@ -9,7 +9,6 @@ mod index;
|
|||
mod log;
|
||||
mod map_comparison;
|
||||
mod mask;
|
||||
mod matmul;
|
||||
mod neg;
|
||||
mod pow;
|
||||
mod precision;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{NdArrayBackend, NdArrayTensor};
|
||||
use super::{BatchMatrix, NdArrayBackend, NdArrayTensor};
|
||||
use crate::{
|
||||
backend::{Backend, NdArrayDevice},
|
||||
ops::TensorOps,
|
||||
|
@ -150,4 +150,26 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
|||
|
||||
NdArrayTensor { array, shape }
|
||||
}
|
||||
|
||||
fn matmul<const D: usize>(
|
||||
lhs: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
|
||||
rhs: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
|
||||
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
|
||||
let batch_self = BatchMatrix::from_ndarray(lhs.array.clone(), lhs.shape);
|
||||
let batch_other = BatchMatrix::from_ndarray(rhs.array.clone(), rhs.shape);
|
||||
|
||||
let self_iter = batch_self.arrays.iter();
|
||||
let other_iter = batch_other.arrays.iter();
|
||||
let arrays = self_iter
|
||||
.zip(other_iter)
|
||||
.map(|(lhs, rhs)| lhs.dot(rhs))
|
||||
.map(|output| output.into_shared())
|
||||
.collect();
|
||||
|
||||
let mut shape = lhs.shape;
|
||||
shape.dims[D - 1] = rhs.shape.dims[D - 1];
|
||||
let output = BatchMatrix::new(arrays, shape);
|
||||
|
||||
NdArrayTensor::from_bmatrix(output)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,15 +0,0 @@
|
|||
use crate::tensor::{backend::tch::TchTensor, ops::*, Shape};
|
||||
|
||||
impl<P: tch::kind::Element, const D: usize> TensorOpsMatmul<P, D> for TchTensor<P, D> {
|
||||
fn matmul(&self, other: &Self) -> Self {
|
||||
let tensor = self.tensor.matmul(&other.tensor);
|
||||
let kind = self.kind.clone();
|
||||
let shape = Shape::from(tensor.size());
|
||||
|
||||
Self {
|
||||
kind,
|
||||
tensor,
|
||||
shape,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -9,7 +9,6 @@ mod index;
|
|||
mod log;
|
||||
mod map_comparison;
|
||||
mod mask;
|
||||
mod matmul;
|
||||
mod neg;
|
||||
mod pow;
|
||||
mod precision;
|
||||
|
|
|
@ -174,4 +174,16 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
kind,
|
||||
}
|
||||
}
|
||||
|
||||
fn matmul<const D: usize>(lhs: &TchTensor<E, D>, rhs: &TchTensor<E, D>) -> TchTensor<E, D> {
|
||||
let tensor = lhs.tensor.matmul(&rhs.tensor);
|
||||
let kind = lhs.kind;
|
||||
let shape = Shape::from(tensor.size());
|
||||
|
||||
TchTensor {
|
||||
tensor,
|
||||
shape,
|
||||
kind,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -217,7 +217,7 @@ where
|
|||
///
|
||||
/// If the two tensors dont' have a compatible shape.
|
||||
pub fn matmul(&self, other: &Self) -> Self {
|
||||
Self::new(self.value.matmul(&other.value))
|
||||
Self::new(B::matmul(&self.value, &other.value))
|
||||
}
|
||||
|
||||
/// Switch sign of each element in the tensor.
|
||||
|
|
|
@ -98,6 +98,10 @@ pub trait TensorOps<B: Backend> {
|
|||
lhs: &B::TensorPrimitive<D>,
|
||||
rhs: &B::Elem,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
fn matmul<const D: usize>(
|
||||
lhs: &B::TensorPrimitive<D>,
|
||||
rhs: &B::TensorPrimitive<D>,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
}
|
||||
|
||||
pub trait TensorOpsTranspose<E, const D: usize> {
|
||||
|
@ -105,10 +109,6 @@ pub trait TensorOpsTranspose<E, const D: usize> {
|
|||
fn swap_dims(&self, dim1: usize, dim2: usize) -> Self;
|
||||
}
|
||||
|
||||
pub trait TensorOpsMatmul<E, const D: usize> {
|
||||
fn matmul(&self, other: &Self) -> Self;
|
||||
}
|
||||
|
||||
pub trait TensorOpsNeg<E, const D: usize> {
|
||||
fn neg(&self) -> Self;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
use crate::tensor::TestADTensor;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn should_diff_matmul() {
|
||||
let data_1: Data<f32, 2> = Data::from([[1.0, 7.0], [2.0, 3.0]]);
|
||||
let data_2: Data<f32, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
|
||||
let tensor_1 = TestADTensor::from_data(data_1);
|
||||
let tensor_2 = TestADTensor::from_data(data_2);
|
||||
|
||||
let tensor_3 = &tensor_1.matmul(&tensor_2);
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]]));
|
||||
assert_eq!(grad_2.to_data(), Data::from([[3.0, 3.0], [10.0, 10.0]]));
|
||||
assert_eq!(
|
||||
tensor_3.clone().into_data(),
|
||||
Data::from([[18.0, 28.0], [14.0, 23.0]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_complex_1() {
|
||||
let data_1: Data<f32, 2> = Data::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2: Data<f32, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
let data_3: Data<f32, 2> = Data::from([[2.0, 2.0], [2.0, 2.0]]);
|
||||
|
||||
let tensor_1 = TestADTensor::from_data(data_1);
|
||||
let tensor_2 = TestADTensor::from_data(data_2);
|
||||
let tensor_3 = TestADTensor::from_data(data_3);
|
||||
|
||||
let tensor_4 = tensor_1.matmul(&tensor_2);
|
||||
let tensor_5 = tensor_4.matmul(&tensor_3);
|
||||
|
||||
let grads = tensor_5.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
assert_eq!(grad_1.to_data(), Data::from([[44.0, 20.0], [44.0, 20.0]]));
|
||||
assert_eq!(grad_2.to_data(), Data::from([[56.0, 56.0], [16.0, 16.0]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_complex_2() {
|
||||
let data_1: Data<f32, 2> = Data::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let data_2: Data<f32, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
let data_3: Data<f32, 2> = Data::from([[2.0, 2.0], [2.0, 2.0]]);
|
||||
|
||||
let tensor_1 = TestADTensor::from_data(data_1);
|
||||
let tensor_2 = TestADTensor::from_data(data_2);
|
||||
let tensor_3 = TestADTensor::from_data(data_3);
|
||||
|
||||
let tensor_4 = tensor_1.matmul(&tensor_2);
|
||||
let tensor_5 = tensor_4.matmul(&tensor_3);
|
||||
let tensor_6 = tensor_1.matmul(&tensor_5);
|
||||
|
||||
let grads = tensor_6.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
grad_1.to_data(),
|
||||
Data::from([[800.0, 792.0], [360.0, 592.0]])
|
||||
);
|
||||
assert_eq!(
|
||||
grad_2.to_data(),
|
||||
Data::from([[264., 264.0], [344.0, 344.0]])
|
||||
);
|
||||
}
|
|
@ -2,6 +2,7 @@ mod add;
|
|||
mod aggregation;
|
||||
mod cross_entropy;
|
||||
mod div;
|
||||
mod matmul;
|
||||
mod mul;
|
||||
mod softmax;
|
||||
mod sub;
|
||||
|
|
Loading…
Reference in New Issue