diff --git a/burn-wgpu/benches/matmul.rs b/burn-wgpu/benches/matmul.rs index 85dc8c9df..5cbcb2992 100644 --- a/burn-wgpu/benches/matmul.rs +++ b/burn-wgpu/benches/matmul.rs @@ -1,11 +1,13 @@ -use std::marker::PhantomData; - use burn_tensor::{backend::Backend, Distribution, Shape, Tensor}; use burn_wgpu::{ benchmark::Benchmark, - kernel::{matmul_mem_coalescing_default, matmul_naive_default, matmul_tiling_2d_default}, + kernel::matmul::{ + continuous, continuous_vectorized, matmul_mem_coalescing_default, matmul_naive_default, + tile, tile_vectorized, + }, run_benchmark, GraphicsApi, WgpuBackend, WgpuDevice, }; +use std::marker::PhantomData; trait MatmulFunction { fn run(lhs: Tensor, rhs: Tensor) -> Tensor; @@ -37,6 +39,10 @@ where ) } + fn num_samples(&self) -> usize { + 10 + } + fn execute(&self, (lhs, rhs): Self::Args) { for _ in 0..self.num_repeats { F::run(lhs.clone(), rhs.clone()); @@ -51,71 +57,69 @@ where } } -struct Tiling2DMatmul; +macro_rules! benchmark { + ($name:ident, $func:expr) => { + struct $name; -impl MatmulFunction, D> - for Tiling2DMatmul -{ - fn run( - lhs: Tensor, D>, - rhs: Tensor, D>, - ) -> Tensor, D> { - Tensor::from_primitive(matmul_tiling_2d_default( - lhs.into_primitive(), - rhs.into_primitive(), - )) - } + impl MatmulFunction, D> for $name { + fn run( + lhs: Tensor, D>, + rhs: Tensor, D>, + ) -> Tensor, D> { + Tensor::from_primitive($func(lhs.into_primitive(), rhs.into_primitive())) + } + } + }; } -struct NaiveMatmul; - -impl MatmulFunction, D> for NaiveMatmul { - fn run( - lhs: Tensor, D>, - rhs: Tensor, D>, - ) -> Tensor, D> { - Tensor::from_primitive(matmul_naive_default( - lhs.into_primitive(), - rhs.into_primitive(), - )) - } -} - -struct MemCoalescingMatmul; - -impl MatmulFunction, D> - for MemCoalescingMatmul -{ - fn run( - lhs: Tensor, D>, - rhs: Tensor, D>, - ) -> Tensor, D> { - Tensor::from_primitive(matmul_mem_coalescing_default( - lhs.into_primitive(), - rhs.into_primitive(), - )) - } -} +benchmark!(NaiveMatmul, matmul_naive_default); +benchmark!(MemCoalescingMatmul, matmul_mem_coalescing_default); +benchmark!( + Tiling2DMatmulContinuous, + continuous::matmul_tiling_2d_default +); +benchmark!(Tiling2DMatmulTile, tile::matmul_tiling_2d_default); +benchmark!( + Tiling2DMatmulTileVectorized, + tile_vectorized::matmul_tiling_2d_default +); +benchmark!( + Tiling2DMatmulContinuousVectorized, + continuous_vectorized::matmul_tiling_2d_default +); fn main() { - let batch_size = 32; - let matrix_size = 128; - run_benchmark!(MatmulBenchmark:: { - shape_lhs: [batch_size, matrix_size, matrix_size].into(), - shape_rhs: [batch_size, matrix_size, matrix_size].into(), - num_repeats: 10, - matmul: PhantomData::default() - }); + let num_repeats = 3; + let batch_size = 3; + let matrix_size = 1000; run_benchmark!(MatmulBenchmark:: { shape_lhs: [batch_size, matrix_size, matrix_size].into(), shape_rhs: [batch_size, matrix_size, matrix_size].into(), - num_repeats: 10, + num_repeats, matmul: PhantomData::default() }); - run_benchmark!(MatmulBenchmark:: { + run_benchmark!(MatmulBenchmark:: { shape_lhs: [batch_size, matrix_size, matrix_size].into(), shape_rhs: [batch_size, matrix_size, matrix_size].into(), - num_repeats: 10, + num_repeats, + matmul: PhantomData::default() + }); + run_benchmark!(MatmulBenchmark:: { + shape_lhs: [batch_size, matrix_size, matrix_size].into(), + shape_rhs: [batch_size, matrix_size, matrix_size].into(), + num_repeats, + matmul: PhantomData::default() + }); + run_benchmark!(MatmulBenchmark:: { + shape_lhs: [batch_size, matrix_size, matrix_size].into(), + shape_rhs: [batch_size, matrix_size, matrix_size].into(), + num_repeats, + matmul: PhantomData::default() + }); + run_benchmark!(MatmulBenchmark:: { + shape_lhs: [batch_size, matrix_size, matrix_size].into(), + shape_rhs: [batch_size, matrix_size, matrix_size].into(), + num_repeats, matmul: PhantomData::default() }); } diff --git a/burn-wgpu/src/kernel/matmul/mem_coalescing.rs b/burn-wgpu/src/kernel/matmul/mem_coalescing.rs index 63768270b..2b53496e6 100644 --- a/burn-wgpu/src/kernel/matmul/mem_coalescing.rs +++ b/burn-wgpu/src/kernel/matmul/mem_coalescing.rs @@ -1,3 +1,4 @@ +use super::utils::shape_out; use crate::{ context::WorkGroup, element::WgpuElement, @@ -5,11 +6,10 @@ use crate::{ kernel_wgsl, tensor::WgpuTensor, }; -use burn_tensor::Shape; kernel_wgsl!( MatmulMemCoalescingRaw, - "../../template/matmul_mem_coalescing.wgsl" + "../../template/matmul/mem_coalescing.wgsl" ); struct MatmulMemCoalescing; @@ -44,21 +44,9 @@ pub fn matmul_mem_coalescing< ) -> WgpuTensor { lhs.assert_is_on_same_device(&rhs); - let mut shape_out = [0; D]; - lhs.shape - .dims - .iter() - .zip(rhs.shape.dims.iter()) - .enumerate() - .for_each(|(index, (dim_lhs, dim_rhs))| { - shape_out[index] = usize::max(*dim_lhs, *dim_rhs); - }); - + let shape_out = shape_out(&lhs, &rhs); let num_rows = lhs.shape.dims[D - 2]; let num_cols = rhs.shape.dims[D - 1]; - shape_out[D - 2] = num_rows; - shape_out[D - 1] = num_cols; - let shape_out = Shape::new(shape_out); let buffer = lhs .context @@ -103,10 +91,7 @@ pub fn matmul_mem_coalescing< #[cfg(test)] mod tests { use super::*; - use crate::tests::TestTensor; - - pub type ReferenceTensor = - burn_tensor::Tensor, D>; + use crate::kernel::matmul::utils::tests::same_as_reference; #[test] pub fn test_matmul_mem_coalescing_straightforward() { @@ -167,25 +152,4 @@ mod tests { let shape_rhs = [batch_1, batch_2, k, n]; same_as_reference(func, shape_lhs, shape_rhs); } - - fn same_as_reference(func: F, shape_lhs: S, shape_rhs: S) - where - F: Fn(WgpuTensor, WgpuTensor) -> WgpuTensor, - S: Into>, - { - let x = ReferenceTensor::random(shape_lhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); - let y = ReferenceTensor::random(shape_rhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); - - let x_wgpu = TestTensor::from_data(x.to_data()); - let y_wgpu = TestTensor::from_data(y.to_data()); - - let z_reference = x.matmul(y); - - let z = func(x_wgpu.into_primitive(), y_wgpu.into_primitive()); - let z = TestTensor::from_primitive(z); - - println!("{z}"); - - z_reference.into_data().assert_approx_eq(&z.into_data(), 3); - } } diff --git a/burn-wgpu/src/kernel/matmul/mod.rs b/burn-wgpu/src/kernel/matmul/mod.rs index 803e89d38..9dfcecc3a 100644 --- a/burn-wgpu/src/kernel/matmul/mod.rs +++ b/burn-wgpu/src/kernel/matmul/mod.rs @@ -1,3 +1,5 @@ +pub(crate) mod utils; + mod mem_coalescing; mod naive; mod tiling2d; diff --git a/burn-wgpu/src/kernel/matmul/naive.rs b/burn-wgpu/src/kernel/matmul/naive.rs index 678c64d04..a12508404 100644 --- a/burn-wgpu/src/kernel/matmul/naive.rs +++ b/burn-wgpu/src/kernel/matmul/naive.rs @@ -1,3 +1,4 @@ +use super::utils::shape_out; use crate::{ context::WorkGroup, element::WgpuElement, @@ -5,9 +6,8 @@ use crate::{ kernel_wgsl, tensor::WgpuTensor, }; -use burn_tensor::Shape; -kernel_wgsl!(MatmulNaiveRaw, "../../template/matmul_naive.wgsl"); +kernel_wgsl!(MatmulNaiveRaw, "../../template/matmul/naive.wgsl"); struct MatmulNaive; @@ -41,21 +41,10 @@ pub fn matmul_naive< ) -> WgpuTensor { lhs.assert_is_on_same_device(&rhs); - let mut shape_out = [0; D]; - lhs.shape - .dims - .iter() - .zip(rhs.shape.dims.iter()) - .enumerate() - .for_each(|(index, (dim_lhs, dim_rhs))| { - shape_out[index] = usize::max(*dim_lhs, *dim_rhs); - }); + let shape_out = shape_out(&lhs, &rhs); let num_rows = lhs.shape.dims[D - 2]; let num_cols = rhs.shape.dims[D - 1]; - shape_out[D - 2] = num_rows; - shape_out[D - 1] = num_cols; - let shape_out = Shape::new(shape_out); let buffer = lhs .context @@ -100,10 +89,7 @@ pub fn matmul_naive< #[cfg(test)] mod tests { use super::*; - use crate::tests::TestTensor; - - pub type ReferenceTensor = - burn_tensor::Tensor, D>; + use crate::kernel::matmul::utils::tests::same_as_reference; #[test] pub fn test_matmul_naive_straightforward() { @@ -162,25 +148,4 @@ mod tests { let shape_rhs = [batch_1, batch_2, k, n]; same_as_reference(func, shape_lhs, shape_rhs); } - - fn same_as_reference(func: F, shape_lhs: S, shape_rhs: S) - where - F: Fn(WgpuTensor, WgpuTensor) -> WgpuTensor, - S: Into>, - { - let x = ReferenceTensor::random(shape_lhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); - let y = ReferenceTensor::random(shape_rhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); - - let x_wgpu = TestTensor::from_data(x.to_data()); - let y_wgpu = TestTensor::from_data(y.to_data()); - - let z_reference = x.matmul(y); - - let z = func(x_wgpu.into_primitive(), y_wgpu.into_primitive()); - let z = TestTensor::from_primitive(z); - - println!("{z}"); - - z_reference.into_data().assert_approx_eq(&z.into_data(), 3); - } } diff --git a/burn-wgpu/src/kernel/matmul/tiling2d.rs b/burn-wgpu/src/kernel/matmul/tiling2d.rs deleted file mode 100644 index e0ba97050..000000000 --- a/burn-wgpu/src/kernel/matmul/tiling2d.rs +++ /dev/null @@ -1,330 +0,0 @@ -use std::cmp::{max, min}; - -use crate::{ - context::WorkGroup, - element::WgpuElement, - kernel::{build_info, KernelSettings, SourceTemplate, StaticKernel}, - kernel_wgsl, - tensor::WgpuTensor, -}; -use burn_tensor::Shape; - -const MAX_SHARED_MEMORY_SIZE: usize = 8192; - -kernel_wgsl!( - MatmulTiling2DRaw, - "../../template/matmul_blocktiling_2d.wgsl" -); - -struct MatmulTiling2D< - const B_M: usize, - const B_N: usize, - const B_K: usize, - const T_M: usize, - const T_N: usize, - const WORKGROUP_SIZE_X: usize, - const WORKGROUP_SIZE_Y: usize, ->; - -impl< - const B_M: usize, - const B_N: usize, - const B_K: usize, - const T_M: usize, - const T_N: usize, - const WORKGROUP_SIZE_X: usize, - const WORKGROUP_SIZE_Y: usize, - > StaticKernel for MatmulTiling2D -{ - fn source_template() -> SourceTemplate { - MatmulTiling2DRaw::source_template() - .register("b_m", B_M.to_string()) - .register("b_n", B_N.to_string()) - .register("b_k", B_K.to_string()) - .register("bm_x_bk", (B_M * B_K).to_string()) - .register("bk_x_bn", (B_K * B_N).to_string()) - .register("t_m", T_M.to_string()) - .register("t_n", T_N.to_string()) - .register("tm_x_tn", (T_M * T_N).to_string()) - } -} - -/// Matrix multiplication using tiling 2D algorithm with default parameters -pub fn matmul_tiling_2d_default( - lhs: WgpuTensor, - rhs: WgpuTensor, -) -> WgpuTensor { - // Suppose a matmul of m1 of size [M, K] with m2 of size [K, N] - // Block size along dim M - const B_M: usize = 128; - // // Block size along dim N - const B_N: usize = 128; - // // Block size along dim K - const B_K: usize = 8; - // // Tiling size along dim M - const T_M: usize = 8; - // // Tiling size along dim N - const T_N: usize = 8; - // WORKGROUP_SIZE_X = ceil(B_M / T_M) - const WORKGROUP_SIZE_X: usize = 16; - // WORKGROUP_SIZE_Y = ceil(B_N / T_N) - const WORKGROUP_SIZE_Y: usize = 16; - - matmul_tiling_2d::(lhs, rhs) -} - -/// Matrix multiplication using tiling 2D algorithm with custom parameters -pub fn matmul_tiling_2d< - E: WgpuElement, - const D: usize, - const B_M: usize, - const B_N: usize, - const B_K: usize, - const T_M: usize, - const T_N: usize, - const WORKGROUP_SIZE_X: usize, - const WORKGROUP_SIZE_Y: usize, ->( - lhs: WgpuTensor, - rhs: WgpuTensor, -) -> WgpuTensor { - assert!(B_K <= min(B_M, B_N), "B_K must be smaller than both B_M and B_M, otherwise there won't be enough threads to fill shared memory. "); - assert!(B_K * max(B_M, B_N) <= MAX_SHARED_MEMORY_SIZE, "B_K x B_M and B_K x B_N must be smaller or equal than 8192, otherwise shared memory limit will be busted. "); - assert!( - WORKGROUP_SIZE_X == f32::ceil(B_M as f32 / T_M as f32) as usize, - "Workgroup size x must equal ceil(B_M / T_M)" - ); - assert!( - WORKGROUP_SIZE_Y == f32::ceil(B_N as f32 / T_N as f32) as usize, - "Workgroup size y must equal ceil(B_N / T_N)" - ); - lhs.assert_is_on_same_device(&rhs); - - let mut shape_out = [0; D]; - lhs.shape - .dims - .iter() - .zip(rhs.shape.dims.iter()) - .enumerate() - .for_each(|(index, (dim_lhs, dim_rhs))| { - shape_out[index] = usize::max(*dim_lhs, *dim_rhs); - }); - - let num_rows = lhs.shape.dims[D - 2]; - let num_cols = rhs.shape.dims[D - 1]; - shape_out[D - 2] = num_rows; - shape_out[D - 1] = num_cols; - let shape_out = Shape::new(shape_out); - - let buffer = lhs - .context - .create_buffer(shape_out.num_elements() * core::mem::size_of::()); - let output = WgpuTensor::new(lhs.context.clone(), shape_out, buffer); - - // set number of workgroups - let blocks_needed_in_x = f32::ceil(num_rows as f32 / B_M as f32) as u32; - let blocks_needed_in_y = f32::ceil(num_cols as f32 / B_N as f32) as u32; - - let kernel = lhs.context.compile_static::, - E, - i32, - WORKGROUP_SIZE_X, - WORKGROUP_SIZE_Y, - 1, - >>(); - - let info = build_info(&[&lhs, &rhs, &output]); - - let info_buffers = lhs - .context - .create_buffer_with_data(bytemuck::cast_slice(&info)); - - let mut num_iter = 1; - for i in 0..D - 2 { - num_iter *= output.shape.dims[i]; - } - - let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32); - - lhs.context.execute( - workgroup, - kernel, - &[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers], - ); - - output -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::tests::TestTensor; - - pub type ReferenceTensor = - burn_tensor::Tensor, D>; - - #[test] - pub fn test_matmul_tiling_2d_shapes_smaller_than_blocks() { - test_with_params::<128, 128, 16, 8, 8, 16, 16>(8, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_tiling_2d_m_not_equals_n() { - test_with_params::<16, 16, 8, 8, 8, 2, 2>(8, 8, 3, 1, 1); - } - - #[test] - pub fn test_matmul_tiling_2d_k_smaller_than_m_n() { - test_with_params::<16, 16, 8, 8, 8, 2, 2>(8, 3, 8, 1, 1); - } - - #[test] - pub fn test_matmul_tiling_2d_k_larger_than_m_n() { - test_with_params::<16, 16, 8, 8, 8, 2, 2>(8, 48, 8, 1, 1); - } - - #[test] - pub fn test_matmul_tiling_2d_t_divides_b_unevenly() { - test_with_params::<128, 128, 8, 7, 11, 19, 12>(8, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_tiling_2d_small_parameters() { - test_with_params::<128, 128, 8, 8, 8, 16, 16>(8, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_tiling_2d_bm_not_equals_bn() { - test_with_params::<32, 128, 8, 8, 8, 4, 16>(8, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_tiling_2d_multibatch_1_dim() { - test_with_params::<128, 128, 8, 8, 8, 16, 16>(8, 8, 8, 3, 1); - } - - #[test] - pub fn test_matmul_tiling_2d_multibatch_2_dims() { - test_with_params::<128, 128, 8, 8, 8, 16, 16>(8, 8, 8, 3, 4); - } - - #[test] - #[should_panic] - pub fn test_matmul_tiling_2d_memory_busted_should_panic() { - test_with_params::<128, 128, 128, 8, 8, 16, 16>(8, 8, 8, 1, 1); - } - - #[test] - #[should_panic] - pub fn test_matmul_tiling_2d_bk_larger_than_bm_should_panic() { - test_with_params::<64, 64, 128, 8, 8, 8, 8>(8, 8, 8, 1, 1); - } - - #[test] - #[should_panic] - pub fn test_matmul_tiling_2d_workgroup_x_wrong_should_panic() { - test_with_params::<128, 128, 16, 8, 8, 16, 8>(8, 8, 8, 1, 1); - } - - #[test] - #[should_panic] - pub fn test_matmul_tiling_2d_workgroup_y_wrong_should_panic() { - test_with_params::<128, 128, 16, 8, 8, 8, 7>(8, 8, 8, 1, 1); - } - - #[test] - pub fn test_matmul_tiling_2d_multiple_blocks() { - test_with_params::<16, 16, 8, 8, 8, 2, 2>(32, 32, 32, 1, 1); - } - - #[test] - pub fn test_matmul_tiling_2d_k_bigger_than_bk() { - test_with_params::<128, 128, 8, 8, 8, 16, 16>(8, 10, 8, 1, 1); - } - - #[test] - pub fn test_matmul_tiling_2d_blocks_divide_shapes_unevenly() { - test_with_params::<16, 16, 8, 8, 8, 2, 2>(31, 23, 17, 1, 1); - } - - #[test] - pub fn test_matmul_tiling_2d_large_parameters() { - test_with_params::<256, 256, 16, 16, 16, 16, 16>(40, 40, 40, 1, 1); - } - - #[test] - pub fn test_matmul_tiling_2d_shapes_slightly_larger_than_blocks() { - test_with_params::<32, 32, 8, 8, 8, 4, 4>(40, 40, 30, 1, 1); - } - - #[test] - pub fn test_matmul_tiling_2d_shapes_way_larger_than_blocks() { - test_with_params::<16, 16, 8, 8, 8, 2, 2>(50, 50, 50, 1, 1); - } - - #[test] - pub fn test_matmul_tiling_2d_tm_larger_than_bm() { - test_with_params::<2, 2, 2, 3, 2, 1, 1>(5, 5, 5, 1, 1); - } - - #[test] - pub fn test_matmul_tiling_2d_tn_larger_than_bn() { - test_with_params::<2, 2, 2, 2, 3, 1, 1>(5, 5, 5, 1, 1); - } - - #[test] - pub fn test_matmul_tiling_2d_uneven_parameters() { - test_with_params::<17, 15, 11, 13, 7, 2, 3>(24, 24, 24, 1, 1); - } - - #[test] - pub fn test_matmul_tiling_2d_uneven_parameters_2() { - test_with_params::<11, 14, 10, 7, 17, 2, 1>(10, 24, 17, 1, 1); - } - - fn test_with_params< - const B_M: usize, - const B_N: usize, - const B_K: usize, - const T_M: usize, - const T_N: usize, - const WORKGROUP_SIZE_X: usize, - const WORKGROUP_SIZE_Y: usize, - >( - m: usize, - k: usize, - n: usize, - batch_1: usize, - batch_2: usize, - ) { - let func = |lhs, rhs| { - matmul_tiling_2d::( - lhs, rhs, - ) - }; - let shape_lhs = [batch_1, batch_2, m, k]; - let shape_rhs = [batch_1, batch_2, k, n]; - same_as_reference(func, shape_lhs, shape_rhs); - } - - fn same_as_reference(func: F, shape_lhs: S, shape_rhs: S) - where - F: Fn(WgpuTensor, WgpuTensor) -> WgpuTensor, - S: Into>, - { - let x = ReferenceTensor::random(shape_lhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); - let y = ReferenceTensor::random(shape_rhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); - - let x_wgpu = TestTensor::from_data(x.to_data()); - let y_wgpu = TestTensor::from_data(y.to_data()); - - let z_reference = x.matmul(y); - - let z = func(x_wgpu.into_primitive(), y_wgpu.into_primitive()); - let z = TestTensor::from_primitive(z); - - println!("{z}"); - z_reference.into_data().assert_approx_eq(&z.into_data(), 3); - } -} diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/base.rs b/burn-wgpu/src/kernel/matmul/tiling2d/base.rs new file mode 100644 index 000000000..02eda16a6 --- /dev/null +++ b/burn-wgpu/src/kernel/matmul/tiling2d/base.rs @@ -0,0 +1,415 @@ +use crate::{ + context::{Context, WorkGroup}, + element::WgpuElement, + kernel::{build_info, matmul::utils::shape_out, SourceTemplate}, + tensor::WgpuTensor, +}; +use burn_tensor::Shape; +use std::{ + cmp::{max, min}, + sync::Arc, +}; +use wgpu::ComputePipeline; + +use super::padding::{crop, pad_round}; + +const MAX_SHARED_MEMORY_SIZE: usize = 8192; + +pub(super) fn empty_from_context( + context: Arc, + shape: &Shape, +) -> WgpuTensor { + let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::()); + + WgpuTensor::new(context, shape.clone(), buffer) +} + +/// Create a source template for tile 2d matmul. +#[macro_export(local_inner_macros)] +macro_rules! matmul_tile_2d { + ( + $struct:ident, + $file:expr + ) => { + matmul_tile_2d!( + $struct, + $file, + B_M 64, + B_N 64, + B_K 32, + T_M 4, + T_N 4 + ); + }; + + ( + $struct:ident, + $file:expr, + B_M $bm:expr, + B_N $bn:expr, + B_K $bk:expr, + T_M $tm:expr, + T_N $tn:expr + ) => { + struct $struct< + const B_M: usize, + const B_N: usize, + const B_K: usize, + const T_M: usize, + const T_N: usize, + const WORKGROUP_SIZE_X: usize, + const WORKGROUP_SIZE_Y: usize, + >; + + impl< + const B_M: usize, + const B_N: usize, + const B_K: usize, + const T_M: usize, + const T_N: usize, + const WORKGROUP_SIZE_X: usize, + const WORKGROUP_SIZE_Y: usize, + > StaticKernel + for $struct + { + fn source_template() -> SourceTemplate { + kernel_wgsl!(Raw, $file); + + register_template::( + Raw::source_template(), + ) + } + } + + /// Matrix multiplication using tiling 2D algorithm with default parameters + pub fn matmul_tiling_2d_default( + lhs: WgpuTensor, + rhs: WgpuTensor, + ) -> WgpuTensor { + // Suppose a matmul of m1 of size [M, K] with m2 of size [K, N] + // Block size along dim M + const B_M: usize = $bm; + // // Block size along dim N + const B_N: usize = $bn; + // // Block size along dim K + const B_K: usize = $bk; + // // Tiling size along dim M + const T_M: usize = $tm; + // // Tiling size along dim N + const T_N: usize = $tn; + // WORKGROUP_SIZE_X = ceil(B_M / T_M) + const WORKGROUP_SIZE_X: usize = B_M / T_M; + // WORKGROUP_SIZE_Y = ceil(B_N / T_N) + const WORKGROUP_SIZE_Y: usize = B_N / T_N; + + matmul_tiling_2d::( + lhs, rhs, + ) + } + + /// Matrix multiplication using tiling 2D algorithm with custom parameters + pub fn matmul_tiling_2d< + E: WgpuElement, + const D: usize, + const B_M: usize, + const B_N: usize, + const B_K: usize, + const T_M: usize, + const T_N: usize, + const WORKGROUP_SIZE_X: usize, + const WORKGROUP_SIZE_Y: usize, + >( + lhs: WgpuTensor, + rhs: WgpuTensor, + ) -> WgpuTensor { + let kernel = lhs.context.compile_static::, + E, + i32, + WORKGROUP_SIZE_X, + WORKGROUP_SIZE_Y, + 1, + >>(); + matmul_tiling_2d_launch::< + E, + D, + B_M, + B_N, + B_K, + T_M, + T_N, + WORKGROUP_SIZE_X, + WORKGROUP_SIZE_Y, + >(lhs, rhs, kernel) + } + + #[cfg(test)] + mod tests { + use super::*; + use $crate::kernel::matmul::utils::tests::same_as_reference; + + #[test] + pub fn test_matmul_tiling_2d_large_blocks() { + test_with_params::<128, 128, 8, 4, 4, 32, 32>(8, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_tiling_2d_shapes_smaller_than_blocks() { + test_with_params::<64, 64, 8, 4, 4, 16, 16>(8, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_tiling_2d_m_not_equals_n() { + test_with_params::<16, 16, 8, 2, 2, 8, 8>(16, 8, 16, 1, 1); + } + + #[test] + pub fn test_matmul_tiling_2d_k_smaller_than_m_n() { + test_with_params::<16, 16, 4, 2, 2, 8, 8>(16, 4, 16, 1, 1); + } + + #[test] + pub fn test_matmul_tiling_2d_k_larger_than_m_n() { + test_with_params::<8, 8, 8, 2, 2, 4, 4>(8, 48, 8, 1, 1); + } + + #[test] + #[should_panic] + pub fn test_matmul_tiling_2d_t_divides_b_unevenly_should_panic() { + test_with_params::<128, 128, 8, 7, 11, 19, 12>(8, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_tiling_2d_bm_not_equals_bn() { + test_with_params::<8, 16, 8, 2, 4, 4, 4>(8, 8, 16, 1, 1); + } + + #[test] + pub fn test_matmul_tiling_2d_multibatch_1_dim() { + test_with_params::<8, 8, 8, 2, 2, 4, 4>(8, 8, 8, 3, 1); + } + + #[test] + pub fn test_matmul_tiling_2d_multibatch_2_dims() { + test_with_params::<8, 8, 8, 2, 2, 4, 4>(8, 8, 8, 3, 4); + } + + #[test] + #[should_panic] + pub fn test_matmul_tiling_2d_memory_busted_should_panic() { + test_with_params::<128, 128, 128, 8, 8, 16, 16>(8, 8, 8, 1, 1); + } + + #[test] + #[should_panic] + pub fn test_matmul_tiling_2d_bk_larger_than_bm_should_panic() { + test_with_params::<64, 64, 128, 8, 8, 8, 8>(8, 8, 8, 1, 1); + } + + #[test] + #[should_panic] + pub fn test_matmul_tiling_2d_workgroup_x_wrong_should_panic() { + test_with_params::<128, 128, 16, 8, 8, 16, 8>(8, 8, 8, 1, 1); + } + + #[test] + #[should_panic] + pub fn test_matmul_tiling_2d_workgroup_y_wrong_should_panic() { + test_with_params::<128, 128, 16, 8, 8, 8, 7>(8, 8, 8, 1, 1); + } + + #[test] + pub fn test_matmul_tiling_2d_multiple_blocks() { + test_with_params::<16, 16, 8, 2, 2, 8, 8>(32, 32, 32, 1, 1); + } + + #[test] + pub fn test_matmul_tiling_2d_k_bigger_than_bk() { + test_with_params::<8, 8, 8, 2, 2, 4, 4>(8, 16, 8, 1, 1); + } + + #[test] + pub fn test_matmul_tiling_2d_blocks_divide_shapes_unevenly() { + test_with_params::<16, 16, 8, 2, 2, 8, 8>(31, 23, 17, 1, 1); + } + + #[test] + pub fn test_matmul_tiling_2d_shapes_way_larger_than_blocks() { + test_with_params::<16, 16, 8, 2, 2, 8, 8>(48, 48, 48, 1, 1); + } + + #[test] + #[should_panic] + pub fn test_matmul_tiling_2d_tm_larger_than_bm_should_panic() { + test_with_params::<2, 2, 2, 3, 2, 1, 1>(5, 5, 5, 1, 1); + } + + #[test] + #[should_panic] + pub fn test_matmul_tiling_2d_tn_larger_than_bn_should_panic() { + test_with_params::<2, 2, 2, 2, 3, 1, 1>(5, 5, 5, 1, 1); + } + + #[test] + #[should_panic] + pub fn test_matmul_tiling_2d_uneven_parameters_should_panic() { + test_with_params::<17, 15, 11, 13, 7, 2, 3>(24, 24, 24, 1, 1); + } + + #[test] + #[should_panic] + pub fn test_matmul_tiling_2d_uneven_parameters_2_should_panic() { + test_with_params::<11, 14, 10, 7, 17, 2, 1>(10, 24, 17, 1, 1); + } + + fn test_with_params< + const B_M: usize, + const B_N: usize, + const B_K: usize, + const T_M: usize, + const T_N: usize, + const WORKGROUP_SIZE_X: usize, + const WORKGROUP_SIZE_Y: usize, + >( + m: usize, + k: usize, + n: usize, + batch_1: usize, + batch_2: usize, + ) { + let func = |lhs, rhs| { + matmul_tiling_2d::( + lhs, rhs, + ) + }; + let shape_lhs = [batch_1, batch_2, m, k]; + let shape_rhs = [batch_1, batch_2, k, n]; + same_as_reference(func, shape_lhs, shape_rhs); + } + + + } + }; +} + +pub(super) fn register_template< + const B_M: usize, + const B_N: usize, + const B_K: usize, + const T_M: usize, + const T_N: usize, + const WORKGROUP_SIZE_X: usize, + const WORKGROUP_SIZE_Y: usize, +>( + template: SourceTemplate, +) -> SourceTemplate { + template + .register("b_m", B_M.to_string()) + .register("b_n", B_N.to_string()) + .register("b_k", B_K.to_string()) + .register("bm_x_bk", (B_M * B_K).to_string()) + .register("bk_x_bn", (B_K * B_N).to_string()) + .register("t_m", T_M.to_string()) + .register("t_n", T_N.to_string()) + .register("tm_x_tn", (T_M * T_N).to_string()) +} + +#[allow(clippy::too_many_arguments)] +pub(super) fn matmul_parameter_assertions( + b_m: usize, + b_n: usize, + b_k: usize, + t_m: usize, + t_n: usize, + workgroup_size_x: usize, + workgroup_size_y: usize, + lhs: &WgpuTensor, + rhs: &WgpuTensor, +) { + assert!(b_k <= min(b_m, b_n), "B_K must be smaller than both B_M and B_M, otherwise there won't be enough threads to fill shared memory. "); + assert!(b_k * max(b_m, b_n) <= MAX_SHARED_MEMORY_SIZE, "B_K x B_M and B_K x B_N must be smaller or equal than 8192, otherwise shared memory limit will be busted. "); + assert!( + b_m % t_m == 0 && b_n % t_n == 0, + "T_M must divide B_M in this version" + ); + assert!( + workgroup_size_x == b_m / t_m, + "Workgroup size x must equal B_M / T_M" + ); + assert!( + workgroup_size_y == b_n / t_n, + "Workgroup size y must equal B_N / T_N" + ); + lhs.assert_is_on_same_device(rhs); +} + +pub(super) fn make_workgroup( + output_shape: Shape, + b_m: usize, + b_n: usize, +) -> WorkGroup { + let num_blocks_x = f32::ceil(output_shape.dims[D - 2] as f32 / b_m as f32) as u32; + let num_blocks_y = f32::ceil(output_shape.dims[D - 1] as f32 / b_n as f32) as u32; + let mut num_blocks_z = 1; + for i in 0..D - 2 { + num_blocks_z *= output_shape.dims[i]; + } + + WorkGroup::new(num_blocks_x, num_blocks_y, num_blocks_z as u32) +} + +pub(super) fn make_info_buffers( + lhs: &WgpuTensor, + rhs: &WgpuTensor, + output: &WgpuTensor, +) -> Arc { + let info = build_info(&[lhs, rhs, output]); + rhs.context + .create_buffer_with_data(bytemuck::cast_slice(&info)) +} + +pub(super) fn matmul_tiling_2d_launch< + E: WgpuElement, + const D: usize, + const B_M: usize, + const B_N: usize, + const B_K: usize, + const T_M: usize, + const T_N: usize, + const WORKGROUP_SIZE_X: usize, + const WORKGROUP_SIZE_Y: usize, +>( + lhs: WgpuTensor, + rhs: WgpuTensor, + kernel: Arc, +) -> WgpuTensor { + matmul_parameter_assertions::( + B_M, + B_N, + B_K, + T_M, + T_N, + WORKGROUP_SIZE_X, + WORKGROUP_SIZE_Y, + &lhs, + &rhs, + ); + + let final_output_shape = shape_out(&lhs, &rhs); + let lhs = pad_round(lhs, B_M, B_K); + let rhs = pad_round(rhs, B_K, B_N); + let rounded_output_shape = shape_out(&lhs, &rhs); + + let output = empty_from_context::(rhs.context.clone(), &rounded_output_shape); + + let workgroup = make_workgroup(rounded_output_shape, B_M, B_N); + let info_buffers = make_info_buffers(&lhs, &rhs, &output); + + lhs.context.execute( + workgroup, + kernel, + &[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers], + ); + + crop(output, final_output_shape) +} diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/continuous.rs b/burn-wgpu/src/kernel/matmul/tiling2d/continuous.rs new file mode 100644 index 000000000..0e7016633 --- /dev/null +++ b/burn-wgpu/src/kernel/matmul/tiling2d/continuous.rs @@ -0,0 +1,12 @@ +use super::base::{matmul_tiling_2d_launch, register_template}; +use crate::{ + element::WgpuElement, + kernel::{KernelSettings, SourceTemplate, StaticKernel}, + matmul_tile_2d, + tensor::WgpuTensor, +}; + +matmul_tile_2d!( + MatmulTiling2DContinuous, + "../../../template/matmul/blocktiling_2d/continuous.wgsl" +); diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/continuous_vectorized.rs b/burn-wgpu/src/kernel/matmul/tiling2d/continuous_vectorized.rs new file mode 100644 index 000000000..cbc4fae01 --- /dev/null +++ b/burn-wgpu/src/kernel/matmul/tiling2d/continuous_vectorized.rs @@ -0,0 +1,12 @@ +use super::base::{matmul_tiling_2d_launch, register_template}; +use crate::{ + element::WgpuElement, + kernel::{KernelSettings, SourceTemplate, StaticKernel}, + matmul_tile_2d, + tensor::WgpuTensor, +}; + +matmul_tile_2d!( + MatmulTiling2DContinuousVectorized, + "../../../template/matmul/blocktiling_2d/continuous_vectorized.wgsl" +); diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/mod.rs b/burn-wgpu/src/kernel/matmul/tiling2d/mod.rs new file mode 100644 index 000000000..3f5c0871a --- /dev/null +++ b/burn-wgpu/src/kernel/matmul/tiling2d/mod.rs @@ -0,0 +1,13 @@ +mod base; +mod padding; + +/// Loading is done in a continuous manner +pub mod continuous; +/// Loading is done in a continuous manner. lhs is transposed +pub mod continuous_vectorized; +/// Loading is done in a tile manner +pub mod tile; +/// Loading is done in a tile manner. lhs is transposed +pub mod tile_vectorized; + +pub use tile_vectorized::*; diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/padding.rs b/burn-wgpu/src/kernel/matmul/tiling2d/padding.rs new file mode 100644 index 000000000..cb5f7426c --- /dev/null +++ b/burn-wgpu/src/kernel/matmul/tiling2d/padding.rs @@ -0,0 +1,241 @@ +use std::ops::Range; + +use burn_tensor::Shape; + +use crate::{ + element::WgpuElement, + kernel::{slice, slice_assign}, + tensor::WgpuTensor, +}; + +use super::base::empty_from_context; + +/// Pads tensor with zeros to make tensor number of rows and columns +/// divisible by some quantity. +/// For instance tensor of shape [1000, 1000] with divisors 64 and 64 +/// will be padded to [1024, 1024] with the last 24 elements being zeros +pub(super) fn pad_round( + tensor: WgpuTensor, + row_divisor: usize, + col_divisor: usize, +) -> WgpuTensor { + let row_modulo = tensor.shape.dims[D - 2] % row_divisor; + let col_modulo = tensor.shape.dims[D - 1] % col_divisor; + if row_modulo == 0 && col_modulo == 0 { + return tensor; + } + let mut padded_shape = Vec::with_capacity(D); + for i in 0..D - 2 { + padded_shape.push(tensor.shape.dims[i]); + } + padded_shape.push(tensor.shape.dims[D - 2] - row_modulo + row_divisor); + padded_shape.push(tensor.shape.dims[D - 1] - col_modulo + col_divisor); + padding::(tensor, padded_shape.into()) +} + +/// Pads tensor by adding zeros when padded dim is larger than tensor dim +fn padding( + tensor: WgpuTensor, + padded_shape: Shape, +) -> WgpuTensor { + let ranges = padded_shape + .dims + .iter() + .map(|dim| 0..*dim) + .collect::>>() + .try_into() + .unwrap(); + slice_assign::( + empty_from_context(tensor.context.clone(), &padded_shape), + ranges, + tensor, + ) +} + +/// Crops tensor by deleting values when cropped dim is smaller than tensor dim +pub(super) fn crop( + tensor: WgpuTensor, + cropped_shape: Shape, +) -> WgpuTensor { + let ranges = cropped_shape + .dims + .iter() + .map(|dim| 0..*dim) + .collect::>>() + .try_into() + .unwrap(); + slice::(tensor, ranges) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::TestTensor; + + #[test] + fn padding_already_round_should_have_same_shape() { + let row = 10; + let row_divisor = 5; + let col = 12; + let col_divisor = 3; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + let expected_shape = [row, col].into(); + + let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor); + + assert!(padded.shape == expected_shape); + } + + #[test] + fn padding_already_round_should_have_same_values() { + let row = 10; + let row_divisor = 5; + let col = 12; + let col_divisor = 3; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + + let padded = pad_round(tensor.clone().into_primitive(), row_divisor, col_divisor); + + let padded = TestTensor::from_primitive(padded); + padded.into_data().assert_approx_eq(&tensor.into_data(), 3); + } + + #[test] + fn padding_not_round_should_have_rounded_shape() { + let row = 10; + let row_divisor = 6; + let col = 12; + let col_divisor = 5; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + let expected_shape = [12, 15].into(); + + let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor); + + assert!(padded.shape == expected_shape); + } + + #[test] + fn padding_not_round_should_have_same_values() { + let row = 10; + let row_divisor = 6; + let col = 12; + let col_divisor = 5; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + + let padded = pad_round(tensor.clone().into_primitive(), row_divisor, col_divisor); + + let padded = TestTensor::from_primitive(padded).to_data(); + let tensor = tensor.into_data(); + for i in 0..row { + for j in 0..col { + assert!(padded.value[i * 15 + j] == tensor.value[i * col + j]); + } + } + } + + #[test] + fn padding_not_round_should_have_zero_padding() { + let row = 10; + let row_divisor = 6; + let col = 12; + let col_divisor = 5; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + + let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor); + let padded = TestTensor::from_primitive(padded).to_data(); + + // check right of matrix + for i in 0..row { + for j in col..15 { + assert!(padded.value[i * 15 + j] == 0.0); + } + } + // check below matrix, including bottom right + for i in row..12 { + for j in 0..15 { + assert!(padded.value[i * 15 + j] == 0.0); + } + } + } + + #[test] + fn padding_works_with_batch() { + let row = 10; + let row_divisor = 4; + let col = 12; + let col_divisor = 5; + let tensor = TestTensor::random([2, 3, row, col], burn_tensor::Distribution::Default); + let expected_shape = [2, 3, 12, 15].into(); + + let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor); + + assert!(padded.shape == expected_shape); + } + + #[test] + fn crop_same_shape_should_be_unchanged_shape() { + let row = 10; + let col = 12; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + let expected_shape = [row, col].into(); + + let unpadded = crop(tensor.into_primitive(), [row, col].into()); + + assert!(unpadded.shape == expected_shape); + } + + #[test] + fn crop_same_shape_should_have_unchanged_values() { + let row = 10; + let col = 12; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + + let unpadded = crop(tensor.clone().into_primitive(), [row, col].into()); + + let unpadded = TestTensor::from_primitive(unpadded).to_data(); + let tensor = tensor.into_data(); + for i in 0..row { + for j in 0..col { + assert!(unpadded.value[i * col + j] == tensor.value[i * col + j]); + } + } + } + + #[test] + fn crop_should_decrease_shape() { + let row = 10; + let keep_rows = 8; + let col = 12; + let keep_cols = 10; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + let expected_shape = [keep_rows, keep_cols].into(); + + let unpadded = crop(tensor.into_primitive(), [keep_rows, keep_cols].into()); + + assert!(unpadded.shape == expected_shape); + } + + #[test] + fn crop_should_keep_same_values() { + let row = 4; + let keep_rows = 3; + let col = 4; + let keep_cols = 3; + let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default); + + let unpadded = crop( + tensor.clone().into_primitive(), + [keep_rows, keep_cols].into(), + ); + + let unpadded = TestTensor::from_primitive(unpadded).to_data(); + let tensor = tensor.into_data(); + println!("{:?}\n {:?}", unpadded, tensor); + + for i in 0..keep_rows { + for j in 0..keep_cols { + assert!(unpadded.value[i * keep_cols + j] == tensor.value[i * col + j]); + } + } + } +} diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/tile.rs b/burn-wgpu/src/kernel/matmul/tiling2d/tile.rs new file mode 100644 index 000000000..156d953fa --- /dev/null +++ b/burn-wgpu/src/kernel/matmul/tiling2d/tile.rs @@ -0,0 +1,13 @@ +use crate::{ + element::WgpuElement, + kernel::{KernelSettings, SourceTemplate, StaticKernel}, + matmul_tile_2d, + tensor::WgpuTensor, +}; + +use super::base::{matmul_tiling_2d_launch, register_template}; + +matmul_tile_2d!( + MatmulTiling2DTile, + "../../../template/matmul/blocktiling_2d/tile.wgsl" +); diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/tile_vectorized.rs b/burn-wgpu/src/kernel/matmul/tiling2d/tile_vectorized.rs new file mode 100644 index 000000000..2e05fd5a5 --- /dev/null +++ b/burn-wgpu/src/kernel/matmul/tiling2d/tile_vectorized.rs @@ -0,0 +1,12 @@ +use super::base::{matmul_tiling_2d_launch, register_template}; +use crate::{ + element::WgpuElement, + kernel::{KernelSettings, SourceTemplate, StaticKernel}, + matmul_tile_2d, + tensor::WgpuTensor, +}; + +matmul_tile_2d!( + MatmulTiling2DTileVectorized, + "../../../template/matmul/blocktiling_2d/tile_vectorized.wgsl" +); diff --git a/burn-wgpu/src/kernel/matmul/utils.rs b/burn-wgpu/src/kernel/matmul/utils.rs new file mode 100644 index 000000000..ad62c1a3b --- /dev/null +++ b/burn-wgpu/src/kernel/matmul/utils.rs @@ -0,0 +1,48 @@ +use crate::{element::WgpuElement, tensor::WgpuTensor}; +use burn_tensor::Shape; + +pub(crate) fn shape_out( + lhs: &WgpuTensor, + rhs: &WgpuTensor, +) -> Shape { + let mut shape_out = [0; D]; + lhs.shape + .dims + .iter() + .zip(rhs.shape.dims.iter()) + .enumerate() + .for_each(|(index, (dim_lhs, dim_rhs))| { + shape_out[index] = usize::max(*dim_lhs, *dim_rhs); + }); + shape_out[D - 2] = lhs.shape.dims[D - 2]; + shape_out[D - 1] = rhs.shape.dims[D - 1]; + Shape::new(shape_out) +} + +#[cfg(test)] +pub(crate) mod tests { + use crate::tensor::WgpuTensor; + use crate::tests::{ReferenceTensor, TestTensor}; + use burn_tensor::Shape; + + pub(crate) fn same_as_reference(func: F, shape_lhs: S, shape_rhs: S) + where + F: Fn(WgpuTensor, WgpuTensor) -> WgpuTensor, + S: Into>, + { + let x = ReferenceTensor::random(shape_lhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); + let y = ReferenceTensor::random(shape_rhs, burn_tensor::Distribution::Uniform(-1.0, 1.0)); + + let x_wgpu = TestTensor::from_data(x.to_data()); + let y_wgpu = TestTensor::from_data(y.to_data()); + + let z_reference = x.matmul(y); + + let z = func(x_wgpu.into_primitive(), y_wgpu.into_primitive()); + let z = TestTensor::from_primitive(z); + + std::println!("{z}"); + std::println!("{z_reference}"); + z_reference.into_data().assert_approx_eq(&z.into_data(), 3); + } +} diff --git a/burn-wgpu/src/kernel/mod.rs b/burn-wgpu/src/kernel/mod.rs index ad3a7a00e..6b9be7a69 100644 --- a/burn-wgpu/src/kernel/mod.rs +++ b/burn-wgpu/src/kernel/mod.rs @@ -4,7 +4,6 @@ mod cat; mod comparison; mod index; mod mask; -mod matmul; mod reduction; mod source; mod unary; @@ -12,11 +11,13 @@ mod unary_scalar; pub use base::*; pub use binary_elemwise::*; -pub use matmul::*; pub use source::*; pub use unary::*; pub use unary_scalar::*; +/// Matmul kernels +pub mod matmul; + pub(crate) use cat::*; pub(crate) use comparison::*; pub(crate) use index::*; diff --git a/burn-wgpu/src/lib.rs b/burn-wgpu/src/lib.rs index eb6b74f9d..4a8ead8f7 100644 --- a/burn-wgpu/src/lib.rs +++ b/burn-wgpu/src/lib.rs @@ -38,7 +38,9 @@ mod tests { pub type TestBackend = WgpuBackend; pub type ReferenceBackend = burn_ndarray::NdArrayBackend; + pub type TestTensor = burn_tensor::Tensor; + pub type ReferenceTensor = burn_tensor::Tensor; pub type TestTensorInt = burn_tensor::Tensor; burn_tensor::testgen_add!(); diff --git a/burn-wgpu/src/ops/float_ops.rs b/burn-wgpu/src/ops/float_ops.rs index e16943679..f7c4ebd5a 100644 --- a/burn-wgpu/src/ops/float_ops.rs +++ b/burn-wgpu/src/ops/float_ops.rs @@ -1,7 +1,6 @@ use super::{numeric, BoolTensor, Device, FloatElem, FloatTensor, IntTensor}; use crate::kernel::{ - self, matmul_tiling_2d_default, unary_default, unary_inplace_default, unary_scalar_default, - unary_scalar_inplace_default, + self, unary_default, unary_inplace_default, unary_scalar_default, unary_scalar_inplace_default, }; use crate::unary_scalar_inplace; use crate::{ @@ -140,7 +139,7 @@ where let lhs = kernel::into_continuous(lhs); let rhs = kernel::into_continuous(rhs); - matmul_tiling_2d_default::, D>(lhs, rhs) + kernel::matmul::tile_vectorized::matmul_tiling_2d_default(lhs, rhs) } fn swap_dims( diff --git a/burn-wgpu/src/template/matmul/blocktiling_2d/continuous.wgsl b/burn-wgpu/src/template/matmul/blocktiling_2d/continuous.wgsl new file mode 100644 index 000000000..26f002ad3 --- /dev/null +++ b/burn-wgpu/src/template/matmul/blocktiling_2d/continuous.wgsl @@ -0,0 +1,139 @@ +@group(0) +@binding(0) +var lhs: array<{{ elem }}>; + +@group(0) +@binding(1) +var rhs: array<{{ elem }}>; + +@group(0) +@binding(2) +var output: array<{{ elem }}>; + +@group(0) +@binding(3) +var info: array; + +const B_M = {{b_m}}u; +const B_N = {{b_n}}u; +const B_K = {{b_k}}u; +const B_M_X_B_K = {{bm_x_bk}}u; +const B_K_X_B_N = {{bk_x_bn}}u; +const T_M = {{t_m}}u; +const T_N = {{t_n}}u; +const T_M_X_T_N = {{tm_x_tn}}u; + +var shared_lhs: array<{{ elem }}, B_M_X_B_K>; +var shared_rhs: array<{{ elem }}, B_K_X_B_N>; + +@compute +@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }}) +fn main( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_index) local_idx: u32, + @builtin(workgroup_id) workgroup_id: vec3, +) { + let skip_row = workgroup_id.x * B_M; + let skip_col = workgroup_id.y * B_N; + + let n_thread_per_row = ((B_N - 1u) / T_N) + 1u; + let thread_row = (local_idx / n_thread_per_row) * T_M; + let thread_col = (local_idx % n_thread_per_row) * T_N; + + let row = skip_row + thread_row; + let col = skip_col + thread_col; + + let batch = global_id.z; + + // Basic information + let dim = info[0]; + let n_rows = info[6u * dim - 1u]; + let n_cols = info[6u * dim]; + let K = info[5u * dim - 1u]; + + // Calculate the corresponding offsets with support for broadcasting. + let offset_output = batch * n_rows * n_cols; + var offset_lhs: u32 = skip_row * K; + var offset_rhs: u32 = skip_col; + + let batch_dims = dim - 2u; + for (var b: u32 = 1u; b <= batch_dims; b++) { + let stride_lhs = info[b]; + let stride_rhs = info[b + dim]; + let stride_output = info[b + 2u * dim]; + let shape_lhs = info[b + 3u * dim]; + let shape_rhs = info[b + 4u * dim]; + + offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs; + offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs; + } + + var results: array<{{ elem }}, T_M_X_T_N>; + var register_M: array<{{ elem }}, T_M>; + var register_N: array<{{ elem }}, T_N>; + + let thread_offset = local_idx * T_M_X_T_N; + + for (var k = 0u; k < K; k += B_K) { + for (var load_index = 0u; load_index < T_M_X_T_N; load_index ++) { + let lhs_sm_position = thread_offset + load_index; + let block_row = lhs_sm_position / B_K; + let block_col = lhs_sm_position % B_K; + let lhs_position = offset_lhs + k + block_row * K + block_col; + + if block_row < B_M { + shared_lhs[lhs_sm_position] = lhs[lhs_position]; + } + } + + for (var load_index = 0u; load_index < T_M_X_T_N; load_index ++) { + let rhs_sm_position = thread_offset + load_index; + let block_row = rhs_sm_position / B_N; + let block_col = rhs_sm_position % B_N; + let rhs_position = offset_rhs + (k + block_row) * n_cols + block_col; + + if block_row < B_K { + shared_rhs[rhs_sm_position] = rhs[rhs_position]; + } + } + + workgroupBarrier(); + + // Compute intermediate results + // Results are cumulated in results array and updated at each block + // Outer loop indicates which subcolumns/subrows to read from shared memories + for (var dot_index = 0u; dot_index < B_K; dot_index++) { + // Load a subcolumn of values from lhs + for (var tile_index = 0u; tile_index < T_M; tile_index++) { + let lhs_sm_position = (thread_row + tile_index) * B_K + dot_index; + register_M[tile_index] = shared_lhs[lhs_sm_position]; + } + // Load a subrow of values from rhs + for (var tile_index = 0u; tile_index < T_N; tile_index++) { + let rhs_sm_position = thread_col + tile_index + dot_index * B_N; + register_N[tile_index] = shared_rhs[rhs_sm_position]; + } + // Multiply subcolumn and subrow and store results + for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) { + for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) { + results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N]; + } + } + } + + workgroupBarrier(); + } + + // Write output matrix + // Each thread is responsible of writing T_M x T_N results + for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) { + for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) { + let current_row = row + res_idx_M; + let current_col = col + res_idx_N; + // Check that we are within the bounds of output matrix + let result_position = res_idx_M * T_N + res_idx_N; + let output_position = offset_output + current_row * n_cols + current_col; + output[output_position] = results[result_position]; + } + } +} diff --git a/burn-wgpu/src/template/matmul/blocktiling_2d/continuous_vectorized.wgsl b/burn-wgpu/src/template/matmul/blocktiling_2d/continuous_vectorized.wgsl new file mode 100644 index 000000000..fbb34a62d --- /dev/null +++ b/burn-wgpu/src/template/matmul/blocktiling_2d/continuous_vectorized.wgsl @@ -0,0 +1,138 @@ +@group(0) +@binding(0) +var lhs: array<{{ elem }}>; + +@group(0) +@binding(1) +var rhs: array<{{ elem }}>; + +@group(0) +@binding(2) +var output: array<{{ elem }}>; + +@group(0) +@binding(3) +var info: array; + +const B_M = {{b_m}}u; +const B_N = {{b_n}}u; +const B_K = {{b_k}}u; +const B_M_X_B_K = {{bm_x_bk}}u; +const B_K_X_B_N = {{bk_x_bn}}u; +const T_M = {{t_m}}u; +const T_N = {{t_n}}u; +const T_M_X_T_N = {{tm_x_tn}}u; + +var shared_lhs: array<{{ elem }}, B_M_X_B_K>; +var shared_rhs: array<{{ elem }}, B_K_X_B_N>; + +@compute +@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }}) +fn main( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_index) local_idx: u32, + @builtin(workgroup_id) workgroup_id: vec3, +) { + let skip_row = workgroup_id.x * B_M; + let skip_col = workgroup_id.y * B_N; + + let n_thread_per_row = ((B_N - 1u) / T_N) + 1u; + let thread_row = (local_idx / n_thread_per_row) * T_M; + let thread_col = (local_idx % n_thread_per_row) * T_N; + + let row = skip_row + thread_row; + let col = skip_col + thread_col; + + let batch = global_id.z; + + // Basic information + let dim = info[0]; + let n_rows = info[6u * dim - 1u]; + let n_cols = info[6u * dim]; + let K = info[5u * dim - 1u]; + + // Calculate the corresponding offsets with support for broadcasting. + let offset_output = batch * n_rows * n_cols; + var offset_lhs: u32 = skip_row * K; + var offset_rhs: u32 = skip_col; + + let batch_dims = dim - 2u; + for (var b: u32 = 1u; b <= batch_dims; b++) { + let stride_lhs = info[b]; + let stride_rhs = info[b + dim]; + let stride_output = info[b + 2u * dim]; + let shape_lhs = info[b + 3u * dim]; + let shape_rhs = info[b + 4u * dim]; + + offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs; + offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs; + } + + var results: array<{{ elem }}, T_M_X_T_N>; + var register_M: array<{{ elem }}, T_M>; + var register_N: array<{{ elem }}, T_N>; + + let thread_offset = local_idx * T_M_X_T_N; + + for (var k = 0u; k < K; k += B_K) { + // tile: let lhs_sm_position = current_row * B_K + current_col; + // tile_vec: let lhs_sm_position = current_row + current_col * B_M; + for (var load_index = 0u; load_index < T_M_X_T_N; load_index ++) { + let lhs_sm_position = thread_offset + load_index; + let block_row = lhs_sm_position % B_M; + let block_col = lhs_sm_position / B_M; + let lhs_position = offset_lhs + k + block_row * K + block_col; + + if block_col < B_K { + shared_lhs[lhs_sm_position] = lhs[lhs_position]; + } + } + + for (var load_index = 0u; load_index < T_M_X_T_N; load_index ++) { + let rhs_sm_position = thread_offset + load_index; + let block_row = rhs_sm_position / B_N; + let block_col = rhs_sm_position % B_N; + let rhs_position = offset_rhs + (k + block_row) * n_cols + block_col; + + if block_row < B_K { + shared_rhs[rhs_sm_position] = rhs[rhs_position]; + } + } + + workgroupBarrier(); + + // Compute intermediate results + // Results are cumulated in results array and updated at each block + // Outer loop indicates which subcolumns/subrows to read from shared memories + for (var dot_index = 0u; dot_index < B_K; dot_index++) { + // Load a subcolumn of values from lhs + for (var tile_index = 0u; tile_index < T_M; tile_index++) { + let lhs_sm_position = thread_row + tile_index + dot_index * B_M; + register_M[tile_index] = shared_lhs[lhs_sm_position]; + } + // Load a subrow of values from rhs + for (var tile_index = 0u; tile_index < T_N; tile_index++) { + let rhs_sm_position = thread_col + tile_index + dot_index * B_N; + register_N[tile_index] = shared_rhs[rhs_sm_position]; + } + // Multiply subcolumn and subrow and store results + for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) { + for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) { + results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N]; + } + } + } + + workgroupBarrier(); + } + + // Write output matrix + // Each thread is responsible of writing T_M x T_N results + for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) { + for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) { + let result_position = res_idx_M * T_N + res_idx_N; + let output_position = offset_output + (row + res_idx_M) * n_cols + col + res_idx_N;; + output[output_position] = results[result_position]; + } + } +} diff --git a/burn-wgpu/src/template/matmul_blocktiling_2d.wgsl b/burn-wgpu/src/template/matmul/blocktiling_2d/tile.wgsl similarity index 73% rename from burn-wgpu/src/template/matmul_blocktiling_2d.wgsl rename to burn-wgpu/src/template/matmul/blocktiling_2d/tile.wgsl index e6fc013c3..f6a83c659 100644 --- a/burn-wgpu/src/template/matmul_blocktiling_2d.wgsl +++ b/burn-wgpu/src/template/matmul/blocktiling_2d/tile.wgsl @@ -68,35 +68,31 @@ fn main( offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs; } - // In case B_M % T_M != 0 or B_N % T_N != 0 - // A thread must not read out of its block - let actual_T_M = min(B_M - thread_row, T_M); - let actual_T_N = min(B_N - thread_col, T_N); - var results: array<{{ elem }}, T_M_X_T_N>; var register_M: array<{{ elem }}, T_M>; var register_N: array<{{ elem }}, T_N>; + let thread_offset = local_idx * T_M_X_T_N; + for (var k = 0u; k < K; k += B_K) { // sm_limit ensures that although there are up to B_M x B_N writes to memory, // shared memories remain B_M x B_K (lhs) or B_K x B_N (rhs) // also ensures we do not read out of matrices if M % B_M != 0 or N % B_N != 0 - let sm_limit = min(B_K, K - k); // Load data into shared memories // Each thread is responsible of loading T_M x T_N values from both lhs and rhs - for (var i = 0u; i < actual_T_M; i++) { - for (var j = 0u; j < actual_T_N; j++) { + for (var i = 0u; i < T_M; i++) { + for (var j = 0u; j < T_N; j++) { let current_row = thread_row + i; let current_col = thread_col + j; - if current_col < sm_limit { + if current_col < B_K { let lhs_sm_position = current_row * B_K + current_col; let lhs_position = offset_lhs + k + current_row * K + current_col; shared_lhs[lhs_sm_position] = lhs[lhs_position]; } - if current_row < sm_limit { + if current_row < B_K { let rhs_sm_position = current_row * B_N + current_col; let rhs_position = offset_rhs + (k + current_row) * n_cols + current_col; shared_rhs[rhs_sm_position] = rhs[rhs_position]; @@ -104,26 +100,27 @@ fn main( } } + workgroupBarrier(); // Compute intermediate results // Results are cumulated in results array and updated at each block // Outer loop indicates which subcolumns/subrows to read from shared memories - for (var dot_index = 0u; dot_index < sm_limit; dot_index++) { + for (var dot_index = 0u; dot_index < B_K; dot_index++) { // Load a subcolumn of values from lhs - for (var tile_index = 0u; tile_index < actual_T_M; tile_index++) { + for (var tile_index = 0u; tile_index < T_M; tile_index++) { let lhs_sm_position = (thread_row + tile_index) * B_K + dot_index; register_M[tile_index] = shared_lhs[lhs_sm_position]; } // Load a subrow of values from rhs - for (var tile_index = 0u; tile_index < actual_T_N; tile_index++) { + for (var tile_index = 0u; tile_index < T_N; tile_index++) { let rhs_sm_position = thread_col + tile_index + dot_index * B_N; register_N[tile_index] = shared_rhs[rhs_sm_position]; } // Multiply subcolumn and subrow and store results - for (var res_idx_M = 0u; res_idx_M < actual_T_M; res_idx_M++) { - for (var res_idx_N = 0u; res_idx_N < actual_T_N; res_idx_N++) { - results[res_idx_M * actual_T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N]; + for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) { + for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) { + results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N]; } } } @@ -133,16 +130,11 @@ fn main( // Write output matrix // Each thread is responsible of writing T_M x T_N results - for (var res_idx_M = 0u; res_idx_M < actual_T_M; res_idx_M++) { - for (var res_idx_N = 0u; res_idx_N < actual_T_N; res_idx_N++) { - let current_row = row + res_idx_M; - let current_col = col + res_idx_N; - // Check that we are within the bounds of output matrix - if current_row < n_rows && current_col < n_cols { - let result_position = res_idx_M * actual_T_N + res_idx_N; - let output_position = offset_output + current_row * n_cols + current_col; - output[output_position] = results[result_position]; - } + for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) { + for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) { + let result_position = res_idx_M * T_N + res_idx_N; + let output_position = offset_output + (row + res_idx_M) * n_cols + col + res_idx_N;; + output[output_position] = results[result_position]; } } } diff --git a/burn-wgpu/src/template/matmul/blocktiling_2d/tile_vectorized.wgsl b/burn-wgpu/src/template/matmul/blocktiling_2d/tile_vectorized.wgsl new file mode 100644 index 000000000..1e3b62389 --- /dev/null +++ b/burn-wgpu/src/template/matmul/blocktiling_2d/tile_vectorized.wgsl @@ -0,0 +1,140 @@ +@group(0) +@binding(0) +var lhs: array<{{ elem }}>; + +@group(0) +@binding(1) +var rhs: array<{{ elem }}>; + +@group(0) +@binding(2) +var output: array<{{ elem }}>; + +@group(0) +@binding(3) +var info: array; + +const B_M = {{b_m}}u; +const B_N = {{b_n}}u; +const B_K = {{b_k}}u; +const B_M_X_B_K = {{bm_x_bk}}u; +const B_K_X_B_N = {{bk_x_bn}}u; +const T_M = {{t_m}}u; +const T_N = {{t_n}}u; +const T_M_X_T_N = {{tm_x_tn}}u; + +var shared_lhs: array<{{ elem }}, B_M_X_B_K>; +var shared_rhs: array<{{ elem }}, B_K_X_B_N>; + +@compute +@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }}) +fn main( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_index) local_idx: u32, + @builtin(workgroup_id) workgroup_id: vec3, +) { + let skip_row = workgroup_id.x * B_M; + let skip_col = workgroup_id.y * B_N; + + let n_thread_per_row = ((B_N - 1u) / T_N) + 1u; + let thread_row = (local_idx / n_thread_per_row) * T_M; + let thread_col = (local_idx % n_thread_per_row) * T_N; + + let row = skip_row + thread_row; + let col = skip_col + thread_col; + + let batch = global_id.z; + + // Basic information + let dim = info[0]; + let n_rows = info[6u * dim - 1u]; + let n_cols = info[6u * dim]; + let K = info[5u * dim - 1u]; + + // Calculate the corresponding offsets with support for broadcasting. + let offset_output = batch * n_rows * n_cols; + var offset_lhs: u32 = skip_row * K; + var offset_rhs: u32 = skip_col; + + let batch_dims = dim - 2u; + for (var b: u32 = 1u; b <= batch_dims; b++) { + let stride_lhs = info[b]; + let stride_rhs = info[b + dim]; + let stride_output = info[b + 2u * dim]; + let shape_lhs = info[b + 3u * dim]; + let shape_rhs = info[b + 4u * dim]; + + offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs; + offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs; + } + + var results: array<{{ elem }}, T_M_X_T_N>; + var register_M: array<{{ elem }}, T_M>; + var register_N: array<{{ elem }}, T_N>; + + let thread_offset = local_idx * T_M_X_T_N; + + for (var k = 0u; k < K; k += B_K) { + // sm_limit ensures that although there are up to B_M x B_N writes to memory, + // shared memories remain B_M x B_K (lhs) or B_K x B_N (rhs) + // also ensures we do not read out of matrices if M % B_M != 0 or N % B_N != 0 + + // Load data into shared memories + // Each thread is responsible of loading T_M x T_N values from both lhs and rhs + for (var i = 0u; i < T_M; i++) { + for (var j = 0u; j < T_N; j++) { + let current_row = thread_row + i; + let current_col = thread_col + j; + + if current_col < B_K { + let lhs_sm_position = current_row + current_col * B_M; + let lhs_position = offset_lhs + k + current_row * K + current_col; + shared_lhs[lhs_sm_position] = lhs[lhs_position]; + } + + if current_row < B_K { + let rhs_sm_position = current_row * B_N + current_col; + let rhs_position = offset_rhs + (k + current_row) * n_cols + current_col; + shared_rhs[rhs_sm_position] = rhs[rhs_position]; + } + } + } + + + workgroupBarrier(); + + // Compute intermediate results + // Results are cumulated in results array and updated at each block + // Outer loop indicates which subcolumns/subrows to read from shared memories + for (var dot_index = 0u; dot_index < B_K; dot_index++) { + // Load a subcolumn of values from lhs + for (var tile_index = 0u; tile_index < T_M; tile_index++) { + let lhs_sm_position = thread_row + tile_index + dot_index * B_M; + register_M[tile_index] = shared_lhs[lhs_sm_position]; + } + // Load a subrow of values from rhs + for (var tile_index = 0u; tile_index < T_N; tile_index++) { + let rhs_sm_position = thread_col + tile_index + dot_index * B_N; + register_N[tile_index] = shared_rhs[rhs_sm_position]; + } + // Multiply subcolumn and subrow and store results + for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) { + for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) { + results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N]; + } + } + } + + workgroupBarrier(); + } + + // Write output matrix + // Each thread is responsible of writing T_M x T_N results + for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) { + for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) { + let result_position = res_idx_M * T_N + res_idx_N; + let output_position = offset_output + (row + res_idx_M) * n_cols + col + res_idx_N;; + output[output_position] = results[result_position]; + } + } +} diff --git a/burn-wgpu/src/template/matmul_mem_coalescing.wgsl b/burn-wgpu/src/template/matmul/mem_coalescing.wgsl similarity index 100% rename from burn-wgpu/src/template/matmul_mem_coalescing.wgsl rename to burn-wgpu/src/template/matmul/mem_coalescing.wgsl diff --git a/burn-wgpu/src/template/matmul_naive.wgsl b/burn-wgpu/src/template/matmul/naive.wgsl similarity index 100% rename from burn-wgpu/src/template/matmul_naive.wgsl rename to burn-wgpu/src/template/matmul/naive.wgsl