Matmul 2D Tiling (#442)

This commit is contained in:
Louis Fortier-Dubois 2023-06-28 16:48:15 -04:00 committed by GitHub
parent f42176e93a
commit f99fe0fadd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 259 additions and 23 deletions

View File

@ -55,4 +55,34 @@ mod tests {
Data::from([[15.0, 34.0, 53.0], [42.0, 106.0, 170.0]])
);
}
#[test]
fn test_matmul_simple_2() {
let tensor_1 = TestTensor::from_floats([[1.0, 2.0, 3.0, 4.0]]);
let tensor_2 = TestTensor::from_floats([[3.0], [4.0], [5.0], [6.0]]);
let tensor_3 = tensor_1.matmul(tensor_2);
assert_eq!(tensor_3.into_data(), Data::from([[50.0]]));
}
#[test]
fn test_matmul_simple_3() {
let tensor_1 =
TestTensor::from_floats([[3., 3., 3.], [4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]);
let tensor_2 =
TestTensor::from_floats([[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]]);
let tensor_3 = tensor_1.matmul(tensor_2);
assert_eq!(
tensor_3.into_data(),
Data::from([
[9., 18., 27., 36.],
[12., 24., 36., 48.],
[15., 30., 45., 60.],
[18., 36., 54., 72.]
])
);
}
}

View File

@ -27,11 +27,14 @@ spin = {workspace = true}
# WGPU stuff
futures-intrusive = {workspace = true}
pollster = {workspace = true}
text_placeholder = {workspace = true}
serde = {workspace = true}
wgpu = {workspace = true}
# Template
serde = {workspace = true}
text_placeholder = { version = "0.5.0", features = ["struct_context"] }
[dev-dependencies]
burn-ndarray = {path = "../burn-ndarray", version = "0.8.0"}
burn-autodiff = {path = "../burn-autodiff", version = "0.8.0", default-features = false, features = [
"export_tests",
]}

View File

@ -51,15 +51,15 @@ impl Context {
/// Create a new context where computing tasks will be executed on the given
/// [device](WgpuDevice).
pub(crate) fn new<G: GraphicsApi>(device: &WgpuDevice) -> Self {
let device_wgpu = device.clone();
let (device, queue) = pollster::block_on(select_device::<G>(device));
let device = Arc::new(device);
let client = ContextServerImpl::start(device.clone(), queue);
let (device_wgpu, queue) = pollster::block_on(select_device::<G>(device));
let device = device.clone();
let device_wgpu = Arc::new(device_wgpu);
let client = ContextServerImpl::start(device_wgpu.clone(), queue);
Self {
id: IdGenerator::generate(),
device_wgpu: device,
device: device_wgpu,
device_wgpu,
device,
client,
cache: Mutex::new(HashMap::new()),
}

View File

@ -130,7 +130,6 @@ pub(crate) fn build_info<E: WgpuElement, const D: usize>(
current += 1;
}
}
info
}

View File

@ -1,16 +1,64 @@
use super::{build_info, KernelSettings};
use crate::{context::WorkGroup, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
use std::cmp::{max, min};
use super::{build_info, SourceTemplate, StaticKernel};
use crate::{
context::WorkGroup, element::WgpuElement, kernel::KernelSettings, kernel_wgsl,
tensor::WgpuTensor,
};
use burn_tensor::Shape;
const BLOCK_SIZE: usize = 16;
// Suppose a matmul of m1 of size [M, K] with m2 of size [K, N]
// Block size along dim M
const B_M: usize = 128;
// Block size along dim N
const B_N: usize = 128;
// Block size along dim K
const B_K: usize = 8;
// Tiling size along dim M
const T_M: usize = 8;
// Tiling size along dim N
const T_N: usize = 8;
kernel_wgsl!(MatmulCoalescing, "../template/matmul_mem_coalescing.wgsl");
// WORKGROUP_SIZE_X = ceil(B_M / T_M)
const WORKGROUP_SIZE_X: usize = 16;
// WORKGROUP_SIZE_Y = ceil(B_N / T_N)
const WORKGROUP_SIZE_Y: usize = 16;
const MAX_SHARED_MEMORY_SIZE: usize = 8192;
kernel_wgsl!(MatmulTiling2DRaw, "../template/matmul_blocktiling_2d.wgsl");
struct MatmulTiling2D;
impl StaticKernel for MatmulTiling2D {
fn source_template() -> SourceTemplate {
MatmulTiling2DRaw::source_template()
.register("b_m", B_M.to_string())
.register("b_n", B_N.to_string())
.register("b_k", B_K.to_string())
.register("bm_x_bk", (B_M * B_K).to_string())
.register("bk_x_bn", (B_K * B_N).to_string())
.register("t_m", T_M.to_string())
.register("t_n", T_N.to_string())
.register("tm_x_tn", (T_M * T_N).to_string())
}
}
pub fn matmul<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
matmul_tiling_2d(lhs, rhs)
}
pub fn matmul_tiling_2d<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
assert!(B_K <= min(B_M, B_N), "B_K must be smaller than both B_M and B_M, otherwise there won't be enough threads to fill shared memory. ");
assert!(B_K * max(B_M, B_N) <= MAX_SHARED_MEMORY_SIZE, "B_K x B_M and B_K x B_N must be smaller or equal than 8192, otherwise shared memory limit will be busted. ");
lhs.assert_is_on_save_device(&rhs);
let mut shape_out = [0; D];
lhs.shape
.dims
@ -21,22 +69,32 @@ pub fn matmul<E: WgpuElement, const D: usize>(
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
});
shape_out[D - 2] = lhs.shape.dims[D - 2];
shape_out[D - 1] = rhs.shape.dims[D - 1];
let num_rows = lhs.shape.dims[D - 2];
let num_cols = rhs.shape.dims[D - 1];
shape_out[D - 2] = num_rows;
shape_out[D - 1] = num_cols;
let shape_out = Shape::new(shape_out);
let buffer = lhs
.context
.create_buffer(shape_out.num_elements() * core::mem::size_of::<E>());
let output = WgpuTensor::new(lhs.context.clone(), shape_out, buffer);
let num_rows = lhs.shape.dims[D - 2];
let num_cols = rhs.shape.dims[D - 1];
let kernel = lhs
.context
.compile_static::<KernelSettings<MatmulCoalescing, E, i32, BLOCK_SIZE, BLOCK_SIZE, 1>>();
// set number of workgroups
let blocks_needed_in_x = f32::ceil(num_rows as f32 / (WORKGROUP_SIZE_X * T_M) as f32) as u32;
let blocks_needed_in_y = f32::ceil(num_cols as f32 / (WORKGROUP_SIZE_Y * T_N) as f32) as u32;
let kernel = lhs.context.compile_static::<KernelSettings<
MatmulTiling2D,
E,
i32,
WORKGROUP_SIZE_X,
WORKGROUP_SIZE_Y,
1,
>>();
let info = build_info(&[&lhs, &rhs, &output]);
let info_buffers = lhs
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
@ -46,9 +104,7 @@ pub fn matmul<E: WgpuElement, const D: usize>(
num_iter *= output.shape.dims[i];
}
let workgroup_x = f32::ceil(num_rows as f32 / BLOCK_SIZE as f32) as u32;
let workgroup_y = f32::ceil(num_cols as f32 / BLOCK_SIZE as f32) as u32;
let workgroup = WorkGroup::new(workgroup_x, workgroup_y, num_iter as u32);
let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32);
lhs.context.execute(
workgroup,

View File

@ -0,0 +1,148 @@
@group(0)
@binding(0)
var<storage, read> lhs: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> rhs: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(3)
var<storage, read> info: array<u32>;
const B_M = {{b_m}}u;
const B_N = {{b_n}}u;
const B_K = {{b_k}}u;
const B_M_X_B_K = {{bm_x_bk}}u;
const B_K_X_B_N = {{bk_x_bn}}u;
const T_M = {{t_m}}u;
const T_N = {{t_n}}u;
const T_M_X_T_N = {{tm_x_tn}}u;
var<workgroup> shared_lhs: array<{{ elem }}, B_M_X_B_K>;
var<workgroup> shared_rhs: array<{{ elem }}, B_K_X_B_N>;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }})
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_index) local_idx: u32,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
let skip_row = workgroup_id.x * B_M;
let skip_col = workgroup_id.y * B_N;
let n_thread_per_row = ((B_N - 1u) / T_N) + 1u;
let thread_row = (local_idx / n_thread_per_row) * T_M;
let thread_col = (local_idx % n_thread_per_row) * T_N;
let row = skip_row + thread_row;
let col = skip_col + thread_col;
let batch = global_id.z;
// Basic information
let dim = info[0];
let n_rows = info[6u * dim - 1u];
let n_cols = info[6u * dim];
let K = info[5u * dim - 1u];
// Calculate the corresponding offsets with support for broadcasting.
let offset_output = batch * n_rows * n_cols;
var offset_lhs: u32 = skip_row * K;
var offset_rhs: u32 = skip_col;
let batch_dims = dim - 2u;
for (var b: u32 = 1u; b <= batch_dims; b++) {
let stride_lhs = info[b];
let stride_rhs = info[b + dim];
let stride_output = info[b + 2u * dim];
let shape_lhs = info[b + 3u * dim];
let shape_rhs = info[b + 4u * dim];
offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;
offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
}
// In case B_M % T_M != 0 or B_N % T_N != 0
// A thread must not read out of its block
let actual_T_M = min(B_M - thread_row, T_M);
let actual_T_N = min(B_N - thread_col, T_N);
var results: array<{{ elem }}, T_M_X_T_N>;
var register_M: array<{{ elem }}, T_M>;
var register_N: array<{{ elem }}, T_N>;
for (var k = 0u; k < K; k += B_K) {
// sm_limit ensures that although there are up to B_M x B_N writes to memory,
// shared memories remain B_M x B_K (lhs) or B_K x B_N (rhs)
// also ensures we do not read out of matrices if M % B_M != 0 or N % B_N != 0
let sm_limit = min(B_K, K - k);
// Load data into shared memories
// Each thread is responsible of loading T_M x T_N values from both lhs and rhs
for (var i = 0u; i < actual_T_M; i++) {
for (var j = 0u; j < actual_T_N; j++) {
let current_row = thread_row + i;
let current_col = thread_col + j;
if current_col < sm_limit {
let lhs_sm_position = current_row * B_K + current_col;
let lhs_position = offset_lhs + k + current_row * K + current_col;
shared_lhs[lhs_sm_position] = lhs[lhs_position];
}
if current_row < sm_limit {
let rhs_sm_position = current_row * B_N + current_col;
let rhs_position = offset_rhs + (k + current_row) * n_cols + current_col;
shared_rhs[rhs_sm_position] = rhs[rhs_position];
}
}
}
workgroupBarrier();
// Compute intermediate results
// Results are cumulated in results array and updated at each block
// Outer loop indicates which subcolumns/subrows to read from shared memories
for (var dot_index = 0u; dot_index < B_K; dot_index++) {
// Load a subcolumn of values from lhs
for (var tile_index = 0u; tile_index < actual_T_M; tile_index++) {
let lhs_sm_position = (thread_row + tile_index) * B_K + dot_index;
register_M[tile_index] = shared_lhs[lhs_sm_position];
}
// Load a subrow of values from rhs
for (var tile_index = 0u; tile_index < actual_T_N; tile_index++) {
let rhs_sm_position = thread_col + tile_index + dot_index * B_N;
register_N[tile_index] = shared_rhs[rhs_sm_position];
}
// Multiply subcolumn and subrow and store results
for (var res_idx_M = 0u; res_idx_M < actual_T_M; res_idx_M++) {
for (var res_idx_N = 0u; res_idx_N < actual_T_N; res_idx_N++) {
results[res_idx_M * actual_T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N];
}
}
}
workgroupBarrier();
}
// Write output matrix
// Each thread is responsible of writing T_M x T_N results
for (var res_idx_M = 0u; res_idx_M < actual_T_M; res_idx_M++) {
for (var res_idx_N = 0u; res_idx_N < actual_T_N; res_idx_N++) {
let current_row = row + res_idx_M;
let current_col = col + res_idx_N;
// Check that we are within the bounds of output matrix
if current_row < n_rows && current_col < n_cols {
let result_position = res_idx_M * actual_T_N + res_idx_N;
let output_position = offset_output + current_row * n_cols + current_col;
output[output_position] = results[result_position];
}
}
}
}