mirror of https://github.com/tracel-ai/burn.git
feat: implement matmul diff
This commit is contained in:
parent
5ce657ded9
commit
2084aced63
|
@ -3,7 +3,6 @@ use crate::node::{Node, NodeId, NodeRef, Ones, Zeros};
|
||||||
use std::ops::{Add, Mul};
|
use std::ops::{Add, Mul};
|
||||||
|
|
||||||
pub trait BinaryOps<Lhs, Rhs, Out>: std::fmt::Debug {
|
pub trait BinaryOps<Lhs, Rhs, Out>: std::fmt::Debug {
|
||||||
fn forward(&self, left: Lhs, right: Rhs) -> Out;
|
|
||||||
fn partial_left(&self, state: &BinaryRecordedState<Lhs, Rhs, Out>) -> Lhs;
|
fn partial_left(&self, state: &BinaryRecordedState<Lhs, Rhs, Out>) -> Lhs;
|
||||||
fn partial_right(&self, state: &BinaryRecordedState<Lhs, Rhs, Out>) -> Rhs;
|
fn partial_right(&self, state: &BinaryRecordedState<Lhs, Rhs, Out>) -> Rhs;
|
||||||
}
|
}
|
||||||
|
@ -80,13 +79,11 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
fn backward(&mut self) {
|
fn backward(&mut self) {
|
||||||
let left = self.lhs.borrow().value();
|
let state = BinaryRecordedState::new(&self.lhs, &self.rhs, &self.out);
|
||||||
let right = self.rhs.borrow().value();
|
|
||||||
let output = self.out.borrow().value();
|
|
||||||
let state = BinaryRecordedState::new(&left, &right, &output);
|
|
||||||
|
|
||||||
let partial_left = self.ops.partial_left(&state);
|
let partial_left = self.ops.partial_left(&state);
|
||||||
let partial_right: Rhs = self.ops.partial_right(&state);
|
let partial_right: Rhs = self.ops.partial_right(&state);
|
||||||
|
|
||||||
let grad_mine = self.out.borrow_mut().grad();
|
let grad_mine = self.out.borrow_mut().grad();
|
||||||
|
|
||||||
self.lhs
|
self.lhs
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
use crate::node::NodeId;
|
use crate::node::{NodeId, NodeRef};
|
||||||
|
|
||||||
#[derive(new)]
|
#[derive(new)]
|
||||||
pub struct BinaryRecordedState<'a, Lhs, Rhs, Out> {
|
pub struct BinaryRecordedState<'a, Lhs, Rhs, Out> {
|
||||||
pub left: &'a Lhs,
|
pub left: &'a NodeRef<Lhs>,
|
||||||
pub right: &'a Rhs,
|
pub right: &'a NodeRef<Rhs>,
|
||||||
pub output: &'a Out,
|
pub output: &'a NodeRef<Out>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(new)]
|
#[derive(new)]
|
||||||
|
|
|
@ -3,7 +3,6 @@ use crate::node::{Node, NodeId, NodeRef, Ones, Zeros};
|
||||||
use std::ops::{Add, Mul};
|
use std::ops::{Add, Mul};
|
||||||
|
|
||||||
pub trait SingleOps<In, Out>: std::fmt::Debug {
|
pub trait SingleOps<In, Out>: std::fmt::Debug {
|
||||||
fn forward(&self, input: In) -> Out;
|
|
||||||
fn partial(&self, state: &SingleRecordedState<In, Out>) -> In;
|
fn partial(&self, state: &SingleRecordedState<In, Out>) -> In;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -12,15 +12,13 @@ use num_traits::Float;
|
||||||
register_ops!(
|
register_ops!(
|
||||||
ops BinaryOps<T, T, T>,
|
ops BinaryOps<T, T, T>,
|
||||||
name ADTensorAddOps,
|
name ADTensorAddOps,
|
||||||
forward |left, right| left * right,
|
partial_left |state: &BinaryRecordedState<T, T, T>| state.left.borrow().value().ones(),
|
||||||
partial_left |state: &BinaryRecordedState<T, T, T>| state.left.ones(),
|
partial_right |state: &BinaryRecordedState<T, T, T>| state.right.borrow().value().ones(),
|
||||||
partial_right |state: &BinaryRecordedState<T, T, T>| state.right.ones(),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
register_ops!(
|
register_ops!(
|
||||||
ops SingleOps<T, T>,
|
ops SingleOps<T, T>,
|
||||||
name ADTensorAddScalarOps state P,
|
name ADTensorAddScalarOps state P,
|
||||||
forward |state, input| input * state,
|
|
||||||
partial |_state, state_recorded: &SingleRecordedState<T, T>| state_recorded.input.ones(),
|
partial |_state, state_recorded: &SingleRecordedState<T, T>| state_recorded.input.ones(),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,6 @@ macro_rules! register_ops {
|
||||||
(
|
(
|
||||||
ops $ops:ty,
|
ops $ops:ty,
|
||||||
name $name:ident,
|
name $name:ident,
|
||||||
forward $forward:expr,
|
|
||||||
partial_left $partial_left:expr,
|
partial_left $partial_left:expr,
|
||||||
partial_right $partial_right:expr,
|
partial_right $partial_right:expr,
|
||||||
) => {
|
) => {
|
||||||
|
@ -55,10 +54,6 @@ macro_rules! register_ops {
|
||||||
P: $crate::tensor::backend::autodiff::ADFloat,
|
P: $crate::tensor::backend::autodiff::ADFloat,
|
||||||
T: $crate::tensor::backend::autodiff::ADFloatTensor<P, D>,
|
T: $crate::tensor::backend::autodiff::ADFloatTensor<P, D>,
|
||||||
{
|
{
|
||||||
fn forward(&self, left: T, right: T) -> T {
|
|
||||||
$forward(left, right)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn partial_left(&self, state: &$crate::graph::ops::BinaryRecordedState<T, T, T>) -> T {
|
fn partial_left(&self, state: &$crate::graph::ops::BinaryRecordedState<T, T, T>) -> T {
|
||||||
$partial_left(state)
|
$partial_left(state)
|
||||||
}
|
}
|
||||||
|
@ -71,7 +66,6 @@ macro_rules! register_ops {
|
||||||
(
|
(
|
||||||
ops $ops:ty,
|
ops $ops:ty,
|
||||||
name $name:ident state $ops_tensor_state:ident,
|
name $name:ident state $ops_tensor_state:ident,
|
||||||
forward $forward:expr,
|
|
||||||
partial $partial:expr,
|
partial $partial:expr,
|
||||||
) => {
|
) => {
|
||||||
define_ops!(
|
define_ops!(
|
||||||
|
@ -84,10 +78,6 @@ macro_rules! register_ops {
|
||||||
P: $crate::tensor::backend::autodiff::ADFloat,
|
P: $crate::tensor::backend::autodiff::ADFloat,
|
||||||
T: $crate::tensor::backend::autodiff::ADFloatTensor<P, D>,
|
T: $crate::tensor::backend::autodiff::ADFloatTensor<P, D>,
|
||||||
{
|
{
|
||||||
fn forward(&self, input: T) -> T {
|
|
||||||
$forward(self.state, input)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn partial(&self, state: &$crate::graph::ops::SingleRecordedState<T, T>) -> T {
|
fn partial(&self, state: &$crate::graph::ops::SingleRecordedState<T, T>) -> T {
|
||||||
$partial(self.state, state)
|
$partial(self.state, state)
|
||||||
}
|
}
|
||||||
|
@ -96,7 +86,6 @@ macro_rules! register_ops {
|
||||||
(
|
(
|
||||||
ops $ops:ty,
|
ops $ops:ty,
|
||||||
name $name:ident,
|
name $name:ident,
|
||||||
forward $forward:expr,
|
|
||||||
partial $partial:expr,
|
partial $partial:expr,
|
||||||
) => {
|
) => {
|
||||||
define_ops!(
|
define_ops!(
|
||||||
|
@ -108,10 +97,6 @@ macro_rules! register_ops {
|
||||||
P: $crate::tensor::backend::autodiff::ADFloat,
|
P: $crate::tensor::backend::autodiff::ADFloat,
|
||||||
T: $crate::tensor::backend::autodiff::ADFloatTensor<P, D>,
|
T: $crate::tensor::backend::autodiff::ADFloatTensor<P, D>,
|
||||||
{
|
{
|
||||||
fn forward(&self, input: T) -> T {
|
|
||||||
$forward(input)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn partial(&self, state: &$crate::graph::ops::SingleRecordedState<T, T>) -> T {
|
fn partial(&self, state: &$crate::graph::ops::SingleRecordedState<T, T>) -> T {
|
||||||
$partial(state)
|
$partial(state)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
use crate::execute_ops;
|
||||||
|
use crate::{
|
||||||
|
backend::autodiff::{ADFloat, ADFloatTensor, ADTensor},
|
||||||
|
ops::{BinaryOps, BinaryRecordedOps, BinaryRecordedState},
|
||||||
|
register_ops, TensorOpsMatmul,
|
||||||
|
};
|
||||||
|
use num_traits::Float;
|
||||||
|
|
||||||
|
register_ops!(
|
||||||
|
ops BinaryOps<T, T, T>,
|
||||||
|
name ADTensorMatmulOps,
|
||||||
|
partial_left |state: &BinaryRecordedState<T, T, T>| {
|
||||||
|
let out_grad = state.output.borrow_mut().grad();
|
||||||
|
let rhs = state.right.borrow().value().transpose();
|
||||||
|
out_grad.matmul(&rhs)
|
||||||
|
},
|
||||||
|
partial_right |state: &BinaryRecordedState<T, T, T>| {
|
||||||
|
let out_grad = state.output.borrow_mut().grad();
|
||||||
|
let lhs = state.left.borrow().value().transpose();
|
||||||
|
lhs.matmul(&out_grad)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
impl<T, P, const D: usize> TensorOpsMatmul<P, D> for ADTensor<P, D, T>
|
||||||
|
where
|
||||||
|
T: ADFloatTensor<P, D>,
|
||||||
|
P: ADFloat,
|
||||||
|
{
|
||||||
|
fn matmul(&self, other: &Self) -> Self {
|
||||||
|
let node = execute_ops!(
|
||||||
|
lhs self.node.clone(),
|
||||||
|
rhs other.node.clone(),
|
||||||
|
out TensorOpsMatmul::matmul(&self.tensor(), &other.tensor()),
|
||||||
|
tape self.tape.clone(),
|
||||||
|
ops ADTensorMatmulOps::new(),
|
||||||
|
);
|
||||||
|
self.from_existing(node)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::{backend::autodiff::helper::ADTchTensor, tape::Tape, Data, TensorBase};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_diff_mul() {
|
||||||
|
let tape = Tape::new_ref();
|
||||||
|
let data_1: Data<f64, 2> = Data::from([[1.0, 7.0], [2.0, 3.0]]);
|
||||||
|
let data_2: Data<f64, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||||
|
|
||||||
|
let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone());
|
||||||
|
let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone());
|
||||||
|
|
||||||
|
let tensor_3 = &tensor_1.matmul(&tensor_2);
|
||||||
|
tensor_3.backprob();
|
||||||
|
|
||||||
|
let grad_1 = tensor_1.grad();
|
||||||
|
let grad_2 = tensor_2.grad();
|
||||||
|
|
||||||
|
assert_eq!(grad_1.into_data(), Data::from([[11.0, 5.0], [11.0, 5.0]]));
|
||||||
|
assert_eq!(grad_2.into_data(), Data::from([[3.0, 3.0], [10.0, 10.0]]));
|
||||||
|
assert_eq!(
|
||||||
|
tensor_3.clone().into_data(),
|
||||||
|
Data::from([[18.0, 28.0], [14.0, 23.0]])
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,5 @@
|
||||||
mod add;
|
mod add;
|
||||||
|
mod matmul;
|
||||||
mod mul;
|
mod mul;
|
||||||
mod sub;
|
mod sub;
|
||||||
|
|
||||||
|
|
|
@ -12,15 +12,13 @@ use num_traits::Float;
|
||||||
register_ops!(
|
register_ops!(
|
||||||
ops BinaryOps<T, T, T>,
|
ops BinaryOps<T, T, T>,
|
||||||
name ADTensorMulOps,
|
name ADTensorMulOps,
|
||||||
forward |left, right| left * right,
|
partial_left |state: &BinaryRecordedState<T, T, T>| state.right.borrow().value().clone(),
|
||||||
partial_left |state: &BinaryRecordedState<T, T, T>| state.right.clone(),
|
partial_right |state: &BinaryRecordedState<T, T, T>| state.left.borrow().value().clone(),
|
||||||
partial_right |state: &BinaryRecordedState<T, T, T>| state.left.clone(),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
register_ops!(
|
register_ops!(
|
||||||
ops SingleOps<T, T>,
|
ops SingleOps<T, T>,
|
||||||
name ADTensorMulScalarOps state P,
|
name ADTensorMulScalarOps state P,
|
||||||
forward |state, input| input * state,
|
|
||||||
partial |state, state_recorded: &SingleRecordedState<T, T>| state_recorded.input.ones() * state,
|
partial |state, state_recorded: &SingleRecordedState<T, T>| state_recorded.input.ones() * state,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -12,15 +12,13 @@ use num_traits::Float;
|
||||||
register_ops!(
|
register_ops!(
|
||||||
ops BinaryOps<T, T, T>,
|
ops BinaryOps<T, T, T>,
|
||||||
name ADTensorSubOps,
|
name ADTensorSubOps,
|
||||||
forward |left, right| left * right,
|
partial_left |state: &BinaryRecordedState<T, T, T>| state.left.borrow().value().ones(),
|
||||||
partial_left |state: &BinaryRecordedState<T, T, T>| state.left.ones(),
|
partial_right |state: &BinaryRecordedState<T, T, T>| state.right.borrow().value().ones().neg(),
|
||||||
partial_right |state: &BinaryRecordedState<T, T, T>| state.right.ones().neg(),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
register_ops!(
|
register_ops!(
|
||||||
ops SingleOps<T, T>,
|
ops SingleOps<T, T>,
|
||||||
name ADTensorSubScalarOps state P,
|
name ADTensorSubScalarOps state P,
|
||||||
forward |state, input| input * state,
|
|
||||||
partial |_state, state_recorded: &SingleRecordedState<T, T>| state_recorded.input.ones(),
|
partial |_state, state_recorded: &SingleRecordedState<T, T>| state_recorded.input.ones(),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -5,3 +5,4 @@ mod mul;
|
||||||
mod neg;
|
mod neg;
|
||||||
mod reshape;
|
mod reshape;
|
||||||
mod sub;
|
mod sub;
|
||||||
|
mod transpose;
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
use crate::{backend::tch::TchTensor, Shape, TensorOpsTranspose};
|
||||||
|
|
||||||
|
impl<P: tch::kind::Element, const D: usize> TensorOpsTranspose<P, D> for TchTensor<P, D> {
|
||||||
|
fn transpose(&self) -> Self {
|
||||||
|
let tensor = self.tensor.transpose(-2, -1);
|
||||||
|
let kind = self.kind.clone();
|
||||||
|
let shape = Shape::from(tensor.size());
|
||||||
|
|
||||||
|
Self {
|
||||||
|
kind,
|
||||||
|
tensor,
|
||||||
|
shape,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -13,6 +13,7 @@ pub trait FloatTensor<P: num_traits::Float, const D: usize>:
|
||||||
+ TensorOpsAdd<P, D>
|
+ TensorOpsAdd<P, D>
|
||||||
+ TensorOpsSub<P, D>
|
+ TensorOpsSub<P, D>
|
||||||
+ TensorOpsMatmul<P, D>
|
+ TensorOpsMatmul<P, D>
|
||||||
|
+ TensorOpsTranspose<P, D>
|
||||||
+ std::fmt::Debug
|
+ std::fmt::Debug
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
@ -40,6 +41,10 @@ where
|
||||||
fn sub_scalar(&self, other: &P) -> Self;
|
fn sub_scalar(&self, other: &P) -> Self;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait TensorOpsTranspose<P, const D: usize> {
|
||||||
|
fn transpose(&self) -> Self;
|
||||||
|
}
|
||||||
|
|
||||||
pub trait TensorOpsMatmul<P, const D: usize> {
|
pub trait TensorOpsMatmul<P, const D: usize> {
|
||||||
fn matmul(&self, other: &Self) -> Self;
|
fn matmul(&self, other: &Self) -> Self;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue