diff --git a/crates/burn-compute/src/channel/base.rs b/crates/burn-compute/src/channel/base.rs index 14e1a8e2c..b11bdbb91 100644 --- a/crates/burn-compute/src/channel/base.rs +++ b/crates/burn-compute/src/channel/base.rs @@ -1,4 +1,7 @@ -use crate::server::{Binding, ComputeServer, Handle}; +use crate::{ + server::{Binding, ComputeServer, Handle}, + storage::ComputeStorage, +}; use alloc::vec::Vec; use burn_common::reader::Reader; @@ -8,6 +11,12 @@ pub trait ComputeChannel: Clone + core::fmt::Debug + Send /// Given a binding, returns owned resource as bytes fn read(&self, binding: Binding) -> Reader>; + /// Given a resource handle, return the storage resource. + fn get_resource( + &self, + binding: Binding, + ) -> ::Resource; + /// Given a resource as bytes, stores it and returns the resource handle fn create(&self, data: &[u8]) -> Handle; diff --git a/crates/burn-compute/src/channel/cell.rs b/crates/burn-compute/src/channel/cell.rs index 769f6bcc0..c41f895fd 100644 --- a/crates/burn-compute/src/channel/cell.rs +++ b/crates/burn-compute/src/channel/cell.rs @@ -1,5 +1,6 @@ use super::ComputeChannel; use crate::server::{Binding, ComputeServer, Handle}; +use crate::storage::ComputeStorage; use alloc::sync::Arc; use alloc::vec::Vec; use burn_common::reader::Reader; @@ -46,6 +47,13 @@ where self.server.borrow_mut().read(binding) } + fn get_resource( + &self, + binding: Binding, + ) -> ::Resource { + self.server.borrow_mut().get_resource(binding) + } + fn create(&self, resource: &[u8]) -> Handle { self.server.borrow_mut().create(resource) } diff --git a/crates/burn-compute/src/channel/mpsc.rs b/crates/burn-compute/src/channel/mpsc.rs index 689c2d578..6ca9e96a7 100644 --- a/crates/burn-compute/src/channel/mpsc.rs +++ b/crates/burn-compute/src/channel/mpsc.rs @@ -6,7 +6,10 @@ use std::{ use burn_common::reader::Reader; use super::ComputeChannel; -use crate::server::{Binding, ComputeServer, Handle}; +use crate::{ + server::{Binding, ComputeServer, Handle}, + storage::ComputeStorage, +}; /// Create a channel using the [multi-producer, single-consumer channel](mpsc) to communicate with /// the compute server spawn on its own thread. @@ -34,6 +37,10 @@ where Server: ComputeServer, { Read(Binding, Callback>>), + GetResource( + Binding, + Callback<::Resource>, + ), Create(Vec, Callback>), Empty(usize, Callback>), ExecuteKernel(Server::Kernel, Vec>), @@ -55,6 +62,10 @@ where let data = server.read(binding); callback.send(data).unwrap(); } + Message::GetResource(binding, callback) => { + let data = server.get_resource(binding); + callback.send(data).unwrap(); + } Message::Create(data, callback) => { let handle = server.create(&data); callback.send(handle).unwrap(); @@ -103,6 +114,20 @@ where self.response(response) } + fn get_resource( + &self, + binding: Binding, + ) -> ::Resource { + let (callback, response) = mpsc::channel(); + + self.state + .sender + .send(Message::GetResource(binding, callback)) + .unwrap(); + + self.response(response) + } + fn create(&self, data: &[u8]) -> Handle { let (callback, response) = mpsc::channel(); diff --git a/crates/burn-compute/src/channel/mutex.rs b/crates/burn-compute/src/channel/mutex.rs index 422539829..e6db60904 100644 --- a/crates/burn-compute/src/channel/mutex.rs +++ b/crates/burn-compute/src/channel/mutex.rs @@ -1,5 +1,6 @@ use super::ComputeChannel; use crate::server::{Binding, ComputeServer, Handle}; +use crate::storage::ComputeStorage; use alloc::sync::Arc; use alloc::vec::Vec; use burn_common::reader::Reader; @@ -39,6 +40,13 @@ where self.server.lock().read(handle) } + fn get_resource( + &self, + binding: Binding, + ) -> ::Resource { + self.server.lock().get_resource(binding) + } + fn create(&self, data: &[u8]) -> Handle { self.server.lock().create(data) } diff --git a/crates/burn-compute/src/client.rs b/crates/burn-compute/src/client.rs index d3ae348f2..b8ea92434 100644 --- a/crates/burn-compute/src/client.rs +++ b/crates/burn-compute/src/client.rs @@ -1,6 +1,7 @@ use crate::{ channel::ComputeChannel, server::{Binding, ComputeServer, Handle}, + storage::ComputeStorage, tune::{AutotuneOperationSet, Tuner}, }; use alloc::vec::Vec; @@ -44,6 +45,14 @@ where self.channel.read(binding) } + /// Given a resource handle, returns the storage resource. + pub fn get_resource( + &self, + binding: Binding, + ) -> ::Resource { + self.channel.get_resource(binding) + } + /// Given a resource, stores it and returns the resource handle. pub fn create(&self, data: &[u8]) -> Handle { self.channel.create(data) diff --git a/crates/burn-compute/src/server.rs b/crates/burn-compute/src/server.rs index 3ce2c738a..aa0360815 100644 --- a/crates/burn-compute/src/server.rs +++ b/crates/burn-compute/src/server.rs @@ -27,6 +27,12 @@ where /// Given a handle, returns the owned resource as bytes. fn read(&mut self, binding: Binding) -> Reader>; + /// Given a resource handle, returns the storage resource. + fn get_resource( + &mut self, + binding: Binding, + ) -> ::Resource; + /// Given a resource as bytes, stores it and returns the memory handle. fn create(&mut self, data: &[u8]) -> Handle; diff --git a/crates/burn-compute/tests/dummy/server.rs b/crates/burn-compute/tests/dummy/server.rs index 77b78c036..c49d2e067 100644 --- a/crates/burn-compute/tests/dummy/server.rs +++ b/crates/burn-compute/tests/dummy/server.rs @@ -2,10 +2,9 @@ use std::sync::Arc; use burn_common::reader::Reader; use burn_compute::{ - memory_management::simple::SimpleMemoryManagement, - memory_management::{MemoryHandle, MemoryManagement}, + memory_management::{simple::SimpleMemoryManagement, MemoryHandle, MemoryManagement}, server::{Binding, ComputeServer, Handle}, - storage::BytesStorage, + storage::{BytesResource, BytesStorage}, }; use derive_new::new; @@ -33,6 +32,10 @@ where Reader::Concrete(bytes.read().to_vec()) } + fn get_resource(&mut self, binding: Binding) -> BytesResource { + self.memory_management.get(binding.memory) + } + fn create(&mut self, data: &[u8]) -> Handle { let handle = self.memory_management.reserve(data.len()); let resource = self.memory_management.get(handle.clone().binding()); diff --git a/crates/burn-cuda/src/compute/server.rs b/crates/burn-cuda/src/compute/server.rs index eb83a0d7a..c49cbb22e 100644 --- a/crates/burn-cuda/src/compute/server.rs +++ b/crates/burn-cuda/src/compute/server.rs @@ -114,6 +114,14 @@ impl> ComputeServer for CudaServer { let ctx = self.get_context(); ctx.sync(); } + + fn get_resource( + &mut self, + binding: server::Binding, + ) -> ::Resource { + let ctx = self.get_context(); + ctx.memory_management.get(binding.memory) + } } impl> CudaContext { diff --git a/crates/burn-wgpu/src/compute/server.rs b/crates/burn-wgpu/src/compute/server.rs index 56873cfc4..7c52fa6e7 100644 --- a/crates/burn-wgpu/src/compute/server.rs +++ b/crates/burn-wgpu/src/compute/server.rs @@ -218,6 +218,13 @@ where Reader::Concrete(self.buffer_reader(binding).read(&self.device)) } + fn get_resource( + &mut self, + binding: server::Binding, + ) -> ::Resource { + self.memory_management.get(binding.memory) + } + /// When we create a new handle from existing data, we use custom allocations so that we don't /// have to execute the current pending tasks. ///