mirror of https://github.com/tracel-ai/burn.git
parent
844b199dc1
commit
69001b0d69
|
@ -0,0 +1,55 @@
|
|||
use crate::{
|
||||
grads::Gradients,
|
||||
ops::{unary, Backward, Ops, OpsKind},
|
||||
tensor::ADTensor,
|
||||
ADBackendDecorator,
|
||||
};
|
||||
use burn_tensor::{backend::Backend, ops::ActivationOps};
|
||||
|
||||
impl<B: Backend> ActivationOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
||||
fn gelu<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, D> {
|
||||
#[derive(Debug)]
|
||||
struct Gelu<const D: usize>;
|
||||
|
||||
impl<const D: usize, B: Backend> Backward<B, D, 1> for Gelu<D> {
|
||||
type State = B::TensorPrimitive<D>;
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
||||
let input = ops.state;
|
||||
|
||||
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
|
||||
B::gelu_backward(input, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
match Gelu::<D>.prepare([tensor.node], [tensor.graph]).statefull() {
|
||||
OpsKind::Tracked(prep) => {
|
||||
let output = B::gelu(tensor.primitive.clone());
|
||||
prep.finish(tensor.primitive, output)
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::gelu(tensor.primitive)),
|
||||
}
|
||||
}
|
||||
|
||||
fn relu<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, D> {
|
||||
#[derive(Debug)]
|
||||
struct Relu;
|
||||
|
||||
impl<B: Backend, const D: usize> Backward<B, D, 1> for Relu {
|
||||
type State = B::TensorPrimitive<D>;
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
||||
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
|
||||
B::relu_backward(ops.state, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
let output = B::relu(tensor.primitive);
|
||||
|
||||
match Relu.prepare([tensor.node], [tensor.graph]).statefull() {
|
||||
OpsKind::Tracked(prep) => prep.finish(output.clone(), output),
|
||||
OpsKind::UnTracked(prep) => prep.finish(output),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
mod activation;
|
||||
mod backward;
|
||||
mod base;
|
||||
mod bool_tensor;
|
||||
|
|
|
@ -1304,29 +1304,6 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
let ops = CatStep::<B, D>::new(nodes, output.node.clone(), dim);
|
||||
output.register_step(ops)
|
||||
}
|
||||
|
||||
fn relu<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, D> {
|
||||
#[derive(Debug)]
|
||||
struct Relu;
|
||||
|
||||
impl<B: Backend, const D: usize> Backward<B, D, 1> for Relu {
|
||||
type State = B::TensorPrimitive<D>;
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
||||
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
|
||||
let zero = 0.elem();
|
||||
let mask = B::lower_equal_elem(ops.state, zero);
|
||||
B::mask_fill(grad, mask, zero)
|
||||
});
|
||||
}
|
||||
}
|
||||
let output = B::relu(tensor.primitive);
|
||||
|
||||
match Relu.prepare([tensor.node], [tensor.graph]).statefull() {
|
||||
OpsKind::Tracked(prep) => prep.finish(output.clone(), output),
|
||||
OpsKind::UnTracked(prep) => prep.finish(output),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Make sure the grad tensor has the given shape.
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
#[burn_tensor_testgen::testgen(ad_gelu)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{activation, Data};
|
||||
|
||||
#[test]
|
||||
fn should_diff_gelu() {
|
||||
let tensor_1 = TestADTensor::from_floats([[0.0, 1.0], [-3.0, 4.0]]).require_grad();
|
||||
let tensor_2 = TestADTensor::from_floats([[6.0, -0.5], [9.0, 10.0]]).require_grad();
|
||||
|
||||
let x = tensor_1.clone().matmul(activation::gelu(tensor_2.clone()));
|
||||
let x = tensor_1.clone().matmul(x);
|
||||
let grads = x.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[1.4629, 1.4629], [48.2286, 153.4629]]), 2);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[-15.0000, -1.9895], [17.0000, 17.0000]]), 2);
|
||||
}
|
||||
}
|
|
@ -12,6 +12,7 @@ mod cross_entropy;
|
|||
mod div;
|
||||
mod erf;
|
||||
mod exp;
|
||||
mod gelu;
|
||||
mod index;
|
||||
mod index_select;
|
||||
mod index_select_dim;
|
||||
|
@ -42,6 +43,10 @@ macro_rules! testgen_all {
|
|||
// Behavior
|
||||
burn_autodiff::testgen_ad_broadcast!();
|
||||
|
||||
// Activation
|
||||
burn_autodiff::testgen_ad_relu!();
|
||||
burn_autodiff::testgen_ad_gelu!();
|
||||
|
||||
// Modules
|
||||
burn_autodiff::testgen_ad_conv1d!();
|
||||
burn_autodiff::testgen_ad_conv2d!();
|
||||
|
@ -70,7 +75,6 @@ macro_rules! testgen_all {
|
|||
burn_autodiff::testgen_ad_mul!();
|
||||
burn_autodiff::testgen_ad_neg!();
|
||||
burn_autodiff::testgen_ad_powf!();
|
||||
burn_autodiff::testgen_ad_relu!();
|
||||
burn_autodiff::testgen_ad_reshape!();
|
||||
burn_autodiff::testgen_ad_sin!();
|
||||
burn_autodiff::testgen_ad_softmax!();
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
|
||||
use burn_tensor::{ops::ActivationOps, ElementConversion};
|
||||
|
||||
impl<E: FloatNdArrayElement> ActivationOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
||||
fn relu<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||
let zero = 0.elem();
|
||||
let array = tensor
|
||||
.array
|
||||
.mapv_into(|elem| match elem < zero {
|
||||
true => zero,
|
||||
false => elem,
|
||||
})
|
||||
.into_shared();
|
||||
|
||||
NdArrayTensor::new(array)
|
||||
}
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
mod activations;
|
||||
mod base;
|
||||
mod bool_tensor;
|
||||
mod int_tensor;
|
||||
|
|
|
@ -417,19 +417,6 @@ impl<E: FloatNdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E>
|
|||
fn cat<const D: usize>(tensors: Vec<NdArrayTensor<E, D>>, dim: usize) -> NdArrayTensor<E, D> {
|
||||
NdArrayOps::cat(tensors, dim)
|
||||
}
|
||||
|
||||
fn relu<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||
let zero = 0.elem();
|
||||
let array = tensor
|
||||
.array
|
||||
.mapv_into(|elem| match elem < zero {
|
||||
true => 0.0.elem(),
|
||||
false => elem,
|
||||
})
|
||||
.into_shared();
|
||||
|
||||
NdArrayTensor::new(array)
|
||||
}
|
||||
}
|
||||
|
||||
fn arg<E: FloatNdArrayElement, F, const D: usize>(
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
use crate::{element::TchElement, TchBackend, TchTensor};
|
||||
use burn_tensor::ops::ActivationOps;
|
||||
|
||||
impl<E: TchElement> ActivationOps<TchBackend<E>> for TchBackend<E> {
|
||||
fn relu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
||||
tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())
|
||||
}
|
||||
fn gelu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
||||
tensor.unary_ops(
|
||||
|mut tensor| tensor.gelu_("none"),
|
||||
|tensor| tensor.gelu("none"),
|
||||
)
|
||||
}
|
||||
fn gelu_backward<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
grad: TchTensor<E, D>,
|
||||
) -> TchTensor<E, D> {
|
||||
TchTensor::new(tensor.tensor.gelu_backward(&grad.tensor, "none"))
|
||||
}
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
mod activation;
|
||||
mod base;
|
||||
mod bool_tensor;
|
||||
mod int_tensor;
|
||||
|
|
|
@ -380,8 +380,4 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
fn cat<const D: usize>(tensors: Vec<TchTensor<E, D>>, dim: usize) -> TchTensor<E, D> {
|
||||
TchOps::cat(tensors, dim)
|
||||
}
|
||||
|
||||
fn relu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
||||
tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
use crate::backend::Backend;
|
||||
use crate::Tensor;
|
||||
use crate::{ElementPrecision, Precision};
|
||||
use core::f64::consts::SQRT_2;
|
||||
|
||||
/// Applies the rectified linear unit function.
|
||||
pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
|
@ -10,9 +9,7 @@ pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
|||
|
||||
/// Applies the Gaussian Error Linear Units function as described in the paper in [Gaussian Error Linear Units (GELUs)](https://arxiv.org/pdf/1606.08415v3.pdf).
|
||||
pub fn gelu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let x = tensor.clone().div_scalar(SQRT_2).erf().add_scalar(1.0_f32);
|
||||
|
||||
tensor.mul(x) / 2
|
||||
Tensor::from_primitive(B::gelu(tensor.primitive))
|
||||
}
|
||||
|
||||
/// Applies the softmax function.
|
||||
|
|
|
@ -54,6 +54,7 @@ pub trait Backend:
|
|||
+ BoolTensorOps<Self>
|
||||
+ IntTensorOps<Self>
|
||||
+ ModuleOps<Self>
|
||||
+ ActivationOps<Self>
|
||||
+ Clone
|
||||
+ Sized
|
||||
+ Default
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
use crate::{backend::Backend, ElementConversion};
|
||||
use core::f64::consts::SQRT_2;
|
||||
|
||||
/// Activation function operations.
|
||||
///
|
||||
/// This trait let backend implementations override activation functions for better performance.
|
||||
pub trait ActivationOps<B: Backend> {
|
||||
fn relu<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
|
||||
let mask = B::lower_equal_elem(tensor.clone(), 0.elem());
|
||||
|
||||
B::mask_fill(tensor, mask, 0.elem())
|
||||
}
|
||||
fn relu_backward<const D: usize>(
|
||||
output: B::TensorPrimitive<D>,
|
||||
grad: B::TensorPrimitive<D>,
|
||||
) -> B::TensorPrimitive<D> {
|
||||
let mask = B::lower_equal_elem(output, 0.elem());
|
||||
|
||||
B::mask_fill(grad, mask, 0.elem())
|
||||
}
|
||||
fn gelu<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
|
||||
let x = B::div_scalar(tensor.clone(), SQRT_2.elem());
|
||||
let x = B::erf(x);
|
||||
let x = B::add_scalar(x, 1i32.elem());
|
||||
let x = B::mul(tensor, x);
|
||||
|
||||
B::div_scalar(x, 2i32.elem())
|
||||
}
|
||||
|
||||
fn gelu_backward<const D: usize>(
|
||||
x: B::TensorPrimitive<D>,
|
||||
grad: B::TensorPrimitive<D>,
|
||||
) -> B::TensorPrimitive<D> {
|
||||
// Derivative of the approximate gelu implementation based on tanh.
|
||||
|
||||
let constant_1 = 0.0356774;
|
||||
let constant_2 = 0.797885;
|
||||
let constant_3 = 0.0535161;
|
||||
let constant_4 = 0.398942;
|
||||
|
||||
let x3 = B::powf(x.clone(), 3.0);
|
||||
|
||||
let c1 = B::mul_scalar(x3.clone(), constant_1.elem());
|
||||
let c2 = B::mul_scalar(x.clone(), constant_2.elem());
|
||||
let c3 = B::mul_scalar(x3, constant_3.elem());
|
||||
let c4 = B::mul_scalar(x, constant_4.elem());
|
||||
|
||||
let inner1 = B::add(c1, c2);
|
||||
let inner2 = B::add(c3, c4);
|
||||
|
||||
let tanh = B::tanh(inner1);
|
||||
|
||||
let sech = B::powf(tanh.clone(), 2.0);
|
||||
let sech = B::neg(sech);
|
||||
let sech = B::add_scalar(sech, 1.elem());
|
||||
|
||||
let y1 = B::mul_scalar(tanh, 0.5.elem());
|
||||
let y2 = B::mul(inner2, sech);
|
||||
let y2 = B::add_scalar(y2, 0.5.elem());
|
||||
let y = B::add(y1, y2);
|
||||
|
||||
B::mul(y, grad)
|
||||
}
|
||||
}
|
|
@ -1,8 +1,10 @@
|
|||
mod activation;
|
||||
mod bool_tensor;
|
||||
mod int_tensor;
|
||||
mod modules;
|
||||
mod tensor;
|
||||
|
||||
pub use activation::*;
|
||||
pub use bool_tensor::*;
|
||||
pub use int_tensor::*;
|
||||
pub use modules::*;
|
||||
|
|
|
@ -241,5 +241,4 @@ pub trait TensorOps<B: Backend> {
|
|||
tensors: Vec<B::TensorPrimitive<D>>,
|
||||
dim: usize,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
fn relu<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ mod tests {
|
|||
let data_expected = Data::from([[
|
||||
0.3851, 0.8207, 0.2714, 0.0777, 0.6351, 0.2704, 0.1419, 0.3687, 0.4993, 0.5051,
|
||||
]]);
|
||||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
data_expected.assert_approx_eq(&data_actual, 2); // Low precision to allow approximation
|
||||
// implementation using tanh
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue