mirror of https://github.com/tracel-ai/burn.git
Feat: Dynamic cube count dispatch (#1975)
This commit is contained in:
parent
b331290f8a
commit
3f9e97946f
|
@ -248,11 +248,12 @@ impl<F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime, F, I> {
|
||||||
// Declare the wgsl workgroup with the number of blocks in x, y and z.
|
// Declare the wgsl workgroup with the number of blocks in x, y and z.
|
||||||
let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;
|
let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;
|
||||||
let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32;
|
let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32;
|
||||||
let cube_count = CubeCount::new(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);
|
let cube_count = CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);
|
||||||
|
|
||||||
// Execute lazily the kernel with the launch information and the given buffers.
|
// Execute lazily the kernel with the launch information and the given buffers.
|
||||||
lhs.client.execute(
|
lhs.client.execute(
|
||||||
Box::new(SourceKernel::new(kernel, cube_count, cube_dim)),
|
Box::new(SourceKernel::new(kernel, cube_dim)),
|
||||||
|
cube_count,
|
||||||
vec![
|
vec![
|
||||||
lhs.handle.binding(),
|
lhs.handle.binding(),
|
||||||
rhs.handle.binding(),
|
rhs.handle.binding(),
|
||||||
|
|
|
@ -24,7 +24,12 @@ pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug + Send
|
||||||
fn empty(&self, size: usize) -> Handle<Server>;
|
fn empty(&self, size: usize) -> Handle<Server>;
|
||||||
|
|
||||||
/// Executes the `kernel` over the given `bindings`.
|
/// Executes the `kernel` over the given `bindings`.
|
||||||
fn execute(&self, kernel: Server::Kernel, bindings: Vec<Binding<Server>>);
|
fn execute(
|
||||||
|
&self,
|
||||||
|
kernel: Server::Kernel,
|
||||||
|
count: Server::DispatchOptions,
|
||||||
|
bindings: Vec<Binding<Server>>,
|
||||||
|
);
|
||||||
|
|
||||||
/// Perform some synchronization of commands on the server.
|
/// Perform some synchronization of commands on the server.
|
||||||
fn sync(&self, sync_type: SyncType);
|
fn sync(&self, sync_type: SyncType);
|
||||||
|
|
|
@ -63,10 +63,15 @@ where
|
||||||
self.server.borrow_mut().empty(size)
|
self.server.borrow_mut().empty(size)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn execute(&self, kernel_description: Server::Kernel, bindings: Vec<Binding<Server>>) {
|
fn execute(
|
||||||
|
&self,
|
||||||
|
kernel_description: Server::Kernel,
|
||||||
|
count: Server::DispatchOptions,
|
||||||
|
bindings: Vec<Binding<Server>>,
|
||||||
|
) {
|
||||||
self.server
|
self.server
|
||||||
.borrow_mut()
|
.borrow_mut()
|
||||||
.execute(kernel_description, bindings)
|
.execute(kernel_description, count, bindings)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sync(&self, sync_type: SyncType) {
|
fn sync(&self, sync_type: SyncType) {
|
||||||
|
|
|
@ -39,7 +39,10 @@ where
|
||||||
),
|
),
|
||||||
Create(Vec<u8>, Callback<Handle<Server>>),
|
Create(Vec<u8>, Callback<Handle<Server>>),
|
||||||
Empty(usize, Callback<Handle<Server>>),
|
Empty(usize, Callback<Handle<Server>>),
|
||||||
ExecuteKernel(Server::Kernel, Vec<Binding<Server>>),
|
ExecuteKernel(
|
||||||
|
(Server::Kernel, Server::DispatchOptions),
|
||||||
|
Vec<Binding<Server>>,
|
||||||
|
),
|
||||||
Sync(SyncType, Callback<()>),
|
Sync(SyncType, Callback<()>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,7 +77,7 @@ where
|
||||||
callback.send(handle).await.unwrap();
|
callback.send(handle).await.unwrap();
|
||||||
}
|
}
|
||||||
Message::ExecuteKernel(kernel, bindings) => {
|
Message::ExecuteKernel(kernel, bindings) => {
|
||||||
server.execute(kernel, bindings);
|
server.execute(kernel.0, kernel.1, bindings);
|
||||||
}
|
}
|
||||||
Message::Sync(sync_type, callback) => {
|
Message::Sync(sync_type, callback) => {
|
||||||
server.sync(sync_type);
|
server.sync(sync_type);
|
||||||
|
@ -148,10 +151,15 @@ where
|
||||||
handle_response(response.recv_blocking())
|
handle_response(response.recv_blocking())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn execute(&self, kernel: Server::Kernel, bindings: Vec<Binding<Server>>) {
|
fn execute(
|
||||||
|
&self,
|
||||||
|
kernel: Server::Kernel,
|
||||||
|
count: Server::DispatchOptions,
|
||||||
|
bindings: Vec<Binding<Server>>,
|
||||||
|
) {
|
||||||
self.state
|
self.state
|
||||||
.sender
|
.sender
|
||||||
.send_blocking(Message::ExecuteKernel(kernel, bindings))
|
.send_blocking(Message::ExecuteKernel((kernel, count), bindings))
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -56,8 +56,13 @@ where
|
||||||
self.server.lock().empty(size)
|
self.server.lock().empty(size)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn execute(&self, kernel: Server::Kernel, handles: Vec<Binding<Server>>) {
|
fn execute(
|
||||||
self.server.lock().execute(kernel, handles)
|
&self,
|
||||||
|
kernel: Server::Kernel,
|
||||||
|
count: Server::DispatchOptions,
|
||||||
|
handles: Vec<Binding<Server>>,
|
||||||
|
) {
|
||||||
|
self.server.lock().execute(kernel, count, handles)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sync(&self, sync_type: SyncType) {
|
fn sync(&self, sync_type: SyncType) {
|
||||||
|
|
|
@ -82,8 +82,13 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Executes the `kernel` over the given `bindings`.
|
/// Executes the `kernel` over the given `bindings`.
|
||||||
pub fn execute(&self, kernel: Server::Kernel, bindings: Vec<Binding<Server>>) {
|
pub fn execute(
|
||||||
self.channel.execute(kernel, bindings)
|
&self,
|
||||||
|
kernel: Server::Kernel,
|
||||||
|
count: Server::DispatchOptions,
|
||||||
|
bindings: Vec<Binding<Server>>,
|
||||||
|
) {
|
||||||
|
self.channel.execute(kernel, count, bindings)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Wait for the completion of every task in the server.
|
/// Wait for the completion of every task in the server.
|
||||||
|
|
|
@ -17,6 +17,8 @@ where
|
||||||
{
|
{
|
||||||
/// The kernel type defines the computation algorithms.
|
/// The kernel type defines the computation algorithms.
|
||||||
type Kernel: Send;
|
type Kernel: Send;
|
||||||
|
/// Options when dispatching the kernel, eg. the number of executions.
|
||||||
|
type DispatchOptions: Send;
|
||||||
/// The [storage](ComputeStorage) type defines how data is stored and accessed.
|
/// The [storage](ComputeStorage) type defines how data is stored and accessed.
|
||||||
type Storage: ComputeStorage;
|
type Storage: ComputeStorage;
|
||||||
/// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type.
|
/// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type.
|
||||||
|
@ -45,7 +47,12 @@ where
|
||||||
///
|
///
|
||||||
/// Kernels have mutable access to every resource they are given
|
/// Kernels have mutable access to every resource they are given
|
||||||
/// and are responsible of determining which should be read or written.
|
/// and are responsible of determining which should be read or written.
|
||||||
fn execute(&mut self, kernel: Self::Kernel, bindings: Vec<Binding<Self>>);
|
fn execute(
|
||||||
|
&mut self,
|
||||||
|
kernel: Self::Kernel,
|
||||||
|
count: Self::DispatchOptions,
|
||||||
|
bindings: Vec<Binding<Self>>,
|
||||||
|
);
|
||||||
|
|
||||||
/// Wait for the completion of every task in the server.
|
/// Wait for the completion of every task in the server.
|
||||||
fn sync(&mut self, command: SyncType);
|
fn sync(&mut self, command: SyncType);
|
||||||
|
|
|
@ -21,6 +21,7 @@ impl<MM> ComputeServer for DummyServer<MM>
|
||||||
where
|
where
|
||||||
MM: MemoryManagement<BytesStorage>,
|
MM: MemoryManagement<BytesStorage>,
|
||||||
{
|
{
|
||||||
|
type DispatchOptions = ();
|
||||||
type Kernel = Arc<dyn DummyKernel>;
|
type Kernel = Arc<dyn DummyKernel>;
|
||||||
type Storage = BytesStorage;
|
type Storage = BytesStorage;
|
||||||
type MemoryManagement = MM;
|
type MemoryManagement = MM;
|
||||||
|
@ -53,7 +54,12 @@ where
|
||||||
Handle::new(self.memory_management.reserve(size, || {}))
|
Handle::new(self.memory_management.reserve(size, || {}))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn execute(&mut self, kernel: Self::Kernel, bindings: Vec<Binding<Self>>) {
|
fn execute(
|
||||||
|
&mut self,
|
||||||
|
kernel: Self::Kernel,
|
||||||
|
_count: Self::DispatchOptions,
|
||||||
|
bindings: Vec<Binding<Self>>,
|
||||||
|
) {
|
||||||
let mut resources = bindings
|
let mut resources = bindings
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|binding| self.memory_management.get(binding.memory))
|
.map(|binding| self.memory_management.get(binding.memory))
|
||||||
|
|
|
@ -18,7 +18,7 @@ pub struct OneKernelAutotuneOperation {
|
||||||
impl AutotuneOperation for OneKernelAutotuneOperation {
|
impl AutotuneOperation for OneKernelAutotuneOperation {
|
||||||
/// Executes the operation on given bindings and server, with the additional parameters
|
/// Executes the operation on given bindings and server, with the additional parameters
|
||||||
fn execute(self: Box<Self>) {
|
fn execute(self: Box<Self>) {
|
||||||
self.client.execute(self.kernel.clone(), self.bindings);
|
self.client.execute(self.kernel.clone(), (), self.bindings);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn clone(&self) -> Box<dyn AutotuneOperation> {
|
fn clone(&self) -> Box<dyn AutotuneOperation> {
|
||||||
|
|
|
@ -38,6 +38,7 @@ fn execute_elementwise_addition() {
|
||||||
|
|
||||||
client.execute(
|
client.execute(
|
||||||
Arc::new(DummyElementwiseAddition),
|
Arc::new(DummyElementwiseAddition),
|
||||||
|
(),
|
||||||
vec![lhs.binding(), rhs.binding(), out.clone().binding()],
|
vec![lhs.binding(), rhs.binding(), out.clone().binding()],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -470,7 +470,7 @@ pub fn codegen_launch(sig: &syn::Signature) -> TokenStream {
|
||||||
/// Launch
|
/// Launch
|
||||||
pub fn #ident #generics (
|
pub fn #ident #generics (
|
||||||
client: ComputeClient<R::Server, R::Channel>,
|
client: ComputeClient<R::Server, R::Channel>,
|
||||||
cube_count: CubeCount,
|
cube_count: CubeCount<R::Server>,
|
||||||
cube_dim: CubeDim,
|
cube_dim: CubeDim,
|
||||||
#inputs
|
#inputs
|
||||||
) -> #output {
|
) -> #output {
|
||||||
|
|
|
@ -4,13 +4,13 @@ use crate::ir::Elem;
|
||||||
use crate::pod::CubeElement;
|
use crate::pod::CubeElement;
|
||||||
use crate::{calculate_cube_count_elemwise, Kernel, Runtime, SUBCUBE_DIM_APPROX};
|
use crate::{calculate_cube_count_elemwise, Kernel, Runtime, SUBCUBE_DIM_APPROX};
|
||||||
use burn_compute::client::ComputeClient;
|
use burn_compute::client::ComputeClient;
|
||||||
use burn_compute::server::{Binding, Handle};
|
use burn_compute::server::{Binding, ComputeServer, Handle};
|
||||||
|
|
||||||
/// The position of the input or output to calculate the number of workgroups to launch.
|
/// The position of the input or output to calculate the number of cubes to launch.
|
||||||
pub enum CubeCountSettings {
|
pub enum CubeCountSettings<S: ComputeServer> {
|
||||||
Input { pos: usize },
|
Input { pos: usize },
|
||||||
Output { pos: usize },
|
Output { pos: usize },
|
||||||
Custom(CubeCount),
|
Custom(CubeCount<S>),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Execution<'h, K, R: Runtime, Scalars> {
|
pub struct Execution<'h, K, R: Runtime, Scalars> {
|
||||||
|
@ -73,7 +73,7 @@ where
|
||||||
}
|
}
|
||||||
/// Execute a dynamic kernel.
|
/// Execute a dynamic kernel.
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
pub fn execute(self, launch: CubeCountSettings) {
|
pub fn execute(self, launch: CubeCountSettings<R::Server>) {
|
||||||
execute_dynamic::<R, K, f32, f32, f32>(
|
execute_dynamic::<R, K, f32, f32, f32>(
|
||||||
self.inputs,
|
self.inputs,
|
||||||
self.outputs,
|
self.outputs,
|
||||||
|
@ -108,7 +108,7 @@ where
|
||||||
|
|
||||||
/// Execute a dynamic kernel.
|
/// Execute a dynamic kernel.
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
pub fn execute(self, launch: CubeCountSettings) {
|
pub fn execute(self, launch: CubeCountSettings<R::Server>) {
|
||||||
execute_dynamic::<R, K, E, f32, f32>(
|
execute_dynamic::<R, K, E, f32, f32>(
|
||||||
self.inputs,
|
self.inputs,
|
||||||
self.outputs,
|
self.outputs,
|
||||||
|
@ -144,7 +144,7 @@ where
|
||||||
}
|
}
|
||||||
/// Execute a dynamic kernel.
|
/// Execute a dynamic kernel.
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn execute(self, launch: CubeCountSettings)
|
pub fn execute(self, launch: CubeCountSettings<R::Server>)
|
||||||
where
|
where
|
||||||
K: Kernel + 'static,
|
K: Kernel + 'static,
|
||||||
R: Runtime,
|
R: Runtime,
|
||||||
|
@ -172,7 +172,7 @@ where
|
||||||
{
|
{
|
||||||
/// Execute a dynamic kernel.
|
/// Execute a dynamic kernel.
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
pub fn execute(self, launch: CubeCountSettings) {
|
pub fn execute(self, launch: CubeCountSettings<R::Server>) {
|
||||||
execute_dynamic::<R, K, E1, E2, E3>(
|
execute_dynamic::<R, K, E1, E2, E3>(
|
||||||
self.inputs,
|
self.inputs,
|
||||||
self.outputs,
|
self.outputs,
|
||||||
|
@ -194,7 +194,7 @@ fn execute_dynamic<R, K, E1, E2, E3>(
|
||||||
scalars_2: Option<&[E2]>,
|
scalars_2: Option<&[E2]>,
|
||||||
scalars_3: Option<&[E3]>,
|
scalars_3: Option<&[E3]>,
|
||||||
kernel: K,
|
kernel: K,
|
||||||
launch: CubeCountSettings,
|
launch: CubeCountSettings<R::Server>,
|
||||||
client: ComputeClient<R::Server, R::Channel>,
|
client: ComputeClient<R::Server, R::Channel>,
|
||||||
) where
|
) where
|
||||||
K: Kernel + 'static,
|
K: Kernel + 'static,
|
||||||
|
@ -207,23 +207,21 @@ fn execute_dynamic<R, K, E1, E2, E3>(
|
||||||
inputs, outputs, scalars_1, scalars_2, scalars_3, launch, &client,
|
inputs, outputs, scalars_1, scalars_2, scalars_3, launch, &client,
|
||||||
);
|
);
|
||||||
let mut handles = settings.handles_tensors;
|
let mut handles = settings.handles_tensors;
|
||||||
let workgroup = settings.cube_count;
|
|
||||||
|
|
||||||
handles.push(settings.handle_info.binding());
|
handles.push(settings.handle_info.binding());
|
||||||
for handle in settings.handles_scalars.into_iter() {
|
for handle in settings.handles_scalars.into_iter() {
|
||||||
handles.push(handle.binding());
|
handles.push(handle.binding());
|
||||||
}
|
}
|
||||||
|
|
||||||
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel, workgroup));
|
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
|
||||||
|
client.execute(kernel, settings.cube_count, handles);
|
||||||
client.execute(kernel, handles);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ExecuteSettings<R: Runtime> {
|
struct ExecuteSettings<R: Runtime> {
|
||||||
handles_tensors: Vec<Binding<R::Server>>,
|
handles_tensors: Vec<Binding<R::Server>>,
|
||||||
handle_info: Handle<R::Server>,
|
handle_info: Handle<R::Server>,
|
||||||
handles_scalars: Vec<Handle<R::Server>>,
|
handles_scalars: Vec<Handle<R::Server>>,
|
||||||
cube_count: CubeCount,
|
cube_count: CubeCount<R::Server>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeElement>(
|
fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeElement>(
|
||||||
|
@ -232,7 +230,7 @@ fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeEl
|
||||||
scalars_1: Option<&[E1]>,
|
scalars_1: Option<&[E1]>,
|
||||||
scalars_2: Option<&[E2]>,
|
scalars_2: Option<&[E2]>,
|
||||||
scalars_3: Option<&[E3]>,
|
scalars_3: Option<&[E3]>,
|
||||||
launch: CubeCountSettings,
|
launch: CubeCountSettings<R::Server>,
|
||||||
client: &ComputeClient<R::Server, R::Channel>,
|
client: &ComputeClient<R::Server, R::Channel>,
|
||||||
) -> ExecuteSettings<R> {
|
) -> ExecuteSettings<R> {
|
||||||
let mut info = Vec::new();
|
let mut info = Vec::new();
|
||||||
|
@ -295,8 +293,8 @@ fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeEl
|
||||||
let handles_scalars =
|
let handles_scalars =
|
||||||
create_scalar_handles::<R, E1, E2, E3>(scalars_1, scalars_2, scalars_3, client);
|
create_scalar_handles::<R, E1, E2, E3>(scalars_1, scalars_2, scalars_3, client);
|
||||||
|
|
||||||
let workgroup = match launch {
|
let cube_count = match launch {
|
||||||
CubeCountSettings::Custom(workgroup) => workgroup,
|
CubeCountSettings::Custom(count) => count,
|
||||||
_ => calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX),
|
_ => calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -304,7 +302,7 @@ fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeEl
|
||||||
handles_tensors: handles,
|
handles_tensors: handles,
|
||||||
handle_info: info,
|
handle_info: info,
|
||||||
handles_scalars,
|
handles_scalars,
|
||||||
cube_count: workgroup,
|
cube_count,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -74,9 +74,9 @@ impl core::fmt::Display for KernelSettings {
|
||||||
// * Vectorization Global: vg{factor}
|
// * Vectorization Global: vg{factor}
|
||||||
// * Vectorization Partial Input: v{factor}i{pos}
|
// * Vectorization Partial Input: v{factor}i{pos}
|
||||||
// * Vectorization Partial Output: vo
|
// * Vectorization Partial Output: vo
|
||||||
// * Workgroup Size X: x
|
// * Cube Dim X: x
|
||||||
// * Workgroup Size Y: y
|
// * Cube Dim Y: y
|
||||||
// * Workgroup Size Z: z
|
// * Cube Dim Z: z
|
||||||
f.write_str("m")?;
|
f.write_str("m")?;
|
||||||
for mapping in self.mappings.iter() {
|
for mapping in self.mappings.iter() {
|
||||||
f.write_fmt(format_args!(
|
f.write_fmt(format_args!(
|
||||||
|
|
|
@ -1,41 +1,32 @@
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
use crate::{codegen::CompilerRepresentation, ir::CubeDim, Compiler, Kernel};
|
use crate::{codegen::CompilerRepresentation, ir::CubeDim, Compiler, Kernel};
|
||||||
use alloc::sync::Arc;
|
use alloc::sync::Arc;
|
||||||
use std::marker::PhantomData;
|
use burn_compute::server::{Binding, ComputeServer};
|
||||||
|
|
||||||
/// A kernel, compiled in the target language
|
/// A kernel, compiled in the target language
|
||||||
pub struct CompiledKernel {
|
pub struct CompiledKernel {
|
||||||
/// Source code of the kernel
|
/// Source code of the kernel
|
||||||
pub source: String,
|
pub source: String,
|
||||||
/// Size of a workgroup for the compiled kernel
|
/// Size of a cube for the compiled kernel
|
||||||
pub cube_dim: CubeDim,
|
pub cube_dim: CubeDim,
|
||||||
/// The number of bytes used by the share memory
|
/// The number of bytes used by the share memory
|
||||||
pub shared_mem_bytes: usize,
|
pub shared_mem_bytes: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Information needed to launch the kernel
|
|
||||||
pub struct LaunchSettings {
|
|
||||||
/// Layout of workgroups for the kernel
|
|
||||||
pub cube_count: CubeCount,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Kernel trait with the ComputeShader that will be compiled and cached based on the
|
/// Kernel trait with the ComputeShader that will be compiled and cached based on the
|
||||||
/// provided id.
|
/// provided id.
|
||||||
///
|
|
||||||
/// The kernel will be launched with the given [launch settings](LaunchSettings).
|
|
||||||
pub trait CubeTask: Send + Sync {
|
pub trait CubeTask: Send + Sync {
|
||||||
/// Identifier for the kernel, used for caching kernel compilation.
|
/// Identifier for the kernel, used for caching kernel compilation.
|
||||||
fn id(&self) -> String;
|
fn id(&self) -> String;
|
||||||
/// Compile the kernel into source
|
/// Compile the kernel into source
|
||||||
fn compile(&self) -> CompiledKernel;
|
fn compile(&self) -> CompiledKernel;
|
||||||
/// Launch settings.
|
|
||||||
fn launch_settings(&self) -> LaunchSettings;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Wraps a [kernel](Kernel) with its [cube count](CubeCount) to create a [cube task](CubeTask).
|
/// Wraps a [kernel](Kernel) to create a [cube task](CubeTask).
|
||||||
#[derive(new)]
|
#[derive(new)]
|
||||||
pub struct KernelTask<C: Compiler, K: Kernel> {
|
pub struct KernelTask<C: Compiler, K: Kernel> {
|
||||||
kernel_definition: K,
|
kernel_definition: K,
|
||||||
cube_count: CubeCount,
|
|
||||||
_compiler: PhantomData<C>,
|
_compiler: PhantomData<C>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,12 +48,6 @@ impl<C: Compiler, K: Kernel> CubeTask for KernelTask<C, K> {
|
||||||
fn id(&self) -> String {
|
fn id(&self) -> String {
|
||||||
self.kernel_definition.id().clone()
|
self.kernel_definition.id().clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn launch_settings(&self) -> LaunchSettings {
|
|
||||||
LaunchSettings {
|
|
||||||
cube_count: self.cube_count.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CubeTask for Arc<dyn CubeTask> {
|
impl CubeTask for Arc<dyn CubeTask> {
|
||||||
|
@ -73,10 +58,6 @@ impl CubeTask for Arc<dyn CubeTask> {
|
||||||
fn id(&self) -> String {
|
fn id(&self) -> String {
|
||||||
self.as_ref().id()
|
self.as_ref().id()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn launch_settings(&self) -> LaunchSettings {
|
|
||||||
self.as_ref().launch_settings()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CubeTask for Box<dyn CubeTask> {
|
impl CubeTask for Box<dyn CubeTask> {
|
||||||
|
@ -87,26 +68,21 @@ impl CubeTask for Box<dyn CubeTask> {
|
||||||
fn id(&self) -> String {
|
fn id(&self) -> String {
|
||||||
self.as_ref().id()
|
self.as_ref().id()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn launch_settings(&self) -> LaunchSettings {
|
|
||||||
self.as_ref().launch_settings()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Provides launch information specifying the number of work groups to be used by a compute shader.
|
/// Provides launch information specifying the number of work groups to be used by a compute shader.
|
||||||
#[derive(new, Clone, Debug)]
|
pub enum CubeCount<S: ComputeServer> {
|
||||||
pub struct CubeCount {
|
/// Dispatch x,y,z work groups.
|
||||||
/// Work groups for the x axis.
|
Static(u32, u32, u32),
|
||||||
pub x: u32,
|
/// Dispatch work groups based on the values in this buffer. The buffer should contain a u32 array [x, y, z].
|
||||||
/// Work groups for the y axis.
|
Dynamic(Binding<S>),
|
||||||
pub y: u32,
|
|
||||||
/// Work groups for the z axis.
|
|
||||||
pub z: u32,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CubeCount {
|
impl<S: ComputeServer> Clone for CubeCount<S> {
|
||||||
/// Calculate the number of invocations of a compute shader.
|
fn clone(&self) -> Self {
|
||||||
pub fn num_invocations(&self) -> usize {
|
match self {
|
||||||
(self.x * self.y * self.z) as usize
|
Self::Static(x, y, z) => Self::Static(*x, *y, *z),
|
||||||
|
Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -78,15 +78,15 @@ impl<R: Runtime> KernelLauncher<R> {
|
||||||
/// Launch the kernel.
|
/// Launch the kernel.
|
||||||
pub fn launch<K: Kernel>(
|
pub fn launch<K: Kernel>(
|
||||||
self,
|
self,
|
||||||
cube_count: CubeCount,
|
cube_count: CubeCount<R::Server>,
|
||||||
kernel: K,
|
kernel: K,
|
||||||
client: ComputeClient<R::Server, R::Channel>,
|
client: ComputeClient<R::Server, R::Channel>,
|
||||||
) {
|
) {
|
||||||
let bindings = self.into_bindings(&client);
|
let bindings = self.into_bindings(&client);
|
||||||
|
|
||||||
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel, cube_count));
|
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
|
||||||
|
|
||||||
client.execute(kernel, bindings);
|
client.execute(kernel, cube_count, bindings);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// We need to create the bindings in the same order they are defined in the compilation step.
|
/// We need to create the bindings in the same order they are defined in the compilation step.
|
||||||
|
|
|
@ -6,6 +6,7 @@ extern crate derive_new;
|
||||||
/// Cube Frontend Types.
|
/// Cube Frontend Types.
|
||||||
pub mod frontend;
|
pub mod frontend;
|
||||||
|
|
||||||
|
use burn_compute::server::ComputeServer;
|
||||||
pub use frontend::cmma;
|
pub use frontend::cmma;
|
||||||
|
|
||||||
/// Cube Language Internal Representation.
|
/// Cube Language Internal Representation.
|
||||||
|
@ -45,13 +46,16 @@ pub trait Kernel: Send + Sync + 'static {
|
||||||
|
|
||||||
/// Calculate the number of cubes required to execute an operation where one cube unit is
|
/// Calculate the number of cubes required to execute an operation where one cube unit is
|
||||||
/// assigned to one element.
|
/// assigned to one element.
|
||||||
pub fn calculate_cube_count_elemwise(num_elems: usize, cube_dim: usize) -> CubeCount {
|
pub fn calculate_cube_count_elemwise<S: ComputeServer>(
|
||||||
|
num_elems: usize,
|
||||||
|
cube_dim: usize,
|
||||||
|
) -> CubeCount<S> {
|
||||||
let num_elems_per_cube = cube_dim * cube_dim;
|
let num_elems_per_cube = cube_dim * cube_dim;
|
||||||
let cube_counts = f32::ceil(num_elems as f32 / num_elems_per_cube as f32);
|
let cube_counts = f32::ceil(num_elems as f32 / num_elems_per_cube as f32);
|
||||||
let cube_count_x = f32::ceil(f32::sqrt(cube_counts));
|
let cube_count_x = f32::ceil(f32::sqrt(cube_counts));
|
||||||
let cube_count_y = f32::ceil(num_elems as f32 / (cube_count_x * num_elems_per_cube as f32));
|
let cube_count_y = f32::ceil(num_elems as f32 / (cube_count_x * num_elems_per_cube as f32));
|
||||||
|
|
||||||
CubeCount::new(cube_count_x as u32, cube_count_y as u32, 1)
|
CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tensor_vectorization_factor(
|
pub fn tensor_vectorization_factor(
|
||||||
|
|
|
@ -1,4 +1,8 @@
|
||||||
use crate::{codegen::Compiler, compute::CubeTask, ir::Elem};
|
use crate::{
|
||||||
|
codegen::Compiler,
|
||||||
|
compute::{CubeCount, CubeTask},
|
||||||
|
ir::Elem,
|
||||||
|
};
|
||||||
use burn_compute::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
|
use burn_compute::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
|
||||||
|
|
||||||
/// Runtime for the CubeCL.
|
/// Runtime for the CubeCL.
|
||||||
|
@ -6,7 +10,11 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
|
||||||
/// The compiler used to compile the inner representation into tokens.
|
/// The compiler used to compile the inner representation into tokens.
|
||||||
type Compiler: Compiler;
|
type Compiler: Compiler;
|
||||||
/// The compute server used to run kernels and perform autotuning.
|
/// The compute server used to run kernels and perform autotuning.
|
||||||
type Server: ComputeServer<Kernel = Box<dyn CubeTask>, FeatureSet = FeatureSet>;
|
type Server: ComputeServer<
|
||||||
|
Kernel = Box<dyn CubeTask>,
|
||||||
|
DispatchOptions = CubeCount<Self::Server>,
|
||||||
|
FeatureSet = FeatureSet,
|
||||||
|
>;
|
||||||
/// The channel used to communicate with the compute server.
|
/// The channel used to communicate with the compute server.
|
||||||
type Channel: ComputeChannel<Self::Server>;
|
type Channel: ComputeChannel<Self::Server>;
|
||||||
/// The device used to retrieve the compute client.
|
/// The device used to retrieve the compute client.
|
||||||
|
|
|
@ -61,7 +61,7 @@ pub fn test_simple_1<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
|
||||||
|
|
||||||
kernel_simple_1_launch::<R>(
|
kernel_simple_1_launch::<R>(
|
||||||
client.clone(),
|
client.clone(),
|
||||||
CubeCount::new(1, 1, 1),
|
CubeCount::Static(1, 1, 1),
|
||||||
CubeDim::new(16, 16, 1),
|
CubeDim::new(16, 16, 1),
|
||||||
ArrayArg::new(&lhs, 256),
|
ArrayArg::new(&lhs, 256),
|
||||||
ArrayArg::new(&rhs, 256),
|
ArrayArg::new(&rhs, 256),
|
||||||
|
|
|
@ -20,7 +20,7 @@ pub fn test_kernel_with_generics<R: Runtime>(client: ComputeClient<R::Server, R:
|
||||||
|
|
||||||
kernel_with_generics_launch::<F32, R>(
|
kernel_with_generics_launch::<F32, R>(
|
||||||
client.clone(),
|
client.clone(),
|
||||||
CubeCount::new(1, 1, 1),
|
CubeCount::Static(1, 1, 1),
|
||||||
CubeDim::default(),
|
CubeDim::default(),
|
||||||
ArrayArg::new(&handle, 2),
|
ArrayArg::new(&handle, 2),
|
||||||
);
|
);
|
||||||
|
@ -36,7 +36,7 @@ pub fn test_kernel_without_generics<R: Runtime>(client: ComputeClient<R::Server,
|
||||||
|
|
||||||
kernel_without_generics_launch::<R>(
|
kernel_without_generics_launch::<R>(
|
||||||
client.clone(),
|
client.clone(),
|
||||||
CubeCount::new(1, 1, 1),
|
CubeCount::Static(1, 1, 1),
|
||||||
CubeDim::default(),
|
CubeDim::default(),
|
||||||
ArrayArg::new(&handle, 2),
|
ArrayArg::new(&handle, 2),
|
||||||
);
|
);
|
||||||
|
|
|
@ -98,7 +98,7 @@ fn test_subcube_operation<TestRuntime: Runtime, Launch>(
|
||||||
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
|
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
|
||||||
launch: Launch,
|
launch: Launch,
|
||||||
) where
|
) where
|
||||||
Launch: Fn(CubeCount, CubeDim, TensorArg<'_, TestRuntime>),
|
Launch: Fn(CubeCount<TestRuntime::Server>, CubeDim, TensorArg<'_, TestRuntime>),
|
||||||
{
|
{
|
||||||
if !client.features().enabled(Feature::Subcube) {
|
if !client.features().enabled(Feature::Subcube) {
|
||||||
// Can't execute the test.
|
// Can't execute the test.
|
||||||
|
@ -109,7 +109,7 @@ fn test_subcube_operation<TestRuntime: Runtime, Launch>(
|
||||||
let (shape, strides) = ([input.len()], [1]);
|
let (shape, strides) = ([input.len()], [1]);
|
||||||
|
|
||||||
launch(
|
launch(
|
||||||
CubeCount::new(1, 1, 1),
|
CubeCount::Static(1, 1, 1),
|
||||||
CubeDim::new(input.len() as u32, 1, 1),
|
CubeDim::new(input.len() as u32, 1, 1),
|
||||||
TensorArg::new(&handle, &strides, &shape),
|
TensorArg::new(&handle, &strides, &shape),
|
||||||
);
|
);
|
||||||
|
|
|
@ -57,14 +57,8 @@ struct CompiledKernel {
|
||||||
|
|
||||||
unsafe impl<MM: MemoryManagement<CudaStorage>> Send for CudaServer<MM> {}
|
unsafe impl<MM: MemoryManagement<CudaStorage>> Send for CudaServer<MM> {}
|
||||||
|
|
||||||
impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
|
impl<MM: MemoryManagement<CudaStorage>> CudaServer<MM> {
|
||||||
type Kernel = Box<dyn CubeTask>;
|
fn read_sync(&mut self, binding: server::Binding<Self>) -> Vec<u8> {
|
||||||
type Storage = CudaStorage;
|
|
||||||
type MemoryManagement = MM;
|
|
||||||
type AutotuneKey = JitAutotuneKey;
|
|
||||||
type FeatureSet = FeatureSet;
|
|
||||||
|
|
||||||
fn read(&mut self, binding: server::Binding<Self>) -> Reader {
|
|
||||||
let ctx = self.get_context();
|
let ctx = self.get_context();
|
||||||
let resource = ctx.memory_management.get(binding.memory);
|
let resource = ctx.memory_management.get(binding.memory);
|
||||||
|
|
||||||
|
@ -74,7 +68,20 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
|
||||||
cudarc::driver::result::memcpy_dtoh_async(&mut data, resource.ptr, ctx.stream).unwrap();
|
cudarc::driver::result::memcpy_dtoh_async(&mut data, resource.ptr, ctx.stream).unwrap();
|
||||||
};
|
};
|
||||||
ctx.sync();
|
ctx.sync();
|
||||||
reader_from_concrete(data)
|
data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
|
||||||
|
type Kernel = Box<dyn CubeTask>;
|
||||||
|
type DispatchOptions = CubeCount<Self>;
|
||||||
|
type Storage = CudaStorage;
|
||||||
|
type MemoryManagement = MM;
|
||||||
|
type AutotuneKey = JitAutotuneKey;
|
||||||
|
type FeatureSet = FeatureSet;
|
||||||
|
|
||||||
|
fn read(&mut self, binding: server::Binding<Self>) -> Reader {
|
||||||
|
reader_from_concrete(self.read_sync(binding))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create(&mut self, data: &[u8]) -> server::Handle<Self> {
|
fn create(&mut self, data: &[u8]) -> server::Handle<Self> {
|
||||||
|
@ -101,12 +108,33 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
|
||||||
server::Handle::new(handle)
|
server::Handle::new(handle)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn execute(&mut self, kernel: Self::Kernel, bindings: Vec<server::Binding<Self>>) {
|
fn execute(
|
||||||
|
&mut self,
|
||||||
|
kernel: Self::Kernel,
|
||||||
|
count: Self::DispatchOptions,
|
||||||
|
bindings: Vec<server::Binding<Self>>,
|
||||||
|
) {
|
||||||
let arch = self.minimum_arch_version;
|
let arch = self.minimum_arch_version;
|
||||||
|
|
||||||
let ctx = self.get_context();
|
|
||||||
let kernel_id = kernel.id();
|
let kernel_id = kernel.id();
|
||||||
let settings = kernel.launch_settings();
|
|
||||||
|
let count = match count {
|
||||||
|
CubeCount::Static(x, y, z) => (x, y, z),
|
||||||
|
// TODO: CUDA doesn't have an exact equivalen of dynamic dispatch. Instead, kernels are free to launch other kernels.
|
||||||
|
// One option is to create a dummy kernel with 1 thread that launches the real kernel with the dynamic dispatch settings.
|
||||||
|
// For now, just read the dispatch settings from the buffer.
|
||||||
|
CubeCount::Dynamic(binding) => {
|
||||||
|
let data = self.read_sync(binding);
|
||||||
|
let data = bytemuck::cast_slice(&data);
|
||||||
|
assert!(
|
||||||
|
data.len() == 3,
|
||||||
|
"Dynamic cube count should contain 3 values"
|
||||||
|
);
|
||||||
|
(data[0], data[1], data[2])
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let ctx = self.get_context();
|
||||||
|
|
||||||
if !ctx.module_names.contains_key(&kernel_id) {
|
if !ctx.module_names.contains_key(&kernel_id) {
|
||||||
ctx.compile_kernel(&kernel_id, kernel, arch);
|
ctx.compile_kernel(&kernel_id, kernel, arch);
|
||||||
|
@ -117,7 +145,7 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
|
||||||
.map(|binding| ctx.memory_management.get(binding.memory).as_binding())
|
.map(|binding| ctx.memory_management.get(binding.memory).as_binding())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
ctx.execute_task(kernel_id, settings.cube_count, bindings);
|
ctx.execute_task(kernel_id, count, bindings);
|
||||||
// TODO: fix this
|
// TODO: fix this
|
||||||
// self.memory_management.storage().perform_deallocations();
|
// self.memory_management.storage().perform_deallocations();
|
||||||
}
|
}
|
||||||
|
@ -217,16 +245,15 @@ impl<MM: MemoryManagement<CudaStorage>> CudaContext<MM> {
|
||||||
fn execute_task(
|
fn execute_task(
|
||||||
&mut self,
|
&mut self,
|
||||||
kernel_id: String,
|
kernel_id: String,
|
||||||
cube_count: CubeCount,
|
dispatch_count: (u32, u32, u32),
|
||||||
mut bindings: Vec<Binding>,
|
mut bindings: Vec<Binding>,
|
||||||
) {
|
) {
|
||||||
let kernel = self.module_names.get(&kernel_id).unwrap();
|
let kernel = self.module_names.get(&kernel_id).unwrap();
|
||||||
let cube_dim = kernel.cube_dim;
|
let cube_dim = kernel.cube_dim;
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
cudarc::driver::result::launch_kernel(
|
cudarc::driver::result::launch_kernel(
|
||||||
kernel.func,
|
kernel.func,
|
||||||
(cube_count.x, cube_count.y, cube_count.z),
|
dispatch_count,
|
||||||
(cube_dim.x, cube_dim.y, cube_dim.z),
|
(cube_dim.x, cube_dim.y, cube_dim.z),
|
||||||
kernel.shared_mem_bytes as u32,
|
kernel.shared_mem_bytes as u32,
|
||||||
self.stream,
|
self.stream,
|
||||||
|
|
|
@ -29,9 +29,9 @@ pub struct CompilationPhase;
|
||||||
/// Phase where the kernel should be executed.
|
/// Phase where the kernel should be executed.
|
||||||
#[derive(new)]
|
#[derive(new)]
|
||||||
pub struct ExecutionPhase<R: JitRuntime> {
|
pub struct ExecutionPhase<R: JitRuntime> {
|
||||||
/// Kernel set with default workgroup size.
|
/// Kernel set with default cube size.
|
||||||
pub(super) kernel_factory_1: ElementWiseKernelFactory<R>,
|
pub(super) kernel_factory_1: ElementWiseKernelFactory<R>,
|
||||||
/// Kernel set with custom workgroup size.
|
/// Kernel set with custom cube size.
|
||||||
pub(super) kernel_factory_2: ElementWiseKernelFactory<R>,
|
pub(super) kernel_factory_2: ElementWiseKernelFactory<R>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ pub struct FusionKernel<R: JitRuntime> {
|
||||||
info: Arc<KernelExpansion>,
|
info: Arc<KernelExpansion>,
|
||||||
settings: KernelSettings,
|
settings: KernelSettings,
|
||||||
runtime_info: Vec<OutputRuntimeInfo>,
|
runtime_info: Vec<OutputRuntimeInfo>,
|
||||||
cube_count: CubeCount,
|
cube_count: CubeCount<R::Server>,
|
||||||
_runtime: PhantomData<R>,
|
_runtime: PhantomData<R>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,6 +41,7 @@ pub trait FusionKernelFactory<R: JitRuntime> {
|
||||||
#[derive(new)]
|
#[derive(new)]
|
||||||
pub struct ExecutableKernel<R: JitRuntime> {
|
pub struct ExecutableKernel<R: JitRuntime> {
|
||||||
kernel: Box<dyn CubeTask>,
|
kernel: Box<dyn CubeTask>,
|
||||||
|
cube_count: CubeCount<R::Server>,
|
||||||
bindings: Vec<Binding<R::Server>>,
|
bindings: Vec<Binding<R::Server>>,
|
||||||
client: ComputeClient<R::Server, R::Channel>,
|
client: ComputeClient<R::Server, R::Channel>,
|
||||||
}
|
}
|
||||||
|
@ -54,6 +55,7 @@ pub struct ExecutableKernel<R: JitRuntime> {
|
||||||
#[derive(new)]
|
#[derive(new)]
|
||||||
pub struct AutotunableKernel<R: JitRuntime> {
|
pub struct AutotunableKernel<R: JitRuntime> {
|
||||||
kernel: Arc<dyn CubeTask>,
|
kernel: Arc<dyn CubeTask>,
|
||||||
|
count: CubeCount<R::Server>,
|
||||||
bindings: Vec<Binding<R::Server>>,
|
bindings: Vec<Binding<R::Server>>,
|
||||||
client: ComputeClient<R::Server, R::Channel>,
|
client: ComputeClient<R::Server, R::Channel>,
|
||||||
}
|
}
|
||||||
|
@ -68,18 +70,21 @@ pub enum OutputRuntimeInfo {
|
||||||
impl<R: JitRuntime> ExecutableKernel<R> {
|
impl<R: JitRuntime> ExecutableKernel<R> {
|
||||||
/// Execute the kernel.
|
/// Execute the kernel.
|
||||||
pub fn execute(self) {
|
pub fn execute(self) {
|
||||||
self.client.execute(self.kernel, self.bindings)
|
self.client
|
||||||
|
.execute(self.kernel, self.cube_count, self.bindings)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R: JitRuntime> AutotuneOperation for AutotunableKernel<R> {
|
impl<R: JitRuntime> AutotuneOperation for AutotunableKernel<R> {
|
||||||
fn execute(self: Box<Self>) {
|
fn execute(self: Box<Self>) {
|
||||||
self.client.execute(Box::new(self.kernel), self.bindings)
|
self.client
|
||||||
|
.execute(Box::new(self.kernel), self.count, self.bindings)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn clone(&self) -> Box<dyn AutotuneOperation> {
|
fn clone(&self) -> Box<dyn AutotuneOperation> {
|
||||||
Box::new(Self {
|
Box::new(Self {
|
||||||
kernel: self.kernel.clone(),
|
kernel: self.kernel.clone(),
|
||||||
|
count: self.count.clone(),
|
||||||
bindings: self.bindings.clone(),
|
bindings: self.bindings.clone(),
|
||||||
client: self.client.clone(),
|
client: self.client.clone(),
|
||||||
})
|
})
|
||||||
|
@ -90,6 +95,7 @@ impl<R: JitRuntime> From<ExecutableKernel<R>> for AutotunableKernel<R> {
|
||||||
fn from(value: ExecutableKernel<R>) -> Self {
|
fn from(value: ExecutableKernel<R>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
kernel: Arc::new(value.kernel),
|
kernel: Arc::new(value.kernel),
|
||||||
|
count: value.cube_count.clone(),
|
||||||
bindings: value.bindings,
|
bindings: value.bindings,
|
||||||
client: value.client,
|
client: value.client,
|
||||||
}
|
}
|
||||||
|
@ -233,12 +239,10 @@ impl<R: JitRuntime> FusionKernel<R> {
|
||||||
context.handles.register_handle(id, handle);
|
context.handles.register_handle(id, handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
let workgroup = fusion_kernel.cube_count.clone();
|
let cube_count = fusion_kernel.cube_count.clone();
|
||||||
ExecutableKernel::new(
|
ExecutableKernel::new(
|
||||||
Box::new(KernelTask::<R::Compiler, FusionKernel<R>>::new(
|
Box::new(KernelTask::<R::Compiler, _>::new(fusion_kernel)),
|
||||||
fusion_kernel,
|
cube_count,
|
||||||
workgroup,
|
|
||||||
)),
|
|
||||||
bindings,
|
bindings,
|
||||||
client,
|
client,
|
||||||
)
|
)
|
||||||
|
|
|
@ -196,7 +196,7 @@ pub(crate) fn scatter<R: JitRuntime, E: JitElement, I: JitElement, const D: usiz
|
||||||
let kernel = ScatterEagerKernel::<R, E>::new(dim);
|
let kernel = ScatterEagerKernel::<R, E>::new(dim);
|
||||||
let mut strides = [0; D];
|
let mut strides = [0; D];
|
||||||
let mut current = 1;
|
let mut current = 1;
|
||||||
let mut num_elems_per_workgroup = 1;
|
let mut num_elems = 1;
|
||||||
|
|
||||||
tensor
|
tensor
|
||||||
.shape
|
.shape
|
||||||
|
@ -208,13 +208,13 @@ pub(crate) fn scatter<R: JitRuntime, E: JitElement, I: JitElement, const D: usiz
|
||||||
.for_each(|(index, val)| {
|
.for_each(|(index, val)| {
|
||||||
strides[index] = current;
|
strides[index] = current;
|
||||||
current *= val;
|
current *= val;
|
||||||
num_elems_per_workgroup *= tensor.shape.dims[index];
|
num_elems *= tensor.shape.dims[index];
|
||||||
});
|
});
|
||||||
|
|
||||||
// Fake strides of the virtual output where the strides of dim is hardcoded to one.
|
// Fake strides of the virtual output where the strides of dim is hardcoded to one.
|
||||||
indices.strides = strides;
|
indices.strides = strides;
|
||||||
|
|
||||||
let workgroup = calculate_cube_count_elemwise(num_elems_per_workgroup, SUBCUBE_DIM_APPROX);
|
let cube_count = calculate_cube_count_elemwise(num_elems, SUBCUBE_DIM_APPROX);
|
||||||
|
|
||||||
Execution::start(kernel, indices.client)
|
Execution::start(kernel, indices.client)
|
||||||
.inputs(&[
|
.inputs(&[
|
||||||
|
@ -222,7 +222,7 @@ pub(crate) fn scatter<R: JitRuntime, E: JitElement, I: JitElement, const D: usiz
|
||||||
TensorHandle::new(&indices.handle, &indices.strides, &indices.shape.dims),
|
TensorHandle::new(&indices.handle, &indices.strides, &indices.shape.dims),
|
||||||
TensorHandle::new(&value.handle, &value.strides, &value.shape.dims),
|
TensorHandle::new(&value.handle, &value.strides, &value.shape.dims),
|
||||||
])
|
])
|
||||||
.execute(CubeCountSettings::Custom(workgroup));
|
.execute(CubeCountSettings::Custom(cube_count));
|
||||||
|
|
||||||
tensor
|
tensor
|
||||||
}
|
}
|
||||||
|
|
|
@ -189,7 +189,7 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
|
||||||
|
|
||||||
let mut strides = [0; D];
|
let mut strides = [0; D];
|
||||||
let mut current = 1;
|
let mut current = 1;
|
||||||
let mut num_elems_per_workgroup = 1;
|
let mut num_elems = 1;
|
||||||
|
|
||||||
tensor
|
tensor
|
||||||
.shape
|
.shape
|
||||||
|
@ -201,11 +201,11 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
|
||||||
.for_each(|(index, val)| {
|
.for_each(|(index, val)| {
|
||||||
strides[index] = current;
|
strides[index] = current;
|
||||||
current *= val;
|
current *= val;
|
||||||
num_elems_per_workgroup *= tensor.shape.dims[index];
|
num_elems *= tensor.shape.dims[index];
|
||||||
});
|
});
|
||||||
|
|
||||||
let kernel = SelectAssignEagerKernel::<R, E>::new(dim);
|
let kernel = SelectAssignEagerKernel::<R, E>::new(dim);
|
||||||
let workgroup = calculate_cube_count_elemwise(num_elems_per_workgroup, SUBCUBE_DIM_APPROX);
|
let cube_count = calculate_cube_count_elemwise(num_elems, SUBCUBE_DIM_APPROX);
|
||||||
|
|
||||||
Execution::start(kernel, indices.client)
|
Execution::start(kernel, indices.client)
|
||||||
.inputs(&[
|
.inputs(&[
|
||||||
|
@ -215,7 +215,7 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement, const D
|
||||||
// kernel, but we need to put the right number of dimensions (rank).
|
// kernel, but we need to put the right number of dimensions (rank).
|
||||||
TensorHandle::new(&indices.handle, &strides, &strides),
|
TensorHandle::new(&indices.handle, &strides, &strides),
|
||||||
])
|
])
|
||||||
.execute(CubeCountSettings::Custom(workgroup));
|
.execute(CubeCountSettings::Custom(cube_count));
|
||||||
|
|
||||||
tensor
|
tensor
|
||||||
}
|
}
|
||||||
|
|
|
@ -140,41 +140,39 @@ pub fn matmul<R: JitRuntime, E: FloatElement, const D: usize>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn simple_launch_options<const D: usize>(
|
pub(crate) fn simple_cube_count<R: JitRuntime, const D: usize>(
|
||||||
lhs_shape: &Shape<D>,
|
lhs_shape: &Shape<D>,
|
||||||
rhs_shape: &Shape<D>,
|
rhs_shape: &Shape<D>,
|
||||||
output_shape: &Shape<D>,
|
output_shape: &Shape<D>,
|
||||||
workgroup_size_x: usize,
|
cube_dim_x: usize,
|
||||||
workgroup_size_y: usize,
|
cube_dim_y: usize,
|
||||||
) -> CubeCount {
|
) -> CubeCount<R::Server> {
|
||||||
let num_rows = lhs_shape.dims[D - 2];
|
let num_rows = lhs_shape.dims[D - 2];
|
||||||
let num_cols = rhs_shape.dims[D - 1];
|
let num_cols = rhs_shape.dims[D - 1];
|
||||||
|
|
||||||
// set number of workgroups
|
let cubes_x = f32::ceil(num_rows as f32 / cube_dim_x as f32) as u32;
|
||||||
let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32;
|
let cubes_y = f32::ceil(num_cols as f32 / cube_dim_y as f32) as u32;
|
||||||
let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32;
|
|
||||||
let mut num_iter = 1;
|
let mut num_iter = 1;
|
||||||
for i in 0..D - 2 {
|
for i in 0..D - 2 {
|
||||||
num_iter *= output_shape.dims[i];
|
num_iter *= output_shape.dims[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
CubeCount::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32)
|
CubeCount::Static(cubes_x, cubes_y, num_iter as u32)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn tiling2d_launch_options<const D: usize>(
|
pub(crate) fn tiling2d_launch_options<R: JitRuntime, const D: usize>(
|
||||||
output_shape: &Shape<D>,
|
output_shape: &Shape<D>,
|
||||||
config: Tiling2dConfig,
|
config: Tiling2dConfig,
|
||||||
) -> CubeCount {
|
) -> CubeCount<R::Server> {
|
||||||
let num_rows = output_shape.dims[D - 2];
|
let num_rows = output_shape.dims[D - 2];
|
||||||
let num_cols = output_shape.dims[D - 1];
|
let num_cols = output_shape.dims[D - 1];
|
||||||
|
|
||||||
// set number of workgroups
|
let cubes_x = f32::ceil(num_rows as f32 / config.block_size_m as f32) as u32;
|
||||||
let blocks_needed_in_x = f32::ceil(num_rows as f32 / config.block_size_m as f32) as u32;
|
let cubes_y = f32::ceil(num_cols as f32 / config.block_size_n as f32) as u32;
|
||||||
let blocks_needed_in_y = f32::ceil(num_cols as f32 / config.block_size_n as f32) as u32;
|
|
||||||
let mut num_iter = 1;
|
let mut num_iter = 1;
|
||||||
for i in 0..D - 2 {
|
for i in 0..D - 2 {
|
||||||
num_iter *= output_shape.dims[i];
|
num_iter *= output_shape.dims[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
CubeCount::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32)
|
CubeCount::Static(cubes_x, cubes_y, num_iter as u32)
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ use crate::{
|
||||||
use burn_cube::ir::KernelDefinition;
|
use burn_cube::ir::KernelDefinition;
|
||||||
use burn_cube::{frontend::TensorArg, KernelSettings};
|
use burn_cube::{frontend::TensorArg, KernelSettings};
|
||||||
|
|
||||||
use super::simple_launch_options;
|
use super::simple_cube_count;
|
||||||
use burn_cube::prelude::*;
|
use burn_cube::prelude::*;
|
||||||
|
|
||||||
#[cube(launch)]
|
#[cube(launch)]
|
||||||
|
@ -80,7 +80,7 @@ fn matmul_kernel<F: Float>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Matrix multiplication using memory coalescing algorithm with workgroups of size 16
|
/// Matrix multiplication using memory coalescing algorithm with cube dimensions of size 16
|
||||||
pub fn matmul_mem_coalescing_default<R: JitRuntime, E: FloatElement, const D: usize>(
|
pub fn matmul_mem_coalescing_default<R: JitRuntime, E: FloatElement, const D: usize>(
|
||||||
lhs: JitTensor<R, E, D>,
|
lhs: JitTensor<R, E, D>,
|
||||||
rhs: JitTensor<R, E, D>,
|
rhs: JitTensor<R, E, D>,
|
||||||
|
@ -89,7 +89,7 @@ pub fn matmul_mem_coalescing_default<R: JitRuntime, E: FloatElement, const D: us
|
||||||
matmul_simple::<R, E, D>(lhs, rhs, out, SUBCUBE_DIM_APPROX, SUBCUBE_DIM_APPROX)
|
matmul_simple::<R, E, D>(lhs, rhs, out, SUBCUBE_DIM_APPROX, SUBCUBE_DIM_APPROX)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Matrix multiplication using memory coalescing algorithm with custom workgroup sizes
|
/// Matrix multiplication using memory coalescing algorithm with custom cube dimensions
|
||||||
pub fn matmul_simple<R: JitRuntime, E: FloatElement, const D: usize>(
|
pub fn matmul_simple<R: JitRuntime, E: FloatElement, const D: usize>(
|
||||||
lhs: JitTensor<R, E, D>,
|
lhs: JitTensor<R, E, D>,
|
||||||
rhs: JitTensor<R, E, D>,
|
rhs: JitTensor<R, E, D>,
|
||||||
|
@ -103,7 +103,7 @@ pub fn matmul_simple<R: JitRuntime, E: FloatElement, const D: usize>(
|
||||||
let rhs_original_shape = rhs.shape.clone();
|
let rhs_original_shape = rhs.shape.clone();
|
||||||
let rhs = into_contiguous(swap_dims(rhs, D - 1, D - 2));
|
let rhs = into_contiguous(swap_dims(rhs, D - 1, D - 2));
|
||||||
|
|
||||||
let cube_count = simple_launch_options(
|
let cube_count = simple_cube_count::<R, D>(
|
||||||
&lhs.shape,
|
&lhs.shape,
|
||||||
&rhs_original_shape,
|
&rhs_original_shape,
|
||||||
&out.shape,
|
&out.shape,
|
||||||
|
|
|
@ -118,7 +118,7 @@ pub fn matmul_tiling_2d<R: JitRuntime, E: JitElement + Element, const D: usize>(
|
||||||
&out.strides,
|
&out.strides,
|
||||||
&out.shape.dims,
|
&out.shape.dims,
|
||||||
)])
|
)])
|
||||||
.execute(CubeCountSettings::Custom(tiling2d_launch_options(
|
.execute(CubeCountSettings::Custom(tiling2d_launch_options::<R, D>(
|
||||||
&out.shape, config,
|
&out.shape, config,
|
||||||
)));
|
)));
|
||||||
|
|
||||||
|
@ -175,7 +175,7 @@ pub fn matmul_tiling_2d_padded<R: JitRuntime, E: JitElement + Element, const D:
|
||||||
&rounded_output.strides,
|
&rounded_output.strides,
|
||||||
&rounded_output.shape.dims,
|
&rounded_output.shape.dims,
|
||||||
)])
|
)])
|
||||||
.execute(CubeCountSettings::Custom(tiling2d_launch_options(
|
.execute(CubeCountSettings::Custom(tiling2d_launch_options::<R, D>(
|
||||||
&rounded_output.shape,
|
&rounded_output.shape,
|
||||||
config,
|
config,
|
||||||
)));
|
)));
|
||||||
|
|
|
@ -55,14 +55,14 @@ pub(crate) fn gather_shader_information(
|
||||||
cpa!(scope, out_stride_row = stride(out, second_to_last_dim));
|
cpa!(scope, out_stride_row = stride(out, second_to_last_dim));
|
||||||
cpa!(scope, out_stride_col = stride(out, last_dim));
|
cpa!(scope, out_stride_col = stride(out, last_dim));
|
||||||
|
|
||||||
// Workgroup offset
|
// Cube offset
|
||||||
let skip_row = scope.create_local(Elem::UInt);
|
let skip_row = scope.create_local(Elem::UInt);
|
||||||
let skip_col = scope.create_local(Elem::UInt);
|
let skip_col = scope.create_local(Elem::UInt);
|
||||||
let workgroup_id_x = Variable::CubePosX;
|
let cube_pos_x = Variable::CubePosX;
|
||||||
let workgroup_id_y = Variable::CubePosY;
|
let cube_pos_y = Variable::CubePosY;
|
||||||
cpa!(scope, skip_row = workgroup_id_x);
|
cpa!(scope, skip_row = cube_pos_x);
|
||||||
cpa!(scope, skip_row *= block_size_m);
|
cpa!(scope, skip_row *= block_size_m);
|
||||||
cpa!(scope, skip_col = workgroup_id_y);
|
cpa!(scope, skip_col = cube_pos_y);
|
||||||
cpa!(scope, skip_col *= block_size_n);
|
cpa!(scope, skip_col *= block_size_n);
|
||||||
|
|
||||||
// Position of the first element of the thread, relative to the block
|
// Position of the first element of the thread, relative to the block
|
||||||
|
|
|
@ -38,7 +38,7 @@ pub(crate) fn random<P: Prng<E>, R: JitRuntime, E: JitElement, const D: usize>(
|
||||||
)])
|
)])
|
||||||
.with_scalars(&seeds)
|
.with_scalars(&seeds)
|
||||||
.with_scalars(&prng.args())
|
.with_scalars(&prng.args())
|
||||||
.execute(CubeCountSettings::Custom(prng_cube_count(
|
.execute(CubeCountSettings::Custom(prng_cube_count::<R>(
|
||||||
num_elems,
|
num_elems,
|
||||||
SUBCUBE_DIM_APPROX,
|
SUBCUBE_DIM_APPROX,
|
||||||
N_VALUES_PER_THREAD,
|
N_VALUES_PER_THREAD,
|
||||||
|
@ -47,14 +47,18 @@ pub(crate) fn random<P: Prng<E>, R: JitRuntime, E: JitElement, const D: usize>(
|
||||||
output
|
output
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prng_cube_count(num_elems: usize, cube_dim: usize, n_values_per_thread: usize) -> CubeCount {
|
fn prng_cube_count<R: JitRuntime>(
|
||||||
|
num_elems: usize,
|
||||||
|
cube_dim: usize,
|
||||||
|
n_values_per_thread: usize,
|
||||||
|
) -> CubeCount<R::Server> {
|
||||||
let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32);
|
let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32);
|
||||||
let num_elems_per_cube = cube_dim * cube_dim;
|
let num_elems_per_cube = cube_dim * cube_dim;
|
||||||
let num_invocations = f32::ceil(num_threads / num_elems_per_cube as f32);
|
let num_invocations = f32::ceil(num_threads / num_elems_per_cube as f32);
|
||||||
let workgroup_x = f32::ceil(f32::sqrt(num_invocations));
|
let cubes_x = f32::ceil(f32::sqrt(num_invocations));
|
||||||
let workgroup_y = f32::ceil(num_invocations / workgroup_x);
|
let cubes_y = f32::ceil(num_invocations / cubes_x);
|
||||||
|
|
||||||
CubeCount::new(workgroup_x as u32, workgroup_y as u32, 1)
|
CubeCount::Static(cubes_x as u32, cubes_y as u32, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<P: Prng<E>, R: JitRuntime, E: JitElement> Kernel for PrngEagerKernel<P, R, E> {
|
impl<P: Prng<E>, R: JitRuntime, E: JitElement> Kernel for PrngEagerKernel<P, R, E> {
|
||||||
|
@ -163,24 +167,24 @@ impl<P: Prng<E>, E: JitElement> PrngShader<P, E> {
|
||||||
let n_values_per_thread: Variable = self.n_values_per_thread.into();
|
let n_values_per_thread: Variable = self.n_values_per_thread.into();
|
||||||
let args = self.args;
|
let args = self.args;
|
||||||
|
|
||||||
let workgroup_size_x = Variable::CubeDimX;
|
let cube_dim_x = Variable::CubeDimX;
|
||||||
let workgroup_size_y = Variable::CubeDimY;
|
let cube_dim_y = Variable::CubeDimY;
|
||||||
let workgroup_id_x = Variable::CubePosX;
|
let cube_pos_x = Variable::CubePosX;
|
||||||
let workgroup_id_y = Variable::CubePosY;
|
let cube_pos_y = Variable::CubePosY;
|
||||||
let num_workgroups_y = Variable::CubeCountY;
|
let cube_count_y = Variable::CubeCountY;
|
||||||
let local_index = Variable::UnitPos;
|
let local_index = Variable::UnitPos;
|
||||||
|
|
||||||
let n_invocations = scope.create_local(Elem::UInt);
|
let n_invocations = scope.create_local(Elem::UInt);
|
||||||
cpa!(scope, n_invocations = workgroup_size_x);
|
cpa!(scope, n_invocations = cube_dim_x);
|
||||||
cpa!(scope, n_invocations *= workgroup_size_y);
|
cpa!(scope, n_invocations *= cube_dim_y);
|
||||||
|
|
||||||
let workgroup_offset = scope.create_local(Elem::UInt);
|
let cube_offset = scope.create_local(Elem::UInt);
|
||||||
cpa!(scope, workgroup_offset = workgroup_id_x * num_workgroups_y);
|
cpa!(scope, cube_offset = cube_pos_x * cube_count_y);
|
||||||
cpa!(scope, workgroup_offset += workgroup_id_y);
|
cpa!(scope, cube_offset += cube_pos_y);
|
||||||
cpa!(scope, workgroup_offset *= n_invocations);
|
cpa!(scope, cube_offset *= n_invocations);
|
||||||
|
|
||||||
let write_index_base = scope.create_local(Elem::UInt);
|
let write_index_base = scope.create_local(Elem::UInt);
|
||||||
cpa!(scope, write_index_base = workgroup_offset);
|
cpa!(scope, write_index_base = cube_offset);
|
||||||
cpa!(scope, write_index_base *= n_values_per_thread);
|
cpa!(scope, write_index_base *= n_values_per_thread);
|
||||||
cpa!(scope, write_index_base += local_index);
|
cpa!(scope, write_index_base += local_index);
|
||||||
|
|
||||||
|
@ -188,7 +192,7 @@ impl<P: Prng<E>, E: JitElement> PrngShader<P, E> {
|
||||||
let thread_seed = scope.create_local(Elem::UInt);
|
let thread_seed = scope.create_local(Elem::UInt);
|
||||||
cpa!(scope, thread_seed = cast(1000000007));
|
cpa!(scope, thread_seed = cast(1000000007));
|
||||||
let thread_seed_index = scope.create_local(Elem::UInt);
|
let thread_seed_index = scope.create_local(Elem::UInt);
|
||||||
cpa!(scope, thread_seed_index = workgroup_offset + local_index);
|
cpa!(scope, thread_seed_index = cube_offset + local_index);
|
||||||
cpa!(scope, thread_seed *= thread_seed_index);
|
cpa!(scope, thread_seed *= thread_seed_index);
|
||||||
|
|
||||||
let state_0 = scope.create_local(Elem::UInt);
|
let state_0 = scope.create_local(Elem::UInt);
|
||||||
|
|
|
@ -33,8 +33,8 @@ pub(crate) struct SharedReduceDimEagerKernel<
|
||||||
EO: JitElement,
|
EO: JitElement,
|
||||||
> {
|
> {
|
||||||
dim: usize,
|
dim: usize,
|
||||||
workgroup_size_x: usize,
|
cube_dim_x: usize,
|
||||||
workgroup_size_y: usize,
|
cube_dim_y: usize,
|
||||||
n_input_values_per_thread: u32,
|
n_input_values_per_thread: u32,
|
||||||
divisible_shape: bool,
|
divisible_shape: bool,
|
||||||
_reduce_dim: PhantomData<RD>,
|
_reduce_dim: PhantomData<RD>,
|
||||||
|
@ -58,7 +58,7 @@ impl<RD: ReduceDimShared<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Ker
|
||||||
SharedReduceDimComputeShader {
|
SharedReduceDimComputeShader {
|
||||||
tensor,
|
tensor,
|
||||||
dim: self.dim,
|
dim: self.dim,
|
||||||
shared_memory_size: self.workgroup_size_x * self.workgroup_size_y,
|
shared_memory_size: self.cube_dim_x * self.cube_dim_y,
|
||||||
n_input_values_per_thread: self.n_input_values_per_thread,
|
n_input_values_per_thread: self.n_input_values_per_thread,
|
||||||
output,
|
output,
|
||||||
divisible_shape: self.divisible_shape,
|
divisible_shape: self.divisible_shape,
|
||||||
|
@ -83,8 +83,8 @@ impl<RD: ReduceDimShared<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Ker
|
||||||
};
|
};
|
||||||
|
|
||||||
let settings = KernelSettings::default().cube_dim(CubeDim::new(
|
let settings = KernelSettings::default().cube_dim(CubeDim::new(
|
||||||
self.workgroup_size_x as u32,
|
self.cube_dim_x as u32,
|
||||||
self.workgroup_size_y as u32,
|
self.cube_dim_y as u32,
|
||||||
1,
|
1,
|
||||||
));
|
));
|
||||||
KernelIntegrator::new(info).integrate(settings)
|
KernelIntegrator::new(info).integrate(settings)
|
||||||
|
@ -95,8 +95,8 @@ impl<RD: ReduceDimShared<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Ker
|
||||||
"{:?}dim={}x={}y={}n={}divshape={}",
|
"{:?}dim={}x={}y={}n={}divshape={}",
|
||||||
core::any::TypeId::of::<Self>(),
|
core::any::TypeId::of::<Self>(),
|
||||||
self.dim,
|
self.dim,
|
||||||
self.workgroup_size_x,
|
self.cube_dim_x,
|
||||||
self.workgroup_size_y,
|
self.cube_dim_y,
|
||||||
self.n_input_values_per_thread,
|
self.n_input_values_per_thread,
|
||||||
self.divisible_shape
|
self.divisible_shape
|
||||||
)
|
)
|
||||||
|
@ -111,13 +111,13 @@ impl<E: JitElement, RD: ReduceDimShared<E>> SharedReduceDimComputeShader<E, RD>
|
||||||
let rank = Variable::Rank;
|
let rank = Variable::Rank;
|
||||||
let dim: Variable = self.dim.into();
|
let dim: Variable = self.dim.into();
|
||||||
|
|
||||||
let workgroup_id_x = Variable::CubePosX;
|
let cube_pos_x = Variable::CubePosX;
|
||||||
let workgroup_id_y = Variable::CubePosY;
|
let cube_pos_y = Variable::CubePosY;
|
||||||
let num_workgroups_x = Variable::CubeCountX;
|
let cube_count_x = Variable::CubeCountX;
|
||||||
let local_invocation_id_x = Variable::UnitPosX;
|
let local_invocation_id_x = Variable::UnitPosX;
|
||||||
let local_invocation_id_y = Variable::UnitPosY;
|
let local_invocation_id_y = Variable::UnitPosY;
|
||||||
let workgroup_size_x = Variable::CubeDimX;
|
let cube_dim_x = Variable::CubeDimX;
|
||||||
let workgroup_size_y = Variable::CubeDimY;
|
let cube_dim_y = Variable::CubeDimY;
|
||||||
|
|
||||||
let stride_reduce_dim_input = scope.create_local(Elem::UInt);
|
let stride_reduce_dim_input = scope.create_local(Elem::UInt);
|
||||||
cpa!(scope, stride_reduce_dim_input = stride(tensor, dim));
|
cpa!(scope, stride_reduce_dim_input = stride(tensor, dim));
|
||||||
|
@ -126,16 +126,16 @@ impl<E: JitElement, RD: ReduceDimShared<E>> SharedReduceDimComputeShader<E, RD>
|
||||||
|
|
||||||
// To determine which reduce_group (not position, but absolute id)
|
// To determine which reduce_group (not position, but absolute id)
|
||||||
let reduce_group_id = scope.create_local(Elem::UInt);
|
let reduce_group_id = scope.create_local(Elem::UInt);
|
||||||
cpa!(scope, reduce_group_id = workgroup_id_y * num_workgroups_x);
|
cpa!(scope, reduce_group_id = cube_pos_y * cube_count_x);
|
||||||
cpa!(scope, reduce_group_id += workgroup_id_x);
|
cpa!(scope, reduce_group_id += cube_pos_x);
|
||||||
|
|
||||||
// nth thread in the workgroup
|
// nth thread in the cube
|
||||||
let local_id = scope.create_local(Elem::UInt);
|
let local_id = scope.create_local(Elem::UInt);
|
||||||
cpa!(scope, local_id = local_invocation_id_y * workgroup_size_x);
|
cpa!(scope, local_id = local_invocation_id_y * cube_dim_x);
|
||||||
cpa!(scope, local_id += local_invocation_id_x);
|
cpa!(scope, local_id += local_invocation_id_x);
|
||||||
|
|
||||||
let n_threads = scope.create_local(Elem::UInt);
|
let n_threads = scope.create_local(Elem::UInt);
|
||||||
cpa!(scope, n_threads = workgroup_size_x * workgroup_size_y);
|
cpa!(scope, n_threads = cube_dim_x * cube_dim_y);
|
||||||
|
|
||||||
let index_offset = scope.zero(Elem::UInt);
|
let index_offset = scope.zero(Elem::UInt);
|
||||||
|
|
||||||
|
@ -242,17 +242,17 @@ pub fn reduce_dim_shared<
|
||||||
dim: usize,
|
dim: usize,
|
||||||
) -> JitTensor<R, EO, D> {
|
) -> JitTensor<R, EO, D> {
|
||||||
let num_elems_output = output.shape.num_elements();
|
let num_elems_output = output.shape.num_elements();
|
||||||
let n_workgroups_x = f32::ceil(f32::sqrt(num_elems_output as f32));
|
let cube_count_x = f32::ceil(f32::sqrt(num_elems_output as f32));
|
||||||
let n_workgroups_y = f32::ceil(num_elems_output as f32 / n_workgroups_x);
|
let cube_count_y = f32::ceil(num_elems_output as f32 / cube_count_x);
|
||||||
let grid = CubeCount::new(n_workgroups_x as u32, n_workgroups_y as u32, 1);
|
let grid = CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1);
|
||||||
|
|
||||||
let reduce_group_size = input.shape.dims[dim];
|
let reduce_group_size = input.shape.dims[dim];
|
||||||
let n_invocation_per_workgroup = SUBCUBE_DIM_APPROX * SUBCUBE_DIM_APPROX;
|
let n_invocation_per_cube = SUBCUBE_DIM_APPROX * SUBCUBE_DIM_APPROX;
|
||||||
let n_input_values_per_thread =
|
let n_input_values_per_thread =
|
||||||
f32::ceil(reduce_group_size as f32 / n_invocation_per_workgroup as f32) as u32;
|
f32::ceil(reduce_group_size as f32 / n_invocation_per_cube as f32) as u32;
|
||||||
|
|
||||||
let divisible_shape =
|
let divisible_shape =
|
||||||
n_invocation_per_workgroup as u32 * n_input_values_per_thread == reduce_group_size as u32;
|
n_invocation_per_cube as u32 * n_input_values_per_thread == reduce_group_size as u32;
|
||||||
|
|
||||||
let kernel = SharedReduceDimEagerKernel::<RD, R, EI, EO>::new(
|
let kernel = SharedReduceDimEagerKernel::<RD, R, EI, EO>::new(
|
||||||
dim,
|
dim,
|
||||||
|
|
|
@ -18,7 +18,10 @@ pub(crate) mod tune;
|
||||||
/// Elements for JIT backend
|
/// Elements for JIT backend
|
||||||
pub mod element;
|
pub mod element;
|
||||||
|
|
||||||
use burn_cube::{compute::CubeTask, Runtime};
|
use burn_cube::{
|
||||||
|
compute::{CubeCount, CubeTask},
|
||||||
|
Runtime,
|
||||||
|
};
|
||||||
pub use element::{FloatElement, IntElement, JitElement};
|
pub use element::{FloatElement, IntElement, JitElement};
|
||||||
|
|
||||||
mod backend;
|
mod backend;
|
||||||
|
@ -48,5 +51,6 @@ pub trait JitRuntime: Runtime<Device = Self::JitDevice, Server = Self::JitServer
|
||||||
type JitServer: burn_compute::server::ComputeServer<
|
type JitServer: burn_compute::server::ComputeServer<
|
||||||
AutotuneKey = JitAutotuneKey,
|
AutotuneKey = JitAutotuneKey,
|
||||||
Kernel = Box<dyn CubeTask>,
|
Kernel = Box<dyn CubeTask>,
|
||||||
|
DispatchOptions = CubeCount<Self::JitServer>,
|
||||||
>;
|
>;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
|
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
|
||||||
use burn_cube::compute::LaunchSettings;
|
|
||||||
use burn_cube::prelude::*;
|
use burn_cube::prelude::*;
|
||||||
|
|
||||||
use super::SourceTemplate;
|
use super::SourceTemplate;
|
||||||
|
@ -11,18 +10,13 @@ pub trait KernelSource: Send + 'static + Sync {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(new)]
|
#[derive(new)]
|
||||||
/// Wraps a [kernel source](KernelSource) into a [cube task](CubeTask) with launch
|
/// Wraps a [kernel source](KernelSource) into a [cube task](CubeTask).
|
||||||
/// information.
|
|
||||||
pub struct SourceKernel<K> {
|
pub struct SourceKernel<K> {
|
||||||
kernel_source: K,
|
kernel_source: K,
|
||||||
cube_count: CubeCount,
|
|
||||||
cube_dim: CubeDim,
|
cube_dim: CubeDim,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<K> CubeTask for SourceKernel<K>
|
impl<K: KernelSource> CubeTask for SourceKernel<K> {
|
||||||
where
|
|
||||||
K: KernelSource + 'static,
|
|
||||||
{
|
|
||||||
fn compile(&self) -> CompiledKernel {
|
fn compile(&self) -> CompiledKernel {
|
||||||
let source_template = self.kernel_source.source();
|
let source_template = self.kernel_source.source();
|
||||||
let source = source_template.complete();
|
let source = source_template.complete();
|
||||||
|
@ -37,12 +31,6 @@ where
|
||||||
fn id(&self) -> String {
|
fn id(&self) -> String {
|
||||||
format!("{:?}", core::any::TypeId::of::<K>())
|
format!("{:?}", core::any::TypeId::of::<K>())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn launch_settings(&self) -> LaunchSettings {
|
|
||||||
LaunchSettings {
|
|
||||||
cube_count: self.cube_count.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates kernel source code by replacing some information using templating.
|
/// Generates kernel source code by replacing some information using templating.
|
||||||
|
|
|
@ -64,8 +64,15 @@ where
|
||||||
&mut self,
|
&mut self,
|
||||||
pipeline: Arc<ComputePipeline>,
|
pipeline: Arc<ComputePipeline>,
|
||||||
bind_group: BindGroup,
|
bind_group: BindGroup,
|
||||||
work_group: CubeCount,
|
count: CubeCount<Self>,
|
||||||
) {
|
) {
|
||||||
|
// First resolve the dispatch buffer if needed. The weird ordering is because the lifetime of this
|
||||||
|
// needs to be longer than the compute pass, so we can't do this just before dispatching.
|
||||||
|
let dispatch_resource = match count.clone() {
|
||||||
|
CubeCount::Dynamic(binding) => Some(self.memory_management.get(binding.memory)),
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
let mut compute = self
|
let mut compute = self
|
||||||
.encoder
|
.encoder
|
||||||
.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||||
|
@ -75,12 +82,21 @@ where
|
||||||
|
|
||||||
compute.set_pipeline(&pipeline);
|
compute.set_pipeline(&pipeline);
|
||||||
compute.set_bind_group(0, &bind_group, &[]);
|
compute.set_bind_group(0, &bind_group, &[]);
|
||||||
compute.dispatch_workgroups(work_group.x, work_group.y, work_group.z);
|
|
||||||
|
match count {
|
||||||
|
CubeCount::Static(x, y, z) => {
|
||||||
|
compute.dispatch_workgroups(x, y, z);
|
||||||
|
}
|
||||||
|
CubeCount::Dynamic(_) => {
|
||||||
|
let resource = dispatch_resource.as_ref().unwrap();
|
||||||
|
compute.dispatch_workgroups_indirect(&resource.buffer, resource.offset());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
self.tasks_count += 1;
|
self.tasks_count += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn pipeline(&mut self, kernel: Box<dyn CubeTask>) -> Arc<ComputePipeline> {
|
fn pipeline(&mut self, kernel: <Self as ComputeServer>::Kernel) -> Arc<ComputePipeline> {
|
||||||
let kernel_id = kernel.id();
|
let kernel_id = kernel.id();
|
||||||
|
|
||||||
if let Some(pipeline) = self.pipelines.get(&kernel_id) {
|
if let Some(pipeline) = self.pipelines.get(&kernel_id) {
|
||||||
|
@ -143,6 +159,7 @@ where
|
||||||
MM: MemoryManagement<WgpuStorage>,
|
MM: MemoryManagement<WgpuStorage>,
|
||||||
{
|
{
|
||||||
type Kernel = Box<dyn CubeTask>;
|
type Kernel = Box<dyn CubeTask>;
|
||||||
|
type DispatchOptions = CubeCount<Self>;
|
||||||
type Storage = WgpuStorage;
|
type Storage = WgpuStorage;
|
||||||
type MemoryManagement = MM;
|
type MemoryManagement = MM;
|
||||||
type AutotuneKey = JitAutotuneKey;
|
type AutotuneKey = JitAutotuneKey;
|
||||||
|
@ -259,8 +276,12 @@ where
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn execute(&mut self, kernel: Self::Kernel, bindings: Vec<server::Binding<Self>>) {
|
fn execute(
|
||||||
let work_group = kernel.launch_settings().cube_count;
|
&mut self,
|
||||||
|
kernel: Self::Kernel,
|
||||||
|
count: Self::DispatchOptions,
|
||||||
|
bindings: Vec<server::Binding<Self>>,
|
||||||
|
) {
|
||||||
let pipeline = self.pipeline(kernel);
|
let pipeline = self.pipeline(kernel);
|
||||||
let group_layout = pipeline.get_bind_group_layout(0);
|
let group_layout = pipeline.get_bind_group_layout(0);
|
||||||
|
|
||||||
|
@ -284,7 +305,7 @@ where
|
||||||
entries: &entries,
|
entries: &entries,
|
||||||
});
|
});
|
||||||
|
|
||||||
self.register_compute(pipeline, bind_group, work_group);
|
self.register_compute(pipeline, bind_group, count);
|
||||||
|
|
||||||
if self.tasks_count >= self.tasks_max {
|
if self.tasks_count >= self.tasks_max {
|
||||||
self.sync(SyncType::Flush);
|
self.sync(SyncType::Flush);
|
||||||
|
|
|
@ -111,7 +111,8 @@ impl ComputeStorage for WgpuStorage {
|
||||||
size: size as u64,
|
size: size as u64,
|
||||||
usage: wgpu::BufferUsages::COPY_DST
|
usage: wgpu::BufferUsages::COPY_DST
|
||||||
| wgpu::BufferUsages::STORAGE
|
| wgpu::BufferUsages::STORAGE
|
||||||
| wgpu::BufferUsages::COPY_SRC,
|
| wgpu::BufferUsages::COPY_SRC
|
||||||
|
| wgpu::BufferUsages::INDIRECT,
|
||||||
mapped_at_creation: false,
|
mapped_at_creation: false,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
|
|
@ -87,11 +87,13 @@ impl<F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime, F, I> {
|
||||||
// Declare the wgsl workgroup with the number of cubes in x, y and z.
|
// Declare the wgsl workgroup with the number of cubes in x, y and z.
|
||||||
let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;
|
let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;
|
||||||
let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32;
|
let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32;
|
||||||
let cube_count = CubeCount::new(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);
|
let cube_count =
|
||||||
|
CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);
|
||||||
|
|
||||||
// Execute lazily the kernel with the launch information and the given buffers.
|
// Execute lazily the kernel with the launch information and the given buffers.
|
||||||
lhs.client.execute(
|
lhs.client.execute(
|
||||||
Box::new(SourceKernel::new(kernel, cube_count, cube_dim)),
|
Box::new(SourceKernel::new(kernel, cube_dim)),
|
||||||
|
cube_count,
|
||||||
vec![
|
vec![
|
||||||
lhs.handle.binding(),
|
lhs.handle.binding(),
|
||||||
rhs.handle.binding(),
|
rhs.handle.binding(),
|
||||||
|
|
|
@ -22,7 +22,7 @@ pub fn launch<R: Runtime>(device: &R::Device) {
|
||||||
|
|
||||||
gelu_launch::<F32, R>(
|
gelu_launch::<F32, R>(
|
||||||
client.clone(),
|
client.clone(),
|
||||||
CubeCount::new(1, 1, 1),
|
CubeCount::Static(1, 1, 1),
|
||||||
CubeDim::default(),
|
CubeDim::default(),
|
||||||
ArrayArg::new(&input_handle, input.len()),
|
ArrayArg::new(&input_handle, input.len()),
|
||||||
ArrayArg::new(&output_handle, input.len()),
|
ArrayArg::new(&output_handle, input.len()),
|
||||||
|
|
Loading…
Reference in New Issue