mirror of https://github.com/tracel-ai/burn.git
parent
844b199dc1
commit
69001b0d69
|
@ -0,0 +1,55 @@
|
||||||
|
use crate::{
|
||||||
|
grads::Gradients,
|
||||||
|
ops::{unary, Backward, Ops, OpsKind},
|
||||||
|
tensor::ADTensor,
|
||||||
|
ADBackendDecorator,
|
||||||
|
};
|
||||||
|
use burn_tensor::{backend::Backend, ops::ActivationOps};
|
||||||
|
|
||||||
|
impl<B: Backend> ActivationOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
||||||
|
fn gelu<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, D> {
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Gelu<const D: usize>;
|
||||||
|
|
||||||
|
impl<const D: usize, B: Backend> Backward<B, D, 1> for Gelu<D> {
|
||||||
|
type State = B::TensorPrimitive<D>;
|
||||||
|
|
||||||
|
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
||||||
|
let input = ops.state;
|
||||||
|
|
||||||
|
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
|
||||||
|
B::gelu_backward(input, grad)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match Gelu::<D>.prepare([tensor.node], [tensor.graph]).statefull() {
|
||||||
|
OpsKind::Tracked(prep) => {
|
||||||
|
let output = B::gelu(tensor.primitive.clone());
|
||||||
|
prep.finish(tensor.primitive, output)
|
||||||
|
}
|
||||||
|
OpsKind::UnTracked(prep) => prep.finish(B::gelu(tensor.primitive)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn relu<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, D> {
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Relu;
|
||||||
|
|
||||||
|
impl<B: Backend, const D: usize> Backward<B, D, 1> for Relu {
|
||||||
|
type State = B::TensorPrimitive<D>;
|
||||||
|
|
||||||
|
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
||||||
|
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
|
||||||
|
B::relu_backward(ops.state, grad)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let output = B::relu(tensor.primitive);
|
||||||
|
|
||||||
|
match Relu.prepare([tensor.node], [tensor.graph]).statefull() {
|
||||||
|
OpsKind::Tracked(prep) => prep.finish(output.clone(), output),
|
||||||
|
OpsKind::UnTracked(prep) => prep.finish(output),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,3 +1,4 @@
|
||||||
|
mod activation;
|
||||||
mod backward;
|
mod backward;
|
||||||
mod base;
|
mod base;
|
||||||
mod bool_tensor;
|
mod bool_tensor;
|
||||||
|
|
|
@ -1304,29 +1304,6 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
||||||
let ops = CatStep::<B, D>::new(nodes, output.node.clone(), dim);
|
let ops = CatStep::<B, D>::new(nodes, output.node.clone(), dim);
|
||||||
output.register_step(ops)
|
output.register_step(ops)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn relu<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, D> {
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct Relu;
|
|
||||||
|
|
||||||
impl<B: Backend, const D: usize> Backward<B, D, 1> for Relu {
|
|
||||||
type State = B::TensorPrimitive<D>;
|
|
||||||
|
|
||||||
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
|
||||||
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
|
|
||||||
let zero = 0.elem();
|
|
||||||
let mask = B::lower_equal_elem(ops.state, zero);
|
|
||||||
B::mask_fill(grad, mask, zero)
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let output = B::relu(tensor.primitive);
|
|
||||||
|
|
||||||
match Relu.prepare([tensor.node], [tensor.graph]).statefull() {
|
|
||||||
OpsKind::Tracked(prep) => prep.finish(output.clone(), output),
|
|
||||||
OpsKind::UnTracked(prep) => prep.finish(output),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Make sure the grad tensor has the given shape.
|
/// Make sure the grad tensor has the given shape.
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
#[burn_tensor_testgen::testgen(ad_gelu)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn_tensor::{activation, Data};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_diff_gelu() {
|
||||||
|
let tensor_1 = TestADTensor::from_floats([[0.0, 1.0], [-3.0, 4.0]]).require_grad();
|
||||||
|
let tensor_2 = TestADTensor::from_floats([[6.0, -0.5], [9.0, 10.0]]).require_grad();
|
||||||
|
|
||||||
|
let x = tensor_1.clone().matmul(activation::gelu(tensor_2.clone()));
|
||||||
|
let x = tensor_1.clone().matmul(x);
|
||||||
|
let grads = x.backward();
|
||||||
|
|
||||||
|
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||||
|
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||||
|
|
||||||
|
grad_1
|
||||||
|
.to_data()
|
||||||
|
.assert_approx_eq(&Data::from([[1.4629, 1.4629], [48.2286, 153.4629]]), 2);
|
||||||
|
grad_2
|
||||||
|
.to_data()
|
||||||
|
.assert_approx_eq(&Data::from([[-15.0000, -1.9895], [17.0000, 17.0000]]), 2);
|
||||||
|
}
|
||||||
|
}
|
|
@ -12,6 +12,7 @@ mod cross_entropy;
|
||||||
mod div;
|
mod div;
|
||||||
mod erf;
|
mod erf;
|
||||||
mod exp;
|
mod exp;
|
||||||
|
mod gelu;
|
||||||
mod index;
|
mod index;
|
||||||
mod index_select;
|
mod index_select;
|
||||||
mod index_select_dim;
|
mod index_select_dim;
|
||||||
|
@ -42,6 +43,10 @@ macro_rules! testgen_all {
|
||||||
// Behavior
|
// Behavior
|
||||||
burn_autodiff::testgen_ad_broadcast!();
|
burn_autodiff::testgen_ad_broadcast!();
|
||||||
|
|
||||||
|
// Activation
|
||||||
|
burn_autodiff::testgen_ad_relu!();
|
||||||
|
burn_autodiff::testgen_ad_gelu!();
|
||||||
|
|
||||||
// Modules
|
// Modules
|
||||||
burn_autodiff::testgen_ad_conv1d!();
|
burn_autodiff::testgen_ad_conv1d!();
|
||||||
burn_autodiff::testgen_ad_conv2d!();
|
burn_autodiff::testgen_ad_conv2d!();
|
||||||
|
@ -70,7 +75,6 @@ macro_rules! testgen_all {
|
||||||
burn_autodiff::testgen_ad_mul!();
|
burn_autodiff::testgen_ad_mul!();
|
||||||
burn_autodiff::testgen_ad_neg!();
|
burn_autodiff::testgen_ad_neg!();
|
||||||
burn_autodiff::testgen_ad_powf!();
|
burn_autodiff::testgen_ad_powf!();
|
||||||
burn_autodiff::testgen_ad_relu!();
|
|
||||||
burn_autodiff::testgen_ad_reshape!();
|
burn_autodiff::testgen_ad_reshape!();
|
||||||
burn_autodiff::testgen_ad_sin!();
|
burn_autodiff::testgen_ad_sin!();
|
||||||
burn_autodiff::testgen_ad_softmax!();
|
burn_autodiff::testgen_ad_softmax!();
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
|
||||||
|
use burn_tensor::{ops::ActivationOps, ElementConversion};
|
||||||
|
|
||||||
|
impl<E: FloatNdArrayElement> ActivationOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
||||||
|
fn relu<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||||
|
let zero = 0.elem();
|
||||||
|
let array = tensor
|
||||||
|
.array
|
||||||
|
.mapv_into(|elem| match elem < zero {
|
||||||
|
true => zero,
|
||||||
|
false => elem,
|
||||||
|
})
|
||||||
|
.into_shared();
|
||||||
|
|
||||||
|
NdArrayTensor::new(array)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,3 +1,4 @@
|
||||||
|
mod activations;
|
||||||
mod base;
|
mod base;
|
||||||
mod bool_tensor;
|
mod bool_tensor;
|
||||||
mod int_tensor;
|
mod int_tensor;
|
||||||
|
|
|
@ -417,19 +417,6 @@ impl<E: FloatNdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E>
|
||||||
fn cat<const D: usize>(tensors: Vec<NdArrayTensor<E, D>>, dim: usize) -> NdArrayTensor<E, D> {
|
fn cat<const D: usize>(tensors: Vec<NdArrayTensor<E, D>>, dim: usize) -> NdArrayTensor<E, D> {
|
||||||
NdArrayOps::cat(tensors, dim)
|
NdArrayOps::cat(tensors, dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn relu<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
|
||||||
let zero = 0.elem();
|
|
||||||
let array = tensor
|
|
||||||
.array
|
|
||||||
.mapv_into(|elem| match elem < zero {
|
|
||||||
true => 0.0.elem(),
|
|
||||||
false => elem,
|
|
||||||
})
|
|
||||||
.into_shared();
|
|
||||||
|
|
||||||
NdArrayTensor::new(array)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn arg<E: FloatNdArrayElement, F, const D: usize>(
|
fn arg<E: FloatNdArrayElement, F, const D: usize>(
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
use crate::{element::TchElement, TchBackend, TchTensor};
|
||||||
|
use burn_tensor::ops::ActivationOps;
|
||||||
|
|
||||||
|
impl<E: TchElement> ActivationOps<TchBackend<E>> for TchBackend<E> {
|
||||||
|
fn relu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
||||||
|
tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())
|
||||||
|
}
|
||||||
|
fn gelu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
||||||
|
tensor.unary_ops(
|
||||||
|
|mut tensor| tensor.gelu_("none"),
|
||||||
|
|tensor| tensor.gelu("none"),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
fn gelu_backward<const D: usize>(
|
||||||
|
tensor: TchTensor<E, D>,
|
||||||
|
grad: TchTensor<E, D>,
|
||||||
|
) -> TchTensor<E, D> {
|
||||||
|
TchTensor::new(tensor.tensor.gelu_backward(&grad.tensor, "none"))
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,3 +1,4 @@
|
||||||
|
mod activation;
|
||||||
mod base;
|
mod base;
|
||||||
mod bool_tensor;
|
mod bool_tensor;
|
||||||
mod int_tensor;
|
mod int_tensor;
|
||||||
|
|
|
@ -380,8 +380,4 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
||||||
fn cat<const D: usize>(tensors: Vec<TchTensor<E, D>>, dim: usize) -> TchTensor<E, D> {
|
fn cat<const D: usize>(tensors: Vec<TchTensor<E, D>>, dim: usize) -> TchTensor<E, D> {
|
||||||
TchOps::cat(tensors, dim)
|
TchOps::cat(tensors, dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn relu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
|
||||||
tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
use crate::backend::Backend;
|
use crate::backend::Backend;
|
||||||
use crate::Tensor;
|
use crate::Tensor;
|
||||||
use crate::{ElementPrecision, Precision};
|
use crate::{ElementPrecision, Precision};
|
||||||
use core::f64::consts::SQRT_2;
|
|
||||||
|
|
||||||
/// Applies the rectified linear unit function.
|
/// Applies the rectified linear unit function.
|
||||||
pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||||
|
@ -10,9 +9,7 @@ pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||||
|
|
||||||
/// Applies the Gaussian Error Linear Units function as described in the paper in [Gaussian Error Linear Units (GELUs)](https://arxiv.org/pdf/1606.08415v3.pdf).
|
/// Applies the Gaussian Error Linear Units function as described in the paper in [Gaussian Error Linear Units (GELUs)](https://arxiv.org/pdf/1606.08415v3.pdf).
|
||||||
pub fn gelu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
pub fn gelu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||||
let x = tensor.clone().div_scalar(SQRT_2).erf().add_scalar(1.0_f32);
|
Tensor::from_primitive(B::gelu(tensor.primitive))
|
||||||
|
|
||||||
tensor.mul(x) / 2
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Applies the softmax function.
|
/// Applies the softmax function.
|
||||||
|
|
|
@ -54,6 +54,7 @@ pub trait Backend:
|
||||||
+ BoolTensorOps<Self>
|
+ BoolTensorOps<Self>
|
||||||
+ IntTensorOps<Self>
|
+ IntTensorOps<Self>
|
||||||
+ ModuleOps<Self>
|
+ ModuleOps<Self>
|
||||||
|
+ ActivationOps<Self>
|
||||||
+ Clone
|
+ Clone
|
||||||
+ Sized
|
+ Sized
|
||||||
+ Default
|
+ Default
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
use crate::{backend::Backend, ElementConversion};
|
||||||
|
use core::f64::consts::SQRT_2;
|
||||||
|
|
||||||
|
/// Activation function operations.
|
||||||
|
///
|
||||||
|
/// This trait let backend implementations override activation functions for better performance.
|
||||||
|
pub trait ActivationOps<B: Backend> {
|
||||||
|
fn relu<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
|
||||||
|
let mask = B::lower_equal_elem(tensor.clone(), 0.elem());
|
||||||
|
|
||||||
|
B::mask_fill(tensor, mask, 0.elem())
|
||||||
|
}
|
||||||
|
fn relu_backward<const D: usize>(
|
||||||
|
output: B::TensorPrimitive<D>,
|
||||||
|
grad: B::TensorPrimitive<D>,
|
||||||
|
) -> B::TensorPrimitive<D> {
|
||||||
|
let mask = B::lower_equal_elem(output, 0.elem());
|
||||||
|
|
||||||
|
B::mask_fill(grad, mask, 0.elem())
|
||||||
|
}
|
||||||
|
fn gelu<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
|
||||||
|
let x = B::div_scalar(tensor.clone(), SQRT_2.elem());
|
||||||
|
let x = B::erf(x);
|
||||||
|
let x = B::add_scalar(x, 1i32.elem());
|
||||||
|
let x = B::mul(tensor, x);
|
||||||
|
|
||||||
|
B::div_scalar(x, 2i32.elem())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gelu_backward<const D: usize>(
|
||||||
|
x: B::TensorPrimitive<D>,
|
||||||
|
grad: B::TensorPrimitive<D>,
|
||||||
|
) -> B::TensorPrimitive<D> {
|
||||||
|
// Derivative of the approximate gelu implementation based on tanh.
|
||||||
|
|
||||||
|
let constant_1 = 0.0356774;
|
||||||
|
let constant_2 = 0.797885;
|
||||||
|
let constant_3 = 0.0535161;
|
||||||
|
let constant_4 = 0.398942;
|
||||||
|
|
||||||
|
let x3 = B::powf(x.clone(), 3.0);
|
||||||
|
|
||||||
|
let c1 = B::mul_scalar(x3.clone(), constant_1.elem());
|
||||||
|
let c2 = B::mul_scalar(x.clone(), constant_2.elem());
|
||||||
|
let c3 = B::mul_scalar(x3, constant_3.elem());
|
||||||
|
let c4 = B::mul_scalar(x, constant_4.elem());
|
||||||
|
|
||||||
|
let inner1 = B::add(c1, c2);
|
||||||
|
let inner2 = B::add(c3, c4);
|
||||||
|
|
||||||
|
let tanh = B::tanh(inner1);
|
||||||
|
|
||||||
|
let sech = B::powf(tanh.clone(), 2.0);
|
||||||
|
let sech = B::neg(sech);
|
||||||
|
let sech = B::add_scalar(sech, 1.elem());
|
||||||
|
|
||||||
|
let y1 = B::mul_scalar(tanh, 0.5.elem());
|
||||||
|
let y2 = B::mul(inner2, sech);
|
||||||
|
let y2 = B::add_scalar(y2, 0.5.elem());
|
||||||
|
let y = B::add(y1, y2);
|
||||||
|
|
||||||
|
B::mul(y, grad)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,8 +1,10 @@
|
||||||
|
mod activation;
|
||||||
mod bool_tensor;
|
mod bool_tensor;
|
||||||
mod int_tensor;
|
mod int_tensor;
|
||||||
mod modules;
|
mod modules;
|
||||||
mod tensor;
|
mod tensor;
|
||||||
|
|
||||||
|
pub use activation::*;
|
||||||
pub use bool_tensor::*;
|
pub use bool_tensor::*;
|
||||||
pub use int_tensor::*;
|
pub use int_tensor::*;
|
||||||
pub use modules::*;
|
pub use modules::*;
|
||||||
|
|
|
@ -241,5 +241,4 @@ pub trait TensorOps<B: Backend> {
|
||||||
tensors: Vec<B::TensorPrimitive<D>>,
|
tensors: Vec<B::TensorPrimitive<D>>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
) -> B::TensorPrimitive<D>;
|
) -> B::TensorPrimitive<D>;
|
||||||
fn relu<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@ mod tests {
|
||||||
let data_expected = Data::from([[
|
let data_expected = Data::from([[
|
||||||
0.3851, 0.8207, 0.2714, 0.0777, 0.6351, 0.2704, 0.1419, 0.3687, 0.4993, 0.5051,
|
0.3851, 0.8207, 0.2714, 0.0777, 0.6351, 0.2704, 0.1419, 0.3687, 0.4993, 0.5051,
|
||||||
]]);
|
]]);
|
||||||
data_expected.assert_approx_eq(&data_actual, 3);
|
data_expected.assert_approx_eq(&data_actual, 2); // Low precision to allow approximation
|
||||||
|
// implementation using tanh
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue