mirror of https://github.com/tracel-ai/burn.git
Add candle CudaDevice and MetalDevice to avoid creating a new unique device each time
This commit is contained in:
parent
aa79e36a8d
commit
1d2b68ecef
|
@ -5,7 +5,7 @@ use burn_tensor::{
|
|||
quantization::{QTensorPrimitive, QuantizationStrategy},
|
||||
Device,
|
||||
};
|
||||
use candle_core::DeviceLocation;
|
||||
use candle_core::{backend::BackendDevice, DeviceLocation};
|
||||
|
||||
use crate::{
|
||||
element::{CandleElement, FloatCandleElement, IntCandleElement},
|
||||
|
@ -16,7 +16,7 @@ use crate::{
|
|||
///
|
||||
/// It is compatible with a wide range of hardware configurations, including CPUs and GPUs
|
||||
/// that support CUDA or Metal. Additionally, the backend can be compiled to `wasm` when using the CPU.
|
||||
#[derive(Clone, Copy, Default, Debug)]
|
||||
#[derive(Clone, Default, Debug)]
|
||||
pub struct Candle<F = f32, I = i64>
|
||||
where
|
||||
F: FloatCandleElement,
|
||||
|
@ -27,29 +27,87 @@ where
|
|||
}
|
||||
|
||||
/// The device type for the candle backend.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
/// The device struct when using the `candle` backend.
|
||||
///
|
||||
/// Note that you need to provide the device index when using Cuda.
|
||||
/// To create a Cuda or Metal device from the index, use the associated methods to create the variant:
|
||||
/// ```no_run
|
||||
/// // Create a Cuda device from its index
|
||||
/// let device = CandleDevice::cuda(0);
|
||||
/// // Create a Metal device from its index
|
||||
/// let device = CandleDevice::metal(0);
|
||||
/// ```
|
||||
pub enum CandleDevice {
|
||||
/// CPU device.
|
||||
Cpu,
|
||||
|
||||
/// Cuda device with the given index. The index is the index of the Cuda device in the list of
|
||||
/// all Cuda devices found on the system.
|
||||
Cuda(usize),
|
||||
Cuda(CudaDevice),
|
||||
|
||||
/// Metal device with the given index. The index is the index of the Metal device in the list of
|
||||
/// all Metal devices found on the system.
|
||||
Metal(usize),
|
||||
Metal(MetalDevice),
|
||||
}
|
||||
|
||||
impl CandleDevice {
|
||||
/// Create a Cuda device with the given index.
|
||||
/// The index is the index of the Cuda device in the list of all Cuda devices found on the system.
|
||||
pub fn cuda(index: usize) -> Self {
|
||||
CandleDevice::Cuda(CudaDevice {
|
||||
device: candle_core::CudaDevice::new(index).unwrap(),
|
||||
index,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a Metal device with the given index.
|
||||
/// The index is the index of the Metal device in the list of all Metal devices found on the system.
|
||||
pub fn metal(index: usize) -> Self {
|
||||
CandleDevice::Metal(MetalDevice {
|
||||
device: candle_core::MetalDevice::new(index).unwrap(),
|
||||
index,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
/// A Cuda device for the `candle` backend.
|
||||
pub struct CudaDevice {
|
||||
pub(crate) device: candle_core::CudaDevice,
|
||||
/// The index of the Cuda device in the list of all devices on the system.
|
||||
pub index: usize,
|
||||
}
|
||||
|
||||
impl PartialEq for CudaDevice {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.device.same_device(&other.device) && self.index == other.index
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for CudaDevice {}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
/// A Metal device for the `candle` backend.
|
||||
pub struct MetalDevice {
|
||||
pub(crate) device: candle_core::MetalDevice,
|
||||
/// The index of the Metal device in the list of all devices on the system.
|
||||
pub index: usize,
|
||||
}
|
||||
|
||||
impl PartialEq for MetalDevice {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.device.same_device(&other.device) && self.index == other.index
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for MetalDevice {}
|
||||
|
||||
impl From<CandleDevice> for candle_core::Device {
|
||||
fn from(device: CandleDevice) -> Self {
|
||||
match device {
|
||||
CandleDevice::Cpu => candle_core::Device::Cpu,
|
||||
CandleDevice::Cuda(ordinal) => candle_core::Device::new_cuda(ordinal).unwrap(),
|
||||
CandleDevice::Metal(ordinal) => candle_core::Device::new_metal(ordinal).unwrap(),
|
||||
CandleDevice::Cuda(device) => candle_core::Device::Cuda(device.device),
|
||||
CandleDevice::Metal(device) => candle_core::Device::Metal(device.device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -58,8 +116,26 @@ impl From<candle_core::Device> for CandleDevice {
|
|||
fn from(device: candle_core::Device) -> Self {
|
||||
match device.location() {
|
||||
DeviceLocation::Cpu => CandleDevice::Cpu,
|
||||
DeviceLocation::Cuda { gpu_id } => CandleDevice::Cuda(gpu_id),
|
||||
DeviceLocation::Metal { gpu_id } => CandleDevice::Metal(gpu_id),
|
||||
DeviceLocation::Cuda { gpu_id } => {
|
||||
if let candle_core::Device::Cuda(device) = device {
|
||||
CandleDevice::Cuda(CudaDevice {
|
||||
device,
|
||||
index: gpu_id,
|
||||
})
|
||||
} else {
|
||||
panic!("Expected CUDA device.");
|
||||
}
|
||||
}
|
||||
DeviceLocation::Metal { gpu_id } => {
|
||||
if let candle_core::Device::Metal(device) = device {
|
||||
CandleDevice::Metal(MetalDevice {
|
||||
device,
|
||||
index: gpu_id,
|
||||
})
|
||||
} else {
|
||||
panic!("Expected Metal device.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -68,8 +144,8 @@ impl DeviceOps for CandleDevice {
|
|||
fn id(&self) -> burn_tensor::backend::DeviceId {
|
||||
match self {
|
||||
CandleDevice::Cpu => DeviceId::new(0, 0),
|
||||
CandleDevice::Cuda(index) => DeviceId::new(1, *index as u32),
|
||||
CandleDevice::Metal(index) => DeviceId::new(2, *index as u32),
|
||||
CandleDevice::Cuda(device) => DeviceId::new(1, device.index as u32),
|
||||
CandleDevice::Metal(device) => DeviceId::new(2, device.index as u32),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -111,7 +187,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
|
|||
fn sync(device: &Device<Self>, sync_type: SyncType) {
|
||||
match sync_type {
|
||||
SyncType::Wait => {
|
||||
let device: candle_core::Device = (*device).into();
|
||||
let device: candle_core::Device = (device.clone()).into();
|
||||
|
||||
match device {
|
||||
candle_core::Device::Cpu => (),
|
||||
|
|
|
@ -21,7 +21,7 @@ pub fn from_data<E: CandleElement, const D: usize>(
|
|||
data: TensorData,
|
||||
device: &CandleDevice,
|
||||
) -> CandleTensor<E, D> {
|
||||
CandleTensor::from_data(data, *device)
|
||||
CandleTensor::from_data(data, device.clone())
|
||||
}
|
||||
pub fn into_data<E: CandleElement, const D: usize>(tensor: CandleTensor<E, D>) -> TensorData {
|
||||
TensorData::new(
|
||||
|
@ -34,14 +34,16 @@ pub fn to_device<E: CandleElement, const D: usize>(
|
|||
tensor: CandleTensor<E, D>,
|
||||
device: &CandleDevice,
|
||||
) -> CandleTensor<E, D> {
|
||||
CandleTensor::new(tensor.tensor.to_device(&(*device).into()).unwrap())
|
||||
CandleTensor::new(tensor.tensor.to_device(&(device.clone()).into()).unwrap())
|
||||
}
|
||||
|
||||
pub fn empty<E: CandleElement, const D: usize>(
|
||||
shape: Shape<D>,
|
||||
device: &CandleDevice,
|
||||
) -> CandleTensor<E, D> {
|
||||
CandleTensor::new(candle_core::Tensor::zeros(&shape.dims, E::DTYPE, &(*device).into()).unwrap())
|
||||
CandleTensor::new(
|
||||
candle_core::Tensor::zeros(&shape.dims, E::DTYPE, &(device.clone()).into()).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn swap_dims<E: CandleElement, const D: usize>(
|
||||
|
|
|
@ -299,13 +299,13 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
|
|||
|
||||
fn int_zeros<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
|
||||
CandleTensor::new(
|
||||
candle_core::Tensor::zeros(&shape.dims, I::DTYPE, &(*device).into()).unwrap(),
|
||||
candle_core::Tensor::zeros(&shape.dims, I::DTYPE, &(device.clone()).into()).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn int_ones<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
|
||||
CandleTensor::new(
|
||||
candle_core::Tensor::ones(&shape.dims, I::DTYPE, &(*device).into()).unwrap(),
|
||||
candle_core::Tensor::ones(&shape.dims, I::DTYPE, &(device.clone()).into()).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -401,7 +401,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
|
|||
device: &Device<Self>,
|
||||
) -> IntTensor<Self, D> {
|
||||
let shape = &shape.dims;
|
||||
let device = &(*device).into();
|
||||
let device = &(device.clone()).into();
|
||||
match distribution {
|
||||
Distribution::Default => CandleTensor::new(
|
||||
candle_core::Tensor::rand(0.elem::<F>(), 255.elem::<F>(), shape, device)
|
||||
|
|
|
@ -18,7 +18,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
|
|||
data: TensorData,
|
||||
device: &Device<Self>,
|
||||
) -> CandleTensor<F, D> {
|
||||
CandleTensor::from_data(data, *device)
|
||||
CandleTensor::from_data(data, device.clone())
|
||||
}
|
||||
|
||||
fn float_random<const D: usize>(
|
||||
|
@ -27,7 +27,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
|
|||
device: &Device<Self>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
let shape = &shape.dims;
|
||||
let device = &(*device).into();
|
||||
let device = &(device.clone()).into();
|
||||
match distribution {
|
||||
Distribution::Default => CandleTensor::new(
|
||||
candle_core::Tensor::rand(0.elem::<F>(), 1.elem::<F>(), shape, device)
|
||||
|
|
Loading…
Reference in New Issue