diff --git a/Cargo.lock b/Cargo.lock index d2cf1e5eb..30e47d900 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -190,14 +190,15 @@ dependencies = [ ] [[package]] -name = "async-trait" -version = "0.1.80" +name = "async-channel" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" +checksum = "89b47800b0be77592da0afd425cc03468052844aff33b84e33cc696f64e77b6a" dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.68", + "concurrent-queue", + "event-listener-strategy", + "futures-core", + "pin-project-lite", ] [[package]] @@ -461,12 +462,12 @@ dependencies = [ name = "burn-common" version = "0.14.0" dependencies = [ - "async-trait", "dashmap", "data-encoding", "derive-new", "getrandom", "indicatif", + "pollster", "rand", "reqwest 0.12.5", "serde", @@ -479,12 +480,14 @@ dependencies = [ name = "burn-compute" version = "0.14.0" dependencies = [ + "async-channel", "burn-common", "derive-new", "dirs 5.0.1", "hashbrown 0.14.5", "log", "md5", + "pollster", "rand", "serde", "serde_json", @@ -761,6 +764,7 @@ dependencies = [ name = "burn-wgpu" version = "0.14.0" dependencies = [ + "async-channel", "burn-common", "burn-compute", "burn-cube", @@ -769,7 +773,6 @@ dependencies = [ "burn-tensor", "bytemuck", "derive-new", - "futures-intrusive", "hashbrown 0.14.5", "log", "pollster", @@ -1143,6 +1146,15 @@ dependencies = [ "static_assertions", ] +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "console" version = "0.15.8" @@ -1733,6 +1745,27 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6" +[[package]] +name = "event-listener" +version = "5.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f214dc438f977e6d4e3500aaa277f5ad94ca83fbbd9b1a15713ce2344ccc5a1" +dependencies = [ + "event-listener", + "pin-project-lite", +] + [[package]] name = "exr" version = "1.72.0" @@ -1931,17 +1964,6 @@ dependencies = [ "futures-util", ] -[[package]] -name = "futures-intrusive" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" -dependencies = [ - "futures-core", - "lock_api", - "parking_lot 0.12.2", -] - [[package]] name = "futures-io" version = "0.3.30" @@ -3274,7 +3296,6 @@ dependencies = [ "burn", "console_error_panic_hook", "js-sys", - "pollster", "serde", "wasm-bindgen", "wasm-bindgen-futures", @@ -3778,6 +3799,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "parking" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" + [[package]] name = "parking_lot" version = "0.11.2" diff --git a/Cargo.toml b/Cargo.toml index 42031dfa8..7da1823e8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,6 @@ readme = "README.md" license = "MIT OR Apache-2.0" [workspace.dependencies] -async-trait = "0.1.80" bytemuck = "1.16.1" candle-core = { version = "0.5.1" } clap = { version = "4.5.8", features = ["derive"] } @@ -83,14 +82,16 @@ tracing-subscriber = "0.3.18" web-time = "1.1.0" zip = "2.1.3" +# Async handling +pollster = "0.3" +async-channel = "2.3" + # Terminal UI ratatui = "0.26.3" crossterm = "0.27.0" # WGPU stuff -futures-intrusive = "0.5.0" text_placeholder = "0.5.0" -pollster = "0.3.0" wgpu = "0.20.1" # Benchmarks and Burnbench diff --git a/crates/burn-autodiff/src/ops/bool_tensor.rs b/crates/burn-autodiff/src/ops/bool_tensor.rs index 7cb834e9d..6d3823c4d 100644 --- a/crates/burn-autodiff/src/ops/bool_tensor.rs +++ b/crates/burn-autodiff/src/ops/bool_tensor.rs @@ -3,7 +3,7 @@ use crate::{checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor, Au use burn_tensor::{ backend::Backend, ops::{BoolTensor, BoolTensorOps, IntTensor}, - Device, Reader, Shape, TensorData, + Device, Shape, TensorData, }; impl BoolTensorOps for Autodiff { @@ -15,12 +15,8 @@ impl BoolTensorOps for Autodiff { B::bool_shape(tensor) } - fn bool_to_data(tensor: &BoolTensor) -> Reader { - B::bool_to_data(tensor) - } - - fn bool_into_data(tensor: BoolTensor) -> Reader { - B::bool_into_data(tensor) + async fn bool_into_data(tensor: BoolTensor) -> TensorData { + B::bool_into_data(tensor).await } fn bool_into_int(tensor: BoolTensor) -> IntTensor { @@ -121,14 +117,12 @@ impl BoolTensorOps for Autodiff { B::bool_flip(tensor, axes) } - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - fn bool_argwhere(tensor: BoolTensor) -> IntTensor { - B::bool_argwhere(tensor) + async fn bool_argwhere(tensor: BoolTensor) -> IntTensor { + B::bool_argwhere(tensor).await } - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - fn bool_nonzero(tensor: BoolTensor) -> Vec> { - B::bool_nonzero(tensor) + async fn bool_nonzero(tensor: BoolTensor) -> Vec> { + B::bool_nonzero(tensor).await } fn bool_expand( diff --git a/crates/burn-autodiff/src/ops/int_tensor.rs b/crates/burn-autodiff/src/ops/int_tensor.rs index 1243a12ce..9f187bcf9 100644 --- a/crates/burn-autodiff/src/ops/int_tensor.rs +++ b/crates/burn-autodiff/src/ops/int_tensor.rs @@ -3,7 +3,7 @@ use crate::{checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor, Au use burn_tensor::{ backend::Backend, ops::{BoolTensor, IntTensor, IntTensorOps}, - Device, Distribution, Reader, Shape, TensorData, + Device, Distribution, Shape, TensorData, }; impl IntTensorOps for Autodiff { @@ -15,12 +15,8 @@ impl IntTensorOps for Autodiff { B::int_shape(tensor) } - fn int_to_data(tensor: &IntTensor) -> Reader { - B::int_to_data(tensor) - } - - fn int_into_data(tensor: IntTensor) -> Reader { - B::int_into_data(tensor) + async fn int_into_data(tensor: IntTensor) -> TensorData { + B::int_into_data(tensor).await } fn int_to_device( @@ -380,7 +376,6 @@ impl IntTensorOps for Autodiff { B::int_expand(tensor, shape) } - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn int_sort( tensor: IntTensor, dim: usize, @@ -389,7 +384,6 @@ impl IntTensorOps for Autodiff { B::int_sort(tensor, dim, descending) } - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn int_sort_with_indices( tensor: IntTensor, dim: usize, @@ -398,7 +392,6 @@ impl IntTensorOps for Autodiff { B::int_sort_with_indices(tensor, dim, descending) } - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn int_argsort( tensor: IntTensor, dim: usize, diff --git a/crates/burn-autodiff/src/ops/tensor.rs b/crates/burn-autodiff/src/ops/tensor.rs index fa25bd37a..c8b7d8605 100644 --- a/crates/burn-autodiff/src/ops/tensor.rs +++ b/crates/burn-autodiff/src/ops/tensor.rs @@ -17,7 +17,7 @@ use crate::{ use burn_tensor::{ backend::Backend, ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor}, - Device, ElementConversion, Reader, Shape, Tensor, TensorData, + Device, ElementConversion, Shape, Tensor, TensorData, }; use super::maxmin::MaxMinDim; @@ -50,12 +50,8 @@ impl FloatTensorOps for Autodiff B::float_shape(&tensor.primitive) } - fn float_to_data(tensor: &FloatTensor) -> Reader { - B::float_to_data(&tensor.primitive) - } - - fn float_into_data(tensor: FloatTensor) -> Reader { - B::float_into_data(tensor.primitive) + async fn float_into_data(tensor: FloatTensor) -> TensorData { + B::float_into_data(tensor.primitive).await } fn float_device(tensor: &FloatTensor) -> Device { @@ -2364,7 +2360,6 @@ impl FloatTensorOps for Autodiff } } - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn float_sort( tensor: FloatTensor, dim: usize, @@ -2387,7 +2382,6 @@ impl FloatTensorOps for Autodiff } } - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn float_sort_with_indices( tensor: FloatTensor, dim: usize, @@ -2416,7 +2410,6 @@ impl FloatTensorOps for Autodiff } } - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn float_argsort( tensor: FloatTensor, dim: usize, diff --git a/crates/burn-candle/src/ops/base.rs b/crates/burn-candle/src/ops/base.rs index 16b56227b..45a849e1a 100644 --- a/crates/burn-candle/src/ops/base.rs +++ b/crates/burn-candle/src/ops/base.rs @@ -1,6 +1,6 @@ use std::marker::PhantomData; -use burn_tensor::{backend::Backend, Reader, Shape, TensorData}; +use burn_tensor::{backend::Backend, Shape, TensorData}; use crate::{ element::{CandleElement, FloatCandleElement, IntCandleElement}, diff --git a/crates/burn-candle/src/ops/bool_tensor.rs b/crates/burn-candle/src/ops/bool_tensor.rs index c7997e610..7f6c4847d 100644 --- a/crates/burn-candle/src/ops/bool_tensor.rs +++ b/crates/burn-candle/src/ops/bool_tensor.rs @@ -1,6 +1,6 @@ use burn_tensor::{ ops::{BoolTensor, BoolTensorOps, FloatTensor, IntTensor}, - Device, Reader, Shape, TensorData, + Device, Shape, TensorData, }; use crate::{ @@ -19,12 +19,10 @@ impl BoolTensorOps for Candle< super::base::shape(tensor) } - fn bool_into_data(tensor: BoolTensor) -> Reader { + async fn bool_into_data(tensor: BoolTensor) -> TensorData { let x: Vec = tensor.tensor.flatten_all().unwrap().to_vec1().unwrap(); let y = x.iter().map(|b| !matches!(b, 0)).collect(); - let data = TensorData::new(y, tensor.shape()); - - Reader::Concrete(data) + TensorData::new(y, tensor.shape()) } fn bool_from_data( diff --git a/crates/burn-candle/src/ops/int_tensor.rs b/crates/burn-candle/src/ops/int_tensor.rs index 35adc2129..3502bea5c 100644 --- a/crates/burn-candle/src/ops/int_tensor.rs +++ b/crates/burn-candle/src/ops/int_tensor.rs @@ -1,6 +1,6 @@ use burn_tensor::{ ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps}, - Bool, Device, Distribution, ElementConversion, Reader, Shape, TensorData, + Bool, Device, Distribution, ElementConversion, Shape, TensorData, }; use crate::{ @@ -19,8 +19,8 @@ impl IntTensorOps for Candle(tensor: IntTensor) -> Reader { - Reader::Concrete(super::base::into_data(tensor)) + async fn int_into_data(tensor: IntTensor) -> TensorData { + super::base::into_data(tensor) } fn int_from_data( diff --git a/crates/burn-candle/src/ops/tensor.rs b/crates/burn-candle/src/ops/tensor.rs index e021c2e47..b50dfc193 100644 --- a/crates/burn-candle/src/ops/tensor.rs +++ b/crates/burn-candle/src/ops/tensor.rs @@ -2,7 +2,7 @@ use std::borrow::Borrow; use burn_tensor::{ ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, FullPrecisionBackend, IntTensor}, - Device, Distribution, ElementConversion, Reader, Shape, TensorData, + Device, Distribution, ElementConversion, Shape, TensorData, }; use candle_core::{backend::BackendStorage, shape, Tensor}; @@ -59,8 +59,8 @@ impl FloatTensorOps for Candle super::base::shape(tensor) } - fn float_into_data(tensor: CandleTensor) -> Reader { - Reader::Concrete(super::base::into_data(tensor)) + async fn float_into_data(tensor: CandleTensor) -> TensorData { + super::base::into_data(tensor) } fn float_device(tensor: &CandleTensor) -> Device { diff --git a/crates/burn-common/Cargo.toml b/crates/burn-common/Cargo.toml index ec35ffbf8..402b0215e 100644 --- a/crates/burn-common/Cargo.toml +++ b/crates/burn-common/Cargo.toml @@ -12,25 +12,23 @@ version.workspace = true [features] default = ["std"] -std = ["rand/std", "data-encoding/std"] +std = ["rand/std", "data-encoding/std", "dep:pollster"] doc = ["default"] -wasm-sync = [] network = ["dep:indicatif", "dep:reqwest", "dep:tokio"] [target.'cfg(target_family = "wasm")'.dependencies] -async-trait = { workspace = true } getrandom = { workspace = true, features = ["js"] } web-time = { version = "1.1.0" } [dependencies] # ** Please make sure all dependencies support no_std when std is disabled ** - rand = { workspace = true } -spin = { workspace = true } # using in place of use std::sync::Mutex; +spin = { workspace = true } # using in place of use std::sync::Mutex; derive-new = { workspace = true } serde = { workspace = true } data-encoding = { workspace = true } +pollster = { workspace = true, optional = true } # Network downloader indicatif = { workspace = true, optional = true } diff --git a/crates/burn-common/src/reader.rs b/crates/burn-common/src/reader.rs index 408b44c11..d459a6e9c 100644 --- a/crates/burn-common/src/reader.rs +++ b/crates/burn-common/src/reader.rs @@ -1,115 +1,54 @@ -use alloc::boxed::Box; -use core::marker::PhantomData; +use alloc::{boxed::Box, sync::Arc, task::Wake, vec::Vec}; +use core::{ + future::Future, + pin::Pin, + task::{Context, Poll, Waker}, +}; -#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] -#[async_trait::async_trait] -/// Allows to create async reader. -pub trait AsyncReader: Send { - /// Read asynchronously. - async fn read(self: Box) -> T; +/// A future that is used to read resources from a compute server. +pub type Reader = Pin> + Send>>; + +/// Create a reader from a concrete value. +pub fn reader_from_concrete(val: Vec) -> Reader { + Box::pin(async move { val }) } -/// Define how data is read, sync or async. -pub enum Reader { - /// Concrete variant. - Concrete(T), - /// Sync data variant. - Sync(Box>), - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - /// Async data variant. - Async(Box>), - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - /// Future data variant. - Future(core::pin::Pin + Send>>), +struct DummyWaker; + +impl Wake for DummyWaker { + fn wake(self: Arc) {} + fn wake_by_ref(self: &Arc) {} } -/// Allows to create sync reader. -pub trait SyncReader: Send { - /// Read synchronously. - fn read(self: Box) -> T; +/// Read a future synchronously. +/// +/// On WASM futures cannot block, so this only succeeds if the future returns immediately. +/// If you want to handle this error, please use +/// try_read_sync instead. +pub fn read_sync, T>(f: F) -> T { + try_read_sync(f).expect("Failed to read tensor data synchronously. This can happen on platforms that don't support blocking futures like WASM. If possible, try using an async variant of this function instead.") } -#[derive(new)] -struct MappedReader { - reader: Reader, - mapper: F, - _output: PhantomData, -} +/// Read a future synchronously. +/// +/// On WASM futures cannot block, so this only succeeds if the future returns immediately. +/// otherwise this returns None. +pub fn try_read_sync, T>(f: F) -> Option { + // Create a dummy context. + let waker = Waker::from(Arc::new(DummyWaker)); + let mut context = Context::from_waker(&waker); -impl SyncReader for MappedReader -where - I: Send, - O: Send, - F: Send + FnOnce(I) -> O, -{ - fn read(self: Box) -> O { - let input = self - .reader - .read_sync() - .expect("Only sync data supported in a sync reader."); + // Pin & poll the future. Some backends don't do async readbacks, and instead immediately get + // the data. This let's us detect when a future is synchronous and doesn't require any waiting. + let mut pinned = core::pin::pin!(f); - (self.mapper)(input) - } -} - -#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] -#[async_trait::async_trait] -impl AsyncReader for MappedReader -where - I: Send, - O: Send, - F: Send + FnOnce(I) -> O, -{ - async fn read(self: Box) -> O { - let input = self.reader.read().await; - (self.mapper)(input) - } -} - -impl Reader { - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - /// Read the data. - pub async fn read(self) -> T { - match self { - Self::Concrete(data) => data, - Self::Sync(reader) => reader.read(), - Self::Async(func) => func.read().await, - Self::Future(future) => future.await, - } - } - - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - /// Read the data. - pub fn read(self) -> T { - match self { - Self::Concrete(data) => data, - Self::Sync(reader) => reader.read(), - } - } - - /// Read the data only if sync, returns None if an async reader. - pub fn read_sync(self) -> Option { - match self { - Self::Concrete(data) => Some(data), - Self::Sync(reader) => Some(reader.read()), - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - Self::Async(_func) => return None, - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - Self::Future(_future) => return None, - } - } - - /// Map the current reader to another type. - pub fn map(self, mapper: F) -> Reader - where - T: 'static + Send, - O: 'static + Send, - F: FnOnce(T) -> O + 'static + Send, - { - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - return Reader::Async(Box::new(MappedReader::new(self, mapper))); - - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - Reader::Sync(Box::new(MappedReader::new(self, mapper))) + match pinned.as_mut().poll(&mut context) { + Poll::Ready(output) => Some(output), + // On platforms that support it, now just block on the future and drive it to completion. + #[cfg(all(not(target_family = "wasm"), feature = "std"))] + Poll::Pending => Some(pollster::block_on(pinned)), + // Otherwise, just bail and return None - this futures will have to be read back asynchronously. + #[cfg(any(target_family = "wasm", not(feature = "std")))] + Poll::Pending => None, } } diff --git a/crates/burn-compute/Cargo.toml b/crates/burn-compute/Cargo.toml index b5f9309e4..30a0a602b 100644 --- a/crates/burn-compute/Cargo.toml +++ b/crates/burn-compute/Cargo.toml @@ -22,7 +22,7 @@ default = [ std = ["burn-common/std"] channel-mutex = [] channel-cell = [] -channel-mpsc = [] # Assume std +channel-mpsc = ["dep:async-channel", "dep:pollster"] # Assume std storage-bytes = [] autotune-persistent-cache = ["dirs", "md5", "serde", "serde_json"] # Assume std @@ -36,6 +36,8 @@ dirs = { workspace = true, optional = true } serde = { workspace = true, optional = true } serde_json = { workspace = true, features = ["std"], optional = true } md5 = { workspace = true, optional = true } +pollster = { workspace = true, optional = true } +async-channel = { workspace = true, optional = true } [target.'cfg(target_family = "wasm")'.dependencies] web-time = { workspace = true } diff --git a/crates/burn-compute/src/channel/base.rs b/crates/burn-compute/src/channel/base.rs index d74a23f9b..0a6910337 100644 --- a/crates/burn-compute/src/channel/base.rs +++ b/crates/burn-compute/src/channel/base.rs @@ -9,7 +9,7 @@ use burn_common::{reader::Reader, sync_type::SyncType}; /// while ensuring thread-safety pub trait ComputeChannel: Clone + core::fmt::Debug + Send + Sync { /// Given a binding, returns owned resource as bytes - fn read(&self, binding: Binding) -> Reader>; + fn read(&self, binding: Binding) -> Reader; /// Given a resource handle, return the storage resource. fn get_resource( diff --git a/crates/burn-compute/src/channel/cell.rs b/crates/burn-compute/src/channel/cell.rs index 4d809694d..631947966 100644 --- a/crates/burn-compute/src/channel/cell.rs +++ b/crates/burn-compute/src/channel/cell.rs @@ -42,9 +42,9 @@ where impl ComputeChannel for RefCellComputeChannel where - Server: ComputeServer, + Server: ComputeServer + Send, { - fn read(&self, binding: Binding) -> Reader> { + fn read(&self, binding: Binding) -> Reader { self.server.borrow_mut().read(binding) } diff --git a/crates/burn-compute/src/channel/mpsc.rs b/crates/burn-compute/src/channel/mpsc.rs index b0af33e01..8eeb04326 100644 --- a/crates/burn-compute/src/channel/mpsc.rs +++ b/crates/burn-compute/src/channel/mpsc.rs @@ -1,9 +1,5 @@ -use std::{ - sync::{mpsc, Arc}, - thread, -}; - use burn_common::{reader::Reader, sync_type::SyncType}; +use std::{sync::Arc, thread}; use super::ComputeChannel; use crate::{ @@ -11,7 +7,7 @@ use crate::{ storage::ComputeStorage, }; -/// Create a channel using the [multi-producer, single-consumer channel](mpsc) to communicate with +/// Create a channel using a [multi-producer, single-consumer channel to communicate with /// the compute server spawn on its own thread. #[derive(Debug)] pub struct MpscComputeChannel @@ -27,16 +23,16 @@ where Server: ComputeServer, { _handle: thread::JoinHandle<()>, - sender: mpsc::Sender>, + sender: async_channel::Sender>, } -type Callback = mpsc::Sender; +type Callback = async_channel::Sender; enum Message where Server: ComputeServer, { - Read(Binding, Callback>>), + Read(Binding, Callback>), GetResource( Binding, Callback<::Resource>, @@ -53,36 +49,40 @@ where { /// Create a new mpsc compute channel. pub fn new(mut server: Server) -> Self { - let (sender, receiver) = mpsc::channel(); + let (sender, receiver) = async_channel::unbounded(); let _handle = thread::spawn(move || { - while let Ok(message) = receiver.recv() { - match message { - Message::Read(binding, callback) => { - let data = server.read(binding); - callback.send(data).unwrap(); - } - Message::GetResource(binding, callback) => { - let data = server.get_resource(binding); - callback.send(data).unwrap(); - } - Message::Create(data, callback) => { - let handle = server.create(&data); - callback.send(handle).unwrap(); - } - Message::Empty(size, callback) => { - let handle = server.empty(size); - callback.send(handle).unwrap(); - } - Message::ExecuteKernel(kernel, bindings) => { - server.execute(kernel, bindings); - } - Message::Sync(sync_type, callback) => { - server.sync(sync_type); - callback.send(()).unwrap(); - } - }; - } + // Run the whole procedure as one blocking future. This is much simpler than trying + // to use some multithreaded executor. + pollster::block_on(async { + while let Ok(message) = receiver.recv().await { + match message { + Message::Read(binding, callback) => { + let data = server.read(binding).await; + callback.send(data).await.unwrap(); + } + Message::GetResource(binding, callback) => { + let data = server.get_resource(binding); + callback.send(data).await.unwrap(); + } + Message::Create(data, callback) => { + let handle = server.create(&data); + callback.send(handle).await.unwrap(); + } + Message::Empty(size, callback) => { + let handle = server.empty(size); + callback.send(handle).await.unwrap(); + } + Message::ExecuteKernel(kernel, bindings) => { + server.execute(kernel, bindings); + } + Message::Sync(sync_type, callback) => { + server.sync(sync_type); + callback.send(()).await.unwrap(); + } + }; + } + }); }); let state = Arc::new(MpscComputeChannelState { sender, _handle }); @@ -103,75 +103,71 @@ impl ComputeChannel for MpscComputeChannel where Server: ComputeServer + 'static, { - fn read(&self, binding: Binding) -> Reader> { - let (callback, response) = mpsc::channel(); + fn read(&self, binding: Binding) -> Reader { + let sender = self.state.sender.clone(); - self.state - .sender - .send(Message::Read(binding, callback)) - .unwrap(); - - self.response(response) + Box::pin(async move { + let (callback, response) = async_channel::unbounded(); + sender.send(Message::Read(binding, callback)).await.unwrap(); + handle_response(response.recv().await) + }) } fn get_resource( &self, binding: Binding, ) -> ::Resource { - let (callback, response) = mpsc::channel(); + let (callback, response) = async_channel::unbounded(); self.state .sender - .send(Message::GetResource(binding, callback)) + .send_blocking(Message::GetResource(binding, callback)) .unwrap(); - self.response(response) + handle_response(response.recv_blocking()) } fn create(&self, data: &[u8]) -> Handle { - let (callback, response) = mpsc::channel(); + let (callback, response) = async_channel::unbounded(); self.state .sender - .send(Message::Create(data.to_vec(), callback)) + .send_blocking(Message::Create(data.to_vec(), callback)) .unwrap(); - self.response(response) + handle_response(response.recv_blocking()) } fn empty(&self, size: usize) -> Handle { - let (callback, response) = mpsc::channel(); - + let (callback, response) = async_channel::unbounded(); self.state .sender - .send(Message::Empty(size, callback)) + .send_blocking(Message::Empty(size, callback)) .unwrap(); - self.response(response) + handle_response(response.recv_blocking()) } fn execute(&self, kernel: Server::Kernel, bindings: Vec>) { self.state .sender - .send(Message::ExecuteKernel(kernel, bindings)) + .send_blocking(Message::ExecuteKernel(kernel, bindings)) .unwrap() } fn sync(&self, sync_type: SyncType) { - let (callback, response) = mpsc::channel(); + let (callback, response) = async_channel::unbounded(); self.state .sender - .send(Message::Sync(sync_type, callback)) + .send_blocking(Message::Sync(sync_type, callback)) .unwrap(); - self.response(response) + handle_response(response.recv_blocking()) } } -impl MpscComputeChannel { - fn response(&self, response: mpsc::Receiver) -> Response { - match response.recv() { - Ok(val) => val, - Err(err) => panic!("Can't connect to the server correctly {err:?}"), - } +fn handle_response(response: Result) -> Response { + match response { + Ok(val) => val, + Err(err) => panic!("Can't connect to the server correctly {err:?}"), } } diff --git a/crates/burn-compute/src/channel/mutex.rs b/crates/burn-compute/src/channel/mutex.rs index 141cfca7b..a063ab1f1 100644 --- a/crates/burn-compute/src/channel/mutex.rs +++ b/crates/burn-compute/src/channel/mutex.rs @@ -37,7 +37,7 @@ impl ComputeChannel for MutexComputeChannel where Server: ComputeServer, { - fn read(&self, handle: Binding) -> Reader> { + fn read(&self, handle: Binding) -> Reader { self.server.lock().read(handle) } diff --git a/crates/burn-compute/src/client.rs b/crates/burn-compute/src/client.rs index c946682b9..c5ea1d7f4 100644 --- a/crates/burn-compute/src/client.rs +++ b/crates/burn-compute/src/client.rs @@ -7,7 +7,7 @@ use crate::{ use alloc::vec::Vec; use alloc::{boxed::Box, sync::Arc}; use burn_common::stub::RwLock; -use burn_common::{reader::Reader, sync_type::SyncType}; +use burn_common::sync_type::SyncType; /// The ComputeClient is the entry point to require tasks from the ComputeServer. /// It should be obtained for a specific device via the Compute struct. @@ -41,8 +41,16 @@ where } /// Given a binding, returns owned resource as bytes. - pub fn read(&self, binding: Binding) -> Reader> { - self.channel.read(binding) + pub async fn read_async(&self, binding: Binding) -> Vec { + self.channel.read(binding).await + } + + /// Given a binding, returns owned resource as bytes. + /// + /// # Remarks + /// Panics if the read operation fails. + pub fn read(&self, binding: Binding) -> Vec { + burn_common::reader::read_sync(self.channel.read(binding)) } /// Given a resource handle, returns the storage resource. diff --git a/crates/burn-compute/src/server.rs b/crates/burn-compute/src/server.rs index cb495b34a..41e0b56cf 100644 --- a/crates/burn-compute/src/server.rs +++ b/crates/burn-compute/src/server.rs @@ -25,7 +25,7 @@ where type AutotuneKey: AutotuneKey; /// Given a handle, returns the owned resource as bytes. - fn read(&mut self, binding: Binding) -> Reader>; + fn read(&mut self, binding: Binding) -> Reader; /// Given a resource handle, returns the storage resource. fn get_resource( diff --git a/crates/burn-compute/tests/dummy/server.rs b/crates/burn-compute/tests/dummy/server.rs index 77c83cc5e..999590a2f 100644 --- a/crates/burn-compute/tests/dummy/server.rs +++ b/crates/burn-compute/tests/dummy/server.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use burn_common::{reader::Reader, sync_type::SyncType}; +use burn_common::{reader::reader_from_concrete, sync_type::SyncType}; use burn_compute::{ memory_management::{simple::SimpleMemoryManagement, MemoryHandle, MemoryManagement}, server::{Binding, ComputeServer, Handle}, @@ -26,10 +26,9 @@ where type MemoryManagement = MM; type AutotuneKey = String; - fn read(&mut self, binding: Binding) -> Reader> { + fn read(&mut self, binding: Binding) -> burn_common::reader::Reader { let bytes = self.memory_management.get(binding.memory); - - Reader::Concrete(bytes.read().to_vec()) + reader_from_concrete(bytes.read().to_vec()) } fn get_resource(&mut self, binding: Binding) -> BytesResource { diff --git a/crates/burn-compute/tests/integration_test.rs b/crates/burn-compute/tests/integration_test.rs index 14539ef80..b134090b8 100644 --- a/crates/burn-compute/tests/integration_test.rs +++ b/crates/burn-compute/tests/integration_test.rs @@ -16,7 +16,7 @@ fn created_resource_is_the_same_when_read() { let obtained_resource = client.read(resource_description.binding()); - assert_eq!(resource, obtained_resource.read()) + assert_eq!(resource, obtained_resource) } #[test] @@ -26,7 +26,7 @@ fn empty_allocates_memory() { let resource_description = client.empty(size); let empty_resource = client.read(resource_description.binding()); - assert_eq!(empty_resource.read().len(), 4); + assert_eq!(empty_resource.len(), 4); } #[test] @@ -43,7 +43,7 @@ fn execute_elementwise_addition() { let obtained_resource = client.read(out.binding()); - assert_eq!(obtained_resource.read(), Vec::from([4, 5, 6])) + assert_eq!(obtained_resource, Vec::from([4, 5, 6])) } #[test] @@ -65,7 +65,7 @@ fn autotune_basic_addition_execution() { let obtained_resource = client.read(out.binding()); // If slow kernel was selected it would output [0, 1, 2] - assert_eq!(obtained_resource.read(), Vec::from([4, 5, 6])); + assert_eq!(obtained_resource, Vec::from([4, 5, 6])); } #[test] @@ -87,7 +87,7 @@ fn autotune_basic_multiplication_execution() { let obtained_resource = client.read(out.binding()); // If slow kernel was selected it would output [0, 1, 2] - assert_eq!(obtained_resource.read(), Vec::from([0, 4, 8])); + assert_eq!(obtained_resource, Vec::from([0, 4, 8])); } #[test] @@ -126,7 +126,7 @@ fn autotune_cache_same_key_return_a_cache_hit() { let obtained_resource = client.read(out_2.binding()); // Cache should be hit, so CacheTestFastOn3 should be used, returning lhs - assert_eq!(obtained_resource.read(), Vec::from([0, 1, 2, 3])); + assert_eq!(obtained_resource, Vec::from([0, 1, 2, 3])); } #[test] @@ -167,7 +167,7 @@ fn autotune_cache_no_cache_on_disk_return_a_cache_miss() { let obtained_resource = client.read(out_2.binding()); // Cache should be missed, so CacheTestSlowOn3 (but faster on 5) should be used, returning rhs - assert_eq!(obtained_resource.read(), Vec::from([5, 6, 7, 8, 9])); + assert_eq!(obtained_resource, Vec::from([5, 6, 7, 8, 9])); } #[test] @@ -175,6 +175,8 @@ fn autotune_cache_no_cache_on_disk_return_a_cache_miss() { #[cfg(feature = "std")] fn autotune_cache_file_path_creation_works_when_path_does_not_exist_yet() { // delete the cache file + + use burn_common::sync_type::SyncType; let file_path = burn_compute::tune::get_persistent_cache_file_path(crate::dummy::TUNER_PREFIX); let parent_dir = file_path .parent() @@ -198,7 +200,7 @@ fn autotune_cache_file_path_creation_works_when_path_does_not_exist_yet() { dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes, handles); client.autotune_execute(Box::new(cache_test_autotune_kernel)); // ensure that the autotune operations are run and cached - let _obtained_resource = client.read(out.binding()); + client.sync(SyncType::Wait); assert!( parent_dir.exists(), @@ -237,7 +239,7 @@ fn autotune_cache_different_keys_return_a_cache_miss() { let obtained_resource = client.read(out_2.binding()); // Cache should be missed, so CacheTestSlowOn3 (but faster on 5) should be used, returning rhs - assert_eq!(obtained_resource.read(), Vec::from([5, 6, 7, 8, 9])); + assert_eq!(obtained_resource, Vec::from([5, 6, 7, 8, 9])); } #[test] @@ -285,5 +287,5 @@ fn autotune_cache_different_checksums_return_a_cache_miss() { // Cache should be missed because the checksum on 4 is generated randomly // and thus is always different, // so CacheTestSlowOn3 (but faster on 4) should be used, returning rhs - assert_eq!(obtained_resource.read(), Vec::from([5, 6, 7, 8])); + assert_eq!(obtained_resource, Vec::from([5, 6, 7, 8])); } diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index e4aad8f8f..a33fab299 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -65,8 +65,6 @@ sqlite = ["burn-dataset?/sqlite"] sqlite-bundled = ["burn-dataset?/sqlite-bundled"] vision = ["burn-dataset?/vision", "burn-common/network"] -wasm-sync = ["burn-tensor/wasm-sync", "burn-common/wasm-sync"] - # Backend autodiff = ["burn-autodiff"] fusion = ["burn-wgpu?/fusion"] diff --git a/crates/burn-core/src/grad_clipping/base.rs b/crates/burn-core/src/grad_clipping/base.rs index ecb5f2da8..b60bd1a7e 100644 --- a/crates/burn-core/src/grad_clipping/base.rs +++ b/crates/burn-core/src/grad_clipping/base.rs @@ -69,16 +69,6 @@ impl GradientClipping { clipped_grad.mask_fill(lower_mask, -threshold) } - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - fn clip_by_norm( - &self, - _grad: Tensor, - _threshold: f32, - ) -> Tensor { - todo!("Not yet supported on wasm"); - } - - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn clip_by_norm( &self, grad: Tensor, @@ -97,11 +87,9 @@ impl GradientClipping { } } - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn l2_norm(tensor: Tensor) -> Tensor { let squared = tensor.powf_scalar(2.0); let sum = squared.sum(); - sum.sqrt() } } diff --git a/crates/burn-core/src/record/tensor.rs b/crates/burn-core/src/record/tensor.rs index b783b71d6..97bdf8796 100644 --- a/crates/burn-core/src/record/tensor.rs +++ b/crates/burn-core/src/record/tensor.rs @@ -135,10 +135,6 @@ impl Record for Tensor { type Item = FloatTensorSerde; fn into_item(self) -> Self::Item { - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - todo!("Recording float tensors isn't yet supported on wasm."); - - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] FloatTensorSerde::new(self.into_data().convert::()) } @@ -152,10 +148,6 @@ impl Record for Tensor { type Item = IntTensorSerde; fn into_item(self) -> Self::Item { - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - todo!("Recording int tensors isn't yet supported on wasm."); - - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] IntTensorSerde::new(self.into_data().convert::()) } @@ -169,10 +161,6 @@ impl Record for Tensor { type Item = BoolTensorSerde; fn into_item(self) -> Self::Item { - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - todo!("Recording bool tensors isn't yet supported on wasm."); - - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] BoolTensorSerde::new(self.into_data()) } diff --git a/crates/burn-cube/src/runtime_tests/launch.rs b/crates/burn-cube/src/runtime_tests/launch.rs index 6f3a646b7..f69fb820b 100644 --- a/crates/burn-cube/src/runtime_tests/launch.rs +++ b/crates/burn-cube/src/runtime_tests/launch.rs @@ -25,7 +25,7 @@ pub fn test_kernel_with_generics(client: ComputeClient(client: ComputeClient( TensorHandle::new(&handle, &strides, &shape), ); - let actual = client.read(handle.binding()).read_sync().unwrap(); + let actual = client.read(handle.binding()); let actual = f32::from_bytes(&actual); assert_eq!(actual, expected); diff --git a/crates/burn-cuda/src/compute/server.rs b/crates/burn-cuda/src/compute/server.rs index f88b2e3ce..8864ff457 100644 --- a/crates/burn-cuda/src/compute/server.rs +++ b/crates/burn-cuda/src/compute/server.rs @@ -8,6 +8,8 @@ use burn_cube::ir::CubeDim; use burn_cube::prelude::*; use burn_jit::JitAutotuneKey; use burn_tensor::backend::SyncType; +use burn_tensor::reader_from_concrete; +use burn_tensor::Reader; use cudarc::driver::sys::CUctx_st; use cudarc::driver::sys::CUfunc_st; use std::collections::HashMap; @@ -58,18 +60,17 @@ impl> ComputeServer for CudaServer { type MemoryManagement = MM; type AutotuneKey = JitAutotuneKey; - fn read(&mut self, binding: server::Binding) -> burn_tensor::Reader> { + fn read(&mut self, binding: server::Binding) -> Reader { let ctx = self.get_context(); let resource = ctx.memory_management.get(binding.memory); + // TODO: Check if it is possible to make this faster let mut data = vec![0; resource.size() as usize]; unsafe { cudarc::driver::result::memcpy_dtoh_async(&mut data, resource.ptr, ctx.stream).unwrap(); }; - ctx.sync(); - - burn_tensor::Reader::Concrete(data) + reader_from_concrete(data) } fn create(&mut self, data: &[u8]) -> server::Handle { diff --git a/crates/burn-fusion/src/client/base.rs b/crates/burn-fusion/src/client/base.rs index dc5b33a1d..a98b958dc 100644 --- a/crates/burn-fusion/src/client/base.rs +++ b/crates/burn-fusion/src/client/base.rs @@ -1,10 +1,12 @@ +use std::future::Future; + use crate::{ stream::{execution::Operation, StreamId}, FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionTensor, }; use burn_tensor::{ repr::{OperationDescription, TensorDescription, TensorId}, - DType, Reader, TensorData, + DType, TensorData, }; /// Define how to interact with the fusion server. @@ -37,7 +39,7 @@ where &self, tensor: TensorDescription, stream: StreamId, - ) -> Reader + ) -> impl Future + Send where B: FusionBackend; /// Read the values contained by an int tensor. @@ -45,7 +47,7 @@ where &self, tensor: TensorDescription, stream: StreamId, - ) -> Reader + ) -> impl Future + Send where B: FusionBackend; /// Read the values contained by a bool tensor. @@ -53,7 +55,7 @@ where &self, tensor: TensorDescription, stream: StreamId, - ) -> Reader + ) -> impl Future + Send where B: FusionBackend; /// Change the client of the given float tensor. diff --git a/crates/burn-fusion/src/client/mutex.rs b/crates/burn-fusion/src/client/mutex.rs index 481ac05fb..89b470cce 100644 --- a/crates/burn-fusion/src/client/mutex.rs +++ b/crates/burn-fusion/src/client/mutex.rs @@ -78,37 +78,37 @@ where FusionTensor::new(id, shape, dtype, self.clone(), stream) } - fn read_tensor_float( + async fn read_tensor_float( &self, tensor: TensorDescription, stream: StreamId, - ) -> burn_tensor::Reader + ) -> burn_tensor::TensorData where B: FusionBackend, { - self.server.lock().read_float::(tensor, stream) + self.server.lock().read_float::(tensor, stream).await } - fn read_tensor_int( + async fn read_tensor_int( &self, tensor: TensorDescription, id: StreamId, - ) -> burn_tensor::Reader + ) -> burn_tensor::TensorData where B: FusionBackend, { - self.server.lock().read_int::(tensor, id) + self.server.lock().read_int::(tensor, id).await } - fn read_tensor_bool( + async fn read_tensor_bool( &self, tensor: TensorDescription, stream: StreamId, - ) -> burn_tensor::Reader + ) -> burn_tensor::TensorData where B: FusionBackend, { - self.server.lock().read_bool::(tensor, stream) + self.server.lock().read_bool::(tensor, stream).await } fn change_client_float( diff --git a/crates/burn-fusion/src/ops/boolean.rs b/crates/burn-fusion/src/ops/boolean.rs index da7e4df19..c97899487 100644 --- a/crates/burn-fusion/src/ops/boolean.rs +++ b/crates/burn-fusion/src/ops/boolean.rs @@ -1,4 +1,4 @@ -use burn_tensor::{DType, Element}; +use burn_tensor::{DType, Element, TensorData}; use std::marker::PhantomData; use crate::{ @@ -37,10 +37,8 @@ impl BoolTensorOps for Fusion { tensor.shape() } - fn bool_into_data( - tensor: BoolTensor, - ) -> burn_tensor::Reader { - tensor.bool_into_data::() + async fn bool_into_data(tensor: BoolTensor) -> TensorData { + tensor.bool_into_data::().await } fn bool_from_data( diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index ea28bd1a0..89bc36c18 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -10,7 +10,7 @@ use crate::{ use burn_tensor::{ ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor}, repr::*, - DType, Device, Distribution, Element, ElementConversion, Reader, Shape, TensorData, + DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, }; use std::{marker::PhantomData, ops::Range}; @@ -169,8 +169,8 @@ impl FloatTensorOps for Fusion { tensor.shape() } - fn float_into_data(tensor: FloatTensor) -> Reader { - tensor.into_data::() + async fn float_into_data(tensor: FloatTensor) -> TensorData { + tensor.into_data::().await } fn float_device(tensor: &FloatTensor) -> Device { diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index 42c1ba8c9..643dfbe9e 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -10,7 +10,7 @@ use crate::{ use burn_tensor::{ ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps}, repr::{self, *}, - DType, Device, Distribution, Element, ElementConversion, Reader, Shape, TensorData, + DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, }; use core::ops::Range; use std::marker::PhantomData; @@ -33,8 +33,8 @@ impl IntTensorOps for Fusion { tensor.shape() } - fn int_into_data(tensor: IntTensor) -> Reader { - tensor.int_into_data::() + async fn int_into_data(tensor: IntTensor) -> TensorData { + tensor.int_into_data::().await } fn int_from_data( diff --git a/crates/burn-fusion/src/server.rs b/crates/burn-fusion/src/server.rs index acaf6891d..32990cb32 100644 --- a/crates/burn-fusion/src/server.rs +++ b/crates/burn-fusion/src/server.rs @@ -39,11 +39,11 @@ where self.handles.create_tensor_uninit() } - pub fn read_float( + pub async fn read_float( &mut self, tensor: TensorDescription, id: StreamId, - ) -> burn_tensor::Reader + ) -> burn_tensor::TensorData where B: FusionBackend, { @@ -52,14 +52,14 @@ where self.drain_stream(id); let tensor = self.handles.get_float_tensor::(&tensor); - B::float_into_data(tensor) + B::float_into_data(tensor).await } - pub fn read_int( + pub async fn read_int( &mut self, tensor: TensorDescription, id: StreamId, - ) -> burn_tensor::Reader + ) -> burn_tensor::TensorData where B: FusionBackend, { @@ -68,14 +68,14 @@ where self.drain_stream(id); let tensor = self.handles.get_int_tensor::(&tensor); - B::int_into_data(tensor) + B::int_into_data(tensor).await } - pub fn read_bool( + pub async fn read_bool( &mut self, tensor: TensorDescription, id: StreamId, - ) -> burn_tensor::Reader + ) -> burn_tensor::TensorData where B: FusionBackend, { @@ -84,7 +84,7 @@ where self.drain_stream(id); let tensor = self.handles.get_bool_tensor::(&tensor); - B::bool_into_data(tensor) + B::bool_into_data(tensor).await } pub fn change_server_float( diff --git a/crates/burn-fusion/src/tensor.rs b/crates/burn-fusion/src/tensor.rs index 9e9d24621..54630a991 100644 --- a/crates/burn-fusion/src/tensor.rs +++ b/crates/burn-fusion/src/tensor.rs @@ -1,7 +1,7 @@ use crate::{client::FusionClient, stream::StreamId, Client, FusionBackend, FusionRuntime}; use burn_tensor::{ repr::{TensorDescription, TensorId, TensorStatus}, - DType, Reader, Shape, TensorData, + DType, Shape, TensorData, }; use std::sync::Arc; @@ -108,7 +108,7 @@ impl FusionTensor { } } - pub(crate) fn into_data(self) -> Reader + pub(crate) async fn into_data(self) -> TensorData where B: FusionBackend, { @@ -116,9 +116,10 @@ impl FusionTensor { self.client .clone() .read_tensor_float::(self.into_description(), id) + .await } - pub(crate) fn int_into_data(self) -> Reader + pub(crate) async fn int_into_data(self) -> TensorData where B: FusionBackend, { @@ -126,9 +127,10 @@ impl FusionTensor { self.client .clone() .read_tensor_int::(self.into_description(), id) + .await } - pub(crate) fn bool_into_data(self) -> Reader + pub(crate) async fn bool_into_data(self) -> TensorData where B: FusionBackend, { @@ -136,6 +138,7 @@ impl FusionTensor { self.client .clone() .read_tensor_bool::(self.into_description(), id) + .await } } diff --git a/crates/burn-jit/src/ops/base.rs b/crates/burn-jit/src/ops/base.rs index f88bd326f..5cc0cc1cf 100644 --- a/crates/burn-jit/src/ops/base.rs +++ b/crates/burn-jit/src/ops/base.rs @@ -1,6 +1,6 @@ use crate::{element::JitElement, kernel, tensor::JitTensor, JitRuntime}; use burn_cube::CubeElement; -use burn_tensor::{Reader, Shape, TensorData}; +use burn_tensor::{Shape, TensorData}; use std::marker::PhantomData; pub(crate) fn from_data( @@ -14,28 +14,24 @@ pub(crate) fn from_data( JitTensor::new(client, device.clone(), shape, buffer) } -pub(crate) fn into_data( +pub(crate) async fn into_data( tensor: JitTensor, -) -> Reader { +) -> TensorData { let tensor = kernel::into_contiguous(tensor); - tensor - .client - .read(tensor.handle.binding()) - .map(|bytes| TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape)) + let bytes = tensor.client.read_async(tensor.handle.binding()).await; + TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape) } -pub(crate) fn bool_into_data( +pub(crate) async fn bool_into_data( tensor: JitTensor, -) -> Reader { +) -> TensorData { let tensor = kernel::into_contiguous(tensor); - - tensor.client.read(tensor.handle.binding()).map(|bytes| { - TensorData::new( - u32::from_bytes(&bytes).iter().map(|i| *i != 0).collect(), - tensor.shape, - ) - }) + let bytes = tensor.client.read_async(tensor.handle.binding()).await; + TensorData::new( + u32::from_bytes(&bytes).iter().map(|i| *i != 0).collect(), + tensor.shape, + ) } pub(crate) fn to_device( diff --git a/crates/burn-jit/src/ops/bool_ops.rs b/crates/burn-jit/src/ops/bool_ops.rs index aca338dbe..9f0082ee1 100644 --- a/crates/burn-jit/src/ops/bool_ops.rs +++ b/crates/burn-jit/src/ops/bool_ops.rs @@ -1,6 +1,5 @@ use crate::{kernel, FloatElement, IntElement, JitBackend, JitRuntime}; use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntTensor}; -use burn_tensor::Reader; use burn_tensor::{ops::BoolTensorOps, Shape, TensorData}; use std::ops::Range; @@ -20,8 +19,8 @@ where tensor.shape.clone() } - fn bool_into_data(tensor: BoolTensor) -> Reader { - super::bool_into_data(tensor) + async fn bool_into_data(tensor: BoolTensor) -> TensorData { + super::bool_into_data(tensor).await } fn bool_from_data( diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index 383c3c658..ff10c97fb 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -7,8 +7,8 @@ use crate::{FloatElement, IntElement, JitRuntime}; use burn_cube::ir::{BinaryOperator, Elem, Operator, Scope, UnaryOperator, Variable}; use burn_cube::Runtime; use burn_tensor::ops::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor}; +use burn_tensor::ElementConversion; use burn_tensor::{ops::FloatTensorOps, Distribution, Shape, TensorData}; -use burn_tensor::{ElementConversion, Reader}; use std::ops::Range; impl FloatTensorOps for JitBackend @@ -45,8 +45,8 @@ where tensor.shape.clone() } - fn float_into_data(tensor: FloatTensor) -> Reader { - super::into_data(tensor) + async fn float_into_data(tensor: FloatTensor) -> TensorData { + super::into_data(tensor).await } fn float_device(tensor: &FloatTensor) -> Device { diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index 7b61faa19..efeafe3d7 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -4,7 +4,7 @@ use crate::{kernel, unary, FloatElement, IntElement, JitBackend, JitRuntime}; use burn_cube::ir::{Elem, Item, Operator, Scope, UnaryOperator, Variable}; use burn_cube::Runtime; use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor}; -use burn_tensor::{ops::IntTensorOps, Distribution, ElementConversion, Reader, Shape, TensorData}; +use burn_tensor::{ops::IntTensorOps, Distribution, ElementConversion, Shape, TensorData}; use std::ops::Range; impl IntTensorOps for JitBackend @@ -21,8 +21,8 @@ where tensor.shape.clone() } - fn int_into_data(tensor: IntTensor) -> Reader { - super::into_data(tensor) + async fn int_into_data(tensor: IntTensor) -> TensorData { + super::into_data(tensor).await } fn int_from_data( diff --git a/crates/burn-jit/src/tensor/base.rs b/crates/burn-jit/src/tensor/base.rs index 24bf34d49..9a4b4f0a8 100644 --- a/crates/burn-jit/src/tensor/base.rs +++ b/crates/burn-jit/src/tensor/base.rs @@ -101,11 +101,10 @@ where client: ComputeClient, device: R::Device, ) -> Self { - let bytes = self - .client - .read(self.handle.clone().binding()) - .read_sync() - .expect("Can only change client synchronously"); + let bytes = burn_common::reader::try_read_sync( + self.client.read_async(self.handle.clone().binding()), + ) + .expect("Can only change client synchronously"); let handle = client.create(&bytes); Self { diff --git a/crates/burn-ndarray/src/ops/bool_tensor.rs b/crates/burn-ndarray/src/ops/bool_tensor.rs index d07949563..432327b49 100644 --- a/crates/burn-ndarray/src/ops/bool_tensor.rs +++ b/crates/burn-ndarray/src/ops/bool_tensor.rs @@ -2,7 +2,7 @@ use alloc::vec; use alloc::vec::Vec; use burn_tensor::ops::{BoolTensorOps, IntTensorOps}; -use burn_tensor::{ElementConversion, Reader}; +use burn_tensor::ElementConversion; use core::ops::Range; use ndarray::IntoDimension; @@ -30,13 +30,12 @@ impl BoolTensorOps for NdArray { tensor.shape() } - fn bool_into_data( + async fn bool_into_data( tensor: as Backend>::BoolTensorPrimitive, - ) -> Reader { + ) -> TensorData { let shape = tensor.shape(); let values = tensor.array.into_iter().collect(); - - Reader::Concrete(TensorData::new(values, shape)) + TensorData::new(values, shape) } fn bool_to_device( @@ -63,10 +62,12 @@ impl BoolTensorOps for NdArray { fn bool_into_int( tensor: as Backend>::BoolTensorPrimitive, ) -> NdArrayTensor { - let data = Self::bool_into_data(tensor) - .read_sync() - .expect("Always sync with ndarray"); - NdArray::::int_from_data(data.convert::(), &NdArrayDevice::Cpu) + let shape = tensor.shape(); + let values = tensor.array.into_iter().collect(); + NdArray::::int_from_data( + TensorData::new(values, shape).convert::(), + &NdArrayDevice::Cpu, + ) } fn bool_device( diff --git a/crates/burn-ndarray/src/ops/int_tensor.rs b/crates/burn-ndarray/src/ops/int_tensor.rs index 417e61567..058915abb 100644 --- a/crates/burn-ndarray/src/ops/int_tensor.rs +++ b/crates/burn-ndarray/src/ops/int_tensor.rs @@ -3,7 +3,7 @@ use alloc::vec; use alloc::vec::Vec; use burn_common::rand::get_seeded_rng; use burn_tensor::ops::IntTensorOps; -use burn_tensor::{Distribution, Reader}; +use burn_tensor::Distribution; use burn_tensor::ElementConversion; use core::ops::Range; @@ -32,11 +32,10 @@ impl IntTensorOps for NdArray { tensor.shape() } - fn int_into_data(tensor: NdArrayTensor) -> Reader { + async fn int_into_data(tensor: NdArrayTensor) -> TensorData { let shape = tensor.shape(); let values = tensor.array.into_iter().collect(); - - Reader::Concrete(TensorData::new(values, shape)) + TensorData::new(values, shape) } fn int_to_device( diff --git a/crates/burn-ndarray/src/ops/tensor.rs b/crates/burn-ndarray/src/ops/tensor.rs index cb406183e..deb260f40 100644 --- a/crates/burn-ndarray/src/ops/tensor.rs +++ b/crates/burn-ndarray/src/ops/tensor.rs @@ -11,8 +11,8 @@ use crate::{NdArrayDevice, SEED}; // Workspace crates use burn_common::rand::get_seeded_rng; +use burn_tensor::Distribution; use burn_tensor::{backend::Backend, ops::FloatTensorOps, ElementConversion, Shape, TensorData}; -use burn_tensor::{Distribution, Reader}; #[cfg(not(feature = "std"))] #[allow(unused_imports)] @@ -51,11 +51,10 @@ impl FloatTensorOps for NdArray { tensor.shape() } - fn float_into_data(tensor: NdArrayTensor) -> Reader { + async fn float_into_data(tensor: NdArrayTensor) -> TensorData { let shape = tensor.shape(); let values = tensor.array.into_iter().collect(); - - Reader::Concrete(TensorData::new(values, shape)) + TensorData::new(values, shape) } fn float_device(_tensor: &NdArrayTensor) -> NdArrayDevice { diff --git a/crates/burn-tch/src/ops/bool_tensor.rs b/crates/burn-tch/src/ops/bool_tensor.rs index bc304cee5..792269e6c 100644 --- a/crates/burn-tch/src/ops/bool_tensor.rs +++ b/crates/burn-tch/src/ops/bool_tensor.rs @@ -1,6 +1,6 @@ use super::TchOps; use crate::{element::TchElement, LibTorch, LibTorchDevice, TchTensor}; -use burn_tensor::{backend::Backend, ops::BoolTensorOps, Reader, Shape, TensorData}; +use burn_tensor::{backend::Backend, ops::BoolTensorOps, Shape, TensorData}; use std::ops::Range; impl BoolTensorOps for LibTorch { @@ -23,12 +23,11 @@ impl BoolTensorOps for LibTorch { TchOps::repeat(tensor, dim, times) } - fn bool_into_data(tensor: TchTensor) -> Reader { + async fn bool_into_data(tensor: TchTensor) -> TensorData { let shape = Self::bool_shape(&tensor); let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()])); let values: Result, tch::TchError> = tensor.tensor.shallow_clone().try_into(); - - Reader::Concrete(TensorData::new(values.unwrap(), shape)) + TensorData::new(values.unwrap(), shape) } fn bool_to_device( @@ -143,13 +142,13 @@ impl BoolTensorOps for LibTorch { TchOps::flip(tensor, axes) } - fn bool_argwhere( + async fn bool_argwhere( tensor: as Backend>::BoolTensorPrimitive, ) -> TchTensor { TchTensor::new(tensor.tensor.argwhere()) } - fn bool_nonzero( + async fn bool_nonzero( tensor: as Backend>::BoolTensorPrimitive, ) -> Vec> { tensor diff --git a/crates/burn-tch/src/ops/int_tensor.rs b/crates/burn-tch/src/ops/int_tensor.rs index 91ebd1b35..7a4a246d0 100644 --- a/crates/burn-tch/src/ops/int_tensor.rs +++ b/crates/burn-tch/src/ops/int_tensor.rs @@ -1,6 +1,6 @@ use std::ops::Range; -use burn_tensor::{backend::Backend, ops::IntTensorOps, Distribution, Reader, Shape, TensorData}; +use burn_tensor::{backend::Backend, ops::IntTensorOps, Distribution, Shape, TensorData}; use crate::{element::TchElement, LibTorch, LibTorchDevice, TchShape, TchTensor}; @@ -26,12 +26,11 @@ impl IntTensorOps for LibTorch { TchOps::repeat(tensor, dim, times) } - fn int_into_data(tensor: TchTensor) -> Reader { + async fn int_into_data(tensor: TchTensor) -> TensorData { let shape = Self::int_shape(&tensor); let tensor = Self::int_reshape(tensor.clone(), Shape::new([shape.num_elements()])); let values: Result, tch::TchError> = tensor.tensor.shallow_clone().try_into(); - - Reader::Concrete(TensorData::new(values.unwrap(), shape)) + TensorData::new(values.unwrap(), shape) } fn int_to_device( diff --git a/crates/burn-tch/src/ops/tensor.rs b/crates/burn-tch/src/ops/tensor.rs index f827862ae..c9ce30c5c 100644 --- a/crates/burn-tch/src/ops/tensor.rs +++ b/crates/burn-tch/src/ops/tensor.rs @@ -1,8 +1,7 @@ use super::TchOps; use crate::{element::TchElement, LibTorch, LibTorchDevice, TchShape, TchTensor}; use burn_tensor::{ - backend::Backend, ops::FloatTensorOps, Distribution, ElementConversion, Reader, Shape, - TensorData, + backend::Backend, ops::FloatTensorOps, Distribution, ElementConversion, Shape, TensorData, }; use std::ops::Range; @@ -71,14 +70,14 @@ impl FloatTensorOps for LibTorch { tensor.shape() } - fn float_into_data( + async fn float_into_data( tensor: as Backend>::FloatTensorPrimitive, - ) -> Reader { + ) -> TensorData { let shape = Self::float_shape(&tensor); let tensor = Self::float_reshape(tensor.clone(), Shape::new([shape.num_elements()])); let values: Result, tch::TchError> = tensor.tensor.try_into(); - Reader::Concrete(TensorData::new(values.unwrap(), shape)) + TensorData::new(values.unwrap(), shape) } fn float_device(tensor: &TchTensor) -> LibTorchDevice { diff --git a/crates/burn-tch/src/tensor.rs b/crates/burn-tch/src/tensor.rs index 8a15368da..5be4b7721 100644 --- a/crates/burn-tch/src/tensor.rs +++ b/crates/burn-tch/src/tensor.rs @@ -295,21 +295,6 @@ impl TchTensor } } -#[cfg(test)] -mod utils { - use super::*; - use crate::{backend::LibTorch, element::TchElement}; - - impl TchTensor { - pub(crate) fn into_data(self) -> TensorData - where - P: tch::kind::Element, - { - as FloatTensorOps>>::float_into_data(self).read() - } - } -} - impl TchTensor { /// Creates an empty tensor from a shape and a device. /// @@ -345,7 +330,7 @@ mod tests { ); let tensor = TchTensor::::from_data(data_expected.clone(), tch::Device::Cpu); - let data_actual = tensor.into_data(); + let data_actual = Tensor::, 1>::from_primitive(tensor).into_data(); assert_eq!(data_expected, data_actual); } @@ -359,7 +344,7 @@ mod tests { ); let tensor = TchTensor::::from_data(data_expected.clone(), tch::Device::Cpu); - let data_actual = tensor.into_data(); + let data_actual = Tensor::, 2>::from_primitive(tensor).into_data(); assert_eq!(data_expected, data_actual); } diff --git a/crates/burn-tensor/Cargo.toml b/crates/burn-tensor/Cargo.toml index 39f05a6b8..3d5c91a34 100644 --- a/crates/burn-tensor/Cargo.toml +++ b/crates/burn-tensor/Cargo.toml @@ -17,7 +17,6 @@ experimental-named-tensor = [] export_tests = ["burn-tensor-testgen"] std = ["rand/std", "half/std", "num-traits/std"] repr = [] -wasm-sync = [] [dependencies] burn-common = { path = "../burn-common", version = "0.14.0", default-features = false } @@ -27,7 +26,7 @@ derive-new = { workspace = true } half = { workspace = true, features = ["bytemuck"] } num-traits = { workspace = true } rand = { workspace = true } -rand_distr = { workspace = true } # use instead of statrs because it supports no_std +rand_distr = { workspace = true } # use instead of statrs because it supports no_std bytemuck = { workspace = true } # The same implementation of HashMap in std but with no_std support (only needs alloc crate) diff --git a/crates/burn-tensor/src/lib.rs b/crates/burn-tensor/src/lib.rs index 14b01f4d5..dade906ac 100644 --- a/crates/burn-tensor/src/lib.rs +++ b/crates/burn-tensor/src/lib.rs @@ -25,4 +25,4 @@ pub use half::{bf16, f16}; pub(crate) use tensor::check::macros::check; pub use tensor::*; -pub use burn_common::reader::Reader; // Useful so that backends don't have to add `burn_common` as +pub use burn_common::reader::*; // Useful so that backends don't have to add `burn_common` as a dependency. diff --git a/crates/burn-tensor/src/tensor/api/argwhere.rs b/crates/burn-tensor/src/tensor/api/argwhere.rs index 76e0ea0f5..d3a722993 100644 --- a/crates/burn-tensor/src/tensor/api/argwhere.rs +++ b/crates/burn-tensor/src/tensor/api/argwhere.rs @@ -1,15 +1,11 @@ -use crate::{ - backend::Backend, - ops::{BoolTensor, IntTensor}, - Device, ElementConversion, Shape, TensorData, -}; +use crate::{backend::Backend, ops::IntTensor, Device, ElementConversion, Shape, TensorData}; use alloc::vec::Vec; /// Compute the indices of the elements that are non-zero, grouped by element. /// /// # Arguments /// -/// * `tensor` - The input tensor. +/// * `data` - The input tensor data. /// /// # Returns /// @@ -22,45 +18,7 @@ use alloc::vec::Vec; /// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved /// by static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. -#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] -pub fn argwhere(tensor: BoolTensor) -> IntTensor { - // Size of each output tensor is variable (= number of nonzero elements in the tensor). - // Reading the data to count the number of truth values might cause sync but is required. - // let dims = B::bool_shape(&tensor).dims; - let device = B::bool_device(&tensor); - let data = B::bool_into_data(tensor).read(); - - argwhere_data::(data, &device) -} - -/// Compute the indices of the elements that are non-zero, grouped by element. -/// -/// # Arguments -/// -/// * `tensor` - The input tensor. -/// -/// # Returns -/// -/// A vector of tensors, one for each dimension of the given tensor, containing the indices of -/// the non-zero elements in that dimension. -/// -/// # Remarks -/// -/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation. -/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved -/// by static dispatch. It is not designed for direct usage by users, and not recommended to import -/// or use this function directly. -#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] -pub async fn argwhere(tensor: BoolTensor) -> IntTensor { - // Size of each output tensor is variable (= number of nonzero elements in the tensor). - // Reading the data to count the number of truth values might cause sync but is required. - let device = B::bool_device(&tensor); - let data = B::bool_into_data(tensor).read().await; - - argwhere_data::(data, &device) -} - -fn argwhere_data( +pub fn argwhere_data( data: TensorData, device: &Device, ) -> IntTensor { diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index 7d6d6eb69..3a1bc4d46 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -2,19 +2,16 @@ use alloc::vec::Vec; -#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] use alloc::format; -#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] use alloc::string::String; -#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] use alloc::vec; -use burn_common::{reader::Reader, stub::Mutex}; +use burn_common::stub::Mutex; +use core::future::Future; use core::iter::repeat; use core::{fmt::Debug, ops::Range}; use serde::{Deserialize, Deserializer}; -#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] use serde::{Serialize, Serializer}; use crate::check::TensorCheck; @@ -675,28 +672,28 @@ where Self::new(K::to_device(self.primitive, device)) } - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - /// Returns the data of the current tensor. - pub async fn into_data(self) -> TensorData { - K::into_data(self.primitive).read().await - } - - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - /// Returns the data of the current tensor. + /// Converts the data of the current tensor. pub fn into_data(self) -> TensorData { - K::into_data(self.primitive).read() + crate::try_read_sync(self.into_data_async()).expect( + "Failed to read tensor data synchronously. + This can happen on platforms that don't support blocking futures like WASM. + If possible, try using into_data_async instead.", + ) } - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] /// Returns the data of the current tensor. - pub async fn to_data(&self) -> TensorData { - K::into_data(self.primitive.clone()).read().await + pub fn to_data(&self) -> TensorData { + self.clone().into_data() } - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - /// Returns the data of the current tensor without taking ownership. - pub fn to_data(&self) -> TensorData { - Self::into_data(self.clone()) + /// Returns the data of the current tensor. + pub async fn into_data_async(self) -> TensorData { + K::into_data_async(self.primitive).await + } + + /// Returns the data of the current tensor. + pub async fn to_data_async(&self) -> TensorData { + self.clone().into_data_async().await } /// Create a tensor from the given data on the given device. @@ -875,11 +872,12 @@ where /// # Panics /// /// If the tensor doesn't have one element. - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] + /// If the backend fails to read the tensor data synchronously. pub fn into_scalar(self) -> K::Elem { - check!(TensorCheck::into_scalar(&self.shape())); - let x = self.into_data().iter().next().unwrap(); - x + crate::try_read_sync(self.into_scalar_async()).expect( + "Failed to read tensor data synchronously. This can happen on platforms + that don't support blocking futures like WASM. Try into_scalar_async instead.", + ) } /// Convert the tensor into a scalar. @@ -887,10 +885,9 @@ where /// # Panics /// /// If the tensor doesn't have one element. - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - pub async fn into_scalar(self) -> K::Elem { + pub async fn into_scalar_async(self) -> K::Elem { check!(TensorCheck::into_scalar(&self.shape())); - let x = self.into_data().await.iter().next().unwrap(); + let x = self.into_data_async().await.iter().next().unwrap(); x } @@ -989,7 +986,6 @@ where K: BasicOps, >::Elem: Debug, { - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] #[inline] fn push_newline_indent(acc: &mut String, indent: usize) { acc.push('\n'); @@ -998,7 +994,6 @@ where } } - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn fmt_inner_tensor( &self, acc: &mut String, @@ -1015,18 +1010,18 @@ where let range: [core::ops::Range; D] = core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1); - let elem = &self - .clone() - .slice(range) - .into_data() - .iter::<>::Elem>() - .next() - .unwrap(); - acc.push_str(&format!("{elem:?}")); + let data = + burn_common::reader::try_read_sync(self.clone().slice(range).into_data_async()); + + if let Some(data) = data { + let elem = data.iter::<>::Elem>().next().unwrap(); + acc.push_str(&format!("{elem:?}")); + } else { + acc.push_str(""); + } } } - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn fmt_outer_tensor( &self, acc: &mut String, @@ -1061,7 +1056,6 @@ where /// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output. /// * `depth` - The current depth of the tensor dimensions being processed. /// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension. - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn display_recursive( &self, acc: &mut String, @@ -1172,7 +1166,6 @@ where fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { writeln!(f, "Tensor {{")?; - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] { let po = PRINT_OPTS.lock().unwrap(); let mut acc = String::new(); @@ -1435,7 +1428,7 @@ pub trait BasicOps: TensorKind { device: &B::Device, ) -> Self::Primitive; - /// Extracts the data from the tensor. + /// Extracts the data from the tensor asynchronously. /// /// # Arguments /// @@ -1453,7 +1446,9 @@ pub trait BasicOps: TensorKind { /// /// For extracting the data of a tensor, users should prefer the [Tensor::into_data](Tensor::into_data) function, /// which is more high-level and designed for public use. - fn into_data(tensor: Self::Primitive) -> Reader; + fn into_data_async( + tensor: Self::Primitive, + ) -> impl Future + Send; /// Creates a tensor from the given data. /// @@ -1724,8 +1719,8 @@ impl BasicOps for Float { B::float_to_device(tensor, device) } - fn into_data(tensor: Self::Primitive) -> Reader { - B::float_into_data(tensor) + async fn into_data_async(tensor: Self::Primitive) -> TensorData { + B::float_into_data(tensor).await } fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive { @@ -1846,8 +1841,8 @@ impl BasicOps for Int { B::int_to_device(tensor, device) } - fn into_data(tensor: Self::Primitive) -> Reader { - B::int_into_data(tensor) + async fn into_data_async(tensor: Self::Primitive) -> TensorData { + B::int_into_data(tensor).await } fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive { @@ -1968,8 +1963,8 @@ impl BasicOps for Bool { B::bool_to_device(tensor, device) } - fn into_data(tensor: Self::Primitive) -> Reader { - B::bool_into_data(tensor) + async fn into_data_async(tensor: Self::Primitive) -> TensorData { + B::bool_into_data(tensor).await } fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive { @@ -2230,7 +2225,6 @@ impl BroadcastArgs for [i32; D2] { } } -#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] impl Serialize for Tensor where B: Backend, diff --git a/crates/burn-tensor/src/tensor/api/bool.rs b/crates/burn-tensor/src/tensor/api/bool.rs index 85a6a7aaa..27635c366 100644 --- a/crates/burn-tensor/src/tensor/api/bool.rs +++ b/crates/burn-tensor/src/tensor/api/bool.rs @@ -1,8 +1,7 @@ use crate::{backend::Backend, Bool, Int, Shape, Tensor, TensorData}; use alloc::vec::Vec; -#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] -use crate::argwhere; +use crate::try_read_sync; /// The part of the tensor to keep when creating a triangular mask. enum TriPart { @@ -46,27 +45,21 @@ where /// /// A vector of tensors, one for each dimension of the given tensor, containing the indices of /// the non-zero elements in that dimension. - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] pub fn nonzero(self) -> Vec> { - B::bool_nonzero(self.primitive) - .into_iter() - .map(Tensor::new) - .collect() + try_read_sync(self.nonzero_async()) + .expect("Failed to read tensor data synchronously. Try using nonzero_async instead.") } - /// Compute the indices of the elements that are true. + /// Compute the indices of the elements that are non-zero. /// /// # Returns /// /// A vector of tensors, one for each dimension of the given tensor, containing the indices of /// the non-zero elements in that dimension. - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - pub async fn nonzero(self) -> Vec> { - let indices = self.argwhere().await.primitive; - let dims = B::int_shape(&indices).dims; - B::int_chunk(indices, dims[1], 1) + pub async fn nonzero_async(self) -> Vec> { + B::bool_nonzero(self.primitive) + .await .into_iter() - .map(|t| B::int_reshape(t, Shape::new([dims[0]]))) .map(Tensor::new) .collect() } @@ -77,9 +70,9 @@ where /// /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the /// result contains the indices of a non-zero element. - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] pub fn argwhere(self) -> Tensor { - Tensor::new(B::bool_argwhere(self.primitive)) + try_read_sync(self.argwhere_async()) + .expect("Failed to read tensor data synchronously. Try using argwhere_async instead.") } /// Compute the indices of the elements that are true, grouped by element. @@ -88,9 +81,8 @@ where /// /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the /// result contains the indices of a non-zero element. - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - pub async fn argwhere(self) -> Tensor { - Tensor::new(argwhere::(self.primitive).await) + pub async fn argwhere_async(self) -> Tensor { + Tensor::new(B::bool_argwhere(self.primitive).await) } /// Creates a mask for the upper, lower triangle, or diagonal of a matrix, which can be used to diff --git a/crates/burn-tensor/src/tensor/api/float.rs b/crates/burn-tensor/src/tensor/api/float.rs index 1f2ee1a90..e0bcb36f0 100644 --- a/crates/burn-tensor/src/tensor/api/float.rs +++ b/crates/burn-tensor/src/tensor/api/float.rs @@ -10,9 +10,6 @@ use crate::tensor::{Distribution, Shape, TensorData}; use crate::Int; use crate::Tensor; -#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] -use crate::{argsort, sort, sort_with_indices, Float}; - impl Tensor where B: Backend, @@ -265,88 +262,4 @@ where .matmul(centered) .div_scalar(n as f32 - correction_factor as f32) } - - /// Sort the elements by value in ascending order along a given dimension. - /// - /// This sort is unstable (i.e., may reorder equal elements). - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - pub async fn sort(self, dim: usize) -> Tensor { - Tensor::new(sort::(self.primitive, dim, /*descending*/ false).await) - } - - /// Sort the elements by value in descending order along a given dimension. - /// - /// This sort is unstable (i.e., may reorder equal elements). - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - pub async fn sort_descending(self, dim: usize) -> Tensor { - Tensor::new(sort::(self.primitive, dim, /*descending*/ true).await) - } - - /// Sort the elements by value in ascending order along a given dimension. - /// Also returns the indices. - /// - /// This sort is unstable (i.e., may reorder equal elements). - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - pub async fn sort_with_indices(self, dim: usize) -> (Tensor, Tensor) { - check!(TensorCheck::sort_dim::("Sort_with_indices", dim)); - let (values, indices) = - sort_with_indices::(self.primitive, dim, /*descending*/ false).await; - (Tensor::new(values), Tensor::new(indices)) - } - - /// Sort the elements by value in descending order along a given dimension. - /// Also returns the indices. - /// - /// This sort is unstable (i.e., may reorder equal elements). - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - pub async fn sort_descending_with_indices( - self, - dim: usize, - ) -> (Tensor, Tensor) { - check!(TensorCheck::sort_dim::("Sort_with_indices", dim)); - let (values, indices) = - sort_with_indices::(self.primitive, dim, /*descending*/ true).await; - (Tensor::new(values), Tensor::new(indices)) - } - - /// Returns the indices that sort the elements by value in ascending order along a given dimension. - /// - /// This sort is unstable (i.e., may reorder equal elements). - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - pub async fn argsort(self, dim: usize) -> Tensor { - check!(TensorCheck::sort_dim::("Argsort", dim)); - Tensor::new(argsort::(self.primitive, dim, /*descending*/ false).await) - } - - /// Returns the indices that sort the elements by value in descending order along a given dimension. - /// - /// This sort is unstable (i.e., may reorder equal elements). - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - pub async fn argsort_descending(self, dim: usize) -> Tensor { - check!(TensorCheck::sort_dim::("Argsort", dim)); - Tensor::new(argsort::(self.primitive, dim, /*descending*/ true).await) - } - - /// Returns the `k` largest elements of the given input tensor along a given dimension. - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - pub async fn topk(self, k: usize, dim: usize) -> Tensor { - let k_indices = Tensor::arange(0..k as i64, &self.device()); - self.sort_descending(dim).await.select(dim, k_indices) - } - - /// Returns the `k` largest elements of the given input tensor along a given dimension. - /// Also returns the indices. - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - pub async fn topk_with_indices( - self, - k: usize, - dim: usize, - ) -> (Tensor, Tensor) { - let k_indices = Tensor::arange(0..k as i64, &self.device()); - let (values, indices) = self.sort_descending_with_indices(dim).await; - ( - values.select(dim, k_indices.clone()), - indices.select(dim, k_indices), - ) - } } diff --git a/crates/burn-tensor/src/tensor/api/int.rs b/crates/burn-tensor/src/tensor/api/int.rs index e8b91dca4..1accdd68e 100644 --- a/crates/burn-tensor/src/tensor/api/int.rs +++ b/crates/burn-tensor/src/tensor/api/int.rs @@ -2,9 +2,6 @@ use crate::{backend::Backend, Float, Int, Shape, Tensor, TensorData}; use core::ops::Range; -#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] -use crate::{argsort, check, check::TensorCheck, sort, sort_with_indices}; - impl Tensor where B: Backend, @@ -100,88 +97,4 @@ where ) -> Tensor { Tensor::new(B::int_cartesian_grid::(shape, device)) } - - /// Sort the elements by value in ascending order along a given dimension. - /// - /// This sort is unstable (i.e., may reorder equal elements). - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - pub async fn sort(self, dim: usize) -> Tensor { - Tensor::new(sort::(self.primitive, dim, /* descending */ false).await) - } - - /// Sort the elements by value in descending order along a given dimension. - /// - /// This sort is unstable (i.e., may reorder equal elements). - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - pub async fn sort_descending(self, dim: usize) -> Tensor { - Tensor::new(sort::(self.primitive, dim, /* descending */ true).await) - } - - /// Sort the elements by value in ascending order along a given dimension. - /// Also returns the indices. - /// - /// This sort is unstable (i.e., may reorder equal elements). - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - pub async fn sort_with_indices(self, dim: usize) -> (Tensor, Tensor) { - check!(TensorCheck::sort_dim::("Sort_with_indices", dim)); - let (values, indices) = - sort_with_indices::(self.primitive, dim, /*descending*/ false).await; - (Tensor::new(values), Tensor::new(indices)) - } - - /// Sort the elements by value in descending order along a given dimension. - /// Also returns the indices. - /// - /// This sort is unstable (i.e., may reorder equal elements). - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - pub async fn sort_descending_with_indices( - self, - dim: usize, - ) -> (Tensor, Tensor) { - check!(TensorCheck::sort_dim::("Sort_with_indices", dim)); - let (values, indices) = - sort_with_indices::(self.primitive, dim, /*descending*/ true).await; - (Tensor::new(values), Tensor::new(indices)) - } - - /// Returns the indices that sort the elements by value in ascending order along a given dimension. - /// - /// This sort is unstable (i.e., may reorder equal elements). - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - pub async fn argsort(self, dim: usize) -> Tensor { - check!(TensorCheck::sort_dim::("Argsort", dim)); - Tensor::new(argsort::(self.primitive, dim, /*descending*/ false).await) - } - - /// Returns the indices that sort the elements by value in descending order along a given dimension. - /// - /// This sort is unstable (i.e., may reorder equal elements). - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - pub async fn argsort_descending(self, dim: usize) -> Tensor { - check!(TensorCheck::sort_dim::("Argsort", dim)); - Tensor::new(argsort::(self.primitive, dim, /*descending*/ true).await) - } - - /// Returns the `k` largest elements of the given input tensor along a given dimension. - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - pub async fn topk(self, k: usize, dim: usize) -> Tensor { - let k_indices = Tensor::arange(0..k as i64, &self.device()); - self.sort_descending(dim).await.select(dim, k_indices) - } - - /// Returns the `k` largest elements of the given input tensor along a given dimension. - /// Also returns the indices. - #[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] - pub async fn topk_with_indices( - self, - k: usize, - dim: usize, - ) -> (Tensor, Tensor) { - let k_indices = Tensor::arange(0..k as i64, &self.device()); - let (values, indices) = self.sort_descending_with_indices(dim).await; - ( - values.select(dim, k_indices.clone()), - indices.select(dim, k_indices), - ) - } } diff --git a/crates/burn-tensor/src/tensor/api/mod.rs b/crates/burn-tensor/src/tensor/api/mod.rs index 62118c93a..60272d80b 100644 --- a/crates/burn-tensor/src/tensor/api/mod.rs +++ b/crates/burn-tensor/src/tensor/api/mod.rs @@ -13,7 +13,7 @@ mod narrow; mod numeric; mod sort; -pub use argwhere::argwhere; +pub use argwhere::argwhere_data; pub use autodiff::*; pub use base::*; pub use cartesian_grid::cartesian_grid; diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 7bbaa363e..a7c7b3cf6 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -642,9 +642,6 @@ where /// A boolean scalar. /// /// # Remarks - /// - /// This method is only available for non-wasm targets or when the `wasm-sync` feature is enabled. - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] pub fn all_close(self, other: Self, rtol: Option, atol: Option) -> bool { self.is_close(other, rtol, atol).all().into_scalar() } @@ -671,7 +668,6 @@ where /// Sort the elements by value in ascending order along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] pub fn sort(self, dim: usize) -> Tensor { check!(TensorCheck::sort_dim::("Sort", dim)); Tensor::new(K::sort(self.primitive, dim, /*descending*/ false)) @@ -680,7 +676,6 @@ where /// Sort the elements by value in descending order along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] pub fn sort_descending(self, dim: usize) -> Tensor { check!(TensorCheck::sort_dim::("Sort", dim)); Tensor::new(K::sort(self.primitive, dim, /*descending*/ true)) @@ -690,7 +685,6 @@ where /// Also returns the indices. /// /// This sort is unstable (i.e., may reorder equal elements). - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] pub fn sort_with_indices(self, dim: usize) -> (Tensor, Tensor) { check!(TensorCheck::sort_dim::("Sort_with_indices", dim)); let (values, indices) = @@ -702,7 +696,6 @@ where /// Also returns the indices. /// /// This sort is unstable (i.e., may reorder equal elements). - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] pub fn sort_descending_with_indices(self, dim: usize) -> (Tensor, Tensor) { check!(TensorCheck::sort_dim::("Sort_with_indices", dim)); let (values, indices) = K::sort_with_indices(self.primitive, dim, /*descending*/ true); @@ -712,7 +705,6 @@ where /// Returns the indices that sort the elements by value in ascending order along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] pub fn argsort(self, dim: usize) -> Tensor { check!(TensorCheck::sort_dim::("Argsort", dim)); Tensor::new(K::argsort(self.primitive, dim, /*descending*/ false)) @@ -721,14 +713,12 @@ where /// Returns the indices that sort the elements by value in descending order along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] pub fn argsort_descending(self, dim: usize) -> Tensor { check!(TensorCheck::sort_dim::("Argsort", dim)); Tensor::new(K::argsort(self.primitive, dim, /*descending*/ true)) } /// Returns the `k` largest elements of the given input tensor along a given dimension. - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] pub fn topk(self, k: usize, dim: usize) -> Tensor { let k_indices = Tensor::arange(0..k as i64, &self.device()); self.sort_descending(dim).select(dim, k_indices) @@ -736,7 +726,6 @@ where /// Returns the `k` largest elements of the given input tensor along a given dimension. /// Also returns the indices. - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] pub fn topk_with_indices(self, k: usize, dim: usize) -> (Tensor, Tensor) { let k_indices = Tensor::arange(0..k as i64, &self.device()); let (values, indices) = self.sort_descending_with_indices(dim); @@ -2029,7 +2018,6 @@ where /// /// Users should prefer the [Tensor::sort](Tensor::sort) function, /// which is more high-level and designed for public use. - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn sort( tensor: Self::Primitive, dim: usize, @@ -2059,7 +2047,6 @@ where /// For sorting the elements of a tensor, users should prefer the /// [Tensor::sort_with_indices](Tensor::sort_with_indices) function, which is more high-level /// and designed for public use. - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn sort_with_indices( tensor: Self::Primitive, dim: usize, @@ -2087,7 +2074,6 @@ where /// /// Users should prefer the [Tensor::argsort](Tensor::argsort) function, /// which is more high-level and designed for public use. - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn argsort( tensor: Self::Primitive, dim: usize, @@ -2411,7 +2397,6 @@ impl Numeric for Int { B::int_sign(tensor) } - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn sort( tensor: Self::Primitive, dim: usize, @@ -2420,7 +2405,6 @@ impl Numeric for Int { B::int_sort(tensor, dim, descending) } - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn sort_with_indices( tensor: Self::Primitive, dim: usize, @@ -2429,7 +2413,6 @@ impl Numeric for Int { B::int_sort_with_indices(tensor, dim, descending) } - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn argsort( tensor: Self::Primitive, dim: usize, @@ -2758,7 +2741,6 @@ impl Numeric for Float { B::float_sign(tensor) } - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn sort( tensor: Self::Primitive, dim: usize, @@ -2767,7 +2749,6 @@ impl Numeric for Float { B::float_sort(tensor, dim, descending) } - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn sort_with_indices( tensor: Self::Primitive, dim: usize, @@ -2776,7 +2757,6 @@ impl Numeric for Float { B::float_sort_with_indices(tensor, dim, descending) } - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn argsort( tensor: Self::Primitive, dim: usize, diff --git a/crates/burn-tensor/src/tensor/api/sort.rs b/crates/burn-tensor/src/tensor/api/sort.rs index 2e214b7e8..799bcb86b 100644 --- a/crates/burn-tensor/src/tensor/api/sort.rs +++ b/crates/burn-tensor/src/tensor/api/sort.rs @@ -6,6 +6,7 @@ use crate::{ BasicOps, Device, Element, ElementComparison, ElementConversion, TensorData, TensorKind, }; use alloc::vec::Vec; +use burn_common::reader::try_read_sync; /// Sort the elements of the input `tensor` by value along a given dimension. /// @@ -27,7 +28,6 @@ use alloc::vec::Vec; /// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved /// by static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. -#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] pub fn sort + BasicOps>( tensor: K::Primitive, dim: usize, @@ -37,43 +37,7 @@ where >::Elem: Element, { let device = K::device(&tensor); - let data = K::into_data(tensor).read(); - - sort_data::(data, dim, &device, descending) -} - -/// Sort the elements of the input `tensor` by value along a given dimension. -/// -/// This sort is unstable (i.e., may reorder equal elements). -/// -/// # Arguments -/// -/// * `tensor` - The input tensor. -/// * `dim` - The axis along which to sort. -/// * `descending` - The sorting order. -/// -/// # Returns -/// -/// A tensor with the same shape as the input tensor, where the elements are sorted by value. -/// -/// # Remarks -/// -/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation. -/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved -/// by static dispatch. It is not designed for direct usage by users, and not recommended to import -/// or use this function directly. -#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] -pub async fn sort + BasicOps>( - tensor: K::Primitive, - dim: usize, - descending: bool, -) -> K::Primitive -where - >::Elem: Element, -{ - let device = K::device(&tensor); - let data = K::into_data(tensor).read().await; - + let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchonously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."); sort_data::(data, dim, &device, descending) } @@ -119,7 +83,6 @@ where /// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved /// by static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. -#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] pub fn sort_with_indices + BasicOps>( tensor: K::Primitive, dim: usize, @@ -129,44 +92,7 @@ where >::Elem: Element, { let device = K::device(&tensor); - let data = K::into_data(tensor).read(); - - sort_data_with_indices::(data, dim, &device, descending) -} - -/// Sort the elements of the input `tensor` by value along a given dimension. -/// -/// This sort is unstable (i.e., may reorder equal elements). -/// -/// # Arguments -/// -/// * `tensor` - The input tensor. -/// * `dim` - The axis along which to sort. -/// * `descending` - The sorting order. -/// -/// # Returns -/// -/// A tensor with the same shape as the input tensor and corresponding indices, where -/// the elements are sorted by value and the indices map back to the original input tensor. -/// -/// # Remarks -/// -/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation. -/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved -/// by static dispatch. It is not designed for direct usage by users, and not recommended to import -/// or use this function directly. -#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] -pub async fn sort_with_indices + BasicOps>( - tensor: K::Primitive, - dim: usize, - descending: bool, -) -> (K::Primitive, IntTensor) -where - >::Elem: Element, -{ - let device = K::device(&tensor); - let data = K::into_data(tensor).read().await; - + let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchonously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."); sort_data_with_indices::(data, dim, &device, descending) } @@ -253,7 +179,6 @@ where /// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved /// by static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. -#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] pub fn argsort + BasicOps>( tensor: K::Primitive, dim: usize, @@ -263,42 +188,7 @@ where >::Elem: Element, { let device = K::device(&tensor); - let data = K::into_data(tensor).read(); - - argsort_data::(data, dim, &device, descending) -} - -/// Returns the indices that sort the elements of the input `tensor` along a given dimension. -/// -/// This sort is unstable (i.e., may reorder equal elements). -/// -/// # Arguments -/// -/// * `tensor` - The input tensor. -/// * `dim` - The axis along which to sort. -/// * `descending` - The sorting order. -/// -/// # Returns -/// -/// A tensor with the same shape as the input tensor the indices map back to the original input tensor. -/// -/// # Remarks -/// -/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation. -/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved -/// by static dispatch. It is not designed for direct usage by users, and not recommended to import -/// or use this function directly. -#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))] -pub async fn argsort + BasicOps>( - tensor: K::Primitive, - dim: usize, - descending: bool, -) -> IntTensor -where - >::Elem: Element, -{ - let device = K::device(&tensor); - let data = K::into_data(tensor).read().await; + let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchonously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."); argsort_data::(data, dim, &device, descending) } diff --git a/crates/burn-tensor/src/tensor/ops/bool_tensor.rs b/crates/burn-tensor/src/tensor/ops/bool_tensor.rs index aa82714fd..8c591bb93 100644 --- a/crates/burn-tensor/src/tensor/ops/bool_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/bool_tensor.rs @@ -3,14 +3,11 @@ use super::{ IntTensor, }; use crate::{ - backend::Backend, chunk, narrow, tensor::Shape, Bool, ElementConversion, Tensor, TensorData, + argwhere_data, backend::Backend, chunk, narrow, tensor::Shape, Bool, ElementConversion, Tensor, + TensorData, }; use alloc::vec::Vec; -use burn_common::reader::Reader; -use core::ops::Range; - -#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] -use crate::argwhere; +use core::{future::Future, ops::Range}; /// Bool Tensor API for basic operations, see [tensor](crate::Tensor) /// for documentation on each function. @@ -47,21 +44,9 @@ pub trait BoolTensorOps { /// # Returns /// /// The data structure with the tensor's data. - fn bool_into_data(tensor: BoolTensor) -> Reader; - - /// Gets the data from the tensor. - /// - /// - /// # Arguments - /// - /// * `data` - The data structure. - /// - /// # Returns - /// - /// The data cloned from the data structure. - fn bool_to_data(tensor: &BoolTensor) -> Reader { - Self::bool_into_data(tensor.clone()) - } + fn bool_into_data( + tensor: BoolTensor, + ) -> impl Future + Send; /// Creates a tensor from the data structure. /// @@ -420,9 +405,16 @@ pub trait BoolTensorOps { /// /// A vector of tensors, one for each dimension of the given tensor, containing the indices of /// the non-zero elements in that dimension. - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - fn bool_argwhere(tensor: BoolTensor) -> IntTensor { - argwhere::(tensor) + fn bool_argwhere( + tensor: BoolTensor, + ) -> impl Future> + Send { + async { + // Size of each output tensor is variable (= number of nonzero elements in the tensor). + // Reading the data to count the number of truth values might cause sync but is required. + let device = B::bool_device(&tensor); + let data = B::bool_into_data(tensor).await; + argwhere_data::(data, &device) + } } /// Compute the indices of the elements that are non-zero. @@ -435,14 +427,17 @@ pub trait BoolTensorOps { /// /// A vector of tensors, one for each dimension of the given tensor, containing the indices of /// the non-zero elements in that dimension. - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] - fn bool_nonzero(tensor: BoolTensor) -> Vec> { - let indices = B::bool_argwhere(tensor); - let dims = B::int_shape(&indices).dims; - B::int_chunk(indices, dims[1], 1) - .into_iter() - .map(|t| B::int_reshape(t, Shape::new([dims[0]]))) - .collect() + fn bool_nonzero( + tensor: BoolTensor, + ) -> impl Future>> + Send { + async { + let indices = B::bool_argwhere(tensor).await; + let dims = B::int_shape(&indices).dims; + B::int_chunk(indices, dims[1], 1) + .into_iter() + .map(|t| B::int_reshape(t, Shape::new([dims[0]]))) + .collect() + } } /// Broadcasts the bool `tensor` to the given `shape`. diff --git a/crates/burn-tensor/src/tensor/ops/int_tensor.rs b/crates/burn-tensor/src/tensor/ops/int_tensor.rs index 2f8aad84f..19cbb3494 100644 --- a/crates/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/int_tensor.rs @@ -6,10 +6,9 @@ use crate::{backend::Backend, tensor::Shape, Distribution, ElementConversion, In use crate::{cartesian_grid, Tensor}; use crate::{tensor::api::chunk, tensor::api::narrow}; use alloc::vec::Vec; -use burn_common::reader::Reader; +use core::future::Future; use core::ops::Range; -#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] use crate::{argsort, sort, sort_with_indices}; /// Int Tensor API for basic and numeric operations, see [tensor](crate::Tensor) @@ -47,20 +46,9 @@ pub trait IntTensorOps { /// # Returns /// /// The data structure with the tensor's data. - fn int_into_data(tensor: IntTensor) -> Reader; - - /// Gets the data from the tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The data cloned from the data structure. - fn int_to_data(tensor: &IntTensor) -> Reader { - Self::int_into_data(tensor.clone()) - } + fn int_into_data( + tensor: IntTensor, + ) -> impl Future + Send; /// Creates a tensor from the data structure. /// @@ -1241,7 +1229,6 @@ pub trait IntTensorOps { /// # Returns /// /// A tensor with the same shape as the input tensor, where the elements are sorted by value. - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn int_sort( tensor: IntTensor, dim: usize, @@ -1263,7 +1250,6 @@ pub trait IntTensorOps { /// /// A tensor with the same shape as the input tensor and corresponding indices, where /// the elements are sorted by value and the indices map back to the original input tensor. - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn int_sort_with_indices( tensor: IntTensor, dim: usize, @@ -1286,7 +1272,6 @@ pub trait IntTensorOps { /// # Returns /// /// A tensor with the same shape as the input tensor the indices map back to the original input tensor. - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn int_argsort( tensor: IntTensor, dim: usize, diff --git a/crates/burn-tensor/src/tensor/ops/tensor.rs b/crates/burn-tensor/src/tensor/ops/tensor.rs index 123893801..19bd446c3 100644 --- a/crates/burn-tensor/src/tensor/ops/tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/tensor.rs @@ -7,10 +7,9 @@ use crate::Tensor; use crate::{backend::Backend, tensor::Shape, Distribution, ElementConversion, Float, TensorData}; use crate::{tensor::api::chunk, tensor::api::narrow}; use alloc::vec::Vec; -use burn_common::reader::Reader; +use core::future::Future; use core::ops::Range; -#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] use crate::{argsort, sort, sort_with_indices}; /// Operations on float tensors. @@ -111,20 +110,9 @@ pub trait FloatTensorOps { /// # Returns /// /// The data structure with the tensor's data. - fn float_to_data(tensor: &FloatTensor) -> Reader { - Self::float_into_data(tensor.clone()) - } - - /// Converts the tensor to a data structure. - /// - /// # Arguments - /// - /// * `tensor` - The tensor. - /// - /// # Returns - /// - /// The data structure with the tensor's data. - fn float_into_data(tensor: FloatTensor) -> Reader; + fn float_into_data( + tensor: FloatTensor, + ) -> impl Future + Send; /// Gets the device of the tensor. /// @@ -1389,7 +1377,6 @@ pub trait FloatTensorOps { /// # Returns /// /// A tensor with the same shape as the input tensor, where the elements are sorted by value. - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn float_sort( tensor: FloatTensor, dim: usize, @@ -1412,7 +1399,6 @@ pub trait FloatTensorOps { /// /// A tensor with the same shape as the input tensor and corresponding indices, where /// the elements are sorted by value and the indices map back to the original input tensor. - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn float_sort_with_indices( tensor: FloatTensor, dim: usize, @@ -1434,7 +1420,6 @@ pub trait FloatTensorOps { /// # Returns /// /// A tensor with the same shape as the input tensor the indices map back to the original input tensor. - #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] fn float_argsort( tensor: FloatTensor, dim: usize, diff --git a/crates/burn-wgpu/Cargo.toml b/crates/burn-wgpu/Cargo.toml index dd923391c..48dab9b8b 100644 --- a/crates/burn-wgpu/Cargo.toml +++ b/crates/burn-wgpu/Cargo.toml @@ -31,7 +31,7 @@ wgpu = { workspace = true, features = ["fragile-send-sync-non-atomic-wasm"] } pollster = { workspace = true } log = { workspace = true } -futures-intrusive = { workspace = true } +async-channel = { workspace = true } derive-new = { workspace = true } hashbrown = { workspace = true } diff --git a/crates/burn-wgpu/src/compute/server.rs b/crates/burn-wgpu/src/compute/server.rs index a03eb74e3..5311dff50 100644 --- a/crates/burn-wgpu/src/compute/server.rs +++ b/crates/burn-wgpu/src/compute/server.rs @@ -113,7 +113,7 @@ where ) } - fn buffer_reader(&mut self, handle: server::Binding) -> BufferReader { + fn create_read_buffer(&mut self, handle: server::Binding) -> wgpu::Buffer { let resource = self.memory_management.get(handle.memory); let size = resource.size(); @@ -134,50 +134,7 @@ where self.tasks_count += 1; self.sync(SyncType::Flush); - - BufferReader::new(buffer_dest) - } -} - -#[derive(new)] -struct BufferReader { - buffer: wgpu::Buffer, -} - -impl BufferReader { - #[cfg(target_family = "wasm")] - async fn read(self, device: alloc::sync::Arc) -> Vec { - self.read_async(&device).await - } - - #[cfg(not(target_family = "wasm"))] - fn read(self, device: &wgpu::Device) -> Vec { - pollster::block_on(self.read_async(device)) - } - - async fn read_async(&self, device: &wgpu::Device) -> Vec { - let buffer_slice = self.buffer.slice(..); - let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel(); - buffer_slice.map_async(wgpu::MapMode::Read, move |v| { - sender - .send(v) - .expect("Unable to send buffer slice result to async channel.") - }); - - device.poll(wgpu::Maintain::Wait); - - let result = receiver.receive().await; - - if let Some(Ok(())) = result { - let data = buffer_slice.get_mapped_range(); - let result = bytemuck::cast_slice(&data).to_vec(); - - drop(data); - self.buffer.unmap(); - result - } else { - panic!("Unable to read buffer {:?}", result) - } + buffer_dest } } @@ -190,15 +147,38 @@ where type MemoryManagement = MM; type AutotuneKey = JitAutotuneKey; - fn read(&mut self, binding: server::Binding) -> Reader> { - #[cfg(target_family = "wasm")] - { - let future = self.buffer_reader(binding).read(self.device.clone()); - return Reader::Future(Box::pin(future)); - } + fn read(&mut self, binding: server::Binding) -> Reader { + let device = self.device.clone(); + let buffer = self.create_read_buffer(binding); - #[cfg(not(target_family = "wasm"))] - Reader::Concrete(self.buffer_reader(binding).read(&self.device)) + Box::pin(async move { + let buffer_slice = buffer.slice(..); + let (sender, receiver) = async_channel::bounded(1); + + buffer_slice.map_async(wgpu::MapMode::Read, move |v| { + sender + .try_send(v) + .expect("Unable to send buffer slice result to async channel.") + }); + + device.poll(wgpu::Maintain::Wait); + + let result = receiver + .recv() + .await + .expect("Unable to receive buffer slice result."); + + if let Ok(()) = result { + let data = buffer_slice.get_mapped_range(); + let result = bytemuck::cast_slice(&data).to_vec(); + + drop(data); + buffer.unmap(); + result + } else { + panic!("Unable to read buffer {:?}", result) + } + }) } fn get_resource( diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index 326a97630..4352559e8 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -25,9 +25,6 @@ tui = ["burn-train?/tui"] ## Includes system info metrics (CPU/GPU usage, etc) metrics = ["burn-train?/metrics"] -# Useful when targeting WASM and not using WGPU. -wasm-sync = ["burn-core/wasm-sync"] - # Datasets dataset = ["burn-core/dataset"] diff --git a/crates/burn/src/lib.rs b/crates/burn/src/lib.rs index b59716e76..a85269c1c 100644 --- a/crates/burn/src/lib.rs +++ b/crates/burn/src/lib.rs @@ -73,7 +73,6 @@ //! - `blas-netlib`: If supported, Blas Netlib will be use //! - `openblas`: If supported, Openblas will be use //! - `openblas-system`: If supported, Openblas installed on the system will be use -//! - `wasm-sync`: When targeting wasm, but want a sync API (won't work with WGPU) //! - `autotune`: Enable running benchmarks to select the best kernel in backends that support it. //! - `fusion`: Enable operation fusion in backends that support it. //! - Backend decorators diff --git a/examples/gelu/src/lib.rs b/examples/gelu/src/lib.rs index f8bf5ad67..df5d94c53 100644 --- a/examples/gelu/src/lib.rs +++ b/examples/gelu/src/lib.rs @@ -28,7 +28,7 @@ pub fn launch(device: &R::Device) { ArrayHandle::new(&output_handle, input.len()), ); - let output = client.read(output_handle.binding()).read_sync().unwrap(); + let output = client.read(output_handle.binding()); let output = f32::from_bytes(&output); // Should be [-0.1587, 0.0000, 0.8413, 5.0000] diff --git a/examples/image-classification-web/src/web.rs b/examples/image-classification-web/src/web.rs index 7afb4e22b..8ca6cb660 100644 --- a/examples/image-classification-web/src/web.rs +++ b/examples/image-classification-web/src/web.rs @@ -150,19 +150,13 @@ impl Model { // Convert the model output into probability distribution using softmax formula let probabilities = softmax(output, 1); - #[cfg(not(target_family = "wasm"))] - let result = probabilities.into_data().convert::().to_vec().unwrap(); - // Forces the result to be computed - #[cfg(target_family = "wasm")] - let result = probabilities - .into_data() + probabilities + .into_data_async() .await .convert::() .to_vec() - .unwrap(); - - result + .unwrap() } } diff --git a/examples/mnist-inference-web/Cargo.toml b/examples/mnist-inference-web/Cargo.toml index 943a56428..a72b3d3c1 100644 --- a/examples/mnist-inference-web/Cargo.toml +++ b/examples/mnist-inference-web/Cargo.toml @@ -24,6 +24,3 @@ console_error_panic_hook = { workspace = true } wasm-bindgen = "0.2" wasm-bindgen-futures = "0.4" js-sys = "0.3" - -[dev-dependencies] -pollster = { workspace = true } diff --git a/examples/mnist-inference-web/src/web.rs b/examples/mnist-inference-web/src/web.rs index 590925fa4..f4a8b4c40 100644 --- a/examples/mnist-inference-web/src/web.rs +++ b/examples/mnist-inference-web/src/web.rs @@ -68,11 +68,7 @@ impl Mnist { let output = burn::tensor::activation::softmax(output, 1); // Flatten output tensor with [1, 10] shape into boxed slice of [f32] - #[cfg(not(target_family = "wasm"))] - let output = output.into_data(); - - #[cfg(target_family = "wasm")] - let output = output.into_data().await; + let output = output.into_data_async().await; let array = Array::new(); for value in output.iter::() {