mirror of https://github.com/tracel-ai/burn.git
Add RoPE `init_with_frequency_scaling` (#2194)
* Add RoPE init_with_frequency_scaling * Fix clippy
This commit is contained in:
parent
17de832c6e
commit
4999421f6c
|
@ -31,6 +31,35 @@ impl RotaryEncodingConfig {
|
|||
/// Panics if the size of input embedding dimension is not even.
|
||||
/// Panics if the theta parameter is not positive.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> RotaryEncoding<B> {
|
||||
self.initialize(None, device)
|
||||
}
|
||||
|
||||
/// Initialize a new [RotaryEncoding](RotaryEncoding) module with a custom frequency scaling function.
|
||||
/// This is useful to apply different RoPE extensions.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the size of input embedding dimension is not even.
|
||||
/// Panics if the theta parameter is not positive.
|
||||
pub fn init_with_frequency_scaling<B: Backend>(
|
||||
&self,
|
||||
scaling: fn(Tensor<B, 1>) -> Tensor<B, 1>,
|
||||
device: &B::Device,
|
||||
) -> RotaryEncoding<B> {
|
||||
self.initialize(Some(scaling), device)
|
||||
}
|
||||
|
||||
/// Initialize a new [RotaryEncoding](RotaryEncoding) module.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the size of input embedding dimension is not even.
|
||||
/// Panics if the theta parameter is not positive.
|
||||
fn initialize<B: Backend>(
|
||||
&self,
|
||||
scaling: Option<fn(Tensor<B, 1>) -> Tensor<B, 1>>,
|
||||
device: &B::Device,
|
||||
) -> RotaryEncoding<B> {
|
||||
assert_eq!(
|
||||
self.d_model % 2,
|
||||
0,
|
||||
|
@ -42,7 +71,7 @@ impl RotaryEncodingConfig {
|
|||
);
|
||||
|
||||
// Calculate the rotation frequencies for positional embeddings based on the formula
|
||||
// `theta_i = 1 / (10000 ^ (2i / d_model)) for i in [0..d_model/2]`
|
||||
// `theta_i = 1 / (theta ^ (2i / d_model)) for i in [0..d_model/2]`
|
||||
let exponent = Tensor::<B, 1, Int>::arange_step(0..self.d_model as i64, 2, device)
|
||||
.float()
|
||||
.div_scalar(self.d_model as f32);
|
||||
|
@ -50,7 +79,11 @@ impl RotaryEncodingConfig {
|
|||
// Calculate (10000 ^ (2i / d_model)) by using the log base property `exp(log(10000) * (2i / d_model))`
|
||||
// This is done since burn doesn't support exponentiation of scalar to tensor
|
||||
let theta_i = exponent.mul_scalar(self.theta.ln()).exp();
|
||||
let theta_i = theta_i.powf_scalar(-1.0);
|
||||
let mut theta_i = theta_i.powf_scalar(-1.0);
|
||||
|
||||
if let Some(scaling) = scaling {
|
||||
theta_i = scaling(theta_i)
|
||||
}
|
||||
|
||||
// Generate frequency values for positional embeddings
|
||||
let frequencies: Tensor<B, 2> =
|
||||
|
@ -265,6 +298,109 @@ mod tests {
|
|||
let _output = pe.forward(input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rotary_encoding_frequencies() {
|
||||
let device = Default::default();
|
||||
let rotary_encoding = RotaryEncodingConfig::new(2, 8).init::<TestBackend>(&device);
|
||||
|
||||
let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
|
||||
[
|
||||
[
|
||||
[1.0000, 0.0000],
|
||||
[1.0000, 0.0000],
|
||||
[1.0000, 0.0000],
|
||||
[1.0000, 0.0000],
|
||||
],
|
||||
[
|
||||
[5.4030e-01, 8.4147e-01],
|
||||
[9.9500e-01, 9.9833e-02],
|
||||
[9.9995e-01, 9.9998e-03],
|
||||
[9.9999e-01, 9.9999e-04],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
)
|
||||
.unsqueeze_dim::<4>(2)
|
||||
.repeat_dim(2, 2)
|
||||
.reshape([2, 8, 2]);
|
||||
|
||||
rotary_encoding
|
||||
.freq_complex
|
||||
.to_data()
|
||||
.assert_approx_eq(&expected_freqs.to_data(), 4);
|
||||
}
|
||||
|
||||
fn apply_freq_scaling_by_parts<B: Backend>(freqs: Tensor<B, 1>) -> Tensor<B, 1> {
|
||||
// Adapted from: https://github.com/meta-llama/llama-models/blob/main/models/llama3/reference_impl/model.py#L45
|
||||
let scale_factor = 8.;
|
||||
let low_freq_factor = 1.;
|
||||
let high_freq_factor = 4.;
|
||||
let old_context_len = 8192.;
|
||||
|
||||
let low_freq_wavelen = old_context_len / low_freq_factor;
|
||||
let high_freq_wavelen = old_context_len / high_freq_factor;
|
||||
|
||||
let wavelen = freqs.clone().recip().mul_scalar(2. * core::f32::consts::PI);
|
||||
|
||||
// if wavelen >= high_freq_wavelen
|
||||
let cond = wavelen.clone().greater_equal_elem(high_freq_wavelen);
|
||||
let smooth = wavelen
|
||||
.clone()
|
||||
.recip()
|
||||
.mul_scalar(old_context_len)
|
||||
.sub_scalar(low_freq_factor)
|
||||
.div_scalar(high_freq_factor - low_freq_factor);
|
||||
// (1 - smooth) * freq / scale_factor + smooth * freq
|
||||
let new_freqs = smooth
|
||||
.clone()
|
||||
.neg()
|
||||
.add_scalar(1.)
|
||||
.mul(freqs.clone().div_scalar(scale_factor))
|
||||
.add(smooth.clone().mul(freqs.clone()));
|
||||
let new_freqs = freqs.clone().mask_where(cond, new_freqs);
|
||||
|
||||
// if wavelen > low_freq_wavelen
|
||||
let cond = wavelen.clone().greater_elem(low_freq_wavelen);
|
||||
let new_freqs = new_freqs.mask_where(cond, freqs.clone().div_scalar(scale_factor));
|
||||
|
||||
// if wavelen < high_freq_wavelen
|
||||
let cond = wavelen.lower_elem(high_freq_wavelen);
|
||||
new_freqs.mask_where(cond, freqs)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rotary_encoding_with_frequency_scaling() {
|
||||
let device = Default::default();
|
||||
let rotary_encoding = RotaryEncodingConfig::new(2, 8)
|
||||
.init_with_frequency_scaling::<TestBackend>(apply_freq_scaling_by_parts, &device);
|
||||
|
||||
let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
|
||||
[
|
||||
[
|
||||
[1.0000, 0.0000],
|
||||
[1.0000, 0.0000],
|
||||
[1.0000, 0.0000],
|
||||
[1.0000, 0.0000],
|
||||
],
|
||||
[
|
||||
[5.4030e-01, 8.4148e-01],
|
||||
[9.9500e-01, 9.9833e-02],
|
||||
[9.9995e-01, 9.9998e-03],
|
||||
[1.0000, 2.1361e-04],
|
||||
],
|
||||
],
|
||||
&device,
|
||||
)
|
||||
.unsqueeze_dim::<4>(2)
|
||||
.repeat_dim(2, 2)
|
||||
.reshape([2, 8, 2]);
|
||||
|
||||
rotary_encoding
|
||||
.freq_complex
|
||||
.to_data()
|
||||
.assert_approx_eq(&expected_freqs.to_data(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = RotaryEncodingConfig::new(10, 4);
|
||||
|
|
Loading…
Reference in New Issue