slice assign in candle (#1095)

This commit is contained in:
Louis Fortier-Dubois 2024-01-08 16:41:34 -05:00 committed by GitHub
parent e132e21816
commit ab67b6b036
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 27 additions and 10 deletions

View File

@ -14,6 +14,7 @@ default = ["std"]
std = []
candle-cpu = ["burn/candle"]
candle-cuda = ["burn/candle", "burn/cuda"]
candle-metal = ["burn/candle", "burn/metal"]
candle-accelerate = ["burn/candle", "burn/accelerate"]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]

View File

@ -68,5 +68,14 @@ macro_rules! bench_on_backend {
let device = CandleDevice::Cuda(0);
bench::<Candle>(&device);
}
#[cfg(feature = "candle-metal")]
{
use burn::backend::candle::CandleDevice;
use burn::backend::Candle;
let device = CandleDevice::Metal(0);
bench::<Candle>(&device);
}
};
}

View File

@ -14,6 +14,7 @@ version.workspace = true
default = ["std"]
std = []
cuda = ["candle-core/cuda"]
metal = ["candle-core/metal"]
accelerate = ["candle-core/accelerate"]
[dependencies]
@ -21,8 +22,7 @@ derive-new = { workspace = true }
burn-tensor = { path = "../burn-tensor", version = "0.12.0", default-features = false }
half = { workspace = true }
candle-core = { version = "0.3.1" }
candle-core = { version = "0.3.2" }
[dev-dependencies]
burn-autodiff = { path = "../burn-autodiff", version = "0.12.0", default-features = false, features = [

View File

@ -10,8 +10,8 @@ use crate::{
/// Tensor backend that uses the [candle](candle_core) crate for executing tensor operations.
///
/// It is compatible with a wide range of hardware configurations, including CPUs and Nvidia GPUs
/// that support CUDA. Additionally, the backend can be compiled to `wasm` when using the CPU.
/// It is compatible with a wide range of hardware configurations, including CPUs and GPUs
/// that support CUDA or Metal. Additionally, the backend can be compiled to `wasm` when using the CPU.
#[derive(Clone, Copy, Default, Debug)]
pub struct Candle<F = f32, I = i64>
where
@ -34,6 +34,10 @@ pub enum CandleDevice {
/// Cuda device with the given index. The index is the index of the Cuda device in the list of
/// all Cuda devices found on the system.
Cuda(usize),
/// Metal device with the given index. The index is the index of the Metal device in the list of
/// all Metal devices found on the system.
Metal(usize),
}
impl From<CandleDevice> for candle_core::Device {
@ -41,6 +45,7 @@ impl From<CandleDevice> for candle_core::Device {
match device {
CandleDevice::Cpu => candle_core::Device::Cpu,
CandleDevice::Cuda(ordinal) => candle_core::Device::new_cuda(ordinal).unwrap(),
CandleDevice::Metal(ordinal) => candle_core::Device::new_metal(ordinal).unwrap(),
}
}
}
@ -50,7 +55,7 @@ impl From<candle_core::Device> for CandleDevice {
match device.location() {
DeviceLocation::Cpu => CandleDevice::Cpu,
DeviceLocation::Cuda { gpu_id } => CandleDevice::Cuda(gpu_id),
_ => panic!("Device unsupported: {device:?}"),
DeviceLocation::Metal { gpu_id } => CandleDevice::Metal(gpu_id),
}
}
}

View File

@ -78,11 +78,11 @@ mod tests {
burn_tensor::testgen_neg!();
burn_tensor::testgen_powf!();
burn_tensor::testgen_random!();
// burn_tensor::testgen_repeat!();
burn_tensor::testgen_repeat!();
burn_tensor::testgen_reshape!();
burn_tensor::testgen_select!();
burn_tensor::testgen_sin!();
// burn_tensor::testgen_slice!();
burn_tensor::testgen_slice!();
burn_tensor::testgen_sqrt!();
burn_tensor::testgen_abs!();
burn_tensor::testgen_squeeze!();
@ -126,7 +126,7 @@ mod tests {
burn_autodiff::testgen_ad_div!();
burn_autodiff::testgen_ad_erf!();
burn_autodiff::testgen_ad_exp!();
// burn_autodiff::testgen_ad_slice!();
burn_autodiff::testgen_ad_slice!();
burn_autodiff::testgen_ad_gather_scatter!();
burn_autodiff::testgen_ad_select!();
burn_autodiff::testgen_ad_log!();

View File

@ -86,7 +86,7 @@ pub fn slice_assign<E: CandleElement, const D1: usize, const D2: usize>(
ranges: [std::ops::Range<usize>; D2],
value: CandleTensor<E, D1>,
) -> CandleTensor<E, D1> {
panic!("slice_assign not supported by Candle")
CandleTensor::new(tensor.tensor.slice_assign(&ranges, &value.tensor).unwrap())
}
pub fn narrow<E: CandleElement, const D: usize>(

View File

@ -51,6 +51,7 @@ fusion = ["burn-fusion", "burn-wgpu?/fusion"]
## Backend features
cuda = ["burn-candle?/cuda"]
metal = ["burn-candle?/metal"]
accelerate = ["burn-candle?/accelerate", "burn-ndarray?/blas-accelerate"]
openblas = ["burn-ndarray?/blas-openblas"]
openblas-system = ["burn-ndarray?/blas-openblas-system"]

View File

@ -39,6 +39,7 @@ fusion = ["burn-core/fusion"]
## Backend features
cuda = ["burn-core/cuda"]
metal = ["burn-core/metal"]
accelerate = ["burn-core/accelerate"]
openblas = ["burn-core/openblas"]
openblas-system = ["burn-core/openblas-system"]
@ -73,5 +74,5 @@ features = [
"wgpu",
"candle",
"fusion",
"experimental-named-tensor"
"experimental-named-tensor",
]