Remote Backend (#2463)

This commit is contained in:
Nathaniel Simard 2024-11-07 15:49:21 -05:00 committed by GitHub
parent 9b9b03c959
commit 099b6dcae0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 1585 additions and 39 deletions

179
Cargo.lock generated
View File

@ -229,6 +229,17 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "async-trait"
version = "0.1.83"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.87",
]
[[package]]
name = "atoi"
version = "2.0.0"
@ -296,6 +307,64 @@ dependencies = [
"arrayvec",
]
[[package]]
name = "axum"
version = "0.7.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae"
dependencies = [
"async-trait",
"axum-core",
"base64 0.22.1",
"bytes",
"futures-util",
"http 1.1.0",
"http-body 1.0.1",
"http-body-util",
"hyper 1.5.0",
"hyper-util",
"itoa",
"matchit",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sha1",
"sync_wrapper 1.0.1",
"tokio",
"tokio-tungstenite",
"tower",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "axum-core"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199"
dependencies = [
"async-trait",
"bytes",
"futures-util",
"http 1.1.0",
"http-body 1.0.1",
"http-body-util",
"mime",
"pin-project-lite",
"rustversion",
"sync_wrapper 1.0.1",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "backend-comparison"
version = "0.16.0"
@ -524,6 +593,7 @@ dependencies = [
"indicatif",
"rayon",
"reqwest 0.12.9",
"serde",
"tokio",
"web-time",
]
@ -542,6 +612,7 @@ dependencies = [
"burn-derive",
"burn-hip",
"burn-ndarray",
"burn-remote",
"burn-tch",
"burn-tensor",
"burn-wgpu",
@ -647,6 +718,7 @@ dependencies = [
"derive-new 0.7.0",
"half",
"log",
"paste",
]
[[package]]
@ -729,15 +801,38 @@ dependencies = [
"serde",
]
[[package]]
name = "burn-remote"
version = "0.16.0"
dependencies = [
"axum",
"burn-common",
"burn-remote",
"burn-router",
"burn-tensor",
"derive-new 0.7.0",
"futures-util",
"log",
"rmp-serde",
"serde",
"serde_bytes",
"tokio",
"tokio-tungstenite",
"tracing-core",
"tracing-subscriber",
]
[[package]]
name = "burn-router"
version = "0.16.0"
dependencies = [
"burn-autodiff",
"burn-common",
"burn-ndarray",
"burn-tensor",
"burn-wgpu",
"hashbrown 0.15.0",
"log",
"spin",
]
@ -3149,6 +3244,7 @@ dependencies = [
"http 1.1.0",
"http-body 1.0.1",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"smallvec",
@ -3693,6 +3789,12 @@ dependencies = [
"regex-automata 0.1.10",
]
[[package]]
name = "matchit"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
[[package]]
name = "matrixmultiply"
version = "0.3.9"
@ -6098,6 +6200,16 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_path_to_error"
version = "0.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af99884400da37c88f5e9146b7f1fd0fbcae8f6eec4e9da38b67d05486f814a6"
dependencies = [
"itoa",
"serde",
]
[[package]]
name = "serde_rusqlite"
version = "0.36.0"
@ -6154,6 +6266,14 @@ dependencies = [
"syn 2.0.87",
]
[[package]]
name = "server"
version = "0.16.0"
dependencies = [
"burn",
"cfg-if",
]
[[package]]
name = "sha1"
version = "0.10.6"
@ -6874,6 +6994,18 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-tungstenite"
version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
]
[[package]]
name = "tokio-util"
version = "0.7.12"
@ -6936,6 +7068,28 @@ dependencies = [
"zip 0.6.6",
]
[[package]]
name = "tower"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f"
dependencies = [
"futures-core",
"futures-util",
"pin-project-lite",
"sync_wrapper 0.1.2",
"tokio",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tower-layer"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e"
[[package]]
name = "tower-service"
version = "0.3.3"
@ -6978,6 +7132,7 @@ version = "0.1.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
dependencies = [
"log",
"pin-project-lite",
"tracing-attributes",
"tracing-core",
@ -7051,6 +7206,24 @@ version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "tungstenite"
version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a"
dependencies = [
"byteorder",
"bytes",
"data-encoding",
"http 1.1.0",
"httparse",
"log",
"rand",
"sha1",
"thiserror",
"utf-8",
]
[[package]]
name = "typenum"
version = "1.17.0"
@ -7174,6 +7347,12 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]]
name = "utf8parse"
version = "0.2.2"

View File

@ -119,6 +119,7 @@ bincode = { version = "2.0.0-rc.3", features = [
# The following packages disable the "std" feature for no_std compatibility
#
derive-new = { version = "0.7.0", default-features = false }
cfg-if = "1.0.0"
blas-src = { version = "0.10.0", default-features = false }
half = { version = "2.4.1", features = [

View File

@ -23,6 +23,8 @@ getrandom = { workspace = true, features = ["js"] }
web-time = { version = "1.1.0" }
[dependencies]
serde = { workspace = true }
# Network downloader
indicatif = { workspace = true, optional = true }
reqwest = { workspace = true, optional = true }

View File

@ -1,4 +1,5 @@
use crate::rand::gen_random;
use serde::{Deserialize, Serialize};
/// Simple ID generator.
pub struct IdGenerator {}
@ -64,3 +65,49 @@ mod tests {
assert_eq!(set.len(), EXPECTED_TOTAL_IDS);
}
}
/// Unique identifier that can represent a stream based on the current thread id.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub struct StreamId {
/// The value representing the thread id.
pub value: u64,
}
impl StreamId {
/// Get the current thread id.
pub fn current() -> Self {
Self {
#[cfg(feature = "std")]
value: Self::from_current_thread(),
#[cfg(not(feature = "std"))]
value: 0,
}
}
#[cfg(feature = "std")]
fn from_current_thread() -> u64 {
use core::hash::Hash;
std::thread_local! {
static ID: std::cell::OnceCell::<u64> = const { std::cell::OnceCell::new() };
};
// Getting the current thread is expensive, so we cache the value into a thread local
// variable, which is very fast.
ID.with(|cell| {
*cell.get_or_init(|| {
// A way to get a thread id encoded as u64.
let mut hasher = std::hash::DefaultHasher::default();
let id = std::thread::current().id();
id.hash(&mut hasher);
std::hash::Hasher::finish(&hasher)
})
})
}
}
impl core::fmt::Display for StreamId {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!("StreamId({:?})", self.value))
}
}

View File

@ -39,6 +39,8 @@ doc = [
"hip-jit",
"vision",
"autodiff",
"remote",
"server",
# Doc features
"burn-candle/doc",
"burn-common/doc",
@ -86,6 +88,8 @@ metal = ["burn-candle?/metal"]
openblas = ["burn-ndarray?/blas-openblas"]
openblas-system = ["burn-ndarray?/blas-openblas-system"]
template = ["burn-wgpu?/template"]
remote = ["burn-remote/client"]
server = ["burn-remote/server"]
candle = ["burn-candle"]
candle-cuda = ["candle", "burn-candle/cuda"]
@ -131,6 +135,7 @@ burn-hip = { path = "../burn-hip", version = "0.16.0", optional = true, default-
burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", optional = true, default-features = false }
burn-tch = { path = "../burn-tch", version = "0.16.0", optional = true }
burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", optional = true, default-features = false }
burn-remote = { path = "../burn-remote", version = "0.16.0", default-features = false, optional = true }
data-encoding = { workspace = true }
uuid = { workspace = true }

View File

@ -7,6 +7,11 @@ pub use ndarray::NdArray;
#[cfg(feature = "autodiff")]
pub use burn_autodiff as autodiff;
#[cfg(feature = "remote")]
pub use burn_remote as remote;
#[cfg(feature = "remote")]
pub use burn_remote::RemoteBackend;
#[cfg(feature = "autodiff")]
pub use burn_autodiff::Autodiff;

View File

@ -43,6 +43,9 @@ pub mod tensor;
/// Backend module.
pub mod backend;
#[cfg(feature = "server")]
pub use burn_remote::server;
extern crate alloc;
#[cfg(all(

View File

@ -12,8 +12,8 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cuda"
version.workspace = true
[features]
default = ["fusion", "autotune", "burn-jit/default", "cubecl/default"]
autotune = ["burn-jit/autotune"]
default = ["fusion", "burn-jit/default", "cubecl/default"]
doc = ["burn-jit/doc"]
fusion = ["burn-fusion", "burn-jit/fusion"]
std = ["burn-jit/std", "cubecl/std"]

View File

@ -34,6 +34,7 @@ derive-new = { workspace = true }
burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false, features = [
"export_tests",
] }
paste = { workspace = true }
[package.metadata.docs.rs]
features = ["doc"]

View File

@ -24,6 +24,7 @@ mod tests {
use burn_jit::JitBackend;
pub type TestRuntime = cubecl::hip::HipRuntime;
pub use half::{bf16, f16};
burn_jit::testgen_all!();
}

View File

@ -0,0 +1,51 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science"]
description = "Backend router decorator over websocket."
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "data"]
license.workspace = true
name = "burn-remote"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-router-remote"
documentation = "https://docs.rs/burn-router-remote"
version.workspace = true
[features]
default = []
doc = []
client = ["tokio-tungstenite"]
server = ["axum", "tracing-core", "tracing-subscriber"]
[dependencies]
burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = true, features = ["repr"]}
burn-common = { path = "../burn-common", version = "0.16.0", default-features = true}
burn-router = { path = "../burn-router", version = "0.16.0", default-features = true}
# Basic dependencies
derive-new = {workspace = true }
log = { workspace = true }
# Shared dependencies
tokio = { version = "1.37", features = ["sync", "rt-multi-thread"] }
serde = { workspace = true, features = ["derive"] }
serde_bytes = { workspace = true }
rmp-serde = { workspace = true }
futures-util = { version = "0.3" }
# Client dependencies
tokio-tungstenite = { version = "0.24", optional = true }
# Server dependencies
axum = { version = "0.7.5", features = ["ws"], optional = true }
tracing-core = { workspace = true, optional = true }
tracing-subscriber = { workspace = true, optional = true }
[dev-dependencies]
# We activate the features client and server during dev.
burn-remote = { path = ".", version = "0.16.0", features=["client", "server"] }
[package.metadata.docs.rs]
features = ["doc"]
rustdoc-args = ["--cfg", "docsrs"]

View File

View File

@ -0,0 +1,99 @@
use super::worker::{ClientRequest, ClientWorker};
use crate::shared::{ComputeTask, ConnectionId, Task, TaskResponseContent};
use burn_common::id::StreamId;
use burn_tensor::repr::TensorId;
use std::{
future::Future,
sync::{atomic::AtomicU64, Arc},
};
use tokio::sync::mpsc::Sender;
pub use super::WsDevice;
#[derive(Clone)]
pub struct WsClient {
pub(crate) device: WsDevice,
pub(crate) sender: Arc<WsSender>,
pub(crate) runtime: Arc<tokio::runtime::Runtime>,
}
impl WsClient {
pub fn init(device: WsDevice) -> Self {
ClientWorker::start(device)
}
pub(crate) fn new(
device: WsDevice,
sender: Sender<ClientRequest>,
runtime: Arc<tokio::runtime::Runtime>,
) -> Self {
Self {
device,
runtime,
sender: Arc::new(WsSender {
sender,
position_counter: AtomicU64::new(0),
tensor_id_counter: AtomicU64::new(0),
}),
}
}
}
pub(crate) struct WsSender {
sender: Sender<ClientRequest>,
position_counter: AtomicU64,
tensor_id_counter: AtomicU64,
}
impl WsSender {
pub(crate) fn send(&self, task: ComputeTask) -> impl Future<Output = ()> + Send {
let position = self
.position_counter
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let stream_id = StreamId::current();
let sender = self.sender.clone();
async move {
sender
.send(ClientRequest::WithoutCallback(Task::Compute(
task,
ConnectionId::new(position, stream_id),
)))
.await
.unwrap();
}
}
pub(crate) fn new_tensor_id(&self) -> TensorId {
let val = self
.tensor_id_counter
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
TensorId::new(val)
}
pub(crate) fn send_callback(
&self,
task: ComputeTask,
) -> impl Future<Output = TaskResponseContent> + Send {
let position = self
.position_counter
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let stream_id = StreamId::current();
let sender = self.sender.clone();
let (callback_sender, mut callback_recv) = tokio::sync::mpsc::channel(1);
async move {
sender
.send(ClientRequest::WithSyncCallback(
Task::Compute(task, ConnectionId::new(position, stream_id)),
callback_sender,
))
.await
.unwrap();
match callback_recv.recv().await {
Some(val) => val,
None => panic!(""),
}
}
}
}

View File

@ -0,0 +1,45 @@
use burn_router::{RouterTensor, RunnerChannel, TensorHandle};
use burn_tensor::repr::TensorDescription;
use super::{
runner::{WsBridge, WsDevice},
WsClient,
};
/// A local channel with direct connection to the backend runner clients.
#[derive(Clone)]
pub struct WsChannel;
impl RunnerChannel for WsChannel {
type Device = WsDevice;
type Bridge = WsBridge;
type Client = WsClient;
type FloatElem = f32;
type IntElem = i32;
fn name() -> String {
"remote".into()
}
fn init_client(device: &Self::Device) -> Self::Client {
WsClient::init(device.clone())
}
fn get_tensor_handle(
_tensor: &TensorDescription,
_client: &Self::Client,
) -> TensorHandle<Self::Bridge> {
panic!("Unsupported")
}
fn register_tensor(
_client: &Self::Client,
_handle: TensorHandle<Self::Bridge>,
_shape: Vec<usize>,
_dtype: burn_tensor::DType,
) -> RouterTensor<Self::Client> {
panic!("Unsupported")
}
}

View File

@ -0,0 +1,8 @@
mod base;
mod channel;
mod runner;
mod worker;
pub use base::*;
pub use channel::*;
pub use runner::WsDevice;

View File

@ -0,0 +1,175 @@
use burn_router::{MultiBackendBridge, RouterTensor, RunnerClient};
use burn_tensor::{
backend::{DeviceId, DeviceOps},
DType, TensorData,
};
use std::sync::Arc;
use crate::shared::{ComputeTask, TaskResponseContent};
use super::WsClient;
// It is very important to block on any request made with the sender, since ordering is crucial
// when registering operation or creating tensors.
//
// The overhead is minimal, since we only wait for the task to be sent to the async
// channel, but not sent to the websocket server and even less processed by the server.
impl RunnerClient for WsClient {
type Device = WsDevice;
fn register(&self, op: burn_tensor::repr::OperationDescription) {
let fut = self
.sender
.send(ComputeTask::RegisterOperation(Box::new(op)));
self.runtime.block_on(fut);
}
fn read_tensor(
&self,
tensor: burn_tensor::repr::TensorDescription,
) -> impl std::future::Future<Output = TensorData> + Send {
// Important for ordering to call the creation of the future sync.
let fut = self.sender.send_callback(ComputeTask::ReadTensor(tensor));
async move {
match fut.await {
TaskResponseContent::ReadTensor(data) => data,
_ => panic!("Invalid message type"),
}
}
}
fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
let id = self.sender.new_tensor_id();
let shape = data.shape.clone();
let dtype = data.dtype;
let fut = self.sender.send(ComputeTask::RegisterTensor(id, data));
self.runtime.block_on(fut);
RouterTensor::new(Arc::new(id), shape, dtype, self.clone())
}
fn register_empty_tensor(
&self,
shape: Vec<usize>,
dtype: burn_tensor::DType,
) -> RouterTensor<Self> {
let id = self.sender.new_tensor_id();
RouterTensor::new(Arc::new(id), shape, dtype, self.clone())
}
fn register_float_tensor(
&self,
shape: Vec<usize>,
_full_precision: bool,
) -> RouterTensor<Self> {
self.register_empty_tensor(shape, DType::F32)
}
fn device(&self) -> Self::Device {
self.device.clone()
}
fn register_orphan(&self, id: &burn_tensor::repr::TensorId) {
let fut = self.sender.send(ComputeTask::RegisterOrphan(*id));
self.runtime.block_on(fut);
}
fn sync(&self) {
// Important for ordering to call the creation of the future sync.
let fut = self.sender.send_callback(ComputeTask::SyncBackend);
let fut = async move {
match fut.await {
TaskResponseContent::SyncBackend => {}
_ => panic!("Invalid message type"),
};
};
self.runtime.block_on(fut)
}
fn seed(&self, _seed: u64) {
// TODO
}
}
#[derive(Clone, PartialEq, Eq, Debug)]
/// The device contains the connection information of the server.
pub struct WsDevice {
pub(crate) address: Arc<String>,
}
impl WsDevice {
/// Create a device from an url.
pub fn new(url: &str) -> Self {
let mut address = String::new();
if !url.starts_with("ws://") {
address += "ws://";
address += url;
} else {
address += url;
};
Self {
address: Arc::new(address),
}
}
}
impl Default for WsDevice {
fn default() -> Self {
let address = match std::env::var("BURN_REMOTE_ADDRESS") {
Ok(address) => address,
Err(_) => String::from("ws://127.0.0.1:3000"),
};
Self {
address: Arc::new(address),
}
}
}
impl DeviceOps for WsDevice {
fn id(&self) -> DeviceId {
DeviceId {
type_id: 0,
index_id: 0,
}
}
}
pub struct WsBridge;
impl MultiBackendBridge for WsBridge {
type TensorHandle = TensorData;
type Device = WsDevice;
fn change_backend_float(
tensor: Self::TensorHandle,
_shape: burn_tensor::Shape,
_target_device: &Self::Device,
) -> Self::TensorHandle {
tensor
}
fn change_backend_int(
tensor: Self::TensorHandle,
_shape: burn_tensor::Shape,
_target_device: &Self::Device,
) -> Self::TensorHandle {
tensor
}
fn change_backend_bool(
tensor: Self::TensorHandle,
_shape: burn_tensor::Shape,
_target_device: &Self::Device,
) -> Self::TensorHandle {
tensor
}
}

View File

@ -0,0 +1,140 @@
use super::{runner::WsDevice, WsClient};
use crate::shared::{ConnectionId, SessionId, Task, TaskResponse, TaskResponseContent};
use futures_util::{SinkExt, StreamExt};
use std::{collections::HashMap, sync::Arc};
use tokio_tungstenite::{
connect_async_with_config,
tungstenite::protocol::{Message, WebSocketConfig},
};
pub type CallbackSender = tokio::sync::mpsc::Sender<TaskResponseContent>;
pub enum ClientRequest {
WithSyncCallback(Task, CallbackSender),
WithoutCallback(Task),
}
#[derive(Default)]
pub(crate) struct ClientWorker {
requests: HashMap<ConnectionId, CallbackSender>,
}
impl ClientWorker {
async fn on_response(&mut self, response: TaskResponse) {
match self.requests.remove(&response.id) {
Some(request) => {
request.send(response.content).await.unwrap();
}
None => {
panic!("Can't ignore message from the server.");
}
}
}
fn register_callback(&mut self, id: ConnectionId, callback: CallbackSender) {
self.requests.insert(id, callback);
}
}
impl ClientWorker {
pub fn start(device: WsDevice) -> WsClient {
let runtime = Arc::new(
tokio::runtime::Builder::new_multi_thread()
.enable_io()
.build()
.unwrap(),
);
let (sender, mut rec) = tokio::sync::mpsc::channel(10);
let address_request = format!("{}/{}", device.address.as_str(), "request");
let address_response = format!("{}/{}", device.address.as_str(), "response");
const MB: usize = 1024 * 1024;
#[allow(deprecated)]
runtime.spawn(async move {
log::info!("Connecting to {address_request} ...");
let (mut stream_request, _) = connect_async_with_config(
address_request.clone(),
Some(WebSocketConfig {
max_send_queue: None,
write_buffer_size: 0,
max_write_buffer_size: usize::MAX,
max_message_size: None,
max_frame_size: Some(MB * 512),
accept_unmasked_frames: true,
}),
true,
)
.await
.expect("Failed to connect");
let (mut stream_response, _) = connect_async_with_config(
address_response,
Some(WebSocketConfig {
max_send_queue: None,
write_buffer_size: 0,
max_write_buffer_size: usize::MAX,
max_message_size: None,
max_frame_size: Some(MB * 512),
accept_unmasked_frames: true,
}),
true,
)
.await
.expect("Failed to connect");
let state = Arc::new(tokio::sync::Mutex::new(ClientWorker::default()));
// Init the connection.
let session_id = SessionId::new();
let bytes = rmp_serde::to_vec(&Task::Init(session_id)).expect("Can serialize tasks to bytes.");
stream_request.send(Message::Binary(bytes.clone())).await.expect("Can send the message on the websocket.");
stream_response.send(Message::Binary(bytes)).await.expect("Can send the message on the websocket.");
// Websocket async worker loading callback from the server.
let state_ws = state.clone();
tokio::spawn(async move {
while let Some(msg) = stream_response.next().await {
let msg = match msg {
Ok(msg) => msg,
Err(err) => panic!("An error happened while receiving messages from the websocket: {err:?}"),
};
match msg {
Message::Binary(bytes) => {
let response: TaskResponse = rmp_serde::from_slice(&bytes).expect("Can deserialize messages from the websocket.");
let mut state = state_ws.lock().await;
state.on_response(response).await;
}
Message::Close(_) => {
log::warn!("Closed connection");
return;
},
_ => panic!("Unsupported websocket message: {msg:?}"),
};
}
});
// Channel async worker sending operations to the server.
tokio::spawn(async move {
while let Some(req) = rec.recv().await {
let task = match req {
ClientRequest::WithSyncCallback(task, callback) => {
let mut state = state.lock().await;
if let Task::Compute(_content, id) = &task {
state.register_callback(*id, callback);
}
task
}
ClientRequest::WithoutCallback(task) => task,
};
let bytes = rmp_serde::to_vec(&task).expect("Can serialize tasks to bytes.");
stream_request.send(Message::Binary(bytes)).await.expect("Can send the message on the websocket.");
}
});
});
WsClient::new(device, sender, runtime)
}
}

View File

@ -0,0 +1,37 @@
#[macro_use]
extern crate derive_new;
#[cfg(feature = "client")]
pub(crate) mod client;
#[cfg(feature = "server")]
pub mod server;
pub(crate) mod shared;
#[cfg(feature = "client")]
mod __client {
use super::*;
use burn_router::BackendRouter;
use client::WsChannel;
/// The remote backend allows you to run computation on a remote device.
///
/// Make sure there is a running server before trying to connect to it.
///
/// ```rust, ignore
/// fn main() {
/// let device = Default::default();
/// let port = 3000;
///
/// // You need to activate the `server` feature flag to have access to this function.
/// burn::server::start::<burn::backend::Wgpu>(device, port);
/// }
///```
pub type RemoteBackend = BackendRouter<WsChannel>;
pub use client::WsDevice as RemoteDevice;
}
#[cfg(feature = "client")]
pub use __client::*;

View File

@ -0,0 +1,196 @@
use std::{net::SocketAddr, sync::Arc};
use axum::{
extract::{
ws::{self, WebSocket, WebSocketUpgrade},
State,
},
response::IntoResponse,
routing::any,
Router,
};
use burn_tensor::{
backend::{Backend, BackendBridge},
repr::ReprBackend,
Device,
};
use tracing_core::{Level, LevelFilter};
use tracing_subscriber::prelude::*;
use tracing_subscriber::{filter::filter_fn, registry};
use crate::shared::{ComputeTask, Task};
use super::session::SessionManager;
#[derive(Clone)]
pub struct WsServer<B: ReprBackend> {
state: Arc<SessionManager<B>>,
}
impl<B: ReprBackend> WsServer<B>
where
// Restrict full precision backend handle to be the same
<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target:
ReprBackend<Handle = B::Handle>,
{
/// Start the server on the given address.
pub async fn start(device: Device<B>, port: u16) {
let layer = tracing_subscriber::fmt::layer()
.with_filter(LevelFilter::INFO)
.with_filter(filter_fn(|m| {
if let Some(path) = m.module_path() {
// The wgpu crate is logging too much, so we skip `info` level.
if path.starts_with("wgpu") && *m.level() >= Level::INFO {
return false;
}
}
true
}));
registry().with(layer).init();
let address = format!("0.0.0.0:{port}");
log::info!("Start server {address} on device {device:?}");
let state = SessionManager::<B>::new(device);
let state = Self {
state: Arc::new(state),
};
// build our application with some routes
let app = Router::new()
.route("/response", any(Self::handler_response))
.route("/request", any(Self::handler_request))
.with_state(state);
// run it with hyper
let listener = tokio::net::TcpListener::bind(address).await.unwrap();
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await
.unwrap();
}
async fn handler_response(
ws: WebSocketUpgrade,
State(state): State<Self>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| state.handle_socket_response(socket))
}
async fn handler_request(ws: WebSocketUpgrade, State(state): State<Self>) -> impl IntoResponse {
ws.on_upgrade(move |socket| state.handle_socket_request(socket))
}
async fn handle_socket_response(self, mut socket: WebSocket) {
log::info!("[Response Handler] On new connection.");
let packet = socket.recv().await;
let msg = match packet {
Some(msg) => msg,
None => {
log::info!("Still no message");
panic!("");
}
};
if let Ok(ws::Message::Binary(bytes)) = msg {
let task = match rmp_serde::from_slice::<Task>(&bytes) {
Ok(val) => val,
Err(err) => {
log::info!("Only bytes messages are supported {err:?}");
panic!("");
}
};
let id = match task {
Task::Init(id) => id,
_ => panic!(""),
};
let receiver = self.state.register_responder(id).await;
log::info!("Response handler connection active");
while let Ok(callback) = receiver.recv() {
let response = callback.recv().unwrap();
let bytes = rmp_serde::to_vec(&response).unwrap();
socket.send(ws::Message::Binary(bytes)).await.unwrap();
}
} else {
panic!("");
}
}
async fn handle_socket_request(self, mut socket: WebSocket) {
log::info!("[Request Handler] On new connection.");
let mut session_id = None;
loop {
let packet = socket.recv().await;
let msg = match packet {
Some(msg) => msg,
None => {
log::info!("Still no message");
continue;
}
};
if let Ok(ws::Message::Binary(bytes)) = msg {
let task = match rmp_serde::from_slice::<Task>(&bytes) {
Ok(val) => val,
Err(err) => {
log::info!("Only bytes message in the json format are supported {err:?}");
break;
}
};
let (stream, connection_id, task) =
match self.state.stream(&mut session_id, task).await {
Some(val) => val,
None => {
log::info!("Ops session activated {session_id:?}");
continue;
}
};
match task {
ComputeTask::RegisterOperation(op) => {
stream.register_operation(op);
}
ComputeTask::RegisterTensor(id, data) => {
stream.register_tensor(id, data);
}
ComputeTask::RegisterOrphan(id) => {
stream.register_orphan(id);
}
ComputeTask::ReadTensor(tensor) => {
stream.read_tensor(connection_id, tensor);
}
ComputeTask::SyncBackend => {
stream.sync(connection_id);
}
}
} else {
log::info!("Not a binary message, closing, received {msg:?}");
break;
};
}
log::info!("Closing connection");
self.state.close(session_id).await;
}
}
#[tokio::main]
/// Start the server on the given port and [device](Device).
pub async fn start<B: ReprBackend>(device: Device<B>, port: u16)
where
// Restrict full precision backend handle to be the same
<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target:
ReprBackend<Handle = B::Handle>,
{
WsServer::<B>::start(device, port).await;
}

View File

@ -0,0 +1,7 @@
pub(crate) mod processor;
pub(crate) mod session;
pub(crate) mod stream;
mod base;
pub use base::start;

View File

@ -0,0 +1,84 @@
use burn_router::{Runner, RunnerClient};
use burn_tensor::{
backend::{Backend, BackendBridge},
repr::{OperationDescription, ReprBackend, TensorDescription, TensorId},
TensorData,
};
use core::marker::PhantomData;
use std::sync::mpsc::Sender;
use crate::shared::{ConnectionId, TaskResponse, TaskResponseContent};
/// The goal of the processor is to asynchonously process compute tasks on it own thread.
pub struct Processor<B: ReprBackend> {
p: PhantomData<B>,
}
pub type Callback<M> = Sender<M>;
pub enum ProcessorTask {
RegisterOperation(Box<OperationDescription>),
RegisterTensor(TensorId, TensorData),
ReadTensor(ConnectionId, TensorDescription, Callback<TaskResponse>),
Sync(ConnectionId, Callback<TaskResponse>),
Fence(Callback<()>),
RegisterOrphan(TensorId),
Close,
}
impl<B: ReprBackend> Processor<B>
where
// Restrict full precision backend handle to be the same
<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target:
ReprBackend<Handle = B::Handle>,
{
pub fn start(runner: Runner<B>) -> Sender<ProcessorTask> {
let (sender, rec) = std::sync::mpsc::channel();
std::thread::spawn(move || {
for item in rec.iter() {
match item {
ProcessorTask::RegisterOperation(op) => {
runner.register(*op);
}
ProcessorTask::RegisterOrphan(id) => {
runner.register_orphan(&id);
}
ProcessorTask::Sync(id, callback) => {
runner.sync();
callback
.send(TaskResponse {
content: TaskResponseContent::SyncBackend,
id,
})
.unwrap();
}
ProcessorTask::RegisterTensor(id, data) => {
runner.register_tensor_data_id(id, data);
}
ProcessorTask::ReadTensor(id, tensor, callback) => {
let tensor = burn_common::future::block_on(runner.read_tensor(tensor));
callback
.send(TaskResponse {
content: TaskResponseContent::ReadTensor(tensor),
id,
})
.unwrap();
}
ProcessorTask::Close => {
let device = runner.device();
runner.sync();
core::mem::drop(runner);
B::sync(&device);
return;
}
ProcessorTask::Fence(sender) => {
sender.send(()).unwrap();
}
}
}
});
sender
}
}

View File

@ -0,0 +1,242 @@
use burn_common::id::StreamId;
use burn_router::Runner;
use burn_tensor::{
backend::{Backend, BackendBridge},
repr::{ReprBackend, TensorDescription, TensorId, TensorStatus},
Device,
};
use std::{
collections::HashMap,
sync::mpsc::{Receiver, Sender},
};
use tokio::sync::Mutex;
use crate::shared::{ComputeTask, ConnectionId, SessionId, Task, TaskResponse};
use super::stream::Stream;
/// A session manager control the creation of sessions.
///
/// Each session manages its own stream, spawning one thread per stream to mimic the same behavior
/// a native backend would have.
pub struct SessionManager<B: ReprBackend> {
runner: Runner<B>,
sessions: tokio::sync::Mutex<HashMap<SessionId, Session<B>>>,
}
struct Session<B: ReprBackend> {
runner: Runner<B>,
tensors: HashMap<TensorId, Vec<StreamId>>,
streams: HashMap<StreamId, Stream<B>>,
sender: Sender<Receiver<TaskResponse>>,
receiver: Option<Receiver<Receiver<TaskResponse>>>,
}
impl<B: ReprBackend> SessionManager<B>
where
// Restrict full precision backend handle to be the same
<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target:
ReprBackend<Handle = B::Handle>,
{
pub fn new(device: Device<B>) -> Self {
Self {
runner: Runner::new(device),
sessions: Mutex::new(Default::default()),
}
}
/// Register a new responder for the session. Only one responder can exist for a session for
/// now.
pub async fn register_responder(
&self,
session_id: SessionId,
) -> Receiver<Receiver<TaskResponse>> {
log::info!("Register responder for session {session_id}");
let mut sessions = self.sessions.lock().await;
self.register_session(&mut sessions, session_id);
let session = sessions.get_mut(&session_id).unwrap();
session.init_responder()
}
/// Get the stream for the current session and task.
pub async fn stream(
&self,
session_id: &mut Option<SessionId>,
task: Task,
) -> Option<(Stream<B>, ConnectionId, ComputeTask)> {
let mut sessions = self.sessions.lock().await;
let session_id = match session_id {
Some(id) => *id,
None => match task {
Task::Init(id) => {
log::info!("Init requester for session {id}");
*session_id = Some(id);
self.register_session(&mut sessions, id);
return None;
}
_ => panic!("The first message should initialize the session"),
},
};
match sessions.get_mut(&session_id) {
Some(session) => {
let (task, connection_id) = match task {
Task::Compute(task, connection_id) => (task, connection_id),
_ => panic!("Only support compute tasks."),
};
let stream = session.select(connection_id.stream_id, &task);
Some((stream, connection_id, task))
}
None => {
panic!("To be initialized");
}
}
}
/// Close the session with the given id.
pub async fn close(&self, session_id: Option<SessionId>) {
if let Some(id) = session_id {
let mut sessions = self.sessions.lock().await;
if let Some(session) = sessions.get_mut(&id) {
session.close();
}
}
}
fn register_session(&self, sessions: &mut HashMap<SessionId, Session<B>>, id: SessionId) {
sessions.entry(id).or_insert_with(|| {
log::info!("Creating a new session {id}");
Session::new(self.runner.clone())
});
}
}
impl<B: ReprBackend> Session<B>
where
// Restrict full precision backend handle to be the same
<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target:
ReprBackend<Handle = B::Handle>,
{
fn new(runner: Runner<B>) -> Self {
let (sender, reveiver) = std::sync::mpsc::channel();
Self {
runner,
tensors: Default::default(),
streams: Default::default(),
sender,
receiver: Some(reveiver),
}
}
fn init_responder(&mut self) -> Receiver<Receiver<TaskResponse>> {
let mut receiver = None;
core::mem::swap(&mut receiver, &mut self.receiver);
receiver.expect("Only one responder per session is possible.")
}
/// Select the current [stream](Stream) based on the given task.
fn select(&mut self, stream_id: StreamId, task: &ComputeTask) -> Stream<B> {
// We have to check every streams involved in the last operation, making
// sure the backend is up-to-date with those operations.
//
// 1. We update the tensor status of all tensors in the task.
// 2. We don't keep track of tensors that are used for the last time.
let mut fences = Vec::new();
for (tensor_id, status) in task.tensors_info() {
let tensor_stream_ids = match self.tensors.get(&tensor_id) {
Some(val) => val,
None => {
if status != TensorStatus::ReadWrite {
// Add the first stream that created the tensor that may be used by other
// streams later.
self.register_tensor(tensor_id, stream_id);
}
continue;
}
};
let current_stream_already_synced = tensor_stream_ids.contains(&stream_id);
if !current_stream_already_synced {
// We only need to sync to the first stream that created the tensor.
if let Some(id) = tensor_stream_ids.iter().next() {
fences.push(*id);
}
}
// We add the stream to the list of updated stream to avoid needed to flush other
// operations that might use this tensor.
self.register_tensor(tensor_id, stream_id);
// If the tensor has the status `read_write`, it means no other stream can reuse it
// afterward, so we remove it from the state.
if status == TensorStatus::ReadWrite {
self.tensors.remove(&tensor_id);
}
}
// Cleanup orphans.
if let ComputeTask::RegisterOrphan(tensor_id) = task {
self.tensors.remove(tensor_id);
}
// We have to wait for the streams to be updated.
for stream_id in fences {
if let Some(stream) = self.streams.get(&stream_id) {
stream.fence_sync();
}
}
// We return the stream.
match self.streams.get(&stream_id) {
Some(stream) => stream.clone(),
None => {
let stream = Stream::<B>::new(self.runner.clone(), self.sender.clone());
self.streams.insert(stream_id, stream.clone());
stream
}
}
}
fn register_tensor(&mut self, tensor_id: TensorId, stream_id: StreamId) {
match self.tensors.get_mut(&tensor_id) {
Some(ids) => {
ids.push(stream_id);
}
None => {
self.tensors.insert(tensor_id, vec![stream_id]);
}
}
}
// Close all streams created in the session.
fn close(&mut self) {
for (id, stream) in self.streams.drain() {
log::info!("Closing stream {id}");
stream.close();
}
}
}
impl ComputeTask {
fn tensors_info(&self) -> Vec<(TensorId, TensorStatus)> {
fn from_descriptions(desc: &[&TensorDescription]) -> Vec<(TensorId, TensorStatus)> {
desc.iter().map(|t| (t.id, t.status.clone())).collect()
}
match self {
ComputeTask::RegisterOperation(op) => from_descriptions(&op.nodes()),
ComputeTask::RegisterTensor(tensor_id, _tensor_data) => {
vec![(*tensor_id, TensorStatus::NotInit)]
}
ComputeTask::RegisterOrphan(tensor_id) => {
vec![(*tensor_id, TensorStatus::ReadWrite)]
}
ComputeTask::ReadTensor(tensor_description) => from_descriptions(&[tensor_description]),
ComputeTask::SyncBackend => vec![],
}
}
}

View File

@ -0,0 +1,94 @@
use core::marker::PhantomData;
use std::sync::mpsc::{Receiver, Sender};
use crate::shared::{ConnectionId, TaskResponse};
use super::processor::{Processor, ProcessorTask};
use burn_router::Runner;
use burn_tensor::{
backend::{Backend, BackendBridge},
repr::{OperationDescription, ReprBackend, TensorDescription, TensorId},
TensorData,
};
/// A stream makes sure all operations registered are executed in the order they were sent to the
/// server, protentially waiting to reconstruct consistency.
#[derive(Clone)]
pub struct Stream<B: ReprBackend> {
compute_sender: Sender<ProcessorTask>,
writer_sender: Sender<Receiver<TaskResponse>>,
_p: PhantomData<B>,
}
impl<B: ReprBackend> Stream<B>
where
// Restrict full precision backend handle to be the same
<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target:
ReprBackend<Handle = B::Handle>,
{
pub fn new(runner: Runner<B>, writer_sender: Sender<Receiver<TaskResponse>>) -> Self {
let sender = Processor::start(runner);
Self {
compute_sender: sender,
writer_sender,
_p: PhantomData,
}
}
pub fn register_operation(&self, op: Box<OperationDescription>) {
self.compute_sender
.send(ProcessorTask::RegisterOperation(op))
.unwrap();
}
pub fn register_tensor(&self, tensor_id: TensorId, data: TensorData) {
self.compute_sender
.send(ProcessorTask::RegisterTensor(tensor_id, data))
.unwrap()
}
pub fn register_orphan(&self, tensor_id: TensorId) {
self.compute_sender
.send(ProcessorTask::RegisterOrphan(tensor_id))
.unwrap()
}
pub fn read_tensor(&self, id: ConnectionId, desc: TensorDescription) {
let (callback_sender, callback_rec) = std::sync::mpsc::channel();
self.compute_sender
.send(ProcessorTask::ReadTensor(id, desc, callback_sender))
.unwrap();
self.writer_sender.send(callback_rec).unwrap();
}
pub fn sync(&self, id: ConnectionId) {
let (callback_sender, callback_rec) = std::sync::mpsc::channel();
self.compute_sender
.send(ProcessorTask::Sync(id, callback_sender))
.unwrap();
self.writer_sender.send(callback_rec).unwrap();
}
// Ensure that all tasks are sent to the backend.
//
// It doesn't mean that the computation is done, but it means the backend has received the
// tasks, which may be queued.
pub fn fence_sync(&self) {
let (callback_sender, callback_rec) = std::sync::mpsc::channel();
self.compute_sender
.send(ProcessorTask::Fence(callback_sender.clone()))
.unwrap();
callback_rec.recv().unwrap();
}
pub fn close(&self) {
self.compute_sender.send(ProcessorTask::Close).unwrap();
}
}

View File

@ -0,0 +1,2 @@
mod task;
pub(crate) use task::*;

View File

@ -0,0 +1,68 @@
use std::fmt::Display;
use burn_common::id::{IdGenerator, StreamId};
use burn_tensor::{
repr::{OperationDescription, TensorDescription, TensorId},
TensorData,
};
use serde::{Deserialize, Serialize};
#[allow(missing_docs)]
#[derive(new, Serialize, Deserialize, Debug, Hash, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)]
pub struct ConnectionId {
pub position: u64,
pub stream_id: StreamId,
}
/// Unique identifier that can represent a session.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub struct SessionId {
id: u64,
}
impl Display for SessionId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "SessionId({})", self.id)
}
}
impl SessionId {
/// Create a new [session id](SessionId).
#[allow(dead_code)]
pub fn new() -> Self {
Self {
id: IdGenerator::generate(),
}
}
}
#[allow(missing_docs)]
#[derive(Serialize, Deserialize, Debug)]
pub enum Task {
Compute(ComputeTask, ConnectionId),
Init(SessionId),
}
#[allow(missing_docs)]
#[derive(Serialize, Deserialize, Debug)]
pub enum ComputeTask {
RegisterOperation(Box<OperationDescription>),
RegisterTensor(TensorId, TensorData),
RegisterOrphan(TensorId),
ReadTensor(TensorDescription),
SyncBackend,
}
#[allow(missing_docs)]
#[derive(Serialize, Deserialize, Debug)]
pub struct TaskResponse {
pub content: TaskResponseContent,
pub id: ConnectionId,
}
#[allow(missing_docs)]
#[derive(Serialize, Deserialize, Debug)]
pub enum TaskResponseContent {
ReadTensor(TensorData),
SyncBackend,
}

View File

@ -13,14 +13,15 @@ version.workspace = true
[features]
default = ["std"]
std = []
std = ["burn-tensor/std", "burn-common/std"]
doc = ["default"]
[dependencies]
burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = ["repr"]}
burn-common = { path = "../burn-common", version = "0.16.0", default-features = false}
hashbrown = { workspace = true }
spin = { workspace = true }
log = { workspace = true }
[dev-dependencies]
burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, features = [
@ -31,7 +32,7 @@ burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features =
] }
burn-ndarray = { path = "../burn-ndarray", version = "0.16.0" }
burn-wgpu = { path = "../burn-wgpu", version = "0.16.0" }
burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", default-features = false }
[package.metadata.docs.rs]

View File

@ -14,26 +14,3 @@ impl<Backends, Bridge> Clone for DirectChannel<Backends, Bridge> {
}
}
}
// NOTE: conflicting implementations because B1 and B2 cannot be differentiated (could be the same type)
// impl<B1: ReprBackend, B2: ReprBackend> From<RouterTensor<Runner<B1>>>
// for RouterTensor<MultiRunnerClient2<B1, B2>>
// {
// fn from(value: RouterTensor<Runner<B1>>) -> Self {
// RouterTensor {
// desc: value.desc,
// client: MultiRunnerClient2::RunnerClient1(value.client),
// }
// }
// }
// impl<B1: ReprBackend, B2: ReprBackend> From<RouterTensor<Runner<B2>>>
// for RouterTensor<MultiRunnerClient2<B1, B2>>
// {
// fn from(value: RouterTensor<Runner<B2>>) -> Self {
// RouterTensor {
// desc: value.desc,
// client: MultiRunnerClient2::RunnerClient2(value.client),
// }
// }
// }

View File

@ -58,7 +58,8 @@ where
<<B as Backend>::FullPrecisionBridge as BackendBridge<B>>::Target:
ReprBackend<Handle = B::Handle>,
{
pub(crate) fn new(device: B::Device) -> Self {
/// Create a new runner.
pub fn new(device: B::Device) -> Self {
Self {
context: Arc::new(Mutex::new(RunnerContext {
handles: HandleContainer::new(),
@ -90,7 +91,29 @@ where
RouterTensor::new(id, shape, dtype, client)
}
pub(crate) fn register_tensor_data_desc(&self, data: TensorData) -> TensorDescription {
/// Register a tensor from its data and id.
pub fn register_tensor_data_id(&self, id: TensorId, data: TensorData) {
let mut ctx = self.context.lock();
let dtype = data.dtype;
if dtype.is_float() {
let tensor = B::float_from_data(data, &self.device);
ctx.handles.register_float_tensor::<B>(&id, tensor)
} else if dtype.is_int() {
let tensor = B::int_from_data(data, &self.device);
ctx.handles.register_int_tensor::<B>(&id, tensor)
} else if dtype.is_bool() {
let tensor = B::bool_from_data(data, &self.device);
ctx.handles.register_bool_tensor::<B>(&id, tensor)
} else if let DType::QFloat(_) = dtype {
todo!();
}
core::mem::drop(ctx);
}
/// Register a tensor and returns its description.
pub fn register_tensor_data_desc(&self, data: TensorData) -> TensorDescription {
let mut ctx = self.context.lock();
let id = ctx.create_empty_handle();
let shape = data.shape.clone();
@ -119,11 +142,8 @@ where
}
}
pub(crate) fn register_empty_tensor_desc(
&self,
shape: Vec<usize>,
dtype: DType,
) -> TensorDescription {
/// Register an empty tensor and returns its description.
pub fn register_empty_tensor_desc(&self, shape: Vec<usize>, dtype: DType) -> TensorDescription {
let mut ctx = self.context.lock();
let id = ctx.create_empty_handle();
core::mem::drop(ctx);

View File

@ -21,7 +21,8 @@ pub struct RouterTensor<C: RunnerClient> {
}
impl<C: RunnerClient> RouterTensor<C> {
pub(crate) fn new(id: Arc<TensorId>, shape: Vec<usize>, dtype: DType, client: C) -> Self {
/// Create a new router tensor.
pub fn new(id: Arc<TensorId>, shape: Vec<usize>, dtype: DType, client: C) -> Self {
Self {
id,
shape,

View File

@ -55,6 +55,8 @@ ndarray = ["burn-core/ndarray"]
tch = ["burn-core/tch"]
wgpu = ["burn-core/wgpu"]
wgpu-spirv = ["burn-core/wgpu-spirv"]
remote = ["burn-core/remote"]
server = ["burn-core/server"]
# Network utils
network = ["burn-core/network"]

View File

@ -92,6 +92,7 @@
//! - `autodiff`: Makes available the Autodiff backend
//! - Others:
//! - `std`: Activates the standard library (deactivate for no_std)
//! - `server`: Enables the remote server.
//! - `network`: Enables network utilities (currently, only a file downloader with progress bar)
//! - `experimental-named-tensor`: Enables named tensors (experimental)
//!

View File

@ -1,5 +1,5 @@
use burn::backend::{wgpu::WgpuDevice, Autodiff, Wgpu};
use burn::backend::{Autodiff, Wgpu};
fn main() {
custom_training_loop::run::<Autodiff<Wgpu>>(WgpuDevice::default());
custom_training_loop::run::<Autodiff<Wgpu>>(Default::default());
}

View File

@ -0,0 +1,18 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
edition.workspace = true
license.workspace = true
name = "server"
publish = false
version.workspace = true
[features]
default = ["wgpu"]
cuda-jit = ["burn/cuda-jit"]
wgpu = ["burn/wgpu"]
wgpu-spirv = ["wgpu", "burn/wgpu-spirv"]
ndarray = ["burn/ndarray"]
[dependencies]
cfg-if = { workspace = true }
burn = { path = "../../crates/burn", version = "0.16.0", features = ["server"] }

View File

@ -0,0 +1,3 @@
fn main() {
server::start();
}

View File

@ -0,0 +1,20 @@
pub fn start() {
let port = std::env::var("REMOTE_BACKEND_PORT")
.map(|port| match port.parse::<u16>() {
Ok(val) => val,
Err(err) => panic!("Invalid port, got {port} with error {err}"),
})
.unwrap_or(3000);
cfg_if::cfg_if! {
if #[cfg(feature = "ndarray")]{
burn::server::start::<burn::backend::NdArray>(Default::default(), port);
} else if #[cfg(feature = "wgpu")] {
burn::server::start::<burn::backend::Wgpu>(Default::default(), port);
} else if #[cfg(feature = "cuda-jit")]{
burn::server::start::<burn::backend::CudaJit>(Default::default(), port);
} else {
panic!("No backend selected, can't start server on port {port}");
}
}
}

View File

@ -17,6 +17,7 @@ tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]
wgpu-spirv = ["wgpu", "burn/wgpu-spirv"]
remote = ["burn/remote"]
cuda-jit = ["burn/cuda-jit"]
hip-jit = ["burn/hip-jit"]

View File

@ -91,6 +91,16 @@ mod wgpu {
}
}
#[cfg(feature = "remote")]
mod remote {
use crate::{launch, ElemType};
use burn::backend::{Autodiff, RemoteBackend};
pub fn run() {
launch::<Autodiff<RemoteBackend>>(vec![Default::default()]);
}
}
#[cfg(feature = "cuda-jit")]
mod cuda_jit {
use crate::{launch, ElemType};
@ -129,4 +139,6 @@ fn main() {
cuda_jit::run();
#[cfg(feature = "hip-jit")]
hip_jit::run();
#[cfg(feature = "remote")]
remote::run();
}

View File

@ -96,8 +96,6 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
.metric_valid_numeric(AccuracyMetric::new())
.metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.metric_train_numeric(LearningRateMetric::new())
.with_file_checkpointer(CompactRecorder::new())
.devices(devices)