Metal commands refactoring (#2489)

* Split out the commands part of the metal device.

* Make most fields private.

* Move the allocator back.

* Rework the encoder provider type.
This commit is contained in:
Laurent Mazare 2024-09-21 13:18:42 +02:00 committed by GitHub
parent 5fc4f17727
commit af2104078f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 145 additions and 108 deletions

View File

@ -4,7 +4,7 @@ use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}
use std::collections::HashMap; use std::collections::HashMap;
use std::ffi::c_void; use std::ffi::c_void;
use std::path::Path; use std::path::Path;
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard}; use std::sync::{Arc, Mutex, RwLock};
use super::MetalError; use super::MetalError;
@ -22,7 +22,73 @@ impl DeviceId {
} }
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>; type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
type AllocatedBuffers = Arc<RwLock<BufferMap>>; pub(crate) struct Commands {
/// Single command queue for the entire device.
command_queue: CommandQueue,
/// One command buffer at a time.
/// The scheduler works by allowing multiple
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
/// to start to work).
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
/// for their START time, but there's no guarantee that command buffer1 will finish before
/// command buffer2 starts (or there are metal bugs there)
command_buffer: CommandBuffer,
/// Keeps track of the current amount of compute command encoders on the current
/// command buffer
/// Arc, RwLock because of the interior mutability.
command_buffer_index: usize,
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
compute_per_buffer: usize,
}
impl Commands {
pub(crate) fn new(command_queue: CommandQueue) -> Result<Self> {
let command_buffer = command_queue.new_command_buffer().to_owned();
command_buffer.enqueue();
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
Ok(val) => val.parse()?,
_ => 50,
};
Ok(Self {
command_queue,
command_buffer,
command_buffer_index: 0,
compute_per_buffer,
})
}
pub fn command_buffer(&mut self) -> Result<(bool, CommandBuffer)> {
let mut command_buffer = self.command_buffer.to_owned();
let mut flushed = false;
if self.command_buffer_index > self.compute_per_buffer {
self.command_buffer.commit();
command_buffer = self.command_queue.new_command_buffer().to_owned();
self.command_buffer = command_buffer.clone();
self.command_buffer_index = 0;
flushed = true;
}
self.command_buffer_index += 1;
Ok((flushed, command_buffer))
}
pub fn wait_until_completed(&mut self) -> Result<()> {
match self.command_buffer.status() {
metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled
| metal::MTLCommandBufferStatus::Completed => {
panic!("Already committed");
}
_ => {}
}
self.command_buffer.commit();
self.command_buffer.wait_until_completed();
self.command_buffer = self.command_queue.new_command_buffer().to_owned();
Ok(())
}
}
#[derive(Clone)] #[derive(Clone)]
pub struct MetalDevice { pub struct MetalDevice {
@ -33,27 +99,8 @@ pub struct MetalDevice {
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc> /// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
pub(crate) device: metal::Device, pub(crate) device: metal::Device,
/// Single command queue for the entire device. pub(crate) commands: Arc<RwLock<Commands>>,
pub(crate) command_queue: CommandQueue,
/// One command buffer at a time.
/// The scheduler works by allowing multiple
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
/// to start to work).
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
/// for their START time, but there's no guarantee that command buffer1 will finish before
/// command buffer2 starts (or there are metal bugs there)
pub(crate) command_buffer: Arc<RwLock<CommandBuffer>>,
/// Keeps track of the current amount of compute command encoders on the current
/// command buffer
/// Arc, RwLock because of the interior mutability.
pub(crate) command_buffer_index: Arc<RwLock<usize>>,
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
pub(crate) compute_per_buffer: usize,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
/// Heavily used by [`candle_metal_kernels`]
pub(crate) kernels: Arc<Kernels>,
/// Simple allocator struct. /// Simple allocator struct.
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over. /// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting /// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
@ -67,7 +114,11 @@ pub struct MetalDevice {
/// ///
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers /// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
/// (strong_count = 1). /// (strong_count = 1).
pub(crate) buffers: AllocatedBuffers, pub(crate) buffers: Arc<RwLock<BufferMap>>,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
/// Heavily used by [`candle_metal_kernels`]
pub(crate) kernels: Arc<Kernels>,
/// Seed for random number generation. /// Seed for random number generation.
pub(crate) seed: Arc<Mutex<Buffer>>, pub(crate) seed: Arc<Mutex<Buffer>>,
/// Whether to use the MLX matmul kernels instead of the MFA ones. /// Whether to use the MLX matmul kernels instead of the MFA ones.
@ -101,44 +152,31 @@ impl MetalDevice {
&self.device &self.device
} }
pub fn command_queue(&self) -> &CommandQueue { fn drop_unused_buffers(&self) -> Result<()> {
&self.command_queue let mut buffers = self.buffers.write().map_err(MetalError::from)?;
for subbuffers in buffers.values_mut() {
let newbuffers = subbuffers
.iter()
.filter(|s| Arc::strong_count(*s) > 1)
.map(Arc::clone)
.collect();
*subbuffers = newbuffers;
}
Ok(())
} }
pub fn command_buffer(&self) -> Result<CommandBuffer> { pub fn command_buffer(&self) -> Result<CommandBuffer> {
let mut command_buffer_lock = self.command_buffer.write().map_err(MetalError::from)?; let mut commands = self.commands.write().map_err(MetalError::from)?;
let mut command_buffer = command_buffer_lock.to_owned(); let (flushed, command_buffer) = commands.command_buffer()?;
let mut index = self if flushed {
.command_buffer_index self.drop_unused_buffers()?
.write()
.map_err(MetalError::from)?;
if *index > self.compute_per_buffer {
command_buffer.commit();
command_buffer = self.command_queue.new_command_buffer().to_owned();
*command_buffer_lock = command_buffer.clone();
*index = 0;
self.drop_unused_buffers()?;
} }
*index += 1;
Ok(command_buffer) Ok(command_buffer)
} }
pub fn wait_until_completed(&self) -> Result<()> { pub fn wait_until_completed(&self) -> Result<()> {
let mut command_buffer = self.command_buffer.write().map_err(MetalError::from)?; let mut commands = self.commands.write().map_err(MetalError::from)?;
match command_buffer.status() { commands.wait_until_completed()
metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled
| metal::MTLCommandBufferStatus::Completed => {
panic!("Already committed");
}
_ => {}
}
command_buffer.commit();
command_buffer.wait_until_completed();
*command_buffer = self.command_queue.new_command_buffer().to_owned();
Ok(())
} }
pub fn kernels(&self) -> &Kernels { pub fn kernels(&self) -> &Kernels {
@ -186,6 +224,7 @@ impl MetalDevice {
MTLResourceOptions::StorageModeManaged, MTLResourceOptions::StorageModeManaged,
); );
let mut buffers = self.buffers.write().map_err(MetalError::from)?; let mut buffers = self.buffers.write().map_err(MetalError::from)?;
let subbuffers = buffers let subbuffers = buffers
.entry((size, MTLResourceOptions::StorageModeManaged)) .entry((size, MTLResourceOptions::StorageModeManaged))
.or_insert(vec![]); .or_insert(vec![]);
@ -216,40 +255,6 @@ impl MetalDevice {
Ok(buffer) Ok(buffer)
} }
fn find_available_buffer(
&self,
size: NSUInteger,
option: MTLResourceOptions,
buffers: &RwLockWriteGuard<BufferMap>,
) -> Option<Arc<Buffer>> {
let mut best_buffer: Option<&Arc<Buffer>> = None;
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
for sub in subbuffers {
if Arc::strong_count(sub) == 1 {
best_buffer = Some(sub);
best_buffer_size = *buffer_size;
}
}
}
}
best_buffer.cloned()
}
fn drop_unused_buffers(&self) -> Result<()> {
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
for subbuffers in buffers.values_mut() {
let newbuffers = subbuffers
.iter()
.filter(|s| Arc::strong_count(*s) > 1)
.map(Arc::clone)
.collect();
*subbuffers = newbuffers;
}
Ok(())
}
/// The critical allocator algorithm /// The critical allocator algorithm
fn allocate_buffer( fn allocate_buffer(
&self, &self,
@ -258,7 +263,7 @@ impl MetalDevice {
_name: &str, _name: &str,
) -> Result<Arc<Buffer>> { ) -> Result<Arc<Buffer>> {
let mut buffers = self.buffers.write().map_err(MetalError::from)?; let mut buffers = self.buffers.write().map_err(MetalError::from)?;
if let Some(b) = self.find_available_buffer(size, option, &buffers) { if let Some(b) = find_available_buffer(size, option, &buffers) {
// Cloning also ensures we increment the strong count // Cloning also ensures we increment the strong count
return Ok(b.clone()); return Ok(b.clone());
} }
@ -297,3 +302,23 @@ impl MetalDevice {
fn buf_size(size: NSUInteger) -> NSUInteger { fn buf_size(size: NSUInteger) -> NSUInteger {
size.saturating_sub(1).next_power_of_two() as NSUInteger size.saturating_sub(1).next_power_of_two() as NSUInteger
} }
fn find_available_buffer(
size: NSUInteger,
option: MTLResourceOptions,
buffers: &BufferMap,
) -> Option<Arc<Buffer>> {
let mut best_buffer: Option<&Arc<Buffer>> = None;
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
for sub in subbuffers {
if Arc::strong_count(sub) == 1 {
best_buffer = Some(sub);
best_buffer_size = *buffer_size;
}
}
}
}
best_buffer.cloned()
}

View File

@ -1864,33 +1864,22 @@ impl BackendDevice for MetalDevice {
fn new(ordinal: usize) -> Result<Self> { fn new(ordinal: usize) -> Result<Self> {
let device = metal::Device::all().swap_remove(ordinal); let device = metal::Device::all().swap_remove(ordinal);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer().to_owned();
command_buffer.enqueue();
let command_buffer = Arc::new(RwLock::new(command_buffer));
let command_buffer_index = Arc::new(RwLock::new(0));
let kernels = Arc::new(Kernels::new()); let kernels = Arc::new(Kernels::new());
let buffers = Arc::new(RwLock::new(HashMap::new()));
let use_mlx_mm = match std::env::var("CANDLE_USE_MLX_MM").as_deref() { let use_mlx_mm = match std::env::var("CANDLE_USE_MLX_MM").as_deref() {
Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => false, Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => false,
Ok(_) => true, Ok(_) => true,
}; };
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
Ok(val) => val.parse()?,
_ => 50,
};
let seed = Arc::new(Mutex::new(device.new_buffer_with_data( let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
[299792458].as_ptr() as *const c_void, [299792458].as_ptr() as *const c_void,
4, 4,
MTLResourceOptions::StorageModeManaged, MTLResourceOptions::StorageModeManaged,
))); )));
let commands = device::Commands::new(command_queue)?;
Ok(Self { Ok(Self {
id: DeviceId::new(), id: DeviceId::new(),
device, device,
command_queue, commands: Arc::new(RwLock::new(commands)),
command_buffer, buffers: Arc::new(RwLock::new(HashMap::new())),
command_buffer_index,
compute_per_buffer,
buffers,
kernels, kernels,
seed, seed,
use_mlx_mm, use_mlx_mm,

View File

@ -168,17 +168,22 @@ pub trait EncoderProvider {
fn encoder(&self) -> Self::Encoder<'_>; fn encoder(&self) -> Self::Encoder<'_>;
} }
pub struct WrappedEncoder<'a>(&'a ComputeCommandEncoderRef); pub struct WrappedEncoder<'a> {
inner: &'a ComputeCommandEncoderRef,
end_encoding_on_drop: bool,
}
impl<'a> Drop for WrappedEncoder<'a> { impl<'a> Drop for WrappedEncoder<'a> {
fn drop(&mut self) { fn drop(&mut self) {
self.0.end_encoding() if self.end_encoding_on_drop {
self.inner.end_encoding()
}
} }
} }
impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> { impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> {
fn as_ref(&self) -> &metal::ComputeCommandEncoderRef { fn as_ref(&self) -> &metal::ComputeCommandEncoderRef {
self.0 self.inner
} }
} }
@ -187,7 +192,10 @@ impl EncoderProvider for &metal::CommandBuffer {
where where
Self: 'a; Self: 'a;
fn encoder(&self) -> Self::Encoder<'_> { fn encoder(&self) -> Self::Encoder<'_> {
WrappedEncoder(self.new_compute_command_encoder()) WrappedEncoder {
inner: self.new_compute_command_encoder(),
end_encoding_on_drop: true,
}
} }
} }
@ -196,6 +204,21 @@ impl EncoderProvider for &metal::CommandBufferRef {
where where
Self: 'a; Self: 'a;
fn encoder(&self) -> Self::Encoder<'_> { fn encoder(&self) -> Self::Encoder<'_> {
WrappedEncoder(self.new_compute_command_encoder()) WrappedEncoder {
inner: self.new_compute_command_encoder(),
end_encoding_on_drop: true,
}
}
}
impl EncoderProvider for &ComputeCommandEncoderRef {
type Encoder<'a> = WrappedEncoder<'a>
where
Self: 'a;
fn encoder(&self) -> Self::Encoder<'_> {
WrappedEncoder {
inner: self,
end_encoding_on_drop: false,
}
} }
} }