Add candle CudaDevice and MetalDevice to avoid creating a new unique device each time

This commit is contained in:
Guillaume Lagrange 2024-09-20 11:05:55 -04:00
parent aa79e36a8d
commit 1d2b68ecef
4 changed files with 99 additions and 21 deletions

View File

@ -5,7 +5,7 @@ use burn_tensor::{
quantization::{QTensorPrimitive, QuantizationStrategy},
Device,
};
use candle_core::DeviceLocation;
use candle_core::{backend::BackendDevice, DeviceLocation};
use crate::{
element::{CandleElement, FloatCandleElement, IntCandleElement},
@ -16,7 +16,7 @@ use crate::{
///
/// It is compatible with a wide range of hardware configurations, including CPUs and GPUs
/// that support CUDA or Metal. Additionally, the backend can be compiled to `wasm` when using the CPU.
#[derive(Clone, Copy, Default, Debug)]
#[derive(Clone, Default, Debug)]
pub struct Candle<F = f32, I = i64>
where
F: FloatCandleElement,
@ -27,29 +27,87 @@ where
}
/// The device type for the candle backend.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq)]
/// The device struct when using the `candle` backend.
///
/// Note that you need to provide the device index when using Cuda.
/// To create a Cuda or Metal device from the index, use the associated methods to create the variant:
/// ```no_run
/// // Create a Cuda device from its index
/// let device = CandleDevice::cuda(0);
/// // Create a Metal device from its index
/// let device = CandleDevice::metal(0);
/// ```
pub enum CandleDevice {
/// CPU device.
Cpu,
/// Cuda device with the given index. The index is the index of the Cuda device in the list of
/// all Cuda devices found on the system.
Cuda(usize),
Cuda(CudaDevice),
/// Metal device with the given index. The index is the index of the Metal device in the list of
/// all Metal devices found on the system.
Metal(usize),
Metal(MetalDevice),
}
impl CandleDevice {
/// Create a Cuda device with the given index.
/// The index is the index of the Cuda device in the list of all Cuda devices found on the system.
pub fn cuda(index: usize) -> Self {
CandleDevice::Cuda(CudaDevice {
device: candle_core::CudaDevice::new(index).unwrap(),
index,
})
}
/// Create a Metal device with the given index.
/// The index is the index of the Metal device in the list of all Metal devices found on the system.
pub fn metal(index: usize) -> Self {
CandleDevice::Metal(MetalDevice {
device: candle_core::MetalDevice::new(index).unwrap(),
index,
})
}
}
#[derive(Clone, Debug)]
/// A Cuda device for the `candle` backend.
pub struct CudaDevice {
pub(crate) device: candle_core::CudaDevice,
/// The index of the Cuda device in the list of all devices on the system.
pub index: usize,
}
impl PartialEq for CudaDevice {
fn eq(&self, other: &Self) -> bool {
self.device.same_device(&other.device) && self.index == other.index
}
}
impl Eq for CudaDevice {}
#[derive(Clone, Debug)]
/// A Metal device for the `candle` backend.
pub struct MetalDevice {
pub(crate) device: candle_core::MetalDevice,
/// The index of the Metal device in the list of all devices on the system.
pub index: usize,
}
impl PartialEq for MetalDevice {
fn eq(&self, other: &Self) -> bool {
self.device.same_device(&other.device) && self.index == other.index
}
}
impl Eq for MetalDevice {}
impl From<CandleDevice> for candle_core::Device {
fn from(device: CandleDevice) -> Self {
match device {
CandleDevice::Cpu => candle_core::Device::Cpu,
CandleDevice::Cuda(ordinal) => candle_core::Device::new_cuda(ordinal).unwrap(),
CandleDevice::Metal(ordinal) => candle_core::Device::new_metal(ordinal).unwrap(),
CandleDevice::Cuda(device) => candle_core::Device::Cuda(device.device),
CandleDevice::Metal(device) => candle_core::Device::Metal(device.device),
}
}
}
@ -58,8 +116,26 @@ impl From<candle_core::Device> for CandleDevice {
fn from(device: candle_core::Device) -> Self {
match device.location() {
DeviceLocation::Cpu => CandleDevice::Cpu,
DeviceLocation::Cuda { gpu_id } => CandleDevice::Cuda(gpu_id),
DeviceLocation::Metal { gpu_id } => CandleDevice::Metal(gpu_id),
DeviceLocation::Cuda { gpu_id } => {
if let candle_core::Device::Cuda(device) = device {
CandleDevice::Cuda(CudaDevice {
device,
index: gpu_id,
})
} else {
panic!("Expected CUDA device.");
}
}
DeviceLocation::Metal { gpu_id } => {
if let candle_core::Device::Metal(device) = device {
CandleDevice::Metal(MetalDevice {
device,
index: gpu_id,
})
} else {
panic!("Expected Metal device.");
}
}
}
}
}
@ -68,8 +144,8 @@ impl DeviceOps for CandleDevice {
fn id(&self) -> burn_tensor::backend::DeviceId {
match self {
CandleDevice::Cpu => DeviceId::new(0, 0),
CandleDevice::Cuda(index) => DeviceId::new(1, *index as u32),
CandleDevice::Metal(index) => DeviceId::new(2, *index as u32),
CandleDevice::Cuda(device) => DeviceId::new(1, device.index as u32),
CandleDevice::Metal(device) => DeviceId::new(2, device.index as u32),
}
}
}
@ -111,7 +187,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
fn sync(device: &Device<Self>, sync_type: SyncType) {
match sync_type {
SyncType::Wait => {
let device: candle_core::Device = (*device).into();
let device: candle_core::Device = (device.clone()).into();
match device {
candle_core::Device::Cpu => (),

View File

@ -21,7 +21,7 @@ pub fn from_data<E: CandleElement, const D: usize>(
data: TensorData,
device: &CandleDevice,
) -> CandleTensor<E, D> {
CandleTensor::from_data(data, *device)
CandleTensor::from_data(data, device.clone())
}
pub fn into_data<E: CandleElement, const D: usize>(tensor: CandleTensor<E, D>) -> TensorData {
TensorData::new(
@ -34,14 +34,16 @@ pub fn to_device<E: CandleElement, const D: usize>(
tensor: CandleTensor<E, D>,
device: &CandleDevice,
) -> CandleTensor<E, D> {
CandleTensor::new(tensor.tensor.to_device(&(*device).into()).unwrap())
CandleTensor::new(tensor.tensor.to_device(&(device.clone()).into()).unwrap())
}
pub fn empty<E: CandleElement, const D: usize>(
shape: Shape<D>,
device: &CandleDevice,
) -> CandleTensor<E, D> {
CandleTensor::new(candle_core::Tensor::zeros(&shape.dims, E::DTYPE, &(*device).into()).unwrap())
CandleTensor::new(
candle_core::Tensor::zeros(&shape.dims, E::DTYPE, &(device.clone()).into()).unwrap(),
)
}
pub fn swap_dims<E: CandleElement, const D: usize>(

View File

@ -299,13 +299,13 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
fn int_zeros<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
CandleTensor::new(
candle_core::Tensor::zeros(&shape.dims, I::DTYPE, &(*device).into()).unwrap(),
candle_core::Tensor::zeros(&shape.dims, I::DTYPE, &(device.clone()).into()).unwrap(),
)
}
fn int_ones<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
CandleTensor::new(
candle_core::Tensor::ones(&shape.dims, I::DTYPE, &(*device).into()).unwrap(),
candle_core::Tensor::ones(&shape.dims, I::DTYPE, &(device.clone()).into()).unwrap(),
)
}
@ -401,7 +401,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
device: &Device<Self>,
) -> IntTensor<Self, D> {
let shape = &shape.dims;
let device = &(*device).into();
let device = &(device.clone()).into();
match distribution {
Distribution::Default => CandleTensor::new(
candle_core::Tensor::rand(0.elem::<F>(), 255.elem::<F>(), shape, device)

View File

@ -18,7 +18,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
data: TensorData,
device: &Device<Self>,
) -> CandleTensor<F, D> {
CandleTensor::from_data(data, *device)
CandleTensor::from_data(data, device.clone())
}
fn float_random<const D: usize>(
@ -27,7 +27,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
device: &Device<Self>,
) -> FloatTensor<Self, D> {
let shape = &shape.dims;
let device = &(*device).into();
let device = &(device.clone()).into();
match distribution {
Distribution::Default => CandleTensor::new(
candle_core::Tensor::rand(0.elem::<F>(), 1.elem::<F>(), shape, device)