mirror of https://github.com/tracel-ai/burn.git
Feat/swiglu (#1507)
This commit is contained in:
parent
4542ceddca
commit
613e698007
|
@ -52,7 +52,7 @@ the `Module` derive, you need to be careful to achieve the behavior you want.
|
|||
These methods are available for all modules.
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| --------------------------------------- | ---------------------------------------- |
|
||||
|-----------------------------------------|------------------------------------------|
|
||||
| `module.devices()` | N/A |
|
||||
| `module.fork(device)` | Similar to `module.to(device).detach()` |
|
||||
| `module.to_device(device)` | `module.to(device)` |
|
||||
|
@ -69,7 +69,7 @@ Similar to the backend trait, there is also the `AutodiffModule` trait to signif
|
|||
autodiff support.
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| ---------------- | ------------------ |
|
||||
|------------------|--------------------|
|
||||
| `module.valid()` | `module.eval()` |
|
||||
|
||||
## Visitor & Mapper
|
||||
|
@ -106,24 +106,25 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
|
||||
### General
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| ----------- | --------------------------------------- |
|
||||
| `BatchNorm` | `nn.BatchNorm1d`, `nn.BatchNorm2d` etc. |
|
||||
| `LayerNorm` | `nn.LayerNorm` |
|
||||
| `GroupNorm` | `nn.GroupNorm` |
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|----------------|-----------------------------------------------|
|
||||
| `BatchNorm` | `nn.BatchNorm1d`, `nn.BatchNorm2d` etc. |
|
||||
| `LayerNorm` | `nn.LayerNorm` |
|
||||
| `GroupNorm` | `nn.GroupNorm` |
|
||||
| `InstanceNorm` | `nn.InstanceNorm1d`, `nn.InstanceNorm2d` etc. |
|
||||
| `Dropout` | `nn.Dropout` |
|
||||
| `Gelu` | `nn.Gelu` |
|
||||
| `Prelu` | `nn.PReLu` |
|
||||
| `LeakyRelu` | `nn.LeakyReLu` |
|
||||
| `Linear` | `nn.Linear` |
|
||||
| `Embedding` | `nn.Embedding` |
|
||||
| `Relu` | `nn.ReLU` |
|
||||
| `Dropout` | `nn.Dropout` |
|
||||
| `Gelu` | `nn.Gelu` |
|
||||
| `Prelu` | `nn.PReLu` |
|
||||
| `LeakyRelu` | `nn.LeakyReLu` |
|
||||
| `Linear` | `nn.Linear` |
|
||||
| `Embedding` | `nn.Embedding` |
|
||||
| `Relu` | `nn.ReLU` |
|
||||
| `SwiGlu` | _No direct equivalent_ |
|
||||
|
||||
### Convolutions
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| ----------------- | -------------------- |
|
||||
|-------------------|----------------------|
|
||||
| `Conv1d` | `nn.Conv1d` |
|
||||
| `Conv2d` | `nn.Conv2d` |
|
||||
| `ConvTranspose1d` | `nn.ConvTranspose1d` |
|
||||
|
@ -132,7 +133,7 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
### Pooling
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| ------------------- | ---------------------- |
|
||||
|---------------------|------------------------|
|
||||
| `AdaptiveAvgPool1d` | `nn.AdaptiveAvgPool1d` |
|
||||
| `AdaptiveAvgPool2d` | `nn.AdaptiveAvgPool2d` |
|
||||
| `AvgPool1d` | `nn.AvgPool1d` |
|
||||
|
@ -143,7 +144,7 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
### RNNs
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| ---------------- | ---------------------- |
|
||||
|------------------|------------------------|
|
||||
| `Gru` | `nn.GRU` |
|
||||
| `Lstm` | `nn.LSTM` |
|
||||
| `GateController` | _No direct equivalent_ |
|
||||
|
@ -151,7 +152,7 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
### Transformer
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| -------------------- | ----------------------- |
|
||||
|----------------------|-------------------------|
|
||||
| `MultiHeadAttention` | `nn.MultiheadAttention` |
|
||||
| `TransformerDecoder` | `nn.TransformerDecoder` |
|
||||
| `TransformerEncoder` | `nn.TransformerEncoder` |
|
||||
|
@ -160,7 +161,7 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
### Loss
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| ------------------ | --------------------- |
|
||||
|--------------------|-----------------------|
|
||||
| `CrossEntropyLoss` | `nn.CrossEntropyLoss` |
|
||||
| `MseLoss` | `nn.MSELoss` |
|
||||
| `HuberLoss` | `nn.HuberLoss` |
|
||||
|
|
|
@ -28,6 +28,7 @@ mod pos_encoding;
|
|||
mod prelu;
|
||||
mod relu;
|
||||
mod rnn;
|
||||
mod swiglu;
|
||||
mod unfold;
|
||||
|
||||
pub use dropout::*;
|
||||
|
@ -42,4 +43,5 @@ pub use pos_encoding::*;
|
|||
pub use prelu::*;
|
||||
pub use relu::*;
|
||||
pub use rnn::*;
|
||||
pub use swiglu::*;
|
||||
pub use unfold::*;
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::tensor::activation::silu;
|
||||
use crate::tensor::{backend::Backend, Tensor};
|
||||
|
||||
use super::{Initializer, Linear, LinearConfig};
|
||||
|
||||
/// Configuration to create a [SwiGlu](SwiGlu) activation layer.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct SwiGluConfig {
|
||||
/// The size of the input features.
|
||||
pub d_input: usize,
|
||||
/// The size of the output features.
|
||||
pub d_output: usize,
|
||||
/// If a bias should be applied during the linear transformation. Default behaviour is False
|
||||
/// for SwiGLU activation implementations.
|
||||
#[config(default = false)]
|
||||
pub bias: bool,
|
||||
/// The type of function used to initialize the linear layer parameters
|
||||
#[config(
|
||||
default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}"
|
||||
)]
|
||||
pub initializer: Initializer,
|
||||
}
|
||||
|
||||
/// Applies the SwiGLU or Swish Gated Linear Unit to the input tensor.
|
||||
/// The SwiGLU activation function is defined as:
|
||||
/// `SwiGLU(x) = Swish(W_inner * x + b_inner) * (W_outer * x + b_outer)`
|
||||
///
|
||||
/// # Params
|
||||
///
|
||||
/// - linear inner: The inner linear layer for Swish activation function
|
||||
/// with `d_input` input features and `d_output` output features.
|
||||
/// - linear outer: Outer Linear layer for element wise multiplication
|
||||
/// with `d_input` input features and `d_output` output features.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct SwiGlu<B: Backend> {
|
||||
linear_inner: Linear<B>,
|
||||
linear_outer: Linear<B>,
|
||||
}
|
||||
|
||||
impl SwiGluConfig {
|
||||
/// Initialize a new [SwiGLU](SwiGlu) activation layer.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> SwiGlu<B> {
|
||||
SwiGlu {
|
||||
linear_inner: LinearConfig::new(self.d_input, self.d_output)
|
||||
.with_bias(self.bias)
|
||||
.with_initializer(self.initializer.clone())
|
||||
.init(device),
|
||||
linear_outer: LinearConfig::new(self.d_input, self.d_output)
|
||||
.with_bias(self.bias)
|
||||
.with_initializer(self.initializer.clone())
|
||||
.init(device),
|
||||
}
|
||||
}
|
||||
/// Initialize a new [SwiGlu](SwiGlu) activation layer with a [record](SwiGlu).
|
||||
pub fn init_with<B: Backend>(&self, record: SwiGluRecord<B>) -> SwiGlu<B> {
|
||||
SwiGlu {
|
||||
linear_inner: LinearConfig::new(self.d_input, self.d_output)
|
||||
.with_bias(self.bias)
|
||||
.init_with(record.linear_inner),
|
||||
linear_outer: LinearConfig::new(self.d_input, self.d_output)
|
||||
.with_bias(self.bias)
|
||||
.init_with(record.linear_outer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> SwiGlu<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - tensor: `[batch_size, seq_length, d_input]`
|
||||
/// - output: `[batch_size, seq_length, d_output]`
|
||||
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let x = self.linear_inner.forward(input.clone());
|
||||
let x = silu(x);
|
||||
x.mul(self.linear_outer.forward(input))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
|
||||
#[test]
|
||||
fn test_swiglu_forward_no_bias() {
|
||||
TestBackend::seed(0);
|
||||
let device = Default::default();
|
||||
let config = SwiGluConfig::new(3, 3).with_initializer(Initializer::Constant { value: 0.5 });
|
||||
let swiglu = config.init(&device);
|
||||
let input =
|
||||
Tensor::<TestBackend, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
|
||||
let output = swiglu.forward(input);
|
||||
let expected_output = Tensor::<TestBackend, 2>::from_data(
|
||||
[[8.5732, 8.5732, 8.5732], [56.2189, 56.2189, 56.2189]],
|
||||
&device,
|
||||
);
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq(&expected_output.to_data(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_swiglu_forward_with_bias() {
|
||||
TestBackend::seed(0);
|
||||
let device = Default::default();
|
||||
let config = SwiGluConfig::new(3, 3)
|
||||
.with_bias(true)
|
||||
.with_initializer(Initializer::Constant { value: 0.5 });
|
||||
let swiglu = config.init(&device);
|
||||
let input =
|
||||
Tensor::<TestBackend, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
|
||||
let output = swiglu.forward(input);
|
||||
let expected_output = Tensor::<TestBackend, 2>::from_data(
|
||||
[[11.8909, 11.8909, 11.8909], [63.9785, 63.9785, 63.9785]],
|
||||
&device,
|
||||
);
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq(&expected_output.to_data(), 4);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue