mirror of https://github.com/tracel-ai/burn.git
fix the problem of sigmoid gradient generating NaN (#1140)
* use sigmoid derivative formulas * add test * fix test error * move sigmoid to tensor/ops/activation.rs * use full precision in the default implementation * rename the param of `sigmoid_backward`
This commit is contained in:
parent
b99726f804
commit
a5bdf38c92
|
@ -54,4 +54,27 @@ impl<B: Backend> ActivationOps<Autodiff<B>> for Autodiff<B> {
|
|||
OpsKind::UnTracked(prep) => prep.finish(output),
|
||||
}
|
||||
}
|
||||
|
||||
fn sigmoid<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
#[derive(Debug)]
|
||||
struct Sigmoid;
|
||||
|
||||
impl<B: Backend, const D: usize> Backward<B, D, 1> for Sigmoid {
|
||||
type State = B::TensorPrimitive<D>;
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
||||
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
|
||||
B::sigmoid_backward(ops.state, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
match Sigmoid.prepare([tensor.node], [tensor.graph]).stateful() {
|
||||
OpsKind::Tracked(prep) => {
|
||||
let output = B::sigmoid(tensor.primitive);
|
||||
prep.finish(output.clone(), output)
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::sigmoid(tensor.primitive)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,6 +38,7 @@ mod recip;
|
|||
mod relu;
|
||||
mod reshape;
|
||||
mod select;
|
||||
mod sigmoid;
|
||||
mod sin;
|
||||
mod slice;
|
||||
mod softmax;
|
||||
|
@ -103,6 +104,7 @@ macro_rules! testgen_all {
|
|||
burn_autodiff::testgen_ad_abs!();
|
||||
burn_autodiff::testgen_ad_sub!();
|
||||
burn_autodiff::testgen_ad_tanh!();
|
||||
burn_autodiff::testgen_ad_sigmoid!();
|
||||
burn_autodiff::testgen_ad_transpose!();
|
||||
};
|
||||
}
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
#[burn_tensor_testgen::testgen(ad_sigmoid)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{activation, Data};
|
||||
|
||||
#[test]
|
||||
fn should_diff_sigmoid() {
|
||||
let data = Data::<f32, 1>::from([0.8762]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data, &device).require_grad();
|
||||
let tensor_2 = activation::sigmoid(tensor_1.clone());
|
||||
let grads = tensor_2.backward();
|
||||
|
||||
let grad = tensor_1.grad(&grads).unwrap();
|
||||
|
||||
grad.to_data().assert_approx_eq(&Data::from([0.207549]), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn small_neg_val_should_not_cause_grad_overflow() {
|
||||
let data = Data::<f32, 1>::from([-90.0]);
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data, &device).require_grad();
|
||||
let tensor_2 = activation::sigmoid(tensor_1.clone());
|
||||
let grads = tensor_2.backward();
|
||||
|
||||
let grad = tensor_1.grad(&grads).unwrap();
|
||||
|
||||
grad.to_data().assert_approx_eq(&Data::from([0.0]), 4);
|
||||
}
|
||||
}
|
|
@ -22,4 +22,8 @@ impl<E: TchElement> ActivationOps<Self> for LibTorch<E> {
|
|||
|
||||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
|
||||
fn sigmoid<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
||||
tensor.unary_ops(|mut tensor| tensor.sigmoid_(), |tensor| tensor.sigmoid())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -78,7 +78,7 @@ pub fn log_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize)
|
|||
|
||||
/// Applies the sigmoid function.
|
||||
pub fn sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
log_sigmoid(tensor).exp()
|
||||
Tensor::from_primitive(B::sigmoid(tensor.primitive))
|
||||
}
|
||||
|
||||
/// Applies the log sigmoid function.
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use crate::tensor::ops::tensor::TensorOps;
|
||||
use crate::{backend::Backend, ElementConversion};
|
||||
use core::f64::consts::SQRT_2;
|
||||
|
||||
|
@ -102,4 +103,43 @@ pub trait ActivationOps<B: Backend> {
|
|||
|
||||
B::mul(y, grad)
|
||||
}
|
||||
|
||||
/// Applies the Sigmoid activation function.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The output tensor.
|
||||
fn sigmoid<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D> {
|
||||
let tensor_full = B::to_full_precision(&tensor);
|
||||
let tensor_tmp = B::FullPrecisionBackend::exp(B::FullPrecisionBackend::neg(
|
||||
B::FullPrecisionBackend::log(B::FullPrecisionBackend::add_scalar(
|
||||
B::FullPrecisionBackend::exp(B::FullPrecisionBackend::neg(tensor_full)),
|
||||
1.0.elem(),
|
||||
)),
|
||||
));
|
||||
|
||||
B::from_full_precision(tensor_tmp)
|
||||
}
|
||||
|
||||
/// Applies the Sigmoid activation function backward.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `output` - The output tensor of the sigmoid function.
|
||||
/// * `grad` - The gradient.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The output tensor.
|
||||
fn sigmoid_backward<const D: usize>(
|
||||
output: FloatTensor<B, D>,
|
||||
grad: FloatTensor<B, D>,
|
||||
) -> FloatTensor<B, D> {
|
||||
let value = B::mul(output.clone(), B::add_scalar(B::neg(output), 1.0.elem()));
|
||||
B::mul(value, grad)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue