mirror of https://github.com/tracel-ai/burn.git
make candle available (#886)
This commit is contained in:
parent
07c0cf146d
commit
e4d9d67526
|
@ -6,6 +6,7 @@ resolver = "2"
|
|||
members = [
|
||||
"burn",
|
||||
"burn-autodiff",
|
||||
"burn-candle",
|
||||
"burn-common",
|
||||
"burn-compute",
|
||||
"burn-core",
|
||||
|
|
|
@ -26,7 +26,7 @@ simplifying the process of experimenting, training, and deploying models.
|
|||
[`no_std`](#support-for-no_std) compatibility, ensuring universal platform adaptability 👌
|
||||
- [WebGPU](https://github.com/burn-rs/burn/tree/main/burn-wgpu) backend, offering cross-platform,
|
||||
browser-inclusive, GPU-based computations 🌐
|
||||
- [Candle](https://github.com/burn-rs/burn/tree/main/burn-candle) backend (alpha) 🕯️
|
||||
- [Candle](https://github.com/burn-rs/burn/tree/main/burn-candle) backend 🕯️
|
||||
- [Autodiff](https://github.com/burn-rs/burn/tree/main/burn-autodiff) backend that enables
|
||||
differentiability across all backends 🌟
|
||||
- [Dataset](https://github.com/burn-rs/burn/tree/main/burn-dataset) crate containing a diverse range
|
||||
|
|
|
@ -12,6 +12,7 @@ version = "0.10.0"
|
|||
[features]
|
||||
default = ["std"]
|
||||
std = []
|
||||
candle = ["burn/candle"]
|
||||
ndarray = ["burn/ndarray"]
|
||||
ndarray-blas-accelerate = ["burn/ndarray-blas-accelerate"]
|
||||
ndarray-blas-netlib = ["burn/ndarray-blas-netlib"]
|
||||
|
|
|
@ -40,5 +40,14 @@ macro_rules! bench_on_backend {
|
|||
let device = NdArrayDevice::Cpu;
|
||||
bench::<NdArrayBackend>(&device);
|
||||
}
|
||||
|
||||
#[cfg(feature = "candle")]
|
||||
{
|
||||
use burn::backend::candle::CandleDevice;
|
||||
use burn::backend::CandleBackend;
|
||||
|
||||
let device = CandleDevice::Cpu;
|
||||
bench::<CandleBackend>(&device);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -52,6 +52,8 @@ wgpu = ["burn-wgpu/default"]
|
|||
|
||||
tch = ["burn-tch"]
|
||||
|
||||
candle = ["burn-candle"]
|
||||
|
||||
# Serialization formats
|
||||
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
|
||||
|
||||
|
@ -72,6 +74,7 @@ burn-ndarray = { path = "../burn-ndarray", version = "0.10.0", optional = true,
|
|||
burn-autodiff = { path = "../burn-autodiff", version = "0.10.0", optional = true }
|
||||
burn-wgpu = { path = "../burn-wgpu", version = "0.10.0", optional = true }
|
||||
burn-tch = { path = "../burn-tch", version = "0.10.0", optional = true }
|
||||
burn-candle = { path = "../burn-candle", version = "0.10.0", optional = true }
|
||||
|
||||
derive-new = { workspace = true }
|
||||
libm = { workspace = true }
|
||||
|
|
|
@ -23,6 +23,18 @@ pub type WgpuBackend<G = wgpu::AutoGraphicsApi, F = f32, I = i32> = wgpu::WgpuBa
|
|||
pub type WgpuAutodiffBackend<G = wgpu::AutoGraphicsApi, F = f32, I = i32> =
|
||||
crate::autodiff::ADBackendDecorator<WgpuBackend<G, F, I>>;
|
||||
|
||||
#[cfg(feature = "candle")]
|
||||
/// Candle module.
|
||||
pub use burn_candle as candle;
|
||||
|
||||
#[cfg(feature = "candle")]
|
||||
/// A CandleBackend with a default type of f32/i64.
|
||||
pub type CandleBackend = candle::CandleBackend<f32, i64>;
|
||||
|
||||
#[cfg(all(feature = "candle", feature = "autodiff"))]
|
||||
/// A CandleBackend with autodiffing enabled.
|
||||
pub type CandleAutodiffBackend = crate::autodiff::ADBackendDecorator<CandleBackend>;
|
||||
|
||||
#[cfg(feature = "tch")]
|
||||
/// Tch module.
|
||||
pub use burn_tch as tch;
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
[![Current Crates.io Version](https://img.shields.io/crates/v/burn-wgpu.svg)](https://crates.io/crates/burn-wgpu)
|
||||
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/burn-rs/burn-wgpu/blob/master/README.md)
|
||||
|
||||
This crate provides a WGPU backend for [Burn](https://github.com/burn-rs/burn) utilizing the
|
||||
This crate provides a WGPU backend for [Burn](https://github.com/burn-rs/burn) using the
|
||||
[wgpu](https://github.com/gfx-rs/wgpu).
|
||||
|
||||
The backend supports Vulkan, Metal, DirectX11/12, OpenGL, WebGPU.
|
||||
|
|
|
@ -45,6 +45,7 @@ ndarray-blas-openblas-system = ["burn-core/ndarray-blas-openblas-system"]
|
|||
|
||||
wgpu = ["burn-core/wgpu"]
|
||||
tch = ["burn-core/tch"]
|
||||
candle = ["burn-core/candle"]
|
||||
|
||||
# Experimental
|
||||
experimental-named-tensor = ["burn-core/experimental-named-tensor"]
|
||||
|
@ -53,8 +54,8 @@ experimental-named-tensor = ["burn-core/experimental-named-tensor"]
|
|||
|
||||
# ** Please make sure all dependencies support no_std when std is disabled **
|
||||
|
||||
burn-core = {path = "../burn-core", version = "0.10.0", default-features = false}
|
||||
burn-train = {path = "../burn-train", version = "0.10.0", optional = true, default-features = false }
|
||||
burn-core = { path = "../burn-core", version = "0.10.0", default-features = false }
|
||||
burn-train = { path = "../burn-train", version = "0.10.0", optional = true, default-features = false }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
all-features = true
|
||||
|
|
|
@ -16,9 +16,9 @@ ndarray = ["burn/ndarray-no-std"]
|
|||
wgpu = ["burn/wgpu"]
|
||||
|
||||
[dependencies]
|
||||
burn = {path = "../../burn", default-features = false}
|
||||
serde = {workspace = true}
|
||||
wasm-bindgen = { version = "0.2.87" }
|
||||
burn = { path = "../../burn", default-features = false }
|
||||
serde = { workspace = true }
|
||||
wasm-bindgen = { version = "0.2.87" }
|
||||
wasm-bindgen-futures = "0.4"
|
||||
js-sys = "0.3.64"
|
||||
|
||||
|
|
Loading…
Reference in New Issue