From 3f9e97946fed2dc97776f5cbea32ddcd878ab103 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Sun, 7 Jul 2024 00:17:01 +0100 Subject: [PATCH] Feat: Dynamic cube count dispatch (#1975) --- .../backend-extension/custom-wgpu-kernel.md | 5 +- crates/burn-compute/src/channel/base.rs | 7 ++- crates/burn-compute/src/channel/cell.rs | 9 ++- crates/burn-compute/src/channel/mpsc.rs | 16 +++-- crates/burn-compute/src/channel/mutex.rs | 9 ++- crates/burn-compute/src/client.rs | 9 ++- crates/burn-compute/src/server.rs | 9 ++- crates/burn-compute/tests/dummy/server.rs | 8 ++- .../tests/dummy/tune/autotune_operations.rs | 2 +- crates/burn-compute/tests/integration_test.rs | 1 + .../src/codegen_function/launch.rs | 2 +- crates/burn-cube/src/codegen/execution.rs | 34 +++++------ crates/burn-cube/src/codegen/integrator.rs | 6 +- crates/burn-cube/src/compute/kernel.rs | 56 +++++------------- crates/burn-cube/src/compute/launcher.rs | 6 +- crates/burn-cube/src/lib.rs | 8 ++- crates/burn-cube/src/runtime.rs | 12 +++- crates/burn-cube/src/runtime_tests/cmma.rs | 2 +- crates/burn-cube/src/runtime_tests/launch.rs | 4 +- crates/burn-cube/src/runtime_tests/subcube.rs | 4 +- crates/burn-cuda/src/compute/server.rs | 59 ++++++++++++++----- .../src/fusion/elemwise/optimization.rs | 4 +- crates/burn-jit/src/fusion/kernel.rs | 20 ++++--- crates/burn-jit/src/kernel/index/scatter.rs | 8 +-- .../src/kernel/index/select_assign.rs | 8 +-- crates/burn-jit/src/kernel/matmul/base.rs | 26 ++++---- crates/burn-jit/src/kernel/matmul/simple.rs | 8 +-- crates/burn-jit/src/kernel/matmul/tiling2d.rs | 4 +- .../tiling2d_shader/shader_information.rs | 10 ++-- crates/burn-jit/src/kernel/prng/base.rs | 40 +++++++------ .../src/kernel/reduce/shared/shader.rs | 46 +++++++-------- crates/burn-jit/src/lib.rs | 6 +- crates/burn-jit/src/template/base.rs | 16 +---- crates/burn-wgpu/src/compute/server.rs | 33 +++++++++-- crates/burn-wgpu/src/compute/storage.rs | 3 +- examples/custom-wgpu-kernel/src/forward.rs | 6 +- examples/gelu/src/lib.rs | 2 +- 37 files changed, 293 insertions(+), 215 deletions(-) diff --git a/burn-book/src/advanced/backend-extension/custom-wgpu-kernel.md b/burn-book/src/advanced/backend-extension/custom-wgpu-kernel.md index fda6dcea5..b03f146db 100644 --- a/burn-book/src/advanced/backend-extension/custom-wgpu-kernel.md +++ b/burn-book/src/advanced/backend-extension/custom-wgpu-kernel.md @@ -248,11 +248,12 @@ impl Backend for JitBackend { // Declare the wgsl workgroup with the number of blocks in x, y and z. let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32; let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32; - let cube_count = CubeCount::new(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32); + let cube_count = CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32); // Execute lazily the kernel with the launch information and the given buffers. lhs.client.execute( - Box::new(SourceKernel::new(kernel, cube_count, cube_dim)), + Box::new(SourceKernel::new(kernel, cube_dim)), + cube_count, vec![ lhs.handle.binding(), rhs.handle.binding(), diff --git a/crates/burn-compute/src/channel/base.rs b/crates/burn-compute/src/channel/base.rs index 0a6910337..0fa266159 100644 --- a/crates/burn-compute/src/channel/base.rs +++ b/crates/burn-compute/src/channel/base.rs @@ -24,7 +24,12 @@ pub trait ComputeChannel: Clone + core::fmt::Debug + Send fn empty(&self, size: usize) -> Handle; /// Executes the `kernel` over the given `bindings`. - fn execute(&self, kernel: Server::Kernel, bindings: Vec>); + fn execute( + &self, + kernel: Server::Kernel, + count: Server::DispatchOptions, + bindings: Vec>, + ); /// Perform some synchronization of commands on the server. fn sync(&self, sync_type: SyncType); diff --git a/crates/burn-compute/src/channel/cell.rs b/crates/burn-compute/src/channel/cell.rs index 631947966..d80a44b9f 100644 --- a/crates/burn-compute/src/channel/cell.rs +++ b/crates/burn-compute/src/channel/cell.rs @@ -63,10 +63,15 @@ where self.server.borrow_mut().empty(size) } - fn execute(&self, kernel_description: Server::Kernel, bindings: Vec>) { + fn execute( + &self, + kernel_description: Server::Kernel, + count: Server::DispatchOptions, + bindings: Vec>, + ) { self.server .borrow_mut() - .execute(kernel_description, bindings) + .execute(kernel_description, count, bindings) } fn sync(&self, sync_type: SyncType) { diff --git a/crates/burn-compute/src/channel/mpsc.rs b/crates/burn-compute/src/channel/mpsc.rs index 8eeb04326..c7c22659f 100644 --- a/crates/burn-compute/src/channel/mpsc.rs +++ b/crates/burn-compute/src/channel/mpsc.rs @@ -39,7 +39,10 @@ where ), Create(Vec, Callback>), Empty(usize, Callback>), - ExecuteKernel(Server::Kernel, Vec>), + ExecuteKernel( + (Server::Kernel, Server::DispatchOptions), + Vec>, + ), Sync(SyncType, Callback<()>), } @@ -74,7 +77,7 @@ where callback.send(handle).await.unwrap(); } Message::ExecuteKernel(kernel, bindings) => { - server.execute(kernel, bindings); + server.execute(kernel.0, kernel.1, bindings); } Message::Sync(sync_type, callback) => { server.sync(sync_type); @@ -148,10 +151,15 @@ where handle_response(response.recv_blocking()) } - fn execute(&self, kernel: Server::Kernel, bindings: Vec>) { + fn execute( + &self, + kernel: Server::Kernel, + count: Server::DispatchOptions, + bindings: Vec>, + ) { self.state .sender - .send_blocking(Message::ExecuteKernel(kernel, bindings)) + .send_blocking(Message::ExecuteKernel((kernel, count), bindings)) .unwrap() } diff --git a/crates/burn-compute/src/channel/mutex.rs b/crates/burn-compute/src/channel/mutex.rs index a063ab1f1..1eeb1bf37 100644 --- a/crates/burn-compute/src/channel/mutex.rs +++ b/crates/burn-compute/src/channel/mutex.rs @@ -56,8 +56,13 @@ where self.server.lock().empty(size) } - fn execute(&self, kernel: Server::Kernel, handles: Vec>) { - self.server.lock().execute(kernel, handles) + fn execute( + &self, + kernel: Server::Kernel, + count: Server::DispatchOptions, + handles: Vec>, + ) { + self.server.lock().execute(kernel, count, handles) } fn sync(&self, sync_type: SyncType) { diff --git a/crates/burn-compute/src/client.rs b/crates/burn-compute/src/client.rs index bd085712d..2c0bb14ad 100644 --- a/crates/burn-compute/src/client.rs +++ b/crates/burn-compute/src/client.rs @@ -82,8 +82,13 @@ where } /// Executes the `kernel` over the given `bindings`. - pub fn execute(&self, kernel: Server::Kernel, bindings: Vec>) { - self.channel.execute(kernel, bindings) + pub fn execute( + &self, + kernel: Server::Kernel, + count: Server::DispatchOptions, + bindings: Vec>, + ) { + self.channel.execute(kernel, count, bindings) } /// Wait for the completion of every task in the server. diff --git a/crates/burn-compute/src/server.rs b/crates/burn-compute/src/server.rs index b6aa2eb9c..948fa24e9 100644 --- a/crates/burn-compute/src/server.rs +++ b/crates/burn-compute/src/server.rs @@ -17,6 +17,8 @@ where { /// The kernel type defines the computation algorithms. type Kernel: Send; + /// Options when dispatching the kernel, eg. the number of executions. + type DispatchOptions: Send; /// The [storage](ComputeStorage) type defines how data is stored and accessed. type Storage: ComputeStorage; /// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type. @@ -45,7 +47,12 @@ where /// /// Kernels have mutable access to every resource they are given /// and are responsible of determining which should be read or written. - fn execute(&mut self, kernel: Self::Kernel, bindings: Vec>); + fn execute( + &mut self, + kernel: Self::Kernel, + count: Self::DispatchOptions, + bindings: Vec>, + ); /// Wait for the completion of every task in the server. fn sync(&mut self, command: SyncType); diff --git a/crates/burn-compute/tests/dummy/server.rs b/crates/burn-compute/tests/dummy/server.rs index eb4a6ecd8..5e786fe16 100644 --- a/crates/burn-compute/tests/dummy/server.rs +++ b/crates/burn-compute/tests/dummy/server.rs @@ -21,6 +21,7 @@ impl ComputeServer for DummyServer where MM: MemoryManagement, { + type DispatchOptions = (); type Kernel = Arc; type Storage = BytesStorage; type MemoryManagement = MM; @@ -53,7 +54,12 @@ where Handle::new(self.memory_management.reserve(size, || {})) } - fn execute(&mut self, kernel: Self::Kernel, bindings: Vec>) { + fn execute( + &mut self, + kernel: Self::Kernel, + _count: Self::DispatchOptions, + bindings: Vec>, + ) { let mut resources = bindings .into_iter() .map(|binding| self.memory_management.get(binding.memory)) diff --git a/crates/burn-compute/tests/dummy/tune/autotune_operations.rs b/crates/burn-compute/tests/dummy/tune/autotune_operations.rs index b4abad81f..97e1a132f 100644 --- a/crates/burn-compute/tests/dummy/tune/autotune_operations.rs +++ b/crates/burn-compute/tests/dummy/tune/autotune_operations.rs @@ -18,7 +18,7 @@ pub struct OneKernelAutotuneOperation { impl AutotuneOperation for OneKernelAutotuneOperation { /// Executes the operation on given bindings and server, with the additional parameters fn execute(self: Box) { - self.client.execute(self.kernel.clone(), self.bindings); + self.client.execute(self.kernel.clone(), (), self.bindings); } fn clone(&self) -> Box { diff --git a/crates/burn-compute/tests/integration_test.rs b/crates/burn-compute/tests/integration_test.rs index b134090b8..db6927b45 100644 --- a/crates/burn-compute/tests/integration_test.rs +++ b/crates/burn-compute/tests/integration_test.rs @@ -38,6 +38,7 @@ fn execute_elementwise_addition() { client.execute( Arc::new(DummyElementwiseAddition), + (), vec![lhs.binding(), rhs.binding(), out.clone().binding()], ); diff --git a/crates/burn-cube-macros/src/codegen_function/launch.rs b/crates/burn-cube-macros/src/codegen_function/launch.rs index 016bdf728..f87de3a3c 100644 --- a/crates/burn-cube-macros/src/codegen_function/launch.rs +++ b/crates/burn-cube-macros/src/codegen_function/launch.rs @@ -470,7 +470,7 @@ pub fn codegen_launch(sig: &syn::Signature) -> TokenStream { /// Launch pub fn #ident #generics ( client: ComputeClient, - cube_count: CubeCount, + cube_count: CubeCount, cube_dim: CubeDim, #inputs ) -> #output { diff --git a/crates/burn-cube/src/codegen/execution.rs b/crates/burn-cube/src/codegen/execution.rs index 59195ca66..d9fe2cb5f 100644 --- a/crates/burn-cube/src/codegen/execution.rs +++ b/crates/burn-cube/src/codegen/execution.rs @@ -4,13 +4,13 @@ use crate::ir::Elem; use crate::pod::CubeElement; use crate::{calculate_cube_count_elemwise, Kernel, Runtime, SUBCUBE_DIM_APPROX}; use burn_compute::client::ComputeClient; -use burn_compute::server::{Binding, Handle}; +use burn_compute::server::{Binding, ComputeServer, Handle}; -/// The position of the input or output to calculate the number of workgroups to launch. -pub enum CubeCountSettings { +/// The position of the input or output to calculate the number of cubes to launch. +pub enum CubeCountSettings { Input { pos: usize }, Output { pos: usize }, - Custom(CubeCount), + Custom(CubeCount), } pub struct Execution<'h, K, R: Runtime, Scalars> { @@ -73,7 +73,7 @@ where } /// Execute a dynamic kernel. #[allow(unused)] - pub fn execute(self, launch: CubeCountSettings) { + pub fn execute(self, launch: CubeCountSettings) { execute_dynamic::( self.inputs, self.outputs, @@ -108,7 +108,7 @@ where /// Execute a dynamic kernel. #[allow(unused)] - pub fn execute(self, launch: CubeCountSettings) { + pub fn execute(self, launch: CubeCountSettings) { execute_dynamic::( self.inputs, self.outputs, @@ -144,7 +144,7 @@ where } /// Execute a dynamic kernel. #[allow(clippy::too_many_arguments)] - pub fn execute(self, launch: CubeCountSettings) + pub fn execute(self, launch: CubeCountSettings) where K: Kernel + 'static, R: Runtime, @@ -172,7 +172,7 @@ where { /// Execute a dynamic kernel. #[allow(unused)] - pub fn execute(self, launch: CubeCountSettings) { + pub fn execute(self, launch: CubeCountSettings) { execute_dynamic::( self.inputs, self.outputs, @@ -194,7 +194,7 @@ fn execute_dynamic( scalars_2: Option<&[E2]>, scalars_3: Option<&[E3]>, kernel: K, - launch: CubeCountSettings, + launch: CubeCountSettings, client: ComputeClient, ) where K: Kernel + 'static, @@ -207,23 +207,21 @@ fn execute_dynamic( inputs, outputs, scalars_1, scalars_2, scalars_3, launch, &client, ); let mut handles = settings.handles_tensors; - let workgroup = settings.cube_count; handles.push(settings.handle_info.binding()); for handle in settings.handles_scalars.into_iter() { handles.push(handle.binding()); } - let kernel = Box::new(KernelTask::::new(kernel, workgroup)); - - client.execute(kernel, handles); + let kernel = Box::new(KernelTask::::new(kernel)); + client.execute(kernel, settings.cube_count, handles); } struct ExecuteSettings { handles_tensors: Vec>, handle_info: Handle, handles_scalars: Vec>, - cube_count: CubeCount, + cube_count: CubeCount, } fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeElement>( @@ -232,7 +230,7 @@ fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeEl scalars_1: Option<&[E1]>, scalars_2: Option<&[E2]>, scalars_3: Option<&[E3]>, - launch: CubeCountSettings, + launch: CubeCountSettings, client: &ComputeClient, ) -> ExecuteSettings { let mut info = Vec::new(); @@ -295,8 +293,8 @@ fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeEl let handles_scalars = create_scalar_handles::(scalars_1, scalars_2, scalars_3, client); - let workgroup = match launch { - CubeCountSettings::Custom(workgroup) => workgroup, + let cube_count = match launch { + CubeCountSettings::Custom(count) => count, _ => calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX), }; @@ -304,7 +302,7 @@ fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeEl handles_tensors: handles, handle_info: info, handles_scalars, - cube_count: workgroup, + cube_count, } } diff --git a/crates/burn-cube/src/codegen/integrator.rs b/crates/burn-cube/src/codegen/integrator.rs index 4832905e5..d65df226b 100644 --- a/crates/burn-cube/src/codegen/integrator.rs +++ b/crates/burn-cube/src/codegen/integrator.rs @@ -74,9 +74,9 @@ impl core::fmt::Display for KernelSettings { // * Vectorization Global: vg{factor} // * Vectorization Partial Input: v{factor}i{pos} // * Vectorization Partial Output: vo - // * Workgroup Size X: x - // * Workgroup Size Y: y - // * Workgroup Size Z: z + // * Cube Dim X: x + // * Cube Dim Y: y + // * Cube Dim Z: z f.write_str("m")?; for mapping in self.mappings.iter() { f.write_fmt(format_args!( diff --git a/crates/burn-cube/src/compute/kernel.rs b/crates/burn-cube/src/compute/kernel.rs index e6e751508..356ecb4e0 100644 --- a/crates/burn-cube/src/compute/kernel.rs +++ b/crates/burn-cube/src/compute/kernel.rs @@ -1,41 +1,32 @@ +use std::marker::PhantomData; + use crate::{codegen::CompilerRepresentation, ir::CubeDim, Compiler, Kernel}; use alloc::sync::Arc; -use std::marker::PhantomData; +use burn_compute::server::{Binding, ComputeServer}; /// A kernel, compiled in the target language pub struct CompiledKernel { /// Source code of the kernel pub source: String, - /// Size of a workgroup for the compiled kernel + /// Size of a cube for the compiled kernel pub cube_dim: CubeDim, /// The number of bytes used by the share memory pub shared_mem_bytes: usize, } -/// Information needed to launch the kernel -pub struct LaunchSettings { - /// Layout of workgroups for the kernel - pub cube_count: CubeCount, -} - /// Kernel trait with the ComputeShader that will be compiled and cached based on the /// provided id. -/// -/// The kernel will be launched with the given [launch settings](LaunchSettings). pub trait CubeTask: Send + Sync { /// Identifier for the kernel, used for caching kernel compilation. fn id(&self) -> String; /// Compile the kernel into source fn compile(&self) -> CompiledKernel; - /// Launch settings. - fn launch_settings(&self) -> LaunchSettings; } -/// Wraps a [kernel](Kernel) with its [cube count](CubeCount) to create a [cube task](CubeTask). +/// Wraps a [kernel](Kernel) to create a [cube task](CubeTask). #[derive(new)] pub struct KernelTask { kernel_definition: K, - cube_count: CubeCount, _compiler: PhantomData, } @@ -57,12 +48,6 @@ impl CubeTask for KernelTask { fn id(&self) -> String { self.kernel_definition.id().clone() } - - fn launch_settings(&self) -> LaunchSettings { - LaunchSettings { - cube_count: self.cube_count.clone(), - } - } } impl CubeTask for Arc { @@ -73,10 +58,6 @@ impl CubeTask for Arc { fn id(&self) -> String { self.as_ref().id() } - - fn launch_settings(&self) -> LaunchSettings { - self.as_ref().launch_settings() - } } impl CubeTask for Box { @@ -87,26 +68,21 @@ impl CubeTask for Box { fn id(&self) -> String { self.as_ref().id() } - - fn launch_settings(&self) -> LaunchSettings { - self.as_ref().launch_settings() - } } /// Provides launch information specifying the number of work groups to be used by a compute shader. -#[derive(new, Clone, Debug)] -pub struct CubeCount { - /// Work groups for the x axis. - pub x: u32, - /// Work groups for the y axis. - pub y: u32, - /// Work groups for the z axis. - pub z: u32, +pub enum CubeCount { + /// Dispatch x,y,z work groups. + Static(u32, u32, u32), + /// Dispatch work groups based on the values in this buffer. The buffer should contain a u32 array [x, y, z]. + Dynamic(Binding), } -impl CubeCount { - /// Calculate the number of invocations of a compute shader. - pub fn num_invocations(&self) -> usize { - (self.x * self.y * self.z) as usize +impl Clone for CubeCount { + fn clone(&self) -> Self { + match self { + Self::Static(x, y, z) => Self::Static(*x, *y, *z), + Self::Dynamic(handle) => Self::Dynamic(handle.clone()), + } } } diff --git a/crates/burn-cube/src/compute/launcher.rs b/crates/burn-cube/src/compute/launcher.rs index d51640873..7bc292229 100644 --- a/crates/burn-cube/src/compute/launcher.rs +++ b/crates/burn-cube/src/compute/launcher.rs @@ -78,15 +78,15 @@ impl KernelLauncher { /// Launch the kernel. pub fn launch( self, - cube_count: CubeCount, + cube_count: CubeCount, kernel: K, client: ComputeClient, ) { let bindings = self.into_bindings(&client); - let kernel = Box::new(KernelTask::::new(kernel, cube_count)); + let kernel = Box::new(KernelTask::::new(kernel)); - client.execute(kernel, bindings); + client.execute(kernel, cube_count, bindings); } /// We need to create the bindings in the same order they are defined in the compilation step. diff --git a/crates/burn-cube/src/lib.rs b/crates/burn-cube/src/lib.rs index 1d0687403..bf2372b72 100644 --- a/crates/burn-cube/src/lib.rs +++ b/crates/burn-cube/src/lib.rs @@ -6,6 +6,7 @@ extern crate derive_new; /// Cube Frontend Types. pub mod frontend; +use burn_compute::server::ComputeServer; pub use frontend::cmma; /// Cube Language Internal Representation. @@ -45,13 +46,16 @@ pub trait Kernel: Send + Sync + 'static { /// Calculate the number of cubes required to execute an operation where one cube unit is /// assigned to one element. -pub fn calculate_cube_count_elemwise(num_elems: usize, cube_dim: usize) -> CubeCount { +pub fn calculate_cube_count_elemwise( + num_elems: usize, + cube_dim: usize, +) -> CubeCount { let num_elems_per_cube = cube_dim * cube_dim; let cube_counts = f32::ceil(num_elems as f32 / num_elems_per_cube as f32); let cube_count_x = f32::ceil(f32::sqrt(cube_counts)); let cube_count_y = f32::ceil(num_elems as f32 / (cube_count_x * num_elems_per_cube as f32)); - CubeCount::new(cube_count_x as u32, cube_count_y as u32, 1) + CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1) } pub fn tensor_vectorization_factor( diff --git a/crates/burn-cube/src/runtime.rs b/crates/burn-cube/src/runtime.rs index ea962b726..2975a3f57 100644 --- a/crates/burn-cube/src/runtime.rs +++ b/crates/burn-cube/src/runtime.rs @@ -1,4 +1,8 @@ -use crate::{codegen::Compiler, compute::CubeTask, ir::Elem}; +use crate::{ + codegen::Compiler, + compute::{CubeCount, CubeTask}, + ir::Elem, +}; use burn_compute::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer}; /// Runtime for the CubeCL. @@ -6,7 +10,11 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug { /// The compiler used to compile the inner representation into tokens. type Compiler: Compiler; /// The compute server used to run kernels and perform autotuning. - type Server: ComputeServer, FeatureSet = FeatureSet>; + type Server: ComputeServer< + Kernel = Box, + DispatchOptions = CubeCount, + FeatureSet = FeatureSet, + >; /// The channel used to communicate with the compute server. type Channel: ComputeChannel; /// The device used to retrieve the compute client. diff --git a/crates/burn-cube/src/runtime_tests/cmma.rs b/crates/burn-cube/src/runtime_tests/cmma.rs index 08da16d24..2fa312218 100644 --- a/crates/burn-cube/src/runtime_tests/cmma.rs +++ b/crates/burn-cube/src/runtime_tests/cmma.rs @@ -61,7 +61,7 @@ pub fn test_simple_1(client: ComputeClient) { kernel_simple_1_launch::( client.clone(), - CubeCount::new(1, 1, 1), + CubeCount::Static(1, 1, 1), CubeDim::new(16, 16, 1), ArrayArg::new(&lhs, 256), ArrayArg::new(&rhs, 256), diff --git a/crates/burn-cube/src/runtime_tests/launch.rs b/crates/burn-cube/src/runtime_tests/launch.rs index 7cadafb36..158ca490b 100644 --- a/crates/burn-cube/src/runtime_tests/launch.rs +++ b/crates/burn-cube/src/runtime_tests/launch.rs @@ -20,7 +20,7 @@ pub fn test_kernel_with_generics(client: ComputeClient( client.clone(), - CubeCount::new(1, 1, 1), + CubeCount::Static(1, 1, 1), CubeDim::default(), ArrayArg::new(&handle, 2), ); @@ -36,7 +36,7 @@ pub fn test_kernel_without_generics(client: ComputeClient( client.clone(), - CubeCount::new(1, 1, 1), + CubeCount::Static(1, 1, 1), CubeDim::default(), ArrayArg::new(&handle, 2), ); diff --git a/crates/burn-cube/src/runtime_tests/subcube.rs b/crates/burn-cube/src/runtime_tests/subcube.rs index bde884d10..6cfe6f5db 100644 --- a/crates/burn-cube/src/runtime_tests/subcube.rs +++ b/crates/burn-cube/src/runtime_tests/subcube.rs @@ -98,7 +98,7 @@ fn test_subcube_operation( client: ComputeClient, launch: Launch, ) where - Launch: Fn(CubeCount, CubeDim, TensorArg<'_, TestRuntime>), + Launch: Fn(CubeCount, CubeDim, TensorArg<'_, TestRuntime>), { if !client.features().enabled(Feature::Subcube) { // Can't execute the test. @@ -109,7 +109,7 @@ fn test_subcube_operation( let (shape, strides) = ([input.len()], [1]); launch( - CubeCount::new(1, 1, 1), + CubeCount::Static(1, 1, 1), CubeDim::new(input.len() as u32, 1, 1), TensorArg::new(&handle, &strides, &shape), ); diff --git a/crates/burn-cuda/src/compute/server.rs b/crates/burn-cuda/src/compute/server.rs index 4e1c8cca5..d147f2597 100644 --- a/crates/burn-cuda/src/compute/server.rs +++ b/crates/burn-cuda/src/compute/server.rs @@ -57,14 +57,8 @@ struct CompiledKernel { unsafe impl> Send for CudaServer {} -impl> ComputeServer for CudaServer { - type Kernel = Box; - type Storage = CudaStorage; - type MemoryManagement = MM; - type AutotuneKey = JitAutotuneKey; - type FeatureSet = FeatureSet; - - fn read(&mut self, binding: server::Binding) -> Reader { +impl> CudaServer { + fn read_sync(&mut self, binding: server::Binding) -> Vec { let ctx = self.get_context(); let resource = ctx.memory_management.get(binding.memory); @@ -74,7 +68,20 @@ impl> ComputeServer for CudaServer { cudarc::driver::result::memcpy_dtoh_async(&mut data, resource.ptr, ctx.stream).unwrap(); }; ctx.sync(); - reader_from_concrete(data) + data + } +} + +impl> ComputeServer for CudaServer { + type Kernel = Box; + type DispatchOptions = CubeCount; + type Storage = CudaStorage; + type MemoryManagement = MM; + type AutotuneKey = JitAutotuneKey; + type FeatureSet = FeatureSet; + + fn read(&mut self, binding: server::Binding) -> Reader { + reader_from_concrete(self.read_sync(binding)) } fn create(&mut self, data: &[u8]) -> server::Handle { @@ -101,12 +108,33 @@ impl> ComputeServer for CudaServer { server::Handle::new(handle) } - fn execute(&mut self, kernel: Self::Kernel, bindings: Vec>) { + fn execute( + &mut self, + kernel: Self::Kernel, + count: Self::DispatchOptions, + bindings: Vec>, + ) { let arch = self.minimum_arch_version; - let ctx = self.get_context(); let kernel_id = kernel.id(); - let settings = kernel.launch_settings(); + + let count = match count { + CubeCount::Static(x, y, z) => (x, y, z), + // TODO: CUDA doesn't have an exact equivalen of dynamic dispatch. Instead, kernels are free to launch other kernels. + // One option is to create a dummy kernel with 1 thread that launches the real kernel with the dynamic dispatch settings. + // For now, just read the dispatch settings from the buffer. + CubeCount::Dynamic(binding) => { + let data = self.read_sync(binding); + let data = bytemuck::cast_slice(&data); + assert!( + data.len() == 3, + "Dynamic cube count should contain 3 values" + ); + (data[0], data[1], data[2]) + } + }; + + let ctx = self.get_context(); if !ctx.module_names.contains_key(&kernel_id) { ctx.compile_kernel(&kernel_id, kernel, arch); @@ -117,7 +145,7 @@ impl> ComputeServer for CudaServer { .map(|binding| ctx.memory_management.get(binding.memory).as_binding()) .collect(); - ctx.execute_task(kernel_id, settings.cube_count, bindings); + ctx.execute_task(kernel_id, count, bindings); // TODO: fix this // self.memory_management.storage().perform_deallocations(); } @@ -217,16 +245,15 @@ impl> CudaContext { fn execute_task( &mut self, kernel_id: String, - cube_count: CubeCount, + dispatch_count: (u32, u32, u32), mut bindings: Vec, ) { let kernel = self.module_names.get(&kernel_id).unwrap(); let cube_dim = kernel.cube_dim; - unsafe { cudarc::driver::result::launch_kernel( kernel.func, - (cube_count.x, cube_count.y, cube_count.z), + dispatch_count, (cube_dim.x, cube_dim.y, cube_dim.z), kernel.shared_mem_bytes as u32, self.stream, diff --git a/crates/burn-jit/src/fusion/elemwise/optimization.rs b/crates/burn-jit/src/fusion/elemwise/optimization.rs index a5312c9c7..86d641a33 100644 --- a/crates/burn-jit/src/fusion/elemwise/optimization.rs +++ b/crates/burn-jit/src/fusion/elemwise/optimization.rs @@ -29,9 +29,9 @@ pub struct CompilationPhase; /// Phase where the kernel should be executed. #[derive(new)] pub struct ExecutionPhase { - /// Kernel set with default workgroup size. + /// Kernel set with default cube size. pub(super) kernel_factory_1: ElementWiseKernelFactory, - /// Kernel set with custom workgroup size. + /// Kernel set with custom cube size. pub(super) kernel_factory_2: ElementWiseKernelFactory, } diff --git a/crates/burn-jit/src/fusion/kernel.rs b/crates/burn-jit/src/fusion/kernel.rs index c7d95e601..c7a766ac1 100644 --- a/crates/burn-jit/src/fusion/kernel.rs +++ b/crates/burn-jit/src/fusion/kernel.rs @@ -22,7 +22,7 @@ pub struct FusionKernel { info: Arc, settings: KernelSettings, runtime_info: Vec, - cube_count: CubeCount, + cube_count: CubeCount, _runtime: PhantomData, } @@ -41,6 +41,7 @@ pub trait FusionKernelFactory { #[derive(new)] pub struct ExecutableKernel { kernel: Box, + cube_count: CubeCount, bindings: Vec>, client: ComputeClient, } @@ -54,6 +55,7 @@ pub struct ExecutableKernel { #[derive(new)] pub struct AutotunableKernel { kernel: Arc, + count: CubeCount, bindings: Vec>, client: ComputeClient, } @@ -68,18 +70,21 @@ pub enum OutputRuntimeInfo { impl ExecutableKernel { /// Execute the kernel. pub fn execute(self) { - self.client.execute(self.kernel, self.bindings) + self.client + .execute(self.kernel, self.cube_count, self.bindings) } } impl AutotuneOperation for AutotunableKernel { fn execute(self: Box) { - self.client.execute(Box::new(self.kernel), self.bindings) + self.client + .execute(Box::new(self.kernel), self.count, self.bindings) } fn clone(&self) -> Box { Box::new(Self { kernel: self.kernel.clone(), + count: self.count.clone(), bindings: self.bindings.clone(), client: self.client.clone(), }) @@ -90,6 +95,7 @@ impl From> for AutotunableKernel { fn from(value: ExecutableKernel) -> Self { Self { kernel: Arc::new(value.kernel), + count: value.cube_count.clone(), bindings: value.bindings, client: value.client, } @@ -233,12 +239,10 @@ impl FusionKernel { context.handles.register_handle(id, handle); } - let workgroup = fusion_kernel.cube_count.clone(); + let cube_count = fusion_kernel.cube_count.clone(); ExecutableKernel::new( - Box::new(KernelTask::>::new( - fusion_kernel, - workgroup, - )), + Box::new(KernelTask::::new(fusion_kernel)), + cube_count, bindings, client, ) diff --git a/crates/burn-jit/src/kernel/index/scatter.rs b/crates/burn-jit/src/kernel/index/scatter.rs index 0845a0eba..7a795e18d 100644 --- a/crates/burn-jit/src/kernel/index/scatter.rs +++ b/crates/burn-jit/src/kernel/index/scatter.rs @@ -196,7 +196,7 @@ pub(crate) fn scatter::new(dim); let mut strides = [0; D]; let mut current = 1; - let mut num_elems_per_workgroup = 1; + let mut num_elems = 1; tensor .shape @@ -208,13 +208,13 @@ pub(crate) fn scatter::new(dim); - let workgroup = calculate_cube_count_elemwise(num_elems_per_workgroup, SUBCUBE_DIM_APPROX); + let cube_count = calculate_cube_count_elemwise(num_elems, SUBCUBE_DIM_APPROX); Execution::start(kernel, indices.client) .inputs(&[ @@ -215,7 +215,7 @@ pub(crate) fn select_assign( } } -pub(crate) fn simple_launch_options( +pub(crate) fn simple_cube_count( lhs_shape: &Shape, rhs_shape: &Shape, output_shape: &Shape, - workgroup_size_x: usize, - workgroup_size_y: usize, -) -> CubeCount { + cube_dim_x: usize, + cube_dim_y: usize, +) -> CubeCount { let num_rows = lhs_shape.dims[D - 2]; let num_cols = rhs_shape.dims[D - 1]; - // set number of workgroups - let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32; - let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32; + let cubes_x = f32::ceil(num_rows as f32 / cube_dim_x as f32) as u32; + let cubes_y = f32::ceil(num_cols as f32 / cube_dim_y as f32) as u32; let mut num_iter = 1; for i in 0..D - 2 { num_iter *= output_shape.dims[i]; } - CubeCount::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32) + CubeCount::Static(cubes_x, cubes_y, num_iter as u32) } -pub(crate) fn tiling2d_launch_options( +pub(crate) fn tiling2d_launch_options( output_shape: &Shape, config: Tiling2dConfig, -) -> CubeCount { +) -> CubeCount { let num_rows = output_shape.dims[D - 2]; let num_cols = output_shape.dims[D - 1]; - // set number of workgroups - let blocks_needed_in_x = f32::ceil(num_rows as f32 / config.block_size_m as f32) as u32; - let blocks_needed_in_y = f32::ceil(num_cols as f32 / config.block_size_n as f32) as u32; + let cubes_x = f32::ceil(num_rows as f32 / config.block_size_m as f32) as u32; + let cubes_y = f32::ceil(num_cols as f32 / config.block_size_n as f32) as u32; let mut num_iter = 1; for i in 0..D - 2 { num_iter *= output_shape.dims[i]; } - CubeCount::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32) + CubeCount::Static(cubes_x, cubes_y, num_iter as u32) } diff --git a/crates/burn-jit/src/kernel/matmul/simple.rs b/crates/burn-jit/src/kernel/matmul/simple.rs index 57c2fa9f3..16d92516d 100644 --- a/crates/burn-jit/src/kernel/matmul/simple.rs +++ b/crates/burn-jit/src/kernel/matmul/simple.rs @@ -7,7 +7,7 @@ use crate::{ use burn_cube::ir::KernelDefinition; use burn_cube::{frontend::TensorArg, KernelSettings}; -use super::simple_launch_options; +use super::simple_cube_count; use burn_cube::prelude::*; #[cube(launch)] @@ -80,7 +80,7 @@ fn matmul_kernel( } } -/// Matrix multiplication using memory coalescing algorithm with workgroups of size 16 +/// Matrix multiplication using memory coalescing algorithm with cube dimensions of size 16 pub fn matmul_mem_coalescing_default( lhs: JitTensor, rhs: JitTensor, @@ -89,7 +89,7 @@ pub fn matmul_mem_coalescing_default(lhs, rhs, out, SUBCUBE_DIM_APPROX, SUBCUBE_DIM_APPROX) } -/// Matrix multiplication using memory coalescing algorithm with custom workgroup sizes +/// Matrix multiplication using memory coalescing algorithm with custom cube dimensions pub fn matmul_simple( lhs: JitTensor, rhs: JitTensor, @@ -103,7 +103,7 @@ pub fn matmul_simple( let rhs_original_shape = rhs.shape.clone(); let rhs = into_contiguous(swap_dims(rhs, D - 1, D - 2)); - let cube_count = simple_launch_options( + let cube_count = simple_cube_count::( &lhs.shape, &rhs_original_shape, &out.shape, diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d.rs b/crates/burn-jit/src/kernel/matmul/tiling2d.rs index 3fd0e3618..dea73cd02 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d.rs @@ -118,7 +118,7 @@ pub fn matmul_tiling_2d( &out.strides, &out.shape.dims, )]) - .execute(CubeCountSettings::Custom(tiling2d_launch_options( + .execute(CubeCountSettings::Custom(tiling2d_launch_options::( &out.shape, config, ))); @@ -175,7 +175,7 @@ pub fn matmul_tiling_2d_padded( &rounded_output.shape, config, ))); diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs index 0b76238af..903321fe5 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs @@ -55,14 +55,14 @@ pub(crate) fn gather_shader_information( cpa!(scope, out_stride_row = stride(out, second_to_last_dim)); cpa!(scope, out_stride_col = stride(out, last_dim)); - // Workgroup offset + // Cube offset let skip_row = scope.create_local(Elem::UInt); let skip_col = scope.create_local(Elem::UInt); - let workgroup_id_x = Variable::CubePosX; - let workgroup_id_y = Variable::CubePosY; - cpa!(scope, skip_row = workgroup_id_x); + let cube_pos_x = Variable::CubePosX; + let cube_pos_y = Variable::CubePosY; + cpa!(scope, skip_row = cube_pos_x); cpa!(scope, skip_row *= block_size_m); - cpa!(scope, skip_col = workgroup_id_y); + cpa!(scope, skip_col = cube_pos_y); cpa!(scope, skip_col *= block_size_n); // Position of the first element of the thread, relative to the block diff --git a/crates/burn-jit/src/kernel/prng/base.rs b/crates/burn-jit/src/kernel/prng/base.rs index 852462d75..bf04e33e2 100644 --- a/crates/burn-jit/src/kernel/prng/base.rs +++ b/crates/burn-jit/src/kernel/prng/base.rs @@ -38,7 +38,7 @@ pub(crate) fn random, R: JitRuntime, E: JitElement, const D: usize>( )]) .with_scalars(&seeds) .with_scalars(&prng.args()) - .execute(CubeCountSettings::Custom(prng_cube_count( + .execute(CubeCountSettings::Custom(prng_cube_count::( num_elems, SUBCUBE_DIM_APPROX, N_VALUES_PER_THREAD, @@ -47,14 +47,18 @@ pub(crate) fn random, R: JitRuntime, E: JitElement, const D: usize>( output } -fn prng_cube_count(num_elems: usize, cube_dim: usize, n_values_per_thread: usize) -> CubeCount { +fn prng_cube_count( + num_elems: usize, + cube_dim: usize, + n_values_per_thread: usize, +) -> CubeCount { let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32); let num_elems_per_cube = cube_dim * cube_dim; let num_invocations = f32::ceil(num_threads / num_elems_per_cube as f32); - let workgroup_x = f32::ceil(f32::sqrt(num_invocations)); - let workgroup_y = f32::ceil(num_invocations / workgroup_x); + let cubes_x = f32::ceil(f32::sqrt(num_invocations)); + let cubes_y = f32::ceil(num_invocations / cubes_x); - CubeCount::new(workgroup_x as u32, workgroup_y as u32, 1) + CubeCount::Static(cubes_x as u32, cubes_y as u32, 1) } impl, R: JitRuntime, E: JitElement> Kernel for PrngEagerKernel { @@ -163,24 +167,24 @@ impl, E: JitElement> PrngShader { let n_values_per_thread: Variable = self.n_values_per_thread.into(); let args = self.args; - let workgroup_size_x = Variable::CubeDimX; - let workgroup_size_y = Variable::CubeDimY; - let workgroup_id_x = Variable::CubePosX; - let workgroup_id_y = Variable::CubePosY; - let num_workgroups_y = Variable::CubeCountY; + let cube_dim_x = Variable::CubeDimX; + let cube_dim_y = Variable::CubeDimY; + let cube_pos_x = Variable::CubePosX; + let cube_pos_y = Variable::CubePosY; + let cube_count_y = Variable::CubeCountY; let local_index = Variable::UnitPos; let n_invocations = scope.create_local(Elem::UInt); - cpa!(scope, n_invocations = workgroup_size_x); - cpa!(scope, n_invocations *= workgroup_size_y); + cpa!(scope, n_invocations = cube_dim_x); + cpa!(scope, n_invocations *= cube_dim_y); - let workgroup_offset = scope.create_local(Elem::UInt); - cpa!(scope, workgroup_offset = workgroup_id_x * num_workgroups_y); - cpa!(scope, workgroup_offset += workgroup_id_y); - cpa!(scope, workgroup_offset *= n_invocations); + let cube_offset = scope.create_local(Elem::UInt); + cpa!(scope, cube_offset = cube_pos_x * cube_count_y); + cpa!(scope, cube_offset += cube_pos_y); + cpa!(scope, cube_offset *= n_invocations); let write_index_base = scope.create_local(Elem::UInt); - cpa!(scope, write_index_base = workgroup_offset); + cpa!(scope, write_index_base = cube_offset); cpa!(scope, write_index_base *= n_values_per_thread); cpa!(scope, write_index_base += local_index); @@ -188,7 +192,7 @@ impl, E: JitElement> PrngShader { let thread_seed = scope.create_local(Elem::UInt); cpa!(scope, thread_seed = cast(1000000007)); let thread_seed_index = scope.create_local(Elem::UInt); - cpa!(scope, thread_seed_index = workgroup_offset + local_index); + cpa!(scope, thread_seed_index = cube_offset + local_index); cpa!(scope, thread_seed *= thread_seed_index); let state_0 = scope.create_local(Elem::UInt); diff --git a/crates/burn-jit/src/kernel/reduce/shared/shader.rs b/crates/burn-jit/src/kernel/reduce/shared/shader.rs index c5c98a4f9..c7704a591 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/shader.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/shader.rs @@ -33,8 +33,8 @@ pub(crate) struct SharedReduceDimEagerKernel< EO: JitElement, > { dim: usize, - workgroup_size_x: usize, - workgroup_size_y: usize, + cube_dim_x: usize, + cube_dim_y: usize, n_input_values_per_thread: u32, divisible_shape: bool, _reduce_dim: PhantomData, @@ -58,7 +58,7 @@ impl, R: JitRuntime, EI: JitElement, EO: JitElement> Ker SharedReduceDimComputeShader { tensor, dim: self.dim, - shared_memory_size: self.workgroup_size_x * self.workgroup_size_y, + shared_memory_size: self.cube_dim_x * self.cube_dim_y, n_input_values_per_thread: self.n_input_values_per_thread, output, divisible_shape: self.divisible_shape, @@ -83,8 +83,8 @@ impl, R: JitRuntime, EI: JitElement, EO: JitElement> Ker }; let settings = KernelSettings::default().cube_dim(CubeDim::new( - self.workgroup_size_x as u32, - self.workgroup_size_y as u32, + self.cube_dim_x as u32, + self.cube_dim_y as u32, 1, )); KernelIntegrator::new(info).integrate(settings) @@ -95,8 +95,8 @@ impl, R: JitRuntime, EI: JitElement, EO: JitElement> Ker "{:?}dim={}x={}y={}n={}divshape={}", core::any::TypeId::of::(), self.dim, - self.workgroup_size_x, - self.workgroup_size_y, + self.cube_dim_x, + self.cube_dim_y, self.n_input_values_per_thread, self.divisible_shape ) @@ -111,13 +111,13 @@ impl> SharedReduceDimComputeShader let rank = Variable::Rank; let dim: Variable = self.dim.into(); - let workgroup_id_x = Variable::CubePosX; - let workgroup_id_y = Variable::CubePosY; - let num_workgroups_x = Variable::CubeCountX; + let cube_pos_x = Variable::CubePosX; + let cube_pos_y = Variable::CubePosY; + let cube_count_x = Variable::CubeCountX; let local_invocation_id_x = Variable::UnitPosX; let local_invocation_id_y = Variable::UnitPosY; - let workgroup_size_x = Variable::CubeDimX; - let workgroup_size_y = Variable::CubeDimY; + let cube_dim_x = Variable::CubeDimX; + let cube_dim_y = Variable::CubeDimY; let stride_reduce_dim_input = scope.create_local(Elem::UInt); cpa!(scope, stride_reduce_dim_input = stride(tensor, dim)); @@ -126,16 +126,16 @@ impl> SharedReduceDimComputeShader // To determine which reduce_group (not position, but absolute id) let reduce_group_id = scope.create_local(Elem::UInt); - cpa!(scope, reduce_group_id = workgroup_id_y * num_workgroups_x); - cpa!(scope, reduce_group_id += workgroup_id_x); + cpa!(scope, reduce_group_id = cube_pos_y * cube_count_x); + cpa!(scope, reduce_group_id += cube_pos_x); - // nth thread in the workgroup + // nth thread in the cube let local_id = scope.create_local(Elem::UInt); - cpa!(scope, local_id = local_invocation_id_y * workgroup_size_x); + cpa!(scope, local_id = local_invocation_id_y * cube_dim_x); cpa!(scope, local_id += local_invocation_id_x); let n_threads = scope.create_local(Elem::UInt); - cpa!(scope, n_threads = workgroup_size_x * workgroup_size_y); + cpa!(scope, n_threads = cube_dim_x * cube_dim_y); let index_offset = scope.zero(Elem::UInt); @@ -242,17 +242,17 @@ pub fn reduce_dim_shared< dim: usize, ) -> JitTensor { let num_elems_output = output.shape.num_elements(); - let n_workgroups_x = f32::ceil(f32::sqrt(num_elems_output as f32)); - let n_workgroups_y = f32::ceil(num_elems_output as f32 / n_workgroups_x); - let grid = CubeCount::new(n_workgroups_x as u32, n_workgroups_y as u32, 1); + let cube_count_x = f32::ceil(f32::sqrt(num_elems_output as f32)); + let cube_count_y = f32::ceil(num_elems_output as f32 / cube_count_x); + let grid = CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1); let reduce_group_size = input.shape.dims[dim]; - let n_invocation_per_workgroup = SUBCUBE_DIM_APPROX * SUBCUBE_DIM_APPROX; + let n_invocation_per_cube = SUBCUBE_DIM_APPROX * SUBCUBE_DIM_APPROX; let n_input_values_per_thread = - f32::ceil(reduce_group_size as f32 / n_invocation_per_workgroup as f32) as u32; + f32::ceil(reduce_group_size as f32 / n_invocation_per_cube as f32) as u32; let divisible_shape = - n_invocation_per_workgroup as u32 * n_input_values_per_thread == reduce_group_size as u32; + n_invocation_per_cube as u32 * n_input_values_per_thread == reduce_group_size as u32; let kernel = SharedReduceDimEagerKernel::::new( dim, diff --git a/crates/burn-jit/src/lib.rs b/crates/burn-jit/src/lib.rs index 9e3ed3e0b..450454b07 100644 --- a/crates/burn-jit/src/lib.rs +++ b/crates/burn-jit/src/lib.rs @@ -18,7 +18,10 @@ pub(crate) mod tune; /// Elements for JIT backend pub mod element; -use burn_cube::{compute::CubeTask, Runtime}; +use burn_cube::{ + compute::{CubeCount, CubeTask}, + Runtime, +}; pub use element::{FloatElement, IntElement, JitElement}; mod backend; @@ -48,5 +51,6 @@ pub trait JitRuntime: Runtime, + DispatchOptions = CubeCount, >; } diff --git a/crates/burn-jit/src/template/base.rs b/crates/burn-jit/src/template/base.rs index 253520dbd..e94515af0 100644 --- a/crates/burn-jit/src/template/base.rs +++ b/crates/burn-jit/src/template/base.rs @@ -1,5 +1,4 @@ use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; -use burn_cube::compute::LaunchSettings; use burn_cube::prelude::*; use super::SourceTemplate; @@ -11,18 +10,13 @@ pub trait KernelSource: Send + 'static + Sync { } #[derive(new)] -/// Wraps a [kernel source](KernelSource) into a [cube task](CubeTask) with launch -/// information. +/// Wraps a [kernel source](KernelSource) into a [cube task](CubeTask). pub struct SourceKernel { kernel_source: K, - cube_count: CubeCount, cube_dim: CubeDim, } -impl CubeTask for SourceKernel -where - K: KernelSource + 'static, -{ +impl CubeTask for SourceKernel { fn compile(&self) -> CompiledKernel { let source_template = self.kernel_source.source(); let source = source_template.complete(); @@ -37,12 +31,6 @@ where fn id(&self) -> String { format!("{:?}", core::any::TypeId::of::()) } - - fn launch_settings(&self) -> LaunchSettings { - LaunchSettings { - cube_count: self.cube_count.clone(), - } - } } /// Generates kernel source code by replacing some information using templating. diff --git a/crates/burn-wgpu/src/compute/server.rs b/crates/burn-wgpu/src/compute/server.rs index 45f0dbd00..b432d670d 100644 --- a/crates/burn-wgpu/src/compute/server.rs +++ b/crates/burn-wgpu/src/compute/server.rs @@ -64,8 +64,15 @@ where &mut self, pipeline: Arc, bind_group: BindGroup, - work_group: CubeCount, + count: CubeCount, ) { + // First resolve the dispatch buffer if needed. The weird ordering is because the lifetime of this + // needs to be longer than the compute pass, so we can't do this just before dispatching. + let dispatch_resource = match count.clone() { + CubeCount::Dynamic(binding) => Some(self.memory_management.get(binding.memory)), + _ => None, + }; + let mut compute = self .encoder .begin_compute_pass(&wgpu::ComputePassDescriptor { @@ -75,12 +82,21 @@ where compute.set_pipeline(&pipeline); compute.set_bind_group(0, &bind_group, &[]); - compute.dispatch_workgroups(work_group.x, work_group.y, work_group.z); + + match count { + CubeCount::Static(x, y, z) => { + compute.dispatch_workgroups(x, y, z); + } + CubeCount::Dynamic(_) => { + let resource = dispatch_resource.as_ref().unwrap(); + compute.dispatch_workgroups_indirect(&resource.buffer, resource.offset()); + } + } self.tasks_count += 1; } - fn pipeline(&mut self, kernel: Box) -> Arc { + fn pipeline(&mut self, kernel: ::Kernel) -> Arc { let kernel_id = kernel.id(); if let Some(pipeline) = self.pipelines.get(&kernel_id) { @@ -143,6 +159,7 @@ where MM: MemoryManagement, { type Kernel = Box; + type DispatchOptions = CubeCount; type Storage = WgpuStorage; type MemoryManagement = MM; type AutotuneKey = JitAutotuneKey; @@ -259,8 +276,12 @@ where })) } - fn execute(&mut self, kernel: Self::Kernel, bindings: Vec>) { - let work_group = kernel.launch_settings().cube_count; + fn execute( + &mut self, + kernel: Self::Kernel, + count: Self::DispatchOptions, + bindings: Vec>, + ) { let pipeline = self.pipeline(kernel); let group_layout = pipeline.get_bind_group_layout(0); @@ -284,7 +305,7 @@ where entries: &entries, }); - self.register_compute(pipeline, bind_group, work_group); + self.register_compute(pipeline, bind_group, count); if self.tasks_count >= self.tasks_max { self.sync(SyncType::Flush); diff --git a/crates/burn-wgpu/src/compute/storage.rs b/crates/burn-wgpu/src/compute/storage.rs index 336c8089e..a2733b806 100644 --- a/crates/burn-wgpu/src/compute/storage.rs +++ b/crates/burn-wgpu/src/compute/storage.rs @@ -111,7 +111,8 @@ impl ComputeStorage for WgpuStorage { size: size as u64, usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::STORAGE - | wgpu::BufferUsages::COPY_SRC, + | wgpu::BufferUsages::COPY_SRC + | wgpu::BufferUsages::INDIRECT, mapped_at_creation: false, })); diff --git a/examples/custom-wgpu-kernel/src/forward.rs b/examples/custom-wgpu-kernel/src/forward.rs index 632177695..84a982660 100644 --- a/examples/custom-wgpu-kernel/src/forward.rs +++ b/examples/custom-wgpu-kernel/src/forward.rs @@ -87,11 +87,13 @@ impl Backend for JitBackend { // Declare the wgsl workgroup with the number of cubes in x, y and z. let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32; let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32; - let cube_count = CubeCount::new(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32); + let cube_count = + CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32); // Execute lazily the kernel with the launch information and the given buffers. lhs.client.execute( - Box::new(SourceKernel::new(kernel, cube_count, cube_dim)), + Box::new(SourceKernel::new(kernel, cube_dim)), + cube_count, vec![ lhs.handle.binding(), rhs.handle.binding(), diff --git a/examples/gelu/src/lib.rs b/examples/gelu/src/lib.rs index 76149a090..9854fd4e3 100644 --- a/examples/gelu/src/lib.rs +++ b/examples/gelu/src/lib.rs @@ -22,7 +22,7 @@ pub fn launch(device: &R::Device) { gelu_launch::( client.clone(), - CubeCount::new(1, 1, 1), + CubeCount::Static(1, 1, 1), CubeDim::default(), ArrayArg::new(&input_handle, input.len()), ArrayArg::new(&output_handle, input.len()),