Add RoPE `init_with_frequency_scaling` (#2194)

* Add RoPE init_with_frequency_scaling

* Fix clippy
This commit is contained in:
Guillaume Lagrange 2024-08-23 10:30:23 -04:00 committed by GitHub
parent 17de832c6e
commit 4999421f6c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 138 additions and 2 deletions

View File

@ -31,6 +31,35 @@ impl RotaryEncodingConfig {
/// Panics if the size of input embedding dimension is not even.
/// Panics if the theta parameter is not positive.
pub fn init<B: Backend>(&self, device: &B::Device) -> RotaryEncoding<B> {
self.initialize(None, device)
}
/// Initialize a new [RotaryEncoding](RotaryEncoding) module with a custom frequency scaling function.
/// This is useful to apply different RoPE extensions.
///
/// # Panics
///
/// Panics if the size of input embedding dimension is not even.
/// Panics if the theta parameter is not positive.
pub fn init_with_frequency_scaling<B: Backend>(
&self,
scaling: fn(Tensor<B, 1>) -> Tensor<B, 1>,
device: &B::Device,
) -> RotaryEncoding<B> {
self.initialize(Some(scaling), device)
}
/// Initialize a new [RotaryEncoding](RotaryEncoding) module.
///
/// # Panics
///
/// Panics if the size of input embedding dimension is not even.
/// Panics if the theta parameter is not positive.
fn initialize<B: Backend>(
&self,
scaling: Option<fn(Tensor<B, 1>) -> Tensor<B, 1>>,
device: &B::Device,
) -> RotaryEncoding<B> {
assert_eq!(
self.d_model % 2,
0,
@ -42,7 +71,7 @@ impl RotaryEncodingConfig {
);
// Calculate the rotation frequencies for positional embeddings based on the formula
// `theta_i = 1 / (10000 ^ (2i / d_model)) for i in [0..d_model/2]`
// `theta_i = 1 / (theta ^ (2i / d_model)) for i in [0..d_model/2]`
let exponent = Tensor::<B, 1, Int>::arange_step(0..self.d_model as i64, 2, device)
.float()
.div_scalar(self.d_model as f32);
@ -50,7 +79,11 @@ impl RotaryEncodingConfig {
// Calculate (10000 ^ (2i / d_model)) by using the log base property `exp(log(10000) * (2i / d_model))`
// This is done since burn doesn't support exponentiation of scalar to tensor
let theta_i = exponent.mul_scalar(self.theta.ln()).exp();
let theta_i = theta_i.powf_scalar(-1.0);
let mut theta_i = theta_i.powf_scalar(-1.0);
if let Some(scaling) = scaling {
theta_i = scaling(theta_i)
}
// Generate frequency values for positional embeddings
let frequencies: Tensor<B, 2> =
@ -265,6 +298,109 @@ mod tests {
let _output = pe.forward(input);
}
#[test]
fn test_rotary_encoding_frequencies() {
let device = Default::default();
let rotary_encoding = RotaryEncodingConfig::new(2, 8).init::<TestBackend>(&device);
let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
[
[
[1.0000, 0.0000],
[1.0000, 0.0000],
[1.0000, 0.0000],
[1.0000, 0.0000],
],
[
[5.4030e-01, 8.4147e-01],
[9.9500e-01, 9.9833e-02],
[9.9995e-01, 9.9998e-03],
[9.9999e-01, 9.9999e-04],
],
],
&device,
)
.unsqueeze_dim::<4>(2)
.repeat_dim(2, 2)
.reshape([2, 8, 2]);
rotary_encoding
.freq_complex
.to_data()
.assert_approx_eq(&expected_freqs.to_data(), 4);
}
fn apply_freq_scaling_by_parts<B: Backend>(freqs: Tensor<B, 1>) -> Tensor<B, 1> {
// Adapted from: https://github.com/meta-llama/llama-models/blob/main/models/llama3/reference_impl/model.py#L45
let scale_factor = 8.;
let low_freq_factor = 1.;
let high_freq_factor = 4.;
let old_context_len = 8192.;
let low_freq_wavelen = old_context_len / low_freq_factor;
let high_freq_wavelen = old_context_len / high_freq_factor;
let wavelen = freqs.clone().recip().mul_scalar(2. * core::f32::consts::PI);
// if wavelen >= high_freq_wavelen
let cond = wavelen.clone().greater_equal_elem(high_freq_wavelen);
let smooth = wavelen
.clone()
.recip()
.mul_scalar(old_context_len)
.sub_scalar(low_freq_factor)
.div_scalar(high_freq_factor - low_freq_factor);
// (1 - smooth) * freq / scale_factor + smooth * freq
let new_freqs = smooth
.clone()
.neg()
.add_scalar(1.)
.mul(freqs.clone().div_scalar(scale_factor))
.add(smooth.clone().mul(freqs.clone()));
let new_freqs = freqs.clone().mask_where(cond, new_freqs);
// if wavelen > low_freq_wavelen
let cond = wavelen.clone().greater_elem(low_freq_wavelen);
let new_freqs = new_freqs.mask_where(cond, freqs.clone().div_scalar(scale_factor));
// if wavelen < high_freq_wavelen
let cond = wavelen.lower_elem(high_freq_wavelen);
new_freqs.mask_where(cond, freqs)
}
#[test]
fn test_rotary_encoding_with_frequency_scaling() {
let device = Default::default();
let rotary_encoding = RotaryEncodingConfig::new(2, 8)
.init_with_frequency_scaling::<TestBackend>(apply_freq_scaling_by_parts, &device);
let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
[
[
[1.0000, 0.0000],
[1.0000, 0.0000],
[1.0000, 0.0000],
[1.0000, 0.0000],
],
[
[5.4030e-01, 8.4148e-01],
[9.9500e-01, 9.9833e-02],
[9.9995e-01, 9.9998e-03],
[1.0000, 2.1361e-04],
],
],
&device,
)
.unsqueeze_dim::<4>(2)
.repeat_dim(2, 2)
.reshape([2, 8, 2]);
rotary_encoding
.freq_complex
.to_data()
.assert_approx_eq(&expected_freqs.to_data(), 4);
}
#[test]
fn display() {
let config = RotaryEncodingConfig::new(10, 4);