Contiguous variant of the rope kernel. (#1929)

* Contiguous variant of the rope kernel.

* Add the cuda kernel.

* Metal kernel.
This commit is contained in:
Laurent Mazare 2024-03-25 09:11:20 +01:00 committed by GitHub
parent 1b98f84a2b
commit e7f8e72588
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 389 additions and 13 deletions

View File

@ -160,6 +160,24 @@ __device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, cons
dst[2 * idx + 1] = src[2 * idx] * s + src[2 * idx + 1] * c;
}
template <typename T>
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (2 * idx > bh * td) return;
uint32_t i_bh = idx / (td / 2);
uint32_t i_td = idx - (td / 2) * i_bh;
uint32_t i_t = i_td / (d / 2);
uint32_t i_d = i_td - (d / 2) * i_t;
uint32_t i1 = i_bh * td + i_t * d + i_d;
uint32_t i2 = i1 + d / 2;
uint32_t i_cs = i_t * (d / 2) + i_d;
T c = cos[i_cs];
T s = sin[i_cs];
dst[i1] = src[i1] * c - src[i2] * s;
dst[i2] = src[i1] * s + src[i2] * c;
}
template <typename T>
__device__ void
@ -416,8 +434,8 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
} \
#define ROPEI_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I) \
extern "C" __global__ void FN_NAME_I( \
const TYPENAME *src, \
const TYPENAME *cos, \
const TYPENAME *sin, \
@ -426,11 +444,21 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
const uint32_t td) { \
ropei<TYPENAME>(src, cos, sin, dst, bh, td); \
} \
extern "C" __global__ void FN_NAME( \
const TYPENAME *src, \
const TYPENAME *cos, \
const TYPENAME *sin, \
TYPENAME *dst, \
const uint32_t bh, \
const uint32_t td, \
const uint32_t d) { \
rope<TYPENAME>(src, cos, sin, dst, bh, td, d); \
} \
#if __CUDA_ARCH__ >= 800
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
ROPEI_OP(__nv_bfloat16, rope_i_bf16)
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16)
SUM_OP(__nv_bfloat16, sum_bf16)
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
#endif
@ -438,7 +466,7 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm
#if __CUDA_ARCH__ >= 530
SOFTMAX_OP(__half, float, softmax_f16)
RMSNORM_OP(__half, rmsnorm_f16)
ROPEI_OP(__half, rope_i_f16)
ROPE_OP(__half, rope_f16, rope_i_f16)
SUM_OP(__half, sum_f16)
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
#endif
@ -450,8 +478,8 @@ SOFTMAX_OP(float, float, softmax_f32)
SOFTMAX_OP(double, double, softmax_f64)
RMSNORM_OP(float, rmsnorm_f32)
RMSNORM_OP(double, rmsnorm_f64)
ROPEI_OP(float, rope_i_f32)
ROPEI_OP(double, rope_i_f64)
ROPE_OP(float, rope_f32, rope_i_f32)
ROPE_OP(double, rope_f64, rope_i_f64)
FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32)
FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)

View File

@ -849,6 +849,49 @@ pub fn call_rope_i(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_rope(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: &'static str,
bh: usize,
td: usize,
d: usize,
src: &Buffer,
src_offset: usize,
cos: &Buffer,
cos_offset: usize,
sin: &Buffer,
sin_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
bh,
td,
d,
(src, src_offset),
(cos, cos_offset),
(sin, sin_offset),
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, (bh * td) / 2);
encoder.use_resource(src, metal::MTLResourceUsage::Read);
encoder.use_resource(cos, metal::MTLResourceUsage::Read);
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_affine(
device: &Device,

View File

@ -313,8 +313,8 @@ kernel void NAME(
} \
} \
#define ROPEI(FN_NAME, TYPENAME) \
kernel void FN_NAME( \
#define ROPEI(FN_NAME, FN_NAME_I, TYPENAME) \
kernel void FN_NAME_I( \
constant size_t &bh, \
constant size_t &td, \
device const TYPENAME *src, \
@ -332,6 +332,31 @@ kernel void FN_NAME( \
dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s; \
dst[2 * tid + 1] = src[2 * tid] * s + src[2 * tid + 1] * c; \
}\
kernel void FN_NAME( \
constant size_t &bh, \
constant size_t &td, \
constant size_t &d, \
device const TYPENAME *src, \
device const TYPENAME *cos, \
device const TYPENAME *sin, \
device TYPENAME *dst, \
uint idx [[ thread_position_in_grid ]] \
) { \
if (2 * idx >= bh * td) { \
return; \
} \
size_t i_bh = idx / (td / 2); \
size_t i_td = idx - (td / 2) * i_bh; \
size_t i_t = i_td / (d / 2); \
size_t i_d = i_td - (d / 2) * i_t; \
size_t i1 = i_bh * td + i_t * d + i_d; \
size_t i2 = i1 + d / 2; \
size_t i_cs = i_t * (d / 2) + i_d; \
TYPENAME c = cos[i_cs]; \
TYPENAME s = sin[i_cs]; \
dst[i1] = src[i1] * c - src[i2] * s; \
dst[i2] = src[i1] * s + src[i2] * c; \
}\
REDUCE(x + y, fast_sum_f32_strided, float, 0)
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
@ -361,8 +386,8 @@ SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
RMSNORM(rmsnorm_f32, float)
RMSNORM(rmsnorm_f16, half)
ROPEI(rope_i_f32, float)
ROPEI(rope_i_f16, half)
ROPEI(rope_f32, rope_i_f32, float)
ROPEI(rope_f16, rope_i_f16, half)
#if __METAL_VERSION__ >= 220
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
@ -381,5 +406,5 @@ ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
SOFTMAX(softmax_bf16, bfloat)
RMSNORM(rmsnorm_bf16, bfloat)
ROPEI(rope_i_bf16, bfloat)
ROPEI(rope_bf16, rope_i_bf16, bfloat)
#endif

View File

@ -245,3 +245,255 @@ pub fn rope_i_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let rope = rope.flatten_from(D::Minus2)?;
Ok(rope)
}
/// Contiguous variant of rope embeddings.
#[derive(Debug, Clone)]
struct RotaryEmb;
impl candle::CustomOp3 for RotaryEmb {
fn name(&self) -> &'static str {
"rotary-emb"
}
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
s3: &CpuStorage,
l3: &Layout,
) -> Result<(CpuStorage, Shape)> {
fn inner<T: candle::WithDType + num_traits::Float>(
src: &[T],
l_src: &Layout,
cos: &[T],
l_cos: &Layout,
sin: &[T],
l_sin: &Layout,
) -> Result<(CpuStorage, Shape)> {
let src = match l_src.contiguous_offsets() {
None => candle::bail!("input src has to be contiguous"),
Some((o1, o2)) => &src[o1..o2],
};
let cos = match l_cos.contiguous_offsets() {
None => candle::bail!("input cos has to be contiguous"),
Some((o1, o2)) => &cos[o1..o2],
};
let sin = match l_sin.contiguous_offsets() {
None => candle::bail!("input sin has to be contiguous"),
Some((o1, o2)) => &sin[o1..o2],
};
let (b, h, t, d) = l_src.shape().dims4()?;
let el_count = b * h * t * d;
let mut dst = vec![T::zero(); el_count];
src.par_chunks(t * d)
.zip(dst.par_chunks_mut(t * d))
.for_each(|(src, dst)| {
for i_t in 0..t {
for i_d in 0..d / 2 {
let i1 = i_t * d + i_d;
let i2 = i1 + d / 2;
let i_cs = i_t * (d / 2) + i_d;
dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
}
}
});
let storage = candle::WithDType::to_cpu_storage_owned(dst);
Ok((storage, (b, h, t, d).into()))
}
use candle::backend::BackendStorage;
use CpuStorage::{BF16, F16, F32, F64};
match (s1, s2, s3) {
(BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
(F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
(F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
(F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
_ => candle::bail!(
"unsupported dtype for rope {:?} {:?} {:?}",
s1.dtype(),
s2.dtype(),
s3.dtype()
),
}
}
#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
s1: &candle::CudaStorage,
l1: &Layout,
s2: &candle::CudaStorage,
l2: &Layout,
s3: &candle::CudaStorage,
l3: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
};
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
use candle::{CudaDevice, WithDType};
fn inner<T: DeviceRepr + WithDType>(
src: &CudaSlice<T>,
l_src: &Layout,
cos: &CudaSlice<T>,
l_cos: &Layout,
sin: &CudaSlice<T>,
l_sin: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
let src = match l_src.contiguous_offsets() {
None => candle::bail!("src input has to be contiguous"),
Some((o1, o2)) => src.slice(o1..o2),
};
let cos = match l_cos.contiguous_offsets() {
None => candle::bail!("cos input has to be contiguous"),
Some((o1, o2)) => cos.slice(o1..o2),
};
let sin = match l_sin.contiguous_offsets() {
None => candle::bail!("sin input has to be contiguous"),
Some((o1, o2)) => sin.slice(o1..o2),
};
let (b, h, t, d) = l_src.shape().dims4()?;
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
&src,
&cos,
&sin,
&dst,
(b * h) as u32,
(t * d) as u32,
d as u32,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
use candle::backend::BackendStorage;
use candle::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
let dev = s1.device();
let slice = match (&s1.slice, &s2.slice, &s3.slice) {
(BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
(F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
(F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
(F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
_ => candle::bail!(
"unsupported dtype for rope {:?} {:?} {:?}",
s1.dtype(),
s2.dtype(),
s3.dtype()
),
};
let dst = candle::cuda_backend::CudaStorage {
slice,
device: dev.clone(),
};
Ok((dst, l1.shape().clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
src: &candle::MetalStorage,
l_src: &Layout,
cos: &candle::MetalStorage,
l_cos: &Layout,
sin: &candle::MetalStorage,
l_sin: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::BackendStorage;
let device = src.device();
let command_buffer = device.command_buffer()?;
let kernels = device.kernels();
if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
candle::bail!(
"dtype mismatch in rope {:?} {:?} {:?}",
src.dtype(),
cos.dtype(),
sin.dtype()
)
}
let name = match src.dtype() {
candle::DType::F32 => "rope_f32",
candle::DType::F16 => "rope_f16",
candle::DType::BF16 => "rope_bf16",
dtype => candle::bail!("rope is not implemented for {dtype:?}"),
};
let (b, h, t, d) = l_src.shape().dims4()?;
let el = b * h * t * d;
let output = device.new_buffer(el, src.dtype(), "rope-i")?;
candle_metal_kernels::call_rope(
device.metal_device(),
&command_buffer,
kernels,
name,
b * h,
t * d,
d,
src.buffer(),
l_src.start_offset() * src.dtype().size_in_bytes(),
cos.buffer(),
l_cos.start_offset() * cos.dtype().size_in_bytes(),
sin.buffer(),
l_sin.start_offset() * sin.dtype().size_in_bytes(),
&output,
)
.map_err(candle::Error::wrap)?;
let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype());
Ok((out, l_src.shape().clone()))
}
}
pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
let (sin_seq_len, sin_n_embd) = cos.dims2()?;
if cos_n_embd * 2 != n_embd
|| sin_n_embd * 2 != n_embd
|| seq_len > cos_seq_len
|| seq_len > sin_seq_len
{
candle::bail!(
"inconsistent last dim size in rope {:?} {:?} {:?}",
xs.shape(),
cos.shape(),
sin.shape()
)
}
if !xs.is_contiguous() {
candle::bail!("xs has to be contiguous in rope")
}
if !cos.is_contiguous() {
candle::bail!("cos has to be contiguous in rope")
}
if !sin.is_contiguous() {
candle::bail!("sin has to be contiguous in rope")
}
xs.apply_op3_no_bwd(cos, sin, &RotaryEmb)
}
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
let last_dim = xs.dim(D::Minus1)?;
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
}
pub fn rope_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let (_b_sz, _h, seq_len, _n_embd) = x.dims4()?;
let cos = Tensor::cat(&[cos, cos], D::Minus1)?;
let sin = Tensor::cat(&[sin, sin], D::Minus1)?;
let cos = cos.narrow(0, 0, seq_len)?;
let sin = sin.narrow(0, 0, seq_len)?;
let cos = cos.unsqueeze(0)?.unsqueeze(0)?;
let sin = sin.unsqueeze(0)?.unsqueeze(0)?;
x.broadcast_mul(&cos)? + rotate_half(x)?.broadcast_mul(&sin)?
}

View File

@ -86,7 +86,7 @@ fn softmax_numerical_stability() -> Result<()> {
Ok(())
}
fn rope(device: &Device) -> Result<()> {
fn ropei(device: &Device) -> Result<()> {
use rand::{rngs::StdRng, Rng, SeedableRng};
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
@ -107,12 +107,40 @@ fn rope(device: &Device) -> Result<()> {
let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
if device.is_cpu() {
assert_eq!(sum_diff, 0.);
} else if device.is_cuda() {
} else {
assert!(sum_diff < 1e-4);
}
Ok(())
}
fn rope(device: &Device) -> Result<()> {
use rand::{rngs::StdRng, Rng, SeedableRng};
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
let el_count = b_size * num_head * seq_len * head_dim;
let mut rng = StdRng::seed_from_u64(299792458);
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
let cos: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.gen::<f32>())
.collect();
let sin: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.gen::<f32>())
.collect();
let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?;
let rope1 = candle_nn::rotary_emb::rope(&src, &cos, &sin)?;
let rope2 = candle_nn::rotary_emb::rope_slow(&src, &cos, &sin)?;
let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
if device.is_cpu() {
assert_eq!(sum_diff, 0.);
} else {
assert!(sum_diff < 1e-4);
}
Ok(())
}
test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal);
test_device!(rope, rope_cpu, rope_gpu, rope_metal);
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);