diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index 3deb465b..29b8995b 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -4,7 +4,7 @@ use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger} use std::collections::HashMap; use std::ffi::c_void; use std::path::Path; -use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard}; +use std::sync::{Arc, Mutex, RwLock}; use super::MetalError; @@ -22,7 +22,73 @@ impl DeviceId { } type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec>>; -type AllocatedBuffers = Arc>; +pub(crate) struct Commands { + /// Single command queue for the entire device. + command_queue: CommandQueue, + /// One command buffer at a time. + /// The scheduler works by allowing multiple + /// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) + /// on a single command buffer. Using a single command buffer would be fastest on the GPU but + /// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed + /// to start to work). + /// Despite what the documentation says, command buffers are NOT ordered. They are ordered + /// for their START time, but there's no guarantee that command buffer1 will finish before + /// command buffer2 starts (or there are metal bugs there) + command_buffer: CommandBuffer, + /// Keeps track of the current amount of compute command encoders on the current + /// command buffer + /// Arc, RwLock because of the interior mutability. + command_buffer_index: usize, + /// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc) + compute_per_buffer: usize, +} + +impl Commands { + pub(crate) fn new(command_queue: CommandQueue) -> Result { + let command_buffer = command_queue.new_command_buffer().to_owned(); + command_buffer.enqueue(); + let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") { + Ok(val) => val.parse()?, + _ => 50, + }; + Ok(Self { + command_queue, + command_buffer, + command_buffer_index: 0, + compute_per_buffer, + }) + } + + pub fn command_buffer(&mut self) -> Result<(bool, CommandBuffer)> { + let mut command_buffer = self.command_buffer.to_owned(); + let mut flushed = false; + if self.command_buffer_index > self.compute_per_buffer { + self.command_buffer.commit(); + command_buffer = self.command_queue.new_command_buffer().to_owned(); + self.command_buffer = command_buffer.clone(); + self.command_buffer_index = 0; + flushed = true; + } + self.command_buffer_index += 1; + Ok((flushed, command_buffer)) + } + + pub fn wait_until_completed(&mut self) -> Result<()> { + match self.command_buffer.status() { + metal::MTLCommandBufferStatus::Committed + | metal::MTLCommandBufferStatus::Scheduled + | metal::MTLCommandBufferStatus::Completed => { + panic!("Already committed"); + } + _ => {} + } + self.command_buffer.commit(); + self.command_buffer.wait_until_completed(); + self.command_buffer = self.command_queue.new_command_buffer().to_owned(); + + Ok(()) + } +} #[derive(Clone)] pub struct MetalDevice { @@ -33,27 +99,8 @@ pub struct MetalDevice { /// Raw metal device: pub(crate) device: metal::Device, - /// Single command queue for the entire device. - pub(crate) command_queue: CommandQueue, - /// One command buffer at a time. - /// The scheduler works by allowing multiple - /// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) - /// on a single command buffer. Using a single command buffer would be fastest on the GPU but - /// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed - /// to start to work). - /// Despite what the documentation says, command buffers are NOT ordered. They are ordered - /// for their START time, but there's no guarantee that command buffer1 will finish before - /// command buffer2 starts (or there are metal bugs there) - pub(crate) command_buffer: Arc>, - /// Keeps track of the current amount of compute command encoders on the current - /// command buffer - /// Arc, RwLock because of the interior mutability. - pub(crate) command_buffer_index: Arc>, - /// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc) - pub(crate) compute_per_buffer: usize, - /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them. - /// Heavily used by [`candle_metal_kernels`] - pub(crate) kernels: Arc, + pub(crate) commands: Arc>, + /// Simple allocator struct. /// The buffers are stored in size buckets since ML tends to use similar shapes over and over. /// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting @@ -67,7 +114,11 @@ pub struct MetalDevice { /// /// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers /// (strong_count = 1). - pub(crate) buffers: AllocatedBuffers, + pub(crate) buffers: Arc>, + + /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them. + /// Heavily used by [`candle_metal_kernels`] + pub(crate) kernels: Arc, /// Seed for random number generation. pub(crate) seed: Arc>, /// Whether to use the MLX matmul kernels instead of the MFA ones. @@ -101,44 +152,31 @@ impl MetalDevice { &self.device } - pub fn command_queue(&self) -> &CommandQueue { - &self.command_queue + fn drop_unused_buffers(&self) -> Result<()> { + let mut buffers = self.buffers.write().map_err(MetalError::from)?; + for subbuffers in buffers.values_mut() { + let newbuffers = subbuffers + .iter() + .filter(|s| Arc::strong_count(*s) > 1) + .map(Arc::clone) + .collect(); + *subbuffers = newbuffers; + } + Ok(()) } pub fn command_buffer(&self) -> Result { - let mut command_buffer_lock = self.command_buffer.write().map_err(MetalError::from)?; - let mut command_buffer = command_buffer_lock.to_owned(); - let mut index = self - .command_buffer_index - .write() - .map_err(MetalError::from)?; - if *index > self.compute_per_buffer { - command_buffer.commit(); - command_buffer = self.command_queue.new_command_buffer().to_owned(); - *command_buffer_lock = command_buffer.clone(); - *index = 0; - - self.drop_unused_buffers()?; + let mut commands = self.commands.write().map_err(MetalError::from)?; + let (flushed, command_buffer) = commands.command_buffer()?; + if flushed { + self.drop_unused_buffers()? } - *index += 1; Ok(command_buffer) } pub fn wait_until_completed(&self) -> Result<()> { - let mut command_buffer = self.command_buffer.write().map_err(MetalError::from)?; - match command_buffer.status() { - metal::MTLCommandBufferStatus::Committed - | metal::MTLCommandBufferStatus::Scheduled - | metal::MTLCommandBufferStatus::Completed => { - panic!("Already committed"); - } - _ => {} - } - command_buffer.commit(); - command_buffer.wait_until_completed(); - *command_buffer = self.command_queue.new_command_buffer().to_owned(); - - Ok(()) + let mut commands = self.commands.write().map_err(MetalError::from)?; + commands.wait_until_completed() } pub fn kernels(&self) -> &Kernels { @@ -186,6 +224,7 @@ impl MetalDevice { MTLResourceOptions::StorageModeManaged, ); let mut buffers = self.buffers.write().map_err(MetalError::from)?; + let subbuffers = buffers .entry((size, MTLResourceOptions::StorageModeManaged)) .or_insert(vec![]); @@ -216,40 +255,6 @@ impl MetalDevice { Ok(buffer) } - fn find_available_buffer( - &self, - size: NSUInteger, - option: MTLResourceOptions, - buffers: &RwLockWriteGuard, - ) -> Option> { - let mut best_buffer: Option<&Arc> = None; - let mut best_buffer_size: NSUInteger = NSUInteger::MAX; - for ((buffer_size, buffer_option), subbuffers) in buffers.iter() { - if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option { - for sub in subbuffers { - if Arc::strong_count(sub) == 1 { - best_buffer = Some(sub); - best_buffer_size = *buffer_size; - } - } - } - } - best_buffer.cloned() - } - - fn drop_unused_buffers(&self) -> Result<()> { - let mut buffers = self.buffers.write().map_err(MetalError::from)?; - for subbuffers in buffers.values_mut() { - let newbuffers = subbuffers - .iter() - .filter(|s| Arc::strong_count(*s) > 1) - .map(Arc::clone) - .collect(); - *subbuffers = newbuffers; - } - Ok(()) - } - /// The critical allocator algorithm fn allocate_buffer( &self, @@ -258,7 +263,7 @@ impl MetalDevice { _name: &str, ) -> Result> { let mut buffers = self.buffers.write().map_err(MetalError::from)?; - if let Some(b) = self.find_available_buffer(size, option, &buffers) { + if let Some(b) = find_available_buffer(size, option, &buffers) { // Cloning also ensures we increment the strong count return Ok(b.clone()); } @@ -297,3 +302,23 @@ impl MetalDevice { fn buf_size(size: NSUInteger) -> NSUInteger { size.saturating_sub(1).next_power_of_two() as NSUInteger } + +fn find_available_buffer( + size: NSUInteger, + option: MTLResourceOptions, + buffers: &BufferMap, +) -> Option> { + let mut best_buffer: Option<&Arc> = None; + let mut best_buffer_size: NSUInteger = NSUInteger::MAX; + for ((buffer_size, buffer_option), subbuffers) in buffers.iter() { + if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option { + for sub in subbuffers { + if Arc::strong_count(sub) == 1 { + best_buffer = Some(sub); + best_buffer_size = *buffer_size; + } + } + } + } + best_buffer.cloned() +} diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 9c980db8..69edd2d1 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1864,33 +1864,22 @@ impl BackendDevice for MetalDevice { fn new(ordinal: usize) -> Result { let device = metal::Device::all().swap_remove(ordinal); let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer().to_owned(); - command_buffer.enqueue(); - let command_buffer = Arc::new(RwLock::new(command_buffer)); - let command_buffer_index = Arc::new(RwLock::new(0)); let kernels = Arc::new(Kernels::new()); - let buffers = Arc::new(RwLock::new(HashMap::new())); let use_mlx_mm = match std::env::var("CANDLE_USE_MLX_MM").as_deref() { Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => false, Ok(_) => true, }; - let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") { - Ok(val) => val.parse()?, - _ => 50, - }; let seed = Arc::new(Mutex::new(device.new_buffer_with_data( [299792458].as_ptr() as *const c_void, 4, MTLResourceOptions::StorageModeManaged, ))); + let commands = device::Commands::new(command_queue)?; Ok(Self { id: DeviceId::new(), device, - command_queue, - command_buffer, - command_buffer_index, - compute_per_buffer, - buffers, + commands: Arc::new(RwLock::new(commands)), + buffers: Arc::new(RwLock::new(HashMap::new())), kernels, seed, use_mlx_mm, diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs index 2ddd610b..d2cc09f4 100644 --- a/candle-metal-kernels/src/utils.rs +++ b/candle-metal-kernels/src/utils.rs @@ -168,17 +168,22 @@ pub trait EncoderProvider { fn encoder(&self) -> Self::Encoder<'_>; } -pub struct WrappedEncoder<'a>(&'a ComputeCommandEncoderRef); +pub struct WrappedEncoder<'a> { + inner: &'a ComputeCommandEncoderRef, + end_encoding_on_drop: bool, +} impl<'a> Drop for WrappedEncoder<'a> { fn drop(&mut self) { - self.0.end_encoding() + if self.end_encoding_on_drop { + self.inner.end_encoding() + } } } impl<'a> AsRef for WrappedEncoder<'a> { fn as_ref(&self) -> &metal::ComputeCommandEncoderRef { - self.0 + self.inner } } @@ -187,7 +192,10 @@ impl EncoderProvider for &metal::CommandBuffer { where Self: 'a; fn encoder(&self) -> Self::Encoder<'_> { - WrappedEncoder(self.new_compute_command_encoder()) + WrappedEncoder { + inner: self.new_compute_command_encoder(), + end_encoding_on_drop: true, + } } } @@ -196,6 +204,21 @@ impl EncoderProvider for &metal::CommandBufferRef { where Self: 'a; fn encoder(&self) -> Self::Encoder<'_> { - WrappedEncoder(self.new_compute_command_encoder()) + WrappedEncoder { + inner: self.new_compute_command_encoder(), + end_encoding_on_drop: true, + } + } +} + +impl EncoderProvider for &ComputeCommandEncoderRef { + type Encoder<'a> = WrappedEncoder<'a> + where + Self: 'a; + fn encoder(&self) -> Self::Encoder<'_> { + WrappedEncoder { + inner: self, + end_encoding_on_drop: false, + } } }