From 0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 27 Oct 2024 15:20:37 +0100 Subject: [PATCH] UG metal integration. (#2580) --- Cargo.toml | 1 + candle-core/Cargo.toml | 3 +- candle-core/src/custom_op.rs | 48 ++++++++++++++++++++++--- candle-core/src/device.rs | 8 +++++ candle-core/src/metal_backend/device.rs | 22 ++++++++++++ candle-core/tests/custom_op_tests.rs | 16 ++++++--- candle-metal-kernels/src/lib.rs | 2 +- candle-metal-kernels/src/utils.rs | 10 ++---- 8 files changed, 92 insertions(+), 18 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 64e1460e..f27ec933 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -72,6 +72,7 @@ tracing-chrome = "0.7.1" tracing-subscriber = "0.3.7" ug = "0.0.2" ug-cuda = "0.0.2" +ug-metal = "0.0.2" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "1.1.1", default-features = false } metal = { version = "0.27.0", features = ["mps"]} diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 8ea2b08c..4ffc869f 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -30,6 +30,7 @@ safetensors = { workspace = true } thiserror = { workspace = true } ug = { workspace = true } ug-cuda = { workspace = true, optional = true } +ug-metal = { workspace = true, optional = true } yoke = { workspace = true } zip = { workspace = true } @@ -45,7 +46,7 @@ cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"] cudnn = ["cuda", "cudarc/cudnn"] mkl = ["dep:libc", "dep:intel-mkl-src"] accelerate = ["dep:libc", "dep:accelerate-src"] -metal = ["dep:metal", "dep:candle-metal-kernels"] +metal = ["dep:metal", "dep:candle-metal-kernels", "dep:ug-metal"] [[bench]] name = "bench_main" diff --git a/candle-core/src/custom_op.rs b/candle-core/src/custom_op.rs index 276e3658..c0d97d67 100644 --- a/candle-core/src/custom_op.rs +++ b/candle-core/src/custom_op.rs @@ -380,6 +380,8 @@ pub struct UgIOp1 { name: &'static str, #[cfg(feature = "cuda")] func: cudarc::driver::CudaFunction, + #[cfg(feature = "metal")] + func: metal::ComputePipelineState, } impl UgIOp1 { @@ -395,7 +397,13 @@ impl UgIOp1 { let func = device.compile(name, kernel)?; Ok(Self { name, func }) } - #[cfg(not(feature = "cuda"))] + #[cfg(feature = "metal")] + { + let device = device.as_metal_device()?; + let func = device.compile(name, kernel)?; + Ok(Self { name, func }) + } + #[cfg(not(any(feature = "cuda", feature = "metal")))] { Ok(Self { name }) } @@ -408,11 +416,43 @@ impl InplaceOp1 for UgIOp1 { } fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> { - crate::bail!("ug ops are only supported on cuda at the moment") + crate::bail!("ug ops are only supported on metal/cuda at the moment") } - fn metal_fwd(&self, _: &mut MetalStorage, _: &Layout) -> Result<()> { - crate::bail!("ug ops are only supported on cuda at the moment") + #[cfg(feature = "metal")] + fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> { + use crate::backend::BackendStorage; + use candle_metal_kernels::utils::EncoderProvider; + + let elem_count = layout.shape().elem_count(); + if sto.dtype() != crate::DType::F32 { + // TODO: support more dtypes. + crate::bail!("input is not a f32 tensor") + } + let device = sto.device(); + println!("here"); + let command_buffer = device.command_buffer()?; + let command_buffer = &command_buffer; + let encoder = command_buffer.encoder(); + let encoder = encoder.as_ref(); + encoder.set_compute_pipeline_state(&self.func); + let (g, b) = if elem_count % 32 == 0 { + (elem_count / 32, 32) + } else { + (elem_count, 1) + }; + let grid_dims = metal::MTLSize { + width: g as u64, + height: 1, + depth: 1, + }; + let group_dims = candle_metal_kernels::utils::get_block_dims(b as u64, 1, 1); + candle_metal_kernels::utils::set_param(encoder, 0, (sto.buffer(), 0usize)); + + encoder.use_resource(sto.buffer(), metal::MTLResourceUsage::Write); + encoder.dispatch_threads(grid_dims, group_dims); + + Ok(()) } #[cfg(feature = "cuda")] diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 91925b57..18aa61af 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -138,6 +138,14 @@ impl Device { } } + pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> { + match self { + Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"), + Self::Cpu => crate::bail!("expected a metal device, got cpu"), + Self::Metal(d) => Ok(d), + } + } + pub fn new_cuda_with_stream(ordinal: usize) -> Result { Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?)) } diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index 29b8995b..46be6ce4 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -144,6 +144,28 @@ impl MetalDevice { self.use_mlx_mm = use_mlx_mm } + pub fn compile( + &self, + func_name: &'static str, + kernel: ug::lang::ssa::Kernel, + ) -> Result { + let mut buf = vec![]; + ug_metal::code_gen::gen(&mut buf, func_name, &kernel)?; + let metal_code = String::from_utf8(buf)?; + let lib = self + .device + .new_library_with_source(&metal_code, &metal::CompileOptions::new()) + .map_err(MetalError::from)?; + let func = lib + .get_function(func_name, None) + .map_err(MetalError::from)?; + let pl = self + .device + .new_compute_pipeline_state_with_function(&func) + .map_err(MetalError::from)?; + Ok(pl) + } + pub fn id(&self) -> DeviceId { self.id } diff --git a/candle-core/tests/custom_op_tests.rs b/candle-core/tests/custom_op_tests.rs index f2c01aca..3572a4c9 100644 --- a/candle-core/tests/custom_op_tests.rs +++ b/candle-core/tests/custom_op_tests.rs @@ -144,7 +144,7 @@ fn inplace_op1() -> Result<()> { Ok(()) } -#[cfg(feature = "cuda")] +#[cfg(any(feature = "cuda", feature = "metal"))] #[allow(clippy::approx_constant)] #[test] fn ug_op() -> Result<()> { @@ -160,15 +160,21 @@ fn ug_op() -> Result<()> { let opts: ug::lower_op::Opts = Default::default(); kernel.lower(&opts.with_global(0, 12))? }; - let device = Device::new_cuda(0)?; + let device = if candle_core::utils::cuda_is_available() { + Device::new_cuda(0)? + } else if candle_core::utils::metal_is_available() { + Device::new_metal(0)? + } else { + candle_core::bail!("metal/cuda is mandatory for this test") + }; let op = candle_core::UgIOp1::new("test", kernel, &device)?; let t = Tensor::arange(0u32, 12u32, &device)?.to_dtype(DType::F32)?; t.inplace_op1(&op)?; assert_eq!( - to_vec1_round(&t, 4)?, + to_vec1_round(&t, 2)?, &[ - 1.0, 2.7183, 7.3891, 20.0855, 54.5982, 148.4132, 403.4287, 1096.6334, 2980.9578, - 8103.0806, 22026.469, 59874.133 + 1.0, 2.72, 7.39, 20.09, 54.6, 148.41, 403.43, 1096.63, 2980.96, 8103.08, 22026.47, + 59874.13 ] ); Ok(()) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index be616009..222ae8ad 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -6,7 +6,7 @@ use std::collections::HashMap; use std::ffi::c_void; use std::sync::RwLock; -mod utils; +pub mod utils; pub use utils::BufferOffset; use utils::{get_block_dims, linear_split, EncoderProvider}; diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs index d2cc09f4..0092ecfa 100644 --- a/candle-metal-kernels/src/utils.rs +++ b/candle-metal-kernels/src/utils.rs @@ -24,7 +24,7 @@ pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (M } // https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96 -pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize { +pub fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize { let mut pows0 = 0u64; let mut pows1 = 0u64; let mut pows2 = 0u64; @@ -61,18 +61,14 @@ pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize { } } -pub(crate) fn set_param( - encoder: &ComputeCommandEncoderRef, - position: u64, - data: P, -) { +pub fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {

::set_param(encoder, position, data) } /// Helper functions to create the various objects on the compute command encoder /// on a single line. /// Prevents getting wrong some arguments number and mixing length and size in bytes. -pub(crate) trait EncoderParam { +pub trait EncoderParam { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); } macro_rules! primitive {