mirror of https://github.com/tracel-ai/burn.git
chore(candle): Allow enabling accelerate (#1009)
* chore(candle): Allow enabling accelerate * Temporarily disable test for accelerate feature * Allow enabling accelerate from upstream * Update the README * Have xtask also test using accelerate * Renable failing test * Fix matmul on candle when using accelerate * Add additional comment to xtask method
This commit is contained in:
parent
1d4e91ad32
commit
f73136e3df
|
@ -14,6 +14,7 @@ default = ["std"]
|
||||||
std = []
|
std = []
|
||||||
candle-cpu = ["burn/candle"]
|
candle-cpu = ["burn/candle"]
|
||||||
candle-cuda = ["burn/candle-cuda"]
|
candle-cuda = ["burn/candle-cuda"]
|
||||||
|
candle-accelerate = ["burn/candle-accelerate"]
|
||||||
ndarray = ["burn/ndarray"]
|
ndarray = ["burn/ndarray"]
|
||||||
ndarray-blas-accelerate = ["burn/ndarray-blas-accelerate"]
|
ndarray-blas-accelerate = ["burn/ndarray-blas-accelerate"]
|
||||||
ndarray-blas-netlib = ["burn/ndarray-blas-netlib"]
|
ndarray-blas-netlib = ["burn/ndarray-blas-netlib"]
|
||||||
|
|
|
@ -12,6 +12,7 @@ version = "0.11.0"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
cuda = ["candle-core/cuda"]
|
cuda = ["candle-core/cuda"]
|
||||||
|
accelerate = ["candle-core/accelerate"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
derive-new = { workspace = true }
|
derive-new = { workspace = true }
|
||||||
|
|
|
@ -4,4 +4,11 @@ This crate provides a backend for [Burn](https://github.com/burn-rs/burn) based
|
||||||
|
|
||||||
It is still in alpha stage, not all operations are supported. It is usable for some use cases, like for inference.
|
It is still in alpha stage, not all operations are supported. It is usable for some use cases, like for inference.
|
||||||
|
|
||||||
It can be used with CPU or CUDA.
|
It can be used with CPU or CUDA. On macOS computations can be accelerated by using the Accelerate framework.
|
||||||
|
|
||||||
|
## Feature Flags
|
||||||
|
|
||||||
|
The following features are supported:
|
||||||
|
|
||||||
|
- `cuda` - Cuda GPU device (NVIDIA only)
|
||||||
|
- `accelerate` - Accelerate framework (macOS only)
|
||||||
|
|
|
@ -137,7 +137,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
|
||||||
lhs: FloatTensor<Self, D>,
|
lhs: FloatTensor<Self, D>,
|
||||||
rhs: FloatTensor<Self, D>,
|
rhs: FloatTensor<Self, D>,
|
||||||
) -> FloatTensor<Self, D> {
|
) -> FloatTensor<Self, D> {
|
||||||
CandleTensor::new(lhs.tensor.broadcast_matmul(&rhs.tensor).unwrap())
|
let rhs_contiguous = rhs.tensor.contiguous().unwrap();
|
||||||
|
CandleTensor::new(lhs.tensor.broadcast_matmul(&rhs_contiguous).unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn swap_dims<const D: usize>(
|
fn swap_dims<const D: usize>(
|
||||||
|
|
|
@ -55,6 +55,7 @@ tch = ["burn-tch"]
|
||||||
|
|
||||||
candle = ["burn-candle"]
|
candle = ["burn-candle"]
|
||||||
candle-cuda = ["candle", "burn-candle/cuda"]
|
candle-cuda = ["candle", "burn-candle/cuda"]
|
||||||
|
candle-accelerate = ["candle", "burn-candle/accelerate"]
|
||||||
|
|
||||||
# Serialization formats
|
# Serialization formats
|
||||||
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
|
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
|
||||||
|
|
|
@ -51,6 +51,7 @@ wgpu = ["burn-core/wgpu"]
|
||||||
tch = ["burn-core/tch"]
|
tch = ["burn-core/tch"]
|
||||||
candle = ["burn-core/candle"]
|
candle = ["burn-core/candle"]
|
||||||
candle-cuda = ["burn-core/candle-cuda"]
|
candle-cuda = ["burn-core/candle-cuda"]
|
||||||
|
candle-accelerate = ["burn-core/candle-accelerate"]
|
||||||
|
|
||||||
# Experimental
|
# Experimental
|
||||||
experimental-named-tensor = ["burn-core/experimental-named-tensor"]
|
experimental-named-tensor = ["burn-core/experimental-named-tensor"]
|
||||||
|
|
|
@ -249,6 +249,13 @@ fn burn_dataset_features_std() {
|
||||||
cargo_doc(["-p", "burn-dataset", "--all-features"].into());
|
cargo_doc(["-p", "burn-dataset", "--all-features"].into());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test burn-candle with accelerate (macOS only)
|
||||||
|
// Leverages the macOS Accelerate framework: https://developer.apple.com/documentation/accelerate
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
fn burn_candle_accelerate() {
|
||||||
|
cargo_test(["-p", "burn-candle", "--features", "accelerate"].into());
|
||||||
|
}
|
||||||
|
|
||||||
fn std_checks() {
|
fn std_checks() {
|
||||||
// Set RUSTDOCFLAGS environment variable to treat warnings as errors
|
// Set RUSTDOCFLAGS environment variable to treat warnings as errors
|
||||||
// for the documentation build
|
// for the documentation build
|
||||||
|
@ -284,6 +291,10 @@ fn std_checks() {
|
||||||
// Test each workspace
|
// Test each workspace
|
||||||
cargo_test(["--workspace"].into());
|
cargo_test(["--workspace"].into());
|
||||||
|
|
||||||
|
// Test burn-candle with accelerate (macOS only)
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
burn_candle_accelerate();
|
||||||
|
|
||||||
// Test burn-dataset features
|
// Test burn-dataset features
|
||||||
burn_dataset_features_std();
|
burn_dataset_features_std();
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue