mirror of https://github.com/tracel-ai/burn.git
Add candle `CudaDevice` and `MetalDevice` to avoid creating a new unique device each time (#2290)
* Add candle CudaDevice and MetalDevice to avoid creating a new unique device each time * Fix doc example * Change enum usage
This commit is contained in:
parent
37d87956e2
commit
112f09ebaf
|
@ -114,7 +114,7 @@ macro_rules! bench_on_backend {
|
|||
use burn::backend::candle::CandleDevice;
|
||||
use burn::backend::Candle;
|
||||
|
||||
let device = CandleDevice::Cuda(0);
|
||||
let device = CandleDevice::cuda(0);
|
||||
bench::<Candle>(&device, feature_name, url, token);
|
||||
}
|
||||
|
||||
|
@ -123,7 +123,7 @@ macro_rules! bench_on_backend {
|
|||
use burn::backend::candle::CandleDevice;
|
||||
use burn::backend::Candle;
|
||||
|
||||
let device = CandleDevice::Metal(0);
|
||||
let device = CandleDevice::metal(0);
|
||||
bench::<Candle>(&device, feature_name, url, token);
|
||||
}
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ use burn_tensor::{
|
|||
quantization::{QTensorPrimitive, QuantizationStrategy},
|
||||
Device,
|
||||
};
|
||||
use candle_core::DeviceLocation;
|
||||
use candle_core::{backend::BackendDevice, DeviceLocation};
|
||||
|
||||
use crate::{
|
||||
element::{CandleElement, FloatCandleElement, IntCandleElement},
|
||||
|
@ -16,7 +16,7 @@ use crate::{
|
|||
///
|
||||
/// It is compatible with a wide range of hardware configurations, including CPUs and GPUs
|
||||
/// that support CUDA or Metal. Additionally, the backend can be compiled to `wasm` when using the CPU.
|
||||
#[derive(Clone, Copy, Default, Debug)]
|
||||
#[derive(Clone, Default, Debug)]
|
||||
pub struct Candle<F = f32, I = i64>
|
||||
where
|
||||
F: FloatCandleElement,
|
||||
|
@ -27,29 +27,89 @@ where
|
|||
}
|
||||
|
||||
/// The device type for the candle backend.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
/// The device struct when using the `candle` backend.
|
||||
///
|
||||
/// Note that you need to provide the device index when using Cuda.
|
||||
/// To create a Cuda or Metal device from the index, use the associated methods to create the variant:
|
||||
/// ```no_run
|
||||
/// use burn_candle::CandleDevice;
|
||||
///
|
||||
/// // Create a Cuda device from its index
|
||||
/// let device = CandleDevice::cuda(0);
|
||||
/// // Create a Metal device from its index
|
||||
/// let device = CandleDevice::metal(0);
|
||||
/// ```
|
||||
pub enum CandleDevice {
|
||||
/// CPU device.
|
||||
Cpu,
|
||||
|
||||
/// Cuda device with the given index. The index is the index of the Cuda device in the list of
|
||||
/// all Cuda devices found on the system.
|
||||
Cuda(usize),
|
||||
Cuda(CudaDevice),
|
||||
|
||||
/// Metal device with the given index. The index is the index of the Metal device in the list of
|
||||
/// all Metal devices found on the system.
|
||||
Metal(usize),
|
||||
Metal(MetalDevice),
|
||||
}
|
||||
|
||||
impl CandleDevice {
|
||||
/// Create a Cuda device with the given index.
|
||||
/// The index is the index of the Cuda device in the list of all Cuda devices found on the system.
|
||||
pub fn cuda(index: usize) -> Self {
|
||||
CandleDevice::Cuda(CudaDevice {
|
||||
device: candle_core::CudaDevice::new(index).unwrap(),
|
||||
index,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a Metal device with the given index.
|
||||
/// The index is the index of the Metal device in the list of all Metal devices found on the system.
|
||||
pub fn metal(index: usize) -> Self {
|
||||
CandleDevice::Metal(MetalDevice {
|
||||
device: candle_core::MetalDevice::new(index).unwrap(),
|
||||
index,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
/// A Cuda device for the `candle` backend.
|
||||
pub struct CudaDevice {
|
||||
pub(crate) device: candle_core::CudaDevice,
|
||||
/// The index of the Cuda device in the list of all devices on the system.
|
||||
pub index: usize,
|
||||
}
|
||||
|
||||
impl PartialEq for CudaDevice {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.device.same_device(&other.device) && self.index == other.index
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for CudaDevice {}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
/// A Metal device for the `candle` backend.
|
||||
pub struct MetalDevice {
|
||||
pub(crate) device: candle_core::MetalDevice,
|
||||
/// The index of the Metal device in the list of all devices on the system.
|
||||
pub index: usize,
|
||||
}
|
||||
|
||||
impl PartialEq for MetalDevice {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.device.same_device(&other.device) && self.index == other.index
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for MetalDevice {}
|
||||
|
||||
impl From<CandleDevice> for candle_core::Device {
|
||||
fn from(device: CandleDevice) -> Self {
|
||||
match device {
|
||||
CandleDevice::Cpu => candle_core::Device::Cpu,
|
||||
CandleDevice::Cuda(ordinal) => candle_core::Device::new_cuda(ordinal).unwrap(),
|
||||
CandleDevice::Metal(ordinal) => candle_core::Device::new_metal(ordinal).unwrap(),
|
||||
CandleDevice::Cuda(device) => candle_core::Device::Cuda(device.device),
|
||||
CandleDevice::Metal(device) => candle_core::Device::Metal(device.device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -58,8 +118,26 @@ impl From<candle_core::Device> for CandleDevice {
|
|||
fn from(device: candle_core::Device) -> Self {
|
||||
match device.location() {
|
||||
DeviceLocation::Cpu => CandleDevice::Cpu,
|
||||
DeviceLocation::Cuda { gpu_id } => CandleDevice::Cuda(gpu_id),
|
||||
DeviceLocation::Metal { gpu_id } => CandleDevice::Metal(gpu_id),
|
||||
DeviceLocation::Cuda { gpu_id } => {
|
||||
if let candle_core::Device::Cuda(device) = device {
|
||||
CandleDevice::Cuda(CudaDevice {
|
||||
device,
|
||||
index: gpu_id,
|
||||
})
|
||||
} else {
|
||||
panic!("Expected CUDA device.");
|
||||
}
|
||||
}
|
||||
DeviceLocation::Metal { gpu_id } => {
|
||||
if let candle_core::Device::Metal(device) = device {
|
||||
CandleDevice::Metal(MetalDevice {
|
||||
device,
|
||||
index: gpu_id,
|
||||
})
|
||||
} else {
|
||||
panic!("Expected Metal device.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -68,8 +146,8 @@ impl DeviceOps for CandleDevice {
|
|||
fn id(&self) -> burn_tensor::backend::DeviceId {
|
||||
match self {
|
||||
CandleDevice::Cpu => DeviceId::new(0, 0),
|
||||
CandleDevice::Cuda(index) => DeviceId::new(1, *index as u32),
|
||||
CandleDevice::Metal(index) => DeviceId::new(2, *index as u32),
|
||||
CandleDevice::Cuda(device) => DeviceId::new(1, device.index as u32),
|
||||
CandleDevice::Metal(device) => DeviceId::new(2, device.index as u32),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -111,7 +189,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
|
|||
fn sync(device: &Device<Self>, sync_type: SyncType) {
|
||||
match sync_type {
|
||||
SyncType::Wait => {
|
||||
let device: candle_core::Device = (*device).into();
|
||||
let device: candle_core::Device = (device.clone()).into();
|
||||
|
||||
match device {
|
||||
candle_core::Device::Cpu => (),
|
||||
|
|
|
@ -15,7 +15,7 @@ pub fn cat<E: CandleElement>(tensors: Vec<CandleTensor<E>>, dim: usize) -> Candl
|
|||
}
|
||||
|
||||
pub fn from_data<E: CandleElement>(data: TensorData, device: &CandleDevice) -> CandleTensor<E> {
|
||||
CandleTensor::from_data(data, *device)
|
||||
CandleTensor::from_data(data, device.clone())
|
||||
}
|
||||
pub fn into_data<E: CandleElement>(tensor: CandleTensor<E>) -> TensorData {
|
||||
TensorData::new(
|
||||
|
@ -28,11 +28,13 @@ pub fn to_device<E: CandleElement>(
|
|||
tensor: CandleTensor<E>,
|
||||
device: &CandleDevice,
|
||||
) -> CandleTensor<E> {
|
||||
CandleTensor::new(tensor.tensor.to_device(&(*device).into()).unwrap())
|
||||
CandleTensor::new(tensor.tensor.to_device(&(device.clone()).into()).unwrap())
|
||||
}
|
||||
|
||||
pub fn empty<E: CandleElement>(shape: Shape, device: &CandleDevice) -> CandleTensor<E> {
|
||||
CandleTensor::new(candle_core::Tensor::zeros(shape.dims, E::DTYPE, &(*device).into()).unwrap())
|
||||
CandleTensor::new(
|
||||
candle_core::Tensor::zeros(shape.dims, E::DTYPE, &(device.clone()).into()).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn swap_dims<E: CandleElement>(
|
||||
|
|
|
@ -230,13 +230,13 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
|
|||
|
||||
fn int_zeros(shape: Shape, device: &Device<Self>) -> IntTensor<Self> {
|
||||
CandleTensor::new(
|
||||
candle_core::Tensor::zeros(shape.dims, I::DTYPE, &(*device).into()).unwrap(),
|
||||
candle_core::Tensor::zeros(shape.dims, I::DTYPE, &(device.clone()).into()).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
fn int_ones(shape: Shape, device: &Device<Self>) -> IntTensor<Self> {
|
||||
CandleTensor::new(
|
||||
candle_core::Tensor::ones(shape.dims, I::DTYPE, &(*device).into()).unwrap(),
|
||||
candle_core::Tensor::ones(shape.dims, I::DTYPE, &(device.clone()).into()).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -324,7 +324,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
|
|||
device: &Device<Self>,
|
||||
) -> IntTensor<Self> {
|
||||
let shape = shape.dims;
|
||||
let device = &(*device).into();
|
||||
let device = &(device.clone()).into();
|
||||
match distribution {
|
||||
Distribution::Default => CandleTensor::new(
|
||||
candle_core::Tensor::rand(0.elem::<F>(), 255.elem::<F>(), shape, device)
|
||||
|
|
|
@ -15,7 +15,7 @@ use super::base::{expand, permute, sign};
|
|||
|
||||
impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle<F, I> {
|
||||
fn float_from_data(data: TensorData, device: &Device<Self>) -> CandleTensor<F> {
|
||||
CandleTensor::from_data(data, *device)
|
||||
CandleTensor::from_data(data, device.clone())
|
||||
}
|
||||
|
||||
fn float_random(
|
||||
|
@ -24,7 +24,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
|
|||
device: &Device<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
let shape = shape.dims;
|
||||
let device = &(*device).into();
|
||||
let device = &(device.clone()).into();
|
||||
match distribution {
|
||||
Distribution::Default => CandleTensor::new(
|
||||
candle_core::Tensor::rand(0.elem::<F>(), 1.elem::<F>(), shape, device)
|
||||
|
|
Loading…
Reference in New Issue