diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs index abff2d1b..d902b9d5 100644 --- a/candle-core/examples/cuda_basics.rs +++ b/candle-core/examples/cuda_basics.rs @@ -1,3 +1,6 @@ +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + #[cfg(feature = "mkl")] extern crate intel_mkl_src; diff --git a/candle-core/examples/cuda_sum_benchmark.rs b/candle-core/examples/cuda_sum_benchmark.rs index 1c8b0136..d6d182e8 100644 --- a/candle-core/examples/cuda_sum_benchmark.rs +++ b/candle-core/examples/cuda_sum_benchmark.rs @@ -1,6 +1,9 @@ #[cfg(feature = "mkl")] extern crate intel_mkl_src; +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + use std::str::FromStr; use anyhow::Result; diff --git a/candle-core/src/accelerate.rs b/candle-core/src/accelerate.rs index 8b0df5c1..3c92dc97 100644 --- a/candle-core/src/accelerate.rs +++ b/candle-core/src/accelerate.rs @@ -1,5 +1,5 @@ #![allow(dead_code)] -use libc::{c_char, c_double, c_float, c_int}; +use libc::{c_char, c_double, c_float, c_int, c_long, c_ulong}; mod ffi { use super::*; @@ -39,6 +39,90 @@ mod ffi { c: *mut c_double, ldc: *const c_int, ); + + pub fn vvexpf(dst: *mut c_float, src: *const c_float, len: *const c_int); + pub fn vvexp(dst: *mut c_double, src: *const c_double, len: *const c_int); + pub fn vvsqrtf(dst: *mut c_float, src: *const c_float, len: *const c_int); + pub fn vvsqrt(dst: *mut c_double, src: *const c_double, len: *const c_int); + pub fn vvsinf(dst: *mut c_float, src: *const c_float, len: *const c_int); + pub fn vvsin(dst: *mut c_double, src: *const c_double, len: *const c_int); + pub fn vvcosf(dst: *mut c_float, src: *const c_float, len: *const c_int); + pub fn vvcos(dst: *mut c_double, src: *const c_double, len: *const c_int); + pub fn vvlogf(dst: *mut c_float, src: *const c_float, len: *const c_int); + pub fn vvlog(dst: *mut c_double, src: *const c_double, len: *const c_int); + + pub fn vDSP_vaddD( + _: *const c_double, + _: c_long, + _: *const c_double, + _: c_long, + _: *mut c_double, + _: c_long, + _: c_ulong, + ); + pub fn vDSP_vadd( + _: *const c_float, + _: c_long, + _: *const c_float, + _: c_long, + _: *mut c_float, + _: c_long, + _: c_ulong, + ); + pub fn vDSP_vsubD( + _: *const c_double, + _: c_long, + _: *const c_double, + _: c_long, + _: *mut c_double, + _: c_long, + _: c_ulong, + ); + pub fn vDSP_vsub( + _: *const c_float, + _: c_long, + _: *const c_float, + _: c_long, + _: *mut c_float, + _: c_long, + _: c_ulong, + ); + pub fn vDSP_vmulD( + _: *const c_double, + _: c_long, + _: *const c_double, + _: c_long, + _: *mut c_double, + _: c_long, + _: c_ulong, + ); + pub fn vDSP_vmul( + _: *const c_float, + _: c_long, + _: *const c_float, + _: c_long, + _: *mut c_float, + _: c_long, + _: c_ulong, + ); + pub fn vDSP_vdivD( + _: *const c_double, + _: c_long, + _: *const c_double, + _: c_long, + _: *mut c_double, + _: c_long, + _: c_ulong, + ); + pub fn vDSP_vdiv( + _: *const c_float, + _: c_long, + _: *const c_float, + _: c_long, + _: *mut c_float, + _: c_long, + _: c_ulong, + ); } } @@ -109,3 +193,158 @@ pub unsafe fn dgemm( &ldc, ) } + +#[inline] +pub fn vs_exp(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvexpf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} + +#[inline] +pub fn vd_exp(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvexp(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} + +#[inline] +pub fn vs_sqrt(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvsqrtf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} + +#[inline] +pub fn vd_sqrt(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvsqrt(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} + +#[inline] +pub fn vs_sin(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvsinf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} + +#[inline] +pub fn vd_sin(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvsin(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} +#[inline] +pub fn vs_cos(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvcosf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} + +#[inline] +pub fn vd_cos(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvcos(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} +#[inline] +pub fn vs_ln(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvlogf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} + +#[inline] +pub fn vd_ln(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvlog(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} + +#[inline] +pub fn vs_sqr(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a) +} + +#[inline] +pub fn vd_sqr(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a) +} + +macro_rules! binary_op { + ($fn_name:ident, $ty:ty, $accelerate_name:ident) => { + #[inline] + pub fn $fn_name(a: &[$ty], b: &[$ty], y: &mut [$ty]) { + let a_len = a.len(); + let b_len = b.len(); + let y_len = y.len(); + if a_len != y_len || b_len != y_len { + panic!( + "{} a,b,y len mismatch {a_len} {b_len} {y_len}", + stringify!($fn_name) + ); + } + unsafe { + // Weird quirk of accelerate, the rhs comes before the lhs. + ffi::$accelerate_name( + b.as_ptr(), + 1, + a.as_ptr(), + 1, + y.as_mut_ptr(), + 1, + a_len as u64, + ) + } + } + }; +} +binary_op!(vs_add, f32, vDSP_vadd); +binary_op!(vd_add, f64, vDSP_vaddD); +binary_op!(vs_sub, f32, vDSP_vsub); +binary_op!(vd_sub, f64, vDSP_vsubD); +binary_op!(vs_mul, f32, vDSP_vmul); +binary_op!(vd_mul, f64, vDSP_vmulD); +binary_op!(vs_div, f32, vDSP_vdiv); +binary_op!(vd_div, f64, vDSP_vdivD); diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index f99d8adc..2b57f7f7 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -338,6 +338,21 @@ macro_rules! bin_op { fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) { crate::mkl::$f64_vec(xs1, xs2, ys) } + + #[cfg(feature = "accelerate")] + const F32_VEC: bool = true; + #[cfg(feature = "accelerate")] + const F64_VEC: bool = true; + #[cfg(feature = "accelerate")] + #[inline(always)] + fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) { + crate::accelerate::$f32_vec(xs1, xs2, ys) + } + #[cfg(feature = "accelerate")] + #[inline(always)] + fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) { + crate::accelerate::$f64_vec(xs1, xs2, ys) + } } }; } @@ -424,6 +439,21 @@ macro_rules! unary_op { fn f64_vec(xs: &[f64], ys: &mut [f64]) { crate::mkl::$f64_vec(xs, ys) } + + #[cfg(feature = "accelerate")] + const F32_VEC: bool = true; + #[cfg(feature = "accelerate")] + const F64_VEC: bool = true; + #[cfg(feature = "accelerate")] + #[inline(always)] + fn f32_vec(xs: &[f32], ys: &mut [f32]) { + crate::accelerate::$f32_vec(xs, ys) + } + #[cfg(feature = "accelerate")] + #[inline(always)] + fn f64_vec(xs: &[f64], ys: &mut [f64]) { + crate::accelerate::$f64_vec(xs, ys) + } } }; } diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index f5d7cd14..6657c918 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -85,8 +85,14 @@ fn unary_grad(device: &Device) -> Result<()> { let y = (x.log()? + 1.)?; let grads = y.backward()?; let grad_x = grads.get(x).context("no grad for x")?; - assert_eq!(y.to_vec1::()?, [2.0986123, 1.0, 2.3862944, -0.89712]); - assert_eq!(grad_x.to_vec1::()?, [0.33333334, 1.0, 0.25, 6.6666665]); + assert_eq!( + test_utils::to_vec1_round(&y, 4)?, + [2.0986, 1.0, 2.3863, -0.8971] + ); + assert_eq!( + test_utils::to_vec1_round(grad_x, 4)?, + [0.3333, 1.0, 0.25, 6.6667] + ); let y = x.exp()?; let grads = y.backward()?; let grad_x = grads.get(x).context("no grad for x")?; @@ -141,7 +147,7 @@ fn unary_grad(device: &Device) -> Result<()> { let grads = y.backward()?; let grad_x = grads.get(x).context("no grad for x")?; assert_eq!(y.to_vec1::()?, [3.0, 1.0, 4.0, 0.15]); - assert_eq!(grad_x.to_vec1::()?, [1.0, 1.0, 1.0, 1.0]); + assert_eq!(test_utils::to_vec1_round(grad_x, 4)?, [1.0, 1.0, 1.0, 1.0]); let y = x.neg()?; let grads = y.backward()?; let grad_x = grads.get(x).context("no grad for x")?; @@ -155,7 +161,10 @@ fn unary_grad(device: &Device) -> Result<()> { let y = Tensor::new(1f32, device)?.broadcast_div(x)?; let grads = y.backward()?; let grad_x = grads.get(x).context("no grad for x")?; - assert_eq!(y.to_vec1::()?, [0.33333334, 1.0, 0.25, 6.6666665]); + assert_eq!( + test_utils::to_vec1_round(&y, 4)?, + [0.3333, 1.0, 0.25, 6.6667] + ); assert_eq!( grad_x.to_vec1::()?, [-0.11111111, -1.0, -0.0625, -44.444443], diff --git a/candle-core/tests/test_utils.rs b/candle-core/tests/test_utils.rs index 0c7ec1b6..327e88c6 100644 --- a/candle-core/tests/test_utils.rs +++ b/candle-core/tests/test_utils.rs @@ -1,5 +1,8 @@ #![allow(dead_code)] +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + use candle_core::{Result, Tensor}; #[macro_export] diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index a01191a5..c37d9a96 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -1,5 +1,8 @@ // TODO: Add an offline mode. +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + #[cfg(feature = "mkl")] extern crate intel_mkl_src; diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 7e614c9c..bc12692d 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -1,9 +1,11 @@ // https://github.com/openai/whisper/blob/main/whisper/model.py/rgs // TODO: -// - kv-cache support? // - Batch size greater than 1. // - More token filters (SuppressBlanks, ApplyTimestampRules). +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + #[cfg(feature = "mkl")] extern crate intel_mkl_src; diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index d3ebe02e..2fa04fb0 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -1,4 +1,4 @@ -use candle::{Device, Result, Tensor}; +use candle::{Device, IndexOp, Result, Tensor}; use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder}; use serde::Deserialize; @@ -105,12 +105,16 @@ struct MultiHeadAttention { out: Linear, n_head: usize, span: tracing::Span, + softmax_span: tracing::Span, + matmul_span: tracing::Span, kv_cache: Option<(Tensor, Tensor)>, } impl MultiHeadAttention { fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result { let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn"); + let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax"); + let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul"); let query = linear(n_state, n_state, vb.pp("q_proj"))?; let value = linear(n_state, n_state, vb.pp("v_proj"))?; let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?; @@ -122,6 +126,8 @@ impl MultiHeadAttention { out, n_head, span, + softmax_span, + matmul_span, kv_cache: None, }) } @@ -178,13 +184,24 @@ impl MultiHeadAttention { let q = (self.reshape_head(q)? * scale)?; let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?; let v = self.reshape_head(v)?.contiguous()?; - let mut qk = q.matmul(&k)?; + let mut qk = { + let _enter = self.matmul_span.enter(); + q.matmul(&k)? + }; if let Some(mask) = mask { - let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?; + let mask = mask.i((0..n_ctx, 0..n_ctx))?; qk = qk.broadcast_add(&mask)? } - let w = softmax(&qk, candle::D::Minus1)?; - let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?; + let w = { + let _enter = self.softmax_span.enter(); + softmax(&qk, candle::D::Minus1)? + }; + let wv = { + let _enter = self.matmul_span.enter(); + w.matmul(&v)? + } + .transpose(1, 2)? + .flatten_from(2)?; Ok(wv) } }