mirror of https://github.com/tracel-ai/burn.git
Feat: Dynamic cube count dispatch (#1975)
This commit is contained in:
parent
b331290f8a
commit
3f9e97946f
|
@ -248,11 +248,12 @@ impl<F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime, F, I> {
|
|||
// Declare the wgsl workgroup with the number of blocks in x, y and z.
|
||||
let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;
|
||||
let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32;
|
||||
let cube_count = CubeCount::new(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);
|
||||
let cube_count = CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);
|
||||
|
||||
// Execute lazily the kernel with the launch information and the given buffers.
|
||||
lhs.client.execute(
|
||||
Box::new(SourceKernel::new(kernel, cube_count, cube_dim)),
|
||||
Box::new(SourceKernel::new(kernel, cube_dim)),
|
||||
cube_count,
|
||||
vec![
|
||||
lhs.handle.binding(),
|
||||
rhs.handle.binding(),
|
||||
|
|
|
@ -24,7 +24,12 @@ pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug + Send
|
|||
fn empty(&self, size: usize) -> Handle<Server>;
|
||||
|
||||
/// Executes the `kernel` over the given `bindings`.
|
||||
fn execute(&self, kernel: Server::Kernel, bindings: Vec<Binding<Server>>);
|
||||
fn execute(
|
||||
&self,
|
||||
kernel: Server::Kernel,
|
||||
count: Server::DispatchOptions,
|
||||
bindings: Vec<Binding<Server>>,
|
||||
);
|
||||
|
||||
/// Perform some synchronization of commands on the server.
|
||||
fn sync(&self, sync_type: SyncType);
|
||||
|
|
|
@ -63,10 +63,15 @@ where
|
|||
self.server.borrow_mut().empty(size)
|
||||
}
|
||||
|
||||
fn execute(&self, kernel_description: Server::Kernel, bindings: Vec<Binding<Server>>) {
|
||||
fn execute(
|
||||
&self,
|
||||
kernel_description: Server::Kernel,
|
||||
count: Server::DispatchOptions,
|
||||
bindings: Vec<Binding<Server>>,
|
||||
) {
|
||||
self.server
|
||||
.borrow_mut()
|
||||
.execute(kernel_description, bindings)
|
||||
.execute(kernel_description, count, bindings)
|
||||
}
|
||||
|
||||
fn sync(&self, sync_type: SyncType) {
|
||||
|
|
|
@ -39,7 +39,10 @@ where
|
|||
),
|
||||
Create(Vec<u8>, Callback<Handle<Server>>),
|
||||
Empty(usize, Callback<Handle<Server>>),
|
||||
ExecuteKernel(Server::Kernel, Vec<Binding<Server>>),
|
||||
ExecuteKernel(
|
||||
(Server::Kernel, Server::DispatchOptions),
|
||||
Vec<Binding<Server>>,
|
||||
),
|
||||
Sync(SyncType, Callback<()>),
|
||||
}
|
||||
|
||||
|
@ -74,7 +77,7 @@ where
|
|||
callback.send(handle).await.unwrap();
|
||||
}
|
||||
Message::ExecuteKernel(kernel, bindings) => {
|
||||
server.execute(kernel, bindings);
|
||||
server.execute(kernel.0, kernel.1, bindings);
|
||||
}
|
||||
Message::Sync(sync_type, callback) => {
|
||||
server.sync(sync_type);
|
||||
|
@ -148,10 +151,15 @@ where
|
|||
handle_response(response.recv_blocking())
|
||||
}
|
||||
|
||||
fn execute(&self, kernel: Server::Kernel, bindings: Vec<Binding<Server>>) {
|
||||
fn execute(
|
||||
&self,
|
||||
kernel: Server::Kernel,
|
||||
count: Server::DispatchOptions,
|
||||
bindings: Vec<Binding<Server>>,
|
||||
) {
|
||||
self.state
|
||||
.sender
|
||||
.send_blocking(Message::ExecuteKernel(kernel, bindings))
|
||||
.send_blocking(Message::ExecuteKernel((kernel, count), bindings))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
|
|
|
@ -56,8 +56,13 @@ where
|
|||
self.server.lock().empty(size)
|
||||
}
|
||||
|
||||
fn execute(&self, kernel: Server::Kernel, handles: Vec<Binding<Server>>) {
|
||||
self.server.lock().execute(kernel, handles)
|
||||
fn execute(
|
||||
&self,
|
||||
kernel: Server::Kernel,
|
||||
count: Server::DispatchOptions,
|
||||
handles: Vec<Binding<Server>>,
|
||||
) {
|
||||
self.server.lock().execute(kernel, count, handles)
|
||||
}
|
||||
|
||||
fn sync(&self, sync_type: SyncType) {
|
||||
|
|
|
@ -82,8 +82,13 @@ where
|
|||
}
|
||||
|
||||
/// Executes the `kernel` over the given `bindings`.
|
||||
pub fn execute(&self, kernel: Server::Kernel, bindings: Vec<Binding<Server>>) {
|
||||
self.channel.execute(kernel, bindings)
|
||||
pub fn execute(
|
||||
&self,
|
||||
kernel: Server::Kernel,
|
||||
count: Server::DispatchOptions,
|
||||
bindings: Vec<Binding<Server>>,
|
||||
) {
|
||||
self.channel.execute(kernel, count, bindings)
|
||||
}
|
||||
|
||||
/// Wait for the completion of every task in the server.
|
||||
|
|
|
@ -17,6 +17,8 @@ where
|
|||
{
|
||||
/// The kernel type defines the computation algorithms.
|
||||
type Kernel: Send;
|
||||
/// Options when dispatching the kernel, eg. the number of executions.
|
||||
type DispatchOptions: Send;
|
||||
/// The [storage](ComputeStorage) type defines how data is stored and accessed.
|
||||
type Storage: ComputeStorage;
|
||||
/// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type.
|
||||
|
@ -45,7 +47,12 @@ where
|
|||
///
|
||||
/// Kernels have mutable access to every resource they are given
|
||||
/// and are responsible of determining which should be read or written.
|
||||
fn execute(&mut self, kernel: Self::Kernel, bindings: Vec<Binding<Self>>);
|
||||
fn execute(
|
||||
&mut self,
|
||||
kernel: Self::Kernel,
|
||||
count: Self::DispatchOptions,
|
||||
bindings: Vec<Binding<Self>>,
|
||||
);
|
||||
|
||||
/// Wait for the completion of every task in the server.
|
||||
fn sync(&mut self, command: SyncType);
|
||||
|
|
|
@ -21,6 +21,7 @@ impl<MM> ComputeServer for DummyServer<MM>
|
|||
where
|
||||
MM: MemoryManagement<BytesStorage>,
|
||||
{
|
||||
type DispatchOptions = ();
|
||||
type Kernel = Arc<dyn DummyKernel>;
|
||||
type Storage = BytesStorage;
|
||||
type MemoryManagement = MM;
|
||||
|
@ -53,7 +54,12 @@ where
|
|||
Handle::new(self.memory_management.reserve(size, || {}))
|
||||
}
|
||||
|
||||
fn execute(&mut self, kernel: Self::Kernel, bindings: Vec<Binding<Self>>) {
|
||||
fn execute(
|
||||
&mut self,
|
||||
kernel: Self::Kernel,
|
||||
_count: Self::DispatchOptions,
|
||||
bindings: Vec<Binding<Self>>,
|
||||
) {
|
||||
let mut resources = bindings
|
||||
.into_iter()
|
||||
.map(|binding| self.memory_management.get(binding.memory))
|
||||
|
|
|
@ -18,7 +18,7 @@ pub struct OneKernelAutotuneOperation {
|
|||
impl AutotuneOperation for OneKernelAutotuneOperation {
|
||||
/// Executes the operation on given bindings and server, with the additional parameters
|
||||
fn execute(self: Box<Self>) {
|
||||
self.client.execute(self.kernel.clone(), self.bindings);
|
||||
self.client.execute(self.kernel.clone(), (), self.bindings);
|
||||
}
|
||||
|
||||
fn clone(&self) -> Box<dyn AutotuneOperation> {
|
||||
|
|
|
@ -38,6 +38,7 @@ fn execute_elementwise_addition() {
|
|||
|
||||
client.execute(
|
||||
Arc::new(DummyElementwiseAddition),
|
||||
(),
|
||||
vec![lhs.binding(), rhs.binding(), out.clone().binding()],
|
||||
);
|
||||
|
||||
|
|
|
@ -470,7 +470,7 @@ pub fn codegen_launch(sig: &syn::Signature) -> TokenStream {
|
|||
/// Launch
|
||||
pub fn #ident #generics (
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
cube_count: CubeCount,
|
||||
cube_count: CubeCount<R::Server>,
|
||||
cube_dim: CubeDim,
|
||||
#inputs
|
||||
) -> #output {
|
||||
|
|
|
@ -4,13 +4,13 @@ use crate::ir::Elem;
|
|||
use crate::pod::CubeElement;
|
||||
use crate::{calculate_cube_count_elemwise, Kernel, Runtime, SUBCUBE_DIM_APPROX};
|
||||
use burn_compute::client::ComputeClient;
|
||||
use burn_compute::server::{Binding, Handle};
|
||||
use burn_compute::server::{Binding, ComputeServer, Handle};
|
||||
|
||||
/// The position of the input or output to calculate the number of workgroups to launch.
|
||||
pub enum CubeCountSettings {
|
||||
/// The position of the input or output to calculate the number of cubes to launch.
|
||||
pub enum CubeCountSettings<S: ComputeServer> {
|
||||
Input { pos: usize },
|
||||
Output { pos: usize },
|
||||
Custom(CubeCount),
|
||||
Custom(CubeCount<S>),
|
||||
}
|
||||
|
||||
pub struct Execution<'h, K, R: Runtime, Scalars> {
|
||||
|
@ -73,7 +73,7 @@ where
|
|||
}
|
||||
/// Execute a dynamic kernel.
|
||||
#[allow(unused)]
|
||||
pub fn execute(self, launch: CubeCountSettings) {
|
||||
pub fn execute(self, launch: CubeCountSettings<R::Server>) {
|
||||
execute_dynamic::<R, K, f32, f32, f32>(
|
||||
self.inputs,
|
||||
self.outputs,
|
||||
|
@ -108,7 +108,7 @@ where
|
|||
|
||||
/// Execute a dynamic kernel.
|
||||
#[allow(unused)]
|
||||
pub fn execute(self, launch: CubeCountSettings) {
|
||||
pub fn execute(self, launch: CubeCountSettings<R::Server>) {
|
||||
execute_dynamic::<R, K, E, f32, f32>(
|
||||
self.inputs,
|
||||
self.outputs,
|
||||
|
@ -144,7 +144,7 @@ where
|
|||
}
|
||||
/// Execute a dynamic kernel.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn execute(self, launch: CubeCountSettings)
|
||||
pub fn execute(self, launch: CubeCountSettings<R::Server>)
|
||||
where
|
||||
K: Kernel + 'static,
|
||||
R: Runtime,
|
||||
|
@ -172,7 +172,7 @@ where
|
|||
{
|
||||
/// Execute a dynamic kernel.
|
||||
#[allow(unused)]
|
||||
pub fn execute(self, launch: CubeCountSettings) {
|
||||
pub fn execute(self, launch: CubeCountSettings<R::Server>) {
|
||||
execute_dynamic::<R, K, E1, E2, E3>(
|
||||
self.inputs,
|
||||
self.outputs,
|
||||
|
@ -194,7 +194,7 @@ fn execute_dynamic<R, K, E1, E2, E3>(
|
|||
scalars_2: Option<&[E2]>,
|
||||
scalars_3: Option<&[E3]>,
|
||||
kernel: K,
|
||||
launch: CubeCountSettings,
|
||||
launch: CubeCountSettings<R::Server>,
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
) where
|
||||
K: Kernel + 'static,
|
||||
|
@ -207,23 +207,21 @@ fn execute_dynamic<R, K, E1, E2, E3>(
|
|||
inputs, outputs, scalars_1, scalars_2, scalars_3, launch, &client,
|
||||
);
|
||||
let mut handles = settings.handles_tensors;
|
||||
let workgroup = settings.cube_count;
|
||||
|
||||
handles.push(settings.handle_info.binding());
|
||||
for handle in settings.handles_scalars.into_iter() {
|
||||
handles.push(handle.binding());
|
||||
}
|
||||
|
||||
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel, workgroup));
|
||||
|
||||
client.execute(kernel, handles);
|
||||
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
|
||||
client.execute(kernel, settings.cube_count, handles);
|
||||
}
|
||||
|
||||
struct ExecuteSettings<R: Runtime> {
|
||||
handles_tensors: Vec<Binding<R::Server>>,
|
||||
handle_info: Handle<R::Server>,
|
||||
handles_scalars: Vec<Handle<R::Server>>,
|
||||
cube_count: CubeCount,
|
||||
cube_count: CubeCount<R::Server>,
|
||||
}
|
||||
|
||||
fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeElement>(
|
||||
|
@ -232,7 +230,7 @@ fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeEl
|
|||
scalars_1: Option<&[E1]>,
|
||||
scalars_2: Option<&[E2]>,
|
||||
scalars_3: Option<&[E3]>,
|
||||
launch: CubeCountSettings,
|
||||
launch: CubeCountSettings<R::Server>,
|
||||
client: &ComputeClient<R::Server, R::Channel>,
|
||||
) -> ExecuteSettings<R> {
|
||||
let mut info = Vec::new();
|
||||
|
@ -295,8 +293,8 @@ fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeEl
|
|||
let handles_scalars =
|
||||
create_scalar_handles::<R, E1, E2, E3>(scalars_1, scalars_2, scalars_3, client);
|
||||
|
||||
let workgroup = match launch {
|
||||
CubeCountSettings::Custom(workgroup) => workgroup,
|
||||
let cube_count = match launch {
|
||||
CubeCountSettings::Custom(count) => count,
|
||||
_ => calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX),
|
||||
};
|
||||
|
||||
|
@ -304,7 +302,7 @@ fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeEl
|
|||
handles_tensors: handles,
|
||||
handle_info: info,
|
||||
handles_scalars,
|
||||
cube_count: workgroup,
|
||||
cube_count,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -74,9 +74,9 @@ impl core::fmt::Display for KernelSettings {
|
|||
// * Vectorization Global: vg{factor}
|
||||
// * Vectorization Partial Input: v{factor}i{pos}
|
||||
// * Vectorization Partial Output: vo
|
||||
// * Workgroup Size X: x
|
||||
// * Workgroup Size Y: y
|
||||
// * Workgroup Size Z: z
|
||||
// * Cube Dim X: x
|
||||
// * Cube Dim Y: y
|
||||
// * Cube Dim Z: z
|
||||
f.write_str("m")?;
|
||||
for mapping in self.mappings.iter() {
|
||||
f.write_fmt(format_args!(
|
||||
|
|
|
@ -1,41 +1,32 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{codegen::CompilerRepresentation, ir::CubeDim, Compiler, Kernel};
|
||||
use alloc::sync::Arc;
|
||||
use std::marker::PhantomData;
|
||||
use burn_compute::server::{Binding, ComputeServer};
|
||||
|
||||
/// A kernel, compiled in the target language
|
||||
pub struct CompiledKernel {
|
||||
/// Source code of the kernel
|
||||
pub source: String,
|
||||
/// Size of a workgroup for the compiled kernel
|
||||
/// Size of a cube for the compiled kernel
|
||||
pub cube_dim: CubeDim,
|
||||
/// The number of bytes used by the share memory
|
||||
pub shared_mem_bytes: usize,
|
||||
}
|
||||
|
||||
/// Information needed to launch the kernel
|
||||
pub struct LaunchSettings {
|
||||
/// Layout of workgroups for the kernel
|
||||
pub cube_count: CubeCount,
|
||||
}
|
||||
|
||||
/// Kernel trait with the ComputeShader that will be compiled and cached based on the
|
||||
/// provided id.
|
||||
///
|
||||
/// The kernel will be launched with the given [launch settings](LaunchSettings).
|
||||
pub trait CubeTask: Send + Sync {
|
||||
/// Identifier for the kernel, used for caching kernel compilation.
|
||||
fn id(&self) -> String;
|
||||
/// Compile the kernel into source
|
||||
fn compile(&self) -> CompiledKernel;
|
||||
/// Launch settings.
|
||||
fn launch_settings(&self) -> LaunchSettings;
|
||||
}
|
||||
|
||||
/// Wraps a [kernel](Kernel) with its [cube count](CubeCount) to create a [cube task](CubeTask).
|
||||
/// Wraps a [kernel](Kernel) to create a [cube task](CubeTask).
|
||||
#[derive(new)]
|
||||
pub struct KernelTask<C: Compiler, K: Kernel> {
|
||||
kernel_definition: K,
|
||||
cube_count: CubeCount,
|
||||
_compiler: PhantomData<C>,
|
||||
}
|
||||
|
||||
|
@ -57,12 +48,6 @@ impl<C: Compiler, K: Kernel> CubeTask for KernelTask<C, K> {
|
|||
fn id(&self) -> String {
|
||||
self.kernel_definition.id().clone()
|
||||
}
|
||||
|
||||
fn launch_settings(&self) -> LaunchSettings {
|
||||
LaunchSettings {
|
||||
cube_count: self.cube_count.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CubeTask for Arc<dyn CubeTask> {
|
||||
|
@ -73,10 +58,6 @@ impl CubeTask for Arc<dyn CubeTask> {
|
|||
fn id(&self) -> String {
|
||||
self.as_ref().id()
|
||||
}
|
||||
|
||||
fn launch_settings(&self) -> LaunchSettings {
|
||||
self.as_ref().launch_settings()
|
||||
}
|
||||
}
|
||||
|
||||
impl CubeTask for Box<dyn CubeTask> {
|
||||
|
@ -87,26 +68,21 @@ impl CubeTask for Box<dyn CubeTask> {
|
|||
fn id(&self) -> String {
|
||||
self.as_ref().id()
|
||||
}
|
||||
|
||||
fn launch_settings(&self) -> LaunchSettings {
|
||||
self.as_ref().launch_settings()
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides launch information specifying the number of work groups to be used by a compute shader.
|
||||
#[derive(new, Clone, Debug)]
|
||||
pub struct CubeCount {
|
||||
/// Work groups for the x axis.
|
||||
pub x: u32,
|
||||
/// Work groups for the y axis.
|
||||
pub y: u32,
|
||||
/// Work groups for the z axis.
|
||||
pub z: u32,
|
||||
pub enum CubeCount<S: ComputeServer> {
|
||||
/// Dispatch x,y,z work groups.
|
||||
Static(u32, u32, u32),
|
||||
/// Dispatch work groups based on the values in this buffer. The buffer should contain a u32 array [x, y, z].
|
||||
Dynamic(Binding<S>),
|
||||
}
|
||||
|
||||
impl CubeCount {
|
||||
/// Calculate the number of invocations of a compute shader.
|
||||
pub fn num_invocations(&self) -> usize {
|
||||
(self.x * self.y * self.z) as usize
|
||||
impl<S: ComputeServer> Clone for CubeCount<S> {
|
||||
fn clone(&self) -> Self {
|
||||
match self {
|
||||
Self::Static(x, y, z) => Self::Static(*x, *y, *z),
|
||||
Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -78,15 +78,15 @@ impl<R: Runtime> KernelLauncher<R> {
|
|||
/// Launch the kernel.
|
||||
pub fn launch<K: Kernel>(
|
||||
self,
|
||||
cube_count: CubeCount,
|
||||
cube_count: CubeCount<R::Server>,
|
||||
kernel: K,
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
) {
|
||||
let bindings = self.into_bindings(&client);
|
||||
|
||||
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel, cube_count));
|
||||
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
|
||||
|
||||
client.execute(kernel, bindings);
|
||||
client.execute(kernel, cube_count, bindings);
|
||||
}
|
||||
|
||||
/// We need to create the bindings in the same order they are defined in the compilation step.
|
||||
|
|
|
@ -6,6 +6,7 @@ extern crate derive_new;
|
|||
/// Cube Frontend Types.
|
||||
pub mod frontend;
|
||||
|
||||
use burn_compute::server::ComputeServer;
|
||||
pub use frontend::cmma;
|
||||
|
||||
/// Cube Language Internal Representation.
|
||||
|
@ -45,13 +46,16 @@ pub trait Kernel: Send + Sync + 'static {
|
|||
|
||||
/// Calculate the number of cubes required to execute an operation where one cube unit is
|
||||
/// assigned to one element.
|
||||
pub fn calculate_cube_count_elemwise(num_elems: usize, cube_dim: usize) -> CubeCount {
|
||||
pub fn calculate_cube_count_elemwise<S: ComputeServer>(
|
||||
num_elems: usize,
|
||||
cube_dim: usize,
|
||||
) -> CubeCount<S> {
|
||||
let num_elems_per_cube = cube_dim * cube_dim;
|
||||
let cube_counts = f32::ceil(num_elems as f32 / num_elems_per_cube as f32);
|
||||
let cube_count_x = f32::ceil(f32::sqrt(cube_counts));
|
||||
let cube_count_y = f32::ceil(num_elems as f32 / (cube_count_x * num_elems_per_cube as f32));
|
||||
|
||||
CubeCount::new(cube_count_x as u32, cube_count_y as u32, 1)
|
||||
CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1)
|
||||
}
|
||||
|
||||
pub fn tensor_vectorization_factor(
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
use crate::{codegen::Compiler, compute::CubeTask, ir::Elem};
|
||||
use crate::{
|
||||
codegen::Compiler,
|
||||
compute::{CubeCount, CubeTask},
|
||||
ir::Elem,
|
||||
};
|
||||
use burn_compute::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
|
||||
|
||||
/// Runtime for the CubeCL.
|
||||
|
@ -6,7 +10,11 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
|
|||
/// The compiler used to compile the inner representation into tokens.
|
||||
type Compiler: Compiler;
|
||||
/// The compute server used to run kernels and perform autotuning.
|
||||
type Server: ComputeServer<Kernel = Box<dyn CubeTask>, FeatureSet = FeatureSet>;
|
||||
type Server: ComputeServer<
|
||||
Kernel = Box<dyn CubeTask>,
|
||||
DispatchOptions = CubeCount<Self::Server>,
|
||||
FeatureSet = FeatureSet,
|
||||
>;
|
||||
/// The channel used to communicate with the compute server.
|
||||
type Channel: ComputeChannel<Self::Server>;
|
||||
/// The device used to retrieve the compute client.
|
||||
|
|
|
@ -61,7 +61,7 @@ pub fn test_simple_1<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
|
|||
|
||||
kernel_simple_1_launch::<R>(
|
||||
client.clone(),
|
||||
CubeCount::new(1, 1, 1),
|
||||
CubeCount::Static(1, 1, 1),
|
||||
CubeDim::new(16, 16, 1),
|
||||
ArrayArg::new(&lhs, 256),
|
||||
ArrayArg::new(&rhs, 256),
|
||||
|
|
|
@ -20,7 +20,7 @@ pub fn test_kernel_with_generics<R: Runtime>(client: ComputeClient<R::Server, R:
|
|||
|
||||
kernel_with_generics_launch::<F32, R>(
|
||||
client.clone(),
|
||||
CubeCount::new(1, 1, 1),
|
||||
CubeCount::Static(1, 1, 1),
|
||||
CubeDim::default(),
|
||||
ArrayArg::new(&handle, 2),
|
||||
);
|
||||
|
@ -36,7 +36,7 @@ pub fn test_kernel_without_generics<R: Runtime>(client: ComputeClient<R::Server,
|
|||
|
||||
kernel_without_generics_launch::<R>(
|
||||
client.clone(),
|
||||
CubeCount::new(1, 1, 1),
|
||||
CubeCount::Static(1, 1, 1),
|
||||
CubeDim::default(),
|
||||
ArrayArg::new(&handle, 2),
|
||||
);
|
||||
|
|
|
@ -98,7 +98,7 @@ fn test_subcube_operation<TestRuntime: Runtime, Launch>(
|
|||
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
|
||||
launch: Launch,
|
||||
) where
|
||||
Launch: Fn(CubeCount, CubeDim, TensorArg<'_, TestRuntime>),
|
||||
Launch: Fn(CubeCount<TestRuntime::Server>, CubeDim, TensorArg<'_, TestRuntime>),
|
||||
{
|
||||
if !client.features().enabled(Feature::Subcube) {
|
||||
// Can't execute the test.
|
||||
|
@ -109,7 +109,7 @@ fn test_subcube_operation<TestRuntime: Runtime, Launch>(
|
|||
let (shape, strides) = ([input.len()], [1]);
|
||||
|
||||
launch(
|
||||
CubeCount::new(1, 1, 1),
|
||||
CubeCount::Static(1, 1, 1),
|
||||
CubeDim::new(input.len() as u32, 1, 1),
|
||||
TensorArg::new(&handle, &strides, &shape),
|
||||
);
|
||||
|
|
|
@ -57,14 +57,8 @@ struct CompiledKernel {
|
|||
|
||||
unsafe impl<MM: MemoryManagement<CudaStorage>> Send for CudaServer<MM> {}
|
||||
|
||||
impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
|
||||
type Kernel = Box<dyn CubeTask>;
|
||||
type Storage = CudaStorage;
|
||||
type MemoryManagement = MM;
|
||||
type AutotuneKey = JitAutotuneKey;
|
||||
type FeatureSet = FeatureSet;
|
||||
|
||||
fn read(&mut self, binding: server::Binding<Self>) -> Reader {
|
||||
impl<MM: MemoryManagement<CudaStorage>> CudaServer<MM> {
|
||||
fn read_sync(&mut self, binding: server::Binding<Self>) -> Vec<u8> {
|
||||
let ctx = self.get_context();
|
||||
let resource = ctx.memory_management.get(binding.memory);
|
||||
|
||||
|
@ -74,7 +68,20 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
|
|||
cudarc::driver::result::memcpy_dtoh_async(&mut data, resource.ptr, ctx.stream).unwrap();
|
||||
};
|
||||
ctx.sync();
|
||||
reader_from_concrete(data)
|
||||
data
|
||||
}
|
||||
}
|
||||
|
||||
impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
|
||||
type Kernel = Box<dyn CubeTask>;
|
||||
type DispatchOptions = CubeCount<Self>;
|
||||
type Storage = CudaStorage;
|
||||
type MemoryManagement = MM;
|
||||
type AutotuneKey = JitAutotuneKey;
|
||||
type FeatureSet = FeatureSet;
|
||||
|
||||
fn read(&mut self, binding: server::Binding<Self>) -> Reader {
|
||||
reader_from_concrete(self.read_sync(binding))
|
||||
}
|
||||
|
||||
fn create(&mut self, data: &[u8]) -> server::Handle<Self> {
|
||||
|
@ -101,12 +108,33 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
|
|||
server::Handle::new(handle)
|
||||
}
|
||||
|
||||
fn execute(&mut self, kernel: Self::Kernel, bindings: Vec<server::Binding<Self>>) {
|
||||
fn execute(
|
||||
&mut self,
|
||||
kernel: Self::Kernel,
|
||||
count: Self::DispatchOptions,
|
||||
bindings: Vec<server::Binding<Self>>,
|
||||
) {
|
||||
let arch = self.minimum_arch_version;
|
||||
|
||||
let ctx = self.get_context();
|
||||
let kernel_id = kernel.id();
|
||||
let settings = kernel.launch_settings();
|
||||
|
||||
let count = match count {
|
||||
CubeCount::Static(x, y, z) => (x, y, z),
|
||||
// TODO: CUDA doesn't have an exact equivalen of dynamic dispatch. Instead, kernels are free to launch other kernels.
|
||||
// One option is to create a dummy kernel with 1 thread that launches the real kernel with the dynamic dispatch settings.
|
||||
// For now, just read the dispatch settings from the buffer.
|
||||
CubeCount::Dynamic(binding) => {
|
||||
let data = self.read_sync(binding);
|
||||
let data = bytemuck::cast_slice(&data);
|
||||
assert!(
|
||||
data.len() == 3,
|
||||
"Dynamic cube count should contain 3 values"
|
||||
);
|
||||
(data[0], data[1], data[2])
|
||||
}
|
||||
};
|
||||
|
||||
let ctx = self.get_context();
|
||||
|
||||
if !ctx.module_names.contains_key(&kernel_id) {
|
||||
ctx.compile_kernel(&kernel_id, kernel, arch);
|
||||
|
@ -117,7 +145,7 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
|
|||
.map(|binding| ctx.memory_management.get(binding.memory).as_binding())
|
||||
.collect();
|
||||
|
||||
ctx.execute_task(kernel_id, settings.cube_count, bindings);
|
||||
ctx.execute_task(kernel_id, count, bindings);
|
||||
// TODO: fix this
|
||||
// self.memory_management.storage().perform_deallocations();
|
||||
}
|
||||
|
@ -217,16 +245,15 @@ impl<MM: MemoryManagement<CudaStorage>> CudaContext<MM> {
|
|||
fn execute_task(
|
||||
&mut self,
|
||||
kernel_id: String,
|
||||
cube_count: CubeCount,
|
||||
dispatch_count: (u32, u32, u32),
|
||||
mut bindings: Vec<Binding>,
|
||||
) {
|
||||
let kernel = self.module_names.get(&kernel_id).unwrap();
|
||||
let cube_dim = kernel.cube_dim;
|
||||
|
||||
unsafe {
|
||||
cudarc::driver::result::launch_kernel(
|
||||
kernel.func,
|
||||
(cube_count.x, cube_count.y, cube_count.z),
|
||||
dispatch_count,
|
||||
(cube_dim.x, cube_dim.y, cube_dim.z),
|
||||
kernel.shared_mem_bytes as u32,
|
||||
self.stream,
|
||||
|
|
|
@ -29,9 +29,9 @@ pub struct CompilationPhase;
|
|||
/// Phase where the kernel should be executed.
|
||||
#[derive(new)]
|
||||
pub struct ExecutionPhase<R: JitRuntime> {
|
||||
/// Kernel set with default workgroup size.
|
||||
/// Kernel set with default cube size.
|
||||
pub(super) kernel_factory_1: ElementWiseKernelFactory<R>,
|
||||
/// Kernel set with custom workgroup size.
|
||||
/// Kernel set with custom cube size.
|
||||
pub(super) kernel_factory_2: ElementWiseKernelFactory<R>,
|
||||
}
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ pub struct FusionKernel<R: JitRuntime> {
|
|||
info: Arc<KernelExpansion>,
|
||||
settings: KernelSettings,
|
||||
runtime_info: Vec<OutputRuntimeInfo>,
|
||||
cube_count: CubeCount,
|
||||
cube_count: CubeCount<R::Server>,
|
||||
_runtime: PhantomData<R>,
|
||||
}
|
||||
|
||||
|
@ -41,6 +41,7 @@ pub trait FusionKernelFactory<R: JitRuntime> {
|
|||
#[derive(new)]
|
||||
pub struct ExecutableKernel<R: JitRuntime> {
|
||||
kernel: Box<dyn CubeTask>,
|
||||
cube_count: CubeCount<R::Server>,
|
||||
bindings: Vec<Binding<R::Server>>,
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
}
|
||||
|
@ -54,6 +55,7 @@ pub struct ExecutableKernel<R: JitRuntime> {
|
|||
#[derive(new)]
|
||||
pub struct AutotunableKernel<R: JitRuntime> {
|
||||
kernel: Arc<dyn CubeTask>,
|
||||
count: CubeCount<R::Server>,
|
||||
bindings: Vec<Binding<R::Server>>,
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
}
|
||||
|
@ -68,18 +70,21 @@ pub enum OutputRuntimeInfo {
|
|||
impl<R: JitRuntime> ExecutableKernel<R> {
|
||||
/// Execute the kernel.
|
||||
pub fn execute(self) {
|
||||
self.client.execute(self.kernel, self.bindings)
|
||||
self.client
|
||||
.execute(self.kernel, self.cube_count, self.bindings)
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: JitRuntime> AutotuneOperation for AutotunableKernel<R> {
|
||||
fn execute(self: Box<Self>) {
|
||||
self.client.execute(Box::new(self.kernel), self.bindings)
|
||||
self.client
|
||||
.execute(Box::new(self.kernel), self.count, self.bindings)
|
||||
}
|
||||
|
||||
fn clone(&self) -> Box<dyn AutotuneOperation> {
|
||||
Box::new(Self {
|
||||
kernel: self.kernel.clone(),
|
||||
count: self.count.clone(),
|
||||
bindings: self.bindings.clone(),
|
||||
client: self.client.clone(),
|
||||
})
|
||||
|
@ -90,6 +95,7 @@ impl<R: JitRuntime> From<ExecutableKernel<R>> for AutotunableKernel<R> {
|
|||
fn from(value: ExecutableKernel<R>) -> Self {
|
||||
Self {
|
||||
kernel: Arc::new(value.kernel),
|
||||
count: value.cube_count.clone(),
|
||||
bindings: value.bindings,
|
||||
client: value.client,
|
||||
}
|
||||
|
@ -233,12 +239,10 @@ impl<R: JitRuntime> FusionKernel<R> {
|
|||
context.handles.register_handle(id, handle);
|
||||
}
|
||||
|
||||
let workgroup = fusion_kernel.cube_count.clone();
|
||||
let cube_count = fusion_kernel.cube_count.clone();
|
||||
ExecutableKernel::new(
|
||||
Box::new(KernelTask::<R::Compiler, FusionKernel<R>>::new(
|
||||
fusion_kernel,
|
||||
workgroup,
|
||||
)),
|
||||
Box::new(KernelTask::<R::Compiler, _>::new(fusion_kernel)),
|
||||
cube_count,
|
||||
bindings,
|
||||
client,
|
||||
)
|
||||
|
|
|
@ -196,7 +196,7 @@ pub(crate) fn scatter<R: JitRuntime, E: JitElement, I: JitElement, const D: usiz
|
|||
let kernel = ScatterEagerKernel::<R, E>::new(dim);
|
||||
let mut strides = [0; D];
|
||||
let mut current = 1;
|
||||
let mut num_elems_per_workgroup = 1;
|
||||
let mut num_elems = 1;
|
||||
|
||||
tensor
|
||||
.shape
|
||||
|
@ -208,13 +208,13 @@ pub(crate) fn scatter<R: JitRuntime, E: JitElement, I: JitElement, const D: usiz
|
|||
.for_each(|(index, val)| {
|
||||
strides[index] = current;
|
||||
current *= val;
|
||||
num_elems_per_workgroup *= tensor.shape.dims[index];
|
||||
num_elems *= tensor.shape.dims[index];
|
||||
});
|
||||
|
||||
// Fake strides of the virtual output where the strides of dim is hardcoded to one.
|
||||
indices.strides = strides;
|
||||
|
||||
let workgroup = calculate_cube_count_elemwise(num_elems_per_workgroup, SUBCUBE_DIM_APPROX);
|
||||
let cube_count = calculate_cube_count_elemwise(num_elems, SUBCUBE_DIM_APPROX);
|
||||
|
||||
Execution::start(kernel, indices.client)
|
||||
.inputs(&[
|
||||
|
@ -222,7 +222,7 @@ pub(crate) fn scatter<R: JitRuntime, E: JitElement, I: JitElement, const D: usiz
|
|||
TensorHandle::new(&indices.handle, &indices.strides, &indices.shape.dims),
|
||||
TensorHandle::new(&value.handle, &value.strides, &value.shape.dims),
|
||||
])
|
||||
.execute(CubeCountSettings::Custom(workgroup));
|
||||
.execute(CubeCountSettings::Custom(cube_count));
|
||||
|
||||
tensor
|
||||
}
|
||||
|
|
|
@ -189,7 +189,7 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
|
|||
|
||||
let mut strides = [0; D];
|
||||
let mut current = 1;
|
||||
let mut num_elems_per_workgroup = 1;
|
||||
let mut num_elems = 1;
|
||||
|
||||
tensor
|
||||
.shape
|
||||
|
@ -201,11 +201,11 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
|
|||
.for_each(|(index, val)| {
|
||||
strides[index] = current;
|
||||
current *= val;
|
||||
num_elems_per_workgroup *= tensor.shape.dims[index];
|
||||
num_elems *= tensor.shape.dims[index];
|
||||
});
|
||||
|
||||
let kernel = SelectAssignEagerKernel::<R, E>::new(dim);
|
||||
let workgroup = calculate_cube_count_elemwise(num_elems_per_workgroup, SUBCUBE_DIM_APPROX);
|
||||
let cube_count = calculate_cube_count_elemwise(num_elems, SUBCUBE_DIM_APPROX);
|
||||
|
||||
Execution::start(kernel, indices.client)
|
||||
.inputs(&[
|
||||
|
@ -215,7 +215,7 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
|
|||
// kernel, but we need to put the right number of dimensions (rank).
|
||||
TensorHandle::new(&indices.handle, &strides, &strides),
|
||||
])
|
||||
.execute(CubeCountSettings::Custom(workgroup));
|
||||
.execute(CubeCountSettings::Custom(cube_count));
|
||||
|
||||
tensor
|
||||
}
|
||||
|
|
|
@ -140,41 +140,39 @@ pub fn matmul<R: JitRuntime, E: FloatElement, const D: usize>(
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn simple_launch_options<const D: usize>(
|
||||
pub(crate) fn simple_cube_count<R: JitRuntime, const D: usize>(
|
||||
lhs_shape: &Shape<D>,
|
||||
rhs_shape: &Shape<D>,
|
||||
output_shape: &Shape<D>,
|
||||
workgroup_size_x: usize,
|
||||
workgroup_size_y: usize,
|
||||
) -> CubeCount {
|
||||
cube_dim_x: usize,
|
||||
cube_dim_y: usize,
|
||||
) -> CubeCount<R::Server> {
|
||||
let num_rows = lhs_shape.dims[D - 2];
|
||||
let num_cols = rhs_shape.dims[D - 1];
|
||||
|
||||
// set number of workgroups
|
||||
let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32;
|
||||
let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32;
|
||||
let cubes_x = f32::ceil(num_rows as f32 / cube_dim_x as f32) as u32;
|
||||
let cubes_y = f32::ceil(num_cols as f32 / cube_dim_y as f32) as u32;
|
||||
let mut num_iter = 1;
|
||||
for i in 0..D - 2 {
|
||||
num_iter *= output_shape.dims[i];
|
||||
}
|
||||
|
||||
CubeCount::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32)
|
||||
CubeCount::Static(cubes_x, cubes_y, num_iter as u32)
|
||||
}
|
||||
|
||||
pub(crate) fn tiling2d_launch_options<const D: usize>(
|
||||
pub(crate) fn tiling2d_launch_options<R: JitRuntime, const D: usize>(
|
||||
output_shape: &Shape<D>,
|
||||
config: Tiling2dConfig,
|
||||
) -> CubeCount {
|
||||
) -> CubeCount<R::Server> {
|
||||
let num_rows = output_shape.dims[D - 2];
|
||||
let num_cols = output_shape.dims[D - 1];
|
||||
|
||||
// set number of workgroups
|
||||
let blocks_needed_in_x = f32::ceil(num_rows as f32 / config.block_size_m as f32) as u32;
|
||||
let blocks_needed_in_y = f32::ceil(num_cols as f32 / config.block_size_n as f32) as u32;
|
||||
let cubes_x = f32::ceil(num_rows as f32 / config.block_size_m as f32) as u32;
|
||||
let cubes_y = f32::ceil(num_cols as f32 / config.block_size_n as f32) as u32;
|
||||
let mut num_iter = 1;
|
||||
for i in 0..D - 2 {
|
||||
num_iter *= output_shape.dims[i];
|
||||
}
|
||||
|
||||
CubeCount::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32)
|
||||
CubeCount::Static(cubes_x, cubes_y, num_iter as u32)
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ use crate::{
|
|||
use burn_cube::ir::KernelDefinition;
|
||||
use burn_cube::{frontend::TensorArg, KernelSettings};
|
||||
|
||||
use super::simple_launch_options;
|
||||
use super::simple_cube_count;
|
||||
use burn_cube::prelude::*;
|
||||
|
||||
#[cube(launch)]
|
||||
|
@ -80,7 +80,7 @@ fn matmul_kernel<F: Float>(
|
|||
}
|
||||
}
|
||||
|
||||
/// Matrix multiplication using memory coalescing algorithm with workgroups of size 16
|
||||
/// Matrix multiplication using memory coalescing algorithm with cube dimensions of size 16
|
||||
pub fn matmul_mem_coalescing_default<R: JitRuntime, E: FloatElement, const D: usize>(
|
||||
lhs: JitTensor<R, E, D>,
|
||||
rhs: JitTensor<R, E, D>,
|
||||
|
@ -89,7 +89,7 @@ pub fn matmul_mem_coalescing_default<R: JitRuntime, E: FloatElement, const D: us
|
|||
matmul_simple::<R, E, D>(lhs, rhs, out, SUBCUBE_DIM_APPROX, SUBCUBE_DIM_APPROX)
|
||||
}
|
||||
|
||||
/// Matrix multiplication using memory coalescing algorithm with custom workgroup sizes
|
||||
/// Matrix multiplication using memory coalescing algorithm with custom cube dimensions
|
||||
pub fn matmul_simple<R: JitRuntime, E: FloatElement, const D: usize>(
|
||||
lhs: JitTensor<R, E, D>,
|
||||
rhs: JitTensor<R, E, D>,
|
||||
|
@ -103,7 +103,7 @@ pub fn matmul_simple<R: JitRuntime, E: FloatElement, const D: usize>(
|
|||
let rhs_original_shape = rhs.shape.clone();
|
||||
let rhs = into_contiguous(swap_dims(rhs, D - 1, D - 2));
|
||||
|
||||
let cube_count = simple_launch_options(
|
||||
let cube_count = simple_cube_count::<R, D>(
|
||||
&lhs.shape,
|
||||
&rhs_original_shape,
|
||||
&out.shape,
|
||||
|
|
|
@ -118,7 +118,7 @@ pub fn matmul_tiling_2d<R: JitRuntime, E: JitElement + Element, const D: usize>(
|
|||
&out.strides,
|
||||
&out.shape.dims,
|
||||
)])
|
||||
.execute(CubeCountSettings::Custom(tiling2d_launch_options(
|
||||
.execute(CubeCountSettings::Custom(tiling2d_launch_options::<R, D>(
|
||||
&out.shape, config,
|
||||
)));
|
||||
|
||||
|
@ -175,7 +175,7 @@ pub fn matmul_tiling_2d_padded<R: JitRuntime, E: JitElement + Element, const D:
|
|||
&rounded_output.strides,
|
||||
&rounded_output.shape.dims,
|
||||
)])
|
||||
.execute(CubeCountSettings::Custom(tiling2d_launch_options(
|
||||
.execute(CubeCountSettings::Custom(tiling2d_launch_options::<R, D>(
|
||||
&rounded_output.shape,
|
||||
config,
|
||||
)));
|
||||
|
|
|
@ -55,14 +55,14 @@ pub(crate) fn gather_shader_information(
|
|||
cpa!(scope, out_stride_row = stride(out, second_to_last_dim));
|
||||
cpa!(scope, out_stride_col = stride(out, last_dim));
|
||||
|
||||
// Workgroup offset
|
||||
// Cube offset
|
||||
let skip_row = scope.create_local(Elem::UInt);
|
||||
let skip_col = scope.create_local(Elem::UInt);
|
||||
let workgroup_id_x = Variable::CubePosX;
|
||||
let workgroup_id_y = Variable::CubePosY;
|
||||
cpa!(scope, skip_row = workgroup_id_x);
|
||||
let cube_pos_x = Variable::CubePosX;
|
||||
let cube_pos_y = Variable::CubePosY;
|
||||
cpa!(scope, skip_row = cube_pos_x);
|
||||
cpa!(scope, skip_row *= block_size_m);
|
||||
cpa!(scope, skip_col = workgroup_id_y);
|
||||
cpa!(scope, skip_col = cube_pos_y);
|
||||
cpa!(scope, skip_col *= block_size_n);
|
||||
|
||||
// Position of the first element of the thread, relative to the block
|
||||
|
|
|
@ -38,7 +38,7 @@ pub(crate) fn random<P: Prng<E>, R: JitRuntime, E: JitElement, const D: usize>(
|
|||
)])
|
||||
.with_scalars(&seeds)
|
||||
.with_scalars(&prng.args())
|
||||
.execute(CubeCountSettings::Custom(prng_cube_count(
|
||||
.execute(CubeCountSettings::Custom(prng_cube_count::<R>(
|
||||
num_elems,
|
||||
SUBCUBE_DIM_APPROX,
|
||||
N_VALUES_PER_THREAD,
|
||||
|
@ -47,14 +47,18 @@ pub(crate) fn random<P: Prng<E>, R: JitRuntime, E: JitElement, const D: usize>(
|
|||
output
|
||||
}
|
||||
|
||||
fn prng_cube_count(num_elems: usize, cube_dim: usize, n_values_per_thread: usize) -> CubeCount {
|
||||
fn prng_cube_count<R: JitRuntime>(
|
||||
num_elems: usize,
|
||||
cube_dim: usize,
|
||||
n_values_per_thread: usize,
|
||||
) -> CubeCount<R::Server> {
|
||||
let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32);
|
||||
let num_elems_per_cube = cube_dim * cube_dim;
|
||||
let num_invocations = f32::ceil(num_threads / num_elems_per_cube as f32);
|
||||
let workgroup_x = f32::ceil(f32::sqrt(num_invocations));
|
||||
let workgroup_y = f32::ceil(num_invocations / workgroup_x);
|
||||
let cubes_x = f32::ceil(f32::sqrt(num_invocations));
|
||||
let cubes_y = f32::ceil(num_invocations / cubes_x);
|
||||
|
||||
CubeCount::new(workgroup_x as u32, workgroup_y as u32, 1)
|
||||
CubeCount::Static(cubes_x as u32, cubes_y as u32, 1)
|
||||
}
|
||||
|
||||
impl<P: Prng<E>, R: JitRuntime, E: JitElement> Kernel for PrngEagerKernel<P, R, E> {
|
||||
|
@ -163,24 +167,24 @@ impl<P: Prng<E>, E: JitElement> PrngShader<P, E> {
|
|||
let n_values_per_thread: Variable = self.n_values_per_thread.into();
|
||||
let args = self.args;
|
||||
|
||||
let workgroup_size_x = Variable::CubeDimX;
|
||||
let workgroup_size_y = Variable::CubeDimY;
|
||||
let workgroup_id_x = Variable::CubePosX;
|
||||
let workgroup_id_y = Variable::CubePosY;
|
||||
let num_workgroups_y = Variable::CubeCountY;
|
||||
let cube_dim_x = Variable::CubeDimX;
|
||||
let cube_dim_y = Variable::CubeDimY;
|
||||
let cube_pos_x = Variable::CubePosX;
|
||||
let cube_pos_y = Variable::CubePosY;
|
||||
let cube_count_y = Variable::CubeCountY;
|
||||
let local_index = Variable::UnitPos;
|
||||
|
||||
let n_invocations = scope.create_local(Elem::UInt);
|
||||
cpa!(scope, n_invocations = workgroup_size_x);
|
||||
cpa!(scope, n_invocations *= workgroup_size_y);
|
||||
cpa!(scope, n_invocations = cube_dim_x);
|
||||
cpa!(scope, n_invocations *= cube_dim_y);
|
||||
|
||||
let workgroup_offset = scope.create_local(Elem::UInt);
|
||||
cpa!(scope, workgroup_offset = workgroup_id_x * num_workgroups_y);
|
||||
cpa!(scope, workgroup_offset += workgroup_id_y);
|
||||
cpa!(scope, workgroup_offset *= n_invocations);
|
||||
let cube_offset = scope.create_local(Elem::UInt);
|
||||
cpa!(scope, cube_offset = cube_pos_x * cube_count_y);
|
||||
cpa!(scope, cube_offset += cube_pos_y);
|
||||
cpa!(scope, cube_offset *= n_invocations);
|
||||
|
||||
let write_index_base = scope.create_local(Elem::UInt);
|
||||
cpa!(scope, write_index_base = workgroup_offset);
|
||||
cpa!(scope, write_index_base = cube_offset);
|
||||
cpa!(scope, write_index_base *= n_values_per_thread);
|
||||
cpa!(scope, write_index_base += local_index);
|
||||
|
||||
|
@ -188,7 +192,7 @@ impl<P: Prng<E>, E: JitElement> PrngShader<P, E> {
|
|||
let thread_seed = scope.create_local(Elem::UInt);
|
||||
cpa!(scope, thread_seed = cast(1000000007));
|
||||
let thread_seed_index = scope.create_local(Elem::UInt);
|
||||
cpa!(scope, thread_seed_index = workgroup_offset + local_index);
|
||||
cpa!(scope, thread_seed_index = cube_offset + local_index);
|
||||
cpa!(scope, thread_seed *= thread_seed_index);
|
||||
|
||||
let state_0 = scope.create_local(Elem::UInt);
|
||||
|
|
|
@ -33,8 +33,8 @@ pub(crate) struct SharedReduceDimEagerKernel<
|
|||
EO: JitElement,
|
||||
> {
|
||||
dim: usize,
|
||||
workgroup_size_x: usize,
|
||||
workgroup_size_y: usize,
|
||||
cube_dim_x: usize,
|
||||
cube_dim_y: usize,
|
||||
n_input_values_per_thread: u32,
|
||||
divisible_shape: bool,
|
||||
_reduce_dim: PhantomData<RD>,
|
||||
|
@ -58,7 +58,7 @@ impl<RD: ReduceDimShared<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Ker
|
|||
SharedReduceDimComputeShader {
|
||||
tensor,
|
||||
dim: self.dim,
|
||||
shared_memory_size: self.workgroup_size_x * self.workgroup_size_y,
|
||||
shared_memory_size: self.cube_dim_x * self.cube_dim_y,
|
||||
n_input_values_per_thread: self.n_input_values_per_thread,
|
||||
output,
|
||||
divisible_shape: self.divisible_shape,
|
||||
|
@ -83,8 +83,8 @@ impl<RD: ReduceDimShared<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Ker
|
|||
};
|
||||
|
||||
let settings = KernelSettings::default().cube_dim(CubeDim::new(
|
||||
self.workgroup_size_x as u32,
|
||||
self.workgroup_size_y as u32,
|
||||
self.cube_dim_x as u32,
|
||||
self.cube_dim_y as u32,
|
||||
1,
|
||||
));
|
||||
KernelIntegrator::new(info).integrate(settings)
|
||||
|
@ -95,8 +95,8 @@ impl<RD: ReduceDimShared<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Ker
|
|||
"{:?}dim={}x={}y={}n={}divshape={}",
|
||||
core::any::TypeId::of::<Self>(),
|
||||
self.dim,
|
||||
self.workgroup_size_x,
|
||||
self.workgroup_size_y,
|
||||
self.cube_dim_x,
|
||||
self.cube_dim_y,
|
||||
self.n_input_values_per_thread,
|
||||
self.divisible_shape
|
||||
)
|
||||
|
@ -111,13 +111,13 @@ impl<E: JitElement, RD: ReduceDimShared<E>> SharedReduceDimComputeShader<E, RD>
|
|||
let rank = Variable::Rank;
|
||||
let dim: Variable = self.dim.into();
|
||||
|
||||
let workgroup_id_x = Variable::CubePosX;
|
||||
let workgroup_id_y = Variable::CubePosY;
|
||||
let num_workgroups_x = Variable::CubeCountX;
|
||||
let cube_pos_x = Variable::CubePosX;
|
||||
let cube_pos_y = Variable::CubePosY;
|
||||
let cube_count_x = Variable::CubeCountX;
|
||||
let local_invocation_id_x = Variable::UnitPosX;
|
||||
let local_invocation_id_y = Variable::UnitPosY;
|
||||
let workgroup_size_x = Variable::CubeDimX;
|
||||
let workgroup_size_y = Variable::CubeDimY;
|
||||
let cube_dim_x = Variable::CubeDimX;
|
||||
let cube_dim_y = Variable::CubeDimY;
|
||||
|
||||
let stride_reduce_dim_input = scope.create_local(Elem::UInt);
|
||||
cpa!(scope, stride_reduce_dim_input = stride(tensor, dim));
|
||||
|
@ -126,16 +126,16 @@ impl<E: JitElement, RD: ReduceDimShared<E>> SharedReduceDimComputeShader<E, RD>
|
|||
|
||||
// To determine which reduce_group (not position, but absolute id)
|
||||
let reduce_group_id = scope.create_local(Elem::UInt);
|
||||
cpa!(scope, reduce_group_id = workgroup_id_y * num_workgroups_x);
|
||||
cpa!(scope, reduce_group_id += workgroup_id_x);
|
||||
cpa!(scope, reduce_group_id = cube_pos_y * cube_count_x);
|
||||
cpa!(scope, reduce_group_id += cube_pos_x);
|
||||
|
||||
// nth thread in the workgroup
|
||||
// nth thread in the cube
|
||||
let local_id = scope.create_local(Elem::UInt);
|
||||
cpa!(scope, local_id = local_invocation_id_y * workgroup_size_x);
|
||||
cpa!(scope, local_id = local_invocation_id_y * cube_dim_x);
|
||||
cpa!(scope, local_id += local_invocation_id_x);
|
||||
|
||||
let n_threads = scope.create_local(Elem::UInt);
|
||||
cpa!(scope, n_threads = workgroup_size_x * workgroup_size_y);
|
||||
cpa!(scope, n_threads = cube_dim_x * cube_dim_y);
|
||||
|
||||
let index_offset = scope.zero(Elem::UInt);
|
||||
|
||||
|
@ -242,17 +242,17 @@ pub fn reduce_dim_shared<
|
|||
dim: usize,
|
||||
) -> JitTensor<R, EO, D> {
|
||||
let num_elems_output = output.shape.num_elements();
|
||||
let n_workgroups_x = f32::ceil(f32::sqrt(num_elems_output as f32));
|
||||
let n_workgroups_y = f32::ceil(num_elems_output as f32 / n_workgroups_x);
|
||||
let grid = CubeCount::new(n_workgroups_x as u32, n_workgroups_y as u32, 1);
|
||||
let cube_count_x = f32::ceil(f32::sqrt(num_elems_output as f32));
|
||||
let cube_count_y = f32::ceil(num_elems_output as f32 / cube_count_x);
|
||||
let grid = CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1);
|
||||
|
||||
let reduce_group_size = input.shape.dims[dim];
|
||||
let n_invocation_per_workgroup = SUBCUBE_DIM_APPROX * SUBCUBE_DIM_APPROX;
|
||||
let n_invocation_per_cube = SUBCUBE_DIM_APPROX * SUBCUBE_DIM_APPROX;
|
||||
let n_input_values_per_thread =
|
||||
f32::ceil(reduce_group_size as f32 / n_invocation_per_workgroup as f32) as u32;
|
||||
f32::ceil(reduce_group_size as f32 / n_invocation_per_cube as f32) as u32;
|
||||
|
||||
let divisible_shape =
|
||||
n_invocation_per_workgroup as u32 * n_input_values_per_thread == reduce_group_size as u32;
|
||||
n_invocation_per_cube as u32 * n_input_values_per_thread == reduce_group_size as u32;
|
||||
|
||||
let kernel = SharedReduceDimEagerKernel::<RD, R, EI, EO>::new(
|
||||
dim,
|
||||
|
|
|
@ -18,7 +18,10 @@ pub(crate) mod tune;
|
|||
/// Elements for JIT backend
|
||||
pub mod element;
|
||||
|
||||
use burn_cube::{compute::CubeTask, Runtime};
|
||||
use burn_cube::{
|
||||
compute::{CubeCount, CubeTask},
|
||||
Runtime,
|
||||
};
|
||||
pub use element::{FloatElement, IntElement, JitElement};
|
||||
|
||||
mod backend;
|
||||
|
@ -48,5 +51,6 @@ pub trait JitRuntime: Runtime<Device = Self::JitDevice, Server = Self::JitServer
|
|||
type JitServer: burn_compute::server::ComputeServer<
|
||||
AutotuneKey = JitAutotuneKey,
|
||||
Kernel = Box<dyn CubeTask>,
|
||||
DispatchOptions = CubeCount<Self::JitServer>,
|
||||
>;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
|
||||
use burn_cube::compute::LaunchSettings;
|
||||
use burn_cube::prelude::*;
|
||||
|
||||
use super::SourceTemplate;
|
||||
|
@ -11,18 +10,13 @@ pub trait KernelSource: Send + 'static + Sync {
|
|||
}
|
||||
|
||||
#[derive(new)]
|
||||
/// Wraps a [kernel source](KernelSource) into a [cube task](CubeTask) with launch
|
||||
/// information.
|
||||
/// Wraps a [kernel source](KernelSource) into a [cube task](CubeTask).
|
||||
pub struct SourceKernel<K> {
|
||||
kernel_source: K,
|
||||
cube_count: CubeCount,
|
||||
cube_dim: CubeDim,
|
||||
}
|
||||
|
||||
impl<K> CubeTask for SourceKernel<K>
|
||||
where
|
||||
K: KernelSource + 'static,
|
||||
{
|
||||
impl<K: KernelSource> CubeTask for SourceKernel<K> {
|
||||
fn compile(&self) -> CompiledKernel {
|
||||
let source_template = self.kernel_source.source();
|
||||
let source = source_template.complete();
|
||||
|
@ -37,12 +31,6 @@ where
|
|||
fn id(&self) -> String {
|
||||
format!("{:?}", core::any::TypeId::of::<K>())
|
||||
}
|
||||
|
||||
fn launch_settings(&self) -> LaunchSettings {
|
||||
LaunchSettings {
|
||||
cube_count: self.cube_count.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates kernel source code by replacing some information using templating.
|
||||
|
|
|
@ -64,8 +64,15 @@ where
|
|||
&mut self,
|
||||
pipeline: Arc<ComputePipeline>,
|
||||
bind_group: BindGroup,
|
||||
work_group: CubeCount,
|
||||
count: CubeCount<Self>,
|
||||
) {
|
||||
// First resolve the dispatch buffer if needed. The weird ordering is because the lifetime of this
|
||||
// needs to be longer than the compute pass, so we can't do this just before dispatching.
|
||||
let dispatch_resource = match count.clone() {
|
||||
CubeCount::Dynamic(binding) => Some(self.memory_management.get(binding.memory)),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let mut compute = self
|
||||
.encoder
|
||||
.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||
|
@ -75,12 +82,21 @@ where
|
|||
|
||||
compute.set_pipeline(&pipeline);
|
||||
compute.set_bind_group(0, &bind_group, &[]);
|
||||
compute.dispatch_workgroups(work_group.x, work_group.y, work_group.z);
|
||||
|
||||
match count {
|
||||
CubeCount::Static(x, y, z) => {
|
||||
compute.dispatch_workgroups(x, y, z);
|
||||
}
|
||||
CubeCount::Dynamic(_) => {
|
||||
let resource = dispatch_resource.as_ref().unwrap();
|
||||
compute.dispatch_workgroups_indirect(&resource.buffer, resource.offset());
|
||||
}
|
||||
}
|
||||
|
||||
self.tasks_count += 1;
|
||||
}
|
||||
|
||||
fn pipeline(&mut self, kernel: Box<dyn CubeTask>) -> Arc<ComputePipeline> {
|
||||
fn pipeline(&mut self, kernel: <Self as ComputeServer>::Kernel) -> Arc<ComputePipeline> {
|
||||
let kernel_id = kernel.id();
|
||||
|
||||
if let Some(pipeline) = self.pipelines.get(&kernel_id) {
|
||||
|
@ -143,6 +159,7 @@ where
|
|||
MM: MemoryManagement<WgpuStorage>,
|
||||
{
|
||||
type Kernel = Box<dyn CubeTask>;
|
||||
type DispatchOptions = CubeCount<Self>;
|
||||
type Storage = WgpuStorage;
|
||||
type MemoryManagement = MM;
|
||||
type AutotuneKey = JitAutotuneKey;
|
||||
|
@ -259,8 +276,12 @@ where
|
|||
}))
|
||||
}
|
||||
|
||||
fn execute(&mut self, kernel: Self::Kernel, bindings: Vec<server::Binding<Self>>) {
|
||||
let work_group = kernel.launch_settings().cube_count;
|
||||
fn execute(
|
||||
&mut self,
|
||||
kernel: Self::Kernel,
|
||||
count: Self::DispatchOptions,
|
||||
bindings: Vec<server::Binding<Self>>,
|
||||
) {
|
||||
let pipeline = self.pipeline(kernel);
|
||||
let group_layout = pipeline.get_bind_group_layout(0);
|
||||
|
||||
|
@ -284,7 +305,7 @@ where
|
|||
entries: &entries,
|
||||
});
|
||||
|
||||
self.register_compute(pipeline, bind_group, work_group);
|
||||
self.register_compute(pipeline, bind_group, count);
|
||||
|
||||
if self.tasks_count >= self.tasks_max {
|
||||
self.sync(SyncType::Flush);
|
||||
|
|
|
@ -111,7 +111,8 @@ impl ComputeStorage for WgpuStorage {
|
|||
size: size as u64,
|
||||
usage: wgpu::BufferUsages::COPY_DST
|
||||
| wgpu::BufferUsages::STORAGE
|
||||
| wgpu::BufferUsages::COPY_SRC,
|
||||
| wgpu::BufferUsages::COPY_SRC
|
||||
| wgpu::BufferUsages::INDIRECT,
|
||||
mapped_at_creation: false,
|
||||
}));
|
||||
|
||||
|
|
|
@ -87,11 +87,13 @@ impl<F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime, F, I> {
|
|||
// Declare the wgsl workgroup with the number of cubes in x, y and z.
|
||||
let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;
|
||||
let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32;
|
||||
let cube_count = CubeCount::new(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);
|
||||
let cube_count =
|
||||
CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);
|
||||
|
||||
// Execute lazily the kernel with the launch information and the given buffers.
|
||||
lhs.client.execute(
|
||||
Box::new(SourceKernel::new(kernel, cube_count, cube_dim)),
|
||||
Box::new(SourceKernel::new(kernel, cube_dim)),
|
||||
cube_count,
|
||||
vec![
|
||||
lhs.handle.binding(),
|
||||
rhs.handle.binding(),
|
||||
|
|
|
@ -22,7 +22,7 @@ pub fn launch<R: Runtime>(device: &R::Device) {
|
|||
|
||||
gelu_launch::<F32, R>(
|
||||
client.clone(),
|
||||
CubeCount::new(1, 1, 1),
|
||||
CubeCount::Static(1, 1, 1),
|
||||
CubeDim::default(),
|
||||
ArrayArg::new(&input_handle, input.len()),
|
||||
ArrayArg::new(&output_handle, input.len()),
|
||||
|
|
Loading…
Reference in New Issue