Move log_sigmoid to activation ops (#1558)

This commit is contained in:
Guillaume Lagrange 2024-04-02 09:25:40 -04:00 committed by GitHub
parent 38479be726
commit 8d210a152f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 166 additions and 34 deletions

View File

@ -128,4 +128,42 @@ impl<B: Backend, C: CheckpointStrategy> ActivationOps<Autodiff<B, C>> for Autodi
OpsKind::UnTracked(prep) => prep.finish(B::sigmoid(tensor.primitive)), OpsKind::UnTracked(prep) => prep.finish(B::sigmoid(tensor.primitive)),
} }
} }
fn log_sigmoid<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
#[derive(Debug)]
struct LogSigmoid<const D: usize>;
retro_unary!(RetroLogSigmoid, B::log_sigmoid);
impl<const D: usize, B: Backend> Backward<B, D, 1> for LogSigmoid<D> {
type State = NodeID;
fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
) {
let input = checkpointer.retrieve_node_output(ops.state);
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
B::log_sigmoid_backward(input, grad)
});
}
}
match LogSigmoid::<D>
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.memory_bound()
.retro_forward(RetroLogSigmoid::<B, D>::new(tensor.node.id.clone()))
.parents([&tensor])
.stateful()
{
OpsKind::Tracked(mut prep) => {
let state = prep.checkpoint(&tensor);
prep.finish(state, B::log_sigmoid(tensor.primitive.clone()))
}
OpsKind::UnTracked(prep) => prep.finish(B::log_sigmoid(tensor.primitive)),
}
}
} }

View File

@ -0,0 +1,20 @@
#[burn_tensor_testgen::testgen(ad_log_sigmoid)]
mod tests {
use super::*;
use burn_tensor::{activation, Data};
#[test]
fn should_diff_log_sigmoid() {
let data = Data::<f32, 2>::from([[0.8762, -0.1423], [-300., 200.]]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(data, &device).require_grad();
let tensor_2 = activation::log_sigmoid(tensor_1.clone());
let grads = tensor_2.backward();
let grad = tensor_1.grad(&grads).unwrap();
grad.to_data()
.assert_approx_eq(&Data::from([[0.293966, 0.535515], [1.000000, 0.000000]]), 4);
}
}

View File

@ -29,6 +29,7 @@ mod gelu;
mod gradients; mod gradients;
mod log; mod log;
mod log1p; mod log1p;
mod log_sigmoid;
mod mask; mod mask;
mod matmul; mod matmul;
mod maxmin; mod maxmin;
@ -117,6 +118,7 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_sub!(); burn_autodiff::testgen_ad_sub!();
burn_autodiff::testgen_ad_tanh!(); burn_autodiff::testgen_ad_tanh!();
burn_autodiff::testgen_ad_sigmoid!(); burn_autodiff::testgen_ad_sigmoid!();
burn_autodiff::testgen_ad_log_sigmoid!();
burn_autodiff::testgen_ad_transpose!(); burn_autodiff::testgen_ad_transpose!();
burn_autodiff::testgen_ad_permute!(); burn_autodiff::testgen_ad_permute!();
burn_autodiff::testgen_ad_flip!(); burn_autodiff::testgen_ad_flip!();

View File

@ -26,4 +26,15 @@ impl<E: TchElement> ActivationOps<Self> for LibTorch<E> {
fn sigmoid<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> { fn sigmoid<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(|mut tensor| tensor.sigmoid_(), |tensor| tensor.sigmoid()) tensor.unary_ops(|mut tensor| tensor.sigmoid_(), |tensor| tensor.sigmoid())
} }
fn log_sigmoid<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
// NOTE: we don't override log_sigmoid_backward because Torch has a special backward
// formula that uses a buffer with computed values from the forward pass
// no in-place log_sigmoid_
let storage = tensor.storage.clone();
let tensor = tensor.tensor.log_sigmoid();
TchTensor::from_existing(tensor, storage)
}
} }

View File

@ -1,7 +1,6 @@
use crate::backend::Backend; use crate::backend::Backend;
use crate::check::TensorCheck; use crate::check::TensorCheck;
use crate::{check, Tensor}; use crate::{check, Tensor};
use crate::{ElementPrecision, Precision};
/// Applies the rectified linear unit function. /// Applies the rectified linear unit function.
pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> { pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
@ -125,39 +124,7 @@ pub fn sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D>
/// Applies the log sigmoid function. /// Applies the log sigmoid function.
pub fn log_sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> { pub fn log_sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
/// To avoid overflow, we use the log-sum-exp trick. Tensor::from_primitive(B::log_sigmoid(tensor.primitive))
///
/// ```ignore
/// log(sigmoid(x)) = log(1/(1 + exp(-x)))
/// = log(1) - log(1 + exp(-x))
/// = -log(1 + exp(-x))
/// = -log(exp(0) + exp(-x))
/// ```
/// The `exp(t)` of even a moderate-magnitude positive number can be astronomically huge, so we
/// subtract the `max(t, 0)` of each value (where `t = -x` in this case). This results in the
/// following equivalence:
/// ```ignore
/// log(sigmoid(x)) = -(max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
/// ```
///
/// This extends the range of values for which we obtain accurate results.
fn numerically_stable_log_sigmoid<const D: usize, B: Backend>(x: Tensor<B, D>) -> Tensor<B, D> {
// max(-x, 0)
let max_elem = x.clone().neg().max_pair(x.zeros_like());
// log(exp(-max(-x, 0)) + exp(-x - max(-x, 0)))
let z = (max_elem.clone().neg().exp() + (x.neg() - max_elem.clone()).exp()).log();
z.neg() - max_elem
}
match B::FloatElem::precision() {
Precision::Half => {
let tensor_full = tensor.into_full_precision();
let tensor_tmp = numerically_stable_log_sigmoid(tensor_full);
Tensor::from_full_precision(tensor_tmp)
}
_ => numerically_stable_log_sigmoid(tensor),
}
} }
/// Applies the silu function /// Applies the silu function

View File

@ -181,4 +181,98 @@ pub trait ActivationOps<B: Backend> {
); );
B::float_mul(value, grad) B::float_mul(value, grad)
} }
/// Applies the LogSigmoid activation function.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The output tensor.
fn log_sigmoid<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D> {
// To avoid overflow, we use the log-sum-exp trick.
//
// ```ignore
// log(sigmoid(x)) = log(1/(1 + exp(-x)))
// = log(1) - log(1 + exp(-x))
// = -log(1 + exp(-x))
// = -log(exp(0) + exp(-x))
// ```
// The `exp(t)` of even a moderate-magnitude positive number can be astronomically huge, so we
// subtract the `max(t, 0)` of each value (where `t = -x` in this case). This results in the
// following equivalence:
// ```ignore
// log(sigmoid(x)) = -(max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
// ```
//
// This extends the range of values for which we obtain accurate results.
// max(-x, 0)
let tensor_neg = B::float_neg(tensor);
let mask = B::float_lower_elem(tensor_neg.clone(), 0.elem());
let max_elem = B::float_mask_fill(tensor_neg.clone(), mask, 0.elem());
let max_elem_neg = B::float_neg(max_elem.clone());
// z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
let z = B::float_add(
B::float_exp(max_elem_neg.clone()),
B::float_exp(B::float_sub(tensor_neg, max_elem.clone())),
);
// -max(-x, 0) - log(-z)
B::float_sub(max_elem_neg, B::float_log(z))
}
/// Applies the LogSigmoid activation function backward.
///
/// # Arguments
///
/// * `x` - The input tensor.
/// * `grad` - The gradient.
///
/// # Returns
///
/// The output gradient.
fn log_sigmoid_backward<const D: usize>(
x: FloatTensor<B, D>,
grad: FloatTensor<B, D>,
) -> FloatTensor<B, D> {
// Derivative of -max(-x, 0) - log(exp(-max(-x, 0)) - exp(-x - max(-x, 0)))) is
// -max_derive - (-max_derive * exp(-max(-x, 0)) + (-1 - max_derive) * exp(-x - max(-x, 0))) / z
// where z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
//
// This simplifies to:
// -max_derive - (z-1)/z if x is >= 0
// -max_derive + (z-1)/z if x is < 0
let shape = B::float_shape(&x);
let device = B::float_device(&x);
// max(-x, 0)
let x_neg = B::float_neg(x);
let mask = B::float_lower_elem(x_neg.clone(), 0.elem()); // -x < 0 or x >= 0
let max_elem = B::float_mask_fill(x_neg.clone(), mask.clone(), 0.elem());
// z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
let z = B::float_add(
B::float_exp(B::float_neg(max_elem.clone())),
B::float_exp(B::float_sub(x_neg, max_elem)),
);
// Derivative of max(-x, 0) is 1 if x < 0 or 0 if x >= 0
let ones = B::float_ones(shape, &device);
let max_derive = B::float_mask_fill(ones.clone(), mask.clone(), 0.elem());
let sign = B::float_mask_fill(ones.clone(), mask, (-1).elem());
// grad * (max_derive - sign * (1 - (1 / z)))
B::float_mul(
grad,
B::float_sub(
max_derive,
B::float_mul(sign, B::float_sub(ones, B::float_recip(z))),
),
)
}
} }