Perf/wgpu/matmul unpadded (#922)

This commit is contained in:
Louis Fortier-Dubois 2023-11-01 16:37:33 -04:00 committed by GitHub
parent 64e58b4463
commit 35df31f700
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 475 additions and 13 deletions

View File

@ -2,6 +2,7 @@ use burn_common::benchmark::{run_benchmark, Benchmark};
use burn_tensor::backend::Backend;
use burn_tensor::{Distribution, Shape, Tensor};
use burn_wgpu::kernel::matmul::init_matmul_output;
use burn_wgpu::kernel::matmul::unpadded::matmul_tiling_2d_unpadded;
use burn_wgpu::kernel::matmul::vec4::matmul_tiling_2d_vec4;
use burn_wgpu::kernel::matmul::vec4_lhs::matmul_tiling_2d_vec4_lhs;
use burn_wgpu::WgpuDevice;
@ -100,6 +101,11 @@ bench_matmul!(
Tiling2DMatmulVec4,
matmul_tiling_2d_vec4
);
bench_matmul!(
Tiling2DMatmulUnpaddedBenchmark,
Tiling2DMatmulUnpadded,
matmul_tiling_2d_unpadded
);
#[allow(dead_code)]
/// Runs the benchmarks for wgpu matmul implementations
@ -107,9 +113,9 @@ pub fn bench(device: &WgpuDevice) {
const D: usize = 3;
let num_repeats = 3;
let batch_size = 3;
let m = 2048;
let k = 2048;
let n = 1024;
let m = 1007;
let k = 1023;
let n = 1005;
let shape_lhs = Shape::new([batch_size, m, k]);
let shape_rhs = Shape::new([batch_size, k, n]);
@ -125,6 +131,7 @@ pub fn bench(device: &WgpuDevice) {
}
run_matmul_benchmark!(NaiveMatmulBenchmark);
run_matmul_benchmark!(MemCoalescingMatmulBenchmark);
run_matmul_benchmark!(Tiling2DMatmulUnpaddedBenchmark);
run_matmul_benchmark!(Tiling2DMatmulVec4LHSBenchmark);
run_matmul_benchmark!(Tiling2DMatmulVec4Benchmark);
}

View File

@ -13,7 +13,7 @@ pub(crate) const B_N: usize = 64;
pub(crate) const B_K: usize = 32;
pub(crate) const WORKGROUP_SIZE: usize = 16;
pub(super) fn make_workgroup<const D: usize>(output_shape: Shape<D>) -> WorkGroup {
pub(super) fn make_workgroup<const D: usize>(output_shape: &Shape<D>) -> WorkGroup {
let num_blocks_x = f32::ceil(output_shape.dims[D - 2] as f32 / B_M as f32) as u32;
let num_blocks_y = f32::ceil(output_shape.dims[D - 1] as f32 / B_N as f32) as u32;
let mut num_blocks_z = 1;
@ -71,7 +71,7 @@ pub(super) fn matmul_tiling_2d_launch<
rounded_output_shape.clone(),
);
let workgroup = make_workgroup(rounded_output_shape);
let workgroup = make_workgroup(&rounded_output_shape);
let info_handle = make_info_handle(&lhs, &rhs, &rounded_output);
lhs.client.execute(

View File

@ -1,6 +1,9 @@
mod base;
mod padding;
/// WGSL vec4 primitives are used on left and right hand tensor,
/// padding is avoided through the use of conditions in the kernel
pub mod unpadded;
/// WGSL vec4 primitives are used on left and right hand tensor
pub mod vec4;
/// WGSL vec4 primitives are used on left hand tensor

View File

@ -0,0 +1,195 @@
use burn_tensor::Element;
use crate::{
compute::DynamicKernel,
element::WgpuElement,
kernel::{into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource},
tensor::WgpuTensor,
};
use std::marker::PhantomData;
use crate::kernel_wgsl;
use super::base::{make_info_handle, make_workgroup, B_K, B_M, B_N, WORKGROUP_SIZE};
kernel_wgsl!(
MatmulTiling2DUnpaddedRaw,
"../../../template/matmul/blocktiling_2d/unpadded.wgsl"
);
#[derive(new, Debug)]
struct MatmulTiling2DUnpadded<E: WgpuElement> {
_elem: PhantomData<E>,
}
impl<E: WgpuElement> DynamicKernelSource for MatmulTiling2DUnpadded<E> {
fn source(&self) -> SourceTemplate {
MatmulTiling2DUnpaddedRaw::source()
.register("b_m", B_M.to_string())
.register("b_n", B_N.to_string())
.register("b_k", B_K.to_string())
.register("bm_x_bk_4", (B_M * B_K / 4).to_string())
.register("bk_x_bn_4", (B_K * B_N / 4).to_string())
.register("workgroup_size_x", WORKGROUP_SIZE.to_string())
.register("workgroup_size_y", WORKGROUP_SIZE.to_string())
.register("workgroup_size_z", "1".to_string())
.register("elem", E::type_name())
.register("int", "i32")
}
fn id(&self) -> String {
std::format!("{:?}", self)
}
}
/// Matrix multiplication using tiling 2d algorithm with
/// vec4 primitive on both lhs and rhs, with no padding needed
pub fn matmul_tiling_2d_unpadded<E: WgpuElement + Element, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
out: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
let lhs = match lhs.batch_swapped_with_row_col() {
true => into_contiguous(lhs),
false => lhs,
};
let rhs = match rhs.batch_swapped_with_row_col() {
true => into_contiguous(rhs),
false => rhs,
};
let workgroup = make_workgroup(&out.shape);
let info_handle = make_info_handle(&lhs, &rhs, &out);
lhs.client.execute(
Box::new(DynamicKernel::new(
MatmulTiling2DUnpadded::<E>::new(),
workgroup,
)),
&[&lhs.handle, &rhs.handle, &out.handle, &info_handle],
);
out
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims};
#[test]
pub fn test_matmul_unpadded_straightforward() {
test_with_params(1, 2, 1, 1, 1);
}
#[test]
pub fn test_matmul_unpadded_shapes_smaller_than_blocks() {
test_with_params(8, 8, 8, 1, 1);
}
#[test]
pub fn test_matmul_unpadded_shapes_equal_blocks() {
test_with_params(64, 32, 64, 2, 2);
}
#[test]
pub fn test_matmul_unpadded_m_exceeds_block() {
test_with_params(75, 32, 64, 2, 2);
}
#[test]
pub fn test_matmul_unpadded_k_exceeds_block() {
test_with_params(64, 33, 32, 1, 1);
}
#[test]
pub fn test_matmul_irregular_shape() {
test_with_params(123, 255, 72, 3, 5);
}
#[test]
pub fn test64_matmul_unpadded_n_exceeds_block() {
test_with_params(64, 32, 75, 2, 2);
}
#[test]
pub fn test_matmul_unpadded_n_smaller_than_m() {
test_with_params(8, 8, 3, 1, 1);
}
#[test]
pub fn test_matmul_unpadded_m_smaller_than_n() {
test_with_params(3, 8, 8, 1, 1);
}
#[test]
pub fn test_matmul_unpadded_k_smaller_than_m_n() {
test_with_params(8, 3, 8, 1, 1);
}
#[test]
pub fn test_matmul_unpadded_k_larger_than_m_n() {
test_with_params(8, 48, 8, 1, 1);
}
#[test]
pub fn test_matmul_unpadded_multibatch_1_dim() {
test_with_params(8, 8, 8, 3, 1);
}
#[test]
pub fn test_matmul_unpadded_multibatch_2_dims() {
test_with_params(8, 8, 8, 3, 4);
}
#[test]
pub fn test_matmul_unpadded_blocks_divide_shapes_unevenly() {
test_with_params(7, 7, 7, 1, 1);
}
#[test]
pub fn test_matmul_unpadded_medium() {
test_with_params(17, 16, 16, 1, 1);
}
#[test]
pub fn test_matmul_unpadded_large() {
test_with_params(134, 242, 250, 1, 1);
}
fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) {
let func = matmul_tiling_2d_unpadded;
let shape_lhs = [batch_1, batch_2, m, k];
let shape_rhs = [batch_1, batch_2, k, n];
same_as_reference(func, shape_lhs, shape_rhs);
}
#[test]
fn test_matmul_tiling_2d_primitive_swapped_batches_no_padding() {
let matmul_func = matmul_tiling_2d_unpadded;
let swap = [0, 1];
let shape_lhs = [3, 2, 4, 4];
let shape_rhs = [3, 2, 4, 4];
same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs);
}
#[test]
fn test_matmul_tiling_2d_primitive_swapped_row_col_no_padding() {
let matmul_func = matmul_tiling_2d_unpadded;
let swap_lhs = [0, 0];
let swap_rhs = [2, 3];
let shape_lhs = [3, 2, 4, 4];
let shape_rhs = [3, 2, 4, 4];
same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs);
}
#[test]
fn test_matmul_tiling_2d_primitive_swapped_row_with_batch_no_padding() {
let matmul_func = matmul_tiling_2d_unpadded;
let swap_lhs = [0, 3];
let swap_rhs = [0, 2];
let shape_lhs = [4, 4, 4, 4];
let shape_rhs = [4, 4, 4, 4];
same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs);
}
}

View File

@ -12,18 +12,18 @@ use crate::kernel_wgsl;
use super::base::{matmul_tiling_2d_launch, B_K, B_M, B_N, WORKGROUP_SIZE};
kernel_wgsl!(
MatmulTiling2Dvec4RHSRaw,
MatmulTiling2Dvec4Raw,
"../../../template/matmul/blocktiling_2d/vec4.wgsl"
);
#[derive(new, Debug)]
struct MatmulTiling2Dvec4RHS<E: WgpuElement> {
struct MatmulTiling2Dvec4<E: WgpuElement> {
_elem: PhantomData<E>,
}
impl<E: WgpuElement> DynamicKernelSource for MatmulTiling2Dvec4RHS<E> {
impl<E: WgpuElement> DynamicKernelSource for MatmulTiling2Dvec4<E> {
fn source(&self) -> SourceTemplate {
MatmulTiling2Dvec4RHSRaw::source()
MatmulTiling2Dvec4Raw::source()
.register("b_m", B_M.to_string())
.register("b_n", B_N.to_string())
.register("b_k", B_K.to_string())
@ -48,7 +48,7 @@ pub fn matmul_tiling_2d_vec4<E: WgpuElement + Element, const D: usize>(
rhs: WgpuTensor<E, D>,
out: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
let kernel = MatmulTiling2Dvec4RHS::<E>::new();
let kernel = MatmulTiling2Dvec4::<E>::new();
matmul_tiling_2d_launch(lhs, rhs, out, kernel)
}

View File

@ -82,6 +82,11 @@ impl<E: WgpuElement + Element, const D: usize> AutotuneOperationSet
rhs.clone(),
out.clone(),
)),
Box::new(Vec4TilingMatmulUnpaddedDefault::<E, 3>::new(
lhs.clone(),
rhs.clone(),
out.clone(),
)),
Box::new(Vec4LhsOnlyTilingMatmulDefault::<E, 3>::new(lhs, rhs, out)),
]
}
@ -97,7 +102,10 @@ impl<E: WgpuElement + Element, const D: usize> AutotuneOperationSet
2 => Box::new(Vec4TilingMatmulDefault::<E, D>::new(
self.lhs, self.rhs, self.out,
)),
3 => Box::new(Vec4LhsOnlyTilingMatmulDefault::<E, D>::new(
3 => Box::new(Vec4TilingMatmulUnpaddedDefault::<E, D>::new(
self.lhs, self.rhs, self.out,
)),
4 => Box::new(Vec4LhsOnlyTilingMatmulDefault::<E, D>::new(
self.lhs, self.rhs, self.out,
)),
_ => panic!("Fastest index is out of bound"),
@ -162,18 +170,24 @@ matmul_tune_ops!(MemoryCoalescingMatmulW16x16, |lhs, rhs, out| {
crate::kernel::matmul::matmul_mem_coalescing(lhs, rhs, out, 16, 16)
});
// Probably the fastest on MacOS.
// Maybe the fastest on MacOS.
matmul_tune_ops!(
Vec4LhsOnlyTilingMatmulDefault,
crate::kernel::matmul::vec4_lhs::matmul_tiling_2d_vec4_lhs
);
// Probably the fastest.
// Probably the fastest when fixed sizes.
matmul_tune_ops!(
Vec4TilingMatmulDefault,
crate::kernel::matmul::vec4::matmul_tiling_2d_vec4
);
// Probably the fastest otherwise.
matmul_tune_ops!(
Vec4TilingMatmulUnpaddedDefault,
crate::kernel::matmul::unpadded::matmul_tiling_2d_unpadded
);
#[cfg(test)]
mod tests {
use super::*;

View File

@ -0,0 +1,243 @@
@group(0)
@binding(0)
var<storage, read> lhs: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> rhs: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(3)
var<storage, read> info: array<u32>;
const B_M = {{b_m}}u;
const B_N = {{b_n}}u;
const B_K = {{b_k}}u;
const B_M_X_B_K_4 = {{bm_x_bk_4}}u;
const B_K_X_B_N_4 = {{bk_x_bn_4}}u;
const T_M = 4u;
const T_N = 4u;
const T_M_X_T_N = 16u;
var<workgroup> shared_lhs: array<vec4<{{ elem }}>, B_M_X_B_K_4>;
var<workgroup> shared_rhs: array<vec4<{{ elem }}>, B_K_X_B_N_4>;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }})
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_index) local_idx: u32,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
let skip_row = workgroup_id.x * B_M;
let skip_col = workgroup_id.y * B_N;
let n_thread_per_row = ((B_N - 1u) / T_N) + 1u;
// Position of the first element of the thread, relative to the block
let thread_row = (local_idx / n_thread_per_row) * T_M;
let thread_col = (local_idx % n_thread_per_row) * T_N;
// Position of the first element of the thread, in absolute (in one batch)
let row = skip_row + thread_row;
let col = skip_col + thread_col;
let batch = global_id.z;
// Basic information
let dim = info[0];
let n_rows = info[6u * dim - 1u];
let n_cols = info[6u * dim];
let K = info[5u * dim - 1u];
// Row / col strides
let lhs_stride_row = info[dim - 1u];
let lhs_stride_col = info[dim];
let rhs_stride_row = info[2u * dim - 1u];
let rhs_stride_col = info[2u * dim];
let out_stride_row = info [3u * dim - 1u];
let out_stride_col = info [3u * dim];
// Calculate the corresponding offsets with support for broadcasting.
let offset_output = batch * n_rows * n_cols;
var offset_lhs: u32 = skip_row * lhs_stride_row;
var offset_rhs: u32 = skip_col * rhs_stride_col;
let batch_dims = dim - 2u;
for (var b: u32 = 1u; b <= batch_dims; b++) {
let stride_lhs = info[b];
let stride_rhs = info[b + dim];
let stride_output = info[b + 2u * dim];
let shape_lhs = info[b + 3u * dim];
let shape_rhs = info[b + 4u * dim];
offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;
offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
}
// Registers used in the compute pass
var results: array<{{ elem }}, T_M_X_T_N>;
var register_M: vec4<{{ elem }}>;
var register_N: vec4<{{ elem }}>;
// How close is the thread to the end of the matrix.
// If < 4 then it is an edge case
let remain_row_lhs = n_rows - row;
let remain_col_rhs = n_cols - col;
for (var k = 0u; k < K; k += B_K) {
// LHS LOAD PASS
// For the 4 vec4 columns of this thread
for (var j = 0u; j < 4u; j++) {
// The precise
let current_col = thread_col + j;
// Position of the column vec4 in shared memory
let lhs_sm_position = (thread_row/4u) * B_K + current_col;
// To avoid overwriting following row in share memory
if current_col < B_K {
// To pad with zeros if outside lhs
if current_col + k < K && remain_row_lhs >= 1u {
let lhs_position0 = offset_lhs + (k + current_col) * lhs_stride_col + thread_row * lhs_stride_row;
let lhs_position1 = lhs_position0 + lhs_stride_row;
let lhs_position2 = lhs_position1 + lhs_stride_row;
let lhs_position3 = lhs_position2 + lhs_stride_row;
if remain_row_lhs >= 4u {
shared_lhs[lhs_sm_position] = vec4(
lhs[lhs_position0],
lhs[lhs_position1],
lhs[lhs_position2],
lhs[lhs_position3],
);
} else if remain_row_lhs == 3u {
shared_lhs[lhs_sm_position] = vec4(
lhs[lhs_position0],
lhs[lhs_position1],
lhs[lhs_position2],
0.
);
} else if remain_row_lhs == 2u {
shared_lhs[lhs_sm_position] = vec4(
lhs[lhs_position0],
lhs[lhs_position1],
0.,
0.
);
} else if remain_row_lhs == 1u {
shared_lhs[lhs_sm_position] = vec4(
lhs[lhs_position0],
0.,
0.,
0.
);
}
} else {
shared_lhs[lhs_sm_position] = vec4(0.,0.,0.,0.);
}
}
}
// RHS LOAD PASS
for (var i = 0u; i < 4u; i++) {
let current_row = thread_row + i;
let rhs_sm_position = (current_row * B_N + thread_col) / 4u;
if current_row < B_K {
if current_row + k < K && remain_col_rhs >= 1u {
let rhs_position0 = offset_rhs + (k + current_row) * rhs_stride_row + thread_col * rhs_stride_col;
let rhs_position1 = rhs_position0 + rhs_stride_col;
let rhs_position2 = rhs_position1 + rhs_stride_col;
let rhs_position3 = rhs_position2 + rhs_stride_col;
if remain_col_rhs >= 4u {
shared_rhs[rhs_sm_position] = vec4(
rhs[rhs_position0],
rhs[rhs_position1],
rhs[rhs_position2],
rhs[rhs_position3],
);
} else if remain_col_rhs == 3u {
shared_rhs[rhs_sm_position] = vec4(
rhs[rhs_position0],
rhs[rhs_position1],
rhs[rhs_position2],
0.
);
} else if remain_col_rhs == 2u {
shared_rhs[rhs_sm_position] = vec4(
rhs[rhs_position0],
rhs[rhs_position1],
0.,
0.
);
} else if remain_col_rhs == 1u {
shared_rhs[rhs_sm_position] = vec4(
rhs[rhs_position0],
0.,
0.,
0.
);
}
} else {
shared_rhs[rhs_sm_position] = vec4(0.,0.,0.,0.);
}
}
}
workgroupBarrier();
// COMPUTE PASS
// Compute intermediate results
// Results are cumulated in results array and updated at each block
// Outer loop indicates which subcolumns/subrows to read from shared memories
for (var dot_index = 0u; dot_index < B_K; dot_index++) {
// Load a subcolumn of values from lhs
let lhs_sm_position = (thread_row/4u) * B_K + dot_index;
register_M = shared_lhs[lhs_sm_position];
// Load a subrow of values from rhs
let rhs_sm_position = (dot_index * B_N + thread_col) / 4u;
register_N = shared_rhs[rhs_sm_position];
// Multiply subcolumn and subrow and store results
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N];
}
}
}
workgroupBarrier();
}
// OUTPUT PASS
// Write output matrix
// Each thread is responsible of writing T_M x T_N results
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
let row_index = row + res_idx_M;
let col_index = col + res_idx_N;
if row_index < n_rows && col_index < n_cols {
let result_position = res_idx_M * T_N + res_idx_N;
let output_position = offset_output + row_index * out_stride_row + col_index * out_stride_col;
output[output_position] = results[result_position];
}
}
}
}