Support the Accelerate BLAS on macOS. (#325)
* Add the accelerate feature. * Ffi tweaks.
This commit is contained in:
parent
0b175fcbbd
commit
b278834267
|
@ -24,6 +24,7 @@ categories = ["science"]
|
|||
license = "MIT/Apache-2.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
|
|
|
@ -10,6 +10,7 @@ license.workspace = true
|
|||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.1.0", optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
|
@ -32,3 +33,4 @@ anyhow = { workspace = true }
|
|||
default = []
|
||||
cuda = ["dep:cudarc", "dep:candle-kernels"]
|
||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
#![allow(dead_code)]
|
||||
use libc::{c_char, c_double, c_float, c_int};
|
||||
|
||||
mod ffi {
|
||||
use super::*;
|
||||
extern "C" {
|
||||
// It would be nice to be able to switch to the NEWLAPACK version of the function but this
|
||||
// seems to trigger some link error. Available function names can be seen here:
|
||||
// /Library/Developer/CommandLineTools/SDKs/MacOSX13.3.sdk/System/Library/Frameworks/Accelerate.framework/Versions/A/Accelerate.tbd
|
||||
#[link_name = "sgemm_"]
|
||||
pub fn sgemm_ffi(
|
||||
transa: *const c_char,
|
||||
transb: *const c_char,
|
||||
m: *const c_int,
|
||||
n: *const c_int,
|
||||
k: *const c_int,
|
||||
alpha: *const c_float,
|
||||
a: *const c_float,
|
||||
lda: *const c_int,
|
||||
b: *const c_float,
|
||||
ldb: *const c_int,
|
||||
beta: *const c_float,
|
||||
c: *mut c_float,
|
||||
ldc: *const c_int,
|
||||
);
|
||||
#[link_name = "dgemm_"]
|
||||
pub fn dgemm_ffi(
|
||||
transa: *const c_char,
|
||||
transb: *const c_char,
|
||||
m: *const c_int,
|
||||
n: *const c_int,
|
||||
k: *const c_int,
|
||||
alpha: *const c_double,
|
||||
a: *const c_double,
|
||||
lda: *const c_int,
|
||||
b: *const c_double,
|
||||
ldb: *const c_int,
|
||||
beta: *const c_double,
|
||||
c: *mut c_double,
|
||||
ldc: *const c_int,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[inline]
|
||||
pub unsafe fn sgemm(
|
||||
transa: u8,
|
||||
transb: u8,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
alpha: f32,
|
||||
a: &[f32],
|
||||
lda: i32,
|
||||
b: &[f32],
|
||||
ldb: i32,
|
||||
beta: f32,
|
||||
c: &mut [f32],
|
||||
ldc: i32,
|
||||
) {
|
||||
ffi::sgemm_ffi(
|
||||
&(transa as c_char),
|
||||
&(transb as c_char),
|
||||
&m,
|
||||
&n,
|
||||
&k,
|
||||
&alpha,
|
||||
a.as_ptr(),
|
||||
&lda,
|
||||
b.as_ptr(),
|
||||
&ldb,
|
||||
&beta,
|
||||
c.as_mut_ptr(),
|
||||
&ldc,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[inline]
|
||||
pub unsafe fn dgemm(
|
||||
transa: u8,
|
||||
transb: u8,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
alpha: f64,
|
||||
a: &[f64],
|
||||
lda: i32,
|
||||
b: &[f64],
|
||||
ldb: i32,
|
||||
beta: f64,
|
||||
c: &mut [f64],
|
||||
ldc: i32,
|
||||
) {
|
||||
ffi::dgemm_ffi(
|
||||
&(transa as c_char),
|
||||
&(transb as c_char),
|
||||
&m,
|
||||
&n,
|
||||
&k,
|
||||
&alpha,
|
||||
a.as_ptr(),
|
||||
&lda,
|
||||
b.as_ptr(),
|
||||
&ldb,
|
||||
&beta,
|
||||
c.as_mut_ptr(),
|
||||
&ldc,
|
||||
)
|
||||
}
|
|
@ -974,7 +974,7 @@ impl MatMul {
|
|||
impl Map2 for MatMul {
|
||||
const OP: &'static str = "mat_mul";
|
||||
|
||||
#[cfg(not(feature = "mkl"))]
|
||||
#[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))]
|
||||
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
|
||||
&self,
|
||||
lhs: &[T],
|
||||
|
@ -1053,6 +1053,109 @@ impl Map2 for MatMul {
|
|||
Ok(dst)
|
||||
}
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
|
||||
&self,
|
||||
lhs: &[T],
|
||||
lhs_l: &Layout,
|
||||
rhs: &[T],
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Vec<T>> {
|
||||
let (b, m, n, k) = self.0;
|
||||
let lhs = &lhs[lhs_l.start_offset()..];
|
||||
let rhs = &rhs[rhs_l.start_offset()..];
|
||||
|
||||
let lhs_stride = lhs_l.stride();
|
||||
let rhs_stride = rhs_l.stride();
|
||||
let rank = lhs_stride.len();
|
||||
|
||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
||||
};
|
||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
||||
};
|
||||
let c_skip: usize = m * n;
|
||||
|
||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
|
||||
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
(n as i32, b'N')
|
||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||
(k as i32, b'T')
|
||||
} else {
|
||||
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
||||
};
|
||||
// The b tensor has dims batching, m, k (lhs)
|
||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
(k as i32, b'N')
|
||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||
(m as i32, b'T')
|
||||
} else {
|
||||
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
|
||||
};
|
||||
|
||||
let mut dst = vec![T::zero(); b * m * n];
|
||||
match T::DTYPE {
|
||||
DType::F16 => {
|
||||
crate::bail!("the accelerate backend does not support f16 matmul")
|
||||
}
|
||||
DType::F32 => {
|
||||
for step in 0..b {
|
||||
let lhs_p = &lhs[step * a_skip..];
|
||||
let rhs_p = &rhs[step * b_skip..];
|
||||
let dst_p = &mut dst[step * c_skip..];
|
||||
unsafe {
|
||||
let a = rhs_p.as_ptr() as *const f32;
|
||||
let b = lhs_p.as_ptr() as *const f32;
|
||||
let c = dst_p.as_mut_ptr() as *mut f32;
|
||||
let a = std::slice::from_raw_parts(a, a_skip);
|
||||
let b = std::slice::from_raw_parts(b, b_skip);
|
||||
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
||||
crate::accelerate::sgemm(
|
||||
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
|
||||
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
|
||||
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
|
||||
/* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
DType::F64 => {
|
||||
for step in 0..b {
|
||||
let lhs_p = &lhs[step * a_skip..];
|
||||
let rhs_p = &rhs[step * b_skip..];
|
||||
let dst_p = &mut dst[step * c_skip..];
|
||||
unsafe {
|
||||
let a = rhs_p.as_ptr() as *const f64;
|
||||
let b = lhs_p.as_ptr() as *const f64;
|
||||
let c = dst_p.as_mut_ptr() as *mut f64;
|
||||
let a = std::slice::from_raw_parts(a, a_skip);
|
||||
let b = std::slice::from_raw_parts(b, b_skip);
|
||||
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
||||
crate::accelerate::dgemm(
|
||||
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
|
||||
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
|
||||
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
|
||||
/* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
|
||||
&self,
|
||||
|
|
|
@ -33,6 +33,8 @@
|
|||
//!
|
||||
//! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers)
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
mod accelerate;
|
||||
pub mod backend;
|
||||
pub mod backprop;
|
||||
mod conv;
|
||||
|
|
|
@ -11,16 +11,14 @@ pub fn get_num_threads() -> usize {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn has_accelerate() -> bool {
|
||||
cfg!(feature = "accelerate")
|
||||
}
|
||||
|
||||
pub fn has_mkl() -> bool {
|
||||
#[cfg(feature = "mkl")]
|
||||
return true;
|
||||
#[cfg(not(feature = "mkl"))]
|
||||
return false;
|
||||
cfg!(feature = "mkl")
|
||||
}
|
||||
|
||||
pub fn cuda_is_available() -> bool {
|
||||
#[cfg(feature = "cuda")]
|
||||
return true;
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
return false;
|
||||
cfg!(feature = "cuda")
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ license.workspace = true
|
|||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.1.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.0" }
|
||||
|
@ -41,6 +42,7 @@ anyhow = { workspace = true }
|
|||
|
||||
[features]
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
|
|
|
@ -9,6 +9,9 @@
|
|||
// In order to convert the llama weights to a .npz file, run:
|
||||
// python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
// https://github.com/karpathy/llama2.c
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ license.workspace = true
|
|||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
|
||||
thiserror = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
|
@ -20,5 +21,6 @@ anyhow = { workspace = true }
|
|||
|
||||
[features]
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||
cuda = ["candle/cuda"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||
|
|
|
@ -10,6 +10,7 @@ license.workspace = true
|
|||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
|
||||
hf-hub = { workspace = true}
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.0" }
|
||||
|
@ -20,5 +21,6 @@ wav = { workspace = true }
|
|||
|
||||
[features]
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
|
||||
|
|
Loading…
Reference in New Issue