fix the problem of sigmoid gradient generating NaN (#1140)

* use sigmoid derivative formulas

* add test

* fix test error

* move sigmoid to tensor/ops/activation.rs

* use full precision in the default implementation

* rename the param of `sigmoid_backward`
This commit is contained in:
wcshds 2024-01-17 05:20:18 +08:00 committed by GitHub
parent b99726f804
commit a5bdf38c92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 103 additions and 1 deletions

View File

@ -54,4 +54,27 @@ impl<B: Backend> ActivationOps<Autodiff<B>> for Autodiff<B> {
OpsKind::UnTracked(prep) => prep.finish(output),
}
}
fn sigmoid<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
#[derive(Debug)]
struct Sigmoid;
impl<B: Backend, const D: usize> Backward<B, D, 1> for Sigmoid {
type State = B::TensorPrimitive<D>;
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
B::sigmoid_backward(ops.state, grad)
});
}
}
match Sigmoid.prepare([tensor.node], [tensor.graph]).stateful() {
OpsKind::Tracked(prep) => {
let output = B::sigmoid(tensor.primitive);
prep.finish(output.clone(), output)
}
OpsKind::UnTracked(prep) => prep.finish(B::sigmoid(tensor.primitive)),
}
}
}

View File

@ -38,6 +38,7 @@ mod recip;
mod relu;
mod reshape;
mod select;
mod sigmoid;
mod sin;
mod slice;
mod softmax;
@ -103,6 +104,7 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_abs!();
burn_autodiff::testgen_ad_sub!();
burn_autodiff::testgen_ad_tanh!();
burn_autodiff::testgen_ad_sigmoid!();
burn_autodiff::testgen_ad_transpose!();
};
}

View File

@ -0,0 +1,33 @@
#[burn_tensor_testgen::testgen(ad_sigmoid)]
mod tests {
use super::*;
use burn_tensor::{activation, Data};
#[test]
fn should_diff_sigmoid() {
let data = Data::<f32, 1>::from([0.8762]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(data, &device).require_grad();
let tensor_2 = activation::sigmoid(tensor_1.clone());
let grads = tensor_2.backward();
let grad = tensor_1.grad(&grads).unwrap();
grad.to_data().assert_approx_eq(&Data::from([0.207549]), 4);
}
#[test]
fn small_neg_val_should_not_cause_grad_overflow() {
let data = Data::<f32, 1>::from([-90.0]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(data, &device).require_grad();
let tensor_2 = activation::sigmoid(tensor_1.clone());
let grads = tensor_2.backward();
let grad = tensor_1.grad(&grads).unwrap();
grad.to_data().assert_approx_eq(&Data::from([0.0]), 4);
}
}

View File

@ -22,4 +22,8 @@ impl<E: TchElement> ActivationOps<Self> for LibTorch<E> {
TchTensor::from_existing(tensor, storage)
}
fn sigmoid<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(|mut tensor| tensor.sigmoid_(), |tensor| tensor.sigmoid())
}
}

View File

@ -78,7 +78,7 @@ pub fn log_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize)
/// Applies the sigmoid function.
pub fn sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
log_sigmoid(tensor).exp()
Tensor::from_primitive(B::sigmoid(tensor.primitive))
}
/// Applies the log sigmoid function.

View File

@ -1,3 +1,4 @@
use crate::tensor::ops::tensor::TensorOps;
use crate::{backend::Backend, ElementConversion};
use core::f64::consts::SQRT_2;
@ -102,4 +103,43 @@ pub trait ActivationOps<B: Backend> {
B::mul(y, grad)
}
/// Applies the Sigmoid activation function.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The output tensor.
fn sigmoid<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D> {
let tensor_full = B::to_full_precision(&tensor);
let tensor_tmp = B::FullPrecisionBackend::exp(B::FullPrecisionBackend::neg(
B::FullPrecisionBackend::log(B::FullPrecisionBackend::add_scalar(
B::FullPrecisionBackend::exp(B::FullPrecisionBackend::neg(tensor_full)),
1.0.elem(),
)),
));
B::from_full_precision(tensor_tmp)
}
/// Applies the Sigmoid activation function backward.
///
/// # Arguments
///
/// * `output` - The output tensor of the sigmoid function.
/// * `grad` - The gradient.
///
/// # Returns
///
/// The output tensor.
fn sigmoid_backward<const D: usize>(
output: FloatTensor<B, D>,
grad: FloatTensor<B, D>,
) -> FloatTensor<B, D> {
let value = B::mul(output.clone(), B::add_scalar(B::neg(output), 1.0.elem()));
B::mul(value, grad)
}
}