mirror of https://github.com/tracel-ai/burn.git
Adding burn::nn::Sigmoid (#2031)
This commit is contained in:
parent
ed8a91d48a
commit
9804bf81b2
|
@ -29,6 +29,7 @@ mod prelu;
|
||||||
mod relu;
|
mod relu;
|
||||||
mod rnn;
|
mod rnn;
|
||||||
mod rope_encoding;
|
mod rope_encoding;
|
||||||
|
mod sigmoid;
|
||||||
mod swiglu;
|
mod swiglu;
|
||||||
mod tanh;
|
mod tanh;
|
||||||
mod unfold;
|
mod unfold;
|
||||||
|
@ -46,6 +47,7 @@ pub use prelu::*;
|
||||||
pub use relu::*;
|
pub use relu::*;
|
||||||
pub use rnn::*;
|
pub use rnn::*;
|
||||||
pub use rope_encoding::*;
|
pub use rope_encoding::*;
|
||||||
|
pub use sigmoid::*;
|
||||||
pub use swiglu::*;
|
pub use swiglu::*;
|
||||||
pub use tanh::*;
|
pub use tanh::*;
|
||||||
pub use unfold::*;
|
pub use unfold::*;
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
use crate as burn;
|
||||||
|
|
||||||
|
use crate::module::Module;
|
||||||
|
use crate::tensor::backend::Backend;
|
||||||
|
use crate::tensor::Tensor;
|
||||||
|
|
||||||
|
/// Applies the sigmoid function element-wise
|
||||||
|
/// See also [sigmoid](burn::tensor::activation::sigmoid)
|
||||||
|
#[derive(Module, Clone, Debug, Default)]
|
||||||
|
pub struct Sigmoid;
|
||||||
|
|
||||||
|
impl Sigmoid {
|
||||||
|
/// Create the module.
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {}
|
||||||
|
}
|
||||||
|
/// Applies the forward pass on the input tensor.
|
||||||
|
///
|
||||||
|
/// # Shapes
|
||||||
|
///
|
||||||
|
/// - input: `[..., any]`
|
||||||
|
/// - output: `[..., any]`
|
||||||
|
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||||
|
crate::tensor::activation::sigmoid(input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn display() {
|
||||||
|
let layer = Sigmoid::new();
|
||||||
|
|
||||||
|
assert_eq!(alloc::format!("{}", layer), "Sigmoid");
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue