mirror of https://github.com/tracel-ai/burn.git
Matmul 2D Tiling (#442)
This commit is contained in:
parent
f42176e93a
commit
f99fe0fadd
|
@ -55,4 +55,34 @@ mod tests {
|
|||
Data::from([[15.0, 34.0, 53.0], [42.0, 106.0, 170.0]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_simple_2() {
|
||||
let tensor_1 = TestTensor::from_floats([[1.0, 2.0, 3.0, 4.0]]);
|
||||
let tensor_2 = TestTensor::from_floats([[3.0], [4.0], [5.0], [6.0]]);
|
||||
|
||||
let tensor_3 = tensor_1.matmul(tensor_2);
|
||||
|
||||
assert_eq!(tensor_3.into_data(), Data::from([[50.0]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_simple_3() {
|
||||
let tensor_1 =
|
||||
TestTensor::from_floats([[3., 3., 3.], [4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]);
|
||||
let tensor_2 =
|
||||
TestTensor::from_floats([[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]]);
|
||||
|
||||
let tensor_3 = tensor_1.matmul(tensor_2);
|
||||
|
||||
assert_eq!(
|
||||
tensor_3.into_data(),
|
||||
Data::from([
|
||||
[9., 18., 27., 36.],
|
||||
[12., 24., 36., 48.],
|
||||
[15., 30., 45., 60.],
|
||||
[18., 36., 54., 72.]
|
||||
])
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,11 +27,14 @@ spin = {workspace = true}
|
|||
# WGPU stuff
|
||||
futures-intrusive = {workspace = true}
|
||||
pollster = {workspace = true}
|
||||
text_placeholder = {workspace = true}
|
||||
serde = {workspace = true}
|
||||
wgpu = {workspace = true}
|
||||
|
||||
# Template
|
||||
serde = {workspace = true}
|
||||
text_placeholder = { version = "0.5.0", features = ["struct_context"] }
|
||||
|
||||
[dev-dependencies]
|
||||
burn-ndarray = {path = "../burn-ndarray", version = "0.8.0"}
|
||||
burn-autodiff = {path = "../burn-autodiff", version = "0.8.0", default-features = false, features = [
|
||||
"export_tests",
|
||||
]}
|
||||
|
|
|
@ -51,15 +51,15 @@ impl Context {
|
|||
/// Create a new context where computing tasks will be executed on the given
|
||||
/// [device](WgpuDevice).
|
||||
pub(crate) fn new<G: GraphicsApi>(device: &WgpuDevice) -> Self {
|
||||
let device_wgpu = device.clone();
|
||||
let (device, queue) = pollster::block_on(select_device::<G>(device));
|
||||
let device = Arc::new(device);
|
||||
let client = ContextServerImpl::start(device.clone(), queue);
|
||||
let (device_wgpu, queue) = pollster::block_on(select_device::<G>(device));
|
||||
let device = device.clone();
|
||||
let device_wgpu = Arc::new(device_wgpu);
|
||||
let client = ContextServerImpl::start(device_wgpu.clone(), queue);
|
||||
|
||||
Self {
|
||||
id: IdGenerator::generate(),
|
||||
device_wgpu: device,
|
||||
device: device_wgpu,
|
||||
device_wgpu,
|
||||
device,
|
||||
client,
|
||||
cache: Mutex::new(HashMap::new()),
|
||||
}
|
||||
|
|
|
@ -130,7 +130,6 @@ pub(crate) fn build_info<E: WgpuElement, const D: usize>(
|
|||
current += 1;
|
||||
}
|
||||
}
|
||||
|
||||
info
|
||||
}
|
||||
|
||||
|
|
|
@ -1,16 +1,64 @@
|
|||
use super::{build_info, KernelSettings};
|
||||
use crate::{context::WorkGroup, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
|
||||
use std::cmp::{max, min};
|
||||
|
||||
use super::{build_info, SourceTemplate, StaticKernel};
|
||||
use crate::{
|
||||
context::WorkGroup, element::WgpuElement, kernel::KernelSettings, kernel_wgsl,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
use burn_tensor::Shape;
|
||||
|
||||
const BLOCK_SIZE: usize = 16;
|
||||
// Suppose a matmul of m1 of size [M, K] with m2 of size [K, N]
|
||||
// Block size along dim M
|
||||
const B_M: usize = 128;
|
||||
// Block size along dim N
|
||||
const B_N: usize = 128;
|
||||
// Block size along dim K
|
||||
const B_K: usize = 8;
|
||||
// Tiling size along dim M
|
||||
const T_M: usize = 8;
|
||||
// Tiling size along dim N
|
||||
const T_N: usize = 8;
|
||||
|
||||
kernel_wgsl!(MatmulCoalescing, "../template/matmul_mem_coalescing.wgsl");
|
||||
// WORKGROUP_SIZE_X = ceil(B_M / T_M)
|
||||
const WORKGROUP_SIZE_X: usize = 16;
|
||||
// WORKGROUP_SIZE_Y = ceil(B_N / T_N)
|
||||
const WORKGROUP_SIZE_Y: usize = 16;
|
||||
|
||||
const MAX_SHARED_MEMORY_SIZE: usize = 8192;
|
||||
|
||||
kernel_wgsl!(MatmulTiling2DRaw, "../template/matmul_blocktiling_2d.wgsl");
|
||||
|
||||
struct MatmulTiling2D;
|
||||
|
||||
impl StaticKernel for MatmulTiling2D {
|
||||
fn source_template() -> SourceTemplate {
|
||||
MatmulTiling2DRaw::source_template()
|
||||
.register("b_m", B_M.to_string())
|
||||
.register("b_n", B_N.to_string())
|
||||
.register("b_k", B_K.to_string())
|
||||
.register("bm_x_bk", (B_M * B_K).to_string())
|
||||
.register("bk_x_bn", (B_K * B_N).to_string())
|
||||
.register("t_m", T_M.to_string())
|
||||
.register("t_n", T_N.to_string())
|
||||
.register("tm_x_tn", (T_M * T_N).to_string())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn matmul<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
matmul_tiling_2d(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn matmul_tiling_2d<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
assert!(B_K <= min(B_M, B_N), "B_K must be smaller than both B_M and B_M, otherwise there won't be enough threads to fill shared memory. ");
|
||||
assert!(B_K * max(B_M, B_N) <= MAX_SHARED_MEMORY_SIZE, "B_K x B_M and B_K x B_N must be smaller or equal than 8192, otherwise shared memory limit will be busted. ");
|
||||
lhs.assert_is_on_save_device(&rhs);
|
||||
|
||||
let mut shape_out = [0; D];
|
||||
lhs.shape
|
||||
.dims
|
||||
|
@ -21,22 +69,32 @@ pub fn matmul<E: WgpuElement, const D: usize>(
|
|||
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
|
||||
});
|
||||
|
||||
shape_out[D - 2] = lhs.shape.dims[D - 2];
|
||||
shape_out[D - 1] = rhs.shape.dims[D - 1];
|
||||
let num_rows = lhs.shape.dims[D - 2];
|
||||
let num_cols = rhs.shape.dims[D - 1];
|
||||
shape_out[D - 2] = num_rows;
|
||||
shape_out[D - 1] = num_cols;
|
||||
let shape_out = Shape::new(shape_out);
|
||||
|
||||
let buffer = lhs
|
||||
.context
|
||||
.create_buffer(shape_out.num_elements() * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(lhs.context.clone(), shape_out, buffer);
|
||||
let num_rows = lhs.shape.dims[D - 2];
|
||||
let num_cols = rhs.shape.dims[D - 1];
|
||||
|
||||
let kernel = lhs
|
||||
.context
|
||||
.compile_static::<KernelSettings<MatmulCoalescing, E, i32, BLOCK_SIZE, BLOCK_SIZE, 1>>();
|
||||
// set number of workgroups
|
||||
let blocks_needed_in_x = f32::ceil(num_rows as f32 / (WORKGROUP_SIZE_X * T_M) as f32) as u32;
|
||||
let blocks_needed_in_y = f32::ceil(num_cols as f32 / (WORKGROUP_SIZE_Y * T_N) as f32) as u32;
|
||||
|
||||
let kernel = lhs.context.compile_static::<KernelSettings<
|
||||
MatmulTiling2D,
|
||||
E,
|
||||
i32,
|
||||
WORKGROUP_SIZE_X,
|
||||
WORKGROUP_SIZE_Y,
|
||||
1,
|
||||
>>();
|
||||
|
||||
let info = build_info(&[&lhs, &rhs, &output]);
|
||||
|
||||
let info_buffers = lhs
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
@ -46,9 +104,7 @@ pub fn matmul<E: WgpuElement, const D: usize>(
|
|||
num_iter *= output.shape.dims[i];
|
||||
}
|
||||
|
||||
let workgroup_x = f32::ceil(num_rows as f32 / BLOCK_SIZE as f32) as u32;
|
||||
let workgroup_y = f32::ceil(num_cols as f32 / BLOCK_SIZE as f32) as u32;
|
||||
let workgroup = WorkGroup::new(workgroup_x, workgroup_y, num_iter as u32);
|
||||
let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32);
|
||||
|
||||
lhs.context.execute(
|
||||
workgroup,
|
||||
|
|
|
@ -0,0 +1,148 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> lhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> rhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(3)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const B_M = {{b_m}}u;
|
||||
const B_N = {{b_n}}u;
|
||||
const B_K = {{b_k}}u;
|
||||
const B_M_X_B_K = {{bm_x_bk}}u;
|
||||
const B_K_X_B_N = {{bk_x_bn}}u;
|
||||
const T_M = {{t_m}}u;
|
||||
const T_N = {{t_n}}u;
|
||||
const T_M_X_T_N = {{tm_x_tn}}u;
|
||||
|
||||
var<workgroup> shared_lhs: array<{{ elem }}, B_M_X_B_K>;
|
||||
var<workgroup> shared_rhs: array<{{ elem }}, B_K_X_B_N>;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }})
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_index) local_idx: u32,
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
||||
) {
|
||||
let skip_row = workgroup_id.x * B_M;
|
||||
let skip_col = workgroup_id.y * B_N;
|
||||
|
||||
let n_thread_per_row = ((B_N - 1u) / T_N) + 1u;
|
||||
let thread_row = (local_idx / n_thread_per_row) * T_M;
|
||||
let thread_col = (local_idx % n_thread_per_row) * T_N;
|
||||
|
||||
let row = skip_row + thread_row;
|
||||
let col = skip_col + thread_col;
|
||||
|
||||
let batch = global_id.z;
|
||||
|
||||
// Basic information
|
||||
let dim = info[0];
|
||||
let n_rows = info[6u * dim - 1u];
|
||||
let n_cols = info[6u * dim];
|
||||
let K = info[5u * dim - 1u];
|
||||
|
||||
// Calculate the corresponding offsets with support for broadcasting.
|
||||
let offset_output = batch * n_rows * n_cols;
|
||||
var offset_lhs: u32 = skip_row * K;
|
||||
var offset_rhs: u32 = skip_col;
|
||||
|
||||
let batch_dims = dim - 2u;
|
||||
for (var b: u32 = 1u; b <= batch_dims; b++) {
|
||||
let stride_lhs = info[b];
|
||||
let stride_rhs = info[b + dim];
|
||||
let stride_output = info[b + 2u * dim];
|
||||
let shape_lhs = info[b + 3u * dim];
|
||||
let shape_rhs = info[b + 4u * dim];
|
||||
|
||||
offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;
|
||||
offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
|
||||
}
|
||||
|
||||
// In case B_M % T_M != 0 or B_N % T_N != 0
|
||||
// A thread must not read out of its block
|
||||
let actual_T_M = min(B_M - thread_row, T_M);
|
||||
let actual_T_N = min(B_N - thread_col, T_N);
|
||||
|
||||
var results: array<{{ elem }}, T_M_X_T_N>;
|
||||
var register_M: array<{{ elem }}, T_M>;
|
||||
var register_N: array<{{ elem }}, T_N>;
|
||||
|
||||
for (var k = 0u; k < K; k += B_K) {
|
||||
// sm_limit ensures that although there are up to B_M x B_N writes to memory,
|
||||
// shared memories remain B_M x B_K (lhs) or B_K x B_N (rhs)
|
||||
// also ensures we do not read out of matrices if M % B_M != 0 or N % B_N != 0
|
||||
let sm_limit = min(B_K, K - k);
|
||||
|
||||
// Load data into shared memories
|
||||
// Each thread is responsible of loading T_M x T_N values from both lhs and rhs
|
||||
for (var i = 0u; i < actual_T_M; i++) {
|
||||
for (var j = 0u; j < actual_T_N; j++) {
|
||||
let current_row = thread_row + i;
|
||||
let current_col = thread_col + j;
|
||||
|
||||
if current_col < sm_limit {
|
||||
let lhs_sm_position = current_row * B_K + current_col;
|
||||
let lhs_position = offset_lhs + k + current_row * K + current_col;
|
||||
shared_lhs[lhs_sm_position] = lhs[lhs_position];
|
||||
}
|
||||
|
||||
if current_row < sm_limit {
|
||||
let rhs_sm_position = current_row * B_N + current_col;
|
||||
let rhs_position = offset_rhs + (k + current_row) * n_cols + current_col;
|
||||
shared_rhs[rhs_sm_position] = rhs[rhs_position];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute intermediate results
|
||||
// Results are cumulated in results array and updated at each block
|
||||
// Outer loop indicates which subcolumns/subrows to read from shared memories
|
||||
for (var dot_index = 0u; dot_index < B_K; dot_index++) {
|
||||
// Load a subcolumn of values from lhs
|
||||
for (var tile_index = 0u; tile_index < actual_T_M; tile_index++) {
|
||||
let lhs_sm_position = (thread_row + tile_index) * B_K + dot_index;
|
||||
register_M[tile_index] = shared_lhs[lhs_sm_position];
|
||||
}
|
||||
// Load a subrow of values from rhs
|
||||
for (var tile_index = 0u; tile_index < actual_T_N; tile_index++) {
|
||||
let rhs_sm_position = thread_col + tile_index + dot_index * B_N;
|
||||
register_N[tile_index] = shared_rhs[rhs_sm_position];
|
||||
}
|
||||
// Multiply subcolumn and subrow and store results
|
||||
for (var res_idx_M = 0u; res_idx_M < actual_T_M; res_idx_M++) {
|
||||
for (var res_idx_N = 0u; res_idx_N < actual_T_N; res_idx_N++) {
|
||||
results[res_idx_M * actual_T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Write output matrix
|
||||
// Each thread is responsible of writing T_M x T_N results
|
||||
for (var res_idx_M = 0u; res_idx_M < actual_T_M; res_idx_M++) {
|
||||
for (var res_idx_N = 0u; res_idx_N < actual_T_N; res_idx_N++) {
|
||||
let current_row = row + res_idx_M;
|
||||
let current_col = col + res_idx_N;
|
||||
// Check that we are within the bounds of output matrix
|
||||
if current_row < n_rows && current_col < n_cols {
|
||||
let result_position = res_idx_M * actual_T_N + res_idx_N;
|
||||
let output_position = offset_output + current_row * n_cols + current_col;
|
||||
output[output_position] = results[result_position];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue