mirror of https://github.com/tracel-ai/burn.git
slice assign in candle (#1095)
This commit is contained in:
parent
e132e21816
commit
ab67b6b036
|
@ -14,6 +14,7 @@ default = ["std"]
|
|||
std = []
|
||||
candle-cpu = ["burn/candle"]
|
||||
candle-cuda = ["burn/candle", "burn/cuda"]
|
||||
candle-metal = ["burn/candle", "burn/metal"]
|
||||
candle-accelerate = ["burn/candle", "burn/accelerate"]
|
||||
ndarray = ["burn/ndarray"]
|
||||
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
|
||||
|
|
|
@ -68,5 +68,14 @@ macro_rules! bench_on_backend {
|
|||
let device = CandleDevice::Cuda(0);
|
||||
bench::<Candle>(&device);
|
||||
}
|
||||
|
||||
#[cfg(feature = "candle-metal")]
|
||||
{
|
||||
use burn::backend::candle::CandleDevice;
|
||||
use burn::backend::Candle;
|
||||
|
||||
let device = CandleDevice::Metal(0);
|
||||
bench::<Candle>(&device);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ version.workspace = true
|
|||
default = ["std"]
|
||||
std = []
|
||||
cuda = ["candle-core/cuda"]
|
||||
metal = ["candle-core/metal"]
|
||||
accelerate = ["candle-core/accelerate"]
|
||||
|
||||
[dependencies]
|
||||
|
@ -21,8 +22,7 @@ derive-new = { workspace = true }
|
|||
burn-tensor = { path = "../burn-tensor", version = "0.12.0", default-features = false }
|
||||
half = { workspace = true }
|
||||
|
||||
candle-core = { version = "0.3.1" }
|
||||
|
||||
candle-core = { version = "0.3.2" }
|
||||
|
||||
[dev-dependencies]
|
||||
burn-autodiff = { path = "../burn-autodiff", version = "0.12.0", default-features = false, features = [
|
||||
|
|
|
@ -10,8 +10,8 @@ use crate::{
|
|||
|
||||
/// Tensor backend that uses the [candle](candle_core) crate for executing tensor operations.
|
||||
///
|
||||
/// It is compatible with a wide range of hardware configurations, including CPUs and Nvidia GPUs
|
||||
/// that support CUDA. Additionally, the backend can be compiled to `wasm` when using the CPU.
|
||||
/// It is compatible with a wide range of hardware configurations, including CPUs and GPUs
|
||||
/// that support CUDA or Metal. Additionally, the backend can be compiled to `wasm` when using the CPU.
|
||||
#[derive(Clone, Copy, Default, Debug)]
|
||||
pub struct Candle<F = f32, I = i64>
|
||||
where
|
||||
|
@ -34,6 +34,10 @@ pub enum CandleDevice {
|
|||
/// Cuda device with the given index. The index is the index of the Cuda device in the list of
|
||||
/// all Cuda devices found on the system.
|
||||
Cuda(usize),
|
||||
|
||||
/// Metal device with the given index. The index is the index of the Metal device in the list of
|
||||
/// all Metal devices found on the system.
|
||||
Metal(usize),
|
||||
}
|
||||
|
||||
impl From<CandleDevice> for candle_core::Device {
|
||||
|
@ -41,6 +45,7 @@ impl From<CandleDevice> for candle_core::Device {
|
|||
match device {
|
||||
CandleDevice::Cpu => candle_core::Device::Cpu,
|
||||
CandleDevice::Cuda(ordinal) => candle_core::Device::new_cuda(ordinal).unwrap(),
|
||||
CandleDevice::Metal(ordinal) => candle_core::Device::new_metal(ordinal).unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -50,7 +55,7 @@ impl From<candle_core::Device> for CandleDevice {
|
|||
match device.location() {
|
||||
DeviceLocation::Cpu => CandleDevice::Cpu,
|
||||
DeviceLocation::Cuda { gpu_id } => CandleDevice::Cuda(gpu_id),
|
||||
_ => panic!("Device unsupported: {device:?}"),
|
||||
DeviceLocation::Metal { gpu_id } => CandleDevice::Metal(gpu_id),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -78,11 +78,11 @@ mod tests {
|
|||
burn_tensor::testgen_neg!();
|
||||
burn_tensor::testgen_powf!();
|
||||
burn_tensor::testgen_random!();
|
||||
// burn_tensor::testgen_repeat!();
|
||||
burn_tensor::testgen_repeat!();
|
||||
burn_tensor::testgen_reshape!();
|
||||
burn_tensor::testgen_select!();
|
||||
burn_tensor::testgen_sin!();
|
||||
// burn_tensor::testgen_slice!();
|
||||
burn_tensor::testgen_slice!();
|
||||
burn_tensor::testgen_sqrt!();
|
||||
burn_tensor::testgen_abs!();
|
||||
burn_tensor::testgen_squeeze!();
|
||||
|
@ -126,7 +126,7 @@ mod tests {
|
|||
burn_autodiff::testgen_ad_div!();
|
||||
burn_autodiff::testgen_ad_erf!();
|
||||
burn_autodiff::testgen_ad_exp!();
|
||||
// burn_autodiff::testgen_ad_slice!();
|
||||
burn_autodiff::testgen_ad_slice!();
|
||||
burn_autodiff::testgen_ad_gather_scatter!();
|
||||
burn_autodiff::testgen_ad_select!();
|
||||
burn_autodiff::testgen_ad_log!();
|
||||
|
|
|
@ -86,7 +86,7 @@ pub fn slice_assign<E: CandleElement, const D1: usize, const D2: usize>(
|
|||
ranges: [std::ops::Range<usize>; D2],
|
||||
value: CandleTensor<E, D1>,
|
||||
) -> CandleTensor<E, D1> {
|
||||
panic!("slice_assign not supported by Candle")
|
||||
CandleTensor::new(tensor.tensor.slice_assign(&ranges, &value.tensor).unwrap())
|
||||
}
|
||||
|
||||
pub fn narrow<E: CandleElement, const D: usize>(
|
||||
|
|
|
@ -51,6 +51,7 @@ fusion = ["burn-fusion", "burn-wgpu?/fusion"]
|
|||
|
||||
## Backend features
|
||||
cuda = ["burn-candle?/cuda"]
|
||||
metal = ["burn-candle?/metal"]
|
||||
accelerate = ["burn-candle?/accelerate", "burn-ndarray?/blas-accelerate"]
|
||||
openblas = ["burn-ndarray?/blas-openblas"]
|
||||
openblas-system = ["burn-ndarray?/blas-openblas-system"]
|
||||
|
|
|
@ -39,6 +39,7 @@ fusion = ["burn-core/fusion"]
|
|||
|
||||
## Backend features
|
||||
cuda = ["burn-core/cuda"]
|
||||
metal = ["burn-core/metal"]
|
||||
accelerate = ["burn-core/accelerate"]
|
||||
openblas = ["burn-core/openblas"]
|
||||
openblas-system = ["burn-core/openblas-system"]
|
||||
|
@ -73,5 +74,5 @@ features = [
|
|||
"wgpu",
|
||||
"candle",
|
||||
"fusion",
|
||||
"experimental-named-tensor"
|
||||
"experimental-named-tensor",
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue