mirror of https://github.com/tracel-ai/burn.git
feat: support sub autograd
This commit is contained in:
parent
f241a6c114
commit
5ce657ded9
|
@ -82,8 +82,8 @@ mod tests {
|
|||
#[test]
|
||||
fn should_diff_add() {
|
||||
let tape = Tape::new_ref();
|
||||
let data_1 = Data::from([2.0]);
|
||||
let data_2 = Data::from([4.0]);
|
||||
let data_1 = Data::from([2.0, 5.0]);
|
||||
let data_2 = Data::from([4.0, 1.0]);
|
||||
|
||||
let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone());
|
||||
let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone());
|
||||
|
@ -94,8 +94,22 @@ mod tests {
|
|||
let grad_1 = tensor_1.grad();
|
||||
let grad_2 = tensor_2.grad();
|
||||
|
||||
assert_eq!(grad_1.into_data(), Data::from([1.0]));
|
||||
assert_eq!(grad_2.into_data(), Data::from([1.0]));
|
||||
assert_eq!(tensor_3.into_data(), Data::from([6.0]));
|
||||
assert_eq!(grad_1.into_data(), Data::from([1.0, 1.0]));
|
||||
assert_eq!(grad_2.into_data(), Data::from([1.0, 1.0]));
|
||||
assert_eq!(tensor_3.into_data(), Data::from([6.0, 6.0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_add_scalar() {
|
||||
let tape = Tape::new_ref();
|
||||
let data = Data::from([2.0, 10.0]);
|
||||
|
||||
let tensor = ADTchTensor::from_data(data.clone(), tape.clone());
|
||||
let tensor_out = tensor.clone() + 5.0;
|
||||
tensor_out.backprob();
|
||||
|
||||
let grad = tensor.grad();
|
||||
assert_eq!(grad.into_data(), Data::from([1.0, 1.0]));
|
||||
assert_eq!(tensor_out.into_data(), Data::from([7.0, 15.0]));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
mod add;
|
||||
mod mul;
|
||||
mod sub;
|
||||
|
||||
mod macros;
|
||||
pub use macros::*;
|
||||
|
|
|
@ -82,8 +82,8 @@ mod tests {
|
|||
#[test]
|
||||
fn should_diff_mul() {
|
||||
let tape = Tape::new_ref();
|
||||
let data_1 = Data::from([1.0]);
|
||||
let data_2 = Data::from([4.0]);
|
||||
let data_1 = Data::from([1.0, 7.0]);
|
||||
let data_2 = Data::from([4.0, 7.0]);
|
||||
|
||||
let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone());
|
||||
let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone());
|
||||
|
@ -96,20 +96,20 @@ mod tests {
|
|||
|
||||
assert_eq!(grad_1.into_data(), data_2);
|
||||
assert_eq!(grad_2.into_data(), data_1);
|
||||
assert_eq!(tensor_3.into_data(), Data::from([4.0]));
|
||||
assert_eq!(tensor_3.into_data(), Data::from([4.0, 49.0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_mul_scalar() {
|
||||
let tape = Tape::new_ref();
|
||||
let data = Data::from([2.0]);
|
||||
let data = Data::from([2.0, 5.0]);
|
||||
|
||||
let tensor = ADTchTensor::from_data(data.clone(), tape.clone());
|
||||
let tensor_out = tensor.clone() * 4.0;
|
||||
tensor_out.backprob();
|
||||
|
||||
let grad = tensor.grad();
|
||||
assert_eq!(tensor_out.into_data(), Data::from([8.0]));
|
||||
assert_eq!(grad.into_data(), Data::from([4.0]));
|
||||
assert_eq!(tensor_out.into_data(), Data::from([8.0, 20.0]));
|
||||
assert_eq!(grad.into_data(), Data::from([4.0, 4.0]));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
use crate::{
|
||||
backend::autodiff::{ADFloat, ADFloatTensor, ADTensor},
|
||||
define_ops, execute_ops,
|
||||
ops::{
|
||||
BinaryOps, BinaryRecordedOps, BinaryRecordedState, SingleOps, SingleRecordedOps,
|
||||
SingleRecordedState,
|
||||
},
|
||||
register_ops, TensorOpsSub,
|
||||
};
|
||||
use num_traits::Float;
|
||||
|
||||
register_ops!(
|
||||
ops BinaryOps<T, T, T>,
|
||||
name ADTensorSubOps,
|
||||
forward |left, right| left * right,
|
||||
partial_left |state: &BinaryRecordedState<T, T, T>| state.left.ones(),
|
||||
partial_right |state: &BinaryRecordedState<T, T, T>| state.right.ones().neg(),
|
||||
);
|
||||
|
||||
register_ops!(
|
||||
ops SingleOps<T, T>,
|
||||
name ADTensorSubScalarOps state P,
|
||||
forward |state, input| input * state,
|
||||
partial |_state, state_recorded: &SingleRecordedState<T, T>| state_recorded.input.ones(),
|
||||
);
|
||||
|
||||
impl<T, P, const D: usize> TensorOpsSub<P, D> for ADTensor<P, D, T>
|
||||
where
|
||||
T: ADFloatTensor<P, D>,
|
||||
P: ADFloat,
|
||||
{
|
||||
fn sub(&self, other: &Self) -> Self {
|
||||
let node = execute_ops!(
|
||||
lhs self.node.clone(),
|
||||
rhs other.node.clone(),
|
||||
out TensorOpsSub::sub(&self.tensor(), &other.tensor()),
|
||||
tape self.tape.clone(),
|
||||
ops ADTensorSubOps::new(),
|
||||
);
|
||||
self.from_existing(node)
|
||||
}
|
||||
|
||||
fn sub_scalar(&self, other: &P) -> Self {
|
||||
let node = execute_ops!(
|
||||
input self.node.clone(),
|
||||
out TensorOpsSub::sub_scalar(&self.tensor(), &other),
|
||||
tape self.tape.clone(),
|
||||
ops ADTensorSubScalarOps::new(other.clone()),
|
||||
);
|
||||
self.from_existing(node)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, P, const D: usize> std::ops::Sub<P> for ADTensor<P, D, T>
|
||||
where
|
||||
T: ADFloatTensor<P, D> + 'static,
|
||||
P: ADFloat + 'static,
|
||||
{
|
||||
type Output = ADTensor<P, D, T>;
|
||||
|
||||
fn sub(self, rhs: P) -> Self::Output {
|
||||
TensorOpsSub::sub_scalar(&self, &rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, P, const D: usize> std::ops::Sub<ADTensor<P, D, T>> for ADTensor<P, D, T>
|
||||
where
|
||||
T: ADFloatTensor<P, D> + 'static,
|
||||
P: ADFloat + 'static,
|
||||
{
|
||||
type Output = ADTensor<P, D, T>;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
TensorOpsSub::sub(&self, &rhs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{backend::autodiff::helper::ADTchTensor, tape::Tape, Data, TensorBase};
|
||||
|
||||
#[test]
|
||||
fn should_diff_sub() {
|
||||
let tape = Tape::new_ref();
|
||||
let data_1 = Data::from([2.0, 5.0]);
|
||||
let data_2 = Data::from([4.0, 1.0]);
|
||||
|
||||
let tensor_1 = ADTchTensor::from_data(data_1.clone(), tape.clone());
|
||||
let tensor_2 = ADTchTensor::from_data(data_2.clone(), tape.clone());
|
||||
|
||||
let tensor_3 = tensor_1.clone() - tensor_2.clone();
|
||||
tensor_3.backprob();
|
||||
|
||||
let grad_1 = tensor_1.grad();
|
||||
let grad_2 = tensor_2.grad();
|
||||
|
||||
assert_eq!(grad_1.into_data(), Data::from([1.0, 1.0]));
|
||||
assert_eq!(grad_2.into_data(), Data::from([-1.0, -1.0]));
|
||||
assert_eq!(tensor_3.into_data(), Data::from([-2.0, 4.0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_sub_scalar() {
|
||||
let tape = Tape::new_ref();
|
||||
let data = Data::from([2.0, 10.0]);
|
||||
|
||||
let tensor = ADTchTensor::from_data(data.clone(), tape.clone());
|
||||
let tensor_out = tensor.clone() - 5.0;
|
||||
tensor_out.backprob();
|
||||
|
||||
let grad = tensor.grad();
|
||||
assert_eq!(grad.into_data(), Data::from([1.0, 1.0]));
|
||||
assert_eq!(tensor_out.into_data(), Data::from([-3.0, 5.0]));
|
||||
}
|
||||
}
|
|
@ -4,3 +4,4 @@ mod matmul;
|
|||
mod mul;
|
||||
mod neg;
|
||||
mod reshape;
|
||||
mod sub;
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
use crate::{backend::tch::TchTensor, Data, TensorOpsSub};
|
||||
use std::ops::Sub;
|
||||
|
||||
impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TensorOpsSub<P, D>
|
||||
for TchTensor<P, D>
|
||||
{
|
||||
fn sub(&self, other: &Self) -> Self {
|
||||
let tensor = (&self.tensor).sub(&other.tensor);
|
||||
let kind = self.kind.clone();
|
||||
let shape = self.shape.clone();
|
||||
|
||||
Self {
|
||||
tensor,
|
||||
shape,
|
||||
kind,
|
||||
}
|
||||
}
|
||||
fn sub_scalar(&self, other: &P) -> Self {
|
||||
let elems: [P; D] = [*other; D];
|
||||
let data = Data::from(elems);
|
||||
let other = TchTensor::from_data(data, self.tensor.device());
|
||||
let tensor = (&self.tensor).sub(&other.tensor);
|
||||
let kind = self.kind.clone();
|
||||
let shape = self.shape.clone();
|
||||
|
||||
Self {
|
||||
tensor,
|
||||
shape,
|
||||
kind,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: tch::kind::Element + Default + std::fmt::Debug + Copy, const D: usize> std::ops::Sub<Self>
|
||||
for TchTensor<P, D>
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
TensorOpsSub::sub(&self, &rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: tch::kind::Element + Default + std::fmt::Debug + Copy, const D: usize> std::ops::Sub<P>
|
||||
for TchTensor<P, D>
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: P) -> Self::Output {
|
||||
TensorOpsSub::sub_scalar(&self, &rhs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TensorBase;
|
||||
|
||||
#[test]
|
||||
fn should_support_sub_ops() {
|
||||
let data_1 = Data::<f64, 2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let data_2 = Data::<f64, 2>::from([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]);
|
||||
let data_expected = Data::from([[-6.0, -6.0, -6.0], [-6.0, -6.0, -6.0]]);
|
||||
let tensor_1 = TchTensor::from_data(data_1, tch::Device::Cpu);
|
||||
let tensor_2 = TchTensor::from_data(data_2, tch::Device::Cpu);
|
||||
|
||||
let data_actual = (tensor_1 - tensor_2).into_data();
|
||||
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
}
|
|
@ -11,6 +11,7 @@ pub trait FloatTensor<P: num_traits::Float, const D: usize>:
|
|||
+ TensorOpsMul<P, D>
|
||||
+ TensorOpsNeg<P, D>
|
||||
+ TensorOpsAdd<P, D>
|
||||
+ TensorOpsSub<P, D>
|
||||
+ TensorOpsMatmul<P, D>
|
||||
+ std::fmt::Debug
|
||||
{
|
||||
|
@ -30,11 +31,20 @@ where
|
|||
fn add_scalar(&self, other: &P) -> Self;
|
||||
}
|
||||
|
||||
pub trait TensorOpsSub<P, const D: usize>:
|
||||
std::ops::Sub<Self, Output = Self> + std::ops::Sub<P, Output = Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
fn sub(&self, other: &Self) -> Self;
|
||||
fn sub_scalar(&self, other: &P) -> Self;
|
||||
}
|
||||
|
||||
pub trait TensorOpsMatmul<P, const D: usize> {
|
||||
fn matmul(&self, other: &Self) -> Self;
|
||||
}
|
||||
|
||||
pub trait TensorOpsNeg<P, const D: usize>: std::ops::Neg {
|
||||
pub trait TensorOpsNeg<P, const D: usize>: std::ops::Neg<Output = Self> {
|
||||
fn neg(&self) -> Self;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue