mirror of https://github.com/tracel-ai/burn.git
support for rotary positional encoding to transformer modules. (#1604)
* add rotary positional encoding to transformer modules. * fix f64 error * use num_traits * add panic condition
This commit is contained in:
parent
06ce2b02d6
commit
fb1da53a38
|
@ -52,7 +52,7 @@ the `Module` derive, you need to be careful to achieve the behavior you want.
|
|||
These methods are available for all modules.
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| --------------------------------------- | ---------------------------------------- |
|
||||
|-----------------------------------------|------------------------------------------|
|
||||
| `module.devices()` | N/A |
|
||||
| `module.fork(device)` | Similar to `module.to(device).detach()` |
|
||||
| `module.to_device(device)` | `module.to(device)` |
|
||||
|
@ -69,7 +69,7 @@ Similar to the backend trait, there is also the `AutodiffModule` trait to signif
|
|||
autodiff support.
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| ---------------- | ------------------ |
|
||||
|------------------|--------------------|
|
||||
| `module.valid()` | `module.eval()` |
|
||||
|
||||
## Visitor & Mapper
|
||||
|
@ -107,7 +107,7 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
### General
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| -------------- | --------------------------------------------- |
|
||||
|----------------|-----------------------------------------------|
|
||||
| `BatchNorm` | `nn.BatchNorm1d`, `nn.BatchNorm2d` etc. |
|
||||
| `Dropout` | `nn.Dropout` |
|
||||
| `Embedding` | `nn.Embedding` |
|
||||
|
@ -125,7 +125,7 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
### Convolutions
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| ----------------- | -------------------- |
|
||||
|-------------------|----------------------|
|
||||
| `Conv1d` | `nn.Conv1d` |
|
||||
| `Conv2d` | `nn.Conv2d` |
|
||||
| `ConvTranspose1d` | `nn.ConvTranspose1d` |
|
||||
|
@ -134,7 +134,7 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
### Pooling
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| ------------------- | ---------------------- |
|
||||
|---------------------|------------------------|
|
||||
| `AdaptiveAvgPool1d` | `nn.AdaptiveAvgPool1d` |
|
||||
| `AdaptiveAvgPool2d` | `nn.AdaptiveAvgPool2d` |
|
||||
| `AvgPool1d` | `nn.AvgPool1d` |
|
||||
|
@ -145,7 +145,7 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
### RNNs
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| ---------------- | ---------------------- |
|
||||
|------------------|------------------------|
|
||||
| `Gru` | `nn.GRU` |
|
||||
| `Lstm` | `nn.LSTM` |
|
||||
| `GateController` | _No direct equivalent_ |
|
||||
|
@ -153,16 +153,17 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
### Transformer
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| -------------------- | ----------------------- |
|
||||
|----------------------|-------------------------|
|
||||
| `MultiHeadAttention` | `nn.MultiheadAttention` |
|
||||
| `TransformerDecoder` | `nn.TransformerDecoder` |
|
||||
| `TransformerEncoder` | `nn.TransformerEncoder` |
|
||||
| `PositionalEncoding` | _No direct equivalent_ |
|
||||
| `RotaryEncoding` | _No direct equivalent_ |
|
||||
|
||||
### Loss
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| ------------------ | --------------------- |
|
||||
|--------------------|-----------------------|
|
||||
| `CrossEntropyLoss` | `nn.CrossEntropyLoss` |
|
||||
| `MseLoss` | `nn.MSELoss` |
|
||||
| `HuberLoss` | `nn.HuberLoss` |
|
||||
|
|
|
@ -28,6 +28,7 @@ mod pos_encoding;
|
|||
mod prelu;
|
||||
mod relu;
|
||||
mod rnn;
|
||||
mod rope_encoding;
|
||||
mod swiglu;
|
||||
mod unfold;
|
||||
|
||||
|
@ -43,5 +44,6 @@ pub use pos_encoding::*;
|
|||
pub use prelu::*;
|
||||
pub use relu::*;
|
||||
pub use rnn::*;
|
||||
pub use rope_encoding::*;
|
||||
pub use swiglu::*;
|
||||
pub use unfold::*;
|
||||
|
|
|
@ -0,0 +1,219 @@
|
|||
use crate as burn;
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
use alloc::vec;
|
||||
use burn_tensor::Int;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use num_traits::Float;
|
||||
|
||||
/// Configuration to create a [RotaryEncoding](RotaryEncoding) layer.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct RotaryEncodingConfig {
|
||||
/// Maximum sequence length of input
|
||||
max_sequence_length: usize,
|
||||
|
||||
/// Size of the input embedding or hidden dimension
|
||||
d_model: usize,
|
||||
|
||||
/// Scaling factor for frequency computation. Defaults to 10000.0
|
||||
#[config(default = "10000.0")]
|
||||
theta: f32,
|
||||
}
|
||||
|
||||
impl RotaryEncodingConfig {
|
||||
/// Initialize a new [RotaryEncoding](RotaryEncoding) module.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the size of input embedding dimension is not even.
|
||||
/// Panics if the theta parameter is not positive.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> RotaryEncoding<B> {
|
||||
assert_eq!(
|
||||
self.d_model % 2,
|
||||
0,
|
||||
"The input embedding dimension must be even"
|
||||
);
|
||||
assert!(
|
||||
self.theta > 0.0,
|
||||
"Theta parameter must be positive (default: 10000)."
|
||||
);
|
||||
|
||||
// Calculate the rotation frequencies for positional embeddings based on the formula
|
||||
// `theta_i = 1 / (10000 ^ (2i / d_model)) for i in [0..d_model/2]`
|
||||
let exponent = Tensor::<B, 1, Int>::arange_step(0..self.d_model as i64, 2, device)
|
||||
.float()
|
||||
.div_scalar(self.d_model as f32);
|
||||
|
||||
// Calculate (10000 ^ (2i / d_model)) by using the log base property `exp(log(10000) * (2i / d_model))`
|
||||
// This is done since burn doesn't support exponentiation of scalar to tensor
|
||||
let theta_i = exponent.mul_scalar(self.theta.ln()).exp();
|
||||
let theta_i = theta_i.powf_scalar(-1.0);
|
||||
|
||||
// Generate frequency values for positional embeddings
|
||||
let frequencies: Tensor<B, 2> =
|
||||
Tensor::<B, 1, Int>::arange(0..self.max_sequence_length as i64, device)
|
||||
.float()
|
||||
.unsqueeze()
|
||||
.transpose()
|
||||
.repeat(1, self.d_model / 2)
|
||||
* theta_i.unsqueeze();
|
||||
|
||||
// Convert frequency values to complex numbers (polar form)
|
||||
let p_cos = frequencies.clone().cos();
|
||||
let p_sin = frequencies.sin();
|
||||
|
||||
// Create the frequency tensor of shape (max_sequence_length, d_model, 2) with the real(cos)
|
||||
// and imaginary(sin) components along last dimension
|
||||
let freq_complex: Tensor<B, 3> = Tensor::cat(vec![p_cos, p_sin], 1)
|
||||
.reshape([self.max_sequence_length, 2, self.d_model / 2])
|
||||
.transpose()
|
||||
.unsqueeze_dim::<4>(2)
|
||||
.repeat(2, 2)
|
||||
.reshape([self.max_sequence_length, self.d_model, 2]);
|
||||
|
||||
RotaryEncoding { freq_complex }
|
||||
}
|
||||
}
|
||||
|
||||
/// A module that applies rotary positional encoding to a tensor.
|
||||
/// Rotary Position Encoding or Embedding (RoPE), is a type of position embedding which encodes
|
||||
/// absolute positional information with rotation matrix and naturally incorporates
|
||||
/// explicit relative position dependency in self-attention formulation.
|
||||
///
|
||||
/// Introduced in the paper: [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864)
|
||||
#[derive(Module, Debug)]
|
||||
pub struct RotaryEncoding<B: Backend> {
|
||||
/// Frequency Tensor of shape (max_sequence_length, d_model, 2) with real and imaginary components
|
||||
freq_complex: Tensor<B, 3>,
|
||||
}
|
||||
|
||||
#[allow(clippy::single_range_in_vec_init)]
|
||||
impl<B: Backend> RotaryEncoding<B> {
|
||||
/// Applies rotary positional encoding to a tensor of dimensions (..., seq_len, d_model)
|
||||
///
|
||||
/// Arguments:
|
||||
/// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors
|
||||
/// for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim)
|
||||
/// respectively.
|
||||
///
|
||||
/// Returns:
|
||||
/// * Output tensor with the same shape as input tensor after applying rotary encoding.
|
||||
///
|
||||
/// Panics if the input tensor does not have at least 2 dimensions for sequence length and hidden dimension.
|
||||
pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
|
||||
assert!(
|
||||
D >= 2,
|
||||
"Input tensor must have at least 2 dimensions for sequence length and hidden dimension"
|
||||
);
|
||||
|
||||
let device = x.device();
|
||||
let input_shape = x.shape();
|
||||
|
||||
// Extract the sequence length and embedding dimension, other dimensions are kept generic
|
||||
// to allow both 3D and 4D tensors i.e. batch_size or (batch_size, num_heads)
|
||||
let (seq_len, d_model) = (x.dims()[D - 2], x.dims()[D - 1]);
|
||||
let dummy_dim_size = input_shape.num_elements() / (seq_len * d_model);
|
||||
|
||||
// Create a dummy tensor with signed ones based on the 2D rotation matrix
|
||||
// [[cos, -sin], [sin, cos]]
|
||||
let sign_tensor =
|
||||
Tensor::from_floats([[1.0, 0.0, 0.0, 1.0], [0.0, -1.0, 1.0, 0.0]], &device);
|
||||
|
||||
// Rotate input using the frequency tensor. Slice the frequencies till input sequence length
|
||||
let out: Tensor<B, 4> = x
|
||||
.reshape([dummy_dim_size, seq_len, d_model / 2, 2])
|
||||
.matmul(sign_tensor.unsqueeze())
|
||||
.reshape([dummy_dim_size, seq_len, d_model, 2])
|
||||
* self.freq_complex.clone().slice([0..seq_len]).unsqueeze();
|
||||
|
||||
// Sum the real and imaginary components to get output tensor and reshape to original shape
|
||||
out.sum_dim(D - 1).reshape(input_shape)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
|
||||
#[test]
|
||||
fn test_rotary_encoding_forward() {
|
||||
let device = Default::default();
|
||||
let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
|
||||
|
||||
let input = Tensor::from_floats(
|
||||
[
|
||||
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
|
||||
[[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
|
||||
// Input = [Batch size, Num of heads, Seq_len, d_model]
|
||||
let input = input.unsqueeze::<4>();
|
||||
|
||||
let output = rotary_encoding.forward(input);
|
||||
let expected_output = Tensor::<TestBackend, 3>::from_floats(
|
||||
[
|
||||
[
|
||||
[1.0000, 2.0000, 3.0000, 4.0000],
|
||||
[-2.3473, 7.4492, 6.9197, 8.0696],
|
||||
],
|
||||
[
|
||||
[9.0000, 10.0000, 11.0000, 12.0000],
|
||||
[-4.7567, 18.5034, 14.8393, 16.1492],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.squeeze(0)
|
||||
.to_data()
|
||||
.assert_approx_eq(&expected_output.to_data(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_input_rotary_encoding_forward() {
|
||||
let device = Default::default();
|
||||
let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
|
||||
|
||||
// Use a tensor of exact zeros as input. The output rotary embedding should be zeros as well
|
||||
let input = Tensor::zeros([1, 2, 2, 4], &device);
|
||||
|
||||
let output = rotary_encoding.forward(input);
|
||||
let expected_output = Tensor::<TestBackend, 3>::from_floats(
|
||||
[
|
||||
[
|
||||
[0.0000, 0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000, 0.0000],
|
||||
],
|
||||
[
|
||||
[0.0000, 0.0000, 0.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.0000, 0.0000],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
|
||||
output
|
||||
.squeeze(0)
|
||||
.to_data()
|
||||
.assert_approx_eq(&expected_output.to_data(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_valid_input_hidden_dim() {
|
||||
// Hidden dimension must be even to be able to split into real and imaginary components
|
||||
// for rotation
|
||||
let d_model = 15;
|
||||
let device = Default::default();
|
||||
let pe = RotaryEncodingConfig::new(10, d_model).init::<TestBackend>(&device);
|
||||
let input = Tensor::zeros([1, 5, d_model], &device);
|
||||
let _output = pe.forward(input);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue