Add candle `CudaDevice` and `MetalDevice` to avoid creating a new unique device each time (#2290)

* Add candle CudaDevice and MetalDevice to avoid creating a new unique device each time

* Fix doc example

* Change enum usage
This commit is contained in:
Guillaume Lagrange 2024-09-25 13:08:49 -04:00 committed by GitHub
parent 37d87956e2
commit 112f09ebaf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 103 additions and 23 deletions

View File

@ -114,7 +114,7 @@ macro_rules! bench_on_backend {
use burn::backend::candle::CandleDevice;
use burn::backend::Candle;
let device = CandleDevice::Cuda(0);
let device = CandleDevice::cuda(0);
bench::<Candle>(&device, feature_name, url, token);
}
@ -123,7 +123,7 @@ macro_rules! bench_on_backend {
use burn::backend::candle::CandleDevice;
use burn::backend::Candle;
let device = CandleDevice::Metal(0);
let device = CandleDevice::metal(0);
bench::<Candle>(&device, feature_name, url, token);
}

View File

@ -5,7 +5,7 @@ use burn_tensor::{
quantization::{QTensorPrimitive, QuantizationStrategy},
Device,
};
use candle_core::DeviceLocation;
use candle_core::{backend::BackendDevice, DeviceLocation};
use crate::{
element::{CandleElement, FloatCandleElement, IntCandleElement},
@ -16,7 +16,7 @@ use crate::{
///
/// It is compatible with a wide range of hardware configurations, including CPUs and GPUs
/// that support CUDA or Metal. Additionally, the backend can be compiled to `wasm` when using the CPU.
#[derive(Clone, Copy, Default, Debug)]
#[derive(Clone, Default, Debug)]
pub struct Candle<F = f32, I = i64>
where
F: FloatCandleElement,
@ -27,29 +27,89 @@ where
}
/// The device type for the candle backend.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq)]
/// The device struct when using the `candle` backend.
///
/// Note that you need to provide the device index when using Cuda.
/// To create a Cuda or Metal device from the index, use the associated methods to create the variant:
/// ```no_run
/// use burn_candle::CandleDevice;
///
/// // Create a Cuda device from its index
/// let device = CandleDevice::cuda(0);
/// // Create a Metal device from its index
/// let device = CandleDevice::metal(0);
/// ```
pub enum CandleDevice {
/// CPU device.
Cpu,
/// Cuda device with the given index. The index is the index of the Cuda device in the list of
/// all Cuda devices found on the system.
Cuda(usize),
Cuda(CudaDevice),
/// Metal device with the given index. The index is the index of the Metal device in the list of
/// all Metal devices found on the system.
Metal(usize),
Metal(MetalDevice),
}
impl CandleDevice {
/// Create a Cuda device with the given index.
/// The index is the index of the Cuda device in the list of all Cuda devices found on the system.
pub fn cuda(index: usize) -> Self {
CandleDevice::Cuda(CudaDevice {
device: candle_core::CudaDevice::new(index).unwrap(),
index,
})
}
/// Create a Metal device with the given index.
/// The index is the index of the Metal device in the list of all Metal devices found on the system.
pub fn metal(index: usize) -> Self {
CandleDevice::Metal(MetalDevice {
device: candle_core::MetalDevice::new(index).unwrap(),
index,
})
}
}
#[derive(Clone, Debug)]
/// A Cuda device for the `candle` backend.
pub struct CudaDevice {
pub(crate) device: candle_core::CudaDevice,
/// The index of the Cuda device in the list of all devices on the system.
pub index: usize,
}
impl PartialEq for CudaDevice {
fn eq(&self, other: &Self) -> bool {
self.device.same_device(&other.device) && self.index == other.index
}
}
impl Eq for CudaDevice {}
#[derive(Clone, Debug)]
/// A Metal device for the `candle` backend.
pub struct MetalDevice {
pub(crate) device: candle_core::MetalDevice,
/// The index of the Metal device in the list of all devices on the system.
pub index: usize,
}
impl PartialEq for MetalDevice {
fn eq(&self, other: &Self) -> bool {
self.device.same_device(&other.device) && self.index == other.index
}
}
impl Eq for MetalDevice {}
impl From<CandleDevice> for candle_core::Device {
fn from(device: CandleDevice) -> Self {
match device {
CandleDevice::Cpu => candle_core::Device::Cpu,
CandleDevice::Cuda(ordinal) => candle_core::Device::new_cuda(ordinal).unwrap(),
CandleDevice::Metal(ordinal) => candle_core::Device::new_metal(ordinal).unwrap(),
CandleDevice::Cuda(device) => candle_core::Device::Cuda(device.device),
CandleDevice::Metal(device) => candle_core::Device::Metal(device.device),
}
}
}
@ -58,8 +118,26 @@ impl From<candle_core::Device> for CandleDevice {
fn from(device: candle_core::Device) -> Self {
match device.location() {
DeviceLocation::Cpu => CandleDevice::Cpu,
DeviceLocation::Cuda { gpu_id } => CandleDevice::Cuda(gpu_id),
DeviceLocation::Metal { gpu_id } => CandleDevice::Metal(gpu_id),
DeviceLocation::Cuda { gpu_id } => {
if let candle_core::Device::Cuda(device) = device {
CandleDevice::Cuda(CudaDevice {
device,
index: gpu_id,
})
} else {
panic!("Expected CUDA device.");
}
}
DeviceLocation::Metal { gpu_id } => {
if let candle_core::Device::Metal(device) = device {
CandleDevice::Metal(MetalDevice {
device,
index: gpu_id,
})
} else {
panic!("Expected Metal device.");
}
}
}
}
}
@ -68,8 +146,8 @@ impl DeviceOps for CandleDevice {
fn id(&self) -> burn_tensor::backend::DeviceId {
match self {
CandleDevice::Cpu => DeviceId::new(0, 0),
CandleDevice::Cuda(index) => DeviceId::new(1, *index as u32),
CandleDevice::Metal(index) => DeviceId::new(2, *index as u32),
CandleDevice::Cuda(device) => DeviceId::new(1, device.index as u32),
CandleDevice::Metal(device) => DeviceId::new(2, device.index as u32),
}
}
}
@ -111,7 +189,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
fn sync(device: &Device<Self>, sync_type: SyncType) {
match sync_type {
SyncType::Wait => {
let device: candle_core::Device = (*device).into();
let device: candle_core::Device = (device.clone()).into();
match device {
candle_core::Device::Cpu => (),

View File

@ -15,7 +15,7 @@ pub fn cat<E: CandleElement>(tensors: Vec<CandleTensor<E>>, dim: usize) -> Candl
}
pub fn from_data<E: CandleElement>(data: TensorData, device: &CandleDevice) -> CandleTensor<E> {
CandleTensor::from_data(data, *device)
CandleTensor::from_data(data, device.clone())
}
pub fn into_data<E: CandleElement>(tensor: CandleTensor<E>) -> TensorData {
TensorData::new(
@ -28,11 +28,13 @@ pub fn to_device<E: CandleElement>(
tensor: CandleTensor<E>,
device: &CandleDevice,
) -> CandleTensor<E> {
CandleTensor::new(tensor.tensor.to_device(&(*device).into()).unwrap())
CandleTensor::new(tensor.tensor.to_device(&(device.clone()).into()).unwrap())
}
pub fn empty<E: CandleElement>(shape: Shape, device: &CandleDevice) -> CandleTensor<E> {
CandleTensor::new(candle_core::Tensor::zeros(shape.dims, E::DTYPE, &(*device).into()).unwrap())
CandleTensor::new(
candle_core::Tensor::zeros(shape.dims, E::DTYPE, &(device.clone()).into()).unwrap(),
)
}
pub fn swap_dims<E: CandleElement>(

View File

@ -230,13 +230,13 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
fn int_zeros(shape: Shape, device: &Device<Self>) -> IntTensor<Self> {
CandleTensor::new(
candle_core::Tensor::zeros(shape.dims, I::DTYPE, &(*device).into()).unwrap(),
candle_core::Tensor::zeros(shape.dims, I::DTYPE, &(device.clone()).into()).unwrap(),
)
}
fn int_ones(shape: Shape, device: &Device<Self>) -> IntTensor<Self> {
CandleTensor::new(
candle_core::Tensor::ones(shape.dims, I::DTYPE, &(*device).into()).unwrap(),
candle_core::Tensor::ones(shape.dims, I::DTYPE, &(device.clone()).into()).unwrap(),
)
}
@ -324,7 +324,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
device: &Device<Self>,
) -> IntTensor<Self> {
let shape = shape.dims;
let device = &(*device).into();
let device = &(device.clone()).into();
match distribution {
Distribution::Default => CandleTensor::new(
candle_core::Tensor::rand(0.elem::<F>(), 255.elem::<F>(), shape, device)

View File

@ -15,7 +15,7 @@ use super::base::{expand, permute, sign};
impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle<F, I> {
fn float_from_data(data: TensorData, device: &Device<Self>) -> CandleTensor<F> {
CandleTensor::from_data(data, *device)
CandleTensor::from_data(data, device.clone())
}
fn float_random(
@ -24,7 +24,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
device: &Device<Self>,
) -> FloatTensor<Self> {
let shape = shape.dims;
let device = &(*device).into();
let device = &(device.clone()).into();
match distribution {
Distribution::Default => CandleTensor::new(
candle_core::Tensor::rand(0.elem::<F>(), 1.elem::<F>(), shape, device)