mirror of https://github.com/tracel-ai/burn.git
Move log_sigmoid to activation ops (#1558)
This commit is contained in:
parent
38479be726
commit
8d210a152f
|
@ -128,4 +128,42 @@ impl<B: Backend, C: CheckpointStrategy> ActivationOps<Autodiff<B, C>> for Autodi
|
|||
OpsKind::UnTracked(prep) => prep.finish(B::sigmoid(tensor.primitive)),
|
||||
}
|
||||
}
|
||||
|
||||
fn log_sigmoid<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
#[derive(Debug)]
|
||||
struct LogSigmoid<const D: usize>;
|
||||
|
||||
retro_unary!(RetroLogSigmoid, B::log_sigmoid);
|
||||
|
||||
impl<const D: usize, B: Backend> Backward<B, D, 1> for LogSigmoid<D> {
|
||||
type State = NodeID;
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 1>,
|
||||
grads: &mut Gradients,
|
||||
checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
let input = checkpointer.retrieve_node_output(ops.state);
|
||||
|
||||
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
|
||||
B::log_sigmoid_backward(input, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
match LogSigmoid::<D>
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroLogSigmoid::<B, D>::new(tensor.node.id.clone()))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
OpsKind::Tracked(mut prep) => {
|
||||
let state = prep.checkpoint(&tensor);
|
||||
prep.finish(state, B::log_sigmoid(tensor.primitive.clone()))
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::log_sigmoid(tensor.primitive)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
#[burn_tensor_testgen::testgen(ad_log_sigmoid)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{activation, Data};
|
||||
|
||||
#[test]
|
||||
fn should_diff_log_sigmoid() {
|
||||
let data = Data::<f32, 2>::from([[0.8762, -0.1423], [-300., 200.]]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data, &device).require_grad();
|
||||
let tensor_2 = activation::log_sigmoid(tensor_1.clone());
|
||||
let grads = tensor_2.backward();
|
||||
|
||||
let grad = tensor_1.grad(&grads).unwrap();
|
||||
|
||||
grad.to_data()
|
||||
.assert_approx_eq(&Data::from([[0.293966, 0.535515], [1.000000, 0.000000]]), 4);
|
||||
}
|
||||
}
|
|
@ -29,6 +29,7 @@ mod gelu;
|
|||
mod gradients;
|
||||
mod log;
|
||||
mod log1p;
|
||||
mod log_sigmoid;
|
||||
mod mask;
|
||||
mod matmul;
|
||||
mod maxmin;
|
||||
|
@ -117,6 +118,7 @@ macro_rules! testgen_all {
|
|||
burn_autodiff::testgen_ad_sub!();
|
||||
burn_autodiff::testgen_ad_tanh!();
|
||||
burn_autodiff::testgen_ad_sigmoid!();
|
||||
burn_autodiff::testgen_ad_log_sigmoid!();
|
||||
burn_autodiff::testgen_ad_transpose!();
|
||||
burn_autodiff::testgen_ad_permute!();
|
||||
burn_autodiff::testgen_ad_flip!();
|
||||
|
|
|
@ -26,4 +26,15 @@ impl<E: TchElement> ActivationOps<Self> for LibTorch<E> {
|
|||
fn sigmoid<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
||||
tensor.unary_ops(|mut tensor| tensor.sigmoid_(), |tensor| tensor.sigmoid())
|
||||
}
|
||||
|
||||
fn log_sigmoid<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
||||
// NOTE: we don't override log_sigmoid_backward because Torch has a special backward
|
||||
// formula that uses a buffer with computed values from the forward pass
|
||||
|
||||
// no in-place log_sigmoid_
|
||||
let storage = tensor.storage.clone();
|
||||
let tensor = tensor.tensor.log_sigmoid();
|
||||
|
||||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
use crate::backend::Backend;
|
||||
use crate::check::TensorCheck;
|
||||
use crate::{check, Tensor};
|
||||
use crate::{ElementPrecision, Precision};
|
||||
|
||||
/// Applies the rectified linear unit function.
|
||||
pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
|
@ -125,39 +124,7 @@ pub fn sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D>
|
|||
|
||||
/// Applies the log sigmoid function.
|
||||
pub fn log_sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
/// To avoid overflow, we use the log-sum-exp trick.
|
||||
///
|
||||
/// ```ignore
|
||||
/// log(sigmoid(x)) = log(1/(1 + exp(-x)))
|
||||
/// = log(1) - log(1 + exp(-x))
|
||||
/// = -log(1 + exp(-x))
|
||||
/// = -log(exp(0) + exp(-x))
|
||||
/// ```
|
||||
/// The `exp(t)` of even a moderate-magnitude positive number can be astronomically huge, so we
|
||||
/// subtract the `max(t, 0)` of each value (where `t = -x` in this case). This results in the
|
||||
/// following equivalence:
|
||||
/// ```ignore
|
||||
/// log(sigmoid(x)) = -(max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
|
||||
/// ```
|
||||
///
|
||||
/// This extends the range of values for which we obtain accurate results.
|
||||
fn numerically_stable_log_sigmoid<const D: usize, B: Backend>(x: Tensor<B, D>) -> Tensor<B, D> {
|
||||
// max(-x, 0)
|
||||
let max_elem = x.clone().neg().max_pair(x.zeros_like());
|
||||
|
||||
// log(exp(-max(-x, 0)) + exp(-x - max(-x, 0)))
|
||||
let z = (max_elem.clone().neg().exp() + (x.neg() - max_elem.clone()).exp()).log();
|
||||
|
||||
z.neg() - max_elem
|
||||
}
|
||||
match B::FloatElem::precision() {
|
||||
Precision::Half => {
|
||||
let tensor_full = tensor.into_full_precision();
|
||||
let tensor_tmp = numerically_stable_log_sigmoid(tensor_full);
|
||||
Tensor::from_full_precision(tensor_tmp)
|
||||
}
|
||||
_ => numerically_stable_log_sigmoid(tensor),
|
||||
}
|
||||
Tensor::from_primitive(B::log_sigmoid(tensor.primitive))
|
||||
}
|
||||
|
||||
/// Applies the silu function
|
||||
|
|
|
@ -181,4 +181,98 @@ pub trait ActivationOps<B: Backend> {
|
|||
);
|
||||
B::float_mul(value, grad)
|
||||
}
|
||||
|
||||
/// Applies the LogSigmoid activation function.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The output tensor.
|
||||
fn log_sigmoid<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D> {
|
||||
// To avoid overflow, we use the log-sum-exp trick.
|
||||
//
|
||||
// ```ignore
|
||||
// log(sigmoid(x)) = log(1/(1 + exp(-x)))
|
||||
// = log(1) - log(1 + exp(-x))
|
||||
// = -log(1 + exp(-x))
|
||||
// = -log(exp(0) + exp(-x))
|
||||
// ```
|
||||
// The `exp(t)` of even a moderate-magnitude positive number can be astronomically huge, so we
|
||||
// subtract the `max(t, 0)` of each value (where `t = -x` in this case). This results in the
|
||||
// following equivalence:
|
||||
// ```ignore
|
||||
// log(sigmoid(x)) = -(max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
|
||||
// ```
|
||||
//
|
||||
// This extends the range of values for which we obtain accurate results.
|
||||
|
||||
// max(-x, 0)
|
||||
let tensor_neg = B::float_neg(tensor);
|
||||
let mask = B::float_lower_elem(tensor_neg.clone(), 0.elem());
|
||||
let max_elem = B::float_mask_fill(tensor_neg.clone(), mask, 0.elem());
|
||||
let max_elem_neg = B::float_neg(max_elem.clone());
|
||||
|
||||
// z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
|
||||
let z = B::float_add(
|
||||
B::float_exp(max_elem_neg.clone()),
|
||||
B::float_exp(B::float_sub(tensor_neg, max_elem.clone())),
|
||||
);
|
||||
|
||||
// -max(-x, 0) - log(-z)
|
||||
B::float_sub(max_elem_neg, B::float_log(z))
|
||||
}
|
||||
|
||||
/// Applies the LogSigmoid activation function backward.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `x` - The input tensor.
|
||||
/// * `grad` - The gradient.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The output gradient.
|
||||
fn log_sigmoid_backward<const D: usize>(
|
||||
x: FloatTensor<B, D>,
|
||||
grad: FloatTensor<B, D>,
|
||||
) -> FloatTensor<B, D> {
|
||||
// Derivative of -max(-x, 0) - log(exp(-max(-x, 0)) - exp(-x - max(-x, 0)))) is
|
||||
// -max_derive - (-max_derive * exp(-max(-x, 0)) + (-1 - max_derive) * exp(-x - max(-x, 0))) / z
|
||||
// where z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
|
||||
//
|
||||
// This simplifies to:
|
||||
// -max_derive - (z-1)/z if x is >= 0
|
||||
// -max_derive + (z-1)/z if x is < 0
|
||||
|
||||
let shape = B::float_shape(&x);
|
||||
let device = B::float_device(&x);
|
||||
|
||||
// max(-x, 0)
|
||||
let x_neg = B::float_neg(x);
|
||||
let mask = B::float_lower_elem(x_neg.clone(), 0.elem()); // -x < 0 or x >= 0
|
||||
let max_elem = B::float_mask_fill(x_neg.clone(), mask.clone(), 0.elem());
|
||||
|
||||
// z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
|
||||
let z = B::float_add(
|
||||
B::float_exp(B::float_neg(max_elem.clone())),
|
||||
B::float_exp(B::float_sub(x_neg, max_elem)),
|
||||
);
|
||||
|
||||
// Derivative of max(-x, 0) is 1 if x < 0 or 0 if x >= 0
|
||||
let ones = B::float_ones(shape, &device);
|
||||
let max_derive = B::float_mask_fill(ones.clone(), mask.clone(), 0.elem());
|
||||
let sign = B::float_mask_fill(ones.clone(), mask, (-1).elem());
|
||||
|
||||
// grad * (max_derive - sign * (1 - (1 / z)))
|
||||
B::float_mul(
|
||||
grad,
|
||||
B::float_sub(
|
||||
max_derive,
|
||||
B::float_mul(sign, B::float_sub(ones, B::float_recip(z))),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue