Add cos, sin and tanh operations (#155)

* Add cos, sin and tanh operations

* Add tests

* Fix formatting
This commit is contained in:
Makro 2023-01-24 16:40:30 -08:00 committed by GitHub
parent b7aa066e5c
commit f6f0d0e4f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 286 additions and 1 deletions

View File

@ -1088,6 +1088,85 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
unary_ops_wrapper(tensor.node.clone(), output, ops)
}
fn cos<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
#[derive(new, Debug)]
struct Backward<B: Backend, const D: usize> {
_b: B,
}
impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
for Backward<B, D>
{
fn partial(
&self,
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
) -> B::TensorPrimitive<D> {
let value = B::neg(&B::sin(&state.input.value()));
B::mul(&state.output.grad(), &value)
}
}
let output = B::cos(tensor.tensor_ref());
let ops = Backward::<B, D>::new(B::default());
unary_ops_wrapper(tensor.node.clone(), output, ops)
}
fn sin<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
#[derive(new, Debug)]
struct Backward<B: Backend, const D: usize> {
_b: B,
}
impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
for Backward<B, D>
{
fn partial(
&self,
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
) -> B::TensorPrimitive<D> {
let value = B::cos(&state.input.value());
B::mul(&state.output.grad(), &value)
}
}
let output = B::sin(tensor.tensor_ref());
let ops = Backward::<B, D>::new(B::default());
unary_ops_wrapper(tensor.node.clone(), output, ops)
}
fn tanh<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
#[derive(new, Debug)]
struct Backward<B: Backend, const D: usize> {
_b: B,
}
impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
for Backward<B, D>
{
fn partial(
&self,
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
) -> B::TensorPrimitive<D> {
let value =
B::add_scalar(&B::neg(&B::powf(&state.output.value(), 2.0)), &1.to_elem());
B::mul(&state.output.grad(), &value)
}
}
let output = B::tanh(tensor.tensor_ref());
let ops = Backward::<B, D>::new(B::default());
unary_ops_wrapper(tensor.node.clone(), output, ops)
}
fn erf<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {

View File

@ -0,0 +1,29 @@
#[burn_tensor_testgen::testgen(ad_cos)]
mod tests {
use super::*;
use burn_tensor::Data;
#[test]
fn should_diff_cos() {
let data_1 = Data::<f32, 2>::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = Data::<f32, 2>::from([[6.0, 7.0], [9.0, 10.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = tensor_1.matmul(&tensor_2.cos());
let tensor_4 = tensor_3.matmul(&tensor_2);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_approx_eq(&Data::from([[26.8063, -27.7870], [26.8063, -27.7870]]), 3);
grad_2.to_data().assert_approx_eq(
&Data::from([[9.222064, -39.123375], [-28.721354, 49.748356]]),
3,
);
}
}

View File

@ -5,6 +5,7 @@ mod cat;
mod complex;
mod conv1d;
mod conv2d;
mod cos;
mod cross_entropy;
mod div;
mod erf;
@ -20,9 +21,11 @@ mod neg;
mod pow;
mod relu;
mod reshape;
mod sin;
mod softmax;
mod sqrt;
mod sub;
mod tanh;
mod transpose;
#[macro_export]
@ -44,6 +47,7 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_add!();
burn_autodiff::testgen_ad_aggregation!();
burn_autodiff::testgen_ad_cat!();
burn_autodiff::testgen_ad_cos!();
burn_autodiff::testgen_ad_cross_entropy_loss!();
burn_autodiff::testgen_ad_div!();
burn_autodiff::testgen_ad_erf!();
@ -55,11 +59,13 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_mul!();
burn_autodiff::testgen_ad_neg!();
burn_autodiff::testgen_ad_powf!();
burn_autodiff::testgen_ad_sqrt!();
burn_autodiff::testgen_ad_relu!();
burn_autodiff::testgen_ad_reshape!();
burn_autodiff::testgen_ad_sin!();
burn_autodiff::testgen_ad_softmax!();
burn_autodiff::testgen_ad_sqrt!();
burn_autodiff::testgen_ad_sub!();
burn_autodiff::testgen_ad_tanh!();
burn_autodiff::testgen_ad_transpose!();
};
}

View File

@ -0,0 +1,29 @@
#[burn_tensor_testgen::testgen(ad_sin)]
mod tests {
use super::*;
use burn_tensor::Data;
#[test]
fn should_diff_sin() {
let data_1 = Data::<f32, 2>::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = Data::<f32, 2>::from([[6.0, 7.0], [9.0, 10.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = tensor_1.matmul(&tensor_2.sin());
let tensor_4 = tensor_3.matmul(&tensor_2);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_approx_eq(&Data::from([[8.8500, -4.9790], [8.8500, -4.9790]]), 3);
grad_2.to_data().assert_approx_eq(
&Data::from([[38.668987, 44.194775], [-59.97261, -80.46094]]),
3,
);
}
}

View File

@ -0,0 +1,28 @@
#[burn_tensor_testgen::testgen(ad_tanh)]
mod tests {
use super::*;
use burn_tensor::Data;
#[test]
fn should_diff_tanh() {
let data_1 = Data::<f32, 2>::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = Data::<f32, 2>::from([[6.0, 7.0], [9.0, 10.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = tensor_1.matmul(&tensor_2.tanh());
let tensor_4 = tensor_3.matmul(&tensor_2);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_approx_eq(&Data::from([[32.0, 32.0], [32.0, 32.0]]), 3);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[8.00092, 8.000153], [8.000003, 7.999995]]), 3);
}
}

View File

@ -537,6 +537,36 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
NdArrayTensor { array, shape }
}
fn cos<const D: usize>(tensor: &NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor
.array
.mapv(|a| a.to_f64().unwrap().cos().to_elem())
.into_shared();
let shape = tensor.shape;
NdArrayTensor { array, shape }
}
fn sin<const D: usize>(tensor: &NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor
.array
.mapv(|a| a.to_f64().unwrap().sin().to_elem())
.into_shared();
let shape = tensor.shape;
NdArrayTensor { array, shape }
}
fn tanh<const D: usize>(tensor: &NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor
.array
.mapv(|a| a.to_f64().unwrap().tanh().to_elem())
.into_shared();
let shape = tensor.shape;
NdArrayTensor { array, shape }
}
fn erf<const D: usize>(tensor: &NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor
.array

View File

@ -451,6 +451,18 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
to_tensor(tensor.tensor.sqrt())
}
fn cos<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<E, D> {
to_tensor(tensor.tensor.cos())
}
fn sin<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<E, D> {
to_tensor(tensor.tensor.sin())
}
fn tanh<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<E, D> {
to_tensor(tensor.tensor.tanh())
}
fn erf<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<E, D> {
to_tensor(tensor.tensor.erf())
}

View File

@ -105,6 +105,21 @@ where
Self::new(B::sqrt(&self.value))
}
/// Applies element wise cosine operation.
pub fn cos(&self) -> Self {
Self::new(B::cos(&self.value))
}
/// Applies element wise sine operation.
pub fn sin(&self) -> Self {
Self::new(B::sin(&self.value))
}
/// Applies element wise hyperbolic tangent operation.
pub fn tanh(&self) -> Self {
Self::new(B::tanh(&self.value))
}
/// Returns the shape of the current tensor.
pub fn shape(&self) -> &Shape<D> {
B::shape(&self.value)

View File

@ -215,6 +215,9 @@ pub trait TensorOps<B: Backend> {
fn log<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
fn powf<const D: usize>(tensor: &B::TensorPrimitive<D>, value: f32) -> B::TensorPrimitive<D>;
fn sqrt<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
fn cos<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
fn sin<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
fn tanh<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
fn erf<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
fn cat<const D: usize>(tensors: &[B::TensorPrimitive<D>], dim: usize) -> B::TensorPrimitive<D>;
fn relu<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;

View File

@ -21,6 +21,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_add!();
burn_tensor::testgen_aggregation!();
burn_tensor::testgen_arg!();
burn_tensor::testgen_cos!();
burn_tensor::testgen_div!();
burn_tensor::testgen_erf!();
burn_tensor::testgen_exp!();
@ -33,6 +34,8 @@ macro_rules! testgen_all {
burn_tensor::testgen_powf!();
burn_tensor::testgen_repeat!();
burn_tensor::testgen_reshape!();
burn_tensor::testgen_sin!();
burn_tensor::testgen_tanh!();
burn_tensor::testgen_sub!();
burn_tensor::testgen_transpose!();

View File

@ -0,0 +1,16 @@
#[burn_tensor_testgen::testgen(cos)]
mod tests {
use super::*;
use burn_tensor::{Data, Tensor};
#[test]
fn should_support_cos_ops() {
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);
let data_actual = tensor.cos().into_data();
let data_expected = Data::from([[1.0, 0.5403, -0.4161], [-0.9899, -0.6536, 0.2836]]);
data_expected.assert_approx_eq(&data_actual, 3);
}
}

View File

@ -1,6 +1,7 @@
mod add;
mod aggregation;
mod arg;
mod cos;
mod div;
mod erf;
mod exp;
@ -13,5 +14,7 @@ mod neg;
mod powf;
mod repeat;
mod reshape;
mod sin;
mod sub;
mod tanh;
mod transpose;

View File

@ -0,0 +1,16 @@
#[burn_tensor_testgen::testgen(sin)]
mod tests {
use super::*;
use burn_tensor::{Data, Tensor};
#[test]
fn should_support_sin_ops() {
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);
let data_actual = tensor.sin().into_data();
let data_expected = Data::from([[0.0, 0.8414, 0.9092], [0.1411, -0.7568, -0.9589]]);
data_expected.assert_approx_eq(&data_actual, 3);
}
}

View File

@ -0,0 +1,16 @@
#[burn_tensor_testgen::testgen(tanh)]
mod tests {
use super::*;
use burn_tensor::{Data, Tensor};
#[test]
fn should_support_tanh_ops() {
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);
let data_actual = tensor.tanh().into_data();
let data_expected = Data::from([[0.0, 0.7615, 0.9640], [0.9950, 0.9993, 0.9999]]);
data_expected.assert_approx_eq(&data_actual, 3);
}
}