Fast kernels for rotary embeddings. (#1928)
* Fast kernels for rotary embeddings. * Add a test for the fast CPU kernel. * Rope cuda bindings. * Cuda kernel. * Metal kernel (part 1). * Cuda kernels. * Finish the metal kernel. * Use the new kernels in the quantized example. * Fix warning.
This commit is contained in:
parent
cf7d7fcf2f
commit
1b98f84a2b
|
@ -147,6 +147,20 @@ __device__ void softmax(const T * x, T * dst, const int ncols) {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (2 * idx > bh * td) return;
|
||||
|
||||
uint32_t rope_idx = idx % (td / 2);
|
||||
T c = cos[rope_idx];
|
||||
T s = sin[rope_idx];
|
||||
|
||||
dst[2 * idx] = src[2 * idx] * c - src[2 * idx + 1] * s;
|
||||
dst[2 * idx + 1] = src[2 * idx] * s + src[2 * idx + 1] * c;
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
__device__ void
|
||||
fast_max(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
|
@ -402,9 +416,21 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
|||
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
|
||||
} \
|
||||
|
||||
#define ROPEI_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const TYPENAME *src, \
|
||||
const TYPENAME *cos, \
|
||||
const TYPENAME *sin, \
|
||||
TYPENAME *dst, \
|
||||
const uint32_t bh, \
|
||||
const uint32_t td) { \
|
||||
ropei<TYPENAME>(src, cos, sin, dst, bh, td); \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
|
||||
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
|
||||
ROPEI_OP(__nv_bfloat16, rope_i_bf16)
|
||||
SUM_OP(__nv_bfloat16, sum_bf16)
|
||||
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
|
||||
#endif
|
||||
|
@ -412,6 +438,7 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm
|
|||
#if __CUDA_ARCH__ >= 530
|
||||
SOFTMAX_OP(__half, float, softmax_f16)
|
||||
RMSNORM_OP(__half, rmsnorm_f16)
|
||||
ROPEI_OP(__half, rope_i_f16)
|
||||
SUM_OP(__half, sum_f16)
|
||||
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
|
||||
#endif
|
||||
|
@ -423,6 +450,8 @@ SOFTMAX_OP(float, float, softmax_f32)
|
|||
SOFTMAX_OP(double, double, softmax_f64)
|
||||
RMSNORM_OP(float, rmsnorm_f32)
|
||||
RMSNORM_OP(double, rmsnorm_f64)
|
||||
ROPEI_OP(float, rope_i_f32)
|
||||
ROPEI_OP(double, rope_i_f64)
|
||||
|
||||
FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32)
|
||||
FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)
|
||||
|
|
|
@ -808,6 +808,47 @@ pub fn call_rms_norm(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_rope_i(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
kernel_name: &'static str,
|
||||
bh: usize,
|
||||
td: usize,
|
||||
src: &Buffer,
|
||||
src_offset: usize,
|
||||
cos: &Buffer,
|
||||
cos_offset: usize,
|
||||
sin: &Buffer,
|
||||
sin_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
bh,
|
||||
td,
|
||||
(src, src_offset),
|
||||
(cos, cos_offset),
|
||||
(sin, sin_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, (bh * td) / 2);
|
||||
encoder.use_resource(src, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(cos, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_affine(
|
||||
device: &Device,
|
||||
|
|
|
@ -313,6 +313,26 @@ kernel void NAME(
|
|||
} \
|
||||
} \
|
||||
|
||||
#define ROPEI(FN_NAME, TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &bh, \
|
||||
constant size_t &td, \
|
||||
device const TYPENAME *src, \
|
||||
device const TYPENAME *cos, \
|
||||
device const TYPENAME *sin, \
|
||||
device TYPENAME *dst, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (2 * tid >= bh * td) { \
|
||||
return; \
|
||||
} \
|
||||
size_t rope_idx = tid % (td / 2); \
|
||||
TYPENAME c = cos[rope_idx]; \
|
||||
TYPENAME s = sin[rope_idx]; \
|
||||
dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s; \
|
||||
dst[2 * tid + 1] = src[2 * tid] * s + src[2 * tid + 1] * c; \
|
||||
}\
|
||||
|
||||
REDUCE(x + y, fast_sum_f32_strided, float, 0)
|
||||
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
|
||||
REDUCE(x + y, fast_sum_f16_strided, half, 0)
|
||||
|
@ -341,6 +361,8 @@ SOFTMAX(softmax_f32, float)
|
|||
SOFTMAX(softmax_f16, half)
|
||||
RMSNORM(rmsnorm_f32, float)
|
||||
RMSNORM(rmsnorm_f16, half)
|
||||
ROPEI(rope_i_f32, float)
|
||||
ROPEI(rope_i_f16, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
|
||||
|
@ -359,4 +381,5 @@ ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
|||
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
||||
SOFTMAX(softmax_bf16, bfloat)
|
||||
RMSNORM(rmsnorm_bf16, bfloat)
|
||||
ROPEI(rope_i_bf16, bfloat)
|
||||
#endif
|
||||
|
|
|
@ -25,6 +25,7 @@ candle-metal-kernels = { workspace = true, optional = true }
|
|||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
|
|
@ -12,6 +12,7 @@ pub mod loss;
|
|||
pub mod ops;
|
||||
pub mod optim;
|
||||
pub mod rnn;
|
||||
pub mod rotary_emb;
|
||||
pub mod sequential;
|
||||
pub mod var_builder;
|
||||
pub mod var_map;
|
||||
|
|
|
@ -0,0 +1,247 @@
|
|||
use candle::{CpuStorage, Layout, Result, Shape, Tensor, D};
|
||||
use rayon::prelude::*;
|
||||
|
||||
/// Interleaved variant of rotary embeddings.
|
||||
/// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
|
||||
/// The resulting y0 and y1 are also interleaved with:
|
||||
/// y0 = x0*cos - x1*sin
|
||||
/// y1 = x0*sin + x1*cos
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbI;
|
||||
|
||||
impl candle::CustomOp3 for RotaryEmbI {
|
||||
fn name(&self) -> &'static str {
|
||||
"rotary-emb-int"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
s2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
s3: &CpuStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
fn inner<T: candle::WithDType + num_traits::Float>(
|
||||
src: &[T],
|
||||
l_src: &Layout,
|
||||
cos: &[T],
|
||||
l_cos: &Layout,
|
||||
sin: &[T],
|
||||
l_sin: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
let src = match l_src.contiguous_offsets() {
|
||||
None => candle::bail!("input src has to be contiguous"),
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
let cos = match l_cos.contiguous_offsets() {
|
||||
None => candle::bail!("input cos has to be contiguous"),
|
||||
Some((o1, o2)) => &cos[o1..o2],
|
||||
};
|
||||
let sin = match l_sin.contiguous_offsets() {
|
||||
None => candle::bail!("input sin has to be contiguous"),
|
||||
Some((o1, o2)) => &sin[o1..o2],
|
||||
};
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let el_count = b * h * t * d;
|
||||
let mut dst = vec![T::zero(); el_count];
|
||||
src.par_chunks(t * d)
|
||||
.zip(dst.par_chunks_mut(t * d))
|
||||
.for_each(|(src, dst)| {
|
||||
for i_over_2 in 0..t * d / 2 {
|
||||
let i = 2 * i_over_2;
|
||||
dst[i] = src[i] * cos[i_over_2] - src[i + 1] * sin[i_over_2];
|
||||
dst[i + 1] = src[i] * sin[i_over_2] + src[i + 1] * cos[i_over_2];
|
||||
}
|
||||
});
|
||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
Ok((storage, (b, h, t, d).into()))
|
||||
}
|
||||
|
||||
use candle::backend::BackendStorage;
|
||||
use CpuStorage::{BF16, F16, F32, F64};
|
||||
match (s1, s2, s3) {
|
||||
(BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
|
||||
(F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
|
||||
(F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
|
||||
(F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
|
||||
_ => candle::bail!(
|
||||
"unsupported dtype for rope {:?} {:?} {:?}",
|
||||
s1.dtype(),
|
||||
s2.dtype(),
|
||||
s3.dtype()
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
s1: &candle::CudaStorage,
|
||||
l1: &Layout,
|
||||
s2: &candle::CudaStorage,
|
||||
l2: &Layout,
|
||||
s3: &candle::CudaStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
|
||||
fn inner<T: DeviceRepr + WithDType>(
|
||||
src: &CudaSlice<T>,
|
||||
l_src: &Layout,
|
||||
cos: &CudaSlice<T>,
|
||||
l_cos: &Layout,
|
||||
sin: &CudaSlice<T>,
|
||||
l_sin: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let src = match l_src.contiguous_offsets() {
|
||||
None => candle::bail!("src input has to be contiguous"),
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
};
|
||||
let cos = match l_cos.contiguous_offsets() {
|
||||
None => candle::bail!("cos input has to be contiguous"),
|
||||
Some((o1, o2)) => cos.slice(o1..o2),
|
||||
};
|
||||
let sin = match l_sin.contiguous_offsets() {
|
||||
None => candle::bail!("sin input has to be contiguous"),
|
||||
Some((o1, o2)) => sin.slice(o1..o2),
|
||||
};
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let el = b * h * t * d;
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (&src, &cos, &sin, &dst, (b * h) as u32, (t * d) as u32);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
|
||||
let dev = s1.device();
|
||||
let slice = match (&s1.slice, &s2.slice, &s3.slice) {
|
||||
(BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
|
||||
(F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
|
||||
(F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
|
||||
(F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
|
||||
_ => candle::bail!(
|
||||
"unsupported dtype for rope {:?} {:?} {:?}",
|
||||
s1.dtype(),
|
||||
s2.dtype(),
|
||||
s3.dtype()
|
||||
),
|
||||
};
|
||||
let dst = candle::cuda_backend::CudaStorage {
|
||||
slice,
|
||||
device: dev.clone(),
|
||||
};
|
||||
Ok((dst, l1.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
src: &candle::MetalStorage,
|
||||
l_src: &Layout,
|
||||
cos: &candle::MetalStorage,
|
||||
l_cos: &Layout,
|
||||
sin: &candle::MetalStorage,
|
||||
l_sin: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
let device = src.device();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let kernels = device.kernels();
|
||||
if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
|
||||
candle::bail!(
|
||||
"dtype mismatch in rope-i {:?} {:?} {:?}",
|
||||
src.dtype(),
|
||||
cos.dtype(),
|
||||
sin.dtype()
|
||||
)
|
||||
}
|
||||
let name = match src.dtype() {
|
||||
candle::DType::F32 => "rope_i_f32",
|
||||
candle::DType::F16 => "rope_i_f16",
|
||||
candle::DType::BF16 => "rope_i_bf16",
|
||||
dtype => candle::bail!("rope-i is not implemented for {dtype:?}"),
|
||||
};
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let el = b * h * t * d;
|
||||
let output = device.new_buffer(el, src.dtype(), "rope-i")?;
|
||||
candle_metal_kernels::call_rope_i(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
kernels,
|
||||
name,
|
||||
b * h,
|
||||
t * d,
|
||||
src.buffer(),
|
||||
l_src.start_offset() * src.dtype().size_in_bytes(),
|
||||
cos.buffer(),
|
||||
l_cos.start_offset() * cos.dtype().size_in_bytes(),
|
||||
sin.buffer(),
|
||||
l_sin.start_offset() * sin.dtype().size_in_bytes(),
|
||||
&output,
|
||||
)
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype());
|
||||
Ok((out, l_src.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rope_i(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
||||
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
|
||||
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
|
||||
let (sin_seq_len, sin_n_embd) = cos.dims2()?;
|
||||
if cos_n_embd * 2 != n_embd
|
||||
|| sin_n_embd * 2 != n_embd
|
||||
|| seq_len > cos_seq_len
|
||||
|| seq_len > sin_seq_len
|
||||
{
|
||||
candle::bail!(
|
||||
"inconsistent last dim size in rope {:?} {:?} {:?}",
|
||||
xs.shape(),
|
||||
cos.shape(),
|
||||
sin.shape()
|
||||
)
|
||||
}
|
||||
if !xs.is_contiguous() {
|
||||
candle::bail!("xs has to be contiguous in rope")
|
||||
}
|
||||
if !cos.is_contiguous() {
|
||||
candle::bail!("cos has to be contiguous in rope")
|
||||
}
|
||||
if !sin.is_contiguous() {
|
||||
candle::bail!("sin has to be contiguous in rope")
|
||||
}
|
||||
xs.apply_op3_no_bwd(cos, sin, &RotaryEmbI)
|
||||
}
|
||||
|
||||
pub fn rope_i_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
||||
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
|
||||
let cos = cos
|
||||
.narrow(0, 0, seq_len)?
|
||||
.reshape((seq_len, n_embd / 2, 1))?;
|
||||
let sin = sin
|
||||
.narrow(0, 0, seq_len)?
|
||||
.reshape((seq_len, n_embd / 2, 1))?;
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
||||
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
|
||||
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
||||
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
||||
let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
||||
let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
||||
let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
|
||||
let rope = rope.flatten_from(D::Minus2)?;
|
||||
Ok(rope)
|
||||
}
|
|
@ -86,5 +86,33 @@ fn softmax_numerical_stability() -> Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn rope(device: &Device) -> Result<()> {
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
|
||||
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
|
||||
let el_count = b_size * num_head * seq_len * head_dim;
|
||||
let mut rng = StdRng::seed_from_u64(299792458);
|
||||
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
|
||||
let cos: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.gen::<f32>())
|
||||
.collect();
|
||||
let sin: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.gen::<f32>())
|
||||
.collect();
|
||||
let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
|
||||
let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
|
||||
let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?;
|
||||
let rope1 = candle_nn::rotary_emb::rope_i(&src, &cos, &sin)?;
|
||||
let rope2 = candle_nn::rotary_emb::rope_i_slow(&src, &cos, &sin)?;
|
||||
let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
if device.is_cpu() {
|
||||
assert_eq!(sum_diff, 0.);
|
||||
} else if device.is_cuda() {
|
||||
assert!(sum_diff < 1e-4);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(rope, rope_cpu, rope_gpu, rope_metal);
|
||||
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
|
||||
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
|
||||
|
|
|
@ -3,7 +3,7 @@ use std::collections::HashMap;
|
|||
use crate::quantized_nn::RmsNorm;
|
||||
use candle::quantized::QTensor;
|
||||
use candle::quantized::{ggml_file, gguf_file};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
use candle_nn::{Embedding, Module};
|
||||
|
||||
pub const MAX_SEQ_LEN: usize = 4096;
|
||||
|
@ -154,31 +154,10 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Ten
|
|||
impl LayerWeights {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let _enter = self.span_rot.enter();
|
||||
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
|
||||
let cos = self
|
||||
.cos
|
||||
.narrow(0, index_pos, seq_len)?
|
||||
.reshape((seq_len, n_embd / 2, 1))?;
|
||||
let sin = self
|
||||
.sin
|
||||
.narrow(0, index_pos, seq_len)?
|
||||
.reshape((seq_len, n_embd / 2, 1))?;
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
||||
// This mimics the llama.cpp behavior.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105
|
||||
// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
|
||||
// The resulting y0 and y1 are also interleaved with:
|
||||
// y0 = x0*cos - x1*sin
|
||||
// y1 = x0*sin + x1*cos
|
||||
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
|
||||
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
||||
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
||||
let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
||||
let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
||||
let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
|
||||
let rope = rope.flatten_from(D::Minus2)?;
|
||||
Ok(rope)
|
||||
let (_b_sz, _n_head, seq_len, _n_embd) = x.dims4()?;
|
||||
let cos = self.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.sin.narrow(0, index_pos, seq_len)?;
|
||||
candle_nn::rotary_emb::rope_i(&x.contiguous()?, &cos, &sin)
|
||||
}
|
||||
|
||||
fn forward_attn(
|
||||
|
|
Loading…
Reference in New Issue