mirror of https://github.com/tracel-ai/burn.git
Remote Backend (#2463)
This commit is contained in:
parent
9b9b03c959
commit
099b6dcae0
|
@ -229,6 +229,17 @@ dependencies = [
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "async-trait"
|
||||||
|
version = "0.1.83"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.87",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "atoi"
|
name = "atoi"
|
||||||
version = "2.0.0"
|
version = "2.0.0"
|
||||||
|
@ -296,6 +307,64 @@ dependencies = [
|
||||||
"arrayvec",
|
"arrayvec",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "axum"
|
||||||
|
version = "0.7.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae"
|
||||||
|
dependencies = [
|
||||||
|
"async-trait",
|
||||||
|
"axum-core",
|
||||||
|
"base64 0.22.1",
|
||||||
|
"bytes",
|
||||||
|
"futures-util",
|
||||||
|
"http 1.1.0",
|
||||||
|
"http-body 1.0.1",
|
||||||
|
"http-body-util",
|
||||||
|
"hyper 1.5.0",
|
||||||
|
"hyper-util",
|
||||||
|
"itoa",
|
||||||
|
"matchit",
|
||||||
|
"memchr",
|
||||||
|
"mime",
|
||||||
|
"percent-encoding",
|
||||||
|
"pin-project-lite",
|
||||||
|
"rustversion",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"serde_path_to_error",
|
||||||
|
"serde_urlencoded",
|
||||||
|
"sha1",
|
||||||
|
"sync_wrapper 1.0.1",
|
||||||
|
"tokio",
|
||||||
|
"tokio-tungstenite",
|
||||||
|
"tower",
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "axum-core"
|
||||||
|
version = "0.4.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199"
|
||||||
|
dependencies = [
|
||||||
|
"async-trait",
|
||||||
|
"bytes",
|
||||||
|
"futures-util",
|
||||||
|
"http 1.1.0",
|
||||||
|
"http-body 1.0.1",
|
||||||
|
"http-body-util",
|
||||||
|
"mime",
|
||||||
|
"pin-project-lite",
|
||||||
|
"rustversion",
|
||||||
|
"sync_wrapper 1.0.1",
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "backend-comparison"
|
name = "backend-comparison"
|
||||||
version = "0.16.0"
|
version = "0.16.0"
|
||||||
|
@ -524,6 +593,7 @@ dependencies = [
|
||||||
"indicatif",
|
"indicatif",
|
||||||
"rayon",
|
"rayon",
|
||||||
"reqwest 0.12.9",
|
"reqwest 0.12.9",
|
||||||
|
"serde",
|
||||||
"tokio",
|
"tokio",
|
||||||
"web-time",
|
"web-time",
|
||||||
]
|
]
|
||||||
|
@ -542,6 +612,7 @@ dependencies = [
|
||||||
"burn-derive",
|
"burn-derive",
|
||||||
"burn-hip",
|
"burn-hip",
|
||||||
"burn-ndarray",
|
"burn-ndarray",
|
||||||
|
"burn-remote",
|
||||||
"burn-tch",
|
"burn-tch",
|
||||||
"burn-tensor",
|
"burn-tensor",
|
||||||
"burn-wgpu",
|
"burn-wgpu",
|
||||||
|
@ -647,6 +718,7 @@ dependencies = [
|
||||||
"derive-new 0.7.0",
|
"derive-new 0.7.0",
|
||||||
"half",
|
"half",
|
||||||
"log",
|
"log",
|
||||||
|
"paste",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -729,15 +801,38 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "burn-remote"
|
||||||
|
version = "0.16.0"
|
||||||
|
dependencies = [
|
||||||
|
"axum",
|
||||||
|
"burn-common",
|
||||||
|
"burn-remote",
|
||||||
|
"burn-router",
|
||||||
|
"burn-tensor",
|
||||||
|
"derive-new 0.7.0",
|
||||||
|
"futures-util",
|
||||||
|
"log",
|
||||||
|
"rmp-serde",
|
||||||
|
"serde",
|
||||||
|
"serde_bytes",
|
||||||
|
"tokio",
|
||||||
|
"tokio-tungstenite",
|
||||||
|
"tracing-core",
|
||||||
|
"tracing-subscriber",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "burn-router"
|
name = "burn-router"
|
||||||
version = "0.16.0"
|
version = "0.16.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"burn-autodiff",
|
"burn-autodiff",
|
||||||
|
"burn-common",
|
||||||
"burn-ndarray",
|
"burn-ndarray",
|
||||||
"burn-tensor",
|
"burn-tensor",
|
||||||
"burn-wgpu",
|
"burn-wgpu",
|
||||||
"hashbrown 0.15.0",
|
"hashbrown 0.15.0",
|
||||||
|
"log",
|
||||||
"spin",
|
"spin",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -3149,6 +3244,7 @@ dependencies = [
|
||||||
"http 1.1.0",
|
"http 1.1.0",
|
||||||
"http-body 1.0.1",
|
"http-body 1.0.1",
|
||||||
"httparse",
|
"httparse",
|
||||||
|
"httpdate",
|
||||||
"itoa",
|
"itoa",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"smallvec",
|
"smallvec",
|
||||||
|
@ -3693,6 +3789,12 @@ dependencies = [
|
||||||
"regex-automata 0.1.10",
|
"regex-automata 0.1.10",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "matchit"
|
||||||
|
version = "0.7.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "matrixmultiply"
|
name = "matrixmultiply"
|
||||||
version = "0.3.9"
|
version = "0.3.9"
|
||||||
|
@ -6098,6 +6200,16 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_path_to_error"
|
||||||
|
version = "0.1.16"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "af99884400da37c88f5e9146b7f1fd0fbcae8f6eec4e9da38b67d05486f814a6"
|
||||||
|
dependencies = [
|
||||||
|
"itoa",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_rusqlite"
|
name = "serde_rusqlite"
|
||||||
version = "0.36.0"
|
version = "0.36.0"
|
||||||
|
@ -6154,6 +6266,14 @@ dependencies = [
|
||||||
"syn 2.0.87",
|
"syn 2.0.87",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "server"
|
||||||
|
version = "0.16.0"
|
||||||
|
dependencies = [
|
||||||
|
"burn",
|
||||||
|
"cfg-if",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sha1"
|
name = "sha1"
|
||||||
version = "0.10.6"
|
version = "0.10.6"
|
||||||
|
@ -6874,6 +6994,18 @@ dependencies = [
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio-tungstenite"
|
||||||
|
version = "0.24.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9"
|
||||||
|
dependencies = [
|
||||||
|
"futures-util",
|
||||||
|
"log",
|
||||||
|
"tokio",
|
||||||
|
"tungstenite",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio-util"
|
name = "tokio-util"
|
||||||
version = "0.7.12"
|
version = "0.7.12"
|
||||||
|
@ -6936,6 +7068,28 @@ dependencies = [
|
||||||
"zip 0.6.6",
|
"zip 0.6.6",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tower"
|
||||||
|
version = "0.5.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f"
|
||||||
|
dependencies = [
|
||||||
|
"futures-core",
|
||||||
|
"futures-util",
|
||||||
|
"pin-project-lite",
|
||||||
|
"sync_wrapper 0.1.2",
|
||||||
|
"tokio",
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tower-layer"
|
||||||
|
version = "0.3.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tower-service"
|
name = "tower-service"
|
||||||
version = "0.3.3"
|
version = "0.3.3"
|
||||||
|
@ -6978,6 +7132,7 @@ version = "0.1.40"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
|
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"log",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"tracing-attributes",
|
"tracing-attributes",
|
||||||
"tracing-core",
|
"tracing-core",
|
||||||
|
@ -7051,6 +7206,24 @@ version = "0.2.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tungstenite"
|
||||||
|
version = "0.24.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a"
|
||||||
|
dependencies = [
|
||||||
|
"byteorder",
|
||||||
|
"bytes",
|
||||||
|
"data-encoding",
|
||||||
|
"http 1.1.0",
|
||||||
|
"httparse",
|
||||||
|
"log",
|
||||||
|
"rand",
|
||||||
|
"sha1",
|
||||||
|
"thiserror",
|
||||||
|
"utf-8",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "typenum"
|
name = "typenum"
|
||||||
version = "1.17.0"
|
version = "1.17.0"
|
||||||
|
@ -7174,6 +7347,12 @@ dependencies = [
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "utf-8"
|
||||||
|
version = "0.7.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "utf8parse"
|
name = "utf8parse"
|
||||||
version = "0.2.2"
|
version = "0.2.2"
|
||||||
|
|
|
@ -119,6 +119,7 @@ bincode = { version = "2.0.0-rc.3", features = [
|
||||||
# The following packages disable the "std" feature for no_std compatibility
|
# The following packages disable the "std" feature for no_std compatibility
|
||||||
#
|
#
|
||||||
derive-new = { version = "0.7.0", default-features = false }
|
derive-new = { version = "0.7.0", default-features = false }
|
||||||
|
cfg-if = "1.0.0"
|
||||||
|
|
||||||
blas-src = { version = "0.10.0", default-features = false }
|
blas-src = { version = "0.10.0", default-features = false }
|
||||||
half = { version = "2.4.1", features = [
|
half = { version = "2.4.1", features = [
|
||||||
|
|
|
@ -23,6 +23,8 @@ getrandom = { workspace = true, features = ["js"] }
|
||||||
web-time = { version = "1.1.0" }
|
web-time = { version = "1.1.0" }
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
serde = { workspace = true }
|
||||||
|
|
||||||
# Network downloader
|
# Network downloader
|
||||||
indicatif = { workspace = true, optional = true }
|
indicatif = { workspace = true, optional = true }
|
||||||
reqwest = { workspace = true, optional = true }
|
reqwest = { workspace = true, optional = true }
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
use crate::rand::gen_random;
|
use crate::rand::gen_random;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Simple ID generator.
|
/// Simple ID generator.
|
||||||
pub struct IdGenerator {}
|
pub struct IdGenerator {}
|
||||||
|
@ -64,3 +65,49 @@ mod tests {
|
||||||
assert_eq!(set.len(), EXPECTED_TOTAL_IDS);
|
assert_eq!(set.len(), EXPECTED_TOTAL_IDS);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Unique identifier that can represent a stream based on the current thread id.
|
||||||
|
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)]
|
||||||
|
pub struct StreamId {
|
||||||
|
/// The value representing the thread id.
|
||||||
|
pub value: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StreamId {
|
||||||
|
/// Get the current thread id.
|
||||||
|
pub fn current() -> Self {
|
||||||
|
Self {
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
value: Self::from_current_thread(),
|
||||||
|
#[cfg(not(feature = "std"))]
|
||||||
|
value: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
fn from_current_thread() -> u64 {
|
||||||
|
use core::hash::Hash;
|
||||||
|
|
||||||
|
std::thread_local! {
|
||||||
|
static ID: std::cell::OnceCell::<u64> = const { std::cell::OnceCell::new() };
|
||||||
|
};
|
||||||
|
|
||||||
|
// Getting the current thread is expensive, so we cache the value into a thread local
|
||||||
|
// variable, which is very fast.
|
||||||
|
ID.with(|cell| {
|
||||||
|
*cell.get_or_init(|| {
|
||||||
|
// A way to get a thread id encoded as u64.
|
||||||
|
let mut hasher = std::hash::DefaultHasher::default();
|
||||||
|
let id = std::thread::current().id();
|
||||||
|
id.hash(&mut hasher);
|
||||||
|
std::hash::Hasher::finish(&hasher)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl core::fmt::Display for StreamId {
|
||||||
|
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||||
|
f.write_fmt(format_args!("StreamId({:?})", self.value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -39,6 +39,8 @@ doc = [
|
||||||
"hip-jit",
|
"hip-jit",
|
||||||
"vision",
|
"vision",
|
||||||
"autodiff",
|
"autodiff",
|
||||||
|
"remote",
|
||||||
|
"server",
|
||||||
# Doc features
|
# Doc features
|
||||||
"burn-candle/doc",
|
"burn-candle/doc",
|
||||||
"burn-common/doc",
|
"burn-common/doc",
|
||||||
|
@ -86,6 +88,8 @@ metal = ["burn-candle?/metal"]
|
||||||
openblas = ["burn-ndarray?/blas-openblas"]
|
openblas = ["burn-ndarray?/blas-openblas"]
|
||||||
openblas-system = ["burn-ndarray?/blas-openblas-system"]
|
openblas-system = ["burn-ndarray?/blas-openblas-system"]
|
||||||
template = ["burn-wgpu?/template"]
|
template = ["burn-wgpu?/template"]
|
||||||
|
remote = ["burn-remote/client"]
|
||||||
|
server = ["burn-remote/server"]
|
||||||
|
|
||||||
candle = ["burn-candle"]
|
candle = ["burn-candle"]
|
||||||
candle-cuda = ["candle", "burn-candle/cuda"]
|
candle-cuda = ["candle", "burn-candle/cuda"]
|
||||||
|
@ -131,6 +135,7 @@ burn-hip = { path = "../burn-hip", version = "0.16.0", optional = true, default-
|
||||||
burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", optional = true, default-features = false }
|
burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", optional = true, default-features = false }
|
||||||
burn-tch = { path = "../burn-tch", version = "0.16.0", optional = true }
|
burn-tch = { path = "../burn-tch", version = "0.16.0", optional = true }
|
||||||
burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", optional = true, default-features = false }
|
burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", optional = true, default-features = false }
|
||||||
|
burn-remote = { path = "../burn-remote", version = "0.16.0", default-features = false, optional = true }
|
||||||
|
|
||||||
data-encoding = { workspace = true }
|
data-encoding = { workspace = true }
|
||||||
uuid = { workspace = true }
|
uuid = { workspace = true }
|
||||||
|
|
|
@ -7,6 +7,11 @@ pub use ndarray::NdArray;
|
||||||
#[cfg(feature = "autodiff")]
|
#[cfg(feature = "autodiff")]
|
||||||
pub use burn_autodiff as autodiff;
|
pub use burn_autodiff as autodiff;
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
pub use burn_remote as remote;
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
pub use burn_remote::RemoteBackend;
|
||||||
|
|
||||||
#[cfg(feature = "autodiff")]
|
#[cfg(feature = "autodiff")]
|
||||||
pub use burn_autodiff::Autodiff;
|
pub use burn_autodiff::Autodiff;
|
||||||
|
|
||||||
|
|
|
@ -43,6 +43,9 @@ pub mod tensor;
|
||||||
/// Backend module.
|
/// Backend module.
|
||||||
pub mod backend;
|
pub mod backend;
|
||||||
|
|
||||||
|
#[cfg(feature = "server")]
|
||||||
|
pub use burn_remote::server;
|
||||||
|
|
||||||
extern crate alloc;
|
extern crate alloc;
|
||||||
|
|
||||||
#[cfg(all(
|
#[cfg(all(
|
||||||
|
|
|
@ -12,8 +12,8 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cuda"
|
||||||
version.workspace = true
|
version.workspace = true
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
|
default = ["fusion", "autotune", "burn-jit/default", "cubecl/default"]
|
||||||
autotune = ["burn-jit/autotune"]
|
autotune = ["burn-jit/autotune"]
|
||||||
default = ["fusion", "burn-jit/default", "cubecl/default"]
|
|
||||||
doc = ["burn-jit/doc"]
|
doc = ["burn-jit/doc"]
|
||||||
fusion = ["burn-fusion", "burn-jit/fusion"]
|
fusion = ["burn-fusion", "burn-jit/fusion"]
|
||||||
std = ["burn-jit/std", "cubecl/std"]
|
std = ["burn-jit/std", "cubecl/std"]
|
||||||
|
|
|
@ -34,6 +34,7 @@ derive-new = { workspace = true }
|
||||||
burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false, features = [
|
burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false, features = [
|
||||||
"export_tests",
|
"export_tests",
|
||||||
] }
|
] }
|
||||||
|
paste = { workspace = true }
|
||||||
|
|
||||||
[package.metadata.docs.rs]
|
[package.metadata.docs.rs]
|
||||||
features = ["doc"]
|
features = ["doc"]
|
||||||
|
|
|
@ -24,6 +24,7 @@ mod tests {
|
||||||
use burn_jit::JitBackend;
|
use burn_jit::JitBackend;
|
||||||
|
|
||||||
pub type TestRuntime = cubecl::hip::HipRuntime;
|
pub type TestRuntime = cubecl::hip::HipRuntime;
|
||||||
|
pub use half::{bf16, f16};
|
||||||
|
|
||||||
burn_jit::testgen_all!();
|
burn_jit::testgen_all!();
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
[package]
|
||||||
|
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||||
|
categories = ["science"]
|
||||||
|
description = "Backend router decorator over websocket."
|
||||||
|
edition.workspace = true
|
||||||
|
keywords = ["deep-learning", "machine-learning", "data"]
|
||||||
|
license.workspace = true
|
||||||
|
name = "burn-remote"
|
||||||
|
readme.workspace = true
|
||||||
|
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-router-remote"
|
||||||
|
documentation = "https://docs.rs/burn-router-remote"
|
||||||
|
version.workspace = true
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = []
|
||||||
|
doc = []
|
||||||
|
client = ["tokio-tungstenite"]
|
||||||
|
server = ["axum", "tracing-core", "tracing-subscriber"]
|
||||||
|
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = true, features = ["repr"]}
|
||||||
|
burn-common = { path = "../burn-common", version = "0.16.0", default-features = true}
|
||||||
|
burn-router = { path = "../burn-router", version = "0.16.0", default-features = true}
|
||||||
|
|
||||||
|
# Basic dependencies
|
||||||
|
derive-new = {workspace = true }
|
||||||
|
log = { workspace = true }
|
||||||
|
|
||||||
|
# Shared dependencies
|
||||||
|
tokio = { version = "1.37", features = ["sync", "rt-multi-thread"] }
|
||||||
|
serde = { workspace = true, features = ["derive"] }
|
||||||
|
serde_bytes = { workspace = true }
|
||||||
|
rmp-serde = { workspace = true }
|
||||||
|
futures-util = { version = "0.3" }
|
||||||
|
|
||||||
|
# Client dependencies
|
||||||
|
tokio-tungstenite = { version = "0.24", optional = true }
|
||||||
|
|
||||||
|
# Server dependencies
|
||||||
|
axum = { version = "0.7.5", features = ["ws"], optional = true }
|
||||||
|
tracing-core = { workspace = true, optional = true }
|
||||||
|
tracing-subscriber = { workspace = true, optional = true }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
# We activate the features client and server during dev.
|
||||||
|
burn-remote = { path = ".", version = "0.16.0", features=["client", "server"] }
|
||||||
|
|
||||||
|
[package.metadata.docs.rs]
|
||||||
|
features = ["doc"]
|
||||||
|
rustdoc-args = ["--cfg", "docsrs"]
|
|
@ -0,0 +1,99 @@
|
||||||
|
use super::worker::{ClientRequest, ClientWorker};
|
||||||
|
use crate::shared::{ComputeTask, ConnectionId, Task, TaskResponseContent};
|
||||||
|
use burn_common::id::StreamId;
|
||||||
|
use burn_tensor::repr::TensorId;
|
||||||
|
use std::{
|
||||||
|
future::Future,
|
||||||
|
sync::{atomic::AtomicU64, Arc},
|
||||||
|
};
|
||||||
|
use tokio::sync::mpsc::Sender;
|
||||||
|
|
||||||
|
pub use super::WsDevice;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct WsClient {
|
||||||
|
pub(crate) device: WsDevice,
|
||||||
|
pub(crate) sender: Arc<WsSender>,
|
||||||
|
pub(crate) runtime: Arc<tokio::runtime::Runtime>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WsClient {
|
||||||
|
pub fn init(device: WsDevice) -> Self {
|
||||||
|
ClientWorker::start(device)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn new(
|
||||||
|
device: WsDevice,
|
||||||
|
sender: Sender<ClientRequest>,
|
||||||
|
runtime: Arc<tokio::runtime::Runtime>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
device,
|
||||||
|
runtime,
|
||||||
|
sender: Arc::new(WsSender {
|
||||||
|
sender,
|
||||||
|
position_counter: AtomicU64::new(0),
|
||||||
|
tensor_id_counter: AtomicU64::new(0),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct WsSender {
|
||||||
|
sender: Sender<ClientRequest>,
|
||||||
|
position_counter: AtomicU64,
|
||||||
|
tensor_id_counter: AtomicU64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WsSender {
|
||||||
|
pub(crate) fn send(&self, task: ComputeTask) -> impl Future<Output = ()> + Send {
|
||||||
|
let position = self
|
||||||
|
.position_counter
|
||||||
|
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
let stream_id = StreamId::current();
|
||||||
|
let sender = self.sender.clone();
|
||||||
|
|
||||||
|
async move {
|
||||||
|
sender
|
||||||
|
.send(ClientRequest::WithoutCallback(Task::Compute(
|
||||||
|
task,
|
||||||
|
ConnectionId::new(position, stream_id),
|
||||||
|
)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn new_tensor_id(&self) -> TensorId {
|
||||||
|
let val = self
|
||||||
|
.tensor_id_counter
|
||||||
|
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
TensorId::new(val)
|
||||||
|
}
|
||||||
|
pub(crate) fn send_callback(
|
||||||
|
&self,
|
||||||
|
task: ComputeTask,
|
||||||
|
) -> impl Future<Output = TaskResponseContent> + Send {
|
||||||
|
let position = self
|
||||||
|
.position_counter
|
||||||
|
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
let stream_id = StreamId::current();
|
||||||
|
let sender = self.sender.clone();
|
||||||
|
let (callback_sender, mut callback_recv) = tokio::sync::mpsc::channel(1);
|
||||||
|
|
||||||
|
async move {
|
||||||
|
sender
|
||||||
|
.send(ClientRequest::WithSyncCallback(
|
||||||
|
Task::Compute(task, ConnectionId::new(position, stream_id)),
|
||||||
|
callback_sender,
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
match callback_recv.recv().await {
|
||||||
|
Some(val) => val,
|
||||||
|
None => panic!(""),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
use burn_router::{RouterTensor, RunnerChannel, TensorHandle};
|
||||||
|
use burn_tensor::repr::TensorDescription;
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
runner::{WsBridge, WsDevice},
|
||||||
|
WsClient,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// A local channel with direct connection to the backend runner clients.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct WsChannel;
|
||||||
|
|
||||||
|
impl RunnerChannel for WsChannel {
|
||||||
|
type Device = WsDevice;
|
||||||
|
type Bridge = WsBridge;
|
||||||
|
type Client = WsClient;
|
||||||
|
|
||||||
|
type FloatElem = f32;
|
||||||
|
|
||||||
|
type IntElem = i32;
|
||||||
|
|
||||||
|
fn name() -> String {
|
||||||
|
"remote".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn init_client(device: &Self::Device) -> Self::Client {
|
||||||
|
WsClient::init(device.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_tensor_handle(
|
||||||
|
_tensor: &TensorDescription,
|
||||||
|
_client: &Self::Client,
|
||||||
|
) -> TensorHandle<Self::Bridge> {
|
||||||
|
panic!("Unsupported")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn register_tensor(
|
||||||
|
_client: &Self::Client,
|
||||||
|
_handle: TensorHandle<Self::Bridge>,
|
||||||
|
_shape: Vec<usize>,
|
||||||
|
_dtype: burn_tensor::DType,
|
||||||
|
) -> RouterTensor<Self::Client> {
|
||||||
|
panic!("Unsupported")
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,8 @@
|
||||||
|
mod base;
|
||||||
|
mod channel;
|
||||||
|
mod runner;
|
||||||
|
mod worker;
|
||||||
|
|
||||||
|
pub use base::*;
|
||||||
|
pub use channel::*;
|
||||||
|
pub use runner::WsDevice;
|
|
@ -0,0 +1,175 @@
|
||||||
|
use burn_router::{MultiBackendBridge, RouterTensor, RunnerClient};
|
||||||
|
use burn_tensor::{
|
||||||
|
backend::{DeviceId, DeviceOps},
|
||||||
|
DType, TensorData,
|
||||||
|
};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::shared::{ComputeTask, TaskResponseContent};
|
||||||
|
|
||||||
|
use super::WsClient;
|
||||||
|
|
||||||
|
// It is very important to block on any request made with the sender, since ordering is crucial
|
||||||
|
// when registering operation or creating tensors.
|
||||||
|
//
|
||||||
|
// The overhead is minimal, since we only wait for the task to be sent to the async
|
||||||
|
// channel, but not sent to the websocket server and even less processed by the server.
|
||||||
|
impl RunnerClient for WsClient {
|
||||||
|
type Device = WsDevice;
|
||||||
|
|
||||||
|
fn register(&self, op: burn_tensor::repr::OperationDescription) {
|
||||||
|
let fut = self
|
||||||
|
.sender
|
||||||
|
.send(ComputeTask::RegisterOperation(Box::new(op)));
|
||||||
|
self.runtime.block_on(fut);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_tensor(
|
||||||
|
&self,
|
||||||
|
tensor: burn_tensor::repr::TensorDescription,
|
||||||
|
) -> impl std::future::Future<Output = TensorData> + Send {
|
||||||
|
// Important for ordering to call the creation of the future sync.
|
||||||
|
let fut = self.sender.send_callback(ComputeTask::ReadTensor(tensor));
|
||||||
|
|
||||||
|
async move {
|
||||||
|
match fut.await {
|
||||||
|
TaskResponseContent::ReadTensor(data) => data,
|
||||||
|
_ => panic!("Invalid message type"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
|
||||||
|
let id = self.sender.new_tensor_id();
|
||||||
|
let shape = data.shape.clone();
|
||||||
|
let dtype = data.dtype;
|
||||||
|
|
||||||
|
let fut = self.sender.send(ComputeTask::RegisterTensor(id, data));
|
||||||
|
|
||||||
|
self.runtime.block_on(fut);
|
||||||
|
|
||||||
|
RouterTensor::new(Arc::new(id), shape, dtype, self.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn register_empty_tensor(
|
||||||
|
&self,
|
||||||
|
shape: Vec<usize>,
|
||||||
|
dtype: burn_tensor::DType,
|
||||||
|
) -> RouterTensor<Self> {
|
||||||
|
let id = self.sender.new_tensor_id();
|
||||||
|
|
||||||
|
RouterTensor::new(Arc::new(id), shape, dtype, self.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn register_float_tensor(
|
||||||
|
&self,
|
||||||
|
shape: Vec<usize>,
|
||||||
|
_full_precision: bool,
|
||||||
|
) -> RouterTensor<Self> {
|
||||||
|
self.register_empty_tensor(shape, DType::F32)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn device(&self) -> Self::Device {
|
||||||
|
self.device.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn register_orphan(&self, id: &burn_tensor::repr::TensorId) {
|
||||||
|
let fut = self.sender.send(ComputeTask::RegisterOrphan(*id));
|
||||||
|
self.runtime.block_on(fut);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sync(&self) {
|
||||||
|
// Important for ordering to call the creation of the future sync.
|
||||||
|
let fut = self.sender.send_callback(ComputeTask::SyncBackend);
|
||||||
|
|
||||||
|
let fut = async move {
|
||||||
|
match fut.await {
|
||||||
|
TaskResponseContent::SyncBackend => {}
|
||||||
|
_ => panic!("Invalid message type"),
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
self.runtime.block_on(fut)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn seed(&self, _seed: u64) {
|
||||||
|
// TODO
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||||
|
/// The device contains the connection information of the server.
|
||||||
|
pub struct WsDevice {
|
||||||
|
pub(crate) address: Arc<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WsDevice {
|
||||||
|
/// Create a device from an url.
|
||||||
|
pub fn new(url: &str) -> Self {
|
||||||
|
let mut address = String::new();
|
||||||
|
|
||||||
|
if !url.starts_with("ws://") {
|
||||||
|
address += "ws://";
|
||||||
|
address += url;
|
||||||
|
} else {
|
||||||
|
address += url;
|
||||||
|
};
|
||||||
|
|
||||||
|
Self {
|
||||||
|
address: Arc::new(address),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for WsDevice {
|
||||||
|
fn default() -> Self {
|
||||||
|
let address = match std::env::var("BURN_REMOTE_ADDRESS") {
|
||||||
|
Ok(address) => address,
|
||||||
|
Err(_) => String::from("ws://127.0.0.1:3000"),
|
||||||
|
};
|
||||||
|
|
||||||
|
Self {
|
||||||
|
address: Arc::new(address),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DeviceOps for WsDevice {
|
||||||
|
fn id(&self) -> DeviceId {
|
||||||
|
DeviceId {
|
||||||
|
type_id: 0,
|
||||||
|
index_id: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct WsBridge;
|
||||||
|
|
||||||
|
impl MultiBackendBridge for WsBridge {
|
||||||
|
type TensorHandle = TensorData;
|
||||||
|
type Device = WsDevice;
|
||||||
|
|
||||||
|
fn change_backend_float(
|
||||||
|
tensor: Self::TensorHandle,
|
||||||
|
_shape: burn_tensor::Shape,
|
||||||
|
_target_device: &Self::Device,
|
||||||
|
) -> Self::TensorHandle {
|
||||||
|
tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
fn change_backend_int(
|
||||||
|
tensor: Self::TensorHandle,
|
||||||
|
_shape: burn_tensor::Shape,
|
||||||
|
_target_device: &Self::Device,
|
||||||
|
) -> Self::TensorHandle {
|
||||||
|
tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
fn change_backend_bool(
|
||||||
|
tensor: Self::TensorHandle,
|
||||||
|
_shape: burn_tensor::Shape,
|
||||||
|
_target_device: &Self::Device,
|
||||||
|
) -> Self::TensorHandle {
|
||||||
|
tensor
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,140 @@
|
||||||
|
use super::{runner::WsDevice, WsClient};
|
||||||
|
use crate::shared::{ConnectionId, SessionId, Task, TaskResponse, TaskResponseContent};
|
||||||
|
use futures_util::{SinkExt, StreamExt};
|
||||||
|
use std::{collections::HashMap, sync::Arc};
|
||||||
|
use tokio_tungstenite::{
|
||||||
|
connect_async_with_config,
|
||||||
|
tungstenite::protocol::{Message, WebSocketConfig},
|
||||||
|
};
|
||||||
|
|
||||||
|
pub type CallbackSender = tokio::sync::mpsc::Sender<TaskResponseContent>;
|
||||||
|
|
||||||
|
pub enum ClientRequest {
|
||||||
|
WithSyncCallback(Task, CallbackSender),
|
||||||
|
WithoutCallback(Task),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub(crate) struct ClientWorker {
|
||||||
|
requests: HashMap<ConnectionId, CallbackSender>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClientWorker {
|
||||||
|
async fn on_response(&mut self, response: TaskResponse) {
|
||||||
|
match self.requests.remove(&response.id) {
|
||||||
|
Some(request) => {
|
||||||
|
request.send(response.content).await.unwrap();
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
panic!("Can't ignore message from the server.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn register_callback(&mut self, id: ConnectionId, callback: CallbackSender) {
|
||||||
|
self.requests.insert(id, callback);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClientWorker {
|
||||||
|
pub fn start(device: WsDevice) -> WsClient {
|
||||||
|
let runtime = Arc::new(
|
||||||
|
tokio::runtime::Builder::new_multi_thread()
|
||||||
|
.enable_io()
|
||||||
|
.build()
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let (sender, mut rec) = tokio::sync::mpsc::channel(10);
|
||||||
|
let address_request = format!("{}/{}", device.address.as_str(), "request");
|
||||||
|
let address_response = format!("{}/{}", device.address.as_str(), "response");
|
||||||
|
|
||||||
|
const MB: usize = 1024 * 1024;
|
||||||
|
|
||||||
|
#[allow(deprecated)]
|
||||||
|
runtime.spawn(async move {
|
||||||
|
log::info!("Connecting to {address_request} ...");
|
||||||
|
let (mut stream_request, _) = connect_async_with_config(
|
||||||
|
address_request.clone(),
|
||||||
|
Some(WebSocketConfig {
|
||||||
|
max_send_queue: None,
|
||||||
|
write_buffer_size: 0,
|
||||||
|
max_write_buffer_size: usize::MAX,
|
||||||
|
max_message_size: None,
|
||||||
|
max_frame_size: Some(MB * 512),
|
||||||
|
accept_unmasked_frames: true,
|
||||||
|
}),
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("Failed to connect");
|
||||||
|
let (mut stream_response, _) = connect_async_with_config(
|
||||||
|
address_response,
|
||||||
|
Some(WebSocketConfig {
|
||||||
|
max_send_queue: None,
|
||||||
|
write_buffer_size: 0,
|
||||||
|
max_write_buffer_size: usize::MAX,
|
||||||
|
max_message_size: None,
|
||||||
|
max_frame_size: Some(MB * 512),
|
||||||
|
accept_unmasked_frames: true,
|
||||||
|
}),
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("Failed to connect");
|
||||||
|
|
||||||
|
let state = Arc::new(tokio::sync::Mutex::new(ClientWorker::default()));
|
||||||
|
|
||||||
|
// Init the connection.
|
||||||
|
let session_id = SessionId::new();
|
||||||
|
let bytes = rmp_serde::to_vec(&Task::Init(session_id)).expect("Can serialize tasks to bytes.");
|
||||||
|
stream_request.send(Message::Binary(bytes.clone())).await.expect("Can send the message on the websocket.");
|
||||||
|
stream_response.send(Message::Binary(bytes)).await.expect("Can send the message on the websocket.");
|
||||||
|
|
||||||
|
// Websocket async worker loading callback from the server.
|
||||||
|
let state_ws = state.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
while let Some(msg) = stream_response.next().await {
|
||||||
|
let msg = match msg {
|
||||||
|
Ok(msg) => msg,
|
||||||
|
Err(err) => panic!("An error happened while receiving messages from the websocket: {err:?}"),
|
||||||
|
};
|
||||||
|
|
||||||
|
match msg {
|
||||||
|
Message::Binary(bytes) => {
|
||||||
|
let response: TaskResponse = rmp_serde::from_slice(&bytes).expect("Can deserialize messages from the websocket.");
|
||||||
|
let mut state = state_ws.lock().await;
|
||||||
|
state.on_response(response).await;
|
||||||
|
}
|
||||||
|
Message::Close(_) => {
|
||||||
|
log::warn!("Closed connection");
|
||||||
|
return;
|
||||||
|
},
|
||||||
|
_ => panic!("Unsupported websocket message: {msg:?}"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Channel async worker sending operations to the server.
|
||||||
|
tokio::spawn(async move {
|
||||||
|
while let Some(req) = rec.recv().await {
|
||||||
|
let task = match req {
|
||||||
|
ClientRequest::WithSyncCallback(task, callback) => {
|
||||||
|
let mut state = state.lock().await;
|
||||||
|
if let Task::Compute(_content, id) = &task {
|
||||||
|
state.register_callback(*id, callback);
|
||||||
|
}
|
||||||
|
task
|
||||||
|
}
|
||||||
|
ClientRequest::WithoutCallback(task) => task,
|
||||||
|
|
||||||
|
};
|
||||||
|
let bytes = rmp_serde::to_vec(&task).expect("Can serialize tasks to bytes.");
|
||||||
|
stream_request.send(Message::Binary(bytes)).await.expect("Can send the message on the websocket.");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
WsClient::new(device, sender, runtime)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,37 @@
|
||||||
|
#[macro_use]
|
||||||
|
extern crate derive_new;
|
||||||
|
|
||||||
|
#[cfg(feature = "client")]
|
||||||
|
pub(crate) mod client;
|
||||||
|
|
||||||
|
#[cfg(feature = "server")]
|
||||||
|
pub mod server;
|
||||||
|
|
||||||
|
pub(crate) mod shared;
|
||||||
|
|
||||||
|
#[cfg(feature = "client")]
|
||||||
|
mod __client {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
use burn_router::BackendRouter;
|
||||||
|
use client::WsChannel;
|
||||||
|
|
||||||
|
/// The remote backend allows you to run computation on a remote device.
|
||||||
|
///
|
||||||
|
/// Make sure there is a running server before trying to connect to it.
|
||||||
|
///
|
||||||
|
/// ```rust, ignore
|
||||||
|
/// fn main() {
|
||||||
|
/// let device = Default::default();
|
||||||
|
/// let port = 3000;
|
||||||
|
///
|
||||||
|
/// // You need to activate the `server` feature flag to have access to this function.
|
||||||
|
/// burn::server::start::<burn::backend::Wgpu>(device, port);
|
||||||
|
/// }
|
||||||
|
///```
|
||||||
|
pub type RemoteBackend = BackendRouter<WsChannel>;
|
||||||
|
|
||||||
|
pub use client::WsDevice as RemoteDevice;
|
||||||
|
}
|
||||||
|
#[cfg(feature = "client")]
|
||||||
|
pub use __client::*;
|
|
@ -0,0 +1,196 @@
|
||||||
|
use std::{net::SocketAddr, sync::Arc};
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
extract::{
|
||||||
|
ws::{self, WebSocket, WebSocketUpgrade},
|
||||||
|
State,
|
||||||
|
},
|
||||||
|
response::IntoResponse,
|
||||||
|
routing::any,
|
||||||
|
Router,
|
||||||
|
};
|
||||||
|
|
||||||
|
use burn_tensor::{
|
||||||
|
backend::{Backend, BackendBridge},
|
||||||
|
repr::ReprBackend,
|
||||||
|
Device,
|
||||||
|
};
|
||||||
|
use tracing_core::{Level, LevelFilter};
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
use tracing_subscriber::{filter::filter_fn, registry};
|
||||||
|
|
||||||
|
use crate::shared::{ComputeTask, Task};
|
||||||
|
|
||||||
|
use super::session::SessionManager;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct WsServer<B: ReprBackend> {
|
||||||
|
state: Arc<SessionManager<B>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: ReprBackend> WsServer<B>
|
||||||
|
where
|
||||||
|
// Restrict full precision backend handle to be the same
|
||||||
|
<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target:
|
||||||
|
ReprBackend<Handle = B::Handle>,
|
||||||
|
{
|
||||||
|
/// Start the server on the given address.
|
||||||
|
pub async fn start(device: Device<B>, port: u16) {
|
||||||
|
let layer = tracing_subscriber::fmt::layer()
|
||||||
|
.with_filter(LevelFilter::INFO)
|
||||||
|
.with_filter(filter_fn(|m| {
|
||||||
|
if let Some(path) = m.module_path() {
|
||||||
|
// The wgpu crate is logging too much, so we skip `info` level.
|
||||||
|
if path.starts_with("wgpu") && *m.level() >= Level::INFO {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
true
|
||||||
|
}));
|
||||||
|
registry().with(layer).init();
|
||||||
|
|
||||||
|
let address = format!("0.0.0.0:{port}");
|
||||||
|
log::info!("Start server {address} on device {device:?}");
|
||||||
|
|
||||||
|
let state = SessionManager::<B>::new(device);
|
||||||
|
let state = Self {
|
||||||
|
state: Arc::new(state),
|
||||||
|
};
|
||||||
|
|
||||||
|
// build our application with some routes
|
||||||
|
let app = Router::new()
|
||||||
|
.route("/response", any(Self::handler_response))
|
||||||
|
.route("/request", any(Self::handler_request))
|
||||||
|
.with_state(state);
|
||||||
|
|
||||||
|
// run it with hyper
|
||||||
|
let listener = tokio::net::TcpListener::bind(address).await.unwrap();
|
||||||
|
axum::serve(
|
||||||
|
listener,
|
||||||
|
app.into_make_service_with_connect_info::<SocketAddr>(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handler_response(
|
||||||
|
ws: WebSocketUpgrade,
|
||||||
|
State(state): State<Self>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
ws.on_upgrade(move |socket| state.handle_socket_response(socket))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handler_request(ws: WebSocketUpgrade, State(state): State<Self>) -> impl IntoResponse {
|
||||||
|
ws.on_upgrade(move |socket| state.handle_socket_request(socket))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_socket_response(self, mut socket: WebSocket) {
|
||||||
|
log::info!("[Response Handler] On new connection.");
|
||||||
|
|
||||||
|
let packet = socket.recv().await;
|
||||||
|
let msg = match packet {
|
||||||
|
Some(msg) => msg,
|
||||||
|
None => {
|
||||||
|
log::info!("Still no message");
|
||||||
|
panic!("");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Ok(ws::Message::Binary(bytes)) = msg {
|
||||||
|
let task = match rmp_serde::from_slice::<Task>(&bytes) {
|
||||||
|
Ok(val) => val,
|
||||||
|
Err(err) => {
|
||||||
|
log::info!("Only bytes messages are supported {err:?}");
|
||||||
|
panic!("");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let id = match task {
|
||||||
|
Task::Init(id) => id,
|
||||||
|
_ => panic!(""),
|
||||||
|
};
|
||||||
|
|
||||||
|
let receiver = self.state.register_responder(id).await;
|
||||||
|
|
||||||
|
log::info!("Response handler connection active");
|
||||||
|
|
||||||
|
while let Ok(callback) = receiver.recv() {
|
||||||
|
let response = callback.recv().unwrap();
|
||||||
|
let bytes = rmp_serde::to_vec(&response).unwrap();
|
||||||
|
|
||||||
|
socket.send(ws::Message::Binary(bytes)).await.unwrap();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
panic!("");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_socket_request(self, mut socket: WebSocket) {
|
||||||
|
log::info!("[Request Handler] On new connection.");
|
||||||
|
let mut session_id = None;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let packet = socket.recv().await;
|
||||||
|
let msg = match packet {
|
||||||
|
Some(msg) => msg,
|
||||||
|
None => {
|
||||||
|
log::info!("Still no message");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Ok(ws::Message::Binary(bytes)) = msg {
|
||||||
|
let task = match rmp_serde::from_slice::<Task>(&bytes) {
|
||||||
|
Ok(val) => val,
|
||||||
|
Err(err) => {
|
||||||
|
log::info!("Only bytes message in the json format are supported {err:?}");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let (stream, connection_id, task) =
|
||||||
|
match self.state.stream(&mut session_id, task).await {
|
||||||
|
Some(val) => val,
|
||||||
|
None => {
|
||||||
|
log::info!("Ops session activated {session_id:?}");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match task {
|
||||||
|
ComputeTask::RegisterOperation(op) => {
|
||||||
|
stream.register_operation(op);
|
||||||
|
}
|
||||||
|
ComputeTask::RegisterTensor(id, data) => {
|
||||||
|
stream.register_tensor(id, data);
|
||||||
|
}
|
||||||
|
ComputeTask::RegisterOrphan(id) => {
|
||||||
|
stream.register_orphan(id);
|
||||||
|
}
|
||||||
|
ComputeTask::ReadTensor(tensor) => {
|
||||||
|
stream.read_tensor(connection_id, tensor);
|
||||||
|
}
|
||||||
|
ComputeTask::SyncBackend => {
|
||||||
|
stream.sync(connection_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log::info!("Not a binary message, closing, received {msg:?}");
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
log::info!("Closing connection");
|
||||||
|
self.state.close(session_id).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
/// Start the server on the given port and [device](Device).
|
||||||
|
pub async fn start<B: ReprBackend>(device: Device<B>, port: u16)
|
||||||
|
where
|
||||||
|
// Restrict full precision backend handle to be the same
|
||||||
|
<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target:
|
||||||
|
ReprBackend<Handle = B::Handle>,
|
||||||
|
{
|
||||||
|
WsServer::<B>::start(device, port).await;
|
||||||
|
}
|
|
@ -0,0 +1,7 @@
|
||||||
|
pub(crate) mod processor;
|
||||||
|
pub(crate) mod session;
|
||||||
|
pub(crate) mod stream;
|
||||||
|
|
||||||
|
mod base;
|
||||||
|
|
||||||
|
pub use base::start;
|
|
@ -0,0 +1,84 @@
|
||||||
|
use burn_router::{Runner, RunnerClient};
|
||||||
|
use burn_tensor::{
|
||||||
|
backend::{Backend, BackendBridge},
|
||||||
|
repr::{OperationDescription, ReprBackend, TensorDescription, TensorId},
|
||||||
|
TensorData,
|
||||||
|
};
|
||||||
|
use core::marker::PhantomData;
|
||||||
|
use std::sync::mpsc::Sender;
|
||||||
|
|
||||||
|
use crate::shared::{ConnectionId, TaskResponse, TaskResponseContent};
|
||||||
|
|
||||||
|
/// The goal of the processor is to asynchonously process compute tasks on it own thread.
|
||||||
|
pub struct Processor<B: ReprBackend> {
|
||||||
|
p: PhantomData<B>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type Callback<M> = Sender<M>;
|
||||||
|
|
||||||
|
pub enum ProcessorTask {
|
||||||
|
RegisterOperation(Box<OperationDescription>),
|
||||||
|
RegisterTensor(TensorId, TensorData),
|
||||||
|
ReadTensor(ConnectionId, TensorDescription, Callback<TaskResponse>),
|
||||||
|
Sync(ConnectionId, Callback<TaskResponse>),
|
||||||
|
Fence(Callback<()>),
|
||||||
|
RegisterOrphan(TensorId),
|
||||||
|
Close,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: ReprBackend> Processor<B>
|
||||||
|
where
|
||||||
|
// Restrict full precision backend handle to be the same
|
||||||
|
<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target:
|
||||||
|
ReprBackend<Handle = B::Handle>,
|
||||||
|
{
|
||||||
|
pub fn start(runner: Runner<B>) -> Sender<ProcessorTask> {
|
||||||
|
let (sender, rec) = std::sync::mpsc::channel();
|
||||||
|
|
||||||
|
std::thread::spawn(move || {
|
||||||
|
for item in rec.iter() {
|
||||||
|
match item {
|
||||||
|
ProcessorTask::RegisterOperation(op) => {
|
||||||
|
runner.register(*op);
|
||||||
|
}
|
||||||
|
ProcessorTask::RegisterOrphan(id) => {
|
||||||
|
runner.register_orphan(&id);
|
||||||
|
}
|
||||||
|
ProcessorTask::Sync(id, callback) => {
|
||||||
|
runner.sync();
|
||||||
|
callback
|
||||||
|
.send(TaskResponse {
|
||||||
|
content: TaskResponseContent::SyncBackend,
|
||||||
|
id,
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
ProcessorTask::RegisterTensor(id, data) => {
|
||||||
|
runner.register_tensor_data_id(id, data);
|
||||||
|
}
|
||||||
|
ProcessorTask::ReadTensor(id, tensor, callback) => {
|
||||||
|
let tensor = burn_common::future::block_on(runner.read_tensor(tensor));
|
||||||
|
callback
|
||||||
|
.send(TaskResponse {
|
||||||
|
content: TaskResponseContent::ReadTensor(tensor),
|
||||||
|
id,
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
ProcessorTask::Close => {
|
||||||
|
let device = runner.device();
|
||||||
|
runner.sync();
|
||||||
|
core::mem::drop(runner);
|
||||||
|
B::sync(&device);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ProcessorTask::Fence(sender) => {
|
||||||
|
sender.send(()).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
sender
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,242 @@
|
||||||
|
use burn_common::id::StreamId;
|
||||||
|
use burn_router::Runner;
|
||||||
|
use burn_tensor::{
|
||||||
|
backend::{Backend, BackendBridge},
|
||||||
|
repr::{ReprBackend, TensorDescription, TensorId, TensorStatus},
|
||||||
|
Device,
|
||||||
|
};
|
||||||
|
use std::{
|
||||||
|
collections::HashMap,
|
||||||
|
sync::mpsc::{Receiver, Sender},
|
||||||
|
};
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
use crate::shared::{ComputeTask, ConnectionId, SessionId, Task, TaskResponse};
|
||||||
|
|
||||||
|
use super::stream::Stream;
|
||||||
|
|
||||||
|
/// A session manager control the creation of sessions.
|
||||||
|
///
|
||||||
|
/// Each session manages its own stream, spawning one thread per stream to mimic the same behavior
|
||||||
|
/// a native backend would have.
|
||||||
|
pub struct SessionManager<B: ReprBackend> {
|
||||||
|
runner: Runner<B>,
|
||||||
|
sessions: tokio::sync::Mutex<HashMap<SessionId, Session<B>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Session<B: ReprBackend> {
|
||||||
|
runner: Runner<B>,
|
||||||
|
tensors: HashMap<TensorId, Vec<StreamId>>,
|
||||||
|
streams: HashMap<StreamId, Stream<B>>,
|
||||||
|
sender: Sender<Receiver<TaskResponse>>,
|
||||||
|
receiver: Option<Receiver<Receiver<TaskResponse>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: ReprBackend> SessionManager<B>
|
||||||
|
where
|
||||||
|
// Restrict full precision backend handle to be the same
|
||||||
|
<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target:
|
||||||
|
ReprBackend<Handle = B::Handle>,
|
||||||
|
{
|
||||||
|
pub fn new(device: Device<B>) -> Self {
|
||||||
|
Self {
|
||||||
|
runner: Runner::new(device),
|
||||||
|
sessions: Mutex::new(Default::default()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register a new responder for the session. Only one responder can exist for a session for
|
||||||
|
/// now.
|
||||||
|
pub async fn register_responder(
|
||||||
|
&self,
|
||||||
|
session_id: SessionId,
|
||||||
|
) -> Receiver<Receiver<TaskResponse>> {
|
||||||
|
log::info!("Register responder for session {session_id}");
|
||||||
|
let mut sessions = self.sessions.lock().await;
|
||||||
|
self.register_session(&mut sessions, session_id);
|
||||||
|
|
||||||
|
let session = sessions.get_mut(&session_id).unwrap();
|
||||||
|
session.init_responder()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the stream for the current session and task.
|
||||||
|
pub async fn stream(
|
||||||
|
&self,
|
||||||
|
session_id: &mut Option<SessionId>,
|
||||||
|
task: Task,
|
||||||
|
) -> Option<(Stream<B>, ConnectionId, ComputeTask)> {
|
||||||
|
let mut sessions = self.sessions.lock().await;
|
||||||
|
|
||||||
|
let session_id = match session_id {
|
||||||
|
Some(id) => *id,
|
||||||
|
None => match task {
|
||||||
|
Task::Init(id) => {
|
||||||
|
log::info!("Init requester for session {id}");
|
||||||
|
*session_id = Some(id);
|
||||||
|
self.register_session(&mut sessions, id);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
_ => panic!("The first message should initialize the session"),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
match sessions.get_mut(&session_id) {
|
||||||
|
Some(session) => {
|
||||||
|
let (task, connection_id) = match task {
|
||||||
|
Task::Compute(task, connection_id) => (task, connection_id),
|
||||||
|
_ => panic!("Only support compute tasks."),
|
||||||
|
};
|
||||||
|
let stream = session.select(connection_id.stream_id, &task);
|
||||||
|
Some((stream, connection_id, task))
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
panic!("To be initialized");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Close the session with the given id.
|
||||||
|
pub async fn close(&self, session_id: Option<SessionId>) {
|
||||||
|
if let Some(id) = session_id {
|
||||||
|
let mut sessions = self.sessions.lock().await;
|
||||||
|
if let Some(session) = sessions.get_mut(&id) {
|
||||||
|
session.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn register_session(&self, sessions: &mut HashMap<SessionId, Session<B>>, id: SessionId) {
|
||||||
|
sessions.entry(id).or_insert_with(|| {
|
||||||
|
log::info!("Creating a new session {id}");
|
||||||
|
|
||||||
|
Session::new(self.runner.clone())
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: ReprBackend> Session<B>
|
||||||
|
where
|
||||||
|
// Restrict full precision backend handle to be the same
|
||||||
|
<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target:
|
||||||
|
ReprBackend<Handle = B::Handle>,
|
||||||
|
{
|
||||||
|
fn new(runner: Runner<B>) -> Self {
|
||||||
|
let (sender, reveiver) = std::sync::mpsc::channel();
|
||||||
|
Self {
|
||||||
|
runner,
|
||||||
|
tensors: Default::default(),
|
||||||
|
streams: Default::default(),
|
||||||
|
sender,
|
||||||
|
receiver: Some(reveiver),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn init_responder(&mut self) -> Receiver<Receiver<TaskResponse>> {
|
||||||
|
let mut receiver = None;
|
||||||
|
core::mem::swap(&mut receiver, &mut self.receiver);
|
||||||
|
receiver.expect("Only one responder per session is possible.")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Select the current [stream](Stream) based on the given task.
|
||||||
|
fn select(&mut self, stream_id: StreamId, task: &ComputeTask) -> Stream<B> {
|
||||||
|
// We have to check every streams involved in the last operation, making
|
||||||
|
// sure the backend is up-to-date with those operations.
|
||||||
|
//
|
||||||
|
// 1. We update the tensor status of all tensors in the task.
|
||||||
|
// 2. We don't keep track of tensors that are used for the last time.
|
||||||
|
let mut fences = Vec::new();
|
||||||
|
for (tensor_id, status) in task.tensors_info() {
|
||||||
|
let tensor_stream_ids = match self.tensors.get(&tensor_id) {
|
||||||
|
Some(val) => val,
|
||||||
|
None => {
|
||||||
|
if status != TensorStatus::ReadWrite {
|
||||||
|
// Add the first stream that created the tensor that may be used by other
|
||||||
|
// streams later.
|
||||||
|
self.register_tensor(tensor_id, stream_id);
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let current_stream_already_synced = tensor_stream_ids.contains(&stream_id);
|
||||||
|
|
||||||
|
if !current_stream_already_synced {
|
||||||
|
// We only need to sync to the first stream that created the tensor.
|
||||||
|
if let Some(id) = tensor_stream_ids.iter().next() {
|
||||||
|
fences.push(*id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We add the stream to the list of updated stream to avoid needed to flush other
|
||||||
|
// operations that might use this tensor.
|
||||||
|
self.register_tensor(tensor_id, stream_id);
|
||||||
|
|
||||||
|
// If the tensor has the status `read_write`, it means no other stream can reuse it
|
||||||
|
// afterward, so we remove it from the state.
|
||||||
|
if status == TensorStatus::ReadWrite {
|
||||||
|
self.tensors.remove(&tensor_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup orphans.
|
||||||
|
if let ComputeTask::RegisterOrphan(tensor_id) = task {
|
||||||
|
self.tensors.remove(tensor_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// We have to wait for the streams to be updated.
|
||||||
|
for stream_id in fences {
|
||||||
|
if let Some(stream) = self.streams.get(&stream_id) {
|
||||||
|
stream.fence_sync();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We return the stream.
|
||||||
|
match self.streams.get(&stream_id) {
|
||||||
|
Some(stream) => stream.clone(),
|
||||||
|
None => {
|
||||||
|
let stream = Stream::<B>::new(self.runner.clone(), self.sender.clone());
|
||||||
|
self.streams.insert(stream_id, stream.clone());
|
||||||
|
stream
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn register_tensor(&mut self, tensor_id: TensorId, stream_id: StreamId) {
|
||||||
|
match self.tensors.get_mut(&tensor_id) {
|
||||||
|
Some(ids) => {
|
||||||
|
ids.push(stream_id);
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
self.tensors.insert(tensor_id, vec![stream_id]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close all streams created in the session.
|
||||||
|
fn close(&mut self) {
|
||||||
|
for (id, stream) in self.streams.drain() {
|
||||||
|
log::info!("Closing stream {id}");
|
||||||
|
stream.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ComputeTask {
|
||||||
|
fn tensors_info(&self) -> Vec<(TensorId, TensorStatus)> {
|
||||||
|
fn from_descriptions(desc: &[&TensorDescription]) -> Vec<(TensorId, TensorStatus)> {
|
||||||
|
desc.iter().map(|t| (t.id, t.status.clone())).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
match self {
|
||||||
|
ComputeTask::RegisterOperation(op) => from_descriptions(&op.nodes()),
|
||||||
|
ComputeTask::RegisterTensor(tensor_id, _tensor_data) => {
|
||||||
|
vec![(*tensor_id, TensorStatus::NotInit)]
|
||||||
|
}
|
||||||
|
ComputeTask::RegisterOrphan(tensor_id) => {
|
||||||
|
vec![(*tensor_id, TensorStatus::ReadWrite)]
|
||||||
|
}
|
||||||
|
ComputeTask::ReadTensor(tensor_description) => from_descriptions(&[tensor_description]),
|
||||||
|
ComputeTask::SyncBackend => vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,94 @@
|
||||||
|
use core::marker::PhantomData;
|
||||||
|
use std::sync::mpsc::{Receiver, Sender};
|
||||||
|
|
||||||
|
use crate::shared::{ConnectionId, TaskResponse};
|
||||||
|
|
||||||
|
use super::processor::{Processor, ProcessorTask};
|
||||||
|
use burn_router::Runner;
|
||||||
|
use burn_tensor::{
|
||||||
|
backend::{Backend, BackendBridge},
|
||||||
|
repr::{OperationDescription, ReprBackend, TensorDescription, TensorId},
|
||||||
|
TensorData,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// A stream makes sure all operations registered are executed in the order they were sent to the
|
||||||
|
/// server, protentially waiting to reconstruct consistency.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Stream<B: ReprBackend> {
|
||||||
|
compute_sender: Sender<ProcessorTask>,
|
||||||
|
writer_sender: Sender<Receiver<TaskResponse>>,
|
||||||
|
_p: PhantomData<B>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: ReprBackend> Stream<B>
|
||||||
|
where
|
||||||
|
// Restrict full precision backend handle to be the same
|
||||||
|
<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target:
|
||||||
|
ReprBackend<Handle = B::Handle>,
|
||||||
|
{
|
||||||
|
pub fn new(runner: Runner<B>, writer_sender: Sender<Receiver<TaskResponse>>) -> Self {
|
||||||
|
let sender = Processor::start(runner);
|
||||||
|
|
||||||
|
Self {
|
||||||
|
compute_sender: sender,
|
||||||
|
writer_sender,
|
||||||
|
_p: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn register_operation(&self, op: Box<OperationDescription>) {
|
||||||
|
self.compute_sender
|
||||||
|
.send(ProcessorTask::RegisterOperation(op))
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn register_tensor(&self, tensor_id: TensorId, data: TensorData) {
|
||||||
|
self.compute_sender
|
||||||
|
.send(ProcessorTask::RegisterTensor(tensor_id, data))
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn register_orphan(&self, tensor_id: TensorId) {
|
||||||
|
self.compute_sender
|
||||||
|
.send(ProcessorTask::RegisterOrphan(tensor_id))
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read_tensor(&self, id: ConnectionId, desc: TensorDescription) {
|
||||||
|
let (callback_sender, callback_rec) = std::sync::mpsc::channel();
|
||||||
|
|
||||||
|
self.compute_sender
|
||||||
|
.send(ProcessorTask::ReadTensor(id, desc, callback_sender))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
self.writer_sender.send(callback_rec).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sync(&self, id: ConnectionId) {
|
||||||
|
let (callback_sender, callback_rec) = std::sync::mpsc::channel();
|
||||||
|
|
||||||
|
self.compute_sender
|
||||||
|
.send(ProcessorTask::Sync(id, callback_sender))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
self.writer_sender.send(callback_rec).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that all tasks are sent to the backend.
|
||||||
|
//
|
||||||
|
// It doesn't mean that the computation is done, but it means the backend has received the
|
||||||
|
// tasks, which may be queued.
|
||||||
|
pub fn fence_sync(&self) {
|
||||||
|
let (callback_sender, callback_rec) = std::sync::mpsc::channel();
|
||||||
|
|
||||||
|
self.compute_sender
|
||||||
|
.send(ProcessorTask::Fence(callback_sender.clone()))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
callback_rec.recv().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn close(&self) {
|
||||||
|
self.compute_sender.send(ProcessorTask::Close).unwrap();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,2 @@
|
||||||
|
mod task;
|
||||||
|
pub(crate) use task::*;
|
|
@ -0,0 +1,68 @@
|
||||||
|
use std::fmt::Display;
|
||||||
|
|
||||||
|
use burn_common::id::{IdGenerator, StreamId};
|
||||||
|
use burn_tensor::{
|
||||||
|
repr::{OperationDescription, TensorDescription, TensorId},
|
||||||
|
TensorData,
|
||||||
|
};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[allow(missing_docs)]
|
||||||
|
#[derive(new, Serialize, Deserialize, Debug, Hash, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)]
|
||||||
|
pub struct ConnectionId {
|
||||||
|
pub position: u64,
|
||||||
|
pub stream_id: StreamId,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unique identifier that can represent a session.
|
||||||
|
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)]
|
||||||
|
pub struct SessionId {
|
||||||
|
id: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for SessionId {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
writeln!(f, "SessionId({})", self.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SessionId {
|
||||||
|
/// Create a new [session id](SessionId).
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
id: IdGenerator::generate(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(missing_docs)]
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub enum Task {
|
||||||
|
Compute(ComputeTask, ConnectionId),
|
||||||
|
Init(SessionId),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(missing_docs)]
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub enum ComputeTask {
|
||||||
|
RegisterOperation(Box<OperationDescription>),
|
||||||
|
RegisterTensor(TensorId, TensorData),
|
||||||
|
RegisterOrphan(TensorId),
|
||||||
|
ReadTensor(TensorDescription),
|
||||||
|
SyncBackend,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(missing_docs)]
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub struct TaskResponse {
|
||||||
|
pub content: TaskResponseContent,
|
||||||
|
pub id: ConnectionId,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(missing_docs)]
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub enum TaskResponseContent {
|
||||||
|
ReadTensor(TensorData),
|
||||||
|
SyncBackend,
|
||||||
|
}
|
|
@ -13,14 +13,15 @@ version.workspace = true
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["std"]
|
default = ["std"]
|
||||||
std = []
|
std = ["burn-tensor/std", "burn-common/std"]
|
||||||
doc = ["default"]
|
doc = ["default"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = ["repr"]}
|
burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = ["repr"]}
|
||||||
|
burn-common = { path = "../burn-common", version = "0.16.0", default-features = false}
|
||||||
hashbrown = { workspace = true }
|
hashbrown = { workspace = true }
|
||||||
spin = { workspace = true }
|
spin = { workspace = true }
|
||||||
|
log = { workspace = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, features = [
|
burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, features = [
|
||||||
|
@ -31,7 +32,7 @@ burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features =
|
||||||
] }
|
] }
|
||||||
|
|
||||||
burn-ndarray = { path = "../burn-ndarray", version = "0.16.0" }
|
burn-ndarray = { path = "../burn-ndarray", version = "0.16.0" }
|
||||||
burn-wgpu = { path = "../burn-wgpu", version = "0.16.0" }
|
burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", default-features = false }
|
||||||
|
|
||||||
|
|
||||||
[package.metadata.docs.rs]
|
[package.metadata.docs.rs]
|
||||||
|
|
|
@ -14,26 +14,3 @@ impl<Backends, Bridge> Clone for DirectChannel<Backends, Bridge> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: conflicting implementations because B1 and B2 cannot be differentiated (could be the same type)
|
|
||||||
// impl<B1: ReprBackend, B2: ReprBackend> From<RouterTensor<Runner<B1>>>
|
|
||||||
// for RouterTensor<MultiRunnerClient2<B1, B2>>
|
|
||||||
// {
|
|
||||||
// fn from(value: RouterTensor<Runner<B1>>) -> Self {
|
|
||||||
// RouterTensor {
|
|
||||||
// desc: value.desc,
|
|
||||||
// client: MultiRunnerClient2::RunnerClient1(value.client),
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// impl<B1: ReprBackend, B2: ReprBackend> From<RouterTensor<Runner<B2>>>
|
|
||||||
// for RouterTensor<MultiRunnerClient2<B1, B2>>
|
|
||||||
// {
|
|
||||||
// fn from(value: RouterTensor<Runner<B2>>) -> Self {
|
|
||||||
// RouterTensor {
|
|
||||||
// desc: value.desc,
|
|
||||||
// client: MultiRunnerClient2::RunnerClient2(value.client),
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
|
@ -58,7 +58,8 @@ where
|
||||||
<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target:
|
<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target:
|
||||||
ReprBackend<Handle = B::Handle>,
|
ReprBackend<Handle = B::Handle>,
|
||||||
{
|
{
|
||||||
pub(crate) fn new(device: B::Device) -> Self {
|
/// Create a new runner.
|
||||||
|
pub fn new(device: B::Device) -> Self {
|
||||||
Self {
|
Self {
|
||||||
context: Arc::new(Mutex::new(RunnerContext {
|
context: Arc::new(Mutex::new(RunnerContext {
|
||||||
handles: HandleContainer::new(),
|
handles: HandleContainer::new(),
|
||||||
|
@ -90,7 +91,29 @@ where
|
||||||
RouterTensor::new(id, shape, dtype, client)
|
RouterTensor::new(id, shape, dtype, client)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn register_tensor_data_desc(&self, data: TensorData) -> TensorDescription {
|
/// Register a tensor from its data and id.
|
||||||
|
pub fn register_tensor_data_id(&self, id: TensorId, data: TensorData) {
|
||||||
|
let mut ctx = self.context.lock();
|
||||||
|
let dtype = data.dtype;
|
||||||
|
|
||||||
|
if dtype.is_float() {
|
||||||
|
let tensor = B::float_from_data(data, &self.device);
|
||||||
|
ctx.handles.register_float_tensor::<B>(&id, tensor)
|
||||||
|
} else if dtype.is_int() {
|
||||||
|
let tensor = B::int_from_data(data, &self.device);
|
||||||
|
ctx.handles.register_int_tensor::<B>(&id, tensor)
|
||||||
|
} else if dtype.is_bool() {
|
||||||
|
let tensor = B::bool_from_data(data, &self.device);
|
||||||
|
ctx.handles.register_bool_tensor::<B>(&id, tensor)
|
||||||
|
} else if let DType::QFloat(_) = dtype {
|
||||||
|
todo!();
|
||||||
|
}
|
||||||
|
|
||||||
|
core::mem::drop(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register a tensor and returns its description.
|
||||||
|
pub fn register_tensor_data_desc(&self, data: TensorData) -> TensorDescription {
|
||||||
let mut ctx = self.context.lock();
|
let mut ctx = self.context.lock();
|
||||||
let id = ctx.create_empty_handle();
|
let id = ctx.create_empty_handle();
|
||||||
let shape = data.shape.clone();
|
let shape = data.shape.clone();
|
||||||
|
@ -119,11 +142,8 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn register_empty_tensor_desc(
|
/// Register an empty tensor and returns its description.
|
||||||
&self,
|
pub fn register_empty_tensor_desc(&self, shape: Vec<usize>, dtype: DType) -> TensorDescription {
|
||||||
shape: Vec<usize>,
|
|
||||||
dtype: DType,
|
|
||||||
) -> TensorDescription {
|
|
||||||
let mut ctx = self.context.lock();
|
let mut ctx = self.context.lock();
|
||||||
let id = ctx.create_empty_handle();
|
let id = ctx.create_empty_handle();
|
||||||
core::mem::drop(ctx);
|
core::mem::drop(ctx);
|
||||||
|
|
|
@ -21,7 +21,8 @@ pub struct RouterTensor<C: RunnerClient> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<C: RunnerClient> RouterTensor<C> {
|
impl<C: RunnerClient> RouterTensor<C> {
|
||||||
pub(crate) fn new(id: Arc<TensorId>, shape: Vec<usize>, dtype: DType, client: C) -> Self {
|
/// Create a new router tensor.
|
||||||
|
pub fn new(id: Arc<TensorId>, shape: Vec<usize>, dtype: DType, client: C) -> Self {
|
||||||
Self {
|
Self {
|
||||||
id,
|
id,
|
||||||
shape,
|
shape,
|
||||||
|
|
|
@ -55,6 +55,8 @@ ndarray = ["burn-core/ndarray"]
|
||||||
tch = ["burn-core/tch"]
|
tch = ["burn-core/tch"]
|
||||||
wgpu = ["burn-core/wgpu"]
|
wgpu = ["burn-core/wgpu"]
|
||||||
wgpu-spirv = ["burn-core/wgpu-spirv"]
|
wgpu-spirv = ["burn-core/wgpu-spirv"]
|
||||||
|
remote = ["burn-core/remote"]
|
||||||
|
server = ["burn-core/server"]
|
||||||
|
|
||||||
# Network utils
|
# Network utils
|
||||||
network = ["burn-core/network"]
|
network = ["burn-core/network"]
|
||||||
|
|
|
@ -92,6 +92,7 @@
|
||||||
//! - `autodiff`: Makes available the Autodiff backend
|
//! - `autodiff`: Makes available the Autodiff backend
|
||||||
//! - Others:
|
//! - Others:
|
||||||
//! - `std`: Activates the standard library (deactivate for no_std)
|
//! - `std`: Activates the standard library (deactivate for no_std)
|
||||||
|
//! - `server`: Enables the remote server.
|
||||||
//! - `network`: Enables network utilities (currently, only a file downloader with progress bar)
|
//! - `network`: Enables network utilities (currently, only a file downloader with progress bar)
|
||||||
//! - `experimental-named-tensor`: Enables named tensors (experimental)
|
//! - `experimental-named-tensor`: Enables named tensors (experimental)
|
||||||
//!
|
//!
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use burn::backend::{wgpu::WgpuDevice, Autodiff, Wgpu};
|
use burn::backend::{Autodiff, Wgpu};
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
custom_training_loop::run::<Autodiff<Wgpu>>(WgpuDevice::default());
|
custom_training_loop::run::<Autodiff<Wgpu>>(Default::default());
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
[package]
|
||||||
|
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||||
|
edition.workspace = true
|
||||||
|
license.workspace = true
|
||||||
|
name = "server"
|
||||||
|
publish = false
|
||||||
|
version.workspace = true
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["wgpu"]
|
||||||
|
cuda-jit = ["burn/cuda-jit"]
|
||||||
|
wgpu = ["burn/wgpu"]
|
||||||
|
wgpu-spirv = ["wgpu", "burn/wgpu-spirv"]
|
||||||
|
ndarray = ["burn/ndarray"]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
cfg-if = { workspace = true }
|
||||||
|
burn = { path = "../../crates/burn", version = "0.16.0", features = ["server"] }
|
|
@ -0,0 +1,3 @@
|
||||||
|
fn main() {
|
||||||
|
server::start();
|
||||||
|
}
|
|
@ -0,0 +1,20 @@
|
||||||
|
pub fn start() {
|
||||||
|
let port = std::env::var("REMOTE_BACKEND_PORT")
|
||||||
|
.map(|port| match port.parse::<u16>() {
|
||||||
|
Ok(val) => val,
|
||||||
|
Err(err) => panic!("Invalid port, got {port} with error {err}"),
|
||||||
|
})
|
||||||
|
.unwrap_or(3000);
|
||||||
|
|
||||||
|
cfg_if::cfg_if! {
|
||||||
|
if #[cfg(feature = "ndarray")]{
|
||||||
|
burn::server::start::<burn::backend::NdArray>(Default::default(), port);
|
||||||
|
} else if #[cfg(feature = "wgpu")] {
|
||||||
|
burn::server::start::<burn::backend::Wgpu>(Default::default(), port);
|
||||||
|
} else if #[cfg(feature = "cuda-jit")]{
|
||||||
|
burn::server::start::<burn::backend::CudaJit>(Default::default(), port);
|
||||||
|
} else {
|
||||||
|
panic!("No backend selected, can't start server on port {port}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -17,6 +17,7 @@ tch-cpu = ["burn/tch"]
|
||||||
tch-gpu = ["burn/tch"]
|
tch-gpu = ["burn/tch"]
|
||||||
wgpu = ["burn/wgpu"]
|
wgpu = ["burn/wgpu"]
|
||||||
wgpu-spirv = ["wgpu", "burn/wgpu-spirv"]
|
wgpu-spirv = ["wgpu", "burn/wgpu-spirv"]
|
||||||
|
remote = ["burn/remote"]
|
||||||
cuda-jit = ["burn/cuda-jit"]
|
cuda-jit = ["burn/cuda-jit"]
|
||||||
hip-jit = ["burn/hip-jit"]
|
hip-jit = ["burn/hip-jit"]
|
||||||
|
|
||||||
|
|
|
@ -91,6 +91,16 @@ mod wgpu {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
mod remote {
|
||||||
|
use crate::{launch, ElemType};
|
||||||
|
use burn::backend::{Autodiff, RemoteBackend};
|
||||||
|
|
||||||
|
pub fn run() {
|
||||||
|
launch::<Autodiff<RemoteBackend>>(vec![Default::default()]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(feature = "cuda-jit")]
|
#[cfg(feature = "cuda-jit")]
|
||||||
mod cuda_jit {
|
mod cuda_jit {
|
||||||
use crate::{launch, ElemType};
|
use crate::{launch, ElemType};
|
||||||
|
@ -129,4 +139,6 @@ fn main() {
|
||||||
cuda_jit::run();
|
cuda_jit::run();
|
||||||
#[cfg(feature = "hip-jit")]
|
#[cfg(feature = "hip-jit")]
|
||||||
hip_jit::run();
|
hip_jit::run();
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
remote::run();
|
||||||
}
|
}
|
||||||
|
|
|
@ -96,8 +96,6 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
|
||||||
.metric_valid_numeric(AccuracyMetric::new())
|
.metric_valid_numeric(AccuracyMetric::new())
|
||||||
.metric_train_numeric(LossMetric::new())
|
.metric_train_numeric(LossMetric::new())
|
||||||
.metric_valid_numeric(LossMetric::new())
|
.metric_valid_numeric(LossMetric::new())
|
||||||
.metric_train_numeric(LossMetric::new())
|
|
||||||
.metric_valid_numeric(LossMetric::new())
|
|
||||||
.metric_train_numeric(LearningRateMetric::new())
|
.metric_train_numeric(LearningRateMetric::new())
|
||||||
.with_file_checkpointer(CompactRecorder::new())
|
.with_file_checkpointer(CompactRecorder::new())
|
||||||
.devices(devices)
|
.devices(devices)
|
||||||
|
|
Loading…
Reference in New Issue