Migrate/jit/matmul tiling 2d (#1472)

* refactor matmul files

* wip refactor matmul

* everything is memco

* support local arrays

* advancing tiling2d

* advancing tiling2d

* advancing tiling2d

* tiling2d finished but buggy

* configurable unrolling

* not bugged

* fails on unroll

* stupid break

* tiling2d no assumption works

* clippy

* bounds check as bool

* lhs rhs as enum

* tiling 2d major refactor

* remove assign vec4

* variable declarations above loops

* fmt

* clippy

* Fix autotune + unroll

* move val

* clippy

* fmt

---------

Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
This commit is contained in:
Louis Fortier-Dubois 2024-03-22 08:26:32 -04:00 committed by GitHub
parent 0a8a3cc9e9
commit dd699a90a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 1296 additions and 330 deletions

View File

@ -22,4 +22,6 @@ pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug {
fn compile(shader: gpu::ComputeShader) -> Self::Representation;
/// The size of the given element in bytes.
fn elem_size(elem: gpu::Elem) -> usize;
/// The maximal size of a shared memory
fn max_shared_memory_size() -> usize;
}

View File

@ -119,3 +119,15 @@ impl Loop {
parent_scope.register(Branch::Loop(op));
}
}
#[allow(missing_docs)]
pub struct UnrolledRangeLoop;
impl UnrolledRangeLoop {
/// Registers an unrolled range loop to the given scope.
pub fn register<F: Fn(Variable, &mut Scope)>(scope: &mut Scope, start: u32, end: u32, func: F) {
for i in start..end {
func(i.into(), scope);
}
}
}

View File

@ -293,6 +293,17 @@ macro_rules! gpu {
gpu!(unary $input, $out)
));
};
// out = vec4(a, b, c, d)
($scope:expr, $out:ident = vec4($a:ident,$b:ident,$c:ident,$d:ident)) => {
let i = $scope.zero(Elem::UInt);
gpu!($scope, $out[i] = $a);
gpu!($scope, i = i + 1u32);
gpu!($scope, $out[i] = $b);
gpu!($scope, i = i + 1u32);
gpu!($scope, $out[i] = $c);
gpu!($scope, i = i + 1u32);
gpu!($scope, $out[i] = $d);
};
// out = input
($scope:expr, $out:ident = $input:ident) => {
gpu!($scope, $out = cast($input))
@ -326,10 +337,18 @@ macro_rules! gpu {
out: $out.into(),
});
};
// range(start, end).for_each(|scope| { ... })
// range(start, end).for_each(|i, scope| { ... })
($scope:expr, range($start:expr, $end:expr).for_each($arg:expr)) => {
$crate::codegen::dialect::gpu::RangeLoop::register($scope, $start.into(), $end.into(), $arg);
};
// range(start, end, unroll).for_each(|i, scope| { ... })
($scope:expr, range($start:expr, $end:expr, $unroll:expr).for_each($arg:expr)) => {
if $unroll {
$crate::codegen::dialect::gpu::UnrolledRangeLoop::register($scope, $start.into(), $end.into(), $arg);
} else {
$crate::codegen::dialect::gpu::RangeLoop::register($scope, $start.into(), $end.into(), $arg);
}
};
// loop(|scope| { ... })
($scope:expr, loop($arg:expr)) => {
$crate::codegen::dialect::gpu::Loop::register($scope, $arg);

View File

@ -36,6 +36,7 @@ pub enum Operator {
Tanh(UnaryOperator),
Powf(BinaryOperator),
Sqrt(UnaryOperator),
Ceil(UnaryOperator),
Erf(UnaryOperator),
Recip(UnaryOperator),
Equal(BinaryOperator),

View File

@ -20,7 +20,8 @@ pub struct Scope {
pub depth: u8,
pub operations: Vec<Operation>,
locals: Vec<Variable>,
shared: Vec<Variable>,
shared_memories: Vec<Variable>,
local_arrays: Vec<Variable>,
reads_global: Vec<(Variable, ReadingStrategy, Variable)>,
index_offset_with_output_layout_position: Vec<usize>,
writes_global: Vec<(Variable, Variable)>,
@ -48,7 +49,8 @@ impl Scope {
depth: 0,
operations: Vec::new(),
locals: Vec::new(),
shared: Vec::new(),
local_arrays: Vec::new(),
shared_memories: Vec::new(),
reads_global: Vec::new(),
index_offset_with_output_layout_position: Vec::new(),
writes_global: Vec::new(),
@ -213,7 +215,8 @@ impl Scope {
depth: self.depth + 1,
operations: Vec::new(),
locals: Vec::new(),
shared: Vec::new(),
shared_memories: Vec::new(),
local_arrays: Vec::new(),
reads_global: Vec::new(),
index_offset_with_output_layout_position: Vec::new(),
writes_global: Vec::new(),
@ -308,7 +311,11 @@ impl Scope {
}
fn new_shared_index(&self) -> u16 {
self.shared.len() as u16
self.shared_memories.len() as u16
}
fn new_local_array_index(&self) -> u16 {
self.local_arrays.len() as u16
}
fn read_input_strategy(
@ -339,7 +346,16 @@ impl Scope {
let item = item.into();
let index = self.new_shared_index();
let shared_memory = Variable::SharedMemory(index, item, shared_memory_size);
self.shared.push(shared_memory);
self.shared_memories.push(shared_memory);
shared_memory
}
/// Create a local array of the given [item type](Item).
pub fn create_local_array<I: Into<Item>>(&mut self, item: I, array_size: u32) -> Variable {
let item = item.into();
let index = self.new_local_array_index();
let local_array = Variable::LocalArray(index, item, self.depth, array_size);
self.local_arrays.push(local_array);
local_array
}
}

View File

@ -11,6 +11,7 @@ pub enum Variable {
LocalScalar(u16, Elem, u8),
ConstantScalar(f64, Elem),
SharedMemory(u16, Item, u32),
LocalArray(u16, Item, u8, u32),
Id,
LocalInvocationIndex,
LocalInvocationIdX,
@ -41,6 +42,7 @@ impl Variable {
Variable::GlobalOutputArray(idx, _) => Some(*idx),
Variable::ConstantScalar(_, _) => None,
Variable::SharedMemory(idx, _, _) => Some(*idx),
Variable::LocalArray(idx, _, _, _) => Some(*idx),
Variable::Id => None,
Variable::LocalInvocationIndex => None,
Variable::LocalInvocationIdX => None,
@ -70,6 +72,7 @@ impl Variable {
Variable::LocalScalar(_, elem, _) => Item::Scalar(*elem),
Variable::ConstantScalar(_, elem) => Item::Scalar(*elem),
Variable::SharedMemory(_, item, _) => *item,
Variable::LocalArray(_, item, _, _) => *item,
Variable::Id => Item::Scalar(Elem::UInt),
Variable::Rank => Item::Scalar(Elem::UInt),
Variable::LocalInvocationIndex => Item::Scalar(Elem::UInt),

View File

@ -52,6 +52,7 @@ impl Operator {
Operator::Tanh(op) => Operator::Tanh(op.vectorize(vectorization)),
Operator::Powf(op) => Operator::Powf(op.vectorize(vectorization)),
Operator::Sqrt(op) => Operator::Sqrt(op.vectorize(vectorization)),
Operator::Ceil(op) => Operator::Ceil(op.vectorize(vectorization)),
Operator::Erf(op) => Operator::Erf(op.vectorize(vectorization)),
Operator::Recip(op) => Operator::Recip(op.vectorize(vectorization)),
Operator::Equal(op) => Operator::Equal(op.vectorize(vectorization)),
@ -130,6 +131,12 @@ impl Variable {
item.vectorize(vectorize),
item.vectorized_size(vectorize, *size),
),
Variable::LocalArray(index, item, name, size) => Variable::LocalArray(
*index,
item.vectorize(vectorize),
*name,
item.vectorized_size(vectorize, *size),
),
Variable::ConstantScalar(_, _) => *self,
Variable::GlobalScalar(_, _) => *self,
Variable::Id => *self,

View File

@ -247,6 +247,11 @@ impl TraceBuilder {
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Ceil(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Log(op) => mark_unary(
op,
&mut local_tensor_ids_input,

View File

@ -1,24 +1,98 @@
use crate::{tensor::JitTensor, JitElement, Runtime};
use std::cmp::{max, min};
use burn_tensor::Shape;
use crate::{compute::WorkGroup, tensor::JitTensor, Compiler, JitElement, Runtime};
use super::{
init_matmul_output, matmul_autotune, matmul_mem_coalescing,
unpadded::matmul_tiling_2d_unpadded, vec4::matmul_tiling_2d_vec4,
init_matmul_output, matmul_autotune, matmul_simple, matmul_tiling_2d, matmul_tiling_2d_padded,
};
#[derive(Debug, Clone)]
/// Tiling 2D parameters
pub struct Tiling2dConfig {
/// Number of invocations in x
pub grid_x: usize,
/// Number of invocations in y
pub grid_y: usize,
/// Block size along dimension of lhs
pub block_size_m: usize,
/// Block size along common dimension
pub block_size_k: usize,
/// Block size along dimension of rhs
pub block_size_n: usize,
/// Tile size along dimension of lhs
pub tile_size_m: usize,
/// Tile size along dimension of rhs
pub tile_size_n: usize,
}
impl Tiling2dConfig {
#[allow(unused)]
fn new<R: Runtime>(
grid_x: usize,
grid_y: usize,
block_size_m: usize,
block_size_k: usize,
block_size_n: usize,
tile_size_m: usize,
tile_size_n: usize,
) -> Self {
assert!(grid_x == f32::ceil(block_size_m as f32 / tile_size_m as f32) as usize);
assert!(grid_y == f32::ceil(block_size_n as f32 / tile_size_n as f32) as usize);
assert!(
block_size_k <= min(block_size_m, block_size_n),
"Not enough invocations to fill shared memory"
);
assert!(
block_size_k * max(block_size_m, block_size_n)
<= <R::Compiler as Compiler>::max_shared_memory_size(),
"Shared memory limit will be busted. "
);
assert!(
block_size_m % tile_size_m == 0 && block_size_n % tile_size_n == 0,
"Tile size must divide block size in m and n dimensions"
);
Self {
grid_x,
grid_y,
block_size_m,
block_size_k,
block_size_n,
tile_size_m,
tile_size_n,
}
}
}
impl Default for Tiling2dConfig {
fn default() -> Self {
Self {
grid_x: 16,
grid_y: 16,
block_size_m: 64,
block_size_k: 32,
block_size_n: 64,
tile_size_m: 4,
tile_size_n: 4,
}
}
}
/// The strategy to be used when launching a matmul kernel.
#[derive(Default)]
pub enum MatmulStrategy {
/// A simple kernel will be used with memory coalescing optimization.
Simple {
/// Grad size x
/// Number of invocations in x
grid_x: usize,
/// Grad size y
/// Number of invocations in y
grid_y: usize,
},
/// A tiling 2d kernel will be used, with support for any matrix size without padding.
Tiling2d,
Tiling2d(Tiling2dConfig),
/// A tiling 2d kernel will be used, with support for any matrix size with padding.
Tiling2dPadded,
Tiling2dPadded(Tiling2dConfig),
#[cfg(feature = "autotune")]
/// Using autotune to chose the best kernel based on runtime information.
#[default]
@ -42,17 +116,56 @@ pub fn matmul<R: Runtime, E: JitElement, const D: usize>(
match strategy {
MatmulStrategy::Simple { grid_x, grid_y } => {
let out = init_matmul_output(&lhs, &rhs);
matmul_mem_coalescing(lhs, rhs, out, grid_x, grid_y)
matmul_simple(lhs, rhs, out, grid_x, grid_y)
}
MatmulStrategy::Tiling2d => {
MatmulStrategy::Tiling2d(config) => {
let out = init_matmul_output(&lhs, &rhs);
matmul_tiling_2d_unpadded(lhs, rhs, out)
matmul_tiling_2d(lhs, rhs, out, config)
}
MatmulStrategy::Tiling2dPadded => {
MatmulStrategy::Tiling2dPadded(config) => {
let out = init_matmul_output(&lhs, &rhs);
matmul_tiling_2d_vec4(lhs, rhs, out)
matmul_tiling_2d_padded(lhs, rhs, out, config)
}
#[cfg(feature = "autotune")]
MatmulStrategy::Autotune => matmul_autotune(lhs, rhs),
}
}
pub(crate) fn simple_launch_options<const D: usize>(
lhs_shape: &Shape<D>,
rhs_shape: &Shape<D>,
output_shape: &Shape<D>,
workgroup_size_x: usize,
workgroup_size_y: usize,
) -> WorkGroup {
let num_rows = lhs_shape.dims[D - 2];
let num_cols = rhs_shape.dims[D - 1];
// set number of workgroups
let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32;
let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32;
let mut num_iter = 1;
for i in 0..D - 2 {
num_iter *= output_shape.dims[i];
}
WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32)
}
pub(crate) fn tiling2d_launch_options<const D: usize>(
output_shape: &Shape<D>,
config: Tiling2dConfig,
) -> WorkGroup {
let num_rows = output_shape.dims[D - 2];
let num_cols = output_shape.dims[D - 1];
// set number of workgroups
let blocks_needed_in_x = f32::ceil(num_rows as f32 / config.block_size_m as f32) as u32;
let blocks_needed_in_y = f32::ceil(num_cols as f32 / config.block_size_n as f32) as u32;
let mut num_iter = 1;
for i in 0..D - 2 {
num_iter *= output_shape.dims[i];
}
WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32)
}

View File

@ -1,13 +1,22 @@
mod base;
mod mem_coalescing;
mod simple;
mod tiling2d;
mod tiling2d_shader;
mod tune;
/// Contains utilitary for matmul operation
pub mod utils;
pub use base::*;
pub use mem_coalescing::*;
pub use tiling2d::*;
pub use simple::*;
pub use tune::*;
pub use utils::*;
#[cfg(feature = "export_tests")]
#[allow(missing_docs)]
pub mod padding;
#[cfg(not(feature = "export_tests"))]
mod padding;
pub use tiling2d::*;

View File

@ -6,15 +6,15 @@ use crate::{
dialect::gpu, execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler,
EagerHandle, InputInfo, OutputInfo, WorkgroupLaunch,
},
compute::WorkGroup,
element::JitElement,
kernel::{into_contiguous, DynamicKernelSource, SourceTemplate, WORKGROUP_DEFAULT},
tensor::JitTensor,
Runtime,
};
use burn_tensor::Shape;
use std::marker::PhantomData;
use super::simple_launch_options;
#[derive(new, Debug)]
struct MatmulEagerKernel<R: Runtime> {
workgroup_size_x: usize,
@ -213,11 +213,11 @@ pub fn matmul_mem_coalescing_default<R: Runtime, E: JitElement, const D: usize>(
rhs: JitTensor<R, E, D>,
out: JitTensor<R, E, D>,
) -> JitTensor<R, E, D> {
matmul_mem_coalescing::<R, E, D>(lhs, rhs, out, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT)
matmul_simple::<R, E, D>(lhs, rhs, out, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT)
}
/// Matrix multiplication using memory coalescing algorithm with custom workgroup sizes
pub fn matmul_mem_coalescing<R: Runtime, E: JitElement, const D: usize>(
pub fn matmul_simple<R: Runtime, E: JitElement, const D: usize>(
lhs: JitTensor<R, E, D>,
rhs: JitTensor<R, E, D>,
out: JitTensor<R, E, D>,
@ -228,7 +228,7 @@ pub fn matmul_mem_coalescing<R: Runtime, E: JitElement, const D: usize>(
let lhs = into_contiguous(lhs);
let rhs = into_contiguous(rhs);
let workgroup = launch_options(
let workgroup = simple_launch_options(
&lhs.shape,
&rhs.shape,
&out.shape,
@ -242,9 +242,8 @@ pub fn matmul_mem_coalescing<R: Runtime, E: JitElement, const D: usize>(
&[
EagerHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
EagerHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
EagerHandle::new(&out.handle, &out.strides, &out.shape.dims),
],
&[],
&[EagerHandle::new(&out.handle, &out.strides, &out.shape.dims)],
None,
kernel,
WorkgroupLaunch::Custom(workgroup),
@ -253,24 +252,3 @@ pub fn matmul_mem_coalescing<R: Runtime, E: JitElement, const D: usize>(
out
}
fn launch_options<const D: usize>(
lhs_shape: &Shape<D>,
rhs_shape: &Shape<D>,
output_shape: &Shape<D>,
workgroup_size_x: usize,
workgroup_size_y: usize,
) -> WorkGroup {
let num_rows = lhs_shape.dims[D - 2];
let num_cols = rhs_shape.dims[D - 1];
// set number of workgroups
let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32;
let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32;
let mut num_iter = 1;
for i in 0..D - 2 {
num_iter *= output_shape.dims[i];
}
WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32)
}

View File

@ -0,0 +1,190 @@
use burn_tensor::{Element, Shape};
use crate::{
codegen::{
dialect::gpu, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
Execution, InputInfo, OutputInfo, WorkgroupLaunch,
},
element::JitElement,
kernel::{into_contiguous, DynamicKernelSource, SourceTemplate},
tensor::JitTensor,
Runtime,
};
use std::marker::PhantomData;
use super::{
padding::{crop, pad_round, PaddingOutput},
shape_out, tiling2d_launch_options,
tiling2d_shader::MatmulTiling2dShader,
Tiling2dConfig,
};
#[derive(new, Debug)]
struct MatmulTiling2d<E: JitElement> {
_elem: PhantomData<E>,
}
#[derive(new, Debug)]
struct MatmulTiling2dEagerKernel<R: Runtime> {
config: Tiling2dConfig,
bounds_check_required: bool,
_runtime: PhantomData<R>,
}
impl<R: Runtime> DynamicKernelSource for MatmulTiling2dEagerKernel<R> {
fn source(&self) -> SourceTemplate {
let mut scope = gpu::Scope::root();
let lhs = gpu::Variable::GlobalInputArray(0, gpu::Elem::Float.into());
let rhs = gpu::Variable::GlobalInputArray(1, gpu::Elem::Float.into());
let out = gpu::Variable::GlobalOutputArray(0, gpu::Elem::Float.into());
scope.write_global_custom(out);
MatmulTiling2dShader {
variables: gpu::BinaryOperator { lhs, rhs, out },
config: self.config.clone(),
bounds_check_required: self.bounds_check_required,
unroll: true,
}
.expand(&mut scope);
let lhs = InputInfo::Array {
item: gpu::Elem::Float.into(),
visibility: gpu::Visibility::Read,
};
let rhs = InputInfo::Array {
item: gpu::Elem::Float.into(),
visibility: gpu::Visibility::Read,
};
let out = OutputInfo::Array {
item: gpu::Elem::Float.into(),
};
let info = CompilationInfo {
inputs: vec![lhs, rhs],
outputs: vec![out],
scope,
};
let settings = CompilationSettings::default().workgroup_size(gpu::WorkgroupSize::new(
self.config.grid_x as u32,
self.config.grid_y as u32,
1,
));
let shader = Compilation::new(info).compile(settings);
let shader = <R::Compiler as Compiler>::compile(shader);
SourceTemplate::new(shader.to_string())
}
fn id(&self) -> String {
format!(
"{:?}config={:?}boundcheck={:?}",
core::any::TypeId::of::<Self>(),
self.config,
self.bounds_check_required
)
}
}
/// Matrix multiplication using tiling 2d algorithm with
/// vec4 primitive on both lhs and rhs, with no padding needed
pub fn matmul_tiling_2d<R: Runtime, E: JitElement + Element, const D: usize>(
lhs: JitTensor<R, E, D>,
rhs: JitTensor<R, E, D>,
out: JitTensor<R, E, D>,
config: Tiling2dConfig,
) -> JitTensor<R, E, D> {
let bounds_check_required = check_bound_requirement(&lhs.shape, &rhs.shape, &config);
let kernel = MatmulTiling2dEagerKernel::<R>::new(config.clone(), bounds_check_required);
let client = lhs.client.clone();
let lhs = match lhs.batch_swapped_with_row_col() {
true => into_contiguous(lhs),
false => lhs,
};
let rhs = match rhs.batch_swapped_with_row_col() {
true => into_contiguous(rhs),
false => rhs,
};
Execution::start(kernel, client)
.inputs(&[
EagerHandle::<R>::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
EagerHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
])
.outputs(&[EagerHandle::new(&out.handle, &out.strides, &out.shape.dims)])
.execute(WorkgroupLaunch::Custom(tiling2d_launch_options(
&out.shape, config,
)));
out
}
/// Matrix multiplication using tiling 2d algorithm with padding needed
pub fn matmul_tiling_2d_padded<R: Runtime, E: JitElement + Element, const D: usize>(
lhs: JitTensor<R, E, D>,
rhs: JitTensor<R, E, D>,
out: JitTensor<R, E, D>,
config: Tiling2dConfig,
) -> JitTensor<R, E, D> {
let kernel = MatmulTiling2dEagerKernel::<R>::new(config.clone(), false);
let client = lhs.client.clone();
// A tensor may need to be padded, in which case it will implicitly become contiguous
// If not needed, it is only turned into contiguous if some batch dim has been swapped with row or col dim.
// If batches were swapped among themselves, or if the last two dims are transposed, the underlying
// kernel handles it without needing to turn it into contiguous.
let round_lhs = pad_round::<R, E, D>(lhs, config.block_size_m, config.block_size_k);
let lhs = match round_lhs {
PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => {
into_contiguous(tensor)
}
_ => round_lhs.into_tensor(),
};
let round_rhs = pad_round::<R, E, D>(rhs, config.block_size_k, config.block_size_n);
let rhs = match round_rhs {
PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => {
into_contiguous(tensor)
}
_ => round_rhs.into_tensor(),
};
let rounded_output_shape = shape_out(&lhs, &rhs);
let num_elems = rounded_output_shape.num_elements();
let buffer = client.empty(num_elems * core::mem::size_of::<E>());
let rounded_output = JitTensor::new(
rhs.client.clone(),
rhs.device.clone(),
rounded_output_shape.clone(),
buffer,
);
Execution::start(kernel, client)
.inputs(&[
EagerHandle::<R>::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
EagerHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
])
.outputs(&[EagerHandle::new(
&rounded_output.handle,
&rounded_output.strides,
&rounded_output.shape.dims,
)])
.execute(WorkgroupLaunch::Custom(tiling2d_launch_options(
&rounded_output.shape,
config,
)));
crop(rounded_output, out)
}
fn check_bound_requirement<const D: usize>(
lhs_shape: &Shape<D>,
rhs_shape: &Shape<D>,
config: &Tiling2dConfig,
) -> bool {
lhs_shape.dims[D - 2] % config.block_size_m != 0
|| lhs_shape.dims[D - 1] % config.block_size_k != 0
|| rhs_shape.dims[D - 1] % config.block_size_n != 0
}

View File

@ -1,91 +0,0 @@
use super::padding::{crop, pad_round, PaddingOutput};
use crate::{
compute::{DynamicKernel, WorkGroup},
element::JitElement,
kernel::{build_info, into_contiguous, matmul::utils::shape_out, DynamicKernelSource},
ops::numeric::empty_device,
tensor::JitTensor,
Runtime,
};
use burn_compute::server::Handle;
use burn_tensor::Shape;
pub(crate) const B_M: usize = 64;
pub(crate) const B_N: usize = 64;
pub(crate) const B_K: usize = 32;
pub(crate) const WORKGROUP_SIZE: usize = 16;
pub(super) fn make_workgroup<const D: usize>(output_shape: &Shape<D>) -> WorkGroup {
let num_blocks_x = f32::ceil(output_shape.dims[D - 2] as f32 / B_M as f32) as u32;
let num_blocks_y = f32::ceil(output_shape.dims[D - 1] as f32 / B_N as f32) as u32;
let mut num_blocks_z = 1;
for i in 0..D - 2 {
num_blocks_z *= output_shape.dims[i];
}
WorkGroup::new(num_blocks_x, num_blocks_y, num_blocks_z as u32)
}
pub(super) fn make_info_handle<R: Runtime, E: JitElement, const D: usize>(
lhs: &JitTensor<R, E, D>,
rhs: &JitTensor<R, E, D>,
output: &JitTensor<R, E, D>,
) -> Handle<R::Server> {
let info = build_info(&[lhs, rhs, output]);
rhs.client.create(bytemuck::cast_slice(&info))
}
#[allow(clippy::too_many_arguments)]
pub(super) fn matmul_tiling_2d_launch<
R: Runtime,
E: JitElement,
const D: usize,
K: DynamicKernelSource + 'static,
>(
lhs: JitTensor<R, E, D>,
rhs: JitTensor<R, E, D>,
output: JitTensor<R, E, D>,
kernel: K,
) -> JitTensor<R, E, D> {
// A tensor may need to be padded, in which case it will implicitly become contiguous
// If not needed, it is only turned into contiguous if some batch dim has been swapped with row or col dim.
// If batches were swapped among themselves, or if the last two dims are transposed, the underlying
// kernel handles it without needing to turn it into contiguous.
let round_lhs = pad_round::<R, E, D>(lhs, B_M, B_K);
let lhs = match round_lhs {
PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => {
into_contiguous(tensor)
}
_ => round_lhs.into_tensor(),
};
let round_rhs = pad_round::<R, E, D>(rhs, B_K, B_N);
let rhs = match round_rhs {
PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => {
into_contiguous(tensor)
}
_ => round_rhs.into_tensor(),
};
let rounded_output_shape = shape_out(&lhs, &rhs);
let rounded_output = empty_device(
rhs.client.clone(),
rhs.device.clone(),
rounded_output_shape.clone(),
);
let workgroup = make_workgroup(&rounded_output_shape);
let info_handle = make_info_handle(&lhs, &rhs, &rounded_output);
lhs.client.execute(
Box::new(DynamicKernel::new(kernel, workgroup)),
&[
&lhs.handle,
&rhs.handle,
&rounded_output.handle,
&info_handle,
],
);
crop(rounded_output, output)
}

View File

@ -1,14 +0,0 @@
mod base;
#[cfg(feature = "export_tests")]
#[allow(missing_docs)]
pub mod padding;
#[cfg(not(feature = "export_tests"))]
mod padding;
/// WGSL vec4 primitives are used on left and right hand tensor,
/// padding is avoided through the use of conditions in the kernel
pub mod unpadded;
/// WGSL vec4 primitives are used on left and right hand tensor
pub mod vec4;

View File

@ -1,74 +0,0 @@
use burn_tensor::Element;
use crate::{
compute::DynamicKernel,
element::JitElement,
kernel::{into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource},
tensor::JitTensor,
Runtime,
};
use std::marker::PhantomData;
use crate::kernel_wgsl;
use super::base::{make_info_handle, make_workgroup, B_K, B_M, B_N, WORKGROUP_SIZE};
kernel_wgsl!(
MatmulTiling2DUnpaddedRaw,
"../../../template/matmul/blocktiling_2d/unpadded.wgsl"
);
#[derive(new, Debug)]
struct MatmulTiling2DUnpadded<E: JitElement> {
_elem: PhantomData<E>,
}
impl<E: JitElement> DynamicKernelSource for MatmulTiling2DUnpadded<E> {
fn source(&self) -> SourceTemplate {
MatmulTiling2DUnpaddedRaw::source()
.register("b_m", B_M.to_string())
.register("b_n", B_N.to_string())
.register("b_k", B_K.to_string())
.register("bm_x_bk_4", (B_M * B_K / 4).to_string())
.register("bk_x_bn_4", (B_K * B_N / 4).to_string())
.register("workgroup_size_x", WORKGROUP_SIZE.to_string())
.register("workgroup_size_y", WORKGROUP_SIZE.to_string())
.register("workgroup_size_z", "1".to_string())
.register("elem", E::type_name())
.register("int", "i32")
}
fn id(&self) -> String {
std::format!("{:?}", self)
}
}
/// Matrix multiplication using tiling 2d algorithm with
/// vec4 primitive on both lhs and rhs, with no padding needed
pub fn matmul_tiling_2d_unpadded<R: Runtime, E: JitElement + Element, const D: usize>(
lhs: JitTensor<R, E, D>,
rhs: JitTensor<R, E, D>,
out: JitTensor<R, E, D>,
) -> JitTensor<R, E, D> {
let lhs = match lhs.batch_swapped_with_row_col() {
true => into_contiguous(lhs),
false => lhs,
};
let rhs = match rhs.batch_swapped_with_row_col() {
true => into_contiguous(rhs),
false => rhs,
};
let workgroup = make_workgroup(&out.shape);
let info_handle = make_info_handle(&lhs, &rhs, &out);
lhs.client.execute(
Box::new(DynamicKernel::new(
MatmulTiling2DUnpadded::<E>::new(),
workgroup,
)),
&[&lhs.handle, &rhs.handle, &out.handle, &info_handle],
);
out
}

View File

@ -1,49 +0,0 @@
use super::base::{matmul_tiling_2d_launch, B_K, B_M, B_N, WORKGROUP_SIZE};
use crate::{
element::JitElement,
kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource},
tensor::JitTensor,
};
use crate::{kernel_wgsl, Runtime};
use std::marker::PhantomData;
kernel_wgsl!(
MatmulTiling2Dvec4Raw,
"../../../template/matmul/blocktiling_2d/vec4.wgsl"
);
#[derive(new, Debug)]
struct MatmulTiling2Dvec4<E: JitElement> {
_elem: PhantomData<E>,
}
impl<E: JitElement> DynamicKernelSource for MatmulTiling2Dvec4<E> {
fn source(&self) -> SourceTemplate {
MatmulTiling2Dvec4Raw::source()
.register("b_m", B_M.to_string())
.register("b_n", B_N.to_string())
.register("b_k", B_K.to_string())
.register("bm_x_bk_4", (B_M * B_K / 4).to_string())
.register("bk_x_bn_4", (B_K * B_N / 4).to_string())
.register("workgroup_size_x", WORKGROUP_SIZE.to_string())
.register("workgroup_size_y", WORKGROUP_SIZE.to_string())
.register("workgroup_size_z", "1".to_string())
.register("elem", E::type_name())
.register("int", "i32")
}
fn id(&self) -> String {
std::format!("{:?}", self)
}
}
/// Matrix multiplication using tiling 2d algorithm with
/// vec4 primitive on both lhs and rhs
pub fn matmul_tiling_2d_vec4<R: Runtime, E: JitElement, const D: usize>(
lhs: JitTensor<R, E, D>,
rhs: JitTensor<R, E, D>,
out: JitTensor<R, E, D>,
) -> JitTensor<R, E, D> {
let kernel = MatmulTiling2Dvec4::<E>::new();
matmul_tiling_2d_launch::<R, _, D, _>(lhs, rhs, out, kernel)
}

View File

@ -0,0 +1,68 @@
use crate::gpu::{gpu, BinaryOperator, Scope, Synchronization, Variable};
use crate::kernel::matmul::tiling2d_shader::{
computation_loop, gather_shader_information, load_shared_memory, write_to_output,
};
use crate::kernel::matmul::Tiling2dConfig;
pub(crate) struct MatmulTiling2dShader {
pub variables: BinaryOperator,
pub config: Tiling2dConfig,
pub bounds_check_required: bool,
pub unroll: bool,
}
pub(crate) struct Tiling2dState {
pub n_loops: Variable,
pub k: Variable,
pub lhs: Variable,
pub rhs: Variable,
pub out: Variable,
pub offset_lhs: Variable,
pub offset_rhs: Variable,
pub offset_output: Variable,
pub row: Variable,
pub col: Variable,
pub dim_m: Variable,
pub dim_k: Variable,
pub dim_n: Variable,
pub thread_col: Variable,
pub thread_row: Variable,
pub shared_lhs: Variable,
pub shared_rhs: Variable,
pub register_m: Variable,
pub register_n: Variable,
pub results: Variable,
pub lhs_stride_col: Variable,
pub lhs_stride_row: Variable,
pub rhs_stride_col: Variable,
pub rhs_stride_row: Variable,
pub out_stride_row: Variable,
pub out_stride_col: Variable,
}
impl MatmulTiling2dShader {
pub(crate) fn expand(self, scope: &mut Scope) {
let shader_state = gather_shader_information(scope, &self);
let block_size_k: Variable = self.config.block_size_k.into();
gpu!(
scope,
range(0u32, shader_state.n_loops).for_each(|i, scope| {
// From 0 to K with steps block_size_k
let k = shader_state.k;
gpu!(scope, k = i * block_size_k);
load_shared_memory(scope, &self, &shader_state);
scope.register(Synchronization::WorkgroupBarrier);
computation_loop(scope, &self, &shader_state);
scope.register(Synchronization::WorkgroupBarrier);
})
);
write_to_output(scope, &self, &shader_state);
}
}

View File

@ -0,0 +1,82 @@
use crate::gpu::{gpu, Elem, Scope, Variable};
use super::{MatmulTiling2dShader, Tiling2dState};
#[allow(clippy::too_many_arguments)]
pub fn computation_loop(
scope: &mut Scope,
shader: &MatmulTiling2dShader,
shader_state: &Tiling2dState,
) {
let thread_col = shader_state.thread_col;
let thread_row = shader_state.thread_row;
let shared_lhs = shader_state.shared_lhs;
let shared_rhs = shader_state.shared_rhs;
let register_m = shader_state.register_m;
let register_n = shader_state.register_n;
let results = shader_state.results;
let block_size_k: Variable = shader.config.block_size_k.into();
let block_size_n: Variable = shader.config.block_size_n.into();
let elem = results.item().elem();
let lhs_sm_position = scope.create_local(Elem::UInt);
let rhs_sm_position = scope.create_local(Elem::UInt);
let registered_m = scope.create_local(elem);
let registered_n = scope.create_local(elem);
let multiplied = scope.create_local(elem);
let results_position = scope.create_local(Elem::UInt);
let results_before = scope.create_local(elem);
let results_after = scope.create_local(elem);
gpu!(
scope,
range(0u32, shader.config.block_size_k as u32, shader.unroll).for_each(
|dot_index, scope| {
// Load a subcolumn of values from lhs
gpu!(scope, lhs_sm_position = thread_row / 4u32);
gpu!(scope, lhs_sm_position *= block_size_k);
gpu!(scope, lhs_sm_position += dot_index);
gpu!(scope, register_m = shared_lhs[lhs_sm_position]);
// Load a subrow of values from rhs
gpu!(scope, rhs_sm_position = dot_index * block_size_n);
gpu!(scope, rhs_sm_position += thread_col);
gpu!(scope, rhs_sm_position = rhs_sm_position / 4u32);
gpu!(scope, register_n = shared_rhs[rhs_sm_position]);
gpu!(
scope,
range(0u32, shader.config.tile_size_m as u32, shader.unroll).for_each(
|res_idx_m, scope| {
gpu!(
scope,
range(0u32, shader.config.tile_size_n as u32, shader.unroll)
.for_each(|res_idx_n, scope| {
gpu!(scope, registered_m = register_m[res_idx_m]);
gpu!(scope, registered_n = register_n[res_idx_n]);
gpu!(scope, multiplied = registered_m * registered_n);
gpu!(
scope,
results_position =
res_idx_m * shader.config.tile_size_n
);
gpu!(scope, results_position += res_idx_n);
gpu!(scope, results_before = results[results_position]);
gpu!(scope, results_after = results_before + multiplied);
gpu!(scope, results[results_position] = results_after);
})
);
}
)
);
}
)
);
}

View File

@ -0,0 +1,278 @@
use crate::gpu::{gpu, Elem, Scope, Variable};
use super::{MatmulTiling2dShader, Tiling2dState};
enum InputIdentifier {
Lhs,
Rhs,
}
pub(crate) fn load_shared_memory(
scope: &mut Scope,
shader: &MatmulTiling2dShader,
shader_state: &Tiling2dState,
) {
if shader.bounds_check_required {
load_shared_memory_with_bound_check(scope, shader, shader_state, InputIdentifier::Lhs);
load_shared_memory_with_bound_check(scope, shader, shader_state, InputIdentifier::Rhs);
} else {
load_shared_memory_no_bound_check(scope, shader, shader_state, InputIdentifier::Lhs);
load_shared_memory_no_bound_check(scope, shader, shader_state, InputIdentifier::Rhs);
}
}
#[allow(clippy::too_many_arguments)]
fn load_shared_memory_with_bound_check(
scope: &mut Scope,
shader: &MatmulTiling2dShader,
shader_state: &Tiling2dState,
input_identifier: InputIdentifier,
) {
let (
input,
input_offset,
shared_memory,
thread_idx_1,
thread_idx_2,
stride_1,
stride_2,
dim,
pos_in_dim,
) = match input_identifier {
InputIdentifier::Lhs => (
shader_state.lhs,
shader_state.offset_lhs,
shader_state.shared_lhs,
shader_state.thread_col,
shader_state.thread_row,
shader_state.lhs_stride_col,
shader_state.lhs_stride_row,
shader_state.dim_m,
shader_state.row,
),
InputIdentifier::Rhs => (
shader_state.rhs,
shader_state.offset_rhs,
shader_state.shared_rhs,
shader_state.thread_row,
shader_state.thread_col,
shader_state.rhs_stride_row,
shader_state.rhs_stride_col,
shader_state.dim_n,
shader_state.col,
),
};
let k = shader_state.k;
let dim_k = shader_state.dim_k;
// How close is the thread to the end of the matrix.
// If < 4 then it is an edge case
let remain = scope.create_local(Elem::UInt);
gpu!(scope, remain = dim - pos_in_dim);
let block_size_k: Variable = shader.config.block_size_k.into();
let block_size_n: Variable = shader.config.block_size_n.into();
let elem = input.item().elem();
let current = scope.create_local(Elem::UInt);
let aligned_with_shared_memory = scope.create_local(Elem::Bool);
let sm_position = scope.create_local(Elem::UInt);
let within_input = scope.create_local(Elem::Bool);
let current_with_k = scope.create_local(Elem::UInt);
let remain_at_least_1 = scope.create_local(Elem::Bool);
let read_condition = scope.create_local(Elem::Bool);
let val_vec4 = scope.create_local(shared_memory.item());
let tmp = scope.create_local(Elem::UInt);
let position_0 = scope.create_local(Elem::UInt);
let position_1 = scope.create_local(Elem::UInt);
let position_2 = scope.create_local(Elem::UInt);
let position_3 = scope.create_local(Elem::UInt);
let remain_n = scope.create_local(Elem::Bool);
let val_0 = scope.create_local(elem);
let val_1 = scope.create_local(elem);
let val_2 = scope.create_local(elem);
let val_3 = scope.create_local(elem);
let zero: Variable = 0u32.into();
gpu!(
scope,
range(0_u32, 4u32, shader.unroll).for_each(|j, scope| {
gpu!(scope, current = thread_idx_1 + j);
gpu!(scope, aligned_with_shared_memory = current < block_size_k);
// To avoid overwriting following row in shared memory
gpu!(scope, if(aligned_with_shared_memory).then(|scope|{
// Position in shared memory
match input_identifier {
InputIdentifier::Lhs => {
gpu!(scope, sm_position = thread_idx_2 / 4u32);
gpu!(scope, sm_position *= block_size_k);
gpu!(scope, sm_position += current);
},
InputIdentifier::Rhs => {
gpu!(scope, sm_position = current * block_size_n);
gpu!(scope, sm_position += thread_idx_2);
gpu!(scope, sm_position = sm_position / 4u32);
}
}
// To pad with zeros if outside lhs
gpu!(scope, current_with_k = current + k);
gpu!(scope, within_input = current_with_k < dim_k);
gpu!(scope, remain_at_least_1 = remain >= 1u32);
gpu!(scope, read_condition = within_input && remain_at_least_1);
gpu!(scope, if(read_condition).then(|scope| {
gpu!(scope, position_0 = k + current);
gpu!(scope, position_0 *= stride_1);
gpu!(scope, tmp = thread_idx_2 * stride_2);
gpu!(scope, position_0 += tmp);
gpu!(scope, position_0 += input_offset);
gpu!(scope, position_1 = position_0 + stride_2);
gpu!(scope, position_2 = position_1 + stride_2);
gpu!(scope, position_3 = position_2 + stride_2);
gpu!(scope, remain_n = remain >= 4u32);
gpu!(scope, if(remain_n).then(|scope|{
gpu!(scope, val_0 = input[position_0]);
gpu!(scope, val_1 = input[position_1]);
gpu!(scope, val_2 = input[position_2]);
gpu!(scope, val_3 = input[position_3]);
}).else(|scope|{
gpu!(scope, remain_n = remain == 3u32);
gpu!(scope, if(remain_n).then(|scope|{
gpu!(scope, val_0 = input[position_0]);
gpu!(scope, val_1 = input[position_1]);
gpu!(scope, val_2 = input[position_2]);
gpu!(scope, val_3 = zero);
}).else(|scope|{
gpu!(scope, remain_n = remain == 2u32);
gpu!(scope, if(remain_n).then(|scope|{
gpu!(scope, val_0 = input[position_0]);
gpu!(scope, val_1 = input[position_1]);
gpu!(scope, val_2 = zero);
gpu!(scope, val_3 = zero);
}).else(|scope|{
gpu!(scope, remain_n = remain == 1u32);
gpu!(scope, if(remain_n).then(|scope|{
gpu!(scope, val_0 = input[position_0]);
gpu!(scope, val_1 = zero);
gpu!(scope, val_2 = zero);
gpu!(scope, val_3 = zero);
}));
}));
}));
}));
gpu!(scope, val_vec4 = vec4(val_0, val_1, val_2, val_3));
gpu!(scope, shared_memory[sm_position] = val_vec4);
}).else(|scope|{
gpu!(scope, val_0 = zero);
gpu!(scope, val_vec4 = vec4(val_0, val_0, val_0, val_0));
gpu!(scope, shared_memory[sm_position] = val_vec4);
}));
}));
})
);
}
#[allow(clippy::too_many_arguments)]
fn load_shared_memory_no_bound_check(
scope: &mut Scope,
shader: &MatmulTiling2dShader,
shader_state: &Tiling2dState,
input_identifier: InputIdentifier,
) {
let (input, input_offset, shared_memory, thread_idx_1, thread_idx_2, stride_1, stride_2) =
match input_identifier {
InputIdentifier::Lhs => (
shader_state.lhs,
shader_state.offset_lhs,
shader_state.shared_lhs,
shader_state.thread_col,
shader_state.thread_row,
shader_state.lhs_stride_col,
shader_state.lhs_stride_row,
),
InputIdentifier::Rhs => (
shader_state.rhs,
shader_state.offset_rhs,
shader_state.shared_rhs,
shader_state.thread_row,
shader_state.thread_col,
shader_state.rhs_stride_row,
shader_state.rhs_stride_col,
),
};
let k = shader_state.k;
let block_size_k: Variable = shader.config.block_size_k.into();
let block_size_n: Variable = shader.config.block_size_n.into();
let elem = input.item().elem();
let current = scope.create_local(Elem::UInt);
let aligned_with_shared_memory = scope.create_local(Elem::Bool);
let sm_position = scope.create_local(Elem::UInt);
let tmp = scope.create_local(Elem::UInt);
let position_0 = scope.create_local(Elem::UInt);
let position_1 = scope.create_local(Elem::UInt);
let position_2 = scope.create_local(Elem::UInt);
let position_3 = scope.create_local(Elem::UInt);
let val_0 = scope.create_local(elem);
let val_1 = scope.create_local(elem);
let val_2 = scope.create_local(elem);
let val_3 = scope.create_local(elem);
let val_vec4 = scope.create_local(shared_memory.item());
gpu!(
scope,
range(0_u32, 4u32, shader.unroll).for_each(|j, scope| {
gpu!(scope, current = thread_idx_1 + j);
gpu!(scope, aligned_with_shared_memory = current < block_size_k);
// To avoid overwriting following row in shared memory
gpu!(scope, if(aligned_with_shared_memory).then(|scope|{
match input_identifier {
InputIdentifier::Lhs => {
gpu!(scope, sm_position = thread_idx_2 / 4u32);
gpu!(scope, sm_position *= block_size_k);
gpu!(scope, sm_position += current);
},
InputIdentifier::Rhs => {
gpu!(scope, sm_position = current * block_size_n);
gpu!(scope, sm_position += thread_idx_2);
gpu!(scope, sm_position = sm_position / 4u32);
}
}
gpu!(scope, position_0 = k + current);
gpu!(scope, position_0 *= stride_1);
gpu!(scope, tmp = thread_idx_2 * stride_2);
gpu!(scope, position_0 += tmp);
gpu!(scope, position_0 += input_offset);
gpu!(scope, position_1 = position_0 + stride_2);
gpu!(scope, position_2 = position_1 + stride_2);
gpu!(scope, position_3 = position_2 + stride_2);
gpu!(scope, val_0 = input[position_0]);
gpu!(scope, val_1 = input[position_1]);
gpu!(scope, val_2 = input[position_2]);
gpu!(scope, val_3 = input[position_3]);
gpu!(scope, val_vec4 = vec4(val_0, val_1, val_2, val_3));
gpu!(scope, shared_memory[sm_position] = val_vec4);
}));
})
);
}

View File

@ -0,0 +1,11 @@
mod base;
mod computation;
mod load_shared_memory;
mod shader_information;
mod write_output;
pub(crate) use base::*;
pub(crate) use computation::*;
pub(crate) use load_shared_memory::*;
pub(crate) use shader_information::*;
pub(crate) use write_output::*;

View File

@ -0,0 +1,180 @@
use crate::gpu::{gpu, Elem, Item, Scope, Variable};
use super::{MatmulTiling2dShader, Tiling2dState};
pub(crate) fn gather_shader_information(
scope: &mut Scope,
shader: &MatmulTiling2dShader,
) -> Tiling2dState {
// Inputs
let lhs = shader.variables.lhs;
let rhs = shader.variables.rhs;
let out = shader.variables.out;
// Config variables
let block_size_m: Variable = shader.config.block_size_m.into();
let block_size_k: Variable = shader.config.block_size_k.into();
let block_size_n: Variable = shader.config.block_size_n.into();
let tile_size_m: Variable = shader.config.tile_size_m.into();
let tile_size_n: Variable = shader.config.tile_size_n.into();
let n_threads_per_row: Variable =
(((shader.config.block_size_n - 1) / shader.config.tile_size_n) + 1).into();
let results_size = (shader.config.tile_size_m * shader.config.tile_size_n) as u32;
// Shader info
let local_idx = Variable::LocalInvocationIndex;
let batch = Variable::GlobalInvocationIdZ;
// Shapes
let rank = Variable::Rank;
let last_dim = scope.create_local(Elem::UInt);
let second_to_last_dim = scope.create_local(Elem::UInt);
let dim_m = scope.create_local(Elem::UInt);
let dim_k = scope.create_local(Elem::UInt);
let dim_n = scope.create_local(Elem::UInt);
gpu!(scope, last_dim = rank - 1u32);
gpu!(scope, second_to_last_dim = rank - 2u32);
gpu!(scope, dim_m = shape(lhs, second_to_last_dim));
gpu!(scope, dim_k = shape(lhs, last_dim));
gpu!(scope, dim_n = shape(rhs, last_dim));
// Strides
let lhs_stride_row = scope.create_local(Elem::UInt);
let lhs_stride_col = scope.create_local(Elem::UInt);
let rhs_stride_row = scope.create_local(Elem::UInt);
let rhs_stride_col = scope.create_local(Elem::UInt);
let out_stride_row = scope.create_local(Elem::UInt);
let out_stride_col = scope.create_local(Elem::UInt);
gpu!(scope, lhs_stride_row = stride(lhs, second_to_last_dim));
gpu!(scope, lhs_stride_col = stride(lhs, last_dim));
gpu!(scope, rhs_stride_row = stride(rhs, second_to_last_dim));
gpu!(scope, rhs_stride_col = stride(rhs, last_dim));
gpu!(scope, out_stride_row = stride(out, second_to_last_dim));
gpu!(scope, out_stride_col = stride(out, last_dim));
// Workgroup offset
let skip_row = scope.create_local(Elem::UInt);
let skip_col = scope.create_local(Elem::UInt);
let workgroup_id_x = Variable::WorkgroupIdX;
let workgroup_id_y = Variable::WorkgroupIdY;
gpu!(scope, skip_row = workgroup_id_x);
gpu!(scope, skip_row *= block_size_m);
gpu!(scope, skip_col = workgroup_id_y);
gpu!(scope, skip_col *= block_size_n);
// Position of the first element of the thread, relative to the block
let thread_row = scope.create_local(Elem::UInt);
let thread_col = scope.create_local(Elem::UInt);
gpu!(scope, thread_row = local_idx / n_threads_per_row);
gpu!(scope, thread_row *= tile_size_m);
gpu!(scope, thread_col = local_idx % n_threads_per_row);
gpu!(scope, thread_col *= tile_size_n);
// Position of the first element of the thread, in absolute (in one batch)
let row = scope.create_local(Elem::UInt);
let col = scope.create_local(Elem::UInt);
gpu!(scope, row = skip_row + thread_row);
gpu!(scope, col = skip_col + thread_col);
// Calculate offset.
let offset_lhs = scope.create_local(Elem::UInt);
let offset_rhs = scope.create_local(Elem::UInt);
gpu!(scope, offset_lhs = skip_row * lhs_stride_row);
gpu!(scope, offset_rhs = skip_col * rhs_stride_col);
// Batch offset for the output.
let offset_output = scope.create_local(Elem::UInt);
let batch_dims = scope.create_local(Elem::UInt);
gpu!(scope, offset_output = dim_m * dim_n);
gpu!(scope, offset_output = offset_output * batch);
// Batch offset for the lhs & rhs matrices.
let stride_lhs = scope.create_local(Elem::UInt);
let stride_rhs = scope.create_local(Elem::UInt);
let stride_output = scope.create_local(Elem::UInt);
let shape_lhs = scope.create_local(Elem::UInt);
let shape_rhs = scope.create_local(Elem::UInt);
let tmp = scope.create_local(Elem::UInt);
let tmp_lhs = scope.create_local(Elem::UInt);
let tmp_rhs = scope.create_local(Elem::UInt);
gpu!(scope, batch_dims = rank - 2u32);
gpu!(
scope,
range(0u32, batch_dims).for_each(|b, scope| {
gpu!(scope, stride_lhs = stride(lhs, b));
gpu!(scope, stride_rhs = stride(rhs, b));
gpu!(scope, stride_output = stride(out, b));
gpu!(scope, shape_lhs = shape(lhs, b));
gpu!(scope, shape_rhs = shape(rhs, b));
gpu!(scope, tmp = offset_output / stride_output);
gpu!(scope, tmp_lhs = tmp % shape_lhs);
gpu!(scope, tmp_lhs = tmp_lhs * stride_lhs);
gpu!(scope, offset_lhs += tmp_lhs);
gpu!(scope, tmp_rhs = tmp % shape_rhs);
gpu!(scope, tmp_rhs = tmp_rhs * stride_rhs);
gpu!(scope, offset_rhs += tmp_rhs);
})
);
let elem = lhs.item().elem();
// Registers used in the compute pass
let results = scope.create_local_array(elem, results_size);
let register_m = scope.create_local(Item::Vec4(elem));
let register_n = scope.create_local(Item::Vec4(elem));
let shared_lhs = scope.create_shared(
Item::Vec4(elem),
shader.config.block_size_m as u32 * shader.config.block_size_k as u32 / 4u32,
);
let shared_rhs = scope.create_shared(
Item::Vec4(elem),
shader.config.block_size_k as u32 * shader.config.block_size_n as u32 / 4u32,
);
// Calculate exact number of loop iterations
let n_loops = scope.create_local(Elem::UInt);
let k = scope.create_local(Elem::UInt);
if shader.bounds_check_required {
let dim_k_float = scope.create_local(elem);
let block_size_k_float = scope.create_local(elem);
let n_loops_float = scope.create_local(elem);
gpu!(scope, dim_k_float = dim_k);
gpu!(scope, block_size_k_float = block_size_k);
gpu!(scope, n_loops_float = dim_k_float / block_size_k_float);
gpu!(scope, n_loops_float = ceil(n_loops_float));
gpu!(scope, n_loops = n_loops_float);
} else {
gpu!(scope, n_loops = dim_k / block_size_k);
}
Tiling2dState {
n_loops,
k,
lhs,
rhs,
out,
offset_lhs,
offset_rhs,
offset_output,
row,
col,
dim_m,
dim_k,
dim_n,
thread_col,
thread_row,
shared_lhs,
shared_rhs,
register_m,
register_n,
results,
lhs_stride_col,
lhs_stride_row,
rhs_stride_col,
rhs_stride_row,
out_stride_row,
out_stride_col,
}
}

View File

@ -0,0 +1,121 @@
use crate::gpu::{gpu, Elem, Scope, Variable};
use super::{MatmulTiling2dShader, Tiling2dState};
#[allow(clippy::too_many_arguments)]
pub fn write_to_output(
scope: &mut Scope,
shader: &MatmulTiling2dShader,
shader_state: &Tiling2dState,
) {
let row = shader_state.row;
let col = shader_state.col;
let row_index = scope.create_local(Elem::UInt);
let col_index = scope.create_local(Elem::UInt);
if shader.bounds_check_required {
let dim_m = shader_state.dim_m;
let dim_n = shader_state.dim_n;
let within_output = scope.create_local(Elem::Bool);
let within_output_tmp = scope.create_local(Elem::Bool);
gpu!(
scope,
range(0u32, shader.config.tile_size_m as u32, shader.unroll).for_each(
|res_idx_m, scope| {
gpu!(
scope,
range(0u32, shader.config.tile_size_n as u32, shader.unroll).for_each(
|res_idx_n, scope| {
gpu!(scope, row_index = row + res_idx_m);
gpu!(scope, col_index = col + res_idx_n);
gpu!(scope, within_output = row_index < dim_m);
gpu!(scope, within_output_tmp = col_index < dim_n);
gpu!(scope, within_output = within_output && within_output_tmp);
gpu!(scope, if(within_output).then(|scope|{
write_inner(
scope,
shader,
shader_state,
res_idx_m,
res_idx_n,
row_index,
col_index,
);
}));
}
)
);
}
)
);
} else {
gpu!(
scope,
range(0u32, shader.config.tile_size_m as u32, shader.unroll).for_each(
|res_idx_m, scope| {
gpu!(
scope,
range(0u32, shader.config.tile_size_n as u32, shader.unroll).for_each(
|res_idx_n, scope| {
gpu!(scope, row_index = row + res_idx_m);
gpu!(scope, col_index = col + res_idx_n);
write_inner(
scope,
shader,
shader_state,
res_idx_m,
res_idx_n,
row_index,
col_index,
)
}
)
);
}
)
);
}
}
#[allow(clippy::too_many_arguments)]
fn write_inner(
scope: &mut Scope,
shader: &MatmulTiling2dShader,
shader_state: &Tiling2dState,
res_idx_m: Variable,
res_idx_n: Variable,
row_index: Variable,
col_index: Variable,
) {
let offset_output = shader_state.offset_output;
let out = shader_state.out;
let out_stride_row = shader_state.out_stride_row;
let out_stride_col = shader_state.out_stride_col;
let results = shader_state.results;
let elem = results.item().elem();
let results_position = scope.create_local(Elem::UInt);
let result = scope.create_local(elem);
let output_position = scope.create_local(Elem::UInt);
gpu!(
scope,
results_position = res_idx_m * shader.config.tile_size_n
);
gpu!(scope, results_position += res_idx_n);
gpu!(scope, result = results[results_position]);
gpu!(scope, row_index *= out_stride_row);
gpu!(scope, col_index *= out_stride_col);
gpu!(scope, output_position = row_index + col_index);
gpu!(scope, output_position += offset_output);
gpu!(scope, out[output_position] = result);
}

View File

@ -4,7 +4,10 @@ use burn_tensor::{Element, ElementConversion};
use crate::{
compute::JitAutotuneKey,
element::JitElement,
kernel::{matmul::utils::init_matmul_output, prng::random_like_uniform},
kernel::{
matmul::{utils::init_matmul_output, Tiling2dConfig},
prng::random_like_uniform,
},
ops::numeric::empty_device,
tensor::JitTensor,
Runtime,
@ -50,22 +53,14 @@ impl<R: Runtime, E: JitElement + Element, const D: usize> AutotuneOperationSet<J
);
vec![
Box::new(MemoryCoalescingMatmulDefault::new(
Box::new(SimpleMatmul::new(lhs.clone(), rhs.clone(), out.clone())),
Box::new(SimpleMatmul16x16::new(
lhs.clone(),
rhs.clone(),
out.clone(),
)),
Box::new(MemoryCoalescingMatmulW16x16::new(
lhs.clone(),
rhs.clone(),
out.clone(),
)),
Box::new(Vec4TilingMatmulDefault::new(
lhs.clone(),
rhs.clone(),
out.clone(),
)),
Box::new(Vec4TilingMatmulUnpaddedDefault::new(
Box::new(Tiling2dMatmul::new(lhs.clone(), rhs.clone(), out.clone())),
Box::new(Tiling2dMatmulPadded::new(
lhs.clone(),
rhs.clone(),
out.clone(),
@ -75,16 +70,10 @@ impl<R: Runtime, E: JitElement + Element, const D: usize> AutotuneOperationSet<J
fn fastest(self: Box<Self>, fastest_index: usize) -> Box<dyn AutotuneOperation> {
match fastest_index {
0 => Box::new(MemoryCoalescingMatmulDefault::new(
self.lhs, self.rhs, self.out,
)),
1 => Box::new(MemoryCoalescingMatmulW16x16::new(
self.lhs, self.rhs, self.out,
)),
2 => Box::new(Vec4TilingMatmulDefault::new(self.lhs, self.rhs, self.out)),
3 => Box::new(Vec4TilingMatmulUnpaddedDefault::new(
self.lhs, self.rhs, self.out,
)),
0 => Box::new(SimpleMatmul::new(self.lhs, self.rhs, self.out)),
1 => Box::new(SimpleMatmul16x16::new(self.lhs, self.rhs, self.out)),
2 => Box::new(Tiling2dMatmul::new(self.lhs, self.rhs, self.out)),
3 => Box::new(Tiling2dMatmulPadded::new(self.lhs, self.rhs, self.out)),
_ => panic!("Fastest index is out of bound"),
}
}
@ -134,23 +123,21 @@ macro_rules! matmul_tune_ops {
// Potentially better for small matrices.
matmul_tune_ops!(
MemoryCoalescingMatmulDefault,
SimpleMatmul,
crate::kernel::matmul::matmul_mem_coalescing_default
);
// Potentially better for small matrices.
matmul_tune_ops!(MemoryCoalescingMatmulW16x16, |lhs, rhs, out| {
crate::kernel::matmul::matmul_mem_coalescing(lhs, rhs, out, 16, 16)
matmul_tune_ops!(SimpleMatmul16x16, |lhs, rhs, out| {
crate::kernel::matmul::matmul_simple(lhs, rhs, out, 16, 16)
});
// Probably the fastest when fixed sizes.
matmul_tune_ops!(
Vec4TilingMatmulDefault,
crate::kernel::matmul::vec4::matmul_tiling_2d_vec4
);
matmul_tune_ops!(Tiling2dMatmulPadded, |lhs, rhs, out| {
crate::kernel::matmul::matmul_tiling_2d_padded(lhs, rhs, out, Tiling2dConfig::default())
});
// Probably the fastest otherwise.
matmul_tune_ops!(
Vec4TilingMatmulUnpaddedDefault,
crate::kernel::matmul::unpadded::matmul_tiling_2d_unpadded
);
// Probably the fastest in the general case
matmul_tune_ops!(Tiling2dMatmul, |lhs, rhs, out| {
crate::kernel::matmul::matmul_tiling_2d(lhs, rhs, out, Tiling2dConfig::default())
});

View File

@ -1,7 +1,7 @@
#[burn_tensor_testgen::testgen(matmul)]
mod tests {
use super::*;
use burn_jit::kernel::matmul::{matmul, MatmulStrategy};
use burn_jit::kernel::matmul::{matmul, MatmulStrategy, Tiling2dConfig};
use burn_tensor::{Shape, Tensor};
mod simple {
@ -174,7 +174,7 @@ mod tests {
let shape_lhs = [3, 2, 4, 4];
let shape_rhs = [3, 2, 4, 4];
same_as_reference_swapped_dims(
MatmulStrategy::Tiling2dPadded,
MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()),
swap,
swap,
shape_lhs,
@ -189,7 +189,7 @@ mod tests {
let shape_lhs = [3, 2, 4, 4];
let shape_rhs = [3, 2, 4, 4];
same_as_reference_swapped_dims(
MatmulStrategy::Tiling2dPadded,
MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()),
swap_lhs,
swap_rhs,
shape_lhs,
@ -204,7 +204,7 @@ mod tests {
let shape_lhs = [4, 4, 4, 4];
let shape_rhs = [4, 4, 4, 4];
same_as_reference_swapped_dims(
MatmulStrategy::Tiling2dPadded,
MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()),
swap_lhs,
swap_rhs,
shape_lhs,
@ -212,10 +212,56 @@ mod tests {
);
}
#[test]
fn stable_test() {
let ref_tensor_device = Default::default();
let x = ReferenceTensor::from_floats([[0., 1., 2.], [3., 4., 5.]], &ref_tensor_device);
let y =
ReferenceTensor::from_floats([[0., 1.], [2., 3.], [4., 5.]], &ref_tensor_device);
let test_tensor_device = Default::default();
let x_jit = TestTensor::from_data(x.to_data(), &test_tensor_device);
let y_jit = TestTensor::from_data(y.to_data(), &test_tensor_device);
let z_reference = x.matmul(y);
let z = Tensor::<TestBackend, 2>::from_primitive(matmul(
x_jit.into_primitive(),
y_jit.into_primitive(),
MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()),
));
z_reference.into_data().assert_approx_eq(&z.into_data(), 3);
}
#[test]
fn stable_test_2() {
let ref_tensor_device = Default::default();
let x =
ReferenceTensor::from_floats([[0., 1.], [2., 3.], [4., 5.]], &ref_tensor_device);
let y = ReferenceTensor::from_floats([[0., 1., 2.], [3., 4., 5.]], &ref_tensor_device);
let test_tensor_device = Default::default();
let x_jit = TestTensor::from_data(x.to_data(), &test_tensor_device);
let y_jit = TestTensor::from_data(y.to_data(), &test_tensor_device);
let z_reference = x.matmul(y);
let z = Tensor::<TestBackend, 2>::from_primitive(matmul(
x_jit.into_primitive(),
y_jit.into_primitive(),
MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()),
));
z_reference.into_data().assert_approx_eq(&z.into_data(), 3);
}
fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) {
let shape_lhs = [batch_1, batch_2, m, k];
let shape_rhs = [batch_1, batch_2, k, n];
same_as_reference(MatmulStrategy::Tiling2dPadded, shape_lhs, shape_rhs);
same_as_reference(
MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()),
shape_lhs,
shape_rhs,
);
}
}
@ -308,7 +354,7 @@ mod tests {
let shape_lhs = [3, 2, 4, 4];
let shape_rhs = [3, 2, 4, 4];
same_as_reference_swapped_dims(
MatmulStrategy::Tiling2d,
MatmulStrategy::Tiling2d(Tiling2dConfig::default()),
swap,
swap,
shape_lhs,
@ -323,7 +369,7 @@ mod tests {
let shape_lhs = [3, 2, 4, 4];
let shape_rhs = [3, 2, 4, 4];
same_as_reference_swapped_dims(
MatmulStrategy::Tiling2d,
MatmulStrategy::Tiling2d((Tiling2dConfig::default())),
swap_lhs,
swap_rhs,
shape_lhs,
@ -338,7 +384,7 @@ mod tests {
let shape_lhs = [4, 4, 4, 4];
let shape_rhs = [4, 4, 4, 4];
same_as_reference_swapped_dims(
MatmulStrategy::Tiling2d,
MatmulStrategy::Tiling2d(Tiling2dConfig::default()),
swap_lhs,
swap_rhs,
shape_lhs,
@ -349,7 +395,11 @@ mod tests {
fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) {
let shape_lhs = [batch_1, batch_2, m, k];
let shape_rhs = [batch_1, batch_2, k, n];
same_as_reference(MatmulStrategy::Tiling2d, shape_lhs, shape_rhs);
same_as_reference(
MatmulStrategy::Tiling2d(Tiling2dConfig::default()),
shape_lhs,
shape_rhs,
);
}
}

View File

@ -18,6 +18,7 @@ pub enum Variable {
scope_depth: u8,
},
SharedMemory(u16, Item, u32),
LocalArray(u16, Item, u8, u32),
Id,
LocalInvocationIndex,
LocalInvocationIdX,
@ -79,6 +80,7 @@ impl Variable {
Variable::GlobalInputArray(_, _) => false,
Variable::GlobalOutputArray(_, _) => false,
Variable::SharedMemory(_, _, _) => false,
Variable::LocalArray(_, _, _, _) => false,
Variable::Local {
index: _,
item: _,
@ -110,6 +112,7 @@ impl Variable {
Self::GlobalInputArray(_, e) => *e,
Self::GlobalOutputArray(_, e) => *e,
Self::SharedMemory(_, e, _) => *e,
Self::LocalArray(_, e, _, _) => *e,
Self::Local {
index: _,
item,
@ -222,6 +225,9 @@ impl Display for Variable {
Variable::SharedMemory(number, _, _) => {
f.write_fmt(format_args!("shared_memory_{number}"))
}
Variable::LocalArray(number, _, scope_depth, _) => {
f.write_fmt(format_args!("a_{scope_depth}_{number}"))
}
Variable::Id => f.write_str("id"),
Variable::LocalInvocationIndex => f.write_str("local_idx"),
Variable::LocalInvocationIdX => f.write_str("local_invocation_id.x"),

View File

@ -1,3 +1,4 @@
use super::LocalArray;
use super::{shader::ComputeShader, Item, SharedMemory};
use crate::compiler::wgsl;
use crate::{FloatElement, IntElement};
@ -19,6 +20,7 @@ pub struct WgslCompiler<F: FloatElement, I: IntElement> {
shape: bool,
num_workgroups: bool,
shared_memories: Vec<SharedMemory>,
local_arrays: Vec<LocalArray>,
_float: PhantomData<F>,
_int: PhantomData<I>,
}
@ -44,6 +46,7 @@ impl<F: FloatElement, I: IntElement> Default for WgslCompiler<F, I> {
shape: false,
num_workgroups: false,
shared_memories: Vec::default(),
local_arrays: Vec::default(),
_float: PhantomData,
_int: PhantomData,
}
@ -64,6 +67,10 @@ impl<F: FloatElement, I: IntElement> burn_jit::Compiler for WgslCompiler<F, I> {
fn elem_size(elem: gpu::Elem) -> usize {
Self::compile_elem(elem).size()
}
fn max_shared_memory_size() -> usize {
8192
}
}
impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
@ -98,6 +105,7 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
.map(|(name, binding)| (name, Self::compile_binding(binding)))
.collect(),
shared_memories: self.shared_memories.clone(),
local_arrays: self.local_arrays.clone(),
workgroup_size: value.workgroup_size,
global_invocation_id: self.global_invocation_id || self.id,
local_invocation_index: self.local_invocation_index,
@ -159,6 +167,14 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
}
wgsl::Variable::SharedMemory(index, item, size)
}
gpu::Variable::LocalArray(index, item, scope_depth, size) => {
let item = Self::compile_item(item);
if !self.local_arrays.iter().any(|s| s.index == index) {
self.local_arrays
.push(LocalArray::new(index, item, scope_depth, size));
}
wgsl::Variable::LocalArray(index, item, scope_depth, size)
}
gpu::Variable::Id => {
self.id = true;
wgsl::Variable::Id
@ -416,6 +432,10 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
input: self.compile_variable(op.input),
out: self.compile_variable(op.out),
},
gpu::Operator::Ceil(op) => wgsl::Instruction::Ceil {
input: self.compile_variable(op.input),
out: self.compile_variable(op.out),
},
gpu::Operator::Log(op) => wgsl::Instruction::Log {
input: self.compile_variable(op.input),
out: self.compile_variable(op.out),

View File

@ -109,6 +109,10 @@ pub enum Instruction {
input: Variable,
out: Variable,
},
Ceil {
input: Variable,
out: Variable,
},
Erf {
input: Variable,
out: Variable,
@ -272,6 +276,9 @@ impl Display for Instruction {
Instruction::Sqrt { input, out } => {
f.write_fmt(format_args!("{out} = sqrt({input});\n"))
}
Instruction::Ceil { input, out } => {
f.write_fmt(format_args!("{out} = ceil({input});\n"))
}
Instruction::Log1p { input, out } => {
f.write_fmt(format_args!("{out} = log({input} + 1.0);\n"))
}

View File

@ -5,7 +5,6 @@ use std::fmt::Display;
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum Location {
Storage,
#[allow(dead_code)]
Workgroup,
}
@ -42,12 +41,32 @@ impl SharedMemory {
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct LocalArray {
pub index: u16,
item: Item,
name: u8,
size: u32,
}
impl LocalArray {
pub fn new(index: u16, item: Item, name: u8, size: u32) -> Self {
Self {
index,
item,
name,
size,
}
}
}
#[derive(Debug, Clone)]
pub struct ComputeShader {
pub inputs: Vec<Binding>,
pub outputs: Vec<Binding>,
pub named: Vec<(String, Binding)>,
pub shared_memories: Vec<SharedMemory>,
pub local_arrays: Vec<LocalArray>,
pub workgroup_size: WorkgroupSize,
pub global_invocation_id: bool,
pub local_invocation_index: bool,
@ -72,10 +91,10 @@ impl Display for ComputeShader {
)?;
}
for shared_memory in self.shared_memories.iter() {
for array in self.shared_memories.iter() {
f.write_fmt(format_args!(
"var<{}> shared_memory_{}: array<{}, {}>;\n\n",
shared_memory.location, shared_memory.index, shared_memory.item, shared_memory.size
array.location, array.index, array.item, array.size
))?;
}
@ -115,12 +134,22 @@ fn main(
f.write_str(" @builtin(workgroup_id) workgroup_id: vec3<u32>,\n")?;
}
f.write_fmt(format_args!(
") {{
{}
}}",
self.body
))?;
// Open body
f.write_fmt(format_args!(") {{"))?;
// Local arrays
for array in self.local_arrays.iter() {
f.write_fmt(format_args!(
"var a_{}_{}: array<{}, {}>;\n\n",
array.name, array.index, array.item, array.size
))?;
}
// Body
f.write_fmt(format_args!("{}", self.body))?;
// Close body
f.write_fmt(format_args!("}}"))?;
for extension in self.extensions.iter() {
f.write_fmt(format_args!("{extension}\n\n"))?;