Really unique identifier for metal device ids. (#1932)

* Really unique identifier for metal device ids.

* Same device.
This commit is contained in:
Laurent Mazare 2024-03-25 11:48:16 +01:00 committed by GitHub
parent e7f8e72588
commit cd254074f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 9 deletions

View File

@ -10,6 +10,19 @@ use std::ffi::c_void;
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard, TryLockError};
/// Unique identifier for cuda devices.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct DeviceId(usize);
impl DeviceId {
fn new() -> Self {
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
use std::sync::atomic;
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
}
}
/// Simple way to catch lock error without
/// depending on T
#[derive(thiserror::Error, Debug)]
@ -64,6 +77,10 @@ type AllocatedBuffers = Arc<RwLock<BufferMap>>;
#[derive(Clone)]
pub struct MetalDevice {
/// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than
/// the device itself.
id: DeviceId,
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
device: metal::Device,
@ -108,7 +125,7 @@ pub struct MetalDevice {
impl std::fmt::Debug for MetalDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MetalDevice({:?})", self.device.registry_id())
write!(f, "MetalDevice({:?})", self.id)
}
}
@ -121,8 +138,8 @@ impl std::ops::Deref for MetalDevice {
}
impl MetalDevice {
pub fn id(&self) -> NSUInteger {
self.registry_id()
pub fn id(&self) -> DeviceId {
self.id
}
pub fn metal_device(&self) -> &metal::Device {
@ -1117,8 +1134,8 @@ impl BackendStorage for MetalStorage {
padding: params.padding,
output_padding: params.output_padding,
c_out: params.c_out,
out_h: out_h,
out_w: out_w,
out_h,
out_w,
b_size: params.b_size,
input_dims: l.dims(),
input_stride: l.stride(),
@ -1867,6 +1884,7 @@ impl BackendDevice for MetalDevice {
MTLResourceOptions::StorageModeManaged,
)));
Ok(Self {
id: DeviceId::new(),
device,
command_queue,
command_buffer,
@ -1885,7 +1903,7 @@ impl BackendDevice for MetalDevice {
}
fn same_device(&self, rhs: &Self) -> bool {
self.device.registry_id() == rhs.device.registry_id()
self.id == rhs.id
}
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {

View File

@ -44,9 +44,19 @@ impl Storage {
}
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
let lhs = self.device().location();
let rhs = rhs.device().location();
if lhs != rhs {
let lhs_device = self.device();
let rhs_device = rhs.device();
let lhs = lhs_device.location();
let rhs = rhs_device.location();
let same_device = if self.device().is_metal() {
// On metal, we require the device to be exactly the same rather than
// having the same location. In cuda this is not necessary as all CudaDevice on the
// same GPU will use the same cuda stream.
lhs_device.same_device(&rhs_device)
} else {
lhs == rhs
};
if !same_device {
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
} else {
Ok(())