mirror of https://github.com/tracel-ai/burn.git
Add RoPE `init_with_frequency_scaling` (#2194)
* Add RoPE init_with_frequency_scaling * Fix clippy
This commit is contained in:
parent
17de832c6e
commit
4999421f6c
|
@ -31,6 +31,35 @@ impl RotaryEncodingConfig {
|
||||||
/// Panics if the size of input embedding dimension is not even.
|
/// Panics if the size of input embedding dimension is not even.
|
||||||
/// Panics if the theta parameter is not positive.
|
/// Panics if the theta parameter is not positive.
|
||||||
pub fn init<B: Backend>(&self, device: &B::Device) -> RotaryEncoding<B> {
|
pub fn init<B: Backend>(&self, device: &B::Device) -> RotaryEncoding<B> {
|
||||||
|
self.initialize(None, device)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initialize a new [RotaryEncoding](RotaryEncoding) module with a custom frequency scaling function.
|
||||||
|
/// This is useful to apply different RoPE extensions.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the size of input embedding dimension is not even.
|
||||||
|
/// Panics if the theta parameter is not positive.
|
||||||
|
pub fn init_with_frequency_scaling<B: Backend>(
|
||||||
|
&self,
|
||||||
|
scaling: fn(Tensor<B, 1>) -> Tensor<B, 1>,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> RotaryEncoding<B> {
|
||||||
|
self.initialize(Some(scaling), device)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initialize a new [RotaryEncoding](RotaryEncoding) module.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the size of input embedding dimension is not even.
|
||||||
|
/// Panics if the theta parameter is not positive.
|
||||||
|
fn initialize<B: Backend>(
|
||||||
|
&self,
|
||||||
|
scaling: Option<fn(Tensor<B, 1>) -> Tensor<B, 1>>,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> RotaryEncoding<B> {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
self.d_model % 2,
|
self.d_model % 2,
|
||||||
0,
|
0,
|
||||||
|
@ -42,7 +71,7 @@ impl RotaryEncodingConfig {
|
||||||
);
|
);
|
||||||
|
|
||||||
// Calculate the rotation frequencies for positional embeddings based on the formula
|
// Calculate the rotation frequencies for positional embeddings based on the formula
|
||||||
// `theta_i = 1 / (10000 ^ (2i / d_model)) for i in [0..d_model/2]`
|
// `theta_i = 1 / (theta ^ (2i / d_model)) for i in [0..d_model/2]`
|
||||||
let exponent = Tensor::<B, 1, Int>::arange_step(0..self.d_model as i64, 2, device)
|
let exponent = Tensor::<B, 1, Int>::arange_step(0..self.d_model as i64, 2, device)
|
||||||
.float()
|
.float()
|
||||||
.div_scalar(self.d_model as f32);
|
.div_scalar(self.d_model as f32);
|
||||||
|
@ -50,7 +79,11 @@ impl RotaryEncodingConfig {
|
||||||
// Calculate (10000 ^ (2i / d_model)) by using the log base property `exp(log(10000) * (2i / d_model))`
|
// Calculate (10000 ^ (2i / d_model)) by using the log base property `exp(log(10000) * (2i / d_model))`
|
||||||
// This is done since burn doesn't support exponentiation of scalar to tensor
|
// This is done since burn doesn't support exponentiation of scalar to tensor
|
||||||
let theta_i = exponent.mul_scalar(self.theta.ln()).exp();
|
let theta_i = exponent.mul_scalar(self.theta.ln()).exp();
|
||||||
let theta_i = theta_i.powf_scalar(-1.0);
|
let mut theta_i = theta_i.powf_scalar(-1.0);
|
||||||
|
|
||||||
|
if let Some(scaling) = scaling {
|
||||||
|
theta_i = scaling(theta_i)
|
||||||
|
}
|
||||||
|
|
||||||
// Generate frequency values for positional embeddings
|
// Generate frequency values for positional embeddings
|
||||||
let frequencies: Tensor<B, 2> =
|
let frequencies: Tensor<B, 2> =
|
||||||
|
@ -265,6 +298,109 @@ mod tests {
|
||||||
let _output = pe.forward(input);
|
let _output = pe.forward(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rotary_encoding_frequencies() {
|
||||||
|
let device = Default::default();
|
||||||
|
let rotary_encoding = RotaryEncodingConfig::new(2, 8).init::<TestBackend>(&device);
|
||||||
|
|
||||||
|
let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[1.0000, 0.0000],
|
||||||
|
[1.0000, 0.0000],
|
||||||
|
[1.0000, 0.0000],
|
||||||
|
[1.0000, 0.0000],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[5.4030e-01, 8.4147e-01],
|
||||||
|
[9.9500e-01, 9.9833e-02],
|
||||||
|
[9.9995e-01, 9.9998e-03],
|
||||||
|
[9.9999e-01, 9.9999e-04],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
&device,
|
||||||
|
)
|
||||||
|
.unsqueeze_dim::<4>(2)
|
||||||
|
.repeat_dim(2, 2)
|
||||||
|
.reshape([2, 8, 2]);
|
||||||
|
|
||||||
|
rotary_encoding
|
||||||
|
.freq_complex
|
||||||
|
.to_data()
|
||||||
|
.assert_approx_eq(&expected_freqs.to_data(), 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_freq_scaling_by_parts<B: Backend>(freqs: Tensor<B, 1>) -> Tensor<B, 1> {
|
||||||
|
// Adapted from: https://github.com/meta-llama/llama-models/blob/main/models/llama3/reference_impl/model.py#L45
|
||||||
|
let scale_factor = 8.;
|
||||||
|
let low_freq_factor = 1.;
|
||||||
|
let high_freq_factor = 4.;
|
||||||
|
let old_context_len = 8192.;
|
||||||
|
|
||||||
|
let low_freq_wavelen = old_context_len / low_freq_factor;
|
||||||
|
let high_freq_wavelen = old_context_len / high_freq_factor;
|
||||||
|
|
||||||
|
let wavelen = freqs.clone().recip().mul_scalar(2. * core::f32::consts::PI);
|
||||||
|
|
||||||
|
// if wavelen >= high_freq_wavelen
|
||||||
|
let cond = wavelen.clone().greater_equal_elem(high_freq_wavelen);
|
||||||
|
let smooth = wavelen
|
||||||
|
.clone()
|
||||||
|
.recip()
|
||||||
|
.mul_scalar(old_context_len)
|
||||||
|
.sub_scalar(low_freq_factor)
|
||||||
|
.div_scalar(high_freq_factor - low_freq_factor);
|
||||||
|
// (1 - smooth) * freq / scale_factor + smooth * freq
|
||||||
|
let new_freqs = smooth
|
||||||
|
.clone()
|
||||||
|
.neg()
|
||||||
|
.add_scalar(1.)
|
||||||
|
.mul(freqs.clone().div_scalar(scale_factor))
|
||||||
|
.add(smooth.clone().mul(freqs.clone()));
|
||||||
|
let new_freqs = freqs.clone().mask_where(cond, new_freqs);
|
||||||
|
|
||||||
|
// if wavelen > low_freq_wavelen
|
||||||
|
let cond = wavelen.clone().greater_elem(low_freq_wavelen);
|
||||||
|
let new_freqs = new_freqs.mask_where(cond, freqs.clone().div_scalar(scale_factor));
|
||||||
|
|
||||||
|
// if wavelen < high_freq_wavelen
|
||||||
|
let cond = wavelen.lower_elem(high_freq_wavelen);
|
||||||
|
new_freqs.mask_where(cond, freqs)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rotary_encoding_with_frequency_scaling() {
|
||||||
|
let device = Default::default();
|
||||||
|
let rotary_encoding = RotaryEncodingConfig::new(2, 8)
|
||||||
|
.init_with_frequency_scaling::<TestBackend>(apply_freq_scaling_by_parts, &device);
|
||||||
|
|
||||||
|
let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[1.0000, 0.0000],
|
||||||
|
[1.0000, 0.0000],
|
||||||
|
[1.0000, 0.0000],
|
||||||
|
[1.0000, 0.0000],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[5.4030e-01, 8.4148e-01],
|
||||||
|
[9.9500e-01, 9.9833e-02],
|
||||||
|
[9.9995e-01, 9.9998e-03],
|
||||||
|
[1.0000, 2.1361e-04],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
&device,
|
||||||
|
)
|
||||||
|
.unsqueeze_dim::<4>(2)
|
||||||
|
.repeat_dim(2, 2)
|
||||||
|
.reshape([2, 8, 2]);
|
||||||
|
|
||||||
|
rotary_encoding
|
||||||
|
.freq_complex
|
||||||
|
.to_data()
|
||||||
|
.assert_approx_eq(&expected_freqs.to_data(), 4);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn display() {
|
fn display() {
|
||||||
let config = RotaryEncodingConfig::new(10, 4);
|
let config = RotaryEncodingConfig::new(10, 4);
|
||||||
|
|
Loading…
Reference in New Issue