mirror of https://github.com/tracel-ai/burn.git
Refactor/burn compute wgpu (#826)
This commit is contained in:
parent
7d706fae98
commit
95e660488e
|
@ -32,6 +32,10 @@ impl<B: Backend> Backend for ADBackendDecorator<B> {
|
|||
fn seed(seed: u64) {
|
||||
B::seed(seed)
|
||||
}
|
||||
|
||||
fn sync(device: &B::Device) {
|
||||
B::sync(device);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ADBackend for ADBackendDecorator<B> {
|
||||
|
|
|
@ -3,7 +3,7 @@ use alloc::vec::Vec;
|
|||
|
||||
/// The ComputeChannel trait links the ComputeClient to the ComputeServer
|
||||
/// while ensuring thread-safety
|
||||
pub trait ComputeChannel<Server: ComputeServer>: Clone {
|
||||
pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug {
|
||||
/// Given a handle, returns owned resource as bytes
|
||||
fn read(&self, handle: &Handle<Server>) -> Vec<u8>;
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ use alloc::vec::Vec;
|
|||
///
|
||||
/// This is mosly useful for `no-std` environments where threads aren't supported, otherwise prefer
|
||||
/// the [mutex](super::MutexComputeChannel) or the [mpsc](super::MpscComputeChannel) channels.
|
||||
#[derive(Debug)]
|
||||
pub struct RefCellComputeChannel<Server> {
|
||||
server: Arc<core::cell::RefCell<Server>>,
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ use crate::server::{ComputeServer, Handle};
|
|||
|
||||
/// Create a channel using the [multi-producer, single-consumer channel](mpsc) to communicate with
|
||||
/// the compute server spawn on its own thread.
|
||||
#[derive(Debug)]
|
||||
pub struct MpscComputeChannel<Server>
|
||||
where
|
||||
Server: ComputeServer,
|
||||
|
@ -15,6 +16,7 @@ where
|
|||
state: Arc<MpscComputeChannelState<Server>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MpscComputeChannelState<Server>
|
||||
where
|
||||
Server: ComputeServer,
|
||||
|
|
|
@ -6,6 +6,7 @@ use spin::Mutex;
|
|||
|
||||
/// The MutexComputeChannel ensures thread-safety by locking the server
|
||||
/// on every operation
|
||||
#[derive(Debug)]
|
||||
pub struct MutexComputeChannel<Server> {
|
||||
server: Arc<Mutex<Server>>,
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ use core::marker::PhantomData;
|
|||
|
||||
/// The ComputeClient is the entry point to require tasks from the ComputeServer.
|
||||
/// It should be obtained for a specific device via the Compute struct.
|
||||
#[derive(Debug)]
|
||||
pub struct ComputeClient<Server, Channel> {
|
||||
channel: Channel,
|
||||
_server: PhantomData<Server>,
|
||||
|
|
|
@ -29,7 +29,7 @@ macro_rules! storage_id_type {
|
|||
/// Create a new memory ID type.
|
||||
macro_rules! memory_id_type {
|
||||
($name:ident) => {
|
||||
#[derive(Clone, Hash, PartialEq, Eq)]
|
||||
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
|
||||
/// Memory ID.
|
||||
pub struct $name {
|
||||
id: alloc::sync::Arc<alloc::string::String>,
|
||||
|
|
|
@ -5,7 +5,7 @@ use crate::storage::ComputeStorage;
|
|||
///
|
||||
/// It is responsible for determining if the memory segment can be mutated,
|
||||
/// for instance by keeping track of a reference count
|
||||
pub trait MemoryHandle: Clone + Send {
|
||||
pub trait MemoryHandle: Clone + Send + core::fmt::Debug {
|
||||
/// Checks if the underlying memory can be safely mutated.
|
||||
fn can_mut(&self) -> bool;
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ pub trait MemoryHandle: Clone + Send {
|
|||
///
|
||||
/// The MemoryManagement can only reserve memory space or get the resource located at a space.
|
||||
/// Modification of the resource data should be done directly on the resource.
|
||||
pub trait MemoryManagement<Storage: ComputeStorage>: Send {
|
||||
pub trait MemoryManagement<Storage: ComputeStorage>: Send + core::fmt::Debug {
|
||||
/// The associated type Handle must implement MemoryHandle
|
||||
type Handle: MemoryHandle;
|
||||
|
||||
|
@ -24,4 +24,30 @@ pub trait MemoryManagement<Storage: ComputeStorage>: Send {
|
|||
|
||||
/// Finds a spot in memory for a resource with the given size in bytes, and returns a handle to it
|
||||
fn reserve(&mut self, size: usize) -> Self::Handle;
|
||||
|
||||
/// Bypass the memory allocation algorithm to allocate data directly.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Can be useful for servers that want specific control over memory.
|
||||
fn alloc(&mut self, size: usize) -> Self::Handle;
|
||||
|
||||
/// Bypass the memory allocation algorithm to deallocate data directly.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Can be useful for servers that want specific control over memory.
|
||||
fn dealloc(&mut self, handle: &Self::Handle);
|
||||
|
||||
/// Fetch the storage used by the memory manager.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The storage should probably not be used for allocations since the handles won't be
|
||||
/// compatible with the ones provided by the current trait. Prefer using the
|
||||
/// [alloc](MemoryManagement::alloc) and [dealloc](MemoryManagement::dealloc) functions.
|
||||
///
|
||||
/// This is useful if you need to time the deallocations based on async computation, or to
|
||||
/// change the mode of storage for different reasons.
|
||||
fn storage(&mut self) -> &mut Storage;
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ impl SliceId {
|
|||
}
|
||||
|
||||
/// The SimpleHandle is a memory handle, referring to either a chunk or a slice.
|
||||
#[derive(Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SimpleHandle {
|
||||
/// A whole chunk of memory.
|
||||
Chunk(ChunkId),
|
||||
|
@ -35,34 +35,72 @@ pub enum SimpleHandle {
|
|||
}
|
||||
|
||||
/// The strategy defines the frequency at which deallocation of unused memory chunks should occur.
|
||||
#[derive(Debug)]
|
||||
pub enum DeallocStrategy {
|
||||
/// Once every n calls to reserve.
|
||||
///
|
||||
/// First associated data is n, second is the state and should start at 0
|
||||
PeriodTick(usize, usize),
|
||||
PeriodTick {
|
||||
/// Number of calls to be executed before triggering the deallocation.
|
||||
period: usize,
|
||||
/// Current state. Should start at zero.
|
||||
state: usize,
|
||||
},
|
||||
#[cfg(feature = "std")]
|
||||
/// Once every period of time
|
||||
PeriodTime(std::time::Duration, std::time::Instant),
|
||||
PeriodTime {
|
||||
/// Number of time before triggering the deallocation.
|
||||
period: std::time::Duration,
|
||||
/// Current state. Should start at now.
|
||||
state: std::time::Instant,
|
||||
},
|
||||
/// Never deallocate.
|
||||
Never,
|
||||
}
|
||||
|
||||
/// The strategy defines when to reuse chunk with slices.
|
||||
#[derive(Debug)]
|
||||
pub enum SliceStrategy {
|
||||
/// Never use slices.
|
||||
Never,
|
||||
/// Ratio needed before the chunk can be used as a slice. Between 0 and 1.
|
||||
Ratio(f32),
|
||||
/// When the reserved memory is at least {} bytes.
|
||||
MinimumSize(usize),
|
||||
/// When the reserved memory less than {} bytes.
|
||||
MaximumSize(usize),
|
||||
}
|
||||
|
||||
impl SliceStrategy {
|
||||
/// If the chunk can be used with a slice.
|
||||
pub fn can_use_chunk(&self, chunk_size: usize, reserved_size: usize) -> bool {
|
||||
if chunk_size < reserved_size {
|
||||
return false;
|
||||
}
|
||||
|
||||
match self {
|
||||
SliceStrategy::Never => false,
|
||||
SliceStrategy::Ratio(ratio) => (reserved_size as f32 / chunk_size as f32) >= *ratio,
|
||||
SliceStrategy::MinimumSize(bytes) => reserved_size >= *bytes,
|
||||
SliceStrategy::MaximumSize(bytes) => reserved_size <= *bytes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DeallocStrategy {
|
||||
/// Create a new strategy with the given period.
|
||||
pub fn new_period_tick(period: usize) -> Self {
|
||||
DeallocStrategy::PeriodTick(period, 0)
|
||||
DeallocStrategy::PeriodTick { period, state: 0 }
|
||||
}
|
||||
|
||||
fn should_dealloc(&mut self) -> bool {
|
||||
match self {
|
||||
DeallocStrategy::PeriodTick(period, last) => {
|
||||
*last = (*last + 1) % *period;
|
||||
*last == 0
|
||||
DeallocStrategy::PeriodTick { period, state } => {
|
||||
*state = (*state + 1) % *period;
|
||||
*state == 0
|
||||
}
|
||||
#[cfg(feature = "std")]
|
||||
DeallocStrategy::PeriodTime(period, last) => {
|
||||
if &last.elapsed() > period {
|
||||
*last = std::time::Instant::now();
|
||||
DeallocStrategy::PeriodTime { period, state } => {
|
||||
if &state.elapsed() > period {
|
||||
*state = std::time::Instant::now();
|
||||
true
|
||||
} else {
|
||||
false
|
||||
|
@ -78,9 +116,23 @@ pub struct SimpleMemoryManagement<Storage> {
|
|||
chunks: HashMap<ChunkId, (StorageHandle, Vec<SliceId>)>,
|
||||
slices: HashMap<SliceId, (StorageHandle, ChunkId)>,
|
||||
dealloc_strategy: DeallocStrategy,
|
||||
slice_strategy: SliceStrategy,
|
||||
storage: Storage,
|
||||
}
|
||||
|
||||
impl<Storage> core::fmt::Debug for SimpleMemoryManagement<Storage> {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str(
|
||||
alloc::format!(
|
||||
"SimpleMemoryManagement {:?} - {:?}",
|
||||
self.dealloc_strategy,
|
||||
core::any::type_name::<Storage>(),
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryHandle for SimpleHandle {
|
||||
/// Returns true if referenced by only one tensor, and only once by the
|
||||
/// memory management hashmaps
|
||||
|
@ -126,24 +178,43 @@ impl<Storage: ComputeStorage> MemoryManagement<Storage> for SimpleMemoryManageme
|
|||
|
||||
handle
|
||||
}
|
||||
|
||||
fn alloc(&mut self, size: usize) -> Self::Handle {
|
||||
self.create_chunk(size)
|
||||
}
|
||||
|
||||
fn dealloc(&mut self, handle: &Self::Handle) {
|
||||
match handle {
|
||||
SimpleHandle::Chunk(id) => {
|
||||
if let Some((handle, _slices)) = self.chunks.remove(id) {
|
||||
self.storage.dealloc(handle.id);
|
||||
}
|
||||
}
|
||||
SimpleHandle::Slice(_) => panic!("Can't dealloc slice manually"),
|
||||
}
|
||||
}
|
||||
|
||||
fn storage(&mut self) -> &mut Storage {
|
||||
&mut self.storage
|
||||
}
|
||||
}
|
||||
|
||||
impl<Storage: ComputeStorage> SimpleMemoryManagement<Storage> {
|
||||
/// Creates a new instance using the given storage and deallocation strategy.
|
||||
pub fn new(storage: Storage, dealloc_strategy: DeallocStrategy) -> Self {
|
||||
/// Creates a new instance using the given storage, deallocation strategy and slice strategy.
|
||||
pub fn new(
|
||||
storage: Storage,
|
||||
dealloc_strategy: DeallocStrategy,
|
||||
slice_strategy: SliceStrategy,
|
||||
) -> Self {
|
||||
Self {
|
||||
chunks: HashMap::new(),
|
||||
slices: HashMap::new(),
|
||||
dealloc_strategy,
|
||||
slice_strategy,
|
||||
storage,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates an new instance using the given storage without deallocation.
|
||||
pub fn never_dealloc(storage: Storage) -> Self {
|
||||
Self::new(storage, DeallocStrategy::Never)
|
||||
}
|
||||
|
||||
fn reserve_algorithm(&mut self, size: usize) -> SimpleHandle {
|
||||
// Looks for a large enough, existing but unused chunk of memory.
|
||||
let chunk = self.find_free_chunk(size);
|
||||
|
@ -169,19 +240,31 @@ impl<Storage: ComputeStorage> SimpleMemoryManagement<Storage> {
|
|||
let mut size_diff_current = usize::MAX;
|
||||
let mut current = None;
|
||||
|
||||
self.chunks
|
||||
.iter()
|
||||
.for_each(|(chunk_id, (resource, slices))| {
|
||||
let is_free = slices.is_empty() && chunk_id.is_free();
|
||||
for (chunk_id, (resource, slices)) in self.chunks.iter() {
|
||||
// If chunk is already used, we do not choose it
|
||||
if !slices.is_empty() || !chunk_id.is_free() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if is_free && resource.size() > size {
|
||||
let size_diff = resource.size() - size;
|
||||
if size_diff < size_diff_current {
|
||||
current = Some((chunk_id, resource));
|
||||
size_diff_current = size_diff;
|
||||
}
|
||||
let resource_size = resource.size();
|
||||
|
||||
// If we find a chunk of exactly the right size, we stop searching altogether
|
||||
if size == resource_size {
|
||||
current = Some((chunk_id, resource));
|
||||
break;
|
||||
}
|
||||
|
||||
// Finds the smallest of the large enough chunks that can accept a slice
|
||||
// of the given size
|
||||
if self.slice_strategy.can_use_chunk(resource_size, size) {
|
||||
let size_diff = resource_size - size;
|
||||
|
||||
if size_diff < size_diff_current {
|
||||
current = Some((chunk_id, resource));
|
||||
size_diff_current = size_diff;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
current.map(|(id, handle)| (id.clone(), handle.size()))
|
||||
}
|
||||
|
@ -263,7 +346,7 @@ impl<Storage: ComputeStorage> SimpleMemoryManagement<Storage> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
memory_management::{MemoryHandle, MemoryManagement},
|
||||
memory_management::{MemoryHandle, MemoryManagement, SliceStrategy},
|
||||
storage::BytesStorage,
|
||||
};
|
||||
|
||||
|
@ -271,7 +354,11 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn can_mut_with_single_tensor_reference() {
|
||||
let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default());
|
||||
let mut memory_management = SimpleMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
DeallocStrategy::Never,
|
||||
SliceStrategy::Never,
|
||||
);
|
||||
|
||||
let chunk_size = 4;
|
||||
let simple_handle = memory_management.create_chunk(chunk_size);
|
||||
|
@ -284,7 +371,11 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn two_tensor_references_remove_mutability() {
|
||||
let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default());
|
||||
let mut memory_management = SimpleMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
DeallocStrategy::Never,
|
||||
SliceStrategy::Never,
|
||||
);
|
||||
|
||||
let chunk_size = 4;
|
||||
let simple_handle = memory_management.create_chunk(chunk_size);
|
||||
|
@ -297,7 +388,11 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn when_non_empty_chunk_exists_and_other_one_created_there_should_be_two() {
|
||||
let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default());
|
||||
let mut memory_management = SimpleMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
DeallocStrategy::Never,
|
||||
SliceStrategy::Never,
|
||||
);
|
||||
let chunk_size = 4;
|
||||
let _chunk_handle = memory_management.reserve(chunk_size);
|
||||
let _new_handle = memory_management.reserve(chunk_size);
|
||||
|
@ -307,7 +402,11 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn when_empty_chunk_is_cleaned_upexists_it_disappears() {
|
||||
let mut memory_management = SimpleMemoryManagement::never_dealloc(BytesStorage::default());
|
||||
let mut memory_management = SimpleMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
DeallocStrategy::Never,
|
||||
SliceStrategy::Never,
|
||||
);
|
||||
let chunk_size = 4;
|
||||
let chunk_handle = memory_management.reserve(chunk_size);
|
||||
drop(chunk_handle);
|
||||
|
@ -336,4 +435,28 @@ mod tests {
|
|||
assert!(period_tick_dealloc.should_dealloc());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slice_strategy_minimum_bytes() {
|
||||
let strategy = SliceStrategy::MinimumSize(100);
|
||||
|
||||
assert!(strategy.can_use_chunk(200, 101));
|
||||
assert!(!strategy.can_use_chunk(200, 99));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slice_strategy_maximum_bytes() {
|
||||
let strategy = SliceStrategy::MaximumSize(100);
|
||||
|
||||
assert!(strategy.can_use_chunk(200, 99));
|
||||
assert!(!strategy.can_use_chunk(200, 101));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slice_strategy_ratio() {
|
||||
let strategy = SliceStrategy::Ratio(0.9);
|
||||
|
||||
assert!(strategy.can_use_chunk(200, 180));
|
||||
assert!(!strategy.can_use_chunk(200, 179));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,18 +1,39 @@
|
|||
use crate::{
|
||||
memory_management::{MemoryHandle, MemoryManagement},
|
||||
storage::ComputeStorage,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use crate::{memory_management::MemoryManagement, storage::ComputeStorage};
|
||||
/// Server handle containing the [memory handle](MemoryManagement::Handle).
|
||||
#[derive(new, Debug)]
|
||||
pub struct Handle<Server: ComputeServer> {
|
||||
/// Handle for the memory in use.
|
||||
pub memory: <Server::MemoryManagement as MemoryManagement<Server::Storage>>::Handle,
|
||||
}
|
||||
|
||||
type _Storage<Server> = <Server as ComputeServer>::Storage;
|
||||
type _MemoryManagement<Server> = <Server as ComputeServer>::MemoryManagement;
|
||||
impl<Server: ComputeServer> Handle<Server> {
|
||||
/// If the tensor handle can be mut with an inplace operation.
|
||||
pub fn can_mut(&self) -> bool {
|
||||
self.memory.can_mut()
|
||||
}
|
||||
}
|
||||
|
||||
/// This alias for a [memory handle](MemoryManagement::Handle).
|
||||
pub type Handle<Server> = <_MemoryManagement<Server> as MemoryManagement<_Storage<Server>>>::Handle;
|
||||
impl<Server: ComputeServer> Clone for Handle<Server> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
memory: self.memory.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The compute server is responsible for handling resources and computations over resources.
|
||||
///
|
||||
/// Everything in the server is mutable, therefore it should be solely accessed through the
|
||||
/// [compute channel](crate::channel::ComputeChannel) for thread safety.
|
||||
pub trait ComputeServer: Send {
|
||||
pub trait ComputeServer: Send + core::fmt::Debug
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
/// The kernel type defines the computation algorithms.
|
||||
type Kernel: Send;
|
||||
/// The [storage](ComputeStorage) type defines how data is stored and accessed.
|
||||
|
|
|
@ -8,6 +8,12 @@ pub struct BytesStorage {
|
|||
memory: HashMap<StorageId, AllocatedBytes>,
|
||||
}
|
||||
|
||||
impl core::fmt::Debug for BytesStorage {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str("BytesStorage")
|
||||
}
|
||||
}
|
||||
|
||||
/// Can send to other threads, but can't sync.
|
||||
unsafe impl Send for BytesStorage {}
|
||||
unsafe impl Send for BytesResource {}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use super::DummyServer;
|
||||
use burn_compute::channel::MutexComputeChannel;
|
||||
use burn_compute::client::ComputeClient;
|
||||
use burn_compute::memory_management::SimpleMemoryManagement;
|
||||
use burn_compute::memory_management::{DeallocStrategy, SimpleMemoryManagement, SliceStrategy};
|
||||
use burn_compute::storage::BytesStorage;
|
||||
use burn_compute::Compute;
|
||||
|
||||
|
@ -17,7 +17,8 @@ pub fn client(
|
|||
) -> ComputeClient<DummyServer, MutexComputeChannel<DummyServer>> {
|
||||
COMPUTE.client(device, || {
|
||||
let storage = BytesStorage::default();
|
||||
let memory_management = SimpleMemoryManagement::never_dealloc(storage);
|
||||
let memory_management =
|
||||
SimpleMemoryManagement::new(storage, DeallocStrategy::Never, SliceStrategy::Never);
|
||||
let server = DummyServer::new(memory_management);
|
||||
let channel = MutexComputeChannel::new(server);
|
||||
|
||||
|
|
|
@ -2,14 +2,14 @@ use burn_compute::storage::BytesResource;
|
|||
|
||||
/// The DummyKernel trait should be implemented for every supported operation
|
||||
pub trait DummyKernel: Send {
|
||||
fn compute<'a>(&self, resources: &mut [BytesResource]);
|
||||
fn compute(&self, resources: &mut [BytesResource]);
|
||||
}
|
||||
|
||||
/// Contains the algorithm for element-wise addition
|
||||
pub struct DummyElementwiseAddition;
|
||||
|
||||
impl DummyKernel for DummyElementwiseAddition {
|
||||
fn compute<'a>(&self, inputs: &mut [BytesResource]) {
|
||||
fn compute(&self, inputs: &mut [BytesResource]) {
|
||||
// Notice how the kernel is responsible for determining which inputs
|
||||
// are read-only and which are writable.
|
||||
let lhs = &inputs[0].read();
|
||||
|
|
|
@ -9,7 +9,7 @@ use super::DummyKernel;
|
|||
|
||||
/// The dummy server is used to test the burn-compute infrastructure.
|
||||
/// It uses simple memory management with a bytes storage on CPU, without asynchronous tasks.
|
||||
#[derive(new)]
|
||||
#[derive(new, Debug)]
|
||||
pub struct DummyServer<MM = SimpleMemoryManagement<BytesStorage>> {
|
||||
memory_management: MM,
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ where
|
|||
type MemoryManagement = MM;
|
||||
|
||||
fn read(&mut self, handle: &Handle<Self>) -> Vec<u8> {
|
||||
let bytes = self.memory_management.get(handle);
|
||||
let bytes = self.memory_management.get(&handle.memory);
|
||||
|
||||
bytes.read().to_vec()
|
||||
}
|
||||
|
@ -38,17 +38,17 @@ where
|
|||
bytes[i] = *val;
|
||||
}
|
||||
|
||||
handle
|
||||
Handle::new(handle)
|
||||
}
|
||||
|
||||
fn empty(&mut self, size: usize) -> Handle<Self> {
|
||||
self.memory_management.reserve(size)
|
||||
Handle::new(self.memory_management.reserve(size))
|
||||
}
|
||||
|
||||
fn execute(&mut self, kernel: Self::Kernel, handles: &[&Handle<Self>]) {
|
||||
let mut resources = handles
|
||||
.iter()
|
||||
.map(|handle| self.memory_management.get(handle))
|
||||
.map(|handle| self.memory_management.get(&handle.memory))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
kernel.compute(&mut resources);
|
||||
|
|
|
@ -23,6 +23,7 @@ std = [
|
|||
"serde_json/std",
|
||||
"bincode/std",
|
||||
"half/std",
|
||||
"derive-new/std",
|
||||
]
|
||||
|
||||
dataset = ["burn-dataset/default"]
|
||||
|
@ -42,7 +43,6 @@ ndarray-blas-openblas-system = ["__ndarray", "ndarray", "burn-ndarray/blas-openb
|
|||
__ndarray = [] # Internal flag to know when one ndarray feature is enabled.
|
||||
|
||||
wgpu = ["burn-wgpu"]
|
||||
wgpu-autotune = ["wgpu", "burn-wgpu/autotune"]
|
||||
|
||||
tch = ["burn-tch"]
|
||||
|
||||
|
@ -67,7 +67,7 @@ burn-autodiff = { path = "../burn-autodiff", version = "0.10.0", optional = true
|
|||
burn-wgpu = { path = "../burn-wgpu", version = "0.10.0", optional = true }
|
||||
burn-tch = { path = "../burn-tch", version = "0.10.0", optional = true }
|
||||
|
||||
derive-new = { workspace = true }
|
||||
derive-new = { workspace = true, default-features = false }
|
||||
libm = { workspace = true }
|
||||
log = { workspace = true, optional = true }
|
||||
rand = { workspace = true, features = ["std_rng"] } # Default enables std
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
use super::{bin_config, PrecisionSettings, Recorder, RecorderError};
|
||||
use alloc::vec::Vec;
|
||||
use core::marker::PhantomData;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
|
||||
/// Recorder trait specialized to save and load data to and from bytes.
|
||||
|
@ -17,7 +16,7 @@ pub trait BytesRecorder:
|
|||
/// In memory recorder using the [bincode format](bincode).
|
||||
#[derive(new, Debug, Default, Clone)]
|
||||
pub struct BinBytesRecorder<S: PrecisionSettings> {
|
||||
_settings: PhantomData<S>,
|
||||
_settings: core::marker::PhantomData<S>,
|
||||
}
|
||||
|
||||
impl<S: PrecisionSettings> BytesRecorder for BinBytesRecorder<S> {}
|
||||
|
@ -45,7 +44,7 @@ impl<S: PrecisionSettings> Recorder for BinBytesRecorder<S> {
|
|||
/// In memory recorder using the [Named MessagePack](rmp_serde).
|
||||
#[derive(new, Debug, Default, Clone)]
|
||||
pub struct NamedMpkBytesRecorder<S: PrecisionSettings> {
|
||||
_settings: PhantomData<S>,
|
||||
_settings: core::marker::PhantomData<S>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
use super::{PrecisionSettings, Record};
|
||||
use burn_tensor::{backend::Backend, Bool, DataSerialize, Int, Tensor};
|
||||
use core::marker::PhantomData;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// This struct implements serde to lazily serialize and deserialize a float tensor
|
||||
|
@ -8,7 +7,7 @@ use serde::{Deserialize, Serialize};
|
|||
#[derive(new, Clone, Debug)]
|
||||
pub struct FloatTensorSerde<B: Backend, const D: usize, S: PrecisionSettings> {
|
||||
tensor: Tensor<B, D>,
|
||||
elem: PhantomData<S>,
|
||||
elem: core::marker::PhantomData<S>,
|
||||
}
|
||||
|
||||
/// This struct implements serde to lazily serialize and deserialize an int tensor
|
||||
|
@ -16,7 +15,7 @@ pub struct FloatTensorSerde<B: Backend, const D: usize, S: PrecisionSettings> {
|
|||
#[derive(new, Clone, Debug)]
|
||||
pub struct IntTensorSerde<B: Backend, const D: usize, S: PrecisionSettings> {
|
||||
tensor: Tensor<B, D, Int>,
|
||||
elem: PhantomData<S>,
|
||||
elem: core::marker::PhantomData<S>,
|
||||
}
|
||||
|
||||
/// This struct implements serde to lazily serialize and deserialize an bool tensor.
|
||||
|
|
|
@ -91,4 +91,10 @@ impl<E: TchElement> Backend for TchBackend<E> {
|
|||
fn name() -> String {
|
||||
"tch".to_string()
|
||||
}
|
||||
|
||||
fn sync(device: &Self::Device) {
|
||||
if let TchDevice::Cuda(index) = device {
|
||||
tch::Cuda::synchronize(*index as i64);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -94,6 +94,9 @@ pub trait Backend:
|
|||
|
||||
/// Seed the backend.
|
||||
fn seed(seed: u64);
|
||||
|
||||
/// Sync the backend, ensure that all computation are finished.
|
||||
fn sync(_device: &Self::Device) {}
|
||||
}
|
||||
|
||||
pub(crate) type ADBackendTensorPrimitive<const D: usize, B> =
|
||||
|
|
|
@ -11,10 +11,7 @@ repository = "https://github.com/burn-rs/burn/tree/main/burn-wgpu"
|
|||
version = "0.10.0"
|
||||
|
||||
[features]
|
||||
default = ["async"]
|
||||
async = []
|
||||
# Still experimental
|
||||
autotune = []
|
||||
default = []
|
||||
|
||||
[dependencies]
|
||||
burn-common = { path = "../burn-common", version = "0.10.0" }
|
||||
|
@ -35,6 +32,10 @@ wgpu = { workspace = true }
|
|||
serde = { workspace = true }
|
||||
text_placeholder = { version = "0.5.0", features = ["struct_context"] }
|
||||
|
||||
hashbrown = { workspace = true }
|
||||
burn-compute = { path = "../burn-compute", version = "0.10.0", default-features = false, features=["channel-mutex", "std"] }
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
burn-autodiff = { path = "../burn-autodiff", version = "0.10.0", default-features = false, features = [
|
||||
"export_tests",
|
||||
|
@ -45,10 +46,6 @@ burn-tensor = { path = "../burn-tensor", version = "0.10.0", default-features =
|
|||
burn-ndarray = { path = "../burn-ndarray", version = "0.10.0" }
|
||||
serial_test = "2.0.0"
|
||||
|
||||
# Still only in dev mode
|
||||
hashbrown = { workspace = true }
|
||||
burn-compute = { path = "../burn-compute", version = "0.10.0", default-features = false, features=["channel-mutex", "std"] }
|
||||
|
||||
[[bench]]
|
||||
name = "unary"
|
||||
harness = false
|
||||
|
|
|
@ -3,7 +3,7 @@ use burn_wgpu::{
|
|||
benchmark::Benchmark,
|
||||
kernel::matmul::{
|
||||
contiguous, contiguous_vectorized, matmul_mem_coalescing_default, matmul_naive_default,
|
||||
tile, tile_vectorized, tune,
|
||||
tile, tile_vectorized,
|
||||
},
|
||||
run_benchmark, GraphicsApi, WgpuBackend, WgpuDevice,
|
||||
};
|
||||
|
@ -50,8 +50,8 @@ where
|
|||
}
|
||||
|
||||
fn prepare(&self, device: &WgpuDevice) -> Self::Args {
|
||||
let lhs = Tensor::random(self.shape_lhs.clone(), Distribution::Default).to_device(device);
|
||||
let rhs = Tensor::random(self.shape_rhs.clone(), Distribution::Default).to_device(device);
|
||||
let lhs = Tensor::random_device(self.shape_lhs.clone(), Distribution::Default, device);
|
||||
let rhs = Tensor::random_device(self.shape_rhs.clone(), Distribution::Default, device);
|
||||
|
||||
(lhs, rhs)
|
||||
}
|
||||
|
@ -88,22 +88,6 @@ benchmark!(
|
|||
contiguous_vectorized::matmul_tiling_2d_default
|
||||
);
|
||||
|
||||
struct MatmulAutotune;
|
||||
|
||||
impl<const D: usize, G: GraphicsApi> MatmulFunction<WgpuBackend<G, f32, i32>, D>
|
||||
for MatmulAutotune
|
||||
{
|
||||
fn run(
|
||||
lhs: Tensor<WgpuBackend<G, f32, i32>, D>,
|
||||
rhs: Tensor<WgpuBackend<G, f32, i32>, D>,
|
||||
) -> Tensor<WgpuBackend<G, f32, i32>, D> {
|
||||
Tensor::from_primitive(tune::<G, f32, D>(
|
||||
lhs.into_primitive(),
|
||||
rhs.into_primitive(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let num_repeats = 3;
|
||||
let batch_size = 3;
|
||||
|
@ -141,10 +125,4 @@ fn main() {
|
|||
num_repeats,
|
||||
matmul: PhantomData
|
||||
});
|
||||
run_benchmark!(MatmulBenchmark::<MatmulAutotune, 3> {
|
||||
shape_lhs: [batch_size, m, k].into(),
|
||||
shape_rhs: [batch_size, k, n].into(),
|
||||
num_repeats,
|
||||
matmul: PhantomData
|
||||
});
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ use burn_tensor::backend::Backend;
|
|||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
use crate::{
|
||||
compute::compute_client,
|
||||
element::{FloatElement, IntElement},
|
||||
tensor::WgpuTensor,
|
||||
GraphicsApi, WgpuDevice,
|
||||
|
@ -43,4 +44,9 @@ impl<G: GraphicsApi + 'static, F: FloatElement, I: IntElement> Backend for WgpuB
|
|||
fn ad_enabled() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn sync(device: &Self::Device) {
|
||||
let client = compute_client::<G>(device);
|
||||
client.sync();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{pool::get_context, GraphicsApi, WgpuDevice};
|
||||
use crate::{compute::compute_client, GraphicsApi, WgpuDevice};
|
||||
use std::{
|
||||
fmt::Display,
|
||||
time::{Duration, Instant},
|
||||
|
@ -15,6 +15,7 @@ impl BenchmarkResult {
|
|||
self.durations.iter().sum::<Duration>() / self.durations.len() as u32
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn median_duration(&self) -> Duration {
|
||||
let mut sorted = self.durations.clone();
|
||||
sorted.sort();
|
||||
|
@ -83,23 +84,23 @@ pub trait Benchmark<G: GraphicsApi> {
|
|||
fn name(&self) -> String;
|
||||
/// Run the benchmark a number of times.
|
||||
fn run(&self, device: &WgpuDevice) -> BenchmarkResult {
|
||||
let context = get_context::<G>(device);
|
||||
let client = compute_client::<G>(device);
|
||||
|
||||
// Warmup
|
||||
self.execute(self.prepare(device));
|
||||
context.sync();
|
||||
client.sync();
|
||||
|
||||
let mut durations = Vec::with_capacity(self.num_samples());
|
||||
|
||||
for _ in 0..self.num_samples() {
|
||||
// Prepare
|
||||
let args = self.prepare(device);
|
||||
context.sync();
|
||||
client.sync();
|
||||
|
||||
// Execute the benchmark
|
||||
let start = Instant::now();
|
||||
self.execute(args);
|
||||
context.sync();
|
||||
client.sync();
|
||||
let end = Instant::now();
|
||||
|
||||
// Register the duration
|
||||
|
|
|
@ -1,26 +1,28 @@
|
|||
use super::{Kernel, WgpuServer};
|
||||
use crate::{
|
||||
compute::WgpuStorage,
|
||||
context::{select_device, WorkGroup},
|
||||
kernel::{DynamicKernel, SourceTemplate, StaticKernel},
|
||||
GraphicsApi, WgpuDevice,
|
||||
};
|
||||
use super::WgpuServer;
|
||||
use crate::{compute::WgpuStorage, GraphicsApi, WgpuDevice};
|
||||
use alloc::sync::Arc;
|
||||
use burn_compute::{
|
||||
channel::MutexComputeChannel,
|
||||
client::ComputeClient,
|
||||
memory_management::{DeallocStrategy, SimpleMemoryManagement},
|
||||
memory_management::{DeallocStrategy, SimpleMemoryManagement, SliceStrategy},
|
||||
Compute,
|
||||
};
|
||||
use std::{marker::PhantomData, sync::Arc};
|
||||
use wgpu::{DeviceDescriptor, DeviceType};
|
||||
|
||||
type WgpuChannel = MutexComputeChannel<WgpuServer>;
|
||||
type MemoryManagement = SimpleMemoryManagement<WgpuStorage>;
|
||||
type Server = WgpuServer<MemoryManagement>;
|
||||
type Channel = MutexComputeChannel<Server>;
|
||||
|
||||
/// Wgpu [compute client](ComputeClient) to communicate with the [compute server](WgpuServer).
|
||||
pub type WgpuComputeClient = ComputeClient<Server, Channel>;
|
||||
/// Wgpu [server handle](burn_compute::server::Handle).
|
||||
pub type WgpuHandle = burn_compute::server::Handle<Server>;
|
||||
|
||||
/// Compute handle for the wgpu backend.
|
||||
static COMPUTE: Compute<WgpuDevice, WgpuServer, WgpuChannel> = Compute::new();
|
||||
static COMPUTE: Compute<WgpuDevice, WgpuServer<MemoryManagement>, Channel> = Compute::new();
|
||||
|
||||
pub fn compute_client<G: GraphicsApi>(
|
||||
device: &WgpuDevice,
|
||||
) -> ComputeClient<WgpuServer, WgpuChannel> {
|
||||
/// Get the [compute client](ComputeClient) for the given [device](WgpuDevice).
|
||||
pub fn compute_client<G: GraphicsApi>(device: &WgpuDevice) -> ComputeClient<Server, Channel> {
|
||||
let device = Arc::new(device);
|
||||
|
||||
COMPUTE.client(&device, move || {
|
||||
|
@ -37,93 +39,159 @@ pub fn compute_client<G: GraphicsApi>(
|
|||
Ok(value) => value
|
||||
.parse::<usize>()
|
||||
.expect("BURN_WGPU_MAX_TASKS should be a positive integer."),
|
||||
Err(_) => 16, // 16 tasks by default
|
||||
Err(_) => 64, // 64 tasks by default
|
||||
};
|
||||
|
||||
let device = Arc::new(device_wgpu);
|
||||
let storage = WgpuStorage::new(device.clone());
|
||||
// Maximum reusability.
|
||||
let memory_management = SimpleMemoryManagement::new(storage, DeallocStrategy::Never);
|
||||
let memory_management = SimpleMemoryManagement::new(
|
||||
storage,
|
||||
DeallocStrategy::new_period_tick(1000),
|
||||
SliceStrategy::Ratio(0.9),
|
||||
);
|
||||
let server = WgpuServer::new(memory_management, device, queue, max_tasks);
|
||||
let channel = WgpuChannel::new(server);
|
||||
let channel = Channel::new(server);
|
||||
|
||||
ComputeClient::new(channel)
|
||||
})
|
||||
}
|
||||
|
||||
pub struct DynamicComputeKernel<K> {
|
||||
kernel: K,
|
||||
workgroup: WorkGroup,
|
||||
/// Select the wgpu device and queue based on the provided [device](WgpuDevice).
|
||||
pub async fn select_device<G: GraphicsApi>(
|
||||
device: &WgpuDevice,
|
||||
) -> (wgpu::Device, wgpu::Queue, wgpu::AdapterInfo) {
|
||||
let adapter = select_adapter::<G>(device);
|
||||
let limits = adapter.limits();
|
||||
|
||||
let (device, queue) = adapter
|
||||
.request_device(
|
||||
&DeviceDescriptor {
|
||||
label: None,
|
||||
features: wgpu::Features::empty(),
|
||||
limits,
|
||||
},
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
format!(
|
||||
"Unable to request the device with the adapter {:?}, err {:?}",
|
||||
adapter.get_info(),
|
||||
err
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
(device, queue, adapter.get_info())
|
||||
}
|
||||
|
||||
impl<K> Kernel for DynamicComputeKernel<K>
|
||||
where
|
||||
K: DynamicKernel + 'static,
|
||||
{
|
||||
fn source_template(self: Box<Self>) -> SourceTemplate {
|
||||
self.kernel.source_template()
|
||||
fn select_adapter<G: GraphicsApi>(device: &WgpuDevice) -> wgpu::Adapter {
|
||||
let instance = wgpu::Instance::default();
|
||||
|
||||
let mut adapters_other = Vec::new();
|
||||
let mut adapters = Vec::new();
|
||||
|
||||
instance
|
||||
.enumerate_adapters(G::backend().into())
|
||||
.for_each(|adapter| {
|
||||
let device_type = adapter.get_info().device_type;
|
||||
|
||||
if let DeviceType::Other = device_type {
|
||||
adapters_other.push(adapter);
|
||||
return;
|
||||
}
|
||||
|
||||
let is_same_type = match device {
|
||||
WgpuDevice::DiscreteGpu(_) => device_type == DeviceType::DiscreteGpu,
|
||||
WgpuDevice::IntegratedGpu(_) => device_type == DeviceType::IntegratedGpu,
|
||||
WgpuDevice::VirtualGpu(_) => device_type == DeviceType::VirtualGpu,
|
||||
WgpuDevice::Cpu => device_type == DeviceType::Cpu,
|
||||
WgpuDevice::BestAvailable => true,
|
||||
};
|
||||
|
||||
if is_same_type {
|
||||
adapters.push(adapter);
|
||||
}
|
||||
});
|
||||
|
||||
fn select(
|
||||
num: usize,
|
||||
error: &str,
|
||||
mut adapters: Vec<wgpu::Adapter>,
|
||||
mut adapters_other: Vec<wgpu::Adapter>,
|
||||
) -> wgpu::Adapter {
|
||||
if adapters.len() <= num {
|
||||
if adapters_other.len() <= num {
|
||||
panic!(
|
||||
"{}, adapters {:?}, other adapters {:?}",
|
||||
error,
|
||||
adapters
|
||||
.into_iter()
|
||||
.map(|adapter| adapter.get_info())
|
||||
.collect::<Vec<_>>(),
|
||||
adapters_other
|
||||
.into_iter()
|
||||
.map(|adapter| adapter.get_info())
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
} else {
|
||||
return adapters_other.remove(num);
|
||||
}
|
||||
}
|
||||
|
||||
adapters.remove(num)
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
self.kernel.id()
|
||||
}
|
||||
let adapter = match device {
|
||||
WgpuDevice::DiscreteGpu(num) => select(
|
||||
*num,
|
||||
"No Discrete GPU device found",
|
||||
adapters,
|
||||
adapters_other,
|
||||
),
|
||||
WgpuDevice::IntegratedGpu(num) => select(
|
||||
*num,
|
||||
"No Integrated GPU device found",
|
||||
adapters,
|
||||
adapters_other,
|
||||
),
|
||||
WgpuDevice::VirtualGpu(num) => select(
|
||||
*num,
|
||||
"No Virtual GPU device found",
|
||||
adapters,
|
||||
adapters_other,
|
||||
),
|
||||
WgpuDevice::Cpu => select(0, "No CPU device found", adapters, adapters_other),
|
||||
WgpuDevice::BestAvailable => {
|
||||
let mut most_performant_adapter = None;
|
||||
let mut current_score = -1;
|
||||
|
||||
fn workgroup(&self) -> WorkGroup {
|
||||
self.workgroup.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub struct StaticComputeKernel<K> {
|
||||
workgroup: WorkGroup,
|
||||
_kernel: PhantomData<K>,
|
||||
}
|
||||
|
||||
impl<K> Kernel for StaticComputeKernel<K>
|
||||
where
|
||||
K: StaticKernel + 'static,
|
||||
{
|
||||
fn source_template(self: Box<Self>) -> SourceTemplate {
|
||||
K::source_template()
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
format!("{:?}", core::any::TypeId::of::<K>())
|
||||
}
|
||||
|
||||
fn workgroup(&self) -> WorkGroup {
|
||||
self.workgroup.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{binary_elemwise, kernel::KernelSettings, AutoGraphicsApi};
|
||||
|
||||
#[test]
|
||||
fn can_run_kernel() {
|
||||
binary_elemwise!(Add, "+");
|
||||
|
||||
let client = compute_client::<AutoGraphicsApi>(&WgpuDevice::default());
|
||||
|
||||
let lhs: Vec<f32> = vec![0., 1., 2., 3., 4., 5., 6., 7.];
|
||||
let rhs: Vec<f32> = vec![10., 11., 12., 6., 7., 3., 1., 0.];
|
||||
let info: Vec<u32> = vec![1, 1, 1, 1, 8, 8, 8];
|
||||
|
||||
let lhs = client.create(bytemuck::cast_slice(&lhs));
|
||||
let rhs = client.create(bytemuck::cast_slice(&rhs));
|
||||
let out = client.empty(core::mem::size_of::<f32>() * 8);
|
||||
let info = client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
type Kernel = KernelSettings<Add, f32, i32, 16, 16, 1>;
|
||||
let kernel = Box::new(StaticComputeKernel::<Kernel>::new(WorkGroup::new(1, 1, 1)));
|
||||
|
||||
client.execute(kernel, &[&lhs, &rhs, &out, &info]);
|
||||
|
||||
let data = client.read(&out);
|
||||
let output: &[f32] = bytemuck::cast_slice(&data);
|
||||
|
||||
assert_eq!(output, [10., 12., 14., 9., 11., 8., 7., 7.]);
|
||||
}
|
||||
adapters.into_iter().for_each(|adapter| {
|
||||
let info = adapter.get_info();
|
||||
let score = match info.device_type {
|
||||
DeviceType::DiscreteGpu => 5,
|
||||
DeviceType::Other => 4, // Let's be optimistic with the Other device, it's
|
||||
// often a Discrete Gpu.
|
||||
DeviceType::IntegratedGpu => 3,
|
||||
DeviceType::VirtualGpu => 2,
|
||||
DeviceType::Cpu => 1,
|
||||
};
|
||||
|
||||
if score > current_score {
|
||||
most_performant_adapter = Some(adapter);
|
||||
current_score = score;
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(adapter) = most_performant_adapter {
|
||||
adapter
|
||||
} else {
|
||||
panic!("No adapter found for graphics API {:?}", G::default());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
log::info!("Using adapter {:?}", adapter.get_info());
|
||||
|
||||
adapter
|
||||
}
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
use super::Kernel;
|
||||
use crate::kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
/// Provides launch information specifying the number of work groups to be used by a compute shader.
|
||||
#[derive(new, Clone, Debug)]
|
||||
pub struct WorkGroup {
|
||||
/// Work groups for the x axis.
|
||||
pub x: u32,
|
||||
/// Work groups for the y axis.
|
||||
pub y: u32,
|
||||
/// Work groups for the z axis.
|
||||
pub z: u32,
|
||||
}
|
||||
|
||||
impl WorkGroup {
|
||||
/// Calculate the number of invocations of a compute shader.
|
||||
pub fn num_invocations(&self) -> usize {
|
||||
(self.x * self.y * self.z) as usize
|
||||
}
|
||||
}
|
||||
|
||||
/// Wraps a [dynamic kernel source](DynamicKernelSource) into a [kernel](Kernel) with launch
|
||||
/// information such as [workgroup](WorkGroup).
|
||||
#[derive(new)]
|
||||
pub struct DynamicKernel<K> {
|
||||
kernel: K,
|
||||
workgroup: WorkGroup,
|
||||
}
|
||||
|
||||
/// Wraps a [static kernel source](StaticKernelSource) into a [kernel](Kernel) with launch
|
||||
/// information such as [workgroup](WorkGroup).
|
||||
#[derive(new)]
|
||||
pub struct StaticKernel<K> {
|
||||
workgroup: WorkGroup,
|
||||
_kernel: PhantomData<K>,
|
||||
}
|
||||
|
||||
impl<K> Kernel for DynamicKernel<K>
|
||||
where
|
||||
K: DynamicKernelSource + 'static,
|
||||
{
|
||||
fn source(self: Box<Self>) -> SourceTemplate {
|
||||
self.kernel.source()
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
self.kernel.id()
|
||||
}
|
||||
|
||||
fn workgroup(&self) -> WorkGroup {
|
||||
self.workgroup.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl<K> Kernel for StaticKernel<K>
|
||||
where
|
||||
K: StaticKernelSource + 'static,
|
||||
{
|
||||
fn source(self: Box<Self>) -> SourceTemplate {
|
||||
K::source()
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
format!("{:?}", core::any::TypeId::of::<K>())
|
||||
}
|
||||
|
||||
fn workgroup(&self) -> WorkGroup {
|
||||
self.workgroup.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
binary_elemwise, compute::compute_client, kernel::KernelSettings, AutoGraphicsApi,
|
||||
WgpuDevice,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn can_run_kernel() {
|
||||
binary_elemwise!(Add, "+");
|
||||
|
||||
let client = compute_client::<AutoGraphicsApi>(&WgpuDevice::default());
|
||||
|
||||
let lhs: Vec<f32> = vec![0., 1., 2., 3., 4., 5., 6., 7.];
|
||||
let rhs: Vec<f32> = vec![10., 11., 12., 6., 7., 3., 1., 0.];
|
||||
let info: Vec<u32> = vec![1, 1, 1, 1, 8, 8, 8];
|
||||
|
||||
let lhs = client.create(bytemuck::cast_slice(&lhs));
|
||||
let rhs = client.create(bytemuck::cast_slice(&rhs));
|
||||
let out = client.empty(core::mem::size_of::<f32>() * 8);
|
||||
let info = client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
type Kernel = KernelSettings<Add, f32, i32, 16, 16, 1>;
|
||||
let kernel = Box::new(StaticKernel::<Kernel>::new(WorkGroup::new(1, 1, 1)));
|
||||
|
||||
client.execute(kernel, &[&lhs, &rhs, &out, &info]);
|
||||
|
||||
let data = client.read(&out);
|
||||
let output: &[f32] = bytemuck::cast_slice(&data);
|
||||
|
||||
assert_eq!(output, [10., 12., 14., 9., 11., 8., 7., 7.]);
|
||||
}
|
||||
}
|
|
@ -1,7 +1,9 @@
|
|||
mod base;
|
||||
mod kernel;
|
||||
mod server;
|
||||
mod storage;
|
||||
|
||||
pub use base::*;
|
||||
pub use kernel::*;
|
||||
pub use server::*;
|
||||
pub use storage::*;
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
use std::{borrow::Cow, sync::Arc};
|
||||
|
||||
use super::WgpuStorage;
|
||||
use crate::{context::WorkGroup, kernel::SourceTemplate};
|
||||
use super::{WgpuStorage, WorkGroup};
|
||||
use crate::kernel::SourceTemplate;
|
||||
use alloc::{borrow::Cow, sync::Arc};
|
||||
use burn_compute::{
|
||||
memory_management::{MemoryManagement, SimpleMemoryManagement},
|
||||
memory_management::MemoryManagement,
|
||||
server::{self, ComputeServer},
|
||||
};
|
||||
use hashbrown::HashMap;
|
||||
|
@ -13,7 +12,8 @@ use wgpu::{
|
|||
};
|
||||
|
||||
/// Wgpu compute server.
|
||||
pub struct WgpuServer<MM = SimpleMemoryManagement<WgpuStorage>> {
|
||||
#[derive(Debug)]
|
||||
pub struct WgpuServer<MM: MemoryManagement<WgpuStorage>> {
|
||||
memory_management: MM,
|
||||
device: Arc<wgpu::Device>,
|
||||
queue: wgpu::Queue,
|
||||
|
@ -21,20 +21,27 @@ pub struct WgpuServer<MM = SimpleMemoryManagement<WgpuStorage>> {
|
|||
pipelines: HashMap<String, Arc<ComputePipeline>>,
|
||||
tasks: Vec<ComputeTask>,
|
||||
max_tasks: usize,
|
||||
manual_available: HashMap<usize, Vec<server::Handle<Self>>>,
|
||||
manual_taken: Vec<(usize, server::Handle<Self>)>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
#[derive(new, Debug)]
|
||||
struct ComputeTask {
|
||||
pipeline: Arc<ComputePipeline>,
|
||||
bind_group: BindGroup,
|
||||
work_group: WorkGroup,
|
||||
}
|
||||
|
||||
/// Kernel trait with the [source](SourceTemplate) that will be compiled and cached based on the
|
||||
/// provided id.
|
||||
///
|
||||
/// The kernel will be launched with the given [workgroup](WorkGroup).
|
||||
pub trait Kernel: 'static + Send {
|
||||
/// Source template for the kernel.
|
||||
fn source_template(self: Box<Self>) -> SourceTemplate;
|
||||
fn source(self: Box<Self>) -> SourceTemplate;
|
||||
/// Identifier for the kernel, used for caching kernel compilation.
|
||||
fn id(&self) -> String;
|
||||
/// Launch information.
|
||||
fn workgroup(&self) -> WorkGroup;
|
||||
}
|
||||
|
||||
|
@ -42,6 +49,7 @@ impl<MM> WgpuServer<MM>
|
|||
where
|
||||
MM: MemoryManagement<WgpuStorage>,
|
||||
{
|
||||
/// Create a new server.
|
||||
pub fn new(
|
||||
memory_management: MM,
|
||||
device: Arc<wgpu::Device>,
|
||||
|
@ -60,6 +68,8 @@ where
|
|||
pipelines: HashMap::new(),
|
||||
tasks: Vec::new(),
|
||||
max_tasks,
|
||||
manual_available: HashMap::new(),
|
||||
manual_taken: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -68,13 +78,54 @@ where
|
|||
self.tasks.is_empty(),
|
||||
"Tasks should be completed before submitting the current encoder."
|
||||
);
|
||||
println!("Submit");
|
||||
let mut new_encoder = self
|
||||
.device
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
|
||||
core::mem::swap(&mut new_encoder, &mut self.encoder);
|
||||
|
||||
self.queue.submit(Some(new_encoder.finish()));
|
||||
|
||||
// Cleanup allocations and deallocations.
|
||||
self.free_manual_allocations();
|
||||
self.memory_management.storage().perform_deallocations();
|
||||
}
|
||||
|
||||
fn free_manual_allocations(&mut self) {
|
||||
let mut manual_taken_tmp = Vec::new();
|
||||
core::mem::swap(&mut manual_taken_tmp, &mut self.manual_taken);
|
||||
|
||||
for (size, handle) in manual_taken_tmp.drain(..) {
|
||||
if handle.can_mut() {
|
||||
self.register_manual(size, handle);
|
||||
} else {
|
||||
self.manual_taken.push((size, handle));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Finds a free, manually-added handle of specified size, or creates it if none is found
|
||||
fn manual_reserve(&mut self, size: usize) -> server::Handle<Self> {
|
||||
let handle = self
|
||||
.manual_available
|
||||
.get_mut(&size)
|
||||
.and_then(|h| h.pop())
|
||||
.unwrap_or_else(|| {
|
||||
let memory = self.memory_management.alloc(size);
|
||||
server::Handle::new(memory)
|
||||
});
|
||||
|
||||
self.manual_taken.push((size, handle.clone()));
|
||||
|
||||
handle
|
||||
}
|
||||
|
||||
// Manually adds a handle of given size
|
||||
fn register_manual(&mut self, size: usize, handle: server::Handle<Self>) {
|
||||
if let Some(handles) = self.manual_available.get_mut(&size) {
|
||||
handles.push(handle);
|
||||
} else {
|
||||
self.manual_available.insert(size, [handle].into());
|
||||
}
|
||||
}
|
||||
|
||||
fn register_tasks(&mut self) {
|
||||
|
@ -102,7 +153,7 @@ where
|
|||
return pipeline.clone();
|
||||
}
|
||||
|
||||
let pipeline = self.compile_source(&kernel.source_template().complete());
|
||||
let pipeline = self.compile_source(&kernel.source().complete());
|
||||
self.pipelines.insert(kernel_id.clone(), pipeline.clone());
|
||||
|
||||
pipeline
|
||||
|
@ -138,7 +189,7 @@ where
|
|||
// Register previous tasks before reading the buffer so that it is up to date.
|
||||
self.register_tasks();
|
||||
|
||||
let resource = self.memory_management.get(handle);
|
||||
let resource = self.memory_management.get(&handle.memory);
|
||||
|
||||
let size = resource.size();
|
||||
let buffer_dest = self.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
|
@ -182,8 +233,13 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// When we create a new handle from existing data, we use custom allocations so that we don't
|
||||
/// have to execute the current pending tasks.
|
||||
///
|
||||
/// This is important, otherwise the compute passes are going to be too small and we won't be able to
|
||||
/// fully utilize the GPU.
|
||||
fn create(&mut self, data: &[u8]) -> server::Handle<Self> {
|
||||
let handle = self.empty(data.len());
|
||||
let handle = self.manual_reserve(data.len());
|
||||
|
||||
let buffer_src = Arc::new(self.device.create_buffer_init(&BufferInitDescriptor {
|
||||
label: Some("Buffer Src"),
|
||||
|
@ -191,9 +247,7 @@ where
|
|||
usage: wgpu::BufferUsages::COPY_SRC,
|
||||
}));
|
||||
|
||||
let resource = self.memory_management.get(&handle);
|
||||
|
||||
self.register_tasks();
|
||||
let resource = self.memory_management.get(&handle.memory);
|
||||
|
||||
self.encoder.copy_buffer_to_buffer(
|
||||
&buffer_src,
|
||||
|
@ -207,7 +261,7 @@ where
|
|||
}
|
||||
|
||||
fn empty(&mut self, size: usize) -> server::Handle<Self> {
|
||||
self.memory_management.reserve(size)
|
||||
server::Handle::new(self.memory_management.reserve(size))
|
||||
}
|
||||
|
||||
fn execute(&mut self, kernel: Self::Kernel, handles: &[&server::Handle<Self>]) {
|
||||
|
@ -217,7 +271,7 @@ where
|
|||
|
||||
let handles = handles
|
||||
.iter()
|
||||
.map(|handle| self.memory_management.get(handle))
|
||||
.map(|handle| self.memory_management.get(&handle.memory))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let entries = handles
|
||||
|
@ -249,5 +303,7 @@ where
|
|||
self.register_tasks();
|
||||
self.submit();
|
||||
}
|
||||
|
||||
self.device.poll(wgpu::Maintain::Wait);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,14 +2,25 @@ use burn_compute::storage::{ComputeStorage, StorageHandle, StorageId, StorageUti
|
|||
use hashbrown::HashMap;
|
||||
use std::{num::NonZeroU64, sync::Arc};
|
||||
|
||||
/// Buffer storage for wgpu.
|
||||
pub struct WgpuStorage {
|
||||
memory: HashMap<StorageId, Arc<wgpu::Buffer>>,
|
||||
deallocations: Vec<StorageId>,
|
||||
device: Arc<wgpu::Device>,
|
||||
}
|
||||
|
||||
impl core::fmt::Debug for WgpuStorage {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str())
|
||||
}
|
||||
}
|
||||
|
||||
/// The memory resource that can be allocated for wgpu.
|
||||
#[derive(new, Debug)]
|
||||
pub struct WgpuResource {
|
||||
/// The wgpu buffer.
|
||||
pub buffer: Arc<wgpu::Buffer>,
|
||||
/// How the resource is used.
|
||||
pub kind: WgpuResourceKind,
|
||||
}
|
||||
|
||||
|
@ -44,6 +55,7 @@ impl WgpuResource {
|
|||
}
|
||||
}
|
||||
|
||||
/// How the resource is used, either as a slice or fully.
|
||||
#[derive(Debug)]
|
||||
pub enum WgpuResourceKind {
|
||||
/// Represents an entire buffer.
|
||||
|
@ -54,12 +66,23 @@ pub enum WgpuResourceKind {
|
|||
|
||||
/// Keeps actual wgpu buffer references in a hashmap with ids as key.
|
||||
impl WgpuStorage {
|
||||
/// Create a new storage on the given [device](wgpu::Device).
|
||||
pub fn new(device: Arc<wgpu::Device>) -> Self {
|
||||
Self {
|
||||
memory: HashMap::new(),
|
||||
deallocations: Vec::new(),
|
||||
device,
|
||||
}
|
||||
}
|
||||
|
||||
/// Actually deallocates buffers tagged to be deallocated.
|
||||
pub fn perform_deallocations(&mut self) {
|
||||
for id in self.deallocations.drain(..) {
|
||||
if let Some(buffer) = self.memory.remove(&id) {
|
||||
buffer.destroy()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ComputeStorage for WgpuStorage {
|
||||
|
@ -96,7 +119,6 @@ impl ComputeStorage for WgpuStorage {
|
|||
}
|
||||
|
||||
fn dealloc(&mut self, id: StorageId) {
|
||||
self.memory.get(&id).unwrap().destroy();
|
||||
let _ = self.memory.remove(&id);
|
||||
self.deallocations.push(id);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,410 +0,0 @@
|
|||
use super::client::ContextClient;
|
||||
use crate::{
|
||||
context::server::ContextServer,
|
||||
kernel::{DynamicKernel, StaticKernel},
|
||||
tune::Tuner,
|
||||
GraphicsApi, WgpuDevice,
|
||||
};
|
||||
use burn_common::id::IdGenerator;
|
||||
use spin::Mutex;
|
||||
use std::{
|
||||
any::TypeId,
|
||||
borrow::Cow,
|
||||
collections::HashMap,
|
||||
sync::atomic::{AtomicBool, Ordering},
|
||||
sync::Arc,
|
||||
};
|
||||
use wgpu::{
|
||||
util::{BufferInitDescriptor, DeviceExt},
|
||||
Buffer, ComputePipeline, DeviceDescriptor, DeviceType, ShaderModuleDescriptor,
|
||||
};
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
pub(crate) type ContextClientImpl = super::client::AsyncContextClient;
|
||||
#[cfg(not(feature = "async"))]
|
||||
pub(crate) type ContextClientImpl = super::client::SyncContextClient;
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
pub(crate) type ContextServerImpl = super::server::AsyncContextServer;
|
||||
#[cfg(not(feature = "async"))]
|
||||
pub(crate) type ContextServerImpl = super::server::SyncContextServer;
|
||||
|
||||
/// The context is the basic struct that allows to execute GPU kernel on devices.
|
||||
///
|
||||
/// You can access a context for a WGPUDevice using get_context.
|
||||
#[derive(Debug)]
|
||||
pub struct Context {
|
||||
id: String,
|
||||
device_wgpu: Arc<wgpu::Device>,
|
||||
cache: Mutex<HashMap<TemplateKey, Arc<ComputePipeline>>>,
|
||||
is_tuning: AtomicBool,
|
||||
client: ContextClientImpl,
|
||||
pub(crate) tuner: Tuner,
|
||||
tuning_template_ids: Mutex<Vec<TemplateKey>>,
|
||||
pub(crate) device: WgpuDevice,
|
||||
pub(crate) info: wgpu::AdapterInfo,
|
||||
}
|
||||
|
||||
#[derive(Debug, Hash, Clone, PartialOrd, PartialEq, Eq)]
|
||||
enum TemplateKey {
|
||||
Static(TypeId),
|
||||
Dynamic(String),
|
||||
}
|
||||
|
||||
/// Provides launch information specifying the number of work groups to be used by a compute shader.
|
||||
#[derive(new, Clone, Debug)]
|
||||
pub struct WorkGroup {
|
||||
/// Work groups for the x axis.
|
||||
pub x: u32,
|
||||
/// Work groups for the y axis.
|
||||
pub y: u32,
|
||||
/// Work groups for the z axis.
|
||||
pub z: u32,
|
||||
}
|
||||
|
||||
impl WorkGroup {
|
||||
/// Calculate the number of invocations of a compute shader.
|
||||
pub fn num_invocations(&self) -> usize {
|
||||
(self.x * self.y * self.z) as usize
|
||||
}
|
||||
}
|
||||
|
||||
impl Context {
|
||||
/// Create a new context where computing tasks will be executed on the given
|
||||
/// [device](WgpuDevice).
|
||||
pub(crate) fn new<G: GraphicsApi>(device: &WgpuDevice) -> Self {
|
||||
let (device_wgpu, queue, info) = pollster::block_on(select_device::<G>(device));
|
||||
let device = device.clone();
|
||||
let device_wgpu = Arc::new(device_wgpu);
|
||||
let client = ContextServerImpl::start(device_wgpu.clone(), queue);
|
||||
|
||||
Self {
|
||||
id: IdGenerator::generate(),
|
||||
device_wgpu,
|
||||
device,
|
||||
client,
|
||||
cache: Mutex::new(HashMap::new()),
|
||||
is_tuning: AtomicBool::new(false),
|
||||
tuner: Tuner::new(),
|
||||
tuning_template_ids: Mutex::new(Vec::new()),
|
||||
info,
|
||||
}
|
||||
}
|
||||
|
||||
/// Wait for all computation to be executed.
|
||||
///
|
||||
/// Useful for benchmarks.
|
||||
pub fn sync(&self) {
|
||||
self.client.sync();
|
||||
}
|
||||
|
||||
/// Execute a kernel using the provided buffers.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// This function isn't safe, buffer can be mutated by the GPU. The users must ensure that a
|
||||
/// buffer can be mutated when launching a compute shaders with write access to a buffer.
|
||||
///
|
||||
/// Buffer positions are used as bindings when launching a compute kernel.
|
||||
pub fn execute(
|
||||
&self,
|
||||
work_group: WorkGroup,
|
||||
pipeline: Arc<ComputePipeline>,
|
||||
buffers: &[&Buffer],
|
||||
) {
|
||||
let group_layout = pipeline.get_bind_group_layout(0);
|
||||
|
||||
let entries = buffers
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, buffer)| wgpu::BindGroupEntry {
|
||||
binding: i as u32,
|
||||
resource: buffer.as_entire_binding(),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let bind_group = self
|
||||
.device_wgpu
|
||||
.create_bind_group(&wgpu::BindGroupDescriptor {
|
||||
label: None,
|
||||
layout: &group_layout,
|
||||
entries: &entries,
|
||||
});
|
||||
|
||||
self.client
|
||||
.register_compute(bind_group, pipeline, work_group)
|
||||
}
|
||||
|
||||
/// Create a new buffer with the provided size.
|
||||
pub fn create_buffer(&self, size: usize) -> Arc<Buffer> {
|
||||
Arc::new(self.device_wgpu.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: None,
|
||||
size: size as u64,
|
||||
usage: wgpu::BufferUsages::COPY_DST
|
||||
| wgpu::BufferUsages::STORAGE
|
||||
| wgpu::BufferUsages::COPY_SRC,
|
||||
mapped_at_creation: false,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Create a new buffer initialized with the provided bytes.
|
||||
pub fn create_buffer_with_data(&self, data: &[u8]) -> Arc<Buffer> {
|
||||
self.create_buffer_with_data_options(data, false)
|
||||
}
|
||||
|
||||
/// Create a new buffer initialized with the provided bytes with the option to be sync.
|
||||
///
|
||||
/// It's important to be sync when you want to reuse the buffer using the Arc strong count for
|
||||
/// inner mutability.
|
||||
pub fn create_buffer_with_data_options(&self, data: &[u8], sync: bool) -> Arc<Buffer> {
|
||||
let buffer_src = Arc::new(self.device_wgpu.create_buffer_init(&BufferInitDescriptor {
|
||||
label: Some("Buffer Src"),
|
||||
contents: data,
|
||||
usage: wgpu::BufferUsages::COPY_SRC,
|
||||
}));
|
||||
|
||||
let buffer_dest = self.create_buffer(buffer_src.size() as usize);
|
||||
|
||||
self.client.copy_buffer(buffer_src, buffer_dest, sync)
|
||||
}
|
||||
|
||||
/// Copy buffer to buffer.
|
||||
///
|
||||
/// Wait for registered may be useful if you want to allow inplace operations on the created
|
||||
/// buffer. Otherwise, the strong count of the buffer might not be 1 when registering a new
|
||||
/// operation, which makes the buffer readonly.
|
||||
pub fn copy_buffer(&self, buffer_src: Arc<Buffer>, wait_for_registered: bool) -> Arc<Buffer> {
|
||||
let buffer_dest = self.create_buffer(buffer_src.size() as usize);
|
||||
|
||||
self.client
|
||||
.copy_buffer(buffer_src, buffer_dest, wait_for_registered)
|
||||
}
|
||||
|
||||
/// Read a buffer from the GPU and return its content as bytes.
|
||||
pub fn read_buffer(&self, buffer: Arc<Buffer>) -> Vec<u8> {
|
||||
self.client.read_buffer(buffer)
|
||||
}
|
||||
|
||||
/// Compile a kernel template if not present in the cache.
|
||||
pub fn compile_static<K: StaticKernel>(&self) -> Arc<ComputePipeline> {
|
||||
let mut cache = self.cache.lock();
|
||||
let template_id = TemplateKey::Static(TypeId::of::<K>());
|
||||
|
||||
if let Some(module) = cache.get(&template_id) {
|
||||
return module.clone();
|
||||
}
|
||||
|
||||
let source = K::source_template();
|
||||
let pipeline = self.compile_source(&source.complete());
|
||||
|
||||
if self.is_tuning.load(Ordering::Relaxed) {
|
||||
let mut templates_vec = self.tuning_template_ids.lock();
|
||||
templates_vec.push(template_id.clone());
|
||||
}
|
||||
|
||||
cache.insert(template_id, pipeline.clone());
|
||||
pipeline
|
||||
}
|
||||
|
||||
/// Compile a dynamic template if not present in the cache.
|
||||
pub fn compile_dynamic<K: DynamicKernel>(&self, kernel: K) -> Arc<ComputePipeline> {
|
||||
let mut cache = self.cache.lock();
|
||||
let template_id = TemplateKey::Dynamic(kernel.id());
|
||||
|
||||
if let Some(module) = cache.get(&template_id) {
|
||||
return module.clone();
|
||||
}
|
||||
|
||||
let source = kernel.source_template();
|
||||
let pipeline = self.compile_source(&source.complete());
|
||||
|
||||
if self.is_tuning.load(Ordering::Relaxed) {
|
||||
let mut templates_vec = self.tuning_template_ids.lock();
|
||||
templates_vec.push(template_id.clone());
|
||||
}
|
||||
|
||||
cache.insert(template_id, pipeline.clone());
|
||||
pipeline
|
||||
}
|
||||
|
||||
fn compile_source(&self, source: &str) -> Arc<ComputePipeline> {
|
||||
let module = self
|
||||
.device_wgpu
|
||||
.create_shader_module(ShaderModuleDescriptor {
|
||||
label: None,
|
||||
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
|
||||
});
|
||||
let pipeline = self
|
||||
.device_wgpu
|
||||
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||
label: None,
|
||||
layout: None,
|
||||
module: &module,
|
||||
entry_point: "main",
|
||||
});
|
||||
|
||||
Arc::new(pipeline)
|
||||
}
|
||||
|
||||
pub(crate) fn start_tuning(&self) {
|
||||
self.is_tuning.store(true, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub(crate) fn stop_tuning(&self) {
|
||||
self.is_tuning.store(false, Ordering::Relaxed);
|
||||
|
||||
// clean cache of pipelines accumulated during tuning
|
||||
let mut cache = self.cache.lock();
|
||||
let mut tuning_template_ids = self.tuning_template_ids.lock();
|
||||
for template_id in tuning_template_ids.iter() {
|
||||
cache.remove(template_id);
|
||||
}
|
||||
|
||||
tuning_template_ids.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Context {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.id == other.id
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn select_device<G: GraphicsApi>(
|
||||
device: &WgpuDevice,
|
||||
) -> (wgpu::Device, wgpu::Queue, wgpu::AdapterInfo) {
|
||||
let adapter = select_adapter::<G>(device);
|
||||
let limits = adapter.limits();
|
||||
|
||||
let (device, queue) = adapter
|
||||
.request_device(
|
||||
&DeviceDescriptor {
|
||||
label: None,
|
||||
features: wgpu::Features::empty(),
|
||||
limits,
|
||||
},
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
format!(
|
||||
"Unable to request the device with the adapter {:?}, err {:?}",
|
||||
adapter.get_info(),
|
||||
err
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
(device, queue, adapter.get_info())
|
||||
}
|
||||
|
||||
fn select_adapter<G: GraphicsApi>(device: &WgpuDevice) -> wgpu::Adapter {
|
||||
let instance = wgpu::Instance::default();
|
||||
|
||||
let mut adapters_other = Vec::new();
|
||||
let mut adapters = Vec::new();
|
||||
|
||||
instance
|
||||
.enumerate_adapters(G::backend().into())
|
||||
.for_each(|adapter| {
|
||||
let device_type = adapter.get_info().device_type;
|
||||
|
||||
if let DeviceType::Other = device_type {
|
||||
adapters_other.push(adapter);
|
||||
return;
|
||||
}
|
||||
|
||||
let is_same_type = match device {
|
||||
WgpuDevice::DiscreteGpu(_) => device_type == DeviceType::DiscreteGpu,
|
||||
WgpuDevice::IntegratedGpu(_) => device_type == DeviceType::IntegratedGpu,
|
||||
WgpuDevice::VirtualGpu(_) => device_type == DeviceType::VirtualGpu,
|
||||
WgpuDevice::Cpu => device_type == DeviceType::Cpu,
|
||||
WgpuDevice::BestAvailable => true,
|
||||
};
|
||||
|
||||
if is_same_type {
|
||||
adapters.push(adapter);
|
||||
}
|
||||
});
|
||||
|
||||
fn select(
|
||||
num: usize,
|
||||
error: &str,
|
||||
mut adapters: Vec<wgpu::Adapter>,
|
||||
mut adapters_other: Vec<wgpu::Adapter>,
|
||||
) -> wgpu::Adapter {
|
||||
if adapters.len() <= num {
|
||||
if adapters_other.len() <= num {
|
||||
panic!(
|
||||
"{}, adapters {:?}, other adapters {:?}",
|
||||
error,
|
||||
adapters
|
||||
.into_iter()
|
||||
.map(|adapter| adapter.get_info())
|
||||
.collect::<Vec<_>>(),
|
||||
adapters_other
|
||||
.into_iter()
|
||||
.map(|adapter| adapter.get_info())
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
} else {
|
||||
return adapters_other.remove(num);
|
||||
}
|
||||
}
|
||||
|
||||
adapters.remove(num)
|
||||
}
|
||||
|
||||
let adapter = match device {
|
||||
WgpuDevice::DiscreteGpu(num) => select(
|
||||
*num,
|
||||
"No Discrete GPU device found",
|
||||
adapters,
|
||||
adapters_other,
|
||||
),
|
||||
WgpuDevice::IntegratedGpu(num) => select(
|
||||
*num,
|
||||
"No Integrated GPU device found",
|
||||
adapters,
|
||||
adapters_other,
|
||||
),
|
||||
WgpuDevice::VirtualGpu(num) => select(
|
||||
*num,
|
||||
"No Virtual GPU device found",
|
||||
adapters,
|
||||
adapters_other,
|
||||
),
|
||||
WgpuDevice::Cpu => select(0, "No CPU device found", adapters, adapters_other),
|
||||
WgpuDevice::BestAvailable => {
|
||||
let mut most_performant_adapter = None;
|
||||
let mut current_score = -1;
|
||||
|
||||
adapters.into_iter().for_each(|adapter| {
|
||||
let info = adapter.get_info();
|
||||
let score = match info.device_type {
|
||||
DeviceType::DiscreteGpu => 5,
|
||||
DeviceType::Other => 4, // Let's be optimistic with the Other device, it's
|
||||
// often a Discrete Gpu.
|
||||
DeviceType::IntegratedGpu => 3,
|
||||
DeviceType::VirtualGpu => 2,
|
||||
DeviceType::Cpu => 1,
|
||||
};
|
||||
|
||||
if score > current_score {
|
||||
most_performant_adapter = Some(adapter);
|
||||
current_score = score;
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(adapter) = most_performant_adapter {
|
||||
adapter
|
||||
} else {
|
||||
panic!("No adapter found for graphics API {:?}", G::default());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
log::info!("Using adapter {:?}", adapter.get_info());
|
||||
|
||||
adapter
|
||||
}
|
|
@ -1,190 +0,0 @@
|
|||
use super::WorkGroup;
|
||||
use std::sync::Arc;
|
||||
use wgpu::{BindGroup, Buffer, ComputePipeline};
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
pub use async_client::AsyncContextClient;
|
||||
#[cfg(not(feature = "async"))]
|
||||
pub use sync_client::SyncContextClient;
|
||||
|
||||
/// Context client allows to speak with a server to execute tasks on the GPU.
|
||||
pub trait ContextClient {
|
||||
/// Copy the source buffer content into the destination buffer.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Make sure the source buffer isn't used afterward, since a race condition may happen.
|
||||
///
|
||||
/// If the source buffer is still used afterward, use [tensor copy](crate::tensor::WgpuTensor::copy)
|
||||
/// instead. This method is still useful to load data from the CPU into a new buffer.
|
||||
fn copy_buffer(
|
||||
&self,
|
||||
buffer_src: Arc<Buffer>,
|
||||
buffer_dest: Arc<Buffer>,
|
||||
wait_for_registered: bool,
|
||||
) -> Arc<Buffer>;
|
||||
/// Read a [buffer](Buffer).
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// All pending compute tasks will be executed.
|
||||
fn read_buffer(&self, buffer: Arc<Buffer>) -> Vec<u8>;
|
||||
/// Register a new computing task.
|
||||
fn register_compute(
|
||||
&self,
|
||||
bind_group: BindGroup,
|
||||
pipeline: Arc<ComputePipeline>,
|
||||
work_group: WorkGroup,
|
||||
);
|
||||
/// Wait for all computation to be done.
|
||||
///
|
||||
/// Useful for benchmarks.
|
||||
fn sync(&self);
|
||||
}
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
mod async_client {
|
||||
use super::ContextClient;
|
||||
use crate::context::{
|
||||
server::{ComputeTask, ContextTask, CopyBufferTask, ReadBufferTask},
|
||||
WorkGroup,
|
||||
};
|
||||
use std::sync::{mpsc, Arc};
|
||||
use wgpu::{BindGroup, Buffer, ComputePipeline};
|
||||
|
||||
/// Client returned by
|
||||
#[derive(new, Debug)]
|
||||
pub struct AsyncContextClient {
|
||||
sender: mpsc::SyncSender<ContextTask>,
|
||||
_server_handle: std::thread::JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl ContextClient for AsyncContextClient {
|
||||
fn sync(&self) {
|
||||
let (sender, receiver) = std::sync::mpsc::channel();
|
||||
|
||||
self.sender.send(ContextTask::Sync(sender)).unwrap();
|
||||
|
||||
if receiver.iter().next().is_some() {
|
||||
log::debug!("Sync completed");
|
||||
} else {
|
||||
panic!("Unable sync")
|
||||
}
|
||||
}
|
||||
|
||||
fn copy_buffer(
|
||||
&self,
|
||||
buffer_src: Arc<Buffer>,
|
||||
buffer_dest: Arc<Buffer>,
|
||||
wait_for_registered: bool,
|
||||
) -> Arc<Buffer> {
|
||||
if wait_for_registered {
|
||||
assert_eq!(Arc::strong_count(&buffer_dest), 1, "You can't wait for the buffer to be registered when multiple references already exist.");
|
||||
}
|
||||
|
||||
self.sender
|
||||
.send(CopyBufferTask::new(buffer_src, buffer_dest.clone()).into())
|
||||
.unwrap();
|
||||
|
||||
if !wait_for_registered {
|
||||
return buffer_dest;
|
||||
}
|
||||
|
||||
// Wait for the buffer to be correctly registered so that inplace operations can be
|
||||
// prioritize.
|
||||
//
|
||||
// Note that this is unsafe and a channel could have been used to wait for completion.
|
||||
// The loop is there for performance reason.
|
||||
//
|
||||
// TODO: Use a performant one time channel here as callback instead.
|
||||
loop {
|
||||
std::thread::sleep(std::time::Duration::from_micros(1));
|
||||
|
||||
if Arc::strong_count(&buffer_dest) == 1 {
|
||||
return buffer_dest;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn read_buffer(&self, buffer: Arc<Buffer>) -> Vec<u8> {
|
||||
let (sender, receiver) = std::sync::mpsc::channel();
|
||||
|
||||
self.sender
|
||||
.send(ReadBufferTask::new(buffer, sender).into())
|
||||
.unwrap();
|
||||
|
||||
let mut iter = receiver.iter();
|
||||
if let Some(data) = iter.next() {
|
||||
data
|
||||
} else {
|
||||
panic!("Unable to read buffer")
|
||||
}
|
||||
}
|
||||
fn register_compute(
|
||||
&self,
|
||||
bind_group: BindGroup,
|
||||
pipeline: Arc<ComputePipeline>,
|
||||
work_group: WorkGroup,
|
||||
) {
|
||||
self.sender
|
||||
.send(ComputeTask::new(bind_group, pipeline, work_group).into())
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "async"))]
|
||||
mod sync_client {
|
||||
use super::ContextClient;
|
||||
use crate::context::{
|
||||
server::{ComputeTask, SyncContextServer},
|
||||
WorkGroup,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use wgpu::{BindGroup, Buffer, ComputePipeline};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SyncContextClient {
|
||||
server: spin::Mutex<SyncContextServer>,
|
||||
}
|
||||
|
||||
impl SyncContextClient {
|
||||
pub fn new(server: SyncContextServer) -> Self {
|
||||
Self {
|
||||
server: spin::Mutex::new(server),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ContextClient for SyncContextClient {
|
||||
fn sync(&self) {
|
||||
let mut server = self.server.lock();
|
||||
server.sync();
|
||||
}
|
||||
fn copy_buffer(
|
||||
&self,
|
||||
buffer_src: Arc<Buffer>,
|
||||
buffer_dest: Arc<Buffer>,
|
||||
_wait_for_registered: bool, // Ignored when sync
|
||||
) -> Arc<Buffer> {
|
||||
let mut server = self.server.lock();
|
||||
server.buffer_to_buffer(buffer_src, buffer_dest.clone());
|
||||
|
||||
buffer_dest
|
||||
}
|
||||
fn read_buffer(&self, buffer: Arc<Buffer>) -> Vec<u8> {
|
||||
let mut server = self.server.lock();
|
||||
server.read_buffer(&buffer)
|
||||
}
|
||||
|
||||
fn register_compute(
|
||||
&self,
|
||||
bind_group: BindGroup,
|
||||
pipeline: Arc<ComputePipeline>,
|
||||
work_group: WorkGroup,
|
||||
) {
|
||||
let mut server = self.server.lock();
|
||||
server.register_compute(ComputeTask::new(bind_group, pipeline, work_group));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,6 +0,0 @@
|
|||
pub(super) mod client;
|
||||
pub(super) mod server;
|
||||
|
||||
mod base;
|
||||
|
||||
pub use base::*;
|
|
@ -1,269 +0,0 @@
|
|||
use super::{client::ContextClient, WorkGroup};
|
||||
use std::sync::Arc;
|
||||
use wgpu::{BindGroup, Buffer, CommandEncoder, ComputePipeline};
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
pub use async_server::{AsyncContextServer, ContextTask, CopyBufferTask, ReadBufferTask};
|
||||
|
||||
/// Context server allow to run tasks on the GPU.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// There are two implementations of this trait. One is a bit more performant while the other
|
||||
/// doesn't require std.
|
||||
///
|
||||
/// * [Asynchronous server](AsyncContextServer).
|
||||
/// * [Synchronous server](SyncContextServer).
|
||||
pub trait ContextServer {
|
||||
/// The client where task can be sent to the server for execution.
|
||||
type Client: ContextClient;
|
||||
|
||||
/// Start the server and returns its [client](ContextClient).
|
||||
fn start(device: Arc<wgpu::Device>, queue: wgpu::Queue) -> Self::Client;
|
||||
}
|
||||
|
||||
/// Context server where each operation is added in a synchronous maner.
|
||||
#[derive(Debug)]
|
||||
pub struct SyncContextServer {
|
||||
device: Arc<wgpu::Device>,
|
||||
queue: wgpu::Queue,
|
||||
encoder: CommandEncoder,
|
||||
tasks: Vec<ComputeTask>,
|
||||
max_tasks: usize,
|
||||
}
|
||||
|
||||
/// Basic building block to execute computing tasks on the GPU.
|
||||
#[derive(new, Debug)]
|
||||
pub struct ComputeTask {
|
||||
bind_group: BindGroup,
|
||||
pipeline: Arc<ComputePipeline>,
|
||||
work_group: WorkGroup,
|
||||
}
|
||||
|
||||
/// Most of the functions are similar to [server client](IContextClient).
|
||||
///
|
||||
/// The main difference comes from the functions are mutable instead of immutable requirering a
|
||||
/// lock by a sync client or using an async channel with the async client/server.
|
||||
impl SyncContextServer {
|
||||
/// Create a new sync context server.
|
||||
pub fn new(device: Arc<wgpu::Device>, queue: wgpu::Queue) -> Self {
|
||||
let encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||
label: Some("Command Encoder"),
|
||||
});
|
||||
|
||||
// TODO: Support a way to modify this value without std.
|
||||
let max_tasks = match std::env::var("BURN_WGPU_MAX_TASKS") {
|
||||
Ok(value) => value
|
||||
.parse::<usize>()
|
||||
.expect("BURN_WGPU_MAX_TASKS should be a positive integer."),
|
||||
Err(_) => 16, // 16 tasks by default
|
||||
};
|
||||
|
||||
Self {
|
||||
device,
|
||||
queue,
|
||||
encoder,
|
||||
tasks: Vec::new(),
|
||||
max_tasks,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register_compute(&mut self, task: ComputeTask) {
|
||||
self.tasks.push(task);
|
||||
|
||||
if self.tasks.len() > self.max_tasks {
|
||||
self.register_tasks();
|
||||
self.submit();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_buffer(&mut self, buffer: &Buffer) -> Vec<u8> {
|
||||
// Register previous tasks before reading the buffer so that it is up to date.
|
||||
self.register_tasks();
|
||||
|
||||
let size = buffer.size();
|
||||
let buffer_dest = self.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: None,
|
||||
size,
|
||||
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
self.encoder
|
||||
.copy_buffer_to_buffer(buffer, 0, &buffer_dest, 0, size);
|
||||
|
||||
self.submit();
|
||||
|
||||
let buffer_slice = buffer_dest.slice(..);
|
||||
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
|
||||
buffer_slice.map_async(wgpu::MapMode::Read, move |v| {
|
||||
sender
|
||||
.send(v)
|
||||
.expect("Unable to send buffer slice result to async channel.")
|
||||
});
|
||||
|
||||
self.device.poll(wgpu::Maintain::Wait);
|
||||
|
||||
let result = pollster::block_on(receiver.receive());
|
||||
|
||||
if let Some(Ok(())) = result {
|
||||
let data = buffer_slice.get_mapped_range();
|
||||
let result = bytemuck::cast_slice(&data).to_vec();
|
||||
|
||||
drop(data);
|
||||
buffer_dest.unmap();
|
||||
result
|
||||
} else {
|
||||
panic!("Unable to read buffer {:?}", result)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sync(&mut self) {
|
||||
if !self.tasks.is_empty() {
|
||||
self.register_tasks();
|
||||
self.submit();
|
||||
}
|
||||
|
||||
self.device.poll(wgpu::Maintain::Wait);
|
||||
}
|
||||
|
||||
pub fn buffer_to_buffer(&mut self, buffer_src: Arc<Buffer>, buffer_dest: Arc<Buffer>) {
|
||||
self.encoder
|
||||
.copy_buffer_to_buffer(&buffer_src, 0, &buffer_dest, 0, buffer_src.size());
|
||||
}
|
||||
|
||||
fn register_tasks(&mut self) {
|
||||
let mut compute = self
|
||||
.encoder
|
||||
.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
|
||||
for task in self.tasks.iter() {
|
||||
compute.set_pipeline(&task.pipeline);
|
||||
compute.set_bind_group(0, &task.bind_group, &[]);
|
||||
compute.dispatch_workgroups(task.work_group.x, task.work_group.y, task.work_group.z);
|
||||
}
|
||||
std::mem::drop(compute);
|
||||
self.tasks.clear();
|
||||
}
|
||||
|
||||
fn submit(&mut self) {
|
||||
assert!(
|
||||
self.tasks.is_empty(),
|
||||
"Tasks should be completed before submitting the current encoder."
|
||||
);
|
||||
let mut new_encoder = self
|
||||
.device
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
|
||||
core::mem::swap(&mut new_encoder, &mut self.encoder);
|
||||
|
||||
self.queue.submit(Some(new_encoder.finish()));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
mod async_server {
|
||||
use crate::context::client::AsyncContextClient;
|
||||
|
||||
use super::{ComputeTask, ContextServer, SyncContextServer};
|
||||
use std::sync::{mpsc, Arc};
|
||||
use wgpu::Buffer;
|
||||
|
||||
#[derive(new)]
|
||||
pub struct ReadBufferTask {
|
||||
buffer: Arc<Buffer>,
|
||||
sender: mpsc::Sender<Vec<u8>>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub struct CopyBufferTask {
|
||||
pub(crate) buffer_src: Arc<Buffer>,
|
||||
pub(crate) buffer_dest: Arc<Buffer>,
|
||||
}
|
||||
|
||||
pub enum ContextTask {
|
||||
Compute(ComputeTask),
|
||||
ReadBuffer(ReadBufferTask),
|
||||
CopyBuffer(CopyBufferTask),
|
||||
Sync(mpsc::Sender<()>),
|
||||
}
|
||||
|
||||
impl From<ComputeTask> for ContextTask {
|
||||
fn from(val: ComputeTask) -> Self {
|
||||
ContextTask::Compute(val)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ReadBufferTask> for ContextTask {
|
||||
fn from(val: ReadBufferTask) -> Self {
|
||||
ContextTask::ReadBuffer(val)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CopyBufferTask> for ContextTask {
|
||||
fn from(val: CopyBufferTask) -> Self {
|
||||
ContextTask::CopyBuffer(val)
|
||||
}
|
||||
}
|
||||
|
||||
/// Asynchronous context server where [tasks](ContextTask) are sent using a channel.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// This is pretty useful to avoid blocking the main thread when registering and
|
||||
/// executing [compute tasks](ComputeTask).
|
||||
pub struct AsyncContextServer {
|
||||
server: SyncContextServer,
|
||||
receiver: mpsc::Receiver<ContextTask>,
|
||||
}
|
||||
|
||||
impl AsyncContextServer {
|
||||
fn run(mut self) {
|
||||
loop {
|
||||
let task = self.receiver.recv().unwrap();
|
||||
match task {
|
||||
ContextTask::Compute(task) => self.server.register_compute(task),
|
||||
ContextTask::CopyBuffer(task) => self
|
||||
.server
|
||||
.buffer_to_buffer(task.buffer_src, task.buffer_dest),
|
||||
ContextTask::ReadBuffer(task) => {
|
||||
let bytes = self.server.read_buffer(&task.buffer);
|
||||
task.sender.send(bytes).unwrap();
|
||||
}
|
||||
ContextTask::Sync(callback) => {
|
||||
self.server.sync();
|
||||
callback.send(()).unwrap();
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
impl ContextServer for AsyncContextServer {
|
||||
type Client = AsyncContextClient;
|
||||
|
||||
fn start(device: Arc<wgpu::Device>, queue: wgpu::Queue) -> Self::Client {
|
||||
let (sender, receiver) = std::sync::mpsc::sync_channel(50);
|
||||
let server = SyncContextServer::new(device, queue);
|
||||
let context = Self { server, receiver };
|
||||
|
||||
let handle = std::thread::spawn(|| context.run());
|
||||
|
||||
AsyncContextClient::new(sender, handle)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "async"))]
|
||||
mod sync_server {
|
||||
use super::{ContextServer, SyncContextServer};
|
||||
use crate::context::client::SyncContextClient;
|
||||
use std::sync::Arc;
|
||||
|
||||
impl ContextServer for SyncContextServer {
|
||||
type Client = SyncContextClient;
|
||||
|
||||
fn start(device: Arc<wgpu::Device>, queue: wgpu::Queue) -> Self::Client {
|
||||
let server = Self::new(device, queue);
|
||||
|
||||
SyncContextClient::new(server)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,17 +1,21 @@
|
|||
use super::SourceTemplate;
|
||||
use crate::{context::WorkGroup, element::WgpuElement, tensor::WgpuTensor};
|
||||
use crate::{
|
||||
compute::{StaticKernel, WorkGroup},
|
||||
element::WgpuElement,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Static wgpu kernel to create a [source template](SourceTemplate).
|
||||
pub trait StaticKernel: Send + 'static {
|
||||
pub trait StaticKernelSource: Send + 'static {
|
||||
/// Source template for the kernel.
|
||||
fn source_template() -> SourceTemplate;
|
||||
fn source() -> SourceTemplate;
|
||||
}
|
||||
|
||||
/// Dynamic wgpu kernel to create a [source template](SourceTemplate).
|
||||
pub trait DynamicKernel: Send {
|
||||
pub trait DynamicKernelSource: Send {
|
||||
/// Source template for the kernel.
|
||||
fn source_template(self) -> SourceTemplate;
|
||||
fn source(self) -> SourceTemplate;
|
||||
/// Identifier for the kernel, used for caching kernel compilation.
|
||||
fn id(&self) -> String;
|
||||
}
|
||||
|
@ -27,8 +31,8 @@ macro_rules! kernel_wgsl {
|
|||
#[derive(new)]
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernel for $struct {
|
||||
fn source_template() -> $crate::kernel::SourceTemplate {
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::SourceTemplate::new(include_str!($file))
|
||||
}
|
||||
}
|
||||
|
@ -48,23 +52,21 @@ pub fn into_contiguous<E: WgpuElement, const D: usize>(
|
|||
const WORKGROUP: usize = 32;
|
||||
|
||||
let num_elems = tensor.shape.num_elements();
|
||||
let buffer = tensor
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(tensor.context.clone(), tensor.shape.clone(), buffer);
|
||||
let handle = tensor.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(
|
||||
tensor.client.clone(),
|
||||
tensor.device.clone(),
|
||||
tensor.shape.clone(),
|
||||
handle,
|
||||
);
|
||||
let info = build_info(&[&tensor, &output]);
|
||||
let info_buffer = tensor
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
let info_handle = tensor.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = tensor
|
||||
.context
|
||||
.compile_static::<KernelSettings<ContiguousRaw, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
tensor.context.execute(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
&[&tensor.buffer, &output.buffer, &info_buffer],
|
||||
tensor.client.execute(
|
||||
Box::new(StaticKernel::<
|
||||
KernelSettings<ContiguousRaw, E, i32, WORKGROUP, WORKGROUP, 1>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP))),
|
||||
&[&tensor.handle, &output.handle, &info_handle],
|
||||
);
|
||||
|
||||
output
|
||||
|
@ -72,7 +74,7 @@ pub fn into_contiguous<E: WgpuElement, const D: usize>(
|
|||
|
||||
/// Generates kernel source code by replacing some information using templating.
|
||||
pub struct KernelSettings<
|
||||
K: StaticKernel,
|
||||
K: StaticKernelSource,
|
||||
E: WgpuElement,
|
||||
I: WgpuElement,
|
||||
const WORKGROUP_X_SIZE: usize,
|
||||
|
@ -85,17 +87,17 @@ pub struct KernelSettings<
|
|||
}
|
||||
|
||||
impl<
|
||||
K: StaticKernel,
|
||||
K: StaticKernelSource,
|
||||
E: WgpuElement,
|
||||
I: WgpuElement,
|
||||
const WORKGROUP_X_SIZE: usize,
|
||||
const WORKGROUP_Y_SIZE: usize,
|
||||
const WORKGROUP_Z_SIZE: usize,
|
||||
> StaticKernel
|
||||
> StaticKernelSource
|
||||
for KernelSettings<K, E, I, WORKGROUP_X_SIZE, WORKGROUP_Y_SIZE, WORKGROUP_Z_SIZE>
|
||||
{
|
||||
fn source_template() -> SourceTemplate {
|
||||
K::source_template()
|
||||
fn source() -> SourceTemplate {
|
||||
K::source()
|
||||
.register("workgroup_size_x", WORKGROUP_X_SIZE.to_string())
|
||||
.register("workgroup_size_y", WORKGROUP_Y_SIZE.to_string())
|
||||
.register("workgroup_size_z", WORKGROUP_Z_SIZE.to_string())
|
||||
|
@ -110,7 +112,7 @@ impl<
|
|||
|
||||
/// Generate kernel source code by replacing some information using templating.
|
||||
#[derive(new)]
|
||||
pub struct DynamicKernelSettings<K: StaticKernel, E: WgpuElement, I: WgpuElement> {
|
||||
pub struct DynamicKernelSettings<K: StaticKernelSource, E: WgpuElement, I: WgpuElement> {
|
||||
workgroup_x_size: usize,
|
||||
workgroup_y_size: usize,
|
||||
workgroup_z_size: usize,
|
||||
|
@ -119,11 +121,11 @@ pub struct DynamicKernelSettings<K: StaticKernel, E: WgpuElement, I: WgpuElement
|
|||
_i: PhantomData<I>,
|
||||
}
|
||||
|
||||
impl<K: StaticKernel, E: WgpuElement, I: WgpuElement> DynamicKernel
|
||||
impl<K: StaticKernelSource, E: WgpuElement, I: WgpuElement> DynamicKernelSource
|
||||
for DynamicKernelSettings<K, E, I>
|
||||
{
|
||||
fn source_template(self) -> SourceTemplate {
|
||||
K::source_template()
|
||||
fn source(self) -> SourceTemplate {
|
||||
K::source()
|
||||
.register("workgroup_size_x", self.workgroup_x_size.to_string())
|
||||
.register("workgroup_size_y", self.workgroup_y_size.to_string())
|
||||
.register("workgroup_size_z", self.workgroup_z_size.to_string())
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use super::{build_info, elemwise_workgroup, KernelSettings, StaticKernel};
|
||||
use super::{build_info, elemwise_workgroup, KernelSettings, StaticKernelSource};
|
||||
use crate::compute::StaticKernel;
|
||||
use crate::{element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
|
||||
use burn_tensor::Shape;
|
||||
|
||||
|
@ -17,9 +18,9 @@ macro_rules! binary_elemwise {
|
|||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernel for $struct {
|
||||
fn source_template() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::BinaryElemwiseRaw::source_template().register(
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::BinaryElemwiseRaw::source().register(
|
||||
"body",
|
||||
format!("output[id] = lhs[index_lhs] {} rhs[index_rhs];", $ops),
|
||||
)
|
||||
|
@ -37,9 +38,9 @@ macro_rules! binary_elemwise_inplace {
|
|||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernel for $struct {
|
||||
fn source_template() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::BinaryElemwiseInplaceRaw::source_template().register(
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::BinaryElemwiseInplaceRaw::source().register(
|
||||
"body",
|
||||
format!("lhs[id] = lhs[id] {} rhs[index_rhs];", $ops),
|
||||
)
|
||||
|
@ -49,7 +50,7 @@ macro_rules! binary_elemwise_inplace {
|
|||
}
|
||||
|
||||
/// Execute a binary kernel using the default settings.
|
||||
pub fn binary_elemwise_default<K: StaticKernel, E: WgpuElement, const D: usize>(
|
||||
pub fn binary_elemwise_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
|
@ -57,7 +58,12 @@ pub fn binary_elemwise_default<K: StaticKernel, E: WgpuElement, const D: usize>(
|
|||
}
|
||||
|
||||
/// Execute a binary kernel using the provided WORKGROUP.
|
||||
pub fn binary_elemwise<K: StaticKernel, E: WgpuElement, const D: usize, const WORKGROUP: usize>(
|
||||
pub fn binary_elemwise<
|
||||
K: StaticKernelSource,
|
||||
E: WgpuElement,
|
||||
const D: usize,
|
||||
const WORKGROUP: usize,
|
||||
>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
|
@ -76,30 +82,26 @@ pub fn binary_elemwise<K: StaticKernel, E: WgpuElement, const D: usize, const WO
|
|||
let shape_out = Shape::new(shape_out);
|
||||
let num_elems = shape_out.num_elements();
|
||||
|
||||
let buffer = lhs
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(lhs.context.clone(), shape_out, buffer);
|
||||
let handle = lhs.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(lhs.client.clone(), lhs.device.clone(), shape_out, handle);
|
||||
|
||||
let kernel = lhs
|
||||
.context
|
||||
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
let info = build_info(&[&lhs, &rhs, &output]);
|
||||
let info_buffers = lhs
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
lhs.context.execute(
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
&[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers],
|
||||
);
|
||||
|
||||
lhs.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&lhs.handle, &rhs.handle, &output.handle, &info_handle],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Execute a binary inplace kernel using the default settings.
|
||||
pub fn binary_elemwise_inplace_default<K: StaticKernel, E: WgpuElement, const D: usize>(
|
||||
pub fn binary_elemwise_inplace_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
|
@ -108,7 +110,7 @@ pub fn binary_elemwise_inplace_default<K: StaticKernel, E: WgpuElement, const D:
|
|||
|
||||
/// Execute a binary inplace kernel using the provided WORKGROUP.
|
||||
pub fn binary_elemwise_inplace<
|
||||
K: StaticKernel,
|
||||
K: StaticKernelSource,
|
||||
E: WgpuElement,
|
||||
const D: usize,
|
||||
const WORKGROUP: usize,
|
||||
|
@ -118,20 +120,15 @@ pub fn binary_elemwise_inplace<
|
|||
) -> WgpuTensor<E, D> {
|
||||
lhs.assert_is_on_same_device(&rhs);
|
||||
|
||||
let kernel = lhs
|
||||
.context
|
||||
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
let info = build_info(&[&lhs, &rhs]);
|
||||
let info_buffers = lhs
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
lhs.context.execute(
|
||||
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP),
|
||||
kernel,
|
||||
&[&lhs.buffer, &rhs.buffer, &info_buffers],
|
||||
);
|
||||
|
||||
lhs.client
|
||||
.execute(Box::new(kernel), &[&lhs.handle, &rhs.handle, &info_handle]);
|
||||
|
||||
lhs
|
||||
}
|
||||
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
use crate::{element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl, tensor::WgpuTensor};
|
||||
use super::{KernelSettings, SourceTemplate, StaticKernelSource};
|
||||
use crate::{
|
||||
compute::StaticKernel, element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
use std::{any::TypeId, marker::PhantomData};
|
||||
|
||||
use super::{KernelSettings, SourceTemplate, StaticKernel};
|
||||
|
||||
kernel_wgsl!(CastRaw, "../template/cast.wgsl");
|
||||
|
||||
struct Cast<InputElem: WgpuElement, OutputElem: WgpuElement> {
|
||||
|
@ -10,9 +12,11 @@ struct Cast<InputElem: WgpuElement, OutputElem: WgpuElement> {
|
|||
_o: PhantomData<OutputElem>,
|
||||
}
|
||||
|
||||
impl<InputElem: WgpuElement, OutputElem: WgpuElement> StaticKernel for Cast<InputElem, OutputElem> {
|
||||
fn source_template() -> SourceTemplate {
|
||||
CastRaw::source_template()
|
||||
impl<InputElem: WgpuElement, OutputElem: WgpuElement> StaticKernelSource
|
||||
for Cast<InputElem, OutputElem>
|
||||
{
|
||||
fn source() -> SourceTemplate {
|
||||
CastRaw::source()
|
||||
.register("input_elem", InputElem::type_name())
|
||||
.register("output_elem", OutputElem::type_name())
|
||||
}
|
||||
|
@ -23,32 +27,30 @@ pub fn cast<InputElem: WgpuElement, OutputElem: WgpuElement, const D: usize>(
|
|||
tensor: WgpuTensor<InputElem, D>,
|
||||
) -> WgpuTensor<OutputElem, D> {
|
||||
if TypeId::of::<InputElem>() == TypeId::of::<OutputElem>() {
|
||||
return WgpuTensor::new(tensor.context, tensor.shape, tensor.buffer);
|
||||
return WgpuTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle);
|
||||
}
|
||||
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let num_elems = tensor.shape.num_elements();
|
||||
let kernel = tensor.context.compile_static::<KernelSettings<
|
||||
Cast<InputElem, OutputElem>,
|
||||
f32,
|
||||
i32,
|
||||
WORKGROUP,
|
||||
WORKGROUP,
|
||||
1,
|
||||
>>();
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<Cast<InputElem, OutputElem>, f32, i32, WORKGROUP, WORKGROUP, 1>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP));
|
||||
|
||||
let buffer = tensor
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<OutputElem>());
|
||||
let output = WgpuTensor::new(tensor.context.clone(), tensor.shape.clone(), buffer);
|
||||
|
||||
tensor.context.execute(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
&[&tensor.buffer, &output.buffer],
|
||||
let handle = tensor
|
||||
.client
|
||||
.empty(num_elems * core::mem::size_of::<OutputElem>());
|
||||
let output = WgpuTensor::new(
|
||||
tensor.client.clone(),
|
||||
tensor.device,
|
||||
tensor.shape.clone(),
|
||||
handle,
|
||||
);
|
||||
|
||||
tensor
|
||||
.client
|
||||
.execute(Box::new(kernel), &[&tensor.handle, &output.handle]);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{build_info, elemwise_workgroup, KernelSettings},
|
||||
kernel_wgsl,
|
||||
|
@ -14,16 +15,20 @@ pub fn cat<E: WgpuElement, const D: usize>(
|
|||
const WORKGROUP: usize = 32;
|
||||
|
||||
let first_input = inputs.get(0).unwrap();
|
||||
let context = &first_input.context;
|
||||
let client = &first_input.client;
|
||||
let mut shape_output = first_input.shape.clone();
|
||||
shape_output.dims[dim] = inputs.iter().map(|input| input.shape.dims[dim]).sum();
|
||||
|
||||
let buffer = first_input
|
||||
.context
|
||||
.create_buffer(shape_output.num_elements() * std::mem::size_of::<E>());
|
||||
.client
|
||||
.empty(shape_output.num_elements() * std::mem::size_of::<E>());
|
||||
|
||||
let output = WgpuTensor::new(context.clone(), shape_output, buffer);
|
||||
let kernel = context.compile_static::<KernelSettings<Cat, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
let output = WgpuTensor::new(
|
||||
client.clone(),
|
||||
first_input.device.clone(),
|
||||
shape_output,
|
||||
buffer,
|
||||
);
|
||||
|
||||
let mut dim_cat_index = 0;
|
||||
|
||||
|
@ -32,12 +37,14 @@ pub fn cat<E: WgpuElement, const D: usize>(
|
|||
info.push(dim as u32);
|
||||
info.push(dim_cat_index as u32);
|
||||
dim_cat_index += input.shape.dims[dim];
|
||||
let info_buffer = context.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
context.execute(
|
||||
let info_buffer = client.create(bytemuck::cast_slice(&info));
|
||||
let kernel = StaticKernel::<KernelSettings<Cat, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(input.shape.num_elements(), WORKGROUP),
|
||||
kernel.clone(),
|
||||
&[&input.buffer, &output.buffer, &info_buffer],
|
||||
);
|
||||
|
||||
client.execute(
|
||||
Box::new(kernel),
|
||||
&[&input.handle, &output.handle, &info_buffer],
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{build_info, elemwise_workgroup, KernelSettings, StaticKernel},
|
||||
kernel::{build_info, elemwise_workgroup, KernelSettings, StaticKernelSource},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
use burn_tensor::Shape;
|
||||
|
@ -21,9 +23,9 @@ macro_rules! comparison {
|
|||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernel for $struct {
|
||||
fn source_template() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::ComparisonRaw::source_template().register(
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::ComparisonRaw::source().register(
|
||||
"body",
|
||||
format!("output[id] = u32(lhs[index_lhs] {} rhs[index_rhs]);", $ops),
|
||||
)
|
||||
|
@ -41,9 +43,9 @@ macro_rules! comparison_inplace {
|
|||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernel for $struct {
|
||||
fn source_template() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::ComparisonInplaceRaw::source_template()
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::ComparisonInplaceRaw::source()
|
||||
.register(
|
||||
"body",
|
||||
"lhs[index_lhs] = compare(lhs[index_lhs], rhs[index_rhs]);",
|
||||
|
@ -59,7 +61,7 @@ macro_rules! comparison_inplace {
|
|||
};
|
||||
}
|
||||
|
||||
pub fn comparison<K: StaticKernel, E: WgpuElement, const D: usize>(
|
||||
pub fn comparison<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
|
@ -79,29 +81,23 @@ pub fn comparison<K: StaticKernel, E: WgpuElement, const D: usize>(
|
|||
let shape_out = Shape::new(shape_out);
|
||||
let num_elems = shape_out.num_elements();
|
||||
|
||||
let buffer = lhs
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<u32>());
|
||||
let output = WgpuTensor::new(lhs.context.clone(), shape_out, buffer);
|
||||
let output = empty_device(lhs.client.clone(), lhs.device.clone(), shape_out);
|
||||
|
||||
let kernel = lhs
|
||||
.context
|
||||
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
let info = build_info(&[&lhs, &rhs, &output]);
|
||||
let info_buffers = lhs
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
lhs.context.execute(
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
&[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers],
|
||||
);
|
||||
let info = build_info(&[&lhs, &rhs, &output]);
|
||||
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
lhs.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&lhs.handle, &rhs.handle, &output.handle, &info_handle],
|
||||
);
|
||||
|
||||
WgpuTensor::new(output.context, output.shape, output.buffer)
|
||||
WgpuTensor::new(output.client, output.device, output.shape, output.handle)
|
||||
}
|
||||
|
||||
pub fn comparison_inplace<K: StaticKernel, E: WgpuElement, const D: usize>(
|
||||
pub fn comparison_inplace<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
|
@ -109,21 +105,16 @@ pub fn comparison_inplace<K: StaticKernel, E: WgpuElement, const D: usize>(
|
|||
|
||||
lhs.assert_is_on_same_device(&rhs);
|
||||
|
||||
let kernel = lhs
|
||||
.context
|
||||
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
let info = build_info(&[&lhs, &rhs]);
|
||||
let info_buffers = lhs
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
lhs.context.execute(
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP),
|
||||
kernel,
|
||||
&[&lhs.buffer, &rhs.buffer, &info_buffers],
|
||||
);
|
||||
let info = build_info(&[&lhs, &rhs]);
|
||||
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
WgpuTensor::new(lhs.context, lhs.shape, lhs.buffer)
|
||||
lhs.client
|
||||
.execute(Box::new(kernel), &[&lhs.handle, &rhs.handle, &info_handle]);
|
||||
|
||||
WgpuTensor::new(lhs.client, lhs.device, lhs.shape, lhs.handle)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{elemwise_workgroup, KernelSettings, StaticKernel},
|
||||
kernel::{elemwise_workgroup, KernelSettings, StaticKernelSource},
|
||||
kernel_wgsl,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
@ -20,9 +21,9 @@ macro_rules! comparison_elem {
|
|||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernel for $struct {
|
||||
fn source_template() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::ComparisonElemRaw::source_template()
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::ComparisonElemRaw::source()
|
||||
.register("body", format!("output[id] = u32(lhs[id] {} rhs);", $ops))
|
||||
}
|
||||
}
|
||||
|
@ -38,9 +39,9 @@ macro_rules! comparison_elem_inplace {
|
|||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernel for $struct {
|
||||
fn source_template() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::ComparisonElemInplaceRaw::source_template()
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::ComparisonElemInplaceRaw::source()
|
||||
.register("body", "lhs[id] = compare(lhs[id], rhs);")
|
||||
.add_template(format!(
|
||||
"{}return {{{{ elem }}}}(lhs {} rhs);{}",
|
||||
|
@ -53,48 +54,39 @@ macro_rules! comparison_elem_inplace {
|
|||
};
|
||||
}
|
||||
|
||||
pub fn comparison_elem<K: StaticKernel, E: WgpuElement, const D: usize>(
|
||||
pub fn comparison_elem<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
let num_elems = lhs.shape.num_elements();
|
||||
|
||||
let buffer = lhs
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<u32>());
|
||||
let rhs_buffer = lhs.context.create_buffer_with_data(E::as_bytes(&[rhs]));
|
||||
let kernel = lhs
|
||||
.context
|
||||
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
lhs.context.execute(
|
||||
let handle = lhs.client.empty(num_elems * core::mem::size_of::<u32>());
|
||||
let rhs_handle = lhs.client.create(E::as_bytes(&[rhs]));
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
&[&lhs.buffer, &rhs_buffer, &buffer],
|
||||
);
|
||||
|
||||
WgpuTensor::new(lhs.context, lhs.shape, buffer)
|
||||
lhs.client
|
||||
.execute(Box::new(kernel), &[&lhs.handle, &rhs_handle, &handle]);
|
||||
|
||||
WgpuTensor::new(lhs.client, lhs.device, lhs.shape, handle)
|
||||
}
|
||||
|
||||
pub fn comparison_elem_inplace<K: StaticKernel, E: WgpuElement, const D: usize>(
|
||||
pub fn comparison_elem_inplace<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WgpuTensor<u32, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let kernel = lhs
|
||||
.context
|
||||
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
let rhs_buffer = lhs.context.create_buffer_with_data(E::as_bytes(&[rhs]));
|
||||
lhs.context.execute(
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(lhs.shape.num_elements(), WORKGROUP),
|
||||
kernel,
|
||||
&[&lhs.buffer, &rhs_buffer],
|
||||
);
|
||||
let rhs_handle = lhs.client.create(E::as_bytes(&[rhs]));
|
||||
lhs.client
|
||||
.execute(Box::new(kernel), &[&lhs.handle, &rhs_handle]);
|
||||
|
||||
WgpuTensor::new(lhs.context, lhs.shape, lhs.buffer)
|
||||
WgpuTensor::new(lhs.client, lhs.device, lhs.shape, lhs.handle)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -1,17 +1,19 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{self, build_info, elemwise_workgroup, KernelSettings},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
use burn_tensor::{
|
||||
ops::{conv::calculate_conv_output_size, ConvOptions},
|
||||
Shape,
|
||||
Element, ElementConversion, Shape,
|
||||
};
|
||||
|
||||
kernel_wgsl!(Conv2d, "../../template/conv/conv2d.wgsl");
|
||||
|
||||
pub(crate) fn conv2d<E: WgpuElement>(
|
||||
pub(crate) fn conv2d<E: WgpuElement + Element>(
|
||||
input: WgpuTensor<E, 4>,
|
||||
weight: WgpuTensor<E, 4>,
|
||||
bias: Option<WgpuTensor<E, 1>>,
|
||||
|
@ -40,12 +42,12 @@ pub(crate) fn conv2d<E: WgpuElement>(
|
|||
);
|
||||
|
||||
let shape_out = Shape::new([batch_size, out_channels, out_0, out_1]);
|
||||
let num_elems = shape_out.num_elements();
|
||||
|
||||
let buffer = input
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(input.context.clone(), shape_out, buffer);
|
||||
let output = empty_device(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
shape_out.clone(),
|
||||
);
|
||||
|
||||
let mut info = build_info(&[&input, &output, &weight]);
|
||||
info.push(options.stride[0] as u32);
|
||||
|
@ -56,27 +58,24 @@ pub(crate) fn conv2d<E: WgpuElement>(
|
|||
info.push(options.dilation[1] as u32);
|
||||
info.push(options.groups as u32);
|
||||
|
||||
let bias_buffer = bias
|
||||
.map(|bias| bias.buffer)
|
||||
.unwrap_or_else(|| input.context.create_buffer(core::mem::size_of::<E>()));
|
||||
let bias_handle = bias
|
||||
.map(|bias| bias.handle)
|
||||
.unwrap_or_else(|| input.client.create(E::as_bytes(&[0.elem()])));
|
||||
|
||||
let info_buffer = input
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
let info_handle = input.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = input
|
||||
.context
|
||||
.compile_static::<KernelSettings<Conv2d, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
input.context.execute(
|
||||
let kernel = StaticKernel::<KernelSettings<Conv2d, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
|
||||
kernel,
|
||||
);
|
||||
|
||||
input.client.execute(
|
||||
Box::new(kernel),
|
||||
&[
|
||||
&input.buffer,
|
||||
&weight.buffer,
|
||||
&bias_buffer,
|
||||
&output.buffer,
|
||||
&info_buffer,
|
||||
&input.handle,
|
||||
&weight.handle,
|
||||
&bias_handle,
|
||||
&output.handle,
|
||||
&info_handle,
|
||||
],
|
||||
);
|
||||
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{self, build_info, elemwise_workgroup, KernelSettings},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
use burn_tensor::{ops::ConvTransposeOptions, Shape};
|
||||
use burn_tensor::{ops::ConvTransposeOptions, Element, ElementConversion, Shape};
|
||||
|
||||
kernel_wgsl!(ConvTranspose2d, "../../template/conv/conv_transpose2d.wgsl");
|
||||
|
||||
pub(crate) fn conv_transpose2d<E: WgpuElement>(
|
||||
pub(crate) fn conv_transpose2d<E: WgpuElement + Element>(
|
||||
input: WgpuTensor<E, 4>,
|
||||
weight: WgpuTensor<E, 4>,
|
||||
bias: Option<WgpuTensor<E, 1>>,
|
||||
|
@ -35,12 +37,13 @@ pub(crate) fn conv_transpose2d<E: WgpuElement>(
|
|||
let shape_out = Shape::new([batch_size, out_channels * options.groups, out_0, out_1]);
|
||||
let num_elems = shape_out.num_elements();
|
||||
|
||||
let buffer = input
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(input.context.clone(), shape_out, buffer);
|
||||
|
||||
let output = empty_device(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
shape_out.clone(),
|
||||
);
|
||||
let mut info = build_info(&[&input, &output, &weight]);
|
||||
|
||||
info.push(options.stride[0] as u32);
|
||||
info.push(options.stride[1] as u32);
|
||||
info.push(options.padding[0] as u32);
|
||||
|
@ -49,28 +52,24 @@ pub(crate) fn conv_transpose2d<E: WgpuElement>(
|
|||
info.push(options.dilation[1] as u32);
|
||||
info.push(options.groups as u32);
|
||||
|
||||
let bias_buffer = bias
|
||||
.map(|bias| bias.buffer)
|
||||
.unwrap_or_else(|| input.context.create_buffer(core::mem::size_of::<E>()));
|
||||
let bias_handle = bias
|
||||
.map(|bias| bias.handle)
|
||||
.unwrap_or_else(|| input.client.create(E::as_bytes(&[0.elem()])));
|
||||
|
||||
let info_buffer = input
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
let info_handle = input.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = input
|
||||
.context
|
||||
.compile_static::<KernelSettings<ConvTranspose2d, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
let workgroup = elemwise_workgroup(num_elems, WORKGROUP);
|
||||
input.context.execute(
|
||||
workgroup,
|
||||
kernel,
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<ConvTranspose2d, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
);
|
||||
input.client.execute(
|
||||
Box::new(kernel),
|
||||
&[
|
||||
&input.buffer,
|
||||
&weight.buffer,
|
||||
&bias_buffer,
|
||||
&output.buffer,
|
||||
&info_buffer,
|
||||
&input.handle,
|
||||
&weight.handle,
|
||||
&bias_handle,
|
||||
&output.handle,
|
||||
&info_handle,
|
||||
],
|
||||
);
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{self, build_info, elemwise_workgroup, KernelSettings},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
|
@ -17,29 +19,23 @@ pub(crate) fn gather<E: WgpuElement, I: WgpuElement, const D: usize>(
|
|||
let shape_output = indices.shape.clone();
|
||||
let num_elems = shape_output.num_elements();
|
||||
let indices = kernel::into_contiguous(indices);
|
||||
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
|
||||
|
||||
let buffer = tensor
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(tensor.context.clone(), shape_output, buffer);
|
||||
let mut info = build_info(&[&tensor, &output]);
|
||||
info.push(dim as u32);
|
||||
let info_buffer = tensor
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
let info_handle = tensor.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = tensor
|
||||
.context
|
||||
.compile_static::<KernelSettings<Gather, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
tensor.context.execute(
|
||||
let kernel = StaticKernel::<KernelSettings<Gather, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
);
|
||||
|
||||
tensor.client.execute(
|
||||
Box::new(kernel),
|
||||
&[
|
||||
&tensor.buffer,
|
||||
&indices.buffer,
|
||||
&output.buffer,
|
||||
&info_buffer,
|
||||
&tensor.handle,
|
||||
&indices.handle,
|
||||
&output.handle,
|
||||
&info_handle,
|
||||
],
|
||||
);
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{self, build_info, elemwise_workgroup, KernelSettings},
|
||||
kernel_wgsl,
|
||||
|
@ -48,18 +49,15 @@ pub(crate) fn scatter<E: WgpuElement, I: WgpuElement, const D: usize>(
|
|||
|
||||
info.push(dim as u32);
|
||||
|
||||
let info_buffer = tensor
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
let info_handle = tensor.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = tensor
|
||||
.context
|
||||
.compile_static::<KernelSettings<Scatter, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
tensor.context.execute(
|
||||
let kernel = StaticKernel::<KernelSettings<Scatter, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems_per_workgroup, WORKGROUP),
|
||||
kernel,
|
||||
&[&tensor.buffer, &indices.buffer, &value.buffer, &info_buffer],
|
||||
);
|
||||
|
||||
tensor.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&tensor.handle, &indices.handle, &value.handle, &info_handle],
|
||||
);
|
||||
|
||||
tensor
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{build_info, elemwise_workgroup, KernelSettings},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
|
@ -20,32 +22,25 @@ pub(crate) fn select<E: WgpuElement, I: WgpuElement, const D: usize>(
|
|||
|
||||
let mut output_shape = tensor.shape.clone();
|
||||
output_shape.dims[dim] = indices.shape.dims[0];
|
||||
let num_elems = output_shape.num_elements();
|
||||
|
||||
let buffer = tensor
|
||||
.context
|
||||
.create_buffer(num_elems * std::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(tensor.context.clone(), output_shape, buffer);
|
||||
let num_elems = output_shape.num_elements();
|
||||
let output = empty_device(tensor.client.clone(), tensor.device.clone(), output_shape);
|
||||
|
||||
let mut info = build_info(&[&tensor, &output]);
|
||||
info.push(dim as u32);
|
||||
|
||||
let info_buffer = tensor
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = tensor
|
||||
.context
|
||||
.compile_static::<KernelSettings<IndexSelect, E, I, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
tensor.context.execute(
|
||||
let info_handle = output.client.create(bytemuck::cast_slice(&info));
|
||||
let kernel = StaticKernel::<KernelSettings<IndexSelect, E, I, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
);
|
||||
|
||||
tensor.client.execute(
|
||||
Box::new(kernel),
|
||||
&[
|
||||
&tensor.buffer,
|
||||
&indices.buffer,
|
||||
&output.buffer,
|
||||
&info_buffer,
|
||||
&tensor.handle,
|
||||
&indices.handle,
|
||||
&output.handle,
|
||||
&info_handle,
|
||||
],
|
||||
);
|
||||
|
||||
|
@ -89,18 +84,16 @@ pub(crate) fn select_assign<E: WgpuElement, I: WgpuElement, const D: usize>(
|
|||
|
||||
info.push(dim as u32);
|
||||
|
||||
let info_buffer = tensor
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
let info_handle = tensor.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = tensor
|
||||
.context
|
||||
.compile_static::<KernelSettings<SelectAssignInplace, E, I, WORKGROUP, WORKGROUP, 1>>();
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<SelectAssignInplace, E, I, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems_per_workgroup, WORKGROUP),
|
||||
);
|
||||
|
||||
tensor.context.execute(
|
||||
elemwise_workgroup(num_elems_per_workgroup, WORKGROUP),
|
||||
kernel,
|
||||
&[&tensor.buffer, &indices.buffer, &value.buffer, &info_buffer],
|
||||
tensor.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&tensor.handle, &indices.handle, &value.handle, &info_handle],
|
||||
);
|
||||
|
||||
tensor
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{build_info, elemwise_workgroup, KernelSettings},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
use burn_tensor::Shape;
|
||||
|
@ -26,10 +28,7 @@ pub(crate) fn slice<E: WgpuElement, const D1: usize, const D2: usize>(
|
|||
let shape_output = Shape::new(dims);
|
||||
let num_elems = shape_output.num_elements();
|
||||
|
||||
let buffer = tensor
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(tensor.context.clone(), shape_output, buffer);
|
||||
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
|
||||
let mut info = build_info(&[&tensor, &output]);
|
||||
|
||||
for i in 0..D1 {
|
||||
|
@ -37,18 +36,15 @@ pub(crate) fn slice<E: WgpuElement, const D1: usize, const D2: usize>(
|
|||
info.push(start as u32);
|
||||
}
|
||||
|
||||
let info_buffer = tensor
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
let info_handle = output.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = tensor
|
||||
.context
|
||||
.compile_static::<KernelSettings<IndexRaw, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
tensor.context.execute(
|
||||
let kernel = StaticKernel::<KernelSettings<IndexRaw, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
&[&tensor.buffer, &output.buffer, &info_buffer],
|
||||
);
|
||||
|
||||
tensor.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&tensor.handle, &output.handle, &info_handle],
|
||||
);
|
||||
|
||||
output
|
||||
|
@ -73,23 +69,15 @@ pub(crate) fn slice_assign<E: WgpuElement, const D1: usize, const D2: usize>(
|
|||
info.push(start as u32);
|
||||
}
|
||||
|
||||
let info_buffer = tensor
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
let info_handle = tensor.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = tensor.context.compile_static::<KernelSettings<
|
||||
IndexAssignInplaceRaw,
|
||||
E,
|
||||
i32,
|
||||
WORKGROUP,
|
||||
WORKGROUP,
|
||||
1,
|
||||
>>();
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<IndexAssignInplaceRaw, E, i32, WORKGROUP, WORKGROUP, 1>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP));
|
||||
|
||||
tensor.context.execute(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
&[&tensor.buffer, &value.buffer, &info_buffer],
|
||||
tensor.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&tensor.handle, &value.handle, &info_handle],
|
||||
);
|
||||
|
||||
tensor
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{build_info, elemwise_workgroup, KernelSettings},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
|
@ -16,30 +18,28 @@ pub fn mask_fill<E: WgpuElement, const D: usize>(
|
|||
const WORKGROUP: usize = 32;
|
||||
|
||||
let num_elems = input.shape.num_elements();
|
||||
let buffer = input
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(input.context.clone(), input.shape.clone(), buffer);
|
||||
let output = empty_device(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
input.shape.clone(),
|
||||
);
|
||||
|
||||
let value_buffer = input.context.create_buffer_with_data(E::as_bytes(&[value]));
|
||||
let kernel = input
|
||||
.context
|
||||
.compile_static::<KernelSettings<MaskFill, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
let mask = WgpuTensor::new(mask.context, mask.shape, mask.buffer);
|
||||
let info = build_info(&[&input, &mask, &output]);
|
||||
let info_buffers = input
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
input.context.execute(
|
||||
let value_handle = output.client.create(E::as_bytes(&[value]));
|
||||
let kernel = StaticKernel::<KernelSettings<MaskFill, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
);
|
||||
let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle);
|
||||
let info = build_info(&[&input, &mask, &output]);
|
||||
let info_handle = input.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
input.client.execute(
|
||||
Box::new(kernel),
|
||||
&[
|
||||
&input.buffer,
|
||||
&value_buffer,
|
||||
&mask.buffer,
|
||||
&output.buffer,
|
||||
&info_buffers,
|
||||
&input.handle,
|
||||
&value_handle,
|
||||
&mask.handle,
|
||||
&output.handle,
|
||||
&info_handle,
|
||||
],
|
||||
);
|
||||
|
||||
|
@ -54,20 +54,18 @@ pub fn mask_fill_inplace<E: WgpuElement, const D: usize>(
|
|||
const WORKGROUP: usize = 32;
|
||||
|
||||
let num_elems = input.shape.num_elements();
|
||||
let value_buffer = input.context.create_buffer_with_data(E::as_bytes(&[value]));
|
||||
let kernel = input
|
||||
.context
|
||||
.compile_static::<KernelSettings<MaskFillInplace, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
let mask = WgpuTensor::new(mask.context, mask.shape, mask.buffer);
|
||||
let value_handle = input.client.create(E::as_bytes(&[value]));
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<MaskFillInplace, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
);
|
||||
let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle);
|
||||
let info = build_info(&[&input, &mask]);
|
||||
let info_buffers = input
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
let info_handle = input.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
input.context.execute(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
&[&input.buffer, &value_buffer, &mask.buffer, &info_buffers],
|
||||
input.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&input.handle, &value_handle, &mask.handle, &info_handle],
|
||||
);
|
||||
|
||||
input
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{build_info, elemwise_workgroup, KernelSettings},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
|
@ -16,29 +18,27 @@ pub fn mask_where<E: WgpuElement, const D: usize>(
|
|||
const WORKGROUP: usize = 32;
|
||||
|
||||
let num_elems = input.shape.num_elements();
|
||||
let buffer = input
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(input.context.clone(), input.shape.clone(), buffer);
|
||||
let output = empty_device(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
input.shape.clone(),
|
||||
);
|
||||
|
||||
let kernel = input
|
||||
.context
|
||||
.compile_static::<KernelSettings<MaskWhere, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
let mask = WgpuTensor::new(mask.context, mask.shape, mask.buffer);
|
||||
let info = build_info(&[&input, &value, &mask, &output]);
|
||||
let info_buffers = input
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
input.context.execute(
|
||||
let kernel = StaticKernel::<KernelSettings<MaskWhere, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
);
|
||||
let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle);
|
||||
let info = build_info(&[&input, &value, &mask, &output]);
|
||||
let info_handle = input.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
input.client.execute(
|
||||
Box::new(kernel),
|
||||
&[
|
||||
&input.buffer,
|
||||
&value.buffer,
|
||||
&mask.buffer,
|
||||
&output.buffer,
|
||||
&info_buffers,
|
||||
&input.handle,
|
||||
&value.handle,
|
||||
&mask.handle,
|
||||
&output.handle,
|
||||
&info_handle,
|
||||
],
|
||||
);
|
||||
|
||||
|
@ -53,23 +53,21 @@ pub fn mask_where_inplace<E: WgpuElement, const D: usize>(
|
|||
) -> WgpuTensor<E, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let kernel = input
|
||||
.context
|
||||
.compile_static::<KernelSettings<MaskWhereInplace, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
let mask = WgpuTensor::new(mask.context, mask.shape, mask.buffer);
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<MaskWhereInplace, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(input.shape.num_elements(), WORKGROUP),
|
||||
);
|
||||
let mask = WgpuTensor::new(mask.client, mask.device, mask.shape, mask.handle);
|
||||
let mut info = build_info(&[&input, &value, &mask]);
|
||||
info.push(match reverse {
|
||||
true => 1,
|
||||
false => 0,
|
||||
});
|
||||
let info_buffers = input
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
let info_handle = input.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
input.context.execute(
|
||||
elemwise_workgroup(input.shape.num_elements(), WORKGROUP),
|
||||
kernel,
|
||||
&[&input.buffer, &value.buffer, &mask.buffer, &info_buffers],
|
||||
input.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&input.handle, &value.handle, &mask.handle, &info_handle],
|
||||
);
|
||||
|
||||
input
|
||||
|
|
|
@ -2,10 +2,13 @@ use std::marker::PhantomData;
|
|||
|
||||
use super::utils::shape_out;
|
||||
use crate::{
|
||||
context::WorkGroup,
|
||||
compute::{DynamicKernel, WorkGroup},
|
||||
element::WgpuElement,
|
||||
kernel::{build_info, into_contiguous, DynamicKernel, SourceTemplate, StaticKernel},
|
||||
kernel::{
|
||||
build_info, into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource,
|
||||
},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
|
@ -21,9 +24,9 @@ struct MatmulMemCoalescing<E: WgpuElement> {
|
|||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<E: WgpuElement> DynamicKernel for MatmulMemCoalescing<E> {
|
||||
fn source_template(self) -> SourceTemplate {
|
||||
MatmulMemCoalescingRaw::source_template()
|
||||
impl<E: WgpuElement> DynamicKernelSource for MatmulMemCoalescing<E> {
|
||||
fn source(self) -> SourceTemplate {
|
||||
MatmulMemCoalescingRaw::source()
|
||||
.register("workgroup_size_x", self.workgroup_size_x.to_string())
|
||||
.register("workgroup_size_y", self.workgroup_size_y.to_string())
|
||||
.register("elem", E::type_name())
|
||||
|
@ -59,26 +62,11 @@ pub fn matmul_mem_coalescing<E: WgpuElement, const D: usize>(
|
|||
let num_rows = lhs.shape.dims[D - 2];
|
||||
let num_cols = rhs.shape.dims[D - 1];
|
||||
|
||||
let buffer = lhs
|
||||
.context
|
||||
.create_buffer(shape_out.num_elements() * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(lhs.context.clone(), shape_out, buffer);
|
||||
let output = empty_device(lhs.client.clone(), lhs.device.clone(), shape_out);
|
||||
|
||||
// set number of workgroups
|
||||
let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32;
|
||||
let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32;
|
||||
|
||||
let kernel = lhs.context.compile_dynamic(MatmulMemCoalescing::<E>::new(
|
||||
workgroup_size_x,
|
||||
workgroup_size_y,
|
||||
));
|
||||
|
||||
let info = build_info(&[&lhs, &rhs, &output]);
|
||||
|
||||
let info_buffers = lhs
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
let mut num_iter = 1;
|
||||
for i in 0..D - 2 {
|
||||
num_iter *= output.shape.dims[i];
|
||||
|
@ -86,10 +74,18 @@ pub fn matmul_mem_coalescing<E: WgpuElement, const D: usize>(
|
|||
|
||||
let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32);
|
||||
|
||||
lhs.context.execute(
|
||||
let kernel = DynamicKernel::new(
|
||||
MatmulMemCoalescing::<E>::new(workgroup_size_x, workgroup_size_y),
|
||||
workgroup,
|
||||
kernel,
|
||||
&[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers],
|
||||
);
|
||||
|
||||
let info = build_info(&[&lhs, &rhs, &output]);
|
||||
|
||||
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
lhs.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&lhs.handle, &rhs.handle, &output.handle, &info_handle],
|
||||
);
|
||||
|
||||
output
|
||||
|
|
|
@ -3,9 +3,7 @@ pub(crate) mod utils;
|
|||
mod mem_coalescing;
|
||||
mod naive;
|
||||
mod tiling2d;
|
||||
mod tune;
|
||||
|
||||
pub use mem_coalescing::*;
|
||||
pub use naive::*;
|
||||
pub use tiling2d::*;
|
||||
pub use tune::*;
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
use super::utils::shape_out;
|
||||
use crate::{
|
||||
context::WorkGroup,
|
||||
compute::{StaticKernel, WorkGroup},
|
||||
element::WgpuElement,
|
||||
kernel::{build_info, into_contiguous, KernelSettings, SourceTemplate, StaticKernel},
|
||||
kernel::{build_info, into_contiguous, KernelSettings, SourceTemplate, StaticKernelSource},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
|
@ -11,11 +12,11 @@ kernel_wgsl!(MatmulNaiveRaw, "../../template/matmul/naive.wgsl");
|
|||
|
||||
struct MatmulNaive<const WORKGROUP_SIZE_X: usize, const WORKGROUP_SIZE_Y: usize>;
|
||||
|
||||
impl<const WORKGROUP_SIZE_X: usize, const WORKGROUP_SIZE_Y: usize> StaticKernel
|
||||
impl<const WORKGROUP_SIZE_X: usize, const WORKGROUP_SIZE_Y: usize> StaticKernelSource
|
||||
for MatmulNaive<WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>
|
||||
{
|
||||
fn source_template() -> SourceTemplate {
|
||||
MatmulNaiveRaw::source_template()
|
||||
fn source() -> SourceTemplate {
|
||||
MatmulNaiveRaw::source()
|
||||
.register("block_size_m", WORKGROUP_SIZE_X.to_string())
|
||||
.register("block_size_n", WORKGROUP_SIZE_Y.to_string())
|
||||
}
|
||||
|
@ -49,41 +50,35 @@ pub fn matmul_naive<
|
|||
let num_rows = lhs.shape.dims[D - 2];
|
||||
let num_cols = rhs.shape.dims[D - 1];
|
||||
|
||||
let buffer = lhs
|
||||
.context
|
||||
.create_buffer(shape_out.num_elements() * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(lhs.context.clone(), shape_out, buffer);
|
||||
let output = empty_device(lhs.client.clone(), lhs.device.clone(), shape_out);
|
||||
|
||||
// set number of workgroups
|
||||
let blocks_needed_in_x = f32::ceil(num_rows as f32 / WORKGROUP_SIZE_X as f32) as u32;
|
||||
let blocks_needed_in_y = f32::ceil(num_cols as f32 / WORKGROUP_SIZE_Y as f32) as u32;
|
||||
|
||||
let kernel = lhs.context.compile_static::<KernelSettings<
|
||||
MatmulNaive<WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>,
|
||||
E,
|
||||
i32,
|
||||
WORKGROUP_SIZE_X,
|
||||
WORKGROUP_SIZE_Y,
|
||||
1,
|
||||
>>();
|
||||
|
||||
let info = build_info(&[&lhs, &rhs, &output]);
|
||||
|
||||
let info_buffers = lhs
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
let mut num_iter = 1;
|
||||
for i in 0..D - 2 {
|
||||
num_iter *= output.shape.dims[i];
|
||||
}
|
||||
|
||||
let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32);
|
||||
|
||||
lhs.context.execute(
|
||||
workgroup,
|
||||
kernel,
|
||||
&[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers],
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<
|
||||
MatmulNaive<WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>,
|
||||
E,
|
||||
i32,
|
||||
WORKGROUP_SIZE_X,
|
||||
WORKGROUP_SIZE_Y,
|
||||
1,
|
||||
>,
|
||||
>::new(workgroup);
|
||||
|
||||
let info = build_info(&[&lhs, &rhs, &output]);
|
||||
|
||||
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
lhs.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&lhs.handle, &rhs.handle, &output.handle, &info_handle],
|
||||
);
|
||||
|
||||
output
|
||||
|
|
|
@ -1,29 +1,16 @@
|
|||
use super::padding::{crop, pad_round, PaddingOutput};
|
||||
use crate::{
|
||||
context::{Context, WorkGroup},
|
||||
compute::{DynamicKernel, WgpuHandle, WorkGroup},
|
||||
element::WgpuElement,
|
||||
kernel::{build_info, into_contiguous, matmul::utils::shape_out},
|
||||
kernel::{build_info, into_contiguous, matmul::utils::shape_out, DynamicKernelSource},
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
use burn_tensor::Shape;
|
||||
use std::{
|
||||
cmp::{max, min},
|
||||
sync::Arc,
|
||||
};
|
||||
use wgpu::ComputePipeline;
|
||||
|
||||
use super::padding::{crop, pad_round, PaddingOutput};
|
||||
use burn_tensor::{Element, Shape};
|
||||
use std::cmp::{max, min};
|
||||
|
||||
const MAX_SHARED_MEMORY_SIZE: usize = 8192;
|
||||
|
||||
pub(super) fn empty_from_context<E: WgpuElement, const D: usize>(
|
||||
context: Arc<Context>,
|
||||
shape: &Shape<D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::<E>());
|
||||
|
||||
WgpuTensor::new(context, shape.clone(), buffer)
|
||||
}
|
||||
|
||||
/// Create a source template for tile 2d matmul.
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! matmul_tile_2d {
|
||||
|
@ -63,11 +50,11 @@ macro_rules! matmul_tile_2d {
|
|||
_elem: core::marker::PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<E: WgpuElement> DynamicKernel for $struct<E> {
|
||||
fn source_template(self) -> SourceTemplate {
|
||||
impl<E: WgpuElement> DynamicKernelSource for $struct<E> {
|
||||
fn source(self) -> SourceTemplate {
|
||||
kernel_wgsl!(Raw, $file);
|
||||
|
||||
Raw::source_template()
|
||||
Raw::source()
|
||||
.register("b_m", self.b_m.to_string())
|
||||
.register("b_n", self.b_n.to_string())
|
||||
.register("b_k", self.b_k.to_string())
|
||||
|
@ -90,7 +77,7 @@ macro_rules! matmul_tile_2d {
|
|||
}
|
||||
|
||||
/// Matrix multiplication using tiling 2D algorithm with default parameters
|
||||
pub fn matmul_tiling_2d_default<E: WgpuElement, const D: usize>(
|
||||
pub fn matmul_tiling_2d_default<E: WgpuElement + burn_tensor::Element, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
|
@ -125,7 +112,7 @@ macro_rules! matmul_tile_2d {
|
|||
/// Matrix multiplication using tiling 2D algorithm with custom parameters
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn matmul_tiling_2d<
|
||||
E: WgpuElement,
|
||||
E: WgpuElement + burn_tensor::Element,
|
||||
const D: usize,
|
||||
>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
|
@ -140,14 +127,13 @@ macro_rules! matmul_tile_2d {
|
|||
|
||||
) -> WgpuTensor<E, D> {
|
||||
let kernel = $struct::<E>::new(b_m, b_n, b_k, t_m, t_n, workgroup_size_x, workgroup_size_y);
|
||||
let kernel = lhs.context.compile_dynamic(kernel);
|
||||
matmul_tiling_2d_launch::<
|
||||
E,
|
||||
D,
|
||||
$struct::<E>,
|
||||
>(
|
||||
lhs,
|
||||
rhs,
|
||||
kernel,
|
||||
b_m,
|
||||
b_n,
|
||||
b_k,
|
||||
|
@ -155,6 +141,7 @@ macro_rules! matmul_tile_2d {
|
|||
t_n,
|
||||
workgroup_size_x,
|
||||
workgroup_size_y,
|
||||
kernel,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -412,21 +399,23 @@ pub(super) fn make_workgroup<const D: usize>(
|
|||
WorkGroup::new(num_blocks_x, num_blocks_y, num_blocks_z as u32)
|
||||
}
|
||||
|
||||
pub(super) fn make_info_buffers<E: WgpuElement, const D: usize>(
|
||||
pub(super) fn make_info_handle<E: WgpuElement, const D: usize>(
|
||||
lhs: &WgpuTensor<E, D>,
|
||||
rhs: &WgpuTensor<E, D>,
|
||||
output: &WgpuTensor<E, D>,
|
||||
) -> Arc<wgpu::Buffer> {
|
||||
) -> WgpuHandle {
|
||||
let info = build_info(&[lhs, rhs, output]);
|
||||
rhs.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info))
|
||||
rhs.client.create(bytemuck::cast_slice(&info))
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(super) fn matmul_tiling_2d_launch<E: WgpuElement, const D: usize>(
|
||||
pub(super) fn matmul_tiling_2d_launch<
|
||||
E: WgpuElement + Element,
|
||||
const D: usize,
|
||||
K: DynamicKernelSource + 'static,
|
||||
>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
kernel: Arc<ComputePipeline>,
|
||||
b_m: usize,
|
||||
b_n: usize,
|
||||
b_k: usize,
|
||||
|
@ -434,6 +423,7 @@ pub(super) fn matmul_tiling_2d_launch<E: WgpuElement, const D: usize>(
|
|||
t_n: usize,
|
||||
workgroup_size_x: usize,
|
||||
workgroup_size_y: usize,
|
||||
kernel: K,
|
||||
) -> WgpuTensor<E, D> {
|
||||
matmul_parameter_assertions::<E, D>(
|
||||
b_m,
|
||||
|
@ -470,15 +460,18 @@ pub(super) fn matmul_tiling_2d_launch<E: WgpuElement, const D: usize>(
|
|||
|
||||
let rounded_output_shape = shape_out(&lhs, &rhs);
|
||||
|
||||
let output = empty_from_context::<E, D>(rhs.context.clone(), &rounded_output_shape);
|
||||
let output = empty_device(
|
||||
rhs.client.clone(),
|
||||
rhs.device.clone(),
|
||||
rounded_output_shape.clone(),
|
||||
);
|
||||
|
||||
let workgroup = make_workgroup(rounded_output_shape, b_m, b_n);
|
||||
let info_buffers = make_info_buffers(&lhs, &rhs, &output);
|
||||
let info_handle = make_info_handle(&lhs, &rhs, &output);
|
||||
|
||||
lhs.context.execute(
|
||||
workgroup,
|
||||
kernel,
|
||||
&[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers],
|
||||
output.client.execute(
|
||||
Box::new(DynamicKernel::new(kernel, workgroup)),
|
||||
&[&lhs.handle, &rhs.handle, &output.handle, &info_handle],
|
||||
);
|
||||
|
||||
crop(output, final_output_shape)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use super::base::matmul_tiling_2d_launch;
|
||||
use crate::{
|
||||
element::WgpuElement,
|
||||
kernel::{DynamicKernel, SourceTemplate, StaticKernel},
|
||||
kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource},
|
||||
matmul_tile_2d,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use super::base::matmul_tiling_2d_launch;
|
||||
use crate::{
|
||||
element::WgpuElement,
|
||||
kernel::{DynamicKernel, SourceTemplate, StaticKernel},
|
||||
kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource},
|
||||
matmul_tile_2d,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
use std::ops::Range;
|
||||
|
||||
use burn_tensor::Shape;
|
||||
use burn_tensor::{Element, Shape};
|
||||
|
||||
use crate::{
|
||||
element::WgpuElement,
|
||||
kernel::{slice, slice_assign},
|
||||
ops::numeric::zeros_device,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
use super::base::empty_from_context;
|
||||
|
||||
// Output of the pad_round function. Allows to know explicitly if early return occurred
|
||||
pub(super) enum PaddingOutput<E: WgpuElement, const D: usize> {
|
||||
Padded(WgpuTensor<E, D>),
|
||||
|
@ -29,7 +28,7 @@ impl<E: WgpuElement, const D: usize> PaddingOutput<E, D> {
|
|||
/// divisible by some quantity.
|
||||
/// For instance tensor of shape [1000, 1000] with divisors 64 and 64
|
||||
/// will be padded to [1024, 1024] with the last 24 elements being zeros
|
||||
pub(super) fn pad_round<E: WgpuElement, const D: usize>(
|
||||
pub(super) fn pad_round<E: WgpuElement + Element, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
row_divisor: usize,
|
||||
col_divisor: usize,
|
||||
|
@ -62,7 +61,7 @@ pub(super) fn pad_round<E: WgpuElement, const D: usize>(
|
|||
}
|
||||
|
||||
/// Pads tensor by adding zeros when padded dim is larger than tensor dim
|
||||
fn padding<E: WgpuElement, const D: usize>(
|
||||
fn padding<E: WgpuElement + Element, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
padded_shape: Shape<D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
|
@ -73,8 +72,9 @@ fn padding<E: WgpuElement, const D: usize>(
|
|||
.collect::<Vec<Range<usize>>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
slice_assign::<E, D, D>(
|
||||
empty_from_context(tensor.context.clone(), &padded_shape),
|
||||
zeros_device(tensor.client.clone(), tensor.device.clone(), padded_shape),
|
||||
ranges,
|
||||
tensor,
|
||||
)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
element::WgpuElement,
|
||||
kernel::{DynamicKernel, SourceTemplate, StaticKernel},
|
||||
kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource},
|
||||
matmul_tile_2d,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use super::base::matmul_tiling_2d_launch;
|
||||
use crate::{
|
||||
element::WgpuElement,
|
||||
kernel::{DynamicKernel, SourceTemplate, StaticKernel},
|
||||
kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource},
|
||||
matmul_tile_2d,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
|
|
@ -1,396 +0,0 @@
|
|||
use burn_tensor::{Distribution, Shape, Tensor};
|
||||
use mem_coalescing::matmul_mem_coalescing;
|
||||
|
||||
use crate::{
|
||||
benchmark::Benchmark,
|
||||
element::{FloatElement, WgpuElement},
|
||||
kernel,
|
||||
tensor::{WgpuTensor, WgpuTensorDyn},
|
||||
tune::{AutoTuneFunction, AutoTuneKey, Execution, KernelFunction, Tunable},
|
||||
GraphicsApi, WgpuBackend, WgpuDevice,
|
||||
};
|
||||
use std::{marker::PhantomData, sync::Arc};
|
||||
|
||||
use super::mem_coalescing;
|
||||
|
||||
const TILING_2D_BLOCK_SIZES: [usize; 2] = [64, 128];
|
||||
const TILING_2D_BLOCK_SIZES_K: [usize; 2] = [16, 32];
|
||||
const TILING_2D_TILE_SIZES: [usize; 2] = [4, 16];
|
||||
const MEMORY_COALESCING_WORKGROUP_SIZES: [usize; 3] = [8, 16, 32];
|
||||
|
||||
macro_rules! call_dim {
|
||||
($func:expr, $dim:expr, $( $x:expr ),*) => {
|
||||
match $dim {
|
||||
1 => {
|
||||
let tensor: WgpuTensor<E, 1> = $func($($x,)*);
|
||||
tensor.into()
|
||||
},
|
||||
2 => {
|
||||
let tensor: WgpuTensor<E, 2> = $func($($x,)*);
|
||||
tensor.into()
|
||||
},
|
||||
3 => {
|
||||
let tensor: WgpuTensor<E, 3> = $func($($x,)*);
|
||||
tensor.into()
|
||||
},
|
||||
4 => {
|
||||
let tensor: WgpuTensor<E, 4> = $func($($x,)*);
|
||||
tensor.into()
|
||||
},
|
||||
5 => {
|
||||
let tensor: WgpuTensor<E, 5> = $func($($x,)*);
|
||||
tensor.into()
|
||||
},
|
||||
6 => {
|
||||
let tensor: WgpuTensor<E, 6> = $func($($x,)*);
|
||||
tensor.into()
|
||||
},
|
||||
_ => panic!("Tensors of rank 7 and more can't be autotuned."),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! tiling2d_tunable {
|
||||
($name:ident, $func:expr) => {
|
||||
#[derive(new, Default)]
|
||||
struct $name<E: WgpuElement> {
|
||||
b_m: usize,
|
||||
b_n: usize,
|
||||
b_k: usize,
|
||||
t_m: usize,
|
||||
t_n: usize,
|
||||
workgroup_size_x: usize,
|
||||
workgroup_size_y: usize,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<E: WgpuElement> KernelFunction for $name<E> {
|
||||
type Input = (WgpuTensorDyn<E>, WgpuTensorDyn<E>);
|
||||
type Output = WgpuTensorDyn<E>;
|
||||
|
||||
fn call(&self, (lhs, rhs): Self::Input) -> Self::Output {
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn call_dyn<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensorDyn<E>,
|
||||
rhs: WgpuTensorDyn<E>,
|
||||
b_m: usize,
|
||||
b_n: usize,
|
||||
b_k: usize,
|
||||
t_m: usize,
|
||||
t_n: usize,
|
||||
workgroup_size_x: usize,
|
||||
workgroup_size_y: usize,
|
||||
) -> WgpuTensor<E, D> {
|
||||
$func(
|
||||
WgpuTensor::<E, D>::from(lhs),
|
||||
WgpuTensor::<E, D>::from(rhs),
|
||||
b_m,
|
||||
b_n,
|
||||
b_k,
|
||||
t_m,
|
||||
t_n,
|
||||
workgroup_size_x,
|
||||
workgroup_size_y,
|
||||
)
|
||||
}
|
||||
|
||||
return call_dim!(
|
||||
call_dyn,
|
||||
lhs.shape.len(),
|
||||
lhs,
|
||||
rhs,
|
||||
self.b_m,
|
||||
self.b_n,
|
||||
self.b_k,
|
||||
self.t_m,
|
||||
self.t_n,
|
||||
self.workgroup_size_x,
|
||||
self.workgroup_size_y
|
||||
);
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
format!(
|
||||
"Tiling 2D matmul ({}) - B_M {}, B_N {}, B_K {}, T_M {}, T_N {}, W_X {}, W_X {}",
|
||||
stringify!($name),
|
||||
self.b_m,
|
||||
self.b_n,
|
||||
self.b_k,
|
||||
self.t_m,
|
||||
self.t_n,
|
||||
self.workgroup_size_x,
|
||||
self.workgroup_size_y
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
tiling2d_tunable!(
|
||||
Tiling2DContiguousLoad,
|
||||
kernel::matmul::contiguous::matmul_tiling_2d
|
||||
);
|
||||
|
||||
tiling2d_tunable!(Tiling2DTileLoad, kernel::matmul::tile::matmul_tiling_2d);
|
||||
tiling2d_tunable!(
|
||||
Tiling2DContiguousLoadVectorized,
|
||||
kernel::matmul::contiguous_vectorized::matmul_tiling_2d
|
||||
);
|
||||
|
||||
tiling2d_tunable!(
|
||||
Tiling2DTileLoadVectorized,
|
||||
kernel::matmul::tile_vectorized::matmul_tiling_2d
|
||||
);
|
||||
|
||||
#[derive(new)]
|
||||
struct MemoryCoalescing<E: WgpuElement> {
|
||||
workgroup_size_x: usize,
|
||||
workgroup_size_y: usize,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<E: WgpuElement> KernelFunction for MemoryCoalescing<E> {
|
||||
type Input = (WgpuTensorDyn<E>, WgpuTensorDyn<E>);
|
||||
type Output = WgpuTensorDyn<E>;
|
||||
|
||||
fn call(&self, (lhs, rhs): Self::Input) -> Self::Output {
|
||||
fn call_dyn<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensorDyn<E>,
|
||||
rhs: WgpuTensorDyn<E>,
|
||||
workgroup_size_x: usize,
|
||||
workgroup_size_y: usize,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let lhs = WgpuTensor::from(lhs);
|
||||
let rhs = WgpuTensor::from(rhs);
|
||||
|
||||
matmul_mem_coalescing::<E, D>(lhs, rhs, workgroup_size_x, workgroup_size_y)
|
||||
}
|
||||
|
||||
call_dim!(
|
||||
call_dyn,
|
||||
lhs.shape.len(),
|
||||
lhs,
|
||||
rhs,
|
||||
self.workgroup_size_x,
|
||||
self.workgroup_size_y
|
||||
)
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
format!(
|
||||
"Memory Coalescing matmul - W_X {}, W_Y {}",
|
||||
self.workgroup_size_x, self.workgroup_size_y
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct MatmulBenchmark<F: WgpuElement, const D: usize> {
|
||||
shape_lhs: Shape<D>,
|
||||
shape_rhs: Shape<D>,
|
||||
num_repeats: usize,
|
||||
matmul: PhantomData<F>,
|
||||
func: AutoTuneFunction<(WgpuTensorDyn<F>, WgpuTensorDyn<F>), WgpuTensorDyn<F>>,
|
||||
}
|
||||
|
||||
impl<E, const D: usize, G> Benchmark<G> for MatmulBenchmark<E, D>
|
||||
where
|
||||
E: WgpuElement + FloatElement,
|
||||
G: GraphicsApi,
|
||||
{
|
||||
type Args = (WgpuTensorDyn<E>, WgpuTensorDyn<E>);
|
||||
|
||||
fn name(&self) -> String {
|
||||
format!("{:?} x {:?}", self.shape_lhs.dims, self.shape_rhs.dims)
|
||||
}
|
||||
|
||||
fn num_samples(&self) -> usize {
|
||||
5
|
||||
}
|
||||
|
||||
fn execute(&self, (lhs, rhs): Self::Args) {
|
||||
for _ in 0..self.num_repeats {
|
||||
self.func.call((lhs.clone(), rhs.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare(&self, device: &WgpuDevice) -> Self::Args {
|
||||
let lhs = Tensor::<WgpuBackend<G, E, i32>, D>::random_device(
|
||||
self.shape_lhs.clone(),
|
||||
Distribution::Default,
|
||||
device,
|
||||
);
|
||||
let rhs = Tensor::<WgpuBackend<G, E, i32>, D>::random_device(
|
||||
self.shape_rhs.clone(),
|
||||
Distribution::Default,
|
||||
device,
|
||||
);
|
||||
|
||||
(lhs.into_primitive().into(), rhs.into_primitive().into())
|
||||
}
|
||||
}
|
||||
|
||||
/// Choose the best matmul kernel by using autotuning.
|
||||
pub fn tune<G: GraphicsApi, E, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D>
|
||||
where
|
||||
E: WgpuElement + FloatElement,
|
||||
{
|
||||
if D > 6 {
|
||||
log::debug!("Can't autotune matmul for tensors of rank 7 or more.");
|
||||
return kernel::matmul::matmul_tiling_2d_default(lhs, rhs);
|
||||
}
|
||||
|
||||
let (shape_lhs, shape_rhs) = calculate_benchmark_shapes(lhs.shape.clone(), rhs.shape.clone());
|
||||
let id = AutoTuneKey::new(
|
||||
vec![
|
||||
shape_lhs.dims[D - 2..].to_vec(),
|
||||
shape_rhs.dims[D - 2..].to_vec(),
|
||||
],
|
||||
format!("matmul {}", E::type_name()),
|
||||
&lhs.context,
|
||||
);
|
||||
|
||||
let context = lhs.context.clone();
|
||||
let input: (WgpuTensorDyn<E>, WgpuTensorDyn<E>) = (lhs.into(), rhs.into());
|
||||
let output: WgpuTensorDyn<E> = match context.tuner.execute(&id, input) {
|
||||
Execution::Executed(output) => output,
|
||||
Execution::NoCacheFound((lhs, rhs)) => {
|
||||
let tunables = matmul_candidates::<G, E, D>(shape_lhs, shape_rhs);
|
||||
|
||||
context
|
||||
.tuner
|
||||
.tune(id, (lhs, rhs), tunables, &context.device, &context)
|
||||
}
|
||||
};
|
||||
|
||||
output.into()
|
||||
}
|
||||
|
||||
/// Shape dims are anchored to the closest (on a log scale) power of 2
|
||||
fn calculate_benchmark_shapes<const D: usize>(
|
||||
lhs: Shape<D>,
|
||||
rhs: Shape<D>,
|
||||
) -> (Shape<D>, Shape<D>) {
|
||||
let anchor = |a| f32::powf(2., f32::min(f32::round(f32::log(a as f32, 2.)), 12.)) as usize;
|
||||
let m = anchor(lhs.dims[D - 2]);
|
||||
let k = anchor(lhs.dims[D - 1]);
|
||||
let n = anchor(rhs.dims[D - 1]);
|
||||
|
||||
let mut lhs_shape = [1; D];
|
||||
lhs_shape[D - 2] = m;
|
||||
lhs_shape[D - 1] = k;
|
||||
let lhs_shape = Shape::new(lhs_shape);
|
||||
|
||||
let mut rhs_shape = [1; D];
|
||||
rhs_shape[D - 2] = k;
|
||||
rhs_shape[D - 1] = n;
|
||||
let rhs_shape = Shape::new(rhs_shape);
|
||||
|
||||
(lhs_shape, rhs_shape)
|
||||
}
|
||||
|
||||
type MatmulTunable<G, E> = Tunable<G, (WgpuTensorDyn<E>, WgpuTensorDyn<E>), WgpuTensorDyn<E>>;
|
||||
|
||||
/// Enumerates all matmul versions that are candidates for autotuning
|
||||
fn matmul_candidates<G: GraphicsApi, E, const D: usize>(
|
||||
shape_lhs: Shape<D>,
|
||||
shape_rhs: Shape<D>,
|
||||
) -> Vec<MatmulTunable<G, E>>
|
||||
where
|
||||
E: WgpuElement + FloatElement,
|
||||
{
|
||||
let matmul_benchmark =
|
||||
|func: AutoTuneFunction<(WgpuTensorDyn<E>, WgpuTensorDyn<E>), WgpuTensorDyn<E>>| {
|
||||
Tunable::<G, _, _>::new(
|
||||
func.clone(),
|
||||
Arc::new(MatmulBenchmark::new(
|
||||
shape_lhs.clone(),
|
||||
shape_rhs.clone(),
|
||||
5,
|
||||
func.clone(),
|
||||
)),
|
||||
)
|
||||
};
|
||||
|
||||
let mut candidates = Vec::new();
|
||||
|
||||
// All combinations of tiling 2d parameters are pushed for a grid search
|
||||
for block_size in TILING_2D_BLOCK_SIZES {
|
||||
for block_size_k in TILING_2D_BLOCK_SIZES_K {
|
||||
for tile_size in TILING_2D_TILE_SIZES {
|
||||
candidates.push(matmul_benchmark(Arc::new(
|
||||
Tiling2DContiguousLoad::<E>::new(
|
||||
block_size,
|
||||
block_size,
|
||||
block_size_k,
|
||||
tile_size,
|
||||
tile_size,
|
||||
block_size / tile_size,
|
||||
block_size / tile_size,
|
||||
),
|
||||
)));
|
||||
candidates.push(matmul_benchmark(Arc::new(
|
||||
Tiling2DContiguousLoadVectorized::<E>::new(
|
||||
block_size,
|
||||
block_size,
|
||||
block_size_k,
|
||||
tile_size,
|
||||
tile_size,
|
||||
block_size / tile_size,
|
||||
block_size / tile_size,
|
||||
),
|
||||
)));
|
||||
candidates.push(matmul_benchmark(Arc::new(Tiling2DTileLoad::<E>::new(
|
||||
block_size,
|
||||
block_size,
|
||||
block_size_k,
|
||||
tile_size,
|
||||
tile_size,
|
||||
block_size / tile_size,
|
||||
block_size / tile_size,
|
||||
))));
|
||||
candidates.push(matmul_benchmark(Arc::new(
|
||||
Tiling2DTileLoadVectorized::<E>::new(
|
||||
block_size,
|
||||
block_size,
|
||||
block_size_k,
|
||||
tile_size,
|
||||
tile_size,
|
||||
block_size / tile_size,
|
||||
block_size / tile_size,
|
||||
),
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// All combinations of tiling 2d parameters are pushed for a grid search
|
||||
for workgroup_size in MEMORY_COALESCING_WORKGROUP_SIZES {
|
||||
candidates.push(matmul_benchmark(Arc::new(MemoryCoalescing::new(
|
||||
workgroup_size,
|
||||
workgroup_size,
|
||||
))));
|
||||
}
|
||||
}
|
||||
candidates
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::calculate_benchmark_shapes;
|
||||
|
||||
#[test]
|
||||
pub fn benchmark_shapes_are_anchored_correctly() {
|
||||
let m = f32::powf(2., 8.49) as usize;
|
||||
let k = f32::powf(2., 8.51) as usize;
|
||||
let n = f32::powf(2., 4.) as usize;
|
||||
let lhs_shape = [m, k].into();
|
||||
let rhs_shape = [k, n].into();
|
||||
let (lhs_shape, rhs_shape) = calculate_benchmark_shapes(lhs_shape, rhs_shape);
|
||||
assert_eq!(lhs_shape.dims, [256, 512]);
|
||||
assert_eq!(rhs_shape.dims, [512, 16]);
|
||||
}
|
||||
}
|
|
@ -1,14 +1,12 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use burn_tensor::Shape;
|
||||
use wgpu::Buffer;
|
||||
|
||||
use crate::{
|
||||
compute::{StaticKernel, WgpuHandle},
|
||||
element::WgpuElement,
|
||||
kernel::{elemwise_workgroup, KernelSettings},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
use burn_tensor::Shape;
|
||||
|
||||
kernel_wgsl!(
|
||||
AdaptiveAvgPool2d,
|
||||
|
@ -28,23 +26,16 @@ pub(crate) fn adaptive_avg_pool2d<E: WgpuElement>(
|
|||
let [batch_size, channels, _, _] = x.shape.dims;
|
||||
|
||||
let output_shape = Shape::new([batch_size, channels, output_size[0], output_size[1]]);
|
||||
let num_elems = output_shape.num_elements();
|
||||
let output_buffer = x
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(x.context.clone(), output_shape, output_buffer);
|
||||
let output = empty_device(x.client.clone(), x.device.clone(), output_shape);
|
||||
|
||||
let kernel = x
|
||||
.context
|
||||
.compile_static::<KernelSettings<AdaptiveAvgPool2d, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<AdaptiveAvgPool2d, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
|
||||
);
|
||||
|
||||
let info_buffer = build_info(&x, &output);
|
||||
|
||||
x.context.execute(
|
||||
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
|
||||
kernel,
|
||||
&[&x.buffer, &output.buffer, &info_buffer],
|
||||
);
|
||||
let info_handle = build_info(&x, &output);
|
||||
x.client
|
||||
.execute(Box::new(kernel), &[&x.handle, &output.handle, &info_handle]);
|
||||
|
||||
output
|
||||
}
|
||||
|
@ -57,32 +48,29 @@ pub(crate) fn adaptive_avg_pool2d_backward<E: WgpuElement>(
|
|||
|
||||
let output_shape = x.shape.clone();
|
||||
let num_elems = output_shape.num_elements();
|
||||
let output_buffer = x
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(x.context.clone(), output_shape, output_buffer);
|
||||
let output_buffer = x.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(
|
||||
x.client.clone(),
|
||||
x.device.clone(),
|
||||
output_shape,
|
||||
output_buffer,
|
||||
);
|
||||
|
||||
let kernel = x.context.compile_static::<KernelSettings<
|
||||
AdaptiveAvgPool2dBackward,
|
||||
E,
|
||||
i32,
|
||||
WORKGROUP,
|
||||
WORKGROUP,
|
||||
1,
|
||||
>>();
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<AdaptiveAvgPool2dBackward, E, i32, WORKGROUP, WORKGROUP, 1>,
|
||||
>::new(elemwise_workgroup(output.shape.num_elements(), WORKGROUP));
|
||||
|
||||
let info_buffer = build_info(&x, &out_grad);
|
||||
let info_handle = build_info(&x, &out_grad);
|
||||
|
||||
x.context.execute(
|
||||
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
|
||||
kernel,
|
||||
&[&out_grad.buffer, &output.buffer, &info_buffer],
|
||||
x.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&out_grad.handle, &output.handle, &info_handle],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
fn build_info<E: WgpuElement>(x: &WgpuTensor<E, 4>, output: &WgpuTensor<E, 4>) -> Arc<Buffer> {
|
||||
fn build_info<E: WgpuElement>(x: &WgpuTensor<E, 4>, output: &WgpuTensor<E, 4>) -> WgpuHandle {
|
||||
let mut info: [u32; 16] = [0; 16];
|
||||
info[0] = x.strides[0] as u32;
|
||||
info[1] = x.strides[1] as u32;
|
||||
|
@ -102,7 +90,5 @@ fn build_info<E: WgpuElement>(x: &WgpuTensor<E, 4>, output: &WgpuTensor<E, 4>) -
|
|||
info[14] = output.shape.dims[2] as u32;
|
||||
info[15] = output.shape.dims[3] as u32;
|
||||
|
||||
output
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info))
|
||||
output.client.create(bytemuck::cast_slice(&info))
|
||||
}
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
use crate::{
|
||||
compute::{Kernel, StaticKernel},
|
||||
element::WgpuElement,
|
||||
kernel::{
|
||||
self, elemwise_workgroup,
|
||||
pool::{build_output_and_info_pool2d, build_pool2d_info},
|
||||
KernelSettings, StaticKernel,
|
||||
KernelSettings, StaticKernelSource,
|
||||
},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
|
@ -18,17 +20,15 @@ kernel_wgsl!(
|
|||
struct AvgPool2dBackward<const COUNT_INCLUDE_PAD: bool>;
|
||||
struct AvgPool2d<const COUNT_INCLUDE_PAD: bool>;
|
||||
|
||||
impl<const COUNT_INCLUDE_PAD: bool> StaticKernel for AvgPool2dBackward<COUNT_INCLUDE_PAD> {
|
||||
fn source_template() -> kernel::SourceTemplate {
|
||||
AvgPool2dBackwardRaw::source_template()
|
||||
.register("count_include_pad", format!("{COUNT_INCLUDE_PAD}"))
|
||||
impl<const COUNT_INCLUDE_PAD: bool> StaticKernelSource for AvgPool2dBackward<COUNT_INCLUDE_PAD> {
|
||||
fn source() -> kernel::SourceTemplate {
|
||||
AvgPool2dBackwardRaw::source().register("count_include_pad", format!("{COUNT_INCLUDE_PAD}"))
|
||||
}
|
||||
}
|
||||
|
||||
impl<const COUNT_INCLUDE_PAD: bool> StaticKernel for AvgPool2d<COUNT_INCLUDE_PAD> {
|
||||
fn source_template() -> kernel::SourceTemplate {
|
||||
AvgPool2dRaw::source_template()
|
||||
.register("count_include_pad", format!("{COUNT_INCLUDE_PAD}"))
|
||||
impl<const COUNT_INCLUDE_PAD: bool> StaticKernelSource for AvgPool2d<COUNT_INCLUDE_PAD> {
|
||||
fn source() -> kernel::SourceTemplate {
|
||||
AvgPool2dRaw::source().register("count_include_pad", format!("{COUNT_INCLUDE_PAD}"))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -41,22 +41,21 @@ pub(crate) fn avg_pool2d<E: WgpuElement>(
|
|||
) -> WgpuTensor<E, 4> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let (info_buffer, output) =
|
||||
let (info_handle, output) =
|
||||
build_output_and_info_pool2d(&x, kernel_size, stride, padding, [1, 1]);
|
||||
let kernel = match count_include_pad {
|
||||
true => x
|
||||
.context
|
||||
.compile_static::<KernelSettings<AvgPool2d<true>, E, i32, WORKGROUP, WORKGROUP, 1>>(),
|
||||
false => x
|
||||
.context
|
||||
.compile_static::<KernelSettings<AvgPool2d<false>, E, i32, WORKGROUP, WORKGROUP, 1>>(),
|
||||
|
||||
let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP);
|
||||
let kernel: Box<dyn Kernel> = match count_include_pad {
|
||||
true => Box::new(StaticKernel::<
|
||||
KernelSettings<AvgPool2d<true>, E, i32, WORKGROUP, WORKGROUP, 1>,
|
||||
>::new(workgroup)),
|
||||
false => Box::new(StaticKernel::<
|
||||
KernelSettings<AvgPool2d<false>, E, i32, WORKGROUP, WORKGROUP, 1>,
|
||||
>::new(workgroup)),
|
||||
};
|
||||
|
||||
x.context.execute(
|
||||
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
|
||||
kernel,
|
||||
&[&x.buffer, &output.buffer, &info_buffer],
|
||||
);
|
||||
x.client
|
||||
.execute(kernel, &[&x.handle, &output.handle, &info_handle]);
|
||||
|
||||
output
|
||||
}
|
||||
|
@ -72,39 +71,21 @@ pub(crate) fn avg_pool2d_backward<E: WgpuElement>(
|
|||
const WORKGROUP: usize = 32;
|
||||
|
||||
let grad = kernel::into_contiguous(grad);
|
||||
let output = empty_device(x.client.clone(), x.device.clone(), x.shape.clone());
|
||||
let info_handle = build_pool2d_info(&x, &grad, kernel_size, stride, padding, [1, 1]);
|
||||
let workgroup = elemwise_workgroup(output.shape.num_elements(), WORKGROUP);
|
||||
|
||||
let num_elems = x.shape.num_elements();
|
||||
let buffer = x
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(x.context.clone(), x.shape.clone(), buffer);
|
||||
let info_buffer = build_pool2d_info(&x, &grad, kernel_size, stride, padding, [1, 1]);
|
||||
let kernel: Box<dyn Kernel> = match count_include_pad {
|
||||
true => Box::new(StaticKernel::<
|
||||
KernelSettings<AvgPool2dBackward<true>, E, i32, WORKGROUP, WORKGROUP, 1>,
|
||||
>::new(workgroup)),
|
||||
false => Box::new(StaticKernel::<
|
||||
KernelSettings<AvgPool2dBackward<false>, E, i32, WORKGROUP, WORKGROUP, 1>,
|
||||
>::new(workgroup)),
|
||||
};
|
||||
|
||||
let kernel =
|
||||
match count_include_pad {
|
||||
true => x.context.compile_static::<KernelSettings<
|
||||
AvgPool2dBackward<true>,
|
||||
E,
|
||||
i32,
|
||||
WORKGROUP,
|
||||
WORKGROUP,
|
||||
1,
|
||||
>>(),
|
||||
false => x.context.compile_static::<KernelSettings<
|
||||
AvgPool2dBackward<false>,
|
||||
E,
|
||||
i32,
|
||||
WORKGROUP,
|
||||
WORKGROUP,
|
||||
1,
|
||||
>>(),
|
||||
};
|
||||
|
||||
x.context.execute(
|
||||
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
|
||||
kernel,
|
||||
&[&grad.buffer, &output.buffer, &info_buffer],
|
||||
);
|
||||
x.client
|
||||
.execute(kernel, &[&grad.handle, &output.handle, &info_handle]);
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{element::WgpuElement, tensor::WgpuTensor};
|
||||
use crate::{
|
||||
compute::WgpuHandle, element::WgpuElement, ops::numeric::empty_device, tensor::WgpuTensor,
|
||||
};
|
||||
use burn_tensor::Shape;
|
||||
use std::sync::Arc;
|
||||
use wgpu::Buffer;
|
||||
|
||||
/// Build basic info to launch pool 2d kernels.
|
||||
pub fn build_output_and_info_pool2d<E: WgpuElement>(
|
||||
|
@ -10,7 +10,7 @@ pub fn build_output_and_info_pool2d<E: WgpuElement>(
|
|||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> (Arc<Buffer>, WgpuTensor<E, 4>) {
|
||||
) -> (WgpuHandle, WgpuTensor<E, 4>) {
|
||||
let [kernel_height, kernel_width] = kernel_size;
|
||||
let [padding_height, padding_width] = padding;
|
||||
let [stride_height, stride_width] = stride;
|
||||
|
@ -24,12 +24,7 @@ pub fn build_output_and_info_pool2d<E: WgpuElement>(
|
|||
/ stride_width)
|
||||
+ 1;
|
||||
let shape_out = Shape::new([batch_size, channels, out_height, out_width]);
|
||||
let num_elems = shape_out.num_elements();
|
||||
|
||||
let buffer = x
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(x.context.clone(), shape_out, buffer);
|
||||
let output = empty_device(x.client.clone(), x.device.clone(), shape_out);
|
||||
|
||||
let info_buffer = build_pool2d_info(x, &output, kernel_size, stride, padding, dilation);
|
||||
|
||||
|
@ -43,7 +38,7 @@ pub fn build_pool2d_info<E: WgpuElement>(
|
|||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> Arc<Buffer> {
|
||||
) -> WgpuHandle {
|
||||
let mut info: [u32; 24] = [0; 24];
|
||||
info[0] = input.strides[0] as u32;
|
||||
info[1] = input.strides[1] as u32;
|
||||
|
@ -72,9 +67,7 @@ pub fn build_pool2d_info<E: WgpuElement>(
|
|||
info[22] = dilation[0] as u32;
|
||||
info[23] = dilation[1] as u32;
|
||||
|
||||
let info_buffer = input
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
let info_buffer = input.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
info_buffer
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{
|
||||
self, elemwise_workgroup,
|
||||
|
@ -6,6 +7,7 @@ use crate::{
|
|||
KernelSettings,
|
||||
},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
|
@ -28,18 +30,15 @@ pub(crate) fn max_pool2d<E: WgpuElement>(
|
|||
) -> WgpuTensor<E, 4> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let (info_buffer, output) =
|
||||
let (info_handle, output) =
|
||||
build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation);
|
||||
let kernel = x
|
||||
.context
|
||||
.compile_static::<KernelSettings<MaxPool2d, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
x.context.execute(
|
||||
let kernel = StaticKernel::<KernelSettings<MaxPool2d, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
|
||||
kernel,
|
||||
&[&x.buffer, &output.buffer, &info_buffer],
|
||||
);
|
||||
|
||||
x.client
|
||||
.execute(Box::new(kernel), &[&x.handle, &output.handle, &info_handle]);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
|
@ -52,25 +51,17 @@ pub(crate) fn max_pool2d_with_indices<E: WgpuElement, I: WgpuElement>(
|
|||
) -> (WgpuTensor<E, 4>, WgpuTensor<I, 4>) {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let (info_buffer, output) =
|
||||
let (info_handle, output) =
|
||||
build_output_and_info_pool2d(&x, kernel_size, stride, padding, dilation);
|
||||
let num_elems = output.shape.num_elements();
|
||||
let indices = empty_device(x.client.clone(), x.device, output.shape.clone());
|
||||
|
||||
let indices = WgpuTensor::new(
|
||||
x.context.clone(),
|
||||
output.shape.clone(),
|
||||
x.context
|
||||
.create_buffer(num_elems * std::mem::size_of::<I>()),
|
||||
);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<MaxPool2dWithIndices, E, i32, WORKGROUP, WORKGROUP, 1>,
|
||||
>::new(elemwise_workgroup(output.shape.num_elements(), WORKGROUP));
|
||||
|
||||
let kernel = x
|
||||
.context
|
||||
.compile_static::<KernelSettings<MaxPool2dWithIndices, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
x.context.execute(
|
||||
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
|
||||
kernel,
|
||||
&[&x.buffer, &output.buffer, &indices.buffer, &info_buffer],
|
||||
x.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&x.handle, &output.handle, &indices.handle, &info_handle],
|
||||
);
|
||||
|
||||
(output, indices)
|
||||
|
@ -91,26 +82,18 @@ pub(crate) fn max_pool2d_with_indices_backward<E: WgpuElement, I: WgpuElement>(
|
|||
let indices = kernel::into_contiguous(indices);
|
||||
|
||||
let num_elems = x.shape.num_elements();
|
||||
let buffer = x
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(x.context.clone(), x.shape.clone(), buffer);
|
||||
let buffer = x.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(x.client.clone(), x.device.clone(), x.shape.clone(), buffer);
|
||||
|
||||
let info_buffer = build_pool2d_info(&x, &grad, kernel_size, stride, padding, dilation);
|
||||
let info_handle = build_pool2d_info(&x, &grad, kernel_size, stride, padding, dilation);
|
||||
|
||||
let kernel = x.context.compile_static::<KernelSettings<
|
||||
MaxPool2dWithIndicesBackward,
|
||||
E,
|
||||
I,
|
||||
WORKGROUP,
|
||||
WORKGROUP,
|
||||
1,
|
||||
>>();
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<MaxPool2dWithIndicesBackward, E, I, WORKGROUP, WORKGROUP, 1>,
|
||||
>::new(elemwise_workgroup(output.shape.num_elements(), WORKGROUP));
|
||||
|
||||
x.context.execute(
|
||||
elemwise_workgroup(output.shape.num_elements(), WORKGROUP),
|
||||
kernel,
|
||||
&[&indices.buffer, &grad.buffer, &output.buffer, &info_buffer],
|
||||
x.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&indices.handle, &grad.handle, &output.handle, &info_handle],
|
||||
);
|
||||
output
|
||||
}
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
compute::{WgpuComputeClient, WgpuHandle},
|
||||
element::WgpuElement,
|
||||
kernel_wgsl, SEED,
|
||||
};
|
||||
use burn_common::rand::get_seeded_rng;
|
||||
use burn_tensor::Shape;
|
||||
use rand::Rng;
|
||||
use wgpu::Buffer;
|
||||
|
||||
use crate::{context::Context, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor, SEED};
|
||||
|
||||
kernel_wgsl!(Prng, "../../template/prng/prng.wgsl");
|
||||
|
||||
|
@ -23,22 +22,20 @@ pub(crate) fn get_seeds() -> Vec<u32> {
|
|||
seeds
|
||||
}
|
||||
|
||||
pub(crate) fn make_output_tensor<E: WgpuElement, const D: usize>(
|
||||
context: Arc<Context>,
|
||||
shape: Shape<D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::<E>());
|
||||
WgpuTensor::new(context, shape, buffer)
|
||||
}
|
||||
|
||||
pub(crate) fn make_info_buffer(context: Arc<Context>, n_values_per_thread: usize) -> Arc<Buffer> {
|
||||
pub(crate) fn make_info_buffer(
|
||||
client: WgpuComputeClient,
|
||||
n_values_per_thread: usize,
|
||||
) -> WgpuHandle {
|
||||
let mut info = get_seeds();
|
||||
info.insert(0, n_values_per_thread as u32);
|
||||
context.create_buffer_with_data(bytemuck::cast_slice(&info))
|
||||
client.create(bytemuck::cast_slice(&info))
|
||||
}
|
||||
|
||||
pub(crate) fn make_args_buffer<E: WgpuElement>(context: Arc<Context>, args: &[E]) -> Arc<Buffer> {
|
||||
context.create_buffer_with_data(E::as_bytes(args))
|
||||
pub(crate) fn make_args_buffer<E: WgpuElement>(
|
||||
client: WgpuComputeClient,
|
||||
args: &[E],
|
||||
) -> WgpuHandle {
|
||||
client.create(E::as_bytes(args))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
use burn_tensor::Shape;
|
||||
|
||||
use crate::{
|
||||
compute::{compute_client, StaticKernel},
|
||||
element::WgpuElement,
|
||||
kernel::{
|
||||
prng::base::{make_args_buffer, make_info_buffer, make_output_tensor},
|
||||
prng_workgroup, KernelSettings, SourceTemplate, StaticKernel,
|
||||
prng::base::{make_args_buffer, make_info_buffer},
|
||||
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource,
|
||||
},
|
||||
pool::get_context,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
GraphicsApi, WgpuDevice,
|
||||
};
|
||||
|
@ -15,9 +16,9 @@ use super::base::Prng;
|
|||
|
||||
struct BernoulliPrng;
|
||||
|
||||
impl StaticKernel for BernoulliPrng {
|
||||
fn source_template() -> SourceTemplate {
|
||||
Prng::source_template()
|
||||
impl StaticKernelSource for BernoulliPrng {
|
||||
fn source() -> SourceTemplate {
|
||||
Prng::source()
|
||||
.register("num_args", "1")
|
||||
.register(
|
||||
"prng_loop",
|
||||
|
@ -36,15 +37,19 @@ pub fn random_bernoulli<G: GraphicsApi, E: WgpuElement, const D: usize>(
|
|||
const WORKGROUP: usize = 32;
|
||||
const N_VALUES_PER_THREAD: usize = 128;
|
||||
|
||||
let context = get_context::<G>(device);
|
||||
let output = make_output_tensor(context.clone(), shape.clone());
|
||||
let info_buffer = make_info_buffer(context.clone(), N_VALUES_PER_THREAD);
|
||||
let args_buffer = make_args_buffer(context.clone(), &[prob]);
|
||||
let client = compute_client::<G>(device);
|
||||
let output = empty_device(client.clone(), device.clone(), shape.clone());
|
||||
let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD);
|
||||
let args_handle = make_args_buffer(client.clone(), &[prob]);
|
||||
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD);
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<BernoulliPrng, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
workgroup,
|
||||
);
|
||||
|
||||
context.execute(
|
||||
prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD),
|
||||
context.compile_static::<KernelSettings<BernoulliPrng, E, i32, WORKGROUP, WORKGROUP, 1>>(),
|
||||
&[&output.buffer, &info_buffer, &args_buffer],
|
||||
client.execute(
|
||||
Box::new(kernel),
|
||||
&[&output.handle, &info_handle, &args_handle],
|
||||
);
|
||||
|
||||
output
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
use burn_tensor::Shape;
|
||||
|
||||
use crate::{
|
||||
compute::{compute_client, StaticKernel},
|
||||
element::WgpuElement,
|
||||
kernel::{
|
||||
prng::base::{make_args_buffer, make_info_buffer, make_output_tensor},
|
||||
prng_workgroup, KernelSettings, SourceTemplate, StaticKernel,
|
||||
prng::base::{make_args_buffer, make_info_buffer},
|
||||
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource,
|
||||
},
|
||||
pool::get_context,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
GraphicsApi, WgpuDevice,
|
||||
};
|
||||
|
@ -15,9 +16,9 @@ use super::base::Prng;
|
|||
|
||||
struct NormalPrng;
|
||||
|
||||
impl StaticKernel for NormalPrng {
|
||||
fn source_template() -> SourceTemplate {
|
||||
Prng::source_template()
|
||||
impl StaticKernelSource for NormalPrng {
|
||||
fn source() -> SourceTemplate {
|
||||
Prng::source()
|
||||
.register("num_args", "2")
|
||||
.register(
|
||||
"prng_loop",
|
||||
|
@ -39,15 +40,17 @@ pub fn random_normal<G: GraphicsApi, E: WgpuElement, const D: usize>(
|
|||
const WORKGROUP: usize = 32;
|
||||
const N_VALUES_PER_THREAD: usize = 128; // must be even
|
||||
|
||||
let context = get_context::<G>(device);
|
||||
let output = make_output_tensor(context.clone(), shape.clone());
|
||||
let info_buffer = make_info_buffer(context.clone(), N_VALUES_PER_THREAD);
|
||||
let args_buffer = make_args_buffer(context.clone(), &[mean, std]);
|
||||
let client = compute_client::<G>(device);
|
||||
let output = empty_device(client.clone(), device.clone(), shape.clone());
|
||||
let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD);
|
||||
let args_handle = make_args_buffer(client.clone(), &[mean, std]);
|
||||
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD);
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<NormalPrng, E, i32, WORKGROUP, WORKGROUP, 1>>::new(workgroup);
|
||||
|
||||
context.execute(
|
||||
prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD),
|
||||
context.compile_static::<KernelSettings<NormalPrng, E, i32, WORKGROUP, WORKGROUP, 1>>(),
|
||||
&[&output.buffer, &info_buffer, &args_buffer],
|
||||
client.execute(
|
||||
Box::new(kernel),
|
||||
&[&output.handle, &info_handle, &args_handle],
|
||||
);
|
||||
|
||||
output
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
use burn_tensor::Shape;
|
||||
|
||||
use crate::{
|
||||
compute::{compute_client, StaticKernel},
|
||||
element::WgpuElement,
|
||||
kernel::{
|
||||
prng::base::{make_args_buffer, make_info_buffer, make_output_tensor},
|
||||
prng_workgroup, KernelSettings, SourceTemplate, StaticKernel,
|
||||
prng::base::{make_args_buffer, make_info_buffer},
|
||||
prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource,
|
||||
},
|
||||
pool::get_context,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
GraphicsApi, WgpuDevice,
|
||||
};
|
||||
|
@ -15,9 +16,9 @@ use super::base::Prng;
|
|||
|
||||
struct UniformPrng;
|
||||
|
||||
impl StaticKernel for UniformPrng {
|
||||
fn source_template() -> SourceTemplate {
|
||||
Prng::source_template().register("num_args", "2").register(
|
||||
impl StaticKernelSource for UniformPrng {
|
||||
fn source() -> SourceTemplate {
|
||||
Prng::source().register("num_args", "2").register(
|
||||
"prng_loop",
|
||||
include_str!("../../template/prng/uniform_inner_loop.wgsl"),
|
||||
)
|
||||
|
@ -34,15 +35,18 @@ pub fn random_uniform<G: GraphicsApi, E: WgpuElement, const D: usize>(
|
|||
const WORKGROUP: usize = 32;
|
||||
const N_VALUES_PER_THREAD: usize = 128;
|
||||
|
||||
let context = get_context::<G>(device);
|
||||
let output = make_output_tensor(context.clone(), shape.clone());
|
||||
let info_buffer = make_info_buffer(context.clone(), N_VALUES_PER_THREAD);
|
||||
let args_buffer = make_args_buffer(context.clone(), &[low, high]);
|
||||
let client = compute_client::<G>(device);
|
||||
let output = empty_device(client.clone(), device.clone(), shape.clone());
|
||||
let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD);
|
||||
let args_handle = make_args_buffer(client.clone(), &[low, high]);
|
||||
let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD);
|
||||
let kernel = StaticKernel::<KernelSettings<UniformPrng, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
workgroup,
|
||||
);
|
||||
|
||||
context.execute(
|
||||
prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD),
|
||||
context.compile_static::<KernelSettings<UniformPrng, E, i32, WORKGROUP, WORKGROUP, 1>>(),
|
||||
&[&output.buffer, &info_buffer, &args_buffer],
|
||||
client.execute(
|
||||
Box::new(kernel),
|
||||
&[&output.handle, &info_handle, &args_handle],
|
||||
);
|
||||
|
||||
output
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
use super::{build_info, KernelSettings, SourceTemplate, StaticKernel};
|
||||
use crate::{element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl, tensor::WgpuTensor};
|
||||
use super::{build_info, KernelSettings, SourceTemplate, StaticKernelSource};
|
||||
use crate::{
|
||||
compute::StaticKernel, element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
use burn_tensor::Shape;
|
||||
|
||||
kernel_wgsl!(RecursiveSumRaw, "../template/reduction/recursive_sum.wgsl");
|
||||
|
@ -11,15 +14,15 @@ pub struct ArgsMin;
|
|||
pub struct SumDim;
|
||||
pub struct MeanDim;
|
||||
|
||||
impl StaticKernel for SumDim {
|
||||
fn source_template() -> SourceTemplate {
|
||||
ReductionDimRaw::source_template().register("assign", "output[id] = sum;")
|
||||
impl StaticKernelSource for SumDim {
|
||||
fn source() -> SourceTemplate {
|
||||
ReductionDimRaw::source().register("assign", "output[id] = sum;")
|
||||
}
|
||||
}
|
||||
|
||||
impl StaticKernel for MeanDim {
|
||||
fn source_template() -> SourceTemplate {
|
||||
ReductionDimRaw::source_template()
|
||||
impl StaticKernelSource for MeanDim {
|
||||
fn source() -> SourceTemplate {
|
||||
ReductionDimRaw::source()
|
||||
.add_template(
|
||||
"fn mean_dim(sum: {{ elem }}, dim: u32) -> {{ elem }} {
|
||||
return sum / {{ elem }}(dim);
|
||||
|
@ -29,17 +32,17 @@ impl StaticKernel for MeanDim {
|
|||
}
|
||||
}
|
||||
|
||||
impl StaticKernel for ArgsMax {
|
||||
fn source_template() -> SourceTemplate {
|
||||
ReductionArgsRaw::source_template()
|
||||
impl StaticKernelSource for ArgsMax {
|
||||
fn source() -> SourceTemplate {
|
||||
ReductionArgsRaw::source()
|
||||
.register("cmp", ">")
|
||||
.register("initial", (-32767).to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl StaticKernel for ArgsMin {
|
||||
fn source_template() -> SourceTemplate {
|
||||
ReductionArgsRaw::source_template()
|
||||
impl StaticKernelSource for ArgsMin {
|
||||
fn source() -> SourceTemplate {
|
||||
ReductionArgsRaw::source()
|
||||
.register("cmp", "<")
|
||||
.register("initial", 32767.to_string())
|
||||
}
|
||||
|
@ -49,28 +52,29 @@ impl StaticKernel for ArgsMin {
|
|||
pub fn sum<E: WgpuElement, const D: usize>(input: WgpuTensor<E, D>) -> WgpuTensor<E, 1> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let mut input_buffer = input.buffer;
|
||||
let mut input_handle = input.handle;
|
||||
let mut workgroup = elemwise_workgroup(input.shape.num_elements(), WORKGROUP);
|
||||
|
||||
let kernel = input
|
||||
.context
|
||||
.compile_static::<KernelSettings<RecursiveSumRaw, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
loop {
|
||||
let num_invocations = workgroup.num_invocations();
|
||||
let buffer = input
|
||||
.context
|
||||
.create_buffer(core::mem::size_of::<E>() * num_invocations);
|
||||
let handle = input
|
||||
.client
|
||||
.empty(core::mem::size_of::<E>() * num_invocations);
|
||||
|
||||
let kernel =
|
||||
StaticKernel::<KernelSettings<RecursiveSumRaw, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
workgroup,
|
||||
);
|
||||
|
||||
input
|
||||
.context
|
||||
.execute(workgroup.clone(), kernel.clone(), &[&input_buffer, &buffer]);
|
||||
.client
|
||||
.execute(Box::new(kernel), &[&input_handle, &handle]);
|
||||
|
||||
if num_invocations <= 1 {
|
||||
return WgpuTensor::new(input.context, Shape::new([1]), buffer);
|
||||
return WgpuTensor::new(input.client, input.device, Shape::new([1]), handle);
|
||||
}
|
||||
|
||||
input_buffer = buffer;
|
||||
input_handle = handle;
|
||||
workgroup = elemwise_workgroup(num_invocations, WORKGROUP);
|
||||
}
|
||||
}
|
||||
|
@ -91,7 +95,7 @@ pub fn mean_dim<E: WgpuElement, const D: usize>(
|
|||
reduction_dim::<MeanDim, E, D>(input, dim)
|
||||
}
|
||||
|
||||
fn reduction_dim<K: StaticKernel, E: WgpuElement, const D: usize>(
|
||||
fn reduction_dim<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<E, D> {
|
||||
|
@ -100,25 +104,25 @@ fn reduction_dim<K: StaticKernel, E: WgpuElement, const D: usize>(
|
|||
let mut shape_out = input.shape.clone();
|
||||
shape_out.dims[dim] = 1;
|
||||
let num_elems = shape_out.num_elements();
|
||||
let buffer = input
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(input.context.clone(), shape_out, buffer);
|
||||
let handle = input.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
shape_out,
|
||||
handle,
|
||||
);
|
||||
|
||||
let kernel = input
|
||||
.context
|
||||
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
);
|
||||
|
||||
let mut info = build_info(&[&input, &output]);
|
||||
info.push(dim as u32);
|
||||
let info_buffers = input
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
let info_handle = input.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
input.context.execute(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
&[&input.buffer, &output.buffer, &info_buffers],
|
||||
input.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&input.handle, &output.handle, &info_handle],
|
||||
);
|
||||
|
||||
output
|
||||
|
@ -140,7 +144,7 @@ pub fn argmin<E: WgpuElement, I: WgpuElement, const D: usize>(
|
|||
reduction_args_dim::<ArgsMin, E, I, D>(input, dim)
|
||||
}
|
||||
|
||||
fn reduction_args_dim<K: StaticKernel, E: WgpuElement, I: WgpuElement, const D: usize>(
|
||||
fn reduction_args_dim<K: StaticKernelSource, E: WgpuElement, I: WgpuElement, const D: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<I, D> {
|
||||
|
@ -149,27 +153,27 @@ fn reduction_args_dim<K: StaticKernel, E: WgpuElement, I: WgpuElement, const D:
|
|||
let mut shape_out = input.shape.clone();
|
||||
shape_out.dims[dim] = 1;
|
||||
let num_elems = shape_out.num_elements();
|
||||
let buffer = input
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<I>());
|
||||
let output = WgpuTensor::new(input.context.clone(), shape_out, buffer);
|
||||
|
||||
let kernel = input
|
||||
.context
|
||||
.compile_static::<KernelSettings<K, E, I, WORKGROUP, WORKGROUP, 1>>();
|
||||
let mut info = build_info(&[&input, &output]);
|
||||
info.push(dim as u32);
|
||||
let info_buffers = input
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
input.context.execute(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
&[&input.buffer, &output.buffer, &info_buffers],
|
||||
let buffer = input.client.empty(num_elems * core::mem::size_of::<I>());
|
||||
let output = WgpuTensor::new(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
shape_out,
|
||||
buffer,
|
||||
);
|
||||
|
||||
WgpuTensor::new(output.context, output.shape, output.buffer)
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, I, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
);
|
||||
let mut info = build_info(&[&input, &output]);
|
||||
info.push(dim as u32);
|
||||
let info_handle = input.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
input.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&input.handle, &output.handle, &info_handle],
|
||||
);
|
||||
|
||||
WgpuTensor::new(output.client, output.device, output.shape, output.handle)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use super::{elemwise_workgroup, KernelSettings, StaticKernel};
|
||||
use crate::{element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
|
||||
use super::{elemwise_workgroup, KernelSettings, StaticKernelSource};
|
||||
use crate::{compute::StaticKernel, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
|
||||
|
||||
kernel_wgsl!(UnaryRaw, "../template/unary.wgsl");
|
||||
kernel_wgsl!(UnaryInplaceRaw, "../template/unary_inplace.wgsl");
|
||||
|
@ -13,9 +13,9 @@ macro_rules! unary {
|
|||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernel for $struct {
|
||||
fn source_template() -> $crate::kernel::SourceTemplate {
|
||||
let source = $crate::kernel::UnaryRaw::source_template();
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
let source = $crate::kernel::UnaryRaw::source();
|
||||
source.register("body", format!("output[id] = {}(input[id]);", $func))
|
||||
}
|
||||
}
|
||||
|
@ -26,9 +26,9 @@ macro_rules! unary {
|
|||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernel for $struct {
|
||||
fn source_template() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryRaw::source_template().register("body", $body)
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryRaw::source().register("body", $body)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -39,9 +39,9 @@ macro_rules! unary {
|
|||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernel for $struct {
|
||||
fn source_template() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryRaw::source_template()
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryRaw::source()
|
||||
.register("body", format!("output[id] = {}(input[id]);", $func))
|
||||
.add_template(include_str!($file))
|
||||
}
|
||||
|
@ -58,9 +58,9 @@ macro_rules! unary_inplace {
|
|||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernel for $struct {
|
||||
fn source_template() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryInplaceRaw::source_template()
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryInplaceRaw::source()
|
||||
.register("body", format!("input[id] = {}(input[id]);", $func))
|
||||
}
|
||||
}
|
||||
|
@ -71,9 +71,9 @@ macro_rules! unary_inplace {
|
|||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernel for $struct {
|
||||
fn source_template() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryInplaceRaw::source_template().register("body", $body)
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryInplaceRaw::source().register("body", $body)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -84,9 +84,9 @@ macro_rules! unary_inplace {
|
|||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernel for $struct {
|
||||
fn source_template() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryInplaceRaw::source_template()
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryInplaceRaw::source()
|
||||
.register("body", format!("input[id] = {}(input[id]);", $func))
|
||||
.add_template(include_str!($file))
|
||||
}
|
||||
|
@ -95,59 +95,55 @@ macro_rules! unary_inplace {
|
|||
}
|
||||
|
||||
/// Execute a unary kernel using the default settings.
|
||||
pub fn unary_default<K: StaticKernel, E: WgpuElement, const D: usize>(
|
||||
pub fn unary_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
unary::<K, E, D, 32>(input)
|
||||
}
|
||||
|
||||
/// Execute a unary inplace kernel using the default settings.
|
||||
pub fn unary_inplace_default<K: StaticKernel, E: WgpuElement, const D: usize>(
|
||||
pub fn unary_inplace_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
unary_inplace::<K, E, D, 32>(input)
|
||||
}
|
||||
|
||||
/// Execute a unary inplace kernel using the provided WORKGROUP.
|
||||
pub fn unary_inplace<K: StaticKernel, E: WgpuElement, const D: usize, const WORKGROUP: usize>(
|
||||
pub fn unary_inplace<
|
||||
K: StaticKernelSource,
|
||||
E: WgpuElement,
|
||||
const D: usize,
|
||||
const WORKGROUP: usize,
|
||||
>(
|
||||
input: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let num_elems = input.shape.num_elements();
|
||||
let kernel = input
|
||||
.context
|
||||
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
input.context.execute(
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
&[&input.buffer],
|
||||
);
|
||||
|
||||
input.client.execute(Box::new(kernel), &[&input.handle]);
|
||||
|
||||
input
|
||||
}
|
||||
|
||||
/// Execute a unary kernel using the provided WORKGROUP.
|
||||
pub fn unary<K: StaticKernel, E: WgpuElement, const D: usize, const WORKGROUP: usize>(
|
||||
pub fn unary<K: StaticKernelSource, E: WgpuElement, const D: usize, const WORKGROUP: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let num_elems = input.shape.num_elements();
|
||||
let buffer = input
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let mut output = WgpuTensor::new(input.context.clone(), input.shape, buffer);
|
||||
let buffer = input.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let mut output = WgpuTensor::new(input.client.clone(), input.device, input.shape, buffer);
|
||||
// Since we don't handle the stride inside the kernel, the output tensor have the same strides
|
||||
// as the input tensor. It might not be in the default format.
|
||||
output.strides = input.strides;
|
||||
|
||||
let kernel = input
|
||||
.context
|
||||
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
input.context.execute(
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
&[&input.buffer, &output.buffer],
|
||||
);
|
||||
input
|
||||
.client
|
||||
.execute(Box::new(kernel), &[&input.handle, &output.handle]);
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use super::{elemwise_workgroup, KernelSettings, StaticKernel};
|
||||
use crate::{element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
|
||||
use super::{elemwise_workgroup, KernelSettings, StaticKernelSource};
|
||||
use crate::{compute::StaticKernel, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
|
||||
|
||||
kernel_wgsl!(UnaryScalarRaw, "../template/unary_scalar.wgsl");
|
||||
kernel_wgsl!(
|
||||
|
@ -16,9 +16,9 @@ macro_rules! unary_scalar {
|
|||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernel for $struct {
|
||||
fn source_template() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryScalarRaw::source_template()
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryScalarRaw::source()
|
||||
.register("body", format!("output[id] = lhs[id] {} rhs;", $ops))
|
||||
}
|
||||
}
|
||||
|
@ -30,9 +30,9 @@ macro_rules! unary_scalar {
|
|||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernel for $struct {
|
||||
fn source_template() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryScalarRaw::source_template()
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryScalarRaw::source()
|
||||
.register("body", format!("output[id] = {}(lhs[id], rhs);", $func))
|
||||
}
|
||||
}
|
||||
|
@ -45,9 +45,9 @@ macro_rules! unary_scalar {
|
|||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernel for $struct {
|
||||
fn source_template() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryScalarRaw::source_template()
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryScalarRaw::source()
|
||||
.register("body", format!("output[id] = {}(lhs[id], rhs);", $func))
|
||||
.add_template(include_str!($file))
|
||||
}
|
||||
|
@ -64,9 +64,9 @@ macro_rules! unary_scalar_inplace {
|
|||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernel for $struct {
|
||||
fn source_template() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryScalarInplaceRaw::source_template()
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryScalarInplaceRaw::source()
|
||||
.register("body", format!("lhs[id] = lhs[id] {} rhs;", $ops))
|
||||
}
|
||||
}
|
||||
|
@ -78,9 +78,9 @@ macro_rules! unary_scalar_inplace {
|
|||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernel for $struct {
|
||||
fn source_template() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryScalarInplaceRaw::source_template()
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryScalarInplaceRaw::source()
|
||||
.register("body", format!("lhs[id] = {}(lhs[id], rhs);", $func))
|
||||
}
|
||||
}
|
||||
|
@ -93,9 +93,9 @@ macro_rules! unary_scalar_inplace {
|
|||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernel for $struct {
|
||||
fn source_template() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryScalarInplaceRaw::source_template()
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryScalarInplaceRaw::source()
|
||||
.register("body", format!("lhs[id] = {}(lhs[id], rhs);", $func))
|
||||
.add_template(include_str!($file))
|
||||
}
|
||||
|
@ -104,7 +104,7 @@ macro_rules! unary_scalar_inplace {
|
|||
}
|
||||
|
||||
/// Execute a unary scalar kernel using the default settings.
|
||||
pub fn unary_scalar_default<K: StaticKernel, E: WgpuElement, const D: usize>(
|
||||
pub fn unary_scalar_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
scalar: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
|
@ -112,31 +112,33 @@ pub fn unary_scalar_default<K: StaticKernel, E: WgpuElement, const D: usize>(
|
|||
}
|
||||
|
||||
/// Execute a unary scalar kernel using the provided WORKGROUP.
|
||||
pub fn unary_scalar<K: StaticKernel, E: WgpuElement, const D: usize, const WORKGROUP: usize>(
|
||||
pub fn unary_scalar<
|
||||
K: StaticKernelSource,
|
||||
E: WgpuElement,
|
||||
const D: usize,
|
||||
const WORKGROUP: usize,
|
||||
>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
scalar: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let num_elems = lhs.shape.num_elements();
|
||||
let buffer = lhs
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(lhs.context.clone(), lhs.shape, buffer);
|
||||
let kernel = lhs
|
||||
.context
|
||||
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
let rhs_buffer = lhs.context.create_buffer_with_data(E::as_bytes(&[scalar]));
|
||||
|
||||
lhs.context.execute(
|
||||
let buffer = lhs.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(lhs.client.clone(), lhs.device, lhs.shape, buffer);
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
&[&lhs.buffer, &rhs_buffer, &output.buffer],
|
||||
);
|
||||
let rhs_handle = lhs.client.create(E::as_bytes(&[scalar]));
|
||||
|
||||
lhs.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&lhs.handle, &rhs_handle, &output.handle],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Execute a unary scalar inplace kernel using the default settings.
|
||||
pub fn unary_scalar_inplace_default<K: StaticKernel, E: WgpuElement, const D: usize>(
|
||||
pub fn unary_scalar_inplace_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
scalar: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
|
@ -145,7 +147,7 @@ pub fn unary_scalar_inplace_default<K: StaticKernel, E: WgpuElement, const D: us
|
|||
|
||||
/// Execute a unary scalar inplace kernel using the provided WORKGROUP.
|
||||
pub fn unary_scalar_inplace<
|
||||
K: StaticKernel,
|
||||
K: StaticKernelSource,
|
||||
E: WgpuElement,
|
||||
const D: usize,
|
||||
const WORKGROUP: usize,
|
||||
|
@ -153,19 +155,14 @@ pub fn unary_scalar_inplace<
|
|||
lhs: WgpuTensor<E, D>,
|
||||
scalar: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let kernel = lhs
|
||||
.context
|
||||
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
let rhs_buffer = lhs.context.create_buffer_with_data(E::as_bytes(&[scalar]));
|
||||
|
||||
lhs.context.execute(
|
||||
{
|
||||
let num_elems = lhs.shape.num_elements();
|
||||
elemwise_workgroup(num_elems, WORKGROUP)
|
||||
},
|
||||
kernel,
|
||||
&[&lhs.buffer, &rhs_buffer],
|
||||
let num_elems = lhs.shape.num_elements();
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
);
|
||||
let rhs_handle = lhs.client.create(E::as_bytes(&[scalar]));
|
||||
|
||||
lhs.client
|
||||
.execute(Box::new(kernel), &[&lhs.handle, &rhs_handle]);
|
||||
|
||||
lhs
|
||||
}
|
||||
|
|
|
@ -10,20 +10,13 @@ mod ops;
|
|||
|
||||
/// Benchmark module
|
||||
pub mod benchmark;
|
||||
/// Context module.
|
||||
pub mod context;
|
||||
/// Compute related module.
|
||||
pub mod compute;
|
||||
/// Kernel module
|
||||
pub mod kernel;
|
||||
/// Tensor module.
|
||||
pub mod tensor;
|
||||
|
||||
#[cfg(test)] // Only enabled for dev for now.
|
||||
/// Compute related module.
|
||||
pub mod compute;
|
||||
|
||||
pub(crate) mod pool;
|
||||
pub(crate) mod tune;
|
||||
|
||||
mod element;
|
||||
pub use element::{FloatElement, IntElement};
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use crate::{
|
||||
element::WgpuElement, kernel, pool::get_context, tensor::WgpuTensor, GraphicsApi, WgpuDevice,
|
||||
compute::compute_client, element::WgpuElement, kernel, tensor::WgpuTensor, GraphicsApi,
|
||||
WgpuDevice,
|
||||
};
|
||||
use burn_tensor::{backend::Backend, Data, Shape};
|
||||
|
||||
|
@ -18,15 +19,15 @@ pub fn from_data<G: GraphicsApi, E: WgpuElement, const D: usize>(
|
|||
data: Data<E, D>,
|
||||
device: &WgpuDevice,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let context = get_context::<G>(device);
|
||||
let buffer = context.create_buffer_with_data_options(E::as_bytes(&data.value), true);
|
||||
let client = compute_client::<G>(device);
|
||||
let buffer = client.create(E::as_bytes(&data.value));
|
||||
|
||||
WgpuTensor::new(context, data.shape, buffer)
|
||||
WgpuTensor::new(client, device.clone(), data.shape, buffer)
|
||||
}
|
||||
|
||||
pub fn into_data<E: WgpuElement, const D: usize>(tensor: WgpuTensor<E, D>) -> Data<E, D> {
|
||||
let tensor = kernel::into_contiguous(tensor);
|
||||
let bytes = tensor.context.read_buffer(tensor.buffer);
|
||||
let bytes = tensor.client.read(&tensor.handle);
|
||||
let values = E::from_bytes(&bytes);
|
||||
|
||||
Data::new(values.to_vec(), tensor.shape)
|
||||
|
@ -36,22 +37,22 @@ pub fn to_device<G: GraphicsApi, E: WgpuElement, const D: usize>(
|
|||
tensor: WgpuTensor<E, D>,
|
||||
device: &WgpuDevice,
|
||||
) -> WgpuTensor<E, D> {
|
||||
if &tensor.context.device == device {
|
||||
if &tensor.device == device {
|
||||
return tensor;
|
||||
}
|
||||
|
||||
let context = get_context::<G>(device);
|
||||
tensor.to_context(context)
|
||||
let client = compute_client::<G>(device);
|
||||
tensor.to_client(client, device.clone())
|
||||
}
|
||||
|
||||
pub fn empty<G: GraphicsApi, E: WgpuElement, const D: usize>(
|
||||
shape: Shape<D>,
|
||||
device: &WgpuDevice,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let context = get_context::<G>(device);
|
||||
let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::<E>());
|
||||
let client = compute_client::<G>(device);
|
||||
let buffer = client.empty(shape.num_elements() * core::mem::size_of::<E>());
|
||||
|
||||
WgpuTensor::new(context, shape, buffer)
|
||||
WgpuTensor::new(client, device.clone(), shape, buffer)
|
||||
}
|
||||
|
||||
pub fn swap_dims<E: WgpuElement, const D: usize>(
|
||||
|
@ -72,5 +73,5 @@ pub fn reshape<E: WgpuElement, const D1: usize, const D2: usize>(
|
|||
// TODO: Not force standard layout all the time (improve performance).
|
||||
let tensor = kernel::into_contiguous(tensor);
|
||||
|
||||
WgpuTensor::new(tensor.context, shape, tensor.buffer)
|
||||
WgpuTensor::new(tensor.client, tensor.device, shape, tensor.handle)
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@ where
|
|||
|
||||
fn bool_into_int<const D: usize>(tensor: BoolTensor<Self, D>) -> IntTensor<Self, D> {
|
||||
if std::mem::size_of::<I>() == std::mem::size_of::<u32>() {
|
||||
return WgpuTensor::new(tensor.context, tensor.shape, tensor.buffer);
|
||||
return WgpuTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle);
|
||||
}
|
||||
|
||||
let device = Self::bool_device(&tensor);
|
||||
|
@ -57,7 +57,7 @@ where
|
|||
}
|
||||
|
||||
fn bool_device<const D: usize>(tensor: &BoolTensor<Self, D>) -> Device<Self> {
|
||||
tensor.context.device.clone()
|
||||
tensor.device.clone()
|
||||
}
|
||||
|
||||
fn bool_to_device<const D: usize>(
|
||||
|
|
|
@ -57,7 +57,7 @@ where
|
|||
}
|
||||
|
||||
fn device<const D: usize>(tensor: &FloatTensor<Self, D>) -> Device<Self> {
|
||||
tensor.context.device.clone()
|
||||
tensor.device.clone()
|
||||
}
|
||||
|
||||
fn to_device<const D: usize>(
|
||||
|
@ -139,11 +139,7 @@ where
|
|||
lhs: FloatTensor<Self, D>,
|
||||
rhs: FloatTensor<Self, D>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
#[cfg(feature = "autotune")]
|
||||
return kernel::matmul::tune::<G, F, D>(lhs, rhs);
|
||||
|
||||
#[cfg(not(feature = "autotune"))]
|
||||
kernel::matmul::contiguous_vectorized::matmul_tiling_2d_default(lhs, rhs)
|
||||
kernel::matmul::contiguous::matmul_tiling_2d_default(lhs, rhs)
|
||||
}
|
||||
|
||||
fn swap_dims<const D: usize>(
|
||||
|
|
|
@ -33,7 +33,7 @@ where
|
|||
}
|
||||
|
||||
fn int_device<const D: usize>(tensor: &IntTensor<Self, D>) -> Device<Self> {
|
||||
tensor.context.device.clone()
|
||||
tensor.device.clone()
|
||||
}
|
||||
|
||||
fn int_to_device<const D: usize>(
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use crate::compute::{compute_client, WgpuComputeClient};
|
||||
use crate::kernel::{
|
||||
binary_elemwise_default, binary_elemwise_inplace_default, unary_scalar_default,
|
||||
unary_scalar_inplace_default,
|
||||
};
|
||||
use crate::pool::get_context;
|
||||
use crate::{
|
||||
binary_elemwise, binary_elemwise_inplace, element::WgpuElement, tensor::WgpuTensor,
|
||||
unary_scalar, unary_scalar_inplace,
|
||||
|
@ -10,26 +10,38 @@ use crate::{
|
|||
use crate::{GraphicsApi, WgpuDevice};
|
||||
use burn_tensor::{Element, ElementConversion, Shape};
|
||||
|
||||
pub fn zeros<G: GraphicsApi, E: WgpuElement, const D: usize>(
|
||||
pub fn zeros<G: GraphicsApi, E: WgpuElement + Element, const D: usize>(
|
||||
shape: Shape<D>,
|
||||
device: &WgpuDevice,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let context = get_context::<G>(device);
|
||||
let client = compute_client::<G>(device);
|
||||
|
||||
let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::<E>());
|
||||
zeros_device(client, device.clone(), shape)
|
||||
}
|
||||
|
||||
WgpuTensor::new(context, shape, buffer)
|
||||
pub fn zeros_device<E: WgpuElement + Element, const D: usize>(
|
||||
client: WgpuComputeClient,
|
||||
device: WgpuDevice,
|
||||
shape: Shape<D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
mul_scalar(empty_device(client, device, shape), 0i32.elem::<E>())
|
||||
}
|
||||
|
||||
pub fn empty_device<E: WgpuElement, const D: usize>(
|
||||
client: WgpuComputeClient,
|
||||
device: WgpuDevice,
|
||||
shape: Shape<D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let buffer = client.empty(shape.num_elements() * core::mem::size_of::<E>());
|
||||
|
||||
WgpuTensor::new(client, device, shape, buffer)
|
||||
}
|
||||
|
||||
pub fn ones<G: GraphicsApi, E: WgpuElement + Element, const D: usize>(
|
||||
shape: Shape<D>,
|
||||
device: &WgpuDevice,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let context = get_context::<G>(device);
|
||||
|
||||
let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::<E>());
|
||||
|
||||
add_scalar(WgpuTensor::new(context, shape, buffer), 1i32.elem::<E>())
|
||||
add_scalar(zeros::<G, E, D>(shape, device), 1i32.elem::<E>())
|
||||
}
|
||||
|
||||
pub fn add<E: WgpuElement, const D: usize>(
|
||||
|
|
|
@ -1,63 +0,0 @@
|
|||
use crate::{context::Context, GraphicsApi, WgpuDevice};
|
||||
use std::{
|
||||
any::TypeId,
|
||||
collections::HashMap,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
static POOL_CONTEXT: Mutex<Option<ContextPool>> = Mutex::new(None);
|
||||
|
||||
#[derive(Default)]
|
||||
struct ContextPool {
|
||||
contexts: HashMap<Key, Arc<Context>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
|
||||
struct Key {
|
||||
api_id: TypeId,
|
||||
device: WgpuDevice,
|
||||
}
|
||||
|
||||
impl Key {
|
||||
fn new<G: GraphicsApi>(device: &WgpuDevice) -> Self {
|
||||
Self {
|
||||
api_id: TypeId::of::<G>(),
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a [context](Context) for the given [device](WGPUDevice).
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// If a context already exist for the current [device](WGPUDevice), the same instance will be
|
||||
/// returned.
|
||||
pub fn get_context<G: GraphicsApi>(device: &WgpuDevice) -> Arc<Context> {
|
||||
let mut pool = POOL_CONTEXT.lock().unwrap();
|
||||
|
||||
let context = if let Some(pool) = pool.as_mut() {
|
||||
// Fetch device in pool
|
||||
match pool.contexts.get(&Key::new::<G>(device)) {
|
||||
Some(context) => context.clone(),
|
||||
None => {
|
||||
// Init new device
|
||||
let context = Arc::new(Context::new::<G>(device));
|
||||
pool.contexts.insert(Key::new::<G>(device), context.clone());
|
||||
context
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Initialize pool
|
||||
let context = Arc::new(Context::new::<G>(device));
|
||||
let mut new_pool = ContextPool::default();
|
||||
|
||||
new_pool
|
||||
.contexts
|
||||
.insert(Key::new::<G>(device), context.clone());
|
||||
*pool = Some(new_pool);
|
||||
context
|
||||
};
|
||||
|
||||
context
|
||||
}
|
|
@ -1,19 +1,22 @@
|
|||
use crate::unary;
|
||||
use crate::{
|
||||
compute::{WgpuComputeClient, WgpuHandle},
|
||||
unary, WgpuDevice,
|
||||
};
|
||||
use crate::{element::WgpuElement, kernel::unary_default};
|
||||
use burn_tensor::Shape;
|
||||
use std::{marker::PhantomData, sync::Arc};
|
||||
use wgpu::Buffer;
|
||||
|
||||
use crate::{context::Context, element::WgpuElement, kernel::unary_default};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// The basic tensor primitive struct.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WgpuTensor<E: WgpuElement, const D: usize> {
|
||||
/// The context the tensor is binded to.
|
||||
pub context: Arc<Context>,
|
||||
/// Compute client for wgpu.
|
||||
pub client: WgpuComputeClient,
|
||||
/// The buffer where the data are stored.
|
||||
pub buffer: Arc<Buffer>,
|
||||
pub handle: WgpuHandle,
|
||||
/// The shape of the current tensor.
|
||||
pub shape: Shape<D>,
|
||||
/// The device of the current tensor.
|
||||
pub device: WgpuDevice,
|
||||
/// The strides of the current tensor.
|
||||
pub strides: [usize; D],
|
||||
elem: PhantomData<E>,
|
||||
|
@ -21,8 +24,12 @@ pub struct WgpuTensor<E: WgpuElement, const D: usize> {
|
|||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct WgpuTensorDyn<E: WgpuElement> {
|
||||
pub(crate) context: Arc<Context>,
|
||||
pub(crate) buffer: Arc<Buffer>,
|
||||
/// Compute client for wgpu.
|
||||
pub client: WgpuComputeClient,
|
||||
/// The buffer where the data are stored.
|
||||
pub handle: WgpuHandle,
|
||||
/// The device of the current tensor.
|
||||
pub device: WgpuDevice,
|
||||
pub(crate) shape: Vec<usize>,
|
||||
pub(crate) strides: Vec<usize>,
|
||||
elem: PhantomData<E>,
|
||||
|
@ -31,8 +38,9 @@ pub(crate) struct WgpuTensorDyn<E: WgpuElement> {
|
|||
impl<E: WgpuElement, const D: usize> From<WgpuTensor<E, D>> for WgpuTensorDyn<E> {
|
||||
fn from(value: WgpuTensor<E, D>) -> Self {
|
||||
WgpuTensorDyn {
|
||||
context: value.context,
|
||||
buffer: value.buffer,
|
||||
client: value.client,
|
||||
handle: value.handle,
|
||||
device: value.device,
|
||||
shape: value.shape.dims.to_vec(),
|
||||
strides: value.strides.to_vec(),
|
||||
elem: PhantomData,
|
||||
|
@ -43,8 +51,9 @@ impl<E: WgpuElement, const D: usize> From<WgpuTensor<E, D>> for WgpuTensorDyn<E>
|
|||
impl<E: WgpuElement, const D: usize> From<WgpuTensorDyn<E>> for WgpuTensor<E, D> {
|
||||
fn from(value: WgpuTensorDyn<E>) -> Self {
|
||||
WgpuTensor {
|
||||
context: value.context,
|
||||
buffer: value.buffer,
|
||||
client: value.client,
|
||||
handle: value.handle,
|
||||
device: value.device,
|
||||
shape: Shape::new(value.shape.try_into().expect("Wrong dimension")),
|
||||
strides: value.strides.try_into().expect("Wrong dimension"),
|
||||
elem: PhantomData,
|
||||
|
@ -54,7 +63,12 @@ impl<E: WgpuElement, const D: usize> From<WgpuTensorDyn<E>> for WgpuTensor<E, D>
|
|||
|
||||
impl<E: WgpuElement, const D: usize> WgpuTensor<E, D> {
|
||||
/// Create a new tensor.
|
||||
pub fn new(context: Arc<Context>, shape: Shape<D>, buffer: Arc<Buffer>) -> Self {
|
||||
pub fn new(
|
||||
client: WgpuComputeClient,
|
||||
device: WgpuDevice,
|
||||
shape: Shape<D>,
|
||||
handle: WgpuHandle,
|
||||
) -> Self {
|
||||
let mut strides = [0; D];
|
||||
|
||||
let mut current = 1;
|
||||
|
@ -69,29 +83,31 @@ impl<E: WgpuElement, const D: usize> WgpuTensor<E, D> {
|
|||
});
|
||||
|
||||
Self {
|
||||
context,
|
||||
buffer,
|
||||
client,
|
||||
handle,
|
||||
shape,
|
||||
strides,
|
||||
device,
|
||||
elem: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Change the context of the current tensor and return the newly transferred tensor.
|
||||
pub fn to_context(&self, context: Arc<Context>) -> Self {
|
||||
let data = self.context.read_buffer(self.buffer.clone());
|
||||
let buffer = context.create_buffer_with_data(&data);
|
||||
pub fn to_client(&self, client: WgpuComputeClient, device: WgpuDevice) -> Self {
|
||||
let data = self.client.read(&self.handle);
|
||||
let handle = client.create(&data);
|
||||
|
||||
Self {
|
||||
context,
|
||||
buffer,
|
||||
client,
|
||||
handle,
|
||||
shape: self.shape.clone(),
|
||||
strides: self.strides,
|
||||
device,
|
||||
elem: PhantomData,
|
||||
}
|
||||
}
|
||||
pub(crate) fn can_mut_broadcast(&self, tensor_other: &WgpuTensor<E, D>) -> bool {
|
||||
if Arc::strong_count(&self.buffer) > 1 {
|
||||
if !self.handle.can_mut() {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -120,19 +136,15 @@ impl<E: WgpuElement, const D: usize> WgpuTensor<E, D> {
|
|||
|
||||
/// Check if the tensor is safe to mutate.
|
||||
pub fn can_mut(&self) -> bool {
|
||||
if Arc::strong_count(&self.buffer) > 1 {
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
self.handle.can_mut()
|
||||
}
|
||||
|
||||
/// Assert that both tensors are on the same device.
|
||||
pub fn assert_is_on_same_device(&self, other: &Self) {
|
||||
if self.context.device != other.context.device {
|
||||
if self.device != other.device {
|
||||
panic!(
|
||||
"Both tensors should be on the same device {:?} != {:?}",
|
||||
self.context.device, other.context.device
|
||||
self.device, other.device
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,196 +0,0 @@
|
|||
use std::{collections::HashMap, fmt::Display, hash::Hash, sync::Arc, time::Duration};
|
||||
|
||||
use burn_common::stub::RwLock;
|
||||
|
||||
use crate::{
|
||||
benchmark::{Benchmark, BenchmarkResult},
|
||||
context::Context,
|
||||
GraphicsApi, WgpuDevice,
|
||||
};
|
||||
|
||||
/// Key used for caching.
|
||||
#[derive(Hash, Clone, Debug, PartialEq, Eq)]
|
||||
pub struct AutoTuneKey {
|
||||
/// List all shapes used for the autotuned kernel.
|
||||
shapes: Vec<Vec<usize>>,
|
||||
/// Operation name.
|
||||
ops: String,
|
||||
/// Device name used to benchmark.
|
||||
device: String,
|
||||
/// Graphics api name.
|
||||
graphics_api: String,
|
||||
}
|
||||
|
||||
impl AutoTuneKey {
|
||||
pub fn new(shapes: Vec<Vec<usize>>, ops: String, context: &Context) -> Self {
|
||||
let device = format!("{:?}", context.info.name);
|
||||
let graphics_api = format!("{:?}", context.info.backend);
|
||||
|
||||
Self {
|
||||
shapes,
|
||||
ops,
|
||||
device,
|
||||
graphics_api,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for AutoTuneKey {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let message = format!(
|
||||
"(AutoTuneKey) Kernel {} - Shapes {:?} - Device {} - API {}",
|
||||
self.ops, self.shapes, self.device, self.graphics_api,
|
||||
);
|
||||
f.write_str(&message)
|
||||
}
|
||||
}
|
||||
|
||||
/// Objects that are stored in the tuner cache. Can have any inputs and outputs.
|
||||
pub type AutoTuneValue = Box<dyn core::any::Any + Send + Sync>;
|
||||
|
||||
/// Executable function
|
||||
pub trait KernelFunction: Send + Sync + 'static {
|
||||
type Input;
|
||||
type Output;
|
||||
|
||||
fn call(&self, input: Self::Input) -> Self::Output;
|
||||
fn description(&self) -> String;
|
||||
}
|
||||
|
||||
/// Encapsulates kernel functions, with specified inputs and outputs
|
||||
pub type AutoTuneFunction<I, O> = Arc<dyn KernelFunction<Input = I, Output = O>>;
|
||||
|
||||
/// The tunable links an executable function to its corresponding benchmark
|
||||
#[derive(new)]
|
||||
pub struct Tunable<G, I, O> {
|
||||
func: AutoTuneFunction<I, O>,
|
||||
benchmark: Arc<dyn Benchmark<G, Args = I>>,
|
||||
}
|
||||
|
||||
impl<G, I, O> std::fmt::Display for Tunable<G, I, O>
|
||||
where
|
||||
G: GraphicsApi,
|
||||
I: Send + Sync + 'static,
|
||||
O: Send + Sync + 'static,
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(self.func.description().as_str())
|
||||
}
|
||||
}
|
||||
|
||||
/// Output of the tuner execution. If execution succeeded, the output of
|
||||
/// the execution is contained. Otherwise, the function must be tuned and
|
||||
/// the input is given back to preserve ownership.
|
||||
#[derive(Debug)]
|
||||
pub enum Execution<I, O> {
|
||||
Executed(O),
|
||||
NoCacheFound(I),
|
||||
}
|
||||
|
||||
/// The tuner allows to find the best version of a kernel by benchmarking
|
||||
/// different versions. It keeps the best version found in a cache, so the best
|
||||
/// function is reused automatically in similar circumstances.
|
||||
#[derive(Debug)]
|
||||
pub struct Tuner {
|
||||
cache: RwLock<HashMap<AutoTuneKey, AutoTuneValue>>,
|
||||
}
|
||||
|
||||
impl Tuner {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
cache: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Executes the function stored in the cache at key id, on specified input,
|
||||
/// and returns its output. If cache has no such id, returns NoCacheFound.
|
||||
pub fn execute<I, O>(&self, id: &AutoTuneKey, input: I) -> Execution<I, O>
|
||||
where
|
||||
I: Send + Sync + 'static,
|
||||
O: Send + Sync + 'static,
|
||||
{
|
||||
let cache = self.cache.read().unwrap();
|
||||
let obj = cache.get(id);
|
||||
|
||||
let obj = match obj {
|
||||
None => return Execution::NoCacheFound(input),
|
||||
Some(value) => value,
|
||||
};
|
||||
|
||||
let func: &Arc<dyn KernelFunction<Input = I, Output = O>> = obj.downcast_ref().unwrap();
|
||||
let output = func.call(input);
|
||||
|
||||
Execution::Executed(output)
|
||||
}
|
||||
|
||||
/// Finds the best tunable and writes it to the cache.
|
||||
pub fn tune<G: GraphicsApi, I, O>(
|
||||
&self,
|
||||
id: AutoTuneKey,
|
||||
input: I,
|
||||
tunables: Vec<Tunable<G, I, O>>,
|
||||
device: &WgpuDevice,
|
||||
context: &Context,
|
||||
) -> O
|
||||
where
|
||||
I: Send + Sync + 'static,
|
||||
O: Send + Sync + 'static,
|
||||
{
|
||||
let mut cache = self.cache.write().unwrap();
|
||||
|
||||
context.start_tuning();
|
||||
let results = benchmark(&tunables, device);
|
||||
let kernel = find_best(&id, tunables, results);
|
||||
cache.insert(id.clone(), kernel);
|
||||
drop(cache);
|
||||
context.stop_tuning();
|
||||
|
||||
match self.execute(&id, input) {
|
||||
Execution::Executed(output) => output,
|
||||
_ => panic!("Should have found a kernel to execute. "),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Finds the best kernel by keeping the one with the smallest median duration.
|
||||
fn find_best<G: GraphicsApi, I, O>(
|
||||
id: &AutoTuneKey,
|
||||
tunables: Vec<Tunable<G, I, O>>,
|
||||
results: Vec<BenchmarkResult>,
|
||||
) -> AutoTuneValue
|
||||
where
|
||||
I: Send + Sync + 'static,
|
||||
O: Send + Sync + 'static,
|
||||
{
|
||||
let mut best_duration = Duration::MAX;
|
||||
let mut best_tunable = None;
|
||||
|
||||
for (tunable, result) in tunables.into_iter().zip(results) {
|
||||
let duration = result.median_duration();
|
||||
|
||||
if duration < best_duration {
|
||||
best_duration = duration;
|
||||
best_tunable = Some(tunable);
|
||||
}
|
||||
}
|
||||
|
||||
let tunable = best_tunable.expect("At least one tunable needed. ");
|
||||
|
||||
log::info!("{} => {}", id, tunable);
|
||||
Box::new(tunable.func)
|
||||
}
|
||||
|
||||
/// Run benchmarks.
|
||||
fn benchmark<G: GraphicsApi, I, O>(
|
||||
tunables: &[Tunable<G, I, O>],
|
||||
device: &WgpuDevice,
|
||||
) -> Vec<BenchmarkResult>
|
||||
where
|
||||
I: Send + Sync + 'static,
|
||||
O: Send + Sync + 'static,
|
||||
{
|
||||
tunables
|
||||
.iter()
|
||||
.map(|tunable| tunable.benchmark.run(device))
|
||||
.collect()
|
||||
}
|
|
@ -1,3 +0,0 @@
|
|||
mod base;
|
||||
|
||||
pub use base::*;
|
|
@ -44,7 +44,6 @@ ndarray-blas-openblas = ["burn-core/ndarray-blas-openblas"]
|
|||
ndarray-blas-openblas-system = ["burn-core/ndarray-blas-openblas-system"]
|
||||
|
||||
wgpu = ["burn-core/wgpu"]
|
||||
wgpu-autotune = ["burn-core/wgpu-autotune"]
|
||||
|
||||
tch = ["burn-core/tch"]
|
||||
|
||||
|
|
|
@ -2,8 +2,10 @@ use crate::FloatTensor;
|
|||
|
||||
use super::Backend;
|
||||
use burn::backend::wgpu::{
|
||||
context::WorkGroup,
|
||||
kernel::{build_info, into_contiguous, DynamicKernel, SourceTemplate, StaticKernel},
|
||||
compute::{DynamicKernel, WorkGroup},
|
||||
kernel::{
|
||||
build_info, into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource,
|
||||
},
|
||||
kernel_wgsl,
|
||||
tensor::WgpuTensor,
|
||||
FloatElement, GraphicsApi, IntElement, WgpuBackend,
|
||||
|
@ -24,11 +26,11 @@ struct FusedMatmulAddRelu<E: FloatElement> {
|
|||
}
|
||||
|
||||
// Implement the dynamic kernel trait for our kernel type.
|
||||
impl<E: FloatElement> DynamicKernel for FusedMatmulAddRelu<E> {
|
||||
fn source_template(self) -> SourceTemplate {
|
||||
impl<E: FloatElement> DynamicKernelSource for FusedMatmulAddRelu<E> {
|
||||
fn source(self) -> SourceTemplate {
|
||||
// Extend our raw kernel with workgroup size information using the
|
||||
// `SourceTemplate` trait.
|
||||
FusedMatmulAddReluRaw::source_template()
|
||||
FusedMatmulAddReluRaw::source()
|
||||
.register("workgroup_size_x", self.workgroup_size_x.to_string())
|
||||
.register("workgroup_size_y", self.workgroup_size_y.to_string())
|
||||
.register("elem", E::type_name())
|
||||
|
@ -76,23 +78,18 @@ impl<G: GraphicsApi, F: FloatElement, I: IntElement> Backend for WgpuBackend<G,
|
|||
|
||||
// Create a buffer for the output tensor.
|
||||
let buffer = lhs
|
||||
.context
|
||||
.create_buffer(shape_out.num_elements() * core::mem::size_of::<F>());
|
||||
.client
|
||||
.empty(shape_out.num_elements() * core::mem::size_of::<F>());
|
||||
|
||||
// Create the output tensor primitive.
|
||||
let output = WgpuTensor::new(lhs.context.clone(), shape_out, buffer);
|
||||
let output = WgpuTensor::new(lhs.client.clone(), lhs.device.clone(), shape_out, buffer);
|
||||
|
||||
// Compile the kernel or use the cache based on the template id.
|
||||
let kernel = lhs.context.compile_dynamic(FusedMatmulAddRelu::<F>::new(
|
||||
workgroup_size_x,
|
||||
workgroup_size_y,
|
||||
));
|
||||
// Create the kernel.
|
||||
let kernel = FusedMatmulAddRelu::<F>::new(workgroup_size_x, workgroup_size_y);
|
||||
|
||||
// Build info buffer with tensor information needed by the kernel, such as shapes and strides.
|
||||
let info = build_info(&[&lhs, &rhs, &output]);
|
||||
let info_buffer = lhs
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
let info_handle = lhs.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
// Declare the wgsl workgroup with the number of blocks in x, y and z.
|
||||
let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32;
|
||||
|
@ -100,15 +97,14 @@ impl<G: GraphicsApi, F: FloatElement, I: IntElement> Backend for WgpuBackend<G,
|
|||
let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_batches as u32);
|
||||
|
||||
// Execute lazily the kernel with the launch information and the given buffers.
|
||||
lhs.context.execute(
|
||||
workgroup,
|
||||
kernel,
|
||||
lhs.client.execute(
|
||||
Box::new(DynamicKernel::new(kernel, workgroup)),
|
||||
&[
|
||||
&lhs.buffer,
|
||||
&rhs.buffer,
|
||||
&bias.buffer,
|
||||
&output.buffer,
|
||||
&info_buffer,
|
||||
&lhs.handle,
|
||||
&rhs.handle,
|
||||
&bias.handle,
|
||||
&output.handle,
|
||||
&info_handle,
|
||||
],
|
||||
);
|
||||
|
||||
|
|
Loading…
Reference in New Issue