mirror of https://github.com/tracel-ai/burn.git
Move log_sigmoid to activation ops (#1558)
This commit is contained in:
parent
38479be726
commit
8d210a152f
|
@ -128,4 +128,42 @@ impl<B: Backend, C: CheckpointStrategy> ActivationOps<Autodiff<B, C>> for Autodi
|
||||||
OpsKind::UnTracked(prep) => prep.finish(B::sigmoid(tensor.primitive)),
|
OpsKind::UnTracked(prep) => prep.finish(B::sigmoid(tensor.primitive)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn log_sigmoid<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct LogSigmoid<const D: usize>;
|
||||||
|
|
||||||
|
retro_unary!(RetroLogSigmoid, B::log_sigmoid);
|
||||||
|
|
||||||
|
impl<const D: usize, B: Backend> Backward<B, D, 1> for LogSigmoid<D> {
|
||||||
|
type State = NodeID;
|
||||||
|
|
||||||
|
fn backward(
|
||||||
|
self,
|
||||||
|
ops: Ops<Self::State, 1>,
|
||||||
|
grads: &mut Gradients,
|
||||||
|
checkpointer: &mut Checkpointer,
|
||||||
|
) {
|
||||||
|
let input = checkpointer.retrieve_node_output(ops.state);
|
||||||
|
|
||||||
|
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
|
||||||
|
B::log_sigmoid_backward(input, grad)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match LogSigmoid::<D>
|
||||||
|
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||||
|
.memory_bound()
|
||||||
|
.retro_forward(RetroLogSigmoid::<B, D>::new(tensor.node.id.clone()))
|
||||||
|
.parents([&tensor])
|
||||||
|
.stateful()
|
||||||
|
{
|
||||||
|
OpsKind::Tracked(mut prep) => {
|
||||||
|
let state = prep.checkpoint(&tensor);
|
||||||
|
prep.finish(state, B::log_sigmoid(tensor.primitive.clone()))
|
||||||
|
}
|
||||||
|
OpsKind::UnTracked(prep) => prep.finish(B::log_sigmoid(tensor.primitive)),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
#[burn_tensor_testgen::testgen(ad_log_sigmoid)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn_tensor::{activation, Data};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_diff_log_sigmoid() {
|
||||||
|
let data = Data::<f32, 2>::from([[0.8762, -0.1423], [-300., 200.]]);
|
||||||
|
|
||||||
|
let device = Default::default();
|
||||||
|
let tensor_1 = TestAutodiffTensor::from_data(data, &device).require_grad();
|
||||||
|
let tensor_2 = activation::log_sigmoid(tensor_1.clone());
|
||||||
|
let grads = tensor_2.backward();
|
||||||
|
|
||||||
|
let grad = tensor_1.grad(&grads).unwrap();
|
||||||
|
|
||||||
|
grad.to_data()
|
||||||
|
.assert_approx_eq(&Data::from([[0.293966, 0.535515], [1.000000, 0.000000]]), 4);
|
||||||
|
}
|
||||||
|
}
|
|
@ -29,6 +29,7 @@ mod gelu;
|
||||||
mod gradients;
|
mod gradients;
|
||||||
mod log;
|
mod log;
|
||||||
mod log1p;
|
mod log1p;
|
||||||
|
mod log_sigmoid;
|
||||||
mod mask;
|
mod mask;
|
||||||
mod matmul;
|
mod matmul;
|
||||||
mod maxmin;
|
mod maxmin;
|
||||||
|
@ -117,6 +118,7 @@ macro_rules! testgen_all {
|
||||||
burn_autodiff::testgen_ad_sub!();
|
burn_autodiff::testgen_ad_sub!();
|
||||||
burn_autodiff::testgen_ad_tanh!();
|
burn_autodiff::testgen_ad_tanh!();
|
||||||
burn_autodiff::testgen_ad_sigmoid!();
|
burn_autodiff::testgen_ad_sigmoid!();
|
||||||
|
burn_autodiff::testgen_ad_log_sigmoid!();
|
||||||
burn_autodiff::testgen_ad_transpose!();
|
burn_autodiff::testgen_ad_transpose!();
|
||||||
burn_autodiff::testgen_ad_permute!();
|
burn_autodiff::testgen_ad_permute!();
|
||||||
burn_autodiff::testgen_ad_flip!();
|
burn_autodiff::testgen_ad_flip!();
|
||||||
|
|
|
@ -26,4 +26,15 @@ impl<E: TchElement> ActivationOps<Self> for LibTorch<E> {
|
||||||
fn sigmoid<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
fn sigmoid<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
||||||
tensor.unary_ops(|mut tensor| tensor.sigmoid_(), |tensor| tensor.sigmoid())
|
tensor.unary_ops(|mut tensor| tensor.sigmoid_(), |tensor| tensor.sigmoid())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn log_sigmoid<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
||||||
|
// NOTE: we don't override log_sigmoid_backward because Torch has a special backward
|
||||||
|
// formula that uses a buffer with computed values from the forward pass
|
||||||
|
|
||||||
|
// no in-place log_sigmoid_
|
||||||
|
let storage = tensor.storage.clone();
|
||||||
|
let tensor = tensor.tensor.log_sigmoid();
|
||||||
|
|
||||||
|
TchTensor::from_existing(tensor, storage)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
use crate::backend::Backend;
|
use crate::backend::Backend;
|
||||||
use crate::check::TensorCheck;
|
use crate::check::TensorCheck;
|
||||||
use crate::{check, Tensor};
|
use crate::{check, Tensor};
|
||||||
use crate::{ElementPrecision, Precision};
|
|
||||||
|
|
||||||
/// Applies the rectified linear unit function.
|
/// Applies the rectified linear unit function.
|
||||||
pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||||
|
@ -125,39 +124,7 @@ pub fn sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D>
|
||||||
|
|
||||||
/// Applies the log sigmoid function.
|
/// Applies the log sigmoid function.
|
||||||
pub fn log_sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
pub fn log_sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||||
/// To avoid overflow, we use the log-sum-exp trick.
|
Tensor::from_primitive(B::log_sigmoid(tensor.primitive))
|
||||||
///
|
|
||||||
/// ```ignore
|
|
||||||
/// log(sigmoid(x)) = log(1/(1 + exp(-x)))
|
|
||||||
/// = log(1) - log(1 + exp(-x))
|
|
||||||
/// = -log(1 + exp(-x))
|
|
||||||
/// = -log(exp(0) + exp(-x))
|
|
||||||
/// ```
|
|
||||||
/// The `exp(t)` of even a moderate-magnitude positive number can be astronomically huge, so we
|
|
||||||
/// subtract the `max(t, 0)` of each value (where `t = -x` in this case). This results in the
|
|
||||||
/// following equivalence:
|
|
||||||
/// ```ignore
|
|
||||||
/// log(sigmoid(x)) = -(max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// This extends the range of values for which we obtain accurate results.
|
|
||||||
fn numerically_stable_log_sigmoid<const D: usize, B: Backend>(x: Tensor<B, D>) -> Tensor<B, D> {
|
|
||||||
// max(-x, 0)
|
|
||||||
let max_elem = x.clone().neg().max_pair(x.zeros_like());
|
|
||||||
|
|
||||||
// log(exp(-max(-x, 0)) + exp(-x - max(-x, 0)))
|
|
||||||
let z = (max_elem.clone().neg().exp() + (x.neg() - max_elem.clone()).exp()).log();
|
|
||||||
|
|
||||||
z.neg() - max_elem
|
|
||||||
}
|
|
||||||
match B::FloatElem::precision() {
|
|
||||||
Precision::Half => {
|
|
||||||
let tensor_full = tensor.into_full_precision();
|
|
||||||
let tensor_tmp = numerically_stable_log_sigmoid(tensor_full);
|
|
||||||
Tensor::from_full_precision(tensor_tmp)
|
|
||||||
}
|
|
||||||
_ => numerically_stable_log_sigmoid(tensor),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Applies the silu function
|
/// Applies the silu function
|
||||||
|
|
|
@ -181,4 +181,98 @@ pub trait ActivationOps<B: Backend> {
|
||||||
);
|
);
|
||||||
B::float_mul(value, grad)
|
B::float_mul(value, grad)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Applies the LogSigmoid activation function.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `tensor` - The tensor.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// The output tensor.
|
||||||
|
fn log_sigmoid<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D> {
|
||||||
|
// To avoid overflow, we use the log-sum-exp trick.
|
||||||
|
//
|
||||||
|
// ```ignore
|
||||||
|
// log(sigmoid(x)) = log(1/(1 + exp(-x)))
|
||||||
|
// = log(1) - log(1 + exp(-x))
|
||||||
|
// = -log(1 + exp(-x))
|
||||||
|
// = -log(exp(0) + exp(-x))
|
||||||
|
// ```
|
||||||
|
// The `exp(t)` of even a moderate-magnitude positive number can be astronomically huge, so we
|
||||||
|
// subtract the `max(t, 0)` of each value (where `t = -x` in this case). This results in the
|
||||||
|
// following equivalence:
|
||||||
|
// ```ignore
|
||||||
|
// log(sigmoid(x)) = -(max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
|
||||||
|
// ```
|
||||||
|
//
|
||||||
|
// This extends the range of values for which we obtain accurate results.
|
||||||
|
|
||||||
|
// max(-x, 0)
|
||||||
|
let tensor_neg = B::float_neg(tensor);
|
||||||
|
let mask = B::float_lower_elem(tensor_neg.clone(), 0.elem());
|
||||||
|
let max_elem = B::float_mask_fill(tensor_neg.clone(), mask, 0.elem());
|
||||||
|
let max_elem_neg = B::float_neg(max_elem.clone());
|
||||||
|
|
||||||
|
// z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
|
||||||
|
let z = B::float_add(
|
||||||
|
B::float_exp(max_elem_neg.clone()),
|
||||||
|
B::float_exp(B::float_sub(tensor_neg, max_elem.clone())),
|
||||||
|
);
|
||||||
|
|
||||||
|
// -max(-x, 0) - log(-z)
|
||||||
|
B::float_sub(max_elem_neg, B::float_log(z))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies the LogSigmoid activation function backward.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `x` - The input tensor.
|
||||||
|
/// * `grad` - The gradient.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// The output gradient.
|
||||||
|
fn log_sigmoid_backward<const D: usize>(
|
||||||
|
x: FloatTensor<B, D>,
|
||||||
|
grad: FloatTensor<B, D>,
|
||||||
|
) -> FloatTensor<B, D> {
|
||||||
|
// Derivative of -max(-x, 0) - log(exp(-max(-x, 0)) - exp(-x - max(-x, 0)))) is
|
||||||
|
// -max_derive - (-max_derive * exp(-max(-x, 0)) + (-1 - max_derive) * exp(-x - max(-x, 0))) / z
|
||||||
|
// where z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
|
||||||
|
//
|
||||||
|
// This simplifies to:
|
||||||
|
// -max_derive - (z-1)/z if x is >= 0
|
||||||
|
// -max_derive + (z-1)/z if x is < 0
|
||||||
|
|
||||||
|
let shape = B::float_shape(&x);
|
||||||
|
let device = B::float_device(&x);
|
||||||
|
|
||||||
|
// max(-x, 0)
|
||||||
|
let x_neg = B::float_neg(x);
|
||||||
|
let mask = B::float_lower_elem(x_neg.clone(), 0.elem()); // -x < 0 or x >= 0
|
||||||
|
let max_elem = B::float_mask_fill(x_neg.clone(), mask.clone(), 0.elem());
|
||||||
|
|
||||||
|
// z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
|
||||||
|
let z = B::float_add(
|
||||||
|
B::float_exp(B::float_neg(max_elem.clone())),
|
||||||
|
B::float_exp(B::float_sub(x_neg, max_elem)),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Derivative of max(-x, 0) is 1 if x < 0 or 0 if x >= 0
|
||||||
|
let ones = B::float_ones(shape, &device);
|
||||||
|
let max_derive = B::float_mask_fill(ones.clone(), mask.clone(), 0.elem());
|
||||||
|
let sign = B::float_mask_fill(ones.clone(), mask, (-1).elem());
|
||||||
|
|
||||||
|
// grad * (max_derive - sign * (1 - (1 / z)))
|
||||||
|
B::float_mul(
|
||||||
|
grad,
|
||||||
|
B::float_sub(
|
||||||
|
max_derive,
|
||||||
|
B::float_mul(sign, B::float_sub(ones, B::float_recip(z))),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue