From 95e660488eab4485581d32d8a22b6ce2e13fd228 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Mon, 25 Sep 2023 10:42:45 -0400 Subject: [PATCH] Refactor/burn compute wgpu (#826) --- burn-autodiff/src/backend.rs | 4 + burn-compute/src/channel/base.rs | 2 +- burn-compute/src/channel/cell.rs | 1 + burn-compute/src/channel/mpsc.rs | 2 + burn-compute/src/channel/mutex.rs | 1 + burn-compute/src/client.rs | 1 + burn-compute/src/id.rs | 2 +- burn-compute/src/memory_management/base.rs | 30 +- burn-compute/src/memory_management/simple.rs | 193 +++++++-- burn-compute/src/server.rs | 33 +- burn-compute/src/storage/bytes_cpu.rs | 6 + burn-compute/tests/dummy/compute.rs | 5 +- burn-compute/tests/dummy/kernel.rs | 4 +- burn-compute/tests/dummy/server.rs | 10 +- burn-core/Cargo.toml | 4 +- burn-core/src/record/memory.rs | 5 +- burn-core/src/record/tensor.rs | 5 +- burn-tch/src/backend.rs | 6 + burn-tensor/src/tensor/backend/base.rs | 3 + burn-wgpu/Cargo.toml | 13 +- burn-wgpu/benches/matmul.rs | 28 +- burn-wgpu/src/backend.rs | 6 + burn-wgpu/src/benchmark.rs | 11 +- burn-wgpu/src/compute/base.rs | 244 +++++++---- burn-wgpu/src/compute/kernel.rs | 106 +++++ burn-wgpu/src/compute/mod.rs | 2 + burn-wgpu/src/compute/server.rs | 90 +++- burn-wgpu/src/compute/storage.rs | 26 +- burn-wgpu/src/context/base.rs | 410 ------------------ burn-wgpu/src/context/client.rs | 190 -------- burn-wgpu/src/context/mod.rs | 6 - burn-wgpu/src/context/server.rs | 269 ------------ burn-wgpu/src/kernel/base.rs | 64 +-- burn-wgpu/src/kernel/binary_elemwise.rs | 65 ++- burn-wgpu/src/kernel/cast.rs | 50 ++- burn-wgpu/src/kernel/cat.rs | 27 +- burn-wgpu/src/kernel/comparison/binary.rs | 65 ++- burn-wgpu/src/kernel/comparison/elem.rs | 52 +-- burn-wgpu/src/kernel/conv/conv2d.rs | 47 +- burn-wgpu/src/kernel/conv/conv_transpose2d.rs | 51 ++- burn-wgpu/src/kernel/index/gather.rs | 30 +- burn-wgpu/src/kernel/index/scatter.rs | 18 +- burn-wgpu/src/kernel/index/select.rs | 51 +-- burn-wgpu/src/kernel/index/slice.rs | 46 +- burn-wgpu/src/kernel/mask/mask_fill.rs | 64 ++- burn-wgpu/src/kernel/mask/mask_where.rs | 60 ++- burn-wgpu/src/kernel/matmul/mem_coalescing.rs | 44 +- burn-wgpu/src/kernel/matmul/mod.rs | 2 - burn-wgpu/src/kernel/matmul/naive.rs | 55 ++- burn-wgpu/src/kernel/matmul/tiling2d/base.rs | 69 ++- .../src/kernel/matmul/tiling2d/contiguous.rs | 2 +- .../matmul/tiling2d/contiguous_vectorized.rs | 2 +- .../src/kernel/matmul/tiling2d/padding.rs | 12 +- burn-wgpu/src/kernel/matmul/tiling2d/tile.rs | 2 +- .../kernel/matmul/tiling2d/tile_vectorized.rs | 2 +- burn-wgpu/src/kernel/matmul/tune.rs | 396 ----------------- .../src/kernel/pool/adaptive_avg_pool2d.rs | 68 ++- burn-wgpu/src/kernel/pool/avg_pool2d.rs | 87 ++-- burn-wgpu/src/kernel/pool/base.rs | 21 +- burn-wgpu/src/kernel/pool/max_pool2d.rs | 65 +-- burn-wgpu/src/kernel/prng/base.rs | 33 +- burn-wgpu/src/kernel/prng/bernoulli.rs | 33 +- burn-wgpu/src/kernel/prng/normal.rs | 31 +- burn-wgpu/src/kernel/prng/uniform.rs | 32 +- burn-wgpu/src/kernel/reduction.rs | 126 +++--- burn-wgpu/src/kernel/unary.rs | 80 ++-- burn-wgpu/src/kernel/unary_scalar.rs | 93 ++-- burn-wgpu/src/lib.rs | 11 +- burn-wgpu/src/ops/base.rs | 25 +- burn-wgpu/src/ops/bool_ops.rs | 4 +- burn-wgpu/src/ops/float_ops.rs | 8 +- burn-wgpu/src/ops/int_ops.rs | 2 +- burn-wgpu/src/ops/numeric.rs | 32 +- burn-wgpu/src/pool.rs | 63 --- burn-wgpu/src/tensor/base.rs | 72 +-- burn-wgpu/src/tune/base.rs | 196 --------- burn-wgpu/src/tune/mod.rs | 3 - burn/Cargo.toml | 1 - examples/custom-wgpu-kernel/src/forward.rs | 44 +- 79 files changed, 1460 insertions(+), 2664 deletions(-) create mode 100644 burn-wgpu/src/compute/kernel.rs delete mode 100644 burn-wgpu/src/context/base.rs delete mode 100644 burn-wgpu/src/context/client.rs delete mode 100644 burn-wgpu/src/context/mod.rs delete mode 100644 burn-wgpu/src/context/server.rs delete mode 100644 burn-wgpu/src/kernel/matmul/tune.rs delete mode 100644 burn-wgpu/src/pool.rs delete mode 100644 burn-wgpu/src/tune/base.rs delete mode 100644 burn-wgpu/src/tune/mod.rs diff --git a/burn-autodiff/src/backend.rs b/burn-autodiff/src/backend.rs index 20929720f..6aaa281d1 100644 --- a/burn-autodiff/src/backend.rs +++ b/burn-autodiff/src/backend.rs @@ -32,6 +32,10 @@ impl Backend for ADBackendDecorator { fn seed(seed: u64) { B::seed(seed) } + + fn sync(device: &B::Device) { + B::sync(device); + } } impl ADBackend for ADBackendDecorator { diff --git a/burn-compute/src/channel/base.rs b/burn-compute/src/channel/base.rs index 23180cdb7..2ed11c349 100644 --- a/burn-compute/src/channel/base.rs +++ b/burn-compute/src/channel/base.rs @@ -3,7 +3,7 @@ use alloc::vec::Vec; /// The ComputeChannel trait links the ComputeClient to the ComputeServer /// while ensuring thread-safety -pub trait ComputeChannel: Clone { +pub trait ComputeChannel: Clone + core::fmt::Debug { /// Given a handle, returns owned resource as bytes fn read(&self, handle: &Handle) -> Vec; diff --git a/burn-compute/src/channel/cell.rs b/burn-compute/src/channel/cell.rs index ada9f0665..7af3e44b4 100644 --- a/burn-compute/src/channel/cell.rs +++ b/burn-compute/src/channel/cell.rs @@ -12,6 +12,7 @@ use alloc::vec::Vec; /// /// This is mosly useful for `no-std` environments where threads aren't supported, otherwise prefer /// the [mutex](super::MutexComputeChannel) or the [mpsc](super::MpscComputeChannel) channels. +#[derive(Debug)] pub struct RefCellComputeChannel { server: Arc>, } diff --git a/burn-compute/src/channel/mpsc.rs b/burn-compute/src/channel/mpsc.rs index 63392e0f0..1aa18a0bb 100644 --- a/burn-compute/src/channel/mpsc.rs +++ b/burn-compute/src/channel/mpsc.rs @@ -8,6 +8,7 @@ use crate::server::{ComputeServer, Handle}; /// Create a channel using the [multi-producer, single-consumer channel](mpsc) to communicate with /// the compute server spawn on its own thread. +#[derive(Debug)] pub struct MpscComputeChannel where Server: ComputeServer, @@ -15,6 +16,7 @@ where state: Arc>, } +#[derive(Debug)] struct MpscComputeChannelState where Server: ComputeServer, diff --git a/burn-compute/src/channel/mutex.rs b/burn-compute/src/channel/mutex.rs index 369b365fa..3c98c8686 100644 --- a/burn-compute/src/channel/mutex.rs +++ b/burn-compute/src/channel/mutex.rs @@ -6,6 +6,7 @@ use spin::Mutex; /// The MutexComputeChannel ensures thread-safety by locking the server /// on every operation +#[derive(Debug)] pub struct MutexComputeChannel { server: Arc>, } diff --git a/burn-compute/src/client.rs b/burn-compute/src/client.rs index 422603118..4d576cbb0 100644 --- a/burn-compute/src/client.rs +++ b/burn-compute/src/client.rs @@ -7,6 +7,7 @@ use core::marker::PhantomData; /// The ComputeClient is the entry point to require tasks from the ComputeServer. /// It should be obtained for a specific device via the Compute struct. +#[derive(Debug)] pub struct ComputeClient { channel: Channel, _server: PhantomData, diff --git a/burn-compute/src/id.rs b/burn-compute/src/id.rs index 1c71ccd84..33ba53c04 100644 --- a/burn-compute/src/id.rs +++ b/burn-compute/src/id.rs @@ -29,7 +29,7 @@ macro_rules! storage_id_type { /// Create a new memory ID type. macro_rules! memory_id_type { ($name:ident) => { - #[derive(Clone, Hash, PartialEq, Eq)] + #[derive(Clone, Hash, PartialEq, Eq, Debug)] /// Memory ID. pub struct $name { id: alloc::sync::Arc, diff --git a/burn-compute/src/memory_management/base.rs b/burn-compute/src/memory_management/base.rs index bf0203290..4a6310cf9 100644 --- a/burn-compute/src/memory_management/base.rs +++ b/burn-compute/src/memory_management/base.rs @@ -5,7 +5,7 @@ use crate::storage::ComputeStorage; /// /// It is responsible for determining if the memory segment can be mutated, /// for instance by keeping track of a reference count -pub trait MemoryHandle: Clone + Send { +pub trait MemoryHandle: Clone + Send + core::fmt::Debug { /// Checks if the underlying memory can be safely mutated. fn can_mut(&self) -> bool; } @@ -15,7 +15,7 @@ pub trait MemoryHandle: Clone + Send { /// /// The MemoryManagement can only reserve memory space or get the resource located at a space. /// Modification of the resource data should be done directly on the resource. -pub trait MemoryManagement: Send { +pub trait MemoryManagement: Send + core::fmt::Debug { /// The associated type Handle must implement MemoryHandle type Handle: MemoryHandle; @@ -24,4 +24,30 @@ pub trait MemoryManagement: Send { /// Finds a spot in memory for a resource with the given size in bytes, and returns a handle to it fn reserve(&mut self, size: usize) -> Self::Handle; + + /// Bypass the memory allocation algorithm to allocate data directly. + /// + /// # Notes + /// + /// Can be useful for servers that want specific control over memory. + fn alloc(&mut self, size: usize) -> Self::Handle; + + /// Bypass the memory allocation algorithm to deallocate data directly. + /// + /// # Notes + /// + /// Can be useful for servers that want specific control over memory. + fn dealloc(&mut self, handle: &Self::Handle); + + /// Fetch the storage used by the memory manager. + /// + /// # Notes + /// + /// The storage should probably not be used for allocations since the handles won't be + /// compatible with the ones provided by the current trait. Prefer using the + /// [alloc](MemoryManagement::alloc) and [dealloc](MemoryManagement::dealloc) functions. + /// + /// This is useful if you need to time the deallocations based on async computation, or to + /// change the mode of storage for different reasons. + fn storage(&mut self) -> &mut Storage; } diff --git a/burn-compute/src/memory_management/simple.rs b/burn-compute/src/memory_management/simple.rs index 6cad5a062..e6bb4fb37 100644 --- a/burn-compute/src/memory_management/simple.rs +++ b/burn-compute/src/memory_management/simple.rs @@ -26,7 +26,7 @@ impl SliceId { } /// The SimpleHandle is a memory handle, referring to either a chunk or a slice. -#[derive(Clone)] +#[derive(Debug, Clone)] pub enum SimpleHandle { /// A whole chunk of memory. Chunk(ChunkId), @@ -35,34 +35,72 @@ pub enum SimpleHandle { } /// The strategy defines the frequency at which deallocation of unused memory chunks should occur. +#[derive(Debug)] pub enum DeallocStrategy { /// Once every n calls to reserve. - /// - /// First associated data is n, second is the state and should start at 0 - PeriodTick(usize, usize), + PeriodTick { + /// Number of calls to be executed before triggering the deallocation. + period: usize, + /// Current state. Should start at zero. + state: usize, + }, #[cfg(feature = "std")] /// Once every period of time - PeriodTime(std::time::Duration, std::time::Instant), + PeriodTime { + /// Number of time before triggering the deallocation. + period: std::time::Duration, + /// Current state. Should start at now. + state: std::time::Instant, + }, /// Never deallocate. Never, } +/// The strategy defines when to reuse chunk with slices. +#[derive(Debug)] +pub enum SliceStrategy { + /// Never use slices. + Never, + /// Ratio needed before the chunk can be used as a slice. Between 0 and 1. + Ratio(f32), + /// When the reserved memory is at least {} bytes. + MinimumSize(usize), + /// When the reserved memory less than {} bytes. + MaximumSize(usize), +} + +impl SliceStrategy { + /// If the chunk can be used with a slice. + pub fn can_use_chunk(&self, chunk_size: usize, reserved_size: usize) -> bool { + if chunk_size < reserved_size { + return false; + } + + match self { + SliceStrategy::Never => false, + SliceStrategy::Ratio(ratio) => (reserved_size as f32 / chunk_size as f32) >= *ratio, + SliceStrategy::MinimumSize(bytes) => reserved_size >= *bytes, + SliceStrategy::MaximumSize(bytes) => reserved_size <= *bytes, + } + } +} + impl DeallocStrategy { /// Create a new strategy with the given period. pub fn new_period_tick(period: usize) -> Self { - DeallocStrategy::PeriodTick(period, 0) + DeallocStrategy::PeriodTick { period, state: 0 } } fn should_dealloc(&mut self) -> bool { match self { - DeallocStrategy::PeriodTick(period, last) => { - *last = (*last + 1) % *period; - *last == 0 + DeallocStrategy::PeriodTick { period, state } => { + *state = (*state + 1) % *period; + *state == 0 } #[cfg(feature = "std")] - DeallocStrategy::PeriodTime(period, last) => { - if &last.elapsed() > period { - *last = std::time::Instant::now(); + DeallocStrategy::PeriodTime { period, state } => { + if &state.elapsed() > period { + *state = std::time::Instant::now(); true } else { false @@ -78,9 +116,23 @@ pub struct SimpleMemoryManagement { chunks: HashMap)>, slices: HashMap, dealloc_strategy: DeallocStrategy, + slice_strategy: SliceStrategy, storage: Storage, } +impl core::fmt::Debug for SimpleMemoryManagement { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str( + alloc::format!( + "SimpleMemoryManagement {:?} - {:?}", + self.dealloc_strategy, + core::any::type_name::(), + ) + .as_str(), + ) + } +} + impl MemoryHandle for SimpleHandle { /// Returns true if referenced by only one tensor, and only once by the /// memory management hashmaps @@ -126,24 +178,43 @@ impl MemoryManagement for SimpleMemoryManageme handle } + + fn alloc(&mut self, size: usize) -> Self::Handle { + self.create_chunk(size) + } + + fn dealloc(&mut self, handle: &Self::Handle) { + match handle { + SimpleHandle::Chunk(id) => { + if let Some((handle, _slices)) = self.chunks.remove(id) { + self.storage.dealloc(handle.id); + } + } + SimpleHandle::Slice(_) => panic!("Can't dealloc slice manually"), + } + } + + fn storage(&mut self) -> &mut Storage { + &mut self.storage + } } impl SimpleMemoryManagement { - /// Creates a new instance using the given storage and deallocation strategy. - pub fn new(storage: Storage, dealloc_strategy: DeallocStrategy) -> Self { + /// Creates a new instance using the given storage, deallocation strategy and slice strategy. + pub fn new( + storage: Storage, + dealloc_strategy: DeallocStrategy, + slice_strategy: SliceStrategy, + ) -> Self { Self { chunks: HashMap::new(), slices: HashMap::new(), dealloc_strategy, + slice_strategy, storage, } } - /// Creates an new instance using the given storage without deallocation. - pub fn never_dealloc(storage: Storage) -> Self { - Self::new(storage, DeallocStrategy::Never) - } - fn reserve_algorithm(&mut self, size: usize) -> SimpleHandle { // Looks for a large enough, existing but unused chunk of memory. let chunk = self.find_free_chunk(size); @@ -169,19 +240,31 @@ impl SimpleMemoryManagement { let mut size_diff_current = usize::MAX; let mut current = None; - self.chunks - .iter() - .for_each(|(chunk_id, (resource, slices))| { - let is_free = slices.is_empty() && chunk_id.is_free(); + for (chunk_id, (resource, slices)) in self.chunks.iter() { + // If chunk is already used, we do not choose it + if !slices.is_empty() || !chunk_id.is_free() { + continue; + } - if is_free && resource.size() > size { - let size_diff = resource.size() - size; - if size_diff < size_diff_current { - current = Some((chunk_id, resource)); - size_diff_current = size_diff; - } + let resource_size = resource.size(); + + // If we find a chunk of exactly the right size, we stop searching altogether + if size == resource_size { + current = Some((chunk_id, resource)); + break; + } + + // Finds the smallest of the large enough chunks that can accept a slice + // of the given size + if self.slice_strategy.can_use_chunk(resource_size, size) { + let size_diff = resource_size - size; + + if size_diff < size_diff_current { + current = Some((chunk_id, resource)); + size_diff_current = size_diff; } - }); + } + } current.map(|(id, handle)| (id.clone(), handle.size())) } @@ -263,7 +346,7 @@ impl SimpleMemoryManagement { #[cfg(test)] mod tests { use crate::{ - memory_management::{MemoryHandle, MemoryManagement}, + memory_management::{MemoryHandle, MemoryManagement, SliceStrategy}, storage::BytesStorage, }; @@ -271,7 +354,11 @@ mod tests { #[test] fn can_mut_with_single_tensor_reference() { - let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default()); + let mut memory_management = SimpleMemoryManagement::new( + BytesStorage::default(), + DeallocStrategy::Never, + SliceStrategy::Never, + ); let chunk_size = 4; let simple_handle = memory_management.create_chunk(chunk_size); @@ -284,7 +371,11 @@ mod tests { #[test] fn two_tensor_references_remove_mutability() { - let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default()); + let mut memory_management = SimpleMemoryManagement::new( + BytesStorage::default(), + DeallocStrategy::Never, + SliceStrategy::Never, + ); let chunk_size = 4; let simple_handle = memory_management.create_chunk(chunk_size); @@ -297,7 +388,11 @@ mod tests { #[test] fn when_non_empty_chunk_exists_and_other_one_created_there_should_be_two() { - let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default()); + let mut memory_management = SimpleMemoryManagement::new( + BytesStorage::default(), + DeallocStrategy::Never, + SliceStrategy::Never, + ); let chunk_size = 4; let _chunk_handle = memory_management.reserve(chunk_size); let _new_handle = memory_management.reserve(chunk_size); @@ -307,7 +402,11 @@ mod tests { #[test] fn when_empty_chunk_is_cleaned_upexists_it_disappears() { - let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default()); + let mut memory_management = SimpleMemoryManagement::new( + BytesStorage::default(), + DeallocStrategy::Never, + SliceStrategy::Never, + ); let chunk_size = 4; let chunk_handle = memory_management.reserve(chunk_size); drop(chunk_handle); @@ -336,4 +435,28 @@ mod tests { assert!(period_tick_dealloc.should_dealloc()); } } + + #[test] + fn slice_strategy_minimum_bytes() { + let strategy = SliceStrategy::MinimumSize(100); + + assert!(strategy.can_use_chunk(200, 101)); + assert!(!strategy.can_use_chunk(200, 99)); + } + + #[test] + fn slice_strategy_maximum_bytes() { + let strategy = SliceStrategy::MaximumSize(100); + + assert!(strategy.can_use_chunk(200, 99)); + assert!(!strategy.can_use_chunk(200, 101)); + } + + #[test] + fn slice_strategy_ratio() { + let strategy = SliceStrategy::Ratio(0.9); + + assert!(strategy.can_use_chunk(200, 180)); + assert!(!strategy.can_use_chunk(200, 179)); + } } diff --git a/burn-compute/src/server.rs b/burn-compute/src/server.rs index 24fd5a59d..0a69e1513 100644 --- a/burn-compute/src/server.rs +++ b/burn-compute/src/server.rs @@ -1,18 +1,39 @@ +use crate::{ + memory_management::{MemoryHandle, MemoryManagement}, + storage::ComputeStorage, +}; use alloc::vec::Vec; -use crate::{memory_management::MemoryManagement, storage::ComputeStorage}; +/// Server handle containing the [memory handle](MemoryManagement::Handle). +#[derive(new, Debug)] +pub struct Handle { + /// Handle for the memory in use. + pub memory: >::Handle, +} -type _Storage = ::Storage; -type _MemoryManagement = ::MemoryManagement; +impl Handle { + /// If the tensor handle can be mut with an inplace operation. + pub fn can_mut(&self) -> bool { + self.memory.can_mut() + } +} -/// This alias for a [memory handle](MemoryManagement::Handle). -pub type Handle = <_MemoryManagement as MemoryManagement<_Storage>>::Handle; +impl Clone for Handle { + fn clone(&self) -> Self { + Self { + memory: self.memory.clone(), + } + } +} /// The compute server is responsible for handling resources and computations over resources. /// /// Everything in the server is mutable, therefore it should be solely accessed through the /// [compute channel](crate::channel::ComputeChannel) for thread safety. -pub trait ComputeServer: Send { +pub trait ComputeServer: Send + core::fmt::Debug +where + Self: Sized, +{ /// The kernel type defines the computation algorithms. type Kernel: Send; /// The [storage](ComputeStorage) type defines how data is stored and accessed. diff --git a/burn-compute/src/storage/bytes_cpu.rs b/burn-compute/src/storage/bytes_cpu.rs index a9b3f3753..6fcdfc1d6 100644 --- a/burn-compute/src/storage/bytes_cpu.rs +++ b/burn-compute/src/storage/bytes_cpu.rs @@ -8,6 +8,12 @@ pub struct BytesStorage { memory: HashMap, } +impl core::fmt::Debug for BytesStorage { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str("BytesStorage") + } +} + /// Can send to other threads, but can't sync. unsafe impl Send for BytesStorage {} unsafe impl Send for BytesResource {} diff --git a/burn-compute/tests/dummy/compute.rs b/burn-compute/tests/dummy/compute.rs index dbb4d7555..e83b09845 100644 --- a/burn-compute/tests/dummy/compute.rs +++ b/burn-compute/tests/dummy/compute.rs @@ -1,7 +1,7 @@ use super::DummyServer; use burn_compute::channel::MutexComputeChannel; use burn_compute::client::ComputeClient; -use burn_compute::memory_management::SimpleMemoryManagement; +use burn_compute::memory_management::{DeallocStrategy, SimpleMemoryManagement, SliceStrategy}; use burn_compute::storage::BytesStorage; use burn_compute::Compute; @@ -17,7 +17,8 @@ pub fn client( ) -> ComputeClient> { COMPUTE.client(device, || { let storage = BytesStorage::default(); - let memory_management = SimpleMemoryManagement::never_dealloc(storage); + let memory_management = + SimpleMemoryManagement::new(storage, DeallocStrategy::Never, SliceStrategy::Never); let server = DummyServer::new(memory_management); let channel = MutexComputeChannel::new(server); diff --git a/burn-compute/tests/dummy/kernel.rs b/burn-compute/tests/dummy/kernel.rs index b8212c6c1..5d6421b71 100644 --- a/burn-compute/tests/dummy/kernel.rs +++ b/burn-compute/tests/dummy/kernel.rs @@ -2,14 +2,14 @@ use burn_compute::storage::BytesResource; /// The DummyKernel trait should be implemented for every supported operation pub trait DummyKernel: Send { - fn compute<'a>(&self, resources: &mut [BytesResource]); + fn compute(&self, resources: &mut [BytesResource]); } /// Contains the algorithm for element-wise addition pub struct DummyElementwiseAddition; impl DummyKernel for DummyElementwiseAddition { - fn compute<'a>(&self, inputs: &mut [BytesResource]) { + fn compute(&self, inputs: &mut [BytesResource]) { // Notice how the kernel is responsible for determining which inputs // are read-only and which are writable. let lhs = &inputs[0].read(); diff --git a/burn-compute/tests/dummy/server.rs b/burn-compute/tests/dummy/server.rs index d7b0ea786..88a4d9fb3 100644 --- a/burn-compute/tests/dummy/server.rs +++ b/burn-compute/tests/dummy/server.rs @@ -9,7 +9,7 @@ use super::DummyKernel; /// The dummy server is used to test the burn-compute infrastructure. /// It uses simple memory management with a bytes storage on CPU, without asynchronous tasks. -#[derive(new)] +#[derive(new, Debug)] pub struct DummyServer> { memory_management: MM, } @@ -23,7 +23,7 @@ where type MemoryManagement = MM; fn read(&mut self, handle: &Handle) -> Vec { - let bytes = self.memory_management.get(handle); + let bytes = self.memory_management.get(&handle.memory); bytes.read().to_vec() } @@ -38,17 +38,17 @@ where bytes[i] = *val; } - handle + Handle::new(handle) } fn empty(&mut self, size: usize) -> Handle { - self.memory_management.reserve(size) + Handle::new(self.memory_management.reserve(size)) } fn execute(&mut self, kernel: Self::Kernel, handles: &[&Handle]) { let mut resources = handles .iter() - .map(|handle| self.memory_management.get(handle)) + .map(|handle| self.memory_management.get(&handle.memory)) .collect::>(); kernel.compute(&mut resources); diff --git a/burn-core/Cargo.toml b/burn-core/Cargo.toml index b0c510f6c..28d263aff 100644 --- a/burn-core/Cargo.toml +++ b/burn-core/Cargo.toml @@ -23,6 +23,7 @@ std = [ "serde_json/std", "bincode/std", "half/std", + "derive-new/std", ] dataset = ["burn-dataset/default"] @@ -42,7 +43,6 @@ ndarray-blas-openblas-system = ["__ndarray", "ndarray", "burn-ndarray/blas-openb __ndarray = [] # Internal flag to know when one ndarray feature is enabled. wgpu = ["burn-wgpu"] -wgpu-autotune = ["wgpu", "burn-wgpu/autotune"] tch = ["burn-tch"] @@ -67,7 +67,7 @@ burn-autodiff = { path = "../burn-autodiff", version = "0.10.0", optional = true burn-wgpu = { path = "../burn-wgpu", version = "0.10.0", optional = true } burn-tch = { path = "../burn-tch", version = "0.10.0", optional = true } -derive-new = { workspace = true } +derive-new = { workspace = true, default-features = false } libm = { workspace = true } log = { workspace = true, optional = true } rand = { workspace = true, features = ["std_rng"] } # Default enables std diff --git a/burn-core/src/record/memory.rs b/burn-core/src/record/memory.rs index 708a22e43..545ad87e6 100644 --- a/burn-core/src/record/memory.rs +++ b/burn-core/src/record/memory.rs @@ -1,6 +1,5 @@ use super::{bin_config, PrecisionSettings, Recorder, RecorderError}; use alloc::vec::Vec; -use core::marker::PhantomData; use serde::{de::DeserializeOwned, Serialize}; /// Recorder trait specialized to save and load data to and from bytes. @@ -17,7 +16,7 @@ pub trait BytesRecorder: /// In memory recorder using the [bincode format](bincode). #[derive(new, Debug, Default, Clone)] pub struct BinBytesRecorder { - _settings: PhantomData, + _settings: core::marker::PhantomData, } impl BytesRecorder for BinBytesRecorder {} @@ -45,7 +44,7 @@ impl Recorder for BinBytesRecorder { /// In memory recorder using the [Named MessagePack](rmp_serde). #[derive(new, Debug, Default, Clone)] pub struct NamedMpkBytesRecorder { - _settings: PhantomData, + _settings: core::marker::PhantomData, } #[cfg(feature = "std")] diff --git a/burn-core/src/record/tensor.rs b/burn-core/src/record/tensor.rs index 0238b149f..4cd8b31c5 100644 --- a/burn-core/src/record/tensor.rs +++ b/burn-core/src/record/tensor.rs @@ -1,6 +1,5 @@ use super::{PrecisionSettings, Record}; use burn_tensor::{backend::Backend, Bool, DataSerialize, Int, Tensor}; -use core::marker::PhantomData; use serde::{Deserialize, Serialize}; /// This struct implements serde to lazily serialize and deserialize a float tensor @@ -8,7 +7,7 @@ use serde::{Deserialize, Serialize}; #[derive(new, Clone, Debug)] pub struct FloatTensorSerde { tensor: Tensor, - elem: PhantomData, + elem: core::marker::PhantomData, } /// This struct implements serde to lazily serialize and deserialize an int tensor @@ -16,7 +15,7 @@ pub struct FloatTensorSerde { #[derive(new, Clone, Debug)] pub struct IntTensorSerde { tensor: Tensor, - elem: PhantomData, + elem: core::marker::PhantomData, } /// This struct implements serde to lazily serialize and deserialize an bool tensor. diff --git a/burn-tch/src/backend.rs b/burn-tch/src/backend.rs index 6dedf1989..10b8b9ac2 100644 --- a/burn-tch/src/backend.rs +++ b/burn-tch/src/backend.rs @@ -91,4 +91,10 @@ impl Backend for TchBackend { fn name() -> String { "tch".to_string() } + + fn sync(device: &Self::Device) { + if let TchDevice::Cuda(index) = device { + tch::Cuda::synchronize(*index as i64); + } + } } diff --git a/burn-tensor/src/tensor/backend/base.rs b/burn-tensor/src/tensor/backend/base.rs index 43e6037e9..f02e96770 100644 --- a/burn-tensor/src/tensor/backend/base.rs +++ b/burn-tensor/src/tensor/backend/base.rs @@ -94,6 +94,9 @@ pub trait Backend: /// Seed the backend. fn seed(seed: u64); + + /// Sync the backend, ensure that all computation are finished. + fn sync(_device: &Self::Device) {} } pub(crate) type ADBackendTensorPrimitive = diff --git a/burn-wgpu/Cargo.toml b/burn-wgpu/Cargo.toml index aac01e056..6a8050041 100644 --- a/burn-wgpu/Cargo.toml +++ b/burn-wgpu/Cargo.toml @@ -11,10 +11,7 @@ repository = "https://github.com/burn-rs/burn/tree/main/burn-wgpu" version = "0.10.0" [features] -default = ["async"] -async = [] -# Still experimental -autotune = [] +default = [] [dependencies] burn-common = { path = "../burn-common", version = "0.10.0" } @@ -35,6 +32,10 @@ wgpu = { workspace = true } serde = { workspace = true } text_placeholder = { version = "0.5.0", features = ["struct_context"] } +hashbrown = { workspace = true } +burn-compute = { path = "../burn-compute", version = "0.10.0", default-features = false, features=["channel-mutex", "std"] } + + [dev-dependencies] burn-autodiff = { path = "../burn-autodiff", version = "0.10.0", default-features = false, features = [ "export_tests", @@ -45,10 +46,6 @@ burn-tensor = { path = "../burn-tensor", version = "0.10.0", default-features = burn-ndarray = { path = "../burn-ndarray", version = "0.10.0" } serial_test = "2.0.0" -# Still only in dev mode -hashbrown = { workspace = true } -burn-compute = { path = "../burn-compute", version = "0.10.0", default-features = false, features=["channel-mutex", "std"] } - [[bench]] name = "unary" harness = false diff --git a/burn-wgpu/benches/matmul.rs b/burn-wgpu/benches/matmul.rs index 4c7efa2ca..e5466a720 100644 --- a/burn-wgpu/benches/matmul.rs +++ b/burn-wgpu/benches/matmul.rs @@ -3,7 +3,7 @@ use burn_wgpu::{ benchmark::Benchmark, kernel::matmul::{ contiguous, contiguous_vectorized, matmul_mem_coalescing_default, matmul_naive_default, - tile, tile_vectorized, tune, + tile, tile_vectorized, }, run_benchmark, GraphicsApi, WgpuBackend, WgpuDevice, }; @@ -50,8 +50,8 @@ where } fn prepare(&self, device: &WgpuDevice) -> Self::Args { - let lhs = Tensor::random(self.shape_lhs.clone(), Distribution::Default).to_device(device); - let rhs = Tensor::random(self.shape_rhs.clone(), Distribution::Default).to_device(device); + let lhs = Tensor::random_device(self.shape_lhs.clone(), Distribution::Default, device); + let rhs = Tensor::random_device(self.shape_rhs.clone(), Distribution::Default, device); (lhs, rhs) } @@ -88,22 +88,6 @@ benchmark!( contiguous_vectorized::matmul_tiling_2d_default ); -struct MatmulAutotune; - -impl MatmulFunction, D> - for MatmulAutotune -{ - fn run( - lhs: Tensor, D>, - rhs: Tensor, D>, - ) -> Tensor, D> { - Tensor::from_primitive(tune::( - lhs.into_primitive(), - rhs.into_primitive(), - )) - } -} - fn main() { let num_repeats = 3; let batch_size = 3; @@ -141,10 +125,4 @@ fn main() { num_repeats, matmul: PhantomData }); - run_benchmark!(MatmulBenchmark:: { - shape_lhs: [batch_size, m, k].into(), - shape_rhs: [batch_size, k, n].into(), - num_repeats, - matmul: PhantomData - }); } diff --git a/burn-wgpu/src/backend.rs b/burn-wgpu/src/backend.rs index 008b32b25..8295bf8c2 100644 --- a/burn-wgpu/src/backend.rs +++ b/burn-wgpu/src/backend.rs @@ -2,6 +2,7 @@ use burn_tensor::backend::Backend; use rand::{rngs::StdRng, SeedableRng}; use crate::{ + compute::compute_client, element::{FloatElement, IntElement}, tensor::WgpuTensor, GraphicsApi, WgpuDevice, @@ -43,4 +44,9 @@ impl Backend for WgpuB fn ad_enabled() -> bool { false } + + fn sync(device: &Self::Device) { + let client = compute_client::(device); + client.sync(); + } } diff --git a/burn-wgpu/src/benchmark.rs b/burn-wgpu/src/benchmark.rs index 64e90e187..daef0f4e2 100644 --- a/burn-wgpu/src/benchmark.rs +++ b/burn-wgpu/src/benchmark.rs @@ -1,4 +1,4 @@ -use crate::{pool::get_context, GraphicsApi, WgpuDevice}; +use crate::{compute::compute_client, GraphicsApi, WgpuDevice}; use std::{ fmt::Display, time::{Duration, Instant}, @@ -15,6 +15,7 @@ impl BenchmarkResult { self.durations.iter().sum::() / self.durations.len() as u32 } + #[allow(dead_code)] pub(crate) fn median_duration(&self) -> Duration { let mut sorted = self.durations.clone(); sorted.sort(); @@ -83,23 +84,23 @@ pub trait Benchmark { fn name(&self) -> String; /// Run the benchmark a number of times. fn run(&self, device: &WgpuDevice) -> BenchmarkResult { - let context = get_context::(device); + let client = compute_client::(device); // Warmup self.execute(self.prepare(device)); - context.sync(); + client.sync(); let mut durations = Vec::with_capacity(self.num_samples()); for _ in 0..self.num_samples() { // Prepare let args = self.prepare(device); - context.sync(); + client.sync(); // Execute the benchmark let start = Instant::now(); self.execute(args); - context.sync(); + client.sync(); let end = Instant::now(); // Register the duration diff --git a/burn-wgpu/src/compute/base.rs b/burn-wgpu/src/compute/base.rs index 75f42ef27..2fc487377 100644 --- a/burn-wgpu/src/compute/base.rs +++ b/burn-wgpu/src/compute/base.rs @@ -1,26 +1,28 @@ -use super::{Kernel, WgpuServer}; -use crate::{ - compute::WgpuStorage, - context::{select_device, WorkGroup}, - kernel::{DynamicKernel, SourceTemplate, StaticKernel}, - GraphicsApi, WgpuDevice, -}; +use super::WgpuServer; +use crate::{compute::WgpuStorage, GraphicsApi, WgpuDevice}; +use alloc::sync::Arc; use burn_compute::{ channel::MutexComputeChannel, client::ComputeClient, - memory_management::{DeallocStrategy, SimpleMemoryManagement}, + memory_management::{DeallocStrategy, SimpleMemoryManagement, SliceStrategy}, Compute, }; -use std::{marker::PhantomData, sync::Arc}; +use wgpu::{DeviceDescriptor, DeviceType}; -type WgpuChannel = MutexComputeChannel; +type MemoryManagement = SimpleMemoryManagement; +type Server = WgpuServer; +type Channel = MutexComputeChannel; + +/// Wgpu [compute client](ComputeClient) to communicate with the [compute server](WgpuServer). +pub type WgpuComputeClient = ComputeClient; +/// Wgpu [server handle](burn_compute::server::Handle). +pub type WgpuHandle = burn_compute::server::Handle; /// Compute handle for the wgpu backend. -static COMPUTE: Compute = Compute::new(); +static COMPUTE: Compute, Channel> = Compute::new(); -pub fn compute_client( - device: &WgpuDevice, -) -> ComputeClient { +/// Get the [compute client](ComputeClient) for the given [device](WgpuDevice). +pub fn compute_client(device: &WgpuDevice) -> ComputeClient { let device = Arc::new(device); COMPUTE.client(&device, move || { @@ -37,93 +39,159 @@ pub fn compute_client( Ok(value) => value .parse::() .expect("BURN_WGPU_MAX_TASKS should be a positive integer."), - Err(_) => 16, // 16 tasks by default + Err(_) => 64, // 64 tasks by default }; let device = Arc::new(device_wgpu); let storage = WgpuStorage::new(device.clone()); - // Maximum reusability. - let memory_management = SimpleMemoryManagement::new(storage, DeallocStrategy::Never); + let memory_management = SimpleMemoryManagement::new( + storage, + DeallocStrategy::new_period_tick(1000), + SliceStrategy::Ratio(0.9), + ); let server = WgpuServer::new(memory_management, device, queue, max_tasks); - let channel = WgpuChannel::new(server); + let channel = Channel::new(server); ComputeClient::new(channel) }) } -pub struct DynamicComputeKernel { - kernel: K, - workgroup: WorkGroup, +/// Select the wgpu device and queue based on the provided [device](WgpuDevice). +pub async fn select_device( + device: &WgpuDevice, +) -> (wgpu::Device, wgpu::Queue, wgpu::AdapterInfo) { + let adapter = select_adapter::(device); + let limits = adapter.limits(); + + let (device, queue) = adapter + .request_device( + &DeviceDescriptor { + label: None, + features: wgpu::Features::empty(), + limits, + }, + None, + ) + .await + .map_err(|err| { + format!( + "Unable to request the device with the adapter {:?}, err {:?}", + adapter.get_info(), + err + ) + }) + .unwrap(); + + (device, queue, adapter.get_info()) } -impl Kernel for DynamicComputeKernel -where - K: DynamicKernel + 'static, -{ - fn source_template(self: Box) -> SourceTemplate { - self.kernel.source_template() +fn select_adapter(device: &WgpuDevice) -> wgpu::Adapter { + let instance = wgpu::Instance::default(); + + let mut adapters_other = Vec::new(); + let mut adapters = Vec::new(); + + instance + .enumerate_adapters(G::backend().into()) + .for_each(|adapter| { + let device_type = adapter.get_info().device_type; + + if let DeviceType::Other = device_type { + adapters_other.push(adapter); + return; + } + + let is_same_type = match device { + WgpuDevice::DiscreteGpu(_) => device_type == DeviceType::DiscreteGpu, + WgpuDevice::IntegratedGpu(_) => device_type == DeviceType::IntegratedGpu, + WgpuDevice::VirtualGpu(_) => device_type == DeviceType::VirtualGpu, + WgpuDevice::Cpu => device_type == DeviceType::Cpu, + WgpuDevice::BestAvailable => true, + }; + + if is_same_type { + adapters.push(adapter); + } + }); + + fn select( + num: usize, + error: &str, + mut adapters: Vec, + mut adapters_other: Vec, + ) -> wgpu::Adapter { + if adapters.len() <= num { + if adapters_other.len() <= num { + panic!( + "{}, adapters {:?}, other adapters {:?}", + error, + adapters + .into_iter() + .map(|adapter| adapter.get_info()) + .collect::>(), + adapters_other + .into_iter() + .map(|adapter| adapter.get_info()) + .collect::>(), + ); + } else { + return adapters_other.remove(num); + } + } + + adapters.remove(num) } - fn id(&self) -> String { - self.kernel.id() - } + let adapter = match device { + WgpuDevice::DiscreteGpu(num) => select( + *num, + "No Discrete GPU device found", + adapters, + adapters_other, + ), + WgpuDevice::IntegratedGpu(num) => select( + *num, + "No Integrated GPU device found", + adapters, + adapters_other, + ), + WgpuDevice::VirtualGpu(num) => select( + *num, + "No Virtual GPU device found", + adapters, + adapters_other, + ), + WgpuDevice::Cpu => select(0, "No CPU device found", adapters, adapters_other), + WgpuDevice::BestAvailable => { + let mut most_performant_adapter = None; + let mut current_score = -1; - fn workgroup(&self) -> WorkGroup { - self.workgroup.clone() - } -} - -#[derive(new)] -pub struct StaticComputeKernel { - workgroup: WorkGroup, - _kernel: PhantomData, -} - -impl Kernel for StaticComputeKernel -where - K: StaticKernel + 'static, -{ - fn source_template(self: Box) -> SourceTemplate { - K::source_template() - } - - fn id(&self) -> String { - format!("{:?}", core::any::TypeId::of::()) - } - - fn workgroup(&self) -> WorkGroup { - self.workgroup.clone() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{binary_elemwise, kernel::KernelSettings, AutoGraphicsApi}; - - #[test] - fn can_run_kernel() { - binary_elemwise!(Add, "+"); - - let client = compute_client::(&WgpuDevice::default()); - - let lhs: Vec = vec![0., 1., 2., 3., 4., 5., 6., 7.]; - let rhs: Vec = vec![10., 11., 12., 6., 7., 3., 1., 0.]; - let info: Vec = vec![1, 1, 1, 1, 8, 8, 8]; - - let lhs = client.create(bytemuck::cast_slice(&lhs)); - let rhs = client.create(bytemuck::cast_slice(&rhs)); - let out = client.empty(core::mem::size_of::() * 8); - let info = client.create(bytemuck::cast_slice(&info)); - - type Kernel = KernelSettings; - let kernel = Box::new(StaticComputeKernel::::new(WorkGroup::new(1, 1, 1))); - - client.execute(kernel, &[&lhs, &rhs, &out, &info]); - - let data = client.read(&out); - let output: &[f32] = bytemuck::cast_slice(&data); - - assert_eq!(output, [10., 12., 14., 9., 11., 8., 7., 7.]); - } + adapters.into_iter().for_each(|adapter| { + let info = adapter.get_info(); + let score = match info.device_type { + DeviceType::DiscreteGpu => 5, + DeviceType::Other => 4, // Let's be optimistic with the Other device, it's + // often a Discrete Gpu. + DeviceType::IntegratedGpu => 3, + DeviceType::VirtualGpu => 2, + DeviceType::Cpu => 1, + }; + + if score > current_score { + most_performant_adapter = Some(adapter); + current_score = score; + } + }); + + if let Some(adapter) = most_performant_adapter { + adapter + } else { + panic!("No adapter found for graphics API {:?}", G::default()); + } + } + }; + + log::info!("Using adapter {:?}", adapter.get_info()); + + adapter } diff --git a/burn-wgpu/src/compute/kernel.rs b/burn-wgpu/src/compute/kernel.rs new file mode 100644 index 000000000..7f01f189b --- /dev/null +++ b/burn-wgpu/src/compute/kernel.rs @@ -0,0 +1,106 @@ +use super::Kernel; +use crate::kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource}; +use core::marker::PhantomData; + +/// Provides launch information specifying the number of work groups to be used by a compute shader. +#[derive(new, Clone, Debug)] +pub struct WorkGroup { + /// Work groups for the x axis. + pub x: u32, + /// Work groups for the y axis. + pub y: u32, + /// Work groups for the z axis. + pub z: u32, +} + +impl WorkGroup { + /// Calculate the number of invocations of a compute shader. + pub fn num_invocations(&self) -> usize { + (self.x * self.y * self.z) as usize + } +} + +/// Wraps a [dynamic kernel source](DynamicKernelSource) into a [kernel](Kernel) with launch +/// information such as [workgroup](WorkGroup). +#[derive(new)] +pub struct DynamicKernel { + kernel: K, + workgroup: WorkGroup, +} + +/// Wraps a [static kernel source](StaticKernelSource) into a [kernel](Kernel) with launch +/// information such as [workgroup](WorkGroup). +#[derive(new)] +pub struct StaticKernel { + workgroup: WorkGroup, + _kernel: PhantomData, +} + +impl Kernel for DynamicKernel +where + K: DynamicKernelSource + 'static, +{ + fn source(self: Box) -> SourceTemplate { + self.kernel.source() + } + + fn id(&self) -> String { + self.kernel.id() + } + + fn workgroup(&self) -> WorkGroup { + self.workgroup.clone() + } +} + +impl Kernel for StaticKernel +where + K: StaticKernelSource + 'static, +{ + fn source(self: Box) -> SourceTemplate { + K::source() + } + + fn id(&self) -> String { + format!("{:?}", core::any::TypeId::of::()) + } + + fn workgroup(&self) -> WorkGroup { + self.workgroup.clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + binary_elemwise, compute::compute_client, kernel::KernelSettings, AutoGraphicsApi, + WgpuDevice, + }; + + #[test] + fn can_run_kernel() { + binary_elemwise!(Add, "+"); + + let client = compute_client::(&WgpuDevice::default()); + + let lhs: Vec = vec![0., 1., 2., 3., 4., 5., 6., 7.]; + let rhs: Vec = vec![10., 11., 12., 6., 7., 3., 1., 0.]; + let info: Vec = vec![1, 1, 1, 1, 8, 8, 8]; + + let lhs = client.create(bytemuck::cast_slice(&lhs)); + let rhs = client.create(bytemuck::cast_slice(&rhs)); + let out = client.empty(core::mem::size_of::() * 8); + let info = client.create(bytemuck::cast_slice(&info)); + + type Kernel = KernelSettings; + let kernel = Box::new(StaticKernel::::new(WorkGroup::new(1, 1, 1))); + + client.execute(kernel, &[&lhs, &rhs, &out, &info]); + + let data = client.read(&out); + let output: &[f32] = bytemuck::cast_slice(&data); + + assert_eq!(output, [10., 12., 14., 9., 11., 8., 7., 7.]); + } +} diff --git a/burn-wgpu/src/compute/mod.rs b/burn-wgpu/src/compute/mod.rs index 3b14e2686..757695e20 100644 --- a/burn-wgpu/src/compute/mod.rs +++ b/burn-wgpu/src/compute/mod.rs @@ -1,7 +1,9 @@ mod base; +mod kernel; mod server; mod storage; pub use base::*; +pub use kernel::*; pub use server::*; pub use storage::*; diff --git a/burn-wgpu/src/compute/server.rs b/burn-wgpu/src/compute/server.rs index 01893c255..ebc2b7b5f 100644 --- a/burn-wgpu/src/compute/server.rs +++ b/burn-wgpu/src/compute/server.rs @@ -1,9 +1,8 @@ -use std::{borrow::Cow, sync::Arc}; - -use super::WgpuStorage; -use crate::{context::WorkGroup, kernel::SourceTemplate}; +use super::{WgpuStorage, WorkGroup}; +use crate::kernel::SourceTemplate; +use alloc::{borrow::Cow, sync::Arc}; use burn_compute::{ - memory_management::{MemoryManagement, SimpleMemoryManagement}, + memory_management::MemoryManagement, server::{self, ComputeServer}, }; use hashbrown::HashMap; @@ -13,7 +12,8 @@ use wgpu::{ }; /// Wgpu compute server. -pub struct WgpuServer> { +#[derive(Debug)] +pub struct WgpuServer> { memory_management: MM, device: Arc, queue: wgpu::Queue, @@ -21,20 +21,27 @@ pub struct WgpuServer> { pipelines: HashMap>, tasks: Vec, max_tasks: usize, + manual_available: HashMap>>, + manual_taken: Vec<(usize, server::Handle)>, } -#[derive(new)] +#[derive(new, Debug)] struct ComputeTask { pipeline: Arc, bind_group: BindGroup, work_group: WorkGroup, } +/// Kernel trait with the [source](SourceTemplate) that will be compiled and cached based on the +/// provided id. +/// +/// The kernel will be launched with the given [workgroup](WorkGroup). pub trait Kernel: 'static + Send { /// Source template for the kernel. - fn source_template(self: Box) -> SourceTemplate; + fn source(self: Box) -> SourceTemplate; /// Identifier for the kernel, used for caching kernel compilation. fn id(&self) -> String; + /// Launch information. fn workgroup(&self) -> WorkGroup; } @@ -42,6 +49,7 @@ impl WgpuServer where MM: MemoryManagement, { + /// Create a new server. pub fn new( memory_management: MM, device: Arc, @@ -60,6 +68,8 @@ where pipelines: HashMap::new(), tasks: Vec::new(), max_tasks, + manual_available: HashMap::new(), + manual_taken: Vec::new(), } } @@ -68,13 +78,54 @@ where self.tasks.is_empty(), "Tasks should be completed before submitting the current encoder." ); - println!("Submit"); let mut new_encoder = self .device .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); core::mem::swap(&mut new_encoder, &mut self.encoder); self.queue.submit(Some(new_encoder.finish())); + + // Cleanup allocations and deallocations. + self.free_manual_allocations(); + self.memory_management.storage().perform_deallocations(); + } + + fn free_manual_allocations(&mut self) { + let mut manual_taken_tmp = Vec::new(); + core::mem::swap(&mut manual_taken_tmp, &mut self.manual_taken); + + for (size, handle) in manual_taken_tmp.drain(..) { + if handle.can_mut() { + self.register_manual(size, handle); + } else { + self.manual_taken.push((size, handle)); + } + } + } + + // Finds a free, manually-added handle of specified size, or creates it if none is found + fn manual_reserve(&mut self, size: usize) -> server::Handle { + let handle = self + .manual_available + .get_mut(&size) + .and_then(|h| h.pop()) + .unwrap_or_else(|| { + let memory = self.memory_management.alloc(size); + server::Handle::new(memory) + }); + + self.manual_taken.push((size, handle.clone())); + + handle + } + + // Manually adds a handle of given size + fn register_manual(&mut self, size: usize, handle: server::Handle) { + if let Some(handles) = self.manual_available.get_mut(&size) { + handles.push(handle); + } else { + self.manual_available.insert(size, [handle].into()); + } } fn register_tasks(&mut self) { @@ -102,7 +153,7 @@ where return pipeline.clone(); } - let pipeline = self.compile_source(&kernel.source_template().complete()); + let pipeline = self.compile_source(&kernel.source().complete()); self.pipelines.insert(kernel_id.clone(), pipeline.clone()); pipeline @@ -138,7 +189,7 @@ where // Register previous tasks before reading the buffer so that it is up to date. self.register_tasks(); - let resource = self.memory_management.get(handle); + let resource = self.memory_management.get(&handle.memory); let size = resource.size(); let buffer_dest = self.device.create_buffer(&wgpu::BufferDescriptor { @@ -182,8 +233,13 @@ where } } + /// When we create a new handle from existing data, we use custom allocations so that we don't + /// have to execute the current pending tasks. + /// + /// This is important, otherwise the compute passes are going to be too small and we won't be able to + /// fully utilize the GPU. fn create(&mut self, data: &[u8]) -> server::Handle { - let handle = self.empty(data.len()); + let handle = self.manual_reserve(data.len()); let buffer_src = Arc::new(self.device.create_buffer_init(&BufferInitDescriptor { label: Some("Buffer Src"), @@ -191,9 +247,7 @@ where usage: wgpu::BufferUsages::COPY_SRC, })); - let resource = self.memory_management.get(&handle); - - self.register_tasks(); + let resource = self.memory_management.get(&handle.memory); self.encoder.copy_buffer_to_buffer( &buffer_src, @@ -207,7 +261,7 @@ where } fn empty(&mut self, size: usize) -> server::Handle { - self.memory_management.reserve(size) + server::Handle::new(self.memory_management.reserve(size)) } fn execute(&mut self, kernel: Self::Kernel, handles: &[&server::Handle]) { @@ -217,7 +271,7 @@ where let handles = handles .iter() - .map(|handle| self.memory_management.get(handle)) + .map(|handle| self.memory_management.get(&handle.memory)) .collect::>(); let entries = handles @@ -249,5 +303,7 @@ where self.register_tasks(); self.submit(); } + + self.device.poll(wgpu::Maintain::Wait); } } diff --git a/burn-wgpu/src/compute/storage.rs b/burn-wgpu/src/compute/storage.rs index 5a8f09669..ef74a927a 100644 --- a/burn-wgpu/src/compute/storage.rs +++ b/burn-wgpu/src/compute/storage.rs @@ -2,14 +2,25 @@ use burn_compute::storage::{ComputeStorage, StorageHandle, StorageId, StorageUti use hashbrown::HashMap; use std::{num::NonZeroU64, sync::Arc}; +/// Buffer storage for wgpu. pub struct WgpuStorage { memory: HashMap>, + deallocations: Vec, device: Arc, } +impl core::fmt::Debug for WgpuStorage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str()) + } +} + +/// The memory resource that can be allocated for wgpu. #[derive(new, Debug)] pub struct WgpuResource { + /// The wgpu buffer. pub buffer: Arc, + /// How the resource is used. pub kind: WgpuResourceKind, } @@ -44,6 +55,7 @@ impl WgpuResource { } } +/// How the resource is used, either as a slice or fully. #[derive(Debug)] pub enum WgpuResourceKind { /// Represents an entire buffer. @@ -54,12 +66,23 @@ pub enum WgpuResourceKind { /// Keeps actual wgpu buffer references in a hashmap with ids as key. impl WgpuStorage { + /// Create a new storage on the given [device](wgpu::Device). pub fn new(device: Arc) -> Self { Self { memory: HashMap::new(), + deallocations: Vec::new(), device, } } + + /// Actually deallocates buffers tagged to be deallocated. + pub fn perform_deallocations(&mut self) { + for id in self.deallocations.drain(..) { + if let Some(buffer) = self.memory.remove(&id) { + buffer.destroy() + } + } + } } impl ComputeStorage for WgpuStorage { @@ -96,7 +119,6 @@ impl ComputeStorage for WgpuStorage { } fn dealloc(&mut self, id: StorageId) { - self.memory.get(&id).unwrap().destroy(); - let _ = self.memory.remove(&id); + self.deallocations.push(id); } } diff --git a/burn-wgpu/src/context/base.rs b/burn-wgpu/src/context/base.rs deleted file mode 100644 index ee207dbf3..000000000 --- a/burn-wgpu/src/context/base.rs +++ /dev/null @@ -1,410 +0,0 @@ -use super::client::ContextClient; -use crate::{ - context::server::ContextServer, - kernel::{DynamicKernel, StaticKernel}, - tune::Tuner, - GraphicsApi, WgpuDevice, -}; -use burn_common::id::IdGenerator; -use spin::Mutex; -use std::{ - any::TypeId, - borrow::Cow, - collections::HashMap, - sync::atomic::{AtomicBool, Ordering}, - sync::Arc, -}; -use wgpu::{ - util::{BufferInitDescriptor, DeviceExt}, - Buffer, ComputePipeline, DeviceDescriptor, DeviceType, ShaderModuleDescriptor, -}; - -#[cfg(feature = "async")] -pub(crate) type ContextClientImpl = super::client::AsyncContextClient; -#[cfg(not(feature = "async"))] -pub(crate) type ContextClientImpl = super::client::SyncContextClient; - -#[cfg(feature = "async")] -pub(crate) type ContextServerImpl = super::server::AsyncContextServer; -#[cfg(not(feature = "async"))] -pub(crate) type ContextServerImpl = super::server::SyncContextServer; - -/// The context is the basic struct that allows to execute GPU kernel on devices. -/// -/// You can access a context for a WGPUDevice using get_context. -#[derive(Debug)] -pub struct Context { - id: String, - device_wgpu: Arc, - cache: Mutex>>, - is_tuning: AtomicBool, - client: ContextClientImpl, - pub(crate) tuner: Tuner, - tuning_template_ids: Mutex>, - pub(crate) device: WgpuDevice, - pub(crate) info: wgpu::AdapterInfo, -} - -#[derive(Debug, Hash, Clone, PartialOrd, PartialEq, Eq)] -enum TemplateKey { - Static(TypeId), - Dynamic(String), -} - -/// Provides launch information specifying the number of work groups to be used by a compute shader. -#[derive(new, Clone, Debug)] -pub struct WorkGroup { - /// Work groups for the x axis. - pub x: u32, - /// Work groups for the y axis. - pub y: u32, - /// Work groups for the z axis. - pub z: u32, -} - -impl WorkGroup { - /// Calculate the number of invocations of a compute shader. - pub fn num_invocations(&self) -> usize { - (self.x * self.y * self.z) as usize - } -} - -impl Context { - /// Create a new context where computing tasks will be executed on the given - /// [device](WgpuDevice). - pub(crate) fn new(device: &WgpuDevice) -> Self { - let (device_wgpu, queue, info) = pollster::block_on(select_device::(device)); - let device = device.clone(); - let device_wgpu = Arc::new(device_wgpu); - let client = ContextServerImpl::start(device_wgpu.clone(), queue); - - Self { - id: IdGenerator::generate(), - device_wgpu, - device, - client, - cache: Mutex::new(HashMap::new()), - is_tuning: AtomicBool::new(false), - tuner: Tuner::new(), - tuning_template_ids: Mutex::new(Vec::new()), - info, - } - } - - /// Wait for all computation to be executed. - /// - /// Useful for benchmarks. - pub fn sync(&self) { - self.client.sync(); - } - - /// Execute a kernel using the provided buffers. - /// - /// # Notes - /// - /// This function isn't safe, buffer can be mutated by the GPU. The users must ensure that a - /// buffer can be mutated when launching a compute shaders with write access to a buffer. - /// - /// Buffer positions are used as bindings when launching a compute kernel. - pub fn execute( - &self, - work_group: WorkGroup, - pipeline: Arc, - buffers: &[&Buffer], - ) { - let group_layout = pipeline.get_bind_group_layout(0); - - let entries = buffers - .iter() - .enumerate() - .map(|(i, buffer)| wgpu::BindGroupEntry { - binding: i as u32, - resource: buffer.as_entire_binding(), - }) - .collect::>(); - - let bind_group = self - .device_wgpu - .create_bind_group(&wgpu::BindGroupDescriptor { - label: None, - layout: &group_layout, - entries: &entries, - }); - - self.client - .register_compute(bind_group, pipeline, work_group) - } - - /// Create a new buffer with the provided size. - pub fn create_buffer(&self, size: usize) -> Arc { - Arc::new(self.device_wgpu.create_buffer(&wgpu::BufferDescriptor { - label: None, - size: size as u64, - usage: wgpu::BufferUsages::COPY_DST - | wgpu::BufferUsages::STORAGE - | wgpu::BufferUsages::COPY_SRC, - mapped_at_creation: false, - })) - } - - /// Create a new buffer initialized with the provided bytes. - pub fn create_buffer_with_data(&self, data: &[u8]) -> Arc { - self.create_buffer_with_data_options(data, false) - } - - /// Create a new buffer initialized with the provided bytes with the option to be sync. - /// - /// It's important to be sync when you want to reuse the buffer using the Arc strong count for - /// inner mutability. - pub fn create_buffer_with_data_options(&self, data: &[u8], sync: bool) -> Arc { - let buffer_src = Arc::new(self.device_wgpu.create_buffer_init(&BufferInitDescriptor { - label: Some("Buffer Src"), - contents: data, - usage: wgpu::BufferUsages::COPY_SRC, - })); - - let buffer_dest = self.create_buffer(buffer_src.size() as usize); - - self.client.copy_buffer(buffer_src, buffer_dest, sync) - } - - /// Copy buffer to buffer. - /// - /// Wait for registered may be useful if you want to allow inplace operations on the created - /// buffer. Otherwise, the strong count of the buffer might not be 1 when registering a new - /// operation, which makes the buffer readonly. - pub fn copy_buffer(&self, buffer_src: Arc, wait_for_registered: bool) -> Arc { - let buffer_dest = self.create_buffer(buffer_src.size() as usize); - - self.client - .copy_buffer(buffer_src, buffer_dest, wait_for_registered) - } - - /// Read a buffer from the GPU and return its content as bytes. - pub fn read_buffer(&self, buffer: Arc) -> Vec { - self.client.read_buffer(buffer) - } - - /// Compile a kernel template if not present in the cache. - pub fn compile_static(&self) -> Arc { - let mut cache = self.cache.lock(); - let template_id = TemplateKey::Static(TypeId::of::()); - - if let Some(module) = cache.get(&template_id) { - return module.clone(); - } - - let source = K::source_template(); - let pipeline = self.compile_source(&source.complete()); - - if self.is_tuning.load(Ordering::Relaxed) { - let mut templates_vec = self.tuning_template_ids.lock(); - templates_vec.push(template_id.clone()); - } - - cache.insert(template_id, pipeline.clone()); - pipeline - } - - /// Compile a dynamic template if not present in the cache. - pub fn compile_dynamic(&self, kernel: K) -> Arc { - let mut cache = self.cache.lock(); - let template_id = TemplateKey::Dynamic(kernel.id()); - - if let Some(module) = cache.get(&template_id) { - return module.clone(); - } - - let source = kernel.source_template(); - let pipeline = self.compile_source(&source.complete()); - - if self.is_tuning.load(Ordering::Relaxed) { - let mut templates_vec = self.tuning_template_ids.lock(); - templates_vec.push(template_id.clone()); - } - - cache.insert(template_id, pipeline.clone()); - pipeline - } - - fn compile_source(&self, source: &str) -> Arc { - let module = self - .device_wgpu - .create_shader_module(ShaderModuleDescriptor { - label: None, - source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), - }); - let pipeline = self - .device_wgpu - .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { - label: None, - layout: None, - module: &module, - entry_point: "main", - }); - - Arc::new(pipeline) - } - - pub(crate) fn start_tuning(&self) { - self.is_tuning.store(true, Ordering::Relaxed); - } - - pub(crate) fn stop_tuning(&self) { - self.is_tuning.store(false, Ordering::Relaxed); - - // clean cache of pipelines accumulated during tuning - let mut cache = self.cache.lock(); - let mut tuning_template_ids = self.tuning_template_ids.lock(); - for template_id in tuning_template_ids.iter() { - cache.remove(template_id); - } - - tuning_template_ids.clear(); - } -} - -impl PartialEq for Context { - fn eq(&self, other: &Self) -> bool { - self.id == other.id - } -} - -pub(crate) async fn select_device( - device: &WgpuDevice, -) -> (wgpu::Device, wgpu::Queue, wgpu::AdapterInfo) { - let adapter = select_adapter::(device); - let limits = adapter.limits(); - - let (device, queue) = adapter - .request_device( - &DeviceDescriptor { - label: None, - features: wgpu::Features::empty(), - limits, - }, - None, - ) - .await - .map_err(|err| { - format!( - "Unable to request the device with the adapter {:?}, err {:?}", - adapter.get_info(), - err - ) - }) - .unwrap(); - - (device, queue, adapter.get_info()) -} - -fn select_adapter(device: &WgpuDevice) -> wgpu::Adapter { - let instance = wgpu::Instance::default(); - - let mut adapters_other = Vec::new(); - let mut adapters = Vec::new(); - - instance - .enumerate_adapters(G::backend().into()) - .for_each(|adapter| { - let device_type = adapter.get_info().device_type; - - if let DeviceType::Other = device_type { - adapters_other.push(adapter); - return; - } - - let is_same_type = match device { - WgpuDevice::DiscreteGpu(_) => device_type == DeviceType::DiscreteGpu, - WgpuDevice::IntegratedGpu(_) => device_type == DeviceType::IntegratedGpu, - WgpuDevice::VirtualGpu(_) => device_type == DeviceType::VirtualGpu, - WgpuDevice::Cpu => device_type == DeviceType::Cpu, - WgpuDevice::BestAvailable => true, - }; - - if is_same_type { - adapters.push(adapter); - } - }); - - fn select( - num: usize, - error: &str, - mut adapters: Vec, - mut adapters_other: Vec, - ) -> wgpu::Adapter { - if adapters.len() <= num { - if adapters_other.len() <= num { - panic!( - "{}, adapters {:?}, other adapters {:?}", - error, - adapters - .into_iter() - .map(|adapter| adapter.get_info()) - .collect::>(), - adapters_other - .into_iter() - .map(|adapter| adapter.get_info()) - .collect::>(), - ); - } else { - return adapters_other.remove(num); - } - } - - adapters.remove(num) - } - - let adapter = match device { - WgpuDevice::DiscreteGpu(num) => select( - *num, - "No Discrete GPU device found", - adapters, - adapters_other, - ), - WgpuDevice::IntegratedGpu(num) => select( - *num, - "No Integrated GPU device found", - adapters, - adapters_other, - ), - WgpuDevice::VirtualGpu(num) => select( - *num, - "No Virtual GPU device found", - adapters, - adapters_other, - ), - WgpuDevice::Cpu => select(0, "No CPU device found", adapters, adapters_other), - WgpuDevice::BestAvailable => { - let mut most_performant_adapter = None; - let mut current_score = -1; - - adapters.into_iter().for_each(|adapter| { - let info = adapter.get_info(); - let score = match info.device_type { - DeviceType::DiscreteGpu => 5, - DeviceType::Other => 4, // Let's be optimistic with the Other device, it's - // often a Discrete Gpu. - DeviceType::IntegratedGpu => 3, - DeviceType::VirtualGpu => 2, - DeviceType::Cpu => 1, - }; - - if score > current_score { - most_performant_adapter = Some(adapter); - current_score = score; - } - }); - - if let Some(adapter) = most_performant_adapter { - adapter - } else { - panic!("No adapter found for graphics API {:?}", G::default()); - } - } - }; - - log::info!("Using adapter {:?}", adapter.get_info()); - - adapter -} diff --git a/burn-wgpu/src/context/client.rs b/burn-wgpu/src/context/client.rs deleted file mode 100644 index ee3af7e2f..000000000 --- a/burn-wgpu/src/context/client.rs +++ /dev/null @@ -1,190 +0,0 @@ -use super::WorkGroup; -use std::sync::Arc; -use wgpu::{BindGroup, Buffer, ComputePipeline}; - -#[cfg(feature = "async")] -pub use async_client::AsyncContextClient; -#[cfg(not(feature = "async"))] -pub use sync_client::SyncContextClient; - -/// Context client allows to speak with a server to execute tasks on the GPU. -pub trait ContextClient { - /// Copy the source buffer content into the destination buffer. - /// - /// # Notes - /// - /// Make sure the source buffer isn't used afterward, since a race condition may happen. - /// - /// If the source buffer is still used afterward, use [tensor copy](crate::tensor::WgpuTensor::copy) - /// instead. This method is still useful to load data from the CPU into a new buffer. - fn copy_buffer( - &self, - buffer_src: Arc, - buffer_dest: Arc, - wait_for_registered: bool, - ) -> Arc; - /// Read a [buffer](Buffer). - /// - /// # Notes - /// - /// All pending compute tasks will be executed. - fn read_buffer(&self, buffer: Arc) -> Vec; - /// Register a new computing task. - fn register_compute( - &self, - bind_group: BindGroup, - pipeline: Arc, - work_group: WorkGroup, - ); - /// Wait for all computation to be done. - /// - /// Useful for benchmarks. - fn sync(&self); -} - -#[cfg(feature = "async")] -mod async_client { - use super::ContextClient; - use crate::context::{ - server::{ComputeTask, ContextTask, CopyBufferTask, ReadBufferTask}, - WorkGroup, - }; - use std::sync::{mpsc, Arc}; - use wgpu::{BindGroup, Buffer, ComputePipeline}; - - /// Client returned by - #[derive(new, Debug)] - pub struct AsyncContextClient { - sender: mpsc::SyncSender, - _server_handle: std::thread::JoinHandle<()>, - } - - impl ContextClient for AsyncContextClient { - fn sync(&self) { - let (sender, receiver) = std::sync::mpsc::channel(); - - self.sender.send(ContextTask::Sync(sender)).unwrap(); - - if receiver.iter().next().is_some() { - log::debug!("Sync completed"); - } else { - panic!("Unable sync") - } - } - - fn copy_buffer( - &self, - buffer_src: Arc, - buffer_dest: Arc, - wait_for_registered: bool, - ) -> Arc { - if wait_for_registered { - assert_eq!(Arc::strong_count(&buffer_dest), 1, "You can't wait for the buffer to be registered when multiple references already exist."); - } - - self.sender - .send(CopyBufferTask::new(buffer_src, buffer_dest.clone()).into()) - .unwrap(); - - if !wait_for_registered { - return buffer_dest; - } - - // Wait for the buffer to be correctly registered so that inplace operations can be - // prioritize. - // - // Note that this is unsafe and a channel could have been used to wait for completion. - // The loop is there for performance reason. - // - // TODO: Use a performant one time channel here as callback instead. - loop { - std::thread::sleep(std::time::Duration::from_micros(1)); - - if Arc::strong_count(&buffer_dest) == 1 { - return buffer_dest; - } - } - } - - fn read_buffer(&self, buffer: Arc) -> Vec { - let (sender, receiver) = std::sync::mpsc::channel(); - - self.sender - .send(ReadBufferTask::new(buffer, sender).into()) - .unwrap(); - - let mut iter = receiver.iter(); - if let Some(data) = iter.next() { - data - } else { - panic!("Unable to read buffer") - } - } - fn register_compute( - &self, - bind_group: BindGroup, - pipeline: Arc, - work_group: WorkGroup, - ) { - self.sender - .send(ComputeTask::new(bind_group, pipeline, work_group).into()) - .unwrap(); - } - } -} - -#[cfg(not(feature = "async"))] -mod sync_client { - use super::ContextClient; - use crate::context::{ - server::{ComputeTask, SyncContextServer}, - WorkGroup, - }; - use std::sync::Arc; - use wgpu::{BindGroup, Buffer, ComputePipeline}; - - #[derive(Debug)] - pub struct SyncContextClient { - server: spin::Mutex, - } - - impl SyncContextClient { - pub fn new(server: SyncContextServer) -> Self { - Self { - server: spin::Mutex::new(server), - } - } - } - - impl ContextClient for SyncContextClient { - fn sync(&self) { - let mut server = self.server.lock(); - server.sync(); - } - fn copy_buffer( - &self, - buffer_src: Arc, - buffer_dest: Arc, - _wait_for_registered: bool, // Ignored when sync - ) -> Arc { - let mut server = self.server.lock(); - server.buffer_to_buffer(buffer_src, buffer_dest.clone()); - - buffer_dest - } - fn read_buffer(&self, buffer: Arc) -> Vec { - let mut server = self.server.lock(); - server.read_buffer(&buffer) - } - - fn register_compute( - &self, - bind_group: BindGroup, - pipeline: Arc, - work_group: WorkGroup, - ) { - let mut server = self.server.lock(); - server.register_compute(ComputeTask::new(bind_group, pipeline, work_group)); - } - } -} diff --git a/burn-wgpu/src/context/mod.rs b/burn-wgpu/src/context/mod.rs deleted file mode 100644 index 8ad1f1544..000000000 --- a/burn-wgpu/src/context/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -pub(super) mod client; -pub(super) mod server; - -mod base; - -pub use base::*; diff --git a/burn-wgpu/src/context/server.rs b/burn-wgpu/src/context/server.rs deleted file mode 100644 index 81104098f..000000000 --- a/burn-wgpu/src/context/server.rs +++ /dev/null @@ -1,269 +0,0 @@ -use super::{client::ContextClient, WorkGroup}; -use std::sync::Arc; -use wgpu::{BindGroup, Buffer, CommandEncoder, ComputePipeline}; - -#[cfg(feature = "async")] -pub use async_server::{AsyncContextServer, ContextTask, CopyBufferTask, ReadBufferTask}; - -/// Context server allow to run tasks on the GPU. -/// -/// # Notes -/// -/// There are two implementations of this trait. One is a bit more performant while the other -/// doesn't require std. -/// -/// * [Asynchronous server](AsyncContextServer). -/// * [Synchronous server](SyncContextServer). -pub trait ContextServer { - /// The client where task can be sent to the server for execution. - type Client: ContextClient; - - /// Start the server and returns its [client](ContextClient). - fn start(device: Arc, queue: wgpu::Queue) -> Self::Client; -} - -/// Context server where each operation is added in a synchronous maner. -#[derive(Debug)] -pub struct SyncContextServer { - device: Arc, - queue: wgpu::Queue, - encoder: CommandEncoder, - tasks: Vec, - max_tasks: usize, -} - -/// Basic building block to execute computing tasks on the GPU. -#[derive(new, Debug)] -pub struct ComputeTask { - bind_group: BindGroup, - pipeline: Arc, - work_group: WorkGroup, -} - -/// Most of the functions are similar to [server client](IContextClient). -/// -/// The main difference comes from the functions are mutable instead of immutable requirering a -/// lock by a sync client or using an async channel with the async client/server. -impl SyncContextServer { - /// Create a new sync context server. - pub fn new(device: Arc, queue: wgpu::Queue) -> Self { - let encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { - label: Some("Command Encoder"), - }); - - // TODO: Support a way to modify this value without std. - let max_tasks = match std::env::var("BURN_WGPU_MAX_TASKS") { - Ok(value) => value - .parse::() - .expect("BURN_WGPU_MAX_TASKS should be a positive integer."), - Err(_) => 16, // 16 tasks by default - }; - - Self { - device, - queue, - encoder, - tasks: Vec::new(), - max_tasks, - } - } - - pub fn register_compute(&mut self, task: ComputeTask) { - self.tasks.push(task); - - if self.tasks.len() > self.max_tasks { - self.register_tasks(); - self.submit(); - } - } - - pub fn read_buffer(&mut self, buffer: &Buffer) -> Vec { - // Register previous tasks before reading the buffer so that it is up to date. - self.register_tasks(); - - let size = buffer.size(); - let buffer_dest = self.device.create_buffer(&wgpu::BufferDescriptor { - label: None, - size, - usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, - mapped_at_creation: false, - }); - - self.encoder - .copy_buffer_to_buffer(buffer, 0, &buffer_dest, 0, size); - - self.submit(); - - let buffer_slice = buffer_dest.slice(..); - let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel(); - buffer_slice.map_async(wgpu::MapMode::Read, move |v| { - sender - .send(v) - .expect("Unable to send buffer slice result to async channel.") - }); - - self.device.poll(wgpu::Maintain::Wait); - - let result = pollster::block_on(receiver.receive()); - - if let Some(Ok(())) = result { - let data = buffer_slice.get_mapped_range(); - let result = bytemuck::cast_slice(&data).to_vec(); - - drop(data); - buffer_dest.unmap(); - result - } else { - panic!("Unable to read buffer {:?}", result) - } - } - - pub fn sync(&mut self) { - if !self.tasks.is_empty() { - self.register_tasks(); - self.submit(); - } - - self.device.poll(wgpu::Maintain::Wait); - } - - pub fn buffer_to_buffer(&mut self, buffer_src: Arc, buffer_dest: Arc) { - self.encoder - .copy_buffer_to_buffer(&buffer_src, 0, &buffer_dest, 0, buffer_src.size()); - } - - fn register_tasks(&mut self) { - let mut compute = self - .encoder - .begin_compute_pass(&wgpu::ComputePassDescriptor { label: None }); - for task in self.tasks.iter() { - compute.set_pipeline(&task.pipeline); - compute.set_bind_group(0, &task.bind_group, &[]); - compute.dispatch_workgroups(task.work_group.x, task.work_group.y, task.work_group.z); - } - std::mem::drop(compute); - self.tasks.clear(); - } - - fn submit(&mut self) { - assert!( - self.tasks.is_empty(), - "Tasks should be completed before submitting the current encoder." - ); - let mut new_encoder = self - .device - .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); - core::mem::swap(&mut new_encoder, &mut self.encoder); - - self.queue.submit(Some(new_encoder.finish())); - } -} - -#[cfg(feature = "async")] -mod async_server { - use crate::context::client::AsyncContextClient; - - use super::{ComputeTask, ContextServer, SyncContextServer}; - use std::sync::{mpsc, Arc}; - use wgpu::Buffer; - - #[derive(new)] - pub struct ReadBufferTask { - buffer: Arc, - sender: mpsc::Sender>, - } - - #[derive(new)] - pub struct CopyBufferTask { - pub(crate) buffer_src: Arc, - pub(crate) buffer_dest: Arc, - } - - pub enum ContextTask { - Compute(ComputeTask), - ReadBuffer(ReadBufferTask), - CopyBuffer(CopyBufferTask), - Sync(mpsc::Sender<()>), - } - - impl From for ContextTask { - fn from(val: ComputeTask) -> Self { - ContextTask::Compute(val) - } - } - - impl From for ContextTask { - fn from(val: ReadBufferTask) -> Self { - ContextTask::ReadBuffer(val) - } - } - - impl From for ContextTask { - fn from(val: CopyBufferTask) -> Self { - ContextTask::CopyBuffer(val) - } - } - - /// Asynchronous context server where [tasks](ContextTask) are sent using a channel. - /// - /// # Notes - /// - /// This is pretty useful to avoid blocking the main thread when registering and - /// executing [compute tasks](ComputeTask). - pub struct AsyncContextServer { - server: SyncContextServer, - receiver: mpsc::Receiver, - } - - impl AsyncContextServer { - fn run(mut self) { - loop { - let task = self.receiver.recv().unwrap(); - match task { - ContextTask::Compute(task) => self.server.register_compute(task), - ContextTask::CopyBuffer(task) => self - .server - .buffer_to_buffer(task.buffer_src, task.buffer_dest), - ContextTask::ReadBuffer(task) => { - let bytes = self.server.read_buffer(&task.buffer); - task.sender.send(bytes).unwrap(); - } - ContextTask::Sync(callback) => { - self.server.sync(); - callback.send(()).unwrap(); - } - }; - } - } - } - impl ContextServer for AsyncContextServer { - type Client = AsyncContextClient; - - fn start(device: Arc, queue: wgpu::Queue) -> Self::Client { - let (sender, receiver) = std::sync::mpsc::sync_channel(50); - let server = SyncContextServer::new(device, queue); - let context = Self { server, receiver }; - - let handle = std::thread::spawn(|| context.run()); - - AsyncContextClient::new(sender, handle) - } - } -} - -#[cfg(not(feature = "async"))] -mod sync_server { - use super::{ContextServer, SyncContextServer}; - use crate::context::client::SyncContextClient; - use std::sync::Arc; - - impl ContextServer for SyncContextServer { - type Client = SyncContextClient; - - fn start(device: Arc, queue: wgpu::Queue) -> Self::Client { - let server = Self::new(device, queue); - - SyncContextClient::new(server) - } - } -} diff --git a/burn-wgpu/src/kernel/base.rs b/burn-wgpu/src/kernel/base.rs index bf68b623d..5b5b6fc6f 100644 --- a/burn-wgpu/src/kernel/base.rs +++ b/burn-wgpu/src/kernel/base.rs @@ -1,17 +1,21 @@ use super::SourceTemplate; -use crate::{context::WorkGroup, element::WgpuElement, tensor::WgpuTensor}; +use crate::{ + compute::{StaticKernel, WorkGroup}, + element::WgpuElement, + tensor::WgpuTensor, +}; use std::marker::PhantomData; /// Static wgpu kernel to create a [source template](SourceTemplate). -pub trait StaticKernel: Send + 'static { +pub trait StaticKernelSource: Send + 'static { /// Source template for the kernel. - fn source_template() -> SourceTemplate; + fn source() -> SourceTemplate; } /// Dynamic wgpu kernel to create a [source template](SourceTemplate). -pub trait DynamicKernel: Send { +pub trait DynamicKernelSource: Send { /// Source template for the kernel. - fn source_template(self) -> SourceTemplate; + fn source(self) -> SourceTemplate; /// Identifier for the kernel, used for caching kernel compilation. fn id(&self) -> String; } @@ -27,8 +31,8 @@ macro_rules! kernel_wgsl { #[derive(new)] pub struct $struct; - impl $crate::kernel::StaticKernel for $struct { - fn source_template() -> $crate::kernel::SourceTemplate { + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { $crate::kernel::SourceTemplate::new(include_str!($file)) } } @@ -48,23 +52,21 @@ pub fn into_contiguous( const WORKGROUP: usize = 32; let num_elems = tensor.shape.num_elements(); - let buffer = tensor - .context - .create_buffer(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(tensor.context.clone(), tensor.shape.clone(), buffer); + let handle = tensor.client.empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new( + tensor.client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + handle, + ); let info = build_info(&[&tensor, &output]); - let info_buffer = tensor - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); + let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); - let kernel = tensor - .context - .compile_static::>(); - - tensor.context.execute( - elemwise_workgroup(num_elems, WORKGROUP), - kernel, - &[&tensor.buffer, &output.buffer, &info_buffer], + tensor.client.execute( + Box::new(StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP))), + &[&tensor.handle, &output.handle, &info_handle], ); output @@ -72,7 +74,7 @@ pub fn into_contiguous( /// Generates kernel source code by replacing some information using templating. pub struct KernelSettings< - K: StaticKernel, + K: StaticKernelSource, E: WgpuElement, I: WgpuElement, const WORKGROUP_X_SIZE: usize, @@ -85,17 +87,17 @@ pub struct KernelSettings< } impl< - K: StaticKernel, + K: StaticKernelSource, E: WgpuElement, I: WgpuElement, const WORKGROUP_X_SIZE: usize, const WORKGROUP_Y_SIZE: usize, const WORKGROUP_Z_SIZE: usize, - > StaticKernel + > StaticKernelSource for KernelSettings { - fn source_template() -> SourceTemplate { - K::source_template() + fn source() -> SourceTemplate { + K::source() .register("workgroup_size_x", WORKGROUP_X_SIZE.to_string()) .register("workgroup_size_y", WORKGROUP_Y_SIZE.to_string()) .register("workgroup_size_z", WORKGROUP_Z_SIZE.to_string()) @@ -110,7 +112,7 @@ impl< /// Generate kernel source code by replacing some information using templating. #[derive(new)] -pub struct DynamicKernelSettings { +pub struct DynamicKernelSettings { workgroup_x_size: usize, workgroup_y_size: usize, workgroup_z_size: usize, @@ -119,11 +121,11 @@ pub struct DynamicKernelSettings, } -impl DynamicKernel +impl DynamicKernelSource for DynamicKernelSettings { - fn source_template(self) -> SourceTemplate { - K::source_template() + fn source(self) -> SourceTemplate { + K::source() .register("workgroup_size_x", self.workgroup_x_size.to_string()) .register("workgroup_size_y", self.workgroup_y_size.to_string()) .register("workgroup_size_z", self.workgroup_z_size.to_string()) diff --git a/burn-wgpu/src/kernel/binary_elemwise.rs b/burn-wgpu/src/kernel/binary_elemwise.rs index 0f181e1a9..43c8879b8 100644 --- a/burn-wgpu/src/kernel/binary_elemwise.rs +++ b/burn-wgpu/src/kernel/binary_elemwise.rs @@ -1,4 +1,5 @@ -use super::{build_info, elemwise_workgroup, KernelSettings, StaticKernel}; +use super::{build_info, elemwise_workgroup, KernelSettings, StaticKernelSource}; +use crate::compute::StaticKernel; use crate::{element::WgpuElement, kernel_wgsl, tensor::WgpuTensor}; use burn_tensor::Shape; @@ -17,9 +18,9 @@ macro_rules! binary_elemwise { ) => { pub struct $struct; - impl $crate::kernel::StaticKernel for $struct { - fn source_template() -> $crate::kernel::SourceTemplate { - $crate::kernel::BinaryElemwiseRaw::source_template().register( + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::BinaryElemwiseRaw::source().register( "body", format!("output[id] = lhs[index_lhs] {} rhs[index_rhs];", $ops), ) @@ -37,9 +38,9 @@ macro_rules! binary_elemwise_inplace { ) => { pub struct $struct; - impl $crate::kernel::StaticKernel for $struct { - fn source_template() -> $crate::kernel::SourceTemplate { - $crate::kernel::BinaryElemwiseInplaceRaw::source_template().register( + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::BinaryElemwiseInplaceRaw::source().register( "body", format!("lhs[id] = lhs[id] {} rhs[index_rhs];", $ops), ) @@ -49,7 +50,7 @@ macro_rules! binary_elemwise_inplace { } /// Execute a binary kernel using the default settings. -pub fn binary_elemwise_default( +pub fn binary_elemwise_default( lhs: WgpuTensor, rhs: WgpuTensor, ) -> WgpuTensor { @@ -57,7 +58,12 @@ pub fn binary_elemwise_default( } /// Execute a binary kernel using the provided WORKGROUP. -pub fn binary_elemwise( +pub fn binary_elemwise< + K: StaticKernelSource, + E: WgpuElement, + const D: usize, + const WORKGROUP: usize, +>( lhs: WgpuTensor, rhs: WgpuTensor, ) -> WgpuTensor { @@ -76,30 +82,26 @@ pub fn binary_elemwise()); - let output = WgpuTensor::new(lhs.context.clone(), shape_out, buffer); + let handle = lhs.client.empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new(lhs.client.clone(), lhs.device.clone(), shape_out, handle); - let kernel = lhs - .context - .compile_static::>(); let info = build_info(&[&lhs, &rhs, &output]); - let info_buffers = lhs - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); + let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); - lhs.context.execute( + let kernel = StaticKernel::>::new( elemwise_workgroup(num_elems, WORKGROUP), - kernel, - &[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers], + ); + + lhs.client.execute( + Box::new(kernel), + &[&lhs.handle, &rhs.handle, &output.handle, &info_handle], ); output } /// Execute a binary inplace kernel using the default settings. -pub fn binary_elemwise_inplace_default( +pub fn binary_elemwise_inplace_default( lhs: WgpuTensor, rhs: WgpuTensor, ) -> WgpuTensor { @@ -108,7 +110,7 @@ pub fn binary_elemwise_inplace_default WgpuTensor { lhs.assert_is_on_same_device(&rhs); - let kernel = lhs - .context - .compile_static::>(); let info = build_info(&[&lhs, &rhs]); - let info_buffers = lhs - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); - - lhs.context.execute( + let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); + let kernel = StaticKernel::>::new( elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP), - kernel, - &[&lhs.buffer, &rhs.buffer, &info_buffers], ); + lhs.client + .execute(Box::new(kernel), &[&lhs.handle, &rhs.handle, &info_handle]); + lhs } diff --git a/burn-wgpu/src/kernel/cast.rs b/burn-wgpu/src/kernel/cast.rs index bdb0c3723..80b2b44c6 100644 --- a/burn-wgpu/src/kernel/cast.rs +++ b/burn-wgpu/src/kernel/cast.rs @@ -1,8 +1,10 @@ -use crate::{element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl, tensor::WgpuTensor}; +use super::{KernelSettings, SourceTemplate, StaticKernelSource}; +use crate::{ + compute::StaticKernel, element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl, + tensor::WgpuTensor, +}; use std::{any::TypeId, marker::PhantomData}; -use super::{KernelSettings, SourceTemplate, StaticKernel}; - kernel_wgsl!(CastRaw, "../template/cast.wgsl"); struct Cast { @@ -10,9 +12,11 @@ struct Cast { _o: PhantomData, } -impl StaticKernel for Cast { - fn source_template() -> SourceTemplate { - CastRaw::source_template() +impl StaticKernelSource + for Cast +{ + fn source() -> SourceTemplate { + CastRaw::source() .register("input_elem", InputElem::type_name()) .register("output_elem", OutputElem::type_name()) } @@ -23,32 +27,30 @@ pub fn cast( tensor: WgpuTensor, ) -> WgpuTensor { if TypeId::of::() == TypeId::of::() { - return WgpuTensor::new(tensor.context, tensor.shape, tensor.buffer); + return WgpuTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle); } const WORKGROUP: usize = 32; let num_elems = tensor.shape.num_elements(); - let kernel = tensor.context.compile_static::, - f32, - i32, - WORKGROUP, - WORKGROUP, - 1, - >>(); + let kernel = StaticKernel::< + KernelSettings, f32, i32, WORKGROUP, WORKGROUP, 1>, + >::new(elemwise_workgroup(num_elems, WORKGROUP)); - let buffer = tensor - .context - .create_buffer(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(tensor.context.clone(), tensor.shape.clone(), buffer); - - tensor.context.execute( - elemwise_workgroup(num_elems, WORKGROUP), - kernel, - &[&tensor.buffer, &output.buffer], + let handle = tensor + .client + .empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new( + tensor.client.clone(), + tensor.device, + tensor.shape.clone(), + handle, ); + tensor + .client + .execute(Box::new(kernel), &[&tensor.handle, &output.handle]); + output } diff --git a/burn-wgpu/src/kernel/cat.rs b/burn-wgpu/src/kernel/cat.rs index 5f0fc6bc5..bea0ab18a 100644 --- a/burn-wgpu/src/kernel/cat.rs +++ b/burn-wgpu/src/kernel/cat.rs @@ -1,4 +1,5 @@ use crate::{ + compute::StaticKernel, element::WgpuElement, kernel::{build_info, elemwise_workgroup, KernelSettings}, kernel_wgsl, @@ -14,16 +15,20 @@ pub fn cat( const WORKGROUP: usize = 32; let first_input = inputs.get(0).unwrap(); - let context = &first_input.context; + let client = &first_input.client; let mut shape_output = first_input.shape.clone(); shape_output.dims[dim] = inputs.iter().map(|input| input.shape.dims[dim]).sum(); let buffer = first_input - .context - .create_buffer(shape_output.num_elements() * std::mem::size_of::()); + .client + .empty(shape_output.num_elements() * std::mem::size_of::()); - let output = WgpuTensor::new(context.clone(), shape_output, buffer); - let kernel = context.compile_static::>(); + let output = WgpuTensor::new( + client.clone(), + first_input.device.clone(), + shape_output, + buffer, + ); let mut dim_cat_index = 0; @@ -32,12 +37,14 @@ pub fn cat( info.push(dim as u32); info.push(dim_cat_index as u32); dim_cat_index += input.shape.dims[dim]; - let info_buffer = context.create_buffer_with_data(bytemuck::cast_slice(&info)); - - context.execute( + let info_buffer = client.create(bytemuck::cast_slice(&info)); + let kernel = StaticKernel::>::new( elemwise_workgroup(input.shape.num_elements(), WORKGROUP), - kernel.clone(), - &[&input.buffer, &output.buffer, &info_buffer], + ); + + client.execute( + Box::new(kernel), + &[&input.handle, &output.handle, &info_buffer], ); } diff --git a/burn-wgpu/src/kernel/comparison/binary.rs b/burn-wgpu/src/kernel/comparison/binary.rs index 0c529bdcb..33ec0c5e7 100644 --- a/burn-wgpu/src/kernel/comparison/binary.rs +++ b/burn-wgpu/src/kernel/comparison/binary.rs @@ -1,7 +1,9 @@ use crate::{ + compute::StaticKernel, element::WgpuElement, - kernel::{build_info, elemwise_workgroup, KernelSettings, StaticKernel}, + kernel::{build_info, elemwise_workgroup, KernelSettings, StaticKernelSource}, kernel_wgsl, + ops::numeric::empty_device, tensor::WgpuTensor, }; use burn_tensor::Shape; @@ -21,9 +23,9 @@ macro_rules! comparison { ) => { pub struct $struct; - impl $crate::kernel::StaticKernel for $struct { - fn source_template() -> $crate::kernel::SourceTemplate { - $crate::kernel::ComparisonRaw::source_template().register( + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::ComparisonRaw::source().register( "body", format!("output[id] = u32(lhs[index_lhs] {} rhs[index_rhs]);", $ops), ) @@ -41,9 +43,9 @@ macro_rules! comparison_inplace { ) => { pub struct $struct; - impl $crate::kernel::StaticKernel for $struct { - fn source_template() -> $crate::kernel::SourceTemplate { - $crate::kernel::ComparisonInplaceRaw::source_template() + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::ComparisonInplaceRaw::source() .register( "body", "lhs[index_lhs] = compare(lhs[index_lhs], rhs[index_rhs]);", @@ -59,7 +61,7 @@ macro_rules! comparison_inplace { }; } -pub fn comparison( +pub fn comparison( lhs: WgpuTensor, rhs: WgpuTensor, ) -> WgpuTensor { @@ -79,29 +81,23 @@ pub fn comparison( let shape_out = Shape::new(shape_out); let num_elems = shape_out.num_elements(); - let buffer = lhs - .context - .create_buffer(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(lhs.context.clone(), shape_out, buffer); + let output = empty_device(lhs.client.clone(), lhs.device.clone(), shape_out); - let kernel = lhs - .context - .compile_static::>(); - let info = build_info(&[&lhs, &rhs, &output]); - let info_buffers = lhs - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); - - lhs.context.execute( + let kernel = StaticKernel::>::new( elemwise_workgroup(num_elems, WORKGROUP), - kernel, - &[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers], + ); + let info = build_info(&[&lhs, &rhs, &output]); + let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); + + lhs.client.execute( + Box::new(kernel), + &[&lhs.handle, &rhs.handle, &output.handle, &info_handle], ); - WgpuTensor::new(output.context, output.shape, output.buffer) + WgpuTensor::new(output.client, output.device, output.shape, output.handle) } -pub fn comparison_inplace( +pub fn comparison_inplace( lhs: WgpuTensor, rhs: WgpuTensor, ) -> WgpuTensor { @@ -109,21 +105,16 @@ pub fn comparison_inplace( lhs.assert_is_on_same_device(&rhs); - let kernel = lhs - .context - .compile_static::>(); - let info = build_info(&[&lhs, &rhs]); - let info_buffers = lhs - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); - - lhs.context.execute( + let kernel = StaticKernel::>::new( elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP), - kernel, - &[&lhs.buffer, &rhs.buffer, &info_buffers], ); + let info = build_info(&[&lhs, &rhs]); + let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); - WgpuTensor::new(lhs.context, lhs.shape, lhs.buffer) + lhs.client + .execute(Box::new(kernel), &[&lhs.handle, &rhs.handle, &info_handle]); + + WgpuTensor::new(lhs.client, lhs.device, lhs.shape, lhs.handle) } #[cfg(test)] diff --git a/burn-wgpu/src/kernel/comparison/elem.rs b/burn-wgpu/src/kernel/comparison/elem.rs index 6f99e7009..51504e9b7 100644 --- a/burn-wgpu/src/kernel/comparison/elem.rs +++ b/burn-wgpu/src/kernel/comparison/elem.rs @@ -1,6 +1,7 @@ use crate::{ + compute::StaticKernel, element::WgpuElement, - kernel::{elemwise_workgroup, KernelSettings, StaticKernel}, + kernel::{elemwise_workgroup, KernelSettings, StaticKernelSource}, kernel_wgsl, tensor::WgpuTensor, }; @@ -20,9 +21,9 @@ macro_rules! comparison_elem { ) => { pub struct $struct; - impl $crate::kernel::StaticKernel for $struct { - fn source_template() -> $crate::kernel::SourceTemplate { - $crate::kernel::ComparisonElemRaw::source_template() + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::ComparisonElemRaw::source() .register("body", format!("output[id] = u32(lhs[id] {} rhs);", $ops)) } } @@ -38,9 +39,9 @@ macro_rules! comparison_elem_inplace { ) => { pub struct $struct; - impl $crate::kernel::StaticKernel for $struct { - fn source_template() -> $crate::kernel::SourceTemplate { - $crate::kernel::ComparisonElemInplaceRaw::source_template() + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::ComparisonElemInplaceRaw::source() .register("body", "lhs[id] = compare(lhs[id], rhs);") .add_template(format!( "{}return {{{{ elem }}}}(lhs {} rhs);{}", @@ -53,48 +54,39 @@ macro_rules! comparison_elem_inplace { }; } -pub fn comparison_elem( +pub fn comparison_elem( lhs: WgpuTensor, rhs: E, ) -> WgpuTensor { const WORKGROUP: usize = 32; let num_elems = lhs.shape.num_elements(); - let buffer = lhs - .context - .create_buffer(num_elems * core::mem::size_of::()); - let rhs_buffer = lhs.context.create_buffer_with_data(E::as_bytes(&[rhs])); - let kernel = lhs - .context - .compile_static::>(); - - lhs.context.execute( + let handle = lhs.client.empty(num_elems * core::mem::size_of::()); + let rhs_handle = lhs.client.create(E::as_bytes(&[rhs])); + let kernel = StaticKernel::>::new( elemwise_workgroup(num_elems, WORKGROUP), - kernel, - &[&lhs.buffer, &rhs_buffer, &buffer], ); - WgpuTensor::new(lhs.context, lhs.shape, buffer) + lhs.client + .execute(Box::new(kernel), &[&lhs.handle, &rhs_handle, &handle]); + + WgpuTensor::new(lhs.client, lhs.device, lhs.shape, handle) } -pub fn comparison_elem_inplace( +pub fn comparison_elem_inplace( lhs: WgpuTensor, rhs: E, ) -> WgpuTensor { const WORKGROUP: usize = 32; - let kernel = lhs - .context - .compile_static::>(); - - let rhs_buffer = lhs.context.create_buffer_with_data(E::as_bytes(&[rhs])); - lhs.context.execute( + let kernel = StaticKernel::>::new( elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP), - kernel, - &[&lhs.buffer, &rhs_buffer], ); + let rhs_handle = lhs.client.create(E::as_bytes(&[rhs])); + lhs.client + .execute(Box::new(kernel), &[&lhs.handle, &rhs_handle]); - WgpuTensor::new(lhs.context, lhs.shape, lhs.buffer) + WgpuTensor::new(lhs.client, lhs.device, lhs.shape, lhs.handle) } #[cfg(test)] diff --git a/burn-wgpu/src/kernel/conv/conv2d.rs b/burn-wgpu/src/kernel/conv/conv2d.rs index f872a53af..de04d8743 100644 --- a/burn-wgpu/src/kernel/conv/conv2d.rs +++ b/burn-wgpu/src/kernel/conv/conv2d.rs @@ -1,17 +1,19 @@ use crate::{ + compute::StaticKernel, element::WgpuElement, kernel::{self, build_info, elemwise_workgroup, KernelSettings}, kernel_wgsl, + ops::numeric::empty_device, tensor::WgpuTensor, }; use burn_tensor::{ ops::{conv::calculate_conv_output_size, ConvOptions}, - Shape, + Element, ElementConversion, Shape, }; kernel_wgsl!(Conv2d, "../../template/conv/conv2d.wgsl"); -pub(crate) fn conv2d( +pub(crate) fn conv2d( input: WgpuTensor, weight: WgpuTensor, bias: Option>, @@ -40,12 +42,12 @@ pub(crate) fn conv2d( ); let shape_out = Shape::new([batch_size, out_channels, out_0, out_1]); - let num_elems = shape_out.num_elements(); - let buffer = input - .context - .create_buffer(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(input.context.clone(), shape_out, buffer); + let output = empty_device( + input.client.clone(), + input.device.clone(), + shape_out.clone(), + ); let mut info = build_info(&[&input, &output, &weight]); info.push(options.stride[0] as u32); @@ -56,27 +58,24 @@ pub(crate) fn conv2d( info.push(options.dilation[1] as u32); info.push(options.groups as u32); - let bias_buffer = bias - .map(|bias| bias.buffer) - .unwrap_or_else(|| input.context.create_buffer(core::mem::size_of::())); + let bias_handle = bias + .map(|bias| bias.handle) + .unwrap_or_else(|| input.client.create(E::as_bytes(&[0.elem()]))); - let info_buffer = input - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); - let kernel = input - .context - .compile_static::>(); - - input.context.execute( + let kernel = StaticKernel::>::new( elemwise_workgroup(output.shape.num_elements(), WORKGROUP), - kernel, + ); + + input.client.execute( + Box::new(kernel), &[ - &input.buffer, - &weight.buffer, - &bias_buffer, - &output.buffer, - &info_buffer, + &input.handle, + &weight.handle, + &bias_handle, + &output.handle, + &info_handle, ], ); diff --git a/burn-wgpu/src/kernel/conv/conv_transpose2d.rs b/burn-wgpu/src/kernel/conv/conv_transpose2d.rs index 0efeeccfe..0416458e3 100644 --- a/burn-wgpu/src/kernel/conv/conv_transpose2d.rs +++ b/burn-wgpu/src/kernel/conv/conv_transpose2d.rs @@ -1,14 +1,16 @@ use crate::{ + compute::StaticKernel, element::WgpuElement, kernel::{self, build_info, elemwise_workgroup, KernelSettings}, kernel_wgsl, + ops::numeric::empty_device, tensor::WgpuTensor, }; -use burn_tensor::{ops::ConvTransposeOptions, Shape}; +use burn_tensor::{ops::ConvTransposeOptions, Element, ElementConversion, Shape}; kernel_wgsl!(ConvTranspose2d, "../../template/conv/conv_transpose2d.wgsl"); -pub(crate) fn conv_transpose2d( +pub(crate) fn conv_transpose2d( input: WgpuTensor, weight: WgpuTensor, bias: Option>, @@ -35,12 +37,13 @@ pub(crate) fn conv_transpose2d( let shape_out = Shape::new([batch_size, out_channels * options.groups, out_0, out_1]); let num_elems = shape_out.num_elements(); - let buffer = input - .context - .create_buffer(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(input.context.clone(), shape_out, buffer); - + let output = empty_device( + input.client.clone(), + input.device.clone(), + shape_out.clone(), + ); let mut info = build_info(&[&input, &output, &weight]); + info.push(options.stride[0] as u32); info.push(options.stride[1] as u32); info.push(options.padding[0] as u32); @@ -49,28 +52,24 @@ pub(crate) fn conv_transpose2d( info.push(options.dilation[1] as u32); info.push(options.groups as u32); - let bias_buffer = bias - .map(|bias| bias.buffer) - .unwrap_or_else(|| input.context.create_buffer(core::mem::size_of::())); + let bias_handle = bias + .map(|bias| bias.handle) + .unwrap_or_else(|| input.client.create(E::as_bytes(&[0.elem()]))); - let info_buffer = input - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); - let kernel = input - .context - .compile_static::>(); - - let workgroup = elemwise_workgroup(num_elems, WORKGROUP); - input.context.execute( - workgroup, - kernel, + let kernel = + StaticKernel::>::new( + elemwise_workgroup(num_elems, WORKGROUP), + ); + input.client.execute( + Box::new(kernel), &[ - &input.buffer, - &weight.buffer, - &bias_buffer, - &output.buffer, - &info_buffer, + &input.handle, + &weight.handle, + &bias_handle, + &output.handle, + &info_handle, ], ); diff --git a/burn-wgpu/src/kernel/index/gather.rs b/burn-wgpu/src/kernel/index/gather.rs index d623dad6f..cd8280d6e 100644 --- a/burn-wgpu/src/kernel/index/gather.rs +++ b/burn-wgpu/src/kernel/index/gather.rs @@ -1,7 +1,9 @@ use crate::{ + compute::StaticKernel, element::WgpuElement, kernel::{self, build_info, elemwise_workgroup, KernelSettings}, kernel_wgsl, + ops::numeric::empty_device, tensor::WgpuTensor, }; @@ -17,29 +19,23 @@ pub(crate) fn gather( let shape_output = indices.shape.clone(); let num_elems = shape_output.num_elements(); let indices = kernel::into_contiguous(indices); + let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output); - let buffer = tensor - .context - .create_buffer(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(tensor.context.clone(), shape_output, buffer); let mut info = build_info(&[&tensor, &output]); info.push(dim as u32); - let info_buffer = tensor - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); + let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); - let kernel = tensor - .context - .compile_static::>(); - - tensor.context.execute( + let kernel = StaticKernel::>::new( elemwise_workgroup(num_elems, WORKGROUP), - kernel, + ); + + tensor.client.execute( + Box::new(kernel), &[ - &tensor.buffer, - &indices.buffer, - &output.buffer, - &info_buffer, + &tensor.handle, + &indices.handle, + &output.handle, + &info_handle, ], ); diff --git a/burn-wgpu/src/kernel/index/scatter.rs b/burn-wgpu/src/kernel/index/scatter.rs index cc60566e5..beb355087 100644 --- a/burn-wgpu/src/kernel/index/scatter.rs +++ b/burn-wgpu/src/kernel/index/scatter.rs @@ -1,4 +1,5 @@ use crate::{ + compute::StaticKernel, element::WgpuElement, kernel::{self, build_info, elemwise_workgroup, KernelSettings}, kernel_wgsl, @@ -48,18 +49,15 @@ pub(crate) fn scatter( info.push(dim as u32); - let info_buffer = tensor - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); + let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); - let kernel = tensor - .context - .compile_static::>(); - - tensor.context.execute( + let kernel = StaticKernel::>::new( elemwise_workgroup(num_elems_per_workgroup, WORKGROUP), - kernel, - &[&tensor.buffer, &indices.buffer, &value.buffer, &info_buffer], + ); + + tensor.client.execute( + Box::new(kernel), + &[&tensor.handle, &indices.handle, &value.handle, &info_handle], ); tensor diff --git a/burn-wgpu/src/kernel/index/select.rs b/burn-wgpu/src/kernel/index/select.rs index 4d374d108..57c1467f6 100644 --- a/burn-wgpu/src/kernel/index/select.rs +++ b/burn-wgpu/src/kernel/index/select.rs @@ -1,7 +1,9 @@ use crate::{ + compute::StaticKernel, element::WgpuElement, kernel::{build_info, elemwise_workgroup, KernelSettings}, kernel_wgsl, + ops::numeric::empty_device, tensor::WgpuTensor, }; @@ -20,32 +22,25 @@ pub(crate) fn select( let mut output_shape = tensor.shape.clone(); output_shape.dims[dim] = indices.shape.dims[0]; - let num_elems = output_shape.num_elements(); - let buffer = tensor - .context - .create_buffer(num_elems * std::mem::size_of::()); - let output = WgpuTensor::new(tensor.context.clone(), output_shape, buffer); + let num_elems = output_shape.num_elements(); + let output = empty_device(tensor.client.clone(), tensor.device.clone(), output_shape); let mut info = build_info(&[&tensor, &output]); info.push(dim as u32); - let info_buffer = tensor - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); - - let kernel = tensor - .context - .compile_static::>(); - - tensor.context.execute( + let info_handle = output.client.create(bytemuck::cast_slice(&info)); + let kernel = StaticKernel::>::new( elemwise_workgroup(num_elems, WORKGROUP), - kernel, + ); + + tensor.client.execute( + Box::new(kernel), &[ - &tensor.buffer, - &indices.buffer, - &output.buffer, - &info_buffer, + &tensor.handle, + &indices.handle, + &output.handle, + &info_handle, ], ); @@ -89,18 +84,16 @@ pub(crate) fn select_assign( info.push(dim as u32); - let info_buffer = tensor - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); + let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); - let kernel = tensor - .context - .compile_static::>(); + let kernel = + StaticKernel::>::new( + elemwise_workgroup(num_elems_per_workgroup, WORKGROUP), + ); - tensor.context.execute( - elemwise_workgroup(num_elems_per_workgroup, WORKGROUP), - kernel, - &[&tensor.buffer, &indices.buffer, &value.buffer, &info_buffer], + tensor.client.execute( + Box::new(kernel), + &[&tensor.handle, &indices.handle, &value.handle, &info_handle], ); tensor diff --git a/burn-wgpu/src/kernel/index/slice.rs b/burn-wgpu/src/kernel/index/slice.rs index 3b25cc4d9..db92b94a7 100644 --- a/burn-wgpu/src/kernel/index/slice.rs +++ b/burn-wgpu/src/kernel/index/slice.rs @@ -1,7 +1,9 @@ use crate::{ + compute::StaticKernel, element::WgpuElement, kernel::{build_info, elemwise_workgroup, KernelSettings}, kernel_wgsl, + ops::numeric::empty_device, tensor::WgpuTensor, }; use burn_tensor::Shape; @@ -26,10 +28,7 @@ pub(crate) fn slice( let shape_output = Shape::new(dims); let num_elems = shape_output.num_elements(); - let buffer = tensor - .context - .create_buffer(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(tensor.context.clone(), shape_output, buffer); + let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output); let mut info = build_info(&[&tensor, &output]); for i in 0..D1 { @@ -37,18 +36,15 @@ pub(crate) fn slice( info.push(start as u32); } - let info_buffer = tensor - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); + let info_handle = output.client.create(bytemuck::cast_slice(&info)); - let kernel = tensor - .context - .compile_static::>(); - - tensor.context.execute( + let kernel = StaticKernel::>::new( elemwise_workgroup(num_elems, WORKGROUP), - kernel, - &[&tensor.buffer, &output.buffer, &info_buffer], + ); + + tensor.client.execute( + Box::new(kernel), + &[&tensor.handle, &output.handle, &info_handle], ); output @@ -73,23 +69,15 @@ pub(crate) fn slice_assign( info.push(start as u32); } - let info_buffer = tensor - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); + let info_handle = tensor.client.create(bytemuck::cast_slice(&info)); - let kernel = tensor.context.compile_static::>(); + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems, WORKGROUP)); - tensor.context.execute( - elemwise_workgroup(num_elems, WORKGROUP), - kernel, - &[&tensor.buffer, &value.buffer, &info_buffer], + tensor.client.execute( + Box::new(kernel), + &[&tensor.handle, &value.handle, &info_handle], ); tensor diff --git a/burn-wgpu/src/kernel/mask/mask_fill.rs b/burn-wgpu/src/kernel/mask/mask_fill.rs index 646ccb0d3..b6e57a409 100644 --- a/burn-wgpu/src/kernel/mask/mask_fill.rs +++ b/burn-wgpu/src/kernel/mask/mask_fill.rs @@ -1,7 +1,9 @@ use crate::{ + compute::StaticKernel, element::WgpuElement, kernel::{build_info, elemwise_workgroup, KernelSettings}, kernel_wgsl, + ops::numeric::empty_device, tensor::WgpuTensor, }; @@ -16,30 +18,28 @@ pub fn mask_fill( const WORKGROUP: usize = 32; let num_elems = input.shape.num_elements(); - let buffer = input - .context - .create_buffer(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(input.context.clone(), input.shape.clone(), buffer); + let output = empty_device( + input.client.clone(), + input.device.clone(), + input.shape.clone(), + ); - let value_buffer = input.context.create_buffer_with_data(E::as_bytes(&[value])); - let kernel = input - .context - .compile_static::>(); - let mask = WgpuTensor::new(mask.context, mask.shape, mask.buffer); - let info = build_info(&[&input, &mask, &output]); - let info_buffers = input - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); - - input.context.execute( + let value_handle = output.client.create(E::as_bytes(&[value])); + let kernel = StaticKernel::>::new( elemwise_workgroup(num_elems, WORKGROUP), - kernel, + ); + let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle); + let info = build_info(&[&input, &mask, &output]); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); + + input.client.execute( + Box::new(kernel), &[ - &input.buffer, - &value_buffer, - &mask.buffer, - &output.buffer, - &info_buffers, + &input.handle, + &value_handle, + &mask.handle, + &output.handle, + &info_handle, ], ); @@ -54,20 +54,18 @@ pub fn mask_fill_inplace( const WORKGROUP: usize = 32; let num_elems = input.shape.num_elements(); - let value_buffer = input.context.create_buffer_with_data(E::as_bytes(&[value])); - let kernel = input - .context - .compile_static::>(); - let mask = WgpuTensor::new(mask.context, mask.shape, mask.buffer); + let value_handle = input.client.create(E::as_bytes(&[value])); + let kernel = + StaticKernel::>::new( + elemwise_workgroup(num_elems, WORKGROUP), + ); + let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle); let info = build_info(&[&input, &mask]); - let info_buffers = input - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); - input.context.execute( - elemwise_workgroup(num_elems, WORKGROUP), - kernel, - &[&input.buffer, &value_buffer, &mask.buffer, &info_buffers], + input.client.execute( + Box::new(kernel), + &[&input.handle, &value_handle, &mask.handle, &info_handle], ); input diff --git a/burn-wgpu/src/kernel/mask/mask_where.rs b/burn-wgpu/src/kernel/mask/mask_where.rs index 9d60733fa..78e721080 100644 --- a/burn-wgpu/src/kernel/mask/mask_where.rs +++ b/burn-wgpu/src/kernel/mask/mask_where.rs @@ -1,7 +1,9 @@ use crate::{ + compute::StaticKernel, element::WgpuElement, kernel::{build_info, elemwise_workgroup, KernelSettings}, kernel_wgsl, + ops::numeric::empty_device, tensor::WgpuTensor, }; @@ -16,29 +18,27 @@ pub fn mask_where( const WORKGROUP: usize = 32; let num_elems = input.shape.num_elements(); - let buffer = input - .context - .create_buffer(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(input.context.clone(), input.shape.clone(), buffer); + let output = empty_device( + input.client.clone(), + input.device.clone(), + input.shape.clone(), + ); - let kernel = input - .context - .compile_static::>(); - let mask = WgpuTensor::new(mask.context, mask.shape, mask.buffer); - let info = build_info(&[&input, &value, &mask, &output]); - let info_buffers = input - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); - - input.context.execute( + let kernel = StaticKernel::>::new( elemwise_workgroup(num_elems, WORKGROUP), - kernel, + ); + let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle); + let info = build_info(&[&input, &value, &mask, &output]); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); + + input.client.execute( + Box::new(kernel), &[ - &input.buffer, - &value.buffer, - &mask.buffer, - &output.buffer, - &info_buffers, + &input.handle, + &value.handle, + &mask.handle, + &output.handle, + &info_handle, ], ); @@ -53,23 +53,21 @@ pub fn mask_where_inplace( ) -> WgpuTensor { const WORKGROUP: usize = 32; - let kernel = input - .context - .compile_static::>(); - let mask = WgpuTensor::new(mask.context, mask.shape, mask.buffer); + let kernel = + StaticKernel::>::new( + elemwise_workgroup(input.shape.num_elements(), WORKGROUP), + ); + let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle); let mut info = build_info(&[&input, &value, &mask]); info.push(match reverse { true => 1, false => 0, }); - let info_buffers = input - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); - input.context.execute( - elemwise_workgroup(input.shape.num_elements(), WORKGROUP), - kernel, - &[&input.buffer, &value.buffer, &mask.buffer, &info_buffers], + input.client.execute( + Box::new(kernel), + &[&input.handle, &value.handle, &mask.handle, &info_handle], ); input diff --git a/burn-wgpu/src/kernel/matmul/mem_coalescing.rs b/burn-wgpu/src/kernel/matmul/mem_coalescing.rs index dbe4dcb5c..e849d0811 100644 --- a/burn-wgpu/src/kernel/matmul/mem_coalescing.rs +++ b/burn-wgpu/src/kernel/matmul/mem_coalescing.rs @@ -2,10 +2,13 @@ use std::marker::PhantomData; use super::utils::shape_out; use crate::{ - context::WorkGroup, + compute::{DynamicKernel, WorkGroup}, element::WgpuElement, - kernel::{build_info, into_contiguous, DynamicKernel, SourceTemplate, StaticKernel}, + kernel::{ + build_info, into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource, + }, kernel_wgsl, + ops::numeric::empty_device, tensor::WgpuTensor, }; @@ -21,9 +24,9 @@ struct MatmulMemCoalescing { _elem: PhantomData, } -impl DynamicKernel for MatmulMemCoalescing { - fn source_template(self) -> SourceTemplate { - MatmulMemCoalescingRaw::source_template() +impl DynamicKernelSource for MatmulMemCoalescing { + fn source(self) -> SourceTemplate { + MatmulMemCoalescingRaw::source() .register("workgroup_size_x", self.workgroup_size_x.to_string()) .register("workgroup_size_y", self.workgroup_size_y.to_string()) .register("elem", E::type_name()) @@ -59,26 +62,11 @@ pub fn matmul_mem_coalescing( let num_rows = lhs.shape.dims[D - 2]; let num_cols = rhs.shape.dims[D - 1]; - let buffer = lhs - .context - .create_buffer(shape_out.num_elements() * core::mem::size_of::()); - let output = WgpuTensor::new(lhs.context.clone(), shape_out, buffer); + let output = empty_device(lhs.client.clone(), lhs.device.clone(), shape_out); // set number of workgroups let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32; let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32; - - let kernel = lhs.context.compile_dynamic(MatmulMemCoalescing::::new( - workgroup_size_x, - workgroup_size_y, - )); - - let info = build_info(&[&lhs, &rhs, &output]); - - let info_buffers = lhs - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); - let mut num_iter = 1; for i in 0..D - 2 { num_iter *= output.shape.dims[i]; @@ -86,10 +74,18 @@ pub fn matmul_mem_coalescing( let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32); - lhs.context.execute( + let kernel = DynamicKernel::new( + MatmulMemCoalescing::::new(workgroup_size_x, workgroup_size_y), workgroup, - kernel, - &[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers], + ); + + let info = build_info(&[&lhs, &rhs, &output]); + + let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); + + lhs.client.execute( + Box::new(kernel), + &[&lhs.handle, &rhs.handle, &output.handle, &info_handle], ); output diff --git a/burn-wgpu/src/kernel/matmul/mod.rs b/burn-wgpu/src/kernel/matmul/mod.rs index 17084730c..9dfcecc3a 100644 --- a/burn-wgpu/src/kernel/matmul/mod.rs +++ b/burn-wgpu/src/kernel/matmul/mod.rs @@ -3,9 +3,7 @@ pub(crate) mod utils; mod mem_coalescing; mod naive; mod tiling2d; -mod tune; pub use mem_coalescing::*; pub use naive::*; pub use tiling2d::*; -pub use tune::*; diff --git a/burn-wgpu/src/kernel/matmul/naive.rs b/burn-wgpu/src/kernel/matmul/naive.rs index 28243ed12..60c36f18a 100644 --- a/burn-wgpu/src/kernel/matmul/naive.rs +++ b/burn-wgpu/src/kernel/matmul/naive.rs @@ -1,9 +1,10 @@ use super::utils::shape_out; use crate::{ - context::WorkGroup, + compute::{StaticKernel, WorkGroup}, element::WgpuElement, - kernel::{build_info, into_contiguous, KernelSettings, SourceTemplate, StaticKernel}, + kernel::{build_info, into_contiguous, KernelSettings, SourceTemplate, StaticKernelSource}, kernel_wgsl, + ops::numeric::empty_device, tensor::WgpuTensor, }; @@ -11,11 +12,11 @@ kernel_wgsl!(MatmulNaiveRaw, "../../template/matmul/naive.wgsl"); struct MatmulNaive; -impl StaticKernel +impl StaticKernelSource for MatmulNaive { - fn source_template() -> SourceTemplate { - MatmulNaiveRaw::source_template() + fn source() -> SourceTemplate { + MatmulNaiveRaw::source() .register("block_size_m", WORKGROUP_SIZE_X.to_string()) .register("block_size_n", WORKGROUP_SIZE_Y.to_string()) } @@ -49,41 +50,35 @@ pub fn matmul_naive< let num_rows = lhs.shape.dims[D - 2]; let num_cols = rhs.shape.dims[D - 1]; - let buffer = lhs - .context - .create_buffer(shape_out.num_elements() * core::mem::size_of::()); - let output = WgpuTensor::new(lhs.context.clone(), shape_out, buffer); + let output = empty_device(lhs.client.clone(), lhs.device.clone(), shape_out); // set number of workgroups let blocks_needed_in_x = f32::ceil(num_rows as f32 / WORKGROUP_SIZE_X as f32) as u32; let blocks_needed_in_y = f32::ceil(num_cols as f32 / WORKGROUP_SIZE_Y as f32) as u32; - - let kernel = lhs.context.compile_static::, - E, - i32, - WORKGROUP_SIZE_X, - WORKGROUP_SIZE_Y, - 1, - >>(); - - let info = build_info(&[&lhs, &rhs, &output]); - - let info_buffers = lhs - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); - let mut num_iter = 1; for i in 0..D - 2 { num_iter *= output.shape.dims[i]; } - let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32); - lhs.context.execute( - workgroup, - kernel, - &[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers], + let kernel = StaticKernel::< + KernelSettings< + MatmulNaive, + E, + i32, + WORKGROUP_SIZE_X, + WORKGROUP_SIZE_Y, + 1, + >, + >::new(workgroup); + + let info = build_info(&[&lhs, &rhs, &output]); + + let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); + + lhs.client.execute( + Box::new(kernel), + &[&lhs.handle, &rhs.handle, &output.handle, &info_handle], ); output diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/base.rs b/burn-wgpu/src/kernel/matmul/tiling2d/base.rs index 7e185277c..8b62a6191 100644 --- a/burn-wgpu/src/kernel/matmul/tiling2d/base.rs +++ b/burn-wgpu/src/kernel/matmul/tiling2d/base.rs @@ -1,29 +1,16 @@ +use super::padding::{crop, pad_round, PaddingOutput}; use crate::{ - context::{Context, WorkGroup}, + compute::{DynamicKernel, WgpuHandle, WorkGroup}, element::WgpuElement, - kernel::{build_info, into_contiguous, matmul::utils::shape_out}, + kernel::{build_info, into_contiguous, matmul::utils::shape_out, DynamicKernelSource}, + ops::numeric::empty_device, tensor::WgpuTensor, }; -use burn_tensor::Shape; -use std::{ - cmp::{max, min}, - sync::Arc, -}; -use wgpu::ComputePipeline; - -use super::padding::{crop, pad_round, PaddingOutput}; +use burn_tensor::{Element, Shape}; +use std::cmp::{max, min}; const MAX_SHARED_MEMORY_SIZE: usize = 8192; -pub(super) fn empty_from_context( - context: Arc, - shape: &Shape, -) -> WgpuTensor { - let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::()); - - WgpuTensor::new(context, shape.clone(), buffer) -} - /// Create a source template for tile 2d matmul. #[macro_export(local_inner_macros)] macro_rules! matmul_tile_2d { @@ -63,11 +50,11 @@ macro_rules! matmul_tile_2d { _elem: core::marker::PhantomData, } - impl DynamicKernel for $struct { - fn source_template(self) -> SourceTemplate { + impl DynamicKernelSource for $struct { + fn source(self) -> SourceTemplate { kernel_wgsl!(Raw, $file); - Raw::source_template() + Raw::source() .register("b_m", self.b_m.to_string()) .register("b_n", self.b_n.to_string()) .register("b_k", self.b_k.to_string()) @@ -90,7 +77,7 @@ macro_rules! matmul_tile_2d { } /// Matrix multiplication using tiling 2D algorithm with default parameters - pub fn matmul_tiling_2d_default( + pub fn matmul_tiling_2d_default( lhs: WgpuTensor, rhs: WgpuTensor, ) -> WgpuTensor { @@ -125,7 +112,7 @@ macro_rules! matmul_tile_2d { /// Matrix multiplication using tiling 2D algorithm with custom parameters #[allow(clippy::too_many_arguments)] pub fn matmul_tiling_2d< - E: WgpuElement, + E: WgpuElement + burn_tensor::Element, const D: usize, >( lhs: WgpuTensor, @@ -140,14 +127,13 @@ macro_rules! matmul_tile_2d { ) -> WgpuTensor { let kernel = $struct::::new(b_m, b_n, b_k, t_m, t_n, workgroup_size_x, workgroup_size_y); - let kernel = lhs.context.compile_dynamic(kernel); matmul_tiling_2d_launch::< E, D, + $struct::, >( lhs, rhs, - kernel, b_m, b_n, b_k, @@ -155,6 +141,7 @@ macro_rules! matmul_tile_2d { t_n, workgroup_size_x, workgroup_size_y, + kernel, ) } @@ -412,21 +399,23 @@ pub(super) fn make_workgroup( WorkGroup::new(num_blocks_x, num_blocks_y, num_blocks_z as u32) } -pub(super) fn make_info_buffers( +pub(super) fn make_info_handle( lhs: &WgpuTensor, rhs: &WgpuTensor, output: &WgpuTensor, -) -> Arc { +) -> WgpuHandle { let info = build_info(&[lhs, rhs, output]); - rhs.context - .create_buffer_with_data(bytemuck::cast_slice(&info)) + rhs.client.create(bytemuck::cast_slice(&info)) } #[allow(clippy::too_many_arguments)] -pub(super) fn matmul_tiling_2d_launch( +pub(super) fn matmul_tiling_2d_launch< + E: WgpuElement + Element, + const D: usize, + K: DynamicKernelSource + 'static, +>( lhs: WgpuTensor, rhs: WgpuTensor, - kernel: Arc, b_m: usize, b_n: usize, b_k: usize, @@ -434,6 +423,7 @@ pub(super) fn matmul_tiling_2d_launch( t_n: usize, workgroup_size_x: usize, workgroup_size_y: usize, + kernel: K, ) -> WgpuTensor { matmul_parameter_assertions::( b_m, @@ -470,15 +460,18 @@ pub(super) fn matmul_tiling_2d_launch( let rounded_output_shape = shape_out(&lhs, &rhs); - let output = empty_from_context::(rhs.context.clone(), &rounded_output_shape); + let output = empty_device( + rhs.client.clone(), + rhs.device.clone(), + rounded_output_shape.clone(), + ); let workgroup = make_workgroup(rounded_output_shape, b_m, b_n); - let info_buffers = make_info_buffers(&lhs, &rhs, &output); + let info_handle = make_info_handle(&lhs, &rhs, &output); - lhs.context.execute( - workgroup, - kernel, - &[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers], + output.client.execute( + Box::new(DynamicKernel::new(kernel, workgroup)), + &[&lhs.handle, &rhs.handle, &output.handle, &info_handle], ); crop(output, final_output_shape) diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/contiguous.rs b/burn-wgpu/src/kernel/matmul/tiling2d/contiguous.rs index c3d37b383..32f5debff 100644 --- a/burn-wgpu/src/kernel/matmul/tiling2d/contiguous.rs +++ b/burn-wgpu/src/kernel/matmul/tiling2d/contiguous.rs @@ -1,7 +1,7 @@ use super::base::matmul_tiling_2d_launch; use crate::{ element::WgpuElement, - kernel::{DynamicKernel, SourceTemplate, StaticKernel}, + kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource}, matmul_tile_2d, tensor::WgpuTensor, }; diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/contiguous_vectorized.rs b/burn-wgpu/src/kernel/matmul/tiling2d/contiguous_vectorized.rs index 68c6edcaa..162f1452b 100644 --- a/burn-wgpu/src/kernel/matmul/tiling2d/contiguous_vectorized.rs +++ b/burn-wgpu/src/kernel/matmul/tiling2d/contiguous_vectorized.rs @@ -1,7 +1,7 @@ use super::base::matmul_tiling_2d_launch; use crate::{ element::WgpuElement, - kernel::{DynamicKernel, SourceTemplate, StaticKernel}, + kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource}, matmul_tile_2d, tensor::WgpuTensor, }; diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/padding.rs b/burn-wgpu/src/kernel/matmul/tiling2d/padding.rs index 4da1b1c2e..30fc68b83 100644 --- a/burn-wgpu/src/kernel/matmul/tiling2d/padding.rs +++ b/burn-wgpu/src/kernel/matmul/tiling2d/padding.rs @@ -1,15 +1,14 @@ use std::ops::Range; -use burn_tensor::Shape; +use burn_tensor::{Element, Shape}; use crate::{ element::WgpuElement, kernel::{slice, slice_assign}, + ops::numeric::zeros_device, tensor::WgpuTensor, }; -use super::base::empty_from_context; - // Output of the pad_round function. Allows to know explicitly if early return occurred pub(super) enum PaddingOutput { Padded(WgpuTensor), @@ -29,7 +28,7 @@ impl PaddingOutput { /// divisible by some quantity. /// For instance tensor of shape [1000, 1000] with divisors 64 and 64 /// will be padded to [1024, 1024] with the last 24 elements being zeros -pub(super) fn pad_round( +pub(super) fn pad_round( tensor: WgpuTensor, row_divisor: usize, col_divisor: usize, @@ -62,7 +61,7 @@ pub(super) fn pad_round( } /// Pads tensor by adding zeros when padded dim is larger than tensor dim -fn padding( +fn padding( tensor: WgpuTensor, padded_shape: Shape, ) -> WgpuTensor { @@ -73,8 +72,9 @@ fn padding( .collect::>>() .try_into() .unwrap(); + slice_assign::( - empty_from_context(tensor.context.clone(), &padded_shape), + zeros_device(tensor.client.clone(), tensor.device.clone(), padded_shape), ranges, tensor, ) diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/tile.rs b/burn-wgpu/src/kernel/matmul/tiling2d/tile.rs index 42d201b2c..2a34613de 100644 --- a/burn-wgpu/src/kernel/matmul/tiling2d/tile.rs +++ b/burn-wgpu/src/kernel/matmul/tiling2d/tile.rs @@ -1,6 +1,6 @@ use crate::{ element::WgpuElement, - kernel::{DynamicKernel, SourceTemplate, StaticKernel}, + kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource}, matmul_tile_2d, tensor::WgpuTensor, }; diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/tile_vectorized.rs b/burn-wgpu/src/kernel/matmul/tiling2d/tile_vectorized.rs index 54da1df70..b04406414 100644 --- a/burn-wgpu/src/kernel/matmul/tiling2d/tile_vectorized.rs +++ b/burn-wgpu/src/kernel/matmul/tiling2d/tile_vectorized.rs @@ -1,7 +1,7 @@ use super::base::matmul_tiling_2d_launch; use crate::{ element::WgpuElement, - kernel::{DynamicKernel, SourceTemplate, StaticKernel}, + kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource}, matmul_tile_2d, tensor::WgpuTensor, }; diff --git a/burn-wgpu/src/kernel/matmul/tune.rs b/burn-wgpu/src/kernel/matmul/tune.rs deleted file mode 100644 index 88fa55f3e..000000000 --- a/burn-wgpu/src/kernel/matmul/tune.rs +++ /dev/null @@ -1,396 +0,0 @@ -use burn_tensor::{Distribution, Shape, Tensor}; -use mem_coalescing::matmul_mem_coalescing; - -use crate::{ - benchmark::Benchmark, - element::{FloatElement, WgpuElement}, - kernel, - tensor::{WgpuTensor, WgpuTensorDyn}, - tune::{AutoTuneFunction, AutoTuneKey, Execution, KernelFunction, Tunable}, - GraphicsApi, WgpuBackend, WgpuDevice, -}; -use std::{marker::PhantomData, sync::Arc}; - -use super::mem_coalescing; - -const TILING_2D_BLOCK_SIZES: [usize; 2] = [64, 128]; -const TILING_2D_BLOCK_SIZES_K: [usize; 2] = [16, 32]; -const TILING_2D_TILE_SIZES: [usize; 2] = [4, 16]; -const MEMORY_COALESCING_WORKGROUP_SIZES: [usize; 3] = [8, 16, 32]; - -macro_rules! call_dim { - ($func:expr, $dim:expr, $( $x:expr ),*) => { - match $dim { - 1 => { - let tensor: WgpuTensor = $func($($x,)*); - tensor.into() - }, - 2 => { - let tensor: WgpuTensor = $func($($x,)*); - tensor.into() - }, - 3 => { - let tensor: WgpuTensor = $func($($x,)*); - tensor.into() - }, - 4 => { - let tensor: WgpuTensor = $func($($x,)*); - tensor.into() - }, - 5 => { - let tensor: WgpuTensor = $func($($x,)*); - tensor.into() - }, - 6 => { - let tensor: WgpuTensor = $func($($x,)*); - tensor.into() - }, - _ => panic!("Tensors of rank 7 and more can't be autotuned."), - } - }; -} - -macro_rules! tiling2d_tunable { - ($name:ident, $func:expr) => { - #[derive(new, Default)] - struct $name { - b_m: usize, - b_n: usize, - b_k: usize, - t_m: usize, - t_n: usize, - workgroup_size_x: usize, - workgroup_size_y: usize, - _elem: PhantomData, - } - - impl KernelFunction for $name { - type Input = (WgpuTensorDyn, WgpuTensorDyn); - type Output = WgpuTensorDyn; - - fn call(&self, (lhs, rhs): Self::Input) -> Self::Output { - - #[allow(clippy::too_many_arguments)] - fn call_dyn( - lhs: WgpuTensorDyn, - rhs: WgpuTensorDyn, - b_m: usize, - b_n: usize, - b_k: usize, - t_m: usize, - t_n: usize, - workgroup_size_x: usize, - workgroup_size_y: usize, - ) -> WgpuTensor { - $func( - WgpuTensor::::from(lhs), - WgpuTensor::::from(rhs), - b_m, - b_n, - b_k, - t_m, - t_n, - workgroup_size_x, - workgroup_size_y, - ) - } - - return call_dim!( - call_dyn, - lhs.shape.len(), - lhs, - rhs, - self.b_m, - self.b_n, - self.b_k, - self.t_m, - self.t_n, - self.workgroup_size_x, - self.workgroup_size_y - ); - } - - fn description(&self) -> String { - format!( - "Tiling 2D matmul ({}) - B_M {}, B_N {}, B_K {}, T_M {}, T_N {}, W_X {}, W_X {}", - stringify!($name), - self.b_m, - self.b_n, - self.b_k, - self.t_m, - self.t_n, - self.workgroup_size_x, - self.workgroup_size_y - ) - } - } - }; -} - -tiling2d_tunable!( - Tiling2DContiguousLoad, - kernel::matmul::contiguous::matmul_tiling_2d -); - -tiling2d_tunable!(Tiling2DTileLoad, kernel::matmul::tile::matmul_tiling_2d); -tiling2d_tunable!( - Tiling2DContiguousLoadVectorized, - kernel::matmul::contiguous_vectorized::matmul_tiling_2d -); - -tiling2d_tunable!( - Tiling2DTileLoadVectorized, - kernel::matmul::tile_vectorized::matmul_tiling_2d -); - -#[derive(new)] -struct MemoryCoalescing { - workgroup_size_x: usize, - workgroup_size_y: usize, - _elem: PhantomData, -} - -impl KernelFunction for MemoryCoalescing { - type Input = (WgpuTensorDyn, WgpuTensorDyn); - type Output = WgpuTensorDyn; - - fn call(&self, (lhs, rhs): Self::Input) -> Self::Output { - fn call_dyn( - lhs: WgpuTensorDyn, - rhs: WgpuTensorDyn, - workgroup_size_x: usize, - workgroup_size_y: usize, - ) -> WgpuTensor { - let lhs = WgpuTensor::from(lhs); - let rhs = WgpuTensor::from(rhs); - - matmul_mem_coalescing::(lhs, rhs, workgroup_size_x, workgroup_size_y) - } - - call_dim!( - call_dyn, - lhs.shape.len(), - lhs, - rhs, - self.workgroup_size_x, - self.workgroup_size_y - ) - } - - fn description(&self) -> String { - format!( - "Memory Coalescing matmul - W_X {}, W_Y {}", - self.workgroup_size_x, self.workgroup_size_y - ) - } -} - -#[derive(new)] -struct MatmulBenchmark { - shape_lhs: Shape, - shape_rhs: Shape, - num_repeats: usize, - matmul: PhantomData, - func: AutoTuneFunction<(WgpuTensorDyn, WgpuTensorDyn), WgpuTensorDyn>, -} - -impl Benchmark for MatmulBenchmark -where - E: WgpuElement + FloatElement, - G: GraphicsApi, -{ - type Args = (WgpuTensorDyn, WgpuTensorDyn); - - fn name(&self) -> String { - format!("{:?} x {:?}", self.shape_lhs.dims, self.shape_rhs.dims) - } - - fn num_samples(&self) -> usize { - 5 - } - - fn execute(&self, (lhs, rhs): Self::Args) { - for _ in 0..self.num_repeats { - self.func.call((lhs.clone(), rhs.clone())); - } - } - - fn prepare(&self, device: &WgpuDevice) -> Self::Args { - let lhs = Tensor::, D>::random_device( - self.shape_lhs.clone(), - Distribution::Default, - device, - ); - let rhs = Tensor::, D>::random_device( - self.shape_rhs.clone(), - Distribution::Default, - device, - ); - - (lhs.into_primitive().into(), rhs.into_primitive().into()) - } -} - -/// Choose the best matmul kernel by using autotuning. -pub fn tune( - lhs: WgpuTensor, - rhs: WgpuTensor, -) -> WgpuTensor -where - E: WgpuElement + FloatElement, -{ - if D > 6 { - log::debug!("Can't autotune matmul for tensors of rank 7 or more."); - return kernel::matmul::matmul_tiling_2d_default(lhs, rhs); - } - - let (shape_lhs, shape_rhs) = calculate_benchmark_shapes(lhs.shape.clone(), rhs.shape.clone()); - let id = AutoTuneKey::new( - vec![ - shape_lhs.dims[D - 2..].to_vec(), - shape_rhs.dims[D - 2..].to_vec(), - ], - format!("matmul {}", E::type_name()), - &lhs.context, - ); - - let context = lhs.context.clone(); - let input: (WgpuTensorDyn, WgpuTensorDyn) = (lhs.into(), rhs.into()); - let output: WgpuTensorDyn = match context.tuner.execute(&id, input) { - Execution::Executed(output) => output, - Execution::NoCacheFound((lhs, rhs)) => { - let tunables = matmul_candidates::(shape_lhs, shape_rhs); - - context - .tuner - .tune(id, (lhs, rhs), tunables, &context.device, &context) - } - }; - - output.into() -} - -/// Shape dims are anchored to the closest (on a log scale) power of 2 -fn calculate_benchmark_shapes( - lhs: Shape, - rhs: Shape, -) -> (Shape, Shape) { - let anchor = |a| f32::powf(2., f32::min(f32::round(f32::log(a as f32, 2.)), 12.)) as usize; - let m = anchor(lhs.dims[D - 2]); - let k = anchor(lhs.dims[D - 1]); - let n = anchor(rhs.dims[D - 1]); - - let mut lhs_shape = [1; D]; - lhs_shape[D - 2] = m; - lhs_shape[D - 1] = k; - let lhs_shape = Shape::new(lhs_shape); - - let mut rhs_shape = [1; D]; - rhs_shape[D - 2] = k; - rhs_shape[D - 1] = n; - let rhs_shape = Shape::new(rhs_shape); - - (lhs_shape, rhs_shape) -} - -type MatmulTunable = Tunable, WgpuTensorDyn), WgpuTensorDyn>; - -/// Enumerates all matmul versions that are candidates for autotuning -fn matmul_candidates( - shape_lhs: Shape, - shape_rhs: Shape, -) -> Vec> -where - E: WgpuElement + FloatElement, -{ - let matmul_benchmark = - |func: AutoTuneFunction<(WgpuTensorDyn, WgpuTensorDyn), WgpuTensorDyn>| { - Tunable::::new( - func.clone(), - Arc::new(MatmulBenchmark::new( - shape_lhs.clone(), - shape_rhs.clone(), - 5, - func.clone(), - )), - ) - }; - - let mut candidates = Vec::new(); - - // All combinations of tiling 2d parameters are pushed for a grid search - for block_size in TILING_2D_BLOCK_SIZES { - for block_size_k in TILING_2D_BLOCK_SIZES_K { - for tile_size in TILING_2D_TILE_SIZES { - candidates.push(matmul_benchmark(Arc::new( - Tiling2DContiguousLoad::::new( - block_size, - block_size, - block_size_k, - tile_size, - tile_size, - block_size / tile_size, - block_size / tile_size, - ), - ))); - candidates.push(matmul_benchmark(Arc::new( - Tiling2DContiguousLoadVectorized::::new( - block_size, - block_size, - block_size_k, - tile_size, - tile_size, - block_size / tile_size, - block_size / tile_size, - ), - ))); - candidates.push(matmul_benchmark(Arc::new(Tiling2DTileLoad::::new( - block_size, - block_size, - block_size_k, - tile_size, - tile_size, - block_size / tile_size, - block_size / tile_size, - )))); - candidates.push(matmul_benchmark(Arc::new( - Tiling2DTileLoadVectorized::::new( - block_size, - block_size, - block_size_k, - tile_size, - tile_size, - block_size / tile_size, - block_size / tile_size, - ), - ))); - } - } - - // All combinations of tiling 2d parameters are pushed for a grid search - for workgroup_size in MEMORY_COALESCING_WORKGROUP_SIZES { - candidates.push(matmul_benchmark(Arc::new(MemoryCoalescing::new( - workgroup_size, - workgroup_size, - )))); - } - } - candidates -} - -#[cfg(test)] -mod tests { - use super::calculate_benchmark_shapes; - - #[test] - pub fn benchmark_shapes_are_anchored_correctly() { - let m = f32::powf(2., 8.49) as usize; - let k = f32::powf(2., 8.51) as usize; - let n = f32::powf(2., 4.) as usize; - let lhs_shape = [m, k].into(); - let rhs_shape = [k, n].into(); - let (lhs_shape, rhs_shape) = calculate_benchmark_shapes(lhs_shape, rhs_shape); - assert_eq!(lhs_shape.dims, [256, 512]); - assert_eq!(rhs_shape.dims, [512, 16]); - } -} diff --git a/burn-wgpu/src/kernel/pool/adaptive_avg_pool2d.rs b/burn-wgpu/src/kernel/pool/adaptive_avg_pool2d.rs index e37cb88a5..0878af73f 100644 --- a/burn-wgpu/src/kernel/pool/adaptive_avg_pool2d.rs +++ b/burn-wgpu/src/kernel/pool/adaptive_avg_pool2d.rs @@ -1,14 +1,12 @@ -use std::sync::Arc; - -use burn_tensor::Shape; -use wgpu::Buffer; - use crate::{ + compute::{StaticKernel, WgpuHandle}, element::WgpuElement, kernel::{elemwise_workgroup, KernelSettings}, kernel_wgsl, + ops::numeric::empty_device, tensor::WgpuTensor, }; +use burn_tensor::Shape; kernel_wgsl!( AdaptiveAvgPool2d, @@ -28,23 +26,16 @@ pub(crate) fn adaptive_avg_pool2d( let [batch_size, channels, _, _] = x.shape.dims; let output_shape = Shape::new([batch_size, channels, output_size[0], output_size[1]]); - let num_elems = output_shape.num_elements(); - let output_buffer = x - .context - .create_buffer(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(x.context.clone(), output_shape, output_buffer); + let output = empty_device(x.client.clone(), x.device.clone(), output_shape); - let kernel = x - .context - .compile_static::>(); + let kernel = + StaticKernel::>::new( + elemwise_workgroup(output.shape.num_elements(), WORKGROUP), + ); - let info_buffer = build_info(&x, &output); - - x.context.execute( - elemwise_workgroup(output.shape.num_elements(), WORKGROUP), - kernel, - &[&x.buffer, &output.buffer, &info_buffer], - ); + let info_handle = build_info(&x, &output); + x.client + .execute(Box::new(kernel), &[&x.handle, &output.handle, &info_handle]); output } @@ -57,32 +48,29 @@ pub(crate) fn adaptive_avg_pool2d_backward( let output_shape = x.shape.clone(); let num_elems = output_shape.num_elements(); - let output_buffer = x - .context - .create_buffer(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(x.context.clone(), output_shape, output_buffer); + let output_buffer = x.client.empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new( + x.client.clone(), + x.device.clone(), + output_shape, + output_buffer, + ); - let kernel = x.context.compile_static::>(); + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(output.shape.num_elements(), WORKGROUP)); - let info_buffer = build_info(&x, &out_grad); + let info_handle = build_info(&x, &out_grad); - x.context.execute( - elemwise_workgroup(output.shape.num_elements(), WORKGROUP), - kernel, - &[&out_grad.buffer, &output.buffer, &info_buffer], + x.client.execute( + Box::new(kernel), + &[&out_grad.handle, &output.handle, &info_handle], ); output } -fn build_info(x: &WgpuTensor, output: &WgpuTensor) -> Arc { +fn build_info(x: &WgpuTensor, output: &WgpuTensor) -> WgpuHandle { let mut info: [u32; 16] = [0; 16]; info[0] = x.strides[0] as u32; info[1] = x.strides[1] as u32; @@ -102,7 +90,5 @@ fn build_info(x: &WgpuTensor, output: &WgpuTensor) - info[14] = output.shape.dims[2] as u32; info[15] = output.shape.dims[3] as u32; - output - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)) + output.client.create(bytemuck::cast_slice(&info)) } diff --git a/burn-wgpu/src/kernel/pool/avg_pool2d.rs b/burn-wgpu/src/kernel/pool/avg_pool2d.rs index 8cb88e3c7..e90d4f6b3 100644 --- a/burn-wgpu/src/kernel/pool/avg_pool2d.rs +++ b/burn-wgpu/src/kernel/pool/avg_pool2d.rs @@ -1,11 +1,13 @@ use crate::{ + compute::{Kernel, StaticKernel}, element::WgpuElement, kernel::{ self, elemwise_workgroup, pool::{build_output_and_info_pool2d, build_pool2d_info}, - KernelSettings, StaticKernel, + KernelSettings, StaticKernelSource, }, kernel_wgsl, + ops::numeric::empty_device, tensor::WgpuTensor, }; @@ -18,17 +20,15 @@ kernel_wgsl!( struct AvgPool2dBackward; struct AvgPool2d; -impl StaticKernel for AvgPool2dBackward { - fn source_template() -> kernel::SourceTemplate { - AvgPool2dBackwardRaw::source_template() - .register("count_include_pad", format!("{COUNT_INCLUDE_PAD}")) +impl StaticKernelSource for AvgPool2dBackward { + fn source() -> kernel::SourceTemplate { + AvgPool2dBackwardRaw::source().register("count_include_pad", format!("{COUNT_INCLUDE_PAD}")) } } -impl StaticKernel for AvgPool2d { - fn source_template() -> kernel::SourceTemplate { - AvgPool2dRaw::source_template() - .register("count_include_pad", format!("{COUNT_INCLUDE_PAD}")) +impl StaticKernelSource for AvgPool2d { + fn source() -> kernel::SourceTemplate { + AvgPool2dRaw::source().register("count_include_pad", format!("{COUNT_INCLUDE_PAD}")) } } @@ -41,22 +41,21 @@ pub(crate) fn avg_pool2d( ) -> WgpuTensor { const WORKGROUP: usize = 32; - let (info_buffer, output) = + let (info_handle, output) = build_output_and_info_pool2d(&x, kernel_size, stride, padding, [1, 1]); - let kernel = match count_include_pad { - true => x - .context - .compile_static::, E, i32, WORKGROUP, WORKGROUP, 1>>(), - false => x - .context - .compile_static::, E, i32, WORKGROUP, WORKGROUP, 1>>(), + + let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP); + let kernel: Box = match count_include_pad { + true => Box::new(StaticKernel::< + KernelSettings, E, i32, WORKGROUP, WORKGROUP, 1>, + >::new(workgroup)), + false => Box::new(StaticKernel::< + KernelSettings, E, i32, WORKGROUP, WORKGROUP, 1>, + >::new(workgroup)), }; - x.context.execute( - elemwise_workgroup(output.shape.num_elements(), WORKGROUP), - kernel, - &[&x.buffer, &output.buffer, &info_buffer], - ); + x.client + .execute(kernel, &[&x.handle, &output.handle, &info_handle]); output } @@ -72,39 +71,21 @@ pub(crate) fn avg_pool2d_backward( const WORKGROUP: usize = 32; let grad = kernel::into_contiguous(grad); + let output = empty_device(x.client.clone(), x.device.clone(), x.shape.clone()); + let info_handle = build_pool2d_info(&x, &grad, kernel_size, stride, padding, [1, 1]); + let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP); - let num_elems = x.shape.num_elements(); - let buffer = x - .context - .create_buffer(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(x.context.clone(), x.shape.clone(), buffer); - let info_buffer = build_pool2d_info(&x, &grad, kernel_size, stride, padding, [1, 1]); + let kernel: Box = match count_include_pad { + true => Box::new(StaticKernel::< + KernelSettings, E, i32, WORKGROUP, WORKGROUP, 1>, + >::new(workgroup)), + false => Box::new(StaticKernel::< + KernelSettings, E, i32, WORKGROUP, WORKGROUP, 1>, + >::new(workgroup)), + }; - let kernel = - match count_include_pad { - true => x.context.compile_static::, - E, - i32, - WORKGROUP, - WORKGROUP, - 1, - >>(), - false => x.context.compile_static::, - E, - i32, - WORKGROUP, - WORKGROUP, - 1, - >>(), - }; - - x.context.execute( - elemwise_workgroup(output.shape.num_elements(), WORKGROUP), - kernel, - &[&grad.buffer, &output.buffer, &info_buffer], - ); + x.client + .execute(kernel, &[&grad.handle, &output.handle, &info_handle]); output } diff --git a/burn-wgpu/src/kernel/pool/base.rs b/burn-wgpu/src/kernel/pool/base.rs index 0370f9713..13a16d6ec 100644 --- a/burn-wgpu/src/kernel/pool/base.rs +++ b/burn-wgpu/src/kernel/pool/base.rs @@ -1,7 +1,7 @@ -use crate::{element::WgpuElement, tensor::WgpuTensor}; +use crate::{ + compute::WgpuHandle, element::WgpuElement, ops::numeric::empty_device, tensor::WgpuTensor, +}; use burn_tensor::Shape; -use std::sync::Arc; -use wgpu::Buffer; /// Build basic info to launch pool 2d kernels. pub fn build_output_and_info_pool2d( @@ -10,7 +10,7 @@ pub fn build_output_and_info_pool2d( stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], -) -> (Arc, WgpuTensor) { +) -> (WgpuHandle, WgpuTensor) { let [kernel_height, kernel_width] = kernel_size; let [padding_height, padding_width] = padding; let [stride_height, stride_width] = stride; @@ -24,12 +24,7 @@ pub fn build_output_and_info_pool2d( / stride_width) + 1; let shape_out = Shape::new([batch_size, channels, out_height, out_width]); - let num_elems = shape_out.num_elements(); - - let buffer = x - .context - .create_buffer(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(x.context.clone(), shape_out, buffer); + let output = empty_device(x.client.clone(), x.device.clone(), shape_out); let info_buffer = build_pool2d_info(x, &output, kernel_size, stride, padding, dilation); @@ -43,7 +38,7 @@ pub fn build_pool2d_info( stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], -) -> Arc { +) -> WgpuHandle { let mut info: [u32; 24] = [0; 24]; info[0] = input.strides[0] as u32; info[1] = input.strides[1] as u32; @@ -72,9 +67,7 @@ pub fn build_pool2d_info( info[22] = dilation[0] as u32; info[23] = dilation[1] as u32; - let info_buffer = input - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); + let info_buffer = input.client.create(bytemuck::cast_slice(&info)); info_buffer } diff --git a/burn-wgpu/src/kernel/pool/max_pool2d.rs b/burn-wgpu/src/kernel/pool/max_pool2d.rs index 70124f935..d2f5d9213 100644 --- a/burn-wgpu/src/kernel/pool/max_pool2d.rs +++ b/burn-wgpu/src/kernel/pool/max_pool2d.rs @@ -1,4 +1,5 @@ use crate::{ + compute::StaticKernel, element::WgpuElement, kernel::{ self, elemwise_workgroup, @@ -6,6 +7,7 @@ use crate::{ KernelSettings, }, kernel_wgsl, + ops::numeric::empty_device, tensor::WgpuTensor, }; @@ -28,18 +30,15 @@ pub(crate) fn max_pool2d( ) -> WgpuTensor { const WORKGROUP: usize = 32; - let (info_buffer, output) = + let (info_handle, output) = build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation); - let kernel = x - .context - .compile_static::>(); - - x.context.execute( + let kernel = StaticKernel::>::new( elemwise_workgroup(output.shape.num_elements(), WORKGROUP), - kernel, - &[&x.buffer, &output.buffer, &info_buffer], ); + x.client + .execute(Box::new(kernel), &[&x.handle, &output.handle, &info_handle]); + output } @@ -52,25 +51,17 @@ pub(crate) fn max_pool2d_with_indices( ) -> (WgpuTensor, WgpuTensor) { const WORKGROUP: usize = 32; - let (info_buffer, output) = + let (info_handle, output) = build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation); - let num_elems = output.shape.num_elements(); + let indices = empty_device(x.client.clone(), x.device, output.shape.clone()); - let indices = WgpuTensor::new( - x.context.clone(), - output.shape.clone(), - x.context - .create_buffer(num_elems * std::mem::size_of::()), - ); + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(output.shape.num_elements(), WORKGROUP)); - let kernel = x - .context - .compile_static::>(); - - x.context.execute( - elemwise_workgroup(output.shape.num_elements(), WORKGROUP), - kernel, - &[&x.buffer, &output.buffer, &indices.buffer, &info_buffer], + x.client.execute( + Box::new(kernel), + &[&x.handle, &output.handle, &indices.handle, &info_handle], ); (output, indices) @@ -91,26 +82,18 @@ pub(crate) fn max_pool2d_with_indices_backward( let indices = kernel::into_contiguous(indices); let num_elems = x.shape.num_elements(); - let buffer = x - .context - .create_buffer(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(x.context.clone(), x.shape.clone(), buffer); + let buffer = x.client.empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new(x.client.clone(), x.device.clone(), x.shape.clone(), buffer); - let info_buffer = build_pool2d_info(&x, &grad, kernel_size, stride, padding, dilation); + let info_handle = build_pool2d_info(&x, &grad, kernel_size, stride, padding, dilation); - let kernel = x.context.compile_static::>(); + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(output.shape.num_elements(), WORKGROUP)); - x.context.execute( - elemwise_workgroup(output.shape.num_elements(), WORKGROUP), - kernel, - &[&indices.buffer, &grad.buffer, &output.buffer, &info_buffer], + x.client.execute( + Box::new(kernel), + &[&indices.handle, &grad.handle, &output.handle, &info_handle], ); output } diff --git a/burn-wgpu/src/kernel/prng/base.rs b/burn-wgpu/src/kernel/prng/base.rs index f9e34479b..c0478db0e 100644 --- a/burn-wgpu/src/kernel/prng/base.rs +++ b/burn-wgpu/src/kernel/prng/base.rs @@ -1,11 +1,10 @@ -use std::sync::Arc; - +use crate::{ + compute::{WgpuComputeClient, WgpuHandle}, + element::WgpuElement, + kernel_wgsl, SEED, +}; use burn_common::rand::get_seeded_rng; -use burn_tensor::Shape; use rand::Rng; -use wgpu::Buffer; - -use crate::{context::Context, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor, SEED}; kernel_wgsl!(Prng, "../../template/prng/prng.wgsl"); @@ -23,22 +22,20 @@ pub(crate) fn get_seeds() -> Vec { seeds } -pub(crate) fn make_output_tensor( - context: Arc, - shape: Shape, -) -> WgpuTensor { - let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::()); - WgpuTensor::new(context, shape, buffer) -} - -pub(crate) fn make_info_buffer(context: Arc, n_values_per_thread: usize) -> Arc { +pub(crate) fn make_info_buffer( + client: WgpuComputeClient, + n_values_per_thread: usize, +) -> WgpuHandle { let mut info = get_seeds(); info.insert(0, n_values_per_thread as u32); - context.create_buffer_with_data(bytemuck::cast_slice(&info)) + client.create(bytemuck::cast_slice(&info)) } -pub(crate) fn make_args_buffer(context: Arc, args: &[E]) -> Arc { - context.create_buffer_with_data(E::as_bytes(args)) +pub(crate) fn make_args_buffer( + client: WgpuComputeClient, + args: &[E], +) -> WgpuHandle { + client.create(E::as_bytes(args)) } #[cfg(test)] diff --git a/burn-wgpu/src/kernel/prng/bernoulli.rs b/burn-wgpu/src/kernel/prng/bernoulli.rs index 1ecbf7d30..9399a10fc 100644 --- a/burn-wgpu/src/kernel/prng/bernoulli.rs +++ b/burn-wgpu/src/kernel/prng/bernoulli.rs @@ -1,12 +1,13 @@ use burn_tensor::Shape; use crate::{ + compute::{compute_client, StaticKernel}, element::WgpuElement, kernel::{ - prng::base::{make_args_buffer, make_info_buffer, make_output_tensor}, - prng_workgroup, KernelSettings, SourceTemplate, StaticKernel, + prng::base::{make_args_buffer, make_info_buffer}, + prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, }, - pool::get_context, + ops::numeric::empty_device, tensor::WgpuTensor, GraphicsApi, WgpuDevice, }; @@ -15,9 +16,9 @@ use super::base::Prng; struct BernoulliPrng; -impl StaticKernel for BernoulliPrng { - fn source_template() -> SourceTemplate { - Prng::source_template() +impl StaticKernelSource for BernoulliPrng { + fn source() -> SourceTemplate { + Prng::source() .register("num_args", "1") .register( "prng_loop", @@ -36,15 +37,19 @@ pub fn random_bernoulli( const WORKGROUP: usize = 32; const N_VALUES_PER_THREAD: usize = 128; - let context = get_context::(device); - let output = make_output_tensor(context.clone(), shape.clone()); - let info_buffer = make_info_buffer(context.clone(), N_VALUES_PER_THREAD); - let args_buffer = make_args_buffer(context.clone(), &[prob]); + let client = compute_client::(device); + let output = empty_device(client.clone(), device.clone(), shape.clone()); + let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD); + let args_handle = make_args_buffer(client.clone(), &[prob]); + let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD); + let kernel = + StaticKernel::>::new( + workgroup, + ); - context.execute( - prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD), - context.compile_static::>(), - &[&output.buffer, &info_buffer, &args_buffer], + client.execute( + Box::new(kernel), + &[&output.handle, &info_handle, &args_handle], ); output diff --git a/burn-wgpu/src/kernel/prng/normal.rs b/burn-wgpu/src/kernel/prng/normal.rs index 9b38c4b1a..8c261f6c2 100644 --- a/burn-wgpu/src/kernel/prng/normal.rs +++ b/burn-wgpu/src/kernel/prng/normal.rs @@ -1,12 +1,13 @@ use burn_tensor::Shape; use crate::{ + compute::{compute_client, StaticKernel}, element::WgpuElement, kernel::{ - prng::base::{make_args_buffer, make_info_buffer, make_output_tensor}, - prng_workgroup, KernelSettings, SourceTemplate, StaticKernel, + prng::base::{make_args_buffer, make_info_buffer}, + prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, }, - pool::get_context, + ops::numeric::empty_device, tensor::WgpuTensor, GraphicsApi, WgpuDevice, }; @@ -15,9 +16,9 @@ use super::base::Prng; struct NormalPrng; -impl StaticKernel for NormalPrng { - fn source_template() -> SourceTemplate { - Prng::source_template() +impl StaticKernelSource for NormalPrng { + fn source() -> SourceTemplate { + Prng::source() .register("num_args", "2") .register( "prng_loop", @@ -39,15 +40,17 @@ pub fn random_normal( const WORKGROUP: usize = 32; const N_VALUES_PER_THREAD: usize = 128; // must be even - let context = get_context::(device); - let output = make_output_tensor(context.clone(), shape.clone()); - let info_buffer = make_info_buffer(context.clone(), N_VALUES_PER_THREAD); - let args_buffer = make_args_buffer(context.clone(), &[mean, std]); + let client = compute_client::(device); + let output = empty_device(client.clone(), device.clone(), shape.clone()); + let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD); + let args_handle = make_args_buffer(client.clone(), &[mean, std]); + let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD); + let kernel = + StaticKernel::>::new(workgroup); - context.execute( - prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD), - context.compile_static::>(), - &[&output.buffer, &info_buffer, &args_buffer], + client.execute( + Box::new(kernel), + &[&output.handle, &info_handle, &args_handle], ); output diff --git a/burn-wgpu/src/kernel/prng/uniform.rs b/burn-wgpu/src/kernel/prng/uniform.rs index b8772fd01..dfcabee68 100644 --- a/burn-wgpu/src/kernel/prng/uniform.rs +++ b/burn-wgpu/src/kernel/prng/uniform.rs @@ -1,12 +1,13 @@ use burn_tensor::Shape; use crate::{ + compute::{compute_client, StaticKernel}, element::WgpuElement, kernel::{ - prng::base::{make_args_buffer, make_info_buffer, make_output_tensor}, - prng_workgroup, KernelSettings, SourceTemplate, StaticKernel, + prng::base::{make_args_buffer, make_info_buffer}, + prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, }, - pool::get_context, + ops::numeric::empty_device, tensor::WgpuTensor, GraphicsApi, WgpuDevice, }; @@ -15,9 +16,9 @@ use super::base::Prng; struct UniformPrng; -impl StaticKernel for UniformPrng { - fn source_template() -> SourceTemplate { - Prng::source_template().register("num_args", "2").register( +impl StaticKernelSource for UniformPrng { + fn source() -> SourceTemplate { + Prng::source().register("num_args", "2").register( "prng_loop", include_str!("../../template/prng/uniform_inner_loop.wgsl"), ) @@ -34,15 +35,18 @@ pub fn random_uniform( const WORKGROUP: usize = 32; const N_VALUES_PER_THREAD: usize = 128; - let context = get_context::(device); - let output = make_output_tensor(context.clone(), shape.clone()); - let info_buffer = make_info_buffer(context.clone(), N_VALUES_PER_THREAD); - let args_buffer = make_args_buffer(context.clone(), &[low, high]); + let client = compute_client::(device); + let output = empty_device(client.clone(), device.clone(), shape.clone()); + let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD); + let args_handle = make_args_buffer(client.clone(), &[low, high]); + let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD); + let kernel = StaticKernel::>::new( + workgroup, + ); - context.execute( - prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD), - context.compile_static::>(), - &[&output.buffer, &info_buffer, &args_buffer], + client.execute( + Box::new(kernel), + &[&output.handle, &info_handle, &args_handle], ); output diff --git a/burn-wgpu/src/kernel/reduction.rs b/burn-wgpu/src/kernel/reduction.rs index d83d287f8..59bdfa425 100644 --- a/burn-wgpu/src/kernel/reduction.rs +++ b/burn-wgpu/src/kernel/reduction.rs @@ -1,5 +1,8 @@ -use super::{build_info, KernelSettings, SourceTemplate, StaticKernel}; -use crate::{element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl, tensor::WgpuTensor}; +use super::{build_info, KernelSettings, SourceTemplate, StaticKernelSource}; +use crate::{ + compute::StaticKernel, element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl, + tensor::WgpuTensor, +}; use burn_tensor::Shape; kernel_wgsl!(RecursiveSumRaw, "../template/reduction/recursive_sum.wgsl"); @@ -11,15 +14,15 @@ pub struct ArgsMin; pub struct SumDim; pub struct MeanDim; -impl StaticKernel for SumDim { - fn source_template() -> SourceTemplate { - ReductionDimRaw::source_template().register("assign", "output[id] = sum;") +impl StaticKernelSource for SumDim { + fn source() -> SourceTemplate { + ReductionDimRaw::source().register("assign", "output[id] = sum;") } } -impl StaticKernel for MeanDim { - fn source_template() -> SourceTemplate { - ReductionDimRaw::source_template() +impl StaticKernelSource for MeanDim { + fn source() -> SourceTemplate { + ReductionDimRaw::source() .add_template( "fn mean_dim(sum: {{ elem }}, dim: u32) -> {{ elem }} { return sum / {{ elem }}(dim); @@ -29,17 +32,17 @@ impl StaticKernel for MeanDim { } } -impl StaticKernel for ArgsMax { - fn source_template() -> SourceTemplate { - ReductionArgsRaw::source_template() +impl StaticKernelSource for ArgsMax { + fn source() -> SourceTemplate { + ReductionArgsRaw::source() .register("cmp", ">") .register("initial", (-32767).to_string()) } } -impl StaticKernel for ArgsMin { - fn source_template() -> SourceTemplate { - ReductionArgsRaw::source_template() +impl StaticKernelSource for ArgsMin { + fn source() -> SourceTemplate { + ReductionArgsRaw::source() .register("cmp", "<") .register("initial", 32767.to_string()) } @@ -49,28 +52,29 @@ impl StaticKernel for ArgsMin { pub fn sum(input: WgpuTensor) -> WgpuTensor { const WORKGROUP: usize = 32; - let mut input_buffer = input.buffer; + let mut input_handle = input.handle; let mut workgroup = elemwise_workgroup(input.shape.num_elements(), WORKGROUP); - let kernel = input - .context - .compile_static::>(); - loop { let num_invocations = workgroup.num_invocations(); - let buffer = input - .context - .create_buffer(core::mem::size_of::() * num_invocations); + let handle = input + .client + .empty(core::mem::size_of::() * num_invocations); + + let kernel = + StaticKernel::>::new( + workgroup, + ); input - .context - .execute(workgroup.clone(), kernel.clone(), &[&input_buffer, &buffer]); + .client + .execute(Box::new(kernel), &[&input_handle, &handle]); if num_invocations <= 1 { - return WgpuTensor::new(input.context, Shape::new([1]), buffer); + return WgpuTensor::new(input.client, input.device, Shape::new([1]), handle); } - input_buffer = buffer; + input_handle = handle; workgroup = elemwise_workgroup(num_invocations, WORKGROUP); } } @@ -91,7 +95,7 @@ pub fn mean_dim( reduction_dim::(input, dim) } -fn reduction_dim( +fn reduction_dim( input: WgpuTensor, dim: usize, ) -> WgpuTensor { @@ -100,25 +104,25 @@ fn reduction_dim( let mut shape_out = input.shape.clone(); shape_out.dims[dim] = 1; let num_elems = shape_out.num_elements(); - let buffer = input - .context - .create_buffer(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(input.context.clone(), shape_out, buffer); + let handle = input.client.empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new( + input.client.clone(), + input.device.clone(), + shape_out, + handle, + ); - let kernel = input - .context - .compile_static::>(); + let kernel = StaticKernel::>::new( + elemwise_workgroup(num_elems, WORKGROUP), + ); let mut info = build_info(&[&input, &output]); info.push(dim as u32); - let info_buffers = input - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); - input.context.execute( - elemwise_workgroup(num_elems, WORKGROUP), - kernel, - &[&input.buffer, &output.buffer, &info_buffers], + input.client.execute( + Box::new(kernel), + &[&input.handle, &output.handle, &info_handle], ); output @@ -140,7 +144,7 @@ pub fn argmin( reduction_args_dim::(input, dim) } -fn reduction_args_dim( +fn reduction_args_dim( input: WgpuTensor, dim: usize, ) -> WgpuTensor { @@ -149,27 +153,27 @@ fn reduction_args_dim()); - let output = WgpuTensor::new(input.context.clone(), shape_out, buffer); - - let kernel = input - .context - .compile_static::>(); - let mut info = build_info(&[&input, &output]); - info.push(dim as u32); - let info_buffers = input - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); - - input.context.execute( - elemwise_workgroup(num_elems, WORKGROUP), - kernel, - &[&input.buffer, &output.buffer, &info_buffers], + let buffer = input.client.empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new( + input.client.clone(), + input.device.clone(), + shape_out, + buffer, ); - WgpuTensor::new(output.context, output.shape, output.buffer) + let kernel = StaticKernel::>::new( + elemwise_workgroup(num_elems, WORKGROUP), + ); + let mut info = build_info(&[&input, &output]); + info.push(dim as u32); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); + + input.client.execute( + Box::new(kernel), + &[&input.handle, &output.handle, &info_handle], + ); + + WgpuTensor::new(output.client, output.device, output.shape, output.handle) } #[cfg(test)] diff --git a/burn-wgpu/src/kernel/unary.rs b/burn-wgpu/src/kernel/unary.rs index 89fa37414..e03b63177 100644 --- a/burn-wgpu/src/kernel/unary.rs +++ b/burn-wgpu/src/kernel/unary.rs @@ -1,5 +1,5 @@ -use super::{elemwise_workgroup, KernelSettings, StaticKernel}; -use crate::{element::WgpuElement, kernel_wgsl, tensor::WgpuTensor}; +use super::{elemwise_workgroup, KernelSettings, StaticKernelSource}; +use crate::{compute::StaticKernel, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor}; kernel_wgsl!(UnaryRaw, "../template/unary.wgsl"); kernel_wgsl!(UnaryInplaceRaw, "../template/unary_inplace.wgsl"); @@ -13,9 +13,9 @@ macro_rules! unary { ) => { pub struct $struct; - impl $crate::kernel::StaticKernel for $struct { - fn source_template() -> $crate::kernel::SourceTemplate { - let source = $crate::kernel::UnaryRaw::source_template(); + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + let source = $crate::kernel::UnaryRaw::source(); source.register("body", format!("output[id] = {}(input[id]);", $func)) } } @@ -26,9 +26,9 @@ macro_rules! unary { ) => { pub struct $struct; - impl $crate::kernel::StaticKernel for $struct { - fn source_template() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryRaw::source_template().register("body", $body) + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryRaw::source().register("body", $body) } } }; @@ -39,9 +39,9 @@ macro_rules! unary { ) => { pub struct $struct; - impl $crate::kernel::StaticKernel for $struct { - fn source_template() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryRaw::source_template() + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryRaw::source() .register("body", format!("output[id] = {}(input[id]);", $func)) .add_template(include_str!($file)) } @@ -58,9 +58,9 @@ macro_rules! unary_inplace { ) => { pub struct $struct; - impl $crate::kernel::StaticKernel for $struct { - fn source_template() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryInplaceRaw::source_template() + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryInplaceRaw::source() .register("body", format!("input[id] = {}(input[id]);", $func)) } } @@ -71,9 +71,9 @@ macro_rules! unary_inplace { ) => { pub struct $struct; - impl $crate::kernel::StaticKernel for $struct { - fn source_template() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryInplaceRaw::source_template().register("body", $body) + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryInplaceRaw::source().register("body", $body) } } }; @@ -84,9 +84,9 @@ macro_rules! unary_inplace { ) => { pub struct $struct; - impl $crate::kernel::StaticKernel for $struct { - fn source_template() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryInplaceRaw::source_template() + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryInplaceRaw::source() .register("body", format!("input[id] = {}(input[id]);", $func)) .add_template(include_str!($file)) } @@ -95,59 +95,55 @@ macro_rules! unary_inplace { } /// Execute a unary kernel using the default settings. -pub fn unary_default( +pub fn unary_default( input: WgpuTensor, ) -> WgpuTensor { unary::(input) } /// Execute a unary inplace kernel using the default settings. -pub fn unary_inplace_default( +pub fn unary_inplace_default( input: WgpuTensor, ) -> WgpuTensor { unary_inplace::(input) } /// Execute a unary inplace kernel using the provided WORKGROUP. -pub fn unary_inplace( +pub fn unary_inplace< + K: StaticKernelSource, + E: WgpuElement, + const D: usize, + const WORKGROUP: usize, +>( input: WgpuTensor, ) -> WgpuTensor { let num_elems = input.shape.num_elements(); - let kernel = input - .context - .compile_static::>(); - - input.context.execute( + let kernel = StaticKernel::>::new( elemwise_workgroup(num_elems, WORKGROUP), - kernel, - &[&input.buffer], ); + input.client.execute(Box::new(kernel), &[&input.handle]); + input } /// Execute a unary kernel using the provided WORKGROUP. -pub fn unary( +pub fn unary( input: WgpuTensor, ) -> WgpuTensor { let num_elems = input.shape.num_elements(); - let buffer = input - .context - .create_buffer(num_elems * core::mem::size_of::()); - let mut output = WgpuTensor::new(input.context.clone(), input.shape, buffer); + let buffer = input.client.empty(num_elems * core::mem::size_of::()); + let mut output = WgpuTensor::new(input.client.clone(), input.device, input.shape, buffer); // Since we don't handle the stride inside the kernel, the output tensor have the same strides // as the input tensor. It might not be in the default format. output.strides = input.strides; - let kernel = input - .context - .compile_static::>(); - - input.context.execute( + let kernel = StaticKernel::>::new( elemwise_workgroup(num_elems, WORKGROUP), - kernel, - &[&input.buffer, &output.buffer], ); + input + .client + .execute(Box::new(kernel), &[&input.handle, &output.handle]); output } diff --git a/burn-wgpu/src/kernel/unary_scalar.rs b/burn-wgpu/src/kernel/unary_scalar.rs index b4700b904..82b1fb583 100644 --- a/burn-wgpu/src/kernel/unary_scalar.rs +++ b/burn-wgpu/src/kernel/unary_scalar.rs @@ -1,5 +1,5 @@ -use super::{elemwise_workgroup, KernelSettings, StaticKernel}; -use crate::{element::WgpuElement, kernel_wgsl, tensor::WgpuTensor}; +use super::{elemwise_workgroup, KernelSettings, StaticKernelSource}; +use crate::{compute::StaticKernel, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor}; kernel_wgsl!(UnaryScalarRaw, "../template/unary_scalar.wgsl"); kernel_wgsl!( @@ -16,9 +16,9 @@ macro_rules! unary_scalar { ) => { pub struct $struct; - impl $crate::kernel::StaticKernel for $struct { - fn source_template() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarRaw::source_template() + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryScalarRaw::source() .register("body", format!("output[id] = lhs[id] {} rhs;", $ops)) } } @@ -30,9 +30,9 @@ macro_rules! unary_scalar { ) => { pub struct $struct; - impl $crate::kernel::StaticKernel for $struct { - fn source_template() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarRaw::source_template() + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryScalarRaw::source() .register("body", format!("output[id] = {}(lhs[id], rhs);", $func)) } } @@ -45,9 +45,9 @@ macro_rules! unary_scalar { ) => { pub struct $struct; - impl $crate::kernel::StaticKernel for $struct { - fn source_template() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarRaw::source_template() + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryScalarRaw::source() .register("body", format!("output[id] = {}(lhs[id], rhs);", $func)) .add_template(include_str!($file)) } @@ -64,9 +64,9 @@ macro_rules! unary_scalar_inplace { ) => { pub struct $struct; - impl $crate::kernel::StaticKernel for $struct { - fn source_template() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarInplaceRaw::source_template() + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryScalarInplaceRaw::source() .register("body", format!("lhs[id] = lhs[id] {} rhs;", $ops)) } } @@ -78,9 +78,9 @@ macro_rules! unary_scalar_inplace { ) => { pub struct $struct; - impl $crate::kernel::StaticKernel for $struct { - fn source_template() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarInplaceRaw::source_template() + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryScalarInplaceRaw::source() .register("body", format!("lhs[id] = {}(lhs[id], rhs);", $func)) } } @@ -93,9 +93,9 @@ macro_rules! unary_scalar_inplace { ) => { pub struct $struct; - impl $crate::kernel::StaticKernel for $struct { - fn source_template() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarInplaceRaw::source_template() + impl $crate::kernel::StaticKernelSource for $struct { + fn source() -> $crate::kernel::SourceTemplate { + $crate::kernel::UnaryScalarInplaceRaw::source() .register("body", format!("lhs[id] = {}(lhs[id], rhs);", $func)) .add_template(include_str!($file)) } @@ -104,7 +104,7 @@ macro_rules! unary_scalar_inplace { } /// Execute a unary scalar kernel using the default settings. -pub fn unary_scalar_default( +pub fn unary_scalar_default( lhs: WgpuTensor, scalar: E, ) -> WgpuTensor { @@ -112,31 +112,33 @@ pub fn unary_scalar_default( } /// Execute a unary scalar kernel using the provided WORKGROUP. -pub fn unary_scalar( +pub fn unary_scalar< + K: StaticKernelSource, + E: WgpuElement, + const D: usize, + const WORKGROUP: usize, +>( lhs: WgpuTensor, scalar: E, ) -> WgpuTensor { let num_elems = lhs.shape.num_elements(); - let buffer = lhs - .context - .create_buffer(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(lhs.context.clone(), lhs.shape, buffer); - let kernel = lhs - .context - .compile_static::>(); - let rhs_buffer = lhs.context.create_buffer_with_data(E::as_bytes(&[scalar])); - - lhs.context.execute( + let buffer = lhs.client.empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new(lhs.client.clone(), lhs.device, lhs.shape, buffer); + let kernel = StaticKernel::>::new( elemwise_workgroup(num_elems, WORKGROUP), - kernel, - &[&lhs.buffer, &rhs_buffer, &output.buffer], + ); + let rhs_handle = lhs.client.create(E::as_bytes(&[scalar])); + + lhs.client.execute( + Box::new(kernel), + &[&lhs.handle, &rhs_handle, &output.handle], ); output } /// Execute a unary scalar inplace kernel using the default settings. -pub fn unary_scalar_inplace_default( +pub fn unary_scalar_inplace_default( lhs: WgpuTensor, scalar: E, ) -> WgpuTensor { @@ -145,7 +147,7 @@ pub fn unary_scalar_inplace_default, scalar: E, ) -> WgpuTensor { - let kernel = lhs - .context - .compile_static::>(); - let rhs_buffer = lhs.context.create_buffer_with_data(E::as_bytes(&[scalar])); - - lhs.context.execute( - { - let num_elems = lhs.shape.num_elements(); - elemwise_workgroup(num_elems, WORKGROUP) - }, - kernel, - &[&lhs.buffer, &rhs_buffer], + let num_elems = lhs.shape.num_elements(); + let kernel = StaticKernel::>::new( + elemwise_workgroup(num_elems, WORKGROUP), ); + let rhs_handle = lhs.client.create(E::as_bytes(&[scalar])); + + lhs.client + .execute(Box::new(kernel), &[&lhs.handle, &rhs_handle]); lhs } diff --git a/burn-wgpu/src/lib.rs b/burn-wgpu/src/lib.rs index 4a88a17d7..20a288cfc 100644 --- a/burn-wgpu/src/lib.rs +++ b/burn-wgpu/src/lib.rs @@ -10,20 +10,13 @@ mod ops; /// Benchmark module pub mod benchmark; -/// Context module. -pub mod context; +/// Compute related module. +pub mod compute; /// Kernel module pub mod kernel; /// Tensor module. pub mod tensor; -#[cfg(test)] // Only enabled for dev for now. -/// Compute related module. -pub mod compute; - -pub(crate) mod pool; -pub(crate) mod tune; - mod element; pub use element::{FloatElement, IntElement}; diff --git a/burn-wgpu/src/ops/base.rs b/burn-wgpu/src/ops/base.rs index 52aad52f8..d71e3af4a 100644 --- a/burn-wgpu/src/ops/base.rs +++ b/burn-wgpu/src/ops/base.rs @@ -1,5 +1,6 @@ use crate::{ - element::WgpuElement, kernel, pool::get_context, tensor::WgpuTensor, GraphicsApi, WgpuDevice, + compute::compute_client, element::WgpuElement, kernel, tensor::WgpuTensor, GraphicsApi, + WgpuDevice, }; use burn_tensor::{backend::Backend, Data, Shape}; @@ -18,15 +19,15 @@ pub fn from_data( data: Data, device: &WgpuDevice, ) -> WgpuTensor { - let context = get_context::(device); - let buffer = context.create_buffer_with_data_options(E::as_bytes(&data.value), true); + let client = compute_client::(device); + let buffer = client.create(E::as_bytes(&data.value)); - WgpuTensor::new(context, data.shape, buffer) + WgpuTensor::new(client, device.clone(), data.shape, buffer) } pub fn into_data(tensor: WgpuTensor) -> Data { let tensor = kernel::into_contiguous(tensor); - let bytes = tensor.context.read_buffer(tensor.buffer); + let bytes = tensor.client.read(&tensor.handle); let values = E::from_bytes(&bytes); Data::new(values.to_vec(), tensor.shape) @@ -36,22 +37,22 @@ pub fn to_device( tensor: WgpuTensor, device: &WgpuDevice, ) -> WgpuTensor { - if &tensor.context.device == device { + if &tensor.device == device { return tensor; } - let context = get_context::(device); - tensor.to_context(context) + let client = compute_client::(device); + tensor.to_client(client, device.clone()) } pub fn empty( shape: Shape, device: &WgpuDevice, ) -> WgpuTensor { - let context = get_context::(device); - let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::()); + let client = compute_client::(device); + let buffer = client.empty(shape.num_elements() * core::mem::size_of::()); - WgpuTensor::new(context, shape, buffer) + WgpuTensor::new(client, device.clone(), shape, buffer) } pub fn swap_dims( @@ -72,5 +73,5 @@ pub fn reshape( // TODO: Not force standard layout all the time (improve performance). let tensor = kernel::into_contiguous(tensor); - WgpuTensor::new(tensor.context, shape, tensor.buffer) + WgpuTensor::new(tensor.client, tensor.device, shape, tensor.handle) } diff --git a/burn-wgpu/src/ops/bool_ops.rs b/burn-wgpu/src/ops/bool_ops.rs index 8e2c51ccc..312b3c513 100644 --- a/burn-wgpu/src/ops/bool_ops.rs +++ b/burn-wgpu/src/ops/bool_ops.rs @@ -47,7 +47,7 @@ where fn bool_into_int(tensor: BoolTensor) -> IntTensor { if std::mem::size_of::() == std::mem::size_of::() { - return WgpuTensor::new(tensor.context, tensor.shape, tensor.buffer); + return WgpuTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle); } let device = Self::bool_device(&tensor); @@ -57,7 +57,7 @@ where } fn bool_device(tensor: &BoolTensor) -> Device { - tensor.context.device.clone() + tensor.device.clone() } fn bool_to_device( diff --git a/burn-wgpu/src/ops/float_ops.rs b/burn-wgpu/src/ops/float_ops.rs index beb655704..b66957571 100644 --- a/burn-wgpu/src/ops/float_ops.rs +++ b/burn-wgpu/src/ops/float_ops.rs @@ -57,7 +57,7 @@ where } fn device(tensor: &FloatTensor) -> Device { - tensor.context.device.clone() + tensor.device.clone() } fn to_device( @@ -139,11 +139,7 @@ where lhs: FloatTensor, rhs: FloatTensor, ) -> FloatTensor { - #[cfg(feature = "autotune")] - return kernel::matmul::tune::(lhs, rhs); - - #[cfg(not(feature = "autotune"))] - kernel::matmul::contiguous_vectorized::matmul_tiling_2d_default(lhs, rhs) + kernel::matmul::contiguous::matmul_tiling_2d_default(lhs, rhs) } fn swap_dims( diff --git a/burn-wgpu/src/ops/int_ops.rs b/burn-wgpu/src/ops/int_ops.rs index 8fd6892b6..1f58ffd38 100644 --- a/burn-wgpu/src/ops/int_ops.rs +++ b/burn-wgpu/src/ops/int_ops.rs @@ -33,7 +33,7 @@ where } fn int_device(tensor: &IntTensor) -> Device { - tensor.context.device.clone() + tensor.device.clone() } fn int_to_device( diff --git a/burn-wgpu/src/ops/numeric.rs b/burn-wgpu/src/ops/numeric.rs index 7561c5c63..6ae646ca6 100644 --- a/burn-wgpu/src/ops/numeric.rs +++ b/burn-wgpu/src/ops/numeric.rs @@ -1,8 +1,8 @@ +use crate::compute::{compute_client, WgpuComputeClient}; use crate::kernel::{ binary_elemwise_default, binary_elemwise_inplace_default, unary_scalar_default, unary_scalar_inplace_default, }; -use crate::pool::get_context; use crate::{ binary_elemwise, binary_elemwise_inplace, element::WgpuElement, tensor::WgpuTensor, unary_scalar, unary_scalar_inplace, @@ -10,26 +10,38 @@ use crate::{ use crate::{GraphicsApi, WgpuDevice}; use burn_tensor::{Element, ElementConversion, Shape}; -pub fn zeros( +pub fn zeros( shape: Shape, device: &WgpuDevice, ) -> WgpuTensor { - let context = get_context::(device); + let client = compute_client::(device); - let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::()); + zeros_device(client, device.clone(), shape) +} - WgpuTensor::new(context, shape, buffer) +pub fn zeros_device( + client: WgpuComputeClient, + device: WgpuDevice, + shape: Shape, +) -> WgpuTensor { + mul_scalar(empty_device(client, device, shape), 0i32.elem::()) +} + +pub fn empty_device( + client: WgpuComputeClient, + device: WgpuDevice, + shape: Shape, +) -> WgpuTensor { + let buffer = client.empty(shape.num_elements() * core::mem::size_of::()); + + WgpuTensor::new(client, device, shape, buffer) } pub fn ones( shape: Shape, device: &WgpuDevice, ) -> WgpuTensor { - let context = get_context::(device); - - let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::()); - - add_scalar(WgpuTensor::new(context, shape, buffer), 1i32.elem::()) + add_scalar(zeros::(shape, device), 1i32.elem::()) } pub fn add( diff --git a/burn-wgpu/src/pool.rs b/burn-wgpu/src/pool.rs deleted file mode 100644 index 05ce3dc28..000000000 --- a/burn-wgpu/src/pool.rs +++ /dev/null @@ -1,63 +0,0 @@ -use crate::{context::Context, GraphicsApi, WgpuDevice}; -use std::{ - any::TypeId, - collections::HashMap, - sync::{Arc, Mutex}, -}; - -static POOL_CONTEXT: Mutex> = Mutex::new(None); - -#[derive(Default)] -struct ContextPool { - contexts: HashMap>, -} - -#[derive(Clone, Debug, Hash, PartialEq, Eq)] -struct Key { - api_id: TypeId, - device: WgpuDevice, -} - -impl Key { - fn new(device: &WgpuDevice) -> Self { - Self { - api_id: TypeId::of::(), - device: device.clone(), - } - } -} - -/// Get a [context](Context) for the given [device](WGPUDevice). -/// -/// # Notes -/// -/// If a context already exist for the current [device](WGPUDevice), the same instance will be -/// returned. -pub fn get_context(device: &WgpuDevice) -> Arc { - let mut pool = POOL_CONTEXT.lock().unwrap(); - - let context = if let Some(pool) = pool.as_mut() { - // Fetch device in pool - match pool.contexts.get(&Key::new::(device)) { - Some(context) => context.clone(), - None => { - // Init new device - let context = Arc::new(Context::new::(device)); - pool.contexts.insert(Key::new::(device), context.clone()); - context - } - } - } else { - // Initialize pool - let context = Arc::new(Context::new::(device)); - let mut new_pool = ContextPool::default(); - - new_pool - .contexts - .insert(Key::new::(device), context.clone()); - *pool = Some(new_pool); - context - }; - - context -} diff --git a/burn-wgpu/src/tensor/base.rs b/burn-wgpu/src/tensor/base.rs index 309e7367a..46ada4a57 100644 --- a/burn-wgpu/src/tensor/base.rs +++ b/burn-wgpu/src/tensor/base.rs @@ -1,19 +1,22 @@ -use crate::unary; +use crate::{ + compute::{WgpuComputeClient, WgpuHandle}, + unary, WgpuDevice, +}; +use crate::{element::WgpuElement, kernel::unary_default}; use burn_tensor::Shape; -use std::{marker::PhantomData, sync::Arc}; -use wgpu::Buffer; - -use crate::{context::Context, element::WgpuElement, kernel::unary_default}; +use std::marker::PhantomData; /// The basic tensor primitive struct. #[derive(Debug, Clone)] pub struct WgpuTensor { - /// The context the tensor is binded to. - pub context: Arc, + /// Compute client for wgpu. + pub client: WgpuComputeClient, /// The buffer where the data are stored. - pub buffer: Arc, + pub handle: WgpuHandle, /// The shape of the current tensor. pub shape: Shape, + /// The device of the current tensor. + pub device: WgpuDevice, /// The strides of the current tensor. pub strides: [usize; D], elem: PhantomData, @@ -21,8 +24,12 @@ pub struct WgpuTensor { #[derive(Debug, Clone)] pub(crate) struct WgpuTensorDyn { - pub(crate) context: Arc, - pub(crate) buffer: Arc, + /// Compute client for wgpu. + pub client: WgpuComputeClient, + /// The buffer where the data are stored. + pub handle: WgpuHandle, + /// The device of the current tensor. + pub device: WgpuDevice, pub(crate) shape: Vec, pub(crate) strides: Vec, elem: PhantomData, @@ -31,8 +38,9 @@ pub(crate) struct WgpuTensorDyn { impl From> for WgpuTensorDyn { fn from(value: WgpuTensor) -> Self { WgpuTensorDyn { - context: value.context, - buffer: value.buffer, + client: value.client, + handle: value.handle, + device: value.device, shape: value.shape.dims.to_vec(), strides: value.strides.to_vec(), elem: PhantomData, @@ -43,8 +51,9 @@ impl From> for WgpuTensorDyn impl From> for WgpuTensor { fn from(value: WgpuTensorDyn) -> Self { WgpuTensor { - context: value.context, - buffer: value.buffer, + client: value.client, + handle: value.handle, + device: value.device, shape: Shape::new(value.shape.try_into().expect("Wrong dimension")), strides: value.strides.try_into().expect("Wrong dimension"), elem: PhantomData, @@ -54,7 +63,12 @@ impl From> for WgpuTensor impl WgpuTensor { /// Create a new tensor. - pub fn new(context: Arc, shape: Shape, buffer: Arc) -> Self { + pub fn new( + client: WgpuComputeClient, + device: WgpuDevice, + shape: Shape, + handle: WgpuHandle, + ) -> Self { let mut strides = [0; D]; let mut current = 1; @@ -69,29 +83,31 @@ impl WgpuTensor { }); Self { - context, - buffer, + client, + handle, shape, strides, + device, elem: PhantomData, } } /// Change the context of the current tensor and return the newly transferred tensor. - pub fn to_context(&self, context: Arc) -> Self { - let data = self.context.read_buffer(self.buffer.clone()); - let buffer = context.create_buffer_with_data(&data); + pub fn to_client(&self, client: WgpuComputeClient, device: WgpuDevice) -> Self { + let data = self.client.read(&self.handle); + let handle = client.create(&data); Self { - context, - buffer, + client, + handle, shape: self.shape.clone(), strides: self.strides, + device, elem: PhantomData, } } pub(crate) fn can_mut_broadcast(&self, tensor_other: &WgpuTensor) -> bool { - if Arc::strong_count(&self.buffer) > 1 { + if !self.handle.can_mut() { return false; } @@ -120,19 +136,15 @@ impl WgpuTensor { /// Check if the tensor is safe to mutate. pub fn can_mut(&self) -> bool { - if Arc::strong_count(&self.buffer) > 1 { - return false; - } - - true + self.handle.can_mut() } /// Assert that both tensors are on the same device. pub fn assert_is_on_same_device(&self, other: &Self) { - if self.context.device != other.context.device { + if self.device != other.device { panic!( "Both tensors should be on the same device {:?} != {:?}", - self.context.device, other.context.device + self.device, other.device ); } } diff --git a/burn-wgpu/src/tune/base.rs b/burn-wgpu/src/tune/base.rs deleted file mode 100644 index f514f278c..000000000 --- a/burn-wgpu/src/tune/base.rs +++ /dev/null @@ -1,196 +0,0 @@ -use std::{collections::HashMap, fmt::Display, hash::Hash, sync::Arc, time::Duration}; - -use burn_common::stub::RwLock; - -use crate::{ - benchmark::{Benchmark, BenchmarkResult}, - context::Context, - GraphicsApi, WgpuDevice, -}; - -/// Key used for caching. -#[derive(Hash, Clone, Debug, PartialEq, Eq)] -pub struct AutoTuneKey { - /// List all shapes used for the autotuned kernel. - shapes: Vec>, - /// Operation name. - ops: String, - /// Device name used to benchmark. - device: String, - /// Graphics api name. - graphics_api: String, -} - -impl AutoTuneKey { - pub fn new(shapes: Vec>, ops: String, context: &Context) -> Self { - let device = format!("{:?}", context.info.name); - let graphics_api = format!("{:?}", context.info.backend); - - Self { - shapes, - ops, - device, - graphics_api, - } - } -} - -impl Display for AutoTuneKey { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let message = format!( - "(AutoTuneKey) Kernel {} - Shapes {:?} - Device {} - API {}", - self.ops, self.shapes, self.device, self.graphics_api, - ); - f.write_str(&message) - } -} - -/// Objects that are stored in the tuner cache. Can have any inputs and outputs. -pub type AutoTuneValue = Box; - -/// Executable function -pub trait KernelFunction: Send + Sync + 'static { - type Input; - type Output; - - fn call(&self, input: Self::Input) -> Self::Output; - fn description(&self) -> String; -} - -/// Encapsulates kernel functions, with specified inputs and outputs -pub type AutoTuneFunction = Arc>; - -/// The tunable links an executable function to its corresponding benchmark -#[derive(new)] -pub struct Tunable { - func: AutoTuneFunction, - benchmark: Arc>, -} - -impl std::fmt::Display for Tunable -where - G: GraphicsApi, - I: Send + Sync + 'static, - O: Send + Sync + 'static, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(self.func.description().as_str()) - } -} - -/// Output of the tuner execution. If execution succeeded, the output of -/// the execution is contained. Otherwise, the function must be tuned and -/// the input is given back to preserve ownership. -#[derive(Debug)] -pub enum Execution { - Executed(O), - NoCacheFound(I), -} - -/// The tuner allows to find the best version of a kernel by benchmarking -/// different versions. It keeps the best version found in a cache, so the best -/// function is reused automatically in similar circumstances. -#[derive(Debug)] -pub struct Tuner { - cache: RwLock>, -} - -impl Tuner { - pub fn new() -> Self { - Self { - cache: RwLock::new(HashMap::new()), - } - } - - /// Executes the function stored in the cache at key id, on specified input, - /// and returns its output. If cache has no such id, returns NoCacheFound. - pub fn execute(&self, id: &AutoTuneKey, input: I) -> Execution - where - I: Send + Sync + 'static, - O: Send + Sync + 'static, - { - let cache = self.cache.read().unwrap(); - let obj = cache.get(id); - - let obj = match obj { - None => return Execution::NoCacheFound(input), - Some(value) => value, - }; - - let func: &Arc> = obj.downcast_ref().unwrap(); - let output = func.call(input); - - Execution::Executed(output) - } - - /// Finds the best tunable and writes it to the cache. - pub fn tune( - &self, - id: AutoTuneKey, - input: I, - tunables: Vec>, - device: &WgpuDevice, - context: &Context, - ) -> O - where - I: Send + Sync + 'static, - O: Send + Sync + 'static, - { - let mut cache = self.cache.write().unwrap(); - - context.start_tuning(); - let results = benchmark(&tunables, device); - let kernel = find_best(&id, tunables, results); - cache.insert(id.clone(), kernel); - drop(cache); - context.stop_tuning(); - - match self.execute(&id, input) { - Execution::Executed(output) => output, - _ => panic!("Should have found a kernel to execute. "), - } - } -} - -/// Finds the best kernel by keeping the one with the smallest median duration. -fn find_best( - id: &AutoTuneKey, - tunables: Vec>, - results: Vec, -) -> AutoTuneValue -where - I: Send + Sync + 'static, - O: Send + Sync + 'static, -{ - let mut best_duration = Duration::MAX; - let mut best_tunable = None; - - for (tunable, result) in tunables.into_iter().zip(results) { - let duration = result.median_duration(); - - if duration < best_duration { - best_duration = duration; - best_tunable = Some(tunable); - } - } - - let tunable = best_tunable.expect("At least one tunable needed. "); - - log::info!("{} => {}", id, tunable); - Box::new(tunable.func) -} - -/// Run benchmarks. -fn benchmark( - tunables: &[Tunable], - device: &WgpuDevice, -) -> Vec -where - I: Send + Sync + 'static, - O: Send + Sync + 'static, -{ - tunables - .iter() - .map(|tunable| tunable.benchmark.run(device)) - .collect() -} diff --git a/burn-wgpu/src/tune/mod.rs b/burn-wgpu/src/tune/mod.rs deleted file mode 100644 index cbcb6ac7e..000000000 --- a/burn-wgpu/src/tune/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod base; - -pub use base::*; diff --git a/burn/Cargo.toml b/burn/Cargo.toml index a3478c71a..98357fcf6 100644 --- a/burn/Cargo.toml +++ b/burn/Cargo.toml @@ -44,7 +44,6 @@ ndarray-blas-openblas = ["burn-core/ndarray-blas-openblas"] ndarray-blas-openblas-system = ["burn-core/ndarray-blas-openblas-system"] wgpu = ["burn-core/wgpu"] -wgpu-autotune = ["burn-core/wgpu-autotune"] tch = ["burn-core/tch"] diff --git a/examples/custom-wgpu-kernel/src/forward.rs b/examples/custom-wgpu-kernel/src/forward.rs index 954da5e86..2462df7d2 100644 --- a/examples/custom-wgpu-kernel/src/forward.rs +++ b/examples/custom-wgpu-kernel/src/forward.rs @@ -2,8 +2,10 @@ use crate::FloatTensor; use super::Backend; use burn::backend::wgpu::{ - context::WorkGroup, - kernel::{build_info, into_contiguous, DynamicKernel, SourceTemplate, StaticKernel}, + compute::{DynamicKernel, WorkGroup}, + kernel::{ + build_info, into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource, + }, kernel_wgsl, tensor::WgpuTensor, FloatElement, GraphicsApi, IntElement, WgpuBackend, @@ -24,11 +26,11 @@ struct FusedMatmulAddRelu { } // Implement the dynamic kernel trait for our kernel type. -impl DynamicKernel for FusedMatmulAddRelu { - fn source_template(self) -> SourceTemplate { +impl DynamicKernelSource for FusedMatmulAddRelu { + fn source(self) -> SourceTemplate { // Extend our raw kernel with workgroup size information using the // `SourceTemplate` trait. - FusedMatmulAddReluRaw::source_template() + FusedMatmulAddReluRaw::source() .register("workgroup_size_x", self.workgroup_size_x.to_string()) .register("workgroup_size_y", self.workgroup_size_y.to_string()) .register("elem", E::type_name()) @@ -76,23 +78,18 @@ impl Backend for WgpuBackend()); + .client + .empty(shape_out.num_elements() * core::mem::size_of::()); // Create the output tensor primitive. - let output = WgpuTensor::new(lhs.context.clone(), shape_out, buffer); + let output = WgpuTensor::new(lhs.client.clone(), lhs.device.clone(), shape_out, buffer); - // Compile the kernel or use the cache based on the template id. - let kernel = lhs.context.compile_dynamic(FusedMatmulAddRelu::::new( - workgroup_size_x, - workgroup_size_y, - )); + // Create the kernel. + let kernel = FusedMatmulAddRelu::::new(workgroup_size_x, workgroup_size_y); // Build info buffer with tensor information needed by the kernel, such as shapes and strides. let info = build_info(&[&lhs, &rhs, &output]); - let info_buffer = lhs - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); + let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); // Declare the wgsl workgroup with the number of blocks in x, y and z. let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32; @@ -100,15 +97,14 @@ impl Backend for WgpuBackend