mirror of https://github.com/tracel-ai/burn.git
Feat: Add Leaky Relu Model (#1467)
This commit is contained in:
parent
53eb3ecfa9
commit
4de1272344
|
@ -115,6 +115,7 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
| `Dropout` | `nn.Dropout` |
|
||||
| `Gelu` | `nn.Gelu` |
|
||||
| `Prelu` | `nn.PReLu` |
|
||||
| `LeakyRelu` | `nn.LeakyReLu` |
|
||||
| `Linear` | `nn.Linear` |
|
||||
| `Embedding` | `nn.Embedding` |
|
||||
| `Relu` | `nn.ReLU` |
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
use crate as burn;
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Data;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
/// Leaky ReLu layer.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct LeakyRelu<B: Backend> {
|
||||
/// The weight used in Leaky ReLu
|
||||
pub negative_slope: Tensor<B, 1>,
|
||||
}
|
||||
/// Configuration to create a [Leaky Relu](LeakyRelu) layer.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct LeakyReluConfig {
|
||||
/// The negative slope. Default is 0.01
|
||||
#[config(default = "0.01")]
|
||||
pub negative_slope: f32,
|
||||
}
|
||||
impl LeakyReluConfig {
|
||||
/// Initialize a new [Leaky Relu](LeakyRelu) Layer
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> LeakyRelu<B> {
|
||||
LeakyRelu {
|
||||
negative_slope: Tensor::from_data(Data::from([self.negative_slope]).convert(), device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> LeakyRelu<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
// leaky relu is a special case of prelu where the weights are all the same. and the
|
||||
// negative_slope is not learnable
|
||||
crate::tensor::activation::prelu(input, self.negative_slope.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn test_leaky_relu_forward() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let model: LeakyRelu<TestBackend> = LeakyReluConfig::new().init(&device);
|
||||
let input = Tensor::<TestBackend, 2>::from_data(Data::from([[0.4410, -0.2507]]), &device);
|
||||
let out = model.forward(input);
|
||||
assert_eq!(out.to_data(), Data::from([[0.4410, -0.002507]]));
|
||||
}
|
||||
#[test]
|
||||
fn test_leaky_relu_forward_multi_dim() {
|
||||
let input = [
|
||||
[
|
||||
[-1.0222, 1.5810, 0.3457, -1.3530],
|
||||
[0.0231, 0.8681, 0.2473, -0.0377],
|
||||
[0.3520, -1.1199, 1.2219, 0.2804],
|
||||
],
|
||||
[
|
||||
[1.0002, 0.7259, 0.8779, 0.2084],
|
||||
[1.5615, -0.1057, -0.4886, -1.5184],
|
||||
[-0.5523, -0.2741, -0.0210, -1.1352],
|
||||
],
|
||||
];
|
||||
let expected_output = [
|
||||
[
|
||||
[-1.0222e-02, 1.5810e+00, 3.457e-01, -1.3530e-02],
|
||||
[2.31e-02, 8.681e-01, 2.473e-01, -3.77e-04],
|
||||
[3.52e-01, -1.1199e-02, 1.2219e+00, 2.804e-01],
|
||||
],
|
||||
[
|
||||
[1.0002e+00, 7.259e-01, 8.779e-01, 2.084e-01],
|
||||
[1.5615e+00, -1.057e-03, -4.886e-03, -1.5184e-02],
|
||||
[-5.523e-03, -2.741e-03, -2.1e-04, -1.1352e-02],
|
||||
],
|
||||
];
|
||||
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let model: LeakyRelu<TestBackend> = LeakyReluConfig::new().init(&device);
|
||||
let input_data = Tensor::<TestBackend, 3>::from_data(Data::from(input), &device);
|
||||
let actual_output = model.forward(input_data);
|
||||
actual_output
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from(expected_output), 4)
|
||||
}
|
||||
}
|
|
@ -20,6 +20,7 @@ mod dropout;
|
|||
mod embedding;
|
||||
mod gelu;
|
||||
mod initializer;
|
||||
mod leaky_relu;
|
||||
mod linear;
|
||||
mod norm;
|
||||
mod padding;
|
||||
|
@ -33,6 +34,7 @@ pub use dropout::*;
|
|||
pub use embedding::*;
|
||||
pub use gelu::*;
|
||||
pub use initializer::*;
|
||||
pub use leaky_relu::*;
|
||||
pub use linear::*;
|
||||
pub use norm::*;
|
||||
pub use padding::*;
|
||||
|
|
Loading…
Reference in New Issue