Metal commands refactoring (#2489)
* Split out the commands part of the metal device. * Make most fields private. * Move the allocator back. * Rework the encoder provider type.
This commit is contained in:
parent
5fc4f17727
commit
af2104078f
|
@ -4,7 +4,7 @@ use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard};
|
use std::sync::{Arc, Mutex, RwLock};
|
||||||
|
|
||||||
use super::MetalError;
|
use super::MetalError;
|
||||||
|
|
||||||
|
@ -22,7 +22,73 @@ impl DeviceId {
|
||||||
}
|
}
|
||||||
|
|
||||||
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
|
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
|
||||||
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
|
pub(crate) struct Commands {
|
||||||
|
/// Single command queue for the entire device.
|
||||||
|
command_queue: CommandQueue,
|
||||||
|
/// One command buffer at a time.
|
||||||
|
/// The scheduler works by allowing multiple
|
||||||
|
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
|
||||||
|
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
|
||||||
|
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
|
||||||
|
/// to start to work).
|
||||||
|
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
|
||||||
|
/// for their START time, but there's no guarantee that command buffer1 will finish before
|
||||||
|
/// command buffer2 starts (or there are metal bugs there)
|
||||||
|
command_buffer: CommandBuffer,
|
||||||
|
/// Keeps track of the current amount of compute command encoders on the current
|
||||||
|
/// command buffer
|
||||||
|
/// Arc, RwLock because of the interior mutability.
|
||||||
|
command_buffer_index: usize,
|
||||||
|
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
||||||
|
compute_per_buffer: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Commands {
|
||||||
|
pub(crate) fn new(command_queue: CommandQueue) -> Result<Self> {
|
||||||
|
let command_buffer = command_queue.new_command_buffer().to_owned();
|
||||||
|
command_buffer.enqueue();
|
||||||
|
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
|
||||||
|
Ok(val) => val.parse()?,
|
||||||
|
_ => 50,
|
||||||
|
};
|
||||||
|
Ok(Self {
|
||||||
|
command_queue,
|
||||||
|
command_buffer,
|
||||||
|
command_buffer_index: 0,
|
||||||
|
compute_per_buffer,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn command_buffer(&mut self) -> Result<(bool, CommandBuffer)> {
|
||||||
|
let mut command_buffer = self.command_buffer.to_owned();
|
||||||
|
let mut flushed = false;
|
||||||
|
if self.command_buffer_index > self.compute_per_buffer {
|
||||||
|
self.command_buffer.commit();
|
||||||
|
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||||
|
self.command_buffer = command_buffer.clone();
|
||||||
|
self.command_buffer_index = 0;
|
||||||
|
flushed = true;
|
||||||
|
}
|
||||||
|
self.command_buffer_index += 1;
|
||||||
|
Ok((flushed, command_buffer))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn wait_until_completed(&mut self) -> Result<()> {
|
||||||
|
match self.command_buffer.status() {
|
||||||
|
metal::MTLCommandBufferStatus::Committed
|
||||||
|
| metal::MTLCommandBufferStatus::Scheduled
|
||||||
|
| metal::MTLCommandBufferStatus::Completed => {
|
||||||
|
panic!("Already committed");
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
self.command_buffer.commit();
|
||||||
|
self.command_buffer.wait_until_completed();
|
||||||
|
self.command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct MetalDevice {
|
pub struct MetalDevice {
|
||||||
|
@ -33,27 +99,8 @@ pub struct MetalDevice {
|
||||||
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
|
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
|
||||||
pub(crate) device: metal::Device,
|
pub(crate) device: metal::Device,
|
||||||
|
|
||||||
/// Single command queue for the entire device.
|
pub(crate) commands: Arc<RwLock<Commands>>,
|
||||||
pub(crate) command_queue: CommandQueue,
|
|
||||||
/// One command buffer at a time.
|
|
||||||
/// The scheduler works by allowing multiple
|
|
||||||
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
|
|
||||||
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
|
|
||||||
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
|
|
||||||
/// to start to work).
|
|
||||||
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
|
|
||||||
/// for their START time, but there's no guarantee that command buffer1 will finish before
|
|
||||||
/// command buffer2 starts (or there are metal bugs there)
|
|
||||||
pub(crate) command_buffer: Arc<RwLock<CommandBuffer>>,
|
|
||||||
/// Keeps track of the current amount of compute command encoders on the current
|
|
||||||
/// command buffer
|
|
||||||
/// Arc, RwLock because of the interior mutability.
|
|
||||||
pub(crate) command_buffer_index: Arc<RwLock<usize>>,
|
|
||||||
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
|
||||||
pub(crate) compute_per_buffer: usize,
|
|
||||||
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
|
||||||
/// Heavily used by [`candle_metal_kernels`]
|
|
||||||
pub(crate) kernels: Arc<Kernels>,
|
|
||||||
/// Simple allocator struct.
|
/// Simple allocator struct.
|
||||||
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
||||||
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
|
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
|
||||||
|
@ -67,7 +114,11 @@ pub struct MetalDevice {
|
||||||
///
|
///
|
||||||
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
|
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
|
||||||
/// (strong_count = 1).
|
/// (strong_count = 1).
|
||||||
pub(crate) buffers: AllocatedBuffers,
|
pub(crate) buffers: Arc<RwLock<BufferMap>>,
|
||||||
|
|
||||||
|
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
||||||
|
/// Heavily used by [`candle_metal_kernels`]
|
||||||
|
pub(crate) kernels: Arc<Kernels>,
|
||||||
/// Seed for random number generation.
|
/// Seed for random number generation.
|
||||||
pub(crate) seed: Arc<Mutex<Buffer>>,
|
pub(crate) seed: Arc<Mutex<Buffer>>,
|
||||||
/// Whether to use the MLX matmul kernels instead of the MFA ones.
|
/// Whether to use the MLX matmul kernels instead of the MFA ones.
|
||||||
|
@ -101,44 +152,31 @@ impl MetalDevice {
|
||||||
&self.device
|
&self.device
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn command_queue(&self) -> &CommandQueue {
|
fn drop_unused_buffers(&self) -> Result<()> {
|
||||||
&self.command_queue
|
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||||
|
for subbuffers in buffers.values_mut() {
|
||||||
|
let newbuffers = subbuffers
|
||||||
|
.iter()
|
||||||
|
.filter(|s| Arc::strong_count(*s) > 1)
|
||||||
|
.map(Arc::clone)
|
||||||
|
.collect();
|
||||||
|
*subbuffers = newbuffers;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
||||||
let mut command_buffer_lock = self.command_buffer.write().map_err(MetalError::from)?;
|
let mut commands = self.commands.write().map_err(MetalError::from)?;
|
||||||
let mut command_buffer = command_buffer_lock.to_owned();
|
let (flushed, command_buffer) = commands.command_buffer()?;
|
||||||
let mut index = self
|
if flushed {
|
||||||
.command_buffer_index
|
self.drop_unused_buffers()?
|
||||||
.write()
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
if *index > self.compute_per_buffer {
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
|
||||||
*command_buffer_lock = command_buffer.clone();
|
|
||||||
*index = 0;
|
|
||||||
|
|
||||||
self.drop_unused_buffers()?;
|
|
||||||
}
|
}
|
||||||
*index += 1;
|
|
||||||
Ok(command_buffer)
|
Ok(command_buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn wait_until_completed(&self) -> Result<()> {
|
pub fn wait_until_completed(&self) -> Result<()> {
|
||||||
let mut command_buffer = self.command_buffer.write().map_err(MetalError::from)?;
|
let mut commands = self.commands.write().map_err(MetalError::from)?;
|
||||||
match command_buffer.status() {
|
commands.wait_until_completed()
|
||||||
metal::MTLCommandBufferStatus::Committed
|
|
||||||
| metal::MTLCommandBufferStatus::Scheduled
|
|
||||||
| metal::MTLCommandBufferStatus::Completed => {
|
|
||||||
panic!("Already committed");
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
*command_buffer = self.command_queue.new_command_buffer().to_owned();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn kernels(&self) -> &Kernels {
|
pub fn kernels(&self) -> &Kernels {
|
||||||
|
@ -186,6 +224,7 @@ impl MetalDevice {
|
||||||
MTLResourceOptions::StorageModeManaged,
|
MTLResourceOptions::StorageModeManaged,
|
||||||
);
|
);
|
||||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||||
|
|
||||||
let subbuffers = buffers
|
let subbuffers = buffers
|
||||||
.entry((size, MTLResourceOptions::StorageModeManaged))
|
.entry((size, MTLResourceOptions::StorageModeManaged))
|
||||||
.or_insert(vec![]);
|
.or_insert(vec![]);
|
||||||
|
@ -216,40 +255,6 @@ impl MetalDevice {
|
||||||
Ok(buffer)
|
Ok(buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn find_available_buffer(
|
|
||||||
&self,
|
|
||||||
size: NSUInteger,
|
|
||||||
option: MTLResourceOptions,
|
|
||||||
buffers: &RwLockWriteGuard<BufferMap>,
|
|
||||||
) -> Option<Arc<Buffer>> {
|
|
||||||
let mut best_buffer: Option<&Arc<Buffer>> = None;
|
|
||||||
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
|
|
||||||
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
|
|
||||||
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
|
|
||||||
for sub in subbuffers {
|
|
||||||
if Arc::strong_count(sub) == 1 {
|
|
||||||
best_buffer = Some(sub);
|
|
||||||
best_buffer_size = *buffer_size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
best_buffer.cloned()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn drop_unused_buffers(&self) -> Result<()> {
|
|
||||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
|
||||||
for subbuffers in buffers.values_mut() {
|
|
||||||
let newbuffers = subbuffers
|
|
||||||
.iter()
|
|
||||||
.filter(|s| Arc::strong_count(*s) > 1)
|
|
||||||
.map(Arc::clone)
|
|
||||||
.collect();
|
|
||||||
*subbuffers = newbuffers;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The critical allocator algorithm
|
/// The critical allocator algorithm
|
||||||
fn allocate_buffer(
|
fn allocate_buffer(
|
||||||
&self,
|
&self,
|
||||||
|
@ -258,7 +263,7 @@ impl MetalDevice {
|
||||||
_name: &str,
|
_name: &str,
|
||||||
) -> Result<Arc<Buffer>> {
|
) -> Result<Arc<Buffer>> {
|
||||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||||
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
|
if let Some(b) = find_available_buffer(size, option, &buffers) {
|
||||||
// Cloning also ensures we increment the strong count
|
// Cloning also ensures we increment the strong count
|
||||||
return Ok(b.clone());
|
return Ok(b.clone());
|
||||||
}
|
}
|
||||||
|
@ -297,3 +302,23 @@ impl MetalDevice {
|
||||||
fn buf_size(size: NSUInteger) -> NSUInteger {
|
fn buf_size(size: NSUInteger) -> NSUInteger {
|
||||||
size.saturating_sub(1).next_power_of_two() as NSUInteger
|
size.saturating_sub(1).next_power_of_two() as NSUInteger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn find_available_buffer(
|
||||||
|
size: NSUInteger,
|
||||||
|
option: MTLResourceOptions,
|
||||||
|
buffers: &BufferMap,
|
||||||
|
) -> Option<Arc<Buffer>> {
|
||||||
|
let mut best_buffer: Option<&Arc<Buffer>> = None;
|
||||||
|
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
|
||||||
|
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
|
||||||
|
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
|
||||||
|
for sub in subbuffers {
|
||||||
|
if Arc::strong_count(sub) == 1 {
|
||||||
|
best_buffer = Some(sub);
|
||||||
|
best_buffer_size = *buffer_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
best_buffer.cloned()
|
||||||
|
}
|
||||||
|
|
|
@ -1864,33 +1864,22 @@ impl BackendDevice for MetalDevice {
|
||||||
fn new(ordinal: usize) -> Result<Self> {
|
fn new(ordinal: usize) -> Result<Self> {
|
||||||
let device = metal::Device::all().swap_remove(ordinal);
|
let device = metal::Device::all().swap_remove(ordinal);
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer().to_owned();
|
|
||||||
command_buffer.enqueue();
|
|
||||||
let command_buffer = Arc::new(RwLock::new(command_buffer));
|
|
||||||
let command_buffer_index = Arc::new(RwLock::new(0));
|
|
||||||
let kernels = Arc::new(Kernels::new());
|
let kernels = Arc::new(Kernels::new());
|
||||||
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
|
||||||
let use_mlx_mm = match std::env::var("CANDLE_USE_MLX_MM").as_deref() {
|
let use_mlx_mm = match std::env::var("CANDLE_USE_MLX_MM").as_deref() {
|
||||||
Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => false,
|
Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => false,
|
||||||
Ok(_) => true,
|
Ok(_) => true,
|
||||||
};
|
};
|
||||||
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
|
|
||||||
Ok(val) => val.parse()?,
|
|
||||||
_ => 50,
|
|
||||||
};
|
|
||||||
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
|
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
|
||||||
[299792458].as_ptr() as *const c_void,
|
[299792458].as_ptr() as *const c_void,
|
||||||
4,
|
4,
|
||||||
MTLResourceOptions::StorageModeManaged,
|
MTLResourceOptions::StorageModeManaged,
|
||||||
)));
|
)));
|
||||||
|
let commands = device::Commands::new(command_queue)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
id: DeviceId::new(),
|
id: DeviceId::new(),
|
||||||
device,
|
device,
|
||||||
command_queue,
|
commands: Arc::new(RwLock::new(commands)),
|
||||||
command_buffer,
|
buffers: Arc::new(RwLock::new(HashMap::new())),
|
||||||
command_buffer_index,
|
|
||||||
compute_per_buffer,
|
|
||||||
buffers,
|
|
||||||
kernels,
|
kernels,
|
||||||
seed,
|
seed,
|
||||||
use_mlx_mm,
|
use_mlx_mm,
|
||||||
|
|
|
@ -168,17 +168,22 @@ pub trait EncoderProvider {
|
||||||
fn encoder(&self) -> Self::Encoder<'_>;
|
fn encoder(&self) -> Self::Encoder<'_>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct WrappedEncoder<'a>(&'a ComputeCommandEncoderRef);
|
pub struct WrappedEncoder<'a> {
|
||||||
|
inner: &'a ComputeCommandEncoderRef,
|
||||||
|
end_encoding_on_drop: bool,
|
||||||
|
}
|
||||||
|
|
||||||
impl<'a> Drop for WrappedEncoder<'a> {
|
impl<'a> Drop for WrappedEncoder<'a> {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
self.0.end_encoding()
|
if self.end_encoding_on_drop {
|
||||||
|
self.inner.end_encoding()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> {
|
impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> {
|
||||||
fn as_ref(&self) -> &metal::ComputeCommandEncoderRef {
|
fn as_ref(&self) -> &metal::ComputeCommandEncoderRef {
|
||||||
self.0
|
self.inner
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -187,7 +192,10 @@ impl EncoderProvider for &metal::CommandBuffer {
|
||||||
where
|
where
|
||||||
Self: 'a;
|
Self: 'a;
|
||||||
fn encoder(&self) -> Self::Encoder<'_> {
|
fn encoder(&self) -> Self::Encoder<'_> {
|
||||||
WrappedEncoder(self.new_compute_command_encoder())
|
WrappedEncoder {
|
||||||
|
inner: self.new_compute_command_encoder(),
|
||||||
|
end_encoding_on_drop: true,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -196,6 +204,21 @@ impl EncoderProvider for &metal::CommandBufferRef {
|
||||||
where
|
where
|
||||||
Self: 'a;
|
Self: 'a;
|
||||||
fn encoder(&self) -> Self::Encoder<'_> {
|
fn encoder(&self) -> Self::Encoder<'_> {
|
||||||
WrappedEncoder(self.new_compute_command_encoder())
|
WrappedEncoder {
|
||||||
|
inner: self.new_compute_command_encoder(),
|
||||||
|
end_encoding_on_drop: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncoderProvider for &ComputeCommandEncoderRef {
|
||||||
|
type Encoder<'a> = WrappedEncoder<'a>
|
||||||
|
where
|
||||||
|
Self: 'a;
|
||||||
|
fn encoder(&self) -> Self::Encoder<'_> {
|
||||||
|
WrappedEncoder {
|
||||||
|
inner: self,
|
||||||
|
end_encoding_on_drop: false,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue