mirror of https://github.com/tracel-ai/burn.git
feat: support sub autograd
This commit is contained in:
parent
f241a6c114
commit
5ce657ded9
|
@ -82,8 +82,8 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn should_diff_add() {
|
fn should_diff_add() {
|
||||||
let tape = Tape::new_ref();
|
let tape = Tape::new_ref();
|
||||||
let data_1 = Data::from([2.0]);
|
let data_1 = Data::from([2.0, 5.0]);
|
||||||
let data_2 = Data::from([4.0]);
|
let data_2 = Data::from([4.0, 1.0]);
|
||||||
|
|
||||||
let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone());
|
let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone());
|
||||||
let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone());
|
let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone());
|
||||||
|
@ -94,8 +94,22 @@ mod tests {
|
||||||
let grad_1 = tensor_1.grad();
|
let grad_1 = tensor_1.grad();
|
||||||
let grad_2 = tensor_2.grad();
|
let grad_2 = tensor_2.grad();
|
||||||
|
|
||||||
assert_eq!(grad_1.into_data(), Data::from([1.0]));
|
assert_eq!(grad_1.into_data(), Data::from([1.0, 1.0]));
|
||||||
assert_eq!(grad_2.into_data(), Data::from([1.0]));
|
assert_eq!(grad_2.into_data(), Data::from([1.0, 1.0]));
|
||||||
assert_eq!(tensor_3.into_data(), Data::from([6.0]));
|
assert_eq!(tensor_3.into_data(), Data::from([6.0, 6.0]));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_diff_add_scalar() {
|
||||||
|
let tape = Tape::new_ref();
|
||||||
|
let data = Data::from([2.0, 10.0]);
|
||||||
|
|
||||||
|
let tensor = ADTchTensor::from_data(data.clone(), tape.clone());
|
||||||
|
let tensor_out = tensor.clone() + 5.0;
|
||||||
|
tensor_out.backprob();
|
||||||
|
|
||||||
|
let grad = tensor.grad();
|
||||||
|
assert_eq!(grad.into_data(), Data::from([1.0, 1.0]));
|
||||||
|
assert_eq!(tensor_out.into_data(), Data::from([7.0, 15.0]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
mod add;
|
mod add;
|
||||||
mod mul;
|
mod mul;
|
||||||
|
mod sub;
|
||||||
|
|
||||||
mod macros;
|
mod macros;
|
||||||
pub use macros::*;
|
pub use macros::*;
|
||||||
|
|
|
@ -82,8 +82,8 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn should_diff_mul() {
|
fn should_diff_mul() {
|
||||||
let tape = Tape::new_ref();
|
let tape = Tape::new_ref();
|
||||||
let data_1 = Data::from([1.0]);
|
let data_1 = Data::from([1.0, 7.0]);
|
||||||
let data_2 = Data::from([4.0]);
|
let data_2 = Data::from([4.0, 7.0]);
|
||||||
|
|
||||||
let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone());
|
let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone());
|
||||||
let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone());
|
let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone());
|
||||||
|
@ -96,20 +96,20 @@ mod tests {
|
||||||
|
|
||||||
assert_eq!(grad_1.into_data(), data_2);
|
assert_eq!(grad_1.into_data(), data_2);
|
||||||
assert_eq!(grad_2.into_data(), data_1);
|
assert_eq!(grad_2.into_data(), data_1);
|
||||||
assert_eq!(tensor_3.into_data(), Data::from([4.0]));
|
assert_eq!(tensor_3.into_data(), Data::from([4.0, 49.0]));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_diff_mul_scalar() {
|
fn should_diff_mul_scalar() {
|
||||||
let tape = Tape::new_ref();
|
let tape = Tape::new_ref();
|
||||||
let data = Data::from([2.0]);
|
let data = Data::from([2.0, 5.0]);
|
||||||
|
|
||||||
let tensor = ADTchTensor::from_data(data.clone(), tape.clone());
|
let tensor = ADTchTensor::from_data(data.clone(), tape.clone());
|
||||||
let tensor_out = tensor.clone() * 4.0;
|
let tensor_out = tensor.clone() * 4.0;
|
||||||
tensor_out.backprob();
|
tensor_out.backprob();
|
||||||
|
|
||||||
let grad = tensor.grad();
|
let grad = tensor.grad();
|
||||||
assert_eq!(tensor_out.into_data(), Data::from([8.0]));
|
assert_eq!(tensor_out.into_data(), Data::from([8.0, 20.0]));
|
||||||
assert_eq!(grad.into_data(), Data::from([4.0]));
|
assert_eq!(grad.into_data(), Data::from([4.0, 4.0]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,115 @@
|
||||||
|
use crate::{
|
||||||
|
backend::autodiff::{ADFloat, ADFloatTensor, ADTensor},
|
||||||
|
define_ops, execute_ops,
|
||||||
|
ops::{
|
||||||
|
BinaryOps, BinaryRecordedOps, BinaryRecordedState, SingleOps, SingleRecordedOps,
|
||||||
|
SingleRecordedState,
|
||||||
|
},
|
||||||
|
register_ops, TensorOpsSub,
|
||||||
|
};
|
||||||
|
use num_traits::Float;
|
||||||
|
|
||||||
|
register_ops!(
|
||||||
|
ops BinaryOps<T, T, T>,
|
||||||
|
name ADTensorSubOps,
|
||||||
|
forward |left, right| left * right,
|
||||||
|
partial_left |state: &BinaryRecordedState<T, T, T>| state.left.ones(),
|
||||||
|
partial_right |state: &BinaryRecordedState<T, T, T>| state.right.ones().neg(),
|
||||||
|
);
|
||||||
|
|
||||||
|
register_ops!(
|
||||||
|
ops SingleOps<T, T>,
|
||||||
|
name ADTensorSubScalarOps state P,
|
||||||
|
forward |state, input| input * state,
|
||||||
|
partial |_state, state_recorded: &SingleRecordedState<T, T>| state_recorded.input.ones(),
|
||||||
|
);
|
||||||
|
|
||||||
|
impl<T, P, const D: usize> TensorOpsSub<P, D> for ADTensor<P, D, T>
|
||||||
|
where
|
||||||
|
T: ADFloatTensor<P, D>,
|
||||||
|
P: ADFloat,
|
||||||
|
{
|
||||||
|
fn sub(&self, other: &Self) -> Self {
|
||||||
|
let node = execute_ops!(
|
||||||
|
lhs self.node.clone(),
|
||||||
|
rhs other.node.clone(),
|
||||||
|
out TensorOpsSub::sub(&self.tensor(), &other.tensor()),
|
||||||
|
tape self.tape.clone(),
|
||||||
|
ops ADTensorSubOps::new(),
|
||||||
|
);
|
||||||
|
self.from_existing(node)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sub_scalar(&self, other: &P) -> Self {
|
||||||
|
let node = execute_ops!(
|
||||||
|
input self.node.clone(),
|
||||||
|
out TensorOpsSub::sub_scalar(&self.tensor(), &other),
|
||||||
|
tape self.tape.clone(),
|
||||||
|
ops ADTensorSubScalarOps::new(other.clone()),
|
||||||
|
);
|
||||||
|
self.from_existing(node)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, P, const D: usize> std::ops::Sub<P> for ADTensor<P, D, T>
|
||||||
|
where
|
||||||
|
T: ADFloatTensor<P, D> + 'static,
|
||||||
|
P: ADFloat + 'static,
|
||||||
|
{
|
||||||
|
type Output = ADTensor<P, D, T>;
|
||||||
|
|
||||||
|
fn sub(self, rhs: P) -> Self::Output {
|
||||||
|
TensorOpsSub::sub_scalar(&self, &rhs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, P, const D: usize> std::ops::Sub<ADTensor<P, D, T>> for ADTensor<P, D, T>
|
||||||
|
where
|
||||||
|
T: ADFloatTensor<P, D> + 'static,
|
||||||
|
P: ADFloat + 'static,
|
||||||
|
{
|
||||||
|
type Output = ADTensor<P, D, T>;
|
||||||
|
|
||||||
|
fn sub(self, rhs: Self) -> Self::Output {
|
||||||
|
TensorOpsSub::sub(&self, &rhs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::{backend::autodiff::helper::ADTchTensor, tape::Tape, Data, TensorBase};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_diff_sub() {
|
||||||
|
let tape = Tape::new_ref();
|
||||||
|
let data_1 = Data::from([2.0, 5.0]);
|
||||||
|
let data_2 = Data::from([4.0, 1.0]);
|
||||||
|
|
||||||
|
let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone());
|
||||||
|
let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone());
|
||||||
|
|
||||||
|
let tensor_3 = tensor_1.clone() - tensor_2.clone();
|
||||||
|
tensor_3.backprob();
|
||||||
|
|
||||||
|
let grad_1 = tensor_1.grad();
|
||||||
|
let grad_2 = tensor_2.grad();
|
||||||
|
|
||||||
|
assert_eq!(grad_1.into_data(), Data::from([1.0, 1.0]));
|
||||||
|
assert_eq!(grad_2.into_data(), Data::from([-1.0, -1.0]));
|
||||||
|
assert_eq!(tensor_3.into_data(), Data::from([-2.0, 4.0]));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_diff_sub_scalar() {
|
||||||
|
let tape = Tape::new_ref();
|
||||||
|
let data = Data::from([2.0, 10.0]);
|
||||||
|
|
||||||
|
let tensor = ADTchTensor::from_data(data.clone(), tape.clone());
|
||||||
|
let tensor_out = tensor.clone() - 5.0;
|
||||||
|
tensor_out.backprob();
|
||||||
|
|
||||||
|
let grad = tensor.grad();
|
||||||
|
assert_eq!(grad.into_data(), Data::from([1.0, 1.0]));
|
||||||
|
assert_eq!(tensor_out.into_data(), Data::from([-3.0, 5.0]));
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,3 +4,4 @@ mod matmul;
|
||||||
mod mul;
|
mod mul;
|
||||||
mod neg;
|
mod neg;
|
||||||
mod reshape;
|
mod reshape;
|
||||||
|
mod sub;
|
||||||
|
|
|
@ -0,0 +1,71 @@
|
||||||
|
use crate::{backend::tch::TchTensor, Data, TensorOpsSub};
|
||||||
|
use std::ops::Sub;
|
||||||
|
|
||||||
|
impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TensorOpsSub<P, D>
|
||||||
|
for TchTensor<P, D>
|
||||||
|
{
|
||||||
|
fn sub(&self, other: &Self) -> Self {
|
||||||
|
let tensor = (&self.tensor).sub(&other.tensor);
|
||||||
|
let kind = self.kind.clone();
|
||||||
|
let shape = self.shape.clone();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
tensor,
|
||||||
|
shape,
|
||||||
|
kind,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn sub_scalar(&self, other: &P) -> Self {
|
||||||
|
let elems: [P; D] = [*other; D];
|
||||||
|
let data = Data::from(elems);
|
||||||
|
let other = TchTensor::from_data(data, self.tensor.device());
|
||||||
|
let tensor = (&self.tensor).sub(&other.tensor);
|
||||||
|
let kind = self.kind.clone();
|
||||||
|
let shape = self.shape.clone();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
tensor,
|
||||||
|
shape,
|
||||||
|
kind,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<P: tch::kind::Element + Default + std::fmt::Debug + Copy, const D: usize> std::ops::Sub<Self>
|
||||||
|
for TchTensor<P, D>
|
||||||
|
{
|
||||||
|
type Output = Self;
|
||||||
|
|
||||||
|
fn sub(self, rhs: Self) -> Self::Output {
|
||||||
|
TensorOpsSub::sub(&self, &rhs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<P: tch::kind::Element + Default + std::fmt::Debug + Copy, const D: usize> std::ops::Sub<P>
|
||||||
|
for TchTensor<P, D>
|
||||||
|
{
|
||||||
|
type Output = Self;
|
||||||
|
|
||||||
|
fn sub(self, rhs: P) -> Self::Output {
|
||||||
|
TensorOpsSub::sub_scalar(&self, &rhs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::TensorBase;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_support_sub_ops() {
|
||||||
|
let data_1 = Data::<f64, 2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||||
|
let data_2 = Data::<f64, 2>::from([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]);
|
||||||
|
let data_expected = Data::from([[-6.0, -6.0, -6.0], [-6.0, -6.0, -6.0]]);
|
||||||
|
let tensor_1 = TchTensor::from_data(data_1, tch::Device::Cpu);
|
||||||
|
let tensor_2 = TchTensor::from_data(data_2, tch::Device::Cpu);
|
||||||
|
|
||||||
|
let data_actual = (tensor_1 - tensor_2).into_data();
|
||||||
|
|
||||||
|
assert_eq!(data_expected, data_actual);
|
||||||
|
}
|
||||||
|
}
|
|
@ -11,6 +11,7 @@ pub trait FloatTensor<P: num_traits::Float, const D: usize>:
|
||||||
+ TensorOpsMul<P, D>
|
+ TensorOpsMul<P, D>
|
||||||
+ TensorOpsNeg<P, D>
|
+ TensorOpsNeg<P, D>
|
||||||
+ TensorOpsAdd<P, D>
|
+ TensorOpsAdd<P, D>
|
||||||
|
+ TensorOpsSub<P, D>
|
||||||
+ TensorOpsMatmul<P, D>
|
+ TensorOpsMatmul<P, D>
|
||||||
+ std::fmt::Debug
|
+ std::fmt::Debug
|
||||||
{
|
{
|
||||||
|
@ -30,11 +31,20 @@ where
|
||||||
fn add_scalar(&self, other: &P) -> Self;
|
fn add_scalar(&self, other: &P) -> Self;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait TensorOpsSub<P, const D: usize>:
|
||||||
|
std::ops::Sub<Self, Output = Self> + std::ops::Sub<P, Output = Self>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
fn sub(&self, other: &Self) -> Self;
|
||||||
|
fn sub_scalar(&self, other: &P) -> Self;
|
||||||
|
}
|
||||||
|
|
||||||
pub trait TensorOpsMatmul<P, const D: usize> {
|
pub trait TensorOpsMatmul<P, const D: usize> {
|
||||||
fn matmul(&self, other: &Self) -> Self;
|
fn matmul(&self, other: &Self) -> Self;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait TensorOpsNeg<P, const D: usize>: std::ops::Neg {
|
pub trait TensorOpsNeg<P, const D: usize>: std::ops::Neg<Output = Self> {
|
||||||
fn neg(&self) -> Self;
|
fn neg(&self) -> Self;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue