mirror of https://github.com/tracel-ai/burn.git
Get resources from server (#1861)
This commit is contained in:
parent
75e26d03c3
commit
4b174a88bd
|
@ -1,4 +1,7 @@
|
|||
use crate::server::{Binding, ComputeServer, Handle};
|
||||
use crate::{
|
||||
server::{Binding, ComputeServer, Handle},
|
||||
storage::ComputeStorage,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
|
||||
|
@ -8,6 +11,12 @@ pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug + Send
|
|||
/// Given a binding, returns owned resource as bytes
|
||||
fn read(&self, binding: Binding<Server>) -> Reader<Vec<u8>>;
|
||||
|
||||
/// Given a resource handle, return the storage resource.
|
||||
fn get_resource(
|
||||
&self,
|
||||
binding: Binding<Server>,
|
||||
) -> <Server::Storage as ComputeStorage>::Resource;
|
||||
|
||||
/// Given a resource as bytes, stores it and returns the resource handle
|
||||
fn create(&self, data: &[u8]) -> Handle<Server>;
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use super::ComputeChannel;
|
||||
use crate::server::{Binding, ComputeServer, Handle};
|
||||
use crate::storage::ComputeStorage;
|
||||
use alloc::sync::Arc;
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
|
@ -46,6 +47,13 @@ where
|
|||
self.server.borrow_mut().read(binding)
|
||||
}
|
||||
|
||||
fn get_resource(
|
||||
&self,
|
||||
binding: Binding<Server>,
|
||||
) -> <Server::Storage as ComputeStorage>::Resource {
|
||||
self.server.borrow_mut().get_resource(binding)
|
||||
}
|
||||
|
||||
fn create(&self, resource: &[u8]) -> Handle<Server> {
|
||||
self.server.borrow_mut().create(resource)
|
||||
}
|
||||
|
|
|
@ -6,7 +6,10 @@ use std::{
|
|||
use burn_common::reader::Reader;
|
||||
|
||||
use super::ComputeChannel;
|
||||
use crate::server::{Binding, ComputeServer, Handle};
|
||||
use crate::{
|
||||
server::{Binding, ComputeServer, Handle},
|
||||
storage::ComputeStorage,
|
||||
};
|
||||
|
||||
/// Create a channel using the [multi-producer, single-consumer channel](mpsc) to communicate with
|
||||
/// the compute server spawn on its own thread.
|
||||
|
@ -34,6 +37,10 @@ where
|
|||
Server: ComputeServer,
|
||||
{
|
||||
Read(Binding<Server>, Callback<Reader<Vec<u8>>>),
|
||||
GetResource(
|
||||
Binding<Server>,
|
||||
Callback<<Server::Storage as ComputeStorage>::Resource>,
|
||||
),
|
||||
Create(Vec<u8>, Callback<Handle<Server>>),
|
||||
Empty(usize, Callback<Handle<Server>>),
|
||||
ExecuteKernel(Server::Kernel, Vec<Binding<Server>>),
|
||||
|
@ -55,6 +62,10 @@ where
|
|||
let data = server.read(binding);
|
||||
callback.send(data).unwrap();
|
||||
}
|
||||
Message::GetResource(binding, callback) => {
|
||||
let data = server.get_resource(binding);
|
||||
callback.send(data).unwrap();
|
||||
}
|
||||
Message::Create(data, callback) => {
|
||||
let handle = server.create(&data);
|
||||
callback.send(handle).unwrap();
|
||||
|
@ -103,6 +114,20 @@ where
|
|||
self.response(response)
|
||||
}
|
||||
|
||||
fn get_resource(
|
||||
&self,
|
||||
binding: Binding<Server>,
|
||||
) -> <Server::Storage as ComputeStorage>::Resource {
|
||||
let (callback, response) = mpsc::channel();
|
||||
|
||||
self.state
|
||||
.sender
|
||||
.send(Message::GetResource(binding, callback))
|
||||
.unwrap();
|
||||
|
||||
self.response(response)
|
||||
}
|
||||
|
||||
fn create(&self, data: &[u8]) -> Handle<Server> {
|
||||
let (callback, response) = mpsc::channel();
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use super::ComputeChannel;
|
||||
use crate::server::{Binding, ComputeServer, Handle};
|
||||
use crate::storage::ComputeStorage;
|
||||
use alloc::sync::Arc;
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
|
@ -39,6 +40,13 @@ where
|
|||
self.server.lock().read(handle)
|
||||
}
|
||||
|
||||
fn get_resource(
|
||||
&self,
|
||||
binding: Binding<Server>,
|
||||
) -> <Server::Storage as ComputeStorage>::Resource {
|
||||
self.server.lock().get_resource(binding)
|
||||
}
|
||||
|
||||
fn create(&self, data: &[u8]) -> Handle<Server> {
|
||||
self.server.lock().create(data)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use crate::{
|
||||
channel::ComputeChannel,
|
||||
server::{Binding, ComputeServer, Handle},
|
||||
storage::ComputeStorage,
|
||||
tune::{AutotuneOperationSet, Tuner},
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
|
@ -44,6 +45,14 @@ where
|
|||
self.channel.read(binding)
|
||||
}
|
||||
|
||||
/// Given a resource handle, returns the storage resource.
|
||||
pub fn get_resource(
|
||||
&self,
|
||||
binding: Binding<Server>,
|
||||
) -> <Server::Storage as ComputeStorage>::Resource {
|
||||
self.channel.get_resource(binding)
|
||||
}
|
||||
|
||||
/// Given a resource, stores it and returns the resource handle.
|
||||
pub fn create(&self, data: &[u8]) -> Handle<Server> {
|
||||
self.channel.create(data)
|
||||
|
|
|
@ -27,6 +27,12 @@ where
|
|||
/// Given a handle, returns the owned resource as bytes.
|
||||
fn read(&mut self, binding: Binding<Self>) -> Reader<Vec<u8>>;
|
||||
|
||||
/// Given a resource handle, returns the storage resource.
|
||||
fn get_resource(
|
||||
&mut self,
|
||||
binding: Binding<Self>,
|
||||
) -> <Self::Storage as ComputeStorage>::Resource;
|
||||
|
||||
/// Given a resource as bytes, stores it and returns the memory handle.
|
||||
fn create(&mut self, data: &[u8]) -> Handle<Self>;
|
||||
|
||||
|
|
|
@ -2,10 +2,9 @@ use std::sync::Arc;
|
|||
|
||||
use burn_common::reader::Reader;
|
||||
use burn_compute::{
|
||||
memory_management::simple::SimpleMemoryManagement,
|
||||
memory_management::{MemoryHandle, MemoryManagement},
|
||||
memory_management::{simple::SimpleMemoryManagement, MemoryHandle, MemoryManagement},
|
||||
server::{Binding, ComputeServer, Handle},
|
||||
storage::BytesStorage,
|
||||
storage::{BytesResource, BytesStorage},
|
||||
};
|
||||
use derive_new::new;
|
||||
|
||||
|
@ -33,6 +32,10 @@ where
|
|||
Reader::Concrete(bytes.read().to_vec())
|
||||
}
|
||||
|
||||
fn get_resource(&mut self, binding: Binding<Self>) -> BytesResource {
|
||||
self.memory_management.get(binding.memory)
|
||||
}
|
||||
|
||||
fn create(&mut self, data: &[u8]) -> Handle<Self> {
|
||||
let handle = self.memory_management.reserve(data.len());
|
||||
let resource = self.memory_management.get(handle.clone().binding());
|
||||
|
|
|
@ -114,6 +114,14 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
|
|||
let ctx = self.get_context();
|
||||
ctx.sync();
|
||||
}
|
||||
|
||||
fn get_resource(
|
||||
&mut self,
|
||||
binding: server::Binding<Self>,
|
||||
) -> <Self::Storage as burn_compute::storage::ComputeStorage>::Resource {
|
||||
let ctx = self.get_context();
|
||||
ctx.memory_management.get(binding.memory)
|
||||
}
|
||||
}
|
||||
|
||||
impl<MM: MemoryManagement<CudaStorage>> CudaContext<MM> {
|
||||
|
|
|
@ -218,6 +218,13 @@ where
|
|||
Reader::Concrete(self.buffer_reader(binding).read(&self.device))
|
||||
}
|
||||
|
||||
fn get_resource(
|
||||
&mut self,
|
||||
binding: server::Binding<Self>,
|
||||
) -> <Self::Storage as burn_compute::storage::ComputeStorage>::Resource {
|
||||
self.memory_management.get(binding.memory)
|
||||
}
|
||||
|
||||
/// When we create a new handle from existing data, we use custom allocations so that we don't
|
||||
/// have to execute the current pending tasks.
|
||||
///
|
||||
|
|
Loading…
Reference in New Issue