diff --git a/burn-wgpu/benches/matmul.rs b/burn-wgpu/benches/matmul.rs index 2e9294a62..586faf0e9 100644 --- a/burn-wgpu/benches/matmul.rs +++ b/burn-wgpu/benches/matmul.rs @@ -2,6 +2,7 @@ use burn_common::benchmark::{run_benchmark, Benchmark}; use burn_tensor::backend::Backend; use burn_tensor::{Distribution, Shape, Tensor}; use burn_wgpu::kernel::matmul::init_matmul_output; +use burn_wgpu::kernel::matmul::unpadded::matmul_tiling_2d_unpadded; use burn_wgpu::kernel::matmul::vec4::matmul_tiling_2d_vec4; use burn_wgpu::kernel::matmul::vec4_lhs::matmul_tiling_2d_vec4_lhs; use burn_wgpu::WgpuDevice; @@ -100,6 +101,11 @@ bench_matmul!( Tiling2DMatmulVec4, matmul_tiling_2d_vec4 ); +bench_matmul!( + Tiling2DMatmulUnpaddedBenchmark, + Tiling2DMatmulUnpadded, + matmul_tiling_2d_unpadded +); #[allow(dead_code)] /// Runs the benchmarks for wgpu matmul implementations @@ -107,9 +113,9 @@ pub fn bench(device: &WgpuDevice) { const D: usize = 3; let num_repeats = 3; let batch_size = 3; - let m = 2048; - let k = 2048; - let n = 1024; + let m = 1007; + let k = 1023; + let n = 1005; let shape_lhs = Shape::new([batch_size, m, k]); let shape_rhs = Shape::new([batch_size, k, n]); @@ -125,6 +131,7 @@ pub fn bench(device: &WgpuDevice) { } run_matmul_benchmark!(NaiveMatmulBenchmark); run_matmul_benchmark!(MemCoalescingMatmulBenchmark); + run_matmul_benchmark!(Tiling2DMatmulUnpaddedBenchmark); run_matmul_benchmark!(Tiling2DMatmulVec4LHSBenchmark); run_matmul_benchmark!(Tiling2DMatmulVec4Benchmark); } diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/base.rs b/burn-wgpu/src/kernel/matmul/tiling2d/base.rs index d4b4b1926..0ce7808d5 100644 --- a/burn-wgpu/src/kernel/matmul/tiling2d/base.rs +++ b/burn-wgpu/src/kernel/matmul/tiling2d/base.rs @@ -13,7 +13,7 @@ pub(crate) const B_N: usize = 64; pub(crate) const B_K: usize = 32; pub(crate) const WORKGROUP_SIZE: usize = 16; -pub(super) fn make_workgroup(output_shape: Shape) -> WorkGroup { +pub(super) fn make_workgroup(output_shape: &Shape) -> WorkGroup { let num_blocks_x = f32::ceil(output_shape.dims[D - 2] as f32 / B_M as f32) as u32; let num_blocks_y = f32::ceil(output_shape.dims[D - 1] as f32 / B_N as f32) as u32; let mut num_blocks_z = 1; @@ -71,7 +71,7 @@ pub(super) fn matmul_tiling_2d_launch< rounded_output_shape.clone(), ); - let workgroup = make_workgroup(rounded_output_shape); + let workgroup = make_workgroup(&rounded_output_shape); let info_handle = make_info_handle(&lhs, &rhs, &rounded_output); lhs.client.execute( diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/mod.rs b/burn-wgpu/src/kernel/matmul/tiling2d/mod.rs index 5537ef934..e58ce1dbd 100644 --- a/burn-wgpu/src/kernel/matmul/tiling2d/mod.rs +++ b/burn-wgpu/src/kernel/matmul/tiling2d/mod.rs @@ -1,6 +1,9 @@ mod base; mod padding; +/// WGSL vec4 primitives are used on left and right hand tensor, +/// padding is avoided through the use of conditions in the kernel +pub mod unpadded; /// WGSL vec4 primitives are used on left and right hand tensor pub mod vec4; /// WGSL vec4 primitives are used on left hand tensor diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/unpadded.rs b/burn-wgpu/src/kernel/matmul/tiling2d/unpadded.rs new file mode 100644 index 000000000..b444bd20b --- /dev/null +++ b/burn-wgpu/src/kernel/matmul/tiling2d/unpadded.rs @@ -0,0 +1,195 @@ +use burn_tensor::Element; + +use crate::{ + compute::DynamicKernel, + element::WgpuElement, + kernel::{into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource}, + tensor::WgpuTensor, +}; +use std::marker::PhantomData; + +use crate::kernel_wgsl; + +use super::base::{make_info_handle, make_workgroup, B_K, B_M, B_N, WORKGROUP_SIZE}; + +kernel_wgsl!( + MatmulTiling2DUnpaddedRaw, + "../../../template/matmul/blocktiling_2d/unpadded.wgsl" +); + +#[derive(new, Debug)] +struct MatmulTiling2DUnpadded { + _elem: PhantomData, +} + +impl DynamicKernelSource for MatmulTiling2DUnpadded { + fn source(&self) -> SourceTemplate { + MatmulTiling2DUnpaddedRaw::source() + .register("b_m", B_M.to_string()) + .register("b_n", B_N.to_string()) + .register("b_k", B_K.to_string()) + .register("bm_x_bk_4", (B_M * B_K / 4).to_string()) + .register("bk_x_bn_4", (B_K * B_N / 4).to_string()) + .register("workgroup_size_x", WORKGROUP_SIZE.to_string()) + .register("workgroup_size_y", WORKGROUP_SIZE.to_string()) + .register("workgroup_size_z", "1".to_string()) + .register("elem", E::type_name()) + .register("int", "i32") + } + + fn id(&self) -> String { + std::format!("{:?}", self) + } +} + +/// Matrix multiplication using tiling 2d algorithm with +/// vec4 primitive on both lhs and rhs, with no padding needed +pub fn matmul_tiling_2d_unpadded( + lhs: WgpuTensor, + rhs: WgpuTensor, + out: WgpuTensor, +) -> WgpuTensor { + let lhs = match lhs.batch_swapped_with_row_col() { + true => into_contiguous(lhs), + false => lhs, + }; + let rhs = match rhs.batch_swapped_with_row_col() { + true => into_contiguous(rhs), + false => rhs, + }; + + let workgroup = make_workgroup(&out.shape); + let info_handle = make_info_handle(&lhs, &rhs, &out); + + lhs.client.execute( + Box::new(DynamicKernel::new( + MatmulTiling2DUnpadded::::new(), + workgroup, + )), + &[&lhs.handle, &rhs.handle, &out.handle, &info_handle], + ); + + out +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims}; + + #[test] + pub fn test_matmul_unpadded_straightforward() { + test_with_params(1, 2, 1, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_shapes_smaller_than_blocks() { + test_with_params(8, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_shapes_equal_blocks() { + test_with_params(64, 32, 64, 2, 2); + } + + #[test] + pub fn test_matmul_unpadded_m_exceeds_block() { + test_with_params(75, 32, 64, 2, 2); + } + + #[test] + pub fn test_matmul_unpadded_k_exceeds_block() { + test_with_params(64, 33, 32, 1, 1); + } + + #[test] + pub fn test_matmul_irregular_shape() { + test_with_params(123, 255, 72, 3, 5); + } + + #[test] + pub fn test64_matmul_unpadded_n_exceeds_block() { + test_with_params(64, 32, 75, 2, 2); + } + + #[test] + pub fn test_matmul_unpadded_n_smaller_than_m() { + test_with_params(8, 8, 3, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_m_smaller_than_n() { + test_with_params(3, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_k_smaller_than_m_n() { + test_with_params(8, 3, 8, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_k_larger_than_m_n() { + test_with_params(8, 48, 8, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_multibatch_1_dim() { + test_with_params(8, 8, 8, 3, 1); + } + + #[test] + pub fn test_matmul_unpadded_multibatch_2_dims() { + test_with_params(8, 8, 8, 3, 4); + } + + #[test] + pub fn test_matmul_unpadded_blocks_divide_shapes_unevenly() { + test_with_params(7, 7, 7, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_medium() { + test_with_params(17, 16, 16, 1, 1); + } + + #[test] + pub fn test_matmul_unpadded_large() { + test_with_params(134, 242, 250, 1, 1); + } + + fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) { + let func = matmul_tiling_2d_unpadded; + let shape_lhs = [batch_1, batch_2, m, k]; + let shape_rhs = [batch_1, batch_2, k, n]; + same_as_reference(func, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_tiling_2d_primitive_swapped_batches_no_padding() { + let matmul_func = matmul_tiling_2d_unpadded; + let swap = [0, 1]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_tiling_2d_primitive_swapped_row_col_no_padding() { + let matmul_func = matmul_tiling_2d_unpadded; + let swap_lhs = [0, 0]; + let swap_rhs = [2, 3]; + let shape_lhs = [3, 2, 4, 4]; + let shape_rhs = [3, 2, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); + } + + #[test] + fn test_matmul_tiling_2d_primitive_swapped_row_with_batch_no_padding() { + let matmul_func = matmul_tiling_2d_unpadded; + let swap_lhs = [0, 3]; + let swap_rhs = [0, 2]; + let shape_lhs = [4, 4, 4, 4]; + let shape_rhs = [4, 4, 4, 4]; + same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs); + } +} diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/vec4.rs b/burn-wgpu/src/kernel/matmul/tiling2d/vec4.rs index 80f13c1ec..1130ccd74 100644 --- a/burn-wgpu/src/kernel/matmul/tiling2d/vec4.rs +++ b/burn-wgpu/src/kernel/matmul/tiling2d/vec4.rs @@ -12,18 +12,18 @@ use crate::kernel_wgsl; use super::base::{matmul_tiling_2d_launch, B_K, B_M, B_N, WORKGROUP_SIZE}; kernel_wgsl!( - MatmulTiling2Dvec4RHSRaw, + MatmulTiling2Dvec4Raw, "../../../template/matmul/blocktiling_2d/vec4.wgsl" ); #[derive(new, Debug)] -struct MatmulTiling2Dvec4RHS { +struct MatmulTiling2Dvec4 { _elem: PhantomData, } -impl DynamicKernelSource for MatmulTiling2Dvec4RHS { +impl DynamicKernelSource for MatmulTiling2Dvec4 { fn source(&self) -> SourceTemplate { - MatmulTiling2Dvec4RHSRaw::source() + MatmulTiling2Dvec4Raw::source() .register("b_m", B_M.to_string()) .register("b_n", B_N.to_string()) .register("b_k", B_K.to_string()) @@ -48,7 +48,7 @@ pub fn matmul_tiling_2d_vec4( rhs: WgpuTensor, out: WgpuTensor, ) -> WgpuTensor { - let kernel = MatmulTiling2Dvec4RHS::::new(); + let kernel = MatmulTiling2Dvec4::::new(); matmul_tiling_2d_launch(lhs, rhs, out, kernel) } diff --git a/burn-wgpu/src/kernel/matmul/tune/base.rs b/burn-wgpu/src/kernel/matmul/tune/base.rs index d393856c0..5f2b11084 100644 --- a/burn-wgpu/src/kernel/matmul/tune/base.rs +++ b/burn-wgpu/src/kernel/matmul/tune/base.rs @@ -82,6 +82,11 @@ impl AutotuneOperationSet rhs.clone(), out.clone(), )), + Box::new(Vec4TilingMatmulUnpaddedDefault::::new( + lhs.clone(), + rhs.clone(), + out.clone(), + )), Box::new(Vec4LhsOnlyTilingMatmulDefault::::new(lhs, rhs, out)), ] } @@ -97,7 +102,10 @@ impl AutotuneOperationSet 2 => Box::new(Vec4TilingMatmulDefault::::new( self.lhs, self.rhs, self.out, )), - 3 => Box::new(Vec4LhsOnlyTilingMatmulDefault::::new( + 3 => Box::new(Vec4TilingMatmulUnpaddedDefault::::new( + self.lhs, self.rhs, self.out, + )), + 4 => Box::new(Vec4LhsOnlyTilingMatmulDefault::::new( self.lhs, self.rhs, self.out, )), _ => panic!("Fastest index is out of bound"), @@ -162,18 +170,24 @@ matmul_tune_ops!(MemoryCoalescingMatmulW16x16, |lhs, rhs, out| { crate::kernel::matmul::matmul_mem_coalescing(lhs, rhs, out, 16, 16) }); -// Probably the fastest on MacOS. +// Maybe the fastest on MacOS. matmul_tune_ops!( Vec4LhsOnlyTilingMatmulDefault, crate::kernel::matmul::vec4_lhs::matmul_tiling_2d_vec4_lhs ); -// Probably the fastest. +// Probably the fastest when fixed sizes. matmul_tune_ops!( Vec4TilingMatmulDefault, crate::kernel::matmul::vec4::matmul_tiling_2d_vec4 ); +// Probably the fastest otherwise. +matmul_tune_ops!( + Vec4TilingMatmulUnpaddedDefault, + crate::kernel::matmul::unpadded::matmul_tiling_2d_unpadded +); + #[cfg(test)] mod tests { use super::*; diff --git a/burn-wgpu/src/template/matmul/blocktiling_2d/unpadded.wgsl b/burn-wgpu/src/template/matmul/blocktiling_2d/unpadded.wgsl new file mode 100644 index 000000000..9dd3f9407 --- /dev/null +++ b/burn-wgpu/src/template/matmul/blocktiling_2d/unpadded.wgsl @@ -0,0 +1,243 @@ +@group(0) +@binding(0) +var lhs: array<{{ elem }}>; + +@group(0) +@binding(1) +var rhs: array<{{ elem }}>; + +@group(0) +@binding(2) +var output: array<{{ elem }}>; + +@group(0) +@binding(3) +var info: array; + +const B_M = {{b_m}}u; +const B_N = {{b_n}}u; +const B_K = {{b_k}}u; +const B_M_X_B_K_4 = {{bm_x_bk_4}}u; +const B_K_X_B_N_4 = {{bk_x_bn_4}}u; + +const T_M = 4u; +const T_N = 4u; +const T_M_X_T_N = 16u; + +var shared_lhs: array, B_M_X_B_K_4>; +var shared_rhs: array, B_K_X_B_N_4>; + +@compute +@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }}) +fn main( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_index) local_idx: u32, + @builtin(workgroup_id) workgroup_id: vec3, +) { + let skip_row = workgroup_id.x * B_M; + let skip_col = workgroup_id.y * B_N; + + let n_thread_per_row = ((B_N - 1u) / T_N) + 1u; + + // Position of the first element of the thread, relative to the block + let thread_row = (local_idx / n_thread_per_row) * T_M; + let thread_col = (local_idx % n_thread_per_row) * T_N; + + // Position of the first element of the thread, in absolute (in one batch) + let row = skip_row + thread_row; + let col = skip_col + thread_col; + + let batch = global_id.z; + + // Basic information + let dim = info[0]; + let n_rows = info[6u * dim - 1u]; + let n_cols = info[6u * dim]; + let K = info[5u * dim - 1u]; + + // Row / col strides + let lhs_stride_row = info[dim - 1u]; + let lhs_stride_col = info[dim]; + let rhs_stride_row = info[2u * dim - 1u]; + let rhs_stride_col = info[2u * dim]; + let out_stride_row = info [3u * dim - 1u]; + let out_stride_col = info [3u * dim]; + + // Calculate the corresponding offsets with support for broadcasting. + let offset_output = batch * n_rows * n_cols; + var offset_lhs: u32 = skip_row * lhs_stride_row; + var offset_rhs: u32 = skip_col * rhs_stride_col; + + let batch_dims = dim - 2u; + for (var b: u32 = 1u; b <= batch_dims; b++) { + let stride_lhs = info[b]; + let stride_rhs = info[b + dim]; + let stride_output = info[b + 2u * dim]; + let shape_lhs = info[b + 3u * dim]; + let shape_rhs = info[b + 4u * dim]; + + offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs; + offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs; + } + + // Registers used in the compute pass + var results: array<{{ elem }}, T_M_X_T_N>; + var register_M: vec4<{{ elem }}>; + var register_N: vec4<{{ elem }}>; + + // How close is the thread to the end of the matrix. + // If < 4 then it is an edge case + let remain_row_lhs = n_rows - row; + let remain_col_rhs = n_cols - col; + + for (var k = 0u; k < K; k += B_K) { + + // LHS LOAD PASS + + // For the 4 vec4 columns of this thread + for (var j = 0u; j < 4u; j++) { + + // The precise + let current_col = thread_col + j; + + // Position of the column vec4 in shared memory + let lhs_sm_position = (thread_row/4u) * B_K + current_col; + + // To avoid overwriting following row in share memory + if current_col < B_K { + // To pad with zeros if outside lhs + if current_col + k < K && remain_row_lhs >= 1u { + let lhs_position0 = offset_lhs + (k + current_col) * lhs_stride_col + thread_row * lhs_stride_row; + let lhs_position1 = lhs_position0 + lhs_stride_row; + let lhs_position2 = lhs_position1 + lhs_stride_row; + let lhs_position3 = lhs_position2 + lhs_stride_row; + + if remain_row_lhs >= 4u { + shared_lhs[lhs_sm_position] = vec4( + lhs[lhs_position0], + lhs[lhs_position1], + lhs[lhs_position2], + lhs[lhs_position3], + ); + } else if remain_row_lhs == 3u { + shared_lhs[lhs_sm_position] = vec4( + lhs[lhs_position0], + lhs[lhs_position1], + lhs[lhs_position2], + 0. + ); + } else if remain_row_lhs == 2u { + shared_lhs[lhs_sm_position] = vec4( + lhs[lhs_position0], + lhs[lhs_position1], + 0., + 0. + ); + } else if remain_row_lhs == 1u { + shared_lhs[lhs_sm_position] = vec4( + lhs[lhs_position0], + 0., + 0., + 0. + ); + } + } else { + shared_lhs[lhs_sm_position] = vec4(0.,0.,0.,0.); + } + } + } + + // RHS LOAD PASS + + for (var i = 0u; i < 4u; i++) { + let current_row = thread_row + i; + + let rhs_sm_position = (current_row * B_N + thread_col) / 4u; + + if current_row < B_K { + if current_row + k < K && remain_col_rhs >= 1u { + + let rhs_position0 = offset_rhs + (k + current_row) * rhs_stride_row + thread_col * rhs_stride_col; + let rhs_position1 = rhs_position0 + rhs_stride_col; + let rhs_position2 = rhs_position1 + rhs_stride_col; + let rhs_position3 = rhs_position2 + rhs_stride_col; + + if remain_col_rhs >= 4u { + shared_rhs[rhs_sm_position] = vec4( + rhs[rhs_position0], + rhs[rhs_position1], + rhs[rhs_position2], + rhs[rhs_position3], + ); + } else if remain_col_rhs == 3u { + shared_rhs[rhs_sm_position] = vec4( + rhs[rhs_position0], + rhs[rhs_position1], + rhs[rhs_position2], + 0. + ); + } else if remain_col_rhs == 2u { + shared_rhs[rhs_sm_position] = vec4( + rhs[rhs_position0], + rhs[rhs_position1], + 0., + 0. + ); + } else if remain_col_rhs == 1u { + shared_rhs[rhs_sm_position] = vec4( + rhs[rhs_position0], + 0., + 0., + 0. + ); + } + } else { + shared_rhs[rhs_sm_position] = vec4(0.,0.,0.,0.); + } + } + } + + workgroupBarrier(); + + // COMPUTE PASS + + // Compute intermediate results + // Results are cumulated in results array and updated at each block + // Outer loop indicates which subcolumns/subrows to read from shared memories + for (var dot_index = 0u; dot_index < B_K; dot_index++) { + + // Load a subcolumn of values from lhs + let lhs_sm_position = (thread_row/4u) * B_K + dot_index; + register_M = shared_lhs[lhs_sm_position]; + + // Load a subrow of values from rhs + let rhs_sm_position = (dot_index * B_N + thread_col) / 4u; + register_N = shared_rhs[rhs_sm_position]; + + // Multiply subcolumn and subrow and store results + for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) { + for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) { + results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N]; + } + } + } + + workgroupBarrier(); + } + + // OUTPUT PASS + + // Write output matrix + // Each thread is responsible of writing T_M x T_N results + for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) { + for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) { + let row_index = row + res_idx_M; + let col_index = col + res_idx_N; + if row_index < n_rows && col_index < n_cols { + let result_position = res_idx_M * T_N + res_idx_N; + let output_position = offset_output + row_index * out_stride_row + col_index * out_stride_col; + output[output_position] = results[result_position]; + } + } + } +}