diff --git a/Cargo.toml b/Cargo.toml index b06cf329d..fe5079484 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ resolver = "2" members = [ "burn", "burn-autodiff", + "burn-candle", "burn-common", "burn-compute", "burn-core", diff --git a/README.md b/README.md index a3dba90de..ef8092498 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ simplifying the process of experimenting, training, and deploying models. [`no_std`](#support-for-no_std) compatibility, ensuring universal platform adaptability 👌 - [WebGPU](https://github.com/burn-rs/burn/tree/main/burn-wgpu) backend, offering cross-platform, browser-inclusive, GPU-based computations 🌐 - - [Candle](https://github.com/burn-rs/burn/tree/main/burn-candle) backend (alpha) 🕯️ + - [Candle](https://github.com/burn-rs/burn/tree/main/burn-candle) backend 🕯️ - [Autodiff](https://github.com/burn-rs/burn/tree/main/burn-autodiff) backend that enables differentiability across all backends 🌟 - [Dataset](https://github.com/burn-rs/burn/tree/main/burn-dataset) crate containing a diverse range diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index e2d39848f..8e2928fb0 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -12,6 +12,7 @@ version = "0.10.0" [features] default = ["std"] std = [] +candle = ["burn/candle"] ndarray = ["burn/ndarray"] ndarray-blas-accelerate = ["burn/ndarray-blas-accelerate"] ndarray-blas-netlib = ["burn/ndarray-blas-netlib"] diff --git a/backend-comparison/src/lib.rs b/backend-comparison/src/lib.rs index d2dbd3e25..064a87cbe 100644 --- a/backend-comparison/src/lib.rs +++ b/backend-comparison/src/lib.rs @@ -40,5 +40,14 @@ macro_rules! bench_on_backend { let device = NdArrayDevice::Cpu; bench::(&device); } + + #[cfg(feature = "candle")] + { + use burn::backend::candle::CandleDevice; + use burn::backend::CandleBackend; + + let device = CandleDevice::Cpu; + bench::(&device); + } }; } diff --git a/burn-core/Cargo.toml b/burn-core/Cargo.toml index 10a4a5e47..4c1abca92 100644 --- a/burn-core/Cargo.toml +++ b/burn-core/Cargo.toml @@ -52,6 +52,8 @@ wgpu = ["burn-wgpu/default"] tch = ["burn-tch"] +candle = ["burn-candle"] + # Serialization formats experimental-named-tensor = ["burn-tensor/experimental-named-tensor"] @@ -72,6 +74,7 @@ burn-ndarray = { path = "../burn-ndarray", version = "0.10.0", optional = true, burn-autodiff = { path = "../burn-autodiff", version = "0.10.0", optional = true } burn-wgpu = { path = "../burn-wgpu", version = "0.10.0", optional = true } burn-tch = { path = "../burn-tch", version = "0.10.0", optional = true } +burn-candle = { path = "../burn-candle", version = "0.10.0", optional = true } derive-new = { workspace = true } libm = { workspace = true } diff --git a/burn-core/src/backend.rs b/burn-core/src/backend.rs index 2e0ad5fb6..d22948d95 100644 --- a/burn-core/src/backend.rs +++ b/burn-core/src/backend.rs @@ -23,6 +23,18 @@ pub type WgpuBackend = wgpu::WgpuBa pub type WgpuAutodiffBackend = crate::autodiff::ADBackendDecorator>; +#[cfg(feature = "candle")] +/// Candle module. +pub use burn_candle as candle; + +#[cfg(feature = "candle")] +/// A CandleBackend with a default type of f32/i64. +pub type CandleBackend = candle::CandleBackend; + +#[cfg(all(feature = "candle", feature = "autodiff"))] +/// A CandleBackend with autodiffing enabled. +pub type CandleAutodiffBackend = crate::autodiff::ADBackendDecorator; + #[cfg(feature = "tch")] /// Tch module. pub use burn_tch as tch; diff --git a/burn-wgpu/README.md b/burn-wgpu/README.md index 51368b2c3..e0b7286b8 100644 --- a/burn-wgpu/README.md +++ b/burn-wgpu/README.md @@ -5,7 +5,7 @@ [![Current Crates.io Version](https://img.shields.io/crates/v/burn-wgpu.svg)](https://crates.io/crates/burn-wgpu) [![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/burn-rs/burn-wgpu/blob/master/README.md) -This crate provides a WGPU backend for [Burn](https://github.com/burn-rs/burn) utilizing the +This crate provides a WGPU backend for [Burn](https://github.com/burn-rs/burn) using the [wgpu](https://github.com/gfx-rs/wgpu). The backend supports Vulkan, Metal, DirectX11/12, OpenGL, WebGPU. diff --git a/burn/Cargo.toml b/burn/Cargo.toml index f704d15d7..04dd15f6f 100644 --- a/burn/Cargo.toml +++ b/burn/Cargo.toml @@ -45,6 +45,7 @@ ndarray-blas-openblas-system = ["burn-core/ndarray-blas-openblas-system"] wgpu = ["burn-core/wgpu"] tch = ["burn-core/tch"] +candle = ["burn-core/candle"] # Experimental experimental-named-tensor = ["burn-core/experimental-named-tensor"] @@ -53,8 +54,8 @@ experimental-named-tensor = ["burn-core/experimental-named-tensor"] # ** Please make sure all dependencies support no_std when std is disabled ** -burn-core = {path = "../burn-core", version = "0.10.0", default-features = false} -burn-train = {path = "../burn-train", version = "0.10.0", optional = true, default-features = false } +burn-core = { path = "../burn-core", version = "0.10.0", default-features = false } +burn-train = { path = "../burn-train", version = "0.10.0", optional = true, default-features = false } [package.metadata.docs.rs] all-features = true diff --git a/examples/mnist-inference-web/Cargo.toml b/examples/mnist-inference-web/Cargo.toml index e0c49cddb..a896b2097 100644 --- a/examples/mnist-inference-web/Cargo.toml +++ b/examples/mnist-inference-web/Cargo.toml @@ -16,9 +16,9 @@ ndarray = ["burn/ndarray-no-std"] wgpu = ["burn/wgpu"] [dependencies] -burn = {path = "../../burn", default-features = false} -serde = {workspace = true} -wasm-bindgen = { version = "0.2.87" } +burn = { path = "../../burn", default-features = false } +serde = { workspace = true } +wasm-bindgen = { version = "0.2.87" } wasm-bindgen-futures = "0.4" js-sys = "0.3.64"