mirror of https://github.com/tracel-ai/burn.git
Consistent sync/async handling, allow more functions to be async for wasm. (#1936)
This commit is contained in:
parent
6f2ba34382
commit
849c8f453b
|
@ -190,14 +190,15 @@ dependencies = [
|
|||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.80"
|
||||
name = "async-channel"
|
||||
version = "2.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca"
|
||||
checksum = "89b47800b0be77592da0afd425cc03468052844aff33b84e33cc696f64e77b6a"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.68",
|
||||
"concurrent-queue",
|
||||
"event-listener-strategy",
|
||||
"futures-core",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -461,12 +462,12 @@ dependencies = [
|
|||
name = "burn-common"
|
||||
version = "0.14.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"dashmap",
|
||||
"data-encoding",
|
||||
"derive-new",
|
||||
"getrandom",
|
||||
"indicatif",
|
||||
"pollster",
|
||||
"rand",
|
||||
"reqwest 0.12.5",
|
||||
"serde",
|
||||
|
@ -479,12 +480,14 @@ dependencies = [
|
|||
name = "burn-compute"
|
||||
version = "0.14.0"
|
||||
dependencies = [
|
||||
"async-channel",
|
||||
"burn-common",
|
||||
"derive-new",
|
||||
"dirs 5.0.1",
|
||||
"hashbrown 0.14.5",
|
||||
"log",
|
||||
"md5",
|
||||
"pollster",
|
||||
"rand",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
@ -761,6 +764,7 @@ dependencies = [
|
|||
name = "burn-wgpu"
|
||||
version = "0.14.0"
|
||||
dependencies = [
|
||||
"async-channel",
|
||||
"burn-common",
|
||||
"burn-compute",
|
||||
"burn-cube",
|
||||
|
@ -769,7 +773,6 @@ dependencies = [
|
|||
"burn-tensor",
|
||||
"bytemuck",
|
||||
"derive-new",
|
||||
"futures-intrusive",
|
||||
"hashbrown 0.14.5",
|
||||
"log",
|
||||
"pollster",
|
||||
|
@ -1143,6 +1146,15 @@ dependencies = [
|
|||
"static_assertions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "concurrent-queue"
|
||||
version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "console"
|
||||
version = "0.15.8"
|
||||
|
@ -1733,6 +1745,27 @@ version = "0.1.10"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6"
|
||||
|
||||
[[package]]
|
||||
name = "event-listener"
|
||||
version = "5.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba"
|
||||
dependencies = [
|
||||
"concurrent-queue",
|
||||
"parking",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "event-listener-strategy"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0f214dc438f977e6d4e3500aaa277f5ad94ca83fbbd9b1a15713ce2344ccc5a1"
|
||||
dependencies = [
|
||||
"event-listener",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exr"
|
||||
version = "1.72.0"
|
||||
|
@ -1931,17 +1964,6 @@ dependencies = [
|
|||
"futures-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-intrusive"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"lock_api",
|
||||
"parking_lot 0.12.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-io"
|
||||
version = "0.3.30"
|
||||
|
@ -3274,7 +3296,6 @@ dependencies = [
|
|||
"burn",
|
||||
"console_error_panic_hook",
|
||||
"js-sys",
|
||||
"pollster",
|
||||
"serde",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
|
@ -3778,6 +3799,12 @@ version = "0.1.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
|
||||
|
||||
[[package]]
|
||||
name = "parking"
|
||||
version = "2.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae"
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.11.2"
|
||||
|
|
|
@ -26,7 +26,6 @@ readme = "README.md"
|
|||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
async-trait = "0.1.80"
|
||||
bytemuck = "1.16.1"
|
||||
candle-core = { version = "0.5.1" }
|
||||
clap = { version = "4.5.8", features = ["derive"] }
|
||||
|
@ -83,14 +82,16 @@ tracing-subscriber = "0.3.18"
|
|||
web-time = "1.1.0"
|
||||
zip = "2.1.3"
|
||||
|
||||
# Async handling
|
||||
pollster = "0.3"
|
||||
async-channel = "2.3"
|
||||
|
||||
# Terminal UI
|
||||
ratatui = "0.26.3"
|
||||
crossterm = "0.27.0"
|
||||
|
||||
# WGPU stuff
|
||||
futures-intrusive = "0.5.0"
|
||||
text_placeholder = "0.5.0"
|
||||
pollster = "0.3.0"
|
||||
wgpu = "0.20.1"
|
||||
|
||||
# Benchmarks and Burnbench
|
||||
|
|
|
@ -3,7 +3,7 @@ use crate::{checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor, Au
|
|||
use burn_tensor::{
|
||||
backend::Backend,
|
||||
ops::{BoolTensor, BoolTensorOps, IntTensor},
|
||||
Device, Reader, Shape, TensorData,
|
||||
Device, Shape, TensorData,
|
||||
};
|
||||
|
||||
impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
|
||||
|
@ -15,12 +15,8 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
|
|||
B::bool_shape(tensor)
|
||||
}
|
||||
|
||||
fn bool_to_data<const D: usize>(tensor: &BoolTensor<B, D>) -> Reader<TensorData> {
|
||||
B::bool_to_data(tensor)
|
||||
}
|
||||
|
||||
fn bool_into_data<const D: usize>(tensor: BoolTensor<B, D>) -> Reader<TensorData> {
|
||||
B::bool_into_data(tensor)
|
||||
async fn bool_into_data<const D: usize>(tensor: BoolTensor<B, D>) -> TensorData {
|
||||
B::bool_into_data(tensor).await
|
||||
}
|
||||
|
||||
fn bool_into_int<const D: usize>(tensor: BoolTensor<B, D>) -> IntTensor<B, D> {
|
||||
|
@ -121,14 +117,12 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
|
|||
B::bool_flip(tensor, axes)
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn bool_argwhere<const D: usize>(tensor: BoolTensor<B, D>) -> IntTensor<B, 2> {
|
||||
B::bool_argwhere(tensor)
|
||||
async fn bool_argwhere<const D: usize>(tensor: BoolTensor<B, D>) -> IntTensor<B, 2> {
|
||||
B::bool_argwhere(tensor).await
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn bool_nonzero<const D: usize>(tensor: BoolTensor<B, D>) -> Vec<IntTensor<B, 1>> {
|
||||
B::bool_nonzero(tensor)
|
||||
async fn bool_nonzero<const D: usize>(tensor: BoolTensor<B, D>) -> Vec<IntTensor<B, 1>> {
|
||||
B::bool_nonzero(tensor).await
|
||||
}
|
||||
|
||||
fn bool_expand<const D: usize, const D2: usize>(
|
||||
|
|
|
@ -3,7 +3,7 @@ use crate::{checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor, Au
|
|||
use burn_tensor::{
|
||||
backend::Backend,
|
||||
ops::{BoolTensor, IntTensor, IntTensorOps},
|
||||
Device, Distribution, Reader, Shape, TensorData,
|
||||
Device, Distribution, Shape, TensorData,
|
||||
};
|
||||
|
||||
impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
|
||||
|
@ -15,12 +15,8 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
|
|||
B::int_shape(tensor)
|
||||
}
|
||||
|
||||
fn int_to_data<const D: usize>(tensor: &IntTensor<B, D>) -> Reader<TensorData> {
|
||||
B::int_to_data(tensor)
|
||||
}
|
||||
|
||||
fn int_into_data<const D: usize>(tensor: IntTensor<B, D>) -> Reader<TensorData> {
|
||||
B::int_into_data(tensor)
|
||||
async fn int_into_data<const D: usize>(tensor: IntTensor<B, D>) -> TensorData {
|
||||
B::int_into_data(tensor).await
|
||||
}
|
||||
|
||||
fn int_to_device<const D: usize>(
|
||||
|
@ -380,7 +376,6 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
|
|||
B::int_expand(tensor, shape)
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn int_sort<const D: usize>(
|
||||
tensor: IntTensor<Self, D>,
|
||||
dim: usize,
|
||||
|
@ -389,7 +384,6 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
|
|||
B::int_sort(tensor, dim, descending)
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn int_sort_with_indices<const D: usize>(
|
||||
tensor: IntTensor<Self, D>,
|
||||
dim: usize,
|
||||
|
@ -398,7 +392,6 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
|
|||
B::int_sort_with_indices(tensor, dim, descending)
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn int_argsort<const D: usize>(
|
||||
tensor: IntTensor<Self, D>,
|
||||
dim: usize,
|
||||
|
|
|
@ -17,7 +17,7 @@ use crate::{
|
|||
use burn_tensor::{
|
||||
backend::Backend,
|
||||
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor},
|
||||
Device, ElementConversion, Reader, Shape, Tensor, TensorData,
|
||||
Device, ElementConversion, Shape, Tensor, TensorData,
|
||||
};
|
||||
|
||||
use super::maxmin::MaxMinDim;
|
||||
|
@ -50,12 +50,8 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
B::float_shape(&tensor.primitive)
|
||||
}
|
||||
|
||||
fn float_to_data<const D: usize>(tensor: &FloatTensor<Self, D>) -> Reader<TensorData> {
|
||||
B::float_to_data(&tensor.primitive)
|
||||
}
|
||||
|
||||
fn float_into_data<const D: usize>(tensor: FloatTensor<Self, D>) -> Reader<TensorData> {
|
||||
B::float_into_data(tensor.primitive)
|
||||
async fn float_into_data<const D: usize>(tensor: FloatTensor<Self, D>) -> TensorData {
|
||||
B::float_into_data(tensor.primitive).await
|
||||
}
|
||||
|
||||
fn float_device<const D: usize>(tensor: &FloatTensor<Self, D>) -> Device<Self> {
|
||||
|
@ -2364,7 +2360,6 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn float_sort<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
dim: usize,
|
||||
|
@ -2387,7 +2382,6 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn float_sort_with_indices<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
dim: usize,
|
||||
|
@ -2416,7 +2410,6 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn float_argsort<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
dim: usize,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use burn_tensor::{backend::Backend, Reader, Shape, TensorData};
|
||||
use burn_tensor::{backend::Backend, Shape, TensorData};
|
||||
|
||||
use crate::{
|
||||
element::{CandleElement, FloatCandleElement, IntCandleElement},
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use burn_tensor::{
|
||||
ops::{BoolTensor, BoolTensorOps, FloatTensor, IntTensor},
|
||||
Device, Reader, Shape, TensorData,
|
||||
Device, Shape, TensorData,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
|
@ -19,12 +19,10 @@ impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<
|
|||
super::base::shape(tensor)
|
||||
}
|
||||
|
||||
fn bool_into_data<const D: usize>(tensor: BoolTensor<Self, D>) -> Reader<TensorData> {
|
||||
async fn bool_into_data<const D: usize>(tensor: BoolTensor<Self, D>) -> TensorData {
|
||||
let x: Vec<u8> = tensor.tensor.flatten_all().unwrap().to_vec1().unwrap();
|
||||
let y = x.iter().map(|b| !matches!(b, 0)).collect();
|
||||
let data = TensorData::new(y, tensor.shape());
|
||||
|
||||
Reader::Concrete(data)
|
||||
TensorData::new(y, tensor.shape())
|
||||
}
|
||||
|
||||
fn bool_from_data<const D: usize>(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use burn_tensor::{
|
||||
ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps},
|
||||
Bool, Device, Distribution, ElementConversion, Reader, Shape, TensorData,
|
||||
Bool, Device, Distribution, ElementConversion, Shape, TensorData,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
|
@ -19,8 +19,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
|
|||
super::base::shape(tensor)
|
||||
}
|
||||
|
||||
fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> Reader<TensorData> {
|
||||
Reader::Concrete(super::base::into_data(tensor))
|
||||
async fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> TensorData {
|
||||
super::base::into_data(tensor)
|
||||
}
|
||||
|
||||
fn int_from_data<const D: usize>(
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::borrow::Borrow;
|
|||
|
||||
use burn_tensor::{
|
||||
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, FullPrecisionBackend, IntTensor},
|
||||
Device, Distribution, ElementConversion, Reader, Shape, TensorData,
|
||||
Device, Distribution, ElementConversion, Shape, TensorData,
|
||||
};
|
||||
use candle_core::{backend::BackendStorage, shape, Tensor};
|
||||
|
||||
|
@ -59,8 +59,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
|
|||
super::base::shape(tensor)
|
||||
}
|
||||
|
||||
fn float_into_data<const D: usize>(tensor: CandleTensor<F, D>) -> Reader<TensorData> {
|
||||
Reader::Concrete(super::base::into_data(tensor))
|
||||
async fn float_into_data<const D: usize>(tensor: CandleTensor<F, D>) -> TensorData {
|
||||
super::base::into_data(tensor)
|
||||
}
|
||||
|
||||
fn float_device<const D: usize>(tensor: &CandleTensor<F, D>) -> Device<Self> {
|
||||
|
|
|
@ -12,25 +12,23 @@ version.workspace = true
|
|||
|
||||
[features]
|
||||
default = ["std"]
|
||||
std = ["rand/std", "data-encoding/std"]
|
||||
std = ["rand/std", "data-encoding/std", "dep:pollster"]
|
||||
doc = ["default"]
|
||||
wasm-sync = []
|
||||
network = ["dep:indicatif", "dep:reqwest", "dep:tokio"]
|
||||
|
||||
[target.'cfg(target_family = "wasm")'.dependencies]
|
||||
async-trait = { workspace = true }
|
||||
getrandom = { workspace = true, features = ["js"] }
|
||||
web-time = { version = "1.1.0" }
|
||||
|
||||
|
||||
[dependencies]
|
||||
# ** Please make sure all dependencies support no_std when std is disabled **
|
||||
|
||||
rand = { workspace = true }
|
||||
spin = { workspace = true } # using in place of use std::sync::Mutex;
|
||||
spin = { workspace = true } # using in place of use std::sync::Mutex;
|
||||
derive-new = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
data-encoding = { workspace = true }
|
||||
pollster = { workspace = true, optional = true }
|
||||
|
||||
# Network downloader
|
||||
indicatif = { workspace = true, optional = true }
|
||||
|
|
|
@ -1,115 +1,54 @@
|
|||
use alloc::boxed::Box;
|
||||
use core::marker::PhantomData;
|
||||
use alloc::{boxed::Box, sync::Arc, task::Wake, vec::Vec};
|
||||
use core::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{Context, Poll, Waker},
|
||||
};
|
||||
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
#[async_trait::async_trait]
|
||||
/// Allows to create async reader.
|
||||
pub trait AsyncReader<T>: Send {
|
||||
/// Read asynchronously.
|
||||
async fn read(self: Box<Self>) -> T;
|
||||
/// A future that is used to read resources from a compute server.
|
||||
pub type Reader = Pin<Box<dyn Future<Output = Vec<u8>> + Send>>;
|
||||
|
||||
/// Create a reader from a concrete value.
|
||||
pub fn reader_from_concrete(val: Vec<u8>) -> Reader {
|
||||
Box::pin(async move { val })
|
||||
}
|
||||
|
||||
/// Define how data is read, sync or async.
|
||||
pub enum Reader<T> {
|
||||
/// Concrete variant.
|
||||
Concrete(T),
|
||||
/// Sync data variant.
|
||||
Sync(Box<dyn SyncReader<T>>),
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
/// Async data variant.
|
||||
Async(Box<dyn AsyncReader<T>>),
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
/// Future data variant.
|
||||
Future(core::pin::Pin<Box<dyn core::future::Future<Output = T> + Send>>),
|
||||
struct DummyWaker;
|
||||
|
||||
impl Wake for DummyWaker {
|
||||
fn wake(self: Arc<Self>) {}
|
||||
fn wake_by_ref(self: &Arc<Self>) {}
|
||||
}
|
||||
|
||||
/// Allows to create sync reader.
|
||||
pub trait SyncReader<T>: Send {
|
||||
/// Read synchronously.
|
||||
fn read(self: Box<Self>) -> T;
|
||||
/// Read a future synchronously.
|
||||
///
|
||||
/// On WASM futures cannot block, so this only succeeds if the future returns immediately.
|
||||
/// If you want to handle this error, please use
|
||||
/// try_read_sync instead.
|
||||
pub fn read_sync<F: Future<Output = T>, T>(f: F) -> T {
|
||||
try_read_sync(f).expect("Failed to read tensor data synchronously. This can happen on platforms that don't support blocking futures like WASM. If possible, try using an async variant of this function instead.")
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct MappedReader<I, O, F> {
|
||||
reader: Reader<I>,
|
||||
mapper: F,
|
||||
_output: PhantomData<O>,
|
||||
}
|
||||
/// Read a future synchronously.
|
||||
///
|
||||
/// On WASM futures cannot block, so this only succeeds if the future returns immediately.
|
||||
/// otherwise this returns None.
|
||||
pub fn try_read_sync<F: Future<Output = T>, T>(f: F) -> Option<T> {
|
||||
// Create a dummy context.
|
||||
let waker = Waker::from(Arc::new(DummyWaker));
|
||||
let mut context = Context::from_waker(&waker);
|
||||
|
||||
impl<I, O, F> SyncReader<O> for MappedReader<I, O, F>
|
||||
where
|
||||
I: Send,
|
||||
O: Send,
|
||||
F: Send + FnOnce(I) -> O,
|
||||
{
|
||||
fn read(self: Box<Self>) -> O {
|
||||
let input = self
|
||||
.reader
|
||||
.read_sync()
|
||||
.expect("Only sync data supported in a sync reader.");
|
||||
// Pin & poll the future. Some backends don't do async readbacks, and instead immediately get
|
||||
// the data. This let's us detect when a future is synchronous and doesn't require any waiting.
|
||||
let mut pinned = core::pin::pin!(f);
|
||||
|
||||
(self.mapper)(input)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
#[async_trait::async_trait]
|
||||
impl<I, O, F> AsyncReader<O> for MappedReader<I, O, F>
|
||||
where
|
||||
I: Send,
|
||||
O: Send,
|
||||
F: Send + FnOnce(I) -> O,
|
||||
{
|
||||
async fn read(self: Box<Self>) -> O {
|
||||
let input = self.reader.read().await;
|
||||
(self.mapper)(input)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Reader<T> {
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
/// Read the data.
|
||||
pub async fn read(self) -> T {
|
||||
match self {
|
||||
Self::Concrete(data) => data,
|
||||
Self::Sync(reader) => reader.read(),
|
||||
Self::Async(func) => func.read().await,
|
||||
Self::Future(future) => future.await,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
/// Read the data.
|
||||
pub fn read(self) -> T {
|
||||
match self {
|
||||
Self::Concrete(data) => data,
|
||||
Self::Sync(reader) => reader.read(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Read the data only if sync, returns None if an async reader.
|
||||
pub fn read_sync(self) -> Option<T> {
|
||||
match self {
|
||||
Self::Concrete(data) => Some(data),
|
||||
Self::Sync(reader) => Some(reader.read()),
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
Self::Async(_func) => return None,
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
Self::Future(_future) => return None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Map the current reader to another type.
|
||||
pub fn map<O, F>(self, mapper: F) -> Reader<O>
|
||||
where
|
||||
T: 'static + Send,
|
||||
O: 'static + Send,
|
||||
F: FnOnce(T) -> O + 'static + Send,
|
||||
{
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
return Reader::Async(Box::new(MappedReader::new(self, mapper)));
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
Reader::Sync(Box::new(MappedReader::new(self, mapper)))
|
||||
match pinned.as_mut().poll(&mut context) {
|
||||
Poll::Ready(output) => Some(output),
|
||||
// On platforms that support it, now just block on the future and drive it to completion.
|
||||
#[cfg(all(not(target_family = "wasm"), feature = "std"))]
|
||||
Poll::Pending => Some(pollster::block_on(pinned)),
|
||||
// Otherwise, just bail and return None - this futures will have to be read back asynchronously.
|
||||
#[cfg(any(target_family = "wasm", not(feature = "std")))]
|
||||
Poll::Pending => None,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ default = [
|
|||
std = ["burn-common/std"]
|
||||
channel-mutex = []
|
||||
channel-cell = []
|
||||
channel-mpsc = [] # Assume std
|
||||
channel-mpsc = ["dep:async-channel", "dep:pollster"] # Assume std
|
||||
storage-bytes = []
|
||||
autotune-persistent-cache = ["dirs", "md5", "serde", "serde_json"] # Assume std
|
||||
|
||||
|
@ -36,6 +36,8 @@ dirs = { workspace = true, optional = true }
|
|||
serde = { workspace = true, optional = true }
|
||||
serde_json = { workspace = true, features = ["std"], optional = true }
|
||||
md5 = { workspace = true, optional = true }
|
||||
pollster = { workspace = true, optional = true }
|
||||
async-channel = { workspace = true, optional = true }
|
||||
|
||||
[target.'cfg(target_family = "wasm")'.dependencies]
|
||||
web-time = { workspace = true }
|
||||
|
|
|
@ -9,7 +9,7 @@ use burn_common::{reader::Reader, sync_type::SyncType};
|
|||
/// while ensuring thread-safety
|
||||
pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug + Send + Sync {
|
||||
/// Given a binding, returns owned resource as bytes
|
||||
fn read(&self, binding: Binding<Server>) -> Reader<Vec<u8>>;
|
||||
fn read(&self, binding: Binding<Server>) -> Reader;
|
||||
|
||||
/// Given a resource handle, return the storage resource.
|
||||
fn get_resource(
|
||||
|
|
|
@ -42,9 +42,9 @@ where
|
|||
|
||||
impl<Server> ComputeChannel<Server> for RefCellComputeChannel<Server>
|
||||
where
|
||||
Server: ComputeServer,
|
||||
Server: ComputeServer + Send,
|
||||
{
|
||||
fn read(&self, binding: Binding<Server>) -> Reader<Vec<u8>> {
|
||||
fn read(&self, binding: Binding<Server>) -> Reader {
|
||||
self.server.borrow_mut().read(binding)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,9 +1,5 @@
|
|||
use std::{
|
||||
sync::{mpsc, Arc},
|
||||
thread,
|
||||
};
|
||||
|
||||
use burn_common::{reader::Reader, sync_type::SyncType};
|
||||
use std::{sync::Arc, thread};
|
||||
|
||||
use super::ComputeChannel;
|
||||
use crate::{
|
||||
|
@ -11,7 +7,7 @@ use crate::{
|
|||
storage::ComputeStorage,
|
||||
};
|
||||
|
||||
/// Create a channel using the [multi-producer, single-consumer channel](mpsc) to communicate with
|
||||
/// Create a channel using a [multi-producer, single-consumer channel to communicate with
|
||||
/// the compute server spawn on its own thread.
|
||||
#[derive(Debug)]
|
||||
pub struct MpscComputeChannel<Server>
|
||||
|
@ -27,16 +23,16 @@ where
|
|||
Server: ComputeServer,
|
||||
{
|
||||
_handle: thread::JoinHandle<()>,
|
||||
sender: mpsc::Sender<Message<Server>>,
|
||||
sender: async_channel::Sender<Message<Server>>,
|
||||
}
|
||||
|
||||
type Callback<Response> = mpsc::Sender<Response>;
|
||||
type Callback<Response> = async_channel::Sender<Response>;
|
||||
|
||||
enum Message<Server>
|
||||
where
|
||||
Server: ComputeServer,
|
||||
{
|
||||
Read(Binding<Server>, Callback<Reader<Vec<u8>>>),
|
||||
Read(Binding<Server>, Callback<Vec<u8>>),
|
||||
GetResource(
|
||||
Binding<Server>,
|
||||
Callback<<Server::Storage as ComputeStorage>::Resource>,
|
||||
|
@ -53,36 +49,40 @@ where
|
|||
{
|
||||
/// Create a new mpsc compute channel.
|
||||
pub fn new(mut server: Server) -> Self {
|
||||
let (sender, receiver) = mpsc::channel();
|
||||
let (sender, receiver) = async_channel::unbounded();
|
||||
|
||||
let _handle = thread::spawn(move || {
|
||||
while let Ok(message) = receiver.recv() {
|
||||
match message {
|
||||
Message::Read(binding, callback) => {
|
||||
let data = server.read(binding);
|
||||
callback.send(data).unwrap();
|
||||
}
|
||||
Message::GetResource(binding, callback) => {
|
||||
let data = server.get_resource(binding);
|
||||
callback.send(data).unwrap();
|
||||
}
|
||||
Message::Create(data, callback) => {
|
||||
let handle = server.create(&data);
|
||||
callback.send(handle).unwrap();
|
||||
}
|
||||
Message::Empty(size, callback) => {
|
||||
let handle = server.empty(size);
|
||||
callback.send(handle).unwrap();
|
||||
}
|
||||
Message::ExecuteKernel(kernel, bindings) => {
|
||||
server.execute(kernel, bindings);
|
||||
}
|
||||
Message::Sync(sync_type, callback) => {
|
||||
server.sync(sync_type);
|
||||
callback.send(()).unwrap();
|
||||
}
|
||||
};
|
||||
}
|
||||
// Run the whole procedure as one blocking future. This is much simpler than trying
|
||||
// to use some multithreaded executor.
|
||||
pollster::block_on(async {
|
||||
while let Ok(message) = receiver.recv().await {
|
||||
match message {
|
||||
Message::Read(binding, callback) => {
|
||||
let data = server.read(binding).await;
|
||||
callback.send(data).await.unwrap();
|
||||
}
|
||||
Message::GetResource(binding, callback) => {
|
||||
let data = server.get_resource(binding);
|
||||
callback.send(data).await.unwrap();
|
||||
}
|
||||
Message::Create(data, callback) => {
|
||||
let handle = server.create(&data);
|
||||
callback.send(handle).await.unwrap();
|
||||
}
|
||||
Message::Empty(size, callback) => {
|
||||
let handle = server.empty(size);
|
||||
callback.send(handle).await.unwrap();
|
||||
}
|
||||
Message::ExecuteKernel(kernel, bindings) => {
|
||||
server.execute(kernel, bindings);
|
||||
}
|
||||
Message::Sync(sync_type, callback) => {
|
||||
server.sync(sync_type);
|
||||
callback.send(()).await.unwrap();
|
||||
}
|
||||
};
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
let state = Arc::new(MpscComputeChannelState { sender, _handle });
|
||||
|
@ -103,75 +103,71 @@ impl<Server> ComputeChannel<Server> for MpscComputeChannel<Server>
|
|||
where
|
||||
Server: ComputeServer + 'static,
|
||||
{
|
||||
fn read(&self, binding: Binding<Server>) -> Reader<Vec<u8>> {
|
||||
let (callback, response) = mpsc::channel();
|
||||
fn read(&self, binding: Binding<Server>) -> Reader {
|
||||
let sender = self.state.sender.clone();
|
||||
|
||||
self.state
|
||||
.sender
|
||||
.send(Message::Read(binding, callback))
|
||||
.unwrap();
|
||||
|
||||
self.response(response)
|
||||
Box::pin(async move {
|
||||
let (callback, response) = async_channel::unbounded();
|
||||
sender.send(Message::Read(binding, callback)).await.unwrap();
|
||||
handle_response(response.recv().await)
|
||||
})
|
||||
}
|
||||
|
||||
fn get_resource(
|
||||
&self,
|
||||
binding: Binding<Server>,
|
||||
) -> <Server::Storage as ComputeStorage>::Resource {
|
||||
let (callback, response) = mpsc::channel();
|
||||
let (callback, response) = async_channel::unbounded();
|
||||
|
||||
self.state
|
||||
.sender
|
||||
.send(Message::GetResource(binding, callback))
|
||||
.send_blocking(Message::GetResource(binding, callback))
|
||||
.unwrap();
|
||||
|
||||
self.response(response)
|
||||
handle_response(response.recv_blocking())
|
||||
}
|
||||
|
||||
fn create(&self, data: &[u8]) -> Handle<Server> {
|
||||
let (callback, response) = mpsc::channel();
|
||||
let (callback, response) = async_channel::unbounded();
|
||||
|
||||
self.state
|
||||
.sender
|
||||
.send(Message::Create(data.to_vec(), callback))
|
||||
.send_blocking(Message::Create(data.to_vec(), callback))
|
||||
.unwrap();
|
||||
|
||||
self.response(response)
|
||||
handle_response(response.recv_blocking())
|
||||
}
|
||||
|
||||
fn empty(&self, size: usize) -> Handle<Server> {
|
||||
let (callback, response) = mpsc::channel();
|
||||
|
||||
let (callback, response) = async_channel::unbounded();
|
||||
self.state
|
||||
.sender
|
||||
.send(Message::Empty(size, callback))
|
||||
.send_blocking(Message::Empty(size, callback))
|
||||
.unwrap();
|
||||
|
||||
self.response(response)
|
||||
handle_response(response.recv_blocking())
|
||||
}
|
||||
|
||||
fn execute(&self, kernel: Server::Kernel, bindings: Vec<Binding<Server>>) {
|
||||
self.state
|
||||
.sender
|
||||
.send(Message::ExecuteKernel(kernel, bindings))
|
||||
.send_blocking(Message::ExecuteKernel(kernel, bindings))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn sync(&self, sync_type: SyncType) {
|
||||
let (callback, response) = mpsc::channel();
|
||||
let (callback, response) = async_channel::unbounded();
|
||||
self.state
|
||||
.sender
|
||||
.send(Message::Sync(sync_type, callback))
|
||||
.send_blocking(Message::Sync(sync_type, callback))
|
||||
.unwrap();
|
||||
self.response(response)
|
||||
handle_response(response.recv_blocking())
|
||||
}
|
||||
}
|
||||
|
||||
impl<Server: ComputeServer> MpscComputeChannel<Server> {
|
||||
fn response<Response>(&self, response: mpsc::Receiver<Response>) -> Response {
|
||||
match response.recv() {
|
||||
Ok(val) => val,
|
||||
Err(err) => panic!("Can't connect to the server correctly {err:?}"),
|
||||
}
|
||||
fn handle_response<Response, Err: core::fmt::Debug>(response: Result<Response, Err>) -> Response {
|
||||
match response {
|
||||
Ok(val) => val,
|
||||
Err(err) => panic!("Can't connect to the server correctly {err:?}"),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ impl<Server> ComputeChannel<Server> for MutexComputeChannel<Server>
|
|||
where
|
||||
Server: ComputeServer,
|
||||
{
|
||||
fn read(&self, handle: Binding<Server>) -> Reader<Vec<u8>> {
|
||||
fn read(&self, handle: Binding<Server>) -> Reader {
|
||||
self.server.lock().read(handle)
|
||||
}
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ use crate::{
|
|||
use alloc::vec::Vec;
|
||||
use alloc::{boxed::Box, sync::Arc};
|
||||
use burn_common::stub::RwLock;
|
||||
use burn_common::{reader::Reader, sync_type::SyncType};
|
||||
use burn_common::sync_type::SyncType;
|
||||
|
||||
/// The ComputeClient is the entry point to require tasks from the ComputeServer.
|
||||
/// It should be obtained for a specific device via the Compute struct.
|
||||
|
@ -41,8 +41,16 @@ where
|
|||
}
|
||||
|
||||
/// Given a binding, returns owned resource as bytes.
|
||||
pub fn read(&self, binding: Binding<Server>) -> Reader<Vec<u8>> {
|
||||
self.channel.read(binding)
|
||||
pub async fn read_async(&self, binding: Binding<Server>) -> Vec<u8> {
|
||||
self.channel.read(binding).await
|
||||
}
|
||||
|
||||
/// Given a binding, returns owned resource as bytes.
|
||||
///
|
||||
/// # Remarks
|
||||
/// Panics if the read operation fails.
|
||||
pub fn read(&self, binding: Binding<Server>) -> Vec<u8> {
|
||||
burn_common::reader::read_sync(self.channel.read(binding))
|
||||
}
|
||||
|
||||
/// Given a resource handle, returns the storage resource.
|
||||
|
|
|
@ -25,7 +25,7 @@ where
|
|||
type AutotuneKey: AutotuneKey;
|
||||
|
||||
/// Given a handle, returns the owned resource as bytes.
|
||||
fn read(&mut self, binding: Binding<Self>) -> Reader<Vec<u8>>;
|
||||
fn read(&mut self, binding: Binding<Self>) -> Reader;
|
||||
|
||||
/// Given a resource handle, returns the storage resource.
|
||||
fn get_resource(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use burn_common::{reader::Reader, sync_type::SyncType};
|
||||
use burn_common::{reader::reader_from_concrete, sync_type::SyncType};
|
||||
use burn_compute::{
|
||||
memory_management::{simple::SimpleMemoryManagement, MemoryHandle, MemoryManagement},
|
||||
server::{Binding, ComputeServer, Handle},
|
||||
|
@ -26,10 +26,9 @@ where
|
|||
type MemoryManagement = MM;
|
||||
type AutotuneKey = String;
|
||||
|
||||
fn read(&mut self, binding: Binding<Self>) -> Reader<Vec<u8>> {
|
||||
fn read(&mut self, binding: Binding<Self>) -> burn_common::reader::Reader {
|
||||
let bytes = self.memory_management.get(binding.memory);
|
||||
|
||||
Reader::Concrete(bytes.read().to_vec())
|
||||
reader_from_concrete(bytes.read().to_vec())
|
||||
}
|
||||
|
||||
fn get_resource(&mut self, binding: Binding<Self>) -> BytesResource {
|
||||
|
|
|
@ -16,7 +16,7 @@ fn created_resource_is_the_same_when_read() {
|
|||
|
||||
let obtained_resource = client.read(resource_description.binding());
|
||||
|
||||
assert_eq!(resource, obtained_resource.read())
|
||||
assert_eq!(resource, obtained_resource)
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -26,7 +26,7 @@ fn empty_allocates_memory() {
|
|||
let resource_description = client.empty(size);
|
||||
let empty_resource = client.read(resource_description.binding());
|
||||
|
||||
assert_eq!(empty_resource.read().len(), 4);
|
||||
assert_eq!(empty_resource.len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -43,7 +43,7 @@ fn execute_elementwise_addition() {
|
|||
|
||||
let obtained_resource = client.read(out.binding());
|
||||
|
||||
assert_eq!(obtained_resource.read(), Vec::from([4, 5, 6]))
|
||||
assert_eq!(obtained_resource, Vec::from([4, 5, 6]))
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -65,7 +65,7 @@ fn autotune_basic_addition_execution() {
|
|||
let obtained_resource = client.read(out.binding());
|
||||
|
||||
// If slow kernel was selected it would output [0, 1, 2]
|
||||
assert_eq!(obtained_resource.read(), Vec::from([4, 5, 6]));
|
||||
assert_eq!(obtained_resource, Vec::from([4, 5, 6]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -87,7 +87,7 @@ fn autotune_basic_multiplication_execution() {
|
|||
let obtained_resource = client.read(out.binding());
|
||||
|
||||
// If slow kernel was selected it would output [0, 1, 2]
|
||||
assert_eq!(obtained_resource.read(), Vec::from([0, 4, 8]));
|
||||
assert_eq!(obtained_resource, Vec::from([0, 4, 8]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -126,7 +126,7 @@ fn autotune_cache_same_key_return_a_cache_hit() {
|
|||
let obtained_resource = client.read(out_2.binding());
|
||||
|
||||
// Cache should be hit, so CacheTestFastOn3 should be used, returning lhs
|
||||
assert_eq!(obtained_resource.read(), Vec::from([0, 1, 2, 3]));
|
||||
assert_eq!(obtained_resource, Vec::from([0, 1, 2, 3]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -167,7 +167,7 @@ fn autotune_cache_no_cache_on_disk_return_a_cache_miss() {
|
|||
let obtained_resource = client.read(out_2.binding());
|
||||
|
||||
// Cache should be missed, so CacheTestSlowOn3 (but faster on 5) should be used, returning rhs
|
||||
assert_eq!(obtained_resource.read(), Vec::from([5, 6, 7, 8, 9]));
|
||||
assert_eq!(obtained_resource, Vec::from([5, 6, 7, 8, 9]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -175,6 +175,8 @@ fn autotune_cache_no_cache_on_disk_return_a_cache_miss() {
|
|||
#[cfg(feature = "std")]
|
||||
fn autotune_cache_file_path_creation_works_when_path_does_not_exist_yet() {
|
||||
// delete the cache file
|
||||
|
||||
use burn_common::sync_type::SyncType;
|
||||
let file_path = burn_compute::tune::get_persistent_cache_file_path(crate::dummy::TUNER_PREFIX);
|
||||
let parent_dir = file_path
|
||||
.parent()
|
||||
|
@ -198,7 +200,7 @@ fn autotune_cache_file_path_creation_works_when_path_does_not_exist_yet() {
|
|||
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes, handles);
|
||||
client.autotune_execute(Box::new(cache_test_autotune_kernel));
|
||||
// ensure that the autotune operations are run and cached
|
||||
let _obtained_resource = client.read(out.binding());
|
||||
client.sync(SyncType::Wait);
|
||||
|
||||
assert!(
|
||||
parent_dir.exists(),
|
||||
|
@ -237,7 +239,7 @@ fn autotune_cache_different_keys_return_a_cache_miss() {
|
|||
let obtained_resource = client.read(out_2.binding());
|
||||
|
||||
// Cache should be missed, so CacheTestSlowOn3 (but faster on 5) should be used, returning rhs
|
||||
assert_eq!(obtained_resource.read(), Vec::from([5, 6, 7, 8, 9]));
|
||||
assert_eq!(obtained_resource, Vec::from([5, 6, 7, 8, 9]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -285,5 +287,5 @@ fn autotune_cache_different_checksums_return_a_cache_miss() {
|
|||
// Cache should be missed because the checksum on 4 is generated randomly
|
||||
// and thus is always different,
|
||||
// so CacheTestSlowOn3 (but faster on 4) should be used, returning rhs
|
||||
assert_eq!(obtained_resource.read(), Vec::from([5, 6, 7, 8]));
|
||||
assert_eq!(obtained_resource, Vec::from([5, 6, 7, 8]));
|
||||
}
|
||||
|
|
|
@ -65,8 +65,6 @@ sqlite = ["burn-dataset?/sqlite"]
|
|||
sqlite-bundled = ["burn-dataset?/sqlite-bundled"]
|
||||
vision = ["burn-dataset?/vision", "burn-common/network"]
|
||||
|
||||
wasm-sync = ["burn-tensor/wasm-sync", "burn-common/wasm-sync"]
|
||||
|
||||
# Backend
|
||||
autodiff = ["burn-autodiff"]
|
||||
fusion = ["burn-wgpu?/fusion"]
|
||||
|
|
|
@ -69,16 +69,6 @@ impl GradientClipping {
|
|||
clipped_grad.mask_fill(lower_mask, -threshold)
|
||||
}
|
||||
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
fn clip_by_norm<B: Backend, const D: usize>(
|
||||
&self,
|
||||
_grad: Tensor<B, D>,
|
||||
_threshold: f32,
|
||||
) -> Tensor<B, D> {
|
||||
todo!("Not yet supported on wasm");
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn clip_by_norm<B: Backend, const D: usize>(
|
||||
&self,
|
||||
grad: Tensor<B, D>,
|
||||
|
@ -97,11 +87,9 @@ impl GradientClipping {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn l2_norm<B: Backend, const D: usize>(tensor: Tensor<B, D>) -> Tensor<B, 1> {
|
||||
let squared = tensor.powf_scalar(2.0);
|
||||
let sum = squared.sum();
|
||||
|
||||
sum.sqrt()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -135,10 +135,6 @@ impl<B: Backend, const D: usize> Record<B> for Tensor<B, D> {
|
|||
type Item<S: PrecisionSettings> = FloatTensorSerde<S>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
todo!("Recording float tensors isn't yet supported on wasm.");
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
FloatTensorSerde::new(self.into_data().convert::<S::FloatElem>())
|
||||
}
|
||||
|
||||
|
@ -152,10 +148,6 @@ impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Int> {
|
|||
type Item<S: PrecisionSettings> = IntTensorSerde<S>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
todo!("Recording int tensors isn't yet supported on wasm.");
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
IntTensorSerde::new(self.into_data().convert::<S::IntElem>())
|
||||
}
|
||||
|
||||
|
@ -169,10 +161,6 @@ impl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Bool> {
|
|||
type Item<S: PrecisionSettings> = BoolTensorSerde;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
todo!("Recording bool tensors isn't yet supported on wasm.");
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
BoolTensorSerde::new(self.into_data())
|
||||
}
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ pub fn test_kernel_with_generics<R: Runtime>(client: ComputeClient<R::Server, R:
|
|||
ArrayHandle::new(&handle, 2),
|
||||
);
|
||||
|
||||
let actual = client.read(handle.binding()).read_sync().unwrap();
|
||||
let actual = client.read(handle.binding());
|
||||
let actual = f32::from_bytes(&actual);
|
||||
|
||||
assert_eq!(actual[0], 5.0);
|
||||
|
@ -41,7 +41,7 @@ pub fn test_kernel_without_generics<R: Runtime>(client: ComputeClient<R::Server,
|
|||
ArrayHandle::new(&handle, 2),
|
||||
);
|
||||
|
||||
let actual = client.read(handle.binding()).read_sync().unwrap();
|
||||
let actual = client.read(handle.binding());
|
||||
let actual = f32::from_bytes(&actual);
|
||||
|
||||
assert_eq!(actual[0], 5.0);
|
||||
|
|
|
@ -109,7 +109,7 @@ fn test_subcube_operation<TestRuntime: Runtime, Launch>(
|
|||
TensorHandle::new(&handle, &strides, &shape),
|
||||
);
|
||||
|
||||
let actual = client.read(handle.binding()).read_sync().unwrap();
|
||||
let actual = client.read(handle.binding());
|
||||
let actual = f32::from_bytes(&actual);
|
||||
|
||||
assert_eq!(actual, expected);
|
||||
|
|
|
@ -8,6 +8,8 @@ use burn_cube::ir::CubeDim;
|
|||
use burn_cube::prelude::*;
|
||||
use burn_jit::JitAutotuneKey;
|
||||
use burn_tensor::backend::SyncType;
|
||||
use burn_tensor::reader_from_concrete;
|
||||
use burn_tensor::Reader;
|
||||
use cudarc::driver::sys::CUctx_st;
|
||||
use cudarc::driver::sys::CUfunc_st;
|
||||
use std::collections::HashMap;
|
||||
|
@ -58,18 +60,17 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
|
|||
type MemoryManagement = MM;
|
||||
type AutotuneKey = JitAutotuneKey;
|
||||
|
||||
fn read(&mut self, binding: server::Binding<Self>) -> burn_tensor::Reader<Vec<u8>> {
|
||||
fn read(&mut self, binding: server::Binding<Self>) -> Reader {
|
||||
let ctx = self.get_context();
|
||||
let resource = ctx.memory_management.get(binding.memory);
|
||||
|
||||
// TODO: Check if it is possible to make this faster
|
||||
let mut data = vec![0; resource.size() as usize];
|
||||
unsafe {
|
||||
cudarc::driver::result::memcpy_dtoh_async(&mut data, resource.ptr, ctx.stream).unwrap();
|
||||
};
|
||||
|
||||
ctx.sync();
|
||||
|
||||
burn_tensor::Reader::Concrete(data)
|
||||
reader_from_concrete(data)
|
||||
}
|
||||
|
||||
fn create(&mut self, data: &[u8]) -> server::Handle<Self> {
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
use std::future::Future;
|
||||
|
||||
use crate::{
|
||||
stream::{execution::Operation, StreamId},
|
||||
FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionTensor,
|
||||
};
|
||||
use burn_tensor::{
|
||||
repr::{OperationDescription, TensorDescription, TensorId},
|
||||
DType, Reader, TensorData,
|
||||
DType, TensorData,
|
||||
};
|
||||
|
||||
/// Define how to interact with the fusion server.
|
||||
|
@ -37,7 +39,7 @@ where
|
|||
&self,
|
||||
tensor: TensorDescription,
|
||||
stream: StreamId,
|
||||
) -> Reader<TensorData>
|
||||
) -> impl Future<Output = TensorData> + Send
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>;
|
||||
/// Read the values contained by an int tensor.
|
||||
|
@ -45,7 +47,7 @@ where
|
|||
&self,
|
||||
tensor: TensorDescription,
|
||||
stream: StreamId,
|
||||
) -> Reader<TensorData>
|
||||
) -> impl Future<Output = TensorData> + Send
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>;
|
||||
/// Read the values contained by a bool tensor.
|
||||
|
@ -53,7 +55,7 @@ where
|
|||
&self,
|
||||
tensor: TensorDescription,
|
||||
stream: StreamId,
|
||||
) -> Reader<TensorData>
|
||||
) -> impl Future<Output = TensorData> + Send
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>;
|
||||
/// Change the client of the given float tensor.
|
||||
|
|
|
@ -78,37 +78,37 @@ where
|
|||
FusionTensor::new(id, shape, dtype, self.clone(), stream)
|
||||
}
|
||||
|
||||
fn read_tensor_float<B, const D: usize>(
|
||||
async fn read_tensor_float<B, const D: usize>(
|
||||
&self,
|
||||
tensor: TensorDescription,
|
||||
stream: StreamId,
|
||||
) -> burn_tensor::Reader<burn_tensor::TensorData>
|
||||
) -> burn_tensor::TensorData
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
self.server.lock().read_float::<B, D>(tensor, stream)
|
||||
self.server.lock().read_float::<B, D>(tensor, stream).await
|
||||
}
|
||||
|
||||
fn read_tensor_int<B, const D: usize>(
|
||||
async fn read_tensor_int<B, const D: usize>(
|
||||
&self,
|
||||
tensor: TensorDescription,
|
||||
id: StreamId,
|
||||
) -> burn_tensor::Reader<burn_tensor::TensorData>
|
||||
) -> burn_tensor::TensorData
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
self.server.lock().read_int::<B, D>(tensor, id)
|
||||
self.server.lock().read_int::<B, D>(tensor, id).await
|
||||
}
|
||||
|
||||
fn read_tensor_bool<B, const D: usize>(
|
||||
async fn read_tensor_bool<B, const D: usize>(
|
||||
&self,
|
||||
tensor: TensorDescription,
|
||||
stream: StreamId,
|
||||
) -> burn_tensor::Reader<burn_tensor::TensorData>
|
||||
) -> burn_tensor::TensorData
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
self.server.lock().read_bool::<B, D>(tensor, stream)
|
||||
self.server.lock().read_bool::<B, D>(tensor, stream).await
|
||||
}
|
||||
|
||||
fn change_client_float<B, const D: usize>(
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use burn_tensor::{DType, Element};
|
||||
use burn_tensor::{DType, Element, TensorData};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
|
@ -37,10 +37,8 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
tensor.shape()
|
||||
}
|
||||
|
||||
fn bool_into_data<const D: usize>(
|
||||
tensor: BoolTensor<Self, D>,
|
||||
) -> burn_tensor::Reader<burn_tensor::TensorData> {
|
||||
tensor.bool_into_data::<B, D>()
|
||||
async fn bool_into_data<const D: usize>(tensor: BoolTensor<Self, D>) -> TensorData {
|
||||
tensor.bool_into_data::<B, D>().await
|
||||
}
|
||||
|
||||
fn bool_from_data<const D: usize>(
|
||||
|
|
|
@ -10,7 +10,7 @@ use crate::{
|
|||
use burn_tensor::{
|
||||
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor},
|
||||
repr::*,
|
||||
DType, Device, Distribution, Element, ElementConversion, Reader, Shape, TensorData,
|
||||
DType, Device, Distribution, Element, ElementConversion, Shape, TensorData,
|
||||
};
|
||||
use std::{marker::PhantomData, ops::Range};
|
||||
|
||||
|
@ -169,8 +169,8 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
tensor.shape()
|
||||
}
|
||||
|
||||
fn float_into_data<const D: usize>(tensor: FloatTensor<Self, D>) -> Reader<TensorData> {
|
||||
tensor.into_data::<B, D>()
|
||||
async fn float_into_data<const D: usize>(tensor: FloatTensor<Self, D>) -> TensorData {
|
||||
tensor.into_data::<B, D>().await
|
||||
}
|
||||
|
||||
fn float_device<const D: usize>(tensor: &FloatTensor<Self, D>) -> Device<Self> {
|
||||
|
|
|
@ -10,7 +10,7 @@ use crate::{
|
|||
use burn_tensor::{
|
||||
ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps},
|
||||
repr::{self, *},
|
||||
DType, Device, Distribution, Element, ElementConversion, Reader, Shape, TensorData,
|
||||
DType, Device, Distribution, Element, ElementConversion, Shape, TensorData,
|
||||
};
|
||||
use core::ops::Range;
|
||||
use std::marker::PhantomData;
|
||||
|
@ -33,8 +33,8 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
tensor.shape()
|
||||
}
|
||||
|
||||
fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> Reader<TensorData> {
|
||||
tensor.int_into_data::<B, D>()
|
||||
async fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> TensorData {
|
||||
tensor.int_into_data::<B, D>().await
|
||||
}
|
||||
|
||||
fn int_from_data<const D: usize>(
|
||||
|
|
|
@ -39,11 +39,11 @@ where
|
|||
self.handles.create_tensor_uninit()
|
||||
}
|
||||
|
||||
pub fn read_float<B, const D: usize>(
|
||||
pub async fn read_float<B, const D: usize>(
|
||||
&mut self,
|
||||
tensor: TensorDescription,
|
||||
id: StreamId,
|
||||
) -> burn_tensor::Reader<burn_tensor::TensorData>
|
||||
) -> burn_tensor::TensorData
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
|
@ -52,14 +52,14 @@ where
|
|||
self.drain_stream(id);
|
||||
|
||||
let tensor = self.handles.get_float_tensor::<B, D>(&tensor);
|
||||
B::float_into_data(tensor)
|
||||
B::float_into_data(tensor).await
|
||||
}
|
||||
|
||||
pub fn read_int<B, const D: usize>(
|
||||
pub async fn read_int<B, const D: usize>(
|
||||
&mut self,
|
||||
tensor: TensorDescription,
|
||||
id: StreamId,
|
||||
) -> burn_tensor::Reader<burn_tensor::TensorData>
|
||||
) -> burn_tensor::TensorData
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
|
@ -68,14 +68,14 @@ where
|
|||
self.drain_stream(id);
|
||||
|
||||
let tensor = self.handles.get_int_tensor::<B, D>(&tensor);
|
||||
B::int_into_data(tensor)
|
||||
B::int_into_data(tensor).await
|
||||
}
|
||||
|
||||
pub fn read_bool<B, const D: usize>(
|
||||
pub async fn read_bool<B, const D: usize>(
|
||||
&mut self,
|
||||
tensor: TensorDescription,
|
||||
id: StreamId,
|
||||
) -> burn_tensor::Reader<burn_tensor::TensorData>
|
||||
) -> burn_tensor::TensorData
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
|
@ -84,7 +84,7 @@ where
|
|||
self.drain_stream(id);
|
||||
|
||||
let tensor = self.handles.get_bool_tensor::<B, D>(&tensor);
|
||||
B::bool_into_data(tensor)
|
||||
B::bool_into_data(tensor).await
|
||||
}
|
||||
|
||||
pub fn change_server_float<B, const D: usize>(
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{client::FusionClient, stream::StreamId, Client, FusionBackend, FusionRuntime};
|
||||
use burn_tensor::{
|
||||
repr::{TensorDescription, TensorId, TensorStatus},
|
||||
DType, Reader, Shape, TensorData,
|
||||
DType, Shape, TensorData,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
|
@ -108,7 +108,7 @@ impl<R: FusionRuntime> FusionTensor<R> {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn into_data<B, const D: usize>(self) -> Reader<TensorData>
|
||||
pub(crate) async fn into_data<B, const D: usize>(self) -> TensorData
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
|
@ -116,9 +116,10 @@ impl<R: FusionRuntime> FusionTensor<R> {
|
|||
self.client
|
||||
.clone()
|
||||
.read_tensor_float::<B, D>(self.into_description(), id)
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) fn int_into_data<B, const D: usize>(self) -> Reader<TensorData>
|
||||
pub(crate) async fn int_into_data<B, const D: usize>(self) -> TensorData
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
|
@ -126,9 +127,10 @@ impl<R: FusionRuntime> FusionTensor<R> {
|
|||
self.client
|
||||
.clone()
|
||||
.read_tensor_int::<B, D>(self.into_description(), id)
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) fn bool_into_data<B, const D: usize>(self) -> Reader<TensorData>
|
||||
pub(crate) async fn bool_into_data<B, const D: usize>(self) -> TensorData
|
||||
where
|
||||
B: FusionBackend<FusionRuntime = R>,
|
||||
{
|
||||
|
@ -136,6 +138,7 @@ impl<R: FusionRuntime> FusionTensor<R> {
|
|||
self.client
|
||||
.clone()
|
||||
.read_tensor_bool::<B, D>(self.into_description(), id)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{element::JitElement, kernel, tensor::JitTensor, JitRuntime};
|
||||
use burn_cube::CubeElement;
|
||||
use burn_tensor::{Reader, Shape, TensorData};
|
||||
use burn_tensor::{Shape, TensorData};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
pub(crate) fn from_data<R: JitRuntime, E: JitElement, const D: usize>(
|
||||
|
@ -14,28 +14,24 @@ pub(crate) fn from_data<R: JitRuntime, E: JitElement, const D: usize>(
|
|||
JitTensor::new(client, device.clone(), shape, buffer)
|
||||
}
|
||||
|
||||
pub(crate) fn into_data<R: JitRuntime, E: JitElement, const D: usize>(
|
||||
pub(crate) async fn into_data<R: JitRuntime, E: JitElement, const D: usize>(
|
||||
tensor: JitTensor<R, E, D>,
|
||||
) -> Reader<TensorData> {
|
||||
) -> TensorData {
|
||||
let tensor = kernel::into_contiguous(tensor);
|
||||
|
||||
tensor
|
||||
.client
|
||||
.read(tensor.handle.binding())
|
||||
.map(|bytes| TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape))
|
||||
let bytes = tensor.client.read_async(tensor.handle.binding()).await;
|
||||
TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape)
|
||||
}
|
||||
|
||||
pub(crate) fn bool_into_data<R: JitRuntime, const D: usize>(
|
||||
pub(crate) async fn bool_into_data<R: JitRuntime, const D: usize>(
|
||||
tensor: JitTensor<R, u32, D>,
|
||||
) -> Reader<TensorData> {
|
||||
) -> TensorData {
|
||||
let tensor = kernel::into_contiguous(tensor);
|
||||
|
||||
tensor.client.read(tensor.handle.binding()).map(|bytes| {
|
||||
TensorData::new(
|
||||
u32::from_bytes(&bytes).iter().map(|i| *i != 0).collect(),
|
||||
tensor.shape,
|
||||
)
|
||||
})
|
||||
let bytes = tensor.client.read_async(tensor.handle.binding()).await;
|
||||
TensorData::new(
|
||||
u32::from_bytes(&bytes).iter().map(|i| *i != 0).collect(),
|
||||
tensor.shape,
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn to_device<R: JitRuntime, E: JitElement, const D: usize>(
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
use crate::{kernel, FloatElement, IntElement, JitBackend, JitRuntime};
|
||||
use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntTensor};
|
||||
use burn_tensor::Reader;
|
||||
use burn_tensor::{ops::BoolTensorOps, Shape, TensorData};
|
||||
use std::ops::Range;
|
||||
|
||||
|
@ -20,8 +19,8 @@ where
|
|||
tensor.shape.clone()
|
||||
}
|
||||
|
||||
fn bool_into_data<const D: usize>(tensor: BoolTensor<Self, D>) -> Reader<TensorData> {
|
||||
super::bool_into_data(tensor)
|
||||
async fn bool_into_data<const D: usize>(tensor: BoolTensor<Self, D>) -> TensorData {
|
||||
super::bool_into_data(tensor).await
|
||||
}
|
||||
|
||||
fn bool_from_data<const D: usize>(
|
||||
|
|
|
@ -7,8 +7,8 @@ use crate::{FloatElement, IntElement, JitRuntime};
|
|||
use burn_cube::ir::{BinaryOperator, Elem, Operator, Scope, UnaryOperator, Variable};
|
||||
use burn_cube::Runtime;
|
||||
use burn_tensor::ops::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor};
|
||||
use burn_tensor::ElementConversion;
|
||||
use burn_tensor::{ops::FloatTensorOps, Distribution, Shape, TensorData};
|
||||
use burn_tensor::{ElementConversion, Reader};
|
||||
use std::ops::Range;
|
||||
|
||||
impl<R, F, I> FloatTensorOps<Self> for JitBackend<R, F, I>
|
||||
|
@ -45,8 +45,8 @@ where
|
|||
tensor.shape.clone()
|
||||
}
|
||||
|
||||
fn float_into_data<const D: usize>(tensor: FloatTensor<Self, D>) -> Reader<TensorData> {
|
||||
super::into_data(tensor)
|
||||
async fn float_into_data<const D: usize>(tensor: FloatTensor<Self, D>) -> TensorData {
|
||||
super::into_data(tensor).await
|
||||
}
|
||||
|
||||
fn float_device<const D: usize>(tensor: &FloatTensor<Self, D>) -> Device<Self> {
|
||||
|
|
|
@ -4,7 +4,7 @@ use crate::{kernel, unary, FloatElement, IntElement, JitBackend, JitRuntime};
|
|||
use burn_cube::ir::{Elem, Item, Operator, Scope, UnaryOperator, Variable};
|
||||
use burn_cube::Runtime;
|
||||
use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
|
||||
use burn_tensor::{ops::IntTensorOps, Distribution, ElementConversion, Reader, Shape, TensorData};
|
||||
use burn_tensor::{ops::IntTensorOps, Distribution, ElementConversion, Shape, TensorData};
|
||||
use std::ops::Range;
|
||||
|
||||
impl<R, F, I> IntTensorOps<Self> for JitBackend<R, F, I>
|
||||
|
@ -21,8 +21,8 @@ where
|
|||
tensor.shape.clone()
|
||||
}
|
||||
|
||||
fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> Reader<TensorData> {
|
||||
super::into_data(tensor)
|
||||
async fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> TensorData {
|
||||
super::into_data(tensor).await
|
||||
}
|
||||
|
||||
fn int_from_data<const D: usize>(
|
||||
|
|
|
@ -101,11 +101,10 @@ where
|
|||
client: ComputeClient<R::Server, R::Channel>,
|
||||
device: R::Device,
|
||||
) -> Self {
|
||||
let bytes = self
|
||||
.client
|
||||
.read(self.handle.clone().binding())
|
||||
.read_sync()
|
||||
.expect("Can only change client synchronously");
|
||||
let bytes = burn_common::reader::try_read_sync(
|
||||
self.client.read_async(self.handle.clone().binding()),
|
||||
)
|
||||
.expect("Can only change client synchronously");
|
||||
let handle = client.create(&bytes);
|
||||
|
||||
Self {
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
use alloc::vec;
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::ops::{BoolTensorOps, IntTensorOps};
|
||||
use burn_tensor::{ElementConversion, Reader};
|
||||
use burn_tensor::ElementConversion;
|
||||
use core::ops::Range;
|
||||
use ndarray::IntoDimension;
|
||||
|
||||
|
@ -30,13 +30,12 @@ impl<E: FloatNdArrayElement> BoolTensorOps<Self> for NdArray<E> {
|
|||
tensor.shape()
|
||||
}
|
||||
|
||||
fn bool_into_data<const D: usize>(
|
||||
async fn bool_into_data<const D: usize>(
|
||||
tensor: <NdArray<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> Reader<TensorData> {
|
||||
) -> TensorData {
|
||||
let shape = tensor.shape();
|
||||
let values = tensor.array.into_iter().collect();
|
||||
|
||||
Reader::Concrete(TensorData::new(values, shape))
|
||||
TensorData::new(values, shape)
|
||||
}
|
||||
|
||||
fn bool_to_device<const D: usize>(
|
||||
|
@ -63,10 +62,12 @@ impl<E: FloatNdArrayElement> BoolTensorOps<Self> for NdArray<E> {
|
|||
fn bool_into_int<const D: usize>(
|
||||
tensor: <NdArray<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> NdArrayTensor<i64, D> {
|
||||
let data = Self::bool_into_data(tensor)
|
||||
.read_sync()
|
||||
.expect("Always sync with ndarray");
|
||||
NdArray::<E>::int_from_data(data.convert::<i64>(), &NdArrayDevice::Cpu)
|
||||
let shape = tensor.shape();
|
||||
let values = tensor.array.into_iter().collect();
|
||||
NdArray::<E>::int_from_data(
|
||||
TensorData::new(values, shape).convert::<i64>(),
|
||||
&NdArrayDevice::Cpu,
|
||||
)
|
||||
}
|
||||
|
||||
fn bool_device<const D: usize>(
|
||||
|
|
|
@ -3,7 +3,7 @@ use alloc::vec;
|
|||
use alloc::vec::Vec;
|
||||
use burn_common::rand::get_seeded_rng;
|
||||
use burn_tensor::ops::IntTensorOps;
|
||||
use burn_tensor::{Distribution, Reader};
|
||||
use burn_tensor::Distribution;
|
||||
|
||||
use burn_tensor::ElementConversion;
|
||||
use core::ops::Range;
|
||||
|
@ -32,11 +32,10 @@ impl<E: FloatNdArrayElement> IntTensorOps<Self> for NdArray<E> {
|
|||
tensor.shape()
|
||||
}
|
||||
|
||||
fn int_into_data<const D: usize>(tensor: NdArrayTensor<i64, D>) -> Reader<TensorData> {
|
||||
async fn int_into_data<const D: usize>(tensor: NdArrayTensor<i64, D>) -> TensorData {
|
||||
let shape = tensor.shape();
|
||||
let values = tensor.array.into_iter().collect();
|
||||
|
||||
Reader::Concrete(TensorData::new(values, shape))
|
||||
TensorData::new(values, shape)
|
||||
}
|
||||
|
||||
fn int_to_device<const D: usize>(
|
||||
|
|
|
@ -11,8 +11,8 @@ use crate::{NdArrayDevice, SEED};
|
|||
|
||||
// Workspace crates
|
||||
use burn_common::rand::get_seeded_rng;
|
||||
use burn_tensor::Distribution;
|
||||
use burn_tensor::{backend::Backend, ops::FloatTensorOps, ElementConversion, Shape, TensorData};
|
||||
use burn_tensor::{Distribution, Reader};
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[allow(unused_imports)]
|
||||
|
@ -51,11 +51,10 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
|
|||
tensor.shape()
|
||||
}
|
||||
|
||||
fn float_into_data<const D: usize>(tensor: NdArrayTensor<E, D>) -> Reader<TensorData> {
|
||||
async fn float_into_data<const D: usize>(tensor: NdArrayTensor<E, D>) -> TensorData {
|
||||
let shape = tensor.shape();
|
||||
let values = tensor.array.into_iter().collect();
|
||||
|
||||
Reader::Concrete(TensorData::new(values, shape))
|
||||
TensorData::new(values, shape)
|
||||
}
|
||||
|
||||
fn float_device<const D: usize>(_tensor: &NdArrayTensor<E, D>) -> NdArrayDevice {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use super::TchOps;
|
||||
use crate::{element::TchElement, LibTorch, LibTorchDevice, TchTensor};
|
||||
use burn_tensor::{backend::Backend, ops::BoolTensorOps, Reader, Shape, TensorData};
|
||||
use burn_tensor::{backend::Backend, ops::BoolTensorOps, Shape, TensorData};
|
||||
use std::ops::Range;
|
||||
|
||||
impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
|
||||
|
@ -23,12 +23,11 @@ impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
|
|||
TchOps::repeat(tensor, dim, times)
|
||||
}
|
||||
|
||||
fn bool_into_data<const D: usize>(tensor: TchTensor<bool, D>) -> Reader<TensorData> {
|
||||
async fn bool_into_data<const D: usize>(tensor: TchTensor<bool, D>) -> TensorData {
|
||||
let shape = Self::bool_shape(&tensor);
|
||||
let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
|
||||
let values: Result<Vec<bool>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
|
||||
|
||||
Reader::Concrete(TensorData::new(values.unwrap(), shape))
|
||||
TensorData::new(values.unwrap(), shape)
|
||||
}
|
||||
|
||||
fn bool_to_device<const D: usize>(
|
||||
|
@ -143,13 +142,13 @@ impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
|
|||
TchOps::flip(tensor, axes)
|
||||
}
|
||||
|
||||
fn bool_argwhere<const D: usize>(
|
||||
async fn bool_argwhere<const D: usize>(
|
||||
tensor: <LibTorch<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> TchTensor<i64, 2> {
|
||||
TchTensor::new(tensor.tensor.argwhere())
|
||||
}
|
||||
|
||||
fn bool_nonzero<const D: usize>(
|
||||
async fn bool_nonzero<const D: usize>(
|
||||
tensor: <LibTorch<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> Vec<TchTensor<i64, 1>> {
|
||||
tensor
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::ops::Range;
|
||||
|
||||
use burn_tensor::{backend::Backend, ops::IntTensorOps, Distribution, Reader, Shape, TensorData};
|
||||
use burn_tensor::{backend::Backend, ops::IntTensorOps, Distribution, Shape, TensorData};
|
||||
|
||||
use crate::{element::TchElement, LibTorch, LibTorchDevice, TchShape, TchTensor};
|
||||
|
||||
|
@ -26,12 +26,11 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
|
|||
TchOps::repeat(tensor, dim, times)
|
||||
}
|
||||
|
||||
fn int_into_data<const D: usize>(tensor: TchTensor<i64, D>) -> Reader<TensorData> {
|
||||
async fn int_into_data<const D: usize>(tensor: TchTensor<i64, D>) -> TensorData {
|
||||
let shape = Self::int_shape(&tensor);
|
||||
let tensor = Self::int_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
|
||||
let values: Result<Vec<i64>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
|
||||
|
||||
Reader::Concrete(TensorData::new(values.unwrap(), shape))
|
||||
TensorData::new(values.unwrap(), shape)
|
||||
}
|
||||
|
||||
fn int_to_device<const D: usize>(
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
use super::TchOps;
|
||||
use crate::{element::TchElement, LibTorch, LibTorchDevice, TchShape, TchTensor};
|
||||
use burn_tensor::{
|
||||
backend::Backend, ops::FloatTensorOps, Distribution, ElementConversion, Reader, Shape,
|
||||
TensorData,
|
||||
backend::Backend, ops::FloatTensorOps, Distribution, ElementConversion, Shape, TensorData,
|
||||
};
|
||||
use std::ops::Range;
|
||||
|
||||
|
@ -71,14 +70,14 @@ impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
|
|||
tensor.shape()
|
||||
}
|
||||
|
||||
fn float_into_data<const D: usize>(
|
||||
async fn float_into_data<const D: usize>(
|
||||
tensor: <LibTorch<E> as Backend>::FloatTensorPrimitive<D>,
|
||||
) -> Reader<TensorData> {
|
||||
) -> TensorData {
|
||||
let shape = Self::float_shape(&tensor);
|
||||
let tensor = Self::float_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
|
||||
let values: Result<Vec<E>, tch::TchError> = tensor.tensor.try_into();
|
||||
|
||||
Reader::Concrete(TensorData::new(values.unwrap(), shape))
|
||||
TensorData::new(values.unwrap(), shape)
|
||||
}
|
||||
|
||||
fn float_device<const D: usize>(tensor: &TchTensor<E, D>) -> LibTorchDevice {
|
||||
|
|
|
@ -295,21 +295,6 @@ impl<E: tch::kind::Element + Default + Element, const D: usize> TchTensor<E, D>
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod utils {
|
||||
use super::*;
|
||||
use crate::{backend::LibTorch, element::TchElement};
|
||||
|
||||
impl<P: TchElement, const D: usize> TchTensor<P, D> {
|
||||
pub(crate) fn into_data(self) -> TensorData
|
||||
where
|
||||
P: tch::kind::Element,
|
||||
{
|
||||
<LibTorch<P> as FloatTensorOps<LibTorch<P>>>::float_into_data(self).read()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TchTensor<E, D> {
|
||||
/// Creates an empty tensor from a shape and a device.
|
||||
///
|
||||
|
@ -345,7 +330,7 @@ mod tests {
|
|||
);
|
||||
let tensor = TchTensor::<f32, 1>::from_data(data_expected.clone(), tch::Device::Cpu);
|
||||
|
||||
let data_actual = tensor.into_data();
|
||||
let data_actual = Tensor::<LibTorch<f32>, 1>::from_primitive(tensor).into_data();
|
||||
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
@ -359,7 +344,7 @@ mod tests {
|
|||
);
|
||||
let tensor = TchTensor::<f32, 2>::from_data(data_expected.clone(), tch::Device::Cpu);
|
||||
|
||||
let data_actual = tensor.into_data();
|
||||
let data_actual = Tensor::<LibTorch<f32>, 2>::from_primitive(tensor).into_data();
|
||||
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@ experimental-named-tensor = []
|
|||
export_tests = ["burn-tensor-testgen"]
|
||||
std = ["rand/std", "half/std", "num-traits/std"]
|
||||
repr = []
|
||||
wasm-sync = []
|
||||
|
||||
[dependencies]
|
||||
burn-common = { path = "../burn-common", version = "0.14.0", default-features = false }
|
||||
|
@ -27,7 +26,7 @@ derive-new = { workspace = true }
|
|||
half = { workspace = true, features = ["bytemuck"] }
|
||||
num-traits = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
rand_distr = { workspace = true } # use instead of statrs because it supports no_std
|
||||
rand_distr = { workspace = true } # use instead of statrs because it supports no_std
|
||||
bytemuck = { workspace = true }
|
||||
|
||||
# The same implementation of HashMap in std but with no_std support (only needs alloc crate)
|
||||
|
|
|
@ -25,4 +25,4 @@ pub use half::{bf16, f16};
|
|||
pub(crate) use tensor::check::macros::check;
|
||||
pub use tensor::*;
|
||||
|
||||
pub use burn_common::reader::Reader; // Useful so that backends don't have to add `burn_common` as
|
||||
pub use burn_common::reader::*; // Useful so that backends don't have to add `burn_common` as a dependency.
|
||||
|
|
|
@ -1,15 +1,11 @@
|
|||
use crate::{
|
||||
backend::Backend,
|
||||
ops::{BoolTensor, IntTensor},
|
||||
Device, ElementConversion, Shape, TensorData,
|
||||
};
|
||||
use crate::{backend::Backend, ops::IntTensor, Device, ElementConversion, Shape, TensorData};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
/// Compute the indices of the elements that are non-zero, grouped by element.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The input tensor.
|
||||
/// * `data` - The input tensor data.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
|
@ -22,45 +18,7 @@ use alloc::vec::Vec;
|
|||
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
|
||||
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
pub fn argwhere<B: Backend, const D: usize>(tensor: BoolTensor<B, D>) -> IntTensor<B, 2> {
|
||||
// Size of each output tensor is variable (= number of nonzero elements in the tensor).
|
||||
// Reading the data to count the number of truth values might cause sync but is required.
|
||||
// let dims = B::bool_shape(&tensor).dims;
|
||||
let device = B::bool_device(&tensor);
|
||||
let data = B::bool_into_data(tensor).read();
|
||||
|
||||
argwhere_data::<B, D>(data, &device)
|
||||
}
|
||||
|
||||
/// Compute the indices of the elements that are non-zero, grouped by element.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The input tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of tensors, one for each dimension of the given tensor, containing the indices of
|
||||
/// the non-zero elements in that dimension.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
|
||||
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
|
||||
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn argwhere<B: Backend, const D: usize>(tensor: BoolTensor<B, D>) -> IntTensor<B, 2> {
|
||||
// Size of each output tensor is variable (= number of nonzero elements in the tensor).
|
||||
// Reading the data to count the number of truth values might cause sync but is required.
|
||||
let device = B::bool_device(&tensor);
|
||||
let data = B::bool_into_data(tensor).read().await;
|
||||
|
||||
argwhere_data::<B, D>(data, &device)
|
||||
}
|
||||
|
||||
fn argwhere_data<B: Backend, const D: usize>(
|
||||
pub fn argwhere_data<B: Backend, const D: usize>(
|
||||
data: TensorData,
|
||||
device: &Device<B>,
|
||||
) -> IntTensor<B, 2> {
|
||||
|
|
|
@ -2,19 +2,16 @@
|
|||
|
||||
use alloc::vec::Vec;
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
use alloc::format;
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
use alloc::string::String;
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
use alloc::vec;
|
||||
|
||||
use burn_common::{reader::Reader, stub::Mutex};
|
||||
use burn_common::stub::Mutex;
|
||||
use core::future::Future;
|
||||
use core::iter::repeat;
|
||||
use core::{fmt::Debug, ops::Range};
|
||||
use serde::{Deserialize, Deserializer};
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
use serde::{Serialize, Serializer};
|
||||
|
||||
use crate::check::TensorCheck;
|
||||
|
@ -675,28 +672,28 @@ where
|
|||
Self::new(K::to_device(self.primitive, device))
|
||||
}
|
||||
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
/// Returns the data of the current tensor.
|
||||
pub async fn into_data(self) -> TensorData {
|
||||
K::into_data(self.primitive).read().await
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
/// Returns the data of the current tensor.
|
||||
/// Converts the data of the current tensor.
|
||||
pub fn into_data(self) -> TensorData {
|
||||
K::into_data(self.primitive).read()
|
||||
crate::try_read_sync(self.into_data_async()).expect(
|
||||
"Failed to read tensor data synchronously.
|
||||
This can happen on platforms that don't support blocking futures like WASM.
|
||||
If possible, try using into_data_async instead.",
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
/// Returns the data of the current tensor.
|
||||
pub async fn to_data(&self) -> TensorData {
|
||||
K::into_data(self.primitive.clone()).read().await
|
||||
pub fn to_data(&self) -> TensorData {
|
||||
self.clone().into_data()
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
/// Returns the data of the current tensor without taking ownership.
|
||||
pub fn to_data(&self) -> TensorData {
|
||||
Self::into_data(self.clone())
|
||||
/// Returns the data of the current tensor.
|
||||
pub async fn into_data_async(self) -> TensorData {
|
||||
K::into_data_async(self.primitive).await
|
||||
}
|
||||
|
||||
/// Returns the data of the current tensor.
|
||||
pub async fn to_data_async(&self) -> TensorData {
|
||||
self.clone().into_data_async().await
|
||||
}
|
||||
|
||||
/// Create a tensor from the given data on the given device.
|
||||
|
@ -875,11 +872,12 @@ where
|
|||
/// # Panics
|
||||
///
|
||||
/// If the tensor doesn't have one element.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
/// If the backend fails to read the tensor data synchronously.
|
||||
pub fn into_scalar(self) -> K::Elem {
|
||||
check!(TensorCheck::into_scalar(&self.shape()));
|
||||
let x = self.into_data().iter().next().unwrap();
|
||||
x
|
||||
crate::try_read_sync(self.into_scalar_async()).expect(
|
||||
"Failed to read tensor data synchronously. This can happen on platforms
|
||||
that don't support blocking futures like WASM. Try into_scalar_async instead.",
|
||||
)
|
||||
}
|
||||
|
||||
/// Convert the tensor into a scalar.
|
||||
|
@ -887,10 +885,9 @@ where
|
|||
/// # Panics
|
||||
///
|
||||
/// If the tensor doesn't have one element.
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn into_scalar(self) -> K::Elem {
|
||||
pub async fn into_scalar_async(self) -> K::Elem {
|
||||
check!(TensorCheck::into_scalar(&self.shape()));
|
||||
let x = self.into_data().await.iter().next().unwrap();
|
||||
let x = self.into_data_async().await.iter().next().unwrap();
|
||||
x
|
||||
}
|
||||
|
||||
|
@ -989,7 +986,6 @@ where
|
|||
K: BasicOps<B>,
|
||||
<K as BasicOps<B>>::Elem: Debug,
|
||||
{
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
#[inline]
|
||||
fn push_newline_indent(acc: &mut String, indent: usize) {
|
||||
acc.push('\n');
|
||||
|
@ -998,7 +994,6 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn fmt_inner_tensor(
|
||||
&self,
|
||||
acc: &mut String,
|
||||
|
@ -1015,18 +1010,18 @@ where
|
|||
let range: [core::ops::Range<usize>; D] =
|
||||
core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);
|
||||
|
||||
let elem = &self
|
||||
.clone()
|
||||
.slice(range)
|
||||
.into_data()
|
||||
.iter::<<K as BasicOps<B>>::Elem>()
|
||||
.next()
|
||||
.unwrap();
|
||||
acc.push_str(&format!("{elem:?}"));
|
||||
let data =
|
||||
burn_common::reader::try_read_sync(self.clone().slice(range).into_data_async());
|
||||
|
||||
if let Some(data) = data {
|
||||
let elem = data.iter::<<K as BasicOps<B>>::Elem>().next().unwrap();
|
||||
acc.push_str(&format!("{elem:?}"));
|
||||
} else {
|
||||
acc.push_str("<Tensor data not available>");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn fmt_outer_tensor(
|
||||
&self,
|
||||
acc: &mut String,
|
||||
|
@ -1061,7 +1056,6 @@ where
|
|||
/// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output.
|
||||
/// * `depth` - The current depth of the tensor dimensions being processed.
|
||||
/// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn display_recursive(
|
||||
&self,
|
||||
acc: &mut String,
|
||||
|
@ -1172,7 +1166,6 @@ where
|
|||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
writeln!(f, "Tensor {{")?;
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
{
|
||||
let po = PRINT_OPTS.lock().unwrap();
|
||||
let mut acc = String::new();
|
||||
|
@ -1435,7 +1428,7 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
|
|||
device: &B::Device,
|
||||
) -> Self::Primitive<D>;
|
||||
|
||||
/// Extracts the data from the tensor.
|
||||
/// Extracts the data from the tensor asynchronously.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
|
@ -1453,7 +1446,9 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
|
|||
///
|
||||
/// For extracting the data of a tensor, users should prefer the [Tensor::into_data](Tensor::into_data) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Reader<TensorData>;
|
||||
fn into_data_async<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
) -> impl Future<Output = TensorData> + Send;
|
||||
|
||||
/// Creates a tensor from the given data.
|
||||
///
|
||||
|
@ -1724,8 +1719,8 @@ impl<B: Backend> BasicOps<B> for Float {
|
|||
B::float_to_device(tensor, device)
|
||||
}
|
||||
|
||||
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Reader<TensorData> {
|
||||
B::float_into_data(tensor)
|
||||
async fn into_data_async<const D: usize>(tensor: Self::Primitive<D>) -> TensorData {
|
||||
B::float_into_data(tensor).await
|
||||
}
|
||||
|
||||
fn from_data<const D: usize>(data: TensorData, device: &B::Device) -> Self::Primitive<D> {
|
||||
|
@ -1846,8 +1841,8 @@ impl<B: Backend> BasicOps<B> for Int {
|
|||
B::int_to_device(tensor, device)
|
||||
}
|
||||
|
||||
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Reader<TensorData> {
|
||||
B::int_into_data(tensor)
|
||||
async fn into_data_async<const D: usize>(tensor: Self::Primitive<D>) -> TensorData {
|
||||
B::int_into_data(tensor).await
|
||||
}
|
||||
|
||||
fn from_data<const D: usize>(data: TensorData, device: &B::Device) -> Self::Primitive<D> {
|
||||
|
@ -1968,8 +1963,8 @@ impl<B: Backend> BasicOps<B> for Bool {
|
|||
B::bool_to_device(tensor, device)
|
||||
}
|
||||
|
||||
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Reader<TensorData> {
|
||||
B::bool_into_data(tensor)
|
||||
async fn into_data_async<const D: usize>(tensor: Self::Primitive<D>) -> TensorData {
|
||||
B::bool_into_data(tensor).await
|
||||
}
|
||||
|
||||
fn from_data<const D: usize>(data: TensorData, device: &B::Device) -> Self::Primitive<D> {
|
||||
|
@ -2230,7 +2225,6 @@ impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for [i32; D2] {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
impl<B, const D: usize, K> Serialize for Tensor<B, D, K>
|
||||
where
|
||||
B: Backend,
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
use crate::{backend::Backend, Bool, Int, Shape, Tensor, TensorData};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
use crate::argwhere;
|
||||
use crate::try_read_sync;
|
||||
|
||||
/// The part of the tensor to keep when creating a triangular mask.
|
||||
enum TriPart {
|
||||
|
@ -46,27 +45,21 @@ where
|
|||
///
|
||||
/// A vector of tensors, one for each dimension of the given tensor, containing the indices of
|
||||
/// the non-zero elements in that dimension.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
pub fn nonzero(self) -> Vec<Tensor<B, 1, Int>> {
|
||||
B::bool_nonzero(self.primitive)
|
||||
.into_iter()
|
||||
.map(Tensor::new)
|
||||
.collect()
|
||||
try_read_sync(self.nonzero_async())
|
||||
.expect("Failed to read tensor data synchronously. Try using nonzero_async instead.")
|
||||
}
|
||||
|
||||
/// Compute the indices of the elements that are true.
|
||||
/// Compute the indices of the elements that are non-zero.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of tensors, one for each dimension of the given tensor, containing the indices of
|
||||
/// the non-zero elements in that dimension.
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn nonzero(self) -> Vec<Tensor<B, 1, Int>> {
|
||||
let indices = self.argwhere().await.primitive;
|
||||
let dims = B::int_shape(&indices).dims;
|
||||
B::int_chunk(indices, dims[1], 1)
|
||||
pub async fn nonzero_async(self) -> Vec<Tensor<B, 1, Int>> {
|
||||
B::bool_nonzero(self.primitive)
|
||||
.await
|
||||
.into_iter()
|
||||
.map(|t| B::int_reshape(t, Shape::new([dims[0]])))
|
||||
.map(Tensor::new)
|
||||
.collect()
|
||||
}
|
||||
|
@ -77,9 +70,9 @@ where
|
|||
///
|
||||
/// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
|
||||
/// result contains the indices of a non-zero element.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
pub fn argwhere(self) -> Tensor<B, 2, Int> {
|
||||
Tensor::new(B::bool_argwhere(self.primitive))
|
||||
try_read_sync(self.argwhere_async())
|
||||
.expect("Failed to read tensor data synchronously. Try using argwhere_async instead.")
|
||||
}
|
||||
|
||||
/// Compute the indices of the elements that are true, grouped by element.
|
||||
|
@ -88,9 +81,8 @@ where
|
|||
///
|
||||
/// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
|
||||
/// result contains the indices of a non-zero element.
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn argwhere(self) -> Tensor<B, 2, Int> {
|
||||
Tensor::new(argwhere::<B, D>(self.primitive).await)
|
||||
pub async fn argwhere_async(self) -> Tensor<B, 2, Int> {
|
||||
Tensor::new(B::bool_argwhere(self.primitive).await)
|
||||
}
|
||||
|
||||
/// Creates a mask for the upper, lower triangle, or diagonal of a matrix, which can be used to
|
||||
|
|
|
@ -10,9 +10,6 @@ use crate::tensor::{Distribution, Shape, TensorData};
|
|||
use crate::Int;
|
||||
use crate::Tensor;
|
||||
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
use crate::{argsort, sort, sort_with_indices, Float};
|
||||
|
||||
impl<const D: usize, B> Tensor<B, D>
|
||||
where
|
||||
B: Backend,
|
||||
|
@ -265,88 +262,4 @@ where
|
|||
.matmul(centered)
|
||||
.div_scalar(n as f32 - correction_factor as f32)
|
||||
}
|
||||
|
||||
/// Sort the elements by value in ascending order along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn sort(self, dim: usize) -> Tensor<B, D> {
|
||||
Tensor::new(sort::<B, D, Float>(self.primitive, dim, /*descending*/ false).await)
|
||||
}
|
||||
|
||||
/// Sort the elements by value in descending order along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn sort_descending(self, dim: usize) -> Tensor<B, D> {
|
||||
Tensor::new(sort::<B, D, Float>(self.primitive, dim, /*descending*/ true).await)
|
||||
}
|
||||
|
||||
/// Sort the elements by value in ascending order along a given dimension.
|
||||
/// Also returns the indices.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn sort_with_indices(self, dim: usize) -> (Tensor<B, D>, Tensor<B, D, Int>) {
|
||||
check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
|
||||
let (values, indices) =
|
||||
sort_with_indices::<B, D, Float>(self.primitive, dim, /*descending*/ false).await;
|
||||
(Tensor::new(values), Tensor::new(indices))
|
||||
}
|
||||
|
||||
/// Sort the elements by value in descending order along a given dimension.
|
||||
/// Also returns the indices.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn sort_descending_with_indices(
|
||||
self,
|
||||
dim: usize,
|
||||
) -> (Tensor<B, D>, Tensor<B, D, Int>) {
|
||||
check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
|
||||
let (values, indices) =
|
||||
sort_with_indices::<B, D, Float>(self.primitive, dim, /*descending*/ true).await;
|
||||
(Tensor::new(values), Tensor::new(indices))
|
||||
}
|
||||
|
||||
/// Returns the indices that sort the elements by value in ascending order along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn argsort(self, dim: usize) -> Tensor<B, D, Int> {
|
||||
check!(TensorCheck::sort_dim::<D>("Argsort", dim));
|
||||
Tensor::new(argsort::<B, D, Float>(self.primitive, dim, /*descending*/ false).await)
|
||||
}
|
||||
|
||||
/// Returns the indices that sort the elements by value in descending order along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn argsort_descending(self, dim: usize) -> Tensor<B, D, Int> {
|
||||
check!(TensorCheck::sort_dim::<D>("Argsort", dim));
|
||||
Tensor::new(argsort::<B, D, Float>(self.primitive, dim, /*descending*/ true).await)
|
||||
}
|
||||
|
||||
/// Returns the `k` largest elements of the given input tensor along a given dimension.
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn topk(self, k: usize, dim: usize) -> Tensor<B, D> {
|
||||
let k_indices = Tensor::arange(0..k as i64, &self.device());
|
||||
self.sort_descending(dim).await.select(dim, k_indices)
|
||||
}
|
||||
|
||||
/// Returns the `k` largest elements of the given input tensor along a given dimension.
|
||||
/// Also returns the indices.
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn topk_with_indices(
|
||||
self,
|
||||
k: usize,
|
||||
dim: usize,
|
||||
) -> (Tensor<B, D>, Tensor<B, D, Int>) {
|
||||
let k_indices = Tensor::arange(0..k as i64, &self.device());
|
||||
let (values, indices) = self.sort_descending_with_indices(dim).await;
|
||||
(
|
||||
values.select(dim, k_indices.clone()),
|
||||
indices.select(dim, k_indices),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,9 +2,6 @@ use crate::{backend::Backend, Float, Int, Shape, Tensor, TensorData};
|
|||
|
||||
use core::ops::Range;
|
||||
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
use crate::{argsort, check, check::TensorCheck, sort, sort_with_indices};
|
||||
|
||||
impl<B> Tensor<B, 1, Int>
|
||||
where
|
||||
B: Backend,
|
||||
|
@ -100,88 +97,4 @@ where
|
|||
) -> Tensor<B, D2, Int> {
|
||||
Tensor::new(B::int_cartesian_grid::<S, D, D2>(shape, device))
|
||||
}
|
||||
|
||||
/// Sort the elements by value in ascending order along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn sort(self, dim: usize) -> Tensor<B, D, Int> {
|
||||
Tensor::new(sort::<B, D, Int>(self.primitive, dim, /* descending */ false).await)
|
||||
}
|
||||
|
||||
/// Sort the elements by value in descending order along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn sort_descending(self, dim: usize) -> Tensor<B, D, Int> {
|
||||
Tensor::new(sort::<B, D, Int>(self.primitive, dim, /* descending */ true).await)
|
||||
}
|
||||
|
||||
/// Sort the elements by value in ascending order along a given dimension.
|
||||
/// Also returns the indices.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn sort_with_indices(self, dim: usize) -> (Tensor<B, D, Int>, Tensor<B, D, Int>) {
|
||||
check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
|
||||
let (values, indices) =
|
||||
sort_with_indices::<B, D, Int>(self.primitive, dim, /*descending*/ false).await;
|
||||
(Tensor::new(values), Tensor::new(indices))
|
||||
}
|
||||
|
||||
/// Sort the elements by value in descending order along a given dimension.
|
||||
/// Also returns the indices.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn sort_descending_with_indices(
|
||||
self,
|
||||
dim: usize,
|
||||
) -> (Tensor<B, D, Int>, Tensor<B, D, Int>) {
|
||||
check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
|
||||
let (values, indices) =
|
||||
sort_with_indices::<B, D, Int>(self.primitive, dim, /*descending*/ true).await;
|
||||
(Tensor::new(values), Tensor::new(indices))
|
||||
}
|
||||
|
||||
/// Returns the indices that sort the elements by value in ascending order along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn argsort(self, dim: usize) -> Tensor<B, D, Int> {
|
||||
check!(TensorCheck::sort_dim::<D>("Argsort", dim));
|
||||
Tensor::new(argsort::<B, D, Int>(self.primitive, dim, /*descending*/ false).await)
|
||||
}
|
||||
|
||||
/// Returns the indices that sort the elements by value in descending order along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn argsort_descending(self, dim: usize) -> Tensor<B, D, Int> {
|
||||
check!(TensorCheck::sort_dim::<D>("Argsort", dim));
|
||||
Tensor::new(argsort::<B, D, Int>(self.primitive, dim, /*descending*/ true).await)
|
||||
}
|
||||
|
||||
/// Returns the `k` largest elements of the given input tensor along a given dimension.
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn topk(self, k: usize, dim: usize) -> Tensor<B, D, Int> {
|
||||
let k_indices = Tensor::arange(0..k as i64, &self.device());
|
||||
self.sort_descending(dim).await.select(dim, k_indices)
|
||||
}
|
||||
|
||||
/// Returns the `k` largest elements of the given input tensor along a given dimension.
|
||||
/// Also returns the indices.
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn topk_with_indices(
|
||||
self,
|
||||
k: usize,
|
||||
dim: usize,
|
||||
) -> (Tensor<B, D, Int>, Tensor<B, D, Int>) {
|
||||
let k_indices = Tensor::arange(0..k as i64, &self.device());
|
||||
let (values, indices) = self.sort_descending_with_indices(dim).await;
|
||||
(
|
||||
values.select(dim, k_indices.clone()),
|
||||
indices.select(dim, k_indices),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ mod narrow;
|
|||
mod numeric;
|
||||
mod sort;
|
||||
|
||||
pub use argwhere::argwhere;
|
||||
pub use argwhere::argwhere_data;
|
||||
pub use autodiff::*;
|
||||
pub use base::*;
|
||||
pub use cartesian_grid::cartesian_grid;
|
||||
|
|
|
@ -642,9 +642,6 @@ where
|
|||
/// A boolean scalar.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This method is only available for non-wasm targets or when the `wasm-sync` feature is enabled.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
pub fn all_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> bool {
|
||||
self.is_close(other, rtol, atol).all().into_scalar()
|
||||
}
|
||||
|
@ -671,7 +668,6 @@ where
|
|||
/// Sort the elements by value in ascending order along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
pub fn sort(self, dim: usize) -> Tensor<B, D, K> {
|
||||
check!(TensorCheck::sort_dim::<D>("Sort", dim));
|
||||
Tensor::new(K::sort(self.primitive, dim, /*descending*/ false))
|
||||
|
@ -680,7 +676,6 @@ where
|
|||
/// Sort the elements by value in descending order along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
pub fn sort_descending(self, dim: usize) -> Tensor<B, D, K> {
|
||||
check!(TensorCheck::sort_dim::<D>("Sort", dim));
|
||||
Tensor::new(K::sort(self.primitive, dim, /*descending*/ true))
|
||||
|
@ -690,7 +685,6 @@ where
|
|||
/// Also returns the indices.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
pub fn sort_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
|
||||
check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
|
||||
let (values, indices) =
|
||||
|
@ -702,7 +696,6 @@ where
|
|||
/// Also returns the indices.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
pub fn sort_descending_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
|
||||
check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
|
||||
let (values, indices) = K::sort_with_indices(self.primitive, dim, /*descending*/ true);
|
||||
|
@ -712,7 +705,6 @@ where
|
|||
/// Returns the indices that sort the elements by value in ascending order along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
pub fn argsort(self, dim: usize) -> Tensor<B, D, Int> {
|
||||
check!(TensorCheck::sort_dim::<D>("Argsort", dim));
|
||||
Tensor::new(K::argsort(self.primitive, dim, /*descending*/ false))
|
||||
|
@ -721,14 +713,12 @@ where
|
|||
/// Returns the indices that sort the elements by value in descending order along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
pub fn argsort_descending(self, dim: usize) -> Tensor<B, D, Int> {
|
||||
check!(TensorCheck::sort_dim::<D>("Argsort", dim));
|
||||
Tensor::new(K::argsort(self.primitive, dim, /*descending*/ true))
|
||||
}
|
||||
|
||||
/// Returns the `k` largest elements of the given input tensor along a given dimension.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
pub fn topk(self, k: usize, dim: usize) -> Tensor<B, D, K> {
|
||||
let k_indices = Tensor::arange(0..k as i64, &self.device());
|
||||
self.sort_descending(dim).select(dim, k_indices)
|
||||
|
@ -736,7 +726,6 @@ where
|
|||
|
||||
/// Returns the `k` largest elements of the given input tensor along a given dimension.
|
||||
/// Also returns the indices.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
pub fn topk_with_indices(self, k: usize, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
|
||||
let k_indices = Tensor::arange(0..k as i64, &self.device());
|
||||
let (values, indices) = self.sort_descending_with_indices(dim);
|
||||
|
@ -2029,7 +2018,6 @@ where
|
|||
///
|
||||
/// Users should prefer the [Tensor::sort](Tensor::sort) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn sort<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
|
@ -2059,7 +2047,6 @@ where
|
|||
/// For sorting the elements of a tensor, users should prefer the
|
||||
/// [Tensor::sort_with_indices](Tensor::sort_with_indices) function, which is more high-level
|
||||
/// and designed for public use.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn sort_with_indices<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
|
@ -2087,7 +2074,6 @@ where
|
|||
///
|
||||
/// Users should prefer the [Tensor::argsort](Tensor::argsort) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn argsort<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
|
@ -2411,7 +2397,6 @@ impl<B: Backend> Numeric<B> for Int {
|
|||
B::int_sign(tensor)
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn sort<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
|
@ -2420,7 +2405,6 @@ impl<B: Backend> Numeric<B> for Int {
|
|||
B::int_sort(tensor, dim, descending)
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn sort_with_indices<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
|
@ -2429,7 +2413,6 @@ impl<B: Backend> Numeric<B> for Int {
|
|||
B::int_sort_with_indices(tensor, dim, descending)
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn argsort<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
|
@ -2758,7 +2741,6 @@ impl<B: Backend> Numeric<B> for Float {
|
|||
B::float_sign(tensor)
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn sort<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
|
@ -2767,7 +2749,6 @@ impl<B: Backend> Numeric<B> for Float {
|
|||
B::float_sort(tensor, dim, descending)
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn sort_with_indices<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
|
@ -2776,7 +2757,6 @@ impl<B: Backend> Numeric<B> for Float {
|
|||
B::float_sort_with_indices(tensor, dim, descending)
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn argsort<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
|
|
|
@ -6,6 +6,7 @@ use crate::{
|
|||
BasicOps, Device, Element, ElementComparison, ElementConversion, TensorData, TensorKind,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::try_read_sync;
|
||||
|
||||
/// Sort the elements of the input `tensor` by value along a given dimension.
|
||||
///
|
||||
|
@ -27,7 +28,6 @@ use alloc::vec::Vec;
|
|||
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
|
||||
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
pub fn sort<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B>>(
|
||||
tensor: K::Primitive<D>,
|
||||
dim: usize,
|
||||
|
@ -37,43 +37,7 @@ where
|
|||
<K as BasicOps<B>>::Elem: Element,
|
||||
{
|
||||
let device = K::device(&tensor);
|
||||
let data = K::into_data(tensor).read();
|
||||
|
||||
sort_data::<B, D, K>(data, dim, &device, descending)
|
||||
}
|
||||
|
||||
/// Sort the elements of the input `tensor` by value along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The input tensor.
|
||||
/// * `dim` - The axis along which to sort.
|
||||
/// * `descending` - The sorting order.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as the input tensor, where the elements are sorted by value.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
|
||||
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
|
||||
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn sort<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B>>(
|
||||
tensor: K::Primitive<D>,
|
||||
dim: usize,
|
||||
descending: bool,
|
||||
) -> K::Primitive<D>
|
||||
where
|
||||
<K as BasicOps<B>>::Elem: Element,
|
||||
{
|
||||
let device = K::device(&tensor);
|
||||
let data = K::into_data(tensor).read().await;
|
||||
|
||||
let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchonously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.");
|
||||
sort_data::<B, D, K>(data, dim, &device, descending)
|
||||
}
|
||||
|
||||
|
@ -119,7 +83,6 @@ where
|
|||
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
|
||||
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
pub fn sort_with_indices<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B>>(
|
||||
tensor: K::Primitive<D>,
|
||||
dim: usize,
|
||||
|
@ -129,44 +92,7 @@ where
|
|||
<K as BasicOps<B>>::Elem: Element,
|
||||
{
|
||||
let device = K::device(&tensor);
|
||||
let data = K::into_data(tensor).read();
|
||||
|
||||
sort_data_with_indices::<B, D, K>(data, dim, &device, descending)
|
||||
}
|
||||
|
||||
/// Sort the elements of the input `tensor` by value along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The input tensor.
|
||||
/// * `dim` - The axis along which to sort.
|
||||
/// * `descending` - The sorting order.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as the input tensor and corresponding indices, where
|
||||
/// the elements are sorted by value and the indices map back to the original input tensor.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
|
||||
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
|
||||
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn sort_with_indices<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B>>(
|
||||
tensor: K::Primitive<D>,
|
||||
dim: usize,
|
||||
descending: bool,
|
||||
) -> (K::Primitive<D>, IntTensor<B, D>)
|
||||
where
|
||||
<K as BasicOps<B>>::Elem: Element,
|
||||
{
|
||||
let device = K::device(&tensor);
|
||||
let data = K::into_data(tensor).read().await;
|
||||
|
||||
let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchonously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.");
|
||||
sort_data_with_indices::<B, D, K>(data, dim, &device, descending)
|
||||
}
|
||||
|
||||
|
@ -253,7 +179,6 @@ where
|
|||
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
|
||||
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
pub fn argsort<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B>>(
|
||||
tensor: K::Primitive<D>,
|
||||
dim: usize,
|
||||
|
@ -263,42 +188,7 @@ where
|
|||
<K as BasicOps<B>>::Elem: Element,
|
||||
{
|
||||
let device = K::device(&tensor);
|
||||
let data = K::into_data(tensor).read();
|
||||
|
||||
argsort_data::<B, D, K>(data, dim, &device, descending)
|
||||
}
|
||||
|
||||
/// Returns the indices that sort the elements of the input `tensor` along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The input tensor.
|
||||
/// * `dim` - The axis along which to sort.
|
||||
/// * `descending` - The sorting order.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
|
||||
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
|
||||
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn argsort<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B>>(
|
||||
tensor: K::Primitive<D>,
|
||||
dim: usize,
|
||||
descending: bool,
|
||||
) -> IntTensor<B, D>
|
||||
where
|
||||
<K as BasicOps<B>>::Elem: Element,
|
||||
{
|
||||
let device = K::device(&tensor);
|
||||
let data = K::into_data(tensor).read().await;
|
||||
let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchonously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.");
|
||||
|
||||
argsort_data::<B, D, K>(data, dim, &device, descending)
|
||||
}
|
||||
|
|
|
@ -3,14 +3,11 @@ use super::{
|
|||
IntTensor,
|
||||
};
|
||||
use crate::{
|
||||
backend::Backend, chunk, narrow, tensor::Shape, Bool, ElementConversion, Tensor, TensorData,
|
||||
argwhere_data, backend::Backend, chunk, narrow, tensor::Shape, Bool, ElementConversion, Tensor,
|
||||
TensorData,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
use core::ops::Range;
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
use crate::argwhere;
|
||||
use core::{future::Future, ops::Range};
|
||||
|
||||
/// Bool Tensor API for basic operations, see [tensor](crate::Tensor)
|
||||
/// for documentation on each function.
|
||||
|
@ -47,21 +44,9 @@ pub trait BoolTensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The data structure with the tensor's data.
|
||||
fn bool_into_data<const D: usize>(tensor: BoolTensor<B, D>) -> Reader<TensorData>;
|
||||
|
||||
/// Gets the data from the tensor.
|
||||
///
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `data` - The data structure.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The data cloned from the data structure.
|
||||
fn bool_to_data<const D: usize>(tensor: &BoolTensor<B, D>) -> Reader<TensorData> {
|
||||
Self::bool_into_data(tensor.clone())
|
||||
}
|
||||
fn bool_into_data<const D: usize>(
|
||||
tensor: BoolTensor<B, D>,
|
||||
) -> impl Future<Output = TensorData> + Send;
|
||||
|
||||
/// Creates a tensor from the data structure.
|
||||
///
|
||||
|
@ -420,9 +405,16 @@ pub trait BoolTensorOps<B: Backend> {
|
|||
///
|
||||
/// A vector of tensors, one for each dimension of the given tensor, containing the indices of
|
||||
/// the non-zero elements in that dimension.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn bool_argwhere<const D: usize>(tensor: BoolTensor<B, D>) -> IntTensor<B, 2> {
|
||||
argwhere::<B, D>(tensor)
|
||||
fn bool_argwhere<const D: usize>(
|
||||
tensor: BoolTensor<B, D>,
|
||||
) -> impl Future<Output = IntTensor<B, 2>> + Send {
|
||||
async {
|
||||
// Size of each output tensor is variable (= number of nonzero elements in the tensor).
|
||||
// Reading the data to count the number of truth values might cause sync but is required.
|
||||
let device = B::bool_device(&tensor);
|
||||
let data = B::bool_into_data(tensor).await;
|
||||
argwhere_data::<B, D>(data, &device)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the indices of the elements that are non-zero.
|
||||
|
@ -435,14 +427,17 @@ pub trait BoolTensorOps<B: Backend> {
|
|||
///
|
||||
/// A vector of tensors, one for each dimension of the given tensor, containing the indices of
|
||||
/// the non-zero elements in that dimension.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn bool_nonzero<const D: usize>(tensor: BoolTensor<B, D>) -> Vec<IntTensor<B, 1>> {
|
||||
let indices = B::bool_argwhere(tensor);
|
||||
let dims = B::int_shape(&indices).dims;
|
||||
B::int_chunk(indices, dims[1], 1)
|
||||
.into_iter()
|
||||
.map(|t| B::int_reshape(t, Shape::new([dims[0]])))
|
||||
.collect()
|
||||
fn bool_nonzero<const D: usize>(
|
||||
tensor: BoolTensor<B, D>,
|
||||
) -> impl Future<Output = Vec<IntTensor<B, 1>>> + Send {
|
||||
async {
|
||||
let indices = B::bool_argwhere(tensor).await;
|
||||
let dims = B::int_shape(&indices).dims;
|
||||
B::int_chunk(indices, dims[1], 1)
|
||||
.into_iter()
|
||||
.map(|t| B::int_reshape(t, Shape::new([dims[0]])))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Broadcasts the bool `tensor` to the given `shape`.
|
||||
|
|
|
@ -6,10 +6,9 @@ use crate::{backend::Backend, tensor::Shape, Distribution, ElementConversion, In
|
|||
use crate::{cartesian_grid, Tensor};
|
||||
use crate::{tensor::api::chunk, tensor::api::narrow};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
use core::future::Future;
|
||||
use core::ops::Range;
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
use crate::{argsort, sort, sort_with_indices};
|
||||
|
||||
/// Int Tensor API for basic and numeric operations, see [tensor](crate::Tensor)
|
||||
|
@ -47,20 +46,9 @@ pub trait IntTensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The data structure with the tensor's data.
|
||||
fn int_into_data<const D: usize>(tensor: IntTensor<B, D>) -> Reader<TensorData>;
|
||||
|
||||
/// Gets the data from the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The data cloned from the data structure.
|
||||
fn int_to_data<const D: usize>(tensor: &IntTensor<B, D>) -> Reader<TensorData> {
|
||||
Self::int_into_data(tensor.clone())
|
||||
}
|
||||
fn int_into_data<const D: usize>(
|
||||
tensor: IntTensor<B, D>,
|
||||
) -> impl Future<Output = TensorData> + Send;
|
||||
|
||||
/// Creates a tensor from the data structure.
|
||||
///
|
||||
|
@ -1241,7 +1229,6 @@ pub trait IntTensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as the input tensor, where the elements are sorted by value.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn int_sort<const D: usize>(
|
||||
tensor: IntTensor<B, D>,
|
||||
dim: usize,
|
||||
|
@ -1263,7 +1250,6 @@ pub trait IntTensorOps<B: Backend> {
|
|||
///
|
||||
/// A tensor with the same shape as the input tensor and corresponding indices, where
|
||||
/// the elements are sorted by value and the indices map back to the original input tensor.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn int_sort_with_indices<const D: usize>(
|
||||
tensor: IntTensor<B, D>,
|
||||
dim: usize,
|
||||
|
@ -1286,7 +1272,6 @@ pub trait IntTensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn int_argsort<const D: usize>(
|
||||
tensor: IntTensor<B, D>,
|
||||
dim: usize,
|
||||
|
|
|
@ -7,10 +7,9 @@ use crate::Tensor;
|
|||
use crate::{backend::Backend, tensor::Shape, Distribution, ElementConversion, Float, TensorData};
|
||||
use crate::{tensor::api::chunk, tensor::api::narrow};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
use core::future::Future;
|
||||
use core::ops::Range;
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
use crate::{argsort, sort, sort_with_indices};
|
||||
|
||||
/// Operations on float tensors.
|
||||
|
@ -111,20 +110,9 @@ pub trait FloatTensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The data structure with the tensor's data.
|
||||
fn float_to_data<const D: usize>(tensor: &FloatTensor<B, D>) -> Reader<TensorData> {
|
||||
Self::float_into_data(tensor.clone())
|
||||
}
|
||||
|
||||
/// Converts the tensor to a data structure.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The data structure with the tensor's data.
|
||||
fn float_into_data<const D: usize>(tensor: FloatTensor<B, D>) -> Reader<TensorData>;
|
||||
fn float_into_data<const D: usize>(
|
||||
tensor: FloatTensor<B, D>,
|
||||
) -> impl Future<Output = TensorData> + Send;
|
||||
|
||||
/// Gets the device of the tensor.
|
||||
///
|
||||
|
@ -1389,7 +1377,6 @@ pub trait FloatTensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as the input tensor, where the elements are sorted by value.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn float_sort<const D: usize>(
|
||||
tensor: FloatTensor<B, D>,
|
||||
dim: usize,
|
||||
|
@ -1412,7 +1399,6 @@ pub trait FloatTensorOps<B: Backend> {
|
|||
///
|
||||
/// A tensor with the same shape as the input tensor and corresponding indices, where
|
||||
/// the elements are sorted by value and the indices map back to the original input tensor.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn float_sort_with_indices<const D: usize>(
|
||||
tensor: FloatTensor<B, D>,
|
||||
dim: usize,
|
||||
|
@ -1434,7 +1420,6 @@ pub trait FloatTensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
fn float_argsort<const D: usize>(
|
||||
tensor: FloatTensor<B, D>,
|
||||
dim: usize,
|
||||
|
|
|
@ -31,7 +31,7 @@ wgpu = { workspace = true, features = ["fragile-send-sync-non-atomic-wasm"] }
|
|||
pollster = { workspace = true }
|
||||
|
||||
log = { workspace = true }
|
||||
futures-intrusive = { workspace = true }
|
||||
async-channel = { workspace = true }
|
||||
derive-new = { workspace = true }
|
||||
hashbrown = { workspace = true }
|
||||
|
||||
|
|
|
@ -113,7 +113,7 @@ where
|
|||
)
|
||||
}
|
||||
|
||||
fn buffer_reader(&mut self, handle: server::Binding<Self>) -> BufferReader {
|
||||
fn create_read_buffer(&mut self, handle: server::Binding<Self>) -> wgpu::Buffer {
|
||||
let resource = self.memory_management.get(handle.memory);
|
||||
|
||||
let size = resource.size();
|
||||
|
@ -134,50 +134,7 @@ where
|
|||
self.tasks_count += 1;
|
||||
|
||||
self.sync(SyncType::Flush);
|
||||
|
||||
BufferReader::new(buffer_dest)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct BufferReader {
|
||||
buffer: wgpu::Buffer,
|
||||
}
|
||||
|
||||
impl BufferReader {
|
||||
#[cfg(target_family = "wasm")]
|
||||
async fn read(self, device: alloc::sync::Arc<wgpu::Device>) -> Vec<u8> {
|
||||
self.read_async(&device).await
|
||||
}
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
fn read(self, device: &wgpu::Device) -> Vec<u8> {
|
||||
pollster::block_on(self.read_async(device))
|
||||
}
|
||||
|
||||
async fn read_async(&self, device: &wgpu::Device) -> Vec<u8> {
|
||||
let buffer_slice = self.buffer.slice(..);
|
||||
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
|
||||
buffer_slice.map_async(wgpu::MapMode::Read, move |v| {
|
||||
sender
|
||||
.send(v)
|
||||
.expect("Unable to send buffer slice result to async channel.")
|
||||
});
|
||||
|
||||
device.poll(wgpu::Maintain::Wait);
|
||||
|
||||
let result = receiver.receive().await;
|
||||
|
||||
if let Some(Ok(())) = result {
|
||||
let data = buffer_slice.get_mapped_range();
|
||||
let result = bytemuck::cast_slice(&data).to_vec();
|
||||
|
||||
drop(data);
|
||||
self.buffer.unmap();
|
||||
result
|
||||
} else {
|
||||
panic!("Unable to read buffer {:?}", result)
|
||||
}
|
||||
buffer_dest
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -190,15 +147,38 @@ where
|
|||
type MemoryManagement = MM;
|
||||
type AutotuneKey = JitAutotuneKey;
|
||||
|
||||
fn read(&mut self, binding: server::Binding<Self>) -> Reader<Vec<u8>> {
|
||||
#[cfg(target_family = "wasm")]
|
||||
{
|
||||
let future = self.buffer_reader(binding).read(self.device.clone());
|
||||
return Reader::Future(Box::pin(future));
|
||||
}
|
||||
fn read(&mut self, binding: server::Binding<Self>) -> Reader {
|
||||
let device = self.device.clone();
|
||||
let buffer = self.create_read_buffer(binding);
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
Reader::Concrete(self.buffer_reader(binding).read(&self.device))
|
||||
Box::pin(async move {
|
||||
let buffer_slice = buffer.slice(..);
|
||||
let (sender, receiver) = async_channel::bounded(1);
|
||||
|
||||
buffer_slice.map_async(wgpu::MapMode::Read, move |v| {
|
||||
sender
|
||||
.try_send(v)
|
||||
.expect("Unable to send buffer slice result to async channel.")
|
||||
});
|
||||
|
||||
device.poll(wgpu::Maintain::Wait);
|
||||
|
||||
let result = receiver
|
||||
.recv()
|
||||
.await
|
||||
.expect("Unable to receive buffer slice result.");
|
||||
|
||||
if let Ok(()) = result {
|
||||
let data = buffer_slice.get_mapped_range();
|
||||
let result = bytemuck::cast_slice(&data).to_vec();
|
||||
|
||||
drop(data);
|
||||
buffer.unmap();
|
||||
result
|
||||
} else {
|
||||
panic!("Unable to read buffer {:?}", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn get_resource(
|
||||
|
|
|
@ -25,9 +25,6 @@ tui = ["burn-train?/tui"]
|
|||
## Includes system info metrics (CPU/GPU usage, etc)
|
||||
metrics = ["burn-train?/metrics"]
|
||||
|
||||
# Useful when targeting WASM and not using WGPU.
|
||||
wasm-sync = ["burn-core/wasm-sync"]
|
||||
|
||||
# Datasets
|
||||
dataset = ["burn-core/dataset"]
|
||||
|
||||
|
|
|
@ -73,7 +73,6 @@
|
|||
//! - `blas-netlib`: If supported, Blas Netlib will be use
|
||||
//! - `openblas`: If supported, Openblas will be use
|
||||
//! - `openblas-system`: If supported, Openblas installed on the system will be use
|
||||
//! - `wasm-sync`: When targeting wasm, but want a sync API (won't work with WGPU)
|
||||
//! - `autotune`: Enable running benchmarks to select the best kernel in backends that support it.
|
||||
//! - `fusion`: Enable operation fusion in backends that support it.
|
||||
//! - Backend decorators
|
||||
|
|
|
@ -28,7 +28,7 @@ pub fn launch<R: Runtime>(device: &R::Device) {
|
|||
ArrayHandle::new(&output_handle, input.len()),
|
||||
);
|
||||
|
||||
let output = client.read(output_handle.binding()).read_sync().unwrap();
|
||||
let output = client.read(output_handle.binding());
|
||||
let output = f32::from_bytes(&output);
|
||||
|
||||
// Should be [-0.1587, 0.0000, 0.8413, 5.0000]
|
||||
|
|
|
@ -150,19 +150,13 @@ impl<B: Backend> Model<B> {
|
|||
// Convert the model output into probability distribution using softmax formula
|
||||
let probabilities = softmax(output, 1);
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
let result = probabilities.into_data().convert::<f32>().to_vec().unwrap();
|
||||
|
||||
// Forces the result to be computed
|
||||
#[cfg(target_family = "wasm")]
|
||||
let result = probabilities
|
||||
.into_data()
|
||||
probabilities
|
||||
.into_data_async()
|
||||
.await
|
||||
.convert::<f32>()
|
||||
.to_vec()
|
||||
.unwrap();
|
||||
|
||||
result
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -24,6 +24,3 @@ console_error_panic_hook = { workspace = true }
|
|||
wasm-bindgen = "0.2"
|
||||
wasm-bindgen-futures = "0.4"
|
||||
js-sys = "0.3"
|
||||
|
||||
[dev-dependencies]
|
||||
pollster = { workspace = true }
|
||||
|
|
|
@ -68,11 +68,7 @@ impl Mnist {
|
|||
let output = burn::tensor::activation::softmax(output, 1);
|
||||
|
||||
// Flatten output tensor with [1, 10] shape into boxed slice of [f32]
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
let output = output.into_data();
|
||||
|
||||
#[cfg(target_family = "wasm")]
|
||||
let output = output.into_data().await;
|
||||
let output = output.into_data_async().await;
|
||||
|
||||
let array = Array::new();
|
||||
for value in output.iter::<f32>() {
|
||||
|
|
Loading…
Reference in New Issue