Feat/activation ops (#338)

* perf: GELU

* Refactor relu
This commit is contained in:
Nathaniel Simard 2023-05-09 08:32:35 -04:00 committed by GitHub
parent 844b199dc1
commit 69001b0d69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 195 additions and 47 deletions

View File

@ -0,0 +1,55 @@
use crate::{
grads::Gradients,
ops::{unary, Backward, Ops, OpsKind},
tensor::ADTensor,
ADBackendDecorator,
};
use burn_tensor::{backend::Backend, ops::ActivationOps};
impl<B: Backend> ActivationOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn gelu<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, D> {
#[derive(Debug)]
struct Gelu<const D: usize>;
impl<const D: usize, B: Backend> Backward<B, D, 1> for Gelu<D> {
type State = B::TensorPrimitive<D>;
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let input = ops.state;
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
B::gelu_backward(input, grad)
});
}
}
match Gelu::<D>.prepare([tensor.node], [tensor.graph]).statefull() {
OpsKind::Tracked(prep) => {
let output = B::gelu(tensor.primitive.clone());
prep.finish(tensor.primitive, output)
}
OpsKind::UnTracked(prep) => prep.finish(B::gelu(tensor.primitive)),
}
}
fn relu<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, D> {
#[derive(Debug)]
struct Relu;
impl<B: Backend, const D: usize> Backward<B, D, 1> for Relu {
type State = B::TensorPrimitive<D>;
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
B::relu_backward(ops.state, grad)
});
}
}
let output = B::relu(tensor.primitive);
match Relu.prepare([tensor.node], [tensor.graph]).statefull() {
OpsKind::Tracked(prep) => prep.finish(output.clone(), output),
OpsKind::UnTracked(prep) => prep.finish(output),
}
}
}

View File

@ -1,3 +1,4 @@
mod activation;
mod backward; mod backward;
mod base; mod base;
mod bool_tensor; mod bool_tensor;

View File

@ -1304,29 +1304,6 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
let ops = CatStep::<B, D>::new(nodes, output.node.clone(), dim); let ops = CatStep::<B, D>::new(nodes, output.node.clone(), dim);
output.register_step(ops) output.register_step(ops)
} }
fn relu<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, D> {
#[derive(Debug)]
struct Relu;
impl<B: Backend, const D: usize> Backward<B, D, 1> for Relu {
type State = B::TensorPrimitive<D>;
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
let zero = 0.elem();
let mask = B::lower_equal_elem(ops.state, zero);
B::mask_fill(grad, mask, zero)
});
}
}
let output = B::relu(tensor.primitive);
match Relu.prepare([tensor.node], [tensor.graph]).statefull() {
OpsKind::Tracked(prep) => prep.finish(output.clone(), output),
OpsKind::UnTracked(prep) => prep.finish(output),
}
}
} }
/// Make sure the grad tensor has the given shape. /// Make sure the grad tensor has the given shape.

View File

@ -0,0 +1,25 @@
#[burn_tensor_testgen::testgen(ad_gelu)]
mod tests {
use super::*;
use burn_tensor::{activation, Data};
#[test]
fn should_diff_gelu() {
let tensor_1 = TestADTensor::from_floats([[0.0, 1.0], [-3.0, 4.0]]).require_grad();
let tensor_2 = TestADTensor::from_floats([[6.0, -0.5], [9.0, 10.0]]).require_grad();
let x = tensor_1.clone().matmul(activation::gelu(tensor_2.clone()));
let x = tensor_1.clone().matmul(x);
let grads = x.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_approx_eq(&Data::from([[1.4629, 1.4629], [48.2286, 153.4629]]), 2);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[-15.0000, -1.9895], [17.0000, 17.0000]]), 2);
}
}

View File

@ -12,6 +12,7 @@ mod cross_entropy;
mod div; mod div;
mod erf; mod erf;
mod exp; mod exp;
mod gelu;
mod index; mod index;
mod index_select; mod index_select;
mod index_select_dim; mod index_select_dim;
@ -42,6 +43,10 @@ macro_rules! testgen_all {
// Behavior // Behavior
burn_autodiff::testgen_ad_broadcast!(); burn_autodiff::testgen_ad_broadcast!();
// Activation
burn_autodiff::testgen_ad_relu!();
burn_autodiff::testgen_ad_gelu!();
// Modules // Modules
burn_autodiff::testgen_ad_conv1d!(); burn_autodiff::testgen_ad_conv1d!();
burn_autodiff::testgen_ad_conv2d!(); burn_autodiff::testgen_ad_conv2d!();
@ -70,7 +75,6 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_mul!(); burn_autodiff::testgen_ad_mul!();
burn_autodiff::testgen_ad_neg!(); burn_autodiff::testgen_ad_neg!();
burn_autodiff::testgen_ad_powf!(); burn_autodiff::testgen_ad_powf!();
burn_autodiff::testgen_ad_relu!();
burn_autodiff::testgen_ad_reshape!(); burn_autodiff::testgen_ad_reshape!();
burn_autodiff::testgen_ad_sin!(); burn_autodiff::testgen_ad_sin!();
burn_autodiff::testgen_ad_softmax!(); burn_autodiff::testgen_ad_softmax!();

View File

@ -0,0 +1,17 @@
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
use burn_tensor::{ops::ActivationOps, ElementConversion};
impl<E: FloatNdArrayElement> ActivationOps<NdArrayBackend<E>> for NdArrayBackend<E> {
fn relu<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let zero = 0.elem();
let array = tensor
.array
.mapv_into(|elem| match elem < zero {
true => zero,
false => elem,
})
.into_shared();
NdArrayTensor::new(array)
}
}

View File

@ -1,3 +1,4 @@
mod activations;
mod base; mod base;
mod bool_tensor; mod bool_tensor;
mod int_tensor; mod int_tensor;

View File

@ -417,19 +417,6 @@ impl<E: FloatNdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E>
fn cat<const D: usize>(tensors: Vec<NdArrayTensor<E, D>>, dim: usize) -> NdArrayTensor<E, D> { fn cat<const D: usize>(tensors: Vec<NdArrayTensor<E, D>>, dim: usize) -> NdArrayTensor<E, D> {
NdArrayOps::cat(tensors, dim) NdArrayOps::cat(tensors, dim)
} }
fn relu<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let zero = 0.elem();
let array = tensor
.array
.mapv_into(|elem| match elem < zero {
true => 0.0.elem(),
false => elem,
})
.into_shared();
NdArrayTensor::new(array)
}
} }
fn arg<E: FloatNdArrayElement, F, const D: usize>( fn arg<E: FloatNdArrayElement, F, const D: usize>(

View File

@ -0,0 +1,20 @@
use crate::{element::TchElement, TchBackend, TchTensor};
use burn_tensor::ops::ActivationOps;
impl<E: TchElement> ActivationOps<TchBackend<E>> for TchBackend<E> {
fn relu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())
}
fn gelu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(
|mut tensor| tensor.gelu_("none"),
|tensor| tensor.gelu("none"),
)
}
fn gelu_backward<const D: usize>(
tensor: TchTensor<E, D>,
grad: TchTensor<E, D>,
) -> TchTensor<E, D> {
TchTensor::new(tensor.tensor.gelu_backward(&grad.tensor, "none"))
}
}

View File

@ -1,3 +1,4 @@
mod activation;
mod base; mod base;
mod bool_tensor; mod bool_tensor;
mod int_tensor; mod int_tensor;

View File

@ -380,8 +380,4 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
fn cat<const D: usize>(tensors: Vec<TchTensor<E, D>>, dim: usize) -> TchTensor<E, D> { fn cat<const D: usize>(tensors: Vec<TchTensor<E, D>>, dim: usize) -> TchTensor<E, D> {
TchOps::cat(tensors, dim) TchOps::cat(tensors, dim)
} }
fn relu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())
}
} }

View File

@ -1,7 +1,6 @@
use crate::backend::Backend; use crate::backend::Backend;
use crate::Tensor; use crate::Tensor;
use crate::{ElementPrecision, Precision}; use crate::{ElementPrecision, Precision};
use core::f64::consts::SQRT_2;
/// Applies the rectified linear unit function. /// Applies the rectified linear unit function.
pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> { pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
@ -10,9 +9,7 @@ pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
/// Applies the Gaussian Error Linear Units function as described in the paper in [Gaussian Error Linear Units (GELUs)](https://arxiv.org/pdf/1606.08415v3.pdf). /// Applies the Gaussian Error Linear Units function as described in the paper in [Gaussian Error Linear Units (GELUs)](https://arxiv.org/pdf/1606.08415v3.pdf).
pub fn gelu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> { pub fn gelu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
let x = tensor.clone().div_scalar(SQRT_2).erf().add_scalar(1.0_f32); Tensor::from_primitive(B::gelu(tensor.primitive))
tensor.mul(x) / 2
} }
/// Applies the softmax function. /// Applies the softmax function.

View File

@ -54,6 +54,7 @@ pub trait Backend:
+ BoolTensorOps<Self> + BoolTensorOps<Self>
+ IntTensorOps<Self> + IntTensorOps<Self>
+ ModuleOps<Self> + ModuleOps<Self>
+ ActivationOps<Self>
+ Clone + Clone
+ Sized + Sized
+ Default + Default

View File

@ -0,0 +1,64 @@
use crate::{backend::Backend, ElementConversion};
use core::f64::consts::SQRT_2;
/// Activation function operations.
///
/// This trait let backend implementations override activation functions for better performance.
pub trait ActivationOps<B: Backend> {
fn relu<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
let mask = B::lower_equal_elem(tensor.clone(), 0.elem());
B::mask_fill(tensor, mask, 0.elem())
}
fn relu_backward<const D: usize>(
output: B::TensorPrimitive<D>,
grad: B::TensorPrimitive<D>,
) -> B::TensorPrimitive<D> {
let mask = B::lower_equal_elem(output, 0.elem());
B::mask_fill(grad, mask, 0.elem())
}
fn gelu<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
let x = B::div_scalar(tensor.clone(), SQRT_2.elem());
let x = B::erf(x);
let x = B::add_scalar(x, 1i32.elem());
let x = B::mul(tensor, x);
B::div_scalar(x, 2i32.elem())
}
fn gelu_backward<const D: usize>(
x: B::TensorPrimitive<D>,
grad: B::TensorPrimitive<D>,
) -> B::TensorPrimitive<D> {
// Derivative of the approximate gelu implementation based on tanh.
let constant_1 = 0.0356774;
let constant_2 = 0.797885;
let constant_3 = 0.0535161;
let constant_4 = 0.398942;
let x3 = B::powf(x.clone(), 3.0);
let c1 = B::mul_scalar(x3.clone(), constant_1.elem());
let c2 = B::mul_scalar(x.clone(), constant_2.elem());
let c3 = B::mul_scalar(x3, constant_3.elem());
let c4 = B::mul_scalar(x, constant_4.elem());
let inner1 = B::add(c1, c2);
let inner2 = B::add(c3, c4);
let tanh = B::tanh(inner1);
let sech = B::powf(tanh.clone(), 2.0);
let sech = B::neg(sech);
let sech = B::add_scalar(sech, 1.elem());
let y1 = B::mul_scalar(tanh, 0.5.elem());
let y2 = B::mul(inner2, sech);
let y2 = B::add_scalar(y2, 0.5.elem());
let y = B::add(y1, y2);
B::mul(y, grad)
}
}

View File

@ -1,8 +1,10 @@
mod activation;
mod bool_tensor; mod bool_tensor;
mod int_tensor; mod int_tensor;
mod modules; mod modules;
mod tensor; mod tensor;
pub use activation::*;
pub use bool_tensor::*; pub use bool_tensor::*;
pub use int_tensor::*; pub use int_tensor::*;
pub use modules::*; pub use modules::*;

View File

@ -241,5 +241,4 @@ pub trait TensorOps<B: Backend> {
tensors: Vec<B::TensorPrimitive<D>>, tensors: Vec<B::TensorPrimitive<D>>,
dim: usize, dim: usize,
) -> B::TensorPrimitive<D>; ) -> B::TensorPrimitive<D>;
fn relu<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
} }

View File

@ -15,6 +15,7 @@ mod tests {
let data_expected = Data::from([[ let data_expected = Data::from([[
0.3851, 0.8207, 0.2714, 0.0777, 0.6351, 0.2704, 0.1419, 0.3687, 0.4993, 0.5051, 0.3851, 0.8207, 0.2714, 0.0777, 0.6351, 0.2704, 0.1419, 0.3687, 0.4993, 0.5051,
]]); ]]);
data_expected.assert_approx_eq(&data_actual, 3); data_expected.assert_approx_eq(&data_actual, 2); // Low precision to allow approximation
// implementation using tanh
} }
} }