Refactor/burn compute wgpu (#826)

This commit is contained in:
Nathaniel Simard 2023-09-25 10:42:45 -04:00 committed by GitHub
parent 7d706fae98
commit 95e660488e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
79 changed files with 1460 additions and 2664 deletions

View File

@ -32,6 +32,10 @@ impl<B: Backend> Backend for ADBackendDecorator<B> {
fn seed(seed: u64) {
B::seed(seed)
}
fn sync(device: &B::Device) {
B::sync(device);
}
}
impl<B: Backend> ADBackend for ADBackendDecorator<B> {

View File

@ -3,7 +3,7 @@ use alloc::vec::Vec;
/// The ComputeChannel trait links the ComputeClient to the ComputeServer
/// while ensuring thread-safety
pub trait ComputeChannel<Server: ComputeServer>: Clone {
pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug {
/// Given a handle, returns owned resource as bytes
fn read(&self, handle: &Handle<Server>) -> Vec<u8>;

View File

@ -12,6 +12,7 @@ use alloc::vec::Vec;
///
/// This is mosly useful for `no-std` environments where threads aren't supported, otherwise prefer
/// the [mutex](super::MutexComputeChannel) or the [mpsc](super::MpscComputeChannel) channels.
#[derive(Debug)]
pub struct RefCellComputeChannel<Server> {
server: Arc<core::cell::RefCell<Server>>,
}

View File

@ -8,6 +8,7 @@ use crate::server::{ComputeServer, Handle};
/// Create a channel using the [multi-producer, single-consumer channel](mpsc) to communicate with
/// the compute server spawn on its own thread.
#[derive(Debug)]
pub struct MpscComputeChannel<Server>
where
Server: ComputeServer,
@ -15,6 +16,7 @@ where
state: Arc<MpscComputeChannelState<Server>>,
}
#[derive(Debug)]
struct MpscComputeChannelState<Server>
where
Server: ComputeServer,

View File

@ -6,6 +6,7 @@ use spin::Mutex;
/// The MutexComputeChannel ensures thread-safety by locking the server
/// on every operation
#[derive(Debug)]
pub struct MutexComputeChannel<Server> {
server: Arc<Mutex<Server>>,
}

View File

@ -7,6 +7,7 @@ use core::marker::PhantomData;
/// The ComputeClient is the entry point to require tasks from the ComputeServer.
/// It should be obtained for a specific device via the Compute struct.
#[derive(Debug)]
pub struct ComputeClient<Server, Channel> {
channel: Channel,
_server: PhantomData<Server>,

View File

@ -29,7 +29,7 @@ macro_rules! storage_id_type {
/// Create a new memory ID type.
macro_rules! memory_id_type {
($name:ident) => {
#[derive(Clone, Hash, PartialEq, Eq)]
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
/// Memory ID.
pub struct $name {
id: alloc::sync::Arc<alloc::string::String>,

View File

@ -5,7 +5,7 @@ use crate::storage::ComputeStorage;
///
/// It is responsible for determining if the memory segment can be mutated,
/// for instance by keeping track of a reference count
pub trait MemoryHandle: Clone + Send {
pub trait MemoryHandle: Clone + Send + core::fmt::Debug {
/// Checks if the underlying memory can be safely mutated.
fn can_mut(&self) -> bool;
}
@ -15,7 +15,7 @@ pub trait MemoryHandle: Clone + Send {
///
/// The MemoryManagement can only reserve memory space or get the resource located at a space.
/// Modification of the resource data should be done directly on the resource.
pub trait MemoryManagement<Storage: ComputeStorage>: Send {
pub trait MemoryManagement<Storage: ComputeStorage>: Send + core::fmt::Debug {
/// The associated type Handle must implement MemoryHandle
type Handle: MemoryHandle;
@ -24,4 +24,30 @@ pub trait MemoryManagement<Storage: ComputeStorage>: Send {
/// Finds a spot in memory for a resource with the given size in bytes, and returns a handle to it
fn reserve(&mut self, size: usize) -> Self::Handle;
/// Bypass the memory allocation algorithm to allocate data directly.
///
/// # Notes
///
/// Can be useful for servers that want specific control over memory.
fn alloc(&mut self, size: usize) -> Self::Handle;
/// Bypass the memory allocation algorithm to deallocate data directly.
///
/// # Notes
///
/// Can be useful for servers that want specific control over memory.
fn dealloc(&mut self, handle: &Self::Handle);
/// Fetch the storage used by the memory manager.
///
/// # Notes
///
/// The storage should probably not be used for allocations since the handles won't be
/// compatible with the ones provided by the current trait. Prefer using the
/// [alloc](MemoryManagement::alloc) and [dealloc](MemoryManagement::dealloc) functions.
///
/// This is useful if you need to time the deallocations based on async computation, or to
/// change the mode of storage for different reasons.
fn storage(&mut self) -> &mut Storage;
}

View File

@ -26,7 +26,7 @@ impl SliceId {
}
/// The SimpleHandle is a memory handle, referring to either a chunk or a slice.
#[derive(Clone)]
#[derive(Debug, Clone)]
pub enum SimpleHandle {
/// A whole chunk of memory.
Chunk(ChunkId),
@ -35,34 +35,72 @@ pub enum SimpleHandle {
}
/// The strategy defines the frequency at which deallocation of unused memory chunks should occur.
#[derive(Debug)]
pub enum DeallocStrategy {
/// Once every n calls to reserve.
///
/// First associated data is n, second is the state and should start at 0
PeriodTick(usize, usize),
PeriodTick {
/// Number of calls to be executed before triggering the deallocation.
period: usize,
/// Current state. Should start at zero.
state: usize,
},
#[cfg(feature = "std")]
/// Once every period of time
PeriodTime(std::time::Duration, std::time::Instant),
PeriodTime {
/// Number of time before triggering the deallocation.
period: std::time::Duration,
/// Current state. Should start at now.
state: std::time::Instant,
},
/// Never deallocate.
Never,
}
/// The strategy defines when to reuse chunk with slices.
#[derive(Debug)]
pub enum SliceStrategy {
/// Never use slices.
Never,
/// Ratio needed before the chunk can be used as a slice. Between 0 and 1.
Ratio(f32),
/// When the reserved memory is at least {} bytes.
MinimumSize(usize),
/// When the reserved memory less than {} bytes.
MaximumSize(usize),
}
impl SliceStrategy {
/// If the chunk can be used with a slice.
pub fn can_use_chunk(&self, chunk_size: usize, reserved_size: usize) -> bool {
if chunk_size < reserved_size {
return false;
}
match self {
SliceStrategy::Never => false,
SliceStrategy::Ratio(ratio) => (reserved_size as f32 / chunk_size as f32) >= *ratio,
SliceStrategy::MinimumSize(bytes) => reserved_size >= *bytes,
SliceStrategy::MaximumSize(bytes) => reserved_size <= *bytes,
}
}
}
impl DeallocStrategy {
/// Create a new strategy with the given period.
pub fn new_period_tick(period: usize) -> Self {
DeallocStrategy::PeriodTick(period, 0)
DeallocStrategy::PeriodTick { period, state: 0 }
}
fn should_dealloc(&mut self) -> bool {
match self {
DeallocStrategy::PeriodTick(period, last) => {
*last = (*last + 1) % *period;
*last == 0
DeallocStrategy::PeriodTick { period, state } => {
*state = (*state + 1) % *period;
*state == 0
}
#[cfg(feature = "std")]
DeallocStrategy::PeriodTime(period, last) => {
if &last.elapsed() > period {
*last = std::time::Instant::now();
DeallocStrategy::PeriodTime { period, state } => {
if &state.elapsed() > period {
*state = std::time::Instant::now();
true
} else {
false
@ -78,9 +116,23 @@ pub struct SimpleMemoryManagement<Storage> {
chunks: HashMap<ChunkId, (StorageHandle, Vec<SliceId>)>,
slices: HashMap<SliceId, (StorageHandle, ChunkId)>,
dealloc_strategy: DeallocStrategy,
slice_strategy: SliceStrategy,
storage: Storage,
}
impl<Storage> core::fmt::Debug for SimpleMemoryManagement<Storage> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(
alloc::format!(
"SimpleMemoryManagement {:?} - {:?}",
self.dealloc_strategy,
core::any::type_name::<Storage>(),
)
.as_str(),
)
}
}
impl MemoryHandle for SimpleHandle {
/// Returns true if referenced by only one tensor, and only once by the
/// memory management hashmaps
@ -126,24 +178,43 @@ impl<Storage: ComputeStorage> MemoryManagement<Storage> for SimpleMemoryManageme
handle
}
fn alloc(&mut self, size: usize) -> Self::Handle {
self.create_chunk(size)
}
fn dealloc(&mut self, handle: &Self::Handle) {
match handle {
SimpleHandle::Chunk(id) => {
if let Some((handle, _slices)) = self.chunks.remove(id) {
self.storage.dealloc(handle.id);
}
}
SimpleHandle::Slice(_) => panic!("Can't dealloc slice manually"),
}
}
fn storage(&mut self) -> &mut Storage {
&mut self.storage
}
}
impl<Storage: ComputeStorage> SimpleMemoryManagement<Storage> {
/// Creates a new instance using the given storage and deallocation strategy.
pub fn new(storage: Storage, dealloc_strategy: DeallocStrategy) -> Self {
/// Creates a new instance using the given storage, deallocation strategy and slice strategy.
pub fn new(
storage: Storage,
dealloc_strategy: DeallocStrategy,
slice_strategy: SliceStrategy,
) -> Self {
Self {
chunks: HashMap::new(),
slices: HashMap::new(),
dealloc_strategy,
slice_strategy,
storage,
}
}
/// Creates an new instance using the given storage without deallocation.
pub fn never_dealloc(storage: Storage) -> Self {
Self::new(storage, DeallocStrategy::Never)
}
fn reserve_algorithm(&mut self, size: usize) -> SimpleHandle {
// Looks for a large enough, existing but unused chunk of memory.
let chunk = self.find_free_chunk(size);
@ -169,19 +240,31 @@ impl<Storage: ComputeStorage> SimpleMemoryManagement<Storage> {
let mut size_diff_current = usize::MAX;
let mut current = None;
self.chunks
.iter()
.for_each(|(chunk_id, (resource, slices))| {
let is_free = slices.is_empty() && chunk_id.is_free();
for (chunk_id, (resource, slices)) in self.chunks.iter() {
// If chunk is already used, we do not choose it
if !slices.is_empty() || !chunk_id.is_free() {
continue;
}
if is_free && resource.size() > size {
let size_diff = resource.size() - size;
if size_diff < size_diff_current {
current = Some((chunk_id, resource));
size_diff_current = size_diff;
}
let resource_size = resource.size();
// If we find a chunk of exactly the right size, we stop searching altogether
if size == resource_size {
current = Some((chunk_id, resource));
break;
}
// Finds the smallest of the large enough chunks that can accept a slice
// of the given size
if self.slice_strategy.can_use_chunk(resource_size, size) {
let size_diff = resource_size - size;
if size_diff < size_diff_current {
current = Some((chunk_id, resource));
size_diff_current = size_diff;
}
});
}
}
current.map(|(id, handle)| (id.clone(), handle.size()))
}
@ -263,7 +346,7 @@ impl<Storage: ComputeStorage> SimpleMemoryManagement<Storage> {
#[cfg(test)]
mod tests {
use crate::{
memory_management::{MemoryHandle, MemoryManagement},
memory_management::{MemoryHandle, MemoryManagement, SliceStrategy},
storage::BytesStorage,
};
@ -271,7 +354,11 @@ mod tests {
#[test]
fn can_mut_with_single_tensor_reference() {
let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default());
let mut memory_management = SimpleMemoryManagement::new(
BytesStorage::default(),
DeallocStrategy::Never,
SliceStrategy::Never,
);
let chunk_size = 4;
let simple_handle = memory_management.create_chunk(chunk_size);
@ -284,7 +371,11 @@ mod tests {
#[test]
fn two_tensor_references_remove_mutability() {
let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default());
let mut memory_management = SimpleMemoryManagement::new(
BytesStorage::default(),
DeallocStrategy::Never,
SliceStrategy::Never,
);
let chunk_size = 4;
let simple_handle = memory_management.create_chunk(chunk_size);
@ -297,7 +388,11 @@ mod tests {
#[test]
fn when_non_empty_chunk_exists_and_other_one_created_there_should_be_two() {
let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default());
let mut memory_management = SimpleMemoryManagement::new(
BytesStorage::default(),
DeallocStrategy::Never,
SliceStrategy::Never,
);
let chunk_size = 4;
let _chunk_handle = memory_management.reserve(chunk_size);
let _new_handle = memory_management.reserve(chunk_size);
@ -307,7 +402,11 @@ mod tests {
#[test]
fn when_empty_chunk_is_cleaned_upexists_it_disappears() {
let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default());
let mut memory_management = SimpleMemoryManagement::new(
BytesStorage::default(),
DeallocStrategy::Never,
SliceStrategy::Never,
);
let chunk_size = 4;
let chunk_handle = memory_management.reserve(chunk_size);
drop(chunk_handle);
@ -336,4 +435,28 @@ mod tests {
assert!(period_tick_dealloc.should_dealloc());
}
}
#[test]
fn slice_strategy_minimum_bytes() {
let strategy = SliceStrategy::MinimumSize(100);
assert!(strategy.can_use_chunk(200, 101));
assert!(!strategy.can_use_chunk(200, 99));
}
#[test]
fn slice_strategy_maximum_bytes() {
let strategy = SliceStrategy::MaximumSize(100);
assert!(strategy.can_use_chunk(200, 99));
assert!(!strategy.can_use_chunk(200, 101));
}
#[test]
fn slice_strategy_ratio() {
let strategy = SliceStrategy::Ratio(0.9);
assert!(strategy.can_use_chunk(200, 180));
assert!(!strategy.can_use_chunk(200, 179));
}
}

View File

@ -1,18 +1,39 @@
use crate::{
memory_management::{MemoryHandle, MemoryManagement},
storage::ComputeStorage,
};
use alloc::vec::Vec;
use crate::{memory_management::MemoryManagement, storage::ComputeStorage};
/// Server handle containing the [memory handle](MemoryManagement::Handle).
#[derive(new, Debug)]
pub struct Handle<Server: ComputeServer> {
/// Handle for the memory in use.
pub memory: <Server::MemoryManagement as MemoryManagement<Server::Storage>>::Handle,
}
type _Storage<Server> = <Server as ComputeServer>::Storage;
type _MemoryManagement<Server> = <Server as ComputeServer>::MemoryManagement;
impl<Server: ComputeServer> Handle<Server> {
/// If the tensor handle can be mut with an inplace operation.
pub fn can_mut(&self) -> bool {
self.memory.can_mut()
}
}
/// This alias for a [memory handle](MemoryManagement::Handle).
pub type Handle<Server> = <_MemoryManagement<Server> as MemoryManagement<_Storage<Server>>>::Handle;
impl<Server: ComputeServer> Clone for Handle<Server> {
fn clone(&self) -> Self {
Self {
memory: self.memory.clone(),
}
}
}
/// The compute server is responsible for handling resources and computations over resources.
///
/// Everything in the server is mutable, therefore it should be solely accessed through the
/// [compute channel](crate::channel::ComputeChannel) for thread safety.
pub trait ComputeServer: Send {
pub trait ComputeServer: Send + core::fmt::Debug
where
Self: Sized,
{
/// The kernel type defines the computation algorithms.
type Kernel: Send;
/// The [storage](ComputeStorage) type defines how data is stored and accessed.

View File

@ -8,6 +8,12 @@ pub struct BytesStorage {
memory: HashMap<StorageId, AllocatedBytes>,
}
impl core::fmt::Debug for BytesStorage {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("BytesStorage")
}
}
/// Can send to other threads, but can't sync.
unsafe impl Send for BytesStorage {}
unsafe impl Send for BytesResource {}

View File

@ -1,7 +1,7 @@
use super::DummyServer;
use burn_compute::channel::MutexComputeChannel;
use burn_compute::client::ComputeClient;
use burn_compute::memory_management::SimpleMemoryManagement;
use burn_compute::memory_management::{DeallocStrategy, SimpleMemoryManagement, SliceStrategy};
use burn_compute::storage::BytesStorage;
use burn_compute::Compute;
@ -17,7 +17,8 @@ pub fn client(
) -> ComputeClient<DummyServer, MutexComputeChannel<DummyServer>> {
COMPUTE.client(device, || {
let storage = BytesStorage::default();
let memory_management = SimpleMemoryManagement::never_dealloc(storage);
let memory_management =
SimpleMemoryManagement::new(storage, DeallocStrategy::Never, SliceStrategy::Never);
let server = DummyServer::new(memory_management);
let channel = MutexComputeChannel::new(server);

View File

@ -2,14 +2,14 @@ use burn_compute::storage::BytesResource;
/// The DummyKernel trait should be implemented for every supported operation
pub trait DummyKernel: Send {
fn compute<'a>(&self, resources: &mut [BytesResource]);
fn compute(&self, resources: &mut [BytesResource]);
}
/// Contains the algorithm for element-wise addition
pub struct DummyElementwiseAddition;
impl DummyKernel for DummyElementwiseAddition {
fn compute<'a>(&self, inputs: &mut [BytesResource]) {
fn compute(&self, inputs: &mut [BytesResource]) {
// Notice how the kernel is responsible for determining which inputs
// are read-only and which are writable.
let lhs = &inputs[0].read();

View File

@ -9,7 +9,7 @@ use super::DummyKernel;
/// The dummy server is used to test the burn-compute infrastructure.
/// It uses simple memory management with a bytes storage on CPU, without asynchronous tasks.
#[derive(new)]
#[derive(new, Debug)]
pub struct DummyServer<MM = SimpleMemoryManagement<BytesStorage>> {
memory_management: MM,
}
@ -23,7 +23,7 @@ where
type MemoryManagement = MM;
fn read(&mut self, handle: &Handle<Self>) -> Vec<u8> {
let bytes = self.memory_management.get(handle);
let bytes = self.memory_management.get(&handle.memory);
bytes.read().to_vec()
}
@ -38,17 +38,17 @@ where
bytes[i] = *val;
}
handle
Handle::new(handle)
}
fn empty(&mut self, size: usize) -> Handle<Self> {
self.memory_management.reserve(size)
Handle::new(self.memory_management.reserve(size))
}
fn execute(&mut self, kernel: Self::Kernel, handles: &[&Handle<Self>]) {
let mut resources = handles
.iter()
.map(|handle| self.memory_management.get(handle))
.map(|handle| self.memory_management.get(&handle.memory))
.collect::<Vec<_>>();
kernel.compute(&mut resources);

View File

@ -23,6 +23,7 @@ std = [
"serde_json/std",
"bincode/std",
"half/std",
"derive-new/std",
]
dataset = ["burn-dataset/default"]
@ -42,7 +43,6 @@ ndarray-blas-openblas-system = ["__ndarray", "ndarray", "burn-ndarray/blas-openb
__ndarray = [] # Internal flag to know when one ndarray feature is enabled.
wgpu = ["burn-wgpu"]
wgpu-autotune = ["wgpu", "burn-wgpu/autotune"]
tch = ["burn-tch"]
@ -67,7 +67,7 @@ burn-autodiff = { path = "../burn-autodiff", version = "0.10.0", optional = true
burn-wgpu = { path = "../burn-wgpu", version = "0.10.0", optional = true }
burn-tch = { path = "../burn-tch", version = "0.10.0", optional = true }
derive-new = { workspace = true }
derive-new = { workspace = true, default-features = false }
libm = { workspace = true }
log = { workspace = true, optional = true }
rand = { workspace = true, features = ["std_rng"] } # Default enables std

View File

@ -1,6 +1,5 @@
use super::{bin_config, PrecisionSettings, Recorder, RecorderError};
use alloc::vec::Vec;
use core::marker::PhantomData;
use serde::{de::DeserializeOwned, Serialize};
/// Recorder trait specialized to save and load data to and from bytes.
@ -17,7 +16,7 @@ pub trait BytesRecorder:
/// In memory recorder using the [bincode format](bincode).
#[derive(new, Debug, Default, Clone)]
pub struct BinBytesRecorder<S: PrecisionSettings> {
_settings: PhantomData<S>,
_settings: core::marker::PhantomData<S>,
}
impl<S: PrecisionSettings> BytesRecorder for BinBytesRecorder<S> {}
@ -45,7 +44,7 @@ impl<S: PrecisionSettings> Recorder for BinBytesRecorder<S> {
/// In memory recorder using the [Named MessagePack](rmp_serde).
#[derive(new, Debug, Default, Clone)]
pub struct NamedMpkBytesRecorder<S: PrecisionSettings> {
_settings: PhantomData<S>,
_settings: core::marker::PhantomData<S>,
}
#[cfg(feature = "std")]

View File

@ -1,6 +1,5 @@
use super::{PrecisionSettings, Record};
use burn_tensor::{backend::Backend, Bool, DataSerialize, Int, Tensor};
use core::marker::PhantomData;
use serde::{Deserialize, Serialize};
/// This struct implements serde to lazily serialize and deserialize a float tensor
@ -8,7 +7,7 @@ use serde::{Deserialize, Serialize};
#[derive(new, Clone, Debug)]
pub struct FloatTensorSerde<B: Backend, const D: usize, S: PrecisionSettings> {
tensor: Tensor<B, D>,
elem: PhantomData<S>,
elem: core::marker::PhantomData<S>,
}
/// This struct implements serde to lazily serialize and deserialize an int tensor
@ -16,7 +15,7 @@ pub struct FloatTensorSerde<B: Backend, const D: usize, S: PrecisionSettings> {
#[derive(new, Clone, Debug)]
pub struct IntTensorSerde<B: Backend, const D: usize, S: PrecisionSettings> {
tensor: Tensor<B, D, Int>,
elem: PhantomData<S>,
elem: core::marker::PhantomData<S>,
}
/// This struct implements serde to lazily serialize and deserialize an bool tensor.

View File

@ -91,4 +91,10 @@ impl<E: TchElement> Backend for TchBackend<E> {
fn name() -> String {
"tch".to_string()
}
fn sync(device: &Self::Device) {
if let TchDevice::Cuda(index) = device {
tch::Cuda::synchronize(*index as i64);
}
}
}

View File

@ -94,6 +94,9 @@ pub trait Backend:
/// Seed the backend.
fn seed(seed: u64);
/// Sync the backend, ensure that all computation are finished.
fn sync(_device: &Self::Device) {}
}
pub(crate) type ADBackendTensorPrimitive<const D: usize, B> =

View File

@ -11,10 +11,7 @@ repository = "https://github.com/burn-rs/burn/tree/main/burn-wgpu"
version = "0.10.0"
[features]
default = ["async"]
async = []
# Still experimental
autotune = []
default = []
[dependencies]
burn-common = { path = "../burn-common", version = "0.10.0" }
@ -35,6 +32,10 @@ wgpu = { workspace = true }
serde = { workspace = true }
text_placeholder = { version = "0.5.0", features = ["struct_context"] }
hashbrown = { workspace = true }
burn-compute = { path = "../burn-compute", version = "0.10.0", default-features = false, features=["channel-mutex", "std"] }
[dev-dependencies]
burn-autodiff = { path = "../burn-autodiff", version = "0.10.0", default-features = false, features = [
"export_tests",
@ -45,10 +46,6 @@ burn-tensor = { path = "../burn-tensor", version = "0.10.0", default-features =
burn-ndarray = { path = "../burn-ndarray", version = "0.10.0" }
serial_test = "2.0.0"
# Still only in dev mode
hashbrown = { workspace = true }
burn-compute = { path = "../burn-compute", version = "0.10.0", default-features = false, features=["channel-mutex", "std"] }
[[bench]]
name = "unary"
harness = false

View File

@ -3,7 +3,7 @@ use burn_wgpu::{
benchmark::Benchmark,
kernel::matmul::{
contiguous, contiguous_vectorized, matmul_mem_coalescing_default, matmul_naive_default,
tile, tile_vectorized, tune,
tile, tile_vectorized,
},
run_benchmark, GraphicsApi, WgpuBackend, WgpuDevice,
};
@ -50,8 +50,8 @@ where
}
fn prepare(&self, device: &WgpuDevice) -> Self::Args {
let lhs = Tensor::random(self.shape_lhs.clone(), Distribution::Default).to_device(device);
let rhs = Tensor::random(self.shape_rhs.clone(), Distribution::Default).to_device(device);
let lhs = Tensor::random_device(self.shape_lhs.clone(), Distribution::Default, device);
let rhs = Tensor::random_device(self.shape_rhs.clone(), Distribution::Default, device);
(lhs, rhs)
}
@ -88,22 +88,6 @@ benchmark!(
contiguous_vectorized::matmul_tiling_2d_default
);
struct MatmulAutotune;
impl<const D: usize, G: GraphicsApi> MatmulFunction<WgpuBackend<G, f32, i32>, D>
for MatmulAutotune
{
fn run(
lhs: Tensor<WgpuBackend<G, f32, i32>, D>,
rhs: Tensor<WgpuBackend<G, f32, i32>, D>,
) -> Tensor<WgpuBackend<G, f32, i32>, D> {
Tensor::from_primitive(tune::<G, f32, D>(
lhs.into_primitive(),
rhs.into_primitive(),
))
}
}
fn main() {
let num_repeats = 3;
let batch_size = 3;
@ -141,10 +125,4 @@ fn main() {
num_repeats,
matmul: PhantomData
});
run_benchmark!(MatmulBenchmark::<MatmulAutotune, 3> {
shape_lhs: [batch_size, m, k].into(),
shape_rhs: [batch_size, k, n].into(),
num_repeats,
matmul: PhantomData
});
}

View File

@ -2,6 +2,7 @@ use burn_tensor::backend::Backend;
use rand::{rngs::StdRng, SeedableRng};
use crate::{
compute::compute_client,
element::{FloatElement, IntElement},
tensor::WgpuTensor,
GraphicsApi, WgpuDevice,
@ -43,4 +44,9 @@ impl<G: GraphicsApi + 'static, F: FloatElement, I: IntElement> Backend for WgpuB
fn ad_enabled() -> bool {
false
}
fn sync(device: &Self::Device) {
let client = compute_client::<G>(device);
client.sync();
}
}

View File

@ -1,4 +1,4 @@
use crate::{pool::get_context, GraphicsApi, WgpuDevice};
use crate::{compute::compute_client, GraphicsApi, WgpuDevice};
use std::{
fmt::Display,
time::{Duration, Instant},
@ -15,6 +15,7 @@ impl BenchmarkResult {
self.durations.iter().sum::<Duration>() / self.durations.len() as u32
}
#[allow(dead_code)]
pub(crate) fn median_duration(&self) -> Duration {
let mut sorted = self.durations.clone();
sorted.sort();
@ -83,23 +84,23 @@ pub trait Benchmark<G: GraphicsApi> {
fn name(&self) -> String;
/// Run the benchmark a number of times.
fn run(&self, device: &WgpuDevice) -> BenchmarkResult {
let context = get_context::<G>(device);
let client = compute_client::<G>(device);
// Warmup
self.execute(self.prepare(device));
context.sync();
client.sync();
let mut durations = Vec::with_capacity(self.num_samples());
for _ in 0..self.num_samples() {
// Prepare
let args = self.prepare(device);
context.sync();
client.sync();
// Execute the benchmark
let start = Instant::now();
self.execute(args);
context.sync();
client.sync();
let end = Instant::now();
// Register the duration

View File

@ -1,26 +1,28 @@
use super::{Kernel, WgpuServer};
use crate::{
compute::WgpuStorage,
context::{select_device, WorkGroup},
kernel::{DynamicKernel, SourceTemplate, StaticKernel},
GraphicsApi, WgpuDevice,
};
use super::WgpuServer;
use crate::{compute::WgpuStorage, GraphicsApi, WgpuDevice};
use alloc::sync::Arc;
use burn_compute::{
channel::MutexComputeChannel,
client::ComputeClient,
memory_management::{DeallocStrategy, SimpleMemoryManagement},
memory_management::{DeallocStrategy, SimpleMemoryManagement, SliceStrategy},
Compute,
};
use std::{marker::PhantomData, sync::Arc};
use wgpu::{DeviceDescriptor, DeviceType};
type WgpuChannel = MutexComputeChannel<WgpuServer>;
type MemoryManagement = SimpleMemoryManagement<WgpuStorage>;
type Server = WgpuServer<MemoryManagement>;
type Channel = MutexComputeChannel<Server>;
/// Wgpu [compute client](ComputeClient) to communicate with the [compute server](WgpuServer).
pub type WgpuComputeClient = ComputeClient<Server, Channel>;
/// Wgpu [server handle](burn_compute::server::Handle).
pub type WgpuHandle = burn_compute::server::Handle<Server>;
/// Compute handle for the wgpu backend.
static COMPUTE: Compute<WgpuDevice, WgpuServer, WgpuChannel> = Compute::new();
static COMPUTE: Compute<WgpuDevice, WgpuServer<MemoryManagement>, Channel> = Compute::new();
pub fn compute_client<G: GraphicsApi>(
device: &WgpuDevice,
) -> ComputeClient<WgpuServer, WgpuChannel> {
/// Get the [compute client](ComputeClient) for the given [device](WgpuDevice).
pub fn compute_client<G: GraphicsApi>(device: &WgpuDevice) -> ComputeClient<Server, Channel> {
let device = Arc::new(device);
COMPUTE.client(&device, move || {
@ -37,93 +39,159 @@ pub fn compute_client<G: GraphicsApi>(
Ok(value) => value
.parse::<usize>()
.expect("BURN_WGPU_MAX_TASKS should be a positive integer."),
Err(_) => 16, // 16 tasks by default
Err(_) => 64, // 64 tasks by default
};
let device = Arc::new(device_wgpu);
let storage = WgpuStorage::new(device.clone());
// Maximum reusability.
let memory_management = SimpleMemoryManagement::new(storage, DeallocStrategy::Never);
let memory_management = SimpleMemoryManagement::new(
storage,
DeallocStrategy::new_period_tick(1000),
SliceStrategy::Ratio(0.9),
);
let server = WgpuServer::new(memory_management, device, queue, max_tasks);
let channel = WgpuChannel::new(server);
let channel = Channel::new(server);
ComputeClient::new(channel)
})
}
pub struct DynamicComputeKernel<K> {
kernel: K,
workgroup: WorkGroup,
/// Select the wgpu device and queue based on the provided [device](WgpuDevice).
pub async fn select_device<G: GraphicsApi>(
device: &WgpuDevice,
) -> (wgpu::Device, wgpu::Queue, wgpu::AdapterInfo) {
let adapter = select_adapter::<G>(device);
let limits = adapter.limits();
let (device, queue) = adapter
.request_device(
&DeviceDescriptor {
label: None,
features: wgpu::Features::empty(),
limits,
},
None,
)
.await
.map_err(|err| {
format!(
"Unable to request the device with the adapter {:?}, err {:?}",
adapter.get_info(),
err
)
})
.unwrap();
(device, queue, adapter.get_info())
}
impl<K> Kernel for DynamicComputeKernel<K>
where
K: DynamicKernel + 'static,
{
fn source_template(self: Box<Self>) -> SourceTemplate {
self.kernel.source_template()
fn select_adapter<G: GraphicsApi>(device: &WgpuDevice) -> wgpu::Adapter {
let instance = wgpu::Instance::default();
let mut adapters_other = Vec::new();
let mut adapters = Vec::new();
instance
.enumerate_adapters(G::backend().into())
.for_each(|adapter| {
let device_type = adapter.get_info().device_type;
if let DeviceType::Other = device_type {
adapters_other.push(adapter);
return;
}
let is_same_type = match device {
WgpuDevice::DiscreteGpu(_) => device_type == DeviceType::DiscreteGpu,
WgpuDevice::IntegratedGpu(_) => device_type == DeviceType::IntegratedGpu,
WgpuDevice::VirtualGpu(_) => device_type == DeviceType::VirtualGpu,
WgpuDevice::Cpu => device_type == DeviceType::Cpu,
WgpuDevice::BestAvailable => true,
};
if is_same_type {
adapters.push(adapter);
}
});
fn select(
num: usize,
error: &str,
mut adapters: Vec<wgpu::Adapter>,
mut adapters_other: Vec<wgpu::Adapter>,
) -> wgpu::Adapter {
if adapters.len() <= num {
if adapters_other.len() <= num {
panic!(
"{}, adapters {:?}, other adapters {:?}",
error,
adapters
.into_iter()
.map(|adapter| adapter.get_info())
.collect::<Vec<_>>(),
adapters_other
.into_iter()
.map(|adapter| adapter.get_info())
.collect::<Vec<_>>(),
);
} else {
return adapters_other.remove(num);
}
}
adapters.remove(num)
}
fn id(&self) -> String {
self.kernel.id()
}
let adapter = match device {
WgpuDevice::DiscreteGpu(num) => select(
*num,
"No Discrete GPU device found",
adapters,
adapters_other,
),
WgpuDevice::IntegratedGpu(num) => select(
*num,
"No Integrated GPU device found",
adapters,
adapters_other,
),
WgpuDevice::VirtualGpu(num) => select(
*num,
"No Virtual GPU device found",
adapters,
adapters_other,
),
WgpuDevice::Cpu => select(0, "No CPU device found", adapters, adapters_other),
WgpuDevice::BestAvailable => {
let mut most_performant_adapter = None;
let mut current_score = -1;
fn workgroup(&self) -> WorkGroup {
self.workgroup.clone()
}
}
#[derive(new)]
pub struct StaticComputeKernel<K> {
workgroup: WorkGroup,
_kernel: PhantomData<K>,
}
impl<K> Kernel for StaticComputeKernel<K>
where
K: StaticKernel + 'static,
{
fn source_template(self: Box<Self>) -> SourceTemplate {
K::source_template()
}
fn id(&self) -> String {
format!("{:?}", core::any::TypeId::of::<K>())
}
fn workgroup(&self) -> WorkGroup {
self.workgroup.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{binary_elemwise, kernel::KernelSettings, AutoGraphicsApi};
#[test]
fn can_run_kernel() {
binary_elemwise!(Add, "+");
let client = compute_client::<AutoGraphicsApi>(&WgpuDevice::default());
let lhs: Vec<f32> = vec![0., 1., 2., 3., 4., 5., 6., 7.];
let rhs: Vec<f32> = vec![10., 11., 12., 6., 7., 3., 1., 0.];
let info: Vec<u32> = vec![1, 1, 1, 1, 8, 8, 8];
let lhs = client.create(bytemuck::cast_slice(&lhs));
let rhs = client.create(bytemuck::cast_slice(&rhs));
let out = client.empty(core::mem::size_of::<f32>() * 8);
let info = client.create(bytemuck::cast_slice(&info));
type Kernel = KernelSettings<Add, f32, i32, 16, 16, 1>;
let kernel = Box::new(StaticComputeKernel::<Kernel>::new(WorkGroup::new(1, 1, 1)));
client.execute(kernel, &[&lhs, &rhs, &out, &info]);
let data = client.read(&out);
let output: &[f32] = bytemuck::cast_slice(&data);
assert_eq!(output, [10., 12., 14., 9., 11., 8., 7., 7.]);
}
adapters.into_iter().for_each(|adapter| {
let info = adapter.get_info();
let score = match info.device_type {
DeviceType::DiscreteGpu => 5,
DeviceType::Other => 4, // Let's be optimistic with the Other device, it's
// often a Discrete Gpu.
DeviceType::IntegratedGpu => 3,
DeviceType::VirtualGpu => 2,
DeviceType::Cpu => 1,
};
if score > current_score {
most_performant_adapter = Some(adapter);
current_score = score;
}
});
if let Some(adapter) = most_performant_adapter {
adapter
} else {
panic!("No adapter found for graphics API {:?}", G::default());
}
}
};
log::info!("Using adapter {:?}", adapter.get_info());
adapter
}

View File

@ -0,0 +1,106 @@
use super::Kernel;
use crate::kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource};
use core::marker::PhantomData;
/// Provides launch information specifying the number of work groups to be used by a compute shader.
#[derive(new, Clone, Debug)]
pub struct WorkGroup {
/// Work groups for the x axis.
pub x: u32,
/// Work groups for the y axis.
pub y: u32,
/// Work groups for the z axis.
pub z: u32,
}
impl WorkGroup {
/// Calculate the number of invocations of a compute shader.
pub fn num_invocations(&self) -> usize {
(self.x * self.y * self.z) as usize
}
}
/// Wraps a [dynamic kernel source](DynamicKernelSource) into a [kernel](Kernel) with launch
/// information such as [workgroup](WorkGroup).
#[derive(new)]
pub struct DynamicKernel<K> {
kernel: K,
workgroup: WorkGroup,
}
/// Wraps a [static kernel source](StaticKernelSource) into a [kernel](Kernel) with launch
/// information such as [workgroup](WorkGroup).
#[derive(new)]
pub struct StaticKernel<K> {
workgroup: WorkGroup,
_kernel: PhantomData<K>,
}
impl<K> Kernel for DynamicKernel<K>
where
K: DynamicKernelSource + 'static,
{
fn source(self: Box<Self>) -> SourceTemplate {
self.kernel.source()
}
fn id(&self) -> String {
self.kernel.id()
}
fn workgroup(&self) -> WorkGroup {
self.workgroup.clone()
}
}
impl<K> Kernel for StaticKernel<K>
where
K: StaticKernelSource + 'static,
{
fn source(self: Box<Self>) -> SourceTemplate {
K::source()
}
fn id(&self) -> String {
format!("{:?}", core::any::TypeId::of::<K>())
}
fn workgroup(&self) -> WorkGroup {
self.workgroup.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
binary_elemwise, compute::compute_client, kernel::KernelSettings, AutoGraphicsApi,
WgpuDevice,
};
#[test]
fn can_run_kernel() {
binary_elemwise!(Add, "+");
let client = compute_client::<AutoGraphicsApi>(&WgpuDevice::default());
let lhs: Vec<f32> = vec![0., 1., 2., 3., 4., 5., 6., 7.];
let rhs: Vec<f32> = vec![10., 11., 12., 6., 7., 3., 1., 0.];
let info: Vec<u32> = vec![1, 1, 1, 1, 8, 8, 8];
let lhs = client.create(bytemuck::cast_slice(&lhs));
let rhs = client.create(bytemuck::cast_slice(&rhs));
let out = client.empty(core::mem::size_of::<f32>() * 8);
let info = client.create(bytemuck::cast_slice(&info));
type Kernel = KernelSettings<Add, f32, i32, 16, 16, 1>;
let kernel = Box::new(StaticKernel::<Kernel>::new(WorkGroup::new(1, 1, 1)));
client.execute(kernel, &[&lhs, &rhs, &out, &info]);
let data = client.read(&out);
let output: &[f32] = bytemuck::cast_slice(&data);
assert_eq!(output, [10., 12., 14., 9., 11., 8., 7., 7.]);
}
}

View File

@ -1,7 +1,9 @@
mod base;
mod kernel;
mod server;
mod storage;
pub use base::*;
pub use kernel::*;
pub use server::*;
pub use storage::*;

View File

@ -1,9 +1,8 @@
use std::{borrow::Cow, sync::Arc};
use super::WgpuStorage;
use crate::{context::WorkGroup, kernel::SourceTemplate};
use super::{WgpuStorage, WorkGroup};
use crate::kernel::SourceTemplate;
use alloc::{borrow::Cow, sync::Arc};
use burn_compute::{
memory_management::{MemoryManagement, SimpleMemoryManagement},
memory_management::MemoryManagement,
server::{self, ComputeServer},
};
use hashbrown::HashMap;
@ -13,7 +12,8 @@ use wgpu::{
};
/// Wgpu compute server.
pub struct WgpuServer<MM = SimpleMemoryManagement<WgpuStorage>> {
#[derive(Debug)]
pub struct WgpuServer<MM: MemoryManagement<WgpuStorage>> {
memory_management: MM,
device: Arc<wgpu::Device>,
queue: wgpu::Queue,
@ -21,20 +21,27 @@ pub struct WgpuServer<MM = SimpleMemoryManagement<WgpuStorage>> {
pipelines: HashMap<String, Arc<ComputePipeline>>,
tasks: Vec<ComputeTask>,
max_tasks: usize,
manual_available: HashMap<usize, Vec<server::Handle<Self>>>,
manual_taken: Vec<(usize, server::Handle<Self>)>,
}
#[derive(new)]
#[derive(new, Debug)]
struct ComputeTask {
pipeline: Arc<ComputePipeline>,
bind_group: BindGroup,
work_group: WorkGroup,
}
/// Kernel trait with the [source](SourceTemplate) that will be compiled and cached based on the
/// provided id.
///
/// The kernel will be launched with the given [workgroup](WorkGroup).
pub trait Kernel: 'static + Send {
/// Source template for the kernel.
fn source_template(self: Box<Self>) -> SourceTemplate;
fn source(self: Box<Self>) -> SourceTemplate;
/// Identifier for the kernel, used for caching kernel compilation.
fn id(&self) -> String;
/// Launch information.
fn workgroup(&self) -> WorkGroup;
}
@ -42,6 +49,7 @@ impl<MM> WgpuServer<MM>
where
MM: MemoryManagement<WgpuStorage>,
{
/// Create a new server.
pub fn new(
memory_management: MM,
device: Arc<wgpu::Device>,
@ -60,6 +68,8 @@ where
pipelines: HashMap::new(),
tasks: Vec::new(),
max_tasks,
manual_available: HashMap::new(),
manual_taken: Vec::new(),
}
}
@ -68,13 +78,54 @@ where
self.tasks.is_empty(),
"Tasks should be completed before submitting the current encoder."
);
println!("Submit");
let mut new_encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
core::mem::swap(&mut new_encoder, &mut self.encoder);
self.queue.submit(Some(new_encoder.finish()));
// Cleanup allocations and deallocations.
self.free_manual_allocations();
self.memory_management.storage().perform_deallocations();
}
fn free_manual_allocations(&mut self) {
let mut manual_taken_tmp = Vec::new();
core::mem::swap(&mut manual_taken_tmp, &mut self.manual_taken);
for (size, handle) in manual_taken_tmp.drain(..) {
if handle.can_mut() {
self.register_manual(size, handle);
} else {
self.manual_taken.push((size, handle));
}
}
}
// Finds a free, manually-added handle of specified size, or creates it if none is found
fn manual_reserve(&mut self, size: usize) -> server::Handle<Self> {
let handle = self
.manual_available
.get_mut(&size)
.and_then(|h| h.pop())
.unwrap_or_else(|| {
let memory = self.memory_management.alloc(size);
server::Handle::new(memory)
});
self.manual_taken.push((size, handle.clone()));
handle
}
// Manually adds a handle of given size
fn register_manual(&mut self, size: usize, handle: server::Handle<Self>) {
if let Some(handles) = self.manual_available.get_mut(&size) {
handles.push(handle);
} else {
self.manual_available.insert(size, [handle].into());
}
}
fn register_tasks(&mut self) {
@ -102,7 +153,7 @@ where
return pipeline.clone();
}
let pipeline = self.compile_source(&kernel.source_template().complete());
let pipeline = self.compile_source(&kernel.source().complete());
self.pipelines.insert(kernel_id.clone(), pipeline.clone());
pipeline
@ -138,7 +189,7 @@ where
// Register previous tasks before reading the buffer so that it is up to date.
self.register_tasks();
let resource = self.memory_management.get(handle);
let resource = self.memory_management.get(&handle.memory);
let size = resource.size();
let buffer_dest = self.device.create_buffer(&wgpu::BufferDescriptor {
@ -182,8 +233,13 @@ where
}
}
/// When we create a new handle from existing data, we use custom allocations so that we don't
/// have to execute the current pending tasks.
///
/// This is important, otherwise the compute passes are going to be too small and we won't be able to
/// fully utilize the GPU.
fn create(&mut self, data: &[u8]) -> server::Handle<Self> {
let handle = self.empty(data.len());
let handle = self.manual_reserve(data.len());
let buffer_src = Arc::new(self.device.create_buffer_init(&BufferInitDescriptor {
label: Some("Buffer Src"),
@ -191,9 +247,7 @@ where
usage: wgpu::BufferUsages::COPY_SRC,
}));
let resource = self.memory_management.get(&handle);
self.register_tasks();
let resource = self.memory_management.get(&handle.memory);
self.encoder.copy_buffer_to_buffer(
&buffer_src,
@ -207,7 +261,7 @@ where
}
fn empty(&mut self, size: usize) -> server::Handle<Self> {
self.memory_management.reserve(size)
server::Handle::new(self.memory_management.reserve(size))
}
fn execute(&mut self, kernel: Self::Kernel, handles: &[&server::Handle<Self>]) {
@ -217,7 +271,7 @@ where
let handles = handles
.iter()
.map(|handle| self.memory_management.get(handle))
.map(|handle| self.memory_management.get(&handle.memory))
.collect::<Vec<_>>();
let entries = handles
@ -249,5 +303,7 @@ where
self.register_tasks();
self.submit();
}
self.device.poll(wgpu::Maintain::Wait);
}
}

View File

@ -2,14 +2,25 @@ use burn_compute::storage::{ComputeStorage, StorageHandle, StorageId, StorageUti
use hashbrown::HashMap;
use std::{num::NonZeroU64, sync::Arc};
/// Buffer storage for wgpu.
pub struct WgpuStorage {
memory: HashMap<StorageId, Arc<wgpu::Buffer>>,
deallocations: Vec<StorageId>,
device: Arc<wgpu::Device>,
}
impl core::fmt::Debug for WgpuStorage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str())
}
}
/// The memory resource that can be allocated for wgpu.
#[derive(new, Debug)]
pub struct WgpuResource {
/// The wgpu buffer.
pub buffer: Arc<wgpu::Buffer>,
/// How the resource is used.
pub kind: WgpuResourceKind,
}
@ -44,6 +55,7 @@ impl WgpuResource {
}
}
/// How the resource is used, either as a slice or fully.
#[derive(Debug)]
pub enum WgpuResourceKind {
/// Represents an entire buffer.
@ -54,12 +66,23 @@ pub enum WgpuResourceKind {
/// Keeps actual wgpu buffer references in a hashmap with ids as key.
impl WgpuStorage {
/// Create a new storage on the given [device](wgpu::Device).
pub fn new(device: Arc<wgpu::Device>) -> Self {
Self {
memory: HashMap::new(),
deallocations: Vec::new(),
device,
}
}
/// Actually deallocates buffers tagged to be deallocated.
pub fn perform_deallocations(&mut self) {
for id in self.deallocations.drain(..) {
if let Some(buffer) = self.memory.remove(&id) {
buffer.destroy()
}
}
}
}
impl ComputeStorage for WgpuStorage {
@ -96,7 +119,6 @@ impl ComputeStorage for WgpuStorage {
}
fn dealloc(&mut self, id: StorageId) {
self.memory.get(&id).unwrap().destroy();
let _ = self.memory.remove(&id);
self.deallocations.push(id);
}
}

View File

@ -1,410 +0,0 @@
use super::client::ContextClient;
use crate::{
context::server::ContextServer,
kernel::{DynamicKernel, StaticKernel},
tune::Tuner,
GraphicsApi, WgpuDevice,
};
use burn_common::id::IdGenerator;
use spin::Mutex;
use std::{
any::TypeId,
borrow::Cow,
collections::HashMap,
sync::atomic::{AtomicBool, Ordering},
sync::Arc,
};
use wgpu::{
util::{BufferInitDescriptor, DeviceExt},
Buffer, ComputePipeline, DeviceDescriptor, DeviceType, ShaderModuleDescriptor,
};
#[cfg(feature = "async")]
pub(crate) type ContextClientImpl = super::client::AsyncContextClient;
#[cfg(not(feature = "async"))]
pub(crate) type ContextClientImpl = super::client::SyncContextClient;
#[cfg(feature = "async")]
pub(crate) type ContextServerImpl = super::server::AsyncContextServer;
#[cfg(not(feature = "async"))]
pub(crate) type ContextServerImpl = super::server::SyncContextServer;
/// The context is the basic struct that allows to execute GPU kernel on devices.
///
/// You can access a context for a WGPUDevice using get_context.
#[derive(Debug)]
pub struct Context {
id: String,
device_wgpu: Arc<wgpu::Device>,
cache: Mutex<HashMap<TemplateKey, Arc<ComputePipeline>>>,
is_tuning: AtomicBool,
client: ContextClientImpl,
pub(crate) tuner: Tuner,
tuning_template_ids: Mutex<Vec<TemplateKey>>,
pub(crate) device: WgpuDevice,
pub(crate) info: wgpu::AdapterInfo,
}
#[derive(Debug, Hash, Clone, PartialOrd, PartialEq, Eq)]
enum TemplateKey {
Static(TypeId),
Dynamic(String),
}
/// Provides launch information specifying the number of work groups to be used by a compute shader.
#[derive(new, Clone, Debug)]
pub struct WorkGroup {
/// Work groups for the x axis.
pub x: u32,
/// Work groups for the y axis.
pub y: u32,
/// Work groups for the z axis.
pub z: u32,
}
impl WorkGroup {
/// Calculate the number of invocations of a compute shader.
pub fn num_invocations(&self) -> usize {
(self.x * self.y * self.z) as usize
}
}
impl Context {
/// Create a new context where computing tasks will be executed on the given
/// [device](WgpuDevice).
pub(crate) fn new<G: GraphicsApi>(device: &WgpuDevice) -> Self {
let (device_wgpu, queue, info) = pollster::block_on(select_device::<G>(device));
let device = device.clone();
let device_wgpu = Arc::new(device_wgpu);
let client = ContextServerImpl::start(device_wgpu.clone(), queue);
Self {
id: IdGenerator::generate(),
device_wgpu,
device,
client,
cache: Mutex::new(HashMap::new()),
is_tuning: AtomicBool::new(false),
tuner: Tuner::new(),
tuning_template_ids: Mutex::new(Vec::new()),
info,
}
}
/// Wait for all computation to be executed.
///
/// Useful for benchmarks.
pub fn sync(&self) {
self.client.sync();
}
/// Execute a kernel using the provided buffers.
///
/// # Notes
///
/// This function isn't safe, buffer can be mutated by the GPU. The users must ensure that a
/// buffer can be mutated when launching a compute shaders with write access to a buffer.
///
/// Buffer positions are used as bindings when launching a compute kernel.
pub fn execute(
&self,
work_group: WorkGroup,
pipeline: Arc<ComputePipeline>,
buffers: &[&Buffer],
) {
let group_layout = pipeline.get_bind_group_layout(0);
let entries = buffers
.iter()
.enumerate()
.map(|(i, buffer)| wgpu::BindGroupEntry {
binding: i as u32,
resource: buffer.as_entire_binding(),
})
.collect::<Vec<_>>();
let bind_group = self
.device_wgpu
.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &group_layout,
entries: &entries,
});
self.client
.register_compute(bind_group, pipeline, work_group)
}
/// Create a new buffer with the provided size.
pub fn create_buffer(&self, size: usize) -> Arc<Buffer> {
Arc::new(self.device_wgpu.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: size as u64,
usage: wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
}))
}
/// Create a new buffer initialized with the provided bytes.
pub fn create_buffer_with_data(&self, data: &[u8]) -> Arc<Buffer> {
self.create_buffer_with_data_options(data, false)
}
/// Create a new buffer initialized with the provided bytes with the option to be sync.
///
/// It's important to be sync when you want to reuse the buffer using the Arc strong count for
/// inner mutability.
pub fn create_buffer_with_data_options(&self, data: &[u8], sync: bool) -> Arc<Buffer> {
let buffer_src = Arc::new(self.device_wgpu.create_buffer_init(&BufferInitDescriptor {
label: Some("Buffer Src"),
contents: data,
usage: wgpu::BufferUsages::COPY_SRC,
}));
let buffer_dest = self.create_buffer(buffer_src.size() as usize);
self.client.copy_buffer(buffer_src, buffer_dest, sync)
}
/// Copy buffer to buffer.
///
/// Wait for registered may be useful if you want to allow inplace operations on the created
/// buffer. Otherwise, the strong count of the buffer might not be 1 when registering a new
/// operation, which makes the buffer readonly.
pub fn copy_buffer(&self, buffer_src: Arc<Buffer>, wait_for_registered: bool) -> Arc<Buffer> {
let buffer_dest = self.create_buffer(buffer_src.size() as usize);
self.client
.copy_buffer(buffer_src, buffer_dest, wait_for_registered)
}
/// Read a buffer from the GPU and return its content as bytes.
pub fn read_buffer(&self, buffer: Arc<Buffer>) -> Vec<u8> {
self.client.read_buffer(buffer)
}
/// Compile a kernel template if not present in the cache.
pub fn compile_static<K: StaticKernel>(&self) -> Arc<ComputePipeline> {
let mut cache = self.cache.lock();
let template_id = TemplateKey::Static(TypeId::of::<K>());
if let Some(module) = cache.get(&template_id) {
return module.clone();
}
let source = K::source_template();
let pipeline = self.compile_source(&source.complete());
if self.is_tuning.load(Ordering::Relaxed) {
let mut templates_vec = self.tuning_template_ids.lock();
templates_vec.push(template_id.clone());
}
cache.insert(template_id, pipeline.clone());
pipeline
}
/// Compile a dynamic template if not present in the cache.
pub fn compile_dynamic<K: DynamicKernel>(&self, kernel: K) -> Arc<ComputePipeline> {
let mut cache = self.cache.lock();
let template_id = TemplateKey::Dynamic(kernel.id());
if let Some(module) = cache.get(&template_id) {
return module.clone();
}
let source = kernel.source_template();
let pipeline = self.compile_source(&source.complete());
if self.is_tuning.load(Ordering::Relaxed) {
let mut templates_vec = self.tuning_template_ids.lock();
templates_vec.push(template_id.clone());
}
cache.insert(template_id, pipeline.clone());
pipeline
}
fn compile_source(&self, source: &str) -> Arc<ComputePipeline> {
let module = self
.device_wgpu
.create_shader_module(ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
});
let pipeline = self
.device_wgpu
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module: &module,
entry_point: "main",
});
Arc::new(pipeline)
}
pub(crate) fn start_tuning(&self) {
self.is_tuning.store(true, Ordering::Relaxed);
}
pub(crate) fn stop_tuning(&self) {
self.is_tuning.store(false, Ordering::Relaxed);
// clean cache of pipelines accumulated during tuning
let mut cache = self.cache.lock();
let mut tuning_template_ids = self.tuning_template_ids.lock();
for template_id in tuning_template_ids.iter() {
cache.remove(template_id);
}
tuning_template_ids.clear();
}
}
impl PartialEq for Context {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
pub(crate) async fn select_device<G: GraphicsApi>(
device: &WgpuDevice,
) -> (wgpu::Device, wgpu::Queue, wgpu::AdapterInfo) {
let adapter = select_adapter::<G>(device);
let limits = adapter.limits();
let (device, queue) = adapter
.request_device(
&DeviceDescriptor {
label: None,
features: wgpu::Features::empty(),
limits,
},
None,
)
.await
.map_err(|err| {
format!(
"Unable to request the device with the adapter {:?}, err {:?}",
adapter.get_info(),
err
)
})
.unwrap();
(device, queue, adapter.get_info())
}
fn select_adapter<G: GraphicsApi>(device: &WgpuDevice) -> wgpu::Adapter {
let instance = wgpu::Instance::default();
let mut adapters_other = Vec::new();
let mut adapters = Vec::new();
instance
.enumerate_adapters(G::backend().into())
.for_each(|adapter| {
let device_type = adapter.get_info().device_type;
if let DeviceType::Other = device_type {
adapters_other.push(adapter);
return;
}
let is_same_type = match device {
WgpuDevice::DiscreteGpu(_) => device_type == DeviceType::DiscreteGpu,
WgpuDevice::IntegratedGpu(_) => device_type == DeviceType::IntegratedGpu,
WgpuDevice::VirtualGpu(_) => device_type == DeviceType::VirtualGpu,
WgpuDevice::Cpu => device_type == DeviceType::Cpu,
WgpuDevice::BestAvailable => true,
};
if is_same_type {
adapters.push(adapter);
}
});
fn select(
num: usize,
error: &str,
mut adapters: Vec<wgpu::Adapter>,
mut adapters_other: Vec<wgpu::Adapter>,
) -> wgpu::Adapter {
if adapters.len() <= num {
if adapters_other.len() <= num {
panic!(
"{}, adapters {:?}, other adapters {:?}",
error,
adapters
.into_iter()
.map(|adapter| adapter.get_info())
.collect::<Vec<_>>(),
adapters_other
.into_iter()
.map(|adapter| adapter.get_info())
.collect::<Vec<_>>(),
);
} else {
return adapters_other.remove(num);
}
}
adapters.remove(num)
}
let adapter = match device {
WgpuDevice::DiscreteGpu(num) => select(
*num,
"No Discrete GPU device found",
adapters,
adapters_other,
),
WgpuDevice::IntegratedGpu(num) => select(
*num,
"No Integrated GPU device found",
adapters,
adapters_other,
),
WgpuDevice::VirtualGpu(num) => select(
*num,
"No Virtual GPU device found",
adapters,
adapters_other,
),
WgpuDevice::Cpu => select(0, "No CPU device found", adapters, adapters_other),
WgpuDevice::BestAvailable => {
let mut most_performant_adapter = None;
let mut current_score = -1;
adapters.into_iter().for_each(|adapter| {
let info = adapter.get_info();
let score = match info.device_type {
DeviceType::DiscreteGpu => 5,
DeviceType::Other => 4, // Let's be optimistic with the Other device, it's
// often a Discrete Gpu.
DeviceType::IntegratedGpu => 3,
DeviceType::VirtualGpu => 2,
DeviceType::Cpu => 1,
};
if score > current_score {
most_performant_adapter = Some(adapter);
current_score = score;
}
});
if let Some(adapter) = most_performant_adapter {
adapter
} else {
panic!("No adapter found for graphics API {:?}", G::default());
}
}
};
log::info!("Using adapter {:?}", adapter.get_info());
adapter
}

View File

@ -1,190 +0,0 @@
use super::WorkGroup;
use std::sync::Arc;
use wgpu::{BindGroup, Buffer, ComputePipeline};
#[cfg(feature = "async")]
pub use async_client::AsyncContextClient;
#[cfg(not(feature = "async"))]
pub use sync_client::SyncContextClient;
/// Context client allows to speak with a server to execute tasks on the GPU.
pub trait ContextClient {
/// Copy the source buffer content into the destination buffer.
///
/// # Notes
///
/// Make sure the source buffer isn't used afterward, since a race condition may happen.
///
/// If the source buffer is still used afterward, use [tensor copy](crate::tensor::WgpuTensor::copy)
/// instead. This method is still useful to load data from the CPU into a new buffer.
fn copy_buffer(
&self,
buffer_src: Arc<Buffer>,
buffer_dest: Arc<Buffer>,
wait_for_registered: bool,
) -> Arc<Buffer>;
/// Read a [buffer](Buffer).
///
/// # Notes
///
/// All pending compute tasks will be executed.
fn read_buffer(&self, buffer: Arc<Buffer>) -> Vec<u8>;
/// Register a new computing task.
fn register_compute(
&self,
bind_group: BindGroup,
pipeline: Arc<ComputePipeline>,
work_group: WorkGroup,
);
/// Wait for all computation to be done.
///
/// Useful for benchmarks.
fn sync(&self);
}
#[cfg(feature = "async")]
mod async_client {
use super::ContextClient;
use crate::context::{
server::{ComputeTask, ContextTask, CopyBufferTask, ReadBufferTask},
WorkGroup,
};
use std::sync::{mpsc, Arc};
use wgpu::{BindGroup, Buffer, ComputePipeline};
/// Client returned by
#[derive(new, Debug)]
pub struct AsyncContextClient {
sender: mpsc::SyncSender<ContextTask>,
_server_handle: std::thread::JoinHandle<()>,
}
impl ContextClient for AsyncContextClient {
fn sync(&self) {
let (sender, receiver) = std::sync::mpsc::channel();
self.sender.send(ContextTask::Sync(sender)).unwrap();
if receiver.iter().next().is_some() {
log::debug!("Sync completed");
} else {
panic!("Unable sync")
}
}
fn copy_buffer(
&self,
buffer_src: Arc<Buffer>,
buffer_dest: Arc<Buffer>,
wait_for_registered: bool,
) -> Arc<Buffer> {
if wait_for_registered {
assert_eq!(Arc::strong_count(&buffer_dest), 1, "You can't wait for the buffer to be registered when multiple references already exist.");
}
self.sender
.send(CopyBufferTask::new(buffer_src, buffer_dest.clone()).into())
.unwrap();
if !wait_for_registered {
return buffer_dest;
}
// Wait for the buffer to be correctly registered so that inplace operations can be
// prioritize.
//
// Note that this is unsafe and a channel could have been used to wait for completion.
// The loop is there for performance reason.
//
// TODO: Use a performant one time channel here as callback instead.
loop {
std::thread::sleep(std::time::Duration::from_micros(1));
if Arc::strong_count(&buffer_dest) == 1 {
return buffer_dest;
}
}
}
fn read_buffer(&self, buffer: Arc<Buffer>) -> Vec<u8> {
let (sender, receiver) = std::sync::mpsc::channel();
self.sender
.send(ReadBufferTask::new(buffer, sender).into())
.unwrap();
let mut iter = receiver.iter();
if let Some(data) = iter.next() {
data
} else {
panic!("Unable to read buffer")
}
}
fn register_compute(
&self,
bind_group: BindGroup,
pipeline: Arc<ComputePipeline>,
work_group: WorkGroup,
) {
self.sender
.send(ComputeTask::new(bind_group, pipeline, work_group).into())
.unwrap();
}
}
}
#[cfg(not(feature = "async"))]
mod sync_client {
use super::ContextClient;
use crate::context::{
server::{ComputeTask, SyncContextServer},
WorkGroup,
};
use std::sync::Arc;
use wgpu::{BindGroup, Buffer, ComputePipeline};
#[derive(Debug)]
pub struct SyncContextClient {
server: spin::Mutex<SyncContextServer>,
}
impl SyncContextClient {
pub fn new(server: SyncContextServer) -> Self {
Self {
server: spin::Mutex::new(server),
}
}
}
impl ContextClient for SyncContextClient {
fn sync(&self) {
let mut server = self.server.lock();
server.sync();
}
fn copy_buffer(
&self,
buffer_src: Arc<Buffer>,
buffer_dest: Arc<Buffer>,
_wait_for_registered: bool, // Ignored when sync
) -> Arc<Buffer> {
let mut server = self.server.lock();
server.buffer_to_buffer(buffer_src, buffer_dest.clone());
buffer_dest
}
fn read_buffer(&self, buffer: Arc<Buffer>) -> Vec<u8> {
let mut server = self.server.lock();
server.read_buffer(&buffer)
}
fn register_compute(
&self,
bind_group: BindGroup,
pipeline: Arc<ComputePipeline>,
work_group: WorkGroup,
) {
let mut server = self.server.lock();
server.register_compute(ComputeTask::new(bind_group, pipeline, work_group));
}
}
}

View File

@ -1,6 +0,0 @@
pub(super) mod client;
pub(super) mod server;
mod base;
pub use base::*;

View File

@ -1,269 +0,0 @@
use super::{client::ContextClient, WorkGroup};
use std::sync::Arc;
use wgpu::{BindGroup, Buffer, CommandEncoder, ComputePipeline};
#[cfg(feature = "async")]
pub use async_server::{AsyncContextServer, ContextTask, CopyBufferTask, ReadBufferTask};
/// Context server allow to run tasks on the GPU.
///
/// # Notes
///
/// There are two implementations of this trait. One is a bit more performant while the other
/// doesn't require std.
///
/// * [Asynchronous server](AsyncContextServer).
/// * [Synchronous server](SyncContextServer).
pub trait ContextServer {
/// The client where task can be sent to the server for execution.
type Client: ContextClient;
/// Start the server and returns its [client](ContextClient).
fn start(device: Arc<wgpu::Device>, queue: wgpu::Queue) -> Self::Client;
}
/// Context server where each operation is added in a synchronous maner.
#[derive(Debug)]
pub struct SyncContextServer {
device: Arc<wgpu::Device>,
queue: wgpu::Queue,
encoder: CommandEncoder,
tasks: Vec<ComputeTask>,
max_tasks: usize,
}
/// Basic building block to execute computing tasks on the GPU.
#[derive(new, Debug)]
pub struct ComputeTask {
bind_group: BindGroup,
pipeline: Arc<ComputePipeline>,
work_group: WorkGroup,
}
/// Most of the functions are similar to [server client](IContextClient).
///
/// The main difference comes from the functions are mutable instead of immutable requirering a
/// lock by a sync client or using an async channel with the async client/server.
impl SyncContextServer {
/// Create a new sync context server.
pub fn new(device: Arc<wgpu::Device>, queue: wgpu::Queue) -> Self {
let encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Command Encoder"),
});
// TODO: Support a way to modify this value without std.
let max_tasks = match std::env::var("BURN_WGPU_MAX_TASKS") {
Ok(value) => value
.parse::<usize>()
.expect("BURN_WGPU_MAX_TASKS should be a positive integer."),
Err(_) => 16, // 16 tasks by default
};
Self {
device,
queue,
encoder,
tasks: Vec::new(),
max_tasks,
}
}
pub fn register_compute(&mut self, task: ComputeTask) {
self.tasks.push(task);
if self.tasks.len() > self.max_tasks {
self.register_tasks();
self.submit();
}
}
pub fn read_buffer(&mut self, buffer: &Buffer) -> Vec<u8> {
// Register previous tasks before reading the buffer so that it is up to date.
self.register_tasks();
let size = buffer.size();
let buffer_dest = self.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
self.encoder
.copy_buffer_to_buffer(buffer, 0, &buffer_dest, 0, size);
self.submit();
let buffer_slice = buffer_dest.slice(..);
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |v| {
sender
.send(v)
.expect("Unable to send buffer slice result to async channel.")
});
self.device.poll(wgpu::Maintain::Wait);
let result = pollster::block_on(receiver.receive());
if let Some(Ok(())) = result {
let data = buffer_slice.get_mapped_range();
let result = bytemuck::cast_slice(&data).to_vec();
drop(data);
buffer_dest.unmap();
result
} else {
panic!("Unable to read buffer {:?}", result)
}
}
pub fn sync(&mut self) {
if !self.tasks.is_empty() {
self.register_tasks();
self.submit();
}
self.device.poll(wgpu::Maintain::Wait);
}
pub fn buffer_to_buffer(&mut self, buffer_src: Arc<Buffer>, buffer_dest: Arc<Buffer>) {
self.encoder
.copy_buffer_to_buffer(&buffer_src, 0, &buffer_dest, 0, buffer_src.size());
}
fn register_tasks(&mut self) {
let mut compute = self
.encoder
.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
for task in self.tasks.iter() {
compute.set_pipeline(&task.pipeline);
compute.set_bind_group(0, &task.bind_group, &[]);
compute.dispatch_workgroups(task.work_group.x, task.work_group.y, task.work_group.z);
}
std::mem::drop(compute);
self.tasks.clear();
}
fn submit(&mut self) {
assert!(
self.tasks.is_empty(),
"Tasks should be completed before submitting the current encoder."
);
let mut new_encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
core::mem::swap(&mut new_encoder, &mut self.encoder);
self.queue.submit(Some(new_encoder.finish()));
}
}
#[cfg(feature = "async")]
mod async_server {
use crate::context::client::AsyncContextClient;
use super::{ComputeTask, ContextServer, SyncContextServer};
use std::sync::{mpsc, Arc};
use wgpu::Buffer;
#[derive(new)]
pub struct ReadBufferTask {
buffer: Arc<Buffer>,
sender: mpsc::Sender<Vec<u8>>,
}
#[derive(new)]
pub struct CopyBufferTask {
pub(crate) buffer_src: Arc<Buffer>,
pub(crate) buffer_dest: Arc<Buffer>,
}
pub enum ContextTask {
Compute(ComputeTask),
ReadBuffer(ReadBufferTask),
CopyBuffer(CopyBufferTask),
Sync(mpsc::Sender<()>),
}
impl From<ComputeTask> for ContextTask {
fn from(val: ComputeTask) -> Self {
ContextTask::Compute(val)
}
}
impl From<ReadBufferTask> for ContextTask {
fn from(val: ReadBufferTask) -> Self {
ContextTask::ReadBuffer(val)
}
}
impl From<CopyBufferTask> for ContextTask {
fn from(val: CopyBufferTask) -> Self {
ContextTask::CopyBuffer(val)
}
}
/// Asynchronous context server where [tasks](ContextTask) are sent using a channel.
///
/// # Notes
///
/// This is pretty useful to avoid blocking the main thread when registering and
/// executing [compute tasks](ComputeTask).
pub struct AsyncContextServer {
server: SyncContextServer,
receiver: mpsc::Receiver<ContextTask>,
}
impl AsyncContextServer {
fn run(mut self) {
loop {
let task = self.receiver.recv().unwrap();
match task {
ContextTask::Compute(task) => self.server.register_compute(task),
ContextTask::CopyBuffer(task) => self
.server
.buffer_to_buffer(task.buffer_src, task.buffer_dest),
ContextTask::ReadBuffer(task) => {
let bytes = self.server.read_buffer(&task.buffer);
task.sender.send(bytes).unwrap();
}
ContextTask::Sync(callback) => {
self.server.sync();
callback.send(()).unwrap();
}
};
}
}
}
impl ContextServer for AsyncContextServer {
type Client = AsyncContextClient;
fn start(device: Arc<wgpu::Device>, queue: wgpu::Queue) -> Self::Client {
let (sender, receiver) = std::sync::mpsc::sync_channel(50);
let server = SyncContextServer::new(device, queue);
let context = Self { server, receiver };
let handle = std::thread::spawn(|| context.run());
AsyncContextClient::new(sender, handle)
}
}
}
#[cfg(not(feature = "async"))]
mod sync_server {
use super::{ContextServer, SyncContextServer};
use crate::context::client::SyncContextClient;
use std::sync::Arc;
impl ContextServer for SyncContextServer {
type Client = SyncContextClient;
fn start(device: Arc<wgpu::Device>, queue: wgpu::Queue) -> Self::Client {
let server = Self::new(device, queue);
SyncContextClient::new(server)
}
}
}

View File

@ -1,17 +1,21 @@
use super::SourceTemplate;
use crate::{context::WorkGroup, element::WgpuElement, tensor::WgpuTensor};
use crate::{
compute::{StaticKernel, WorkGroup},
element::WgpuElement,
tensor::WgpuTensor,
};
use std::marker::PhantomData;
/// Static wgpu kernel to create a [source template](SourceTemplate).
pub trait StaticKernel: Send + 'static {
pub trait StaticKernelSource: Send + 'static {
/// Source template for the kernel.
fn source_template() -> SourceTemplate;
fn source() -> SourceTemplate;
}
/// Dynamic wgpu kernel to create a [source template](SourceTemplate).
pub trait DynamicKernel: Send {
pub trait DynamicKernelSource: Send {
/// Source template for the kernel.
fn source_template(self) -> SourceTemplate;
fn source(self) -> SourceTemplate;
/// Identifier for the kernel, used for caching kernel compilation.
fn id(&self) -> String;
}
@ -27,8 +31,8 @@ macro_rules! kernel_wgsl {
#[derive(new)]
pub struct $struct;
impl $crate::kernel::StaticKernel for $struct {
fn source_template() -> $crate::kernel::SourceTemplate {
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::SourceTemplate::new(include_str!($file))
}
}
@ -48,23 +52,21 @@ pub fn into_contiguous<E: WgpuElement, const D: usize>(
const WORKGROUP: usize = 32;
let num_elems = tensor.shape.num_elements();
let buffer = tensor
.context
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(tensor.context.clone(), tensor.shape.clone(), buffer);
let handle = tensor.client.empty(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(
tensor.client.clone(),
tensor.device.clone(),
tensor.shape.clone(),
handle,
);
let info = build_info(&[&tensor, &output]);
let info_buffer = tensor
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let info_handle = tensor.client.create(bytemuck::cast_slice(&info));
let kernel = tensor
.context
.compile_static::<KernelSettings<ContiguousRaw, E, i32, WORKGROUP, WORKGROUP, 1>>();
tensor.context.execute(
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
&[&tensor.buffer, &output.buffer, &info_buffer],
tensor.client.execute(
Box::new(StaticKernel::<
KernelSettings<ContiguousRaw, E, i32, WORKGROUP, WORKGROUP, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP))),
&[&tensor.handle, &output.handle, &info_handle],
);
output
@ -72,7 +74,7 @@ pub fn into_contiguous<E: WgpuElement, const D: usize>(
/// Generates kernel source code by replacing some information using templating.
pub struct KernelSettings<
K: StaticKernel,
K: StaticKernelSource,
E: WgpuElement,
I: WgpuElement,
const WORKGROUP_X_SIZE: usize,
@ -85,17 +87,17 @@ pub struct KernelSettings<
}
impl<
K: StaticKernel,
K: StaticKernelSource,
E: WgpuElement,
I: WgpuElement,
const WORKGROUP_X_SIZE: usize,
const WORKGROUP_Y_SIZE: usize,
const WORKGROUP_Z_SIZE: usize,
> StaticKernel
> StaticKernelSource
for KernelSettings<K, E, I, WORKGROUP_X_SIZE, WORKGROUP_Y_SIZE, WORKGROUP_Z_SIZE>
{
fn source_template() -> SourceTemplate {
K::source_template()
fn source() -> SourceTemplate {
K::source()
.register("workgroup_size_x", WORKGROUP_X_SIZE.to_string())
.register("workgroup_size_y", WORKGROUP_Y_SIZE.to_string())
.register("workgroup_size_z", WORKGROUP_Z_SIZE.to_string())
@ -110,7 +112,7 @@ impl<
/// Generate kernel source code by replacing some information using templating.
#[derive(new)]
pub struct DynamicKernelSettings<K: StaticKernel, E: WgpuElement, I: WgpuElement> {
pub struct DynamicKernelSettings<K: StaticKernelSource, E: WgpuElement, I: WgpuElement> {
workgroup_x_size: usize,
workgroup_y_size: usize,
workgroup_z_size: usize,
@ -119,11 +121,11 @@ pub struct DynamicKernelSettings<K: StaticKernel, E: WgpuElement, I: WgpuElement
_i: PhantomData<I>,
}
impl<K: StaticKernel, E: WgpuElement, I: WgpuElement> DynamicKernel
impl<K: StaticKernelSource, E: WgpuElement, I: WgpuElement> DynamicKernelSource
for DynamicKernelSettings<K, E, I>
{
fn source_template(self) -> SourceTemplate {
K::source_template()
fn source(self) -> SourceTemplate {
K::source()
.register("workgroup_size_x", self.workgroup_x_size.to_string())
.register("workgroup_size_y", self.workgroup_y_size.to_string())
.register("workgroup_size_z", self.workgroup_z_size.to_string())

View File

@ -1,4 +1,5 @@
use super::{build_info, elemwise_workgroup, KernelSettings, StaticKernel};
use super::{build_info, elemwise_workgroup, KernelSettings, StaticKernelSource};
use crate::compute::StaticKernel;
use crate::{element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
use burn_tensor::Shape;
@ -17,9 +18,9 @@ macro_rules! binary_elemwise {
) => {
pub struct $struct;
impl $crate::kernel::StaticKernel for $struct {
fn source_template() -> $crate::kernel::SourceTemplate {
$crate::kernel::BinaryElemwiseRaw::source_template().register(
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::BinaryElemwiseRaw::source().register(
"body",
format!("output[id] = lhs[index_lhs] {} rhs[index_rhs];", $ops),
)
@ -37,9 +38,9 @@ macro_rules! binary_elemwise_inplace {
) => {
pub struct $struct;
impl $crate::kernel::StaticKernel for $struct {
fn source_template() -> $crate::kernel::SourceTemplate {
$crate::kernel::BinaryElemwiseInplaceRaw::source_template().register(
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::BinaryElemwiseInplaceRaw::source().register(
"body",
format!("lhs[id] = lhs[id] {} rhs[index_rhs];", $ops),
)
@ -49,7 +50,7 @@ macro_rules! binary_elemwise_inplace {
}
/// Execute a binary kernel using the default settings.
pub fn binary_elemwise_default<K: StaticKernel, E: WgpuElement, const D: usize>(
pub fn binary_elemwise_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
@ -57,7 +58,12 @@ pub fn binary_elemwise_default<K: StaticKernel, E: WgpuElement, const D: usize>(
}
/// Execute a binary kernel using the provided WORKGROUP.
pub fn binary_elemwise<K: StaticKernel, E: WgpuElement, const D: usize, const WORKGROUP: usize>(
pub fn binary_elemwise<
K: StaticKernelSource,
E: WgpuElement,
const D: usize,
const WORKGROUP: usize,
>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
@ -76,30 +82,26 @@ pub fn binary_elemwise<K: StaticKernel, E: WgpuElement, const D: usize, const WO
let shape_out = Shape::new(shape_out);
let num_elems = shape_out.num_elements();
let buffer = lhs
.context
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(lhs.context.clone(), shape_out, buffer);
let handle = lhs.client.empty(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(lhs.client.clone(), lhs.device.clone(), shape_out, handle);
let kernel = lhs
.context
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
let info = build_info(&[&lhs, &rhs, &output]);
let info_buffers = lhs
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
lhs.context.execute(
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
&[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers],
);
lhs.client.execute(
Box::new(kernel),
&[&lhs.handle, &rhs.handle, &output.handle, &info_handle],
);
output
}
/// Execute a binary inplace kernel using the default settings.
pub fn binary_elemwise_inplace_default<K: StaticKernel, E: WgpuElement, const D: usize>(
pub fn binary_elemwise_inplace_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
@ -108,7 +110,7 @@ pub fn binary_elemwise_inplace_default<K: StaticKernel, E: WgpuElement, const D:
/// Execute a binary inplace kernel using the provided WORKGROUP.
pub fn binary_elemwise_inplace<
K: StaticKernel,
K: StaticKernelSource,
E: WgpuElement,
const D: usize,
const WORKGROUP: usize,
@ -118,20 +120,15 @@ pub fn binary_elemwise_inplace<
) -> WgpuTensor<E, D> {
lhs.assert_is_on_same_device(&rhs);
let kernel = lhs
.context
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
let info = build_info(&[&lhs, &rhs]);
let info_buffers = lhs
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
lhs.context.execute(
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP),
kernel,
&[&lhs.buffer, &rhs.buffer, &info_buffers],
);
lhs.client
.execute(Box::new(kernel), &[&lhs.handle, &rhs.handle, &info_handle]);
lhs
}

View File

@ -1,8 +1,10 @@
use crate::{element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl, tensor::WgpuTensor};
use super::{KernelSettings, SourceTemplate, StaticKernelSource};
use crate::{
compute::StaticKernel, element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl,
tensor::WgpuTensor,
};
use std::{any::TypeId, marker::PhantomData};
use super::{KernelSettings, SourceTemplate, StaticKernel};
kernel_wgsl!(CastRaw, "../template/cast.wgsl");
struct Cast<InputElem: WgpuElement, OutputElem: WgpuElement> {
@ -10,9 +12,11 @@ struct Cast<InputElem: WgpuElement, OutputElem: WgpuElement> {
_o: PhantomData<OutputElem>,
}
impl<InputElem: WgpuElement, OutputElem: WgpuElement> StaticKernel for Cast<InputElem, OutputElem> {
fn source_template() -> SourceTemplate {
CastRaw::source_template()
impl<InputElem: WgpuElement, OutputElem: WgpuElement> StaticKernelSource
for Cast<InputElem, OutputElem>
{
fn source() -> SourceTemplate {
CastRaw::source()
.register("input_elem", InputElem::type_name())
.register("output_elem", OutputElem::type_name())
}
@ -23,32 +27,30 @@ pub fn cast<InputElem: WgpuElement, OutputElem: WgpuElement, const D: usize>(
tensor: WgpuTensor<InputElem, D>,
) -> WgpuTensor<OutputElem, D> {
if TypeId::of::<InputElem>() == TypeId::of::<OutputElem>() {
return WgpuTensor::new(tensor.context, tensor.shape, tensor.buffer);
return WgpuTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle);
}
const WORKGROUP: usize = 32;
let num_elems = tensor.shape.num_elements();
let kernel = tensor.context.compile_static::<KernelSettings<
Cast<InputElem, OutputElem>,
f32,
i32,
WORKGROUP,
WORKGROUP,
1,
>>();
let kernel = StaticKernel::<
KernelSettings<Cast<InputElem, OutputElem>, f32, i32, WORKGROUP, WORKGROUP, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP));
let buffer = tensor
.context
.create_buffer(num_elems * core::mem::size_of::<OutputElem>());
let output = WgpuTensor::new(tensor.context.clone(), tensor.shape.clone(), buffer);
tensor.context.execute(
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
&[&tensor.buffer, &output.buffer],
let handle = tensor
.client
.empty(num_elems * core::mem::size_of::<OutputElem>());
let output = WgpuTensor::new(
tensor.client.clone(),
tensor.device,
tensor.shape.clone(),
handle,
);
tensor
.client
.execute(Box::new(kernel), &[&tensor.handle, &output.handle]);
output
}

View File

@ -1,4 +1,5 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{build_info, elemwise_workgroup, KernelSettings},
kernel_wgsl,
@ -14,16 +15,20 @@ pub fn cat<E: WgpuElement, const D: usize>(
const WORKGROUP: usize = 32;
let first_input = inputs.get(0).unwrap();
let context = &first_input.context;
let client = &first_input.client;
let mut shape_output = first_input.shape.clone();
shape_output.dims[dim] = inputs.iter().map(|input| input.shape.dims[dim]).sum();
let buffer = first_input
.context
.create_buffer(shape_output.num_elements() * std::mem::size_of::<E>());
.client
.empty(shape_output.num_elements() * std::mem::size_of::<E>());
let output = WgpuTensor::new(context.clone(), shape_output, buffer);
let kernel = context.compile_static::<KernelSettings<Cat, E, i32, WORKGROUP, WORKGROUP, 1>>();
let output = WgpuTensor::new(
client.clone(),
first_input.device.clone(),
shape_output,
buffer,
);
let mut dim_cat_index = 0;
@ -32,12 +37,14 @@ pub fn cat<E: WgpuElement, const D: usize>(
info.push(dim as u32);
info.push(dim_cat_index as u32);
dim_cat_index += input.shape.dims[dim];
let info_buffer = context.create_buffer_with_data(bytemuck::cast_slice(&info));
context.execute(
let info_buffer = client.create(bytemuck::cast_slice(&info));
let kernel = StaticKernel::<KernelSettings<Cat, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(input.shape.num_elements(), WORKGROUP),
kernel.clone(),
&[&input.buffer, &output.buffer, &info_buffer],
);
client.execute(
Box::new(kernel),
&[&input.handle, &output.handle, &info_buffer],
);
}

View File

@ -1,7 +1,9 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{build_info, elemwise_workgroup, KernelSettings, StaticKernel},
kernel::{build_info, elemwise_workgroup, KernelSettings, StaticKernelSource},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
};
use burn_tensor::Shape;
@ -21,9 +23,9 @@ macro_rules! comparison {
) => {
pub struct $struct;
impl $crate::kernel::StaticKernel for $struct {
fn source_template() -> $crate::kernel::SourceTemplate {
$crate::kernel::ComparisonRaw::source_template().register(
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::ComparisonRaw::source().register(
"body",
format!("output[id] = u32(lhs[index_lhs] {} rhs[index_rhs]);", $ops),
)
@ -41,9 +43,9 @@ macro_rules! comparison_inplace {
) => {
pub struct $struct;
impl $crate::kernel::StaticKernel for $struct {
fn source_template() -> $crate::kernel::SourceTemplate {
$crate::kernel::ComparisonInplaceRaw::source_template()
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::ComparisonInplaceRaw::source()
.register(
"body",
"lhs[index_lhs] = compare(lhs[index_lhs], rhs[index_rhs]);",
@ -59,7 +61,7 @@ macro_rules! comparison_inplace {
};
}
pub fn comparison<K: StaticKernel, E: WgpuElement, const D: usize>(
pub fn comparison<K: StaticKernelSource, E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<u32, D> {
@ -79,29 +81,23 @@ pub fn comparison<K: StaticKernel, E: WgpuElement, const D: usize>(
let shape_out = Shape::new(shape_out);
let num_elems = shape_out.num_elements();
let buffer = lhs
.context
.create_buffer(num_elems * core::mem::size_of::<u32>());
let output = WgpuTensor::new(lhs.context.clone(), shape_out, buffer);
let output = empty_device(lhs.client.clone(), lhs.device.clone(), shape_out);
let kernel = lhs
.context
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
let info = build_info(&[&lhs, &rhs, &output]);
let info_buffers = lhs
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
lhs.context.execute(
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
&[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers],
);
let info = build_info(&[&lhs, &rhs, &output]);
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
lhs.client.execute(
Box::new(kernel),
&[&lhs.handle, &rhs.handle, &output.handle, &info_handle],
);
WgpuTensor::new(output.context, output.shape, output.buffer)
WgpuTensor::new(output.client, output.device, output.shape, output.handle)
}
pub fn comparison_inplace<K: StaticKernel, E: WgpuElement, const D: usize>(
pub fn comparison_inplace<K: StaticKernelSource, E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<u32, D> {
@ -109,21 +105,16 @@ pub fn comparison_inplace<K: StaticKernel, E: WgpuElement, const D: usize>(
lhs.assert_is_on_same_device(&rhs);
let kernel = lhs
.context
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
let info = build_info(&[&lhs, &rhs]);
let info_buffers = lhs
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
lhs.context.execute(
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP),
kernel,
&[&lhs.buffer, &rhs.buffer, &info_buffers],
);
let info = build_info(&[&lhs, &rhs]);
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
WgpuTensor::new(lhs.context, lhs.shape, lhs.buffer)
lhs.client
.execute(Box::new(kernel), &[&lhs.handle, &rhs.handle, &info_handle]);
WgpuTensor::new(lhs.client, lhs.device, lhs.shape, lhs.handle)
}
#[cfg(test)]

View File

@ -1,6 +1,7 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{elemwise_workgroup, KernelSettings, StaticKernel},
kernel::{elemwise_workgroup, KernelSettings, StaticKernelSource},
kernel_wgsl,
tensor::WgpuTensor,
};
@ -20,9 +21,9 @@ macro_rules! comparison_elem {
) => {
pub struct $struct;
impl $crate::kernel::StaticKernel for $struct {
fn source_template() -> $crate::kernel::SourceTemplate {
$crate::kernel::ComparisonElemRaw::source_template()
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::ComparisonElemRaw::source()
.register("body", format!("output[id] = u32(lhs[id] {} rhs);", $ops))
}
}
@ -38,9 +39,9 @@ macro_rules! comparison_elem_inplace {
) => {
pub struct $struct;
impl $crate::kernel::StaticKernel for $struct {
fn source_template() -> $crate::kernel::SourceTemplate {
$crate::kernel::ComparisonElemInplaceRaw::source_template()
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::ComparisonElemInplaceRaw::source()
.register("body", "lhs[id] = compare(lhs[id], rhs);")
.add_template(format!(
"{}return {{{{ elem }}}}(lhs {} rhs);{}",
@ -53,48 +54,39 @@ macro_rules! comparison_elem_inplace {
};
}
pub fn comparison_elem<K: StaticKernel, E: WgpuElement, const D: usize>(
pub fn comparison_elem<K: StaticKernelSource, E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: E,
) -> WgpuTensor<u32, D> {
const WORKGROUP: usize = 32;
let num_elems = lhs.shape.num_elements();
let buffer = lhs
.context
.create_buffer(num_elems * core::mem::size_of::<u32>());
let rhs_buffer = lhs.context.create_buffer_with_data(E::as_bytes(&[rhs]));
let kernel = lhs
.context
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
lhs.context.execute(
let handle = lhs.client.empty(num_elems * core::mem::size_of::<u32>());
let rhs_handle = lhs.client.create(E::as_bytes(&[rhs]));
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
&[&lhs.buffer, &rhs_buffer, &buffer],
);
WgpuTensor::new(lhs.context, lhs.shape, buffer)
lhs.client
.execute(Box::new(kernel), &[&lhs.handle, &rhs_handle, &handle]);
WgpuTensor::new(lhs.client, lhs.device, lhs.shape, handle)
}
pub fn comparison_elem_inplace<K: StaticKernel, E: WgpuElement, const D: usize>(
pub fn comparison_elem_inplace<K: StaticKernelSource, E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: E,
) -> WgpuTensor<u32, D> {
const WORKGROUP: usize = 32;
let kernel = lhs
.context
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
let rhs_buffer = lhs.context.create_buffer_with_data(E::as_bytes(&[rhs]));
lhs.context.execute(
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP),
kernel,
&[&lhs.buffer, &rhs_buffer],
);
let rhs_handle = lhs.client.create(E::as_bytes(&[rhs]));
lhs.client
.execute(Box::new(kernel), &[&lhs.handle, &rhs_handle]);
WgpuTensor::new(lhs.context, lhs.shape, lhs.buffer)
WgpuTensor::new(lhs.client, lhs.device, lhs.shape, lhs.handle)
}
#[cfg(test)]

View File

@ -1,17 +1,19 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{self, build_info, elemwise_workgroup, KernelSettings},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
};
use burn_tensor::{
ops::{conv::calculate_conv_output_size, ConvOptions},
Shape,
Element, ElementConversion, Shape,
};
kernel_wgsl!(Conv2d, "../../template/conv/conv2d.wgsl");
pub(crate) fn conv2d<E: WgpuElement>(
pub(crate) fn conv2d<E: WgpuElement + Element>(
input: WgpuTensor<E, 4>,
weight: WgpuTensor<E, 4>,
bias: Option<WgpuTensor<E, 1>>,
@ -40,12 +42,12 @@ pub(crate) fn conv2d<E: WgpuElement>(
);
let shape_out = Shape::new([batch_size, out_channels, out_0, out_1]);
let num_elems = shape_out.num_elements();
let buffer = input
.context
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(input.context.clone(), shape_out, buffer);
let output = empty_device(
input.client.clone(),
input.device.clone(),
shape_out.clone(),
);
let mut info = build_info(&[&input, &output, &weight]);
info.push(options.stride[0] as u32);
@ -56,27 +58,24 @@ pub(crate) fn conv2d<E: WgpuElement>(
info.push(options.dilation[1] as u32);
info.push(options.groups as u32);
let bias_buffer = bias
.map(|bias| bias.buffer)
.unwrap_or_else(|| input.context.create_buffer(core::mem::size_of::<E>()));
let bias_handle = bias
.map(|bias| bias.handle)
.unwrap_or_else(|| input.client.create(E::as_bytes(&[0.elem()])));
let info_buffer = input
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let info_handle = input.client.create(bytemuck::cast_slice(&info));
let kernel = input
.context
.compile_static::<KernelSettings<Conv2d, E, i32, WORKGROUP, WORKGROUP, 1>>();
input.context.execute(
let kernel = StaticKernel::<KernelSettings<Conv2d, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
kernel,
);
input.client.execute(
Box::new(kernel),
&[
&input.buffer,
&weight.buffer,
&bias_buffer,
&output.buffer,
&info_buffer,
&input.handle,
&weight.handle,
&bias_handle,
&output.handle,
&info_handle,
],
);

View File

@ -1,14 +1,16 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{self, build_info, elemwise_workgroup, KernelSettings},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
};
use burn_tensor::{ops::ConvTransposeOptions, Shape};
use burn_tensor::{ops::ConvTransposeOptions, Element, ElementConversion, Shape};
kernel_wgsl!(ConvTranspose2d, "../../template/conv/conv_transpose2d.wgsl");
pub(crate) fn conv_transpose2d<E: WgpuElement>(
pub(crate) fn conv_transpose2d<E: WgpuElement + Element>(
input: WgpuTensor<E, 4>,
weight: WgpuTensor<E, 4>,
bias: Option<WgpuTensor<E, 1>>,
@ -35,12 +37,13 @@ pub(crate) fn conv_transpose2d<E: WgpuElement>(
let shape_out = Shape::new([batch_size, out_channels * options.groups, out_0, out_1]);
let num_elems = shape_out.num_elements();
let buffer = input
.context
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(input.context.clone(), shape_out, buffer);
let output = empty_device(
input.client.clone(),
input.device.clone(),
shape_out.clone(),
);
let mut info = build_info(&[&input, &output, &weight]);
info.push(options.stride[0] as u32);
info.push(options.stride[1] as u32);
info.push(options.padding[0] as u32);
@ -49,28 +52,24 @@ pub(crate) fn conv_transpose2d<E: WgpuElement>(
info.push(options.dilation[1] as u32);
info.push(options.groups as u32);
let bias_buffer = bias
.map(|bias| bias.buffer)
.unwrap_or_else(|| input.context.create_buffer(core::mem::size_of::<E>()));
let bias_handle = bias
.map(|bias| bias.handle)
.unwrap_or_else(|| input.client.create(E::as_bytes(&[0.elem()])));
let info_buffer = input
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let info_handle = input.client.create(bytemuck::cast_slice(&info));
let kernel = input
.context
.compile_static::<KernelSettings<ConvTranspose2d, E, i32, WORKGROUP, WORKGROUP, 1>>();
let workgroup = elemwise_workgroup(num_elems, WORKGROUP);
input.context.execute(
workgroup,
kernel,
let kernel =
StaticKernel::<KernelSettings<ConvTranspose2d, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
);
input.client.execute(
Box::new(kernel),
&[
&input.buffer,
&weight.buffer,
&bias_buffer,
&output.buffer,
&info_buffer,
&input.handle,
&weight.handle,
&bias_handle,
&output.handle,
&info_handle,
],
);

View File

@ -1,7 +1,9 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{self, build_info, elemwise_workgroup, KernelSettings},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
};
@ -17,29 +19,23 @@ pub(crate) fn gather<E: WgpuElement, I: WgpuElement, const D: usize>(
let shape_output = indices.shape.clone();
let num_elems = shape_output.num_elements();
let indices = kernel::into_contiguous(indices);
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
let buffer = tensor
.context
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(tensor.context.clone(), shape_output, buffer);
let mut info = build_info(&[&tensor, &output]);
info.push(dim as u32);
let info_buffer = tensor
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let info_handle = tensor.client.create(bytemuck::cast_slice(&info));
let kernel = tensor
.context
.compile_static::<KernelSettings<Gather, E, i32, WORKGROUP, WORKGROUP, 1>>();
tensor.context.execute(
let kernel = StaticKernel::<KernelSettings<Gather, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
);
tensor.client.execute(
Box::new(kernel),
&[
&tensor.buffer,
&indices.buffer,
&output.buffer,
&info_buffer,
&tensor.handle,
&indices.handle,
&output.handle,
&info_handle,
],
);

View File

@ -1,4 +1,5 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{self, build_info, elemwise_workgroup, KernelSettings},
kernel_wgsl,
@ -48,18 +49,15 @@ pub(crate) fn scatter<E: WgpuElement, I: WgpuElement, const D: usize>(
info.push(dim as u32);
let info_buffer = tensor
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let info_handle = tensor.client.create(bytemuck::cast_slice(&info));
let kernel = tensor
.context
.compile_static::<KernelSettings<Scatter, E, i32, WORKGROUP, WORKGROUP, 1>>();
tensor.context.execute(
let kernel = StaticKernel::<KernelSettings<Scatter, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems_per_workgroup, WORKGROUP),
kernel,
&[&tensor.buffer, &indices.buffer, &value.buffer, &info_buffer],
);
tensor.client.execute(
Box::new(kernel),
&[&tensor.handle, &indices.handle, &value.handle, &info_handle],
);
tensor

View File

@ -1,7 +1,9 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{build_info, elemwise_workgroup, KernelSettings},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
};
@ -20,32 +22,25 @@ pub(crate) fn select<E: WgpuElement, I: WgpuElement, const D: usize>(
let mut output_shape = tensor.shape.clone();
output_shape.dims[dim] = indices.shape.dims[0];
let num_elems = output_shape.num_elements();
let buffer = tensor
.context
.create_buffer(num_elems * std::mem::size_of::<E>());
let output = WgpuTensor::new(tensor.context.clone(), output_shape, buffer);
let num_elems = output_shape.num_elements();
let output = empty_device(tensor.client.clone(), tensor.device.clone(), output_shape);
let mut info = build_info(&[&tensor, &output]);
info.push(dim as u32);
let info_buffer = tensor
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let kernel = tensor
.context
.compile_static::<KernelSettings<IndexSelect, E, I, WORKGROUP, WORKGROUP, 1>>();
tensor.context.execute(
let info_handle = output.client.create(bytemuck::cast_slice(&info));
let kernel = StaticKernel::<KernelSettings<IndexSelect, E, I, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
);
tensor.client.execute(
Box::new(kernel),
&[
&tensor.buffer,
&indices.buffer,
&output.buffer,
&info_buffer,
&tensor.handle,
&indices.handle,
&output.handle,
&info_handle,
],
);
@ -89,18 +84,16 @@ pub(crate) fn select_assign<E: WgpuElement, I: WgpuElement, const D: usize>(
info.push(dim as u32);
let info_buffer = tensor
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let info_handle = tensor.client.create(bytemuck::cast_slice(&info));
let kernel = tensor
.context
.compile_static::<KernelSettings<SelectAssignInplace, E, I, WORKGROUP, WORKGROUP, 1>>();
let kernel =
StaticKernel::<KernelSettings<SelectAssignInplace, E, I, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems_per_workgroup, WORKGROUP),
);
tensor.context.execute(
elemwise_workgroup(num_elems_per_workgroup, WORKGROUP),
kernel,
&[&tensor.buffer, &indices.buffer, &value.buffer, &info_buffer],
tensor.client.execute(
Box::new(kernel),
&[&tensor.handle, &indices.handle, &value.handle, &info_handle],
);
tensor

View File

@ -1,7 +1,9 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{build_info, elemwise_workgroup, KernelSettings},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
};
use burn_tensor::Shape;
@ -26,10 +28,7 @@ pub(crate) fn slice<E: WgpuElement, const D1: usize, const D2: usize>(
let shape_output = Shape::new(dims);
let num_elems = shape_output.num_elements();
let buffer = tensor
.context
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(tensor.context.clone(), shape_output, buffer);
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
let mut info = build_info(&[&tensor, &output]);
for i in 0..D1 {
@ -37,18 +36,15 @@ pub(crate) fn slice<E: WgpuElement, const D1: usize, const D2: usize>(
info.push(start as u32);
}
let info_buffer = tensor
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let info_handle = output.client.create(bytemuck::cast_slice(&info));
let kernel = tensor
.context
.compile_static::<KernelSettings<IndexRaw, E, i32, WORKGROUP, WORKGROUP, 1>>();
tensor.context.execute(
let kernel = StaticKernel::<KernelSettings<IndexRaw, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
&[&tensor.buffer, &output.buffer, &info_buffer],
);
tensor.client.execute(
Box::new(kernel),
&[&tensor.handle, &output.handle, &info_handle],
);
output
@ -73,23 +69,15 @@ pub(crate) fn slice_assign<E: WgpuElement, const D1: usize, const D2: usize>(
info.push(start as u32);
}
let info_buffer = tensor
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let info_handle = tensor.client.create(bytemuck::cast_slice(&info));
let kernel = tensor.context.compile_static::<KernelSettings<
IndexAssignInplaceRaw,
E,
i32,
WORKGROUP,
WORKGROUP,
1,
>>();
let kernel = StaticKernel::<
KernelSettings<IndexAssignInplaceRaw, E, i32, WORKGROUP, WORKGROUP, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP));
tensor.context.execute(
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
&[&tensor.buffer, &value.buffer, &info_buffer],
tensor.client.execute(
Box::new(kernel),
&[&tensor.handle, &value.handle, &info_handle],
);
tensor

View File

@ -1,7 +1,9 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{build_info, elemwise_workgroup, KernelSettings},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
};
@ -16,30 +18,28 @@ pub fn mask_fill<E: WgpuElement, const D: usize>(
const WORKGROUP: usize = 32;
let num_elems = input.shape.num_elements();
let buffer = input
.context
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(input.context.clone(), input.shape.clone(), buffer);
let output = empty_device(
input.client.clone(),
input.device.clone(),
input.shape.clone(),
);
let value_buffer = input.context.create_buffer_with_data(E::as_bytes(&[value]));
let kernel = input
.context
.compile_static::<KernelSettings<MaskFill, E, i32, WORKGROUP, WORKGROUP, 1>>();
let mask = WgpuTensor::new(mask.context, mask.shape, mask.buffer);
let info = build_info(&[&input, &mask, &output]);
let info_buffers = input
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
input.context.execute(
let value_handle = output.client.create(E::as_bytes(&[value]));
let kernel = StaticKernel::<KernelSettings<MaskFill, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
);
let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle);
let info = build_info(&[&input, &mask, &output]);
let info_handle = input.client.create(bytemuck::cast_slice(&info));
input.client.execute(
Box::new(kernel),
&[
&input.buffer,
&value_buffer,
&mask.buffer,
&output.buffer,
&info_buffers,
&input.handle,
&value_handle,
&mask.handle,
&output.handle,
&info_handle,
],
);
@ -54,20 +54,18 @@ pub fn mask_fill_inplace<E: WgpuElement, const D: usize>(
const WORKGROUP: usize = 32;
let num_elems = input.shape.num_elements();
let value_buffer = input.context.create_buffer_with_data(E::as_bytes(&[value]));
let kernel = input
.context
.compile_static::<KernelSettings<MaskFillInplace, E, i32, WORKGROUP, WORKGROUP, 1>>();
let mask = WgpuTensor::new(mask.context, mask.shape, mask.buffer);
let value_handle = input.client.create(E::as_bytes(&[value]));
let kernel =
StaticKernel::<KernelSettings<MaskFillInplace, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
);
let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle);
let info = build_info(&[&input, &mask]);
let info_buffers = input
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let info_handle = input.client.create(bytemuck::cast_slice(&info));
input.context.execute(
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
&[&input.buffer, &value_buffer, &mask.buffer, &info_buffers],
input.client.execute(
Box::new(kernel),
&[&input.handle, &value_handle, &mask.handle, &info_handle],
);
input

View File

@ -1,7 +1,9 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{build_info, elemwise_workgroup, KernelSettings},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
};
@ -16,29 +18,27 @@ pub fn mask_where<E: WgpuElement, const D: usize>(
const WORKGROUP: usize = 32;
let num_elems = input.shape.num_elements();
let buffer = input
.context
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(input.context.clone(), input.shape.clone(), buffer);
let output = empty_device(
input.client.clone(),
input.device.clone(),
input.shape.clone(),
);
let kernel = input
.context
.compile_static::<KernelSettings<MaskWhere, E, i32, WORKGROUP, WORKGROUP, 1>>();
let mask = WgpuTensor::new(mask.context, mask.shape, mask.buffer);
let info = build_info(&[&input, &value, &mask, &output]);
let info_buffers = input
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
input.context.execute(
let kernel = StaticKernel::<KernelSettings<MaskWhere, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
);
let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle);
let info = build_info(&[&input, &value, &mask, &output]);
let info_handle = input.client.create(bytemuck::cast_slice(&info));
input.client.execute(
Box::new(kernel),
&[
&input.buffer,
&value.buffer,
&mask.buffer,
&output.buffer,
&info_buffers,
&input.handle,
&value.handle,
&mask.handle,
&output.handle,
&info_handle,
],
);
@ -53,23 +53,21 @@ pub fn mask_where_inplace<E: WgpuElement, const D: usize>(
) -> WgpuTensor<E, D> {
const WORKGROUP: usize = 32;
let kernel = input
.context
.compile_static::<KernelSettings<MaskWhereInplace, E, i32, WORKGROUP, WORKGROUP, 1>>();
let mask = WgpuTensor::new(mask.context, mask.shape, mask.buffer);
let kernel =
StaticKernel::<KernelSettings<MaskWhereInplace, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(input.shape.num_elements(), WORKGROUP),
);
let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle);
let mut info = build_info(&[&input, &value, &mask]);
info.push(match reverse {
true => 1,
false => 0,
});
let info_buffers = input
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let info_handle = input.client.create(bytemuck::cast_slice(&info));
input.context.execute(
elemwise_workgroup(input.shape.num_elements(), WORKGROUP),
kernel,
&[&input.buffer, &value.buffer, &mask.buffer, &info_buffers],
input.client.execute(
Box::new(kernel),
&[&input.handle, &value.handle, &mask.handle, &info_handle],
);
input

View File

@ -2,10 +2,13 @@ use std::marker::PhantomData;
use super::utils::shape_out;
use crate::{
context::WorkGroup,
compute::{DynamicKernel, WorkGroup},
element::WgpuElement,
kernel::{build_info, into_contiguous, DynamicKernel, SourceTemplate, StaticKernel},
kernel::{
build_info, into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource,
},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
};
@ -21,9 +24,9 @@ struct MatmulMemCoalescing<E: WgpuElement> {
_elem: PhantomData<E>,
}
impl<E: WgpuElement> DynamicKernel for MatmulMemCoalescing<E> {
fn source_template(self) -> SourceTemplate {
MatmulMemCoalescingRaw::source_template()
impl<E: WgpuElement> DynamicKernelSource for MatmulMemCoalescing<E> {
fn source(self) -> SourceTemplate {
MatmulMemCoalescingRaw::source()
.register("workgroup_size_x", self.workgroup_size_x.to_string())
.register("workgroup_size_y", self.workgroup_size_y.to_string())
.register("elem", E::type_name())
@ -59,26 +62,11 @@ pub fn matmul_mem_coalescing<E: WgpuElement, const D: usize>(
let num_rows = lhs.shape.dims[D - 2];
let num_cols = rhs.shape.dims[D - 1];
let buffer = lhs
.context
.create_buffer(shape_out.num_elements() * core::mem::size_of::<E>());
let output = WgpuTensor::new(lhs.context.clone(), shape_out, buffer);
let output = empty_device(lhs.client.clone(), lhs.device.clone(), shape_out);
// set number of workgroups
let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32;
let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32;
let kernel = lhs.context.compile_dynamic(MatmulMemCoalescing::<E>::new(
workgroup_size_x,
workgroup_size_y,
));
let info = build_info(&[&lhs, &rhs, &output]);
let info_buffers = lhs
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let mut num_iter = 1;
for i in 0..D - 2 {
num_iter *= output.shape.dims[i];
@ -86,10 +74,18 @@ pub fn matmul_mem_coalescing<E: WgpuElement, const D: usize>(
let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32);
lhs.context.execute(
let kernel = DynamicKernel::new(
MatmulMemCoalescing::<E>::new(workgroup_size_x, workgroup_size_y),
workgroup,
kernel,
&[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers],
);
let info = build_info(&[&lhs, &rhs, &output]);
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
lhs.client.execute(
Box::new(kernel),
&[&lhs.handle, &rhs.handle, &output.handle, &info_handle],
);
output

View File

@ -3,9 +3,7 @@ pub(crate) mod utils;
mod mem_coalescing;
mod naive;
mod tiling2d;
mod tune;
pub use mem_coalescing::*;
pub use naive::*;
pub use tiling2d::*;
pub use tune::*;

View File

@ -1,9 +1,10 @@
use super::utils::shape_out;
use crate::{
context::WorkGroup,
compute::{StaticKernel, WorkGroup},
element::WgpuElement,
kernel::{build_info, into_contiguous, KernelSettings, SourceTemplate, StaticKernel},
kernel::{build_info, into_contiguous, KernelSettings, SourceTemplate, StaticKernelSource},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
};
@ -11,11 +12,11 @@ kernel_wgsl!(MatmulNaiveRaw, "../../template/matmul/naive.wgsl");
struct MatmulNaive<const WORKGROUP_SIZE_X: usize, const WORKGROUP_SIZE_Y: usize>;
impl<const WORKGROUP_SIZE_X: usize, const WORKGROUP_SIZE_Y: usize> StaticKernel
impl<const WORKGROUP_SIZE_X: usize, const WORKGROUP_SIZE_Y: usize> StaticKernelSource
for MatmulNaive<WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>
{
fn source_template() -> SourceTemplate {
MatmulNaiveRaw::source_template()
fn source() -> SourceTemplate {
MatmulNaiveRaw::source()
.register("block_size_m", WORKGROUP_SIZE_X.to_string())
.register("block_size_n", WORKGROUP_SIZE_Y.to_string())
}
@ -49,41 +50,35 @@ pub fn matmul_naive<
let num_rows = lhs.shape.dims[D - 2];
let num_cols = rhs.shape.dims[D - 1];
let buffer = lhs
.context
.create_buffer(shape_out.num_elements() * core::mem::size_of::<E>());
let output = WgpuTensor::new(lhs.context.clone(), shape_out, buffer);
let output = empty_device(lhs.client.clone(), lhs.device.clone(), shape_out);
// set number of workgroups
let blocks_needed_in_x = f32::ceil(num_rows as f32 / WORKGROUP_SIZE_X as f32) as u32;
let blocks_needed_in_y = f32::ceil(num_cols as f32 / WORKGROUP_SIZE_Y as f32) as u32;
let kernel = lhs.context.compile_static::<KernelSettings<
MatmulNaive<WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>,
E,
i32,
WORKGROUP_SIZE_X,
WORKGROUP_SIZE_Y,
1,
>>();
let info = build_info(&[&lhs, &rhs, &output]);
let info_buffers = lhs
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let mut num_iter = 1;
for i in 0..D - 2 {
num_iter *= output.shape.dims[i];
}
let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32);
lhs.context.execute(
workgroup,
kernel,
&[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers],
let kernel = StaticKernel::<
KernelSettings<
MatmulNaive<WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>,
E,
i32,
WORKGROUP_SIZE_X,
WORKGROUP_SIZE_Y,
1,
>,
>::new(workgroup);
let info = build_info(&[&lhs, &rhs, &output]);
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
lhs.client.execute(
Box::new(kernel),
&[&lhs.handle, &rhs.handle, &output.handle, &info_handle],
);
output

View File

@ -1,29 +1,16 @@
use super::padding::{crop, pad_round, PaddingOutput};
use crate::{
context::{Context, WorkGroup},
compute::{DynamicKernel, WgpuHandle, WorkGroup},
element::WgpuElement,
kernel::{build_info, into_contiguous, matmul::utils::shape_out},
kernel::{build_info, into_contiguous, matmul::utils::shape_out, DynamicKernelSource},
ops::numeric::empty_device,
tensor::WgpuTensor,
};
use burn_tensor::Shape;
use std::{
cmp::{max, min},
sync::Arc,
};
use wgpu::ComputePipeline;
use super::padding::{crop, pad_round, PaddingOutput};
use burn_tensor::{Element, Shape};
use std::cmp::{max, min};
const MAX_SHARED_MEMORY_SIZE: usize = 8192;
pub(super) fn empty_from_context<E: WgpuElement, const D: usize>(
context: Arc<Context>,
shape: &Shape<D>,
) -> WgpuTensor<E, D> {
let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::<E>());
WgpuTensor::new(context, shape.clone(), buffer)
}
/// Create a source template for tile 2d matmul.
#[macro_export(local_inner_macros)]
macro_rules! matmul_tile_2d {
@ -63,11 +50,11 @@ macro_rules! matmul_tile_2d {
_elem: core::marker::PhantomData<E>,
}
impl<E: WgpuElement> DynamicKernel for $struct<E> {
fn source_template(self) -> SourceTemplate {
impl<E: WgpuElement> DynamicKernelSource for $struct<E> {
fn source(self) -> SourceTemplate {
kernel_wgsl!(Raw, $file);
Raw::source_template()
Raw::source()
.register("b_m", self.b_m.to_string())
.register("b_n", self.b_n.to_string())
.register("b_k", self.b_k.to_string())
@ -90,7 +77,7 @@ macro_rules! matmul_tile_2d {
}
/// Matrix multiplication using tiling 2D algorithm with default parameters
pub fn matmul_tiling_2d_default<E: WgpuElement, const D: usize>(
pub fn matmul_tiling_2d_default<E: WgpuElement + burn_tensor::Element, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
@ -125,7 +112,7 @@ macro_rules! matmul_tile_2d {
/// Matrix multiplication using tiling 2D algorithm with custom parameters
#[allow(clippy::too_many_arguments)]
pub fn matmul_tiling_2d<
E: WgpuElement,
E: WgpuElement + burn_tensor::Element,
const D: usize,
>(
lhs: WgpuTensor<E, D>,
@ -140,14 +127,13 @@ macro_rules! matmul_tile_2d {
) -> WgpuTensor<E, D> {
let kernel = $struct::<E>::new(b_m, b_n, b_k, t_m, t_n, workgroup_size_x, workgroup_size_y);
let kernel = lhs.context.compile_dynamic(kernel);
matmul_tiling_2d_launch::<
E,
D,
$struct::<E>,
>(
lhs,
rhs,
kernel,
b_m,
b_n,
b_k,
@ -155,6 +141,7 @@ macro_rules! matmul_tile_2d {
t_n,
workgroup_size_x,
workgroup_size_y,
kernel,
)
}
@ -412,21 +399,23 @@ pub(super) fn make_workgroup<const D: usize>(
WorkGroup::new(num_blocks_x, num_blocks_y, num_blocks_z as u32)
}
pub(super) fn make_info_buffers<E: WgpuElement, const D: usize>(
pub(super) fn make_info_handle<E: WgpuElement, const D: usize>(
lhs: &WgpuTensor<E, D>,
rhs: &WgpuTensor<E, D>,
output: &WgpuTensor<E, D>,
) -> Arc<wgpu::Buffer> {
) -> WgpuHandle {
let info = build_info(&[lhs, rhs, output]);
rhs.context
.create_buffer_with_data(bytemuck::cast_slice(&info))
rhs.client.create(bytemuck::cast_slice(&info))
}
#[allow(clippy::too_many_arguments)]
pub(super) fn matmul_tiling_2d_launch<E: WgpuElement, const D: usize>(
pub(super) fn matmul_tiling_2d_launch<
E: WgpuElement + Element,
const D: usize,
K: DynamicKernelSource + 'static,
>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
kernel: Arc<ComputePipeline>,
b_m: usize,
b_n: usize,
b_k: usize,
@ -434,6 +423,7 @@ pub(super) fn matmul_tiling_2d_launch<E: WgpuElement, const D: usize>(
t_n: usize,
workgroup_size_x: usize,
workgroup_size_y: usize,
kernel: K,
) -> WgpuTensor<E, D> {
matmul_parameter_assertions::<E, D>(
b_m,
@ -470,15 +460,18 @@ pub(super) fn matmul_tiling_2d_launch<E: WgpuElement, const D: usize>(
let rounded_output_shape = shape_out(&lhs, &rhs);
let output = empty_from_context::<E, D>(rhs.context.clone(), &rounded_output_shape);
let output = empty_device(
rhs.client.clone(),
rhs.device.clone(),
rounded_output_shape.clone(),
);
let workgroup = make_workgroup(rounded_output_shape, b_m, b_n);
let info_buffers = make_info_buffers(&lhs, &rhs, &output);
let info_handle = make_info_handle(&lhs, &rhs, &output);
lhs.context.execute(
workgroup,
kernel,
&[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers],
output.client.execute(
Box::new(DynamicKernel::new(kernel, workgroup)),
&[&lhs.handle, &rhs.handle, &output.handle, &info_handle],
);
crop(output, final_output_shape)

View File

@ -1,7 +1,7 @@
use super::base::matmul_tiling_2d_launch;
use crate::{
element::WgpuElement,
kernel::{DynamicKernel, SourceTemplate, StaticKernel},
kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource},
matmul_tile_2d,
tensor::WgpuTensor,
};

View File

@ -1,7 +1,7 @@
use super::base::matmul_tiling_2d_launch;
use crate::{
element::WgpuElement,
kernel::{DynamicKernel, SourceTemplate, StaticKernel},
kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource},
matmul_tile_2d,
tensor::WgpuTensor,
};

View File

@ -1,15 +1,14 @@
use std::ops::Range;
use burn_tensor::Shape;
use burn_tensor::{Element, Shape};
use crate::{
element::WgpuElement,
kernel::{slice, slice_assign},
ops::numeric::zeros_device,
tensor::WgpuTensor,
};
use super::base::empty_from_context;
// Output of the pad_round function. Allows to know explicitly if early return occurred
pub(super) enum PaddingOutput<E: WgpuElement, const D: usize> {
Padded(WgpuTensor<E, D>),
@ -29,7 +28,7 @@ impl<E: WgpuElement, const D: usize> PaddingOutput<E, D> {
/// divisible by some quantity.
/// For instance tensor of shape [1000, 1000] with divisors 64 and 64
/// will be padded to [1024, 1024] with the last 24 elements being zeros
pub(super) fn pad_round<E: WgpuElement, const D: usize>(
pub(super) fn pad_round<E: WgpuElement + Element, const D: usize>(
tensor: WgpuTensor<E, D>,
row_divisor: usize,
col_divisor: usize,
@ -62,7 +61,7 @@ pub(super) fn pad_round<E: WgpuElement, const D: usize>(
}
/// Pads tensor by adding zeros when padded dim is larger than tensor dim
fn padding<E: WgpuElement, const D: usize>(
fn padding<E: WgpuElement + Element, const D: usize>(
tensor: WgpuTensor<E, D>,
padded_shape: Shape<D>,
) -> WgpuTensor<E, D> {
@ -73,8 +72,9 @@ fn padding<E: WgpuElement, const D: usize>(
.collect::<Vec<Range<usize>>>()
.try_into()
.unwrap();
slice_assign::<E, D, D>(
empty_from_context(tensor.context.clone(), &padded_shape),
zeros_device(tensor.client.clone(), tensor.device.clone(), padded_shape),
ranges,
tensor,
)

View File

@ -1,6 +1,6 @@
use crate::{
element::WgpuElement,
kernel::{DynamicKernel, SourceTemplate, StaticKernel},
kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource},
matmul_tile_2d,
tensor::WgpuTensor,
};

View File

@ -1,7 +1,7 @@
use super::base::matmul_tiling_2d_launch;
use crate::{
element::WgpuElement,
kernel::{DynamicKernel, SourceTemplate, StaticKernel},
kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource},
matmul_tile_2d,
tensor::WgpuTensor,
};

View File

@ -1,396 +0,0 @@
use burn_tensor::{Distribution, Shape, Tensor};
use mem_coalescing::matmul_mem_coalescing;
use crate::{
benchmark::Benchmark,
element::{FloatElement, WgpuElement},
kernel,
tensor::{WgpuTensor, WgpuTensorDyn},
tune::{AutoTuneFunction, AutoTuneKey, Execution, KernelFunction, Tunable},
GraphicsApi, WgpuBackend, WgpuDevice,
};
use std::{marker::PhantomData, sync::Arc};
use super::mem_coalescing;
const TILING_2D_BLOCK_SIZES: [usize; 2] = [64, 128];
const TILING_2D_BLOCK_SIZES_K: [usize; 2] = [16, 32];
const TILING_2D_TILE_SIZES: [usize; 2] = [4, 16];
const MEMORY_COALESCING_WORKGROUP_SIZES: [usize; 3] = [8, 16, 32];
macro_rules! call_dim {
($func:expr, $dim:expr, $( $x:expr ),*) => {
match $dim {
1 => {
let tensor: WgpuTensor<E, 1> = $func($($x,)*);
tensor.into()
},
2 => {
let tensor: WgpuTensor<E, 2> = $func($($x,)*);
tensor.into()
},
3 => {
let tensor: WgpuTensor<E, 3> = $func($($x,)*);
tensor.into()
},
4 => {
let tensor: WgpuTensor<E, 4> = $func($($x,)*);
tensor.into()
},
5 => {
let tensor: WgpuTensor<E, 5> = $func($($x,)*);
tensor.into()
},
6 => {
let tensor: WgpuTensor<E, 6> = $func($($x,)*);
tensor.into()
},
_ => panic!("Tensors of rank 7 and more can't be autotuned."),
}
};
}
macro_rules! tiling2d_tunable {
($name:ident, $func:expr) => {
#[derive(new, Default)]
struct $name<E: WgpuElement> {
b_m: usize,
b_n: usize,
b_k: usize,
t_m: usize,
t_n: usize,
workgroup_size_x: usize,
workgroup_size_y: usize,
_elem: PhantomData<E>,
}
impl<E: WgpuElement> KernelFunction for $name<E> {
type Input = (WgpuTensorDyn<E>, WgpuTensorDyn<E>);
type Output = WgpuTensorDyn<E>;
fn call(&self, (lhs, rhs): Self::Input) -> Self::Output {
#[allow(clippy::too_many_arguments)]
fn call_dyn<E: WgpuElement, const D: usize>(
lhs: WgpuTensorDyn<E>,
rhs: WgpuTensorDyn<E>,
b_m: usize,
b_n: usize,
b_k: usize,
t_m: usize,
t_n: usize,
workgroup_size_x: usize,
workgroup_size_y: usize,
) -> WgpuTensor<E, D> {
$func(
WgpuTensor::<E, D>::from(lhs),
WgpuTensor::<E, D>::from(rhs),
b_m,
b_n,
b_k,
t_m,
t_n,
workgroup_size_x,
workgroup_size_y,
)
}
return call_dim!(
call_dyn,
lhs.shape.len(),
lhs,
rhs,
self.b_m,
self.b_n,
self.b_k,
self.t_m,
self.t_n,
self.workgroup_size_x,
self.workgroup_size_y
);
}
fn description(&self) -> String {
format!(
"Tiling 2D matmul ({}) - B_M {}, B_N {}, B_K {}, T_M {}, T_N {}, W_X {}, W_X {}",
stringify!($name),
self.b_m,
self.b_n,
self.b_k,
self.t_m,
self.t_n,
self.workgroup_size_x,
self.workgroup_size_y
)
}
}
};
}
tiling2d_tunable!(
Tiling2DContiguousLoad,
kernel::matmul::contiguous::matmul_tiling_2d
);
tiling2d_tunable!(Tiling2DTileLoad, kernel::matmul::tile::matmul_tiling_2d);
tiling2d_tunable!(
Tiling2DContiguousLoadVectorized,
kernel::matmul::contiguous_vectorized::matmul_tiling_2d
);
tiling2d_tunable!(
Tiling2DTileLoadVectorized,
kernel::matmul::tile_vectorized::matmul_tiling_2d
);
#[derive(new)]
struct MemoryCoalescing<E: WgpuElement> {
workgroup_size_x: usize,
workgroup_size_y: usize,
_elem: PhantomData<E>,
}
impl<E: WgpuElement> KernelFunction for MemoryCoalescing<E> {
type Input = (WgpuTensorDyn<E>, WgpuTensorDyn<E>);
type Output = WgpuTensorDyn<E>;
fn call(&self, (lhs, rhs): Self::Input) -> Self::Output {
fn call_dyn<E: WgpuElement, const D: usize>(
lhs: WgpuTensorDyn<E>,
rhs: WgpuTensorDyn<E>,
workgroup_size_x: usize,
workgroup_size_y: usize,
) -> WgpuTensor<E, D> {
let lhs = WgpuTensor::from(lhs);
let rhs = WgpuTensor::from(rhs);
matmul_mem_coalescing::<E, D>(lhs, rhs, workgroup_size_x, workgroup_size_y)
}
call_dim!(
call_dyn,
lhs.shape.len(),
lhs,
rhs,
self.workgroup_size_x,
self.workgroup_size_y
)
}
fn description(&self) -> String {
format!(
"Memory Coalescing matmul - W_X {}, W_Y {}",
self.workgroup_size_x, self.workgroup_size_y
)
}
}
#[derive(new)]
struct MatmulBenchmark<F: WgpuElement, const D: usize> {
shape_lhs: Shape<D>,
shape_rhs: Shape<D>,
num_repeats: usize,
matmul: PhantomData<F>,
func: AutoTuneFunction<(WgpuTensorDyn<F>, WgpuTensorDyn<F>), WgpuTensorDyn<F>>,
}
impl<E, const D: usize, G> Benchmark<G> for MatmulBenchmark<E, D>
where
E: WgpuElement + FloatElement,
G: GraphicsApi,
{
type Args = (WgpuTensorDyn<E>, WgpuTensorDyn<E>);
fn name(&self) -> String {
format!("{:?} x {:?}", self.shape_lhs.dims, self.shape_rhs.dims)
}
fn num_samples(&self) -> usize {
5
}
fn execute(&self, (lhs, rhs): Self::Args) {
for _ in 0..self.num_repeats {
self.func.call((lhs.clone(), rhs.clone()));
}
}
fn prepare(&self, device: &WgpuDevice) -> Self::Args {
let lhs = Tensor::<WgpuBackend<G, E, i32>, D>::random_device(
self.shape_lhs.clone(),
Distribution::Default,
device,
);
let rhs = Tensor::<WgpuBackend<G, E, i32>, D>::random_device(
self.shape_rhs.clone(),
Distribution::Default,
device,
);
(lhs.into_primitive().into(), rhs.into_primitive().into())
}
}
/// Choose the best matmul kernel by using autotuning.
pub fn tune<G: GraphicsApi, E, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<E, D>
where
E: WgpuElement + FloatElement,
{
if D > 6 {
log::debug!("Can't autotune matmul for tensors of rank 7 or more.");
return kernel::matmul::matmul_tiling_2d_default(lhs, rhs);
}
let (shape_lhs, shape_rhs) = calculate_benchmark_shapes(lhs.shape.clone(), rhs.shape.clone());
let id = AutoTuneKey::new(
vec![
shape_lhs.dims[D - 2..].to_vec(),
shape_rhs.dims[D - 2..].to_vec(),
],
format!("matmul {}", E::type_name()),
&lhs.context,
);
let context = lhs.context.clone();
let input: (WgpuTensorDyn<E>, WgpuTensorDyn<E>) = (lhs.into(), rhs.into());
let output: WgpuTensorDyn<E> = match context.tuner.execute(&id, input) {
Execution::Executed(output) => output,
Execution::NoCacheFound((lhs, rhs)) => {
let tunables = matmul_candidates::<G, E, D>(shape_lhs, shape_rhs);
context
.tuner
.tune(id, (lhs, rhs), tunables, &context.device, &context)
}
};
output.into()
}
/// Shape dims are anchored to the closest (on a log scale) power of 2
fn calculate_benchmark_shapes<const D: usize>(
lhs: Shape<D>,
rhs: Shape<D>,
) -> (Shape<D>, Shape<D>) {
let anchor = |a| f32::powf(2., f32::min(f32::round(f32::log(a as f32, 2.)), 12.)) as usize;
let m = anchor(lhs.dims[D - 2]);
let k = anchor(lhs.dims[D - 1]);
let n = anchor(rhs.dims[D - 1]);
let mut lhs_shape = [1; D];
lhs_shape[D - 2] = m;
lhs_shape[D - 1] = k;
let lhs_shape = Shape::new(lhs_shape);
let mut rhs_shape = [1; D];
rhs_shape[D - 2] = k;
rhs_shape[D - 1] = n;
let rhs_shape = Shape::new(rhs_shape);
(lhs_shape, rhs_shape)
}
type MatmulTunable<G, E> = Tunable<G, (WgpuTensorDyn<E>, WgpuTensorDyn<E>), WgpuTensorDyn<E>>;
/// Enumerates all matmul versions that are candidates for autotuning
fn matmul_candidates<G: GraphicsApi, E, const D: usize>(
shape_lhs: Shape<D>,
shape_rhs: Shape<D>,
) -> Vec<MatmulTunable<G, E>>
where
E: WgpuElement + FloatElement,
{
let matmul_benchmark =
|func: AutoTuneFunction<(WgpuTensorDyn<E>, WgpuTensorDyn<E>), WgpuTensorDyn<E>>| {
Tunable::<G, _, _>::new(
func.clone(),
Arc::new(MatmulBenchmark::new(
shape_lhs.clone(),
shape_rhs.clone(),
5,
func.clone(),
)),
)
};
let mut candidates = Vec::new();
// All combinations of tiling 2d parameters are pushed for a grid search
for block_size in TILING_2D_BLOCK_SIZES {
for block_size_k in TILING_2D_BLOCK_SIZES_K {
for tile_size in TILING_2D_TILE_SIZES {
candidates.push(matmul_benchmark(Arc::new(
Tiling2DContiguousLoad::<E>::new(
block_size,
block_size,
block_size_k,
tile_size,
tile_size,
block_size / tile_size,
block_size / tile_size,
),
)));
candidates.push(matmul_benchmark(Arc::new(
Tiling2DContiguousLoadVectorized::<E>::new(
block_size,
block_size,
block_size_k,
tile_size,
tile_size,
block_size / tile_size,
block_size / tile_size,
),
)));
candidates.push(matmul_benchmark(Arc::new(Tiling2DTileLoad::<E>::new(
block_size,
block_size,
block_size_k,
tile_size,
tile_size,
block_size / tile_size,
block_size / tile_size,
))));
candidates.push(matmul_benchmark(Arc::new(
Tiling2DTileLoadVectorized::<E>::new(
block_size,
block_size,
block_size_k,
tile_size,
tile_size,
block_size / tile_size,
block_size / tile_size,
),
)));
}
}
// All combinations of tiling 2d parameters are pushed for a grid search
for workgroup_size in MEMORY_COALESCING_WORKGROUP_SIZES {
candidates.push(matmul_benchmark(Arc::new(MemoryCoalescing::new(
workgroup_size,
workgroup_size,
))));
}
}
candidates
}
#[cfg(test)]
mod tests {
use super::calculate_benchmark_shapes;
#[test]
pub fn benchmark_shapes_are_anchored_correctly() {
let m = f32::powf(2., 8.49) as usize;
let k = f32::powf(2., 8.51) as usize;
let n = f32::powf(2., 4.) as usize;
let lhs_shape = [m, k].into();
let rhs_shape = [k, n].into();
let (lhs_shape, rhs_shape) = calculate_benchmark_shapes(lhs_shape, rhs_shape);
assert_eq!(lhs_shape.dims, [256, 512]);
assert_eq!(rhs_shape.dims, [512, 16]);
}
}

View File

@ -1,14 +1,12 @@
use std::sync::Arc;
use burn_tensor::Shape;
use wgpu::Buffer;
use crate::{
compute::{StaticKernel, WgpuHandle},
element::WgpuElement,
kernel::{elemwise_workgroup, KernelSettings},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
};
use burn_tensor::Shape;
kernel_wgsl!(
AdaptiveAvgPool2d,
@ -28,23 +26,16 @@ pub(crate) fn adaptive_avg_pool2d<E: WgpuElement>(
let [batch_size, channels, _, _] = x.shape.dims;
let output_shape = Shape::new([batch_size, channels, output_size[0], output_size[1]]);
let num_elems = output_shape.num_elements();
let output_buffer = x
.context
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(x.context.clone(), output_shape, output_buffer);
let output = empty_device(x.client.clone(), x.device.clone(), output_shape);
let kernel = x
.context
.compile_static::<KernelSettings<AdaptiveAvgPool2d, E, i32, WORKGROUP, WORKGROUP, 1>>();
let kernel =
StaticKernel::<KernelSettings<AdaptiveAvgPool2d, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
);
let info_buffer = build_info(&x, &output);
x.context.execute(
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
kernel,
&[&x.buffer, &output.buffer, &info_buffer],
);
let info_handle = build_info(&x, &output);
x.client
.execute(Box::new(kernel), &[&x.handle, &output.handle, &info_handle]);
output
}
@ -57,32 +48,29 @@ pub(crate) fn adaptive_avg_pool2d_backward<E: WgpuElement>(
let output_shape = x.shape.clone();
let num_elems = output_shape.num_elements();
let output_buffer = x
.context
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(x.context.clone(), output_shape, output_buffer);
let output_buffer = x.client.empty(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(
x.client.clone(),
x.device.clone(),
output_shape,
output_buffer,
);
let kernel = x.context.compile_static::<KernelSettings<
AdaptiveAvgPool2dBackward,
E,
i32,
WORKGROUP,
WORKGROUP,
1,
>>();
let kernel = StaticKernel::<
KernelSettings<AdaptiveAvgPool2dBackward, E, i32, WORKGROUP, WORKGROUP, 1>,
>::new(elemwise_workgroup(output.shape.num_elements(), WORKGROUP));
let info_buffer = build_info(&x, &out_grad);
let info_handle = build_info(&x, &out_grad);
x.context.execute(
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
kernel,
&[&out_grad.buffer, &output.buffer, &info_buffer],
x.client.execute(
Box::new(kernel),
&[&out_grad.handle, &output.handle, &info_handle],
);
output
}
fn build_info<E: WgpuElement>(x: &WgpuTensor<E, 4>, output: &WgpuTensor<E, 4>) -> Arc<Buffer> {
fn build_info<E: WgpuElement>(x: &WgpuTensor<E, 4>, output: &WgpuTensor<E, 4>) -> WgpuHandle {
let mut info: [u32; 16] = [0; 16];
info[0] = x.strides[0] as u32;
info[1] = x.strides[1] as u32;
@ -102,7 +90,5 @@ fn build_info<E: WgpuElement>(x: &WgpuTensor<E, 4>, output: &WgpuTensor<E, 4>) -
info[14] = output.shape.dims[2] as u32;
info[15] = output.shape.dims[3] as u32;
output
.context
.create_buffer_with_data(bytemuck::cast_slice(&info))
output.client.create(bytemuck::cast_slice(&info))
}

View File

@ -1,11 +1,13 @@
use crate::{
compute::{Kernel, StaticKernel},
element::WgpuElement,
kernel::{
self, elemwise_workgroup,
pool::{build_output_and_info_pool2d, build_pool2d_info},
KernelSettings, StaticKernel,
KernelSettings, StaticKernelSource,
},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
};
@ -18,17 +20,15 @@ kernel_wgsl!(
struct AvgPool2dBackward<const COUNT_INCLUDE_PAD: bool>;
struct AvgPool2d<const COUNT_INCLUDE_PAD: bool>;
impl<const COUNT_INCLUDE_PAD: bool> StaticKernel for AvgPool2dBackward<COUNT_INCLUDE_PAD> {
fn source_template() -> kernel::SourceTemplate {
AvgPool2dBackwardRaw::source_template()
.register("count_include_pad", format!("{COUNT_INCLUDE_PAD}"))
impl<const COUNT_INCLUDE_PAD: bool> StaticKernelSource for AvgPool2dBackward<COUNT_INCLUDE_PAD> {
fn source() -> kernel::SourceTemplate {
AvgPool2dBackwardRaw::source().register("count_include_pad", format!("{COUNT_INCLUDE_PAD}"))
}
}
impl<const COUNT_INCLUDE_PAD: bool> StaticKernel for AvgPool2d<COUNT_INCLUDE_PAD> {
fn source_template() -> kernel::SourceTemplate {
AvgPool2dRaw::source_template()
.register("count_include_pad", format!("{COUNT_INCLUDE_PAD}"))
impl<const COUNT_INCLUDE_PAD: bool> StaticKernelSource for AvgPool2d<COUNT_INCLUDE_PAD> {
fn source() -> kernel::SourceTemplate {
AvgPool2dRaw::source().register("count_include_pad", format!("{COUNT_INCLUDE_PAD}"))
}
}
@ -41,22 +41,21 @@ pub(crate) fn avg_pool2d<E: WgpuElement>(
) -> WgpuTensor<E, 4> {
const WORKGROUP: usize = 32;
let (info_buffer, output) =
let (info_handle, output) =
build_output_and_info_pool2d(&x, kernel_size, stride, padding, [1, 1]);
let kernel = match count_include_pad {
true => x
.context
.compile_static::<KernelSettings<AvgPool2d<true>, E, i32, WORKGROUP, WORKGROUP, 1>>(),
false => x
.context
.compile_static::<KernelSettings<AvgPool2d<false>, E, i32, WORKGROUP, WORKGROUP, 1>>(),
let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP);
let kernel: Box<dyn Kernel> = match count_include_pad {
true => Box::new(StaticKernel::<
KernelSettings<AvgPool2d<true>, E, i32, WORKGROUP, WORKGROUP, 1>,
>::new(workgroup)),
false => Box::new(StaticKernel::<
KernelSettings<AvgPool2d<false>, E, i32, WORKGROUP, WORKGROUP, 1>,
>::new(workgroup)),
};
x.context.execute(
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
kernel,
&[&x.buffer, &output.buffer, &info_buffer],
);
x.client
.execute(kernel, &[&x.handle, &output.handle, &info_handle]);
output
}
@ -72,39 +71,21 @@ pub(crate) fn avg_pool2d_backward<E: WgpuElement>(
const WORKGROUP: usize = 32;
let grad = kernel::into_contiguous(grad);
let output = empty_device(x.client.clone(), x.device.clone(), x.shape.clone());
let info_handle = build_pool2d_info(&x, &grad, kernel_size, stride, padding, [1, 1]);
let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP);
let num_elems = x.shape.num_elements();
let buffer = x
.context
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(x.context.clone(), x.shape.clone(), buffer);
let info_buffer = build_pool2d_info(&x, &grad, kernel_size, stride, padding, [1, 1]);
let kernel: Box<dyn Kernel> = match count_include_pad {
true => Box::new(StaticKernel::<
KernelSettings<AvgPool2dBackward<true>, E, i32, WORKGROUP, WORKGROUP, 1>,
>::new(workgroup)),
false => Box::new(StaticKernel::<
KernelSettings<AvgPool2dBackward<false>, E, i32, WORKGROUP, WORKGROUP, 1>,
>::new(workgroup)),
};
let kernel =
match count_include_pad {
true => x.context.compile_static::<KernelSettings<
AvgPool2dBackward<true>,
E,
i32,
WORKGROUP,
WORKGROUP,
1,
>>(),
false => x.context.compile_static::<KernelSettings<
AvgPool2dBackward<false>,
E,
i32,
WORKGROUP,
WORKGROUP,
1,
>>(),
};
x.context.execute(
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
kernel,
&[&grad.buffer, &output.buffer, &info_buffer],
);
x.client
.execute(kernel, &[&grad.handle, &output.handle, &info_handle]);
output
}

View File

@ -1,7 +1,7 @@
use crate::{element::WgpuElement, tensor::WgpuTensor};
use crate::{
compute::WgpuHandle, element::WgpuElement, ops::numeric::empty_device, tensor::WgpuTensor,
};
use burn_tensor::Shape;
use std::sync::Arc;
use wgpu::Buffer;
/// Build basic info to launch pool 2d kernels.
pub fn build_output_and_info_pool2d<E: WgpuElement>(
@ -10,7 +10,7 @@ pub fn build_output_and_info_pool2d<E: WgpuElement>(
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> (Arc<Buffer>, WgpuTensor<E, 4>) {
) -> (WgpuHandle, WgpuTensor<E, 4>) {
let [kernel_height, kernel_width] = kernel_size;
let [padding_height, padding_width] = padding;
let [stride_height, stride_width] = stride;
@ -24,12 +24,7 @@ pub fn build_output_and_info_pool2d<E: WgpuElement>(
/ stride_width)
+ 1;
let shape_out = Shape::new([batch_size, channels, out_height, out_width]);
let num_elems = shape_out.num_elements();
let buffer = x
.context
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(x.context.clone(), shape_out, buffer);
let output = empty_device(x.client.clone(), x.device.clone(), shape_out);
let info_buffer = build_pool2d_info(x, &output, kernel_size, stride, padding, dilation);
@ -43,7 +38,7 @@ pub fn build_pool2d_info<E: WgpuElement>(
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> Arc<Buffer> {
) -> WgpuHandle {
let mut info: [u32; 24] = [0; 24];
info[0] = input.strides[0] as u32;
info[1] = input.strides[1] as u32;
@ -72,9 +67,7 @@ pub fn build_pool2d_info<E: WgpuElement>(
info[22] = dilation[0] as u32;
info[23] = dilation[1] as u32;
let info_buffer = input
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let info_buffer = input.client.create(bytemuck::cast_slice(&info));
info_buffer
}

View File

@ -1,4 +1,5 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{
self, elemwise_workgroup,
@ -6,6 +7,7 @@ use crate::{
KernelSettings,
},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
};
@ -28,18 +30,15 @@ pub(crate) fn max_pool2d<E: WgpuElement>(
) -> WgpuTensor<E, 4> {
const WORKGROUP: usize = 32;
let (info_buffer, output) =
let (info_handle, output) =
build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation);
let kernel = x
.context
.compile_static::<KernelSettings<MaxPool2d, E, i32, WORKGROUP, WORKGROUP, 1>>();
x.context.execute(
let kernel = StaticKernel::<KernelSettings<MaxPool2d, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
kernel,
&[&x.buffer, &output.buffer, &info_buffer],
);
x.client
.execute(Box::new(kernel), &[&x.handle, &output.handle, &info_handle]);
output
}
@ -52,25 +51,17 @@ pub(crate) fn max_pool2d_with_indices<E: WgpuElement, I: WgpuElement>(
) -> (WgpuTensor<E, 4>, WgpuTensor<I, 4>) {
const WORKGROUP: usize = 32;
let (info_buffer, output) =
let (info_handle, output) =
build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation);
let num_elems = output.shape.num_elements();
let indices = empty_device(x.client.clone(), x.device, output.shape.clone());
let indices = WgpuTensor::new(
x.context.clone(),
output.shape.clone(),
x.context
.create_buffer(num_elems * std::mem::size_of::<I>()),
);
let kernel = StaticKernel::<
KernelSettings<MaxPool2dWithIndices, E, i32, WORKGROUP, WORKGROUP, 1>,
>::new(elemwise_workgroup(output.shape.num_elements(), WORKGROUP));
let kernel = x
.context
.compile_static::<KernelSettings<MaxPool2dWithIndices, E, i32, WORKGROUP, WORKGROUP, 1>>();
x.context.execute(
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
kernel,
&[&x.buffer, &output.buffer, &indices.buffer, &info_buffer],
x.client.execute(
Box::new(kernel),
&[&x.handle, &output.handle, &indices.handle, &info_handle],
);
(output, indices)
@ -91,26 +82,18 @@ pub(crate) fn max_pool2d_with_indices_backward<E: WgpuElement, I: WgpuElement>(
let indices = kernel::into_contiguous(indices);
let num_elems = x.shape.num_elements();
let buffer = x
.context
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(x.context.clone(), x.shape.clone(), buffer);
let buffer = x.client.empty(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(x.client.clone(), x.device.clone(), x.shape.clone(), buffer);
let info_buffer = build_pool2d_info(&x, &grad, kernel_size, stride, padding, dilation);
let info_handle = build_pool2d_info(&x, &grad, kernel_size, stride, padding, dilation);
let kernel = x.context.compile_static::<KernelSettings<
MaxPool2dWithIndicesBackward,
E,
I,
WORKGROUP,
WORKGROUP,
1,
>>();
let kernel = StaticKernel::<
KernelSettings<MaxPool2dWithIndicesBackward, E, I, WORKGROUP, WORKGROUP, 1>,
>::new(elemwise_workgroup(output.shape.num_elements(), WORKGROUP));
x.context.execute(
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
kernel,
&[&indices.buffer, &grad.buffer, &output.buffer, &info_buffer],
x.client.execute(
Box::new(kernel),
&[&indices.handle, &grad.handle, &output.handle, &info_handle],
);
output
}

View File

@ -1,11 +1,10 @@
use std::sync::Arc;
use crate::{
compute::{WgpuComputeClient, WgpuHandle},
element::WgpuElement,
kernel_wgsl, SEED,
};
use burn_common::rand::get_seeded_rng;
use burn_tensor::Shape;
use rand::Rng;
use wgpu::Buffer;
use crate::{context::Context, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor, SEED};
kernel_wgsl!(Prng, "../../template/prng/prng.wgsl");
@ -23,22 +22,20 @@ pub(crate) fn get_seeds() -> Vec<u32> {
seeds
}
pub(crate) fn make_output_tensor<E: WgpuElement, const D: usize>(
context: Arc<Context>,
shape: Shape<D>,
) -> WgpuTensor<E, D> {
let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::<E>());
WgpuTensor::new(context, shape, buffer)
}
pub(crate) fn make_info_buffer(context: Arc<Context>, n_values_per_thread: usize) -> Arc<Buffer> {
pub(crate) fn make_info_buffer(
client: WgpuComputeClient,
n_values_per_thread: usize,
) -> WgpuHandle {
let mut info = get_seeds();
info.insert(0, n_values_per_thread as u32);
context.create_buffer_with_data(bytemuck::cast_slice(&info))
client.create(bytemuck::cast_slice(&info))
}
pub(crate) fn make_args_buffer<E: WgpuElement>(context: Arc<Context>, args: &[E]) -> Arc<Buffer> {
context.create_buffer_with_data(E::as_bytes(args))
pub(crate) fn make_args_buffer<E: WgpuElement>(
client: WgpuComputeClient,
args: &[E],
) -> WgpuHandle {
client.create(E::as_bytes(args))
}
#[cfg(test)]

View File

@ -1,12 +1,13 @@
use burn_tensor::Shape;
use crate::{
compute::{compute_client, StaticKernel},
element::WgpuElement,
kernel::{
prng::base::{make_args_buffer, make_info_buffer, make_output_tensor},
prng_workgroup, KernelSettings, SourceTemplate, StaticKernel,
prng::base::{make_args_buffer, make_info_buffer},
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource,
},
pool::get_context,
ops::numeric::empty_device,
tensor::WgpuTensor,
GraphicsApi, WgpuDevice,
};
@ -15,9 +16,9 @@ use super::base::Prng;
struct BernoulliPrng;
impl StaticKernel for BernoulliPrng {
fn source_template() -> SourceTemplate {
Prng::source_template()
impl StaticKernelSource for BernoulliPrng {
fn source() -> SourceTemplate {
Prng::source()
.register("num_args", "1")
.register(
"prng_loop",
@ -36,15 +37,19 @@ pub fn random_bernoulli<G: GraphicsApi, E: WgpuElement, const D: usize>(
const WORKGROUP: usize = 32;
const N_VALUES_PER_THREAD: usize = 128;
let context = get_context::<G>(device);
let output = make_output_tensor(context.clone(), shape.clone());
let info_buffer = make_info_buffer(context.clone(), N_VALUES_PER_THREAD);
let args_buffer = make_args_buffer(context.clone(), &[prob]);
let client = compute_client::<G>(device);
let output = empty_device(client.clone(), device.clone(), shape.clone());
let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD);
let args_handle = make_args_buffer(client.clone(), &[prob]);
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD);
let kernel =
StaticKernel::<KernelSettings<BernoulliPrng, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
workgroup,
);
context.execute(
prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD),
context.compile_static::<KernelSettings<BernoulliPrng, E, i32, WORKGROUP, WORKGROUP, 1>>(),
&[&output.buffer, &info_buffer, &args_buffer],
client.execute(
Box::new(kernel),
&[&output.handle, &info_handle, &args_handle],
);
output

View File

@ -1,12 +1,13 @@
use burn_tensor::Shape;
use crate::{
compute::{compute_client, StaticKernel},
element::WgpuElement,
kernel::{
prng::base::{make_args_buffer, make_info_buffer, make_output_tensor},
prng_workgroup, KernelSettings, SourceTemplate, StaticKernel,
prng::base::{make_args_buffer, make_info_buffer},
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource,
},
pool::get_context,
ops::numeric::empty_device,
tensor::WgpuTensor,
GraphicsApi, WgpuDevice,
};
@ -15,9 +16,9 @@ use super::base::Prng;
struct NormalPrng;
impl StaticKernel for NormalPrng {
fn source_template() -> SourceTemplate {
Prng::source_template()
impl StaticKernelSource for NormalPrng {
fn source() -> SourceTemplate {
Prng::source()
.register("num_args", "2")
.register(
"prng_loop",
@ -39,15 +40,17 @@ pub fn random_normal<G: GraphicsApi, E: WgpuElement, const D: usize>(
const WORKGROUP: usize = 32;
const N_VALUES_PER_THREAD: usize = 128; // must be even
let context = get_context::<G>(device);
let output = make_output_tensor(context.clone(), shape.clone());
let info_buffer = make_info_buffer(context.clone(), N_VALUES_PER_THREAD);
let args_buffer = make_args_buffer(context.clone(), &[mean, std]);
let client = compute_client::<G>(device);
let output = empty_device(client.clone(), device.clone(), shape.clone());
let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD);
let args_handle = make_args_buffer(client.clone(), &[mean, std]);
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD);
let kernel =
StaticKernel::<KernelSettings<NormalPrng, E, i32, WORKGROUP, WORKGROUP, 1>>::new(workgroup);
context.execute(
prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD),
context.compile_static::<KernelSettings<NormalPrng, E, i32, WORKGROUP, WORKGROUP, 1>>(),
&[&output.buffer, &info_buffer, &args_buffer],
client.execute(
Box::new(kernel),
&[&output.handle, &info_handle, &args_handle],
);
output

View File

@ -1,12 +1,13 @@
use burn_tensor::Shape;
use crate::{
compute::{compute_client, StaticKernel},
element::WgpuElement,
kernel::{
prng::base::{make_args_buffer, make_info_buffer, make_output_tensor},
prng_workgroup, KernelSettings, SourceTemplate, StaticKernel,
prng::base::{make_args_buffer, make_info_buffer},
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource,
},
pool::get_context,
ops::numeric::empty_device,
tensor::WgpuTensor,
GraphicsApi, WgpuDevice,
};
@ -15,9 +16,9 @@ use super::base::Prng;
struct UniformPrng;
impl StaticKernel for UniformPrng {
fn source_template() -> SourceTemplate {
Prng::source_template().register("num_args", "2").register(
impl StaticKernelSource for UniformPrng {
fn source() -> SourceTemplate {
Prng::source().register("num_args", "2").register(
"prng_loop",
include_str!("../../template/prng/uniform_inner_loop.wgsl"),
)
@ -34,15 +35,18 @@ pub fn random_uniform<G: GraphicsApi, E: WgpuElement, const D: usize>(
const WORKGROUP: usize = 32;
const N_VALUES_PER_THREAD: usize = 128;
let context = get_context::<G>(device);
let output = make_output_tensor(context.clone(), shape.clone());
let info_buffer = make_info_buffer(context.clone(), N_VALUES_PER_THREAD);
let args_buffer = make_args_buffer(context.clone(), &[low, high]);
let client = compute_client::<G>(device);
let output = empty_device(client.clone(), device.clone(), shape.clone());
let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD);
let args_handle = make_args_buffer(client.clone(), &[low, high]);
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD);
let kernel = StaticKernel::<KernelSettings<UniformPrng, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
workgroup,
);
context.execute(
prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD),
context.compile_static::<KernelSettings<UniformPrng, E, i32, WORKGROUP, WORKGROUP, 1>>(),
&[&output.buffer, &info_buffer, &args_buffer],
client.execute(
Box::new(kernel),
&[&output.handle, &info_handle, &args_handle],
);
output

View File

@ -1,5 +1,8 @@
use super::{build_info, KernelSettings, SourceTemplate, StaticKernel};
use crate::{element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl, tensor::WgpuTensor};
use super::{build_info, KernelSettings, SourceTemplate, StaticKernelSource};
use crate::{
compute::StaticKernel, element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl,
tensor::WgpuTensor,
};
use burn_tensor::Shape;
kernel_wgsl!(RecursiveSumRaw, "../template/reduction/recursive_sum.wgsl");
@ -11,15 +14,15 @@ pub struct ArgsMin;
pub struct SumDim;
pub struct MeanDim;
impl StaticKernel for SumDim {
fn source_template() -> SourceTemplate {
ReductionDimRaw::source_template().register("assign", "output[id] = sum;")
impl StaticKernelSource for SumDim {
fn source() -> SourceTemplate {
ReductionDimRaw::source().register("assign", "output[id] = sum;")
}
}
impl StaticKernel for MeanDim {
fn source_template() -> SourceTemplate {
ReductionDimRaw::source_template()
impl StaticKernelSource for MeanDim {
fn source() -> SourceTemplate {
ReductionDimRaw::source()
.add_template(
"fn mean_dim(sum: {{ elem }}, dim: u32) -> {{ elem }} {
return sum / {{ elem }}(dim);
@ -29,17 +32,17 @@ impl StaticKernel for MeanDim {
}
}
impl StaticKernel for ArgsMax {
fn source_template() -> SourceTemplate {
ReductionArgsRaw::source_template()
impl StaticKernelSource for ArgsMax {
fn source() -> SourceTemplate {
ReductionArgsRaw::source()
.register("cmp", ">")
.register("initial", (-32767).to_string())
}
}
impl StaticKernel for ArgsMin {
fn source_template() -> SourceTemplate {
ReductionArgsRaw::source_template()
impl StaticKernelSource for ArgsMin {
fn source() -> SourceTemplate {
ReductionArgsRaw::source()
.register("cmp", "<")
.register("initial", 32767.to_string())
}
@ -49,28 +52,29 @@ impl StaticKernel for ArgsMin {
pub fn sum<E: WgpuElement, const D: usize>(input: WgpuTensor<E, D>) -> WgpuTensor<E, 1> {
const WORKGROUP: usize = 32;
let mut input_buffer = input.buffer;
let mut input_handle = input.handle;
let mut workgroup = elemwise_workgroup(input.shape.num_elements(), WORKGROUP);
let kernel = input
.context
.compile_static::<KernelSettings<RecursiveSumRaw, E, i32, WORKGROUP, WORKGROUP, 1>>();
loop {
let num_invocations = workgroup.num_invocations();
let buffer = input
.context
.create_buffer(core::mem::size_of::<E>() * num_invocations);
let handle = input
.client
.empty(core::mem::size_of::<E>() * num_invocations);
let kernel =
StaticKernel::<KernelSettings<RecursiveSumRaw, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
workgroup,
);
input
.context
.execute(workgroup.clone(), kernel.clone(), &[&input_buffer, &buffer]);
.client
.execute(Box::new(kernel), &[&input_handle, &handle]);
if num_invocations <= 1 {
return WgpuTensor::new(input.context, Shape::new([1]), buffer);
return WgpuTensor::new(input.client, input.device, Shape::new([1]), handle);
}
input_buffer = buffer;
input_handle = handle;
workgroup = elemwise_workgroup(num_invocations, WORKGROUP);
}
}
@ -91,7 +95,7 @@ pub fn mean_dim<E: WgpuElement, const D: usize>(
reduction_dim::<MeanDim, E, D>(input, dim)
}
fn reduction_dim<K: StaticKernel, E: WgpuElement, const D: usize>(
fn reduction_dim<K: StaticKernelSource, E: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<E, D> {
@ -100,25 +104,25 @@ fn reduction_dim<K: StaticKernel, E: WgpuElement, const D: usize>(
let mut shape_out = input.shape.clone();
shape_out.dims[dim] = 1;
let num_elems = shape_out.num_elements();
let buffer = input
.context
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(input.context.clone(), shape_out, buffer);
let handle = input.client.empty(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(
input.client.clone(),
input.device.clone(),
shape_out,
handle,
);
let kernel = input
.context
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
);
let mut info = build_info(&[&input, &output]);
info.push(dim as u32);
let info_buffers = input
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let info_handle = input.client.create(bytemuck::cast_slice(&info));
input.context.execute(
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
&[&input.buffer, &output.buffer, &info_buffers],
input.client.execute(
Box::new(kernel),
&[&input.handle, &output.handle, &info_handle],
);
output
@ -140,7 +144,7 @@ pub fn argmin<E: WgpuElement, I: WgpuElement, const D: usize>(
reduction_args_dim::<ArgsMin, E, I, D>(input, dim)
}
fn reduction_args_dim<K: StaticKernel, E: WgpuElement, I: WgpuElement, const D: usize>(
fn reduction_args_dim<K: StaticKernelSource, E: WgpuElement, I: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<I, D> {
@ -149,27 +153,27 @@ fn reduction_args_dim<K: StaticKernel, E: WgpuElement, I: WgpuElement, const D:
let mut shape_out = input.shape.clone();
shape_out.dims[dim] = 1;
let num_elems = shape_out.num_elements();
let buffer = input
.context
.create_buffer(num_elems * core::mem::size_of::<I>());
let output = WgpuTensor::new(input.context.clone(), shape_out, buffer);
let kernel = input
.context
.compile_static::<KernelSettings<K, E, I, WORKGROUP, WORKGROUP, 1>>();
let mut info = build_info(&[&input, &output]);
info.push(dim as u32);
let info_buffers = input
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
input.context.execute(
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
&[&input.buffer, &output.buffer, &info_buffers],
let buffer = input.client.empty(num_elems * core::mem::size_of::<I>());
let output = WgpuTensor::new(
input.client.clone(),
input.device.clone(),
shape_out,
buffer,
);
WgpuTensor::new(output.context, output.shape, output.buffer)
let kernel = StaticKernel::<KernelSettings<K, E, I, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
);
let mut info = build_info(&[&input, &output]);
info.push(dim as u32);
let info_handle = input.client.create(bytemuck::cast_slice(&info));
input.client.execute(
Box::new(kernel),
&[&input.handle, &output.handle, &info_handle],
);
WgpuTensor::new(output.client, output.device, output.shape, output.handle)
}
#[cfg(test)]

View File

@ -1,5 +1,5 @@
use super::{elemwise_workgroup, KernelSettings, StaticKernel};
use crate::{element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
use super::{elemwise_workgroup, KernelSettings, StaticKernelSource};
use crate::{compute::StaticKernel, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
kernel_wgsl!(UnaryRaw, "../template/unary.wgsl");
kernel_wgsl!(UnaryInplaceRaw, "../template/unary_inplace.wgsl");
@ -13,9 +13,9 @@ macro_rules! unary {
) => {
pub struct $struct;
impl $crate::kernel::StaticKernel for $struct {
fn source_template() -> $crate::kernel::SourceTemplate {
let source = $crate::kernel::UnaryRaw::source_template();
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
let source = $crate::kernel::UnaryRaw::source();
source.register("body", format!("output[id] = {}(input[id]);", $func))
}
}
@ -26,9 +26,9 @@ macro_rules! unary {
) => {
pub struct $struct;
impl $crate::kernel::StaticKernel for $struct {
fn source_template() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryRaw::source_template().register("body", $body)
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryRaw::source().register("body", $body)
}
}
};
@ -39,9 +39,9 @@ macro_rules! unary {
) => {
pub struct $struct;
impl $crate::kernel::StaticKernel for $struct {
fn source_template() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryRaw::source_template()
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryRaw::source()
.register("body", format!("output[id] = {}(input[id]);", $func))
.add_template(include_str!($file))
}
@ -58,9 +58,9 @@ macro_rules! unary_inplace {
) => {
pub struct $struct;
impl $crate::kernel::StaticKernel for $struct {
fn source_template() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryInplaceRaw::source_template()
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryInplaceRaw::source()
.register("body", format!("input[id] = {}(input[id]);", $func))
}
}
@ -71,9 +71,9 @@ macro_rules! unary_inplace {
) => {
pub struct $struct;
impl $crate::kernel::StaticKernel for $struct {
fn source_template() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryInplaceRaw::source_template().register("body", $body)
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryInplaceRaw::source().register("body", $body)
}
}
};
@ -84,9 +84,9 @@ macro_rules! unary_inplace {
) => {
pub struct $struct;
impl $crate::kernel::StaticKernel for $struct {
fn source_template() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryInplaceRaw::source_template()
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryInplaceRaw::source()
.register("body", format!("input[id] = {}(input[id]);", $func))
.add_template(include_str!($file))
}
@ -95,59 +95,55 @@ macro_rules! unary_inplace {
}
/// Execute a unary kernel using the default settings.
pub fn unary_default<K: StaticKernel, E: WgpuElement, const D: usize>(
pub fn unary_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
unary::<K, E, D, 32>(input)
}
/// Execute a unary inplace kernel using the default settings.
pub fn unary_inplace_default<K: StaticKernel, E: WgpuElement, const D: usize>(
pub fn unary_inplace_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
unary_inplace::<K, E, D, 32>(input)
}
/// Execute a unary inplace kernel using the provided WORKGROUP.
pub fn unary_inplace<K: StaticKernel, E: WgpuElement, const D: usize, const WORKGROUP: usize>(
pub fn unary_inplace<
K: StaticKernelSource,
E: WgpuElement,
const D: usize,
const WORKGROUP: usize,
>(
input: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
let num_elems = input.shape.num_elements();
let kernel = input
.context
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
input.context.execute(
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
&[&input.buffer],
);
input.client.execute(Box::new(kernel), &[&input.handle]);
input
}
/// Execute a unary kernel using the provided WORKGROUP.
pub fn unary<K: StaticKernel, E: WgpuElement, const D: usize, const WORKGROUP: usize>(
pub fn unary<K: StaticKernelSource, E: WgpuElement, const D: usize, const WORKGROUP: usize>(
input: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
let num_elems = input.shape.num_elements();
let buffer = input
.context
.create_buffer(num_elems * core::mem::size_of::<E>());
let mut output = WgpuTensor::new(input.context.clone(), input.shape, buffer);
let buffer = input.client.empty(num_elems * core::mem::size_of::<E>());
let mut output = WgpuTensor::new(input.client.clone(), input.device, input.shape, buffer);
// Since we don't handle the stride inside the kernel, the output tensor have the same strides
// as the input tensor. It might not be in the default format.
output.strides = input.strides;
let kernel = input
.context
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
input.context.execute(
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
&[&input.buffer, &output.buffer],
);
input
.client
.execute(Box::new(kernel), &[&input.handle, &output.handle]);
output
}

View File

@ -1,5 +1,5 @@
use super::{elemwise_workgroup, KernelSettings, StaticKernel};
use crate::{element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
use super::{elemwise_workgroup, KernelSettings, StaticKernelSource};
use crate::{compute::StaticKernel, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
kernel_wgsl!(UnaryScalarRaw, "../template/unary_scalar.wgsl");
kernel_wgsl!(
@ -16,9 +16,9 @@ macro_rules! unary_scalar {
) => {
pub struct $struct;
impl $crate::kernel::StaticKernel for $struct {
fn source_template() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryScalarRaw::source_template()
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryScalarRaw::source()
.register("body", format!("output[id] = lhs[id] {} rhs;", $ops))
}
}
@ -30,9 +30,9 @@ macro_rules! unary_scalar {
) => {
pub struct $struct;
impl $crate::kernel::StaticKernel for $struct {
fn source_template() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryScalarRaw::source_template()
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryScalarRaw::source()
.register("body", format!("output[id] = {}(lhs[id], rhs);", $func))
}
}
@ -45,9 +45,9 @@ macro_rules! unary_scalar {
) => {
pub struct $struct;
impl $crate::kernel::StaticKernel for $struct {
fn source_template() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryScalarRaw::source_template()
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryScalarRaw::source()
.register("body", format!("output[id] = {}(lhs[id], rhs);", $func))
.add_template(include_str!($file))
}
@ -64,9 +64,9 @@ macro_rules! unary_scalar_inplace {
) => {
pub struct $struct;
impl $crate::kernel::StaticKernel for $struct {
fn source_template() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryScalarInplaceRaw::source_template()
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryScalarInplaceRaw::source()
.register("body", format!("lhs[id] = lhs[id] {} rhs;", $ops))
}
}
@ -78,9 +78,9 @@ macro_rules! unary_scalar_inplace {
) => {
pub struct $struct;
impl $crate::kernel::StaticKernel for $struct {
fn source_template() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryScalarInplaceRaw::source_template()
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryScalarInplaceRaw::source()
.register("body", format!("lhs[id] = {}(lhs[id], rhs);", $func))
}
}
@ -93,9 +93,9 @@ macro_rules! unary_scalar_inplace {
) => {
pub struct $struct;
impl $crate::kernel::StaticKernel for $struct {
fn source_template() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryScalarInplaceRaw::source_template()
impl $crate::kernel::StaticKernelSource for $struct {
fn source() -> $crate::kernel::SourceTemplate {
$crate::kernel::UnaryScalarInplaceRaw::source()
.register("body", format!("lhs[id] = {}(lhs[id], rhs);", $func))
.add_template(include_str!($file))
}
@ -104,7 +104,7 @@ macro_rules! unary_scalar_inplace {
}
/// Execute a unary scalar kernel using the default settings.
pub fn unary_scalar_default<K: StaticKernel, E: WgpuElement, const D: usize>(
pub fn unary_scalar_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
scalar: E,
) -> WgpuTensor<E, D> {
@ -112,31 +112,33 @@ pub fn unary_scalar_default<K: StaticKernel, E: WgpuElement, const D: usize>(
}
/// Execute a unary scalar kernel using the provided WORKGROUP.
pub fn unary_scalar<K: StaticKernel, E: WgpuElement, const D: usize, const WORKGROUP: usize>(
pub fn unary_scalar<
K: StaticKernelSource,
E: WgpuElement,
const D: usize,
const WORKGROUP: usize,
>(
lhs: WgpuTensor<E, D>,
scalar: E,
) -> WgpuTensor<E, D> {
let num_elems = lhs.shape.num_elements();
let buffer = lhs
.context
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(lhs.context.clone(), lhs.shape, buffer);
let kernel = lhs
.context
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
let rhs_buffer = lhs.context.create_buffer_with_data(E::as_bytes(&[scalar]));
lhs.context.execute(
let buffer = lhs.client.empty(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(lhs.client.clone(), lhs.device, lhs.shape, buffer);
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
&[&lhs.buffer, &rhs_buffer, &output.buffer],
);
let rhs_handle = lhs.client.create(E::as_bytes(&[scalar]));
lhs.client.execute(
Box::new(kernel),
&[&lhs.handle, &rhs_handle, &output.handle],
);
output
}
/// Execute a unary scalar inplace kernel using the default settings.
pub fn unary_scalar_inplace_default<K: StaticKernel, E: WgpuElement, const D: usize>(
pub fn unary_scalar_inplace_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
scalar: E,
) -> WgpuTensor<E, D> {
@ -145,7 +147,7 @@ pub fn unary_scalar_inplace_default<K: StaticKernel, E: WgpuElement, const D: us
/// Execute a unary scalar inplace kernel using the provided WORKGROUP.
pub fn unary_scalar_inplace<
K: StaticKernel,
K: StaticKernelSource,
E: WgpuElement,
const D: usize,
const WORKGROUP: usize,
@ -153,19 +155,14 @@ pub fn unary_scalar_inplace<
lhs: WgpuTensor<E, D>,
scalar: E,
) -> WgpuTensor<E, D> {
let kernel = lhs
.context
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
let rhs_buffer = lhs.context.create_buffer_with_data(E::as_bytes(&[scalar]));
lhs.context.execute(
{
let num_elems = lhs.shape.num_elements();
elemwise_workgroup(num_elems, WORKGROUP)
},
kernel,
&[&lhs.buffer, &rhs_buffer],
let num_elems = lhs.shape.num_elements();
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
elemwise_workgroup(num_elems, WORKGROUP),
);
let rhs_handle = lhs.client.create(E::as_bytes(&[scalar]));
lhs.client
.execute(Box::new(kernel), &[&lhs.handle, &rhs_handle]);
lhs
}

View File

@ -10,20 +10,13 @@ mod ops;
/// Benchmark module
pub mod benchmark;
/// Context module.
pub mod context;
/// Compute related module.
pub mod compute;
/// Kernel module
pub mod kernel;
/// Tensor module.
pub mod tensor;
#[cfg(test)] // Only enabled for dev for now.
/// Compute related module.
pub mod compute;
pub(crate) mod pool;
pub(crate) mod tune;
mod element;
pub use element::{FloatElement, IntElement};

View File

@ -1,5 +1,6 @@
use crate::{
element::WgpuElement, kernel, pool::get_context, tensor::WgpuTensor, GraphicsApi, WgpuDevice,
compute::compute_client, element::WgpuElement, kernel, tensor::WgpuTensor, GraphicsApi,
WgpuDevice,
};
use burn_tensor::{backend::Backend, Data, Shape};
@ -18,15 +19,15 @@ pub fn from_data<G: GraphicsApi, E: WgpuElement, const D: usize>(
data: Data<E, D>,
device: &WgpuDevice,
) -> WgpuTensor<E, D> {
let context = get_context::<G>(device);
let buffer = context.create_buffer_with_data_options(E::as_bytes(&data.value), true);
let client = compute_client::<G>(device);
let buffer = client.create(E::as_bytes(&data.value));
WgpuTensor::new(context, data.shape, buffer)
WgpuTensor::new(client, device.clone(), data.shape, buffer)
}
pub fn into_data<E: WgpuElement, const D: usize>(tensor: WgpuTensor<E, D>) -> Data<E, D> {
let tensor = kernel::into_contiguous(tensor);
let bytes = tensor.context.read_buffer(tensor.buffer);
let bytes = tensor.client.read(&tensor.handle);
let values = E::from_bytes(&bytes);
Data::new(values.to_vec(), tensor.shape)
@ -36,22 +37,22 @@ pub fn to_device<G: GraphicsApi, E: WgpuElement, const D: usize>(
tensor: WgpuTensor<E, D>,
device: &WgpuDevice,
) -> WgpuTensor<E, D> {
if &tensor.context.device == device {
if &tensor.device == device {
return tensor;
}
let context = get_context::<G>(device);
tensor.to_context(context)
let client = compute_client::<G>(device);
tensor.to_client(client, device.clone())
}
pub fn empty<G: GraphicsApi, E: WgpuElement, const D: usize>(
shape: Shape<D>,
device: &WgpuDevice,
) -> WgpuTensor<E, D> {
let context = get_context::<G>(device);
let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::<E>());
let client = compute_client::<G>(device);
let buffer = client.empty(shape.num_elements() * core::mem::size_of::<E>());
WgpuTensor::new(context, shape, buffer)
WgpuTensor::new(client, device.clone(), shape, buffer)
}
pub fn swap_dims<E: WgpuElement, const D: usize>(
@ -72,5 +73,5 @@ pub fn reshape<E: WgpuElement, const D1: usize, const D2: usize>(
// TODO: Not force standard layout all the time (improve performance).
let tensor = kernel::into_contiguous(tensor);
WgpuTensor::new(tensor.context, shape, tensor.buffer)
WgpuTensor::new(tensor.client, tensor.device, shape, tensor.handle)
}

View File

@ -47,7 +47,7 @@ where
fn bool_into_int<const D: usize>(tensor: BoolTensor<Self, D>) -> IntTensor<Self, D> {
if std::mem::size_of::<I>() == std::mem::size_of::<u32>() {
return WgpuTensor::new(tensor.context, tensor.shape, tensor.buffer);
return WgpuTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle);
}
let device = Self::bool_device(&tensor);
@ -57,7 +57,7 @@ where
}
fn bool_device<const D: usize>(tensor: &BoolTensor<Self, D>) -> Device<Self> {
tensor.context.device.clone()
tensor.device.clone()
}
fn bool_to_device<const D: usize>(

View File

@ -57,7 +57,7 @@ where
}
fn device<const D: usize>(tensor: &FloatTensor<Self, D>) -> Device<Self> {
tensor.context.device.clone()
tensor.device.clone()
}
fn to_device<const D: usize>(
@ -139,11 +139,7 @@ where
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
#[cfg(feature = "autotune")]
return kernel::matmul::tune::<G, F, D>(lhs, rhs);
#[cfg(not(feature = "autotune"))]
kernel::matmul::contiguous_vectorized::matmul_tiling_2d_default(lhs, rhs)
kernel::matmul::contiguous::matmul_tiling_2d_default(lhs, rhs)
}
fn swap_dims<const D: usize>(

View File

@ -33,7 +33,7 @@ where
}
fn int_device<const D: usize>(tensor: &IntTensor<Self, D>) -> Device<Self> {
tensor.context.device.clone()
tensor.device.clone()
}
fn int_to_device<const D: usize>(

View File

@ -1,8 +1,8 @@
use crate::compute::{compute_client, WgpuComputeClient};
use crate::kernel::{
binary_elemwise_default, binary_elemwise_inplace_default, unary_scalar_default,
unary_scalar_inplace_default,
};
use crate::pool::get_context;
use crate::{
binary_elemwise, binary_elemwise_inplace, element::WgpuElement, tensor::WgpuTensor,
unary_scalar, unary_scalar_inplace,
@ -10,26 +10,38 @@ use crate::{
use crate::{GraphicsApi, WgpuDevice};
use burn_tensor::{Element, ElementConversion, Shape};
pub fn zeros<G: GraphicsApi, E: WgpuElement, const D: usize>(
pub fn zeros<G: GraphicsApi, E: WgpuElement + Element, const D: usize>(
shape: Shape<D>,
device: &WgpuDevice,
) -> WgpuTensor<E, D> {
let context = get_context::<G>(device);
let client = compute_client::<G>(device);
let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::<E>());
zeros_device(client, device.clone(), shape)
}
WgpuTensor::new(context, shape, buffer)
pub fn zeros_device<E: WgpuElement + Element, const D: usize>(
client: WgpuComputeClient,
device: WgpuDevice,
shape: Shape<D>,
) -> WgpuTensor<E, D> {
mul_scalar(empty_device(client, device, shape), 0i32.elem::<E>())
}
pub fn empty_device<E: WgpuElement, const D: usize>(
client: WgpuComputeClient,
device: WgpuDevice,
shape: Shape<D>,
) -> WgpuTensor<E, D> {
let buffer = client.empty(shape.num_elements() * core::mem::size_of::<E>());
WgpuTensor::new(client, device, shape, buffer)
}
pub fn ones<G: GraphicsApi, E: WgpuElement + Element, const D: usize>(
shape: Shape<D>,
device: &WgpuDevice,
) -> WgpuTensor<E, D> {
let context = get_context::<G>(device);
let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::<E>());
add_scalar(WgpuTensor::new(context, shape, buffer), 1i32.elem::<E>())
add_scalar(zeros::<G, E, D>(shape, device), 1i32.elem::<E>())
}
pub fn add<E: WgpuElement, const D: usize>(

View File

@ -1,63 +0,0 @@
use crate::{context::Context, GraphicsApi, WgpuDevice};
use std::{
any::TypeId,
collections::HashMap,
sync::{Arc, Mutex},
};
static POOL_CONTEXT: Mutex<Option<ContextPool>> = Mutex::new(None);
#[derive(Default)]
struct ContextPool {
contexts: HashMap<Key, Arc<Context>>,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
struct Key {
api_id: TypeId,
device: WgpuDevice,
}
impl Key {
fn new<G: GraphicsApi>(device: &WgpuDevice) -> Self {
Self {
api_id: TypeId::of::<G>(),
device: device.clone(),
}
}
}
/// Get a [context](Context) for the given [device](WGPUDevice).
///
/// # Notes
///
/// If a context already exist for the current [device](WGPUDevice), the same instance will be
/// returned.
pub fn get_context<G: GraphicsApi>(device: &WgpuDevice) -> Arc<Context> {
let mut pool = POOL_CONTEXT.lock().unwrap();
let context = if let Some(pool) = pool.as_mut() {
// Fetch device in pool
match pool.contexts.get(&Key::new::<G>(device)) {
Some(context) => context.clone(),
None => {
// Init new device
let context = Arc::new(Context::new::<G>(device));
pool.contexts.insert(Key::new::<G>(device), context.clone());
context
}
}
} else {
// Initialize pool
let context = Arc::new(Context::new::<G>(device));
let mut new_pool = ContextPool::default();
new_pool
.contexts
.insert(Key::new::<G>(device), context.clone());
*pool = Some(new_pool);
context
};
context
}

View File

@ -1,19 +1,22 @@
use crate::unary;
use crate::{
compute::{WgpuComputeClient, WgpuHandle},
unary, WgpuDevice,
};
use crate::{element::WgpuElement, kernel::unary_default};
use burn_tensor::Shape;
use std::{marker::PhantomData, sync::Arc};
use wgpu::Buffer;
use crate::{context::Context, element::WgpuElement, kernel::unary_default};
use std::marker::PhantomData;
/// The basic tensor primitive struct.
#[derive(Debug, Clone)]
pub struct WgpuTensor<E: WgpuElement, const D: usize> {
/// The context the tensor is binded to.
pub context: Arc<Context>,
/// Compute client for wgpu.
pub client: WgpuComputeClient,
/// The buffer where the data are stored.
pub buffer: Arc<Buffer>,
pub handle: WgpuHandle,
/// The shape of the current tensor.
pub shape: Shape<D>,
/// The device of the current tensor.
pub device: WgpuDevice,
/// The strides of the current tensor.
pub strides: [usize; D],
elem: PhantomData<E>,
@ -21,8 +24,12 @@ pub struct WgpuTensor<E: WgpuElement, const D: usize> {
#[derive(Debug, Clone)]
pub(crate) struct WgpuTensorDyn<E: WgpuElement> {
pub(crate) context: Arc<Context>,
pub(crate) buffer: Arc<Buffer>,
/// Compute client for wgpu.
pub client: WgpuComputeClient,
/// The buffer where the data are stored.
pub handle: WgpuHandle,
/// The device of the current tensor.
pub device: WgpuDevice,
pub(crate) shape: Vec<usize>,
pub(crate) strides: Vec<usize>,
elem: PhantomData<E>,
@ -31,8 +38,9 @@ pub(crate) struct WgpuTensorDyn<E: WgpuElement> {
impl<E: WgpuElement, const D: usize> From<WgpuTensor<E, D>> for WgpuTensorDyn<E> {
fn from(value: WgpuTensor<E, D>) -> Self {
WgpuTensorDyn {
context: value.context,
buffer: value.buffer,
client: value.client,
handle: value.handle,
device: value.device,
shape: value.shape.dims.to_vec(),
strides: value.strides.to_vec(),
elem: PhantomData,
@ -43,8 +51,9 @@ impl<E: WgpuElement, const D: usize> From<WgpuTensor<E, D>> for WgpuTensorDyn<E>
impl<E: WgpuElement, const D: usize> From<WgpuTensorDyn<E>> for WgpuTensor<E, D> {
fn from(value: WgpuTensorDyn<E>) -> Self {
WgpuTensor {
context: value.context,
buffer: value.buffer,
client: value.client,
handle: value.handle,
device: value.device,
shape: Shape::new(value.shape.try_into().expect("Wrong dimension")),
strides: value.strides.try_into().expect("Wrong dimension"),
elem: PhantomData,
@ -54,7 +63,12 @@ impl<E: WgpuElement, const D: usize> From<WgpuTensorDyn<E>> for WgpuTensor<E, D>
impl<E: WgpuElement, const D: usize> WgpuTensor<E, D> {
/// Create a new tensor.
pub fn new(context: Arc<Context>, shape: Shape<D>, buffer: Arc<Buffer>) -> Self {
pub fn new(
client: WgpuComputeClient,
device: WgpuDevice,
shape: Shape<D>,
handle: WgpuHandle,
) -> Self {
let mut strides = [0; D];
let mut current = 1;
@ -69,29 +83,31 @@ impl<E: WgpuElement, const D: usize> WgpuTensor<E, D> {
});
Self {
context,
buffer,
client,
handle,
shape,
strides,
device,
elem: PhantomData,
}
}
/// Change the context of the current tensor and return the newly transferred tensor.
pub fn to_context(&self, context: Arc<Context>) -> Self {
let data = self.context.read_buffer(self.buffer.clone());
let buffer = context.create_buffer_with_data(&data);
pub fn to_client(&self, client: WgpuComputeClient, device: WgpuDevice) -> Self {
let data = self.client.read(&self.handle);
let handle = client.create(&data);
Self {
context,
buffer,
client,
handle,
shape: self.shape.clone(),
strides: self.strides,
device,
elem: PhantomData,
}
}
pub(crate) fn can_mut_broadcast(&self, tensor_other: &WgpuTensor<E, D>) -> bool {
if Arc::strong_count(&self.buffer) > 1 {
if !self.handle.can_mut() {
return false;
}
@ -120,19 +136,15 @@ impl<E: WgpuElement, const D: usize> WgpuTensor<E, D> {
/// Check if the tensor is safe to mutate.
pub fn can_mut(&self) -> bool {
if Arc::strong_count(&self.buffer) > 1 {
return false;
}
true
self.handle.can_mut()
}
/// Assert that both tensors are on the same device.
pub fn assert_is_on_same_device(&self, other: &Self) {
if self.context.device != other.context.device {
if self.device != other.device {
panic!(
"Both tensors should be on the same device {:?} != {:?}",
self.context.device, other.context.device
self.device, other.device
);
}
}

View File

@ -1,196 +0,0 @@
use std::{collections::HashMap, fmt::Display, hash::Hash, sync::Arc, time::Duration};
use burn_common::stub::RwLock;
use crate::{
benchmark::{Benchmark, BenchmarkResult},
context::Context,
GraphicsApi, WgpuDevice,
};
/// Key used for caching.
#[derive(Hash, Clone, Debug, PartialEq, Eq)]
pub struct AutoTuneKey {
/// List all shapes used for the autotuned kernel.
shapes: Vec<Vec<usize>>,
/// Operation name.
ops: String,
/// Device name used to benchmark.
device: String,
/// Graphics api name.
graphics_api: String,
}
impl AutoTuneKey {
pub fn new(shapes: Vec<Vec<usize>>, ops: String, context: &Context) -> Self {
let device = format!("{:?}", context.info.name);
let graphics_api = format!("{:?}", context.info.backend);
Self {
shapes,
ops,
device,
graphics_api,
}
}
}
impl Display for AutoTuneKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let message = format!(
"(AutoTuneKey) Kernel {} - Shapes {:?} - Device {} - API {}",
self.ops, self.shapes, self.device, self.graphics_api,
);
f.write_str(&message)
}
}
/// Objects that are stored in the tuner cache. Can have any inputs and outputs.
pub type AutoTuneValue = Box<dyn core::any::Any + Send + Sync>;
/// Executable function
pub trait KernelFunction: Send + Sync + 'static {
type Input;
type Output;
fn call(&self, input: Self::Input) -> Self::Output;
fn description(&self) -> String;
}
/// Encapsulates kernel functions, with specified inputs and outputs
pub type AutoTuneFunction<I, O> = Arc<dyn KernelFunction<Input = I, Output = O>>;
/// The tunable links an executable function to its corresponding benchmark
#[derive(new)]
pub struct Tunable<G, I, O> {
func: AutoTuneFunction<I, O>,
benchmark: Arc<dyn Benchmark<G, Args = I>>,
}
impl<G, I, O> std::fmt::Display for Tunable<G, I, O>
where
G: GraphicsApi,
I: Send + Sync + 'static,
O: Send + Sync + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.func.description().as_str())
}
}
/// Output of the tuner execution. If execution succeeded, the output of
/// the execution is contained. Otherwise, the function must be tuned and
/// the input is given back to preserve ownership.
#[derive(Debug)]
pub enum Execution<I, O> {
Executed(O),
NoCacheFound(I),
}
/// The tuner allows to find the best version of a kernel by benchmarking
/// different versions. It keeps the best version found in a cache, so the best
/// function is reused automatically in similar circumstances.
#[derive(Debug)]
pub struct Tuner {
cache: RwLock<HashMap<AutoTuneKey, AutoTuneValue>>,
}
impl Tuner {
pub fn new() -> Self {
Self {
cache: RwLock::new(HashMap::new()),
}
}
/// Executes the function stored in the cache at key id, on specified input,
/// and returns its output. If cache has no such id, returns NoCacheFound.
pub fn execute<I, O>(&self, id: &AutoTuneKey, input: I) -> Execution<I, O>
where
I: Send + Sync + 'static,
O: Send + Sync + 'static,
{
let cache = self.cache.read().unwrap();
let obj = cache.get(id);
let obj = match obj {
None => return Execution::NoCacheFound(input),
Some(value) => value,
};
let func: &Arc<dyn KernelFunction<Input = I, Output = O>> = obj.downcast_ref().unwrap();
let output = func.call(input);
Execution::Executed(output)
}
/// Finds the best tunable and writes it to the cache.
pub fn tune<G: GraphicsApi, I, O>(
&self,
id: AutoTuneKey,
input: I,
tunables: Vec<Tunable<G, I, O>>,
device: &WgpuDevice,
context: &Context,
) -> O
where
I: Send + Sync + 'static,
O: Send + Sync + 'static,
{
let mut cache = self.cache.write().unwrap();
context.start_tuning();
let results = benchmark(&tunables, device);
let kernel = find_best(&id, tunables, results);
cache.insert(id.clone(), kernel);
drop(cache);
context.stop_tuning();
match self.execute(&id, input) {
Execution::Executed(output) => output,
_ => panic!("Should have found a kernel to execute. "),
}
}
}
/// Finds the best kernel by keeping the one with the smallest median duration.
fn find_best<G: GraphicsApi, I, O>(
id: &AutoTuneKey,
tunables: Vec<Tunable<G, I, O>>,
results: Vec<BenchmarkResult>,
) -> AutoTuneValue
where
I: Send + Sync + 'static,
O: Send + Sync + 'static,
{
let mut best_duration = Duration::MAX;
let mut best_tunable = None;
for (tunable, result) in tunables.into_iter().zip(results) {
let duration = result.median_duration();
if duration < best_duration {
best_duration = duration;
best_tunable = Some(tunable);
}
}
let tunable = best_tunable.expect("At least one tunable needed. ");
log::info!("{} => {}", id, tunable);
Box::new(tunable.func)
}
/// Run benchmarks.
fn benchmark<G: GraphicsApi, I, O>(
tunables: &[Tunable<G, I, O>],
device: &WgpuDevice,
) -> Vec<BenchmarkResult>
where
I: Send + Sync + 'static,
O: Send + Sync + 'static,
{
tunables
.iter()
.map(|tunable| tunable.benchmark.run(device))
.collect()
}

View File

@ -1,3 +0,0 @@
mod base;
pub use base::*;

View File

@ -44,7 +44,6 @@ ndarray-blas-openblas = ["burn-core/ndarray-blas-openblas"]
ndarray-blas-openblas-system = ["burn-core/ndarray-blas-openblas-system"]
wgpu = ["burn-core/wgpu"]
wgpu-autotune = ["burn-core/wgpu-autotune"]
tch = ["burn-core/tch"]

View File

@ -2,8 +2,10 @@ use crate::FloatTensor;
use super::Backend;
use burn::backend::wgpu::{
context::WorkGroup,
kernel::{build_info, into_contiguous, DynamicKernel, SourceTemplate, StaticKernel},
compute::{DynamicKernel, WorkGroup},
kernel::{
build_info, into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource,
},
kernel_wgsl,
tensor::WgpuTensor,
FloatElement, GraphicsApi, IntElement, WgpuBackend,
@ -24,11 +26,11 @@ struct FusedMatmulAddRelu<E: FloatElement> {
}
// Implement the dynamic kernel trait for our kernel type.
impl<E: FloatElement> DynamicKernel for FusedMatmulAddRelu<E> {
fn source_template(self) -> SourceTemplate {
impl<E: FloatElement> DynamicKernelSource for FusedMatmulAddRelu<E> {
fn source(self) -> SourceTemplate {
// Extend our raw kernel with workgroup size information using the
// `SourceTemplate` trait.
FusedMatmulAddReluRaw::source_template()
FusedMatmulAddReluRaw::source()
.register("workgroup_size_x", self.workgroup_size_x.to_string())
.register("workgroup_size_y", self.workgroup_size_y.to_string())
.register("elem", E::type_name())
@ -76,23 +78,18 @@ impl<G: GraphicsApi, F: FloatElement, I: IntElement> Backend for WgpuBackend<G,
// Create a buffer for the output tensor.
let buffer = lhs
.context
.create_buffer(shape_out.num_elements() * core::mem::size_of::<F>());
.client
.empty(shape_out.num_elements() * core::mem::size_of::<F>());
// Create the output tensor primitive.
let output = WgpuTensor::new(lhs.context.clone(), shape_out, buffer);
let output = WgpuTensor::new(lhs.client.clone(), lhs.device.clone(), shape_out, buffer);
// Compile the kernel or use the cache based on the template id.
let kernel = lhs.context.compile_dynamic(FusedMatmulAddRelu::<F>::new(
workgroup_size_x,
workgroup_size_y,
));
// Create the kernel.
let kernel = FusedMatmulAddRelu::<F>::new(workgroup_size_x, workgroup_size_y);
// Build info buffer with tensor information needed by the kernel, such as shapes and strides.
let info = build_info(&[&lhs, &rhs, &output]);
let info_buffer = lhs
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
// Declare the wgsl workgroup with the number of blocks in x, y and z.
let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32;
@ -100,15 +97,14 @@ impl<G: GraphicsApi, F: FloatElement, I: IntElement> Backend for WgpuBackend<G,
let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_batches as u32);
// Execute lazily the kernel with the launch information and the given buffers.
lhs.context.execute(
workgroup,
kernel,
lhs.client.execute(
Box::new(DynamicKernel::new(kernel, workgroup)),
&[
&lhs.buffer,
&rhs.buffer,
&bias.buffer,
&output.buffer,
&info_buffer,
&lhs.handle,
&rhs.handle,
&bias.handle,
&output.handle,
&info_handle,
],
);