Feat: Dynamic cube count dispatch (#1975)

This commit is contained in:
Arthur Brussee 2024-07-07 00:17:01 +01:00 committed by GitHub
parent b331290f8a
commit 3f9e97946f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 293 additions and 215 deletions

View File

@ -248,11 +248,12 @@ impl<F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime, F, I> {
// Declare the wgsl workgroup with the number of blocks in x, y and z. // Declare the wgsl workgroup with the number of blocks in x, y and z.
let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32; let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;
let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32; let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32;
let cube_count = CubeCount::new(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32); let cube_count = CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);
// Execute lazily the kernel with the launch information and the given buffers. // Execute lazily the kernel with the launch information and the given buffers.
lhs.client.execute( lhs.client.execute(
Box::new(SourceKernel::new(kernel, cube_count, cube_dim)), Box::new(SourceKernel::new(kernel, cube_dim)),
cube_count,
vec![ vec![
lhs.handle.binding(), lhs.handle.binding(),
rhs.handle.binding(), rhs.handle.binding(),

View File

@ -24,7 +24,12 @@ pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug + Send
fn empty(&self, size: usize) -> Handle<Server>; fn empty(&self, size: usize) -> Handle<Server>;
/// Executes the `kernel` over the given `bindings`. /// Executes the `kernel` over the given `bindings`.
fn execute(&self, kernel: Server::Kernel, bindings: Vec<Binding<Server>>); fn execute(
&self,
kernel: Server::Kernel,
count: Server::DispatchOptions,
bindings: Vec<Binding<Server>>,
);
/// Perform some synchronization of commands on the server. /// Perform some synchronization of commands on the server.
fn sync(&self, sync_type: SyncType); fn sync(&self, sync_type: SyncType);

View File

@ -63,10 +63,15 @@ where
self.server.borrow_mut().empty(size) self.server.borrow_mut().empty(size)
} }
fn execute(&self, kernel_description: Server::Kernel, bindings: Vec<Binding<Server>>) { fn execute(
&self,
kernel_description: Server::Kernel,
count: Server::DispatchOptions,
bindings: Vec<Binding<Server>>,
) {
self.server self.server
.borrow_mut() .borrow_mut()
.execute(kernel_description, bindings) .execute(kernel_description, count, bindings)
} }
fn sync(&self, sync_type: SyncType) { fn sync(&self, sync_type: SyncType) {

View File

@ -39,7 +39,10 @@ where
), ),
Create(Vec<u8>, Callback<Handle<Server>>), Create(Vec<u8>, Callback<Handle<Server>>),
Empty(usize, Callback<Handle<Server>>), Empty(usize, Callback<Handle<Server>>),
ExecuteKernel(Server::Kernel, Vec<Binding<Server>>), ExecuteKernel(
(Server::Kernel, Server::DispatchOptions),
Vec<Binding<Server>>,
),
Sync(SyncType, Callback<()>), Sync(SyncType, Callback<()>),
} }
@ -74,7 +77,7 @@ where
callback.send(handle).await.unwrap(); callback.send(handle).await.unwrap();
} }
Message::ExecuteKernel(kernel, bindings) => { Message::ExecuteKernel(kernel, bindings) => {
server.execute(kernel, bindings); server.execute(kernel.0, kernel.1, bindings);
} }
Message::Sync(sync_type, callback) => { Message::Sync(sync_type, callback) => {
server.sync(sync_type); server.sync(sync_type);
@ -148,10 +151,15 @@ where
handle_response(response.recv_blocking()) handle_response(response.recv_blocking())
} }
fn execute(&self, kernel: Server::Kernel, bindings: Vec<Binding<Server>>) { fn execute(
&self,
kernel: Server::Kernel,
count: Server::DispatchOptions,
bindings: Vec<Binding<Server>>,
) {
self.state self.state
.sender .sender
.send_blocking(Message::ExecuteKernel(kernel, bindings)) .send_blocking(Message::ExecuteKernel((kernel, count), bindings))
.unwrap() .unwrap()
} }

View File

@ -56,8 +56,13 @@ where
self.server.lock().empty(size) self.server.lock().empty(size)
} }
fn execute(&self, kernel: Server::Kernel, handles: Vec<Binding<Server>>) { fn execute(
self.server.lock().execute(kernel, handles) &self,
kernel: Server::Kernel,
count: Server::DispatchOptions,
handles: Vec<Binding<Server>>,
) {
self.server.lock().execute(kernel, count, handles)
} }
fn sync(&self, sync_type: SyncType) { fn sync(&self, sync_type: SyncType) {

View File

@ -82,8 +82,13 @@ where
} }
/// Executes the `kernel` over the given `bindings`. /// Executes the `kernel` over the given `bindings`.
pub fn execute(&self, kernel: Server::Kernel, bindings: Vec<Binding<Server>>) { pub fn execute(
self.channel.execute(kernel, bindings) &self,
kernel: Server::Kernel,
count: Server::DispatchOptions,
bindings: Vec<Binding<Server>>,
) {
self.channel.execute(kernel, count, bindings)
} }
/// Wait for the completion of every task in the server. /// Wait for the completion of every task in the server.

View File

@ -17,6 +17,8 @@ where
{ {
/// The kernel type defines the computation algorithms. /// The kernel type defines the computation algorithms.
type Kernel: Send; type Kernel: Send;
/// Options when dispatching the kernel, eg. the number of executions.
type DispatchOptions: Send;
/// The [storage](ComputeStorage) type defines how data is stored and accessed. /// The [storage](ComputeStorage) type defines how data is stored and accessed.
type Storage: ComputeStorage; type Storage: ComputeStorage;
/// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type. /// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type.
@ -45,7 +47,12 @@ where
/// ///
/// Kernels have mutable access to every resource they are given /// Kernels have mutable access to every resource they are given
/// and are responsible of determining which should be read or written. /// and are responsible of determining which should be read or written.
fn execute(&mut self, kernel: Self::Kernel, bindings: Vec<Binding<Self>>); fn execute(
&mut self,
kernel: Self::Kernel,
count: Self::DispatchOptions,
bindings: Vec<Binding<Self>>,
);
/// Wait for the completion of every task in the server. /// Wait for the completion of every task in the server.
fn sync(&mut self, command: SyncType); fn sync(&mut self, command: SyncType);

View File

@ -21,6 +21,7 @@ impl<MM> ComputeServer for DummyServer<MM>
where where
MM: MemoryManagement<BytesStorage>, MM: MemoryManagement<BytesStorage>,
{ {
type DispatchOptions = ();
type Kernel = Arc<dyn DummyKernel>; type Kernel = Arc<dyn DummyKernel>;
type Storage = BytesStorage; type Storage = BytesStorage;
type MemoryManagement = MM; type MemoryManagement = MM;
@ -53,7 +54,12 @@ where
Handle::new(self.memory_management.reserve(size, || {})) Handle::new(self.memory_management.reserve(size, || {}))
} }
fn execute(&mut self, kernel: Self::Kernel, bindings: Vec<Binding<Self>>) { fn execute(
&mut self,
kernel: Self::Kernel,
_count: Self::DispatchOptions,
bindings: Vec<Binding<Self>>,
) {
let mut resources = bindings let mut resources = bindings
.into_iter() .into_iter()
.map(|binding| self.memory_management.get(binding.memory)) .map(|binding| self.memory_management.get(binding.memory))

View File

@ -18,7 +18,7 @@ pub struct OneKernelAutotuneOperation {
impl AutotuneOperation for OneKernelAutotuneOperation { impl AutotuneOperation for OneKernelAutotuneOperation {
/// Executes the operation on given bindings and server, with the additional parameters /// Executes the operation on given bindings and server, with the additional parameters
fn execute(self: Box<Self>) { fn execute(self: Box<Self>) {
self.client.execute(self.kernel.clone(), self.bindings); self.client.execute(self.kernel.clone(), (), self.bindings);
} }
fn clone(&self) -> Box<dyn AutotuneOperation> { fn clone(&self) -> Box<dyn AutotuneOperation> {

View File

@ -38,6 +38,7 @@ fn execute_elementwise_addition() {
client.execute( client.execute(
Arc::new(DummyElementwiseAddition), Arc::new(DummyElementwiseAddition),
(),
vec![lhs.binding(), rhs.binding(), out.clone().binding()], vec![lhs.binding(), rhs.binding(), out.clone().binding()],
); );

View File

@ -470,7 +470,7 @@ pub fn codegen_launch(sig: &syn::Signature) -> TokenStream {
/// Launch /// Launch
pub fn #ident #generics ( pub fn #ident #generics (
client: ComputeClient<R::Server, R::Channel>, client: ComputeClient<R::Server, R::Channel>,
cube_count: CubeCount, cube_count: CubeCount<R::Server>,
cube_dim: CubeDim, cube_dim: CubeDim,
#inputs #inputs
) -> #output { ) -> #output {

View File

@ -4,13 +4,13 @@ use crate::ir::Elem;
use crate::pod::CubeElement; use crate::pod::CubeElement;
use crate::{calculate_cube_count_elemwise, Kernel, Runtime, SUBCUBE_DIM_APPROX}; use crate::{calculate_cube_count_elemwise, Kernel, Runtime, SUBCUBE_DIM_APPROX};
use burn_compute::client::ComputeClient; use burn_compute::client::ComputeClient;
use burn_compute::server::{Binding, Handle}; use burn_compute::server::{Binding, ComputeServer, Handle};
/// The position of the input or output to calculate the number of workgroups to launch. /// The position of the input or output to calculate the number of cubes to launch.
pub enum CubeCountSettings { pub enum CubeCountSettings<S: ComputeServer> {
Input { pos: usize }, Input { pos: usize },
Output { pos: usize }, Output { pos: usize },
Custom(CubeCount), Custom(CubeCount<S>),
} }
pub struct Execution<'h, K, R: Runtime, Scalars> { pub struct Execution<'h, K, R: Runtime, Scalars> {
@ -73,7 +73,7 @@ where
} }
/// Execute a dynamic kernel. /// Execute a dynamic kernel.
#[allow(unused)] #[allow(unused)]
pub fn execute(self, launch: CubeCountSettings) { pub fn execute(self, launch: CubeCountSettings<R::Server>) {
execute_dynamic::<R, K, f32, f32, f32>( execute_dynamic::<R, K, f32, f32, f32>(
self.inputs, self.inputs,
self.outputs, self.outputs,
@ -108,7 +108,7 @@ where
/// Execute a dynamic kernel. /// Execute a dynamic kernel.
#[allow(unused)] #[allow(unused)]
pub fn execute(self, launch: CubeCountSettings) { pub fn execute(self, launch: CubeCountSettings<R::Server>) {
execute_dynamic::<R, K, E, f32, f32>( execute_dynamic::<R, K, E, f32, f32>(
self.inputs, self.inputs,
self.outputs, self.outputs,
@ -144,7 +144,7 @@ where
} }
/// Execute a dynamic kernel. /// Execute a dynamic kernel.
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn execute(self, launch: CubeCountSettings) pub fn execute(self, launch: CubeCountSettings<R::Server>)
where where
K: Kernel + 'static, K: Kernel + 'static,
R: Runtime, R: Runtime,
@ -172,7 +172,7 @@ where
{ {
/// Execute a dynamic kernel. /// Execute a dynamic kernel.
#[allow(unused)] #[allow(unused)]
pub fn execute(self, launch: CubeCountSettings) { pub fn execute(self, launch: CubeCountSettings<R::Server>) {
execute_dynamic::<R, K, E1, E2, E3>( execute_dynamic::<R, K, E1, E2, E3>(
self.inputs, self.inputs,
self.outputs, self.outputs,
@ -194,7 +194,7 @@ fn execute_dynamic<R, K, E1, E2, E3>(
scalars_2: Option<&[E2]>, scalars_2: Option<&[E2]>,
scalars_3: Option<&[E3]>, scalars_3: Option<&[E3]>,
kernel: K, kernel: K,
launch: CubeCountSettings, launch: CubeCountSettings<R::Server>,
client: ComputeClient<R::Server, R::Channel>, client: ComputeClient<R::Server, R::Channel>,
) where ) where
K: Kernel + 'static, K: Kernel + 'static,
@ -207,23 +207,21 @@ fn execute_dynamic<R, K, E1, E2, E3>(
inputs, outputs, scalars_1, scalars_2, scalars_3, launch, &client, inputs, outputs, scalars_1, scalars_2, scalars_3, launch, &client,
); );
let mut handles = settings.handles_tensors; let mut handles = settings.handles_tensors;
let workgroup = settings.cube_count;
handles.push(settings.handle_info.binding()); handles.push(settings.handle_info.binding());
for handle in settings.handles_scalars.into_iter() { for handle in settings.handles_scalars.into_iter() {
handles.push(handle.binding()); handles.push(handle.binding());
} }
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel, workgroup)); let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
client.execute(kernel, settings.cube_count, handles);
client.execute(kernel, handles);
} }
struct ExecuteSettings<R: Runtime> { struct ExecuteSettings<R: Runtime> {
handles_tensors: Vec<Binding<R::Server>>, handles_tensors: Vec<Binding<R::Server>>,
handle_info: Handle<R::Server>, handle_info: Handle<R::Server>,
handles_scalars: Vec<Handle<R::Server>>, handles_scalars: Vec<Handle<R::Server>>,
cube_count: CubeCount, cube_count: CubeCount<R::Server>,
} }
fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeElement>( fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeElement>(
@ -232,7 +230,7 @@ fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeEl
scalars_1: Option<&[E1]>, scalars_1: Option<&[E1]>,
scalars_2: Option<&[E2]>, scalars_2: Option<&[E2]>,
scalars_3: Option<&[E3]>, scalars_3: Option<&[E3]>,
launch: CubeCountSettings, launch: CubeCountSettings<R::Server>,
client: &ComputeClient<R::Server, R::Channel>, client: &ComputeClient<R::Server, R::Channel>,
) -> ExecuteSettings<R> { ) -> ExecuteSettings<R> {
let mut info = Vec::new(); let mut info = Vec::new();
@ -295,8 +293,8 @@ fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeEl
let handles_scalars = let handles_scalars =
create_scalar_handles::<R, E1, E2, E3>(scalars_1, scalars_2, scalars_3, client); create_scalar_handles::<R, E1, E2, E3>(scalars_1, scalars_2, scalars_3, client);
let workgroup = match launch { let cube_count = match launch {
CubeCountSettings::Custom(workgroup) => workgroup, CubeCountSettings::Custom(count) => count,
_ => calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX), _ => calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX),
}; };
@ -304,7 +302,7 @@ fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeEl
handles_tensors: handles, handles_tensors: handles,
handle_info: info, handle_info: info,
handles_scalars, handles_scalars,
cube_count: workgroup, cube_count,
} }
} }

View File

@ -74,9 +74,9 @@ impl core::fmt::Display for KernelSettings {
// * Vectorization Global: vg{factor} // * Vectorization Global: vg{factor}
// * Vectorization Partial Input: v{factor}i{pos} // * Vectorization Partial Input: v{factor}i{pos}
// * Vectorization Partial Output: vo // * Vectorization Partial Output: vo
// * Workgroup Size X: x // * Cube Dim X: x
// * Workgroup Size Y: y // * Cube Dim Y: y
// * Workgroup Size Z: z // * Cube Dim Z: z
f.write_str("m")?; f.write_str("m")?;
for mapping in self.mappings.iter() { for mapping in self.mappings.iter() {
f.write_fmt(format_args!( f.write_fmt(format_args!(

View File

@ -1,41 +1,32 @@
use std::marker::PhantomData;
use crate::{codegen::CompilerRepresentation, ir::CubeDim, Compiler, Kernel}; use crate::{codegen::CompilerRepresentation, ir::CubeDim, Compiler, Kernel};
use alloc::sync::Arc; use alloc::sync::Arc;
use std::marker::PhantomData; use burn_compute::server::{Binding, ComputeServer};
/// A kernel, compiled in the target language /// A kernel, compiled in the target language
pub struct CompiledKernel { pub struct CompiledKernel {
/// Source code of the kernel /// Source code of the kernel
pub source: String, pub source: String,
/// Size of a workgroup for the compiled kernel /// Size of a cube for the compiled kernel
pub cube_dim: CubeDim, pub cube_dim: CubeDim,
/// The number of bytes used by the share memory /// The number of bytes used by the share memory
pub shared_mem_bytes: usize, pub shared_mem_bytes: usize,
} }
/// Information needed to launch the kernel
pub struct LaunchSettings {
/// Layout of workgroups for the kernel
pub cube_count: CubeCount,
}
/// Kernel trait with the ComputeShader that will be compiled and cached based on the /// Kernel trait with the ComputeShader that will be compiled and cached based on the
/// provided id. /// provided id.
///
/// The kernel will be launched with the given [launch settings](LaunchSettings).
pub trait CubeTask: Send + Sync { pub trait CubeTask: Send + Sync {
/// Identifier for the kernel, used for caching kernel compilation. /// Identifier for the kernel, used for caching kernel compilation.
fn id(&self) -> String; fn id(&self) -> String;
/// Compile the kernel into source /// Compile the kernel into source
fn compile(&self) -> CompiledKernel; fn compile(&self) -> CompiledKernel;
/// Launch settings.
fn launch_settings(&self) -> LaunchSettings;
} }
/// Wraps a [kernel](Kernel) with its [cube count](CubeCount) to create a [cube task](CubeTask). /// Wraps a [kernel](Kernel) to create a [cube task](CubeTask).
#[derive(new)] #[derive(new)]
pub struct KernelTask<C: Compiler, K: Kernel> { pub struct KernelTask<C: Compiler, K: Kernel> {
kernel_definition: K, kernel_definition: K,
cube_count: CubeCount,
_compiler: PhantomData<C>, _compiler: PhantomData<C>,
} }
@ -57,12 +48,6 @@ impl<C: Compiler, K: Kernel> CubeTask for KernelTask<C, K> {
fn id(&self) -> String { fn id(&self) -> String {
self.kernel_definition.id().clone() self.kernel_definition.id().clone()
} }
fn launch_settings(&self) -> LaunchSettings {
LaunchSettings {
cube_count: self.cube_count.clone(),
}
}
} }
impl CubeTask for Arc<dyn CubeTask> { impl CubeTask for Arc<dyn CubeTask> {
@ -73,10 +58,6 @@ impl CubeTask for Arc<dyn CubeTask> {
fn id(&self) -> String { fn id(&self) -> String {
self.as_ref().id() self.as_ref().id()
} }
fn launch_settings(&self) -> LaunchSettings {
self.as_ref().launch_settings()
}
} }
impl CubeTask for Box<dyn CubeTask> { impl CubeTask for Box<dyn CubeTask> {
@ -87,26 +68,21 @@ impl CubeTask for Box<dyn CubeTask> {
fn id(&self) -> String { fn id(&self) -> String {
self.as_ref().id() self.as_ref().id()
} }
fn launch_settings(&self) -> LaunchSettings {
self.as_ref().launch_settings()
}
} }
/// Provides launch information specifying the number of work groups to be used by a compute shader. /// Provides launch information specifying the number of work groups to be used by a compute shader.
#[derive(new, Clone, Debug)] pub enum CubeCount<S: ComputeServer> {
pub struct CubeCount { /// Dispatch x,y,z work groups.
/// Work groups for the x axis. Static(u32, u32, u32),
pub x: u32, /// Dispatch work groups based on the values in this buffer. The buffer should contain a u32 array [x, y, z].
/// Work groups for the y axis. Dynamic(Binding<S>),
pub y: u32,
/// Work groups for the z axis.
pub z: u32,
} }
impl CubeCount { impl<S: ComputeServer> Clone for CubeCount<S> {
/// Calculate the number of invocations of a compute shader. fn clone(&self) -> Self {
pub fn num_invocations(&self) -> usize { match self {
(self.x * self.y * self.z) as usize Self::Static(x, y, z) => Self::Static(*x, *y, *z),
Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
}
} }
} }

View File

@ -78,15 +78,15 @@ impl<R: Runtime> KernelLauncher<R> {
/// Launch the kernel. /// Launch the kernel.
pub fn launch<K: Kernel>( pub fn launch<K: Kernel>(
self, self,
cube_count: CubeCount, cube_count: CubeCount<R::Server>,
kernel: K, kernel: K,
client: ComputeClient<R::Server, R::Channel>, client: ComputeClient<R::Server, R::Channel>,
) { ) {
let bindings = self.into_bindings(&client); let bindings = self.into_bindings(&client);
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel, cube_count)); let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
client.execute(kernel, bindings); client.execute(kernel, cube_count, bindings);
} }
/// We need to create the bindings in the same order they are defined in the compilation step. /// We need to create the bindings in the same order they are defined in the compilation step.

View File

@ -6,6 +6,7 @@ extern crate derive_new;
/// Cube Frontend Types. /// Cube Frontend Types.
pub mod frontend; pub mod frontend;
use burn_compute::server::ComputeServer;
pub use frontend::cmma; pub use frontend::cmma;
/// Cube Language Internal Representation. /// Cube Language Internal Representation.
@ -45,13 +46,16 @@ pub trait Kernel: Send + Sync + 'static {
/// Calculate the number of cubes required to execute an operation where one cube unit is /// Calculate the number of cubes required to execute an operation where one cube unit is
/// assigned to one element. /// assigned to one element.
pub fn calculate_cube_count_elemwise(num_elems: usize, cube_dim: usize) -> CubeCount { pub fn calculate_cube_count_elemwise<S: ComputeServer>(
num_elems: usize,
cube_dim: usize,
) -> CubeCount<S> {
let num_elems_per_cube = cube_dim * cube_dim; let num_elems_per_cube = cube_dim * cube_dim;
let cube_counts = f32::ceil(num_elems as f32 / num_elems_per_cube as f32); let cube_counts = f32::ceil(num_elems as f32 / num_elems_per_cube as f32);
let cube_count_x = f32::ceil(f32::sqrt(cube_counts)); let cube_count_x = f32::ceil(f32::sqrt(cube_counts));
let cube_count_y = f32::ceil(num_elems as f32 / (cube_count_x * num_elems_per_cube as f32)); let cube_count_y = f32::ceil(num_elems as f32 / (cube_count_x * num_elems_per_cube as f32));
CubeCount::new(cube_count_x as u32, cube_count_y as u32, 1) CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1)
} }
pub fn tensor_vectorization_factor( pub fn tensor_vectorization_factor(

View File

@ -1,4 +1,8 @@
use crate::{codegen::Compiler, compute::CubeTask, ir::Elem}; use crate::{
codegen::Compiler,
compute::{CubeCount, CubeTask},
ir::Elem,
};
use burn_compute::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer}; use burn_compute::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
/// Runtime for the CubeCL. /// Runtime for the CubeCL.
@ -6,7 +10,11 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
/// The compiler used to compile the inner representation into tokens. /// The compiler used to compile the inner representation into tokens.
type Compiler: Compiler; type Compiler: Compiler;
/// The compute server used to run kernels and perform autotuning. /// The compute server used to run kernels and perform autotuning.
type Server: ComputeServer<Kernel = Box<dyn CubeTask>, FeatureSet = FeatureSet>; type Server: ComputeServer<
Kernel = Box<dyn CubeTask>,
DispatchOptions = CubeCount<Self::Server>,
FeatureSet = FeatureSet,
>;
/// The channel used to communicate with the compute server. /// The channel used to communicate with the compute server.
type Channel: ComputeChannel<Self::Server>; type Channel: ComputeChannel<Self::Server>;
/// The device used to retrieve the compute client. /// The device used to retrieve the compute client.

View File

@ -61,7 +61,7 @@ pub fn test_simple_1<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
kernel_simple_1_launch::<R>( kernel_simple_1_launch::<R>(
client.clone(), client.clone(),
CubeCount::new(1, 1, 1), CubeCount::Static(1, 1, 1),
CubeDim::new(16, 16, 1), CubeDim::new(16, 16, 1),
ArrayArg::new(&lhs, 256), ArrayArg::new(&lhs, 256),
ArrayArg::new(&rhs, 256), ArrayArg::new(&rhs, 256),

View File

@ -20,7 +20,7 @@ pub fn test_kernel_with_generics<R: Runtime>(client: ComputeClient<R::Server, R:
kernel_with_generics_launch::<F32, R>( kernel_with_generics_launch::<F32, R>(
client.clone(), client.clone(),
CubeCount::new(1, 1, 1), CubeCount::Static(1, 1, 1),
CubeDim::default(), CubeDim::default(),
ArrayArg::new(&handle, 2), ArrayArg::new(&handle, 2),
); );
@ -36,7 +36,7 @@ pub fn test_kernel_without_generics<R: Runtime>(client: ComputeClient<R::Server,
kernel_without_generics_launch::<R>( kernel_without_generics_launch::<R>(
client.clone(), client.clone(),
CubeCount::new(1, 1, 1), CubeCount::Static(1, 1, 1),
CubeDim::default(), CubeDim::default(),
ArrayArg::new(&handle, 2), ArrayArg::new(&handle, 2),
); );

View File

@ -98,7 +98,7 @@ fn test_subcube_operation<TestRuntime: Runtime, Launch>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>, client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
launch: Launch, launch: Launch,
) where ) where
Launch: Fn(CubeCount, CubeDim, TensorArg<'_, TestRuntime>), Launch: Fn(CubeCount<TestRuntime::Server>, CubeDim, TensorArg<'_, TestRuntime>),
{ {
if !client.features().enabled(Feature::Subcube) { if !client.features().enabled(Feature::Subcube) {
// Can't execute the test. // Can't execute the test.
@ -109,7 +109,7 @@ fn test_subcube_operation<TestRuntime: Runtime, Launch>(
let (shape, strides) = ([input.len()], [1]); let (shape, strides) = ([input.len()], [1]);
launch( launch(
CubeCount::new(1, 1, 1), CubeCount::Static(1, 1, 1),
CubeDim::new(input.len() as u32, 1, 1), CubeDim::new(input.len() as u32, 1, 1),
TensorArg::new(&handle, &strides, &shape), TensorArg::new(&handle, &strides, &shape),
); );

View File

@ -57,14 +57,8 @@ struct CompiledKernel {
unsafe impl<MM: MemoryManagement<CudaStorage>> Send for CudaServer<MM> {} unsafe impl<MM: MemoryManagement<CudaStorage>> Send for CudaServer<MM> {}
impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> { impl<MM: MemoryManagement<CudaStorage>> CudaServer<MM> {
type Kernel = Box<dyn CubeTask>; fn read_sync(&mut self, binding: server::Binding<Self>) -> Vec<u8> {
type Storage = CudaStorage;
type MemoryManagement = MM;
type AutotuneKey = JitAutotuneKey;
type FeatureSet = FeatureSet;
fn read(&mut self, binding: server::Binding<Self>) -> Reader {
let ctx = self.get_context(); let ctx = self.get_context();
let resource = ctx.memory_management.get(binding.memory); let resource = ctx.memory_management.get(binding.memory);
@ -74,7 +68,20 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
cudarc::driver::result::memcpy_dtoh_async(&mut data, resource.ptr, ctx.stream).unwrap(); cudarc::driver::result::memcpy_dtoh_async(&mut data, resource.ptr, ctx.stream).unwrap();
}; };
ctx.sync(); ctx.sync();
reader_from_concrete(data) data
}
}
impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
type Kernel = Box<dyn CubeTask>;
type DispatchOptions = CubeCount<Self>;
type Storage = CudaStorage;
type MemoryManagement = MM;
type AutotuneKey = JitAutotuneKey;
type FeatureSet = FeatureSet;
fn read(&mut self, binding: server::Binding<Self>) -> Reader {
reader_from_concrete(self.read_sync(binding))
} }
fn create(&mut self, data: &[u8]) -> server::Handle<Self> { fn create(&mut self, data: &[u8]) -> server::Handle<Self> {
@ -101,12 +108,33 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
server::Handle::new(handle) server::Handle::new(handle)
} }
fn execute(&mut self, kernel: Self::Kernel, bindings: Vec<server::Binding<Self>>) { fn execute(
&mut self,
kernel: Self::Kernel,
count: Self::DispatchOptions,
bindings: Vec<server::Binding<Self>>,
) {
let arch = self.minimum_arch_version; let arch = self.minimum_arch_version;
let ctx = self.get_context();
let kernel_id = kernel.id(); let kernel_id = kernel.id();
let settings = kernel.launch_settings();
let count = match count {
CubeCount::Static(x, y, z) => (x, y, z),
// TODO: CUDA doesn't have an exact equivalen of dynamic dispatch. Instead, kernels are free to launch other kernels.
// One option is to create a dummy kernel with 1 thread that launches the real kernel with the dynamic dispatch settings.
// For now, just read the dispatch settings from the buffer.
CubeCount::Dynamic(binding) => {
let data = self.read_sync(binding);
let data = bytemuck::cast_slice(&data);
assert!(
data.len() == 3,
"Dynamic cube count should contain 3 values"
);
(data[0], data[1], data[2])
}
};
let ctx = self.get_context();
if !ctx.module_names.contains_key(&kernel_id) { if !ctx.module_names.contains_key(&kernel_id) {
ctx.compile_kernel(&kernel_id, kernel, arch); ctx.compile_kernel(&kernel_id, kernel, arch);
@ -117,7 +145,7 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
.map(|binding| ctx.memory_management.get(binding.memory).as_binding()) .map(|binding| ctx.memory_management.get(binding.memory).as_binding())
.collect(); .collect();
ctx.execute_task(kernel_id, settings.cube_count, bindings); ctx.execute_task(kernel_id, count, bindings);
// TODO: fix this // TODO: fix this
// self.memory_management.storage().perform_deallocations(); // self.memory_management.storage().perform_deallocations();
} }
@ -217,16 +245,15 @@ impl<MM: MemoryManagement<CudaStorage>> CudaContext<MM> {
fn execute_task( fn execute_task(
&mut self, &mut self,
kernel_id: String, kernel_id: String,
cube_count: CubeCount, dispatch_count: (u32, u32, u32),
mut bindings: Vec<Binding>, mut bindings: Vec<Binding>,
) { ) {
let kernel = self.module_names.get(&kernel_id).unwrap(); let kernel = self.module_names.get(&kernel_id).unwrap();
let cube_dim = kernel.cube_dim; let cube_dim = kernel.cube_dim;
unsafe { unsafe {
cudarc::driver::result::launch_kernel( cudarc::driver::result::launch_kernel(
kernel.func, kernel.func,
(cube_count.x, cube_count.y, cube_count.z), dispatch_count,
(cube_dim.x, cube_dim.y, cube_dim.z), (cube_dim.x, cube_dim.y, cube_dim.z),
kernel.shared_mem_bytes as u32, kernel.shared_mem_bytes as u32,
self.stream, self.stream,

View File

@ -29,9 +29,9 @@ pub struct CompilationPhase;
/// Phase where the kernel should be executed. /// Phase where the kernel should be executed.
#[derive(new)] #[derive(new)]
pub struct ExecutionPhase<R: JitRuntime> { pub struct ExecutionPhase<R: JitRuntime> {
/// Kernel set with default workgroup size. /// Kernel set with default cube size.
pub(super) kernel_factory_1: ElementWiseKernelFactory<R>, pub(super) kernel_factory_1: ElementWiseKernelFactory<R>,
/// Kernel set with custom workgroup size. /// Kernel set with custom cube size.
pub(super) kernel_factory_2: ElementWiseKernelFactory<R>, pub(super) kernel_factory_2: ElementWiseKernelFactory<R>,
} }

View File

@ -22,7 +22,7 @@ pub struct FusionKernel<R: JitRuntime> {
info: Arc<KernelExpansion>, info: Arc<KernelExpansion>,
settings: KernelSettings, settings: KernelSettings,
runtime_info: Vec<OutputRuntimeInfo>, runtime_info: Vec<OutputRuntimeInfo>,
cube_count: CubeCount, cube_count: CubeCount<R::Server>,
_runtime: PhantomData<R>, _runtime: PhantomData<R>,
} }
@ -41,6 +41,7 @@ pub trait FusionKernelFactory<R: JitRuntime> {
#[derive(new)] #[derive(new)]
pub struct ExecutableKernel<R: JitRuntime> { pub struct ExecutableKernel<R: JitRuntime> {
kernel: Box<dyn CubeTask>, kernel: Box<dyn CubeTask>,
cube_count: CubeCount<R::Server>,
bindings: Vec<Binding<R::Server>>, bindings: Vec<Binding<R::Server>>,
client: ComputeClient<R::Server, R::Channel>, client: ComputeClient<R::Server, R::Channel>,
} }
@ -54,6 +55,7 @@ pub struct ExecutableKernel<R: JitRuntime> {
#[derive(new)] #[derive(new)]
pub struct AutotunableKernel<R: JitRuntime> { pub struct AutotunableKernel<R: JitRuntime> {
kernel: Arc<dyn CubeTask>, kernel: Arc<dyn CubeTask>,
count: CubeCount<R::Server>,
bindings: Vec<Binding<R::Server>>, bindings: Vec<Binding<R::Server>>,
client: ComputeClient<R::Server, R::Channel>, client: ComputeClient<R::Server, R::Channel>,
} }
@ -68,18 +70,21 @@ pub enum OutputRuntimeInfo {
impl<R: JitRuntime> ExecutableKernel<R> { impl<R: JitRuntime> ExecutableKernel<R> {
/// Execute the kernel. /// Execute the kernel.
pub fn execute(self) { pub fn execute(self) {
self.client.execute(self.kernel, self.bindings) self.client
.execute(self.kernel, self.cube_count, self.bindings)
} }
} }
impl<R: JitRuntime> AutotuneOperation for AutotunableKernel<R> { impl<R: JitRuntime> AutotuneOperation for AutotunableKernel<R> {
fn execute(self: Box<Self>) { fn execute(self: Box<Self>) {
self.client.execute(Box::new(self.kernel), self.bindings) self.client
.execute(Box::new(self.kernel), self.count, self.bindings)
} }
fn clone(&self) -> Box<dyn AutotuneOperation> { fn clone(&self) -> Box<dyn AutotuneOperation> {
Box::new(Self { Box::new(Self {
kernel: self.kernel.clone(), kernel: self.kernel.clone(),
count: self.count.clone(),
bindings: self.bindings.clone(), bindings: self.bindings.clone(),
client: self.client.clone(), client: self.client.clone(),
}) })
@ -90,6 +95,7 @@ impl<R: JitRuntime> From<ExecutableKernel<R>> for AutotunableKernel<R> {
fn from(value: ExecutableKernel<R>) -> Self { fn from(value: ExecutableKernel<R>) -> Self {
Self { Self {
kernel: Arc::new(value.kernel), kernel: Arc::new(value.kernel),
count: value.cube_count.clone(),
bindings: value.bindings, bindings: value.bindings,
client: value.client, client: value.client,
} }
@ -233,12 +239,10 @@ impl<R: JitRuntime> FusionKernel<R> {
context.handles.register_handle(id, handle); context.handles.register_handle(id, handle);
} }
let workgroup = fusion_kernel.cube_count.clone(); let cube_count = fusion_kernel.cube_count.clone();
ExecutableKernel::new( ExecutableKernel::new(
Box::new(KernelTask::<R::Compiler, FusionKernel<R>>::new( Box::new(KernelTask::<R::Compiler, _>::new(fusion_kernel)),
fusion_kernel, cube_count,
workgroup,
)),
bindings, bindings,
client, client,
) )

View File

@ -196,7 +196,7 @@ pub(crate) fn scatter<R: JitRuntime, E: JitElement, I: JitElement, const D: usiz
let kernel = ScatterEagerKernel::<R, E>::new(dim); let kernel = ScatterEagerKernel::<R, E>::new(dim);
let mut strides = [0; D]; let mut strides = [0; D];
let mut current = 1; let mut current = 1;
let mut num_elems_per_workgroup = 1; let mut num_elems = 1;
tensor tensor
.shape .shape
@ -208,13 +208,13 @@ pub(crate) fn scatter<R: JitRuntime, E: JitElement, I: JitElement, const D: usiz
.for_each(|(index, val)| { .for_each(|(index, val)| {
strides[index] = current; strides[index] = current;
current *= val; current *= val;
num_elems_per_workgroup *= tensor.shape.dims[index]; num_elems *= tensor.shape.dims[index];
}); });
// Fake strides of the virtual output where the strides of dim is hardcoded to one. // Fake strides of the virtual output where the strides of dim is hardcoded to one.
indices.strides = strides; indices.strides = strides;
let workgroup = calculate_cube_count_elemwise(num_elems_per_workgroup, SUBCUBE_DIM_APPROX); let cube_count = calculate_cube_count_elemwise(num_elems, SUBCUBE_DIM_APPROX);
Execution::start(kernel, indices.client) Execution::start(kernel, indices.client)
.inputs(&[ .inputs(&[
@ -222,7 +222,7 @@ pub(crate) fn scatter<R: JitRuntime, E: JitElement, I: JitElement, const D: usiz
TensorHandle::new(&indices.handle, &indices.strides, &indices.shape.dims), TensorHandle::new(&indices.handle, &indices.strides, &indices.shape.dims),
TensorHandle::new(&value.handle, &value.strides, &value.shape.dims), TensorHandle::new(&value.handle, &value.strides, &value.shape.dims),
]) ])
.execute(CubeCountSettings::Custom(workgroup)); .execute(CubeCountSettings::Custom(cube_count));
tensor tensor
} }

View File

@ -189,7 +189,7 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
let mut strides = [0; D]; let mut strides = [0; D];
let mut current = 1; let mut current = 1;
let mut num_elems_per_workgroup = 1; let mut num_elems = 1;
tensor tensor
.shape .shape
@ -201,11 +201,11 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
.for_each(|(index, val)| { .for_each(|(index, val)| {
strides[index] = current; strides[index] = current;
current *= val; current *= val;
num_elems_per_workgroup *= tensor.shape.dims[index]; num_elems *= tensor.shape.dims[index];
}); });
let kernel = SelectAssignEagerKernel::<R, E>::new(dim); let kernel = SelectAssignEagerKernel::<R, E>::new(dim);
let workgroup = calculate_cube_count_elemwise(num_elems_per_workgroup, SUBCUBE_DIM_APPROX); let cube_count = calculate_cube_count_elemwise(num_elems, SUBCUBE_DIM_APPROX);
Execution::start(kernel, indices.client) Execution::start(kernel, indices.client)
.inputs(&[ .inputs(&[
@ -215,7 +215,7 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
// kernel, but we need to put the right number of dimensions (rank). // kernel, but we need to put the right number of dimensions (rank).
TensorHandle::new(&indices.handle, &strides, &strides), TensorHandle::new(&indices.handle, &strides, &strides),
]) ])
.execute(CubeCountSettings::Custom(workgroup)); .execute(CubeCountSettings::Custom(cube_count));
tensor tensor
} }

View File

@ -140,41 +140,39 @@ pub fn matmul<R: JitRuntime, E: FloatElement, const D: usize>(
} }
} }
pub(crate) fn simple_launch_options<const D: usize>( pub(crate) fn simple_cube_count<R: JitRuntime, const D: usize>(
lhs_shape: &Shape<D>, lhs_shape: &Shape<D>,
rhs_shape: &Shape<D>, rhs_shape: &Shape<D>,
output_shape: &Shape<D>, output_shape: &Shape<D>,
workgroup_size_x: usize, cube_dim_x: usize,
workgroup_size_y: usize, cube_dim_y: usize,
) -> CubeCount { ) -> CubeCount<R::Server> {
let num_rows = lhs_shape.dims[D - 2]; let num_rows = lhs_shape.dims[D - 2];
let num_cols = rhs_shape.dims[D - 1]; let num_cols = rhs_shape.dims[D - 1];
// set number of workgroups let cubes_x = f32::ceil(num_rows as f32 / cube_dim_x as f32) as u32;
let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32; let cubes_y = f32::ceil(num_cols as f32 / cube_dim_y as f32) as u32;
let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32;
let mut num_iter = 1; let mut num_iter = 1;
for i in 0..D - 2 { for i in 0..D - 2 {
num_iter *= output_shape.dims[i]; num_iter *= output_shape.dims[i];
} }
CubeCount::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32) CubeCount::Static(cubes_x, cubes_y, num_iter as u32)
} }
pub(crate) fn tiling2d_launch_options<const D: usize>( pub(crate) fn tiling2d_launch_options<R: JitRuntime, const D: usize>(
output_shape: &Shape<D>, output_shape: &Shape<D>,
config: Tiling2dConfig, config: Tiling2dConfig,
) -> CubeCount { ) -> CubeCount<R::Server> {
let num_rows = output_shape.dims[D - 2]; let num_rows = output_shape.dims[D - 2];
let num_cols = output_shape.dims[D - 1]; let num_cols = output_shape.dims[D - 1];
// set number of workgroups let cubes_x = f32::ceil(num_rows as f32 / config.block_size_m as f32) as u32;
let blocks_needed_in_x = f32::ceil(num_rows as f32 / config.block_size_m as f32) as u32; let cubes_y = f32::ceil(num_cols as f32 / config.block_size_n as f32) as u32;
let blocks_needed_in_y = f32::ceil(num_cols as f32 / config.block_size_n as f32) as u32;
let mut num_iter = 1; let mut num_iter = 1;
for i in 0..D - 2 { for i in 0..D - 2 {
num_iter *= output_shape.dims[i]; num_iter *= output_shape.dims[i];
} }
CubeCount::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32) CubeCount::Static(cubes_x, cubes_y, num_iter as u32)
} }

View File

@ -7,7 +7,7 @@ use crate::{
use burn_cube::ir::KernelDefinition; use burn_cube::ir::KernelDefinition;
use burn_cube::{frontend::TensorArg, KernelSettings}; use burn_cube::{frontend::TensorArg, KernelSettings};
use super::simple_launch_options; use super::simple_cube_count;
use burn_cube::prelude::*; use burn_cube::prelude::*;
#[cube(launch)] #[cube(launch)]
@ -80,7 +80,7 @@ fn matmul_kernel<F: Float>(
} }
} }
/// Matrix multiplication using memory coalescing algorithm with workgroups of size 16 /// Matrix multiplication using memory coalescing algorithm with cube dimensions of size 16
pub fn matmul_mem_coalescing_default<R: JitRuntime, E: FloatElement, const D: usize>( pub fn matmul_mem_coalescing_default<R: JitRuntime, E: FloatElement, const D: usize>(
lhs: JitTensor<R, E, D>, lhs: JitTensor<R, E, D>,
rhs: JitTensor<R, E, D>, rhs: JitTensor<R, E, D>,
@ -89,7 +89,7 @@ pub fn matmul_mem_coalescing_default<R: JitRuntime, E: FloatElement, const D: us
matmul_simple::<R, E, D>(lhs, rhs, out, SUBCUBE_DIM_APPROX, SUBCUBE_DIM_APPROX) matmul_simple::<R, E, D>(lhs, rhs, out, SUBCUBE_DIM_APPROX, SUBCUBE_DIM_APPROX)
} }
/// Matrix multiplication using memory coalescing algorithm with custom workgroup sizes /// Matrix multiplication using memory coalescing algorithm with custom cube dimensions
pub fn matmul_simple<R: JitRuntime, E: FloatElement, const D: usize>( pub fn matmul_simple<R: JitRuntime, E: FloatElement, const D: usize>(
lhs: JitTensor<R, E, D>, lhs: JitTensor<R, E, D>,
rhs: JitTensor<R, E, D>, rhs: JitTensor<R, E, D>,
@ -103,7 +103,7 @@ pub fn matmul_simple<R: JitRuntime, E: FloatElement, const D: usize>(
let rhs_original_shape = rhs.shape.clone(); let rhs_original_shape = rhs.shape.clone();
let rhs = into_contiguous(swap_dims(rhs, D - 1, D - 2)); let rhs = into_contiguous(swap_dims(rhs, D - 1, D - 2));
let cube_count = simple_launch_options( let cube_count = simple_cube_count::<R, D>(
&lhs.shape, &lhs.shape,
&rhs_original_shape, &rhs_original_shape,
&out.shape, &out.shape,

View File

@ -118,7 +118,7 @@ pub fn matmul_tiling_2d<R: JitRuntime, E: JitElement + Element, const D: usize>(
&out.strides, &out.strides,
&out.shape.dims, &out.shape.dims,
)]) )])
.execute(CubeCountSettings::Custom(tiling2d_launch_options( .execute(CubeCountSettings::Custom(tiling2d_launch_options::<R, D>(
&out.shape, config, &out.shape, config,
))); )));
@ -175,7 +175,7 @@ pub fn matmul_tiling_2d_padded<R: JitRuntime, E: JitElement + Element, const D:
&rounded_output.strides, &rounded_output.strides,
&rounded_output.shape.dims, &rounded_output.shape.dims,
)]) )])
.execute(CubeCountSettings::Custom(tiling2d_launch_options( .execute(CubeCountSettings::Custom(tiling2d_launch_options::<R, D>(
&rounded_output.shape, &rounded_output.shape,
config, config,
))); )));

View File

@ -55,14 +55,14 @@ pub(crate) fn gather_shader_information(
cpa!(scope, out_stride_row = stride(out, second_to_last_dim)); cpa!(scope, out_stride_row = stride(out, second_to_last_dim));
cpa!(scope, out_stride_col = stride(out, last_dim)); cpa!(scope, out_stride_col = stride(out, last_dim));
// Workgroup offset // Cube offset
let skip_row = scope.create_local(Elem::UInt); let skip_row = scope.create_local(Elem::UInt);
let skip_col = scope.create_local(Elem::UInt); let skip_col = scope.create_local(Elem::UInt);
let workgroup_id_x = Variable::CubePosX; let cube_pos_x = Variable::CubePosX;
let workgroup_id_y = Variable::CubePosY; let cube_pos_y = Variable::CubePosY;
cpa!(scope, skip_row = workgroup_id_x); cpa!(scope, skip_row = cube_pos_x);
cpa!(scope, skip_row *= block_size_m); cpa!(scope, skip_row *= block_size_m);
cpa!(scope, skip_col = workgroup_id_y); cpa!(scope, skip_col = cube_pos_y);
cpa!(scope, skip_col *= block_size_n); cpa!(scope, skip_col *= block_size_n);
// Position of the first element of the thread, relative to the block // Position of the first element of the thread, relative to the block

View File

@ -38,7 +38,7 @@ pub(crate) fn random<P: Prng<E>, R: JitRuntime, E: JitElement, const D: usize>(
)]) )])
.with_scalars(&seeds) .with_scalars(&seeds)
.with_scalars(&prng.args()) .with_scalars(&prng.args())
.execute(CubeCountSettings::Custom(prng_cube_count( .execute(CubeCountSettings::Custom(prng_cube_count::<R>(
num_elems, num_elems,
SUBCUBE_DIM_APPROX, SUBCUBE_DIM_APPROX,
N_VALUES_PER_THREAD, N_VALUES_PER_THREAD,
@ -47,14 +47,18 @@ pub(crate) fn random<P: Prng<E>, R: JitRuntime, E: JitElement, const D: usize>(
output output
} }
fn prng_cube_count(num_elems: usize, cube_dim: usize, n_values_per_thread: usize) -> CubeCount { fn prng_cube_count<R: JitRuntime>(
num_elems: usize,
cube_dim: usize,
n_values_per_thread: usize,
) -> CubeCount<R::Server> {
let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32); let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32);
let num_elems_per_cube = cube_dim * cube_dim; let num_elems_per_cube = cube_dim * cube_dim;
let num_invocations = f32::ceil(num_threads / num_elems_per_cube as f32); let num_invocations = f32::ceil(num_threads / num_elems_per_cube as f32);
let workgroup_x = f32::ceil(f32::sqrt(num_invocations)); let cubes_x = f32::ceil(f32::sqrt(num_invocations));
let workgroup_y = f32::ceil(num_invocations / workgroup_x); let cubes_y = f32::ceil(num_invocations / cubes_x);
CubeCount::new(workgroup_x as u32, workgroup_y as u32, 1) CubeCount::Static(cubes_x as u32, cubes_y as u32, 1)
} }
impl<P: Prng<E>, R: JitRuntime, E: JitElement> Kernel for PrngEagerKernel<P, R, E> { impl<P: Prng<E>, R: JitRuntime, E: JitElement> Kernel for PrngEagerKernel<P, R, E> {
@ -163,24 +167,24 @@ impl<P: Prng<E>, E: JitElement> PrngShader<P, E> {
let n_values_per_thread: Variable = self.n_values_per_thread.into(); let n_values_per_thread: Variable = self.n_values_per_thread.into();
let args = self.args; let args = self.args;
let workgroup_size_x = Variable::CubeDimX; let cube_dim_x = Variable::CubeDimX;
let workgroup_size_y = Variable::CubeDimY; let cube_dim_y = Variable::CubeDimY;
let workgroup_id_x = Variable::CubePosX; let cube_pos_x = Variable::CubePosX;
let workgroup_id_y = Variable::CubePosY; let cube_pos_y = Variable::CubePosY;
let num_workgroups_y = Variable::CubeCountY; let cube_count_y = Variable::CubeCountY;
let local_index = Variable::UnitPos; let local_index = Variable::UnitPos;
let n_invocations = scope.create_local(Elem::UInt); let n_invocations = scope.create_local(Elem::UInt);
cpa!(scope, n_invocations = workgroup_size_x); cpa!(scope, n_invocations = cube_dim_x);
cpa!(scope, n_invocations *= workgroup_size_y); cpa!(scope, n_invocations *= cube_dim_y);
let workgroup_offset = scope.create_local(Elem::UInt); let cube_offset = scope.create_local(Elem::UInt);
cpa!(scope, workgroup_offset = workgroup_id_x * num_workgroups_y); cpa!(scope, cube_offset = cube_pos_x * cube_count_y);
cpa!(scope, workgroup_offset += workgroup_id_y); cpa!(scope, cube_offset += cube_pos_y);
cpa!(scope, workgroup_offset *= n_invocations); cpa!(scope, cube_offset *= n_invocations);
let write_index_base = scope.create_local(Elem::UInt); let write_index_base = scope.create_local(Elem::UInt);
cpa!(scope, write_index_base = workgroup_offset); cpa!(scope, write_index_base = cube_offset);
cpa!(scope, write_index_base *= n_values_per_thread); cpa!(scope, write_index_base *= n_values_per_thread);
cpa!(scope, write_index_base += local_index); cpa!(scope, write_index_base += local_index);
@ -188,7 +192,7 @@ impl<P: Prng<E>, E: JitElement> PrngShader<P, E> {
let thread_seed = scope.create_local(Elem::UInt); let thread_seed = scope.create_local(Elem::UInt);
cpa!(scope, thread_seed = cast(1000000007)); cpa!(scope, thread_seed = cast(1000000007));
let thread_seed_index = scope.create_local(Elem::UInt); let thread_seed_index = scope.create_local(Elem::UInt);
cpa!(scope, thread_seed_index = workgroup_offset + local_index); cpa!(scope, thread_seed_index = cube_offset + local_index);
cpa!(scope, thread_seed *= thread_seed_index); cpa!(scope, thread_seed *= thread_seed_index);
let state_0 = scope.create_local(Elem::UInt); let state_0 = scope.create_local(Elem::UInt);

View File

@ -33,8 +33,8 @@ pub(crate) struct SharedReduceDimEagerKernel<
EO: JitElement, EO: JitElement,
> { > {
dim: usize, dim: usize,
workgroup_size_x: usize, cube_dim_x: usize,
workgroup_size_y: usize, cube_dim_y: usize,
n_input_values_per_thread: u32, n_input_values_per_thread: u32,
divisible_shape: bool, divisible_shape: bool,
_reduce_dim: PhantomData<RD>, _reduce_dim: PhantomData<RD>,
@ -58,7 +58,7 @@ impl<RD: ReduceDimShared<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Ker
SharedReduceDimComputeShader { SharedReduceDimComputeShader {
tensor, tensor,
dim: self.dim, dim: self.dim,
shared_memory_size: self.workgroup_size_x * self.workgroup_size_y, shared_memory_size: self.cube_dim_x * self.cube_dim_y,
n_input_values_per_thread: self.n_input_values_per_thread, n_input_values_per_thread: self.n_input_values_per_thread,
output, output,
divisible_shape: self.divisible_shape, divisible_shape: self.divisible_shape,
@ -83,8 +83,8 @@ impl<RD: ReduceDimShared<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Ker
}; };
let settings = KernelSettings::default().cube_dim(CubeDim::new( let settings = KernelSettings::default().cube_dim(CubeDim::new(
self.workgroup_size_x as u32, self.cube_dim_x as u32,
self.workgroup_size_y as u32, self.cube_dim_y as u32,
1, 1,
)); ));
KernelIntegrator::new(info).integrate(settings) KernelIntegrator::new(info).integrate(settings)
@ -95,8 +95,8 @@ impl<RD: ReduceDimShared<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Ker
"{:?}dim={}x={}y={}n={}divshape={}", "{:?}dim={}x={}y={}n={}divshape={}",
core::any::TypeId::of::<Self>(), core::any::TypeId::of::<Self>(),
self.dim, self.dim,
self.workgroup_size_x, self.cube_dim_x,
self.workgroup_size_y, self.cube_dim_y,
self.n_input_values_per_thread, self.n_input_values_per_thread,
self.divisible_shape self.divisible_shape
) )
@ -111,13 +111,13 @@ impl<E: JitElement, RD: ReduceDimShared<E>> SharedReduceDimComputeShader<E, RD>
let rank = Variable::Rank; let rank = Variable::Rank;
let dim: Variable = self.dim.into(); let dim: Variable = self.dim.into();
let workgroup_id_x = Variable::CubePosX; let cube_pos_x = Variable::CubePosX;
let workgroup_id_y = Variable::CubePosY; let cube_pos_y = Variable::CubePosY;
let num_workgroups_x = Variable::CubeCountX; let cube_count_x = Variable::CubeCountX;
let local_invocation_id_x = Variable::UnitPosX; let local_invocation_id_x = Variable::UnitPosX;
let local_invocation_id_y = Variable::UnitPosY; let local_invocation_id_y = Variable::UnitPosY;
let workgroup_size_x = Variable::CubeDimX; let cube_dim_x = Variable::CubeDimX;
let workgroup_size_y = Variable::CubeDimY; let cube_dim_y = Variable::CubeDimY;
let stride_reduce_dim_input = scope.create_local(Elem::UInt); let stride_reduce_dim_input = scope.create_local(Elem::UInt);
cpa!(scope, stride_reduce_dim_input = stride(tensor, dim)); cpa!(scope, stride_reduce_dim_input = stride(tensor, dim));
@ -126,16 +126,16 @@ impl<E: JitElement, RD: ReduceDimShared<E>> SharedReduceDimComputeShader<E, RD>
// To determine which reduce_group (not position, but absolute id) // To determine which reduce_group (not position, but absolute id)
let reduce_group_id = scope.create_local(Elem::UInt); let reduce_group_id = scope.create_local(Elem::UInt);
cpa!(scope, reduce_group_id = workgroup_id_y * num_workgroups_x); cpa!(scope, reduce_group_id = cube_pos_y * cube_count_x);
cpa!(scope, reduce_group_id += workgroup_id_x); cpa!(scope, reduce_group_id += cube_pos_x);
// nth thread in the workgroup // nth thread in the cube
let local_id = scope.create_local(Elem::UInt); let local_id = scope.create_local(Elem::UInt);
cpa!(scope, local_id = local_invocation_id_y * workgroup_size_x); cpa!(scope, local_id = local_invocation_id_y * cube_dim_x);
cpa!(scope, local_id += local_invocation_id_x); cpa!(scope, local_id += local_invocation_id_x);
let n_threads = scope.create_local(Elem::UInt); let n_threads = scope.create_local(Elem::UInt);
cpa!(scope, n_threads = workgroup_size_x * workgroup_size_y); cpa!(scope, n_threads = cube_dim_x * cube_dim_y);
let index_offset = scope.zero(Elem::UInt); let index_offset = scope.zero(Elem::UInt);
@ -242,17 +242,17 @@ pub fn reduce_dim_shared<
dim: usize, dim: usize,
) -> JitTensor<R, EO, D> { ) -> JitTensor<R, EO, D> {
let num_elems_output = output.shape.num_elements(); let num_elems_output = output.shape.num_elements();
let n_workgroups_x = f32::ceil(f32::sqrt(num_elems_output as f32)); let cube_count_x = f32::ceil(f32::sqrt(num_elems_output as f32));
let n_workgroups_y = f32::ceil(num_elems_output as f32 / n_workgroups_x); let cube_count_y = f32::ceil(num_elems_output as f32 / cube_count_x);
let grid = CubeCount::new(n_workgroups_x as u32, n_workgroups_y as u32, 1); let grid = CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1);
let reduce_group_size = input.shape.dims[dim]; let reduce_group_size = input.shape.dims[dim];
let n_invocation_per_workgroup = SUBCUBE_DIM_APPROX * SUBCUBE_DIM_APPROX; let n_invocation_per_cube = SUBCUBE_DIM_APPROX * SUBCUBE_DIM_APPROX;
let n_input_values_per_thread = let n_input_values_per_thread =
f32::ceil(reduce_group_size as f32 / n_invocation_per_workgroup as f32) as u32; f32::ceil(reduce_group_size as f32 / n_invocation_per_cube as f32) as u32;
let divisible_shape = let divisible_shape =
n_invocation_per_workgroup as u32 * n_input_values_per_thread == reduce_group_size as u32; n_invocation_per_cube as u32 * n_input_values_per_thread == reduce_group_size as u32;
let kernel = SharedReduceDimEagerKernel::<RD, R, EI, EO>::new( let kernel = SharedReduceDimEagerKernel::<RD, R, EI, EO>::new(
dim, dim,

View File

@ -18,7 +18,10 @@ pub(crate) mod tune;
/// Elements for JIT backend /// Elements for JIT backend
pub mod element; pub mod element;
use burn_cube::{compute::CubeTask, Runtime}; use burn_cube::{
compute::{CubeCount, CubeTask},
Runtime,
};
pub use element::{FloatElement, IntElement, JitElement}; pub use element::{FloatElement, IntElement, JitElement};
mod backend; mod backend;
@ -48,5 +51,6 @@ pub trait JitRuntime: Runtime<Device = Self::JitDevice, Server = Self::JitServer
type JitServer: burn_compute::server::ComputeServer< type JitServer: burn_compute::server::ComputeServer<
AutotuneKey = JitAutotuneKey, AutotuneKey = JitAutotuneKey,
Kernel = Box<dyn CubeTask>, Kernel = Box<dyn CubeTask>,
DispatchOptions = CubeCount<Self::JitServer>,
>; >;
} }

View File

@ -1,5 +1,4 @@
use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
use burn_cube::compute::LaunchSettings;
use burn_cube::prelude::*; use burn_cube::prelude::*;
use super::SourceTemplate; use super::SourceTemplate;
@ -11,18 +10,13 @@ pub trait KernelSource: Send + 'static + Sync {
} }
#[derive(new)] #[derive(new)]
/// Wraps a [kernel source](KernelSource) into a [cube task](CubeTask) with launch /// Wraps a [kernel source](KernelSource) into a [cube task](CubeTask).
/// information.
pub struct SourceKernel<K> { pub struct SourceKernel<K> {
kernel_source: K, kernel_source: K,
cube_count: CubeCount,
cube_dim: CubeDim, cube_dim: CubeDim,
} }
impl<K> CubeTask for SourceKernel<K> impl<K: KernelSource> CubeTask for SourceKernel<K> {
where
K: KernelSource + 'static,
{
fn compile(&self) -> CompiledKernel { fn compile(&self) -> CompiledKernel {
let source_template = self.kernel_source.source(); let source_template = self.kernel_source.source();
let source = source_template.complete(); let source = source_template.complete();
@ -37,12 +31,6 @@ where
fn id(&self) -> String { fn id(&self) -> String {
format!("{:?}", core::any::TypeId::of::<K>()) format!("{:?}", core::any::TypeId::of::<K>())
} }
fn launch_settings(&self) -> LaunchSettings {
LaunchSettings {
cube_count: self.cube_count.clone(),
}
}
} }
/// Generates kernel source code by replacing some information using templating. /// Generates kernel source code by replacing some information using templating.

View File

@ -64,8 +64,15 @@ where
&mut self, &mut self,
pipeline: Arc<ComputePipeline>, pipeline: Arc<ComputePipeline>,
bind_group: BindGroup, bind_group: BindGroup,
work_group: CubeCount, count: CubeCount<Self>,
) { ) {
// First resolve the dispatch buffer if needed. The weird ordering is because the lifetime of this
// needs to be longer than the compute pass, so we can't do this just before dispatching.
let dispatch_resource = match count.clone() {
CubeCount::Dynamic(binding) => Some(self.memory_management.get(binding.memory)),
_ => None,
};
let mut compute = self let mut compute = self
.encoder .encoder
.begin_compute_pass(&wgpu::ComputePassDescriptor { .begin_compute_pass(&wgpu::ComputePassDescriptor {
@ -75,12 +82,21 @@ where
compute.set_pipeline(&pipeline); compute.set_pipeline(&pipeline);
compute.set_bind_group(0, &bind_group, &[]); compute.set_bind_group(0, &bind_group, &[]);
compute.dispatch_workgroups(work_group.x, work_group.y, work_group.z);
match count {
CubeCount::Static(x, y, z) => {
compute.dispatch_workgroups(x, y, z);
}
CubeCount::Dynamic(_) => {
let resource = dispatch_resource.as_ref().unwrap();
compute.dispatch_workgroups_indirect(&resource.buffer, resource.offset());
}
}
self.tasks_count += 1; self.tasks_count += 1;
} }
fn pipeline(&mut self, kernel: Box<dyn CubeTask>) -> Arc<ComputePipeline> { fn pipeline(&mut self, kernel: <Self as ComputeServer>::Kernel) -> Arc<ComputePipeline> {
let kernel_id = kernel.id(); let kernel_id = kernel.id();
if let Some(pipeline) = self.pipelines.get(&kernel_id) { if let Some(pipeline) = self.pipelines.get(&kernel_id) {
@ -143,6 +159,7 @@ where
MM: MemoryManagement<WgpuStorage>, MM: MemoryManagement<WgpuStorage>,
{ {
type Kernel = Box<dyn CubeTask>; type Kernel = Box<dyn CubeTask>;
type DispatchOptions = CubeCount<Self>;
type Storage = WgpuStorage; type Storage = WgpuStorage;
type MemoryManagement = MM; type MemoryManagement = MM;
type AutotuneKey = JitAutotuneKey; type AutotuneKey = JitAutotuneKey;
@ -259,8 +276,12 @@ where
})) }))
} }
fn execute(&mut self, kernel: Self::Kernel, bindings: Vec<server::Binding<Self>>) { fn execute(
let work_group = kernel.launch_settings().cube_count; &mut self,
kernel: Self::Kernel,
count: Self::DispatchOptions,
bindings: Vec<server::Binding<Self>>,
) {
let pipeline = self.pipeline(kernel); let pipeline = self.pipeline(kernel);
let group_layout = pipeline.get_bind_group_layout(0); let group_layout = pipeline.get_bind_group_layout(0);
@ -284,7 +305,7 @@ where
entries: &entries, entries: &entries,
}); });
self.register_compute(pipeline, bind_group, work_group); self.register_compute(pipeline, bind_group, count);
if self.tasks_count >= self.tasks_max { if self.tasks_count >= self.tasks_max {
self.sync(SyncType::Flush); self.sync(SyncType::Flush);

View File

@ -111,7 +111,8 @@ impl ComputeStorage for WgpuStorage {
size: size as u64, size: size as u64,
usage: wgpu::BufferUsages::COPY_DST usage: wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC, | wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::INDIRECT,
mapped_at_creation: false, mapped_at_creation: false,
})); }));

View File

@ -87,11 +87,13 @@ impl<F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime, F, I> {
// Declare the wgsl workgroup with the number of cubes in x, y and z. // Declare the wgsl workgroup with the number of cubes in x, y and z.
let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32; let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;
let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32; let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32;
let cube_count = CubeCount::new(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32); let cube_count =
CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);
// Execute lazily the kernel with the launch information and the given buffers. // Execute lazily the kernel with the launch information and the given buffers.
lhs.client.execute( lhs.client.execute(
Box::new(SourceKernel::new(kernel, cube_count, cube_dim)), Box::new(SourceKernel::new(kernel, cube_dim)),
cube_count,
vec![ vec![
lhs.handle.binding(), lhs.handle.binding(),
rhs.handle.binding(), rhs.handle.binding(),

View File

@ -22,7 +22,7 @@ pub fn launch<R: Runtime>(device: &R::Device) {
gelu_launch::<F32, R>( gelu_launch::<F32, R>(
client.clone(), client.clone(),
CubeCount::new(1, 1, 1), CubeCount::Static(1, 1, 1),
CubeDim::default(), CubeDim::default(),
ArrayArg::new(&input_handle, input.len()), ArrayArg::new(&input_handle, input.len()),
ArrayArg::new(&output_handle, input.len()), ArrayArg::new(&output_handle, input.len()),