mirror of https://github.com/tracel-ai/burn.git
Add cos, sin and tanh operations (#155)
* Add cos, sin and tanh operations * Add tests * Fix formatting
This commit is contained in:
parent
b7aa066e5c
commit
f6f0d0e4f3
|
@ -1088,6 +1088,85 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
unary_ops_wrapper(tensor.node.clone(), output, ops)
|
||||
}
|
||||
|
||||
fn cos<const D: usize>(
|
||||
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
|
||||
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
|
||||
#[derive(new, Debug)]
|
||||
struct Backward<B: Backend, const D: usize> {
|
||||
_b: B,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
|
||||
for Backward<B, D>
|
||||
{
|
||||
fn partial(
|
||||
&self,
|
||||
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
|
||||
) -> B::TensorPrimitive<D> {
|
||||
let value = B::neg(&B::sin(&state.input.value()));
|
||||
B::mul(&state.output.grad(), &value)
|
||||
}
|
||||
}
|
||||
|
||||
let output = B::cos(tensor.tensor_ref());
|
||||
let ops = Backward::<B, D>::new(B::default());
|
||||
|
||||
unary_ops_wrapper(tensor.node.clone(), output, ops)
|
||||
}
|
||||
|
||||
fn sin<const D: usize>(
|
||||
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
|
||||
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
|
||||
#[derive(new, Debug)]
|
||||
struct Backward<B: Backend, const D: usize> {
|
||||
_b: B,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
|
||||
for Backward<B, D>
|
||||
{
|
||||
fn partial(
|
||||
&self,
|
||||
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
|
||||
) -> B::TensorPrimitive<D> {
|
||||
let value = B::cos(&state.input.value());
|
||||
B::mul(&state.output.grad(), &value)
|
||||
}
|
||||
}
|
||||
|
||||
let output = B::sin(tensor.tensor_ref());
|
||||
let ops = Backward::<B, D>::new(B::default());
|
||||
|
||||
unary_ops_wrapper(tensor.node.clone(), output, ops)
|
||||
}
|
||||
|
||||
fn tanh<const D: usize>(
|
||||
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
|
||||
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
|
||||
#[derive(new, Debug)]
|
||||
struct Backward<B: Backend, const D: usize> {
|
||||
_b: B,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
|
||||
for Backward<B, D>
|
||||
{
|
||||
fn partial(
|
||||
&self,
|
||||
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
|
||||
) -> B::TensorPrimitive<D> {
|
||||
let value =
|
||||
B::add_scalar(&B::neg(&B::powf(&state.output.value(), 2.0)), &1.to_elem());
|
||||
B::mul(&state.output.grad(), &value)
|
||||
}
|
||||
}
|
||||
|
||||
let output = B::tanh(tensor.tensor_ref());
|
||||
let ops = Backward::<B, D>::new(B::default());
|
||||
|
||||
unary_ops_wrapper(tensor.node.clone(), output, ops)
|
||||
}
|
||||
|
||||
fn erf<const D: usize>(
|
||||
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
|
||||
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
#[burn_tensor_testgen::testgen(ad_cos)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn should_diff_cos() {
|
||||
let data_1 = Data::<f32, 2>::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = Data::<f32, 2>::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let tensor_1 = TestADTensor::from_data(data_1);
|
||||
let tensor_2 = TestADTensor::from_data(data_2);
|
||||
|
||||
let tensor_3 = tensor_1.matmul(&tensor_2.cos());
|
||||
let tensor_4 = tensor_3.matmul(&tensor_2);
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[26.8063, -27.7870], [26.8063, -27.7870]]), 3);
|
||||
grad_2.to_data().assert_approx_eq(
|
||||
&Data::from([[9.222064, -39.123375], [-28.721354, 49.748356]]),
|
||||
3,
|
||||
);
|
||||
}
|
||||
}
|
|
@ -5,6 +5,7 @@ mod cat;
|
|||
mod complex;
|
||||
mod conv1d;
|
||||
mod conv2d;
|
||||
mod cos;
|
||||
mod cross_entropy;
|
||||
mod div;
|
||||
mod erf;
|
||||
|
@ -20,9 +21,11 @@ mod neg;
|
|||
mod pow;
|
||||
mod relu;
|
||||
mod reshape;
|
||||
mod sin;
|
||||
mod softmax;
|
||||
mod sqrt;
|
||||
mod sub;
|
||||
mod tanh;
|
||||
mod transpose;
|
||||
|
||||
#[macro_export]
|
||||
|
@ -44,6 +47,7 @@ macro_rules! testgen_all {
|
|||
burn_autodiff::testgen_ad_add!();
|
||||
burn_autodiff::testgen_ad_aggregation!();
|
||||
burn_autodiff::testgen_ad_cat!();
|
||||
burn_autodiff::testgen_ad_cos!();
|
||||
burn_autodiff::testgen_ad_cross_entropy_loss!();
|
||||
burn_autodiff::testgen_ad_div!();
|
||||
burn_autodiff::testgen_ad_erf!();
|
||||
|
@ -55,11 +59,13 @@ macro_rules! testgen_all {
|
|||
burn_autodiff::testgen_ad_mul!();
|
||||
burn_autodiff::testgen_ad_neg!();
|
||||
burn_autodiff::testgen_ad_powf!();
|
||||
burn_autodiff::testgen_ad_sqrt!();
|
||||
burn_autodiff::testgen_ad_relu!();
|
||||
burn_autodiff::testgen_ad_reshape!();
|
||||
burn_autodiff::testgen_ad_sin!();
|
||||
burn_autodiff::testgen_ad_softmax!();
|
||||
burn_autodiff::testgen_ad_sqrt!();
|
||||
burn_autodiff::testgen_ad_sub!();
|
||||
burn_autodiff::testgen_ad_tanh!();
|
||||
burn_autodiff::testgen_ad_transpose!();
|
||||
};
|
||||
}
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
#[burn_tensor_testgen::testgen(ad_sin)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn should_diff_sin() {
|
||||
let data_1 = Data::<f32, 2>::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = Data::<f32, 2>::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let tensor_1 = TestADTensor::from_data(data_1);
|
||||
let tensor_2 = TestADTensor::from_data(data_2);
|
||||
|
||||
let tensor_3 = tensor_1.matmul(&tensor_2.sin());
|
||||
let tensor_4 = tensor_3.matmul(&tensor_2);
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[8.8500, -4.9790], [8.8500, -4.9790]]), 3);
|
||||
grad_2.to_data().assert_approx_eq(
|
||||
&Data::from([[38.668987, 44.194775], [-59.97261, -80.46094]]),
|
||||
3,
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
#[burn_tensor_testgen::testgen(ad_tanh)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn should_diff_tanh() {
|
||||
let data_1 = Data::<f32, 2>::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = Data::<f32, 2>::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let tensor_1 = TestADTensor::from_data(data_1);
|
||||
let tensor_2 = TestADTensor::from_data(data_2);
|
||||
|
||||
let tensor_3 = tensor_1.matmul(&tensor_2.tanh());
|
||||
let tensor_4 = tensor_3.matmul(&tensor_2);
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[32.0, 32.0], [32.0, 32.0]]), 3);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[8.00092, 8.000153], [8.000003, 7.999995]]), 3);
|
||||
}
|
||||
}
|
|
@ -537,6 +537,36 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
|||
NdArrayTensor { array, shape }
|
||||
}
|
||||
|
||||
fn cos<const D: usize>(tensor: &NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||
let array = tensor
|
||||
.array
|
||||
.mapv(|a| a.to_f64().unwrap().cos().to_elem())
|
||||
.into_shared();
|
||||
let shape = tensor.shape;
|
||||
|
||||
NdArrayTensor { array, shape }
|
||||
}
|
||||
|
||||
fn sin<const D: usize>(tensor: &NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||
let array = tensor
|
||||
.array
|
||||
.mapv(|a| a.to_f64().unwrap().sin().to_elem())
|
||||
.into_shared();
|
||||
let shape = tensor.shape;
|
||||
|
||||
NdArrayTensor { array, shape }
|
||||
}
|
||||
|
||||
fn tanh<const D: usize>(tensor: &NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||
let array = tensor
|
||||
.array
|
||||
.mapv(|a| a.to_f64().unwrap().tanh().to_elem())
|
||||
.into_shared();
|
||||
let shape = tensor.shape;
|
||||
|
||||
NdArrayTensor { array, shape }
|
||||
}
|
||||
|
||||
fn erf<const D: usize>(tensor: &NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||
let array = tensor
|
||||
.array
|
||||
|
|
|
@ -451,6 +451,18 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
to_tensor(tensor.tensor.sqrt())
|
||||
}
|
||||
|
||||
fn cos<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<E, D> {
|
||||
to_tensor(tensor.tensor.cos())
|
||||
}
|
||||
|
||||
fn sin<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<E, D> {
|
||||
to_tensor(tensor.tensor.sin())
|
||||
}
|
||||
|
||||
fn tanh<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<E, D> {
|
||||
to_tensor(tensor.tensor.tanh())
|
||||
}
|
||||
|
||||
fn erf<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<E, D> {
|
||||
to_tensor(tensor.tensor.erf())
|
||||
}
|
||||
|
|
|
@ -105,6 +105,21 @@ where
|
|||
Self::new(B::sqrt(&self.value))
|
||||
}
|
||||
|
||||
/// Applies element wise cosine operation.
|
||||
pub fn cos(&self) -> Self {
|
||||
Self::new(B::cos(&self.value))
|
||||
}
|
||||
|
||||
/// Applies element wise sine operation.
|
||||
pub fn sin(&self) -> Self {
|
||||
Self::new(B::sin(&self.value))
|
||||
}
|
||||
|
||||
/// Applies element wise hyperbolic tangent operation.
|
||||
pub fn tanh(&self) -> Self {
|
||||
Self::new(B::tanh(&self.value))
|
||||
}
|
||||
|
||||
/// Returns the shape of the current tensor.
|
||||
pub fn shape(&self) -> &Shape<D> {
|
||||
B::shape(&self.value)
|
||||
|
|
|
@ -215,6 +215,9 @@ pub trait TensorOps<B: Backend> {
|
|||
fn log<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
fn powf<const D: usize>(tensor: &B::TensorPrimitive<D>, value: f32) -> B::TensorPrimitive<D>;
|
||||
fn sqrt<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
fn cos<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
fn sin<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
fn tanh<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
fn erf<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
fn cat<const D: usize>(tensors: &[B::TensorPrimitive<D>], dim: usize) -> B::TensorPrimitive<D>;
|
||||
fn relu<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
|
|
|
@ -21,6 +21,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_add!();
|
||||
burn_tensor::testgen_aggregation!();
|
||||
burn_tensor::testgen_arg!();
|
||||
burn_tensor::testgen_cos!();
|
||||
burn_tensor::testgen_div!();
|
||||
burn_tensor::testgen_erf!();
|
||||
burn_tensor::testgen_exp!();
|
||||
|
@ -33,6 +34,8 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_powf!();
|
||||
burn_tensor::testgen_repeat!();
|
||||
burn_tensor::testgen_reshape!();
|
||||
burn_tensor::testgen_sin!();
|
||||
burn_tensor::testgen_tanh!();
|
||||
burn_tensor::testgen_sub!();
|
||||
burn_tensor::testgen_transpose!();
|
||||
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
#[burn_tensor_testgen::testgen(cos)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_support_cos_ops() {
|
||||
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
||||
|
||||
let data_actual = tensor.cos().into_data();
|
||||
|
||||
let data_expected = Data::from([[1.0, 0.5403, -0.4161], [-0.9899, -0.6536, 0.2836]]);
|
||||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
}
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
mod add;
|
||||
mod aggregation;
|
||||
mod arg;
|
||||
mod cos;
|
||||
mod div;
|
||||
mod erf;
|
||||
mod exp;
|
||||
|
@ -13,5 +14,7 @@ mod neg;
|
|||
mod powf;
|
||||
mod repeat;
|
||||
mod reshape;
|
||||
mod sin;
|
||||
mod sub;
|
||||
mod tanh;
|
||||
mod transpose;
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
#[burn_tensor_testgen::testgen(sin)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_support_sin_ops() {
|
||||
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
||||
|
||||
let data_actual = tensor.sin().into_data();
|
||||
|
||||
let data_expected = Data::from([[0.0, 0.8414, 0.9092], [0.1411, -0.7568, -0.9589]]);
|
||||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
#[burn_tensor_testgen::testgen(tanh)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_support_tanh_ops() {
|
||||
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
||||
|
||||
let data_actual = tensor.tanh().into_data();
|
||||
|
||||
let data_expected = Data::from([[0.0, 0.7615, 0.9640], [0.9950, 0.9993, 0.9999]]);
|
||||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue