make candle available (#886)

This commit is contained in:
Louis Fortier-Dubois 2023-10-23 10:00:39 -04:00 committed by GitHub
parent 07c0cf146d
commit e4d9d67526
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 34 additions and 7 deletions

View File

@ -6,6 +6,7 @@ resolver = "2"
members = [
"burn",
"burn-autodiff",
"burn-candle",
"burn-common",
"burn-compute",
"burn-core",

View File

@ -26,7 +26,7 @@ simplifying the process of experimenting, training, and deploying models.
[`no_std`](#support-for-no_std) compatibility, ensuring universal platform adaptability 👌
- [WebGPU](https://github.com/burn-rs/burn/tree/main/burn-wgpu) backend, offering cross-platform,
browser-inclusive, GPU-based computations 🌐
- [Candle](https://github.com/burn-rs/burn/tree/main/burn-candle) backend (alpha) 🕯️
- [Candle](https://github.com/burn-rs/burn/tree/main/burn-candle) backend 🕯️
- [Autodiff](https://github.com/burn-rs/burn/tree/main/burn-autodiff) backend that enables
differentiability across all backends 🌟
- [Dataset](https://github.com/burn-rs/burn/tree/main/burn-dataset) crate containing a diverse range

View File

@ -12,6 +12,7 @@ version = "0.10.0"
[features]
default = ["std"]
std = []
candle = ["burn/candle"]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray-blas-accelerate"]
ndarray-blas-netlib = ["burn/ndarray-blas-netlib"]

View File

@ -40,5 +40,14 @@ macro_rules! bench_on_backend {
let device = NdArrayDevice::Cpu;
bench::<NdArrayBackend>(&device);
}
#[cfg(feature = "candle")]
{
use burn::backend::candle::CandleDevice;
use burn::backend::CandleBackend;
let device = CandleDevice::Cpu;
bench::<CandleBackend>(&device);
}
};
}

View File

@ -52,6 +52,8 @@ wgpu = ["burn-wgpu/default"]
tch = ["burn-tch"]
candle = ["burn-candle"]
# Serialization formats
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
@ -72,6 +74,7 @@ burn-ndarray = { path = "../burn-ndarray", version = "0.10.0", optional = true,
burn-autodiff = { path = "../burn-autodiff", version = "0.10.0", optional = true }
burn-wgpu = { path = "../burn-wgpu", version = "0.10.0", optional = true }
burn-tch = { path = "../burn-tch", version = "0.10.0", optional = true }
burn-candle = { path = "../burn-candle", version = "0.10.0", optional = true }
derive-new = { workspace = true }
libm = { workspace = true }

View File

@ -23,6 +23,18 @@ pub type WgpuBackend<G = wgpu::AutoGraphicsApi, F = f32, I = i32> = wgpu::WgpuBa
pub type WgpuAutodiffBackend<G = wgpu::AutoGraphicsApi, F = f32, I = i32> =
crate::autodiff::ADBackendDecorator<WgpuBackend<G, F, I>>;
#[cfg(feature = "candle")]
/// Candle module.
pub use burn_candle as candle;
#[cfg(feature = "candle")]
/// A CandleBackend with a default type of f32/i64.
pub type CandleBackend = candle::CandleBackend<f32, i64>;
#[cfg(all(feature = "candle", feature = "autodiff"))]
/// A CandleBackend with autodiffing enabled.
pub type CandleAutodiffBackend = crate::autodiff::ADBackendDecorator<CandleBackend>;
#[cfg(feature = "tch")]
/// Tch module.
pub use burn_tch as tch;

View File

@ -5,7 +5,7 @@
[![Current Crates.io Version](https://img.shields.io/crates/v/burn-wgpu.svg)](https://crates.io/crates/burn-wgpu)
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/burn-rs/burn-wgpu/blob/master/README.md)
This crate provides a WGPU backend for [Burn](https://github.com/burn-rs/burn) utilizing the
This crate provides a WGPU backend for [Burn](https://github.com/burn-rs/burn) using the
[wgpu](https://github.com/gfx-rs/wgpu).
The backend supports Vulkan, Metal, DirectX11/12, OpenGL, WebGPU.

View File

@ -45,6 +45,7 @@ ndarray-blas-openblas-system = ["burn-core/ndarray-blas-openblas-system"]
wgpu = ["burn-core/wgpu"]
tch = ["burn-core/tch"]
candle = ["burn-core/candle"]
# Experimental
experimental-named-tensor = ["burn-core/experimental-named-tensor"]
@ -53,8 +54,8 @@ experimental-named-tensor = ["burn-core/experimental-named-tensor"]
# ** Please make sure all dependencies support no_std when std is disabled **
burn-core = {path = "../burn-core", version = "0.10.0", default-features = false}
burn-train = {path = "../burn-train", version = "0.10.0", optional = true, default-features = false }
burn-core = { path = "../burn-core", version = "0.10.0", default-features = false }
burn-train = { path = "../burn-train", version = "0.10.0", optional = true, default-features = false }
[package.metadata.docs.rs]
all-features = true

View File

@ -16,9 +16,9 @@ ndarray = ["burn/ndarray-no-std"]
wgpu = ["burn/wgpu"]
[dependencies]
burn = {path = "../../burn", default-features = false}
serde = {workspace = true}
wasm-bindgen = { version = "0.2.87" }
burn = { path = "../../burn", default-features = false }
serde = { workspace = true }
wasm-bindgen = { version = "0.2.87" }
wasm-bindgen-futures = "0.4"
js-sys = "0.3.64"