Feat/async read (#833)

This commit is contained in:
Nathaniel Simard 2023-09-28 17:09:58 -04:00 committed by GitHub
parent aa90fe8efb
commit ca787d6446
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
85 changed files with 799 additions and 567 deletions

View File

@ -63,6 +63,7 @@ thiserror = "1.0.40"
tracing-subscriber = "0.3.17"
tracing-core = "0.1.31"
tracing-appender = "0.2.2"
async-trait = "0.1.73"
# WGPU stuff
futures-intrusive = "0.5"

View File

@ -16,7 +16,7 @@ export_tests = ["burn-tensor-testgen"]
[dependencies]
burn-common = {path = "../burn-common", version = "0.10.0" }
burn-tensor = {path = "../burn-tensor", version = "0.10.0" }
burn-tensor = {path = "../burn-tensor", version = "0.10.0", default-features = false }
burn-tensor-testgen = {path = "../burn-tensor-testgen", version = "0.10.0", optional = true}
derive-new = {workspace = true}

View File

@ -10,6 +10,8 @@
#[macro_use]
extern crate derive_new;
extern crate alloc;
/// Gradients module.
pub mod grads;
/// Operation module.

View File

@ -3,7 +3,7 @@ use crate::{
ADBackendDecorator,
};
use burn_tensor::{backend::Backend, ops::BoolTensorOps, Data, Shape};
use burn_tensor::{backend::Backend, ops::BoolTensorOps, Data, Reader, Shape};
impl<B: Backend> BoolTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn bool_from_data<const D: usize>(data: Data<bool, D>, device: &B::Device) -> BoolTensor<B, D> {
@ -14,11 +14,11 @@ impl<B: Backend> BoolTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B>
B::bool_shape(tensor)
}
fn bool_to_data<const D: usize>(tensor: &BoolTensor<B, D>) -> Data<bool, D> {
fn bool_to_data<const D: usize>(tensor: &BoolTensor<B, D>) -> Reader<Data<bool, D>> {
B::bool_to_data(tensor)
}
fn bool_into_data<const D: usize>(tensor: BoolTensor<B, D>) -> Data<bool, D> {
fn bool_into_data<const D: usize>(tensor: BoolTensor<B, D>) -> Reader<Data<bool, D>> {
B::bool_into_data(tensor)
}

View File

@ -3,7 +3,7 @@ use crate::{
ADBackendDecorator,
};
use burn_tensor::{backend::Backend, ops::IntTensorOps, Data, Shape};
use burn_tensor::{backend::Backend, ops::IntTensorOps, Data, Reader, Shape};
impl<B: Backend> IntTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn int_from_data<const D: usize>(
@ -17,11 +17,11 @@ impl<B: Backend> IntTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
B::int_shape(tensor)
}
fn int_to_data<const D: usize>(tensor: &IntTensor<B, D>) -> Data<B::IntElem, D> {
fn int_to_data<const D: usize>(tensor: &IntTensor<B, D>) -> Reader<Data<B::IntElem, D>> {
B::int_to_data(tensor)
}
fn int_into_data<const D: usize>(tensor: IntTensor<B, D>) -> Data<B::IntElem, D> {
fn int_into_data<const D: usize>(tensor: IntTensor<B, D>) -> Reader<Data<B::IntElem, D>> {
B::int_into_data(tensor)
}

View File

@ -9,7 +9,9 @@ use crate::{
ADBackendDecorator,
};
use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape, Tensor};
use burn_tensor::{
backend::Backend, ops::TensorOps, Data, ElementConversion, Reader, Shape, Tensor,
};
use super::maxmin::MaxMinDim;
@ -41,11 +43,11 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
B::shape(&tensor.primitive)
}
fn to_data<const D: usize>(tensor: &ADTensor<B, D>) -> Data<FloatElem<B>, D> {
fn to_data<const D: usize>(tensor: &ADTensor<B, D>) -> Reader<Data<FloatElem<B>, D>> {
B::to_data(&tensor.primitive)
}
fn into_data<const D: usize>(tensor: ADTensor<B, D>) -> Data<FloatElem<B>, D> {
fn into_data<const D: usize>(tensor: ADTensor<B, D>) -> Reader<Data<FloatElem<B>, D>> {
B::into_data(tensor.primitive)
}

View File

@ -1,6 +1,6 @@
use std::marker::PhantomData;
use burn_tensor::{backend::Backend, Data, Shape};
use burn_tensor::{backend::Backend, Data, Reader, Shape};
use crate::{
element::{CandleElement, FloatCandleElement, IntCandleElement},
@ -36,7 +36,7 @@ pub fn from_data<E: CandleElement, const D: usize>(
CandleTensor::from_data(data, *device)
}
pub fn to_data<E: CandleElement, const D: usize>(tensor: &CandleTensor<E, D>) -> Data<E, D> {
pub fn into_data<E: CandleElement, const D: usize>(tensor: CandleTensor<E, D>) -> Data<E, D> {
Data::new(
tensor.tensor.flatten_all().unwrap().to_vec1().unwrap(),
tensor.shape(),

View File

@ -1,4 +1,4 @@
use burn_tensor::{ops::BoolTensorOps, Data, Shape};
use burn_tensor::{ops::BoolTensorOps, Data, Reader, Shape};
use crate::{
element::{CandleElement, FloatCandleElement, IntCandleElement},
@ -18,10 +18,12 @@ impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<CandleBackend<F,
super::base::shape(tensor)
}
fn bool_into_data<const D: usize>(tensor: BoolTensor<Self, D>) -> Data<bool, D> {
fn bool_into_data<const D: usize>(tensor: BoolTensor<Self, D>) -> Reader<Data<bool, D>> {
let x: Vec<u8> = tensor.tensor.flatten_all().unwrap().to_vec1().unwrap();
let y = x.iter().map(|b| !matches!(b, 0)).collect();
Data::new(y, tensor.shape())
let data = Data::new(y, tensor.shape());
Reader::Concrete(data)
}
fn bool_from_data<const D: usize>(

View File

@ -1,4 +1,4 @@
use burn_tensor::{ops::IntTensorOps, Bool, Data, Shape};
use burn_tensor::{ops::IntTensorOps, Bool, Data, Reader, Shape};
use crate::{
element::{CandleElement, FloatCandleElement, IntCandleElement},
@ -18,8 +18,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<CandleBackend<F, I
super::base::shape(tensor)
}
fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> Data<IntElem<Self>, D> {
super::base::to_data(&tensor)
fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> Reader<Data<IntElem<Self>, D>> {
Reader::Concrete(super::base::into_data(tensor))
}
fn int_from_data<const D: usize>(

View File

@ -1,6 +1,6 @@
use std::borrow::Borrow;
use burn_tensor::{ops::TensorOps, Data, Distribution, ElementConversion, Shape};
use burn_tensor::{ops::TensorOps, Data, Distribution, ElementConversion, Reader, Shape};
use candle_core::{backend::BackendStorage, shape, Tensor};
use crate::{
@ -54,8 +54,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<CandleBackend<F, I>>
super::base::shape(tensor)
}
fn to_data<const D: usize>(tensor: &CandleTensor<F, D>) -> Data<F, D> {
super::base::to_data(tensor)
fn into_data<const D: usize>(tensor: CandleTensor<F, D>) -> Reader<Data<F, D>> {
Reader::Concrete(super::base::into_data(tensor))
}
fn device<const D: usize>(tensor: &CandleTensor<F, D>) -> Device<Self> {

View File

@ -15,14 +15,19 @@ default = ["std"]
std = ["rand/std"]
[dependencies]
[target.'cfg(target_family = "wasm")'.dependencies]
async-trait = { workspace = true }
getrandom = { version = "0.2", features = ["js"] }
[dependencies]
# ** Please make sure all dependencies support no_std when std is disabled **
const-random = { workspace = true }
rand = { workspace = true }
spin = { workspace = true } # using in place of use std::sync::Mutex;
spin = { workspace = true } # using in place of use std::sync::Mutex;
uuid = { workspace = true }
derive-new = { workspace = true }
[dev-dependencies]
dashmap = { workspace = true }

View File

@ -5,6 +5,9 @@
//!
//! This library contains common types used by other Burn crates that must be shared.
#[macro_use]
extern crate derive_new;
/// Id module contains types for unique identifiers.
pub mod id;
@ -15,4 +18,8 @@ pub mod rand;
/// Stub module contains types for stubs for non-std environments and for std environments.
pub mod stub;
/// Useful when you need to read async data without having to decorate each function with async
/// notation.
pub mod reader;
extern crate alloc;

115
burn-common/src/reader.rs Normal file
View File

@ -0,0 +1,115 @@
use alloc::boxed::Box;
use core::marker::PhantomData;
#[cfg(target_family = "wasm")]
#[async_trait::async_trait]
/// Allows to create async reader.
pub trait AsyncReader<T>: Send {
/// Read asynchronously.
async fn read(self: Box<Self>) -> T;
}
/// Define how data is read, sync or async.
pub enum Reader<T> {
/// Concrete variant.
Concrete(T),
/// Sync data variant.
Sync(Box<dyn SyncReader<T>>),
#[cfg(target_family = "wasm")]
/// Async data variant.
Async(Box<dyn AsyncReader<T>>),
#[cfg(target_family = "wasm")]
/// Future data variant.
Future(core::pin::Pin<Box<dyn core::future::Future<Output = T> + Send>>),
}
/// Allows to create sync reader.
pub trait SyncReader<T>: Send {
/// Read synchronously.
fn read(self: Box<Self>) -> T;
}
#[derive(new)]
struct MappedReader<I, O, F> {
reader: Reader<I>,
mapper: F,
_output: PhantomData<O>,
}
impl<I, O, F> SyncReader<O> for MappedReader<I, O, F>
where
I: Send,
O: Send,
F: Send + FnOnce(I) -> O,
{
fn read(self: Box<Self>) -> O {
let input = self
.reader
.read_sync()
.expect("Only sync data supported in a sync reader.");
(self.mapper)(input)
}
}
#[cfg(target_family = "wasm")]
#[async_trait::async_trait]
impl<I, O, F> AsyncReader<O> for MappedReader<I, O, F>
where
I: Send,
O: Send,
F: Send + FnOnce(I) -> O,
{
async fn read(self: Box<Self>) -> O {
let input = self.reader.read().await;
(self.mapper)(input)
}
}
impl<T> Reader<T> {
#[cfg(target_family = "wasm")]
/// Read the data.
pub async fn read(self) -> T {
match self {
Self::Concrete(data) => data,
Self::Sync(reader) => reader.read(),
Self::Async(func) => func.read().await,
Self::Future(future) => future.await,
}
}
#[cfg(not(target_family = "wasm"))]
/// Read the data.
pub fn read(self) -> T {
match self {
Self::Concrete(data) => data,
Self::Sync(reader) => reader.read(),
}
}
/// Read the data only if sync, returns None if an async reader.
pub fn read_sync(self) -> Option<T> {
match self {
Self::Concrete(data) => Some(data),
Self::Sync(reader) => Some(reader.read()),
#[cfg(target_family = "wasm")]
Self::Async(_func) => return None,
#[cfg(target_family = "wasm")]
Self::Future(_future) => return None,
}
}
/// Map the current reader to another type.
pub fn map<O, F: FnOnce(T) -> O>(self, mapper: F) -> Reader<O>
where
T: 'static + Send,
O: 'static + Send,
F: 'static + Send,
{
#[cfg(target_family = "wasm")]
return Reader::Async(Box::new(MappedReader::new(self, mapper)));
#[cfg(not(target_family = "wasm"))]
Reader::Sync(Box::new(MappedReader::new(self, mapper)))
}
}

View File

@ -1,11 +1,12 @@
use crate::server::{ComputeServer, Handle};
use alloc::vec::Vec;
use burn_common::reader::Reader;
/// The ComputeChannel trait links the ComputeClient to the ComputeServer
/// while ensuring thread-safety
pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug {
/// Given a handle, returns owned resource as bytes
fn read(&self, handle: &Handle<Server>) -> Vec<u8>;
fn read(&self, handle: &Handle<Server>) -> Reader<Vec<u8>>;
/// Given a resource as bytes, stores it and returns the resource handle
fn create(&self, data: &[u8]) -> Handle<Server>;

View File

@ -2,6 +2,7 @@ use super::ComputeChannel;
use crate::server::{ComputeServer, Handle};
use alloc::sync::Arc;
use alloc::vec::Vec;
use burn_common::reader::Reader;
/// A channel using a [ref cell](core::cell::RefCell) to access the server with mutability.
///
@ -40,7 +41,7 @@ impl<Server> ComputeChannel<Server> for RefCellComputeChannel<Server>
where
Server: ComputeServer,
{
fn read(&self, handle: &Handle<Server>) -> Vec<u8> {
fn read(&self, handle: &Handle<Server>) -> Reader<Vec<u8>> {
let mut server = self.server.borrow_mut();
server.read(handle)

View File

@ -6,9 +6,9 @@ mod mutex;
#[cfg(feature = "channel-mutex")]
pub use mutex::*;
#[cfg(feature = "channel-mpsc")]
#[cfg(all(feature = "channel-mpsc", not(target_family = "wasm")))]
mod mpsc;
#[cfg(feature = "channel-mpsc")]
#[cfg(all(feature = "channel-mpsc", not(target_family = "wasm")))]
pub use mpsc::*;
#[cfg(feature = "channel-cell")]

View File

@ -3,6 +3,8 @@ use std::{
thread,
};
use burn_common::reader::Reader;
use super::ComputeChannel;
use crate::server::{ComputeServer, Handle};
@ -31,7 +33,7 @@ enum Message<Server>
where
Server: ComputeServer,
{
Read(Handle<Server>, Callback<Vec<u8>>),
Read(Handle<Server>, Callback<Reader<Vec<u8>>>),
Create(Vec<u8>, Callback<Handle<Server>>),
Empty(usize, Callback<Handle<Server>>),
Execute(Server::Kernel, Vec<Handle<Server>>),
@ -91,7 +93,7 @@ impl<Server> ComputeChannel<Server> for MpscComputeChannel<Server>
where
Server: ComputeServer + 'static,
{
fn read(&self, handle: &Handle<Server>) -> Vec<u8> {
fn read(&self, handle: &Handle<Server>) -> Reader<Vec<u8>> {
let (callback, response) = mpsc::sync_channel(1);
self.state

View File

@ -2,6 +2,7 @@ use super::ComputeChannel;
use crate::server::{ComputeServer, Handle};
use alloc::sync::Arc;
use alloc::vec::Vec;
use burn_common::reader::Reader;
use spin::Mutex;
/// The MutexComputeChannel ensures thread-safety by locking the server
@ -34,7 +35,7 @@ impl<Server> ComputeChannel<Server> for MutexComputeChannel<Server>
where
Server: ComputeServer,
{
fn read(&self, handle: &Handle<Server>) -> Vec<u8> {
fn read(&self, handle: &Handle<Server>) -> Reader<Vec<u8>> {
self.server.lock().read(handle)
}

View File

@ -3,6 +3,7 @@ use crate::{
server::{ComputeServer, Handle},
};
use alloc::vec::Vec;
use burn_common::reader::Reader;
use core::marker::PhantomData;
/// The ComputeClient is the entry point to require tasks from the ComputeServer.
@ -40,7 +41,7 @@ where
}
/// Given a handle, returns owned resource as bytes.
pub fn read(&self, handle: &Handle<Server>) -> Vec<u8> {
pub fn read(&self, handle: &Handle<Server>) -> Reader<Vec<u8>> {
self.channel.read(handle)
}

View File

@ -3,6 +3,41 @@ use crate::{
storage::ComputeStorage,
};
use alloc::vec::Vec;
use burn_common::reader::Reader;
/// The compute server is responsible for handling resources and computations over resources.
///
/// Everything in the server is mutable, therefore it should be solely accessed through the
/// [compute channel](crate::channel::ComputeChannel) for thread safety.
pub trait ComputeServer: Send + core::fmt::Debug
where
Self: Sized,
{
/// The kernel type defines the computation algorithms.
type Kernel: Send;
/// The [storage](ComputeStorage) type defines how data is stored and accessed.
type Storage: ComputeStorage;
/// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type.
type MemoryManagement: MemoryManagement<Self::Storage>;
/// Given a handle, returns the owned resource as bytes.
fn read(&mut self, handle: &Handle<Self>) -> Reader<Vec<u8>>;
/// Given a resource as bytes, stores it and returns the memory handle.
fn create(&mut self, data: &[u8]) -> Handle<Self>;
/// Reserves `size` bytes in the storage, and returns a handle over them.
fn empty(&mut self, size: usize) -> Handle<Self>;
/// Executes the `kernel` over the given memory `handles`.
///
/// Kernels have mutable access to every resource they are given
/// and are responsible of determining which should be read or written.
fn execute(&mut self, kernel: Self::Kernel, handles: &[&Handle<Self>]);
/// Wait for the completion of every task in the server.
fn sync(&mut self);
}
/// Server handle containing the [memory handle](MemoryManagement::Handle).
#[derive(new, Debug)]
@ -25,37 +60,3 @@ impl<Server: ComputeServer> Clone for Handle<Server> {
}
}
}
/// The compute server is responsible for handling resources and computations over resources.
///
/// Everything in the server is mutable, therefore it should be solely accessed through the
/// [compute channel](crate::channel::ComputeChannel) for thread safety.
pub trait ComputeServer: Send + core::fmt::Debug
where
Self: Sized,
{
/// The kernel type defines the computation algorithms.
type Kernel: Send;
/// The [storage](ComputeStorage) type defines how data is stored and accessed.
type Storage: ComputeStorage;
/// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type.
type MemoryManagement: MemoryManagement<Self::Storage>;
/// Given a handle, returns the owned resource as bytes.
fn read(&mut self, handle: &Handle<Self>) -> Vec<u8>;
/// Given a resource as bytes, stores it and returns the memory handle.
fn create(&mut self, data: &[u8]) -> Handle<Self>;
/// Reserves `size` bytes in the storage, and returns a handle over them.
fn empty(&mut self, size: usize) -> Handle<Self>;
/// Executes the `kernel` over the given memory `handles`.
///
/// Kernels have mutable access to every resource they are given
/// and are responsible of determining which should be read or written.
fn execute(&mut self, kernel: Self::Kernel, handles: &[&Handle<Self>]);
/// Wait for the completion of every task in the server.
fn sync(&mut self);
}

View File

@ -14,7 +14,7 @@ impl core::fmt::Debug for BytesStorage {
}
}
/// Can send to other threads, but can't sync.
/// Can send to other threads.
unsafe impl Send for BytesStorage {}
unsafe impl Send for BytesResource {}

View File

@ -1,3 +1,4 @@
use burn_common::reader::Reader;
use burn_compute::{
memory_management::{MemoryManagement, SimpleMemoryManagement},
server::{ComputeServer, Handle},
@ -22,10 +23,10 @@ where
type Storage = BytesStorage;
type MemoryManagement = MM;
fn read(&mut self, handle: &Handle<Self>) -> Vec<u8> {
fn read(&mut self, handle: &Handle<Self>) -> Reader<Vec<u8>> {
let bytes = self.memory_management.get(&handle.memory);
bytes.read().to_vec()
Reader::Concrete(bytes.read().to_vec())
}
fn create(&mut self, data: &[u8]) -> Handle<Self> {

View File

@ -10,7 +10,7 @@ fn created_resource_is_the_same_when_read() {
let obtained_resource = client.read(&resource_description);
assert_eq!(resource, obtained_resource)
assert_eq!(resource, obtained_resource.read())
}
#[test]
@ -20,7 +20,7 @@ fn empty_allocates_memory() {
let resource_description = client.empty(size);
let empty_resource = client.read(&resource_description);
assert_eq!(empty_resource.len(), 4);
assert_eq!(empty_resource.read().len(), 4);
}
#[test]
@ -34,5 +34,5 @@ fn execute_elementwise_addition() {
let obtained_resource = client.read(&out);
assert_eq!(obtained_resource, Vec::from([4, 5, 6]))
assert_eq!(obtained_resource.read(), Vec::from([4, 5, 6]))
}

View File

@ -25,7 +25,6 @@ std = [
"half/std",
"derive-new/std",
]
dataset = ["burn-dataset/default"]
dataset-minimal = ["burn-dataset"]
dataset-sqlite = ["burn-dataset/sqlite"]
@ -42,7 +41,7 @@ ndarray-blas-openblas = ["__ndarray", "ndarray", "burn-ndarray/blas-openblas"]
ndarray-blas-openblas-system = ["__ndarray", "ndarray", "burn-ndarray/blas-openblas-system"]
__ndarray = [] # Internal flag to know when one ndarray feature is enabled.
wgpu = ["burn-wgpu"]
wgpu = ["burn-wgpu/default"]
tch = ["burn-tch"]

View File

@ -1,7 +1,7 @@
use crate as burn;
use crate::{config::Config, tensor::Tensor};
use burn_tensor::{backend::Backend, ElementConversion};
use burn_tensor::backend::Backend;
/// Gradient Clipping provides a way to mitigate exploding gradients
#[derive(Config)]
@ -68,20 +68,26 @@ impl GradientClipping {
clipped_grad.mask_fill(lower_mask, -threshold)
}
fn l2_norm<B: Backend, const D: usize>(tensor: Tensor<B, D>) -> Tensor<B, 1> {
let squared = tensor.powf(2.0);
let sum = squared.sum();
sum.sqrt()
#[cfg(target_family = "wasm")]
fn clip_by_norm<B: Backend, const D: usize>(
&self,
_grad: Tensor<B, D>,
_threshold: f32,
) -> Tensor<B, D> {
todo!("Not yet supported on wasm");
}
#[cfg(not(target_family = "wasm"))]
fn clip_by_norm<B: Backend, const D: usize>(
&self,
grad: Tensor<B, D>,
threshold: f32,
) -> Tensor<B, D> {
use burn_tensor::ElementConversion;
let norm = Self::l2_norm(grad.clone());
let norm_float = norm.into_scalar().elem::<f32>();
if norm_float > threshold {
let scale = threshold / norm_float;
grad.mul_scalar(scale)
@ -89,6 +95,14 @@ impl GradientClipping {
grad
}
}
#[cfg(not(target_family = "wasm"))]
fn l2_norm<B: Backend, const D: usize>(tensor: Tensor<B, D>) -> Tensor<B, 1> {
let squared = tensor.powf(2.0);
let sum = squared.sum();
sum.sqrt()
}
}
#[cfg(test)]

View File

@ -100,7 +100,7 @@ pub struct ParamSerde<T> {
}
impl<B: Backend, const D: usize> Record for Param<Tensor<B, D>> {
type Item<S: PrecisionSettings> = ParamSerde<FloatTensorSerde<B, D, S>>;
type Item<S: PrecisionSettings> = ParamSerde<FloatTensorSerde<S>>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
ParamSerde::new(self.id.into_string(), self.value.into_item())

View File

@ -5,6 +5,7 @@ use alloc::string::{String, ToString};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use super::{BinBytesRecorder, FullPrecisionSettings, PrecisionSettings, Record};
#[cfg(feature = "std")]
use super::{
BinFileRecorder, BinGzFileRecorder, DefaultFileRecorder, HalfPrecisionSettings,

View File

@ -5,136 +5,131 @@ use serde::{Deserialize, Serialize};
/// This struct implements serde to lazily serialize and deserialize a float tensor
/// using the given [record settings](RecordSettings).
#[derive(new, Clone, Debug)]
pub struct FloatTensorSerde<B: Backend, const D: usize, S: PrecisionSettings> {
tensor: Tensor<B, D>,
elem: core::marker::PhantomData<S>,
pub struct FloatTensorSerde<S: PrecisionSettings> {
data: DataSerialize<S::FloatElem>,
}
/// This struct implements serde to lazily serialize and deserialize an int tensor
/// using the given [record settings](RecordSettings).
#[derive(new, Clone, Debug)]
pub struct IntTensorSerde<B: Backend, const D: usize, S: PrecisionSettings> {
tensor: Tensor<B, D, Int>,
elem: core::marker::PhantomData<S>,
pub struct IntTensorSerde<S: PrecisionSettings> {
data: DataSerialize<S::IntElem>,
}
/// This struct implements serde to lazily serialize and deserialize an bool tensor.
#[derive(new, Clone, Debug)]
pub struct BoolTensorSerde<B: Backend, const D: usize> {
tensor: Tensor<B, D, Bool>,
pub struct BoolTensorSerde {
data: DataSerialize<bool>,
}
// --- SERDE IMPLEMENTATIONS --- //
impl<B: Backend, const D: usize, S: PrecisionSettings> Serialize for FloatTensorSerde<B, D, S> {
impl<S: PrecisionSettings> Serialize for FloatTensorSerde<S> {
fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
where
Se: serde::Serializer,
{
self.tensor
.to_data()
.convert::<S::FloatElem>()
.serialize()
.serialize(serializer)
self.data.serialize(serializer)
}
}
impl<'de, B: Backend, const D: usize, S: PrecisionSettings> Deserialize<'de>
for FloatTensorSerde<B, D, S>
{
impl<'de, S: PrecisionSettings> Deserialize<'de> for FloatTensorSerde<S> {
fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
where
De: serde::Deserializer<'de>,
{
let data = DataSerialize::<S::FloatElem>::deserialize(deserializer)?;
let tensor = Tensor::from_data(data.convert::<B::FloatElem>());
Ok(Self::new(tensor))
Ok(Self::new(data))
}
}
impl<B: Backend, const D: usize, S: PrecisionSettings> Serialize for IntTensorSerde<B, D, S> {
// #[cfg(not(target_family = "wasm"))]
impl<S: PrecisionSettings> Serialize for IntTensorSerde<S> {
fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
where
Se: serde::Serializer,
{
self.tensor
.to_data()
.convert::<S::IntElem>()
.serialize()
.serialize(serializer)
self.data.serialize(serializer)
}
}
impl<'de, B: Backend, const D: usize, S: PrecisionSettings> Deserialize<'de>
for IntTensorSerde<B, D, S>
{
impl<'de, S: PrecisionSettings> Deserialize<'de> for IntTensorSerde<S> {
fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
where
De: serde::Deserializer<'de>,
{
let data = DataSerialize::<S::IntElem>::deserialize(deserializer)?;
let tensor = Tensor::from_data(data.convert::<B::IntElem>());
Ok(Self::new(tensor))
Ok(Self::new(data))
}
}
impl<B: Backend, const D: usize> Serialize for BoolTensorSerde<B, D> {
impl Serialize for BoolTensorSerde {
fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
where
Se: serde::Serializer,
{
self.tensor.to_data().serialize().serialize(serializer)
self.data.serialize(serializer)
}
}
impl<'de, B: Backend, const D: usize> Deserialize<'de> for BoolTensorSerde<B, D> {
impl<'de> Deserialize<'de> for BoolTensorSerde {
fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
where
De: serde::Deserializer<'de>,
{
let data = DataSerialize::<bool>::deserialize(deserializer)?;
let tensor = Tensor::from_data(data);
Ok(Self::new(tensor))
Ok(Self::new(data))
}
}
// --- RECORD IMPLEMENTATIONS --- //
impl<B: Backend, const D: usize> Record for Tensor<B, D> {
type Item<S: PrecisionSettings> = FloatTensorSerde<B, D, S>;
type Item<S: PrecisionSettings> = FloatTensorSerde<S>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
FloatTensorSerde::new(self)
#[cfg(target_family = "wasm")]
todo!("Recording float tensors isn't yet supported on wasm.");
#[cfg(not(target_family = "wasm"))]
FloatTensorSerde::new(self.into_data().convert().serialize())
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
item.tensor
Tensor::from_data(item.data.convert::<B::FloatElem>())
}
}
impl<B: Backend, const D: usize> Record for Tensor<B, D, Int> {
type Item<S: PrecisionSettings> = IntTensorSerde<B, D, S>;
type Item<S: PrecisionSettings> = IntTensorSerde<S>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
IntTensorSerde::new(self)
#[cfg(target_family = "wasm")]
todo!("Recording int tensors isn't yet supported on wasm.");
#[cfg(not(target_family = "wasm"))]
IntTensorSerde::new(self.into_data().convert().serialize())
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
item.tensor
Tensor::from_data(item.data.convert())
}
}
impl<B: Backend, const D: usize> Record for Tensor<B, D, Bool> {
type Item<S: PrecisionSettings> = BoolTensorSerde<B, D>;
type Item<S: PrecisionSettings> = BoolTensorSerde;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
BoolTensorSerde::new(self)
#[cfg(target_family = "wasm")]
todo!("Recording bool tensors isn't yet supported on wasm.");
#[cfg(not(target_family = "wasm"))]
BoolTensorSerde::new(self.into_data().serialize())
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
item.tensor
Tensor::from_data(item.data)
}
}

View File

@ -2,7 +2,7 @@
use alloc::vec;
use alloc::vec::Vec;
use burn_tensor::ops::{BoolTensorOps, IntTensorOps};
use burn_tensor::ElementConversion;
use burn_tensor::{ElementConversion, Reader};
use core::ops::Range;
// Current crate
@ -29,19 +29,13 @@ impl<E: FloatNdArrayElement> BoolTensorOps<NdArrayBackend<E>> for NdArrayBackend
tensor.shape()
}
fn bool_to_data<const D: usize>(
tensor: &<NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> Data<bool, D> {
let values = tensor.array.iter().map(Clone::clone).collect();
Data::new(values, tensor.shape())
}
fn bool_into_data<const D: usize>(
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> Data<bool, D> {
) -> Reader<Data<bool, D>> {
let shape = tensor.shape();
let values = tensor.array.into_iter().collect();
Data::new(values, shape)
Reader::Concrete(Data::new(values, shape))
}
fn bool_to_device<const D: usize>(
@ -68,7 +62,9 @@ impl<E: FloatNdArrayElement> BoolTensorOps<NdArrayBackend<E>> for NdArrayBackend
fn bool_into_int<const D: usize>(
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> NdArrayTensor<i64, D> {
let data = Self::bool_into_data(tensor);
let data = Self::bool_into_data(tensor)
.read_sync()
.expect("Always sync with ndarray");
NdArrayBackend::<E>::int_from_data(data.convert(), &NdArrayDevice::Cpu)
}

View File

@ -2,6 +2,8 @@
use alloc::vec;
use alloc::vec::Vec;
use burn_tensor::ops::IntTensorOps;
use burn_tensor::Reader;
use burn_tensor::ElementConversion;
use core::ops::Range;
@ -28,15 +30,11 @@ impl<E: FloatNdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<
tensor.shape()
}
fn int_to_data<const D: usize>(tensor: &NdArrayTensor<i64, D>) -> Data<i64, D> {
let values = tensor.array.iter().map(Clone::clone).collect();
Data::new(values, tensor.shape())
}
fn int_into_data<const D: usize>(tensor: NdArrayTensor<i64, D>) -> Data<i64, D> {
fn int_into_data<const D: usize>(tensor: NdArrayTensor<i64, D>) -> Reader<Data<i64, D>> {
let shape = tensor.shape();
let values = tensor.array.into_iter().collect();
Data::new(values, shape)
Reader::Concrete(Data::new(values, shape))
}
fn int_to_device<const D: usize>(

View File

@ -10,8 +10,8 @@ use crate::{NdArrayDevice, SEED};
// Workspace crates
use burn_common::rand::get_seeded_rng;
use burn_tensor::Distribution;
use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape};
use burn_tensor::{Distribution, Reader};
// External crates
use libm::{cos, erf, sin, tanh};
@ -45,19 +45,13 @@ impl<E: FloatNdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E>
tensor.shape()
}
fn to_data<const D: usize>(
tensor: &NdArrayTensor<E, D>,
) -> Data<<NdArrayBackend<E> as Backend>::FloatElem, D> {
let values = tensor.array.iter().map(Clone::clone).collect();
Data::new(values, tensor.shape())
}
fn into_data<const D: usize>(
tensor: NdArrayTensor<E, D>,
) -> Data<<NdArrayBackend<E> as Backend>::FloatElem, D> {
) -> Reader<Data<<NdArrayBackend<E> as Backend>::FloatElem, D>> {
let shape = tensor.shape();
let values = tensor.array.into_iter().collect();
Data::new(values, shape)
Reader::Concrete(Data::new(values, shape))
}
fn device<const D: usize>(_tensor: &NdArrayTensor<E, D>) -> NdArrayDevice {

View File

@ -16,8 +16,7 @@ impl<E, const D: usize> NdArrayTensor<E, D> {
#[cfg(test)]
mod utils {
use super::*;
use crate::{element::FloatNdArrayElement, NdArrayBackend};
use burn_tensor::ops::TensorOps;
use crate::element::FloatNdArrayElement;
impl<E, const D: usize> NdArrayTensor<E, D>
where
@ -27,7 +26,10 @@ mod utils {
where
E: FloatNdArrayElement,
{
<NdArrayBackend<E> as TensorOps<NdArrayBackend<E>>>::into_data::<D>(self)
let shape = self.shape();
let values = self.array.into_iter().collect();
Data::new(values, shape)
}
}
}

View File

@ -1,6 +1,6 @@
use std::ops::Range;
use burn_tensor::{backend::Backend, ops::BoolTensorOps, Data, Shape};
use burn_tensor::{backend::Backend, ops::BoolTensorOps, Data, Reader, Shape};
use crate::{element::TchElement, TchBackend, TchDevice, TchTensor};
@ -26,16 +26,12 @@ impl<E: TchElement> BoolTensorOps<TchBackend<E>> for TchBackend<E> {
TchOps::repeat(tensor, dim, times)
}
fn bool_to_data<const D: usize>(tensor: &TchTensor<bool, D>) -> Data<bool, D> {
let shape = Self::bool_shape(tensor);
fn bool_into_data<const D: usize>(tensor: TchTensor<bool, D>) -> Reader<Data<bool, D>> {
let shape = Self::bool_shape(&tensor);
let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
let values: Result<Vec<bool>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
Data::new(values.unwrap(), shape)
}
fn bool_into_data<const D: usize>(tensor: TchTensor<bool, D>) -> Data<bool, D> {
Self::bool_to_data(&tensor)
Reader::Concrete(Data::new(values.unwrap(), shape))
}
fn bool_to_device<const D: usize>(

View File

@ -1,6 +1,6 @@
use std::ops::Range;
use burn_tensor::{backend::Backend, ops::IntTensorOps, Data, Shape};
use burn_tensor::{backend::Backend, ops::IntTensorOps, Data, Reader, Shape};
use crate::{element::TchElement, TchBackend, TchDevice, TchShape, TchTensor};
@ -23,16 +23,12 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
TchOps::repeat(tensor, dim, times)
}
fn int_to_data<const D: usize>(tensor: &TchTensor<i64, D>) -> Data<i64, D> {
let shape = Self::int_shape(tensor);
fn int_into_data<const D: usize>(tensor: TchTensor<i64, D>) -> Reader<Data<i64, D>> {
let shape = Self::int_shape(&tensor);
let tensor = Self::int_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
let values: Result<Vec<i64>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
Data::new(values.unwrap(), shape)
}
fn int_into_data<const D: usize>(tensor: TchTensor<i64, D>) -> Data<i64, D> {
Self::int_to_data(&tensor)
Reader::Concrete(Data::new(values.unwrap(), shape))
}
fn int_to_device<const D: usize>(

View File

@ -1,6 +1,8 @@
use super::TchOps;
use crate::{element::TchElement, TchBackend, TchDevice, TchShape, TchTensor};
use burn_tensor::{backend::Backend, ops::TensorOps, Data, Distribution, ElementConversion, Shape};
use burn_tensor::{
backend::Backend, ops::TensorOps, Data, Distribution, ElementConversion, Reader, Shape,
};
use std::ops::Range;
impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
@ -79,20 +81,14 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
tensor.shape()
}
fn to_data<const D: usize>(
tensor: &<TchBackend<E> as Backend>::TensorPrimitive<D>,
) -> Data<<TchBackend<E> as Backend>::FloatElem, D> {
let shape = Self::shape(tensor);
let tensor = Self::reshape(tensor.clone(), Shape::new([shape.num_elements()]));
let values: Result<Vec<E>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
Data::new(values.unwrap(), shape)
}
fn into_data<const D: usize>(
tensor: <TchBackend<E> as Backend>::TensorPrimitive<D>,
) -> Data<<TchBackend<E> as Backend>::FloatElem, D> {
Self::to_data(&tensor)
) -> Reader<Data<<TchBackend<E> as Backend>::FloatElem, D>> {
let shape = Self::shape(&tensor);
let tensor = Self::reshape(tensor.clone(), Shape::new([shape.num_elements()]));
let values: Result<Vec<E>, tch::TchError> = tensor.tensor.try_into();
Reader::Concrete(Data::new(values.unwrap(), shape))
}
fn device<const D: usize>(tensor: &TchTensor<E, D>) -> TchDevice {

View File

@ -200,7 +200,7 @@ mod utils {
where
P: tch::kind::Element,
{
<TchBackend<P> as TensorOps<TchBackend<P>>>::into_data(self)
<TchBackend<P> as TensorOps<TchBackend<P>>>::into_data(self).read()
}
}
}

View File

@ -18,6 +18,7 @@ std = ["rand/std", "half/std"]
benchmark = []
[dependencies]
burn-common = { path = "../burn-common", version = "0.10.0", default-features = false }
burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.10.0", optional = true }
derive-new = { workspace = true }

View File

@ -18,6 +18,9 @@ mod tests;
pub use half::{bf16, f16};
pub use tensor::*;
pub use burn_common::reader::Reader; // Useful so that backends don't have to add `burn_common` as
// a dependency so that they can implement the traits.
#[cfg(feature = "benchmark")]
/// This module provides benchmark utilities for easily and reliably run
/// benches on any function that is generic over a backend.

View File

@ -1,8 +1,15 @@
#![allow(clippy::single_range_in_vec_init)]
use alloc::format;
use alloc::string::String;
use alloc::vec;
use alloc::vec::Vec;
#[cfg(not(target_family = "wasm"))]
use alloc::format;
#[cfg(not(target_family = "wasm"))]
use alloc::string::String;
#[cfg(not(target_family = "wasm"))]
use alloc::vec;
use burn_common::reader::Reader;
use core::{fmt::Debug, ops::Range};
use crate::{
@ -222,8 +229,6 @@ where
/// Unsqueeze the current tensor. Create new dimensions to fit the given size.
///
/// # Panics
///
/// If the output size is higher than the current tensor.
///
/// # Example
@ -320,11 +325,25 @@ where
Self::new(K::to_device(self.primitive, device))
}
#[cfg(target_family = "wasm")]
/// Returns the data of the current tensor.
pub fn into_data(self) -> Data<K::Elem, D> {
K::into_data(self.primitive)
pub async fn into_data(self) -> Data<K::Elem, D> {
K::into_data(self.primitive).read().await
}
#[cfg(not(target_family = "wasm"))]
/// Returns the data of the current tensor.
pub fn into_data(self) -> Data<K::Elem, D> {
K::into_data(self.primitive).read()
}
#[cfg(target_family = "wasm")]
/// Returns the data of the current tensor.
pub async fn to_data(&self) -> Data<K::Elem, D> {
K::into_data(self.primitive.clone()).read().await
}
#[cfg(not(target_family = "wasm"))]
/// Returns the data of the current tensor without taking ownership.
pub fn to_data(&self) -> Data<K::Elem, D> {
Self::into_data(self.clone())
@ -448,6 +467,7 @@ where
K: BasicOps<B>,
<K as BasicOps<B>>::Elem: Debug,
{
#[cfg(not(target_family = "wasm"))]
/// Recursively formats the tensor data for display and appends it to the provided accumulator string.
///
/// This function is designed to work with tensors of any dimensionality.
@ -474,7 +494,8 @@ where
multi_index[depth] = i;
let range: [core::ops::Range<usize>; D] =
core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);
let elem = &self.clone().slice(range).to_data().value[0];
let elem = &self.clone().slice(range).into_data().value[0];
acc.push_str(&format!("{elem:?}"));
}
} else {
@ -506,13 +527,18 @@ where
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
writeln!(f, "Tensor {{")?;
write!(f, " data: ")?;
let mut acc = String::new();
let mut multi_index = vec![0; D];
self.display_recursive(&mut acc, 0, &mut multi_index);
write!(f, "{acc}")?;
writeln!(f, ",")?;
#[cfg(not(target_family = "wasm"))]
{
write!(f, " data: ")?;
let mut acc = String::new();
let mut multi_index = vec![0; D];
self.display_recursive(&mut acc, 0, &mut multi_index);
write!(f, "{acc}")?;
writeln!(f, ",")?;
}
writeln!(f, " shape: {:?},", self.dims())?;
writeln!(f, " device: {:?},", self.device())?;
writeln!(f, " backend: {:?},", B::name())?;
@ -755,7 +781,7 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
///
/// For extracting the data of a tensor, users should prefer the [Tensor::into_data](Tensor::into_data) function,
/// which is more high-level and designed for public use.
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Data<Self::Elem, D>;
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Reader<Data<Self::Elem, D>>;
/// Creates a tensor from the given data.
///
@ -914,7 +940,7 @@ impl<B: Backend> BasicOps<B> for Float {
B::to_device(tensor, device)
}
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Data<Self::Elem, D> {
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Reader<Data<Self::Elem, D>> {
B::into_data(tensor)
}
@ -1001,7 +1027,7 @@ impl<B: Backend> BasicOps<B> for Int {
B::int_to_device(tensor, device)
}
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Data<Self::Elem, D> {
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Reader<Data<Self::Elem, D>> {
B::int_into_data(tensor)
}
@ -1088,7 +1114,7 @@ impl<B: Backend> BasicOps<B> for Bool {
B::bool_to_device(tensor, device)
}
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Data<Self::Elem, D> {
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Reader<Data<Self::Elem, D>> {
B::bool_into_data(tensor)
}

View File

@ -9,6 +9,7 @@ where
K: Numeric<B>,
K::Elem: Element,
{
#[cfg(not(target_family = "wasm"))]
/// Convert the tensor into a scalar.
///
/// # Panics
@ -19,6 +20,19 @@ where
let data = self.into_data();
data.value[0]
}
#[cfg(target_family = "wasm")]
/// Convert the tensor into a scalar.
///
/// # Panics
///
/// If the tensor doesn't have one element.
pub async fn into_scalar(self) -> K::Elem {
check!(TensorCheck::into_scalar(&self.shape()));
let data = self.into_data().await;
data.value[0]
}
/// Applies element wise addition operation.
///
/// `y = x2 + x1`

View File

@ -1,7 +1,7 @@
use alloc::vec::Vec;
use core::ops::Range;
use crate::{backend::Backend, tensor::Shape, Data};
use alloc::vec::Vec;
use burn_common::reader::Reader;
use core::ops::Range;
/// Bool Tensor API for basic operations, see [tensor](crate::Tensor)
/// for documentation on each function.
@ -39,7 +39,7 @@ pub trait BoolTensorOps<B: Backend> {
/// # Returns
///
/// The data structure with the tensor's data.
fn bool_into_data<const D: usize>(tensor: B::BoolTensorPrimitive<D>) -> Data<bool, D>;
fn bool_into_data<const D: usize>(tensor: B::BoolTensorPrimitive<D>) -> Reader<Data<bool, D>>;
/// Gets the data from the tensor.
///
@ -51,7 +51,7 @@ pub trait BoolTensorOps<B: Backend> {
/// # Returns
///
/// The data cloned from the data structure.
fn bool_to_data<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> Data<bool, D> {
fn bool_to_data<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> Reader<Data<bool, D>> {
Self::bool_into_data(tensor.clone())
}

View File

@ -1,7 +1,7 @@
use alloc::vec::Vec;
use core::ops::Range;
use crate::{backend::Backend, tensor::Shape, Data, ElementConversion};
use alloc::vec::Vec;
use burn_common::reader::Reader;
use core::ops::Range;
/// Int Tensor API for basic and numeric operations, see [tensor](crate::Tensor)
/// for documentation on each function.
@ -38,7 +38,9 @@ pub trait IntTensorOps<B: Backend> {
/// # Returns
///
/// The data structure with the tensor's data.
fn int_into_data<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> Data<B::IntElem, D>;
fn int_into_data<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
) -> Reader<Data<B::IntElem, D>>;
/// Gets the data from the tensor.
///
@ -49,7 +51,9 @@ pub trait IntTensorOps<B: Backend> {
/// # Returns
///
/// The data cloned from the data structure.
fn int_to_data<const D: usize>(tensor: &B::IntTensorPrimitive<D>) -> Data<B::IntElem, D> {
fn int_to_data<const D: usize>(
tensor: &B::IntTensorPrimitive<D>,
) -> Reader<Data<B::IntElem, D>> {
Self::int_into_data(tensor.clone())
}

View File

@ -1,7 +1,7 @@
use alloc::vec::Vec;
use core::ops::Range;
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion};
use alloc::vec::Vec;
use burn_common::reader::Reader;
use core::ops::Range;
/// Operations on float tensors.
pub trait TensorOps<B: Backend> {
@ -104,7 +104,9 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The data structure with the tensor's data.
fn to_data<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Data<B::FloatElem, D>;
fn to_data<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Reader<Data<B::FloatElem, D>> {
Self::into_data(tensor.clone())
}
/// Converts the tensor to a data structure.
///
@ -115,9 +117,7 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The data structure with the tensor's data.
fn into_data<const D: usize>(tensor: B::TensorPrimitive<D>) -> Data<B::FloatElem, D> {
Self::to_data(&tensor)
}
fn into_data<const D: usize>(tensor: B::TensorPrimitive<D>) -> Reader<Data<B::FloatElem, D>>;
/// Gets the device of the tensor.
///

View File

@ -26,7 +26,7 @@ spin = { workspace = true }
# WGPU stuff
futures-intrusive = { workspace = true }
pollster = { workspace = true }
wgpu = { workspace = true }
wgpu = { workspace = true, features = ["fragile-send-sync-non-atomic-wasm"] }
# Template
serde = { workspace = true }

View File

@ -7,7 +7,7 @@ use burn_compute::{
memory_management::{DeallocStrategy, SimpleMemoryManagement, SliceStrategy},
Compute,
};
use wgpu::{DeviceDescriptor, DeviceType};
use wgpu::DeviceDescriptor;
type MemoryManagement = SimpleMemoryManagement<WgpuStorage>;
type Server = WgpuServer<MemoryManagement>;
@ -26,41 +26,58 @@ pub fn compute_client<G: GraphicsApi>(device: &WgpuDevice) -> ComputeClient<Serv
let device = Arc::new(device);
COMPUTE.client(&device, move || {
let (device_wgpu, queue, info) = pollster::block_on(select_device::<G>(&device));
log::info!(
"Created wgpu compute server on device {:?} => {:?}",
device,
info
);
// TODO: Support a way to modify max_tasks without std.
let max_tasks = match std::env::var("BURN_WGPU_MAX_TASKS") {
Ok(value) => value
.parse::<usize>()
.expect("BURN_WGPU_MAX_TASKS should be a positive integer."),
Err(_) => 64, // 64 tasks by default
};
let device = Arc::new(device_wgpu);
let storage = WgpuStorage::new(device.clone());
let memory_management = SimpleMemoryManagement::new(
storage,
DeallocStrategy::new_period_tick(1000),
SliceStrategy::Ratio(0.9),
);
let server = WgpuServer::new(memory_management, device, queue, max_tasks);
let channel = Channel::new(server);
ComputeClient::new(channel)
pollster::block_on(create_client::<G>(&device))
})
}
/// Init the client async, necessary for wasm.
pub async fn init_async<G: GraphicsApi>(device: &WgpuDevice) {
let device = Arc::new(device);
let client = create_client::<G>(&device).await;
COMPUTE.register(&device, client)
}
async fn create_client<G: GraphicsApi>(device: &WgpuDevice) -> ComputeClient<Server, Channel> {
let (device_wgpu, queue, info) = select_device::<G>(device).await;
log::info!(
"Created wgpu compute server on device {:?} => {:?}",
device,
info
);
// TODO: Support a way to modify max_tasks without std.
let max_tasks = match std::env::var("BURN_WGPU_MAX_TASKS") {
Ok(value) => value
.parse::<usize>()
.expect("BURN_WGPU_MAX_TASKS should be a positive integer."),
Err(_) => 64, // 64 tasks by default
};
let device = Arc::new(device_wgpu);
let storage = WgpuStorage::new(device.clone());
let memory_management = SimpleMemoryManagement::new(
storage,
DeallocStrategy::new_period_tick(1000),
SliceStrategy::Ratio(0.9),
);
let server = WgpuServer::new(memory_management, device, queue, max_tasks);
let channel = Channel::new(server);
ComputeClient::new(channel)
}
/// Select the wgpu device and queue based on the provided [device](WgpuDevice).
pub async fn select_device<G: GraphicsApi>(
device: &WgpuDevice,
) -> (wgpu::Device, wgpu::Queue, wgpu::AdapterInfo) {
#[cfg(target_family = "wasm")]
let adapter = select_adapter::<G>(device).await;
#[cfg(not(target_family = "wasm"))]
let adapter = select_adapter::<G>(device);
let limits = adapter.limits();
let (device, queue) = adapter
@ -85,9 +102,21 @@ pub async fn select_device<G: GraphicsApi>(
(device, queue, adapter.get_info())
}
fn select_adapter<G: GraphicsApi>(device: &WgpuDevice) -> wgpu::Adapter {
#[cfg(target_family = "wasm")]
async fn select_adapter<G: GraphicsApi>(_device: &WgpuDevice) -> wgpu::Adapter {
let instance = wgpu::Instance::default();
instance
.request_adapter(&wgpu::RequestAdapterOptionsBase::default())
.await
.unwrap()
}
#[cfg(not(target_family = "wasm"))]
fn select_adapter<G: GraphicsApi>(device: &WgpuDevice) -> wgpu::Adapter {
use wgpu::DeviceType;
let instance = wgpu::Instance::default();
let mut adapters_other = Vec::new();
let mut adapters = Vec::new();

View File

@ -98,7 +98,7 @@ mod tests {
client.execute(kernel, &[&lhs, &rhs, &out, &info]);
let data = client.read(&out);
let data = client.read(&out).read_sync().unwrap();
let output: &[f32] = bytemuck::cast_slice(&data);
assert_eq!(output, [10., 12., 14., 9., 11., 8., 7., 7.]);

View File

@ -5,6 +5,7 @@ use burn_compute::{
memory_management::MemoryManagement,
server::{self, ComputeServer},
};
use burn_tensor::Reader;
use hashbrown::HashMap;
use wgpu::{
util::{BufferInitDescriptor, DeviceExt},
@ -175,17 +176,8 @@ where
}),
)
}
}
impl<MM> ComputeServer for WgpuServer<MM>
where
MM: MemoryManagement<WgpuStorage>,
{
type Kernel = Box<dyn Kernel>;
type Storage = WgpuStorage;
type MemoryManagement = MM;
fn read(&mut self, handle: &server::Handle<Self>) -> Vec<u8> {
fn buffer_reader(&mut self, handle: &server::Handle<Self>) -> BufferReader {
// Register previous tasks before reading the buffer so that it is up to date.
self.register_tasks();
@ -209,7 +201,28 @@ where
self.submit();
let buffer_slice = buffer_dest.slice(..);
BufferReader::new(buffer_dest)
}
}
#[derive(new)]
struct BufferReader {
buffer: wgpu::Buffer,
}
impl BufferReader {
#[cfg(target_family = "wasm")]
async fn read(self, device: alloc::sync::Arc<wgpu::Device>) -> Vec<u8> {
self.read_async(&device).await
}
#[cfg(not(target_family = "wasm"))]
fn read(self, device: &wgpu::Device) -> Vec<u8> {
pollster::block_on(self.read_async(device))
}
async fn read_async(&self, device: &wgpu::Device) -> Vec<u8> {
let buffer_slice = self.buffer.slice(..);
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |v| {
sender
@ -217,21 +230,41 @@ where
.expect("Unable to send buffer slice result to async channel.")
});
self.device.poll(wgpu::Maintain::Wait);
device.poll(wgpu::Maintain::Wait);
let result = pollster::block_on(receiver.receive());
let result = receiver.receive().await;
if let Some(Ok(())) = result {
let data = buffer_slice.get_mapped_range();
let result = bytemuck::cast_slice(&data).to_vec();
drop(data);
buffer_dest.unmap();
self.buffer.unmap();
result
} else {
panic!("Unable to read buffer {:?}", result)
}
}
}
impl<MM> ComputeServer for WgpuServer<MM>
where
MM: MemoryManagement<WgpuStorage>,
{
type Kernel = Box<dyn Kernel>;
type Storage = WgpuStorage;
type MemoryManagement = MM;
fn read(&mut self, handle: &server::Handle<Self>) -> Reader<Vec<u8>> {
#[cfg(target_family = "wasm")]
{
let future = self.buffer_reader(handle).read(self.device.clone());
return Reader::Future(Box::pin(future));
}
#[cfg(not(target_family = "wasm"))]
Reader::Concrete(self.buffer_reader(handle).read(&self.device))
}
/// When we create a new handle from existing data, we use custom allocations so that we don't
/// have to execute the current pending tasks.

View File

@ -1,7 +1,7 @@
use burn_tensor::Element;
/// The base element trait for the wgou backend.
pub trait WgpuElement: core::fmt::Debug + Send + Sync + 'static + Clone
pub trait WgpuElement: core::fmt::Debug + Send + Sync + 'static + Clone + bytemuck::Pod
where
Self: Sized,
{

View File

@ -6,6 +6,11 @@ use crate::{
};
use std::marker::PhantomData;
#[cfg(target_family = "wasm")]
pub(crate) const WORKGROUP_DEFAULT: usize = 16;
#[cfg(not(target_family = "wasm"))]
pub(crate) const WORKGROUP_DEFAULT: usize = 32;
/// Static wgpu kernel to create a [source template](SourceTemplate).
pub trait StaticKernelSource: Send + 'static {
/// Source template for the kernel.
@ -49,8 +54,6 @@ pub fn into_contiguous<E: WgpuElement, const D: usize>(
return tensor;
}
const WORKGROUP: usize = 32;
let num_elems = tensor.shape.num_elements();
let handle = tensor.client.empty(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(
@ -64,8 +67,11 @@ pub fn into_contiguous<E: WgpuElement, const D: usize>(
tensor.client.execute(
Box::new(StaticKernel::<
KernelSettings<ContiguousRaw, E, i32, WORKGROUP, WORKGROUP, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP))),
KernelSettings<ContiguousRaw, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(
num_elems,
WORKGROUP_DEFAULT,
))),
&[&tensor.handle, &output.handle, &info_handle],
);

View File

@ -1,4 +1,6 @@
use super::{build_info, elemwise_workgroup, KernelSettings, StaticKernelSource};
use super::{
build_info, elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT,
};
use crate::compute::StaticKernel;
use crate::{element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
use burn_tensor::Shape;
@ -54,7 +56,7 @@ pub fn binary_elemwise_default<K: StaticKernelSource, E: WgpuElement, const D: u
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
binary_elemwise::<K, E, D, 32>(lhs, rhs)
binary_elemwise::<K, E, D, WORKGROUP_DEFAULT>(lhs, rhs)
}
/// Execute a binary kernel using the provided WORKGROUP.
@ -105,7 +107,7 @@ pub fn binary_elemwise_inplace_default<K: StaticKernelSource, E: WgpuElement, co
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
binary_elemwise_inplace::<K, E, D, 32>(lhs, rhs)
binary_elemwise_inplace::<K, E, D, WORKGROUP_DEFAULT>(lhs, rhs)
}
/// Execute a binary inplace kernel using the provided WORKGROUP.

View File

@ -1,4 +1,4 @@
use super::{KernelSettings, SourceTemplate, StaticKernelSource};
use super::{KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT};
use crate::{
compute::StaticKernel, element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl,
tensor::WgpuTensor,
@ -30,12 +30,17 @@ pub fn cast<InputElem: WgpuElement, OutputElem: WgpuElement, const D: usize>(
return WgpuTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle);
}
const WORKGROUP: usize = 32;
let num_elems = tensor.shape.num_elements();
let kernel = StaticKernel::<
KernelSettings<Cast<InputElem, OutputElem>, f32, i32, WORKGROUP, WORKGROUP, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP));
KernelSettings<
Cast<InputElem, OutputElem>,
f32,
i32,
WORKGROUP_DEFAULT,
WORKGROUP_DEFAULT,
1,
>,
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
let handle = tensor
.client

View File

@ -6,14 +6,14 @@ use crate::{
tensor::WgpuTensor,
};
use super::WORKGROUP_DEFAULT;
kernel_wgsl!(Cat, "../template/cat.wgsl");
pub fn cat<E: WgpuElement, const D: usize>(
inputs: Vec<WgpuTensor<E, D>>,
dim: usize,
) -> WgpuTensor<E, D> {
const WORKGROUP: usize = 32;
let first_input = inputs.get(0).unwrap();
let client = &first_input.client;
let mut shape_output = first_input.shape.clone();
@ -38,9 +38,12 @@ pub fn cat<E: WgpuElement, const D: usize>(
info.push(dim_cat_index as u32);
dim_cat_index += input.shape.dims[dim];
let info_buffer = client.create(bytemuck::cast_slice(&info));
let kernel = StaticKernel::<KernelSettings<Cat, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(input.shape.num_elements(), WORKGROUP),
);
let kernel = StaticKernel::<
KernelSettings<Cat, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(
input.shape.num_elements(),
WORKGROUP_DEFAULT,
));
client.execute(
Box::new(kernel),

View File

@ -1,7 +1,9 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{build_info, elemwise_workgroup, KernelSettings, StaticKernelSource},
kernel::{
build_info, elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT,
},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
@ -66,8 +68,6 @@ pub fn comparison<K: StaticKernelSource, E: WgpuElement, const D: usize>(
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<u32, D> {
lhs.assert_is_on_same_device(&rhs);
const WORKGROUP: usize = 32;
let mut shape_out = [0; D];
lhs.shape
.dims
@ -83,9 +83,10 @@ pub fn comparison<K: StaticKernelSource, E: WgpuElement, const D: usize>(
let output = empty_device(lhs.client.clone(), lhs.device.clone(), shape_out);
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
);
let kernel =
StaticKernel::<KernelSettings<K, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP_DEFAULT),
);
let info = build_info(&[&lhs, &rhs, &output]);
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
@ -101,13 +102,12 @@ pub fn comparison_inplace<K: StaticKernelSource, E: WgpuElement, const D: usize>
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<u32, D> {
const WORKGROUP: usize = 32;
lhs.assert_is_on_same_device(&rhs);
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP),
);
let kernel =
StaticKernel::<KernelSettings<K, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>>::new(
elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP_DEFAULT),
);
let info = build_info(&[&lhs, &rhs]);
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));

View File

@ -1,7 +1,7 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{elemwise_workgroup, KernelSettings, StaticKernelSource},
kernel::{elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT},
kernel_wgsl,
tensor::WgpuTensor,
};
@ -58,14 +58,14 @@ pub fn comparison_elem<K: StaticKernelSource, E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: E,
) -> WgpuTensor<u32, D> {
const WORKGROUP: usize = 32;
let num_elems = lhs.shape.num_elements();
let handle = lhs.client.empty(num_elems * core::mem::size_of::<u32>());
let rhs_handle = lhs.client.create(E::as_bytes(&[rhs]));
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
);
let kernel =
StaticKernel::<KernelSettings<K, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP_DEFAULT),
);
lhs.client
.execute(Box::new(kernel), &[&lhs.handle, &rhs_handle, &handle]);
@ -77,11 +77,10 @@ pub fn comparison_elem_inplace<K: StaticKernelSource, E: WgpuElement, const D: u
lhs: WgpuTensor<E, D>,
rhs: E,
) -> WgpuTensor<u32, D> {
const WORKGROUP: usize = 32;
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP),
);
let kernel =
StaticKernel::<KernelSettings<K, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>>::new(
elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP_DEFAULT),
);
let rhs_handle = lhs.client.create(E::as_bytes(&[rhs]));
lhs.client
.execute(Box::new(kernel), &[&lhs.handle, &rhs_handle]);

View File

@ -1,7 +1,7 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{self, build_info, elemwise_workgroup, KernelSettings},
kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
@ -19,8 +19,6 @@ pub(crate) fn conv2d<E: WgpuElement + Element>(
bias: Option<WgpuTensor<E, 1>>,
options: ConvOptions<2>,
) -> WgpuTensor<E, 4> {
const WORKGROUP: usize = 32;
let input = kernel::into_contiguous(input);
let weight = kernel::into_contiguous(weight);
let [batch_size, _, in_height, in_width] = input.shape.dims;
@ -64,9 +62,12 @@ pub(crate) fn conv2d<E: WgpuElement + Element>(
let info_handle = input.client.create(bytemuck::cast_slice(&info));
let kernel = StaticKernel::<KernelSettings<Conv2d, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
);
let kernel = StaticKernel::<
KernelSettings<Conv2d, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(
output.shape.num_elements(),
WORKGROUP_DEFAULT,
));
input.client.execute(
Box::new(kernel),

View File

@ -1,7 +1,7 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{self, build_info, elemwise_workgroup, KernelSettings},
kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
@ -16,8 +16,6 @@ pub(crate) fn conv_transpose2d<E: WgpuElement + Element>(
bias: Option<WgpuTensor<E, 1>>,
options: ConvTransposeOptions<2>,
) -> WgpuTensor<E, 4> {
const WORKGROUP: usize = 32;
let input = kernel::into_contiguous(input);
let weight = kernel::into_contiguous(weight);
let [batch_size, _, in_height, in_width] = input.shape.dims;
@ -58,10 +56,9 @@ pub(crate) fn conv_transpose2d<E: WgpuElement + Element>(
let info_handle = input.client.create(bytemuck::cast_slice(&info));
let kernel =
StaticKernel::<KernelSettings<ConvTranspose2d, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
);
let kernel = StaticKernel::<
KernelSettings<ConvTranspose2d, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
input.client.execute(
Box::new(kernel),
&[

View File

@ -1,7 +1,7 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{self, build_info, elemwise_workgroup, KernelSettings},
kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
@ -14,8 +14,6 @@ pub(crate) fn gather<E: WgpuElement, I: WgpuElement, const D: usize>(
tensor: WgpuTensor<E, D>,
indices: WgpuTensor<I, D>,
) -> WgpuTensor<E, D> {
const WORKGROUP: usize = 32;
let shape_output = indices.shape.clone();
let num_elems = shape_output.num_elements();
let indices = kernel::into_contiguous(indices);
@ -25,9 +23,9 @@ pub(crate) fn gather<E: WgpuElement, I: WgpuElement, const D: usize>(
info.push(dim as u32);
let info_handle = tensor.client.create(bytemuck::cast_slice(&info));
let kernel = StaticKernel::<KernelSettings<Gather, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
);
let kernel = StaticKernel::<
KernelSettings<Gather, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
tensor.client.execute(
Box::new(kernel),

View File

@ -1,7 +1,7 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{self, build_info, elemwise_workgroup, KernelSettings},
kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
kernel_wgsl,
tensor::WgpuTensor,
};
@ -14,8 +14,6 @@ pub(crate) fn scatter<E: WgpuElement, I: WgpuElement, const D: usize>(
indices: WgpuTensor<I, D>,
value: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
const WORKGROUP: usize = 32;
let indices = kernel::into_contiguous(indices);
let tensor = kernel::into_contiguous(tensor);
let value = kernel::into_contiguous(value);
@ -51,9 +49,12 @@ pub(crate) fn scatter<E: WgpuElement, I: WgpuElement, const D: usize>(
let info_handle = tensor.client.create(bytemuck::cast_slice(&info));
let kernel = StaticKernel::<KernelSettings<Scatter, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems_per_workgroup, WORKGROUP),
);
let kernel = StaticKernel::<
KernelSettings<Scatter, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(
num_elems_per_workgroup,
WORKGROUP_DEFAULT,
));
tensor.client.execute(
Box::new(kernel),

View File

@ -1,7 +1,7 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{build_info, elemwise_workgroup, KernelSettings},
kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
@ -18,8 +18,6 @@ pub(crate) fn select<E: WgpuElement, I: WgpuElement, const D: usize>(
dim: usize,
indices: WgpuTensor<I, 1>,
) -> WgpuTensor<E, D> {
const WORKGROUP: usize = 32;
let mut output_shape = tensor.shape.clone();
output_shape.dims[dim] = indices.shape.dims[0];
@ -30,9 +28,9 @@ pub(crate) fn select<E: WgpuElement, I: WgpuElement, const D: usize>(
info.push(dim as u32);
let info_handle = output.client.create(bytemuck::cast_slice(&info));
let kernel = StaticKernel::<KernelSettings<IndexSelect, E, I, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
);
let kernel = StaticKernel::<
KernelSettings<IndexSelect, E, I, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
tensor.client.execute(
Box::new(kernel),

View File

@ -1,7 +1,7 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{build_info, elemwise_workgroup, KernelSettings},
kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
@ -19,8 +19,6 @@ pub(crate) fn slice<E: WgpuElement, const D1: usize, const D2: usize>(
tensor: WgpuTensor<E, D1>,
indices: [Range<usize>; D2],
) -> WgpuTensor<E, D1> {
const WORKGROUP: usize = 32;
let mut dims = tensor.shape.dims;
for i in 0..D2 {
dims[i] = indices[i].end - indices[i].start;
@ -38,9 +36,9 @@ pub(crate) fn slice<E: WgpuElement, const D1: usize, const D2: usize>(
let info_handle = output.client.create(bytemuck::cast_slice(&info));
let kernel = StaticKernel::<KernelSettings<IndexRaw, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
);
let kernel = StaticKernel::<
KernelSettings<IndexRaw, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
tensor.client.execute(
Box::new(kernel),
@ -55,8 +53,6 @@ pub(crate) fn slice_assign<E: WgpuElement, const D1: usize, const D2: usize>(
indices: [Range<usize>; D2],
value: WgpuTensor<E, D1>,
) -> WgpuTensor<E, D1> {
const WORKGROUP: usize = 32;
let tensor = match tensor.can_mut() {
true => tensor,
false => tensor.copy(),
@ -72,8 +68,8 @@ pub(crate) fn slice_assign<E: WgpuElement, const D1: usize, const D2: usize>(
let info_handle = tensor.client.create(bytemuck::cast_slice(&info));
let kernel = StaticKernel::<
KernelSettings<IndexAssignInplaceRaw, E, i32, WORKGROUP, WORKGROUP, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP));
KernelSettings<IndexAssignInplaceRaw, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
tensor.client.execute(
Box::new(kernel),

View File

@ -1,7 +1,7 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{build_info, elemwise_workgroup, KernelSettings},
kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
@ -15,8 +15,6 @@ pub fn mask_fill<E: WgpuElement, const D: usize>(
mask: WgpuTensor<u32, D>,
value: E,
) -> WgpuTensor<E, D> {
const WORKGROUP: usize = 32;
let num_elems = input.shape.num_elements();
let output = empty_device(
input.client.clone(),
@ -25,9 +23,9 @@ pub fn mask_fill<E: WgpuElement, const D: usize>(
);
let value_handle = output.client.create(E::as_bytes(&[value]));
let kernel = StaticKernel::<KernelSettings<MaskFill, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
);
let kernel = StaticKernel::<
KernelSettings<MaskFill, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle);
let info = build_info(&[&input, &mask, &output]);
let info_handle = input.client.create(bytemuck::cast_slice(&info));
@ -51,14 +49,11 @@ pub fn mask_fill_inplace<E: WgpuElement, const D: usize>(
mask: WgpuTensor<u32, D>,
value: E,
) -> WgpuTensor<E, D> {
const WORKGROUP: usize = 32;
let num_elems = input.shape.num_elements();
let value_handle = input.client.create(E::as_bytes(&[value]));
let kernel =
StaticKernel::<KernelSettings<MaskFillInplace, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
);
let kernel = StaticKernel::<
KernelSettings<MaskFillInplace, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle);
let info = build_info(&[&input, &mask]);
let info_handle = input.client.create(bytemuck::cast_slice(&info));

View File

@ -1,7 +1,7 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{build_info, elemwise_workgroup, KernelSettings},
kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
@ -15,8 +15,6 @@ pub fn mask_where<E: WgpuElement, const D: usize>(
mask: WgpuTensor<u32, D>,
value: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
const WORKGROUP: usize = 32;
let num_elems = input.shape.num_elements();
let output = empty_device(
input.client.clone(),
@ -24,9 +22,9 @@ pub fn mask_where<E: WgpuElement, const D: usize>(
input.shape.clone(),
);
let kernel = StaticKernel::<KernelSettings<MaskWhere, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
);
let kernel = StaticKernel::<
KernelSettings<MaskWhere, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle);
let info = build_info(&[&input, &value, &mask, &output]);
let info_handle = input.client.create(bytemuck::cast_slice(&info));
@ -51,12 +49,12 @@ pub fn mask_where_inplace<E: WgpuElement, const D: usize>(
value: WgpuTensor<E, D>,
reverse: bool,
) -> WgpuTensor<E, D> {
const WORKGROUP: usize = 32;
let kernel =
StaticKernel::<KernelSettings<MaskWhereInplace, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(input.shape.num_elements(), WORKGROUP),
);
let kernel = StaticKernel::<
KernelSettings<MaskWhereInplace, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(
input.shape.num_elements(),
WORKGROUP_DEFAULT,
));
let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle);
let mut info = build_info(&[&input, &value, &mask]);
info.push(match reverse {

View File

@ -6,6 +6,7 @@ use crate::{
element::WgpuElement,
kernel::{
build_info, into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource,
WORKGROUP_DEFAULT,
},
kernel_wgsl,
ops::numeric::empty_device,
@ -43,7 +44,7 @@ pub fn matmul_mem_coalescing_default<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
matmul_mem_coalescing::<E, D>(lhs, rhs, 16, 16)
matmul_mem_coalescing::<E, D>(lhs, rhs, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT)
}
/// Matrix multiplication using memory coalescing algorithm with custom workgroup sizes

View File

@ -1,7 +1,7 @@
use crate::{
compute::{StaticKernel, WgpuHandle},
element::WgpuElement,
kernel::{elemwise_workgroup, KernelSettings},
kernel::{elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
@ -21,17 +21,17 @@ pub(crate) fn adaptive_avg_pool2d<E: WgpuElement>(
x: WgpuTensor<E, 4>,
output_size: [usize; 2],
) -> WgpuTensor<E, 4> {
const WORKGROUP: usize = 32;
let [batch_size, channels, _, _] = x.shape.dims;
let output_shape = Shape::new([batch_size, channels, output_size[0], output_size[1]]);
let output = empty_device(x.client.clone(), x.device.clone(), output_shape);
let kernel =
StaticKernel::<KernelSettings<AdaptiveAvgPool2d, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
);
let kernel = StaticKernel::<
KernelSettings<AdaptiveAvgPool2d, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(
output.shape.num_elements(),
WORKGROUP_DEFAULT,
));
let info_handle = build_info(&x, &output);
x.client
@ -44,8 +44,6 @@ pub(crate) fn adaptive_avg_pool2d_backward<E: WgpuElement>(
x: WgpuTensor<E, 4>,
out_grad: WgpuTensor<E, 4>,
) -> WgpuTensor<E, 4> {
const WORKGROUP: usize = 32;
let output_shape = x.shape.clone();
let num_elems = output_shape.num_elements();
let output_buffer = x.client.empty(num_elems * core::mem::size_of::<E>());
@ -57,8 +55,11 @@ pub(crate) fn adaptive_avg_pool2d_backward<E: WgpuElement>(
);
let kernel = StaticKernel::<
KernelSettings<AdaptiveAvgPool2dBackward, E, i32, WORKGROUP, WORKGROUP, 1>,
>::new(elemwise_workgroup(output.shape.num_elements(), WORKGROUP));
KernelSettings<AdaptiveAvgPool2dBackward, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(
output.shape.num_elements(),
WORKGROUP_DEFAULT,
));
let info_handle = build_info(&x, &out_grad);

View File

@ -4,7 +4,7 @@ use crate::{
kernel::{
self, elemwise_workgroup,
pool::{build_output_and_info_pool2d, build_pool2d_info},
KernelSettings, StaticKernelSource,
KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT,
},
kernel_wgsl,
ops::numeric::empty_device,
@ -39,18 +39,16 @@ pub(crate) fn avg_pool2d<E: WgpuElement>(
padding: [usize; 2],
count_include_pad: bool,
) -> WgpuTensor<E, 4> {
const WORKGROUP: usize = 32;
let (info_handle, output) =
build_output_and_info_pool2d(&x, kernel_size, stride, padding, [1, 1]);
let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP);
let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT);
let kernel: Box<dyn Kernel> = match count_include_pad {
true => Box::new(StaticKernel::<
KernelSettings<AvgPool2d<true>, E, i32, WORKGROUP, WORKGROUP, 1>,
KernelSettings<AvgPool2d<true>, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(workgroup)),
false => Box::new(StaticKernel::<
KernelSettings<AvgPool2d<false>, E, i32, WORKGROUP, WORKGROUP, 1>,
KernelSettings<AvgPool2d<false>, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(workgroup)),
};
@ -68,19 +66,31 @@ pub(crate) fn avg_pool2d_backward<E: WgpuElement>(
padding: [usize; 2],
count_include_pad: bool,
) -> WgpuTensor<E, 4> {
const WORKGROUP: usize = 32;
let grad = kernel::into_contiguous(grad);
let output = empty_device(x.client.clone(), x.device.clone(), x.shape.clone());
let info_handle = build_pool2d_info(&x, &grad, kernel_size, stride, padding, [1, 1]);
let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP);
let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT);
let kernel: Box<dyn Kernel> = match count_include_pad {
true => Box::new(StaticKernel::<
KernelSettings<AvgPool2dBackward<true>, E, i32, WORKGROUP, WORKGROUP, 1>,
KernelSettings<
AvgPool2dBackward<true>,
E,
i32,
WORKGROUP_DEFAULT,
WORKGROUP_DEFAULT,
1,
>,
>::new(workgroup)),
false => Box::new(StaticKernel::<
KernelSettings<AvgPool2dBackward<false>, E, i32, WORKGROUP, WORKGROUP, 1>,
KernelSettings<
AvgPool2dBackward<false>,
E,
i32,
WORKGROUP_DEFAULT,
WORKGROUP_DEFAULT,
1,
>,
>::new(workgroup)),
};

View File

@ -4,7 +4,7 @@ use crate::{
kernel::{
self, elemwise_workgroup,
pool::{build_output_and_info_pool2d, build_pool2d_info},
KernelSettings,
KernelSettings, WORKGROUP_DEFAULT,
},
kernel_wgsl,
ops::numeric::empty_device,
@ -28,13 +28,14 @@ pub(crate) fn max_pool2d<E: WgpuElement>(
padding: [usize; 2],
dilation: [usize; 2],
) -> WgpuTensor<E, 4> {
const WORKGROUP: usize = 32;
let (info_handle, output) =
build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation);
let kernel = StaticKernel::<KernelSettings<MaxPool2d, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
);
let kernel = StaticKernel::<
KernelSettings<MaxPool2d, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(
output.shape.num_elements(),
WORKGROUP_DEFAULT,
));
x.client
.execute(Box::new(kernel), &[&x.handle, &output.handle, &info_handle]);
@ -49,15 +50,16 @@ pub(crate) fn max_pool2d_with_indices<E: WgpuElement, I: WgpuElement>(
padding: [usize; 2],
dilation: [usize; 2],
) -> (WgpuTensor<E, 4>, WgpuTensor<I, 4>) {
const WORKGROUP: usize = 32;
let (info_handle, output) =
build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation);
let indices = empty_device(x.client.clone(), x.device, output.shape.clone());
let kernel = StaticKernel::<
KernelSettings<MaxPool2dWithIndices, E, i32, WORKGROUP, WORKGROUP, 1>,
>::new(elemwise_workgroup(output.shape.num_elements(), WORKGROUP));
KernelSettings<MaxPool2dWithIndices, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(
output.shape.num_elements(),
WORKGROUP_DEFAULT,
));
x.client.execute(
Box::new(kernel),
@ -76,8 +78,6 @@ pub(crate) fn max_pool2d_with_indices_backward<E: WgpuElement, I: WgpuElement>(
padding: [usize; 2],
dilation: [usize; 2],
) -> WgpuTensor<E, 4> {
const WORKGROUP: usize = 32;
let grad = kernel::into_contiguous(grad);
let indices = kernel::into_contiguous(indices);
@ -88,8 +88,11 @@ pub(crate) fn max_pool2d_with_indices_backward<E: WgpuElement, I: WgpuElement>(
let info_handle = build_pool2d_info(&x, &grad, kernel_size, stride, padding, dilation);
let kernel = StaticKernel::<
KernelSettings<MaxPool2dWithIndicesBackward, E, I, WORKGROUP, WORKGROUP, 1>,
>::new(elemwise_workgroup(output.shape.num_elements(), WORKGROUP));
KernelSettings<MaxPool2dWithIndicesBackward, E, I, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(
output.shape.num_elements(),
WORKGROUP_DEFAULT,
));
x.client.execute(
Box::new(kernel),

View File

@ -5,7 +5,7 @@ use crate::{
element::WgpuElement,
kernel::{
prng::base::{make_args_buffer, make_info_buffer},
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource,
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT,
},
ops::numeric::empty_device,
tensor::WgpuTensor,
@ -34,18 +34,16 @@ pub fn random_bernoulli<G: GraphicsApi, E: WgpuElement, const D: usize>(
device: &WgpuDevice,
prob: E,
) -> WgpuTensor<E, D> {
const WORKGROUP: usize = 32;
const N_VALUES_PER_THREAD: usize = 128;
let client = compute_client::<G>(device);
let output = empty_device(client.clone(), device.clone(), shape.clone());
let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD);
let args_handle = make_args_buffer(client.clone(), &[prob]);
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD);
let kernel =
StaticKernel::<KernelSettings<BernoulliPrng, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
workgroup,
);
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD);
let kernel = StaticKernel::<
KernelSettings<BernoulliPrng, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(workgroup);
client.execute(
Box::new(kernel),

View File

@ -5,7 +5,7 @@ use crate::{
element::WgpuElement,
kernel::{
prng::base::{make_args_buffer, make_info_buffer},
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource,
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT,
},
ops::numeric::empty_device,
tensor::WgpuTensor,
@ -37,16 +37,16 @@ pub fn random_normal<G: GraphicsApi, E: WgpuElement, const D: usize>(
mean: E,
std: E,
) -> WgpuTensor<E, D> {
const WORKGROUP: usize = 32;
const N_VALUES_PER_THREAD: usize = 128; // must be even
let client = compute_client::<G>(device);
let output = empty_device(client.clone(), device.clone(), shape.clone());
let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD);
let args_handle = make_args_buffer(client.clone(), &[mean, std]);
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD);
let kernel =
StaticKernel::<KernelSettings<NormalPrng, E, i32, WORKGROUP, WORKGROUP, 1>>::new(workgroup);
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD);
let kernel = StaticKernel::<
KernelSettings<NormalPrng, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(workgroup);
client.execute(
Box::new(kernel),

View File

@ -5,7 +5,7 @@ use crate::{
element::WgpuElement,
kernel::{
prng::base::{make_args_buffer, make_info_buffer},
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource,
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT,
},
ops::numeric::empty_device,
tensor::WgpuTensor,
@ -32,17 +32,16 @@ pub fn random_uniform<G: GraphicsApi, E: WgpuElement, const D: usize>(
low: E,
high: E,
) -> WgpuTensor<E, D> {
const WORKGROUP: usize = 32;
const N_VALUES_PER_THREAD: usize = 128;
let client = compute_client::<G>(device);
let output = empty_device(client.clone(), device.clone(), shape.clone());
let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD);
let args_handle = make_args_buffer(client.clone(), &[low, high]);
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD);
let kernel = StaticKernel::<KernelSettings<UniformPrng, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
workgroup,
);
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD);
let kernel = StaticKernel::<
KernelSettings<UniformPrng, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(workgroup);
client.execute(
Box::new(kernel),

View File

@ -1,4 +1,4 @@
use super::{build_info, KernelSettings, SourceTemplate, StaticKernelSource};
use super::{build_info, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT};
use crate::{
compute::StaticKernel, element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl,
tensor::WgpuTensor,
@ -50,10 +50,8 @@ impl StaticKernelSource for ArgsMin {
/// Sum all elements in the input buffer.
pub fn sum<E: WgpuElement, const D: usize>(input: WgpuTensor<E, D>) -> WgpuTensor<E, 1> {
const WORKGROUP: usize = 32;
let mut input_handle = input.handle;
let mut workgroup = elemwise_workgroup(input.shape.num_elements(), WORKGROUP);
let mut workgroup = elemwise_workgroup(input.shape.num_elements(), WORKGROUP_DEFAULT);
loop {
let num_invocations = workgroup.num_invocations();
@ -61,10 +59,9 @@ pub fn sum<E: WgpuElement, const D: usize>(input: WgpuTensor<E, D>) -> WgpuTenso
.client
.empty(core::mem::size_of::<E>() * num_invocations);
let kernel =
StaticKernel::<KernelSettings<RecursiveSumRaw, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
workgroup,
);
let kernel = StaticKernel::<
KernelSettings<RecursiveSumRaw, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(workgroup);
input
.client
@ -75,7 +72,7 @@ pub fn sum<E: WgpuElement, const D: usize>(input: WgpuTensor<E, D>) -> WgpuTenso
}
input_handle = handle;
workgroup = elemwise_workgroup(num_invocations, WORKGROUP);
workgroup = elemwise_workgroup(num_invocations, WORKGROUP_DEFAULT);
}
}
@ -99,8 +96,6 @@ fn reduction_dim<K: StaticKernelSource, E: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<E, D> {
const WORKGROUP: usize = 32;
let mut shape_out = input.shape.clone();
shape_out.dims[dim] = 1;
let num_elems = shape_out.num_elements();
@ -112,9 +107,10 @@ fn reduction_dim<K: StaticKernelSource, E: WgpuElement, const D: usize>(
handle,
);
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
);
let kernel =
StaticKernel::<KernelSettings<K, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP_DEFAULT),
);
let mut info = build_info(&[&input, &output]);
info.push(dim as u32);
@ -148,8 +144,6 @@ fn reduction_args_dim<K: StaticKernelSource, E: WgpuElement, I: WgpuElement, con
input: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<I, D> {
const WORKGROUP: usize = 32;
let mut shape_out = input.shape.clone();
shape_out.dims[dim] = 1;
let num_elems = shape_out.num_elements();
@ -161,9 +155,10 @@ fn reduction_args_dim<K: StaticKernelSource, E: WgpuElement, I: WgpuElement, con
buffer,
);
let kernel = StaticKernel::<KernelSettings<K, E, I, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
);
let kernel =
StaticKernel::<KernelSettings<K, E, I, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP_DEFAULT),
);
let mut info = build_info(&[&input, &output]);
info.push(dim as u32);
let info_handle = input.client.create(bytemuck::cast_slice(&info));

View File

@ -1,4 +1,4 @@
use super::{elemwise_workgroup, KernelSettings, StaticKernelSource};
use super::{elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT};
use crate::{compute::StaticKernel, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
kernel_wgsl!(UnaryRaw, "../template/unary.wgsl");
@ -98,14 +98,14 @@ macro_rules! unary_inplace {
pub fn unary_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
unary::<K, E, D, 32>(input)
unary::<K, E, D, WORKGROUP_DEFAULT>(input)
}
/// Execute a unary inplace kernel using the default settings.
pub fn unary_inplace_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
unary_inplace::<K, E, D, 32>(input)
unary_inplace::<K, E, D, WORKGROUP_DEFAULT>(input)
}
/// Execute a unary inplace kernel using the provided WORKGROUP.

View File

@ -1,4 +1,4 @@
use super::{elemwise_workgroup, KernelSettings, StaticKernelSource};
use super::{elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT};
use crate::{compute::StaticKernel, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
kernel_wgsl!(UnaryScalarRaw, "../template/unary_scalar.wgsl");
@ -108,7 +108,7 @@ pub fn unary_scalar_default<K: StaticKernelSource, E: WgpuElement, const D: usiz
lhs: WgpuTensor<E, D>,
scalar: E,
) -> WgpuTensor<E, D> {
unary_scalar::<K, E, D, 32>(lhs, scalar)
unary_scalar::<K, E, D, WORKGROUP_DEFAULT>(lhs, scalar)
}
/// Execute a unary scalar kernel using the provided WORKGROUP.
@ -142,7 +142,7 @@ pub fn unary_scalar_inplace_default<K: StaticKernelSource, E: WgpuElement, const
lhs: WgpuTensor<E, D>,
scalar: E,
) -> WgpuTensor<E, D> {
unary_scalar_inplace::<K, E, D, 32>(lhs, scalar)
unary_scalar_inplace::<K, E, D, WORKGROUP_DEFAULT>(lhs, scalar)
}
/// Execute a unary scalar inplace kernel using the provided WORKGROUP.

View File

@ -2,7 +2,7 @@ use crate::{
compute::compute_client, element::WgpuElement, kernel, tensor::WgpuTensor, GraphicsApi,
WgpuDevice,
};
use burn_tensor::{backend::Backend, Data, Shape};
use burn_tensor::{backend::Backend, Data, Reader, Shape};
pub type FloatElem<B> = <B as Backend>::FloatElem;
pub type Device<B> = <B as Backend>::Device;
@ -25,12 +25,24 @@ pub fn from_data<G: GraphicsApi, E: WgpuElement, const D: usize>(
WgpuTensor::new(client, device.clone(), data.shape, buffer)
}
pub fn into_data<E: WgpuElement, const D: usize>(tensor: WgpuTensor<E, D>) -> Data<E, D> {
pub fn into_data<E: WgpuElement, const D: usize>(tensor: WgpuTensor<E, D>) -> Reader<Data<E, D>> {
let tensor = kernel::into_contiguous(tensor);
let bytes = tensor.client.read(&tensor.handle);
let values = E::from_bytes(&bytes);
Data::new(values.to_vec(), tensor.shape)
tensor
.client
.read(&tensor.handle)
.map(|bytes| Data::new(E::from_bytes(&bytes).to_vec(), tensor.shape))
}
pub fn bool_into_data<const D: usize>(tensor: WgpuTensor<u32, D>) -> Reader<Data<bool, D>> {
let tensor = kernel::into_contiguous(tensor);
tensor.client.read(&tensor.handle).map(|bytes| {
Data::new(
u32::from_bytes(&bytes).iter().map(|i| *i != 0).collect(),
tensor.shape,
)
})
}
pub fn to_device<G: GraphicsApi, E: WgpuElement, const D: usize>(

View File

@ -5,7 +5,8 @@ use crate::{
tensor::WgpuTensor,
GraphicsApi, WgpuBackend,
};
use burn_tensor::{ops::BoolTensorOps, ops::IntTensorOps, Data, Shape};
use burn_tensor::{ops::BoolTensorOps, Data, Shape};
use burn_tensor::{ops::IntTensorOps, Reader};
use std::ops::Range;
impl<G, F, I> BoolTensorOps<WgpuBackend<G, F, I>> for WgpuBackend<G, F, I>
@ -22,10 +23,8 @@ where
tensor.shape.clone()
}
fn bool_into_data<const D: usize>(tensor: BoolTensor<Self, D>) -> Data<bool, D> {
let data = super::into_data(tensor);
Data::new(data.value.into_iter().map(|i| i != 0).collect(), data.shape)
fn bool_into_data<const D: usize>(tensor: BoolTensor<Self, D>) -> Reader<Data<bool, D>> {
super::bool_into_data(tensor)
}
fn bool_from_data<const D: usize>(
@ -51,7 +50,10 @@ where
}
let device = Self::bool_device(&tensor);
let data = Self::bool_into_data(tensor).convert::<I>();
let data = Self::bool_into_data(tensor)
.read_sync()
.expect("Can't convert bool to int with a different type size async")
.convert::<I>();
Self::int_from_data(data, &device)
}

View File

@ -9,8 +9,8 @@ use crate::{
element::{FloatElement, IntElement},
unary, unary_inplace, unary_scalar, GraphicsApi, WgpuBackend,
};
use burn_tensor::ElementConversion;
use burn_tensor::{ops::TensorOps, Data, Distribution, Shape};
use burn_tensor::{ElementConversion, Reader};
use std::ops::Range;
@ -48,11 +48,7 @@ where
tensor.shape.clone()
}
fn to_data<const D: usize>(tensor: &FloatTensor<Self, D>) -> Data<FloatElem<Self>, D> {
super::into_data(tensor.clone())
}
fn into_data<const D: usize>(tensor: FloatTensor<Self, D>) -> Data<FloatElem<Self>, D> {
fn into_data<const D: usize>(tensor: FloatTensor<Self, D>) -> Reader<Data<FloatElem<Self>, D>> {
super::into_data(tensor)
}

View File

@ -4,6 +4,8 @@ use crate::{
element::{FloatElement, IntElement},
kernel, unary, unary_inplace, GraphicsApi, WgpuBackend,
};
use burn_tensor::Reader;
use burn_tensor::{ops::IntTensorOps, Data, Shape};
use std::ops::Range;
@ -21,7 +23,7 @@ where
tensor.shape.clone()
}
fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> Data<I, D> {
fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> Reader<Data<I, D>> {
super::into_data(tensor)
}

View File

@ -36,7 +36,7 @@ fn main(
fn taus_step(z: u32, s1: u32, s2: u32, s3: u32, m: u32) -> u32 {
let b = ((z << s1) ^ z) >> s2;
return (z & m) << s3 ^ b;
return ((z & m) << s3) ^ b;
}
fn taus_step_0(z: u32) -> u32 {

View File

@ -94,8 +94,12 @@ impl<E: WgpuElement, const D: usize> WgpuTensor<E, D> {
/// Change the context of the current tensor and return the newly transferred tensor.
pub fn to_client(&self, client: WgpuComputeClient, device: WgpuDevice) -> Self {
let data = self.client.read(&self.handle);
let handle = client.create(&data);
let bytes = self
.client
.read(&self.handle)
.read_sync()
.expect("Can only change client synchronously");
let handle = client.create(&bytes);
Self {
client,
@ -106,6 +110,7 @@ impl<E: WgpuElement, const D: usize> WgpuTensor<E, D> {
elem: PhantomData,
}
}
pub(crate) fn can_mut_broadcast(&self, tensor_other: &WgpuTensor<E, D>) -> bool {
if !self.handle.can_mut() {
return false;

View File

@ -44,7 +44,6 @@ ndarray-blas-openblas = ["burn-core/ndarray-blas-openblas"]
ndarray-blas-openblas-system = ["burn-core/ndarray-blas-openblas-system"]
wgpu = ["burn-core/wgpu"]
tch = ["burn-core/tch"]
# Experimental

View File

@ -10,9 +10,17 @@ version = "0.10.0"
crate-type = ["cdylib"]
[features]
default = []
default = ["ndarray"]
ndarray = ["burn/ndarray-no-std"]
wgpu = ["burn/wgpu"]
[dependencies]
burn = {path = "../../burn", default-features = false, features = ["ndarray-no-std"]}
burn = {path = "../../burn", default-features = false}
serde = {workspace = true}
wasm-bindgen = "0.2.87"
wasm-bindgen = { version = "0.2.87" }
wasm-bindgen-futures = "0.4"
js-sys = "0.3.64"
[dev-dependencies]
pollster = { workspace = true }

View File

@ -9,9 +9,11 @@ This crate demonstrates how to run an MNIST-trained model in the browser for inf
1. Build
```shell
./build-for-web.sh
./build-for-web.sh {backend}
```
The backend can either be `ndarray` or `wgpu`. Note that `wgpu` only works for browsers with support for WebGPU.
2. Run the server
```shell

View File

@ -10,9 +10,9 @@ then
fi
# Set optimization flags
export RUSTFLAGS="-C lto=fat -C embed-bitcode=yes -C codegen-units=1 -C opt-level=3"
export RUSTFLAGS="-C lto=fat -C embed-bitcode=yes -C codegen-units=1 -C opt-level=3 --cfg web_sys_unstable_apis"
# Run wasm pack tool to build JS wrapper files and copy wasm to pkg directory.
mkdir -p pkg
wasm-pack build --out-dir pkg --release --target web --no-typescript
wasm-pack build --out-dir pkg --release --target web --no-typescript --no-default-features --features $1

View File

@ -123,13 +123,13 @@
wasm().then((module) => {
const mnist = new Mnist();
function fireOffInference() {
async function fireOffInference() {
clearTimeout(timeoutId);
timeoutId = setTimeout(() => {
timeoutId = setTimeout(async () => {
isTimeOutSet = true;
fabricCanvas.freeDrawingBrush._finalizeAndAddPath();
const data = cropScaleGetImageData(mainContext, cropContext, scaledContext);
const output = mnist.inference(data);
const output = await mnist.inference(data);
chart.data.datasets[0].data = output;
chart.update();
isTimeOutSet = false;
@ -140,14 +140,14 @@
fabricCanvas.on("mouse:down", function (event) {
isDrawing = true;
});
fabricCanvas.on("mouse:up", function (event) {
fabricCanvas.on("mouse:up", async function (event) {
isDrawing = false;
fireOffInference();
await fireOffInference();
});
fabricCanvas.on("mouse:move", function (event) {
fabricCanvas.on("mouse:move", async function (event) {
if (isDrawing && isTimeOutSet == false) {
fireOffInference();
await fireOffInference();
}
});
});

View File

@ -1,16 +1,25 @@
use crate::model::Model;
use burn::backend::ndarray::NdArrayBackend;
use burn::module::Module;
use burn::record::BinBytesRecorder;
use burn::record::FullPrecisionSettings;
use burn::record::Recorder;
pub type Backend = NdArrayBackend<f32>;
#[cfg(feature = "wgpu")]
use burn::backend::wgpu::{compute::init_async, AutoGraphicsApi, WgpuBackend, WgpuDevice};
#[cfg(feature = "wgpu")]
pub type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
#[cfg(feature = "ndarray")]
pub type Backend = burn::backend::ndarray::NdArrayBackend<f32>;
static STATE_ENCODED: &[u8] = include_bytes!("../model.bin");
/// Builds and loads trained parameters into the model.
pub fn build_and_load_model() -> Model<Backend> {
pub async fn build_and_load_model() -> Model<Backend> {
#[cfg(feature = "wgpu")]
init_async::<AutoGraphicsApi>(&WgpuDevice::default()).await;
let model: Model<Backend> = Model::new();
let record = BinBytesRecorder::<FullPrecisionSettings>::default()
.load(STATE_ENCODED.to_vec())

View File

@ -1,29 +1,29 @@
#![allow(clippy::new_without_default)]
use alloc::{boxed::Box, string::String};
use alloc::string::String;
use js_sys::Array;
#[cfg(target_family = "wasm")]
use wasm_bindgen::prelude::*;
use crate::model::Model;
use crate::state::{build_and_load_model, Backend};
use burn::tensor::Tensor;
use wasm_bindgen::prelude::*;
/// Mnist structure that corresponds to JavaScript class.
/// See:[exporting-rust-struct](https://rustwasm.github.io/wasm-bindgen/contributing/design/exporting-rust-struct.html)
#[wasm_bindgen]
#[cfg_attr(target_family = "wasm", wasm_bindgen)]
pub struct Mnist {
model: Model<Backend>,
model: Option<Model<Backend>>,
}
#[wasm_bindgen]
#[cfg_attr(target_family = "wasm", wasm_bindgen)]
impl Mnist {
/// Constructor called by JavaScripts with the new keyword.
#[wasm_bindgen(constructor)]
#[cfg_attr(target_family = "wasm", wasm_bindgen(constructor))]
pub fn new() -> Self {
Self {
model: build_and_load_model(),
}
Self { model: None }
}
/// Returns the inference results.
@ -38,7 +38,13 @@ impl Mnist {
/// * [number-slices](https://rustwasm.github.io/wasm-bindgen/reference/types/number-slices.html)
/// * [boxed-number-slices](https://rustwasm.github.io/wasm-bindgen/reference/types/boxed-number-slices.html)
///
pub fn inference(&self, input: &[f32]) -> Result<Box<[f32]>, String> {
pub async fn inference(&mut self, input: &[f32]) -> Result<Array, String> {
if self.model.is_none() {
self.model = Some(build_and_load_model().await);
}
let model = self.model.as_ref().unwrap();
// Reshape from the 1D array to 3d tensor [batch, height, width]
let input: Tensor<Backend, 3> = Tensor::from_floats(input).reshape([1, 28, 28]);
@ -49,77 +55,23 @@ impl Mnist {
let input = ((input / 255) - 0.1307) / 0.3081;
// Run the tensor input through the model
let output: Tensor<Backend, 2> = self.model.forward(input);
let output: Tensor<Backend, 2> = model.forward(input);
// Convert the model output into probability distribution using softmax formula
let output: Tensor<Backend, 2> = output.clone().exp() / output.exp().sum_dim(1);
let output = burn::tensor::activation::softmax(output, 1);
// Flatten output tensor with [1, 10] shape into boxed slice of [f32]
Ok(output.to_data().value.into_boxed_slice())
}
}
#[cfg(test)]
mod tests {
use super::Mnist;
#[test]
fn inference_manual_from_test_data() {
let mnist = Mnist::new();
let input: Vec<f32> = vec![
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 22.0, 97.0, 181.0, 254.0, 255.0, 221.0, 106.0, 3.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 28.0, 128.0, 213.0,
245.0, 254.0, 254.0, 246.0, 239.0, 254.0, 254.0, 94.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 27.0, 151.0, 239.0, 254.0, 254.0, 222.0,
204.0, 189.0, 70.0, 27.0, 215.0, 254.0, 98.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 43.0, 248.0, 254.0, 233.0, 40.0, 15.0, 0.0, 0.0,
0.0, 0.0, 96.0, 254.0, 111.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 121.0, 72.0, 33.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 76.0, 254.0,
187.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 94.0, 254.0, 149.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 124.0, 254.0, 40.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 203.0,
174.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 102.0, 254.0, 122.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 40.0, 223.0, 205.0, 254.0, 115.0, 33.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 40.0, 248.0,
254.0, 254.0, 254.0, 242.0, 191.0, 71.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 140.0, 254.0, 254.0, 248.0,
209.0, 254.0, 213.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 186.0, 208.0, 71.0, 49.0, 74.0, 191.0, 134.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 134.0, 254.0, 119.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 16.0, 197.0,
168.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 140.0, 225.0, 43.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 51.0, 233.0, 106.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 161.0,
227.0, 17.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 228.0, 124.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 168.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
let output = mnist.inference(input.as_slice()).unwrap();
assert!(output[7] > 0.9);
#[cfg(not(target_family = "wasm"))]
let output = output.into_data().convert::<f32>().value;
#[cfg(target_family = "wasm")]
let output = output.into_data().await.convert::<f32>().value;
let array = Array::new();
for value in output {
array.push(&value.into());
}
Ok(array)
}
}