mirror of https://github.com/tracel-ai/burn.git
Feat/async read (#833)
This commit is contained in:
parent
aa90fe8efb
commit
ca787d6446
|
@ -63,6 +63,7 @@ thiserror = "1.0.40"
|
|||
tracing-subscriber = "0.3.17"
|
||||
tracing-core = "0.1.31"
|
||||
tracing-appender = "0.2.2"
|
||||
async-trait = "0.1.73"
|
||||
|
||||
# WGPU stuff
|
||||
futures-intrusive = "0.5"
|
||||
|
|
|
@ -16,7 +16,7 @@ export_tests = ["burn-tensor-testgen"]
|
|||
|
||||
[dependencies]
|
||||
burn-common = {path = "../burn-common", version = "0.10.0" }
|
||||
burn-tensor = {path = "../burn-tensor", version = "0.10.0" }
|
||||
burn-tensor = {path = "../burn-tensor", version = "0.10.0", default-features = false }
|
||||
burn-tensor-testgen = {path = "../burn-tensor-testgen", version = "0.10.0", optional = true}
|
||||
|
||||
derive-new = {workspace = true}
|
||||
|
|
|
@ -10,6 +10,8 @@
|
|||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
/// Gradients module.
|
||||
pub mod grads;
|
||||
/// Operation module.
|
||||
|
|
|
@ -3,7 +3,7 @@ use crate::{
|
|||
ADBackendDecorator,
|
||||
};
|
||||
|
||||
use burn_tensor::{backend::Backend, ops::BoolTensorOps, Data, Shape};
|
||||
use burn_tensor::{backend::Backend, ops::BoolTensorOps, Data, Reader, Shape};
|
||||
|
||||
impl<B: Backend> BoolTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
||||
fn bool_from_data<const D: usize>(data: Data<bool, D>, device: &B::Device) -> BoolTensor<B, D> {
|
||||
|
@ -14,11 +14,11 @@ impl<B: Backend> BoolTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B>
|
|||
B::bool_shape(tensor)
|
||||
}
|
||||
|
||||
fn bool_to_data<const D: usize>(tensor: &BoolTensor<B, D>) -> Data<bool, D> {
|
||||
fn bool_to_data<const D: usize>(tensor: &BoolTensor<B, D>) -> Reader<Data<bool, D>> {
|
||||
B::bool_to_data(tensor)
|
||||
}
|
||||
|
||||
fn bool_into_data<const D: usize>(tensor: BoolTensor<B, D>) -> Data<bool, D> {
|
||||
fn bool_into_data<const D: usize>(tensor: BoolTensor<B, D>) -> Reader<Data<bool, D>> {
|
||||
B::bool_into_data(tensor)
|
||||
}
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ use crate::{
|
|||
ADBackendDecorator,
|
||||
};
|
||||
|
||||
use burn_tensor::{backend::Backend, ops::IntTensorOps, Data, Shape};
|
||||
use burn_tensor::{backend::Backend, ops::IntTensorOps, Data, Reader, Shape};
|
||||
|
||||
impl<B: Backend> IntTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
||||
fn int_from_data<const D: usize>(
|
||||
|
@ -17,11 +17,11 @@ impl<B: Backend> IntTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
B::int_shape(tensor)
|
||||
}
|
||||
|
||||
fn int_to_data<const D: usize>(tensor: &IntTensor<B, D>) -> Data<B::IntElem, D> {
|
||||
fn int_to_data<const D: usize>(tensor: &IntTensor<B, D>) -> Reader<Data<B::IntElem, D>> {
|
||||
B::int_to_data(tensor)
|
||||
}
|
||||
|
||||
fn int_into_data<const D: usize>(tensor: IntTensor<B, D>) -> Data<B::IntElem, D> {
|
||||
fn int_into_data<const D: usize>(tensor: IntTensor<B, D>) -> Reader<Data<B::IntElem, D>> {
|
||||
B::int_into_data(tensor)
|
||||
}
|
||||
|
||||
|
|
|
@ -9,7 +9,9 @@ use crate::{
|
|||
ADBackendDecorator,
|
||||
};
|
||||
|
||||
use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape, Tensor};
|
||||
use burn_tensor::{
|
||||
backend::Backend, ops::TensorOps, Data, ElementConversion, Reader, Shape, Tensor,
|
||||
};
|
||||
|
||||
use super::maxmin::MaxMinDim;
|
||||
|
||||
|
@ -41,11 +43,11 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
B::shape(&tensor.primitive)
|
||||
}
|
||||
|
||||
fn to_data<const D: usize>(tensor: &ADTensor<B, D>) -> Data<FloatElem<B>, D> {
|
||||
fn to_data<const D: usize>(tensor: &ADTensor<B, D>) -> Reader<Data<FloatElem<B>, D>> {
|
||||
B::to_data(&tensor.primitive)
|
||||
}
|
||||
|
||||
fn into_data<const D: usize>(tensor: ADTensor<B, D>) -> Data<FloatElem<B>, D> {
|
||||
fn into_data<const D: usize>(tensor: ADTensor<B, D>) -> Reader<Data<FloatElem<B>, D>> {
|
||||
B::into_data(tensor.primitive)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use burn_tensor::{backend::Backend, Data, Shape};
|
||||
use burn_tensor::{backend::Backend, Data, Reader, Shape};
|
||||
|
||||
use crate::{
|
||||
element::{CandleElement, FloatCandleElement, IntCandleElement},
|
||||
|
@ -36,7 +36,7 @@ pub fn from_data<E: CandleElement, const D: usize>(
|
|||
CandleTensor::from_data(data, *device)
|
||||
}
|
||||
|
||||
pub fn to_data<E: CandleElement, const D: usize>(tensor: &CandleTensor<E, D>) -> Data<E, D> {
|
||||
pub fn into_data<E: CandleElement, const D: usize>(tensor: CandleTensor<E, D>) -> Data<E, D> {
|
||||
Data::new(
|
||||
tensor.tensor.flatten_all().unwrap().to_vec1().unwrap(),
|
||||
tensor.shape(),
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use burn_tensor::{ops::BoolTensorOps, Data, Shape};
|
||||
use burn_tensor::{ops::BoolTensorOps, Data, Reader, Shape};
|
||||
|
||||
use crate::{
|
||||
element::{CandleElement, FloatCandleElement, IntCandleElement},
|
||||
|
@ -18,10 +18,12 @@ impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<CandleBackend<F,
|
|||
super::base::shape(tensor)
|
||||
}
|
||||
|
||||
fn bool_into_data<const D: usize>(tensor: BoolTensor<Self, D>) -> Data<bool, D> {
|
||||
fn bool_into_data<const D: usize>(tensor: BoolTensor<Self, D>) -> Reader<Data<bool, D>> {
|
||||
let x: Vec<u8> = tensor.tensor.flatten_all().unwrap().to_vec1().unwrap();
|
||||
let y = x.iter().map(|b| !matches!(b, 0)).collect();
|
||||
Data::new(y, tensor.shape())
|
||||
let data = Data::new(y, tensor.shape());
|
||||
|
||||
Reader::Concrete(data)
|
||||
}
|
||||
|
||||
fn bool_from_data<const D: usize>(
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use burn_tensor::{ops::IntTensorOps, Bool, Data, Shape};
|
||||
use burn_tensor::{ops::IntTensorOps, Bool, Data, Reader, Shape};
|
||||
|
||||
use crate::{
|
||||
element::{CandleElement, FloatCandleElement, IntCandleElement},
|
||||
|
@ -18,8 +18,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<CandleBackend<F, I
|
|||
super::base::shape(tensor)
|
||||
}
|
||||
|
||||
fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> Data<IntElem<Self>, D> {
|
||||
super::base::to_data(&tensor)
|
||||
fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> Reader<Data<IntElem<Self>, D>> {
|
||||
Reader::Concrete(super::base::into_data(tensor))
|
||||
}
|
||||
|
||||
fn int_from_data<const D: usize>(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::borrow::Borrow;
|
||||
|
||||
use burn_tensor::{ops::TensorOps, Data, Distribution, ElementConversion, Shape};
|
||||
use burn_tensor::{ops::TensorOps, Data, Distribution, ElementConversion, Reader, Shape};
|
||||
use candle_core::{backend::BackendStorage, shape, Tensor};
|
||||
|
||||
use crate::{
|
||||
|
@ -54,8 +54,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<CandleBackend<F, I>>
|
|||
super::base::shape(tensor)
|
||||
}
|
||||
|
||||
fn to_data<const D: usize>(tensor: &CandleTensor<F, D>) -> Data<F, D> {
|
||||
super::base::to_data(tensor)
|
||||
fn into_data<const D: usize>(tensor: CandleTensor<F, D>) -> Reader<Data<F, D>> {
|
||||
Reader::Concrete(super::base::into_data(tensor))
|
||||
}
|
||||
|
||||
fn device<const D: usize>(tensor: &CandleTensor<F, D>) -> Device<Self> {
|
||||
|
|
|
@ -15,14 +15,19 @@ default = ["std"]
|
|||
|
||||
std = ["rand/std"]
|
||||
|
||||
[dependencies]
|
||||
|
||||
[target.'cfg(target_family = "wasm")'.dependencies]
|
||||
async-trait = { workspace = true }
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
|
||||
[dependencies]
|
||||
# ** Please make sure all dependencies support no_std when std is disabled **
|
||||
|
||||
const-random = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
spin = { workspace = true } # using in place of use std::sync::Mutex;
|
||||
spin = { workspace = true } # using in place of use std::sync::Mutex;
|
||||
uuid = { workspace = true }
|
||||
derive-new = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
dashmap = { workspace = true }
|
||||
|
|
|
@ -5,6 +5,9 @@
|
|||
//!
|
||||
//! This library contains common types used by other Burn crates that must be shared.
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
/// Id module contains types for unique identifiers.
|
||||
pub mod id;
|
||||
|
||||
|
@ -15,4 +18,8 @@ pub mod rand;
|
|||
/// Stub module contains types for stubs for non-std environments and for std environments.
|
||||
pub mod stub;
|
||||
|
||||
/// Useful when you need to read async data without having to decorate each function with async
|
||||
/// notation.
|
||||
pub mod reader;
|
||||
|
||||
extern crate alloc;
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
use alloc::boxed::Box;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
#[cfg(target_family = "wasm")]
|
||||
#[async_trait::async_trait]
|
||||
/// Allows to create async reader.
|
||||
pub trait AsyncReader<T>: Send {
|
||||
/// Read asynchronously.
|
||||
async fn read(self: Box<Self>) -> T;
|
||||
}
|
||||
|
||||
/// Define how data is read, sync or async.
|
||||
pub enum Reader<T> {
|
||||
/// Concrete variant.
|
||||
Concrete(T),
|
||||
/// Sync data variant.
|
||||
Sync(Box<dyn SyncReader<T>>),
|
||||
#[cfg(target_family = "wasm")]
|
||||
/// Async data variant.
|
||||
Async(Box<dyn AsyncReader<T>>),
|
||||
#[cfg(target_family = "wasm")]
|
||||
/// Future data variant.
|
||||
Future(core::pin::Pin<Box<dyn core::future::Future<Output = T> + Send>>),
|
||||
}
|
||||
|
||||
/// Allows to create sync reader.
|
||||
pub trait SyncReader<T>: Send {
|
||||
/// Read synchronously.
|
||||
fn read(self: Box<Self>) -> T;
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct MappedReader<I, O, F> {
|
||||
reader: Reader<I>,
|
||||
mapper: F,
|
||||
_output: PhantomData<O>,
|
||||
}
|
||||
|
||||
impl<I, O, F> SyncReader<O> for MappedReader<I, O, F>
|
||||
where
|
||||
I: Send,
|
||||
O: Send,
|
||||
F: Send + FnOnce(I) -> O,
|
||||
{
|
||||
fn read(self: Box<Self>) -> O {
|
||||
let input = self
|
||||
.reader
|
||||
.read_sync()
|
||||
.expect("Only sync data supported in a sync reader.");
|
||||
|
||||
(self.mapper)(input)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_family = "wasm")]
|
||||
#[async_trait::async_trait]
|
||||
impl<I, O, F> AsyncReader<O> for MappedReader<I, O, F>
|
||||
where
|
||||
I: Send,
|
||||
O: Send,
|
||||
F: Send + FnOnce(I) -> O,
|
||||
{
|
||||
async fn read(self: Box<Self>) -> O {
|
||||
let input = self.reader.read().await;
|
||||
(self.mapper)(input)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Reader<T> {
|
||||
#[cfg(target_family = "wasm")]
|
||||
/// Read the data.
|
||||
pub async fn read(self) -> T {
|
||||
match self {
|
||||
Self::Concrete(data) => data,
|
||||
Self::Sync(reader) => reader.read(),
|
||||
Self::Async(func) => func.read().await,
|
||||
Self::Future(future) => future.await,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
/// Read the data.
|
||||
pub fn read(self) -> T {
|
||||
match self {
|
||||
Self::Concrete(data) => data,
|
||||
Self::Sync(reader) => reader.read(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Read the data only if sync, returns None if an async reader.
|
||||
pub fn read_sync(self) -> Option<T> {
|
||||
match self {
|
||||
Self::Concrete(data) => Some(data),
|
||||
Self::Sync(reader) => Some(reader.read()),
|
||||
#[cfg(target_family = "wasm")]
|
||||
Self::Async(_func) => return None,
|
||||
#[cfg(target_family = "wasm")]
|
||||
Self::Future(_future) => return None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Map the current reader to another type.
|
||||
pub fn map<O, F: FnOnce(T) -> O>(self, mapper: F) -> Reader<O>
|
||||
where
|
||||
T: 'static + Send,
|
||||
O: 'static + Send,
|
||||
F: 'static + Send,
|
||||
{
|
||||
#[cfg(target_family = "wasm")]
|
||||
return Reader::Async(Box::new(MappedReader::new(self, mapper)));
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
Reader::Sync(Box::new(MappedReader::new(self, mapper)))
|
||||
}
|
||||
}
|
|
@ -1,11 +1,12 @@
|
|||
use crate::server::{ComputeServer, Handle};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
|
||||
/// The ComputeChannel trait links the ComputeClient to the ComputeServer
|
||||
/// while ensuring thread-safety
|
||||
pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug {
|
||||
/// Given a handle, returns owned resource as bytes
|
||||
fn read(&self, handle: &Handle<Server>) -> Vec<u8>;
|
||||
fn read(&self, handle: &Handle<Server>) -> Reader<Vec<u8>>;
|
||||
|
||||
/// Given a resource as bytes, stores it and returns the resource handle
|
||||
fn create(&self, data: &[u8]) -> Handle<Server>;
|
||||
|
|
|
@ -2,6 +2,7 @@ use super::ComputeChannel;
|
|||
use crate::server::{ComputeServer, Handle};
|
||||
use alloc::sync::Arc;
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
|
||||
/// A channel using a [ref cell](core::cell::RefCell) to access the server with mutability.
|
||||
///
|
||||
|
@ -40,7 +41,7 @@ impl<Server> ComputeChannel<Server> for RefCellComputeChannel<Server>
|
|||
where
|
||||
Server: ComputeServer,
|
||||
{
|
||||
fn read(&self, handle: &Handle<Server>) -> Vec<u8> {
|
||||
fn read(&self, handle: &Handle<Server>) -> Reader<Vec<u8>> {
|
||||
let mut server = self.server.borrow_mut();
|
||||
|
||||
server.read(handle)
|
||||
|
|
|
@ -6,9 +6,9 @@ mod mutex;
|
|||
#[cfg(feature = "channel-mutex")]
|
||||
pub use mutex::*;
|
||||
|
||||
#[cfg(feature = "channel-mpsc")]
|
||||
#[cfg(all(feature = "channel-mpsc", not(target_family = "wasm")))]
|
||||
mod mpsc;
|
||||
#[cfg(feature = "channel-mpsc")]
|
||||
#[cfg(all(feature = "channel-mpsc", not(target_family = "wasm")))]
|
||||
pub use mpsc::*;
|
||||
|
||||
#[cfg(feature = "channel-cell")]
|
||||
|
|
|
@ -3,6 +3,8 @@ use std::{
|
|||
thread,
|
||||
};
|
||||
|
||||
use burn_common::reader::Reader;
|
||||
|
||||
use super::ComputeChannel;
|
||||
use crate::server::{ComputeServer, Handle};
|
||||
|
||||
|
@ -31,7 +33,7 @@ enum Message<Server>
|
|||
where
|
||||
Server: ComputeServer,
|
||||
{
|
||||
Read(Handle<Server>, Callback<Vec<u8>>),
|
||||
Read(Handle<Server>, Callback<Reader<Vec<u8>>>),
|
||||
Create(Vec<u8>, Callback<Handle<Server>>),
|
||||
Empty(usize, Callback<Handle<Server>>),
|
||||
Execute(Server::Kernel, Vec<Handle<Server>>),
|
||||
|
@ -91,7 +93,7 @@ impl<Server> ComputeChannel<Server> for MpscComputeChannel<Server>
|
|||
where
|
||||
Server: ComputeServer + 'static,
|
||||
{
|
||||
fn read(&self, handle: &Handle<Server>) -> Vec<u8> {
|
||||
fn read(&self, handle: &Handle<Server>) -> Reader<Vec<u8>> {
|
||||
let (callback, response) = mpsc::sync_channel(1);
|
||||
|
||||
self.state
|
||||
|
|
|
@ -2,6 +2,7 @@ use super::ComputeChannel;
|
|||
use crate::server::{ComputeServer, Handle};
|
||||
use alloc::sync::Arc;
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
use spin::Mutex;
|
||||
|
||||
/// The MutexComputeChannel ensures thread-safety by locking the server
|
||||
|
@ -34,7 +35,7 @@ impl<Server> ComputeChannel<Server> for MutexComputeChannel<Server>
|
|||
where
|
||||
Server: ComputeServer,
|
||||
{
|
||||
fn read(&self, handle: &Handle<Server>) -> Vec<u8> {
|
||||
fn read(&self, handle: &Handle<Server>) -> Reader<Vec<u8>> {
|
||||
self.server.lock().read(handle)
|
||||
}
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ use crate::{
|
|||
server::{ComputeServer, Handle},
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
/// The ComputeClient is the entry point to require tasks from the ComputeServer.
|
||||
|
@ -40,7 +41,7 @@ where
|
|||
}
|
||||
|
||||
/// Given a handle, returns owned resource as bytes.
|
||||
pub fn read(&self, handle: &Handle<Server>) -> Vec<u8> {
|
||||
pub fn read(&self, handle: &Handle<Server>) -> Reader<Vec<u8>> {
|
||||
self.channel.read(handle)
|
||||
}
|
||||
|
||||
|
|
|
@ -3,6 +3,41 @@ use crate::{
|
|||
storage::ComputeStorage,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
|
||||
/// The compute server is responsible for handling resources and computations over resources.
|
||||
///
|
||||
/// Everything in the server is mutable, therefore it should be solely accessed through the
|
||||
/// [compute channel](crate::channel::ComputeChannel) for thread safety.
|
||||
pub trait ComputeServer: Send + core::fmt::Debug
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
/// The kernel type defines the computation algorithms.
|
||||
type Kernel: Send;
|
||||
/// The [storage](ComputeStorage) type defines how data is stored and accessed.
|
||||
type Storage: ComputeStorage;
|
||||
/// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type.
|
||||
type MemoryManagement: MemoryManagement<Self::Storage>;
|
||||
|
||||
/// Given a handle, returns the owned resource as bytes.
|
||||
fn read(&mut self, handle: &Handle<Self>) -> Reader<Vec<u8>>;
|
||||
|
||||
/// Given a resource as bytes, stores it and returns the memory handle.
|
||||
fn create(&mut self, data: &[u8]) -> Handle<Self>;
|
||||
|
||||
/// Reserves `size` bytes in the storage, and returns a handle over them.
|
||||
fn empty(&mut self, size: usize) -> Handle<Self>;
|
||||
|
||||
/// Executes the `kernel` over the given memory `handles`.
|
||||
///
|
||||
/// Kernels have mutable access to every resource they are given
|
||||
/// and are responsible of determining which should be read or written.
|
||||
fn execute(&mut self, kernel: Self::Kernel, handles: &[&Handle<Self>]);
|
||||
|
||||
/// Wait for the completion of every task in the server.
|
||||
fn sync(&mut self);
|
||||
}
|
||||
|
||||
/// Server handle containing the [memory handle](MemoryManagement::Handle).
|
||||
#[derive(new, Debug)]
|
||||
|
@ -25,37 +60,3 @@ impl<Server: ComputeServer> Clone for Handle<Server> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The compute server is responsible for handling resources and computations over resources.
|
||||
///
|
||||
/// Everything in the server is mutable, therefore it should be solely accessed through the
|
||||
/// [compute channel](crate::channel::ComputeChannel) for thread safety.
|
||||
pub trait ComputeServer: Send + core::fmt::Debug
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
/// The kernel type defines the computation algorithms.
|
||||
type Kernel: Send;
|
||||
/// The [storage](ComputeStorage) type defines how data is stored and accessed.
|
||||
type Storage: ComputeStorage;
|
||||
/// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type.
|
||||
type MemoryManagement: MemoryManagement<Self::Storage>;
|
||||
|
||||
/// Given a handle, returns the owned resource as bytes.
|
||||
fn read(&mut self, handle: &Handle<Self>) -> Vec<u8>;
|
||||
|
||||
/// Given a resource as bytes, stores it and returns the memory handle.
|
||||
fn create(&mut self, data: &[u8]) -> Handle<Self>;
|
||||
|
||||
/// Reserves `size` bytes in the storage, and returns a handle over them.
|
||||
fn empty(&mut self, size: usize) -> Handle<Self>;
|
||||
|
||||
/// Executes the `kernel` over the given memory `handles`.
|
||||
///
|
||||
/// Kernels have mutable access to every resource they are given
|
||||
/// and are responsible of determining which should be read or written.
|
||||
fn execute(&mut self, kernel: Self::Kernel, handles: &[&Handle<Self>]);
|
||||
|
||||
/// Wait for the completion of every task in the server.
|
||||
fn sync(&mut self);
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ impl core::fmt::Debug for BytesStorage {
|
|||
}
|
||||
}
|
||||
|
||||
/// Can send to other threads, but can't sync.
|
||||
/// Can send to other threads.
|
||||
unsafe impl Send for BytesStorage {}
|
||||
unsafe impl Send for BytesResource {}
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use burn_common::reader::Reader;
|
||||
use burn_compute::{
|
||||
memory_management::{MemoryManagement, SimpleMemoryManagement},
|
||||
server::{ComputeServer, Handle},
|
||||
|
@ -22,10 +23,10 @@ where
|
|||
type Storage = BytesStorage;
|
||||
type MemoryManagement = MM;
|
||||
|
||||
fn read(&mut self, handle: &Handle<Self>) -> Vec<u8> {
|
||||
fn read(&mut self, handle: &Handle<Self>) -> Reader<Vec<u8>> {
|
||||
let bytes = self.memory_management.get(&handle.memory);
|
||||
|
||||
bytes.read().to_vec()
|
||||
Reader::Concrete(bytes.read().to_vec())
|
||||
}
|
||||
|
||||
fn create(&mut self, data: &[u8]) -> Handle<Self> {
|
||||
|
|
|
@ -10,7 +10,7 @@ fn created_resource_is_the_same_when_read() {
|
|||
|
||||
let obtained_resource = client.read(&resource_description);
|
||||
|
||||
assert_eq!(resource, obtained_resource)
|
||||
assert_eq!(resource, obtained_resource.read())
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -20,7 +20,7 @@ fn empty_allocates_memory() {
|
|||
let resource_description = client.empty(size);
|
||||
let empty_resource = client.read(&resource_description);
|
||||
|
||||
assert_eq!(empty_resource.len(), 4);
|
||||
assert_eq!(empty_resource.read().len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -34,5 +34,5 @@ fn execute_elementwise_addition() {
|
|||
|
||||
let obtained_resource = client.read(&out);
|
||||
|
||||
assert_eq!(obtained_resource, Vec::from([4, 5, 6]))
|
||||
assert_eq!(obtained_resource.read(), Vec::from([4, 5, 6]))
|
||||
}
|
||||
|
|
|
@ -25,7 +25,6 @@ std = [
|
|||
"half/std",
|
||||
"derive-new/std",
|
||||
]
|
||||
|
||||
dataset = ["burn-dataset/default"]
|
||||
dataset-minimal = ["burn-dataset"]
|
||||
dataset-sqlite = ["burn-dataset/sqlite"]
|
||||
|
@ -42,7 +41,7 @@ ndarray-blas-openblas = ["__ndarray", "ndarray", "burn-ndarray/blas-openblas"]
|
|||
ndarray-blas-openblas-system = ["__ndarray", "ndarray", "burn-ndarray/blas-openblas-system"]
|
||||
__ndarray = [] # Internal flag to know when one ndarray feature is enabled.
|
||||
|
||||
wgpu = ["burn-wgpu"]
|
||||
wgpu = ["burn-wgpu/default"]
|
||||
|
||||
tch = ["burn-tch"]
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::{config::Config, tensor::Tensor};
|
||||
use burn_tensor::{backend::Backend, ElementConversion};
|
||||
use burn_tensor::backend::Backend;
|
||||
|
||||
/// Gradient Clipping provides a way to mitigate exploding gradients
|
||||
#[derive(Config)]
|
||||
|
@ -68,20 +68,26 @@ impl GradientClipping {
|
|||
clipped_grad.mask_fill(lower_mask, -threshold)
|
||||
}
|
||||
|
||||
fn l2_norm<B: Backend, const D: usize>(tensor: Tensor<B, D>) -> Tensor<B, 1> {
|
||||
let squared = tensor.powf(2.0);
|
||||
let sum = squared.sum();
|
||||
|
||||
sum.sqrt()
|
||||
#[cfg(target_family = "wasm")]
|
||||
fn clip_by_norm<B: Backend, const D: usize>(
|
||||
&self,
|
||||
_grad: Tensor<B, D>,
|
||||
_threshold: f32,
|
||||
) -> Tensor<B, D> {
|
||||
todo!("Not yet supported on wasm");
|
||||
}
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
fn clip_by_norm<B: Backend, const D: usize>(
|
||||
&self,
|
||||
grad: Tensor<B, D>,
|
||||
threshold: f32,
|
||||
) -> Tensor<B, D> {
|
||||
use burn_tensor::ElementConversion;
|
||||
|
||||
let norm = Self::l2_norm(grad.clone());
|
||||
let norm_float = norm.into_scalar().elem::<f32>();
|
||||
|
||||
if norm_float > threshold {
|
||||
let scale = threshold / norm_float;
|
||||
grad.mul_scalar(scale)
|
||||
|
@ -89,6 +95,14 @@ impl GradientClipping {
|
|||
grad
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
fn l2_norm<B: Backend, const D: usize>(tensor: Tensor<B, D>) -> Tensor<B, 1> {
|
||||
let squared = tensor.powf(2.0);
|
||||
let sum = squared.sum();
|
||||
|
||||
sum.sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -100,7 +100,7 @@ pub struct ParamSerde<T> {
|
|||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Record for Param<Tensor<B, D>> {
|
||||
type Item<S: PrecisionSettings> = ParamSerde<FloatTensorSerde<B, D, S>>;
|
||||
type Item<S: PrecisionSettings> = ParamSerde<FloatTensorSerde<S>>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
ParamSerde::new(self.id.into_string(), self.value.into_item())
|
||||
|
|
|
@ -5,6 +5,7 @@ use alloc::string::{String, ToString};
|
|||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
|
||||
use super::{BinBytesRecorder, FullPrecisionSettings, PrecisionSettings, Record};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use super::{
|
||||
BinFileRecorder, BinGzFileRecorder, DefaultFileRecorder, HalfPrecisionSettings,
|
||||
|
|
|
@ -5,136 +5,131 @@ use serde::{Deserialize, Serialize};
|
|||
/// This struct implements serde to lazily serialize and deserialize a float tensor
|
||||
/// using the given [record settings](RecordSettings).
|
||||
#[derive(new, Clone, Debug)]
|
||||
pub struct FloatTensorSerde<B: Backend, const D: usize, S: PrecisionSettings> {
|
||||
tensor: Tensor<B, D>,
|
||||
elem: core::marker::PhantomData<S>,
|
||||
pub struct FloatTensorSerde<S: PrecisionSettings> {
|
||||
data: DataSerialize<S::FloatElem>,
|
||||
}
|
||||
|
||||
/// This struct implements serde to lazily serialize and deserialize an int tensor
|
||||
/// using the given [record settings](RecordSettings).
|
||||
#[derive(new, Clone, Debug)]
|
||||
pub struct IntTensorSerde<B: Backend, const D: usize, S: PrecisionSettings> {
|
||||
tensor: Tensor<B, D, Int>,
|
||||
elem: core::marker::PhantomData<S>,
|
||||
pub struct IntTensorSerde<S: PrecisionSettings> {
|
||||
data: DataSerialize<S::IntElem>,
|
||||
}
|
||||
|
||||
/// This struct implements serde to lazily serialize and deserialize an bool tensor.
|
||||
#[derive(new, Clone, Debug)]
|
||||
pub struct BoolTensorSerde<B: Backend, const D: usize> {
|
||||
tensor: Tensor<B, D, Bool>,
|
||||
pub struct BoolTensorSerde {
|
||||
data: DataSerialize<bool>,
|
||||
}
|
||||
|
||||
// --- SERDE IMPLEMENTATIONS --- //
|
||||
|
||||
impl<B: Backend, const D: usize, S: PrecisionSettings> Serialize for FloatTensorSerde<B, D, S> {
|
||||
impl<S: PrecisionSettings> Serialize for FloatTensorSerde<S> {
|
||||
fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
|
||||
where
|
||||
Se: serde::Serializer,
|
||||
{
|
||||
self.tensor
|
||||
.to_data()
|
||||
.convert::<S::FloatElem>()
|
||||
.serialize()
|
||||
.serialize(serializer)
|
||||
self.data.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, B: Backend, const D: usize, S: PrecisionSettings> Deserialize<'de>
|
||||
for FloatTensorSerde<B, D, S>
|
||||
{
|
||||
impl<'de, S: PrecisionSettings> Deserialize<'de> for FloatTensorSerde<S> {
|
||||
fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
|
||||
where
|
||||
De: serde::Deserializer<'de>,
|
||||
{
|
||||
let data = DataSerialize::<S::FloatElem>::deserialize(deserializer)?;
|
||||
let tensor = Tensor::from_data(data.convert::<B::FloatElem>());
|
||||
|
||||
Ok(Self::new(tensor))
|
||||
Ok(Self::new(data))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize, S: PrecisionSettings> Serialize for IntTensorSerde<B, D, S> {
|
||||
// #[cfg(not(target_family = "wasm"))]
|
||||
impl<S: PrecisionSettings> Serialize for IntTensorSerde<S> {
|
||||
fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
|
||||
where
|
||||
Se: serde::Serializer,
|
||||
{
|
||||
self.tensor
|
||||
.to_data()
|
||||
.convert::<S::IntElem>()
|
||||
.serialize()
|
||||
.serialize(serializer)
|
||||
self.data.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, B: Backend, const D: usize, S: PrecisionSettings> Deserialize<'de>
|
||||
for IntTensorSerde<B, D, S>
|
||||
{
|
||||
impl<'de, S: PrecisionSettings> Deserialize<'de> for IntTensorSerde<S> {
|
||||
fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
|
||||
where
|
||||
De: serde::Deserializer<'de>,
|
||||
{
|
||||
let data = DataSerialize::<S::IntElem>::deserialize(deserializer)?;
|
||||
let tensor = Tensor::from_data(data.convert::<B::IntElem>());
|
||||
|
||||
Ok(Self::new(tensor))
|
||||
Ok(Self::new(data))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Serialize for BoolTensorSerde<B, D> {
|
||||
impl Serialize for BoolTensorSerde {
|
||||
fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
|
||||
where
|
||||
Se: serde::Serializer,
|
||||
{
|
||||
self.tensor.to_data().serialize().serialize(serializer)
|
||||
self.data.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, B: Backend, const D: usize> Deserialize<'de> for BoolTensorSerde<B, D> {
|
||||
impl<'de> Deserialize<'de> for BoolTensorSerde {
|
||||
fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
|
||||
where
|
||||
De: serde::Deserializer<'de>,
|
||||
{
|
||||
let data = DataSerialize::<bool>::deserialize(deserializer)?;
|
||||
let tensor = Tensor::from_data(data);
|
||||
|
||||
Ok(Self::new(tensor))
|
||||
Ok(Self::new(data))
|
||||
}
|
||||
}
|
||||
|
||||
// --- RECORD IMPLEMENTATIONS --- //
|
||||
|
||||
impl<B: Backend, const D: usize> Record for Tensor<B, D> {
|
||||
type Item<S: PrecisionSettings> = FloatTensorSerde<B, D, S>;
|
||||
type Item<S: PrecisionSettings> = FloatTensorSerde<S>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
FloatTensorSerde::new(self)
|
||||
#[cfg(target_family = "wasm")]
|
||||
todo!("Recording float tensors isn't yet supported on wasm.");
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
FloatTensorSerde::new(self.into_data().convert().serialize())
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
|
||||
item.tensor
|
||||
Tensor::from_data(item.data.convert::<B::FloatElem>())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Record for Tensor<B, D, Int> {
|
||||
type Item<S: PrecisionSettings> = IntTensorSerde<B, D, S>;
|
||||
type Item<S: PrecisionSettings> = IntTensorSerde<S>;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
IntTensorSerde::new(self)
|
||||
#[cfg(target_family = "wasm")]
|
||||
todo!("Recording int tensors isn't yet supported on wasm.");
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
IntTensorSerde::new(self.into_data().convert().serialize())
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
|
||||
item.tensor
|
||||
Tensor::from_data(item.data.convert())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Record for Tensor<B, D, Bool> {
|
||||
type Item<S: PrecisionSettings> = BoolTensorSerde<B, D>;
|
||||
type Item<S: PrecisionSettings> = BoolTensorSerde;
|
||||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
BoolTensorSerde::new(self)
|
||||
#[cfg(target_family = "wasm")]
|
||||
todo!("Recording bool tensors isn't yet supported on wasm.");
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
BoolTensorSerde::new(self.into_data().serialize())
|
||||
}
|
||||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
|
||||
item.tensor
|
||||
Tensor::from_data(item.data)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
use alloc::vec;
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::ops::{BoolTensorOps, IntTensorOps};
|
||||
use burn_tensor::ElementConversion;
|
||||
use burn_tensor::{ElementConversion, Reader};
|
||||
use core::ops::Range;
|
||||
|
||||
// Current crate
|
||||
|
@ -29,19 +29,13 @@ impl<E: FloatNdArrayElement> BoolTensorOps<NdArrayBackend<E>> for NdArrayBackend
|
|||
tensor.shape()
|
||||
}
|
||||
|
||||
fn bool_to_data<const D: usize>(
|
||||
tensor: &<NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> Data<bool, D> {
|
||||
let values = tensor.array.iter().map(Clone::clone).collect();
|
||||
Data::new(values, tensor.shape())
|
||||
}
|
||||
|
||||
fn bool_into_data<const D: usize>(
|
||||
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> Data<bool, D> {
|
||||
) -> Reader<Data<bool, D>> {
|
||||
let shape = tensor.shape();
|
||||
let values = tensor.array.into_iter().collect();
|
||||
Data::new(values, shape)
|
||||
|
||||
Reader::Concrete(Data::new(values, shape))
|
||||
}
|
||||
|
||||
fn bool_to_device<const D: usize>(
|
||||
|
@ -68,7 +62,9 @@ impl<E: FloatNdArrayElement> BoolTensorOps<NdArrayBackend<E>> for NdArrayBackend
|
|||
fn bool_into_int<const D: usize>(
|
||||
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> NdArrayTensor<i64, D> {
|
||||
let data = Self::bool_into_data(tensor);
|
||||
let data = Self::bool_into_data(tensor)
|
||||
.read_sync()
|
||||
.expect("Always sync with ndarray");
|
||||
NdArrayBackend::<E>::int_from_data(data.convert(), &NdArrayDevice::Cpu)
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
use alloc::vec;
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::ops::IntTensorOps;
|
||||
use burn_tensor::Reader;
|
||||
|
||||
use burn_tensor::ElementConversion;
|
||||
use core::ops::Range;
|
||||
|
||||
|
@ -28,15 +30,11 @@ impl<E: FloatNdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<
|
|||
tensor.shape()
|
||||
}
|
||||
|
||||
fn int_to_data<const D: usize>(tensor: &NdArrayTensor<i64, D>) -> Data<i64, D> {
|
||||
let values = tensor.array.iter().map(Clone::clone).collect();
|
||||
Data::new(values, tensor.shape())
|
||||
}
|
||||
|
||||
fn int_into_data<const D: usize>(tensor: NdArrayTensor<i64, D>) -> Data<i64, D> {
|
||||
fn int_into_data<const D: usize>(tensor: NdArrayTensor<i64, D>) -> Reader<Data<i64, D>> {
|
||||
let shape = tensor.shape();
|
||||
let values = tensor.array.into_iter().collect();
|
||||
Data::new(values, shape)
|
||||
|
||||
Reader::Concrete(Data::new(values, shape))
|
||||
}
|
||||
|
||||
fn int_to_device<const D: usize>(
|
||||
|
|
|
@ -10,8 +10,8 @@ use crate::{NdArrayDevice, SEED};
|
|||
|
||||
// Workspace crates
|
||||
use burn_common::rand::get_seeded_rng;
|
||||
use burn_tensor::Distribution;
|
||||
use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape};
|
||||
use burn_tensor::{Distribution, Reader};
|
||||
|
||||
// External crates
|
||||
use libm::{cos, erf, sin, tanh};
|
||||
|
@ -45,19 +45,13 @@ impl<E: FloatNdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E>
|
|||
tensor.shape()
|
||||
}
|
||||
|
||||
fn to_data<const D: usize>(
|
||||
tensor: &NdArrayTensor<E, D>,
|
||||
) -> Data<<NdArrayBackend<E> as Backend>::FloatElem, D> {
|
||||
let values = tensor.array.iter().map(Clone::clone).collect();
|
||||
Data::new(values, tensor.shape())
|
||||
}
|
||||
|
||||
fn into_data<const D: usize>(
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
) -> Data<<NdArrayBackend<E> as Backend>::FloatElem, D> {
|
||||
) -> Reader<Data<<NdArrayBackend<E> as Backend>::FloatElem, D>> {
|
||||
let shape = tensor.shape();
|
||||
let values = tensor.array.into_iter().collect();
|
||||
Data::new(values, shape)
|
||||
|
||||
Reader::Concrete(Data::new(values, shape))
|
||||
}
|
||||
|
||||
fn device<const D: usize>(_tensor: &NdArrayTensor<E, D>) -> NdArrayDevice {
|
||||
|
|
|
@ -16,8 +16,7 @@ impl<E, const D: usize> NdArrayTensor<E, D> {
|
|||
#[cfg(test)]
|
||||
mod utils {
|
||||
use super::*;
|
||||
use crate::{element::FloatNdArrayElement, NdArrayBackend};
|
||||
use burn_tensor::ops::TensorOps;
|
||||
use crate::element::FloatNdArrayElement;
|
||||
|
||||
impl<E, const D: usize> NdArrayTensor<E, D>
|
||||
where
|
||||
|
@ -27,7 +26,10 @@ mod utils {
|
|||
where
|
||||
E: FloatNdArrayElement,
|
||||
{
|
||||
<NdArrayBackend<E> as TensorOps<NdArrayBackend<E>>>::into_data::<D>(self)
|
||||
let shape = self.shape();
|
||||
let values = self.array.into_iter().collect();
|
||||
|
||||
Data::new(values, shape)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::ops::Range;
|
||||
|
||||
use burn_tensor::{backend::Backend, ops::BoolTensorOps, Data, Shape};
|
||||
use burn_tensor::{backend::Backend, ops::BoolTensorOps, Data, Reader, Shape};
|
||||
|
||||
use crate::{element::TchElement, TchBackend, TchDevice, TchTensor};
|
||||
|
||||
|
@ -26,16 +26,12 @@ impl<E: TchElement> BoolTensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
TchOps::repeat(tensor, dim, times)
|
||||
}
|
||||
|
||||
fn bool_to_data<const D: usize>(tensor: &TchTensor<bool, D>) -> Data<bool, D> {
|
||||
let shape = Self::bool_shape(tensor);
|
||||
fn bool_into_data<const D: usize>(tensor: TchTensor<bool, D>) -> Reader<Data<bool, D>> {
|
||||
let shape = Self::bool_shape(&tensor);
|
||||
let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
|
||||
let values: Result<Vec<bool>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
|
||||
|
||||
Data::new(values.unwrap(), shape)
|
||||
}
|
||||
|
||||
fn bool_into_data<const D: usize>(tensor: TchTensor<bool, D>) -> Data<bool, D> {
|
||||
Self::bool_to_data(&tensor)
|
||||
Reader::Concrete(Data::new(values.unwrap(), shape))
|
||||
}
|
||||
|
||||
fn bool_to_device<const D: usize>(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::ops::Range;
|
||||
|
||||
use burn_tensor::{backend::Backend, ops::IntTensorOps, Data, Shape};
|
||||
use burn_tensor::{backend::Backend, ops::IntTensorOps, Data, Reader, Shape};
|
||||
|
||||
use crate::{element::TchElement, TchBackend, TchDevice, TchShape, TchTensor};
|
||||
|
||||
|
@ -23,16 +23,12 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
TchOps::repeat(tensor, dim, times)
|
||||
}
|
||||
|
||||
fn int_to_data<const D: usize>(tensor: &TchTensor<i64, D>) -> Data<i64, D> {
|
||||
let shape = Self::int_shape(tensor);
|
||||
fn int_into_data<const D: usize>(tensor: TchTensor<i64, D>) -> Reader<Data<i64, D>> {
|
||||
let shape = Self::int_shape(&tensor);
|
||||
let tensor = Self::int_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
|
||||
let values: Result<Vec<i64>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
|
||||
|
||||
Data::new(values.unwrap(), shape)
|
||||
}
|
||||
|
||||
fn int_into_data<const D: usize>(tensor: TchTensor<i64, D>) -> Data<i64, D> {
|
||||
Self::int_to_data(&tensor)
|
||||
Reader::Concrete(Data::new(values.unwrap(), shape))
|
||||
}
|
||||
|
||||
fn int_to_device<const D: usize>(
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
use super::TchOps;
|
||||
use crate::{element::TchElement, TchBackend, TchDevice, TchShape, TchTensor};
|
||||
use burn_tensor::{backend::Backend, ops::TensorOps, Data, Distribution, ElementConversion, Shape};
|
||||
use burn_tensor::{
|
||||
backend::Backend, ops::TensorOps, Data, Distribution, ElementConversion, Reader, Shape,
|
||||
};
|
||||
use std::ops::Range;
|
||||
|
||||
impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
||||
|
@ -79,20 +81,14 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
tensor.shape()
|
||||
}
|
||||
|
||||
fn to_data<const D: usize>(
|
||||
tensor: &<TchBackend<E> as Backend>::TensorPrimitive<D>,
|
||||
) -> Data<<TchBackend<E> as Backend>::FloatElem, D> {
|
||||
let shape = Self::shape(tensor);
|
||||
let tensor = Self::reshape(tensor.clone(), Shape::new([shape.num_elements()]));
|
||||
let values: Result<Vec<E>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
|
||||
|
||||
Data::new(values.unwrap(), shape)
|
||||
}
|
||||
|
||||
fn into_data<const D: usize>(
|
||||
tensor: <TchBackend<E> as Backend>::TensorPrimitive<D>,
|
||||
) -> Data<<TchBackend<E> as Backend>::FloatElem, D> {
|
||||
Self::to_data(&tensor)
|
||||
) -> Reader<Data<<TchBackend<E> as Backend>::FloatElem, D>> {
|
||||
let shape = Self::shape(&tensor);
|
||||
let tensor = Self::reshape(tensor.clone(), Shape::new([shape.num_elements()]));
|
||||
let values: Result<Vec<E>, tch::TchError> = tensor.tensor.try_into();
|
||||
|
||||
Reader::Concrete(Data::new(values.unwrap(), shape))
|
||||
}
|
||||
|
||||
fn device<const D: usize>(tensor: &TchTensor<E, D>) -> TchDevice {
|
||||
|
|
|
@ -200,7 +200,7 @@ mod utils {
|
|||
where
|
||||
P: tch::kind::Element,
|
||||
{
|
||||
<TchBackend<P> as TensorOps<TchBackend<P>>>::into_data(self)
|
||||
<TchBackend<P> as TensorOps<TchBackend<P>>>::into_data(self).read()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ std = ["rand/std", "half/std"]
|
|||
benchmark = []
|
||||
|
||||
[dependencies]
|
||||
burn-common = { path = "../burn-common", version = "0.10.0", default-features = false }
|
||||
burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.10.0", optional = true }
|
||||
|
||||
derive-new = { workspace = true }
|
||||
|
|
|
@ -18,6 +18,9 @@ mod tests;
|
|||
pub use half::{bf16, f16};
|
||||
pub use tensor::*;
|
||||
|
||||
pub use burn_common::reader::Reader; // Useful so that backends don't have to add `burn_common` as
|
||||
// a dependency so that they can implement the traits.
|
||||
|
||||
#[cfg(feature = "benchmark")]
|
||||
/// This module provides benchmark utilities for easily and reliably run
|
||||
/// benches on any function that is generic over a backend.
|
||||
|
|
|
@ -1,8 +1,15 @@
|
|||
#![allow(clippy::single_range_in_vec_init)]
|
||||
use alloc::format;
|
||||
use alloc::string::String;
|
||||
use alloc::vec;
|
||||
|
||||
use alloc::vec::Vec;
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
use alloc::format;
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
use alloc::string::String;
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
use alloc::vec;
|
||||
|
||||
use burn_common::reader::Reader;
|
||||
use core::{fmt::Debug, ops::Range};
|
||||
|
||||
use crate::{
|
||||
|
@ -222,8 +229,6 @@ where
|
|||
|
||||
/// Unsqueeze the current tensor. Create new dimensions to fit the given size.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the output size is higher than the current tensor.
|
||||
///
|
||||
/// # Example
|
||||
|
@ -320,11 +325,25 @@ where
|
|||
Self::new(K::to_device(self.primitive, device))
|
||||
}
|
||||
|
||||
#[cfg(target_family = "wasm")]
|
||||
/// Returns the data of the current tensor.
|
||||
pub fn into_data(self) -> Data<K::Elem, D> {
|
||||
K::into_data(self.primitive)
|
||||
pub async fn into_data(self) -> Data<K::Elem, D> {
|
||||
K::into_data(self.primitive).read().await
|
||||
}
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
/// Returns the data of the current tensor.
|
||||
pub fn into_data(self) -> Data<K::Elem, D> {
|
||||
K::into_data(self.primitive).read()
|
||||
}
|
||||
|
||||
#[cfg(target_family = "wasm")]
|
||||
/// Returns the data of the current tensor.
|
||||
pub async fn to_data(&self) -> Data<K::Elem, D> {
|
||||
K::into_data(self.primitive.clone()).read().await
|
||||
}
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
/// Returns the data of the current tensor without taking ownership.
|
||||
pub fn to_data(&self) -> Data<K::Elem, D> {
|
||||
Self::into_data(self.clone())
|
||||
|
@ -448,6 +467,7 @@ where
|
|||
K: BasicOps<B>,
|
||||
<K as BasicOps<B>>::Elem: Debug,
|
||||
{
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
/// Recursively formats the tensor data for display and appends it to the provided accumulator string.
|
||||
///
|
||||
/// This function is designed to work with tensors of any dimensionality.
|
||||
|
@ -474,7 +494,8 @@ where
|
|||
multi_index[depth] = i;
|
||||
let range: [core::ops::Range<usize>; D] =
|
||||
core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);
|
||||
let elem = &self.clone().slice(range).to_data().value[0];
|
||||
|
||||
let elem = &self.clone().slice(range).into_data().value[0];
|
||||
acc.push_str(&format!("{elem:?}"));
|
||||
}
|
||||
} else {
|
||||
|
@ -506,13 +527,18 @@ where
|
|||
{
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
writeln!(f, "Tensor {{")?;
|
||||
write!(f, " data: ")?;
|
||||
|
||||
let mut acc = String::new();
|
||||
let mut multi_index = vec![0; D];
|
||||
self.display_recursive(&mut acc, 0, &mut multi_index);
|
||||
write!(f, "{acc}")?;
|
||||
writeln!(f, ",")?;
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
{
|
||||
write!(f, " data: ")?;
|
||||
let mut acc = String::new();
|
||||
let mut multi_index = vec![0; D];
|
||||
|
||||
self.display_recursive(&mut acc, 0, &mut multi_index);
|
||||
write!(f, "{acc}")?;
|
||||
writeln!(f, ",")?;
|
||||
}
|
||||
|
||||
writeln!(f, " shape: {:?},", self.dims())?;
|
||||
writeln!(f, " device: {:?},", self.device())?;
|
||||
writeln!(f, " backend: {:?},", B::name())?;
|
||||
|
@ -755,7 +781,7 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
|
|||
///
|
||||
/// For extracting the data of a tensor, users should prefer the [Tensor::into_data](Tensor::into_data) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Data<Self::Elem, D>;
|
||||
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Reader<Data<Self::Elem, D>>;
|
||||
|
||||
/// Creates a tensor from the given data.
|
||||
///
|
||||
|
@ -914,7 +940,7 @@ impl<B: Backend> BasicOps<B> for Float {
|
|||
B::to_device(tensor, device)
|
||||
}
|
||||
|
||||
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Data<Self::Elem, D> {
|
||||
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Reader<Data<Self::Elem, D>> {
|
||||
B::into_data(tensor)
|
||||
}
|
||||
|
||||
|
@ -1001,7 +1027,7 @@ impl<B: Backend> BasicOps<B> for Int {
|
|||
B::int_to_device(tensor, device)
|
||||
}
|
||||
|
||||
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Data<Self::Elem, D> {
|
||||
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Reader<Data<Self::Elem, D>> {
|
||||
B::int_into_data(tensor)
|
||||
}
|
||||
|
||||
|
@ -1088,7 +1114,7 @@ impl<B: Backend> BasicOps<B> for Bool {
|
|||
B::bool_to_device(tensor, device)
|
||||
}
|
||||
|
||||
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Data<Self::Elem, D> {
|
||||
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Reader<Data<Self::Elem, D>> {
|
||||
B::bool_into_data(tensor)
|
||||
}
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ where
|
|||
K: Numeric<B>,
|
||||
K::Elem: Element,
|
||||
{
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
/// Convert the tensor into a scalar.
|
||||
///
|
||||
/// # Panics
|
||||
|
@ -19,6 +20,19 @@ where
|
|||
let data = self.into_data();
|
||||
data.value[0]
|
||||
}
|
||||
|
||||
#[cfg(target_family = "wasm")]
|
||||
/// Convert the tensor into a scalar.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the tensor doesn't have one element.
|
||||
pub async fn into_scalar(self) -> K::Elem {
|
||||
check!(TensorCheck::into_scalar(&self.shape()));
|
||||
let data = self.into_data().await;
|
||||
data.value[0]
|
||||
}
|
||||
|
||||
/// Applies element wise addition operation.
|
||||
///
|
||||
/// `y = x2 + x1`
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use alloc::vec::Vec;
|
||||
use core::ops::Range;
|
||||
|
||||
use crate::{backend::Backend, tensor::Shape, Data};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
use core::ops::Range;
|
||||
|
||||
/// Bool Tensor API for basic operations, see [tensor](crate::Tensor)
|
||||
/// for documentation on each function.
|
||||
|
@ -39,7 +39,7 @@ pub trait BoolTensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The data structure with the tensor's data.
|
||||
fn bool_into_data<const D: usize>(tensor: B::BoolTensorPrimitive<D>) -> Data<bool, D>;
|
||||
fn bool_into_data<const D: usize>(tensor: B::BoolTensorPrimitive<D>) -> Reader<Data<bool, D>>;
|
||||
|
||||
/// Gets the data from the tensor.
|
||||
///
|
||||
|
@ -51,7 +51,7 @@ pub trait BoolTensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The data cloned from the data structure.
|
||||
fn bool_to_data<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> Data<bool, D> {
|
||||
fn bool_to_data<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> Reader<Data<bool, D>> {
|
||||
Self::bool_into_data(tensor.clone())
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use alloc::vec::Vec;
|
||||
use core::ops::Range;
|
||||
|
||||
use crate::{backend::Backend, tensor::Shape, Data, ElementConversion};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
use core::ops::Range;
|
||||
|
||||
/// Int Tensor API for basic and numeric operations, see [tensor](crate::Tensor)
|
||||
/// for documentation on each function.
|
||||
|
@ -38,7 +38,9 @@ pub trait IntTensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The data structure with the tensor's data.
|
||||
fn int_into_data<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> Data<B::IntElem, D>;
|
||||
fn int_into_data<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
) -> Reader<Data<B::IntElem, D>>;
|
||||
|
||||
/// Gets the data from the tensor.
|
||||
///
|
||||
|
@ -49,7 +51,9 @@ pub trait IntTensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The data cloned from the data structure.
|
||||
fn int_to_data<const D: usize>(tensor: &B::IntTensorPrimitive<D>) -> Data<B::IntElem, D> {
|
||||
fn int_to_data<const D: usize>(
|
||||
tensor: &B::IntTensorPrimitive<D>,
|
||||
) -> Reader<Data<B::IntElem, D>> {
|
||||
Self::int_into_data(tensor.clone())
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use alloc::vec::Vec;
|
||||
use core::ops::Range;
|
||||
|
||||
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
use core::ops::Range;
|
||||
|
||||
/// Operations on float tensors.
|
||||
pub trait TensorOps<B: Backend> {
|
||||
|
@ -104,7 +104,9 @@ pub trait TensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The data structure with the tensor's data.
|
||||
fn to_data<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Data<B::FloatElem, D>;
|
||||
fn to_data<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Reader<Data<B::FloatElem, D>> {
|
||||
Self::into_data(tensor.clone())
|
||||
}
|
||||
|
||||
/// Converts the tensor to a data structure.
|
||||
///
|
||||
|
@ -115,9 +117,7 @@ pub trait TensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The data structure with the tensor's data.
|
||||
fn into_data<const D: usize>(tensor: B::TensorPrimitive<D>) -> Data<B::FloatElem, D> {
|
||||
Self::to_data(&tensor)
|
||||
}
|
||||
fn into_data<const D: usize>(tensor: B::TensorPrimitive<D>) -> Reader<Data<B::FloatElem, D>>;
|
||||
|
||||
/// Gets the device of the tensor.
|
||||
///
|
||||
|
|
|
@ -26,7 +26,7 @@ spin = { workspace = true }
|
|||
# WGPU stuff
|
||||
futures-intrusive = { workspace = true }
|
||||
pollster = { workspace = true }
|
||||
wgpu = { workspace = true }
|
||||
wgpu = { workspace = true, features = ["fragile-send-sync-non-atomic-wasm"] }
|
||||
|
||||
# Template
|
||||
serde = { workspace = true }
|
||||
|
|
|
@ -7,7 +7,7 @@ use burn_compute::{
|
|||
memory_management::{DeallocStrategy, SimpleMemoryManagement, SliceStrategy},
|
||||
Compute,
|
||||
};
|
||||
use wgpu::{DeviceDescriptor, DeviceType};
|
||||
use wgpu::DeviceDescriptor;
|
||||
|
||||
type MemoryManagement = SimpleMemoryManagement<WgpuStorage>;
|
||||
type Server = WgpuServer<MemoryManagement>;
|
||||
|
@ -26,41 +26,58 @@ pub fn compute_client<G: GraphicsApi>(device: &WgpuDevice) -> ComputeClient<Serv
|
|||
let device = Arc::new(device);
|
||||
|
||||
COMPUTE.client(&device, move || {
|
||||
let (device_wgpu, queue, info) = pollster::block_on(select_device::<G>(&device));
|
||||
|
||||
log::info!(
|
||||
"Created wgpu compute server on device {:?} => {:?}",
|
||||
device,
|
||||
info
|
||||
);
|
||||
|
||||
// TODO: Support a way to modify max_tasks without std.
|
||||
let max_tasks = match std::env::var("BURN_WGPU_MAX_TASKS") {
|
||||
Ok(value) => value
|
||||
.parse::<usize>()
|
||||
.expect("BURN_WGPU_MAX_TASKS should be a positive integer."),
|
||||
Err(_) => 64, // 64 tasks by default
|
||||
};
|
||||
|
||||
let device = Arc::new(device_wgpu);
|
||||
let storage = WgpuStorage::new(device.clone());
|
||||
let memory_management = SimpleMemoryManagement::new(
|
||||
storage,
|
||||
DeallocStrategy::new_period_tick(1000),
|
||||
SliceStrategy::Ratio(0.9),
|
||||
);
|
||||
let server = WgpuServer::new(memory_management, device, queue, max_tasks);
|
||||
let channel = Channel::new(server);
|
||||
|
||||
ComputeClient::new(channel)
|
||||
pollster::block_on(create_client::<G>(&device))
|
||||
})
|
||||
}
|
||||
|
||||
/// Init the client async, necessary for wasm.
|
||||
pub async fn init_async<G: GraphicsApi>(device: &WgpuDevice) {
|
||||
let device = Arc::new(device);
|
||||
let client = create_client::<G>(&device).await;
|
||||
|
||||
COMPUTE.register(&device, client)
|
||||
}
|
||||
|
||||
async fn create_client<G: GraphicsApi>(device: &WgpuDevice) -> ComputeClient<Server, Channel> {
|
||||
let (device_wgpu, queue, info) = select_device::<G>(device).await;
|
||||
|
||||
log::info!(
|
||||
"Created wgpu compute server on device {:?} => {:?}",
|
||||
device,
|
||||
info
|
||||
);
|
||||
|
||||
// TODO: Support a way to modify max_tasks without std.
|
||||
let max_tasks = match std::env::var("BURN_WGPU_MAX_TASKS") {
|
||||
Ok(value) => value
|
||||
.parse::<usize>()
|
||||
.expect("BURN_WGPU_MAX_TASKS should be a positive integer."),
|
||||
Err(_) => 64, // 64 tasks by default
|
||||
};
|
||||
|
||||
let device = Arc::new(device_wgpu);
|
||||
let storage = WgpuStorage::new(device.clone());
|
||||
let memory_management = SimpleMemoryManagement::new(
|
||||
storage,
|
||||
DeallocStrategy::new_period_tick(1000),
|
||||
SliceStrategy::Ratio(0.9),
|
||||
);
|
||||
let server = WgpuServer::new(memory_management, device, queue, max_tasks);
|
||||
let channel = Channel::new(server);
|
||||
|
||||
ComputeClient::new(channel)
|
||||
}
|
||||
|
||||
/// Select the wgpu device and queue based on the provided [device](WgpuDevice).
|
||||
pub async fn select_device<G: GraphicsApi>(
|
||||
device: &WgpuDevice,
|
||||
) -> (wgpu::Device, wgpu::Queue, wgpu::AdapterInfo) {
|
||||
#[cfg(target_family = "wasm")]
|
||||
let adapter = select_adapter::<G>(device).await;
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
let adapter = select_adapter::<G>(device);
|
||||
|
||||
let limits = adapter.limits();
|
||||
|
||||
let (device, queue) = adapter
|
||||
|
@ -85,9 +102,21 @@ pub async fn select_device<G: GraphicsApi>(
|
|||
(device, queue, adapter.get_info())
|
||||
}
|
||||
|
||||
fn select_adapter<G: GraphicsApi>(device: &WgpuDevice) -> wgpu::Adapter {
|
||||
#[cfg(target_family = "wasm")]
|
||||
async fn select_adapter<G: GraphicsApi>(_device: &WgpuDevice) -> wgpu::Adapter {
|
||||
let instance = wgpu::Instance::default();
|
||||
|
||||
instance
|
||||
.request_adapter(&wgpu::RequestAdapterOptionsBase::default())
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
fn select_adapter<G: GraphicsApi>(device: &WgpuDevice) -> wgpu::Adapter {
|
||||
use wgpu::DeviceType;
|
||||
|
||||
let instance = wgpu::Instance::default();
|
||||
let mut adapters_other = Vec::new();
|
||||
let mut adapters = Vec::new();
|
||||
|
||||
|
|
|
@ -98,7 +98,7 @@ mod tests {
|
|||
|
||||
client.execute(kernel, &[&lhs, &rhs, &out, &info]);
|
||||
|
||||
let data = client.read(&out);
|
||||
let data = client.read(&out).read_sync().unwrap();
|
||||
let output: &[f32] = bytemuck::cast_slice(&data);
|
||||
|
||||
assert_eq!(output, [10., 12., 14., 9., 11., 8., 7., 7.]);
|
||||
|
|
|
@ -5,6 +5,7 @@ use burn_compute::{
|
|||
memory_management::MemoryManagement,
|
||||
server::{self, ComputeServer},
|
||||
};
|
||||
use burn_tensor::Reader;
|
||||
use hashbrown::HashMap;
|
||||
use wgpu::{
|
||||
util::{BufferInitDescriptor, DeviceExt},
|
||||
|
@ -175,17 +176,8 @@ where
|
|||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<MM> ComputeServer for WgpuServer<MM>
|
||||
where
|
||||
MM: MemoryManagement<WgpuStorage>,
|
||||
{
|
||||
type Kernel = Box<dyn Kernel>;
|
||||
type Storage = WgpuStorage;
|
||||
type MemoryManagement = MM;
|
||||
|
||||
fn read(&mut self, handle: &server::Handle<Self>) -> Vec<u8> {
|
||||
fn buffer_reader(&mut self, handle: &server::Handle<Self>) -> BufferReader {
|
||||
// Register previous tasks before reading the buffer so that it is up to date.
|
||||
self.register_tasks();
|
||||
|
||||
|
@ -209,7 +201,28 @@ where
|
|||
|
||||
self.submit();
|
||||
|
||||
let buffer_slice = buffer_dest.slice(..);
|
||||
BufferReader::new(buffer_dest)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct BufferReader {
|
||||
buffer: wgpu::Buffer,
|
||||
}
|
||||
|
||||
impl BufferReader {
|
||||
#[cfg(target_family = "wasm")]
|
||||
async fn read(self, device: alloc::sync::Arc<wgpu::Device>) -> Vec<u8> {
|
||||
self.read_async(&device).await
|
||||
}
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
fn read(self, device: &wgpu::Device) -> Vec<u8> {
|
||||
pollster::block_on(self.read_async(device))
|
||||
}
|
||||
|
||||
async fn read_async(&self, device: &wgpu::Device) -> Vec<u8> {
|
||||
let buffer_slice = self.buffer.slice(..);
|
||||
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
|
||||
buffer_slice.map_async(wgpu::MapMode::Read, move |v| {
|
||||
sender
|
||||
|
@ -217,21 +230,41 @@ where
|
|||
.expect("Unable to send buffer slice result to async channel.")
|
||||
});
|
||||
|
||||
self.device.poll(wgpu::Maintain::Wait);
|
||||
device.poll(wgpu::Maintain::Wait);
|
||||
|
||||
let result = pollster::block_on(receiver.receive());
|
||||
let result = receiver.receive().await;
|
||||
|
||||
if let Some(Ok(())) = result {
|
||||
let data = buffer_slice.get_mapped_range();
|
||||
let result = bytemuck::cast_slice(&data).to_vec();
|
||||
|
||||
drop(data);
|
||||
buffer_dest.unmap();
|
||||
self.buffer.unmap();
|
||||
result
|
||||
} else {
|
||||
panic!("Unable to read buffer {:?}", result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<MM> ComputeServer for WgpuServer<MM>
|
||||
where
|
||||
MM: MemoryManagement<WgpuStorage>,
|
||||
{
|
||||
type Kernel = Box<dyn Kernel>;
|
||||
type Storage = WgpuStorage;
|
||||
type MemoryManagement = MM;
|
||||
|
||||
fn read(&mut self, handle: &server::Handle<Self>) -> Reader<Vec<u8>> {
|
||||
#[cfg(target_family = "wasm")]
|
||||
{
|
||||
let future = self.buffer_reader(handle).read(self.device.clone());
|
||||
return Reader::Future(Box::pin(future));
|
||||
}
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
Reader::Concrete(self.buffer_reader(handle).read(&self.device))
|
||||
}
|
||||
|
||||
/// When we create a new handle from existing data, we use custom allocations so that we don't
|
||||
/// have to execute the current pending tasks.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use burn_tensor::Element;
|
||||
|
||||
/// The base element trait for the wgou backend.
|
||||
pub trait WgpuElement: core::fmt::Debug + Send + Sync + 'static + Clone
|
||||
pub trait WgpuElement: core::fmt::Debug + Send + Sync + 'static + Clone + bytemuck::Pod
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
|
|
|
@ -6,6 +6,11 @@ use crate::{
|
|||
};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(target_family = "wasm")]
|
||||
pub(crate) const WORKGROUP_DEFAULT: usize = 16;
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
pub(crate) const WORKGROUP_DEFAULT: usize = 32;
|
||||
|
||||
/// Static wgpu kernel to create a [source template](SourceTemplate).
|
||||
pub trait StaticKernelSource: Send + 'static {
|
||||
/// Source template for the kernel.
|
||||
|
@ -49,8 +54,6 @@ pub fn into_contiguous<E: WgpuElement, const D: usize>(
|
|||
return tensor;
|
||||
}
|
||||
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let num_elems = tensor.shape.num_elements();
|
||||
let handle = tensor.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(
|
||||
|
@ -64,8 +67,11 @@ pub fn into_contiguous<E: WgpuElement, const D: usize>(
|
|||
|
||||
tensor.client.execute(
|
||||
Box::new(StaticKernel::<
|
||||
KernelSettings<ContiguousRaw, E, i32, WORKGROUP, WORKGROUP, 1>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP))),
|
||||
KernelSettings<ContiguousRaw, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(
|
||||
num_elems,
|
||||
WORKGROUP_DEFAULT,
|
||||
))),
|
||||
&[&tensor.handle, &output.handle, &info_handle],
|
||||
);
|
||||
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
use super::{build_info, elemwise_workgroup, KernelSettings, StaticKernelSource};
|
||||
use super::{
|
||||
build_info, elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT,
|
||||
};
|
||||
use crate::compute::StaticKernel;
|
||||
use crate::{element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
|
||||
use burn_tensor::Shape;
|
||||
|
@ -54,7 +56,7 @@ pub fn binary_elemwise_default<K: StaticKernelSource, E: WgpuElement, const D: u
|
|||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
binary_elemwise::<K, E, D, 32>(lhs, rhs)
|
||||
binary_elemwise::<K, E, D, WORKGROUP_DEFAULT>(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Execute a binary kernel using the provided WORKGROUP.
|
||||
|
@ -105,7 +107,7 @@ pub fn binary_elemwise_inplace_default<K: StaticKernelSource, E: WgpuElement, co
|
|||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
binary_elemwise_inplace::<K, E, D, 32>(lhs, rhs)
|
||||
binary_elemwise_inplace::<K, E, D, WORKGROUP_DEFAULT>(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Execute a binary inplace kernel using the provided WORKGROUP.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{KernelSettings, SourceTemplate, StaticKernelSource};
|
||||
use super::{KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT};
|
||||
use crate::{
|
||||
compute::StaticKernel, element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl,
|
||||
tensor::WgpuTensor,
|
||||
|
@ -30,12 +30,17 @@ pub fn cast<InputElem: WgpuElement, OutputElem: WgpuElement, const D: usize>(
|
|||
return WgpuTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle);
|
||||
}
|
||||
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let num_elems = tensor.shape.num_elements();
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<Cast<InputElem, OutputElem>, f32, i32, WORKGROUP, WORKGROUP, 1>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP));
|
||||
KernelSettings<
|
||||
Cast<InputElem, OutputElem>,
|
||||
f32,
|
||||
i32,
|
||||
WORKGROUP_DEFAULT,
|
||||
WORKGROUP_DEFAULT,
|
||||
1,
|
||||
>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
|
||||
|
||||
let handle = tensor
|
||||
.client
|
||||
|
|
|
@ -6,14 +6,14 @@ use crate::{
|
|||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
use super::WORKGROUP_DEFAULT;
|
||||
|
||||
kernel_wgsl!(Cat, "../template/cat.wgsl");
|
||||
|
||||
pub fn cat<E: WgpuElement, const D: usize>(
|
||||
inputs: Vec<WgpuTensor<E, D>>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<E, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let first_input = inputs.get(0).unwrap();
|
||||
let client = &first_input.client;
|
||||
let mut shape_output = first_input.shape.clone();
|
||||
|
@ -38,9 +38,12 @@ pub fn cat<E: WgpuElement, const D: usize>(
|
|||
info.push(dim_cat_index as u32);
|
||||
dim_cat_index += input.shape.dims[dim];
|
||||
let info_buffer = client.create(bytemuck::cast_slice(&info));
|
||||
let kernel = StaticKernel::<KernelSettings<Cat, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(input.shape.num_elements(), WORKGROUP),
|
||||
);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<Cat, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(
|
||||
input.shape.num_elements(),
|
||||
WORKGROUP_DEFAULT,
|
||||
));
|
||||
|
||||
client.execute(
|
||||
Box::new(kernel),
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{build_info, elemwise_workgroup, KernelSettings, StaticKernelSource},
|
||||
kernel::{
|
||||
build_info, elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT,
|
||||
},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
|
@ -66,8 +68,6 @@ pub fn comparison<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
|||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
lhs.assert_is_on_same_device(&rhs);
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let mut shape_out = [0; D];
|
||||
lhs.shape
|
||||
.dims
|
||||
|
@ -83,9 +83,10 @@ pub fn comparison<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
|||
|
||||
let output = empty_device(lhs.client.clone(), lhs.device.clone(), shape_out);
|
||||
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
);
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<K, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP_DEFAULT),
|
||||
);
|
||||
let info = build_info(&[&lhs, &rhs, &output]);
|
||||
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
|
@ -101,13 +102,12 @@ pub fn comparison_inplace<K: StaticKernelSource, E: WgpuElement, const D: usize>
|
|||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
lhs.assert_is_on_same_device(&rhs);
|
||||
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP),
|
||||
);
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<K, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>>::new(
|
||||
elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP_DEFAULT),
|
||||
);
|
||||
let info = build_info(&[&lhs, &rhs]);
|
||||
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{elemwise_workgroup, KernelSettings, StaticKernelSource},
|
||||
kernel::{elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT},
|
||||
kernel_wgsl,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
@ -58,14 +58,14 @@ pub fn comparison_elem<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
|||
lhs: WgpuTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
let num_elems = lhs.shape.num_elements();
|
||||
|
||||
let handle = lhs.client.empty(num_elems * core::mem::size_of::<u32>());
|
||||
let rhs_handle = lhs.client.create(E::as_bytes(&[rhs]));
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
);
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<K, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP_DEFAULT),
|
||||
);
|
||||
|
||||
lhs.client
|
||||
.execute(Box::new(kernel), &[&lhs.handle, &rhs_handle, &handle]);
|
||||
|
@ -77,11 +77,10 @@ pub fn comparison_elem_inplace<K: StaticKernelSource, E: WgpuElement, const D: u
|
|||
lhs: WgpuTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP),
|
||||
);
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<K, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>>::new(
|
||||
elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP_DEFAULT),
|
||||
);
|
||||
let rhs_handle = lhs.client.create(E::as_bytes(&[rhs]));
|
||||
lhs.client
|
||||
.execute(Box::new(kernel), &[&lhs.handle, &rhs_handle]);
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{self, build_info, elemwise_workgroup, KernelSettings},
|
||||
kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
|
@ -19,8 +19,6 @@ pub(crate) fn conv2d<E: WgpuElement + Element>(
|
|||
bias: Option<WgpuTensor<E, 1>>,
|
||||
options: ConvOptions<2>,
|
||||
) -> WgpuTensor<E, 4> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let input = kernel::into_contiguous(input);
|
||||
let weight = kernel::into_contiguous(weight);
|
||||
let [batch_size, _, in_height, in_width] = input.shape.dims;
|
||||
|
@ -64,9 +62,12 @@ pub(crate) fn conv2d<E: WgpuElement + Element>(
|
|||
|
||||
let info_handle = input.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = StaticKernel::<KernelSettings<Conv2d, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
|
||||
);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<Conv2d, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(
|
||||
output.shape.num_elements(),
|
||||
WORKGROUP_DEFAULT,
|
||||
));
|
||||
|
||||
input.client.execute(
|
||||
Box::new(kernel),
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{self, build_info, elemwise_workgroup, KernelSettings},
|
||||
kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
|
@ -16,8 +16,6 @@ pub(crate) fn conv_transpose2d<E: WgpuElement + Element>(
|
|||
bias: Option<WgpuTensor<E, 1>>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> WgpuTensor<E, 4> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let input = kernel::into_contiguous(input);
|
||||
let weight = kernel::into_contiguous(weight);
|
||||
let [batch_size, _, in_height, in_width] = input.shape.dims;
|
||||
|
@ -58,10 +56,9 @@ pub(crate) fn conv_transpose2d<E: WgpuElement + Element>(
|
|||
|
||||
let info_handle = input.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<ConvTranspose2d, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<ConvTranspose2d, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
|
||||
input.client.execute(
|
||||
Box::new(kernel),
|
||||
&[
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{self, build_info, elemwise_workgroup, KernelSettings},
|
||||
kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
|
@ -14,8 +14,6 @@ pub(crate) fn gather<E: WgpuElement, I: WgpuElement, const D: usize>(
|
|||
tensor: WgpuTensor<E, D>,
|
||||
indices: WgpuTensor<I, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let shape_output = indices.shape.clone();
|
||||
let num_elems = shape_output.num_elements();
|
||||
let indices = kernel::into_contiguous(indices);
|
||||
|
@ -25,9 +23,9 @@ pub(crate) fn gather<E: WgpuElement, I: WgpuElement, const D: usize>(
|
|||
info.push(dim as u32);
|
||||
let info_handle = tensor.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = StaticKernel::<KernelSettings<Gather, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<Gather, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
|
||||
|
||||
tensor.client.execute(
|
||||
Box::new(kernel),
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{self, build_info, elemwise_workgroup, KernelSettings},
|
||||
kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
|
||||
kernel_wgsl,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
@ -14,8 +14,6 @@ pub(crate) fn scatter<E: WgpuElement, I: WgpuElement, const D: usize>(
|
|||
indices: WgpuTensor<I, D>,
|
||||
value: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let indices = kernel::into_contiguous(indices);
|
||||
let tensor = kernel::into_contiguous(tensor);
|
||||
let value = kernel::into_contiguous(value);
|
||||
|
@ -51,9 +49,12 @@ pub(crate) fn scatter<E: WgpuElement, I: WgpuElement, const D: usize>(
|
|||
|
||||
let info_handle = tensor.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = StaticKernel::<KernelSettings<Scatter, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems_per_workgroup, WORKGROUP),
|
||||
);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<Scatter, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(
|
||||
num_elems_per_workgroup,
|
||||
WORKGROUP_DEFAULT,
|
||||
));
|
||||
|
||||
tensor.client.execute(
|
||||
Box::new(kernel),
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{build_info, elemwise_workgroup, KernelSettings},
|
||||
kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
|
@ -18,8 +18,6 @@ pub(crate) fn select<E: WgpuElement, I: WgpuElement, const D: usize>(
|
|||
dim: usize,
|
||||
indices: WgpuTensor<I, 1>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let mut output_shape = tensor.shape.clone();
|
||||
output_shape.dims[dim] = indices.shape.dims[0];
|
||||
|
||||
|
@ -30,9 +28,9 @@ pub(crate) fn select<E: WgpuElement, I: WgpuElement, const D: usize>(
|
|||
info.push(dim as u32);
|
||||
|
||||
let info_handle = output.client.create(bytemuck::cast_slice(&info));
|
||||
let kernel = StaticKernel::<KernelSettings<IndexSelect, E, I, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<IndexSelect, E, I, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
|
||||
|
||||
tensor.client.execute(
|
||||
Box::new(kernel),
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{build_info, elemwise_workgroup, KernelSettings},
|
||||
kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
|
@ -19,8 +19,6 @@ pub(crate) fn slice<E: WgpuElement, const D1: usize, const D2: usize>(
|
|||
tensor: WgpuTensor<E, D1>,
|
||||
indices: [Range<usize>; D2],
|
||||
) -> WgpuTensor<E, D1> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let mut dims = tensor.shape.dims;
|
||||
for i in 0..D2 {
|
||||
dims[i] = indices[i].end - indices[i].start;
|
||||
|
@ -38,9 +36,9 @@ pub(crate) fn slice<E: WgpuElement, const D1: usize, const D2: usize>(
|
|||
|
||||
let info_handle = output.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = StaticKernel::<KernelSettings<IndexRaw, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<IndexRaw, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
|
||||
|
||||
tensor.client.execute(
|
||||
Box::new(kernel),
|
||||
|
@ -55,8 +53,6 @@ pub(crate) fn slice_assign<E: WgpuElement, const D1: usize, const D2: usize>(
|
|||
indices: [Range<usize>; D2],
|
||||
value: WgpuTensor<E, D1>,
|
||||
) -> WgpuTensor<E, D1> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let tensor = match tensor.can_mut() {
|
||||
true => tensor,
|
||||
false => tensor.copy(),
|
||||
|
@ -72,8 +68,8 @@ pub(crate) fn slice_assign<E: WgpuElement, const D1: usize, const D2: usize>(
|
|||
let info_handle = tensor.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<IndexAssignInplaceRaw, E, i32, WORKGROUP, WORKGROUP, 1>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP));
|
||||
KernelSettings<IndexAssignInplaceRaw, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
|
||||
|
||||
tensor.client.execute(
|
||||
Box::new(kernel),
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{build_info, elemwise_workgroup, KernelSettings},
|
||||
kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
|
@ -15,8 +15,6 @@ pub fn mask_fill<E: WgpuElement, const D: usize>(
|
|||
mask: WgpuTensor<u32, D>,
|
||||
value: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let num_elems = input.shape.num_elements();
|
||||
let output = empty_device(
|
||||
input.client.clone(),
|
||||
|
@ -25,9 +23,9 @@ pub fn mask_fill<E: WgpuElement, const D: usize>(
|
|||
);
|
||||
|
||||
let value_handle = output.client.create(E::as_bytes(&[value]));
|
||||
let kernel = StaticKernel::<KernelSettings<MaskFill, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<MaskFill, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
|
||||
let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle);
|
||||
let info = build_info(&[&input, &mask, &output]);
|
||||
let info_handle = input.client.create(bytemuck::cast_slice(&info));
|
||||
|
@ -51,14 +49,11 @@ pub fn mask_fill_inplace<E: WgpuElement, const D: usize>(
|
|||
mask: WgpuTensor<u32, D>,
|
||||
value: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let num_elems = input.shape.num_elements();
|
||||
let value_handle = input.client.create(E::as_bytes(&[value]));
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<MaskFillInplace, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<MaskFillInplace, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
|
||||
let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle);
|
||||
let info = build_info(&[&input, &mask]);
|
||||
let info_handle = input.client.create(bytemuck::cast_slice(&info));
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{build_info, elemwise_workgroup, KernelSettings},
|
||||
kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
|
@ -15,8 +15,6 @@ pub fn mask_where<E: WgpuElement, const D: usize>(
|
|||
mask: WgpuTensor<u32, D>,
|
||||
value: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let num_elems = input.shape.num_elements();
|
||||
let output = empty_device(
|
||||
input.client.clone(),
|
||||
|
@ -24,9 +22,9 @@ pub fn mask_where<E: WgpuElement, const D: usize>(
|
|||
input.shape.clone(),
|
||||
);
|
||||
|
||||
let kernel = StaticKernel::<KernelSettings<MaskWhere, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<MaskWhere, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
|
||||
let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle);
|
||||
let info = build_info(&[&input, &value, &mask, &output]);
|
||||
let info_handle = input.client.create(bytemuck::cast_slice(&info));
|
||||
|
@ -51,12 +49,12 @@ pub fn mask_where_inplace<E: WgpuElement, const D: usize>(
|
|||
value: WgpuTensor<E, D>,
|
||||
reverse: bool,
|
||||
) -> WgpuTensor<E, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<MaskWhereInplace, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(input.shape.num_elements(), WORKGROUP),
|
||||
);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<MaskWhereInplace, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(
|
||||
input.shape.num_elements(),
|
||||
WORKGROUP_DEFAULT,
|
||||
));
|
||||
let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle);
|
||||
let mut info = build_info(&[&input, &value, &mask]);
|
||||
info.push(match reverse {
|
||||
|
|
|
@ -6,6 +6,7 @@ use crate::{
|
|||
element::WgpuElement,
|
||||
kernel::{
|
||||
build_info, into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource,
|
||||
WORKGROUP_DEFAULT,
|
||||
},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
|
@ -43,7 +44,7 @@ pub fn matmul_mem_coalescing_default<E: WgpuElement, const D: usize>(
|
|||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
matmul_mem_coalescing::<E, D>(lhs, rhs, 16, 16)
|
||||
matmul_mem_coalescing::<E, D>(lhs, rhs, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT)
|
||||
}
|
||||
|
||||
/// Matrix multiplication using memory coalescing algorithm with custom workgroup sizes
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
compute::{StaticKernel, WgpuHandle},
|
||||
element::WgpuElement,
|
||||
kernel::{elemwise_workgroup, KernelSettings},
|
||||
kernel::{elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
|
@ -21,17 +21,17 @@ pub(crate) fn adaptive_avg_pool2d<E: WgpuElement>(
|
|||
x: WgpuTensor<E, 4>,
|
||||
output_size: [usize; 2],
|
||||
) -> WgpuTensor<E, 4> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let [batch_size, channels, _, _] = x.shape.dims;
|
||||
|
||||
let output_shape = Shape::new([batch_size, channels, output_size[0], output_size[1]]);
|
||||
let output = empty_device(x.client.clone(), x.device.clone(), output_shape);
|
||||
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<AdaptiveAvgPool2d, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
|
||||
);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<AdaptiveAvgPool2d, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(
|
||||
output.shape.num_elements(),
|
||||
WORKGROUP_DEFAULT,
|
||||
));
|
||||
|
||||
let info_handle = build_info(&x, &output);
|
||||
x.client
|
||||
|
@ -44,8 +44,6 @@ pub(crate) fn adaptive_avg_pool2d_backward<E: WgpuElement>(
|
|||
x: WgpuTensor<E, 4>,
|
||||
out_grad: WgpuTensor<E, 4>,
|
||||
) -> WgpuTensor<E, 4> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let output_shape = x.shape.clone();
|
||||
let num_elems = output_shape.num_elements();
|
||||
let output_buffer = x.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
|
@ -57,8 +55,11 @@ pub(crate) fn adaptive_avg_pool2d_backward<E: WgpuElement>(
|
|||
);
|
||||
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<AdaptiveAvgPool2dBackward, E, i32, WORKGROUP, WORKGROUP, 1>,
|
||||
>::new(elemwise_workgroup(output.shape.num_elements(), WORKGROUP));
|
||||
KernelSettings<AdaptiveAvgPool2dBackward, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(
|
||||
output.shape.num_elements(),
|
||||
WORKGROUP_DEFAULT,
|
||||
));
|
||||
|
||||
let info_handle = build_info(&x, &out_grad);
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ use crate::{
|
|||
kernel::{
|
||||
self, elemwise_workgroup,
|
||||
pool::{build_output_and_info_pool2d, build_pool2d_info},
|
||||
KernelSettings, StaticKernelSource,
|
||||
KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT,
|
||||
},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
|
@ -39,18 +39,16 @@ pub(crate) fn avg_pool2d<E: WgpuElement>(
|
|||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
) -> WgpuTensor<E, 4> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let (info_handle, output) =
|
||||
build_output_and_info_pool2d(&x, kernel_size, stride, padding, [1, 1]);
|
||||
|
||||
let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP);
|
||||
let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT);
|
||||
let kernel: Box<dyn Kernel> = match count_include_pad {
|
||||
true => Box::new(StaticKernel::<
|
||||
KernelSettings<AvgPool2d<true>, E, i32, WORKGROUP, WORKGROUP, 1>,
|
||||
KernelSettings<AvgPool2d<true>, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(workgroup)),
|
||||
false => Box::new(StaticKernel::<
|
||||
KernelSettings<AvgPool2d<false>, E, i32, WORKGROUP, WORKGROUP, 1>,
|
||||
KernelSettings<AvgPool2d<false>, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(workgroup)),
|
||||
};
|
||||
|
||||
|
@ -68,19 +66,31 @@ pub(crate) fn avg_pool2d_backward<E: WgpuElement>(
|
|||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
) -> WgpuTensor<E, 4> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let grad = kernel::into_contiguous(grad);
|
||||
let output = empty_device(x.client.clone(), x.device.clone(), x.shape.clone());
|
||||
let info_handle = build_pool2d_info(&x, &grad, kernel_size, stride, padding, [1, 1]);
|
||||
let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP);
|
||||
let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT);
|
||||
|
||||
let kernel: Box<dyn Kernel> = match count_include_pad {
|
||||
true => Box::new(StaticKernel::<
|
||||
KernelSettings<AvgPool2dBackward<true>, E, i32, WORKGROUP, WORKGROUP, 1>,
|
||||
KernelSettings<
|
||||
AvgPool2dBackward<true>,
|
||||
E,
|
||||
i32,
|
||||
WORKGROUP_DEFAULT,
|
||||
WORKGROUP_DEFAULT,
|
||||
1,
|
||||
>,
|
||||
>::new(workgroup)),
|
||||
false => Box::new(StaticKernel::<
|
||||
KernelSettings<AvgPool2dBackward<false>, E, i32, WORKGROUP, WORKGROUP, 1>,
|
||||
KernelSettings<
|
||||
AvgPool2dBackward<false>,
|
||||
E,
|
||||
i32,
|
||||
WORKGROUP_DEFAULT,
|
||||
WORKGROUP_DEFAULT,
|
||||
1,
|
||||
>,
|
||||
>::new(workgroup)),
|
||||
};
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ use crate::{
|
|||
kernel::{
|
||||
self, elemwise_workgroup,
|
||||
pool::{build_output_and_info_pool2d, build_pool2d_info},
|
||||
KernelSettings,
|
||||
KernelSettings, WORKGROUP_DEFAULT,
|
||||
},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
|
@ -28,13 +28,14 @@ pub(crate) fn max_pool2d<E: WgpuElement>(
|
|||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> WgpuTensor<E, 4> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let (info_handle, output) =
|
||||
build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation);
|
||||
let kernel = StaticKernel::<KernelSettings<MaxPool2d, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
|
||||
);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<MaxPool2d, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(
|
||||
output.shape.num_elements(),
|
||||
WORKGROUP_DEFAULT,
|
||||
));
|
||||
|
||||
x.client
|
||||
.execute(Box::new(kernel), &[&x.handle, &output.handle, &info_handle]);
|
||||
|
@ -49,15 +50,16 @@ pub(crate) fn max_pool2d_with_indices<E: WgpuElement, I: WgpuElement>(
|
|||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> (WgpuTensor<E, 4>, WgpuTensor<I, 4>) {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let (info_handle, output) =
|
||||
build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation);
|
||||
let indices = empty_device(x.client.clone(), x.device, output.shape.clone());
|
||||
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<MaxPool2dWithIndices, E, i32, WORKGROUP, WORKGROUP, 1>,
|
||||
>::new(elemwise_workgroup(output.shape.num_elements(), WORKGROUP));
|
||||
KernelSettings<MaxPool2dWithIndices, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(
|
||||
output.shape.num_elements(),
|
||||
WORKGROUP_DEFAULT,
|
||||
));
|
||||
|
||||
x.client.execute(
|
||||
Box::new(kernel),
|
||||
|
@ -76,8 +78,6 @@ pub(crate) fn max_pool2d_with_indices_backward<E: WgpuElement, I: WgpuElement>(
|
|||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> WgpuTensor<E, 4> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let grad = kernel::into_contiguous(grad);
|
||||
let indices = kernel::into_contiguous(indices);
|
||||
|
||||
|
@ -88,8 +88,11 @@ pub(crate) fn max_pool2d_with_indices_backward<E: WgpuElement, I: WgpuElement>(
|
|||
let info_handle = build_pool2d_info(&x, &grad, kernel_size, stride, padding, dilation);
|
||||
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<MaxPool2dWithIndicesBackward, E, I, WORKGROUP, WORKGROUP, 1>,
|
||||
>::new(elemwise_workgroup(output.shape.num_elements(), WORKGROUP));
|
||||
KernelSettings<MaxPool2dWithIndicesBackward, E, I, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(
|
||||
output.shape.num_elements(),
|
||||
WORKGROUP_DEFAULT,
|
||||
));
|
||||
|
||||
x.client.execute(
|
||||
Box::new(kernel),
|
||||
|
|
|
@ -5,7 +5,7 @@ use crate::{
|
|||
element::WgpuElement,
|
||||
kernel::{
|
||||
prng::base::{make_args_buffer, make_info_buffer},
|
||||
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource,
|
||||
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT,
|
||||
},
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
|
@ -34,18 +34,16 @@ pub fn random_bernoulli<G: GraphicsApi, E: WgpuElement, const D: usize>(
|
|||
device: &WgpuDevice,
|
||||
prob: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
const N_VALUES_PER_THREAD: usize = 128;
|
||||
|
||||
let client = compute_client::<G>(device);
|
||||
let output = empty_device(client.clone(), device.clone(), shape.clone());
|
||||
let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD);
|
||||
let args_handle = make_args_buffer(client.clone(), &[prob]);
|
||||
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD);
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<BernoulliPrng, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
workgroup,
|
||||
);
|
||||
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<BernoulliPrng, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(workgroup);
|
||||
|
||||
client.execute(
|
||||
Box::new(kernel),
|
||||
|
|
|
@ -5,7 +5,7 @@ use crate::{
|
|||
element::WgpuElement,
|
||||
kernel::{
|
||||
prng::base::{make_args_buffer, make_info_buffer},
|
||||
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource,
|
||||
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT,
|
||||
},
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
|
@ -37,16 +37,16 @@ pub fn random_normal<G: GraphicsApi, E: WgpuElement, const D: usize>(
|
|||
mean: E,
|
||||
std: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
const N_VALUES_PER_THREAD: usize = 128; // must be even
|
||||
|
||||
let client = compute_client::<G>(device);
|
||||
let output = empty_device(client.clone(), device.clone(), shape.clone());
|
||||
let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD);
|
||||
let args_handle = make_args_buffer(client.clone(), &[mean, std]);
|
||||
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD);
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<NormalPrng, E, i32, WORKGROUP, WORKGROUP, 1>>::new(workgroup);
|
||||
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<NormalPrng, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(workgroup);
|
||||
|
||||
client.execute(
|
||||
Box::new(kernel),
|
||||
|
|
|
@ -5,7 +5,7 @@ use crate::{
|
|||
element::WgpuElement,
|
||||
kernel::{
|
||||
prng::base::{make_args_buffer, make_info_buffer},
|
||||
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource,
|
||||
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT,
|
||||
},
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
|
@ -32,17 +32,16 @@ pub fn random_uniform<G: GraphicsApi, E: WgpuElement, const D: usize>(
|
|||
low: E,
|
||||
high: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
const N_VALUES_PER_THREAD: usize = 128;
|
||||
|
||||
let client = compute_client::<G>(device);
|
||||
let output = empty_device(client.clone(), device.clone(), shape.clone());
|
||||
let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD);
|
||||
let args_handle = make_args_buffer(client.clone(), &[low, high]);
|
||||
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD);
|
||||
let kernel = StaticKernel::<KernelSettings<UniformPrng, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
workgroup,
|
||||
);
|
||||
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<UniformPrng, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(workgroup);
|
||||
|
||||
client.execute(
|
||||
Box::new(kernel),
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{build_info, KernelSettings, SourceTemplate, StaticKernelSource};
|
||||
use super::{build_info, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT};
|
||||
use crate::{
|
||||
compute::StaticKernel, element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl,
|
||||
tensor::WgpuTensor,
|
||||
|
@ -50,10 +50,8 @@ impl StaticKernelSource for ArgsMin {
|
|||
|
||||
/// Sum all elements in the input buffer.
|
||||
pub fn sum<E: WgpuElement, const D: usize>(input: WgpuTensor<E, D>) -> WgpuTensor<E, 1> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let mut input_handle = input.handle;
|
||||
let mut workgroup = elemwise_workgroup(input.shape.num_elements(), WORKGROUP);
|
||||
let mut workgroup = elemwise_workgroup(input.shape.num_elements(), WORKGROUP_DEFAULT);
|
||||
|
||||
loop {
|
||||
let num_invocations = workgroup.num_invocations();
|
||||
|
@ -61,10 +59,9 @@ pub fn sum<E: WgpuElement, const D: usize>(input: WgpuTensor<E, D>) -> WgpuTenso
|
|||
.client
|
||||
.empty(core::mem::size_of::<E>() * num_invocations);
|
||||
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<RecursiveSumRaw, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
workgroup,
|
||||
);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<RecursiveSumRaw, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(workgroup);
|
||||
|
||||
input
|
||||
.client
|
||||
|
@ -75,7 +72,7 @@ pub fn sum<E: WgpuElement, const D: usize>(input: WgpuTensor<E, D>) -> WgpuTenso
|
|||
}
|
||||
|
||||
input_handle = handle;
|
||||
workgroup = elemwise_workgroup(num_invocations, WORKGROUP);
|
||||
workgroup = elemwise_workgroup(num_invocations, WORKGROUP_DEFAULT);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -99,8 +96,6 @@ fn reduction_dim<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
|||
input: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<E, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let mut shape_out = input.shape.clone();
|
||||
shape_out.dims[dim] = 1;
|
||||
let num_elems = shape_out.num_elements();
|
||||
|
@ -112,9 +107,10 @@ fn reduction_dim<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
|||
handle,
|
||||
);
|
||||
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
);
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<K, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP_DEFAULT),
|
||||
);
|
||||
|
||||
let mut info = build_info(&[&input, &output]);
|
||||
info.push(dim as u32);
|
||||
|
@ -148,8 +144,6 @@ fn reduction_args_dim<K: StaticKernelSource, E: WgpuElement, I: WgpuElement, con
|
|||
input: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<I, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let mut shape_out = input.shape.clone();
|
||||
shape_out.dims[dim] = 1;
|
||||
let num_elems = shape_out.num_elements();
|
||||
|
@ -161,9 +155,10 @@ fn reduction_args_dim<K: StaticKernelSource, E: WgpuElement, I: WgpuElement, con
|
|||
buffer,
|
||||
);
|
||||
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, I, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
);
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<K, E, I, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP_DEFAULT),
|
||||
);
|
||||
let mut info = build_info(&[&input, &output]);
|
||||
info.push(dim as u32);
|
||||
let info_handle = input.client.create(bytemuck::cast_slice(&info));
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{elemwise_workgroup, KernelSettings, StaticKernelSource};
|
||||
use super::{elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT};
|
||||
use crate::{compute::StaticKernel, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
|
||||
|
||||
kernel_wgsl!(UnaryRaw, "../template/unary.wgsl");
|
||||
|
@ -98,14 +98,14 @@ macro_rules! unary_inplace {
|
|||
pub fn unary_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
unary::<K, E, D, 32>(input)
|
||||
unary::<K, E, D, WORKGROUP_DEFAULT>(input)
|
||||
}
|
||||
|
||||
/// Execute a unary inplace kernel using the default settings.
|
||||
pub fn unary_inplace_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
unary_inplace::<K, E, D, 32>(input)
|
||||
unary_inplace::<K, E, D, WORKGROUP_DEFAULT>(input)
|
||||
}
|
||||
|
||||
/// Execute a unary inplace kernel using the provided WORKGROUP.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{elemwise_workgroup, KernelSettings, StaticKernelSource};
|
||||
use super::{elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT};
|
||||
use crate::{compute::StaticKernel, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
|
||||
|
||||
kernel_wgsl!(UnaryScalarRaw, "../template/unary_scalar.wgsl");
|
||||
|
@ -108,7 +108,7 @@ pub fn unary_scalar_default<K: StaticKernelSource, E: WgpuElement, const D: usiz
|
|||
lhs: WgpuTensor<E, D>,
|
||||
scalar: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
unary_scalar::<K, E, D, 32>(lhs, scalar)
|
||||
unary_scalar::<K, E, D, WORKGROUP_DEFAULT>(lhs, scalar)
|
||||
}
|
||||
|
||||
/// Execute a unary scalar kernel using the provided WORKGROUP.
|
||||
|
@ -142,7 +142,7 @@ pub fn unary_scalar_inplace_default<K: StaticKernelSource, E: WgpuElement, const
|
|||
lhs: WgpuTensor<E, D>,
|
||||
scalar: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
unary_scalar_inplace::<K, E, D, 32>(lhs, scalar)
|
||||
unary_scalar_inplace::<K, E, D, WORKGROUP_DEFAULT>(lhs, scalar)
|
||||
}
|
||||
|
||||
/// Execute a unary scalar inplace kernel using the provided WORKGROUP.
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate::{
|
|||
compute::compute_client, element::WgpuElement, kernel, tensor::WgpuTensor, GraphicsApi,
|
||||
WgpuDevice,
|
||||
};
|
||||
use burn_tensor::{backend::Backend, Data, Shape};
|
||||
use burn_tensor::{backend::Backend, Data, Reader, Shape};
|
||||
|
||||
pub type FloatElem<B> = <B as Backend>::FloatElem;
|
||||
pub type Device<B> = <B as Backend>::Device;
|
||||
|
@ -25,12 +25,24 @@ pub fn from_data<G: GraphicsApi, E: WgpuElement, const D: usize>(
|
|||
WgpuTensor::new(client, device.clone(), data.shape, buffer)
|
||||
}
|
||||
|
||||
pub fn into_data<E: WgpuElement, const D: usize>(tensor: WgpuTensor<E, D>) -> Data<E, D> {
|
||||
pub fn into_data<E: WgpuElement, const D: usize>(tensor: WgpuTensor<E, D>) -> Reader<Data<E, D>> {
|
||||
let tensor = kernel::into_contiguous(tensor);
|
||||
let bytes = tensor.client.read(&tensor.handle);
|
||||
let values = E::from_bytes(&bytes);
|
||||
|
||||
Data::new(values.to_vec(), tensor.shape)
|
||||
tensor
|
||||
.client
|
||||
.read(&tensor.handle)
|
||||
.map(|bytes| Data::new(E::from_bytes(&bytes).to_vec(), tensor.shape))
|
||||
}
|
||||
|
||||
pub fn bool_into_data<const D: usize>(tensor: WgpuTensor<u32, D>) -> Reader<Data<bool, D>> {
|
||||
let tensor = kernel::into_contiguous(tensor);
|
||||
|
||||
tensor.client.read(&tensor.handle).map(|bytes| {
|
||||
Data::new(
|
||||
u32::from_bytes(&bytes).iter().map(|i| *i != 0).collect(),
|
||||
tensor.shape,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn to_device<G: GraphicsApi, E: WgpuElement, const D: usize>(
|
||||
|
|
|
@ -5,7 +5,8 @@ use crate::{
|
|||
tensor::WgpuTensor,
|
||||
GraphicsApi, WgpuBackend,
|
||||
};
|
||||
use burn_tensor::{ops::BoolTensorOps, ops::IntTensorOps, Data, Shape};
|
||||
use burn_tensor::{ops::BoolTensorOps, Data, Shape};
|
||||
use burn_tensor::{ops::IntTensorOps, Reader};
|
||||
use std::ops::Range;
|
||||
|
||||
impl<G, F, I> BoolTensorOps<WgpuBackend<G, F, I>> for WgpuBackend<G, F, I>
|
||||
|
@ -22,10 +23,8 @@ where
|
|||
tensor.shape.clone()
|
||||
}
|
||||
|
||||
fn bool_into_data<const D: usize>(tensor: BoolTensor<Self, D>) -> Data<bool, D> {
|
||||
let data = super::into_data(tensor);
|
||||
|
||||
Data::new(data.value.into_iter().map(|i| i != 0).collect(), data.shape)
|
||||
fn bool_into_data<const D: usize>(tensor: BoolTensor<Self, D>) -> Reader<Data<bool, D>> {
|
||||
super::bool_into_data(tensor)
|
||||
}
|
||||
|
||||
fn bool_from_data<const D: usize>(
|
||||
|
@ -51,7 +50,10 @@ where
|
|||
}
|
||||
|
||||
let device = Self::bool_device(&tensor);
|
||||
let data = Self::bool_into_data(tensor).convert::<I>();
|
||||
let data = Self::bool_into_data(tensor)
|
||||
.read_sync()
|
||||
.expect("Can't convert bool to int with a different type size async")
|
||||
.convert::<I>();
|
||||
|
||||
Self::int_from_data(data, &device)
|
||||
}
|
||||
|
|
|
@ -9,8 +9,8 @@ use crate::{
|
|||
element::{FloatElement, IntElement},
|
||||
unary, unary_inplace, unary_scalar, GraphicsApi, WgpuBackend,
|
||||
};
|
||||
use burn_tensor::ElementConversion;
|
||||
use burn_tensor::{ops::TensorOps, Data, Distribution, Shape};
|
||||
use burn_tensor::{ElementConversion, Reader};
|
||||
|
||||
use std::ops::Range;
|
||||
|
||||
|
@ -48,11 +48,7 @@ where
|
|||
tensor.shape.clone()
|
||||
}
|
||||
|
||||
fn to_data<const D: usize>(tensor: &FloatTensor<Self, D>) -> Data<FloatElem<Self>, D> {
|
||||
super::into_data(tensor.clone())
|
||||
}
|
||||
|
||||
fn into_data<const D: usize>(tensor: FloatTensor<Self, D>) -> Data<FloatElem<Self>, D> {
|
||||
fn into_data<const D: usize>(tensor: FloatTensor<Self, D>) -> Reader<Data<FloatElem<Self>, D>> {
|
||||
super::into_data(tensor)
|
||||
}
|
||||
|
||||
|
|
|
@ -4,6 +4,8 @@ use crate::{
|
|||
element::{FloatElement, IntElement},
|
||||
kernel, unary, unary_inplace, GraphicsApi, WgpuBackend,
|
||||
};
|
||||
|
||||
use burn_tensor::Reader;
|
||||
use burn_tensor::{ops::IntTensorOps, Data, Shape};
|
||||
use std::ops::Range;
|
||||
|
||||
|
@ -21,7 +23,7 @@ where
|
|||
tensor.shape.clone()
|
||||
}
|
||||
|
||||
fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> Data<I, D> {
|
||||
fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> Reader<Data<I, D>> {
|
||||
super::into_data(tensor)
|
||||
}
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ fn main(
|
|||
|
||||
fn taus_step(z: u32, s1: u32, s2: u32, s3: u32, m: u32) -> u32 {
|
||||
let b = ((z << s1) ^ z) >> s2;
|
||||
return (z & m) << s3 ^ b;
|
||||
return ((z & m) << s3) ^ b;
|
||||
}
|
||||
|
||||
fn taus_step_0(z: u32) -> u32 {
|
||||
|
|
|
@ -94,8 +94,12 @@ impl<E: WgpuElement, const D: usize> WgpuTensor<E, D> {
|
|||
|
||||
/// Change the context of the current tensor and return the newly transferred tensor.
|
||||
pub fn to_client(&self, client: WgpuComputeClient, device: WgpuDevice) -> Self {
|
||||
let data = self.client.read(&self.handle);
|
||||
let handle = client.create(&data);
|
||||
let bytes = self
|
||||
.client
|
||||
.read(&self.handle)
|
||||
.read_sync()
|
||||
.expect("Can only change client synchronously");
|
||||
let handle = client.create(&bytes);
|
||||
|
||||
Self {
|
||||
client,
|
||||
|
@ -106,6 +110,7 @@ impl<E: WgpuElement, const D: usize> WgpuTensor<E, D> {
|
|||
elem: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn can_mut_broadcast(&self, tensor_other: &WgpuTensor<E, D>) -> bool {
|
||||
if !self.handle.can_mut() {
|
||||
return false;
|
||||
|
|
|
@ -44,7 +44,6 @@ ndarray-blas-openblas = ["burn-core/ndarray-blas-openblas"]
|
|||
ndarray-blas-openblas-system = ["burn-core/ndarray-blas-openblas-system"]
|
||||
|
||||
wgpu = ["burn-core/wgpu"]
|
||||
|
||||
tch = ["burn-core/tch"]
|
||||
|
||||
# Experimental
|
||||
|
|
|
@ -10,9 +10,17 @@ version = "0.10.0"
|
|||
crate-type = ["cdylib"]
|
||||
|
||||
[features]
|
||||
default = []
|
||||
default = ["ndarray"]
|
||||
|
||||
ndarray = ["burn/ndarray-no-std"]
|
||||
wgpu = ["burn/wgpu"]
|
||||
|
||||
[dependencies]
|
||||
burn = {path = "../../burn", default-features = false, features = ["ndarray-no-std"]}
|
||||
burn = {path = "../../burn", default-features = false}
|
||||
serde = {workspace = true}
|
||||
wasm-bindgen = "0.2.87"
|
||||
wasm-bindgen = { version = "0.2.87" }
|
||||
wasm-bindgen-futures = "0.4"
|
||||
js-sys = "0.3.64"
|
||||
|
||||
[dev-dependencies]
|
||||
pollster = { workspace = true }
|
||||
|
|
|
@ -9,9 +9,11 @@ This crate demonstrates how to run an MNIST-trained model in the browser for inf
|
|||
1. Build
|
||||
|
||||
```shell
|
||||
./build-for-web.sh
|
||||
./build-for-web.sh {backend}
|
||||
```
|
||||
|
||||
The backend can either be `ndarray` or `wgpu`. Note that `wgpu` only works for browsers with support for WebGPU.
|
||||
|
||||
2. Run the server
|
||||
|
||||
```shell
|
||||
|
|
|
@ -10,9 +10,9 @@ then
|
|||
fi
|
||||
|
||||
# Set optimization flags
|
||||
export RUSTFLAGS="-C lto=fat -C embed-bitcode=yes -C codegen-units=1 -C opt-level=3"
|
||||
export RUSTFLAGS="-C lto=fat -C embed-bitcode=yes -C codegen-units=1 -C opt-level=3 --cfg web_sys_unstable_apis"
|
||||
|
||||
# Run wasm pack tool to build JS wrapper files and copy wasm to pkg directory.
|
||||
mkdir -p pkg
|
||||
wasm-pack build --out-dir pkg --release --target web --no-typescript
|
||||
wasm-pack build --out-dir pkg --release --target web --no-typescript --no-default-features --features $1
|
||||
|
||||
|
|
|
@ -123,13 +123,13 @@
|
|||
wasm().then((module) => {
|
||||
const mnist = new Mnist();
|
||||
|
||||
function fireOffInference() {
|
||||
async function fireOffInference() {
|
||||
clearTimeout(timeoutId);
|
||||
timeoutId = setTimeout(() => {
|
||||
timeoutId = setTimeout(async () => {
|
||||
isTimeOutSet = true;
|
||||
fabricCanvas.freeDrawingBrush._finalizeAndAddPath();
|
||||
const data = cropScaleGetImageData(mainContext, cropContext, scaledContext);
|
||||
const output = mnist.inference(data);
|
||||
const output = await mnist.inference(data);
|
||||
chart.data.datasets[0].data = output;
|
||||
chart.update();
|
||||
isTimeOutSet = false;
|
||||
|
@ -140,14 +140,14 @@
|
|||
fabricCanvas.on("mouse:down", function (event) {
|
||||
isDrawing = true;
|
||||
});
|
||||
fabricCanvas.on("mouse:up", function (event) {
|
||||
fabricCanvas.on("mouse:up", async function (event) {
|
||||
isDrawing = false;
|
||||
fireOffInference();
|
||||
await fireOffInference();
|
||||
});
|
||||
|
||||
fabricCanvas.on("mouse:move", function (event) {
|
||||
fabricCanvas.on("mouse:move", async function (event) {
|
||||
if (isDrawing && isTimeOutSet == false) {
|
||||
fireOffInference();
|
||||
await fireOffInference();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
|
|
@ -1,16 +1,25 @@
|
|||
use crate::model::Model;
|
||||
use burn::backend::ndarray::NdArrayBackend;
|
||||
use burn::module::Module;
|
||||
use burn::record::BinBytesRecorder;
|
||||
use burn::record::FullPrecisionSettings;
|
||||
use burn::record::Recorder;
|
||||
|
||||
pub type Backend = NdArrayBackend<f32>;
|
||||
#[cfg(feature = "wgpu")]
|
||||
use burn::backend::wgpu::{compute::init_async, AutoGraphicsApi, WgpuBackend, WgpuDevice};
|
||||
|
||||
#[cfg(feature = "wgpu")]
|
||||
pub type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
pub type Backend = burn::backend::ndarray::NdArrayBackend<f32>;
|
||||
|
||||
static STATE_ENCODED: &[u8] = include_bytes!("../model.bin");
|
||||
|
||||
/// Builds and loads trained parameters into the model.
|
||||
pub fn build_and_load_model() -> Model<Backend> {
|
||||
pub async fn build_and_load_model() -> Model<Backend> {
|
||||
#[cfg(feature = "wgpu")]
|
||||
init_async::<AutoGraphicsApi>(&WgpuDevice::default()).await;
|
||||
|
||||
let model: Model<Backend> = Model::new();
|
||||
let record = BinBytesRecorder::<FullPrecisionSettings>::default()
|
||||
.load(STATE_ENCODED.to_vec())
|
||||
|
|
|
@ -1,29 +1,29 @@
|
|||
#![allow(clippy::new_without_default)]
|
||||
|
||||
use alloc::{boxed::Box, string::String};
|
||||
use alloc::string::String;
|
||||
use js_sys::Array;
|
||||
|
||||
#[cfg(target_family = "wasm")]
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
use crate::model::Model;
|
||||
use crate::state::{build_and_load_model, Backend};
|
||||
|
||||
use burn::tensor::Tensor;
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Mnist structure that corresponds to JavaScript class.
|
||||
/// See:[exporting-rust-struct](https://rustwasm.github.io/wasm-bindgen/contributing/design/exporting-rust-struct.html)
|
||||
#[wasm_bindgen]
|
||||
#[cfg_attr(target_family = "wasm", wasm_bindgen)]
|
||||
pub struct Mnist {
|
||||
model: Model<Backend>,
|
||||
model: Option<Model<Backend>>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
#[cfg_attr(target_family = "wasm", wasm_bindgen)]
|
||||
impl Mnist {
|
||||
/// Constructor called by JavaScripts with the new keyword.
|
||||
#[wasm_bindgen(constructor)]
|
||||
#[cfg_attr(target_family = "wasm", wasm_bindgen(constructor))]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
model: build_and_load_model(),
|
||||
}
|
||||
Self { model: None }
|
||||
}
|
||||
|
||||
/// Returns the inference results.
|
||||
|
@ -38,7 +38,13 @@ impl Mnist {
|
|||
/// * [number-slices](https://rustwasm.github.io/wasm-bindgen/reference/types/number-slices.html)
|
||||
/// * [boxed-number-slices](https://rustwasm.github.io/wasm-bindgen/reference/types/boxed-number-slices.html)
|
||||
///
|
||||
pub fn inference(&self, input: &[f32]) -> Result<Box<[f32]>, String> {
|
||||
pub async fn inference(&mut self, input: &[f32]) -> Result<Array, String> {
|
||||
if self.model.is_none() {
|
||||
self.model = Some(build_and_load_model().await);
|
||||
}
|
||||
|
||||
let model = self.model.as_ref().unwrap();
|
||||
|
||||
// Reshape from the 1D array to 3d tensor [batch, height, width]
|
||||
let input: Tensor<Backend, 3> = Tensor::from_floats(input).reshape([1, 28, 28]);
|
||||
|
||||
|
@ -49,77 +55,23 @@ impl Mnist {
|
|||
let input = ((input / 255) - 0.1307) / 0.3081;
|
||||
|
||||
// Run the tensor input through the model
|
||||
let output: Tensor<Backend, 2> = self.model.forward(input);
|
||||
let output: Tensor<Backend, 2> = model.forward(input);
|
||||
|
||||
// Convert the model output into probability distribution using softmax formula
|
||||
let output: Tensor<Backend, 2> = output.clone().exp() / output.exp().sum_dim(1);
|
||||
let output = burn::tensor::activation::softmax(output, 1);
|
||||
|
||||
// Flatten output tensor with [1, 10] shape into boxed slice of [f32]
|
||||
Ok(output.to_data().value.into_boxed_slice())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::Mnist;
|
||||
|
||||
#[test]
|
||||
fn inference_manual_from_test_data() {
|
||||
let mnist = Mnist::new();
|
||||
let input: Vec<f32> = vec![
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 22.0, 97.0, 181.0, 254.0, 255.0, 221.0, 106.0, 3.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 28.0, 128.0, 213.0,
|
||||
245.0, 254.0, 254.0, 246.0, 239.0, 254.0, 254.0, 94.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 27.0, 151.0, 239.0, 254.0, 254.0, 222.0,
|
||||
204.0, 189.0, 70.0, 27.0, 215.0, 254.0, 98.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 43.0, 248.0, 254.0, 233.0, 40.0, 15.0, 0.0, 0.0,
|
||||
0.0, 0.0, 96.0, 254.0, 111.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 121.0, 72.0, 33.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 76.0, 254.0,
|
||||
187.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 94.0, 254.0, 149.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 124.0, 254.0, 40.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 203.0,
|
||||
174.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 102.0, 254.0, 122.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 40.0, 223.0, 205.0, 254.0, 115.0, 33.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 40.0, 248.0,
|
||||
254.0, 254.0, 254.0, 242.0, 191.0, 71.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 140.0, 254.0, 254.0, 248.0,
|
||||
209.0, 254.0, 213.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 186.0, 208.0, 71.0, 49.0, 74.0, 191.0, 134.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 134.0, 254.0, 119.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 16.0, 197.0,
|
||||
168.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 140.0, 225.0, 43.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 51.0, 233.0, 106.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 161.0,
|
||||
227.0, 17.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 228.0, 124.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 168.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
];
|
||||
|
||||
let output = mnist.inference(input.as_slice()).unwrap();
|
||||
|
||||
assert!(output[7] > 0.9);
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
let output = output.into_data().convert::<f32>().value;
|
||||
|
||||
#[cfg(target_family = "wasm")]
|
||||
let output = output.into_data().await.convert::<f32>().value;
|
||||
|
||||
let array = Array::new();
|
||||
for value in output {
|
||||
array.push(&value.into());
|
||||
}
|
||||
|
||||
Ok(array)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue