Consistent sync/async handling, allow more functions to be async for wasm. (#1936)

This commit is contained in:
Arthur Brussee 2024-07-02 13:25:28 +01:00 committed by GitHub
parent 6f2ba34382
commit 849c8f453b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
66 changed files with 456 additions and 986 deletions

67
Cargo.lock generated
View File

@ -190,14 +190,15 @@ dependencies = [
]
[[package]]
name = "async-trait"
version = "0.1.80"
name = "async-channel"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca"
checksum = "89b47800b0be77592da0afd425cc03468052844aff33b84e33cc696f64e77b6a"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.68",
"concurrent-queue",
"event-listener-strategy",
"futures-core",
"pin-project-lite",
]
[[package]]
@ -461,12 +462,12 @@ dependencies = [
name = "burn-common"
version = "0.14.0"
dependencies = [
"async-trait",
"dashmap",
"data-encoding",
"derive-new",
"getrandom",
"indicatif",
"pollster",
"rand",
"reqwest 0.12.5",
"serde",
@ -479,12 +480,14 @@ dependencies = [
name = "burn-compute"
version = "0.14.0"
dependencies = [
"async-channel",
"burn-common",
"derive-new",
"dirs 5.0.1",
"hashbrown 0.14.5",
"log",
"md5",
"pollster",
"rand",
"serde",
"serde_json",
@ -761,6 +764,7 @@ dependencies = [
name = "burn-wgpu"
version = "0.14.0"
dependencies = [
"async-channel",
"burn-common",
"burn-compute",
"burn-cube",
@ -769,7 +773,6 @@ dependencies = [
"burn-tensor",
"bytemuck",
"derive-new",
"futures-intrusive",
"hashbrown 0.14.5",
"log",
"pollster",
@ -1143,6 +1146,15 @@ dependencies = [
"static_assertions",
]
[[package]]
name = "concurrent-queue"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "console"
version = "0.15.8"
@ -1733,6 +1745,27 @@ version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6"
[[package]]
name = "event-listener"
version = "5.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba"
dependencies = [
"concurrent-queue",
"parking",
"pin-project-lite",
]
[[package]]
name = "event-listener-strategy"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f214dc438f977e6d4e3500aaa277f5ad94ca83fbbd9b1a15713ce2344ccc5a1"
dependencies = [
"event-listener",
"pin-project-lite",
]
[[package]]
name = "exr"
version = "1.72.0"
@ -1931,17 +1964,6 @@ dependencies = [
"futures-util",
]
[[package]]
name = "futures-intrusive"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f"
dependencies = [
"futures-core",
"lock_api",
"parking_lot 0.12.2",
]
[[package]]
name = "futures-io"
version = "0.3.30"
@ -3274,7 +3296,6 @@ dependencies = [
"burn",
"console_error_panic_hook",
"js-sys",
"pollster",
"serde",
"wasm-bindgen",
"wasm-bindgen-futures",
@ -3778,6 +3799,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]]
name = "parking"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae"
[[package]]
name = "parking_lot"
version = "0.11.2"

View File

@ -26,7 +26,6 @@ readme = "README.md"
license = "MIT OR Apache-2.0"
[workspace.dependencies]
async-trait = "0.1.80"
bytemuck = "1.16.1"
candle-core = { version = "0.5.1" }
clap = { version = "4.5.8", features = ["derive"] }
@ -83,14 +82,16 @@ tracing-subscriber = "0.3.18"
web-time = "1.1.0"
zip = "2.1.3"
# Async handling
pollster = "0.3"
async-channel = "2.3"
# Terminal UI
ratatui = "0.26.3"
crossterm = "0.27.0"
# WGPU stuff
futures-intrusive = "0.5.0"
text_placeholder = "0.5.0"
pollster = "0.3.0"
wgpu = "0.20.1"
# Benchmarks and Burnbench

View File

@ -3,7 +3,7 @@ use crate::{checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor, Au
use burn_tensor::{
backend::Backend,
ops::{BoolTensor, BoolTensorOps, IntTensor},
Device, Reader, Shape, TensorData,
Device, Shape, TensorData,
};
impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
@ -15,12 +15,8 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
B::bool_shape(tensor)
}
fn bool_to_data<const D: usize>(tensor: &BoolTensor<B, D>) -> Reader<TensorData> {
B::bool_to_data(tensor)
}
fn bool_into_data<const D: usize>(tensor: BoolTensor<B, D>) -> Reader<TensorData> {
B::bool_into_data(tensor)
async fn bool_into_data<const D: usize>(tensor: BoolTensor<B, D>) -> TensorData {
B::bool_into_data(tensor).await
}
fn bool_into_int<const D: usize>(tensor: BoolTensor<B, D>) -> IntTensor<B, D> {
@ -121,14 +117,12 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
B::bool_flip(tensor, axes)
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn bool_argwhere<const D: usize>(tensor: BoolTensor<B, D>) -> IntTensor<B, 2> {
B::bool_argwhere(tensor)
async fn bool_argwhere<const D: usize>(tensor: BoolTensor<B, D>) -> IntTensor<B, 2> {
B::bool_argwhere(tensor).await
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn bool_nonzero<const D: usize>(tensor: BoolTensor<B, D>) -> Vec<IntTensor<B, 1>> {
B::bool_nonzero(tensor)
async fn bool_nonzero<const D: usize>(tensor: BoolTensor<B, D>) -> Vec<IntTensor<B, 1>> {
B::bool_nonzero(tensor).await
}
fn bool_expand<const D: usize, const D2: usize>(

View File

@ -3,7 +3,7 @@ use crate::{checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor, Au
use burn_tensor::{
backend::Backend,
ops::{BoolTensor, IntTensor, IntTensorOps},
Device, Distribution, Reader, Shape, TensorData,
Device, Distribution, Shape, TensorData,
};
impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
@ -15,12 +15,8 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_shape(tensor)
}
fn int_to_data<const D: usize>(tensor: &IntTensor<B, D>) -> Reader<TensorData> {
B::int_to_data(tensor)
}
fn int_into_data<const D: usize>(tensor: IntTensor<B, D>) -> Reader<TensorData> {
B::int_into_data(tensor)
async fn int_into_data<const D: usize>(tensor: IntTensor<B, D>) -> TensorData {
B::int_into_data(tensor).await
}
fn int_to_device<const D: usize>(
@ -380,7 +376,6 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_expand(tensor, shape)
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn int_sort<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,
@ -389,7 +384,6 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_sort(tensor, dim, descending)
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn int_sort_with_indices<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,
@ -398,7 +392,6 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_sort_with_indices(tensor, dim, descending)
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn int_argsort<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,

View File

@ -17,7 +17,7 @@ use crate::{
use burn_tensor::{
backend::Backend,
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor},
Device, ElementConversion, Reader, Shape, Tensor, TensorData,
Device, ElementConversion, Shape, Tensor, TensorData,
};
use super::maxmin::MaxMinDim;
@ -50,12 +50,8 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
B::float_shape(&tensor.primitive)
}
fn float_to_data<const D: usize>(tensor: &FloatTensor<Self, D>) -> Reader<TensorData> {
B::float_to_data(&tensor.primitive)
}
fn float_into_data<const D: usize>(tensor: FloatTensor<Self, D>) -> Reader<TensorData> {
B::float_into_data(tensor.primitive)
async fn float_into_data<const D: usize>(tensor: FloatTensor<Self, D>) -> TensorData {
B::float_into_data(tensor.primitive).await
}
fn float_device<const D: usize>(tensor: &FloatTensor<Self, D>) -> Device<Self> {
@ -2364,7 +2360,6 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn float_sort<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
@ -2387,7 +2382,6 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn float_sort_with_indices<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
@ -2416,7 +2410,6 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn float_argsort<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,

View File

@ -1,6 +1,6 @@
use std::marker::PhantomData;
use burn_tensor::{backend::Backend, Reader, Shape, TensorData};
use burn_tensor::{backend::Backend, Shape, TensorData};
use crate::{
element::{CandleElement, FloatCandleElement, IntCandleElement},

View File

@ -1,6 +1,6 @@
use burn_tensor::{
ops::{BoolTensor, BoolTensorOps, FloatTensor, IntTensor},
Device, Reader, Shape, TensorData,
Device, Shape, TensorData,
};
use crate::{
@ -19,12 +19,10 @@ impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<
super::base::shape(tensor)
}
fn bool_into_data<const D: usize>(tensor: BoolTensor<Self, D>) -> Reader<TensorData> {
async fn bool_into_data<const D: usize>(tensor: BoolTensor<Self, D>) -> TensorData {
let x: Vec<u8> = tensor.tensor.flatten_all().unwrap().to_vec1().unwrap();
let y = x.iter().map(|b| !matches!(b, 0)).collect();
let data = TensorData::new(y, tensor.shape());
Reader::Concrete(data)
TensorData::new(y, tensor.shape())
}
fn bool_from_data<const D: usize>(

View File

@ -1,6 +1,6 @@
use burn_tensor::{
ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps},
Bool, Device, Distribution, ElementConversion, Reader, Shape, TensorData,
Bool, Device, Distribution, ElementConversion, Shape, TensorData,
};
use crate::{
@ -19,8 +19,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
super::base::shape(tensor)
}
fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> Reader<TensorData> {
Reader::Concrete(super::base::into_data(tensor))
async fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> TensorData {
super::base::into_data(tensor)
}
fn int_from_data<const D: usize>(

View File

@ -2,7 +2,7 @@ use std::borrow::Borrow;
use burn_tensor::{
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, FullPrecisionBackend, IntTensor},
Device, Distribution, ElementConversion, Reader, Shape, TensorData,
Device, Distribution, ElementConversion, Shape, TensorData,
};
use candle_core::{backend::BackendStorage, shape, Tensor};
@ -59,8 +59,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
super::base::shape(tensor)
}
fn float_into_data<const D: usize>(tensor: CandleTensor<F, D>) -> Reader<TensorData> {
Reader::Concrete(super::base::into_data(tensor))
async fn float_into_data<const D: usize>(tensor: CandleTensor<F, D>) -> TensorData {
super::base::into_data(tensor)
}
fn float_device<const D: usize>(tensor: &CandleTensor<F, D>) -> Device<Self> {

View File

@ -12,25 +12,23 @@ version.workspace = true
[features]
default = ["std"]
std = ["rand/std", "data-encoding/std"]
std = ["rand/std", "data-encoding/std", "dep:pollster"]
doc = ["default"]
wasm-sync = []
network = ["dep:indicatif", "dep:reqwest", "dep:tokio"]
[target.'cfg(target_family = "wasm")'.dependencies]
async-trait = { workspace = true }
getrandom = { workspace = true, features = ["js"] }
web-time = { version = "1.1.0" }
[dependencies]
# ** Please make sure all dependencies support no_std when std is disabled **
rand = { workspace = true }
spin = { workspace = true } # using in place of use std::sync::Mutex;
spin = { workspace = true } # using in place of use std::sync::Mutex;
derive-new = { workspace = true }
serde = { workspace = true }
data-encoding = { workspace = true }
pollster = { workspace = true, optional = true }
# Network downloader
indicatif = { workspace = true, optional = true }

View File

@ -1,115 +1,54 @@
use alloc::boxed::Box;
use core::marker::PhantomData;
use alloc::{boxed::Box, sync::Arc, task::Wake, vec::Vec};
use core::{
future::Future,
pin::Pin,
task::{Context, Poll, Waker},
};
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
#[async_trait::async_trait]
/// Allows to create async reader.
pub trait AsyncReader<T>: Send {
/// Read asynchronously.
async fn read(self: Box<Self>) -> T;
/// A future that is used to read resources from a compute server.
pub type Reader = Pin<Box<dyn Future<Output = Vec<u8>> + Send>>;
/// Create a reader from a concrete value.
pub fn reader_from_concrete(val: Vec<u8>) -> Reader {
Box::pin(async move { val })
}
/// Define how data is read, sync or async.
pub enum Reader<T> {
/// Concrete variant.
Concrete(T),
/// Sync data variant.
Sync(Box<dyn SyncReader<T>>),
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
/// Async data variant.
Async(Box<dyn AsyncReader<T>>),
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
/// Future data variant.
Future(core::pin::Pin<Box<dyn core::future::Future<Output = T> + Send>>),
struct DummyWaker;
impl Wake for DummyWaker {
fn wake(self: Arc<Self>) {}
fn wake_by_ref(self: &Arc<Self>) {}
}
/// Allows to create sync reader.
pub trait SyncReader<T>: Send {
/// Read synchronously.
fn read(self: Box<Self>) -> T;
/// Read a future synchronously.
///
/// On WASM futures cannot block, so this only succeeds if the future returns immediately.
/// If you want to handle this error, please use
/// try_read_sync instead.
pub fn read_sync<F: Future<Output = T>, T>(f: F) -> T {
try_read_sync(f).expect("Failed to read tensor data synchronously. This can happen on platforms that don't support blocking futures like WASM. If possible, try using an async variant of this function instead.")
}
#[derive(new)]
struct MappedReader<I, O, F> {
reader: Reader<I>,
mapper: F,
_output: PhantomData<O>,
}
/// Read a future synchronously.
///
/// On WASM futures cannot block, so this only succeeds if the future returns immediately.
/// otherwise this returns None.
pub fn try_read_sync<F: Future<Output = T>, T>(f: F) -> Option<T> {
// Create a dummy context.
let waker = Waker::from(Arc::new(DummyWaker));
let mut context = Context::from_waker(&waker);
impl<I, O, F> SyncReader<O> for MappedReader<I, O, F>
where
I: Send,
O: Send,
F: Send + FnOnce(I) -> O,
{
fn read(self: Box<Self>) -> O {
let input = self
.reader
.read_sync()
.expect("Only sync data supported in a sync reader.");
// Pin & poll the future. Some backends don't do async readbacks, and instead immediately get
// the data. This let's us detect when a future is synchronous and doesn't require any waiting.
let mut pinned = core::pin::pin!(f);
(self.mapper)(input)
}
}
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
#[async_trait::async_trait]
impl<I, O, F> AsyncReader<O> for MappedReader<I, O, F>
where
I: Send,
O: Send,
F: Send + FnOnce(I) -> O,
{
async fn read(self: Box<Self>) -> O {
let input = self.reader.read().await;
(self.mapper)(input)
}
}
impl<T> Reader<T> {
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
/// Read the data.
pub async fn read(self) -> T {
match self {
Self::Concrete(data) => data,
Self::Sync(reader) => reader.read(),
Self::Async(func) => func.read().await,
Self::Future(future) => future.await,
}
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
/// Read the data.
pub fn read(self) -> T {
match self {
Self::Concrete(data) => data,
Self::Sync(reader) => reader.read(),
}
}
/// Read the data only if sync, returns None if an async reader.
pub fn read_sync(self) -> Option<T> {
match self {
Self::Concrete(data) => Some(data),
Self::Sync(reader) => Some(reader.read()),
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
Self::Async(_func) => return None,
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
Self::Future(_future) => return None,
}
}
/// Map the current reader to another type.
pub fn map<O, F>(self, mapper: F) -> Reader<O>
where
T: 'static + Send,
O: 'static + Send,
F: FnOnce(T) -> O + 'static + Send,
{
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
return Reader::Async(Box::new(MappedReader::new(self, mapper)));
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
Reader::Sync(Box::new(MappedReader::new(self, mapper)))
match pinned.as_mut().poll(&mut context) {
Poll::Ready(output) => Some(output),
// On platforms that support it, now just block on the future and drive it to completion.
#[cfg(all(not(target_family = "wasm"), feature = "std"))]
Poll::Pending => Some(pollster::block_on(pinned)),
// Otherwise, just bail and return None - this futures will have to be read back asynchronously.
#[cfg(any(target_family = "wasm", not(feature = "std")))]
Poll::Pending => None,
}
}

View File

@ -22,7 +22,7 @@ default = [
std = ["burn-common/std"]
channel-mutex = []
channel-cell = []
channel-mpsc = [] # Assume std
channel-mpsc = ["dep:async-channel", "dep:pollster"] # Assume std
storage-bytes = []
autotune-persistent-cache = ["dirs", "md5", "serde", "serde_json"] # Assume std
@ -36,6 +36,8 @@ dirs = { workspace = true, optional = true }
serde = { workspace = true, optional = true }
serde_json = { workspace = true, features = ["std"], optional = true }
md5 = { workspace = true, optional = true }
pollster = { workspace = true, optional = true }
async-channel = { workspace = true, optional = true }
[target.'cfg(target_family = "wasm")'.dependencies]
web-time = { workspace = true }

View File

@ -9,7 +9,7 @@ use burn_common::{reader::Reader, sync_type::SyncType};
/// while ensuring thread-safety
pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug + Send + Sync {
/// Given a binding, returns owned resource as bytes
fn read(&self, binding: Binding<Server>) -> Reader<Vec<u8>>;
fn read(&self, binding: Binding<Server>) -> Reader;
/// Given a resource handle, return the storage resource.
fn get_resource(

View File

@ -42,9 +42,9 @@ where
impl<Server> ComputeChannel<Server> for RefCellComputeChannel<Server>
where
Server: ComputeServer,
Server: ComputeServer + Send,
{
fn read(&self, binding: Binding<Server>) -> Reader<Vec<u8>> {
fn read(&self, binding: Binding<Server>) -> Reader {
self.server.borrow_mut().read(binding)
}

View File

@ -1,9 +1,5 @@
use std::{
sync::{mpsc, Arc},
thread,
};
use burn_common::{reader::Reader, sync_type::SyncType};
use std::{sync::Arc, thread};
use super::ComputeChannel;
use crate::{
@ -11,7 +7,7 @@ use crate::{
storage::ComputeStorage,
};
/// Create a channel using the [multi-producer, single-consumer channel](mpsc) to communicate with
/// Create a channel using a [multi-producer, single-consumer channel to communicate with
/// the compute server spawn on its own thread.
#[derive(Debug)]
pub struct MpscComputeChannel<Server>
@ -27,16 +23,16 @@ where
Server: ComputeServer,
{
_handle: thread::JoinHandle<()>,
sender: mpsc::Sender<Message<Server>>,
sender: async_channel::Sender<Message<Server>>,
}
type Callback<Response> = mpsc::Sender<Response>;
type Callback<Response> = async_channel::Sender<Response>;
enum Message<Server>
where
Server: ComputeServer,
{
Read(Binding<Server>, Callback<Reader<Vec<u8>>>),
Read(Binding<Server>, Callback<Vec<u8>>),
GetResource(
Binding<Server>,
Callback<<Server::Storage as ComputeStorage>::Resource>,
@ -53,36 +49,40 @@ where
{
/// Create a new mpsc compute channel.
pub fn new(mut server: Server) -> Self {
let (sender, receiver) = mpsc::channel();
let (sender, receiver) = async_channel::unbounded();
let _handle = thread::spawn(move || {
while let Ok(message) = receiver.recv() {
match message {
Message::Read(binding, callback) => {
let data = server.read(binding);
callback.send(data).unwrap();
}
Message::GetResource(binding, callback) => {
let data = server.get_resource(binding);
callback.send(data).unwrap();
}
Message::Create(data, callback) => {
let handle = server.create(&data);
callback.send(handle).unwrap();
}
Message::Empty(size, callback) => {
let handle = server.empty(size);
callback.send(handle).unwrap();
}
Message::ExecuteKernel(kernel, bindings) => {
server.execute(kernel, bindings);
}
Message::Sync(sync_type, callback) => {
server.sync(sync_type);
callback.send(()).unwrap();
}
};
}
// Run the whole procedure as one blocking future. This is much simpler than trying
// to use some multithreaded executor.
pollster::block_on(async {
while let Ok(message) = receiver.recv().await {
match message {
Message::Read(binding, callback) => {
let data = server.read(binding).await;
callback.send(data).await.unwrap();
}
Message::GetResource(binding, callback) => {
let data = server.get_resource(binding);
callback.send(data).await.unwrap();
}
Message::Create(data, callback) => {
let handle = server.create(&data);
callback.send(handle).await.unwrap();
}
Message::Empty(size, callback) => {
let handle = server.empty(size);
callback.send(handle).await.unwrap();
}
Message::ExecuteKernel(kernel, bindings) => {
server.execute(kernel, bindings);
}
Message::Sync(sync_type, callback) => {
server.sync(sync_type);
callback.send(()).await.unwrap();
}
};
}
});
});
let state = Arc::new(MpscComputeChannelState { sender, _handle });
@ -103,75 +103,71 @@ impl<Server> ComputeChannel<Server> for MpscComputeChannel<Server>
where
Server: ComputeServer + 'static,
{
fn read(&self, binding: Binding<Server>) -> Reader<Vec<u8>> {
let (callback, response) = mpsc::channel();
fn read(&self, binding: Binding<Server>) -> Reader {
let sender = self.state.sender.clone();
self.state
.sender
.send(Message::Read(binding, callback))
.unwrap();
self.response(response)
Box::pin(async move {
let (callback, response) = async_channel::unbounded();
sender.send(Message::Read(binding, callback)).await.unwrap();
handle_response(response.recv().await)
})
}
fn get_resource(
&self,
binding: Binding<Server>,
) -> <Server::Storage as ComputeStorage>::Resource {
let (callback, response) = mpsc::channel();
let (callback, response) = async_channel::unbounded();
self.state
.sender
.send(Message::GetResource(binding, callback))
.send_blocking(Message::GetResource(binding, callback))
.unwrap();
self.response(response)
handle_response(response.recv_blocking())
}
fn create(&self, data: &[u8]) -> Handle<Server> {
let (callback, response) = mpsc::channel();
let (callback, response) = async_channel::unbounded();
self.state
.sender
.send(Message::Create(data.to_vec(), callback))
.send_blocking(Message::Create(data.to_vec(), callback))
.unwrap();
self.response(response)
handle_response(response.recv_blocking())
}
fn empty(&self, size: usize) -> Handle<Server> {
let (callback, response) = mpsc::channel();
let (callback, response) = async_channel::unbounded();
self.state
.sender
.send(Message::Empty(size, callback))
.send_blocking(Message::Empty(size, callback))
.unwrap();
self.response(response)
handle_response(response.recv_blocking())
}
fn execute(&self, kernel: Server::Kernel, bindings: Vec<Binding<Server>>) {
self.state
.sender
.send(Message::ExecuteKernel(kernel, bindings))
.send_blocking(Message::ExecuteKernel(kernel, bindings))
.unwrap()
}
fn sync(&self, sync_type: SyncType) {
let (callback, response) = mpsc::channel();
let (callback, response) = async_channel::unbounded();
self.state
.sender
.send(Message::Sync(sync_type, callback))
.send_blocking(Message::Sync(sync_type, callback))
.unwrap();
self.response(response)
handle_response(response.recv_blocking())
}
}
impl<Server: ComputeServer> MpscComputeChannel<Server> {
fn response<Response>(&self, response: mpsc::Receiver<Response>) -> Response {
match response.recv() {
Ok(val) => val,
Err(err) => panic!("Can't connect to the server correctly {err:?}"),
}
fn handle_response<Response, Err: core::fmt::Debug>(response: Result<Response, Err>) -> Response {
match response {
Ok(val) => val,
Err(err) => panic!("Can't connect to the server correctly {err:?}"),
}
}

View File

@ -37,7 +37,7 @@ impl<Server> ComputeChannel<Server> for MutexComputeChannel<Server>
where
Server: ComputeServer,
{
fn read(&self, handle: Binding<Server>) -> Reader<Vec<u8>> {
fn read(&self, handle: Binding<Server>) -> Reader {
self.server.lock().read(handle)
}

View File

@ -7,7 +7,7 @@ use crate::{
use alloc::vec::Vec;
use alloc::{boxed::Box, sync::Arc};
use burn_common::stub::RwLock;
use burn_common::{reader::Reader, sync_type::SyncType};
use burn_common::sync_type::SyncType;
/// The ComputeClient is the entry point to require tasks from the ComputeServer.
/// It should be obtained for a specific device via the Compute struct.
@ -41,8 +41,16 @@ where
}
/// Given a binding, returns owned resource as bytes.
pub fn read(&self, binding: Binding<Server>) -> Reader<Vec<u8>> {
self.channel.read(binding)
pub async fn read_async(&self, binding: Binding<Server>) -> Vec<u8> {
self.channel.read(binding).await
}
/// Given a binding, returns owned resource as bytes.
///
/// # Remarks
/// Panics if the read operation fails.
pub fn read(&self, binding: Binding<Server>) -> Vec<u8> {
burn_common::reader::read_sync(self.channel.read(binding))
}
/// Given a resource handle, returns the storage resource.

View File

@ -25,7 +25,7 @@ where
type AutotuneKey: AutotuneKey;
/// Given a handle, returns the owned resource as bytes.
fn read(&mut self, binding: Binding<Self>) -> Reader<Vec<u8>>;
fn read(&mut self, binding: Binding<Self>) -> Reader;
/// Given a resource handle, returns the storage resource.
fn get_resource(

View File

@ -1,6 +1,6 @@
use std::sync::Arc;
use burn_common::{reader::Reader, sync_type::SyncType};
use burn_common::{reader::reader_from_concrete, sync_type::SyncType};
use burn_compute::{
memory_management::{simple::SimpleMemoryManagement, MemoryHandle, MemoryManagement},
server::{Binding, ComputeServer, Handle},
@ -26,10 +26,9 @@ where
type MemoryManagement = MM;
type AutotuneKey = String;
fn read(&mut self, binding: Binding<Self>) -> Reader<Vec<u8>> {
fn read(&mut self, binding: Binding<Self>) -> burn_common::reader::Reader {
let bytes = self.memory_management.get(binding.memory);
Reader::Concrete(bytes.read().to_vec())
reader_from_concrete(bytes.read().to_vec())
}
fn get_resource(&mut self, binding: Binding<Self>) -> BytesResource {

View File

@ -16,7 +16,7 @@ fn created_resource_is_the_same_when_read() {
let obtained_resource = client.read(resource_description.binding());
assert_eq!(resource, obtained_resource.read())
assert_eq!(resource, obtained_resource)
}
#[test]
@ -26,7 +26,7 @@ fn empty_allocates_memory() {
let resource_description = client.empty(size);
let empty_resource = client.read(resource_description.binding());
assert_eq!(empty_resource.read().len(), 4);
assert_eq!(empty_resource.len(), 4);
}
#[test]
@ -43,7 +43,7 @@ fn execute_elementwise_addition() {
let obtained_resource = client.read(out.binding());
assert_eq!(obtained_resource.read(), Vec::from([4, 5, 6]))
assert_eq!(obtained_resource, Vec::from([4, 5, 6]))
}
#[test]
@ -65,7 +65,7 @@ fn autotune_basic_addition_execution() {
let obtained_resource = client.read(out.binding());
// If slow kernel was selected it would output [0, 1, 2]
assert_eq!(obtained_resource.read(), Vec::from([4, 5, 6]));
assert_eq!(obtained_resource, Vec::from([4, 5, 6]));
}
#[test]
@ -87,7 +87,7 @@ fn autotune_basic_multiplication_execution() {
let obtained_resource = client.read(out.binding());
// If slow kernel was selected it would output [0, 1, 2]
assert_eq!(obtained_resource.read(), Vec::from([0, 4, 8]));
assert_eq!(obtained_resource, Vec::from([0, 4, 8]));
}
#[test]
@ -126,7 +126,7 @@ fn autotune_cache_same_key_return_a_cache_hit() {
let obtained_resource = client.read(out_2.binding());
// Cache should be hit, so CacheTestFastOn3 should be used, returning lhs
assert_eq!(obtained_resource.read(), Vec::from([0, 1, 2, 3]));
assert_eq!(obtained_resource, Vec::from([0, 1, 2, 3]));
}
#[test]
@ -167,7 +167,7 @@ fn autotune_cache_no_cache_on_disk_return_a_cache_miss() {
let obtained_resource = client.read(out_2.binding());
// Cache should be missed, so CacheTestSlowOn3 (but faster on 5) should be used, returning rhs
assert_eq!(obtained_resource.read(), Vec::from([5, 6, 7, 8, 9]));
assert_eq!(obtained_resource, Vec::from([5, 6, 7, 8, 9]));
}
#[test]
@ -175,6 +175,8 @@ fn autotune_cache_no_cache_on_disk_return_a_cache_miss() {
#[cfg(feature = "std")]
fn autotune_cache_file_path_creation_works_when_path_does_not_exist_yet() {
// delete the cache file
use burn_common::sync_type::SyncType;
let file_path = burn_compute::tune::get_persistent_cache_file_path(crate::dummy::TUNER_PREFIX);
let parent_dir = file_path
.parent()
@ -198,7 +200,7 @@ fn autotune_cache_file_path_creation_works_when_path_does_not_exist_yet() {
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes, handles);
client.autotune_execute(Box::new(cache_test_autotune_kernel));
// ensure that the autotune operations are run and cached
let _obtained_resource = client.read(out.binding());
client.sync(SyncType::Wait);
assert!(
parent_dir.exists(),
@ -237,7 +239,7 @@ fn autotune_cache_different_keys_return_a_cache_miss() {
let obtained_resource = client.read(out_2.binding());
// Cache should be missed, so CacheTestSlowOn3 (but faster on 5) should be used, returning rhs
assert_eq!(obtained_resource.read(), Vec::from([5, 6, 7, 8, 9]));
assert_eq!(obtained_resource, Vec::from([5, 6, 7, 8, 9]));
}
#[test]
@ -285,5 +287,5 @@ fn autotune_cache_different_checksums_return_a_cache_miss() {
// Cache should be missed because the checksum on 4 is generated randomly
// and thus is always different,
// so CacheTestSlowOn3 (but faster on 4) should be used, returning rhs
assert_eq!(obtained_resource.read(), Vec::from([5, 6, 7, 8]));
assert_eq!(obtained_resource, Vec::from([5, 6, 7, 8]));
}

View File

@ -65,8 +65,6 @@ sqlite = ["burn-dataset?/sqlite"]
sqlite-bundled = ["burn-dataset?/sqlite-bundled"]
vision = ["burn-dataset?/vision", "burn-common/network"]
wasm-sync = ["burn-tensor/wasm-sync", "burn-common/wasm-sync"]
# Backend
autodiff = ["burn-autodiff"]
fusion = ["burn-wgpu?/fusion"]

View File

@ -69,16 +69,6 @@ impl GradientClipping {
clipped_grad.mask_fill(lower_mask, -threshold)
}
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
fn clip_by_norm<B: Backend, const D: usize>(
&self,
_grad: Tensor<B, D>,
_threshold: f32,
) -> Tensor<B, D> {
todo!("Not yet supported on wasm");
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn clip_by_norm<B: Backend, const D: usize>(
&self,
grad: Tensor<B, D>,
@ -97,11 +87,9 @@ impl GradientClipping {
}
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn l2_norm<B: Backend, const D: usize>(tensor: Tensor<B, D>) -> Tensor<B, 1> {
let squared = tensor.powf_scalar(2.0);
let sum = squared.sum();
sum.sqrt()
}
}

View File

@ -135,10 +135,6 @@ impl<B: Backend, const D: usize> Record<B> for Tensor<B, D> {
type Item<S: PrecisionSettings> = FloatTensorSerde<S>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
todo!("Recording float tensors isn't yet supported on wasm.");
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
FloatTensorSerde::new(self.into_data().convert::<S::FloatElem>())
}
@ -152,10 +148,6 @@ impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Int> {
type Item<S: PrecisionSettings> = IntTensorSerde<S>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
todo!("Recording int tensors isn't yet supported on wasm.");
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
IntTensorSerde::new(self.into_data().convert::<S::IntElem>())
}
@ -169,10 +161,6 @@ impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Bool> {
type Item<S: PrecisionSettings> = BoolTensorSerde;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
todo!("Recording bool tensors isn't yet supported on wasm.");
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
BoolTensorSerde::new(self.into_data())
}

View File

@ -25,7 +25,7 @@ pub fn test_kernel_with_generics<R: Runtime>(client: ComputeClient<R::Server, R:
ArrayHandle::new(&handle, 2),
);
let actual = client.read(handle.binding()).read_sync().unwrap();
let actual = client.read(handle.binding());
let actual = f32::from_bytes(&actual);
assert_eq!(actual[0], 5.0);
@ -41,7 +41,7 @@ pub fn test_kernel_without_generics<R: Runtime>(client: ComputeClient<R::Server,
ArrayHandle::new(&handle, 2),
);
let actual = client.read(handle.binding()).read_sync().unwrap();
let actual = client.read(handle.binding());
let actual = f32::from_bytes(&actual);
assert_eq!(actual[0], 5.0);

View File

@ -109,7 +109,7 @@ fn test_subcube_operation<TestRuntime: Runtime, Launch>(
TensorHandle::new(&handle, &strides, &shape),
);
let actual = client.read(handle.binding()).read_sync().unwrap();
let actual = client.read(handle.binding());
let actual = f32::from_bytes(&actual);
assert_eq!(actual, expected);

View File

@ -8,6 +8,8 @@ use burn_cube::ir::CubeDim;
use burn_cube::prelude::*;
use burn_jit::JitAutotuneKey;
use burn_tensor::backend::SyncType;
use burn_tensor::reader_from_concrete;
use burn_tensor::Reader;
use cudarc::driver::sys::CUctx_st;
use cudarc::driver::sys::CUfunc_st;
use std::collections::HashMap;
@ -58,18 +60,17 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
type MemoryManagement = MM;
type AutotuneKey = JitAutotuneKey;
fn read(&mut self, binding: server::Binding<Self>) -> burn_tensor::Reader<Vec<u8>> {
fn read(&mut self, binding: server::Binding<Self>) -> Reader {
let ctx = self.get_context();
let resource = ctx.memory_management.get(binding.memory);
// TODO: Check if it is possible to make this faster
let mut data = vec![0; resource.size() as usize];
unsafe {
cudarc::driver::result::memcpy_dtoh_async(&mut data, resource.ptr, ctx.stream).unwrap();
};
ctx.sync();
burn_tensor::Reader::Concrete(data)
reader_from_concrete(data)
}
fn create(&mut self, data: &[u8]) -> server::Handle<Self> {

View File

@ -1,10 +1,12 @@
use std::future::Future;
use crate::{
stream::{execution::Operation, StreamId},
FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionTensor,
};
use burn_tensor::{
repr::{OperationDescription, TensorDescription, TensorId},
DType, Reader, TensorData,
DType, TensorData,
};
/// Define how to interact with the fusion server.
@ -37,7 +39,7 @@ where
&self,
tensor: TensorDescription,
stream: StreamId,
) -> Reader<TensorData>
) -> impl Future<Output = TensorData> + Send
where
B: FusionBackend<FusionRuntime = R>;
/// Read the values contained by an int tensor.
@ -45,7 +47,7 @@ where
&self,
tensor: TensorDescription,
stream: StreamId,
) -> Reader<TensorData>
) -> impl Future<Output = TensorData> + Send
where
B: FusionBackend<FusionRuntime = R>;
/// Read the values contained by a bool tensor.
@ -53,7 +55,7 @@ where
&self,
tensor: TensorDescription,
stream: StreamId,
) -> Reader<TensorData>
) -> impl Future<Output = TensorData> + Send
where
B: FusionBackend<FusionRuntime = R>;
/// Change the client of the given float tensor.

View File

@ -78,37 +78,37 @@ where
FusionTensor::new(id, shape, dtype, self.clone(), stream)
}
fn read_tensor_float<B, const D: usize>(
async fn read_tensor_float<B, const D: usize>(
&self,
tensor: TensorDescription,
stream: StreamId,
) -> burn_tensor::Reader<burn_tensor::TensorData>
) -> burn_tensor::TensorData
where
B: FusionBackend<FusionRuntime = R>,
{
self.server.lock().read_float::<B, D>(tensor, stream)
self.server.lock().read_float::<B, D>(tensor, stream).await
}
fn read_tensor_int<B, const D: usize>(
async fn read_tensor_int<B, const D: usize>(
&self,
tensor: TensorDescription,
id: StreamId,
) -> burn_tensor::Reader<burn_tensor::TensorData>
) -> burn_tensor::TensorData
where
B: FusionBackend<FusionRuntime = R>,
{
self.server.lock().read_int::<B, D>(tensor, id)
self.server.lock().read_int::<B, D>(tensor, id).await
}
fn read_tensor_bool<B, const D: usize>(
async fn read_tensor_bool<B, const D: usize>(
&self,
tensor: TensorDescription,
stream: StreamId,
) -> burn_tensor::Reader<burn_tensor::TensorData>
) -> burn_tensor::TensorData
where
B: FusionBackend<FusionRuntime = R>,
{
self.server.lock().read_bool::<B, D>(tensor, stream)
self.server.lock().read_bool::<B, D>(tensor, stream).await
}
fn change_client_float<B, const D: usize>(

View File

@ -1,4 +1,4 @@
use burn_tensor::{DType, Element};
use burn_tensor::{DType, Element, TensorData};
use std::marker::PhantomData;
use crate::{
@ -37,10 +37,8 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
tensor.shape()
}
fn bool_into_data<const D: usize>(
tensor: BoolTensor<Self, D>,
) -> burn_tensor::Reader<burn_tensor::TensorData> {
tensor.bool_into_data::<B, D>()
async fn bool_into_data<const D: usize>(tensor: BoolTensor<Self, D>) -> TensorData {
tensor.bool_into_data::<B, D>().await
}
fn bool_from_data<const D: usize>(

View File

@ -10,7 +10,7 @@ use crate::{
use burn_tensor::{
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor},
repr::*,
DType, Device, Distribution, Element, ElementConversion, Reader, Shape, TensorData,
DType, Device, Distribution, Element, ElementConversion, Shape, TensorData,
};
use std::{marker::PhantomData, ops::Range};
@ -169,8 +169,8 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
tensor.shape()
}
fn float_into_data<const D: usize>(tensor: FloatTensor<Self, D>) -> Reader<TensorData> {
tensor.into_data::<B, D>()
async fn float_into_data<const D: usize>(tensor: FloatTensor<Self, D>) -> TensorData {
tensor.into_data::<B, D>().await
}
fn float_device<const D: usize>(tensor: &FloatTensor<Self, D>) -> Device<Self> {

View File

@ -10,7 +10,7 @@ use crate::{
use burn_tensor::{
ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps},
repr::{self, *},
DType, Device, Distribution, Element, ElementConversion, Reader, Shape, TensorData,
DType, Device, Distribution, Element, ElementConversion, Shape, TensorData,
};
use core::ops::Range;
use std::marker::PhantomData;
@ -33,8 +33,8 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
tensor.shape()
}
fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> Reader<TensorData> {
tensor.int_into_data::<B, D>()
async fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> TensorData {
tensor.int_into_data::<B, D>().await
}
fn int_from_data<const D: usize>(

View File

@ -39,11 +39,11 @@ where
self.handles.create_tensor_uninit()
}
pub fn read_float<B, const D: usize>(
pub async fn read_float<B, const D: usize>(
&mut self,
tensor: TensorDescription,
id: StreamId,
) -> burn_tensor::Reader<burn_tensor::TensorData>
) -> burn_tensor::TensorData
where
B: FusionBackend<FusionRuntime = R>,
{
@ -52,14 +52,14 @@ where
self.drain_stream(id);
let tensor = self.handles.get_float_tensor::<B, D>(&tensor);
B::float_into_data(tensor)
B::float_into_data(tensor).await
}
pub fn read_int<B, const D: usize>(
pub async fn read_int<B, const D: usize>(
&mut self,
tensor: TensorDescription,
id: StreamId,
) -> burn_tensor::Reader<burn_tensor::TensorData>
) -> burn_tensor::TensorData
where
B: FusionBackend<FusionRuntime = R>,
{
@ -68,14 +68,14 @@ where
self.drain_stream(id);
let tensor = self.handles.get_int_tensor::<B, D>(&tensor);
B::int_into_data(tensor)
B::int_into_data(tensor).await
}
pub fn read_bool<B, const D: usize>(
pub async fn read_bool<B, const D: usize>(
&mut self,
tensor: TensorDescription,
id: StreamId,
) -> burn_tensor::Reader<burn_tensor::TensorData>
) -> burn_tensor::TensorData
where
B: FusionBackend<FusionRuntime = R>,
{
@ -84,7 +84,7 @@ where
self.drain_stream(id);
let tensor = self.handles.get_bool_tensor::<B, D>(&tensor);
B::bool_into_data(tensor)
B::bool_into_data(tensor).await
}
pub fn change_server_float<B, const D: usize>(

View File

@ -1,7 +1,7 @@
use crate::{client::FusionClient, stream::StreamId, Client, FusionBackend, FusionRuntime};
use burn_tensor::{
repr::{TensorDescription, TensorId, TensorStatus},
DType, Reader, Shape, TensorData,
DType, Shape, TensorData,
};
use std::sync::Arc;
@ -108,7 +108,7 @@ impl<R: FusionRuntime> FusionTensor<R> {
}
}
pub(crate) fn into_data<B, const D: usize>(self) -> Reader<TensorData>
pub(crate) async fn into_data<B, const D: usize>(self) -> TensorData
where
B: FusionBackend<FusionRuntime = R>,
{
@ -116,9 +116,10 @@ impl<R: FusionRuntime> FusionTensor<R> {
self.client
.clone()
.read_tensor_float::<B, D>(self.into_description(), id)
.await
}
pub(crate) fn int_into_data<B, const D: usize>(self) -> Reader<TensorData>
pub(crate) async fn int_into_data<B, const D: usize>(self) -> TensorData
where
B: FusionBackend<FusionRuntime = R>,
{
@ -126,9 +127,10 @@ impl<R: FusionRuntime> FusionTensor<R> {
self.client
.clone()
.read_tensor_int::<B, D>(self.into_description(), id)
.await
}
pub(crate) fn bool_into_data<B, const D: usize>(self) -> Reader<TensorData>
pub(crate) async fn bool_into_data<B, const D: usize>(self) -> TensorData
where
B: FusionBackend<FusionRuntime = R>,
{
@ -136,6 +138,7 @@ impl<R: FusionRuntime> FusionTensor<R> {
self.client
.clone()
.read_tensor_bool::<B, D>(self.into_description(), id)
.await
}
}

View File

@ -1,6 +1,6 @@
use crate::{element::JitElement, kernel, tensor::JitTensor, JitRuntime};
use burn_cube::CubeElement;
use burn_tensor::{Reader, Shape, TensorData};
use burn_tensor::{Shape, TensorData};
use std::marker::PhantomData;
pub(crate) fn from_data<R: JitRuntime, E: JitElement, const D: usize>(
@ -14,28 +14,24 @@ pub(crate) fn from_data<R: JitRuntime, E: JitElement, const D: usize>(
JitTensor::new(client, device.clone(), shape, buffer)
}
pub(crate) fn into_data<R: JitRuntime, E: JitElement, const D: usize>(
pub(crate) async fn into_data<R: JitRuntime, E: JitElement, const D: usize>(
tensor: JitTensor<R, E, D>,
) -> Reader<TensorData> {
) -> TensorData {
let tensor = kernel::into_contiguous(tensor);
tensor
.client
.read(tensor.handle.binding())
.map(|bytes| TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape))
let bytes = tensor.client.read_async(tensor.handle.binding()).await;
TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape)
}
pub(crate) fn bool_into_data<R: JitRuntime, const D: usize>(
pub(crate) async fn bool_into_data<R: JitRuntime, const D: usize>(
tensor: JitTensor<R, u32, D>,
) -> Reader<TensorData> {
) -> TensorData {
let tensor = kernel::into_contiguous(tensor);
tensor.client.read(tensor.handle.binding()).map(|bytes| {
TensorData::new(
u32::from_bytes(&bytes).iter().map(|i| *i != 0).collect(),
tensor.shape,
)
})
let bytes = tensor.client.read_async(tensor.handle.binding()).await;
TensorData::new(
u32::from_bytes(&bytes).iter().map(|i| *i != 0).collect(),
tensor.shape,
)
}
pub(crate) fn to_device<R: JitRuntime, E: JitElement, const D: usize>(

View File

@ -1,6 +1,5 @@
use crate::{kernel, FloatElement, IntElement, JitBackend, JitRuntime};
use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntTensor};
use burn_tensor::Reader;
use burn_tensor::{ops::BoolTensorOps, Shape, TensorData};
use std::ops::Range;
@ -20,8 +19,8 @@ where
tensor.shape.clone()
}
fn bool_into_data<const D: usize>(tensor: BoolTensor<Self, D>) -> Reader<TensorData> {
super::bool_into_data(tensor)
async fn bool_into_data<const D: usize>(tensor: BoolTensor<Self, D>) -> TensorData {
super::bool_into_data(tensor).await
}
fn bool_from_data<const D: usize>(

View File

@ -7,8 +7,8 @@ use crate::{FloatElement, IntElement, JitRuntime};
use burn_cube::ir::{BinaryOperator, Elem, Operator, Scope, UnaryOperator, Variable};
use burn_cube::Runtime;
use burn_tensor::ops::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor};
use burn_tensor::ElementConversion;
use burn_tensor::{ops::FloatTensorOps, Distribution, Shape, TensorData};
use burn_tensor::{ElementConversion, Reader};
use std::ops::Range;
impl<R, F, I> FloatTensorOps<Self> for JitBackend<R, F, I>
@ -45,8 +45,8 @@ where
tensor.shape.clone()
}
fn float_into_data<const D: usize>(tensor: FloatTensor<Self, D>) -> Reader<TensorData> {
super::into_data(tensor)
async fn float_into_data<const D: usize>(tensor: FloatTensor<Self, D>) -> TensorData {
super::into_data(tensor).await
}
fn float_device<const D: usize>(tensor: &FloatTensor<Self, D>) -> Device<Self> {

View File

@ -4,7 +4,7 @@ use crate::{kernel, unary, FloatElement, IntElement, JitBackend, JitRuntime};
use burn_cube::ir::{Elem, Item, Operator, Scope, UnaryOperator, Variable};
use burn_cube::Runtime;
use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
use burn_tensor::{ops::IntTensorOps, Distribution, ElementConversion, Reader, Shape, TensorData};
use burn_tensor::{ops::IntTensorOps, Distribution, ElementConversion, Shape, TensorData};
use std::ops::Range;
impl<R, F, I> IntTensorOps<Self> for JitBackend<R, F, I>
@ -21,8 +21,8 @@ where
tensor.shape.clone()
}
fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> Reader<TensorData> {
super::into_data(tensor)
async fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> TensorData {
super::into_data(tensor).await
}
fn int_from_data<const D: usize>(

View File

@ -101,11 +101,10 @@ where
client: ComputeClient<R::Server, R::Channel>,
device: R::Device,
) -> Self {
let bytes = self
.client
.read(self.handle.clone().binding())
.read_sync()
.expect("Can only change client synchronously");
let bytes = burn_common::reader::try_read_sync(
self.client.read_async(self.handle.clone().binding()),
)
.expect("Can only change client synchronously");
let handle = client.create(&bytes);
Self {

View File

@ -2,7 +2,7 @@
use alloc::vec;
use alloc::vec::Vec;
use burn_tensor::ops::{BoolTensorOps, IntTensorOps};
use burn_tensor::{ElementConversion, Reader};
use burn_tensor::ElementConversion;
use core::ops::Range;
use ndarray::IntoDimension;
@ -30,13 +30,12 @@ impl<E: FloatNdArrayElement> BoolTensorOps<Self> for NdArray<E> {
tensor.shape()
}
fn bool_into_data<const D: usize>(
async fn bool_into_data<const D: usize>(
tensor: <NdArray<E> as Backend>::BoolTensorPrimitive<D>,
) -> Reader<TensorData> {
) -> TensorData {
let shape = tensor.shape();
let values = tensor.array.into_iter().collect();
Reader::Concrete(TensorData::new(values, shape))
TensorData::new(values, shape)
}
fn bool_to_device<const D: usize>(
@ -63,10 +62,12 @@ impl<E: FloatNdArrayElement> BoolTensorOps<Self> for NdArray<E> {
fn bool_into_int<const D: usize>(
tensor: <NdArray<E> as Backend>::BoolTensorPrimitive<D>,
) -> NdArrayTensor<i64, D> {
let data = Self::bool_into_data(tensor)
.read_sync()
.expect("Always sync with ndarray");
NdArray::<E>::int_from_data(data.convert::<i64>(), &NdArrayDevice::Cpu)
let shape = tensor.shape();
let values = tensor.array.into_iter().collect();
NdArray::<E>::int_from_data(
TensorData::new(values, shape).convert::<i64>(),
&NdArrayDevice::Cpu,
)
}
fn bool_device<const D: usize>(

View File

@ -3,7 +3,7 @@ use alloc::vec;
use alloc::vec::Vec;
use burn_common::rand::get_seeded_rng;
use burn_tensor::ops::IntTensorOps;
use burn_tensor::{Distribution, Reader};
use burn_tensor::Distribution;
use burn_tensor::ElementConversion;
use core::ops::Range;
@ -32,11 +32,10 @@ impl<E: FloatNdArrayElement> IntTensorOps<Self> for NdArray<E> {
tensor.shape()
}
fn int_into_data<const D: usize>(tensor: NdArrayTensor<i64, D>) -> Reader<TensorData> {
async fn int_into_data<const D: usize>(tensor: NdArrayTensor<i64, D>) -> TensorData {
let shape = tensor.shape();
let values = tensor.array.into_iter().collect();
Reader::Concrete(TensorData::new(values, shape))
TensorData::new(values, shape)
}
fn int_to_device<const D: usize>(

View File

@ -11,8 +11,8 @@ use crate::{NdArrayDevice, SEED};
// Workspace crates
use burn_common::rand::get_seeded_rng;
use burn_tensor::Distribution;
use burn_tensor::{backend::Backend, ops::FloatTensorOps, ElementConversion, Shape, TensorData};
use burn_tensor::{Distribution, Reader};
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
@ -51,11 +51,10 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
tensor.shape()
}
fn float_into_data<const D: usize>(tensor: NdArrayTensor<E, D>) -> Reader<TensorData> {
async fn float_into_data<const D: usize>(tensor: NdArrayTensor<E, D>) -> TensorData {
let shape = tensor.shape();
let values = tensor.array.into_iter().collect();
Reader::Concrete(TensorData::new(values, shape))
TensorData::new(values, shape)
}
fn float_device<const D: usize>(_tensor: &NdArrayTensor<E, D>) -> NdArrayDevice {

View File

@ -1,6 +1,6 @@
use super::TchOps;
use crate::{element::TchElement, LibTorch, LibTorchDevice, TchTensor};
use burn_tensor::{backend::Backend, ops::BoolTensorOps, Reader, Shape, TensorData};
use burn_tensor::{backend::Backend, ops::BoolTensorOps, Shape, TensorData};
use std::ops::Range;
impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
@ -23,12 +23,11 @@ impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
TchOps::repeat(tensor, dim, times)
}
fn bool_into_data<const D: usize>(tensor: TchTensor<bool, D>) -> Reader<TensorData> {
async fn bool_into_data<const D: usize>(tensor: TchTensor<bool, D>) -> TensorData {
let shape = Self::bool_shape(&tensor);
let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
let values: Result<Vec<bool>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
Reader::Concrete(TensorData::new(values.unwrap(), shape))
TensorData::new(values.unwrap(), shape)
}
fn bool_to_device<const D: usize>(
@ -143,13 +142,13 @@ impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
TchOps::flip(tensor, axes)
}
fn bool_argwhere<const D: usize>(
async fn bool_argwhere<const D: usize>(
tensor: <LibTorch<E> as Backend>::BoolTensorPrimitive<D>,
) -> TchTensor<i64, 2> {
TchTensor::new(tensor.tensor.argwhere())
}
fn bool_nonzero<const D: usize>(
async fn bool_nonzero<const D: usize>(
tensor: <LibTorch<E> as Backend>::BoolTensorPrimitive<D>,
) -> Vec<TchTensor<i64, 1>> {
tensor

View File

@ -1,6 +1,6 @@
use std::ops::Range;
use burn_tensor::{backend::Backend, ops::IntTensorOps, Distribution, Reader, Shape, TensorData};
use burn_tensor::{backend::Backend, ops::IntTensorOps, Distribution, Shape, TensorData};
use crate::{element::TchElement, LibTorch, LibTorchDevice, TchShape, TchTensor};
@ -26,12 +26,11 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
TchOps::repeat(tensor, dim, times)
}
fn int_into_data<const D: usize>(tensor: TchTensor<i64, D>) -> Reader<TensorData> {
async fn int_into_data<const D: usize>(tensor: TchTensor<i64, D>) -> TensorData {
let shape = Self::int_shape(&tensor);
let tensor = Self::int_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
let values: Result<Vec<i64>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
Reader::Concrete(TensorData::new(values.unwrap(), shape))
TensorData::new(values.unwrap(), shape)
}
fn int_to_device<const D: usize>(

View File

@ -1,8 +1,7 @@
use super::TchOps;
use crate::{element::TchElement, LibTorch, LibTorchDevice, TchShape, TchTensor};
use burn_tensor::{
backend::Backend, ops::FloatTensorOps, Distribution, ElementConversion, Reader, Shape,
TensorData,
backend::Backend, ops::FloatTensorOps, Distribution, ElementConversion, Shape, TensorData,
};
use std::ops::Range;
@ -71,14 +70,14 @@ impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
tensor.shape()
}
fn float_into_data<const D: usize>(
async fn float_into_data<const D: usize>(
tensor: <LibTorch<E> as Backend>::FloatTensorPrimitive<D>,
) -> Reader<TensorData> {
) -> TensorData {
let shape = Self::float_shape(&tensor);
let tensor = Self::float_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
let values: Result<Vec<E>, tch::TchError> = tensor.tensor.try_into();
Reader::Concrete(TensorData::new(values.unwrap(), shape))
TensorData::new(values.unwrap(), shape)
}
fn float_device<const D: usize>(tensor: &TchTensor<E, D>) -> LibTorchDevice {

View File

@ -295,21 +295,6 @@ impl<E: tch::kind::Element + Default + Element, const D: usize> TchTensor<E, D>
}
}
#[cfg(test)]
mod utils {
use super::*;
use crate::{backend::LibTorch, element::TchElement};
impl<P: TchElement, const D: usize> TchTensor<P, D> {
pub(crate) fn into_data(self) -> TensorData
where
P: tch::kind::Element,
{
<LibTorch<P> as FloatTensorOps<LibTorch<P>>>::float_into_data(self).read()
}
}
}
impl<E: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TchTensor<E, D> {
/// Creates an empty tensor from a shape and a device.
///
@ -345,7 +330,7 @@ mod tests {
);
let tensor = TchTensor::<f32, 1>::from_data(data_expected.clone(), tch::Device::Cpu);
let data_actual = tensor.into_data();
let data_actual = Tensor::<LibTorch<f32>, 1>::from_primitive(tensor).into_data();
assert_eq!(data_expected, data_actual);
}
@ -359,7 +344,7 @@ mod tests {
);
let tensor = TchTensor::<f32, 2>::from_data(data_expected.clone(), tch::Device::Cpu);
let data_actual = tensor.into_data();
let data_actual = Tensor::<LibTorch<f32>, 2>::from_primitive(tensor).into_data();
assert_eq!(data_expected, data_actual);
}

View File

@ -17,7 +17,6 @@ experimental-named-tensor = []
export_tests = ["burn-tensor-testgen"]
std = ["rand/std", "half/std", "num-traits/std"]
repr = []
wasm-sync = []
[dependencies]
burn-common = { path = "../burn-common", version = "0.14.0", default-features = false }
@ -27,7 +26,7 @@ derive-new = { workspace = true }
half = { workspace = true, features = ["bytemuck"] }
num-traits = { workspace = true }
rand = { workspace = true }
rand_distr = { workspace = true } # use instead of statrs because it supports no_std
rand_distr = { workspace = true } # use instead of statrs because it supports no_std
bytemuck = { workspace = true }
# The same implementation of HashMap in std but with no_std support (only needs alloc crate)

View File

@ -25,4 +25,4 @@ pub use half::{bf16, f16};
pub(crate) use tensor::check::macros::check;
pub use tensor::*;
pub use burn_common::reader::Reader; // Useful so that backends don't have to add `burn_common` as
pub use burn_common::reader::*; // Useful so that backends don't have to add `burn_common` as a dependency.

View File

@ -1,15 +1,11 @@
use crate::{
backend::Backend,
ops::{BoolTensor, IntTensor},
Device, ElementConversion, Shape, TensorData,
};
use crate::{backend::Backend, ops::IntTensor, Device, ElementConversion, Shape, TensorData};
use alloc::vec::Vec;
/// Compute the indices of the elements that are non-zero, grouped by element.
///
/// # Arguments
///
/// * `tensor` - The input tensor.
/// * `data` - The input tensor data.
///
/// # Returns
///
@ -22,45 +18,7 @@ use alloc::vec::Vec;
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
pub fn argwhere<B: Backend, const D: usize>(tensor: BoolTensor<B, D>) -> IntTensor<B, 2> {
// Size of each output tensor is variable (= number of nonzero elements in the tensor).
// Reading the data to count the number of truth values might cause sync but is required.
// let dims = B::bool_shape(&tensor).dims;
let device = B::bool_device(&tensor);
let data = B::bool_into_data(tensor).read();
argwhere_data::<B, D>(data, &device)
}
/// Compute the indices of the elements that are non-zero, grouped by element.
///
/// # Arguments
///
/// * `tensor` - The input tensor.
///
/// # Returns
///
/// A vector of tensors, one for each dimension of the given tensor, containing the indices of
/// the non-zero elements in that dimension.
///
/// # Remarks
///
/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn argwhere<B: Backend, const D: usize>(tensor: BoolTensor<B, D>) -> IntTensor<B, 2> {
// Size of each output tensor is variable (= number of nonzero elements in the tensor).
// Reading the data to count the number of truth values might cause sync but is required.
let device = B::bool_device(&tensor);
let data = B::bool_into_data(tensor).read().await;
argwhere_data::<B, D>(data, &device)
}
fn argwhere_data<B: Backend, const D: usize>(
pub fn argwhere_data<B: Backend, const D: usize>(
data: TensorData,
device: &Device<B>,
) -> IntTensor<B, 2> {

View File

@ -2,19 +2,16 @@
use alloc::vec::Vec;
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
use alloc::format;
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
use alloc::string::String;
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
use alloc::vec;
use burn_common::{reader::Reader, stub::Mutex};
use burn_common::stub::Mutex;
use core::future::Future;
use core::iter::repeat;
use core::{fmt::Debug, ops::Range};
use serde::{Deserialize, Deserializer};
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
use serde::{Serialize, Serializer};
use crate::check::TensorCheck;
@ -675,28 +672,28 @@ where
Self::new(K::to_device(self.primitive, device))
}
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
/// Returns the data of the current tensor.
pub async fn into_data(self) -> TensorData {
K::into_data(self.primitive).read().await
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
/// Returns the data of the current tensor.
/// Converts the data of the current tensor.
pub fn into_data(self) -> TensorData {
K::into_data(self.primitive).read()
crate::try_read_sync(self.into_data_async()).expect(
"Failed to read tensor data synchronously.
This can happen on platforms that don't support blocking futures like WASM.
If possible, try using into_data_async instead.",
)
}
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
/// Returns the data of the current tensor.
pub async fn to_data(&self) -> TensorData {
K::into_data(self.primitive.clone()).read().await
pub fn to_data(&self) -> TensorData {
self.clone().into_data()
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
/// Returns the data of the current tensor without taking ownership.
pub fn to_data(&self) -> TensorData {
Self::into_data(self.clone())
/// Returns the data of the current tensor.
pub async fn into_data_async(self) -> TensorData {
K::into_data_async(self.primitive).await
}
/// Returns the data of the current tensor.
pub async fn to_data_async(&self) -> TensorData {
self.clone().into_data_async().await
}
/// Create a tensor from the given data on the given device.
@ -875,11 +872,12 @@ where
/// # Panics
///
/// If the tensor doesn't have one element.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
/// If the backend fails to read the tensor data synchronously.
pub fn into_scalar(self) -> K::Elem {
check!(TensorCheck::into_scalar(&self.shape()));
let x = self.into_data().iter().next().unwrap();
x
crate::try_read_sync(self.into_scalar_async()).expect(
"Failed to read tensor data synchronously. This can happen on platforms
that don't support blocking futures like WASM. Try into_scalar_async instead.",
)
}
/// Convert the tensor into a scalar.
@ -887,10 +885,9 @@ where
/// # Panics
///
/// If the tensor doesn't have one element.
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn into_scalar(self) -> K::Elem {
pub async fn into_scalar_async(self) -> K::Elem {
check!(TensorCheck::into_scalar(&self.shape()));
let x = self.into_data().await.iter().next().unwrap();
let x = self.into_data_async().await.iter().next().unwrap();
x
}
@ -989,7 +986,6 @@ where
K: BasicOps<B>,
<K as BasicOps<B>>::Elem: Debug,
{
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
#[inline]
fn push_newline_indent(acc: &mut String, indent: usize) {
acc.push('\n');
@ -998,7 +994,6 @@ where
}
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn fmt_inner_tensor(
&self,
acc: &mut String,
@ -1015,18 +1010,18 @@ where
let range: [core::ops::Range<usize>; D] =
core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);
let elem = &self
.clone()
.slice(range)
.into_data()
.iter::<<K as BasicOps<B>>::Elem>()
.next()
.unwrap();
acc.push_str(&format!("{elem:?}"));
let data =
burn_common::reader::try_read_sync(self.clone().slice(range).into_data_async());
if let Some(data) = data {
let elem = data.iter::<<K as BasicOps<B>>::Elem>().next().unwrap();
acc.push_str(&format!("{elem:?}"));
} else {
acc.push_str("<Tensor data not available>");
}
}
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn fmt_outer_tensor(
&self,
acc: &mut String,
@ -1061,7 +1056,6 @@ where
/// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output.
/// * `depth` - The current depth of the tensor dimensions being processed.
/// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn display_recursive(
&self,
acc: &mut String,
@ -1172,7 +1166,6 @@ where
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
writeln!(f, "Tensor {{")?;
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
{
let po = PRINT_OPTS.lock().unwrap();
let mut acc = String::new();
@ -1435,7 +1428,7 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
device: &B::Device,
) -> Self::Primitive<D>;
/// Extracts the data from the tensor.
/// Extracts the data from the tensor asynchronously.
///
/// # Arguments
///
@ -1453,7 +1446,9 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
///
/// For extracting the data of a tensor, users should prefer the [Tensor::into_data](Tensor::into_data) function,
/// which is more high-level and designed for public use.
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Reader<TensorData>;
fn into_data_async<const D: usize>(
tensor: Self::Primitive<D>,
) -> impl Future<Output = TensorData> + Send;
/// Creates a tensor from the given data.
///
@ -1724,8 +1719,8 @@ impl<B: Backend> BasicOps<B> for Float {
B::float_to_device(tensor, device)
}
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Reader<TensorData> {
B::float_into_data(tensor)
async fn into_data_async<const D: usize>(tensor: Self::Primitive<D>) -> TensorData {
B::float_into_data(tensor).await
}
fn from_data<const D: usize>(data: TensorData, device: &B::Device) -> Self::Primitive<D> {
@ -1846,8 +1841,8 @@ impl<B: Backend> BasicOps<B> for Int {
B::int_to_device(tensor, device)
}
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Reader<TensorData> {
B::int_into_data(tensor)
async fn into_data_async<const D: usize>(tensor: Self::Primitive<D>) -> TensorData {
B::int_into_data(tensor).await
}
fn from_data<const D: usize>(data: TensorData, device: &B::Device) -> Self::Primitive<D> {
@ -1968,8 +1963,8 @@ impl<B: Backend> BasicOps<B> for Bool {
B::bool_to_device(tensor, device)
}
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Reader<TensorData> {
B::bool_into_data(tensor)
async fn into_data_async<const D: usize>(tensor: Self::Primitive<D>) -> TensorData {
B::bool_into_data(tensor).await
}
fn from_data<const D: usize>(data: TensorData, device: &B::Device) -> Self::Primitive<D> {
@ -2230,7 +2225,6 @@ impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for [i32; D2] {
}
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
impl<B, const D: usize, K> Serialize for Tensor<B, D, K>
where
B: Backend,

View File

@ -1,8 +1,7 @@
use crate::{backend::Backend, Bool, Int, Shape, Tensor, TensorData};
use alloc::vec::Vec;
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
use crate::argwhere;
use crate::try_read_sync;
/// The part of the tensor to keep when creating a triangular mask.
enum TriPart {
@ -46,27 +45,21 @@ where
///
/// A vector of tensors, one for each dimension of the given tensor, containing the indices of
/// the non-zero elements in that dimension.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
pub fn nonzero(self) -> Vec<Tensor<B, 1, Int>> {
B::bool_nonzero(self.primitive)
.into_iter()
.map(Tensor::new)
.collect()
try_read_sync(self.nonzero_async())
.expect("Failed to read tensor data synchronously. Try using nonzero_async instead.")
}
/// Compute the indices of the elements that are true.
/// Compute the indices of the elements that are non-zero.
///
/// # Returns
///
/// A vector of tensors, one for each dimension of the given tensor, containing the indices of
/// the non-zero elements in that dimension.
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn nonzero(self) -> Vec<Tensor<B, 1, Int>> {
let indices = self.argwhere().await.primitive;
let dims = B::int_shape(&indices).dims;
B::int_chunk(indices, dims[1], 1)
pub async fn nonzero_async(self) -> Vec<Tensor<B, 1, Int>> {
B::bool_nonzero(self.primitive)
.await
.into_iter()
.map(|t| B::int_reshape(t, Shape::new([dims[0]])))
.map(Tensor::new)
.collect()
}
@ -77,9 +70,9 @@ where
///
/// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
/// result contains the indices of a non-zero element.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
pub fn argwhere(self) -> Tensor<B, 2, Int> {
Tensor::new(B::bool_argwhere(self.primitive))
try_read_sync(self.argwhere_async())
.expect("Failed to read tensor data synchronously. Try using argwhere_async instead.")
}
/// Compute the indices of the elements that are true, grouped by element.
@ -88,9 +81,8 @@ where
///
/// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
/// result contains the indices of a non-zero element.
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn argwhere(self) -> Tensor<B, 2, Int> {
Tensor::new(argwhere::<B, D>(self.primitive).await)
pub async fn argwhere_async(self) -> Tensor<B, 2, Int> {
Tensor::new(B::bool_argwhere(self.primitive).await)
}
/// Creates a mask for the upper, lower triangle, or diagonal of a matrix, which can be used to

View File

@ -10,9 +10,6 @@ use crate::tensor::{Distribution, Shape, TensorData};
use crate::Int;
use crate::Tensor;
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
use crate::{argsort, sort, sort_with_indices, Float};
impl<const D: usize, B> Tensor<B, D>
where
B: Backend,
@ -265,88 +262,4 @@ where
.matmul(centered)
.div_scalar(n as f32 - correction_factor as f32)
}
/// Sort the elements by value in ascending order along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn sort(self, dim: usize) -> Tensor<B, D> {
Tensor::new(sort::<B, D, Float>(self.primitive, dim, /*descending*/ false).await)
}
/// Sort the elements by value in descending order along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn sort_descending(self, dim: usize) -> Tensor<B, D> {
Tensor::new(sort::<B, D, Float>(self.primitive, dim, /*descending*/ true).await)
}
/// Sort the elements by value in ascending order along a given dimension.
/// Also returns the indices.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn sort_with_indices(self, dim: usize) -> (Tensor<B, D>, Tensor<B, D, Int>) {
check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
let (values, indices) =
sort_with_indices::<B, D, Float>(self.primitive, dim, /*descending*/ false).await;
(Tensor::new(values), Tensor::new(indices))
}
/// Sort the elements by value in descending order along a given dimension.
/// Also returns the indices.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn sort_descending_with_indices(
self,
dim: usize,
) -> (Tensor<B, D>, Tensor<B, D, Int>) {
check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
let (values, indices) =
sort_with_indices::<B, D, Float>(self.primitive, dim, /*descending*/ true).await;
(Tensor::new(values), Tensor::new(indices))
}
/// Returns the indices that sort the elements by value in ascending order along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn argsort(self, dim: usize) -> Tensor<B, D, Int> {
check!(TensorCheck::sort_dim::<D>("Argsort", dim));
Tensor::new(argsort::<B, D, Float>(self.primitive, dim, /*descending*/ false).await)
}
/// Returns the indices that sort the elements by value in descending order along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn argsort_descending(self, dim: usize) -> Tensor<B, D, Int> {
check!(TensorCheck::sort_dim::<D>("Argsort", dim));
Tensor::new(argsort::<B, D, Float>(self.primitive, dim, /*descending*/ true).await)
}
/// Returns the `k` largest elements of the given input tensor along a given dimension.
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn topk(self, k: usize, dim: usize) -> Tensor<B, D> {
let k_indices = Tensor::arange(0..k as i64, &self.device());
self.sort_descending(dim).await.select(dim, k_indices)
}
/// Returns the `k` largest elements of the given input tensor along a given dimension.
/// Also returns the indices.
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn topk_with_indices(
self,
k: usize,
dim: usize,
) -> (Tensor<B, D>, Tensor<B, D, Int>) {
let k_indices = Tensor::arange(0..k as i64, &self.device());
let (values, indices) = self.sort_descending_with_indices(dim).await;
(
values.select(dim, k_indices.clone()),
indices.select(dim, k_indices),
)
}
}

View File

@ -2,9 +2,6 @@ use crate::{backend::Backend, Float, Int, Shape, Tensor, TensorData};
use core::ops::Range;
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
use crate::{argsort, check, check::TensorCheck, sort, sort_with_indices};
impl<B> Tensor<B, 1, Int>
where
B: Backend,
@ -100,88 +97,4 @@ where
) -> Tensor<B, D2, Int> {
Tensor::new(B::int_cartesian_grid::<S, D, D2>(shape, device))
}
/// Sort the elements by value in ascending order along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn sort(self, dim: usize) -> Tensor<B, D, Int> {
Tensor::new(sort::<B, D, Int>(self.primitive, dim, /* descending */ false).await)
}
/// Sort the elements by value in descending order along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn sort_descending(self, dim: usize) -> Tensor<B, D, Int> {
Tensor::new(sort::<B, D, Int>(self.primitive, dim, /* descending */ true).await)
}
/// Sort the elements by value in ascending order along a given dimension.
/// Also returns the indices.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn sort_with_indices(self, dim: usize) -> (Tensor<B, D, Int>, Tensor<B, D, Int>) {
check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
let (values, indices) =
sort_with_indices::<B, D, Int>(self.primitive, dim, /*descending*/ false).await;
(Tensor::new(values), Tensor::new(indices))
}
/// Sort the elements by value in descending order along a given dimension.
/// Also returns the indices.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn sort_descending_with_indices(
self,
dim: usize,
) -> (Tensor<B, D, Int>, Tensor<B, D, Int>) {
check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
let (values, indices) =
sort_with_indices::<B, D, Int>(self.primitive, dim, /*descending*/ true).await;
(Tensor::new(values), Tensor::new(indices))
}
/// Returns the indices that sort the elements by value in ascending order along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn argsort(self, dim: usize) -> Tensor<B, D, Int> {
check!(TensorCheck::sort_dim::<D>("Argsort", dim));
Tensor::new(argsort::<B, D, Int>(self.primitive, dim, /*descending*/ false).await)
}
/// Returns the indices that sort the elements by value in descending order along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn argsort_descending(self, dim: usize) -> Tensor<B, D, Int> {
check!(TensorCheck::sort_dim::<D>("Argsort", dim));
Tensor::new(argsort::<B, D, Int>(self.primitive, dim, /*descending*/ true).await)
}
/// Returns the `k` largest elements of the given input tensor along a given dimension.
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn topk(self, k: usize, dim: usize) -> Tensor<B, D, Int> {
let k_indices = Tensor::arange(0..k as i64, &self.device());
self.sort_descending(dim).await.select(dim, k_indices)
}
/// Returns the `k` largest elements of the given input tensor along a given dimension.
/// Also returns the indices.
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn topk_with_indices(
self,
k: usize,
dim: usize,
) -> (Tensor<B, D, Int>, Tensor<B, D, Int>) {
let k_indices = Tensor::arange(0..k as i64, &self.device());
let (values, indices) = self.sort_descending_with_indices(dim).await;
(
values.select(dim, k_indices.clone()),
indices.select(dim, k_indices),
)
}
}

View File

@ -13,7 +13,7 @@ mod narrow;
mod numeric;
mod sort;
pub use argwhere::argwhere;
pub use argwhere::argwhere_data;
pub use autodiff::*;
pub use base::*;
pub use cartesian_grid::cartesian_grid;

View File

@ -642,9 +642,6 @@ where
/// A boolean scalar.
///
/// # Remarks
///
/// This method is only available for non-wasm targets or when the `wasm-sync` feature is enabled.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
pub fn all_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> bool {
self.is_close(other, rtol, atol).all().into_scalar()
}
@ -671,7 +668,6 @@ where
/// Sort the elements by value in ascending order along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
pub fn sort(self, dim: usize) -> Tensor<B, D, K> {
check!(TensorCheck::sort_dim::<D>("Sort", dim));
Tensor::new(K::sort(self.primitive, dim, /*descending*/ false))
@ -680,7 +676,6 @@ where
/// Sort the elements by value in descending order along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
pub fn sort_descending(self, dim: usize) -> Tensor<B, D, K> {
check!(TensorCheck::sort_dim::<D>("Sort", dim));
Tensor::new(K::sort(self.primitive, dim, /*descending*/ true))
@ -690,7 +685,6 @@ where
/// Also returns the indices.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
pub fn sort_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
let (values, indices) =
@ -702,7 +696,6 @@ where
/// Also returns the indices.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
pub fn sort_descending_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
let (values, indices) = K::sort_with_indices(self.primitive, dim, /*descending*/ true);
@ -712,7 +705,6 @@ where
/// Returns the indices that sort the elements by value in ascending order along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
pub fn argsort(self, dim: usize) -> Tensor<B, D, Int> {
check!(TensorCheck::sort_dim::<D>("Argsort", dim));
Tensor::new(K::argsort(self.primitive, dim, /*descending*/ false))
@ -721,14 +713,12 @@ where
/// Returns the indices that sort the elements by value in descending order along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
pub fn argsort_descending(self, dim: usize) -> Tensor<B, D, Int> {
check!(TensorCheck::sort_dim::<D>("Argsort", dim));
Tensor::new(K::argsort(self.primitive, dim, /*descending*/ true))
}
/// Returns the `k` largest elements of the given input tensor along a given dimension.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
pub fn topk(self, k: usize, dim: usize) -> Tensor<B, D, K> {
let k_indices = Tensor::arange(0..k as i64, &self.device());
self.sort_descending(dim).select(dim, k_indices)
@ -736,7 +726,6 @@ where
/// Returns the `k` largest elements of the given input tensor along a given dimension.
/// Also returns the indices.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
pub fn topk_with_indices(self, k: usize, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
let k_indices = Tensor::arange(0..k as i64, &self.device());
let (values, indices) = self.sort_descending_with_indices(dim);
@ -2029,7 +2018,6 @@ where
///
/// Users should prefer the [Tensor::sort](Tensor::sort) function,
/// which is more high-level and designed for public use.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn sort<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
@ -2059,7 +2047,6 @@ where
/// For sorting the elements of a tensor, users should prefer the
/// [Tensor::sort_with_indices](Tensor::sort_with_indices) function, which is more high-level
/// and designed for public use.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn sort_with_indices<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
@ -2087,7 +2074,6 @@ where
///
/// Users should prefer the [Tensor::argsort](Tensor::argsort) function,
/// which is more high-level and designed for public use.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn argsort<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
@ -2411,7 +2397,6 @@ impl<B: Backend> Numeric<B> for Int {
B::int_sign(tensor)
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn sort<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
@ -2420,7 +2405,6 @@ impl<B: Backend> Numeric<B> for Int {
B::int_sort(tensor, dim, descending)
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn sort_with_indices<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
@ -2429,7 +2413,6 @@ impl<B: Backend> Numeric<B> for Int {
B::int_sort_with_indices(tensor, dim, descending)
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn argsort<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
@ -2758,7 +2741,6 @@ impl<B: Backend> Numeric<B> for Float {
B::float_sign(tensor)
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn sort<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
@ -2767,7 +2749,6 @@ impl<B: Backend> Numeric<B> for Float {
B::float_sort(tensor, dim, descending)
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn sort_with_indices<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
@ -2776,7 +2757,6 @@ impl<B: Backend> Numeric<B> for Float {
B::float_sort_with_indices(tensor, dim, descending)
}
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn argsort<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,

View File

@ -6,6 +6,7 @@ use crate::{
BasicOps, Device, Element, ElementComparison, ElementConversion, TensorData, TensorKind,
};
use alloc::vec::Vec;
use burn_common::reader::try_read_sync;
/// Sort the elements of the input `tensor` by value along a given dimension.
///
@ -27,7 +28,6 @@ use alloc::vec::Vec;
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
pub fn sort<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B>>(
tensor: K::Primitive<D>,
dim: usize,
@ -37,43 +37,7 @@ where
<K as BasicOps<B>>::Elem: Element,
{
let device = K::device(&tensor);
let data = K::into_data(tensor).read();
sort_data::<B, D, K>(data, dim, &device, descending)
}
/// Sort the elements of the input `tensor` by value along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
///
/// # Arguments
///
/// * `tensor` - The input tensor.
/// * `dim` - The axis along which to sort.
/// * `descending` - The sorting order.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor, where the elements are sorted by value.
///
/// # Remarks
///
/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn sort<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B>>(
tensor: K::Primitive<D>,
dim: usize,
descending: bool,
) -> K::Primitive<D>
where
<K as BasicOps<B>>::Elem: Element,
{
let device = K::device(&tensor);
let data = K::into_data(tensor).read().await;
let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchonously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.");
sort_data::<B, D, K>(data, dim, &device, descending)
}
@ -119,7 +83,6 @@ where
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
pub fn sort_with_indices<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B>>(
tensor: K::Primitive<D>,
dim: usize,
@ -129,44 +92,7 @@ where
<K as BasicOps<B>>::Elem: Element,
{
let device = K::device(&tensor);
let data = K::into_data(tensor).read();
sort_data_with_indices::<B, D, K>(data, dim, &device, descending)
}
/// Sort the elements of the input `tensor` by value along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
///
/// # Arguments
///
/// * `tensor` - The input tensor.
/// * `dim` - The axis along which to sort.
/// * `descending` - The sorting order.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor and corresponding indices, where
/// the elements are sorted by value and the indices map back to the original input tensor.
///
/// # Remarks
///
/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn sort_with_indices<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B>>(
tensor: K::Primitive<D>,
dim: usize,
descending: bool,
) -> (K::Primitive<D>, IntTensor<B, D>)
where
<K as BasicOps<B>>::Elem: Element,
{
let device = K::device(&tensor);
let data = K::into_data(tensor).read().await;
let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchonously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.");
sort_data_with_indices::<B, D, K>(data, dim, &device, descending)
}
@ -253,7 +179,6 @@ where
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
pub fn argsort<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B>>(
tensor: K::Primitive<D>,
dim: usize,
@ -263,42 +188,7 @@ where
<K as BasicOps<B>>::Elem: Element,
{
let device = K::device(&tensor);
let data = K::into_data(tensor).read();
argsort_data::<B, D, K>(data, dim, &device, descending)
}
/// Returns the indices that sort the elements of the input `tensor` along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
///
/// # Arguments
///
/// * `tensor` - The input tensor.
/// * `dim` - The axis along which to sort.
/// * `descending` - The sorting order.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
///
/// # Remarks
///
/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn argsort<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B>>(
tensor: K::Primitive<D>,
dim: usize,
descending: bool,
) -> IntTensor<B, D>
where
<K as BasicOps<B>>::Elem: Element,
{
let device = K::device(&tensor);
let data = K::into_data(tensor).read().await;
let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchonously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.");
argsort_data::<B, D, K>(data, dim, &device, descending)
}

View File

@ -3,14 +3,11 @@ use super::{
IntTensor,
};
use crate::{
backend::Backend, chunk, narrow, tensor::Shape, Bool, ElementConversion, Tensor, TensorData,
argwhere_data, backend::Backend, chunk, narrow, tensor::Shape, Bool, ElementConversion, Tensor,
TensorData,
};
use alloc::vec::Vec;
use burn_common::reader::Reader;
use core::ops::Range;
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
use crate::argwhere;
use core::{future::Future, ops::Range};
/// Bool Tensor API for basic operations, see [tensor](crate::Tensor)
/// for documentation on each function.
@ -47,21 +44,9 @@ pub trait BoolTensorOps<B: Backend> {
/// # Returns
///
/// The data structure with the tensor's data.
fn bool_into_data<const D: usize>(tensor: BoolTensor<B, D>) -> Reader<TensorData>;
/// Gets the data from the tensor.
///
///
/// # Arguments
///
/// * `data` - The data structure.
///
/// # Returns
///
/// The data cloned from the data structure.
fn bool_to_data<const D: usize>(tensor: &BoolTensor<B, D>) -> Reader<TensorData> {
Self::bool_into_data(tensor.clone())
}
fn bool_into_data<const D: usize>(
tensor: BoolTensor<B, D>,
) -> impl Future<Output = TensorData> + Send;
/// Creates a tensor from the data structure.
///
@ -420,9 +405,16 @@ pub trait BoolTensorOps<B: Backend> {
///
/// A vector of tensors, one for each dimension of the given tensor, containing the indices of
/// the non-zero elements in that dimension.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn bool_argwhere<const D: usize>(tensor: BoolTensor<B, D>) -> IntTensor<B, 2> {
argwhere::<B, D>(tensor)
fn bool_argwhere<const D: usize>(
tensor: BoolTensor<B, D>,
) -> impl Future<Output = IntTensor<B, 2>> + Send {
async {
// Size of each output tensor is variable (= number of nonzero elements in the tensor).
// Reading the data to count the number of truth values might cause sync but is required.
let device = B::bool_device(&tensor);
let data = B::bool_into_data(tensor).await;
argwhere_data::<B, D>(data, &device)
}
}
/// Compute the indices of the elements that are non-zero.
@ -435,14 +427,17 @@ pub trait BoolTensorOps<B: Backend> {
///
/// A vector of tensors, one for each dimension of the given tensor, containing the indices of
/// the non-zero elements in that dimension.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn bool_nonzero<const D: usize>(tensor: BoolTensor<B, D>) -> Vec<IntTensor<B, 1>> {
let indices = B::bool_argwhere(tensor);
let dims = B::int_shape(&indices).dims;
B::int_chunk(indices, dims[1], 1)
.into_iter()
.map(|t| B::int_reshape(t, Shape::new([dims[0]])))
.collect()
fn bool_nonzero<const D: usize>(
tensor: BoolTensor<B, D>,
) -> impl Future<Output = Vec<IntTensor<B, 1>>> + Send {
async {
let indices = B::bool_argwhere(tensor).await;
let dims = B::int_shape(&indices).dims;
B::int_chunk(indices, dims[1], 1)
.into_iter()
.map(|t| B::int_reshape(t, Shape::new([dims[0]])))
.collect()
}
}
/// Broadcasts the bool `tensor` to the given `shape`.

View File

@ -6,10 +6,9 @@ use crate::{backend::Backend, tensor::Shape, Distribution, ElementConversion, In
use crate::{cartesian_grid, Tensor};
use crate::{tensor::api::chunk, tensor::api::narrow};
use alloc::vec::Vec;
use burn_common::reader::Reader;
use core::future::Future;
use core::ops::Range;
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
use crate::{argsort, sort, sort_with_indices};
/// Int Tensor API for basic and numeric operations, see [tensor](crate::Tensor)
@ -47,20 +46,9 @@ pub trait IntTensorOps<B: Backend> {
/// # Returns
///
/// The data structure with the tensor's data.
fn int_into_data<const D: usize>(tensor: IntTensor<B, D>) -> Reader<TensorData>;
/// Gets the data from the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The data cloned from the data structure.
fn int_to_data<const D: usize>(tensor: &IntTensor<B, D>) -> Reader<TensorData> {
Self::int_into_data(tensor.clone())
}
fn int_into_data<const D: usize>(
tensor: IntTensor<B, D>,
) -> impl Future<Output = TensorData> + Send;
/// Creates a tensor from the data structure.
///
@ -1241,7 +1229,6 @@ pub trait IntTensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the same shape as the input tensor, where the elements are sorted by value.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn int_sort<const D: usize>(
tensor: IntTensor<B, D>,
dim: usize,
@ -1263,7 +1250,6 @@ pub trait IntTensorOps<B: Backend> {
///
/// A tensor with the same shape as the input tensor and corresponding indices, where
/// the elements are sorted by value and the indices map back to the original input tensor.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn int_sort_with_indices<const D: usize>(
tensor: IntTensor<B, D>,
dim: usize,
@ -1286,7 +1272,6 @@ pub trait IntTensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn int_argsort<const D: usize>(
tensor: IntTensor<B, D>,
dim: usize,

View File

@ -7,10 +7,9 @@ use crate::Tensor;
use crate::{backend::Backend, tensor::Shape, Distribution, ElementConversion, Float, TensorData};
use crate::{tensor::api::chunk, tensor::api::narrow};
use alloc::vec::Vec;
use burn_common::reader::Reader;
use core::future::Future;
use core::ops::Range;
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
use crate::{argsort, sort, sort_with_indices};
/// Operations on float tensors.
@ -111,20 +110,9 @@ pub trait FloatTensorOps<B: Backend> {
/// # Returns
///
/// The data structure with the tensor's data.
fn float_to_data<const D: usize>(tensor: &FloatTensor<B, D>) -> Reader<TensorData> {
Self::float_into_data(tensor.clone())
}
/// Converts the tensor to a data structure.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The data structure with the tensor's data.
fn float_into_data<const D: usize>(tensor: FloatTensor<B, D>) -> Reader<TensorData>;
fn float_into_data<const D: usize>(
tensor: FloatTensor<B, D>,
) -> impl Future<Output = TensorData> + Send;
/// Gets the device of the tensor.
///
@ -1389,7 +1377,6 @@ pub trait FloatTensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the same shape as the input tensor, where the elements are sorted by value.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn float_sort<const D: usize>(
tensor: FloatTensor<B, D>,
dim: usize,
@ -1412,7 +1399,6 @@ pub trait FloatTensorOps<B: Backend> {
///
/// A tensor with the same shape as the input tensor and corresponding indices, where
/// the elements are sorted by value and the indices map back to the original input tensor.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn float_sort_with_indices<const D: usize>(
tensor: FloatTensor<B, D>,
dim: usize,
@ -1434,7 +1420,6 @@ pub trait FloatTensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
fn float_argsort<const D: usize>(
tensor: FloatTensor<B, D>,
dim: usize,

View File

@ -31,7 +31,7 @@ wgpu = { workspace = true, features = ["fragile-send-sync-non-atomic-wasm"] }
pollster = { workspace = true }
log = { workspace = true }
futures-intrusive = { workspace = true }
async-channel = { workspace = true }
derive-new = { workspace = true }
hashbrown = { workspace = true }

View File

@ -113,7 +113,7 @@ where
)
}
fn buffer_reader(&mut self, handle: server::Binding<Self>) -> BufferReader {
fn create_read_buffer(&mut self, handle: server::Binding<Self>) -> wgpu::Buffer {
let resource = self.memory_management.get(handle.memory);
let size = resource.size();
@ -134,50 +134,7 @@ where
self.tasks_count += 1;
self.sync(SyncType::Flush);
BufferReader::new(buffer_dest)
}
}
#[derive(new)]
struct BufferReader {
buffer: wgpu::Buffer,
}
impl BufferReader {
#[cfg(target_family = "wasm")]
async fn read(self, device: alloc::sync::Arc<wgpu::Device>) -> Vec<u8> {
self.read_async(&device).await
}
#[cfg(not(target_family = "wasm"))]
fn read(self, device: &wgpu::Device) -> Vec<u8> {
pollster::block_on(self.read_async(device))
}
async fn read_async(&self, device: &wgpu::Device) -> Vec<u8> {
let buffer_slice = self.buffer.slice(..);
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |v| {
sender
.send(v)
.expect("Unable to send buffer slice result to async channel.")
});
device.poll(wgpu::Maintain::Wait);
let result = receiver.receive().await;
if let Some(Ok(())) = result {
let data = buffer_slice.get_mapped_range();
let result = bytemuck::cast_slice(&data).to_vec();
drop(data);
self.buffer.unmap();
result
} else {
panic!("Unable to read buffer {:?}", result)
}
buffer_dest
}
}
@ -190,15 +147,38 @@ where
type MemoryManagement = MM;
type AutotuneKey = JitAutotuneKey;
fn read(&mut self, binding: server::Binding<Self>) -> Reader<Vec<u8>> {
#[cfg(target_family = "wasm")]
{
let future = self.buffer_reader(binding).read(self.device.clone());
return Reader::Future(Box::pin(future));
}
fn read(&mut self, binding: server::Binding<Self>) -> Reader {
let device = self.device.clone();
let buffer = self.create_read_buffer(binding);
#[cfg(not(target_family = "wasm"))]
Reader::Concrete(self.buffer_reader(binding).read(&self.device))
Box::pin(async move {
let buffer_slice = buffer.slice(..);
let (sender, receiver) = async_channel::bounded(1);
buffer_slice.map_async(wgpu::MapMode::Read, move |v| {
sender
.try_send(v)
.expect("Unable to send buffer slice result to async channel.")
});
device.poll(wgpu::Maintain::Wait);
let result = receiver
.recv()
.await
.expect("Unable to receive buffer slice result.");
if let Ok(()) = result {
let data = buffer_slice.get_mapped_range();
let result = bytemuck::cast_slice(&data).to_vec();
drop(data);
buffer.unmap();
result
} else {
panic!("Unable to read buffer {:?}", result)
}
})
}
fn get_resource(

View File

@ -25,9 +25,6 @@ tui = ["burn-train?/tui"]
## Includes system info metrics (CPU/GPU usage, etc)
metrics = ["burn-train?/metrics"]
# Useful when targeting WASM and not using WGPU.
wasm-sync = ["burn-core/wasm-sync"]
# Datasets
dataset = ["burn-core/dataset"]

View File

@ -73,7 +73,6 @@
//! - `blas-netlib`: If supported, Blas Netlib will be use
//! - `openblas`: If supported, Openblas will be use
//! - `openblas-system`: If supported, Openblas installed on the system will be use
//! - `wasm-sync`: When targeting wasm, but want a sync API (won't work with WGPU)
//! - `autotune`: Enable running benchmarks to select the best kernel in backends that support it.
//! - `fusion`: Enable operation fusion in backends that support it.
//! - Backend decorators

View File

@ -28,7 +28,7 @@ pub fn launch<R: Runtime>(device: &R::Device) {
ArrayHandle::new(&output_handle, input.len()),
);
let output = client.read(output_handle.binding()).read_sync().unwrap();
let output = client.read(output_handle.binding());
let output = f32::from_bytes(&output);
// Should be [-0.1587, 0.0000, 0.8413, 5.0000]

View File

@ -150,19 +150,13 @@ impl<B: Backend> Model<B> {
// Convert the model output into probability distribution using softmax formula
let probabilities = softmax(output, 1);
#[cfg(not(target_family = "wasm"))]
let result = probabilities.into_data().convert::<f32>().to_vec().unwrap();
// Forces the result to be computed
#[cfg(target_family = "wasm")]
let result = probabilities
.into_data()
probabilities
.into_data_async()
.await
.convert::<f32>()
.to_vec()
.unwrap();
result
.unwrap()
}
}

View File

@ -24,6 +24,3 @@ console_error_panic_hook = { workspace = true }
wasm-bindgen = "0.2"
wasm-bindgen-futures = "0.4"
js-sys = "0.3"
[dev-dependencies]
pollster = { workspace = true }

View File

@ -68,11 +68,7 @@ impl Mnist {
let output = burn::tensor::activation::softmax(output, 1);
// Flatten output tensor with [1, 10] shape into boxed slice of [f32]
#[cfg(not(target_family = "wasm"))]
let output = output.into_data();
#[cfg(target_family = "wasm")]
let output = output.into_data().await;
let output = output.into_data_async().await;
let array = Array::new();
for value in output.iter::<f32>() {