Feat/swiglu (#1507)

This commit is contained in:
Aasheesh Singh 2024-03-25 15:55:27 -04:00 committed by GitHub
parent 4542ceddca
commit 613e698007
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 149 additions and 19 deletions

View File

@ -52,7 +52,7 @@ the `Module` derive, you need to be careful to achieve the behavior you want.
These methods are available for all modules.
| Burn API | PyTorch Equivalent |
| --------------------------------------- | ---------------------------------------- |
|-----------------------------------------|------------------------------------------|
| `module.devices()` | N/A |
| `module.fork(device)` | Similar to `module.to(device).detach()` |
| `module.to_device(device)` | `module.to(device)` |
@ -69,7 +69,7 @@ Similar to the backend trait, there is also the `AutodiffModule` trait to signif
autodiff support.
| Burn API | PyTorch Equivalent |
| ---------------- | ------------------ |
|------------------|--------------------|
| `module.valid()` | `module.eval()` |
## Visitor & Mapper
@ -106,24 +106,25 @@ Burn comes with built-in modules that you can use to build your own modules.
### General
| Burn API | PyTorch Equivalent |
| ----------- | --------------------------------------- |
| `BatchNorm` | `nn.BatchNorm1d`, `nn.BatchNorm2d` etc. |
| `LayerNorm` | `nn.LayerNorm` |
| `GroupNorm` | `nn.GroupNorm` |
| Burn API | PyTorch Equivalent |
|----------------|-----------------------------------------------|
| `BatchNorm` | `nn.BatchNorm1d`, `nn.BatchNorm2d` etc. |
| `LayerNorm` | `nn.LayerNorm` |
| `GroupNorm` | `nn.GroupNorm` |
| `InstanceNorm` | `nn.InstanceNorm1d`, `nn.InstanceNorm2d` etc. |
| `Dropout` | `nn.Dropout` |
| `Gelu` | `nn.Gelu` |
| `Prelu` | `nn.PReLu` |
| `LeakyRelu` | `nn.LeakyReLu` |
| `Linear` | `nn.Linear` |
| `Embedding` | `nn.Embedding` |
| `Relu` | `nn.ReLU` |
| `Dropout` | `nn.Dropout` |
| `Gelu` | `nn.Gelu` |
| `Prelu` | `nn.PReLu` |
| `LeakyRelu` | `nn.LeakyReLu` |
| `Linear` | `nn.Linear` |
| `Embedding` | `nn.Embedding` |
| `Relu` | `nn.ReLU` |
| `SwiGlu` | _No direct equivalent_ |
### Convolutions
| Burn API | PyTorch Equivalent |
| ----------------- | -------------------- |
|-------------------|----------------------|
| `Conv1d` | `nn.Conv1d` |
| `Conv2d` | `nn.Conv2d` |
| `ConvTranspose1d` | `nn.ConvTranspose1d` |
@ -132,7 +133,7 @@ Burn comes with built-in modules that you can use to build your own modules.
### Pooling
| Burn API | PyTorch Equivalent |
| ------------------- | ---------------------- |
|---------------------|------------------------|
| `AdaptiveAvgPool1d` | `nn.AdaptiveAvgPool1d` |
| `AdaptiveAvgPool2d` | `nn.AdaptiveAvgPool2d` |
| `AvgPool1d` | `nn.AvgPool1d` |
@ -143,7 +144,7 @@ Burn comes with built-in modules that you can use to build your own modules.
### RNNs
| Burn API | PyTorch Equivalent |
| ---------------- | ---------------------- |
|------------------|------------------------|
| `Gru` | `nn.GRU` |
| `Lstm` | `nn.LSTM` |
| `GateController` | _No direct equivalent_ |
@ -151,7 +152,7 @@ Burn comes with built-in modules that you can use to build your own modules.
### Transformer
| Burn API | PyTorch Equivalent |
| -------------------- | ----------------------- |
|----------------------|-------------------------|
| `MultiHeadAttention` | `nn.MultiheadAttention` |
| `TransformerDecoder` | `nn.TransformerDecoder` |
| `TransformerEncoder` | `nn.TransformerEncoder` |
@ -160,7 +161,7 @@ Burn comes with built-in modules that you can use to build your own modules.
### Loss
| Burn API | PyTorch Equivalent |
| ------------------ | --------------------- |
|--------------------|-----------------------|
| `CrossEntropyLoss` | `nn.CrossEntropyLoss` |
| `MseLoss` | `nn.MSELoss` |
| `HuberLoss` | `nn.HuberLoss` |

View File

@ -28,6 +28,7 @@ mod pos_encoding;
mod prelu;
mod relu;
mod rnn;
mod swiglu;
mod unfold;
pub use dropout::*;
@ -42,4 +43,5 @@ pub use pos_encoding::*;
pub use prelu::*;
pub use relu::*;
pub use rnn::*;
pub use swiglu::*;
pub use unfold::*;

View File

@ -0,0 +1,127 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::tensor::activation::silu;
use crate::tensor::{backend::Backend, Tensor};
use super::{Initializer, Linear, LinearConfig};
/// Configuration to create a [SwiGlu](SwiGlu) activation layer.
#[derive(Config, Debug)]
pub struct SwiGluConfig {
/// The size of the input features.
pub d_input: usize,
/// The size of the output features.
pub d_output: usize,
/// If a bias should be applied during the linear transformation. Default behaviour is False
/// for SwiGLU activation implementations.
#[config(default = false)]
pub bias: bool,
/// The type of function used to initialize the linear layer parameters
#[config(
default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}"
)]
pub initializer: Initializer,
}
/// Applies the SwiGLU or Swish Gated Linear Unit to the input tensor.
/// The SwiGLU activation function is defined as:
/// `SwiGLU(x) = Swish(W_inner * x + b_inner) * (W_outer * x + b_outer)`
///
/// # Params
///
/// - linear inner: The inner linear layer for Swish activation function
/// with `d_input` input features and `d_output` output features.
/// - linear outer: Outer Linear layer for element wise multiplication
/// with `d_input` input features and `d_output` output features.
#[derive(Module, Debug)]
pub struct SwiGlu<B: Backend> {
linear_inner: Linear<B>,
linear_outer: Linear<B>,
}
impl SwiGluConfig {
/// Initialize a new [SwiGLU](SwiGlu) activation layer.
pub fn init<B: Backend>(&self, device: &B::Device) -> SwiGlu<B> {
SwiGlu {
linear_inner: LinearConfig::new(self.d_input, self.d_output)
.with_bias(self.bias)
.with_initializer(self.initializer.clone())
.init(device),
linear_outer: LinearConfig::new(self.d_input, self.d_output)
.with_bias(self.bias)
.with_initializer(self.initializer.clone())
.init(device),
}
}
/// Initialize a new [SwiGlu](SwiGlu) activation layer with a [record](SwiGlu).
pub fn init_with<B: Backend>(&self, record: SwiGluRecord<B>) -> SwiGlu<B> {
SwiGlu {
linear_inner: LinearConfig::new(self.d_input, self.d_output)
.with_bias(self.bias)
.init_with(record.linear_inner),
linear_outer: LinearConfig::new(self.d_input, self.d_output)
.with_bias(self.bias)
.init_with(record.linear_outer),
}
}
}
impl<B: Backend> SwiGlu<B> {
/// Applies the forward pass on the input tensor.
///
/// # Shapes
///
/// - tensor: `[batch_size, seq_length, d_input]`
/// - output: `[batch_size, seq_length, d_output]`
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let x = self.linear_inner.forward(input.clone());
let x = silu(x);
x.mul(self.linear_outer.forward(input))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn test_swiglu_forward_no_bias() {
TestBackend::seed(0);
let device = Default::default();
let config = SwiGluConfig::new(3, 3).with_initializer(Initializer::Constant { value: 0.5 });
let swiglu = config.init(&device);
let input =
Tensor::<TestBackend, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
let output = swiglu.forward(input);
let expected_output = Tensor::<TestBackend, 2>::from_data(
[[8.5732, 8.5732, 8.5732], [56.2189, 56.2189, 56.2189]],
&device,
);
output
.to_data()
.assert_approx_eq(&expected_output.to_data(), 4);
}
#[test]
fn test_swiglu_forward_with_bias() {
TestBackend::seed(0);
let device = Default::default();
let config = SwiGluConfig::new(3, 3)
.with_bias(true)
.with_initializer(Initializer::Constant { value: 0.5 });
let swiglu = config.init(&device);
let input =
Tensor::<TestBackend, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
let output = swiglu.forward(input);
let expected_output = Tensor::<TestBackend, 2>::from_data(
[[11.8909, 11.8909, 11.8909], [63.9785, 63.9785, 63.9785]],
&device,
);
output
.to_data()
.assert_approx_eq(&expected_output.to_data(), 4);
}
}