Get resources from server (#1861)

This commit is contained in:
Arthur Brussee 2024-06-06 22:33:57 +01:00 committed by GitHub
parent 75e26d03c3
commit 4b174a88bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 88 additions and 5 deletions

View File

@ -1,4 +1,7 @@
use crate::server::{Binding, ComputeServer, Handle};
use crate::{
server::{Binding, ComputeServer, Handle},
storage::ComputeStorage,
};
use alloc::vec::Vec;
use burn_common::reader::Reader;
@ -8,6 +11,12 @@ pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug + Send
/// Given a binding, returns owned resource as bytes
fn read(&self, binding: Binding<Server>) -> Reader<Vec<u8>>;
/// Given a resource handle, return the storage resource.
fn get_resource(
&self,
binding: Binding<Server>,
) -> <Server::Storage as ComputeStorage>::Resource;
/// Given a resource as bytes, stores it and returns the resource handle
fn create(&self, data: &[u8]) -> Handle<Server>;

View File

@ -1,5 +1,6 @@
use super::ComputeChannel;
use crate::server::{Binding, ComputeServer, Handle};
use crate::storage::ComputeStorage;
use alloc::sync::Arc;
use alloc::vec::Vec;
use burn_common::reader::Reader;
@ -46,6 +47,13 @@ where
self.server.borrow_mut().read(binding)
}
fn get_resource(
&self,
binding: Binding<Server>,
) -> <Server::Storage as ComputeStorage>::Resource {
self.server.borrow_mut().get_resource(binding)
}
fn create(&self, resource: &[u8]) -> Handle<Server> {
self.server.borrow_mut().create(resource)
}

View File

@ -6,7 +6,10 @@ use std::{
use burn_common::reader::Reader;
use super::ComputeChannel;
use crate::server::{Binding, ComputeServer, Handle};
use crate::{
server::{Binding, ComputeServer, Handle},
storage::ComputeStorage,
};
/// Create a channel using the [multi-producer, single-consumer channel](mpsc) to communicate with
/// the compute server spawn on its own thread.
@ -34,6 +37,10 @@ where
Server: ComputeServer,
{
Read(Binding<Server>, Callback<Reader<Vec<u8>>>),
GetResource(
Binding<Server>,
Callback<<Server::Storage as ComputeStorage>::Resource>,
),
Create(Vec<u8>, Callback<Handle<Server>>),
Empty(usize, Callback<Handle<Server>>),
ExecuteKernel(Server::Kernel, Vec<Binding<Server>>),
@ -55,6 +62,10 @@ where
let data = server.read(binding);
callback.send(data).unwrap();
}
Message::GetResource(binding, callback) => {
let data = server.get_resource(binding);
callback.send(data).unwrap();
}
Message::Create(data, callback) => {
let handle = server.create(&data);
callback.send(handle).unwrap();
@ -103,6 +114,20 @@ where
self.response(response)
}
fn get_resource(
&self,
binding: Binding<Server>,
) -> <Server::Storage as ComputeStorage>::Resource {
let (callback, response) = mpsc::channel();
self.state
.sender
.send(Message::GetResource(binding, callback))
.unwrap();
self.response(response)
}
fn create(&self, data: &[u8]) -> Handle<Server> {
let (callback, response) = mpsc::channel();

View File

@ -1,5 +1,6 @@
use super::ComputeChannel;
use crate::server::{Binding, ComputeServer, Handle};
use crate::storage::ComputeStorage;
use alloc::sync::Arc;
use alloc::vec::Vec;
use burn_common::reader::Reader;
@ -39,6 +40,13 @@ where
self.server.lock().read(handle)
}
fn get_resource(
&self,
binding: Binding<Server>,
) -> <Server::Storage as ComputeStorage>::Resource {
self.server.lock().get_resource(binding)
}
fn create(&self, data: &[u8]) -> Handle<Server> {
self.server.lock().create(data)
}

View File

@ -1,6 +1,7 @@
use crate::{
channel::ComputeChannel,
server::{Binding, ComputeServer, Handle},
storage::ComputeStorage,
tune::{AutotuneOperationSet, Tuner},
};
use alloc::vec::Vec;
@ -44,6 +45,14 @@ where
self.channel.read(binding)
}
/// Given a resource handle, returns the storage resource.
pub fn get_resource(
&self,
binding: Binding<Server>,
) -> <Server::Storage as ComputeStorage>::Resource {
self.channel.get_resource(binding)
}
/// Given a resource, stores it and returns the resource handle.
pub fn create(&self, data: &[u8]) -> Handle<Server> {
self.channel.create(data)

View File

@ -27,6 +27,12 @@ where
/// Given a handle, returns the owned resource as bytes.
fn read(&mut self, binding: Binding<Self>) -> Reader<Vec<u8>>;
/// Given a resource handle, returns the storage resource.
fn get_resource(
&mut self,
binding: Binding<Self>,
) -> <Self::Storage as ComputeStorage>::Resource;
/// Given a resource as bytes, stores it and returns the memory handle.
fn create(&mut self, data: &[u8]) -> Handle<Self>;

View File

@ -2,10 +2,9 @@ use std::sync::Arc;
use burn_common::reader::Reader;
use burn_compute::{
memory_management::simple::SimpleMemoryManagement,
memory_management::{MemoryHandle, MemoryManagement},
memory_management::{simple::SimpleMemoryManagement, MemoryHandle, MemoryManagement},
server::{Binding, ComputeServer, Handle},
storage::BytesStorage,
storage::{BytesResource, BytesStorage},
};
use derive_new::new;
@ -33,6 +32,10 @@ where
Reader::Concrete(bytes.read().to_vec())
}
fn get_resource(&mut self, binding: Binding<Self>) -> BytesResource {
self.memory_management.get(binding.memory)
}
fn create(&mut self, data: &[u8]) -> Handle<Self> {
let handle = self.memory_management.reserve(data.len());
let resource = self.memory_management.get(handle.clone().binding());

View File

@ -114,6 +114,14 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
let ctx = self.get_context();
ctx.sync();
}
fn get_resource(
&mut self,
binding: server::Binding<Self>,
) -> <Self::Storage as burn_compute::storage::ComputeStorage>::Resource {
let ctx = self.get_context();
ctx.memory_management.get(binding.memory)
}
}
impl<MM: MemoryManagement<CudaStorage>> CudaContext<MM> {

View File

@ -218,6 +218,13 @@ where
Reader::Concrete(self.buffer_reader(binding).read(&self.device))
}
fn get_resource(
&mut self,
binding: server::Binding<Self>,
) -> <Self::Storage as burn_compute::storage::ComputeStorage>::Resource {
self.memory_management.get(binding.memory)
}
/// When we create a new handle from existing data, we use custom allocations so that we don't
/// have to execute the current pending tasks.
///