From ab67b6b0366e5cfc4a052c004ea87c435cedc915 Mon Sep 17 00:00:00 2001 From: Louis Fortier-Dubois Date: Mon, 8 Jan 2024 16:41:34 -0500 Subject: [PATCH] slice assign in candle (#1095) --- backend-comparison/Cargo.toml | 1 + backend-comparison/src/lib.rs | 9 +++++++++ burn-candle/Cargo.toml | 4 ++-- burn-candle/src/backend.rs | 11 ++++++++--- burn-candle/src/lib.rs | 6 +++--- burn-candle/src/ops/base.rs | 2 +- burn-core/Cargo.toml | 1 + burn/Cargo.toml | 3 ++- 8 files changed, 27 insertions(+), 10 deletions(-) diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index f1228d7c1..f82ce8ca6 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -14,6 +14,7 @@ default = ["std"] std = [] candle-cpu = ["burn/candle"] candle-cuda = ["burn/candle", "burn/cuda"] +candle-metal = ["burn/candle", "burn/metal"] candle-accelerate = ["burn/candle", "burn/accelerate"] ndarray = ["burn/ndarray"] ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] diff --git a/backend-comparison/src/lib.rs b/backend-comparison/src/lib.rs index 77ac9cf44..0cb471296 100644 --- a/backend-comparison/src/lib.rs +++ b/backend-comparison/src/lib.rs @@ -68,5 +68,14 @@ macro_rules! bench_on_backend { let device = CandleDevice::Cuda(0); bench::(&device); } + + #[cfg(feature = "candle-metal")] + { + use burn::backend::candle::CandleDevice; + use burn::backend::Candle; + + let device = CandleDevice::Metal(0); + bench::(&device); + } }; } diff --git a/burn-candle/Cargo.toml b/burn-candle/Cargo.toml index 29be44b9e..6bc429b9b 100644 --- a/burn-candle/Cargo.toml +++ b/burn-candle/Cargo.toml @@ -14,6 +14,7 @@ version.workspace = true default = ["std"] std = [] cuda = ["candle-core/cuda"] +metal = ["candle-core/metal"] accelerate = ["candle-core/accelerate"] [dependencies] @@ -21,8 +22,7 @@ derive-new = { workspace = true } burn-tensor = { path = "../burn-tensor", version = "0.12.0", default-features = false } half = { workspace = true } -candle-core = { version = "0.3.1" } - +candle-core = { version = "0.3.2" } [dev-dependencies] burn-autodiff = { path = "../burn-autodiff", version = "0.12.0", default-features = false, features = [ diff --git a/burn-candle/src/backend.rs b/burn-candle/src/backend.rs index f26e889c3..b418d3173 100644 --- a/burn-candle/src/backend.rs +++ b/burn-candle/src/backend.rs @@ -10,8 +10,8 @@ use crate::{ /// Tensor backend that uses the [candle](candle_core) crate for executing tensor operations. /// -/// It is compatible with a wide range of hardware configurations, including CPUs and Nvidia GPUs -/// that support CUDA. Additionally, the backend can be compiled to `wasm` when using the CPU. +/// It is compatible with a wide range of hardware configurations, including CPUs and GPUs +/// that support CUDA or Metal. Additionally, the backend can be compiled to `wasm` when using the CPU. #[derive(Clone, Copy, Default, Debug)] pub struct Candle where @@ -34,6 +34,10 @@ pub enum CandleDevice { /// Cuda device with the given index. The index is the index of the Cuda device in the list of /// all Cuda devices found on the system. Cuda(usize), + + /// Metal device with the given index. The index is the index of the Metal device in the list of + /// all Metal devices found on the system. + Metal(usize), } impl From for candle_core::Device { @@ -41,6 +45,7 @@ impl From for candle_core::Device { match device { CandleDevice::Cpu => candle_core::Device::Cpu, CandleDevice::Cuda(ordinal) => candle_core::Device::new_cuda(ordinal).unwrap(), + CandleDevice::Metal(ordinal) => candle_core::Device::new_metal(ordinal).unwrap(), } } } @@ -50,7 +55,7 @@ impl From for CandleDevice { match device.location() { DeviceLocation::Cpu => CandleDevice::Cpu, DeviceLocation::Cuda { gpu_id } => CandleDevice::Cuda(gpu_id), - _ => panic!("Device unsupported: {device:?}"), + DeviceLocation::Metal { gpu_id } => CandleDevice::Metal(gpu_id), } } } diff --git a/burn-candle/src/lib.rs b/burn-candle/src/lib.rs index 8cd2cc8a8..2e1979924 100644 --- a/burn-candle/src/lib.rs +++ b/burn-candle/src/lib.rs @@ -78,11 +78,11 @@ mod tests { burn_tensor::testgen_neg!(); burn_tensor::testgen_powf!(); burn_tensor::testgen_random!(); - // burn_tensor::testgen_repeat!(); + burn_tensor::testgen_repeat!(); burn_tensor::testgen_reshape!(); burn_tensor::testgen_select!(); burn_tensor::testgen_sin!(); - // burn_tensor::testgen_slice!(); + burn_tensor::testgen_slice!(); burn_tensor::testgen_sqrt!(); burn_tensor::testgen_abs!(); burn_tensor::testgen_squeeze!(); @@ -126,7 +126,7 @@ mod tests { burn_autodiff::testgen_ad_div!(); burn_autodiff::testgen_ad_erf!(); burn_autodiff::testgen_ad_exp!(); - // burn_autodiff::testgen_ad_slice!(); + burn_autodiff::testgen_ad_slice!(); burn_autodiff::testgen_ad_gather_scatter!(); burn_autodiff::testgen_ad_select!(); burn_autodiff::testgen_ad_log!(); diff --git a/burn-candle/src/ops/base.rs b/burn-candle/src/ops/base.rs index 7e6f3f833..16750ffcb 100644 --- a/burn-candle/src/ops/base.rs +++ b/burn-candle/src/ops/base.rs @@ -86,7 +86,7 @@ pub fn slice_assign( ranges: [std::ops::Range; D2], value: CandleTensor, ) -> CandleTensor { - panic!("slice_assign not supported by Candle") + CandleTensor::new(tensor.tensor.slice_assign(&ranges, &value.tensor).unwrap()) } pub fn narrow( diff --git a/burn-core/Cargo.toml b/burn-core/Cargo.toml index f338e54ba..b251ec3c7 100644 --- a/burn-core/Cargo.toml +++ b/burn-core/Cargo.toml @@ -51,6 +51,7 @@ fusion = ["burn-fusion", "burn-wgpu?/fusion"] ## Backend features cuda = ["burn-candle?/cuda"] +metal = ["burn-candle?/metal"] accelerate = ["burn-candle?/accelerate", "burn-ndarray?/blas-accelerate"] openblas = ["burn-ndarray?/blas-openblas"] openblas-system = ["burn-ndarray?/blas-openblas-system"] diff --git a/burn/Cargo.toml b/burn/Cargo.toml index 24e433093..54ee4c8db 100644 --- a/burn/Cargo.toml +++ b/burn/Cargo.toml @@ -39,6 +39,7 @@ fusion = ["burn-core/fusion"] ## Backend features cuda = ["burn-core/cuda"] +metal = ["burn-core/metal"] accelerate = ["burn-core/accelerate"] openblas = ["burn-core/openblas"] openblas-system = ["burn-core/openblas-system"] @@ -73,5 +74,5 @@ features = [ "wgpu", "candle", "fusion", - "experimental-named-tensor" + "experimental-named-tensor", ]