mirror of https://github.com/tracel-ai/burn.git
Remote Backend (#2463)
This commit is contained in:
parent
9b9b03c959
commit
099b6dcae0
|
@ -229,6 +229,17 @@ dependencies = [
|
|||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.83"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.87",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atoi"
|
||||
version = "2.0.0"
|
||||
|
@ -296,6 +307,64 @@ dependencies = [
|
|||
"arrayvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum"
|
||||
version = "0.7.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 1.1.0",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
"hyper 1.5.0",
|
||||
"hyper-util",
|
||||
"itoa",
|
||||
"matchit",
|
||||
"memchr",
|
||||
"mime",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"rustversion",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"serde_urlencoded",
|
||||
"sha1",
|
||||
"sync_wrapper 1.0.1",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-core"
|
||||
version = "0.4.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 1.1.0",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
"mime",
|
||||
"pin-project-lite",
|
||||
"rustversion",
|
||||
"sync_wrapper 1.0.1",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backend-comparison"
|
||||
version = "0.16.0"
|
||||
|
@ -524,6 +593,7 @@ dependencies = [
|
|||
"indicatif",
|
||||
"rayon",
|
||||
"reqwest 0.12.9",
|
||||
"serde",
|
||||
"tokio",
|
||||
"web-time",
|
||||
]
|
||||
|
@ -542,6 +612,7 @@ dependencies = [
|
|||
"burn-derive",
|
||||
"burn-hip",
|
||||
"burn-ndarray",
|
||||
"burn-remote",
|
||||
"burn-tch",
|
||||
"burn-tensor",
|
||||
"burn-wgpu",
|
||||
|
@ -647,6 +718,7 @@ dependencies = [
|
|||
"derive-new 0.7.0",
|
||||
"half",
|
||||
"log",
|
||||
"paste",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -729,15 +801,38 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "burn-remote"
|
||||
version = "0.16.0"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"burn-common",
|
||||
"burn-remote",
|
||||
"burn-router",
|
||||
"burn-tensor",
|
||||
"derive-new 0.7.0",
|
||||
"futures-util",
|
||||
"log",
|
||||
"rmp-serde",
|
||||
"serde",
|
||||
"serde_bytes",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tracing-core",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "burn-router"
|
||||
version = "0.16.0"
|
||||
dependencies = [
|
||||
"burn-autodiff",
|
||||
"burn-common",
|
||||
"burn-ndarray",
|
||||
"burn-tensor",
|
||||
"burn-wgpu",
|
||||
"hashbrown 0.15.0",
|
||||
"log",
|
||||
"spin",
|
||||
]
|
||||
|
||||
|
@ -3149,6 +3244,7 @@ dependencies = [
|
|||
"http 1.1.0",
|
||||
"http-body 1.0.1",
|
||||
"httparse",
|
||||
"httpdate",
|
||||
"itoa",
|
||||
"pin-project-lite",
|
||||
"smallvec",
|
||||
|
@ -3693,6 +3789,12 @@ dependencies = [
|
|||
"regex-automata 0.1.10",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "matchit"
|
||||
version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
|
||||
|
||||
[[package]]
|
||||
name = "matrixmultiply"
|
||||
version = "0.3.9"
|
||||
|
@ -6098,6 +6200,16 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_path_to_error"
|
||||
version = "0.1.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "af99884400da37c88f5e9146b7f1fd0fbcae8f6eec4e9da38b67d05486f814a6"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_rusqlite"
|
||||
version = "0.36.0"
|
||||
|
@ -6154,6 +6266,14 @@ dependencies = [
|
|||
"syn 2.0.87",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "server"
|
||||
version = "0.16.0"
|
||||
dependencies = [
|
||||
"burn",
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha1"
|
||||
version = "0.10.6"
|
||||
|
@ -6874,6 +6994,18 @@ dependencies = [
|
|||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.24.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.7.12"
|
||||
|
@ -6936,6 +7068,28 @@ dependencies = [
|
|||
"zip 0.6.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"pin-project-lite",
|
||||
"sync_wrapper 0.1.2",
|
||||
"tokio",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower-layer"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e"
|
||||
|
||||
[[package]]
|
||||
name = "tower-service"
|
||||
version = "0.3.3"
|
||||
|
@ -6978,6 +7132,7 @@ version = "0.1.40"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
|
||||
dependencies = [
|
||||
"log",
|
||||
"pin-project-lite",
|
||||
"tracing-attributes",
|
||||
"tracing-core",
|
||||
|
@ -7051,6 +7206,24 @@ version = "0.2.5"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.24.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"data-encoding",
|
||||
"http 1.1.0",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand",
|
||||
"sha1",
|
||||
"thiserror",
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.17.0"
|
||||
|
@ -7174,6 +7347,12 @@ dependencies = [
|
|||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "utf-8"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||
|
||||
[[package]]
|
||||
name = "utf8parse"
|
||||
version = "0.2.2"
|
||||
|
|
|
@ -119,6 +119,7 @@ bincode = { version = "2.0.0-rc.3", features = [
|
|||
# The following packages disable the "std" feature for no_std compatibility
|
||||
#
|
||||
derive-new = { version = "0.7.0", default-features = false }
|
||||
cfg-if = "1.0.0"
|
||||
|
||||
blas-src = { version = "0.10.0", default-features = false }
|
||||
half = { version = "2.4.1", features = [
|
||||
|
|
|
@ -23,6 +23,8 @@ getrandom = { workspace = true, features = ["js"] }
|
|||
web-time = { version = "1.1.0" }
|
||||
|
||||
[dependencies]
|
||||
serde = { workspace = true }
|
||||
|
||||
# Network downloader
|
||||
indicatif = { workspace = true, optional = true }
|
||||
reqwest = { workspace = true, optional = true }
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use crate::rand::gen_random;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Simple ID generator.
|
||||
pub struct IdGenerator {}
|
||||
|
@ -64,3 +65,49 @@ mod tests {
|
|||
assert_eq!(set.len(), EXPECTED_TOTAL_IDS);
|
||||
}
|
||||
}
|
||||
|
||||
/// Unique identifier that can represent a stream based on the current thread id.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
pub struct StreamId {
|
||||
/// The value representing the thread id.
|
||||
pub value: u64,
|
||||
}
|
||||
|
||||
impl StreamId {
|
||||
/// Get the current thread id.
|
||||
pub fn current() -> Self {
|
||||
Self {
|
||||
#[cfg(feature = "std")]
|
||||
value: Self::from_current_thread(),
|
||||
#[cfg(not(feature = "std"))]
|
||||
value: 0,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
fn from_current_thread() -> u64 {
|
||||
use core::hash::Hash;
|
||||
|
||||
std::thread_local! {
|
||||
static ID: std::cell::OnceCell::<u64> = const { std::cell::OnceCell::new() };
|
||||
};
|
||||
|
||||
// Getting the current thread is expensive, so we cache the value into a thread local
|
||||
// variable, which is very fast.
|
||||
ID.with(|cell| {
|
||||
*cell.get_or_init(|| {
|
||||
// A way to get a thread id encoded as u64.
|
||||
let mut hasher = std::hash::DefaultHasher::default();
|
||||
let id = std::thread::current().id();
|
||||
id.hash(&mut hasher);
|
||||
std::hash::Hasher::finish(&hasher)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl core::fmt::Display for StreamId {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_fmt(format_args!("StreamId({:?})", self.value))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,6 +39,8 @@ doc = [
|
|||
"hip-jit",
|
||||
"vision",
|
||||
"autodiff",
|
||||
"remote",
|
||||
"server",
|
||||
# Doc features
|
||||
"burn-candle/doc",
|
||||
"burn-common/doc",
|
||||
|
@ -86,6 +88,8 @@ metal = ["burn-candle?/metal"]
|
|||
openblas = ["burn-ndarray?/blas-openblas"]
|
||||
openblas-system = ["burn-ndarray?/blas-openblas-system"]
|
||||
template = ["burn-wgpu?/template"]
|
||||
remote = ["burn-remote/client"]
|
||||
server = ["burn-remote/server"]
|
||||
|
||||
candle = ["burn-candle"]
|
||||
candle-cuda = ["candle", "burn-candle/cuda"]
|
||||
|
@ -131,6 +135,7 @@ burn-hip = { path = "../burn-hip", version = "0.16.0", optional = true, default-
|
|||
burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", optional = true, default-features = false }
|
||||
burn-tch = { path = "../burn-tch", version = "0.16.0", optional = true }
|
||||
burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", optional = true, default-features = false }
|
||||
burn-remote = { path = "../burn-remote", version = "0.16.0", default-features = false, optional = true }
|
||||
|
||||
data-encoding = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
|
|
|
@ -7,6 +7,11 @@ pub use ndarray::NdArray;
|
|||
#[cfg(feature = "autodiff")]
|
||||
pub use burn_autodiff as autodiff;
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
pub use burn_remote as remote;
|
||||
#[cfg(feature = "remote")]
|
||||
pub use burn_remote::RemoteBackend;
|
||||
|
||||
#[cfg(feature = "autodiff")]
|
||||
pub use burn_autodiff::Autodiff;
|
||||
|
||||
|
|
|
@ -43,6 +43,9 @@ pub mod tensor;
|
|||
/// Backend module.
|
||||
pub mod backend;
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
pub use burn_remote::server;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
#[cfg(all(
|
||||
|
|
|
@ -12,8 +12,8 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cuda"
|
|||
version.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["fusion", "autotune", "burn-jit/default", "cubecl/default"]
|
||||
autotune = ["burn-jit/autotune"]
|
||||
default = ["fusion", "burn-jit/default", "cubecl/default"]
|
||||
doc = ["burn-jit/doc"]
|
||||
fusion = ["burn-fusion", "burn-jit/fusion"]
|
||||
std = ["burn-jit/std", "cubecl/std"]
|
||||
|
|
|
@ -34,6 +34,7 @@ derive-new = { workspace = true }
|
|||
burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false, features = [
|
||||
"export_tests",
|
||||
] }
|
||||
paste = { workspace = true }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
|
|
|
@ -24,6 +24,7 @@ mod tests {
|
|||
use burn_jit::JitBackend;
|
||||
|
||||
pub type TestRuntime = cubecl::hip::HipRuntime;
|
||||
pub use half::{bf16, f16};
|
||||
|
||||
burn_jit::testgen_all!();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
[package]
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
categories = ["science"]
|
||||
description = "Backend router decorator over websocket."
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "data"]
|
||||
license.workspace = true
|
||||
name = "burn-remote"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-router-remote"
|
||||
documentation = "https://docs.rs/burn-router-remote"
|
||||
version.workspace = true
|
||||
|
||||
[features]
|
||||
default = []
|
||||
doc = []
|
||||
client = ["tokio-tungstenite"]
|
||||
server = ["axum", "tracing-core", "tracing-subscriber"]
|
||||
|
||||
|
||||
[dependencies]
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = true, features = ["repr"]}
|
||||
burn-common = { path = "../burn-common", version = "0.16.0", default-features = true}
|
||||
burn-router = { path = "../burn-router", version = "0.16.0", default-features = true}
|
||||
|
||||
# Basic dependencies
|
||||
derive-new = {workspace = true }
|
||||
log = { workspace = true }
|
||||
|
||||
# Shared dependencies
|
||||
tokio = { version = "1.37", features = ["sync", "rt-multi-thread"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_bytes = { workspace = true }
|
||||
rmp-serde = { workspace = true }
|
||||
futures-util = { version = "0.3" }
|
||||
|
||||
# Client dependencies
|
||||
tokio-tungstenite = { version = "0.24", optional = true }
|
||||
|
||||
# Server dependencies
|
||||
axum = { version = "0.7.5", features = ["ws"], optional = true }
|
||||
tracing-core = { workspace = true, optional = true }
|
||||
tracing-subscriber = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
# We activate the features client and server during dev.
|
||||
burn-remote = { path = ".", version = "0.16.0", features=["client", "server"] }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
|
@ -0,0 +1,99 @@
|
|||
use super::worker::{ClientRequest, ClientWorker};
|
||||
use crate::shared::{ComputeTask, ConnectionId, Task, TaskResponseContent};
|
||||
use burn_common::id::StreamId;
|
||||
use burn_tensor::repr::TensorId;
|
||||
use std::{
|
||||
future::Future,
|
||||
sync::{atomic::AtomicU64, Arc},
|
||||
};
|
||||
use tokio::sync::mpsc::Sender;
|
||||
|
||||
pub use super::WsDevice;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WsClient {
|
||||
pub(crate) device: WsDevice,
|
||||
pub(crate) sender: Arc<WsSender>,
|
||||
pub(crate) runtime: Arc<tokio::runtime::Runtime>,
|
||||
}
|
||||
|
||||
impl WsClient {
|
||||
pub fn init(device: WsDevice) -> Self {
|
||||
ClientWorker::start(device)
|
||||
}
|
||||
|
||||
pub(crate) fn new(
|
||||
device: WsDevice,
|
||||
sender: Sender<ClientRequest>,
|
||||
runtime: Arc<tokio::runtime::Runtime>,
|
||||
) -> Self {
|
||||
Self {
|
||||
device,
|
||||
runtime,
|
||||
sender: Arc::new(WsSender {
|
||||
sender,
|
||||
position_counter: AtomicU64::new(0),
|
||||
tensor_id_counter: AtomicU64::new(0),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct WsSender {
|
||||
sender: Sender<ClientRequest>,
|
||||
position_counter: AtomicU64,
|
||||
tensor_id_counter: AtomicU64,
|
||||
}
|
||||
|
||||
impl WsSender {
|
||||
pub(crate) fn send(&self, task: ComputeTask) -> impl Future<Output = ()> + Send {
|
||||
let position = self
|
||||
.position_counter
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
let stream_id = StreamId::current();
|
||||
let sender = self.sender.clone();
|
||||
|
||||
async move {
|
||||
sender
|
||||
.send(ClientRequest::WithoutCallback(Task::Compute(
|
||||
task,
|
||||
ConnectionId::new(position, stream_id),
|
||||
)))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_tensor_id(&self) -> TensorId {
|
||||
let val = self
|
||||
.tensor_id_counter
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
TensorId::new(val)
|
||||
}
|
||||
pub(crate) fn send_callback(
|
||||
&self,
|
||||
task: ComputeTask,
|
||||
) -> impl Future<Output = TaskResponseContent> + Send {
|
||||
let position = self
|
||||
.position_counter
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
let stream_id = StreamId::current();
|
||||
let sender = self.sender.clone();
|
||||
let (callback_sender, mut callback_recv) = tokio::sync::mpsc::channel(1);
|
||||
|
||||
async move {
|
||||
sender
|
||||
.send(ClientRequest::WithSyncCallback(
|
||||
Task::Compute(task, ConnectionId::new(position, stream_id)),
|
||||
callback_sender,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match callback_recv.recv().await {
|
||||
Some(val) => val,
|
||||
None => panic!(""),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
use burn_router::{RouterTensor, RunnerChannel, TensorHandle};
|
||||
use burn_tensor::repr::TensorDescription;
|
||||
|
||||
use super::{
|
||||
runner::{WsBridge, WsDevice},
|
||||
WsClient,
|
||||
};
|
||||
|
||||
/// A local channel with direct connection to the backend runner clients.
|
||||
#[derive(Clone)]
|
||||
pub struct WsChannel;
|
||||
|
||||
impl RunnerChannel for WsChannel {
|
||||
type Device = WsDevice;
|
||||
type Bridge = WsBridge;
|
||||
type Client = WsClient;
|
||||
|
||||
type FloatElem = f32;
|
||||
|
||||
type IntElem = i32;
|
||||
|
||||
fn name() -> String {
|
||||
"remote".into()
|
||||
}
|
||||
|
||||
fn init_client(device: &Self::Device) -> Self::Client {
|
||||
WsClient::init(device.clone())
|
||||
}
|
||||
|
||||
fn get_tensor_handle(
|
||||
_tensor: &TensorDescription,
|
||||
_client: &Self::Client,
|
||||
) -> TensorHandle<Self::Bridge> {
|
||||
panic!("Unsupported")
|
||||
}
|
||||
|
||||
fn register_tensor(
|
||||
_client: &Self::Client,
|
||||
_handle: TensorHandle<Self::Bridge>,
|
||||
_shape: Vec<usize>,
|
||||
_dtype: burn_tensor::DType,
|
||||
) -> RouterTensor<Self::Client> {
|
||||
panic!("Unsupported")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
mod base;
|
||||
mod channel;
|
||||
mod runner;
|
||||
mod worker;
|
||||
|
||||
pub use base::*;
|
||||
pub use channel::*;
|
||||
pub use runner::WsDevice;
|
|
@ -0,0 +1,175 @@
|
|||
use burn_router::{MultiBackendBridge, RouterTensor, RunnerClient};
|
||||
use burn_tensor::{
|
||||
backend::{DeviceId, DeviceOps},
|
||||
DType, TensorData,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::shared::{ComputeTask, TaskResponseContent};
|
||||
|
||||
use super::WsClient;
|
||||
|
||||
// It is very important to block on any request made with the sender, since ordering is crucial
|
||||
// when registering operation or creating tensors.
|
||||
//
|
||||
// The overhead is minimal, since we only wait for the task to be sent to the async
|
||||
// channel, but not sent to the websocket server and even less processed by the server.
|
||||
impl RunnerClient for WsClient {
|
||||
type Device = WsDevice;
|
||||
|
||||
fn register(&self, op: burn_tensor::repr::OperationDescription) {
|
||||
let fut = self
|
||||
.sender
|
||||
.send(ComputeTask::RegisterOperation(Box::new(op)));
|
||||
self.runtime.block_on(fut);
|
||||
}
|
||||
|
||||
fn read_tensor(
|
||||
&self,
|
||||
tensor: burn_tensor::repr::TensorDescription,
|
||||
) -> impl std::future::Future<Output = TensorData> + Send {
|
||||
// Important for ordering to call the creation of the future sync.
|
||||
let fut = self.sender.send_callback(ComputeTask::ReadTensor(tensor));
|
||||
|
||||
async move {
|
||||
match fut.await {
|
||||
TaskResponseContent::ReadTensor(data) => data,
|
||||
_ => panic!("Invalid message type"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
|
||||
let id = self.sender.new_tensor_id();
|
||||
let shape = data.shape.clone();
|
||||
let dtype = data.dtype;
|
||||
|
||||
let fut = self.sender.send(ComputeTask::RegisterTensor(id, data));
|
||||
|
||||
self.runtime.block_on(fut);
|
||||
|
||||
RouterTensor::new(Arc::new(id), shape, dtype, self.clone())
|
||||
}
|
||||
|
||||
fn register_empty_tensor(
|
||||
&self,
|
||||
shape: Vec<usize>,
|
||||
dtype: burn_tensor::DType,
|
||||
) -> RouterTensor<Self> {
|
||||
let id = self.sender.new_tensor_id();
|
||||
|
||||
RouterTensor::new(Arc::new(id), shape, dtype, self.clone())
|
||||
}
|
||||
|
||||
fn register_float_tensor(
|
||||
&self,
|
||||
shape: Vec<usize>,
|
||||
_full_precision: bool,
|
||||
) -> RouterTensor<Self> {
|
||||
self.register_empty_tensor(shape, DType::F32)
|
||||
}
|
||||
|
||||
fn device(&self) -> Self::Device {
|
||||
self.device.clone()
|
||||
}
|
||||
|
||||
fn register_orphan(&self, id: &burn_tensor::repr::TensorId) {
|
||||
let fut = self.sender.send(ComputeTask::RegisterOrphan(*id));
|
||||
self.runtime.block_on(fut);
|
||||
}
|
||||
|
||||
fn sync(&self) {
|
||||
// Important for ordering to call the creation of the future sync.
|
||||
let fut = self.sender.send_callback(ComputeTask::SyncBackend);
|
||||
|
||||
let fut = async move {
|
||||
match fut.await {
|
||||
TaskResponseContent::SyncBackend => {}
|
||||
_ => panic!("Invalid message type"),
|
||||
};
|
||||
};
|
||||
|
||||
self.runtime.block_on(fut)
|
||||
}
|
||||
|
||||
fn seed(&self, _seed: u64) {
|
||||
// TODO
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||
/// The device contains the connection information of the server.
|
||||
pub struct WsDevice {
|
||||
pub(crate) address: Arc<String>,
|
||||
}
|
||||
|
||||
impl WsDevice {
|
||||
/// Create a device from an url.
|
||||
pub fn new(url: &str) -> Self {
|
||||
let mut address = String::new();
|
||||
|
||||
if !url.starts_with("ws://") {
|
||||
address += "ws://";
|
||||
address += url;
|
||||
} else {
|
||||
address += url;
|
||||
};
|
||||
|
||||
Self {
|
||||
address: Arc::new(address),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WsDevice {
|
||||
fn default() -> Self {
|
||||
let address = match std::env::var("BURN_REMOTE_ADDRESS") {
|
||||
Ok(address) => address,
|
||||
Err(_) => String::from("ws://127.0.0.1:3000"),
|
||||
};
|
||||
|
||||
Self {
|
||||
address: Arc::new(address),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DeviceOps for WsDevice {
|
||||
fn id(&self) -> DeviceId {
|
||||
DeviceId {
|
||||
type_id: 0,
|
||||
index_id: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WsBridge;
|
||||
|
||||
impl MultiBackendBridge for WsBridge {
|
||||
type TensorHandle = TensorData;
|
||||
type Device = WsDevice;
|
||||
|
||||
fn change_backend_float(
|
||||
tensor: Self::TensorHandle,
|
||||
_shape: burn_tensor::Shape,
|
||||
_target_device: &Self::Device,
|
||||
) -> Self::TensorHandle {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn change_backend_int(
|
||||
tensor: Self::TensorHandle,
|
||||
_shape: burn_tensor::Shape,
|
||||
_target_device: &Self::Device,
|
||||
) -> Self::TensorHandle {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn change_backend_bool(
|
||||
tensor: Self::TensorHandle,
|
||||
_shape: burn_tensor::Shape,
|
||||
_target_device: &Self::Device,
|
||||
) -> Self::TensorHandle {
|
||||
tensor
|
||||
}
|
||||
}
|
|
@ -0,0 +1,140 @@
|
|||
use super::{runner::WsDevice, WsClient};
|
||||
use crate::shared::{ConnectionId, SessionId, Task, TaskResponse, TaskResponseContent};
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio_tungstenite::{
|
||||
connect_async_with_config,
|
||||
tungstenite::protocol::{Message, WebSocketConfig},
|
||||
};
|
||||
|
||||
pub type CallbackSender = tokio::sync::mpsc::Sender<TaskResponseContent>;
|
||||
|
||||
pub enum ClientRequest {
|
||||
WithSyncCallback(Task, CallbackSender),
|
||||
WithoutCallback(Task),
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct ClientWorker {
|
||||
requests: HashMap<ConnectionId, CallbackSender>,
|
||||
}
|
||||
|
||||
impl ClientWorker {
|
||||
async fn on_response(&mut self, response: TaskResponse) {
|
||||
match self.requests.remove(&response.id) {
|
||||
Some(request) => {
|
||||
request.send(response.content).await.unwrap();
|
||||
}
|
||||
None => {
|
||||
panic!("Can't ignore message from the server.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register_callback(&mut self, id: ConnectionId, callback: CallbackSender) {
|
||||
self.requests.insert(id, callback);
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientWorker {
|
||||
pub fn start(device: WsDevice) -> WsClient {
|
||||
let runtime = Arc::new(
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_io()
|
||||
.build()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let (sender, mut rec) = tokio::sync::mpsc::channel(10);
|
||||
let address_request = format!("{}/{}", device.address.as_str(), "request");
|
||||
let address_response = format!("{}/{}", device.address.as_str(), "response");
|
||||
|
||||
const MB: usize = 1024 * 1024;
|
||||
|
||||
#[allow(deprecated)]
|
||||
runtime.spawn(async move {
|
||||
log::info!("Connecting to {address_request} ...");
|
||||
let (mut stream_request, _) = connect_async_with_config(
|
||||
address_request.clone(),
|
||||
Some(WebSocketConfig {
|
||||
max_send_queue: None,
|
||||
write_buffer_size: 0,
|
||||
max_write_buffer_size: usize::MAX,
|
||||
max_message_size: None,
|
||||
max_frame_size: Some(MB * 512),
|
||||
accept_unmasked_frames: true,
|
||||
}),
|
||||
true,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to connect");
|
||||
let (mut stream_response, _) = connect_async_with_config(
|
||||
address_response,
|
||||
Some(WebSocketConfig {
|
||||
max_send_queue: None,
|
||||
write_buffer_size: 0,
|
||||
max_write_buffer_size: usize::MAX,
|
||||
max_message_size: None,
|
||||
max_frame_size: Some(MB * 512),
|
||||
accept_unmasked_frames: true,
|
||||
}),
|
||||
true,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to connect");
|
||||
|
||||
let state = Arc::new(tokio::sync::Mutex::new(ClientWorker::default()));
|
||||
|
||||
// Init the connection.
|
||||
let session_id = SessionId::new();
|
||||
let bytes = rmp_serde::to_vec(&Task::Init(session_id)).expect("Can serialize tasks to bytes.");
|
||||
stream_request.send(Message::Binary(bytes.clone())).await.expect("Can send the message on the websocket.");
|
||||
stream_response.send(Message::Binary(bytes)).await.expect("Can send the message on the websocket.");
|
||||
|
||||
// Websocket async worker loading callback from the server.
|
||||
let state_ws = state.clone();
|
||||
tokio::spawn(async move {
|
||||
while let Some(msg) = stream_response.next().await {
|
||||
let msg = match msg {
|
||||
Ok(msg) => msg,
|
||||
Err(err) => panic!("An error happened while receiving messages from the websocket: {err:?}"),
|
||||
};
|
||||
|
||||
match msg {
|
||||
Message::Binary(bytes) => {
|
||||
let response: TaskResponse = rmp_serde::from_slice(&bytes).expect("Can deserialize messages from the websocket.");
|
||||
let mut state = state_ws.lock().await;
|
||||
state.on_response(response).await;
|
||||
}
|
||||
Message::Close(_) => {
|
||||
log::warn!("Closed connection");
|
||||
return;
|
||||
},
|
||||
_ => panic!("Unsupported websocket message: {msg:?}"),
|
||||
};
|
||||
}
|
||||
});
|
||||
|
||||
// Channel async worker sending operations to the server.
|
||||
tokio::spawn(async move {
|
||||
while let Some(req) = rec.recv().await {
|
||||
let task = match req {
|
||||
ClientRequest::WithSyncCallback(task, callback) => {
|
||||
let mut state = state.lock().await;
|
||||
if let Task::Compute(_content, id) = &task {
|
||||
state.register_callback(*id, callback);
|
||||
}
|
||||
task
|
||||
}
|
||||
ClientRequest::WithoutCallback(task) => task,
|
||||
|
||||
};
|
||||
let bytes = rmp_serde::to_vec(&task).expect("Can serialize tasks to bytes.");
|
||||
stream_request.send(Message::Binary(bytes)).await.expect("Can send the message on the websocket.");
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
WsClient::new(device, sender, runtime)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
#[cfg(feature = "client")]
|
||||
pub(crate) mod client;
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
pub mod server;
|
||||
|
||||
pub(crate) mod shared;
|
||||
|
||||
#[cfg(feature = "client")]
|
||||
mod __client {
|
||||
use super::*;
|
||||
|
||||
use burn_router::BackendRouter;
|
||||
use client::WsChannel;
|
||||
|
||||
/// The remote backend allows you to run computation on a remote device.
|
||||
///
|
||||
/// Make sure there is a running server before trying to connect to it.
|
||||
///
|
||||
/// ```rust, ignore
|
||||
/// fn main() {
|
||||
/// let device = Default::default();
|
||||
/// let port = 3000;
|
||||
///
|
||||
/// // You need to activate the `server` feature flag to have access to this function.
|
||||
/// burn::server::start::<burn::backend::Wgpu>(device, port);
|
||||
/// }
|
||||
///```
|
||||
pub type RemoteBackend = BackendRouter<WsChannel>;
|
||||
|
||||
pub use client::WsDevice as RemoteDevice;
|
||||
}
|
||||
#[cfg(feature = "client")]
|
||||
pub use __client::*;
|
|
@ -0,0 +1,196 @@
|
|||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use axum::{
|
||||
extract::{
|
||||
ws::{self, WebSocket, WebSocketUpgrade},
|
||||
State,
|
||||
},
|
||||
response::IntoResponse,
|
||||
routing::any,
|
||||
Router,
|
||||
};
|
||||
|
||||
use burn_tensor::{
|
||||
backend::{Backend, BackendBridge},
|
||||
repr::ReprBackend,
|
||||
Device,
|
||||
};
|
||||
use tracing_core::{Level, LevelFilter};
|
||||
use tracing_subscriber::prelude::*;
|
||||
use tracing_subscriber::{filter::filter_fn, registry};
|
||||
|
||||
use crate::shared::{ComputeTask, Task};
|
||||
|
||||
use super::session::SessionManager;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WsServer<B: ReprBackend> {
|
||||
state: Arc<SessionManager<B>>,
|
||||
}
|
||||
|
||||
impl<B: ReprBackend> WsServer<B>
|
||||
where
|
||||
// Restrict full precision backend handle to be the same
|
||||
<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target:
|
||||
ReprBackend<Handle = B::Handle>,
|
||||
{
|
||||
/// Start the server on the given address.
|
||||
pub async fn start(device: Device<B>, port: u16) {
|
||||
let layer = tracing_subscriber::fmt::layer()
|
||||
.with_filter(LevelFilter::INFO)
|
||||
.with_filter(filter_fn(|m| {
|
||||
if let Some(path) = m.module_path() {
|
||||
// The wgpu crate is logging too much, so we skip `info` level.
|
||||
if path.starts_with("wgpu") && *m.level() >= Level::INFO {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}));
|
||||
registry().with(layer).init();
|
||||
|
||||
let address = format!("0.0.0.0:{port}");
|
||||
log::info!("Start server {address} on device {device:?}");
|
||||
|
||||
let state = SessionManager::<B>::new(device);
|
||||
let state = Self {
|
||||
state: Arc::new(state),
|
||||
};
|
||||
|
||||
// build our application with some routes
|
||||
let app = Router::new()
|
||||
.route("/response", any(Self::handler_response))
|
||||
.route("/request", any(Self::handler_request))
|
||||
.with_state(state);
|
||||
|
||||
// run it with hyper
|
||||
let listener = tokio::net::TcpListener::bind(address).await.unwrap();
|
||||
axum::serve(
|
||||
listener,
|
||||
app.into_make_service_with_connect_info::<SocketAddr>(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
async fn handler_response(
|
||||
ws: WebSocketUpgrade,
|
||||
State(state): State<Self>,
|
||||
) -> impl IntoResponse {
|
||||
ws.on_upgrade(move |socket| state.handle_socket_response(socket))
|
||||
}
|
||||
|
||||
async fn handler_request(ws: WebSocketUpgrade, State(state): State<Self>) -> impl IntoResponse {
|
||||
ws.on_upgrade(move |socket| state.handle_socket_request(socket))
|
||||
}
|
||||
|
||||
async fn handle_socket_response(self, mut socket: WebSocket) {
|
||||
log::info!("[Response Handler] On new connection.");
|
||||
|
||||
let packet = socket.recv().await;
|
||||
let msg = match packet {
|
||||
Some(msg) => msg,
|
||||
None => {
|
||||
log::info!("Still no message");
|
||||
panic!("");
|
||||
}
|
||||
};
|
||||
|
||||
if let Ok(ws::Message::Binary(bytes)) = msg {
|
||||
let task = match rmp_serde::from_slice::<Task>(&bytes) {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
log::info!("Only bytes messages are supported {err:?}");
|
||||
panic!("");
|
||||
}
|
||||
};
|
||||
let id = match task {
|
||||
Task::Init(id) => id,
|
||||
_ => panic!(""),
|
||||
};
|
||||
|
||||
let receiver = self.state.register_responder(id).await;
|
||||
|
||||
log::info!("Response handler connection active");
|
||||
|
||||
while let Ok(callback) = receiver.recv() {
|
||||
let response = callback.recv().unwrap();
|
||||
let bytes = rmp_serde::to_vec(&response).unwrap();
|
||||
|
||||
socket.send(ws::Message::Binary(bytes)).await.unwrap();
|
||||
}
|
||||
} else {
|
||||
panic!("");
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_socket_request(self, mut socket: WebSocket) {
|
||||
log::info!("[Request Handler] On new connection.");
|
||||
let mut session_id = None;
|
||||
|
||||
loop {
|
||||
let packet = socket.recv().await;
|
||||
let msg = match packet {
|
||||
Some(msg) => msg,
|
||||
None => {
|
||||
log::info!("Still no message");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if let Ok(ws::Message::Binary(bytes)) = msg {
|
||||
let task = match rmp_serde::from_slice::<Task>(&bytes) {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
log::info!("Only bytes message in the json format are supported {err:?}");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let (stream, connection_id, task) =
|
||||
match self.state.stream(&mut session_id, task).await {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
log::info!("Ops session activated {session_id:?}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match task {
|
||||
ComputeTask::RegisterOperation(op) => {
|
||||
stream.register_operation(op);
|
||||
}
|
||||
ComputeTask::RegisterTensor(id, data) => {
|
||||
stream.register_tensor(id, data);
|
||||
}
|
||||
ComputeTask::RegisterOrphan(id) => {
|
||||
stream.register_orphan(id);
|
||||
}
|
||||
ComputeTask::ReadTensor(tensor) => {
|
||||
stream.read_tensor(connection_id, tensor);
|
||||
}
|
||||
ComputeTask::SyncBackend => {
|
||||
stream.sync(connection_id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log::info!("Not a binary message, closing, received {msg:?}");
|
||||
break;
|
||||
};
|
||||
}
|
||||
|
||||
log::info!("Closing connection");
|
||||
self.state.close(session_id).await;
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
/// Start the server on the given port and [device](Device).
|
||||
pub async fn start<B: ReprBackend>(device: Device<B>, port: u16)
|
||||
where
|
||||
// Restrict full precision backend handle to be the same
|
||||
<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target:
|
||||
ReprBackend<Handle = B::Handle>,
|
||||
{
|
||||
WsServer::<B>::start(device, port).await;
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
pub(crate) mod processor;
|
||||
pub(crate) mod session;
|
||||
pub(crate) mod stream;
|
||||
|
||||
mod base;
|
||||
|
||||
pub use base::start;
|
|
@ -0,0 +1,84 @@
|
|||
use burn_router::{Runner, RunnerClient};
|
||||
use burn_tensor::{
|
||||
backend::{Backend, BackendBridge},
|
||||
repr::{OperationDescription, ReprBackend, TensorDescription, TensorId},
|
||||
TensorData,
|
||||
};
|
||||
use core::marker::PhantomData;
|
||||
use std::sync::mpsc::Sender;
|
||||
|
||||
use crate::shared::{ConnectionId, TaskResponse, TaskResponseContent};
|
||||
|
||||
/// The goal of the processor is to asynchonously process compute tasks on it own thread.
|
||||
pub struct Processor<B: ReprBackend> {
|
||||
p: PhantomData<B>,
|
||||
}
|
||||
|
||||
pub type Callback<M> = Sender<M>;
|
||||
|
||||
pub enum ProcessorTask {
|
||||
RegisterOperation(Box<OperationDescription>),
|
||||
RegisterTensor(TensorId, TensorData),
|
||||
ReadTensor(ConnectionId, TensorDescription, Callback<TaskResponse>),
|
||||
Sync(ConnectionId, Callback<TaskResponse>),
|
||||
Fence(Callback<()>),
|
||||
RegisterOrphan(TensorId),
|
||||
Close,
|
||||
}
|
||||
|
||||
impl<B: ReprBackend> Processor<B>
|
||||
where
|
||||
// Restrict full precision backend handle to be the same
|
||||
<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target:
|
||||
ReprBackend<Handle = B::Handle>,
|
||||
{
|
||||
pub fn start(runner: Runner<B>) -> Sender<ProcessorTask> {
|
||||
let (sender, rec) = std::sync::mpsc::channel();
|
||||
|
||||
std::thread::spawn(move || {
|
||||
for item in rec.iter() {
|
||||
match item {
|
||||
ProcessorTask::RegisterOperation(op) => {
|
||||
runner.register(*op);
|
||||
}
|
||||
ProcessorTask::RegisterOrphan(id) => {
|
||||
runner.register_orphan(&id);
|
||||
}
|
||||
ProcessorTask::Sync(id, callback) => {
|
||||
runner.sync();
|
||||
callback
|
||||
.send(TaskResponse {
|
||||
content: TaskResponseContent::SyncBackend,
|
||||
id,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
ProcessorTask::RegisterTensor(id, data) => {
|
||||
runner.register_tensor_data_id(id, data);
|
||||
}
|
||||
ProcessorTask::ReadTensor(id, tensor, callback) => {
|
||||
let tensor = burn_common::future::block_on(runner.read_tensor(tensor));
|
||||
callback
|
||||
.send(TaskResponse {
|
||||
content: TaskResponseContent::ReadTensor(tensor),
|
||||
id,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
ProcessorTask::Close => {
|
||||
let device = runner.device();
|
||||
runner.sync();
|
||||
core::mem::drop(runner);
|
||||
B::sync(&device);
|
||||
return;
|
||||
}
|
||||
ProcessorTask::Fence(sender) => {
|
||||
sender.send(()).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
sender
|
||||
}
|
||||
}
|
|
@ -0,0 +1,242 @@
|
|||
use burn_common::id::StreamId;
|
||||
use burn_router::Runner;
|
||||
use burn_tensor::{
|
||||
backend::{Backend, BackendBridge},
|
||||
repr::{ReprBackend, TensorDescription, TensorId, TensorStatus},
|
||||
Device,
|
||||
};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::mpsc::{Receiver, Sender},
|
||||
};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::shared::{ComputeTask, ConnectionId, SessionId, Task, TaskResponse};
|
||||
|
||||
use super::stream::Stream;
|
||||
|
||||
/// A session manager control the creation of sessions.
|
||||
///
|
||||
/// Each session manages its own stream, spawning one thread per stream to mimic the same behavior
|
||||
/// a native backend would have.
|
||||
pub struct SessionManager<B: ReprBackend> {
|
||||
runner: Runner<B>,
|
||||
sessions: tokio::sync::Mutex<HashMap<SessionId, Session<B>>>,
|
||||
}
|
||||
|
||||
struct Session<B: ReprBackend> {
|
||||
runner: Runner<B>,
|
||||
tensors: HashMap<TensorId, Vec<StreamId>>,
|
||||
streams: HashMap<StreamId, Stream<B>>,
|
||||
sender: Sender<Receiver<TaskResponse>>,
|
||||
receiver: Option<Receiver<Receiver<TaskResponse>>>,
|
||||
}
|
||||
|
||||
impl<B: ReprBackend> SessionManager<B>
|
||||
where
|
||||
// Restrict full precision backend handle to be the same
|
||||
<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target:
|
||||
ReprBackend<Handle = B::Handle>,
|
||||
{
|
||||
pub fn new(device: Device<B>) -> Self {
|
||||
Self {
|
||||
runner: Runner::new(device),
|
||||
sessions: Mutex::new(Default::default()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a new responder for the session. Only one responder can exist for a session for
|
||||
/// now.
|
||||
pub async fn register_responder(
|
||||
&self,
|
||||
session_id: SessionId,
|
||||
) -> Receiver<Receiver<TaskResponse>> {
|
||||
log::info!("Register responder for session {session_id}");
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
self.register_session(&mut sessions, session_id);
|
||||
|
||||
let session = sessions.get_mut(&session_id).unwrap();
|
||||
session.init_responder()
|
||||
}
|
||||
|
||||
/// Get the stream for the current session and task.
|
||||
pub async fn stream(
|
||||
&self,
|
||||
session_id: &mut Option<SessionId>,
|
||||
task: Task,
|
||||
) -> Option<(Stream<B>, ConnectionId, ComputeTask)> {
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
|
||||
let session_id = match session_id {
|
||||
Some(id) => *id,
|
||||
None => match task {
|
||||
Task::Init(id) => {
|
||||
log::info!("Init requester for session {id}");
|
||||
*session_id = Some(id);
|
||||
self.register_session(&mut sessions, id);
|
||||
return None;
|
||||
}
|
||||
_ => panic!("The first message should initialize the session"),
|
||||
},
|
||||
};
|
||||
|
||||
match sessions.get_mut(&session_id) {
|
||||
Some(session) => {
|
||||
let (task, connection_id) = match task {
|
||||
Task::Compute(task, connection_id) => (task, connection_id),
|
||||
_ => panic!("Only support compute tasks."),
|
||||
};
|
||||
let stream = session.select(connection_id.stream_id, &task);
|
||||
Some((stream, connection_id, task))
|
||||
}
|
||||
None => {
|
||||
panic!("To be initialized");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Close the session with the given id.
|
||||
pub async fn close(&self, session_id: Option<SessionId>) {
|
||||
if let Some(id) = session_id {
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
if let Some(session) = sessions.get_mut(&id) {
|
||||
session.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register_session(&self, sessions: &mut HashMap<SessionId, Session<B>>, id: SessionId) {
|
||||
sessions.entry(id).or_insert_with(|| {
|
||||
log::info!("Creating a new session {id}");
|
||||
|
||||
Session::new(self.runner.clone())
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: ReprBackend> Session<B>
|
||||
where
|
||||
// Restrict full precision backend handle to be the same
|
||||
<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target:
|
||||
ReprBackend<Handle = B::Handle>,
|
||||
{
|
||||
fn new(runner: Runner<B>) -> Self {
|
||||
let (sender, reveiver) = std::sync::mpsc::channel();
|
||||
Self {
|
||||
runner,
|
||||
tensors: Default::default(),
|
||||
streams: Default::default(),
|
||||
sender,
|
||||
receiver: Some(reveiver),
|
||||
}
|
||||
}
|
||||
|
||||
fn init_responder(&mut self) -> Receiver<Receiver<TaskResponse>> {
|
||||
let mut receiver = None;
|
||||
core::mem::swap(&mut receiver, &mut self.receiver);
|
||||
receiver.expect("Only one responder per session is possible.")
|
||||
}
|
||||
|
||||
/// Select the current [stream](Stream) based on the given task.
|
||||
fn select(&mut self, stream_id: StreamId, task: &ComputeTask) -> Stream<B> {
|
||||
// We have to check every streams involved in the last operation, making
|
||||
// sure the backend is up-to-date with those operations.
|
||||
//
|
||||
// 1. We update the tensor status of all tensors in the task.
|
||||
// 2. We don't keep track of tensors that are used for the last time.
|
||||
let mut fences = Vec::new();
|
||||
for (tensor_id, status) in task.tensors_info() {
|
||||
let tensor_stream_ids = match self.tensors.get(&tensor_id) {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
if status != TensorStatus::ReadWrite {
|
||||
// Add the first stream that created the tensor that may be used by other
|
||||
// streams later.
|
||||
self.register_tensor(tensor_id, stream_id);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let current_stream_already_synced = tensor_stream_ids.contains(&stream_id);
|
||||
|
||||
if !current_stream_already_synced {
|
||||
// We only need to sync to the first stream that created the tensor.
|
||||
if let Some(id) = tensor_stream_ids.iter().next() {
|
||||
fences.push(*id);
|
||||
}
|
||||
}
|
||||
|
||||
// We add the stream to the list of updated stream to avoid needed to flush other
|
||||
// operations that might use this tensor.
|
||||
self.register_tensor(tensor_id, stream_id);
|
||||
|
||||
// If the tensor has the status `read_write`, it means no other stream can reuse it
|
||||
// afterward, so we remove it from the state.
|
||||
if status == TensorStatus::ReadWrite {
|
||||
self.tensors.remove(&tensor_id);
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup orphans.
|
||||
if let ComputeTask::RegisterOrphan(tensor_id) = task {
|
||||
self.tensors.remove(tensor_id);
|
||||
}
|
||||
|
||||
// We have to wait for the streams to be updated.
|
||||
for stream_id in fences {
|
||||
if let Some(stream) = self.streams.get(&stream_id) {
|
||||
stream.fence_sync();
|
||||
}
|
||||
}
|
||||
|
||||
// We return the stream.
|
||||
match self.streams.get(&stream_id) {
|
||||
Some(stream) => stream.clone(),
|
||||
None => {
|
||||
let stream = Stream::<B>::new(self.runner.clone(), self.sender.clone());
|
||||
self.streams.insert(stream_id, stream.clone());
|
||||
stream
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register_tensor(&mut self, tensor_id: TensorId, stream_id: StreamId) {
|
||||
match self.tensors.get_mut(&tensor_id) {
|
||||
Some(ids) => {
|
||||
ids.push(stream_id);
|
||||
}
|
||||
None => {
|
||||
self.tensors.insert(tensor_id, vec![stream_id]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close all streams created in the session.
|
||||
fn close(&mut self) {
|
||||
for (id, stream) in self.streams.drain() {
|
||||
log::info!("Closing stream {id}");
|
||||
stream.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ComputeTask {
|
||||
fn tensors_info(&self) -> Vec<(TensorId, TensorStatus)> {
|
||||
fn from_descriptions(desc: &[&TensorDescription]) -> Vec<(TensorId, TensorStatus)> {
|
||||
desc.iter().map(|t| (t.id, t.status.clone())).collect()
|
||||
}
|
||||
|
||||
match self {
|
||||
ComputeTask::RegisterOperation(op) => from_descriptions(&op.nodes()),
|
||||
ComputeTask::RegisterTensor(tensor_id, _tensor_data) => {
|
||||
vec![(*tensor_id, TensorStatus::NotInit)]
|
||||
}
|
||||
ComputeTask::RegisterOrphan(tensor_id) => {
|
||||
vec![(*tensor_id, TensorStatus::ReadWrite)]
|
||||
}
|
||||
ComputeTask::ReadTensor(tensor_description) => from_descriptions(&[tensor_description]),
|
||||
ComputeTask::SyncBackend => vec![],
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,94 @@
|
|||
use core::marker::PhantomData;
|
||||
use std::sync::mpsc::{Receiver, Sender};
|
||||
|
||||
use crate::shared::{ConnectionId, TaskResponse};
|
||||
|
||||
use super::processor::{Processor, ProcessorTask};
|
||||
use burn_router::Runner;
|
||||
use burn_tensor::{
|
||||
backend::{Backend, BackendBridge},
|
||||
repr::{OperationDescription, ReprBackend, TensorDescription, TensorId},
|
||||
TensorData,
|
||||
};
|
||||
|
||||
/// A stream makes sure all operations registered are executed in the order they were sent to the
|
||||
/// server, protentially waiting to reconstruct consistency.
|
||||
#[derive(Clone)]
|
||||
pub struct Stream<B: ReprBackend> {
|
||||
compute_sender: Sender<ProcessorTask>,
|
||||
writer_sender: Sender<Receiver<TaskResponse>>,
|
||||
_p: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: ReprBackend> Stream<B>
|
||||
where
|
||||
// Restrict full precision backend handle to be the same
|
||||
<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target:
|
||||
ReprBackend<Handle = B::Handle>,
|
||||
{
|
||||
pub fn new(runner: Runner<B>, writer_sender: Sender<Receiver<TaskResponse>>) -> Self {
|
||||
let sender = Processor::start(runner);
|
||||
|
||||
Self {
|
||||
compute_sender: sender,
|
||||
writer_sender,
|
||||
_p: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register_operation(&self, op: Box<OperationDescription>) {
|
||||
self.compute_sender
|
||||
.send(ProcessorTask::RegisterOperation(op))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn register_tensor(&self, tensor_id: TensorId, data: TensorData) {
|
||||
self.compute_sender
|
||||
.send(ProcessorTask::RegisterTensor(tensor_id, data))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub fn register_orphan(&self, tensor_id: TensorId) {
|
||||
self.compute_sender
|
||||
.send(ProcessorTask::RegisterOrphan(tensor_id))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub fn read_tensor(&self, id: ConnectionId, desc: TensorDescription) {
|
||||
let (callback_sender, callback_rec) = std::sync::mpsc::channel();
|
||||
|
||||
self.compute_sender
|
||||
.send(ProcessorTask::ReadTensor(id, desc, callback_sender))
|
||||
.unwrap();
|
||||
|
||||
self.writer_sender.send(callback_rec).unwrap();
|
||||
}
|
||||
|
||||
pub fn sync(&self, id: ConnectionId) {
|
||||
let (callback_sender, callback_rec) = std::sync::mpsc::channel();
|
||||
|
||||
self.compute_sender
|
||||
.send(ProcessorTask::Sync(id, callback_sender))
|
||||
.unwrap();
|
||||
|
||||
self.writer_sender.send(callback_rec).unwrap();
|
||||
}
|
||||
|
||||
// Ensure that all tasks are sent to the backend.
|
||||
//
|
||||
// It doesn't mean that the computation is done, but it means the backend has received the
|
||||
// tasks, which may be queued.
|
||||
pub fn fence_sync(&self) {
|
||||
let (callback_sender, callback_rec) = std::sync::mpsc::channel();
|
||||
|
||||
self.compute_sender
|
||||
.send(ProcessorTask::Fence(callback_sender.clone()))
|
||||
.unwrap();
|
||||
|
||||
callback_rec.recv().unwrap();
|
||||
}
|
||||
|
||||
pub fn close(&self) {
|
||||
self.compute_sender.send(ProcessorTask::Close).unwrap();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,2 @@
|
|||
mod task;
|
||||
pub(crate) use task::*;
|
|
@ -0,0 +1,68 @@
|
|||
use std::fmt::Display;
|
||||
|
||||
use burn_common::id::{IdGenerator, StreamId};
|
||||
use burn_tensor::{
|
||||
repr::{OperationDescription, TensorDescription, TensorId},
|
||||
TensorData,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(new, Serialize, Deserialize, Debug, Hash, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)]
|
||||
pub struct ConnectionId {
|
||||
pub position: u64,
|
||||
pub stream_id: StreamId,
|
||||
}
|
||||
|
||||
/// Unique identifier that can represent a session.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
pub struct SessionId {
|
||||
id: u64,
|
||||
}
|
||||
|
||||
impl Display for SessionId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
writeln!(f, "SessionId({})", self.id)
|
||||
}
|
||||
}
|
||||
|
||||
impl SessionId {
|
||||
/// Create a new [session id](SessionId).
|
||||
#[allow(dead_code)]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
id: IdGenerator::generate(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub enum Task {
|
||||
Compute(ComputeTask, ConnectionId),
|
||||
Init(SessionId),
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub enum ComputeTask {
|
||||
RegisterOperation(Box<OperationDescription>),
|
||||
RegisterTensor(TensorId, TensorData),
|
||||
RegisterOrphan(TensorId),
|
||||
ReadTensor(TensorDescription),
|
||||
SyncBackend,
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct TaskResponse {
|
||||
pub content: TaskResponseContent,
|
||||
pub id: ConnectionId,
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub enum TaskResponseContent {
|
||||
ReadTensor(TensorData),
|
||||
SyncBackend,
|
||||
}
|
|
@ -13,14 +13,15 @@ version.workspace = true
|
|||
|
||||
[features]
|
||||
default = ["std"]
|
||||
std = []
|
||||
std = ["burn-tensor/std", "burn-common/std"]
|
||||
doc = ["default"]
|
||||
|
||||
[dependencies]
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = ["repr"]}
|
||||
burn-common = { path = "../burn-common", version = "0.16.0", default-features = false}
|
||||
hashbrown = { workspace = true }
|
||||
spin = { workspace = true }
|
||||
|
||||
log = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, features = [
|
||||
|
@ -31,7 +32,7 @@ burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features =
|
|||
] }
|
||||
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "0.16.0" }
|
||||
burn-wgpu = { path = "../burn-wgpu", version = "0.16.0" }
|
||||
burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", default-features = false }
|
||||
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
|
|
|
@ -14,26 +14,3 @@ impl<Backends, Bridge> Clone for DirectChannel<Backends, Bridge> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: conflicting implementations because B1 and B2 cannot be differentiated (could be the same type)
|
||||
// impl<B1: ReprBackend, B2: ReprBackend> From<RouterTensor<Runner<B1>>>
|
||||
// for RouterTensor<MultiRunnerClient2<B1, B2>>
|
||||
// {
|
||||
// fn from(value: RouterTensor<Runner<B1>>) -> Self {
|
||||
// RouterTensor {
|
||||
// desc: value.desc,
|
||||
// client: MultiRunnerClient2::RunnerClient1(value.client),
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// impl<B1: ReprBackend, B2: ReprBackend> From<RouterTensor<Runner<B2>>>
|
||||
// for RouterTensor<MultiRunnerClient2<B1, B2>>
|
||||
// {
|
||||
// fn from(value: RouterTensor<Runner<B2>>) -> Self {
|
||||
// RouterTensor {
|
||||
// desc: value.desc,
|
||||
// client: MultiRunnerClient2::RunnerClient2(value.client),
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
|
|
@ -58,7 +58,8 @@ where
|
|||
<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target:
|
||||
ReprBackend<Handle = B::Handle>,
|
||||
{
|
||||
pub(crate) fn new(device: B::Device) -> Self {
|
||||
/// Create a new runner.
|
||||
pub fn new(device: B::Device) -> Self {
|
||||
Self {
|
||||
context: Arc::new(Mutex::new(RunnerContext {
|
||||
handles: HandleContainer::new(),
|
||||
|
@ -90,7 +91,29 @@ where
|
|||
RouterTensor::new(id, shape, dtype, client)
|
||||
}
|
||||
|
||||
pub(crate) fn register_tensor_data_desc(&self, data: TensorData) -> TensorDescription {
|
||||
/// Register a tensor from its data and id.
|
||||
pub fn register_tensor_data_id(&self, id: TensorId, data: TensorData) {
|
||||
let mut ctx = self.context.lock();
|
||||
let dtype = data.dtype;
|
||||
|
||||
if dtype.is_float() {
|
||||
let tensor = B::float_from_data(data, &self.device);
|
||||
ctx.handles.register_float_tensor::<B>(&id, tensor)
|
||||
} else if dtype.is_int() {
|
||||
let tensor = B::int_from_data(data, &self.device);
|
||||
ctx.handles.register_int_tensor::<B>(&id, tensor)
|
||||
} else if dtype.is_bool() {
|
||||
let tensor = B::bool_from_data(data, &self.device);
|
||||
ctx.handles.register_bool_tensor::<B>(&id, tensor)
|
||||
} else if let DType::QFloat(_) = dtype {
|
||||
todo!();
|
||||
}
|
||||
|
||||
core::mem::drop(ctx);
|
||||
}
|
||||
|
||||
/// Register a tensor and returns its description.
|
||||
pub fn register_tensor_data_desc(&self, data: TensorData) -> TensorDescription {
|
||||
let mut ctx = self.context.lock();
|
||||
let id = ctx.create_empty_handle();
|
||||
let shape = data.shape.clone();
|
||||
|
@ -119,11 +142,8 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn register_empty_tensor_desc(
|
||||
&self,
|
||||
shape: Vec<usize>,
|
||||
dtype: DType,
|
||||
) -> TensorDescription {
|
||||
/// Register an empty tensor and returns its description.
|
||||
pub fn register_empty_tensor_desc(&self, shape: Vec<usize>, dtype: DType) -> TensorDescription {
|
||||
let mut ctx = self.context.lock();
|
||||
let id = ctx.create_empty_handle();
|
||||
core::mem::drop(ctx);
|
||||
|
|
|
@ -21,7 +21,8 @@ pub struct RouterTensor<C: RunnerClient> {
|
|||
}
|
||||
|
||||
impl<C: RunnerClient> RouterTensor<C> {
|
||||
pub(crate) fn new(id: Arc<TensorId>, shape: Vec<usize>, dtype: DType, client: C) -> Self {
|
||||
/// Create a new router tensor.
|
||||
pub fn new(id: Arc<TensorId>, shape: Vec<usize>, dtype: DType, client: C) -> Self {
|
||||
Self {
|
||||
id,
|
||||
shape,
|
||||
|
|
|
@ -55,6 +55,8 @@ ndarray = ["burn-core/ndarray"]
|
|||
tch = ["burn-core/tch"]
|
||||
wgpu = ["burn-core/wgpu"]
|
||||
wgpu-spirv = ["burn-core/wgpu-spirv"]
|
||||
remote = ["burn-core/remote"]
|
||||
server = ["burn-core/server"]
|
||||
|
||||
# Network utils
|
||||
network = ["burn-core/network"]
|
||||
|
|
|
@ -92,6 +92,7 @@
|
|||
//! - `autodiff`: Makes available the Autodiff backend
|
||||
//! - Others:
|
||||
//! - `std`: Activates the standard library (deactivate for no_std)
|
||||
//! - `server`: Enables the remote server.
|
||||
//! - `network`: Enables network utilities (currently, only a file downloader with progress bar)
|
||||
//! - `experimental-named-tensor`: Enables named tensors (experimental)
|
||||
//!
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use burn::backend::{wgpu::WgpuDevice, Autodiff, Wgpu};
|
||||
use burn::backend::{Autodiff, Wgpu};
|
||||
|
||||
fn main() {
|
||||
custom_training_loop::run::<Autodiff<Wgpu>>(WgpuDevice::default());
|
||||
custom_training_loop::run::<Autodiff<Wgpu>>(Default::default());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
[package]
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
name = "server"
|
||||
publish = false
|
||||
version.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["wgpu"]
|
||||
cuda-jit = ["burn/cuda-jit"]
|
||||
wgpu = ["burn/wgpu"]
|
||||
wgpu-spirv = ["wgpu", "burn/wgpu-spirv"]
|
||||
ndarray = ["burn/ndarray"]
|
||||
|
||||
[dependencies]
|
||||
cfg-if = { workspace = true }
|
||||
burn = { path = "../../crates/burn", version = "0.16.0", features = ["server"] }
|
|
@ -0,0 +1,3 @@
|
|||
fn main() {
|
||||
server::start();
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
pub fn start() {
|
||||
let port = std::env::var("REMOTE_BACKEND_PORT")
|
||||
.map(|port| match port.parse::<u16>() {
|
||||
Ok(val) => val,
|
||||
Err(err) => panic!("Invalid port, got {port} with error {err}"),
|
||||
})
|
||||
.unwrap_or(3000);
|
||||
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(feature = "ndarray")]{
|
||||
burn::server::start::<burn::backend::NdArray>(Default::default(), port);
|
||||
} else if #[cfg(feature = "wgpu")] {
|
||||
burn::server::start::<burn::backend::Wgpu>(Default::default(), port);
|
||||
} else if #[cfg(feature = "cuda-jit")]{
|
||||
burn::server::start::<burn::backend::CudaJit>(Default::default(), port);
|
||||
} else {
|
||||
panic!("No backend selected, can't start server on port {port}");
|
||||
}
|
||||
}
|
||||
}
|
|
@ -17,6 +17,7 @@ tch-cpu = ["burn/tch"]
|
|||
tch-gpu = ["burn/tch"]
|
||||
wgpu = ["burn/wgpu"]
|
||||
wgpu-spirv = ["wgpu", "burn/wgpu-spirv"]
|
||||
remote = ["burn/remote"]
|
||||
cuda-jit = ["burn/cuda-jit"]
|
||||
hip-jit = ["burn/hip-jit"]
|
||||
|
||||
|
|
|
@ -91,6 +91,16 @@ mod wgpu {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
mod remote {
|
||||
use crate::{launch, ElemType};
|
||||
use burn::backend::{Autodiff, RemoteBackend};
|
||||
|
||||
pub fn run() {
|
||||
launch::<Autodiff<RemoteBackend>>(vec![Default::default()]);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda-jit")]
|
||||
mod cuda_jit {
|
||||
use crate::{launch, ElemType};
|
||||
|
@ -129,4 +139,6 @@ fn main() {
|
|||
cuda_jit::run();
|
||||
#[cfg(feature = "hip-jit")]
|
||||
hip_jit::run();
|
||||
#[cfg(feature = "remote")]
|
||||
remote::run();
|
||||
}
|
||||
|
|
|
@ -96,8 +96,6 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
|
|||
.metric_valid_numeric(AccuracyMetric::new())
|
||||
.metric_train_numeric(LossMetric::new())
|
||||
.metric_valid_numeric(LossMetric::new())
|
||||
.metric_train_numeric(LossMetric::new())
|
||||
.metric_valid_numeric(LossMetric::new())
|
||||
.metric_train_numeric(LearningRateMetric::new())
|
||||
.with_file_checkpointer(CompactRecorder::new())
|
||||
.devices(devices)
|
||||
|
|
Loading…
Reference in New Issue