Integrate the MLX gemm kernels (#2468)
* Include the MLX gemm kernels. * Clippy lints. * Export the gemm_f32 kernel. * Add the f16/bf16 variants. * Add the initial dispatch code. * More plugging of the mlx kernels. * Add a currently broken test. * Tweaks. * Bugfix + get the tests to pass. * Enable the gemm bf16 tests. * Add some randomized tests. * Update candle-metal-kernels/src/lib.rs Co-authored-by: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com> * More fixes. * More clippy fixes. --------- Co-authored-by: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com>
This commit is contained in:
parent
13b2a8a4a0
commit
5635650d38
|
@ -23,3 +23,4 @@ half = { version = "2.3.1", features = [
|
|||
"rand_distr",
|
||||
] }
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
|
|
|
@ -11,33 +11,35 @@ pub use utils::BufferOffset;
|
|||
use utils::{get_block_dims, linear_split, EncoderProvider};
|
||||
|
||||
const AFFINE: &str = include_str!("affine.metal");
|
||||
const INDEXING: &str = include_str!("indexing.metal");
|
||||
const UNARY: &str = include_str!("unary.metal");
|
||||
const BINARY: &str = include_str!("binary.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const CAST: &str = include_str!("cast.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
const INDEXING: &str = include_str!("indexing.metal");
|
||||
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const MLX_GEMM: &str = include_str!("mlx_gemm.metal");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const SORT: &str = include_str!("sort.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const UNARY: &str = include_str!("unary.metal");
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum Source {
|
||||
Affine,
|
||||
Indexing,
|
||||
Unary,
|
||||
Binary,
|
||||
Ternary,
|
||||
Cast,
|
||||
Reduce,
|
||||
Mfa,
|
||||
Conv,
|
||||
Random,
|
||||
Gemm,
|
||||
Indexing,
|
||||
Mfa,
|
||||
Quantized,
|
||||
Random,
|
||||
Reduce,
|
||||
Sort,
|
||||
Ternary,
|
||||
Unary,
|
||||
}
|
||||
|
||||
pub mod copy2d {
|
||||
|
@ -191,16 +193,17 @@ impl Kernels {
|
|||
fn get_library_source(&self, source: Source) -> &'static str {
|
||||
match source {
|
||||
Source::Affine => AFFINE,
|
||||
Source::Unary => UNARY,
|
||||
Source::Binary => BINARY,
|
||||
Source::Ternary => TERNARY,
|
||||
Source::Indexing => INDEXING,
|
||||
Source::Cast => CAST,
|
||||
Source::Reduce => REDUCE,
|
||||
Source::Conv => CONV,
|
||||
Source::Random => RANDOM,
|
||||
Source::Gemm => MLX_GEMM,
|
||||
Source::Indexing => INDEXING,
|
||||
Source::Quantized => QUANTIZED,
|
||||
Source::Random => RANDOM,
|
||||
Source::Reduce => REDUCE,
|
||||
Source::Sort => SORT,
|
||||
Source::Ternary => TERNARY,
|
||||
Source::Unary => UNARY,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
}
|
||||
}
|
||||
|
@ -2178,5 +2181,181 @@ pub fn call_arg_sort(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
|
||||
pub enum GemmDType {
|
||||
BF16,
|
||||
F16,
|
||||
F32,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_mlx_gemm(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
dtype: GemmDType,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_stride: &[usize],
|
||||
lhs_offset: usize,
|
||||
lhs_buffer: &Buffer,
|
||||
rhs_stride: &[usize],
|
||||
rhs_offset: usize,
|
||||
rhs_buffer: &Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
#[derive(Debug)]
|
||||
#[repr(C)]
|
||||
struct GemmParams {
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
lda: i32,
|
||||
ldb: i32,
|
||||
ldd: i32,
|
||||
tiles_n: i32,
|
||||
tiles_m: i32,
|
||||
batch_stride_a: isize,
|
||||
batch_stride_b: isize,
|
||||
batch_stride_d: isize,
|
||||
swizzle_log: i32,
|
||||
gemm_k_iterations_aligned: i32,
|
||||
batch_ndim: i32,
|
||||
}
|
||||
assert!(rhs_stride.len() >= 2);
|
||||
assert!(lhs_stride.len() >= 2);
|
||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
// lhs has shape b, m, k
|
||||
// We also allow for the case where the stride on the minor dimension is not as expected but
|
||||
// there is a single element.
|
||||
let (lda, a_trans) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||
(k as i32, false)
|
||||
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
|
||||
(m as i32, true)
|
||||
} else {
|
||||
return Err(MetalKernelError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?;
|
||||
};
|
||||
// rhs has shape b, k, n
|
||||
let (ldb, b_trans) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||
(n as i32, false)
|
||||
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
|
||||
(k as i32, true)
|
||||
} else {
|
||||
return Err(MetalKernelError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?;
|
||||
};
|
||||
let (bm, bn, bk, wn, wm) = (32, 32, 16, 2, 2);
|
||||
// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/matmul.cpp#L422
|
||||
let constants = Some(ConstantValues::new(vec![
|
||||
(10, Value::Bool(/* has_batch */ b > 1)),
|
||||
(100, Value::Bool(/* use_out_source */ false)),
|
||||
(110, Value::Bool(/* do_axpby */ false)),
|
||||
(200, Value::Bool(/* align_m */ m % bm == 0)),
|
||||
(201, Value::Bool(/* align_n */ n % bn == 0)),
|
||||
(202, Value::Bool(/* align_k */ k % bk == 0)),
|
||||
(300, Value::Bool(/* do_gather */ false)),
|
||||
]));
|
||||
|
||||
let swizzle_log = 0;
|
||||
let tile = 1 << swizzle_log;
|
||||
let tn = n.div_ceil(bn);
|
||||
let tm = m.div_ceil(bm);
|
||||
let tn = tn * tile;
|
||||
let tm = tm.div_ceil(tile);
|
||||
|
||||
let batch_stride_a = if lhs_stride.len() > 2 {
|
||||
lhs_stride[lhs_stride.len() - 3]
|
||||
} else {
|
||||
m * k
|
||||
};
|
||||
let batch_stride_b = if rhs_stride.len() > 2 {
|
||||
rhs_stride[rhs_stride.len() - 3]
|
||||
} else {
|
||||
n * k
|
||||
};
|
||||
|
||||
let gemm_params = GemmParams {
|
||||
m: m as i32,
|
||||
n: n as i32,
|
||||
k: k as i32,
|
||||
lda,
|
||||
ldb,
|
||||
ldd: n as i32,
|
||||
tiles_n: tn as i32,
|
||||
tiles_m: tm as i32,
|
||||
swizzle_log,
|
||||
batch_stride_a: batch_stride_a as isize,
|
||||
batch_stride_b: batch_stride_b as isize,
|
||||
batch_stride_d: (m * n) as isize,
|
||||
batch_ndim: 1i32,
|
||||
gemm_k_iterations_aligned: (k / bk) as i32,
|
||||
};
|
||||
let batch_strides = [gemm_params.batch_stride_a, gemm_params.batch_stride_b];
|
||||
|
||||
// TODO(laurent): generate the name
|
||||
// template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]]
|
||||
let name = match (dtype, a_trans, b_trans) {
|
||||
(GemmDType::F32, false, false) => "gemm_nn_f32_f32_32_32_16_2_2",
|
||||
(GemmDType::F32, true, false) => "gemm_tn_f32_f32_32_32_16_2_2",
|
||||
(GemmDType::F32, false, true) => "gemm_nt_f32_f32_32_32_16_2_2",
|
||||
(GemmDType::F32, true, true) => "gemm_tt_f32_f32_32_32_16_2_2",
|
||||
(GemmDType::BF16, false, false) => "gemm_nn_bf16_bf16_32_32_16_2_2",
|
||||
(GemmDType::BF16, true, false) => "gemm_tn_bf16_bf16_32_32_16_2_2",
|
||||
(GemmDType::BF16, false, true) => "gemm_nt_bf16_bf16_32_32_16_2_2",
|
||||
(GemmDType::BF16, true, true) => "gemm_tt_bf16_bf16_32_32_16_2_2",
|
||||
(GemmDType::F16, false, false) => "gemm_nn_f16_f16_32_32_16_2_2",
|
||||
(GemmDType::F16, true, false) => "gemm_tn_f16_f16_32_32_16_2_2",
|
||||
(GemmDType::F16, false, true) => "gemm_nt_f16_f16_32_32_16_2_2",
|
||||
(GemmDType::F16, true, true) => "gemm_tt_f16_f16_32_32_16_2_2",
|
||||
};
|
||||
let pipeline = kernels.load_pipeline_with_constants(device, Source::Gemm, name, constants)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
||||
encoder.set_buffer(3, Some(output), 0);
|
||||
encoder.set_bytes(
|
||||
4,
|
||||
std::mem::size_of::<GemmParams>() as u64,
|
||||
&gemm_params as *const GemmParams as *const c_void,
|
||||
);
|
||||
encoder.set_bytes(
|
||||
6, // batch_shape
|
||||
std::mem::size_of::<i32>() as u64,
|
||||
&(b as i32) as *const i32 as *const c_void,
|
||||
);
|
||||
encoder.set_bytes(
|
||||
7,
|
||||
(std::mem::size_of::<isize>() * batch_strides.len()) as u64,
|
||||
batch_strides.as_ptr() as *const c_void,
|
||||
);
|
||||
|
||||
let grid_size = MTLSize {
|
||||
width: tn as u64,
|
||||
height: tm as u64,
|
||||
depth: /* batch_size_out */ b as u64,
|
||||
};
|
||||
let group_size = MTLSize {
|
||||
width: 32,
|
||||
height: wn,
|
||||
depth: wm,
|
||||
};
|
||||
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(grid_size, group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -329,7 +329,7 @@ fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
|||
|
||||
#[test]
|
||||
fn cast_f32() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f64 = [1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
|
@ -360,7 +360,7 @@ fn cast_f32() {
|
|||
|
||||
#[test]
|
||||
fn cast_f16() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f64 = [1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
|
@ -391,7 +391,7 @@ fn cast_f16() {
|
|||
|
||||
#[test]
|
||||
fn cast_bf16() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f64 = [1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
|
@ -422,7 +422,7 @@ fn cast_bf16() {
|
|||
|
||||
#[test]
|
||||
fn cast_u32() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f64 = [1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
|
@ -453,7 +453,7 @@ fn cast_u32() {
|
|||
|
||||
#[test]
|
||||
fn cast_u8() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f64 = [1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
|
@ -484,7 +484,7 @@ fn cast_u8() {
|
|||
|
||||
#[test]
|
||||
fn cast_i64() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f64 = [1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
|
@ -911,7 +911,7 @@ fn softmax() {
|
|||
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
|
||||
);
|
||||
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
|
@ -922,7 +922,7 @@ fn softmax() {
|
|||
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
|
||||
);
|
||||
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
|
@ -1045,14 +1045,15 @@ fn where_cond_u32_f32() {
|
|||
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_gemm<T: Clone>(
|
||||
name: &'static str,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs: &[T],
|
||||
lhs_stride: Vec<usize>,
|
||||
lhs_stride: &[usize],
|
||||
lhs_offset: usize,
|
||||
rhs: &[T],
|
||||
rhs_stride: Vec<usize>,
|
||||
rhs_stride: &[usize],
|
||||
rhs_offset: usize,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
|
@ -1079,10 +1080,10 @@ fn run_gemm<T: Clone>(
|
|||
&kernels,
|
||||
name,
|
||||
(b, m, n, k),
|
||||
&lhs_stride,
|
||||
lhs_stride,
|
||||
lhs_offset,
|
||||
&lhs,
|
||||
&rhs_stride,
|
||||
rhs_stride,
|
||||
rhs_offset,
|
||||
&rhs,
|
||||
&output,
|
||||
|
@ -1105,10 +1106,10 @@ fn gemm() {
|
|||
"sgemm",
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
lhs_stride,
|
||||
&lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
rhs_stride,
|
||||
&rhs_stride,
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
|
@ -1125,10 +1126,10 @@ fn gemm() {
|
|||
"sgemm",
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
lhs_stride,
|
||||
&lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
rhs_stride,
|
||||
&rhs_stride,
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
|
@ -1150,10 +1151,10 @@ fn gemm() {
|
|||
"sgemm",
|
||||
(1, m, n, k),
|
||||
&lhs,
|
||||
lhs_stride,
|
||||
&lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
rhs_stride,
|
||||
&rhs_stride,
|
||||
12 * 4,
|
||||
);
|
||||
assert_eq!(
|
||||
|
@ -1172,10 +1173,10 @@ fn gemm() {
|
|||
"bgemm",
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
lhs_stride,
|
||||
&lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
rhs_stride,
|
||||
&rhs_stride,
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
|
@ -1194,10 +1195,10 @@ fn gemm() {
|
|||
"hgemm",
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
lhs_stride,
|
||||
&lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
rhs_stride,
|
||||
&rhs_stride,
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
|
@ -1206,6 +1207,204 @@ fn gemm() {
|
|||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_mlx_gemm<T: Clone>(
|
||||
dtype: GemmDType,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs: &[T],
|
||||
lhs_stride: &[usize],
|
||||
lhs_offset: usize,
|
||||
rhs: &[T],
|
||||
rhs_stride: &[usize],
|
||||
rhs_offset: usize,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let lhs = device.new_buffer_with_data(
|
||||
lhs.as_ptr() as *const core::ffi::c_void,
|
||||
std::mem::size_of_val(lhs) as u64,
|
||||
options,
|
||||
);
|
||||
let rhs = device.new_buffer_with_data(
|
||||
rhs.as_ptr() as *const core::ffi::c_void,
|
||||
std::mem::size_of_val(rhs) as u64,
|
||||
options,
|
||||
);
|
||||
let length = b * m * n;
|
||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||
call_mlx_gemm(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
dtype,
|
||||
(b, m, n, k),
|
||||
lhs_stride,
|
||||
lhs_offset,
|
||||
&lhs,
|
||||
rhs_stride,
|
||||
rhs_offset,
|
||||
&rhs,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
read_to_vec(&output, length)
|
||||
}
|
||||
|
||||
fn mlx_vs_mfa_one(b: usize, m: usize, n: usize, k: usize, dtype: GemmDType) {
|
||||
use rand::SeedableRng;
|
||||
use rand_distr::Distribution;
|
||||
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(42424242);
|
||||
let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
|
||||
|
||||
let lhs: Vec<_> = (0..b * m * k).map(|_| normal.sample(&mut rng)).collect();
|
||||
let rhs: Vec<_> = (0..b * n * k).map(|_| normal.sample(&mut rng)).collect();
|
||||
let v1: Vec<f32> = run_mlx_gemm(
|
||||
dtype,
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&[k * n, n, 1],
|
||||
0,
|
||||
);
|
||||
let v2: Vec<f32> = run_gemm(
|
||||
"sgemm",
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&[k * n, n, 1],
|
||||
0,
|
||||
);
|
||||
for (a, b) in v1.iter().zip(v2.iter()) {
|
||||
let diff = (a - b).abs();
|
||||
assert_eq!((diff * 1e4).round(), 0.)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mlx_vs_mfa() {
|
||||
mlx_vs_mfa_one(1, 32, 32, 25, GemmDType::F32);
|
||||
mlx_vs_mfa_one(1, 128, 128, 100, GemmDType::F32);
|
||||
mlx_vs_mfa_one(1, 256, 256, 256, GemmDType::F32);
|
||||
mlx_vs_mfa_one(1, 192, 200, 75, GemmDType::F32);
|
||||
mlx_vs_mfa_one(3, 27, 67, 64, GemmDType::F32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mlx_gemm() {
|
||||
let (b, m, n, k) = (1, 2, 4, 3);
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
let results = run_mlx_gemm(
|
||||
GemmDType::F32,
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&[n * k, n, 1],
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||
);
|
||||
|
||||
let (b, m, n, k) = (2, 2, 4, 3);
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
let results = run_mlx_gemm(
|
||||
GemmDType::F32,
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&[n * k, n, 1],
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![
|
||||
20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0,
|
||||
518.0, 548.0, 578.0
|
||||
]
|
||||
);
|
||||
|
||||
// OFFSET
|
||||
let (b, m, n, k) = (2, 2, 4, 3);
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
// Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32
|
||||
let results = run_mlx_gemm(
|
||||
GemmDType::F32,
|
||||
(1, m, n, k),
|
||||
&lhs,
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&[n * k, n, 1],
|
||||
12 * 4,
|
||||
);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
|
||||
);
|
||||
|
||||
// bgemm sanity test
|
||||
{
|
||||
let (b, m, n, k) = (1, 2, 4, 3);
|
||||
let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
|
||||
let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
|
||||
let results = run_mlx_gemm(
|
||||
GemmDType::BF16,
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&[n * k, n, 1],
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
approx_bf16(results, 4),
|
||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// hgemm sanity test
|
||||
let (b, m, n, k) = (1, 2, 4, 3);
|
||||
let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect();
|
||||
let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect();
|
||||
let results = run_mlx_gemm(
|
||||
GemmDType::F16,
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&[n * k, n, 1],
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
approx_f16(results, 4),
|
||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
|
@ -1280,7 +1479,7 @@ fn random() {
|
|||
variance.sqrt()
|
||||
}
|
||||
|
||||
let shape = vec![1024, 10];
|
||||
let shape = [1024, 10];
|
||||
|
||||
let length = shape.iter().product::<usize>();
|
||||
let seed = 299792458;
|
||||
|
@ -1636,7 +1835,7 @@ fn max_pool2d_f16() {
|
|||
&strides,
|
||||
"max_pool2d_f16",
|
||||
);
|
||||
let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
|
||||
let expected = [5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
|
||||
.iter()
|
||||
.map(|v| half::f16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
|
@ -1656,7 +1855,7 @@ fn max_pool2d_f16() {
|
|||
&strides,
|
||||
"max_pool2d_f16",
|
||||
);
|
||||
let expected = vec![5.0, 7.0, 13.0, 15.0]
|
||||
let expected = [5.0, 7.0, 13.0, 15.0]
|
||||
.iter()
|
||||
.map(|v| half::f16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
|
@ -1679,7 +1878,7 @@ fn max_pool2d_bf16() {
|
|||
&strides,
|
||||
"max_pool2d_bf16",
|
||||
);
|
||||
let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
|
||||
let expected = [5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
|
||||
.iter()
|
||||
.map(|v| half::bf16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
|
@ -1699,7 +1898,7 @@ fn max_pool2d_bf16() {
|
|||
&strides,
|
||||
"max_pool2d_bf16",
|
||||
);
|
||||
let expected = vec![5.0, 7.0, 13.0, 15.0]
|
||||
let expected = [5.0, 7.0, 13.0, 15.0]
|
||||
.iter()
|
||||
.map(|v| half::bf16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
|
@ -1818,7 +2017,7 @@ fn avg_pool2d_f16() {
|
|||
&strides,
|
||||
"avg_pool2d_f16",
|
||||
);
|
||||
let expected = vec![
|
||||
let expected = [
|
||||
2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000,
|
||||
]
|
||||
.iter()
|
||||
|
@ -1843,7 +2042,7 @@ fn avg_pool2d_bf16() {
|
|||
&strides,
|
||||
"avg_pool2d_bf16",
|
||||
);
|
||||
let expected = vec![
|
||||
let expected = [
|
||||
2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000,
|
||||
]
|
||||
.iter()
|
||||
|
@ -1981,14 +2180,14 @@ fn conv_transpose1d_f32() {
|
|||
|
||||
#[test]
|
||||
fn conv_transpose1d_f16() {
|
||||
let input: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0]
|
||||
let input: Vec<f16> = [1.0, 2.0, 3.0, 4.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect();
|
||||
let input_shape = &[1, 1, 4];
|
||||
let input_stride = &[4, 4, 1];
|
||||
|
||||
let kernel: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0]
|
||||
let kernel: Vec<f16> = [1.0, 2.0, 3.0, 4.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect();
|
||||
|
@ -2009,7 +2208,7 @@ fn conv_transpose1d_f16() {
|
|||
"conv_transpose1d_f16",
|
||||
);
|
||||
|
||||
let expected = vec![1., 4., 10., 20., 25., 24., 16.]
|
||||
let expected = [1., 4., 10., 20., 25., 24., 16.]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
|
@ -2018,14 +2217,14 @@ fn conv_transpose1d_f16() {
|
|||
|
||||
#[test]
|
||||
fn conv_transpose1d_bf16() {
|
||||
let input: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0]
|
||||
let input: Vec<bf16> = [1.0, 2.0, 3.0, 4.0]
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect();
|
||||
let input_shape = &[1, 1, 4];
|
||||
let input_stride = &[4, 4, 1];
|
||||
|
||||
let kernel: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0]
|
||||
let kernel: Vec<bf16> = [1.0, 2.0, 3.0, 4.0]
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect();
|
||||
|
@ -2046,7 +2245,7 @@ fn conv_transpose1d_bf16() {
|
|||
"conv_transpose1d_bf16",
|
||||
);
|
||||
|
||||
let expected = vec![1., 4., 10., 20., 25., 24., 16.]
|
||||
let expected = [1., 4., 10., 20., 25., 24., 16.]
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
|
|
|
@ -165,7 +165,7 @@ pub trait EncoderProvider {
|
|||
type Encoder<'a>: AsRef<metal::ComputeCommandEncoderRef>
|
||||
where
|
||||
Self: 'a;
|
||||
fn encoder<'a>(&'a self) -> Self::Encoder<'a>;
|
||||
fn encoder(&self) -> Self::Encoder<'_>;
|
||||
}
|
||||
|
||||
pub struct WrappedEncoder<'a>(&'a ComputeCommandEncoderRef);
|
||||
|
@ -178,7 +178,7 @@ impl<'a> Drop for WrappedEncoder<'a> {
|
|||
|
||||
impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> {
|
||||
fn as_ref(&self) -> &metal::ComputeCommandEncoderRef {
|
||||
&self.0
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -186,7 +186,7 @@ impl EncoderProvider for &metal::CommandBuffer {
|
|||
type Encoder<'a> = WrappedEncoder<'a>
|
||||
where
|
||||
Self: 'a;
|
||||
fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
|
||||
fn encoder(&self) -> Self::Encoder<'_> {
|
||||
WrappedEncoder(self.new_compute_command_encoder())
|
||||
}
|
||||
}
|
||||
|
@ -195,7 +195,7 @@ impl EncoderProvider for &metal::CommandBufferRef {
|
|||
type Encoder<'a> = WrappedEncoder<'a>
|
||||
where
|
||||
Self: 'a;
|
||||
fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
|
||||
fn encoder(&self) -> Self::Encoder<'_> {
|
||||
WrappedEncoder(self.new_compute_command_encoder())
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue