Feat: Dynamic cube count dispatch (#1975)

This commit is contained in:
Arthur Brussee 2024-07-07 00:17:01 +01:00 committed by GitHub
parent b331290f8a
commit 3f9e97946f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 293 additions and 215 deletions

View File

@ -248,11 +248,12 @@ impl<F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime, F, I> {
// Declare the wgsl workgroup with the number of blocks in x, y and z.
let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;
let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32;
let cube_count = CubeCount::new(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);
let cube_count = CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);
// Execute lazily the kernel with the launch information and the given buffers.
lhs.client.execute(
Box::new(SourceKernel::new(kernel, cube_count, cube_dim)),
Box::new(SourceKernel::new(kernel, cube_dim)),
cube_count,
vec![
lhs.handle.binding(),
rhs.handle.binding(),

View File

@ -24,7 +24,12 @@ pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug + Send
fn empty(&self, size: usize) -> Handle<Server>;
/// Executes the `kernel` over the given `bindings`.
fn execute(&self, kernel: Server::Kernel, bindings: Vec<Binding<Server>>);
fn execute(
&self,
kernel: Server::Kernel,
count: Server::DispatchOptions,
bindings: Vec<Binding<Server>>,
);
/// Perform some synchronization of commands on the server.
fn sync(&self, sync_type: SyncType);

View File

@ -63,10 +63,15 @@ where
self.server.borrow_mut().empty(size)
}
fn execute(&self, kernel_description: Server::Kernel, bindings: Vec<Binding<Server>>) {
fn execute(
&self,
kernel_description: Server::Kernel,
count: Server::DispatchOptions,
bindings: Vec<Binding<Server>>,
) {
self.server
.borrow_mut()
.execute(kernel_description, bindings)
.execute(kernel_description, count, bindings)
}
fn sync(&self, sync_type: SyncType) {

View File

@ -39,7 +39,10 @@ where
),
Create(Vec<u8>, Callback<Handle<Server>>),
Empty(usize, Callback<Handle<Server>>),
ExecuteKernel(Server::Kernel, Vec<Binding<Server>>),
ExecuteKernel(
(Server::Kernel, Server::DispatchOptions),
Vec<Binding<Server>>,
),
Sync(SyncType, Callback<()>),
}
@ -74,7 +77,7 @@ where
callback.send(handle).await.unwrap();
}
Message::ExecuteKernel(kernel, bindings) => {
server.execute(kernel, bindings);
server.execute(kernel.0, kernel.1, bindings);
}
Message::Sync(sync_type, callback) => {
server.sync(sync_type);
@ -148,10 +151,15 @@ where
handle_response(response.recv_blocking())
}
fn execute(&self, kernel: Server::Kernel, bindings: Vec<Binding<Server>>) {
fn execute(
&self,
kernel: Server::Kernel,
count: Server::DispatchOptions,
bindings: Vec<Binding<Server>>,
) {
self.state
.sender
.send_blocking(Message::ExecuteKernel(kernel, bindings))
.send_blocking(Message::ExecuteKernel((kernel, count), bindings))
.unwrap()
}

View File

@ -56,8 +56,13 @@ where
self.server.lock().empty(size)
}
fn execute(&self, kernel: Server::Kernel, handles: Vec<Binding<Server>>) {
self.server.lock().execute(kernel, handles)
fn execute(
&self,
kernel: Server::Kernel,
count: Server::DispatchOptions,
handles: Vec<Binding<Server>>,
) {
self.server.lock().execute(kernel, count, handles)
}
fn sync(&self, sync_type: SyncType) {

View File

@ -82,8 +82,13 @@ where
}
/// Executes the `kernel` over the given `bindings`.
pub fn execute(&self, kernel: Server::Kernel, bindings: Vec<Binding<Server>>) {
self.channel.execute(kernel, bindings)
pub fn execute(
&self,
kernel: Server::Kernel,
count: Server::DispatchOptions,
bindings: Vec<Binding<Server>>,
) {
self.channel.execute(kernel, count, bindings)
}
/// Wait for the completion of every task in the server.

View File

@ -17,6 +17,8 @@ where
{
/// The kernel type defines the computation algorithms.
type Kernel: Send;
/// Options when dispatching the kernel, eg. the number of executions.
type DispatchOptions: Send;
/// The [storage](ComputeStorage) type defines how data is stored and accessed.
type Storage: ComputeStorage;
/// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type.
@ -45,7 +47,12 @@ where
///
/// Kernels have mutable access to every resource they are given
/// and are responsible of determining which should be read or written.
fn execute(&mut self, kernel: Self::Kernel, bindings: Vec<Binding<Self>>);
fn execute(
&mut self,
kernel: Self::Kernel,
count: Self::DispatchOptions,
bindings: Vec<Binding<Self>>,
);
/// Wait for the completion of every task in the server.
fn sync(&mut self, command: SyncType);

View File

@ -21,6 +21,7 @@ impl<MM> ComputeServer for DummyServer<MM>
where
MM: MemoryManagement<BytesStorage>,
{
type DispatchOptions = ();
type Kernel = Arc<dyn DummyKernel>;
type Storage = BytesStorage;
type MemoryManagement = MM;
@ -53,7 +54,12 @@ where
Handle::new(self.memory_management.reserve(size, || {}))
}
fn execute(&mut self, kernel: Self::Kernel, bindings: Vec<Binding<Self>>) {
fn execute(
&mut self,
kernel: Self::Kernel,
_count: Self::DispatchOptions,
bindings: Vec<Binding<Self>>,
) {
let mut resources = bindings
.into_iter()
.map(|binding| self.memory_management.get(binding.memory))

View File

@ -18,7 +18,7 @@ pub struct OneKernelAutotuneOperation {
impl AutotuneOperation for OneKernelAutotuneOperation {
/// Executes the operation on given bindings and server, with the additional parameters
fn execute(self: Box<Self>) {
self.client.execute(self.kernel.clone(), self.bindings);
self.client.execute(self.kernel.clone(), (), self.bindings);
}
fn clone(&self) -> Box<dyn AutotuneOperation> {

View File

@ -38,6 +38,7 @@ fn execute_elementwise_addition() {
client.execute(
Arc::new(DummyElementwiseAddition),
(),
vec![lhs.binding(), rhs.binding(), out.clone().binding()],
);

View File

@ -470,7 +470,7 @@ pub fn codegen_launch(sig: &syn::Signature) -> TokenStream {
/// Launch
pub fn #ident #generics (
client: ComputeClient<R::Server, R::Channel>,
cube_count: CubeCount,
cube_count: CubeCount<R::Server>,
cube_dim: CubeDim,
#inputs
) -> #output {

View File

@ -4,13 +4,13 @@ use crate::ir::Elem;
use crate::pod::CubeElement;
use crate::{calculate_cube_count_elemwise, Kernel, Runtime, SUBCUBE_DIM_APPROX};
use burn_compute::client::ComputeClient;
use burn_compute::server::{Binding, Handle};
use burn_compute::server::{Binding, ComputeServer, Handle};
/// The position of the input or output to calculate the number of workgroups to launch.
pub enum CubeCountSettings {
/// The position of the input or output to calculate the number of cubes to launch.
pub enum CubeCountSettings<S: ComputeServer> {
Input { pos: usize },
Output { pos: usize },
Custom(CubeCount),
Custom(CubeCount<S>),
}
pub struct Execution<'h, K, R: Runtime, Scalars> {
@ -73,7 +73,7 @@ where
}
/// Execute a dynamic kernel.
#[allow(unused)]
pub fn execute(self, launch: CubeCountSettings) {
pub fn execute(self, launch: CubeCountSettings<R::Server>) {
execute_dynamic::<R, K, f32, f32, f32>(
self.inputs,
self.outputs,
@ -108,7 +108,7 @@ where
/// Execute a dynamic kernel.
#[allow(unused)]
pub fn execute(self, launch: CubeCountSettings) {
pub fn execute(self, launch: CubeCountSettings<R::Server>) {
execute_dynamic::<R, K, E, f32, f32>(
self.inputs,
self.outputs,
@ -144,7 +144,7 @@ where
}
/// Execute a dynamic kernel.
#[allow(clippy::too_many_arguments)]
pub fn execute(self, launch: CubeCountSettings)
pub fn execute(self, launch: CubeCountSettings<R::Server>)
where
K: Kernel + 'static,
R: Runtime,
@ -172,7 +172,7 @@ where
{
/// Execute a dynamic kernel.
#[allow(unused)]
pub fn execute(self, launch: CubeCountSettings) {
pub fn execute(self, launch: CubeCountSettings<R::Server>) {
execute_dynamic::<R, K, E1, E2, E3>(
self.inputs,
self.outputs,
@ -194,7 +194,7 @@ fn execute_dynamic<R, K, E1, E2, E3>(
scalars_2: Option<&[E2]>,
scalars_3: Option<&[E3]>,
kernel: K,
launch: CubeCountSettings,
launch: CubeCountSettings<R::Server>,
client: ComputeClient<R::Server, R::Channel>,
) where
K: Kernel + 'static,
@ -207,23 +207,21 @@ fn execute_dynamic<R, K, E1, E2, E3>(
inputs, outputs, scalars_1, scalars_2, scalars_3, launch, &client,
);
let mut handles = settings.handles_tensors;
let workgroup = settings.cube_count;
handles.push(settings.handle_info.binding());
for handle in settings.handles_scalars.into_iter() {
handles.push(handle.binding());
}
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel, workgroup));
client.execute(kernel, handles);
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
client.execute(kernel, settings.cube_count, handles);
}
struct ExecuteSettings<R: Runtime> {
handles_tensors: Vec<Binding<R::Server>>,
handle_info: Handle<R::Server>,
handles_scalars: Vec<Handle<R::Server>>,
cube_count: CubeCount,
cube_count: CubeCount<R::Server>,
}
fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeElement>(
@ -232,7 +230,7 @@ fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeEl
scalars_1: Option<&[E1]>,
scalars_2: Option<&[E2]>,
scalars_3: Option<&[E3]>,
launch: CubeCountSettings,
launch: CubeCountSettings<R::Server>,
client: &ComputeClient<R::Server, R::Channel>,
) -> ExecuteSettings<R> {
let mut info = Vec::new();
@ -295,8 +293,8 @@ fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeEl
let handles_scalars =
create_scalar_handles::<R, E1, E2, E3>(scalars_1, scalars_2, scalars_3, client);
let workgroup = match launch {
CubeCountSettings::Custom(workgroup) => workgroup,
let cube_count = match launch {
CubeCountSettings::Custom(count) => count,
_ => calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX),
};
@ -304,7 +302,7 @@ fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeEl
handles_tensors: handles,
handle_info: info,
handles_scalars,
cube_count: workgroup,
cube_count,
}
}

View File

@ -74,9 +74,9 @@ impl core::fmt::Display for KernelSettings {
// * Vectorization Global: vg{factor}
// * Vectorization Partial Input: v{factor}i{pos}
// * Vectorization Partial Output: vo
// * Workgroup Size X: x
// * Workgroup Size Y: y
// * Workgroup Size Z: z
// * Cube Dim X: x
// * Cube Dim Y: y
// * Cube Dim Z: z
f.write_str("m")?;
for mapping in self.mappings.iter() {
f.write_fmt(format_args!(

View File

@ -1,41 +1,32 @@
use std::marker::PhantomData;
use crate::{codegen::CompilerRepresentation, ir::CubeDim, Compiler, Kernel};
use alloc::sync::Arc;
use std::marker::PhantomData;
use burn_compute::server::{Binding, ComputeServer};
/// A kernel, compiled in the target language
pub struct CompiledKernel {
/// Source code of the kernel
pub source: String,
/// Size of a workgroup for the compiled kernel
/// Size of a cube for the compiled kernel
pub cube_dim: CubeDim,
/// The number of bytes used by the share memory
pub shared_mem_bytes: usize,
}
/// Information needed to launch the kernel
pub struct LaunchSettings {
/// Layout of workgroups for the kernel
pub cube_count: CubeCount,
}
/// Kernel trait with the ComputeShader that will be compiled and cached based on the
/// provided id.
///
/// The kernel will be launched with the given [launch settings](LaunchSettings).
pub trait CubeTask: Send + Sync {
/// Identifier for the kernel, used for caching kernel compilation.
fn id(&self) -> String;
/// Compile the kernel into source
fn compile(&self) -> CompiledKernel;
/// Launch settings.
fn launch_settings(&self) -> LaunchSettings;
}
/// Wraps a [kernel](Kernel) with its [cube count](CubeCount) to create a [cube task](CubeTask).
/// Wraps a [kernel](Kernel) to create a [cube task](CubeTask).
#[derive(new)]
pub struct KernelTask<C: Compiler, K: Kernel> {
kernel_definition: K,
cube_count: CubeCount,
_compiler: PhantomData<C>,
}
@ -57,12 +48,6 @@ impl<C: Compiler, K: Kernel> CubeTask for KernelTask<C, K> {
fn id(&self) -> String {
self.kernel_definition.id().clone()
}
fn launch_settings(&self) -> LaunchSettings {
LaunchSettings {
cube_count: self.cube_count.clone(),
}
}
}
impl CubeTask for Arc<dyn CubeTask> {
@ -73,10 +58,6 @@ impl CubeTask for Arc<dyn CubeTask> {
fn id(&self) -> String {
self.as_ref().id()
}
fn launch_settings(&self) -> LaunchSettings {
self.as_ref().launch_settings()
}
}
impl CubeTask for Box<dyn CubeTask> {
@ -87,26 +68,21 @@ impl CubeTask for Box<dyn CubeTask> {
fn id(&self) -> String {
self.as_ref().id()
}
fn launch_settings(&self) -> LaunchSettings {
self.as_ref().launch_settings()
}
}
/// Provides launch information specifying the number of work groups to be used by a compute shader.
#[derive(new, Clone, Debug)]
pub struct CubeCount {
/// Work groups for the x axis.
pub x: u32,
/// Work groups for the y axis.
pub y: u32,
/// Work groups for the z axis.
pub z: u32,
pub enum CubeCount<S: ComputeServer> {
/// Dispatch x,y,z work groups.
Static(u32, u32, u32),
/// Dispatch work groups based on the values in this buffer. The buffer should contain a u32 array [x, y, z].
Dynamic(Binding<S>),
}
impl CubeCount {
/// Calculate the number of invocations of a compute shader.
pub fn num_invocations(&self) -> usize {
(self.x * self.y * self.z) as usize
impl<S: ComputeServer> Clone for CubeCount<S> {
fn clone(&self) -> Self {
match self {
Self::Static(x, y, z) => Self::Static(*x, *y, *z),
Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
}
}
}

View File

@ -78,15 +78,15 @@ impl<R: Runtime> KernelLauncher<R> {
/// Launch the kernel.
pub fn launch<K: Kernel>(
self,
cube_count: CubeCount,
cube_count: CubeCount<R::Server>,
kernel: K,
client: ComputeClient<R::Server, R::Channel>,
) {
let bindings = self.into_bindings(&client);
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel, cube_count));
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
client.execute(kernel, bindings);
client.execute(kernel, cube_count, bindings);
}
/// We need to create the bindings in the same order they are defined in the compilation step.

View File

@ -6,6 +6,7 @@ extern crate derive_new;
/// Cube Frontend Types.
pub mod frontend;
use burn_compute::server::ComputeServer;
pub use frontend::cmma;
/// Cube Language Internal Representation.
@ -45,13 +46,16 @@ pub trait Kernel: Send + Sync + 'static {
/// Calculate the number of cubes required to execute an operation where one cube unit is
/// assigned to one element.
pub fn calculate_cube_count_elemwise(num_elems: usize, cube_dim: usize) -> CubeCount {
pub fn calculate_cube_count_elemwise<S: ComputeServer>(
num_elems: usize,
cube_dim: usize,
) -> CubeCount<S> {
let num_elems_per_cube = cube_dim * cube_dim;
let cube_counts = f32::ceil(num_elems as f32 / num_elems_per_cube as f32);
let cube_count_x = f32::ceil(f32::sqrt(cube_counts));
let cube_count_y = f32::ceil(num_elems as f32 / (cube_count_x * num_elems_per_cube as f32));
CubeCount::new(cube_count_x as u32, cube_count_y as u32, 1)
CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1)
}
pub fn tensor_vectorization_factor(

View File

@ -1,4 +1,8 @@
use crate::{codegen::Compiler, compute::CubeTask, ir::Elem};
use crate::{
codegen::Compiler,
compute::{CubeCount, CubeTask},
ir::Elem,
};
use burn_compute::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
/// Runtime for the CubeCL.
@ -6,7 +10,11 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
/// The compiler used to compile the inner representation into tokens.
type Compiler: Compiler;
/// The compute server used to run kernels and perform autotuning.
type Server: ComputeServer<Kernel = Box<dyn CubeTask>, FeatureSet = FeatureSet>;
type Server: ComputeServer<
Kernel = Box<dyn CubeTask>,
DispatchOptions = CubeCount<Self::Server>,
FeatureSet = FeatureSet,
>;
/// The channel used to communicate with the compute server.
type Channel: ComputeChannel<Self::Server>;
/// The device used to retrieve the compute client.

View File

@ -61,7 +61,7 @@ pub fn test_simple_1<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
kernel_simple_1_launch::<R>(
client.clone(),
CubeCount::new(1, 1, 1),
CubeCount::Static(1, 1, 1),
CubeDim::new(16, 16, 1),
ArrayArg::new(&lhs, 256),
ArrayArg::new(&rhs, 256),

View File

@ -20,7 +20,7 @@ pub fn test_kernel_with_generics<R: Runtime>(client: ComputeClient<R::Server, R:
kernel_with_generics_launch::<F32, R>(
client.clone(),
CubeCount::new(1, 1, 1),
CubeCount::Static(1, 1, 1),
CubeDim::default(),
ArrayArg::new(&handle, 2),
);
@ -36,7 +36,7 @@ pub fn test_kernel_without_generics<R: Runtime>(client: ComputeClient<R::Server,
kernel_without_generics_launch::<R>(
client.clone(),
CubeCount::new(1, 1, 1),
CubeCount::Static(1, 1, 1),
CubeDim::default(),
ArrayArg::new(&handle, 2),
);

View File

@ -98,7 +98,7 @@ fn test_subcube_operation<TestRuntime: Runtime, Launch>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
launch: Launch,
) where
Launch: Fn(CubeCount, CubeDim, TensorArg<'_, TestRuntime>),
Launch: Fn(CubeCount<TestRuntime::Server>, CubeDim, TensorArg<'_, TestRuntime>),
{
if !client.features().enabled(Feature::Subcube) {
// Can't execute the test.
@ -109,7 +109,7 @@ fn test_subcube_operation<TestRuntime: Runtime, Launch>(
let (shape, strides) = ([input.len()], [1]);
launch(
CubeCount::new(1, 1, 1),
CubeCount::Static(1, 1, 1),
CubeDim::new(input.len() as u32, 1, 1),
TensorArg::new(&handle, &strides, &shape),
);

View File

@ -57,14 +57,8 @@ struct CompiledKernel {
unsafe impl<MM: MemoryManagement<CudaStorage>> Send for CudaServer<MM> {}
impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
type Kernel = Box<dyn CubeTask>;
type Storage = CudaStorage;
type MemoryManagement = MM;
type AutotuneKey = JitAutotuneKey;
type FeatureSet = FeatureSet;
fn read(&mut self, binding: server::Binding<Self>) -> Reader {
impl<MM: MemoryManagement<CudaStorage>> CudaServer<MM> {
fn read_sync(&mut self, binding: server::Binding<Self>) -> Vec<u8> {
let ctx = self.get_context();
let resource = ctx.memory_management.get(binding.memory);
@ -74,7 +68,20 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
cudarc::driver::result::memcpy_dtoh_async(&mut data, resource.ptr, ctx.stream).unwrap();
};
ctx.sync();
reader_from_concrete(data)
data
}
}
impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
type Kernel = Box<dyn CubeTask>;
type DispatchOptions = CubeCount<Self>;
type Storage = CudaStorage;
type MemoryManagement = MM;
type AutotuneKey = JitAutotuneKey;
type FeatureSet = FeatureSet;
fn read(&mut self, binding: server::Binding<Self>) -> Reader {
reader_from_concrete(self.read_sync(binding))
}
fn create(&mut self, data: &[u8]) -> server::Handle<Self> {
@ -101,12 +108,33 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
server::Handle::new(handle)
}
fn execute(&mut self, kernel: Self::Kernel, bindings: Vec<server::Binding<Self>>) {
fn execute(
&mut self,
kernel: Self::Kernel,
count: Self::DispatchOptions,
bindings: Vec<server::Binding<Self>>,
) {
let arch = self.minimum_arch_version;
let ctx = self.get_context();
let kernel_id = kernel.id();
let settings = kernel.launch_settings();
let count = match count {
CubeCount::Static(x, y, z) => (x, y, z),
// TODO: CUDA doesn't have an exact equivalen of dynamic dispatch. Instead, kernels are free to launch other kernels.
// One option is to create a dummy kernel with 1 thread that launches the real kernel with the dynamic dispatch settings.
// For now, just read the dispatch settings from the buffer.
CubeCount::Dynamic(binding) => {
let data = self.read_sync(binding);
let data = bytemuck::cast_slice(&data);
assert!(
data.len() == 3,
"Dynamic cube count should contain 3 values"
);
(data[0], data[1], data[2])
}
};
let ctx = self.get_context();
if !ctx.module_names.contains_key(&kernel_id) {
ctx.compile_kernel(&kernel_id, kernel, arch);
@ -117,7 +145,7 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
.map(|binding| ctx.memory_management.get(binding.memory).as_binding())
.collect();
ctx.execute_task(kernel_id, settings.cube_count, bindings);
ctx.execute_task(kernel_id, count, bindings);
// TODO: fix this
// self.memory_management.storage().perform_deallocations();
}
@ -217,16 +245,15 @@ impl<MM: MemoryManagement<CudaStorage>> CudaContext<MM> {
fn execute_task(
&mut self,
kernel_id: String,
cube_count: CubeCount,
dispatch_count: (u32, u32, u32),
mut bindings: Vec<Binding>,
) {
let kernel = self.module_names.get(&kernel_id).unwrap();
let cube_dim = kernel.cube_dim;
unsafe {
cudarc::driver::result::launch_kernel(
kernel.func,
(cube_count.x, cube_count.y, cube_count.z),
dispatch_count,
(cube_dim.x, cube_dim.y, cube_dim.z),
kernel.shared_mem_bytes as u32,
self.stream,

View File

@ -29,9 +29,9 @@ pub struct CompilationPhase;
/// Phase where the kernel should be executed.
#[derive(new)]
pub struct ExecutionPhase<R: JitRuntime> {
/// Kernel set with default workgroup size.
/// Kernel set with default cube size.
pub(super) kernel_factory_1: ElementWiseKernelFactory<R>,
/// Kernel set with custom workgroup size.
/// Kernel set with custom cube size.
pub(super) kernel_factory_2: ElementWiseKernelFactory<R>,
}

View File

@ -22,7 +22,7 @@ pub struct FusionKernel<R: JitRuntime> {
info: Arc<KernelExpansion>,
settings: KernelSettings,
runtime_info: Vec<OutputRuntimeInfo>,
cube_count: CubeCount,
cube_count: CubeCount<R::Server>,
_runtime: PhantomData<R>,
}
@ -41,6 +41,7 @@ pub trait FusionKernelFactory<R: JitRuntime> {
#[derive(new)]
pub struct ExecutableKernel<R: JitRuntime> {
kernel: Box<dyn CubeTask>,
cube_count: CubeCount<R::Server>,
bindings: Vec<Binding<R::Server>>,
client: ComputeClient<R::Server, R::Channel>,
}
@ -54,6 +55,7 @@ pub struct ExecutableKernel<R: JitRuntime> {
#[derive(new)]
pub struct AutotunableKernel<R: JitRuntime> {
kernel: Arc<dyn CubeTask>,
count: CubeCount<R::Server>,
bindings: Vec<Binding<R::Server>>,
client: ComputeClient<R::Server, R::Channel>,
}
@ -68,18 +70,21 @@ pub enum OutputRuntimeInfo {
impl<R: JitRuntime> ExecutableKernel<R> {
/// Execute the kernel.
pub fn execute(self) {
self.client.execute(self.kernel, self.bindings)
self.client
.execute(self.kernel, self.cube_count, self.bindings)
}
}
impl<R: JitRuntime> AutotuneOperation for AutotunableKernel<R> {
fn execute(self: Box<Self>) {
self.client.execute(Box::new(self.kernel), self.bindings)
self.client
.execute(Box::new(self.kernel), self.count, self.bindings)
}
fn clone(&self) -> Box<dyn AutotuneOperation> {
Box::new(Self {
kernel: self.kernel.clone(),
count: self.count.clone(),
bindings: self.bindings.clone(),
client: self.client.clone(),
})
@ -90,6 +95,7 @@ impl<R: JitRuntime> From<ExecutableKernel<R>> for AutotunableKernel<R> {
fn from(value: ExecutableKernel<R>) -> Self {
Self {
kernel: Arc::new(value.kernel),
count: value.cube_count.clone(),
bindings: value.bindings,
client: value.client,
}
@ -233,12 +239,10 @@ impl<R: JitRuntime> FusionKernel<R> {
context.handles.register_handle(id, handle);
}
let workgroup = fusion_kernel.cube_count.clone();
let cube_count = fusion_kernel.cube_count.clone();
ExecutableKernel::new(
Box::new(KernelTask::<R::Compiler, FusionKernel<R>>::new(
fusion_kernel,
workgroup,
)),
Box::new(KernelTask::<R::Compiler, _>::new(fusion_kernel)),
cube_count,
bindings,
client,
)

View File

@ -196,7 +196,7 @@ pub(crate) fn scatter<R: JitRuntime, E: JitElement, I: JitElement, const D: usiz
let kernel = ScatterEagerKernel::<R, E>::new(dim);
let mut strides = [0; D];
let mut current = 1;
let mut num_elems_per_workgroup = 1;
let mut num_elems = 1;
tensor
.shape
@ -208,13 +208,13 @@ pub(crate) fn scatter<R: JitRuntime, E: JitElement, I: JitElement, const D: usiz
.for_each(|(index, val)| {
strides[index] = current;
current *= val;
num_elems_per_workgroup *= tensor.shape.dims[index];
num_elems *= tensor.shape.dims[index];
});
// Fake strides of the virtual output where the strides of dim is hardcoded to one.
indices.strides = strides;
let workgroup = calculate_cube_count_elemwise(num_elems_per_workgroup, SUBCUBE_DIM_APPROX);
let cube_count = calculate_cube_count_elemwise(num_elems, SUBCUBE_DIM_APPROX);
Execution::start(kernel, indices.client)
.inputs(&[
@ -222,7 +222,7 @@ pub(crate) fn scatter<R: JitRuntime, E: JitElement, I: JitElement, const D: usiz
TensorHandle::new(&indices.handle, &indices.strides, &indices.shape.dims),
TensorHandle::new(&value.handle, &value.strides, &value.shape.dims),
])
.execute(CubeCountSettings::Custom(workgroup));
.execute(CubeCountSettings::Custom(cube_count));
tensor
}

View File

@ -189,7 +189,7 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
let mut strides = [0; D];
let mut current = 1;
let mut num_elems_per_workgroup = 1;
let mut num_elems = 1;
tensor
.shape
@ -201,11 +201,11 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
.for_each(|(index, val)| {
strides[index] = current;
current *= val;
num_elems_per_workgroup *= tensor.shape.dims[index];
num_elems *= tensor.shape.dims[index];
});
let kernel = SelectAssignEagerKernel::<R, E>::new(dim);
let workgroup = calculate_cube_count_elemwise(num_elems_per_workgroup, SUBCUBE_DIM_APPROX);
let cube_count = calculate_cube_count_elemwise(num_elems, SUBCUBE_DIM_APPROX);
Execution::start(kernel, indices.client)
.inputs(&[
@ -215,7 +215,7 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
// kernel, but we need to put the right number of dimensions (rank).
TensorHandle::new(&indices.handle, &strides, &strides),
])
.execute(CubeCountSettings::Custom(workgroup));
.execute(CubeCountSettings::Custom(cube_count));
tensor
}

View File

@ -140,41 +140,39 @@ pub fn matmul<R: JitRuntime, E: FloatElement, const D: usize>(
}
}
pub(crate) fn simple_launch_options<const D: usize>(
pub(crate) fn simple_cube_count<R: JitRuntime, const D: usize>(
lhs_shape: &Shape<D>,
rhs_shape: &Shape<D>,
output_shape: &Shape<D>,
workgroup_size_x: usize,
workgroup_size_y: usize,
) -> CubeCount {
cube_dim_x: usize,
cube_dim_y: usize,
) -> CubeCount<R::Server> {
let num_rows = lhs_shape.dims[D - 2];
let num_cols = rhs_shape.dims[D - 1];
// set number of workgroups
let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32;
let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32;
let cubes_x = f32::ceil(num_rows as f32 / cube_dim_x as f32) as u32;
let cubes_y = f32::ceil(num_cols as f32 / cube_dim_y as f32) as u32;
let mut num_iter = 1;
for i in 0..D - 2 {
num_iter *= output_shape.dims[i];
}
CubeCount::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32)
CubeCount::Static(cubes_x, cubes_y, num_iter as u32)
}
pub(crate) fn tiling2d_launch_options<const D: usize>(
pub(crate) fn tiling2d_launch_options<R: JitRuntime, const D: usize>(
output_shape: &Shape<D>,
config: Tiling2dConfig,
) -> CubeCount {
) -> CubeCount<R::Server> {
let num_rows = output_shape.dims[D - 2];
let num_cols = output_shape.dims[D - 1];
// set number of workgroups
let blocks_needed_in_x = f32::ceil(num_rows as f32 / config.block_size_m as f32) as u32;
let blocks_needed_in_y = f32::ceil(num_cols as f32 / config.block_size_n as f32) as u32;
let cubes_x = f32::ceil(num_rows as f32 / config.block_size_m as f32) as u32;
let cubes_y = f32::ceil(num_cols as f32 / config.block_size_n as f32) as u32;
let mut num_iter = 1;
for i in 0..D - 2 {
num_iter *= output_shape.dims[i];
}
CubeCount::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32)
CubeCount::Static(cubes_x, cubes_y, num_iter as u32)
}

View File

@ -7,7 +7,7 @@ use crate::{
use burn_cube::ir::KernelDefinition;
use burn_cube::{frontend::TensorArg, KernelSettings};
use super::simple_launch_options;
use super::simple_cube_count;
use burn_cube::prelude::*;
#[cube(launch)]
@ -80,7 +80,7 @@ fn matmul_kernel<F: Float>(
}
}
/// Matrix multiplication using memory coalescing algorithm with workgroups of size 16
/// Matrix multiplication using memory coalescing algorithm with cube dimensions of size 16
pub fn matmul_mem_coalescing_default<R: JitRuntime, E: FloatElement, const D: usize>(
lhs: JitTensor<R, E, D>,
rhs: JitTensor<R, E, D>,
@ -89,7 +89,7 @@ pub fn matmul_mem_coalescing_default<R: JitRuntime, E: FloatElement, const D: us
matmul_simple::<R, E, D>(lhs, rhs, out, SUBCUBE_DIM_APPROX, SUBCUBE_DIM_APPROX)
}
/// Matrix multiplication using memory coalescing algorithm with custom workgroup sizes
/// Matrix multiplication using memory coalescing algorithm with custom cube dimensions
pub fn matmul_simple<R: JitRuntime, E: FloatElement, const D: usize>(
lhs: JitTensor<R, E, D>,
rhs: JitTensor<R, E, D>,
@ -103,7 +103,7 @@ pub fn matmul_simple<R: JitRuntime, E: FloatElement, const D: usize>(
let rhs_original_shape = rhs.shape.clone();
let rhs = into_contiguous(swap_dims(rhs, D - 1, D - 2));
let cube_count = simple_launch_options(
let cube_count = simple_cube_count::<R, D>(
&lhs.shape,
&rhs_original_shape,
&out.shape,

View File

@ -118,7 +118,7 @@ pub fn matmul_tiling_2d<R: JitRuntime, E: JitElement + Element, const D: usize>(
&out.strides,
&out.shape.dims,
)])
.execute(CubeCountSettings::Custom(tiling2d_launch_options(
.execute(CubeCountSettings::Custom(tiling2d_launch_options::<R, D>(
&out.shape, config,
)));
@ -175,7 +175,7 @@ pub fn matmul_tiling_2d_padded<R: JitRuntime, E: JitElement + Element, const D:
&rounded_output.strides,
&rounded_output.shape.dims,
)])
.execute(CubeCountSettings::Custom(tiling2d_launch_options(
.execute(CubeCountSettings::Custom(tiling2d_launch_options::<R, D>(
&rounded_output.shape,
config,
)));

View File

@ -55,14 +55,14 @@ pub(crate) fn gather_shader_information(
cpa!(scope, out_stride_row = stride(out, second_to_last_dim));
cpa!(scope, out_stride_col = stride(out, last_dim));
// Workgroup offset
// Cube offset
let skip_row = scope.create_local(Elem::UInt);
let skip_col = scope.create_local(Elem::UInt);
let workgroup_id_x = Variable::CubePosX;
let workgroup_id_y = Variable::CubePosY;
cpa!(scope, skip_row = workgroup_id_x);
let cube_pos_x = Variable::CubePosX;
let cube_pos_y = Variable::CubePosY;
cpa!(scope, skip_row = cube_pos_x);
cpa!(scope, skip_row *= block_size_m);
cpa!(scope, skip_col = workgroup_id_y);
cpa!(scope, skip_col = cube_pos_y);
cpa!(scope, skip_col *= block_size_n);
// Position of the first element of the thread, relative to the block

View File

@ -38,7 +38,7 @@ pub(crate) fn random<P: Prng<E>, R: JitRuntime, E: JitElement, const D: usize>(
)])
.with_scalars(&seeds)
.with_scalars(&prng.args())
.execute(CubeCountSettings::Custom(prng_cube_count(
.execute(CubeCountSettings::Custom(prng_cube_count::<R>(
num_elems,
SUBCUBE_DIM_APPROX,
N_VALUES_PER_THREAD,
@ -47,14 +47,18 @@ pub(crate) fn random<P: Prng<E>, R: JitRuntime, E: JitElement, const D: usize>(
output
}
fn prng_cube_count(num_elems: usize, cube_dim: usize, n_values_per_thread: usize) -> CubeCount {
fn prng_cube_count<R: JitRuntime>(
num_elems: usize,
cube_dim: usize,
n_values_per_thread: usize,
) -> CubeCount<R::Server> {
let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32);
let num_elems_per_cube = cube_dim * cube_dim;
let num_invocations = f32::ceil(num_threads / num_elems_per_cube as f32);
let workgroup_x = f32::ceil(f32::sqrt(num_invocations));
let workgroup_y = f32::ceil(num_invocations / workgroup_x);
let cubes_x = f32::ceil(f32::sqrt(num_invocations));
let cubes_y = f32::ceil(num_invocations / cubes_x);
CubeCount::new(workgroup_x as u32, workgroup_y as u32, 1)
CubeCount::Static(cubes_x as u32, cubes_y as u32, 1)
}
impl<P: Prng<E>, R: JitRuntime, E: JitElement> Kernel for PrngEagerKernel<P, R, E> {
@ -163,24 +167,24 @@ impl<P: Prng<E>, E: JitElement> PrngShader<P, E> {
let n_values_per_thread: Variable = self.n_values_per_thread.into();
let args = self.args;
let workgroup_size_x = Variable::CubeDimX;
let workgroup_size_y = Variable::CubeDimY;
let workgroup_id_x = Variable::CubePosX;
let workgroup_id_y = Variable::CubePosY;
let num_workgroups_y = Variable::CubeCountY;
let cube_dim_x = Variable::CubeDimX;
let cube_dim_y = Variable::CubeDimY;
let cube_pos_x = Variable::CubePosX;
let cube_pos_y = Variable::CubePosY;
let cube_count_y = Variable::CubeCountY;
let local_index = Variable::UnitPos;
let n_invocations = scope.create_local(Elem::UInt);
cpa!(scope, n_invocations = workgroup_size_x);
cpa!(scope, n_invocations *= workgroup_size_y);
cpa!(scope, n_invocations = cube_dim_x);
cpa!(scope, n_invocations *= cube_dim_y);
let workgroup_offset = scope.create_local(Elem::UInt);
cpa!(scope, workgroup_offset = workgroup_id_x * num_workgroups_y);
cpa!(scope, workgroup_offset += workgroup_id_y);
cpa!(scope, workgroup_offset *= n_invocations);
let cube_offset = scope.create_local(Elem::UInt);
cpa!(scope, cube_offset = cube_pos_x * cube_count_y);
cpa!(scope, cube_offset += cube_pos_y);
cpa!(scope, cube_offset *= n_invocations);
let write_index_base = scope.create_local(Elem::UInt);
cpa!(scope, write_index_base = workgroup_offset);
cpa!(scope, write_index_base = cube_offset);
cpa!(scope, write_index_base *= n_values_per_thread);
cpa!(scope, write_index_base += local_index);
@ -188,7 +192,7 @@ impl<P: Prng<E>, E: JitElement> PrngShader<P, E> {
let thread_seed = scope.create_local(Elem::UInt);
cpa!(scope, thread_seed = cast(1000000007));
let thread_seed_index = scope.create_local(Elem::UInt);
cpa!(scope, thread_seed_index = workgroup_offset + local_index);
cpa!(scope, thread_seed_index = cube_offset + local_index);
cpa!(scope, thread_seed *= thread_seed_index);
let state_0 = scope.create_local(Elem::UInt);

View File

@ -33,8 +33,8 @@ pub(crate) struct SharedReduceDimEagerKernel<
EO: JitElement,
> {
dim: usize,
workgroup_size_x: usize,
workgroup_size_y: usize,
cube_dim_x: usize,
cube_dim_y: usize,
n_input_values_per_thread: u32,
divisible_shape: bool,
_reduce_dim: PhantomData<RD>,
@ -58,7 +58,7 @@ impl<RD: ReduceDimShared<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Ker
SharedReduceDimComputeShader {
tensor,
dim: self.dim,
shared_memory_size: self.workgroup_size_x * self.workgroup_size_y,
shared_memory_size: self.cube_dim_x * self.cube_dim_y,
n_input_values_per_thread: self.n_input_values_per_thread,
output,
divisible_shape: self.divisible_shape,
@ -83,8 +83,8 @@ impl<RD: ReduceDimShared<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Ker
};
let settings = KernelSettings::default().cube_dim(CubeDim::new(
self.workgroup_size_x as u32,
self.workgroup_size_y as u32,
self.cube_dim_x as u32,
self.cube_dim_y as u32,
1,
));
KernelIntegrator::new(info).integrate(settings)
@ -95,8 +95,8 @@ impl<RD: ReduceDimShared<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Ker
"{:?}dim={}x={}y={}n={}divshape={}",
core::any::TypeId::of::<Self>(),
self.dim,
self.workgroup_size_x,
self.workgroup_size_y,
self.cube_dim_x,
self.cube_dim_y,
self.n_input_values_per_thread,
self.divisible_shape
)
@ -111,13 +111,13 @@ impl<E: JitElement, RD: ReduceDimShared<E>> SharedReduceDimComputeShader<E, RD>
let rank = Variable::Rank;
let dim: Variable = self.dim.into();
let workgroup_id_x = Variable::CubePosX;
let workgroup_id_y = Variable::CubePosY;
let num_workgroups_x = Variable::CubeCountX;
let cube_pos_x = Variable::CubePosX;
let cube_pos_y = Variable::CubePosY;
let cube_count_x = Variable::CubeCountX;
let local_invocation_id_x = Variable::UnitPosX;
let local_invocation_id_y = Variable::UnitPosY;
let workgroup_size_x = Variable::CubeDimX;
let workgroup_size_y = Variable::CubeDimY;
let cube_dim_x = Variable::CubeDimX;
let cube_dim_y = Variable::CubeDimY;
let stride_reduce_dim_input = scope.create_local(Elem::UInt);
cpa!(scope, stride_reduce_dim_input = stride(tensor, dim));
@ -126,16 +126,16 @@ impl<E: JitElement, RD: ReduceDimShared<E>> SharedReduceDimComputeShader<E, RD>
// To determine which reduce_group (not position, but absolute id)
let reduce_group_id = scope.create_local(Elem::UInt);
cpa!(scope, reduce_group_id = workgroup_id_y * num_workgroups_x);
cpa!(scope, reduce_group_id += workgroup_id_x);
cpa!(scope, reduce_group_id = cube_pos_y * cube_count_x);
cpa!(scope, reduce_group_id += cube_pos_x);
// nth thread in the workgroup
// nth thread in the cube
let local_id = scope.create_local(Elem::UInt);
cpa!(scope, local_id = local_invocation_id_y * workgroup_size_x);
cpa!(scope, local_id = local_invocation_id_y * cube_dim_x);
cpa!(scope, local_id += local_invocation_id_x);
let n_threads = scope.create_local(Elem::UInt);
cpa!(scope, n_threads = workgroup_size_x * workgroup_size_y);
cpa!(scope, n_threads = cube_dim_x * cube_dim_y);
let index_offset = scope.zero(Elem::UInt);
@ -242,17 +242,17 @@ pub fn reduce_dim_shared<
dim: usize,
) -> JitTensor<R, EO, D> {
let num_elems_output = output.shape.num_elements();
let n_workgroups_x = f32::ceil(f32::sqrt(num_elems_output as f32));
let n_workgroups_y = f32::ceil(num_elems_output as f32 / n_workgroups_x);
let grid = CubeCount::new(n_workgroups_x as u32, n_workgroups_y as u32, 1);
let cube_count_x = f32::ceil(f32::sqrt(num_elems_output as f32));
let cube_count_y = f32::ceil(num_elems_output as f32 / cube_count_x);
let grid = CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1);
let reduce_group_size = input.shape.dims[dim];
let n_invocation_per_workgroup = SUBCUBE_DIM_APPROX * SUBCUBE_DIM_APPROX;
let n_invocation_per_cube = SUBCUBE_DIM_APPROX * SUBCUBE_DIM_APPROX;
let n_input_values_per_thread =
f32::ceil(reduce_group_size as f32 / n_invocation_per_workgroup as f32) as u32;
f32::ceil(reduce_group_size as f32 / n_invocation_per_cube as f32) as u32;
let divisible_shape =
n_invocation_per_workgroup as u32 * n_input_values_per_thread == reduce_group_size as u32;
n_invocation_per_cube as u32 * n_input_values_per_thread == reduce_group_size as u32;
let kernel = SharedReduceDimEagerKernel::<RD, R, EI, EO>::new(
dim,

View File

@ -18,7 +18,10 @@ pub(crate) mod tune;
/// Elements for JIT backend
pub mod element;
use burn_cube::{compute::CubeTask, Runtime};
use burn_cube::{
compute::{CubeCount, CubeTask},
Runtime,
};
pub use element::{FloatElement, IntElement, JitElement};
mod backend;
@ -48,5 +51,6 @@ pub trait JitRuntime: Runtime<Device = Self::JitDevice, Server = Self::JitServer
type JitServer: burn_compute::server::ComputeServer<
AutotuneKey = JitAutotuneKey,
Kernel = Box<dyn CubeTask>,
DispatchOptions = CubeCount<Self::JitServer>,
>;
}

View File

@ -1,5 +1,4 @@
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
use burn_cube::compute::LaunchSettings;
use burn_cube::prelude::*;
use super::SourceTemplate;
@ -11,18 +10,13 @@ pub trait KernelSource: Send + 'static + Sync {
}
#[derive(new)]
/// Wraps a [kernel source](KernelSource) into a [cube task](CubeTask) with launch
/// information.
/// Wraps a [kernel source](KernelSource) into a [cube task](CubeTask).
pub struct SourceKernel<K> {
kernel_source: K,
cube_count: CubeCount,
cube_dim: CubeDim,
}
impl<K> CubeTask for SourceKernel<K>
where
K: KernelSource + 'static,
{
impl<K: KernelSource> CubeTask for SourceKernel<K> {
fn compile(&self) -> CompiledKernel {
let source_template = self.kernel_source.source();
let source = source_template.complete();
@ -37,12 +31,6 @@ where
fn id(&self) -> String {
format!("{:?}", core::any::TypeId::of::<K>())
}
fn launch_settings(&self) -> LaunchSettings {
LaunchSettings {
cube_count: self.cube_count.clone(),
}
}
}
/// Generates kernel source code by replacing some information using templating.

View File

@ -64,8 +64,15 @@ where
&mut self,
pipeline: Arc<ComputePipeline>,
bind_group: BindGroup,
work_group: CubeCount,
count: CubeCount<Self>,
) {
// First resolve the dispatch buffer if needed. The weird ordering is because the lifetime of this
// needs to be longer than the compute pass, so we can't do this just before dispatching.
let dispatch_resource = match count.clone() {
CubeCount::Dynamic(binding) => Some(self.memory_management.get(binding.memory)),
_ => None,
};
let mut compute = self
.encoder
.begin_compute_pass(&wgpu::ComputePassDescriptor {
@ -75,12 +82,21 @@ where
compute.set_pipeline(&pipeline);
compute.set_bind_group(0, &bind_group, &[]);
compute.dispatch_workgroups(work_group.x, work_group.y, work_group.z);
match count {
CubeCount::Static(x, y, z) => {
compute.dispatch_workgroups(x, y, z);
}
CubeCount::Dynamic(_) => {
let resource = dispatch_resource.as_ref().unwrap();
compute.dispatch_workgroups_indirect(&resource.buffer, resource.offset());
}
}
self.tasks_count += 1;
}
fn pipeline(&mut self, kernel: Box<dyn CubeTask>) -> Arc<ComputePipeline> {
fn pipeline(&mut self, kernel: <Self as ComputeServer>::Kernel) -> Arc<ComputePipeline> {
let kernel_id = kernel.id();
if let Some(pipeline) = self.pipelines.get(&kernel_id) {
@ -143,6 +159,7 @@ where
MM: MemoryManagement<WgpuStorage>,
{
type Kernel = Box<dyn CubeTask>;
type DispatchOptions = CubeCount<Self>;
type Storage = WgpuStorage;
type MemoryManagement = MM;
type AutotuneKey = JitAutotuneKey;
@ -259,8 +276,12 @@ where
}))
}
fn execute(&mut self, kernel: Self::Kernel, bindings: Vec<server::Binding<Self>>) {
let work_group = kernel.launch_settings().cube_count;
fn execute(
&mut self,
kernel: Self::Kernel,
count: Self::DispatchOptions,
bindings: Vec<server::Binding<Self>>,
) {
let pipeline = self.pipeline(kernel);
let group_layout = pipeline.get_bind_group_layout(0);
@ -284,7 +305,7 @@ where
entries: &entries,
});
self.register_compute(pipeline, bind_group, work_group);
self.register_compute(pipeline, bind_group, count);
if self.tasks_count >= self.tasks_max {
self.sync(SyncType::Flush);

View File

@ -111,7 +111,8 @@ impl ComputeStorage for WgpuStorage {
size: size as u64,
usage: wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC,
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::INDIRECT,
mapped_at_creation: false,
}));

View File

@ -87,11 +87,13 @@ impl<F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime, F, I> {
// Declare the wgsl workgroup with the number of cubes in x, y and z.
let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;
let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32;
let cube_count = CubeCount::new(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);
let cube_count =
CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);
// Execute lazily the kernel with the launch information and the given buffers.
lhs.client.execute(
Box::new(SourceKernel::new(kernel, cube_count, cube_dim)),
Box::new(SourceKernel::new(kernel, cube_dim)),
cube_count,
vec![
lhs.handle.binding(),
rhs.handle.binding(),

View File

@ -22,7 +22,7 @@ pub fn launch<R: Runtime>(device: &R::Device) {
gelu_launch::<F32, R>(
client.clone(),
CubeCount::new(1, 1, 1),
CubeCount::Static(1, 1, 1),
CubeDim::default(),
ArrayArg::new(&input_handle, input.len()),
ArrayArg::new(&output_handle, input.len()),