Really unique identifier for metal device ids. (#1932)
* Really unique identifier for metal device ids. * Same device.
This commit is contained in:
parent
e7f8e72588
commit
cd254074f3
|
@ -10,6 +10,19 @@ use std::ffi::c_void;
|
|||
use std::path::Path;
|
||||
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard, TryLockError};
|
||||
|
||||
/// Unique identifier for cuda devices.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct DeviceId(usize);
|
||||
|
||||
impl DeviceId {
|
||||
fn new() -> Self {
|
||||
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||
use std::sync::atomic;
|
||||
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
||||
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple way to catch lock error without
|
||||
/// depending on T
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
|
@ -64,6 +77,10 @@ type AllocatedBuffers = Arc<RwLock<BufferMap>>;
|
|||
|
||||
#[derive(Clone)]
|
||||
pub struct MetalDevice {
|
||||
/// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than
|
||||
/// the device itself.
|
||||
id: DeviceId,
|
||||
|
||||
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
|
||||
device: metal::Device,
|
||||
|
||||
|
@ -108,7 +125,7 @@ pub struct MetalDevice {
|
|||
|
||||
impl std::fmt::Debug for MetalDevice {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "MetalDevice({:?})", self.device.registry_id())
|
||||
write!(f, "MetalDevice({:?})", self.id)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -121,8 +138,8 @@ impl std::ops::Deref for MetalDevice {
|
|||
}
|
||||
|
||||
impl MetalDevice {
|
||||
pub fn id(&self) -> NSUInteger {
|
||||
self.registry_id()
|
||||
pub fn id(&self) -> DeviceId {
|
||||
self.id
|
||||
}
|
||||
|
||||
pub fn metal_device(&self) -> &metal::Device {
|
||||
|
@ -1117,8 +1134,8 @@ impl BackendStorage for MetalStorage {
|
|||
padding: params.padding,
|
||||
output_padding: params.output_padding,
|
||||
c_out: params.c_out,
|
||||
out_h: out_h,
|
||||
out_w: out_w,
|
||||
out_h,
|
||||
out_w,
|
||||
b_size: params.b_size,
|
||||
input_dims: l.dims(),
|
||||
input_stride: l.stride(),
|
||||
|
@ -1867,6 +1884,7 @@ impl BackendDevice for MetalDevice {
|
|||
MTLResourceOptions::StorageModeManaged,
|
||||
)));
|
||||
Ok(Self {
|
||||
id: DeviceId::new(),
|
||||
device,
|
||||
command_queue,
|
||||
command_buffer,
|
||||
|
@ -1885,7 +1903,7 @@ impl BackendDevice for MetalDevice {
|
|||
}
|
||||
|
||||
fn same_device(&self, rhs: &Self) -> bool {
|
||||
self.device.registry_id() == rhs.device.registry_id()
|
||||
self.id == rhs.id
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
|
|
|
@ -44,9 +44,19 @@ impl Storage {
|
|||
}
|
||||
|
||||
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
||||
let lhs = self.device().location();
|
||||
let rhs = rhs.device().location();
|
||||
if lhs != rhs {
|
||||
let lhs_device = self.device();
|
||||
let rhs_device = rhs.device();
|
||||
let lhs = lhs_device.location();
|
||||
let rhs = rhs_device.location();
|
||||
let same_device = if self.device().is_metal() {
|
||||
// On metal, we require the device to be exactly the same rather than
|
||||
// having the same location. In cuda this is not necessary as all CudaDevice on the
|
||||
// same GPU will use the same cuda stream.
|
||||
lhs_device.same_device(&rhs_device)
|
||||
} else {
|
||||
lhs == rhs
|
||||
};
|
||||
if !same_device {
|
||||
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
|
||||
} else {
|
||||
Ok(())
|
||||
|
|
Loading…
Reference in New Issue