diff --git a/backend-comparison/src/lib.rs b/backend-comparison/src/lib.rs index 5f9fb23fa..4c9b6eb16 100644 --- a/backend-comparison/src/lib.rs +++ b/backend-comparison/src/lib.rs @@ -114,7 +114,7 @@ macro_rules! bench_on_backend { use burn::backend::candle::CandleDevice; use burn::backend::Candle; - let device = CandleDevice::Cuda(0); + let device = CandleDevice::cuda(0); bench::(&device, feature_name, url, token); } @@ -123,7 +123,7 @@ macro_rules! bench_on_backend { use burn::backend::candle::CandleDevice; use burn::backend::Candle; - let device = CandleDevice::Metal(0); + let device = CandleDevice::metal(0); bench::(&device, feature_name, url, token); } diff --git a/crates/burn-candle/src/backend.rs b/crates/burn-candle/src/backend.rs index 28d8fbee7..fe42039dd 100644 --- a/crates/burn-candle/src/backend.rs +++ b/crates/burn-candle/src/backend.rs @@ -5,7 +5,7 @@ use burn_tensor::{ quantization::{QTensorPrimitive, QuantizationStrategy}, Device, }; -use candle_core::DeviceLocation; +use candle_core::{backend::BackendDevice, DeviceLocation}; use crate::{ element::{CandleElement, FloatCandleElement, IntCandleElement}, @@ -16,7 +16,7 @@ use crate::{ /// /// It is compatible with a wide range of hardware configurations, including CPUs and GPUs /// that support CUDA or Metal. Additionally, the backend can be compiled to `wasm` when using the CPU. -#[derive(Clone, Copy, Default, Debug)] +#[derive(Clone, Default, Debug)] pub struct Candle where F: FloatCandleElement, @@ -27,29 +27,89 @@ where } /// The device type for the candle backend. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq)] /// The device struct when using the `candle` backend. /// -/// Note that you need to provide the device index when using Cuda. +/// To create a Cuda or Metal device from the index, use the associated methods to create the variant: +/// ```no_run +/// use burn_candle::CandleDevice; +/// +/// // Create a Cuda device from its index +/// let device = CandleDevice::cuda(0); +/// // Create a Metal device from its index +/// let device = CandleDevice::metal(0); +/// ``` pub enum CandleDevice { /// CPU device. Cpu, /// Cuda device with the given index. The index is the index of the Cuda device in the list of /// all Cuda devices found on the system. - Cuda(usize), + Cuda(CudaDevice), /// Metal device with the given index. The index is the index of the Metal device in the list of /// all Metal devices found on the system. - Metal(usize), + Metal(MetalDevice), } +impl CandleDevice { + /// Create a Cuda device with the given index. + /// The index is the index of the Cuda device in the list of all Cuda devices found on the system. + pub fn cuda(index: usize) -> Self { + CandleDevice::Cuda(CudaDevice { + device: candle_core::CudaDevice::new(index).unwrap(), + index, + }) + } + + /// Create a Metal device with the given index. + /// The index is the index of the Metal device in the list of all Metal devices found on the system. + pub fn metal(index: usize) -> Self { + CandleDevice::Metal(MetalDevice { + device: candle_core::MetalDevice::new(index).unwrap(), + index, + }) + } +} + +#[derive(Clone, Debug)] +/// A Cuda device for the `candle` backend. +pub struct CudaDevice { + pub(crate) device: candle_core::CudaDevice, + /// The index of the Cuda device in the list of all devices on the system. + pub index: usize, +} + +impl PartialEq for CudaDevice { + fn eq(&self, other: &Self) -> bool { + self.device.same_device(&other.device) && self.index == other.index + } +} + +impl Eq for CudaDevice {} + +#[derive(Clone, Debug)] +/// A Metal device for the `candle` backend. +pub struct MetalDevice { + pub(crate) device: candle_core::MetalDevice, + /// The index of the Metal device in the list of all devices on the system. + pub index: usize, +} + +impl PartialEq for MetalDevice { + fn eq(&self, other: &Self) -> bool { + self.device.same_device(&other.device) && self.index == other.index + } +} + +impl Eq for MetalDevice {} + impl From for candle_core::Device { fn from(device: CandleDevice) -> Self { match device { CandleDevice::Cpu => candle_core::Device::Cpu, - CandleDevice::Cuda(ordinal) => candle_core::Device::new_cuda(ordinal).unwrap(), - CandleDevice::Metal(ordinal) => candle_core::Device::new_metal(ordinal).unwrap(), + CandleDevice::Cuda(device) => candle_core::Device::Cuda(device.device), + CandleDevice::Metal(device) => candle_core::Device::Metal(device.device), } } } @@ -58,8 +118,26 @@ impl From for CandleDevice { fn from(device: candle_core::Device) -> Self { match device.location() { DeviceLocation::Cpu => CandleDevice::Cpu, - DeviceLocation::Cuda { gpu_id } => CandleDevice::Cuda(gpu_id), - DeviceLocation::Metal { gpu_id } => CandleDevice::Metal(gpu_id), + DeviceLocation::Cuda { gpu_id } => { + if let candle_core::Device::Cuda(device) = device { + CandleDevice::Cuda(CudaDevice { + device, + index: gpu_id, + }) + } else { + panic!("Expected CUDA device."); + } + } + DeviceLocation::Metal { gpu_id } => { + if let candle_core::Device::Metal(device) = device { + CandleDevice::Metal(MetalDevice { + device, + index: gpu_id, + }) + } else { + panic!("Expected Metal device."); + } + } } } } @@ -68,8 +146,8 @@ impl DeviceOps for CandleDevice { fn id(&self) -> burn_tensor::backend::DeviceId { match self { CandleDevice::Cpu => DeviceId::new(0, 0), - CandleDevice::Cuda(index) => DeviceId::new(1, *index as u32), - CandleDevice::Metal(index) => DeviceId::new(2, *index as u32), + CandleDevice::Cuda(device) => DeviceId::new(1, device.index as u32), + CandleDevice::Metal(device) => DeviceId::new(2, device.index as u32), } } } @@ -111,7 +189,7 @@ impl Backend for Candle { fn sync(device: &Device, sync_type: SyncType) { match sync_type { SyncType::Wait => { - let device: candle_core::Device = (*device).into(); + let device: candle_core::Device = (device.clone()).into(); match device { candle_core::Device::Cpu => (), diff --git a/crates/burn-candle/src/ops/base.rs b/crates/burn-candle/src/ops/base.rs index ada5a2202..4d2a4e00e 100644 --- a/crates/burn-candle/src/ops/base.rs +++ b/crates/burn-candle/src/ops/base.rs @@ -15,7 +15,7 @@ pub fn cat(tensors: Vec>, dim: usize) -> Candl } pub fn from_data(data: TensorData, device: &CandleDevice) -> CandleTensor { - CandleTensor::from_data(data, *device) + CandleTensor::from_data(data, device.clone()) } pub fn into_data(tensor: CandleTensor) -> TensorData { TensorData::new( @@ -28,11 +28,13 @@ pub fn to_device( tensor: CandleTensor, device: &CandleDevice, ) -> CandleTensor { - CandleTensor::new(tensor.tensor.to_device(&(*device).into()).unwrap()) + CandleTensor::new(tensor.tensor.to_device(&(device.clone()).into()).unwrap()) } pub fn empty(shape: Shape, device: &CandleDevice) -> CandleTensor { - CandleTensor::new(candle_core::Tensor::zeros(shape.dims, E::DTYPE, &(*device).into()).unwrap()) + CandleTensor::new( + candle_core::Tensor::zeros(shape.dims, E::DTYPE, &(device.clone()).into()).unwrap(), + ) } pub fn swap_dims( diff --git a/crates/burn-candle/src/ops/int_tensor.rs b/crates/burn-candle/src/ops/int_tensor.rs index b34875819..5d8df491d 100644 --- a/crates/burn-candle/src/ops/int_tensor.rs +++ b/crates/burn-candle/src/ops/int_tensor.rs @@ -230,13 +230,13 @@ impl IntTensorOps for Candle) -> IntTensor { CandleTensor::new( - candle_core::Tensor::zeros(shape.dims, I::DTYPE, &(*device).into()).unwrap(), + candle_core::Tensor::zeros(shape.dims, I::DTYPE, &(device.clone()).into()).unwrap(), ) } fn int_ones(shape: Shape, device: &Device) -> IntTensor { CandleTensor::new( - candle_core::Tensor::ones(shape.dims, I::DTYPE, &(*device).into()).unwrap(), + candle_core::Tensor::ones(shape.dims, I::DTYPE, &(device.clone()).into()).unwrap(), ) } @@ -324,7 +324,7 @@ impl IntTensorOps for Candle, ) -> IntTensor { let shape = shape.dims; - let device = &(*device).into(); + let device = &(device.clone()).into(); match distribution { Distribution::Default => CandleTensor::new( candle_core::Tensor::rand(0.elem::(), 255.elem::(), shape, device) diff --git a/crates/burn-candle/src/ops/tensor.rs b/crates/burn-candle/src/ops/tensor.rs index 663ca9bf4..22fa97dd7 100644 --- a/crates/burn-candle/src/ops/tensor.rs +++ b/crates/burn-candle/src/ops/tensor.rs @@ -15,7 +15,7 @@ use super::base::{expand, permute, sign}; impl FloatTensorOps for Candle { fn float_from_data(data: TensorData, device: &Device) -> CandleTensor { - CandleTensor::from_data(data, *device) + CandleTensor::from_data(data, device.clone()) } fn float_random( @@ -24,7 +24,7 @@ impl FloatTensorOps for Candle device: &Device, ) -> FloatTensor { let shape = shape.dims; - let device = &(*device).into(); + let device = &(device.clone()).into(); match distribution { Distribution::Default => CandleTensor::new( candle_core::Tensor::rand(0.elem::(), 1.elem::(), shape, device)